diff --git a/src/helpers/__init__.py b/src/helpers/__init__.py index 4a428fd2fbbde81111305e9feca4bb4cbc1fb324..546713b3f18f2cb64c1527b57d1e9e2138e927aa 100644 --- a/src/helpers/__init__.py +++ b/src/helpers/__init__.py @@ -1,6 +1,6 @@ """Collection of different supporting functions and classes.""" -from .testing import PyTestRegex, PyTestAllEqual, xr_all_equal +from .testing import PyTestRegex, PyTestAllEqual from .time_tracking import TimeTracking, TimeTrackingWrapper from .logger import Logger from .helpers import remove_items, float_round, dict_to_xarray, to_list diff --git a/src/helpers/testing.py b/src/helpers/testing.py index 3eea56bd9e1b748d1adc4c7d87af6f968f437734..244eb69fdc46dcadaeb3ada5779f09d44aa83e2a 100644 --- a/src/helpers/testing.py +++ b/src/helpers/testing.py @@ -2,6 +2,7 @@ import re from typing import Union, Pattern, List +import numpy as np import xarray as xr @@ -44,6 +45,13 @@ class PyTestAllEqual: def __init__(self, check_list: List): """Construct class.""" self._list = check_list + self._test_function = None + + def _set_test_function(self): + if isinstance(self._list[0], np.ndarray): + self._test_function = np.testing.assert_array_equal + else: + self._test_function = xr.testing.assert_equal def _check_all_equal(self) -> bool: """ @@ -52,8 +60,9 @@ class PyTestAllEqual: :return boolean if elements are equal """ equal = True + self._set_test_function() for b in self._list: - equal *= xr.testing.assert_equal(self._list[0], b) is None + equal *= self._test_function(self._list[0], b) is None return bool(equal == 1) def is_true(self) -> bool: diff --git a/src/plotting/tracker_plot.py b/src/plotting/tracker_plot.py index e63e2988c5b3ed216cc73ddb359d9e7eb7618932..2d7b06cb6d7430be80eeae8ecedf811a3f2dc37c 100644 --- a/src/plotting/tracker_plot.py +++ b/src/plotting/tracker_plot.py @@ -272,7 +272,7 @@ class TrackPlot: """Draw rectangle with lower left at (x,y), size equal to width/height and label/color according to method.""" # draw rectangle color = {"get": "orange"}.get(method, "lightblue") - r = Rectangle((x, y), self.width, self.height, color=color, label=color) + r = Rectangle((x, y), self.width, self.height, color=color) self.ax.add_artist(r) # add label rx, ry = r.get_xy() diff --git a/test/test_plotting/test_tracker_plot.py b/test/test_plotting/test_tracker_plot.py index 37a53c1c2dc21e311ad9f59a885f066ca20a802f..9a92360a819c130c213d06b89a48a896e082adad 100644 --- a/test/test_plotting/test_tracker_plot.py +++ b/test/test_plotting/test_tracker_plot.py @@ -1,8 +1,15 @@ import pytest from collections import OrderedDict +import os +import shutil + +from matplotlib import pyplot as plt +import numpy as np + +from src.plotting.tracker_plot import TrackObject, TrackChain, TrackPlot +from src.helpers import PyTestAllEqual -from src.plotting.tracker_plot import TrackObject, TrackChain class TestTrackObject: @@ -284,3 +291,157 @@ class TestTrackChain: expected_control = {'another': {'general': None, 'general.sub': None}, 'first': {'general': tr2, 'general.sub': None}, } assert control == expected_control + + +class TestTrackPlot: + + @pytest.fixture + def track_plot_obj(self): + return object.__new__(TrackPlot) + + @pytest.fixture + def track_list(self): + return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'},], + 'another': [{'method': 'set', 'scope': 'general'}]}}, + {'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}], + 'another': [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}]}}, + {'RunEnvironment': {'another': [{'method': 'get', 'scope': 'general.daytime'}], + 'test': [{'method': 'get', 'scope': 'general'}], + 'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}] + + @pytest.fixture + def scopes(self): + return {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + + @pytest.fixture + def dims(self): + return {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2} + + @pytest.fixture + def track_chain_dict(self, track_list): + return TrackChain(track_list).create_track_chain() + + @pytest.fixture + def path(self): + p = os.path.join(os.path.dirname(__file__), "TestExperiment") + if not os.path.exists(p): + os.makedirs(p) + yield p + shutil.rmtree(p, ignore_errors=True) + + def test_init(self, path, track_list): + assert "tracking.pdf" not in os.listdir(path) + TrackPlot(track_list, plot_folder=path) + assert "tracking.pdf" in os.listdir(path) + + def test_plot(self): + pass + + def test_line(self, track_plot_obj): + h, w = 0.6, 0.65 + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.lines) == 0 + track_plot_obj.line(start_x=5, end_x=6, y=2) + assert len(track_plot_obj.ax.lines) == 2 + pos_x, pos_y = np.array([5 + w, 6]), np.ones((2, )) * (2 + h / 2) + assert track_plot_obj.ax.lines[0]._color == "white" + assert track_plot_obj.ax.lines[0]._linewidth == 2.5 + assert track_plot_obj.ax.lines[1]._color == "darkgrey" + assert track_plot_obj.ax.lines[1]._linewidth == 1.4 + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + + def test_step(self, track_plot_obj): + x_int, h, w = 0.5, 0.6, 0.65 + track_plot_obj.space_intern_x = x_int + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.lines) == 0 + track_plot_obj.step(start_x=5, end_x=6, start_y=2, end_y=3) + assert len(track_plot_obj.ax.lines) == 2 + pos_x = np.array([5 + w, 6 - x_int / 2, 6 - x_int / 2, 6]) + pos_y = np.array([2 + h / 2, 2 + h / 2, 3 + h / 2, 3 + h / 2]) + assert track_plot_obj.ax.lines[0]._color == "white" + assert track_plot_obj.ax.lines[0]._linewidth == 2.5 + assert track_plot_obj.ax.lines[1]._color == "black" + assert track_plot_obj.ax.lines[1]._linewidth == 1.4 + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + + def test_rect(self, track_plot_obj): + h, w = 0.5, 0.6 + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.artists) == 0 + assert len(track_plot_obj.ax.texts) == 0 + track_plot_obj.rect(x=4, y=2) + assert len(track_plot_obj.ax.artists) == 1 + assert len(track_plot_obj.ax.texts) == 1 + track_plot_obj.ax.artists[0].xy == (4, 2) + track_plot_obj.ax.artists[0]._height == h + track_plot_obj.ax.artists[0]._width == w + track_plot_obj.ax.artists[0]._original_facecolor == "orange" + track_plot_obj.ax.texts[0].xy == (4 + w / 2, 2 + h / 2) + track_plot_obj.ax.texts[0]._color == "w" + track_plot_obj.ax.texts[0]._text == "get" + track_plot_obj.rect(x=4, y=2, method="set") + assert len(track_plot_obj.ax.artists) == 2 + assert len(track_plot_obj.ax.texts) == 2 + track_plot_obj.ax.artists[0]._original_facecolor == "lightblue" + track_plot_obj.ax.texts[0]._text == "set" + + + + def test_set_ypos_anchor(self, track_plot_obj, scopes, dims): + assert not hasattr(track_plot_obj, "y_pos") + assert not hasattr(track_plot_obj, "anchor") + y_int, y_ext, h = 0.5, 0.7, 0.6 + track_plot_obj.space_intern_y = y_int + track_plot_obj.height = h + track_plot_obj.space_extern_y = y_ext + track_plot_obj.set_ypos_anchor(scopes, dims) + d_y = 0 - sum([factor * (y_int + h) + y_ext - y_int for factor in dims.values()]) + expected_anchor = (d_y + sum(dims.values()), h + y_ext + sum(dims.values())) + assert np.testing.assert_array_almost_equal(track_plot_obj.anchor, expected_anchor) is None + assert track_plot_obj.y_pos["another"]["general"] == sum(dims.values()) + assert track_plot_obj.y_pos["another"]["general.daytime"] == sum(dims.values()) - (h + y_int) + assert track_plot_obj.y_pos["another"]["general.daytime.noon"] == sum(dims.values()) - 2 * (h + y_int) + + def test_plot_track_chain(self): + pass + + def test_add_variable_names(self): + pass + + def test_add_stages(self): + pass + + def test_create_track_chain_plot_run_env(self): + pass + + def test_set_lims(self, track_plot_obj): + track_plot_obj.x_max = 10 + track_plot_obj.space_intern_x = 0.5 + track_plot_obj.width = 0.4 + track_plot_obj.anchor = np.array((0.1, 12.5)) + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert track_plot_obj.ax.get_ylim() == (0, 1) # matplotlib default + assert track_plot_obj.ax.get_xlim() == (0, 1) # matplotlib default + track_plot_obj.set_lims() + assert track_plot_obj.ax.get_ylim() == (0.1, 12.5) + assert track_plot_obj.ax.get_xlim() == (0, 10+0.5+0.4) \ No newline at end of file