Skip to content
Snippets Groups Projects
Select Git revision
  • dc8427b0d738c514c2cb6c4d622be4ed490afd52
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

test_datastore.py

Blame
  • test_training_monitoring.py 3.28 KiB
    import keras
    import pytest
    import os
    
    from src.plotting.training_monitoring import PlotModelLearningRate, PlotModelHistory
    from src.helpers import LearningRateDecay
    
    
    @pytest.fixture
    def path():
        p = os.path.join(os.path.dirname(__file__), "TestExperiment")
        if not os.path.exists(p):
            os.makedirs(p)
        return p
    
    
    class TestPlotModelHistory:
    
        @pytest.fixture
        def history(self):
            hist = keras.callbacks.History()
            hist.epoch = [0, 1]
            hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
                         'val_mean_squared_error': [0.5586272982587484, 0.45712877659670287],
                         'val_mean_absolute_error': [0.595368885413389, 0.530547587585537],
                         'loss': [0.6795708956961347, 0.45963566494176616],
                         'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                         'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
            return hist
    
        @pytest.fixture
        def history_var(self):
            hist = keras.callbacks.History()
            hist.epoch = [0, 1]
            hist.history = {'val_loss': [0.5586272982587484, 0.45712877659670287],
                         'test_loss': [0.595368885413389, 0.530547587585537],
                         'loss': [0.6795708956961347, 0.45963566494176616],
                         'mean_squared_error': [0.6795708956961347, 0.45963566494176616],
                         'mean_absolute_error': [0.6523177288928538, 0.5363963260296364]}
            return hist
    
        @pytest.fixture
        def no_init(self):
            return object.__new__(PlotModelHistory)
    
        def test_plot_from_hist_obj(self, history, path):
            assert "hist_obj.pdf" not in os.listdir(path)
            PlotModelHistory(os.path.join(path, "hist_obj.pdf"), history)
            assert "hist_obj.pdf" in os.listdir(path)
    
        def test_plot_from_hist_dict(self, history, path):
            assert "hist_dict.pdf" not in os.listdir(path)
            PlotModelHistory(os.path.join(path, "hist_dict.pdf"), history.history)
            assert "hist_dict.pdf" in os.listdir(path)
    
        def test_plot_additional_loss(self, history_var, path):
            assert "hist_additional.pdf" not in os.listdir(path)
            PlotModelHistory(os.path.join(path, "hist_additional.pdf"), history_var)
            assert "hist_additional.pdf" in os.listdir(path)
    
        def test_filter_list(self, no_init):
            res = no_init._filter_columns({'loss': None, 'another_loss': None, 'val_loss': None, 'wrong': None})
            assert res == ['another_loss']
    
    
    class TestPlotModelLearningRate:
    
        @pytest.fixture
        def learning_rate(self):
            lr_obj = LearningRateDecay()
            lr_obj.lr = {"lr": [0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.01, 0.0094, 0.0094, 0.0094, 0.0094,
                                0.0094, 0.0094, 0.0094, 0.0094, 0.0094, 0.0094]}
            return lr_obj
    
        def test_plot_from_lr_obj(self, learning_rate, path):
            assert "lr_obj.pdf" not in os.listdir(path)
            PlotModelLearningRate(os.path.join(path, "lr_obj.pdf"), learning_rate)
            assert "lr_obj.pdf" in os.listdir(path)
    
        def test_plot_from_lr_dict(self, learning_rate, path):
            assert "lr_dict.pdf" not in os.listdir(path)
            PlotModelLearningRate(os.path.join(path, "lr_dict.pdf"), learning_rate.lr)
            assert "lr_dict.pdf" in os.listdir(path)