diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index 74df5f6ac1c998e644fa7d89a688fc12dee21265..c6f38a6f0e70518956bcbbd51a6fdfc1a1e7849f 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -45,7 +45,7 @@ class Distributor(keras.utils.Sequence): for prev, curr in enumerate(range(1, num_mini_batches+1)): x = x_total[prev*self.batch_size:curr*self.batch_size, ...] y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] - if x is not None: + if x is not None: # pragma: no branch yield (x, y) if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call: return diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 0e60c8916c62c10929fad72b2271b0e7f54dee4a..19a94fbb9dbbc8f382a225c852f34971a98395b8 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -2,10 +2,12 @@ __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-11-07' import os -from typing import Union, List, Tuple +from typing import Union, List, Tuple, Any import keras import xarray as xr +import pickle +import logging from src import helpers from src.data_handling.data_preparation import DataPrep @@ -25,6 +27,9 @@ class DataGenerator(keras.utils.Sequence): interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, window_lead_time: int = 4, transform_method: str = "standardise", **kwargs): self.data_path = os.path.abspath(data_path) + self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") + if not os.path.exists(self.data_path_tmp): + os.makedirs(self.data_path_tmp) self.network = network self.stations = helpers.to_list(stations) self.variables = variables @@ -72,11 +77,11 @@ class DataGenerator(keras.utils.Sequence): if self._iterator < self.__len__(): data = self.get_data_generator() self._iterator += 1 - if data.history is not None and data.label is not None: + if data.history is not None and data.label is not None: # pragma: no branch return data.history.transpose("datetime", "window", "Stations", "variables"), \ data.label.squeeze("Stations").transpose("datetime", "window") else: - self.__next__() + self.__next__() # pragma: no cover else: raise StopIteration @@ -87,24 +92,61 @@ class DataGenerator(keras.utils.Sequence): :return: The generator's time series of history data and its labels """ data = self.get_data_generator(key=item) - return data.history.transpose("datetime", "window", "Stations", "variables"), \ - data.label.squeeze("Stations").transpose("datetime", "window") + return data.get_transposed_history(), data.label.squeeze("Stations").transpose("datetime", "window") - def get_data_generator(self, key: Union[str, int] = None) -> DataPrep: + def get_data_generator(self, key: Union[str, int] = None, local_tmp_storage: bool = True) -> DataPrep: """ Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and remove nans. :param key: station key to choose the data generator. + :param local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from + tmp pickle file to save computational time (but of course more disk space required). :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) - data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, - **self.kwargs) - data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) - data.transform("datetime", method=self.transform_method) - data.make_history_window(self.interpolate_dim, self.window_history_size) - data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) - data.history_label_nan_remove(self.interpolate_dim) + try: + if not local_tmp_storage: + raise FileNotFoundError + data = self._load_pickle_data(station, self.variables) + except FileNotFoundError: + logging.info(f"load not pickle data for {station}") + data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, + **self.kwargs) + data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) + data.transform("datetime", method=self.transform_method) + data.make_history_window(self.interpolate_dim, self.window_history_size) + data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) + data.history_label_nan_remove(self.interpolate_dim) + self._save_pickle_data(data) + return data + + def _save_pickle_data(self, data: Any): + """ + Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle' + :param data: any data, that should be saved + """ + date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}" + vars = '_'.join(sorted(data.variables)) + station = ''.join(data.station) + file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle") + with open(file, "wb") as f: + pickle.dump(data, f) + logging.debug(f"save pickle data to {file}") + + def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any: + """ + Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'. + :param station: station to load + :param variables: list of variables to load + :return: loaded data + """ + date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}" + vars = '_'.join(sorted(variables)) + station = ''.join(station) + file = os.path.join(self.data_path_tmp, f"{station}_{vars}_{date}_.pickle") + with open(file, "rb") as f: + data = pickle.load(f) + logging.debug(f"load pickle data from {file}") return data def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str: diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 1ae5f0ae8ec77c9e5e0b6776d4e17f1dff412286..5bca71f52c9f136b5910d4e080491e0ff86484ae 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -109,7 +109,7 @@ class DataPrep(object): check_dict = {"station_type": self.station_type, "network_name": self.network} for (k, v) in check_dict.items(): if self.meta.at[k, self.station[0]] != v: - logging.debug(f"meta data does not agree which given request for {k}: {v} (requested) != " + logging.debug(f"meta data does not agree with given request for {k}: {v} (requested) != " f"{self.meta.at[k, self.station[0]]} (local). Raise FileNotFoundError to trigger new " f"grapping from web.") raise FileNotFoundError @@ -386,6 +386,10 @@ class DataPrep(object): data.loc[..., used_chem_vars] = data.loc[..., used_chem_vars].clip(min=minimum) return data + def get_transposed_history(self): + if self.history is not None: + return self.history.transpose("datetime", "window", "Stations", "variables") + if __name__ == "__main__": dp = DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp'], statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 6ae8573d11375335edda48a1bf26f30d6125bd69..0e63125d921029d28a672a5e5e8d0ecd2995d050 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -141,11 +141,11 @@ class PlotStationMap(RunEnvironment): """ Draw coastline, lakes, ocean, rivers and country borders as background on the map. """ - self._ax.add_feature(cfeature.COASTLINE.with_scale("10m"), edgecolor='black') + self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') self._ax.add_feature(cfeature.LAKES.with_scale("50m")) self._ax.add_feature(cfeature.OCEAN.with_scale("50m")) - self._ax.add_feature(cfeature.RIVERS.with_scale("10m")) - self._ax.add_feature(cfeature.BORDERS.with_scale("10m"), facecolor='none', edgecolor='black') + self._ax.add_feature(cfeature.RIVERS.with_scale("50m")) + self._ax.add_feature(cfeature.BORDERS.with_scale("50m"), facecolor='none', edgecolor='black') def _plot_stations(self, generators): """ @@ -477,3 +477,76 @@ class PlotCompetitiveSkillScore(RunEnvironment): logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=500) plt.close() + + +class PlotTimeSeries(RunEnvironment): + + def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = "."): + super().__init__() + self._data_path = data_path + self._data_name = name + self._stations = stations + self._window_lead_time = self._get_window_lead_time(window_lead_time) + self._plot(plot_folder) + + def _get_window_lead_time(self, window_lead_time: int): + """ + Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from + data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number + of ahead dimensions in data is lower than the given lead time, data's lead time is used. + :param window_lead_time: lead time from arguments to validate + :return: validated lead time, comes either from given argument or from data itself + """ + ahead_steps = len(self._load_data(self._stations[0]).ahead) + if window_lead_time is None: + window_lead_time = ahead_steps + return min(ahead_steps, window_lead_time) + + def _load_data(self, station): + logging.debug(f"... preprocess station {station}") + file_name = os.path.join(self._data_path, self._data_name % station) + data = xr.open_dataarray(file_name) + return data.sel(type=["CNN", "orig"]) + + def _plot(self, plot_folder): + pdf_pages = self._save_pdf_pages(plot_folder) + start, end = self._get_time_range(self._load_data(self._stations[0])) + color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() + for pos, station in enumerate(self._stations): + data = self._load_data(station) + f, axes = plt.subplots(end - start + 1, sharey=True, figsize=(40, 20)) + nan_list = [] + for i in range(end - start + 1): + data_year = data.sel(index=f"{start + i}") + orig_data = data_year.sel(type="orig", ahead=1).values + axes[i].plot(data_year.index + np.timedelta64(1, "D"), orig_data, color=color_palette[0], label="orig") + for ahead in data.coords["ahead"].values: + plot_data = data_year.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze() + axes[i].plot(plot_data.index + np.timedelta64(int(ahead), "D"), plot_data.values, color=color_palette[ahead], label=f"{ahead}d") + if np.isnan(orig_data).all(): + nan_list.append(i) + for i in reversed(nan_list): + f.delaxes(axes[i]) + + plt.suptitle(station) + plt.legend() + plt.tight_layout() + pdf_pages.savefig(dpi=500) + pdf_pages.close() + plt.close('all') + + @staticmethod + def _get_time_range(data): + def f(x, f_x): + return pd.to_datetime(f_x(x.index.values)).year + return f(data, min), f(data, max) + + @staticmethod + def _save_pdf_pages(plot_folder): + """ + Standard save method to store plot locally. The name of this plot is static. + :param plot_folder: path to save the plot + """ + plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf') + logging.debug(f"... save plot to {plot_name}") + return matplotlib.backends.backend_pdf.PdfPages(plot_name) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index d1a22885325dae0483fae2a2e6493a391c4596b0..fdc691c33e40acd7f6b6c9ca9e80acaa33d9e055 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -17,7 +17,7 @@ from src.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ - PlotCompetitiveSkillScore + PlotCompetitiveSkillScore, PlotTimeSeries from src.plotting.postprocessing_plotting import plot_conditional_quantiles from src.run_modules.run_environment import RunEnvironment @@ -44,7 +44,7 @@ class PostProcessing(RunEnvironment): logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") with TimeTracking(): - preds_for_all_stations = self.make_prediction() + self.make_prediction() logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") self.skill_scores = self.calculate_skill_scores() @@ -76,6 +76,7 @@ class PostProcessing(RunEnvironment): PlotClimatologicalSkillScore(self.skill_scores[1], plot_folder=self.plot_path, score_only=False, extra_name_tag="all_terms_", model_setup="CNN") PlotCompetitiveSkillScore(self.skill_scores[0], plot_folder=self.plot_path, model_setup="CNN") + PlotTimeSeries(self.test_data.stations, path, r"forecasts_%s_test.nc", plot_folder=self.plot_path) def calculate_test_score(self): test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), @@ -94,12 +95,11 @@ class PostProcessing(RunEnvironment): def make_prediction(self, freq="1D"): logging.debug("start make_prediction") - nn_prediction_all_stations = [] - for i, v in enumerate(self.test_data): + for i, _ in enumerate(self.test_data): data = self.test_data.get_data_generator(i) nn_prediction, persistence_prediction, ols_prediction = self._create_empty_prediction_arrays(data, count=3) - input_data = self.test_data[i][0] + input_data = data.get_transposed_history() # get scaling parameters mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) @@ -130,10 +130,6 @@ class PostProcessing(RunEnvironment): file = os.path.join(path, f"forecasts_{data.station[0]}_test.nc") all_predictions.to_netcdf(file) - # save nn forecast to return variable - nn_prediction_all_stations.append(nn_prediction) - return nn_prediction_all_stations - @staticmethod def _create_orig_forecast(data, _, mean, std, transformation_method): return statistics.apply_inverse_transformation(data.label.copy(), mean, std, transformation_method) diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 9eef12f5dafbf2930ed22337c46281954de55b3c..4660a8116b6d0b860a7d0d50b92cee5e0deb77d8 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -35,7 +35,7 @@ class PreProcessing(RunEnvironment): def _run(self): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope="general.preprocessing") kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope="general.preprocessing") - valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general")) + valid_stations = self.check_valid_stations(args, kwargs, self.data_store.get("stations", "general"), load_tmp=False) self.data_store.set("stations", valid_stations, "general") self.split_train_val_test() self.report_pre_processing() @@ -96,7 +96,7 @@ class PreProcessing(RunEnvironment): self.data_store.set("generator", data_set, scope) @staticmethod - def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str]): + def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True): """ Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. @@ -117,9 +117,10 @@ class PreProcessing(RunEnvironment): for station in all_stations: t_inner.run() try: - (history, label) = data_gen[station] + # (history, label) = data_gen[station] + data = data_gen.get_data_generator(key=station, local_tmp_storage=load_tmp) valid_stations.append(station) - logging.debug(f"{station}: history_shape = {history.shape}") + logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") except (AttributeError, EmptyQueryResult): continue diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 2f741a1aca241767681f23f2e0912a28103b24b7..ad112549d833e0fda9eae3c238b106d4b0215e79 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -2,7 +2,11 @@ import os import pytest +import shutil +import numpy as np +import pickle from src.data_handling.data_generator import DataGenerator +from src.data_handling.data_preparation import DataPrep class TestDataGenerator: @@ -18,6 +22,12 @@ class TestDataGenerator: return DataGenerator(os.path.join(os.path.dirname(__file__), 'data'), 'AIRBASE', 'DEBW107', ['o3', 'temp'], 'datetime', 'variables', 'o3') + class DummyDataPrep: + def __init__(self, data): + self.station = "DEBW107" + self.variables = ["o3", "temp"] + self.data = data + def test_init(self, gen): assert gen.data_path == os.path.join(os.path.dirname(__file__), 'data') assert gen.network == 'AIRBASE' @@ -45,15 +55,6 @@ class TestDataGenerator: gen.stations = ['station1', 'station2', 'station3'] assert len(gen) == 3 - def test_getitem(self, gen): - gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} - station = gen["DEBW107"] - assert len(station) == 2 - assert station[0].Stations.data == "DEBW107" - assert station[0].data.shape[1:] == (8, 1, 2) - assert station[1].data.shape[-1] == gen.window_lead_time - assert station[0].data.shape[1] == gen.window_history_size + 1 - def test_iter(self, gen): assert hasattr(gen, '_iterator') is False iter(gen) @@ -65,6 +66,15 @@ class TestDataGenerator: for i, d in enumerate(gen, start=1): assert i == gen._iterator + def test_getitem(self, gen): + gen.kwargs = {'statistics_per_var': {'o3': 'dma8eu', 'temp': 'maximum'}} + station = gen["DEBW107"] + assert len(station) == 2 + assert station[0].Stations.data == "DEBW107" + assert station[0].data.shape[1:] == (8, 1, 2) + assert station[1].data.shape[-1] == gen.window_lead_time + assert station[0].data.shape[1] == gen.window_history_size + 1 + def test_get_station_key(self, gen): gen.stations.append("DEBW108") f = gen.get_station_key @@ -86,3 +96,38 @@ class TestDataGenerator: with pytest.raises(KeyError) as e: f(6.5) assert "key has to be from Union[str, int]. Given was 6.5 (float)" + + def test_get_data_generator(self, gen): + gen.kwargs = {"statistics_per_var": {'o3': 'dma8eu', 'temp': 'maximum'}} + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + if os.path.exists(file): + os.remove(file) + assert not os.path.exists(file) + assert isinstance(gen.get_data_generator("DEBW107", local_tmp_storage=False), DataPrep) + t = os.stat(file).st_ctime + assert os.path.exists(file) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert os.stat(file).st_mtime == t + os.remove(file) + assert isinstance(gen.get_data_generator("DEBW107"), DataPrep) + assert os.stat(file).st_ctime > t + + def test_save_pickle_data(self, gen): + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + if os.path.exists(file): + os.remove(file) + assert not os.path.exists(file) + data = self.DummyDataPrep(np.ones((10, 2))) + gen._save_pickle_data(data) + assert os.path.exists(file) + os.remove(file) + + def test_load_pickle_data(self, gen): + file = os.path.join(gen.data_path_tmp, f"DEBW107_{'_'.join(sorted(gen.variables))}.pickle") + data = self.DummyDataPrep(np.ones((10, 2))) + with open(file, "wb") as f: + pickle.dump(data, f) + assert os.path.exists(file) + res = gen._load_pickle_data("DEBW107", ["o3", "temp"]).data + assert np.testing.assert_almost_equal(res, np.ones((10, 2))) is None + os.remove(file) diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 8b5968460a935638ba4d64d3e939eb32d6b49189..fd33cea96af5131e34ca17ae8c95bef9266f796e 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -1,6 +1,7 @@ import datetime as dt import os from operator import itemgetter +import logging import numpy as np import pandas as pd @@ -19,6 +20,17 @@ class TestDataPrep: station_type='background', test='testKWARGS', statistics_per_var={'o3': 'dma8eu', 'temp': 'maximum'}) + @pytest.fixture + def data_prep_no_init(self): + d = object.__new__(DataPrep) + d.path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') + d.network = 'UBA' + d.station = ['DEBW107'] + d.variables = ['o3', 'temp'] + d.station_type = "background" + d.kwargs = None + return d + def test_init(self, data): assert data.path == os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') assert data.network == 'AIRBASE' @@ -33,16 +45,79 @@ class TestDataPrep: with pytest.raises(NotImplementedError): DataPrep('data/', 'dummy', 'DEBW107', ['o3', 'temp']) - def test_repr(self): - d = object.__new__(DataPrep) - d.path = 'data/test' - d.network = 'dummy' - d.station = ['DEBW107'] - d.variables = ['o3', 'temp'] - d.station_type = "traffic" - d.kwargs = None - assert d.__repr__().rstrip() == "Dataprep(path='data/test', network='dummy', station=['DEBW107'], "\ - "variables=['o3', 'temp'], station_type=traffic, **None)".rstrip() + def test_download_data(self, data_prep_no_init): + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + data_prep_no_init.download_data(file_name, meta_file) + assert isinstance(data_prep_no_init.data, xr.DataArray) + + def test_download_data_from_join(self, data_prep_no_init): + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + xarr, meta = data_prep_no_init.download_data_from_join(file_name, meta_file) + assert isinstance(xarr, xr.DataArray) + assert isinstance(meta, pd.DataFrame) + + def test_check_station_meta(self, caplog, data_prep_no_init): + caplog.set_level(logging.DEBUG) + file_name = data_prep_no_init._set_file_name() + meta_file = data_prep_no_init._set_meta_file_name() + data_prep_no_init.kwargs = {"store_data_locally": False} + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + data_prep_no_init.download_data(file_name, meta_file) + assert data_prep_no_init.check_station_meta() is None + data_prep_no_init.station_type = "traffic" + with pytest.raises(FileNotFoundError) as e: + data_prep_no_init.check_station_meta() + msg = "meta data does not agree with given request for station_type: traffic (requested) != background (local)" + assert caplog.record_tuples[-1][:-1] == ('root', 10) + assert msg in caplog.record_tuples[-1][-1] + + def test_load_data_overwrite_local_data(self, data_prep_no_init): + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + file_path = data_prep_no_init._set_file_name() + meta_file_path = data_prep_no_init._set_meta_file_name() + os.remove(file_path) + os.remove(meta_file_path) + assert not os.path.exists(file_path) + assert not os.path.exists(meta_file_path) + data_prep_no_init.kwargs = {"overwrite_local_data": True} + data_prep_no_init.load_data() + assert os.path.exists(file_path) + assert os.path.exists(meta_file_path) + t = os.stat(file_path).st_ctime + tm = os.stat(meta_file_path).st_ctime + data_prep_no_init.load_data() + assert os.path.exists(file_path) + assert os.path.exists(meta_file_path) + assert os.stat(file_path).st_ctime > t + assert os.stat(meta_file_path).st_ctime > tm + assert isinstance(data_prep_no_init.data, xr.DataArray) + assert isinstance(data_prep_no_init.meta, pd.DataFrame) + + def test_load_data_keep_local_data(self, data_prep_no_init): + data_prep_no_init.statistics_per_var = {'o3': 'dma8eu', 'temp': 'maximum'} + data_prep_no_init.station_type = None + data_prep_no_init.kwargs = {} + file_path = data_prep_no_init._set_file_name() + data_prep_no_init.load_data() + assert os.path.exists(file_path) + t = os.stat(file_path).st_ctime + data_prep_no_init.load_data() + assert os.path.exists(data_prep_no_init._set_file_name()) + assert os.stat(file_path).st_ctime == t + assert isinstance(data_prep_no_init.data, xr.DataArray) + assert isinstance(data_prep_no_init.meta, pd.DataFrame) + + def test_repr(self, data_prep_no_init): + path = os.path.join(os.path.abspath(os.path.dirname(__file__)), 'data') + assert data_prep_no_init.__repr__().rstrip() == f"Dataprep(path='{path}', network='UBA', " \ + f"station=['DEBW107'], variables=['o3', 'temp'], " \ + f"station_type=background, **None)".rstrip() def test_set_file_name_and_meta(self): d = object.__new__(DataPrep) @@ -135,6 +210,16 @@ class TestDataPrep: with pytest.raises(NotImplementedError): data.inverse_transform() + def test_get_transformation_information(self, data): + assert (None, None, None) == data.get_transformation_information("o3") + mean_test = data.data.mean("datetime").sel(variables='o3').values + std_test = data.data.std("datetime").sel(variables='o3').values + data.transform('datetime') + mean, std, info = data.get_transformation_information("o3") + assert np.testing.assert_almost_equal(mean, mean_test) is None + assert np.testing.assert_almost_equal(std, std_test) is None + assert info == "standardise" + def test_nan_remove_no_hist_or_label(self, data): assert data.history is None assert data.label is None diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py index 8bcebec60efe46db48a403b72e61a2036a6c00b9..2f6565b4cabe295169047a6582d2b89cbf387062 100644 --- a/test/test_model_modules/test_keras_extensions.py +++ b/test/test_model_modules/test_keras_extensions.py @@ -6,6 +6,24 @@ from src.helpers import l_p_loss from src.model_modules.keras_extensions import * +class TestHistoryAdvanced: + + def test_init(self): + hist = HistoryAdvanced() + assert hist.validation_data is None + assert hist.model is None + assert isinstance(hist.epoch, list) and len(hist.epoch) == 0 + assert isinstance(hist.history, dict) and len(hist.history.keys()) == 0 + + def test_on_train_begin(self): + hist = HistoryAdvanced() + hist.epoch = [1, 2, 3] + hist.history = {"mse": [10, 7, 4]} + hist.on_train_begin() + assert hist.epoch == [1, 2, 3] + assert hist.history == {"mse": [10, 7, 4]} + + class TestLearningRateDecay: def test_init(self):