Skip to content
Snippets Groups Projects
Commit c637fe94 authored by gong1's avatar gong1
Browse files

add cross entropy loss fun into the convLSTM module

parent 2d72f1b6
No related branches found
No related tags found
No related merge requests found
Pipeline #44405 failed
......@@ -28,6 +28,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length - self.context_frames
self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun
def get_default_hparams_dict(self):
"""
The keys of this dict define valid hyperparameters for instances of
......@@ -40,13 +42,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
batch_size: batch size for training.
lr: learning rate. if decay steps is non-zero, this is the
learning rate for steps <= decay_step.
max_steps: number of training steps.
context_frames: the number of ground-truth frames to pass :qin at
start. Must be specified during instantiation.
sequence_length: the number of frames in the video sequence,
including the context frames, so this model predicts
`sequence_length - context_frames` future frames. Must be
specified during instantiation.
max_epochs: number of training epochs, each epoch equal to sample_size/batch_size
loss_fun: string can be either "rmse" or "cross_entropy", loss function has to be set from the user
"""
default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict()
print ("default hparams",default_hparams)
......@@ -54,6 +51,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
batch_size=16,
lr=0.001,
max_epochs=3000,
loss_fun = None
)
return dict(itertools.chain(default_hparams.items(), hparams.items()))
......@@ -71,8 +69,17 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
# tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
# This is the loss function (RMSE):
#This is loss function only for 1 channel (temperature RMSE)
if self.loss_fun == "rmse":
self.total_loss = tf.reduce_mean(
tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0]))
elif self.loss_fun == "cross_entropy":
x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
x_hat_predict_frames_flatten = tf.reshape(self.x_hat_predict_frames[:,:,:,:,0],[-1])
bce = tf.keras.losses.BinaryCrossentropy()
self.total_loss = bce(x_flatten,x_hat_predict_frames_flatten)
else:
raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")
#This is the loss for only all the channels(temperature, geo500, pressure)
#self.total_loss = tf.reduce_mean(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment