Skip to content
Snippets Groups Projects
Commit 5ceb5778 authored by BING GONG's avatar BING GONG
Browse files

update the hparameters for weatherbench model

parent 8c29a0ce
No related branches found
No related tags found
No related merge requests found
Pipeline #102331 passed
{ {
"batch_size": 4, "batch_size": 4,
"lr": 0.001, "lr": 0.0001,
"max_epochs":20, "max_epochs":20,
"context_frames":12, "context_frames":12,
"loss_fun":"mse", "loss_fun":"mse",
"opt_var": "0", "opt_var": "0",
"shuffle_on_val":true "shuffle_on_val":true,
"filters": [64, 64, 64, 64, 2],
"kernels": [5, 5, 5, 5, 5]
} }
......
...@@ -24,14 +24,11 @@ class WeatherBenchModel(object): ...@@ -24,14 +24,11 @@ class WeatherBenchModel(object):
self.hparams = self.parse_hparams() self.hparams = self.parse_hparams()
self.learning_rate = self.hparams.lr self.learning_rate = self.hparams.lr
self.filters = self.hparams.filters self.filters = self.hparams.filters
self.kernels = self.hparams.kernes self.kernels = self.hparams.kernels
self.context_frames = self.hparams.context_frames self.context_frames = self.hparams.context_frames
self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length- self.context_frames
self.max_epochs = self.hparams.max_epochs self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun self.loss_fun = self.hparams.loss_fun
self.batch_size = self.hparams.batch_size self.batch_size = self.hparams.batch_size
self.recon_weight = self.hparams.recon_weight
self.outputs = {} self.outputs = {}
self.total_loss = None self.total_loss = None
...@@ -60,15 +57,13 @@ class WeatherBenchModel(object): ...@@ -60,15 +57,13 @@ class WeatherBenchModel(object):
""" """
hparams = dict( hparams = dict(
context_frames =12, context_frames =12,
sequence_length =24,
max_epochs = 20, max_epochs = 20,
batch_size = 40, batch_size = 40,
lr = 0.001, lr = 0.001,
loss_fun = "mse", loss_fun = "mse",
shuffle_on_val= True, shuffle_on_val= True,
filter = 4, filter = [64, 64, 64, 64, 2],
kernels = 4, kernels = [5, 5, 5, 5, 5]
) )
return hparams return hparams
...@@ -85,8 +80,10 @@ class WeatherBenchModel(object): ...@@ -85,8 +80,10 @@ class WeatherBenchModel(object):
x_hat = self.build_model(x, self.filters, self.kernels, dr=0) x_hat = self.build_model(x, self.filters, self.kernels, dr=0)
# Loss # Loss
self.total_loss = l1_loss(x[...,0], x_hat[...,0]) self.total_loss = l1_loss(x[...,0], x_hat[...,0])
# Optimizer # Optimizer
self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) self.train_op = tf.train.AdamOptimizer(
learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step)
# outputs # outputs
self.outputs["total_loss"] = self.total_loss self.outputs["total_loss"] = self.total_loss
...@@ -111,7 +108,7 @@ class WeatherBenchModel(object): ...@@ -111,7 +108,7 @@ class WeatherBenchModel(object):
class PeriodicPadding2D(object): class PeriodicPadding2D(object):
def __init__(self, x, pad_width): def __init__(self, pad_width):
self.pad_width = pad_width self.pad_width = pad_width
...@@ -127,7 +124,6 @@ class PeriodicPadding2D(object): ...@@ -127,7 +124,6 @@ class PeriodicPadding2D(object):
return inputs_padded return inputs_padded
class PeriodicConv2D(object): class PeriodicConv2D(object):
def __init__(self, filters, kernel_size, conv_kwargs={}): def __init__(self, filters, kernel_size, conv_kwargs={}):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment