diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 98df062bf9919efa0657c8a63a8a842a54d2502b..08202678b896f43630ae2cecf88f34fe2fe67298 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -312,13 +312,12 @@ class TrainModel(object): self.results = sess.run(self.fetches) # ...and run it here! # Note: For SAVP, the obtained loss is a list where the first element is of interest, for convLSTM, # it's just a number. Thus, with list(<losses>)[0], we can handle both - train_losses.append(self.results[list(self.saver_loss)[0]]) + train_losses.append(list(self.results[self.saver_loss])[0]) # run and fetch losses for validation data val_handle_eval = sess.run(self.val_handle) self.create_fetches_for_val() self.val_results = sess.run(self.val_fetches, feed_dict={self.train_handle: val_handle_eval}) - val_losses.append(self.val_results[list(self.saver_loss)[0]]) - print(val_losses) + val_losses.append(list(self.val_results[self.saver_loss])[0]) self.write_to_summary() self.print_results(step, self.results) # track iteration time