Skip to content
Snippets Groups Projects
Commit 745e433b authored by stadtler1's avatar stadtler1
Browse files

Minor changes in order to get the code running, also added time tracking for...

Minor changes in order to get the code running, also added time tracking for each iteration in train_moving_mnist.py
parent d211a54e
Branches scarlet_issue#019_merge_loliver_mnist
Tags
No related merge requests found
Pipeline #44267 failed
...@@ -10,7 +10,7 @@ ...@@ -10,7 +10,7 @@
#SBATCH --gres=gpu:1 #SBATCH --gres=gpu:1
#SBATCH --partition=develgpus #SBATCH --partition=develgpus
#SBATCH --mail-type=ALL #SBATCH --mail-type=ALL
#SBATCH --mail-user=b.gong@fz-juelich.de #SBATCH --mail-user=s.stadtler@fz-juelich.de
##jutil env activate -p cjjsc42 ##jutil env activate -p cjjsc42
# Name of virtual environment # Name of virtual environment
......
...@@ -4,8 +4,8 @@ ...@@ -4,8 +4,8 @@
#SBATCH --ntasks=1 #SBATCH --ntasks=1
##SBATCH --ntasks-per-node=1 ##SBATCH --ntasks-per-node=1
#SBATCH --cpus-per-task=1 #SBATCH --cpus-per-task=1
#SBATCH --output=train_era5-out.%j #SBATCH --output=train_moving_mnist-out.%j
#SBATCH --error=train_era5-err.%j #SBATCH --error=train_moving_mnist-err.%j
#SBATCH --time=00:20:00 #SBATCH --time=00:20:00
#SBATCH --gres=gpu:1 #SBATCH --gres=gpu:1
#SBATCH --partition=develgpus #SBATCH --partition=develgpus
...@@ -36,9 +36,9 @@ fi ...@@ -36,9 +36,9 @@ fi
source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist source_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/moving_mnist
destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/models/moving_mnist
# for choosing the model, convLSTM,savp, mcnet,vae,convLSTM_Loliver # for choosing the model, convLSTM,savp, mcnet,vae
model=convLSTM model=convLSTM
model_hparams=../hparams/era5/${model}/model_hparams.json model_hparams=../hparams/era5/${model}/model_hparams.json
# rund training # rund training
srun python ../scripts/train_moving_mnist.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/ --checkpoint ${destination_dir}/${model}/ srun python ../scripts/train_dummy_moving_mnist.py --input_dir ${source_dir}/tfrecords/ --dataset moving_mnist --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/${model}/
...@@ -276,8 +276,10 @@ def main(): ...@@ -276,8 +276,10 @@ def main():
val_losses=[] val_losses=[]
run_start_time = time.time() run_start_time = time.time()
for step in range(start_step,total_steps): for step in range(start_step,total_steps):
#global_step = sess.run(global_step):q #global_step = sess.run(global_step)
# +++ Scarlet 20200813
timeit_start = time.time()
# --- Scarlet 20200813
print ("step:", step) print ("step:", step)
val_handle_eval = sess.run(val_handle) val_handle_eval = sess.run(val_handle)
...@@ -342,7 +344,11 @@ def main(): ...@@ -342,7 +344,11 @@ def main():
print ("The model name does not exist") print ("The model name does not exist")
#print("saving model to", args.output_dir) #print("saving model to", args.output_dir)
saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)# saver.save(sess, os.path.join(args.output_dir, "model"), global_step=step)
# +++ Scarlet 20200813
timeit_end = time.time()
# --- Scarlet 20200813
print("time needed for this step", timeit_end - timeit_start, ' s')
train_time = time.time() - run_start_time train_time = time.time() - run_start_time
results_dict = {"train_time":train_time, results_dict = {"train_time":train_time,
"total_steps":total_steps} "total_steps":total_steps}
...@@ -352,6 +358,9 @@ def main(): ...@@ -352,6 +358,9 @@ def main():
print("val_losses:",val_losses) print("val_losses:",val_losses)
plot_train(train_losses,val_losses,args.output_dir) plot_train(train_losses,val_losses,args.output_dir)
print("Done") print("Done")
# +++ Scarlet 20200814
print("Total training time:", train_time/60., "min")
# +++ Scarlet 20200814
if __name__ == '__main__': if __name__ == '__main__':
main() main()
...@@ -74,6 +74,8 @@ def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.co ...@@ -74,6 +74,8 @@ def conv_layer(inputs, kernel_size, stride, num_features, idx, initializer=tf.co
conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx)) conv_rect = tf.nn.elu(conv_biased, name = '{0}_conv'.format(idx))
elif activate == "leaky_relu": elif activate == "leaky_relu":
conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx)) conv_rect = tf.nn.leaky_relu(conv_biased, name = '{0}_conv'.format(idx))
elif activate == "sigmoid":
conv_rect = tf.nn.sigmoid(conv_biased, name = '{0}_conv'.format(idx))
else: else:
raise ("activation function is not correct") raise ("activation function is not correct")
return conv_rect return conv_rect
......
...@@ -21,7 +21,6 @@ def get_model_class(model): ...@@ -21,7 +21,6 @@ def get_model_class(model):
'vae': 'VanillaVAEVideoPredictionModel', 'vae': 'VanillaVAEVideoPredictionModel',
'convLSTM': 'VanillaConvLstmVideoPredictionModel', 'convLSTM': 'VanillaConvLstmVideoPredictionModel',
'mcnet': 'McNetVideoPredictionModel', 'mcnet': 'McNetVideoPredictionModel',
'convLSTM_Loliver': "ConvLstmLoliverVideoPredictionModel"
} }
model_class = model_mappings.get(model, model) model_class = model_mappings.get(model, model)
model_class = globals().get(model_class) model_class = globals().get(model_class)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment