diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index b16e33919d9f335e0d2b45ad3309ad901e568f57..5242c8848bd5c0fa4ee8927532a2c8e7bf15743d 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -426,6 +426,13 @@ class TrainModel(object):
             fetch_list = fetch_list + ["inputs", "total_loss"]
             self.saver_loss = fetch_list[-1]
             self.saver_loss_name = "Total loss"
+        if self.video_model.__class__.__name__ == "WeatherBenchModel":
+            fetch_list = fetch_list + ["total_loss"]
+            self.saver_loss = fetch_list[-1]
+            self.saver_loss_name = "Total loss"
+        else:
+            raise ("self.saver_loss is not set up for your video model class {}".format(self.video_model.__class__.__name__ ))
+
 
         self.fetches = self.generate_fetches(fetch_list)
 
@@ -491,7 +498,7 @@ class TrainModel(object):
         if self.video_model.__class__.__name__ == "McNetVideoPredictionModel":
             print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"], results["L_p"],
                                                                            results["L_gdl"],results["L_GAN"]))
-        elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
+        elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or self.video_model.__class__.__name__ == "WeatherBenchModel":
             print ("Total_loss:{}".format(results["total_loss"]))
         elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel":
             print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}"
diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
index 07b5cbf26251e5d61520b07cc25e22d64338d195..77ceb43ba93be8c4eaa00ec068e1fc12d43ab382 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
@@ -23,11 +23,10 @@ class WeatherBenchModel(object):
         self.hparams_dict = hparams_dict
         self.hparams = self.parse_hparams()
         self.learning_rate = self.hparams.lr
+        print("hparams",self.hparams)
         self.filters = self.hparams.filters
         self.kernels = self.hparams.kernels
-        self.context_frames = self.hparams.context_frames
         self.max_epochs = self.hparams.max_epochs
-        self.loss_fun = self.hparams.loss_fun
         self.batch_size = self.hparams.batch_size
         self.outputs = {}
         self.total_loss = None
@@ -50,19 +49,19 @@ class WeatherBenchModel(object):
         Returns:
             A dict with the following hyperparameters.
             context_frames  : the number of ground-truth frames to pass in at start.
-            sequence_length : the number of frames in the video sequence
             max_epochs      : the number of epochs to train model
             lr              : learning rate
             loss_fun        : the loss function
+            filters         : list contains the filters of each convolutional layer
+            kernels         : list contains the kernels size for each convolutional layer
             """
         hparams = dict(
-            context_frames =12,
+            sequence_length =12,
             max_epochs = 20,
             batch_size = 40,
             lr = 0.001,
-            loss_fun = "mse",
             shuffle_on_val= True,
-            filter = [64, 64, 64, 64, 2],
+            filters = [64, 64, 64, 64, 2],
             kernels = [5, 5, 5, 5, 5]
         )
         return hparams
@@ -70,16 +69,15 @@ class WeatherBenchModel(object):
 
     def build_graph(self, x):
         self.is_build_graph = False
-        self.inputs = x
         self.x = x["images"]
 
         self.global_step = tf.train.get_or_create_global_step()
         original_global_variables = tf.global_variables()
 
         # Architecture
-        x_hat = self.build_model(x, self.filters, self.kernels)
+        x_hat = self.build_model(self.x[:,0,:, :,0:1], self.filters, self.kernels)
         # Loss
-        self.total_loss = l1_loss(x[, 0,:, :,0], x_hat[, 0,:, :,0])
+        self.total_loss = l1_loss(self.x[:,0,:, :,0:1], x_hat)
 
         # Optimizer
         self.train_op = tf.train.AdamOptimizer(
@@ -99,9 +97,13 @@ class WeatherBenchModel(object):
 
     def build_model(self, x, filters, kernels):
         """Fully convolutional network"""
+        idx = 0 
         for f, k in zip(filters[:-1], kernels[:-1]):
-            x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="Conv_layer", activate="leaky_relu")
-        output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_layer", activate="linear")
+            print("1",x)
+            x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu")
+            print("2",x)
+            idx += 1
+        output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear")
         return output