Skip to content
Snippets Groups Projects
Commit 80a43db8 authored by masak1112's avatar masak1112
Browse files

update weatherBenchModel

parent 6f3b4588
No related branches found
No related tags found
No related merge requests found
Pipeline #103140 passed
......@@ -13,7 +13,7 @@ from model_modules.video_prediction.losses import *
class WeatherBenchModel(object):
def __init__(self, hparams_dict=None,**kwargs):
def __init__(self, hparams_dict=None, mode="train",**kwargs):
"""
This is class for building weahterBench architecture by using updated hparameters
args:
......@@ -21,9 +21,9 @@ class WeatherBenchModel(object):
hparams_dict: dict, the dictionary contains the hparaemters names and values
"""
self.hparams_dict = hparams_dict
self.mode = mode
self.hparams = self.parse_hparams()
self.learning_rate = self.hparams.lr
print("hparams",self.hparams)
self.filters = self.hparams.filters
self.kernels = self.hparams.kernels
self.max_epochs = self.hparams.max_epochs
......@@ -56,7 +56,7 @@ class WeatherBenchModel(object):
kernels : list contains the kernels size for each convolutional layer
"""
hparams = dict(
sequence_length =12,
sequence_length =13,
context_frames =1,
max_epochs = 20,
batch_size = 40,
......@@ -79,7 +79,7 @@ class WeatherBenchModel(object):
x_hat = self.build_model(self.x[:,0,:, :, :],self.filters, self.kernels)
# Loss
self.total_loss = l1_loss(self.x[:,0,:, :,0], x_hat[:,:,:,0])
self.total_loss = l1_loss(self.x[:,1,:, :,:], x_hat[:,:,:,:])
# Optimizer
self.train_op = tf.train.AdamOptimizer(
......@@ -89,9 +89,11 @@ class WeatherBenchModel(object):
self.outputs["total_loss"] = self.total_loss
# inferences
if self.mode == "test":
self.outputs["gen_images"] = self.forecast(self.x, 12, self.filters, self.kernels)
else:
self.outputs["gen_images"] = x_hat
self.outputs["gen_images"] = self.forecast(self.x[:,0,:, :,0:1], 12, self.filters, self.kernels)
# Summary op
tf.summary.scalar("total_loss", self.total_loss)
self.summary_op = tf.summary.merge_all()
......@@ -105,21 +107,22 @@ class WeatherBenchModel(object):
"""Fully convolutional network"""
idx = 0
for f, k in zip(filters[:-1], kernels[:-1]):
print("1",x)
with tf.variable_scope("conv_layer_"+str(idx),reuse=tf.AUTO_REUSE):
x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu")
print("2",x)
idx += 1
with tf.variable_scope("Conv_last_layer",reuse=tf.AUTO_REUSE):
output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear")
print("output dimension", output)
return output
def forecast(self, inputs, forecast_time, filters, kernels):
def forecast(self, x, forecast_time, filters, kernels):
x_hat = []
for i in range(forecast_time):
x_pred = self.build_model(self.x[:,i,:, :,:],filters,kernels)
if i == 0:
x_pred = self.build_model(x[:,i,:, :,:],filters,kernels)
else:
x_pred = self.build_model(x_pred,filters,kernels)
x_hat.append(x_pred)
x_hat = tf.stack(x_hat)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment