Skip to content
Snippets Groups Projects
Commit 3f4a88df authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue193_bug_index_error' into 'lukas_issue079_feat_kz-filtered-data-handler'

Lukas issue193 bug index error

See merge request toar/mlair!168
parents 005f0993 234183fd
No related branches found
No related tags found
2 merge requests!168Lukas issue193 bug index error,!139Draft: Resolve "KZ filter"
Pipeline #48726 passed
...@@ -32,6 +32,11 @@ def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: ...@@ -32,6 +32,11 @@ def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray:
:return: combined xarray :return: combined xarray
""" """
if len(d.keys()) == 1:
k = list(d.keys())
xarray: xr.DataArray = d[k[0]]
return xarray.expand_dims(dim={coordinate_name: k}, axis=0)
else:
xarray = None xarray = None
for k, v in d.items(): for k, v in d.items():
if xarray is None: if xarray is None:
......
...@@ -4,7 +4,7 @@ __date__ = '2019-11-15' ...@@ -4,7 +4,7 @@ __date__ = '2019-11-15'
import argparse import argparse
import logging import logging
import os import os
from typing import Union, Dict, Any, List from typing import Union, Dict, Any, List, Callable
from mlair.configuration import path_config from mlair.configuration import path_config
from mlair import helpers from mlair import helpers
...@@ -279,7 +279,7 @@ class ExperimentSetup(RunEnvironment): ...@@ -279,7 +279,7 @@ class ExperimentSetup(RunEnvironment):
path_config.check_path_and_create(self.data_store.get("logging_path")) path_config.check_path_and_create(self.data_store.get("logging_path"))
# setup for data # setup for data
self._set_param("stations", stations, default=DEFAULT_STATIONS) self._set_param("stations", stations, default=DEFAULT_STATIONS, apply=helpers.to_list)
self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys())) self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys()))
self._set_param("start", start, default=DEFAULT_START) self._set_param("start", start, default=DEFAULT_START)
...@@ -355,10 +355,14 @@ class ExperimentSetup(RunEnvironment): ...@@ -355,10 +355,14 @@ class ExperimentSetup(RunEnvironment):
raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a " raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a "
f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}") f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}")
def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general",
"""Set given parameter and log in debug.""" apply: Callable = None) -> None:
"""Set given parameter and log in debug. Use apply parameter to adjust the stored value (e.g. to transform value
to a list use apply=helpers.to_list)."""
if value is None and default is not None: if value is None and default is not None:
value = default value = default
if apply is not None:
value = apply(value)
self.data_store.set(param, value, scope) self.data_store.set(param, value, scope)
logging.debug(f"set experiment attribute: {param}({scope})={value}") logging.debug(f"set experiment attribute: {param}({scope})={value}")
......
...@@ -124,14 +124,22 @@ class TestPytestRegex: ...@@ -124,14 +124,22 @@ class TestPytestRegex:
class TestDictToXarray: class TestDictToXarray:
def test_dict_to_xarray(self): def test_dict_to_xarray(self):
array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
d = {"number1": array1, "number2": array2} d = {"number1": array1, "number2": array2}
res = dict_to_xarray(d, "merge_dim") res = dict_to_xarray(d, "merge_dim")
assert type(res) == xr.DataArray assert type(res) == xr.DataArray
assert sorted(list(res.coords)) == ["merge_dim", "x"] assert sorted(list(res.coords)) == ["merge_dim", "x", "y"]
assert res.shape == (2, 2, 3) assert res.shape == (2, 2, 3)
def test_dict_to_xarray_single_entry(self):
array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
d = {"number1": array1}
res = dict_to_xarray(d, "merge_dim")
assert type(res) == xr.DataArray
assert sorted(list(res.coords)) == ["merge_dim", "x", "y"]
assert res.shape == (1, 2, 3)
class TestFloatRound: class TestFloatRound:
......
...@@ -4,7 +4,7 @@ import os ...@@ -4,7 +4,7 @@ import os
import pytest import pytest
from mlair.helpers import TimeTracking from mlair.helpers import TimeTracking, to_list
from mlair.configuration.path_config import prepare_host from mlair.configuration.path_config import prepare_host
from mlair.run_modules.experiment_setup import ExperimentSetup from mlair.run_modules.experiment_setup import ExperimentSetup
...@@ -33,6 +33,16 @@ class TestExperimentSetup: ...@@ -33,6 +33,16 @@ class TestExperimentSetup:
empty_obj._set_param("AnotherNoneTester", None) empty_obj._set_param("AnotherNoneTester", None)
assert empty_obj.data_store.get("AnotherNoneTester", "general") is None assert empty_obj.data_store.get("AnotherNoneTester", "general") is None
def test_set_param_with_apply(self, caplog, empty_obj):
empty_obj._set_param("NoneTester", None, default="notNone", apply=None)
assert empty_obj.data_store.get("NoneTester") == "notNone"
empty_obj._set_param("NoneTester", None, default="notNone", apply=to_list)
assert empty_obj.data_store.get("NoneTester") == ["notNone"]
empty_obj._set_param("NoneTester", None, apply=to_list)
assert empty_obj.data_store.get("NoneTester") == [None]
empty_obj._set_param("NoneTester", 2.3, apply=int)
assert empty_obj.data_store.get("NoneTester") == 2
def test_init_default(self): def test_init_default(self):
exp_setup = ExperimentSetup() exp_setup = ExperimentSetup()
data_store = exp_setup.data_store data_store = exp_setup.data_store
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment