Skip to content
Snippets Groups Projects
Select Git revision
  • e75c7ab8d3ed8668cbc125c8dab51fae40369c39
  • master default
  • bing_issues#190_tf2
  • bing_tf2_convert
  • bing_issue#189_train_modular
  • simon_#172_integrate_weatherbench
  • develop
  • bing_issue#188_restructure_ambs
  • yan_issue#100_extract_prcp_data
  • bing_issue#170_data_preprocess_training_tf1
  • Gong2022_temperature_forecasts
  • bing_issue#186_clean_GMD1_tag
  • yan_issue#179_integrate_GZAWS_data_onfly
  • bing_issue#178_runscript_bug_postprocess
  • michael_issue#187_bugfix_setup_runscript_template
  • bing_issue#180_bugs_postprpocess_meta_postprocess
  • yan_issue#177_repo_for_CLGAN_gmd
  • bing_issue#176_integrate_weather_bench
  • michael_issue#181_eval_era5_forecasts
  • michael_issue#182_eval_subdomain
  • michael_issue#119_warmup_Horovod
  • bing_issue#160_test_zam347
  • ambs_v1
  • ambs_gmd_nowcasting_v1.0
  • GMD1
  • modular_booster_20210203
  • new_structure_20201004_v1.0
  • old_structure_20200930
28 results

train.py

Blame
  • train.py 18.88 KiB
    from __future__ import absolute_import
    from __future__ import division
    from __future__ import print_function
    
    import argparse
    import errno
    import json
    import os
    import random
    import time
    
    import numpy as np
    import tensorflow as tf
    
    from video_prediction import datasets, models
    
    
    def add_tag_suffix(summary, tag_suffix):
        summary_proto = tf.Summary()
        summary_proto.ParseFromString(summary)
        summary = summary_proto
    
        for value in summary.value:
            tag_split = value.tag.split('/')
            value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:])
        return summary.SerializeToString()
    
    
    def main():
        parser = argparse.ArgumentParser()
        parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
                                                                         "train, val, test, etc, or a directory containing "
                                                                         "the tfrecords")
        parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
        parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
        parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
                                                 "default is logs_dir/model_fname, where model_fname consists of "
                                                 "information from model and model_hparams")
        parser.add_argument("--output_dir_postfix", default="")
        parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
        parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
    
        parser.add_argument("--dataset", type=str, help="dataset class name")
        parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters")
        parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters")
        parser.add_argument("--model", type=str, help="model class name")
        parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
        parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
    
        parser.add_argument("--summary_freq", type=int, default=1000, help="save frequency of summaries (except for image and eval summaries) for train/validation set")
        parser.add_argument("--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set")
        parser.add_argument("--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set")
        parser.add_argument("--accum_eval_summary_freq", type=int, default=100000, help="save frequency of accumulated eval summaries for validation set only")
        parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps")
        parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable")
    
        parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training")
        parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use")
        parser.add_argument("--seed", type=int)
    
        args = parser.parse_args()
    
        if args.seed is not None:
            tf.set_random_seed(args.seed)
            np.random.seed(args.seed)
            random.seed(args.seed)
    
        if args.output_dir is None:
            list_depth = 0
            model_fname = ''
            for t in ('model=%s,%s' % (args.model, args.model_hparams)):
                if t == '[':
                    list_depth += 1
                if t == ']':
                    list_depth -= 1
                if list_depth and t == ',':
                    t = '..'
                if t in '=,':
                    t = '.'
                if t in '[]':
                    t = ''
                model_fname += t
            args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix
    
        if args.resume:
            if args.checkpoint:
                raise ValueError('resume and checkpoint cannot both be specified')
            args.checkpoint = args.output_dir
    
        dataset_hparams_dict = {}
        model_hparams_dict = {}
        if args.dataset_hparams_dict:
            with open(args.dataset_hparams_dict) as f:
                dataset_hparams_dict.update(json.loads(f.read()))
        if args.model_hparams_dict:
            with open(args.model_hparams_dict) as f:
                model_hparams_dict.update(json.loads(f.read()))
        if args.checkpoint:
            checkpoint_dir = os.path.normpath(args.checkpoint)
            if not os.path.isdir(args.checkpoint):
                checkpoint_dir, _ = os.path.split(checkpoint_dir)
            if not os.path.exists(checkpoint_dir):
                raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
            with open(os.path.join(checkpoint_dir, "options.json")) as f:
                print("loading options from checkpoint %s" % args.checkpoint)
                options = json.loads(f.read())
                args.dataset = args.dataset or options['dataset']
                args.model = args.model or options['model']
            try:
                with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f:
                    dataset_hparams_dict.update(json.loads(f.read()))
            except FileNotFoundError:
                print("dataset_hparams.json was not loaded because it does not exist")
            try:
                with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
                    model_hparams_dict.update(json.loads(f.read()))
            except FileNotFoundError:
                print("model_hparams.json was not loaded because it does not exist")
    
        print('----------------------------------- Options ------------------------------------')
        for k, v in args._get_kwargs():
            print(k, "=", v)
        print('------------------------------------- End --------------------------------------')
    
        VideoDataset = datasets.get_dataset_class(args.dataset)
        train_dataset = VideoDataset(
            args.input_dir,
            mode='train',
            hparams_dict=dataset_hparams_dict,
            hparams=args.dataset_hparams)
        val_dataset = VideoDataset(
            args.val_input_dir or args.input_dir,
            mode='val',
            hparams_dict=dataset_hparams_dict,
            hparams=args.dataset_hparams)
        if val_dataset.hparams.long_sequence_length != val_dataset.hparams.sequence_length:
            # the longer dataset is only used for the accum_eval_metrics
            long_val_dataset = VideoDataset(
                args.val_input_dir or args.input_dir,
                mode='val',
                hparams_dict=dataset_hparams_dict,
                hparams=args.dataset_hparams)
            long_val_dataset.set_sequence_length(val_dataset.hparams.long_sequence_length)
        else:
            long_val_dataset = None
    
        variable_scope = tf.get_variable_scope()
        variable_scope.set_use_resource(True)
    
        VideoPredictionModel = models.get_model_class(args.model)
        hparams_dict = dict(model_hparams_dict)
        hparams_dict.update({
            'context_frames': train_dataset.hparams.context_frames,#Bing: TODO what is context_frames?
            'sequence_length': train_dataset.hparams.sequence_length,#Bing: TODO what is sequence_frames
            'repeat': train_dataset.hparams.time_shift,
        })
        model = VideoPredictionModel(
            hparams_dict=hparams_dict,
            hparams=args.model_hparams,
            aggregate_nccl=args.aggregate_nccl)
    
        batch_size = model.hparams.batch_size
        train_tf_dataset = train_dataset.make_dataset(batch_size)#Bing: adopt the meteo data prepartion here
        train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here
        # The `Iterator.string_handle()` method returns a tensor that can be evaluated
        # and used to feed the `handle` placeholder.
        train_handle = train_iterator.string_handle()
        val_tf_dataset = val_dataset.make_dataset(batch_size)
        val_iterator = val_tf_dataset.make_one_shot_iterator()
        val_handle = val_iterator.string_handle()
        iterator = tf.data.Iterator.from_string_handle(
            train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
        inputs = iterator.get_next()
        #Bing for debug
        with tf.Session() as sess:
            for i in range(2):
                print(sess.run(tf.shape(inputs["images"])))
    
        # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles
        model.build_graph(inputs)
    
        if long_val_dataset is not None:
            # separately build a model for the longer sequence.
            # this is needed because the model doesn't support dynamic shapes.
            long_hparams_dict = dict(hparams_dict)
            long_hparams_dict['sequence_length'] = long_val_dataset.hparams.sequence_length
            # use smaller batch size for longer model to prevenet running out of memory
            long_hparams_dict['batch_size'] = model.hparams.batch_size // 2
            long_model = VideoPredictionModel(
                mode="test",  # to not build the losses and discriminators
                hparams_dict=long_hparams_dict,
                hparams=args.model_hparams,
                aggregate_nccl=args.aggregate_nccl)
            tf.get_variable_scope().reuse_variables()
            long_model.build_graph(long_val_dataset.make_batch(batch_size))
        else:
            long_model = None
    
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        with open(os.path.join(args.output_dir, "options.json"), "w") as f:
            f.write(json.dumps(vars(args), sort_keys=True, indent=4))
        with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f:
            f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
        with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
            f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
    
        with tf.name_scope("parameter_count"):
            # exclude trainable variables that are replicas (used in multi-gpu setting)
            trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
            parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables])
    
        saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)
    
        # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero
        if (args.summary_freq != 0 or args.image_summary_freq != 0 or
                args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0):
            summary_writer = tf.summary.FileWriter(args.output_dir)
    
        gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True)
        config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
        global_step = tf.train.get_or_create_global_step()
        max_steps = model.hparams.max_steps
        with tf.Session(config=config) as sess:
            print("parameter_count =", sess.run(parameter_count))
            sess.run(tf.global_variables_initializer())
            sess.run(tf.local_variables_initializer())
            #coord = tf.train.Coordinator()
            #threads = tf.train.start_queue_runners(sess = sess, coord = coord)
            print("Init done: {sess.run(tf.local_variables_initializer())}%")
            model.restore(sess, args.checkpoint)
            print("Restore processed finished")
            sess.run(model.post_init_ops)
            print("Model run started")
            val_handle_eval = sess.run(val_handle)
            print("val handle done")
            sess.graph.finalize()
            print("graph inalize done")
            start_step = sess.run(global_step)
            print("global step done")
    
            def should(step, freq):
                if freq is None:
                    return (step + 1) == (max_steps - start_step)
                else:
                    return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step))
    
            def should_eval(step, freq):
                # never run eval summaries at the beginning since it's expensive, unless it's the last iteration
                return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step))
    
            # start at one step earlier to log everything without doing any training
            # step is relative to the start_step
            for step in range(-1, max_steps - start_step):
                if step == 1:
                    # skip step -1 and 0 for timing purposes (for warmstarting)
                    start_time = time.time()
    
                fetches = {"global_step": global_step}
                if step >= 0:
                    fetches["train_op"] = model.train_op
                if should(step, args.progress_freq):
                    fetches['d_loss'] = model.d_loss
                    fetches['g_loss'] = model.g_loss
                    fetches['d_losses'] = model.d_losses
                    fetches['g_losses'] = model.g_losses
                    if isinstance(model.learning_rate, tf.Tensor):
                        fetches["learning_rate"] = model.learning_rate
                if should(step, args.summary_freq):
                    fetches["summary"] = model.summary_op
                if should(step, args.image_summary_freq):
                    fetches["image_summary"] = model.image_summary_op
                if should_eval(step, args.eval_summary_freq):
                    fetches["eval_summary"] = model.eval_summary_op
    
                run_start_time = time.time()
                results = sess.run(fetches) #fetch the elements in dictinoary fetch
    
                run_elapsed_time = time.time() - run_start_time
                if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}:
                    print('running train_op took too long (%0.1fs)' % run_elapsed_time)
    
                if (should(step, args.summary_freq) or
                        should(step, args.image_summary_freq) or
                        should_eval(step, args.eval_summary_freq)):
                    val_fetches = {"global_step": global_step}
                    if should(step, args.summary_freq):
                        val_fetches["summary"] = model.summary_op
                    if should(step, args.image_summary_freq):
                        val_fetches["image_summary"] = model.image_summary_op
                    if should_eval(step, args.eval_summary_freq):
                        val_fetches["eval_summary"] = model.eval_summary_op
                    val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
                    for name, summary in val_results.items():
                        if name == 'global_step':
                            continue
                        val_results[name] = add_tag_suffix(summary, '_1')
    
                if should(step, args.summary_freq):
                    print("recording summary")
                    summary_writer.add_summary(results["summary"], results["global_step"])
                    summary_writer.add_summary(val_results["summary"], val_results["global_step"])
                    print("done")
                if should(step, args.image_summary_freq):
                    print("recording image summary")
                    summary_writer.add_summary(results["image_summary"], results["global_step"])
                    summary_writer.add_summary(val_results["image_summary"], val_results["global_step"])
                    print("done")
                if should_eval(step, args.eval_summary_freq):
                    print("recording eval summary")
                    summary_writer.add_summary(results["eval_summary"], results["global_step"])
                    summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"])
                    print("done")
                if should_eval(step, args.accum_eval_summary_freq):
                    val_datasets = [val_dataset]
                    val_models = [model]
                    if long_model is not None:
                        val_datasets.append(long_val_dataset)
                        val_models.append(long_model)
                    for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)):
                        sess.run(val_model.accum_eval_metrics_reset_op)
                        # traverse (roughly up to rounding based on the batch size) all the validation dataset
                        accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size
                        val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op}
                        for update_step in range(accum_eval_summary_num_updates):
                            print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates))
                            val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval})
                        accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1))
                        print("recording accum eval summary")
                        summary_writer.add_summary(accum_eval_summary, val_results["global_step"])
                        print("done")
                if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or
                        should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)):
                    summary_writer.flush()
                if should(step, args.progress_freq):
                    # global_step will have the correct step count if we resume from a checkpoint
                    # global step is read before it's incremented
                    steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size
                    train_epoch = results["global_step"] / steps_per_epoch
                    print("progress  global step %d  epoch %0.1f" % (results["global_step"] + 1, train_epoch))
                    if step > 0:
                        elapsed_time = time.time() - start_time
                        average_time = elapsed_time / step
                        images_per_sec = batch_size / average_time
                        remaining_time = (max_steps - (start_step + step + 1)) * average_time
                        print("          image/sec %0.1f  remaining %dm (%0.1fh) (%0.1fd)" %
                              (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24))
    
                    if results['d_losses']:
                        print("d_loss", results["d_loss"])
                    for name, loss in results['d_losses'].items():
                        print("  ", name, loss)
                    if results['g_losses']:
                        print("g_loss", results["g_loss"])
                    for name, loss in results['g_losses'].items():
                        print("  ", name, loss)
                    if isinstance(model.learning_rate, tf.Tensor):
                        print("learning_rate", results["learning_rate"])
    
                if should(step, args.save_freq):
                    print("saving model to", args.output_dir)
                    saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step)
                    print("done")
    
    
    if __name__ == '__main__':
        main()