diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py index 50bd39e9780132f174b1607dcee1d47802d31825..5c3563bc9fbe4709f4f87f4d5d6905269c949ba8 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -13,7 +13,7 @@ from model_modules.video_prediction.losses import * class WeatherBenchModel(object): - def __init__(self, hparams_dict=None,**kwargs): + def __init__(self, hparams_dict=None, mode="train",**kwargs): """ This is class for building weahterBench architecture by using updated hparameters args: @@ -21,9 +21,9 @@ class WeatherBenchModel(object): hparams_dict: dict, the dictionary contains the hparaemters names and values """ self.hparams_dict = hparams_dict + self.mode = mode self.hparams = self.parse_hparams() self.learning_rate = self.hparams.lr - print("hparams",self.hparams) self.filters = self.hparams.filters self.kernels = self.hparams.kernels self.max_epochs = self.hparams.max_epochs @@ -56,7 +56,7 @@ class WeatherBenchModel(object): kernels : list contains the kernels size for each convolutional layer """ hparams = dict( - sequence_length =12, + sequence_length =13, context_frames =1, max_epochs = 20, batch_size = 40, @@ -79,7 +79,7 @@ class WeatherBenchModel(object): x_hat = self.build_model(self.x[:,0,:, :, :],self.filters, self.kernels) # Loss - self.total_loss = l1_loss(self.x[:,0,:, :,0], x_hat[:,:,:,0]) + self.total_loss = l1_loss(self.x[:,1,:, :,:], x_hat[:,:,:,:]) # Optimizer self.train_op = tf.train.AdamOptimizer( @@ -89,9 +89,11 @@ class WeatherBenchModel(object): self.outputs["total_loss"] = self.total_loss # inferences + if self.mode == "test": + self.outputs["gen_images"] = self.forecast(self.x, 12, self.filters, self.kernels) + else: + self.outputs["gen_images"] = x_hat - - self.outputs["gen_images"] = self.forecast(self.x[:,0,:, :,0:1], 12, self.filters, self.kernels) # Summary op tf.summary.scalar("total_loss", self.total_loss) self.summary_op = tf.summary.merge_all() @@ -105,21 +107,22 @@ class WeatherBenchModel(object): """Fully convolutional network""" idx = 0 for f, k in zip(filters[:-1], kernels[:-1]): - print("1",x) with tf.variable_scope("conv_layer_"+str(idx),reuse=tf.AUTO_REUSE): x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu") - print("2",x) idx += 1 with tf.variable_scope("Conv_last_layer",reuse=tf.AUTO_REUSE): output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear") - print("output dimension", output) return output - def forecast(self, inputs, forecast_time, filters, kernels): + def forecast(self, x, forecast_time, filters, kernels): x_hat = [] + for i in range(forecast_time): - x_pred = self.build_model(self.x[:,i,:, :,:],filters,kernels) + if i == 0: + x_pred = self.build_model(x[:,i,:, :,:],filters,kernels) + else: + x_pred = self.build_model(x_pred,filters,kernels) x_hat.append(x_pred) x_hat = tf.stack(x_hat)