diff --git a/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..4f3a43f11a88e1172d4769bee98bbab8e0a7f59b --- /dev/null +++ b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json @@ -0,0 +1,13 @@ + +{ + "batch_size": 4, + "lr": 0.001, + "max_epochs":20, + "context_frames":12, + "loss_fun":"mse", + "opt_var": "0", + "shuffle_on_val":true +} + + + diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py index 09107827ea439e2e9a25712ff45793e57b65c856..725e41388681093385bca7ea53905ad2e308e7b1 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -99,7 +99,7 @@ class WeatherBenchModel(object): return self.is_build_graph - def build_model(self, x,filters, kernels, dr=0): + def build_model(self, x, filters, kernels, dr=0): """Fully convolutional network""" for f, k in zip(filters[:-1], kernels[:-1]): x = PeriodicConv2D(x, f, k) @@ -109,7 +109,6 @@ class WeatherBenchModel(object): return output - class PeriodicPadding2D(object): def __init__(self, x, pad_width): @@ -118,6 +117,7 @@ class PeriodicPadding2D(object): def call(self, inputs, **kwargs): if self.pad_width == 0: return inputs + inputs_padded = tf.concat( [inputs[:, :, -self.pad_width:, :], inputs, inputs[:, :, :self.pad_width, :]], axis=2) @@ -126,6 +126,7 @@ class PeriodicPadding2D(object): return inputs_padded + class PeriodicConv2D(object): def __init__(self, filters, kernel_size, conv_kwargs={}):