Skip to content
Snippets Groups Projects
Select Git revision
  • 63d232d56a8a95bf633fadc38e9251507f00daa4
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

test_model_setup.py

Blame
  • train.py 15.09 KiB
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import errno
    import itertools
    import json
    import math
    import os
    import random
    import time
    from collections import OrderedDict
    
    import numpy as np
    import tensorflow as tf
    
    from video_prediction import datasets, models
    from video_prediction.utils import ffmpeg_gif, tf_utils
    
    
    def main():
        parser = argparse.ArgumentParser()
        parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
                                                                         "train, val, test, etc, or a directory containing "
                                                                         "the tfrecords")
        parser.add_argument("--val_input_dirs", type=str, nargs='+', help="directories containing the tfrecords. default: [input_dir]")
        parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
        parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
                                                 "default is logs_dir/model_fname, where model_fname consists of "
                                                 "information from model and model_hparams")
        parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
        parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
    
        parser.add_argument("--dataset", type=str, help="dataset class name")
        parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
        parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters")
        parser.add_argument("--model", type=str, help="model class name")
        parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
        parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
    
        parser.add_argument("--summary_freq", type=int, default=1000, help="save summaries (except for image and eval summaries) every summary_freq steps")
        parser.add_argument("--image_summary_freq", type=int, default=5000, help="save image summaries every image_summary_freq steps")
        parser.add_argument("--eval_summary_freq", type=int, default=0, help="save eval summaries every eval_summary_freq steps")
        parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps")
        parser.add_argument("--metrics_freq", type=int, default=0, help="run and display metrics every metrics_freq step")
        parser.add_argument("--gif_freq", type=int, default=0, help="save gifs of predicted frames every gif_freq steps")
        parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable")
    
        parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
        parser.add_argument("--seed", type=int)
    
        args = parser.parse_args()
    
        if args.seed is not None:
            tf.set_random_seed(args.seed)
            np.random.seed(args.seed)
            random.seed(args.seed)
    
        if args.output_dir is None:
            list_depth = 0
            model_fname = ''
            for t in ('model=%s,%s' % (args.model, args.model_hparams)):
                if t == '[':
                    list_depth += 1
                if t == ']':
                    list_depth -= 1
                if list_depth and t == ',':
                    t = '..'
                if t in '=,':
                    t = '.'
                if t in '[]':
                    t = ''
                model_fname += t
            args.output_dir = os.path.join(args.logs_dir, model_fname)
    
        if args.resume:
            if args.checkpoint:
                raise ValueError('resume and checkpoint cannot both be specified')
            args.checkpoint = args.output_dir
    
        dataset_hparams_dict = {}
        model_hparams_dict = {}
        if args.dataset_hparams_dict:
            with open(args.dataset_hparams_dict) as f:
                dataset_hparams_dict.update(json.loads(f.read()))
        if args.model_hparams_dict:
            with open(args.model_hparams_dict) as f:
                model_hparams_dict.update(json.loads(f.read()))
        if args.checkpoint:
            checkpoint_dir = os.path.normpath(args.checkpoint)
            if not os.path.exists(checkpoint_dir):
                raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
            if not os.path.isdir(args.checkpoint):
                checkpoint_dir, _ = os.path.split(checkpoint_dir)
            with open(os.path.join(checkpoint_dir, "options.json")) as f:
                print("loading options from checkpoint %s" % args.checkpoint)
                options = json.loads(f.read())
                args.dataset = args.dataset or options['dataset']
                args.model = args.model or options['model']
            try:
                with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                    dataset_hparams_dict.update(json.loads(f.read()))
            except FileNotFoundError:
                print("dataset_hparams.json was not loaded because it does not exist")
            try:
                with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                    model_hparams_dict.update(json.loads(f.read()))
                    model_hparams_dict.pop('num_gpus', None)  # backwards-compatibility
            except FileNotFoundError:
                print("model_hparams.json was not loaded because it does not exist")
    
        print('----------------------------------- Options ------------------------------------')
        for k, v in args._get_kwargs():
            print(k, "=", v)
        print('------------------------------------- End --------------------------------------')
    
        VideoDataset = datasets.get_dataset_class(args.dataset)
        train_dataset = VideoDataset(args.input_dir, mode='train', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
        val_input_dirs = args.val_input_dirs or [args.input_dir]
        val_datasets = [VideoDataset(val_input_dir, mode='val', hparams_dict=dataset_hparams_dict, hparams=args.dataset_hparams)
                        for val_input_dir in val_input_dirs]
        if len(val_input_dirs) > 1:
            if isinstance(val_datasets[-1], datasets.KTHVideoDataset):
                val_datasets[-1].set_sequence_length(40)
            else:
                val_datasets[-1].set_sequence_length(30)
    
        def override_hparams_dict(dataset):
            hparams_dict = dict(model_hparams_dict)
            hparams_dict['context_frames'] = dataset.hparams.context_frames
            hparams_dict['sequence_length'] = dataset.hparams.sequence_length
            hparams_dict['repeat'] = dataset.hparams.time_shift
            return hparams_dict
    
        VideoPredictionModel = models.get_model_class(args.model)
        train_model = VideoPredictionModel(mode='train', hparams_dict=override_hparams_dict(train_dataset), hparams=args.model_hparams)
        val_models = [VideoPredictionModel(mode='val', hparams_dict=override_hparams_dict(val_dataset), hparams=args.model_hparams)
                      for val_dataset in val_datasets]
    
        batch_size = train_model.hparams.batch_size
        with tf.variable_scope('') as training_scope:
            train_model.build_graph(*train_dataset.make_batch(batch_size))
        for val_model, val_dataset in zip(val_models, val_datasets):
            with tf.variable_scope(training_scope, reuse=True):
                val_model.build_graph(*val_dataset.make_batch(batch_size))
    
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        with open(os.path.join(args.output_dir, "options.json"), "w") as f:
            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
        with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
            f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
        with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
            f.write(json.dumps(train_model.hparams.values(), sort_keys=True, indent=4))
    
        if args.gif_freq:
            val_model = val_models[0]
            val_tensors = OrderedDict()
            context_images = val_model.inputs['images'][:, :val_model.hparams.context_frames]
            val_tensors['gen_images_vis'] = tf.concat([context_images, val_model.gen_images], axis=1)
            if val_model.gen_images_enc is not None:
                val_tensors['gen_images_enc_vis'] = tf.concat([context_images, val_model.gen_images_enc], axis=1)
            val_tensors.update({name: tensor for name, tensor in val_model.inputs.items() if tensor.shape.ndims >= 4})
            val_tensors['targets'] = val_model.targets
            val_tensors.update({name: tensor for name, tensor in val_model.outputs.items() if tensor.shape.ndims >= 4})
            val_tensor_clips = OrderedDict([(name, tf_utils.tensor_to_clip(output)) for name, output in val_tensors.items()])
    
        with tf.name_scope("parameter_count"):
            parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()])
    
        saver = tf.train.Saver(max_to_keep=3)
        summaries = tf.get_collection(tf.GraphKeys.SUMMARIES)
        image_summaries = set(tf.get_collection(tf_utils.IMAGE_SUMMARIES))
        eval_summaries = set(tf.get_collection(tf_utils.EVAL_SUMMARIES))
        eval_image_summaries = image_summaries & eval_summaries
        image_summaries -= eval_image_summaries
        eval_summaries -= eval_image_summaries
        if args.summary_freq:
            summary_op = tf.summary.merge(summaries)
        if args.image_summary_freq:
            image_summary_op = tf.summary.merge(list(image_summaries))
        if args.eval_summary_freq:
            eval_summary_op = tf.summary.merge(list(eval_summaries))
            eval_image_summary_op = tf.summary.merge(list(eval_image_summaries))
    
        if args.summary_freq or args.image_summary_freq or args.eval_summary_freq:
            summary_writer = tf.summary.FileWriter(args.output_dir)
    
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac)
        config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
        global_step = tf.train.get_or_create_global_step()
        max_steps = train_model.hparams.max_steps
        with tf.Session(config=config) as sess:
            print("parameter_count =", sess.run(parameter_count))
    
            sess.run(tf.global_variables_initializer())
            train_model.restore(sess, args.checkpoint)
    
            start_step = sess.run(global_step)
            # start at one step earlier to log everything without doing any training
            # step is relative to the start_step
            for step in range(-1, max_steps - start_step):
                if step == 0:
                    start = time.time()
    
                def should(freq):
                    return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step))
    
                fetches = {"global_step": global_step}
                if step >= 0:
                    fetches["train_op"] = train_model.train_op
    
                if should(args.progress_freq):
                    fetches['d_losses'] = train_model.d_losses
                    fetches['g_losses'] = train_model.g_losses
                    if isinstance(train_model.learning_rate, tf.Tensor):
                        fetches["learning_rate"] = train_model.learning_rate
                if should(args.metrics_freq):
                    fetches['metrics'] = train_model.metrics
                if should(args.summary_freq):
                    fetches["summary"] = summary_op
                if should(args.image_summary_freq):
                    fetches["image_summary"] = image_summary_op
                if should(args.eval_summary_freq):
                    fetches["eval_summary"] = eval_summary_op
                    fetches["eval_image_summary"] = eval_image_summary_op
    
                run_start_time = time.time()
                results = sess.run(fetches)
                run_elapsed_time = time.time() - run_start_time
                if run_elapsed_time > 1.5:
                    print('session.run took %0.1fs' % run_elapsed_time)
    
                if should(args.summary_freq):
                    print("recording summary")
                    summary_writer.add_summary(results["summary"], results["global_step"])
                    print("done")
                if should(args.image_summary_freq):
                    print("recording image summary")
                    summary_writer.add_summary(
                        tf_utils.convert_tensor_to_gif_summary(results["image_summary"]), results["global_step"])
                    print("done")
                if should(args.eval_summary_freq):
                    print("recording eval summary")
                    summary_writer.add_summary(results["eval_summary"], results["global_step"])
                    summary_writer.add_summary(
                        tf_utils.convert_tensor_to_gif_summary(results["eval_image_summary"]), results["global_step"])
                    print("done")
                if should(args.summary_freq) or should(args.image_summary_freq) or should(args.eval_summary_freq):
                    summary_writer.flush()
                if should(args.progress_freq):
                    # global_step will have the correct step count if we resume from a checkpoint
                    steps_per_epoch = math.ceil(train_dataset.num_examples_per_epoch() / batch_size)
                    train_epoch = math.ceil(results["global_step"] / steps_per_epoch)
                    train_step = (results["global_step"] - 1) % steps_per_epoch + 1
                    print("progress  global step %d  epoch %d  step %d" % (results["global_step"], train_epoch, train_step))
                    if step >= 0:
                        elapsed_time = time.time() - start
                        average_time = elapsed_time / (step + 1)
                        images_per_sec = batch_size / average_time
                        remaining_time = (max_steps - (start_step + step)) * average_time
                        print("          image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
                              (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
    
                    for name, loss in itertools.chain(results['d_losses'].items(), results['g_losses'].items()):
                        print(name, loss)
                    if isinstance(train_model.learning_rate, tf.Tensor):
                        print("learning_rate", results["learning_rate"])
                if should(args.metrics_freq):
                    for name, metric in results['metrics']:
                        print(name, metric)
    
                if should(args.save_freq):
                    print("saving model to", args.output_dir)
                    saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step)
                    print("done")
    
                if should(args.gif_freq):
                    image_dir = os.path.join(args.output_dir, 'images')
                    if not os.path.exists(image_dir):
                        os.makedirs(image_dir)
    
                    gif_clips = sess.run(val_tensor_clips)
                    gif_step = results["global_step"]
                    for name, clip in gif_clips.items():
                        filename = "%08d-%s.gif" % (gif_step, name)
                        print("saving gif to", os.path.join(image_dir, filename))
                        ffmpeg_gif.save_gif(os.path.join(image_dir, filename), clip, fps=4)
                        print("done")
    
    
    if __name__ == '__main__':
        main()