Skip to content
Snippets Groups Projects
Commit 530b8e9e authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue457_feat_set-config-paths-as-parameter' into 'develop'

Resolve "set config paths as parameter"

See merge request !520
parents 66df645d 59f09e79
Branches
Tags
3 merge requests!522filter can now combine obs, forecast, and apriori for first iteration. Further...,!521Resolve "release v2.4.0",!520Resolve "set config paths as parameter"
Pipeline #144451 passed
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from typing import Tuple from typing import Tuple
def era5_settings(sampling="daily") -> Tuple[str, str]: def era5_settings(sampling="daily", era5_data_path=None, era5_file_names=None) -> Tuple[str, str]:
""" """
Check for sampling as only hourly resolution is supported by era5 and return path on HPC systems. Check for sampling as only hourly resolution is supported by era5 and return path on HPC systems.
...@@ -12,8 +12,8 @@ def era5_settings(sampling="daily") -> Tuple[str, str]: ...@@ -12,8 +12,8 @@ def era5_settings(sampling="daily") -> Tuple[str, str]:
:return: HPC path :return: HPC path
""" """
if sampling == "hourly": # pragma: no branch if sampling == "hourly": # pragma: no branch
ERA5_DATA_PATH = "." ERA5_DATA_PATH = era5_data_path or "."
FILE_NAMES = "*.nc" FILE_NAMES = era5_file_names or "*.nc"
else: else:
raise NameError(f"Given sampling {sampling} is not supported, only hourly sampling can be used.") raise NameError(f"Given sampling {sampling} is not supported, only hourly sampling can be used.")
return ERA5_DATA_PATH, FILE_NAMES return ERA5_DATA_PATH, FILE_NAMES
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
from typing import Tuple from typing import Tuple
def ifs_settings(sampling="daily") -> Tuple[str, str]: def ifs_settings(sampling="daily", ifs_data_path=None, ifs_file_names=None) -> Tuple[str, str]:
""" """
Check for sampling as only hourly resolution is supported by ifs and return path on HPC systems. Check for sampling as only hourly resolution is supported by ifs and return path on HPC systems.
...@@ -12,8 +12,8 @@ def ifs_settings(sampling="daily") -> Tuple[str, str]: ...@@ -12,8 +12,8 @@ def ifs_settings(sampling="daily") -> Tuple[str, str]:
:return: HPC path :return: HPC path
""" """
if sampling == "hourly": # pragma: no branch if sampling == "hourly": # pragma: no branch
IFS_DATA_PATH = "." IFS_DATA_PATH = ifs_data_path or "."
FILE_NAMES = "*.nc" FILE_NAMES = ifs_file_names or "*.nc"
else: else:
raise NameError(f"Given sampling {sampling} is not supported, only hourly sampling can be used.") raise NameError(f"Given sampling {sampling} is not supported, only hourly sampling can be used.")
return IFS_DATA_PATH, FILE_NAMES return IFS_DATA_PATH, FILE_NAMES
...@@ -71,7 +71,8 @@ class DataHandlerSingleStation(AbstractDataHandler): ...@@ -71,7 +71,8 @@ class DataHandlerSingleStation(AbstractDataHandler):
interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD, interpolation_method: Union[str, Tuple[str]] = DEFAULT_INTERPOLATION_METHOD,
overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True, overwrite_local_data: bool = False, transformation=None, store_data_locally: bool = True,
min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None, min_length: int = 0, start=None, end=None, variables=None, data_origin: Dict = None,
lazy_preprocessing: bool = False, overwrite_lazy_data=False, **kwargs): lazy_preprocessing: bool = False, overwrite_lazy_data=False, era5_data_path=None, era5_file_names=None,
ifs_data_path=None, ifs_file_names=None, **kwargs):
super().__init__() super().__init__()
self.station = helpers.to_list(station) self.station = helpers.to_list(station)
self.path = self.setup_data_path(data_path, sampling) self.path = self.setup_data_path(data_path, sampling)
...@@ -115,6 +116,11 @@ class DataHandlerSingleStation(AbstractDataHandler): ...@@ -115,6 +116,11 @@ class DataHandlerSingleStation(AbstractDataHandler):
self.label = None self.label = None
self.observation = None self.observation = None
self._era5_data_path = era5_data_path
self._era5_file_names = era5_file_names
self._ifs_data_path = ifs_data_path
self._ifs_file_names = ifs_file_names
# create samples # create samples
self.setup_samples() self.setup_samples()
self.clean_up() self.clean_up()
...@@ -343,7 +349,11 @@ class DataHandlerSingleStation(AbstractDataHandler): ...@@ -343,7 +349,11 @@ class DataHandlerSingleStation(AbstractDataHandler):
data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling, data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
store_data_locally=store_data_locally, data_origin=data_origin, store_data_locally=store_data_locally, data_origin=data_origin,
time_dim=self.time_dim, target_dim=self.target_dim, time_dim=self.time_dim, target_dim=self.target_dim,
iter_dim=self.iter_dim, window_dim=self.window_dim) iter_dim=self.iter_dim, window_dim=self.window_dim,
era5_data_path=self._era5_data_path,
era5_file_names=self._era5_file_names,
ifs_data_path=self._ifs_data_path,
ifs_file_names=self._ifs_file_names)
logging.debug(f"{self.station[0]}: loaded new data") logging.debug(f"{self.station[0]}: loaded new data")
else: else:
try: try:
...@@ -358,7 +368,10 @@ class DataHandlerSingleStation(AbstractDataHandler): ...@@ -358,7 +368,10 @@ class DataHandlerSingleStation(AbstractDataHandler):
data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling, data, meta = data_sources.download_data(file_name, meta_file, station, statistics_per_var, sampling,
store_data_locally=store_data_locally, data_origin=data_origin, store_data_locally=store_data_locally, data_origin=data_origin,
time_dim=self.time_dim, target_dim=self.target_dim, time_dim=self.time_dim, target_dim=self.target_dim,
iter_dim=self.iter_dim) iter_dim=self.iter_dim, era5_data_path=self._era5_data_path,
era5_file_names=self._era5_file_names,
ifs_data_path=self._ifs_data_path,
ifs_file_names=self._ifs_file_names)
logging.debug(f"{self.station[0]}: loading finished") logging.debug(f"{self.station[0]}: loading finished")
# create slices and check for negative concentration. # create slices and check for negative concentration.
data = self._slice_prep(data, start=start, end=end) data = self._slice_prep(data, start=start, end=end)
......
...@@ -22,7 +22,9 @@ DEFAULT_WINDOW_DIM = "window" ...@@ -22,7 +22,9 @@ DEFAULT_WINDOW_DIM = "window"
def download_data(file_name: str, meta_file: str, station, statistics_per_var, sampling, def download_data(file_name: str, meta_file: str, station, statistics_per_var, sampling,
store_data_locally=True, data_origin: Dict = None, time_dim=DEFAULT_TIME_DIM, store_data_locally=True, data_origin: Dict = None, time_dim=DEFAULT_TIME_DIM,
target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM) -> [xr.DataArray, pd.DataFrame]: target_dim=DEFAULT_TARGET_DIM, iter_dim=DEFAULT_ITER_DIM, window_dim=DEFAULT_WINDOW_DIM,
era5_data_path=None, era5_file_names=None, ifs_data_path=None, ifs_file_names=None) -> \
[xr.DataArray, pd.DataFrame]:
""" """
Download data from TOAR database using the JOIN interface or load local era5 data. Download data from TOAR database using the JOIN interface or load local era5 data.
...@@ -54,12 +56,14 @@ def download_data(file_name: str, meta_file: str, station, statistics_per_var, s ...@@ -54,12 +56,14 @@ def download_data(file_name: str, meta_file: str, station, statistics_per_var, s
# load era5 data # load era5 data
df_era5_local, meta_era5_local = data_sources.era5.load_era5( df_era5_local, meta_era5_local = data_sources.era5.load_era5(
station_name=station, stat_var=era5_local_stats, sampling=sampling, data_origin=era5_local_origin, station_name=station, stat_var=era5_local_stats, sampling=sampling, data_origin=era5_local_origin,
window_dim=window_dim, time_dim=time_dim, target_dim=target_dim) window_dim=window_dim, time_dim=time_dim, target_dim=target_dim, era5_data_path=era5_data_path,
era5_file_names=era5_file_names)
if ifs_local_origin is not None and len(ifs_local_stats) > 0: if ifs_local_origin is not None and len(ifs_local_stats) > 0:
# load era5 data # load era5 data
df_ifs_local, meta_ifs_local = data_sources.ifs.load_ifs( df_ifs_local, meta_ifs_local = data_sources.ifs.load_ifs(
station_name=station, stat_var=ifs_local_stats, sampling=sampling, data_origin=ifs_local_origin, station_name=station, stat_var=ifs_local_stats, sampling=sampling, data_origin=ifs_local_origin,
lead_time_dim=window_dim, initial_time_dim=time_dim, target_dim=target_dim) lead_time_dim=window_dim, initial_time_dim=time_dim, target_dim=target_dim, ifs_data_path=ifs_data_path,
ifs_file_names=ifs_file_names)
if toar_origin is None or len(toar_stats) > 0: if toar_origin is None or len(toar_stats) > 0:
# load combined data from toar-data (v2 & v1) # load combined data from toar-data (v2 & v1)
df_toar, meta_toar = data_sources.toar_data.download_toar(station=station, toar_stats=toar_stats, df_toar, meta_toar = data_sources.toar_data.download_toar(station=station, toar_stats=toar_stats,
......
...@@ -17,13 +17,14 @@ from mlair.helpers.data_sources.data_loader import EmptyQueryResult ...@@ -17,13 +17,14 @@ from mlair.helpers.data_sources.data_loader import EmptyQueryResult
from mlair.helpers.meteo import relative_humidity_from_dewpoint from mlair.helpers.meteo import relative_humidity_from_dewpoint
def load_era5(station_name, stat_var, sampling, data_origin, time_dim, window_dim, target_dim): def load_era5(station_name, stat_var, sampling, data_origin, time_dim, window_dim, target_dim, era5_data_path=None,
era5_file_names=None):
# make sure station_name parameter is a list # make sure station_name parameter is a list
station_name = helpers.to_list(station_name) station_name = helpers.to_list(station_name)
# get data path # get data path
data_path, file_names = era5_settings(sampling) data_path, file_names = era5_settings(sampling, era5_data_path=era5_data_path, era5_file_names=era5_file_names)
# correct stat_var values if data is not aggregated (hourly) # correct stat_var values if data is not aggregated (hourly)
if sampling == "hourly": if sampling == "hourly":
......
...@@ -20,13 +20,14 @@ from mlair.helpers.data_sources.data_loader import EmptyQueryResult ...@@ -20,13 +20,14 @@ from mlair.helpers.data_sources.data_loader import EmptyQueryResult
from mlair.helpers.meteo import relative_humidity_from_dewpoint from mlair.helpers.meteo import relative_humidity_from_dewpoint
def load_ifs(station_name, stat_var, sampling, data_origin, lead_time_dim, initial_time_dim, target_dim): def load_ifs(station_name, stat_var, sampling, data_origin, lead_time_dim, initial_time_dim, target_dim,
ifs_data_path=None, ifs_file_names=None):
# make sure station_name parameter is a list # make sure station_name parameter is a list
station_name = helpers.to_list(station_name) station_name = helpers.to_list(station_name)
# get data path # get data path
data_path, file_names = ifs_settings(sampling) data_path, file_names = ifs_settings(sampling, ifs_data_path=ifs_data_path, ifs_file_names=ifs_file_names)
# correct stat_var values if data is not aggregated (hourly) # correct stat_var values if data is not aggregated (hourly)
if sampling == "hourly": if sampling == "hourly":
......
...@@ -461,7 +461,8 @@ class PreProcessing(RunEnvironment): ...@@ -461,7 +461,8 @@ class PreProcessing(RunEnvironment):
"neighbors", "plot_list", "plot_path", "regularizer", "restore_best_model_weights", "neighbors", "plot_list", "plot_path", "regularizer", "restore_best_model_weights",
"snapshot_load_path", "snapshot_path", "stations", "tmp_path", "train_model", "snapshot_load_path", "snapshot_path", "stations", "tmp_path", "train_model",
"transformation", "use_multiprocessing", "cams_data_path", "cams_interp_method", "transformation", "use_multiprocessing", "cams_data_path", "cams_interp_method",
"do_bias_free_evaluation", "apriori_file", "model_path", "model_load_path"] "do_bias_free_evaluation", "apriori_file", "model_path", "model_load_path", "era5_data_path",
"era5_file_names", "ifs_data_path", "ifs_file_names"]
data_handler = self.data_store.get("data_handler") data_handler = self.data_store.get("data_handler")
model_class = self.data_store.get("model_class") model_class = self.data_store.get("model_class")
excluded_params = list(set(excluded_params + data_handler.store_attributes() + model_class.requirements())) excluded_params = list(set(excluded_params + data_handler.store_attributes() + model_class.requirements()))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment