Skip to content
Snippets Groups Projects
Commit 80a43db8 authored by masak1112's avatar masak1112
Browse files

update weatherBenchModel

parent 6f3b4588
Branches
Tags
No related merge requests found
Pipeline #103140 passed
......@@ -13,7 +13,7 @@ from model_modules.video_prediction.losses import *
class WeatherBenchModel(object):
def __init__(self, hparams_dict=None,**kwargs):
def __init__(self, hparams_dict=None, mode="train",**kwargs):
"""
This is class for building weahterBench architecture by using updated hparameters
args:
......@@ -21,9 +21,9 @@ class WeatherBenchModel(object):
hparams_dict: dict, the dictionary contains the hparaemters names and values
"""
self.hparams_dict = hparams_dict
self.mode = mode
self.hparams = self.parse_hparams()
self.learning_rate = self.hparams.lr
print("hparams",self.hparams)
self.filters = self.hparams.filters
self.kernels = self.hparams.kernels
self.max_epochs = self.hparams.max_epochs
......@@ -56,7 +56,7 @@ class WeatherBenchModel(object):
kernels : list contains the kernels size for each convolutional layer
"""
hparams = dict(
sequence_length =12,
sequence_length =13,
context_frames =1,
max_epochs = 20,
batch_size = 40,
......@@ -79,7 +79,7 @@ class WeatherBenchModel(object):
x_hat = self.build_model(self.x[:,0,:, :, :],self.filters, self.kernels)
# Loss
self.total_loss = l1_loss(self.x[:,0,:, :,0], x_hat[:,:,:,0])
self.total_loss = l1_loss(self.x[:,1,:, :,:], x_hat[:,:,:,:])
# Optimizer
self.train_op = tf.train.AdamOptimizer(
......@@ -89,9 +89,11 @@ class WeatherBenchModel(object):
self.outputs["total_loss"] = self.total_loss
# inferences
if self.mode == "test":
self.outputs["gen_images"] = self.forecast(self.x, 12, self.filters, self.kernels)
else:
self.outputs["gen_images"] = x_hat
self.outputs["gen_images"] = self.forecast(self.x[:,0,:, :,0:1], 12, self.filters, self.kernels)
# Summary op
tf.summary.scalar("total_loss", self.total_loss)
self.summary_op = tf.summary.merge_all()
......@@ -105,21 +107,22 @@ class WeatherBenchModel(object):
"""Fully convolutional network"""
idx = 0
for f, k in zip(filters[:-1], kernels[:-1]):
print("1",x)
with tf.variable_scope("conv_layer_"+str(idx),reuse=tf.AUTO_REUSE):
x = ld.conv_layer(x, kernel_size=k, stride=1, num_features=f, idx="conv_layer_"+str(idx) , activate="leaky_relu")
print("2",x)
idx += 1
with tf.variable_scope("Conv_last_layer",reuse=tf.AUTO_REUSE):
output = ld.conv_layer(x, kernel_size=kernels[-1], stride=1, num_features=filters[-1], idx="Conv_last_layer", activate="linear")
print("output dimension", output)
return output
def forecast(self, inputs, forecast_time, filters, kernels):
def forecast(self, x, forecast_time, filters, kernels):
x_hat = []
for i in range(forecast_time):
x_pred = self.build_model(self.x[:,i,:, :,:],filters,kernels)
if i == 0:
x_pred = self.build_model(x[:,i,:, :,:],filters,kernels)
else:
x_pred = self.build_model(x_pred,filters,kernels)
x_hat.append(x_pred)
x_hat = tf.stack(x_hat)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment