Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
AMBS
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Container registry
Model registry
Operate
Environments
Monitor
Incidents
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
esde
machine-learning
AMBS
Commits
5c0f3e79
Commit
5c0f3e79
authored
3 years ago
by
Michael Langguth
Browse files
Options
Downloads
Patches
Plain Diff
Corrected handling of t_start_points in preprocess_data_step2.py and source-code style changes.
parent
ff258b1d
No related branches found
No related tags found
No related merge requests found
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
video_prediction_tools/data_preprocess/preprocess_data_step2.py
+45
-36
45 additions, 36 deletions
...prediction_tools/data_preprocess/preprocess_data_step2.py
with
45 additions
and
36 deletions
video_prediction_tools/data_preprocess/preprocess_data_step2.py
+
45
−
36
View file @
5c0f3e79
...
...
@@ -11,6 +11,7 @@ import os
import
glob
import
pickle
import
numpy
as
np
import
pandas
as
pd
import
json
import
tensorflow
as
tf
from
normalization
import
Norm_data
...
...
@@ -32,10 +33,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
default:
"
minmax
"
)
"""
self
.
input_dir
=
input_dir
# ML: No hidden path-extensions (rather managed in generate_runscript.py)
# self.input_dir_pkl = os.path.join(input_dir,"pickle")
self
.
output_dir
=
dest_dir
# if the output_dir
i
s not exist, then create it
# if the output_dir
doe
s not exist, then create it
os
.
makedirs
(
self
.
output_dir
,
exist_ok
=
True
)
# get metadata,includes the var_in, image height, width etc.
self
.
metadata_fl
=
os
.
path
.
join
(
os
.
path
.
dirname
(
self
.
input_dir
.
rstrip
(
"
/
"
)),
"
metadata.json
"
)
...
...
@@ -62,9 +61,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
# search for folder names from pickle folder to get years
patt
=
"
X_*.pkl
"
for
year
in
self
.
years
:
print
(
"
pahtL:
"
,
os
.
path
.
join
(
self
.
input_dir
,
year
,
patt
))
months_pkl_list
=
glob
.
glob
(
os
.
path
.
join
(
self
.
input_dir
,
year
,
patt
))
print
(
"
months_pkl_list
"
,
months_pkl_list
)
months_list
=
[
int
(
m
[
-
6
:
-
4
])
for
m
in
months_pkl_list
]
self
.
months
.
extend
(
months_list
)
self
.
years_months
.
append
(
months_list
)
...
...
@@ -74,13 +71,15 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
"""
Get the corresponding statistics file
"""
self
.
stats_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
self
.
input_dir
),
"
statistics.json
"
)
print
(
"
Opening json-file: {0}
"
.
format
(
self
.
stats_file
))
if
os
.
path
.
isfile
(
self
.
stats_file
):
with
open
(
self
.
stats_file
)
as
js_file
:
method
=
ERA5Pkl2Tfrecords
.
get_stat_file
.
__name__
stats_file
=
os
.
path
.
join
(
os
.
path
.
dirname
(
self
.
input_dir
),
"
statistics.json
"
)
print
(
"
Opening json-file: {0}
"
.
format
(
stats_file
))
if
os
.
path
.
isfile
(
stats_file
):
with
open
(
stats_file
)
as
js_file
:
self
.
stats
=
json
.
load
(
js_file
)
else
:
raise
FileNotFoundError
(
"
Statistic file does not exist
"
)
raise
FileNotFoundError
(
"
%{0}: Could not find statistic file
'
{1}
'"
.
format
(
method
,
stats_file
)
)
def
get_metadata
(
self
,
md_instance
):
"""
...
...
@@ -90,18 +89,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
height : int, the height of the image
width : int, the width of the image
"""
method
=
ERA5Pkl2Tfrecords
.
get_metadata
.
__name__
if
not
isinstance
(
md_instance
,
MetaData
):
raise
ValueError
(
"
md_instance-argument must be a MetaData class instance
"
)
raise
ValueError
(
"
%{0}:
md_instance-argument must be a MetaData class instance
"
.
format
(
method
)
)
if
not
hasattr
(
self
,
"
metadata_fl
"
):
raise
ValueError
(
"
MetaData class instance passed, but attribute metadata_fl is still missing.
"
)
raise
ValueError
(
"
%{0}:
MetaData class instance passed, but attribute metadata_fl is still missing.
"
.
format
(
method
)
)
try
:
self
.
height
,
self
.
width
=
md_instance
.
ny
,
md_instance
.
nx
self
.
vars_in
=
md_instance
.
variables
except
:
raise
IOError
(
"
Could not retrieve all required information from metadata-file
'
{0}
'"
.
format
(
self
.
metadata_fl
))
raise
IOError
(
"
%{0}:
Could not retrieve all required information from metadata-file
'
{0}
'"
.
format
(
method
,
self
.
metadata_fl
))
@staticmethod
def
save_tf_record
(
output_fname
,
sequences
,
t_start_points
):
...
...
@@ -114,12 +115,12 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
t_start_points : datetime type in the list, the first timestamp for each sequence
[seq_len,height,width, channel], the len of t_start_points is the same as sequences
"""
method
=
ERA5Pkl2Tfrecords
.
save_tf_record
.
__name__
sequences
=
np
.
array
(
sequences
)
# sanity checks
print
(
t_start_points
[
0
])
print
(
type
(
t_start_points
[
0
]))
assert
sequences
.
shape
[
0
]
==
len
(
t_start_points
)
assert
type
(
t_start_points
)
==
datetime
.
datetime
,
"
What
'
s that: {0} (type {1})
"
.
format
(
t_start_points
[
0
],
type
(
t_start_points
[
0
]))
assert
sequences
.
shape
[
0
]
==
len
(
t_start_points
),
"
%{0}: Lengths of sequence differs from length of t_start_points.
"
.
format
(
method
)
assert
type
(
t_start_points
[
0
])
==
datetime
.
datetime
,
"
%{0}: Elements of t_start_points must be datetime-objects.
"
.
format
(
method
)
with
tf
.
python_io
.
TFRecordWriter
(
output_fname
)
as
writer
:
for
i
in
range
(
len
(
sequences
)):
...
...
@@ -144,7 +145,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
"""
Get normalization data class
"""
print
(
"
Make use of default minmax-normalization...
"
)
method
=
ERA5Pkl2Tfrecords
.
init_norm_class
.
__name__
print
(
"
%{0}: Make use of default minmax-normalization.
"
.
format
(
method
))
# init normalization-instance
self
.
norm_cls
=
Norm_data
(
self
.
vars_in
)
self
.
nvars
=
len
(
self
.
vars_in
)
...
...
@@ -162,7 +165,9 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
Return:
the normalized sequences
"""
assert
len
(
np
.
array
(
sequences
).
shape
)
==
5
method
=
ERA5Pkl2Tfrecords
.
normalize_vars_per_seq
.
__name__
assert
len
(
np
.
array
(
sequences
).
shape
)
==
5
,
"
%{0}: Length of sequence array must be 5.
"
.
format
(
method
)
# normalization should adpot the selected variables, here we used duplicated channel temperature variables
sequences
=
np
.
array
(
sequences
)
# normalization
...
...
@@ -177,6 +182,8 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
year : int, the target year to save to tfrecord
month : int, the target month to save to tfrecord
"""
method
=
ERA5Pkl2Tfrecords
.
read_pkl_and_save_tfrecords
.
__name__
# Define the input_file based on the year and month
self
.
input_file_year
=
os
.
path
.
join
(
self
.
input_dir
,
str
(
year
))
input_file
=
os
.
path
.
join
(
self
.
input_file_year
,
'
X_{:02d}.pkl
'
.
format
(
month
))
...
...
@@ -187,11 +194,17 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
t_start_points
=
[]
sequence_iter
=
0
#
try:
try
:
with
open
(
input_file
,
"
rb
"
)
as
data_file
:
X_train
=
pickle
.
load
(
data_file
)
except
:
raise
IOError
(
"
%{0}: Could not read data from pickle-file
'
{1}
'"
.
format
(
method
,
input_file
))
try
:
with
open
(
temp_input_file
,
"
rb
"
)
as
temp_file
:
T_train
=
pickle
.
load
(
temp_file
)
except
:
raise
IOError
(
"
%{0}: Could not read data from pickle-file
'
{1}
'"
.
format
(
method
,
temp_input_file
))
# check to make sure that X_train and T_train have the same length
assert
(
len
(
X_train
)
==
len
(
T_train
))
...
...
@@ -202,8 +215,6 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
seq
=
X_train
[
X_start
:
X_end
,
...]
# recording the start point of the timestamps (already datetime-objects)
t_start
=
ERA5Pkl2Tfrecords
.
ensure_datetime
(
T_train
[
X_start
][
0
])
print
(
"
t_start,
"
,
t_start
)
print
(
"
type of t_starty
"
,
type
(
t_start
))
seq
=
list
(
np
.
array
(
seq
).
reshape
((
self
.
sequence_length
,
self
.
height
,
self
.
width
,
self
.
nvars
)))
if
not
sequences
:
last_start_sequence_iter
=
sequence_iter
...
...
@@ -221,7 +232,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
ERA5Pkl2Tfrecords
.
write_seq_to_tfrecord
(
output_fname
,
sequences
,
t_start_points
)
t_start_points
=
[]
sequences
=
[]
print
(
"
Finished
for input file
"
,
input_file
)
print
(
"
%{0}:
Finished
processing of input file
'
{1}
'"
.
format
(
method
,
input_file
)
)
# except FileNotFoundError as fnf_error:
# print(fnf_error)
...
...
@@ -232,8 +243,10 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
Function to check if the sequences has been processed.
If yes, the sequences are skipped, otherwise the sequences are saved to the output file
"""
method
=
ERA5Pkl2Tfrecords
.
write_seq_to_tfrecord
.
__name__
if
os
.
path
.
isfile
(
output_fname
):
print
(
output_fname
,
'
already exists, skip it
'
)
print
(
"
%{0}: TFrecord-file {1} already exists. It is therefore skipped.
"
.
format
(
method
,
output_fname
)
)
else
:
ERA5Pkl2Tfrecords
.
save_tf_record
(
output_fname
,
list
(
sequences
),
t_start_points
)
...
...
@@ -252,24 +265,20 @@ class ERA5Pkl2Tfrecords(ERA5Dataset):
"""
Wrapper to return a datetime-object
"""
method
=
ERA5Pkl2Tfrecords
.
ensure_datetime
.
__name__
fmt
=
"
%Y%m%d %H:%M
"
if
isinstance
(
date
,
datetime
.
datetime
):
date_new
=
date
else
:
try
:
date_new
=
pd
.
to_datetime
(
date
)
date_new
=
date
time
.
datetime
(
date_new
.
strptime
(
fmt
),
fmt
)
date_new
=
date
_new
.
to_pydatetime
(
)
except
Exception
as
err
:
print
(
"
%{0}: Could not handle input data {1} which is of type {2}.
"
.
format
(
method
,
date
,
type
(
date
)))
raise
err
return
date_new
# def num_examples_per_epoch(self):
# with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file:
# sequence_lengths = sequence_lengths_file.readlines()
# sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths]
# return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length)
def
_bytes_feature
(
value
):
return
tf
.
train
.
Feature
(
bytes_list
=
tf
.
train
.
BytesList
(
value
=
[
value
]))
...
...
This diff is collapsed.
Click to expand it.
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment