Skip to content
Snippets Groups Projects
Commit 6a6f6eac authored by gong1's avatar gong1
Browse files

Merge branch 'bing_issue#020_add_cross_entropy_loss_function_to_convLSTM' into develop

parents 48827e48 e5d7db5b
No related branches found
No related tags found
No related merge requests found
...@@ -42,4 +42,4 @@ dataset=moving_mnist ...@@ -42,4 +42,4 @@ dataset=moving_mnist
model_hparams=../hparams/${dataset}/${model}/model_hparams.json model_hparams=../hparams/${dataset}/${model}/model_hparams.json
# rund training # rund training
srun python ../scripts/train_moving_mnist.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}_bing_20200902/ srun python ../scripts/train_dummy.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}_bing_20200902/
...@@ -232,6 +232,8 @@ def main(): ...@@ -232,6 +232,8 @@ def main():
inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size) inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size)
#build model graph #build model graph
#since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train the model, otherwise the model will raise error
if args.dataset == "era5":
del inputs["T_start"] del inputs["T_start"]
model.build_graph(inputs) model.build_graph(inputs)
......
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import argparse
import errno
import json
import os
import random
import time
import numpy as np
import tensorflow as tf
from video_prediction import datasets, models
import matplotlib.pyplot as plt
from json import JSONEncoder
import pickle as pkl
class NumpyArrayEncoder(JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return JSONEncoder.default(self, obj)
def add_tag_suffix(summary, tag_suffix):
summary_proto = tf.Summary()
summary_proto.ParseFromString(summary)
summary = summary_proto
for value in summary.value:
tag_split = value.tag.split('/')
value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:])
return summary.SerializeToString()
def generate_output_dir(output_dir, model,model_hparams,logs_dir,output_dir_postfix):
if output_dir is None:
list_depth = 0
model_fname = ''
for t in ('model=%s,%s' % (model, model_hparams)):
if t == '[':
list_depth += 1
if t == ']':
list_depth -= 1
if list_depth and t == ',':
t = '..'
if t in '=,':
t = '.'
if t in '[]':
t = ''
model_fname += t
output_dir = os.path.join(logs_dir, model_fname) + output_dir_postfix
return output_dir
def get_model_hparams_dict(model_hparams_dict):
"""
Get model_hparams_dict from json file
"""
model_hparams_dict_load = {}
if model_hparams_dict:
with open(model_hparams_dict) as f:
model_hparams_dict_load.update(json.loads(f.read()))
return model_hparams_dict
def resume_checkpoint(resume,checkpoint,output_dir):
"""
Resume the existing model checkpoints and set checkpoint directory
"""
if resume:
if checkpoint:
raise ValueError('resume and checkpoint cannot both be specified')
checkpoint = output_dir
return checkpoint
def set_seed(seed):
if seed is not None:
tf.set_random_seed(seed)
np.random.seed(seed)
random.seed(seed)
def load_params_from_checkpoints_dir(model_hparams_dict,checkpoint,dataset,model):
model_hparams_dict_load = {}
if model_hparams_dict:
with open(model_hparams_dict) as f:
model_hparams_dict_load.update(json.loads(f.read()))
if checkpoint:
checkpoint_dir = os.path.normpath(checkpoint)
if not os.path.isdir(checkpoint):
checkpoint_dir, _ = os.path.split(checkpoint_dir)
if not os.path.exists(checkpoint_dir):
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir)
with open(os.path.join(checkpoint_dir, "options.json")) as f:
print("loading options from checkpoint %s" % args.checkpoint)
options = json.loads(f.read())
dataset = dataset or options['dataset']
model = model or options['model']
try:
with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f:
model_hparams_dict_load.update(json.loads(f.read()))
except FileNotFoundError:
print("model_hparams.json was not loaded because it does not exist")
return dataset, model, model_hparams_dict_load
def setup_dataset(dataset,input_dir,val_input_dir):
VideoDataset = datasets.get_dataset_class(dataset)
train_dataset = VideoDataset(
input_dir,
mode='train')
val_dataset = VideoDataset(
val_input_dir or input_dir,
mode='val')
variable_scope = tf.get_variable_scope()
variable_scope.set_use_resource(True)
return train_dataset,val_dataset,variable_scope
def setup_model(model,model_hparams_dict,train_dataset,model_hparams):
"""
Set up model instance
"""
VideoPredictionModel = models.get_model_class(model)
hparams_dict = dict(model_hparams_dict)
hparams_dict.update({
'context_frames': train_dataset.hparams.context_frames,
'sequence_length': train_dataset.hparams.sequence_length,
'repeat': train_dataset.hparams.time_shift,
})
model = VideoPredictionModel(
hparams_dict=hparams_dict,
hparams=model_hparams)
return model
def save_dataset_model_params_to_checkpoint_dir(args,output_dir,train_dataset,model):
if not os.path.exists(output_dir):
os.makedirs(output_dir)
with open(os.path.join(output_dir, "options.json"), "w") as f:
f.write(json.dumps(vars(args), sort_keys=True, indent=4))
with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f:
f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4))
with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f:
f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4))
return None
def make_dataset_iterator(train_dataset, val_dataset, batch_size ):
train_tf_dataset = train_dataset.make_dataset_v2(batch_size)
train_iterator = train_tf_dataset.make_one_shot_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
train_handle = train_iterator.string_handle()
val_tf_dataset = val_dataset.make_dataset_v2(batch_size)
val_iterator = val_tf_dataset.make_one_shot_iterator()
val_handle = val_iterator.string_handle()
#iterator = tf.data.Iterator.from_string_handle(
# train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
inputs = train_iterator.get_next()
val = val_iterator.get_next()
return inputs,train_handle, val_handle
def plot_train(train_losses,val_losses,output_dir):
iterations = list(range(len(train_losses)))
plt.plot(iterations, train_losses, 'g', label='Training loss')
plt.plot(iterations, val_losses, 'b', label='validation loss')
plt.title('Training and Validation loss')
plt.xlabel('Iterations')
plt.ylabel('Loss')
plt.legend()
plt.savefig(os.path.join(output_dir,'plot_train.png'))
def save_results_to_dict(results_dict,output_dir):
with open(os.path.join(output_dir,"results.json"),"w") as fp:
json.dump(results_dict,fp)
def save_results_to_pkl(train_losses,val_losses, output_dir):
with open(os.path.join(output_dir,"train_losses.pkl"),"wb") as f:
pkl.dump(train_losses,f)
with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f:
pkl.dump(val_losses,f)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories "
"train, val, test, etc, or a directory containing "
"the tfrecords")
parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir")
parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified")
parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. "
"default is logs_dir/model_fname, where model_fname consists of "
"information from model and model_hparams")
parser.add_argument("--output_dir_postfix", default="")
parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)")
parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.')
parser.add_argument("--dataset", type=str, help="dataset class name")
parser.add_argument("--model", type=str, help="model class name")
parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters")
parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters")
parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="fraction of gpu memory to use")
parser.add_argument("--seed",default=1234, type=int)
args = parser.parse_args()
#Set seed
set_seed(args.seed)
#setup output directory
args.output_dir = generate_output_dir(args.output_dir, args.model, args.model_hparams, args.logs_dir, args.output_dir_postfix)
#resume the existing checkpoint and set up the checkpoint directory to output directory
args.checkpoint = resume_checkpoint(args.resume,args.checkpoint,args.output_dir)
#get model hparams dict from json file
#load the existing checkpoint related datasets, model configure (This information was stored in the checkpoint dir when last time training model)
args.dataset,args.model,model_hparams_dict = load_params_from_checkpoints_dir(args.model_hparams_dict,args.checkpoint,args.dataset,args.model)
print('----------------------------------- Options ------------------------------------')
for k, v in args._get_kwargs():
print(k, "=", v)
print('------------------------------------- End --------------------------------------')
#setup training val datset instance
train_dataset,val_dataset,variable_scope = setup_dataset(args.dataset,args.input_dir,args.val_input_dir)
#setup model instance
model=setup_model(args.model,model_hparams_dict,train_dataset,args.model_hparams)
batch_size = model.hparams.batch_size
#Create input and val iterator
inputs, train_handle, val_handle = make_dataset_iterator(train_dataset, val_dataset, batch_size)
#build model graph
#del inputs["T_start"]
model.build_graph(inputs)
#save all the model, data params to output dirctory
save_dataset_model_params_to_checkpoint_dir(args,args.output_dir,train_dataset,model)
with tf.name_scope("parameter_count"):
# exclude trainable variables that are replicas (used in multi-gpu setting)
trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables)
parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables])
saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2)
# None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero
summary_writer = tf.summary.FileWriter(args.output_dir)
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True)
config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True)
max_epochs = model.hparams.max_epochs #the number of epochs
num_examples_per_epoch = train_dataset.num_examples_per_epoch()
print ("number of exmaples per epoch:",num_examples_per_epoch)
steps_per_epoch = int(num_examples_per_epoch/batch_size)
total_steps = steps_per_epoch * max_epochs
global_step = tf.train.get_or_create_global_step()
#mock total_steps only for fast debugging
#total_steps = 10
print ("Total steps for training:",total_steps)
results_dict = {}
with tf.Session(config=config) as sess:
print("parameter_count =", sess.run(parameter_count))
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
model.restore(sess, args.checkpoint)
sess.graph.finalize()
#start_step = sess.run(model.global_step)
start_step = sess.run(global_step)
print("start_step", start_step)
# start at one step earlier to log everything without doing any training
# step is relative to the start_step
train_losses=[]
val_losses=[]
run_start_time = time.time()
for step in range(start_step,total_steps):
#global_step = sess.run(global_step)
# +++ Scarlet 20200813
timeit_start = time.time()
# --- Scarlet 20200813
print ("step:", step)
val_handle_eval = sess.run(val_handle)
#Fetch variables in the graph
fetches = {"train_op": model.train_op}
#fetches["latent_loss"] = model.latent_loss
fetches["summary"] = model.summary_op
if model.__class__.__name__ == "McNetVideoPredictionModel" or model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel" or model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
fetches["global_step"] = model.global_step
fetches["total_loss"] = model.total_loss
#fetch the specific loss function only for mcnet
if model.__class__.__name__ == "McNetVideoPredictionModel":
fetches["L_p"] = model.L_p
fetches["L_gdl"] = model.L_gdl
fetches["L_GAN"] =model.L_GAN
if model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
fetches["latent_loss"] = model.latent_loss
fetches["recon_loss"] = model.recon_loss
results = sess.run(fetches)
train_losses.append(results["total_loss"])
#Fetch losses for validation data
val_fetches = {}
#val_fetches["latent_loss"] = model.latent_loss
val_fetches["total_loss"] = model.total_loss
if model.__class__.__name__ == "SAVPVideoPredictionModel":
fetches['d_loss'] = model.d_loss
fetches['g_loss'] = model.g_loss
fetches['d_losses'] = model.d_losses
fetches['g_losses'] = model.g_losses
results = sess.run(fetches)
train_losses.append(results["g_losses"])
val_fetches = {}
#val_fetches["latent_loss"] = model.latent_loss
#For SAVP the total loss is the generator loses
val_fetches["total_loss"] = model.g_losses
val_fetches["summary"] = model.summary_op
val_results = sess.run(val_fetches,feed_dict={train_handle: val_handle_eval})
val_losses.append(val_results["total_loss"])
summary_writer.add_summary(results["summary"])
summary_writer.add_summary(val_results["summary"])
summary_writer.flush()
# global_step will have the correct step count if we resume from a checkpoint
# global step is read before it's incemented
train_epoch = step/steps_per_epoch
print("progress global step %d epoch %0.1f" % (step + 1, train_epoch))
if model.__class__.__name__ == "McNetVideoPredictionModel":
print("Total_loss:{}; L_p_loss:{}; L_gdl:{}; L_GAN: {}".format(results["total_loss"],results["L_p"],results["L_gdl"],results["L_GAN"]))
elif model.__class__.__name__ == "VanillaConvLstmVideoPredictionModel":
print ("Total_loss:{}".format(results["total_loss"]))
elif model.__class__.__name__ == "SAVPVideoPredictionModel":
print("Total_loss/g_losses:{}; d_losses:{}; g_loss:{}; d_loss: {}".format(results["g_losses"],results["d_losses"],results["g_loss"],results["d_loss"]))
elif model.__class__.__name__ == "VanillaVAEVideoPredictionModel":
print("Total_loss:{}; latent_losses:{}; reconst_loss:{}".format(results["total_loss"],results["latent_loss"],results["recon_loss"]))
else:
print ("The model name does not exist")
#print("saving model to", args.output_dir)
saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)
# +++ Scarlet 20200813
timeit_end = time.time()
# --- Scarlet 20200813
print("time needed for this step", timeit_end - timeit_start, ' s')
train_time = time.time() - run_start_time
results_dict = {"train_time":train_time,
"total_steps":total_steps}
save_results_to_dict(results_dict,args.output_dir)
save_results_to_pkl(train_losses, val_losses, args.output_dir)
print("train_losses:",train_losses)
print("val_losses:",val_losses)
plot_train(train_losses,val_losses,args.output_dir)
print("Done")
# +++ Scarlet 20200814
print("Total training time:", train_time/60., "min")
# +++ Scarlet 20200814
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment