Skip to content
Snippets Groups Projects
Commit 48827e48 authored by gong1's avatar gong1
Browse files

Merge branch 'bing_issue#020_add_cross_entropy_loss_function_to_convLSTM' into develop

parents 653459a8 0bd530c8
No related branches found
No related tags found
No related merge requests found
......@@ -38,7 +38,8 @@ destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/model
# for choosing the model, convLSTM,savp, mcnet,vae
model=convLSTM
model_hparams=../hparams/era5/${model}/model_hparams.json
dataset=moving_mnist
model_hparams=../hparams/${dataset}/${model}/model_hparams.json
# rund training
srun python ../scripts/train_dummy_moving_mnist.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/
srun python ../scripts/train_moving_mnist.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}_bing_20200902/
......@@ -4,7 +4,8 @@
"lr": 0.001,
"max_epochs":2,
"context_frames":10,
"sequence_length":20
"sequence_length":20,
"loss_fun":"rmse"
}
......
{
"batch_size": 10,
"lr": 0.001,
"max_epochs":2,
"context_frames":10,
"sequence_length":20,
"loss_fun":"cross_entropy"
}
......@@ -28,6 +28,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length - self.context_frames
self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun
def get_default_hparams_dict(self):
"""
The keys of this dict define valid hyperparameters for instances of
......@@ -40,13 +42,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
batch_size: batch size for training.
lr: learning rate. if decay steps is non-zero, this is the
learning rate for steps <= decay_step.
max_steps: number of training steps.
context_frames: the number of ground-truth frames to pass :qin at
start. Must be specified during instantiation.
sequence_length: the number of frames in the video sequence,
including the context frames, so this model predicts
`sequence_length - context_frames` future frames. Must be
specified during instantiation.
max_epochs: number of training epochs, each epoch equal to sample_size/batch_size
loss_fun: string can be either "rmse" or "cross_entropy", loss function has to be set from the user
"""
default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict()
print ("default hparams",default_hparams)
......@@ -54,6 +51,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
batch_size=16,
lr=0.001,
max_epochs=3000,
loss_fun = None
)
return dict(itertools.chain(default_hparams.items(), hparams.items()))
......@@ -71,8 +69,17 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
# tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
# This is the loss function (RMSE):
#This is loss function only for 1 channel (temperature RMSE)
if self.loss_fun == "rmse":
self.total_loss = tf.reduce_mean(
tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0]))
elif self.loss_fun == "cross_entropy":
x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
x_hat_predict_frames_flatten = tf.reshape(self.x_hat_predict_frames[:,:,:,:,0],[-1])
bce = tf.keras.losses.BinaryCrossentropy()
self.total_loss = bce(x_flatten,x_hat_predict_frames_flatten)
else:
raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")
#This is the loss for only all the channels(temperature, geo500, pressure)
#self.total_loss = tf.reduce_mean(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment