Skip to content
Snippets Groups Projects
Commit 3f4a88df authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue193_bug_index_error' into 'lukas_issue079_feat_kz-filtered-data-handler'

Lukas issue193 bug index error

See merge request toar/mlair!168
parents 005f0993 234183fd
No related branches found
No related tags found
2 merge requests!168Lukas issue193 bug index error,!139Draft: Resolve "KZ filter"
Pipeline #48726 passed
......@@ -32,6 +32,11 @@ def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray:
:return: combined xarray
"""
if len(d.keys()) == 1:
k = list(d.keys())
xarray: xr.DataArray = d[k[0]]
return xarray.expand_dims(dim={coordinate_name: k}, axis=0)
else:
xarray = None
for k, v in d.items():
if xarray is None:
......
......@@ -4,7 +4,7 @@ __date__ = '2019-11-15'
import argparse
import logging
import os
from typing import Union, Dict, Any, List
from typing import Union, Dict, Any, List, Callable
from mlair.configuration import path_config
from mlair import helpers
......@@ -279,7 +279,7 @@ class ExperimentSetup(RunEnvironment):
path_config.check_path_and_create(self.data_store.get("logging_path"))
# setup for data
self._set_param("stations", stations, default=DEFAULT_STATIONS)
self._set_param("stations", stations, default=DEFAULT_STATIONS, apply=helpers.to_list)
self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT)
self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys()))
self._set_param("start", start, default=DEFAULT_START)
......@@ -355,10 +355,14 @@ class ExperimentSetup(RunEnvironment):
raise KeyError(f"Given argument {k} with value {v} cannot be set for this experiment due to a "
f"conflict with an existing entry with same naming: {k}={self.data_store.get(k)}")
def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None:
"""Set given parameter and log in debug."""
def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general",
apply: Callable = None) -> None:
"""Set given parameter and log in debug. Use apply parameter to adjust the stored value (e.g. to transform value
to a list use apply=helpers.to_list)."""
if value is None and default is not None:
value = default
if apply is not None:
value = apply(value)
self.data_store.set(param, value, scope)
logging.debug(f"set experiment attribute: {param}({scope})={value}")
......
......@@ -124,14 +124,22 @@ class TestPytestRegex:
class TestDictToXarray:
def test_dict_to_xarray(self):
array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]})
array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]})
array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
d = {"number1": array1, "number2": array2}
res = dict_to_xarray(d, "merge_dim")
assert type(res) == xr.DataArray
assert sorted(list(res.coords)) == ["merge_dim", "x"]
assert sorted(list(res.coords)) == ["merge_dim", "x", "y"]
assert res.shape == (2, 2, 3)
def test_dict_to_xarray_single_entry(self):
array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20], 'y': [0, 10, 20]})
d = {"number1": array1}
res = dict_to_xarray(d, "merge_dim")
assert type(res) == xr.DataArray
assert sorted(list(res.coords)) == ["merge_dim", "x", "y"]
assert res.shape == (1, 2, 3)
class TestFloatRound:
......
......@@ -4,7 +4,7 @@ import os
import pytest
from mlair.helpers import TimeTracking
from mlair.helpers import TimeTracking, to_list
from mlair.configuration.path_config import prepare_host
from mlair.run_modules.experiment_setup import ExperimentSetup
......@@ -33,6 +33,16 @@ class TestExperimentSetup:
empty_obj._set_param("AnotherNoneTester", None)
assert empty_obj.data_store.get("AnotherNoneTester", "general") is None
def test_set_param_with_apply(self, caplog, empty_obj):
empty_obj._set_param("NoneTester", None, default="notNone", apply=None)
assert empty_obj.data_store.get("NoneTester") == "notNone"
empty_obj._set_param("NoneTester", None, default="notNone", apply=to_list)
assert empty_obj.data_store.get("NoneTester") == ["notNone"]
empty_obj._set_param("NoneTester", None, apply=to_list)
assert empty_obj.data_store.get("NoneTester") == [None]
empty_obj._set_param("NoneTester", 2.3, apply=int)
assert empty_obj.data_store.get("NoneTester") == 2
def test_init_default(self):
exp_setup = ExperimentSetup()
data_store = exp_setup.data_store
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment