diff --git a/video_prediction_savp/scripts/train_dummy.py b/video_prediction_savp/scripts/train_dummy.py index 4dd111321c1584c029ef23f91df9b65e47d125cc..0cdaafb09ca725e4f38c9381f59ea0267f9dc345 100644 --- a/video_prediction_savp/scripts/train_dummy.py +++ b/video_prediction_savp/scripts/train_dummy.py @@ -192,9 +192,21 @@ def save_results_to_pkl(train_losses,val_losses, output_dir): pkl.dump(train_losses,f) with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f: pkl.dump(val_losses,f) - + +# +++ Scarlet 20200917 +def save_timing_to_pkl(total_time,training_time,time_per_iteration, output_dir): + with open(os.path.join(output_dir,"timing_total_time.pkl"),"wb") as f: + pkl.dump(total_time,f) + with open(os.path.join(output_dir,"timing_training_time.pkl"),"wb") as f: + pkl.dump(training_time,f) + with open(os.path.join(output_dir,"timing_per_iteration_time.pkl"),"wb") as f: + pkl.dump(time_per_iteration,f) +# --- Scarlet 20200917 def main(): + # +++ Scarlet 20200917 + timeit_start_total_time = time.time() + # --- Scarlet 20200917 parser = argparse.ArgumentParser() parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " @@ -273,7 +285,15 @@ def main(): print ("number of exmaples per epoch:",num_examples_per_epoch) steps_per_epoch = int(num_examples_per_epoch/batch_size) #number of steps totally equal to the number of steps per each echo multiple by number of epochs - total_steps = steps_per_epoch * max_epochs + + # Please comment in again this line: + #total_steps = steps_per_epoch * max_epochs + + #+++++ Scarlet Booster testing ONLY! + total_steps = 1 + #----- Scarlet + + global_step = tf.train.get_or_create_global_step() #mock total_steps only for fast debugging #total_steps = 10 @@ -292,7 +312,10 @@ def main(): # step is relative to the start_step train_losses=[] val_losses=[] - run_start_time = time.time() + # +++ Scarlet 20200917 + time_per_iteration = [] + # --- Scarlet 20200917 + run_start_time = time.time() for step in range(start_step,total_steps): #global_step = sess.run(global_step) # +++ Scarlet 20200813 @@ -367,6 +390,7 @@ def main(): timeit_end = time.time() # --- Scarlet 20200813 print("time needed for this step", timeit_end - timeit_start, ' s') + time_per_iteration.append(timeit_end - timeit_start) if step % 20 == 0: # I save the pickle file and plot here inside the loop in case the training process cannot finished after job is done. save_results_to_pkl(train_losses,val_losses,args.output_dir) @@ -385,6 +409,11 @@ def main(): # +++ Scarlet 20200814 print("Total training time:", train_time/60., "min") # +++ Scarlet 20200814 + # +++ Scarlet 20200917 + total_run_time = time.time() - timeit_start_total_time + print("Total run time:", total_run_time/60., "min") + save_timing_to_pkl(total_run_time,train_time,time_per_iteration, args.output_dir) + # +++ Scarlet 20200917 if __name__ == '__main__': main()