diff --git a/mlair/plotting/postprocessing_plotting.py b/mlair/plotting/postprocessing_plotting.py
index 2484d5ddeac883243f1920c9904f07a7c93bb2cb..77c12c43cf20b3176d9c0e5f07feb186eb368150 100644
--- a/mlair/plotting/postprocessing_plotting.py
+++ b/mlair/plotting/postprocessing_plotting.py
@@ -25,6 +25,7 @@ from mlair.helpers import TimeTrackingWrapper
 from mlair.plotting.abstract_plot_class import AbstractPlotClass
 from mlair.helpers.statistics import mann_whitney_u_test, represent_p_values_as_asteriks
 
+
 logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
 
@@ -1234,50 +1235,31 @@ class PlotSampleUncertaintyFromBootstrap(AbstractPlotClass):  # pragma: no cover
         return ax
 
 
+@TimeTrackingWrapper
 class PlotTimeEvolutionMetric(AbstractPlotClass):
 
-    # mse_pandas = block_mse_per_station.sel(ahead=1, type="nn").to_pandas()
-    #
-    # mse_pandas.columns = mse_pandas.columns.strftime("%b %Y")
-    # years = mse_pandas.columns.strftime("%Y").to_list()
-    # months = mse_pandas.columns.strftime("%b").to_list()
-    #
-    # fig, ax = plt.subplots()
-    # sns.heatmap(mse_pandas, linewidths=1, square=True, cmap="coolwarm", ax=ax)
-    # locs = ax.get_xticks(minor=False).tolist()
-    # # ax.xaxis.set_major_locator(NullLocator())
-    # # ax.xaxis.set_minor_locator(FixedLocator(locs))
-    #
-    # ax.set_xticks(locs, minor=True)
-    # ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0)
-    #
-    # locs_major = []
-    # labels_major = []
-    # for l, major, minor in zip(locs, years, months):
-    #     if minor == "Jan":
-    #         locs_major.append(l+0.001)
-    #         labels_major.append(major)
-    #
-    # ax.set_xticks(locs_major)
-    # ax.set_xticklabels(labels_major, minor=False, rotation=0)
-    # ax.tick_params(axis="x", which="major", pad=15)
-
-
-    def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type"):
-
-        data = data.sel(ahead=1, type="nn").to_pandas()
-        years = data.columns.strftime("%Y").to_list()
-        months = data.columns.strftime("%b").to_list()
-        data.columns = data.columns.strftime("%b %Y")
-        self._plot(data)
-
-        pass
-
-    def prepare_data(self, data, ahead_dim, model_type_dim):
-        pass
-
-    def _set_ticks(self, ax, years, months):
-        locs = ax.get_xticks(minor=False).tolist()
+    def __init__(self, data: xr.DataArray, ahead_dim="ahead", model_type_dim="type", plot_folder=".",
+                 error_measure: str = "mse", error_unit: str = None,):
+        super().__init__(plot_folder, "time_evolution_mse")
+        self.title = error_measure + f" (in {error_unit})" if error_unit is not None else ""
+        plot_name = self.plot_name
+        vmin = int(data.quantile(0.05))
+        vmax = int(data.quantile(0.95))
+        for t in data[model_type_dim]:
+            # note: could be expanded to create plot per ahead step
+            print(data.sel({model_type_dim: t}).mean((ahead_dim, "index")))
+            plot_data = data.sel({model_type_dim: t}).mean(ahead_dim).to_pandas()
+            years = plot_data.columns.strftime("%Y").to_list()
+            months = plot_data.columns.strftime("%b").to_list()
+            plot_data.columns = plot_data.columns.strftime("%b %Y")
+            self.plot_name = f"{plot_name}_{t.values}"
+            self._plot(plot_data, years, months, vmin, vmax)
+
+    @staticmethod
+    def _set_ticks(ax, years, months):
+        from matplotlib.ticker import IndexLocator
+        ax.xaxis.set_major_locator(IndexLocator(1, 0.5))
+        locs = ax.get_xticks(minor=False).tolist()[:len(months)]
         ax.set_xticks(locs, minor=True)
         ax.set_xticklabels([m[0] for m in months], minor=True, rotation=0)
         locs_major = []
@@ -1293,11 +1275,17 @@ class PlotTimeEvolutionMetric(AbstractPlotClass):
         ax.set_xticklabels(labels_major, minor=False, rotation=0)
         ax.tick_params(axis="x", which="major", pad=15)
 
-    def _plot(self, data, years, months):
-        fig, ax = plt.subplots()
-        sns.heatmap(data, linewidths=1, square=True, cmap="coolwarm", ax=ax, robust=True)
+    def _plot(self, data, years, months, vmin=None, vmax=None):
+        fig, ax = plt.subplots(figsize=(data.shape[1] / 5, data.shape[0] / 2.8))
+        sns.heatmap(data, linewidths=1, cmap="coolwarm", ax=ax, vmin=vmin, vmax=vmax, cbar_kws={"aspect": 10})
         # or cmap="Spectral_r", cmap="RdYlBu_r", cmap="coolwarm",
+        # square=True
         self._set_ticks(ax, years, months)
+        ax.set_xlabel(None)
+        ax.set_ylabel(None)
+        plt.title(self.title)
+        plt.tight_layout()
+        self._save()
 
 
 if __name__ == "__main__":
diff --git a/mlair/run_modules/post_processing.py b/mlair/run_modules/post_processing.py
index ca0a1143e80cf9b47ac6795add62d58c6dc9f69c..ce08fbdaa8a2ef89108dfa37d0af51e714e5712b 100644
--- a/mlair/run_modules/post_processing.py
+++ b/mlair/run_modules/post_processing.py
@@ -655,9 +655,9 @@ class PostProcessing(RunEnvironment):
 
         try:
             if "PlotTimeEvolutionMetric" in plot_list:
-                PlotTimeEvolutionMetric(self.block_mse_per_station, model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim)
-                #plot_folder=self.plot_path,
-
+                PlotTimeEvolutionMetric(self.block_mse_per_station, plot_folder=self.plot_path,
+                                        model_type_dim=self.model_type_dim, ahead_dim=self.ahead_dim,
+                                        error_measure="Mean Squared Error", error_unit=r"ppb$^2$")
         except Exception as e:
             logging.error(f"Could not create plot PlotTimeEvolutionMetric due to the following error: {e}"
                           f"\n{sys.exc_info()[0]}\n{sys.exc_info()[1]}\n{sys.exc_info()[2]}")