Skip to content
Snippets Groups Projects
Commit 48827e48 authored by gong1's avatar gong1
Browse files

Merge branch 'bing_issue#020_add_cross_entropy_loss_function_to_convLSTM' into develop

parents 653459a8 0bd530c8
No related branches found
No related tags found
No related merge requests found
...@@ -38,7 +38,8 @@ destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/model ...@@ -38,7 +38,8 @@ destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/model
# for choosing the model, convLSTM,savp, mcnet,vae # for choosing the model, convLSTM,savp, mcnet,vae
model=convLSTM model=convLSTM
model_hparams=../hparams/era5/${model}/model_hparams.json dataset=moving_mnist
model_hparams=../hparams/${dataset}/${model}/model_hparams.json
# rund training # rund training
srun python ../scripts/train_dummy_moving_mnist.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/ srun python ../scripts/train_moving_mnist.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}_bing_20200902/
...@@ -4,7 +4,8 @@ ...@@ -4,7 +4,8 @@
"lr": 0.001, "lr": 0.001,
"max_epochs":2, "max_epochs":2,
"context_frames":10, "context_frames":10,
"sequence_length":20 "sequence_length":20,
"loss_fun":"rmse"
} }
......
{
"batch_size": 10,
"lr": 0.001,
"max_epochs":2,
"context_frames":10,
"sequence_length":20,
"loss_fun":"cross_entropy"
}
...@@ -28,6 +28,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): ...@@ -28,6 +28,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
self.sequence_length = self.hparams.sequence_length self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length - self.context_frames self.predict_frames = self.sequence_length - self.context_frames
self.max_epochs = self.hparams.max_epochs self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun
def get_default_hparams_dict(self): def get_default_hparams_dict(self):
""" """
The keys of this dict define valid hyperparameters for instances of The keys of this dict define valid hyperparameters for instances of
...@@ -40,13 +42,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): ...@@ -40,13 +42,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
batch_size: batch size for training. batch_size: batch size for training.
lr: learning rate. if decay steps is non-zero, this is the lr: learning rate. if decay steps is non-zero, this is the
learning rate for steps <= decay_step. learning rate for steps <= decay_step.
max_steps: number of training steps. max_epochs: number of training epochs, each epoch equal to sample_size/batch_size
context_frames: the number of ground-truth frames to pass :qin at loss_fun: string can be either "rmse" or "cross_entropy", loss function has to be set from the user
start. Must be specified during instantiation.
sequence_length: the number of frames in the video sequence,
including the context frames, so this model predicts
`sequence_length - context_frames` future frames. Must be
specified during instantiation.
""" """
default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict() default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict()
print ("default hparams",default_hparams) print ("default hparams",default_hparams)
...@@ -54,6 +51,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): ...@@ -54,6 +51,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
batch_size=16, batch_size=16,
lr=0.001, lr=0.001,
max_epochs=3000, max_epochs=3000,
loss_fun = None
) )
return dict(itertools.chain(default_hparams.items(), hparams.items())) return dict(itertools.chain(default_hparams.items(), hparams.items()))
...@@ -71,8 +69,17 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): ...@@ -71,8 +69,17 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
# tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) # tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
# This is the loss function (RMSE): # This is the loss function (RMSE):
#This is loss function only for 1 channel (temperature RMSE) #This is loss function only for 1 channel (temperature RMSE)
if self.loss_fun == "rmse":
self.total_loss = tf.reduce_mean( self.total_loss = tf.reduce_mean(
tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0])) tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0]))
elif self.loss_fun == "cross_entropy":
x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
x_hat_predict_frames_flatten = tf.reshape(self.x_hat_predict_frames[:,:,:,:,0],[-1])
bce = tf.keras.losses.BinaryCrossentropy()
self.total_loss = bce(x_flatten,x_hat_predict_frames_flatten)
else:
raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")
#This is the loss for only all the channels(temperature, geo500, pressure) #This is the loss for only all the channels(temperature, geo500, pressure)
#self.total_loss = tf.reduce_mean( #self.total_loss = tf.reduce_mean(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment