Skip to content
Snippets Groups Projects
Commit cd1c7255 authored by gong1's avatar gong1
Browse files

update train_dummy by saving the train val losses and training info

parent 9bf4eda3
No related branches found
No related tags found
No related merge requests found
Pipeline #40549 passed
......@@ -8,8 +8,8 @@
#SBATCH --cpus-per-task=1
#SBATCH --output=DataExtraction-out.%j
#SBATCH --error=DataExtraction-err.%j
#SBATCH --time=00:20:00
#SBATCH --partition=devel
#SBATCH --time=05:20:00
#SBATCH --partition=batch
#SBATCH --mail-type=ALL
#SBATCH --mail-user=b.gong@fz-juelich.de
......@@ -23,8 +23,8 @@ module load h5py/2.9.0-Python-3.6.8
module load mpi4py/3.0.1-Python-3.6.8
module load netcdf4-python/1.5.0.1-Python-3.6.8
#srun python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/2014/ --destination_dir ${SAVE_DIR}/extractedData/2014
year=2012
srun python ../../workflow_parallel_frame_prediction/DataExtraction/mpi_stager_v2.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/${year}/ --destination_dir ${SAVE_DIR}/extractedData/${year}
# 2tier pystager
srun python ../../workflow_parallel_frame_prediction/DataExtraction/main_single_master.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/2013/ --destination_dir ${SAVE_DIR}/extractedData/2013
#srun python ../../workflow_parallel_frame_prediction/DataExtraction/main_single_master.py --source_dir /p/fastdata/slmet/slmet111/met_data/ecmwf/era5/nc/${year}/ --destination_dir ${SAVE_DIR}/extractedData/${year}
{
"batch_size": 10,
"lr": 0.001,
"max_epochs":100,
"max_epochs":30,
"context_frames":10,
"sequence_length":20
......
......@@ -12,6 +12,14 @@ import numpy as np
import tensorflow as tf
from video_prediction import datasets, models
import matplotlib.pyplot as plt
from json import JSONEncoder
import pickle as pkl
class NumpyArrayEncoder(JSONEncoder):
def default(self, obj):
if isinstance(obj, np.ndarray):
return obj.tolist()
return JSONEncoder.default(self, obj)
def add_tag_suffix(summary, tag_suffix):
summary_proto = tf.Summary()
......@@ -159,7 +167,15 @@ def plot_train(train_losses,val_losses,output_dir):
plt.legend()
plt.savefig(os.path.join(output_dir,'plot_train.png'))
def save_results_to_dict(results_dict,output_dir):
with open(os.path.join(output_dir,"results.json"),"w") as fp:
json.dump(results_dict,fp)
def save_results_to_pkl(train_losses,val_losses, output_dir):
with open(os.path.join(output_dir,"train_losses.pkl"),"wb") as f:
pkl.dump(train_losses,f)
with open(os.path.join(output_dir,"val_losses.pkl"),"wb") as f:
pkl.dump(val_losses,f)
def main():
......@@ -237,7 +253,10 @@ def main():
print ("number of exmaples per epoch:",num_examples_per_epoch)
steps_per_epoch = int(num_examples_per_epoch/batch_size)
total_steps = steps_per_epoch * max_epochs
#mock total_steps only for fast debugging
#total_steps = 10
print ("Total steps for training:",total_steps)
results_dict = {}
with tf.Session(config=config) as sess:
print("parameter_count =", sess.run(parameter_count))
sess.run(tf.global_variables_initializer())
......@@ -249,7 +268,7 @@ def main():
# step is relative to the start_step
train_losses=[]
val_losses=[]
run_start_time = time.time()
for step in range(total_steps):
global_step = sess.run(model.global_step)
print ("global_step:", global_step)
......@@ -294,8 +313,13 @@ def main():
else:
print ("The model name does not exist")
print("saving model to", args.output_dir)
#print("saving model to", args.output_dir)
saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)#
train_time = time.time() - run_start_time
results_dict = {"train_time":train_time,
"total_steps":total_steps}
save_results_to_dict(results_dict,args.output_dir)
save_results_to_pkl(train_losses, val_losses, args.output_dir)
print("train_losses:",train_losses)
print("val_losses:",val_losses)
plot_train(train_losses,val_losses,args.output_dir)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment