From 517835a99f1b752da99aca0a47540c216ff8490a Mon Sep 17 00:00:00 2001
From: Bing Gong <b.gong@fz-juelich.de>
Date: Sun, 15 Nov 2020 20:35:02 +0000
Subject: [PATCH] Update main_train_models.py

---
 .../main_scripts/main_train_models.py         | 21 +++++++++++--------
 1 file changed, 12 insertions(+), 9 deletions(-)

diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py
index dcdcbe88..60ab2cb9 100644
--- a/video_prediction_tools/main_scripts/main_train_models.py
+++ b/video_prediction_tools/main_scripts/main_train_models.py
@@ -210,8 +210,8 @@ class TrainModel(object):
 
     def calculate_samples_and_epochs(self):
         """
-        Clculate the number of samples for train/val/testing dataset. The samples are used for training model for each epoch. 
-        Clculate the iterations (samples multiple by max_epochs) for training.
+        Calculate the number of samples for train dataset, which is used for each epoch training
+        Calculate the iterations (samples multiple by max_epochs) for training.
         """
         batch_size = self.video_model.hparams.batch_size
         max_epochs = self.video_model.hparams.max_epochs #the number of epochs
@@ -220,6 +220,9 @@ class TrainModel(object):
         self.total_steps = self.steps_per_epoch * max_epochs
 
     def restore(self,sess, checkpoints, restore_to_checkpoint_mapping=None):
+        """
+        Restore the models checkpoints if the checkpoints is given
+        """
         if checkpoints:
            var_list = self.video_model.saveable_variables
            # possibly restore from multiple checkpoints. useful if subset of weights
@@ -240,7 +243,7 @@ class TrainModel(object):
     
     def restore_train_val_losses(self):
         """
-        Restore the train and validation losses in the pickle file
+        Restore the train and validation losses in the pickle file if checkpoint is given 
         """
         if self.start_step == 0:
             train_losses = []
@@ -302,7 +305,7 @@ class TrainModel(object):
  
     def create_fetches_for_train(self):
        """
-       Fetch variables in the graph, this can be custermized based on models and based on the needs of users
+       Fetch variables in the graph, this can be custermized based on models and also the needs of users
        """
        #This is the base fetch that for all the  models
        self.fetches = {"train_op": self.video_model.train_op}
@@ -317,14 +320,14 @@ class TrainModel(object):
     
     def fetches_for_train_convLSTM(self):
         """
-        Fetch variables in the graph for convLSTM model, this can be custermized based on models and based on the needs of users
+        Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users
         """
         pass
 
  
     def fetches_for_train_savp(self):
         """
-        Fetch variables in the graph for savp model, this can be custermized based on models and based on the needs of users
+        Fetch variables in the graph for savp model, this can be custermized based on models and the needs of users
         """
         self.fetches["g_losses"] = self.video_model.g_losses
         self.fetches["d_losses"] = self.video_model.d_losses
@@ -333,7 +336,7 @@ class TrainModel(object):
 
     def fetches_for_train_mcnet(self):
         """
-        Fetch variables in the graph for mcnet model, this can be custermized based on models and based on the needs of users
+        Fetch variables in the graph for mcnet model, this can be custermized based on models and  the needs of users
         """
         self.fetches["L_p"] = self.video_model.L_p
         self.fetches["L_gdl"] = self.video_model.L_gdl
@@ -349,7 +352,7 @@ class TrainModel(object):
 
     def create_fetches_for_val(self):
         """
-        Fetch variables in the graph for validation dataset, this can be custermized based on models and based on the needs of users
+        Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users
         """
         self.val_fetches = {"total_loss": self.video_model.total_loss}
         self.val_fetches["summary"] = self.video_model.summary_op
@@ -383,7 +386,7 @@ class TrainModel(object):
         """
         Function to plot training losses for train and val datasets against steps
         params:
-            train_losses/val_losses       :list, train losses, which length should be equal to the number of training steps
+            train_losses/val_losses       : list, train losses, which length should be equal to the number of training steps
             step                          : int, current training step
             output_dir                    : str,  the path to save the plot
         """ 
-- 
GitLab