Skip to content
Snippets Groups Projects
Commit 309adfcb authored by Michael Langguth's avatar Michael Langguth
Browse files

Merge branch 'bing_issue#138_create_checkpoints_folders' into develop

parents 3c069964 873a8821
No related branches found
No related tags found
No related merge requests found
Pipeline #76965 passed
......@@ -24,12 +24,13 @@ import matplotlib.pyplot as plt
import pickle as pkl
from model_modules.video_prediction.utils import tf_utils
from general_utils import *
import math
class TrainModel(object):
def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None,
model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None,
gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01):
gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.01, frac_save_model_start: float=None, prob_save_model:float=None):
"""
Class instance for training the models
:param input_dir: parent directory under which "pickle" and "tfrecords" files directiory are located
......@@ -46,6 +47,8 @@ class TrainModel(object):
steps per epoch is denoted here, e.g. 0.01 with 1000 iteration steps per epoch results
into a diagnozing intreval of 10 iteration steps (= interval over which validation loss
is averaged to identify best model performance)
:param frac_save_model_start: fraction of total iterations steps as the start point to save checkpoints
:param prob_save_model: probabability that model are saved to checkpoint (control the frequences of saving model0)
"""
self.input_dir = os.path.normpath(input_dir)
self.output_dir = os.path.normpath(output_dir)
......@@ -58,6 +61,8 @@ class TrainModel(object):
self.seed = seed
self.args = args
self.diag_intv_frac = diag_intv_frac
self.frac_save_model_start = frac_save_model_start
self.prob_save_model = prob_save_model
# for diagnozing and saving the model during training
self.saver_loss = None # set in create_fetches_for_train-method
self.saver_loss_name = None # set in create_fetches_for_train-method
......@@ -77,6 +82,8 @@ class TrainModel(object):
self.create_saver_and_writer()
self.setup_gpu_config()
self.calculate_samples_and_epochs()
self.calculate_checkpoint_saver_conf()
def set_seed(self):
"""
......@@ -214,7 +221,7 @@ class TrainModel(object):
"""
Create saver to save the models latest checkpoints, and a summery writer to store the train/val metrics
"""
self.saver = tf.train.Saver(var_list=self.video_model.saveable_variables, max_to_keep=2)
self.saver = tf.train.Saver(var_list=self.video_model.saveable_variables, max_to_keep=None)
self.summary_writer = tf.summary.FileWriter(self.output_dir)
def setup_gpu_config(self):
......@@ -240,6 +247,26 @@ class TrainModel(object):
print("%{}: Batch size: {}; max_epochs: {}; num_samples per epoch: {}; steps_per_epoch: {}, total steps: {}"
.format(method, batch_size,max_epochs, self.num_examples,self.steps_per_epoch,self.total_steps))
def calculate_checkpoint_saver_conf(self):
"""
Calculate the start step for saving the checkpoint, and the frequences steps to save model
"""
method = TrainModel.calculate_checkpoint_saver_conf.__name__
if hasattr(self.total_steps, "attr_name"):
raise SyntaxError(" function 'calculate_sample_and_epochs' is required to call to calcualte the total_step before all function {}".format(method))
if self.prob_save_model > 1 or self.prob_save_model<0 :
raise ValueError("pro_save_model should be less than 1 and larger than 0")
if self.frac_save_model_start > 1 or self.frac_save_model_start<0:
raise ValueError("frac_save_model_start should be less than 1 and larger than 0")
self.start_checkpoint_step = int(math.ceil(self.total_steps * self.frac_save_model_start))
self.saver_interval_step = int(math.ceil(self.total_steps * self.prob_save_model))
print("The model will be saved starting from step {} with {} interval step ".format(str(self.start_checkpoint_step),self.saver_interval_step))
def restore(self, sess, checkpoints, restore_to_checkpoint_mapping=None):
"""
Restore the models checkpoints if the checkpoints is given
......@@ -283,6 +310,22 @@ class TrainModel(object):
val_losses = pkl.load(f)
return train_losses,val_losses
def create_checkpoints_folder(self, step:int=None):
"""
Create a folder to store checkpoint at certain step
:param step: the step you want to save the checkpoint
return : dir path to save model
"""
dir_name = "checkpoint_" + str(step)
full_dir_name = os.path.join(self.output_dir,dir_name)
if os.path.isfile(os.path.join(full_dir_name,"checkpoints")):
print("The checkpoint at step {} exists".format(step))
else:
os.mkdir(full_dir_name)
return full_dir_name
def train_model(self):
"""
Start session and train the model by looping over all iteration steps
......@@ -303,7 +346,7 @@ class TrainModel(object):
# initialize auxiliary variables
time_per_iteration = []
run_start_time = time.time()
val_loss_min = 999.
# perform iteration
for step in range(start_step, self.total_steps):
timeit_start = time.time()
......@@ -324,12 +367,12 @@ class TrainModel(object):
time_iter = time.time() - timeit_start
time_per_iteration.append(time_iter)
print("%{0}: time needed for this step {1:.3f}s".format(method, time_iter))
if step > self.diag_intv_step and (step % self.diag_intv_step == 0 or step == self.total_steps - 1):
lsave, val_loss_min = TrainModel.set_model_saver_flag(val_losses, val_loss_min, self.diag_intv_step)
# save best and final model state
if lsave or step == self.total_steps - 1:
self.saver.save(sess, os.path.join(self.output_dir, "model_best" if lsave else "model_last"),
global_step=step)
if step > self.start_checkpoint_step and (step % self.saver_interval_step == 0 or step == self.total_steps - 1):
#create a checkpoint folder for step
full_dir_name = self.create_checkpoints_folder(step=step)
self.saver.save(sess, os.path.join(full_dir_name, "model_"), global_step=step)
# pickle file and plots are always created
TrainModel.save_results_to_pkl(train_losses, val_losses, self.output_dir)
TrainModel.plot_train(train_losses, val_losses, self.saver_loss_name, self.output_dir)
......@@ -551,6 +594,8 @@ def main():
parser.add_argument("--model", type=str, help="Model class name")
parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters")
parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use")
parser.add_argument("--frac_save_model_start", type=float,default=0.6,help="fraction of the start step for saving checkpoint")
parser.add_argument("--prob_save_model", type = float, default = 0.01, help = "probabability that model are saved to checkpoint (control the frequences of saving model")
parser.add_argument("--seed", default=1234, type=int)
args = parser.parse_args()
......@@ -559,7 +604,8 @@ def main():
#create a training instance
train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict,
model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset,
gpu_mem_frac=args.gpu_mem_frac,seed=args.seed,args=args)
gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_save_model_start=args.frac_save_model_start,
prob_save_model=args.prob_save_model)
print('----------------------------------- Options ------------------------------------')
for k, v in args._get_kwargs():
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment