diff --git a/Jupyter_Notebooks/Data_Preprocess_toy.ipynb b/Jupyter_Notebooks/Data_Preprocess_toy.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..86126f9cee0e1ca1373ab813b35baf5fbac3a761 --- /dev/null +++ b/Jupyter_Notebooks/Data_Preprocess_toy.ipynb @@ -0,0 +1,495 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "5e4295b1-17d3-49eb-b1cc-25cd2c3e38e3", + "metadata": {}, + "outputs": [], + "source": [ + "#https://stackoverflow.com/questions/55429307/how-to-use-windows-created-by-the-dataset-window-method-in-tensorflow-2-0\n", + "import os\n", + "import xarray as xr\n", + "import numpy as np\n", + "import time\n", + "\n", + "video_pred_folder = \"/p/home/jusers/gong1/juwels/video_prediction_shared_folder/\"\n", + "datadir = os.path.join(video_pred_folder, \"test_data_roshni\")\n", + "ds = xr.open_mfdataset(os.path.join(datadir, \"*.nc\"))\n", + "da = ds.to_array(dim=\"variables\").squeeze()\n", + "dims = [\"time\", \"lat\", \"lon\"]\n", + "max_vars, min_vars = da.max(dim=dims).values, da.min(dim=dims).values\n", + "data_arr = np.squeeze(da.values)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "92a20eb3-6358-410b-bf63-2d0cf8e38856", + "metadata": {}, + "outputs": [], + "source": [ + "%%timeit\n", + "data_arr.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cea4aede-32db-4593-a5aa-d7a242bc960a", + "metadata": {}, + "outputs": [], + "source": [ + "data_arr.shape\n", + "data_arr = data_arr.reshape(17520, 3, 56, 92)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "903034b5-4706-419d-915c-886790a9201f", + "metadata": {}, + "outputs": [], + "source": [ + "#data_arr = data_arr[:48]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f62c709f-7cf1-41a1-bb78-94dc7e064f3b", + "metadata": {}, + "outputs": [], + "source": [ + "data_arr.shape" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2db0b465-92d6-4716-98c1-100df47f0041", + "metadata": {}, + "outputs": [], + "source": [ + "data_arr [0,0,0,0]" + ] + }, + { + "cell_type": "code", + "execution_count": 187, + "id": "fafabe11-ed9f-40b6-b830-8d78f52dc239", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "280.05115" + ] + }, + "execution_count": 187, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_arr [1,0,0,0]" + ] + }, + { + "cell_type": "code", + "execution_count": 197, + "id": "e4f5d0cb-56a2-4085-80cf-9c681dc02c5f", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "279.88528" + ] + }, + "execution_count": 197, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "data_arr [2,0,0,0]" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "id": "3f5382c4-b4d1-4e11-9401-7113c46e83a7", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([317.27255, 1. , 303.1935 ], dtype=float32)" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "max_vars" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "220c117c-09dd-42fe-a4b4-30e5a7147e57", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2022-03-17 15:07:39.060539: I tensorflow/stream_executor/platform/default/dso_loader.cc:53] Successfully opened dynamic library libcudart.so.11.0\n" + ] + }, + { + "ename": "NameError", + "evalue": "name 'data_arr' is not defined", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m/tmp/ipykernel_15466/2295205246.py\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtensorflow\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mwindow_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m24\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mDataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfrom_tensor_slices\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata_arr\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mwindow\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mshift\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mdrop_remainder\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mflat_map\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;32mlambda\u001b[0m \u001b[0mwindow\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mwindow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mwindow_size\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mdataset\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mdataset\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m3\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mNameError\u001b[0m: name 'data_arr' is not defined" + ] + } + ], + "source": [ + "import tensorflow as tf\n", + "window_size=24\n", + "dataset = tf.data.Dataset.from_tensor_slices(data_arr).window(window_size,shift=1,drop_remainder=True)\n", + "dataset = dataset.flat_map(lambda window: window.batch(window_size))\n", + "dataset = dataset.batch(3)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "7973bc9c-6671-4499-8d51-9eea02302808", + "metadata": {}, + "outputs": [], + "source": [ + "def benchmark(dataset, num_epochs=2):\n", + " start_time = time.perf_counter()\n", + " for epoch_num in range(num_epochs):\n", + " for sample in dataset:\n", + " # Performing a training step\n", + " time.sleep(0.01)\n", + " print(\"Execution time:\", time.perf_counter() - start_time)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b33ce29b-e0af-4665-bb27-2c08aa706af4", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": 229, + "id": "18e04c1e-a01b-41bb-bd3a-84603a1409e1", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "400 ms ± 28.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)\n" + ] + } + ], + "source": [ + "#dataset = dataset.shuffle(9).batch(3)\n", + "%%timeit\n", + "for next_element in dataset.take(200):\n", + " #time_s = time.time()\n", + " #tf.print(next_element.shape)\n", + " pass\n", + " # print(next_element.numpy()[0,0,0,0,0])\n", + " # print(next_element.numpy()[0,1,0,0,0])\n", + " # print(next_element.numpy()[0,2,0,0,0])\n", + " # print(next_element.numpy()[0,3,0,0,0])\n", + " # print(\"++++++++\")\n", + " # print(next_element.numpy()[1,0,0,0,0])\n", + " # print(next_element.numpy()[1,1,0,0,0])\n", + " # print(next_element.numpy()[1,2,0,0,0])\n", + " # print(next_element.numpy()[1,3,0,0,0])\n", + " # print(\"-----------------\")\n", + " #print(time.time - time_s)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1d20b48a-0d77-442d-9efd-8e16acac0fd2", + "metadata": {}, + "outputs": [], + "source": [] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2dd96f63-d0ae-4515-809f-e3b3a05ca801", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "tf.Tensor(\n", + "[[[[140.14406 140.20949 140.24464 ... 136.82472 136.78859 136.75441]\n", + " [140.188 140.1841 140.19875 ... 136.94777 136.85597 136.74953]\n", + " [140.0298 140.08058 140.10402 ... 136.93312 136.86574 136.77003]\n", + " ...\n", + " [144.21242 144.55812 144.6089 ... 141.7339 138.86769 137.20656]\n", + " [144.19484 144.6714 144.79152 ... 141.88722 139.03078 137.73586]\n", + " [144.96925 145.10793 145.06593 ... 141.23781 138.5757 137.70265]]\n", + "\n", + " [[140.04933 140.11769 140.1382 ... 136.84132 136.79054 136.73 ]\n", + " [140.14015 140.15578 140.15578 ... 136.93898 136.84523 136.74855]\n", + " [140.01418 140.08058 140.12453 ... 136.88918 136.85402 136.7925 ]\n", + " ...\n", + " [143.89015 144.28078 144.42531 ... 141.77394 138.91847 137.21632]\n", + " [143.8091 144.4175 144.59425 ... 141.8755 139.0591 137.75539]\n", + " [144.74757 144.96535 144.95753 ... 141.37062 138.63039 137.75636]]\n", + "\n", + " [[140.02666 140.0833 140.1038 ... 136.97197 136.9583 136.92607]\n", + " [140.13408 140.15166 140.15752 ... 136.99931 136.94463 136.88506]\n", + " [139.97295 140.06572 140.12431 ... 136.90068 136.89189 136.8665 ]\n", + " ...\n", + " [143.63017 144.03642 144.16338 ... 141.80107 138.64287 136.84795]\n", + " [143.43486 144.10674 144.29033 ... 141.82744 138.76006 137.36552]\n", + " [144.4085 144.66924 144.70049 ... 141.00615 138.22197 137.35283]]]\n", + "\n", + "\n", + " [[[140.02557 140.10565 140.14667 ... 136.85077 136.88397 136.9035 ]\n", + " [140.04999 140.10956 140.14569 ... 136.88788 136.87714 136.86542]\n", + " [139.79999 139.93182 140.02753 ... 136.81561 136.83905 136.86053]\n", + " ...\n", + " [143.4035 143.82635 143.96503 ... 141.76874 138.75507 136.71893]\n", + " [143.07245 143.8869 144.07538 ... 141.72089 138.591 137.19647]\n", + " [144.12909 144.40839 144.42987 ... 140.69745 138.03143 137.18378]]\n", + "\n", + " [[139.96661 140.0545 140.10822 ... 136.75568 136.79181 136.8299 ]\n", + " [139.9129 139.95587 140.00177 ... 136.76154 136.7547 136.75958]\n", + " [139.6629 139.76056 139.83185 ... 136.69025 136.68048 136.70001]\n", + " ...\n", + " [143.16681 143.60431 143.75568 ... 141.64532 138.63849 136.65314]\n", + " [142.867 143.66095 143.867 ... 141.55841 138.534 137.13458]\n", + " [143.88165 144.15509 144.18634 ... 140.5252 138.01349 137.13947]]\n", + "\n", + " [[139.93112 140.0151 140.05026 ... 136.88815 136.90182 136.92526]\n", + " [139.86569 139.87155 139.88034 ... 136.89401 136.87839 136.86179]\n", + " [139.603 139.68015 139.72604 ... 136.89792 136.83054 136.79636]\n", + " ...\n", + " [142.95358 143.42526 143.5737 ... 141.45749 138.35202 136.58054]\n", + " [142.6323 143.46921 143.686 ... 141.37448 138.26608 136.80515]\n", + " [143.68112 143.95847 143.9819 ... 140.15085 137.85007 136.94675]]]\n", + "\n", + "\n", + " [[[139.94264 139.97878 139.97292 ... 136.90944 136.886 136.88698]\n", + " [139.89186 139.85182 139.81471 ... 136.9905 136.89284 136.79128]\n", + " [139.59889 139.65846 139.68776 ... 137.04323 136.90358 136.78249]\n", + " ...\n", + " [142.8108 143.27858 143.43092 ... 141.32253 138.28932 136.56569]\n", + " [142.51198 143.31178 143.54909 ... 141.22292 138.24147 136.7903 ]\n", + " [143.54811 143.82643 143.82545 ... 140.09303 137.95436 137.03249]]\n", + "\n", + " [[140.0629 140.11563 140.08731 ... 136.9252 136.87051 136.85 ]\n", + " [139.97403 139.91934 139.85294 ... 137.10196 136.9672 136.81876]\n", + " [139.65958 139.69669 139.70645 ... 137.18008 137.00919 136.84512]\n", + " ...\n", + " [142.68399 143.17227 143.30899 ... 141.20059 138.35587 136.82657]\n", + " [142.15762 143.18008 143.44278 ... 141.19278 138.70743 137.34512]\n", + " [143.37637 143.73184 143.71817 ... 140.65567 138.64493 137.85196]]\n", + "\n", + " [[140.18867 140.24532 140.2209 ... 137.46504 137.36446 137.28242]\n", + " [140.07051 140.02461 139.96211 ... 137.70137 137.5168 137.32637]\n", + " [139.7375 139.78047 139.79317 ... 137.77461 137.57344 137.38203]\n", + " ...\n", + " [142.60176 143.092 143.21309 ... 141.36153 139.64082 138.467 ]\n", + " [142.15352 143.10176 143.37715 ... 141.35957 140.28145 139.16426]\n", + " [143.27168 143.66231 143.68086 ... 141.87227 140.62422 139.9875 ]]]\n", + "\n", + "\n", + " [[[140.24962 140.31407 140.30724 ... 137.61388 137.49376 137.38536]\n", + " [140.09727 140.06993 140.0377 ... 137.86192 137.65099 137.43224]\n", + " [139.76134 139.82481 139.86095 ... 137.93419 137.692 137.46837]\n", + " ...\n", + " [142.76524 143.07384 143.15392 ... 141.54747 140.30724 139.2838 ]\n", + " [143.0631 143.2584 143.38634 ... 141.61876 140.86095 139.89806]\n", + " [143.44005 143.64806 143.70274 ... 142.00352 141.234 140.70958]]\n", + "\n", + " [[140.24228 140.2833 140.2911 ... 138.17587 138.0499 137.9288 ]\n", + " [140.17099 140.17197 140.1661 ... 138.3165 138.12216 137.93661]\n", + " [139.92392 139.99423 140.03818 ... 138.2833 138.08994 137.90927]\n", + " ...\n", + " [143.39853 143.39658 143.28622 ... 141.97275 141.54306 141.02939]\n", + " [144.20615 143.7081 143.5831 ... 142.06943 141.8663 141.33994]\n", + " [143.78525 143.79794 143.84189 ... 142.46591 142.06064 141.61826]]\n", + "\n", + " [[140.25374 140.3055 140.32698 ... 138.42952 138.27425 138.12776]\n", + " [140.19807 140.19319 140.19319 ... 138.5467 138.33284 138.12093]\n", + " [139.94319 139.99983 140.04768 ... 138.49495 138.26936 138.05745]\n", + " ...\n", + " [143.96956 143.79573 143.51448 ... 142.1971 141.92854 141.56136]\n", + " [145.04573 144.16194 143.826 ... 142.31721 142.27034 141.85628]\n", + " [144.12093 143.89339 143.92757 ... 142.64925 142.44612 142.06917]]]], shape=(4, 3, 56, 92), dtype=float32)\n", + "tf.Tensor(\n", + "[[[[140.2667 140.30966 140.32138 ... 138.63876 138.4581 138.29794]\n", + " [140.20224 140.18466 140.17294 ... 138.74228 138.504 138.26376]\n", + " [139.9415 139.97763 140.004 ... 138.68857 138.43173 138.18954]\n", + " ...\n", + " [144.53427 144.2872 143.89072 ... 142.38486 142.25693 141.98056]\n", + " [145.74228 144.46005 144.03818 ... 142.50009 142.59091 142.25107]\n", + " [144.37997 143.9415 143.96494 ... 142.87021 142.7579 142.41122]]\n", + "\n", + " [[140.2666 140.30664 140.31152 ... 138.7832 138.58887 138.41699]\n", + " [140.19238 140.16602 140.14551 ... 138.86426 138.61035 138.36328]\n", + " [139.96875 139.99316 140.01172 ... 138.77148 138.5166 138.27441]\n", + " ...\n", + " [144.9414 144.6211 144.24805 ... 142.54199 142.41699 142.17188]\n", + " [146.0791 144.5586 144.12793 ... 142.65234 142.77148 142.4668 ]\n", + " [144.60059 144.02148 144.00586 ... 143.02246 142.92285 142.60938]]\n", + "\n", + " [[140.30289 140.3273 140.31851 ... 138.9689 138.84879 138.72867]\n", + " [140.25015 140.20914 140.17496 ... 139.03336 138.83218 138.6271 ]\n", + " [140.05484 140.07242 140.08023 ... 138.86246 138.68765 138.49625]\n", + " ...\n", + " [145.11441 144.7609 144.32925 ... 142.69156 142.29996 141.95425]\n", + " [146.44254 144.61343 144.14175 ... 142.81754 142.6398 142.18668]\n", + " [144.84879 144.19937 144.10074 ... 142.96988 142.72379 142.32632]]]\n", + "\n", + "\n", + " [[[140.33517 140.35373 140.34103 ... 138.93967 138.8469 138.75217]\n", + " [140.29416 140.26291 140.2297 ... 138.9426 138.78537 138.63596]\n", + " [140.09885 140.1301 140.13596 ... 138.70236 138.59201 138.47092]\n", + " ...\n", + " [145.23264 144.87814 144.36642 ... 142.76682 141.71701 141.61642]\n", + " [146.36057 144.81271 144.29709 ... 142.95139 141.85568 141.82248]\n", + " [144.9172 144.4338 144.28342 ... 143.05783 142.44357 142.06955]]\n", + "\n", + " [[140.33109 140.37796 140.37796 ... 138.85745 138.7637 138.67484]\n", + " [140.29984 140.29105 140.27151 ... 138.78226 138.62015 138.48929]\n", + " [140.08011 140.14359 140.16214 ... 138.5098 138.39066 138.28714]\n", + " ...\n", + " [145.26468 145.07718 144.6094 ... 142.84964 140.98343 140.79398]\n", + " [146.26273 145.03519 144.57425 ... 143.05472 141.1094 141.08206]\n", + " [144.94632 144.54105 144.40921 ... 142.51956 141.72952 141.4053 ]]\n", + "\n", + " [[140.29703 140.36832 140.39175 ... 138.90543 138.84683 138.79507]\n", + " [140.2521 140.28629 140.29703 ... 138.85562 138.71695 138.60464]\n", + " [139.97476 140.1066 140.17398 ... 138.57828 138.465 138.38492]\n", + " ...\n", + " [145.10855 145.22672 144.90738 ... 142.9064 139.83804 139.53433]\n", + " [145.49332 145.00992 144.73843 ... 143.10855 140.13882 139.86832]\n", + " [144.87808 144.59293 144.42398 ... 140.91226 140.26773 140.1564 ]]]\n", + "\n", + "\n", + " [[[140.2214 140.35715 140.40793 ... 138.83762 138.77316 138.7214 ]\n", + " [140.18039 140.22629 140.24875 ... 138.8132 138.6716 138.55832]\n", + " [139.83176 139.99875 140.09933 ... 138.56223 138.44797 138.36887]\n", + " ...\n", + " [144.77805 145.05344 145.0007 ... 142.98996 139.88644 139.5134 ]\n", + " [145.0261 144.93039 144.84543 ... 143.11691 140.15793 139.77902]\n", + " [144.9011 144.71164 144.51535 ... 140.9968 139.34738 139.79465]]\n", + "\n", + " [[140.16458 140.27884 140.33646 ... 138.79056 138.72415 138.66849]\n", + " [140.13333 140.16556 140.19193 ... 138.77493 138.64114 138.53372]\n", + " [139.78665 139.94193 140.03372 ... 138.5513 138.42728 138.34036]\n", + " ...\n", + " [144.44876 144.78275 144.85599 ... 143.04837 139.89896 139.43607]\n", + " [144.61575 144.83841 144.89114 ... 143.11575 140.15677 139.71243]\n", + " [144.9263 144.82474 144.63333 ... 141.3804 139.51419 139.42532]]\n", + "\n", + " [[140.03177 140.08841 140.13919 ... 138.8013 138.74075 138.69388]\n", + " [140.0972 140.11868 140.15872 ... 138.80716 138.67825 138.57474]\n", + " [139.82376 139.9595 140.03958 ... 138.60501 138.48782 138.40482]\n", + " ...\n", + " [144.17532 144.52884 144.60794 ... 143.07181 139.80618 139.1177 ]\n", + " [144.43118 144.68997 144.83548 ... 143.07181 139.98685 139.39114]\n", + " [144.85892 144.90677 144.7388 ... 141.57962 139.48392 138.99173]]]\n", + "\n", + "\n", + " [[[139.9985 140.0571 140.08737 ... 138.86081 138.78073 138.72311]\n", + " [140.04147 140.06686 140.09225 ... 138.87839 138.71432 138.59323]\n", + " [139.89304 139.99655 140.05222 ... 138.67136 138.53952 138.44577]\n", + " ...\n", + " [143.9653 144.30807 144.40866 ... 143.07272 139.80124 138.57956]\n", + " [144.07663 144.50241 144.69577 ... 142.97995 139.98093 138.89792]\n", + " [144.65671 144.87253 144.83444 ... 141.88425 139.64597 138.72995]]\n", + "\n", + " [[140.03333 140.07434 140.0929 ... 138.79993 138.75305 138.71399]\n", + " [140.00598 140.04504 140.06458 ... 138.85364 138.72571 138.61047]\n", + " [139.90051 140.02356 140.0841 ... 138.70618 138.58313 138.48547]\n", + " ...\n", + " [143.96497 144.26282 144.29895 ... 142.96594 139.62708 137.83508]\n", + " [144.14856 144.3927 144.56555 ... 142.81555 139.79797 138.33313]\n", + " [144.61047 144.80579 144.82141 ... 141.8927 139.54993 138.37415]]\n", + "\n", + " [[140.02469 140.10086 140.12527 ... 138.86551 138.83719 138.80106]\n", + " [139.96902 140.02957 140.05301 ... 138.92703 138.8157 138.70438]\n", + " [139.8245 139.96902 140.04129 ... 138.7786 138.66727 138.57059]\n", + " ...\n", + " [143.89774 144.1995 144.21902 ... 142.8948 139.5745 137.63895]\n", + " [144.05887 144.29422 144.4368 ... 142.6868 139.80496 138.22293]\n", + " [144.44266 144.63309 144.69168 ... 141.96317 139.60867 138.39285]]]], shape=(4, 3, 56, 92), dtype=float32)\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2022-03-02 15:55:14.988517: W tensorflow/core/kernels/data/cache_dataset_ops.cc:768] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.\n" + ] + } + ], + "source": [ + "tf.random.set_seed(\n", + " 123\n", + ")\n", + "#https://www.tensorflow.org/guide/data_performance\n", + "\n", + "\n", + "def parse_fn(x, min_value, max_value):\n", + " return (x-min_value)/(max_value - min_value)\n", + "\n", + "preprocessed_dataset = dataset.map(map_func=parse_fn(x, min_value,max_value))\n", + "\n", + "for row in preprocessed_dataset.take(2):\n", + " print(row)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "PyDeepLearning-1.1", + "language": "python", + "name": "pydeeplearning" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.5" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/foo.ipynb b/foo.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..87581ea65f680722adea94b7daed27ab12c0e7c1 --- /dev/null +++ b/foo.ipynb @@ -0,0 +1,89 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "import xarray as xr\n", + "import numpy as np\n", + "\n", + "filenames_t850 = [\n", + " \"data_t850/temperature_850hPa_1979_5.625deg.nc\",\n", + " \"data_t850/temperature_850hPa_1980_5.625deg.nc\"\n", + "]\n", + "filenames_z500 = [\n", + " \"data_z500/geopotential_500hPa_1979_5.625deg.nc\",\n", + " \"data_z500/geopotential_500hPa_1980_5.625deg.nc\"\n", + "]\n", + "filenames = [*filenames_t850, *filenames_z500]\n", + "ds = xr.open_mfdataset(filenames, coords=\"minimal\", compat=\"override\")\n", + "ds = ds.drop_vars(\"level\")" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "(32, 64, 2)" + ] + }, + "execution_count": 37, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "da = ds.to_array(dim=\"variables\").squeeze()\n", + "\n", + "dims = [\"time\", \"lat\", \"lon\", \"variables\"]\n", + "da = da.transpose(*dims)\n", + "\n", + "def generator(iterable):\n", + " iterator = iter(iterable)\n", + " yield from iterator\n", + "\n", + "da.shape[1:]" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ambs", + "language": "python", + "name": "ambs" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.4" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "341ba53bbba0a6f1cf5ae0d50bab29c5266302a4d2a8950e418cc5f54c6f95ff" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/test/run_pytest.sh b/test/run_pytest.sh index 6aae33cf0312455efbaa95bbd491440e6d672b2e..6f8dbb694b4760af7410f1d037804102622ef5db 100644 --- a/test/run_pytest.sh +++ b/test/run_pytest.sh @@ -1,8 +1,7 @@ #!#bin/bash -# Name of virtual environment -#VIRT_ENV_NAME="vp_new_structure" -VIRT_ENV_NAME="env_hdfml" +# Name of virtual environment +VIRT_ENV_NAME="venv2_hdfml" if [ -z ${VIRTUAL_ENV} ]; then if [[ -f ../video_prediction_tools/${VIRT_ENV_NAME}/bin/activate ]]; then @@ -21,6 +20,7 @@ fi #python -m pytest test_prepare_era5_data.py ##Test for preprocess_step1 #python -m pytest test_process_netCDF_v2.py +#source ../video_prediction_tools/env_setup/modules_preprocess+extract.sh source ../video_prediction_tools/env_setup/modules_train.sh ##Test for preprocess moving mnist #python -m pytest test_prepare_moving_mnist_data.py @@ -33,5 +33,5 @@ source ../video_prediction_tools/env_setup/modules_train.sh #rm /p/project/deepacf/deeprain/video_prediction_shared_folder/models/test/* #python -m pytest test_train_model_era5.py #python -m pytest test_vanilla_vae_model.py -python -m pytest test_visualize_postprocess.py +python -m pytest test_gzprcp_data.py #python -m pytest test_meta_postprocess.py diff --git a/test/run_pytest_era5_data_preprocess.sh b/test/run_pytest_era5_data_preprocess.sh new file mode 100755 index 0000000000000000000000000000000000000000..2111c750e6fab426551360d5048d8f8a0a2d7eba --- /dev/null +++ b/test/run_pytest_era5_data_preprocess.sh @@ -0,0 +1,23 @@ + + +# Name of virtual environment +VIRT_ENV_NAME="venv_hdfml" + + +CONTAINER_IMG="../video_prediction_tools/HPC_scripts/tensorflow_21.09-tf1-py3.sif" +WRAPPER="./wrapper_container.sh" + +# sanity checks +if [[ ! -f ${CONTAINER_IMG} ]]; then + echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'." + exit 1 +fi + +if [[ ! -f ${WRAPPER} ]]; then + echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image." + exit 1 +fi + +#source ../video_prediction_tools/env_setup/modules_preprocess+extract.sh +singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} python3 -m pytest test_era5_data.py + diff --git a/test/test_era5_data.py b/test/test_era5_data.py index bb7e2ad6e382a38023d18dd84ad70dd5990b8850..605d6629595185b7003fd0d7caa045069bd0e512 100644 --- a/test/test_era5_data.py +++ b/test/test_era5_data.py @@ -1,43 +1,92 @@ - __email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth" - - +__author__ = "Bing Gong" from video_prediction.datasets.era5_dataset import * import pytest +import xarray as xr +import os +import tensorflow as tf import numpy as np -import json -import datetime - -input_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/preprocessedData/test" -datasplit_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/data_split/cv_test.json" -hparams_dict_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/hparams/era5/convLSTM/model_hparams.json" -sequences_per_file = 10 -mode = "val" +input_dir = "/p/project/deepacf/deeprain/video_prediction_shared_folder/test_data_roshni" +datasplit_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/data_split/test/cv_test.json" +hparams_dict_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json" +mode = "test" @pytest.fixture(scope="module") -def era5_dataset_case2(): - return ERA5Dataset(input_dir=input_dir,mode=mode, - datasplit_config=datasplit_config,hparams_dict_config=hparams_dict_config,seed=1234) -def test_init_era5_dataset(era5_dataset_case2): - assert era5_dataset_case2.hparams.max_epochs == 20 - assert era5_dataset_case2.mode == mode +def era5_dataset_case1(): + return ERA5Dataset(input_dir=input_dir, datasplit_config=datasplit_config, hparams_dict_config=hparams_dict_config, + mode="test", seed=1234, nsamples_ref=1000) -def test_get_tfrecords_filesnames(era5_dataset_case2): - era5_dataset_case2.get_tfrecords_filesnames_base_datasplit() - assert era5_dataset_case2.filenames[0] == os.path.join(input_dir,"tfrecords","sequence_Y_2017_M_2_0_to_9.tfrecords")# def test_check_pkl_tfrecords_consistency(era5_dataset_case1): - -def test_get_example_info(era5_dataset_case2): - era5_dataset_case2.get_tfrecords_filesnames_base_datasplit() - era5_dataset_case2.get_example_info() - assert era5_dataset_case2.image_shape[0] == 160 - assert era5_dataset_case2.image_shape[1] == 128 - assert era5_dataset_case2.image_shape[2] == 3 +def test_init_era5_dataset(era5_dataset_case1): + era5_dataset_case1.get_hparams() + assert era5_dataset_case1.max_epochs == 20 + assert era5_dataset_case1.mode == mode + assert era5_dataset_case1.batch_size == 4 +def test_get_filenames_from_datasplit(era5_dataset_case1): + flname= os.path.join(era5_dataset_case1.input_dir, "era5_vars4ambs_201901.nc") + n_files = len(era5_dataset_case1.filenames) + check = flname in era5_dataset_case1.filenames + assert check == True + assert n_files == 12 +def test_make_dataset(era5_dataset_case1): + # Get the data from nc files directly + data_arr = era5_dataset_case1.load_data_from_nc() + assert len(data_arr) !=0 + ds = xr.open_mfdataset(era5_dataset_case1.filenames) + len_dt = len(ds["time"].values) # count number of images/samples in the test dataset + da = ds.to_array(dim = "variables").squeeze() + dims = ["time", "lat", "lon"] + data_arr = np.squeeze(da.values) #[vars,samples,lat,lon] + max_vars, min_vars = da.max(dim=dims).values, da.min(dim=dims).values #three dimension + print("data_arr shape",data_arr.shape) + #normalise the data for the first variable + def norm_var(x, min_value, max_value): + return (x - min_value) / (max_value - min_value) + assert np.max(data_arr[0]) == max_vars[0] + #mannualy calculate the normalization of the data + dt_norm = norm_var(data_arr[0],np.min(data_arr[0]), np.max(data_arr[0])) + + print("dt_norm",dt_norm.shape) + s1 = dt_norm[0] #the first sample, first timestamp + s2 = dt_norm[23] #the first sample, last timestamp + s3 = dt_norm[1] # the second sample, first timestamp + s4 = dt_norm[24] # the second sample, last timestamp + # Get the data from make_dataset function + test_dataset = era5_dataset_case1.make_dataset() + test_iterator = test_dataset.make_one_shot_iterator() + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + test_handle = test_iterator.string_handle() + iterator = tf.data.Iterator.from_string_handle(test_handle, test_dataset.output_types, test_dataset.output_shapes) + inputs = iterator.get_next() + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + #get the batch size samples from dataset + dt = sess.run(inputs) #[batch_size,sequence_len,n_vars,lon,lat] + dt.shape[0] == 4 + dt.shape[1] == 24 + print("shape of dt",dt.shape) + s1t = dt[0,0,0] + s2t = dt[0,23,0] + + #get the second sample from dataset + s3t = dt[1,0,0] + s4t = dt[1,23,0] + + #s2t = sess.run(inputs)[0,:,0] + assert np.sum(s1-s1t) < 0.0001 + assert np.sum(s2-s2t) < 0.0001 + assert np.sum(s3-s3t) < 0.0001 + assert np.sum(s4 -s4t) < 0.0001 + #compare the data from nc files and make_dataset + + + diff --git a/test/test_gzprcp_data.py b/test/test_gzprcp_data.py new file mode 100644 index 0000000000000000000000000000000000000000..4903690e4d6cfb829677de0cd0e9e86e61f6e44b --- /dev/null +++ b/test/test_gzprcp_data.py @@ -0,0 +1,75 @@ + +__email__ = "b.gong@fz-juelich.de" + +from video_prediction.datasets.gzprcp_dataset import * +import pytest +import tensorflow as tf +import xarray as xr + +input_dir = "/p/largedata/jjsc42/project/deeprain/project_data/10min_AWS_prcp" +datasplit_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/data_split/gzprcp/datasplit.json" +hparams_dict_config = "/p/project/deepacf/deeprain/bing/ambs/video_prediction_tools/hparams/gzprcp/convLSTM_gan/model_hparams_template.json" +sequences_per_file = 10 +mode = "test" + + +@pytest.fixture(scope="module") +def gzprcp_dataset_case1(): + dataset = GzprcpDataset(input_dir=input_dir, datasplit_config=datasplit_config, hparams_dict_config=hparams_dict_config, + mode="test", seed=1234, nsamples_ref=1000) + dataset.get_hparams() + dataset.get_filenames_from_datasplit() + dataset.load_data_from_nc() + return dataset + +def test_init_gzprcp_dataset(gzprcp_dataset_case1): + # gzprcp_dataset_case1.get_hparams() + print('gzprcp_dataset_case1.max_epochs: {}'.format(gzprcp_dataset_case1.max_epochs)) + print('gzprcp_dataset_case1.mode: {}'.format(gzprcp_dataset_case1.mode)) + print('gzprcp_dataset_case1.batch_size: {}'.format(gzprcp_dataset_case1.batch_size)) + print('gzprcp_dataset_case1.k: {}'.format(gzprcp_dataset_case1.k)) + print('gzprcp_dataset_case1.filenames: {}'.format(gzprcp_dataset_case1.filenames)) + + assert gzprcp_dataset_case1.max_epochs == 8 + assert gzprcp_dataset_case1.mode == mode + assert gzprcp_dataset_case1.batch_size == 32 + assert gzprcp_dataset_case1.k == 0.01 + # assert gzprcp_dataset_case1.filenames[0] == 'GZ_prcp_2019.nc' + +def test_load_data_from_nc(gzprcp_dataset_case1): + train_tf_dataset = gzprcp_dataset_case1.make_dataset() + train_iterator = train_tf_dataset.make_one_shot_iterator() + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + train_handle = train_iterator.string_handle() + iterator = tf.data.Iterator.from_string_handle(train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) + inputs = iterator.get_next() + + with tf.Session() as sess: + sess.run(tf.global_variables_initializer()) + sess.run(tf.local_variables_initializer()) + for step in range(2): + sess.run(inputs) + + # df = xr.open_mfdataset(era5_dataset_case1.filenames) + +# if __name__ == '__main__': +# dataset = ERA5Dataset(input_dir: str = None, datasplit_config: str = None, hparams_dict_config: str = None, +# mode: str = "train", seed: int = None, nsamples_ref: int = None) +# for next_element in dataset.take(2): +# # time_s = time.time() +# # tf.print(next_element.shape) +# pass + + + + + + + + + + + + + diff --git a/test/test_prepare_era5_data.py b/test/test_prepare_era5_data.py index 833df4faae350395de43cc93990e136138f1a25e..bead9d1517960b5170bea7a091de69402a6c94d9 100644 --- a/test/test_prepare_era5_data.py +++ b/test/test_prepare_era5_data.py @@ -4,11 +4,8 @@ __author__ = "Bing Gong" __date__ = "2021-03-03" - from data_preprocess.prepare_era5_data import * import pytest -import numpy as np -import json import os year="2007" @@ -23,8 +20,6 @@ def dataExtraction_case1(year=year,job_name=job_name,src_dir=src_dir,target_dir= return ERA5DataExtraction(year,job_name,src_dir,target_dir,varslist_json) - - def test_init(dataExtraction_case1): assert dataExtraction_case1.job_name == 1 assert dataExtraction_case1.src_dir == src_dir diff --git a/video_prediction_tools/HPC_scripts/data_extraction_weatherbench_template.sh b/video_prediction_tools/HPC_scripts/data_extraction_weatherbench_template.sh new file mode 100755 index 0000000000000000000000000000000000000000..ceae63dc9bbd409269aa09576f9e88d98ef084c9 --- /dev/null +++ b/video_prediction_tools/HPC_scripts/data_extraction_weatherbench_template.sh @@ -0,0 +1,63 @@ +#!/bin/bash -x +#SBATCH --account=deepacf +#SBATCH --nodes=1 +#SBATCH --gres=gpu:0 +#SBATCH --output=log_out.%j +#SBATCH --error=log_err.%j +#SBATCH --time=00:10:00 +#SBATCH --partition=batch + +######### Template identifier (don't remove) ######### +echo "Do not run the template scripts" +exit 99 +######### Template identifier (don't remove) ######### + +ml Stages/2022 +ml GCCcore/.11.2.0 +ml GCC/11.2.0 +ml ParaStationMPI/5.5.0-1 + +ml Python/3.9.6 +ml SciPy-bundle/2021.10 +ml xarray/0.20.1 +ml netcdf4-python/1.5.7 +ml dask/2021.9.1 + +# Name of virtual environment +VIRT_ENV_NAME="my_venv" + +# Activate virtual environment if needed (and possible) +""" +if [ -z ${VIRTUAL_ENV} ]; then + if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi +# Loading modules +source ../env_setup/modules_preprocess+extract.sh +""" + +source_dir=/p/scratch/deepacf/inbound_data/weatherbench +destination_dir=/p/project/deepacf/deeprain/video_prediction_shared_folder/weatherbench_test/extracted +data_extraction_dir=/p/project/deepacf/deeprain/grasse/ambs/video_prediction_tools/data_preprocess +variables='[{"name":"temperature","lvl":[850],"interpolation":"p"},{"name":"geopotential","lvl":[500],"interpolation":"p"}]' +years=("2013" "2014" "2015" "2016" "2017") + +cd ${data_extraction_dir} + +# Name of virtual environment +venv_dir=".venv" +python -m venv --system-site-packages ${venv_dir} +. ${venv_dir}/bin/activate +#pip3 install --no-cache-dir pytz +#pip3 install --no-cache-dir python-dateutil +export PYTHONPATH=${data_extraction_dir}:$PYTHONPATH +export PYTHONPATH="${data_extraction_dir}/..":$PYTHONPATH + +python3 ../main_scripts/main_data_extraction.py ${source_dir} ${dest_dir} ${years[@]} ${variables} + +rm -r ${venv_dir} diff --git a/video_prediction_tools/HPC_scripts/data_extraction_era5_template.sh b/video_prediction_tools/HPC_scripts/era5_data_extraction_template.sh similarity index 56% rename from video_prediction_tools/HPC_scripts/data_extraction_era5_template.sh rename to video_prediction_tools/HPC_scripts/era5_data_extraction_template.sh index d4741e796c2762d8056204bc36e21655af1e70dd..255e1651cf50abdcb2b455a69b6ae4b4545ac7c4 100644 --- a/video_prediction_tools/HPC_scripts/data_extraction_era5_template.sh +++ b/video_prediction_tools/HPC_scripts/era5_data_extraction_template.sh @@ -3,13 +3,13 @@ #SBATCH --account=<your_project> #SBATCH --nodes=1 #SBATCH --ntasks=13 -##SBATCH --ntasks-per-node=13 +##SBATCH --ntasks-per-node=12 #SBATCH --cpus-per-task=1 -#SBATCH --output=data_extraction_era5-out.%j -#SBATCH --error=data_extraction_era5-err.%j +#SBATCH --output=DataExtraction_era5_step1-out.%j +#SBATCH --error=DataExtraction_era5_step1-err.%j #SBATCH --time=04:20:00 -#SBATCH --partition=batch #SBATCH --gres=gpu:0 +#SBATCH --partition=batch #SBATCH --mail-type=ALL #SBATCH --mail-user=me@somewhere.com @@ -22,7 +22,7 @@ exit 99 VIRT_ENV_NAME="my_venv" # Activate virtual environment if needed (and possible) -if [ -z ${VIRTUAL_ENV} ]; then +if [ -z "${VIRTUAL_ENV}" ]; then if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then echo "Activating virtual environment..." source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate @@ -34,16 +34,21 @@ fi # Loading modules source ../env_setup/modules_preprocess+extract.sh -# Declare path-variables (dest_dir will be set and configured automatically via generate_runscript.py) -source_dir=/my/path/to/era5 + +# select years and variables for dataset and define target domain +years=( 2017 ) +months=( "all" ) +var_dict='{"2t": {"sf": ""}, "tcc": {"sf": ""}, "t": {"ml": "p85000."}}' +sw_corner=(38.4 0.0) +nyx=(56 92) + +# set some paths +# note, that destination_dir is adjusted during runtime based on the data +source_dir=/my/path/to/era5/data destination_dir=/my/path/to/extracted/data -varmap_file=/my/path/to/varmapping/file -years=( "2015" ) +# execute Python-script +srun python ../main_scripts/main_era5_data_extraction.py -src_dir "${source_dir}" \ + -dest_dir "${destination_dir}" -y "${years[@]}" -m "${months[@]}" \ + -swc "${sw_corner[@]}" -nyx "${nyx[@]}" -v "${var_dict}" -# Run data extraction -for year in "${years[@]}"; do - echo "Perform ERA5-data extraction for year ${year}" - srun python ../main_scripts/main_data_extraction.py --source_dir ${source_dir} --target_dir ${destination_dir} \ - --year ${year} --varslist_path ${varmap_file} -done diff --git a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh b/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh deleted file mode 100644 index 990095c1f5fa00e4058e08d355830e0fa620b0f3..0000000000000000000000000000000000000000 --- a/video_prediction_tools/HPC_scripts/preprocess_data_era5_step1_template.sh +++ /dev/null @@ -1,59 +0,0 @@ -#!/bin/bash -x -## Controlling Batch-job -#SBATCH --account=<your_project> -#SBATCH --nodes=1 -#SBATCH --ntasks=13 -##SBATCH --ntasks-per-node=12 -#SBATCH --cpus-per-task=1 -#SBATCH --output=DataPreprocess_era5_step1-out.%j -#SBATCH --error=DataPreprocess_era5_step1-err.%j -#SBATCH --time=04:20:00 -#SBATCH --gres=gpu:0 -#SBATCH --partition=batch -#SBATCH --mail-type=ALL -#SBATCH --mail-user=me@somewhere.com - -######### Template identifier (don't remove) ######### -echo "Do not run the template scripts" -exit 99 -######### Template identifier (don't remove) ######### - -# Name of virtual environment -VIRT_ENV_NAME="my_venv" - -# Activate virtual environment if needed (and possible) -if [ -z ${VIRTUAL_ENV} ]; then - if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then - echo "Activating virtual environment..." - source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate - else - echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." - exit 1 - fi -fi -# Loading modules -source ../env_setup/modules_preprocess+extract.sh - - -# select years and variables for dataset and define target domain -years=( "2015" ) -variables=( "t2" "t2" "t2" ) -sw_corner=( -999.9 -999.9) -nyx=( -999 -999 ) - -# set some paths -# note, that destination_dir is adjusted during runtime based on the data -source_dir=/my/path/to/extracted/data/ -destination_dir=/my/path/to/pickle/files - -# execute Python-scripts -for year in "${years[@]}"; do - echo "start preprocessing data for year ${year}" - srun python ../main_scripts/main_preprocess_data_step1.py \ - --source_dir ${source_dir} --destination_dir ${destination_dir} --years "${year}" \ - --vars "${variables[0]}" "${variables[1]}" "${variables[2]}" \ - --sw_corner "${sw_corner[0]}" "${sw_corner[1]}" --nyx "${nyx[0]}" "${nyx[1]}" -done - - -#srun python ../../workflow_parallel_frame_prediction/DataPreprocess/mpi_split_data_multi_years.py --destination_dir ${destination_dir} --varnames T2 MSL gph500 diff --git a/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh b/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh new file mode 100644 index 0000000000000000000000000000000000000000..44ccf018d2896553ad360d5c5dbd0c398b7b54d8 --- /dev/null +++ b/video_prediction_tools/HPC_scripts/train_model_weatherbench_template.sh @@ -0,0 +1,78 @@ +#!/bin/bash -x +#SBATCH --account=<your_project> +#SBATCH --nodes=1 +#SBATCH --ntasks=1 +#SBATCH --output=train_model_era5-out.%j +#SBATCH --error=train_model_era5-err.%j +#SBATCH --time=24:00:00 +#SBATCH --gres=gpu:1 +#SBATCH --partition=some_partition +#SBATCH --mail-type=ALL +#SBATCH --mail-user=me@somewhere.com + +######### Template identifier (don't remove) ######### +echo "Do not run the template scripts" +exit 99 +######### Template identifier (don't remove) ######### + +# auxiliary variables +WORK_DIR="$(pwd)" +BASE_DIR=$(dirname "$WORK_DIR") +# Name of virtual environment +VIRT_ENV_NAME="my_venv" +# !!! ADAPAT DEPENDING ON USAGE OF CONTAINER !!! +# For container usage, comment in the follwoing lines +# Name of container image (must be available in working directory) +CONTAINER_IMG="${WORK_DIR}/tensorflow_21.09-tf1-py3.sif" +WRAPPER="${BASE_DIR}/env_setup/wrapper_container.sh" + +# sanity checks +if [[ ! -f ${CONTAINER_IMG} ]]; then + echo "ERROR: Cannot find required TF1.15 container image '${CONTAINER_IMG}'." + exit 1 +fi + +if [[ ! -f ${WRAPPER} ]]; then + echo "ERROR: Cannot find wrapper-script '${WRAPPER}' for TF1.15 container image." + exit 1 +fi + +# clean-up modules to avoid conflicts between host and container settings +module purge + +# declare directory-variables which will be modified by generate_runscript.py +source_dir=/my/path/to/tfrecords/files +destination_dir=/my/model/output/path + +# valid identifiers for model-argument are: convLSTM, savp, mcnet and vae +model=convLSTM +datasplit_dict=${destination_dir}/data_split.json +model_hparams=${destination_dir}/model_hparams.json + +# run training in container +export CUDA_VISIBLE_DEVICES=0 +## One node, single GPU +srun --mpi=pspmix --cpu-bind=none \ + singularity exec --nv "${CONTAINER_IMG}" "${WRAPPER}" ${VIRT_ENV_NAME} \ + python3 "${BASE_DIR}"/main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \ + --dataset weatherbench --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/ + +# WITHOUT container usage, comment in the follwoing lines (and uncomment the lines above) +# Activate virtual environment if needed (and possible) +#if [ -z ${VIRTUAL_ENV} ]; then +# if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then +# echo "Activating virtual environment..." +# source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate +# else +# echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." +# exit 1 +# fi +#fi +# +# Loading modules +#module purge +#source ../env_setup/modules_train.sh +#export CUDA_VISIBLE_DEVICES=0 +# +# srun python3 "${BASE_DIR}"/main_scripts/main_train_models.py --input_dir ${source_dir} --datasplit_dict ${datasplit_dict} \ +# --dataset era5 --model ${model} --model_hparams_dict ${model_hparams} --output_dir ${destination_dir}/ \ No newline at end of file diff --git a/video_prediction_tools/data_preprocess/calc_climatology.py b/video_prediction_tools/data_extraction/calc_climatology.py similarity index 100% rename from video_prediction_tools/data_preprocess/calc_climatology.py rename to video_prediction_tools/data_extraction/calc_climatology.py diff --git a/video_prediction_tools/data_extraction/dataset_options.py b/video_prediction_tools/data_extraction/dataset_options.py new file mode 100644 index 0000000000000000000000000000000000000000..81160d4cf1808efa76ce13e81727df5ce3ef3955 --- /dev/null +++ b/video_prediction_tools/data_extraction/dataset_options.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT + +known_datasets = {"era5", "weatherbench"} \ No newline at end of file diff --git a/video_prediction_tools/data_extraction/extract_era5_data.py b/video_prediction_tools/data_extraction/extract_era5_data.py new file mode 100644 index 0000000000000000000000000000000000000000..d17a2f82ab75a2257b65cfa16622ddf3a015e757 --- /dev/null +++ b/video_prediction_tools/data_extraction/extract_era5_data.py @@ -0,0 +1,425 @@ +""" +Class and functions required for preprocessing raw ERA5 data. +Calling the class-object will... +* select variables of interest +* slice data to target region +* interpolate multi-level data onto pressure levels +Data will be stored monthly in netCDF-files, whereas the original ERA5-data is expected to be hourly data in grib-files. +""" +__email__ = "m.langguth@fz-juelich.de" +__author__ = "Michael Langguth" +__date__ = "2022-04-07" + +import os, glob +from typing import List, Union, get_args +import shutil +import subprocess as sp +import numpy as np +import pandas as pd +import logging +from general_utils import check_str_in_list, get_path_component, isw +from pystager_utils import PyStager + +# for typing +str_or_List = Union[List, str] + + +class Extract_ERA5_data(object): + + cls_name = "Extract_ERA5_data" + + def __init__(self, dirin: str, dirout: str, var_req: dict, coord_sw: List, nyx: List, years: List, + months: str_or_List = "all", lon_intv: List = (0., 360.), lat_intv: List = (-90., 90.), + dx: float = 0.3): + """ + This script performs several sanity checks and sets the class attributes accordingly. + :param dirin: directory to the ERA5 reanalysis data + :param dirout: directory where the output data will be saved + :param var_req: controlled dictionary for getting variables from ERA5-dataset, e.g. {"t": {"ml": "p850"}} + :param coord_sw: latitude [°N] and longitude [°E] of south-western corner defining target domain + :param nyx: number of grid points in latitude and longitude direction for target domain + :param lon_intv: Allowed interval for slicing along longitude axis for ERA5 data (adapt if required) + :param lat_intv: Allowed interval for slicing along latitude axis for ERA5 data (adapt if required) + :param dx: grid spacing of regular, spherical grid onto which the ERA5 data is provided (adapt if required) + """ + method = Extract_ERA5_data.__init__.__name__ + + self.dirout = dirout if os.path.isdir(dirout) else None + if not self.dirout: + raise NotADirectoryError("%{0}: Output directory does not exist.".format(method)) + + self.months = self.get_months(months) + self.dirin, self.years = self.check_dirin(dirin, years, months) + self.var_requests = self.check_varnames(var_req) + # some basic grid information + self.lat_intv, self.lon_intv = list(lat_intv), list(lon_intv) + self.lon_intv[1] -= np.abs(dx) + self.dx = dx + # check provided data for regional slicing + self.lat_bounds, self.lon_bounds = self.check_coords(coord_sw, nyx) + + # initialize PyStager + self.era5_pystager = PyStager(self.worker_func, "year_month_list") + + def __call__(self): + """ + Set-up and run Pystager to preprocess the data. + :return: - + """ + self.era5_pystager.setup(self.years, self.months) + self.era5_pystager.run(self.dirin, self.dirout, self.var_requests, self.lat_bounds, self.lon_bounds) + + def check_varnames(self, var_req: dict): + """ + Check if all variables can be found in an exemplary datafile stored under datadir + :param var_req: nested dictionary where first-level keys carry request variable name. The values of these keys + are dictionaries whose keys denote the type of variable (e.g. "sf" for surface) + and whose values control (optional) vertical interpolation + :return: the approved dictionary (same as var_req) + """ + method = Extract_ERA5_data.check_varnames.__name__ + + allowed_vartypes = ["ml", "sf"] + + assert isinstance(var_req, dict), "%{0}: var_req must be a (controlled) dictionary. See doc-string."\ + .format(method) + + if not all(isinstance(vardict, dict) for vardict in var_req.values()): + raise ValueError("%{0}: Values of var_req-dictionary must be dictionary, i.e. pass a nested dictionary." + .format(method)) + + varnames = list(var_req.keys()) + vartypes = [list(var_req[varname].keys())[0] for varname in varnames] + + # first check for availability of variables in grib-files + for vartype in allowed_vartypes: + inds = [i for i, vtype in enumerate(vartypes) if vtype == vartype] + vars2check = list(map(varnames.__getitem__, inds)) + if not vars2check: continue # skip the following if no variable to check + + # construct path to exemplary datafile + yy_str, mm_str = str(self.years[0]), "{0:02d}".format(self.months[0]) + f2check = os.path.join(self.dirin, yy_str, mm_str, "{0}{1}0100_{2}.grb".format(yy_str, mm_str, vartype)) + + _ = Extract_ERA5_data.check_var_in_grib(f2check, vars2check, labort=True) + + # second check for content of nested dictionaries + print("%{0}: Start checking consistency of nested dictionaries for each requested variable.".format(method)) + for varname in varnames: + # check vartypes + vartype = list(var_req[varname].keys())[0] + lvl_info = list(var_req[varname].values())[0] + if not vartype in allowed_vartypes: + raise ValueError("%{0}: Key of variable dict for '{1}' must be one of the following types: {2}" + .format(method, vartype, ", ".join(allowed_vartypes))) + # check level types + if vartype == "sf" and lvl_info != "": + print("%{0}: lvl_info for surface variable '{1}' is not empty and thus will be ignored." + .format(method, varname)) + elif vartype == "ml" and not lvl_info.startswith("p"): + raise ValueError("%{0}: Variable '' on model levels requires target pressure level".format(method) + + "for interpolation. Thus, provide 'pX' where X denotes a pressure level in hPa.") + + print("%{0}: Check for consistency of nested dictionaries approved.".format(method)) + + return var_req + + def check_coords(self, coords_sw: List, nyx: List): + """ + Check arguments for defining target domain and return bounding coordinates. + :param coords_sw: latitude [°N] and longitude [°E] of south-western corner defining target domain + :param nyx: number of grid points in latitude and longitude direction for target domain + :return: two tuples of latitude and longitude boundaries of target domain, respectively + + TO-DO: Support slicing accross the zero meridian (i.e. handling data queries accross 0°E). + """ + method = Extract_ERA5_data.check_coords.__name__ + + if self.dx < 0: + print("%{0}: Grid spacing should be positive. Change negative value to positive one.".format(method)) + self.dx = np.abs(self.dx) + + # convert type of input data if required + coords_sw = [float(coord) for coord in coords_sw] + nyx = [int(n) for n in nyx] + + if not isw(coords_sw[0], self.lat_intv): + raise ValueError("%{0}: Latitude of south-western domain corner {1:.2f}".format(method, coords_sw[0]) + + "is not within expected [{0:.1f}°N, {1:.f}°N]".format(self.lat_intv[0], self.lat_intv[1])) + + if not isw(coords_sw[1], self.lon_intv): + raise ValueError("%{0}: Longitude of south-western domain corner {1:.2f}".format(method, coords_sw[1]) + + "is not within expected [{0:.1f}°E, {1:.f}°E]".format(self.lon_intv[0], self.lon_intv[1])) + + coords_ne = coords_sw[0] + (nyx[0] - 1)*self.dx, coords_sw[1] + (nyx[1] - 1)*self.dx + + if not isw(coords_ne[0], self.lat_intv): + raise ValueError("%{0}: Latitude of north-eastern domain corner {1:.2f}".format(method, coords_ne[0]) + + "is not within expected [{0:.1f}°N, {1:.f}°N]. Adapt nyx-argument." + "".format(self.lat_intv[0], self.lat_intv[1])) + + if not isw(coords_ne[1], self.lon_intv): + raise ValueError("%{0}: Longitude of north-eastern domain corner {1:.2f}".format(method, coords_ne[1]) + + "is not within expected [{0:.1f}°E, {1:.f}°E]. Adapt nyx-argument." + .format(self.lon_intv[0], self.lon_intv[1])) + + return [coords_sw[0], coords_ne[0]], [coords_sw[1], coords_ne[1]] + + @staticmethod + def worker_func(year_months: list, dirin: str, dirout: str, var_req: dict, lat_bounds, lon_bounds, + logger: logging.Logger, nmax_warn: int = 1): + """ + Handle grib-files from dirin to extract the variables of interest incl. lateral slicing and optional interpol. + The resulting netCDF-files contain monthly data. + :param year_months: List of months of years (format: YYYY-MM) to be processed (list elements: datetime-objects) + :param dirin: base input directory where grib-files are sorted in yearly and monthly sub-directories + :param dirout: base output-directory; netCDF-files will be colledted in yearly subdirectories + :param var_req: dictionary for querying variables and interpolation from grib-files, + e.g. {"t" : {"ml": "85000."}} for getting 850hPa temperature + :param lat_bounds: lateral boundaries of target domain in meridional (latitude) direction + :param lon_bounds: lateral boundaries of target domain in zonal (longitude) direction + :param logger: logger instance + :param nmax_warn: maximum number of allowed warnings/errors during runtime of worker + :return: netCDF-file with monthly data under <dirout>/YYYY/ + + TO-DO: Support interpolation on multiple pressure levels + """ + method = Extract_ERA5_data.worker_func.__name__ + + # sanity check + assert isinstance(logger, logging.Logger), "%{0}: logger-argument must be a logging.Logger instance" \ + .format(method) + + varnames = list(var_req.keys()) + vartypes = [list(var_req[varname].keys())[0] for varname in varnames] + + # initilaize warn counter and start iterating over year_month-list + nwarns = 0 + + for year_month in year_months: + year, month = int(year_month.strftime("%Y")), int(year_month.strftime("%m")) + year_str, month_str = str(year), "{0:02d}".format(int(month)) + + dirin_now, dirout_now = os.path.join(dirin, year_str, month_str), os.path.join(dirout, year_str) + dirout_tmp = os.path.join(dirout_now, "{0}_tmp".format(month_str)) + dest_file = os.path.join(dirout, "ambs_era5_{0}{1}.nc".format(year_str, month_str)) + # create output- and temp-directory (store intermediate netCDF-files merged with -mergetime operator later) + os.makedirs(dirout_now, exist_ok=True) + os.makedirs(dirout_tmp, exist_ok=True) + + for vartype in np.unique(vartypes): + logger.info("Start processing variable type '{1}'".format(method, vartype)) + vars4type = [varname for c, varname in enumerate(varnames) if vartypes[c] == vartype] + # ensure that lnsp (logarithmic surface pressure) and z (geopotential) are in list for p-interpolation + if vartype == "ml": + vars4type_aux = set(vars4type + ["lnsp", "z"]) + # this only allows handling of a single pressure level for all variables! + p_lvl = int(float(var_req[vars4type[0]].get("ml").lstrip("p"))) + cmd_add, vars4type = Extract_ERA5_data.get_cdo_op4ml(p_lvl, vars4type) + else: + cmd_add = None + vars4type_aux = vars4type + + search_patt = os.path.join(dirin_now, "{0}{1}*_{2}.grb".format(year_str, month_str, vartype)) + + logger.info("%{0}: Serach for grib-files under '{1}' for year {2} and month {3}" + .format(method, dirin_now, year_str, month_str)) + grb_files = glob.glob(search_patt) + + nfiles, nfiles_exp = len(grb_files), pd.Period("{0}-{1}".format(year_str, month_str)).days_in_month*24 + + if not nfiles == nfiles_exp: + err = "%{0}: Found {1:d} grib-files with search pattern '{2}'".format(method, nfiles, search_patt) \ + + ", but {0:d} files found. Check data directory...".format(nfiles) + logger.critical(err) + raise FileNotFoundError(err) + + logger.info("%{0}: Start converting and slicing of data from {1:d} files found with pattern {2}..." + .format(method, nfiles, search_patt)) + + # process each file individually and store resulting netCDF-file in temp-directory + for grb_file in grb_files: + tmp_file = os.path.join(dirout_tmp, + "preprocess_{0}".format(os.path.basename(grb_file).replace("grb", "nc"))) + + cmd = "cdo -v --eccodes -f nc copy -selname,{0} -sellonlatbox,{1},{2},{3},{4} {5} {6}" \ + .format(",".join(vars4type_aux), *lon_bounds, *lat_bounds, grb_file, tmp_file) + + if cmd_add is not None: # append cdo-command if required + cmd = cmd.replace("copy", "copy{0}".format(cmd_add)) + + nwarns = Extract_ERA5_data.run_cmd(cmd, logger, nwarns) + # check if nwarns was exceeded + if nwarns >= nmax_warn: return -1 + + # Merge all files in temp-directory of current vartype to yield monthly data files + logger.info("%{0}: Start merging all files from vartype '{1}', i.e. '*{1}.nc'.".format(method, vartype)) + tmp_file = os.path.join(dirout_tmp, "preprocess_merged_{0}.nc".format(vartype)) + cmd = "cdo -v -mergetime {0} {1}".format(os.path.join(dirout_tmp, "*{0}.nc".format(vartype)), tmp_file) + nwarns = Extract_ERA5_data.run_cmd(cmd, logger, nwarns, labort=True) + # check if previous merging failed + if nwarns == -1: return nwarns + # remove auxiliary data from file if pressure interpolation was performed + if vartype == "ml": + cmd = "cdo -v -selname,{0} {1} {2}".format(",".join(vars4type), tmp_file, + tmp_file.replace("ml", "ml_reduced")) + nwarns = Extract_ERA5_data.run_cmd(cmd, logger, nwarns, labort=True) + # check if previous variable selection failed + if nwarns == -1: return nwarns + + # Now, merge the temp-files to get the final output file + cmd = "cdo -v -merge {0} {1}".format(os.path.join(dirout_tmp, "preprocess_merged_*.nc"), dest_file) + logger.info("%{0}: Final merge of temp-files to '{1}'.".format(method, dest_file)) + nwarns = Extract_ERA5_data.run_cmd(cmd, logger, nwarns, labort=True) + # clean temp-directory + if nwarns == -1: + logger.info("%{0}: Clean temp-directory '{1}'.".format(method, dirout_tmp)) + shutil.rmtree(dirout_tmp) + + return nwarns + + @staticmethod + def get_cdo_op4ml(p_lvl, variables): + """ + Generate CDO-command snippet for performing pressure interpolation and renaming variables in data file. + Also adjusts the list of variables subject to renaming. + :param p_lvl: pressure level onto which data is interpolated [Pa] + :param variables: list of variable names subject to interpolation + :return: cdo-command snippet for pressure interpolation and updated list of variable names + """ + varrename = ["{0},{0}_{1:d}".format(var, int(p_lvl/100.)) for var in variables] + cdo_add = " -chname,{0} -ml2pl,{1:d} ".format(",".join(varrename), p_lvl) + # also adjust variable names incl. vars4type for latter variable selection + vars4type = ["{0}_{1:d}".format(var, int(p_lvl/100.)) for var in variables] + + return cdo_add, vars4type + + @staticmethod + def run_cmd(cmd: str, logger: logging.Logger, nwarns: int, labort: bool = False): + """ + Run command in separated shell. + :param cmd: command to run + :param logger: logger instance + :param nwarns: number of current warnings + :param labort: Boolean if abortion will be triggered (i.e. return nwarns = -1) + :return: updated nwarns + """ + method = Extract_ERA5_data.run_cmd.__name__ + + nwarns_loc = nwarns + try: + _ = sp.check_output(cmd, stderr=sp.STDOUT, shell=True) + logger.info("%{0}: Command '{1}' ran successfully...".format(method, cmd)) + except sp.CalledProcessError as exc: + logger.critical("%{0}: Failed to run the following command: {1}".format(method, cmd)) + logger.critical("%{0}: Return code: {1}, error message: {2}".format(method, exc.returncode, + exc.output)) + if labort: + nwarns_loc = -1 + else: + nwarns_loc += 1 + + return nwarns_loc + + @staticmethod + def check_dirin(dirin: str, years: str_or_List, months: List): + """ + Checks if data directories for all years exist + :param dirin: path to basic data directory under which files are located + :param years: years for which data is requested + :param months: months for which data is requested + :return: status + """ + method = Extract_ERA5_data.check_dirin.__name__ + + years = list(years) + # basic sanity checks + assert isinstance(dirin, str), "%{0}: Parsed dirin must be a string, but is of type '{1}'".format(method, + type(dirin)) + if not all(isinstance(yr, int) for yr in years): + raise ValueError("%{0}: Passed years must be a list of integers.".format(method)) + + if not os.path.isdir(dirin): + raise NotADirectoryError("%{0}: Input directory for ERA%-data '{1}' does not exist.".format(method, dirin)) + + # check if at least one ERA5-datafile is present + for year in years: + for month in months: + yr_str, mm_str = str(year), "{0:02d}".format(int(month)) + dirin_now = os.path.join(dirin, yr_str, mm_str) + print("dirin_base: {0}, dirin_now: {1}".format(dirin, dirin_now)) + f = glob.iglob(os.path.join(dirin_now, "{0}{1}*.grb".format(yr_str, mm_str))) + + _exhausted = object() + if next(f, _exhausted) is _exhausted: + raise FileNotFoundError("%{0}: Could not find any ERA5 file for {1}/{2} under '{3}'" + .format(method, yr_str, mm_str, dirin_now)) + return dirin, years + + @staticmethod + def get_months(months): + + method = Extract_ERA5_data.get_months.__name__ + + assert isinstance(months, get_args(str_or_List)), \ + "%{0}: months must be either a list of months or a (known) string.".format(method) + + if isinstance(months, list): + month_list = [int(x) for x in months] + if not np.all([isw(x, [1,12]) for x in month_list]): + for i, x in enumerate(month_list): + print(" Value {0:d}: {1}".format(i, x)) + raise ValueError("%{0}: Not all elements of months can serve as month-integers".format(method) + + "(only values between 1 and 12 can be passed).") + elif months == "DJF": + month_list = [1, 2, 12] + elif months == "MAM": + month_list = [3, 4, 5] + elif months == "JJA": + month_list = [6, 7, 8] + elif months == "SON": + month_list = [9, 10, 11] + elif months == "all": + month_list = list(np.range(1, 13)) + else: + raise ValueError("%{0}: months-argument cannot be converted to list of months (see doc-string)" + .format(method)) + + return month_list + + @staticmethod + def check_var_in_grib(gribfile: str, varnames: str_or_List, labort: bool = False): + """ + Checks if the desired varname exists in gribfile. Requires grib_ls. + :param gribfile: name of gribfile to be checked. + :param varnames: name of variable or list of variable names (must be the shortName!) + :param labort: flag if script breaks in case that variable is not found in gribfile + :return: status of check (True if all variables are found in gribfile) + """ + method = Extract_ERA5_data.check_var_in_grib.__name__ + + if not (os.path.isfile(gribfile) and gribfile.endswith("grb")): + raise FileNotFoundError("%{0}: File '{1}' does not exist or is not a grib-file.".format(method, gribfile)) + + if shutil.which("grib_ls") is None: + raise NotImplementedError("%{0}: Program 'grib_ls' is not available".format(method)) + + cmd = "grib_ls -p shortName:s {0} | tail -n +3 | head -n -3".format(gribfile) + varlist = str(sp.check_output(cmd, stderr=sp.STDOUT, shell=True)).lstrip("b'").rstrip("'").replace(" ", "") + varlist = varlist.split("\\n") + + stat = check_str_in_list(varlist, varnames, labort=labort) + + return stat + + + + + + + + + + + + diff --git a/video_prediction_tools/data_extraction/extract_weatherbench.py b/video_prediction_tools/data_extraction/extract_weatherbench.py new file mode 100644 index 0000000000000000000000000000000000000000..a5798c188d62b69d93ef22efa0253a80bf2b031a --- /dev/null +++ b/video_prediction_tools/data_extraction/extract_weatherbench.py @@ -0,0 +1,168 @@ +import os, glob +import logging + +from zipfile import ZipFile +from typing import Union +from pathlib import Path +import multiprocessing as mp +import itertools as it +import sys + +import pandas as pd +import xarray as xr + +from utils.dataset_utils import get_filename_template + +logging.basicConfig(level=logging.DEBUG) + +class ExtractWeatherbench: + max_years = list(range(1979, 2018)) + + def __init__( + self, + dirin: Path, + dirout: Path, + variables: list[dict], + years: Union[list[int], int], + months: list[int], + lat_range: tuple[float], + lon_range: tuple[float], + resolution: float, + ): + """ + This script performs several sanity checks and sets the class attributes accordingly. + :param dirin: directory to the ERA5 reanalysis data + :param dirout: directory where the output data will be saved + :param variables: controlled dictionary for getting variables from ERA5-dataset, e.g. {"t": {"ml": "p850"}} + :param years: list of year to to extract, -1 if all + :param months: list of months to extract + :param lat_range: domain of the latitude axis to extract + :param lon_range: domain of the longitude axis to extract + :param resolution: spacing on both lat, lon axis + """ + + self.dirin = dirin + self.dirout = dirout + + if years[0] == -1: + self.years = ExtractWeatherbench.max_years + else: + self.years = years + self.months = months + + # TODO handle special variables for resolution 5.625 (temperature_850, geopotential_500) + if resolution == 5.625: + for var in variables: + combined_name = f"{var['name']}_{var['lvl'][0]}" + if combined_name in {"temperature_850", "geopotential_500"}: + var["name"] = combined_name + + self.variables = variables + + self.lat_range = lat_range + self.lon_range = lon_range + + self.resolution = resolution + + + def __call__(self): + """ + Run extraction. + :return: - + """ + logging.info("start extraction") + + zip_files, data_files = self.get_data_files() + + # extract archives => netcdf files (maybe use tempfiles ?) + args = [ + (var_zip, file, self.dirout) + for var_zip, files in zip(zip_files, data_files) + for file in files + ] + with mp.Pool(20) as p: + p.starmap(ExtractWeatherbench.extract_task, args) + logging.info("finished extraction") + + # TODO: handle 3d data + + # load data + files = [self.dirout / file for data_file in data_files for file in data_file] + ds = xr.open_mfdataset(files, coords="minimal", compat="override") + logging.info("opened dataset") + ds.drop_vars("level") + logging.info("data loaded") + + # select months + ds = ds.isel(time=ds.time.dt.month.isin(self.months)) + + # select region + ds = ds.sel(lat=slice(*self.lat_range), lon=slice(*self.lon_range)) + logging.info("selected region") + + # split into monthly netcdf + year_month_idx = pd.MultiIndex.from_arrays( + [ds.time.dt.year.values, ds.time.dt.month.values] + ) + ds.coords["year_month"] = ("time", year_month_idx) + logging.info("constructed splitting-index") + + with mp.Pool(20) as p: + p.map( + ExtractWeatherbench.write_task, + zip(ds.groupby("year_month"), it.repeat(self.dirout)), + chunksize=5, + ) + logging.info("wrote output") + + @staticmethod + def extract_task(var_zip, file, dirout): + with ZipFile(var_zip, "r") as myzip: + myzip.extract(path=dirout, member=file) + + @staticmethod + def write_task(args): + (year_month, monthly_ds), dirout = args + year, month = year_month + logging.debug(f"{year}.{month:02d}: dropping index") + monthly_ds = monthly_ds.drop_vars("year_month") + try: + logging.debug(f"{year}.{month:02d}: writing to netCDF") + monthly_ds.to_netcdf(path=dirout / get_filename_template("weatherbench").format(year=year, month=month)) + except RuntimeError as e: + logging.error(f"runtime error for writing {year}.{month}\n{str(e)}") + logging.debug(f"{year}.{month:02d}: finished processing") + + def get_data_files(self): + """ + Get path to zip files and names of the yearly files within. + :return lists paths to zips of variables + """ + data_files = [] + zip_files = [] + res_str = f"{self.resolution}deg" + years = self.years + for var in self.variables: + var_dir = self.dirin / res_str / var["name"] + if not var_dir.exists(): + raise ValueError( + f"variable {var} is not available for resolution {res_str}" + ) + + zip_file = var_dir / f"{var['name']}_{res_str}.zip" + with ZipFile(zip_file, "r") as myzip: + names = myzip.namelist() + logging.debug(f"var:{var}\nyears:{years}\nnames:{names}") + if not all(any(str(year) in name for name in names) for year in years): + missing_years = list(filter(lambda year: any(str(year) in name for name in names), years)) + raise ValueError( + f"variable {var} is not available for years: {missing_years}" + ) + names = filter( + lambda name: any(str(year) in name for year in years), names + ) + + data_files.append(list(names)) + zip_files.append(zip_file) + + return zip_files, data_files diff --git a/video_prediction_tools/data_extraction/prepare_gzprcp_data.py b/video_prediction_tools/data_extraction/prepare_gzprcp_data.py new file mode 100644 index 0000000000000000000000000000000000000000..675713ece670f3ca8bdb46c9e25175dca0648183 --- /dev/null +++ b/video_prediction_tools/data_extraction/prepare_gzprcp_data.py @@ -0,0 +1,162 @@ +""" +Class and functions required for preprocessing guizhou prcp data from .nc to TFRecords +""" +__email__ = "y.ji@fz-juelich.de" +__author__ = "Yan Ji, Bing Gong" +__date__ = "2021_05_09" + +import datetime +import os +import numpy as np +import tensorflow as tf +import argparse +import netCDF4 as nc +from model_modules.video_prediction.datasets.gzprcp_data import GZprcp + + +class GZprcp2Tfrecords(GZprcp): + + def __init__(self, input_dir=None, target_year=2019,dest_dir=None, sequences_per_file=10): + """ + This class is used for converting .nc files to tfrecords + + :param input_dir: str, the path direcotry to the file of npz + :param dest_dir: the output directory to save TFrecords. + :param sequence_length: int, default is 40, the sequence length per sample + :param sequences_per_file:int, how many sequences/samples per tfrecord to be saved + """ + self.input_dir = input_dir + self.output_dir = dest_dir + self.target_year = target_year + os.makedirs(self.output_dir, exist_ok = True) + self.sequences_per_file = sequences_per_file + self.write_sequence_file() + + def __call__(self): + """ + steps to process nc file to tfrecords + :return: None + """ + self.read_nc_file() + self.save_nc_to_tfrecords() + + def read_nc_file(self): + data_temp = nc.Dataset(os.path.join(self.input_dir,str(self.target_year),"rainy","guizhou_prcp.nc")) + prcp_temp = np.transpose(data_temp['prcp'],[3,2,1,0]) + + ######### missing data + prcp_temp[np.isnan(prcp_temp)] = 0 + + self.data = prcp_temp + self.time = np.transpose(data_temp['time'],[2,1,0]) + print("data in gzprcp_test_Seq shape", self.data.shape) + return None + + def save_nc_to_tfrecords(self): + """ + Read the gzprcp data which is nc format, and save it to tfrecords files + The shape of data_nc is [number_samples,seq_length,height,width] + moving_mnst only has one channel + """ + idx = 0 + num_samples = self.data.shape[0] + if len(self.data.shape) == 4: + #add one dim to represent channel, then got [num_samples,seq_length,height,width,channel] + self.data = np.expand_dims(self.data, axis = 4) + elif len(self.data.shape) == 5: + pass + else: + #print('data shape nor match') + raise (f"The shape of input movning mnist npz file is {len(self.data.shape)} which is not either 4 or 5, please further check your data source!") + + self.data = self.data.astype(np.float32) + # self.data/= 255.0 # normalize RGB codes by dividing it to the max RGB value + + ############# normalization ############ + #k = 0.001 + #self.data = np.log(self.data+k)-np.log(k) # log + + ####################################### + + while idx < num_samples - self.sequences_per_file: + sequences = self.data[idx:idx+self.sequences_per_file, :, :, :, :] + + # use the first sequence time + t_start = self.time[idx:idx+self.sequences_per_file,0,4]+self.time[idx:idx+self.sequences_per_file,0,3]*100+self.time[idx:idx+self.sequences_per_file,0,2]*10000+self.time[idx:idx+self.sequences_per_file,0,1]*1000000+self.time[idx:idx+self.sequences_per_file,0,0]*100000000 + + # t_start = self.time[idx:idx+self.sequences_per_file,:,:] + # print('self.target_year: ',self.target_year) + output_fname = 'sequence_Y_{}_index_{}_to_{}.tfrecords'.format(self.target_year, idx, idx + self.sequences_per_file-1) + output_fname = os.path.join(self.output_dir, output_fname) + GZprcp2Tfrecords.save_tf_record(output_fname, sequences, t_start) + idx = idx + self.sequences_per_file + return None + + @staticmethod + def save_tf_record(output_fname, sequences, t_start_points): + with tf.python_io.TFRecordWriter(output_fname) as writer: + for i in range(np.array(sequences).shape[0]): + sequence = sequences[i, :, :, :, :] + + ############### time class ############## + # t_start = datetime.datetime(int(t_start_points[i,19,0]),int(t_start_points[i,19,1]),int(t_start_points[i,19,2]),int(t_start_points[i,19,3]),int(t_start_points[i,19,4])).strftime("%Y%m%d%H%M") + + t_start = int(t_start_points[i]) + ############### time class ############## + + num_frames = len(sequence) + height, width = sequence[0, :, :, 0].shape + encoded_sequence = np.array([list(image) for image in sequence]) + features = tf.train.Features(feature = { + 'sequence_length': _int64_feature(num_frames), + 'height': _int64_feature(height), + 'width': _int64_feature(width), + 'channels': _int64_feature(1), + 't_start': _int64_feature(t_start), + 'images/encoded': _floats_feature(encoded_sequence.flatten()), + }) + example = tf.train.Example(features = features) + writer.write(example.SerializeToString()) + + def write_sequence_file(self): + """ + Generate a txt file, with the numbers of sequences for each tfrecords file. + This is mainly used for calculting the number of samples for each epoch during training epoch + """ + + with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file: + seq_file.write("%d\n" % self.sequences_per_file) + + + + +def _bytes_feature(value): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + + +def _bytes_list_feature(values): + return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) + +def _floats_feature(value): + return tf.train.Feature(float_list=tf.train.FloatList(value=value)) + +def _int64_feature(value): + return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) + + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("-source_dir", type=str, help="The input directory that contains the gprcp_data nc file", default="/p/scratch/deepacf/ji4/extractedData/guizhou_prcpdata/prcp_squence/") + parser.add_argument("-target_year", type=int,default=2019) + parser.add_argument("-dest_dir", type=str,default="/p/scratch/deepacf/ji4/preprocessedData/gzprcp_data/tfrecords_seq_len") + parser.add_argument("-sequences_per_file", type=int, default=10) + args = parser.parse_args() + inst = GZprcp2Tfrecords(args.source_dir, args.target_year, args.dest_dir, args.sequences_per_file) + inst() + + +if __name__ == '__main__': + main() + + diff --git a/video_prediction_tools/data_preprocess/dataset_options.py b/video_prediction_tools/data_preprocess/dataset_options.py deleted file mode 100644 index 4a9141b979db610a528dcf0d200c859e75ae9a6f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/data_preprocess/dataset_options.py +++ /dev/null @@ -1,23 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -def known_datasets(): - """ - An auxilary function - :return: dictionary of known datasets - """ - dataset_mappings = { - 'google_robot': 'GoogleRobotVideoDataset', - 'sv2p': 'SV2PVideoDataset', - 'softmotion': 'SoftmotionVideoDataset', - 'bair': 'SoftmotionVideoDataset', # alias of softmotion - 'kth': 'KTHVideoDataset', - 'ucf101': 'UCF101VideoDataset', - 'cartgripper': 'CartgripperVideoDataset', - "era5": "ERA5Dataset", - "moving_mnist": "MovingMnist" - # "era5_anomaly":"ERA5Dataset_v2_anomaly", - } - - return dataset_mappings diff --git a/video_prediction_tools/data_preprocess/era5_varmapping.json b/video_prediction_tools/data_preprocess/era5_varmapping.json deleted file mode 100644 index bc32d9060e915a597e596730ed6275754c1a2260..0000000000000000000000000000000000000000 --- a/video_prediction_tools/data_preprocess/era5_varmapping.json +++ /dev/null @@ -1,9 +0,0 @@ -{ -"surface": ["2t", "tcc","msl","10u","10v"], -"multi":{ - "t" : { - "pl": 85000 - } - - } -} diff --git a/video_prediction_tools/data_preprocess/era5_varmapping_template.json b/video_prediction_tools/data_preprocess/era5_varmapping_template.json deleted file mode 100644 index e62aba65c85f46cb85a81adbd07b3e31e9501361..0000000000000000000000000000000000000000 --- a/video_prediction_tools/data_preprocess/era5_varmapping_template.json +++ /dev/null @@ -1,22 +0,0 @@ -# NOTE: Please configure this JSON-files according your needs. Any line starting with # will be removed -# when editing is invoked from generate_runscript.py. -# -# Explanation: In the following, the mapping of known variable names from the ERA5-data (grib2-files) is defined -# The keys of the dictionary 'surface' (for 2D surface varibales) denote the variable names -# in the target netCDF-file while the values denote the name of the variable in the ERA5 grib file. -# For the dictionary 'multi' (used for 3D variables), the keys denote both, -# the variable name in the target netCDF-file and in the ERA5 grib file. -# The value of the 'pl'-key denotes the pressure level (in Pa) onto which the data is interpolated -# !!! This file should be only adapted if you are familiar with the ERA5 grib files!!! -{ -"surface":{ - ["2t", "tcc","msl","10u","10v"] - }, - -"multi":{ - "t" : { - "pl": 85000 - } - - } -} diff --git a/video_prediction_tools/data_preprocess/prepare_era5_data.py b/video_prediction_tools/data_preprocess/prepare_era5_data.py deleted file mode 100644 index e34bf8fe3da5f378142005a5ac18dc1c131112ba..0000000000000000000000000000000000000000 --- a/video_prediction_tools/data_preprocess/prepare_era5_data.py +++ /dev/null @@ -1,109 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -""" -Functions required for extracting ERA5 data. -""" -import os -import json -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong,Michael Langguth,Yanji" -__update_date__ = "2022-02-15" -# specify source and target directories - - -class ERA5DataExtraction(object): - - def __init__(self, year, job_name, src_dir, target_dir, varslist_json): - """ - Function to extract ERA5 data from slmet - args: - year : str, the target year to be processed "2017" - job_name :int from 1 to 12 correspoding to month - scr_dir :str, upper level of directory at year level - target_dir : str, upper level of directory at year level - varslist_json: str, the path to the varibale list that to be extracted from original grib file - """ - self.year = year - self.job_name = job_name - self.src_dir = src_dir - self.target_dir = target_dir - self.varslist_json = varslist_json - self.get_varslist() - - def get_varslist(self): - """ - Function that read varslist_path json file and get variable list - """ - with open(self.varslist_json) as f: - self.varslist = json.load(f) - - self.varslist_keys = list(self.varslist.keys()) - if not ("surface" in self.varslist_keys and "multi" in self.varslist_keys): - raise ValueError("Thie file '{0}' should have two keys : surface and multi".format(self.varslist_json)) - else: - self.varslist_surface = self.varslist["surface"] - self.varslist_multi = self.varslist["multi"] - self.varslist_multi_vars = self.varslist_multi.keys() - - - def prepare_era5_data_one_file(self, month, day, hour): # extract 2t,tcc,msl,t850,10u,10v - """ - Process one grib file from source directory (extract variables and interplolate variable) and save to output_directory - args: - month : str, the target month to be processed, e.g."01","02","03" ...,"12" - date : str, the target date to be processed e.g "01","02","03",..."31" - hour : str, the target hour to be processed e.g. "00","01",...,"23" - varslist_path: str, the path to variable list json file - output_path : str, the path to output directory - - """ - temp_path = os.path.join(self.target_dir, self.year) - os.makedirs(temp_path, exist_ok=True) - temp_path = os.path.join(self.target_dir, self.year, month) - os.makedirs(temp_path, exist_ok=True) - - for value in self.varslist_surface: - # surface variables - infile = os.path.join(self.src_dir, self.year, month, self.year+month+day+hour+'_sf.grb') - outfile_sf = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'_'+value+'.nc') - os.system('cdo --eccodes -f nc copy -selname,%s %s %s' % (value, infile, outfile_sf)) - - - # multi-level variables - for var, pl_dic in self.varslist_multi.items(): - for pl, pl_value in pl_dic.items(): - infile = os.path.join(self.src_dir, self.year, month, self.year+month+day+hour+'_ml.grb') - outfile_sf_temp = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'_'+var + - str(pl_value) + '.nc') - outfile_sf = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'_'+var + - str(int(pl_value/100.)) + '.nc') - os.system('cdo -f nc copy -selname,%s -ml2pl,%d %s %s' % (var,pl_value,infile,outfile_sf_temp)) - os.system('cdo -chname,%s,%s %s %s' % (var, var+"_{0:d}".format(int(pl_value/100.)), outfile_sf_temp, outfile_sf)) - os.system('rm %s' % (outfile_sf_temp)) - # merge both variables - infile = os.path.join(self.target_dir, self.year, month, self.year+month+day+hour+'*.nc') - # change the output file name - outfile = os.path.join(self.target_dir, self.year, month, 'ecmwf_era5_'+self.year[2:]+month+day+hour+'.nc') - os.system('cdo merge %s %s' % (infile, outfile)) - os.system('rm %s' % (infile)) - - def process_era5_in_dir(self): - """ - Function that extract data at year level - """ - - dates = list(range(1,32)) - dates = ["{:02d}".format(d) for d in dates] - - hours = list(range(0,24)) - hours = ["{:02d}".format(h) for h in hours] - - print ("job_name",self.job_name) - for d in dates: - for h in hours: - self.prepare_era5_data_one_file(self.job_name, d, h) - # here the defeinition of the failure, success is placed 0=success / -1= fatal-failure / +1 = non-fatal -failure - worker_status = 0 - return worker_status diff --git a/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py b/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py deleted file mode 100644 index d58f705083c7735c7cfd45af524cf0fef9821deb..0000000000000000000000000000000000000000 --- a/video_prediction_tools/data_preprocess/prepare_moving_mnist_data.py +++ /dev/null @@ -1,133 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -""" -Class and functions required for preprocessing Moving mnist data from .npz to TFRecords -""" -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong, Karim Mache" -__date__ = "2021_05_04" - - -import os -import numpy as np -import tensorflow as tf -import argparse -from model_modules.video_prediction.datasets.moving_mnist import MovingMnist - - -class MovingMnist2Tfrecords(MovingMnist): - - def __init__(self, input_dir=None, dest_dir=None, sequences_per_file=128): - """ - This class is used for converting .npz files to tfrecords - - :param input_dir: str, the path direcotry to the file of npz - :param dest_dir: the output directory to save TFrecords. - :param sequence_length: int, default is 20, the sequence length per sample - :param sequences_per_file:int, how many sequences/samples per tfrecord to be saved - """ - self.input_dir = input_dir - self.output_dir = dest_dir - os.makedirs(self.output_dir, exist_ok = True) - self.sequences_per_file = sequences_per_file - self.write_sequence_file() - - - def __call__(self): - """ - steps to process npy file to tfrecords - :return: None - """ - self.read_npz_file() - self.save_npz_to_tfrecords() - - def read_npz_file(self): - self.data = np.load(os.path.join(self.input_dir, "mnist_test_seq.npy")) - print("data in minist_test_Seq shape", self.data.shape) - return None - - def save_npz_to_tfrecords(self): # Bing: original 128 - """ - Read the moving_mnst data which is npz format, and save it to tfrecords files - The shape of dat_npz is [seq_length,number_samples,height,width] - moving_mnst only has one channel - """ - idx = 0 - num_samples = self.data.shape[1] - if len(self.data.shape) == 4: - #add one dim to represent channel, then got [seq_length,num_samples,height,width,channel] - self.data = np.expand_dims(self.data, axis = 4) - elif len(self.data.shape) == 5: - pass - else: - raise (f"The shape of input movning mnist npz file is {len(self.data.shape)} which is not either 4 or 5, please further check your data source!") - - self.data = self.data.astype(np.float32) - self.data/= 255.0 # normalize RGB codes by dividing it to the max RGB value - while idx < num_samples - self.sequences_per_file: - sequences = self.data[:, idx:idx+self.sequences_per_file, :, :, :] - output_fname = 'sequence_index_{}_to_{}.tfrecords'.format(idx, idx + self.sequences_per_file-1) - output_fname = os.path.join(self.output_dir, output_fname) - MovingMnist2Tfrecords.save_tf_record(output_fname, sequences) - idx = idx + self.sequences_per_file - return None - - @staticmethod - def save_tf_record(output_fname, sequences): - with tf.python_io.TFRecordWriter(output_fname) as writer: - for i in range(np.array(sequences).shape[1] - 1): - sequence = sequences[:, i, :, :, :] - num_frames = len(sequence) - height, width = sequence[0, :, :, 0].shape - encoded_sequence = np.array([list(image) for image in sequence]) - features = tf.train.Features(feature = { - 'sequence_length': _int64_feature(num_frames), - 'height': _int64_feature(height), - 'width': _int64_feature(width), - 'channels': _int64_feature(1), - 'images/encoded': _floats_feature(encoded_sequence.flatten()), - }) - example = tf.train.Example(features = features) - writer.write(example.SerializeToString()) - - def write_sequence_file(self): - """ - Generate a txt file, with the numbers of sequences for each tfrecords file. - This is mainly used for calculting the number of samples for each epoch during training epoch - """ - - with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file: - seq_file.write("%d\n" % self.sequences_per_file) - - - - -def _bytes_feature(value): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - - -def _bytes_list_feature(values): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) - -def _floats_feature(value): - return tf.train.Feature(float_list=tf.train.FloatList(value=value)) - -def _int64_feature(value): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("-input_dir", type=str, help="The input directory that contains the movning mnnist npz file", default="/p/largedata/datasets/moving-mnist/mnist_test_seq.npy") - parser.add_argument("-output_dir", type=str) - parser.add_argument("-sequences_per_file", type=int, default=2) - args = parser.parse_args() - inst = MovingMnist2Tfrecords(args.input_dir, args.output_dir, args.sequence_per_file) - inst() - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/data_preprocess/preprocess_data_step2.py b/video_prediction_tools/data_preprocess/preprocess_data_step2.py deleted file mode 100644 index a197471b22f28cc1c3bae9fe29bd7279d2015cde..0000000000000000000000000000000000000000 --- a/video_prediction_tools/data_preprocess/preprocess_data_step2.py +++ /dev/null @@ -1,301 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -""" -Class and functions required for preprocessing ERA5 data (preprocessing substep 2) -""" -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong" -__date__ = "2020_12_29" - - -# import modules -import os -import glob -import pickle -import numpy as np -import pandas as pd -import json -import tensorflow as tf -from normalization import Norm_data -from metadata import MetaData -import datetime -from model_modules.video_prediction.datasets import ERA5Dataset - - -class ERA5Pkl2Tfrecords(ERA5Dataset): - def __init__(self, input_dir=None, dest_dir=None, sequence_length=20, sequences_per_file=128, norm="minmax"): - """ - This class is used for converting pkl files to tfrecords - args: - input_dir : str, the path to the PreprocessData directory which is parent directory of "Pickle" - and "tfrecords" files directiory. - sequence_length : int, default is 20, the sequen length per sample - sequences_per_file : int, how many sequences/samples per tfrecord to be saved - norm : str, normalization methods from Norm_data class ("minmax" or "znorm"; - default: "minmax") - """ - self.input_dir = input_dir - self.output_dir = dest_dir - # if the output_dir does not exist, then create it - os.makedirs(self.output_dir, exist_ok=True) - # get metadata,includes the var_in, image height, width etc. - self.metadata_fl = os.path.join(os.path.dirname(self.input_dir.rstrip("/")), "metadata.json") - self.get_metadata(MetaData(json_file=self.metadata_fl)) - # Get the data split informaiton - self.sequence_length = sequence_length - if norm == "minmax" or norm == "znorm": - self.norm = norm - else: - raise ValueError("norm should be either 'minmax' or 'znorm'") - self.sequences_per_file = sequences_per_file - self.write_sequence_file() - - def get_years_months(self): - """ - Get the months in the datasplit_config - Return : - two elements: each contains 1-dim array with the months set from data_split_config json file - """ - self.months = [] - self.years_months = [] - # search for pickle names with pattern 'X_{}.pkl'for months - self.years = [name for name in os.listdir(self.input_dir) if os.path.isdir(os.path.join(self.input_dir, name))] - # search for folder names from pickle folder to get years - patt = "X_*.pkl" - for year in self.years: - months_pkl_list = glob.glob(os.path.join(self.input_dir, year, patt)) - months_list = [int(m[-6:-4]) for m in months_pkl_list] - self.months.extend(months_list) - self.years_months.append(months_list) - return self.years, list(set(self.months)), self.years_months - - def get_stats_file(self): - """ - Get the corresponding statistics file - """ - method = ERA5Pkl2Tfrecords.get_stats_file.__name__ - - stats_file = os.path.join(os.path.dirname(self.input_dir), "statistics.json") - print("Opening json-file: {0}".format(stats_file)) - if os.path.isfile(stats_file): - with open(stats_file) as js_file: - self.stats = json.load(js_file) - else: - raise FileNotFoundError("%{0}: Could not find statistic file '{1}'".format(method, stats_file)) - - def get_metadata(self, md_instance): - """ - This function gets the meta data that has been generated in data_process_step1. Here, we aim to extract - the height and width information from it - vars_in : list(str), must be consistent with the list from DataPreprocessing_step1 - height : int, the height of the image - width : int, the width of the image - """ - method = ERA5Pkl2Tfrecords.get_metadata.__name__ - - if not isinstance(md_instance, MetaData): - raise ValueError("%{0}: md_instance-argument must be a MetaData class instance".format(method)) - - if not hasattr(self, "metadata_fl"): - raise ValueError("%{0}: MetaData class instance passed, but attribute metadata_fl is still missing.".format(method)) - - try: - self.height, self.width = md_instance.ny, md_instance.nx - self.vars_in = md_instance.variables - except: - raise IOError("%{0}: Could not retrieve all required information from metadata-file '{0}'" - .format(method, self.metadata_fl)) - - @staticmethod - def save_tf_record(output_fname, sequences, t_start_points): - """ - Save the sequences, and the corresponding timestamp start point to tfrecords - args: - output_frames : str, the file names of the output - sequences : list or array, the sequences want to be saved to tfrecords, - [sequences,seq_len,height,width,channels] - t_start_points : datetime type in the list, the first timestamp for each sequence - [seq_len,height,width, channel], the len of t_start_points is the same as sequences - """ - method = ERA5Pkl2Tfrecords.save_tf_record.__name__ - - sequences = np.array(sequences) - # sanity checks - assert sequences.shape[0] == len(t_start_points), "%{0}: Lengths of sequence differs from length of t_start_points.".format(method) - assert isinstance(t_start_points[0], datetime.datetime), "%{0}: Elements of t_start_points must be datetime-objects.".format(method) - - with tf.python_io.TFRecordWriter(output_fname) as writer: - for i in range(len(sequences)): - sequence = sequences[i] - - t_start = t_start_points[i].strftime("%Y%m%d%H") - num_frames = len(sequence) - height, width, channels = sequence[0].shape - encoded_sequence = np.array([list(image) for image in sequence]) - features = tf.train.Features(feature={ - 'sequence_length': _int64_feature(num_frames), - 'height': _int64_feature(height), - 'width': _int64_feature(width), - 'channels': _int64_feature(channels), - 't_start': _int64_feature(int(t_start)), - 'images/encoded': _floats_feature(encoded_sequence.flatten()), - }) - example = tf.train.Example(features=features) - writer.write(example.SerializeToString()) - - def init_norm_class(self): - """ - Get normalization data class - """ - method = ERA5Pkl2Tfrecords.init_norm_class.__name__ - - print("%{0}: Make use of default minmax-normalization.".format(method)) - # init normalization-instance - self.norm_cls = Norm_data(self.vars_in) - self.nvars = len(self.vars_in) - # get statistics file - self.get_stats_file() - # open statistics file and feed it to norm-instance - self.norm_cls.check_and_set_norm(self.stats, self.norm) - - def normalize_vars_per_seq(self, sequences): - """ - Normalize all the variables for the sequences - args: - sequences: list or array, is the sequences need to be saved to tfrecorcd. - The shape should be [sequences_per_file,seq_length,height,width,nvars] - Return: - the normalized sequences - """ - method = ERA5Pkl2Tfrecords.normalize_vars_per_seq.__name__ - - assert len(np.array(sequences).shape) == 5, "%{0}: Length of sequence array must be 5.".format(method) - # normalization should adpot the selected variables, here we used duplicated channel temperature variables - sequences = np.array(sequences) - # normalization - for i in range(self.nvars): - sequences[..., i] = self.norm_cls.norm_var(sequences[..., i], self.vars_in[i], self.norm) - return sequences - - def read_pkl_and_save_tfrecords(self, year, month): - """ - Read pickle files based on month, to process and save to tfrecords, - args: - year : int, the target year to save to tfrecord - month : int, the target month to save to tfrecord - """ - method = ERA5Pkl2Tfrecords.read_pkl_and_save_tfrecords.__name__ - - # Define the input_file based on the year and month - self.input_file_year = os.path.join(self.input_dir, str(year)) - input_file = os.path.join(self.input_file_year, 'X_{:02d}.pkl'.format(month)) - temp_input_file = os.path.join(self.input_file_year, 'T_{:02d}.pkl'.format(month)) - - self.init_norm_class() - sequences = [] - t_start_points = [] - sequence_iter = 0 - - try: - with open(input_file, "rb") as data_file: - X_train = pickle.load(data_file) - except: - raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, input_file)) - - try: - with open(temp_input_file, "rb") as temp_file: - T_train = pickle.load(temp_file) - except: - raise IOError("%{0}: Could not read data from pickle-file '{1}'".format(method, temp_input_file)) - - # check to make sure that X_train and T_train have the same length - assert (len(X_train) == len(T_train)) - - X_possible_starts = [i for i in range(len(X_train) - self.sequence_length)] - for X_start in X_possible_starts: - X_end = X_start + self.sequence_length - seq = X_train[X_start:X_end, ...] - # recording the start point of the timestamps (already datetime-objects) - - t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start]) - seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars))) - if not sequences: - last_start_sequence_iter = sequence_iter - sequences.append(seq) - t_start_points.append(t_start) - sequence_iter += 1 - - if len(sequences) == self.sequences_per_file: - # normalize variables in the sequences - sequences = ERA5Pkl2Tfrecords.normalize_vars_per_seq(self, sequences) - output_fname = 'sequence_Y_{}_M_{}_{}_to_{}.tfrecords'.format(year, month, last_start_sequence_iter, - sequence_iter - 1) - output_fname = os.path.join(self.output_dir, output_fname) - # write to tfrecord - ERA5Pkl2Tfrecords.write_seq_to_tfrecord(output_fname, sequences, t_start_points) - t_start_points = [] - sequences = [] - print("%{0}: Finished processing of input file '{1}'".format(method, input_file)) - -# except FileNotFoundError as fnf_error: -# print(fnf_error) - - @staticmethod - def write_seq_to_tfrecord(output_fname, sequences, t_start_points): - """ - Function to check if the sequences has been processed. - If yes, the sequences are skipped, otherwise the sequences are saved to the output file - """ - method = ERA5Pkl2Tfrecords.write_seq_to_tfrecord.__name__ - - if os.path.isfile(output_fname): - print("%{0}: TFrecord-file {1} already exists. It is therefore skipped.".format(method, output_fname)) - else: - ERA5Pkl2Tfrecords.save_tf_record(output_fname, list(sequences), t_start_points) - - def write_sequence_file(self): - """ - Generate a txt file, with the numbers of sequences for each tfrecords file. - This is mainly used for calculting the number of samples for each epoch during training epoch - """ - - with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file: - seq_file.write("%d\n" % self.sequences_per_file) - - - @staticmethod - def ensure_datetime(date): - """ - Wrapper to return a datetime-object - """ - method = ERA5Pkl2Tfrecords.ensure_datetime.__name__ - - fmt = "%Y%m%d %H:%M" - if isinstance(date, datetime.datetime): - date_new = date - else: - try: - date_new=pd.to_datetime(date) - date_new=date_new.to_pydatetime() - except Exception as err: - print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date))) - raise err - - return date_new - -def _bytes_feature(value): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - - -def _bytes_list_feature(values): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) - - -def _floats_feature(value): - return tf.train.Feature(float_list=tf.train.FloatList(value=value)) - - -def _int64_feature(value): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) diff --git a/video_prediction_tools/data_preprocess/process_netCDF_v2.py b/video_prediction_tools/data_preprocess/process_netCDF_v2.py deleted file mode 100644 index ae8c142134da53ee1c37c846d63f2ef8ea9340cc..0000000000000000000000000000000000000000 --- a/video_prediction_tools/data_preprocess/process_netCDF_v2.py +++ /dev/null @@ -1,243 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -""" -Code for processing staged ERA5 data, this is used for the DataPreprocessing step 1 of the workflow - -reviewed by Michael Langguth: 2021-03-21 -""" - -__email__ = "b.gong@fz-juelich.de" -__author__ = "Michael Langguth, Bing Gong, Scarlet Stadtler" - -import sys, os -import fnmatch -import pickle -import numpy as np -import xarray as xr -import datetime as dt -from netcdf_datahandling import GeoSubdomain -from statistics import Calc_data_stat - -class PreprocessNcToPkl(object): - - def __init__(self, src_dir, target_dir, year, job_id, target_dom, variables=("2t", "msl", "t_850")): - """ - Function to process data from netCDF file to pickle file - args: - src_dir : string, directory based on year where netCDF-files are stored to be processed - target_dir : base-directory where data is stored (files are stored under [target_dir]/pickle/[year]/) - job_id : job_id with range "01"-"12" (organized by PyStager) job_name also corresponds to the month - year : year of data to be processed - target_dom : class instance of GeoSubdomain which defines target domain - vars : variables to be processed - """ - # directory_to_process is month-based directory - self.directory_to_process=os.path.join(src_dir,str(year), str(job_id)) - # sanity checks - if int(job_id) > 12 or int(job_id) < 1 or not isinstance(job_id, str): - raise ValueError("job_name should be int type between 1 to 12") - - if not os.path.exists(self.directory_to_process): - raise NotADirectoryError("The directory_to_process '"+self.directory_to_process+"' does not exist") - - if not isinstance(target_dom, GeoSubdomain): - raise ValueError("target_dom must be a {0}-instance.".format(GeoSubdomain.__name__)) - - self.target_dir = os.path.join(target_dir, "pickle", str(year)) # preprocessed data to pickle-subdirectory - if not os.path.exists(self.target_dir): - os.mkdir(self.target_dir) - self.job_id = job_id - self.tar_dom = target_dom - # target file name needs to be saved - self.target_file = os.path.join(self.target_dir, 'X_' + str(self.job_id) + '.pkl') - self.vars = list(variables) - self.nvars = len(variables) - # attributes to set during call of class instance - self.imageList = None - self.stat_obj = None - self.data = None - - def __call__(self): - """ - Process the necCDF files in the month_base folder, store the variables of the images into list, - store temporal information to list and save them to pickle file - """ - if os.path.exists(self.target_file): - print(self.target_file, " file exists in the directory ", self.target_dir) - else: - print ("==========Processing files in directory {} =============== ".format(self.directory_to_process)) - self.imageList = self.get_images_list() - self.stat_obj = self.init_stat() - self.data = self.process_era5_data() - self.save_data_to_pickle() - self.save_stat_info() - # ------------------------------------------------------------------------------------------------------------------ - - def get_images_list(self, patt="ecmwf_era5_[0-9][0-9][0-9][0-9][0-9][0-9][0-9][0-9].nc"): - """ - Get the images list from the directory_to_process and sort them by date names - :param patt: The string pattern to filter for (otional) - :return filelist_filt: filtered list of files whose names match patt - """ - method = "{0} of class {1}".format(PreprocessNcToPkl.get_images_list.__name__, PreprocessNcToPkl.__name__) - - filelist_all = list(os.walk(self.directory_to_process, topdown = False))[-1][-1] - filelist_filt = fnmatch.filter(filelist_all, patt) - filelist_filt = sorted(filelist_filt) - # sanity check - if len(filelist_filt) == 0: - raise FileNotFoundError("%{0}: Could not find ERA5 netCDf-files under '{1}'" - .format(method, self.directory_to_process)) - - return filelist_filt - - # ------------------------------------------------------------------------------------------------------------------ - - def init_stat(self): - """ - Initializes the statistics instance - """ - method = "{0} of class {1}".format(PreprocessNcToPkl.init_stat.__name__, PreprocessNcToPkl.__name__) - # sanity check - if self.nvars <= 0: - raise AttributeError("%{0}: At least one variable must be tracked from the statistic object." - .format(method)) - - stat_obj = Calc_data_stat(self.nvars) - - return stat_obj - - # ------------------------------------------------------------------------------------------------------------------ - - def process_era5_data(self): - """ - Get the selected variables from netCDF file, and concanate all the variables from all the images in the - directiory_to_process into a list EU_stack_list - EU_stack_list dimension should be [numer_of_images,height, width,number_of_variables] - temporal_list is 1-dim list with timestamp data type, contains all the timestamps of netCDF files. - """ - method = "{0} of class {1}".format(PreprocessNcToPkl.process_era5_data.__name__, - PreprocessNcToPkl.__name__) - - tar_dom = self.tar_dom - for j, nc_fname in enumerate(self.imageList): - nc_fname_full = os.path.join(self.directory_to_process, nc_fname) - try: - data_curr = tar_dom.get_data_dom(nc_fname_full, self.vars) - if j == 0: - data_all = data_curr.copy(deep=True) - else: - data_all = xr.concat([data_all, data_curr], dim="time") - # feed statistics-instance (ML, 2021-03-21: This is kind of slow and could be optimized by using - # the data_all-dataset directly at the end. However, we keep the former approach for now.) - for i, var in enumerate(self.vars): - self.stat_obj.acc_stat_loc(i, np.squeeze(data_curr[var].values)) - except Exception as err: - print("%{0}: ERROR in job {1}: Could not handle data from netCDf-file '{2}'".format(method, self.job_id, - nc_fname_full)) - #print("%{0}: The related error is: {1}".format(method, str(err))) - raise err # would better catched by Pystager - - return data_all - - # ------------------------------------------------------------------------------------------------------------------ - - def save_data_to_pickle(self): - method = "{0} of class {1}".format(PreprocessNcToPkl.save_data_to_pickle.__name__, - PreprocessNcToPkl.__name__) - # saity check - if self.data is None: - raise AttributeError("%{0}: Class instance does not contain any data".format(method)) - - # construct pickle filenames - tar_fdata = os.path.join(self.target_dir, "X_{0}.pkl".format(self.job_id)) - tar_ftimes = os.path.join(self.target_dir, "T_{0}.pkl".format(self.job_id)) - - # write data to pickle-file - data_arr = self.convert_ds_to_np() - try: - with open(tar_fdata, "wb") as pkl_file: - pickle.dump(data_arr, pkl_file) - except Exception as err: - print("%{0}: ERROR in job {1}: could not write data to pickle-file '{2}'".format(method, self.job_id, - tar_fdata)) - # print("%{0}: The related error is: {1}".format(method, str(err))) - raise err # would better catched by Pystager - - # write times to pickle-file incl. conversion to datetime-object - try: - time = self.data.coords["time"] - time = np.array([dt.datetime.strptime(np.datetime_as_string(date, "m"), "%Y-%m-%dT%H:%M") for date in time]) - with open(tar_ftimes, "wb") as tpkl_file: - pickle.dump(time, tpkl_file) - except Exception as err: - print("%{0}: ERROR in job {1}: could not write times to pickle-file '{2}'".format(method, self.job_id, - tar_ftimes)) - # print("%{0}: The related error is: {1}".format(method, str(err))) - raise err # would better catched by Pystager - - # ------------------------------------------------------------------------------------------------------------------ - - def convert_ds_to_np(self): - """ - Converts given dataset to numpy-array that is ready to be pickled - :return data_arr: The numpy-array which can be pickled - """ - method = "{0} of class {1}".format(PreprocessNcToPkl.convert_ds_to_np.__name__, - PreprocessNcToPkl.__name__) - - if self.data is None: - raise AttributeError("%{0}: Class instance still does not contain any data.".format(method)) - - # write some attributes to local variables for conveninece - data = self.data - tar_dom = self.tar_dom - # roll data if domain crosses zero-meridian (to get spatially coherent data-arrays) - if tar_dom.lon_slices[0] > tar_dom.lon_slices[1]: - nroll_lon = tar_dom.nlon - tar_dom.lon_slices[0] - data = data.roll(lon=nroll_lon, roll_coords=True) - - # init resulting numpy-array... - print("data[self.vars[0]] shape is: ",data[self.vars[0]].shape) - dshape = list(np.shape(np.squeeze(data[self.vars[0]]))) + [self.nvars] - print("dshape is: ",dshape) - data_arr = np.full(dshape, np.nan) - print("data_arr shape is: ",data_arr.shape) - # ... and populate the data in it - for ivar, var in enumerate(self.vars): - data_arr[..., ivar] = np.squeeze(data[var].values) - - return data_arr - - - def save_stat_info(self): - """ - save the stat information to the target dir - """ - self.stat_obj.finalize_stat_loc(self.vars) - self.stat_obj.write_stat_json(self.target_dir, file_id=self.job_id) - - # -------------------------------------------- end of class -------------------------------------------------------- - - - - - - - - - - - - - - - - - - - - - diff --git a/video_prediction_tools/data_split/cv_test.json b/video_prediction_tools/data_split/cv_test.json deleted file mode 100644 index ccde8645ce129aac1ded0ed87a9d6feaefcbbd11..0000000000000000000000000000000000000000 --- a/video_prediction_tools/data_split/cv_test.json +++ /dev/null @@ -1,20 +0,0 @@ - - -{ - "train":{ - "2010":[1,2,3,4,5,6,7,8,9,10,11,12], - "2013":[1,2,3,4,5,6,7,8,9,10,11,12], - "2015":[1,2,3,4,5,6,7,8,9,10,11,12], - "2019":[1,2,3,4,5,6,7,8,9,10,11,12] - }, - "val": - { - "2017":[1,2,3,4,5,6,7,8,9,10,11,12] - }, - "test": - { - "2016":[1,2,3,4,5,6,7,8,9,10,11,12] - - } - } - diff --git a/video_prediction_tools/data_split/era5/datasplit.json b/video_prediction_tools/data_split/era5/datasplit.json index 5dafd53b7143c064beabf67b02c723805c4b52ef..a1d3fceb7dd7c3d754b9febce8aa2e6e60a0a1a9 100644 --- a/video_prediction_tools/data_split/era5/datasplit.json +++ b/video_prediction_tools/data_split/era5/datasplit.json @@ -1,14 +1,14 @@ { "train":{ - "2017":[1] + "2018":[1,2,3,4,5,6] }, "val": { - "2017":[2] + "2018":[7,8,9,10,11,12] }, "test": { - "2017":[3] + "2019":[1,2,3,4,5,6,7,8,9,10,11,12] } } diff --git a/video_prediction_tools/data_split/era5/datasplit_template.json b/video_prediction_tools/data_split/era5/datasplit_template.json index 2f38f782da2f64cf809ae9b0bd54ee05802049fb..3a039ce7b95befabcc5f71c1ddacd3669474e493 100644 --- a/video_prediction_tools/data_split/era5/datasplit_template.json +++ b/video_prediction_tools/data_split/era5/datasplit_template.json @@ -7,7 +7,6 @@ # such as np.range or similar here! { "train":{ - "2007":[1,2,3,4,5,6,7,8,9,10,11,12], "2008":[1,2,3,4,5,6,7,8,9,10,11,12], "2009":[1,2,3,4,5,6,7,8,9,10,11,12], "2010":[1,2,3,4,5,6,7,8,9,10,11,12], @@ -16,12 +15,12 @@ "2013":[1,2,3,4,5,6,7,8,9,10,11,12], "2014":[1,2,3,4,5,6,7,8,9,10,11,12], "2015":[1,2,3,4,5,6,7,8,9,10,11,12], - "2017":[1,2,3,4,5,6,7,8,9,10,11,12], - "2018":[1,2,3,4,5,6,7,8,9,10,11,12] + "2016":[1,2,3,4,5,6,7,8,9,10,11,12], + "2017":[1,2,3,4,5,6,7,8,9,10,11,12] }, "val": { - "2016":[1,2,3,4,5,6,7,8,9,10,11,12] + "2018":[1,2,3,4,5,6,7,8,9,10,11,12] }, "test": { diff --git a/video_prediction_tools/data_split/gzprcp/datasplit.json b/video_prediction_tools/data_split/gzprcp/datasplit.json new file mode 100644 index 0000000000000000000000000000000000000000..04b3137dd5a8473e695b859d8721cf7607f4cb81 --- /dev/null +++ b/video_prediction_tools/data_split/gzprcp/datasplit.json @@ -0,0 +1,5 @@ +{ +"train":[2015,2016,2017], +"val": [2018], +"test":[2019] + } diff --git a/video_prediction_tools/data_split/gzprcp_data/datasplit_template.json b/video_prediction_tools/data_split/gzprcp_data/datasplit_template.json new file mode 100644 index 0000000000000000000000000000000000000000..189051a88036f1c957f9d4d16d30c90b282f71bb --- /dev/null +++ b/video_prediction_tools/data_split/gzprcp_data/datasplit_template.json @@ -0,0 +1,20 @@ +# NOTE: This json-file should not be processed and simply serves as an exemplary file to configure the datasplit for kth human action dataset. +# If you would like to generate your own datasplit config file, you may copy this template and modify it to your personal needs. +# However, remember to remove any comment lines (starting with #) from your config-file then!!! +# +# Explanation: In the following, the data is splitted based on the index, each index has a list with two elements which are the start and end indices of the +# raw dataset +# Be aware that this is a prue data file, i.e. do not make use of any Python-functions such as np.range or similar here! +{ + "train":{ + "year":[2017] + }, + "val": + { + "year":[2018] + }, + "test": + { + "year":[2019] + } + } diff --git a/video_prediction_tools/data_split/cv1.json b/video_prediction_tools/data_split/test/cv1.json similarity index 100% rename from video_prediction_tools/data_split/cv1.json rename to video_prediction_tools/data_split/test/cv1.json diff --git a/video_prediction_tools/data_split/test/cv_test.json b/video_prediction_tools/data_split/test/cv_test.json new file mode 100644 index 0000000000000000000000000000000000000000..6d6a94cab8ac08474ed7765127550ccb3c3f1e0d --- /dev/null +++ b/video_prediction_tools/data_split/test/cv_test.json @@ -0,0 +1,17 @@ + + +{ + "train":{ + "2018":[1,2,3,4,5,6,7,8,9,10,11,12] + }, + "val": + { + "2018":[1,2,3,4,5,6,7,8,9,10,11,12] + }, + "test": + { + "2019":[1,2,3,4,5,6,7,8,9,10,11,12] + + } + } + diff --git a/video_prediction_tools/data_split/weatherbench/datasplit_template.json b/video_prediction_tools/data_split/weatherbench/datasplit_template.json new file mode 100644 index 0000000000000000000000000000000000000000..92dc669c5d90cc855aa1aaa367bbd0ef35d0b979 --- /dev/null +++ b/video_prediction_tools/data_split/weatherbench/datasplit_template.json @@ -0,0 +1,27 @@ +# NOTE: Please configure this JSON-files according your needs. Any line starting with # will be removed +# when editing is invoked from generate_runscript.py. +# +# Explanation: In the following, the data of the whole year 2015 and the first half of 2016 is used for training, +# while the data of 2017 and 2018 are used for validation and testing, respectively. +# Be aware that this is a pure data file, i.e. do not make use of any Python-functions +# such as np.range or similar here! +{ + "train":{ + "2008":[1,2,3,4,5,6,7,8,9,10,11,12], + "2009":[1,2,3,4,5,6,7,8,9,10,11,12], + "2010":[1,2,3,4,5,6,7,8,9,10,11,12], + "2011":[1,2,3,4,5,6,7,8,9,10,11,12], + "2012":[1,2,3,4,5,6,7,8,9,10,11,12], + "2013":[1,2,3,4,5,6,7,8,9,10,11,12], + "2014":[1,2,3,4,5,6,7,8,9,10,11,12], + "2015":[1,2,3,4,5,6,7,8,9,10,11,12] + }, + "val": + { + "2016":[1,2,3,4,5,6,7,8,9,10,11,12] + }, + "test": + { + "2017":[1,2,3,4,5,6,7,8,9,10,11,12] + } + } diff --git a/video_prediction_tools/data_split/weatherbench/weatherbench.json b/video_prediction_tools/data_split/weatherbench/weatherbench.json new file mode 100644 index 0000000000000000000000000000000000000000..b63794937d80257200c38eac74c6105a7b4d45ac --- /dev/null +++ b/video_prediction_tools/data_split/weatherbench/weatherbench.json @@ -0,0 +1,19 @@ +{ + "normalize": "ZScore", + "variables": [ + {"name": "temperature", "lvl": [850], "interpolation":"p"}, + {"name": "geopotential", "lvl": [500], "interpolation":"p"} + ], + "resolution": [ + {"nx": 32, "ny": 64}, + {"nx": 64, "ny": 128}, + {"nx": 128, "ny": 256} + ], + "years": [ + 1979,1980, + 1981,1982,1983,1984,1985,1986,1987,1988,1989,1990, + 1991,1992,1993,1994,1995,1996,1997,1998,1999,2000, + 2001,2002,2003,2004,2005,2006,2007,2008,2009,2010, + 2011,2012,2013,2014,2015,2016,2017,2018 + ] +} \ No newline at end of file diff --git a/video_prediction_tools/deprecated/create_env_zam347.sh b/video_prediction_tools/deprecated/create_env_zam347.sh deleted file mode 100755 index 711810a1d0ca17c6491722b8588753c007c3b7f8..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/create_env_zam347.sh +++ /dev/null @@ -1,32 +0,0 @@ -#!/usr/bin/env bash - - -if [[ ! -n "$1" ]]; then - echo "Provide the env name, which will be taken as folder name" - exit 1 -fi - -ENV_NAME=$1 -WORKING_DIR=/home/$USER/ambs/video_prediction_savp -ENV_SETUP_DIR=${WORKING_DIR}/env_setup -ENV_DIR=${WORKING_DIR}/${ENV_NAME} -unset PYTHONPATH -# Install additional Python packages. -python3 -m venv $ENV_DIR -source ${ENV_DIR}/bin/activate -pip3 install --upgrade pip -pip3 install -r ${ENV_SETUP_DIR}/requirements.txt -pip3 install mpi4py -pip3 install netCDF4 -pip3 install numpy -pip3 install h5py -pip3 install tensorflow-gpu==1.13.1 - -#export PYTHONPATH=/home/$USER/miniconda3/pkgs:$PYTHONPATH -export PYTHONPATH=${WORKING_DIR}/external_package/hickle/lib/python3.6/site-packages:$PYTHONPATH -export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH -#export PYTHONPATH=${ENV_DIR}/lib/python3.6/site-packages:$PYTHONPATH -#export PYTHONPATH=/p/home/jusers/${USER}/juwels/.local/bin:$PYTHONPATH -export PYTHONPATH=${WORKING_DIR}/external_package/lpips-tensorflow:$PYTHONPATH - - diff --git a/video_prediction_tools/deprecated/datasets/Download_ERA5_Variable.py b/video_prediction_tools/deprecated/datasets/Download_ERA5_Variable.py deleted file mode 100644 index 419960b81d5bda6811e63621f760f847e78c71b8..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/datasets/Download_ERA5_Variable.py +++ /dev/null @@ -1,169 +0,0 @@ -#!/usr/bin/env python - -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -import cdsapi -import argparse - -''' -code example: -python /mnt/jiyan/ERA5/Download_ERA5_geopotential.py --year_select 2003 --month_select 01 02 03 ---date_select 01 02 31 --hour_select 00:00 01:00 23:00 --lon_start 70 --lon_end 140 --lat_start 15 ---lat_end 60 --data_format 'netcdf' --variable 'temperature' --pressure_level 500 850 --output_path /tmp/ -''' - -''' -copy the .cdsapirc to ~ -then pip install cdsapi -''' - -''' -year_select should be in - ['1979', '1980', '1981', - '1982', '1983', '1984', - '1985', '1986', '1987', - '1988', '1989', '1990', - '1991', '1992', '1993', - '1994', '1995', '1996', - '1997', '1998', '1999', - '2000', '2001', '2002', - '2003', '2004', '2005', - '2006', '2007', '2008', - '2009', '2010', '2011', - '2012', '2013', '2014', - '2015', '2016', '2017', - '2018', '2019', '2020',] -''' -year_select = [ '2010', '2011', '2012', '2013', '2014'] - -''' -month_select should be in - '01', '02', '03', - '04', '05', '06', - '07', '08', '09', - '10', '11', '12', -''' -month_select = ['01', '02', '03'] - -''' -date_select should be in - '01', '02', '03', - '04', '05', '06', - '07', '08', '09', - '10', '11', '12', - '13', '14', '15', - '16', '17', '18', - '19', '20', '21', - '22', '23', '24', - '25', '26', '27', - '28', '29', '30', - '31', -''' -date_select = [ '01', '02', '03'] - -''' -hour_select should be in - '00:00', '01:00', '02:00', - '03:00', '04:00', '05:00', - '06:00', '07:00', '08:00', - '09:00', '10:00', '11:00', - '12:00', '13:00', '14:00', - '15:00', '16:00', '17:00', - '18:00', '19:00', '20:00', - '21:00', '22:00', '23:00', -''' -hour_select = [ '00:00', '01:00', '02:00'] - -''' -[north(0~90), west(-180~0),south(0~-90), east(0~180)] -''' -area_select = [60, 70, 15, 140,] - -# 'grib' or 'netcdf' -data_format = 'netcdf' - -''' -variable: should be a single variable in - 'divergence', 'fraction_of_cloud_cover', 'geopotential', - 'ozone_mass_mixing_ratio', 'potential_vorticity', 'relative_humidity', - 'specific_cloud_ice_water_content', 'specific_cloud_liquid_water_content', 'specific_humidity', - 'specific_rain_water_content', 'specific_snow_water_content', 'temperature', - 'u_component_of_wind', 'v_component_of_wind', 'vertical_velocity', - 'vorticity' -''' -variable = 'temperature' # single varibale - -''' -pressure_level should be a single pressure level in - '1', '2', '3', - '5', '7', '10', - '20', '30', '50', - '70', '100', '125', - '150', '175', '200', - '225', '250', '300', - '350', '400', '450', - '500', '550', '600', - '650', '700', '750', - '775', '800', '825', - '850', '875', '900', - '925', '950', '975', - '1000', -''' -pressure_level = '500' - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--year_select", nargs='+', type=int, default=['1979', '1980', '1981','1982', '1983', '1984', - '1985', '1986', '1987', '1988', '1989', '1990', '1991', '1992', '1993', '1994', '1995', '1996', '1997', '1998', '1999', - '2000', '2001', '2002', '2003', '2004', '2005', '2006', '2007', '2008', '2009', '2010', '2011', '2012', '2013', '2014', - '2015', '2016', '2017', '2018', '2019', '2020'], help="the year list to be downloaded") - parser.add_argument("--month_select", nargs='+', type=int, default=['01', '02', '03', '04', '05', '06', - '07', '08', '09', '10', '11', '12'], help="the month list to be downloaded") - parser.add_argument("--date_select", nargs='+', type=int, default=['01', '02', '03', '04', '05', '06', - '07', '08', '09', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', - '25', '26', '27', '28', '29', '30', '31'], help="the date list to be downloaded") - parser.add_argument("--hour_select", nargs='+', type=str, default=['00:00', '01:00', '02:00', '03:00', '04:00', '05:00', - '06:00', '07:00', '08:00', '09:00', '10:00', '11:00', '12:00', '13:00', '14:00', '15:00', '16:00', '17:00', - '18:00', '19:00', '20:00', '21:00', '22:00', '23:00'], help="the hour list to be downloaded") - parser.add_argument("--lon_start", type=float, default=-180, help="the minimum longitude of the area") - parser.add_argument("--lon_end", type=float, default=180, help="the minimum longitude of the area") - parser.add_argument("--lat_start", type=float, default=-90, help="the maximum latitude of the area") - parser.add_argument("--lat_end", type=float, default=90, help="the minimum latitude of the area") - parser.add_argument("--output_path", type=str, required=True, help="the path to be saved") - parser.add_argument("--data_format", type=str, default='netcdf', help="the data format") - parser.add_argument("--variable", type=str, required=True, help="the variable to be downloaded") - parser.add_argument("--pressure_level", nargs='+', type=int, required=True, help="the variable to be downloaded") - args = parser.parse_args() - download_hourly_reanalysis_era5_pl_variables(year_select=args.year_select,month_select=args.month_select, - date_select=args.date_select,hour_select=args.hour_select,lon_start=args.lon_start,lon_end=args.lon_end, - lat_start=args.lat_start,lat_end=args.lat_end,data_format=args.data_format,variable=args.variable, - pressure_level=args.pressure_level,output_path=args.output_path) - - -def download_hourly_reanalysis_era5_pl_variables(year_select,month_select,date_select,hour_select, - lon_start,lon_end,lat_start,lat_end,data_format,variable,pressure_level,output_path): - if data_format=='netcdf': - fp = '.nc' - elif data_format=='grib': - fp = '.grib' - c = cdsapi.Client() - for iyear in year_select: - c.retrieve( - 'reanalysis-era5-pressure-levels', - { - 'product_type': 'reanalysis', - 'format': data_format, - 'variable': variable, - 'pressure_level': pressure_level, - 'year': str(iyear), - 'month': month_select, - 'day': date_select, - 'time': hour_select, - 'area': [lat_end, lon_start, lat_start, lon_end,], - }, - output_path+'ERA5_'+variable+'_'+str(iyear)+fp) - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/deprecated/datasets/extract_data/era5_dataset_v2_anomaly.py b/video_prediction_tools/deprecated/datasets/extract_data/era5_dataset_v2_anomaly.py deleted file mode 100644 index 2a30c9ce59040b1f8ce9b8efa66508fd18e9be1b..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/datasets/extract_data/era5_dataset_v2_anomaly.py +++ /dev/null @@ -1,278 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -import argparse -import glob -import itertools -import os -import pickle -import random -import re -import netCDF4 -import hickle as hkl -import numpy as np -import tensorflow as tf -import pandas as pd -from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset -from collections import OrderedDict -from tensorflow.contrib.training import HParams - -units = "hours since 2000-01-01 00:00:00" -calendar = "gregorian" - -class ERA5Dataset_v2_anomaly(VarLenFeatureVideoDataset): - def __init__(self, *args, **kwargs): - super(ERA5Dataset_v2_anomaly, self).__init__(*args, **kwargs) - from google.protobuf.json_format import MessageToDict - example = next(tf.python_io.tf_record_iterator(self.filenames[0])) - dict_message = MessageToDict(tf.train.Example.FromString(example)) - feature = dict_message['features']['feature'] - image_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['height', 'width', 'channels']) - self.state_like_names_and_shapes['images'] = 'images/encoded', image_shape - - def get_default_hparams_dict(self): - default_hparams = super(ERA5Dataset_v2_anomaly, self).get_default_hparams_dict() - hparams = dict( - context_frames=10, - sequence_length=20, - long_sequence_length=40, - force_time_shift=True, - shuffle_on_val=True, - use_state=False, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - @property - def jpeg_encoding(self): - return False - - - def num_examples_per_epoch(self): - with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: - sequence_lengths = sequence_lengths_file.readlines() - sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] - return np.sum(np.array(sequence_lengths) >= self.hparams.sequence_length) - - - def filter(self, serialized_example): - return tf.convert_to_tensor(True) - - - - def make_dataset_v2(self, batch_size): - def parser(serialized_example): - seqs = OrderedDict() - keys_to_features = { - # 'width': tf.FixedLenFeature([], tf.int64), - # 'height': tf.FixedLenFeature([], tf.int64), - 'sequence_length': tf.FixedLenFeature([], tf.int64), - # 'channels': tf.FixedLenFeature([],tf.int64), - # 'images/encoded': tf.FixedLenFeature([], tf.string) - 'images/encoded': tf.VarLenFeature(tf.float32) - } - # for i in range(20): - # keys_to_features["frames/{:04d}".format(i)] = tf.FixedLenFeature((), tf.string) - parsed_features = tf.parse_single_example(serialized_example, keys_to_features) - seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) - images = [] - # for i in range(20): - # images.append(parsed_features["images/encoded"].values[i]) - # images = parsed_features["images/encoded"] - # images = tf.map_fn(lambda i: tf.image.decode_jpeg(parsed_features["images/encoded"].values[i]),offsets) - # seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"], '') - # Parse the string into an array of pixels corresponding to the image - # images = tf.decode_raw(parsed_features["images/encoded"],tf.int32) - - # images = seq - images = tf.reshape(seq, [20, 64, 64, 1], name = "reshape_new") - seqs["images"] = images - return seqs - filenames = self.filenames - filenames_mean = self.filenames_mean - shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) - if shuffle: - random.shuffle(filenames) - dataset = tf.data.TFRecordDataset(filenames, buffer_size = 8 * 1024 * 1024) # todo: what is buffer_size - dataset = dataset.filter(self.filter) - #Bing: for Anomaly - dataset_mean = tf.data.TFRecordDataset(filenames_mean, buffer_size = 8 * 1024 * 1024) - dataset_mean = dataset_mean.filter(self.filter) - if shuffle: - dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size = 1024, count = self.num_epochs)) - dataset_mean = dataset_mean.apply(tf.contrib.data.shuffle_and_repeat(buffer_size = 1024, count = self.num_epochs)) - else: - dataset = dataset.repeat(self.num_epochs) - dataset_mean = dataset_mean.repeat(self.num_epochs) - - num_parallel_calls = None if shuffle else 1 - dataset = dataset.apply(tf.contrib.data.map_and_batch( - parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) - dataset_mean = dataset_mean.apply(tf.contrib.data.map_and_batch( - parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) - #dataset = dataset.map(parser) - # num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set) - # dataset = dataset.apply(tf.contrib.data.map_and_batch( - # _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs - dataset = dataset.prefetch(batch_size) # Bing: Take the data to buffer inorder to save the waiting time for GPU - dataset_mean = dataset_mean.prefetch(batch_size) - return dataset, dataset_mean - - def make_batch_v2(self, batch_size): - dataset, dataset_mean = self.make_dataset_v2(batch_size) - iterator = dataset.make_one_shot_iterator() - interator2 = dataset_mean.make_one_shot_iterator() - return iterator.get_next(), interator2.get_next() - - - def make_data_mean(self,batch_size): - pass - - - -def _bytes_feature(value): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - - -def _bytes_list_feature(values): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) - -def _floats_feature(value): - return tf.train.Feature(float_list=tf.train.FloatList(value=value)) - -def _int64_feature(value): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - - - -def save_tf_record(output_fname, sequences): - print('saving sequences to %s' % output_fname) - with tf.python_io.TFRecordWriter(output_fname) as writer: - for sequence in sequences: - num_frames = len(sequence) - height, width, channels = sequence[0].shape - encoded_sequence = np.array([list(image) for image in sequence]) - - features = tf.train.Features(feature={ - 'sequence_length': _int64_feature(num_frames), - 'height': _int64_feature(height), - 'width': _int64_feature(width), - 'channels': _int64_feature(channels), - 'images/encoded': _floats_feature(encoded_sequence.flatten()), - }) - example = tf.train.Example(features=features) - writer.write(example.SerializeToString()) - - -def extract_anomaly_one_pixel(X, X_timestamps,pixel): - print("Processing Pixel {}, {}".format(pixel[0],pixel[1])) - dates = [x.date() for x in X_timestamps] - df = pd.DataFrame(data = X[:, pixel[0], pixel[1]], index = dates) - df_mean = df.groupby(df.index).mean() - df2 = pd.merge(df, df_mean, left_index = True, right_index = True) - df2.columns = ["Real","Daily_mean"] - df2["Anomaly"] = df2["Real"] - df2["Daily_mean"] - daily_mean = df2["Daily_mean"].values - anomaly = df2["Anomaly"].values - return daily_mean, anomaly - -def extract_anomaly_all_pixels(X, X_timestamps): - #daily_mean, anomaly = extract_anomaly_one_pixel(X, X_timestamps, pixel = [0, 0]) - daily_mean_pixels = np.zeros((X.shape[0], X.shape[1], X.shape[2])) - anomaly_pixels = np.zeros((X.shape[0], X.shape[1], X.shape[2])) - #daily_mean_all_pixels = [extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j])[0] for i in range(X.shape[1]) for j in range(X.shape[2])] - #anomaly_all_pixels = [extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j])[1] for i in range(X.shape[1]) for j in range(X.shape[2])] - for i in range(X.shape[1]): - for j in range(X.shape[2]): - daily_mean, anomaly = extract_anomaly_one_pixel(X, X_timestamps, pixel = [i, j]) - daily_mean_pixels[:,i,j] = daily_mean - anomaly_pixels[:,i,j] = anomaly - return daily_mean_pixels, anomaly_pixels - - -def read_frames_and_save_tf_records(output_dir, input_dir, partition_name, N_seq, sequences_per_file=128):#Bing: original 128 - output_orig_dir = os.path.join(output_dir,partition_name + "_orig") - output_time_dir = os.path.join(output_dir,partition_name + "_time") - output_mean_dir = os.path.join(output_dir,partition_name + "_mean") - output_anomaly_dir = os.path.join(output_dir, partition_name ) - - - if not os.path.exists(output_orig_dir): os.mkdir(output_orig_dir) - if not os.path.exists(output_time_dir): os.mkdir(output_time_dir) - if not os.path.exists(output_mean_dir): os.mkdir(output_mean_dir) - if not os.path.exists(output_anomaly_dir): os.mkdir(output_anomaly_dir) - sequences = [] - sequences_time = [] - sequences_mean = [] - sequences_anomaly = [] - - sequence_iter = 0 - sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') - X_train = hkl.load(os.path.join(input_dir, "X_" + partition_name + ".hkl")) - X_time = hkl.load(os.path.join(input_dir, "Time_time_" + partition_name + ".hkl")) - print ("X shape", X_train.shape) - X_timestamps = [netCDF4.num2date(x, units = units, calendar = calendar) for x in X_time] - - print("X_time example", X_time[:10]) - print("X_time after to date", X_timestamps[:10]) - daily_mean_all_pixels, anomaly_all_pixels = extract_anomaly_all_pixels(X_train, X_timestamps) - - X_possible_starts = [i for i in range(len(X_train) - N_seq)] - for X_start in X_possible_starts: - print("Interation", sequence_iter) - X_end = X_start + N_seq - #seq = X_train[X_start:X_end, :, :,:] - seq = X_train[X_start:X_end,:,:] - seq_time = X_time[X_start:X_end] - seq_mean = daily_mean_all_pixels[X_start:X_end,:,:] - seq_anomaly = anomaly_all_pixels[X_start:X_end,:,:] - #print("*****len of seq ***.{}".format(len(seq))) - seq = list(np.array(seq).reshape((len(seq), 64, 64, 1))) - seq_time = list(np.array(seq_time)) - seq_mean = list(np.array(seq_mean).reshape((len(seq_mean), 64, 64, 1))) - seq_anomaly = list(np.array(seq_anomaly).reshape((len(seq_anomaly), 64, 64, 1))) - if not sequences: - last_start_sequence_iter = sequence_iter - print("reading sequences starting at sequence %d" % sequence_iter) - sequences.append(seq) - sequences_time.append(seq_time) - sequences_mean.append(seq_mean) - sequences_anomaly.append(seq_anomaly) - sequence_iter += 1 - sequence_lengths_file.write("%d\n" % len(seq)) - - if len(sequences) == sequences_per_file: - output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) - output_orig_fname = os.path.join(output_orig_dir, output_fname) - output_time_fname = os.path.join(output_time_dir,'sequence_{0}_to_{1}.hkl'.format(last_start_sequence_iter, sequence_iter - 1)) - output_mean_fname = os.path.join(output_mean_dir, output_fname) - output_anomaly_fname = os.path.join(output_anomaly_dir, output_fname) - - save_tf_record(output_orig_fname, sequences) - hkl.dump(sequences_time,output_time_fname ) - #save_tf_record(output_time_fname,sequences_time) - save_tf_record(output_mean_fname, sequences_mean) - save_tf_record(output_anomaly_fname, sequences_anomaly) - sequences[:] = [] - sequences_time[:] = [] - sequences_mean[:] = [] - sequences_anomaly[:] = [] - sequence_lengths_file.close() - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input_dir", type=str, help="directory containing the processed directories ""boxing, handclapping, handwaving, ""jogging, running, walking") - parser.add_argument("output_dir", type=str) - # parser.add_argument("image_size_h", type=int) - # parser.add_argument("image_size_v", type = int) - args = parser.parse_args() - current_path = os.getcwd() - #input_dir = "/Users/gongbing/PycharmProjects/video_prediction/splits" - #output_dir = "/Users/gongbing/PycharmProjects/video_prediction/data/era5" - partition_names = ['train', 'val', 'test'] - for partition_name in partition_names: - read_frames_and_save_tf_records(output_dir=args.output_dir,input_dir=args.input_dir,partition_name=partition_name, N_seq=20) #Bing: Todo need check the N_seq - #ead_frames_and_save_tf_records(output_dir = output_dir, input_dir = input_dir,partition_name = partition_name, N_seq=20) #Bing: TODO: first try for N_seq is 10, but it met loading data issue. let's try 5 - -if __name__ == '__main__': - main() - diff --git a/video_prediction_tools/deprecated/datasets/extract_data/extract_era5.py b/video_prediction_tools/deprecated/datasets/extract_data/extract_era5.py deleted file mode 100755 index 527645626823a3d542e518eddf4e7fd4cf02c0b5..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/datasets/extract_data/extract_era5.py +++ /dev/null @@ -1,108 +0,0 @@ -#!/usr/bin/env python - -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -# -*- coding: utf-8 -*- -""" -Spyder Editor - -Author: YanJI -Email: y.ji@fz-juelich.de -Date: 25 Feb 2021 -""" - -import os - -os.system('module load GCC/9.3.0; module load ParaStationMPI/5.4.7-1; module load CDO/1.9.8') - -def mergetime_oneday(input_path,year,month,date,variable,pressure_level,output_path): - temp_path = os.path.join(output_path,year) - if not os.path.exists(temp_path): - os.mkdir(temp_path) - temp_path = os.path.join(output_path,year,month) - if not os.path.exists(temp_path): - os.mkdir(temp_path) - if variable == 'T2': - infilelist = year+month+date+'*_sf.grb' - outfile = year+month+date+'_'+variable+'.grb' - ori_path = os.path.join(input_path,year,month,infilelist) - out_path = os.path.join(output_path,year,month,outfile) - os.system('cdo expr,"var167" -mergetime %s %s' % (ori_path,out_path)) - if variable == 'MSL': - infilelist = year+month+date+'*_sf.grb' - outfile = year+month+date+'_'+variable+'.grb' - ori_path = os.path.join(input_path,year,month,infilelist) - out_path = os.path.join(output_path,year,month,outfile) - os.system('cdo expr,"var151" -mergetime %s %s' % (ori_path,out_path)) - if variable == 'gph500': - infilelist = year+month+date+'*_ml.grb' - outfile = year+month+date+'_'+variable+'.grb' - ori_path = os.path.join(input_path,year,month,infilelist) - out_path = os.path.join(output_path,year,month,outfile) - os.system('cdo expr,"z" -sellevel,%d -mergetime %s %s' % (pressure_level,ori_path,out_path)) # something wrong -- the variable 'z' only has one level - - -def mergevars_oneday(input_path,year,month,date,var1,var2,var3): - ori_path = os.path.join(input_path,year,month) - varfile1 = os.path.join(ori_path,year+month+date+'_'+var1+'.grb') - varfile2 = os.path.join(ori_path,year+month+date+'_'+var2+'.grb') - varfile3 = os.path.join(ori_path,year+month+date+'_'+var3+'.grb') - varfile1_nc = os.path.join(ori_path,year+month+date+'_'+var1+'.nc') - varfile2_nc = os.path.join(ori_path,year+month+date+'_'+var2+'.nc') - varfile3_nc = os.path.join(ori_path,year+month+date+'_'+var3+'.nc') - outfile = os.path.join(ori_path,year+month+date+'_'+var1+'_'+var2+'_'+var3+'.nc') - os.system('cdo -f nc copy %s %s' % (varfile1,varfile1_nc)) - os.system('cdo -f nc copy %s %s' % (varfile2,varfile2_nc)) - os.system('cdo -f nc copy %s %s' % (varfile3,varfile3_nc)) - os.system('cdo merge %s %s %s %s' % (varfile1_nc,varfile2_nc,varfile3_nc,outfile)) - - -def mergevars_onehour(input_path,year,month,date,hour,output_path): #extract t2,tcc,msl,t850,u10,v10 - temp_path = os.path.join(output_path,year) - if not os.path.exists(temp_path): - os.mkdir(temp_path) - temp_path = os.path.join(output_path,year,month) - if not os.path.exists(temp_path): - os.mkdir(temp_path) - # surface variables - infile = os.path.join(input_path,year,month,year+month+date+hour+'_sf.grb') - outfile = os.path.join(output_path,year,month,year+month+date+hour+'_sfvar.grb') - outfile_sf = os.path.join(output_path,year,month,year+month+date+hour+'_sfvar.nc') - os.system('cdo -merge -selname,"var167","var151" %s %s' % (infile,outfile)) # change the select vatiables - os.system('cdo -f nc copy %s %s' % (outfile,outfile_sf)) - os.system('rm %s' % outfile) - # multi-level variables - infile = os.path.join(input_path,year,month,year+month+date+hour+'_ml.grb') - outfile = os.path.join(output_path,year,month,year+month+date+hour+'_mlvar.grb') - outfile_ml = os.path.join(output_path,year,month,year+month+date+hour+'_mlvar.nc') - pl1 = 95; pl2 = 96 # seclect the levels - os.system('cdo -merge -selname,"t" -sellevel,%d,%d %s %s' % (pl1,pl2,infile,outfile)) # change the select vatiables and levels - os.system('cdo -f nc copy %s %s' % (outfile,outfile_ml)) - os.system('rm %s' % outfile) - # merge both variables - outfile = os.path.join(output_path,year,month,year+month+date+hour+'_era5.nc') # change the output file name - os.system('cdo merge %s %s %s' % (outfile_sf,outfile_ml,outfile)) - os.system('rm %s %s' % (outfile_sf,outfile_ml)) - -input_path='/p/fastdata/slmet/slmet111/met_data/ecmwf/era5/grib' -output_path='/p/home/jusers/ji4/juwels/ambs/era5_extra' - -#### test for one day -#y = '2009' -#m = '01' -#d = '01' -#mergetime_oneday(input_path,y,m,d,'T2',1,output_path) -#mergetime_oneday(input_path,y,m,d,'MSL',1,output_path) -#mergetime_oneday(input_path,y,m,d,'gph500',1,output_path) # something wrong -- the variable 'z' only has one level -#merge_oneday(output_path,y,m,d,'T2','MSL','gph500') - -#### looping -for y in ['2007', '2008']:#, '2009', '2010', '2011', '2012', '2013', '2014', '2015', '2016', '2017', '2018']: - for m in ['01', '02']:#, '03', '04', '05', '06', '07', '08', '09', '10', '11', '12']: - for d in ['01', '02']:#, '03', '04', '05', '06', '07', '08', '09', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31']: - for h in ['01', '02']:#, '03', '04', '05', '06', '07', '08', '09', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23']: - mergevars_onehour(input_path,y,m,d,h,output_path) - - diff --git a/video_prediction_tools/deprecated/helper/helper.py b/video_prediction_tools/deprecated/helper/helper.py deleted file mode 100644 index e85a4df8a8ad9811b2634983db5eac8bf0204b47..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/helper/helper.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging -import time -from functools import wraps - -def logDecorator(fn,verbose=False): - @wraps(fn) - def wrapper(*args,**kwargs): - print("inside wrapper of log decorator function") - logger = logging.getLogger(fn.__name__) - # create a file handler - handler = logging.FileHandler("log.log") - logger.setLevel(logging.DEBUG if verbose else logging.INFO) - #create a console handler - ch = logging.StreamHandler() - logger.setLevel(logging.DEBUG if verbose else logging.INFO) - # create a logging format - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') - handler.setFormatter(formatter) - ch.setFormatter(formatter) - logger.addHandler(handler) - logger.addHandler(ch) - logger.info("Logging 1") - start = time.time() - results = fn(*args,**kwargs) - end = time.time() - logger.info("{} ran in {}s".format(fn.__name__, round(end - start, 2))) - return results - return wrapper - - -#logger = logging.getLogger(__name__) -# def set_logger(verbose=False): -# # Remove all handlers associated with the root logger object. -# for handler in logging.root.handlers[:]: -# logging.root.removeHandler(handler) -# logger = logging.getLogger(__name__) -# logger.propagate = False -# -# -# if not logger.handlers: -# logger.setLevel(logging.DEBUG if verbose else logging.INFO) -# formatter = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" -# -# -# -# #再创建一个handler,用于输出到控制台 -# console_handler = logging.StreamHandler() -# console_handler.setLevel(logging.DEBUG if verbose else logging.INFO) -# console_handler.setFormatter(formatter) -# logger.handlers = [] -# logger.addHandler(console_handler) -# -# return logger \ No newline at end of file diff --git a/video_prediction_tools/deprecated/model_modules/sna_model.py b/video_prediction_tools/deprecated/model_modules/sna_model.py deleted file mode 100644 index 033f2de90a123f6cda6c2616e5115825182f5386..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/model_modules/sna_model.py +++ /dev/null @@ -1,667 +0,0 @@ -# Copyright 2016 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Model architecture for predictive model, including CDNA, DNA, and STP.""" - -import itertools - -import numpy as np -import tensorflow as tf -import tensorflow.contrib.slim as slim -from tensorflow.contrib.layers.python import layers as tf_layers -from tensorflow.contrib.slim import add_arg_scope -from tensorflow.contrib.slim import layers - -from model_modules.video_prediction.models import VideoPredictionModel - - -# Amount to use when lower bounding tensors -RELU_SHIFT = 1e-12 - - -@add_arg_scope -def basic_conv_lstm_cell(inputs, - state, - num_channels, - filter_size=5, - forget_bias=1.0, - scope=None, - reuse=None, - ): - """Basic LSTM recurrent network cell, with 2D convolution connctions. - We add forget_bias (default: 1) to the biases of the forget gate in order to - reduce the scale of forgetting in the beginning of the training. - It does not allow cell clipping, a projection layer, and does not - use peep-hole connections: it is the basic baseline. - Args: - inputs: input Tensor, 4D, batch x height x width x channels. - state: state Tensor, 4D, batch x height x width x channels. - num_channels: the number of output channels in the layer. - filter_size: the shape of the each convolution filter. - forget_bias: the initial value of the forget biases. - scope: Optional scope for variable_scope. - reuse: whether or not the layer and the variables should be reused. - Returns: - a tuple of tensors representing output and the new state. - """ - if state is None: - state = tf.zeros(inputs.get_shape().as_list()[:3] + [2 * num_channels], name='init_state') - - with tf.variable_scope(scope, - 'BasicConvLstmCell', - [inputs, state], - reuse=reuse): - - inputs.get_shape().assert_has_rank(4) - state.get_shape().assert_has_rank(4) - c, h = tf.split(axis=3, num_or_size_splits=2, value=state) - inputs_h = tf.concat(values=[inputs, h], axis=3) - # Parameters of gates are concatenated into one conv for efficiency. - i_j_f_o = layers.conv2d(inputs_h, - 4 * num_channels, [filter_size, filter_size], - stride=1, - activation_fn=None, - scope='Gates', - ) - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = tf.split(value=i_j_f_o, num_or_size_splits=4, axis=3) - - new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j) - new_h = tf.tanh(new_c) * tf.sigmoid(o) - - return new_h, tf.concat(values=[new_c, new_h], axis=3) - - -class Prediction_Model(object): - - def __init__(self, - images, - actions=None, - states=None, - iter_num=-1.0, - pix_distributions1=None, - pix_distributions2=None, - conf=None): - - self.pix_distributions1 = pix_distributions1 - self.pix_distributions2 = pix_distributions2 - self.actions = actions - self.iter_num = iter_num - self.conf = conf - self.images = images - - self.cdna, self.stp, self.dna = False, False, False - if self.conf['model'] == 'CDNA': - self.cdna = True - elif self.conf['model'] == 'DNA': - self.dna = True - elif self.conf['model'] == 'STP': - self.stp = True - if self.stp + self.cdna + self.dna != 1: - raise ValueError("More than one option selected!") - - self.k = conf['schedsamp_k'] - self.use_state = conf['use_state'] - self.num_masks = conf['num_masks'] - self.context_frames = conf['context_frames'] - - self.batch_size, self.img_height, self.img_width, self.color_channels = [int(i) for i in - images[0].get_shape()[0:4]] - self.lstm_func = basic_conv_lstm_cell - - # Generated robot states and images. - self.gen_states = [] - self.gen_images = [] - self.gen_masks = [] - - self.moved_images = [] - - self.moved_pix_distrib1 = [] - self.moved_pix_distrib2 = [] - - self.states = states - self.gen_distrib1 = [] - self.gen_distrib2 = [] - - self.trafos = [] - - def build(self): - - if 'kern_size' in self.conf.keys(): - KERN_SIZE = self.conf['kern_size'] - else: - KERN_SIZE = 5 - - batch_size, img_height, img_width, color_channels = self.images[0].get_shape()[0:4] - lstm_func = basic_conv_lstm_cell - - - if self.states != None: - current_state = self.states[0] - else: - current_state = None - - if self.actions == None: - self.actions = [None for _ in self.images] - - if self.k == -1: - feedself = True - else: - # Scheduled sampling: - # Calculate number of ground-truth frames to pass in. - num_ground_truth = tf.to_int32( - tf.round(tf.to_float(batch_size) * (self.k / (self.k + tf.exp(self.iter_num / self.k))))) - feedself = False - - # LSTM state sizes and states. - - if 'lstm_size' in self.conf: - lstm_size = self.conf['lstm_size'] - print('using lstm size', lstm_size) - else: - ngf = self.conf['ngf'] - lstm_size = np.int32(np.array([ngf, ngf * 2, ngf * 4, ngf * 2, ngf])) - - - lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None - lstm_state5, lstm_state6, lstm_state7 = None, None, None - - for t, action in enumerate(self.actions): - print(t) - # Reuse variables after the first timestep. - reuse = bool(self.gen_images) - - done_warm_start = len(self.gen_images) > self.context_frames - 1 - with slim.arg_scope( - [lstm_func, slim.layers.conv2d, slim.layers.fully_connected, - tf_layers.layer_norm, slim.layers.conv2d_transpose], - reuse=reuse): - - if feedself and done_warm_start: - # Feed in generated image. - prev_image = self.gen_images[-1] # 64x64x6 - if self.pix_distributions1 != None: - prev_pix_distrib1 = self.gen_distrib1[-1] - if 'ndesig' in self.conf: - prev_pix_distrib2 = self.gen_distrib2[-1] - elif done_warm_start: - # Scheduled sampling - prev_image = scheduled_sample(self.images[t], self.gen_images[-1], batch_size, - num_ground_truth) - else: - # Always feed in ground_truth - prev_image = self.images[t] - if self.pix_distributions1 != None: - prev_pix_distrib1 = self.pix_distributions1[t] - if 'ndesig' in self.conf: - prev_pix_distrib2 = self.pix_distributions2[t] - if len(prev_pix_distrib1.get_shape()) == 3: - prev_pix_distrib1 = tf.expand_dims(prev_pix_distrib1, -1) - if 'ndesig' in self.conf: - prev_pix_distrib2 = tf.expand_dims(prev_pix_distrib2, -1) - - if 'refeed_firstimage' in self.conf: - assert self.conf['model']=='STP' - if t > 1: - input_image = self.images[1] - print('refeed with image 1') - else: - input_image = prev_image - else: - input_image = prev_image - - # Predicted state is always fed back in - if not 'ignore_state_action' in self.conf: - state_action = tf.concat(axis=1, values=[action, current_state]) - - enc0 = slim.layers.conv2d( #32x32x32 - input_image, - 32, [5, 5], - stride=2, - scope='scale1_conv1', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'layer_norm1'}) - - hidden1, lstm_state1 = lstm_func( # 32x32x16 - enc0, lstm_state1, lstm_size[0], scope='state1') - hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2') - - enc1 = slim.layers.conv2d( # 16x16x16 - hidden1, hidden1.get_shape()[3], [3, 3], stride=2, scope='conv2') - - hidden3, lstm_state3 = lstm_func( #16x16x32 - enc1, lstm_state3, lstm_size[1], scope='state3') - hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4') - - enc2 = slim.layers.conv2d( # 8x8x32 - hidden3, hidden3.get_shape()[3], [3, 3], stride=2, scope='conv3') - - if not 'ignore_state_action' in self.conf: - # Pass in state and action. - if 'ignore_state' in self.conf: - lowdim = action - print('ignoring state') - else: - lowdim = state_action - - smear = tf.reshape( - lowdim, - [int(batch_size), 1, 1, int(lowdim.get_shape()[1])]) - smear = tf.tile( - smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1]) - - enc2 = tf.concat(axis=3, values=[enc2, smear]) - else: - print('ignoring states and actions') - - enc3 = slim.layers.conv2d( #8x8x32 - enc2, hidden3.get_shape()[3], [1, 1], stride=1, scope='conv4') - - hidden5, lstm_state5 = lstm_func( #8x8x64 - enc3, lstm_state5, lstm_size[2], scope='state5') - hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6') - enc4 = slim.layers.conv2d_transpose( #16x16x64 - hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1') - - hidden6, lstm_state6 = lstm_func( #16x16x32 - enc4, lstm_state6, lstm_size[3], scope='state6') - hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7') - - if 'noskip' not in self.conf: - # Skip connection. - hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 - - enc5 = slim.layers.conv2d_transpose( #32x32x32 - hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2') - hidden7, lstm_state7 = lstm_func( # 32x32x16 - enc5, lstm_state7, lstm_size[4], scope='state7') - hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8') - - if not 'noskip' in self.conf: - # Skip connection. - hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 - - enc6 = slim.layers.conv2d_transpose( # 64x64x16 - hidden7, - hidden7.get_shape()[3], 3, stride=2, scope='convt3', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'layer_norm9'}) - - if 'transform_from_firstimage' in self.conf: - prev_image = self.images[1] - if self.pix_distributions1 != None: - prev_pix_distrib1 = self.pix_distributions1[1] - prev_pix_distrib1 = tf.expand_dims(prev_pix_distrib1, -1) - print('transform from image 1') - - if self.conf['model'] == 'DNA': - # Using largest hidden state for predicting untied conv kernels. - trafo_input = slim.layers.conv2d_transpose( - enc6, KERN_SIZE ** 2, 1, stride=1, scope='convt4_cam2') - - transformed_l = [self.dna_transformation(prev_image, trafo_input, self.conf['kern_size'])] - if self.pix_distributions1 != None: - transf_distrib_ndesig1 = [self.dna_transformation(prev_pix_distrib1, trafo_input, KERN_SIZE)] - if 'ndesig' in self.conf: - transf_distrib_ndesig2 = [ - self.dna_transformation(prev_pix_distrib2, trafo_input, KERN_SIZE)] - - - extra_masks = 1 ## extra_masks = 2 is needed for running singleview_shifted!! - # print('using extra masks 2 because of single view shifted!!') - # extra_masks = 2 - - if self.conf['model'] == 'CDNA': - if 'gen_pix' in self.conf: - # Using largest hidden state for predicting a new image layer. - enc7 = slim.layers.conv2d_transpose( - enc6, color_channels, 1, stride=1, scope='convt4', activation_fn=None) - # This allows the network to also generate one image from scratch, - # which is useful when regions of the image become unoccluded. - transformed_l = [tf.nn.sigmoid(enc7)] - extra_masks = 2 - else: - transformed_l = [] - extra_masks = 1 - - cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) - new_transformed, _ = self.cdna_transformation(prev_image, - cdna_input, - reuse_sc=reuse) - transformed_l += new_transformed - self.moved_images.append(transformed_l) - - if self.pix_distributions1 != None: - transf_distrib_ndesig1, _ = self.cdna_transformation(prev_pix_distrib1, - cdna_input, - reuse_sc=True) - self.moved_pix_distrib1.append(transf_distrib_ndesig1) - if 'ndesig' in self.conf: - transf_distrib_ndesig2, _ = self.cdna_transformation( - prev_pix_distrib2, - cdna_input, - reuse_sc=True) - - self.moved_pix_distrib2.append(transf_distrib_ndesig2) - - if self.conf['model'] == 'STP': - enc7 = slim.layers.conv2d_transpose(enc6, color_channels, 1, stride=1, scope='convt5', activation_fn= None) - # This allows the network to also generate one image from scratch, - # which is useful when regions of the image become unoccluded. - if 'gen_pix' in self.conf: - transformed_l = [tf.nn.sigmoid(enc7)] - extra_masks = 2 - else: - transformed_l = [] - extra_masks = 1 - - enc_stp = tf.reshape(hidden5, [int(batch_size), -1]) - stp_input = slim.layers.fully_connected( - enc_stp, 200, scope='fc_stp_cam2') - - # disabling capability to generete pixels - reuse_stp = None - if reuse: - reuse_stp = reuse - - # enable the generation of pixels: - transformed, trafo = self.stp_transformation(prev_image, stp_input, self.num_masks, reuse_stp, suffix='cam2') - transformed_l += transformed - - self.trafos.append(trafo) - self.moved_images.append(transformed_l) - - if self.pix_distributions1 != None: - transf_distrib_ndesig1, _ = self.stp_transformation(prev_pix_distrib1, stp_input, suffix='cam2', reuse=True) - self.moved_pix_distrib1.append(transf_distrib_ndesig1) - - if '1stimg_bckgd' in self.conf: - background = self.images[0] - print('using background from first image..') - else: background = prev_image - output, mask_list = self.fuse_trafos(enc6, background, - transformed_l, - scope='convt7_cam2', - extra_masks= extra_masks) - self.gen_images.append(output) - self.gen_masks.append(mask_list) - - if self.pix_distributions1!=None: - pix_distrib_output = self.fuse_pix_distrib(extra_masks, - mask_list, - self.pix_distributions1, - prev_pix_distrib1, - transf_distrib_ndesig1) - - self.gen_distrib1.append(pix_distrib_output) - if 'ndesig' in self.conf: - pix_distrib_output = self.fuse_pix_distrib(extra_masks, - mask_list, - self.pix_distributions2, - prev_pix_distrib2, - transf_distrib_ndesig2) - - self.gen_distrib2.append(pix_distrib_output) - - if int(current_state.get_shape()[1]) == 0: - current_state = tf.zeros_like(state_action) - else: - current_state = slim.layers.fully_connected( - state_action, - int(current_state.get_shape()[1]), - scope='state_pred', - activation_fn=None) - - self.gen_states.append(current_state) - - def fuse_trafos(self, enc6, background_image, transformed, scope, extra_masks): - masks = slim.layers.conv2d_transpose( - enc6, (self.conf['num_masks']+ extra_masks), 1, stride=1, activation_fn=None, scope=scope) - - img_height = 64 - img_width = 64 - num_masks = self.conf['num_masks'] - - if self.conf['model']=='DNA': - if num_masks != 1: - raise ValueError('Only one mask is supported for DNA model.') - - # the total number of masks is num_masks +extra_masks because of background and generated pixels! - masks = tf.reshape( - tf.nn.softmax(tf.reshape(masks, [-1, num_masks +extra_masks])), - [int(self.batch_size), int(img_height), int(img_width), num_masks +extra_masks]) - mask_list = tf.split(axis=3, num_or_size_splits=num_masks +extra_masks, value=masks) - output = mask_list[0] * background_image - - assert len(transformed) == len(mask_list[1:]) - for layer, mask in zip(transformed, mask_list[1:]): - output += layer * mask - - return output, mask_list - - def fuse_pix_distrib(self, extra_masks, mask_list, pix_distributions, prev_pix_distrib, - transf_distrib): - - if '1stimg_bckgd' in self.conf: - background_pix = pix_distributions[0] - if len(background_pix.get_shape()) == 3: - background_pix = tf.expand_dims(background_pix, -1) - print('using pix_distrib-background from first image..') - else: - background_pix = prev_pix_distrib - pix_distrib_output = mask_list[0] * background_pix - if 'gen_pix' in self.conf: - pix_distrib_output += mask_list[1] * prev_pix_distrib # assume pixels don't when image is generated from scratch - for i in range(self.num_masks): - pix_distrib_output += transf_distrib[i] * mask_list[i + extra_masks] - pix_distrib_output /= tf.reduce_sum(pix_distrib_output, axis=(1, 2), keepdims=True) - return pix_distrib_output - - ## Utility functions - def stp_transformation(self, prev_image, stp_input, num_masks, reuse= None, suffix = None): - """Apply spatial transformer predictor (STP) to previous image. - - Args: - prev_image: previous image to be transformed. - stp_input: hidden layer to be used for computing STN parameters. - num_masks: number of masks and hence the number of STP transformations. - Returns: - List of images transformed by the predicted STP parameters. - """ - # Only import spatial transformer if needed. - from spatial_transformer import transformer - - identity_params = tf.convert_to_tensor( - np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32)) - transformed = [] - trafos = [] - for i in range(num_masks): - params = slim.layers.fully_connected( - stp_input, 6, scope='stp_params' + str(i) + suffix, - activation_fn=None, - reuse= reuse) + identity_params - outsize = (prev_image.get_shape()[1], prev_image.get_shape()[2]) - transformed.append(transformer(prev_image, params, outsize)) - trafos.append(params) - - return transformed, trafos - - def dna_transformation(self, prev_image, dna_input, DNA_KERN_SIZE): - """Apply dynamic neural advection to previous image. - - Args: - prev_image: previous image to be transformed. - dna_input: hidden lyaer to be used for computing DNA transformation. - Returns: - List of images transformed by the predicted CDNA kernels. - """ - # Construct translated images. - pad_len = int(np.floor(DNA_KERN_SIZE / 2)) - prev_image_pad = tf.pad(prev_image, [[0, 0], [pad_len, pad_len], [pad_len, pad_len], [0, 0]]) - image_height = int(prev_image.get_shape()[1]) - image_width = int(prev_image.get_shape()[2]) - - inputs = [] - for xkern in range(DNA_KERN_SIZE): - for ykern in range(DNA_KERN_SIZE): - inputs.append( - tf.expand_dims( - tf.slice(prev_image_pad, [0, xkern, ykern, 0], - [-1, image_height, image_width, -1]), [3])) - inputs = tf.concat(axis=3, values=inputs) - - # Normalize channels to 1. - kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT - kernel = tf.expand_dims( - kernel / tf.reduce_sum( - kernel, [3], keepdims=True), [4]) - - return tf.reduce_sum(kernel * inputs, [3], keepdims=False) - - def cdna_transformation(self, prev_image, cdna_input, reuse_sc=None): - """Apply convolutional dynamic neural advection to previous image. - - Args: - prev_image: previous image to be transformed. - cdna_input: hidden lyaer to be used for computing CDNA kernels. - num_masks: the number of masks and hence the number of CDNA transformations. - color_channels: the number of color channels in the images. - Returns: - List of images transformed by the predicted CDNA kernels. - """ - batch_size = int(cdna_input.get_shape()[0]) - height = int(prev_image.get_shape()[1]) - width = int(prev_image.get_shape()[2]) - - DNA_KERN_SIZE = self.conf['kern_size'] - num_masks = self.conf['num_masks'] - color_channels = int(prev_image.get_shape()[3]) - - # Predict kernels using linear function of last hidden layer. - cdna_kerns = slim.layers.fully_connected( - cdna_input, - DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks, - scope='cdna_params', - activation_fn=None, - reuse = reuse_sc) - - # Reshape and normalize. - cdna_kerns = tf.reshape( - cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks]) - cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT - norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True) - cdna_kerns /= norm_factor - cdna_kerns_summary = cdna_kerns - - # Transpose and reshape. - cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3]) - cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks]) - prev_image = tf.transpose(prev_image, [3, 1, 2, 0]) - - transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME') - - # Transpose and reshape. - transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks]) - transformed = tf.transpose(transformed, [3, 1, 2, 0, 4]) - transformed = tf.unstack(value=transformed, axis=-1) - - return transformed, cdna_kerns_summary - - -def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): - """Sample batch with specified mix of ground truth and generated data_files points. - - Args: - ground_truth_x: tensor of ground-truth data_files points. - generated_x: tensor of generated data_files points. - batch_size: batch size - num_ground_truth: number of ground-truth examples to include in batch. - Returns: - New batch with num_ground_truth sampled from ground_truth_x and the rest - from generated_x. - """ - idx = tf.random_shuffle(tf.range(int(batch_size))) - ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) - generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size))) - - ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) - generated_examps = tf.gather(generated_x, generated_idx) - return tf.dynamic_stitch([ground_truth_idx, generated_idx], - [ground_truth_examps, generated_examps]) - - -def generator_fn(inputs, mode, hparams): - images = tf.unstack(inputs['images'], axis=0) - actions = tf.unstack(inputs['actions'], axis=0) - states = tf.unstack(inputs['states'], axis=0) - pix_distributions1 = tf.unstack(inputs['pix_distribs'], axis=0) if 'pix_distribs' in inputs else None - iter_num = tf.to_float(tf.train.get_or_create_global_step()) - - if isinstance(hparams.kernel_size, (tuple, list)): - kernel_height, kernel_width = hparams.kernel_size - assert kernel_height == kernel_width - kern_size = kernel_height - else: - kern_size = hparams.kernel_size - - schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1 - conf = { - 'context_frames': hparams.context_frames, # of frames before predictions.' , - 'use_state': 1, # 'Whether or not to give the state+action to the model' , - 'ngf': hparams.ngf, - 'model': hparams.transformation.upper(), # 'model architecture to use - CDNA, DNA, or STP' , - 'num_masks': hparams.num_masks, # 'number of masks, usually 1 for DNA, 10 for CDNA, STN.' , - 'schedsamp_k': schedule_sampling_k, # 'The k hyperparameter for scheduled sampling -1 for no scheduled sampling.' , - 'kern_size': kern_size, # size of DNA kerns - } - if hparams.first_image_background: - conf['1stimg_bckgd'] = '' - if hparams.generate_scratch_image: - conf['gen_pix'] = '' - - m = Prediction_Model(images, actions, states, - pix_distributions1=pix_distributions1, - iter_num=iter_num, conf=conf) - m.build() - outputs = { - 'gen_images': tf.stack(m.gen_images, axis=0), - 'gen_states': tf.stack(m.gen_states, axis=0), - } - if 'pix_distribs' in inputs: - outputs['gen_pix_distribs'] = tf.stack(m.gen_distrib1, axis=0) - return outputs - - -class SNAVideoPredictionModel(VideoPredictionModel): - def __init__(self, *args, **kwargs): - super(SNAVideoPredictionModel, self).__init__( - generator_fn, *args, **kwargs) - - def get_default_hparams_dict(self): - default_hparams = super(SNAVideoPredictionModel, self).get_default_hparams_dict() - hparams = dict( - batch_size=32, - l1_weight=0.0, - l2_weight=1.0, - ngf=16, - transformation='cdna', - kernel_size=(5, 5), - num_masks=10, - first_image_background=True, - generate_scratch_image=True, - schedule_sampling_k=900.0, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) diff --git a/video_prediction_tools/deprecated/model_modules/sv2p_model.py b/video_prediction_tools/deprecated/model_modules/sv2p_model.py deleted file mode 100644 index f0ddd99cecec43348ef00f87162d6dbf51ed95aa..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/model_modules/sv2p_model.py +++ /dev/null @@ -1,677 +0,0 @@ -# Copyright 2016 The TensorFlow Authors All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -"""Model architecture for predictive model, including CDNA, DNA, and STP.""" - -import itertools -import numpy as np -import tensorflow as tf -import tensorflow.contrib.slim as slim -from tensorflow.contrib.layers.python import layers as tf_layers -from tensorflow.contrib.slim import add_arg_scope -from tensorflow.contrib.slim import layers -from model_modules.video_prediction.models import VideoPredictionModel - - -# Amount to use when lower bounding tensors -RELU_SHIFT = 1e-12 - -# kernel size for DNA and CDNA. -DNA_KERN_SIZE = 5 - - -def init_state(inputs, - state_shape, - state_initializer=tf.zeros_initializer(), - dtype=tf.float32): - """Helper function to create an initial state given inputs. - Args: - inputs: input Tensor, at least 2D, the first dimension being batch_size - state_shape: the shape of the state. - state_initializer: Initializer(shape, dtype) for state Tensor. - dtype: Optional dtype, needed when inputs is None. - Returns: - A tensors representing the initial state. - """ - if inputs is not None: - # Handle both the dynamic shape as well as the inferred shape. - inferred_batch_size = inputs.get_shape().with_rank_at_least(1)[0] - dtype = inputs.dtype - else: - inferred_batch_size = 0 - initial_state = state_initializer( - [inferred_batch_size] + state_shape, dtype=dtype) - return initial_state - - -@add_arg_scope -def basic_conv_lstm_cell(inputs, - state, - num_channels, - filter_size=5, - forget_bias=1.0, - scope=None, - reuse=None): - """Basic LSTM recurrent network cell, with 2D convolution connctions. - We add forget_bias (default: 1) to the biases of the forget gate in order to - reduce the scale of forgetting in the beginning of the training. - It does not allow cell clipping, a projection layer, and does not - use peep-hole connections: it is the basic baseline. - Args: - inputs: input Tensor, 4D, batch x height x width x channels. - state: state Tensor, 4D, batch x height x width x channels. - num_channels: the number of output channels in the layer. - filter_size: the shape of the each convolution filter. - forget_bias: the initial value of the forget biases. - scope: Optional scope for variable_scope. - reuse: whether or not the layer and the variables should be reused. - Returns: - a tuple of tensors representing output and the new state. - """ - spatial_size = inputs.get_shape()[1:3] - if state is None: - state = init_state(inputs, list(spatial_size) + [2 * num_channels]) - with tf.variable_scope(scope, - 'BasicConvLstmCell', - [inputs, state], - reuse=reuse): - inputs.get_shape().assert_has_rank(4) - state.get_shape().assert_has_rank(4) - c, h = tf.split(axis=3, num_or_size_splits=2, value=state) - inputs_h = tf.concat(axis=3, values=[inputs, h]) - # Parameters of gates are concatenated into one conv for efficiency. - i_j_f_o = layers.conv2d(inputs_h, - 4 * num_channels, [filter_size, filter_size], - stride=1, - activation_fn=None, - scope='Gates') - - # i = input_gate, j = new_input, f = forget_gate, o = output_gate - i, j, f, o = tf.split(axis=3, num_or_size_splits=4, value=i_j_f_o) - - new_c = c * tf.sigmoid(f + forget_bias) + tf.sigmoid(i) * tf.tanh(j) - new_h = tf.tanh(new_c) * tf.sigmoid(o) - - return new_h, tf.concat(axis=3, values=[new_c, new_h]) - - -def kl_divergence(mu, log_sigma): - """KL divergence of diagonal gaussian N(mu,exp(log_sigma)) and N(0,1). - - Args: - mu: mu parameter of the distribution. - log_sigma: log(sigma) parameter of the distribution. - Returns: - the KL loss. - """ - - return -.5 * tf.reduce_sum(1. + log_sigma - tf.square(mu) - tf.exp(log_sigma), - axis=1) - - -def construct_latent_tower(images, hparams): - """Builds convolutional latent tower for stochastic model. - - At training time this tower generates a latent distribution (mean and std) - conditioned on the entire video. This latent variable will be fed to the - main tower as an extra variable to be used for future frames prediction. - At inference time, the tower is disabled and only returns latents sampled - from N(0,1). - If the multi_latent flag is on, a different latent for every timestep would - be generated. - - Args: - images: tensor of ground truth image sequences - Returns: - latent_mean: predicted latent mean - latent_std: predicted latent standard deviation - latent_loss: loss of the latent twoer - samples: random samples sampled from standard guassian - """ - - with slim.arg_scope([slim.conv2d], reuse=False): - stacked_images = tf.concat(images, 3) - - latent_enc1 = slim.conv2d( - stacked_images, - 32, [3, 3], - stride=2, - scope='latent_conv1', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'latent_norm1'}) - - latent_enc2 = slim.conv2d( - latent_enc1, - 64, [3, 3], - stride=2, - scope='latent_conv2', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'latent_norm2'}) - - latent_enc3 = slim.conv2d( - latent_enc2, - 64, [3, 3], - stride=1, - scope='latent_conv3', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'latent_norm3'}) - - latent_mean = slim.conv2d( - latent_enc3, - hparams.latent_channels, [3, 3], - stride=2, - activation_fn=None, - scope='latent_mean', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'latent_norm_mean'}) - - latent_std = slim.conv2d( - latent_enc3, - hparams.latent_channels, [3, 3], - stride=2, - scope='latent_std', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'latent_std_norm'}) - - latent_std += hparams.latent_std_min - - return latent_mean, latent_std - - -def encoder_fn(inputs, hparams): - images = tf.unstack(inputs['images'], axis=0) - latent_mean, latent_std = construct_latent_tower(images, hparams) - outputs = {'zs_mu_enc': latent_mean, 'zs_log_sigma_sq_enc': latent_std} - return outputs - - -def construct_model(images, - actions=None, - states=None, - outputs_enc=None, - iter_num=-1.0, - k=-1, - use_state=True, - num_masks=10, - stp=False, - cdna=True, - dna=False, - context_frames=2, - hparams=None): - """Build convolutional lstm video predictor using STP, CDNA, or DNA. - - Args: - images: tensor of ground truth image sequences - actions: tensor of action sequences - states: tensor of ground truth state sequences - iter_num: tensor of the current training iteration (for sched. sampling) - k: constant used for scheduled sampling. -1 to feed in own prediction. - use_state: True to include state and action in prediction - num_masks: the number of different pixel motion predictions (and - the number of masks for each of those predictions) - stp: True to use Spatial Transformer Predictor (STP) - cdna: True to use Convoluational Dynamic Neural Advection (CDNA) - dna: True to use Dynamic Neural Advection (DNA) - context_frames: number of ground truth frames to pass in before - feeding in own predictions - Returns: - gen_images: predicted future image frames - gen_states: predicted future states - - Raises: - ValueError: if more than one network option specified or more than 1 mask - specified for DNA model. - """ - # Each image is being used twice, in latent tower and main tower. - # This is to make sure we are using the *same* image for both, ... - # ... given how TF queues work. - images = [tf.identity(image) for image in images] - - if stp + cdna + dna != 1: - raise ValueError('More than one, or no network option specified.') - batch_size, img_height, img_width, color_channels = images[0].shape.as_list() - lstm_func = basic_conv_lstm_cell - - # Generated robot states and images. - gen_states, gen_images = [], [] - current_state = states[0] - - if k == -1: - feedself = True - else: - # Scheduled sampling: - # Calculate number of ground-truth frames to pass in. - num_ground_truth = tf.to_int32( - tf.round(tf.to_float(batch_size) * (k / (k + tf.exp(iter_num / k))))) - feedself = False - - # LSTM state sizes and states. - lstm_size = np.int32(np.array([32, 32, 64, 64, 128, 64, 32])) - lstm_state1, lstm_state2, lstm_state3, lstm_state4 = None, None, None, None - lstm_state5, lstm_state6, lstm_state7 = None, None, None - - # Latent tower - if hparams.stochastic_model: - latent_shape = [batch_size, img_height // 8, img_width // 8, hparams.latent_channels] - if outputs_enc is None: # equivalent to inference_time - latent_mean, latent_std = None, None - else: - latent_mean, latent_std = outputs_enc['zs_mu_enc'], outputs_enc['zs_log_sigma_sq_enc'] - assert latent_mean.shape.as_list() == latent_shape - - if hparams.multi_latent: - # timestep x batch_size x latent_size - samples = tf.random_normal( - [hparams.sequence_length - 1] + latent_shape, 0, 1, - dtype=tf.float32) - else: - # batch_size x latent_size - samples = tf.random_normal(latent_shape, 0, 1, dtype=tf.float32) - - # Main tower - for t in range(hparams.sequence_length - 1): - action = actions[t] - # Reuse variables after the first timestep. - reuse = bool(gen_images) - - done_warm_start = len(gen_images) > context_frames - 1 - with slim.arg_scope( - [lstm_func, slim.layers.conv2d, slim.layers.fully_connected, - tf_layers.layer_norm, slim.layers.conv2d_transpose], - reuse=reuse): - - if feedself and done_warm_start: - # Feed in generated image. - prev_image = gen_images[-1] - elif done_warm_start: - # Scheduled sampling - prev_image = scheduled_sample(images[t], gen_images[-1], batch_size, - num_ground_truth) - else: - # Always feed in ground_truth - prev_image = images[t] - - # Predicted state is always fed back in - state_action = tf.concat(axis=1, values=[action, current_state]) - - enc0 = slim.layers.conv2d( - prev_image, - 32, [5, 5], - stride=2, - scope='scale1_conv1', - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'layer_norm1'}) - - hidden1, lstm_state1 = lstm_func( - enc0, lstm_state1, lstm_size[0], scope='state1') - hidden1 = tf_layers.layer_norm(hidden1, scope='layer_norm2') - hidden2, lstm_state2 = lstm_func( - hidden1, lstm_state2, lstm_size[1], scope='state2') - hidden2 = tf_layers.layer_norm(hidden2, scope='layer_norm3') - enc1 = slim.layers.conv2d( - hidden2, hidden2.get_shape()[3], [3, 3], stride=2, scope='conv2') - - hidden3, lstm_state3 = lstm_func( - enc1, lstm_state3, lstm_size[2], scope='state3') - hidden3 = tf_layers.layer_norm(hidden3, scope='layer_norm4') - hidden4, lstm_state4 = lstm_func( - hidden3, lstm_state4, lstm_size[3], scope='state4') - hidden4 = tf_layers.layer_norm(hidden4, scope='layer_norm5') - enc2 = slim.layers.conv2d( - hidden4, hidden4.get_shape()[3], [3, 3], stride=2, scope='conv3') - - # Pass in state and action. - smear = tf.reshape( - state_action, - [int(batch_size), 1, 1, int(state_action.get_shape()[1])]) - smear = tf.tile( - smear, [1, int(enc2.get_shape()[1]), int(enc2.get_shape()[2]), 1]) - if use_state: - enc2 = tf.concat(axis=3, values=[enc2, smear]) - # Setup latent - if hparams.stochastic_model: - latent = samples - if hparams.multi_latent: - latent = samples[t] - if outputs_enc is not None: # equivalent to not inference_time - latent = tf.cond(iter_num < hparams.num_iterations_1st_stage, - lambda: tf.identity(latent), - lambda: latent_mean + tf.exp(latent_std / 2.0) * latent) - with tf.control_dependencies([latent]): - enc2 = tf.concat([enc2, latent], 3) - - enc3 = slim.layers.conv2d( - enc2, hidden4.get_shape()[3], [1, 1], stride=1, scope='conv4') - - hidden5, lstm_state5 = lstm_func( - enc3, lstm_state5, lstm_size[4], scope='state5') # last 8x8 - hidden5 = tf_layers.layer_norm(hidden5, scope='layer_norm6') - enc4 = slim.layers.conv2d_transpose( - hidden5, hidden5.get_shape()[3], 3, stride=2, scope='convt1') - - hidden6, lstm_state6 = lstm_func( - enc4, lstm_state6, lstm_size[5], scope='state6') # 16x16 - hidden6 = tf_layers.layer_norm(hidden6, scope='layer_norm7') - # Skip connection. - hidden6 = tf.concat(axis=3, values=[hidden6, enc1]) # both 16x16 - - enc5 = slim.layers.conv2d_transpose( - hidden6, hidden6.get_shape()[3], 3, stride=2, scope='convt2') - hidden7, lstm_state7 = lstm_func( - enc5, lstm_state7, lstm_size[6], scope='state7') # 32x32 - hidden7 = tf_layers.layer_norm(hidden7, scope='layer_norm8') - - # Skip connection. - hidden7 = tf.concat(axis=3, values=[hidden7, enc0]) # both 32x32 - - enc6 = slim.layers.conv2d_transpose( - hidden7, - hidden7.get_shape()[3], 3, stride=2, scope='convt3', activation_fn=None, - normalizer_fn=tf_layers.layer_norm, - normalizer_params={'scope': 'layer_norm9'}) - - if dna: - # Using largest hidden state for predicting untied conv kernels. - enc7 = slim.layers.conv2d_transpose( - enc6, DNA_KERN_SIZE ** 2, 1, stride=1, scope='convt4', activation_fn=None) - else: - # Using largest hidden state for predicting a new image layer. - enc7 = slim.layers.conv2d_transpose( - enc6, color_channels, 1, stride=1, scope='convt4', activation_fn=None) - # This allows the network to also generate one image from scratch, - # which is useful when regions of the image become unoccluded. - transformed = [tf.nn.sigmoid(enc7)] - - if stp: - stp_input0 = tf.reshape(hidden5, [int(batch_size), -1]) - stp_input1 = slim.layers.fully_connected( - stp_input0, 100, scope='fc_stp') - transformed += stp_transformation(prev_image, stp_input1, num_masks) - elif cdna: - cdna_input = tf.reshape(hidden5, [int(batch_size), -1]) - transformed += cdna_transformation(prev_image, cdna_input, num_masks, - int(color_channels)) - elif dna: - # Only one mask is supported (more should be unnecessary). - if num_masks != 1: - raise ValueError('Only one mask is supported for DNA model.') - transformed = [dna_transformation(prev_image, enc7)] - - masks = slim.layers.conv2d_transpose( - enc6, num_masks + 1, 1, stride=1, scope='convt7', activation_fn=None) - masks = tf.reshape( - tf.nn.softmax(tf.reshape(masks, [-1, num_masks + 1])), - [int(batch_size), int(img_height), int(img_width), num_masks + 1]) - mask_list = tf.split(axis=3, num_or_size_splits=num_masks + 1, value=masks) - output = mask_list[0] * prev_image - for layer, mask in zip(transformed, mask_list[1:]): - output += layer * mask - gen_images.append(output) - - current_state = slim.layers.fully_connected( - state_action, - int(current_state.get_shape()[1]), - scope='state_pred', - activation_fn=None) - gen_states.append(current_state) - - return gen_images, gen_states - - -## Utility functions -def stp_transformation(prev_image, stp_input, num_masks): - """Apply spatial transformer predictor (STP) to previous image. - - Args: - prev_image: previous image to be transformed. - stp_input: hidden layer to be used for computing STN parameters. - num_masks: number of masks and hence the number of STP transformations. - Returns: - List of images transformed by the predicted STP parameters. - """ - # Only import spatial transformer if needed. - from spatial_transformer import transformer - - identity_params = tf.convert_to_tensor( - np.array([1.0, 0.0, 0.0, 0.0, 1.0, 0.0], np.float32)) - transformed = [] - for i in range(num_masks - 1): - params = slim.layers.fully_connected( - stp_input, 6, scope='stp_params' + str(i), - activation_fn=None) + identity_params - transformed.append(transformer(prev_image, params)) - - return transformed - - -def cdna_transformation(prev_image, cdna_input, num_masks, color_channels): - """Apply convolutional dynamic neural advection to previous image. - - Args: - prev_image: previous image to be transformed. - cdna_input: hidden lyaer to be used for computing CDNA kernels. - num_masks: the number of masks and hence the number of CDNA transformations. - color_channels: the number of color channels in the images. - Returns: - List of images transformed by the predicted CDNA kernels. - """ - batch_size = int(cdna_input.get_shape()[0]) - height = int(prev_image.get_shape()[1]) - width = int(prev_image.get_shape()[2]) - - # Predict kernels using linear function of last hidden layer. - cdna_kerns = slim.layers.fully_connected( - cdna_input, - DNA_KERN_SIZE * DNA_KERN_SIZE * num_masks, - scope='cdna_params', - activation_fn=None) - - # Reshape and normalize. - cdna_kerns = tf.reshape( - cdna_kerns, [batch_size, DNA_KERN_SIZE, DNA_KERN_SIZE, 1, num_masks]) - cdna_kerns = tf.nn.relu(cdna_kerns - RELU_SHIFT) + RELU_SHIFT - norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keepdims=True) - cdna_kerns /= norm_factor - - # Treat the color channel dimension as the batch dimension since the same - # transformation is applied to each color channel. - # Treat the batch dimension as the channel dimension so that - # depthwise_conv2d can apply a different transformation to each sample. - cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3]) - cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks]) - # Swap the batch and channel dimensions. - prev_image = tf.transpose(prev_image, [3, 1, 2, 0]) - - # Transform image. - transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME') - - # Transpose the dimensions to where they belong. - transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks]) - transformed = tf.transpose(transformed, [3, 1, 2, 0, 4]) - transformed = tf.unstack(transformed, axis=-1) - return transformed - - -def dna_transformation(prev_image, dna_input): - """Apply dynamic neural advection to previous image. - - Args: - prev_image: previous image to be transformed. - dna_input: hidden lyaer to be used for computing DNA transformation. - Returns: - List of images transformed by the predicted CDNA kernels. - """ - # Construct translated images. - prev_image_pad = tf.pad(prev_image, [[0, 0], [2, 2], [2, 2], [0, 0]]) - image_height = int(prev_image.get_shape()[1]) - image_width = int(prev_image.get_shape()[2]) - - inputs = [] - for xkern in range(DNA_KERN_SIZE): - for ykern in range(DNA_KERN_SIZE): - inputs.append( - tf.expand_dims( - tf.slice(prev_image_pad, [0, xkern, ykern, 0], - [-1, image_height, image_width, -1]), [3])) - inputs = tf.concat(axis=3, values=inputs) - - # Normalize channels to 1. - kernel = tf.nn.relu(dna_input - RELU_SHIFT) + RELU_SHIFT - kernel = tf.expand_dims( - kernel / tf.reduce_sum( - kernel, [3], keepdims=True), [4]) - return tf.reduce_sum(kernel * inputs, [3], keepdims=False) - - -def scheduled_sample(ground_truth_x, generated_x, batch_size, num_ground_truth): - """Sample batch with specified mix of ground truth and generated data points. - - Args: - ground_truth_x: tensor of ground-truth data points. - generated_x: tensor of generated data points. - batch_size: batch size - num_ground_truth: number of ground-truth examples to include in batch. - Returns: - New batch with num_ground_truth sampled from ground_truth_x and the rest - from generated_x. - """ - idx = tf.random_shuffle(tf.range(int(batch_size))) - ground_truth_idx = tf.gather(idx, tf.range(num_ground_truth)) - generated_idx = tf.gather(idx, tf.range(num_ground_truth, int(batch_size))) - - ground_truth_examps = tf.gather(ground_truth_x, ground_truth_idx) - generated_examps = tf.gather(generated_x, generated_idx) - return tf.dynamic_stitch([ground_truth_idx, generated_idx], - [ground_truth_examps, generated_examps]) - - -def generator_fn(inputs, mode, hparams): - images = tf.unstack(inputs['images'], axis=0) - batch_size = images[0].shape[0].value - action_dim, state_dim = 4, 3 - - # if not use_state, use zero actions and states to match reference implementation. - actions = inputs.get('actions', tf.zeros([hparams.sequence_length - 1, batch_size, action_dim])) - actions = tf.unstack(actions, axis=0) - states = inputs.get('states', tf.zeros([hparams.sequence_length, batch_size, state_dim])) - states = tf.unstack(states, axis=0) - iter_num = tf.to_float(tf.train.get_or_create_global_step()) - - schedule_sampling_k = hparams.schedule_sampling_k if mode == 'train' else -1 - gen_images, gen_states = \ - construct_model(images, - actions, - states, - outputs_enc=None, - iter_num=iter_num, - k=schedule_sampling_k, - use_state='actions' in inputs, - num_masks=hparams.num_masks, - cdna=hparams.transformation == 'cdna', - dna=hparams.transformation == 'dna', - stp=hparams.transformation == 'stp', - context_frames=hparams.context_frames, - hparams=hparams) - outputs = { - 'gen_images': tf.stack(gen_images, axis=0), - 'gen_states': tf.stack(gen_states, axis=0), - } - - if mode == 'train': - outputs_enc = encoder_fn(inputs, hparams) - tf.get_variable_scope().reuse_variables() - gen_images_enc, gen_states_enc = \ - construct_model(images, - actions, - states, - outputs_enc=outputs_enc, - iter_num=iter_num, - k=schedule_sampling_k, - use_state='actions' in inputs, - num_masks=hparams.num_masks, - cdna=hparams.transformation == 'cdna', - dna=hparams.transformation == 'dna', - stp=hparams.transformation == 'stp', - context_frames=hparams.context_frames, - hparams=hparams) - outputs.update({ - 'gen_images_enc': tf.stack(gen_images_enc, axis=0), - 'gen_states_enc': tf.stack(gen_states_enc, axis=0), - 'zs_mu_enc': outputs_enc['zs_mu_enc'], - 'zs_log_sigma_sq_enc': outputs_enc['zs_log_sigma_sq_enc'], - }) - return outputs - - -class SV2PVideoPredictionModel(VideoPredictionModel): - """ - Stochastic Variational Video Prediction - https://arxiv.org/abs/1710.11252 - - Reference implementation: - https://github.com/mbz/models/tree/master/research/video_prediction - """ - def __init__(self, *args, **kwargs): - super(SV2PVideoPredictionModel, self).__init__( - generator_fn, *args, ** kwargs) - self.deterministic = not self.hparams.stochastic_model - - def get_default_hparams_dict(self): - default_hparams = super(SV2PVideoPredictionModel, self).get_default_hparams_dict() - hparams = dict( - batch_size=32, - l1_weight=0.0, - l2_weight=1.0, - kl_weight=1e-3 * 10 * 8, # equivalent to latent_loss_multiplier up to a factor (see below) - transformation='cdna', - num_masks=10, - schedule_sampling_k=900.0, - stochastic_model=True, - multi_latent=False, - latent_std_min=-5.0, - latent_channels=1, - num_iterations_1st_stage=50000, - kl_anneal_steps=(100000, 120000), - max_steps=200000, - decay_steps=(0, 0), # do not decay the learning rate (doing so produces blurrier images) - ) - # Notes on equivalence with reference implementation: - # kl_weight is equivalent to latent_loss_multiplier * time_factor * factor, where - # time_factor = (sequence_length - context_frames) since the reference implementation - # doesn't normalize the kl divergence over time, and factor = (width // 8) / latent_channels - # since the reference implementation's kl_divergence sums over axis=1 instead of axis=-1. - # The paper and the reference implementation differs in the annealing of the kl_weight. - # Based on Figure 4 and the Appendix, it seems that in the 3rd stage, the kl_weight is - # linearly increased for the first 20k iterations of this stage. - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def parse_hparams(self, hparams_dict, hparams): - # backwards compatibility - deprecated_hparams_keys = [ - 'num_gpus', - 'acvideo_gan_weight', - 'acvideo_vae_gan_weight', - 'image_gan_weight', - 'image_vae_gan_weight', - 'tuple_gan_weight', - 'tuple_vae_gan_weight', - 'gan_weight', - 'vae_gan_weight', - 'video_gan_weight', - 'video_vae_gan_weight', - ] - for deprecated_hparams_key in deprecated_hparams_keys: - hparams_dict.pop(deprecated_hparams_key, None) - return super(SV2PVideoPredictionModel, self).parse_hparams(hparams_dict, hparams) diff --git a/video_prediction_tools/deprecated/model_modules/vanilla_GAN_model.py b/video_prediction_tools/deprecated/model_modules/vanilla_GAN_model.py deleted file mode 100644 index 74f3f16bf93a44981afc52155a8a07ab2ce61d92..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/model_modules/vanilla_GAN_model.py +++ /dev/null @@ -1,227 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - - -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong" -__date__ = "2021=01-05" - - -import tensorflow as tf - -from model_modules.video_prediction.models.model_helpers import set_and_check_pred_frames -from model_modules.video_prediction.layers import layer_def as ld -from tensorflow.contrib.training import HParams - -class VanillaGANVideoPredictionModel(object): - def __init__(self, mode='train', hparams_dict=None): - """ - This is class for building vanilla GAN architecture by using updated hparameters - args: - mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model - hparams_dict: dict, the dictionary contains the hparaemters names and values - """ - self.mode = mode - self.hparams_dict = hparams_dict - self.hparams = self.parse_hparams() - self.learning_rate = self.hparams.lr - self.total_loss = None - self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) - self.max_epochs = self.hparams.max_epochs - self.loss_fun = self.hparams.loss_fun - self.batch_size = self.hparams.batch_size - self.z_dim = self.hparams.z_dim # dim of noise-vector - - def get_default_hparams(self): - return HParams(**self.get_default_hparams_dict()) - - def parse_hparams(self): - """ - Parse the hparams setting to ovoerride the default ones - """ - - parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) - return parsed_hparams - - - def get_default_hparams_dict(self): - """ - The function that contains default hparams - Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. - sequence_length : the number of frames in the video sequence - max_epochs : the number of epochs to train model - lr : learning rate - loss_fun : the loss function - """ - hparams = dict( - context_frames=12, - sequence_length=24, - max_epochs = 20, - batch_size = 40, - lr = 0.001, - loss_fun = "cross_entropy", - shuffle_on_val= True, - z_dim = 32, - ) - return hparams - - - def build_graph(self, x): - self.is_build_graph = False - self.x = x["images"] - self.width = self.x.shape.as_list()[3] - self.height = self.x.shape.as_list()[2] - self.channels = self.x.shape.as_list()[4] - self.n_samples = self.x.shape.as_list()[0] * self.x.shape.as_list()[1] - self.x = tf.reshape(self.x, [-1, self.height,self.width,self.channels]) - self.global_step = tf.train.get_or_create_global_step() - original_global_variables = tf.global_variables() - # Architecture - self.define_gan() - #This is the loss function (RMSE): - #This is loss function only for 1 channel (temperature RMSE) - if self.mode == "train": - self.D_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.D_loss, var_list=self.disc_vars) - with tf.control_dependencies([self.D_solver]): - self.G_solver = tf.train.AdamOptimizer(learning_rate = self.learning_rate).minimize(self.G_loss, var_list=self.gen_vars) - with tf.control_dependencies([self.G_solver]): - self.train_op = tf.assign_add(self.global_step,1) - else: - self.train_op = None - self.total_loss = self.G_loss + self.D_loss - self.outputs = {} - self.outputs["gen_images"] = self.gen_images - self.outputs["total_loss"] = self.total_loss - # Summary op - self.loss_summary = tf.summary.scalar("total_loss", self.G_loss + self.D_loss) - self.summary_op = tf.summary.merge_all() - global_variables = [var for var in tf.global_variables() if var not in original_global_variables] - self.saveable_variables = [self.global_step] + global_variables - self.is_build_graph = True - return self.is_build_graph - - def get_noise(self): - """ - Function for creating noise: Given the dimensions (n_samples,z_dim) - """ - self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.n_samples, self.height, self.width, self.channels]) - return self.noise - - def get_generator_block(self,inputs,output_dim,idx): - - """ - Generator Block - Function for return a neural network of the generator given input and output dimensions - args: - inputs : the input vector - output_dim: the dimeniosn of output vector - return: - a generator neural network layer, with a convolutional layers followed by batch normalization and a relu activation - - """ - output1 = ld.conv_layer(inputs,kernel_size=2,stride=1,num_features=output_dim,idx=idx,activate="linear") - output2 = ld.bn_layers(output1,idx,is_training=False) - output3 = tf.nn.relu(output2) - return output3 - - - def generator(self,hidden_dim): - """ - Function to build up the generator architecture - args: - noise: a noise tensor with dimension (n_samples,height,width,channel) - hidden_dim: the inner dimension - """ - with tf.variable_scope("generator",reuse=tf.AUTO_REUSE): - layer1 = self.get_generator_block(self.noise,hidden_dim,1) - layer2 = self.get_generator_block(layer1,hidden_dim*2,2) - layer3 = self.get_generator_block(layer2,hidden_dim*4,3) - layer4 = self.get_generator_block(layer3,hidden_dim*8,4) - layer5 = ld.conv_layer(layer4,kernel_size=2,stride=1,num_features=self.channels,idx=5,activate="linear") - layer6 = tf.nn.sigmoid(layer5,name="6_conv") - print("layer6",layer6) - return layer6 - - - - def get_discriminator_block(self,inputs,output_dim,idx): - - """ - Distriminator block - Function for ruturn a neural network of a descriminator given input and output dimensions - - args: - inputs : the dimension of input vector - output_dim: the dimension of output dim - idx: : the index for the namespace of this block - Return: - a distriminator neural network layer with a convolutional layers followed by a leakyRelu function - """ - output1 = ld.conv_layer(inputs,2,stride=1,num_features=output_dim,idx=idx,activate="linear") - output2 = tf.nn.leaky_relu(output1) - return output2 - - - def discriminator(self,image,hidden_dim): - """ - Function that get discriminator architecture - """ - with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): - layer1 = self.get_discriminator_block(image,hidden_dim,idx=1) - layer2 = self.get_discriminator_block(layer1,hidden_dim*4,idx=2) - layer3 = self.get_discriminator_block(layer2,hidden_dim*2,idx=3) - layer4 = self.get_discriminator_block(layer3, self.channels,idx=4) - layer5 = tf.nn.sigmoid(layer4) - return layer5 - - - def get_disc_loss(self): - """ - Return the loss of discriminator given inputs - """ - - real_labels = tf.ones_like(self.D_real) - gen_labels = tf.zeros_like(self.D_fake) - D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real, labels=real_labels)) - D_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=gen_labels)) - self.D_loss = D_loss_real + D_loss_fake - return self.D_loss - - - def get_gen_loss(self): - """ - Param: - num_images: the number of images the generator should produce, which is also the lenght of the real image - z_dim : the dimension of the noise vector, a scalar - Return the loss of generator given inputs - """ - real_labels = tf.ones_like(self.gen_images) - self.G_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_fake, labels=real_labels)) - return self.G_loss - - def get_vars(self): - """ - Get trainable variables from discriminator and generator - """ - self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] - self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] - - - - def define_gan(self): - """ - Define gan architectures - """ - self.noise = self.get_noise() - self.gen_images = self.generator(hidden_dim=8) - self.D_real = self.discriminator(self.x,hidden_dim=8) - self.D_fake = self.discriminator(self.gen_images,hidden_dim=8) - self.get_gen_loss() - self.get_disc_loss() - self.get_vars() - diff --git a/video_prediction_tools/deprecated/modules_postprocess.sh b/video_prediction_tools/deprecated/modules_postprocess.sh deleted file mode 100755 index e424b4cf6b77db7c2b9e573e55977e35818a8b38..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/modules_postprocess.sh +++ /dev/null @@ -1,36 +0,0 @@ -#!/usr/bin/env bash - -# __author__ = Bing Gong, Michael Langguth -# __date__ = '2021_01_04' - -# This script loads the required modules for the postprocessing workflow step of AMBS on Juwels and HDF-ML. -# Note that some other packages have to be installed into the virtual environment since not all Python-packages -# are available via the software stack (see create_env.sh and requirements.txt). - -HOST_NAME=`hostname` - -echo "Start loading modules on ${HOST_NAME}..." -echo "modules_postprocess.sh is subject to: " -echo "* visualize_postprocess_era5_<exp_id>.sh" - -module purge -module use $OTHERSTAGES -ml Stages/2019a -ml GCC/8.3.0 -ml GCCcore/.8.3.0 -ml ParaStationMPI/5.2.2-1 -ml mpi4py/3.0.1-Python-3.6.8 -# serialized version of HDF5 is used since only this version is compatible with TensorFlow/1.13.1-GPU-Python-3.6.8 -ml h5py/2.9.0-serial-Python-3.6.8 -ml TensorFlow/1.13.1-GPU-Python-3.6.8 -ml cuDNN/7.5.1.10-CUDA-10.1.105 -ml SciPy-Stack/2019a-Python-3.6.8 -ml scikit/2019a-Python-3.6.8 -ml netcdf4-python/1.5.0.1-Python-3.6.8 -ml basemap/1.2.0-Python-3.6.8 - -# clean up if triggered via script argument -if [[ $1 == purge ]]; then - echo "Purge all modules after loading them..." - module --force purge -fi diff --git a/video_prediction_tools/deprecated/modules_train.sh b/video_prediction_tools/deprecated/modules_train.sh deleted file mode 100755 index babc7489e4613254e9fbe7fa05ba0e755a702577..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/modules_train.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/usr/bin/env bash - -# __author__ = Bing Gong, Michael Langguth -# __date__ = '2021_01_15' - -# This script loads the required modules for the training workflow step of AMBS on Juwels, Juwels Booster and HDF-ML. -# Note that some other packages have to be installed into the virtual environment since not all Python-packages -# are available via the software stack (see create_env.sh and requirements.txt). - -HOST_NAME=`hostname` - -echo "Start loading modules on ${HOST_NAME}..." -echo "modules_train.sh is subject to: " -echo "* preprocess_data_era5_step2.sh" -echo "* train_model_era5_[booster_]<exp_id>.sh" - -module use $OTHERSTAGES -if [[ "${HOST_NAME}" == jwlogin2[1-4]* || "${HOST_NAME}" == jwb* ]]; then - ml Stages/2020 - ml UCX/1.8.1 - ml GCC/9.3.0 - ml OpenMPI/4.1.0rc1 -else - ml Stages/2019a - ml GCC/8.3.0 - ml ParaStationMPI/5.4.4-1 - ml mpi4py/3.0.1-Python-3.6.8 - ml h5py/2.9.0-serial-Python-3.6.8 - ml TensorFlow/1.13.1-GPU-Python-3.6.8 - ml cuDNN/7.5.1.10-CUDA-10.1.105 - ml SciPy-Stack/2019a-Python-3.6.8 - ml scikit/2019a-Python-3.6.8 - ml netcdf4-python/1.5.0.1-Python-3.6.8 - # Horovod is excluded as long as parallelization does not work properly - # Note: Horovod/0.16.2 requires MVAPICH2 which is incomaptible with netcdf4-python - #ml MVAPICH2/2.3.3-GDR # - #ml Horovod/0.16.2-GPU-Python-3.6.8 -fi - -# clean up if triggered via script argument -if [[ $1 == purge ]]; then - echo "Purge all modules after loading them..." - module --force purge -fi diff --git a/video_prediction_tools/deprecated/pretrained_models/download_model.sh b/video_prediction_tools/deprecated/pretrained_models/download_model.sh deleted file mode 100644 index fdffd762b334f709edc9369b1d7c69268b2c43b4..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/pretrained_models/download_model.sh +++ /dev/null @@ -1,73 +0,0 @@ -#!/usr/bin/env bash - -# exit if any command fails -set -e - -if [ "$#" -ne 2 ]; then - echo "Usage: $0 DATASET_NAME MODEL_NAME" >&2 - exit 1 -fi -DATASET_NAME=$1 -MODEL_NAME=$2 - -declare -A model_name_to_fname -if [ ${DATASET_NAME} = "bair_action_free" ]; then - model_name_to_fname=( - [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1 - [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1 - [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2 - [ours_gan]=${DATASET_NAME}_ours_gan - [ours_savp]=${DATASET_NAME}_ours_savp - [ours_vae]=${DATASET_NAME}_ours_vae_l1 - [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1 - [ours_vae_l2]=${DATASET_NAME}_ours_vae_l2 - [sv2p_time_invariant]=${DATASET_NAME}_sv2p_time_invariant - ) -elif [ ${DATASET_NAME} = "kth" ]; then - model_name_to_fname=( - [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1 - [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1 - [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2 - [ours_gan]=${DATASET_NAME}_ours_gan - [ours_savp]=${DATASET_NAME}_ours_savp - [ours_vae]=${DATASET_NAME}_ours_vae_l1 - [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1 - [sv2p_time_invariant]=${DATASET_NAME}_sv2p_time_invariant - [sv2p_time_variant]=${DATASET_NAME}_sv2p_time_variant - ) -elif [ ${DATASET_NAME} = "bair" ]; then - model_name_to_fname=( - [ours_deterministic]=${DATASET_NAME}_ours_deterministic_l1 - [ours_deterministic_l1]=${DATASET_NAME}_ours_deterministic_l1 - [ours_deterministic_l2]=${DATASET_NAME}_ours_deterministic_l2 - [ours_gan]=${DATASET_NAME}_ours_gan - [ours_savp]=${DATASET_NAME}_ours_savp - [ours_vae]=${DATASET_NAME}_ours_vae_l1 - [ours_vae_l1]=${DATASET_NAME}_ours_vae_l1 - [ours_vae_l2]=${DATASET_NAME}_ours_vae_l2 - [sna_l1]=${DATASET_NAME}_sna_l1 - [sna_l2]=${DATASET_NAME}_sna_l2 - [sv2p_time_variant]=${DATASET_NAME}_sv2p_time_variant - ) -else - echo "Invalid dataset name: '${DATASET_NAME}' (choose from 'bair_action_free', 'kth', 'bair)" >&2 - exit 1 -fi - -if ! [[ ${model_name_to_fname[${MODEL_NAME}]} ]]; then - echo "Invalid model name '${MODEL_NAME}' when dataset name is '${DATASET_NAME}'. Valid mode names are:" >&2 - for model_name in "${!model_name_to_fname[@]}"; do - echo "'${model_name}'" >&2 - done - exit 1 -fi -TARGET_DIR=./pretrained_models/${DATASET_NAME}/${MODEL_NAME} -mkdir -p ${TARGET_DIR} -TAR_FNAME=${model_name_to_fname[${MODEL_NAME}]}.tar.gz -URL=http://rail.eecs.berkeley.edu/models/savp/pretrained_models/${TAR_FNAME} -echo "Downloading '${TAR_FNAME}'" -wget ${URL} -O ${TARGET_DIR}/${TAR_FNAME} -tar -xvf ${TARGET_DIR}/${TAR_FNAME} -C ${TARGET_DIR} -rm ${TARGET_DIR}/${TAR_FNAME} - -echo "Succesfully finished downloading pretrained model '${MODEL_NAME}' on dataset '${DATASET_NAME}' into directory ${TARGET_DIR}" diff --git a/video_prediction_tools/deprecated/scripts/combine_results.py b/video_prediction_tools/deprecated/scripts/combine_results.py deleted file mode 100644 index ce6b00f420876a334b0af358e915318d03fb097f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/combine_results.py +++ /dev/null @@ -1,262 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import glob -import itertools -import os - -import cv2 -import numpy as np - -from video_prediction.utils import html -from video_prediction.utils.ffmpeg_gif import save_gif as ffmpeg_save_gif - - -def load_metrics(prefix_fname): - import csv - with open('%s.csv' % prefix_fname, newline='') as csvfile: - reader = csv.reader(csvfile, delimiter='\t', quotechar='|') - rows = list(reader) - # skip header (first row), indices (first column), and means (last column) - metrics = np.array(rows)[1:, 1:-1].astype(np.float32) - return metrics - - -def load_images(image_fnames): - images = [] - for image_fname in image_fnames: - image = cv2.imread(image_fname) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - images.append(image) - return images - - -def save_images(image_fnames, images): - head, tail = os.path.split(image_fnames[0]) - if head and not os.path.exists(head): - os.makedirs(head) - for image_fname, image in zip(image_fnames, images): - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - cv2.imwrite(image_fname, image) - - -def save_gif(gif_fname, images, fps=4): - import moviepy.editor as mpy - head, tail = os.path.split(gif_fname) - if head and not os.path.exists(head): - os.makedirs(head) - clip = mpy.ImageSequenceClip(list(images), fps=fps) - clip.write_gif(gif_fname) - - -def concat_images(all_images): - """ - all_images is a list of lists of images - """ - min_height, min_width = None, None - for all_image in all_images: - for image in all_image: - if min_height is None or min_width is None: - min_height, min_width = image.shape[:2] - else: - min_height = min(min_height, image.shape[0]) - min_width = min(min_width, image.shape[1]) - - def maybe_resize(image): - if image.shape[:2] != (min_height, min_width): - image = cv2.resize(image, (min_height, min_width)) - return image - - resized_all_images = [] - for all_image in all_images: - resized_all_image = [maybe_resize(image) for image in all_image] - resized_all_images.append(resized_all_image) - all_images = resized_all_images - all_images = [np.concatenate(all_image, axis=1) for all_image in zip(*all_images)] - return all_images - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("results_dir", type=str) - parser.add_argument("--method_dirs", type=str, nargs='+', help='directories in results_dir (all of them by default)') - parser.add_argument("--method_names", type=str, nargs='+', help='method names for the header') - parser.add_argument("--web_dir", type=str, help='default is results_dir/web') - parser.add_argument("--sort_by", type=str, nargs=2, help='task and metric name to sort by, e.g. prediction mse') - parser.add_argument("--no_ffmpeg", action='store_true') - parser.add_argument("--batch_size", type=int, default=1, help="number of samples in batch") - parser.add_argument("--num_samples", type=int, help="number of samples for the table of sequence (all of them by default)") - parser.add_argument("--show_se", action='store_true', help="show standard error in the table metrics") - parser.add_argument("--only_metrics", action='store_true') - args = parser.parse_args() - - if args.web_dir is None: - args.web_dir = os.path.join(args.results_dir, 'web') - webpage = html.HTML(args.web_dir, 'Experiment name = %s' % os.path.normpath(args.results_dir), reflesh=1) - webpage.add_header1(os.path.normpath(args.results_dir)) - - if args.method_dirs is None: - unsorted_method_dirs = os.listdir(args.results_dir) - # exclude web_dir and all directories that starts with web - if args.web_dir in unsorted_method_dirs: - unsorted_method_dirs.remove(args.web_dir) - unsorted_method_dirs = [method_dir for method_dir in unsorted_method_dirs if not os.path.basename(method_dir).startswith('web')] - # put ground_truth and repeat in the front (if any) - method_dirs = [] - for first_method_dir in ['ground_truth', 'repeat']: - if first_method_dir in unsorted_method_dirs: - unsorted_method_dirs.remove(first_method_dir) - method_dirs.append(first_method_dir) - method_dirs.extend(sorted(unsorted_method_dirs)) - else: - method_dirs = list(args.method_dirs) - if args.method_names is None: - method_names = list(method_dirs) - else: - method_names = list(args.method_names) - method_dirs = [os.path.join(args.results_dir, method_dir) for method_dir in method_dirs] - - if args.sort_by: - task_name, metric_name = args.sort_by - sort_criterion = [] - for method_id, (method_name, method_dir) in enumerate(zip(method_names, method_dirs)): - metric = load_metrics(os.path.join(method_dir, task_name, 'metrics', metric_name)) - sort_criterion.append(np.mean(metric)) - sort_criterion, method_ids, method_names, method_dirs = \ - zip(*sorted(zip(sort_criterion, range(len(method_names)), method_names, method_dirs))) - webpage.add_header3('sorted by %s, %s' % tuple(args.sort_by)) - else: - method_ids = range(len(method_names)) - - # infer task and metric names from first method - metric_fnames = sorted(glob.glob('%s/*/metrics/*.csv' % glob.escape(method_dirs[0]))) - task_names = [] - metric_names = [] - for metric_fname in metric_fnames: - head, tail = os.path.split(metric_fname) - task_name = head.split('/')[-2] - metric_name, _ = os.path.splitext(tail) - task_names.append(task_name) - metric_names.append(metric_name) - - # save metrics - webpage.add_table() - header_txts = [''] - header_colspans = [2] - for task_name in task_names: - if task_name != header_txts[-1]: - header_txts.append(task_name) - header_colspans.append(2 if args.show_se else 1) # mean and standard error for each task - else: - # group consecutive task names that are the same - header_colspans[-1] += 2 if args.show_se else 1 - webpage.add_row(header_txts, header_colspans) - subheader_txts = ['id', 'method'] - for task_name, metric_name in zip(task_names, metric_names): - subheader_txts.append('%s (mean)' % metric_name) - if args.show_se: - subheader_txts.append('%s (se)' % metric_name) - webpage.add_row(subheader_txts) - all_metric_means = [] - for method_id, method_name, method_dir in zip(method_ids, method_names, method_dirs): - metric_txts = [method_id, method_name] - metric_means = [] - for task_name, metric_name in zip(task_names, metric_names): - metric = load_metrics(os.path.join(method_dir, task_name, 'metrics', metric_name)) - metric_mean = np.mean(metric) - num_samples = len(metric) - metric_se = np.std(metric) / np.sqrt(num_samples) - metric_txts.append('%.4f' % metric_mean) - if args.show_se: - metric_txts.append('%.4f' % metric_se) - metric_means.append(metric_mean) - webpage.add_row(metric_txts) - all_metric_means.append(metric_means) - webpage.save() - - if args.only_metrics: - return - - # infer task names from first method - outputs_dirs = sorted(glob.glob('%s/*/outputs' % glob.escape(method_dirs[0]))) - task_names = [outputs_dir.split('/')[-2] for outputs_dir in outputs_dirs] - - # save image sequences - image_dir = os.path.join(args.web_dir, 'images') - webpage.add_table() - header_txts = [''] - subheader_txts = ['id'] - methods_subheader_txts = [''] - header_colspans = [1] - subheader_colspans = [1] - methods_subheader_colspans = [1] - num_samples = args.num_samples or num_samples - for sample_ind in range(num_samples): - if sample_ind % args.batch_size == 0: - print("saving samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) - ims = [None] - txts = [sample_ind] - links = [None] - colspans = [1] - for task_name in task_names: - # load input images from first method - input_fnames = sorted(glob.glob('%s/inputs/*_%05d_??.png' % - (glob.escape(os.path.join(method_dirs[0], task_name)), sample_ind))) - input_images = load_images(input_fnames) - # save input images as image sequence - input_fnames = [os.path.join(task_name, 'inputs', os.path.basename(input_fname)) for input_fname in input_fnames] - save_images([os.path.join(image_dir, input_fname) for input_fname in input_fnames], input_images) - # infer output names from first method - output_fnames = sorted(glob.glob('%s/outputs/*_%05d_??.png' % - (glob.escape(os.path.join(method_dirs[0], task_name)), sample_ind))) - output_names = sorted(set(os.path.splitext(os.path.basename(output_fname))[0][:-9] - for output_fname in output_fnames)) # remove _?????_??.png - # load output images - all_output_images = [] - for output_name in output_names: - for method_name, method_dir in zip(method_names, method_dirs): - output_fnames = sorted(glob.glob('%s/outputs/%s_%05d_??.png' % - (glob.escape(os.path.join(method_dir, task_name)), - output_name, sample_ind))) - output_images = load_images(output_fnames) - all_output_images.append(output_images) - # concatenate output images of all the methods - all_output_images = concat_images(all_output_images) - # save output images as image sequence or as gif clip - output_fname = os.path.join(task_name, 'outputs', '%s_%05d.gif' % ('_'.join(output_names), sample_ind)) - if args.no_ffmpeg: - save_gif(os.path.join(image_dir, output_fname), all_output_images, fps=4) - else: - ffmpeg_save_gif(os.path.join(image_dir, output_fname), all_output_images, fps=4) - - if sample_ind == 0: - header_txts.append(task_name) - subheader_txts.extend(['inputs', 'outputs']) - header_colspans.append(len(input_fnames) + len(method_ids) * len(output_names)) - subheader_colspans.extend([len(input_fnames), len(method_ids) * len(output_names)]) - method_id_strs = ['%02d' % method_id for method_id in method_ids] - methods_subheader_txts.extend([''] + list(itertools.chain(*[method_id_strs] * len(output_names)))) - methods_subheader_colspans.extend([len(input_fnames)] + [1] * (len(method_ids) * len(output_names))) - ims.extend(input_fnames + [output_fname]) - txts.extend([None] * (len(input_fnames) + 1)) - links.extend(input_fnames + [output_fname]) - colspans.extend([1] * len(input_fnames) + [len(method_ids) * len(output_names)]) - - if sample_ind == 0: - webpage.add_row(header_txts, header_colspans) - webpage.add_row(subheader_txts, subheader_colspans) - webpage.add_row(methods_subheader_txts, methods_subheader_colspans) - webpage.add_images(ims, txts, links, colspans, height=64, width=None) - if (sample_ind + 1) % args.batch_size == 0: - webpage.save() - webpage.save() - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/deprecated/scripts/evaluate.py b/video_prediction_tools/deprecated/scripts/evaluate.py deleted file mode 100644 index 11792bb196bf77a2aef3188ae599f8074c6d8da3..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/evaluate.py +++ /dev/null @@ -1,322 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import re -import argparse -import csv -import errno -import json -import os -import random - -import numpy as np -import tensorflow as tf - -from video_prediction import datasets, models - - -def save_image_sequence(prefix_fname, images, time_start_ind=0): - import cv2 - head, tail = os.path.split(prefix_fname) - if head and not os.path.exists(head): - os.makedirs(head) - for t, image in enumerate(images): - image_fname = '%s_%02d.png' % (prefix_fname, time_start_ind + t) - image = (image * 255.0).astype(np.uint8) - if image.shape[-1] == 1: - image = np.tile(image, (1, 1, 3)) - else: - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - cv2.imwrite(image_fname, image) - - -def save_image_sequences(prefix_fname, images, sample_start_ind=0, time_start_ind=0): - head, tail = os.path.split(prefix_fname) - if head and not os.path.exists(head): - os.makedirs(head) - for i, images_ in enumerate(images): - images_fname = '%s_%05d' % (prefix_fname, sample_start_ind + i) - save_image_sequence(images_fname, images_, time_start_ind=time_start_ind) - - -def save_metrics(prefix_fname, metrics, sample_start_ind=0): - head, tail = os.path.split(prefix_fname) - if head and not os.path.exists(head): - os.makedirs(head) - assert metrics.ndim == 2 - file_mode = 'w' if sample_start_ind == 0 else 'a' - with open('%s.csv' % prefix_fname, file_mode, newline='') as csvfile: - writer = csv.writer(csvfile, delimiter='\t', quotechar='|', quoting=csv.QUOTE_MINIMAL) - if sample_start_ind == 0: - writer.writerow(map(str, ['sample_ind'] + list(range(metrics.shape[1])) + ['mean'])) - for i, metrics_row in enumerate(metrics): - writer.writerow(map(str, [sample_start_ind + i] + list(metrics_row) + [np.mean(metrics_row)])) - - -def load_metrics(prefix_fname): - with open('%s.csv' % prefix_fname, newline='') as csvfile: - reader = csv.reader(csvfile, delimiter='\t', quotechar='|') - rows = list(reader) - # skip header (first row), indices (first column), and means (last column) - metrics = np.array(rows)[1:, 1:-1].astype(np.float32) - return metrics - - -def merge_hparams(hparams0, hparams1): - hparams0 = hparams0 or [] - hparams1 = hparams1 or [] - if not isinstance(hparams0, (list, tuple)): - hparams0 = [hparams0] - if not isinstance(hparams1, (list, tuple)): - hparams1 = [hparams1] - hparams = list(hparams0) + list(hparams1) - # simplify into the content if possible - if len(hparams) == 1: - hparams, = hparams - return hparams - - -def save_prediction_eval_results(task_dir, results, model_hparams, sample_start_ind=0, only_metrics=False, subtasks=None): - sequence_length = model_hparams.sequence_length - context_frames = model_hparams.context_frames - future_length = sequence_length - context_frames - - context_images = results['images'][:, :context_frames] - - if 'eval_diversity' in results: - metric = results['eval_diversity'] - metric_name = 'diversity' - subtask_dir = task_dir + '_%s' % metric_name - save_metrics(os.path.join(subtask_dir, 'metrics', metric_name), - metric, sample_start_ind=sample_start_ind) - - subtasks = subtasks or ['max'] - for subtask in subtasks: - metric_names = [] - for k in results.keys(): - if re.match('eval_(\w+)/%s' % subtask, k) and not re.match('eval_gen_images_(\w+)/%s' % subtask, k): - m = re.match('eval_(\w+)/%s' % subtask, k) - metric_names.append(m.group(1)) - for metric_name in metric_names: - subtask_dir = task_dir + '_%s_%s' % (metric_name, subtask) - gen_images = results.get('eval_gen_images_%s/%s' % (metric_name, subtask), results.get('eval_gen_images')) - # only keep the future frames - gen_images = gen_images[:, -future_length:] - metric = results['eval_%s/%s' % (metric_name, subtask)] - save_metrics(os.path.join(subtask_dir, 'metrics', metric_name), - metric, sample_start_ind=sample_start_ind) - if only_metrics: - continue - - save_image_sequences(os.path.join(subtask_dir, 'inputs', 'context_image'), - context_images, sample_start_ind=sample_start_ind) - save_image_sequences(os.path.join(subtask_dir, 'outputs', 'gen_image'), - gen_images, sample_start_ind=sample_start_ind) - - -def main(): - """ - results_dir - ├── output_dir # condition / method - │ ├── prediction_eval_lpips_max # task: best sample in terms of LPIPS similarity - │ │ ├── inputs - │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step - │ │ │ └── ... - │ │ ├── outputs - │ │ │ ├── gen_image_00000_00.png # predicted images (only the future ones) - │ │ │ └── ... - │ │ └── metrics - │ │ └── lpips.csv - │ ├── prediction_eval_ssim_max # task: best sample in terms of SSIM - │ │ ├── inputs - │ │ │ ├── context_image_00000_00.png # indexed by sample index and time step - │ │ │ └── ... - │ │ ├── outputs - │ │ │ ├── gen_image_00000_00.png # predicted images (only the future ones) - │ │ │ └── ... - │ │ └── metrics - │ │ └── ssim.csv - │ └── ... - └── ... - """ - parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " - "train, val, test, etc, or a directory containing " - "the tfrecords") - parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_dir is specified") - parser.add_argument("--output_dir", help="output directory where results are saved. default is results_dir/model_fname, " - "where model_fname is the directory name of checkpoint") - parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") - - parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') - - parser.add_argument("--dataset", type=str, help="dataset class name") - parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") - parser.add_argument("--model", type=str, help="model class name") - parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") - - parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") - parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") - parser.add_argument("--num_epochs", type=int, default=1) - - parser.add_argument("--eval_substasks", type=str, nargs='+', default=['max', 'avg', 'min'], help='subtasks to evaluate (e.g. max, avg, min)') - parser.add_argument("--only_metrics", action='store_true') - parser.add_argument("--num_stochastic_samples", type=int, default=100) - - parser.add_argument("--gt_inputs_dir", type=str, help="directory containing input ground truth images for ismple dataset") - parser.add_argument("--gt_outputs_dir", type=str, help="directory containing output ground truth images for ismple dataset") - - parser.add_argument("--eval_parallel_iterations", type=int, default=10) - parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") - parser.add_argument("--seed", type=int, default=7) - - args = parser.parse_args() - - if args.seed is not None: - tf.set_random_seed(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) - - dataset_hparams_dict = {} - model_hparams_dict = {} - if args.checkpoint: - checkpoint_dir = os.path.normpath(args.checkpoint) - if not os.path.isdir(args.checkpoint): - checkpoint_dir, _ = os.path.split(checkpoint_dir) - if not os.path.exists(checkpoint_dir): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) - with open(os.path.join(checkpoint_dir, "options.json")) as f: - print("loading options from checkpoint %s" % args.checkpoint) - options = json.loads(f.read()) - args.dataset = args.dataset or options['dataset'] - args.model = args.model or options['model'] - try: - with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: - dataset_hparams_dict = json.loads(f.read()) - except FileNotFoundError: - print("dataset_hparams.json was not loaded because it does not exist") - try: - with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: - model_hparams_dict = json.loads(f.read()) - except FileNotFoundError: - print("model_hparams.json was not loaded because it does not exist") - args.output_dir = args.output_dir or os.path.join(args.results_dir, os.path.split(checkpoint_dir)[1]) - else: - if not args.dataset: - raise ValueError('dataset is required when checkpoint is not specified') - if not args.model: - raise ValueError('model is required when checkpoint is not specified') - args.output_dir = args.output_dir or os.path.join(args.results_dir, 'model.%s' % args.model) - - print('----------------------------------- Options ------------------------------------') - for k, v in args._get_kwargs(): - print(k, "=", v) - print('------------------------------------- End --------------------------------------') - - VideoDataset = datasets.get_dataset_class(args.dataset) - dataset = VideoDataset( - args.input_dir, - mode=args.mode, - num_epochs=args.num_epochs, - seed=args.seed, - hparams_dict=dataset_hparams_dict, - hparams=args.dataset_hparams) - - VideoPredictionModel = models.get_model_class(args.model) - hparams_dict = dict(model_hparams_dict) - hparams_dict.update({ - 'context_frames': dataset.hparams.context_frames, - 'sequence_length': dataset.hparams.sequence_length, - 'repeat': dataset.hparams.time_shift, - }) - model = VideoPredictionModel( - mode=args.mode, - hparams_dict=hparams_dict, - hparams=args.model_hparams, - eval_num_samples=args.num_stochastic_samples, - eval_parallel_iterations=args.eval_parallel_iterations) - - if args.num_samples: - if args.num_samples > dataset.num_examples_per_epoch(): - raise ValueError('num_samples cannot be larger than the dataset') - num_examples_per_epoch = args.num_samples - else: - num_examples_per_epoch = dataset.num_examples_per_epoch() - if num_examples_per_epoch % args.batch_size != 0: - #bing0 - #raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) - pass - #Bing if it is era 5 data we used dataset.make_batch_v2 - #inputs = dataset.make_batch(args.batch_size) - inputs = dataset.make_batch_v2(args.batch_size) - input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} - with tf.variable_scope(''): - model.build_graph(input_phs) - - output_dir = args.output_dir - if not os.path.exists(output_dir): - os.makedirs(output_dir) - with open(os.path.join(output_dir, "options.json"), "w") as f: - f.write(json.dumps(vars(args), sort_keys=True, indent=4)) - with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: - f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) - with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: - f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) - - gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) - config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) - sess = tf.Session(config=config) - sess.graph.as_default() - - model.restore(sess, args.checkpoint) - - sample_ind = 0 - while True: - if args.num_samples and sample_ind >= args.num_samples: - break - try: - input_results = sess.run(inputs) - except tf.errors.OutOfRangeError: - break - print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) - - feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} - # compute "best" metrics using the computation graph - fetches = {'images': model.inputs['images']} - fetches.update(model.eval_outputs.items()) - fetches.update(model.eval_metrics.items()) - results = sess.run(fetches, feed_dict=feed_dict) - save_prediction_eval_results(os.path.join(output_dir, 'prediction_eval'), - results, model.hparams, sample_ind, args.only_metrics, args.eval_substasks) - sample_ind += args.batch_size - - metric_fnames = [] - metric_names = ['psnr', 'ssim', 'lpips'] - subtasks = ['max'] - for metric_name in metric_names: - for subtask in subtasks: - metric_fnames.append( - os.path.join(output_dir, 'prediction_eval_%s_%s' % (metric_name, subtask), 'metrics', metric_name)) - - for metric_fname in metric_fnames: - task_name, _, metric_name = metric_fname.split('/')[-3:] - metric = load_metrics(metric_fname) - print('=' * 31) - print(task_name, metric_name) - print('-' * 31) - metric_header_format = '{:>10} {:>20}' - metric_row_format = '{:>10} {:>10.4f} ({:>7.4f})' - print(metric_header_format.format('time step', os.path.split(metric_fname)[1])) - for t, (metric_mean, metric_std) in enumerate(zip(metric.mean(axis=0), metric.std(axis=0))): - print(metric_row_format.format(t, metric_mean, metric_std)) - print(metric_row_format.format('mean (std)', metric.mean(), metric.std())) - print('=' * 31) - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/deprecated/scripts/evaluate_all.sh b/video_prediction_tools/deprecated/scripts/evaluate_all.sh deleted file mode 100644 index c57f5c895da22f167a0a1bb4204a225965c8a48c..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/evaluate_all.sh +++ /dev/null @@ -1,44 +0,0 @@ -# BAIR action-free robot pushing dataset -dataset=bair_action_free -for method_dir in \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - ours_vae_l2 \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - sv2p_time_invariant \ -; do - CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8 -done - -# KTH human actions dataset -# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence -dataset=kth -for method_dir in \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - sv2p_time_variant \ - sv2p_time_invariant \ -; do - CUDA_VISIBLE_DEVICES=0 python scripts/evaluate.py --input_dir data/kth --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 1 -done - -# BAIR action-conditioned robot pushing dataset -dataset=bair -for method_dir in \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - ours_vae_l2 \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - sna_l1 \ - sna_l2 \ - sv2p_time_variant \ -; do - CUDA_VISIBLE_DEVICES=1 python scripts/evaluate.py --input_dir data/bair --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test --results_dir results_test/${dataset} --batch_size 8 -done diff --git a/video_prediction_tools/deprecated/scripts/generate_all.sh b/video_prediction_tools/deprecated/scripts/generate_all.sh deleted file mode 100644 index c3736b36df1840cbc46b89526cef0c908b500760..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/generate_all.sh +++ /dev/null @@ -1,55 +0,0 @@ -# BAIR action-free robot pushing dataset -dataset=bair_action_free -CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \ - --dataset_hparams sequence_length=30 --model ground_truth --mode test \ - --output_gif_dir results_test_2afc/${dataset}/ground_truth \ - --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 -for method_dir in \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - sv2p_time_invariant \ -; do - CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \ - --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \ - --results_gif_dir results_test_2afc/${dataset} \ - --results_png_dir results_test_samples/${dataset} --gif_length 10 -done - -# KTH human actions dataset -# use batch_size=1 to ensure reproducibility when sampling subclips within a sequence -dataset=kth -CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/kth --dataset kth \ - --dataset_hparams sequence_length=40 --model ground_truth --mode test \ - --output_gif_dir results_test_2afc/${dataset}/ground_truth \ - --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 --batch_size 1 -for method_dir in \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - sv2p_time_invariant \ - sv2p_time_variant \ -; do - CUDA_VISIBLE_DEVICES=1 python scripts/generate.py --input_dir data/kth \ - --dataset_hparams sequence_length=40 --checkpoint models/${dataset}/${method_dir} --mode test \ - --results_gif_dir results_test_2afc/${dataset} \ - --results_png_dir results_test_samples/${dataset} --gif_length 10 --batch_size 1 -done - -# BAIR action-conditioned robot pushing dataset -dataset=bair -CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair --dataset bair \ - --dataset_hparams sequence_length=30 --model ground_truth --mode test \ - --output_gif_dir results_test_2afc/${dataset}/ground_truth \ - --output_png_dir results_test_samples/${dataset}/ground_truth --gif_length 10 -for method_dir in \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - sv2p_time_variant \ -; do - CUDA_VISIBLE_DEVICES=0 python scripts/generate.py --input_dir data/bair \ - --dataset_hparams sequence_length=30 --checkpoint models/${dataset}/${method_dir} --mode test \ - --results_gif_dir results_test_2afc/${dataset} \ - --results_png_dir results_test_samples/${dataset} --gif_length 10 -done diff --git a/video_prediction_tools/deprecated/scripts/generate_orig.py b/video_prediction_tools/deprecated/scripts/generate_orig.py deleted file mode 100644 index 52f2a1ea9a4c7119cba5b6a063fa911392800fe1..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/generate_orig.py +++ /dev/null @@ -1,197 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import errno -import json -import os -import random - -import cv2 -import numpy as np -import tensorflow as tf - -from video_prediction import datasets, models -from video_prediction.utils.ffmpeg_gif import save_gif - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " - "train, val, test, etc, or a directory containing " - "the tfrecords") - parser.add_argument("--results_dir", type=str, default='results', help="ignored if output_gif_dir is specified") - parser.add_argument("--results_gif_dir", type=str, help="default is results_dir. ignored if output_gif_dir is specified") - parser.add_argument("--results_png_dir", type=str, help="default is results_dir. ignored if output_png_dir is specified") - parser.add_argument("--output_gif_dir", help="output directory where samples are saved as gifs. default is " - "results_gif_dir/model_fname") - parser.add_argument("--output_png_dir", help="output directory where samples are saved as pngs. default is " - "results_png_dir/model_fname") - parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") - - parser.add_argument("--mode", type=str, choices=['val', 'test'], default='val', help='mode for dataset, val or test.') - - parser.add_argument("--dataset", type=str, help="dataset class name") - parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") - parser.add_argument("--model", type=str, help="model class name") - parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") - - parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") - parser.add_argument("--num_samples", type=int, help="number of samples in total (all of them by default)") - parser.add_argument("--num_epochs", type=int, default=1) - - parser.add_argument("--num_stochastic_samples", type=int, default=5) - parser.add_argument("--gif_length", type=int, help="default is sequence_length") - parser.add_argument("--fps", type=int, default=4) - - parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") - parser.add_argument("--seed", type=int, default=7) - - args = parser.parse_args() - - if args.seed is not None: - tf.set_random_seed(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) - - args.results_gif_dir = args.results_gif_dir or args.results_dir - args.results_png_dir = args.results_png_dir or args.results_dir - dataset_hparams_dict = {} - model_hparams_dict = {} - if args.checkpoint: - checkpoint_dir = os.path.normpath(args.checkpoint) - if not os.path.isdir(args.checkpoint): - checkpoint_dir, _ = os.path.split(checkpoint_dir) - if not os.path.exists(checkpoint_dir): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) - with open(os.path.join(checkpoint_dir, "options.json")) as f: - print("loading options from checkpoint %s" % args.checkpoint) - options = json.loads(f.read()) - args.dataset = args.dataset or options['dataset'] - args.model = args.model or options['model'] - try: - with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: - dataset_hparams_dict = json.loads(f.read()) - except FileNotFoundError: - print("dataset_hparams.json was not loaded because it does not exist") - try: - with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: - model_hparams_dict = json.loads(f.read()) - except FileNotFoundError: - print("model_hparams.json was not loaded because it does not exist") - args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, os.path.split(checkpoint_dir)[1]) - args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, os.path.split(checkpoint_dir)[1]) - else: - if not args.dataset: - raise ValueError('dataset is required when checkpoint is not specified') - if not args.model: - raise ValueError('model is required when checkpoint is not specified') - args.output_gif_dir = args.output_gif_dir or os.path.join(args.results_gif_dir, 'model.%s' % args.model) - args.output_png_dir = args.output_png_dir or os.path.join(args.results_png_dir, 'model.%s' % args.model) - - print('----------------------------------- Options ------------------------------------') - for k, v in args._get_kwargs(): - print(k, "=", v) - print('------------------------------------- End --------------------------------------') - - VideoDataset = datasets.get_dataset_class(args.dataset) - dataset = VideoDataset( - args.input_dir, - mode=args.mode, - num_epochs=args.num_epochs, - seed=args.seed, - hparams_dict=dataset_hparams_dict, - hparams=args.dataset_hparams) - - VideoPredictionModel = models.get_model_class(args.model) - hparams_dict = dict(model_hparams_dict) - hparams_dict.update({ - 'context_frames': dataset.hparams.context_frames, - 'sequence_length': dataset.hparams.sequence_length, - 'repeat': dataset.hparams.time_shift, - }) - model = VideoPredictionModel( - mode=args.mode, - hparams_dict=hparams_dict, - hparams=args.model_hparams) - - sequence_length = model.hparams.sequence_length - context_frames = model.hparams.context_frames - future_length = sequence_length - context_frames - - if args.num_samples: - if args.num_samples > dataset.num_examples_per_epoch(): - raise ValueError('num_samples cannot be larger than the dataset') - num_examples_per_epoch = args.num_samples - else: - num_examples_per_epoch = dataset.num_examples_per_epoch() - if num_examples_per_epoch % args.batch_size != 0: - raise ValueError('batch_size should evenly divide the dataset size %d' % num_examples_per_epoch) - - inputs = dataset.make_batch(args.batch_size) - input_phs = {k: tf.placeholder(v.dtype, v.shape, '%s_ph' % k) for k, v in inputs.items()} - with tf.variable_scope(''): - model.build_graph(input_phs) - - for output_dir in (args.output_gif_dir, args.output_png_dir): - if not os.path.exists(output_dir): - os.makedirs(output_dir) - with open(os.path.join(output_dir, "options.json"), "w") as f: - f.write(json.dumps(vars(args), sort_keys=True, indent=4)) - with open(os.path.join(output_dir, "dataset_hparams.json"), "w") as f: - f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) - with open(os.path.join(output_dir, "model_hparams.json"), "w") as f: - f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) - - gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac) - config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) - sess = tf.Session(config=config) - sess.graph.as_default() - model.restore(sess, args.checkpoint) - - sample_ind = 0 - while True: - if args.num_samples and sample_ind >= args.num_samples: - break - try: - input_results = sess.run(inputs) - except tf.errors.OutOfRangeError: - break - print("evaluation samples from %d to %d" % (sample_ind, sample_ind + args.batch_size)) - - feed_dict = {input_ph: input_results[name] for name, input_ph in input_phs.items()} - for stochastic_sample_ind in range(args.num_stochastic_samples): - gen_images = sess.run(model.outputs['gen_images'], feed_dict=feed_dict) - # only keep the future frames - gen_images = gen_images[:, -future_length:] - for i, gen_images_ in enumerate(gen_images): - #context_images_ = (input_results['images'][i] * 255.0).astype(np.uint8) - #gen_images_ = (gen_images_ * 255.0).astype(np.uint8) - context_images_ = (input_results['images'][i]) - gen_images_ = (gen_images_) - - gen_images_fname = 'gen_image_%05d_%02d.gif' % (sample_ind + i, stochastic_sample_ind) - context_and_gen_images = list(context_images_[:context_frames]) + list(gen_images_) - if args.gif_length: - context_and_gen_images = context_and_gen_images[:args.gif_length] - save_gif(os.path.join(args.output_gif_dir, gen_images_fname), - context_and_gen_images, fps=args.fps) - gen_image_fname_pattern = 'gen_image_%%05d_%%02d_%%0%dd.png' % max(2, len(str(len(gen_images_) - 1))) - for t, gen_image in enumerate(gen_images_): - gen_image_fname = gen_image_fname_pattern % (sample_ind + i, stochastic_sample_ind, t) - if gen_image.shape[-1] == 1: - gen_image = np.tile(gen_image, (1, 1, 3)) - else: - gen_image = cv2.cvtColor(gen_image, cv2.COLOR_RGB2BGR) - cv2.imwrite(os.path.join(args.output_png_dir, gen_image_fname), gen_image) - - sample_ind += args.batch_size - - -if __name__ == '__main__': - main() \ No newline at end of file diff --git a/video_prediction_tools/deprecated/scripts/plot_results.py b/video_prediction_tools/deprecated/scripts/plot_results.py deleted file mode 100644 index 92019754ca03df9295a325fda4b511d137afd1d6..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/plot_results.py +++ /dev/null @@ -1,258 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import glob -import os - -import numpy as np - - -def load_metrics(prefix_fname): - import csv - with open('%s.csv' % prefix_fname, newline='') as csvfile: - reader = csv.reader(csvfile, delimiter='\t', quotechar='|') - rows = list(reader) - # skip header (first row), indices (first column), and means (last column) - metrics = np.array(rows)[1:, 1:-1].astype(np.float32) - return metrics - - -def plot_metric(metric, start_x=0, color=None, label=None, zorder=None): - import matplotlib.pyplot as plt - metric_mean = np.mean(metric, axis=0) - metric_se = np.std(metric, axis=0) / np.sqrt(len(metric)) - kwargs = {} - if color: - kwargs['color'] = color - if zorder: - kwargs['zorder'] = zorder - plt.errorbar(np.arange(len(metric_mean)) + start_x, - metric_mean, yerr=metric_se, linewidth=2, - label=label, **kwargs) - # metric_std = np.std(metric, axis=0) - # plt.plot(np.arange(len(metric_mean)) + start_x, metric_mean, - # linewidth=2, color=color, label=label) - # plt.fill_between(np.arange(len(metric_mean)) + start_x, - # metric_mean - metric_std, metric_mean + metric_std, - # color=color, alpha=0.5) - - -def get_color(method_name): - import matplotlib.pyplot as plt - color_mapping = { - 'ours_vae_gan': plt.cm.Vega20(0), - 'ours_gan': plt.cm.Vega20(2), - 'ours_vae': plt.cm.Vega20(4), - 'ours_vae_l1': plt.cm.Vega20(4), - 'ours_vae_l2': plt.cm.Vega20(14), - 'ours_deterministic': plt.cm.Vega20(6), - 'ours_deterministic_l1': plt.cm.Vega20(6), - 'ours_deterministic_l2': plt.cm.Vega20(10), - 'sna_l1': plt.cm.Vega20(8), - 'sna_l2': plt.cm.Vega20(9), - 'sv2p_time_variant': plt.cm.Vega20(16), - 'sv2p_time_invariant': plt.cm.Vega20(16), - 'svg_lp': plt.cm.Vega20(18), - 'svg_fp': plt.cm.Vega20(18), - 'svg_fp_resized_data_loader': plt.cm.Vega20(18), - 'mathieu': plt.cm.Vega20(8), - 'mcnet': plt.cm.Vega20(8), - 'repeat': 'k', - } - if method_name in color_mapping: - color = color_mapping[method_name] - else: - color = None - for k, v in color_mapping.items(): - if method_name.startswith(k): - color = v - break - return color - - -def get_method_name(method_name): - method_name_mapping = { - 'ours_vae_gan': 'Ours, SAVP', - 'ours_gan': 'Ours, GAN-only', - 'ours_vae': 'Ours, VAE-only', - 'ours_vae_l1': 'Ours, VAE-only, $\mathcal{L}_1$', - 'ours_vae_l2': 'Ours, VAE-only, $\mathcal{L}_2$', - 'ours_deterministic': 'Ours, deterministic', - 'ours_deterministic_l1': 'Ours, deterministic, $\mathcal{L}_1$', - 'ours_deterministic_l2': 'Ours, deterministic, $\mathcal{L}_2$', - 'sna_l1': 'SNA, $\mathcal{L}_1$ (Ebert et al.)', - 'sna_l2': 'SNA, $\mathcal{L}_2$ (Ebert et al.)', - 'sv2p_time_variant': 'SV2P time-variant (Babaeizadeh et al.)', - 'sv2p_time_invariant': 'SV2P time-invariant (Babaeizadeh et al.)', - 'mathieu': 'Mathieu et al.', - 'mcnet': 'MCnet (Villegas et al.)', - 'repeat': 'Copy last frame', - } - return method_name_mapping.get(method_name, method_name) - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("results_dir", type=str) - parser.add_argument("--dataset_name", type=str) - parser.add_argument("--method_dirs", type=str, nargs='+', help='directories in results_dir (all of them by default)') - parser.add_argument("--method_names", type=str, nargs='+', help='method names for the header') - parser.add_argument("--web_dir", type=str, help='default is results_dir/web') - parser.add_argument("--plot_fname", type=str, default='metrics.pdf') - parser.add_argument('--usetex', '--use_tex', action='store_true') - parser.add_argument('--save', action='store_true') - parser.add_argument('--mode', choices=['paper', 'rebuttal'], default='paper') - parser.add_argument("--plot_metric_names", type=str, nargs='+') - args = parser.parse_args() - - if args.save: - import matplotlib - matplotlib.use('Agg') # Must be before importing matplotlib.pyplot or pylab! - import matplotlib.pyplot as plt - - if args.usetex: - plt.rc('text', usetex=True) - plt.rc('text.latex', preview=True) - plt.rc('font', family='serif') - - if args.web_dir is None: - args.web_dir = os.path.join(args.results_dir, 'web') - - if args.method_dirs is None: - unsorted_method_dirs = os.listdir(args.results_dir) - # exclude web_dir and all directories that starts with web - if args.web_dir in unsorted_method_dirs: - unsorted_method_dirs.remove(args.web_dir) - unsorted_method_dirs = [method_dir for method_dir in unsorted_method_dirs if not os.path.basename(method_dir).startswith('web')] - # put ground_truth and repeat in the front (if any) - method_dirs = [] - for first_method_dir in ['ground_truth', 'repeat']: - if first_method_dir in unsorted_method_dirs: - unsorted_method_dirs.remove(first_method_dir) - method_dirs.append(first_method_dir) - method_dirs.extend(sorted(unsorted_method_dirs)) - else: - method_dirs = list(args.method_dirs) - if args.method_names is None: - method_names = [get_method_name(method_dir) for method_dir in method_dirs] - else: - method_names = list(args.method_names) - if args.usetex: - method_names = [method_name.replace('kl_weight', r'$\lambda_{\textsc{kl}}$') for method_name in method_names] - method_dirs = [os.path.join(args.results_dir, method_dir) for method_dir in method_dirs] - - # infer task and metric names from first method - metric_fnames = sorted(glob.glob('%s/*_max/metrics/*.csv' % glob.escape(method_dirs[0]))) - task_names = [] - metric_names = [] # all the metric names inferred from file names - for metric_fname in metric_fnames: - head, tail = os.path.split(metric_fname) - task_name = head.split('/')[-2] - metric_name, _ = os.path.splitext(tail) - task_names.append(task_name) - metric_names.append(metric_name) - - # save plots - dataset_name = args.dataset_name or os.path.split(os.path.normpath(args.results_dir))[1] - plots_dir = os.path.join(args.web_dir, 'plots') - if not os.path.exists(plots_dir): - os.makedirs(plots_dir) - - if dataset_name in ('bair', 'bair_action_free'): - context_frames = 2 - training_sequence_length = 12 - plot_metric_names = ('psnr', 'ssim_finn', 'vgg_csim') - elif dataset_name == 'kth': - context_frames = 10 - training_sequence_length = 20 - plot_metric_names = ('psnr', 'ssim_scikit', 'vgg_csim') - elif dataset_name == 'ucf101': - context_frames = 4 - training_sequence_length = 8 - plot_metric_names = ('psnr', 'ssim_mcnet', 'vgg_csim') - else: - raise NotImplementedError - plot_metric_names = args.plot_metric_names or plot_metric_names # metric names to plot - - if args.mode == 'paper': - fig = plt.figure(figsize=(4 * len(plot_metric_names), 5)) - elif args.mode == 'rebuttal': - fig = plt.figure(figsize=(4, 3 * len(plot_metric_names))) - else: - raise ValueError - i_task = 0 - for task_name, metric_name in zip(task_names, metric_names): - if not task_name.endswith('max'): - continue - if metric_name not in plot_metric_names: - continue - - if args.mode == 'paper': - plt.subplot(1, len(plot_metric_names), i_task + 1) - elif args.mode == 'rebuttal': - plt.subplot(len(plot_metric_names), 1, i_task + 1) - - for method_name, method_dir in zip(method_names, method_dirs): - metric_fname = os.path.join(method_dir, task_name, 'metrics', metric_name) - if not os.path.isfile('%s.csv' % metric_fname): - print('Skipping', metric_fname) - continue - metric = load_metrics(metric_fname) - plot_metric(metric, context_frames + 1, color=get_color(os.path.basename(method_dir)), label=method_name) - - plt.grid(axis='y') - plt.axvline(x=training_sequence_length, linewidth=1, color='k') - fontsize = 12 if args.mode == 'rebuttal' else 15 - legend_fontsize = 10 if args.mode == 'rebuttal' else 15 - labelsize = 10 - if args.mode == 'paper': - plt.xlabel('Time Step', fontsize=fontsize) - plt.ylabel({ - 'psnr': 'Average PSNR', - 'ssim': 'Average SSIM', - 'ssim_scikit': 'Average SSIM', - 'ssim_finn': 'Average SSIM', - 'ssim_mcnet': 'Average SSIM', - 'vgg_csim': 'Average VGG cosine similarity', - }[metric_name], fontsize=fontsize) - plt.xlim((context_frames + 1, metric.shape[1] + context_frames)) - plt.tick_params(labelsize=labelsize) - - if args.mode == 'paper': - if i_task == 1: - # plt.title({ - # 'bair': 'Action-conditioned BAIR Dataset', - # 'bair_action_free': 'Action-free BAIR Dataset', - # 'kth': 'KTH Dataset', - # }[dataset_name], fontsize=16) - if len(method_names) <= 4 and sum([len(method_name) for method_name in method_names]) < 90: - ncol = len(method_names) - else: - ncol = (len(method_names) + 1) // 2 - # ncol = 2 - plt.legend(bbox_to_anchor=(0.5, -0.12), loc='upper center', ncol=ncol, fontsize=legend_fontsize) - elif args.mode == 'rebuttal': - if i_task == 0: - # plt.legend(fontsize=legend_fontsize) - plt.legend(bbox_to_anchor=(0.4, -0.12), loc='upper center', fontsize=legend_fontsize) - plt.ylim(ymin=0.8) - plt.xlim((context_frames + 1, metric.shape[1] + context_frames)) - i_task += 1 - fig.tight_layout(rect=(0, 0.1, 1, 1)) - - if args.save: - plt.show(block=False) - print("Saving to", os.path.join(plots_dir, args.plot_fname)) - plt.savefig(os.path.join(plots_dir, args.plot_fname), bbox_inches='tight') - else: - plt.show() - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/deprecated/scripts/plot_results_all.sh b/video_prediction_tools/deprecated/scripts/plot_results_all.sh deleted file mode 100644 index 11045ca4831cea4c01ef355dd2c333418d83daeb..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/plot_results_all.sh +++ /dev/null @@ -1,80 +0,0 @@ -python scripts/plot_results.py results_test/bair_action_free --method_dirs \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - ours_vae_l2 \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - sv2p_time_invariant \ - svg_lp \ - --save --use_tex --plot_fname metrics_all.pdf - -python scripts/plot_results.py results_test/bair --method_dirs \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - ours_vae_l2 \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - sna_l1 \ - sna_l2 \ - sv2p_time_variant \ - --save --use_tex --plot_fname metrics_all.pdf - -python scripts/plot_results.py results_test/kth --method_dirs \ - ours_vae_gan \ - ours_gan \ - ours_vae_l1 \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - sv2p_time_variant \ - sv2p_time_invariant \ - svg_fp_resized_data_loader \ - --save --use_tex --plot_fname metrics_all.pdf - - -python scripts/plot_results.py results_test/bair_action_free --method_dirs \ - sv2p_time_invariant \ - svg_lp \ - ours_vae_gan \ - --save --use_tex --plot_fname metrics.pdf; \ -python scripts/plot_results.py results_test/bair_action_free --method_dirs \ - ours_deterministic \ - ours_vae \ - ours_gan \ - ours_vae_gan \ - --save --use_tex --plot_fname metrics_ablation.pdf; \ -python scripts/plot_results.py results_test/bair_action_free --method_dirs \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - ours_vae_l1 \ - ours_vae_l2 \ - --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf; \ -python scripts/plot_results.py results_test/kth --method_dirs \ - sv2p_time_variant \ - svg_fp_resized_data_loader \ - ours_vae_gan \ - --save --use_tex --plot_fname metrics.pdf; \ -python scripts/plot_results.py results_test/kth --method_dirs \ - ours_deterministic \ - ours_vae \ - ours_gan \ - ours_vae_gan \ - --save --use_tex --plot_fname metrics_ablation.pdf; \ -python scripts/plot_results.py results_test/bair --method_dirs \ - sv2p_time_variant \ - ours_deterministic \ - ours_vae \ - ours_gan \ - ours_vae_gan \ - --save --use_tex --plot_fname metrics.pdf; \ -python scripts/plot_results.py results_test/bair -- - -method_dirs \ - sna_l1 \ - sna_l2 \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - ours_vae_l1 \ - ours_vae_l2 \ - --save --use_tex --plot_fname metrics_ablation_l1_l2.pdf diff --git a/video_prediction_tools/deprecated/scripts/train.py b/video_prediction_tools/deprecated/scripts/train.py deleted file mode 100644 index ee76b2066f20e575727a28f05cac7a9d8b2dffb6..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/train.py +++ /dev/null @@ -1,371 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import argparse -import errno -import json -import os -import random -import time - -import numpy as np -import tensorflow as tf - -from video_prediction import datasets, models - - -def add_tag_suffix(summary, tag_suffix): - summary_proto = tf.Summary() - summary_proto.ParseFromString(summary) - summary = summary_proto - - for value in summary.value: - tag_split = value.tag.split('/') - value.tag = '/'.join([tag_split[0] + tag_suffix] + tag_split[1:]) - return summary.SerializeToString() - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("--input_dir", type=str, required=True, help="either a directory containing subdirectories " - "train, val, test, etc, or a directory containing " - "the tfrecords") - parser.add_argument("--val_input_dir", type=str, help="directories containing the tfrecords. default: input_dir") - parser.add_argument("--logs_dir", default='logs', help="ignored if output_dir is specified") - parser.add_argument("--output_dir", help="output directory where json files, summary, model, gifs, etc are saved. " - "default is logs_dir/model_fname, where model_fname consists of " - "information from model and model_hparams") - parser.add_argument("--output_dir_postfix", default="") - parser.add_argument("--checkpoint", help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") - parser.add_argument("--resume", action='store_true', help='resume from lastest checkpoint in output_dir.') - - parser.add_argument("--dataset", type=str, help="dataset class name") - parser.add_argument("--dataset_hparams", type=str, help="a string of comma separated list of dataset hyperparameters") - parser.add_argument("--dataset_hparams_dict", type=str, help="a json file of dataset hyperparameters") - parser.add_argument("--model", type=str, help="model class name") - parser.add_argument("--model_hparams", type=str, help="a string of comma separated list of model hyperparameters") - parser.add_argument("--model_hparams_dict", type=str, help="a json file of model hyperparameters") - - parser.add_argument("--summary_freq", type=int, default=1000, help="save frequency of summaries (except for image and eval summaries) for train/validation set") - parser.add_argument("--image_summary_freq", type=int, default=5000, help="save frequency of image summaries for train/validation set") - parser.add_argument("--eval_summary_freq", type=int, default=25000, help="save frequency of eval summaries for train/validation set") - parser.add_argument("--accum_eval_summary_freq", type=int, default=100000, help="save frequency of accumulated eval summaries for validation set only") - parser.add_argument("--progress_freq", type=int, default=100, help="display progress every progress_freq steps") - parser.add_argument("--save_freq", type=int, default=5000, help="save frequence of model, 0 to disable") - - parser.add_argument("--aggregate_nccl", type=int, default=0, help="whether to use nccl or cpu for gradient aggregation in multi-gpu training") - parser.add_argument("--gpu_mem_frac", type=float, default=0, help="fraction of gpu memory to use") - parser.add_argument("--seed", type=int) - - args = parser.parse_args() - - if args.seed is not None: - tf.set_random_seed(args.seed) - np.random.seed(args.seed) - random.seed(args.seed) - - if args.output_dir is None: - list_depth = 0 - model_fname = '' - for t in ('model=%s,%s' % (args.model, args.model_hparams)): - if t == '[': - list_depth += 1 - if t == ']': - list_depth -= 1 - if list_depth and t == ',': - t = '..' - if t in '=,': - t = '.' - if t in '[]': - t = '' - model_fname += t - args.output_dir = os.path.join(args.logs_dir, model_fname) + args.output_dir_postfix - - if args.resume: - if args.checkpoint: - raise ValueError('resume and checkpoint cannot both be specified') - args.checkpoint = args.output_dir - - dataset_hparams_dict = {} - model_hparams_dict = {} - if args.dataset_hparams_dict: - with open(args.dataset_hparams_dict) as f: - dataset_hparams_dict.update(json.loads(f.read())) - if args.model_hparams_dict: - with open(args.model_hparams_dict) as f: - model_hparams_dict.update(json.loads(f.read())) - if args.checkpoint: - checkpoint_dir = os.path.normpath(args.checkpoint) - if not os.path.isdir(args.checkpoint): - checkpoint_dir, _ = os.path.split(checkpoint_dir) - if not os.path.exists(checkpoint_dir): - raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), checkpoint_dir) - with open(os.path.join(checkpoint_dir, "options.json")) as f: - print("loading options from checkpoint %s" % args.checkpoint) - options = json.loads(f.read()) - args.dataset = args.dataset or options['dataset'] - args.model = args.model or options['model'] - try: - with open(os.path.join(checkpoint_dir, "dataset_hparams.json")) as f: - dataset_hparams_dict.update(json.loads(f.read())) - except FileNotFoundError: - print("dataset_hparams.json was not loaded because it does not exist") - try: - with open(os.path.join(checkpoint_dir, "model_hparams.json")) as f: - model_hparams_dict.update(json.loads(f.read())) - except FileNotFoundError: - print("model_hparams.json was not loaded because it does not exist") - - print('----------------------------------- Options ------------------------------------') - for k, v in args._get_kwargs(): - print(k, "=", v) - print('------------------------------------- End --------------------------------------') - - VideoDataset = datasets.get_dataset_class(args.dataset) - train_dataset = VideoDataset( - args.input_dir, - mode='train', - hparams_dict=dataset_hparams_dict, - hparams=args.dataset_hparams) - val_dataset = VideoDataset( - args.val_input_dir or args.input_dir, - mode='val', - hparams_dict=dataset_hparams_dict, - hparams=args.dataset_hparams) - if val_dataset.hparams.long_sequence_length != val_dataset.hparams.sequence_length: - # the longer dataset is only used for the accum_eval_metrics - long_val_dataset = VideoDataset( - args.val_input_dir or args.input_dir, - mode='val', - hparams_dict=dataset_hparams_dict, - hparams=args.dataset_hparams) - long_val_dataset.set_sequence_length(val_dataset.hparams.long_sequence_length) - else: - long_val_dataset = None - - variable_scope = tf.get_variable_scope() - variable_scope.set_use_resource(True) - - VideoPredictionModel = models.get_model_class(args.model) - hparams_dict = dict(model_hparams_dict) - hparams_dict.update({ - 'context_frames': train_dataset.hparams.context_frames,#Bing: TODO what is context_frames? - 'sequence_length': train_dataset.hparams.sequence_length,#Bing: TODO what is sequence_frames - 'repeat': train_dataset.hparams.time_shift, - }) - model = VideoPredictionModel( - hparams_dict=hparams_dict, - hparams=args.model_hparams, - aggregate_nccl=args.aggregate_nccl) - - batch_size = model.hparams.batch_size - train_tf_dataset = train_dataset.make_dataset(batch_size)#Bing: adopt the meteo data prepartion here - train_iterator = train_tf_dataset.make_one_shot_iterator()#Bing:for era5, the problem happen in sess.run(feches) should come from here - # The `Iterator.string_handle()` method returns a tensor that can be evaluated - # and used to feed the `handle` placeholder. - train_handle = train_iterator.string_handle() - val_tf_dataset = val_dataset.make_dataset(batch_size) - val_iterator = val_tf_dataset.make_one_shot_iterator() - val_handle = val_iterator.string_handle() - iterator = tf.data.Iterator.from_string_handle( - train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) - inputs = iterator.get_next() - #Bing for debug - with tf.Session() as sess: - for i in range(2): - print(sess.run(tf.shape(inputs["images"]))) - - # inputs comes from the training dataset by default, unless train_handle is remapped to the val_handles - model.build_graph(inputs) - - if long_val_dataset is not None: - # separately build a model for the longer sequence. - # this is needed because the model doesn't support dynamic shapes. - long_hparams_dict = dict(hparams_dict) - long_hparams_dict['sequence_length'] = long_val_dataset.hparams.sequence_length - # use smaller batch size for longer model to prevenet running out of memory - long_hparams_dict['batch_size'] = model.hparams.batch_size // 2 - long_model = VideoPredictionModel( - mode="test", # to not build the losses and discriminators - hparams_dict=long_hparams_dict, - hparams=args.model_hparams, - aggregate_nccl=args.aggregate_nccl) - tf.get_variable_scope().reuse_variables() - long_model.build_graph(long_val_dataset.make_batch(batch_size)) - else: - long_model = None - - if not os.path.exists(args.output_dir): - os.makedirs(args.output_dir) - with open(os.path.join(args.output_dir, "options.json"), "w") as f: - f.write(json.dumps(vars(args), sort_keys=True, indent=4)) - with open(os.path.join(args.output_dir, "dataset_hparams.json"), "w") as f: - f.write(json.dumps(train_dataset.hparams.values(), sort_keys=True, indent=4)) - with open(os.path.join(args.output_dir, "model_hparams.json"), "w") as f: - f.write(json.dumps(model.hparams.values(), sort_keys=True, indent=4)) - - with tf.name_scope("parameter_count"): - # exclude trainable variables that are replicas (used in multi-gpu setting) - trainable_variables = set(tf.trainable_variables()) & set(model.saveable_variables) - parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in trainable_variables]) - - saver = tf.train.Saver(var_list=model.saveable_variables, max_to_keep=2) - - # None has the special meaning of evaluating at the end, so explicitly check for non-equality to zero - if (args.summary_freq != 0 or args.image_summary_freq != 0 or - args.eval_summary_freq != 0 or args.accum_eval_summary_freq != 0): - summary_writer = tf.summary.FileWriter(args.output_dir) - - gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=args.gpu_mem_frac, allow_growth=True) - config = tf.ConfigProto(gpu_options=gpu_options, allow_soft_placement=True) - global_step = tf.train.get_or_create_global_step() - max_steps = model.hparams.max_steps - with tf.Session(config=config) as sess: - print("parameter_count =", sess.run(parameter_count)) - sess.run(tf.global_variables_initializer()) - sess.run(tf.local_variables_initializer()) - #coord = tf.train.Coordinator() - #threads = tf.train.start_queue_runners(sess = sess, coord = coord) - print("Init done: {sess.run(tf.local_variables_initializer())}%") - model.restore(sess, args.checkpoint) - print("Restore processed finished") - sess.run(model.post_init_ops) - print("Model run started") - val_handle_eval = sess.run(val_handle) - print("val handle done") - sess.graph.finalize() - print("graph inalize done") - start_step = sess.run(global_step) - print("global step done") - - def should(step, freq): - if freq is None: - return (step + 1) == (max_steps - start_step) - else: - return freq and ((step + 1) % freq == 0 or (step + 1) in (0, max_steps - start_step)) - - def should_eval(step, freq): - # never run eval summaries at the beginning since it's expensive, unless it's the last iteration - return should(step, freq) and (step >= 0 or (step + 1) == (max_steps - start_step)) - - # start at one step earlier to log everything without doing any training - # step is relative to the start_step - for step in range(-1, max_steps - start_step): - if step == 1: - # skip step -1 and 0 for timing purposes (for warmstarting) - start_time = time.time() - - fetches = {"global_step": global_step} - if step >= 0: - fetches["train_op"] = model.train_op - if should(step, args.progress_freq): - fetches['d_loss'] = model.d_loss - fetches['g_loss'] = model.g_loss - fetches['d_losses'] = model.d_losses - fetches['g_losses'] = model.g_losses - if isinstance(model.learning_rate, tf.Tensor): - fetches["learning_rate"] = model.learning_rate - if should(step, args.summary_freq): - fetches["summary"] = model.summary_op - if should(step, args.image_summary_freq): - fetches["image_summary"] = model.image_summary_op - if should_eval(step, args.eval_summary_freq): - fetches["eval_summary"] = model.eval_summary_op - - run_start_time = time.time() - results = sess.run(fetches) #fetch the elements in dictinoary fetch - - run_elapsed_time = time.time() - run_start_time - if run_elapsed_time > 1.5 and step > 0 and set(fetches.keys()) == {"global_step", "train_op"}: - print('running train_op took too long (%0.1fs)' % run_elapsed_time) - - if (should(step, args.summary_freq) or - should(step, args.image_summary_freq) or - should_eval(step, args.eval_summary_freq)): - val_fetches = {"global_step": global_step} - if should(step, args.summary_freq): - val_fetches["summary"] = model.summary_op - if should(step, args.image_summary_freq): - val_fetches["image_summary"] = model.image_summary_op - if should_eval(step, args.eval_summary_freq): - val_fetches["eval_summary"] = model.eval_summary_op - val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) - for name, summary in val_results.items(): - if name == 'global_step': - continue - val_results[name] = add_tag_suffix(summary, '_1') - - if should(step, args.summary_freq): - print("recording summary") - summary_writer.add_summary(results["summary"], results["global_step"]) - summary_writer.add_summary(val_results["summary"], val_results["global_step"]) - print("done") - if should(step, args.image_summary_freq): - print("recording image summary") - summary_writer.add_summary(results["image_summary"], results["global_step"]) - summary_writer.add_summary(val_results["image_summary"], val_results["global_step"]) - print("done") - if should_eval(step, args.eval_summary_freq): - print("recording eval summary") - summary_writer.add_summary(results["eval_summary"], results["global_step"]) - summary_writer.add_summary(val_results["eval_summary"], val_results["global_step"]) - print("done") - if should_eval(step, args.accum_eval_summary_freq): - val_datasets = [val_dataset] - val_models = [model] - if long_model is not None: - val_datasets.append(long_val_dataset) - val_models.append(long_model) - for i, (val_dataset_, val_model) in enumerate(zip(val_datasets, val_models)): - sess.run(val_model.accum_eval_metrics_reset_op) - # traverse (roughly up to rounding based on the batch size) all the validation dataset - accum_eval_summary_num_updates = val_dataset_.num_examples_per_epoch() // val_model.hparams.batch_size - val_fetches = {"global_step": global_step, "accum_eval_summary": val_model.accum_eval_summary_op} - for update_step in range(accum_eval_summary_num_updates): - print('evaluating %d / %d' % (update_step + 1, accum_eval_summary_num_updates)) - val_results = sess.run(val_fetches, feed_dict={train_handle: val_handle_eval}) - accum_eval_summary = add_tag_suffix(val_results["accum_eval_summary"], '_%d' % (i + 1)) - print("recording accum eval summary") - summary_writer.add_summary(accum_eval_summary, val_results["global_step"]) - print("done") - if (should(step, args.summary_freq) or should(step, args.image_summary_freq) or - should_eval(step, args.eval_summary_freq) or should_eval(step, args.accum_eval_summary_freq)): - summary_writer.flush() - if should(step, args.progress_freq): - # global_step will have the correct step count if we resume from a checkpoint - # global step is read before it's incremented - steps_per_epoch = train_dataset.num_examples_per_epoch() / batch_size - train_epoch = results["global_step"] / steps_per_epoch - print("progress global step %d epoch %0.1f" % (results["global_step"] + 1, train_epoch)) - if step > 0: - elapsed_time = time.time() - start_time - average_time = elapsed_time / step - images_per_sec = batch_size / average_time - remaining_time = (max_steps - (start_step + step + 1)) * average_time - print(" image/sec %0.1f remaining %dm (%0.1fh) (%0.1fd)" % - (images_per_sec, remaining_time / 60, remaining_time / 60 / 60, remaining_time / 60 / 60 / 24)) - - if results['d_losses']: - print("d_loss", results["d_loss"]) - for name, loss in results['d_losses'].items(): - print(" ", name, loss) - if results['g_losses']: - print("g_loss", results["g_loss"]) - for name, loss in results['g_losses'].items(): - print(" ", name, loss) - if isinstance(model.learning_rate, tf.Tensor): - print("learning_rate", results["learning_rate"]) - - if should(step, args.save_freq): - print("saving model to", args.output_dir) - saver.save(sess, os.path.join(args.output_dir, "model"), global_step=global_step) - print("done") - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/deprecated/scripts/train_all.sh b/video_prediction_tools/deprecated/scripts/train_all.sh deleted file mode 100644 index c695a8b2453956dcac54c5440c0058e5598fa03d..0000000000000000000000000000000000000000 --- a/video_prediction_tools/deprecated/scripts/train_all.sh +++ /dev/null @@ -1,40 +0,0 @@ -# BAIR action-free robot pushing dataset -for model in \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - ours_vae_l1 \ - ours_vae_l2 \ - ours_gan \ - ours_savp \ -; do - CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --model savp --model_hparams_dict hparams/bair_action_free/${model}/model_hparams.json --output_dir logs/bair_action_free/${model} -done - -# KTH human actions dataset -for model in \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - ours_vae_l1 \ - ours_gan \ - ours_savp \ -; do - CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/kth --dataset kth --model savp --model_hparams_dict hparams/kth/${model}/model_hparams.json --output_dir logs/kth/${model} -done - -# BAIR action-conditioned robot pushing dataset -for model in \ - ours_deterministic_l1 \ - ours_deterministic_l2 \ - ours_vae_l1 \ - ours_vae_l2 \ - ours_gan \ - ours_savp \ -; do - CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --dataset_hparams use_state=True --model savp --model_hparams_dict hparams/bair/${model}/model_hparams.json --output_dir logs/bair/${model} -done -for model in \ - sna_l1 \ - sna_l2 \ -; do - CUDA_VISIBLE_DEVICES=0 python scripts/train.py --input_dir data/bair --dataset bair --dataset_hparams use_state=True --model sna --model_hparams_dict hparams/bair/${model}/model_hparams.json --output_dir logs/bair/${model} -done diff --git a/video_prediction_tools/docs/discussion/20201112_AMBS_report_to_Martin.pptx b/video_prediction_tools/docs/discussion/20201112_AMBS_report_to_Martin.pptx deleted file mode 100644 index 844d31e7a358a970a2567b6826b7151a52a6971c..0000000000000000000000000000000000000000 Binary files a/video_prediction_tools/docs/discussion/20201112_AMBS_report_to_Martin.pptx and /dev/null differ diff --git a/video_prediction_tools/docs/discussion/discussion.md b/video_prediction_tools/docs/discussion/discussion.md deleted file mode 100644 index 6eddf0502ef95759769da0c37ca82cace5bc4d8b..0000000000000000000000000000000000000000 --- a/video_prediction_tools/docs/discussion/discussion.md +++ /dev/null @@ -1,50 +0,0 @@ -This is the list of last-mins files for VP group - -## 2020-03-01 - 2020-04-15 AMBS internal meeting - -- https://docs.google.com/document/d/1cQUEWrenIlW1zebZwSSHpfka2Bhb8u63kPM3x7nya_o/edit#heading=h.yjmq51s4fxnm - -## 2020-08-31 - 2020-11-04 AMBS internal meeting - -- https://docs.google.com/document/d/1mHKey_lcy6-UluVm-nrpOBoNgochOxWnnwZ4XaJ-d2c/edit?usp=sharing - - -## 2021-01-01 -- 2021-02-28 AMBS internal meeting - -- merge branches discussion https://docs.google.com/document/d/1nXRGz0vTrEVj9kTGOj1jeblc-ix_WLdBZsb7r5zu52Q/edit?usp=sharing - - -## 2020-11-12 AMBS update with Martin - -- https://docs.google.com/document/d/1rc-hImd_A0rdOTSem461vZCY_8-GZY1zvwCvl5J8BQ8/edit?usp=sharing -- Presentation: https://gitlab.version.fz-juelich.de/toar/ambs/-/blob/bing_%2337_organize_last_mins_meeting/video_prediction_tools/docs/discussion/20201112_AMBS_report_to_Martin.pptx - - -## 2020-09-11 - 2021-01-20 JUWELS Booster Early Access Program -- Instruction: How to submit jobs in container on Booster. https://docs.google.com/document/d/1t2cmjTDbNtzEYBQSfeJn11T5w-wLjgMaJm9vwlgtCMA/edit?usp=sharing -- EA program Profile (German): ea-application-DeepACF-de_stadtler+michael.docx -- EA program Profile (English):ea-application-DeepACF_Bing.docx -- convLSTM training visu animation: https://fz-juelich.sciebo.de/s/2cSpnnEzPlqZufL -- EA experiments results update: https://docs.google.com/presentation/d/1Y5uEsIAcK22c-6J9uk3UNmC3CxZaVd07qj8O7mORYYM/edit?usp=sharing - - -## Helmholtz AI consultant -- Voucher discription and objectives :https://docs.google.com/document/d/1QaXD6G4UU1zQTZo4fDfl965CZsupBzrESOO4cEmePa8/edit?usp=sharing -- Parallel training performance update: https://docs.google.com/presentation/d/1F88bQRfK9pi1VUlZppvJ6z8i7VZIgrU4078xK1Ro2PI/edit?usp=sharing -- Tensorboard and IO memory issue update: https://docs.google.com/presentation/d/1R8LRtjwQN4kHSuKsFCce-wVLXoonba2XaSv6qBN6V90/edit?usp=sharing - -## 2021-01-20 - 2021-03-31 -- Project plan, working packages, and scientific questions https://docs.google.com/document/d/1sKAHw3zuomvBj59OI-DVwxgO7Y-kwxPQ64BPhuaKWao/edit?usp=sharing -- ERA5 requirement from Olaf : https://docs.google.com/document/d/1IW2GDDFhX941JZGO2__CMTormijdbfnrsNesSNxyZgk/edit?usp=sharing -- Experiments track: https://docs.google.com/spreadsheets/d/1fHfo6upNylrjGyvqZwnmRPg_FkN9UFgbHLS7a8nCc38/edit?usp=sharing - - -## 2021-04-01 - 2021-06-30 - -- 2021 -05 -17 Bing and Michael meeting to discuss about the tasks and priority:https://gitlab.version.fz-juelich.de/hedgedoc/s/2AXz8887T - -# AMBS manuscript -1. GMD1 paper - 2m temperature by deep learning https://ifftex.fz-juelich.de/1137516136yywsrnjmqgqp -2. GMB2 paper - workflow paper https://ifftex.fz-juelich.de/8281829766nmyhfmkkqntt -3. ECCV 2019 paper : https://ifftex.fz-juelich.de/9254475755brrnmgcbvvds - diff --git a/video_prediction_tools/docs/discussion/ea-application-DeepACF-de_stadtler+michael.docx b/video_prediction_tools/docs/discussion/ea-application-DeepACF-de_stadtler+michael.docx deleted file mode 100644 index 81ed732038ce0d9a0a2308b7e7aaa68fe84aa1e4..0000000000000000000000000000000000000000 Binary files a/video_prediction_tools/docs/discussion/ea-application-DeepACF-de_stadtler+michael.docx and /dev/null differ diff --git a/video_prediction_tools/docs/discussion/ea-application-DeepACF_Bing.DOCX b/video_prediction_tools/docs/discussion/ea-application-DeepACF_Bing.DOCX deleted file mode 100644 index d14a0368f10f6c8aabf7db61c0492d6abd23147e..0000000000000000000000000000000000000000 Binary files a/video_prediction_tools/docs/discussion/ea-application-DeepACF_Bing.DOCX and /dev/null differ diff --git a/video_prediction_tools/docs/structure_name_convention.md b/video_prediction_tools/docs/structure_name_convention.md deleted file mode 100644 index 4a2679c83ea8d99b9562ef775ed2ac1190f5d7fb..0000000000000000000000000000000000000000 --- a/video_prediction_tools/docs/structure_name_convention.md +++ /dev/null @@ -1,108 +0,0 @@ -This is the output folder structure and name convention - -## Shared folder structure - -``` -├── ExtractedData -│ ├── [Year] -│ │ ├── [Month] -│ │ │ ├── **/*.netCDF -├── PreprocessedData -│ ├── [Data_name_convention] -│ │ ├── hickle -│ │ │ ├── train -│ │ │ ├── val -│ │ │ ├── test -│ │ ├── tfrecords -│ │ │ ├── train -│ │ │ ├── val -│ │ │ ├── test -├── Models -│ ├── [Data_name_convention] -│ │ ├── [model_name] -│ │ ├── [model_name] -├── Results -│ ├── [Data_name_convention] -│ │ ├── [training_mode] -│ │ │ ├── [source_data_name_convention] -│ │ │ │ ├── [model_name] - -``` - -| Arguments | Value | -|--- |--- | -| [Year] | 2005;2005;2007 ...| -| [Month] | 01;02;03 ...,12| -|[Data_name_convention]|Y[yyyy]to[yyyy]M[mm]to[mm]-[nx]_[ny]-[nn.nn]N[ee.ee]E-[var1]_[var2]_[var3]| -|[model_name]| Ours_savp; ours_gan; ours_vae; prednet| -|[training_mode]|end_to_end; transfer_learning| - - -## Data name convention - -`Y[yyyy]to[yyyy]M[mm]to[mm]-[nx]_[ny]-[nn.nn]N[ee.ee]E-[var1]_[var2]_[var3]` - - - Y[yyyy]to[yyyy]M[mm]to[mm] - - [nx]_[ny] : the size of images,e.g 64_64 means 64*64 pixels - - [nn.nn]N[ee.ee]E :the geolocation of selected regions with two decimal points. e.g : 0.00N11.50E - - [var1]_[var2]_[var3] : [Use the abbrevation of selected variables](#variable-abbrevaition-and-the-corresponding-full-names) - -### `Y[yyyy]to[yyyy]M[mm]to[mm]` - -| Examples | Name abbrevation | -|--- |--- | -|all data from March to June of the years 2005-2015 | Y2005toY2015M03to06 | -|data from February to May of years 2005-2008 + data from March to June of year 2015| Y2005to2008M02to05_Y2015M03to06 | -|Data from February to May, and October to December of 2005 | Y2005M02to05_Y2015M10to12 | -|operational’ data base: whole year 2016 | Y2016M01to12 | -|add new whole year data of 2017 on the operational data base |Y2016to2017M01to12 | -| Note: Y2016to2017M01to12 = Y2016M01to12_Y2017M01to12| - - -### variable abbrevaition and the corresponding full names - -| var | full names | -|--- |--- | -|T|2m temperature| -|gph500|500 hPa geopotential| -|msl|meansealevelpressure| - - - -### Example - -``` -├── ExtractedData -│ ├── 2016 -│ │ ├── 01 -│ │ │ ├── *.netCDF -│ │ ├── 02 -│ │ ├── 03 -│ │ ├── … -│ ├── 2017 -│ │ ├── 01 -│ │ ├── … -├── PreprocessedData -│ ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T -│ │ ├── hickle -│ │ │ ├── train -│ │ │ ├── val -│ │ │ ├── test -│ │ ├── tfrecords -│ │ │ ├── train -│ │ │ ├── val -│ │ │ ├── test -├── Models -│ ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T -│ │ ├── outs_savp -│ │ ├── outs_gan -├── Results -│ ├── 2016to2017M01to12-64_64-50.00N11.50E-T_T_T -│ │ ├── end_to_end -│ │ │ ├── ours_savp -│ │ │ ├── ours_gan -│ │ ├── transfer_learning -│ │ │ ├── 2018M01to12-64_64-50.00N11.50E-T_T_T -│ │ │ │ ├── ours_savp -``` - diff --git a/video_prediction_tools/docs/template_bug b/video_prediction_tools/docs/template_bug deleted file mode 100644 index 82a17546e44a3d18518604afd6e2911814fc6551..0000000000000000000000000000000000000000 --- a/video_prediction_tools/docs/template_bug +++ /dev/null @@ -1,37 +0,0 @@ -<!-- These are comments that will not be shown in the ticket. --> -<!-- Irrelevant sections may be deleted. --> - -## Summary - -<!-- Summarize the bug encountered concisely. --> - -## Steps to reproduce - -<!--How one can reproduce the issue - this is very important. --> - -### Environment - -<!-- How did you setup the environment (lock-file)? --> - -### What is the current bug behavior? - -<!-- What actually happens--> - -### What is the expected correct behavior? - -<!-- What you should see instead. --> - -### Can you provide a minimal working example? - -## Relevant logs and/or screenshots - -<!-- Paste any relevant logs - please use code blocks to format console output, -logs, and code as it's very hard to read otherwise. --> - - -## Possible fixes - -<!-- If you can, link to the line of code that might be responsible for the problem. --> - - -/label ~"type::bug" diff --git a/video_prediction_tools/docs/template_userstory b/video_prediction_tools/docs/template_userstory deleted file mode 100644 index 4d6628f4663b9c06b0aedecce581eba80bad7a3b..0000000000000000000000000000000000000000 --- a/video_prediction_tools/docs/template_userstory +++ /dev/null @@ -1,28 +0,0 @@ -User story -<!-- These are comments that will not be shown in the ticket. --> -<!-- Irrelevant sections may be deleted. --> - -### Problem to solve - -<!-- What problem do we solve? --> - -### Further details - -<!-- Include use cases, benefits, and/or goals (contributes to our vision?) --> - -### Proposal - -<!-- How are we going to solve the problem? --> - -### Testing - -<!-- What risks does this change pose? How might it affect the quality of the product? What additional test coverage or changes to tests will be needed? --> - -### What does success look like, and how can we measure that? - -<!-- Define both the success metrics and acceptance criteria. Note that success metrics indicate the desired business outcomes, while acceptance criteria indicate when the solution is working correctly. If there is no way to measure success, link to an issue that will implement a way to measure this. --> - -### Links / references - -/label ~"type:story" - diff --git a/video_prediction_tools/env_setup/requirements.txt b/video_prediction_tools/env_setup/requirements.txt index 9c188138ea805d0f05203938a910e6247d4dd8ac..dd5a43273f5077834916343019bf80e3f476e43a 100755 --- a/video_prediction_tools/env_setup/requirements.txt +++ b/video_prediction_tools/env_setup/requirements.txt @@ -10,4 +10,5 @@ netcdf4==1.5.8 #metadata==0.2 normalization==0.4 utils==1.0.1 - +pytest==7.1.1 +dask==2021.7.2 diff --git a/video_prediction_tools/external_package/hickle/bin/f2py b/video_prediction_tools/external_package/hickle/bin/f2py deleted file mode 100755 index fcc774fba52f3705ff41babc8dbb21dae36d2c29..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/bin/f2py +++ /dev/null @@ -1,4 +0,0 @@ -#!/usr/local/software/jureca/Stages/2018b/software/Python/3.6.6-GCCcore-7.3.0/bin/python -# EASY-INSTALL-SCRIPT: 'numpy==1.15.2','f2py' -__requires__ = 'numpy==1.15.2' -__import__('pkg_resources').run_script('numpy==1.15.2', 'f2py') diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/easy-install.pth b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/easy-install.pth deleted file mode 100644 index 09ac282550d7bba3d89ef3a91ea75877f66f0384..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/easy-install.pth +++ /dev/null @@ -1,4 +0,0 @@ -./hickle-3.4.3-py3.6.egg -/usr/local/software/jureca/Stages/2018b/software/h5py/2.8.0-ipsmpi-2018b-Python-3.6.6/lib/python3.6/site-packages/h5py-2.8.0-py3.6-linux-x86_64.egg -/usr/local/software/jureca/Stages/2018b/software/SciPy-Stack/2018b-gcccoremkl-7.3.0-2019.0.117-Python-3.6.6/lib/python3.6/site-packages/numpy-1.15.2-py3.6-linux-x86_64.egg -/usr/local/software/jureca/Stages/2018b/software/Python/3.6.6-GCCcore-7.3.0/lib/python3.6/site-packages/six-1.11.0-py3.6.egg diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/PKG-INFO b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/PKG-INFO deleted file mode 100644 index 5f8214504c72f2cfb7307cf8259de678fba12236..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/PKG-INFO +++ /dev/null @@ -1,207 +0,0 @@ -Metadata-Version: 2.1 -Name: hickle -Version: 3.4.3 -Summary: Hickle - a HDF5 based version of pickle -Home-page: http://github.com/telegraphic/hickle -Author: Danny Price -Author-email: dan@thetelegraphic.com -License: UNKNOWN -Download-URL: https://github.com/telegraphic/hickle/archive/3.4.3.tar.gz -Description: [](https://travis-ci.org/telegraphic/hickle) - [](http://joss.theoj.org/papers/0c6638f84a1a574913ed7c6dd1051847) - - - Hickle - ====== - - Hickle is a [HDF5](https://www.hdfgroup.org/solutions/hdf5/) based clone of `pickle`, with a twist: instead of serializing to a pickle file, - Hickle dumps to a HDF5 file (Hierarchical Data Format). It is designed to be a "drop-in" replacement for pickle (for common data objects), but is - really an amalgam of `h5py` and `dill`/`pickle` with extended functionality. - - That is: `hickle` is a neat little way of dumping python variables to HDF5 files that can be read in most programming - languages, not just Python. Hickle is fast, and allows for transparent compression of your data (LZF / GZIP). - - Why use Hickle? - --------------- - - While `hickle` is designed to be a drop-in replacement for `pickle` (or something like `json`), it works very differently. - Instead of serializing / json-izing, it instead stores the data using the excellent [h5py](https://www.h5py.org/) module. - - The main reasons to use hickle are: - - 1. It's faster than pickle and cPickle. - 2. It stores data in HDF5. - 3. You can easily compress your data. - - The main reasons not to use hickle are: - - 1. You don't want to store your data in HDF5. While hickle can serialize arbitrary python objects, this functionality is provided only for convenience, and you're probably better off just using the pickle module. - 2. You want to convert your data in human-readable JSON/YAML, in which case, you should do that instead. - - So, if you want your data in HDF5, or if your pickling is taking too long, give hickle a try. - Hickle is particularly good at storing large numpy arrays, thanks to `h5py` running under the hood. - - Documentation - ------------- - - Documentation for hickle can be found at [telegraphic.github.io/hickle/](http://telegraphic.github.io/hickle/). - - - Usage example - ------------- - - Hickle is nice and easy to use, and should look very familiar to those of you who have pickled before. - - In short, `hickle` provides two methods: a [hickle.load](http://telegraphic.github.io/hickle/toc.html#hickle.load) - method, for loading hickle files, and a [hickle.dump](http://telegraphic.github.io/hickle/toc.html#hickle.dump) - method, for dumping data into HDF5. Here's a complete example: - - ```python - import os - import hickle as hkl - import numpy as np - - # Create a numpy array of data - array_obj = np.ones(32768, dtype='float32') - - # Dump to file - hkl.dump(array_obj, 'test.hkl', mode='w') - - # Dump data, with compression - hkl.dump(array_obj, 'test_gzip.hkl', mode='w', compression='gzip') - - # Compare filesizes - print('uncompressed: %i bytes' % os.path.getsize('test.hkl')) - print('compressed: %i bytes' % os.path.getsize('test_gzip.hkl')) - - # Load data - array_hkl = hkl.load('test_gzip.hkl') - - # Check the two are the same file - assert array_hkl.dtype == array_obj.dtype - assert np.all((array_hkl, array_obj)) - ``` - - ### HDF5 compression options - - A major benefit of `hickle` over `pickle` is that it allows fancy HDF5 features to - be applied, by passing on keyword arguments on to `h5py`. So, you can do things like: - ```python - hkl.dump(array_obj, 'test_lzf.hkl', mode='w', compression='lzf', scaleoffset=0, - chunks=(100, 100), shuffle=True, fletcher32=True) - ``` - A detailed explanation of these keywords is given at http://docs.h5py.org/en/latest/high/dataset.html, - but we give a quick rundown below. - - In HDF5, datasets are stored as B-trees, a tree data structure that has speed benefits over contiguous - blocks of data. In the B-tree, data are split into [chunks](http://docs.h5py.org/en/latest/high/dataset.html#chunked-storage), - which is leveraged to allow [dataset resizing](http://docs.h5py.org/en/latest/high/dataset.html#resizable-datasets) and - compression via [filter pipelines](http://docs.h5py.org/en/latest/high/dataset.html#filter-pipeline). Filters such as - `shuffle` and `scaleoffset` move your data around to improve compression ratios, and `fletcher32` computes a checksum. - These file-level options are abstracted away from the data model. - - Recent changes - -------------- - - * December 2018: Accepted to Journal of Open-Source Software (JOSS). - * June 2018: Major refactor and support for Python 3. - * Aug 2016: Added support for scipy sparse matrices `bsr_matrix`, `csr_matrix` and `csc_matrix`. - - Performance comparison - ---------------------- - - Hickle runs a lot faster than pickle with its default settings, and a little faster than pickle with `protocol=2` set: - - ```Python - In [1]: import numpy as np - - In [2]: x = np.random.random((2000, 2000)) - - In [3]: import pickle - - In [4]: f = open('foo.pkl', 'w') - - In [5]: %time pickle.dump(x, f) # slow by default - CPU times: user 2 s, sys: 274 ms, total: 2.27 s - Wall time: 2.74 s - - In [6]: f = open('foo.pkl', 'w') - - In [7]: %time pickle.dump(x, f, protocol=2) # actually very fast - CPU times: user 18.8 ms, sys: 36 ms, total: 54.8 ms - Wall time: 55.6 ms - - In [8]: import hickle - - In [9]: f = open('foo.hkl', 'w') - - In [10]: %time hickle.dump(x, f) # a bit faster - dumping <type 'numpy.ndarray'> to file <HDF5 file "foo.hkl" (mode r+)> - CPU times: user 764 us, sys: 35.6 ms, total: 36.4 ms - Wall time: 36.2 ms - ``` - - So if you do continue to use pickle, add the `protocol=2` keyword (thanks @mrocklin for pointing this out). - - For storing python dictionaries of lists, hickle beats the python json encoder, but is slower than uJson. For a dictionary with 64 entries, each containing a 4096 length list of random numbers, the times are: - - - json took 2633.263 ms - uJson took 138.482 ms - hickle took 232.181 ms - - - It should be noted that these comparisons are of course not fair: storing in HDF5 will not help you convert something into JSON, nor will it help you serialize a string. But for quick storage of the contents of a python variable, it's a pretty good option. - - Installation guidelines (for Linux and Mac OS). - ----------------------------------------------- - - ### Easy method - Install with `pip` by running `pip install hickle` from the command line. - - ### Manual install - - 1. You should have Python 2.7 and above installed - - 2. Install h5py - (Official page: http://docs.h5py.org/en/latest/build.html) - - 3. Install hdf5 - (Official page: http://www.hdfgroup.org/ftp/HDF5/current/src/unpacked/release_docs/INSTALL) - - 4. Download `hickle`: - via terminal: git clone https://github.com/telegraphic/hickle.git - via manual download: Go to https://github.com/telegraphic/hickle and on right hand side you will find `Download ZIP` file - - 5. cd to your downloaded `hickle` directory - - 6. Then run the following command in the `hickle` directory: - `python setup.py install` - - ### Testing - - Once installed from source, run `python setup.py test` to check it's all working. - - - Bugs & contributing - -------------------- - - Contributions and bugfixes are very welcome. Please check out our [contribution guidelines](https://github.com/telegraphic/hickle/blob/master/CONTRIBUTING.md) - for more details on how to contribute to development. - - - Referencing hickle - ------------------ - - If you use `hickle` in academic research, we would be grateful if you could reference [our paper](http://joss.theoj.org/papers/0c6638f84a1a574913ed7c6dd1051847) in the [Journal of Open-Source Software (JOSS)](http://joss.theoj.org/about). - - ``` - Price et al., (2018). Hickle: A HDF5-based python pickle replacement. Journal of Open Source Software, 3(32), 1115, https://doi.org/10.21105/joss.01115 - ``` - -Keywords: pickle,hdf5,data storage,data export -Platform: Cross platform (Linux -Platform: Mac OSX -Platform: Windows) -Requires-Python: >=2.7 -Description-Content-Type: text/markdown diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/SOURCES.txt b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/SOURCES.txt deleted file mode 100644 index bf56f059f14d80d641efba6de75e401b4410786f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/SOURCES.txt +++ /dev/null @@ -1,52 +0,0 @@ -.gitignore -.nojekyll -.pylintrc -.travis.yml -CODE_OF_CONDUCT.md -CONTRIBUTING.md -LICENSE -README.md -_config.yml -paper.bib -paper.md -requirements.txt -setup.cfg -setup.py -docs/Makefile -docs/make_docs.sh -docs/source/conf.py -docs/source/index.md -docs/source/toc.rst -docs/source/_static/empty.txt -docs/source/_templates/empty.txt -hickle/__init__.py -hickle/helpers.py -hickle/hickle.py -hickle/hickle_legacy.py -hickle/hickle_legacy2.py -hickle/lookup.py -hickle.egg-info/PKG-INFO -hickle.egg-info/SOURCES.txt -hickle.egg-info/dependency_links.txt -hickle.egg-info/not-zip-safe -hickle.egg-info/requires.txt -hickle.egg-info/top_level.txt -hickle/loaders/__init__.py -hickle/loaders/load_astropy.py -hickle/loaders/load_numpy.py -hickle/loaders/load_pandas.py -hickle/loaders/load_python.py -hickle/loaders/load_python3.py -hickle/loaders/load_scipy.py -tests/__init__.py -tests/test_astropy.py -tests/test_hickle.py -tests/test_hickle_helpers.py -tests/test_legacy_load.py -tests/test_scipy.py -tests/legacy_hkls/generate_test_hickle.py -tests/legacy_hkls/hickle_1_1_0.hkl -tests/legacy_hkls/hickle_1_3_2.hkl -tests/legacy_hkls/hickle_1_4_0.hkl -tests/legacy_hkls/hickle_2_0_5.hkl -tests/legacy_hkls/hickle_2_1_0.hkl \ No newline at end of file diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/dependency_links.txt b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/dependency_links.txt deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/not-zip-safe b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/not-zip-safe deleted file mode 100644 index 8b137891791fe96927ad78e64b0aad7bded08bdc..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/not-zip-safe +++ /dev/null @@ -1 +0,0 @@ - diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/requires.txt b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/requires.txt deleted file mode 100644 index 8ccd55587b619ea766f8d1a76bc06739e176f552..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/requires.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy -h5py diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/top_level.txt b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/top_level.txt deleted file mode 100644 index ce3b9fb874814125f842378fab0204ff0e9184a3..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/EGG-INFO/top_level.txt +++ /dev/null @@ -1,2 +0,0 @@ -hickle -tests diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/__init__.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/__init__.py deleted file mode 100644 index 46e2ea2c6d0f5578529b3e40e060b1a244420772..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .hickle import dump, load -from .hickle import __version__ - - diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/helpers.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/helpers.py deleted file mode 100644 index 6c3d7f9f3853101723380f4658487978605f0cf3..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/helpers.py +++ /dev/null @@ -1,113 +0,0 @@ -import re -import six - -def get_type_and_data(h_node): - """ Helper function to return the py_type and data block for a HDF node """ - py_type = h_node.attrs["type"][0] - data = h_node[()] -# if h_node.shape == (): -# data = h_node.value -# else: -# data = h_node[:] - return py_type, data - -def get_type(h_node): - """ Helper function to return the py_type for a HDF node """ - py_type = h_node.attrs["type"][0] - return py_type - -def sort_keys(key_list): - """ Take a list of strings and sort it by integer value within string - - Args: - key_list (list): List of keys - - Returns: - key_list_sorted (list): List of keys, sorted by integer - """ - - # Py3 h5py returns an irritating KeysView object - # Py3 also complains about bytes and strings, convert all keys to bytes - if six.PY3: - key_list2 = [] - for key in key_list: - if isinstance(key, str): - key = bytes(key, 'ascii') - key_list2.append(key) - key_list = key_list2 - - # Check which keys contain a number - numbered_keys = [re.search(b'\d+', key) for key in key_list] - - # Sort the keys on number if they have it, or normally if not - if(len(key_list) and not numbered_keys.count(None)): - to_int = lambda x: int(re.search(b'\d+', x).group(0)) - return(sorted(key_list, key=to_int)) - else: - return(sorted(key_list)) - - -def check_is_iterable(py_obj): - """ Check whether a python object is iterable. - - Note: this treats unicode and string as NON ITERABLE - - Args: - py_obj: python object to test - - Returns: - iter_ok (bool): True if item is iterable, False is item is not - """ - if six.PY2: - string_types = (str, unicode) - else: - string_types = (str, bytes, bytearray) - if isinstance(py_obj, string_types): - return False - try: - iter(py_obj) - return True - except TypeError: - return False - - -def check_is_hashable(py_obj): - """ Check if a python object is hashable - - Note: this function is currently not used, but is useful for future - development. - - Args: - py_obj: python object to test - """ - - try: - py_obj.__hash__() - return True - except TypeError: - return False - - -def check_iterable_item_type(iter_obj): - """ Check if all items within an iterable are the same type. - - Args: - iter_obj: iterable object - - Returns: - iter_type: type of item contained within the iterable. If - the iterable has many types, a boolean False is returned instead. - - References: - http://stackoverflow.com/questions/13252333/python-check-if-all-elements-of-a-list-are-the-same-type - """ - iseq = iter(iter_obj) - - try: - first_type = type(next(iseq)) - except StopIteration: - return False - except Exception as ex: - return False - else: - return first_type if all((type(x) is first_type) for x in iseq) else False diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle.py deleted file mode 100644 index 24b38c3e1283618c9ce2c4d97b6960334cc08530..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle.py +++ /dev/null @@ -1,611 +0,0 @@ -# encoding: utf-8 -""" -# hickle.py - -Created by Danny Price 2016-02-03. - -Hickle is a HDF5 based clone of Pickle. Instead of serializing to a pickle -file, Hickle dumps to a HDF5 file. It is designed to be as similar to pickle in -usage as possible, providing a load() and dump() function. - -## Notes - -Hickle has two main advantages over Pickle: -1) LARGE PICKLE HANDLING. Unpickling a large pickle is slow, as the Unpickler -reads the entire pickle thing and loads it into memory. In comparison, HDF5 -files are designed for large datasets. Things are only loaded when accessed. - -2) CROSS PLATFORM SUPPORT. Attempting to unpickle a pickle pickled on Windows -on Linux and vice versa is likely to fail with errors like "Insecure string -pickle". HDF5 files will load fine, as long as both machines have -h5py installed. - -""" - -from __future__ import absolute_import, division, print_function -import sys -import os -from pkg_resources import get_distribution, DistributionNotFound -from ast import literal_eval - -import numpy as np -import h5py as h5 - - -from .helpers import get_type, sort_keys, check_is_iterable, check_iterable_item_type -from .lookup import types_dict, hkl_types_dict, types_not_to_sort, \ - container_types_dict, container_key_types_dict -from .lookup import check_is_ndarray_like - - -try: - from exceptions import Exception - from types import NoneType -except ImportError: - pass # above imports will fail in python3 - -from six import PY2, PY3, string_types, integer_types -import io - -# Make several aliases for Python2/Python3 compatibility -if PY3: - file = io.TextIOWrapper - -# Import a default 'pickler' -# Not the nicest import code, but should work on Py2/Py3 -try: - import dill as pickle -except ImportError: - try: - import cPickle as pickle - except ImportError: - import pickle - -import warnings - -try: - __version__ = get_distribution('hickle').version -except DistributionNotFound: - __version__ = '0.0.0 - please install via pip/setup.py' - -################## -# Error handling # -################## - -class FileError(Exception): - """ An exception raised if the file is fishy """ - def __init__(self): - return - - def __str__(self): - return ("Cannot open file. Please pass either a filename " - "string, a file object, or a h5py.File") - - -class ClosedFileError(Exception): - """ An exception raised if the file is fishy """ - def __init__(self): - return - - def __str__(self): - return ("HDF5 file has been closed. Please pass either " - "a filename string, a file object, or an open h5py.File") - - -class NoMatchError(Exception): - """ An exception raised if the object type is not understood (or - supported)""" - def __init__(self): - return - - def __str__(self): - return ("Error: this type of python object cannot be converted into a " - "hickle.") - - -class ToDoError(Exception): - """ An exception raised for non-implemented functionality""" - def __init__(self): - return - - def __str__(self): - return "Error: this functionality hasn't been implemented yet." - - -class SerializedWarning(UserWarning): - """ An object type was not understood - - The data will be serialized using pickle. - """ - pass - - -###################### -# H5PY file wrappers # -###################### - -class H5GroupWrapper(h5.Group): - """ Group wrapper that provides a track_times kwarg. - - track_times is a boolean flag that can be set to False, so that two - files created at different times will have identical MD5 hashes. - """ - def create_dataset(self, *args, **kwargs): - kwargs['track_times'] = getattr(self, 'track_times', True) - return super(H5GroupWrapper, self).create_dataset(*args, **kwargs) - - def create_group(self, *args, **kwargs): - group = super(H5GroupWrapper, self).create_group(*args, **kwargs) - group.__class__ = H5GroupWrapper - group.track_times = getattr(self, 'track_times', True) - return group - - -class H5FileWrapper(h5.File): - """ Wrapper for h5py File that provides a track_times kwarg. - - track_times is a boolean flag that can be set to False, so that two - files created at different times will have identical MD5 hashes. - """ - def create_dataset(self, *args, **kwargs): - kwargs['track_times'] = getattr(self, 'track_times', True) - return super(H5FileWrapper, self).create_dataset(*args, **kwargs) - - def create_group(self, *args, **kwargs): - group = super(H5FileWrapper, self).create_group(*args, **kwargs) - group.__class__ = H5GroupWrapper - group.track_times = getattr(self, 'track_times', True) - return group - - -def file_opener(f, mode='r', track_times=True): - """ A file opener helper function with some error handling. This can open - files through a file object, a h5py file, or just the filename. - - Args: - f (file, h5py.File, or string): File-identifier, e.g. filename or file object. - mode (str): File open mode. Only required if opening by filename string. - track_times (bool): Track time in HDF5; turn off if you want hickling at - different times to produce identical files (e.g. for MD5 hash check). - - """ - - # Assume that we will have to close the file after dump or load - close_flag = True - - # Were we handed a file object or just a file name string? - if isinstance(f, (file, io.TextIOWrapper)): - filename, mode = f.name, f.mode - f.close() - h5f = h5.File(filename, mode) - elif isinstance(f, string_types): - filename = f - h5f = h5.File(filename, mode) - elif isinstance(f, (H5FileWrapper, h5._hl.files.File)): - try: - filename = f.filename - except ValueError: - raise ClosedFileError - h5f = f - # Since this file was already open, do not close the file afterward - close_flag = False - else: - print(f.__class__) - raise FileError - - h5f.__class__ = H5FileWrapper - h5f.track_times = track_times - return(h5f, close_flag) - - -########### -# DUMPERS # -########### - - -def _dump(py_obj, h_group, call_id=0, **kwargs): - """ Dump a python object to a group within a HDF5 file. - - This function is called recursively by the main dump() function. - - Args: - py_obj: python object to dump. - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - - # Get list of dumpable dtypes - dumpable_dtypes = [] - for lst in [[bool, complex, bytes, float], string_types, integer_types]: - dumpable_dtypes.extend(lst) - - # Firstly, check if item is a numpy array. If so, just dump it. - if check_is_ndarray_like(py_obj): - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - - # Next, check if item is a dict - elif isinstance(py_obj, dict): - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - - # If not, check if item is iterable - elif check_is_iterable(py_obj): - item_type = check_iterable_item_type(py_obj) - - # item_type == False implies multiple types. Create a dataset - if item_type is False: - h_subgroup = create_hkl_group(py_obj, h_group, call_id) - for ii, py_subobj in enumerate(py_obj): - _dump(py_subobj, h_subgroup, call_id=ii, **kwargs) - - # otherwise, subitems have same type. Check if subtype is an iterable - # (e.g. list of lists), or not (e.g. list of ints, which should be treated - # as a single dataset). - else: - if item_type in dumpable_dtypes: - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - else: - h_subgroup = create_hkl_group(py_obj, h_group, call_id) - for ii, py_subobj in enumerate(py_obj): - _dump(py_subobj, h_subgroup, call_id=ii, **kwargs) - - # item is not iterable, so create a dataset for it - else: - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - - -def dump(py_obj, file_obj, mode='w', track_times=True, path='/', **kwargs): - """ Write a pickled representation of obj to the open file object file. - - Args: - obj (object): python object o store in a Hickle - file: file object, filename string, or h5py.File object - file in which to store the object. A h5py.File or a filename is also - acceptable. - mode (str): optional argument, 'r' (read only), 'w' (write) or 'a' (append). - Ignored if file is a file object. - compression (str): optional argument. Applies compression to dataset. Options: None, gzip, - lzf (+ szip, if installed) - track_times (bool): optional argument. If set to False, repeated hickling will produce - identical files. - path (str): path within hdf5 file to save data to. Defaults to root / - """ - - # Make sure that file is not closed unless modified - # This is to avoid trying to close a file that was never opened - close_flag = False - - try: - # Open the file - h5f, close_flag = file_opener(file_obj, mode, track_times) - h5f.attrs["CLASS"] = b'hickle' - h5f.attrs["VERSION"] = get_distribution('hickle').version - h5f.attrs["type"] = [b'hickle'] - # Log which version of python was used to generate the hickle file - pv = sys.version_info - py_ver = "%i.%i.%i" % (pv[0], pv[1], pv[2]) - h5f.attrs["PYTHON_VERSION"] = py_ver - - h_root_group = h5f.get(path) - - if h_root_group is None: - h_root_group = h5f.create_group(path) - h_root_group.attrs["type"] = [b'hickle'] - - _dump(py_obj, h_root_group, **kwargs) - except NoMatchError: - fname = h5f.filename - h5f.close() - try: - os.remove(fname) - except OSError: - warnings.warn("Dump failed. Could not remove %s" % fname) - finally: - raise NoMatchError - finally: - # Close the file if requested. - # Closing a file twice will not cause any problems - if close_flag: - h5f.close() - - -def create_dataset_lookup(py_obj): - """ What type of object are we trying to pickle? This is a python - dictionary based equivalent of a case statement. It returns the correct - helper function for a given data type. - - Args: - py_obj: python object to look-up what function to use to dump to disk - - Returns: - match: function that should be used to dump data to a new dataset - """ - t = type(py_obj) - types_lookup = {dict: create_dict_dataset} - types_lookup.update(types_dict) - - match = types_lookup.get(t, no_match) - - return match - - - -def create_hkl_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Create a dataset within the hickle HDF5 file - - Args: - py_obj: python object to dump. - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - - """ - #lookup dataset creator type based on python object type - create_dataset = create_dataset_lookup(py_obj) - - # do the creation - create_dataset(py_obj, h_group, call_id, **kwargs) - - -def create_hkl_group(py_obj, h_group, call_id=0): - """ Create a new group within the hickle file - - Args: - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - - """ - h_subgroup = h_group.create_group('data_%i' % call_id) - h_subgroup.attrs['type'] = [str(type(py_obj)).encode('ascii', 'ignore')] - return h_subgroup - - -def create_dict_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Creates a data group for each key in dictionary - - Notes: - This is a very important function which uses the recursive _dump - method to build up hierarchical data models stored in the HDF5 file. - As this is critical to functioning, it is kept in the main hickle.py - file instead of in the loaders/ directory. - - Args: - py_obj: python object to dump; should be dictionary - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - h_dictgroup = h_group.create_group('data_%i' % call_id) - h_dictgroup.attrs['type'] = [str(type(py_obj)).encode('ascii', 'ignore')] - - for key, py_subobj in py_obj.items(): - if isinstance(key, string_types): - h_subgroup = h_dictgroup.create_group("%r" % (key)) - else: - h_subgroup = h_dictgroup.create_group(str(key)) - h_subgroup.attrs["type"] = [b'dict_item'] - - h_subgroup.attrs["key_type"] = [str(type(key)).encode('ascii', 'ignore')] - - _dump(py_subobj, h_subgroup, call_id=0, **kwargs) - - -def no_match(py_obj, h_group, call_id=0, **kwargs): - """ If no match is made, raise an exception - - Args: - py_obj: python object to dump; default if item is not matched. - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - pickled_obj = pickle.dumps(py_obj) - d = h_group.create_dataset('data_%i' % call_id, data=[pickled_obj]) - d.attrs["type"] = [b'pickle'] - - warnings.warn("%s type not understood, data have been serialized" % type(py_obj), - SerializedWarning) - - - -############# -## LOADERS ## -############# - -class PyContainer(list): - """ A group-like object into which to load datasets. - - In order to build up a tree-like structure, we need to be able - to load datasets into a container with an append() method. - Python tuples and sets do not allow this. This class provides - a list-like object that be converted into a list, tuple, set or dict. - """ - def __init__(self): - super(PyContainer, self).__init__() - self.container_type = None - self.name = None - self.key_type = None - - def convert(self): - """ Convert from PyContainer to python core data type. - - Returns: self, either as a list, tuple, set or dict - (or other type specified in lookup.py) - """ - - if self.container_type in container_types_dict.keys(): - convert_fn = container_types_dict[self.container_type] - return convert_fn(self) - if self.container_type == str(dict).encode('ascii', 'ignore'): - keys = [] - for item in self: - key = item.name.split('/')[-1] - key_type = item.key_type[0] - if key_type in container_key_types_dict.keys(): - to_type_fn = container_key_types_dict[key_type] - key = to_type_fn(key) - keys.append(key) - - items = [item[0] for item in self] - return dict(zip(keys, items)) - else: - return self - -def no_match_load(key): - """ If no match is made when loading, need to raise an exception - """ - raise RuntimeError("Cannot load %s data type" % key) - #pass - -def load_dataset_lookup(key): - """ What type of object are we trying to unpickle? This is a python - dictionary based equivalent of a case statement. It returns the type - a given 'type' keyword in the hickle file. - - Args: - py_obj: python object to look-up what function to use to dump to disk - - Returns: - match: function that should be used to dump data to a new dataset - """ - - match = hkl_types_dict.get(key, no_match_load) - - return match - -def load(fileobj, path='/', safe=True): - """ Load a hickle file and reconstruct a python object - - Args: - fileobj: file object, h5py.File, or filename string - safe (bool): Disable automatic depickling of arbitrary python objects. - DO NOT set this to False unless the file is from a trusted source. - (see http://www.cs.jhu.edu/~s/musings/pickle.html for an explanation) - - path (str): path within hdf5 file to save data to. Defaults to root / - """ - - # Make sure that the file is not closed unless modified - # This is to avoid trying to close a file that was never opened - close_flag = False - - try: - h5f, close_flag = file_opener(fileobj) - h_root_group = h5f.get(path) - try: - assert 'CLASS' in h5f.attrs.keys() - assert 'VERSION' in h5f.attrs.keys() - VER = h5f.attrs['VERSION'] - try: - VER_MAJOR = int(VER) - except ValueError: - VER_MAJOR = int(VER[0]) - if VER_MAJOR == 1: - if PY2: - warnings.warn("Hickle file versioned as V1, attempting legacy loading...") - from . import hickle_legacy - return hickle_legacy.load(fileobj, safe) - else: - raise RuntimeError("Cannot open file. This file was likely" - " created with Python 2 and an old hickle version.") - elif VER_MAJOR == 2: - if PY2: - warnings.warn("Hickle file appears to be old version (v2), attempting " - "legacy loading...") - from . import hickle_legacy2 - return hickle_legacy2.load(fileobj, path=path, safe=safe) - else: - raise RuntimeError("Cannot open file. This file was likely" - " created with Python 2 and an old hickle version.") - # There is an unfortunate period of time where hickle 2.1.0 claims VERSION = int(3) - # For backward compatibility we really need to catch this. - # Actual hickle v3 files are versioned as A.B.C (e.g. 3.1.0) - elif VER_MAJOR == 3 and VER == VER_MAJOR: - if PY2: - warnings.warn("Hickle file appears to be old version (v2.1.0), attempting " - "legacy loading...") - from . import hickle_legacy2 - return hickle_legacy2.load(fileobj, path=path, safe=safe) - else: - raise RuntimeError("Cannot open file. This file was likely" - " created with Python 2 and an old hickle version.") - elif VER_MAJOR >= 3: - py_container = PyContainer() - py_container.container_type = 'hickle' - py_container = _load(py_container, h_root_group) - return py_container[0][0] - - except AssertionError: - if PY2: - warnings.warn("Hickle file is not versioned, attempting legacy loading...") - from . import hickle_legacy - return hickle_legacy.load(fileobj, safe) - else: - raise RuntimeError("Cannot open file. This file was likely" - " created with Python 2 and an old hickle version.") - finally: - # Close the file if requested. - # Closing a file twice will not cause any problems - if close_flag: - h5f.close() - -def load_dataset(h_node): - """ Load a dataset, converting into its correct python type - - Args: - h_node (h5py dataset): h5py dataset object to read - - Returns: - data: reconstructed python object from loaded data - """ - py_type = get_type(h_node) - - try: - load_fn = load_dataset_lookup(py_type) - return load_fn(h_node) - except: - raise - #raise RuntimeError("Hickle type %s not understood." % py_type) - -def _load(py_container, h_group): - """ Load a hickle file - - Recursive funnction to load hdf5 data into a PyContainer() - - Args: - py_container (PyContainer): Python container to load data into - h_group (h5 group or dataset): h5py object, group or dataset, to spider - and load all datasets. - """ - - group_dtype = h5._hl.group.Group - dataset_dtype = h5._hl.dataset.Dataset - - #either a file, group, or dataset - if isinstance(h_group, (H5FileWrapper, group_dtype)): - - py_subcontainer = PyContainer() - try: - py_subcontainer.container_type = bytes(h_group.attrs['type'][0]) - except KeyError: - raise - #py_subcontainer.container_type = '' - py_subcontainer.name = h_group.name - - if py_subcontainer.container_type == b'dict_item': - py_subcontainer.key_type = h_group.attrs['key_type'] - - if py_subcontainer.container_type not in types_not_to_sort: - h_keys = sort_keys(h_group.keys()) - else: - h_keys = h_group.keys() - - for h_name in h_keys: - h_node = h_group[h_name] - py_subcontainer = _load(py_subcontainer, h_node) - - sub_data = py_subcontainer.convert() - py_container.append(sub_data) - - else: - # must be a dataset - subdata = load_dataset(h_group) - py_container.append(subdata) - - return py_container diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle_legacy.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle_legacy.py deleted file mode 100644 index 61a171fde3d39304d78d1ddede9656dd7ad50940..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle_legacy.py +++ /dev/null @@ -1,535 +0,0 @@ -# encoding: utf-8 -""" -# hickle_legacy.py - -Created by Danny Price 2012-05-28. - -Hickle is a HDF5 based clone of Pickle. Instead of serializing to a -pickle file, Hickle dumps to a HDF5 file. It is designed to be as similar -to pickle in usage as possible. - -## Notes - -This is a legacy handler, for hickle v1 files. -If V2 reading fails, this will be called as a fail-over. - -""" - -import os -import sys -import numpy as np -import h5py as h5 - -if sys.version_info.major == 3: - NoneType = type(None) -else: - from types import NoneType - -__version__ = "1.3.0" -__author__ = "Danny Price" - -#################### -## Error handling ## -#################### - - -class FileError(Exception): - """ An exception raised if the file is fishy""" - - def __init__(self): - return - - def __str__(self): - print("Error: cannot open file. Please pass either a filename string, a file object, " - "or a h5py.File") - - -class NoMatchError(Exception): - """ An exception raised if the object type is not understood (or supported)""" - - def __init__(self): - return - - def __str__(self): - print("Error: this type of python object cannot be converted into a hickle.") - - -class ToDoError(Exception): - """ An exception raised for non-implemented functionality""" - - def __init__(self): - return - - def __str__(self): - print("Error: this functionality hasn't been implemented yet.") - - -class H5GroupWrapper(h5.Group): - def create_dataset(self, *args, **kwargs): - kwargs['track_times'] = getattr(self, 'track_times', True) - return super(H5GroupWrapper, self).create_dataset(*args, **kwargs) - - def create_group(self, *args, **kwargs): - group = super(H5GroupWrapper, self).create_group(*args, **kwargs) - group.__class__ = H5GroupWrapper - group.track_times = getattr(self, 'track_times', True) - return group - - -class H5FileWrapper(h5.File): - def create_dataset(self, *args, **kwargs): - kwargs['track_times'] = getattr(self, 'track_times', True) - return super(H5FileWrapper, self).create_dataset(*args, **kwargs) - - def create_group(self, *args, **kwargs): - group = super(H5FileWrapper, self).create_group(*args, **kwargs) - group.__class__ = H5GroupWrapper - group.track_times = getattr(self, 'track_times', True) - return group - - -def file_opener(f, mode='r', track_times=True): - """ A file opener helper function with some error handling. - - This can open files through a file object, a h5py file, or just the filename. - """ - # Were we handed a file object or just a file name string? - if isinstance(f, file): - filename, mode = f.name, f.mode - f.close() - h5f = h5.File(filename, mode) - - elif isinstance(f, h5._hl.files.File): - h5f = f - elif isinstance(f, str): - filename = f - h5f = h5.File(filename, mode) - else: - raise FileError - - h5f.__class__ = H5FileWrapper - h5f.track_times = track_times - return h5f - - -############# -## dumpers ## -############# - -def dump_ndarray(obj, h5f, **kwargs): - """ dumps an ndarray object to h5py file""" - h5f.create_dataset('data', data=obj, **kwargs) - h5f.create_dataset('type', data=['ndarray']) - - -def dump_np_dtype(obj, h5f, **kwargs): - """ dumps an np dtype object to h5py file""" - h5f.create_dataset('data', data=obj) - h5f.create_dataset('type', data=['np_dtype']) - - -def dump_np_dtype_dict(obj, h5f, **kwargs): - """ dumps an np dtype object within a group""" - h5f.create_dataset('data', data=obj) - h5f.create_dataset('_data', data=['np_dtype']) - - -def dump_masked(obj, h5f, **kwargs): - """ dumps an ndarray object to h5py file""" - h5f.create_dataset('data', data=obj, **kwargs) - h5f.create_dataset('mask', data=obj.mask, **kwargs) - h5f.create_dataset('type', data=['masked']) - - -def dump_list(obj, h5f, **kwargs): - """ dumps a list object to h5py file""" - - # Check if there are any numpy arrays in the list - contains_numpy = any(isinstance(el, np.ndarray) for el in obj) - - if contains_numpy: - _dump_list_np(obj, h5f, **kwargs) - else: - h5f.create_dataset('data', data=obj, **kwargs) - h5f.create_dataset('type', data=['list']) - - -def _dump_list_np(obj, h5f, **kwargs): - """ Dump a list of numpy objects to file """ - - np_group = h5f.create_group('data') - h5f.create_dataset('type', data=['np_list']) - - ii = 0 - for np_item in obj: - np_group.create_dataset("%s" % ii, data=np_item, **kwargs) - ii += 1 - - -def dump_tuple(obj, h5f, **kwargs): - """ dumps a list object to h5py file""" - - # Check if there are any numpy arrays in the list - contains_numpy = any(isinstance(el, np.ndarray) for el in obj) - - if contains_numpy: - _dump_tuple_np(obj, h5f, **kwargs) - else: - h5f.create_dataset('data', data=obj, **kwargs) - h5f.create_dataset('type', data=['tuple']) - - -def _dump_tuple_np(obj, h5f, **kwargs): - """ Dump a tuple of numpy objects to file """ - - np_group = h5f.create_group('data') - h5f.create_dataset('type', data=['np_tuple']) - - ii = 0 - for np_item in obj: - np_group.create_dataset("%s" % ii, data=np_item, **kwargs) - ii += 1 - - -def dump_set(obj, h5f, **kwargs): - """ dumps a set object to h5py file""" - obj = list(obj) - h5f.create_dataset('data', data=obj, **kwargs) - h5f.create_dataset('type', data=['set']) - - -def dump_string(obj, h5f, **kwargs): - """ dumps a list object to h5py file""" - h5f.create_dataset('data', data=[obj], **kwargs) - h5f.create_dataset('type', data=['string']) - - -def dump_none(obj, h5f, **kwargs): - """ Dump None type to file """ - h5f.create_dataset('data', data=[0], **kwargs) - h5f.create_dataset('type', data=['none']) - - -def dump_unicode(obj, h5f, **kwargs): - """ dumps a list object to h5py file""" - dt = h5.special_dtype(vlen=unicode) - ll = len(obj) - dset = h5f.create_dataset('data', shape=(ll, ), dtype=dt, **kwargs) - dset[:ll] = obj - h5f.create_dataset('type', data=['unicode']) - - -def _dump_dict(dd, hgroup, **kwargs): - for key in dd: - if type(dd[key]) in (str, int, float, unicode, bool): - # Figure out type to be stored - types = {str: 'str', int: 'int', float: 'float', - unicode: 'unicode', bool: 'bool', NoneType: 'none'} - _key = types.get(type(dd[key])) - - # Store along with dtype info - if _key == 'unicode': - dd[key] = str(dd[key]) - - hgroup.create_dataset("%s" % key, data=[dd[key]], **kwargs) - hgroup.create_dataset("_%s" % key, data=[_key]) - - elif type(dd[key]) in (type(np.array([1])), type(np.ma.array([1]))): - - if hasattr(dd[key], 'mask'): - hgroup.create_dataset("_%s" % key, data=["masked"]) - hgroup.create_dataset("%s" % key, data=dd[key].data, **kwargs) - hgroup.create_dataset("_%s_mask" % key, data=dd[key].mask, **kwargs) - else: - hgroup.create_dataset("_%s" % key, data=["ndarray"]) - hgroup.create_dataset("%s" % key, data=dd[key], **kwargs) - - elif type(dd[key]) is list: - hgroup.create_dataset("%s" % key, data=dd[key], **kwargs) - hgroup.create_dataset("_%s" % key, data=["list"]) - - elif type(dd[key]) is tuple: - hgroup.create_dataset("%s" % key, data=dd[key], **kwargs) - hgroup.create_dataset("_%s" % key, data=["tuple"]) - - elif type(dd[key]) is set: - hgroup.create_dataset("%s" % key, data=list(dd[key]), **kwargs) - hgroup.create_dataset("_%s" % key, data=["set"]) - - elif isinstance(dd[key], dict): - new_group = hgroup.create_group("%s" % key) - _dump_dict(dd[key], new_group, **kwargs) - - elif type(dd[key]) is NoneType: - hgroup.create_dataset("%s" % key, data=[0], **kwargs) - hgroup.create_dataset("_%s" % key, data=["none"]) - - else: - if type(dd[key]).__module__ == np.__name__: - #print type(dd[key]) - hgroup.create_dataset("%s" % key, data=dd[key]) - hgroup.create_dataset("_%s" % key, data=["np_dtype"]) - #new_group = hgroup.create_group("%s" % key) - #dump_np_dtype_dict(dd[key], new_group) - else: - raise NoMatchError - - -def dump_dict(obj, h5f='', **kwargs): - """ dumps a dictionary to h5py file """ - h5f.create_dataset('type', data=['dict']) - hgroup = h5f.create_group('data') - _dump_dict(obj, hgroup, **kwargs) - - -def no_match(obj, h5f, *args, **kwargs): - """ If no match is made, raise an exception """ - try: - import dill as cPickle - except ImportError: - import cPickle - - pickled_obj = cPickle.dumps(obj) - h5f.create_dataset('type', data=['pickle']) - h5f.create_dataset('data', data=[pickled_obj]) - - print("Warning: %s type not understood, data have been serialized" % type(obj)) - #raise NoMatchError - - -def dumper_lookup(obj): - """ What type of object are we trying to pickle? - - This is a python dictionary based equivalent of a case statement. - It returns the correct helper function for a given data type. - """ - t = type(obj) - - types = { - list: dump_list, - tuple: dump_tuple, - set: dump_set, - dict: dump_dict, - str: dump_string, - unicode: dump_unicode, - NoneType: dump_none, - np.ndarray: dump_ndarray, - np.ma.core.MaskedArray: dump_masked, - np.float16: dump_np_dtype, - np.float32: dump_np_dtype, - np.float64: dump_np_dtype, - np.int8: dump_np_dtype, - np.int16: dump_np_dtype, - np.int32: dump_np_dtype, - np.int64: dump_np_dtype, - np.uint8: dump_np_dtype, - np.uint16: dump_np_dtype, - np.uint32: dump_np_dtype, - np.uint64: dump_np_dtype, - np.complex64: dump_np_dtype, - np.complex128: dump_np_dtype, - } - - match = types.get(t, no_match) - return match - - -def dump(obj, file, mode='w', track_times=True, **kwargs): - """ Write a pickled representation of obj to the open file object file. - - Parameters - ---------- - obj: object - python object o store in a Hickle - file: file object, filename string, or h5py.File object - file in which to store the object. A h5py.File or a filename is also acceptable. - mode: string - optional argument, 'r' (read only), 'w' (write) or 'a' (append). Ignored if file - is a file object. - compression: str - optional argument. Applies compression to dataset. Options: None, gzip, lzf (+ szip, - if installed) - track_times: bool - optional argument. If set to False, repeated hickling will produce identical files. - """ - - try: - # See what kind of object to dump - dumper = dumper_lookup(obj) - # Open the file - h5f = file_opener(file, mode, track_times) - print("dumping %s to file %s" % (type(obj), repr(h5f))) - dumper(obj, h5f, **kwargs) - h5f.close() - except NoMatchError: - fname = h5f.filename - h5f.close() - try: - os.remove(fname) - except: - print("Warning: dump failed. Could not remove %s" % fname) - finally: - raise NoMatchError - - -############# -## loaders ## -############# - -def load(file, safe=True): - """ Load a hickle file and reconstruct a python object - - Parameters - ---------- - file: file object, h5py.File, or filename string - - safe (bool): Disable automatic depickling of arbitrary python objects. - DO NOT set this to False unless the file is from a trusted source. - (see http://www.cs.jhu.edu/~s/musings/pickle.html for an explanation) - """ - - try: - h5f = file_opener(file) - dtype = h5f["type"][0] - - if dtype == 'dict': - group = h5f["data"] - data = load_dict(group) - elif dtype == 'pickle': - data = load_pickle(h5f, safe) - elif dtype == 'np_list': - group = h5f["data"] - data = load_np_list(group) - elif dtype == 'np_tuple': - group = h5f["data"] - data = load_np_tuple(group) - elif dtype == 'masked': - data = np.ma.array(h5f["data"][:], mask=h5f["mask"][:]) - elif dtype == 'none': - data = None - else: - if dtype in ('string', 'unicode'): - data = h5f["data"][0] - else: - try: - data = h5f["data"][:] - except ValueError: - data = h5f["data"] - types = { - 'list': list, - 'set': set, - 'unicode': unicode, - 'string': str, - 'ndarray': load_ndarray, - 'np_dtype': load_np_dtype - } - - mod = types.get(dtype, no_match) - data = mod(data) - finally: - if 'h5f' in locals(): - h5f.close() - return data - - -def load_pickle(h5f, safe=True): - """ Deserialize and load a pickled object within a hickle file - - WARNING: Pickle has - - Parameters - ---------- - h5f: h5py.File object - - safe (bool): Disable automatic depickling of arbitrary python objects. - DO NOT set this to False unless the file is from a trusted source. - (see http://www.cs.jhu.edu/~s/musings/pickle.html for an explanation) - """ - - if not safe: - try: - import dill as cPickle - except ImportError: - import cPickle - - data = h5f["data"][:] - data = cPickle.loads(data[0]) - return data - else: - print("\nWarning: Object is of an unknown type, and has not been loaded") - print(" for security reasons (it could be malicious code). If") - print(" you wish to continue, manually set safe=False\n") - - -def load_np_list(group): - """ load a numpy list """ - np_list = [] - for key in sorted(group.keys()): - data = group[key][:] - np_list.append(data) - return np_list - - -def load_np_tuple(group): - """ load a tuple containing numpy arrays """ - return tuple(load_np_list(group)) - - -def load_ndarray(arr): - """ Load a numpy array """ - # Nothing to be done! - return arr - - -def load_np_dtype(arr): - """ Load a numpy array """ - # Just return first value - return arr.value - - -def load_dict(group): - """ Load dictionary """ - - dd = {} - for key in group.keys(): - if isinstance(group[key], h5._hl.group.Group): - new_group = group[key] - dd[key] = load_dict(new_group) - elif not key.startswith("_"): - _key = "_%s" % key - - if group[_key][0] == 'np_dtype': - dd[key] = group[key].value - elif group[_key][0] in ('str', 'int', 'float', 'unicode', 'bool'): - dd[key] = group[key][0] - elif group[_key][0] == 'masked': - key_ma = "_%s_mask" % key - dd[key] = np.ma.array(group[key][:], mask=group[key_ma]) - else: - dd[key] = group[key][:] - - # Convert numpy constructs back to string - dtype = group[_key][0] - types = {'str': str, 'int': int, 'float': float, - 'unicode': unicode, 'bool': bool, 'list': list, 'none' : NoneType} - try: - mod = types.get(dtype) - if dtype == 'none': - dd[key] = None - else: - dd[key] = mod(dd[key]) - except: - pass - return dd - - -def load_large(file): - """ Load a large hickle file (returns the h5py object not the data) - - Parameters - ---------- - file: file object, h5py.File, or filename string - """ - - h5f = file_opener(file) - return h5f diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle_legacy2.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle_legacy2.py deleted file mode 100644 index 4d018fde9a161713213b00190267439257cb876d..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/hickle_legacy2.py +++ /dev/null @@ -1,672 +0,0 @@ -# encoding: utf-8 -""" -# hickle_legacy2.py - -Created by Danny Price 2016-02-03. - -This is a legacy handler, for hickle v2 files. -If V3 reading fails, this will be called as a fail-over. - -""" - -import os -import numpy as np -import h5py as h5 -import re - -try: - from exceptions import Exception - from types import NoneType -except ImportError: - pass # above imports will fail in python3 - -import warnings -__version__ = "2.0.4" -__author__ = "Danny Price" - - -################## -# Error handling # -################## - -class FileError(Exception): - """ An exception raised if the file is fishy """ - def __init__(self): - return - - def __str__(self): - return ("Cannot open file. Please pass either a filename " - "string, a file object, or a h5py.File") - - -class ClosedFileError(Exception): - """ An exception raised if the file is fishy """ - def __init__(self): - return - - def __str__(self): - return ("HDF5 file has been closed. Please pass either " - "a filename string, a file object, or an open h5py.File") - - -class NoMatchError(Exception): - """ An exception raised if the object type is not understood (or - supported)""" - def __init__(self): - return - - def __str__(self): - return ("Error: this type of python object cannot be converted into a " - "hickle.") - - -class ToDoError(Exception): - """ An exception raised for non-implemented functionality""" - def __init__(self): - return - - def __str__(self): - return "Error: this functionality hasn't been implemented yet." - - -###################### -# H5PY file wrappers # -###################### - -class H5GroupWrapper(h5.Group): - """ Group wrapper that provides a track_times kwarg. - - track_times is a boolean flag that can be set to False, so that two - files created at different times will have identical MD5 hashes. - """ - def create_dataset(self, *args, **kwargs): - kwargs['track_times'] = getattr(self, 'track_times', True) - return super(H5GroupWrapper, self).create_dataset(*args, **kwargs) - - def create_group(self, *args, **kwargs): - group = super(H5GroupWrapper, self).create_group(*args, **kwargs) - group.__class__ = H5GroupWrapper - group.track_times = getattr(self, 'track_times', True) - return group - - -class H5FileWrapper(h5.File): - """ Wrapper for h5py File that provides a track_times kwarg. - - track_times is a boolean flag that can be set to False, so that two - files created at different times will have identical MD5 hashes. - """ - def create_dataset(self, *args, **kwargs): - kwargs['track_times'] = getattr(self, 'track_times', True) - return super(H5FileWrapper, self).create_dataset(*args, **kwargs) - - def create_group(self, *args, **kwargs): - group = super(H5FileWrapper, self).create_group(*args, **kwargs) - group.__class__ = H5GroupWrapper - group.track_times = getattr(self, 'track_times', True) - return group - - -def file_opener(f, mode='r', track_times=True): - """ A file opener helper function with some error handling. This can open - files through a file object, a h5py file, or just the filename. - - Args: - f (file, h5py.File, or string): File-identifier, e.g. filename or file object. - mode (str): File open mode. Only required if opening by filename string. - track_times (bool): Track time in HDF5; turn off if you want hickling at - different times to produce identical files (e.g. for MD5 hash check). - - """ - # Were we handed a file object or just a file name string? - if isinstance(f, file): - filename, mode = f.name, f.mode - f.close() - h5f = h5.File(filename, mode) - elif isinstance(f, str) or isinstance(f, unicode): - filename = f - h5f = h5.File(filename, mode) - elif isinstance(f, H5FileWrapper) or isinstance(f, h5._hl.files.File): - try: - filename = f.filename - except ValueError: - raise ClosedFileError() - h5f = f - else: - print(type(f)) - raise FileError - - h5f.__class__ = H5FileWrapper - h5f.track_times = track_times - return h5f - - -########### -# DUMPERS # -########### - -def check_is_iterable(py_obj): - """ Check whether a python object is iterable. - - Note: this treats unicode and string as NON ITERABLE - - Args: - py_obj: python object to test - - Returns: - iter_ok (bool): True if item is iterable, False is item is not - """ - if type(py_obj) in (str, unicode): - return False - try: - iter(py_obj) - return True - except TypeError: - return False - - -def check_iterable_item_type(iter_obj): - """ Check if all items within an iterable are the same type. - - Args: - iter_obj: iterable object - - Returns: - iter_type: type of item contained within the iterable. If - the iterable has many types, a boolean False is returned instead. - - References: - http://stackoverflow.com/questions/13252333/python-check-if-all-elements-of-a-list-are-the-same-type - """ - iseq = iter(iter_obj) - first_type = type(next(iseq)) - return first_type if all((type(x) is first_type) for x in iseq) else False - - -def check_is_numpy_array(py_obj): - """ Check if a python object is a numpy array (masked or regular) - - Args: - py_obj: python object to check whether it is a numpy array - - Returns - is_numpy (bool): Returns True if it is a numpy array, else False if it isn't - """ - - is_numpy = type(py_obj) in (type(np.array([1])), type(np.ma.array([1]))) - - return is_numpy - - -def _dump(py_obj, h_group, call_id=0, **kwargs): - """ Dump a python object to a group within a HDF5 file. - - This function is called recursively by the main dump() function. - - Args: - py_obj: python object to dump. - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - - dumpable_dtypes = set([bool, int, float, long, complex, str, unicode]) - - # Firstly, check if item is a numpy array. If so, just dump it. - if check_is_numpy_array(py_obj): - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - - # next, check if item is iterable - elif check_is_iterable(py_obj): - item_type = check_iterable_item_type(py_obj) - - # item_type == False implies multiple types. Create a dataset - if item_type is False: - h_subgroup = create_hkl_group(py_obj, h_group, call_id) - for ii, py_subobj in enumerate(py_obj): - _dump(py_subobj, h_subgroup, call_id=ii, **kwargs) - - # otherwise, subitems have same type. Check if subtype is an iterable - # (e.g. list of lists), or not (e.g. list of ints, which should be treated - # as a single dataset). - else: - if item_type in dumpable_dtypes: - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - else: - h_subgroup = create_hkl_group(py_obj, h_group, call_id) - for ii, py_subobj in enumerate(py_obj): - #print py_subobj, h_subgroup, ii - _dump(py_subobj, h_subgroup, call_id=ii, **kwargs) - - # item is not iterable, so create a dataset for it - else: - create_hkl_dataset(py_obj, h_group, call_id, **kwargs) - - -def dump(py_obj, file_obj, mode='w', track_times=True, path='/', **kwargs): - """ Write a pickled representation of obj to the open file object file. - - Args: - obj (object): python object o store in a Hickle - file: file object, filename string, or h5py.File object - file in which to store the object. A h5py.File or a filename is also - acceptable. - mode (str): optional argument, 'r' (read only), 'w' (write) or 'a' (append). - Ignored if file is a file object. - compression (str): optional argument. Applies compression to dataset. Options: None, gzip, - lzf (+ szip, if installed) - track_times (bool): optional argument. If set to False, repeated hickling will produce - identical files. - path (str): path within hdf5 file to save data to. Defaults to root / - """ - - try: - # Open the file - h5f = file_opener(file_obj, mode, track_times) - h5f.attrs["CLASS"] = 'hickle' - h5f.attrs["VERSION"] = 2 - h5f.attrs["type"] = ['hickle'] - - h_root_group = h5f.get(path) - - if h_root_group is None: - h_root_group = h5f.create_group(path) - h_root_group.attrs["type"] = ['hickle'] - - _dump(py_obj, h_root_group, **kwargs) - h5f.close() - except NoMatchError: - fname = h5f.filename - h5f.close() - try: - os.remove(fname) - except OSError: - warnings.warn("Dump failed. Could not remove %s" % fname) - finally: - raise NoMatchError - - -def create_dataset_lookup(py_obj): - """ What type of object are we trying to pickle? This is a python - dictionary based equivalent of a case statement. It returns the correct - helper function for a given data type. - - Args: - py_obj: python object to look-up what function to use to dump to disk - - Returns: - match: function that should be used to dump data to a new dataset - """ - t = type(py_obj) - - types = { - dict: create_dict_dataset, - list: create_listlike_dataset, - tuple: create_listlike_dataset, - set: create_listlike_dataset, - str: create_stringlike_dataset, - unicode: create_stringlike_dataset, - int: create_python_dtype_dataset, - float: create_python_dtype_dataset, - long: create_python_dtype_dataset, - bool: create_python_dtype_dataset, - complex: create_python_dtype_dataset, - NoneType: create_none_dataset, - np.ndarray: create_np_array_dataset, - np.ma.core.MaskedArray: create_np_array_dataset, - np.float16: create_np_dtype_dataset, - np.float32: create_np_dtype_dataset, - np.float64: create_np_dtype_dataset, - np.int8: create_np_dtype_dataset, - np.int16: create_np_dtype_dataset, - np.int32: create_np_dtype_dataset, - np.int64: create_np_dtype_dataset, - np.uint8: create_np_dtype_dataset, - np.uint16: create_np_dtype_dataset, - np.uint32: create_np_dtype_dataset, - np.uint64: create_np_dtype_dataset, - np.complex64: create_np_dtype_dataset, - np.complex128: create_np_dtype_dataset - } - - match = types.get(t, no_match) - return match - - -def create_hkl_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Create a dataset within the hickle HDF5 file - - Args: - py_obj: python object to dump. - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - - """ - #lookup dataset creator type based on python object type - create_dataset = create_dataset_lookup(py_obj) - - # do the creation - create_dataset(py_obj, h_group, call_id, **kwargs) - - -def create_hkl_group(py_obj, h_group, call_id=0): - """ Create a new group within the hickle file - - Args: - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - - """ - h_subgroup = h_group.create_group('data_%i' % call_id) - h_subgroup.attrs["type"] = [str(type(py_obj))] - return h_subgroup - - -def create_listlike_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Dumper for list, set, tuple - - Args: - py_obj: python object to dump; should be list-like - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - dtype = str(type(py_obj)) - obj = list(py_obj) - d = h_group.create_dataset('data_%i' % call_id, data=obj, **kwargs) - d.attrs["type"] = [dtype] - - -def create_np_dtype_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps an np dtype object to h5py file - - Args: - py_obj: python object to dump; should be a numpy scalar, e.g. np.float16(1) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - d = h_group.create_dataset('data_%i' % call_id, data=py_obj, **kwargs) - d.attrs["type"] = ['np_dtype'] - d.attrs["np_dtype"] = str(d.dtype) - - -def create_python_dtype_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps a python dtype object to h5py file - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - d = h_group.create_dataset('data_%i' % call_id, data=py_obj, - dtype=type(py_obj), **kwargs) - d.attrs["type"] = ['python_dtype'] - d.attrs['python_subdtype'] = str(type(py_obj)) - - -def create_dict_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Creates a data group for each key in dictionary - - Args: - py_obj: python object to dump; should be dictionary - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - h_dictgroup = h_group.create_group('data_%i' % call_id) - h_dictgroup.attrs["type"] = ['dict'] - for key, py_subobj in py_obj.items(): - h_subgroup = h_dictgroup.create_group(key) - h_subgroup.attrs["type"] = ['dict_item'] - _dump(py_subobj, h_subgroup, call_id=0, **kwargs) - - -def create_np_array_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps an ndarray object to h5py file - - Args: - py_obj: python object to dump; should be a numpy array or np.ma.array (masked) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - if isinstance(py_obj, type(np.ma.array([1]))): - d = h_group.create_dataset('data_%i' % call_id, data=py_obj, **kwargs) - #m = h_group.create_dataset('mask_%i' % call_id, data=py_obj.mask, **kwargs) - m = h_group.create_dataset('data_%i_mask' % call_id, data=py_obj.mask, **kwargs) - d.attrs["type"] = ['ndarray_masked_data'] - m.attrs["type"] = ['ndarray_masked_mask'] - else: - d = h_group.create_dataset('data_%i' % call_id, data=py_obj, **kwargs) - d.attrs["type"] = ['ndarray'] - - -def create_stringlike_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps a list object to h5py file - - Args: - py_obj: python object to dump; should be string-like (unicode or string) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - if isinstance(py_obj, str): - d = h_group.create_dataset('data_%i' % call_id, data=[py_obj], **kwargs) - d.attrs["type"] = ['string'] - else: - dt = h5.special_dtype(vlen=unicode) - dset = h_group.create_dataset('data_%i' % call_id, shape=(1, ), dtype=dt, **kwargs) - dset[0] = py_obj - dset.attrs['type'] = ['unicode'] - - -def create_none_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Dump None type to file - - Args: - py_obj: python object to dump; must be None object - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - d = h_group.create_dataset('data_%i' % call_id, data=[0], **kwargs) - d.attrs["type"] = ['none'] - - -def no_match(py_obj, h_group, call_id=0, **kwargs): - """ If no match is made, raise an exception - - Args: - py_obj: python object to dump; default if item is not matched. - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - try: - import dill as cPickle - except ImportError: - import cPickle - - pickled_obj = cPickle.dumps(py_obj) - d = h_group.create_dataset('data_%i' % call_id, data=[pickled_obj]) - d.attrs["type"] = ['pickle'] - - warnings.warn("%s type not understood, data have been " - "serialized" % type(py_obj)) - - -############# -## LOADERS ## -############# - -class PyContainer(list): - """ A group-like object into which to load datasets. - - In order to build up a tree-like structure, we need to be able - to load datasets into a container with an append() method. - Python tuples and sets do not allow this. This class provides - a list-like object that be converted into a list, tuple, set or dict. - """ - def __init__(self): - super(PyContainer, self).__init__() - self.container_type = None - self.name = None - - def convert(self): - """ Convert from PyContainer to python core data type. - - Returns: self, either as a list, tuple, set or dict - """ - if self.container_type == "<type 'list'>": - return list(self) - if self.container_type == "<type 'tuple'>": - return tuple(self) - if self.container_type == "<type 'set'>": - return set(self) - if self.container_type == "dict": - keys = [str(item.name.split('/')[-1]) for item in self] - items = [item[0] for item in self] - return dict(zip(keys, items)) - else: - return self - - -def load(fileobj, path='/', safe=True): - """ Load a hickle file and reconstruct a python object - - Args: - fileobj: file object, h5py.File, or filename string - safe (bool): Disable automatic depickling of arbitrary python objects. - DO NOT set this to False unless the file is from a trusted source. - (see http://www.cs.jhu.edu/~s/musings/pickle.html for an explanation) - - path (str): path within hdf5 file to save data to. Defaults to root / - """ - - try: - h5f = file_opener(fileobj) - h_root_group = h5f.get(path) - - try: - assert 'CLASS' in h5f.attrs.keys() - assert 'VERSION' in h5f.attrs.keys() - py_container = PyContainer() - py_container.container_type = 'hickle' - py_container = _load(py_container, h_root_group) - return py_container[0][0] - except AssertionError: - import hickle_legacy - return hickle_legacy.load(fileobj, safe) - finally: - if 'h5f' in locals(): - h5f.close() - - -def load_dataset(h_node): - """ Load a dataset, converting into its correct python type - - Args: - h_node (h5py dataset): h5py dataset object to read - - Returns: - data: reconstructed python object from loaded data - """ - py_type = h_node.attrs["type"][0] - - if h_node.shape == (): - data = h_node.value - else: - data = h_node[:] - - if py_type == "<type 'list'>": - #print self.name - return list(data) - elif py_type == "<type 'tuple'>": - return tuple(data) - elif py_type == "<type 'set'>": - return set(data) - elif py_type == "np_dtype": - subtype = h_node.attrs["np_dtype"] - data = np.array(data, dtype=subtype) - return data - elif py_type == 'ndarray': - return np.array(data) - elif py_type == 'ndarray_masked_data': - try: - mask_path = h_node.name + "_mask" - h_root = h_node.parent - mask = h_root.get(mask_path)[:] - except IndexError: - mask = h_root.get(mask_path) - except ValueError: - mask = h_root.get(mask_path) - data = np.ma.array(data, mask=mask) - return data - elif py_type == 'python_dtype': - subtype = h_node.attrs["python_subdtype"] - type_dict = { - "<type 'int'>": int, - "<type 'float'>": float, - "<type 'long'>": long, - "<type 'bool'>": bool, - "<type 'complex'>": complex - } - tcast = type_dict.get(subtype) - return tcast(data) - elif py_type == 'string': - return str(data[0]) - elif py_type == 'unicode': - return unicode(data[0]) - elif py_type == 'none': - return None - else: - print(h_node.name, py_type, h_node.attrs.keys()) - return data - - -def sort_keys(key_list): - """ Take a list of strings and sort it by integer value within string - - Args: - key_list (list): List of keys - - Returns: - key_list_sorted (list): List of keys, sorted by integer - """ - to_int = lambda x: int(re.search('\d+', x).group(0)) - keys_by_int = sorted([(to_int(key), key) for key in key_list]) - return [ii[1] for ii in keys_by_int] - - -def _load(py_container, h_group): - """ Load a hickle file - - Recursive funnction to load hdf5 data into a PyContainer() - - Args: - py_container (PyContainer): Python container to load data into - h_group (h5 group or dataset): h5py object, group or dataset, to spider - and load all datasets. - """ - - group_dtype = h5._hl.group.Group - dataset_dtype = h5._hl.dataset.Dataset - - #either a file, group, or dataset - if isinstance(h_group, H5FileWrapper) or isinstance(h_group, group_dtype): - py_subcontainer = PyContainer() - py_subcontainer.container_type = h_group.attrs['type'][0] - py_subcontainer.name = h_group.name - - if py_subcontainer.container_type != 'dict': - h_keys = sort_keys(h_group.keys()) - else: - h_keys = h_group.keys() - - for h_name in h_keys: - h_node = h_group[h_name] - py_subcontainer = _load(py_subcontainer, h_node) - - sub_data = py_subcontainer.convert() - py_container.append(sub_data) - - else: - # must be a dataset - subdata = load_dataset(h_group) - py_container.append(subdata) - - #print h_group.name, py_container - return py_container diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/__init__.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/__init__.py deleted file mode 100644 index 3be6bd298581fb3086bb5a261de72a56970faddf..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from __future__ import absolute_import \ No newline at end of file diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_astropy.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_astropy.py deleted file mode 100644 index dd8efce655c2223262b42868cbb1d9ba5c580acb..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_astropy.py +++ /dev/null @@ -1,237 +0,0 @@ -import numpy as np -from astropy.units import Quantity -from astropy.coordinates import Angle, SkyCoord -from astropy.constants import Constant, EMConstant -from astropy.table import Table -from astropy.time import Time - -from hickle.helpers import get_type_and_data -import six - -def create_astropy_quantity(py_obj, h_group, call_id=0, **kwargs): - """ dumps an astropy quantity - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - # kwarg compression etc does not work on scalars - d = h_group.create_dataset('data_%i' % call_id, data=py_obj.value, - dtype='float64') #, **kwargs) - d.attrs["type"] = [b'astropy_quantity'] - if six.PY3: - unit = bytes(str(py_obj.unit), 'ascii') - else: - unit = str(py_obj.unit) - d.attrs['unit'] = [unit] - -def create_astropy_angle(py_obj, h_group, call_id=0, **kwargs): - """ dumps an astropy quantity - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - # kwarg compression etc does not work on scalars - d = h_group.create_dataset('data_%i' % call_id, data=py_obj.value, - dtype='float64') #, **kwargs) - d.attrs["type"] = [b'astropy_angle'] - if six.PY3: - unit = str(py_obj.unit).encode('ascii') - else: - unit = str(py_obj.unit) - d.attrs['unit'] = [unit] - -def create_astropy_skycoord(py_obj, h_group, call_id=0, **kwargs): - """ dumps an astropy quantity - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - # kwarg compression etc does not work on scalars - lat = py_obj.data.lat.value - lon = py_obj.data.lon.value - dd = np.column_stack((lon, lat)) - - d = h_group.create_dataset('data_%i' % call_id, data=dd, - dtype='float64') #, **kwargs) - d.attrs["type"] = [b'astropy_skycoord'] - if six.PY3: - lon_unit = str(py_obj.data.lon.unit).encode('ascii') - lat_unit = str(py_obj.data.lat.unit).encode('ascii') - else: - lon_unit = str(py_obj.data.lon.unit) - lat_unit = str(py_obj.data.lat.unit) - d.attrs['lon_unit'] = [lon_unit] - d.attrs['lat_unit'] = [lat_unit] - -def create_astropy_time(py_obj, h_group, call_id=0, **kwargs): - """ dumps an astropy Time object - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - - # kwarg compression etc does not work on scalars - data = py_obj.value - dtype = str(py_obj.value.dtype) - - # Need to catch string times - if '<U' in dtype: - dtype = dtype.replace('<U', '|S') - print(dtype) - data = [] - for item in py_obj.value: - data.append(str(item).encode('ascii')) - - d = h_group.create_dataset('data_%i' % call_id, data=data, dtype=dtype) #, **kwargs) - d.attrs["type"] = [b'astropy_time'] - if six.PY2: - fmt = str(py_obj.format) - scale = str(py_obj.scale) - else: - fmt = str(py_obj.format).encode('ascii') - scale = str(py_obj.scale).encode('ascii') - d.attrs['format'] = [fmt] - d.attrs['scale'] = [scale] - -def create_astropy_constant(py_obj, h_group, call_id=0, **kwargs): - """ dumps an astropy constant - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - # kwarg compression etc does not work on scalars - d = h_group.create_dataset('data_%i' % call_id, data=py_obj.value, - dtype='float64') #, **kwargs) - d.attrs["type"] = [b'astropy_constant'] - d.attrs["unit"] = [str(py_obj.unit)] - d.attrs["abbrev"] = [str(py_obj.abbrev)] - d.attrs["name"] = [str(py_obj.name)] - d.attrs["reference"] = [str(py_obj.reference)] - d.attrs["uncertainty"] = [py_obj.uncertainty] - - if py_obj.system: - d.attrs["system"] = [py_obj.system] - - -def create_astropy_table(py_obj, h_group, call_id=0, **kwargs): - """ Dump an astropy Table - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - data = py_obj.as_array() - d = h_group.create_dataset('data_%i' % call_id, data=data, dtype=data.dtype, **kwargs) - d.attrs['type'] = [b'astropy_table'] - - if six.PY3: - colnames = [bytes(cn, 'ascii') for cn in py_obj.colnames] - else: - colnames = py_obj.colnames - d.attrs['colnames'] = colnames - for key, value in py_obj.meta.items(): - d.attrs[key] = value - - -def load_astropy_quantity_dataset(h_node): - py_type, data = get_type_and_data(h_node) - unit = h_node.attrs["unit"][0] - q = Quantity(data, unit) - return q - -def load_astropy_time_dataset(h_node): - py_type, data = get_type_and_data(h_node) - if six.PY3: - fmt = h_node.attrs["format"][0].decode('ascii') - scale = h_node.attrs["scale"][0].decode('ascii') - else: - fmt = h_node.attrs["format"][0] - scale = h_node.attrs["scale"][0] - q = Time(data, format=fmt, scale=scale) - return q - -def load_astropy_angle_dataset(h_node): - py_type, data = get_type_and_data(h_node) - unit = h_node.attrs["unit"][0] - q = Angle(data, unit) - return q - -def load_astropy_skycoord_dataset(h_node): - py_type, data = get_type_and_data(h_node) - lon_unit = h_node.attrs["lon_unit"][0] - lat_unit = h_node.attrs["lat_unit"][0] - q = SkyCoord(data[:,0], data[:, 1], unit=(lon_unit, lat_unit)) - return q - -def load_astropy_constant_dataset(h_node): - py_type, data = get_type_and_data(h_node) - unit = h_node.attrs["unit"][0] - abbrev = h_node.attrs["abbrev"][0] - name = h_node.attrs["name"][0] - ref = h_node.attrs["reference"][0] - unc = h_node.attrs["uncertainty"][0] - - system = None - if "system" in h_node.attrs.keys(): - system = h_node.attrs["system"][0] - - c = Constant(abbrev, name, data, unit, unc, ref, system) - return c - -def load_astropy_table(h_node): - py_type, data = get_type_and_data(h_node) - metadata = dict(h_node.attrs.items()) - metadata.pop('type') - metadata.pop('colnames') - - if six.PY3: - colnames = [cn.decode('ascii') for cn in h_node.attrs["colnames"]] - else: - colnames = h_node.attrs["colnames"] - - t = Table(data, names=colnames, meta=metadata) - return t - -def check_is_astropy_table(py_obj): - return isinstance(py_obj, Table) - -def check_is_astropy_quantity_array(py_obj): - if isinstance(py_obj, Quantity) or isinstance(py_obj, Time) or \ - isinstance(py_obj, Angle) or isinstance(py_obj, SkyCoord): - if py_obj.isscalar: - return False - else: - return True - else: - return False - - -##################### -# Lookup dictionary # -##################### - -class_register = [ - [Quantity, b'astropy_quantity', create_astropy_quantity, load_astropy_quantity_dataset, - True, check_is_astropy_quantity_array], - [Time, b'astropy_time', create_astropy_time, load_astropy_time_dataset, - True, check_is_astropy_quantity_array], - [Angle, b'astropy_angle', create_astropy_angle, load_astropy_angle_dataset, - True, check_is_astropy_quantity_array], - [SkyCoord, b'astropy_skycoord', create_astropy_skycoord, load_astropy_skycoord_dataset, - True, check_is_astropy_quantity_array], - [Constant, b'astropy_constant', create_astropy_constant, load_astropy_constant_dataset, - True, None], - [Table, b'astropy_table', create_astropy_table, load_astropy_table, - True, check_is_astropy_table] -] diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_numpy.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_numpy.py deleted file mode 100644 index 7a31b12e235b07cccb6b1f0045ca9ccbfb874454..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_numpy.py +++ /dev/null @@ -1,145 +0,0 @@ -# encoding: utf-8 -""" -# load_numpy.py - -Utilities and dump / load handlers for handling numpy and scipy arrays - -""" -import six -import numpy as np - - -from hickle.helpers import get_type_and_data - - -def check_is_numpy_array(py_obj): - """ Check if a python object is a numpy array (masked or regular) - - Args: - py_obj: python object to check whether it is a numpy array - - Returns - is_numpy (bool): Returns True if it is a numpy array, else False if it isn't - """ - - is_numpy = type(py_obj) in (type(np.array([1])), type(np.ma.array([1]))) - - return is_numpy - - -def create_np_scalar_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps an np dtype object to h5py file - - Args: - py_obj: python object to dump; should be a numpy scalar, e.g. np.float16(1) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - - # DO NOT PASS KWARGS TO SCALAR DATASETS! - d = h_group.create_dataset('data_%i' % call_id, data=py_obj) # **kwargs) - d.attrs["type"] = [b'np_scalar'] - - if six.PY2: - d.attrs["np_dtype"] = str(d.dtype) - else: - d.attrs["np_dtype"] = bytes(str(d.dtype), 'ascii') - - -def create_np_dtype(py_obj, h_group, call_id=0, **kwargs): - """ dumps an np dtype object to h5py file - - Args: - py_obj: python object to dump; should be a numpy scalar, e.g. np.float16(1) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - d = h_group.create_dataset('data_%i' % call_id, data=[str(py_obj)]) - d.attrs["type"] = [b'np_dtype'] - - -def create_np_array_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps an ndarray object to h5py file - - Args: - py_obj: python object to dump; should be a numpy array or np.ma.array (masked) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - if isinstance(py_obj, type(np.ma.array([1]))): - d = h_group.create_dataset('data_%i' % call_id, data=py_obj, **kwargs) - #m = h_group.create_dataset('mask_%i' % call_id, data=py_obj.mask, **kwargs) - m = h_group.create_dataset('data_%i_mask' % call_id, data=py_obj.mask, **kwargs) - d.attrs["type"] = [b'ndarray_masked_data'] - m.attrs["type"] = [b'ndarray_masked_mask'] - else: - d = h_group.create_dataset('data_%i' % call_id, data=py_obj, **kwargs) - d.attrs["type"] = [b'ndarray'] - - - - -####################### -## Lookup dictionary ## -####################### - -types_dict = { - np.ndarray: create_np_array_dataset, - np.ma.core.MaskedArray: create_np_array_dataset, - np.float16: create_np_scalar_dataset, - np.float32: create_np_scalar_dataset, - np.float64: create_np_scalar_dataset, - np.int8: create_np_scalar_dataset, - np.int16: create_np_scalar_dataset, - np.int32: create_np_scalar_dataset, - np.int64: create_np_scalar_dataset, - np.uint8: create_np_scalar_dataset, - np.uint16: create_np_scalar_dataset, - np.uint32: create_np_scalar_dataset, - np.uint64: create_np_scalar_dataset, - np.complex64: create_np_scalar_dataset, - np.complex128: create_np_scalar_dataset, - np.dtype: create_np_dtype -} - -def load_np_dtype_dataset(h_node): - py_type, data = get_type_and_data(h_node) - data = np.dtype(data[0]) - return data - -def load_np_scalar_dataset(h_node): - py_type, data = get_type_and_data(h_node) - subtype = h_node.attrs["np_dtype"] - data = np.array([data], dtype=subtype)[0] - return data - -def load_ndarray_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return np.array(data, copy=False) - -def load_ndarray_masked_dataset(h_node): - py_type, data = get_type_and_data(h_node) - try: - mask_path = h_node.name + "_mask" - h_root = h_node.parent - mask = h_root.get(mask_path)[:] - except IndexError: - mask = h_root.get(mask_path) - except ValueError: - mask = h_root.get(mask_path) - data = np.ma.array(data, mask=mask) - return data - -def load_nothing(h_hode): - pass - -hkl_types_dict = { - b"np_dtype" : load_np_dtype_dataset, - b"np_scalar" : load_np_scalar_dataset, - b"ndarray" : load_ndarray_dataset, - b"numpy.ndarray" : load_ndarray_dataset, - b"ndarray_masked_data" : load_ndarray_masked_dataset, - b"ndarray_masked_mask" : load_nothing # Loaded autormatically -} - - diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_pandas.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_pandas.py deleted file mode 100644 index 0b5185533dafe9d2f8b2c45405967d7489ce7caf..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_pandas.py +++ /dev/null @@ -1,4 +0,0 @@ -import pandas as pd - -# TODO: populate with classes to load -class_register = [] \ No newline at end of file diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_python.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_python.py deleted file mode 100644 index 58de921ed13e2e9b0c57ad724e94fa2ac9a3268f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_python.py +++ /dev/null @@ -1,141 +0,0 @@ -# encoding: utf-8 -""" -# load_python.py - -Handlers for dumping and loading built-in python types. -NB: As these are for built-in types, they are critical to the functioning of hickle. - -""" - -from hickle.helpers import get_type_and_data - -import sys -if sys.version_info.major == 3: - unicode = type(str) - str = type(bytes) - long = type(int) - NoneType = type(None) -else: - from types import NoneType - -import h5py as h5 - -def create_listlike_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Dumper for list, set, tuple - - Args: - py_obj: python object to dump; should be list-like - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - dtype = str(type(py_obj)) - obj = list(py_obj) - d = h_group.create_dataset('data_%i' % call_id, data=obj, **kwargs) - d.attrs["type"] = [dtype] - - -def create_python_dtype_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps a python dtype object to h5py file - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - # kwarg compression etc does not work on scalars - d = h_group.create_dataset('data_%i' % call_id, data=py_obj, - dtype=type(py_obj)) #, **kwargs) - d.attrs["type"] = ['python_dtype'] - d.attrs['python_subdtype'] = str(type(py_obj)) - - -def create_stringlike_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps a list object to h5py file - - Args: - py_obj: python object to dump; should be string-like (unicode or string) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - if isinstance(py_obj, str): - d = h_group.create_dataset('data_%i' % call_id, data=[py_obj], **kwargs) - d.attrs["type"] = ['string'] - else: - dt = h5.special_dtype(vlen=unicode) - dset = h_group.create_dataset('data_%i' % call_id, shape=(1, ), dtype=dt, **kwargs) - dset[0] = py_obj - dset.attrs['type'] = ['unicode'] - - -def create_none_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Dump None type to file - - Args: - py_obj: python object to dump; must be None object - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - d = h_group.create_dataset('data_%i' % call_id, data=[0], **kwargs) - d.attrs["type"] = ['none'] - - -def load_list_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return list(data) - -def load_tuple_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return tuple(data) - -def load_set_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return set(data) - -def load_string_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return str(data[0]) - -def load_unicode_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return unicode(data[0]) - -def load_none_dataset(h_node): - return None - -def load_python_dtype_dataset(h_node): - py_type, data = get_type_and_data(h_node) - subtype = h_node.attrs["python_subdtype"] - type_dict = { - "<type 'int'>": int, - "<type 'float'>": float, - "<type 'long'>": long, - "<type 'bool'>": bool, - "<type 'complex'>": complex - } - tcast = type_dict.get(subtype) - return tcast(data) - -types_dict = { - list: create_listlike_dataset, - tuple: create_listlike_dataset, - set: create_listlike_dataset, - str: create_stringlike_dataset, - unicode: create_stringlike_dataset, - int: create_python_dtype_dataset, - float: create_python_dtype_dataset, - long: create_python_dtype_dataset, - bool: create_python_dtype_dataset, - complex: create_python_dtype_dataset, - NoneType: create_none_dataset, -} - -hkl_types_dict = { - "<type 'list'>" : load_list_dataset, - "<type 'tuple'>" : load_tuple_dataset, - "<type 'set'>" : load_set_dataset, - "python_dtype" : load_python_dtype_dataset, - "string" : load_string_dataset, - "unicode" : load_unicode_dataset, - "none" : load_none_dataset -} - diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_python3.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_python3.py deleted file mode 100644 index c6b173fd07af42735dd05dd7acb9c42e1c651e38..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_python3.py +++ /dev/null @@ -1,201 +0,0 @@ -# encoding: utf-8 -""" -# load_python.py - -Handlers for dumping and loading built-in python types. -NB: As these are for built-in types, they are critical to the functioning of hickle. - -""" - -import six -from hickle.helpers import get_type_and_data - -try: - from exceptions import Exception -except ImportError: - pass # above imports will fail in python3 - -try: - ModuleNotFoundError # This fails on Py3.5 and below -except NameError: - ModuleNotFoundError = ImportError - -import h5py as h5 - - -def get_py3_string_type(h_node): - """ Helper function to return the python string type for items in a list. - - Notes: - Py3 string handling is a bit funky and doesn't play too nicely with HDF5. - We needed to add metadata to say if the strings in a list started off as - bytes, string, etc. This helper loads - - """ - try: - py_type = h_node.attrs["py3_string_type"][0] - return py_type - except: - return None - -def create_listlike_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Dumper for list, set, tuple - - Args: - py_obj: python object to dump; should be list-like - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - dtype = str(type(py_obj)) - obj = list(py_obj) - - # h5py does not handle Py3 'str' objects well. Need to catch this - # Only need to check first element as this method - # is only called if all elements have same dtype - py3_str_type = None - if type(obj[0]) in (str, bytes): - py3_str_type = bytes(str(type(obj[0])), 'ascii') - - if type(obj[0]) is str: - #print(py3_str_type) - #print(obj, "HERE") - obj = [bytes(oo, 'utf8') for oo in obj] - #print(obj, "HERE") - - - d = h_group.create_dataset('data_%i' % call_id, data=obj, **kwargs) - d.attrs["type"] = [bytes(dtype, 'ascii')] - - # Need to add some metadata to aid in unpickling if it's a string type - if py3_str_type is not None: - d.attrs["py3_string_type"] = [py3_str_type] - - - -def create_python_dtype_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps a python dtype object to h5py file - - Args: - py_obj: python object to dump; should be a python type (int, float, bool etc) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - # kwarg compression etc does not work on scalars - d = h_group.create_dataset('data_%i' % call_id, data=py_obj, - dtype=type(py_obj)) #, **kwargs) - d.attrs["type"] = [b'python_dtype'] - d.attrs['python_subdtype'] = bytes(str(type(py_obj)), 'ascii') - - -def create_stringlike_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps a list object to h5py file - - Args: - py_obj: python object to dump; should be string-like (unicode or string) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - if isinstance(py_obj, bytes): - d = h_group.create_dataset('data_%i' % call_id, data=[py_obj], **kwargs) - d.attrs["type"] = [b'bytes'] - elif isinstance(py_obj, str): - dt = h5.special_dtype(vlen=str) - dset = h_group.create_dataset('data_%i' % call_id, shape=(1, ), dtype=dt, **kwargs) - dset[0] = py_obj - dset.attrs['type'] = [b'string'] - -def create_none_dataset(py_obj, h_group, call_id=0, **kwargs): - """ Dump None type to file - - Args: - py_obj: python object to dump; must be None object - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - d = h_group.create_dataset('data_%i' % call_id, data=[0], **kwargs) - d.attrs["type"] = [b'none'] - - -def load_list_dataset(h_node): - py_type, data = get_type_and_data(h_node) - py3_str_type = get_py3_string_type(h_node) - - if py3_str_type == b"<class 'bytes'>": - # Yuck. Convert numpy._bytes -> str -> bytes - return [bytes(str(item, 'utf8'), 'utf8') for item in data] - if py3_str_type == b"<class 'str'>": - return [str(item, 'utf8') for item in data] - else: - return list(data) - -def load_tuple_dataset(h_node): - data = load_list_dataset(h_node) - return tuple(data) - -def load_set_dataset(h_node): - data = load_list_dataset(h_node) - return set(data) - -def load_bytes_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return bytes(data[0]) - -def load_string_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return str(data[0]) - -def load_unicode_dataset(h_node): - py_type, data = get_type_and_data(h_node) - return unicode(data[0]) - -def load_none_dataset(h_node): - return None - -def load_pickled_data(h_node): - py_type, data = get_type_and_data(h_node) - try: - import cPickle as pickle - except ModuleNotFoundError: - import pickle - return pickle.loads(data[0]) - - -def load_python_dtype_dataset(h_node): - py_type, data = get_type_and_data(h_node) - subtype = h_node.attrs["python_subdtype"] - type_dict = { - b"<class 'int'>": int, - b"<class 'float'>": float, - b"<class 'bool'>": bool, - b"<class 'complex'>": complex - } - - tcast = type_dict.get(subtype) - return tcast(data) - - - -types_dict = { - list: create_listlike_dataset, - tuple: create_listlike_dataset, - set: create_listlike_dataset, - bytes: create_stringlike_dataset, - str: create_stringlike_dataset, - #bytearray: create_stringlike_dataset, - int: create_python_dtype_dataset, - float: create_python_dtype_dataset, - bool: create_python_dtype_dataset, - complex: create_python_dtype_dataset, - type(None): create_none_dataset, -} - -hkl_types_dict = { - b"<class 'list'>" : load_list_dataset, - b"<class 'tuple'>" : load_tuple_dataset, - b"<class 'set'>" : load_set_dataset, - b"bytes" : load_bytes_dataset, - b"python_dtype" : load_python_dtype_dataset, - b"string" : load_string_dataset, - b"pickle" : load_pickled_data, - b"none" : load_none_dataset, -} diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_scipy.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_scipy.py deleted file mode 100644 index ab09fe23c69ea791371e4b6a808b553c84195289..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/loaders/load_scipy.py +++ /dev/null @@ -1,92 +0,0 @@ -import six -import scipy -from scipy import sparse - -from hickle.helpers import get_type_and_data - -def check_is_scipy_sparse_array(py_obj): - """ Check if a python object is a scipy sparse array - - Args: - py_obj: python object to check whether it is a sparse array - - Returns - is_numpy (bool): Returns True if it is a sparse array, else False if it isn't - """ - t_csr = type(scipy.sparse.csr_matrix([0])) - t_csc = type(scipy.sparse.csc_matrix([0])) - t_bsr = type(scipy.sparse.bsr_matrix([0])) - is_sparse = type(py_obj) in (t_csr, t_csc, t_bsr) - - return is_sparse - - -def create_sparse_dataset(py_obj, h_group, call_id=0, **kwargs): - """ dumps an sparse array to h5py file - - Args: - py_obj: python object to dump; should be a numpy array or np.ma.array (masked) - h_group (h5.File.group): group to dump data into. - call_id (int): index to identify object's relative location in the iterable. - """ - h_sparsegroup = h_group.create_group('data_%i' % call_id) - data = h_sparsegroup.create_dataset('data', data=py_obj.data, **kwargs) - indices = h_sparsegroup.create_dataset('indices', data=py_obj.indices, **kwargs) - indptr = h_sparsegroup.create_dataset('indptr', data=py_obj.indptr, **kwargs) - shape = h_sparsegroup.create_dataset('shape', data=py_obj.shape, **kwargs) - - if isinstance(py_obj, type(sparse.csr_matrix([0]))): - type_str = 'csr' - elif isinstance(py_obj, type(sparse.csc_matrix([0]))): - type_str = 'csc' - elif isinstance(py_obj, type(sparse.bsr_matrix([0]))): - type_str = 'bsr' - - if six.PY2: - h_sparsegroup.attrs["type"] = [b'%s_matrix' % type_str] - data.attrs["type"] = [b"%s_matrix_data" % type_str] - indices.attrs["type"] = [b"%s_matrix_indices" % type_str] - indptr.attrs["type"] = [b"%s_matrix_indptr" % type_str] - shape.attrs["type"] = [b"%s_matrix_shape" % type_str] - else: - h_sparsegroup.attrs["type"] = [bytes(str('%s_matrix' % type_str), 'ascii')] - data.attrs["type"] = [bytes(str("%s_matrix_data" % type_str), 'ascii')] - indices.attrs["type"] = [bytes(str("%s_matrix_indices" % type_str), 'ascii')] - indptr.attrs["type"] = [bytes(str("%s_matrix_indptr" % type_str), 'ascii')] - shape.attrs["type"] = [bytes(str("%s_matrix_shape" % type_str), 'ascii')] - -def load_sparse_matrix_data(h_node): - - py_type, data = get_type_and_data(h_node) - h_root = h_node.parent - indices = h_root.get('indices')[:] - indptr = h_root.get('indptr')[:] - shape = h_root.get('shape')[:] - - if py_type == b'csc_matrix_data': - smat = sparse.csc_matrix((data, indices, indptr), dtype=data.dtype, shape=shape) - elif py_type == b'csr_matrix_data': - smat = sparse.csr_matrix((data, indices, indptr), dtype=data.dtype, shape=shape) - elif py_type == b'bsr_matrix_data': - smat = sparse.bsr_matrix((data, indices, indptr), dtype=data.dtype, shape=shape) - return smat - - - - - -class_register = [ - [scipy.sparse.csr_matrix, b'csr_matrix_data', create_sparse_dataset, load_sparse_matrix_data, False, check_is_scipy_sparse_array], - [scipy.sparse.csc_matrix, b'csc_matrix_data', create_sparse_dataset, load_sparse_matrix_data, False, check_is_scipy_sparse_array], - [scipy.sparse.bsr_matrix, b'bsr_matrix_data', create_sparse_dataset, load_sparse_matrix_data, False, check_is_scipy_sparse_array], -] - -exclude_register = [] - -# Need to ignore things like csc_matrix_indices which are loaded automatically -for mat_type in ('csr', 'csc', 'bsr'): - for attrib in ('indices', 'indptr', 'shape'): - hkl_key = "%s_matrix_%s" % (mat_type, attrib) - if not six.PY2: - hkl_key = hkl_key.encode('ascii') - exclude_register.append(hkl_key) diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/lookup.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/lookup.py deleted file mode 100644 index 99d13df9315be642540e46efc44d8e3d293de708..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/hickle/lookup.py +++ /dev/null @@ -1,238 +0,0 @@ -""" -#lookup.py - -This file contains all the mappings between hickle/HDF5 metadata and python types. -There are four dictionaries and one set that are populated here: - -1) types_dict -types_dict: mapping between python types and dataset creation functions, e.g. - types_dict = { - list: create_listlike_dataset, - int: create_python_dtype_dataset, - np.ndarray: create_np_array_dataset - } - -2) hkl_types_dict -hkl_types_dict: mapping between hickle metadata and dataset loading functions, e.g. - hkl_types_dict = { - "<type 'list'>" : load_list_dataset, - "<type 'tuple'>" : load_tuple_dataset - } - -3) container_types_dict -container_types_dict: mapping required to convert the PyContainer object in hickle.py - back into the required native type. PyContainer is required as - some iterable types are immutable (do not have an append() function). - Here is an example: - container_types_dict = { - "<type 'list'>": list, - "<type 'tuple'>": tuple - } - -4) container_key_types_dict -container_key_types_dict: mapping specifically for converting hickled dict data back into - a dictionary with the same key type. While python dictionary keys - can be any hashable object, in HDF5 a unicode/string is required - for a dataset name. Example: - container_key_types_dict = { - "<type 'str'>": str, - "<type 'unicode'>": unicode - } - -5) types_not_to_sort -type_not_to_sort is a list of hickle type attributes that may be hierarchical, -but don't require sorting by integer index. - -## Extending hickle to add support for other classes and types - -The process to add new load/dump capabilities is as follows: - -1) Create a file called load_[newstuff].py in loaders/ -2) In the load_[newstuff].py file, define your create_dataset and load_dataset functions, - along with all required mapping dictionaries. -3) Add an import call here, and populate the lookup dictionaries with update() calls: - # Add loaders for [newstuff] - try: - from .loaders.load_[newstuff[ import types_dict as ns_types_dict - from .loaders.load_[newstuff[ import hkl_types_dict as ns_hkl_types_dict - types_dict.update(ns_types_dict) - hkl_types_dict.update(ns_hkl_types_dict) - ... (Add container_types_dict etc if required) - except ImportError: - raise -""" - -import six -from ast import literal_eval - -def return_first(x): - """ Return first element of a list """ - return x[0] - -def load_nothing(h_hode): - pass - -types_dict = {} - -hkl_types_dict = {} - -types_not_to_sort = [b'dict', b'csr_matrix', b'csc_matrix', b'bsr_matrix'] - -container_types_dict = { - b"<type 'list'>": list, - b"<type 'tuple'>": tuple, - b"<type 'set'>": set, - b"<class 'list'>": list, - b"<class 'tuple'>": tuple, - b"<class 'set'>": set, - b"csr_matrix": return_first, - b"csc_matrix": return_first, - b"bsr_matrix": return_first - } - -# Technically, any hashable object can be used, for now sticking with built-in types -container_key_types_dict = { - b"<type 'str'>": literal_eval, - b"<type 'float'>": float, - b"<type 'bool'>": bool, - b"<type 'int'>": int, - b"<type 'complex'>": complex, - b"<type 'tuple'>": literal_eval, - b"<class 'str'>": literal_eval, - b"<class 'float'>": float, - b"<class 'bool'>": bool, - b"<class 'int'>": int, - b"<class 'complex'>": complex, - b"<class 'tuple'>": literal_eval - } - -if six.PY2: - container_key_types_dict[b"<type 'unicode'>"] = literal_eval - container_key_types_dict[b"<type 'long'>"] = long - -# Add loaders for built-in python types -if six.PY2: - from .loaders.load_python import types_dict as py_types_dict - from .loaders.load_python import hkl_types_dict as py_hkl_types_dict -else: - from .loaders.load_python3 import types_dict as py_types_dict - from .loaders.load_python3 import hkl_types_dict as py_hkl_types_dict - -types_dict.update(py_types_dict) -hkl_types_dict.update(py_hkl_types_dict) - -# Add loaders for numpy types -from .loaders.load_numpy import types_dict as np_types_dict -from .loaders.load_numpy import hkl_types_dict as np_hkl_types_dict -from .loaders.load_numpy import check_is_numpy_array -types_dict.update(np_types_dict) -hkl_types_dict.update(np_hkl_types_dict) - -####################### -## ND-ARRAY checking ## -####################### - -ndarray_like_check_fns = [ - check_is_numpy_array -] - -def check_is_ndarray_like(py_obj): - is_ndarray_like = False - for ii, check_fn in enumerate(ndarray_like_check_fns): - is_ndarray_like = check_fn(py_obj) - if is_ndarray_like: - break - return is_ndarray_like - - - - -####################### -## loading optional ## -####################### - -def register_class(myclass_type, hkl_str, dump_function, load_function, - to_sort=True, ndarray_check_fn=None): - """ Register a new hickle class. - - Args: - myclass_type type(class): type of class - dump_function (function def): function to write data to HDF5 - load_function (function def): function to load data from HDF5 - is_iterable (bool): Is the item iterable? - hkl_str (str): String to write to HDF5 file to describe class - to_sort (bool): If the item is iterable, does it require sorting? - ndarray_check_fn (function def): function to use to check if - - """ - types_dict.update({myclass_type: dump_function}) - hkl_types_dict.update({hkl_str: load_function}) - if to_sort == False: - types_not_to_sort.append(hkl_str) - if ndarray_check_fn is not None: - ndarray_like_check_fns.append(ndarray_check_fn) - -def register_class_list(class_list): - """ Register multiple classes in a list - - Args: - class_list (list): A list, where each item is an argument to - the register_class() function. - - Notes: This just runs the code: - for item in mylist: - register_class(*item) - """ - for class_item in class_list: - register_class(*class_item) - -def register_class_exclude(hkl_str_to_ignore): - """ Tell loading funciton to ignore any HDF5 dataset with attribute 'type=XYZ' - - Args: - hkl_str_to_ignore (str): attribute type=string to ignore and exclude from loading. - """ - hkl_types_dict[hkl_str_to_ignore] = load_nothing - -def register_exclude_list(exclude_list): - """ Ignore HDF5 datasets with attribute type='XYZ' from loading - - ArgsL - exclude_list (list): List of strings, which correspond to hdf5/hickle - type= attributes not to load. - """ - for hkl_str in exclude_list: - register_class_exclude(hkl_str) - -######################## -## Scipy sparse array ## -######################## - -try: - from .loaders.load_scipy import class_register, exclude_register - register_class_list(class_register) - register_exclude_list(exclude_register) -except ImportError: - pass -except NameError: - pass - -#################### -## Astropy stuff ## -#################### - -try: - from .loaders.load_astropy import class_register - register_class_list(class_register) -except ImportError: - pass - -################## -## Pandas stuff ## -################## - -try: - from .loaders.load_pandas import class_register - register_class_list(class_register) -except ImportError: - pass diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/__init__.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_astropy.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_astropy.py deleted file mode 100644 index 2086ec37456b2bbcde77fbed2d5370b67ee89381..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_astropy.py +++ /dev/null @@ -1,133 +0,0 @@ -import hickle as hkl -from astropy.units import Quantity -from astropy.time import Time -from astropy.coordinates import Angle, SkyCoord -from astropy.constants import Constant, EMConstant, G -from astropy.table import Table -import numpy as np -from py.path import local - -# Set the current working directory to the temporary directory -local.get_temproot().chdir() - -def test_astropy_quantity(): - - for uu in ['m^3', 'm^3 / s', 'kg/pc']: - a = Quantity(7, unit=uu) - - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - - assert a == b - assert a.unit == b.unit - - a *= a - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - assert a == b - assert a.unit == b.unit - -def TODO_test_astropy_constant(): - hkl.dump(G, "test_ap.h5") - gg = hkl.load("test_ap.h5") - - print(G) - print(gg) - -def test_astropy_table(): - t = Table([[1, 2], [3, 4]], names=('a', 'b'), meta={'name': 'test_thing'}) - - hkl.dump({'a': t}, "test_ap.h5") - t2 = hkl.load("test_ap.h5")['a'] - - print(t) - print(t.meta) - print(t2) - print(t2.meta) - - print(t.dtype, t2.dtype) - assert t.meta == t2.meta - assert t.dtype == t2.dtype - - assert np.allclose(t['a'].astype('float32'), t2['a'].astype('float32')) - assert np.allclose(t['b'].astype('float32'), t2['b'].astype('float32')) - -def test_astropy_quantity_array(): - a = Quantity([1,2,3], unit='m') - - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - - assert np.allclose(a.value, b.value) - assert a.unit == b.unit - -def test_astropy_time_array(): - times = ['1999-01-01T00:00:00.123456789', '2010-01-01T00:00:00'] - t1 = Time(times, format='isot', scale='utc') - hkl.dump(t1, "test_ap2.h5") - t2 = hkl.load("test_ap2.h5") - - print(t1) - print(t2) - assert t1.value.shape == t2.value.shape - for ii in range(len(t1)): - assert t1.value[ii] == t2.value[ii] - assert t1.format == t2.format - assert t1.scale == t2.scale - - times = [58264, 58265, 58266] - t1 = Time(times, format='mjd', scale='utc') - hkl.dump(t1, "test_ap2.h5") - t2 = hkl.load("test_ap2.h5") - - print(t1) - print(t2) - assert t1.value.shape == t2.value.shape - assert np.allclose(t1.value, t2.value) - assert t1.format == t2.format - assert t1.scale == t2.scale - -def test_astropy_angle(): - for uu in ['radian', 'degree']: - a = Angle(1.02, unit=uu) - - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - assert a == b - assert a.unit == b.unit - -def test_astropy_angle_array(): - a = Angle([1,2,3], unit='degree') - - hkl.dump(a, "test_ap.h5") - b = hkl.load("test_ap.h5") - - assert np.allclose(a.value, b.value) - assert a.unit == b.unit - -def test_astropy_skycoord(): - ra = Angle(['1d20m', '1d21m'], unit='degree') - dec = Angle(['33d0m0s', '33d01m'], unit='degree') - radec = SkyCoord(ra, dec) - hkl.dump(radec, "test_ap.h5") - radec2 = hkl.load("test_ap.h5") - assert np.allclose(radec.ra.value, radec2.ra.value) - assert np.allclose(radec.dec.value, radec2.dec.value) - - ra = Angle(['1d20m', '1d21m'], unit='hourangle') - dec = Angle(['33d0m0s', '33d01m'], unit='degree') - radec = SkyCoord(ra, dec) - hkl.dump(radec, "test_ap.h5") - radec2 = hkl.load("test_ap.h5") - assert np.allclose(radec.ra.value, radec2.ra.value) - assert np.allclose(radec.dec.value, radec2.dec.value) - -if __name__ == "__main__": - test_astropy_quantity() - #test_astropy_constant() - test_astropy_table() - test_astropy_quantity_array() - test_astropy_time_array() - test_astropy_angle() - test_astropy_angle_array() - test_astropy_skycoord() diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_hickle.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_hickle.py deleted file mode 100644 index 5491054239372a3b5d42c9e6f07b6fc5701ed933..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_hickle.py +++ /dev/null @@ -1,826 +0,0 @@ -#! /usr/bin/env python -# encoding: utf-8 -""" -# test_hickle.py - -Unit tests for hickle module. - -""" - -import h5py -import hashlib -import numpy as np -import os -import six -import time -from pprint import pprint - -from py.path import local - -import hickle -from hickle.hickle import * - - -# Set current working directory to the temporary directory -local.get_temproot().chdir() - -NESTED_DICT = { - "level1_1": { - "level2_1": [1, 2, 3], - "level2_2": [4, 5, 6] - }, - "level1_2": { - "level2_1": [1, 2, 3], - "level2_2": [4, 5, 6] - }, - "level1_3": { - "level2_1": { - "level3_1": [1, 2, 3], - "level3_2": [4, 5, 6] - }, - "level2_2": [4, 5, 6] - } -} - -DUMP_CACHE = [] # Used in test_track_times() - - -def test_string(): - """ Dumping and loading a string """ - if six.PY2: - filename, mode = 'test.h5', 'w' - string_obj = "The quick brown fox jumps over the lazy dog" - dump(string_obj, filename, mode) - string_hkl = load(filename) - #print "Initial list: %s"%list_obj - #print "Unhickled data: %s"%list_hkl - assert type(string_obj) == type(string_hkl) == str - assert string_obj == string_hkl - else: - pass - - -def test_unicode(): - """ Dumping and loading a unicode string """ - if six.PY2: - filename, mode = 'test.h5', 'w' - u = unichr(233) + unichr(0x0bf2) + unichr(3972) + unichr(6000) - dump(u, filename, mode) - u_hkl = load(filename) - - assert type(u) == type(u_hkl) == unicode - assert u == u_hkl - # For those interested, uncomment below to see what those codes are: - # for i, c in enumerate(u_hkl): - # print i, '%04x' % ord(c), unicodedata.category(c), - # print unicodedata.name(c) - else: - pass - - -def test_unicode2(): - if six.PY2: - a = u"unicode test" - dump(a, 'test.hkl', mode='w') - - z = load('test.hkl') - assert a == z - assert type(a) == type(z) == unicode - pprint(z) - else: - pass - -def test_list(): - """ Dumping and loading a list """ - filename, mode = 'test_list.h5', 'w' - list_obj = [1, 2, 3, 4, 5] - dump(list_obj, filename, mode=mode) - list_hkl = load(filename) - #print(f'Initial list: {list_obj}') - #print(f'Unhickled data: {list_hkl}') - try: - assert type(list_obj) == type(list_hkl) == list - assert list_obj == list_hkl - import h5py - a = h5py.File(filename) - a.close() - - except AssertionError: - print("ERR:", list_obj, list_hkl) - import h5py - - raise() - - -def test_set(): - """ Dumping and loading a list """ - filename, mode = 'test_set.h5', 'w' - list_obj = set([1, 0, 3, 4.5, 11.2]) - dump(list_obj, filename, mode) - list_hkl = load(filename) - #print "Initial list: %s"%list_obj - #print "Unhickled data: %s"%list_hkl - try: - assert type(list_obj) == type(list_hkl) == set - assert list_obj == list_hkl - except AssertionError: - print(type(list_obj)) - print(type(list_hkl)) - #os.remove(filename) - raise - - -def test_numpy(): - """ Dumping and loading numpy array """ - filename, mode = 'test.h5', 'w' - dtypes = ['float32', 'float64', 'complex64', 'complex128'] - - for dt in dtypes: - array_obj = np.ones(8, dtype=dt) - dump(array_obj, filename, mode) - array_hkl = load(filename) - try: - assert array_hkl.dtype == array_obj.dtype - assert np.all((array_hkl, array_obj)) - except AssertionError: - print(array_hkl) - print(array_obj) - raise - - -def test_masked(): - """ Test masked numpy array """ - filename, mode = 'test.h5', 'w' - a = np.ma.array([1,2,3,4], dtype='float32', mask=[0,1,0,0]) - - dump(a, filename, mode) - a_hkl = load(filename) - - try: - assert a_hkl.dtype == a.dtype - assert np.all((a_hkl, a)) - except AssertionError: - print(a_hkl) - print(a) - raise - - -def test_dict(): - """ Test dictionary dumping and loading """ - filename, mode = 'test.h5', 'w' - - dd = { - 'name' : b'Danny', - 'age' : 28, - 'height' : 6.1, - 'dork' : True, - 'nums' : [1, 2, 3], - 'narr' : np.array([1,2,3]), - #'unic' : u'dan[at]thetelegraphic.com' - } - - - dump(dd, filename, mode) - dd_hkl = load(filename) - - for k in dd.keys(): - try: - assert k in dd_hkl.keys() - - if type(dd[k]) is type(np.array([1])): - assert np.all((dd[k], dd_hkl[k])) - else: - #assert dd_hkl[k] == dd[k] - pass - assert type(dd_hkl[k]) == type(dd[k]) - except AssertionError: - print(k) - print(dd_hkl[k]) - print(dd[k]) - print(type(dd_hkl[k]), type(dd[k])) - raise - - -def test_empty_dict(): - """ Test empty dictionary dumping and loading """ - filename, mode = 'test.h5', 'w' - - dump({}, filename, mode) - assert load(filename) == {} - - -def test_compression(): - """ Test compression on datasets""" - - filename, mode = 'test.h5', 'w' - dtypes = ['int32', 'float32', 'float64', 'complex64', 'complex128'] - - comps = [None, 'gzip', 'lzf'] - - for dt in dtypes: - for cc in comps: - array_obj = np.ones(32768, dtype=dt) - dump(array_obj, filename, mode, compression=cc) - print(cc, os.path.getsize(filename)) - array_hkl = load(filename) - try: - assert array_hkl.dtype == array_obj.dtype - assert np.all((array_hkl, array_obj)) - except AssertionError: - print(array_hkl) - print(array_obj) - raise - - -def test_dict_int_key(): - """ Test for dictionaries with integer keys """ - filename, mode = 'test.h5', 'w' - - dd = { - 0: "test", - 1: "test2" - } - - dump(dd, filename, mode) - dd_hkl = load(filename) - - -def test_dict_nested(): - """ Test for dictionaries with integer keys """ - filename, mode = 'test.h5', 'w' - - dd = NESTED_DICT - - dump(dd, filename, mode) - dd_hkl = load(filename) - - ll_hkl = dd_hkl["level1_3"]["level2_1"]["level3_1"] - ll = dd["level1_3"]["level2_1"]["level3_1"] - assert ll == ll_hkl - - -def test_masked_dict(): - """ Test dictionaries with masked arrays """ - - filename, mode = 'test.h5', 'w' - - dd = { - "data" : np.ma.array([1,2,3], mask=[True, False, False]), - "data2" : np.array([1,2,3,4,5]) - } - - dump(dd, filename, mode) - dd_hkl = load(filename) - - for k in dd.keys(): - try: - assert k in dd_hkl.keys() - if type(dd[k]) is type(np.array([1])): - assert np.all((dd[k], dd_hkl[k])) - elif type(dd[k]) is type(np.ma.array([1])): - print(dd[k].data) - print(dd_hkl[k].data) - assert np.allclose(dd[k].data, dd_hkl[k].data) - assert np.allclose(dd[k].mask, dd_hkl[k].mask) - - assert type(dd_hkl[k]) == type(dd[k]) - - except AssertionError: - print(k) - print(dd_hkl[k]) - print(dd[k]) - print(type(dd_hkl[k]), type(dd[k])) - raise - - -def test_np_float(): - """ Test for singular np dtypes """ - filename, mode = 'np_float.h5', 'w' - - dtype_list = (np.float16, np.float32, np.float64, - np.complex64, np.complex128, - np.int8, np.int16, np.int32, np.int64, - np.uint8, np.uint16, np.uint32, np.uint64) - - for dt in dtype_list: - - dd = dt(1) - dump(dd, filename, mode) - dd_hkl = load(filename) - assert dd == dd_hkl - assert dd.dtype == dd_hkl.dtype - - dd = {} - for dt in dtype_list: - dd[str(dt)] = dt(1.0) - dump(dd, filename, mode) - dd_hkl = load(filename) - - print(dd) - for dt in dtype_list: - assert dd[str(dt)] == dd_hkl[str(dt)] - - -def md5sum(filename, blocksize=65536): - """ Compute MD5 sum for a given file """ - hash = hashlib.md5() - - with open(filename, "r+b") as f: - for block in iter(lambda: f.read(blocksize), ""): - hash.update(block) - return hash.hexdigest() - - -def caching_dump(obj, filename, *args, **kwargs): - """ Save arguments of all dump calls """ - DUMP_CACHE.append((obj, filename, args, kwargs)) - return hickle_dump(obj, filename, *args, **kwargs) - - -def test_track_times(): - """ Verify that track_times = False produces identical files """ - hashes = [] - for obj, filename, mode, kwargs in DUMP_CACHE: - if isinstance(filename, hickle.H5FileWrapper): - filename = str(filename.file_name) - kwargs['track_times'] = False - caching_dump(obj, filename, mode, **kwargs) - hashes.append(md5sum(filename)) - - time.sleep(1) - - for hash1, (obj, filename, mode, kwargs) in zip(hashes, DUMP_CACHE): - if isinstance(filename, hickle.H5FileWrapper): - filename = str(filename.file_name) - caching_dump(obj, filename, mode, **kwargs) - hash2 = md5sum(filename) - print(hash1, hash2) - assert hash1 == hash2 - - -def test_comp_kwargs(): - """ Test compression with some kwargs for shuffle and chunking """ - - filename, mode = 'test.h5', 'w' - dtypes = ['int32', 'float32', 'float64', 'complex64', 'complex128'] - - comps = [None, 'gzip', 'lzf'] - chunks = [(100, 100), (250, 250)] - shuffles = [True, False] - scaleoffsets = [0, 1, 2] - - for dt in dtypes: - for cc in comps: - for ch in chunks: - for sh in shuffles: - for so in scaleoffsets: - kwargs = { - 'compression' : cc, - 'dtype': dt, - 'chunks': ch, - 'shuffle': sh, - 'scaleoffset': so - } - #array_obj = np.random.random_integers(low=-8192, high=8192, size=(1000, 1000)).astype(dt) - array_obj = NESTED_DICT - dump(array_obj, filename, mode, compression=cc) - print(kwargs, os.path.getsize(filename)) - array_hkl = load(filename) - - -def test_list_numpy(): - """ Test converting a list of numpy arrays """ - - filename, mode = 'test.h5', 'w' - - a = np.ones(1024) - b = np.zeros(1000) - c = [a, b] - - dump(c, filename, mode) - dd_hkl = load(filename) - - print(dd_hkl) - - assert isinstance(dd_hkl, list) - assert isinstance(dd_hkl[0], np.ndarray) - - -def test_tuple_numpy(): - """ Test converting a list of numpy arrays """ - - filename, mode = 'test.h5', 'w' - - a = np.ones(1024) - b = np.zeros(1000) - c = (a, b, a) - - dump(c, filename, mode) - dd_hkl = load(filename) - - print(dd_hkl) - - assert isinstance(dd_hkl, tuple) - assert isinstance(dd_hkl[0], np.ndarray) - - -def test_none(): - """ Test None type hickling """ - - filename, mode = 'test.h5', 'w' - - a = None - - dump(a, filename, mode) - dd_hkl = load(filename) - print(a) - print(dd_hkl) - - assert isinstance(dd_hkl, type(None)) - - -def test_dict_none(): - """ Test None type hickling """ - - filename, mode = 'test.h5', 'w' - - a = {'a': 1, 'b' : None} - - dump(a, filename, mode) - dd_hkl = load(filename) - print(a) - print(dd_hkl) - - assert isinstance(a['b'], type(None)) - - -def test_file_open_close(): - """ https://github.com/telegraphic/hickle/issues/20 """ - import h5py - f = h5py.File('test.hdf', 'w') - a = np.arange(5) - - dump(a, 'test.hkl') - dump(a, 'test.hkl') - - dump(a, f, mode='w') - f.close() - try: - dump(a, f, mode='w') - except hickle.hickle.ClosedFileError: - print("Tests: Closed file exception caught") - - -def test_list_order(): - """ https://github.com/telegraphic/hickle/issues/26 """ - d = [np.arange(n + 1) for n in range(20)] - hickle.dump(d, 'test.h5') - d_hkl = hickle.load('test.h5') - - try: - for ii, xx in enumerate(d): - assert d[ii].shape == d_hkl[ii].shape - for ii, xx in enumerate(d): - assert np.allclose(d[ii], d_hkl[ii]) - except AssertionError: - print(d[ii], d_hkl[ii]) - raise - - -def test_embedded_array(): - """ See https://github.com/telegraphic/hickle/issues/24 """ - - d_orig = [[np.array([10., 20.]), np.array([10, 20, 30])], [np.array([10, 2]), np.array([1.])]] - hickle.dump(d_orig, 'test.h5') - d_hkl = hickle.load('test.h5') - - for ii, xx in enumerate(d_orig): - for jj, yy in enumerate(xx): - assert np.allclose(d_orig[ii][jj], d_hkl[ii][jj]) - - print(d_hkl) - print(d_orig) - - -################ -## NEW TESTS ## -################ - - -def generate_nested(): - a = [1, 2, 3] - b = [a, a, a] - c = [a, b, 's'] - d = [a, b, c, c, a] - e = [d, d, d, d, 1] - f = {'a' : a, 'b' : b, 'e' : e} - g = {'f' : f, 'a' : e, 'd': d} - h = {'h': g, 'g' : f} - z = [f, a, b, c, d, e, f, g, h, g, h] - a = np.array([1, 2, 3, 4]) - b = set([1, 2, 3, 4, 5]) - c = (1, 2, 3, 4, 5) - d = np.ma.array([1, 2, 3, 4, 5, 6, 7, 8]) - z = {'a': a, 'b': b, 'c': c, 'd': d, 'z': z} - return z - - -def test_is_iterable(): - a = [1, 2, 3] - b = 1 - - assert check_is_iterable(a) == True - assert check_is_iterable(b) == False - - -def test_check_iterable_item_type(): - - a = [1, 2, 3] - b = [a, a, a] - c = [a, b, 's'] - - type_a = check_iterable_item_type(a) - type_b = check_iterable_item_type(b) - type_c = check_iterable_item_type(c) - - assert type_a is int - assert type_b is list - assert type_c == False - - -def test_dump_nested(): - """ Dump a complicated nested object to HDF5 - """ - z = generate_nested() - dump(z, 'test.hkl', mode='w') - - -def test_with_dump(): - lst = [1] - tpl = (1) - dct = {1: 1} - arr = np.array([1]) - - with h5py.File('test.hkl') as file: - dump(lst, file, path='/lst') - dump(tpl, file, path='/tpl') - dump(dct, file, path='/dct') - dump(arr, file, path='/arr') - - -def test_with_load(): - lst = [1] - tpl = (1) - dct = {1: 1} - arr = np.array([1]) - - with h5py.File('test.hkl') as file: - assert load(file, '/lst') == lst - assert load(file, '/tpl') == tpl - assert load(file, '/dct') == dct - assert load(file, '/arr') == arr - - -def test_load(): - - a = set([1, 2, 3, 4]) - b = set([5, 6, 7, 8]) - c = set([9, 10, 11, 12]) - z = (a, b, c) - z = [z, z] - z = (z, z, z, z, z) - - print("Original:") - pprint(z) - dump(z, 'test.hkl', mode='w') - - print("\nReconstructed:") - z = load('test.hkl') - pprint(z) - - -def test_sort_keys(): - keys = [b'data_0', b'data_1', b'data_2', b'data_3', b'data_10'] - keys_sorted = [b'data_0', b'data_1', b'data_2', b'data_3', b'data_10'] - - print(keys) - print(keys_sorted) - assert sort_keys(keys) == keys_sorted - - -def test_ndarray(): - - a = np.array([1,2,3]) - b = np.array([2,3,4]) - z = (a, b) - - print("Original:") - pprint(z) - dump(z, 'test.hkl', mode='w') - - print("\nReconstructed:") - z = load('test.hkl') - pprint(z) - - -def test_ndarray_masked(): - - a = np.ma.array([1,2,3]) - b = np.ma.array([2,3,4], mask=[True, False, True]) - z = (a, b) - - print("Original:") - pprint(z) - dump(z, 'test.hkl', mode='w') - - print("\nReconstructed:") - z = load('test.hkl') - pprint(z) - - -def test_simple_dict(): - a = {'key1': 1, 'key2': 2} - - dump(a, 'test.hkl') - z = load('test.hkl') - - pprint(a) - pprint(z) - - -def test_complex_dict(): - a = {'akey': 1, 'akey2': 2} - if six.PY2: - # NO LONG TYPE IN PY3! - b = {'bkey': 2.0, 'bkey3': long(3.0)} - else: - b = a - c = {'ckey': "hello", "ckey2": "hi there"} - z = {'zkey1': a, 'zkey2': b, 'zkey3': c} - - print("Original:") - pprint(z) - dump(z, 'test.hkl', mode='w') - - print("\nReconstructed:") - z = load('test.hkl') - pprint(z) - -def test_multi_hickle(): - a = {'a': 123, 'b': [1, 2, 4]} - - if os.path.exists("test.hkl"): - os.remove("test.hkl") - dump(a, "test.hkl", path="/test", mode="w") - dump(a, "test.hkl", path="/test2", mode="r+") - dump(a, "test.hkl", path="/test3", mode="r+") - dump(a, "test.hkl", path="/test4", mode="r+") - - a = load("test.hkl", path="/test") - b = load("test.hkl", path="/test2") - c = load("test.hkl", path="/test3") - d = load("test.hkl", path="/test4") - -def test_complex(): - """ Test complex value dtype is handled correctly - - https://github.com/telegraphic/hickle/issues/29 """ - - data = {"A":1.5, "B":1.5 + 1j, "C":np.linspace(0,1,4) + 2j} - dump(data, "test.hkl") - data2 = load("test.hkl") - for key in data.keys(): - assert type(data[key]) == type(data2[key]) - -def test_nonstring_keys(): - """ Test that keys are reconstructed back to their original datatypes - https://github.com/telegraphic/hickle/issues/36 - """ - if six.PY2: - u = unichr(233) + unichr(0x0bf2) + unichr(3972) + unichr(6000) - - data = {u'test': 123, - 'def': 456, - 'hik' : np.array([1,2,3]), - u: u, - 0: 0, - True: 'hi', - 1.1 : 'hey', - #2L : 'omg', - 1j: 'complex_hashable', - (1, 2): 'boo', - ('A', 17.4, 42): [1, 7, 'A'], - (): '1313e was here', - '0': 0 - } - #data = {'0': 123, 'def': 456} - print(data) - dump(data, "test.hkl") - data2 = load("test.hkl") - print(data2) - - for key in data.keys(): - assert key in data2.keys() - - print(data2) - else: - pass - -def test_scalar_compression(): - """ Test bug where compression causes a crash on scalar datasets - - (Scalars are incompressible!) - https://github.com/telegraphic/hickle/issues/37 - """ - data = {'a' : 0, 'b' : np.float(2), 'c' : True} - - dump(data, "test.hkl", compression='gzip') - data2 = load("test.hkl") - - print(data2) - for key in data.keys(): - assert type(data[key]) == type(data2[key]) - -def test_bytes(): - """ Dumping and loading a string. PYTHON3 ONLY """ - if six.PY3: - filename, mode = 'test.h5', 'w' - string_obj = b"The quick brown fox jumps over the lazy dog" - dump(string_obj, filename, mode) - string_hkl = load(filename) - #print "Initial list: %s"%list_obj - #print "Unhickled data: %s"%list_hkl - print(type(string_obj)) - print(type(string_hkl)) - assert type(string_obj) == type(string_hkl) == bytes - assert string_obj == string_hkl - else: - pass - -def test_np_scalar(): - """ Numpy scalar datatype - - https://github.com/telegraphic/hickle/issues/50 - """ - - fid='test.h5py' - r0={'test': np.float64(10.)} - s = dump(r0, fid) - r = load(fid) - print(r) - assert type(r0['test']) == type(r['test']) - -if __name__ == '__main__': - """ Some tests and examples """ - test_sort_keys() - - test_np_scalar() - test_scalar_compression() - test_complex() - test_file_open_close() - test_dict_none() - test_none() - test_masked_dict() - test_list() - test_set() - test_numpy() - test_dict() - test_empty_dict() - test_compression() - test_masked() - test_dict_nested() - test_comp_kwargs() - test_list_numpy() - test_tuple_numpy() - test_track_times() - test_list_order() - test_embedded_array() - test_np_float() - - if six.PY2: - test_unicode() - test_unicode2() - test_string() - test_nonstring_keys() - - if six.PY3: - test_bytes() - - - # NEW TESTS - test_is_iterable() - test_check_iterable_item_type() - test_dump_nested() - test_with_dump() - test_with_load() - test_load() - test_sort_keys() - test_ndarray() - test_ndarray_masked() - test_simple_dict() - test_complex_dict() - test_multi_hickle() - test_dict_int_key() - - # Cleanup - print("ALL TESTS PASSED!") \ No newline at end of file diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_hickle_helpers.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_hickle_helpers.py deleted file mode 100644 index 253839e97c96e484b7a66ad9d174648d281d1c66..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_hickle_helpers.py +++ /dev/null @@ -1,63 +0,0 @@ -#! /usr/bin/env python -# encoding: utf-8 -""" -# test_hickle_helpers.py - -Unit tests for hickle module -- helper functions. - -""" - -import numpy as np -try: - import scipy - from scipy import sparse - _has_scipy = True -except ImportError: - _has_scipy = False - -from hickle.helpers import check_is_hashable, check_is_iterable, check_iterable_item_type - -from hickle.loaders.load_numpy import check_is_numpy_array -if _has_scipy: - from hickle.loaders.load_scipy import check_is_scipy_sparse_array - - - -def test_check_is_iterable(): - assert check_is_iterable([1,2,3]) is True - assert check_is_iterable(1) is False - - -def test_check_is_hashable(): - assert check_is_hashable(1) is True - assert check_is_hashable([1,2,3]) is False - - -def test_check_iterable_item_type(): - assert check_iterable_item_type([1,2,3]) is int - assert check_iterable_item_type([int(1), float(1)]) is False - assert check_iterable_item_type([]) is False - - -def test_check_is_numpy_array(): - assert check_is_numpy_array(np.array([1,2,3])) is True - assert check_is_numpy_array(np.ma.array([1,2,3])) is True - assert check_is_numpy_array([1,2]) is False - - -def test_check_is_scipy_sparse_array(): - t_csr = scipy.sparse.csr_matrix([0]) - t_csc = scipy.sparse.csc_matrix([0]) - t_bsr = scipy.sparse.bsr_matrix([0]) - assert check_is_scipy_sparse_array(t_csr) is True - assert check_is_scipy_sparse_array(t_csc) is True - assert check_is_scipy_sparse_array(t_bsr) is True - assert check_is_scipy_sparse_array(np.array([1])) is False - -if __name__ == "__main__": - test_check_is_hashable() - test_check_is_iterable() - test_check_is_numpy_array() - test_check_iterable_item_type() - if _has_scipy: - test_check_is_scipy_sparse_array() \ No newline at end of file diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_legacy_load.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_legacy_load.py deleted file mode 100644 index e849bcf6594c7139357659f8cf0721ef777da3b0..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_legacy_load.py +++ /dev/null @@ -1,30 +0,0 @@ -import glob -import warnings -import hickle as hkl -import h5py -import six - -def test_legacy_load(): - if six.PY2: - filelist = sorted(glob.glob('legacy_hkls/*.hkl')) - - # Make all warnings show - warnings.simplefilter("always") - - for filename in filelist: - try: - print(filename) - a = hkl.load(filename) - except: - with h5py.File(filename) as a: - print(a.attrs.items()) - print(a.items()) - for key, item in a.items(): - print(item.attrs.items()) - raise - else: - print("Legacy loading only works in Py2. Sorry.") - pass - -if __name__ == "__main__": - test_legacy_load() \ No newline at end of file diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_scipy.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_scipy.py deleted file mode 100644 index ab78311d3eb543f4d3515b6aef2eba4e5ea2a175..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/hickle-3.4.3-py3.6.egg/tests/test_scipy.py +++ /dev/null @@ -1,57 +0,0 @@ -import numpy as np -from scipy.sparse import csr_matrix, csc_matrix, bsr_matrix - -import hickle -from hickle.loaders.load_scipy import check_is_scipy_sparse_array - -from py.path import local - -# Set the current working directory to the temporary directory -local.get_temproot().chdir() - - -def test_is_sparse(): - sm0 = csr_matrix((3, 4), dtype=np.int8) - sm1 = csc_matrix((1, 2)) - - assert check_is_scipy_sparse_array(sm0) - assert check_is_scipy_sparse_array(sm1) - - -def test_sparse_matrix(): - sm0 = csr_matrix((3, 4), dtype=np.int8).toarray() - - row = np.array([0, 0, 1, 2, 2, 2]) - col = np.array([0, 2, 2, 0, 1, 2]) - data = np.array([1, 2, 3, 4, 5, 6]) - sm1 = csr_matrix((data, (row, col)), shape=(3, 3)) - sm2 = csc_matrix((data, (row, col)), shape=(3, 3)) - - indptr = np.array([0, 2, 3, 6]) - indices = np.array([0, 2, 2, 0, 1, 2]) - data = np.array([1, 2, 3, 4, 5, 6]).repeat(4).reshape(6, 2, 2) - sm3 = bsr_matrix((data,indices, indptr), shape=(6, 6)) - - hickle.dump(sm1, 'test_sp.h5') - sm1_h = hickle.load('test_sp.h5') - hickle.dump(sm2, 'test_sp2.h5') - sm2_h = hickle.load('test_sp2.h5') - hickle.dump(sm3, 'test_sp3.h5') - sm3_h = hickle.load('test_sp3.h5') - - assert isinstance(sm1_h, csr_matrix) - assert isinstance(sm2_h, csc_matrix) - assert isinstance(sm3_h, bsr_matrix) - - assert np.allclose(sm1_h.data, sm1.data) - assert np.allclose(sm2_h.data, sm2.data) - assert np.allclose(sm3_h.data, sm3.data) - - assert sm1_h. shape == sm1.shape - assert sm2_h. shape == sm2.shape - assert sm3_h. shape == sm3.shape - - -if __name__ == "__main__": - test_sparse_matrix() - test_is_sparse() \ No newline at end of file diff --git a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/site.py b/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/site.py deleted file mode 100644 index 0d2d2ff8da3960ecdaa6591fcee836c186fb8c91..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/hickle/lib/python3.6/site-packages/site.py +++ /dev/null @@ -1,74 +0,0 @@ -def __boot(): - import sys - import os - PYTHONPATH = os.environ.get('PYTHONPATH') - if PYTHONPATH is None or (sys.platform == 'win32' and not PYTHONPATH): - PYTHONPATH = [] - else: - PYTHONPATH = PYTHONPATH.split(os.pathsep) - - pic = getattr(sys, 'path_importer_cache', {}) - stdpath = sys.path[len(PYTHONPATH):] - mydir = os.path.dirname(__file__) - - for item in stdpath: - if item == mydir or not item: - continue # skip if current dir. on Windows, or my own directory - importer = pic.get(item) - if importer is not None: - loader = importer.find_module('site') - if loader is not None: - # This should actually reload the current module - loader.load_module('site') - break - else: - try: - import imp # Avoid import loop in Python >= 3.3 - stream, path, descr = imp.find_module('site', [item]) - except ImportError: - continue - if stream is None: - continue - try: - # This should actually reload the current module - imp.load_module('site', stream, path, descr) - finally: - stream.close() - break - else: - raise ImportError("Couldn't find the real 'site' module") - - known_paths = dict([(makepath(item)[1], 1) for item in sys.path]) # 2.2 comp - - oldpos = getattr(sys, '__egginsert', 0) # save old insertion position - sys.__egginsert = 0 # and reset the current one - - for item in PYTHONPATH: - addsitedir(item) - - sys.__egginsert += oldpos # restore effective old position - - d, nd = makepath(stdpath[0]) - insert_at = None - new_path = [] - - for item in sys.path: - p, np = makepath(item) - - if np == nd and insert_at is None: - # We've hit the first 'system' path entry, so added entries go here - insert_at = len(new_path) - - if np in known_paths or insert_at is None: - new_path.append(item) - else: - # new path after the insert point, back-insert it - new_path.insert(insert_at, item) - insert_at += 1 - - sys.path[:] = new_path - - -if __name__ == 'site': - __boot() - del __boot diff --git a/video_prediction_tools/external_package/lpips-tensorflow/.gitignore b/video_prediction_tools/external_package/lpips-tensorflow/.gitignore deleted file mode 100644 index 894a44cc066a027465cd26d634948d56d13af9af..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/.gitignore +++ /dev/null @@ -1,104 +0,0 @@ -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -.hypothesis/ -.pytest_cache/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# pyenv -.python-version - -# celery beat schedule file -celerybeat-schedule - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ diff --git a/video_prediction_tools/external_package/lpips-tensorflow/.gitmodules b/video_prediction_tools/external_package/lpips-tensorflow/.gitmodules deleted file mode 100644 index 085c5852ff85afe688333807a8a26392d20e8ed3..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/.gitmodules +++ /dev/null @@ -1,3 +0,0 @@ -[submodule "PerceptualSimilarity"] - path = PerceptualSimilarity - url = https://github.com/alexlee-gk/PerceptualSimilarity.git diff --git a/video_prediction_tools/external_package/lpips-tensorflow/README.md b/video_prediction_tools/external_package/lpips-tensorflow/README.md deleted file mode 100644 index 760f5a028e2aae7187135c267edca89db464db5f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/README.md +++ /dev/null @@ -1,57 +0,0 @@ -# lpips-tensorflow -Tensorflow port for the [PyTorch](https://github.com/richzhang/PerceptualSimilarity) implementation of the [Learned Perceptual Image Patch Similarity (LPIPS)](http://richzhang.github.io/PerceptualSimilarity/) metric. -This is done by exporting the model from PyTorch to ONNX and then to TensorFlow. - -## Getting started -### Installation -- Clone this repo. -```bash -git clone https://github.com/alexlee-gk/lpips-tensorflow.git -cd lpips-tensorflow -``` -- Install TensorFlow and dependencies from http://tensorflow.org/ -- Install other dependencies. -```bash -pip install -r requirements.txt -``` - -### Using the LPIPS metric -The `lpips` TensorFlow function works with individual images or batches of images. -It also works with images of any spatial dimensions (but the dimensions should be at least the size of the network's receptive field). -This example computes the LPIPS distance between batches of images. -```python -import numpy as np -import tensorflow as tf -import lpips_tf - -batch_size = 32 -image_shape = (batch_size, 64, 64, 3) -image0 = np.random.random(image_shape) -image1 = np.random.random(image_shape) -image0_ph = tf.placeholder(tf.float32) -image1_ph = tf.placeholder(tf.float32) - -distance_t = lpips_tf.lpips(image0_ph, image1_ph, model='net-lin', net='alex') - -with tf.Session() as session: - distance = session.run(distance_t, feed_dict={image0_ph: image0, image1_ph: image1}) -``` - -## Exporting additional models -### Export PyTorch model to TensorFlow through ONNX -- Clone the PerceptualSimilarity submodule and add it to the PYTHONPATH. -```bash -git submodule update --init --recursive -export PYTHONPATH=PerceptualSimilarity:$PYTHONPATH -``` -- Install more dependencies. -```bash -pip install -r requirements-dev.txt -``` -- Export the model to ONNX *.onnx and TensorFlow *.pb files in the `models` directory. -```bash -python export_to_tensorflow.py --model net-lin --net alex -``` - -### Known issues -- The SqueezeNet model cannot be exported since ONNX cannot export one of the operators. diff --git a/video_prediction_tools/external_package/lpips-tensorflow/export_to_tensorflow.py b/video_prediction_tools/external_package/lpips-tensorflow/export_to_tensorflow.py deleted file mode 100644 index 091c2273176f09a686fbef015da9132e5f58c2d8..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/export_to_tensorflow.py +++ /dev/null @@ -1,62 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: BSD-2-Clause - -import argparse -import os - -import onnx -import torch -import torch.onnx - -from models import dist_model as dm - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--model', choices=['net-lin', 'net'], default='net-lin', help='net-lin or net') - parser.add_argument('--net', choices=['squeeze', 'alex', 'vgg'], default='alex', help='squeeze, alex, or vgg') - parser.add_argument('--version', type=str, default='0.1') - parser.add_argument('--image_height', type=int, default=64) - parser.add_argument('--image_width', type=int, default=64) - args = parser.parse_args() - - model = dm.DistModel() - model.initialize(model=args.model, net=args.net, use_gpu=False, version=args.version) - print('Model [%s] initialized' % model.name()) - - dummy_im0 = torch.Tensor(1, 3, args.image_height, args.image_width) # image should be RGB, normalized to [-1, 1] - dummy_im1 = torch.Tensor(1, 3, args.image_height, args.image_width) - - cache_dir = os.path.expanduser('~/.lpips') - os.makedirs(cache_dir, exist_ok=True) - onnx_fname = os.path.join(cache_dir, '%s_%s_v%s.onnx' % (args.model, args.net, args.version)) - - # export model to onnx format - torch.onnx.export(model.net, (dummy_im0, dummy_im1), onnx_fname, verbose=True) - - # load and change dimensions to be dynamic - model = onnx.load(onnx_fname) - for dim in (0, 2, 3): - model.graph.input[0].type.tensor_type.shape.dim[dim].dim_param = '?' - model.graph.input[1].type.tensor_type.shape.dim[dim].dim_param = '?' - - # needs to be imported after all the pytorch stuff, otherwise this causes a segfault - from onnx_tf.backend import prepare - tf_rep = prepare(model, device='CPU') - producer_version = tf_rep.graph.graph_def_versions.producer - pb_fname = os.path.join(cache_dir, '%s_%s_v%s_%d.pb' % (args.model, args.net, args.version, producer_version)) - tf_rep.export_graph(pb_fname) - input0_name, input1_name = [tf_rep.tensor_dict[input_name].name for input_name in tf_rep.inputs] - (output_name,) = [tf_rep.tensor_dict[output_name].name for output_name in tf_rep.outputs] - - # ensure these are the names of the 2 inputs, since that will be assumed when loading the pb file - assert input0_name == '0:0' - assert input1_name == '1:0' - # ensure that the only output is the output of the last op in the graph, since that will be assumed later - (last_output_name,) = [output.name for output in tf_rep.graph.get_operations()[-1].outputs] - assert output_name == last_output_name - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/external_package/lpips-tensorflow/lpips_tf.py b/video_prediction_tools/external_package/lpips-tensorflow/lpips_tf.py deleted file mode 100644 index da8773a94f46821a22906de32a2a81627f06506f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/lpips_tf.py +++ /dev/null @@ -1,94 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: BSD-2-Clause - -import os -import sys - -import tensorflow as tf -from six.moves import urllib - -_URL = 'http://rail.eecs.berkeley.edu/models/lpips' - - -def _download(url, output_dir): - """Downloads the `url` file into `output_dir`. - - Modified from https://github.com/tensorflow/models/blob/master/research/slim/datasets/dataset_utils.py - """ - filename = url.split('/')[-1] - filepath = os.path.join(output_dir, filename) - - def _progress(count, block_size, total_size): - sys.stdout.write('\r>> Downloading %s %.1f%%' % ( - filename, float(count * block_size) / float(total_size) * 100.0)) - sys.stdout.flush() - - filepath, _ = urllib.request.urlretrieve(url, filepath, _progress) - print() - statinfo = os.stat(filepath) - print('Successfully downloaded', filename, statinfo.st_size, 'bytes.') - - -def lpips(input0, input1, model='net-lin', net='alex', version=0.1): - """ - Learned Perceptual Image Patch Similarity (LPIPS) metric. - - Args: - input0: An image tensor of shape `[..., height, width, channels]`, - with values in [0, 1]. - input1: An image tensor of shape `[..., height, width, channels]`, - with values in [0, 1]. - - Returns: - The Learned Perceptual Image Patch Similarity (LPIPS) distance. - - Reference: - Richard Zhang, Phillip Isola, Alexei A. Efros, Eli Shechtman, Oliver Wang. - The Unreasonable Effectiveness of Deep Features as a Perceptual Metric. - In CVPR, 2018. - """ - # flatten the leading dimensions - batch_shape = tf.shape(input0)[:-3] - input0 = tf.reshape(input0, tf.concat([[-1], tf.shape(input0)[-3:]], axis=0)) - input1 = tf.reshape(input1, tf.concat([[-1], tf.shape(input1)[-3:]], axis=0)) - # NHWC to NCHW - input0 = tf.transpose(input0, [0, 3, 1, 2]) - input1 = tf.transpose(input1, [0, 3, 1, 2]) - # normalize to [-1, 1] - input0 = input0 * 2.0 - 1.0 - input1 = input1 * 2.0 - 1.0 - - input0_name, input1_name = '0:0', '1:0' - - default_graph = tf.get_default_graph() - producer_version = default_graph.graph_def_versions.producer - - cache_dir = os.path.expanduser('~/.lpips') - os.makedirs(cache_dir, exist_ok=True) - # files to try. try a specific producer version, but fallback to the version-less version (latest). - pb_fnames = [ - '%s_%s_v%s_%d.pb' % (model, net, version, producer_version), - '%s_%s_v%s.pb' % (model, net, version), - ] - for pb_fname in pb_fnames: - if not os.path.isfile(os.path.join(cache_dir, pb_fname)): - try: - _download(os.path.join(_URL, pb_fname), cache_dir) - except urllib.error.HTTPError: - pass - if os.path.isfile(os.path.join(cache_dir, pb_fname)): - break - - with open(os.path.join(cache_dir, pb_fname), 'rb') as f: - graph_def = tf.GraphDef() - graph_def.ParseFromString(f.read()) - _ = tf.import_graph_def(graph_def, - input_map={input0_name: input0, input1_name: input1}) - distance, = default_graph.get_operations()[-1].outputs - - if distance.shape.ndims == 4: - distance = tf.squeeze(distance, axis=[-3, -2, -1]) - # reshape the leading dimensions - distance = tf.reshape(distance, batch_shape) - return distance diff --git a/video_prediction_tools/external_package/lpips-tensorflow/requirements-dev.txt b/video_prediction_tools/external_package/lpips-tensorflow/requirements-dev.txt deleted file mode 100644 index df36766f1d4cc3b378d7aba0164f3fdfab3be1d2..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/requirements-dev.txt +++ /dev/null @@ -1,4 +0,0 @@ -torch>=0.4.0 -torchvision>=0.2.1 -onnx -onnx-tf diff --git a/video_prediction_tools/external_package/lpips-tensorflow/requirements.txt b/video_prediction_tools/external_package/lpips-tensorflow/requirements.txt deleted file mode 100644 index bc2cbbd1ca22aca38c3fd05d94d07fbf60ed4f6d..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/requirements.txt +++ /dev/null @@ -1,2 +0,0 @@ -numpy -six diff --git a/video_prediction_tools/external_package/lpips-tensorflow/setup.py b/video_prediction_tools/external_package/lpips-tensorflow/setup.py deleted file mode 100644 index 9fc8d1d35b0b9249c7c0e24dd14efc707b873d92..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/setup.py +++ /dev/null @@ -1,11 +0,0 @@ -#!/usr/bin/env python - -from distutils.core import setup - -setup( - name='lpips-tf', - description='Tensorflow port for the Learned Perceptual Image Patch Similarity (LPIPS) metric', - author='Alex Lee', - url='https://github.com/alexlee-gk/lpips-tensorflow/', - py_modules=['lpips_tf'] -) diff --git a/video_prediction_tools/external_package/lpips-tensorflow/test_network.py b/video_prediction_tools/external_package/lpips-tensorflow/test_network.py deleted file mode 100644 index 19df3bfa7414c40ae51018279835152a2e6f6cc1..0000000000000000000000000000000000000000 --- a/video_prediction_tools/external_package/lpips-tensorflow/test_network.py +++ /dev/null @@ -1,46 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: BSD-2-Clause - -import argparse - -import cv2 -import numpy as np -import tensorflow as tf - -import lpips_tf - - -def load_image(fname): - image = cv2.imread(fname) - image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) - return image.astype(np.float32) / 255.0 - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument('--model', choices=['net-lin', 'net'], default='net-lin', help='net-lin or net') - parser.add_argument('--net', choices=['squeeze', 'alex', 'vgg'], default='alex', help='squeeze, alex, or vgg') - parser.add_argument('--version', type=str, default='0.1') - args = parser.parse_args() - - ex_ref = load_image('./PerceptualSimilarity/imgs/ex_ref.png') - ex_p0 = load_image('./PerceptualSimilarity/imgs/ex_p0.png') - ex_p1 = load_image('./PerceptualSimilarity/imgs/ex_p1.png') - - session = tf.Session() - - image0_ph = tf.placeholder(tf.float32) - image1_ph = tf.placeholder(tf.float32) - lpips_fn = session.make_callable( - lpips_tf.lpips(image0_ph, image1_ph, model=args.model, net=args.net, version=args.version), - [image0_ph, image1_ph]) - - ex_d0 = lpips_fn(ex_ref, ex_p0) - ex_d1 = lpips_fn(ex_ref, ex_p1) - - print('Distances: (%.3f, %.3f)' % (ex_d0, ex_d1)) - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json b/video_prediction_tools/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json deleted file mode 100644 index 4a1b23edcb68b57dadee82b1c13366afac50a52a..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/bair_action_free/ours_deterministic_l1/model_hparams.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.001, - "beta1": 0.9, - "beta2": 0.999, - "l1_weight": 1.0, - "l2_weight": 0.0, - "kl_weight": 0.0, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.0, - "state_weight": 0.0, - "nz": 0 -} \ No newline at end of file diff --git a/video_prediction_tools/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json b/video_prediction_tools/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json deleted file mode 100644 index 31e7152ae15df5ee33b264f11c88c76c50592185..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/bair_action_free/ours_deterministic_l2/model_hparams.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.001, - "beta1": 0.9, - "beta2": 0.999, - "l1_weight": 0.0, - "l2_weight": 1.0, - "kl_weight": 0.0, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.0, - "state_weight": 0.0, - "nz": 0 -} \ No newline at end of file diff --git a/video_prediction_tools/hparams/bair_action_free/ours_gan/model_hparams.json b/video_prediction_tools/hparams/bair_action_free/ours_gan/model_hparams.json deleted file mode 100644 index 38837822c90f38c6209dfa27019a90ccdf8ea43a..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/bair_action_free/ours_gan/model_hparams.json +++ /dev/null @@ -1,14 +0,0 @@ -{ - "batch_size": 16, - "lr": 0.0002, - "beta1": 0.5, - "beta2": 0.999, - "l1_weight": 100.0, - "l2_weight": 0.0, - "kl_weight": 0.0, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.1, - "vae_gan_feature_cdist_weight": 0.0, - "gan_feature_cdist_weight": 10.0, - "state_weight": 0.0 -} \ No newline at end of file diff --git a/video_prediction_tools/hparams/bair_action_free/ours_vae_l1/model_hparams.json b/video_prediction_tools/hparams/bair_action_free/ours_vae_l1/model_hparams.json deleted file mode 100644 index 827757e11b75e720d236417be449b7a301a005ec..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/bair_action_free/ours_vae_l1/model_hparams.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.001, - "beta1": 0.9, - "beta2": 0.999, - "l1_weight": 1.0, - "l2_weight": 0.0, - "kl_weight": 0.001, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.0, - "state_weight": 0.0 -} \ No newline at end of file diff --git a/video_prediction_tools/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json b/video_prediction_tools/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json deleted file mode 100644 index 4fddf0eef1d45dbfa16f098e5e42b12f594132e3..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/bair_action_free/sv2p_time_invariant/model_hparams.json +++ /dev/null @@ -1,12 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.001, - "beta1": 0.9, - "beta2": 0.999, - "l1_weight": 0.0, - "l2_weight": 1.0, - "kl_weight": 0.001, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.0, - "state_weight": 0.0 -} \ No newline at end of file diff --git a/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json index 4f3a43f11a88e1172d4769bee98bbab8e0a7f59b..218a008b1a1e62e4d87b37e68a592196ed1fd473 100644 --- a/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/convLSTM/model_hparams_template.json @@ -6,7 +6,9 @@ "context_frames":12, "loss_fun":"mse", "opt_var": "0", - "shuffle_on_val":true + "shuffle_on_val":true, + "sequence_length": 24, + "shift": 1 } diff --git a/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json index a2b9be547d450d49ef230e37821f503838a5dcee..6f8ea27bc86257e7a7b5f8eecdadced4a12ba109 100644 --- a/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json +++ b/video_prediction_tools/hparams/era5/convLSTM_gan/model_hparams_template.json @@ -7,7 +7,6 @@ "loss_fun":"rmse", "shuffle_on_val":false, "recon_weight":0.6 - } diff --git a/video_prediction_tools/hparams/gzprcp/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/gzprcp/convLSTM/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..e6177e68848d6c3e282f41793fc2236781dc0b6f --- /dev/null +++ b/video_prediction_tools/hparams/gzprcp/convLSTM/model_hparams_template.json @@ -0,0 +1,12 @@ + +{ + "batch_size": 4, + "lr": 0.001, + "max_epochs":2, + "context_frames":20, + "sequence_length":40, + "loss_fun":"rmse" +} + + + diff --git a/video_prediction_tools/hparams/gzprcp/convLSTM_gan/model_hparams_template.json b/video_prediction_tools/hparams/gzprcp/convLSTM_gan/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..3b5c56ad412ed032f8ab27fef3ab19adc855babf --- /dev/null +++ b/video_prediction_tools/hparams/gzprcp/convLSTM_gan/model_hparams_template.json @@ -0,0 +1,14 @@ + +{ + "batch_size": 32, + "lr": 0.001, + "max_epochs":8, + "context_frames":12, + "sequence_length":12, + "loss_fun":"rmse", + "shuffle_on_val":true, + "k":0.01 +} + + + diff --git a/video_prediction_tools/hparams/kth/ours_savp/model_hparams.json b/video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json similarity index 74% rename from video_prediction_tools/hparams/kth/ours_savp/model_hparams.json rename to video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json index 66b41f87e3c0f417b492314060121a0bfd01c8f9..7c1ab72eea7ad1b341a66a76c4a88d1524450417 100644 --- a/video_prediction_tools/hparams/kth/ours_savp/model_hparams.json +++ b/video_prediction_tools/hparams/gzprcp/savp/model_hparams_template.json @@ -1,5 +1,5 @@ { - "batch_size": 8, + "batch_size": 16, "lr": 0.0002, "beta1": 0.5, "beta2": 0.999, @@ -11,8 +11,10 @@ "vae_gan_feature_cdist_weight": 10.0, "gan_feature_cdist_weight": 0.0, "state_weight": 0.0, - "nz": 32, - "max_steps":20 + "nz": 16, + "max_epochs":2, + "context_frames":10, + "sequence_length":30 } diff --git a/video_prediction_tools/hparams/kth/ours_deterministic_l1/model_hparams.json b/video_prediction_tools/hparams/kth/ours_deterministic_l1/model_hparams.json deleted file mode 100644 index 4a1b23edcb68b57dadee82b1c13366afac50a52a..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/kth/ours_deterministic_l1/model_hparams.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.001, - "beta1": 0.9, - "beta2": 0.999, - "l1_weight": 1.0, - "l2_weight": 0.0, - "kl_weight": 0.0, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.0, - "state_weight": 0.0, - "nz": 0 -} \ No newline at end of file diff --git a/video_prediction_tools/hparams/kth/ours_deterministic_l2/model_hparams.json b/video_prediction_tools/hparams/kth/ours_deterministic_l2/model_hparams.json deleted file mode 100644 index 31e7152ae15df5ee33b264f11c88c76c50592185..0000000000000000000000000000000000000000 --- a/video_prediction_tools/hparams/kth/ours_deterministic_l2/model_hparams.json +++ /dev/null @@ -1,13 +0,0 @@ -{ - "batch_size": 32, - "lr": 0.001, - "beta1": 0.9, - "beta2": 0.999, - "l1_weight": 0.0, - "l2_weight": 1.0, - "kl_weight": 0.0, - "video_sn_vae_gan_weight": 0.0, - "video_sn_gan_weight": 0.0, - "state_weight": 0.0, - "nz": 0 -} \ No newline at end of file diff --git a/video_prediction_tools/hparams/weatherbench/convLSTM/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/convLSTM/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..218a008b1a1e62e4d87b37e68a592196ed1fd473 --- /dev/null +++ b/video_prediction_tools/hparams/weatherbench/convLSTM/model_hparams_template.json @@ -0,0 +1,15 @@ + +{ + "batch_size": 4, + "lr": 0.001, + "max_epochs":20, + "context_frames":12, + "loss_fun":"mse", + "opt_var": "0", + "shuffle_on_val":true, + "sequence_length": 24, + "shift": 1 +} + + + diff --git a/video_prediction_tools/hparams/weatherbench/convLSTM_gan/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/convLSTM_gan/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..6f8ea27bc86257e7a7b5f8eecdadced4a12ba109 --- /dev/null +++ b/video_prediction_tools/hparams/weatherbench/convLSTM_gan/model_hparams_template.json @@ -0,0 +1,13 @@ + +{ + "batch_size": 4, + "lr": 0.001, + "max_epochs":20, + "context_frames":12, + "loss_fun":"rmse", + "shuffle_on_val":false, + "recon_weight":0.6 +} + + + diff --git a/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..bc5f8983a5aa6b0b2ba3d560bc4c2391995794a4 --- /dev/null +++ b/video_prediction_tools/hparams/weatherbench/mcnet/model_hparams_template.json @@ -0,0 +1,10 @@ + +{ + "batch_size": 10, + "lr": 0.001, + "max_epochs": 2, + "context_frames": 12 +} + + + diff --git a/video_prediction_tools/hparams/kth/ours_gan/model_hparams.json b/video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json similarity index 84% rename from video_prediction_tools/hparams/kth/ours_gan/model_hparams.json rename to video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json index 3d14b63edbf14efca2cefe4703453d899b3fb0fd..0ccf44e6370f765857204317f172c866865b4b35 100644 --- a/video_prediction_tools/hparams/kth/ours_gan/model_hparams.json +++ b/video_prediction_tools/hparams/weatherbench/ours_gan/model_hparams_template.json @@ -11,5 +11,7 @@ "vae_gan_feature_cdist_weight": 0.0, "gan_feature_cdist_weight": 10.0, "state_weight": 0.0, - "nz": 32 -} \ No newline at end of file + "nz": 32, + "max_epochs":2, + "context_frames":12 +} diff --git a/video_prediction_tools/hparams/kth/ours_vae_l1/model_hparams.json b/video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json similarity index 80% rename from video_prediction_tools/hparams/kth/ours_vae_l1/model_hparams.json rename to video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json index dee3ce9f8e431d7f7cb46042936cfae3dcfbc6e4..770f9ff516a630ff031b94bb2c8a2b41c1686eec 100644 --- a/video_prediction_tools/hparams/kth/ours_vae_l1/model_hparams.json +++ b/video_prediction_tools/hparams/weatherbench/ours_vae_l1/model_hparams_template.json @@ -9,5 +9,7 @@ "video_sn_vae_gan_weight": 0.0, "video_sn_gan_weight": 0.0, "state_weight": 0.0, - "nz": 32 -} \ No newline at end of file + "nz": 32, + "max_epochs":2, + "context_frames":12 +} diff --git a/video_prediction_tools/hparams/bair_action_free/ours_savp/model_hparams.json b/video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json similarity index 54% rename from video_prediction_tools/hparams/bair_action_free/ours_savp/model_hparams.json rename to video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json index a6eea83a19505d374e9d614f48ef8bc72443c0f2..f36e1c0b44279ad2e4f9e741c7bfade0a5aa0a05 100644 --- a/video_prediction_tools/hparams/bair_action_free/ours_savp/model_hparams.json +++ b/video_prediction_tools/hparams/weatherbench/savp/model_hparams_template.json @@ -1,14 +1,22 @@ { - "batch_size": 16, + "batch_size": 32, "lr": 0.0002, "beta1": 0.5, "beta2": 0.999, "l1_weight": 100.0, "l2_weight": 0.0, - "kl_weight": 1.0, + "kl_weight": 0.01, "video_sn_vae_gan_weight": 0.1, "video_sn_gan_weight": 0.1, "vae_gan_feature_cdist_weight": 10.0, "gan_feature_cdist_weight": 0.0, - "state_weight": 0.0 -} \ No newline at end of file + "state_weight": 0.0, + "nz": 16, + "max_epochs":4, + "context_frames": 12, + "opt_var": "0", + "decay_steps":[3000,9000], + "end_lr": 0.00000008 +} + + diff --git a/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json b/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json new file mode 100644 index 0000000000000000000000000000000000000000..1306627e24bec0888600fb88fcaa937e5f01dbd7 --- /dev/null +++ b/video_prediction_tools/hparams/weatherbench/vae/model_hparams_template.json @@ -0,0 +1,14 @@ + +{ + "batch_size": 10, + "lr": 0.001, + "nz":16, + "max_epochs":2, + "context_frames":12, + "weight_recon":1, + "loss_fun": "rmse", + "shuffle_on_val": true +} + + + diff --git a/video_prediction_tools/main_scripts/main_data_extraction.py b/video_prediction_tools/main_scripts/main_data_extraction.py index 494540eedf9553271ff3a3e9e901a34c0d1236fc..7f38a1a88b41053107d758e38893f9860f0ebe16 100644 --- a/video_prediction_tools/main_scripts/main_data_extraction.py +++ b/video_prediction_tools/main_scripts/main_data_extraction.py @@ -1,234 +1,277 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong, Amirpasha Mozaffari" -__date__ = "2020-11-10" - -from mpi4py import MPI -import sys -import subprocess -import logging -import time -from utils.external_function import directory_scanner -from utils.external_function import load_distributor -from data_preprocess.prepare_era5_data import * -# How to Run it! -# mpirun -np 6 python mpi_stager_v2.py +import json as js import os -import shutil -from pathlib import Path import argparse +import itertools as it +from pathlib import Path +from typing import Union, get_args +import zipfile as zf +import multiprocessing as mp +import sys +import json + +from data_preprocess.extract_weatherbench import ExtractWeatherbench +from utils.dataset_utils import DATASETS, get_dataset_info + + +# IDEA: type conversion (generic) => params_obj => bounds_checking (ds-specific)/ semantic checking + +def dataset(name): + if name not in DATASETS: + raise ValueError(f"'dataset' must be one of {DATASETS}.") + return name + +def source_dir(directory: str) -> Path: + dir = Path(directory) + if not dir.exists(): + raise ValueError(f"Input directory {dir.absolute()} does not exist") + return dir + + +def destination_dir(directory: str) -> Path: + dir = Path(directory) + if not dir.exists(): + raise ValueError(f"Output directory: {dir.absolute()} does not exist.") + return dir + + +def years(years: str) -> Union[list[int], int]: + try: + year_list = [int(x) for x in years] + except ValueError as e: + if not years == "all": + raise ValueError( + f"years must be either a list of years or 'all', not {months}." + ) + year_list = -1 + + return year_list + + +def months(months: str) -> list[int]: + try: + month_list = [int(x) for x in months] + except ValueError as e: + if months == "DJF": + month_list = [1, 2, 12] + elif months == "MAM": + month_list = [3, 4, 5] + elif months == "JJA": + month_list = [6, 7, 8] + elif months == "SON": + month_list = [9, 10, 11] + elif months == "all": + month_list = list(range(1, 13)) + else: + raise ValueError( + f"months-string '{months}' cannot be converted to list of months" + ) + + if not all(1 <= m <= 12 for m in month_list): + errors = filter(lambda m: not 1 <= m <= 12, month_list) + raise ValueError( + f"all month integers must be within 1, ..., 12 not {list(errors)}" + ) + + return month_list + + +def variables(variables: str) -> list[dict]: + var_list = json.loads(variables) + + attributes = {"name", "lvl", "interpolation"} + interpolations = {"p", "z"} + + for var in var_list: + if not var.keys() == attributes: + raise ValueError(f"each variable should have the attributes {attributes}") + if not type(var["name"]) == str: + raise ValueError(f"'name' should be of type string not {type(var['name'])}") + if not type(var["lvl"]) == list: + raise ValueError(f"'lvl' should be of type list not {type(var['lvl'])}") + if not var["interpolation"] in interpolations: + raise ValueError(f"value of 'interpolation' should be one of {interpolations} not {var['interpolation']}") + if len(var["lvl"]) == 0: + raise ValueError(f"'lvl' should have at least one entry") + if not all(type(lvl) == int for lvl in var["lvl"]): + raise ValueError(f"all entries of 'lvl' should be of type int") + + return var_list + + +def get_data_files(variables: list, years, resolution, dirin: Path): + """ + Get path to zip files and names of the yearly files within. + :param variables: list of variables + :param years: list of years + :param months: list of months + :param resolution: + :param dirin: input directory + :return lists paths to zips of variables + """ + data_files = [] + zip_files = [] + res_str = f"{resolution}deg" + for var in variables: + var_dir = dirin / res_str / var + if not var_dir.exists(): + raise ValueError( + f"variable {var} is not available for resolution {res_str}" + ) + + zip_file = var_dir / f"{var}_{res_str}.zip" + with zf.ZipFile(zip_file, "r") as myzip: + names = myzip.namelist() + if not all(any(str(year) in name for name in names) for year in years): + raise ValueError( + f"variable {var} is not available for all years: {years}" + ) + names = filter(lambda name: any(str(year) in name for year in years), names) + + data_files.append(list(names)) + zip_files.append(zip_file) + + return zip_files, data_files + + +def nyx(nyx): + try: + nyx = [int(n) for n in nyx] + except ValueError as e: + raise ValueError(f"number of grid points should be integers not {nyx}") + if not all(n > 0 for n in nyx): + raise ValueError(f"number of grid points should be > 0") + + return nyx + + +def coords(coords_sw): + try: + coords = [float(c) for c in coords_sw] + except ValueError as e: + raise ValueError(f"coordinates should be floats not {coords}") + if not -90 <= coords[0] <= 90: + raise ValueError( + f"latitude of sw-corner is {coords[0]} but should be >= -90, <= 90" + ) + if not 0 <= coords[1] <= 360: + raise ValueError( + f"latitude of sw-corner is {coords[0]} but should be >= 0, <= 360" + ) + return coords + def main(): - current_path = os.getcwd() - parser=argparse.ArgumentParser() - parser.add_argument("--source_dir",type=str,default="//home/a.mozaffari/data_era5/2017/") - parser.add_argument("--target_dir",type=str,default="/home/a.mozaffari/data_dest") - parser.add_argument("--logs_path",type=str,default=current_path) - parser.add_argument("--year",type=str,default="2007") - parser.add_argument("--varslist_path",type=str) - args = parser.parse_args() - # for the local machine test - current_path = os.getcwd() - src_top = args.source_dir - source_dir = os.path.join(args.source_dir,args.year) + "/" - target_dir = args.target_dir - destination_dir = os.path.join(target_dir,args.year) - logs_path = args.logs_path - year = args.year - varslist_path = args.varslist_path - os.chdir(current_path) - # ini. MPI - comm = MPI.COMM_WORLD - my_rank = comm.Get_rank() # rank of the node - p = comm.Get_size() # number of assigned nods + # TODO consult Bing for defaults + parser = argparse.ArgumentParser() + parser.add_argument( + "dataset", + type=dataset, + help="Name of the dataset" + ) + parser.add_argument( + "source_dir", + type=source_dir, + help="Top-level directory where ERA5 grib-files are located under <year>/<month>.", + ) + parser.add_argument( + "destination_dir", + type=destination_dir, + help="Destination directory where the netCDF-files will be stored", + ) + parser.add_argument( + "years", nargs="+", type=int, help="Years of data to be processed." + ) + parser.add_argument( + "variables", + help="list of variables to extract", + type=variables, + ) + parser.add_argument( + "--resolution", + "-r", + choices=[1.40625, 2.8125, 5.625], + default=5.625, + ) + parser.add_argument( + "--months", + "-m", + nargs="+", + dest="months", + default="all", + type=months, + help="Months of data. Can also be 'all' or season-strings, e.g. 'DJF'.", + ) + parser.add_argument( + "--sw_corner", + "-swc", + dest="sw_corner", + nargs="+", + type=coords, + default=(0.0, 0.0), + help="Defines south-west corner of target domain (lat, lon)=(-90..90, 0..360)", + ) + parser.add_argument( + "--nyx", + "-nyx", + dest="nyx", + nargs="+", + type=nyx, + default=(10, 20), + help="Number of grid points in zonal and meridional direction.", + ) - # ============ configuration for data preprocessing =================== # + args = parser.parse_args() - # ==================================== Master Logging ==================================================== # - # DEBUG: Detailed information, typically of interest only when diagnosing problems. - # INFO: Confirmation that things are working as expected. - # WARNING: An indication that something unexpected happened, or indicative of some problem in the near - # ERROR: Due to a more serious problem, the software has not been able to perform some function. - # CRITICAL: A serious error, indicating that the program itself may be unable to continue running. + # check if north-east corner is valid + ne_corner = [ + coord + n * args.resolution for coord, n in zip(args.sw_corner, args.nyx) + ] + if not (-90 <= ne_corner[0] <= 90 and 0 <= ne_corner[1] <= 360): + raise ValueError( + f"number of grid points {args.nyx} will result in a invalid north-east corner: {ne_corner}" + ) + + # check if arguments can be provided by dataset + dataset_info = get_dataset_info(args.dataset) + + years_not_avail = [year not in dataset_info["years"] for year in args.years] + vars_avail_map = {var["name"]: var for var in dataset_info["variables"]} + + for variable in args.variables: + try: + var = vars_avail_map[variable["name"]] + except KeyError as e: + raise ValueError(f"variable {variable['name']} is not available for dataset {args.dataset}.") + + lvl_not_avail = list(filter(lambda l: not l in var["lvl"], variable["lvl"])) + if len(lvl_not_avail) > 0: + raise ValueError(f"variable {variable['name']} at lvl {lvl_not_avail} is not available for dataset {args.dataset}.") + - if my_rank == 0: # node is master - logs_path = logs_path + '/logs/' - if not os.path.exists(logs_path): - os.mkdir(logs_path) + # get extraction instance + if args.dataset == "weatherbench": + extraction = ExtractWeatherbench( + args.source_dir, + args.destination_dir, + args.variables, + args.years, + args.months, + (args.sw_corner[0], ne_corner[0]), + (args.sw_corner[1], ne_corner[1]), + args.resolution, + ) + elif args.dataset == "era5": + extraction = NewEra5Extraction() + else: + raise ValueError("no other extractor.") + + print("initialized extraction") - logger_path_main = logs_path + 'Main_log.log' - if os.path.exists(logger_path_main): - print("Logger Exists -> Logger Deleted") - os.remove(logger_path_main) - - logging.basicConfig(filename=logger_path_main, level=logging.DEBUG, - format='%(asctime)s:%(levelname)s:%(message)s') - logger = logging.getLogger(__file__) - logger.addHandler(logging.StreamHandler(sys.stdout)) - #start = time.time() # start of the MPI - logger.info(' === PyStager is started === ') - - # ================================== ALL Nodes: Read-in parameters ====================================== # - - # check the existence of teh folders : - - if not os.path.exists(source_dir): # check if the source dir. is existing - if my_rank == 0: - logger.critical('The source does not exist') - logger.info('exit status : 1') - - sys.exit(1) - - if not os.path.exists(destination_dir): # check if the Destination dir. is existing - if my_rank == 0: - logger.critical('The Destination does not exist') - logger.info('Create a Destination dir') - if not os.path.exists(destination_dir): os.makedirs(destination_dir) - - if os.path.exists(destination_dir): - if my_rank == 0: - os.makedirs(destination_dir, exist_ok=True) - shutil.rmtree(destination_dir) - logger.critical('The destination exist -> Remove and Re-Create') - - - if my_rank == 0: # node is master - - # ==================================== Master : Directory scanner ================================= # - - print(" # ============== Directory scanner : start ==================# ") - - ret_dir_scanner = directory_scanner(source_dir) - print(ret_dir_scanner) - - dir_detail_list = ret_dir_scanner[0] - sub_dir_list = ret_dir_scanner[1] - total_size_source = ret_dir_scanner[2] - total_num_files = ret_dir_scanner[3] - total_num_dir = ret_dir_scanner[4] - - # =================================== Master : Load Distribution ========================== # - - print(" # ============== Load Distrbution : start ==================# ") - #def load_distributor(dir_detail_list, sub_dir_list, total_size_source, total_num_files, total_num_directories, p): - ret_load_balancer = load_distributor(dir_detail_list, sub_dir_list, total_size_source, total_num_files, total_num_dir, p) - transfer_dict = ret_load_balancer - - - print(ret_load_balancer) - - # ===================================== Main : Send / Receive =============================== # - print(" # ============== Communication : start ==================# ") - - # Send : the list of the directories to the nodes - for nodes in range(1, p): - broadcast_list = transfer_dict[nodes] - comm.send(broadcast_list, dest=nodes) - - # All Receive - message_counter = 1 - while message_counter < p: # non-blocking receive function - message_in = comm.recv() - Worker_status = message_in[0:5] - worker_number = message_in[5:7] - # Idle check Worker_status - if Worker_status == "IDLEE": - status = ' An Idle worker is detected, worker number is: {worker_number}'.format(worker_number=worker_number) - logger.info(status) - # Success - elif Worker_status == "PASSS": - status =' A job process is finished by worker: {worker_number}'.format(worker_number=worker_number) - logger.info(status) - - # Non-Fatal Error - elif Worker_status == "NEROR": - status =' A non-fatal error is triggered by worker: {worker_number}'.format(worker_number=worker_number) - logger.warning(status) - logger.warning("System will continue") - - # Fatal Error - elif Worker_status == "FEROR": - status =' A fatal error is triggered by worker: {worker_number}'.format(worker_number=worker_number) - logger.critical(status) - logger.critical("System is going to terminate") - sys.exit(1) - - # System fail to recogonise the meesage - else: - status =' A message from {worker_number} is not readable by main'.format(worker_number=worker_number) - logger.critical(status) - logger.critical("System is going to terminate") - sys.exit(1) - - message_counter = message_counter + 1 - - logger.info(' Main is finished the job and it will terminate the task') - sys.exit(0) - - else: # node is slave - - # ============================================= worker: Send / Receive ============================================ # - # communication works as a break to stop worker before master is ready - message_in = comm.recv() - - # worker logger file - worker_log = logs_path + '/logs/' + 'Worker_log_{my_rank}.log'.format(my_rank=my_rank) - if os.path.exists(worker_log): - os.remove(worker_log) - - logging.basicConfig(filename=worker_log, level=logging.DEBUG, - format='%(asctime)s:%(levelname)s:%(message)s') - logger = logging.getLogger(__file__) - logger.addHandler(logging.StreamHandler(sys.stdout)) - logger.info('Woker logger is activated') - - # Receive message - if message_in is None: # in case more than number of the dir. processor is assigned todo Tag it! - message_out = ("IDLEE{worker_rank}: is IDLE ".format(worker_rank=my_rank)) - logger.info('Worker {worker_rank} is idle'.format(worker_rank=my_rank)) - logger.info('Worker {worker_rank} is terminated'.format(worker_rank=my_rank)) - comm.send(message_out, dest=0) - sys.exit(0) - - else: # if the Worker node has joblist to do - job_list = message_in.split(';') - logger.info('Worker {worker_rank} to do list is : {to_do_list}'.format(worker_rank=my_rank,to_do_list=job_list)) - - for job_count in range(0, len(job_list)): - job = job_list[job_count] # job is the name of the directory(ies) assigned to worker - logger.info('Worker {worker_rank} next job to do is : {job}'.format(worker_rank=my_rank,job=job)) - - logger.debug('Worker {worker_rank} is starting the ERA5-preproc. on dir.: {job}'.format(worker_rank=my_rank,job=job)) - - era5_case = ERA5DataExtraction(year,job,src_top,target_dir,varslist_path) - worker_status = era5_case.process_era5_in_dir() - - logger.debug('worker status is: {worker_status}'.format(worker_status=worker_status)) - - if worker_status == -1: - message_out = ("FEROR{worker_rank}:Failed is triggered ".format(worker_rank=my_rank)) - logger.critical('progress is unsuccessful. fatal-error is observed. Worker is terminating and communicating the termination of the job to main.') - comm.send(message_out, dest=0) - sys.exit(1) - - if worker_status == 0: - logger.debug('progress is successful') - message_out = ("PASSS{worker_rank}:is finished".format(worker_rank=my_rank)) - logger.info('Worker {worker_rank} finished a task'.format(worker_rank=my_rank)) - - if worker_status == +1: - logger.debug('progress is not successful, but not-fatal') - message_out = ("NEROR{worker_rank}:Failed is triggered ".format(worker_rank=my_rank)) - logger.warning('Worker {worker_rank} has non-fatal failure,but it is continued'.format(worker_rank=my_rank)) - - comm.send(message_out, dest=0) - sys.exit(0) - - MPI.Finalize() + extraction() + if __name__ == "__main__": + print("start script") + mp.set_start_method("spawn") # fix cuda initalization issue main() diff --git a/video_prediction_tools/main_scripts/main_era5_data_extraction.py b/video_prediction_tools/main_scripts/main_era5_data_extraction.py new file mode 100644 index 0000000000000000000000000000000000000000..82a5a314bf14ad3a547e4d4bd1f9a39d7c15443f --- /dev/null +++ b/video_prediction_tools/main_scripts/main_era5_data_extraction.py @@ -0,0 +1,44 @@ +""" +Driver for preprocessing step 1 which parses the input arguments from the runscript +and performs parallelization with PyStager. +""" + +__email__ = "m.langguth@fz-juelich.de" +__author__ = "Michael Langguth" + +import json as js +import argparse +from data_preprocess.extract_era5_data import Extract_ERA5_data + + +def main(): + + parser = argparse.ArgumentParser() + parser.add_argument("--source_dir", "-src_dir", dest="source_dir", type=str, required=True, + help="Top-level directory where ERA5 grib-files are located under <year>/<month>.") + parser.add_argument("--destination_dir", "-dest_dir", dest="destination_dir", type=str, required=True, + help="Destination directory where the netCDF-files will be stored") + parser.add_argument("--years", "-y", nargs="+", dest="years", type=int, required=True, + help="Years of data to be processed.") + parser.add_argument("--months", "-m", nargs="+", dest="months", default="all", + help="Months of data. Can also be 'all' or season-strings, e.g. 'DJF'.") + parser.add_argument("--variables", "-v", dest="vars_dict", type=js.loads, default='{"2t": {"sfc": ""}}', + help="Dictionary-like string to parse variable names (keys) together with " + + "variable types (values).") + parser.add_argument("--sw_corner", "-swc", dest="sw_corner", nargs="+", default=(38.4, 0.), + help="Defines south-west corner of target domain (lat, lon)=(-90..90, 0..360)") + parser.add_argument("--nyx", "-nyx", dest="nyx", nargs="+", type=int, default=(56, 92), + help="Number of grid points in zonal and meridional direction.") + + args = parser.parse_args() + + # initialize preprocessing instance... + era5_extract = Extract_ERA5_data(args.source_dir, args.destination_dir, args.vars_dict, args.sw_corner, + args.nyx, args.years, args.months) + + # ...and run it + era5_extract() + + +if __name__ == "__main__": + main() diff --git a/video_prediction_tools/main_scripts/main_preprocess_data_step1.py b/video_prediction_tools/main_scripts/main_preprocess_data_step1.py deleted file mode 100755 index 0c675df6b6295b053aa9ebb97bd62f768b31bcfe..0000000000000000000000000000000000000000 --- a/video_prediction_tools/main_scripts/main_preprocess_data_step1.py +++ /dev/null @@ -1,214 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -""" -Driver for preprocessing step 1 which parses the input arguments from the runscript -and performs parallelization with PyStager. -""" - -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth" - -from mpi4py import MPI -import os, sys, glob -import logging -import time -import argparse -from utils.external_function import directory_scanner -from utils.external_function import load_distributor -from data_preprocess.process_netCDF_v2 import * -from metadata import MetaData -from netcdf_datahandling import GeoSubdomain -import json - - -def main(): - - parser = argparse.ArgumentParser() - parser.add_argument("--source_dir", "-src_dir", dest="source_dir", type=str, - help="Directory where input netCDF-files are located.") - parser.add_argument("--destination_dir", "-dest_dir", dest="destination_dir", type=str, - help="Destination directory where pickle-files be saved. Note that the complete path is auto-" - "completed during runtime.") - parser.add_argument("--years", "-y", dest="years", help="Year of data to be processed.") - parser.add_argument("--rsync_status", type=int, default=1) - parser.add_argument("--vars", nargs="+", default=["2t", "2t", "2t"], help="Variables to be processed.") - parser.add_argument("--sw_corner", "-swc", dest="sw_corner", nargs="+", help="Defines south-west corner of target domain " + - "(lat, lon)=(-90..90, 0..360)") - parser.add_argument("--nyx", "-nyx", dest="nyx", nargs="+", help="Number of grid points in zonal and meridional direction.") - parser.add_argument("--experimental_id", "-exp_id", dest="exp_id", type=str, default="dummy", - help="Experimental identifier helping to distinguish between different experiments.") - args = parser.parse_args() - - current_path = os.getcwd() - years = args.years - source_dir = args.source_dir - source_dir_full = os.path.join(source_dir, str(years))+"/" - destination_dir = args.destination_dir - rsync_status = args.rsync_status - - vars1 = args.vars - sw_c = [float(f) for f in args.sw_corner] - nyx = [int(i) for i in args.nyx] - print("Selected variables", vars1) - - exp_id = args.exp_id - - os.chdir(current_path) - time.sleep(0) - - # ini. MPI - comm = MPI.COMM_WORLD - my_rank = comm.Get_rank() # rank of the node - p = comm.Get_size() # number of assigned nods - - # ============ configuration for data preprocessing =================== # - # ==================================== Master Logging ==================================================== # - # DEBUG: Detailed information, typically of interest only when diagnosing problems. - # INFO: Confirmation that things are working as expected. - # WARNING: An indication that something unexpected happened, or indicative of some problem in the near - # ERROR: Due to a more serious problem, the software has not been able to perform some function. - # CRITICAL: A serious error, indicating that the program itself may be unable to continue running. - - if my_rank == 0: # node is master - logging.basicConfig(filename='stager.log', level=logging.DEBUG, - format='%(asctime)s:%(levelname)s:%(message)s') - start = time.time() # start of the MPI - logging.debug(' === PyStager is started === ') - print('PyStager is Running .... ') - # ================================== ALL Nodes: Read-in parameters ====================================== # - - # check the existence of teh folders : - if not os.path.exists(source_dir_full): # check if the source dir. is existing - if my_rank == 0: - logging.critical('The source does not exist') - logging.info('exit status : 1') - print('Critical : The source does not exist') - - sys.exit(1) - - # Expand destination_dir-variable by searching for netCDF-files in source_dir - # and processing the file from the first list element to obtain all relevant (meta-)data. - data_files_list = glob.iglob(source_dir_full+"/**/*.nc", recursive=True) - try: - data_file = next(data_files_list) - except StopIteration: - raise FileNotFoundError("Could not find any data to be processed in '{0}'".format(source_dir_full)) - - tar_dom = GeoSubdomain(sw_c, nyx, data_file) - - if my_rank == 0: - md = MetaData(suffix_indir=destination_dir, exp_id=exp_id, data_filename=data_file, tar_dom=tar_dom, - variables=vars1) - - if md.status == "old": # meta-data file already exists and is ok - # check for temp.json in working directory (required by slave nodes) - tmp_file = os.path.join(current_path, "temp.json") - if os.path.isfile(tmp_file): - os.remove(tmp_file) - mess_tmp_file = "Auxiliary file '"+tmp_file+"' already exists, but is cleaned up to be updated" + \ - " for safety reasons." - logging.info(mess_tmp_file) - - # ML 2020/06/08: Dirty workaround as long as data-splitting is done with a seperate Python-script - # called from the same parent Shell-/Batch-script - # -> work with temproary json-file in working directory - # create or update temp.json, respectively - md.write_destdir_jsontmp(os.path.join(md.expdir, md.expname), tmp_dir=current_path) - - # expand destination directory by pickle-subfolder and... - destination_dir = os.path.join(md.expdir, md.expname, "pickle", years) - - # ...create directory if necessary - if not os.path.exists(destination_dir): # check if the Destination dir. is existing - logging.critical('The Destination does not exist') - logging.info('Create new destination dir') - os.makedirs(destination_dir, exist_ok=True) - - with open(os.path.join(md.expdir, md.expname, "options.json"), "w") as f: - f.write(json.dumps(vars(args), sort_keys=True, indent=4)) - - if my_rank == 0: # node is master: - # ==================================== Master : Directory scanner ================================= # - - print(" # ============== Directory scanner : start ==================# ") - - ret_dir_scanner = directory_scanner(source_dir_full) - print(ret_dir_scanner) - dir_detail_list = ret_dir_scanner[0] - sub_dir_list = ret_dir_scanner[1] - total_size_source = ret_dir_scanner[2] - total_num_files = ret_dir_scanner[3] - total_num_dir = ret_dir_scanner[4] - - # =================================== Master : Load Distribution ========================== # - - print(" # ============== Load Distrbution : start ==================# ") - - ret_load_balancer = load_distributor(dir_detail_list, sub_dir_list, total_size_source, total_num_files, - total_num_dir, p) - transfer_dict = ret_load_balancer - - print(ret_load_balancer) - # ===================================== Master : Send / Receive =============================== # - print(" # ============== Communication : start ==================# ") - - # Send : the list of the directories to the nodes - for nodes in range(1, p): - broadcast_list = transfer_dict[nodes] - comm.send(broadcast_list, dest=nodes) - - # Receive : will wait for a certain time to see if it will receive any critical error from the slaves nodes - idle_counter = p - len(sub_dir_list) - while idle_counter > 1: # non-blocking receive function - message_in = comm.recv() - logging.warning(message_in) - # print('Warning:', message_in) - idle_counter = idle_counter - 1 - - # Receive : Message from slave nodes confirming the sync - message_counter = 1 - while message_counter <= len(sub_dir_list): # non-blocking receive function - message_in = comm.recv() - logging.info(message_in) - message_counter = message_counter + 1 - - # stamp the end of the runtime - end = time.time() - logging.debug(end - start) - logging.info('== PyStager is done ==') - logging.info('exit status : 0') - print('PyStager is finished ') - sys.exit(0) - - else: # node is slave - - # ========================================== Slave : Send / Receive ========================================= # - message_in = comm.recv() - - if message_in is None: # in case more than number of the dir. processor is assigned todo Tag it! - message_out = ('Node', str(my_rank), 'is idle') - comm.send(message_out, dest=0) - - else: # if the Slave node has joblist to do - job_list = message_in.split(';') - - for job_count in range(0, len(job_list)): - job = job_list[job_count] # job is the name of the directory(ies) assigned to slave_node - # grib_2_netcdf(rot_grid,source_dir, destination_dir, job) - if rsync_status == 1: - # ML 2020/06/09: workaround to get correct destination_dir obtained by the master node - destination_dir = MetaData.get_destdir_jsontmp(tmp_dir=current_path) - process_data = PreprocessNcToPkl(source_dir, destination_dir, years, job, tar_dom, vars1) - process_data() - - # Send : the finish of the sync message back to master node - message_out = ('Node:', str(my_rank), 'finished :', "", '\r\n') - comm.send(message_out, dest=0) - - MPI.Finalize() - - -if __name__ == "__main__": - main() diff --git a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py b/video_prediction_tools/main_scripts/main_preprocess_data_step2.py deleted file mode 100644 index e71b5b58174324a6875da7f49f3d301233e00ecb..0000000000000000000000000000000000000000 --- a/video_prediction_tools/main_scripts/main_preprocess_data_step2.py +++ /dev/null @@ -1,117 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -""" -Driver for preprocessing step 2 which parses the input arguments from the runscript -and performs parallelization with OpenMPI. -""" -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong" - -# import modules -import os -import argparse -from mpi4py import MPI -from general_utils import get_unique_vars -from statistics import Calc_data_stat -from data_preprocess.preprocess_data_step2 import * -import warnings - - -def main(): - - method="main_preprocess_data_step2" - - parser = argparse.ArgumentParser() - parser.add_argument("-source_dir", type=str) - parser.add_argument("-dest_dir", type=str) - parser.add_argument("-sequence_length", type=int, default=20) - parser.add_argument("-sequences_per_file", type=int, default=20) - args = parser.parse_args() - input_dir = args.source_dir - ins = ERA5Pkl2Tfrecords(input_dir=input_dir, - dest_dir=args.dest_dir, - sequence_length = args.sequence_length, - sequences_per_file=args.sequences_per_file) - - years, months,years_months = ins.get_years_months() - # ini. MPI - comm = MPI.COMM_WORLD - my_rank = comm.Get_rank() # rank of the node - p = comm.Get_size() # number of assigned nodes - if p < 2: - raise ValueError("%{0}: Preprocessing step 2 must be assigned to at least two tasks.".format(method)) - - if my_rank == 0: - # retrieve final statistics first (not parallelized!) - # some preparatory steps - stat_dir = os.path.dirname(input_dir) - varnames = ins.vars_in - - vars_uni, varsind, nvars = get_unique_vars(varnames) - stat_obj = Calc_data_stat(nvars) # init statistic-instance - - # loop over whole data set (training, dev and test set) to collect the intermediate statistics - print("%{0}: Start collecting statistics from the whole dataset to be processed...".format(method)) - - for year in years: - file_dir = os.path.join(input_dir, year) - for month in months: - if os.path.isfile(os.path.join(file_dir, "stat_" + '{0:02}'.format(month) + ".json")): - # process stat-file: - stat_obj.acc_stat_master(file_dir, int(month)) # process monthly statistic-file - else: - warnings.warn("%{0}: The statistic file for year {1}, month {2} does not exist".format(method, year, month)) - # finalize statistics and write to json-file - stat_obj.finalize_stat_master(vars_uni) - stat_obj.write_stat_json(stat_dir) - - # organize parallelized partioning - real_years_months = [] - for i in range(len(years)): - year = years[i] - for month in years_months[i]: - year_month = "Y_{}_M_{}".format(year, month) - real_years_months.append(year_month) - - broadcast_lists = [list(years), real_years_months] - - for nodes in range(1, p): - comm.send(broadcast_lists, dest=nodes) - - message_counter = 1 - while message_counter <= p-1: - message_in = comm.recv() - message_counter = message_counter + 1 - print("%{0}: Message in from worker: {1} ".format(method, message_in)) - - else: - message_in = comm.recv() - print("%{0}: Message from master to rank {1}: {2} ".format(method, my_rank, message_in)) - - years = list(message_in[0]) - real_years_months = message_in[1] - - for year in years: - year_rank = "Y_{}_M_{}".format(year, my_rank) - if year_rank in real_years_months: - # Initilial instance - ins2 = ERA5Pkl2Tfrecords(input_dir=input_dir, - dest_dir=args.dest_dir, - sequence_length = args.sequence_length, - sequences_per_file=args.sequences_per_file) - # create the tfrecords-files - ins2.read_pkl_and_save_tfrecords(year=year, month=my_rank) - print("%{0}: Year {1} finished".format(method, year)) - else: - print("%{0}: {1} is not in the datasplit_dic, will skip the process".format(method, year_rank)) - message_out = ("Node:", str(my_rank), "finished", "", "\r\n") - print("%{0}: Message out for worker: {1}".format(method, message_out)) - comm.send(message_out, dest=0) - - MPI.Finalize() - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index 5242c8848bd5c0fa4ee8927532a2c8e7bf15743d..cf9fc63f01483cc9d0d9fe6925e10fdc2607cd4a 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -1,3 +1,4 @@ +# coding=utf-8 # SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) # SPDX-FileCopyrightText: 2018 Alex X. Lee # @@ -21,17 +22,19 @@ import time import numpy as np import xarray as xr import tensorflow as tf -from model_modules.video_prediction import datasets, models +from model_modules.video_prediction import models +from model_modules.video_prediction.datasets import get_dataset import matplotlib.pyplot as plt import pickle as pkl from model_modules.video_prediction.utils import tf_utils from general_utils import * import math import shutil +from pathlib import Path class TrainModel(object): def __init__(self, input_dir: str = None, output_dir: str = None, datasplit_dict: str = None, - model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset: str = None, + model_hparams_dict: str = None, model: str = None, checkpoint: str = None, dataset_name: str = None, gpu_mem_frac: float = 1., seed: int = None, args=None, diag_intv_frac: float = 0.001, frac_start_save: float = None, frac_intv_save: float = None): """ @@ -50,12 +53,12 @@ class TrainModel(object): :param frac_start_save: fraction of total iterations steps to start checkpointing the model :param frac_intv_save: fraction of total iterations steps for checkpointing the model """ - self.input_dir = os.path.normpath(input_dir) - self.output_dir = os.path.normpath(output_dir) + self.input_dir = Path(input_dir).resolve(strict=False) + self.output_dir = Path(output_dir).resolve(strict=False) self.datasplit_dict = datasplit_dict self.model_hparams_dict = model_hparams_dict self.checkpoint = checkpoint - self.dataset = dataset + self.dataset_name = dataset_name self.model = model self.gpu_mem_frac = gpu_mem_frac self.seed = seed @@ -77,7 +80,7 @@ class TrainModel(object): self.make_dataset_iterator() self.setup_model() self.setup_graph() - self.save_dataset_model_params_to_checkpoint_dir(dataset=self.train_dataset,video_model=self.video_model) + self.save_dataset_model_params_to_checkpoint_dir(dataset=self.dataset, video_model=self.video_model) # TODO: resolve potetial incompatibility self.count_parameters() self.create_saver_and_writer() self.setup_gpu_config() @@ -108,10 +111,12 @@ class TrainModel(object): """ Get and read model_hparams_dict from json file to dictionary """ - self.model_hparams_dict_load = {} if self.model_hparams_dict: - with open(self.model_hparams_dict) as f: - self.model_hparams_dict_load.update(json.loads(f.read())) + with open(self.model_hparams_dict, 'r') as f: + self.model_hparams_dict_load = json.loads(f.read()) + else: + raise FileNotFoundError("hparam directory doesn't exist! please check {}!".format(self.model_hparams_dict)) + return self.model_hparams_dict_load def load_params_from_checkpoints_dir(self): @@ -134,7 +139,7 @@ class TrainModel(object): with open(os.path.join(self.checkpoint_dir, "options.json")) as f: print("%{0}: Loading options from checkpoint '{1}'".format(method, self.checkpoint)) self.options = json.loads(f.read()) - self.dataset = self.dataset or self.options['dataset'] + self.dataset_name = self.dataset_name or self.options['dataset'] self.model = self.model or self.options['model'] except FileNotFoundError: print("%{0}: options.json does not exist in {1}".format(method, self.checkpoint_dir)) @@ -154,16 +159,11 @@ class TrainModel(object): self.batch_size = self.model_hparams_dict_load["batch_size"] self.max_epochs = self.model_hparams_dict_load["max_epochs"] # create dataset instance - VideoDataset = datasets.get_dataset_class(self.dataset) - self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict, - hparams_dict_config=self.model_hparams_dict) + + self.dataset = get_dataset(self.dataset_name, input_dir=self.input_dir, output_dir=self.output_dir, datasplit_path=self.datasplit_dict, hparams_path=self.model_hparams_dict, seed=self.seed) + self.calculate_samples_and_epochs() - self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) - # set-up validation dataset and calculate number of batches for calculating validation loss - self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict, - hparams_dict_config=self.model_hparams_dict, nsamples_ref=self.num_examples) - # Retrieve sequence length from dataset - self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) + self.model_hparams_dict_load.update({"sequence_length": self.dataset.sequence_length}) def setup_model(self, mode="train"): """ @@ -171,7 +171,7 @@ class TrainModel(object): :param mode: "train" used the model graph in train process; "test" for postprocessing step """ VideoPredictionModel = models.get_model_class(self.model) - self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict_load, mode=mode) + self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict, mode=mode) def setup_graph(self): """ @@ -184,12 +184,12 @@ class TrainModel(object): Prepare the dataset interator for training and validation """ self.batch_size = self.model_hparams_dict_load["batch_size"] - train_tf_dataset = self.train_dataset.make_dataset(self.batch_size) + train_tf_dataset = self.dataset.make_training() train_iterator = train_tf_dataset.make_one_shot_iterator() # The `Iterator.string_handle()` method returns a tensor that can be evaluated # and used to feed the `handle` placeholder. self.train_handle = train_iterator.string_handle() - val_tf_dataset = self.val_dataset.make_dataset(self.batch_size) + val_tf_dataset = self.dataset.make_validation() val_iterator = val_tf_dataset.make_one_shot_iterator() self.val_handle = val_iterator.string_handle() self.iterator = tf.data.Iterator.from_string_handle( @@ -197,7 +197,7 @@ class TrainModel(object): self.inputs = self.iterator.get_next() # since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP # Otherwise an error will be risen by SAVP - if self.dataset == "era5" and self.model == "savp": + if self.dataset_name == "era5" and self.model == "savp": del self.inputs["T_start"] def save_dataset_model_params_to_checkpoint_dir(self, dataset, video_model): @@ -207,9 +207,9 @@ class TrainModel(object): with open(os.path.join(self.output_dir, "options.json"), "w") as f: f.write(json.dumps(vars(self.args), sort_keys=True, indent=4)) with open(os.path.join(self.output_dir, "dataset_hparams.json"), "w") as f: - f.write(json.dumps(dataset.hparams.values(), sort_keys=True, indent=4)) + f.write(json.dumps(dataset.hparams, sort_keys=True, indent=4)) with open(os.path.join(self.output_dir, "model_hparams.json"), "w") as f: - f.write(json.dumps(video_model.hparams.values(), sort_keys=True, indent=4)) + f.write(json.dumps(video_model.hparams, sort_keys=True, indent=4)) #with open(os.path.join(self.output_dir, "data_dict.json"), "w") as f: # f.write(json.dumps(dataset.data_dict, sort_keys=True, indent=4)) @@ -243,10 +243,11 @@ class TrainModel(object): """ method = TrainModel.calculate_samples_and_epochs.__name__ - self.num_examples = self.train_dataset.num_examples_per_epoch() + self.num_examples = self.dataset.num_training_samples self.steps_per_epoch = int(self.num_examples/self.batch_size) self.total_steps = self.steps_per_epoch * self.max_epochs self.diag_intv_step = int(self.diag_intv_frac*self.total_steps) + if self.diag_intv_step == 0: self.diag_intv_step = 1 else: @@ -255,6 +256,8 @@ class TrainModel(object): .format(method, self.batch_size, self.max_epochs, self.num_examples, self.steps_per_epoch, self.total_steps)) + + def calculate_checkpoint_saver_conf(self): """ Calculate the start step for saving the checkpoint, and the frequences steps to save model @@ -380,16 +383,23 @@ class TrainModel(object): # Final diagnostics: training track time and save to pickle-files) train_time = time.time() - run_start_time - results_dict = {"train_time": train_time, "total_steps": self.total_steps} + + avg_time_first_epoch = np.mean(time_per_iteration[:self.steps_per_epoch]) + avg_time_non_first_epoch = np.mean(time_per_iteration[self.steps_per_epoch:]) + results_dict = {"train_time": train_time, "total_steps": self.total_steps, + "avg_time_first_epoch": avg_time_first_epoch, + "avg_time_non_first_epoch":avg_time_non_first_epoch} + TrainModel.save_results_to_dict(results_dict, self.output_dir) print("%{0}: Training loss decreased from {1:.6f} to {2:.6f}:" .format(method, np.mean(train_losses[0:10]), np.mean(train_losses[-self.diag_intv_step:]))) print("%{0}: Validation loss decreased from {1:.6f} to {2:.6f}:" .format(method, np.mean(val_losses[0:10]), np.mean(val_losses[-self.diag_intv_step:]))) - print("%{0}: Training finsished".format(method)) + print("%{0}: Training finished".format(method)) print("%{0}: Total training time: {1:.2f} min".format(method, train_time/60.)) - + print("%{0}: The average of training time for the first epoch: {1:.2f} sec".format(method, avg_time_first_epoch)) + print("%{0}: The average of training time for after first epoch: {1:.2f} sec".format(method,avg_time_non_first_epoch)) return train_time, time_per_iteration def create_fetches_for_train(self): @@ -507,6 +517,9 @@ class TrainModel(object): elif self.video_model.__class__.__name__ == "VanillaVAEVideoPredictionModel": print("Total_loss:{}; latent_losses:{}; reconst_loss:{}" .format(results["total_loss"], results["latent_loss"], results["recon_loss"])) + elif self.video_model.__class__.__name__ == "ConvLstmGANVideoPredictionModel": + print("Total_loss:{}" + .format(results["total_loss"])) else: print("%{0}: Printing results of model '{1}' is not implemented yet".format(method, self.video_model.__class__.__name__)) @@ -729,7 +742,7 @@ def main(): parser.add_argument("--output_dir", help="Output directory where JSON-files, summary, model, plots etc. are saved.") parser.add_argument("--datasplit_dict", help="JSON-file that contains the datasplit configuration") parser.add_argument("--checkpoint", help="Checkpoint directory or checkpoint name (e.g. <my_dir>/model-200000)") - parser.add_argument("--dataset", type=str, help="Dataset class name") + parser.add_argument("--dataset", type=str, help="Dataset name") # as in dataset_utils.DATASETS parser.add_argument("--model", type=str, help="Model class name") parser.add_argument("--model_hparams_dict", type=str, help="JSON-file of model hyperparameters") parser.add_argument("--gpu_mem_frac", type=float, default=0.99, help="Fraction of gpu memory to use") @@ -743,10 +756,15 @@ def main(): args = parser.parse_args() # start timing for the whole run + + # list pip environment + import os + print(os.system("pip3 list")) + timeit_start = time.time() # create a training instance train_case = TrainModel(input_dir=args.input_dir,output_dir=args.output_dir,datasplit_dict=args.datasplit_dict, - model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset=args.dataset, + model_hparams_dict=args.model_hparams_dict,model=args.model,checkpoint=args.checkpoint, dataset_name=args.dataset, gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, args=args, frac_start_save=args.frac_start_save, frac_intv_save=args.frac_intv_save) diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess.py b/video_prediction_tools/main_scripts/main_visualize_postprocess.py index 34ed8da4970c0f30b0357e8d72ca128e601c009d..fc55de930ca70eb8624ba85bbe9462f7735a59c9 100644 --- a/video_prediction_tools/main_scripts/main_visualize_postprocess.py +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess.py @@ -729,7 +729,6 @@ class Postprocess(TrainModel): gen_images_denorm = self.denorm_images_all_channels( gen_images, self.vars_in, self.norm_cls, norm_method="minmax" ) - # store data into datset & get number of samples (may differ from batch_size at the end of the test dataset) times_0, init_times = self.get_init_time(t_starts) batch_ds = self.create_dataset( input_images_denorm, gen_images_denorm, init_times @@ -815,7 +814,7 @@ class Postprocess(TrainModel): self.sess.run(tf.global_variables_initializer()) self.sess.run(tf.local_variables_initializer()) - def get_input_data_per_batch(self, input_iter, norm_method="minmax"): + def get_input_data_per_batch(self, input_iter, norm_method="cbnorm"): """ Get the input sequence from the dataset iterator object stored in self.inputs and denormalize the data :param input_iter: the iterator object built by make_test_dataset_iterator-method @@ -1298,8 +1297,7 @@ class Postprocess(TrainModel): @staticmethod def denorm_images_all_channels( - image_sequence, varnames, norm, norm_method="minmax" - ): + image_sequence, varnames, norm, norm_method="minmax") """ Denormalize data of all image channels :param image_sequence: list/array [batch, seq, lat, lon, channel] of images diff --git a/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py b/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py new file mode 100644 index 0000000000000000000000000000000000000000..c2fdd4f1b20751768ce76c138124a215e45f0233 --- /dev/null +++ b/video_prediction_tools/main_scripts/main_visualize_postprocess_gzprcp.py @@ -0,0 +1,1175 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong, Yan Ji, Michael Langguth" +__date__ = "2020-11-10" + +import argparse +import os +import shutil +import numpy as np +import xarray as xr +import pandas as pd +import tensorflow as tf +import pickle +import datetime as dt +import json +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from mpl_toolkits.basemap import Basemap +from normalization import Norm_data +from metadata import MetaData as MetaData +from main_scripts.main_train_models import * +from data_preprocess.preprocess_data_step2 import * +from model_modules.video_prediction import datasets, models, metrics +from postprocess.statistical_evaluation import perform_block_bootstrap_metric, avg_metrics, Scores + + +class Postprocess(TrainModel): + def __init__(self, results_dir=None, checkpoint=None, mode="test", batch_size=None, num_stochastic_samples=1, + stochastic_plot_id=0, gpu_mem_frac=None, seed=None, args=None, run_mode="deterministic"): + """ + The function for inference, generate results and images + results_dir :str, The output directory to save results + checkpoint :str, The directory point to the checkpoints + mode :str, Default is test, could be "train","val", and "test" + batch_size :int, The batch size used for generating test samples for each iteration + num_stochastic_samples: int, for the stochastic models such as SAVP, VAE, it is used for generate a number of + ensemble for each prediction. + For deterministic model such as convLSTM, it is default setup to 1 + stochastic_plot_id :int, the index for stochastically generated images to plot + gpu_mem_frac :int, GPU memory fraction to be used + seed :seed for control test samples + run_mode :str, if "deterministic" then the model running for deterministic forecasting, other string values, it will go for stochastic forecasting + + Side notes : other important varialbes in the class: + self.ts : list, contains the sequence_length timestamps + self.gen_images_ : the length of generate images by model is sequence_length - 1 + self.persistent_image : the length of persistent images is sequence_length - 1 + self.input_images : the length of inputs images is sequence length + + """ + + # initialize input directories (to be retrieved by load_jsons) + self.input_dir = None + self.input_dir_tfr = None + self.input_dir_pkl = None + # forecast products and evaluation metrics to be handled in postprocessing + self.eval_metrics = ["mse", "psnr", "ssim"] + self.fcst_products = {"persistence": "pfcst", "model": "mfcst"} + # initialize dataset to track evaluation metrics and configure bootstrapping procedure + self.eval_metrics_ds = None + self.nboots_block = 1000 + self.block_length = 7 * 24 # this corresponds to a block length of 7 days when forecasts are produced every hour + # other attributes + self.stat_fl = None + self.norm_cls = None # placeholder for normalization instance + self.channel = 0 # index of channel/input variable to evaluate + self.num_samples_per_epoch = None + # set further attributes from parsed arguments + self.results_dir = self.output_dir = os.path.normpath(results_dir) + if not os.path.exists(self.results_dir): + os.makedirs(self.results_dir) + self.batch_size = batch_size + self.gpu_mem_frac = gpu_mem_frac + self.seed = seed + self.num_stochastic_samples = num_stochastic_samples + self.stochastic_plot_id = stochastic_plot_id + self.args = args + self.checkpoint = checkpoint + self.run_mode = run_mode + self.mode = mode + if self.checkpoint is None: + raise ValueError("The directory point to checkpoint is empty, must be provided for postprocess step") + + if not os.path.isdir(self.checkpoint): + raise NotADirectoryError("The checkpoint-directory '{0}' does not exist".format(self.checkpoint)) + + def __call__(self): + self.set_seed() + self.save_args_to_option_json() + self.copy_data_model_json() + self.load_jsons() + self.get_metadata() + self.setup_test_dataset() + self.setup_model(self.mode) + self.get_data_params() + self.setup_num_samples_per_epoch() + self.get_stat_file() + self.make_test_dataset_iterator() + self.check_stochastic_samples_ind_based_on_model() + self.setup_graph() + self.setup_gpu_config() + + # methods that are executed with __call__ + def save_args_to_option_json(self): + """ + Save the argments defined by user to the results dir + """ + with open(os.path.join(self.results_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(self.args), sort_keys=True, indent=4)) + + def copy_data_model_json(self): + """ + Copy relevant JSON-files from checkpoints directory to results_dir + """ + method_name = Postprocess.copy_data_model_json.__name__ + + # correctness of self.checkpoint and self.results_dir is already checked in __init__ + model_opt_js = os.path.join(self.checkpoint, "options.json") + model_ds_js = os.path.join(self.checkpoint, "dataset_hparams.json") + model_hp_js = os.path.join(self.checkpoint, "model_hparams.json") + model_dd_js = os.path.join(self.checkpoint, "data_dict.json") + + if os.path.isfile(model_opt_js): + shutil.copy(model_opt_js, os.path.join(self.results_dir, "options_checkpoints.json")) + else: + raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_opt_js)) + + if os.path.isfile(model_ds_js): + shutil.copy(model_ds_js, os.path.join(self.results_dir, "dataset_hparams.json")) + else: + raise FileNotFoundError("%{0}: the file {1} does not exist".format(method_name, model_ds_js)) + + if os.path.isfile(model_hp_js): + shutil.copy(model_hp_js, os.path.join(self.results_dir, "model_hparams.json")) + else: + raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_hp_js)) + + if os.path.isfile(model_dd_js): + shutil.copy(model_dd_js, os.path.join(self.results_dir, "data_dict.json")) + else: + raise FileNotFoundError("%{0}: The file {1} does not exist".format(method_name, model_dd_js)) + + def load_jsons(self): + """ + Set attributes pointing to JSON-files which track essential information and also load some information + to store it to attributes of the class instance + """ + method_name = Postprocess.load_jsons.__name__ + + self.datasplit_dict = os.path.join(self.results_dir, "data_dict.json") + self.model_hparams_dict = os.path.join(self.results_dir, "model_hparams.json") + checkpoint_opt_dict = os.path.join(self.results_dir, "options_checkpoints.json") + + # sanity checks on the JSON-files + if not os.path.isfile(self.datasplit_dict): + raise FileNotFoundError("%{0}: The file data_dict.json is missing in {1}".format(method_name, + self.results_dir)) + + if not os.path.isfile(self.model_hparams_dict): + raise FileNotFoundError("%{0}: The file model_hparams.json is missing in {1}".format(method_name, + self.results_dir)) + + if not os.path.isfile(checkpoint_opt_dict): + raise FileNotFoundError("%{0}: The file options_checkpoints.json is missing in {1}" + .format(method_name, self.results_dir)) + # retrieve some data from options_checkpoints.json + try: + with open(checkpoint_opt_dict) as f: + options_checkpoint = json.loads(f.read()) + self.dataset = options_checkpoint["dataset"] + self.model = options_checkpoint["model"] + self.input_dir_tfr = options_checkpoint["input_dir"] + self.input_dir = os.path.dirname(self.input_dir_tfr.rstrip("/")) + self.input_dir_pkl = os.path.join(self.input_dir, "pickle") + # update self.fcst_products + if "model" in self.fcst_products.keys(): + self.fcst_products[self.model] = self.fcst_products.pop("model") + except Exception as err: + print("%{0}: Something went wrong when reading the checkpoint-file '{1}'".format(method_name, + checkpoint_opt_dict)) + raise err + + self.model_hparams_dict_load = self.get_model_hparams_dict() + + def get_metadata(self): + + method_name = Postprocess.get_metadata.__name__ + + # some sanity checks + if self.input_dir is None: + raise AttributeError("%{0}: input_dir-attribute is still None".format(method_name)) + + metadata_fl = os.path.join(self.input_dir, "metadata.json") + + if not os.path.isfile(metadata_fl): + raise FileNotFoundError("%{0}: Could not find metadata JSON-file under '{1}'".format(method_name, + self.input_dir)) + + try: + md_instance = MetaData(json_file=metadata_fl) + except Exception as err: + print("%{0}: Something went wrong when getting metadata from file '{1}'".format(method_name, metadata_fl)) + raise err + + # when the metadat is loaded without problems, the follwoing will work + self.height, self.width = md_instance.ny, md_instance.nx + self.vars_in = md_instance.variables + + self.lats = xr.DataArray(md_instance.lat, coords={"lat": md_instance.lat}, dims="lat", + attrs={"units": "degrees_east"}) + self.lons = xr.DataArray(md_instance.lon, coords={"lon": md_instance.lon}, dims="lon", + attrs={"units": "degrees_north"}) + + def setup_test_dataset(self): + """ + setup the test dataset instance + """ + VideoDataset = datasets.get_dataset_class(self.dataset) + self.test_dataset = VideoDataset(input_dir=self.input_dir_tfr, mode=self.mode, + datasplit_config=self.datasplit_dict) + + def setup_num_samples_per_epoch(self): + """ + For generating images, the user can define the examples used, and will be taken as num_examples_per_epoch + For testing we only use exactly one epoch, but to be consistent with the training, we keep the name '_per_epoch' + """ + method = Postprocess.setup_num_samples_per_epoch.__name__ + + self.num_samples_per_epoch = self.test_dataset.num_examples_per_epoch() + + return self.num_samples_per_epoch + + def get_data_params(self): + """ + Get the context_frames, future_frames and total frames from hparamters settings. + Note that future_frames_length is the number of predicted frames. + """ + self.context_frames = self.model_hparams_dict_load["context_frames"] + self.sequence_length = self.model_hparams_dict_load["sequence_length"] + self.future_length = self.sequence_length - self.context_frames + + def get_stat_file(self): + """ + Load the statistics from statistic file from the input directory + """ + self.stat_fl = os.path.join(self.input_dir, "statistics.json") + + def make_test_dataset_iterator(self): + """ + Make the dataset iterator + """ + test_tf_dataset = self.test_dataset.make_dataset(self.batch_size) + test_iterator = test_tf_dataset.make_one_shot_iterator() + # The `Iterator.string_handle()` method returns a tensor that can be evaluated + # and used to feed the `handle` placeholder. + test_handle = test_iterator.string_handle() + dataset_iterator = tf.data.Iterator.from_string_handle(test_handle, test_tf_dataset.output_types, + test_tf_dataset.output_shapes) + self.inputs = dataset_iterator.get_next() + self.input_ts = self.inputs["T_start"] + # if self.dataset == "era5" and self.model == "savp": + # del self.inputs["T_start"] + + def check_stochastic_samples_ind_based_on_model(self): + """ + stochastic forecasting only suitable for the geneerate models such as SAVP, vae. + For convLSTM, McNet only do determinstic forecasting + """ + if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet': + if self.num_stochastic_samples > 1: + print("Number of samples for deterministic model cannot be larger than 1. Higher values are ignored.") + self.num_stochastic_samples = 1 + + def init_session(self): + self.sess = tf.Session(config=self.config) + self.sess.graph.as_default() + self.sess.run(tf.global_variables_initializer()) + self.sess.run(tf.local_variables_initializer()) + + # the run-factory + def run(self): + if self.model == "convLSTM" or self.model == "test_model" or self.model == 'mcnet': + self.run_deterministic() + elif self.run_mode == "deterministic": + self.run_deterministic() + else: + self.run_stochastic() + + def run_stochastic(self): + """ + Run session, save results to netcdf, plot input images, generate images and persistent images + """ + method = Postprocess.run_stochastic.__name__ + raise ValueError("ML: %{0} is not runnable now".format(method)) + + self.init_session() + self.restore(self.sess, self.checkpoint) + # Loop for samples + self.sample_ind = 0 + self.prst_metric_all = [] # store evaluation metrics of persistence forecast (shape [future_len]) + self.fcst_metric_all = [] # store evaluation metric of stochastic forecasts (shape [nstoch, batch, future_len]) + while self.sample_ind < self.num_samples_per_epoch: + if self.num_samples_per_epoch < self.sample_ind: + break + else: + # run the inputs and plot each sequence images + self.input_results, self.input_images_denorm_all, self.t_starts = self.get_input_data_per_batch() + + feed_dict = {input_ph: self.input_results[name] for name, input_ph in self.inputs.items()} + gen_loss_stochastic_batch = [] # [stochastic_ind,future_length] + gen_images_stochastic = [] # [stochastic_ind,batch_size,seq_len,lat,lon,channels] + # Loop for stochastics + for stochastic_sample_ind in range(self.num_stochastic_samples): + print("stochastic_sample_ind:", stochastic_sample_ind) + # return [batchsize,seq_len,lat,lon,channel] + gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict) + # The generate images seq_len should be sequence_len -1, since the last one is + # not used for comparing with groud truth + assert gen_images.shape[1] == self.sequence_length - 1 + gen_images_per_batch = [] + if stochastic_sample_ind == 0: + persistent_images_per_batch = [] # [batch_size,seq_len,lat,lon,channel] + ts_batch = [] + for i in range(self.batch_size): + # generate time stamps for sequences only once, since they are the same for all ensemble members + if stochastic_sample_ind == 0: + self.ts = Postprocess.generate_seq_timestamps(self.t_starts[i], len_seq=self.sequence_length) + init_date_str = self.ts[0].strftime("%Y%m%d%H") + ts_batch.append(init_date_str) + # get persistence_images + self.persistence_images, self.ts_persistence = Postprocess.get_persistence(self.ts, + self.input_dir_pkl) + persistent_images_per_batch.append(self.persistence_images) + assert len(np.array(persistent_images_per_batch).shape) == 5 + self.plot_persistence_images() + + # Denormalized data for generate + gen_images_ = gen_images[i] + self.gen_images_denorm = Postprocess.denorm_images_all_channels(self.stat_fl, gen_images_, + self.vars_in) + gen_images_per_batch.append(self.gen_images_denorm) + assert len(np.array(gen_images_per_batch).shape) == 5 + # only plot when the first stochastic ind otherwise too many plots would be created + # only plot the stochastic results of user-defined ind + self.plot_generate_images(stochastic_sample_ind, self.stochastic_plot_id) + # calculate the persistnet error per batch + if stochastic_sample_ind == 0: + persistent_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all, + persistent_images_per_batch, + self.future_length, + self.context_frames, + matric="mse", channel=0) + self.prst_metric_all.append(persistent_loss_per_batch) + + # calculate the gen_images_per_batch error + gen_loss_per_batch = Postprocess.calculate_metrics_by_batch(self.input_images_denorm_all, + gen_images_per_batch, self.future_length, + self.context_frames, + matric="mse", channel=0) + gen_loss_stochastic_batch.append( + gen_loss_per_batch) # self.gen_images_stochastic[stochastic,future_length] + print("gen_images_per_batch shape:", np.array(gen_images_per_batch).shape) + gen_images_stochastic.append( + gen_images_per_batch) # [stochastic,batch_size, seq_len, lat, lon, channel] + + # Switch the 0 and 1 position + print("before transpose:", np.array(gen_images_stochastic).shape) + gen_images_stochastic = np.transpose(np.array(gen_images_stochastic), ( + 1, 0, 2, 3, 4, 5)) # [batch_size, stochastic, seq_len, lat, lon, chanel] + Postprocess.check_gen_images_stochastic_shape(gen_images_stochastic) + assert len(gen_images_stochastic.shape) == 6 + assert np.array(gen_images_stochastic).shape[1] == self.num_stochastic_samples + + self.fcst_metric_all.append( + gen_loss_stochastic_batch) # [samples/batch_size,stochastic,future_length] + # save input and stochastic generate images to netcdf file + # For each prediction (either deterministic or ensemble) we create one netCDF file. + for batch_id in range(self.batch_size): + self.save_to_netcdf_for_stochastic_generate_images(self.input_images_denorm_all[batch_id], + persistent_images_per_batch[batch_id], + np.array(gen_images_stochastic)[batch_id], + fl_name="vfp_date_{}_sample_ind_{}.nc" + .format(ts_batch[batch_id], + self.sample_ind + batch_id)) + + self.sample_ind += self.batch_size + + self.persistent_loss_all_batches = np.mean(np.array(self.persistent_loss_all_batches), axis=0) + self.stochastic_loss_all_batches = np.mean(np.array(self.stochastic_loss_all_batches), axis=0) + assert len(np.array(self.persistent_loss_all_batches).shape) == 1 + assert np.array(self.persistent_loss_all_batches).shape[0] == self.future_length + + assert len(np.array(self.stochastic_loss_all_batches).shape) == 2 + assert np.array(self.stochastic_loss_all_batches).shape[0] == self.num_stochastic_samples + + def run_deterministic(self): + """ + Revised and vectorized version of run_deterministic + Loops over the training data, generates forecasts and calculates basic evaluation metrics on-the-fly + """ + method = Postprocess.run_deterministic.__name__ + + # init the session and restore the trained model + self.init_session() + self.restore(self.sess, self.checkpoint) + + # init sample index for looping and acculmulators for evaulation metrics + sample_ind = 0 + nsamples = self.num_samples_per_epoch + # initialize datasets + eval_metric_ds = Postprocess.init_metric_ds(self.fcst_products, self.eval_metrics, self.vars_in[self.channel], + nsamples, self.future_length) + + while sample_ind < self.num_samples_per_epoch: + # get normalized and denormalized input data + input_results, input_images_denorm, t_starts = self.get_input_data_per_batch(self.inputs) + # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel] + feed_dict = {input_ph: input_results[name] for name, input_ph in self.inputs.items()} + gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict) + + # sanity check on length of forecast sequence + assert gen_images.shape[1] == self.sequence_length - 1, \ + "%{0}: Sequence length of prediction must be smaller by one than total sequence length.".format(method) + # denormalize forecast sequence (self.norm_cls is already set in get_input_data_per_batch-method) + gen_images_denorm = self.denorm_images_all_channels(gen_images, self.vars_in, self.norm_cls, + norm_method="cbnorm") + # store data into datset and get number of samples (may differ from batch_size at the end of the test dataset) + times_0, init_times = self.get_init_time(t_starts) + batch_ds = self.create_dataset(input_images_denorm, gen_images_denorm, init_times) + nbs = np.minimum(self.batch_size, self.num_samples_per_epoch - sample_ind) + batch_ds = batch_ds.isel(init_time=slice(0, nbs)) + + for i in np.arange(nbs): + # work-around to make use of get_persistence_forecast_per_sample-method + #times_seq = (pd.date_range(times_0[i], periods=int(self.sequence_length), freq="h")).to_pydatetime() + # get persistence forecast for sequences at hand and write to dataset + #persistence_seq, _ = Postprocess.get_persistence(times_seq, self.input_dir_pkl) + #for ivar, var in enumerate(self.vars_in): + #batch_ds["{0}_persistence_fcst".format(var)].loc[dict(init_time=init_times[i])] = \ + # persistence_seq[self.context_frames-1:, :, :, ivar] + + # save sequences to netcdf-file and track initial time + nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc" + .format(pd.to_datetime(init_times[i]).strftime("%Y%m%d%H%M"), sample_ind + i)) + self.save_ds_to_netcdf(batch_ds.isel(init_time=i), nc_fname) + # end of batch-loop + # write evaluation metric to corresponding dataset... + eval_metric_ds = self.populate_eval_metric_ds(eval_metric_ds, batch_ds, sample_ind, + self.vars_in[self.channel]) + # ... and increment sample_ind + sample_ind += self.batch_size + # end of while-loop for samples + # safe dataset with evaluation metrics for later use + self.eval_metrics_ds = eval_metric_ds + #self.add_ensemble_dim() + + # all methods of the run factory + def get_input_data_per_batch(self, input_iter, norm_method="cbnorm"): + """ + Get the input sequence from the dataset iterator object stored in self.inputs and denormalize the data + :param input_iter: the iterator object built by make_test_dataset_iterator-method + :param norm_method: normalization method applicable to the data + :return input_results: the normalized input data + :return input_images_denorm: the denormalized input data + :return t_starts: the initial time of the sequences + """ + method = Postprocess.get_input_data_per_batch.__name__ + + input_results = self.sess.run(input_iter) + input_images = input_results["images"] + t_starts = input_results["T_start"] + if self.norm_cls is None: + if self.stat_fl is None: + raise AttributeError("%{0}: Attribute stat_fl is not initialized yet.".format(method)) + self.norm_cls = Postprocess.get_norm(self.vars_in, self.stat_fl, norm_method) + + # sanity check on input sequence + assert np.ndim(input_images) == 5, "%{0}: Input sequence of mini-batch does not have five dimensions."\ + .format(method) + + input_images_denorm = Postprocess.denorm_images_all_channels(input_images, self.vars_in, self.norm_cls, + norm_method=norm_method) + + return input_results, input_images_denorm, t_starts + + def get_init_time(self, t_starts): + """ + Retrieves initial dates of forecast sequences from start time of whole inpt sequence + :param t_starts: list/array of start times of input sequence + :return: list of initial dates of forecast as numpy.datetime64 instances + """ + method = Postprocess.get_init_time.__name__ + + t_starts = np.squeeze(np.asarray(t_starts)) + if not np.ndim(t_starts) == 1: + raise ValueError("%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H" + .format(method)) + for i, t_start in enumerate(t_starts): + try: + #seq_ts = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"), periods=self.context_frames, + # freq="10min") + print('t_start: ',t_start) + t0 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"), periods=3, + freq="-10min") + t1 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"),periods=self.context_frames-2, + freq="10min") + seq_ts = t0.append(t1)[1:] + print('seq_ts: ',seq_ts) + except Exception as err: + print("%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'". + format(method, str(t_start))) + raise err + if i == 0: + ts_all = np.expand_dims(seq_ts, axis=0) + else: + ts_all = np.vstack((ts_all, seq_ts)) + + init_times = ts_all[:, -1] + times0 = ts_all[:, 0] + + return times0, init_times + + def populate_eval_metric_ds(self, metric_ds, data_ds, ind_start, varname): + """ + Populates evaluation metric dataset with values + :param metric_ds: the evaluation metric dataset with variables such as 'mfcst_mse' (MSE of model forecast) + :param data_ds: dataset holding the data from one mini-batch (see create_dataset-method) + :param ind_start: start index of dimension init_time (part of metric_ds) + :param varname: variable of interest (must be part of self.vars_in) + :return: metric_ds + """ + method = Postprocess.populate_eval_metric_ds.__name__ + + # dictionary of implemented evaluation metrics + dims = ["lat", "lon"] + known_eval_metrics = {"mse": Scores("mse", dims), "psnr": Scores("psnr", dims),"ssim": Scores("ssim",dims)} + + # generate list of functions that calculate requested evaluation metrics + if set(self.eval_metrics).issubset(known_eval_metrics): + eval_metrics_func = [known_eval_metrics[metric].score_func for metric in self.eval_metrics] + else: + misses = list(set(self.eval_metrics) - known_eval_metrics.keys()) + raise NotImplementedError("%{0}: The following requested evaluation metrics are not implemented yet: " + .format(method, ", ".join(misses))) + + varname_ref = "{0}_ref".format(varname) + # reset init-time coordinate of metric_ds in place and get indices for slicing + ind_end = np.minimum(ind_start + self.batch_size, self.num_samples_per_epoch) + init_times_metric = metric_ds["init_time"].values + init_times_metric[ind_start:ind_end] = data_ds["init_time"] + metric_ds = metric_ds.assign_coords(init_time=init_times_metric) + # populate metric_ds + for fcst_prod in self.fcst_products.keys(): + for imetric, eval_metric in enumerate(self.eval_metrics): + metric_name = "{0}_{1}_{2}".format(varname, fcst_prod, eval_metric) + varname_fcst = "{0}_{1}_fcst".format(varname, fcst_prod) + dict_ind = dict(init_time=data_ds["init_time"]) + metric_ds[metric_name].loc[dict_ind] = eval_metrics_func[imetric](data_ds[varname_fcst], + data_ds[varname_ref]) + # end of metric-loop + # end of forecast product-loop + + return metric_ds + + def add_ensemble_dim(self): + """ + Expands dimensions of loss-arrays by dummy ensemble-dimension (used for deterministic forecasts only) + :return: + """ + self.stochastic_loss_all_batches = np.expand_dims(self.fcst_mse_avg_batches, axis=0) # [1,future_lenght] + self.stochastic_loss_all_batches_psnr = np.expand_dims(self.fcst_psnr_avg_batches, axis=0) # [1,future_lenght] + + def create_dataset(self, input_seq, fcst_seq, ts_ini): + """ + Put input and forecast sequences into a xarray dataset. The latter also involves the persistence forecast + which is just initialized, but unpopulated at this stage. + The input data sequence is split into (effective) input sequence used for the forecast and into reference part. + :param input_seq: sequence of input images [batch ,seq, lat, lon, channel] + :param fcst_seq: sequence of forecast images [batch ,seq-1, lat, lon, channel] + :param ts_ini: initial time of forecast (=last time step of effective input sequence) + :return data_ds: above mentioned data in a nicely formatted dataset + """ + + method = Postprocess.create_dataset.__name__ + + # auxiliary variables for temporal dimensions + seq_hours = np.arange(self.sequence_length) - (self.context_frames-1) + # some sanity checks + assert np.shape(ts_ini)[0] == self.batch_size,\ + "%{0}: Inconsistent number of sequence start times ({1:d}) and batch size ({2:d})"\ + .format(method, np.shape(ts_ini)[0], self.batch_size) + + # turn input and forecast sequences to Data Arrays to ease indexing + try: + input_seq = xr.DataArray(input_seq, coords={"init_time": ts_ini, "fcst_hour": seq_hours, + "lat": self.lats, "lon": self.lons, "varname": self.vars_in}, + dims=["init_time", "fcst_hour", "lat", "lon", "varname"]) + except Exception as err: + print("%{0}: Could not create Data Array for input sequence.".format(method)) + raise err + + try: + fcst_seq = xr.DataArray(fcst_seq, coords={"init_time": ts_ini, "fcst_hour": seq_hours[1::], + "lat": self.lats, "lon": self.lons, "varname": self.vars_in}, + dims=["init_time", "fcst_hour", "lat", "lon", "varname"]) + except Exception as err: + print("%{0}: Could not create Data Array for forecast sequence.".format(method)) + raise err + + # Now create the dataset where the input sequence is splitted into input that served for creating the + # forecast and into the the reference sequences (which can be compared to the forecast) + # as where the persistence forecast is containing NaNs (must be generated later) + data_in_dict = dict([("{0}_in".format(var), input_seq.isel(fcst_hour=slice(None, self.context_frames), + varname=ivar) \ + .rename({"fcst_hour": "in_hour"}) + .reset_coords(names="varname", drop=True)) + for ivar, var in enumerate(self.vars_in)]) + + # get shape of forecast data (one variable) -> required to initialize persistence forecast data + shape_fcst = np.shape(fcst_seq.isel(fcst_hour=slice(self.context_frames-1, None), varname=0) + .reset_coords(names="varname", drop=True)) + data_ref_dict = dict([("{0}_ref".format(var), input_seq.isel(fcst_hour=slice(self.context_frames, None), + varname=ivar) + .reset_coords(names="varname", drop=True)) + for ivar, var in enumerate(self.vars_in)]) + + data_mfcst_dict = dict([("{0}_{1}_fcst".format(var, self.model), + fcst_seq.isel(fcst_hour=slice(self.context_frames-1, None), varname=ivar) + .reset_coords(names="varname", drop=True)) + for ivar, var in enumerate(self.vars_in)]) + + # fill persistence forecast variables with dummy data (to be populated later) + data_pfcst_dict = dict([("{0}_persistence_fcst".format(var), (["init_time", "fcst_hour", "lat", "lon"], + np.full(shape_fcst, np.nan))) + for ivar, var in enumerate(self.vars_in)]) + + # create the dataset + data_ds = xr.Dataset({**data_in_dict, **data_ref_dict, **data_mfcst_dict, **data_pfcst_dict}) + + return data_ds + + def handle_eval_metrics(self): + """ + Plots error-metrics averaged over all predictions to file. + :return: a bunch of plots as png-files + """ + method = Postprocess.handle_eval_metrics.__name__ + + if self.eval_metrics_ds is None: + raise AttributeError("%{0}: Attribute with dataset of evaluation metrics is still None.".format(method)) + + # perform bootstrapping on metric dataset + eval_metric_boot_ds = perform_block_bootstrap_metric(self.eval_metrics_ds, "init_time", self.block_length, + self.nboots_block) + # ... and merge into existing metric dataset + self.eval_metrics_ds = xr.merge([self.eval_metrics_ds, eval_metric_boot_ds]) + + # calculate (unbootstrapped) averaged metrics + eval_metric_avg_ds = avg_metrics(self.eval_metrics_ds, "init_time") + # ... and merge into existing metric dataset + self.eval_metrics_ds = xr.merge([self.eval_metrics_ds, eval_metric_avg_ds]) + + # save evaluation metrics to file + nc_fname = os.path.join(self.results_dir, "evaluation_metrics.nc") + Postprocess.save_ds_to_netcdf(self.eval_metrics_ds, nc_fname) + + # also save averaged metrics to JSON-file and plot it for diagnosis + _ = Postprocess.plot_avg_eval_metrics(self.eval_metrics_ds, self.eval_metrics, self.fcst_products, + self.vars_in[self.channel], self.results_dir) + + # auxiliary methods (not necessarily bound to class instance) + @staticmethod + def get_norm(varnames, stat_fl, norm_method): + """ + Retrieves normalization instance + :param varnames: list of variabe names + :param stat_fl: statistics JSON-file + :param norm_method: normalization method + :return: normalization instance which can be used to normalize images according to norm_method + """ + method = Postprocess.get_norm.__name__ + + if not isinstance(varnames, list): + raise ValueError("%{0}: varnames must be a list of variable names.".format(method)) + + norm_cls = Norm_data(varnames) + try: + with open(stat_fl) as js_file: + norm_cls.check_and_set_norm(json.load(js_file), norm_method) + norm_cls = norm_cls + except Exception as err: + print("%{0}: Could not handle statistics json-file '{1}'.".format(method, stat_fl)) + raise err + return norm_cls + + @staticmethod + def denorm_images_all_channels(image_sequence, varnames, norm, norm_method="cbnorm"): + """ + Denormalize data of all image channels + :param image_sequence: list/array [batch, seq, lat, lon, channel] of images + :param varnames: list of variable names whose order matches channel indices + :param norm: normalization instance + :param norm_method: normalization-method (default: 'minmax') + :return: denormalized image data + """ + method = Postprocess.denorm_images_all_channels.__name__ + + nvars = len(varnames) + image_sequence = np.array(image_sequence) + # sanity checks + if not isinstance(norm, Norm_data): + raise ValueError("%{0}: norm must be a normalization instance.".format(method)) + + if nvars != np.shape(image_sequence)[-1]: + raise ValueError("%{0}: Number of passed variable names ({1:d}) does not match number of channels ({2:d})" + .format(method, nvars, np.shape(image_sequence)[-1])) + + input_images_all_channles_denorm = [Postprocess.denorm_images(image_sequence, norm, {varname: c}, + norm_method=norm_method) + for c, varname in enumerate(varnames)] + + input_images_denorm = np.stack(input_images_all_channles_denorm, axis=-1) + return input_images_denorm + + @staticmethod + def denorm_images(input_images, norm, var_dict, norm_method="cbnorm"): + """ + Denormalize one channel of images + :param input_images: list/array [batch, seq, lat, lon, channel] + :param norm: normalization instance + :param var_dict: dictionary with one key only mapping variable name to channel index, e.g. {"2_t": 0} + :param norm_method: normalization method (default: minmax-normalization) + :return: denormalized image data + """ + method = Postprocess.denorm_images.__name__ + # sanity checks + if not isinstance(var_dict, dict): + raise ValueError("%{0}: var_dict is not a dictionary.".format(method)) + else: + if len(var_dict.keys()) > 1: + raise ValueError("%{0}: var_dict must contain one key only.".format(method)) + varname, channel = *var_dict.keys(), *var_dict.values() + + if not isinstance(norm, Norm_data): + raise ValueError("%{0}: norm must be a normalization instance.".format(method)) + + try: + input_images_denorm = norm.denorm_var(input_images[..., channel], varname, norm_method) + except Exception as err: + print("%{0}: Something went wrong when denormalizing image sequence. Inspect error-message!".format(method)) + raise err + + return input_images_denorm + + @staticmethod + def check_gen_images_stochastic_shape(gen_images_stochastic): + """ + For models with deterministic forecasts, one dimension would be lacking. Therefore, here the array + dimension is expanded by one. + """ + if len(np.array(gen_images_stochastic).shape) == 6: + pass + elif len(np.array(gen_images_stochastic).shape) == 5: + gen_images_stochastic = np.expand_dims(gen_images_stochastic, axis=0) + else: + raise ValueError("Passed gen_images_stochastic is not of the right shape") + return gen_images_stochastic + + @staticmethod + def get_persistence(ts, input_dir_pkl): + """ + This function gets the persistence forecast. + 'Today's weather will be like yesterday's weather.' + :param ts: list dontaining datetime objects from get_init_times + :param input_dir_pkl: input directory to pickle files + :return time_persistence: list containing the dates and times of the persistence forecast. + :return var_peristence: sequence of images corresponding to these times + """ + ts_persistence = [] + year_origin = ts[0].year + for t in range(len(ts)): # Scarlet: this certainly can be made nicer with list comprehension + ts_temp = ts[t] - dt.timedelta(days=1) + ts_persistence.append(ts_temp) + t_persistence_start = ts_persistence[0] + t_persistence_end = ts_persistence[-1] + year_start = t_persistence_start.year + month_start = t_persistence_start.month + month_end = t_persistence_end.month + print("start year:", year_start) + # only one pickle file is needed (all hours during the same month) + if month_start == month_end: + # Open files to search for the indizes of the corresponding time + time_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T')) + # Open file to search for the correspoding meteorological fields + var_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X')) + + if year_origin != year_start: + time_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'T')) + var_origin_pickle = list(Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 12, 'X')) + time_pickle.extend(time_origin_pickle) + var_pickle.extend(var_origin_pickle) + + # Retrieve starting index + ind = list(time_pickle).index(np.array(ts_persistence[0])) + + var_persistence = np.array(var_pickle)[ind:ind + len(ts_persistence)] + time_persistence = np.array(time_pickle)[ind:ind + len(ts_persistence)].ravel() + # case that we need to derive the data from two pickle files (changing month during the forecast periode) + else: + t_persistence_first_m = [] # should hold dates of the first month + t_persistence_second_m = [] # should hold dates of the second month + + for t in range(len(ts)): + m = ts_persistence[t].month + if m == month_start: + t_persistence_first_m.append(ts_persistence[t]) + if m == month_end: + t_persistence_second_m.append(ts_persistence[t]) + if year_origin == year_start: + # Open files to search for the indizes of the corresponding time + time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'T') + time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'T') + + # Open file to search for the correspoding meteorological fields + var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_start, 'X') + var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, month_end, 'X') + + if year_origin != year_start: + # Open files to search for the indizes of the corresponding time + time_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 1, 'T') + time_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, 12, 'T') + + # Open file to search for the correspoding meteorological fields + var_pickle_second = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_origin, 1, 'X') + var_pickle_first = Postprocess.load_pickle_for_persistence(input_dir_pkl, year_start, 12, 'X') + + # Retrieve starting index + ind_first_m = list(time_pickle_first).index(np.array(t_persistence_first_m[0])) + #print("time_pickle_second:", time_pickle_second) + ind_second_m = list(time_pickle_second).index(np.array(t_persistence_second_m[0])) + + # append the sequence of the second month to the first month + var_persistence = np.concatenate((var_pickle_first[ind_first_m:ind_first_m + len(t_persistence_first_m)], + var_pickle_second[ + ind_second_m:ind_second_m + len(t_persistence_second_m)]), + axis=0) + time_persistence = np.concatenate((time_pickle_first[ind_first_m:ind_first_m + len(t_persistence_first_m)], + time_pickle_second[ + ind_second_m:ind_second_m + len(t_persistence_second_m)]), + axis=0).ravel() + # Note: ravel is needed to eliminate the unnecessary dimension (20,1) becomes (20,) + + if len(time_persistence.tolist()) == 0: + raise ValueError("The time_persistent is empty!") + if len(var_persistence) == 0: + raise ValueError("The var persistence is empty!") + + var_persistence = var_persistence[1:] + time_persistence = time_persistence[1:] + + return var_persistence, time_persistence.tolist() + + @staticmethod + def load_pickle_for_persistence(input_dir_pkl, year_start, month_start, pkl_type): + """ + There are two types in our workflow: T_[month].pkl where the timestamp is stored, + X_[month].pkl where the variables are stored, e.g. temperature, geopotential and pressure. + This helper function constructs the directory, opens the file to read it, returns the variable. + :param input_dir_pkl: directory where input pickle files are stored + :param year_start: The year for which data is requested as integer + :param month_start: The year for which data is requested as integer + :param pkl_type: Either "X" or "T" + """ + path_to_pickle = os.path.join(input_dir_pkl, str(year_start), pkl_type + "_{:02}.pkl".format(month_start)) + with open(path_to_pickle, "rb") as pkl_file: + var = pickle.load(pkl_file) + return var + + @staticmethod + def save_ds_to_netcdf(ds, nc_fname, comp_level=5): + """ + Writes xarray dataset into netCDF-file + :param ds: The dataset to be written + :param nc_fname: Path and name of the target netCDF-file + :param comp_level: compression level, must be an integer between 1 and 9 (defualt: 5) + :return: - + """ + method = Postprocess.save_ds_to_netcdf.__name__ + + # sanity checks + if not isinstance(ds, xr.Dataset): + raise ValueError("%{0}: Argument 'ds' must be a xarray dataset.".format(method)) + + if not isinstance(comp_level, int): + raise ValueError("%{0}: Argument 'comp_level' must be an integer.".format(method)) + else: + if comp_level < 1 or comp_level > 9: + raise ValueError("%{0}: Argument 'comp_level' must be an integer between 1 and 9.".format(method)) + + if not os.path.isdir(os.path.dirname(nc_fname)): + raise NotADirectoryError("%{0}: The directory to store the netCDf-file does not exist.".format(method)) + + encode_nc = {key: {"zlib": True, "complevel": comp_level} for key in ds.keys()} + + # populate data in netCDF-file (take care for the mode!) + try: + ds.to_netcdf(nc_fname, encoding=encode_nc) + print("%{0}: netCDF-file '{1}' was created successfully.".format(method, nc_fname)) + except Exception as err: + print("%{0}: Something unexpected happened when creating netCDF-file '1'".format(method, nc_fname)) + raise err + + def plot_example_forecasts(self, metric="mse", channel=0): + """ + Plots example forecasts. The forecasts are chosen from the complete pool of the test dataset and are chosen + according to the accuracy in terms of the chosen metric. In add ition, to the best and worst forecast, + every decil of the chosen metric is retrieved to cover the whole bandwith of forecasts. + :param metric: The metric which is used for measuring accuracy + :param channel: The channel index of the forecasted variable to plot (correspondong to self.vars_in) + :return: 11 exemplary forecast plots are created + """ + method = Postprocess.plot_example_forecasts.__name__ + + metric_name = "{0}_{1}_{2}".format(self.vars_in[channel], self.model, metric) + if not metric_name in self.eval_metrics_ds: + raise ValueError("%{0}: Cannot find requested evaluation metric '{1}'".format(method, metric_name) + + " onto which selection of plotted forecast is done.") + # average metric of interest and obtain quantiles incl. indices + metric_mean = self.eval_metrics_ds[metric_name].mean(dim="fcst_hour") + quantiles = np.arange(0., 1.01, .1) + quantiles_val = metric_mean.quantile(quantiles, interpolation="nearest") + quantiles_inds = self.get_matching_indices(metric_mean.values, quantiles_val) + print(metric_mean.coords["init_time"]) + for i, ifcst in enumerate(quantiles_inds): + date_init = pd.to_datetime(metric_mean.coords["init_time"][ifcst].data) + nc_fname = os.path.join(self.results_dir, "vfp_date_{0}_sample_ind_{1:d}.nc" + .format(date_init.strftime("%Y%m%d%H"), ifcst)) + if not os.path.isfile(nc_fname): + raise FileNotFoundError("%{0}: Could not find requested file '{1}'".format(method, nc_fname)) + else: + # get the data + varname = self.vars_in[channel] + with xr.open_dataset(nc_fname) as dfile: + data_fcst = dfile["{0}_{1}_fcst".format(varname, self.model)] + data_ref = dfile["{0}_ref".format(varname)] + + data_diff = data_fcst - data_ref + # name of plot + plt_fname_base = os.path.join(self.output_dir, "forecast_{0}_{1}_{2}_{3:d}percentile.png" + .format(varname, date_init.strftime("%Y%m%dT%H00"), metric, + int(quantiles[i]*100.))) + + Postprocess.create_plot(data_fcst, data_diff, varname, plt_fname_base) + + @staticmethod + def init_metric_ds(fcst_products, eval_metrics, varname, nsamples, nlead_steps): + """ + Initializes dataset for storing evaluation metrics + :param fcst_products: list of forecast products to be evaluated + :param eval_metrics: list of forecast metrics to be calculated + :param varname: name of the variable for which metrics are calculated + :param nsamples: total number of forecast samples + :param nlead_steps: number of forecast steps + :return: eval_metric_ds + """ + eval_metric_dict = dict([("{0}_{1}_{2}".format(varname, *(fcst_prod, eval_met)), (["init_time", "fcst_hour"], + np.full((nsamples, nlead_steps), np.nan))) + for eval_met in eval_metrics for fcst_prod in fcst_products]) + + init_time_dummy = pd.date_range("1900-01-01 00:00", freq="s", periods=nsamples) + eval_metric_ds = xr.Dataset(eval_metric_dict, coords={"init_time": init_time_dummy, # just a placeholder + "fcst_hour": np.arange(1, nlead_steps+1)}) + + return eval_metric_ds + + + @staticmethod + def get_matching_indices(big_array, subset): + """ + Returns the indices where element values match the values in an array + :param big_array: the array to dig through + :param subset: array of values contained in big_array + :return: the desired indices + """ + + sorted_keys = np.argsort(big_array) + indexes = sorted_keys[np.searchsorted(big_array, subset, sorter=sorted_keys)] + + return indexes + + @staticmethod + def plot_avg_eval_metrics(eval_ds, eval_metrics, fcst_prod_dict, varname, out_dir): + """ + Plots error-metrics averaged over all predictions to file incl. 90%-confidence interval that is estimated by + block bootstrapping. + :param eval_ds: The dataset storing all evaluation metrics for each forecast (produced by init_metric_ds-method) + :param eval_metrics: list of evaluation metrics + :param fcst_prod_dict: dictionary of forecast products, e.g. {"persistence": "pfcst"} + :param varname: the variable name for which the evaluation metrics are available + :param out_dir: output directory to save the lots + :return: a bunch of plots as png-files + """ + method = Postprocess.plot_avg_eval_metrics.__name__ + + # settings for block bootstrapping + # sanity checks + if not isinstance(eval_ds, xr.Dataset): + raise ValueError("%{0}: Argument 'eval_ds' must be a xarray dataset.".format(method)) + + if not isinstance(fcst_prod_dict, dict): + raise ValueError("%{0}: Argument 'fcst_prod_dict' must be dictionary with short names of forecast product" + + "as key and long names as value.".format(method)) + + try: + nhours = np.shape(eval_ds.coords["fcst_hour"])[0] + except Exception as err: + print("%{0}: Input argument 'eval_ds' appears to be unproper.".format(method)) + raise err + + nmodels = len(fcst_prod_dict.values()) + colors = ["blue", "red", "black", "grey"] + for metric in eval_metrics: + # create a new figure object + fig = plt.figure(figsize=(6, 4)) + ax = plt.axes([0.1, 0.15, 0.75, 0.75]) + hours = np.arange(1, nhours+1) + + for ifcst, fcst_prod in enumerate(fcst_prod_dict.keys()): + metric_name = "{0}_{1}_{2}".format(varname, fcst_prod, metric) + try: + metric2plt = eval_ds[metric_name+"_avg"] + metric_boot = eval_ds[metric_name+"_bootstrapped"] + except Exception as err: + print("%{0}: Could not retrieve {1} and/or {2} from evaluation metric dataset." + .format(method, metric_name, metric_name+"_boot")) + raise err + # plot the data + metric2plt_min = metric_boot.quantile(0.05, dim="iboot") + metric2plt_max = metric_boot.quantile(0.95, dim="iboot") + plt.plot(hours, metric2plt, label=fcst_prod, color=colors[ifcst], marker="o") + plt.fill_between(hours, metric2plt_min, metric2plt_max, facecolor=colors[ifcst], alpha=0.3) + # configure plot + plt.xticks(hours) + # automatic y-limits for PSNR wich can be negative and positive + if metric != "psnr": ax.set_ylim(0., None) + legend = ax.legend(loc="upper right", bbox_to_anchor=(1.15, 1)) + ax.set_xlabel("Lead time [hours]") + ax.set_ylabel(metric.upper()) + plt_fname = os.path.join(out_dir, "evaluation_{0}".format(metric)) + print("Saving basic evaluation plot in terms of {1} to '{2}'".format(method, metric, plt_fname)) + plt.savefig(plt_fname) + + plt.close() + + return True + + @staticmethod + def create_plot(data, data_diff, varname, plt_fname): + """ + Creates filled contour plot of forecast data and also draws contours for differences. + ML: So far, only plotting of the 2m temperature is supported (with 12 predicted hours/frames) + :param data: the forecasted data array to be plotted + :param data_diff: the reference data ('ground truth') + :param varname: the name of the variable + :param plt_fname: the filename to the store the plot + :return: - + """ + method = Postprocess.create_plot.__name__ + + try: + coords = data.coords + # handle coordinates and forecast times + lat, lon = coords["lat"], coords["lon"] + date0 = pd.to_datetime(coords["init_time"].data) + fhhs = coords["fcst_hour"].data + except Exception as err: + print("%{0}: Could not retrieve expected coordinates lat, lon and time_forecast from data.".format(method)) + raise err + + lons, lats = np.meshgrid(lon, lat) + + date0_str = date0.strftime("%Y-%m-%d %H:%M UTC") + + # check data to be plotted since programme is not generic so far + if np.shape(fhhs)[0] != 12: + raise ValueError("%{0}: Currently, only 12 hour forecast can be handled properly.".format(method)) + + if varname != "2t": + raise ValueError("%{0}: Currently, only 2m temperature is plotted nicely properly.".format(method)) + + # define levels + clevs = np.arange(-10., 40., 1.) + clevs_diff = np.arange(0.5, 10.5, 2.) + clevs_diff2 = np.arange(-10.5, -0.5, 2.) + + # create fig and subplot axes + fig, axes = plt.subplots(2, 6, sharex=True, sharey=True, figsize=(12, 6)) + axes = axes.flatten() + + # create all subplots + for t, fhh in enumerate(fhhs): + m = Basemap(projection='cyl', llcrnrlat=np.min(lat), urcrnrlat=np.max(lat), + llcrnrlon=np.min(lon), urcrnrlon=np.max(lon), resolution='l', ax=axes[t]) + m.drawcoastlines() + x, y = m(lons, lats) + if t%6 == 0: + lat_lab = [1, 0, 0, 0] + axes[t].set_ylabel(u'Latitude', labelpad=30) + else: + lat_lab = list(np.zeros(4)) + if t/6 >= 1: + lon_lab = [0, 0, 0, 1] + axes[t].set_xlabel(u'Longitude', labelpad=15) + else: + lon_lab = list(np.zeros(4)) + m.drawmapboundary() + m.drawparallels(np.arange(0, 90, 5),labels=lat_lab, xoffset=1.) + m.drawmeridians(np.arange(5, 355, 10),labels=lon_lab, yoffset=1.) + cs = m.contourf(x, y, data.isel(fcst_hour=t)-273.15, clevs, cmap=plt.get_cmap("jet"), ax=axes[t], + extend="both") + cs_c_pos = m.contour(x, y, data_diff.isel(fcst_hour=t), clevs_diff, linewidths=0.5, ax=axes[t], + colors="black") + cs_c_neg = m.contour(x, y, data_diff.isel(fcst_hour=t), clevs_diff2, linewidths=1, linestyles="dotted", + ax=axes[t], colors="black") + axes[t].set_title("{0} +{1:02d}:00".format(date0_str, int(fhh)), fontsize=7.5, pad=4) + + fig.subplots_adjust(top=0.92, bottom=0.08, left=0.10, right=0.95, hspace=-0.7, + wspace=0.05) + # add colorbar. + cbar_ax = fig.add_axes([0.3, 0.22, 0.4, 0.02]) + cbar = fig.colorbar(cs, cax=cbar_ax, orientation="horizontal") + cbar.set_label('°C') + # save to disk + plt.savefig(plt_fname, bbox_inches="tight") + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("--results_dir", type=str, default='results', + help="ignored if output_gif_dir is specified") + parser.add_argument("--checkpoint", + help="directory with checkpoint or checkpoint name (e.g. checkpoint_dir/model-200000)") + parser.add_argument("--mode", type=str, choices=['train', 'val', 'test'], default='test', + help='mode for dataset, val or test.') + parser.add_argument("--batch_size", type=int, default=8, help="number of samples in batch") + parser.add_argument("--num_stochastic_samples", type=int, default=1) + parser.add_argument("--stochastic_plot_id", type=int, default=0, + help="The stochastic generate images index to plot") + parser.add_argument("--gpu_mem_frac", type=float, default=0.95, help="fraction of gpu memory to use") + parser.add_argument("--seed", type=int, default=7) + args = parser.parse_args() + + print('----------------------------------- Options ------------------------------------') + for k, v in args._get_kwargs(): + print(k, "=", v) + print('------------------------------------- End --------------------------------------') + + # ML: test_instance is a bit misleading here + test_instance = Postprocess(results_dir=args.results_dir, checkpoint=args.checkpoint, mode="test", + batch_size=args.batch_size, num_stochastic_samples=args.num_stochastic_samples, + gpu_mem_frac=args.gpu_mem_frac, seed=args.seed, + stochastic_plot_id=args.stochastic_plot_id, args=args) + + test_instance() + test_instance.run() + test_instance.handle_eval_metrics() + test_instance.plot_example_forecasts(metric="mse") + + +if __name__ == '__main__': + main() diff --git a/video_prediction_tools/model_modules/model_architectures.py b/video_prediction_tools/model_modules/model_architectures.py index 79c4b5c67e8e5bd01fba57ec6a43cc70a06f107c..9ab53f1918b2dd57d43e77ee8f4dd5a7556b2d5a 100644 --- a/video_prediction_tools/model_modules/model_architectures.py +++ b/video_prediction_tools/model_modules/model_architectures.py @@ -14,6 +14,7 @@ def known_models(): 'ours_vae_l1': 'SAVPVideoPredictionModel', 'ours_gan': 'SAVPVideoPredictionModel', "weatherBench": "WeatherBenchModel" - } + 'precrnn_v2': 'PredRNNv2VideoPredictionModel' + } return model_mappings diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py index cd0ec2b230169016cc10aee5ee2ff3d7e4fc611b..3b7afcc929e4affcf6f7aa14da37808d1e1faf78 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py @@ -1,29 +1,24 @@ -from .base_dataset import BaseVideoDataset -from .base_dataset import VideoDataset, SequenceExampleVideoDataset, VarLenFeatureVideoDataset -from .google_robot_dataset import GoogleRobotVideoDataset -from .sv2p_dataset import SV2PVideoDataset -from .softmotion_dataset import SoftmotionVideoDataset -from .kth_dataset import KTHVideoDataset -from .ucf101_dataset import UCF101VideoDataset -from .cartgripper_dataset import CartgripperVideoDataset -from .era5_dataset import ERA5Dataset -from .moving_mnist import MovingMnist -from data_preprocess.dataset_options import known_datasets -#from .era5_dataset_v2_anomaly import ERA5Dataset_v2_anomaly +#from .base_dataset import BaseVideoDataset +#from .era5_dataset import ERA5Dataset +#from .gzprcp_dataset import GzprcpDataset +#from .moving_mnist import MovingMnist +#from data_preprocess.dataset_options import known_datasets +from .stats import MinMax, ZScore +from .dataset import Dataset +import dask +from dask.base import tokenize +from utils.dataset_utils import DATASETS, get_dataset_info, get_filename_template -def get_dataset_class(dataset): - dataset_mappings = known_datasets() - dataset_class = dataset_mappings.get(dataset, dataset) - print("datset_class",dataset_class) - if dataset_class is None: - raise ValueError('Invalid dataset %s' % dataset) - else: - # ERA5Dataset movning_mnist does not inherit anything from VarLenFeatureVideoDataset-class, so it is the only dataset which does not need to be a subclass of BaseVideoDataset - #if not dataset_class == "ERA5Dataset" or not dataset_class == "MovingMnist": - # dataset_class = globals().get(dataset_class) - # if not issubclass(dataset_class,BaseVideoDataset): - # raise ValueError('Dataset {0} is not a valid dataset'.format(dataset_class)) - #else: - dataset_class = globals().get(dataset_class) +normalise = {"MinMax": MinMax, + "ZScore": ZScore} - return dataset_class +def get_dataset(name: str, *args, **kwargs): + try: + ds_info = get_dataset_info(name) + except ValueError as e: + raise ValueError(f"unknown dataset: {name}") + + return Dataset(*args, **kwargs, + normalize=normalise[ds_info["normalize"]], + filename_template=get_filename_template(name) + ) diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py deleted file mode 100644 index 99d7ac163883cbea2da2ab2ad1da156ebc2b5ff1..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/datasets/base_dataset.py +++ /dev/null @@ -1,510 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import glob -import os -import random -import re -from collections import OrderedDict -import numpy as np -import tensorflow as tf -from tensorflow.contrib.training import HParams - - -class BaseVideoDataset(object): - def __init__(self, input_dir: str, mode: str = "train", num_epochs: int = None, seed: int = None, - hparams_dict=None, hparams=None): - """ - This class is used for preparing data for training/validation and test models. - :param input_dir: the path of tfrecords files - :param mode: "train","val" or "test" - :param num_epochs: number of epochs - :param seed: the seed for dataset - :param hparams_dict: a dict of `name=value` pairs, where `name` must be defined in `self.get_default_hparams()`. - :param hparams: a dict of `name=value` pairs where `name` must be defined in `self.get_default_hparams()`. - These values overrides any values in hparams_dict (if any). - """ - method = self.__class__.__name__ - - self.input_dir = os.path.normpath(os.path.expanduser(input_dir)) - self.mode = mode - self.num_epochs = num_epochs - self.seed = seed - self.shuffled = False # will be set properly in make_dataset-method - # sanity checks - if self.mode not in ('train', 'val', 'test'): - raise ValueError('%{0}: Invalid mode {1}'.format(method, self.mode)) - if not os.path.exists(self.input_dir): - raise FileNotFoundError("%{0} input_dir '{1}' does not exist".format(method, self.input_dir)) - self.filenames = None - # look for tfrecords in input_dir and input_dir/mode directories - for input_dir in [self.input_dir, os.path.join(self.input_dir, self.mode)]: - filenames = glob.glob(os.path.join(input_dir, '*.tfrecord*')) - if filenames: - self.input_dir = input_dir - self.filenames = sorted(filenames) # ensures order is the same across systems - break - if not self.filenames: - raise FileNotFoundError('No tfrecords were found in %s.' % self.input_dir) - self.dataset_name = os.path.basename(os.path.split(self.input_dir)[0]) - - self.state_like_names_and_shapes = OrderedDict() - self.action_like_names_and_shapes = OrderedDict() - - self.hparams = self.parse_hparams(hparams_dict, hparams) - - def get_default_hparams_dict(self): - """ - Returns: - A dict with the following hyperparameters. - - crop_size: crop image into a square with sides of this length. - scale_size: resize image to this size after it has been cropped. - context_frames: the number of ground-truth frames to pass in at - start. - sequence_length: the number of frames in the video sequence, so - state-like sequences are of length sequence_length and - action-like sequences are of length sequence_length - 1. - This number includes the context frames. - long_sequence_length: the number of frames for the long version. - The default is the same as sequence_length. - frame_skip: number of frames to skip in between outputted frames, - so frame_skip=0 denotes no skipping. - time_shift: shift in time by multiples of this, so time_shift=1 - denotes all possible shifts. time_shift=0 denotes no shifting. - It is ignored (equiv. to time_shift=0) when mode != 'train'. - force_time_shift: whether to do the shift in time regardless of - mode. - shuffle_on_val: whether to shuffle the samples regardless if mode - is 'train' or 'val'. Shuffle never happens when mode is 'test'. - use_state: whether to load and return state and actions. - """ - hparams = dict( - crop_size=0, - scale_size=0, - context_frames=1, - sequence_length=0, - long_sequence_length=0, - frame_skip=0, - time_shift=1, - force_time_shift=False, - shuffle_on_val=False, - use_state=False, - ) - return hparams - - def get_default_hparams(self): - return HParams(**self.get_default_hparams_dict()) - - def parse_hparams(self, hparams_dict, hparams): - parsed_hparams = self.get_default_hparams().override_from_dict(hparams_dict or {}) - if hparams: - if not isinstance(hparams, (list, tuple)): - hparams = [hparams] - for hparam in hparams: - parsed_hparams.parse(hparam) - if parsed_hparams.long_sequence_length == 0: - parsed_hparams.long_sequence_length = parsed_hparams.sequence_length - return parsed_hparams - - @property - def jpeg_encoding(self): - raise NotImplementedError - - def set_sequence_length(self, sequence_length): - self.hparams.sequence_length = sequence_length - - def filter(self, serialized_example): - return tf.convert_to_tensor(True) - - def parser(self, serialized_example): - """ - Parses a single tf.train.Example or tf.train.SequenceExample into - images, states, actions, etc tensors. - """ - raise NotImplementedError - - def make_dataset(self, batch_size): - filenames = self.filenames - shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) - if shuffle: - self.shuffled = True - random.shuffle(filenames) - - dataset = tf.data.TFRecordDataset(filenames, buffer_size= 8 * 1024 * 1024) #todo: what is buffer_size - dataset = dataset.filter(self.filter) - if shuffle: - dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1024, count=self.num_epochs)) - else: - dataset = dataset.repeat(self.num_epochs) - - def _parser(serialized_example): - state_like_seqs, action_like_seqs = self.parser(serialized_example) - seqs = OrderedDict(list(state_like_seqs.items()) + list(action_like_seqs.items())) - return seqs - - num_parallel_calls = None if shuffle else 1 # for reproducibility (e.g. sampled subclips from the test set) - dataset = dataset.apply(tf.contrib.data.map_and_batch( - _parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) # Bing: Parallel data mapping, num_parallel_calls normally depends on the hardware, however, normally should be equal to be the usalbe number of CPUs - dataset = dataset.prefetch(batch_size) #Bing: Take the data to buffer inorder to save the waiting time for GPU - return dataset - - def make_batch(self, batch_size): - dataset = self.make_dataset(batch_size) - iterator = dataset.make_one_shot_iterator() - return iterator.get_next() - - def decode_and_preprocess_images(self, image_buffers, image_shape): - def decode_and_preprocess_image(image_buffer): - print("image buffer", tf.shape(image_buffer)) - - image_buffer = tf.reshape(image_buffer,[],name="reshape_1") - - if self.jpeg_encoding: - image = tf.image.decode_jpeg(image_buffer) - print("14********image decode_jpeg********", image) - else: - image = tf.decode_raw(image_buffer, tf.uint8) - print("15 ********image decode_raw********", tf.shape(image)) - print("16 ******** image shape", image_shape) - - image = tf.reshape(image, image_shape, name="reshape_4") ##Bing:the bug #issue 1 is here - crop_size = self.hparams.crop_size - scale_size = self.hparams.scale_size - if crop_size or scale_size: - if not crop_size: - crop_size = min(image_shape[0], image_shape[1]) - image = tf.image.resize_image_with_crop_or_pad(image, crop_size, crop_size) - image = tf.reshape(image, [crop_size, crop_size, 3],"reshape_3") - if scale_size: - # upsample with bilinear interpolation but downsample with area interpolation - if crop_size < scale_size: - image = tf.image.resize_images(image, [scale_size, scale_size], - method=tf.image.ResizeMethod.BILINEAR) - elif crop_size > scale_size: - image = tf.image.resize_images(image, [scale_size, scale_size], - method=tf.image.ResizeMethod.AREA) - else: - # image remains unchanged - pass - return image - - if not isinstance(image_buffers, (list, tuple)): - image_buffers = tf.unstack(image_buffers) - print("17 **************image buffer", image_buffers[0]) - images = [decode_and_preprocess_image(image_buffer) for image_buffer in image_buffers] - images = tf.image.convert_image_dtype(images, dtype=tf.float32) - return images - - def slice_sequences(self, state_like_seqs, action_like_seqs, example_sequence_length): - """ - Slices sequences of length `example_sequence_length` into subsequences - of length `sequence_length`. The dicts of sequences are updated - in-place and the same dicts are returned. - """ - # handle random shifting and frame skip - sequence_length = self.hparams.sequence_length # desired sequence length - frame_skip = self.hparams.frame_skip - time_shift = self.hparams.time_shift - print("22***********example sequence_length",example_sequence_length) - if (time_shift and self.mode == 'train') or self.hparams.force_time_shift: - print("23***********I am here") - assert time_shift > 0 and isinstance(time_shift, int) - if isinstance(example_sequence_length, tf.Tensor): - example_sequence_length = tf.cast(example_sequence_length, tf.int32) - num_shifts = ((example_sequence_length - 1) - (sequence_length - 1) * (frame_skip + 1)) // time_shift - assert_message = ('example_sequence_length has to be at least %d when ' - 'sequence_length=%d, frame_skip=%d.' % - ((sequence_length - 1) * (frame_skip + 1) + 1, - sequence_length, frame_skip)) - with tf.control_dependencies([tf.assert_greater_equal(num_shifts, 0, - data=[example_sequence_length, num_shifts], message=assert_message)]): - t_start = tf.random_uniform([], 0, num_shifts + 1, dtype=tf.int32, seed=self.seed) * time_shift - else: - t_start = 0 - print("20:**********************sequence_len: {}, t_start:{}, frame_skip:{}".format(sequence_length,tf.shape(t_start),frame_skip)) - state_like_t_slice = slice(t_start, t_start + (sequence_length - 1) * (frame_skip + 1) + 1, frame_skip + 1) - action_like_t_slice = slice(t_start, t_start + (sequence_length - 1) * (frame_skip + 1)) - - for example_name, seq in state_like_seqs.items(): - print("21*****************seq*******",seq) - seq = tf.convert_to_tensor(seq)[state_like_t_slice] - print("25**************ses.shape", [self.hparams.sequence_length] + seq.shape.as_list()[1:]) - seq.set_shape([sequence_length] + seq.shape.as_list()[1:]) - state_like_seqs[example_name] = seq - for example_name, seq in action_like_seqs.items(): - seq = tf.convert_to_tensor(seq)[action_like_t_slice] - seq.set_shape([(sequence_length - 1) * (frame_skip + 1)] + seq.shape.as_list()[1:]) - # concatenate actions of skipped frames into single macro actions - seq = tf.reshape(seq, [sequence_length - 1, -1]) - action_like_seqs[example_name] = seq - return state_like_seqs, action_like_seqs - - def num_examples_per_epoch(self): - raise NotImplementedError - - -class VideoDataset(BaseVideoDataset): - """ - This class supports reading tfrecords where a sequence is stored as - multiple tf.train.Example and each of them is stored under a different - feature name (which is indexed by the time step). - """ - def __init__(self, *args, **kwargs): - super(VideoDataset, self).__init__(*args, **kwargs) - self._max_sequence_length = None - self._dict_message = None - - def _check_or_infer_shapes(self): - """ - Should be called after state_like_names_and_shapes and - action_like_names_and_shapes have been finalized. - """ - state_like_names_and_shapes = OrderedDict([(k, list(v)) for k, v in self.state_like_names_and_shapes.items()]) - action_like_names_and_shapes = OrderedDict([(k, list(v)) for k, v in self.action_like_names_and_shapes.items()]) - from google.protobuf.json_format import MessageToDict - example = next(tf.python_io.tf_record_iterator(self.filenames[0])) - self._dict_message = MessageToDict(tf.train.Example.FromString(example)) - for example_name, name_and_shape in (list(state_like_names_and_shapes.items()) + - list(action_like_names_and_shapes.items())): - name, shape = name_and_shape - feature = self._dict_message['features']['feature'] - names = [name_ for name_ in feature.keys() if re.search(name.replace('%d', '\d+'), name_) is not None] - if not names: - raise ValueError('Could not found any feature with name pattern %s.' % name) - if example_name in self.state_like_names_and_shapes: - sequence_length = len(names) - else: - sequence_length = len(names) + 1 - if self._max_sequence_length is None: - self._max_sequence_length = sequence_length - else: - self._max_sequence_length = min(sequence_length, self._max_sequence_length) - name = names[0] - feature = feature[name] - list_type, = feature.keys() - if list_type == 'floatList': - inferred_shape = (len(feature[list_type]['value']),) - if shape is None: - name_and_shape[1] = inferred_shape - else: - if inferred_shape != shape: - raise ValueError('Inferred shape for feature %s is %r but instead got shape %r.' % - (name, inferred_shape, shape)) - elif list_type == 'bytesList': - image_str, = feature[list_type]['value'] - # try to infer image shape - inferred_shape = None - if not self.jpeg_encoding: - spatial_size = len(image_str) // 4 - height = width = int(np.sqrt(spatial_size)) # assume square image - if len(image_str) == (height * width * 4): - inferred_shape = (height, width, 3) - if shape is None: - if inferred_shape is not None: - name_and_shape[1] = inferred_shape - else: - raise ValueError('Unable to infer shape for feature %s of size %d.' % (name, len(image_str))) - else: - if inferred_shape is not None and inferred_shape != shape: - raise ValueError('Inferred shape for feature %s is %r but instead got shape %r.' % - (name, inferred_shape, shape)) - else: - raise NotImplementedError - self.state_like_names_and_shapes = OrderedDict([(k, tuple(v)) for k, v in state_like_names_and_shapes.items()]) - self.action_like_names_and_shapes = OrderedDict([(k, tuple(v)) for k, v in action_like_names_and_shapes.items()]) - - # set sequence_length to the longest possible if it is not specified - if not self.hparams.sequence_length: - self.hparams.sequence_length = (self._max_sequence_length - 1) // (self.hparams.frame_skip + 1) + 1 - - def set_sequence_length(self, sequence_length): - if not sequence_length: - sequence_length = (self._max_sequence_length - 1) // (self.hparams.frame_skip + 1) + 1 - self.hparams.sequence_length = sequence_length - - def parser(self, serialized_example): - """ - Parses a single tf.train.Example into images, states, actions, etc tensors. - """ - features = dict() - for i in range(self._max_sequence_length): - for example_name, (name, shape) in self.state_like_names_and_shapes.items(): - if example_name == 'images': # special handling for image - features[name % i] = tf.FixedLenFeature([1], tf.string) - else: - features[name % i] = tf.FixedLenFeature(shape, tf.float32) - for i in range(self._max_sequence_length - 1): - for example_name, (name, shape) in self.action_like_names_and_shapes.items(): - features[name % i] = tf.FixedLenFeature(shape, tf.float32) - - # check that the features are in the tfrecord - for name in features.keys(): - if name not in self._dict_message['features']['feature']: - raise ValueError('Feature with name %s not found in tfrecord. Possible feature names are:\n%s' % - (name, '\n'.join(sorted(self._dict_message['features']['feature'].keys())))) - - # parse all the features of all time steps together - features = tf.parse_single_example(serialized_example, features=features) - - - state_like_seqs = OrderedDict([(example_name, []) for example_name in self.state_like_names_and_shapes]) - action_like_seqs = OrderedDict([(example_name, []) for example_name in self.action_like_names_and_shapes]) - for i in range(self._max_sequence_length): - for example_name, (name, shape) in self.state_like_names_and_shapes.items(): - state_like_seqs[example_name].append(features[name % i]) - for i in range(self._max_sequence_length - 1): - for example_name, (name, shape) in self.action_like_names_and_shapes.items(): - action_like_seqs[example_name].append(features[name % i]) - - # for this class, it's much faster to decode and preprocess the entire sequence before sampling a slice - _, image_shape = self.state_like_names_and_shapes['images'] - state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape) - - state_like_seqs, action_like_seqs = \ - self.slice_sequences(state_like_seqs, action_like_seqs, self._max_sequence_length) - return state_like_seqs, action_like_seqs - - -class SequenceExampleVideoDataset(BaseVideoDataset): - """ - This class supports reading tfrecords where an entire sequence is stored as - a single tf.train.SequenceExample. - """ - def parser(self, serialized_example): - """ - Parses a single tf.train.SequenceExample into images, states, actions, etc tensors. - """ - sequence_features = dict() - for example_name, (name, shape) in self.state_like_names_and_shapes.items(): - if example_name == 'images': # special handling for image - sequence_features[name] = tf.FixedLenSequenceFeature([1], tf.string) - else: - sequence_features[name] = tf.FixedLenSequenceFeature(shape, tf.float32) - for example_name, (name, shape) in self.action_like_names_and_shapes.items(): - sequence_features[name] = tf.FixedLenSequenceFeature(shape, tf.float32) - - _, sequence_features = tf.parse_single_sequence_example( - serialized_example, sequence_features=sequence_features) - - state_like_seqs = OrderedDict() - action_like_seqs = OrderedDict() - for example_name, (name, shape) in self.state_like_names_and_shapes.items(): - state_like_seqs[example_name] = sequence_features[name] - for example_name, (name, shape) in self.action_like_names_and_shapes.items(): - action_like_seqs[example_name] = sequence_features[name] - - # the sequence_length of this example is determined by the shortest sequence - example_sequence_length = [] - for example_name, seq in state_like_seqs.items(): - example_sequence_length.append(tf.shape(seq)[0]) - for example_name, seq in action_like_seqs.items(): - example_sequence_length.append(tf.shape(seq)[0] + 1) - example_sequence_length = tf.reduce_min(example_sequence_length) - #bing - state_like_seqs, action_like_seqs = \ - self.slice_sequences(state_like_seqs, action_like_seqs, example_sequence_length) - - # decode and preprocess images on the sampled slice only - _, image_shape = self.state_like_names_and_shapes['images'] - state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape) - return state_like_seqs, action_like_seqs - - -class VarLenFeatureVideoDataset(BaseVideoDataset): - """ - This class supports reading tfrecords where an entire sequence is stored as - a single tf.train.Example. - - https://github.com/tensorflow/tensorflow/issues/15977 - """ - def filter(self, serialized_example): - features = dict() - features['sequence_length'] = tf.FixedLenFeature((), tf.int64) - features = tf.parse_single_example(serialized_example, features=features) - example_sequence_length = features['sequence_length'] - return tf.greater_equal(example_sequence_length, self.hparams.sequence_length) - - def parser(self, serialized_example): - """ - Parses a single tf.train.SequenceExample into images, states, actions, etc tensors. - """ - print("1.***parser function from class VarLenFeatureVideoDatase") - features = dict() - features['sequence_length'] = tf.FixedLenFeature((), tf.int64) - for example_name, (name, shape) in self.state_like_names_and_shapes.items(): - if example_name == 'images': - #Bing - #features[name] = tf.FixedLenFeature([1], tf.string) - features[name] = tf.VarLenFeature(tf.string) - else: - features[name] = tf.VarLenFeature(tf.float32) - for example_name, (name, shape) in self.action_like_names_and_shapes.items(): - features[name] = tf.VarLenFeature(tf.float32) - - features = tf.parse_single_example(serialized_example, features=features) - example_sequence_length = features['sequence_length'] - - state_like_seqs = OrderedDict() - action_like_seqs = OrderedDict() - for example_name, (name, shape) in self.state_like_names_and_shapes.items(): - if example_name == 'images': - seq = tf.sparse_tensor_to_dense(features[name], '') - else: - seq = tf.sparse_tensor_to_dense(features[name]) - seq = tf.reshape(seq, [example_sequence_length] + list(shape)) - - state_like_seqs[example_name] = seq - - - for example_name, (name, shape) in self.action_like_names_and_shapes.items(): - - seq = tf.sparse_tensor_to_dense(features[name]) - seq = tf.reshape(seq, [example_sequence_length - 1] + list(shape)) - action_like_seqs[example_name] = seq - - #Bing: I replce the self.slice_sequence to the following three lines , the program works, but I need to figure it out what happend inside this function - state_like_seqs, action_like_seqs = \ - self.slice_sequences(state_like_seqs, action_like_seqs, example_sequence_length) - # seq = tf.convert_to_tensor(seq) - # print("25**************ses.shape",[self.hparams.sequence_length] + seq.shape.as_list()[1:]) - # seq.set_shape([self.hparams.sequence_length] + seq.shape.as_list()[1:]) - # state_like_seqs[example_name] = seq - #print("11**********Slide sequences**************** ", action_like_seqs) - # decode and preprocess images on the sampled slice only - _, image_shape = self.state_like_names_and_shapes['images'] - - state_like_seqs['images'] = self.decode_and_preprocess_images(state_like_seqs['images'], image_shape) - return state_like_seqs, action_like_seqs - - -if __name__ == '__main__': - import cv2 - from video_prediction import datasets - - datasets = [ - datasets.SV2PVideoDataset('data/shape', mode='val'), - datasets.SV2PVideoDataset('data/humans', mode='val'), - datasets.SoftmotionVideoDataset('data/bair', mode='val'), - datasets.KTHVideoDataset('data/kth', mode='val'), - datasets.KTHVideoDataset('data/kth_128', mode='val'), - datasets.UCF101VideoDataset('data/ucf101', mode='val'), - ] - batch_size = 4 - - sess = tf.Session() - - for dataset in datasets: - inputs = dataset.make_batch(batch_size) - images = inputs['images'] - images = tf.reshape(images, [-1] + images.get_shape().as_list()[2:]) - images = sess.run(images) - images = (images * 255).astype(np.uint8) - for image in images: - if image.shape[-1] == 1: - image = np.tile(image, [1, 1, 3]) - else: - image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) - cv2.imshow(dataset.input_dir, image) - cv2.waitKey(50) diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/cartgripper_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/cartgripper_dataset.py deleted file mode 100644 index 3a4fc6b5ddbd5768b51a0bc2f7418609291916de..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/datasets/cartgripper_dataset.py +++ /dev/null @@ -1,28 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import itertools - -from .base_dataset import VideoDataset -from .softmotion_dataset import SoftmotionVideoDataset - - -class CartgripperVideoDataset(SoftmotionVideoDataset): - def __init__(self, *args, **kwargs): - VideoDataset.__init__(self, *args, **kwargs) - self.state_like_names_and_shapes['images'] = '%d/image_view0/encoded', (48, 64, 3) - if self.hparams.use_state: - self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (6,) - self.action_like_names_and_shapes['actions'] = '%d/action', (3,) - self._check_or_infer_shapes() - - def get_default_hparams_dict(self): - default_hparams = super(CartgripperVideoDataset, self).get_default_hparams_dict() - hparams = dict( - context_frames=2, - sequence_length=15, - time_shift=3, - use_state=True, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..b42702a03842ba53ac3b0bcf8db14b7f83b98bc1 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/datasets/dataset.py @@ -0,0 +1,245 @@ +__author__ = "Bing Gong" +__date__ = "2022-03-17" +__email__ = "b.gong@fz-juelich.de" + +import json +import os +from typing import List +from dataclasses import dataclass +from pathlib import Path + +import xarray as xr +import tensorflow as tf + +from hparams_utils import * +from model_modules.video_prediction.datasets.stats import DatasetStats, Normalize + + +class Dataset: + modes = ["train", "val", "test"] + dims = ["time", "lat", "lon", "variables"] + hparams = [ + "context_frames", + "max_epochs", + "batch_size", + "shuffle_on_val", + "sequence_length", + "shift" + ] + + def __init__( + self, + input_dir: Path, + output_dir: Path, + datasplit_path: str, + hparams_path: str, + normalize, + filename_template: str, + seed: int = None, + nsamples_ref: int = None # TODO: implemment ? + ): + """ + This class is used for preparing data for training/validation and test models + :param input_dir: the path of tfrecords files + :param datasplit_path: the path pointing to the datasplit_config json file + :param hparams_path: the path to the dict that contains hparameters, + :param mode: string, "train","val" or "test" + :param seed: int, the seed for shuffeling the dataset + :param nsamples_ref: number of reference samples which can be used to control repetition factor for dataset + for ensuring adopted size of dataset iterator (used for validation data during training) + Example: Let nsamples_ref be 1000 while the current datset consists 100 samples, then + the repetition-factor will be 10 (i.e. nsamples*rep_fac = nsamples_ref) + :param normalize: class of the desired normalization method + """ + self.input_dir = input_dir + self.output_dir = output_dir + self.seed = seed # used for shuffeling + self.nsamples_ref = nsamples_ref + self.normalize = normalize + self.filename_template = filename_template + + # sanity checks + if not os.path.exists(self.input_dir): + raise FileNotFoundError("input_dir '{self.input_dir}' does not exist") + + # get configuration parameters from datasplit- and model parameters-files + with open(datasplit_path, "r") as f: # TODO:maybe sanity check + self.datasplit = json.loads(f.read()) + + with open(hparams_path, "r") as f: # TODO:maybe sanity check + hparams = dotdict(json.loads(f.read())) + + try: + self.context_frames = hparams["context_frames"] + self.max_epochs = hparams["max_epochs"] + self.batch_size = hparams["batch_size"] + self.shuffle_on_val = hparams["shuffle_on_val"] + self.sequence_length = hparams["sequence_length"] + self.shift = hparams["shift"] + except KeyError as e: + raise ValueError(f"missing hyperparameter: {e.args[0]}") + + self._stats_lookup = { + "train": None, + "val": None, + "test": None, + } + + def load_data(self, files): + """ + load DataSet from files and transform to DataArray (n_samples, lat, lon, channels). + """ + + ds = xr.open_mfdataset(files).load() + da = ds.to_array(dim="variables").squeeze() + return da.transpose(*Dataset.dims) + + def filenames(self, mode): + """ + Get the filenames for training, validation and testing dataset. + + :param mode: differentiate datasets, should be "train", "val" or "test" + """ + time_window = self.datasplit[mode] + files = [] + # {"2008":[1,2,3,4,...], "2009":[1,2,3,4,...]} + for year, months in time_window.items(): + for month in months: + files.append(self.input_dir / self.filename_template.format(year=year, month=month)) + + return files + + def _get_data(self, mode): + """ + Load data from files into memory and calculate statistics. + + :param mode: indicator to differentiate between training, validation and test data + """ + files = self.filenames(mode) + if not len(files) > 0: + raise Exception( + f"no files for dataset {mode} found, check data_split dictionary" + ) + da = self.load_data(files) + + stats = DatasetStats( + da.mean(dim=Dataset.dims[:3]).values, + da.std(dim=Dataset.dims[:3]).values, + da.max(dim=Dataset.dims[:3]).values, + da.min(dim=Dataset.dims[:3]).values, + da.sizes["time"] + ) + + self._stats_lookup[mode] = stats + return da + + def make_dataset(self, mode, use_training_stats=True): + """ + Prepare Tensorflow dataset, load data and do all nessecary preprocessing. + """ + if mode not in Dataset.modes: + raise ValueError(f"Invalid mode {mode}") + + shuffle = mode == "train" or (mode == "val" and self.shuffle_on_val) + + # get data array + da = self._get_data(mode).load() # load everything into memory + print(f"xarray info: {da.shape}") + + def data_generator(iterable): + iterator = iter(iterable) + yield from iterator + + # create tf dataset + dataset = tf.data.Dataset.from_generator( + data_generator, + args = [da], + output_types=tf.float32, + output_shapes=da.shape[1:], + ) + + # create training sequences + dataset = dataset.window( + self.sequence_length, shift=self.shift, drop_remainder=True + ) + dataset = dataset.flat_map(lambda window: window.batch(self.sequence_length)) + + # shuffle + if shuffle: + dataset = dataset.apply( + tf.contrib.data.shuffle_and_repeat( + buffer_size=1024, count=self.max_epochs, seed=self.seed + ) + ) # TODO: check, self.seed + else: + dataset = dataset.repeat(self.max_epochs) + + # create batches + dataset = dataset.batch(self.batch_size) + + # normalize + if use_training_stats: # use training stats for normalization + stats = self.training_stats + print(f"used training stats: {stats}") + else: # use corresponding stats for normalization + stats = self._stats_lookup[mode] + + normalize = self.normalize(stats) + stats.to_json(self.output_dir / "normalization_stats.json") + + dataset = dataset.map(normalize.normalize_vars) + + return dataset + + @property + def num_training_samples(self): + """ + obtain the number of samples per each epoch + :return: int + """ + stats = self._get_stats("train") + return int((stats.n + self.shift) / (self.sequence_length + self.shift)) + + @property + def num_test_samples(self): + """ + obtain the number of samples per each epoch + :return: int + """ + stats = self._get_stats("test") + return int((stats.n + self.shift) / (self.sequence_length + self.shift)) + + @property + def num_validation_samples(self): + """ + obtain the number of samples per each epoch + :return: int + """ + stats = self._get_stats("val") + return int((stats.n + self.shift) / (self.sequence_length + self.shift)) + + def make_training(self): + return self.make_dataset(mode="train") + + def make_test(self): + return self.make_dataset(mode="test") + + def make_validation(self): + return self.make_dataset(mode="val") + + def _get_stats(self, mode): + if self._stats_lookup[mode] is None: + self._get_data(mode) + return self._stats_lookup[mode] + + @property + def training_stats(self): + return self._get_stats("train") + + @property + def validation_stats(self): + return self._get_stats("val") + + @property + def test_stats(self): + return self._get_stats("test") diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/google_robot_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/google_robot_dataset.py deleted file mode 100644 index 518e0736c7a4714a13f60226647d134636e5cb9e..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/datasets/google_robot_dataset.py +++ /dev/null @@ -1,44 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import itertools -import os - -from .base_dataset import VideoDataset - - -class GoogleRobotVideoDataset(VideoDataset): - """ - https://sites.google.com/view/sna-visual-mpc - """ - def __init__(self, *args, **kwargs): - super(GoogleRobotVideoDataset, self).__init__(*args, **kwargs) - self.state_like_names_and_shapes['images'] = 'move/%d/image/encoded', (512, 640, 3) - if self.hparams.use_state: - self.state_like_names_and_shapes['states'] = 'move/%d/endeffector/vec_pitch_yaw', (5,) - self.action_like_names_and_shapes['actions'] = 'move/%d/commanded_pose/vec_pitch_yaw', (5,) - self._check_or_infer_shapes() - - def get_default_hparams_dict(self): - default_hparams = super(GoogleRobotVideoDataset, self).get_default_hparams_dict() - hparams = dict( - context_frames=2, - sequence_length=15, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def num_examples_per_epoch(self): - if os.path.basename(self.input_dir) == 'push_train': - count = 51615 - elif os.path.basename(self.input_dir) == 'push_testseen': - count = 1038 - elif os.path.basename(self.input_dir) == 'push_testnovel': - count = 995 - else: - raise NotImplementedError - return count - - @property - def jpeg_encoding(self): - return True diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py deleted file mode 100644 index 1c29fe5fa11c406f8de8c60fcd29d4bc8de60e10..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/datasets/kth_dataset.py +++ /dev/null @@ -1,226 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import argparse -import glob -import itertools -import os -import pickle -import random -import re -import tensorflow as tf -import numpy as np -import skimage.io -from collections import OrderedDict -from tensorflow.contrib.training import HParams -from google.protobuf.json_format import MessageToDict - - -class KTHVideoDataset(object): - def __init__(self,input_dir=None,datasplit_config=None,hparams_dict_config=None, mode='train',seed=None): - """ - This class is used for preparing data for training/validation and test models - args: - input_dir : the path of tfrecords files - datasplit_config : the path pointing to the datasplit_config json file - hparams_dict_config : the path to the dict that contains hparameters, - mode : string, "train","val" or "test" - seed : int, the seed for dataset - """ - self.input_dir = input_dir - self.datasplit_config = datasplit_config - self.mode = mode - self.seed = seed - if self.mode not in ('train', 'val', 'test'): - raise ValueError('Invalid mode %s' % self.mode) - if not os.path.exists(self.input_dir): - raise FileNotFoundError("input_dir %s does not exist" % self.input_dir) - self.datasplit_dict_path = datasplit_config - self.data_dict = self.get_datasplit() - self.hparams_dict_config = hparams_dict_config - self.hparams_dict = self.get_model_hparams_dict() - self.hparams = self.parse_hparams() - self.get_tfrecords_filesnames_base_datasplit() - self.get_example_info() - - - def get_default_hparams(self): - return HParams(**self.get_default_hparams_dict()) - - def get_default_hparams_dict(self): - """ - The function that contains default hparams - Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. - sequence_length : the number of frames in the video sequence - max_epochs : the number of epochs to train model - lr : learning rate - loss_fun : the loss function - """ - hparams = dict( - context_frames=10, - sequence_length=20, - max_epochs = 20, - batch_size = 40, - lr = 0.001, - loss_fun = "rmse", - shuffle_on_val= True, - ) - return hparams - - def get_datasplit(self): - """ - Get the datasplit json file - """ - - with open(self.datasplit_dict_path) as f: - self.d = json.load(f) - return self.d - - def parse_hparams(self): - """ - Parse the hparams setting to ovoerride the default ones - """ - parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) - return parsed_hparams - - - def get_tfrecords_filesnames_base_datasplit(self): - """ - Get absolute .tfrecord path names based on the data splits patterns - """ - self.filenames = [] - self.data_mode = self.data_dict[self.mode] - self.tf_names = [] - for year, months in self.data_mode.items(): - for month in months: - tf_files = "sequence_Y_{}_M_{}_*_to_*.tfrecord*".format(year,month) - self.tf_names.append(tf_files) - # look for tfrecords in input_dir and input_dir/mode directories - for files in self.tf_names: - self.filenames.extend(glob.glob(os.path.join(self.input_dir, files))) - if self.filenames: - self.filenames = sorted(self.filenames) # ensures order is the same across systems - if not self.filenames: - raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir) - - def num_examples_per_epoch(self): - """ - Calculate how many tfrecords samples in the train/val/test - """ - #count how many tfrecords files for train/val/testing - len_fnames = len(self.filenames) - seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt') - with open(seq_len_file, 'r') as sequence_lengths_file: - sequence_lengths = sequence_lengths_file.readlines() - sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] - self.num_examples_per_epoch = len_fnames * sequence_lengths[0] - return self.num_examples_per_epoch - - -def _bytes_feature(value): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - - -def _bytes_list_feature(values): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) - - -def _int64_feature(value): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - - -def partition_data(input_dir): - # List files and corresponding person IDs - fnames = glob.glob(os.path.join(input_dir, '*/*')) - fnames = [fname for fname in fnames if os.path.isdir(fname)] - print("frames",fnames[0]) - persons = [re.match('person(\d+)_\w+_\w+', os.path.split(fname)[1]).group(1) for fname in fnames] - persons = np.array([int(person) for person in persons]) - train_mask = persons <= 16 - train_fnames = [fnames[i] for i in np.where(train_mask)[0]] - test_fnames = [fnames[i] for i in np.where(~train_mask)[0]] - random.shuffle(train_fnames) - pivot = int(0.95 * len(train_fnames)) - train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:] - return train_fnames, val_fnames, test_fnames - - -def save_tf_record(output_fname, sequences): - print('saving sequences to %s' % output_fname) - with tf.python_io.TFRecordWriter(output_fname) as writer: - for sequence in sequences: - num_frames = len(sequence) - height, width, channels = sequence[0].shape - encoded_sequence = [image.tostring() for image in sequence] - features = tf.train.Features(feature={ - 'sequence_length': _int64_feature(num_frames), - 'height': _int64_feature(height), - 'width': _int64_feature(width), - 'channels': _int64_feature(channels), - 'images/encoded': _bytes_list_feature(encoded_sequence), - }) - example = tf.train.Example(features=features) - writer.write(example.SerializeToString()) - - - def read_frames_and_save_tf_records(output_dir, video_dirs, image_size, sequences_per_file=128): - partition_name = os.path.split(output_dir)[1] #Get the folder name train, val or test - sequences = [] - sequence_iter = 0 - sequence_lengths_file = open(os.path.join(output_dir, 'sequence_lengths.txt'), 'w') - for video_iter, video_dir in enumerate(video_dirs): #Interate group (e.g. walking) each person - meta_partition_name = partition_name if partition_name == 'test' else 'train' - meta_fname = os.path.join(os.path.split(video_dir)[0], '%s_meta%dx%d.pkl' % - (meta_partition_name, image_size, image_size)) - with open(meta_fname, "rb") as f: - data = pickle.load(f) # The data has 62 items, each item is a dict, with three keys. "vid","n", and "files", Each file has 4 channels, each channel has n sequence images with 64*64 png - - vid = os.path.split(video_dir)[1] - (d,) = [d for d in data if d['vid'] == vid] - for frame_fnames_iter, frame_fnames in enumerate(d['files']): - frame_fnames = [os.path.join(video_dir, frame_fname) for frame_fname in frame_fnames] - frames = skimage.io.imread_collection(frame_fnames) - # they are grayscale images, so just keep one of the channels - frames = [frame[..., 0:1] for frame in frames] - - if not sequences: #The length of the sequence in sequences could be different - last_start_sequence_iter = sequence_iter - print("reading sequences starting at sequence %d" % sequence_iter) - - sequences.append(frames) - sequence_iter += 1 - sequence_lengths_file.write("%d\n" % len(frames)) - - if (len(sequences) == sequences_per_file or - (video_iter == (len(video_dirs) - 1) and frame_fnames_iter == (len(d['files']) - 1))): - output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter - 1) - output_fname = os.path.join(output_dir, output_fname) - save_tf_record(output_fname, sequences) - sequences[:] = [] - sequence_lengths_file.close() - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input_dir", type=str, help="directory containing the processed directories " - "boxing, handclapping, handwaving, " - "jogging, running, walking") - parser.add_argument("output_dir", type=str) - parser.add_argument("image_size", type=int) - args = parser.parse_args() - partition_names = ['train', 'val', 'test'] - print("input dir", args.input_dir) - partition_fnames = partition_data(args.input_dir) - print("partiotion_fnames[0]", partition_fnames[0]) - for partition_name, partition_fnames in zip(partition_names, partition_fnames): - partition_dir = os.path.join(args.output_dir, partition_name) - if not os.path.exists(partition_dir): - os.makedirs(partition_dir) - read_frames_and_save_tf_records(partition_dir, partition_fnames, args.image_size) - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py deleted file mode 100644 index 45a51248592e5a94ff951e00a143a9fcd6abc482..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py +++ /dev/null @@ -1,246 +0,0 @@ -# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) -# -# SPDX-License-Identifier: MIT - -__email__ = "b.gong@fz-juelich.de" -__author__ = "Bing Gong, Karim" -__date__ = "2021-05-03" - -import glob -import os -import random -import json -import numpy as np -import tensorflow as tf -from tensorflow.contrib.training import HParams -from collections import OrderedDict -from google.protobuf.json_format import MessageToDict - - -class MovingMnist(object): - def __init__(self, input_dir: str = None, datasplit_config: str = None, hparams_dict_config: str = None, - mode: str = "train", seed: int = None, nsamples_ref: int = None): - """ - This class is used for preparing data for training/validation and test models - :param input_dir: the path of tfrecords files - :param datasplit_config: the path pointing to the datasplit_config json file - :param hparams_dict_config: the path to the dict that contains hparameters, - :param mode: string, "train","val" or "test" - :param seed: int, the seed for dataset - :param nsamples_ref: number of reference samples whch can be used to control repetition factor for dataset - for ensuring adopted size of dataset iterator (used for validation data during training) - Example: Let nsamples_ref be 1000 while the current datset consists 100 samples, then - the repetition-factor will be 10 (i.e. nsamples*rep_fac = nsamples_ref) - """ - method = self.__class__.__name__ - - self.input_dir = input_dir - self.mode = mode - self.seed = seed - self.sequence_length = None # will be set in get_example_info - self.shuffled = False # will be set properly in make_dataset-method - # sanity checks - if self.mode not in ('train', 'val', 'test'): - raise ValueError('%{0}: Invalid mode {1}'.format(method, self.mode)) - if not os.path.exists(self.input_dir): - raise FileNotFoundError("%{0} input_dir '{1}' does not exist".format(method, self.input_dir)) - if nsamples_ref is not None: - self.nsamples_ref = nsamples_ref - self.datasplit_dict_path = datasplit_config - self.data_dict = self.get_datasplit() - self.hparams_dict_config = hparams_dict_config - self.hparams_dict = self.get_model_hparams_dict() - self.hparams = self.parse_hparams() - self.get_tfrecords_filename_base_datasplit() - self.get_example_info() - - def get_datasplit(self): - """ - Get the datasplit json file - """ - with open(self.datasplit_dict_path) as f: - datasplit_dict = json.load(f) - return datasplit_dict - - def get_model_hparams_dict(self): - """ - Get model_hparams_dict from json file - """ - self.model_hparams_dict_load = {} - if self.hparams_dict_config: - with open(self.hparams_dict_config) as f: - self.model_hparams_dict_load.update(json.loads(f.read())) - return self.model_hparams_dict_load - - def parse_hparams(self): - """ - Parse the hparams setting to ovoerride the default ones - """ - parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) - return parsed_hparams - - def get_default_hparams(self): - return HParams(**self.get_default_hparams_dict()) - - def get_default_hparams_dict(self): - """ - The function that contains default hparams - Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. - sequence_length : the number of frames in the video sequence - max_epochs : the number of epochs to train model - lr : learning rate - loss_fun : the loss function - :return: - """ - hparams = dict( - context_frames=10, - sequence_length=20, - max_epochs=20, - batch_size=40, - lr=0.001, - loss_fun="rmse", - shuffle_on_val=True, - ) - return hparams - - def get_tfrecords_filename_base_datasplit(self): - """ - Get obsoluate .tfrecords names based on the data splits patterns - """ - self.filenames = [] - self.data_mode = self.data_dict[self.mode] - self.all_filenames = glob.glob(os.path.join(self.input_dir,"*.tfrecords")) - print("self.all_files",self.all_filenames) - for indice_group, index in self.data_mode.items(): - fs = [MovingMnist.string_filter(max_value=index[1], min_value=index[0], string=s) for s in self.all_filenames] - print("fs:",fs) - self.tf_names = [self.all_filenames[fs_index] for fs_index in range(len(fs)) if fs[fs_index]==True] - print("tf_names,",self.tf_names) - # look for tfrecords in input_dir and input_dir/mode directories - for files in self.tf_names: - self.filenames.extend(glob.glob(os.path.join(self.input_dir, files))) - if self.filenames: - self.filenames = sorted(self.filenames) # ensures order is the same across systems - if not self.filenames: - raise FileNotFoundError('No tfrecords were found in %s' % self.input_dir) - - @staticmethod - def string_filter(max_value=None, min_value=None, string="input_directory/sequence_index_0_index_10.tfrecords"): - a = os.path.split(string)[-1].split("_") - if not len(a) == 5: - raise ("The tfrecords pattern does not match the expected pattern, for instance: 'sequence_index_0_to_10.tfrecords'") - min_index = int(a[2]) - max_index = int(a[4].split(".")[0]) - if min_index >= min_value and max_index <= max_value: - return True - else: - return False - - def get_example_info(self): - """ - Get the data information from tfrecord file - """ - example = next(tf.python_io.tf_record_iterator(self.filenames[0])) - dict_message = MessageToDict(tf.train.Example.FromString(example)) - feature = dict_message['features']['feature'] - print("features in dataset:",feature.keys()) - video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', - 'width', 'channels']) - self.sequence_length = video_shape[0] - self.image_shape = video_shape[1:] - - def num_examples_per_epoch(self): - """ - Calculate how many tfrecords samples in the train/val/test - """ - # count how many tfrecords files for train/val/testing - len_fnames = len(self.filenames) - num_seq_file = os.path.join(self.input_dir, 'number_sequences.txt') - with open(num_seq_file, 'r') as dfile: - num_seqs = dfile.readlines() - num_sequences = [int(num_seq.strip()) for num_seq in num_seqs] - num_examples_per_epoch = len_fnames * num_sequences[0] - - return num_examples_per_epoch - - def make_dataset(self, batch_size): - """ - Prepare batch_size dataset fed into to the models. - If the data are from training dataset,then the data is shuffled; - If the data are from val dataset, the shuffle var will be decided by the hparams.shuffled_on_val; - if the data are from test dataset, the data will not be shuffled - args: - batch_size: int, the size of samples fed into the models per iteration - """ - method = MovingMnist.make_dataset.__name__ - - self.num_epochs = self.hparams.max_epochs - - def parser(serialized_example): - seqs = OrderedDict() - keys_to_features = { - 'width': tf.FixedLenFeature([], tf.int64), - 'height': tf.FixedLenFeature([], tf.int64), - 'sequence_length': tf.FixedLenFeature([], tf.int64), - 'channels': tf.FixedLenFeature([],tf.int64), - 'images/encoded': tf.VarLenFeature(tf.float32) - } - parsed_features = tf.parse_single_example(serialized_example, keys_to_features) - seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) - print("Image shape {}, {},{},{}".format(self.sequence_length,self.image_shape[0],self.image_shape[1], - self.image_shape[2])) - images = tf.reshape(seq, [self.sequence_length,self.image_shape[0],self.image_shape[1], - self.image_shape[2]], name = "reshape_new") - seqs["images"] = images - return seqs - filenames = self.filenames - shuffle = self.mode == 'train' or (self.mode == 'val' and self.hparams.shuffle_on_val) - if shuffle: - self.shuffled = True - random.shuffle(filenames) - dataset = tf.data.TFRecordDataset(filenames, buffer_size=8*1024*1024) - # set-up dataset iterator - nrepeat = self.num_epochs - if self.nsamples_ref: - num_samples = self.num_examples_per_epoch() - nrepeat = int(nrepeat*max(int(np.ceil(self.nsamples_ref/num_samples)), 1)) - - if shuffle: - dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=1024, count=nrepeat)) - else: - dataset = dataset.repeat(nrepeat) - num_parallel_calls = None if shuffle else 1 - dataset = dataset.apply(tf.contrib.data.map_and_batch( - parser, batch_size, drop_remainder=True, num_parallel_calls=num_parallel_calls)) - dataset = dataset.prefetch(batch_size) - return dataset - - def make_batch(self, batch_size): - dataset = self.make_dataset(batch_size) - iterator = dataset.make_one_shot_iterator() - return iterator.get_next() - - -# further auxiliary methods -def _bytes_feature(value): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - - -def _bytes_list_feature(values): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) - - -def _floats_feature(value): - return tf.train.Feature(float_list=tf.train.FloatList(value=value)) - - -def _int64_feature(value): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - - - - - - diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/softmotion_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/softmotion_dataset.py deleted file mode 100644 index 4e0a7375b32c56cd4ebfff801277d2460f2d6c44..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/datasets/softmotion_dataset.py +++ /dev/null @@ -1,84 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import itertools -import os -import re -import tensorflow as tf -from model_modules.video_prediction.utils import tf_utils -from .base_dataset import VideoDataset - - -class SoftmotionVideoDataset(VideoDataset): - """ - https://sites.google.com/view/sna-visual-mpc - """ - def __init__(self, *args, **kwargs): - super(SoftmotionVideoDataset, self).__init__(*args, **kwargs) - # infer name of image feature and check if object_pos feature is present - from google.protobuf.json_format import MessageToDict - example = next(tf.python_io.tf_record_iterator(self.filenames[0])) - dict_message = MessageToDict(tf.train.Example.FromString(example)) - feature = dict_message['features']['feature'] - image_names = set() - for name in feature.keys(): - m = re.search('\d+/(\w+)/encoded', name) - if m: - image_names.add(m.group(1)) - # look for image_aux1 and image_view0 in that order of priority - image_name = None - for name in ['image_aux1', 'image_view0']: - if name in image_names: - image_name = name - break - if not image_name: - if len(image_names) == 1: - image_name = image_names.pop() - else: - raise ValueError('The examples have images under more than one name.') - self.state_like_names_and_shapes['images'] = '%%d/%s/encoded' % image_name, None - if self.hparams.use_state: - self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (3,) - self.action_like_names_and_shapes['actions'] = '%d/action', (4,) - if any([re.search('\d+/object_pos', name) for name in feature.keys()]): - self.state_like_names_and_shapes['object_pos'] = '%d/object_pos', None # shape is (2 * num_designated_pixels) - self._check_or_infer_shapes() - - def get_default_hparams_dict(self): - default_hparams = super(SoftmotionVideoDataset, self).get_default_hparams_dict() - hparams = dict( - context_frames=2, - sequence_length=12, - long_sequence_length=30, - time_shift=2, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - @property - def jpeg_encoding(self): - return False - - def parser(self, serialized_example): - state_like_seqs, action_like_seqs = super(SoftmotionVideoDataset, self).parser(serialized_example) - if 'object_pos' in state_like_seqs: - object_pos = state_like_seqs['object_pos'] - height, width, _ = self.state_like_names_and_shapes['images'][1] - object_pos = tf.reshape(object_pos, [object_pos.shape[0].value, -1, 2]) - pix_distribs = tf.stack([tf_utils.pixel_distribution(object_pos_, height, width) - for object_pos_ in tf.unstack(object_pos, axis=1)], axis=-1) - state_like_seqs['pix_distribs'] = pix_distribs - return state_like_seqs, action_like_seqs - - def num_examples_per_epoch(self): - # extract information from filename to count the number of trajectories in the dataset - count = 0 - for filename in self.filenames: - match = re.search('traj_(\d+)_to_(\d+).tfrecords', os.path.basename(filename)) - start_traj_iter = int(match.group(1)) - end_traj_iter = int(match.group(2)) - count += end_traj_iter - start_traj_iter + 1 - - # alternatively, the dataset size can be determined like this, but it's very slow - # count = sum(sum(1 for _ in tf.python_io.tf_record_iterator(filename)) for filename in filenames) - return count diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/stats.py b/video_prediction_tools/model_modules/video_prediction/datasets/stats.py new file mode 100644 index 0000000000000000000000000000000000000000..a1934e6290712f715691fca02554bc012eda7dd6 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/datasets/stats.py @@ -0,0 +1,117 @@ +from abc import ABC, abstractmethod +import dataclasses as dc +import json +from pathlib import Path + +import numpy as np +import tensorflow as tf + +@dc.dataclass +class VarStats: + mean: float + std: np.ndarray + maximum: float + minimum: float + + +@dc.dataclass +class DatasetStats: + mean: np.ndarray + std: np.ndarray + maximum: np.ndarray + minimum: np.ndarray + n: int + + def as_array(self): + return np.vstack([self.mean, self.std, self.max, self.min]) + + def var_stats(self, index): + """extract specific stats for variable at position 'index.'""" + + return VarStats(self.mean[index], self.std[index], self.maximum[index], self.minimum[index]) + + @staticmethod + def from_json(path): + with open(path, "r") as f: + in_dict = json.read(f) + + return DatasetStats( + **{key: np.array(in_dict[key]) if key != "n" else in_dict[key] + for key in in_dict}) + + def to_json(self, path): + out_dict = dc.asdict(self) + with open(path, "w") as f: + json.dump( + {key: list(out_dict[key].astype(float)) if key != "n" else out_dict[key] + for key in out_dict}, f) + + + +class Normalize(ABC): + """ + Provide normalization and denormalization for different normalization approaches. + """ + def __init__(self, stats: DatasetStats): + self.stats: DatasetStats = stats + + @staticmethod + def _apply_over_vars(fun, x, stats): + # x = x.copy() # assure no inplace operation + + def inner_fun(i): + var_stats = stats.var_stats(i) + return fun(x[:, :, :, :, i], var_stats) + + + print(f"overall shape: {x.shape}") + # normalize each variable seperatly + x = tf.stack([fun(x[:, :, :, :, i], stats.var_stats(i)) for i in range(x.shape[-1])], axis=-1) + + print(f"normalized shape: {x.shape}") + return x + + def normalize_vars(self, x): + """ + Normalize each variable seperatly, using normalize_fun. + """ + return Normalize._apply_over_vars(self.normalize_fun, x, self.stats) + + def denormalize_vars(self, x): + """ + Denormalize each variable seperatly, using denormalize_fun. + """ + return Normalize._apply_over_vars(self.denormalize_fun, x, self.stats) + + @abstractmethod + def normalize_fun(self, x, stats: VarStats): + """ + Normalization for data of shape (batch_size, sequence_len, lat, lon). + """ + pass + + @abstractmethod + def denormalize_fun(self, x, stats: VarStats): + """ + Normalization for data of shape (batch_size, sequence_len, lat, lon). + """ + + + +class MinMax(Normalize): + def normalize_fun(self, x, stats: VarStats): + return (x - stats.minimum) / (stats.maximum - stats.minimum) + + def denormalize_fun(self, x, stats: VarStats): + return x * (stats.maximum - stats.minimum) + stats.minimum + + +class ZScore(Normalize): + """ + Implement ZScore (De)Normalization. + """ + def normalize_fun(self, x, stats: VarStats): + return (x - stats.mean) / stats.std + + def denormalize_fun(self, x, stats: VarStats): + return x * stats.std + stats.mean \ No newline at end of file diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/sv2p_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/sv2p_dataset.py deleted file mode 100644 index 78f7f9fad3b719120b4c5bb502d75234976eba28..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/datasets/sv2p_dataset.py +++ /dev/null @@ -1,69 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import itertools -import os - -from .base_dataset import VideoDataset - - -class SV2PVideoDataset(VideoDataset): - def __init__(self, *args, **kwargs): - super(SV2PVideoDataset, self).__init__(*args, **kwargs) - self.dataset_name = os.path.basename(os.path.split(self.input_dir)[0]) - self.state_like_names_and_shapes['images'] = 'image_%d', (64, 64, 3) - if self.dataset_name == 'shape': - if self.hparams.use_state: - self.state_like_names_and_shapes['states'] = 'state_%d', (2,) - self.action_like_names_and_shapes['actions'] = 'action_%d', (2,) - elif self.dataset_name == 'humans': - if self.hparams.use_state: - raise ValueError('SV2PVideoDataset does not have states, use_state should be False') - else: - raise NotImplementedError - self._check_or_infer_shapes() - - def get_default_hparams_dict(self): - default_hparams = super(SV2PVideoDataset, self).get_default_hparams_dict() - if self.dataset_name == 'shape': - hparams = dict( - context_frames=1, - sequence_length=6, - time_shift=0, - use_state=False, - ) - elif self.dataset_name == 'humans': - hparams = dict( - context_frames=10, - sequence_length=20, - use_state=False, - ) - else: - raise NotImplementedError - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - def num_examples_per_epoch(self): - if self.dataset_name == 'shape': - if os.path.basename(self.input_dir) == 'train': - count = 43415 - elif os.path.basename(self.input_dir) == 'val': - count = 2898 - else: # shape dataset doesn't have a test set - raise NotImplementedError - elif self.dataset_name == 'humans': - if os.path.basename(self.input_dir) == 'train': - count = 23910 - elif os.path.basename(self.input_dir) == 'val': - count = 10472 - elif os.path.basename(self.input_dir) == 'test': - count = 7722 - else: - raise NotImplementedError - else: - raise NotImplementedError - return count - - @property - def jpeg_encoding(self): - return True diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/ucf101_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/ucf101_dataset.py deleted file mode 100644 index b728692fb3080cebe10b3ae8d8798484bf00a339..0000000000000000000000000000000000000000 --- a/video_prediction_tools/model_modules/video_prediction/datasets/ucf101_dataset.py +++ /dev/null @@ -1,215 +0,0 @@ -# SPDX-FileCopyrightText: 2018, alexlee-gk -# -# SPDX-License-Identifier: MIT - -import argparse -import glob -import itertools -import os -import random -import re -from multiprocessing import Pool -import cv2 -import tensorflow as tf -from model_modules.video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset - - -class UCF101VideoDataset(VarLenFeatureVideoDataset): - def __init__(self, *args, **kwargs): - super(UCF101VideoDataset, self).__init__(*args, **kwargs) - self.state_like_names_and_shapes['images'] = 'images/encoded', (240, 320, 3) - - def get_default_hparams_dict(self): - default_hparams = super(UCF101VideoDataset, self).get_default_hparams_dict() - hparams = dict( - context_frames=4, - sequence_length=8, - random_crop_size=0, - use_state=False, - ) - return dict(itertools.chain(default_hparams.items(), hparams.items())) - - @property - def jpeg_encoding(self): - return True - - def decode_and_preprocess_images(self, image_buffers, image_shape): - if self.hparams.crop_size: - raise NotImplementedError - if self.hparams.scale_size: - raise NotImplementedError - image_buffers = tf.reshape(image_buffers, [-1]) - if not isinstance(image_buffers, (list, tuple)): - image_buffers = tf.unstack(image_buffers) - image_size = tf.image.extract_jpeg_shape(image_buffers[0])[:2] # should be the same as image_shape[:2] - if self.hparams.random_crop_size: - random_crop_size = [self.hparams.random_crop_size] * 2 - crop_y = tf.random_uniform([], minval=0, maxval=image_size[0] - random_crop_size[0], dtype=tf.int32) - crop_x = tf.random_uniform([], minval=0, maxval=image_size[1] - random_crop_size[1], dtype=tf.int32) - crop_window = [crop_y, crop_x] + random_crop_size - images = [tf.image.decode_and_crop_jpeg(image_buffer, crop_window) for image_buffer in image_buffers] - images = tf.image.convert_image_dtype(images, dtype=tf.float32) - images.set_shape([None] + random_crop_size + [image_shape[-1]]) - else: - images = [tf.image.decode_jpeg(image_buffer) for image_buffer in image_buffers] - images = tf.image.convert_image_dtype(images, dtype=tf.float32) - images.set_shape([None] + list(image_shape)) - # TODO: only random crop for training - return images - - def num_examples_per_epoch(self): - # extract information from filename to count the number of trajectories in the dataset - count = 0 - for filename in self.filenames: - match = re.search('sequence_(\d+)_to_(\d+).tfrecords', os.path.basename(filename)) - start_traj_iter = int(match.group(1)) - end_traj_iter = int(match.group(2)) - count += end_traj_iter - start_traj_iter + 1 - - # alternatively, the dataset size can be determined like this, but it's very slow - # count = sum(sum(1 for _ in tf.python_io.tf_record_iterator(filename)) for filename in filenames) - return count - - -def _bytes_feature(value): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) - - -def _bytes_list_feature(values): - return tf.train.Feature(bytes_list=tf.train.BytesList(value=values)) - - -def _int64_feature(value): - return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) - - -def partition_data(input_dir, train_test_list_dir): - train_list_fnames = glob.glob(os.path.join(train_test_list_dir, 'trainlist*.txt')) - test_list_fnames = glob.glob(os.path.join(train_test_list_dir, 'testlist*.txt')) - test_list_fnames_mathieu = [os.path.join(train_test_list_dir, 'testlist01.txt')] - - def read_fnames(list_fnames): - fnames = [] - for list_fname in sorted(list_fnames): - with open(list_fname, 'r') as f: - while True: - fname = f.readline() - if not fname: - break - fnames.append(fname.split('\n')[0].split(' ')[0]) - return fnames - - train_fnames = read_fnames(train_list_fnames) - test_fnames = read_fnames(test_list_fnames) - test_fnames_mathieu = read_fnames(test_list_fnames_mathieu) - - train_fnames = [os.path.join(input_dir, train_fname) for train_fname in train_fnames] - test_fnames = [os.path.join(input_dir, test_fname) for test_fname in test_fnames] - test_fnames_mathieu = [os.path.join(input_dir, test_fname) for test_fname in test_fnames_mathieu] - # only use every 10 videos as in Mathieu et al. - test_fnames_mathieu = test_fnames_mathieu[::10] - - random.shuffle(train_fnames) - - pivot = int(0.95 * len(train_fnames)) - train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:] - return train_fnames, val_fnames, test_fnames, test_fnames_mathieu - - -def read_video(fname): - if not os.path.isfile(fname): - raise FileNotFoundError - vidcap = cv2.VideoCapture(fname) - frames, (success, image) = [], vidcap.read() - while success: - frames.append(image) - success, image = vidcap.read() - return frames - - -def save_tf_record(output_fname, sequences, preprocess_image): - print('saving sequences to %s' % output_fname) - with tf.python_io.TFRecordWriter(output_fname) as writer: - for sequence in sequences: - num_frames = len(sequence) - height, width, channels = sequence[0].shape - encoded_sequence = [preprocess_image(image) for image in sequence] - features = tf.train.Features(feature={ - 'sequence_length': _int64_feature(num_frames), - 'height': _int64_feature(height), - 'width': _int64_feature(width), - 'channels': _int64_feature(channels), - 'images/encoded': _bytes_list_feature(encoded_sequence), - }) - example = tf.train.Example(features=features) - writer.write(example.SerializeToString()) - - -def read_videos_and_save_tf_records(output_dir, fnames, start_sequence_iter=None, - end_sequence_iter=None, sequences_per_file=128): - print('started process with PID:', os.getpid()) - - if not os.path.exists(output_dir): - os.makedirs(output_dir) - - if start_sequence_iter is None: - start_sequence_iter = 0 - if end_sequence_iter is None: - end_sequence_iter = len(fnames) - - def preprocess_image(image): - if image.shape != (240, 320, 3): - image = cv2.resize(image, (320, 240), interpolation=cv2.INTER_LINEAR) - return tf.compat.as_bytes(cv2.imencode(".jpg", image)[1].tobytes()) - - print('reading and saving sequences {0} to {1}'.format(start_sequence_iter, end_sequence_iter)) - - sequences = [] - for sequence_iter in range(start_sequence_iter, end_sequence_iter): - if not sequences: - last_start_sequence_iter = sequence_iter - print("reading sequences starting at sequence %d" % sequence_iter) - - sequences.append(read_video(fnames[sequence_iter])) - - if len(sequences) == sequences_per_file or sequence_iter == (end_sequence_iter - 1): - output_fname = 'sequence_{0}_to_{1}.tfrecords'.format(last_start_sequence_iter, sequence_iter) - output_fname = os.path.join(output_dir, output_fname) - save_tf_record(output_fname, sequences, preprocess_image) - sequences[:] = [] - - -def main(): - parser = argparse.ArgumentParser() - parser.add_argument("input_dir", type=str, help="directory containing the directories of " - "classes, each of which contains avi files.") - parser.add_argument("train_test_list_dir", type=str, help='directory containing trainlist*.txt' - 'and testlist*.txt files.') - parser.add_argument("output_dir", type=str) - parser.add_argument('--num_workers', type=int, default=1, help='number of parallel workers') - args = parser.parse_args() - - partition_names = ['train', 'val', 'test', 'test_mathieu'] - partition_fnames = partition_data(args.input_dir, args.train_test_list_dir) - - for partition_name, partition_fnames in zip(partition_names, partition_fnames): - partition_dir = os.path.join(args.output_dir, partition_name) - if not os.path.exists(partition_dir): - os.makedirs(partition_dir) - - if args.num_workers > 1: - num_seqs_per_worker = len(partition_fnames) // args.num_workers - start_seq_iters = [num_seqs_per_worker * i for i in range(args.num_workers)] - end_seq_iters = [num_seqs_per_worker * (i + 1) - 1 for i in range(args.num_workers)] - end_seq_iters[-1] = len(partition_fnames) - - p = Pool(args.num_workers) - p.starmap(read_videos_and_save_tf_records, zip([partition_dir] * args.num_workers, - [partition_fnames] * args.num_workers, - start_seq_iters, end_seq_iters)) - else: - read_videos_and_save_tf_records(partition_dir, partition_fnames) - - -if __name__ == '__main__': - main() diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/weatherbench.json b/video_prediction_tools/model_modules/video_prediction/datasets/weatherbench.json new file mode 100644 index 0000000000000000000000000000000000000000..91288dd19ff3dafc8a77a43e37e7367d92cb8720 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/datasets/weatherbench.json @@ -0,0 +1,19 @@ +{ + "normalize": "ZScore", + "variables": [ + {"name": "temperature", "lvl": [850], "interpolation":"p"}, + {"name": "geopotential", "lvl": [500], "interpolation":"p"} + ], + "resolution": [ + {"deg": 5.625, "nx": 32, "ny": 64}, + {"deg": 2.8125, "nx": 64, "ny": 128}, + {"deg": 1.40625, "nx": 128, "ny": 256} + ], + "years": [ + 1979,1980, + 1981,1982,1983,1984,1985,1986,1987,1988,1989,1990, + 1991,1992,1993,1994,1995,1996,1997,1998,1999,2000, + 2001,2002,2003,2004,2005,2006,2007,2008,2009,2010, + 2011,2012,2013,2014,2015,2016,2017,2018 + ] +} \ No newline at end of file diff --git a/video_prediction_tools/model_modules/video_prediction/models/__init__.py b/video_prediction_tools/model_modules/video_prediction/models/__init__.py index 290def9f7c934f871fedd8c1703bbc114c822dcd..1bb913f18ccc7b9af2053b446584594bc6875c1b 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/models/__init__.py @@ -15,6 +15,8 @@ from .test_model import TestModelVideoPredictionModel from model_modules.model_architectures import known_models from .convLSTM_GAN_model import ConvLstmGANVideoPredictionModel from .weatherBench3DCNN import WeatherBenchModel +#from .vanilla_predrnnv2 import PredRNNv2VideoPredictionModel + def get_model_class(model): model_mappings = known_models() diff --git a/video_prediction_tools/model_modules/video_prediction/models/base_model.py b/video_prediction_tools/model_modules/video_prediction/models/base_model.py index 1857f8b915d62646dff9a73d63f16e78656ddc57..8ac05c98732eeca53c0356845faa3c8db59cb527 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/base_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/base_model.py @@ -9,6 +9,8 @@ import re from collections import OrderedDict import numpy as np import tensorflow as tf +print('tensorflow version: {}'.format(tf.__version__)) +print(tf.contrib.training.HParams) from tensorflow.contrib.training import HParams from tensorflow.python.util import nest import model_modules.video_prediction as vp diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py index daab8d9b72f8340941b78b2b9f4237f11ffcd238..c4035963dc37b840e55cdacb2c7288b10e458cfd 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py @@ -10,103 +10,59 @@ from model_modules.video_prediction.models.model_helpers import set_and_check_pr import tensorflow as tf from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell -from tensorflow.contrib.training import HParams -from .vanilla_convLSTM_model import VanillaConvLstmVideoPredictionModel +from .our_base_model import BaseModels -class batch_norm(object): - def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): - with tf.variable_scope(name): - self.epsilon = epsilon - self.momentum = momentum - self.name = name - - def __call__(self, x, train=True): - return tf.contrib.layers.batch_norm(x, - decay=self.momentum, - updates_collections=None, - epsilon=self.epsilon, - scale=True, - is_training=train, - scope=self.name) - -class ConvLstmGANVideoPredictionModel(object): - def __init__(self, mode='train', hparams_dict=None): +class ConvLstmGANVideoPredictionModel(BaseModels): + def __init__(self, hparams_dict=None, mode='train', **kwargs): """ This is class for building convLSTM_GAN architecture by using updated hparameters args: - mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model + mode :str, "train" or "val", side note: mode may not be used in the convLSTM, but this will be a useful argument for the GAN-based model hparams_dict: dict, the dictionary contains the hparaemters names and values """ + super().__init__(hparams_dict) + self.hparams = self.get_hparams() self.mode = mode - self.hparams_dict = hparams_dict - self.hparams = self.parse_hparams() - self.learning_rate = self.hparams.lr - self.total_loss = None - self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) - self.max_epochs = self.hparams.max_epochs - self.loss_fun = self.hparams.loss_fun - self.batch_size = self.hparams.batch_size - self.recon_weight = self.hparams.recon_weight - self.bd1 = batch_norm(name = "dis1") - self.bd2 = batch_norm(name = "dis2") - self.bd3 = batch_norm(name = "dis3") - - def get_default_hparams(self): - return HParams(**self.get_default_hparams_dict()) - - def parse_hparams(self): + self.bd1 = batch_norm(name = "bd1") + self.bd2 = batch_norm(name = "bd2") + + def get_hparams(self): """ - Parse the hparams setting to ovoerride the default ones + obtain the hparams from the dict to the class variables """ - - parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) - return parsed_hparams + method = BaseModels.get_hparams.__name__ + try: + self.context_frames = self.hparams.context_frames + self.max_epochs = self.hparams.max_epochs + self.batch_size = self.hparams.batch_size + self.shuffle_on_val = self.hparams.shuffle_on_val + self.loss_fun = self.hparams.loss_fun + self.recon_weight = self.hparams.recon_weight + self.learning_rate = self.hparams.lr + self.sequence_length = self.hparams.sequence_length + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) + self.ngf = self.hparams.ngf + self.ndf = self.hparams.ndf + + + except Exception as error: + print("Method %{}: error: {}".format(method,error)) + raise("Method %{}: the hparameter dictionary must include parameters above".format(method)) + + + def build_graph(self, x: tf.Tensor): - def get_default_hparams_dict(self): - """ - The function that contains default hparams - Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. - sequence_length : the number of frames in the video sequence - max_epochs : the number of epochs to train model - lr : learning rate - loss_fun : the loss function - recon_wegiht : the weight for reconstrution loss - """ - hparams = dict( - context_frames=12, - sequence_length=24, - max_epochs = 20, - batch_size = 40, - lr = 0.001, - loss_fun = "cross_entropy", - shuffle_on_val= True, - recon_weight=0.99, - - ) - return hparams - - - def build_graph(self, x): self.is_build_graph = False self.inputs = x - self.x = x["images"] - self.width = self.x.shape.as_list()[3] - self.height = self.x.shape.as_list()[2] - self.channels = self.x.shape.as_list()[4] self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() # Architecture - self.define_gan() - #This is the loss function (RMSE): - #This is loss function only for 1 channel (temperature RMSE) - #generator los + self.build_model() + # define loss function self.total_loss = (1-self.recon_weight) * self.G_loss + self.recon_weight*self.recon_loss self.D_loss = (1-self.recon_weight) * self.D_loss + if self.mode == "train": if self.recon_weight == 1: print("Only train generator- convLSTM") @@ -122,7 +78,6 @@ class ConvLstmGANVideoPredictionModel(object): else: self.train_op = None - self.outputs = {} self.outputs["gen_images"] = self.gen_images self.outputs["total_loss"] = self.total_loss # Summary op @@ -137,118 +92,122 @@ class ConvLstmGANVideoPredictionModel(object): self.saveable_variables = [self.global_step] + global_variables self.is_build_graph = True return self.is_build_graph - - def get_noise(self): - """ - Function for creating noise: Given the dimensions (n_batch,n_seq, n_height, n_width, channel) - """ - self.noise = tf.random.uniform(minval=-1., maxval=1., shape=[self.batch_size, self.sequence_length, self.height, self.width, self.channels]) - return self.noise - + @staticmethod - def lrelu(x, leak=0.2, name="lrelu"): - return tf.maximum(x, leak*x) + def Unet_ConvLSTM_cell(x: tf.Tensor, ngf: int, hidden: tf.Tensor): + """ + Build up a Unet ConvLSTM cell for each time stamp i + params: x: the input at timestamp i + params: ngf: the numnber of filters for convoluational layers + params: hidden: the hidden state from the previous timestamp t-1 + return: + outputs: the predict frame at timestamp i + hidden: the hidden state at current timestamp i + """ + input_shape = x.get_shape().as_list() + num_channels = input_shape[3] + with tf.variable_scope("down_scale", reuse = tf.AUTO_REUSE): + conv1f = ld.conv_layer(x, 3 , 1, ngf, 1, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") + conv1s = ld.conv_layer(conv1f, 3, 1, ngf, 2, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") + pool1 = tf.layers.max_pooling2d(conv1s, pool_size=(2, 2), strides=(2, 2)) + print('pool1 shape: ',pool1.shape) - @staticmethod - def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): - shape = input_.get_shape().as_list() + conv2f = ld.conv_layer(pool1, 3, 1, ngf * 2, 3, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") + conv2s = ld.conv_layer(conv2f, 3, 1, ngf * 2, 4, initializer = tf.contrib.layers.xavier_initializer(), activate = "relu") + pool2 = tf.layers.max_pooling2d(conv2s, pool_size=(2, 2), strides=(2, 2)) + print('pool2 shape: ',pool2.shape) - with tf.variable_scope(scope or "Linear"): - matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, - tf.random_normal_initializer(stddev=stddev)) - bias = tf.get_variable("bias", [output_size], - initializer=tf.constant_initializer(bias_start)) - if with_w: - return tf.matmul(input_, matrix) + bias, matrix, bias - else: - return tf.matmul(input_, matrix) + bias - - @staticmethod - def conv2d(input_, output_dim, k_h=5, k_w=5, d_h=2, d_w=2, stddev=0.02, name="conv2d"): - with tf.variable_scope(name): - w = tf.get_variable('w', [k_h, k_w, input_.get_shape()[-1], output_dim], - initializer=tf.truncated_normal_initializer(stddev=stddev)) - conv = tf.nn.conv2d(input_, w, strides=[1, d_h, d_w, 1], padding='SAME') + conv3f = ld.conv_layer(pool2, 3, 1, ngf * 4, 5, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") + conv3s = ld.conv_layer(conv3f, 3, 1, ngf * 4, 6, initializer = tf.contrib.layers.xavier_initializer(), activate = "relu") + pool3 = tf.layers.max_pooling2d(conv3s, pool_size=(2, 2), strides=(2, 2)) + print('pool3 shape: ',pool3.shape) - biases = tf.get_variable('biases', [output_dim], initializer=tf.constant_initializer(0.0)) - conv = tf.reshape(tf.nn.bias_add(conv, biases), conv.get_shape()) + convLSTM_input = pool3 + #convLSTM_input = tf.layers.dropout(pool2, 0.8) - return conv + convLSTM4, hidden = ConvLstmGANVideoPredictionModel.convLSTM_cell(convLSTM_input, hidden) + print('convLSTM4 shape: ',convLSTM4.shape) + + with tf.variable_scope("upscale", reuse = tf.AUTO_REUSE): + deconv5 = ld.transpose_conv_layer(convLSTM4, 2, 2, ngf * 4, 1, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") + print('deconv5 shape: ',deconv5.shape) + up5 = tf.concat([deconv5, conv3s], axis=3) + print('up5 shape: ',up5.shape) - @staticmethod - def bn(x, scope): - return tf.contrib.layers.batch_norm(x, - decay=0.9, - updates_collections=None, - epsilon=1e-5, - scale=True, - scope=scope) + conv5f = ld.conv_layer(up5, 3, 1, ngf * 4, 2, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") + conv5s = ld.conv_layer(conv5f, 3, 1, ngf * 4, 3, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") + print('conv5s shape:',conv5s.shape) + + deconv6 = ld.transpose_conv_layer(conv5s, 2, 2, ngf * 2, 4, initializer=tf.contrib.layers.xavier_initializer(), activate="relu") + print('deconv6 shape: ',deconv6.shape) + up6 = tf.concat([deconv6, conv2s], axis=3) + print('up6 shape: ',up6.shape) + + conv6f = ld.conv_layer(up6, 3, 1, ngf * 2, 5, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") + conv6s = ld.conv_layer(conv6f, 3, 1, ngf * 2, 6, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") + print('conv6s shape:',conv6s.shape) + + deconv7 = ld.transpose_conv_layer(conv6s, 2, 2, ngf, 7, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") + print('deconv7 shape: ',deconv7.shape) + up7 = tf.concat([deconv7, conv1s], axis=3) + print('up7 shape: ',up7.shape) - def generator(self): + conv7f = ld.conv_layer(up7, 3, 1, ngf, 8, initializer = tf.contrib.layers.xavier_initializer(), activate="relu") + conv7s = ld.conv_layer(conv7f, 3, 1, ngf, 9, initializer = tf.contrib.layers.xavier_initializer(),activate= "relu") + print('conv7s shape:',conv7s.shape) + + conv7t = ld.conv_layer(conv7s, 3, 1, num_channels, 10, initializer = tf.contrib.layers.xavier_initializer(),activate="relu") + outputs = ld.conv_layer(conv7t, 1, 1, num_channels, 11, initializer = tf.contrib.layers.xavier_initializer(),activate="linear") + print('outputs shape: ',outputs.shape) + + return outputs, hidden + + def generator(self, x: tf.Tensor): """ - Function to build up the generator architecture + Function to build up the generator architecture, here we take Unet_ConvLSTM as generator args: input images: a input tensor with dimension (n_batch,sequence_length,height,width,channel) + output images: (n_batch,forecast_length,height,width,channel) """ - with tf.variable_scope("generator",reuse=tf.AUTO_REUSE): - layer_gen = self.convLSTM_network(self.x) - layer_gen_pred = layer_gen[:,self.context_frames-1:,:,:,:] - return layer_gen - + network_template = tf.make_template('network', ConvLstmGANVideoPredictionModel.Unet_ConvLSTM_cell) + with tf.variable_scope("generator", reuse = tf.AUTO_REUSE): + # create network + x_hat = [] + #This is for training (optimization of convLSTM layer) + hidden_g = None + for i in range(self.sequence_length-1): + print('i: ',i) + if i < self.context_frames: + x_1_g, hidden_g = network_template(x[:, i, :, :, :], self.ngf, hidden_g) + else: + x_1_g, hidden_g = network_template(x_1_g, self.ngf, hidden_g) + x_hat.append(x_1_g) + # pack them all together + x_hat = tf.stack(x_hat) + self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) + print('self.x_hat shape is: ',self.x_hat.shape) + return self.x_hat - def discriminator(self,vid): + def discriminator(self, x): """ Function that get discriminator architecture """ with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): - conv1 = tf.layers.conv3d(vid,64,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis1") + conv1 = tf.layers.conv3d(x, 4, kernel_size=[4,4,4], strides=[1,2,2], padding="SAME", name="dis1") conv1 = ConvLstmGANVideoPredictionModel.lrelu(conv1) - conv2 = tf.layers.conv3d(conv1,128,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis2") - conv2 = ConvLstmGANVideoPredictionModel.lrelu(self.bd1(conv2)) - conv3 = tf.layers.conv3d(conv2,256,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis3") - conv3 = ConvLstmGANVideoPredictionModel.lrelu(self.bd2(conv3)) - conv4 = tf.layers.conv3d(conv3,512,kernel_size=[4,4,4],strides=[2,2,2],padding="SAME",name="dis4") - conv4 = ConvLstmGANVideoPredictionModel.lrelu(self.bd3(conv4)) - conv5 = tf.layers.conv3d(conv4,1,kernel_size=[2,4,4],strides=[1,1,1],padding="SAME",name="dis5") - conv5 = tf.reshape(conv5, [-1,1]) - conv5sigmoid = tf.nn.sigmoid(conv5) - return conv5sigmoid,conv5 - - def discriminator0(self,image): - """ - Function that get discriminator architecture - """ - with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): - layer_disc = self.convLSTM_network(image) - layer_disc = layer_disc[:,self.context_frames-1:self.context_frames,:,:, 0:1] - return layer_disc - - def discriminator1(self,sequence): - """ - https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py - Function that give the possibility of a sequence of frames is ture of false - the input squence shape is like [batch_size,time_seq_length,height,width,channel] (e.g., self.x[:,:self.context_frames,:,:,:]) - """ - with tf.variable_scope("discriminator",reuse=tf.AUTO_REUSE): - print(sequence.shape) - x = sequence[:,:,:,:,0:1] # extract targeted variable - x = tf.transpose(x, [0,2,3,1,4]) # sequence shape is like: [batch_size,height,width,time_seq_length] - x = tf.reshape(x,[x.shape[0],x.shape[1],x.shape[2],x.shape[3]]) - print(x.shape) - net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.conv2d(x, 64, 4, 4, 2, 2, name='d_conv1')) - net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.conv2d(net, 128, 4, 4, 2, 2, name='d_conv2'),scope='d_bn2')) - net = tf.reshape(net, [self.batch_size, -1]) - net = ConvLstmGANVideoPredictionModel.lrelu(ConvLstmGANVideoPredictionModel.bn(ConvLstmGANVideoPredictionModel.linear(net, 1024, scope='d_fc3'),scope='d_bn3')) - out_logit = ConvLstmGANVideoPredictionModel.linear(net, 1, scope='d_fc4') + #conv2 = tf.layers.conv3d(conv1, 1, kernel_size=[4,4,4], strides=[1,2,2], padding="SAME", name="dis2") + conv2 = tf.reshape(conv1, [-1,1]) + #fc1 = ConvLstmGANVideoPredictionModel.lrelu(self.bd1(ConvLstmGANVideoPredictionModel.linear(conv2, output_size=256, scope='d_fc1'))) + fc2 = ConvLstmGANVideoPredictionModel.lrelu(self.bd2(ConvLstmGANVideoPredictionModel.linear(conv2, output_size=64, scope='d_fc2'))) + out_logit = ConvLstmGANVideoPredictionModel.linear(fc2, 1, scope='d_fc3') out = tf.nn.sigmoid(out_logit) - print(out.shape) - return out, out_logit + #out,out_logit = self.Conv3Dnet(x,self.ndf) + return out, out_logit def get_disc_loss(self): """ Return the loss of discriminator given inputs """ - real_labels = tf.ones_like(self.D_real) gen_labels = tf.zeros_like(self.D_fake) self.D_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_real_logits, labels=real_labels)) @@ -256,10 +215,11 @@ class ConvLstmGANVideoPredictionModel(object): self.D_loss = self.D_loss_real + self.D_loss_fake return self.D_loss - def get_gen_loss(self): """ Param: + num_images : the number of images the generator should produce, which is also the lenght of the real image + z_dim : the dimension of the noise vector, a scalar Return the loss of generator given inputs """ real_labels = tf.ones_like(self.D_fake) @@ -270,40 +230,36 @@ class ConvLstmGANVideoPredictionModel(object): """ Get trainable variables from discriminator and generator """ - print("trinable_varialbes", len(tf.trainable_variables())) self.disc_vars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] self.gen_vars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] - print("self.disc_vars",self.disc_vars) - print("self.gen_vars",self.gen_vars) - - def define_gan(self): + def build_model(self): """ Define gan architectures """ - self.noise = self.get_noise() - self.gen_images = self.generator() - #!!!! the input of discriminator should be changed when use different discriminators - self.D_real, self.D_real_logits = self.discriminator(self.x[:,self.context_frames:,:,:,:]) - self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,self.context_frames-1:,:,:,:]) + self.gen_images = self.generator(self.inputs) + self.D_real, self.D_real_logits = self.discriminator(self.inputs[:,self.context_frames:, :, :, 0:1]) # use the first varibale as targeted + #self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,:,:,:,0:1]) #0:1 + self.D_fake, self.D_fake_logits = self.discriminator(self.gen_images[:,self.context_frames-1:, :, :, 0:1]) #0:1 + self.get_gen_loss() self.get_disc_loss() self.get_vars() if self.loss_fun == "rmse": - self.recon_loss = tf.reduce_mean(tf.square(self.x[:, self.context_frames:,:,:,0] - self.gen_images[:,self.context_frames-1:,:,:,0])) + #self.recon_loss = tf.reduce_mean(tf.square(self.inputs[:, self.context_frames:,:,:,0] - self.gen_images[:,:,:,:,0])) + self.recon_loss = tf.reduce_mean(tf.square(self.inputs[:, self.context_frames:, :, :, 0] - self.gen_images[:, self.context_frames-1:, :, :, 0])) elif self.loss_fun == "cross_entropy": - x_flatten = tf.reshape(self.x[:, self.context_frames:,:,:,0],[-1]) - x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,self.context_frames-1:,:,:,0],[-1]) + x_flatten = tf.reshape(self.inputs[:, self.context_frames:,:,:,0],[-1]) + #x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,:,:,:,0],[-1]) + x_hat_predict_frames_flatten = tf.reshape(self.gen_images[:,self.context_frames-1:, :, :, 0], [-1]) bce = tf.keras.losses.BinaryCrossentropy() - self.recon_loss = bce(x_flatten,x_hat_predict_frames_flatten) + self.recon_loss = bce(x_flatten, x_hat_predict_frames_flatten) else: - raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'") - + raise ValueError("Loss function is not selected properly, you should chose either 'rmse' or 'cross_entropy'") @staticmethod def convLSTM_cell(inputs, hidden): y_0 = inputs #we only usd patch 1, but the original paper use patch 4 for the moving mnist case, but use 2 for Radar Echo Dataset - channels = inputs.get_shape()[-1] # conv lstm cell cell_shape = y_0.get_shape().as_list() channels = cell_shape[-1] @@ -312,32 +268,107 @@ class ConvLstmGANVideoPredictionModel(object): if hidden is None: hidden = cell.zero_state(y_0, tf.float32) output, hidden = cell(y_0, hidden) - output_shape = output.get_shape().as_list() - z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) - #we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction - x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") - print('x_hat shape is: ',x_hat.shape) - return x_hat, hidden - - def convLSTM_network(self,x): - network_template = tf.make_template('network',VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables - # create network - x_hat = [] - - #This is for training (optimization of convLSTM layer) - hidden_g = None - for i in range(self.sequence_length-1): - if i < self.context_frames: - x_1_g, hidden_g = network_template(x[:, i, :, :, :], hidden_g) + return output, hidden + #output_shape = output.get_shape().as_list() + #z3 = tf.reshape(output, [-1, output_shape[1], output_shape[2], output_shape[3]]) + ###we feed the learn representation into a 1 × 1 convolutional layer to generate the final prediction + #x_hat = ld.conv_layer(z3, 1, 1, channels, "decode_1", activate="sigmoid") + #print('x_hat shape is: ',x_hat.shape) + #return x_hat, hidden + + def get_noise(self, x, sigma=0.2): + """ + Function for creating noise: Given the dimensions (n_batch,n_seq, n_height, n_width, channel) + """ + x_shape = x.get_shape().as_list() + noise = sigma * tf.random.uniform(minval=-1., maxval=1., shape=x_shape) + x = x + noise + return x + + def Conv3Dnet_v1(self, x, ndf): + conv1 = tf.layers.conv3d(x, ndf, kernel_size = [4, 4, 4], strides = [1, 2, 2], padding = "SAME", name = 'conv1') + conv1 = self.lrelu(conv1) + # conv2 = tf.layers.conv3d(conv1,ndf*2,kernel_size=[4,4,4],strides=[1,2,2],padding="SAME",name='conv2') + # conv2 = self.lrelu(conv2) + conv3 = tf.layers.conv3d(conv1, 1, kernel_size = [4, 4, 4], strides = [1, 1, 1], padding = "SAME", name = 'conv3') + fl = tf.reshape(conv3, [-1, 1]) + print('fl shape: ', fl.shape) + fc1 = self.lrelu(self.bd1(self.linear(fl, 256, scope = 'fc1'))) + print('fc1 shape: ', fc1.shape) + fc2 = self.lrelu(self.bd2(self.linear(fc1, 64, scope = 'fc2'))) + print('fc2 shape: ', fc2.shape) + out_logit = self.linear(fc2, 1, scope = 'out') + out = tf.nn.sigmoid(out_logit) + return out, out_logit + + + def Conv3Dnet_v2(self, x, ndf): + """ + args: + input images: a input tensor with dimension (n_batch,forecast_length,height,width,channel) + output images: + """ + conv1 = Conv3D(ndf, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(x) + bn1 = BatchNormalization()(conv1) + bn1 = LeakyReLU(0.2)(bn1) + pool1 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn1) + noise1 = self.get_noise(pool1) + + conv2 = Conv3D(ndf * 2, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(noise1) + bn2 = BatchNormalization()(conv2) + bn2 = LeakyReLU(0.2)(bn2) + pool2 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn2) + noise2 = self.get_noise(pool2) + + conv3 = Conv3D(ndf * 4, 4, strides = (1, 2, 2), padding = 'same', kernel_initializer = 'he_normal')(noise2) + bn3 = BatchNormalization()(conv3) + bn3 = LeakyReLU(0.2)(bn3) + pool3 = MaxPooling3D(pool_size = (1, 2, 2), padding = 'same')(bn3) + + conv4 = Conv3D(1, 4, 1, padding = 'same')(pool3) + + fl = tf.reshape(conv4, [-1, 1]) + drop1 = Dropout(0.3)(fl) + fc1 = Dense(1024, activation = 'relu')(drop1) + drop2 = Dropout(0.3)(fc1) + fc2 = Dense(512, activation = 'relu')(drop2) + out_logit = Dense(1, activation = 'linear')(fc2) + out = tf.nn.sigmoid(out_logit) + return out, out_logit + + + @staticmethod + def lrelu(x, leak=0.2, name='lrelu'): + return tf.maximum(x, leak * x) + + + @staticmethod + def linear(input_, output_size, scope=None, stddev=0.02, bias_start=0.0, with_w=False): + shape = input_.get_shape().as_list() + + with tf.variable_scope(scope or "Linear"): + matrix = tf.get_variable("Matrix", [shape[1], output_size], tf.float32, + tf.random_normal_initializer(stddev = stddev)) + bias = tf.get_variable("bias", [output_size], + initializer = tf.constant_initializer(bias_start)) + if with_w: + return tf.matmul(input_, matrix) + bias, matrix, bias else: - x_1_g, hidden_g = network_template(x_1_g, hidden_g) - x_hat.append(x_1_g) + return tf.matmul(input_, matrix) + bias + +class batch_norm(object): + def __init__(self, epsilon=1e-5, momentum = 0.9, name="batch_norm"): + with tf.variable_scope(name): + self.epsilon = epsilon + self.momentum = momentum + self.name = name + + def __call__(self, x, train=True): + return tf.contrib.layers.batch_norm(x, + decay=self.momentum, + updates_collections=None, + epsilon=self.epsilon, + scale=True, + is_training=train, + scope=self.name) - # pack them all together - x_hat = tf.stack(x_hat) - self.x_hat= tf.transpose(x_hat, [1, 0, 2, 3, 4]) # change first dim with sec dim ???? yan: why? - print('self.x_hat shape is: ',self.x_hat.shape) - return self.x_hat - - - diff --git a/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py b/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py new file mode 100644 index 0000000000000000000000000000000000000000..296ae16a1f70fcd5047481ab46cdd6e14abe1e83 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/linear_regression_model.py @@ -0,0 +1,88 @@ +# SPDX-FileCopyrightText: 2018, alexlee-gk +# +# SPDX-License-Identifier: MIT + +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2022-04-13" + +from .our_base_model import BaseModels +import tensorflow as tf + + +class VanillaConvLstmVideoPredictionModel(BaseModels): + + def __init__(self, hparams_dict=None, **kwargs): + """ + This is class for building convLSTM architecture by using updated hparameters + args: + hparams_dict : dict, the dictionary contains the hparaemters names and values + """ + super().__init__(hparams_dict) + self.get_hparams() + + def get_hparams(self): + """ + obtain the hparams from the dict to the class variables + """ + method = BaseModels.get_hparams.__name__ + + try: + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.max_epochs = self.hparams.max_epochs + self.batch_size = self.hparams.batch_size + self.shuffle_on_val = self.hparams.shuffle_on_val + self.opt_var = self.hparams.opt_var + self.learning_rate = self.hparams.lr + + print("The model hparams have been parsed successfully! ") + except Exception as error: + print("Method %{}: error: {}".format(method, error)) + raise("Method %{}: the hparameter dictionary must include the params defined above!".format(method)) + + def build_graph(self, x: tf.Tensor): + + self.is_build_graph = False + self.inputs = x + self.global_step = tf.train.get_or_create_global_step() + original_global_variables = tf.global_variables() + + self.build_model() + + + # This is the loss function (MSE): + # Optimize all target variables/channels + if self.opt_var == "all": + x = self.inputs[:, self.context_frames:, :, :, :] + x_hat = self.x_hat_predict_frames[:, :, :, :, :] + print("The model is optimzied on all the variables in the loss function") + elif self.opt_var != "all" and isinstance(self.opt_var, str): + self.opt_var = int(self.opt_var) + print("The model is optimized on the {} variable in the loss function".format(self.opt_var)) + x = self.inputs[:, self.context_frames:, :, :, self.opt_var] + x_hat = self.x_hat_predict_frames[:, :, :, :, self.opt_var] + else: + raise ValueError( + "The opt var in the hyper-parameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables") + + #loss function is mean squre error + self.total_loss = tf.reduce_mean(tf.square(x - x_hat)) + + self.train_op = tf.train.AdamOptimizer( + learning_rate = self.learning_rate).minimize(self.total_loss, global_step = self.global_step) + + self.outputs["gen_images"] = self.x_hat + + # Summary op + self.loss_summary = tf.summary.scalar("total_loss", self.total_loss) + self.summary_op = tf.summary.merge_all() + global_variables = [var for var in tf.global_variables() if var not in original_global_variables] + self.saveable_variables = [self.global_step] + global_variables + self.is_build_graph = True + return self.is_build_graph + + + + def build_model(self): + pass diff --git a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py new file mode 100644 index 0000000000000000000000000000000000000000..3219bda2de305cf36ad178ebcc192ac9a5a37b78 --- /dev/null +++ b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT + +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2022-04-13" + +from hparams_utils import * +import json +from abc import ABC, abstractmethod +import tensorflow as tf + + +class BaseModels(ABC): + + def __init__(self, hparams_dict_config=None): + self.hparams_dict_config = hparams_dict_config + self.hparams_dict = self.get_model_hparams_dict() + self.hparams = self.parse_hparams() + # Attributes set during runtime + self.total_loss = None + self.loss_summary = None + self.total_loss = None + self.outputs = {} + self.train_op = None + self.summary_op = None + self.inputs = None + self.global_step = None + self.saveable_variables = None + self.is_build_graph = None + self.x_hat = None + self.x_hat_predict_frames = None + + + def get_model_hparams_dict(self): + """ + Get model_hparams_dict from json file + """ + if self.hparams_dict_config: + with open(self.hparams_dict_config, 'r') as f: + hparams_dict = json.loads(f.read()) + else: + raise FileNotFoundError("hyper-parameter directory doesn't exist! please check {}!".format(self.hparams_dict_config)) + + return hparams_dict + + def parse_hparams(self): + """ + Obtain the parameters from directory + """ + + hparams = dotdict(self.hparams_dict) + return hparams + + @abstractmethod + def get_hparams(self): + """ + obtain the hparams from the dict to the class variables + """ + method = BaseModels.get_hparams.__name__ + + try: + self.context_frames = self.hparams.context_frames + self.max_epochs = self.hparams.max_epochs + self.batch_size = self.hparams.batch_size + self.shuffle_on_val = self.hparams.shuffle_on_val + self.loss_fun = self.hparams.loss_fun + + except Exception as error: + print("Method %{}: error: {}".format(method,error)) + raise("Method %{}: the hparameter dictionary must include " + "'context_frames','max_epochs','batch_size','shuffle_on_val' 'loss_fun'".format(method)) + + @abstractmethod + def build_graph(self, x: tf.Tensor): + pass + + + @abstractmethod + def build_model(self): + pass diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index bd3795ebecbf750bd115a430ee1adce64074708a..dc58cba3638309cd6b7054b91e45363f2aa42fce 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -10,97 +10,61 @@ from model_modules.video_prediction.models.model_helpers import set_and_check_pr import tensorflow as tf from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell -from tensorflow.contrib.training import HParams +from .our_base_model import BaseModels - - -class VanillaConvLstmVideoPredictionModel(object): +class VanillaConvLstmVideoPredictionModel(BaseModels): def __init__(self, hparams_dict=None, **kwargs): """ This is class for building convLSTM architecture by using updated hparameters args: hparams_dict : dict, the dictionary contains the hparaemters names and values """ - self.hparams_dict = hparams_dict - self.hparams = self.parse_hparams() - self.learning_rate = self.hparams.lr - self.context_frames = self.hparams.context_frames - self.sequence_length = self.hparams.sequence_length - self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) - self.max_epochs = self.hparams.max_epochs - self.loss_fun = self.hparams.loss_fun - self.opt_var = self.hparams.opt_var - # Attributes set during runtime - self.loss_summary = None - self.total_loss = None - self.outputs = {} - self.train_op = None - self.summary_op = None - self.x = None - self.inputs = None - self.global_step = None - self.saveable_variables = None - self.is_build_graph = None - self.x_hat = None - self.x_hat_predict_frames = None - - - def get_default_hparams(self): - return HParams(**self.get_default_hparams_dict()) - - def parse_hparams(self): - """ - Parse the hparams setting to ovoerride the default ones - """ - - parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) - return parsed_hparams + super().__init__(hparams_dict) + self.get_hparams() - def get_default_hparams_dict(self): + def get_hparams(self): """ - The function that contains default hparams - Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. - sequence_length : the number of frames in the video sequence - max_epochs : the number of epochs to train model - lr : learning rate - loss_fun : the loss function - opt_var : the target vars/channel to be optimize, string: "0","1",..."n", or "all", if "all" means optimize all the variables/channels + obtain the hparams from the dict to the class variables """ - hparams = dict( - context_frames=10, - sequence_length=20, - max_epochs=20, - batch_size=40, - lr=0.001, - loss_fun="cross_entropy", - shuffle_on_val=True, - opt_var="0", - ) - return hparams + method = BaseModels.get_hparams.__name__ + + try: + self.context_frames = self.hparams.context_frames + self.sequence_length = self.hparams.sequence_length + self.max_epochs = self.hparams.max_epochs + self.batch_size = self.hparams.batch_size + self.shuffle_on_val = self.hparams.shuffle_on_val + self.loss_fun = self.hparams.loss_fun + self.opt_var = self.hparams.opt_var + self.learning_rate = self.hparams.lr + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) + print("The model hparams have been parsed successfully! ") + except Exception as error: + print("Method %{}: error: {}".format(method, error)) + raise("Method %{}: the hparameter dictionary must include " + "'context_frames','max_epochs','batch_size','shuffle_on_val' 'loss_fun'," + "'opt_var', 'lr', 'opt_var'".format(method)) def build_graph(self, x): self.is_build_graph = False self.inputs = x - self.x = x["images"] self.global_step = tf.train.get_or_create_global_step() original_global_variables = tf.global_variables() - self.convLSTM_network() + self.build_model() #This is the loss function (MSE): #Optimize all target variables/channels if self.opt_var == "all": - x = self.x[:, self.context_frames:, :, :, :] + x = self.inputs[:, self.context_frames:, :, :, :] x_hat = self.x_hat_predict_frames[:, :, :, :, :] print ("The model is optimzied on all the variables in the loss function") elif self.opt_var != "all" and isinstance(self.opt_var, str): self.opt_var = int(self.opt_var) print ("The model is optimized on the {} variable in the loss function".format(self.opt_var)) - x = self.x[:, self.context_frames:, :, :, self.opt_var] + x = self.inputs[:, self.context_frames:, :, :, self.opt_var] x_hat = self.x_hat_predict_frames[:, :, :, :, self.opt_var] else: raise ValueError("The opt var in the hyperparameters setup should be '0','1','2' indicate the index of target variable to be optimised or 'all' indicating optimize all the variables") @@ -129,9 +93,9 @@ class VanillaConvLstmVideoPredictionModel(object): global_variables = [var for var in tf.global_variables() if var not in original_global_variables] self.saveable_variables = [self.global_step] + global_variables self.is_build_graph = True - return self.is_build_graph + return self.is_build_graph - def convLSTM_network(self): + def build_model(self): network_template = tf.make_template('network', VanillaConvLstmVideoPredictionModel.convLSTM_cell) # make the template to share the variables # create network @@ -141,7 +105,7 @@ class VanillaConvLstmVideoPredictionModel(object): hidden_g = None for i in range(self.sequence_length-1): if i < self.context_frames: - x_1_g, hidden_g = network_template(self.x[:, i, :, :, :], hidden_g) + x_1_g, hidden_g = network_template(self.inputs[:, i, :, :, :], hidden_g) else: x_1_g, hidden_g = network_template(x_1_g, hidden_g) x_hat.append(x_1_g) diff --git a/video_prediction_tools/no_HPC_scripts/data_extraction_era5_template.sh b/video_prediction_tools/no_HPC_scripts/data_extraction_era5_template.sh deleted file mode 100644 index 4e3c4d3a96f4ffbe1f0c238dc650eba7760151ff..0000000000000000000000000000000000000000 --- a/video_prediction_tools/no_HPC_scripts/data_extraction_era5_template.sh +++ /dev/null @@ -1,33 +0,0 @@ -#!/bin/bash -x - -######### Template identifier (don't remove) ######### -echo "Do not run the template scripts" -exit 99 -######### Template identifier (don't remove) ######### - -# Name of virtual environment -VIRT_ENV_NAME=venv_test - -echo "Activating virtual environment..." -source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate - -# Declare path-variables (dest_dir will be set and configured automatically via generate_runscript.py) -source_dir=/my/path/to/era5 -destination_dir=/my/path/to/extracted/data -varmap_file=/my/path/to/varmapping/file - -years=( "2007" ) - -#The number of nodes should be equal to the number of 1 preprocessed folder plus 1 -n_nodes=3 - -# Run data extraction -for year in "${years[@]}"; do - echo "Perform ERA5-data extraction for year ${year}" - python ../main_scripts/main_data_extraction.py --source_dir ${source_dir} --target_dir ${destination_dir} \ - --year ${year} --varslist_path ${varmap_file} -done - - - - diff --git a/video_prediction_tools/no_HPC_scripts/era5_data_extraction_template.sh b/video_prediction_tools/no_HPC_scripts/era5_data_extraction_template.sh new file mode 100644 index 0000000000000000000000000000000000000000..c0045040c2825f3a954e938179a799b7514bf440 --- /dev/null +++ b/video_prediction_tools/no_HPC_scripts/era5_data_extraction_template.sh @@ -0,0 +1,45 @@ +#!/bin/bash -x + +######### Template identifier (don't remove) ######### +echo "Do not run the template scripts" +exit 99 +######### Template identifier (don't remove) ######### + +# Name of virtual environment +VIRT_ENV_NAME=venv_test + +if [ -z "${VIRTUAL_ENV}" ]; then + if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then + echo "Activating virtual environment..." + source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate + else + echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." + exit 1 + fi +fi + +# select years and variables for dataset and define target domain +years=( 2017 ) +months=( "all" ) +var_dict='{"2t": {"sf": ""}, "tcc": {"sf": ""}, "t": {"ml": "p85000."}}' +sw_corner=(38.4 0.0) +nyx=(56 92) + +# set some paths +# note, that destination_dir is adjusted during runtime based on the data +source_dir=/my/path/to/era5/data +destination_dir=/my/path/to/extracted/data + +# Must be at least 2 +n_nodes=3 + +# execute Python-script +mpirun -n ${n_nodes} python3 ../main_scripts/main_era5_data_extraction.py -src_dir "${source_dir}" \ + -dest_dir "${destination_dir}" -y "${years[@]}" -m "${months[@]}" \ + -swc "${sw_corner[@]}" -nyx "${nyx[@]}" -v "${var_dict}" + + + + + + diff --git a/video_prediction_tools/no_HPC_scripts/preprocess_data_era5_step1_template.sh b/video_prediction_tools/no_HPC_scripts/preprocess_data_era5_step1_template.sh deleted file mode 100644 index 8519569a32d6a86c98363c85eb707ffa75e4960f..0000000000000000000000000000000000000000 --- a/video_prediction_tools/no_HPC_scripts/preprocess_data_era5_step1_template.sh +++ /dev/null @@ -1,44 +0,0 @@ -#!/bin/bash -x - -######### Template identifier (don't remove) ######### -echo "Do not run the template scripts" -exit 99 -######### Template identifier (don't remove) ######### - -# Name of virtual environment -VIRT_ENV_NAME=venv_test - -if [ -z ${VIRTUAL_ENV} ]; then - if [[ -f ../virtual_envs/${VIRT_ENV_NAME}/bin/activate ]]; then - echo "Activating virtual environment..." - source ../virtual_envs/${VIRT_ENV_NAME}/bin/activate - else - echo "ERROR: Requested virtual environment ${VIRT_ENV_NAME} not found..." - exit 1 - fi -fi - -#select years and variables for dataset and define target domain -years=( "2007" ) -variables=( "2t" ) -sw_corner=( 10 20) -nyx=( 40 40 ) - -#your source dir and target dir -source_dir=/home/b.gong/data_era5 -destination_dir=/home/b.gong/preprocessed_data - -#The number of nodes should be equal to the number of 1 preprocessed folders plus 1 -n_nodes=3 - -for year in "${years[@]}"; do - echo "start preprocessing data for year ${year}" - mpirun -n ${n_nodes} python ../main_scripts/main_preprocess_data_step1.py \ - --source_dir ${source_dir} --destination_dir ${destination_dir} --years "${year}" \ - --vars "${variables[0]}" \ - --sw_corner "${sw_corner[0]}" "${sw_corner[1]}" --nyx "${nyx[0]}" "${nyx[1]}" -done - - - - diff --git a/video_prediction_tools/utils/dataset_utils.py b/video_prediction_tools/utils/dataset_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..f45674a4380fe6fc4ce4bca4e941c77d92d33fa5 --- /dev/null +++ b/video_prediction_tools/utils/dataset_utils.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: 2022 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT + +""" +functions providing info about available options +Provides: * DATASET_META_LOCATION + * DATASETS + * get_dataset_info +""" + +import json +from pathlib import Path +from typing import Dict, Any, List +#from dataclasses import dataclass #TODO use dataclass in python 3.7+ + +DATASET_META_LOCATION = Path(__file__).parent.parent / "data_split" +DATASETS = [path.name for path in DATASET_META_LOCATION.iterdir() if path.is_dir()] + +DATE_TEMPLATE = "{year}-{month:02d}" + +def get_filename_template(name: str) -> str: + return f"{name}_{DATE_TEMPLATE}.nc" + +def get_dataset_info(name: str) -> Dict[str,Any]: + """Extract metainformation about dataset from corresponding JSON file.""" + file = DATASET_META_LOCATION / f"{name}/{name}.json" + try: + with open(file, "r") as f: + return json.load(f) # TODO: input validation => specify schema + except FileNotFoundError as e: + raise ValueError("Information on dataset '{dataset}' doesnt exist.") diff --git a/video_prediction_tools/utils/general_utils.py b/video_prediction_tools/utils/general_utils.py index f4031cee83fc27bd14beaedd455f32e9b86a42fd..bf3a47c42ada09e89a18e1d00eb0004434973375 100644 --- a/video_prediction_tools/utils/general_utils.py +++ b/video_prediction_tools/utils/general_utils.py @@ -6,6 +6,7 @@ Some auxilary routines which may are used throughout the project. Provides: * get_unique_vars * add_str_to_path + * get_path_component * is_integer * ensure_list * isw @@ -60,6 +61,23 @@ def add_str_to_path(path_in: str, add_str: str): return line_str +def get_path_component(path: str, ind: int): + """ + Get the ind-component of path, e.g. get_path_component("/my/dir/is", 1) yields "dir". + :param path: the full path + :param ind: the index of the component to retrieve + :return: + """ + method = get_path_component.__name__ + + assert isinstance(path, str), "%{0}: Passed path must be string, but is of type '{1}'".format(method, type(path)) + assert isinstance(ind, int), "%{0}: Passed ind must be an integer, but is of type '{1}'".format(method, type(ind)) + + path_comps = os.path.normpath(path).split(os.path.sep)[1:] + + return path_comps[ind] + + def is_integer(n): """ :param n: input string @@ -259,3 +277,136 @@ def provide_default(dict_in: dict, keyname: str, default=None, required: bool = return dict_in[keyname] +def depth2intensity(depth, interval=600): + """ + Function for convertion rainfall depth (in mm) to + rainfall intensity (mm/h) + + Args: + depth: float + float or array of float + rainfall depth (mm) + + interval : number + time interval (in sec) which is correspondend to depth values + + Returns: + intensity: float + float or array of float + rainfall intensity (mm/h) + """ + return depth * 3600 / interval + + +def intensity2depth(intensity, interval=600): + """ + Function for convertion rainfall intensity (mm/h) to + rainfall depth (in mm) + + Args: + intensity: float + float or array of float + rainfall intensity (mm/h) + + interval : number + time interval (in sec) which is correspondend to depth values + + Returns: + depth: float + float or array of float + rainfall depth (mm) + """ + return intensity * interval / 3600 + + +def RYScaler(X_mm): + ''' + Scale RY data from mm (in float64) to brightness (in uint8). + + Args: + X (numpy.ndarray): RY radar image + + Returns: + numpy.ndarray(uint8): brightness integer values from 0 to 255 + for corresponding input rainfall intensity + float: c1, scaling coefficient + float: c2, scaling coefficient + + ''' + def mmh2rfl(r, a=256., b=1.42): + ''' + .. based on wradlib.zr.r2z function + + .. r --> z + ''' + return a * r ** b + + def rfl2dbz(z): + ''' + .. based on wradlib.trafo.decibel function + + .. z --> d + ''' + return 10. * np.log10(z) + + # mm to mm/h + X_mmh = depth2intensity(X_mm) + # mm/h to reflectivity + X_rfl = mmh2rfl(X_mmh) + # remove zero reflectivity + # then log10(0.1) = -1 not inf (numpy warning arised) + X_rfl[X_rfl == 0] = 0.1 + # reflectivity to dBz + X_dbz = rfl2dbz(X_rfl) + # remove all -inf + X_dbz[X_dbz < 0] = 0 + + # MinMaxScaling + c1 = X_dbz.min() + c2 = X_dbz.max() + + return ((X_dbz - c1) / (c2 - c1) * 255).astype(np.uint8), c1, c2 + + +def inv_RYScaler(X_scl, c1, c2): + ''' + Transfer brightness (in uint8) to RY data (in mm). + Function which is inverse to Scaler() function. + + Args: + X_scl (numpy.ndarray): array of brightness integers obtained + from Scaler() function. + c1: first scaling coefficient obtained from Scaler() function. + c2: second scaling coefficient obtained from Scaler() function. + + Returns: + numpy.ndarray(float): RY radar image + + ''' + def dbz2rfl(d): + ''' + .. based on wradlib.trafo.idecibel function + + .. d --> z + ''' + return 10. ** (d / 10.) + + def rfl2mmh(z, a=256., b=1.42): + ''' + .. based on wradlib.zr.z2r function + + .. z --> r + ''' + return (z / a) ** (1. / b) + + # decibels to reflectivity + X_rfl = dbz2rfl((X_scl / 255)*(c2 - c1) + c1) + # 0 dBz are 0 reflectivity, not 1 + X_rfl[X_rfl == 1] = 0 + # reflectivity to rainfall in mm/h + X_mmh = rfl2mmh(X_rfl) + # intensity in mm/h to depth in mm + X_mm = intensity2depth(X_mmh) + + return X_mm + diff --git a/video_prediction_tools/utils/hparams_utils.py b/video_prediction_tools/utils/hparams_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..54663942bdf9372874be5d791b528018b1d34ce5 --- /dev/null +++ b/video_prediction_tools/utils/hparams_utils.py @@ -0,0 +1,41 @@ +#PDX-FileCopyrightText: 2021 Earth System Data Exploration (ESDE), Jülich Supercomputing Center (JSC) +# +# SPDX-License-Identifier: MIT +__email__ = "b.gong@fz-juelich.de" +__author__ = "Bing Gong" +__date__ = "2022-03-17" + + + +class dotdict(dict): + """dot.notation access to dictionary attributes""" + __getattr__ = dict.get + __setattr__ = dict.__setitem__ + __delattr__ = dict.__delitem__ + + +#auxiliary functions used for parsing the hyerparameters from hparams_dict +def reduce_dict(dict_in: dict, dict_ref: dict): + """ + Reduces input dictionary to keys from reference dictionary. If the input dictionary lacks some keys, these are + copied over from the reference dictionary, i.e. the reference dictionary provides the defaults + :param dict_in: input dictionary + :param dict_ref: reference dictionary + :return: reduced form of input dictionary (with keys complemented from dict_ref if necessary) + """ + method = reduce_dict.__name__ + + # sanity checks + assert isinstance(dict_in, dict), "%{0}: dict_in must be a dictionary, but is of type {1}"\ + .format(method, type(dict_in)) + assert isinstance(dict_ref, dict), "%{0}: dict_ref must be a dictionary, but is of type {1}"\ + .format(method, type(dict_ref)) + + dict_merged = {**dict_ref, **dict_in} + dict_reduced = {key: dict_merged[key] for key in dict_ref} + + return dict_reduced + + + + diff --git a/video_prediction_tools/utils/normalization.py b/video_prediction_tools/utils/normalization.py index 250b93056aec8ad3154bcaf59b5379cf75b2b7b1..f9d9f362ddc24c47b53d8a3d6d34b156dcc8b5b3 100644 --- a/video_prediction_tools/utils/normalization.py +++ b/video_prediction_tools/utils/normalization.py @@ -17,6 +17,8 @@ class Norm_data: known_norms = {} known_norms["minmax"] = ["min", "max"] known_norms["znorm"] = ["avg", "sigma"] + known_norms["cbnorm"] = [] + known_norms["lognorm"] = [] def __init__(self, varnames): """Initialize the instance by setting the variable names to be handled and the status (for sanity checks only) as attributes.""" @@ -76,6 +78,11 @@ class Norm_data: getattr(self, varname + "max") - getattr(self, varname + "min"))) elif norm == "znorm": return ((data[...] - getattr(self, varname + "avg")) / getattr(self, varname + "sigma") ** 2) + elif norm == "cbnorm": + return data[...] + # return (data[...] ** (1./3)) + elif norm == "lognorm": + return (np.log(data[...]+0.001)-np.log(0.001)) def denorm_var(self, data, varname, norm): """ @@ -97,4 +104,12 @@ class Norm_data: return (data[...] * (getattr(self, varname + "max") - getattr(self, varname + "min")) + getattr(self, varname + "min")) elif norm == "znorm": - return (data[...] * getattr(self, varname + "sigma") ** 2 + getattr(self, varname + "avg")) \ No newline at end of file + return (data[...] * getattr(self, varname + "sigma") ** 2 + getattr(self, varname + "avg")) + + elif norm == "cbnorm": + return data[...] + # return (data[...] ** 3) + + elif norm == "lognorm": + return (np.round(np.exp(data[...]+np.log(0.001))-0.001,4)) + diff --git a/video_prediction_tools/utils/pystager_utils.py b/video_prediction_tools/utils/pystager_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..239c82f128993171b8ba2bc0b4a09c773273b736 --- /dev/null +++ b/video_prediction_tools/utils/pystager_utils.py @@ -0,0 +1,491 @@ +# ********** Info ********** +# @Creation: 2021-07-25 +# @Update: 2021-07-27 +# @Author: Michael Langguth, based on work by Amirpasha Mozaffari +# @Site: Juelich supercomputing Centre (JSC) @ FZJ +# @File: pystager_utils.py +# ********** Info ********** + +import sys, os +# The script must be executed with the mpi4py-module to ensure that the job gets aborted when an error is risen +# see https://mpi4py.readthedocs.io/en/stable/mpi4py.run.html +if "mpi4py" not in sys.modules: + raise ModuleNotFoundError("Python-script must be called with the mpi4py module, i.e. 'python -m mpi4py <...>.py") +import multiprocessing +import subprocess +import inspect +from typing import Union, List +from mpi4py import MPI +import logging +import numpy as np +import pandas as pd +import datetime as dt +import platform + + +class Distributor(object): + """ + Class for defining (customized) distributors. The distributor selected by the distributor_engine must provide + the dynamical arguments for a parallelized job run by PyStager (see below) which inherits from this class. + """ + class_name = "Distributor" + + def __init__(self, distributor_name): + self.distributor_name = distributor_name + + def distributor_engine(self, distributor_name: str): + """ + Sets up distributor for organinzing parallelization. + :param distributor_name: Name of distributor + :return distributor: selected callable distributor + """ + method = "{0}->{1}".format(Distributor.class_name, Distributor.distributor_engine.__name__) + + if distributor_name.lower() == "date": + distributor = self.distributor_date + elif distributor_name.lower() == "year_month_list": + distributor = self.distributor_year_month + else: + raise ValueError("%{0}: The distributor named {1} is not implemented yet.".format(method, distributor_name)) + + return distributor + + def distributor_date(self, date_start: dt.datetime, date_end: dt.datetime, freq: str = "1D"): + """ + Creates a transfer dictionary whose elements are lists for individual start and end dates for each processor + param date_start: first date to convert + param date_end: last date to convert + return: transfer_dictionary allowing date-based parallelization + """ + method = "{0}->{1}".format(Distributor.class_name, Distributor.distributor_date.__name__) + + # sanity checks + if not (isinstance(date_start, dt.datetime) and isinstance(date_end, dt.datetime)): + raise ValueError("%{0}: date_start and date_end have to datetime objects!".format(method)) + + if not (date_start.strftime("%H") == "00" and date_end.strftime("%H") == "00"): + raise ValueError("%{0}: date_start and date_end must be valid at 00 UTC.".format(method)) + + if not int((date_end - date_start).days) >= 1: + raise ValueError("%{0}: date_end must be at least one day after date_start.".format(method)) + + if not hasattr(self, "num_processes"): + print("%{0}: WARNING: Attribute num_processes is not set and thus no parallelization will take place.") + num_processes = 2 + else: + num_processes = self.num_processes + + # init transfer dictionary (function create_transfer_dict_from_list does not work here since transfer dictionary + # consists of tuple with start and end-date instead of a number of elements) + transfer_dict = dict.fromkeys(list(range(1, num_processes))) + + dates_req_all = pd.date_range(date_start, date_end, freq=freq) + ndates = len(dates_req_all) + ndates_per_node = int(np.ceil(float(ndates)/(num_processes-1))) + + for node in np.arange(num_processes): + ind_max = np.minimum((node+1)*ndates_per_node-1, ndates-1) + transfer_dict[node+1] = [dates_req_all[node*ndates_per_node], + dates_req_all[ind_max]] + if ndates-1 == ind_max: + break + + return transfer_dict + + def distributor_year_month(self, years, months): + + method = "{0}->{1}".format(Distributor.class_name, Distributor.distributor_year_month.__name__) + + list_or_int = Union[List, int] + + assert isinstance(years, list_or_int.__args__), "%{0}: Input years must be list of years or an integer."\ + .format(method) + assert isinstance(months, list_or_int.__args__), "%{0}: Input months must be list of months or an integer."\ + .format(method) + + if not hasattr(self, "num_processes"): + print("%{0}: WARNING: Attribute num_processes is not set and thus no parallelization will take place.") + num_processes = 2 + else: + num_processes = self.num_processes + + if isinstance(years, int): years = [years] + if isinstance(months, int): months = [months] + + all_years_months = [] + for year in years: + for month in months: + all_years_months.append(dt.datetime.strptime("{0:d}-{1:02d}".format(int(year), int(month)), "%Y-%m")) + + transfer_dict = Distributor.create_transfer_dict_from_list(all_years_months, num_processes) + + return transfer_dict + + @staticmethod + def create_transfer_dict_from_list(in_list: List, num_procs: int): + """ + Splits up list to transfer dictionary for PyStager. + :param in_list: list of elements that can be distributed to workers + :param num_procs: Number of workers that are avaialable to operate on elements of list + :return: transfer dictionary for PyStager + """ + + method = Distributor.create_transfer_dict_from_list.__name__ + + assert isinstance(in_list, list), "%{0} Input argument in_list must be a list, but is of type '{1}'."\ + .format(method, type(in_list)) + + assert int(num_procs) >= 2, "%{0}: Number of processes must be at least two.".format(method) + + nelements = len(in_list) + nelements_per_node = int(np.ceil(float(nelements)/(num_procs-1))) + + transfer_dict = dict.fromkeys(list(range(1, num_procs))) + + for i, element in enumerate(in_list): + ind = i % (num_procs-1) + 1 + if i < num_procs -1: + transfer_dict[ind] = [element] + else: + transfer_dict[ind].append(element) + + print("%{0}: Generated tarnsfer_dict:".format(method)) + print(transfer_dict) + + return transfer_dict + + +class PyStager(Distributor): + """ + Organizes parallelized execution of a job. + The job must be wrapped into a function-object that can be fed with dynamical arguments provided by an available + distributor and static arguments (see below). + Running PyStager constitutes a three-step approach. First PyStager must be instanciated, then it must be set-up by + calling the setup-method and finally, the job gets executed in a parallelized manner. + Example: Let the function 'process_data' be capable to process hourly data files between date_start and date_end. + Thus, parallelization can be organized with distributor_date which only must be fed with a start and end + date (the freq-argument is optional and defaults to "1D" -> daily frequency (see pandas)). + With the data being stored under <data_dir>, PyStager can be called in a Python-script by: + pystager_obj = PyStager(process_data, "date") + pystager_obj.setup(<start_datetime_obj>, <end_datetime_obj>, freq="1H") + pystager_obj.run(<static_arguments>) + By splitting up the setup-method from the execution, multiple job executions becomes possible. + """ + + class_name = "PyStager" + + def __init__(self, job_func: callable, distributor_name: str, nmax_warn: int = 3, logdir: str = None): + """ + Initialize PyStager. + :param job_func: Function whose execution is meant to be parallelized. This function must accept arguments + dynamical arguments provided by the distributor (see distributo_engine-method) and + static arguments (see run-method) in the order mentioned here. Additionally, it must accept + a logger instance. The argument 'nmax_warn' is optional. + :param distributor_name: Name of distributor which takes care for the paralelization (see distributo_engine + -method) + :param nmax_warn: Maximal number of accepted warnings during job execution (default: 3) + :param logdir: directory where logfile are stored (current working directory becomes the default if not set) + """ + super().__init__(distributor_name) + method = PyStager.__init__.__name__ + + self.cpu_name = platform.processor() + self.num_cpus_max = multiprocessing.cpu_count() + self.distributor = self.distributor_engine(distributor_name) + self.logdir = PyStager.set_and_check_logdir(logdir, distributor_name) + self.nmax_warn = int(nmax_warn) + self.job = job_func + self.transfer_dict = None + self.comm = MPI.COMM_WORLD + self.my_rank = self.comm.Get_rank() + self.num_processes = self.comm.Get_size() + + if not callable(self.job): + raise ValueError("%{0}: Passed function to be parallelized must be a callable function for {1}." + .format(method, PyStager.class_name)) + + if self.nmax_warn <= 0: + raise ValueError("%{0}: nmax_warn must be larger than zero for {1}, but is set to {2:d}" + .format(method, PyStager.class_name, self.nmax_warn)) + + if self.num_processes < 2: + raise ValueError("%{0}: Number of assigned MPI processes must be at least two for {1}." + .format(method, PyStager.class_name)) + + def setup(self, *args): + """ + Simply passes arguments to initialized distributor. + *args : Tuple of arguments suitable for distributor (self.distributor) + """ + method = PyStager.setup.__name__ + + if self.my_rank == 0: + try: + self.transfer_dict = self.distributor(*args) + except Exception as err: + print("%{0}: Failed to set up transfer dictionary of PyStager (see raised error below)".format(method)) + raise err + else: + pass + + # def run(self, data_dir, *args, job_name="dummy"): + def run(self, *args, job_name="dummy"): + """ + Run PyStager. + """ + method = "{0}->{1}".format(PyStager.class_name, PyStager.run.__name__) + + if self.my_rank == 0 and self.transfer_dict is None: + raise AttributeError("%{0}: transfer_dict is still None. Call setup beforehand!".format(method)) + + # if not os.path.isdir(data_dir): + # raise NotADirectoryError("%{0}: The passed data directory '{1}' does not exist.".format(method, data_dir)) + + if self.my_rank == 0: + logger_main = os.path.join(self.logdir, "{0}_job_main.log".format(job_name)) + if os.path.exists(logger_main): + print("%{0}: Main logger file '{1}' already existed and was deleted.".format(method, logger_main)) + os.remove(logger_main) + + logging.basicConfig(filename=logger_main, level=logging.DEBUG, + format="%(asctime)s:%(levelname)s:%(message)s") + logger = logging.getLogger(__file__) + logger.addHandler(logging.StreamHandler(sys.stdout)) + + logger.info("PyStager is started at {0}".format(dt.datetime.now().strftime("%Y-%m%-d %H:%M:%S UTC"))) + + # distribute work to worker processes + for proc in range(1, self.num_processes): + broadcast_list = self.transfer_dict[proc] + self.comm.send(broadcast_list, dest=proc) + + stat_mpi = self.manage_recv_mess(logger) + + if stat_mpi: + logger.info("Job has been executed successfully on {0:d} worker processes. Job exists normally at {1}" + .format(self.num_processes, dt.datetime.now().strftime("%Y-%m%-d %H:%M:%S UTC"))) + else: + # worker logger file + logger_worker = os.path.join(self.logdir, "{0}_job_worker_{1}.log".format(job_name, self.my_rank)) + if os.path.exists(logger_worker): + os.remove(logger_worker) + + logging.basicConfig(filename=logger_worker, level=logging.DEBUG, + format='%(asctime)s:%(levelname)s:%(message)s') + logger = logging.getLogger(__file__) + logger.addHandler(logging.StreamHandler(sys.stdout)) + logger.info("==============Worker logger is activated ==============") + logger.info("Start receiving message from master...") + + stat_worker = self.manage_worker_jobs(logger, *args) + + MPI.Finalize() + + def manage_recv_mess(self, logger): + """ + Manages received messages from worker processes. Also accumulates warnings and aborts job if maximum number is + exceeded + :param logger: logger instance to add logs according to received message from worker + :return stat: True if ok, else False + """ + method = "{0}->{1}".format(PyStager.class_name, PyStager.manage_recv_mess.__name__) + + assert isinstance(self.comm, MPI.Intracomm), "%{0}: comm must be a MPI Intracomm-instance, but is type '{1}'"\ + .format(method, type(self.comm)) + + assert isinstance(logger, logging.Logger), "%{0}: logger must be a Logger-instance, but is of type '{1}'"\ + .format(method, type(logger)) + + message_counter = 1 + warn_counter = 0 + lexit = False + while message_counter < self.num_processes: + mess_in = self.comm.recv() + worker_stat = mess_in[0][:5] + worker_num = mess_in[0][5:7] + worker_str = "Worker with ID {0}".format(worker_num) + # check worker status + if worker_stat == "IDLEE": + logger.info("{0} is idle. Nothing to do.".format(worker_str)) + elif worker_stat == "PASSS": + logger.info("{0} has finished the job successfully".format(worker_str)) + elif worker_stat == "WARNN": + warn_counter += int(mess_in[1]) + logger.warning("{0} has seen a non-fatal error/warning. Number of marnings is now {1:d}" + .format(worker_str, warn_counter)) + elif worker_stat == "ERROR": + logger.critical("{0} met a fatal error. System will be terminated".format(worker_str)) + lexit = True + else: + logger.critical("{0} has sent an unknown message: '{1}'. System will be terminated." + .format(method, worker_stat)) + lexit = True + # sum of warnings exceeds allowed maximum + if warn_counter > self.nmax_warn: + logger.critical("Number of allowed warnings exceeded. Job will be terminated...") + lexit = True + + if lexit: + logger.critical("Job is shut down now.") + raise RuntimeError("%{0}: Distributed jobs could not be run properly.".format(method)) + + message_counter += 1 + + return True + + def manage_worker_jobs(self, logger, *args): + """ + Manages worker processes and runs job with passed arguments. + Receives from master process and returns a tuple of a return-message and a worker status. + :param logger: logger instance to add logs according to received message from master and from parallelized job + :param args: the arguments passed to parallelized job (see self.job in __init__) + :return stat: True if ok, else False + """ + method = "{0}->{1}".format(PyStager.class_name, PyStager.manage_worker_jobs.__name__) + + worker_stat_fail = 9999 + + # sanity checks + assert isinstance(self.comm, MPI.Intracomm), "%{0}: comm must be a MPI Intracomm-instance, but is type '{1}'"\ + .format(method, type(self.comm)) + + assert isinstance(logger, logging.Logger), "%{0}: logger must be a Logger-instance, but is of type '{1}'"\ + .format(method, type(logger)) + + mess_in = self.comm.recv() + + if mess_in is None: + mess_out = ("IDLEE{0}: Worker {1} is idle".format(self.my_rank, self.my_rank), 0) + logger.info(mess_out) + logger.info("Thus, nothing to do. Job is terminated locally on rank {0}".format(self.my_rank)) + self.comm.send(mess_out, dest=0) + return True + else: + logger.info("Worker {0} received input message: {1}".format(self.my_rank, mess_in[0])) + if "nmax_warn" in inspect.getfullargspec(self.job).args: + worker_stat = self.job(mess_in, *args, logger, nmax_warn=self.nmax_warn) + else: + worker_stat = self.job(mess_in, *args, logger) + + + err_mess = None + if worker_stat == -1: + mess_out = ("ERROR{0}: Failure was triggered.".format(self.my_rank), worker_stat_fail) + logger.critical("Progress was unsuccessful due to a fatal error observed." + + " Worker{0} triggers termination of all jobs.".format(self.my_rank)) + err_mess = "Worker{0} met a fatal error. Cannot continue...".format(self.my_rank) + elif worker_stat == 0: + logger.debug('Progress was successful') + mess_out = ("PASSS{0}:was finished".format(self.my_rank), worker_stat) + logger.info('Worker {0} finished a task'.format(self.my_rank)) + elif worker_stat > 0: + logger.debug("Progress faced {0:d} warnings which is still acceptable,".format(int(worker_stat)) + + " but requires investigation of the processed dataset.") + mess_out = ("WARNN{0}: Several warnings ({1:d}) have been triggered " + .format(self.my_rank, worker_stat), worker_stat) + logger.warning("Worker {0} has met relevant warnings, but still could continue...".format(self.my_rank)) + else: + mess_out = ("ERROR{0}: Unknown worker status ({1:d}) received ".format(self.my_rank, worker_stat), + worker_stat_fail) + err_mess = "Worker {0} has produced unknown worker status and triggers termination of all jobs."\ + .format(self.my_rank) + logger.critical(err_mess) + # communicate to master process + self.comm.send(mess_out, dest=0) + + if err_mess: + return False + else: + return True + + @staticmethod + def set_and_check_logdir(logdir, distributor_name): + """ + Sets and checks logging directory. + :param logdir: parent directory where log-files will be stored + :param distributor_name: name of distributor-method (used for naming actual log-directory) + :return logdir: adjusted log-directory + """ + method = PyStager.set_and_check_logdir.__name__ + + if logdir is None: + logdir = os.path.join(os.getcwd(), "pystager_log_{0}".format(distributor_name)) + os.makedirs(logdir, exist_ok=True) + print("%{0}: Default log directory '{1}' is used.".format(method, logdir)) + else: + if not os.path.isdir(logdir): + try: + os.mkdir(logdir) + print("%{0}: Created log directory '{1}'".format(method, logdir)) + except Exception as err: + print("%{0}: Failed to create desired log directory '{1}'".format(method, logdir)) + raise Exception + else: + print("%{0}: Directory '{1}' is used as log directory.".format(method, logdir)) + + return logdir + + @staticmethod + def directory_scanner(source_path, lprint=True): # is used at all ? + """ + Scans through directory and returns a couple of information. + NOTE: Subdirectories under source_path are not recursively scanned + :param source_path: Input idrectory to scan + :param lprint: Boolean if info should be printed (default: True) + :return dir_info: dictionary containing info on scanned directory with the following keys + "dir_detail_list": overview on number of files and required memory + "sub_dir_list": list of subsirectories + "total_size_source": total meory under source_path + "total_num_files": total number of files under source_path + "total_num_directories": total number of directories under source_path + """ + method = PyStager.directory_scanner.__name__ + + dir_detail_list = [] # directories details + sub_dir_list = [] + total_size_source = 0 + total_num_files = 0 + + if not os.path.isdir(source_path): + raise NotADirectoryError("%{0}: The directory '{1}' does not exist.".format(method, source_path)) + + list_directories = os.listdir(source_path) + + for d in list_directories: + path = os.path.join(source_path, d) + if os.path.isdir(path): + sub_dir_list.append(d) + sub_dir_list.sort() + # size of the files and subdirectories + size_dir = subprocess.check_output(['du', '-sc', path]) + splitted = size_dir.split() # fist item is the size of the folder + size = (splitted[0]) + num_files = len([f for f in os.listdir(path) if os.path.isfile(os.path.join(path, f))]) + dir_detail_list.extend([d, size, num_files]) + total_num_files = total_num_files + int(num_files) + total_size_source = total_size_source + int(size) + else: + raise NotADirectoryError("%{0}: '{1}' does not exist".format(method, path)) + + total_num_directories = int(len(list_directories)) + total_size_source = float(total_size_source / 1000000) + + if lprint: + print("===== Info from %{0}: =====".format(method)) + print("* Total memory size of the source directory: {0:.2f}Gb.".format(total_size_source)) + print("Total number of the files in the source directory: {0:d} ".format(num_files)) + print("Total number of the directories in the source directory: {0:d} ".format(total_num_directories)) + + dir_info = {"dir_detail_list": dir_detail_list, "sub_dir_list": sub_dir_list, + "total_size_source": total_size_source, "total_num_files": total_num_files, + "total_num_directories": total_num_directories} + + return dir_info + +#1) create dict for load balancing (keys: rank, vals: list of year-months str) +#2) broadcast lists to processes +#3) process task +#4) obtain nwarns +#5) report status based on nwarns +#6) collect worker status +#7) terminate work if critical error/unknown message is encountered, max warnings exeeded \ No newline at end of file diff --git a/video_prediction_tools/utils/runscript_generator/config_extraction.py b/video_prediction_tools/utils/runscript_generator/config_extraction.py index 6d322e29b9e96b7ce6697381a5ce242bcdea9594..dbcd4eeab7d93c526500ea334f92c4f723fba915 100755 --- a/video_prediction_tools/utils/runscript_generator/config_extraction.py +++ b/video_prediction_tools/utils/runscript_generator/config_extraction.py @@ -12,7 +12,10 @@ __date__ = "2021-01-27" import os, glob import subprocess as sp import time +from pathlib import Path + from runscript_generator.config_utils import Config_runscript_base # import parent class +from dataset_utils import DATASETS, get_dataset_info class Config_Extraction(Config_runscript_base): @@ -34,25 +37,100 @@ class Config_Extraction(Config_runscript_base): self.destination_dir = None # list of variables to be written to runscript self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "varmap_file", "years", "destination_dir"] - # copy over method for keyboard interaction - self.run_config = Config_Extraction.run_extraction # # ----------------------------------------------------------------------------------- # - def run_extraction(self): + def run(self): """ Runs the keyboard interaction for data extraction step :return: all attributes of class Data_Extraction are set """ + + dataset = Config_Extraction.keyboard_interaction( + "Choose dataset:" + "".join([f"\n* {name}" for name in DATASETS]) + "\n", + lambda name, silent=False: name in (*DATASETS, "era5"), + ValueError("Cannot find dataset for given name"), + ntries=3 + ) + + if dataset == "weatherbench": # TODO: change once more datasets are added + self.extract(dataset) + else: + self.extract_old() + - method_name = Config_Extraction.run_extraction.__name__ - + def extract(self, dataset): + """ + Special method for handling of weatherbench data. + + (in order to not break previous code) + """ + available_years = get_dataset_info("weatherbench")["years"] + + # override in order to accomdate changes in the structure quick/dirty + self.runscript_template = self.rscrpt_tmpl_prefix + dataset + self.suffix_template + self.runscript_target = self.rscrpt_tmpl_prefix + dataset + ".sh" + self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "years", "destination_dir"] + + def check_input_dir(path, silent=False): + try: + return Path(path).is_dir() + except PermissionError as e: + print(f"Could not access {path} because of missing permissions.") + return False + return False + + def check_years(years, silent=False): + years = [year.strip() for year in years.split(",")] + notyears = list(filter(lambda year: not str.isnumeric(year), years)) + if len(notyears) > 0: + print("The following elements could not be interpreted as valid years:") + for elem in notyears: + print(elem, end=", ") + return False + + return all([int(year) > 1970 and int(year) in available_years for year in years]) + + # get input folder + source_dir = Config_Extraction.keyboard_interaction( + "Enter path to weatherbench data root folder", + check_input_dir, + ValueError("Could not access directory under given path"), + ntries=3 + ) + + # get output folder + destination_dir = Config_Extraction.keyboard_interaction( + "Enter path for destination directory", + check_input_dir, + ValueError("Could not access directory under given path"), + ntries=3 + ) + + # get years + years = Config_Extraction.keyboard_interaction( + "Enter a comma-separated sequence of years for which data extraction should be performed:", + check_years, + ValueError("Cannot get years for preprocessing."), + ntries=3 + ) + years = [year.strip() for year in years.split(",")] + + # set parameters to be written to file + self.years = years + self.source_dir = source_dir + self.destination_dir = destination_dir + + + def extract_old(self): + method_name = Config_Extraction.extract.__name__ + dataset_req_str = "Enter the path where the original ERA5 grib-files are located (standard on JUST: '{0}'):"\ - .format(self.era5dir_just) + .format(self.era5dir_just) dataset_err = FileNotFoundError("Cannot retrieve input data from passed path.") self.source_dir = Config_Extraction.keyboard_interaction(dataset_req_str, Config_Extraction.check_data_indir, - dataset_err, ntries=3) + dataset_err, ntries=3) # maybe use locale variable ? # get years for preprcessing step 1 years_req_str = "Enter a comma-separated sequence of years for which data extraction should be performed:" @@ -94,7 +172,7 @@ class Config_Extraction(Config_runscript_base): base_dir = Config_Extraction.get_var_from_runscript(os.path.join(self.runscript_dir, self.runscript_template), "destination_dir") self.destination_dir = os.path.join(base_dir, "extractedData") - + # # ----------------------------------------------------------------------------------- # diff --git a/video_prediction_tools/utils/runscript_generator/config_postprocess.py b/video_prediction_tools/utils/runscript_generator/config_postprocess.py index 0a0e3316dd242084b316cc435b608127605e7798..3a258272787a4a7e3b0b32d11e5c3bf4f8a3973a 100755 --- a/video_prediction_tools/utils/runscript_generator/config_postprocess.py +++ b/video_prediction_tools/utils/runscript_generator/config_postprocess.py @@ -21,7 +21,7 @@ class Config_Postprocess(Config_runscript_base): # !!! Important note !!! # As long as we don't have runscript templates for all the datasets listed in known_datasets # or a generic template runscript, we need the following manual list - allowed_datasets = ["era5","moving_mnist"] # known_datasets().keys + allowed_datasets = ["era5","moving_mnist","gzprcp_data"] # known_datasets().keys def __init__(self, venv_name, lhpc): super().__init__(venv_name, lhpc) diff --git a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py index 195530bc1679af365c7761dd74bd0f3737316058..7b5d831345048a0f1fa3fcc2d03da4969a42dc98 100755 --- a/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py +++ b/video_prediction_tools/utils/runscript_generator/config_preprocess_step1.py @@ -45,13 +45,11 @@ class Config_Preprocess1(Config_runscript_base): # list of variables to be written to runscript self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "destination_dir", "years", "variables", "sw_corner", "nyx"] - # copy over method for keyboard interaction - self.run_config = Config_Preprocess1.run_preprocess1 # # ----------------------------------------------------------------------------------- # - def run_preprocess1(self): + def run(self): """ Runs the keyboard interaction for Preprocessing step 1 :return: all attributes of class Config_Preprocess1 are set diff --git a/video_prediction_tools/utils/runscript_generator/config_preprocess_step2.py b/video_prediction_tools/utils/runscript_generator/config_preprocess_step2.py index 144ac1d9398fe6acdf3f4089eedd2a988070fc8e..bf433cb1957fd98877328d5900d6deb4e97aa6a2 100755 --- a/video_prediction_tools/utils/runscript_generator/config_preprocess_step2.py +++ b/video_prediction_tools/utils/runscript_generator/config_preprocess_step2.py @@ -33,12 +33,10 @@ class Config_Preprocess2(Config_runscript_base): self.sequence_length = None # only needed for ERA5 # list of variables to be written to runscript self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "destination_dir"] # appended for ERA5 dataset - # copy over method for keyboard interaction - self.run_config = Config_Preprocess2.run_preprocess2 # # ----------------------------------------------------------------------------------- # - def run_preprocess2(self): + def run(self): """ Runs the keyboard interaction for Preprocessing step 2 :return: all attributes of class Config_Preprocess2 set diff --git a/video_prediction_tools/utils/runscript_generator/config_training.py b/video_prediction_tools/utils/runscript_generator/config_training.py index b382e00a7b3662a4350fe10fe972289d28f1c14e..38d3ade765f1ccce84b241d824b7d63c13eb6447 100755 --- a/video_prediction_tools/utils/runscript_generator/config_training.py +++ b/video_prediction_tools/utils/runscript_generator/config_training.py @@ -14,7 +14,9 @@ import re import time import datetime as dt import subprocess as sp +from pathlib import Path from model_modules.model_architectures import known_models + from data_preprocess.dataset_options import known_datasets from runscript_generator.config_utils import Config_runscript_base # import parent class @@ -26,7 +28,7 @@ class Config_Train(Config_runscript_base): # !!! Important note !!! # As long as we don't have runscript templates for all the datasets listed in known_datasets # or a generic template runscript, we need the following manual list - allowed_datasets = ["era5", "moving_mnist"] # known_datasets().keys + allowed_datasets = known_datasets basename_tfdirs = "tfrecords_seq_len_" @@ -46,17 +48,15 @@ class Config_Train(Config_runscript_base): self.model_hparams = None # list of variables to be written to runscript self.list_batch_vars = ["VIRT_ENV_NAME", "source_dir", "model", "destination_dir"] - # copy over method for keyboard interaction - self.run_config = Config_Train.run_training # # ----------------------------------------------------------------------------------- # - def run_training(self): + def run_tfrecords(self): """ Runs the keyboard interaction for Training :return: all attributes of class training are set """ - method_name = Config_Train.run_training.__name__ + method_name = Config_Train.run.__name__ # decide which dataset is used dset_type_req_str = "Enter the name of the dataset on which you want to train:" @@ -73,10 +73,11 @@ class Config_Train(Config_runscript_base): expdir_req_str = "Choose a subdirectory listed above where the preprocessed TFrecords are located:" expdir_err = FileNotFoundError("Could not find any tfrecords.") - self.source_dir = Config_Train.keyboard_interaction(expdir_req_str, Config_Train.check_expdir, + self.source_dir = Config_Train.keyboard_interaction(expdir_req_str, Config_Train.check_expdir_2, expdir_err, ntries=3, prefix2arg=source_dir_base+"/") # expand source_dir by tfrecords-subdirectory - tf_dirs = Config_Train.list_tf_dirs(self) + # tf_dirs = Config_Train.list_tf_dirs(self) + tf_dirs = self.list_dirs() ntf_dirs = len(tf_dirs) if ntf_dirs == 1: @@ -88,19 +89,23 @@ class Config_Train(Config_runscript_base): # Note, how the check_expdir-method is recycled by simply adding a properly defined suffix to it # Note that due to the suffix, the returned value already corresponds to the final path for source_dir - self.source_dir = Config_Train.keyboard_interaction(seq_req_str, Config_Train.check_expdir, + self.source_dir = Config_Train.keyboard_interaction(seq_req_str, Config_Train.check_expdir_2, seq_err, ntries=2, prefix2arg=os.path.join(self.source_dir, Config_Train.basename_tfdirs)) # split up directory path in order to retrieve exp_dir used for setting up the destination directory exp_dir_split = Config_Train.path_rec_split(self.source_dir) + print(exp_dir_split) + print(self.dataset) index = [idx for idx, s in enumerate(exp_dir_split) if self.dataset in s] if not index: raise ValueError( "%{0}: tfrecords found under '{1}', but directory does not seem to reflect naming convention." .format(method_name, self.source_dir)) exp_dir = exp_dir_split[index[0]] + #exp_dir = exp_dir_split[-2] + print(exp_dir) # get the model to train model_req_str = "Enter the name of the model you want to train:" @@ -167,6 +172,115 @@ class Config_Train(Config_runscript_base): # # ----------------------------------------------------------------------------------- # + + def run(self): + """ + Runs the keyboard interaction for Training + :return: all attributes of class training are set + """ + method_name = Config_Train.run.__name__ + + # decide which dataset is used + dset_type_req_str = "Enter the name of the dataset on which you want to train:" + dset_err = ValueError("Please select a dataset from the ones listed above.") + + self.dataset = Config_Train.keyboard_interaction(dset_type_req_str, Config_Train.check_dataset, + dset_err, ntries=2) + + # get source dir (relative to base_dir_source!) + self.runscript_template = os.path.join(self.runscript_dir, "train_model_{0}{1}" + .format(self.dataset, self.suffix_template)) + source_dir_base = Config_Train.handle_source_dir(self, "preprocessedData") + + expdir_req_str = "Choose a subdirectory listed above where the preprocessed TFrecords are located:" + expdir_err = FileNotFoundError("Could not find any tfrecords.") + + self.source_dir = Config_Train.keyboard_interaction(expdir_req_str, Config_Train.check_expdir_2, + expdir_err, ntries=3, prefix2arg=source_dir_base+"/") + # expand source_dir by tfrecords-subdirectory + # tf_dirs = Config_Train.list_tf_dirs(self) + tf_dirs = self.list_dirs() + ntf_dirs = len(tf_dirs) + + if ntf_dirs == 1: + self.source_dir = os.path.join(self.source_dir, tf_dirs[0]) + + # split up directory path in order to retrieve exp_dir used for setting up the destination directory + exp_dir_split = Config_Train.path_rec_split(self.source_dir) + print(exp_dir_split) + print(self.dataset) + index = [idx for idx, s in enumerate(exp_dir_split) if self.dataset in s] + if not index: + raise ValueError( + "%{0}: tfrecords found under '{1}', but directory does not seem to reflect naming convention." + .format(method_name, self.source_dir)) + exp_dir = exp_dir_split[index[0]] + #exp_dir = exp_dir_split[-2] + print(exp_dir) + + # get the model to train + model_req_str = "Enter the name of the model you want to train:" + model_err = ValueError("Please select a model from the ones listed above.") + + self.model = Config_Train.keyboard_interaction(model_req_str, Config_Train.check_model, model_err, ntries=2) + + # experimental ID + # No need to call keyboard_interaction here, because the user can pass whatever we wants + self.exp_id = input("*** Enter your desired experimental id (will be extended by timestamp and username):\n") + + # also get current timestamp and user-name... + timestamp = dt.datetime.now().strftime("%Y%m%dT%H%M%S") + user_name = os.environ["USER"] + # ... to construct final destination_dir and exp_dir_ext as well + self.exp_id = timestamp + "_" + user_name + "_" + self.exp_id # by convention, exp_id is extended by timestamp and username + + # now, we are also ready to set the correct name of the runscript template and the target + self.runscript_target = "{0}{1}_{2}.sh".format(self.rscrpt_tmpl_prefix, self.dataset, self.exp_id) + + base_dir = Config_Train.get_var_from_runscript(os.path.join(self.runscript_dir, self.runscript_template), + "destination_dir") + exp_dir_ext = os.path.join(exp_dir, self.model, self.exp_id) + self.destination_dir = os.path.join(base_dir, "models", exp_dir, self.model, self.exp_id) + + # sanity check (target_dir is unique): + if os.path.isdir(self.destination_dir): + raise IsADirectoryError("%{0}: {1} already exists! Make sure that it is unique." + .format(method_name, self.destination_dir)) + + # create destination directory... + os.makedirs(self.destination_dir) + + # Create json-file for data splitting + source_datasplit = os.path.join("..", "data_split", self.dataset, "datasplit_template.json") + self.datasplit_dict = os.path.join(self.destination_dir, "data_split.json") + # sanity check (default data_split json-file exists) + if not os.path.isfile(source_datasplit): + raise FileNotFoundError("%{0}: Could not find default data_split json-file '{1}'".format(method_name, + source_datasplit)) + # ...copy over json-file for data splitting... + os.system("cp "+source_datasplit+" "+self.datasplit_dict) + # ...and open vim after some delay + print("*** Please configure the data splitting:") + time.sleep(3) + cmd_vim = os.environ.get('EDITOR', 'vi') + ' ' + os.path.join(self.destination_dir,"data_split.json") + sp.call(cmd_vim, shell=True) + sp.call("sed -i '/^#/d' {0}".format(self.datasplit_dict), shell=True) + + # Create json-file for hyperparameters + source_hparams = os.path.join("..","hparams", self.dataset, self.model, "model_hparams_template.json") + self.model_hparams = os.path.join(self.destination_dir, "model_hparams.json") + # sanity check (default hyperparameter json-file exists) + if not os.path.isfile(source_hparams): + raise FileNotFoundError("%{0}: Could not find default hyperparameter json-file '%{1}'" + .format(method_name, source_hparams)) + # ...copy over json-file for hyperparamters... + os.system("cp "+source_hparams+" "+self.model_hparams) + # ...and open vim after some delay + print("*** Please configure the model hyperparameters:") + time.sleep(3) + cmd_vim = os.environ.get('EDITOR', 'vi') + ' ' + self.model_hparams + sp.call(cmd_vim, shell=True) + def list_tf_dirs(self): method = Config_Train.list_tf_dirs.__name__ @@ -189,6 +303,11 @@ class Config_Train(Config_runscript_base): print("* {0}".format(idir)) return tf_dirs + + + def list_dirs(self): + """Quick hack to negate reliance on tf-records.""" + return [path.name for path in Path(self.source_dir).iterdir() if path.is_dir()] @staticmethod @@ -236,6 +355,11 @@ class Config_Train(Config_runscript_base): # # ----------------------------------------------------------------------------------- # + def check_expdir_2(exp_dir, silent=False): + """Quick hack to negate reliance on tf-records.""" + return True + + @staticmethod def check_model(model_name, silent=False): """ diff --git a/video_prediction_tools/utils/runscript_generator/config_utils.py b/video_prediction_tools/utils/runscript_generator/config_utils.py index 27e8f155f7fa40e4863ff3fe3500ac0093213033..9430bf4fe477e0dc17d34d22a978256411b21179 100755 --- a/video_prediction_tools/utils/runscript_generator/config_utils.py +++ b/video_prediction_tools/utils/runscript_generator/config_utils.py @@ -45,25 +45,15 @@ class Config_runscript_base: self.list_batch_vars = None self.dataset = None self.source_dir = None - # attribute storing workflow-step dependant function for keyboard interaction - self.run_config = None + # # ----------------------------------------------------------------------------------- # def run(self): """ - Acts as generic wrapper: Checks if run_config is already set up as a callable - :return: Executes run_config + Procces keyboard interaction for step-specific runscript configuration. """ - method_name = "run" + " of Class " + Config_runscript_base.cls_name - if self.run_config is None: - raise ValueError("%{0}: run-method is still uninitialized.".format(method_name)) - - if not callable(self.run_config): - raise ValueError("%{0}: run-method is not callable".format(method_name)) - - # simply execute it - self.run_config(self) + raise NotImplementedError() # # ----------------------------------------------------------------------------------- # @@ -224,6 +214,7 @@ class Config_runscript_base: for line in runscript: if script_variable in line: var_value = (line.strip(script_variable)).replace("\n", "") + print(var_value) found = True break diff --git a/video_prediction_tools/utils/runscript_generator/setup_runscript_templates.sh b/video_prediction_tools/utils/runscript_generator/setup_runscript_templates.sh index b684069ad3df3cd02e836ae1c68ffa0ad5752760..75ba47765cebf647e12be76320820ee786c00254 100755 --- a/video_prediction_tools/utils/runscript_generator/setup_runscript_templates.sh +++ b/video_prediction_tools/utils/runscript_generator/setup_runscript_templates.sh @@ -16,7 +16,6 @@ # default value for base directory base_data_dir_default=/p/project/deepacf/deeprain/video_prediction_shared_folder/ -# base_data_dir_default=/p/scratch/deepacf/ji4/ # some further directory paths CURR_DIR_FULL="$( cd "$( dirname "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" # retrieves the location of this script BASE_DIR="$(dirname "$(dirname "${CURR_DIR_FULL}")")"