Skip to content
GitLab
Menu
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
esde
machine-learning
MLAir
Commits
90b04251
Commit
90b04251
authored
Jun 30, 2022
by
lukas leufen
👻
Browse files
added trimm method as applied in
#384
parent
a58a6487
Pipeline
#104313
passed with stages
in 12 minutes and 15 seconds
Changes
1
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
mlair/run_modules/post_processing.py
View file @
90b04251
...
...
@@ -261,11 +261,17 @@ class PostProcessing(RunEnvironment):
"""Ensure time dimension to be equidistant. Sometimes dates if missing values have been dropped."""
start_data
=
data
.
coords
[
dim
].
values
[
0
]
freq
=
{
"daily"
:
"1D"
,
"hourly"
:
"1H"
}.
get
(
sampling
)
datetime_index
=
pd
.
DataFrame
(
index
=
pd
.
date_range
(
start
,
end
,
freq
=
freq
))
_ind
=
pd
.
date_range
(
start
,
end
,
freq
=
freq
)
# two steps required to include all hours of end interval
datetime_index
=
pd
.
DataFrame
(
index
=
pd
.
date_range
(
_ind
.
min
(),
_ind
.
max
()
+
dt
.
timedelta
(
days
=
1
),
closed
=
"left"
,
freq
=
freq
))
t
=
data
.
sel
({
dim
:
start_data
},
drop
=
True
)
res
=
xr
.
DataArray
(
coords
=
[
datetime_index
.
index
,
*
[
t
.
coords
[
c
]
for
c
in
t
.
coords
]],
dims
=
[
dim
,
*
t
.
coords
])
res
=
res
.
transpose
(
*
data
.
dims
)
res
.
loc
[
data
.
coords
]
=
data
if
data
.
shape
==
res
.
shape
:
res
.
loc
[
data
.
coords
]
=
data
else
:
_d
=
data
.
sel
({
dim
:
slice
(
start
,
end
)})
res
.
loc
[
_d
.
coords
]
=
_d
return
res
def
load_competitors
(
self
,
station_name
:
str
)
->
xr
.
DataArray
:
...
...
@@ -761,6 +767,7 @@ class PostProcessing(RunEnvironment):
indicated by `station_name`. The name of the competitor is set in the `type` axis as indicator. This method will
raise either a `FileNotFoundError` or `KeyError` if no competitor could be found for the given station. Either
there is no file provided in the expected path or no forecast for given `competitor_name` in the forecast file.
Forecast is trimmed on interval start and end of test subset.
:param station_name: name of the station to load data for
:param competitor_name: name of the model
...
...
@@ -769,10 +776,12 @@ class PostProcessing(RunEnvironment):
path
=
os
.
path
.
join
(
self
.
competitor_path
,
competitor_name
)
file
=
os
.
path
.
join
(
path
,
f
"forecasts_
{
station_name
}
_test.nc"
)
with
xr
.
open_dataarray
(
file
)
as
da
:
data
=
da
.
load
()
data
=
da
.
load
()
forecast
=
data
.
sel
(
type
=
[
self
.
forecast_indicator
])
forecast
.
coords
[
self
.
model_type_dim
]
=
[
competitor_name
]
return
forecast
# limit forecast to time range of test subset
start
,
end
=
self
.
data_store
.
get
(
"start"
,
"test"
),
self
.
data_store
.
get
(
"end"
,
"test"
)
return
self
.
create_full_time_dim
(
forecast
,
self
.
index_dim
,
self
.
_sampling
,
start
,
end
)
def
_create_observation
(
self
,
data
,
_
,
transformation_func
:
Callable
,
normalised
:
bool
)
->
xr
.
DataArray
:
"""
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment