From 6f3b4588698a98b11075ba9879e7b7abcea879c9 Mon Sep 17 00:00:00 2001 From: gong1 <b.gong@fz-juelich.de> Date: Tue, 14 Jun 2022 10:24:45 +0200 Subject: [PATCH] update WeatherBench --- .../models/weatherBench3DCNN.py | 29 ++++++++++++++----- 1 file changed, 22 insertions(+), 7 deletions(-) diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py index 77ceb43b..50bd39e9 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py +++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py @@ -57,11 +57,12 @@ class WeatherBenchModel(object): """ hparams = dict( sequence_length =12, + context_frames =1, max_epochs = 20, batch_size = 40, lr = 0.001, shuffle_on_val= True, - filters = [64, 64, 64, 64, 2], + filters = [64, 64, 64, 64, 3], kernels = [5, 5, 5, 5, 5] ) return hparams @@ -75,9 +76,10 @@ class WeatherBenchModel(object): original_global_variables = tf.global_variables() # Architecture - x_hat = self.build_model(self.x[:,0,:, :,0:1], self.filters, self.kernels) + x_hat = self.build_model(self.x[:,0,:, :, :],self.filters, self.kernels) # Loss - self.total_loss = l1_loss(self.x[:,0,:, :,0:1], x_hat) + + self.total_loss = l1_loss(self.x[:,0,:, :,0], x_hat[:,:,:,0]) # Optimizer self.train_op = tf.train.AdamOptimizer( @@ -85,7 +87,11 @@ class WeatherBenchModel(object): # outputs self.outputs["total_loss"] = self.total_loss + + # inferences + + self.outputs["gen_images"] = self.forecast(self.x[:,0,:, :,0:1], 12, self.filters, self.kernels) # Summary op tf.summary.scalar("total_loss", self.total_loss) self.summary_op = tf.summary.merge_all() @@ -100,13 +106,22 @@ class WeatherBenchModel(object): idx = 0 for f, k in zip(filters[:-1], kernels[:-1]): print("1",x) - x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu") + with tf.variable_scope("conv_layer_"+str(idx),reuse=tf.AUTO_REUSE): + x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu") print("2",x) idx += 1 - output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear") + with tf.variable_scope("Conv_last_layer",reuse=tf.AUTO_REUSE): + output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear") + print("output dimension", output) return output + def forecast(self, inputs, forecast_time, filters, kernels): + x_hat = [] + for i in range(forecast_time): + x_pred = self.build_model(self.x[:,i,:, :,:],filters,kernels) + x_hat.append(x_pred) - - + x_hat = tf.stack(x_hat) + x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4]) + return x_hat -- GitLab