diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
index 01a4f7ce5d6430f19a1e4b99c4cba956b3f7682b..d3b3d4817faa10e6f5db5257fdf4cd526e6d01c7 100644
--- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
+++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py
@@ -28,6 +28,8 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
         self.sequence_length = self.hparams.sequence_length
         self.predict_frames = self.sequence_length - self.context_frames
         self.max_epochs = self.hparams.max_epochs
+        self.loss_fun = self.hparams.loss_fun
+
     def get_default_hparams_dict(self):
         """
         The keys of this dict define valid hyperparameters for instances of
@@ -40,20 +42,16 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
             batch_size: batch size for training.
             lr: learning rate. if decay steps is non-zero, this is the
                 learning rate for steps <= decay_step.
-            max_steps: number of training steps.
-            context_frames: the number of ground-truth frames to pass :qin at
-                start. Must be specified during instantiation.
-            sequence_length: the number of frames in the video sequence,
-                including the context frames, so this model predicts
-                `sequence_length - context_frames` future frames. Must be
-                specified during instantiation.
-        """
+            max_epochs: number of training epochs, each epoch equal to sample_size/batch_size
+            loss_fun: string can be either "rmse" or "cross_entropy", loss function has to be set from the user 
+         """
         default_hparams = super(VanillaConvLstmVideoPredictionModel, self).get_default_hparams_dict()
         print ("default hparams",default_hparams)
         hparams = dict(
             batch_size=16,
             lr=0.001,
             max_epochs=3000,
+            loss_fun = None
         )
 
         return dict(itertools.chain(default_hparams.items(), hparams.items()))
@@ -71,9 +69,18 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel):
         #    tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0]))
         # This is the loss function (RMSE):
         #This is loss function only for 1 channel (temperature RMSE)
-        self.total_loss = tf.reduce_mean(
-            tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0]))
-            
+        if self.loss_fun == "rmse":
+            self.total_loss = tf.reduce_mean(
+                tf.square(self.x[:, self.context_frames:,:,:,0] - self.x_hat_predict_frames[:,:,:,:,0]))
+        elif self.loss_fun == "cross_entropy":
+            x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1])
+            x_hat_predict_frames_flatten = tf.reshape(self.x_hat_predict_frames[:,:,:,:,0],[-1])
+            bce = tf.keras.losses.BinaryCrossentropy()
+            self.total_loss = bce(x_flatten,x_hat_predict_frames_flatten)  
+        else:
+            raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'")
+
+        
         #This is the loss for only all the channels(temperature, geo500, pressure)
         #self.total_loss = tf.reduce_mean(
         #    tf.square(self.x[:, self.context_frames:,:,:,:] - self.x_hat_predict_frames[:,:,:,:,:]))