diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 08202678b896f43630ae2cecf88f34fe2fe67298..0b91b18174ca3416f4be2d9f96019bb8f646767e 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -469,12 +469,17 @@ class TrainModel(object): :return flag: True if model should be saved :return loss_avg: updated minimum loss """ + method = TrainModel.set_model_saver_flag.__name__ + save_flag = False if len(losses) <= niter_steps*2: loss_avg = old_min_loss return save_flag, loss_avg loss_avg = np.mean(losses[-niter_steps:]) + # print diagnosis + print("%{0}: Current loss: {1:.4f}, old minimum: {2:.4f}, model will be saved: {3}" + .format(method, loss_avg, old_min_loss, loss_avg < old_min_loss)) if loss_avg < old_min_loss: save_flag = True else: