diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 92d33264d3dd831638a4afb33bc0fc1f18e41d79..08202678b896f43630ae2cecf88f34fe2fe67298 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -7,13 +7,14 @@ We took the code implementation from https://github.com/alexlee-gk/video_predict """ __email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong" +__author__ = "Bing Gong, Michael Langguth" __date__ = "2020-10-22" import argparse import errno import json import os +from typing import Union, List import random import time import numpy as np @@ -22,30 +23,30 @@ from model_modules.video_prediction import datasets, models import matplotlib.pyplot as plt import pickle as pkl from model_modules.video_prediction.utils import tf_utils +from general_utils import * class TrainModel(object): - def __init__(self, input_dir=None, output_dir=None, datasplit_dict=None, - model_hparams_dict=None, model=None, checkpoint=None, dataset=None, - gpu_mem_frac=None, seed=None, args=None, save_diag_intv=20, save_model_intv = 1000): - + def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None, + model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None, + gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01): + """ + Class instance for training the models + :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located + :param output_dir: directory where all the output is saved (e.g. model, JSON-files, training curves etc.) + :param datasplit_dict: JSON-file for defining data splitting + :param model_hparams_dict: JSON-file of model hyperparameters + :param model: model class name + :param checkpoint: checkpoint directory (pre-trained models) + :param dataset: dataset class name + :param gpu_mem_frac: fraction of GPU memory to be preallocated + :param seed: seed of the randomizers + :param args: list of arguments passed + :param diag_intv_frac: interval for diagnozing and saving model; the fraction with respect to the number of + steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results + into a diagnozing intreval of 10 iteration steps (= interval over which validation loss + is averaged to identify best model performance) """ - This class aims to train the models - args: - input_dir : str, the path to the PreprocessData directory which is parent directory of "Pickle" and "tfrecords" files directiory. - output_dir : str, directory where json files, summary, model, gifs, etc are saved. " - "default is logs_dir/model_fname, where model_fname consists of " - "information from model and model_hparams - datasplit_dict : str, the path pointing to the datasplit_config json file - model_hparams_dict : str, a json file of model hyperparameters - checkpoint : str, directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000) - dataset : str, dataset class name - model : str, model class name - gpu_mem_frac : float, fraction of gpu memory to use - save_diag_intv : int, interval of iteration steps for which the loss is saved and for which a new - loss curve is plotted - save_model_intv : int, interval of iteration for which the model is checkpointed - """ self.input_dir = os.path.normpath(input_dir) self.output_dir = os.path.normpath(output_dir) self.datasplit_dict = datasplit_dict @@ -56,9 +57,12 @@ class TrainModel(object): self.gpu_mem_frac = gpu_mem_frac self.seed = seed self.args = args + self.diag_intv_frac = diag_intv_frac # for diagnozing and saving the model during training - self.save_diag_intv = save_diag_intv - self.save_model_intv = save_model_intv + self.saver_loss = None # set in create_fetches_for_train-method + self.saver_loss_name = None # set in create_fetches_for_train-method + self.saver_loss_dict = None # set in create_fetches_for_train-method if loss of interest is nested + self.diag_intv_step = None # set in calculate_samples_and_epochs-method def setup(self): self.set_seed() @@ -87,13 +91,14 @@ class TrainModel(object): """ Checks if output directory is existing. """ + method = TrainModel.check_output_dir.__name__ + if self.output_dir is None: - raise ValueError("Output_dir-argument is empty. Please define a proper output_dir") + raise ValueError("%{0}: Output_dir-argument is empty. Please define a proper output_dir".format(method)) elif not os.path.isdir(self.output_dir): - raise NotADirectoryError("Base output_dir {0} does not exist. Please pass a proper output_dir and "+\ - "make use of config_train.py.") + raise NotADirectoryError("Base output_dir {0} does not exist. Pass a proper output_dir".format(method) + + " and make use of env_setup/generate_runscript.py.") - def get_model_hparams_dict(self): """ Get and read model_hparams_dict from json file to dictionary @@ -104,7 +109,6 @@ class TrainModel(object): self.model_hparams_dict_load.update(json.loads(f.read())) return self.model_hparams_dict_load - def load_params_from_checkpoints_dir(self): """ If checkpoint is none, load and read the json files of datasplit_config, and hparam_config, @@ -112,6 +116,8 @@ class TrainModel(object): If the checkpoint is given, the configuration of dataset, model and options in the checkpoint dir will be restored and used for continue training. """ + method = TrainModel.load_params_from_checkpoints_dir.__name__ + if self.checkpoint: self.checkpoint_dir = os.path.normpath(self.checkpoint) if not os.path.isdir(self.checkpoint): @@ -121,18 +127,18 @@ class TrainModel(object): # read and overwrite dataset and model from checkpoint try: with open(os.path.join(self.checkpoint_dir, "options.json")) as f: - print("loading options from checkpoint %s" % self.checkpoint) + print("%{0}: Loading options from checkpoint '{1}'".format(method, self.checkpoint)) self.options = json.loads(f.read()) self.dataset = self.dataset or self.options['dataset'] self.model = self.model or self.options['model'] except FileNotFoundError: - print("options.json was not loaded because it does not exist in {0}".format(self.checkpoint_dir)) + print("%{0}: options.json does not exist in {1}".format(method, self.checkpoint_dir)) # loading hyperparameters from checkpoint try: with open(os.path.join(self.checkpoint_dir, "model_hparams.json")) as f: self.model_hparams_dict_load.update(json.loads(f.read())) except FileNotFoundError: - print("model_hparams.json was not loaded because it does not exist in {0}".format(self.checkpoint_dir)) + print("%{0}: model_hparams.json does not exist in {1}".format(method, self.checkpoint_dir)) def setup_dataset(self): """ @@ -144,7 +150,7 @@ class TrainModel(object): hparams_dict_config=self.model_hparams_dict) self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) - # ML/BG 2021-06-15: Is the following needed? + # Retrieve sequence length from dataset self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) def setup_model(self, mode="train"): @@ -160,30 +166,27 @@ class TrainModel(object): build model graph """ self.video_model.build_graph(self.inputs) - def make_dataset_iterator(self): """ Prepare the dataset interator for training and validation """ self.batch_size = self.model_hparams_dict_load["batch_size"] - self.train_tf_dataset = self.train_dataset.make_dataset(self.batch_size) - self.train_iterator = self.train_tf_dataset.make_one_shot_iterator() + train_tf_dataset = self.train_dataset.make_dataset(self.batch_size) + train_iterator = train_tf_dataset.make_one_shot_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated # and used to feed the `handle` placeholder. - self.train_handle = self.train_iterator.string_handle() - self.val_tf_dataset = self.val_dataset.make_dataset(self.batch_size) - self.val_iterator = self.val_tf_dataset.make_one_shot_iterator() - self.val_handle = self.val_iterator.string_handle() + self.train_handle = train_iterator.string_handle() + val_tf_dataset = self.val_dataset.make_dataset(self.batch_size) + val_iterator = val_tf_dataset.make_one_shot_iterator() + self.val_handle = val_iterator.string_handle() self.iterator = tf.data.Iterator.from_string_handle( - self.train_handle, self.train_tf_dataset.output_types, self.train_tf_dataset.output_shapes) + self.train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) self.inputs = self.iterator.get_next() - #since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model, - # otherwise the model will raise error - + # since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP + # Otherwise an error will be risen by SAVP if self.dataset == "era5" and self.model == "savp": - del self.inputs["T_start"] - + del self.inputs["T_start"] def save_dataset_model_params_to_checkpoint_dir(self, dataset, video_model): """ @@ -195,20 +198,18 @@ class TrainModel(object): f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) with open(os.path.join(self.output_dir, "model_hparams.json"), "w") as f: f.write(json.dumps(video_model.hparams.values(), sort_keys=True, indent=4)) - with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f: - f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4)) - + #with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f: + # f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4)) def count_parameters(self): """ Count the paramteres of the model """ with tf.name_scope("parameter_count"): - # exclude trainable variables that are replicas (used in multi-gpu setting) + # exclude trainable variables that are replicates (used in multi-gpu setting) self.trainable_variables = set(tf.trainable_variables()) & set(self.video_model.saveable_variables) self.parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in self.trainable_variables]) - def create_saver_and_writer(self): """ Create saver to save the models latest checkpoints, and a summery writer to store the train/val metrics @@ -223,45 +224,49 @@ class TrainModel(object): self.gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.gpu_mem_frac, allow_growth=True) self.config = tf.ConfigProto(gpu_options=self.gpu_options, allow_soft_placement=True) - def calculate_samples_and_epochs(self): """ Calculate the number of samples for train dataset, which is used for each epoch training Calculate the iterations (samples multiple by max_epochs) for training. """ + method = TrainModel.calculate_samples_and_epochs.__name__ + batch_size = self.video_model.hparams.batch_size - max_epochs = self.video_model.hparams.max_epochs #the number of epochs + max_epochs = self.video_model.hparams.max_epochs # the number of epochs self.num_examples = self.train_dataset.num_examples_per_epoch() self.steps_per_epoch = int(self.num_examples/batch_size) + self.diag_intv_step = int(self.diag_intv_frac*self.steps_per_epoch) self.total_steps = self.steps_per_epoch * max_epochs - print("Batch size is {} ; max_epochs is {}; num_samples per epoch is {}; steps_per_epoch is {}, total steps is {}".format(batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps)) + print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}" + .format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps)) def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): """ Restore the models checkpoints if the checkpoints is given """ - + method = TrainModel.restore.__name__ + if checkpoints is None: - print ("Checkpoint is empty!!") - elif os.path.isdir(checkpoints) and (not os.path.exists(os.path.join(checkpoints,"checkpoint"))): - print("There is not checkpoints in the dir {}".format(checkpoints)) + print("%{0}: Checkpoint is None!".format(method)) + elif os.path.isdir(checkpoints) and (not os.path.exists(os.path.join(checkpoints, "checkpoint"))): + print("%{0}: There are no checkpoints in the dir {1}".format(method, checkpoints)) else: - var_list = self.video_model.saveable_variables - # possibly restore from multiple checkpoints. useful if subset of weights - # (e.g. generator or discriminator) are on different checkpoints. - if not isinstance(checkpoints, (list, tuple)): - checkpoints = [checkpoints] - # automatically skip global_step if more than one checkpoint is provided - skip_global_step = len(checkpoints) > 1 - savers = [] - for checkpoint in checkpoints: - print("creating restore saver from checkpoint %s" % checkpoint) - saver, _ = tf_utils.get_checkpoint_restore_saver( - checkpoint, var_list, skip_global_step=skip_global_step, - restore_to_checkpoint_mapping=restore_to_checkpoint_mapping) - savers.append(saver) - restore_op = [saver.saver_def.restore_op_name for saver in savers] - sess.run(restore_op) + var_list = self.video_model.saveable_variables + # possibly restore from multiple checkpoints. useful if subset of weights + # (e.g. generator or discriminator) are on different checkpoints. + if not isinstance(checkpoints, (list, tuple)): + checkpoints = [checkpoints] + # automatically skip global_step if more than one checkpoint is provided + skip_global_step = len(checkpoints) > 1 + savers = [] + for checkpoint in checkpoints: + print("%{0}: creating restore saver from checkpoint {1}".format(method, checkpoint)) + saver, _ = tf_utils.get_checkpoint_restore_saver(checkpoint, var_list, + skip_global_step=skip_global_step, + restore_to_checkpoint_mapping=restore_to_checkpoint_mapping) + savers.append(saver) + restore_op = [saver.saver_def.restore_op_name for saver in savers] + sess.run(restore_op) def restore_train_val_losses(self): """ @@ -269,161 +274,216 @@ class TrainModel(object): """ if self.checkpoint is None: train_losses, val_losses = [], [] - elif os.path.isdir(self.checkpoint) and (not os.path.exists(os.path.join(self.output_dir,"checkpoint"))): + elif os.path.isdir(self.checkpoint) and (not os.path.exists(os.path.join(self.output_dir, "checkpoint"))): train_losses,val_losses = [], [] else: - with open(os.path.join(self.output_dir,"train_losses.pkl"),"rb") as f: + with open(os.path.join(self.output_dir, "train_losses.pkl"), "rb") as f: train_losses = pkl.load(f) - with open(os.path.join(self.output_dir,"val_losses.pkl"),"rb") as f: + with open(os.path.join(self.output_dir, "val_losses.pkl"), "rb") as f: val_losses = pkl.load(f) return train_losses,val_losses def train_model(self): """ - Start session and train the model + Start session and train the model by looping over all iteration steps """ + method = TrainModel.train_model.__name__ + self.global_step = tf.train.get_or_create_global_step() with tf.Session(config=self.config) as sess: print("parameter_count =", sess.run(self.parameter_count)) sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) self.restore(sess, self.checkpoint) - #sess.graph.finalize() - self.start_step = sess.run(self.global_step) - print("start_step", self.start_step) + start_step = sess.run(self.global_step) + print("%{0}: Iteration starts at step {1}".format(method, start_step)) # start at one step earlier to log everything without doing any training # step is relative to the start_step train_losses, val_losses = self.restore_train_val_losses() + # initialize auxiliary variables time_per_iteration = [] run_start_time = time.time() - for step in range(self.start_step,self.total_steps): + val_loss_min = 999. + # perform iteration + for step in range(start_step, self.total_steps): timeit_start = time.time() - #run for training dataset + # Run training data self.create_fetches_for_train() # In addition to the loss, we fetch the optimizer self.results = sess.run(self.fetches) # ...and run it here! - train_losses.append(self.results["total_loss"]) - #Run and fetch losses for validation data + # Note: For SAVP, the obtained loss is a list where the first element is of interest, for convLSTM, + # it's just a number. Thus, with list(<losses>)[0], we can handle both + train_losses.append(list(self.results[self.saver_loss])[0]) + # run and fetch losses for validation data val_handle_eval = sess.run(self.val_handle) self.create_fetches_for_val() - self.val_results = sess.run(self.val_fetches,feed_dict={self.train_handle: val_handle_eval}) - val_losses.append(self.val_results["total_loss"]) + self.val_results = sess.run(self.val_fetches, feed_dict={self.train_handle: val_handle_eval}) + val_losses.append(list(self.val_results[self.saver_loss])[0]) self.write_to_summary() - self.print_results(step,self.results) - timeit_end = time.time() - time_per_iteration.append(timeit_end - timeit_start) - print("time needed for this step", timeit_end - timeit_start, ' s') - if step % self.save_model_intv == 0: - self.saver.save(sess, os.path.join(self.output_dir, "model"), global_step=step) - if step % self.save_diag_intv == 0: - # I save the pickle file and plot here inside the loop in case the training process cannot finished after job is done. - TrainModel.save_results_to_pkl(train_losses,val_losses,self.output_dir) - TrainModel.plot_train(train_losses,val_losses,step,self.output_dir) - - #Totally train time over all the iterations + self.print_results(step, self.results) + # track iteration time + time_iter = time.time() - timeit_start + time_per_iteration.append(time_iter) + print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter)) + if step > self.diag_intv_step and (step % self.diag_intv_step == 0 or step == self.total_steps - 1): + lsave, val_loss_min = TrainModel.set_model_saver_flag(val_losses, val_loss_min, self.diag_intv_step) + # save best and final model state + if lsave or step == self.total_steps - 1: + self.saver.save(sess, os.path.join(self.output_dir, "model_best" if lsave else "model_last"), + global_step=step) + # pickle file and plots are always created + TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) + TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) + + # Final diagnostics + # track time (save to pickle-files) train_time = time.time() - run_start_time - results_dict = {"train_time":train_time, - "total_steps":self.total_steps} + results_dict = {"train_time": train_time, + "total_steps": self.total_steps} TrainModel.save_results_to_dict(results_dict,self.output_dir) - print("train_losses:",train_losses) - print("val_losses:",val_losses) - print("Done") - print("Total training time:", train_time/60., "min") + print("%{0}: Training loss decreased from {1:.6f} to {2:.6f}:" + .format(method, np.mean(train_losses[0:10]), np.mean(train_losses[-self.diag_intv_step:]))) + print("%{0}: Validation loss decreased from {1:.6f} to {2:.6f}:" + .format(method, np.mean(val_losses[0:10]), np.mean(val_losses[-self.diag_intv_step:]))) + print("%{0}: Training finsished".format(method)) + print("%{0}: Total training time: {1:.2f} min".format(method, train_time/60.)) + return train_time, time_per_iteration - def create_fetches_for_train(self): - """ - Fetch variables in the graph, this can be custermized based on models and also the needs of users - """ - #This is the base fetch that for all the models - self.fetches = {"train_op": self.video_model.train_op} # fetching the optimizer! - self.fetches["summary"] = self.video_model.summary_op - self.fetches["global_step"] = self.global_step - if self.video_model.__class__.__name__ == "McNetVideoPredictionModel": self.fetches_for_train_mcnet() - if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": self.fetches_for_train_convLSTM() - if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.fetches_for_train_savp() - if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": self.fetches_for_train_vae() - if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":self.fetches_for_train_gan() - if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":self.fetches_for_train_convLSTM() - return self.fetches - - def fetches_for_train_convLSTM(self): """ - Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users + Fetch variables in the graph, this can be custermized based on models and also the needs of users """ - self.fetches["total_loss"] = self.video_model.total_loss - self.fetches["inputs"] = self.video_model.inputs + # This is the basic fetch for all the models + fetch_list = ["train_op", "summary_op", "global_step"] - - def fetches_for_train_savp(self): + # Append fetches depending on model to be trained + if self.video_model.__class__.__name__ == "McNetVideoPredictionModel": + fetch_list = fetch_list + ["L_p", "L_gdl", "L_GAN"] + self.saver_loss = fetch_list[-3] # ML: Is this a reasonable choice? + self.saver_loss_name = "Loss" + if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": + fetch_list = fetch_list + ["inputs", "total_loss"] + self.saver_loss = fetch_list[-1] + self.saver_loss_name = "Total loss" + if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": + fetch_list = fetch_list + ["g_losses", "d_losses", "d_loss", "g_loss", ("g_losses", "gen_l1_loss")] + # Add loss that is tracked + self.saver_loss = fetch_list[-1][1] + self.saver_loss_dict = fetch_list[-1][0] + self.saver_loss_name = "Generator L1 loss" + if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": + fetch_list = fetch_list + ["latent_loss", "recon_loss", "total_loss"] + self.saver_loss = fetch_list[-2] + self.saver_loss_name = "Reconstruction loss" + if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel": + fetch_list = fetch_list + ["inputs", "total_loss"] + self.saver_loss = fetch_list[-1] + self.saver_loss_name = "Total loss" + if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": + fetch_list = fetch_list + ["inputs", "total_loss"] + self.saver_loss = fetch_list[-1] + self.saver_loss_name = "Total loss" + + self.fetches = self.generate_fetches(fetch_list) + + return self.fetches + + def create_fetches_for_val(self): """ - Fetch variables in the graph for savp model, this can be custermized based on models and the needs of users + Fetch variables in the graph for validation dataset, customized depending on models and users' needs """ - self.fetches["g_losses"] = self.video_model.g_losses - self.fetches["d_losses"] = self.video_model.d_losses - self.fetches["d_loss"] = self.video_model.d_loss - self.fetches["g_loss"] = self.video_model.g_loss - self.fetches["total_loss"] = self.video_model.g_loss - self.fetches["inputs"] = self.video_model.inputs + method = TrainModel.create_fetches_for_val.__name__ + if not self.saver_loss: + raise AttributeError("%{0}: saver_loss is still not set. create_fetches_for_train must be run in advance." + .format(method)) + + if self.saver_loss_dict: + fetch_list = ["summary_op", (self.saver_loss_dict, self.saver_loss)] + else: + fetch_list= ["summary_op", self.saver_loss] - def fetches_for_train_mcnet(self): - """ - Fetch variables in the graph for mcnet model, this can be custermized based on models and the needs of users - """ - self.fetches["L_p"] = self.video_model.L_p - self.fetches["L_gdl"] = self.video_model.L_gdl - self.fetches["L_GAN"] = self.video_model.L_GAN + self.val_fetches = self.generate_fetches(fetch_list) - def fetches_for_train_vae(self): + return self.val_fetches + + def generate_fetches(self, fetch_list): """ - Fetch variables in the graph for savp model, this can be custermized based on models and based on the needs of users + Generates dictionary of fetches from video model instance + :param fetch_list: list of attributes of video model instance that are of particular interest; + can also handle tuples as list-elements to get attributes nested in a dictionary + :return: dictionary of fetches with keys from fetch_list and values from video model instance """ - self.fetches["latent_loss"] = self.video_model.latent_loss - self.fetches["recon_loss"] = self.video_model.recon_loss - self.fetches["total_loss"] = self.video_model.total_loss + method = TrainModel.generate_fetches.__name__ - def fetches_for_train_gan(self): - self.fetches["total_loss"] = self.video_model.total_loss + if not self.video_model: + raise AttributeError("%{0}: video_model is still not set. setup_model must be run in advance." + .format(method)) - def create_fetches_for_val(self): - """ - Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users - """ - if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": - self.val_fetches = {"total_loss": self.video_model.g_loss} - self.val_fetches["inputs"] = self.video_model.inputs - else: - self.val_fetches = {"total_loss": self.video_model.total_loss} - self.val_fetches["inputs"] = self.video_model.inputs - self.val_fetches["summary"] = self.video_model.summary_op + fetches = {} + for fetch_req in fetch_list: + try: + if isinstance(fetch_req, tuple): + fetches[fetch_req[1]] = getattr(self.video_model, fetch_req[0])[fetch_req[1]] + else: + fetches[fetch_req] = getattr(self.video_model, fetch_req) + except Exception as err: + print("%{0}: Failed to retrieve {1} from video_model-attribute.".format(method, fetch_req)) + raise err + + return fetches def write_to_summary(self): - self.summary_writer.add_summary(self.results["summary"],self.results["global_step"]) - self.summary_writer.add_summary(self.val_results["summary"],self.results["global_step"]) + self.summary_writer.add_summary(self.results["summary_op"], self.results["global_step"]) + self.summary_writer.add_summary(self.val_results["summary_op"], self.results["global_step"]) self.summary_writer.flush() - - def print_results(self,step,results): + def print_results(self, step, results): """ Print the training results /validation results from the training step. """ + method = TrainModel.print_results.__name__ + train_epoch = step/self.steps_per_epoch - print("progress global step %d epoch %0.1f" % (step + 1, train_epoch)) + print("%{0}: Progress global step {1:d} epoch {2:.1f}".format(method, step + 1, train_epoch)) if self.video_model.__class__.__name__ == "McNetVideoPredictionModel": - print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"])) + print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"], results["L_p"], + results["L_gdl"],results["L_GAN"])) elif self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": print ("Total_loss:{}".format(results["total_loss"])) elif self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": - print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"])) + print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}, gen_l1_loss: {}" + .format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"],results["gen_l1_loss"])) elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"])) else: - print ("The model name does not exist") + print("%{0}: Printing results of the model {1} is not implemented yet".format(method, self.video_model.__class__.__name__)) - @staticmethod - def plot_train(train_losses,val_losses,step,output_dir): + def set_model_saver_flag(losses: List, old_min_loss: float, niter_steps: int = 100): + """ + Sets flag to save the model given that a new minimum in the loss is readched + :param losses: list of losses over iteration steps + :param old_min_loss: previous loss + :param niter_steps: number of iteration steps over which the loss is averaged + :return flag: True if model should be saved + :return loss_avg: updated minimum loss + """ + save_flag = False + if len(losses) <= niter_steps*2: + loss_avg = old_min_loss + return save_flag, loss_avg + + loss_avg = np.mean(losses[-niter_steps:]) + if loss_avg < old_min_loss: + save_flag = True + else: + loss_avg = old_min_loss + + return save_flag, loss_avg + + @staticmethod + def plot_train(train_losses, val_losses, loss_name, output_dir): """ Function to plot training losses for train and val datasets against steps params: @@ -441,7 +501,7 @@ class TrainModel(object): plt.yscale("log") plt.title('Training and Validation loss') plt.xlabel('Iterations') - plt.ylabel('Loss') + plt.ylabel(loss_name) plt.legend() plt.savefig(os.path.join(output_dir,'plot_train.png')) plt.close() @@ -477,19 +537,17 @@ class TrainModel(object): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " - "train, val, test, etc, or a directory containing " - "the tfrecords") - parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " - "default is logs_dir/model_fname, where model_fname consists of " - "information from model and model_hparams") - parser.add_argument("--datasplit_dict", help="json file that contains the datasplit configuration") - parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") - parser.add_argument("--dataset", type=str, help="dataset class name") - parser.add_argument("--model", type=str, help="model class name") - parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") - parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="fraction of gpu memory to use") - parser.add_argument("--seed",default=1234, type=int) + parser.add_argument("--input_dir", type=str, required=True, + help="Directory where input data as TFRecord-files are stored.") + parser.add_argument("--output_dir", help="Output directory where JSON-files, summary, model, plots etc. are saved.") + parser.add_argument("--datasplit_dict", help="JSON-file that contains the datasplit configuration") + parser.add_argument("--checkpoint", help="Checkpoint directory or checkpoint name (e.g. <my_dir>/model-200000)") + parser.add_argument("--dataset", type=str, help="Dataset class name") + parser.add_argument("--model", type=str, help="Model class name") + parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters") + parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use") + parser.add_argument("--seed", default=1234, type=int) + args = parser.parse_args() # start timing for the whole run timeit_start_total_time = time.time() diff --git a/video_prediction_tools/model_modules/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py index 508f766d428d402c5119afbfd811f1eb73cbb5f7..41f4bb3d1eb59c7e51624c03bfb8a582cb38b1e2 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/base_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/base_model.py @@ -69,6 +69,8 @@ class BaseVideoPredictionModel(object): self.accum_eval_metrics = None self.saveable_variables = None self.post_init_ops = None + # ML 2021-06-23: Do not hide global step in self.saveable_variables + self.global_step = None def get_default_hparams_dict(self): """ @@ -471,6 +473,8 @@ class VideoPredictionModel(BaseVideoPredictionModel): BaseVideoPredictionModel.build_graph(self, inputs) global_step = tf.train.get_or_create_global_step() + # ML 2021-06-23: Do not hide global step in self.saveable_variables + self.global_step = global_step # Capture the variables created from here until the train_op for the # saveable_variables. Note that if variables are being reused (e.g. # they were created by a previously built model), those variables won't diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index 0c079c4c087d919490b1ade80f0bd73368f008c7..5c949c96a4f849d82143e80878bb59d8f6661965 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -15,7 +15,6 @@ class VanillaConvLstmVideoPredictionModel(object): """ This is class for building convLSTM architecture by using updated hparameters args: - mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model hparams_dict : dict, the dictionary contains the hparaemters names and values """ self.hparams_dict = hparams_dict diff --git a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py index 7a1da880defb61dbd018c6f11ee14c34cf0ce43e..415275e8f909560cab725ecae38bb37a01809244 100644 --- a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py +++ b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py @@ -526,10 +526,14 @@ def reduce_tensors(structures, shallow=False): def get_checkpoint_restore_saver(checkpoint, var_list=None, skip_global_step=False, restore_to_checkpoint_mapping=None): + method = get_checkpoint_restore_saver.__name__ if os.path.isdir(checkpoint): # latest_checkpoint doesn't work when the path has special characters checkpoint = tf.train.latest_checkpoint(checkpoint) + # print name of checkpoint-file for verbosity + print("%{0}: The follwoing checkpoint is used for restoring the model: '{1}'".format(method, checkpoint)) + # Start processing the checkpoint checkpoint_reader = tf.pywrap_tensorflow.NewCheckpointReader(checkpoint) checkpoint_var_names = checkpoint_reader.get_variable_to_shape_map().keys() restore_to_checkpoint_mapping = restore_to_checkpoint_mapping or (lambda name, _: name.split(':')[0]) diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py index 65a25c43e803abc4240edd2e6c4184420a5b0722..83ed526bac9c960229a30585be61606cbf527317 100644 --- a/video_prediction_tools/utils/general_utils.py +++ b/video_prediction_tools/utils/general_utils.py @@ -11,12 +11,15 @@ Provides: * get_unique_vars """ # import modules +from typing import List, Union import os import numpy as np -#import xarray as xr +str_or_List = Union[List, str] # routines -def get_unique_vars(varnames): + + +def get_unique_vars(varnames: List[str]): """ :param varnames: list of variable names (or any other list of strings) :return: list with unique elements of inputted varnames list @@ -27,7 +30,7 @@ def get_unique_vars(varnames): return vars_uni, varsind, nvars_uni -def add_str_to_path(path_in, add_str): +def add_str_to_path(path_in: str, add_str: str): """ :param path_in: input path which is a candidate to be extended by add_str (see below) :param add_str: String to be added to path_in if not already done @@ -92,12 +95,15 @@ def isw(value, interval): return status -def check_str_in_list(list_in, str2check, labort=True): +def check_str_in_list(list_in: List, str2check: str_or_List, labort: bool = True, return_ind: bool = False): """ Checks if all strings are found in list :param list_in: input list :param str2check: string or list of strings to be checked if they are part of list_in - :return: True if existence of all strings was confirmed + :param labort: Flag if error will be risen in case of missing string in list + :param return_ind: Flag if index for each string found in list will be returned + :return: True if existence of all strings was confirmed, if return_ind is True, the index of each string in list is + returned as well """ method = check_str_in_list.__name__ @@ -112,20 +118,25 @@ def check_str_in_list(list_in, str2check, labort=True): stat_element = [True if str1 in list_in else False for str1 in str2check] - if not np.all(stat_element): + if np.all(stat_element): + stat = True + else: print("%{0}: The following elements are not part of the input list:".format(method)) - inds_miss = np.where(stat_element)[0] + inds_miss = np.where(list(~np.array(stat_element)))[0] for i in inds_miss: print("* index {0:d}: {1}".format(i, str2check[i])) if labort: raise ValueError("%{0}: Could not find all expected strings in list.".format(method)) + # return + if stat and not return_ind: + return stat + elif stat: + return stat, [list_in.index(str_curr) for str_curr in str2check] else: - stat = True - - return stat + return stat, [] -def check_dir(path2dir: str, lcreate=False): +def check_dir(path2dir: str, lcreate: bool = False): """ Checks if path2dir exists and create it if desired :param path2dir: @@ -174,7 +185,31 @@ def reduce_dict(dict_in: dict, dict_ref: dict): return dict_reduced -def provide_default(dict_in, keyname, default=None, required=False): +def find_key(dict_in: dict, key: str): + """ + Searchs through nested dictionaries for key. + :param dict_in: input dictionary (cas also be an OrderedDictionary) + :param key: key to be retrieved + :return: value of the key in dict_in + """ + method = find_key.__name__ + # sanity check + if not isinstance(dict_in, dict): + raise TypeError("%{0}: dict_in must be a dictionary instance, but is of type '{1}'" + .format(method, type(dict_in))) + # check for key + if key in dict_in: + return dict_in[key] + for k, v in dict_in.items(): + if isinstance(v,dict): + item = find_key(v, key) + if item is not None: + return item + + raise ValueError("%{0}: {1} could not be found in dict_in".format(method, key)) + + +def provide_default(dict_in: dict, keyname: str, default=None, required: bool = False): """ Returns values of key from input dictionary or alternatively its default diff --git a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py index d420ed1be8460ade5fae3302960de8b761f49f72..025a203b5bd7bc038239a95afea003938b811734 100755 --- a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py +++ b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py @@ -280,11 +280,11 @@ class Config_Preprocess1(Config_runscript_base): check_vars = [var.strip().lower() in known_vars for var in vars_list] status = all(check_vars) if not status: - inds_bad = [i for i, e in enumerate(check_vars) if e] # np.where(~np.array(check_vars))[0] + inds_bad = [i for i, e in enumerate(check_vars) if not e] # np.where(~np.array(check_vars))[0] if not silent: print("%{0}: The following comma-separated elements are unknown variables:".format(method)) for ind in inds_bad: - print(vars_list[ind]) + print("* {0}".format(vars_list[ind])) return status if not len(check_vars) >= 1: