From 5ceb57789cf632400a04d646a43776b854e01d4a Mon Sep 17 00:00:00 2001 From: BING GONG <b.gong@fz-juelich.de> Date: Thu, 9 Jun 2022 10:50:26 +0200 Subject: [PATCH] update the hparameters for weatherbench model --- .../weatherBench/model_hparams_template.json | 6 ++++-- .../models/weatherBench3DCNN.py | 18 +++++++----------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json index 4f3a43f1..219a4caf 100644 --- a/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/weatherBench/model_hparams_template.json @@ -1,12 +1,14 @@ { "batch_size": 4, - "lr": 0.001, + "lr": 0.0001, "max_epochs":20, "context_frames":12, "loss_fun":"mse", "opt_var": "0", - "shuffle_on_val":true + "shuffle_on_val":true, + "filters": [64, 64, 64, 64, 2], + "kernels": [5, 5, 5, 5, 5] } diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py index ca7b6f0b..ac172074 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -24,14 +24,11 @@ class WeatherBenchModel(object): self.hparams = self.parse_hparams() self.learning_rate = self.hparams.lr self.filters = self.hparams.filters - self.kernels = self.hparams.kernes + self.kernels = self.hparams.kernels self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length- self.context_frames self.max_epochs = self.hparams.max_epochs self.loss_fun = self.hparams.loss_fun self.batch_size = self.hparams.batch_size - self.recon_weight = self.hparams.recon_weight self.outputs = {} self.total_loss = None @@ -60,15 +57,13 @@ class WeatherBenchModel(object): """ hparams = dict( context_frames =12, - sequence_length =24, max_epochs = 20, batch_size = 40, lr = 0.001, loss_fun = "mse", shuffle_on_val= True, - filter = 4, - kernels = 4, - + filter = [64, 64, 64, 64, 2], + kernels = [5, 5, 5, 5, 5] ) return hparams @@ -85,8 +80,10 @@ class WeatherBenchModel(object): x_hat = self.build_model(x, self.filters, self.kernels, dr=0) # Loss self.total_loss = l1_loss(x[...,0], x_hat[...,0]) + # Optimizer - self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) # outputs self.outputs["total_loss"] = self.total_loss @@ -111,7 +108,7 @@ class WeatherBenchModel(object): class PeriodicPadding2D(object): - def __init__(self, x, pad_width): + def __init__(self, pad_width): self.pad_width = pad_width @@ -127,7 +124,6 @@ class PeriodicPadding2D(object): return inputs_padded - class PeriodicConv2D(object): def __init__(self, filters, kernel_size, conv_kwargs={}): -- GitLab