Skip to content
Snippets Groups Projects
Commit 5b3d7a79 authored by gong1's avatar gong1
Browse files

Merge branch 'bing_issue#188_restructure_ambs' into develop

parents 4ed09385 b445f890
Branches
No related tags found
No related merge requests found
Pipeline #121747 failed
Showing
with 1684 additions and 516 deletions
%% Cell type:code id:5e4295b1-17d3-49eb-b1cc-25cd2c3e38e3 tags:
``` python
#https://stackoverflow.com/questions/55429307/how-to-use-windows-created-by-the-dataset-window-method-in-tensorflow-2-0
import os
import xarray as xr
import numpy as np
import time
video_pred_folder = "/p/home/jusers/gong1/juwels/video_prediction_shared_folder/"
datadir = os.path.join(video_pred_folder, "test_data_roshni")
ds = xr.open_mfdataset(os.path.join(datadir, "*.nc"))
da = ds.to_array(dim="variables").squeeze()
dims = ["time", "lat", "lon"]
max_vars, min_vars = da.max(dim=dims).values, da.min(dim=dims).values
data_arr = np.squeeze(da.values)
```
%% Cell type:code id:92a20eb3-6358-410b-bf63-2d0cf8e38856 tags:
``` python
%%timeit
data_arr.shape
```
%% Cell type:code id:cea4aede-32db-4593-a5aa-d7a242bc960a tags:
``` python
data_arr.shape
data_arr = data_arr.reshape(17520, 3, 56, 92)
```
%% Cell type:code id:903034b5-4706-419d-915c-886790a9201f tags:
``` python
#data_arr = data_arr[:48]
```
%% Cell type:code id:f62c709f-7cf1-41a1-bb78-94dc7e064f3b tags:
``` python
data_arr.shape
```
%% Cell type:code id:2db0b465-92d6-4716-98c1-100df47f0041 tags:
``` python
data_arr [0,0,0,0]
```
%% Cell type:code id:fafabe11-ed9f-40b6-b830-8d78f52dc239 tags:
``` python
data_arr [1,0,0,0]
```
%% Output
280.05115
%% Cell type:code id:e4f5d0cb-56a2-4085-80cf-9c681dc02c5f tags:
``` python
data_arr [2,0,0,0]
```
%% Output
279.88528
%% Cell type:code id:3f5382c4-b4d1-4e11-9401-7113c46e83a7 tags:
``` python
max_vars
```
%% Output
array([317.27255, 1. , 303.1935 ], dtype=float32)
%% Cell type:code id:220c117c-09dd-42fe-a4b4-30e5a7147e57 tags:
``` python
import tensorflow as tf
window_size=24
dataset = tf.data.Dataset.from_tensor_slices(data_arr).window(window_size,shift=1,drop_remainder=True)
dataset = dataset.flat_map(lambda window: window.batch(window_size))
dataset = dataset.batch(3)
```
%% Output
2022-03-17 15:07:39.060539: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
/tmp/ipykernel_15466/2295205246.py in <module>
1 import tensorflow as tf
2 window_size=24
----> 3 dataset = tf.data.Dataset.from_tensor_slices(data_arr).window(window_size,shift=1,drop_remainder=True)
4 dataset = dataset.flat_map(lambda window: window.batch(window_size))
5 dataset = dataset.batch(3)
NameError: name 'data_arr' is not defined
%% Cell type:code id:7973bc9c-6671-4499-8d51-9eea02302808 tags:
``` python
def benchmark(dataset, num_epochs=2):
start_time = time.perf_counter()
for epoch_num in range(num_epochs):
for sample in dataset:
# Performing a training step
time.sleep(0.01)
print("Execution time:", time.perf_counter() - start_time)
```
%% Cell type:code id:b33ce29b-e0af-4665-bb27-2c08aa706af4 tags:
``` python
```
%% Cell type:code id:18e04c1e-a01b-41bb-bd3a-84603a1409e1 tags:
``` python
#dataset = dataset.shuffle(9).batch(3)
%%timeit
for next_element in dataset.take(200):
#time_s = time.time()
#tf.print(next_element.shape)
pass
# print(next_element.numpy()[0,0,0,0,0])
# print(next_element.numpy()[0,1,0,0,0])
# print(next_element.numpy()[0,2,0,0,0])
# print(next_element.numpy()[0,3,0,0,0])
# print("++++++++")
# print(next_element.numpy()[1,0,0,0,0])
# print(next_element.numpy()[1,1,0,0,0])
# print(next_element.numpy()[1,2,0,0,0])
# print(next_element.numpy()[1,3,0,0,0])
# print("-----------------")
#print(time.time - time_s)
```
%% Output
400 ms ± 28.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%% Cell type:code id:1d20b48a-0d77-442d-9efd-8e16acac0fd2 tags:
``` python
```
%% Cell type:code id:2dd96f63-d0ae-4515-809f-e3b3a05ca801 tags:
``` python
tf.random.set_seed(
123
)
#https://www.tensorflow.org/guide/data_performance
def parse_fn(x, min_value, max_value):
return (x-min_value)/(max_value - min_value)
preprocessed_dataset = dataset.map(map_func=parse_fn(x, min_value,max_value))
for row in preprocessed_dataset.take(2):
print(row)
```
%% Output
tf.Tensor(
[[[[140.14406 140.20949 140.24464 ... 136.82472 136.78859 136.75441]
[140.188 140.1841 140.19875 ... 136.94777 136.85597 136.74953]
[140.0298 140.08058 140.10402 ... 136.93312 136.86574 136.77003]
...
[144.21242 144.55812 144.6089 ... 141.7339 138.86769 137.20656]
[144.19484 144.6714 144.79152 ... 141.88722 139.03078 137.73586]
[144.96925 145.10793 145.06593 ... 141.23781 138.5757 137.70265]]
[[140.04933 140.11769 140.1382 ... 136.84132 136.79054 136.73 ]
[140.14015 140.15578 140.15578 ... 136.93898 136.84523 136.74855]
[140.01418 140.08058 140.12453 ... 136.88918 136.85402 136.7925 ]
...
[143.89015 144.28078 144.42531 ... 141.77394 138.91847 137.21632]
[143.8091 144.4175 144.59425 ... 141.8755 139.0591 137.75539]
[144.74757 144.96535 144.95753 ... 141.37062 138.63039 137.75636]]
[[140.02666 140.0833 140.1038 ... 136.97197 136.9583 136.92607]
[140.13408 140.15166 140.15752 ... 136.99931 136.94463 136.88506]
[139.97295 140.06572 140.12431 ... 136.90068 136.89189 136.8665 ]
...
[143.63017 144.03642 144.16338 ... 141.80107 138.64287 136.84795]
[143.43486 144.10674 144.29033 ... 141.82744 138.76006 137.36552]
[144.4085 144.66924 144.70049 ... 141.00615 138.22197 137.35283]]]
[[[140.02557 140.10565 140.14667 ... 136.85077 136.88397 136.9035 ]
[140.04999 140.10956 140.14569 ... 136.88788 136.87714 136.86542]
[139.79999 139.93182 140.02753 ... 136.81561 136.83905 136.86053]
...
[143.4035 143.82635 143.96503 ... 141.76874 138.75507 136.71893]
[143.07245 143.8869 144.07538 ... 141.72089 138.591 137.19647]
[144.12909 144.40839 144.42987 ... 140.69745 138.03143 137.18378]]
[[139.96661 140.0545 140.10822 ... 136.75568 136.79181 136.8299 ]
[139.9129 139.95587 140.00177 ... 136.76154 136.7547 136.75958]
[139.6629 139.76056 139.83185 ... 136.69025 136.68048 136.70001]
...
[143.16681 143.60431 143.75568 ... 141.64532 138.63849 136.65314]
[142.867 143.66095 143.867 ... 141.55841 138.534 137.13458]
[143.88165 144.15509 144.18634 ... 140.5252 138.01349 137.13947]]
[[139.93112 140.0151 140.05026 ... 136.88815 136.90182 136.92526]
[139.86569 139.87155 139.88034 ... 136.89401 136.87839 136.86179]
[139.603 139.68015 139.72604 ... 136.89792 136.83054 136.79636]
...
[142.95358 143.42526 143.5737 ... 141.45749 138.35202 136.58054]
[142.6323 143.46921 143.686 ... 141.37448 138.26608 136.80515]
[143.68112 143.95847 143.9819 ... 140.15085 137.85007 136.94675]]]
[[[139.94264 139.97878 139.97292 ... 136.90944 136.886 136.88698]
[139.89186 139.85182 139.81471 ... 136.9905 136.89284 136.79128]
[139.59889 139.65846 139.68776 ... 137.04323 136.90358 136.78249]
...
[142.8108 143.27858 143.43092 ... 141.32253 138.28932 136.56569]
[142.51198 143.31178 143.54909 ... 141.22292 138.24147 136.7903 ]
[143.54811 143.82643 143.82545 ... 140.09303 137.95436 137.03249]]
[[140.0629 140.11563 140.08731 ... 136.9252 136.87051 136.85 ]
[139.97403 139.91934 139.85294 ... 137.10196 136.9672 136.81876]
[139.65958 139.69669 139.70645 ... 137.18008 137.00919 136.84512]
...
[142.68399 143.17227 143.30899 ... 141.20059 138.35587 136.82657]
[142.15762 143.18008 143.44278 ... 141.19278 138.70743 137.34512]
[143.37637 143.73184 143.71817 ... 140.65567 138.64493 137.85196]]
[[140.18867 140.24532 140.2209 ... 137.46504 137.36446 137.28242]
[140.07051 140.02461 139.96211 ... 137.70137 137.5168 137.32637]
[139.7375 139.78047 139.79317 ... 137.77461 137.57344 137.38203]
...
[142.60176 143.092 143.21309 ... 141.36153 139.64082 138.467 ]
[142.15352 143.10176 143.37715 ... 141.35957 140.28145 139.16426]
[143.27168 143.66231 143.68086 ... 141.87227 140.62422 139.9875 ]]]
[[[140.24962 140.31407 140.30724 ... 137.61388 137.49376 137.38536]
[140.09727 140.06993 140.0377 ... 137.86192 137.65099 137.43224]
[139.76134 139.82481 139.86095 ... 137.93419 137.692 137.46837]
...
[142.76524 143.07384 143.15392 ... 141.54747 140.30724 139.2838 ]
[143.0631 143.2584 143.38634 ... 141.61876 140.86095 139.89806]
[143.44005 143.64806 143.70274 ... 142.00352 141.234 140.70958]]
[[140.24228 140.2833 140.2911 ... 138.17587 138.0499 137.9288 ]
[140.17099 140.17197 140.1661 ... 138.3165 138.12216 137.93661]
[139.92392 139.99423 140.03818 ... 138.2833 138.08994 137.90927]
...
[143.39853 143.39658 143.28622 ... 141.97275 141.54306 141.02939]
[144.20615 143.7081 143.5831 ... 142.06943 141.8663 141.33994]
[143.78525 143.79794 143.84189 ... 142.46591 142.06064 141.61826]]
[[140.25374 140.3055 140.32698 ... 138.42952 138.27425 138.12776]
[140.19807 140.19319 140.19319 ... 138.5467 138.33284 138.12093]
[139.94319 139.99983 140.04768 ... 138.49495 138.26936 138.05745]
...
[143.96956 143.79573 143.51448 ... 142.1971 141.92854 141.56136]
[145.04573 144.16194 143.826 ... 142.31721 142.27034 141.85628]
[144.12093 143.89339 143.92757 ... 142.64925 142.44612 142.06917]]]], shape=(4, 3, 56, 92), dtype=float32)
tf.Tensor(
[[[[140.2667 140.30966 140.32138 ... 138.63876 138.4581 138.29794]
[140.20224 140.18466 140.17294 ... 138.74228 138.504 138.26376]
[139.9415 139.97763 140.004 ... 138.68857 138.43173 138.18954]
...
[144.53427 144.2872 143.89072 ... 142.38486 142.25693 141.98056]
[145.74228 144.46005 144.03818 ... 142.50009 142.59091 142.25107]
[144.37997 143.9415 143.96494 ... 142.87021 142.7579 142.41122]]
[[140.2666 140.30664 140.31152 ... 138.7832 138.58887 138.41699]
[140.19238 140.16602 140.14551 ... 138.86426 138.61035 138.36328]
[139.96875 139.99316 140.01172 ... 138.77148 138.5166 138.27441]
...
[144.9414 144.6211 144.24805 ... 142.54199 142.41699 142.17188]
[146.0791 144.5586 144.12793 ... 142.65234 142.77148 142.4668 ]
[144.60059 144.02148 144.00586 ... 143.02246 142.92285 142.60938]]
[[140.30289 140.3273 140.31851 ... 138.9689 138.84879 138.72867]
[140.25015 140.20914 140.17496 ... 139.03336 138.83218 138.6271 ]
[140.05484 140.07242 140.08023 ... 138.86246 138.68765 138.49625]
...
[145.11441 144.7609 144.32925 ... 142.69156 142.29996 141.95425]
[146.44254 144.61343 144.14175 ... 142.81754 142.6398 142.18668]
[144.84879 144.19937 144.10074 ... 142.96988 142.72379 142.32632]]]
[[[140.33517 140.35373 140.34103 ... 138.93967 138.8469 138.75217]
[140.29416 140.26291 140.2297 ... 138.9426 138.78537 138.63596]
[140.09885 140.1301 140.13596 ... 138.70236 138.59201 138.47092]
...
[145.23264 144.87814 144.36642 ... 142.76682 141.71701 141.61642]
[146.36057 144.81271 144.29709 ... 142.95139 141.85568 141.82248]
[144.9172 144.4338 144.28342 ... 143.05783 142.44357 142.06955]]
[[140.33109 140.37796 140.37796 ... 138.85745 138.7637 138.67484]
[140.29984 140.29105 140.27151 ... 138.78226 138.62015 138.48929]
[140.08011 140.14359 140.16214 ... 138.5098 138.39066 138.28714]
...
[145.26468 145.07718 144.6094 ... 142.84964 140.98343 140.79398]
[146.26273 145.03519 144.57425 ... 143.05472 141.1094 141.08206]
[144.94632 144.54105 144.40921 ... 142.51956 141.72952 141.4053 ]]
[[140.29703 140.36832 140.39175 ... 138.90543 138.84683 138.79507]
[140.2521 140.28629 140.29703 ... 138.85562 138.71695 138.60464]
[139.97476 140.1066 140.17398 ... 138.57828 138.465 138.38492]
...
[145.10855 145.22672 144.90738 ... 142.9064 139.83804 139.53433]
[145.49332 145.00992 144.73843 ... 143.10855 140.13882 139.86832]
[144.87808 144.59293 144.42398 ... 140.91226 140.26773 140.1564 ]]]
[[[140.2214 140.35715 140.40793 ... 138.83762 138.77316 138.7214 ]
[140.18039 140.22629 140.24875 ... 138.8132 138.6716 138.55832]
[139.83176 139.99875 140.09933 ... 138.56223 138.44797 138.36887]
...
[144.77805 145.05344 145.0007 ... 142.98996 139.88644 139.5134 ]
[145.0261 144.93039 144.84543 ... 143.11691 140.15793 139.77902]
[144.9011 144.71164 144.51535 ... 140.9968 139.34738 139.79465]]
[[140.16458 140.27884 140.33646 ... 138.79056 138.72415 138.66849]
[140.13333 140.16556 140.19193 ... 138.77493 138.64114 138.53372]
[139.78665 139.94193 140.03372 ... 138.5513 138.42728 138.34036]
...
[144.44876 144.78275 144.85599 ... 143.04837 139.89896 139.43607]
[144.61575 144.83841 144.89114 ... 143.11575 140.15677 139.71243]
[144.9263 144.82474 144.63333 ... 141.3804 139.51419 139.42532]]
[[140.03177 140.08841 140.13919 ... 138.8013 138.74075 138.69388]
[140.0972 140.11868 140.15872 ... 138.80716 138.67825 138.57474]
[139.82376 139.9595 140.03958 ... 138.60501 138.48782 138.40482]
...
[144.17532 144.52884 144.60794 ... 143.07181 139.80618 139.1177 ]
[144.43118 144.68997 144.83548 ... 143.07181 139.98685 139.39114]
[144.85892 144.90677 144.7388 ... 141.57962 139.48392 138.99173]]]
[[[139.9985 140.0571 140.08737 ... 138.86081 138.78073 138.72311]
[140.04147 140.06686 140.09225 ... 138.87839 138.71432 138.59323]
[139.89304 139.99655 140.05222 ... 138.67136 138.53952 138.44577]
...
[143.9653 144.30807 144.40866 ... 143.07272 139.80124 138.57956]
[144.07663 144.50241 144.69577 ... 142.97995 139.98093 138.89792]
[144.65671 144.87253 144.83444 ... 141.88425 139.64597 138.72995]]
[[140.03333 140.07434 140.0929 ... 138.79993 138.75305 138.71399]
[140.00598 140.04504 140.06458 ... 138.85364 138.72571 138.61047]
[139.90051 140.02356 140.0841 ... 138.70618 138.58313 138.48547]
...
[143.96497 144.26282 144.29895 ... 142.96594 139.62708 137.83508]
[144.14856 144.3927 144.56555 ... 142.81555 139.79797 138.33313]
[144.61047 144.80579 144.82141 ... 141.8927 139.54993 138.37415]]
[[140.02469 140.10086 140.12527 ... 138.86551 138.83719 138.80106]
[139.96902 140.02957 140.05301 ... 138.92703 138.8157 138.70438]
[139.8245 139.96902 140.04129 ... 138.7786 138.66727 138.57059]
...
[143.89774 144.1995 144.21902 ... 142.8948 139.5745 137.63895]
[144.05887 144.29422 144.4368 ... 142.6868 139.80496 138.22293]
[144.44266 144.63309 144.69168 ... 141.96317 139.60867 138.39285]]]], shape=(4, 3, 56, 92), dtype=float32)
2022-03-02 15:55:14.988517: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
%% Cell type:code id: tags:
``` python
import xarray as xr
import numpy as np
filenames_t850 = [
"data_t850/temperature_850hPa_1979_5.625deg.nc",
"data_t850/temperature_850hPa_1980_5.625deg.nc"
]
filenames_z500 = [
"data_z500/geopotential_500hPa_1979_5.625deg.nc",
"data_z500/geopotential_500hPa_1980_5.625deg.nc"
]
filenames = [*filenames_t850, *filenames_z500]
ds = xr.open_mfdataset(filenames, coords="minimal", compat="override")
ds = ds.drop_vars("level")
```
%% Cell type:code id: tags:
``` python
da = ds.to_array(dim="variables").squeeze()
dims = ["time", "lat", "lon", "variables"]
da = da.transpose(*dims)
def generator(iterable):
iterator = iter(iterable)
yield from iterator
da.shape[1:]
```
%% Output
(32, 64, 2)
%% Cell type:code id: tags:
``` python
```
#!#bin/bash
# Name of virtual environment
#VIRT_ENV_NAME="vp_new_structure"
VIRT_ENV_NAME="env_hdfml"
VIRT_ENV_NAME="venv2_hdfml"
if [ -z ${VIRTUAL_ENV} ]; then
if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then
......@@ -21,6 +20,7 @@ fi
#python -m pytest test_prepare_era5_data.py
##Test for preprocess_step1
#python -m pytest test_process_netCDF_v2.py
#source ../video_prediction_tools/env_setup/modules_preprocess+extract.sh
source ../video_prediction_tools/env_setup/modules_train.sh
##Test for preprocess moving mnist
#python -m pytest test_prepare_moving_mnist_data.py
......@@ -33,5 +33,5 @@ source ../video_prediction_tools/env_setup/modules_train.sh
#rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/*
#python -m pytest test_train_model_era5.py
#python -m pytest test_vanilla_vae_model.py
python -m pytest test_visualize_postprocess.py
python -m pytest test_gzprcp_data.py
#python -m pytest test_meta_postprocess.py
# Name of virtual environment
VIRT_ENV_NAME="venv_hdfml"
CONTAINER_IMG="../video_prediction_tools/HPC_scripts/tensorflow_21.09-tf1-py3.sif"
WRAPPER="./wrapper_container.sh"
# sanity checks
if [[ ! -f ${CONTAINER_IMG} ]]; then
echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'."
exit 1
fi
if [[ ! -f ${WRAPPER} ]]; then
echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image."
exit 1
fi
#source ../video_prediction_tools/env_setup/modules_preprocess+extract.sh
singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} python3 -m pytest test_era5_data.py
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth"
__author__ = "Bing Gong"
from video_prediction.datasets.era5_dataset import *
import pytest
import xarray as xr
import os
import tensorflow as tf
import numpy as np
import json
import datetime
input_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/test"
datasplit_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/data_split/cv_test.json"
hparams_dict_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/hparams/era5/convLSTM/model_hparams.json"
sequences_per_file = 10
mode = "val"
input_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/test_data_roshni"
datasplit_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/data_split/test/cv_test.json"
hparams_dict_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json"
mode = "test"
@pytest.fixture(scope="module")
def era5_dataset_case2():
return ERA5Dataset(input_dir=input_dir,mode=mode,
datasplit_config=datasplit_config,hparams_dict_config=hparams_dict_config,seed=1234)
def test_init_era5_dataset(era5_dataset_case2):
assert era5_dataset_case2.hparams.max_epochs == 20
assert era5_dataset_case2.mode == mode
def era5_dataset_case1():
return ERA5Dataset(input_dir=input_dir, datasplit_config=datasplit_config, hparams_dict_config=hparams_dict_config,
mode="test", seed=1234, nsamples_ref=1000)
def test_init_era5_dataset(era5_dataset_case1):
era5_dataset_case1.get_hparams()
assert era5_dataset_case1.max_epochs == 20
assert era5_dataset_case1.mode == mode
assert era5_dataset_case1.batch_size == 4
def test_get_filenames_from_datasplit(era5_dataset_case1):
flname= os.path.join(era5_dataset_case1.input_dir, "era5_vars4ambs_201901.nc")
n_files = len(era5_dataset_case1.filenames)
check = flname in era5_dataset_case1.filenames
assert check == True
assert n_files == 12
def test_make_dataset(era5_dataset_case1):
# Get the data from nc files directly
data_arr = era5_dataset_case1.load_data_from_nc()
assert len(data_arr) !=0
ds = xr.open_mfdataset(era5_dataset_case1.filenames)
len_dt = len(ds["time"].values) # count number of images/samples in the test dataset
da = ds.to_array(dim = "variables").squeeze()
dims = ["time", "lat", "lon"]
data_arr = np.squeeze(da.values) #[vars,samples,lat,lon]
max_vars, min_vars = da.max(dim=dims).values, da.min(dim=dims).values #three dimension
print("data_arr shape",data_arr.shape)
#normalise the data for the first variable
def norm_var(x, min_value, max_value):
return (x - min_value) / (max_value - min_value)
assert np.max(data_arr[0]) == max_vars[0]
#mannualy calculate the normalization of the data
dt_norm = norm_var(data_arr[0],np.min(data_arr[0]), np.max(data_arr[0]))
def test_get_tfrecords_filesnames(era5_dataset_case2):
era5_dataset_case2.get_tfrecords_filesnames_base_datasplit()
assert era5_dataset_case2.filenames[0] == os.path.join(input_dir,"tfrecords","sequence_Y_2017_M_2_0_to_9.tfrecords")# def test_check_pkl_tfrecords_consistency(era5_dataset_case1):
print("dt_norm",dt_norm.shape)
s1 = dt_norm[0] #the first sample, first timestamp
s2 = dt_norm[23] #the first sample, last timestamp
s3 = dt_norm[1] # the second sample, first timestamp
s4 = dt_norm[24] # the second sample, last timestamp
# Get the data from make_dataset function
test_dataset = era5_dataset_case1.make_dataset()
test_iterator = test_dataset.make_one_shot_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
test_handle = test_iterator.string_handle()
iterator = tf.data.Iterator.from_string_handle(test_handle, test_dataset.output_types, test_dataset.output_shapes)
inputs = iterator.get_next()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
#get the batch size samples from dataset
dt = sess.run(inputs) #[batch_size,sequence_len,n_vars,lon,lat]
dt.shape[0] == 4
dt.shape[1] == 24
print("shape of dt",dt.shape)
s1t = dt[0,0,0]
s2t = dt[0,23,0]
def test_get_example_info(era5_dataset_case2):
era5_dataset_case2.get_tfrecords_filesnames_base_datasplit()
era5_dataset_case2.get_example_info()
assert era5_dataset_case2.image_shape[0] == 160
assert era5_dataset_case2.image_shape[1] == 128
assert era5_dataset_case2.image_shape[2] == 3
#get the second sample from dataset
s3t = dt[1,0,0]
s4t = dt[1,23,0]
#s2t = sess.run(inputs)[0,:,0]
assert np.sum(s1-s1t) < 0.0001
assert np.sum(s2-s2t) < 0.0001
assert np.sum(s3-s3t) < 0.0001
assert np.sum(s4 -s4t) < 0.0001
#compare the data from nc files and make_dataset
__email__ = "b.gong@fz-juelich.de"
from video_prediction.datasets.gzprcp_dataset import *
import pytest
import tensorflow as tf
import xarray as xr
input_dir = "/p/largedata/jjsc42/project/deeprain/project_data/10min_AWS_prcp"
datasplit_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/data_split/gzprcp/datasplit.json"
hparams_dict_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/hparams/gzprcp/convLSTM_gan/model_hparams_template.json"
sequences_per_file = 10
mode = "test"
@pytest.fixture(scope="module")
def gzprcp_dataset_case1():
dataset = GzprcpDataset(input_dir=input_dir, datasplit_config=datasplit_config, hparams_dict_config=hparams_dict_config,
mode="test", seed=1234, nsamples_ref=1000)
dataset.get_hparams()
dataset.get_filenames_from_datasplit()
dataset.load_data_from_nc()
return dataset
def test_init_gzprcp_dataset(gzprcp_dataset_case1):
# gzprcp_dataset_case1.get_hparams()
print('gzprcp_dataset_case1.max_epochs: {}'.format(gzprcp_dataset_case1.max_epochs))
print('gzprcp_dataset_case1.mode: {}'.format(gzprcp_dataset_case1.mode))
print('gzprcp_dataset_case1.batch_size: {}'.format(gzprcp_dataset_case1.batch_size))
print('gzprcp_dataset_case1.k: {}'.format(gzprcp_dataset_case1.k))
print('gzprcp_dataset_case1.filenames: {}'.format(gzprcp_dataset_case1.filenames))
assert gzprcp_dataset_case1.max_epochs == 8
assert gzprcp_dataset_case1.mode == mode
assert gzprcp_dataset_case1.batch_size == 32
assert gzprcp_dataset_case1.k == 0.01
# assert gzprcp_dataset_case1.filenames[0] == 'GZ_prcp_2019.nc'
def test_load_data_from_nc(gzprcp_dataset_case1):
train_tf_dataset = gzprcp_dataset_case1.make_dataset()
train_iterator = train_tf_dataset.make_one_shot_iterator()
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
train_handle = train_iterator.string_handle()
iterator = tf.data.Iterator.from_string_handle(train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
inputs = iterator.get_next()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())
for step in range(2):
sess.run(inputs)
# df = xr.open_mfdataset(era5_dataset_case1.filenames)
# if __name__ == '__main__':
# dataset = ERA5Dataset(input_dir: str = None, datasplit_config: str = None, hparams_dict_config: str = None,
# mode: str = "train", seed: int = None, nsamples_ref: int = None)
# for next_element in dataset.take(2):
# # time_s = time.time()
# # tf.print(next_element.shape)
# pass
......@@ -4,11 +4,8 @@ __author__ = "Bing Gong"
__date__ = "2021-03-03"
from data_preprocess.prepare_era5_data import *
import pytest
import numpy as np
import json
import os
year="2007"
......@@ -23,8 +20,6 @@ def dataExtraction_case1(year=year,job_name=job_name,src_dir=src_dir,target_dir=
return ERA5DataExtraction(year,job_name,src_dir,target_dir,varslist_json)
def test_init(dataExtraction_case1):
assert dataExtraction_case1.job_name == 1
assert dataExtraction_case1.src_dir == src_dir
......
#!/bin/bash -x
## Controlling Batch-job
#SBATCH --account=<your_project>
#SBATCH --account=deepacf
#SBATCH --nodes=1
#SBATCH --ntasks=13
##SBATCH --ntasks-per-node=12
#SBATCH --cpus-per-task=1
#SBATCH --output=DataPreprocess_era5_step1-out.%j
#SBATCH --error=DataPreprocess_era5_step1-err.%j
#SBATCH --time=04:20:00
#SBATCH --gres=gpu:0
#SBATCH --output=log_out.%j
#SBATCH --error=log_err.%j
#SBATCH --time=00:10:00
#SBATCH --partition=batch
#SBATCH --mail-type=ALL
#SBATCH --mail-user=me@somewhere.com
######### Template identifier (don't remove) #########
echo "Do not run the template scripts"
exit 99
######### Template identifier (don't remove) #########
ml Stages/2022
ml GCCcore/.11.2.0
ml GCC/11.2.0
ml ParaStationMPI/5.5.0-1
ml Python/3.9.6
ml SciPy-bundle/2021.10
ml xarray/0.20.1
ml netcdf4-python/1.5.7
ml dask/2021.9.1
# Name of virtual environment
VIRT_ENV_NAME="my_venv"
# Activate virtual environment if needed (and possible)
"""
if [ -z ${VIRTUAL_ENV} ]; then
if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
echo "Activating virtual environment..."
......@@ -33,27 +39,25 @@ if [ -z ${VIRTUAL_ENV} ]; then
fi
# Loading modules
source ../env_setup/modules_preprocess+extract.sh
"""
source_dir=/p/scratch/deepacf/inbound_data/weatherbench
destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/weatherbench_test/extracted
data_extraction_dir=/p/project/deepacf/deeprain/grasse/ambs/video_prediction_tools/data_preprocess
variables='[{"name":"temperature","lvl":[850],"interpolation":"p"},{"name":"geopotential","lvl":[500],"interpolation":"p"}]'
years=("2013" "2014" "2015" "2016" "2017")
# select years and variables for dataset and define target domain
years=( "2015" )
variables=( "t2" "t2" "t2" )
sw_corner=( -999.9 -999.9)
nyx=( -999 -999 )
cd ${data_extraction_dir}
# set some paths
# note, that destination_dir is adjusted during runtime based on the data
source_dir=/my/path/to/extracted/data/
destination_dir=/my/path/to/pickle/files
# execute Python-scripts
for year in "${years[@]}"; do
echo "start preprocessing data for year ${year}"
srun python ../main_scripts/main_preprocess_data_step1.py \
--source_dir ${source_dir} --destination_dir ${destination_dir} --years "${year}" \
--vars "${variables[0]}" "${variables[1]}" "${variables[2]}" \
--sw_corner "${sw_corner[0]}" "${sw_corner[1]}" --nyx "${nyx[0]}" "${nyx[1]}"
done
# Name of virtual environment
venv_dir=".venv"
python -m venv --system-site-packages ${venv_dir}
. ${venv_dir}/bin/activate
#pip3 install --no-cache-dir pytz
#pip3 install --no-cache-dir python-dateutil
export PYTHONPATH=${data_extraction_dir}:$PYTHONPATH
export PYTHONPATH="${data_extraction_dir}/..":$PYTHONPATH
python3 ../main_scripts/main_data_extraction.py ${source_dir} ${dest_dir} ${years[@]} ${variables}
#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} --varnames T2 MSL gph500
rm -r ${venv_dir}
......@@ -3,13 +3,13 @@
#SBATCH --account=<your_project>
#SBATCH --nodes=1
#SBATCH --ntasks=13
##SBATCH --ntasks-per-node=13
##SBATCH --ntasks-per-node=12
#SBATCH --cpus-per-task=1
#SBATCH --output=data_extraction_era5-out.%j
#SBATCH --error=data_extraction_era5-err.%j
#SBATCH --output=DataExtraction_era5_step1-out.%j
#SBATCH --error=DataExtraction_era5_step1-err.%j
#SBATCH --time=04:20:00
#SBATCH --partition=batch
#SBATCH --gres=gpu:0
#SBATCH --partition=batch
#SBATCH --mail-type=ALL
#SBATCH --mail-user=me@somewhere.com
......@@ -22,7 +22,7 @@ exit 99
VIRT_ENV_NAME="my_venv"
# Activate virtual environment if needed (and possible)
if [ -z ${VIRTUAL_ENV} ]; then
if [ -z "${VIRTUAL_ENV}" ]; then
if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
echo "Activating virtual environment..."
source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate
......@@ -34,16 +34,21 @@ fi
# Loading modules
source ../env_setup/modules_preprocess+extract.sh
# Declare path-variables (dest_dir will be set and configured automatically via generate_runscript.py)
source_dir=/my/path/to/era5
# select years and variables for dataset and define target domain
years=( 2017 )
months=( "all" )
var_dict='{"2t": {"sf": ""}, "tcc": {"sf": ""}, "t": {"ml": "p85000."}}'
sw_corner=(38.4 0.0)
nyx=(56 92)
# set some paths
# note, that destination_dir is adjusted during runtime based on the data
source_dir=/my/path/to/era5/data
destination_dir=/my/path/to/extracted/data
varmap_file=/my/path/to/varmapping/file
years=( "2015" )
# execute Python-script
srun python ../main_scripts/main_era5_data_extraction.py -src_dir "${source_dir}" \
-dest_dir "${destination_dir}" -y "${years[@]}" -m "${months[@]}" \
-swc "${sw_corner[@]}" -nyx "${nyx[@]}" -v "${var_dict}"
# Run data extraction
for year in "${years[@]}"; do
echo "Perform ERA5-data extraction for year ${year}"
srun python ../main_scripts/main_data_extraction.py --source_dir ${source_dir} --target_dir ${destination_dir} \
--year ${year} --varslist_path ${varmap_file}
done
#!/bin/bash -x
#SBATCH --account=<your_project>
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --output=train_model_era5-out.%j
#SBATCH --error=train_model_era5-err.%j
#SBATCH --time=24:00:00
#SBATCH --gres=gpu:1
#SBATCH --partition=some_partition
#SBATCH --mail-type=ALL
#SBATCH --mail-user=me@somewhere.com
######### Template identifier (don't remove) #########
echo "Do not run the template scripts"
exit 99
######### Template identifier (don't remove) #########
# auxiliary variables
WORK_DIR="$(pwd)"
BASE_DIR=$(dirname "$WORK_DIR")
# Name of virtual environment
VIRT_ENV_NAME="my_venv"
# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!!
# For container usage, comment in the follwoing lines
# Name of container image (must be available in working directory)
CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif"
WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh"
# sanity checks
if [[ ! -f ${CONTAINER_IMG} ]]; then
echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'."
exit 1
fi
if [[ ! -f ${WRAPPER} ]]; then
echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image."
exit 1
fi
# clean-up modules to avoid conflicts between host and container settings
module purge
# declare directory-variables which will be modified by generate_runscript.py
source_dir=/my/path/to/tfrecords/files
destination_dir=/my/model/output/path
# valid identifiers for model-argument are: convLSTM, savp, mcnet and vae
model=convLSTM
datasplit_dict=${destination_dir}/data_split.json
model_hparams=${destination_dir}/model_hparams.json
# run training in container
export CUDA_VISIBLE_DEVICES=0
## One node, single GPU
srun --mpi=pspmix --cpu-bind=none \
singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \
python3 "${BASE_DIR}"/main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \
--dataset weatherbench --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/
# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above)
# Activate virtual environment if needed (and possible)
#if [ -z ${VIRTUAL_ENV} ]; then
# if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then
# echo "Activating virtual environment..."
# source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate
# else
# echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..."
# exit 1
# fi
#fi
#
# Loading modules
#module purge
#source ../env_setup/modules_train.sh
#export CUDA_VISIBLE_DEVICES=0
#
# srun python3 "${BASE_DIR}"/main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \
# --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/
\ No newline at end of file
# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
known_datasets = {"era5", "weatherbench"}
\ No newline at end of file
This diff is collapsed.
import os, glob
import logging
from zipfile import ZipFile
from typing import Union
from pathlib import Path
import multiprocessing as mp
import itertools as it
import sys
import pandas as pd
import xarray as xr
from utils.dataset_utils import get_filename_template
logging.basicConfig(level=logging.DEBUG)
class ExtractWeatherbench:
max_years = list(range(1979, 2018))
def __init__(
self,
dirin: Path,
dirout: Path,
variables: list[dict],
years: Union[list[int], int],
months: list[int],
lat_range: tuple[float],
lon_range: tuple[float],
resolution: float,
):
"""
This script performs several sanity checks and sets the class attributes accordingly.
:param dirin: directory to the ERA5 reanalysis data
:param dirout: directory where the output data will be saved
:param variables: controlled dictionary for getting variables from ERA5-dataset, e.g. {"t": {"ml": "p850"}}
:param years: list of year to to extract, -1 if all
:param months: list of months to extract
:param lat_range: domain of the latitude axis to extract
:param lon_range: domain of the longitude axis to extract
:param resolution: spacing on both lat, lon axis
"""
self.dirin = dirin
self.dirout = dirout
if years[0] == -1:
self.years = ExtractWeatherbench.max_years
else:
self.years = years
self.months = months
# TODO handle special variables for resolution 5.625 (temperature_850, geopotential_500)
if resolution == 5.625:
for var in variables:
combined_name = f"{var['name']}_{var['lvl'][0]}"
if combined_name in {"temperature_850", "geopotential_500"}:
var["name"] = combined_name
self.variables = variables
self.lat_range = lat_range
self.lon_range = lon_range
self.resolution = resolution
def __call__(self):
"""
Run extraction.
:return: -
"""
logging.info("start extraction")
zip_files, data_files = self.get_data_files()
# extract archives => netcdf files (maybe use tempfiles ?)
args = [
(var_zip, file, self.dirout)
for var_zip, files in zip(zip_files, data_files)
for file in files
]
with mp.Pool(20) as p:
p.starmap(ExtractWeatherbench.extract_task, args)
logging.info("finished extraction")
# TODO: handle 3d data
# load data
files = [self.dirout / file for data_file in data_files for file in data_file]
ds = xr.open_mfdataset(files, coords="minimal", compat="override")
logging.info("opened dataset")
ds.drop_vars("level")
logging.info("data loaded")
# select months
ds = ds.isel(time=ds.time.dt.month.isin(self.months))
# select region
ds = ds.sel(lat=slice(*self.lat_range), lon=slice(*self.lon_range))
logging.info("selected region")
# split into monthly netcdf
year_month_idx = pd.MultiIndex.from_arrays(
[ds.time.dt.year.values, ds.time.dt.month.values]
)
ds.coords["year_month"] = ("time", year_month_idx)
logging.info("constructed splitting-index")
with mp.Pool(20) as p:
p.map(
ExtractWeatherbench.write_task,
zip(ds.groupby("year_month"), it.repeat(self.dirout)),
chunksize=5,
)
logging.info("wrote output")
@staticmethod
def extract_task(var_zip, file, dirout):
with ZipFile(var_zip, "r") as myzip:
myzip.extract(path=dirout, member=file)
@staticmethod
def write_task(args):
(year_month, monthly_ds), dirout = args
year, month = year_month
logging.debug(f"{year}.{month:02d}: dropping index")
monthly_ds = monthly_ds.drop_vars("year_month")
try:
logging.debug(f"{year}.{month:02d}: writing to netCDF")
monthly_ds.to_netcdf(path=dirout / get_filename_template("weatherbench").format(year=year, month=month))
except RuntimeError as e:
logging.error(f"runtime error for writing {year}.{month}\n{str(e)}")
logging.debug(f"{year}.{month:02d}: finished processing")
def get_data_files(self):
"""
Get path to zip files and names of the yearly files within.
:return lists paths to zips of variables
"""
data_files = []
zip_files = []
res_str = f"{self.resolution}deg"
years = self.years
for var in self.variables:
var_dir = self.dirin / res_str / var["name"]
if not var_dir.exists():
raise ValueError(
f"variable {var} is not available for resolution {res_str}"
)
zip_file = var_dir / f"{var['name']}_{res_str}.zip"
with ZipFile(zip_file, "r") as myzip:
names = myzip.namelist()
logging.debug(f"var:{var}\nyears:{years}\nnames:{names}")
if not all(any(str(year) in name for name in names) for year in years):
missing_years = list(filter(lambda year: any(str(year) in name for name in names), years))
raise ValueError(
f"variable {var} is not available for years: {missing_years}"
)
names = filter(
lambda name: any(str(year) in name for year in years), names
)
data_files.append(list(names))
zip_files.append(zip_file)
return zip_files, data_files
# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
"""
Class and functions required for preprocessing Moving mnist data from .npz to TFRecords
Class and functions required for preprocessing guizhou prcp data from .nc to TFRecords
"""
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong, Karim Mache"
__date__ = "2021_05_04"
__email__ = "y.ji@fz-juelich.de"
__author__ = "Yan Ji, Bing Gong"
__date__ = "2021_05_09"
import datetime
import os
import numpy as np
import tensorflow as tf
import argparse
from model_modules.video_prediction.datasets.moving_mnist import MovingMnist
import netCDF4 as nc
from model_modules.video_prediction.datasets.gzprcp_data import GZprcp
class MovingMnist2Tfrecords(MovingMnist):
class GZprcp2Tfrecords(GZprcp):
def __init__(self, input_dir=None, dest_dir=None, sequences_per_file=128):
def __init__(self, input_dir=None, target_year=2019,dest_dir=None, sequences_per_file=10):
"""
This class is used for converting .npz files to tfrecords
This class is used for converting .nc files to tfrecords
:param input_dir: str, the path direcotry to the file of npz
:param dest_dir: the output directory to save TFrecords.
:param sequence_length: int, default is 20, the sequence length per sample
:param sequence_length: int, default is 40, the sequence length per sample
:param sequences_per_file:int, how many sequences/samples per tfrecord to be saved
"""
self.input_dir = input_dir
self.output_dir = dest_dir
self.target_year = target_year
os.makedirs(self.output_dir, exist_ok = True)
self.sequences_per_file = sequences_per_file
self.write_sequence_file()
def __call__(self):
"""
steps to process npy file to tfrecords
steps to process nc file to tfrecords
:return: None
"""
self.read_npz_file()
self.save_npz_to_tfrecords()
self.read_nc_file()
self.save_nc_to_tfrecords()
def read_nc_file(self):
data_temp = nc.Dataset(os.path.join(self.input_dir,str(self.target_year),"rainy","guizhou_prcp.nc"))
prcp_temp = np.transpose(data_temp['prcp'],[3,2,1,0])
def read_npz_file(self):
self.data = np.load(os.path.join(self.input_dir, "mnist_test_seq.npy"))
print("data in minist_test_Seq shape", self.data.shape)
######### missing data
prcp_temp[np.isnan(prcp_temp)] = 0
self.data = prcp_temp
self.time = np.transpose(data_temp['time'],[2,1,0])
print("data in gzprcp_test_Seq shape", self.data.shape)
return None
def save_npz_to_tfrecords(self): # Bing: original 128
def save_nc_to_tfrecords(self):
"""
Read the moving_mnst data which is npz format, and save it to tfrecords files
The shape of dat_npz is [seq_length,number_samples,height,width]
Read the gzprcp data which is nc format, and save it to tfrecords files
The shape of data_nc is [number_samples,seq_length,height,width]
moving_mnst only has one channel
"""
idx = 0
num_samples = self.data.shape[1]
num_samples = self.data.shape[0]
if len(self.data.shape) == 4:
#add one dim to represent channel, then got [seq_length,num_samples,height,width,channel]
#add one dim to represent channel, then got [num_samples,seq_length,height,width,channel]
self.data = np.expand_dims(self.data, axis = 4)
elif len(self.data.shape) == 5:
pass
else:
#print('data shape nor match')
raise (f"The shape of input movning mnist npz file is {len(self.data.shape)} which is not either 4 or 5, please further check your data source!")
self.data = self.data.astype(np.float32)
self.data/= 255.0 # normalize RGB codes by dividing it to the max RGB value
# self.data/= 255.0 # normalize RGB codes by dividing it to the max RGB value
############# normalization ############
#k = 0.001
#self.data = np.log(self.data+k)-np.log(k) # log
#######################################
while idx < num_samples - self.sequences_per_file:
sequences = self.data[:, idx:idx+self.sequences_per_file, :, :, :]
output_fname = 'sequence_index_{}_to_{}.tfrecords'.format(idx, idx + self.sequences_per_file-1)
sequences = self.data[idx:idx+self.sequences_per_file, :, :, :, :]
# use the first sequence time
t_start = self.time[idx:idx+self.sequences_per_file,0,4]+self.time[idx:idx+self.sequences_per_file,0,3]*100+self.time[idx:idx+self.sequences_per_file,0,2]*10000+self.time[idx:idx+self.sequences_per_file,0,1]*1000000+self.time[idx:idx+self.sequences_per_file,0,0]*100000000
# t_start = self.time[idx:idx+self.sequences_per_file,:,:]
# print('self.target_year: ',self.target_year)
output_fname = 'sequence_Y_{}_index_{}_to_{}.tfrecords'.format(self.target_year, idx, idx + self.sequences_per_file-1)
output_fname = os.path.join(self.output_dir, output_fname)
MovingMnist2Tfrecords.save_tf_record(output_fname, sequences)
GZprcp2Tfrecords.save_tf_record(output_fname, sequences, t_start)
idx = idx + self.sequences_per_file
return None
@staticmethod
def save_tf_record(output_fname, sequences):
def save_tf_record(output_fname, sequences, t_start_points):
with tf.python_io.TFRecordWriter(output_fname) as writer:
for i in range(np.array(sequences).shape[1] - 1):
sequence = sequences[:, i, :, :, :]
for i in range(np.array(sequences).shape[0]):
sequence = sequences[i, :, :, :, :]
############### time class ##############
# t_start = datetime.datetime(int(t_start_points[i,19,0]),int(t_start_points[i,19,1]),int(t_start_points[i,19,2]),int(t_start_points[i,19,3]),int(t_start_points[i,19,4])).strftime("%Y%m%d%H%M")
t_start = int(t_start_points[i])
############### time class ##############
num_frames = len(sequence)
height, width = sequence[0, :, :, 0].shape
encoded_sequence = np.array([list(image) for image in sequence])
......@@ -87,6 +112,7 @@ class MovingMnist2Tfrecords(MovingMnist):
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(1),
't_start': _int64_feature(t_start),
'images/encoded': _floats_feature(encoded_sequence.flatten()),
})
example = tf.train.Example(features = features)
......@@ -121,13 +147,16 @@ def _int64_feature(value):
def main():
parser = argparse.ArgumentParser()
parser.add_argument("-input_dir", type=str, help="The input directory that contains the movning mnnist npz file", default="/p/largedata/datasets/moving-mnist/mnist_test_seq.npy")
parser.add_argument("-output_dir", type=str)
parser.add_argument("-sequences_per_file", type=int, default=2)
parser.add_argument("-source_dir", type=str, help="The input directory that contains the gprcp_data nc file", default="/p/scratch/deepacf/ji4/extractedData/guizhou_prcpdata/prcp_squence/")
parser.add_argument("-target_year", type=int,default=2019)
parser.add_argument("-dest_dir", type=str,default="/p/scratch/deepacf/ji4/preprocessedData/gzprcp_data/tfrecords_seq_len")
parser.add_argument("-sequences_per_file", type=int, default=10)
args = parser.parse_args()
inst = MovingMnist2Tfrecords(args.input_dir, args.output_dir, args.sequence_per_file)
inst = GZprcp2Tfrecords(args.source_dir, args.target_year, args.dest_dir, args.sequences_per_file)
inst()
if __name__ == '__main__':
main()
# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
def known_datasets():
"""
An auxilary function
:return: dictionary of known datasets
"""
dataset_mappings = {
'google_robot': 'GoogleRobotVideoDataset',
'sv2p': 'SV2PVideoDataset',
'softmotion': 'SoftmotionVideoDataset',
'bair': 'SoftmotionVideoDataset', # alias of softmotion
'kth': 'KTHVideoDataset',
'ucf101': 'UCF101VideoDataset',
'cartgripper': 'CartgripperVideoDataset',
"era5": "ERA5Dataset",
"moving_mnist": "MovingMnist"
# "era5_anomaly":"ERA5Dataset_v2_anomaly",
}
return dataset_mappings
{
"surface": ["2t", "tcc","msl","10u","10v"],
"multi":{
"t" : {
"pl": 85000
}
}
}
# NOTE: Please configure this JSON-files according your needs. Any line starting with # will be removed
# when editing is invoked from generate_runscript.py.
#
# Explanation: In the following, the mapping of known variable names from the ERA5-data (grib2-files) is defined
# The keys of the dictionary 'surface' (for 2D surface varibales) denote the variable names
# in the target netCDF-file while the values denote the name of the variable in the ERA5 grib file.
# For the dictionary 'multi' (used for 3D variables), the keys denote both,
# the variable name in the target netCDF-file and in the ERA5 grib file.
# The value of the 'pl'-key denotes the pressure level (in Pa) onto which the data is interpolated
# !!! This file should be only adapted if you are familiar with the ERA5 grib files!!!
{
"surface":{
["2t", "tcc","msl","10u","10v"]
},
"multi":{
"t" : {
"pl": 85000
}
}
}
# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
"""
Functions required for extracting ERA5 data.
"""
import os
import json
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong,Michael Langguth,Yanji"
__update_date__ = "2022-02-15"
# specify source and target directories
class ERA5DataExtraction(object):
def __init__(self, year, job_name, src_dir, target_dir, varslist_json):
"""
Function to extract ERA5 data from slmet
args:
year : str, the target year to be processed "2017"
job_name :int from 1 to 12 correspoding to month
scr_dir :str, upper level of directory at year level
target_dir : str, upper level of directory at year level
varslist_json: str, the path to the varibale list that to be extracted from original grib file
"""
self.year = year
self.job_name = job_name
self.src_dir = src_dir
self.target_dir = target_dir
self.varslist_json = varslist_json
self.get_varslist()
def get_varslist(self):
"""
Function that read varslist_path json file and get variable list
"""
with open(self.varslist_json) as f:
self.varslist = json.load(f)
self.varslist_keys = list(self.varslist.keys())
if not ("surface" in self.varslist_keys and "multi" in self.varslist_keys):
raise ValueError("Thie file '{0}' should have two keys : surface and multi".format(self.varslist_json))
else:
self.varslist_surface = self.varslist["surface"]
self.varslist_multi = self.varslist["multi"]
self.varslist_multi_vars = self.varslist_multi.keys()
def prepare_era5_data_one_file(self, month, day, hour): # extract 2t,tcc,msl,t850,10u,10v
"""
Process one grib file from source directory (extract variables and interplolate variable) and save to output_directory
args:
month : str, the target month to be processed, e.g."01","02","03" ...,"12"
date : str, the target date to be processed e.g "01","02","03",..."31"
hour : str, the target hour to be processed e.g. "00","01",...,"23"
varslist_path: str, the path to variable list json file
output_path : str, the path to output directory
"""
temp_path = os.path.join(self.target_dir, self.year)
os.makedirs(temp_path, exist_ok=True)
temp_path = os.path.join(self.target_dir, self.year, month)
os.makedirs(temp_path, exist_ok=True)
for value in self.varslist_surface:
# surface variables
infile = os.path.join(self.src_dir, self.year, month, self.year+month+day+hour+'_sf.grb')
outfile_sf = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'_'+value+'.nc')
os.system('cdo --eccodes -f nc copy -selname,%s %s %s' % (value, infile, outfile_sf))
# multi-level variables
for var, pl_dic in self.varslist_multi.items():
for pl, pl_value in pl_dic.items():
infile = os.path.join(self.src_dir, self.year, month, self.year+month+day+hour+'_ml.grb')
outfile_sf_temp = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'_'+var +
str(pl_value) + '.nc')
outfile_sf = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'_'+var +
str(int(pl_value/100.)) + '.nc')
os.system('cdo -f nc copy -selname,%s -ml2pl,%d %s %s' % (var,pl_value,infile,outfile_sf_temp))
os.system('cdo -chname,%s,%s %s %s' % (var, var+"_{0:d}".format(int(pl_value/100.)), outfile_sf_temp, outfile_sf))
os.system('rm %s' % (outfile_sf_temp))
# merge both variables
infile = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'*.nc')
# change the output file name
outfile = os.path.join(self.target_dir, self.year, month, 'ecmwf_era5_'+self.year[2:]+month+day+hour+'.nc')
os.system('cdo merge %s %s' % (infile, outfile))
os.system('rm %s' % (infile))
def process_era5_in_dir(self):
"""
Function that extract data at year level
"""
dates = list(range(1,32))
dates = ["{:02d}".format(d) for d in dates]
hours = list(range(0,24))
hours = ["{:02d}".format(h) for h in hours]
print ("job_name",self.job_name)
for d in dates:
for h in hours:
self.prepare_era5_data_one_file(self.job_name, d, h)
# here the defeinition of the failure, success is placed 0=success / -1= fatal-failure / +1 = non-fatal -failure
worker_status = 0
return worker_status
# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC)
#
# SPDX-License-Identifier: MIT
"""
Class and functions required for preprocessing ERA5 data (preprocessing substep 2)
"""
__email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong"
__date__ = "2020_12_29"
# import modules
import os
import glob
import pickle
import numpy as np
import pandas as pd
import json
import tensorflow as tf
from normalization import Norm_data
from metadata import MetaData
import datetime
from model_modules.video_prediction.datasets import ERA5Dataset
class ERA5Pkl2Tfrecords(ERA5Dataset):
def __init__(self, input_dir=None, dest_dir=None, sequence_length=20, sequences_per_file=128, norm="minmax"):
"""
This class is used for converting pkl files to tfrecords
args:
input_dir : str, the path to the PreprocessData directory which is parent directory of "Pickle"
and "tfrecords" files directiory.
sequence_length : int, default is 20, the sequen length per sample
sequences_per_file : int, how many sequences/samples per tfrecord to be saved
norm : str, normalization methods from Norm_data class ("minmax" or "znorm";
default: "minmax")
"""
self.input_dir = input_dir
self.output_dir = dest_dir
# if the output_dir does not exist, then create it
os.makedirs(self.output_dir, exist_ok=True)
# get metadata,includes the var_in, image height, width etc.
self.metadata_fl = os.path.join(os.path.dirname(self.input_dir.rstrip("/")), "metadata.json")
self.get_metadata(MetaData(json_file=self.metadata_fl))
# Get the data split informaiton
self.sequence_length = sequence_length
if norm == "minmax" or norm == "znorm":
self.norm = norm
else:
raise ValueError("norm should be either 'minmax' or 'znorm'")
self.sequences_per_file = sequences_per_file
self.write_sequence_file()
def get_years_months(self):
"""
Get the months in the datasplit_config
Return :
two elements: each contains 1-dim array with the months set from data_split_config json file
"""
self.months = []
self.years_months = []
# search for pickle names with pattern 'X_{}.pkl'for months
self.years = [name for name in os.listdir(self.input_dir) if os.path.isdir(os.path.join(self.input_dir, name))]
# search for folder names from pickle folder to get years
patt = "X_*.pkl"
for year in self.years:
months_pkl_list = glob.glob(os.path.join(self.input_dir, year, patt))
months_list = [int(m[-6:-4]) for m in months_pkl_list]
self.months.extend(months_list)
self.years_months.append(months_list)
return self.years, list(set(self.months)), self.years_months
def get_stats_file(self):
"""
Get the corresponding statistics file
"""
method = ERA5Pkl2Tfrecords.get_stats_file.__name__
stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json")
print("Opening json-file: {0}".format(stats_file))
if os.path.isfile(stats_file):
with open(stats_file) as js_file:
self.stats = json.load(js_file)
else:
raise FileNotFoundError("%{0}: Could not find statistic file '{1}'".format(method, stats_file))
def get_metadata(self, md_instance):
"""
This function gets the meta data that has been generated in data_process_step1. Here, we aim to extract
the height and width information from it
vars_in : list(str), must be consistent with the list from DataPreprocessing_step1
height : int, the height of the image
width : int, the width of the image
"""
method = ERA5Pkl2Tfrecords.get_metadata.__name__
if not isinstance(md_instance, MetaData):
raise ValueError("%{0}: md_instance-argument must be a MetaData class instance".format(method))
if not hasattr(self, "metadata_fl"):
raise ValueError("%{0}: MetaData class instance passed, but attribute metadata_fl is still missing.".format(method))
try:
self.height, self.width = md_instance.ny, md_instance.nx
self.vars_in = md_instance.variables
except:
raise IOError("%{0}: Could not retrieve all required information from metadata-file '{0}'"
.format(method, self.metadata_fl))
@staticmethod
def save_tf_record(output_fname, sequences, t_start_points):
"""
Save the sequences, and the corresponding timestamp start point to tfrecords
args:
output_frames : str, the file names of the output
sequences : list or array, the sequences want to be saved to tfrecords,
[sequences,seq_len,height,width,channels]
t_start_points : datetime type in the list, the first timestamp for each sequence
[seq_len,height,width, channel], the len of t_start_points is the same as sequences
"""
method = ERA5Pkl2Tfrecords.save_tf_record.__name__
sequences = np.array(sequences)
# sanity checks
assert sequences.shape[0] == len(t_start_points), "%{0}: Lengths of sequence differs from length of t_start_points.".format(method)
assert isinstance(t_start_points[0], datetime.datetime), "%{0}: Elements of t_start_points must be datetime-objects.".format(method)
with tf.python_io.TFRecordWriter(output_fname) as writer:
for i in range(len(sequences)):
sequence = sequences[i]
t_start = t_start_points[i].strftime("%Y%m%d%H")
num_frames = len(sequence)
height, width, channels = sequence[0].shape
encoded_sequence = np.array([list(image) for image in sequence])
features = tf.train.Features(feature={
'sequence_length': _int64_feature(num_frames),
'height': _int64_feature(height),
'width': _int64_feature(width),
'channels': _int64_feature(channels),
't_start': _int64_feature(int(t_start)),
'images/encoded': _floats_feature(encoded_sequence.flatten()),
})
example = tf.train.Example(features=features)
writer.write(example.SerializeToString())
def init_norm_class(self):
"""
Get normalization data class
"""
method = ERA5Pkl2Tfrecords.init_norm_class.__name__
print("%{0}: Make use of default minmax-normalization.".format(method))
# init normalization-instance
self.norm_cls = Norm_data(self.vars_in)
self.nvars = len(self.vars_in)
# get statistics file
self.get_stats_file()
# open statistics file and feed it to norm-instance
self.norm_cls.check_and_set_norm(self.stats, self.norm)
def normalize_vars_per_seq(self, sequences):
"""
Normalize all the variables for the sequences
args:
sequences: list or array, is the sequences need to be saved to tfrecorcd.
The shape should be [sequences_per_file,seq_length,height,width,nvars]
Return:
the normalized sequences
"""
method = ERA5Pkl2Tfrecords.normalize_vars_per_seq.__name__
assert len(np.array(sequences).shape) == 5, "%{0}: Length of sequence array must be 5.".format(method)
# normalization should adpot the selected variables, here we used duplicated channel temperature variables
sequences = np.array(sequences)
# normalization
for i in range(self.nvars):
sequences[..., i] = self.norm_cls.norm_var(sequences[..., i], self.vars_in[i], self.norm)
return sequences
def read_pkl_and_save_tfrecords(self, year, month):
"""
Read pickle files based on month, to process and save to tfrecords,
args:
year : int, the target year to save to tfrecord
month : int, the target month to save to tfrecord
"""
method = ERA5Pkl2Tfrecords.read_pkl_and_save_tfrecords.__name__
# Define the input_file based on the year and month
self.input_file_year = os.path.join(self.input_dir, str(year))
input_file = os.path.join(self.input_file_year, 'X_{:02d}.pkl'.format(month))
temp_input_file = os.path.join(self.input_file_year, 'T_{:02d}.pkl'.format(month))
self.init_norm_class()
sequences = []
t_start_points = []
sequence_iter = 0
try:
with open(input_file, "rb") as data_file:
X_train = pickle.load(data_file)
except:
raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, input_file))
try:
with open(temp_input_file, "rb") as temp_file:
T_train = pickle.load(temp_file)
except:
raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, temp_input_file))
# check to make sure that X_train and T_train have the same length
assert (len(X_train) == len(T_train))
X_possible_starts = [i for i in range(len(X_train) - self.sequence_length)]
for X_start in X_possible_starts:
X_end = X_start + self.sequence_length
seq = X_train[X_start:X_end, ...]
# recording the start point of the timestamps (already datetime-objects)
t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start])
seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars)))
if not sequences:
last_start_sequence_iter = sequence_iter
sequences.append(seq)
t_start_points.append(t_start)
sequence_iter += 1
if len(sequences) == self.sequences_per_file:
# normalize variables in the sequences
sequences = ERA5Pkl2Tfrecords.normalize_vars_per_seq(self, sequences)
output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year, month, last_start_sequence_iter,
sequence_iter - 1)
output_fname = os.path.join(self.output_dir, output_fname)
# write to tfrecord
ERA5Pkl2Tfrecords.write_seq_to_tfrecord(output_fname, sequences, t_start_points)
t_start_points = []
sequences = []
print("%{0}: Finished processing of input file '{1}'".format(method, input_file))
# except FileNotFoundError as fnf_error:
# print(fnf_error)
@staticmethod
def write_seq_to_tfrecord(output_fname, sequences, t_start_points):
"""
Function to check if the sequences has been processed.
If yes, the sequences are skipped, otherwise the sequences are saved to the output file
"""
method = ERA5Pkl2Tfrecords.write_seq_to_tfrecord.__name__
if os.path.isfile(output_fname):
print("%{0}: TFrecord-file {1} already exists. It is therefore skipped.".format(method, output_fname))
else:
ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points)
def write_sequence_file(self):
"""
Generate a txt file, with the numbers of sequences for each tfrecords file.
This is mainly used for calculting the number of samples for each epoch during training epoch
"""
with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file:
seq_file.write("%d\n" % self.sequences_per_file)
@staticmethod
def ensure_datetime(date):
"""
Wrapper to return a datetime-object
"""
method = ERA5Pkl2Tfrecords.ensure_datetime.__name__
fmt = "%Y%m%d %H:%M"
if isinstance(date, datetime.datetime):
date_new = date
else:
try:
date_new=pd.to_datetime(date)
date_new=date_new.to_pydatetime()
except Exception as err:
print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date)))
raise err
return date_new
def _bytes_feature(value):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
def _bytes_list_feature(values):
return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
def _floats_feature(value):
return tf.train.Feature(float_list=tf.train.FloatList(value=value))
def _int64_feature(value):
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment