Skip to content
Snippets Groups Projects
Commit e35720fd authored by lukas leufen's avatar lukas leufen
Browse files

finished track plot and chain testing, PyTestAllEqual can handle xr and np arrays

parent df9c84f4
No related branches found
No related tags found
4 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!98include track datastore,!91WIP: Resolve "create sphinx docu"
Pipeline #37052 passed
"""Collection of different supporting functions and classes."""
from .testing import PyTestRegex, PyTestAllEqual, xr_all_equal
from .testing import PyTestRegex, PyTestAllEqual
from .time_tracking import TimeTracking, TimeTrackingWrapper
from .logger import Logger
from .helpers import remove_items, float_round, dict_to_xarray, to_list
......@@ -2,6 +2,7 @@
import re
from typing import Union, Pattern, List
import numpy as np
import xarray as xr
......@@ -44,6 +45,13 @@ class PyTestAllEqual:
def __init__(self, check_list: List):
"""Construct class."""
self._list = check_list
self._test_function = None
def _set_test_function(self):
if isinstance(self._list[0], np.ndarray):
self._test_function = np.testing.assert_array_equal
else:
self._test_function = xr.testing.assert_equal
def _check_all_equal(self) -> bool:
"""
......@@ -52,8 +60,9 @@ class PyTestAllEqual:
:return boolean if elements are equal
"""
equal = True
self._set_test_function()
for b in self._list:
equal *= xr.testing.assert_equal(self._list[0], b) is None
equal *= self._test_function(self._list[0], b) is None
return bool(equal == 1)
def is_true(self) -> bool:
......
......@@ -272,7 +272,7 @@ class TrackPlot:
"""Draw rectangle with lower left at (x,y), size equal to width/height and label/color according to method."""
# draw rectangle
color = {"get": "orange"}.get(method, "lightblue")
r = Rectangle((x, y), self.width, self.height, color=color, label=color)
r = Rectangle((x, y), self.width, self.height, color=color)
self.ax.add_artist(r)
# add label
rx, ry = r.get_xy()
......
import pytest
from collections import OrderedDict
import os
import shutil
from matplotlib import pyplot as plt
import numpy as np
from src.plotting.tracker_plot import TrackObject, TrackChain, TrackPlot
from src.helpers import PyTestAllEqual
from src.plotting.tracker_plot import TrackObject, TrackChain
class TestTrackObject:
......@@ -284,3 +291,157 @@ class TestTrackChain:
expected_control = {'another': {'general': None, 'general.sub': None},
'first': {'general': tr2, 'general.sub': None}, }
assert control == expected_control
class TestTrackPlot:
@pytest.fixture
def track_plot_obj(self):
return object.__new__(TrackPlot)
@pytest.fixture
def track_list(self):
return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'},
{'method': 'set', 'scope': 'general'},
{'method': 'get', 'scope': 'general'},
{'method': 'get', 'scope': 'general'},],
'another': [{'method': 'set', 'scope': 'general'}]}},
{'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}],
'another': [{'method': 'get', 'scope': 'general.daytime'},
{'method': 'set', 'scope': 'general'},
{'method': 'set', 'scope': 'general.daytime'},
{'method': 'get', 'scope': 'general.daytime.noon'},
{'method': 'get', 'scope': 'general.nighttime'},
{'method': 'get', 'scope': 'general.daytime.noon'}]}},
{'RunEnvironment': {'another': [{'method': 'get', 'scope': 'general.daytime'}],
'test': [{'method': 'get', 'scope': 'general'}],
'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}]
@pytest.fixture
def scopes(self):
return {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"],
"moonlight": ["general", "general.daytime"],
"sunlight": ["general"],
"test": ["general", "general.daytime"]}
@pytest.fixture
def dims(self):
return {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2}
@pytest.fixture
def track_chain_dict(self, track_list):
return TrackChain(track_list).create_track_chain()
@pytest.fixture
def path(self):
p = os.path.join(os.path.dirname(__file__), "TestExperiment")
if not os.path.exists(p):
os.makedirs(p)
yield p
shutil.rmtree(p, ignore_errors=True)
def test_init(self, path, track_list):
assert "tracking.pdf" not in os.listdir(path)
TrackPlot(track_list, plot_folder=path)
assert "tracking.pdf" in os.listdir(path)
def test_plot(self):
pass
def test_line(self, track_plot_obj):
h, w = 0.6, 0.65
track_plot_obj.height = h
track_plot_obj.width = w
track_plot_obj.fig, track_plot_obj.ax = plt.subplots()
assert len(track_plot_obj.ax.lines) == 0
track_plot_obj.line(start_x=5, end_x=6, y=2)
assert len(track_plot_obj.ax.lines) == 2
pos_x, pos_y = np.array([5 + w, 6]), np.ones((2, )) * (2 + h / 2)
assert track_plot_obj.ax.lines[0]._color == "white"
assert track_plot_obj.ax.lines[0]._linewidth == 2.5
assert track_plot_obj.ax.lines[1]._color == "darkgrey"
assert track_plot_obj.ax.lines[1]._linewidth == 1.4
assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true()
assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true()
def test_step(self, track_plot_obj):
x_int, h, w = 0.5, 0.6, 0.65
track_plot_obj.space_intern_x = x_int
track_plot_obj.height = h
track_plot_obj.width = w
track_plot_obj.fig, track_plot_obj.ax = plt.subplots()
assert len(track_plot_obj.ax.lines) == 0
track_plot_obj.step(start_x=5, end_x=6, start_y=2, end_y=3)
assert len(track_plot_obj.ax.lines) == 2
pos_x = np.array([5 + w, 6 - x_int / 2, 6 - x_int / 2, 6])
pos_y = np.array([2 + h / 2, 2 + h / 2, 3 + h / 2, 3 + h / 2])
assert track_plot_obj.ax.lines[0]._color == "white"
assert track_plot_obj.ax.lines[0]._linewidth == 2.5
assert track_plot_obj.ax.lines[1]._color == "black"
assert track_plot_obj.ax.lines[1]._linewidth == 1.4
assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true()
assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true()
def test_rect(self, track_plot_obj):
h, w = 0.5, 0.6
track_plot_obj.height = h
track_plot_obj.width = w
track_plot_obj.fig, track_plot_obj.ax = plt.subplots()
assert len(track_plot_obj.ax.artists) == 0
assert len(track_plot_obj.ax.texts) == 0
track_plot_obj.rect(x=4, y=2)
assert len(track_plot_obj.ax.artists) == 1
assert len(track_plot_obj.ax.texts) == 1
track_plot_obj.ax.artists[0].xy == (4, 2)
track_plot_obj.ax.artists[0]._height == h
track_plot_obj.ax.artists[0]._width == w
track_plot_obj.ax.artists[0]._original_facecolor == "orange"
track_plot_obj.ax.texts[0].xy == (4 + w / 2, 2 + h / 2)
track_plot_obj.ax.texts[0]._color == "w"
track_plot_obj.ax.texts[0]._text == "get"
track_plot_obj.rect(x=4, y=2, method="set")
assert len(track_plot_obj.ax.artists) == 2
assert len(track_plot_obj.ax.texts) == 2
track_plot_obj.ax.artists[0]._original_facecolor == "lightblue"
track_plot_obj.ax.texts[0]._text == "set"
def test_set_ypos_anchor(self, track_plot_obj, scopes, dims):
assert not hasattr(track_plot_obj, "y_pos")
assert not hasattr(track_plot_obj, "anchor")
y_int, y_ext, h = 0.5, 0.7, 0.6
track_plot_obj.space_intern_y = y_int
track_plot_obj.height = h
track_plot_obj.space_extern_y = y_ext
track_plot_obj.set_ypos_anchor(scopes, dims)
d_y = 0 - sum([factor * (y_int + h) + y_ext - y_int for factor in dims.values()])
expected_anchor = (d_y + sum(dims.values()), h + y_ext + sum(dims.values()))
assert np.testing.assert_array_almost_equal(track_plot_obj.anchor, expected_anchor) is None
assert track_plot_obj.y_pos["another"]["general"] == sum(dims.values())
assert track_plot_obj.y_pos["another"]["general.daytime"] == sum(dims.values()) - (h + y_int)
assert track_plot_obj.y_pos["another"]["general.daytime.noon"] == sum(dims.values()) - 2 * (h + y_int)
def test_plot_track_chain(self):
pass
def test_add_variable_names(self):
pass
def test_add_stages(self):
pass
def test_create_track_chain_plot_run_env(self):
pass
def test_set_lims(self, track_plot_obj):
track_plot_obj.x_max = 10
track_plot_obj.space_intern_x = 0.5
track_plot_obj.width = 0.4
track_plot_obj.anchor = np.array((0.1, 12.5))
track_plot_obj.fig, track_plot_obj.ax = plt.subplots()
assert track_plot_obj.ax.get_ylim() == (0, 1) # matplotlib default
assert track_plot_obj.ax.get_xlim() == (0, 1) # matplotlib default
track_plot_obj.set_lims()
assert track_plot_obj.ax.get_ylim() == (0.1, 12.5)
assert track_plot_obj.ax.get_xlim() == (0, 10+0.5+0.4)
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment