diff --git a/video_prediction_tools/data_extraction/extract_weatherbench.py b/video_prediction_tools/data_extraction/extract_weatherbench.py index a5798c188d62b69d93ef22efa0253a80bf2b031a..ca126a8ccd301c9adb00f0cc54f71e3913bc6a0a 100644 --- a/video_prediction_tools/data_extraction/extract_weatherbench.py +++ b/video_prediction_tools/data_extraction/extract_weatherbench.py @@ -1,16 +1,12 @@ -import os, glob -import logging +import logging from zipfile import ZipFile from typing import Union from pathlib import Path import multiprocessing as mp import itertools as it -import sys - import pandas as pd import xarray as xr - from utils.dataset_utils import get_filename_template logging.basicConfig(level=logging.DEBUG) diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 7dad59d4fc70f0abb82dfbe6fa3be466be8c1f03..0ec4095e831b558e0c4bc7f8abf6432f948115d4 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -12,16 +12,16 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong, Michael Langguth" __date__ = "2020-10-22" -import os, glob +import os, glob, sys import argparse import errno import json -from typing import Union, List import random import time import numpy as np import xarray as xr import tensorflow as tf +sys.path.append("../") from model_modules.video_prediction import models from model_modules.video_prediction.datasets import get_dataset import matplotlib.pyplot as plt @@ -31,27 +31,39 @@ from general_utils import * import math import shutil from pathlib import Path +from tensorflow.python.keras.utils.layer_utils import count_params +tf.config.experimental_run_functions_eagerly(True) + class TrainModel(object): - def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None, - model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset_name: str = None, - gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.001, - frac_start_save: float = None, frac_intv_save: float = None): + def __init__(self, input_dir: str = None, + output_dir: str = None, + datasplit_dict: str = None, + model_hparams_dict: str = None, + model: str = None, + checkpoint: str = None, + dataset_name: str = None, + gpu_mem_frac: float = 1., + seed: int = None, + args=None, + diag_intv_frac: float = 0.001, + frac_start_save: float = None, + frac_intv_save: float = None): """ Class instance for training the models - :param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located - :param output_dir: directory where all the output is saved (e.g. model, JSON-files, training curves etc.) - :param datasplit_dict: JSON-file for defining data splitting - :param model_hparams_dict: JSON-file of model hyperparameters - :param model: model class name - :param checkpoint: checkpoint directory (pre-trained models) - :param dataset: dataset class name - :param gpu_mem_frac: fraction of GPU memory to be preallocated - :param seed: seed of the randomizers - :param args: list of arguments passed - :param diag_intv_frac: interval for diagnozing the model (create loss-curves and save pickle-file with losses) - :param frac_start_save: fraction of total iterations steps to start checkpointing the model - :param frac_intv_save: fraction of total iterations steps for checkpointing the model + :param input_dir : parent directory under which "pickle" and "tfrecords" files directiory are located + :param output_dir : directory where all the output is saved (e.g. model, JSON-files, training curves etc.) + :param datasplit_dict : JSON-file for defining data splitting + :param model_hparams_dict : JSON-file of model hyperparameters + :param model : model class name + :param checkpoint : checkpoint directory (pre-trained models) + :param dataset : dataset class name + :param gpu_mem_frac : fraction of GPU memory to be preallocated + :param seed : seed of the randomizers + :param args : list of arguments passed + :param diag_intv_frac : interval for diagnosing the model (create loss-curves and save pickle-file with losses) + :param frac_start_save : fraction of total iterations steps to start checkpointing the model + :param frac_intv_save : fraction of total iterations steps for checkpointing the model """ self.input_dir = Path(input_dir).resolve(strict=False) self.output_dir = Path(output_dir).resolve(strict=False) @@ -66,7 +78,7 @@ class TrainModel(object): self.diag_intv_frac = diag_intv_frac self.frac_start_save = frac_start_save self.frac_intv_save = frac_intv_save - # for diagnozing and saving the model during training + self.saver_loss = None # set in create_fetches_for_train-method self.saver_loss_name = None # set in create_fetches_for_train-method self.saver_loss_dict = None # set in create_fetches_for_train-method if loss of interest is nested @@ -77,21 +89,20 @@ class TrainModel(object): self.get_model_hparams_dict() self.load_params_from_checkpoints_dir() self.setup_datasets() - self.make_dataset_iterator() self.setup_model() - self.setup_graph() - self.save_dataset_model_params_to_checkpoint_dir(dataset=self.dataset, video_model=self.video_model) # TODO: resolve potetial incompatibility - self.count_parameters() - self.create_saver_and_writer() - self.setup_gpu_config() - self.calculate_checkpoint_saver_conf() + self.save_dataset_model_params_to_checkpoint_dir(dataset=self.dataset, + video_model=self.video_model) + #self.count_parameters() + # self.create_saver_and_writer() + # self.setup_gpu_config() + # self.calculate_checkpoint_saver_conf() def set_seed(self): """ Set seed to control the same train/val/testing dataset for the same seed """ if self.seed is not None: - tf.set_random_seed(self.seed) + tf.random.set_seed(self.seed) np.random.seed(self.seed) random.seed(self.seed) @@ -113,10 +124,10 @@ class TrainModel(object): """ if self.model_hparams_dict: with open(self.model_hparams_dict, 'r') as f: - print("self.model_hparams_dict",self.model_hparams_dict) self.model_hparams_dict_load = json.loads(f.read()) else: - raise FileNotFoundError("hparam directory doesn't exist! please check {}!".format(self.model_hparams_dict)) + raise FileNotFoundError("hparam directory doesn't exist! " + "please check {}!".format(self.model_hparams_dict)) return self.model_hparams_dict_load @@ -124,7 +135,8 @@ class TrainModel(object): """ If checkpoint is none, load and read the json files of datasplit_config, and hparam_config, and use the corresponding parameters. - If the checkpoint is given, the configuration of dataset, model and options in the checkpoint dir will be + If the checkpoint is given, the configuration of dataset, + model and options in the checkpoint dir will be restored and used for continue training. """ method = TrainModel.load_params_from_checkpoints_dir.__name__ @@ -144,7 +156,7 @@ class TrainModel(object): self.model = self.model or self.options['model'] except FileNotFoundError: print("%{0}: options.json does not exist in {1}".format(method, self.checkpoint_dir)) - # loading hyperparameters from checkpoint + try: with open(os.path.join(self.checkpoint_dir, "model_hparams.json")) as f: self.model_hparams_dict_load.update(json.loads(f.read())) @@ -154,15 +166,21 @@ class TrainModel(object): def setup_datasets(self): """ Setup train and val dataset instance with the corresponding data split configuration. - Simultaneously, sequence_length is attached to the hyperparameter dictionary. + Simultaneously, sequence_length is attached to the hyper-parameter dictionary. """ - # get some parameters from the model hyperparameters + # get some parameters from the model hyper-parameters self.batch_size = self.model_hparams_dict_load["batch_size"] self.max_epochs = self.model_hparams_dict_load["max_epochs"] # create dataset instance - - self.dataset = get_dataset(self.dataset_name, input_dir=self.input_dir, output_dir=self.output_dir, datasplit_path=self.datasplit_dict, hparams_path=self.model_hparams_dict, seed=self.seed) - + self.dataset = get_dataset(self.dataset_name, + input_dir=self.input_dir, + output_dir=self.output_dir, + datasplit_path=self.datasplit_dict, + hparams_path=self.model_hparams_dict, + seed=self.seed) + self.train_dataset = self.dataset.make_dataset(mode="train") + self.val_dataset = self.dataset.make_dataset(mode="val") + self.test_dataset = self.dataset.make_dataset(mode="test") self.calculate_samples_and_epochs() self.model_hparams_dict_load.update({"sequence_length": self.dataset.sequence_length}) @@ -172,34 +190,11 @@ class TrainModel(object): :param mode: "train" used the model graph in train process; "test" for postprocessing step """ VideoPredictionModel = models.get_model_class(self.model) - self.video_model = VideoPredictionModel(hparams_dict_config=self.model_hparams_dict, mode=mode) + self.video_model = VideoPredictionModel(hparams_dict_config= + self.model_hparams_dict, + mode=mode) + self.video_model.compile([24, 10, 21, 2]) - def setup_graph(self): - """ - build model graph - """ - self.video_model.build_graph(self.inputs) - - def make_dataset_iterator(self): - """ - Prepare the dataset interator for training and validation - """ - self.batch_size = self.model_hparams_dict_load["batch_size"] - train_tf_dataset = self.dataset.make_training() - train_iterator = train_tf_dataset.make_one_shot_iterator() - # The `Iterator.string_handle()` method returns a tensor that can be evaluated - # and used to feed the `handle` placeholder. - self.train_handle = train_iterator.string_handle() - val_tf_dataset = self.dataset.make_validation() - val_iterator = val_tf_dataset.make_one_shot_iterator() - self.val_handle = val_iterator.string_handle() - self.iterator = tf.data.Iterator.from_string_handle( - self.train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) - self.inputs = self.iterator.get_next() - # since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP - # Otherwise an error will be risen by SAVP - if self.dataset_name == "era5" and self.model == "savp": - del self.inputs["T_start"] def save_dataset_model_params_to_checkpoint_dir(self, dataset, video_model): """ @@ -210,40 +205,40 @@ class TrainModel(object): with open(os.path.join(self.output_dir, "dataset_hparams.json"), "w") as f: f.write(json.dumps(dataset.hparams, sort_keys=True, indent=4)) with open(os.path.join(self.output_dir, "model_hparams.json"), "w") as f: - print("video_model.get_hparams",video_model.get_hparams) f.write(json.dumps(video_model.get_hparams, sort_keys=True, indent=4)) - #with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f: - # f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4)) + def count_parameters(self): """ Count the paramteres of the model - """ - with tf.name_scope("parameter_count"): - # exclude trainable variables that are replicates (used in multi-gpu setting) - self.trainable_variables = set(tf.trainable_variables()) & set(self.video_model.saveable_variables) - self.parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in self.trainable_variables]) + """ + self.trainable_variables = count_params(self.video_model.model.trainable_weights) + self.non_trainable_count = count_params(self.video_model.model.non_trainable_weights) + self.parameter_count = self.trainable_variables + self.non_trainable_count def create_saver_and_writer(self): """ Create saver to save the models latest checkpoints, and a summery writer to store the train/val metrics """ - self.saver = tf.train.Saver(var_list=self.video_model.saveable_variables, max_to_keep=None) + self.saver = tf.train.Saver(var_list=self.video_model.saveable_variables, + max_to_keep=None) self.summary_writer = tf.summary.FileWriter(self.output_dir) def setup_gpu_config(self): """ Setup GPU options """ - self.gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.gpu_mem_frac, allow_growth=True) - self.config = tf.ConfigProto(gpu_options=self.gpu_options, allow_soft_placement=True) + self.gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=self.gpu_mem_frac, + allow_growth=True) + self.config = tf.ConfigProto(gpu_options=self.gpu_options, + allow_soft_placement=True) def calculate_samples_and_epochs(self): """ Calculate the number of samples for train dataset, which is used for each epoch training Calculate the iterations (samples multiple by max_epochs) for training. """ - method = TrainModel.calculate_samples_and_epochs.__name__ + method = TrainModel.calculate_samples_and_epochs.__name__ self.num_examples = self.dataset.num_training_samples self.steps_per_epoch = int(self.num_examples/self.batch_size) @@ -255,11 +250,14 @@ class TrainModel(object): else: pass print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}" - .format(method, self.batch_size, self.max_epochs, self.num_examples, self.steps_per_epoch, + .format(method, + self.batch_size, + self.max_epochs, + self.num_examples, + self.steps_per_epoch, self.total_steps)) - def calculate_checkpoint_saver_conf(self): """ Calculate the start step for saving the checkpoint, and the frequences steps to save model @@ -268,17 +266,20 @@ class TrainModel(object): method = TrainModel.calculate_checkpoint_saver_conf.__name__ if not hasattr(self, "total_steps"): - raise RuntimeError("%{0} self.total_steps is still unset. Run calculate_samples_and_epochs beforehand" + raise RuntimeError("%{0} self.total_steps is still unset. " + "Run calculate_samples_and_epochs beforehand" .format(method)) - if self.frac_intv_save > 1 or self.frac_intv_save<0 : - raise ValueError("%{0}: frac_intv_save must be less than 1 and larger than 0".format(method)) + if self.frac_intv_save > 1 or self.frac_intv_save<0: + raise ValueError("%{0}: frac_intv_save must be less than 1 " + "and larger than 0".format(method)) if self.frac_start_save > 1 or self.frac_start_save < 0: - raise ValueError("%{0}: frac_start_save must be less than 1 and larger than 0".format(method)) + raise ValueError("%{0}: frac_start_save must be less than 1 " + "and larger than 0".format(method)) self.chp_start_step = int(math.ceil(self.total_steps * self.frac_start_save)) self.chp_intv_step = int(math.ceil(self.total_steps * self.frac_intv_save)) print("%{0}: Model will be saved after step {1:d} at each {2:d} interval step " - .format(method, self.chp_start_step,self.chp_intv_step)) + .format(method, self.chp_start_step, self.chp_intv_step)) def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None): """ @@ -292,11 +293,10 @@ class TrainModel(object): print("%{0}: There are no checkpoints in the dir {1}".format(method, checkpoints)) else: var_list = self.video_model.saveable_variables - # possibly restore from multiple checkpoints. useful if subset of weights - # (e.g. generator or discriminator) are on different checkpoints. + if not isinstance(checkpoints, (list, tuple)): checkpoints = [checkpoints] - # automatically skip global_step if more than one checkpoint is provided + skip_global_step = len(checkpoints) > 1 savers = [] for checkpoint in checkpoints: @@ -314,14 +314,15 @@ class TrainModel(object): """ if self.checkpoint is None: train_losses, val_losses = [], [] - elif os.path.isdir(self.checkpoint) and (not os.path.exists(os.path.join(self.output_dir, "checkpoint"))): - train_losses,val_losses = [], [] + elif os.path.isdir(self.checkpoint) and (not os.path.exists(os.path.join(self.output_dir, + "checkpoint"))): + train_losses, val_losses = [], [] else: with open(os.path.join(self.output_dir, "train_losses.pkl"), "rb") as f: train_losses = pkl.load(f) with open(os.path.join(self.output_dir, "val_losses.pkl"), "rb") as f: val_losses = pkl.load(f) - return train_losses,val_losses + return train_losses, val_losses def create_checkpoints_folder(self, step:int=None): """ @@ -337,96 +338,17 @@ class TrainModel(object): """ Start session and train the model by looping over all iteration steps """ - method = TrainModel.train_model.__name__ - - self.global_step = tf.train.get_or_create_global_step() - with tf.Session(config=self.config) as sess: - sess.run(tf.global_variables_initializer()) - sess.run(tf.local_variables_initializer()) - self.restore(sess, self.checkpoint) - start_step = sess.run(self.global_step) - print("%{0}: Iteration starts at step {1}".format(method, start_step)) - # start at one step earlier to log everything without doing any training - # step is relative to the start_step - train_losses, val_losses = self.restore_train_val_losses() - # initialize auxiliary variables - time_per_iteration = [] - run_start_time = time.time() - # perform iteration - for step in range(start_step, self.total_steps): - timeit_start = time.time() - # Run training data - self.create_fetches_for_train() # In addition to the loss, we fetch the optimizer - self.results = sess.run(self.fetches) # ...and run it here! - # Note: For SAVP, the obtained loss is a list where the first element is of interest, for convLSTM, - # it's just a number. Thus, with ensure_list(<losses>)[0], we can handle both - train_losses.append(ensure_list(self.results[self.saver_loss])[0]) - # run and fetch losses for validation data - val_handle_eval = sess.run(self.val_handle) - self.create_fetches_for_val() - self.val_results = sess.run(self.val_fetches, feed_dict={self.train_handle: val_handle_eval}) - val_losses.append(ensure_list(self.val_results[self.saver_loss])[0]) - self.write_to_summary() - self.print_results(step, self.results) - # track iteration time - time_iter = time.time() - timeit_start - time_per_iteration.append(time_iter) - print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter)) - if (step >= self.chp_start_step and (step-self.chp_start_step)%self.chp_intv_step == 0) or \ - step == self.total_steps - 1: - #create a checkpoint folder for step - full_dir_name = self.create_checkpoints_folder(step=step) - self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step) - - # pickle file and plots are always created - if step % self.diag_intv_step == 0 or step == self.total_steps - 1: - TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir) - TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir) - - # Final diagnostics: training track time and save to pickle-files) - train_time = time.time() - run_start_time - - avg_time_first_epoch = np.mean(time_per_iteration[:self.steps_per_epoch]) - avg_time_non_first_epoch = np.mean(time_per_iteration[self.steps_per_epoch:]) - results_dict = {"train_time": train_time, "total_steps": self.total_steps, - "avg_time_first_epoch": avg_time_first_epoch, - "avg_time_non_first_epoch":avg_time_non_first_epoch} - - TrainModel.save_results_to_dict(results_dict, self.output_dir) - - print("%{0}: Training loss decreased from {1:.6f} to {2:.6f}:" - .format(method, np.mean(train_losses[0:10]), np.mean(train_losses[-self.diag_intv_step:]))) - print("%{0}: Validation loss decreased from {1:.6f} to {2:.6f}:" - .format(method, np.mean(val_losses[0:10]), np.mean(val_losses[-self.diag_intv_step:]))) - print("%{0}: Training finished".format(method)) - print("%{0}: Total training time: {1:.2f} min".format(method, train_time/60.)) - print("%{0}: The average of training time for the first epoch: {1:.2f} sec".format(method, avg_time_first_epoch)) - print("%{0}: The average of training time for after first epoch: {1:.2f} sec".format(method,avg_time_non_first_epoch)) - return train_time, time_per_iteration - - def create_fetches_for_train(self): - """ - Fetch variables in the graph, this can be custermized based on models and also the needs of users - """ - # This is the basic fetch for all the models - fetch_list = ["train_op", "summary_op", "global_step","total_loss"] + global_step = 1 + # initialize auxiliary variables + run_start_time = time.time() + for epoch in range(self.max_epochs): + for step, inputs in enumerate(self.train_dataset): + self.video_model.train_step(inputs, global_step) + print("loss", self.video_model.loss_value) + global_step = +1 + + timeit_start = time.time() - # Append fetches depending on model to be trained - if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": - self.saver_loss = fetch_list[-1] - self.saver_loss_name = "Loss" - if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": - fetch_list = fetch_list + ["inputs"] - self.saver_loss = fetch_list[-1] - self.saver_loss_name = "Total loss" - if self.video_model.__class__.__name__ == "WeatherBenchModel": - fetch_list = fetch_list + ["total_loss"] - self.saver_loss = fetch_list[-1] - self.saver_loss_name = "Total loss" - else: - self.saver_loss = "total_loss" - self.fetches = self.generate_fetches(fetch_list) - return self.fetches def create_fetches_for_val(self): """ @@ -492,6 +414,7 @@ class TrainModel(object): else: print("Total_loss:{}" .format(results["total_loss"])) + @staticmethod def plot_train(train_losses, val_losses, loss_name, output_dir): """ @@ -522,10 +445,10 @@ class TrainModel(object): json.dump(results_dict,fp) @staticmethod - def save_results_to_pkl(train_losses,val_losses, output_dir): + def save_results_to_pkl(train_losses, val_losses, output_dir): with open(os.path.join(output_dir,"train_losses.pkl"),"wb") as f: pkl.dump(train_losses,f) - with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f: + with open(os.path.join(output_dir, "val_losses.pkl"),"wb") as f: pkl.dump(val_losses,f) @staticmethod @@ -587,16 +510,22 @@ class BestModelSelector(object): :return: Populated self.checkpoints_eval_all where the average of the metric over all forecast hours is listed """ - method = BestModelSelector.run.__name__ from main_visualize_postprocess import Postprocess metric_avg_all = [] for checkpoint in self.checkpoints_all: print("Start to evalute checkpoint:", checkpoint) results_dir_eager = os.path.join(checkpoint, "results_eager") - eager_eval = Postprocess(results_dir=results_dir_eager, checkpoint=checkpoint, data_mode="val", batch_size=32, - seed=self.seed, eval_metrics=[eval_metric], channel=self.channel, frac_data=0.33, - lquick=True, ltest=self.ltest) + eager_eval = Postprocess(results_dir=results_dir_eager, + checkpoint=checkpoint, + data_mode="val", + batch_size=32, + seed=self.seed, + eval_metrics=[eval_metric], + channel=self.channel, + frac_data=0.33, + lquick=True, + ltest=self.ltest) eager_eval.run() eager_eval.handle_eval_metrics() @@ -680,7 +609,6 @@ class BestModelSelector(object): if ncheckpoints == 0: raise FileExistsError("{0}: No checkpoint folders found under '{1}'".format(method, model_dir)) else: - # glob.glob yiels unsorted directories, i.e. do the soring now checkpoints_all = sorted(checkpoints_all, key=lambda x: int(x.split("_")[-1].replace("/",""))) print("%{0}: {1:d} checkpoints directories has been found.".format(method, ncheckpoints)) @@ -706,12 +634,12 @@ class BestModelSelector(object): def main(): parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, + parser.add_argument("--input_dir", type=str, required=True, help="Directory where input data as TFRecord-files are stored.") parser.add_argument("--output_dir", help="Output directory where JSON-files, summary, model, plots etc. are saved.") parser.add_argument("--datasplit_dict", help="JSON-file that contains the datasplit configuration") parser.add_argument("--checkpoint", help="Checkpoint directory or checkpoint name (e.g. <my_dir>/model-200000)") - parser.add_argument("--dataset", type=str, help="Dataset name") # as in dataset_utils.DATASETS + parser.add_argument("--dataset", type=str, help="Dataset name") parser.add_argument("--model", type=str, help="Model class name") parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters") parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use") @@ -724,18 +652,23 @@ def main(): help="Test mode for postprocessing to allow bootstrapping on small datasets.") args = parser.parse_args() - # start timing for the whole run - # list pip environment import os print(os.system("pip3 list")) timeit_start = time.time() # create a training instance - train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict, - model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset_name=args.dataset, - gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_start_save=args.frac_start_save, - frac_intv_save=args.frac_intv_save) + train_case = TrainModel(input_dir=args.input_dir, + output_dir=args.output_dir, + datasplit_dict=args.datasplit_dict, + model_hparams_dict=args.model_hparams_dict, + model=args.model, + checkpoint=args.checkpoint, + dataset_name=args.dataset, + gpu_mem_frac=args.gpu_mem_frac, + seed=args.seed, args=args, + frac_start_save=args.frac_start_save, + frac_intv_save=args.frac_intv_save) print('----------------------------------- Options ------------------------------------') for k, v in args._get_kwargs(): diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index fc55de930ca70eb8624ba85bbe9462f7735a59c9..34ed8da4970c0f30b0357e8d72ca128e601c009d 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -729,6 +729,7 @@ class Postprocess(TrainModel): gen_images_denorm = self.denorm_images_all_channels( gen_images, self.vars_in, self.norm_cls, norm_method="minmax" ) + # store data into datset & get number of samples (may differ from batch_size at the end of the test dataset) times_0, init_times = self.get_init_time(t_starts) batch_ds = self.create_dataset( input_images_denorm, gen_images_denorm, init_times @@ -814,7 +815,7 @@ class Postprocess(TrainModel): self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) - def get_input_data_per_batch(self, input_iter, norm_method="cbnorm"): + def get_input_data_per_batch(self, input_iter, norm_method="minmax"): """ Get the input sequence from the dataset iterator object stored in self.inputs and denormalize the data :param input_iter: the iterator object built by make_test_dataset_iterator-method @@ -1297,7 +1298,8 @@ class Postprocess(TrainModel): @staticmethod def denorm_images_all_channels( - image_sequence, varnames, norm, norm_method="minmax") + image_sequence, varnames, norm, norm_method="minmax" + ): """ Denormalize data of all image channels :param image_sequence: list/array [batch, seq, lat, lon, channel] of images diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py b/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py index c2fdd4f1b20751768ce76c138124a215e45f0233..b758eca0efaa5740dcd6f852af4c36b8656962ab 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py @@ -283,120 +283,13 @@ class Postprocess(TrainModel): # the run-factory def run(self): - if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet': + if self.model == "convLSTM" or self.model == "test_model": self.run_deterministic() elif self.run_mode == "deterministic": self.run_deterministic() else: self.run_stochastic() - def run_stochastic(self): - """ - Run session, save results to netcdf, plot input images, generate images and persistent images - """ - method = Postprocess.run_stochastic.__name__ - raise ValueError("ML: %{0} is not runnable now".format(method)) - - self.init_session() - self.restore(self.sess, self.checkpoint) - # Loop for samples - self.sample_ind = 0 - self.prst_metric_all = [] # store evaluation metrics of persistence forecast (shape [future_len]) - self.fcst_metric_all = [] # store evaluation metric of stochastic forecasts (shape [nstoch, batch, future_len]) - while self.sample_ind < self.num_samples_per_epoch: - if self.num_samples_per_epoch < self.sample_ind: - break - else: - # run the inputs and plot each sequence images - self.input_results, self.input_images_denorm_all, self.t_starts = self.get_input_data_per_batch() - - feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()} - gen_loss_stochastic_batch = [] # [stochastic_ind,future_length] - gen_images_stochastic = [] # [stochastic_ind,batch_size,seq_len,lat,lon,channels] - # Loop for stochastics - for stochastic_sample_ind in range(self.num_stochastic_samples): - print("stochastic_sample_ind:", stochastic_sample_ind) - # return [batchsize,seq_len,lat,lon,channel] - gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict) - # The generate images seq_len should be sequence_len -1, since the last one is - # not used for comparing with groud truth - assert gen_images.shape[1] == self.sequence_length - 1 - gen_images_per_batch = [] - if stochastic_sample_ind == 0: - persistent_images_per_batch = [] # [batch_size,seq_len,lat,lon,channel] - ts_batch = [] - for i in range(self.batch_size): - # generate time stamps for sequences only once, since they are the same for all ensemble members - if stochastic_sample_ind == 0: - self.ts = Postprocess.generate_seq_timestamps(self.t_starts[i], len_seq=self.sequence_length) - init_date_str = self.ts[0].strftime("%Y%m%d%H") - ts_batch.append(init_date_str) - # get persistence_images - self.persistence_images, self.ts_persistence = Postprocess.get_persistence(self.ts, - self.input_dir_pkl) - persistent_images_per_batch.append(self.persistence_images) - assert len(np.array(persistent_images_per_batch).shape) == 5 - self.plot_persistence_images() - - # Denormalized data for generate - gen_images_ = gen_images[i] - self.gen_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl, gen_images_, - self.vars_in) - gen_images_per_batch.append(self.gen_images_denorm) - assert len(np.array(gen_images_per_batch).shape) == 5 - # only plot when the first stochastic ind otherwise too many plots would be created - # only plot the stochastic results of user-defined ind - self.plot_generate_images(stochastic_sample_ind, self.stochastic_plot_id) - # calculate the persistnet error per batch - if stochastic_sample_ind == 0: - persistent_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all, - persistent_images_per_batch, - self.future_length, - self.context_frames, - matric="mse", channel=0) - self.prst_metric_all.append(persistent_loss_per_batch) - - # calculate the gen_images_per_batch error - gen_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all, - gen_images_per_batch, self.future_length, - self.context_frames, - matric="mse", channel=0) - gen_loss_stochastic_batch.append( - gen_loss_per_batch) # self.gen_images_stochastic[stochastic,future_length] - print("gen_images_per_batch shape:", np.array(gen_images_per_batch).shape) - gen_images_stochastic.append( - gen_images_per_batch) # [stochastic,batch_size, seq_len, lat, lon, channel] - - # Switch the 0 and 1 position - print("before transpose:", np.array(gen_images_stochastic).shape) - gen_images_stochastic = np.transpose(np.array(gen_images_stochastic), ( - 1, 0, 2, 3, 4, 5)) # [batch_size, stochastic, seq_len, lat, lon, chanel] - Postprocess.check_gen_images_stochastic_shape(gen_images_stochastic) - assert len(gen_images_stochastic.shape) == 6 - assert np.array(gen_images_stochastic).shape[1] == self.num_stochastic_samples - - self.fcst_metric_all.append( - gen_loss_stochastic_batch) # [samples/batch_size,stochastic,future_length] - # save input and stochastic generate images to netcdf file - # For each prediction (either deterministic or ensemble) we create one netCDF file. - for batch_id in range(self.batch_size): - self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id], - persistent_images_per_batch[batch_id], - np.array(gen_images_stochastic)[batch_id], - fl_name="vfp_date_{}_sample_ind_{}.nc" - .format(ts_batch[batch_id], - self.sample_ind + batch_id)) - - self.sample_ind += self.batch_size - - self.persistent_loss_all_batches = np.mean(np.array(self.persistent_loss_all_batches), axis=0) - self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0) - assert len(np.array(self.persistent_loss_all_batches).shape) == 1 - assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length - - assert len(np.array(self.stochastic_loss_all_batches).shape) == 2 - assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples - def run_deterministic(self): """ Revised and vectorized version of run_deterministic diff --git a/video_prediction_tools/model_modules/video_prediction/__init__.py b/video_prediction_tools/model_modules/video_prediction/__init__.py index 3089b251bbbcbe086a03a4ea63571ce9427b2cb9..4ad018e9ea2d38ee103b49a1d2c847dfd2faa691 100644 --- a/video_prediction_tools/model_modules/video_prediction/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/__init__.py @@ -1,3 +1,2 @@ from . import losses from . import metrics -from . import ops diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py index 3b7afcc929e4affcf6f7aa14da37808d1e1faf78..c789f64f8a69a9a5734b638eec43ec2bd250aba1 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py @@ -1,7 +1,6 @@ #from .base_dataset import BaseVideoDataset #from .era5_dataset import ERA5Dataset #from .gzprcp_dataset import GzprcpDataset -#from .moving_mnist import MovingMnist #from data_preprocess.dataset_options import known_datasets from .stats import MinMax, ZScore from .dataset import Dataset diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py index b42702a03842ba53ac3b0bcf8db14b7f83b98bc1..767e82bdd9d91bbdd883d5553e8d53906816a8b6 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py @@ -3,16 +3,14 @@ __date__ = "2022-03-17" __email__ = "b.gong@fz-juelich.de" import json -import os -from typing import List -from dataclasses import dataclass from pathlib import Path - import xarray as xr import tensorflow as tf - +import sys, os +sys.path.append("../../../utils") from hparams_utils import * -from model_modules.video_prediction.datasets.stats import DatasetStats, Normalize +sys.path.append("../../../") +from model_modules.video_prediction.datasets.stats import DatasetStats, Normalize, ZScore class Dataset: @@ -40,16 +38,16 @@ class Dataset: ): """ This class is used for preparing data for training/validation and test models - :param input_dir: the path of tfrecords files + :param input_dir : the path of tfrecords files :param datasplit_path: the path pointing to the datasplit_config json file - :param hparams_path: the path to the dict that contains hparameters, - :param mode: string, "train","val" or "test" - :param seed: int, the seed for shuffeling the dataset - :param nsamples_ref: number of reference samples which can be used to control repetition factor for dataset + :param hparams_path : the path to the dict that contains hparameters, + :param mode : string, "train","val" or "test" + :param seed : int, the seed for shuffeling the dataset + :param nsamples_ref : number of reference samples which can be used to control repetition factor for dataset for ensuring adopted size of dataset iterator (used for validation data during training) Example: Let nsamples_ref be 1000 while the current datset consists 100 samples, then the repetition-factor will be 10 (i.e. nsamples*rep_fac = nsamples_ref) - :param normalize: class of the desired normalization method + :param normalize : class of the desired normalization method """ self.input_dir = input_dir self.output_dir = output_dir @@ -105,7 +103,7 @@ class Dataset: # {"2008":[1,2,3,4,...], "2009":[1,2,3,4,...]} for year, months in time_window.items(): for month in months: - files.append(self.input_dir / self.filename_template.format(year=year, month=month)) + files.append(os.path.join(self.input_dir, self.filename_template.format(year=year, month=month))) return files @@ -116,6 +114,7 @@ class Dataset: :param mode: indicator to differentiate between training, validation and test data """ files = self.filenames(mode) + if not len(files) > 0: raise Exception( f"no files for dataset {mode} found, check data_split dictionary" @@ -167,7 +166,7 @@ class Dataset: # shuffle if shuffle: dataset = dataset.apply( - tf.contrib.data.shuffle_and_repeat( + tf.data.experimental.shuffle_and_repeat( buffer_size=1024, count=self.max_epochs, seed=self.seed ) ) # TODO: check, self.seed @@ -185,7 +184,7 @@ class Dataset: stats = self._stats_lookup[mode] normalize = self.normalize(stats) - stats.to_json(self.output_dir / "normalization_stats.json") + stats.to_json(os.path.join(self.output_dir, "normalization_stats.json")) dataset = dataset.map(normalize.normalize_vars) @@ -208,7 +207,9 @@ class Dataset: """ stats = self._get_stats("test") return int((stats.n + self.shift) / (self.sequence_length + self.shift)) - + + + @property def num_validation_samples(self): """ @@ -243,3 +244,17 @@ class Dataset: @property def test_stats(self): return self._get_stats("test") + + +if __name__ == '__main__': + source_dir = "/Users/gongbing/PycharmProjects/weatherbench" + destination_dir = "/Users/gongbing/PycharmProjects/20221201T100105_gong1_wb_convlstm_gan" + datasplit_path = os.path.join(destination_dir, "data_split.json") + hparams_path = os.path.join("/Users/gongbing/PycharmProjects/20221201T100105_gong1_wb_convlstm_gan", "model_hparams.json") + normalize = ZScore + template="weatherbench_{year}-{month:02}.nc" + ins = Dataset(input_dir=source_dir,output_dir=destination_dir,datasplit_path=datasplit_path, + hparams_path=hparams_path,normalize=normalize, filename_template=template) + dt = ins.make_dataset(mode="train") + for i, data in enumerate(dt): + print(data) \ No newline at end of file diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/stats.py b/video_prediction_tools/model_modules/video_prediction/datasets/stats.py index a1934e6290712f715691fca02554bc012eda7dd6..6b5fb9b428eb9d8dec75d881245fc08d12dce9ad 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/stats.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/stats.py @@ -1,11 +1,11 @@ + from abc import ABC, abstractmethod import dataclasses as dc import json -from pathlib import Path - import numpy as np import tensorflow as tf + @dc.dataclass class VarStats: mean: float diff --git a/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py index 919b0d5ac64f4f9cc081d8aba681974870e91f3f..e25a4f12d5c5566a787c1fecaf7bb3651a67beed 100644 --- a/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py +++ b/video_prediction_tools/model_modules/video_prediction/layers/BasicConvLSTMCell.py @@ -4,7 +4,81 @@ import tensorflow as tf from .layer_def import * +from tensorflow.compat.v1.nn.rnn_cell import LSTMStateTuple +# class Zero_state(tf.Module): +# """Abstract object representing an Convolutional RNN cell. +# """ +# def __init__(self, shape, num_features, name=None): +# super().__init__(name=name) +# +# self.shape = shape +# self.num_features = num_features +# +# def __call__(self, x): +# """Return zero-filled state tensor(s). +# Args: +# batch_size: int, float, or unit Tensor representing the batch size. +# dtype : the data type to use for the state. +# Returns: +# tensor of shape '[batch_size x shape[0] x shape[1] x num_features] +# filled with zeros +# """ +# num_features = self.num_features +# zeros = tf.zeros([tf.shape(x)[0], self.shape[0], self.shape[1], num_features * 2]) +# return zeros +# +# +# class BasicConvLSTMCell(tf.Module): +# """Basic ConvLSTM recurrent network cell. The +# """ +# def __init__(self, shape, filter_size, num_features, forget_bias=1.0, +# state_is_tuple=False, activation=tf.nn.tanh, name=None): +# super().__init__(name=name) +# """Initialize the basic Conv LSTM cell. +# Args: +# shape : int tuple thats the height and width of the cell +# filter_size : int tuple thats the height and width of the filter +# num_features : int thats the depth of the cell +# forget_bias : float, The bias added to forget gates (see above). +# input_size : Deprecated and unused. +# state_is_tuple: If True, accepted and returned states are 2-tuples of +# the `c_state` and `m_state`. If False, they are concatenated +# along the column axis. The latter behavior will soon be deprecated. +# activation: Activation function of the inner states. +# """ +# +# self.shape = shape +# self.filter_size = filter_size +# self.num_features = num_features +# self._forget_bias = forget_bias +# self._state_is_tuple = state_is_tuple +# self._activation = activation +# +# """Long short-term memory cell (LSTM).""" +# +# def __call__(self, inputs, state): +# if self._state_is_tuple: +# c, h = state +# else: +# c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state) +# +# input_h = [inputs, h] +# +# input_h_con = tf.concat(axis = 3, values = input_h) +# concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, +# "decode_1", activate="sigmoid") +# i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat) +# new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * +# self._activation(j)) +# new_h = self._activation(new_c) * tf.nn.sigmoid(o) +# +# if self._state_is_tuple: +# new_state = LSTMStateTuple(new_c, new_h) +# else: +# new_state = tf.concat(axis = 3, values = [new_c, new_h]) +# +# return new_h, new_state class ConvRNNCell(object): """Abstract object representing an Convolutional RNN cell. """ @@ -25,7 +99,7 @@ class ConvRNNCell(object): """Integer or TensorShape: size of outputs produced by this cell.""" raise NotImplementedError("Abstract method") - def zero_state(self,input, dtype): + def zero_state(self, input): """Return zero-filled state tensor(s). Args: batch_size: int, float, or unit Tensor representing the batch size. @@ -35,11 +109,9 @@ class ConvRNNCell(object): filled with zeros """ - shape = self.shape + num_features = self.num_features - #x= tf.placeholder(tf.float32, shape=[input.shape[0], shape[0], shape[1], num_features * 2])#Bing: add this to - zeros = tf.zeros([tf.shape(input)[0], shape[0], shape[1], num_features * 2]) - #zeros = tf.zeros_like(x) + zeros = tf.zeros([input.shape[0], self.shape[0], self.shape[1], num_features * 2]) return zeros @@ -84,74 +156,26 @@ class BasicConvLSTMCell(ConvRNNCell): def __call__(self, inputs, state, scope=None,reuse=None): """Long short-term memory cell (LSTM).""" - with tf.variable_scope(scope or type(self).__name__,reuse=reuse): # "BasicLSTMCell" - # Parameters of gates are concatenated into one multiply for efficiency. - if self._state_is_tuple: - c, h = state - else: - c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state) - - input_h = [inputs,h] - #Bing20200930#replace with non-linear convolutional layers - #concat = _conv_linear([inputs, h], self.filter_size, self.num_features * 4, True) - input_h_con = tf.concat(axis = 3, values = input_h) - concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, "decode_1", activate="sigmoid") - i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat) - new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * - self._activation(j)) - new_h = self._activation(new_c) * tf.nn.sigmoid(o) - - if self._state_is_tuple: - new_state = LSTMStateTuple(new_c, new_h) - else: - new_state = tf.concat(axis = 3, values = [new_c, new_h]) - - return new_h, new_state - - -def _conv_linear(args, filter_size, num_features, bias, bias_start=0.0, scope=None): - """convolution: - Args: - args: a 4D Tensor or a list of 4D, batch x n, Tensors. - filter_size: int tuple of filter height and width. - num_features: int, number of features. - bias_start: starting value to initialize the bias; 0 by default. - scope: VariableScope for the created subgraph; defaults to "Linear". - Returns: - A 4D Tensor with shape [batch h w num_features] - Raises: - ValueError: if some of the arguments has unspecified or wrong shape. - """ - # Calculate the total size of arguments on dimension 1. - total_arg_size_depth = 0 - shapes = [a.get_shape().as_list() for a in args] - for shape in shapes: - if len(shape) != 4: - raise ValueError("Linear is expecting 4D arguments: %s" % str(shapes)) - if not shape[3]: - raise ValueError("Linear expects shape[4] of arguments: %s" % str(shapes)) + # Parameters of gates are concatenated into one multiply for efficiency. + if self._state_is_tuple: + c, h = state else: - total_arg_size_depth += shape[3] + c, h = tf.split(axis = 3, num_or_size_splits = 2, value = state) - dtype = [a.dtype for a in args][0] + input_h = [inputs,h] + input_h_con = tf.concat(axis = 3, values = input_h) + concat = conv_layer(input_h_con, self.filter_size, 1, self.num_features*4, "decode_1", activate="sigmoid") + i, j, f, o = tf.split(axis = 3, num_or_size_splits = 4, value = concat) + new_c = (c * tf.nn.sigmoid(f + self._forget_bias) + tf.nn.sigmoid(i) * + self._activation(j)) + new_h = self._activation(new_c) * tf.nn.sigmoid(o) - # Now the computation. - with tf.variable_scope(scope or "Conv"): - matrix = tf.get_variable( - "Matrix", [filter_size[0], filter_size[1], total_arg_size_depth, num_features], dtype = dtype) - if len(args) == 1: + if self._state_is_tuple: + new_state = LSTMStateTuple(new_c, new_h) + else: + new_state = tf.concat(axis = 3, values = [new_c, new_h]) - res = tf.nn.conv2d(args[0], matrix, strides = [1, 1, 1, 1], padding = 'SAME') + return new_h, new_state - else: - res = tf.nn.conv2d(tf.concat(axis = 3, values = args), matrix, strides = [1, 1, 1, 1], padding = 'SAME') - if not bias: - return res - bias_term = tf.get_variable( - "Bias", [num_features], - dtype = dtype, - initializer = tf.constant_initializer( - bias_start, dtype = dtype)) - return res + bias_term diff --git a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py index 6f7c4f38b222afecb2b21d36bedc938b7813399a..33b2a9f3f84f6a82f4997e34f53e27cf4f34d2a9 100644 --- a/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py +++ b/video_prediction_tools/model_modules/video_prediction/layers/layer_def.py @@ -31,11 +31,11 @@ def _variable_on_gpu(name, shape, initializer): Variable Tensor """ with tf.device('/gpu:0'): - var = tf.get_variable(name, shape, initializer=initializer) + var = tf.Variable(initializer(shape=shape)) return var -def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.layers.xavier_initializer()): +def _variable_with_weight_decay(name, shape, stddev, wd): """Helper to create an initialized Variable with weight decay. Note that the Variable is initialized with a truncated normal distribution. A weight decay is added only if one is specified. @@ -48,117 +48,113 @@ def _variable_with_weight_decay(name, shape, stddev, wd,initializer=tf.contrib.l Returns: Variable Tensor """ - #var = _variable_on_gpu(name, shape,tf.truncated_normal_initializer(stddev = stddev)) + initializer = tf.initializers.GlorotUniform() var = _variable_on_gpu(name, shape, initializer) if wd: weight_decay = tf.multiply(tf.nn.l2_loss(var), wd, name = 'weight_loss') weight_decay.set_shape([]) - tf.add_to_collection('losses', weight_decay) + tf.compat.v1.add_to_collection('losses', weight_decay) return var -def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer() , activate="relu"): - - with tf.variable_scope('{0}_conv'.format(idx)) as scope: - - input_channels = inputs.get_shape()[-1] - weights = _variable_with_weight_decay('weights', shape = [kernel_size, kernel_size, - input_channels, num_features], - stddev = 0.01, wd = weight_decay) - biases = _variable_on_gpu('biases', [num_features], initializer) - conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding='SAME') - conv_biased = tf.nn.bias_add(conv, biases) - if activate == "linear": - return conv_biased - elif activate == "relu": - conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx)) - elif activate == "elu": - conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx)) - elif activate == "leaky_relu": - conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx)) - elif activate == "sigmoid": - conv_rect = tf.nn.sigmoid(conv_biased, name = '{0}_conv'.format(idx)) - else: - raise ("activation function is not correct") - return conv_rect - - -def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.contrib.layers.xavier_initializer(),activate="relu"): - with tf.variable_scope('{0}_trans_conv'.format(idx)) as scope: - input_channels = inputs.get_shape()[3] - input_shape = inputs.get_shape().as_list() - - - weights = _variable_with_weight_decay('weights', - shape = [kernel_size, kernel_size, num_features, input_channels], - stddev = 0.1, wd = weight_decay) - biases = _variable_on_gpu('biases', [num_features],initializer) - batch_size = tf.shape(inputs)[0] - - output_shape = tf.stack( - [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features]) - - conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME') - conv_biased = tf.nn.bias_add(conv, biases) - if activate == "linear": - return conv_biased - elif activate == "elu": - return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx)) - elif activate == "relu": - return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) - elif activate == "leaky_relu": - return tf.nn.leaky_relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) - elif activate == "sigmoid": - return tf.nn.sigmoid(conv_biased, name ='sigmoid') - else: - return conv_biased +def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.initializers.GlorotUniform(), activate="relu"): + + input_channels = inputs.get_shape()[-1] + weights = _variable_with_weight_decay('weights', shape = [kernel_size, kernel_size, + input_channels, num_features], + stddev = 0.01, wd = weight_decay) + biases = _variable_on_gpu('biases', [num_features], initializer) + conv = tf.nn.conv2d(inputs, weights, strides = [1, stride, stride, 1], padding='SAME') + conv_biased = tf.nn.bias_add(conv, biases) + if activate == "linear": + return conv_biased + elif activate == "relu": + conv_rect = tf.nn.relu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "elu": + conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "leaky_relu": + conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx)) + elif activate == "sigmoid": + conv_rect = tf.nn.sigmoid(conv_biased, name = '{0}_conv'.format(idx)) + else: + raise ("activation function is not correct") + return conv_rect + + +def transpose_conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.initializers.GlorotUniform(),activate="relu"): + + input_channels = inputs.get_shape()[3] + input_shape = inputs.get_shape().as_list() + + + weights = _variable_with_weight_decay('weights', + shape = [kernel_size, kernel_size, num_features, input_channels], + stddev = 0.1, wd = weight_decay) + biases = _variable_on_gpu('biases', [num_features],initializer) + + output_shape = tf.stack( + [tf.shape(inputs)[0], input_shape[1] * stride, input_shape[2] * stride, num_features]) + + conv = tf.nn.conv2d_transpose(inputs, weights, output_shape, strides = [1, stride, stride, 1], padding = 'SAME') + conv_biased = tf.nn.bias_add(conv, biases) + if activate == "linear": + return conv_biased + elif activate == "elu": + return tf.nn.elu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "relu": + return tf.nn.relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "leaky_relu": + return tf.nn.leaky_relu(conv_biased, name = '{0}_transpose_conv'.format(idx)) + elif activate == "sigmoid": + return tf.nn.sigmoid(conv_biased, name ='sigmoid') + else: + return conv_biased -def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,initializer=tf.contrib.layers.xavier_initializer()): - with tf.variable_scope('{0}_fc'.format(idx)) as scope: - input_shape = inputs.get_shape().as_list() - if flat: - dim = input_shape[1] * input_shape[2] * input_shape[3] - inputs_processed = tf.reshape(inputs, [-1, dim]) - else: - dim = input_shape[1] - inputs_processed = inputs - - weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init, - wd = weight_decay) - biases = _variable_on_gpu('biases', [hiddens],initializer) - if activate == "linear": - return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc') - elif activate == "sigmoid": - return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) - elif activate == "softmax": - return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) - elif activate == "relu": - return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) - elif activate == "leaky_relu": - return tf.nn.leaky_relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) - else: - ip = tf.add(tf.matmul(inputs_processed, weights), biases) - return tf.nn.elu(ip, name = str(idx) + '_fc') - - -def bn_layers(inputs,idx,is_training=True,epsilon=1e-3,decay=0.99,reuse=None): - with tf.variable_scope('{0}_bn'.format(idx)) as scope: - #Calculate batch mean and variance - shape = inputs.get_shape().as_list() - scale = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=is_training) - beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=is_training) - pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False) - pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False) - - if is_training: - batch_mean, batch_var = tf.nn.moments(inputs,[0]) - train_mean = tf.assign(pop_mean,pop_mean * decay + batch_mean * (1 - decay)) - train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) - with tf.control_dependencies([train_mean,train_var]): - return tf.nn.batch_normalization(inputs,batch_mean,batch_var,beta,scale,epsilon) - else: - return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon) +def fc_layer(inputs, hiddens, idx, flat=False, activate="relu",weight_init=0.01,initializer=tf.initializers.GlorotUniform()): + + input_shape = inputs.get_shape().as_list() + if flat: + dim = input_shape[1] * input_shape[2] * input_shape[3] + inputs_processed = tf.reshape(inputs, [-1, dim]) + else: + dim = input_shape[1] + inputs_processed = inputs + + weights = _variable_with_weight_decay('weights', shape = [dim, hiddens], stddev = weight_init, + wd = weight_decay) + biases = _variable_on_gpu('biases', [hiddens],initializer) + if activate == "linear": + return tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc') + elif activate == "sigmoid": + return tf.nn.sigmoid(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "softmax": + return tf.nn.softmax(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "relu": + return tf.nn.relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + elif activate == "leaky_relu": + return tf.nn.leaky_relu(tf.add(tf.matmul(inputs_processed, weights), biases, name = str(idx) + '_fc')) + else: + ip = tf.add(tf.matmul(inputs_processed, weights), biases) + return tf.nn.elu(ip, name = str(idx) + '_fc') + + +def bn_layers(inputs,is_training=True, epsilon=1e-3,decay=0.99): + + shape = inputs.get_shape().as_list() + scale = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=is_training) + beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=is_training) + pop_mean = tf.Variable(tf.zeros([inputs.get_shape()[-1]]), trainable=False) + pop_var = tf.Variable(tf.ones([inputs.get_shape()[-1]]), trainable=False) + + if is_training: + batch_mean, batch_var = tf.nn.moments(inputs,[0]) + train_mean = tf.assign(pop_mean, pop_mean * decay + batch_mean * (1 - decay)) + train_var = tf.assign(pop_var, pop_var * decay + batch_var * (1 - decay)) + with tf.control_dependencies([train_mean,train_var]): + return tf.nn.batch_normalization(inputs,batch_mean,batch_var,beta,scale,epsilon) + else: + return tf.nn.batch_normalization(inputs,pop_mean,pop_var,beta,scale,epsilon) class batch_norm(object): diff --git a/video_prediction_tools/model_modules/video_prediction/metrics.py b/video_prediction_tools/model_modules/video_prediction/metrics.py index 570efd72d993808f143d632d6734eebffb9b9f7e..9e55bbc4683d5002ae6bb72e0c56d3cc91a1ff79 100644 --- a/video_prediction_tools/model_modules/video_prediction/metrics.py +++ b/video_prediction_tools/model_modules/video_prediction/metrics.py @@ -3,7 +3,6 @@ # SPDX-License-Identifier: MIT import tensorflow as tf -#import lpips_tf import math import numpy as np try: @@ -15,8 +14,6 @@ except: print("Could not get ssmi-function from skimage. Please check installed skimage-package.") raise err - - def mse(a, b): return tf.reduce_mean(tf.squared_difference(a, b), [-3, -2, -1]) @@ -39,14 +36,6 @@ def mse_imgs(image1,image2): mse = ((image1 - image2)**2).mean(axis=None) return mse -# def lpips(input0, input1): -# if input0.shape[-1].value == 1: -# input0 = tf.tile(input0, [1] * (input0.shape.ndims - 1) + [3]) -# if input1.shape[-1].value == 1: -# input1 = tf.tile(input1, [1] * (input1.shape.ndims - 1) + [3]) -# -# distance = lpips_tf.lpips(input0, input1) -# return -distance def ssim_images(image1, image2): """ diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py index 5bad1430f9c7499965899fb948bd46bb8e7686c9..2a164529f4427876084256f424a483dd273ba154 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py @@ -34,7 +34,8 @@ class ConvLstmGANVideoPredictionModel(BaseModels): self.learning_rate = self.hparams.lr self.sequence_length = self.hparams.sequence_length self.opt_var = self.hparams.opt_var - self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) + self.predict_frames = set_and_check_pred_frames(self.sequence_length, + self.context_frames) self.ngf = self.hparams.ngf self.ndf = self.hparams.ndf @@ -43,38 +44,38 @@ class ConvLstmGANVideoPredictionModel(BaseModels): raise ValueError("Method %{}: the hyper-parameter dictionary must include parameters above") - def build_graph(self, x: tf.Tensor): + @tf.function + def train_step(self, x: tf.Tensor, step): - self.inputs = x + with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape: + x_hat = self.build_model(x) - self.global_step = tf.train.get_or_create_global_step() - original_global_variables = tf.global_variables() + self.total_loss = self.get_loss(x, x_hat) - #Build graph - x_hat = self.build_model(x) + generator_gradients = gen_tape.gradient(self.total_loss, self.gen_vars) + discriminator_gradients = disc_tape.gradient(self.D_loss, self.gen_vars) - #Get losses (reconstruciton loss, total loss and descriminator loss) - self.total_loss = self.get_loss(x, x_hat) + #Define optimizer + self.train_op = self.optimizer(self.total_loss) - #Define optimizer - self.train_op = self.optimizer(self.total_loss) + self.G_solver.apply_gradients(zip(generator_gradients, + generator.trainable_variables)) + self.D_solver.apply_gradients(zip(discriminator_gradients, + discriminator.trainable_variables)) - #Save to outputs - self.outputs["gen_images"] = x_hat - self.outputs["total_loss"] = self.total_loss - # Summary op - sum_dict = {"total_loss": self.total_loss, - "D_loss": self.D_loss, - "G_loss": self.G_loss, - "D_loss_fake": self.D_loss_fake, - "D_loss_real": self.D_loss_real, - "recon_loss": self.recon_loss} + #Save to outputs + self.outputs["gen_images"] = x_hat + self.outputs["total_loss"] = self.total_loss + # Summary op + sum_dict = {"total_loss": self.total_loss, + "D_loss": self.D_loss, + "G_loss": self.G_loss, + "D_loss_fake": self.D_loss_fake, + "D_loss_real": self.D_loss_real, + "recon_loss": self.recon_loss} + + self.summary_op = self.summary(**sum_dict) - self.summary_op = self.summary(**sum_dict) - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - self.is_build_graph = True - return self.is_build_graph def get_loss(self, x: tf.Tensor, x_hat: tf.Tensor): """ @@ -83,7 +84,6 @@ class ConvLstmGANVideoPredictionModel(BaseModels): self.G_loss = self.get_gen_loss() self.D_loss = self.get_disc_loss() self._get_vars() - #self.recon_loss = self.get_loss(self, x, x_hat) #use the loss from vanilla convLSTM if self.opt_var == "all": x = x[:, self.context_frames:, :, :, :] @@ -111,33 +111,14 @@ class ConvLstmGANVideoPredictionModel(BaseModels): return total_loss def optimizer(self, *args): + pass - if self.mode == "train": - if self.recon_weight == 1: - print("Only train generator- ConvLSTM") - train_op = tf.train.AdamOptimizer(learning_rate = - self.learning_rate).\ - minimize(self.total_loss, var_list=self.gen_vars) - else: - print("Training discriminator") - self.D_solver = tf.train.AdamOptimizer(learning_rate =self.learning_rate).\ - minimize(self.D_loss, var_list=self.disc_vars) - with tf.control_dependencies([self.D_solver]): - print("Training generator....") - self.G_solver = tf.train.AdamOptimizer(learning_rate =self.learning_rate).\ - minimize(self.total_loss, var_list=self.gen_vars) - with tf.control_dependencies([self.G_solver]): - train_op = tf.assign_add(self.global_step, 1) - else: - train_op = None - return train_op def build_model(self, x): """ Define gan architectures """ - #conditional GAN x_hat = self.generator(x) self.D_real, self.D_real_logits = self.discriminator(self.inputs[:, self.context_frames:, :, :, 0:1]) @@ -163,21 +144,49 @@ class ConvLstmGANVideoPredictionModel(BaseModels): return x_hat + + + + + + + + + def discriminator(self, vid): """ Function that get discriminator architecture """ - with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): - conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis1") + with tf.variable_scope("discriminator", reuse=tf.AUTO_REUSE): + conv1 = tf.layers.conv3d(vid, 64, kernel_size=[4,4,4], + strides=[2,2,2], + padding="SAME", + name="dis1") conv1 = self._lrelu(conv1) - conv2 = tf.layers.conv3d(conv1, 128, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis2") + conv2 = tf.layers.conv3d(conv1, 128, + kernel_size=[4, 4, 4], + strides=[2,2,2], + padding="SAME", + name="dis2") conv2 = self._lrelu(self.bd1(conv2)) - conv3 = tf.layers.conv3d(conv2, 256, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME" ,name="dis3") + conv3 = tf.layers.conv3d(conv2, 256, + kernel_size=[4, 4, 4], + strides=[2,2,2], + padding="SAME", + name="dis3") conv3 = self._lrelu(self.bd2(conv3)) - conv4 = tf.layers.conv3d(conv3, 512, kernel_size=[4,4,4],strides=[2,2,2],padding="SAME", name="dis4") + conv4 = tf.layers.conv3d(conv3, 512, + kernel_size=[4, 4, 4], + strides=[2, 2, 2], + padding="SAME", + name="dis4") conv4 = self._lrelu(self.bd3(conv4)) - conv5 = tf.layers.conv3d(conv4, 1, kernel_size=[2,4,4],strides=[1,1,1],padding="SAME", name="dis5") - conv5 = tf.reshape(conv5, [-1,1]) + conv5 = tf.layers.conv3d(conv4, 1, + kernel_size=[2, 4, 4], + strides=[1, 1, 1], + padding="SAME", + name="dis5") + conv5 = tf.reshape(conv5, [-1, 1]) conv5sigmoid = tf.nn.sigmoid(conv5) return conv5sigmoid, conv5 @@ -187,13 +196,14 @@ class ConvLstmGANVideoPredictionModel(BaseModels): """ real_labels = tf.ones_like(self.D_real) gen_labels = tf.zeros_like(self.D_fake) - self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels)) - self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=gen_labels)) + self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( + logits=self.D_real_logits, labels=real_labels)) + self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( + logits=self.D_fake_logits, labels=gen_labels)) D_loss = self.D_loss_real + self.D_loss_fake return D_loss - def get_gen_loss(self): """ Param: @@ -203,7 +213,7 @@ class ConvLstmGANVideoPredictionModel(BaseModels): """ real_labels = tf.ones_like(self.D_fake) G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, - labels=real_labels)) + labels=real_labels)) return G_loss def _get_vars(self): @@ -214,7 +224,6 @@ class ConvLstmGANVideoPredictionModel(BaseModels): self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] - def _lrelu(self, x, leak=0.2): return tf.maximum(x, leak * x) diff --git a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py index 589a6b32b7fe9c80646e53310d54272f696cb88f..f5dbead33b1540acd9322a93d239f93873004654 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py @@ -5,7 +5,8 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong" __date__ = "2022-04-13" - +import sys +sys.path.append("../utils") from hparams_utils import * import json from abc import ABC, abstractmethod @@ -33,10 +34,11 @@ class BaseModels(ABC): self.outputs = {} self.loss_summary = None self.summary_op = None - self.global_step = tf.train.get_or_create_global_step() self.saveable_variables = None self._is_build_graph_set = False + + def hparams_options(self, hparams_dict_config:str): if hparams_dict_config: @@ -71,7 +73,7 @@ class BaseModels(ABC): return self.hparams - def build_graph(self, x: tf.Tensor)->bool: + def train_step(self, x: tf.Tensor)->bool: """ This function is used for build the graph, and allow a optimiser to the graph by using tensorflow function. @@ -90,17 +92,17 @@ class BaseModels(ABC): return self._is_build_graph_set """ - self.inputs = x - original_global_variables = tf.global_variables() - x_hat = self.build_model(x) - self.total_loss = self.get_loss(x, x_hat) - self.train_op = self.optimizer(self.total_loss) - self.outputs["gen_images"] = x_hat - self.summary_op = self.summary(total_loss = self.total_loss) - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - self._is_build_graph_set = True - return self._is_build_graph_set + # self.inputs = x + # original_global_variables = tf.global_variables() + # x_hat = self.build_model(x) + # self.total_loss = self.get_loss(x, x_hat) + # self.train_op = self.optimizer(self.total_loss) + # self.outputs["gen_images"] = x_hat + # self.summary_op = self.summary(total_loss = self.total_loss) + # global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + # self.saveable_variables = [self.global_step] + global_variables + # self._is_build_graph_set = True + def optimizer(self, total_loss): @@ -141,14 +143,14 @@ class BaseModels(ABC): - @abstractmethod - def build_model(self, x)->tf.Tensor: - """ - This function is used to create the network - Example: see example in vanilla_convLSTM_model.py, it must return prediction fnsrames and save it to the self.output - which is used for calculating the loss - """ - pass + # @abstractmethod + # def build_model(self, x)->tf.Tensor: + # """ + # This function is used to create the network + # Example: see example in vanilla_convLSTM_model.py, it must return prediction fnsrames and save it to the self.output + # which is used for calculating the loss + # """ + # pass diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index 70c050f53efa81f1c482fa06f3cd1a5318b6e987..4d6988c3c95768a3a7cb89716c1d17a48e6013a9 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -4,7 +4,7 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong" -__date__ = "2020-11-05" +__date__ = "2022-01-05" from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames import tensorflow as tf @@ -12,6 +12,7 @@ from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from .our_base_model import BaseModels from hparams_utils import * +tf.config.experimental_run_functions_eagerly(True) class VanillaConvLstmVideoPredictionModel(BaseModels): @@ -36,27 +37,47 @@ class VanillaConvLstmVideoPredictionModel(BaseModels): self.batch_size = hparams.batch_size self.shuffle_on_val = hparams.shuffle_on_val self.loss_fun = hparams.loss_fun - self.opt_var = hparams.opt_var + self.opt_var = "all" + #hparams.opt_var self.lr = hparams.lr self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) print("The model hparams have been parsed successfully! ") except Exception as e: - raise ValueError(f"missing hyperparameter: {e.args[0]}") + raise ValueError(f"missing hyper-parameter: {e.args[0]}") + def compile(self, input_shape): + self.model = ConvLSTM_network(input_shape, self.sequence_length, self.context_frames) + self.optimizer = tf.keras.optimizers.Adam(self.lr) + + @tf.function + def train_step(self, x, step): + print("x.shape", x.shape) + + # Open a GradientTape to record the operations run + # during the forward pass, which enables auto-differentiation. + with tf.GradientTape() as tape: + + # Run the forward pass of the layer. + # The operations that the layer applies + # to its inputs are going to be recorded + # on the GradientTape. + + x_hat = self.model(x) + + # Compute the loss value for this minibatch. + self.loss_value = self.get_loss(x, x_hat) + + # Use the gradient tape to automatically retrieve + # the gradients of the trainable variables with respect to the loss. + grads = tape.gradient(self.loss_value, self.model.network_template.trainable_variables) + + # Run one step of gradient descent by updating + # the value of the variables to minimize the loss. + self.optimizer.apply_gradients(zip(grads, self.model.network_template.trainable_variables)) + print("optimiasec") + #print("after training:", self.optimizer.iterations.numpy()) - def build_graph(self, x:tf.Tensor): - self.inputs = x - original_global_variables = tf.global_variables() - x_hat = self.build_model(x) - self.total_loss = self.get_loss(x, x_hat) - self.train_op = self.optimizer(self.total_loss) - self.outputs["gen_images"] = x_hat - self.summary_op = self.summary(total_loss = self.total_loss) - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - self._is_build_graph_set = True - return self._is_build_graph_set def get_loss(self, x:tf.Tensor, x_hat:tf.Tensor)->tf.Tensor: @@ -65,18 +86,17 @@ class VanillaConvLstmVideoPredictionModel(BaseModels): :param x_hat: Prediction/output tensors :return : the loss function """ - #This is the loss function (MSE): - #Optimize all target variables/channels + if self.opt_var == "all": x = x[:, self.context_frames:, :, :, :] - print("The model is optimzied on all the variables in the loss function") elif self.opt_var != "all" and isinstance(self.opt_var, str): self.opt_var = int(self.opt_var) - print("The model is optimized on the {} variable in the loss function".format(self.opt_var)) x = x[:, self.context_frames:, :, :, self.opt_var] x_hat = x_hat[:, :, :, :, self.opt_var] else: - raise ValueError("The opt var in the hyperparameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables") + raise ValueError("The opt var in the hyperparameters setup should " + "be '0','1','2' indicate the index of target variable " + "to be optimised or 'all' indicating optimize all the variables") if self.loss_fun == "mse": total_loss = tf.reduce_mean(tf.square(x - x_hat)) @@ -86,72 +106,101 @@ class VanillaConvLstmVideoPredictionModel(BaseModels): bce = tf.keras.losses.BinaryCrossentropy() total_loss = bce(x_flatten, x_hat_predict_frames_flatten) else: - raise ValueError("Loss function is not selected properly, you should chose either 'mse' or 'cross_entropy'") + raise ValueError("Loss function is not selected properly, " + "you should chose either 'mse' or 'cross_entropy'") return total_loss - - def summary(self, total_loss)->None: + def summary(self, total_loss, step)->None: """ return the summary operation can be used for TensorBoard """ - tf.summary.scalar("total_loss", total_loss) + tf.summary.scalar("total_loss", total_loss, step) summary_op = tf.summary.merge_all() return summary_op - def build_model(self, x: tf.Tensor): - network_template = tf.make_template('network', - VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables - - x_hat = VanillaConvLstmVideoPredictionModel.convLSTM_network(x, - self.sequence_length, - self.context_frames, - network_template) - return x_hat - - - @staticmethod - def convLSTM_network(x:tf.Tensor, sequence_length:int, context_frames:int, network_template:tf.make_template)->tf.Tensor: + # @staticmethod + # def convLSTM_network(x:tf.Tensor, + # sequence_length:int, + # context_frames:int, + # network_template:tf.compat.v1.make_template)->tf.Tensor: + # + # # create network + # x_hat = [] + # + # # This is for training (optimization of convLSTM layer) + # hidden_g = None + # for i in range(sequence_length - 1): + # if i < context_frames: + # print("i",i) + # x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g) + # else: + # x_1_g, hidden_g = network_template(x_1_g, hidden_g) + # x_hat.append(x_1_g) + # + # # pack them all together + # x_hat = tf.stack(x_hat) + # x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim + # x_hat = x_hat[:, context_frames - 1:, :, :, :] + # return x_hat + + + +class ConvLSTM_network(tf.Module): + + def __init__(self, shape:list, sequence_length:int, context_frames:int, name=None): + super(ConvLSTM_network, self).__init__(name = name) + + self.shape = shape + self.sequence_length = sequence_length + self.context_frames = context_frames + self.network_template = tf.compat.v1.make_template('network', + ConvLSTM_network.convLSTM_cell) + + @tf.Module.with_name_scope + def __call__(self, x): # create network x_hat = [] - # This is for training (optimization of convLSTM layer) hidden_g = None - for i in range(sequence_length - 1): - if i < context_frames: - x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g) + for i in range(self.sequence_length - 1): + if i < self.context_frames: + + x_1_g, hidden_g = self.network_template(x[:, i, :, :, :], hidden_g) else: - x_1_g, hidden_g = network_template(x_1_g, hidden_g) + x_1_g, hidden_g = self.network_template(x_1_g, hidden_g) x_hat.append(x_1_g) # pack them all together x_hat = tf.stack(x_hat) x_hat = tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim - x_hat = x_hat[:, context_frames - 1:, :, :, :] + x_hat = x_hat[:, self.context_frames - 1:, :, :, :] return x_hat - @staticmethod - - def convLSTM_cell(inputs:tf.Tensor, hidden:tf.Tensor): + def convLSTM_cell(inputs: tf.Tensor, hidden: tf.Tensor): """ - SPDX-FileCopyrightText: loliverhennigh + SPDX-FileCopyrightText: loliverhennigh SPDX-License-Identifier: Apache-2.0 - The following function was revised based on the github https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow + The following function was revised based on the github https://github.com/loliverhennigh/Convolutional-LSTM-in-Tensorflow """ y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset # conv lstm cell cell_shape = y_0.get_shape().as_list() channels = cell_shape[-1] - with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): - cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64) - if hidden is None: - hidden = cell.zero_state(y_0, tf.float32) - output, hidden = cell(y_0, hidden) + + cell = BasicConvLSTMCell(shape = [cell_shape[1], + cell_shape[2]], + filter_size = 5, + num_features = 64) + if hidden is None: + hidden = cell.zero_state(y_0) + output, hidden = cell(y_0, hidden) output_shape = output.get_shape().as_list() z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) - #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction + x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") return x_hat, hidden + diff --git a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py index b2a1534e43012d56677259eba57bb60049bad6aa..6b1aeb5860e699c324793ffe7d91021c6d61fd06 100644 --- a/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py +++ b/video_prediction_tools/model_modules/video_prediction/utils/tf_utils.py @@ -1,15 +1,6 @@ -import itertools -import os -from collections import OrderedDict -import numpy as np -import six +import os import tensorflow as tf -import tensorflow.contrib.graph_editor as ge -from tensorflow.core.framework import node_def_pb2 -from tensorflow.python.framework import device as pydev -from tensorflow.python.training import device_setter -from tensorflow.python.util import nest diff --git a/video_prediction_tools/postprocess/plot_ambs_forecast.py b/video_prediction_tools/postprocess/plot_ambs_forecast.py index 1ba2a7b2f4f3ffe0c5c591e4f0f8aecbcb614a35..5f0ce2fbe9645e2f8583ad0aaa9e783ac586d8be 100644 --- a/video_prediction_tools/postprocess/plot_ambs_forecast.py +++ b/video_prediction_tools/postprocess/plot_ambs_forecast.py @@ -21,7 +21,6 @@ import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt -import cartopy from mpl_toolkits.basemap import Basemap @@ -102,11 +101,8 @@ def create_plot(data: xr.DataArray, data_ref: xr.DataArray, varname: str, fcst_h cbar_ax = fig.add_axes([0.95, pos.y0+0.08*(t*2-1), 0.02, pos.y1 - pos.y0]) cbar = fig.colorbar(cs, cax=cbar_ax, orientation="vertical", ticks=cbar_ticks) cbar.set_label(cbar_labs[t]) - # save to disk - #plt.show() - #plt.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=-0.5, wspace=0.05) + plt.subplots_adjust(hspace=-0.5) - #plt.tight_layout() plt.savefig(plt_fname, bbox_inches="tight") plt.close() @@ -128,7 +124,7 @@ def main(args): fhh = args.forecast_hour if not os.path.isfile(filename): - raise FileNotFoundError("Could not find the indictaed netCDF-file '{0}'".format(filename)) + raise FileNotFoundError("Could not find the indicated netCDF-file '{0}'".format(filename)) with xr.open_dataset(filename) as dfile: t2m_fcst, t2m_ref = dfile["2t_fcst"], dfile["2t_ref"]