From 2668bacfd8423bc385e255d4c59e5700980b36af Mon Sep 17 00:00:00 2001
From: lukas leufen <l.leufen@fz-juelich.de>
Date: Mon, 11 May 2020 11:30:01 +0200
Subject: [PATCH] Introduced tracker plot (still on refac), moved get default
 to base class, added track parameter to datastore's get and set

---
 src/helpers/datastore.py     |  76 +++----
 src/plotting/tracker_plot.py | 428 +++++++++++++++++++++++++++++++++++
 2 files changed, 456 insertions(+), 48 deletions(-)
 create mode 100644 src/plotting/tracker_plot.py

diff --git a/src/helpers/datastore.py b/src/helpers/datastore.py
index cd852067..b4615216 100644
--- a/src/helpers/datastore.py
+++ b/src/helpers/datastore.py
@@ -120,12 +120,12 @@ class TrackParameter:
     def track(self, tracker_obj, *args):
         name, obj, scope = self._decrypt_args(*args)
         logging.debug(f"{self.__wrapped__.__name__}: {name}({scope})={obj}")
-        tracker = tracker_obj.tracker
-        new_entry = [(self.__wrapped__.__name__, scope, obj)]
-        if tracker.get(name):
+        tracker = tracker_obj.tracker[-1]
+        new_entry = {"method": self.__wrapped__.__name__, "scope": scope}
+        if name in tracker:
             tracker[name].append(new_entry)
         else:
-            tracker[name] = new_entry
+            tracker[name] = [new_entry]
 
     @staticmethod
     def _decrypt_args(*args):
@@ -144,7 +144,7 @@ class AbstractDataStore(ABC):
     adjustments.
     """
 
-    tracker = {}
+    tracker = [{}]
 
     def __init__(self):
         """Initialise by creating empty data store."""
@@ -171,6 +171,27 @@ class AbstractDataStore(ABC):
         """
         pass
 
+    @CorrectScope
+    def get_default(self, name: str, scope: str, default: Any) -> Any:
+        """
+        Retrieve an object with `name` from `scope` and return given default if object wasn't found.
+
+        Same functionality like the standard get method. But this method adds a default argument that is returned if no
+        data was stored in the data store. Use this function with care, because it will not report any errors and just
+        return the given default value. Currently, there is no statement that reports, if the returned value comes from
+        the data store or the default value.
+
+        :param name: Name to look for
+        :param scope: scope to search the name for
+        :param default: default value that is return, if no data was found for given name and scope
+
+        :return: the stored object or the default value
+        """
+        try:
+            return self.get(name, scope)
+        except (NameNotFoundInDataStore, NameNotFoundInScope):
+            return default
+
     def search_name(self, name: str) -> None:
         """
         Abstract method to search for all occurrences of given `name` in the entire data store.
@@ -312,28 +333,6 @@ class DataStoreByVariable(AbstractDataStore):
         """
         return self._stride_through_scopes(name, scope)[2]
 
-    @CorrectScope
-    def get_default(self, name: str, scope: str, default: Any) -> Any:
-        """
-
-        Retrieve an object with `name` from `scope` and return given default if object wasn't found.
-
-        Same functionality like the standard get method. But this method adds a default argument that is returned if no
-        data was stored in the data store. Use this function with care, because it will not report any errors and just
-        return the given default value. Currently, there is no statement that reports, if the returned value comes from
-        the data store or the default value.
-
-        :param name: Name to look for
-        :param scope: scope to search the name for
-        :param default: default value that is return, if no data was found for given name and scope
-
-        :return: the stored object or the default value
-        """
-        try:
-            return self._stride_through_scopes(name, scope)[2]
-        except (NameNotFoundInDataStore, NameNotFoundInScope):
-            return default
-
     @CorrectScope
     def _stride_through_scopes(self, name, scope, depth=0):
         if depth <= scope.count("."):
@@ -449,6 +448,7 @@ class DataStoreByScope(AbstractDataStore):
     """
 
     @CorrectScope
+    @TrackParameter
     def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None:
         """
         Store an object `obj` with given `name` under `scope`.
@@ -467,6 +467,7 @@ class DataStoreByScope(AbstractDataStore):
             logging.debug(f"set: {name}({scope})={obj}")
 
     @CorrectScope
+    @TrackParameter
     def get(self, name: str, scope: str) -> Any:
         """
         Retrieve an object with `name` from `scope`.
@@ -483,27 +484,6 @@ class DataStoreByScope(AbstractDataStore):
         """
         return self._stride_through_scopes(name, scope)[2]
 
-    @CorrectScope
-    def get_default(self, name: str, scope: str, default: Any) -> Any:
-        """
-        Retrieve an object with `name` from `scope` and return given default if object wasn't found.
-
-        Same functionality like the standard get method. But this method adds a default argument that is returned if no
-        data was stored in the data store. Use this function with care, because it will not report any errors and just
-        return the given default value. Currently, there is no statement that reports, if the returned value comes from
-        the data store or the default value.
-
-        :param name: Name to look for
-        :param scope: scope to search the name for
-        :param default: default value that is return, if no data was found for given name and scope
-
-        :return: the stored object or the default value
-        """
-        try:
-            return self._stride_through_scopes(name, scope)[2]
-        except (NameNotFoundInDataStore, NameNotFoundInScope):
-            return default
-
     @CorrectScope
     def _stride_through_scopes(self, name, scope, depth=0):
         if depth <= scope.count("."):
diff --git a/src/plotting/tracker_plot.py b/src/plotting/tracker_plot.py
new file mode 100644
index 00000000..3f979b65
--- /dev/null
+++ b/src/plotting/tracker_plot.py
@@ -0,0 +1,428 @@
+from collections import OrderedDict
+
+import numpy as np
+import os
+from typing import Union, List, Optional, Dict
+
+from src.helpers import to_list
+
+from matplotlib import pyplot as plt, lines as mlines, ticker as ticker
+from matplotlib.patches import Rectangle
+
+
+class TrackObject:
+
+    """
+    A TrackObject can be used to create simple chains of objects.
+
+    :param name: string or list of strings with a name describing the track object
+    :param stage: additional meta information (can be used to highlight different blocks inside a chain)
+    """
+
+    def __init__(self, name: Union[List[str], str], stage: str):
+        self.name = to_list(name)
+        self.stage = stage
+        self.precursor: Optional[List[TrackObject]] = None
+        self.successor: Optional[List[TrackObject]] = None
+        self.x: Optional[float] = None
+        self.y: Optional[float] = None
+
+    def __repr__(self):
+        return str("/".join(self.name))
+
+    @property
+    def x(self):
+        """Get x value."""
+        return self._x
+
+    @x.setter
+    def x(self, value: float):
+        """Set x value."""
+        self._x = value
+
+    @property
+    def y(self):
+        """Get y value."""
+        return self._y
+
+    @y.setter
+    def y(self, value: float):
+        """Set y value."""
+        self._y = value
+
+    def add_precursor(self, precursor: "TrackObject"):
+        """Add a precursory track object."""
+        if self.precursor is None:
+            self.precursor = [precursor]
+        else:
+            if precursor not in self.precursor:
+                self.precursor.append(precursor)
+            else:
+                return
+        precursor.add_successor(self)
+
+    def add_successor(self, successor: "TrackObject"):
+        """Add a successive track object."""
+        if self.successor is None:
+            self.successor = [successor]
+        else:
+            if successor not in self.successor:
+                self.successor.append(successor)
+            else:
+                return
+        successor.add_precursor(self)
+
+
+class TrackChain:
+
+    def __init__(self, track_list):
+        self.track_list = track_list
+        self.scopes = self.get_all_scopes(self.track_list)
+        self.dims = self.get_all_dims(self.scopes)
+
+    def get_all_scopes(self, track_list) -> Dict:
+        """Return dictionary with all distinct variables as keys and its unique scopes as values."""
+        dims = {}
+        for track_dict in track_list:  # all stages
+            for track in track_dict.values():  # single stage, all variables
+                for k, v in track.items():  # single variable
+                    scopes = self.get_unique_scopes(v)
+                    if dims.get(k) is None:
+                        dims[k] = scopes
+                    else:
+                        dims[k] = np.unique(scopes + dims[k]).tolist()
+        return OrderedDict(sorted(dims.items()))
+
+    @staticmethod
+    def get_all_dims(scopes):
+        dims = {}
+        for k, v in scopes.items():
+            dims[k] = len(v)
+        return dims
+
+    def create_track_chain(self):
+        control = self.control_dict(self.scopes)
+        track_chain_dict = OrderedDict()
+        for track_dict in self.track_list:
+            stage, stage_track = list(track_dict.items())[0]
+            track_chain, control = self._create_track_chain(control, OrderedDict(sorted(stage_track.items())), stage)
+            control = self.clean_control(control)
+            track_chain_dict[stage] = track_chain
+        return track_chain_dict
+
+    def _create_track_chain(self, control, sorted_track_dict, stage):
+        track_objects = []
+        for k, v in sorted_track_dict.items():
+            for e in v:
+                tr = TrackObject([k, e["method"], e["scope"]], stage)
+                if e["method"] == "set":
+                    if control[k][e["scope"]] is not None:
+                        track_objects = self._add_precursor(track_objects, tr, control[k][e["scope"]])
+                        # tr.add_precursor(control[k][e["scope"]])
+                        # # if tr.stage != control[k][e["scope"]].stage:
+                        # #     track_objects.append(control[k][e["scope"]])
+                        # track_objects = self._add_track_object(track_objects, tr, control[k][e["scope"]])
+                    else:
+                        track_objects.append(tr)
+                    self._update_control(control, k, e["scope"], tr)
+                    # control[k][e["scope"]] = tr
+                elif e["method"] == "get":
+                    if control[k][e["scope"]] is not None:
+                        track_objects = self._add_precursor(track_objects, tr, control[k][e["scope"]])
+                        # tr.add_precursor(control[k][e["scope"]])
+                        # # if tr.stage != control[k][e["scope"]].stage:
+                        # #     track_objects.append(control[k][e["scope"]])
+                        # track_objects = self._add_track_object(track_objects, tr, control[k][e["scope"]])
+                        # control[k][e["scope"]] = tr
+                        self._update_control(control, k, e["scope"], tr)
+                    else:
+                        scope = e["scope"].rsplit(".", 1)
+                        while len(scope) > 1:
+                            scope = scope[0]
+                            if control[k][scope] is not None:
+                                pre = control[k][scope]
+                                while pre.precursor is not None and pre.stage == stage and pre.name[1] != "set":
+                                    pre = pre.precursor[0]
+                                # tr.add_precursor(pre)
+                                # # if tr.stage != pre.stage:
+                                # #     track_objects.append(pre)
+                                # track_objects = self._add_track_object(track_objects, tr, pre)
+                                track_objects = self._add_precursor(track_objects, tr, pre)
+                                break
+                            scope = scope.rsplit(".", 1)
+                        else:
+                            continue
+                        # control[k][e["scope"]] = tr
+                        self._update_control(control, k, e["scope"], tr)
+        return track_objects, control
+
+    @staticmethod
+    def _update_control(control, variable, scope, tr_obj):
+        control[variable][scope] = tr_obj
+
+    @staticmethod
+    def _add_track_object(track_objects, tr_obj, prev_obj):
+        if tr_obj.stage != prev_obj.stage:
+            track_objects.append(prev_obj)
+        return track_objects
+
+    def _add_precursor(self, track_objects, tr_obj, prev_obj):
+        tr_obj.add_precursor(prev_obj)
+        return self._add_track_object(track_objects, tr_obj, prev_obj)
+
+    @staticmethod
+    def control_dict(scopes):
+        """Create empty control dictionary with variables and scopes as keys and None as default for all values."""
+        control = {}
+        for variable, scope_names in scopes.items():
+            control[variable] = {}
+            for s in scope_names:
+                update = {s: None}
+                if len(control[variable].keys()) == 0:
+                    control[variable] = update
+                else:
+                    control[variable].update(update)
+        return control
+
+    @staticmethod
+    def clean_control(control):
+        for k, v in control.items():
+            for kv, vv in v.items():
+                try:
+                    if vv.precursor[0].name[2] != vv.name[2]:
+                        control[k][kv] = None
+                except (TypeError, AttributeError):
+                    pass
+        return control
+
+    @staticmethod
+    def get_unique_scopes(track_list: List[Dict]) -> List[str]:
+        """Get list with all unique elements from input including general scope if missing."""
+        scopes = [e["scope"] for e in track_list] + ["general"]
+        return np.unique(scopes).tolist()
+
+
+class TrackerPlot:
+
+    def __init__(self, tracker_list, sparse_conn_mode=True, plot_folder: str = ".", skip_run_env=True):
+
+        self.width = 0.6
+        self.height = 0.5
+        self.space_intern_y = 0.2
+        self.space_extern_y = 1
+        self.space_intern_x = 0.4
+        self.space_extern_x = 0.6
+        self.y_pos = None
+        self.anchor = None
+        self.x_max = None
+
+        track_chain_obj = TrackChain(tracker_list)
+        scopes = track_chain_obj.scopes
+        dims = track_chain_obj.dims
+        track_chain_dict = track_chain_obj.create_track_chain()
+        # scopes = self.get_scopes(tracker_list)
+        # dims = self.get_dims(scopes)
+        self.set_ypos_anchor(scopes, dims)
+        # track_chain_dict = self.create_track_chain(tracker_list, scopes)
+        self.fig, self.ax = plt.subplots(figsize=(len(tracker_list) * 2, (self.anchor.max() - self.anchor.min()) / 3))
+        stages, v_lines = self.create_track_chain_plot(track_chain_dict, sparse_conn_mode=sparse_conn_mode, skip_run_env=skip_run_env)
+        self.set_lims()
+        self.add_variable_names()
+        self.add_stages(v_lines, stages)
+        plt.tight_layout()
+        plot_name = os.path.join(os.path.abspath(plot_folder), "tracking.pdf")
+        plt.savefig(plot_name, dpi=600)
+
+    def line(self, start_x, end_x, y, color="darkgrey"):
+        l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color="white",
+                          linewidth=2.5)
+        self.ax.add_line(l)
+        l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color=color,
+                          linewidth=1.4)
+        self.ax.add_line(l)
+
+    def step(self, start_x, end_x, start_y, end_y, color="black"):
+        start_x += self.width
+        start_y += self.height / 2
+        end_y += self.height / 2
+        step_x = end_x - (self.space_intern_x) / 2
+        pos_x = [start_x, step_x, step_x, end_x]
+        # pos_x = [start_x, (start_x + end_x) / 2, (start_x + end_x) / 2, end_x]
+        pos_y = [start_y, start_y, end_y, end_y]
+        l = mlines.Line2D(pos_x, pos_y, color="white", linewidth=2.5)
+        self.ax.add_line(l)
+        l = mlines.Line2D(pos_x, pos_y, color=color, linewidth=1.4)
+        self.ax.add_line(l)
+
+    def rect(self, x, y, method="get"):
+        # r = Rectangle((x, y), self.width, self.height, color=color, label=color)
+        # self.ax.add_patch(r)
+
+        if method == "get":
+            color = "orange"
+        else:
+            color = "lightblue"
+        r = Rectangle((x, y), self.width, self.height, color=color, label=color)
+
+        self.ax.add_artist(r)
+        rx, ry = r.get_xy()
+        cx = rx + r.get_width() / 2.0
+        cy = ry + r.get_height() / 2.0
+        self.ax.annotate(method, (cx, cy), color='w', weight='bold',
+                    fontsize=6, ha='center', va='center')
+
+    def set_ypos_anchor(self, scopes, dims):
+        anchor = sum(dims.values())
+        pos_dict = {}
+        d_y = 0
+        for k, v in scopes.items():
+            pos_dict[k] = {}
+            for e in v:
+                update = {e: anchor + d_y}
+                if len(pos_dict[k].keys()) == 0:
+                    pos_dict[k] = update
+                else:
+                    pos_dict[k].update(update)
+                d_y -= (self.space_intern_y + self.height)
+            d_y -= (self.space_extern_y - self.space_intern_y)
+        self.y_pos = pos_dict
+        self.anchor = np.array((d_y, self.height + self.space_extern_y)) + anchor
+
+    def plot_track_chain(self, chain, y_pos, x_pos=0, prev=None, stage=None, sparse_conn_mode=False):
+        if (chain.successor is None) or (chain.stage == stage):
+            var, method, scope = chain.name
+            x, y = x_pos, y_pos[var][scope]
+            self.rect(x, y, method=method)
+            chain.x, chain.y = x, y
+            if prev is not None and prev[0] is not None:
+                if (sparse_conn_mode is True) and (method == "set"):
+                    pass
+                else:
+                    if y == prev[1]:
+                        self.line(prev[0], x, prev[1])
+                    else:
+                        self.step(prev[0], x, prev[1], y)
+        else:
+            x, y = chain.x, chain.y
+
+        x_max = None
+        if chain.successor is not None:
+            for e in chain.successor:
+                if e.stage == stage:
+                    shift = self.width + self.space_intern_x if chain.stage == e.stage else 0
+                    x_tmp = self.plot_track_chain(e, y_pos, x_pos + shift, prev=(x, y),
+                                                  stage=stage, sparse_conn_mode=sparse_conn_mode)
+                    x_max = np.nanmax(np.array([x_tmp, x_max], dtype=np.float64))
+                else:
+                    x_max = np.nanmax(np.array([x, x_max, x_pos], dtype=np.float64))
+        else:
+            x_max = x
+
+        return x_max
+
+    def add_variable_names(self):
+        labels = []
+        pos = []
+        labels_major = []
+        pos_major = []
+        for k, v in self.y_pos.items():
+            for kv, vv in v.items():
+                if kv == "general":
+                    labels_major.append(k)
+                    pos_major.append(vv + self.height / 2)
+                else:
+                    labels.append(kv.split(".", 1)[1])
+                    pos.append(vv + self.height / 2)
+        self.ax.tick_params(axis="y", which="major", labelsize="large")
+        self.ax.yaxis.set_major_locator(ticker.FixedLocator(pos_major))
+        self.ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels_major))
+        self.ax.yaxis.set_minor_locator(ticker.FixedLocator(pos))
+        self.ax.yaxis.set_minor_formatter(ticker.FixedFormatter(labels))
+
+    def add_stages(self, vlines, stages):
+        x_max = self.x_max + self.space_intern_x + self.width
+        for l in vlines:
+            self.ax.vlines(l, *self.anchor, "black", "dashed")
+        vlines = [0] + vlines + [x_max]
+        pos = [(vlines[i] + vlines[i+1]) / 2 for i in range(len(vlines)-1)]
+        self.ax.xaxis.set_major_locator(ticker.FixedLocator(pos))
+        self.ax.xaxis.set_major_formatter(ticker.FixedFormatter(stages))
+
+    def create_track_chain(self, tracker_list, scopes):
+        control = self.control_dict(scopes)
+        track_chain_dict = OrderedDict()
+        for track_dict in tracker_list:
+            stage, stage_track = list(track_dict.items())[0]
+            track_chain, control = create_track_chain(control, OrderedDict(sorted(stage_track.items())), stage)
+            control = self.clean_control(control)
+            track_chain_dict[stage] = track_chain
+        return track_chain_dict
+
+    def create_track_chain_plot(self, track_chain_dict, sparse_conn_mode=True, skip_run_env=True):
+        x, x_max = 0, 0
+        v_lines, stages = [], []
+        for stage, track_chain in track_chain_dict.items():
+            if stage == "RunEnvironment" and skip_run_env is True:
+                continue
+            if x > 0:
+                v_lines.append(x - self.space_extern_x / 2)
+            for e in track_chain:
+                x_max = max(x_max, self.plot_track_chain(e, self.y_pos, x_pos=x, stage=stage, sparse_conn_mode=sparse_conn_mode))
+            x = x_max + self.space_extern_x + self.width
+            stages.append(stage)
+        self.x_max = x_max
+        return stages, v_lines
+
+    def set_lims(self):
+        x_max = self.x_max + self.space_intern_x + self.width
+        self.ax.set_xlim((0, x_max))
+        self.ax.set_ylim(self.anchor)
+
+    @staticmethod
+    def control_dict(scopes):
+        control = {}
+        for k, v in scopes.items():
+            control[k] = {}
+            for e in v:
+                update = {e: None}
+                if len(control[k].keys()) == 0:
+                    control[k] = update
+                else:
+                    control[k].update(update)
+        return control
+
+    @staticmethod
+    def clean_control(control):
+        for k, v in control.items():
+            for kv, vv in v.items():
+                try:
+                    if vv.precursor[0].name[2] != vv.name[2]:
+                        control[k][kv] = None
+                except (TypeError, AttributeError):
+                    pass
+        return control
+
+    @staticmethod
+    def get_scopes(track_list):
+        dims = {}
+        for track_dict in track_list:
+            for track in track_dict.values():
+                for k, v in track.items():
+                    scopes = get_dim_scope(v)
+                    if dims.get(k) is None:
+                        dims[k] = scopes
+                    else:
+                        dims[k] = np.unique(scopes + dims[k]).tolist()
+        return OrderedDict(sorted(dims.items()))
+
+    @staticmethod
+    def get_dims(scopes):
+        dims = {}
+        for k, v in scopes.items():
+            dims[k] = len(v)
+        return dims
+
+
+def get_dim_scope(track_list):
+    scopes = [e["scope"] for e in track_list] + ["general"]
+    return np.unique(scopes).tolist()
-- 
GitLab