diff --git a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py index a00411077efa2ea86ee5b67db543b4000b0ea19a..3f5cd795b5161a894a6256a9619f71f668477527 100644 --- a/video_prediction_savp/video_prediction/models/vanilla_vae_model.py +++ b/video_prediction_savp/video_prediction/models/vanilla_vae_model.py @@ -74,7 +74,8 @@ class VanillaVAEVideoPredictionModel(BaseVideoPredictionModel): ) return dict(itertools.chain(default_hparams.items(), hparams.items())) - def build_graph(self,x): + + def build_graph(self,x): self.x = x["images"] #self.global_step = tf.Variable(0, name = 'global_step', trainable = False) self.global_step = tf.train.get_or_create_global_step()