diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index 01a4f7ce5d6430f19a1e4b99c4cba956b3f7682b..d3b3d4817faa10e6f5db5257fdf4cd526e6d01c7 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -28,6 +28,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): self.sequence_length = self.hparams.sequence_length self.predict_frames = self.sequence_length - self.context_frames self.max_epochs = self.hparams.max_epochs + self.loss_fun = self.hparams.loss_fun + def get_default_hparams_dict(self): """ The keys of this dict define valid hyperparameters for instances of @@ -40,20 +42,16 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): batch_size: batch size for training. lr: learning rate. if decay steps is non-zero, this is the learning rate for steps <= decay_step. - max_steps: number of training steps. - context_frames: the number of ground-truth frames to pass :qin at - start. Must be specified during instantiation. - sequence_length: the number of frames in the video sequence, - including the context frames, so this model predicts - `sequence_length - context_frames` future frames. Must be - specified during instantiation. - """ + max_epochs: number of training epochs, each epoch equal to sample_size/batch_size + loss_fun: string can be either "rmse" or "cross_entropy", loss function has to be set from the user + """ default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict() print ("default hparams",default_hparams) hparams = dict( batch_size=16, lr=0.001, max_epochs=3000, + loss_fun = None ) return dict(itertools.chain(default_hparams.items(), hparams.items())) @@ -71,9 +69,18 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): # tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) # This is the loss function (RMSE): #This is loss function only for 1 channel (temperature RMSE) - self.total_loss = tf.reduce_mean( - tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0])) - + if self.loss_fun == "rmse": + self.total_loss = tf.reduce_mean( + tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0])) + elif self.loss_fun == "cross_entropy": + x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1]) + x_hat_predict_frames_flatten = tf.reshape(self.x_hat_predict_frames[:,:,:,:,0],[-1]) + bce = tf.keras.losses.BinaryCrossentropy() + self.total_loss = bce(x_flatten,x_hat_predict_frames_flatten) + else: + raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'") + + #This is the loss for only all the channels(temperature, geo500, pressure) #self.total_loss = tf.reduce_mean( # tf.square(self.x[:, self.context_frames:,:,:,:] - self.x_hat_predict_frames[:,:,:,:,:]))