diff --git a/test/run_pytest.sh b/test/run_pytest.sh index b440bebf7026eb777f0bd3aa82b7894f5c0e540c..83220d34a51379e93add931ae6e03e9491b5bce4 100644 --- a/test/run_pytest.sh +++ b/test/run_pytest.sh @@ -2,7 +2,7 @@ # Name of virtual environment #VIRT_ENV_NAME="vp_new_structure" -VIRT_ENV_NAME="env_hdfml" +VIRT_ENV_NAME="juwels_env" if [ -z ${VIRTUAL_ENV} ]; then if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then @@ -21,8 +21,10 @@ fi #python -m pytest test_prepare_era5_data.py ##Test for preprocess_step1 #python -m pytest test_process_netCDF_v2.py - source ../video_prediction_tools/env_setup/modules_train.sh +##Test for preprocess moving mnist +#python -m pytest test_prepare_moving_mnist_data.py +python -m pytest test_train_moving_mnist_data.py #Test for process step2 #python -m pytest test_data_preprocess_step2.py #python -m pytest test_era5_data.py @@ -31,5 +33,5 @@ source ../video_prediction_tools/env_setup/modules_train.sh #rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/* #python -m pytest test_train_model_era5.py #python -m pytest test_vanilla_vae_model.py -python -m pytest test_visualize_postprocess.py +#python -m pytest test_visualize_postprocess.py #python -m pytest test_meta_postprocess.py diff --git a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh index 24d189278868d67853c74192793f9c13c3c7fb94..ed35fd8b68d2d8593c4e9ff411fd0c142b360204 100644 --- a/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh +++ b/video_prediction_tools/HPC_scripts/visualize_postprocess_era5_template.sh @@ -4,11 +4,11 @@ #SBATCH --ntasks=1 ##SBATCH --ntasks-per-node=1 #SBATCH --cpus-per-task=1 -#SBATCH --output=generate_era5-out.%j -#SBATCH --error=generate_era5-err.%j -#SBATCH --time=00:20:00 +#SBATCH --output=postprocess_era5-out.%j +#SBATCH --error=postprocess_era5-err.%j +#SBATCH --time=01:00:00 #SBATCH --gres=gpu:1 -#SBATCH --partition=develgpus +#SBATCH --partition=gpus #SBATCH --mail-type=ALL #SBATCH --mail-user=b.gong@fz-juelich.de ##jutil env activate -p cjjsc42 @@ -47,4 +47,4 @@ model=convLSTM srun python -u ../main_scripts/main_visualize_postprocess.py --checkpoint ${checkpoint_dir} --mode test \ --results_dir ${results_dir} --batch_size 4 \ --num_stochastic_samples 1 \ - > generate_era5-out.out + > postprocess_era5-out_all.${SLURM_JOB_ID} diff --git a/video_prediction_tools/data_preprocess/dataset_options.py b/video_prediction_tools/data_preprocess/dataset_options.py index 28dffb6c8879bd934c6a8f7169ee0a6bcf679999..5e9729d693720e0e1380170a436980fdbeb900e7 100644 --- a/video_prediction_tools/data_preprocess/dataset_options.py +++ b/video_prediction_tools/data_preprocess/dataset_options.py @@ -16,4 +16,4 @@ def known_datasets(): # "era5_anomaly":"ERA5Dataset_v2_anomaly", } - return dataset_mappings \ No newline at end of file + return dataset_mappings diff --git a/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py b/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py new file mode 100644 index 0000000000000000000000000000000000000000..444a6e0bdb0c11b25d19984236a71a3cefb9a2fa --- /dev/null +++ b/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py @@ -0,0 +1,129 @@ +""" +Class and functions required for preprocessing Moving mnist data from .npz to TFRecords +""" +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong, Karim Mache" +__date__ = "2021_05_04" + + +import os +import numpy as np +import tensorflow as tf +import argparse +from model_modules.video_prediction.datasets.moving_mnist import MovingMnist + + +class MovingMnist2Tfrecords(MovingMnist): + + def __init__(self, input_dir=None, dest_dir=None, sequences_per_file=128): + """ + This class is used for converting .npz files to tfrecords + + :param input_dir: str, the path direcotry to the file of npz + :param dest_dir: the output directory to save TFrecords. + :param sequence_length: int, default is 20, the sequence length per sample + :param sequences_per_file:int, how many sequences/samples per tfrecord to be saved + """ + self.input_dir = input_dir + self.output_dir = dest_dir + os.makedirs(self.output_dir, exist_ok = True) + self.sequences_per_file = sequences_per_file + self.write_sequence_file() + + + def __call__(self): + """ + steps to process npy file to tfrecords + :return: None + """ + self.read_npz_file() + self.save_npz_to_tfrecords() + + def read_npz_file(self): + self.data = np.load(os.path.join(self.input_dir, "mnist_test_seq.npy")) + print("data in minist_test_Seq shape", self.data.shape) + return None + + def save_npz_to_tfrecords(self): # Bing: original 128 + """ + Read the moving_mnst data which is npz format, and save it to tfrecords files + The shape of dat_npz is [seq_length,number_samples,height,width] + moving_mnst only has one channel + """ + idx = 0 + num_samples = self.data.shape[1] + if len(self.data.shape) == 4: + #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel] + self.data = np.expand_dims(self.data, axis = 4) + elif len(self.data.shape) == 5: + pass + else: + raise (f"The shape of input movning mnist npz file is {len(self.data.shape)} which is not either 4 or 5, please further check your data source!") + + self.data = self.data.astype(np.float32) + self.data/= 255.0 # normalize RGB codes by dividing it to the max RGB value + while idx < num_samples - self.sequences_per_file: + sequences = self.data[:, idx:idx+self.sequences_per_file, :, :, :] + output_fname = 'sequence_index_{}_to_{}.tfrecords'.format(idx, idx + self.sequences_per_file-1) + output_fname = os.path.join(self.output_dir, output_fname) + MovingMnist2Tfrecords.save_tf_record(output_fname, sequences) + idx = idx + self.sequences_per_file + return None + + @staticmethod + def save_tf_record(output_fname, sequences): + with tf.python_io.TFRecordWriter(output_fname) as writer: + for i in range(np.array(sequences).shape[1] - 1): + sequence = sequences[:, i, :, :, :] + num_frames = len(sequence) + height, width = sequence[0, :, :, 0].shape + encoded_sequence = np.array([list(image) for image in sequence]) + features = tf.train.Features(feature = { + 'sequence_length': _int64_feature(num_frames), + 'height': _int64_feature(height), + 'width': _int64_feature(width), + 'channels': _int64_feature(1), + 'images/encoded': _floats_feature(encoded_sequence.flatten()), + }) + example = tf.train.Example(features = features) + writer.write(example.SerializeToString()) + + def write_sequence_file(self): + """ + Generate a txt file, with the numbers of sequences for each tfrecords file. + This is mainly used for calculting the number of samples for each epoch during training epoch + """ + + with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file: + seq_file.write("%d\n" % self.sequences_per_file) + + + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bytes_list_feature(values): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + +def _floats_feature(value): + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-input_dir", type=str, help="The input directory that contains the movning mnnist npz file", default="/p/largedata/datasets/moving-mnist/mnist_test_seq.npy") + parser.add_argument("-output_dir", type=str) + parser.add_argument("-sequences_per_file", type=int, default=2) + args = parser.parse_args() + inst = MovingMnist2Tfrecords(args.input_dir, args.output_dir, args.sequence_per_file) + inst() + + +if __name__ == '__main__': + main() diff --git a/video_prediction_tools/data_split/moving_mnist/datasplit.json b/video_prediction_tools/data_split/moving_mnist/datasplit.json index 217b285d8e105debbe7841735eb50786762ace19..0c199e18b6685404b1e137a139985f1b511bc4c4 100644 --- a/video_prediction_tools/data_split/moving_mnist/datasplit.json +++ b/video_prediction_tools/data_split/moving_mnist/datasplit.json @@ -1,11 +1,10 @@ { "train":{ - "index1":[0,100], - "index2":[150,200] + "index1":[0,99] }, "val": { - "index1":[110,149] + "index1":[100,149] }, "test": { diff --git a/video_prediction_tools/data_split/moving_mnist/datasplit_template.json b/video_prediction_tools/data_split/moving_mnist/datasplit_template.json index 11407a0439e7bd3d1397d6dfce9cce660786a866..890b7e4599d429a0ee91fd2ebf79ecf345168dda 100644 --- a/video_prediction_tools/data_split/moving_mnist/datasplit_template.json +++ b/video_prediction_tools/data_split/moving_mnist/datasplit_template.json @@ -7,12 +7,11 @@ # Be aware that this is a prue data file, i.e. do not make use of any Python-functions such as np.range or similar here! { "train":{ - "index1":[0,100], - "index2":[150,200] + "index1":[0,100] }, "val": { - "index1":[110,149] + "index1":[100,149] }, "test": { diff --git a/video_prediction_tools/env_setup/create_env.sh b/video_prediction_tools/env_setup/create_env.sh index 1a5ce05448b3fa2b0b44ce002165e0ebe4b1c94f..31f2065e9c7475762f654bcb878fcd5de7421ba4 100755 --- a/video_prediction_tools/env_setup/create_env.sh +++ b/video_prediction_tools/env_setup/create_env.sh @@ -63,27 +63,27 @@ else ENV_EXIST=0 fi -# add personal email-address to Batch-scripts -if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == *juwels* ]]; then - if [[ "${HOST_NAME}" == jwlogin2[1-4]* ]]; then - # on Juwels Booster, we are in a container environment -> loading modules is not possible - echo "***** Note for Juwels Booster! *****" - echo "Already checked the required modules?" - echo "To do so, run 'source modules_train.sh' after exiting the singularity." - echo "***** Note for Juwels Booster! *****" - else - # load modules and check for their availability - echo "***** Checking modules required during the workflow... *****" - source ${ENV_SETUP_DIR}/modules_preprocess.sh purge - source ${ENV_SETUP_DIR}/modules_train.sh purge - source ${ENV_SETUP_DIR}/modules_postprocess.sh - fi -else - # unset PYTHONPATH on every other machine that is not a known HPC-system - unset PYTHONPATH -fi - +# Create fresh virtual environment or just activate the existing one if [[ "$ENV_EXIST" == 0 ]]; then + # Check modules first + if [[ "${HOST_NAME}" == hdfml* || "${HOST_NAME}" == *juwels* ]]; then + if [[ "${HOST_NAME}" == jwlogin2[1-4]* ]]; then + # on Juwels Booster, we are in a container environment -> loading modules is not possible + echo "***** Note for Juwels Booster! *****" + echo "Already checked the required modules?" + echo "To do so, run 'source modules_train.sh' after exiting the singularity." + echo "***** Note for Juwels Booster! *****" + else + # load modules and check for their availability + echo "***** Checking modules required during the workflow... *****" + source ${ENV_SETUP_DIR}/modules_preprocess.sh purge + source ${ENV_SETUP_DIR}/modules_train.sh purge + source ${ENV_SETUP_DIR}/modules_postprocess.sh + fi + else + # unset PYTHONPATH on every other machine that is not a known HPC-system + unset PYTHONPATH + fi # Activate virtual environment and install additional Python packages. echo "Configuring and activating virtual environment on ${HOST_NAME}" @@ -143,7 +143,8 @@ if [[ "$ENV_EXIST" == 0 ]]; then fi info_str="Virtual environment ${ENV_DIR} has been set up successfully." elif [[ "$ENV_EXIST" == 1 ]]; then - # activating virtual env is suifficient + # loading modules of postprocessing and activating virtual env are suifficient + source ${ENV_SETUP_DIR}/modules_postprocess.sh source ${ENV_DIR}/bin/activate info_str="Virtual environment ${ENV_DIR} has been activated successfully." fi @@ -154,10 +155,10 @@ source "${WORKING_DIR}"/utils/runscript_generator/setup_runscript_templates.sh echo "******************************************** NOTE ********************************************" echo "${info_str}" -echo "Make use of config_runscript.py to generate customized runscripts of the workflow steps." +echo "Make use of generate_runscript.py to generate customized runscripts of the workflow steps." echo "******************************************** NOTE ********************************************" # finally clean up loaded modules (if we are not on Juwels) -if [[ "${HOST_NAME}" == *hdfml* || "${HOST_NAME}" == *juwels* ]] && [[ "${HOST_NAME}" != jwlogin2[1-4]* ]]; then - module --force purge -fi +#if [[ "${HOST_NAME}" == *hdfml* || "${HOST_NAME}" == *juwels* ]] && [[ "${HOST_NAME}" != jwlogin2[1-4]* ]]; then +# module --force purge +#fi diff --git a/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json index 17783b5f68c1974ecb12c43948c14fcba77acd8e..878c29a0553ddb74a563299a7d3ec5683469194c 100644 --- a/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json @@ -4,10 +4,8 @@ "lr": 0.001, "max_epochs":20, "context_frames":10, - "sequence_length":20, "loss_fun":"rmse", "shuffle_on_val":false - } diff --git a/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..a2b9be547d450d49ef230e37821f503838a5dcee --- /dev/null +++ b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json @@ -0,0 +1,14 @@ + +{ + "batch_size": 4, + "lr": 0.001, + "max_epochs":20, + "context_frames":12, + "loss_fun":"rmse", + "shuffle_on_val":false, + "recon_weight":0.6 + +} + + + diff --git a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json b/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json index c2edaad9f9ac158f6e7b8d94bb81db16d55d05e8..0b3788d726fc91e3d5c1aec98166259a2e0012e9 100644 --- a/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/mcnet/model_hparams_template.json @@ -3,9 +3,7 @@ "batch_size": 10, "lr": 0.001, "max_epochs":2, - "context_frames":10, - "sequence_length":20 - + "context_frames":10 } diff --git a/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json index c19ecf6d3b565268efb5f74e14943f32a19519b4..0ccf44e6370f765857204317f172c866865b4b35 100644 --- a/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/ours_gan/model_hparams_template.json @@ -13,6 +13,5 @@ "state_weight": 0.0, "nz": 32, "max_epochs":2, - "context_frames":12, - "sequence_length":24 + "context_frames":12 } diff --git a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json b/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json index 0acefc42a13583d13ac1c263ea44593a6d5e17d0..8e96727e95f761efc170365abd4e8af89696c168 100644 --- a/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/ours_vae_l1/model_hparams_template.json @@ -11,7 +11,5 @@ "state_weight": 0.0, "nz": 32, "max_epochs":2, - "context_frames":10, - "sequence_length":20 - + "context_frames":10 } diff --git a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json index d7058c6b2534d46cd6e08672d33a76bfcc4c7a35..d182658a2161a0405d9fb92d9677fc34bd39251f 100644 --- a/video_prediction_tools/hparams/era5/savp/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/savp/model_hparams_template.json @@ -13,8 +13,7 @@ "state_weight": 0.0, "nz": 16, "max_epochs":2, - "context_frames":10, - "sequence_length":20 + "context_frames":10 } diff --git a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json b/video_prediction_tools/hparams/era5/vae/model_hparams_template.json index 1afb4d4391421b34111c252f4775d448be7675d0..2dcecd346b9b4adc4f3179020d0ee83b8512c6a0 100644 --- a/video_prediction_tools/hparams/era5/vae/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/vae/model_hparams_template.json @@ -5,7 +5,6 @@ "nz":16, "max_epochs":2, "context_frames":10, - "sequence_length":20, "weight_recon":1, "loss_fun": "rmse", "shuffle_on_val": false diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json index b59f6cb2ee96162b2eb6014d7ca6bd37f54d4218..6cda5552d437b9b283a40bdde77eb1d3b3497b36 100644 --- a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json +++ b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams.json @@ -4,7 +4,6 @@ "lr": 0.001, "max_epochs":20, "context_frames":10, - "sequence_length":20, "loss_fun":"cross_entropy" } diff --git a/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..6cda5552d437b9b283a40bdde77eb1d3b3497b36 --- /dev/null +++ b/video_prediction_tools/hparams/moving_mnist/convLSTM/model_hparams_template.json @@ -0,0 +1,11 @@ + +{ + "batch_size": 10, + "lr": 0.001, + "max_epochs":20, + "context_frames":10, + "loss_fun":"cross_entropy" +} + + + diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 053d7d7060529b87a290501e86e1a16d5e9cc4a6..f6fd6e78db9c49f433d679c03c97463a654192d7 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -136,23 +136,22 @@ class TrainModel(object): def setup_dataset(self): """ - Setup train and val dataset instance with the corresponding data split configuration + Setup train and val dataset instance with the corresponding data split configuration. + Simultaneously, sequence_length is attached to the hyperparameter dictionary. """ VideoDataset = datasets.get_dataset_class(self.dataset) self.train_dataset = VideoDataset(input_dir=self.input_dir,mode='train',datasplit_config=self.datasplit_dict) self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val',datasplit_config=self.datasplit_dict) - #self.variable_scope = tf.get_variable_scope() - #self.variable_scope.set_use_resource(True) - + + self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) def setup_model(self): """ Set up model instance for the given model names """ VideoPredictionModel = models.get_model_class(self.model) - self.video_model = VideoPredictionModel( - hparams_dict=self.model_hparams_dict_load, - ) + self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict_load) + def setup_graph(self): """ build model graph @@ -178,8 +177,8 @@ class TrainModel(object): self.inputs = self.iterator.get_next() #since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model, # otherwise the model will raise error - if self.dataset == "era5" and self.model == "savp": - del self.inputs["T_start"] + #if self.dataset == "era5" and self.model == "savp": + # del self.inputs["T_start"] @@ -232,6 +231,7 @@ class TrainModel(object): self.num_examples = self.train_dataset.num_examples_per_epoch() self.steps_per_epoch = int(self.num_examples/batch_size) self.total_steps = self.steps_per_epoch * max_epochs + print("Batch size is {} ; max_epochs is {}; num_samples per epoch is {}; steps_per_epoch is {}, total steps is {}".format(batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps)) def restore(self,sess, checkpoints, restore_to_checkpoint_mapping=None): """ @@ -294,11 +294,15 @@ class TrainModel(object): self.create_fetches_for_train() # In addition to the loss, we fetch the optimizer self.results = sess.run(self.fetches) # ...and run it here! train_losses.append(self.results["total_loss"]) + print("t_start for training",self.results["inputs"]["T_start"]) + print("len of t_start per iteration",len(self.results["inputs"]["T_start"])) #Run and fetch losses for validation data val_handle_eval = sess.run(self.val_handle) self.create_fetches_for_val() self.val_results = sess.run(self.val_fetches,feed_dict={self.train_handle: val_handle_eval}) val_losses.append(self.val_results["total_loss"]) + print("t_start for validation",self.val_results["inputs"]["T_start"]) + print("len of t_start per iteration",len(self.val_results["inputs"]["T_start"])) self.write_to_summary() self.print_results(step,self.results) timeit_end = time.time() @@ -335,6 +339,8 @@ class TrainModel(object): if self.video_model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel": self.fetches_for_train_convLSTM() if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.fetches_for_train_savp() if self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": self.fetches_for_train_vae() + if self.video_model.__class__.__name__ == "VanillaGANVideoPredictionModel":self.fetches_for_train_gan() + if self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel":self.fetches_for_train_convLSTM() return self.fetches def fetches_for_train_convLSTM(self): @@ -342,8 +348,7 @@ class TrainModel(object): Fetch variables in the graph for convLSTM model, this can be custermized based on models and the needs of users """ self.fetches["total_loss"] = self.video_model.total_loss - - + self.fetches["inputs"] = self.video_model.inputs def fetches_for_train_savp(self): @@ -355,7 +360,7 @@ class TrainModel(object): self.fetches["d_loss"] = self.video_model.d_loss self.fetches["g_loss"] = self.video_model.g_loss self.fetches["total_loss"] = self.video_model.g_loss - + self.fetches["inputs"] = self.video_model.inputs def fetches_for_train_mcnet(self): @@ -374,15 +379,19 @@ class TrainModel(object): self.fetches["recon_loss"] = self.video_model.recon_loss self.fetches["total_loss"] = self.video_model.total_loss + def fetches_for_train_gan(self): + self.fetches["total_loss"] = self.video_model.total_loss + def create_fetches_for_val(self): """ Fetch variables in the graph for validation dataset, this can be custermized based on models and the needs of users """ if self.video_model.__class__.__name__ == "SAVPVideoPredictionModel": self.val_fetches = {"total_loss": self.video_model.g_loss} + self.val_fetches["inputs"] = self.video_model.inputs else: self.val_fetches = {"total_loss": self.video_model.total_loss} - + self.val_fetches["inputs"] = self.video_model.inputs self.val_fetches["summary"] = self.video_model.summary_op def write_to_summary(self): diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 196cfee6508ec298dbb390941d787cb6957022c9..60dfef0032a7746d57734285a7b86328e0c74f50 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -422,19 +422,20 @@ class Postprocess(TrainModel): # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel] feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()} gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict) + # sanity check on length of forecast sequence assert gen_images.shape[1] == self.sequence_length - 1, \ "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(method) # denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method) gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls, norm_method="minmax") - # store data into datset + # store data into datset and get number of samples (may differ from batch_size at the end of the test dataset) times_0, init_times = self.get_init_time(t_starts) batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times) - # auxilary list of forecast dimensions - dims_fcst = list(batch_ds["{0}_ref".format(self.vars_in[0])].dims) + nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind) + batch_ds = batch_ds.isel(init_time=slice(0, nbs)) - for i in np.arange(self.batch_size): + for i in np.arange(nbs): # work-around to make use of get_persistence_forecast_per_sample-method times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime() # get persistence forecast for sequences at hand and write to dataset @@ -541,9 +542,10 @@ class Postprocess(TrainModel): .format(method, ", ".join(misses))) varname_ref = "{0}_ref".format(varname) - # reset init-time coordinate of metric_ds in place + # reset init-time coordinate of metric_ds in place and get indices for slicing + ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch) init_times_metric = metric_ds["init_time"].values - init_times_metric[ind_start:ind_start+self.batch_size] = data_ds["init_time"] + init_times_metric[ind_start:ind_end] = data_ds["init_time"] metric_ds = metric_ds.assign_coords(init_time=init_times_metric) # populate metric_ds for fcst_prod in self.fcst_products.keys(): diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py index ca602a954a107c9217942e7f01e4eae4c68d58bb..5836ab9fce48692252a4dbc44415b4a4f9e2c2c3 100644 --- a/video_prediction_tools/model_modules/model_architectures.py +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -14,6 +14,8 @@ def known_models(): 'vae': 'VanillaVAEVideoPredictionModel', 'convLSTM': 'VanillaConvLstmVideoPredictionModel', 'mcnet': 'McNetVideoPredictionModel', + 'gan': "VanillaGANVideoPredictionModel", + 'convLSTM_gan': "ConvLstmGANVideoPredictionModel", 'ours_vae_l1': 'SAVPVideoPredictionModel', 'ours_gan': 'SAVPVideoPredictionModel', } diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py index 7a70351e7808103e9a3e02e65654f151213c45ec..cd0ec2b230169016cc10aee5ee2ff3d7e4fc611b 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py @@ -18,13 +18,12 @@ def get_dataset_class(dataset): if dataset_class is None: raise ValueError('Invalid dataset %s' % dataset) else: - # ERA5Dataset does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which - # does not need to be a subclass of BaseVideoDataset - if not dataset_class == "ERA5Dataset": - dataset_class = globals().get(dataset_class) - if not issubclass(dataset_class,BaseVideoDataset): - raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class)) - else: - dataset_class = globals().get(dataset_class) + # ERA5Dataset movning_mnist does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which does not need to be a subclass of BaseVideoDataset + #if not dataset_class == "ERA5Dataset" or not dataset_class == "MovingMnist": + # dataset_class = globals().get(dataset_class) + # if not issubclass(dataset_class,BaseVideoDataset): + # raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class)) + #else: + dataset_class = globals().get(dataset_class) return dataset_class diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index 9bd0362541b858b7ceb8650057d72f17e0188e8c..ce62965a2c92432ffbf739e933279f91b69e355c 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -29,6 +29,7 @@ class ERA5Dataset(object): self.datasplit_config = datasplit_config self.mode = mode self.seed = seed + self.sequence_length = None # will be set in get_example_info if self.mode not in ('train', 'val', 'test'): raise ValueError('Invalid mode %s' % self.mode) if not os.path.exists(self.input_dir): @@ -61,14 +62,12 @@ class ERA5Dataset(object): Returns: A dict with the following hyperparameters. context_frames : the number of ground-truth frames to pass in at start. - sequence_length : the number of frames in the video sequence max_epochs : the number of epochs to train model lr : learning rate loss_fun : the loss function """ hparams = dict( context_frames=10, - sequence_length=20, max_epochs = 20, batch_size = 40, lr = 0.001, @@ -116,26 +115,28 @@ class ERA5Dataset(object): def get_example_info(self): """ - Get the data information from tfrecord file + Get the data information from an example tfrecord file """ example = next(tf.python_io.tf_record_iterator(self.filenames[0])) dict_message = MessageToDict(tf.train.Example.FromString(example)) feature = dict_message['features']['feature'] print("features in dataset:",feature.keys()) - self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels']) - self.image_shape = self.video_shape[1:] + video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length', 'height', + 'width', 'channels']) + self.sequence_length = video_shape[0] + self.image_shape = video_shape[1:] def num_examples_per_epoch(self): """ Calculate how many tfrecords samples in the train/val/test """ - #count how many tfrecords files for train/val/testing + # count how many tfrecords files for train/val/testing len_fnames = len(self.filenames) - seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt') - with open(seq_len_file, 'r') as sequence_lengths_file: - sequence_lengths = sequence_lengths_file.readlines() - sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] - self.num_examples_per_epoch = len_fnames * sequence_lengths[0] + num_seq_file = os.path.join(self.input_dir, 'number_sequences.txt') + with open(num_seq_file, 'r') as dfile: + num_seqs = dfile.readlines() + num_sequences = [int(num_seq.strip()) for num_seq in num_seqs] + self.num_examples_per_epoch = len_fnames * num_sequences[0] return self.num_examples_per_epoch @@ -163,9 +164,10 @@ class ERA5Dataset(object): parsed_features = tf.parse_single_example(serialized_example, keys_to_features) seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) T_start = tf.sparse_tensor_to_dense(parsed_features["t_start"]) - images = [] - print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2])) - images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new") + print("Image shape {}, {},{},{}".format(self.sequence_length, self.image_shape[0], self.image_shape[1], + self.image_shape[2])) + images = tf.reshape(seq, [self.sequence_length,self.image_shape[0],self.image_shape[1], + self.image_shape[2]], name="reshape_new") seqs["images"] = images seqs["T_start"] = T_start return seqs @@ -179,7 +181,9 @@ class ERA5Dataset(object): dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs)) else: dataset = dataset.repeat(self.num_epochs) - if self.mode == "val": dataset = dataset.repeat(20) + + if self.mode == "val": dataset = dataset.repeat(20) + num_parallel_calls = None if shuffle else 1 dataset = dataset.apply(tf.contrib.data.map_and_batch( parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py index b1136172b203218b5bbc3b052e0d454c2c5bd60f..40df33aaaf82d4764d8b0d55c8a1bed55131e963 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py @@ -8,41 +8,117 @@ import re import tensorflow as tf import numpy as np import skimage.io -from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset +from collections import OrderedDict +from tensorflow.contrib.training import HParams +from google.protobuf.json_format import MessageToDict + + +class KTHVideoDataset(object): + def __init__(self,input_dir=None,datasplit_config=None,hparams_dict_config=None, mode='train',seed=None): + """ + This class is used for preparing data for training/validation and test models + args: + input_dir : the path of tfrecords files + datasplit_config : the path pointing to the datasplit_config json file + hparams_dict_config : the path to the dict that contains hparameters, + mode : string, "train","val" or "test" + seed : int, the seed for dataset + """ + self.input_dir = input_dir + self.datasplit_config = datasplit_config + self.mode = mode + self.seed = seed + if self.mode not in ('train', 'val', 'test'): + raise ValueError('Invalid mode %s' % self.mode) + if not os.path.exists(self.input_dir): + raise FileNotFoundError("input_dir %s does not exist" % self.input_dir) + self.datasplit_dict_path = datasplit_config + self.data_dict = self.get_datasplit() + self.hparams_dict_config = hparams_dict_config + self.hparams_dict = self.get_model_hparams_dict() + self.hparams = self.parse_hparams() + self.get_tfrecords_filesnames_base_datasplit() + self.get_example_info() + + + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) -class KTHVideoDataset(VarLenFeatureVideoDataset): - def __init__(self, *args, **kwargs): - super(KTHVideoDataset, self).__init__(*args, **kwargs) - from google.protobuf.json_format import MessageToDict - example = next(tf.python_io.tf_record_iterator(self.filenames[0])) - dict_message = MessageToDict(tf.train.Example.FromString(example)) - feature = dict_message['features']['feature'] - image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels']) - - self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape - def get_default_hparams_dict(self): - default_hparams = super(KTHVideoDataset, self).get_default_hparams_dict() + """ + The function that contains default hparams + Returns: + A dict with the following hyperparameters. + context_frames : the number of ground-truth frames to pass in at start. + sequence_length : the number of frames in the video sequence + max_epochs : the number of epochs to train model + lr : learning rate + loss_fun : the loss function + """ hparams = dict( context_frames=10, sequence_length=20, - long_sequence_length=40, - force_time_shift=True, - shuffle_on_val=True, - use_state=False, + max_epochs = 20, + batch_size = 40, + lr = 0.001, + loss_fun = "rmse", + shuffle_on_val= True, ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - @property - def jpeg_encoding(self): - return False + return hparams + + + + + def get_datasplit(self): + """ + Get the datasplit json file + """ + + with open(self.datasplit_dict_path) as f: + self.d = json.load(f) + return self.d + + def parse_hparams(self): + """ + Parse the hparams setting to ovoerride the default ones + """ + parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams + + + def get_tfrecords_filesnames_base_datasplit(self): + """ + Get absolute .tfrecord path names based on the data splits patterns + """ + self.filenames = [] + self.data_mode = self.data_dict[self.mode] + self.tf_names = [] + for year, months in self.data_mode.items(): + for month in months: + tf_files = "sequence_Y_{}_M_{}_*_to_*.tfrecord*".format(year,month) + self.tf_names.append(tf_files) + # look for tfrecords in input_dir and input_dir/mode directories + for files in self.tf_names: + self.filenames.extend(glob.glob(os.path.join(self.input_dir, files))) + if self.filenames: + self.filenames = sorted(self.filenames) # ensures order is the same across systems + if not self.filenames: + raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir) def num_examples_per_epoch(self): - with open(os.path.join(self.input_dir, 'number_sequences.txt'), 'r') as sequence_lengths_file: - sequence_lengths = sequence_lengths_file.readlines() + """ + Calculate how many tfrecords samples in the train/val/test + """ + #count how many tfrecords files for train/val/testing + len_fnames = len(self.filenames) + seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt') + with open(seq_len_file, 'r') as sequence_lengths_file: + sequence_lengths = sequence_lengths_file.readlines() sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] - return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) + self.num_examples_per_epoch = len_fnames * sequence_lengths[0] + return self.num_examples_per_epoch def _bytes_feature(value): @@ -62,17 +138,12 @@ def partition_data(input_dir): fnames = glob.glob(os.path.join(input_dir, '*/*')) fnames = [fname for fname in fnames if os.path.isdir(fname)] print("frames",fnames[0]) - persons = [re.match('person(\d+)_\w+_\w+', os.path.split(fname)[1]).group(1) for fname in fnames] persons = np.array([int(person) for person in persons]) - train_mask = persons <= 16 - train_fnames = [fnames[i] for i in np.where(train_mask)[0]] test_fnames = [fnames[i] for i in np.where(~train_mask)[0]] - random.shuffle(train_fnames) - pivot = int(0.95 * len(train_fnames)) train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:] return train_fnames, val_fnames, test_fnames @@ -96,41 +167,42 @@ def save_tf_record(output_fname, sequences): writer.write(example.SerializeToString()) -def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128): - partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test - sequences = [] - sequence_iter = 0 - sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') - for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person - meta_partition_name = partition_name if partition_name == 'test' else 'train' - meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' % - (meta_partition_name, image_size, image_size)) - with open(meta_fname, "rb") as f: - data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys. "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png - - vid = os.path.split(video_dir)[1] - (d,) = [d for d in data if d['vid'] == vid] - for frame_fnames_iter, frame_fnames in enumerate(d['files']): - frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames] - frames = skimage.io.imread_collection(frame_fnames) - # they are grayscale images, so just keep one of the channels - frames = [frame[..., 0:1] for frame in frames] - - if not sequences: #The length of the sequence in sequences could be different - last_start_sequence_iter = sequence_iter - print("reading sequences starting at sequence %d" % sequence_iter) - - sequences.append(frames) - sequence_iter += 1 - sequence_lengths_file.write("%d\n" % len(frames)) - - if (len(sequences) == sequences_per_file or - (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))): - output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) - output_fname = os.path.join(output_dir, output_fname) - save_tf_record(output_fname, sequences) - sequences[:] = [] - sequence_lengths_file.close() + + def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128): + partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test + sequences = [] + sequence_iter = 0 + sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') + for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person + meta_partition_name = partition_name if partition_name == 'test' else 'train' + meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' % + (meta_partition_name, image_size, image_size)) + with open(meta_fname, "rb") as f: + data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys. "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png + + vid = os.path.split(video_dir)[1] + (d,) = [d for d in data if d['vid'] == vid] + for frame_fnames_iter, frame_fnames in enumerate(d['files']): + frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames] + frames = skimage.io.imread_collection(frame_fnames) + # they are grayscale images, so just keep one of the channels + frames = [frame[..., 0:1] for frame in frames] + + if not sequences: #The length of the sequence in sequences could be different + last_start_sequence_iter = sequence_iter + print("reading sequences starting at sequence %d" % sequence_iter) + + sequences.append(frames) + sequence_iter += 1 + sequence_lengths_file.write("%d\n" % len(frames)) + + if (len(sequences) == sequences_per_file or + (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))): + output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) + output_fname = os.path.join(output_dir, output_fname) + save_tf_record(output_fname, sequences) + sequences[:] = [] + sequence_lengths_file.close() def main(): @@ -141,12 +213,10 @@ def main(): parser.add_argument("output_dir", type=str) parser.add_argument("image_size", type=int) args = parser.parse_args() - partition_names = ['train', 'val', 'test'] print("input dir", args.input_dir) partition_fnames = partition_data(args.input_dir) print("partiotion_fnames[0]", partition_fnames[0]) - for partition_name, partition_fnames in zip(partition_names, partition_fnames): partition_dir = os.path.join(args.output_dir, partition_name) if not os.path.exists(partition_dir): diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py index 5ef54d379dc796a786a52c1fc535432f079a4b43..adfdf06539174a78be375fa1f2416dade078a1c5 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py @@ -1,118 +1,209 @@ -import argparse -import sys + +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong, Karim" +__date__ = "2021-05-03" + + + import glob -import itertools import os -import pickle import random -import re -import numpy as np import json import tensorflow as tf from tensorflow.contrib.training import HParams -from mpi4py import MPI from collections import OrderedDict -import matplotlib.pyplot as plt -import matplotlib.gridspec as gridspec -from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset -import data_preprocess.process_netCDF_v2 -from general_utils import get_unique_vars -from statistics import Calc_data_stat -from metadata import MetaData - -class MovingMnist(VarLenFeatureVideoDataset): - def __init__(self, *args, **kwargs): - super(MovingMnist, self).__init__(*args, **kwargs) - from google.protobuf.json_format import MessageToDict - example = next(tf.python_io.tf_record_iterator(self.filenames[0])) - dict_message = MessageToDict(tf.train.Example.FromString(example)) - feature = dict_message['features']['feature'] - print("features in dataset:",feature.keys()) - self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels']) - self.image_shape = self.video_shape[1:] - self.state_like_names_and_shapes['images'] = 'images/encoded', self.image_shape +from google.protobuf.json_format import MessageToDict + + +class MovingMnist(object): + def __init__(self, input_dir=None, datasplit_config=None, hparams_dict_config=None, mode="train",seed=None): + """ + This class is used for preparing the data for moving mnist, and split the data to train/val/testing + :params input_dir: the path of tfrecords files + :params datasplit_config: the path pointing to the datasplit_config json file + :params hparams_dict_config: the path to the dict that contains hparameters + :params mode: string, "train","val" or "test" + :params seed:int, the seed for dataset + :return None + """ + self.input_dir = input_dir + self.mode = mode + self.seed = seed + self.sequence_length = None # will be set in get_example_info + if self.mode not in ('train', 'val', 'test'): + raise ValueError('Invalid mode %s' % self.mode) + if not os.path.exists(self.input_dir): + raise FileNotFoundError("input_dir %s does not exist" % self.input_dir) + self.datasplit_dict_path = datasplit_config + self.data_dict = self.get_datasplit() + self.hparams_dict_config = hparams_dict_config + self.hparams_dict = self.get_model_hparams_dict() + self.hparams = self.parse_hparams() + self.get_tfrecords_filename_base_datasplit() + self.get_example_info() + + def get_datasplit(self): + """ + Get the datasplit json file + """ + with open(self.datasplit_dict_path) as f: + self.d = json.load(f) + return self.d + + def get_model_hparams_dict(self): + """ + Get model_hparams_dict from json file + """ + self.model_hparams_dict_load = {} + if self.hparams_dict_config: + with open(self.hparams_dict_config) as f: + self.model_hparams_dict_load.update(json.loads(f.read())) + return self.model_hparams_dict_load + + + def parse_hparams(self): + """ + Parse the hparams setting to ovoerride the default ones + """ + parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) + def get_default_hparams_dict(self): - default_hparams = super(MovingMnist, self).get_default_hparams_dict() + + """ + The function that contains default hparams + Returns: + A dict with the following hyperparameters. + context_frames : the number of ground-truth frames to pass in at start. + sequence_length : the number of frames in the video sequence + max_epochs : the number of epochs to train model + lr : learning rate + loss_fun : the loss function + :return: + """ hparams = dict( - context_frames=10,#Bing: Todo oriignal is 10 - sequence_length=20,#bing: TODO original is 20, - shuffle_on_val=True, + context_frames=10, + sequence_length=20, + max_epochs = 20, + batch_size = 40, + lr = 0.001, + loss_fun = "rmse", + shuffle_on_val= True, ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - - @property - def jpeg_encoding(self): - return False - + return hparams + + + def get_tfrecords_filename_base_datasplit(self): + """ + Get obsoluate .tfrecords names based on the data splits patterns + """ + self.filenames = [] + self.data_mode = self.data_dict[self.mode] + self.all_filenames = glob.glob(os.path.join(self.input_dir,"*.tfrecords")) + print("self.all_files",self.all_filenames) + for indice_group, index in self.data_mode.items(): + fs = [MovingMnist.string_filter(max_value=index[1], min_value=index[0], string=s) for s in self.all_filenames] + print("fs:",fs) + self.tf_names = [self.all_filenames[fs_index] for fs_index in range(len(fs)) if fs[fs_index]==True] + print("tf_names,",self.tf_names) + # look for tfrecords in input_dir and input_dir/mode directories + for files in self.tf_names: + self.filenames.extend(glob.glob(os.path.join(self.input_dir, files))) + if self.filenames: + self.filenames = sorted(self.filenames) # ensures order is the same across systems + if not self.filenames: + raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir) + + + @staticmethod + def string_filter(max_value=None, min_value=None, string="input_directory/sequence_index_0_index_10.tfrecords"): + a = os.path.split(string)[-1].split("_") + if not len(a) == 5: + raise ("The tfrecords pattern does not match the expected pattern, for instanct: 'sequence_index_0_to_10.tfrecords'") + min_index = int(a[2]) + max_index = int(a[4].split(".")[0]) + if min_index >= min_value and max_index <= max_value: + return True + else: + return False + def get_example_info(self): + """ + Get the data information from tfrecord file + """ + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + dict_message = MessageToDict(tf.train.Example.FromString(example)) + feature = dict_message['features']['feature'] + print("features in dataset:",feature.keys()) + video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', + 'width', 'channels']) + self.sequence_length = video_shape[0] + self.image_shape = video_shape[1:] def num_examples_per_epoch(self): - with open(os.path.join(self.input_dir, 'number_squences.txt'), 'r') as sequence_lengths_file: - sequence_lengths = sequence_lengths_file.readlines() - sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] - return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) - - def filter(self, serialized_example): - return tf.convert_to_tensor(True) - - - def make_dataset_v2(self, batch_size): + """ + Calculate how many tfrecords samples in the train/val/test + """ + # count how many tfrecords files for train/val/testing + len_fnames = len(self.filenames) + num_seq_file = os.path.join(self.input_dir, 'number_sequences.txt') + with open(num_seq_file, 'r') as dfile: + num_seqs = dfile.readlines() + num_sequences = [int(num_seq.strip()) for num_seq in num_seqs] + self.num_examples_per_epoch = len_fnames * num_sequences[0] + + return self.num_examples_per_epoch + + + def make_dataset(self, batch_size): + """ + Prepare batch_size dataset fed into to the models. + If the data are from training dataset,then the data is shuffled; + If the data are from val dataset, the shuffle var will be decided by the hparams.shuffled_on_val; + if the data are from test dataset, the data will not be shuffled + args: + batch_size: int, the size of samples fed into the models per iteration + """ + self.num_epochs = self.hparams.max_epochs def parser(serialized_example): seqs = OrderedDict() keys_to_features = { - 'width': tf.FixedLenFeature([], tf.int64), - 'height': tf.FixedLenFeature([], tf.int64), - 'sequence_length': tf.FixedLenFeature([], tf.int64), - 'channels': tf.FixedLenFeature([], tf.int64), - 'images/encoded': tf.VarLenFeature(tf.float32) - } - - # for i in range(20): - # keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string) + 'width': tf.FixedLenFeature([], tf.int64), + 'height': tf.FixedLenFeature([], tf.int64), + 'sequence_length': tf.FixedLenFeature([], tf.int64), + 'channels': tf.FixedLenFeature([],tf.int64), + 'images/encoded': tf.VarLenFeature(tf.float32) + } parsed_features = tf.parse_single_example(serialized_example, keys_to_features) - print ("Parse features", parsed_features) seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) - #width = tf.sparse_tensor_to_dense(parsed_features["width"]) - # height = tf.sparse_tensor_to_dense(parsed_features["height"]) - # channels = tf.sparse_tensor_to_dense(parsed_features["channels"]) - # sequence_length = tf.sparse_tensor_to_dense(parsed_features["sequence_length"]) - images = [] - print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2])) - images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new") + print("Image shape {}, {},{},{}".format(self.sequence_length,self.image_shape[0],self.image_shape[1], + self.image_shape[2])) + images = tf.reshape(seq, [self.sequence_length,self.image_shape[0],self.image_shape[1], + self.image_shape[2]], name = "reshape_new") seqs["images"] = images return seqs filenames = self.filenames - print ("FILENAMES",filenames) - #TODO: - #temporal_filenames = self.temporal_filenames shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) if shuffle: random.shuffle(filenames) - dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) # todo: what is buffer_size - print("files", self.filenames) - print("mode", self.mode) - dataset = dataset.filter(self.filter) + dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8* 1024 * 1024) if shuffle: - dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count = self.num_epochs)) + dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size =1024, count=self.num_epochs)) else: dataset = dataset.repeat(self.num_epochs) - + if self.mode == "val": dataset = dataset.repeat(20) num_parallel_calls = None if shuffle else 1 dataset = dataset.apply(tf.contrib.data.map_and_batch( parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) - #dataset = dataset.map(parser) - # num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set) - # dataset = dataset.apply(tf.contrib.data.map_and_batch( - # _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs - dataset = dataset.prefetch(batch_size) # Bing: Take the data to buffer inorder to save the waiting time for GPU + dataset = dataset.prefetch(batch_size) return dataset - - def make_batch(self, batch_size): - dataset = self.make_dataset_v2(batch_size) + dataset = self.make_dataset(batch_size) iterator = dataset.make_one_shot_iterator() return iterator.get_next() @@ -129,108 +220,8 @@ def _floats_feature(value): def _int64_feature(value): return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) -def save_tf_record(output_fname, sequences): - with tf.python_io.TFRecordWriter(output_fname) as writer: - for i in range(len(sequences)): - sequence = sequences[:,i,:,:,:] - num_frames = len(sequence) - height, width = sequence[0,:,:,0].shape - encoded_sequence = np.array([list(image) for image in sequence]) - features = tf.train.Features(feature={ - 'sequence_length': _int64_feature(num_frames), - 'height': _int64_feature(height), - 'width': _int64_feature(width), - 'channels': _int64_feature(1), - 'images/encoded': _floats_feature(encoded_sequence.flatten()), - }) - example = tf.train.Example(features=features) - writer.write(example.SerializeToString()) - -def read_frames_and_save_tf_records(output_dir,dat_npz, seq_length=20, sequences_per_file=128, height=64, width=64):#Bing: original 128 - """ - Read the moving_mnst data which is npz format, and save it to tfrecords files - The shape of dat_npz is [seq_length,number_samples,height,width] - moving_mnst only has one channel - - """ - os.makedirs(output_dir,exist_ok=True) - idx = 0 - num_samples = dat_npz.shape[1] - dat_npz = np.expand_dims(dat_npz, axis=4) #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel] - print("data_npz_shape",dat_npz.shape) - dat_npz = dat_npz.astype(np.float32) - dat_npz /= 255.0 #normalize RGB codes by dividing it to the max RGB value - while idx < num_samples - sequences_per_file: - sequences = dat_npz[:,idx:idx+sequences_per_file,:,:,:] - output_fname = 'sequence_{}_{}.tfrecords'.format(idx,idx+sequences_per_file) - output_fname = os.path.join(output_dir, output_fname) - save_tf_record(output_fname, sequences) - idx = idx + sequences_per_file - return None - - -def write_sequence_file(output_dir,seq_length,sequences_per_file): - partition_names = ["train","val","test"] - for partition_name in partition_names: - save_output_dir = os.path.join(output_dir,partition_name) - tfCounter = len(glob.glob1(save_output_dir,"*.tfrecords")) - print("Partition_name: {}, number of tfrecords: {}".format(partition_name,tfCounter)) - sequence_lengths_file = open(os.path.join(save_output_dir, 'sequence_lengths.txt'), 'w') - for i in range(tfCounter*sequences_per_file): - sequence_lengths_file.write("%d\n" % seq_length) - sequence_lengths_file.close() - - -def plot_seq_imgs(imgs,output_png_dir,idx,label="Ground Truth"): - """ - Plot the seq images - """ - - if len(np.array(imgs).shape)!=3:raise("img dims should be three: (seq_len,lat,lon)") - img_len = imgs.shape[0] - fig = plt.figure(figsize=(18,6)) - gs = gridspec.GridSpec(1, 10) - gs.update(wspace = 0., hspace = 0.) - for i in range(img_len): - ax1 = plt.subplot(gs[i]) - plt.imshow(imgs[i] ,cmap = 'jet') - plt.setp([ax1], xticks = [], xticklabels = [], yticks = [], yticklabels = []) - plt.savefig(os.path.join(output_png_dir, label + "_" + str(idx) + ".jpg")) - print("images_saved") - plt.clf() + - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking") - parser.add_argument("output_dir", type=str) - parser.add_argument("-sequences_per_file",type=int,default=2) - args = parser.parse_args() - current_path = os.getcwd() - data = np.load(os.path.join(args.input_dir,"mnist_test_seq.npy")) - print("data in minist_test_Seq shape",data.shape) - seq_length = data.shape[0] - height = data.shape[2] - width = data.shape[3] - num_samples = data.shape[1] - max_npz = np.max(data) - min_npz = np.min(data) - print("max_npz,",max_npz) - print("min_npz",min_npz) - #Todo need to discuss how to split the data, since we have totally 10000 samples, the origin paper convLSTM used 10000 as training, 2000 as validation and 3000 for testing - dat_train = data[:,:6000,:,:] - dat_val = data[:,6000:7000,:,:] - dat_test = data[:,7000:,:] - #plot_seq_imgs(dat_test[10:,0,:,:],output_png_dir="/p/project/deepacf/deeprain/video_prediction_shared_folder/results/moving_mnist/convLSTM",idx=1,label="Ground Truth from npz") - #save train - #read_frames_and_save_tf_records(os.path.join(args.output_dir,"train"),dat_train, seq_length=20, sequences_per_file=40, height=height, width=width) - #save val - #read_frames_and_save_tf_records(os.path.join(args.output_dir,"val"),dat_val, seq_length=20, sequences_per_file=40, height=height, width=width) - #save test - #read_frames_and_save_tf_records(os.path.join(args.output_dir,"test"),dat_test, seq_length=20, sequences_per_file=40, height=height, width=width) - #write_sequence_file(output_dir=args.output_dir,seq_length=20,sequences_per_file=40) -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index 960f608deed07e715190cdecb38efeb2eb4c5ace..2053aeed83a3606804af959e1c422d5cb39723a7 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -12,6 +12,10 @@ from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel from .mcnet_model import McNetVideoPredictionModel from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models +from .vanilla_GAN_model import VanillaGANVideoPredictionModel +from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel + + def get_model_class(model): model_mappings = known_models() diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py new file mode 100644 index 0000000000000000000000000000000000000000..092cf81a0db1adeb3c3bac65cb3bc92e54e3e4c5 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py @@ -0,0 +1,341 @@ +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong,Yanji" +__date__ = "2021-04-13" + +from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames +import tensorflow as tf +from model_modules.video_prediction.layers import layer_def as ld +from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell +from tensorflow.contrib.training import HParams +from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel + +class batch_norm(object): + def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): + with tf.variable_scope(name): + self.epsilon = epsilon + self.momentum = momentum + self.name = name + + def __call__(self, x, train=True): + return tf.contrib.layers.batch_norm(x, + decay=self.momentum, + updates_collections=None, + epsilon=self.epsilon, + scale=True, + is_training=train, + scope=self.name) + +class ConvLstmGANVideoPredictionModel(object): + def __init__(self, mode='train', hparams_dict=None): + """ + This is class for building convLSTM_GAN architecture by using updated hparameters + args: + mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model + hparams_dict: dict, the dictionary contains the hparaemters names and values + """ + self.mode = mode + self.hparams_dict = hparams_dict + self.hparams = self.parse_hparams() + self.learning_rate = self.hparams.lr + self.total_loss = None + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) + self.max_epochs = self.hparams.max_epochs + self.loss_fun = self.hparams.loss_fun + self.batch_size = self.hparams.batch_size + self.recon_weight = self.hparams.recon_weight + self.bd1 = batch_norm(name = "dis1") + self.bd2 = batch_norm(name = "dis2") + self.bd3 = batch_norm(name = "dis3") + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) + + def parse_hparams(self): + """ + Parse the hparams setting to ovoerride the default ones + """ + + parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams + + + def get_default_hparams_dict(self): + """ + The function that contains default hparams + Returns: + A dict with the following hyperparameters. + context_frames : the number of ground-truth frames to pass in at start. + sequence_length : the number of frames in the video sequence + max_epochs : the number of epochs to train model + lr : learning rate + loss_fun : the loss function + recon_wegiht : the weight for reconstrution loss + """ + hparams = dict( + context_frames=12, + sequence_length=24, + max_epochs = 20, + batch_size = 40, + lr = 0.001, + loss_fun = "cross_entropy", + shuffle_on_val= True, + recon_weight=0.99, + + ) + return hparams + + + def build_graph(self, x): + self.is_build_graph = False + self.inputs = x + self.x = x["images"] + self.width = self.x.shape.as_list()[3] + self.height = self.x.shape.as_list()[2] + self.channels = self.x.shape.as_list()[4] + self.global_step = tf.train.get_or_create_global_step() + original_global_variables = tf.global_variables() + # Architecture + self.define_gan() + #This is the loss function (RMSE): + #This is loss function only for 1 channel (temperature RMSE) + #generator los + self.total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss + self.D_loss = (1-self.recon_weight) * self.D_loss + if self.mode == "train": + if self.recon_weight == 1: + print("Only train generator- convLSTM") + self.train_op = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + else: + print("Training distriminator") + self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars) + with tf.control_dependencies([self.D_solver]): + print("Training generator....") + self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.total_loss, var_list=self.gen_vars) + with tf.control_dependencies([self.G_solver]): + self.train_op = tf.assign_add(self.global_step,1) + else: + self.train_op = None + + self.outputs = {} + self.outputs["gen_images"] = self.gen_images + self.outputs["total_loss"] = self.total_loss + # Summary op + tf.summary.scalar("total_loss", self.total_loss) + tf.summary.scalar("D_loss", self.D_loss) + tf.summary.scalar("G_loss", self.G_loss) + tf.summary.scalar("D_loss_fake", self.D_loss_fake) + tf.summary.scalar("D_loss_real", self.D_loss_real) + tf.summary.scalar("recon_loss",self.recon_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self.is_build_graph = True + return self.is_build_graph + + def get_noise(self): + """ + Function for creating noise: Given the dimensions (n_batch,n_seq, n_height, n_width, channel) + """ + self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.batch_size, self.sequence_length, self.height, self.width, self.channels]) + return self.noise + + @staticmethod + def lrelu(x, leak=0.2, name="lrelu"): + return tf.maximum(x, leak*x) + + @staticmethod + def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): + shape = input_.get_shape().as_list() + + with tf.variable_scope(scope or "Linear"): + matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, + tf.random_normal_initializer(stddev=stddev)) + bias = tf.get_variable("bias", [output_size], + initializer=tf.constant_initializer(bias_start)) + if with_w: + return tf.matmul(input_, matrix) + bias, matrix, bias + else: + return tf.matmul(input_, matrix) + bias + + @staticmethod + def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"): + with tf.variable_scope(name): + w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], + initializer=tf.truncated_normal_initializer(stddev=stddev)) + conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') + + biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) + conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + + return conv + + @staticmethod + def bn(x, scope): + return tf.contrib.layers.batch_norm(x, + decay=0.9, + updates_collections=None, + epsilon=1e-5, + scale=True, + scope=scope) + + def generator(self): + """ + Function to build up the generator architecture + args: + input images: a input tensor with dimension (n_batch,sequence_length,height,width,channel) + """ + with tf.variable_scope("generator",reuse=tf.AUTO_REUSE): + layer_gen = self.convLSTM_network(self.x) + layer_gen_pred = layer_gen[:,self.context_frames-1:,:,:,:] + return layer_gen + + + def discriminator(self,vid): + """ + Function that get discriminator architecture + """ + with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): + conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis1") + conv1 = ConvLstmGANVideoPredictionModel.lrelu(conv1) + conv2 = tf.layers.conv3d(conv1,128,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis2") + conv2 = ConvLstmGANVideoPredictionModel.lrelu(self.bd1(conv2)) + conv3 = tf.layers.conv3d(conv2,256,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis3") + conv3 = ConvLstmGANVideoPredictionModel.lrelu(self.bd2(conv3)) + conv4 = tf.layers.conv3d(conv3,512,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis4") + conv4 = ConvLstmGANVideoPredictionModel.lrelu(self.bd3(conv4)) + conv5 = tf.layers.conv3d(conv4,1,kernel_size=[2,4,4],strides=[1,1,1],padding="SAME",name="dis5") + conv5 = tf.reshape(conv5, [-1,1]) + conv5sigmoid = tf.nn.sigmoid(conv5) + return conv5sigmoid,conv5 + + def discriminator0(self,image): + """ + Function that get discriminator architecture + """ + with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): + layer_disc = self.convLSTM_network(image) + layer_disc = layer_disc[:,self.context_frames-1:self.context_frames,:,:, 0:1] + return layer_disc + + def discriminator1(self,sequence): + """ + https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py + Function that give the possibility of a sequence of frames is ture of false + the input squence shape is like [batch_size,time_seq_length,height,width,channel] (e.g., self.x[:,:self.context_frames,:,:,:]) + """ + with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): + print(sequence.shape) + x = sequence[:,:,:,:,0:1] # extract targeted variable + x = tf.transpose(x, [0,2,3,1,4]) # sequence shape is like: [batch_size,height,width,time_seq_length] + x = tf.reshape(x,[x.shape[0],x.shape[1],x.shape[2],x.shape[3]]) + print(x.shape) + net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) + net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'),scope='d_bn2')) + net = tf.reshape(net, [self.batch_size, -1]) + net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.linear(net, 1024, scope='d_fc3'),scope='d_bn3')) + out_logit = ConvLstmGANVideoPredictionModel.linear(net, 1, scope='d_fc4') + out = tf.nn.sigmoid(out_logit) + print(out.shape) + return out, out_logit + + def get_disc_loss(self): + """ + Return the loss of discriminator given inputs + """ + + real_labels = tf.ones_like(self.D_real) + gen_labels = tf.zeros_like(self.D_fake) + self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels)) + self.D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=gen_labels)) + self.D_loss = self.D_loss_real + self.D_loss_fake + return self.D_loss + + + def get_gen_loss(self): + """ + Param: + num_images: the number of images the generator should produce, which is also the lenght of the real image + z_dim : the dimension of the noise vector, a scalar + Return the loss of generator given inputs + """ + real_labels = tf.ones_like(self.D_fake) + self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake_logits, labels=real_labels)) + return self.G_loss + + def get_vars(self): + """ + Get trainable variables from discriminator and generator + """ + print("trinable_varialbes", len(tf.trainable_variables())) + self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] + self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] + print("self.disc_vars",self.disc_vars) + print("self.gen_vars",self.gen_vars) + + + def define_gan(self): + """ + Define gan architectures + """ + self.noise = self.get_noise() + self.gen_images = self.generator() + #!!!! the input of discriminator should be changed when use different discriminators + self.D_real, self.D_real_logits = self.discriminator(self.x[:,self.context_frames:,:,:,:]) + self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,self.context_frames-1:,:,:,:]) + self.get_gen_loss() + self.get_disc_loss() + self.get_vars() + if self.loss_fun == "rmse": + self.recon_loss = tf.reduce_mean(tf.square(self.x[:, self.context_frames:,:,:,0] - self.gen_images[:,self.context_frames-1:,:,:,0])) + elif self.loss_fun == "cross_entropy": + x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1]) + x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,self.context_frames-1:,:,:,0],[-1]) + bce = tf.keras.losses.BinaryCrossentropy() + self.recon_loss = bce(x_flatten,x_hat_predict_frames_flatten) + else: + raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'") + + + @staticmethod + def convLSTM_cell(inputs, hidden): + y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset + channels = inputs.get_shape()[-1] + # conv lstm cell + cell_shape = y_0.get_shape().as_list() + channels = cell_shape[-1] + with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): + cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64) + if hidden is None: + hidden = cell.zero_state(y_0, tf.float32) + output, hidden = cell(y_0, hidden) + output_shape = output.get_shape().as_list() + z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) + #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction + x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") + print('x_hat shape is: ',x_hat.shape) + return x_hat, hidden + + def convLSTM_network(self,x): + network_template = tf.make_template('network',VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables + # create network + x_hat = [] + + #This is for training (optimization of convLSTM layer) + hidden_g = None + for i in range(self.sequence_length-1): + if i < self.context_frames: + x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g) + else: + x_1_g, hidden_g = network_template(x_1_g, hidden_g) + x_hat.append(x_1_g) + + # pack them all together + x_hat = tf.stack(x_hat) + self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim ???? yan: why? + print('self.x_hat shape is: ',self.x_hat.shape) + return self.x_hat + + + diff --git a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py index a946bd555a603fd9be14306929e0a8e722a24673..2c755417c27bca22b2a20dc6816add43f160c5a6 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py @@ -3,22 +3,13 @@ __author__ = "Bing Gong" __date__ = "2020-08-22" -import collections -import functools import itertools -from collections import OrderedDict import numpy as np import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops + +from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.models import networks from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils -from datetime import datetime -from pathlib import Path -from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from model_modules.video_prediction.layers.mcnet_ops import * from model_modules.video_prediction.utils.mcnet_utils import * @@ -32,7 +23,7 @@ class McNetVideoPredictionModel(BaseVideoPredictionModel): self.lr = self.hparams.lr self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length - self.context_frames + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.df_dim = self.hparams.df_dim self.gf_dim = self.hparams.gf_dim self.alpha = self.hparams.alpha diff --git a/video_prediction_tools/model_modules/video_prediction/models/model_helpers.py b/video_prediction_tools/model_modules/video_prediction/models/model_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..a81f4631e3edde6efcfc807403ed9b061762e4c3 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/model_helpers.py @@ -0,0 +1,28 @@ +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong, Michael Langguth" +__date__ = "2021-20-05" + +""" +Some auxiliary functions that can be used by any video prediction model +""" + + +def set_and_check_pred_frames(seq_length, context_frames): + """ + Checks if sequence length and context_frames are set properly and returns number of frames to be predicted. + :param seq_length: number of frames/images per sequences + :param context_frames: number of context frames/images + :return: number of predicted frames + """ + + method = set_and_check_pred_frames.__name__ + + # sanity checks + assert isinstance(seq_length, int), "%{0}: Sequence length (seq_length) must be an integer".format(method) + assert isinstance(context_frames, int), "%{0}: Number of context frames must be an integer".format(method) + + if seq_length > context_frames: + return seq_length - context_frames + else: + raise ValueError("%{0}: Sequence length ({1}) must be larger than context frames ({2})." + .format(method, seq_length, context_frames)) \ No newline at end of file diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py new file mode 100644 index 0000000000000000000000000000000000000000..112eaf31e1d9961e402ddbac1a91b37a6b7b9b90 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py @@ -0,0 +1,230 @@ +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2021=01-05" + + + +""" +This code implement take the following as references: +1) https://stackabuse.com/introduction-to-gans-with-python-and-tensorflow/ +2) cousera GAN courses +3) https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py +""" + +import tensorflow as tf + +from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames +from model_modules.video_prediction.layers import layer_def as ld +from tensorflow.contrib.training import HParams + +class VanillaGANVideoPredictionModel(object): + def __init__(self, mode='train', hparams_dict=None): + """ + This is class for building vanilla GAN architecture by using updated hparameters + args: + mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model + hparams_dict: dict, the dictionary contains the hparaemters names and values + """ + self.mode = mode + self.hparams_dict = hparams_dict + self.hparams = self.parse_hparams() + self.learning_rate = self.hparams.lr + self.total_loss = None + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) + self.max_epochs = self.hparams.max_epochs + self.loss_fun = self.hparams.loss_fun + self.batch_size = self.hparams.batch_size + self.z_dim = self.hparams.z_dim # dim of noise-vector + + def get_default_hparams(self): + return HParams(**self.get_default_hparams_dict()) + + def parse_hparams(self): + """ + Parse the hparams setting to ovoerride the default ones + """ + + parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams + + + def get_default_hparams_dict(self): + """ + The function that contains default hparams + Returns: + A dict with the following hyperparameters. + context_frames : the number of ground-truth frames to pass in at start. + sequence_length : the number of frames in the video sequence + max_epochs : the number of epochs to train model + lr : learning rate + loss_fun : the loss function + """ + hparams = dict( + context_frames=12, + sequence_length=24, + max_epochs = 20, + batch_size = 40, + lr = 0.001, + loss_fun = "cross_entropy", + shuffle_on_val= True, + z_dim = 32, + ) + return hparams + + + def build_graph(self, x): + self.is_build_graph = False + self.x = x["images"] + self.width = self.x.shape.as_list()[3] + self.height = self.x.shape.as_list()[2] + self.channels = self.x.shape.as_list()[4] + self.n_samples = self.x.shape.as_list()[0] * self.x.shape.as_list()[1] + self.x = tf.reshape(self.x, [-1, self.height,self.width,self.channels]) + self.global_step = tf.train.get_or_create_global_step() + original_global_variables = tf.global_variables() + # Architecture + self.define_gan() + #This is the loss function (RMSE): + #This is loss function only for 1 channel (temperature RMSE) + if self.mode == "train": + self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars) + with tf.control_dependencies([self.D_solver]): + self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.G_loss, var_list=self.gen_vars) + with tf.control_dependencies([self.G_solver]): + self.train_op = tf.assign_add(self.global_step,1) + else: + self.train_op = None + self.total_loss = self.G_loss + self.D_loss + self.outputs = {} + self.outputs["gen_images"] = self.gen_images + self.outputs["total_loss"] = self.total_loss + # Summary op + self.loss_summary = tf.summary.scalar("total_loss", self.G_loss + self.D_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self.is_build_graph = True + return self.is_build_graph + + def get_noise(self): + """ + Function for creating noise: Given the dimensions (n_samples,z_dim) + """ + self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.n_samples, self.height, self.width, self.channels]) + return self.noise + + def get_generator_block(self,inputs,output_dim,idx): + + """ + Generator Block + Function for return a neural network of the generator given input and output dimensions + args: + inputs : the input vector + output_dim: the dimeniosn of output vector + return: + a generator neural network layer, with a convolutional layers followed by batch normalization and a relu activation + + """ + output1 = ld.conv_layer(inputs,kernel_size=2,stride=1,num_features=output_dim,idx=idx,activate="linear") + output2 = ld.bn_layers(output1,idx,is_training=False) + output3 = tf.nn.relu(output2) + return output3 + + + def generator(self,hidden_dim): + """ + Function to build up the generator architecture + args: + noise: a noise tensor with dimension (n_samples,height,width,channel) + hidden_dim: the inner dimension + """ + with tf.variable_scope("generator",reuse=tf.AUTO_REUSE): + layer1 = self.get_generator_block(self.noise,hidden_dim,1) + layer2 = self.get_generator_block(layer1,hidden_dim*2,2) + layer3 = self.get_generator_block(layer2,hidden_dim*4,3) + layer4 = self.get_generator_block(layer3,hidden_dim*8,4) + layer5 = ld.conv_layer(layer4,kernel_size=2,stride=1,num_features=self.channels,idx=5,activate="linear") + layer6 = tf.nn.sigmoid(layer5,name="6_conv") + print("layer6",layer6) + return layer6 + + + + def get_discriminator_block(self,inputs,output_dim,idx): + + """ + Distriminator block + Function for ruturn a neural network of a descriminator given input and output dimensions + + args: + inputs : the dimension of input vector + output_dim: the dimension of output dim + idx: : the index for the namespace of this block + Return: + a distriminator neural network layer with a convolutional layers followed by a leakyRelu function + """ + output1 = ld.conv_layer(inputs,2,stride=1,num_features=output_dim,idx=idx,activate="linear") + output2 = tf.nn.leaky_relu(output1) + return output2 + + + def discriminator(self,image,hidden_dim): + """ + Function that get discriminator architecture + """ + with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): + layer1 = self.get_discriminator_block(image,hidden_dim,idx=1) + layer2 = self.get_discriminator_block(layer1,hidden_dim*4,idx=2) + layer3 = self.get_discriminator_block(layer2,hidden_dim*2,idx=3) + layer4 = self.get_discriminator_block(layer3, self.channels,idx=4) + layer5 = tf.nn.sigmoid(layer4) + return layer5 + + + def get_disc_loss(self): + """ + Return the loss of discriminator given inputs + """ + + real_labels = tf.ones_like(self.D_real) + gen_labels = tf.zeros_like(self.D_fake) + D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real, labels=real_labels)) + D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=gen_labels)) + self.D_loss = D_loss_real + D_loss_fake + return self.D_loss + + + def get_gen_loss(self): + """ + Param: + num_images: the number of images the generator should produce, which is also the lenght of the real image + z_dim : the dimension of the noise vector, a scalar + Return the loss of generator given inputs + """ + real_labels = tf.ones_like(self.gen_images) + self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=real_labels)) + return self.G_loss + + def get_vars(self): + """ + Get trainable variables from discriminator and generator + """ + self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] + self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] + + + + def define_gan(self): + """ + Define gan architectures + """ + self.noise = self.get_noise() + self.gen_images = self.generator(hidden_dim=8) + self.D_real = self.discriminator(self.x,hidden_dim=8) + self.D_fake = self.discriminator(self.gen_images,hidden_dim=8) + self.get_gen_loss() + self.get_disc_loss() + self.get_vars() + diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index e78f973b2aef760e1fccc24d296b31ee67f2e0c8..570ece3e8e752a0520bf1593bc5e0b922cece3a0 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -2,25 +2,14 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth" __date__ = "2020-11-05" -import collections -import functools -import itertools -from collections import OrderedDict -import numpy as np +from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops -from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.models import networks -from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils -from datetime import datetime -from pathlib import Path from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from tensorflow.contrib.training import HParams + + class VanillaConvLstmVideoPredictionModel(object): def __init__(self, mode='train', hparams_dict=None): """ @@ -36,7 +25,7 @@ class VanillaConvLstmVideoPredictionModel(object): self.total_loss = None self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length - self.context_frames + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.max_epochs = self.hparams.max_epochs self.loss_fun = self.hparams.loss_fun @@ -67,17 +56,18 @@ class VanillaConvLstmVideoPredictionModel(object): hparams = dict( context_frames=10, sequence_length=20, - max_epochs = 20, - batch_size = 40, - lr = 0.001, - loss_fun = "cross_entropy", - shuffle_on_val= True, + max_epochs=20, + batch_size=40, + lr=0.001, + loss_fun="cross_entropy", + shuffle_on_val=True, ) return hparams def build_graph(self, x): self.is_build_graph = False + self.inputs = x self.x = x["images"] self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() @@ -112,24 +102,6 @@ class VanillaConvLstmVideoPredictionModel(object): self.is_build_graph = True return self.is_build_graph - @staticmethod - def convLSTM_cell(inputs, hidden): - y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset - channels = inputs.get_shape()[-1] - # conv lstm cell - cell_shape = y_0.get_shape().as_list() - channels = cell_shape[-1] - with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): - cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64) - if hidden is None: - hidden = cell.zero_state(y_0, tf.float32) - output, hidden = cell(y_0, hidden) - output_shape = output.get_shape().as_list() - z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) - #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction - x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") - return x_hat, hidden - def convLSTM_network(self): network_template = tf.make_template('network', VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables @@ -150,3 +122,41 @@ class VanillaConvLstmVideoPredictionModel(object): self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim self.x_hat_predict_frames = self.x_hat[:,self.context_frames-1:,:,:,:] + @staticmethod + def convLSTM_cell(inputs, hidden): + y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset + channels = inputs.get_shape()[-1] + # conv lstm cell + cell_shape = y_0.get_shape().as_list() + channels = cell_shape[-1] + with tf.variable_scope('conv_lstm', initializer = tf.random_uniform_initializer(-.01, 0.1)): + cell = BasicConvLSTMCell(shape = [cell_shape[1], cell_shape[2]], filter_size=5, num_features=64) + if hidden is None: + hidden = cell.zero_state(y_0, tf.float32) + output, hidden = cell(y_0, hidden) + output_shape = output.get_shape().as_list() + z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) + #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction + x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") + return x_hat, hidden + + @staticmethod + def set_and_check_pred_frames(seq_length, context_frames): + """ + Checks if sequence length and context_frames are set properly and returns number of frames to be predicted. + :param seq_length: number of frames/images per sequences + :param context_frames: number of context frames/images + :return: number of predicted frames + """ + + method = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames.__name__ + + # sanity checks + assert isinstance(seq_length, int), "%{0}: Sequence length (seq_length) must be an integer".format(method) + assert isinstance(context_frames, int), "%{0}: Number of context frames must be an integer".format(method) + + if seq_length > context_frames: + return seq_length-context_frames + else: + raise ValueError("%{0}: Sequence length ({1}) must be larger than context frames ({2})." + .format(method, seq_length, context_frames)) diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py index 986e3626fe0746c2c714ee9fa9a76ad873044415..3e74b23b9d4969544ef0e560e30cd19e4a983f6b 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py @@ -3,21 +3,8 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong" __date__ = "2020-09-01" -import collections -import functools -import itertools -from collections import OrderedDict -import numpy as np +from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops -from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.models import networks -from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils -from datetime import datetime -from pathlib import Path from model_modules.video_prediction.layers import layer_def as ld from tensorflow.contrib.training import HParams @@ -37,7 +24,7 @@ class VanillaVAEVideoPredictionModel(object): self.total_loss = None self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length - self.context_frames + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.max_epochs = self.hparams.max_epochs self.nz = self.hparams.nz self.loss_fun = self.hparams.loss_fun diff --git a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py index a1aa0dba9c5ca1e230c254a4d0a7b5439db0d5e3..d420ed1be8460ade5fae3302960de8b761f49f72 100755 --- a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py +++ b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py @@ -20,7 +20,7 @@ class Config_Preprocess1(Config_runscript_base): cls_name = "Config_Preprocess1"#.__name__ - nvars = 3 # number of variables required for training + nvars_default = 3 # number of variables required for training def __init__(self, venv_name, lhpc): super().__init__(venv_name, lhpc) @@ -35,7 +35,7 @@ class Config_Preprocess1(Config_runscript_base): # initialize additional runscript-specific attributes to be set via keyboard interaction self.destination_dir = None self.years = None - self.variables = [None] * self.nvars + self.variables = [] self.sw_corner = [-999., -999.] # [np.nan, np.nan] self.nyx = [-999., -999.] # [np.nan, np.nan] # list of variables to be written to runscript @@ -54,8 +54,18 @@ class Config_Preprocess1(Config_runscript_base): """ method_name = Config_Preprocess1.run_preprocess1.__name__ - # get source_dir (no user interaction needed when directory tree is fixed) - self.source_dir = Config_Preprocess1.handle_source_dir(self, "extractedData") + src_dir_req_str = "Enter path to directory where netCDF-files of the ERA5 dataset are located " + \ + "(in yearly directories.). Just press enter if the default should be used." + sorurce_dir_err = NotADirectoryError("Passed directory does not exist.") + source_dir_str = Config_Preprocess1.keyboard_interaction(src_dir_req_str, Config_Preprocess1.src_dir_check, + sorurce_dir_err, ntries=3) + if not source_dir_str: + # standard source_dir + self.source_dir = Config_Preprocess1.handle_source_dir(self, "extractedData") + print("%{0}: The following standard base-directory obtained from runscript template was set: '{1}'".format(method_name, self.source_dir)) + else: + self.source_dir = source_dir_str + Config_Preprocess1.get_subdir_list(self.source_dir) # get years for preprocessing step 1 years_req_str = "Enter a comma-separated sequence of years from list above:" @@ -86,8 +96,9 @@ class Config_Preprocess1(Config_runscript_base): vars_err, ntries=2) vars_list = vars_str.split(",") + vars_list = [var.strip().lower() for var in vars_list] if len(vars_list) == 1: - self.variables = vars_list * Config_Preprocess1.nvars + self.variables = vars_list * Config_Preprocess1.nvars_default else: self.variables = [var.strip() for var in vars_list] @@ -169,6 +180,29 @@ class Config_Preprocess1(Config_runscript_base): str(year))) # auxiliary functions for keyboard interaction + @staticmethod + def src_dir_check(srcdir, silent=False): + """ + Checks if source directory exists. Also allows for empty strings. In this case, a default of the source + directory must be applied. + :param srcdir: directory path under which ERA5 netCDF-data is stored + :param silent: flag if print-statement are executed + :return: status with True confirming success + """ + method = Config_Preprocess1.src_dir_check.__name__ + + status = False + if srcdir: + if os.path.isdir(srcdir): + status = True + else: + if not silent: + print("%{0}: '{1}' does not exist.".format(method, srcdir)) + else: + status = True + + return status + @staticmethod def check_data_indir(indir, silent=False, recursive=True): """ @@ -214,9 +248,10 @@ class Config_Preprocess1(Config_runscript_base): check_years = [year.strip().isnumeric() for year in years_list] status = all(check_years) if not status: - inds_bad = [i for i, e in enumerate(check_years) if e] #np.where(~np.array(check_years))[0] + inds_bad = [i for i, e in enumerate(check_years) if not e] #np.where(~np.array(check_years))[0] if not silent: - print("%{0}: The following comma-separated elements could not be interpreted as valid years:".format(method)) + print("%{0}: The following comma-separated elements could not be interpreted as valid years:" + .format(method)) for ind in inds_bad: print(years_list[ind]) return status @@ -245,15 +280,15 @@ class Config_Preprocess1(Config_runscript_base): check_vars = [var.strip().lower() in known_vars for var in vars_list] status = all(check_vars) if not status: - inds_bad = [i for i, e in enumerate(check_vars) if e] # np.where(~np.array(check_vars))[0] + inds_bad = [i for i, e in enumerate(check_vars) if e] # np.where(~np.array(check_vars))[0] if not silent: - print("%{0}: The following comma-separated elements are unknown variables:".format(method_name)) + print("%{0}: The following comma-separated elements are unknown variables:".format(method)) for ind in inds_bad: print(vars_list[ind]) return status - if not (len(check_vars) == Config_Preprocess1.nvars or len(check_vars) == 1): - if not silent: print("%{0}: Unexpected number of variables passed.".method(method)) + if not len(check_vars) >= 1: + if not silent: print("%{0}: Pass at least one input variable".format(method)) status = False return status @@ -321,9 +356,11 @@ class Config_Preprocess1(Config_runscript_base): else: if not silent: if not check_nyx[0]: - print("%{0}: Number of grid points in meridional direction must be smaller than {1:d}".format(method, ny_max)) + print("%{0}: Number of grid points in meridional direction must be smaller than {1:d}" + .format(method, ny_max)) if not check_nyx[1]: - print("%{0}: Number of grid points in zonal direction must be smaller than {1:d}".format(method, nx_max)) + print("%{0}: Number of grid points in zonal direction must be smaller than {1:d}" + .format(method, nx_max)) else: if not silent: print("%{0}: Number of grid points must be integers.".format(method))