Skip to content
Snippets Groups Projects
Commit c637fe94 authored by gong1's avatar gong1
Browse files

add cross entropy loss fun into the convLSTM module

parent 2d72f1b6
Branches
Tags
No related merge requests found
Pipeline #44405 failed
...@@ -28,6 +28,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): ...@@ -28,6 +28,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
self.sequence_length = self.hparams.sequence_length self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length - self.context_frames self.predict_frames = self.sequence_length - self.context_frames
self.max_epochs = self.hparams.max_epochs self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun
def get_default_hparams_dict(self): def get_default_hparams_dict(self):
""" """
The keys of this dict define valid hyperparameters for instances of The keys of this dict define valid hyperparameters for instances of
...@@ -40,13 +42,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): ...@@ -40,13 +42,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
batch_size: batch size for training. batch_size: batch size for training.
lr: learning rate. if decay steps is non-zero, this is the lr: learning rate. if decay steps is non-zero, this is the
learning rate for steps <= decay_step. learning rate for steps <= decay_step.
max_steps: number of training steps. max_epochs: number of training epochs, each epoch equal to sample_size/batch_size
context_frames: the number of ground-truth frames to pass :qin at loss_fun: string can be either "rmse" or "cross_entropy", loss function has to be set from the user
start. Must be specified during instantiation.
sequence_length: the number of frames in the video sequence,
including the context frames, so this model predicts
`sequence_length - context_frames` future frames. Must be
specified during instantiation.
""" """
default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict() default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict()
print ("default hparams",default_hparams) print ("default hparams",default_hparams)
...@@ -54,6 +51,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): ...@@ -54,6 +51,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
batch_size=16, batch_size=16,
lr=0.001, lr=0.001,
max_epochs=3000, max_epochs=3000,
loss_fun = None
) )
return dict(itertools.chain(default_hparams.items(), hparams.items())) return dict(itertools.chain(default_hparams.items(), hparams.items()))
...@@ -71,8 +69,17 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): ...@@ -71,8 +69,17 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
# tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) # tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
# This is the loss function (RMSE): # This is the loss function (RMSE):
#This is loss function only for 1 channel (temperature RMSE) #This is loss function only for 1 channel (temperature RMSE)
if self.loss_fun == "rmse":
self.total_loss = tf.reduce_mean( self.total_loss = tf.reduce_mean(
tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0])) tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0]))
elif self.loss_fun == "cross_entropy":
x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
x_hat_predict_frames_flatten = tf.reshape(self.x_hat_predict_frames[:,:,:,:,0],[-1])
bce = tf.keras.losses.BinaryCrossentropy()
self.total_loss = bce(x_flatten,x_hat_predict_frames_flatten)
else:
raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")
#This is the loss for only all the channels(temperature, geo500, pressure) #This is the loss for only all the channels(temperature, geo500, pressure)
#self.total_loss = tf.reduce_mean( #self.total_loss = tf.reduce_mean(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment