diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index b16e33919d9f335e0d2b45ad3309ad901e568f57..5242c8848bd5c0fa4ee8927532a2c8e7bf15743d 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -426,6 +426,13 @@ class TrainModel(object): fetch_list = fetch_list + ["inputs", "total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" + if self.video_model.__class__.__name__ == "WeatherBenchModel": + fetch_list = fetch_list + ["total_loss"] + self.saver_loss = fetch_list[-1] + self.saver_loss_name = "Total loss" + else: + raise ("self.saver_loss is not set up for your video model class {}".format(self.video_model.__class__.__name__ )) + self.fetches = self.generate_fetches(fetch_list) @@ -491,7 +498,7 @@ class TrainModel(object): if self.video_model.__class__.__name__ == "McNetVideoPredictionModel": print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"], results["L_p"], results["L_gdl"],results["L_GAN"])) - elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": + elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or self.video_model.__class__.__name__ == "WeatherBenchModel": print ("Total_loss:{}".format(results["total_loss"])) elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}" diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py index 07b5cbf26251e5d61520b07cc25e22d64338d195..77ceb43ba93be8c4eaa00ec068e1fc12d43ab382 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -23,11 +23,10 @@ class WeatherBenchModel(object): self.hparams_dict = hparams_dict self.hparams = self.parse_hparams() self.learning_rate = self.hparams.lr + print("hparams",self.hparams) self.filters = self.hparams.filters self.kernels = self.hparams.kernels - self.context_frames = self.hparams.context_frames self.max_epochs = self.hparams.max_epochs - self.loss_fun = self.hparams.loss_fun self.batch_size = self.hparams.batch_size self.outputs = {} self.total_loss = None @@ -50,19 +49,19 @@ class WeatherBenchModel(object): Returns: A dict with the following hyperparameters. context_frames : the number of ground-truth frames to pass in at start. - sequence_length : the number of frames in the video sequence max_epochs : the number of epochs to train model lr : learning rate loss_fun : the loss function + filters : list contains the filters of each convolutional layer + kernels : list contains the kernels size for each convolutional layer """ hparams = dict( - context_frames =12, + sequence_length =12, max_epochs = 20, batch_size = 40, lr = 0.001, - loss_fun = "mse", shuffle_on_val= True, - filter = [64, 64, 64, 64, 2], + filters = [64, 64, 64, 64, 2], kernels = [5, 5, 5, 5, 5] ) return hparams @@ -70,16 +69,15 @@ class WeatherBenchModel(object): def build_graph(self, x): self.is_build_graph = False - self.inputs = x self.x = x["images"] self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() # Architecture - x_hat = self.build_model(x, self.filters, self.kernels) + x_hat = self.build_model(self.x[:,0,:, :,0:1], self.filters, self.kernels) # Loss - self.total_loss = l1_loss(x[, 0,:, :,0], x_hat[, 0,:, :,0]) + self.total_loss = l1_loss(self.x[:,0,:, :,0:1], x_hat) # Optimizer self.train_op = tf.train.AdamOptimizer( @@ -99,9 +97,13 @@ class WeatherBenchModel(object): def build_model(self, x, filters, kernels): """Fully convolutional network""" + idx = 0 for f, k in zip(filters[:-1], kernels[:-1]): - x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="Conv_layer", activate="leaky_relu") - output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_layer", activate="linear") + print("1",x) + x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu") + print("2",x) + idx += 1 + output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear") return output