From 6f3b4588698a98b11075ba9879e7b7abcea879c9 Mon Sep 17 00:00:00 2001
From: gong1 <b.gong@fz-juelich.de>
Date: Tue, 14 Jun 2022 10:24:45 +0200
Subject: [PATCH] update WeatherBench

---
 .../models/weatherBench3DCNN.py               | 29 ++++++++++++++-----
 1 file changed, 22 insertions(+), 7 deletions(-)

diff --git a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
index 77ceb43b..50bd39e9 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/weatherBench3DCNN.py
@@ -57,11 +57,12 @@ class WeatherBenchModel(object):
             """
         hparams = dict(
             sequence_length =12,
+            context_frames =1,
             max_epochs = 20,
             batch_size = 40,
             lr = 0.001,
             shuffle_on_val= True,
-            filters = [64, 64, 64, 64, 2],
+            filters = [64, 64, 64, 64, 3],
             kernels = [5, 5, 5, 5, 5]
         )
         return hparams
@@ -75,9 +76,10 @@ class WeatherBenchModel(object):
         original_global_variables = tf.global_variables()
 
         # Architecture
-        x_hat = self.build_model(self.x[:,0,:, :,0:1], self.filters, self.kernels)
+        x_hat = self.build_model(self.x[:,0,:, :, :],self.filters, self.kernels)
         # Loss
-        self.total_loss = l1_loss(self.x[:,0,:, :,0:1], x_hat)
+        
+        self.total_loss = l1_loss(self.x[:,0,:, :,0], x_hat[:,:,:,0])
 
         # Optimizer
         self.train_op = tf.train.AdamOptimizer(
@@ -85,7 +87,11 @@ class WeatherBenchModel(object):
 
         # outputs
         self.outputs["total_loss"] = self.total_loss
+       
+        # inferences
 
+
+        self.outputs["gen_images"] = self.forecast(self.x[:,0,:, :,0:1], 12, self.filters, self.kernels)
         # Summary op
         tf.summary.scalar("total_loss", self.total_loss)
         self.summary_op = tf.summary.merge_all()
@@ -100,13 +106,22 @@ class WeatherBenchModel(object):
         idx = 0 
         for f, k in zip(filters[:-1], kernels[:-1]):
             print("1",x)
-            x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu")
+            with tf.variable_scope("conv_layer_"+str(idx),reuse=tf.AUTO_REUSE):
+                x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu")
             print("2",x)
             idx += 1
-        output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear")
+        with tf.variable_scope("Conv_last_layer",reuse=tf.AUTO_REUSE):
+            output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear")
+            print("output dimension", output)
         return output
 
 
+    def forecast(self, inputs, forecast_time, filters, kernels):
+        x_hat = []
+        for i in range(forecast_time):
+            x_pred = self.build_model(self.x[:,i,:, :,:],filters,kernels)
+            x_hat.append(x_pred)
 
-
-
+        x_hat = tf.stack(x_hat)
+        x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4])
+        return x_hat
-- 
GitLab