diff --git a/test/run_pytest.sh b/test/run_pytest.sh index 83220d34a51379e93add931ae6e03e9491b5bce4..6aae33cf0312455efbaa95bbd491440e6d672b2e 100644 --- a/test/run_pytest.sh +++ b/test/run_pytest.sh @@ -2,7 +2,7 @@ # Name of virtual environment #VIRT_ENV_NAME="vp_new_structure" -VIRT_ENV_NAME="juwels_env" +VIRT_ENV_NAME="env_hdfml" if [ -z ${VIRTUAL_ENV} ]; then if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then @@ -24,7 +24,7 @@ fi source ../video_prediction_tools/env_setup/modules_train.sh ##Test for preprocess moving mnist #python -m pytest test_prepare_moving_mnist_data.py -python -m pytest test_train_moving_mnist_data.py +#python -m pytest test_train_moving_mnist_data.py #Test for process step2 #python -m pytest test_data_preprocess_step2.py #python -m pytest test_era5_data.py @@ -33,5 +33,5 @@ python -m pytest test_train_moving_mnist_data.py #rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/* #python -m pytest test_train_model_era5.py #python -m pytest test_vanilla_vae_model.py -#python -m pytest test_visualize_postprocess.py +python -m pytest test_visualize_postprocess.py #python -m pytest test_meta_postprocess.py diff --git a/test/test_visualize_postprocess.py b/test/test_visualize_postprocess.py index 458430f632cf2d2ccf374ef1ed5e76dc3a7b8b50..288dad25cfe86b4c8ce03ea418f79bf76851bfeb 100644 --- a/test/test_visualize_postprocess.py +++ b/test/test_visualize_postprocess.py @@ -7,14 +7,13 @@ from main_scripts.main_visualize_postprocess import * import pytest import numpy as np import datetime - +from netCDF4 import Dataset, date2num ########Test case 1################ results_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/results/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210324T120926_ji4_savp_cv12" checkpoint = "/p/project/deepacf/deeprain/video_prediction_shared_folder/models/era5-Y2007-2019M01to12-92x56-3840N0000E-2t_tcc_t_850/savp/20210324T120926_ji4_savp_cv12" mode = "test" batch_size = 2 -num_samples = 16 num_stochastic_samples = 2 gpu_mem_frac = 0.5 seed = 12345 @@ -32,17 +31,9 @@ args = MyClass(results_dir) @pytest.fixture(scope="module") def vis_case1(): return Postprocess(results_dir=results_dir,checkpoint=checkpoint, - mode=mode,batch_size=batch_size, num_samples=num_samples, + mode=mode,batch_size=batch_size, num_stochastic_samples=num_stochastic_samples, seed=seed,args=args,eval_metrics=eval_metrics) -######instance2 -num_samples2 = 200000 -@pytest.fixture(scope="module") -def vis_case2(): - return Postprocess(results_dir=results_dir, checkpoint=checkpoint, - mode=mode, batch_size=batch_size, num_samples=num_samples2, - num_stochastic_samples=num_stochastic_samples, - seed=seed, args=args,eval_metrics=eval_metrics) def test_load_jsons(vis_case1): assert vis_case1.dataset == "era5" @@ -60,28 +51,21 @@ def test_get_metadata(vis_case1): def test_setup_test_dataset(vis_case1): vis_case1.test_dataset.mode == mode -def test_setup_num_samples_per_epoch(vis_case1): - vis_case1.setup_test_dataset() - vis_case1.setup_num_samples_per_epoch() - assert vis_case1.num_samples_per_epoch == num_samples - def test_get_data_params(vis_case1): - vis_case1.get_data_params() assert vis_case1.context_frames == 12 assert vis_case1.future_length == 12 def test_run_deterministic(vis_case1): - vis_case1() vis_case1.init_session() vis_case1.restore(vis_case1.sess,vis_case1.checkpoint) vis_case1.sample_ind = 0 - vis_case1.input_results,vis_case1.input_images_denorm_all, vis_case1.t_starts = vis_case1.run_and_plot_inputs_per_batch() - assert len(vis_case1.t_starts_results) == batch_size + vis_case1.input_results,vis_case1.input_images_denorm_all, vis_case1.t_starts = vis_case1.get_input_data_per_batch(vis_case1.inputs) + assert len(vis_case1.t_starts) == batch_size ts_1 = vis_case1.t_starts[0][0] year = str(ts_1)[:4] month = str(ts_1)[4:6] filename = "ecmwf_era5_" + str(ts_1)[2:] + ".nc" - fl = os.path.join("/p/project/deepacf/deeprain/video_prediction_shared_folder/extractedData",year, month, filename) + fl = os.path.join("/p/scratch/deepacf/deeprain/ambs_era5/extractedData",year, month, filename) print("netCDF file name:",fl) with Dataset(fl,"r") as data_file: t2_var = data_file.variables["2t"][0,:,:]