Skip to content
Snippets Groups Projects
Commit 80a43db8 authored by masak1112's avatar masak1112
Browse files

update weatherBenchModel

parent 6f3b4588
Branches bing_issue#005_clean_and_update_CICD_file
No related tags found
No related merge requests found
Pipeline #103140 passed
...@@ -13,7 +13,7 @@ from model_modules.video_prediction.losses import * ...@@ -13,7 +13,7 @@ from model_modules.video_prediction.losses import *
class WeatherBenchModel(object): class WeatherBenchModel(object):
def __init__(self, hparams_dict=None,**kwargs): def __init__(self, hparams_dict=None, mode="train",**kwargs):
""" """
This is class for building weahterBench architecture by using updated hparameters This is class for building weahterBench architecture by using updated hparameters
args: args:
...@@ -21,9 +21,9 @@ class WeatherBenchModel(object): ...@@ -21,9 +21,9 @@ class WeatherBenchModel(object):
hparams_dict: dict, the dictionary contains the hparaemters names and values hparams_dict: dict, the dictionary contains the hparaemters names and values
""" """
self.hparams_dict = hparams_dict self.hparams_dict = hparams_dict
self.mode = mode
self.hparams = self.parse_hparams() self.hparams = self.parse_hparams()
self.learning_rate = self.hparams.lr self.learning_rate = self.hparams.lr
print("hparams",self.hparams)
self.filters = self.hparams.filters self.filters = self.hparams.filters
self.kernels = self.hparams.kernels self.kernels = self.hparams.kernels
self.max_epochs = self.hparams.max_epochs self.max_epochs = self.hparams.max_epochs
...@@ -56,7 +56,7 @@ class WeatherBenchModel(object): ...@@ -56,7 +56,7 @@ class WeatherBenchModel(object):
kernels : list contains the kernels size for each convolutional layer kernels : list contains the kernels size for each convolutional layer
""" """
hparams = dict( hparams = dict(
sequence_length =12, sequence_length =13,
context_frames =1, context_frames =1,
max_epochs = 20, max_epochs = 20,
batch_size = 40, batch_size = 40,
...@@ -79,7 +79,7 @@ class WeatherBenchModel(object): ...@@ -79,7 +79,7 @@ class WeatherBenchModel(object):
x_hat = self.build_model(self.x[:,0,:, :, :],self.filters, self.kernels) x_hat = self.build_model(self.x[:,0,:, :, :],self.filters, self.kernels)
# Loss # Loss
self.total_loss = l1_loss(self.x[:,0,:, :,0], x_hat[:,:,:,0]) self.total_loss = l1_loss(self.x[:,1,:, :,:], x_hat[:,:,:,:])
# Optimizer # Optimizer
self.train_op = tf.train.AdamOptimizer( self.train_op = tf.train.AdamOptimizer(
...@@ -89,9 +89,11 @@ class WeatherBenchModel(object): ...@@ -89,9 +89,11 @@ class WeatherBenchModel(object):
self.outputs["total_loss"] = self.total_loss self.outputs["total_loss"] = self.total_loss
# inferences # inferences
if self.mode == "test":
self.outputs["gen_images"] = self.forecast(self.x, 12, self.filters, self.kernels)
else:
self.outputs["gen_images"] = x_hat
self.outputs["gen_images"] = self.forecast(self.x[:,0,:, :,0:1], 12, self.filters, self.kernels)
# Summary op # Summary op
tf.summary.scalar("total_loss", self.total_loss) tf.summary.scalar("total_loss", self.total_loss)
self.summary_op = tf.summary.merge_all() self.summary_op = tf.summary.merge_all()
...@@ -105,21 +107,22 @@ class WeatherBenchModel(object): ...@@ -105,21 +107,22 @@ class WeatherBenchModel(object):
"""Fully convolutional network""" """Fully convolutional network"""
idx = 0 idx = 0
for f, k in zip(filters[:-1], kernels[:-1]): for f, k in zip(filters[:-1], kernels[:-1]):
print("1",x)
with tf.variable_scope("conv_layer_"+str(idx),reuse=tf.AUTO_REUSE): with tf.variable_scope("conv_layer_"+str(idx),reuse=tf.AUTO_REUSE):
x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu") x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu")
print("2",x)
idx += 1 idx += 1
with tf.variable_scope("Conv_last_layer",reuse=tf.AUTO_REUSE): with tf.variable_scope("Conv_last_layer",reuse=tf.AUTO_REUSE):
output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear") output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear")
print("output dimension", output)
return output return output
def forecast(self, inputs, forecast_time, filters, kernels): def forecast(self, x, forecast_time, filters, kernels):
x_hat = [] x_hat = []
for i in range(forecast_time): for i in range(forecast_time):
x_pred = self.build_model(self.x[:,i,:, :,:],filters,kernels) if i == 0:
x_pred = self.build_model(x[:,i,:, :,:],filters,kernels)
else:
x_pred = self.build_model(x_pred,filters,kernels)
x_hat.append(x_pred) x_hat.append(x_pred)
x_hat = tf.stack(x_hat) x_hat = tf.stack(x_hat)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment