Skip to content
Snippets Groups Projects
Commit 7da8cd6f authored by masak1112's avatar masak1112
Browse files

update save statistic information for era5_dataset.py

parent 2c2cb0b2
Branches
No related tags found
No related merge requests found
Pipeline #105444 failed
...@@ -21,6 +21,7 @@ class ERA5Dataset(BaseDataset): ...@@ -21,6 +21,7 @@ class ERA5Dataset(BaseDataset):
shuffled_on_val = shuffled_on_val, output_dir = output_dir) shuffled_on_val = shuffled_on_val, output_dir = output_dir)
self.shuffled_on_val = shuffled_on_val self.shuffled_on_val = shuffled_on_val
self.data_arr, self.init_times = self.load_data(self.filenames)
def specific_hparams(self)-> list: def specific_hparams(self)-> list:
s_hparams = ["context_frames", "sequence_length", "shift"] s_hparams = ["context_frames", "sequence_length", "shift"]
...@@ -82,7 +83,6 @@ class ERA5Dataset(BaseDataset): ...@@ -82,7 +83,6 @@ class ERA5Dataset(BaseDataset):
""" """
shuffle = self.mode == 'train' or (self.mode == 'val' and self.shuffle_on_val) shuffle = self.mode == 'train' or (self.mode == 'val' and self.shuffle_on_val)
data_arr, init_times = self.load_data(self.filenames)
def normalize_fn(x:tf.Tensor, min_value:float, max_value:float): def normalize_fn(x:tf.Tensor, min_value:float, max_value:float):
return tf.divide(tf.subtract(x, min_value), tf.subtract(max_value, min_value)) return tf.divide(tf.subtract(x, min_value), tf.subtract(max_value, min_value))
...@@ -121,7 +121,7 @@ class ERA5Dataset(BaseDataset): ...@@ -121,7 +121,7 @@ class ERA5Dataset(BaseDataset):
raise ("The filenames list is empty for {} dataset, please make sure your data_split dictionary is configured correctly".format(self.mode)) raise ("The filenames list is empty for {} dataset, please make sure your data_split dictionary is configured correctly".format(self.mode))
else: else:
#group the data into sequenceds #group the data into sequenceds
dataset = tf.data.Dataset.from_generator(lambda: iter(zip(data_arr, init_times)), dataset = tf.data.Dataset.from_generator(lambda: iter(zip(self.data_arr, self.init_times)),
output_types=(tf.float32,tf.string), output_shapes=((self.nlat,self.nlon,self.n_vars),())) output_types=(tf.float32,tf.string), output_shapes=((self.nlat,self.nlon,self.n_vars),()))
dataset = dataset.window(self.sequence_length, shift = self.shift, drop_remainder=True) dataset = dataset.window(self.sequence_length, shift = self.shift, drop_remainder=True)
dataset = dataset.flat_map(lambda x,y: tf.data.Dataset.zip((x.batch(self.sequence_length),y.batch(self.sequence_length)))) dataset = dataset.flat_map(lambda x,y: tf.data.Dataset.zip((x.batch(self.sequence_length),y.batch(self.sequence_length))))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please to comment