From 517835a99f1b752da99aca0a47540c216ff8490a Mon Sep 17 00:00:00 2001 From: Bing Gong <b.gong@fz-juelich.de> Date: Sun, 15 Nov 2020 20:35:02 +0000 Subject: [PATCH] Update main_train_models.py --- .../main_scripts/main_train_models.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index dcdcbe88..60ab2cb9 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -210,8 +210,8 @@ class TrainModel(object): def calculate_samples_and_epochs(self): """ - Clculate the number of samples for train/val/testing dataset. The samples are used for training model for each epoch. - Clculate the iterations (samples multiple by max_epochs) for training. + Calculate the number of samples for train dataset, which is used for each epoch training + Calculate the iterations (samples multiple by max_epochs) for training. """ batch_size = self.video_model.hparams.batch_size max_epochs = self.video_model.hparams.max_epochs #the number of epochs @@ -220,6 +220,9 @@ class TrainModel(object): self.total_steps = self.steps_per_epoch * max_epochs def restore(self,sess, checkpoints, restore_to_checkpoint_mapping=None): + """ + Restore the models checkpoints if the checkpoints is given + """ if checkpoints: var_list = self.video_model.saveable_variables # possibly restore from multiple checkpoints. useful if subset of weights @@ -240,7 +243,7 @@ class TrainModel(object): def restore_train_val_losses(self): """ - Restore the train and validation losses in the pickle file + Restore the train and validation losses in the pickle file if checkpoint is given """ if self.start_step == 0: train_losses = [] @@ -302,7 +305,7 @@ class TrainModel(object): def create_fetches_for_train(self): """ - Fetch variables in the graph, this can be custermized based on models and based on the needs of users + Fetch variables in the graph, this can be custermized based on models and also the needs of users """ #This is the base fetch that for all the models self.fetches = {"train_op": self.video_model.train_op} @@ -317,14 +320,14 @@ class TrainModel(object): def fetches_for_train_convLSTM(self): """ - Fetch variables in the graph for convLSTM model, this can be custermized based on models and based on the needs of users + Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users """ pass def fetches_for_train_savp(self): """ - Fetch variables in the graph for savp model, this can be custermized based on models and based on the needs of users + Fetch variables in the graph for savp model, this can be custermized based on models and the needs of users """ self.fetches["g_losses"] = self.video_model.g_losses self.fetches["d_losses"] = self.video_model.d_losses @@ -333,7 +336,7 @@ class TrainModel(object): def fetches_for_train_mcnet(self): """ - Fetch variables in the graph for mcnet model, this can be custermized based on models and based on the needs of users + Fetch variables in the graph for mcnet model, this can be custermized based on models and the needs of users """ self.fetches["L_p"] = self.video_model.L_p self.fetches["L_gdl"] = self.video_model.L_gdl @@ -349,7 +352,7 @@ class TrainModel(object): def create_fetches_for_val(self): """ - Fetch variables in the graph for validation dataset, this can be custermized based on models and based on the needs of users + Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users """ self.val_fetches = {"total_loss": self.video_model.total_loss} self.val_fetches["summary"] = self.video_model.summary_op @@ -383,7 +386,7 @@ class TrainModel(object): """ Function to plot training losses for train and val datasets against steps params: - train_losses/val_losses :list, train losses, which length should be equal to the number of training steps + train_losses/val_losses : list, train losses, which length should be equal to the number of training steps step : int, current training step output_dir : str, the path to save the plot """ -- GitLab