diff --git a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py index 7560a225e7651728e2ca8d2107d7f32458106c86..be978eae491d875214bd67322419ec332ccb53d5 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_convLSTM_model.py @@ -69,6 +69,7 @@ class VanillaConvLstmVideoPredictionModel(BaseVideoPredictionModel): print("self.x_hat_context_frames,",self.x_hat_context_frames) #self.context_frames_loss = tf.reduce_mean( # tf.square(self.x[:, :self.context_frames, :, :, 0] - self.x_hat_context_frames[:, :, :, :, 0])) + # This is the loss function (RMSE): self.total_loss = tf.reduce_mean( tf.square(self.x[:, self.context_frames:, :, :, 0] - self.x_hat_context_frames[:, (self.context_frames-1):-1, :, :, 0]))