Skip to content
Snippets Groups Projects
Commit c9200cc1 authored by gong1's avatar gong1
Browse files

Adopt unitest for main_visulize_postprocess.py

parent 609a7034
Branches
No related tags found
No related merge requests found
Pipeline #69403 passed
...@@ -2,7 +2,7 @@ ...@@ -2,7 +2,7 @@
# Name of virtual environment # Name of virtual environment
#VIRT_ENV_NAME="vp_new_structure" #VIRT_ENV_NAME="vp_new_structure"
VIRT_ENV_NAME="juwels_env" VIRT_ENV_NAME="env_hdfml"
if [ -z ${VIRTUAL_ENV} ]; then if [ -z ${VIRTUAL_ENV} ]; then
if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then
...@@ -24,7 +24,7 @@ fi ...@@ -24,7 +24,7 @@ fi
source ../video_prediction_tools/env_setup/modules_train.sh source ../video_prediction_tools/env_setup/modules_train.sh
##Test for preprocess moving mnist ##Test for preprocess moving mnist
#python -m pytest test_prepare_moving_mnist_data.py #python -m pytest test_prepare_moving_mnist_data.py
python -m pytest test_train_moving_mnist_data.py #python -m pytest test_train_moving_mnist_data.py
#Test for process step2 #Test for process step2
#python -m pytest test_data_preprocess_step2.py #python -m pytest test_data_preprocess_step2.py
#python -m pytest test_era5_data.py #python -m pytest test_era5_data.py
...@@ -33,5 +33,5 @@ python -m pytest test_train_moving_mnist_data.py ...@@ -33,5 +33,5 @@ python -m pytest test_train_moving_mnist_data.py
#rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/* #rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/*
#python -m pytest test_train_model_era5.py #python -m pytest test_train_model_era5.py
#python -m pytest test_vanilla_vae_model.py #python -m pytest test_vanilla_vae_model.py
#python -m pytest test_visualize_postprocess.py python -m pytest test_visualize_postprocess.py
#python -m pytest test_meta_postprocess.py #python -m pytest test_meta_postprocess.py
...@@ -7,14 +7,13 @@ from main_scripts.main_visualize_postprocess import * ...@@ -7,14 +7,13 @@ from main_scripts.main_visualize_postprocess import *
import pytest import pytest
import numpy as np import numpy as np
import datetime import datetime
from netCDF4 import Dataset, date2num
########Test case 1################ ########Test case 1################
results_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210324T120926_ji4_savp_cv12" results_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210324T120926_ji4_savp_cv12"
checkpoint = "/p/project/deepacf/deeprain/video_prediction_shared_folder/models/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210324T120926_ji4_savp_cv12" checkpoint = "/p/project/deepacf/deeprain/video_prediction_shared_folder/models/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210324T120926_ji4_savp_cv12"
mode = "test" mode = "test"
batch_size = 2 batch_size = 2
num_samples = 16
num_stochastic_samples = 2 num_stochastic_samples = 2
gpu_mem_frac = 0.5 gpu_mem_frac = 0.5
seed = 12345 seed = 12345
...@@ -32,15 +31,7 @@ args = MyClass(results_dir) ...@@ -32,15 +31,7 @@ args = MyClass(results_dir)
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def vis_case1(): def vis_case1():
return Postprocess(results_dir=results_dir,checkpoint=checkpoint, return Postprocess(results_dir=results_dir,checkpoint=checkpoint,
mode=mode,batch_size=batch_size, num_samples=num_samples, mode=mode,batch_size=batch_size,
num_stochastic_samples=num_stochastic_samples,
seed=seed,args=args,eval_metrics=eval_metrics)
######instance2
num_samples2 = 200000
@pytest.fixture(scope="module")
def vis_case2():
return Postprocess(results_dir=results_dir, checkpoint=checkpoint,
mode=mode, batch_size=batch_size, num_samples=num_samples2,
num_stochastic_samples=num_stochastic_samples, num_stochastic_samples=num_stochastic_samples,
seed=seed,args=args,eval_metrics=eval_metrics) seed=seed,args=args,eval_metrics=eval_metrics)
...@@ -60,28 +51,21 @@ def test_get_metadata(vis_case1): ...@@ -60,28 +51,21 @@ def test_get_metadata(vis_case1):
def test_setup_test_dataset(vis_case1): def test_setup_test_dataset(vis_case1):
vis_case1.test_dataset.mode == mode vis_case1.test_dataset.mode == mode
def test_setup_num_samples_per_epoch(vis_case1):
vis_case1.setup_test_dataset()
vis_case1.setup_num_samples_per_epoch()
assert vis_case1.num_samples_per_epoch == num_samples
def test_get_data_params(vis_case1): def test_get_data_params(vis_case1):
vis_case1.get_data_params()
assert vis_case1.context_frames == 12 assert vis_case1.context_frames == 12
assert vis_case1.future_length == 12 assert vis_case1.future_length == 12
def test_run_deterministic(vis_case1): def test_run_deterministic(vis_case1):
vis_case1()
vis_case1.init_session() vis_case1.init_session()
vis_case1.restore(vis_case1.sess,vis_case1.checkpoint) vis_case1.restore(vis_case1.sess,vis_case1.checkpoint)
vis_case1.sample_ind = 0 vis_case1.sample_ind = 0
vis_case1.input_results,vis_case1.input_images_denorm_all, vis_case1.t_starts = vis_case1.run_and_plot_inputs_per_batch() vis_case1.input_results,vis_case1.input_images_denorm_all, vis_case1.t_starts = vis_case1.get_input_data_per_batch(vis_case1.inputs)
assert len(vis_case1.t_starts_results) == batch_size assert len(vis_case1.t_starts) == batch_size
ts_1 = vis_case1.t_starts[0][0] ts_1 = vis_case1.t_starts[0][0]
year = str(ts_1)[:4] year = str(ts_1)[:4]
month = str(ts_1)[4:6] month = str(ts_1)[4:6]
filename = "ecmwf_era5_" + str(ts_1)[2:] + ".nc" filename = "ecmwf_era5_" + str(ts_1)[2:] + ".nc"
fl = os.path.join("/p/project/deepacf/deeprain/video_prediction_shared_folder/extractedData",year, month, filename) fl = os.path.join("/p/scratch/deepacf/deeprain/ambs_era5/extractedData",year, month, filename)
print("netCDF file name:",fl) print("netCDF file name:",fl)
with Dataset(fl,"r") as data_file: with Dataset(fl,"r") as data_file:
t2_var = data_file.variables["2t"][0,:,:] t2_var = data_file.variables["2t"][0,:,:]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment