diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index a23d59cd9f0b11a5bb5fda4dd2d1f95b97bf0927..e092a8558fbebaf6bea68aff9dc63c692a3a2c6b 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -378,7 +378,7 @@ class TrainModel(object): self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss" if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": - fetch_list = fetch_list + [inputs, "total_loss", "inputs"] + fetch_list = fetch_list + ["inputs", "total_loss"] self.saver_loss = fetch_list[-1] self.saver_loss_name = "Total loss"