diff --git a/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json index 4f3a43f11a88e1172d4769bee98bbab8e0a7f59b..219a4caf1567e0777c45e6f2b0d03b73d91a93cb 100644 --- a/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json @@ -1,12 +1,14 @@ { "batch_size": 4, - "lr": 0.001, + "lr": 0.0001, "max_epochs":20, "context_frames":12, "loss_fun":"mse", "opt_var": "0", - "shuffle_on_val":true + "shuffle_on_val":true, + "filters": [64, 64, 64, 64, 2], + "kernels": [5, 5, 5, 5, 5] } diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py index ca7b6f0b2ad86e2c62510fc21d9d0d66f8babf06..ac172074484648871f586b9aeed868ea6966b3fa 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -24,14 +24,11 @@ class WeatherBenchModel(object): self.hparams = self.parse_hparams() self.learning_rate = self.hparams.lr self.filters = self.hparams.filters - self.kernels = self.hparams.kernes + self.kernels = self.hparams.kernels self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length- self.context_frames self.max_epochs = self.hparams.max_epochs self.loss_fun = self.hparams.loss_fun self.batch_size = self.hparams.batch_size - self.recon_weight = self.hparams.recon_weight self.outputs = {} self.total_loss = None @@ -60,15 +57,13 @@ class WeatherBenchModel(object): """ hparams = dict( context_frames =12, - sequence_length =24, max_epochs = 20, batch_size = 40, lr = 0.001, loss_fun = "mse", shuffle_on_val= True, - filter = 4, - kernels = 4, - + filter = [64, 64, 64, 64, 2], + kernels = [5, 5, 5, 5, 5] ) return hparams @@ -85,8 +80,10 @@ class WeatherBenchModel(object): x_hat = self.build_model(x, self.filters, self.kernels, dr=0) # Loss self.total_loss = l1_loss(x[...,0], x_hat[...,0]) + # Optimizer - self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) # outputs self.outputs["total_loss"] = self.total_loss @@ -111,7 +108,7 @@ class WeatherBenchModel(object): class PeriodicPadding2D(object): - def __init__(self, x, pad_width): + def __init__(self, pad_width): self.pad_width = pad_width @@ -127,7 +124,6 @@ class PeriodicPadding2D(object): return inputs_padded - class PeriodicConv2D(object): def __init__(self, filters, kernel_size, conv_kwargs={}):