diff --git a/.gitignore b/.gitignore index 3554b33c8a9239ac1c7db4958c0dd80cc731fb1c..e109cec7c7622cee9d9a635cc458b9c662fc4761 100644 --- a/.gitignore +++ b/.gitignore @@ -73,4 +73,8 @@ report.html # secret variables # #################### -/src/join_settings.py +/src/configuration/join_settings.py + +# ignore locally build documentation # +###################################### +/docs/_build \ No newline at end of file diff --git a/.gitlab-ci.yml b/.gitlab-ci.yml index 484348180c28c8c0e86024be95ff52039466221c..dec6a0fb2ab9eb51fef737c50d5beab4270c8942 100644 --- a/.gitlab-ci.yml +++ b/.gitlab-ci.yml @@ -1,6 +1,7 @@ stages: - init - test + - docs - pages ### Static Badges ### @@ -131,6 +132,31 @@ coverage: - badges/ - coverage/ +#### Documentation #### +sphinx docs: + stage: docs + tags: + - machinelearningtools + - zam347 + before_script: + - chmod +x ./CI/update_badge.sh + - ./CI/update_badge.sh > /dev/null + script: + - pip install -r requirements.txt + - pip install -r docs/requirements_docs.txt + - chmod +x ./CI/create_documentation.sh + - ./CI/create_documentation.sh + after_script: + - ./CI/update_badge.sh > /dev/null + when: always + artifacts: + name: pages + when: always + paths: + - badges/ + - webpage/ + + #### Pages #### pages: stage: pages @@ -138,15 +164,23 @@ pages: - zam347 - base script: + # badges - mkdir -p public/badges/ - cp -af badges/badge_*.svg public/badges/ - ls public/badges/ + # coverage - mkdir -p public/coverage - cp -af coverage/. public/coverage - ls public/coverage + # test - mkdir -p public/test - cp -af test_results/. public/test - ls public/test + # docs + - mkdir -p public/docs + - cp -af webpage/. public/docs + - ls public/docs + # summary - ls public when: always artifacts: @@ -157,6 +191,7 @@ pages: - badges/ - coverage/ - test_results/ + - webpage/ cache: key: old-pages paths: diff --git a/CI/create_documentation.sh b/CI/create_documentation.sh new file mode 100644 index 0000000000000000000000000000000000000000..6f5aa16e3561cdae43a11220ea35f6f4c8126ae2 --- /dev/null +++ b/CI/create_documentation.sh @@ -0,0 +1,46 @@ +#!/bin/bash + +# reset status +echo "failure" > status.txt + +# create webpage folder +BRANCH_NAME=$( echo -e "${CI_COMMIT_REF_NAME////_}") +mkdir -p webpage/ +mkdir -p webpage/recent +#for w in master develop +#do +# if [[ "${CI_COMMIT_REF_NAME}" == "$w" ]]; then +# mkdir -p "webpage/${BRANCH_NAME}" +# fi +#done +mkdir -p "webpage/${BRANCH_NAME}" + +cd docs || { + echo "no docs to build available"; + echo "incomplete" > status.txt; + echo "no docs to build avail" > incomplete.txt; + exit 0; } + +echo "${CI_COMMIT_TAG}" +make clean +make html +IS_FAILED=$? + +# copy results +cp -r ./_build/html/* "../webpage/${BRANCH_NAME}/." +cp -r ./_build/html/* ../webpage/recent/. +if [[ "${CI_COMMIT_REF_NAME}" = "master" ]]; then + cp -r ./_build/html/* ../webpage/. +fi +cd .. + +# report if job was successful +if [[ ${IS_FAILED} == 0 ]]; then + echo "success" + echo "success" > status.txt + echo "build" > success.txt + exit 0 +else + echo "failed" + exit 1 +fi diff --git a/CI/update_badge.sh b/CI/update_badge.sh index c8b11015d27f509faeb4b26b5d88ec7df5a4e675..45e50f49377ed34350ff4d15fc03ca5e2eae6164 100644 --- a/CI/update_badge.sh +++ b/CI/update_badge.sh @@ -71,10 +71,10 @@ printf "%s\n" "${SHIELDS_IO_NAME//\#/%23}" SHIELDS_IO_NAME="$( echo -e "${SHIELDS_IO_NAME//\_/__}" )" SHIELDS_IO_NAME="$( echo -e "${SHIELDS_IO_NAME//\#/%23}")" -curl "https://img.shields.io/badge/${SHIELDS_IO_NAME}" > ${BADGE_FILENAME} +curl "https://img.shields.io/badge/${SHIELDS_IO_NAME}" > "${BADGE_FILENAME}" echo "https://img.shields.io/badge/${SHIELDS_IO_NAME}" SHIELDS_IO_NAME_RECENT="RECENT:${SHIELDS_IO_NAME}" -curl "https://img.shields.io/badge/${SHIELDS_IO_NAME_RECENT}" > ${RECENT_BADGE_FILENAME} +curl "https://img.shields.io/badge/${SHIELDS_IO_NAME_RECENT}" > "${RECENT_BADGE_FILENAME}" echo "${SHIELDS_IO_NAME_RECENT}" > testRecentName.txt # @@ -82,10 +82,10 @@ if [[ ! -d ./badges ]]; then # Control will enter here if $DIRECTORY doesn't exist. mkdir badges/ fi -mv ${BADGE_FILENAME} ./badges/. +mv "${BADGE_FILENAME}" ./badges/. # replace outdated recent badge by new badge -mv ${RECENT_BADGE_FILENAME} ./badges/${RECENT_BADGE_FILENAME} +mv "${RECENT_BADGE_FILENAME}" "./badges/${RECENT_BADGE_FILENAME}" # set status to failed, this will be overwritten if job ended with exitcode 0 echo "failed" > status.txt diff --git a/German_background_stations.json b/German_background_stations.json index 2997eefbaa9a72f4e94b940b6d0ebb7f6a34370d..9e3b89cd06df62442d582758062815ac2ab8bc7c 100755 --- a/German_background_stations.json +++ b/German_background_stations.json @@ -1 +1,334 @@ -["DENW094", "DEBW029", "DENI052", "DENI063", "DEBY109", "DEUB022", "DESN001", "DEUB013", "DETH016", "DEBY002", "DEBY005", "DEBY099", "DEUB038", "DEBE051", "DEBE056", "DEBE062", "DEBE032", "DEBE034", "DEBE010", "DEHE046", "DEST031", "DEBY122", "DERP022", "DEBY079", "DEBW102", "DEBW076", "DEBW045", "DESH016", "DESN004", "DEHE032", "DEBB050", "DEBW042", "DEBW046", "DENW067", "DESL019", "DEST014", "DENW062", "DEHE033", "DENW081", "DESH008", "DEBB055", "DENI011", "DEHB001", "DEHB004", "DEHB002", "DEHB003", "DEHB005", "DEST039", "DEUB003", "DEBW072", "DEST002", "DEBB001", "DEHE039", "DEBW035", "DESN005", "DEBW047", "DENW004", "DESN011", "DESN076", "DEBB064", "DEBB006", "DEHE001", "DESN012", "DEST030", "DESL003", "DEST104", "DENW050", "DENW008", "DETH026", "DESN085", "DESN014", "DESN092", "DENW071", "DEBW004", "DENI028", "DETH013", "DENI059", "DEBB007", "DEBW049", "DENI043", "DETH020", "DEBY017", "DEBY113", "DENW247", "DENW028", "DEBW025", "DEUB039", "DEBB009", "DEHE027", "DEBB042", "DEHE008", "DESN017", "DEBW084", "DEBW037", "DEHE058", "DEHE028", "DEBW112", "DEBY081", "DEBY082", "DEST032", "DETH009", "DEHE010", "DESN019", "DEHE023", "DETH036", "DETH040", "DEMV017", "DEBW028", "DENI042", "DEMV004", "DEMV019", "DEST044", "DEST050", "DEST072", "DEST022", "DEHH049", "DEHH047", "DEHH033", "DEHH050", "DEHH008", "DEHH021", "DENI054", "DEST070", "DEBB053", "DENW029", "DEBW050", "DEUB034", "DENW018", "DEST052", "DEBY020", "DENW063", "DESN050", "DETH061", "DERP014", "DETH024", "DEBW094", "DENI031", "DETH041", "DERP019", "DEBW081", "DEHE013", "DEBW021", "DEHE060", "DEBY031", "DESH021", "DESH033", "DEHE052", "DEBY004", "DESN024", "DEBW052", "DENW042", "DEBY032", "DENW053", "DENW059", "DEBB082", "DEBB031", "DEHE025", "DEBW053", "DEHE048", "DENW051", "DEBY034", "DEUB035", "DEUB032", "DESN028", "DESN059", "DEMV024", "DENW079", "DEHE044", "DEHE042", "DEBB043", "DEBB036", "DEBW024", "DERP001", "DEMV012", "DESH005", "DESH023", "DEUB031", "DENI062", "DENW006", "DEBB065", "DEST077", "DEST005", "DERP007", "DEBW006", "DEBW007", "DEHE030", "DENW015", "DEBY013", "DETH025", "DEUB033", "DEST025", "DEHE045", "DESN057", "DENW036", "DEBW044", "DEUB036", "DENW096", "DETH095", "DENW038", "DEBY089", "DEBY039", "DENW095", "DEBY047", "DEBB067", "DEBB040", "DEST078", "DENW065", "DENW066", "DEBY052", "DEUB030", "DETH027", "DEBB048", "DENW047", "DEBY049", "DERP021", "DEHE034", "DESN079", "DESL008", "DETH018", "DEBW103", "DEHE017", "DEBW111", "DENI016", "DENI038", "DENI058", "DENI029", "DEBY118", "DEBW032", "DEBW110", "DERP017", "DESN036", "DEBW026", "DETH042", "DEBB075", "DEBB052", "DEBB021", "DEBB038", "DESN051", "DEUB041", "DEBW020", "DEBW113", "DENW078", "DEHE018", "DEBW065", "DEBY062", "DEBW027", "DEBW041", "DEHE043", "DEMV007", "DEMV021", "DEBW054", "DETH005", "DESL012", "DESL011", "DEST069", "DEST071", "DEUB004", "DESH006", "DEUB029", "DEUB040", "DESN074", "DEBW031", "DENW013", "DENW179", "DEBW056", "DEBW087", "DEST061", "DEMV001", "DEBB024", "DEBW057", "DENW064", "DENW068", "DENW080", "DENI019", "DENI077", "DEHE026", "DEBB066", "DEBB083", "DEST063", "DEBW013", "DETH086", "DESL018", "DETH096", "DEBW059", "DEBY072", "DEBY088", "DEBW060", "DEBW107", "DEBW036", "DEUB026", "DEBW019", "DENW010", "DEST098", "DEHE019", "DEBW039", "DESL017", "DEBW034", "DEUB005", "DEBB051", "DEHE051", "DEBW023", "DEBY092", "DEBW008", "DEBW030", "DENI060", "DEST011", "DENW030", "DENI041", "DERP015", "DEUB001", "DERP016", "DERP028", "DERP013", "DEHE022", "DEUB021", "DEBW010", "DEST066", "DEBB063", "DEBB028", "DEHE024", "DENI020", "DENI051", "DERP025", "DEBY077", "DEMV018", "DEST089", "DEST028", "DETH060", "DEHE050", "DEUB028", "DESN045", "DEUB042"] +[ + "DENW094", + "DEBW029", + "DENI052", + "DENI063", + "DEBY109", + "DEUB022", + "DESN001", + "DEUB013", + "DETH016", + "DEBY002", + "DEBY005", + "DEBY099", + "DEUB038", + "DEBE051", + "DEBE056", + "DEBE062", + "DEBE032", + "DEBE034", + "DEBE010", + "DEHE046", + "DEST031", + "DEBY122", + "DERP022", + "DEBY079", + "DEBW102", + "DEBW076", + "DEBW045", + "DESH016", + "DESN004", + "DEHE032", + "DEBB050", + "DEBW042", + "DEBW046", + "DENW067", + "DESL019", + "DEST014", + "DENW062", + "DEHE033", + "DENW081", + "DESH008", + "DEBB055", + "DENI011", + "DEHB001", + "DEHB004", + "DEHB002", + "DEHB003", + "DEHB005", + "DEST039", + "DEUB003", + "DEBW072", + "DEST002", + "DEBB001", + "DEHE039", + "DEBW035", + "DESN005", + "DEBW047", + "DENW004", + "DESN011", + "DESN076", + "DEBB064", + "DEBB006", + "DEHE001", + "DESN012", + "DEST030", + "DESL003", + "DEST104", + "DENW050", + "DENW008", + "DETH026", + "DESN085", + "DESN014", + "DESN092", + "DENW071", + "DEBW004", + "DENI028", + "DETH013", + "DENI059", + "DEBB007", + "DEBW049", + "DENI043", + "DETH020", + "DEBY017", + "DEBY113", + "DENW247", + "DENW028", + "DEBW025", + "DEUB039", + "DEBB009", + "DEHE027", + "DEBB042", + "DEHE008", + "DESN017", + "DEBW084", + "DEBW037", + "DEHE058", + "DEHE028", + "DEBW112", + "DEBY081", + "DEBY082", + "DEST032", + "DETH009", + "DEHE010", + "DESN019", + "DEHE023", + "DETH036", + "DETH040", + "DEMV017", + "DEBW028", + "DENI042", + "DEMV004", + "DEMV019", + "DEST044", + "DEST050", + "DEST072", + "DEST022", + "DEHH049", + "DEHH047", + "DEHH033", + "DEHH050", + "DEHH008", + "DEHH021", + "DENI054", + "DEST070", + "DEBB053", + "DENW029", + "DEBW050", + "DEUB034", + "DENW018", + "DEST052", + "DEBY020", + "DENW063", + "DESN050", + "DETH061", + "DERP014", + "DETH024", + "DEBW094", + "DENI031", + "DETH041", + "DERP019", + "DEBW081", + "DEHE013", + "DEBW021", + "DEHE060", + "DEBY031", + "DESH021", + "DESH033", + "DEHE052", + "DEBY004", + "DESN024", + "DEBW052", + "DENW042", + "DEBY032", + "DENW053", + "DENW059", + "DEBB082", + "DEBB031", + "DEHE025", + "DEBW053", + "DEHE048", + "DENW051", + "DEBY034", + "DEUB035", + "DEUB032", + "DESN028", + "DESN059", + "DEMV024", + "DENW079", + "DEHE044", + "DEHE042", + "DEBB043", + "DEBB036", + "DEBW024", + "DERP001", + "DEMV012", + "DESH005", + "DESH023", + "DEUB031", + "DENI062", + "DENW006", + "DEBB065", + "DEST077", + "DEST005", + "DERP007", + "DEBW006", + "DEBW007", + "DEHE030", + "DENW015", + "DEBY013", + "DETH025", + "DEUB033", + "DEST025", + "DEHE045", + "DESN057", + "DENW036", + "DEBW044", + "DEUB036", + "DENW096", + "DETH095", + "DENW038", + "DEBY089", + "DEBY039", + "DENW095", + "DEBY047", + "DEBB067", + "DEBB040", + "DEST078", + "DENW065", + "DENW066", + "DEBY052", + "DEUB030", + "DETH027", + "DEBB048", + "DENW047", + "DEBY049", + "DERP021", + "DEHE034", + "DESN079", + "DESL008", + "DETH018", + "DEBW103", + "DEHE017", + "DEBW111", + "DENI016", + "DENI038", + "DENI058", + "DENI029", + "DEBY118", + "DEBW032", + "DEBW110", + "DERP017", + "DESN036", + "DEBW026", + "DETH042", + "DEBB075", + "DEBB052", + "DEBB021", + "DEBB038", + "DESN051", + "DEUB041", + "DEBW020", + "DEBW113", + "DENW078", + "DEHE018", + "DEBW065", + "DEBY062", + "DEBW027", + "DEBW041", + "DEHE043", + "DEMV007", + "DEMV021", + "DEBW054", + "DETH005", + "DESL012", + "DESL011", + "DEST069", + "DEST071", + "DEUB004", + "DESH006", + "DEUB029", + "DEUB040", + "DESN074", + "DEBW031", + "DENW013", + "DENW179", + "DEBW056", + "DEBW087", + "DEST061", + "DEMV001", + "DEBB024", + "DEBW057", + "DENW064", + "DENW068", + "DENW080", + "DENI019", + "DENI077", + "DEHE026", + "DEBB066", + "DEBB083", + "DEST063", + "DEBW013", + "DETH086", + "DESL018", + "DETH096", + "DEBW059", + "DEBY072", + "DEBY088", + "DEBW060", + "DEBW107", + "DEBW036", + "DEUB026", + "DEBW019", + "DENW010", + "DEST098", + "DEHE019", + "DEBW039", + "DESL017", + "DEBW034", + "DEUB005", + "DEBB051", + "DEHE051", + "DEBW023", + "DEBY092", + "DEBW008", + "DEBW030", + "DENI060", + "DEST011", + "DENW030", + "DENI041", + "DERP015", + "DEUB001", + "DERP016", + "DERP028", + "DERP013", + "DEHE022", + "DEUB021", + "DEBW010", + "DEST066", + "DEBB063", + "DEBB028", + "DEHE024", + "DENI020", + "DENI051", + "DERP025", + "DEBY077", + "DEMV018", + "DEST089", + "DEST028", + "DETH060", + "DEHE050", + "DEUB028", + "DESN045", + "DEUB042" +] diff --git a/conftest.py b/conftest.py index 92d2159c3b3a3efd7d0c0bfb5bf6bb058697d79c..0726ea7cf9dbd259913c22cb87f83cb47ad5f40c 100644 --- a/conftest.py +++ b/conftest.py @@ -1,4 +1,5 @@ import os +import re import shutil @@ -20,6 +21,25 @@ def pytest_runtest_teardown(item, nextitem): shutil.rmtree(os.path.join(path, "data"), ignore_errors=True) if "TestExperiment" in list_dir: shutil.rmtree(os.path.join(path, "TestExperiment"), ignore_errors=True) + # remove all tracking json + remove_files_from_regex(list_dir, path, re.compile(r"tracking_\d*\.json")) + # remove all tracking pdf + remove_files_from_regex(list_dir, path, re.compile(r"tracking\.pdf")) + # remove all tracking json + remove_files_from_regex(list_dir, path, re.compile(r"logging_\d*\.log")) else: pass # nothing to do if next test is from same test class + +def remove_files_from_regex(list_dir, path, regex): + r = list(filter(regex.search, list_dir)) + if len(r) > 0: + for e in r: + del_path = os.path.join(path, e) + try: + if os.path.isfile(del_path): + os.remove(del_path) + else: + shutil.rmtree(os.path.join(path, e), ignore_errors=True) + except: + pass diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000000000000000000000000000000000..81b6117c9414e4857954b0867364a514752deaa3 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = _source +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/_source/_api/machinelearningtools.rst b/docs/_source/_api/machinelearningtools.rst new file mode 100644 index 0000000000000000000000000000000000000000..cd6885f52bedfa295139251c641c5bba8e2a30e9 --- /dev/null +++ b/docs/_source/_api/machinelearningtools.rst @@ -0,0 +1,10 @@ +machinelearningtools package +============================ + +.. automodule:: src + :members: + :undoc-members: + :show-inheritance: + +Submodules +---------- diff --git a/docs/_source/_plots/conditional_quantiles_cali-ref_plot.png b/docs/_source/_plots/conditional_quantiles_cali-ref_plot.png new file mode 100644 index 0000000000000000000000000000000000000000..94373ab2b71a2a719fbeac84a5e6b5230f93909c Binary files /dev/null and b/docs/_source/_plots/conditional_quantiles_cali-ref_plot.png differ diff --git a/docs/_source/_plots/conditional_quantiles_like-bas_plot.png b/docs/_source/_plots/conditional_quantiles_like-bas_plot.png new file mode 100644 index 0000000000000000000000000000000000000000..1641a12f678028c96646b2daabbf06c599cfb86a Binary files /dev/null and b/docs/_source/_plots/conditional_quantiles_like-bas_plot.png differ diff --git a/docs/_source/_plots/data_availability.png b/docs/_source/_plots/data_availability.png new file mode 100644 index 0000000000000000000000000000000000000000..a2350c4f57befb65b5d90721b9ae51257b59c4a5 Binary files /dev/null and b/docs/_source/_plots/data_availability.png differ diff --git a/docs/_source/_plots/data_availability_combined.png b/docs/_source/_plots/data_availability_combined.png new file mode 100644 index 0000000000000000000000000000000000000000..ae8fa5c034b3694171ec348cdc20fa3f73795691 Binary files /dev/null and b/docs/_source/_plots/data_availability_combined.png differ diff --git a/docs/_source/_plots/data_availability_summary.png b/docs/_source/_plots/data_availability_summary.png new file mode 100644 index 0000000000000000000000000000000000000000..db88b4d1ea4b5d22b8c04143da0824beef41eff9 Binary files /dev/null and b/docs/_source/_plots/data_availability_summary.png differ diff --git a/docs/_source/_plots/monthly_summary_box_plot.png b/docs/_source/_plots/monthly_summary_box_plot.png new file mode 100644 index 0000000000000000000000000000000000000000..f7447d8283adeb62d43d322769bf08925c0e2d89 Binary files /dev/null and b/docs/_source/_plots/monthly_summary_box_plot.png differ diff --git a/docs/_source/_plots/skill_score_bootstrap.png b/docs/_source/_plots/skill_score_bootstrap.png new file mode 100644 index 0000000000000000000000000000000000000000..844bf7f48cd32d588363b75623c7b7d5691a9988 Binary files /dev/null and b/docs/_source/_plots/skill_score_bootstrap.png differ diff --git a/docs/_source/_plots/skill_score_clim_CNN.png b/docs/_source/_plots/skill_score_clim_CNN.png new file mode 100644 index 0000000000000000000000000000000000000000..28a66b5c43b71c39a57d81123dfca7e3158dd8ce Binary files /dev/null and b/docs/_source/_plots/skill_score_clim_CNN.png differ diff --git a/docs/_source/_plots/skill_score_clim_all_terms_CNN.png b/docs/_source/_plots/skill_score_clim_all_terms_CNN.png new file mode 100644 index 0000000000000000000000000000000000000000..000b942154dbe9dde9f48f64ab1b967a6811907d Binary files /dev/null and b/docs/_source/_plots/skill_score_clim_all_terms_CNN.png differ diff --git a/docs/_source/_plots/skill_score_competitive.png b/docs/_source/_plots/skill_score_competitive.png new file mode 100644 index 0000000000000000000000000000000000000000..6b5342c31579c9c6c59ebacded8a92d02cb7c1f4 Binary files /dev/null and b/docs/_source/_plots/skill_score_competitive.png differ diff --git a/docs/_source/_plots/station_map.png b/docs/_source/_plots/station_map.png new file mode 100644 index 0000000000000000000000000000000000000000..181440f4003a65cdacfae66309fb981f3bb420b8 Binary files /dev/null and b/docs/_source/_plots/station_map.png differ diff --git a/docs/_source/_plots/testrun_network_daily_history_learning_rate-1.png b/docs/_source/_plots/testrun_network_daily_history_learning_rate-1.png new file mode 100644 index 0000000000000000000000000000000000000000..c433a6431fb84322ca0097cb5b567aec1d063661 Binary files /dev/null and b/docs/_source/_plots/testrun_network_daily_history_learning_rate-1.png differ diff --git a/docs/_source/_plots/testrun_network_daily_history_loss-1.png b/docs/_source/_plots/testrun_network_daily_history_loss-1.png new file mode 100644 index 0000000000000000000000000000000000000000..3a2234e4b39036f843396f2538ebbe5d4ec8ed5b Binary files /dev/null and b/docs/_source/_plots/testrun_network_daily_history_loss-1.png differ diff --git a/docs/_source/_plots/testrun_network_daily_history_main_mse-1.png b/docs/_source/_plots/testrun_network_daily_history_main_mse-1.png new file mode 100644 index 0000000000000000000000000000000000000000..71f2f2cea3e55d5c3cd404187d95e3255aea4e63 Binary files /dev/null and b/docs/_source/_plots/testrun_network_daily_history_main_mse-1.png differ diff --git a/docs/_source/api.rst b/docs/_source/api.rst new file mode 100644 index 0000000000000000000000000000000000000000..63db2308c7ac34ddaa7e498e62356066d4c2c811 --- /dev/null +++ b/docs/_source/api.rst @@ -0,0 +1,9 @@ +Package Reference +================= + +Information on specific functions, classes, and methods. + +.. toctree:: + :glob: + + _api/* \ No newline at end of file diff --git a/docs/_source/conf.py b/docs/_source/conf.py new file mode 100644 index 0000000000000000000000000000000000000000..6363f57eb45e686f6f2ef8ab07806e4feba0fe2d --- /dev/null +++ b/docs/_source/conf.py @@ -0,0 +1,133 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys + +sys.path.insert(0, os.path.abspath('../..')) + +# -- Project information ----------------------------------------------------- + +project = 'machinelearningtools' +copyright = '2020, Lukas H Leufen, Felix Kleinert' +author = 'Lukas H Leufen, Felix Kleinert' + +# The short X.Y version +version = 'v0.9.0' +# The full version, including alpha/beta/rc tags +release = 'v0.9.0' + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.intersphinx', + 'sphinx.ext.todo', + 'sphinx.ext.coverage', + 'sphinx.ext.imgmath', + 'sphinx.ext.ifconfig', + # 'sphinx.ext.viewcode', + 'sphinx.ext.autosummary', + 'autoapi.extension', + 'sphinx.ext.napoleon', + 'sphinx_rtd_theme', + 'sphinx.ext.githubpages', + 'recommonmark', + 'sphinx.ext.autosectionlabel', + 'sphinx_autodoc_typehints', # must be loaded after napoleon +] + +# 2020-02-19 Begin +# following instruction based on +# https://stackoverflow.com/questions/2701998/sphinx-autodoc-is-not-automatic-enough +autosummary_generate = True + +autoapi_type = 'python' +autoapi_dirs = ['../../src/.'] + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# add asource file parser for markdown +source_parsers = { + '.md': 'recommonmark.parser.CommonMarkParser', +} + +# The suffix(es) of source filenames. +# You can specify multiple suffix as a list of string: +# +source_suffix = ['.rst', '.md'] + +# The master toctree document. +master_doc = 'index' + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +# html_theme = 'alabaster' +# html_theme = 'bizstyle' +# html_theme = 'classic' +html_theme = 'sphinx_rtd_theme' + + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# -- Options for LaTeX output ------------------------------------------------ + +latex_elements = { + # The paper size ('letterpaper' or 'a4paper'). + # + # 'papersize': 'letterpaper', + + # The font size ('10pt', '11pt' or '12pt'). + # + # 'pointsize': '10pt', + + # Additional stuff for the LaTeX preamble. + # + # 'preamble': '', + + # Latex figure (float) alignment + # + # 'figure_align': 'htbp', +} + +# Grouping the document tree into LaTeX files. List of tuples +# (source start file, target name, title, +# author, documentclass [howto, manual, or own class]). +latex_documents = [ + (master_doc, 'machinelearningtools.tex', 'MachineLearningTools Documentation', + author, 'manual'), +] + +# -- Options for intersphinx extension --------------------------------------- + +# Example configuration for intersphinx: refer to the Python standard library. +intersphinx_mapping = { + 'python': ('https://docs.python.org/3', None), + 'pandas': ('http://pandas.pydata.org/pandas-docs/stable/', None), + 'numpy': ('https://docs.scipy.org/doc/numpy/', None), + 'matplotlib': ('https://matplotlib.org/', None) +} diff --git a/docs/_source/get-started.rst b/docs/_source/get-started.rst new file mode 100644 index 0000000000000000000000000000000000000000..e5a82fdcf1d16ca2188a04e3dce76dc7ba9d477a --- /dev/null +++ b/docs/_source/get-started.rst @@ -0,0 +1,16 @@ +Get started with MachineLearningTools +===================================== + +<what is machinelearningtools?> + +MLT Module and Funtion Documentation +------------------------------------ + +Install MachineLearningTools +---------------------------- + +Dependencies +~~~~~~~~~~~~ + +Data +~~~~ diff --git a/docs/_source/index.rst b/docs/_source/index.rst new file mode 100644 index 0000000000000000000000000000000000000000..341ac58acd62ccc5bcf786580fff1bc193170d62 --- /dev/null +++ b/docs/_source/index.rst @@ -0,0 +1,22 @@ +.. machinelearningtools documentation master file, created by + sphinx-quickstart on Wed Apr 15 14:27:29 2020. + You can adapt this file completely to your liking, but it should at least + contain the root `toctree` directive. + +Welcome to machinelearningtools's documentation! +================================================ + +.. toctree:: + :maxdepth: 2 + :caption: Contents: + + get-started + api + + +Indices and tables +================== + +* :ref:`genindex` +* :ref:`modindex` +* :ref:`search` diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000000000000000000000000000000000000..36aafbd3727749c032ec16ed5cffe09359391cb7 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=_source +set BUILDDIR=_build + +if "%1" == "" goto help + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.http://sphinx-doc.org/ + exit /b 1 +) + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/requirements_docs.txt b/docs/requirements_docs.txt new file mode 100644 index 0000000000000000000000000000000000000000..a1294e314d9d04402ba7c063754a56b49deab602 --- /dev/null +++ b/docs/requirements_docs.txt @@ -0,0 +1,5 @@ +sphinx==3.0.3 +sphinx-autoapi==1.3.0 +sphinx-autodoc-typehints==1.10.3 +sphinx-rtd-theme==0.4.3 +recommonmark==0.6.0 \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index b46f44416cf6560ecc0b62f8d22dd7d547a036c6..71bb1338effff38092510982d4a2c1f37f7b026a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -59,6 +59,7 @@ tensorflow-estimator==1.13.0 tensorflow==1.13.1 termcolor==1.1.0 toolz==0.10.0 +typing-extensions urllib3==1.25.8 wcwidth==0.1.8 Werkzeug==1.0.0 diff --git a/requirements_gpu.txt b/requirements_gpu.txt index 6ce4df8fe164408024e21db5ea94a692fb5dbf26..5ddb56acc71e0a51abb99b9447f871ddcb715a5d 100644 --- a/requirements_gpu.txt +++ b/requirements_gpu.txt @@ -59,6 +59,7 @@ tensorflow-estimator==1.13.0 tensorflow-gpu==1.13.1 termcolor==1.1.0 toolz==0.10.0 +typing-extensions urllib3==1.25.8 wcwidth==0.1.8 Werkzeug==1.0.0 diff --git a/run.py b/run.py index bfe1c7ed62d9b2e8a707c117e252ff3f931f339a..572e59e2f75c4068bc63ecb6bb54a687bafebf4b 100644 --- a/run.py +++ b/run.py @@ -1,7 +1,6 @@ __author__ = "Lukas Leufen" __date__ = '2019-11-14' - import argparse from src.run_modules.experiment_setup import ExperimentSetup @@ -14,7 +13,6 @@ from src.run_modules.training import Training def main(parser_args): - with RunEnvironment(): ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], station_type='background', trainable=False, create_new_model=True, window_history_size=6, diff --git a/run_hourly.py b/run_hourly.py index 3c3135c46df9875633499bd17b237a23cdf6be55..559bf1a1056928f55f9ff3527805da121091d830 100644 --- a/run_hourly.py +++ b/run_hourly.py @@ -1,9 +1,7 @@ __author__ = "Lukas Leufen" __date__ = '2019-11-14' - import argparse -import logging from src.run_modules.experiment_setup import ExperimentSetup from src.run_modules.model_setup import ModelSetup @@ -14,7 +12,6 @@ from src.run_modules.training import Training def main(parser_args): - with RunEnvironment(): ExperimentSetup(parser_args, stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001'], station_type='background', trainable=True, sampling="hourly", window_history_size=48) @@ -28,7 +25,6 @@ def main(parser_args): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, help="set experiment date as string") diff --git a/run_zam347.py b/run_zam347.py index 1e140f48188a6df7207e04d048f38d9701c69d4b..d95067bb84a91230b0877f7a2b3d0cac5dc495e1 100644 --- a/run_zam347.py +++ b/run_zam347.py @@ -1,7 +1,6 @@ __author__ = "Lukas Leufen" __date__ = '2019-11-14' - import argparse import json import logging @@ -15,7 +14,6 @@ from src.run_modules.training import Training def load_stations(): - try: filename = 'German_background_stations.json' with open(filename, 'r') as jfile: @@ -31,9 +29,7 @@ def load_stations(): def main(parser_args): - with RunEnvironment(): - ExperimentSetup(parser_args, stations=load_stations(), station_type='background', trainable=False, create_new_model=True) PreProcessing() @@ -46,7 +42,6 @@ def main(parser_args): if __name__ == "__main__": - parser = argparse.ArgumentParser() parser.add_argument('--experiment_date', metavar='--exp_date', type=str, default=None, help="set experiment date as string") diff --git a/src/__init__.py b/src/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..452d0ed8b95a6300a2a47b65be78a5ddf4e968d6 100644 --- a/src/__init__.py +++ b/src/__init__.py @@ -0,0 +1,5 @@ +""" +Test string + +This is all about machine learning tools +""" \ No newline at end of file diff --git a/src/.gitignore b/src/configuration/.gitignore similarity index 100% rename from src/.gitignore rename to src/configuration/.gitignore diff --git a/src/configuration/__init__.py b/src/configuration/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4b461174af6a62f05364ebdff45141cd8ccc05c8 --- /dev/null +++ b/src/configuration/__init__.py @@ -0,0 +1,2 @@ +"""Collection of configuration functions, paths and classes.""" +from .path_config import ROOT_PATH, prepare_host, set_experiment_name, set_bootstrap_path, check_path_and_create \ No newline at end of file diff --git a/src/configuration/join_settings.py b/src/configuration/join_settings.py new file mode 100644 index 0000000000000000000000000000000000000000..22d8b813d6b01c300e37c9d8a0dd4eb343cc87df --- /dev/null +++ b/src/configuration/join_settings.py @@ -0,0 +1,24 @@ +"""Settings to access not public join data.""" +from typing import Tuple, Dict + + +def join_settings(sampling="daily") -> Tuple[str, Dict]: + """ + Set url for join and required headers. + + Headers information is not required for daily resolution. For hourly data "Authorization": "<yourtoken>" is required + to retrieve any data at all. + + :param sampling: temporal resolution to access. Hourly data requires authorisation. + + :return: Service url and optional headers + """ + if sampling == "daily": # pragma: no branch + TOAR_SERVICE_URL = 'https://join.fz-juelich.de/services/rest/surfacedata/' + headers = {} + elif sampling == "hourly": + TOAR_SERVICE_URL = 'https://join.fz-juelich.de/services/rest/surfacedata/' + headers = {"Authorization": "Token 12345"} + else: + raise NameError(f"Given sampling {sampling} is not supported, choose from either daily or hourly sampling.") + return TOAR_SERVICE_URL, headers diff --git a/src/configuration/path_config.py b/src/configuration/path_config.py new file mode 100644 index 0000000000000000000000000000000000000000..289c15821db587b0866eb4808981cfae640cb9a0 --- /dev/null +++ b/src/configuration/path_config.py @@ -0,0 +1,117 @@ +"""Functions related to path and os name setting.""" +import logging +import os +import re +import socket +from typing import Tuple + +ROOT_PATH = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")) + + +def prepare_host(create_new=True, data_path=None, sampling="daily") -> str: + """ + Set up host path. + + INFO: This functions is designed to handle known hosts. For proper working, please add your hostname hardcoded here. + Otherwise parse your custom data_path in kwargs. If data_path is provided, hardcoded paths for known hosts will be + ignored! + + :param create_new: Create new path if enabled + :param data_path: Parse your custom path (and therefore ignore preset paths fitting to known hosts) + :param sampling: sampling rate to separate data physically by temporal resolution + + :return: full path to data + """ + hostname = socket.gethostname() + runner_regex = re.compile(r"runner-.*-project-2411-concurrent-\d+") + try: + user = os.getlogin() + except OSError: + user = "default" + if data_path is None: + if hostname == "ZAM144": + path = f"/home/{user}/Data/toar_{sampling}/" + elif hostname == "zam347": + path = f"/home/{user}/Data/toar_{sampling}/" + elif hostname == "linux-aa9b": + path = f"/home/{user}/machinelearningtools/data/toar_{sampling}/" + elif (len(hostname) > 2) and (hostname[:2] == "jr"): + path = f"/p/project/cjjsc42/{user}/DATA/toar_{sampling}/" + elif (len(hostname) > 2) and (hostname[:2] == "jw"): + path = f"/p/home/jusers/{user}/juwels/intelliaq/DATA/toar_{sampling}/" + elif runner_regex.match(hostname) is not None: + path = f"/home/{user}/machinelearningtools/data/toar_{sampling}/" + else: + raise OSError(f"unknown host '{hostname}'") + else: + path = os.path.abspath(data_path) + if not os.path.exists(path): + try: + if create_new: + check_path_and_create(path) + return path + else: + raise PermissionError + except PermissionError: + raise NotADirectoryError(f"path '{path}' does not exist for host '{hostname}'.") + else: + logging.debug(f"set path to: {path}") + return path + + +def set_experiment_name(experiment_name=None, experiment_path=None, sampling=None) -> Tuple[str, str]: + """ + Set name of experiment and its path. + + * Experiment name is set to `TestExperiment` if not provided in kwargs. If a name is given, this string is expanded + by suffix `_network`. Experiment name is always expanded by `_<sampling>` as ending suffix if sampling is given. + * Experiment path is set to `ROOT_PATH/<exp_name>` if not provided or otherwise use `<experiment_path>/<exp_name>` + + :param experiment_name: custom experiment name + :param experiment_path: custom experiment path + :param sampling: sampling rate as string to add to experiment name + + :return: experiment name and full experiment path + """ + if experiment_name is None: + experiment_name = "TestExperiment" + else: + experiment_name = f"{experiment_name}_network" + if sampling is not None: + experiment_name += f"_{sampling}" + if experiment_path is None: + experiment_path = os.path.abspath(os.path.join(ROOT_PATH, experiment_name)) + else: + experiment_path = os.path.join(os.path.abspath(experiment_path), experiment_name) + return experiment_name, experiment_path + + +def set_bootstrap_path(bootstrap_path: str, data_path: str, sampling: str) -> str: + """ + Set path for bootstrap input data. + + Either use given bootstrap_path or create additional folder in same directory like data path. + + :param bootstrap_path: custom path to store bootstrap data + :param data_path: path of data for default bootstrap path + :param sampling: sampling rate to add, if path is set to default + + :return: full bootstrap path + """ + if bootstrap_path is None: + bootstrap_path = os.path.join(data_path, "..", f"bootstrap_{sampling}") + check_path_and_create(bootstrap_path) + return os.path.abspath(bootstrap_path) + + +def check_path_and_create(path: str) -> None: + """ + Check a given path and create if not existing. + + :param path: path to check and create + """ + try: + os.makedirs(path) + logging.debug(f"Created path: {path}") + except FileExistsError: + logging.debug(f"Path already exists: {path}") \ No newline at end of file diff --git a/src/data_handling/__init__.py b/src/data_handling/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..5139d13da91025dad15db04fdfea34309c4e28ff 100644 --- a/src/data_handling/__init__.py +++ b/src/data_handling/__init__.py @@ -0,0 +1,15 @@ +""" +Data Handling. + +The module data_handling contains all methods and classes that are somehow related to data preprocessing, +postprocessing, loading, and distribution for training. +""" + +__author__ = 'Lukas Leufen, Felix Kleinert' +__date__ = '2020-04-17' + + +from .bootstraps import BootStraps +from .data_preparation import DataPrep +from .data_generator import DataGenerator +from .data_distributor import Distributor diff --git a/src/data_handling/bootstraps.py b/src/data_handling/bootstraps.py index 46fa7c2be39d3dadb1922a1b710065aa42d9e2d2..f50775900c053cbef0c94e6a3e2743c9a017bf88 100644 --- a/src/data_handling/bootstraps.py +++ b/src/data_handling/bootstraps.py @@ -1,41 +1,71 @@ +""" +Collections of bootstrap methods and classes. + +How to use +---------- + +test + +""" + __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2020-02-07' -from src.data_handling.data_generator import DataGenerator -import numpy as np import logging -import keras -import dask.array as da -import xarray as xr import os import re -from src import helpers from typing import List, Union, Pattern, Tuple +import dask.array as da +import keras +import numpy as np +import xarray as xr + +from src import helpers +from src.data_handling.data_generator import DataGenerator + class BootStrapGenerator(keras.utils.Sequence): """ + Generator that returns bootstrapped history objects for given boot index while iteration. + generator for bootstraps as keras sequence inheritance. Initialise with number of boots, the original history, the shuffled data, all used variables and the current shuffled variable. While iterating over this generator, it returns the bootstrapped history for given boot index (this is the iterator index) in the same format like the original history ready to use. Note, that in some cases some samples can contain nan values (in these cases the entire data row is null, not only single entries). """ + def __init__(self, number_of_boots: int, history: xr.DataArray, shuffled: xr.DataArray, variables: List[str], shuffled_variable: str): + """ + Set up the generator. + + :param number_of_boots: number of bootstrap realisations + :param history: original history (the ground truth) + :param shuffled: the shuffled history + :param variables: list with all variables of interest + :param shuffled_variable: name of the variable that shall be bootstrapped + """ self.number_of_boots = number_of_boots self.variables = variables self.history_orig = history - self.history = history.sel(variables=helpers.list_pop(self.variables, shuffled_variable)) + self.history = history.sel(variables=helpers.remove_items(self.variables, shuffled_variable)) self.shuffled = shuffled.sel(variables=shuffled_variable) def __len__(self) -> int: + """ + Return number of bootstraps. + + :return: number of bootstraps + """ return self.number_of_boots def __getitem__(self, index: int) -> xr.DataArray: """ - return bootstrapped history for given bootstrap index in same index structure like the original history object + Return bootstrapped history for given bootstrap index in same index structure like the original history object. + :param index: boot index e [0, nboots-1] :return: bootstrapped history ready to use """ @@ -46,7 +76,8 @@ class BootStrapGenerator(keras.utils.Sequence): def __get_shuffled(self, index: int) -> xr.DataArray: """ - returns shuffled data for given boot index from shuffled attribute + Return shuffled data for given boot index from shuffled attribute. + :param index: boot index e [0, nboots-1] :return: shuffled data """ @@ -56,10 +87,20 @@ class BootStrapGenerator(keras.utils.Sequence): class CreateShuffledData: """ - Verify and create shuffled data for all data contained in given data generator class. Starts automatically on - initialisation, no further calls are required. Check and new creations are all performed inside bootstrap_path. + Verify and create shuffled data for all data contained in given data generator class. + + Starts automatically on initialisation, no further calls are required. Check and new creations are all performed + inside bootstrap_path. """ + def __init__(self, data: DataGenerator, number_of_bootstraps: int, bootstrap_path: str): + """ + Shuffled data is automatically created in initialisation. + + :param data: data to shuffle + :param number_of_bootstraps: + :param bootstrap_path: Path to find and store the bootstraps + """ self.data = data self.number_of_bootstraps = number_of_bootstraps self.bootstrap_path = bootstrap_path @@ -67,9 +108,11 @@ class CreateShuffledData: def create_shuffled_data(self) -> None: """ - Create shuffled data. Use original test data, add dimension 'boots' with length number of bootstraps and insert - randomly selected variables. If there is a suitable local file for requested window size and number of - bootstraps, no additional file will be created inside this function. + Create shuffled data. + + Use original test data, add dimension 'boots' with length number of bootstraps and insert randomly selected + variables. If there is a suitable local file for requested window size and number of bootstraps, no additional + file will be created inside this function. """ logging.info("create / check shuffled bootstrap data") variables_str = '_'.join(sorted(self.data.variables)) @@ -92,8 +135,11 @@ class CreateShuffledData: def _set_file_path(self, station: str, variables: str, window: int, nboots: int) -> str: """ + Set file name. + Set file name following naming convention <station>_<var1>_<var2>_..._hist<window>_nboots<nboots>_shuffled.nc - and creates joined path using bootstrap_path attribute set on initialisation. + and create joined path using bootstrap_path attribute set on initialisation. + :param station: station name :param variables: variables already preprocessed as single string with all variables seperated by underscore :param window: window length @@ -105,13 +151,15 @@ class CreateShuffledData: def valid_bootstrap_file(self, station: str, variables: str, window: int) -> [bool, Union[None, int]]: """ - Compare local bootstrap file with given settings for station, variables, window and number of bootstraps. If a - match was found, this method returns a tuple (True, None). In any other case, it returns (False, max_nboot), - where max_nboot is the highest boot number found in the local storage. A match is defined so that the window - length is ge than given window size form args and the number of boots is also ge than the given number of boots - from this class. Furthermore, this functions deletes local files, if the match the station pattern but don't fit - the window and bootstrap condition. This is performed, because it is assumed, that the corresponding file will - be created with a longer or at the least same window size and numbers of bootstraps. + Compare local bootstrap file with given settings for station, variables, window and number of bootstraps. + + If a match was found, this method returns a tuple (True, None). In any other case, it returns (False, + max_nboot), where max_nboot is the highest boot number found in the local storage. A match is defined so that + the window length is ge than given window size form args and the number of boots is also ge than the given + number of boots from this class. Furthermore, this functions deletes local files, if the match the station + pattern but don't fit the window and bootstrap condition. This is performed, because it is assumed, that the + corresponding file will be created with a longer or at the least same window size and numbers of bootstraps. + :param station: name of the station to validate :param variables: all variables already merged in single string seperated by underscore :param window: required window size @@ -136,21 +184,25 @@ class CreateShuffledData: @staticmethod def shuffle(data: da.array, chunks: Tuple) -> da.core.Array: """ - Shuffle randomly from given data (draw elements with replacement) + Shuffle randomly from given data (draw elements with replacement). + :param data: data to shuffle :param chunks: chunk size for dask :return: shuffled data as dask core array (not computed yet) """ size = data.shape - return da.random.choice(data.reshape(-1,), size=size, chunks=chunks) + return da.random.choice(data.reshape(-1, ), size=size, chunks=chunks) class BootStraps: """ - Main class to perform bootstrap operations. This class requires a DataGenerator object and a path, where to find and - store all data related to the bootstrap operation. In initialisation, this class will automatically call the class - CreateShuffleData to set up the shuffled data sets. How to use BootStraps: - * call .get_generator(<station>, <variable>) to get a generator for given station and variable combination that + Main class to perform bootstrap operations. + + This class requires a DataGenerator object and a path, where to find and store all data related to the bootstrap + operation. In initialisation, this class will automatically call the class CreateShuffleData to set up the shuffled + data sets. How to use BootStraps: + + * call .get_generator(<station>, <variable>) to get a generator for given station and variable combination that \ iterates over all bootstrap realisations (as keras sequence) * call .get_labels(<station>) to get the measured observations in the same format as bootstrap predictions * call .get_bootstrap_predictions(<station>, <variable>) to get the bootstrapped predictions @@ -158,6 +210,13 @@ class BootStraps: """ def __init__(self, data: DataGenerator, bootstrap_path: str, number_of_bootstraps: int = 10): + """ + Automatically check and create (if needed) shuffled data on initialisation. + + :param data: a data generator object to get data / history + :param bootstrap_path: path to find and store the bootstrap data + :param number_of_bootstraps: the number of bootstrap realisations + """ self.data = data self.number_of_bootstraps = number_of_bootstraps self.bootstrap_path = bootstrap_path @@ -165,20 +224,38 @@ class BootStraps: @property def stations(self) -> List[str]: + """ + Station property inherits directly from data generator object. + + :return: list with all stations + """ return self.data.stations @property def variables(self) -> List[str]: + """ + Variables property inherits directly from data generator object. + + :return: list with all variables + """ return self.data.variables @property def window_history_size(self) -> int: + """ + Window history size property inherits directly from data generator object. + + :return: the window history size + """ return self.data.window_history_size def get_generator(self, station: str, variable: str) -> BootStrapGenerator: """ - Returns the actual generator to use for the bootstrap evaluation. The generator requires information on station - and bootstrapped variable. There is only a loop on the bootstrap realisation and not on stations or variables. + Return the actual generator to use for the bootstrap evaluation. + + The generator requires information on station and bootstrapped variable. There is only a loop on the bootstrap + realisation and not on stations or variables. + :param station: name of the station :param variable: name of the variable to bootstrap :return: BootStrapGenerator class ready to use. @@ -189,7 +266,8 @@ class BootStraps: def get_labels(self, station: str) -> np.ndarray: """ - Repeats labels for given key by the number of boots and returns as single array. + Repeat labels for given key by the number of boots and returns as single array. + :param station: name of station :return: repeated labels as single array """ @@ -198,7 +276,8 @@ class BootStraps: def get_orig_prediction(self, path: str, file_name: str, prediction_name: str = "CNN") -> np.ndarray: """ - Repeats predictions from given file(_name) in path by the number of boots. + Repeat predictions from given file(_name) in path by the number of boots. + :param path: path to file :param file_name: file name :param prediction_name: name of the prediction to select from loaded file (default CNN) @@ -211,9 +290,11 @@ class BootStraps: def _load_shuffled_data(self, station: str, variables: List[str]) -> xr.DataArray: """ - Load shuffled data from bootstrap path. Data is stored as - '<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc', e.g. + Load shuffled data from bootstrap path. + + Data is stored as '<station>_<var1>_<var2>_..._hist<histsize>_nboots<nboots>_shuffled.nc', e.g. 'DEBW107_cloudcover_no_no2_temp_u_v_hist13_nboots20_shuffled.nc' + :param station: name of station :param variables: list of variables :return: shuffled data as xarray @@ -224,7 +305,8 @@ class BootStraps: def _get_shuffled_data_file(self, station: str, variables: List[str]) -> str: """ - Looks for data file using regular expressions and returns found file or raise FileNotFoundError + Look for data file using regular expressions and returns found file or raise FileNotFoundError. + :param station: name of station :param variables: name of variables :return: found file with complete path @@ -240,8 +322,11 @@ class BootStraps: @staticmethod def _create_file_regex(station: str, variables: List[str]) -> Pattern: """ - Creates regex for given station and variables to look for shuffled data with pattern: + Create regex for given station and variables. + + With this regex, it is possible to look for shuffled data with pattern: `<station>(_<var>)*_hist(<hist>)_nboots(<nboots>)_shuffled.nc` + :param station: station name to use as prefix :param variables: variables to add after station :return: compiled regular expression @@ -253,10 +338,13 @@ class BootStraps: @staticmethod def _filter_files(regex: Pattern, files: List[str], window: int, nboot: int) -> Union[str, None]: """ - Filter list of files by regex. Regex has to be structured to match the following string structure + Filter list of files by regex. + + Regex has to be structured to match the following string structure `<station>(_<var>)*_hist(<hist>)_nboots(<nboots>)_shuffled.nc`. Hist and nboots values have to be included as group. All matches are compared to given window and nboot parameters. A valid file must have the same value (or larger) than these parameters and contain all variables. + :param regex: compiled regular expression pattern following the style from method description :param files: list of file names to filter :param window: minimum length of window to look for @@ -267,7 +355,7 @@ class BootStraps: match = regex.match(f) if match: last = match.lastindex - if (int(match.group(last-1)) >= window) and (int(match.group(last)) >= nboot): + if (int(match.group(last - 1)) >= window) and (int(match.group(last)) >= nboot): return f diff --git a/src/data_handling/data_distributor.py b/src/data_handling/data_distributor.py index e8c6044280799ded080ab4bff3627aeb9ffde2db..2600afcbd8948c26a2b4cf37329b424cac69f40a 100644 --- a/src/data_handling/data_distributor.py +++ b/src/data_handling/data_distributor.py @@ -1,3 +1,24 @@ +""" +Data Distribution Module. + +How to use +---------- + +Create distributor object from a generator object and parse it to the fit generator method. Provide the number of +steps per epoch with distributor's length method. + +.. code-block:: python + + model = YourKerasModel() + data_generator = DataGenerator(*args, **kwargs) + data_distributor = Distributor(data_generator, model, **kwargs) + history = model.fit_generator(generator=data_distributor.distribute_on_batches(), + steps_per_epoch=len(data_distributor), + epochs=10,) + +Additionally, a validation data set can be parsed using the length and distribute methods. +""" + from __future__ import generator_stop __author__ = "Lukas Leufen, Felix Kleinert" @@ -12,9 +33,20 @@ from src.data_handling.data_generator import DataGenerator class Distributor(keras.utils.Sequence): + """Distribute data generator elements according to mini batch size.""" def __init__(self, generator: DataGenerator, model: keras.models, batch_size: int = 256, permute_data: bool = False, upsampling: bool = False): + """ + Set up distributor. + + :param generator: The generator object must be iterable and return inputs and targets on each iteration + :param model: a keras model with one or more output branches + :param batch_size: batch size to use + :param permute_data: data is randomly permuted if enabled on each train step + :param upsampling: upsample data with upsample extremes data from generator object and shuffle data or use only + the standard input data. + """ self.generator = generator self.model = model self.batch_size = batch_size @@ -38,7 +70,11 @@ class Distributor(keras.utils.Sequence): def _permute_data(self, x, y): """ - Permute inputs x and labels y + Permute inputs x and labels y if permutation is enabled in instance. + + :param x: inputs + :param y: labels + :return: permuted or original data """ if self.do_data_permutation: p = np.random.permutation(len(x)) # equiv to .shape[0] @@ -47,6 +83,17 @@ class Distributor(keras.utils.Sequence): return x, y def distribute_on_batches(self, fit_call=True): + """ + Create generator object to distribute mini batches. + + Split data from given generator object (usually for single station) according to the given batch size. Also + perform upsampling if enabled and random shuffling (either if data permutation is enabled or if upsampling is + enabled). Lastly multiply targets if provided model has multiple output branches. + + :param fit_call: switch to exit while loop after first iteration. This is used to determine the length of all + distributed mini batches. For default, fit_call is True to obtain infinite loop for training. + :return: yields next mini batch + """ while True: for k, v in enumerate(self.generator): # get rank of output @@ -65,15 +112,20 @@ class Distributor(keras.utils.Sequence): num_mini_batches = self._get_number_of_mini_batches(x_total) # permute order for mini-batches x_total, y_total = self._permute_data(x_total, y_total) - for prev, curr in enumerate(range(1, num_mini_batches+1)): - x = x_total[prev*self.batch_size:curr*self.batch_size, ...] - y = [y_total[prev*self.batch_size:curr*self.batch_size, ...] for _ in range(mod_rank)] + for prev, curr in enumerate(range(1, num_mini_batches + 1)): + x = x_total[prev * self.batch_size:curr * self.batch_size, ...] + y = [y_total[prev * self.batch_size:curr * self.batch_size, ...] for _ in range(mod_rank)] if x is not None: # pragma: no branch - yield (x, y) + yield x, y if (k + 1) == len(self.generator) and curr == num_mini_batches and not fit_call: return - def __len__(self): + def __len__(self) -> int: + """ + Total number of distributed mini batches. + + :return: the length of the distribute on batches object + """ num_batch = 0 for _ in self.distribute_on_batches(fit_call=False): num_batch += 1 diff --git a/src/data_handling/data_generator.py b/src/data_handling/data_generator.py index 8d10b3e438e185b9fd158259a6ba49a5612737be..6747e82e0da2d1a68c99a09d75c76cdcd53a05ba 100644 --- a/src/data_handling/data_generator.py +++ b/src/data_handling/data_generator.py @@ -1,36 +1,86 @@ +"""Data Generator class to handle large arrays for machine learning.""" + __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-11-07' +import logging import os +import pickle from typing import Union, List, Tuple, Any, Dict import dask.array as da import keras import xarray as xr -import pickle -import logging from src import helpers from src.data_handling.data_preparation import DataPrep -from src.join import EmptyQueryResult +from src.helpers.join import EmptyQueryResult number = Union[float, int] num_or_list = Union[number, List[number]] +data_or_none = Union[xr.DataArray, None] class DataGenerator(keras.utils.Sequence): """ - This class is a generator to handle large arrays for machine learning. This class can be used with keras' - fit_generator and predict_generator. Individual stations are the iterables. This class uses class Dataprep and - returns X, y when an item is called. - Item can be called manually by position (integer) or station id (string). Methods also accept lists with exactly - one entry of integer or string + This class is a generator to handle large arrays for machine learning. + + .. code-block:: python + + data_generator = DataGenerator(**args, **kwargs) + + Data generator item can be called manually by position (integer) or station id (string). Methods also accept lists + with exactly one entry of integer or string. + + .. code-block:: + + # select generator elements by position index + first_element = data_generator.get_data_generator([0]) # 1st element + n_element = data_generator.get_data_generator([4]) # 5th element + + # select by name + station_xy = data_generator.get_data_generator(["station_xy"]) # will raise KeyError if not available + + If used as iterator or directly called by get item method, the data generator class returns transposed labels and + history object from underlying data preparation class DataPrep. + + .. code-block:: python + + # select history and label by position + hist, labels = data_generator[0] + # by name + hist, labels = data_generator["station_xy"] + # as iterator + for (hist, labels) in data_generator: + pass + + This class can also be used with keras' fit_generator and predict_generator. Individual stations are the iterables. """ def __init__(self, data_path: str, network: str, stations: Union[str, List[str]], variables: List[str], interpolate_dim: str, target_dim: str, target_var: str, station_type: str = None, interpolate_method: str = "linear", limit_nan_fill: int = 1, window_history_size: int = 7, window_lead_time: int = 4, transformation: Dict = None, extreme_values: num_or_list = None, **kwargs): + """ + Set up data generator. + + :param data_path: path to data + :param network: the observational network, the data should come from + :param stations: list with all stations to include + :param variables: list with all used variables + :param interpolate_dim: dimension along which interpolation is applied + :param target_dim: dimension of target variable + :param target_var: name of target variable + :param station_type: TOAR station type classification (background, traffic) + :param interpolate_method: method of interpolation + :param limit_nan_fill: maximum gab in data to fill by interpolation + :param window_history_size: length of the history window + :param window_lead_time: lenght of the label window + :param transformation: transformation method to apply on data + :param extreme_values: set up the extreme value upsampling + :param kwargs: additional kwargs that are used in either DataPrep (transformation, start / stop period, ...) + or extreme values + """ self.data_path = os.path.abspath(data_path) self.data_path_tmp = os.path.join(os.path.abspath(data_path), "tmp") if not os.path.exists(self.data_path_tmp): @@ -51,34 +101,30 @@ class DataGenerator(keras.utils.Sequence): self.transformation = self.setup_transformation(transformation) def __repr__(self): - """ - display all class attributes - """ + """Display all class attributes.""" return f"DataGenerator(path='{self.data_path}', network='{self.network}', stations={self.stations}, " \ f"variables={self.variables}, station_type={self.station_type}, " \ f"interpolate_dim='{self.interpolate_dim}', target_dim='{self.target_dim}', " \ f"target_var='{self.target_var}', **{self.kwargs})" def __len__(self): - """ - display the number of stations - """ + """Return the number of stations.""" return len(self.stations) def __iter__(self) -> "DataGenerator": """ - Define the __iter__ part of the iterator protocol to iterate through this generator. Sets the private attribute - `_iterator` to 0. - :return: + Define the __iter__ part of the iterator protocol to iterate through this generator. + + Sets the private attribute `_iterator` to 0. """ self._iterator = 0 return self def __next__(self) -> Tuple[xr.DataArray, xr.DataArray]: """ - This is the implementation of the __next__ method of the iterator protocol. Get the data generator, and return - the history and label data of this generator. - :return: + Get the data generator, and return the history and label data of this generator. + + This is the implementation of the __next__ method of the iterator protocol. """ if self._iterator < self.__len__(): data = self.get_data_generator() @@ -92,14 +138,37 @@ class DataGenerator(keras.utils.Sequence): def __getitem__(self, item: Union[str, int]) -> Tuple[xr.DataArray, xr.DataArray]: """ - Defines the get item method for this generator. Retrieve data from generator and return history and labels. + Define the get item method for this generator. + + Retrieve data from generator and return history and labels. + :param item: station key to choose the data generator. :return: The generator's time series of history data and its labels """ data = self.get_data_generator(key=item) return data.get_transposed_history(), data.get_transposed_label() - def setup_transformation(self, transformation): + def setup_transformation(self, transformation: Dict): + """ + Set up transformation by extracting all relevant information. + + Extract all information from transformation dictionary. Possible keys are scope. method, mean, and std. Scope + can either be station or data. Station scope means, that data transformation is performed for each station + independently (somehow like batch normalisation), whereas data scope means a transformation applied on the + entire data set. + + * If using data scope, mean and standard deviation (each only if required by transformation method) can either + be calculated accurate or as an estimate (faster implementation). This must be set in dictionary either + as "mean": "accurate" or "mean": "estimate". In both cases, the required statistics are calculated and saved. + After this calculations, the mean key is overwritten by the actual values to use. + * If using station scope, no additional information is required. + * If a transformation should be applied on base of existing values, these need to be provided in the respective + keys "mean" and "std" (again only if required for given method). + + :param transformation: the transformation dictionary as described above. + + :return: updated transformation dictionary + """ if transformation is None: return transformation = transformation.copy() @@ -125,7 +194,17 @@ class DataGenerator(keras.utils.Sequence): transformation["std"] = std return transformation - def calculate_accurate_transformation(self, method): + def calculate_accurate_transformation(self, method: str) -> Tuple[data_or_none, data_or_none]: + """ + Calculate accurate transformation statistics. + + Use all stations of this generator and calculate mean and standard deviation on entire data set using dask. + Because there can be much data, this can take a while. + + :param method: name of transformation method + + :return: accurate calculated mean and std (depending on transformation) + """ tmp = [] mean = None std = None @@ -149,7 +228,22 @@ class DataGenerator(keras.utils.Sequence): return mean, std def calculate_estimated_transformation(self, method): - data = [[]]*len(self.variables) + """ + Calculate estimated transformation statistics. + + Use all stations of this generator and calculate mean and standard deviation first for each station separately. + Afterwards, calculate the average mean and standard devation as estimated statistics. Because this method does + not consider the length of each data set, the estimated mean distinguishes from the real data mean. Furthermore, + the estimated standard deviation is assumed to be the mean (also not weighted) of all deviations. But this is + mathematically not true, but still a rough and faster estimation of the true standard deviation. Do not use this + method for further statistical calculation. However, in the scope of data preparation for machine learning, this + approach is decent ("it is just scaling"). + + :param method: name of transformation method + + :return: accurate calculated mean and std (depending on transformation) + """ + data = [[]] * len(self.variables) coords = {"variables": self.variables, "Stations": range(0)} mean = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) std = xr.DataArray(data, coords=coords, dims=["variables", "Stations"]) @@ -168,12 +262,23 @@ class DataGenerator(keras.utils.Sequence): def get_data_generator(self, key: Union[str, int] = None, load_local_tmp_storage: bool = True, save_local_tmp_storage: bool = True) -> DataPrep: """ - Select data for given key, create a DataPrep object and interpolate, transform, make history and labels and - remove nans. + Create DataPrep object and preprocess data for given key. + + Select data for given key, create a DataPrep object and + * apply transformation (optional) + * interpolate + * make history, labels, and observation + * remove nans + * upsample extremes (optional). + Processed data can be stored locally in a .pickle file. If load local tmp storage is enabled, the get data + generator tries first to load data from local pickle file and only creates a new DataPrep object if it couldn't + load this data from disk. + :param key: station key to choose the data generator. :param load_local_tmp_storage: say if data should be processed from scratch or loaded as already processed data from tmp pickle file to save computational time (but of course more disk space required). :param save_local_tmp_storage: save processed data as temporal file locally (default True) + :return: preprocessed data as a DataPrep instance """ station = self.get_station_key(key) @@ -186,13 +291,13 @@ class DataGenerator(keras.utils.Sequence): data = DataPrep(self.data_path, self.network, station, self.variables, station_type=self.station_type, **self.kwargs) if self.transformation is not None: - data.transform("datetime", **helpers.dict_pop(self.transformation, "scope")) + data.transform("datetime", **helpers.remove_items(self.transformation, "scope")) data.interpolate(self.interpolate_dim, method=self.interpolate_method, limit=self.limit_nan_fill) data.make_history_window(self.target_dim, self.window_history_size, self.interpolate_dim) data.make_labels(self.target_dim, self.target_var, self.interpolate_dim, self.window_lead_time) data.make_observation(self.target_dim, self.target_var, self.interpolate_dim) data.remove_nan(self.interpolate_dim) - if self.extreme_values: + if self.extreme_values is not None: kwargs = {"extremes_on_right_tail_only": self.kwargs.get("extremes_on_right_tail_only", False)} data.multiply_extremes(self.extreme_values, **kwargs) if save_local_tmp_storage: @@ -201,7 +306,8 @@ class DataGenerator(keras.utils.Sequence): def _save_pickle_data(self, data: Any): """ - Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle' + Save given data locally as .pickle in self.data_path_tmp with name '<station>_<var1>_<var2>_..._<varX>.pickle'. + :param data: any data, that should be saved """ date = f"{self.kwargs.get('start')}_{self.kwargs.get('end')}" @@ -215,6 +321,7 @@ class DataGenerator(keras.utils.Sequence): def _load_pickle_data(self, station: Union[str, List[str]], variables: List[str]) -> Any: """ Load locally saved data from self.data_path_tmp and name '<station>_<var1>_<var2>_..._<varX>.pickle'. + :param station: station to load :param variables: list of variables to load :return: loaded data @@ -230,7 +337,8 @@ class DataGenerator(keras.utils.Sequence): def get_station_key(self, key: Union[None, str, int, List[Union[None, str, int]]]) -> str: """ - Return a valid station key or raise KeyError if this wasn't possible + Return a valid station key or raise KeyError if this wasn't possible. + :param key: station key to choose the data generator. :return: station key (id from database) """ diff --git a/src/data_handling/data_preparation.py b/src/data_handling/data_preparation.py index 5628394271918dc5631182d7de610db4ad335b7f..bb5254572e400b89a219ec674f408f09350f849c 100644 --- a/src/data_handling/data_preparation.py +++ b/src/data_handling/data_preparation.py @@ -1,45 +1,52 @@ +"""Data Preparation class to handle data processing for machine learning.""" + __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-10-16' import datetime as dt -from functools import reduce import logging import os +from functools import reduce from typing import Union, List, Iterable, Tuple import numpy as np import pandas as pd import xarray as xr -from src import join, helpers -from src import statistics +from src.configuration import check_path_and_create +from src import helpers +from src.helpers import join, statistics # define a more general date type for type hinting date = Union[dt.date, dt.datetime] str_or_list = Union[str, List[str]] number = Union[float, int] num_or_list = Union[number, List[number]] +data_or_none = Union[xr.DataArray, None] class DataPrep(object): """ - This class prepares data to be used in neural networks. The instance searches for local stored data, that meet the - given demands. If no local data is found, the DataPrep instance will load data from TOAR database and store this - data locally to use the next time. For the moment, there is only support for daily aggregated time series. The - aggregation can be set manually and differ for each variable. + This class prepares data to be used in neural networks. + + The instance searches for local stored data, that meet the given demands. If no local data is found, the DataPrep + instance will load data from TOAR database and store this data locally to use the next time. For the moment, there + is only support for daily aggregated time series. The aggregation can be set manually and differ for each variable. After data loading, different data pre-processing steps can be executed to prepare the data for further applications. Especially the following methods can be used for the pre-processing step: + - interpolate: interpolate between data points by using xarray's interpolation method - - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on - interval [0, 1] are not implemented yet. + - standardise: standardise data to mean=1 and std=1, centralise to mean=0, additional methods like normalise on \ + interval [0, 1] are not implemented yet. - make window history: represent the history (time steps before) for training/ testing; X - make labels: create target vector with given leading time steps for training/ testing; y - - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. Use - this method after the creation of the window history and labels to clean up the data cube. + - remove Nans jointly from desired input and output, only keeps time steps where no NaNs are present in X AND y. \ + Use this method after the creation of the window history and labels to clean up the data cube. To create a DataPrep instance, it is needed to specify the stations by id (e.g. "DEBW107"), its network (e.g. UBA, "Umweltbundesamt") and the variables to use. Further options can be set in the instance. + * `statistics_per_var`: define a specific statistic to extract from the TOAR database for each variable. * `start`: define a start date for the data cube creation. Default: Use the first entry in time series * `end`: set the end date for the data cube. Default: Use last date in time series. @@ -50,18 +57,19 @@ class DataPrep(object): def __init__(self, path: str, network: str, station: Union[str, List[str]], variables: List[str], station_type: str = None, **kwargs): + """Construct instance.""" self.path = os.path.abspath(path) self.network = network self.station = helpers.to_list(station) self.variables = variables self.station_type = station_type - self.mean = None - self.std = None - self.history = None - self.label = None - self.observation = None - self.extremes_history = None - self.extremes_label = None + self.mean: data_or_none = None + self.std: data_or_none = None + self.history: data_or_none = None + self.label: data_or_none = None + self.observation: data_or_none = None + self.extremes_history: data_or_none = None + self.extremes_label: data_or_none = None self.kwargs = kwargs self.data = None self.meta = None @@ -75,10 +83,13 @@ class DataPrep(object): def load_data(self): """ - Load data and meta data either from local disk (preferred) or download new data from TOAR database if no local - data is available. The latter case, store downloaded data locally if wished (default yes). + Load data and meta data either from local disk (preferred) or download new data from TOAR database. + + Data is either downloaded, if no local data is available or parameter overwrite_local_data is true. In both + cases, downloaded data is only stored locally if store_data_locally is not disabled. If this parameter is not + set, it is assumed, that data should be saved locally. """ - helpers.check_path_and_create(self.path) + check_path_and_create(self.path) file_name = self._set_file_name() meta_file = self._set_meta_file_name() if self.kwargs.get('overwrite_local_data', False): @@ -104,14 +115,25 @@ class DataPrep(object): logging.debug("loaded new data from JOIN") def download_data(self, file_name, meta_file): + """ + Download data from join, create slices and check for negative concentration. + + Handle sequence of required operation on new data downloads. First, download data using class method + download_data_from_join. Second, slice data using _slice_prep and lastly check for negative concentrations in + data with check_for_negative_concentrations. Finally, data is stored in instance attribute data. + + :param file_name: name of file to save data to (containing full path) + :param meta_file: name of the meta data file (also containing full path) + """ data, self.meta = self.download_data_from_join(file_name, meta_file) data = self._slice_prep(data) self.data = self.check_for_negative_concentrations(data) def check_station_meta(self): """ - Search for the entries in meta data and compare the value with the requested values. Raise a FileNotFoundError - if the values mismatch. + Search for the entries in meta data and compare the value with the requested values. + + Will raise a FileNotFoundError if the values mismatch. """ check_dict = {"station_type": self.station_type, "network_name": self.network} for (k, v) in check_dict.items(): @@ -124,9 +146,14 @@ class DataPrep(object): def download_data_from_join(self, file_name: str, meta_file: str) -> [xr.DataArray, pd.DataFrame]: """ Download data from TOAR database using the JOIN interface. - :param file_name: - :param meta_file: - :return: + + Data is transformed to a xarray dataset. If class attribute store_data_locally is true, data is additionally + stored locally using given names for file and meta file. + + :param file_name: name of file to save data to (containing full path) + :param meta_file: name of the meta data file (also containing full path) + + :return: downloaded data and its meta data """ df_all = {} df, meta = join.download_join(station_name=self.station, stat_var=self.statistics_per_var, @@ -150,15 +177,17 @@ class DataPrep(object): return os.path.join(self.path, f"{''.join(self.station)}_{'_'.join(all_vars)}_meta.csv") def __repr__(self): + """Represent class attributes.""" return f"Dataprep(path='{self.path}', network='{self.network}', station={self.station}, " \ f"variables={self.variables}, station_type={self.station_type}, **{self.kwargs})" def interpolate(self, dim: str, method: str = 'linear', limit: int = None, use_coordinate: Union[bool, str] = True, **kwargs): """ - (Copy paste from dataarray.interpolate_na) Interpolate values according to different methods. + (Copy paste from dataarray.interpolate_na) + :param dim: Specifies the dimension along which to interpolate. :param method: @@ -187,14 +216,24 @@ class DataPrep(object): used. If use_coordinate is a string, it specifies the name of a coordinate variariable to use as the index. :param kwargs: + :return: xarray.DataArray """ - self.data = self.data.interpolate_na(dim=dim, method=method, limit=limit, use_coordinate=use_coordinate, **kwargs) @staticmethod - def check_inverse_transform_params(mean, std, method) -> None: + def check_inverse_transform_params(mean: data_or_none, std: data_or_none, method: str) -> None: + """ + Support inverse_transformation method. + + Validate if all required statistics are available for given method. E.g. centering requires mean only, whereas + normalisation requires mean and standard deviation. Will raise an AttributeError on missing requirements. + + :param mean: data with all mean values + :param std: data with all standard deviation values + :param method: name of transformation method + """ msg = "" if method in ['standardise', 'centre'] and mean is None: msg += "mean, " @@ -205,8 +244,12 @@ class DataPrep(object): def inverse_transform(self) -> None: """ - Perform inverse transformation - :return: + Perform inverse transformation. + + Will raise an AssertionError, if no transformation was performed before. Checks first, if all required + statistics are available for inverse transformation. Class attributes data, mean and std are overwritten by + new data afterwards. Thereby, mean, std, and the private transform method are set to None to indicate, that the + current data is not transformed. """ def f_inverse(data, mean, std, method_inverse): @@ -225,8 +268,11 @@ class DataPrep(object): self.data, self.mean, self.std = f_inverse(self.data, self.mean, self.std, self._transform_method) self._transform_method = None - def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean = None, std=None) -> None: + def transform(self, dim: Union[str, int] = 0, method: str = 'standardise', inverse: bool = False, mean=None, + std=None) -> None: """ + Transform data according to given transformation settings. + This function transforms a xarray.dataarray (along dim) or pandas.DataFrame (along axis) either with mean=0 and std=1 (`method=standardise`) or centers the data with mean=0 and no change in data scale (`method=centre`). Furthermore, this sets an internal instance attribute for later inverse transformation. This @@ -239,6 +285,7 @@ class DataPrep(object): :param method: Choose the transformation method from 'standardise' and 'centre'. 'normalise' is not implemented yet. This param is not used for inverse transformation. :param inverse: Switch between transformation and inverse transformation. + :return: xarray.DataArrays or pandas.DataFrames: #. mean: Mean of data #. std: Standard deviation of data @@ -273,7 +320,18 @@ class DataPrep(object): else: self.inverse_transform() - def get_transformation_information(self, variable): + def get_transformation_information(self, variable: str) -> Tuple[data_or_none, data_or_none, str]: + """ + Extract transformation statistics and method. + + Get mean and standard deviation for given variable and the transformation method if set. If a transformation + depends only on particular statistics (e.g. only mean is required for centering), the remaining statistics are + returned with None as fill value. + + :param variable: Variable for which the information on transformation is requested. + + :return: mean, standard deviation and transformation method + """ try: mean = self.mean.sel({'variables': variable}).values except AttributeError: @@ -286,8 +344,10 @@ class DataPrep(object): def make_history_window(self, dim_name_of_inputs: str, window: int, dim_name_of_shift: str) -> None: """ - This function uses shifts the data window+1 times and returns a xarray which has a new dimension 'window' - containing the shifted data. This is used to represent history in the data. Results are stored in self.history . + Create a xr.DataArray containing history data. + + Shift the data window+1 times and return a xarray which has a new dimension 'window' containing the shifted + data. This is used to represent history in the data. Results are stored in history attribute. :param dim_name_of_inputs: Name of dimension which contains the input variables :param window: number of time steps to look back in history @@ -301,11 +361,12 @@ class DataPrep(object): def shift(self, dim: str, window: int) -> xr.DataArray: """ - This function uses xarray's shift function multiple times to represent history (if window <= 0) - or lead time (if window > 0) + Shift data multiple times to represent history (if window <= 0) or lead time (if window > 0). + :param dim: dimension along shift is applied :param window: number of steps to shift (corresponds to the window length) - :return: + + :return: shifted data """ start = 1 end = 1 @@ -320,9 +381,13 @@ class DataPrep(object): res = xr.concat(res, dim=window_array) return res - def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, window: int) -> None: + def make_labels(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str, + window: int) -> None: """ - This function creates a xarray.DataArray containing labels + Create a xr.DataArray containing labels. + + Labels are defined as the consecutive target values (t+1, ...t+n) following the current time step t. Set label + attribute. :param dim_name_of_target: Name of dimension which contains the target variable :param target_var: Name of target variable in 'dimension' @@ -334,28 +399,31 @@ class DataPrep(object): def make_observation(self, dim_name_of_target: str, target_var: str_or_list, dim_name_of_shift: str) -> None: """ - This function creates a xarray.DataArray containing labels + Create a xr.DataArray containing observations. - :param dim_name_of_target: Name of dimension which contains the target variable - :param target_var: Name of target variable(s) in 'dimension' + Observations are defined as value of the current time step t. Set observation attribute. + + :param dim_name_of_target: Name of dimension which contains the observation variable + :param target_var: Name of observation variable(s) in 'dimension' :param dim_name_of_shift: Name of dimension on which xarray.DataArray.shift will be applied """ self.observation = self.shift(dim_name_of_shift, 0).sel({dim_name_of_target: target_var}) def remove_nan(self, dim: str) -> None: """ - All NAs slices in dim which contain nans in self.history or self.label are removed in both data sets. - This is done to present only a full matrix to keras.fit. + Remove all NAs slices along dim which contain nans in history, label and observation. - :param dim: - :return: + This is done to present only a full matrix to keras.fit. Update history, label, and observation attribute. + + :param dim: dimension along the remove is performed. """ intersect = [] if (self.history is not None) and (self.label is not None): non_nan_history = self.history.dropna(dim=dim) non_nan_label = self.label.dropna(dim=dim) non_nan_observation = self.observation.dropna(dim=dim) - intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, non_nan_observation.coords[dim].values)) + intersect = reduce(np.intersect1d, (non_nan_history.coords[dim].values, non_nan_label.coords[dim].values, + non_nan_observation.coords[dim].values)) min_length = self.kwargs.get("min_length", 0) if len(intersect) < max(min_length, 1): @@ -370,11 +438,12 @@ class DataPrep(object): @staticmethod def create_index_array(index_name: str, index_value: Iterable[int]) -> xr.DataArray: """ - This Function crates a 1D xarray.DataArray with given index name and value + Create an 1D xr.DataArray with given index name and value. + + :param index_name: name of dimension + :param index_value: values of this dimension - :param index_name: - :param index_value: - :return: + :return: this array """ ind = pd.DataFrame({'val': index_value}, index=index_value) res = xr.Dataset.from_dataframe(ind).to_array().rename({'index': index_name}).squeeze(dim='variable', drop=True) @@ -383,10 +452,12 @@ class DataPrep(object): def _slice_prep(self, data: xr.DataArray, coord: str = 'datetime') -> xr.DataArray: """ - This function prepares all settings for slicing and executes _slice - :param data: + Set start and end date for slicing and execute self._slice(). + + :param data: data to slice :param coord: name of axis to slice - :return: + + :return: sliced data """ start = self.kwargs.get('start', data.coords[coord][0].values) end = self.kwargs.get('end', data.coords[coord][-1].values) @@ -395,22 +466,29 @@ class DataPrep(object): @staticmethod def _slice(data: xr.DataArray, start: Union[date, str], end: Union[date, str], coord: str) -> xr.DataArray: """ - This function slices through a given data_item (for example select only values of 2011) - :param data: - :param start: - :param end: + Slice through a given data_item (for example select only values of 2011). + + :param data: data to slice + :param start: start date of slice + :param end: end date of slice :param coord: name of axis to slice - :return: + + :return: sliced data """ return data.loc[{coord: slice(str(start), str(end))}] def check_for_negative_concentrations(self, data: xr.DataArray, minimum: int = 0) -> xr.DataArray: """ - This function sets all negative concentrations to zero. Names of all concentrations are extracted from - https://join.fz-juelich.de/services/rest/surfacedata/ #2.1 Parameters - :param data: - :param minimum: - :return: + Set all negative concentrations to zero. + + Names of all concentrations are extracted from https://join.fz-juelich.de/services/rest/surfacedata/ + #2.1 Parameters. Currently, this check is applied on "benzene", "ch4", "co", "ethane", "no", "no2", "nox", + "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", and "toluene". + + :param data: data array containing variables to check + :param minimum: minimum value, by default this should be 0 + + :return: corrected data """ chem_vars = ["benzene", "ch4", "co", "ethane", "no", "no2", "nox", "o3", "ox", "pm1", "pm10", "pm2p5", "propane", "so2", "toluene"] @@ -419,20 +497,38 @@ class DataPrep(object): return data def get_transposed_history(self) -> xr.DataArray: + """Return history. + + :return: history with dimensions datetime, window, Stations, variables. + """ return self.history.transpose("datetime", "window", "Stations", "variables").copy() def get_transposed_label(self) -> xr.DataArray: + """Return label. + + :return: label with dimensions datetime, window, Stations, variables. + """ return self.label.squeeze("Stations").transpose("datetime", "window").copy() def get_extremes_history(self) -> xr.DataArray: + """Return extremes history. + + :return: extremes history with dimensions datetime, window, Stations, variables. + """ return self.extremes_history.transpose("datetime", "window", "Stations", "variables").copy() - def get_extremes_label(self): + def get_extremes_label(self) -> xr.DataArray: + """Return extremes label. + + :return: extremes label with dimensions datetime, window, Stations, variables. + """ return self.extremes_label.squeeze("Stations").transpose("datetime", "window").copy() def multiply_extremes(self, extreme_values: num_or_list = 1., extremes_on_right_tail_only: bool = False, timedelta: Tuple[int, str] = (1, 'm')): """ + Multiply extremes. + This method extracts extreme values from self.labels which are defined in the argument extreme_values. One can also decide only to extract extremes on the right tail of the distribution. When extreme_values is a list of floats/ints all values larger (and smaller than negative extreme_values; extraction is performed in standardised @@ -447,7 +543,6 @@ class DataPrep(object): if True only extract values larger than extreme_values :param timedelta: used as arguments for np.timedelta in order to mark extreme values on datetime """ - # check if labels or history is None if (self.label is None) or (self.history is None): logging.debug(f"{self.station} has `None' labels, skip multiply extremes") @@ -465,7 +560,7 @@ class DataPrep(object): if (self.extremes_label is None) or (self.extremes_history is None): # extract extremes based on occurance in labels if extremes_on_right_tail_only: - extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1,) + extreme_label_idx = (self.label > extr_val).any(axis=0).values.reshape(-1, ) else: extreme_label_idx = np.concatenate(((self.label < -extr_val).any(axis=0).values.reshape(-1, 1), (self.label > extr_val).any(axis=0).values.reshape(-1, 1)), @@ -474,15 +569,16 @@ class DataPrep(object): extremes_history = self.history[..., extreme_label_idx, :] extremes_label.datetime.values += np.timedelta64(*timedelta) extremes_history.datetime.values += np.timedelta64(*timedelta) - self.extremes_label = extremes_label#.squeeze('Stations').transpose('datetime', 'window') - self.extremes_history = extremes_history#.transpose('datetime', 'window', 'Stations', 'variables') + self.extremes_label = extremes_label # .squeeze('Stations').transpose('datetime', 'window') + self.extremes_history = extremes_history # .transpose('datetime', 'window', 'Stations', 'variables') else: # one extr value iteration is done already: self.extremes_label is NOT None... if extremes_on_right_tail_only: extreme_label_idx = (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, ) else: - extreme_label_idx = np.concatenate(((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1), - (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1) - ), axis=1).any(axis=1) + extreme_label_idx = np.concatenate( + ((self.extremes_label < -extr_val).any(axis=0).values.reshape(-1, 1), + (self.extremes_label > extr_val).any(axis=0).values.reshape(-1, 1) + ), axis=1).any(axis=1) # check on existing extracted extremes to minimise computational costs for comparison extremes_label = self.extremes_label[..., extreme_label_idx] extremes_history = self.extremes_history[..., extreme_label_idx, :] diff --git a/src/helpers/__init__.py b/src/helpers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..546713b3f18f2cb64c1527b57d1e9e2138e927aa --- /dev/null +++ b/src/helpers/__init__.py @@ -0,0 +1,6 @@ +"""Collection of different supporting functions and classes.""" + +from .testing import PyTestRegex, PyTestAllEqual +from .time_tracking import TimeTracking, TimeTrackingWrapper +from .logger import Logger +from .helpers import remove_items, float_round, dict_to_xarray, to_list diff --git a/src/datastore.py b/src/helpers/datastore.py similarity index 60% rename from src/datastore.py rename to src/helpers/datastore.py index fb1650808a72f2a4d8b6afc10940cd9d14f894ba..b4615216000d887f16e6ed30d97215a261e12c6d 100644 --- a/src/datastore.py +++ b/src/helpers/datastore.py @@ -1,48 +1,60 @@ +"""Implementation of experiment's data store.""" + +__all__ = ['DataStoreByVariable', 'DataStoreByScope', 'NameNotFoundInDataStore', 'NameNotFoundInScope', 'EmptyScope', + 'AbstractDataStore'] __author__ = 'Lukas Leufen' __date__ = '2019-11-22' - -from abc import ABC -from functools import wraps import inspect +import logging import types +from abc import ABC +from functools import wraps from typing import Any, List, Tuple, Dict class NameNotFoundInDataStore(Exception): - """ - Exception that get raised if given name is not found in the entire data store. - """ + """Exception that get raised if given name is not found in the entire data store.""" + pass class NameNotFoundInScope(Exception): - """ - Exception that get raised if given name is not found in the provided scope, but can be found in other scopes. - """ + """Exception that get raised if given name is not found in the provided scope, but can be found in other scopes.""" + pass class EmptyScope(Exception): - """ - Exception that get raised if given scope is not part of the data store. - """ + """Exception that get raised if given scope is not part of the data store.""" + pass class CorrectScope: """ - This class is used as decorator for all class methods, that have scope in parameters. After decoration, the scope - argument is not required on method call anymore. If no scope parameter is given, this decorator automatically adds - the default scope=`general` to the arguments. Furthermore, calls like `scope=general.sub` are obsolete, because this - decorator adds the prefix `general.` if not provided. Therefore, a call like `scope=sub` will actually become - `scope=general.sub` after passing this decorator. + This class is used as decorator for all class methods, that have scope in parameters. + + After decoration, the scope argument is not required on method call anymore. If no scope parameter is given, this + decorator automatically adds the default scope=`general` to the arguments. Furthermore, calls like + `scope=general.sub` are obsolete, because this decorator adds the prefix `general.` if not provided. Therefore, a + call like `scope=sub` will actually become `scope=general.sub` after passing this decorator. """ def __init__(self, func): + """Construct decorator.""" + setattr(self, "wrapper", func) + if hasattr(func, "__wrapped__"): + func = func.__wrapped__ wraps(func)(self) def __call__(self, *args, **kwargs): + """ + Call method of decorator. + + Update tuple if scope argument does not start with `general` or slot `scope=general` into args if not provided + in neither args nor kwargs. + """ f_arg = inspect.getfullargspec(self.__wrapped__) pos_scope = f_arg.args.index("scope") if len(args) < (len(f_arg.args) - len(f_arg.defaults or "")): @@ -50,16 +62,19 @@ class CorrectScope: args = self.update_tuple(args, new_arg, pos_scope) else: args = self.update_tuple(args, args[pos_scope], pos_scope, update=True) - return self.__wrapped__(*args, **kwargs) + return self.wrapper(*args, **kwargs) def __get__(self, instance, cls): + """Create bound method object and supply self argument to the decorated method.""" return types.MethodType(self, instance) @staticmethod def correct(arg: str): """ - adds leading general prefix + Add leading general prefix. + :param arg: string argument of scope to add prefix general if necessary + :return: corrected string """ if not arg.startswith("general"): @@ -68,51 +83,119 @@ class CorrectScope: def update_tuple(self, t: Tuple, new: Any, ind: int, update: bool = False): """ - Either updates a entry in given tuple t (<old1>, <old2>, <old3>) --(ind=1)--> (<old1>, <new>, <old3>) or slots + Update single entry n given tuple or slot entry into given position. + + Either update a entry in given tuple t (<old1>, <old2>, <old3>) --(ind=1)--> (<old1>, <new>, <old3>) or slot entry into given position (<old1>, <old2>, <old3>) --(ind=1,update=True)--> (<old1>, <new>, <old2>, <old3>). In the latter case, length of returned tuple is increased by 1 in comparison to given tuple. + :param t: tuple to update :param new: new element to add to tuple :param ind: position to add or slot in :param update: updates entry if true, otherwise slot in (default: False) + :return: updated tuple """ t_new = (*t[:ind], self.correct(new), *t[ind + update:]) return t_new -class AbstractDataStore(ABC): +class TrackParameter: + + def __init__(self, func): + """Construct decorator.""" + wraps(func)(self) + + def __call__(self, *args, **kwargs): + """ + Call method of decorator. + """ + self.track(*args) + return self.__wrapped__(*args, **kwargs) + + def __get__(self, instance, cls): + """Create bound method object and supply self argument to the decorated method.""" + return types.MethodType(self, instance) + + def track(self, tracker_obj, *args): + name, obj, scope = self._decrypt_args(*args) + logging.debug(f"{self.__wrapped__.__name__}: {name}({scope})={obj}") + tracker = tracker_obj.tracker[-1] + new_entry = {"method": self.__wrapped__.__name__, "scope": scope} + if name in tracker: + tracker[name].append(new_entry) + else: + tracker[name] = [new_entry] + @staticmethod + def _decrypt_args(*args): + if len(args) == 2: + return args[0], None, args[1] + else: + return args + + +class AbstractDataStore(ABC): """ - Data store for all settings for the experiment workflow to save experiment parameters for the proceeding run_modules - and predefine parameters loaded during the experiment setup phase. The data store is hierarchically structured, so - that global settings can be overwritten by local adjustments. + Abstract data store for all settings for the experiment workflow. + + Save experiment parameters for the proceeding run_modules and predefine parameters loaded during the experiment + setup phase. The data store is hierarchically structured, so that global settings can be overwritten by local + adjustments. """ + + tracker = [{}] + def __init__(self): - # empty initialise the data-store variables + """Initialise by creating empty data store.""" self._store: Dict = {} - def set(self, name: str, obj: Any, scope: str) -> None: + def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None: """ - Abstract method to add an object to the data store + Abstract method to add an object to the data store. + :param name: Name of object to store :param obj: The object itself to be stored :param scope: the scope / context of the object, under that the object is valid + :param log: log which objects are stored if enabled (default false) """ pass def get(self, name: str, scope: str) -> None: """ - Abstract method to get an object from the data store + Abstract method to get an object from the data store. + :param name: Name to look for :param scope: scope to search the name for :return: the stored object """ pass + @CorrectScope + def get_default(self, name: str, scope: str, default: Any) -> Any: + """ + Retrieve an object with `name` from `scope` and return given default if object wasn't found. + + Same functionality like the standard get method. But this method adds a default argument that is returned if no + data was stored in the data store. Use this function with care, because it will not report any errors and just + return the given default value. Currently, there is no statement that reports, if the returned value comes from + the data store or the default value. + + :param name: Name to look for + :param scope: scope to search the name for + :param default: default value that is return, if no data was found for given name and scope + + :return: the stored object or the default value + """ + try: + return self.get(name, scope) + except (NameNotFoundInDataStore, NameNotFoundInScope): + return default + def search_name(self, name: str) -> None: """ Abstract method to search for all occurrences of given `name` in the entire data store. + :param name: Name to look for :return: search result """ @@ -120,7 +203,8 @@ class AbstractDataStore(ABC): def search_scope(self, scope: str) -> None: """ - Abstract method to search for all object names that are stored for given scope + Abstract method to search for all object names that are stored for given scope. + :param scope: scope to look for :return: search result """ @@ -128,22 +212,44 @@ class AbstractDataStore(ABC): def list_all_scopes(self) -> None: """ - Abstract method to list all scopes in data store + Abstract method to list all scopes in data store. + :return: all found scopes """ pass def list_all_names(self) -> None: """ - List all names available in the data store. + Abstract method to list all names available in the data store. + :return: all names """ pass def clear_data_store(self) -> None: + """ + Reset entire data store. + + Warning: This will remove all entries of the data store without any exception. + """ self._store = {} - def create_args_dict(self, arg_list: List[str], scope: str = "general") -> Dict: + @CorrectScope + def create_args_dict(self, arg_list: List[str], scope: str) -> Dict: + """ + Create dictionary from given argument list (as keys) and the stored data inside data store (as values). + + Try to load all stored elements for `arg_list` and create an entry in return dictionary for each valid key + value pair. Not existing keys from arg_list are skipped. This method works on a single scope only and cannot + create a dictionary with values from different scopes. Depending on the implementation of the __get__ method, + all superior scopes are included in the parameter search, if no element is found for the given subscope. + + :param arg_list: list with all elements to look for + :param scope: the scope to search in + + :return: dictionary with all valid elements from given arg_list as key and the corresponding stored object as + value. + """ args = {} for arg in arg_list: try: @@ -152,71 +258,81 @@ class AbstractDataStore(ABC): pass return args - def set_args_from_dict(self, arg_dict: Dict, scope: str = "general") -> None: + @CorrectScope + def set_from_dict(self, arg_dict: Dict, scope: str, log: bool = False) -> None: + """ + Store multiple objects from dictionary under same `scope`. + + Each object needs to be parsed as key value pair inside the given dictionary. All new entries are stored under + the same scope. + + :param arg_dict: updates for the data store, provided as key value pairs + :param scope: scope to store updates + :param log: log which objects are stored if enabled (default false) + """ for (k, v) in arg_dict.items(): - self.set(k, v, scope) + self.set(k, v, scope, log=log) class DataStoreByVariable(AbstractDataStore): - """ - Data store for all settings for the experiment workflow to save experiment parameters for the proceeding run_modules - and predefine parameters loaded during the experiment setup phase. The data store is hierarchically structured, so - that global settings can be overwritten by local adjustments. + Data store for all settings for the experiment workflow. + + Save experiment parameters for the proceeding run_modules and predefine parameters loaded during the experiment + setup phase. The data store is hierarchically structured, so that global settings can be overwritten by local + adjustments. This implementation stores data as - <variable1> - <scope1>: value - <scope2>: value - <variable2> - <scope1>: value - <scope3>: value + + .. code-block:: + + <variable1> + <scope1>: value + <scope2>: value + <variable2> + <scope1>: value + <scope3>: value + """ @CorrectScope - def set(self, name: str, obj: Any, scope: str) -> None: + @TrackParameter + def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None: """ - Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are - overwritten. + Store an object `obj` with given `name` under `scope`. + + In the current implementation, existing entries are overwritten. + :param name: Name of object to store :param obj: The object itself to be stored :param scope: the scope / context of the object, under that the object is valid + :param log: log which objects are stored if enabled (default false) """ # open new variable related store with `name` as key if not existing if name not in self._store.keys(): self._store[name] = {} self._store[name][scope] = obj + if log: + logging.debug(f"set: {name}({scope})={obj}") @CorrectScope + @TrackParameter def get(self, name: str, scope: str) -> Any: """ - Retrieve an object with `name` from `scope`. If no object can be found in the exact scope, take an iterative - look on the levels above. Raises a NameNotFoundInDataStore error, if no object with given name can be found in - the entire data store. Raises a NameNotFoundInScope error, if the object is in the data store but not in the - given scope and its levels above (could be either included in another scope or a more detailed sub-scope). + Retrieve an object with `name` from `scope`. + + If no object can be found in the exact scope, take an iterative look on the levels above. Raise a + NameNotFoundInDataStore error, if no object with given name can be found in the entire data store. Raise a + NameNotFoundInScope error, if the object is in the data store but not in the given scope and its levels above + (could be either included in another scope or a more detailed sub-scope). + :param name: Name to look for :param scope: scope to search the name for + :return: the stored object """ return self._stride_through_scopes(name, scope)[2] - @CorrectScope - def get_default(self, name: str, scope: str, default: Any) -> Any: - """ - Same functionality like the standard get method. But this method adds a default argument that is returned if no - data was stored in the data store. Use this function with care, because it will not report any errors and just - return the given default value. Currently, there is no statement that reports, if the returned value comes from - the data store or the default value. - :param name: Name to look for - :param scope: scope to search the name for - :param default: default value that is return, if no data was found for given name and scope - :return: the stored object or the default value - """ - try: - return self._stride_through_scopes(name, scope)[2] - except (NameNotFoundInDataStore, NameNotFoundInScope): - return default - @CorrectScope def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): @@ -236,7 +352,9 @@ class DataStoreByVariable(AbstractDataStore): def search_name(self, name: str) -> List[str]: """ Search for all occurrences of given `name` in the entire data store. + :param name: Name to look for + :return: list with all scopes and sub-scopes containing an object stored as `name` """ return sorted(self._store[name] if name in self._store.keys() else []) @@ -244,12 +362,16 @@ class DataStoreByVariable(AbstractDataStore): @CorrectScope def search_scope(self, scope: str, current_scope_only=True, return_all=False) -> List[str or Tuple]: """ - Search for given `scope` and list all object names stored under this scope. To look also for all superior scopes - set `current_scope_only=False`. To return the scope and the object's value too, set `return_all=True`. + Search for given `scope` and list all object names stored under this scope. + + For an expanded search in all superior scopes, set `current_scope_only=False`. To return the scope and the + object's value too, set `return_all=True`. + :param scope: scope to look for :param current_scope_only: look only for all names for given scope if true, else search for names from superior scopes too. :param return_all: return name, definition scope and value if True, else just the name + :return: list with all object names (if `return_all=False`) or list with tuple of object name, object scope and object value ordered by name (if `return_all=True`) """ @@ -284,7 +406,8 @@ class DataStoreByVariable(AbstractDataStore): def list_all_scopes(self) -> List[str]: """ - List all available scopes in data store + List all available scopes in data store. + :return: names of all stored objects """ scopes = [] @@ -297,70 +420,70 @@ class DataStoreByVariable(AbstractDataStore): def list_all_names(self) -> List[str]: """ List all names available in the data store. + :return: all names """ return sorted(self._store.keys()) class DataStoreByScope(AbstractDataStore): - """ - Data store for all settings for the experiment workflow to save experiment parameters for the proceeding run_modules - and predefine parameters loaded during the experiment setup phase. The data store is hierarchically structured, so - that global settings can be overwritten by local adjustments. + Data store for all settings for the experiment workflow. + + Save experiment parameters for the proceeding run_modules and predefine parameters loaded during the experiment + setup phase. The data store is hierarchically structured, so that global settings can be overwritten by local + adjustments. This implementation stores data as - <scope1> - <variable1>: value - <variable2>: value - <scope2> - <variable1>: value - <variable3>: value + + .. code-block:: + + <scope1> + <variable1>: value + <variable2>: value + <scope2> + <variable1>: value + <variable3>: value + """ @CorrectScope - def set(self, name: str, obj: Any, scope: str) -> None: + @TrackParameter + def set(self, name: str, obj: Any, scope: str, log: bool = False) -> None: """ - Store an object `obj` with given `name` under `scope`. In the current implementation, existing entries are - overwritten. + Store an object `obj` with given `name` under `scope`. + + In the current implementation, existing entries are overwritten. + :param name: Name of object to store :param obj: The object itself to be stored :param scope: the scope / context of the object, under that the object is valid + :param log: log which objects are stored if enabled (default false) """ if scope not in self._store.keys(): self._store[scope] = {} self._store[scope][name] = obj + if log: + logging.debug(f"set: {name}({scope})={obj}") @CorrectScope + @TrackParameter def get(self, name: str, scope: str) -> Any: """ - Retrieve an object with `name` from `scope`. If no object can be found in the exact scope, take an iterative - look on the levels above. Raises a NameNotFoundInDataStore error, if no object with given name can be found in - the entire data store. Raises a NameNotFoundInScope error, if the object is in the data store but not in the - given scope and its levels above (could be either included in another scope or a more detailed sub-scope). + Retrieve an object with `name` from `scope`. + + If no object can be found in the exact scope, take an iterative look on the levels above. Raise a + NameNotFoundInDataStore error, if no object with given name can be found in the entire data store. Raise a + NameNotFoundInScope error, if the object is in the data store but not in the given scope and its levels above + (could be either included in another scope or a more detailed sub-scope). + :param name: Name to look for :param scope: scope to search the name for + :return: the stored object """ return self._stride_through_scopes(name, scope)[2] - @CorrectScope - def get_default(self, name: str, scope: str, default: Any) -> Any: - """ - Same functionality like the standard get method. But this method adds a default argument that is returned if no - data was stored in the data store. Use this function with care, because it will not report any errors and just - return the given default value. Currently, there is no statement that reports, if the returned value comes from - the data store or the default value. - :param name: Name to look for - :param scope: scope to search the name for - :param default: default value that is return, if no data was found for given name and scope - :return: the stored object or the default value - """ - try: - return self._stride_through_scopes(name, scope)[2] - except (NameNotFoundInDataStore, NameNotFoundInScope): - return default - @CorrectScope def _stride_through_scopes(self, name, scope, depth=0): if depth <= scope.count("."): @@ -380,7 +503,9 @@ class DataStoreByScope(AbstractDataStore): def search_name(self, name: str) -> List[str]: """ Search for all occurrences of given `name` in the entire data store. + :param name: Name to look for + :return: list with all scopes and sub-scopes containing an object stored as `name` """ keys = [] @@ -392,12 +517,16 @@ class DataStoreByScope(AbstractDataStore): @CorrectScope def search_scope(self, scope: str, current_scope_only: bool = True, return_all: bool = False) -> List[str or Tuple]: """ - Search for given `scope` and list all object names stored under this scope. To look also for all superior scopes - set `current_scope_only=False`. To return the scope and the object's value too, set `return_all=True`. + Search for given `scope` and list all object names stored under this scope. + + For an expanded search in all superior scopes, set `current_scope_only=False`. To return the scope and the + object's value too, set `return_all=True`. + :param scope: scope to look for :param current_scope_only: look only for all names for given scope if true, else search for names from superior scopes too. :param return_all: return name, definition scope and value if True, else just the name + :return: list with all object names (if `return_all=False`) or list with tuple of object name, object scope and object value ordered by name (if `return_all=True`) """ @@ -428,7 +557,8 @@ class DataStoreByScope(AbstractDataStore): def list_all_scopes(self) -> List[str]: """ - List all available scopes in data store + List all available scopes in data store. + :return: names of all stored objects """ return sorted(self._store.keys()) @@ -436,6 +566,7 @@ class DataStoreByScope(AbstractDataStore): def list_all_names(self) -> List[str]: """ List all names available in the data store. + :return: all names """ names = [] diff --git a/src/helpers/helpers.py b/src/helpers/helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..968ee5385f5a44cdbbce5653a864875011874150 --- /dev/null +++ b/src/helpers/helpers.py @@ -0,0 +1,94 @@ +"""Collection of different help functions.""" +__author__ = 'Lukas Leufen, Felix Kleinert' +__date__ = '2019-10-21' + +import inspect +import math + +import xarray as xr + +from typing import Dict, Callable, Union, List, Any + + +def to_list(obj: Any) -> List: + """ + Transform given object to list if obj is not already a list. + + :param obj: object to transform to list + + :return: list containing obj, or obj itself (if obj was already a list) + """ + if not isinstance(obj, list): + obj = [obj] + return obj + + +def dict_to_xarray(d: Dict, coordinate_name: str) -> xr.DataArray: + """ + Convert a dictionary of 2D-xarrays to single 3D-xarray. The name of new coordinate axis follows <coordinate_name>. + + :param d: dictionary with 2D-xarrays + :param coordinate_name: name of the new created axis (2D -> 3D) + + :return: combined xarray + """ + xarray = None + for k, v in d.items(): + if xarray is None: + xarray = v + xarray.coords[coordinate_name] = k + else: + tmp_xarray = v + tmp_xarray.coords[coordinate_name] = k + xarray = xr.concat([xarray, tmp_xarray], coordinate_name) + return xarray + + +def float_round(number: float, decimals: int = 0, round_type: Callable = math.ceil) -> float: + """ + Perform given rounding operation on number with the precision of decimals. + + :param number: the number to round + :param decimals: numbers of decimals of the rounding operations (default 0 -> round to next integer value) + :param round_type: the actual rounding operation. Can be any callable function like math.ceil, math.floor or python + built-in round operation. + + :return: rounded number with desired precision + """ + multiplier = 10. ** decimals + return round_type(number * multiplier) / multiplier + + +def remove_items(obj: Union[List, Dict], items: Any): + """ + Remove item(s) from either list or dictionary. + + :param obj: object to remove items from (either dictionary or list) + :param items: elements to remove from obj. Can either be a list or single entry / key + + :return: object without items + """ + + def remove_from_list(list_obj, item_list): + """Remove implementation for lists.""" + if len(items) > 1: + return [e for e in list_obj if e not in item_list] + else: + list_obj = list_obj.copy() + try: + list_obj.remove(item_list[0]) + except ValueError: + pass + return list_obj + + def remove_from_dict(dict_obj, key_list): + """Remove implementation for dictionaries.""" + return {k: v for k, v in dict_obj.items() if k not in key_list} + + items = to_list(items) + if isinstance(obj, list): + return remove_from_list(obj, items) + elif isinstance(obj, dict): + return remove_from_dict(obj, items) + else: + raise TypeError(f"{inspect.stack()[0][3]} does not support type {type(obj)}.") diff --git a/src/join.py b/src/helpers/join.py similarity index 86% rename from src/join.py rename to src/helpers/join.py index 351060f7bf4949801f94b04c13e3881f008389b6..1b2abb6c8fe9d0db2dd45636f230cc9a2e232f7c 100644 --- a/src/join.py +++ b/src/helpers/join.py @@ -1,7 +1,7 @@ +"""Functions to access join database.""" __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-10-16' - import datetime as dt import logging from typing import Iterator, Union, List, Dict @@ -10,32 +10,30 @@ import pandas as pd import requests from src import helpers -from src.join_settings import join_settings +from src.configuration.join_settings import join_settings # join_url_base = 'https://join.fz-juelich.de/services/rest/surfacedata/' str_or_none = Union[str, None] class EmptyQueryResult(Exception): - """ - Exception that get raised if a query to JOIN returns empty results. - """ + """Exception that get raised if a query to JOIN returns empty results.""" + pass def download_join(station_name: Union[str, List[str]], stat_var: dict, station_type: str = None, network_name: str = None, sampling: str = "daily") -> [pd.DataFrame, pd.DataFrame]: - """ - read data from JOIN/TOAR + Read data from JOIN/TOAR. + :param station_name: Station name e.g. DEBY122 :param stat_var: key as variable like 'O3', values as statistics on keys like 'mean' :param station_type: set the station type like "traffic" or "background", can be none :param network_name: set the measurement network like "UBA" or "AIRBASE", can be none :param sampling: sampling rate of the downloaded data, either set to daily or hourly (default daily) - :returns: - - df - data frame with all variables and statistics - - meta - data frame with all meta information + + :returns: data frame with all variables and statistics and meta data frame with all meta information """ # make sure station_name parameter is a list station_name = helpers.to_list(station_name) @@ -89,10 +87,13 @@ def download_join(station_name: Union[str, List[str]], stat_var: dict, station_t def correct_data_format(data): """ - Transform to the standard data format. For some cases (e.g. hourly data), the data is returned as list instead of - a dictionary with keys datetime, values and metadata. This functions addresses this issue and transforms the data - into the dictionary version. + Transform to the standard data format. + + For some cases (e.g. hourly data), the data is returned as list instead of a dictionary with keys datetime, values + and metadata. This functions addresses this issue and transforms the data into the dictionary version. + :param data: data in hourly format + :return: the same data but formatted to fit with aggregated format """ formatted = {"datetime": [], @@ -106,10 +107,13 @@ def correct_data_format(data): def get_data(opts: Dict, headers: Dict) -> Union[Dict, List]: """ - Download join data using requests framework. Data is returned as json like structure. Depending on the response - structure, this can lead to a list or dictionary. + Download join data using requests framework. + + Data is returned as json like structure. Depending on the response structure, this can lead to a list or dictionary. + :param opts: options to create the request url :param headers: additional headers information like authorization, can be empty + :return: requested data (either as list or dictionary) """ url = create_url(**opts) @@ -121,6 +125,7 @@ def load_series_information(station_name: List[str], station_type: str_or_none, join_url_base: str, headers: Dict) -> Dict: """ List all series ids that are available for given station id and network name. + :param station_name: Station name e.g. DEBW107 :param station_type: station type like "traffic" or "background" :param network_name: measurement network of the station like "UBA" or "AIRBASE" @@ -138,11 +143,15 @@ def load_series_information(station_name: List[str], station_type: str_or_none, def _save_to_pandas(df: Union[pd.DataFrame, None], data: dict, stat: str, var: str) -> pd.DataFrame: """ - Save given data in data frame. If given data frame is not empty, the data is appened as new column. + Save given data in data frame. + + If given data frame is not empty, the data is appened as new column. + :param df: data frame to append the new data, can be none :param data: new data to append or format as data frame containing the keys 'datetime' and '<stat>' :param stat: extracted statistic to get values from data (e.g. 'mean', 'dma8eu') :param var: variable the data is from (e.g. 'o3') + :return: new created or concatenated data frame """ if len(data["datetime"][0]) == 19: @@ -159,9 +168,12 @@ def _save_to_pandas(df: Union[pd.DataFrame, None], data: dict, stat: str, var: s def _correct_stat_name(stat: str) -> str: """ - Map given statistic name to new namespace defined by mapping dict. Return given name stat if not element of mapping - namespace. + Map given statistic name to new namespace defined by mapping dict. + + Return given name stat if not element of mapping namespace. + :param stat: namespace from JOIN server + :return: stat mapped to local namespace """ mapping = {'average_values': 'mean', 'maximum': 'max', 'minimum': 'min'} @@ -170,8 +182,10 @@ def _correct_stat_name(stat: str) -> str: def _lower_list(args: List[str]) -> Iterator[str]: """ - lower all elements of given list + Lower all elements of given list. + :param args: list with string entries to lower + :return: iterator that lowers all list entries """ for string in args: @@ -180,10 +194,12 @@ def _lower_list(args: List[str]) -> Iterator[str]: def create_url(base: str, service: str, **kwargs: Union[str, int, float, None]) -> str: """ - create a request url with given base url, service type and arbitrarily many additional keyword arguments + Create a request url with given base url, service type and arbitrarily many additional keyword arguments. + :param base: basic url of the rest service :param service: service type, e.g. series, stats :param kwargs: keyword pairs for optional request specifications, e.g. 'statistics=maximum' + :return: combined url as string """ if not base.endswith("/"): diff --git a/src/helpers/logger.py b/src/helpers/logger.py new file mode 100644 index 0000000000000000000000000000000000000000..51ecde41192cb3a2838e443c3c338c5ac4e29b4d --- /dev/null +++ b/src/helpers/logger.py @@ -0,0 +1,70 @@ +"""Logger class.""" +import logging +import os +import time +from ..configuration import ROOT_PATH + + +class Logger: + """ + Basic logger class to unify all logging outputs. + + Logs are saved in local file and returned to std output. In default settings, logging level of file logger is DEBUG, + logging level of stream logger is INFO. Class must be imported and initialised in starting script, all subscripts + should log with logging.info(), debug, ... + """ + + def __init__(self, log_path=None, level_file=logging.DEBUG, level_stream=logging.INFO): + """Construct logger.""" + # define shared logger format + self.formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' + + # set log path + self.log_file = self.setup_logging_path(log_path) + # set root logger as file handler + logging.basicConfig(level=level_file, + format=self.formatter, + filename=self.log_file, + filemode='a') + # add stream handler to the root logger + logging.getLogger('').addHandler(self.logger_console(level_stream)) + # print logger path + logging.info(f"File logger: {self.log_file}") + + @staticmethod + def setup_logging_path(path: str = None): + """ + Check if given path exists and creates if not. + + If path is None, use path from main. The logging file is named like `logging_<runtime>.log` where + runtime=`%Y-%m-%d_%H-%M-%S` of current run. + + :param path: path to logfile + + :return: path of logfile + """ + if not path: # set default path + path = os.path.join(ROOT_PATH, "logging") + if not os.path.exists(path): + os.makedirs(path) + runtime = time.strftime("%Y-%m-%d_%H-%M-%S", time.localtime()) + log_file = os.path.join(path, f'logging_{runtime}.log') + return log_file + + def logger_console(self, level: int): + """ + Define a stream handler which writes messages of given level or higher to std out. + + :param level: logging level as integer, e.g. logging.DEBUG or 10 + + :return: defines stream handler + """ + # define Handler + console = logging.StreamHandler() + # set level of Handler + console.setLevel(level) + # set a format which is simpler for console use + formatter = logging.Formatter(self.formatter) + # tell the handler to use this format + console.setFormatter(formatter) + return console \ No newline at end of file diff --git a/src/statistics.py b/src/helpers/statistics.py similarity index 52% rename from src/statistics.py rename to src/helpers/statistics.py index 6510097fc3c31645bc0fa053a5ade05c3e4d908d..056f92bec25b8d5216988f4dacb8fcd1e5257ab5 100644 --- a/src/statistics.py +++ b/src/helpers/statistics.py @@ -1,6 +1,6 @@ -from scipy import stats +"""Collection of stastical methods: Transformation and Skill Scores.""" -from src.run_modules.run_environment import RunEnvironment +from scipy import stats __author__ = 'Lukas Leufen, Felix Kleinert' __date__ = '2019-10-23' @@ -8,13 +8,22 @@ __date__ = '2019-10-23' import numpy as np import xarray as xr import pandas as pd -from typing import Union, Tuple - +from typing import Union, Tuple, Dict Data = Union[xr.DataArray, pd.DataFrame] -def apply_inverse_transformation(data, mean, std=None, method="standardise"): +def apply_inverse_transformation(data: Data, mean: Data, std: Data = None, method: str = "standardise") -> Data: + """ + Apply inverse transformation for given statistics. + + :param data: transform this data back + :param mean: mean of transformation + :param std: standard deviation of transformation (optional) + :param method: transformation method + + :return: inverse transformed data + """ if method == 'standardise': # pragma: no branch return standardise_inverse(data, mean, std) elif method == 'centre': # pragma: no branch @@ -28,87 +37,134 @@ def apply_inverse_transformation(data, mean, std=None, method="standardise"): def standardise(data: Data, dim: Union[str, int]) -> Tuple[Data, Data, Data]: """ - This function standardises a xarray.dataarray (along dim) or pandas.DataFrame (along axis) with mean=0 and std=1 - :param data: - :param string/int dim: - | for xarray.DataArray as string: name of dimension which should be standardised - | for pandas.DataFrame as int: axis of dimension which should be standardised - :return: xarray.DataArrays or pandas.DataFrames: - #. mean: Mean of data - #. std: Standard deviation of data - #. data: Standardised data + Standardise a xarray.dataarray (along dim) or pandas.DataFrame (along axis) with mean=0 and std=1. + + :param data: data to standardise + :param dim: name (xarray) or axis (pandas) of dimension which should be standardised + :return: mean, standard deviation and standardised data """ return data.mean(dim), data.std(dim), (data - data.mean(dim)) / data.std(dim) def standardise_inverse(data: Data, mean: Data, std: Data) -> Data: """ - This is the inverse function of `standardise` and therefore vanishes the standardising. - :param data: - :param mean: - :param std: - :return: + Apply inverse function of `standardise` on data and therefore vanishes the standardising. + + :param data: standardised data + :param mean: mean of standardisation + :param std: standard deviation of transformation + + :return: inverse standardised data """ return data * std + mean def standardise_apply(data: Data, mean: Data, std: Data) -> Data: """ - This applies `standardise` on data using given mean and std. - :param data: - :param mean: - :param std: - :return: + Apply `standardise` on data using given mean and std. + + :param data: data to transform + :param mean: mean to use for transformation + :param std: standard deviation for transformation + + :return: transformed data """ return (data - mean) / std def centre(data: Data, dim: Union[str, int]) -> Tuple[Data, None, Data]: """ - This function centres a xarray.dataarray (along dim) or pandas.DataFrame (along axis) to mean=0 - :param data: - :param string/int dim: - | for xarray.DataArray as string: name of dimension which should be standardised - | for pandas.DataFrame as int: axis of dimension which should be standardised - :return: xarray.DataArrays or pandas.DataFrames: - #. mean: Mean of data - #. std: Standard deviation of data - #. data: Standardised data + Centre a xarray.dataarray (along dim) or pandas.DataFrame (along axis) to mean=0. + + :param data: data to centre + :param dim: name (xarray) or axis (pandas) of dimension which should be centred + + :return: mean, None placeholder and centred data """ return data.mean(dim), None, data - data.mean(dim) def centre_inverse(data: Data, mean: Data) -> Data: """ - This function is the inverse function of `centre` and therefore adds the given values of mean to the data. - :param data: - :param mean: - :return: + Apply inverse function of `centre` and therefore add given values of mean to data. + + :param data: data to apply inverse centering + :param mean: mean to use for inverse transformation + + :return: inverted centering transformation data """ return data + mean def centre_apply(data: Data, mean: Data) -> Data: """ - This applies `centre` on data using given mean and std. - :param data: - :param mean: - :param std: - :return: + Apply `centre` on data using given mean. + + :param data: data to transform + :param mean: mean to use for transformation + + :return: transformed data """ return data - mean def mean_squared_error(a, b): + """Calculate mean squared error.""" return np.square(a - b).mean() class SkillScores: + r""" + Calculate different kinds of skill scores. + + Skill score on MSE: + Calculate skill score based on MSE for given forecast, reference and observations. + + .. math:: - def __init__(self, internal_data): + \text{SkillScore} = 1 - \frac{\text{MSE(obs, for)}}{\text{MSE(obs, ref)}} + + To run: + + .. code-block:: python + + skill_scores = SkillScores(None).general_skill_score(data, observation_name, forecast_name, reference_name) + + Competitive skill score: + Calculate skill scores to highlight differences between forecasts. This skill score is also based on the MSE. + Currently required forecasts are CNN, OLS and persi, as well as the observation obs. + + .. code-block:: python + + skill_scores_class = SkillScores(internal_data) # must contain columns CNN, OLS, persi and obs. + skill_scores = skill_scores_class.skill_scores(window_lead_time=3) + + Skill score according to Murphy: + Follow climatological skill score definition of Murphy (1988). External data is data from another time period + than the internal data set on initialisation. In other terms, this should be the train and validation data + whereas the internal data is the test data. This sounds perhaps counter-intuitive, but if a skill score is + evaluated to a model to another, this must be performend test data set. Therefore, for this case the foreign + data is train and val data. + + .. code-block:: python + + skill_scores_class = SkillScores(internal_data) # must contain columns obs and CNN. + skill_scores_clim = skill_scores_class.climatological_skill_scores(external_data, window_lead_time=3) + + """ + + def __init__(self, internal_data: Data): + """Set internal data.""" self.internal_data = internal_data - def skill_scores(self, window_lead_time): + def skill_scores(self, window_lead_time: int) -> pd.DataFrame: + """ + Calculate skill scores for all combinations of CNN, persistence and OLS. + + :param window_lead_time: length of forecast steps + + :return: skill score for each comparison and forecast step + """ ahead_names = list(range(1, window_lead_time + 1)) skill_score = pd.DataFrame(index=['cnn-persi', 'ols-persi', 'cnn-ols']) for iahead in ahead_names: @@ -118,7 +174,18 @@ class SkillScores: self.general_skill_score(data, forecast_name="CNN", reference_name="OLS")] return skill_score - def climatological_skill_scores(self, external_data, window_lead_time): + def climatological_skill_scores(self, external_data: Data, window_lead_time: int) -> xr.DataArray: + """ + Calculate climatological skill scores according to Murphy (1988). + + Calculate all CASES I - IV and terms [ABC][I-IV]. Internal data has to be set by initialisation, external data + is part of parameters. + + :param external_data: external data + :param window_lead_time: interested time step of forecast horizon to select data + + :return: all CASES as well as all terms + """ ahead_names = list(range(1, window_lead_time + 1)) all_terms = ['AI', 'AII', 'AIII', 'AIV', 'BI', 'BII', 'BIV', 'CI', 'CIV', 'CASE I', 'CASE II', 'CASE III', @@ -147,12 +214,24 @@ class SkillScores: return skill_score - def _climatological_skill_score(self, data, mu_type=1, observation_name="obs", forecast_name="CNN", external_data=None): + def _climatological_skill_score(self, data, mu_type=1, observation_name="obs", forecast_name="CNN", + external_data=None): kwargs = {"external_data": external_data} if external_data is not None else {} return self.__getattribute__(f"skill_score_mu_case_{mu_type}")(data, observation_name, forecast_name, **kwargs) @staticmethod - def general_skill_score(data, observation_name="obs", forecast_name="CNN", reference_name="persi"): + def general_skill_score(data: Data, observation_name: str = "obs", forecast_name: str = "CNN", + reference_name: str = "persi") -> np.ndarray: + r""" + Calculate general skill score based on mean squared error. + + :param data: internal data containing data for observation, forecast and reference + :param observation_name: name of observation + :param forecast_name: name of forecast + :param reference_name: name of reference + + :return: skill score of forecast + """ data = data.dropna("index") observation = data.sel(type=observation_name) forecast = data.sel(type=forecast_name) @@ -162,14 +241,28 @@ class SkillScores: return skill_score.values @staticmethod - def skill_score_pre_calculations(data, observation_name, forecast_name): - + def skill_score_pre_calculations(data: Data, observation_name: str, forecast_name: str) -> Tuple[np.ndarray, + np.ndarray, + np.ndarray, + Data, + Dict[str, Data]]: + """ + Calculate terms AI, BI, and CI, mean, variance and pearson's correlation and clean up data. + + The additional information on mean, variance and pearson's correlation (and the p-value) are returned as + dictionary with the corresponding keys mean, sigma, r and p. + + :param data: internal data to use for calculations + :param observation_name: name of observation + :param forecast_name: name of forecast + + :returns: Terms AI, BI, and CI, internal data without nans and mean, variance, correlation and its p-value + """ data = data.loc[..., [observation_name, forecast_name]].drop("ahead") data = data.dropna("index") mean = data.mean("index") sigma = np.sqrt(data.var("index")) - # r, p = stats.spearmanr(data.loc[..., [forecast_name, observation_name]]) r, p = stats.pearsonr(data.loc[..., forecast_name], data.loc[..., observation_name]) AI = np.array(r ** 2) @@ -181,17 +274,18 @@ class SkillScores: return AI, BI, CI, data, suffix def skill_score_mu_case_1(self, data, observation_name="obs", forecast_name="CNN"): + """Calculate CASE I.""" AI, BI, CI, data, _ = self.skill_score_pre_calculations(data, observation_name, forecast_name) skill_score = np.array(AI - BI - CI) return pd.DataFrame({"skill_score": [skill_score], "AI": [AI], "BI": [BI], "CI": [CI]}).to_xarray().to_array() def skill_score_mu_case_2(self, data, observation_name="obs", forecast_name="CNN"): + """Calculate CASE II.""" AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name) monthly_mean = self.create_monthly_mean_from_daily_data(data) data = xr.concat([data, monthly_mean], dim="type") sigma = suffix["sigma"] sigma_monthly = np.sqrt(monthly_mean.var()) - # r, p = stats.spearmanr(data.loc[..., [observation_name, observation_name + "X"]]) r, p = stats.pearsonr(data.loc[..., observation_name], data.loc[..., observation_name + "X"]) AII = np.array(r ** 2) BII = ((r - sigma_monthly / sigma.loc[observation_name]) ** 2).values @@ -199,15 +293,18 @@ class SkillScores: return pd.DataFrame({"skill_score": [skill_score], "AII": [AII], "BII": [BII]}).to_xarray().to_array() def skill_score_mu_case_3(self, data, observation_name="obs", forecast_name="CNN", external_data=None): + """Calculate CASE III.""" AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name) mean, sigma = suffix["mean"], suffix["sigma"] - AIII = (((external_data.mean().values - mean.loc[observation_name]) / sigma.loc[observation_name])**2).values + AIII = (((external_data.mean().values - mean.loc[observation_name]) / sigma.loc[observation_name]) ** 2).values skill_score = np.array((AI - BI - CI + AIII) / 1 + AIII) return pd.DataFrame({"skill_score": [skill_score], "AIII": [AIII]}).to_xarray().to_array() def skill_score_mu_case_4(self, data, observation_name="obs", forecast_name="CNN", external_data=None): + """Calculate CASE IV.""" AI, BI, CI, data, suffix = self.skill_score_pre_calculations(data, observation_name, forecast_name) - monthly_mean_external = self.create_monthly_mean_from_daily_data(external_data, columns=data.type.values, index=data.index) + monthly_mean_external = self.create_monthly_mean_from_daily_data(external_data, columns=data.type.values, + index=data.index) data = xr.concat([data, monthly_mean_external], dim="type") mean, sigma = suffix["mean"], suffix["sigma"] monthly_mean_external = self.create_monthly_mean_from_daily_data(external_data, columns=data.type.values) @@ -217,14 +314,24 @@ class SkillScores: # r_mu, p_mu = stats.spearmanr(data.loc[..., [observation_name, observation_name+'X']]) r_mu, p_mu = stats.pearsonr(data.loc[..., observation_name], data.loc[..., observation_name + "X"]) - AIV = np.array(r_mu**2) - BIV = ((r_mu - sigma_external / sigma.loc[observation_name])**2).values - CIV = (((mean_external - mean.loc[observation_name]) / sigma.loc[observation_name])**2).values + AIV = np.array(r_mu ** 2) + BIV = ((r_mu - sigma_external / sigma.loc[observation_name]) ** 2).values + CIV = (((mean_external - mean.loc[observation_name]) / sigma.loc[observation_name]) ** 2).values skill_score = np.array((AI - BI - CI - AIV + BIV + CIV) / (1 - AIV + BIV + CIV)) - return pd.DataFrame({"skill_score": [skill_score], "AIV": [AIV], "BIV": [BIV], "CIV": CIV}).to_xarray().to_array() + return pd.DataFrame( + {"skill_score": [skill_score], "AIV": [AIV], "BIV": [BIV], "CIV": CIV}).to_xarray().to_array() @staticmethod def create_monthly_mean_from_daily_data(data, columns=None, index=None): + """ + Calculate average for each month and save as daily values with flag 'X'. + + :param data: data to average + :param columns: columns to work on (all columns from given data are used if empty) + :param index: index of returned data (index of given data is used if empty) + + :return: data containing monthly means in daily resolution + """ if columns is None: columns = data.type.values if index is None: diff --git a/src/helpers/testing.py b/src/helpers/testing.py new file mode 100644 index 0000000000000000000000000000000000000000..244eb69fdc46dcadaeb3ada5779f09d44aa83e2a --- /dev/null +++ b/src/helpers/testing.py @@ -0,0 +1,88 @@ +"""Helper functions that are used to simplify testing.""" +import re +from typing import Union, Pattern, List + +import numpy as np +import xarray as xr + + +class PyTestRegex: + r""" + Assert that a given string meets some expectations. + + Use like + + >>> PyTestRegex(r"TestString\d+") == "TestString" + False + >>> PyTestRegex(r"TestString\d+") == "TestString2" + True + + + :param pattern: pattern or string to use for regular expresssion + :param flags: python re flags + """ + + def __init__(self, pattern: Union[str, Pattern], flags: int = 0): + """Construct PyTestRegex.""" + self._regex = re.compile(pattern, flags) + + def __eq__(self, actual: str) -> bool: + """Return whether regex matches given string actual or not.""" + return bool(self._regex.match(actual)) + + def __repr__(self) -> str: + """Show regex pattern.""" + return self._regex.pattern + + +class PyTestAllEqual: + """ + Check if all elements in list are the same. + + :param check_list: list with elements to check + """ + + def __init__(self, check_list: List): + """Construct class.""" + self._list = check_list + self._test_function = None + + def _set_test_function(self): + if isinstance(self._list[0], np.ndarray): + self._test_function = np.testing.assert_array_equal + else: + self._test_function = xr.testing.assert_equal + + def _check_all_equal(self) -> bool: + """ + Check if all elements are equal. + + :return boolean if elements are equal + """ + equal = True + self._set_test_function() + for b in self._list: + equal *= self._test_function(self._list[0], b) is None + return bool(equal == 1) + + def is_true(self) -> bool: + """ + Start equality check. + + :return: true if equality test is passed, false otherwise + """ + return self._check_all_equal() + + +def xr_all_equal(check_list: List) -> bool: + """ + Check if all given elements (preferably xarray's) in list are equal. + + :param check_list: list with elements to check + + :return: boolean if all elements are the same or not + """ + equal = True + for b in check_list: + equal *= xr.testing.assert_equal(check_list[0], b) is None + return equal == 1 \ No newline at end of file diff --git a/src/helpers/time_tracking.py b/src/helpers/time_tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..c85a6a047943a589a9d076584ae40186634db767 --- /dev/null +++ b/src/helpers/time_tracking.py @@ -0,0 +1,131 @@ +"""Track time either as decorator or explicit.""" +import datetime as dt +import logging +import math +import time +import types +from functools import wraps +from typing import Optional + + +class TimeTrackingWrapper: + r""" + Wrapper implementation of TimeTracking class. + + Use this implementation easily as decorator for functions, classes and class methods. Implement a custom function + and decorate it for automatic time measure. + + .. code-block:: python + + @TimeTrackingWrapper + def sleeper(): + print("start") + time.sleep(1) + print("end") + + >>> sleeper() + start + end + INFO: foo finished after 00:00:01 (hh:mm:ss) + + """ + + def __init__(self, func): + """Construct.""" + wraps(func)(self) + + def __call__(self, *args, **kwargs): + """Start time tracking.""" + with TimeTracking(name=self.__wrapped__.__name__): + return self.__wrapped__(*args, **kwargs) + + def __get__(self, instance, cls): + """Create bound method object and supply self argument to the decorated method.""" + return types.MethodType(self, instance) + + +class TimeTracking(object): + """ + Track time to measure execution time. + + Time tracking automatically starts on initialisation and ends by calling stop method. Duration can always be shown + by printing the time tracking object or calling get_current_duration. It is possible to start and stop time tracking + by hand like + + .. code-block:: python + + time = TimeTracking(start=True) # start=True is default and not required to set + do_something() + time.stop(get_duration=True) + + A more comfortable way is to use TimeTracking in a with statement like: + + .. code-block:: python + + with TimeTracking(): + do_something() + + The only disadvantage of the latter implementation is, that the duration is logged but not returned. + """ + + def __init__(self, start=True, name="undefined job"): + """Construct time tracking and start if enabled.""" + self.start = None + self.end = None + self._name = name + if start: + self._start() + + def _start(self) -> None: + """Start time tracking.""" + self.start = time.time() + self.end = None + + def _end(self) -> None: + """Stop time tracking.""" + self.end = time.time() + + def _duration(self) -> float: + """Get duration in seconds.""" + if self.end: + return self.end - self.start + else: + return time.time() - self.start + + def __repr__(self) -> str: + """Display current passed time.""" + return f"{dt.timedelta(seconds=math.ceil(self._duration()))} (hh:mm:ss)" + + def run(self) -> None: + """Start time tracking.""" + self._start() + + def stop(self, get_duration=False) -> Optional[float]: + """ + Stop time tracking. + + Will raise an error if time tracking was already stopped. + :param get_duration: return passed time if enabled. + + :return: duration if enabled or None + """ + if self.end is None: + self._end() + else: + msg = f"Time was already stopped {time.time() - self.end}s ago." + raise AssertionError(msg) + if get_duration: + return self.duration() + + def duration(self) -> float: + """Return duration in seconds.""" + return self._duration() + + def __enter__(self): + """Context manager.""" + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + """Stop time tracking on exit and log info about passed time.""" + self.stop() + logging.info(f"{self._name} finished after {self}") \ No newline at end of file diff --git a/src/join_settings.py b/src/join_settings.py deleted file mode 100644 index 365e8f39d25b28375eadf3b0dbda374feb5b158e..0000000000000000000000000000000000000000 --- a/src/join_settings.py +++ /dev/null @@ -1,11 +0,0 @@ - -def join_settings(sampling="daily"): - if sampling == "daily": # pragma: no branch - TOAR_SERVICE_URL = 'https://join.fz-juelich.de/services/rest/surfacedata/' - headers = {} - elif sampling == "hourly": - TOAR_SERVICE_URL = 'https://join.fz-juelich.de/services/rest/surfacedata/' - headers = {} - else: - raise NameError(f"Given sampling {sampling} is not supported, choose from either daily or hourly sampling.") - return TOAR_SERVICE_URL, headers diff --git a/src/model_modules/__init__.py b/src/model_modules/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..35f4060886036d3f51c24b4480738566ff80a445 100644 --- a/src/model_modules/__init__.py +++ b/src/model_modules/__init__.py @@ -0,0 +1 @@ +"""Collection of all modules that are related to a model.""" diff --git a/src/model_modules/advanced_paddings.py b/src/model_modules/advanced_paddings.py index ea16e5b8a7c6a01456e286a2afaab4d5a88c96cc..f2fd4de91e84b1407f54c5ea156ad34f2d46acff 100644 --- a/src/model_modules/advanced_paddings.py +++ b/src/model_modules/advanced_paddings.py @@ -1,30 +1,35 @@ +"""Collection of customised padding layers.""" + __author__ = 'Felix Kleinert' __date__ = '2020-03-02' -import tensorflow as tf -import numpy as np -import keras.backend as K -from keras.layers.convolutional import _ZeroPadding +from typing import Union, Tuple + +import numpy as np +import tensorflow as tf +from keras.backend.common import normalize_data_format from keras.layers import ZeroPadding2D +from keras.layers.convolutional import _ZeroPadding from keras.legacy import interfaces from keras.utils import conv_utils from keras.utils.generic_utils import transpose_shape -from keras.backend.common import normalize_data_format class PadUtils: - """ - Helper class for advanced paddings - """ + """Helper class for advanced padding.""" @staticmethod - def get_padding_for_same(kernel_size, strides=1): + def get_padding_for_same(kernel_size: Tuple[int], strides: int = 1) -> Tuple[int]: """ - This methods calculates the padding size to keep input and output dimensions equal for a given kernel size - (STRIDES HAVE TO BE EQUAL TO ONE!) - :param kernel_size: - :return: + Calculate padding size to keep input and output dimensions equal for a given kernel size. + + .. hint:: custom paddings are currently only implemented for strides = 1 + + :param kernel_size: size of padding kernel size + :param strides: number of strides (default 1, currently only strides=1 supported) + + :return: padding size """ if strides != 1: raise NotImplementedError("Strides other than 1 not implemented!") @@ -40,15 +45,15 @@ class PadUtils: if all(k % 2 == 1 for k in ks): # (d & 0x1 for d in ks): pad = ((ks - 1) / 2).astype(np.int64) # convert numpy int to base int - pad = [np.asscalar(v) for v in pad] + pad = [int(v.item()) for v in pad] return tuple(pad) - # return tuple(PadUtils.check_padding_format(pad)) else: raise NotImplementedError(f"even kernel size not implemented. Got {kernel_size}") @staticmethod def spatial_2d_padding(padding=((1, 1), (1, 1)), data_format=None): - """Pads the 2nd and 3rd dimensions of a 4D tensor. + """ + Pad the 2nd and 3rd dimensions of a 4D tensor. # Arguments x: Tensor or variable. @@ -75,6 +80,7 @@ class PadUtils: @staticmethod def check_padding_format(padding): + """Check padding format (int, 1D or 2D, >0).""" if isinstance(padding, int): normalized_padding = ((padding, padding), (padding, padding)) elif hasattr(padding, '__len__'): @@ -89,16 +95,18 @@ class PadUtils: raise ValueError(f'`padding[{idx_pad}]` should have one or two elements. ' f'Found: {padding[idx_pad]}') if not all(isinstance(sub_k, int) for sub_k in padding[idx_pad]): - raise ValueError(f'`padding[{idx_pad}]` should have one or two elements of type int. ' + raise ValueError(f'`padding[{idx_pad}]` should have one or two elements of type int. ' f"Found:{padding[idx_pad]} of type {[type(sub_k) for sub_k in padding[idx_pad]]}") height_padding = conv_utils.normalize_tuple(padding[0], 2, '1st entry of padding') if not all(k >= 0 for k in height_padding): - raise ValueError(f"The `1st entry of padding` argument must be >= 0. Received: {padding[0]} of type {type(padding[0])}") + raise ValueError( + f"The `1st entry of padding` argument must be >= 0. Received: {padding[0]} of type {type(padding[0])}") width_padding = conv_utils.normalize_tuple(padding[1], 2, '2nd entry of padding') if not all(k >= 0 for k in width_padding): - raise ValueError(f"The `2nd entry of padding` argument must be >= 0. Received: {padding[1]} of type {type(padding[1])}") + raise ValueError( + f"The `2nd entry of padding` argument must be >= 0. Received: {padding[1]} of type {type(padding[1])}") normalized_padding = (height_padding, width_padding) else: raise ValueError('`padding` should be either an int, ' @@ -112,9 +120,10 @@ class PadUtils: class ReflectionPadding2D(_ZeroPadding): """ - Reflection padding layer for 2D input. This custum padding layer is built on keras' zero padding layers. Doc is copy - pasted from the original functions/methods: + Reflection padding layer for 2D input. + This custom padding layer is built on keras' zero padding layers. Doc is copy and pasted from the original + functions/methods: This layer can add rows and columns of reflected values at the top, bottom, left and right side of an image like tensor. @@ -129,7 +138,7 @@ class ReflectionPadding2D(_ZeroPadding): - '# Arguments + # Arguments padding: int, or tuple of 2 ints, or tuple of 2 tuples of 2 ints. - If int: the same symmetric padding is applied to height and width. @@ -172,21 +181,24 @@ class ReflectionPadding2D(_ZeroPadding): padding=(1, 1), data_format=None, **kwargs): + """Initialise ReflectionPadding2D.""" normalized_padding = PadUtils.check_padding_format(padding=padding) super(ReflectionPadding2D, self).__init__(normalized_padding, data_format, **kwargs) def call(self, inputs, mask=None): + """Call ReflectionPadding2D.""" pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format) return tf.pad(inputs, pattern, 'REFLECT') class SymmetricPadding2D(_ZeroPadding): """ - Symmetric padding layer for 2D input. This custom padding layer is built on keras' zero padding layers. Doc is copy - pasted from the original functions/methods: + Symmetric padding layer for 2D input. + This custom padding layer is built on keras' zero padding layers. Doc is copy pasted from the original + functions/methods: This layer can add rows and columns of symmetric values at the top, bottom, left and right side of an image like tensor. @@ -243,12 +255,14 @@ class SymmetricPadding2D(_ZeroPadding): padding=(1, 1), data_format=None, **kwargs): + """Initialise SymmetricPadding2D.""" normalized_padding = PadUtils.check_padding_format(padding=padding) super(SymmetricPadding2D, self).__init__(normalized_padding, data_format, **kwargs) def call(self, inputs, mask=None): + """Call SymmetricPadding2D.""" pattern = PadUtils.spatial_2d_padding(padding=self.padding, data_format=self.data_format) return tf.pad(inputs, pattern, 'SYMMETRIC') @@ -278,18 +292,20 @@ class Padding2D: **dict.fromkeys(("SymPad2D", "SymmetricPadding2D"), SymmetricPadding2D), **dict.fromkeys(("ZeroPad2D", "ZeroPadding2D"), ZeroPadding2D) } + padding_type = Union[ReflectionPadding2D, SymmetricPadding2D, ZeroPadding2D] - def __init__(self, padding_type): + def __init__(self, padding_type: Union[str, padding_type]): + """Set padding type.""" self.padding_type = padding_type def _check_and_get_padding(self): if isinstance(self.padding_type, str): try: pad2d = self.allowed_paddings[self.padding_type] - except KeyError as einfo: + except KeyError as e: raise NotImplementedError( - f"`{einfo}' is not implemented as padding. " - "Use one of those: i) `RefPad2D', ii) `SymPad2D', iii) `ZeroPad2D'") + f"`{e}' is not implemented as padding. Use one of those: i) `RefPad2D', ii) `SymPad2D', " + f"iii) `ZeroPad2D'") else: if self.padding_type in self.allowed_paddings.values(): pad2d = self.padding_type @@ -300,6 +316,7 @@ class Padding2D: return pad2d def __call__(self, *args, **kwargs): + """Call padding.""" return self._check_and_get_padding()(*args, **kwargs) @@ -332,5 +349,3 @@ if __name__ == '__main__': model.compile('adam', loss='mse') model.summary() model.fit(x, y, epochs=10) - - diff --git a/src/model_modules/inception_model.py b/src/model_modules/inception_model.py index 15739556d7d28d9e7e6ecc454615d82fb81a2754..74cd4d806f706a70d554adae468e7fa8c5de153e 100644 --- a/src/model_modules/inception_model.py +++ b/src/model_modules/inception_model.py @@ -5,7 +5,8 @@ import logging import keras import keras.layers as layers -from src.model_modules.advanced_paddings import PadUtils, ReflectionPadding2D, SymmetricPadding2D, Padding2D + +from src.model_modules.advanced_paddings import PadUtils, ReflectionPadding2D, Padding2D class InceptionModelBase: @@ -22,6 +23,7 @@ class InceptionModelBase: def block_part_name(self): """ Use unicode due to some issues of keras with normal strings + :return: """ return chr(self.ord_base + self.part_of_block) @@ -41,6 +43,7 @@ class InceptionModelBase: """ This function creates a "convolution tower block" containing a 1x1 convolution to reduce filter size followed by convolution with given filter and kernel size + :param input_x: Input to network part :param reduction_filter: Number of filters used in 1x1 convolution to reduce overall filter size before conv. :param tower_filter: Number of filters for n x m convolution @@ -111,6 +114,7 @@ class InceptionModelBase: def create_pool_tower(self, input_x, pool_kernel, tower_filter, activation='relu', max_pooling=True, **kwargs): """ This function creates a "MaxPooling tower block" + :param input_x: Input to network part :param pool_kernel: size of pooling kernel :param tower_filter: Number of filters used in 1x1 convolution to reduce filter size @@ -133,11 +137,11 @@ class InceptionModelBase: block_type = "AvgPool" pooling = layers.AveragePooling2D - tower = Padding2D(padding)(padding=padding_size, name=block_name+'Pad')(input_x) - tower = pooling(pool_kernel, strides=(1, 1), padding='valid', name=block_name+block_type)(tower) + tower = Padding2D(padding)(padding=padding_size, name=block_name + 'Pad')(input_x) + tower = pooling(pool_kernel, strides=(1, 1), padding='valid', name=block_name + block_type)(tower) # convolution block - tower = layers.Conv2D(tower_filter, (1, 1), padding='valid', name=block_name+"1x1")(tower) + tower = layers.Conv2D(tower_filter, (1, 1), padding='valid', name=block_name + "1x1")(tower) tower = self.act(tower, activation, **act_settings) return tower @@ -145,6 +149,7 @@ class InceptionModelBase: def inception_block(self, input_x, tower_conv_parts, tower_pool_parts, **kwargs): """ Crate a inception block + :param input_x: Input to block :param tower_conv_parts: dict containing settings for parts of inception block; Example: tower_conv_parts = {'tower_1': {'reduction_filter': 32, @@ -184,7 +189,7 @@ class InceptionModelBase: tower_build['avgpool'] = self.create_pool_tower(input_x, **tower_pool_parts, **kwargs, max_pooling=False) block = keras.layers.concatenate(list(tower_build.values()), axis=3, - name=block_name+"_Co") + name=block_name + "_Co") return block @@ -202,7 +207,7 @@ if __name__ == '__main__': conv_settings_dict = {'tower_1': {'reduction_filter': 64, 'tower_filter': 64, 'tower_kernel': (3, 3), - 'activation': LeakyReLU,}, + 'activation': LeakyReLU, }, 'tower_2': {'reduction_filter': 64, 'tower_filter': 64, 'tower_kernel': (5, 5), @@ -239,12 +244,10 @@ if __name__ == '__main__': # compile epochs = 1 lrate = 0.01 - decay = lrate/epochs + decay = lrate / epochs sgd = SGD(lr=lrate, momentum=0.9, decay=decay, nesterov=False) model.compile(loss='categorical_crossentropy', optimizer=sgd, metrics=['accuracy']) print(X_train.shape) keras.utils.plot_model(model, to_file='model.pdf', show_shapes=True, show_layer_names=True) # model.fit(X_train, y_train, epochs=epochs, validation_data=(X_test, y_test)) print('test') - - diff --git a/src/model_modules/keras_extensions.py b/src/model_modules/keras_extensions.py index 180e324602da25e1df8fb218c1d3bba180004ac8..479913811a668d8330a389b2876360f096f57dbf 100644 --- a/src/model_modules/keras_extensions.py +++ b/src/model_modules/keras_extensions.py @@ -1,25 +1,31 @@ +"""Collection of different extensions to keras framework.""" + __author__ = 'Lukas Leufen, Felix Kleinert' __date__ = '2020-01-31' import logging import math import pickle -from typing import Union +from typing import Union, List +from typing_extensions import TypedDict import numpy as np from keras import backend as K -from keras.callbacks import History, ModelCheckpoint +from keras.callbacks import History, ModelCheckpoint, Callback from src import helpers class HistoryAdvanced(History): """ - This is almost an identical clone of the original History class. The only difference is that attributes epoch and - history are instantiated during the init phase and not during on_train_begin. This is required to resume an already - started but disrupted training from an saved state. This HistoryAdvanced callback needs to be added separately as - additional callback. To get the full history use this object for further steps instead of the default return of - training methods like fit_generator(). + This is almost an identical clone of the original History class. + + The only difference is that attributes epoch and history are instantiated during the init phase and not during + on_train_begin. This is required to resume an already started but disrupted training from an saved state. This + HistoryAdvanced callback needs to be added separately as additional callback. To get the full history use this + object for further steps instead of the default return of training methods like fit_generator(). + + .. code-block:: python hist = HistoryAdvanced() history = model.fit_generator(generator=.... , callbacks=[hist]) @@ -29,21 +35,30 @@ class HistoryAdvanced(History): """ def __init__(self): + """Set up HistoryAdvanced.""" self.epoch = [] self.history = {} super().__init__() def on_train_begin(self, logs=None): + """Overload on_train_begin method to do nothing instead of resetting epoch and history.""" pass class LearningRateDecay(History): """ - Decay learning rate during model training. Start with a base learning rate and lower this rate after every - n(=epochs_drop) epochs by drop value (0, 1], drop value = 1 means no decay in learning rate. + Decay learning rate during model training. + + Start with a base learning rate and lower this rate after every n(=epochs_drop) epochs by drop value (0, 1], drop + value = 1 means no decay in learning rate. + + :param base_lr: base learning rate to start with + :param drop: ratio to drop after epochs_drop + :param epochs_drop: number of epochs after that drop takes place """ def __init__(self, base_lr: float = 0.01, drop: float = 0.96, epochs_drop: int = 8): + """Set up LearningRateDecay.""" super().__init__() self.lr = {'lr': []} self.base_lr = self.check_param(base_lr, 'base_lr') @@ -55,13 +70,16 @@ class LearningRateDecay(History): @staticmethod def check_param(value: float, name: str, lower: Union[float, None] = 0, upper: Union[float, None] = 1): """ - Check if given value is in interval. The left (lower) endpoint is open, right (upper) endpoint is closed. To - only one side of the interval, set the other endpoint to None. If both ends are set to None, just return the - value without any check. + Check if given value is in interval. + + The left (lower) endpoint is open, right (upper) endpoint is closed. To use only one side of the interval, set + the other endpoint to None. If both ends are set to None, just return the value without any check. + :param value: value to check :param name: name of the variable to display in error message :param lower: left (lower) endpoint of interval, opened :param upper: right (upper) endpoint of interval, closed + :return: unchanged value or raise ValueError """ if lower is None: @@ -75,11 +93,13 @@ class LearningRateDecay(History): f"{name}={value}") def on_train_begin(self, logs=None): + """Overload on_train_begin method to do nothing instead of resetting epoch and history.""" pass def on_epoch_begin(self, epoch: int, logs=None): """ Lower learning rate every epochs_drop epochs by factor drop. + :param epoch: current epoch :param logs: ? :return: update keras learning rate @@ -93,46 +113,66 @@ class LearningRateDecay(History): class ModelCheckpointAdvanced(ModelCheckpoint): """ - Enhance the standard ModelCheckpoint class by additional saves of given callbacks. Specify this callbacks as follow: + Enhance the standard ModelCheckpoint class by additional saves of given callbacks. + + **We recommend to use CallbackHandler instead of ModelCheckpointAdvanced.** CallbackHandler will handler all your + callbacks and the ModelCheckpointAdvanced and prevent you from pitfalls like wrong ordering of callbacks. Actually, + CallbackHandler makes use of ModelCheckpointAdvanced. + + However, if you want to use the ModelCheckpointAdvanced explicitly, follow these instructions: + .. code-block:: python + + # load your callbacks lr = CustomLearningRate() hist = CustomHistory() + + # set your callbacks with a list dictionary structure callbacks_name = "your_custom_path_%s.pickle" callbacks = [{"callback": lr, "path": callbacks_name % "lr"}, - {"callback": hist, "path": callbacks_name % "hist"}] + {"callback": hist, "path": callbacks_name % "hist"}] + # initialise ModelCheckpointAdvanced like the normal ModelCheckpoint (see keras callbacks) ckpt_callbacks = ModelCheckpointAdvanced(filepath=.... , callbacks=callbacks) - Add this ckpt_callbacks as all other additional callbacks to the callback list. IMPORTANT: Always add ckpt_callbacks - as last callback to properly update all tracked callbacks, e.g. + Add ModelCheckpointAdvanced as all other additional callbacks to the callback list. IMPORTANT: Always add + ModelCheckpointAdvanced as last callback to properly update all tracked callbacks, e.g. + + .. code-block:: python + # always add ModelCheckpointAdvanced as last element fit_generator(.... , callbacks=[lr, hist, ckpt_callbacks]) """ + def __init__(self, *args, **kwargs): + """Initialise ModelCheckpointAdvanced and set callbacks attribute.""" self.callbacks = kwargs.pop("callbacks") super().__init__(*args, **kwargs) def update_best(self, hist): """ - Update internal best on resuming a training process. Otherwise best is set to +/- inf depending on the - performance metric and the first trained model (first of the resuming training process) will always saved as - best model because its performance will be better than infinity. To prevent this behaviour and compare the - performance with the best model performance, call this method before resuming the training process. + Update internal best on resuming a training process. + + If no best object is available, best is set to +/- inf depending on the performance metric and the first trained + model (first of the resuming training process) will always saved as best model because its performance will be + better than infinity. To prevent this behaviour and compare the performance with the best model performance, + call this method before resuming the training process. + :param hist: The History object from the previous (interrupted) training. """ self.best = hist.history.get(self.monitor)[-1] def update_callbacks(self, callbacks): """ - Update all stored callback objects. The argument callbacks needs to follow the same convention like described - in the class description (list of dictionaries). Must be run before resuming a training process. + Update all stored callback objects. + + The argument callbacks needs to follow the same convention like described in the class description (list of + dictionaries). Must be run before resuming a training process. """ self.callbacks = helpers.to_list(callbacks) def on_epoch_end(self, epoch, logs=None): - """ - Save model as usual (see ModelCheckpoint class), but also save additional callbacks. - """ + """Save model as usual (see ModelCheckpoint class), but also save additional callbacks.""" super().on_epoch_end(epoch, logs) for callback in self.callbacks: @@ -152,10 +192,67 @@ class ModelCheckpointAdvanced(ModelCheckpoint): pickle.dump(callback["callback"], f) +clbk_type = TypedDict("clbk_type", {"name": str, str: Callback, "path": str}) + + class CallbackHandler: + r"""Use the CallbackHandler for better controlling of custom callbacks. + + The callback handler will always keep your callbacks in the right order and adds a model checkpoint at last position + if required. You can add an arbitrary number of callbacks to the handler. First, add all callbacks and finally + create the model checkpoint. Callbacks that have been added after checkpoint create wouldn't be part if it. + Therefore, the handler blocks adding of new callbacks after creation of model checkpoint. + + .. code-block:: python + + # init callbacks handler + callbacks = CallbackHandler() + + # set history object (add further elements like this example) + hist = keras.callbacks.History() + callbacks.add_callback(hist, "callbacks-hist.pickle", "hist") + + # create advanced checkpoint (details see ModelCheckpointAdvanced) + ckpt_name = "model-best.h5" + callbacks.create_model_checkpoint(filepath=ckpt_name, verbose=1, ...) + + # get checkpoint + ckpt = callbacks.get_checkpoint() + + # fit already compiled model and add callbacks, it is important to call get_callbacks with as_dict=False + history = model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False)) + + If you want to continue a training, you can use the callback handler to load already stored callbacks. First you + need to reload all callbacks. Make sure, that all callbacks are available from previous training. If the callback + handler was set up like in the former code example, this will work. + + .. code-block:: python + + # load callbacks and update checkpoint + callbacks.load_callbacks() + callbacks.update_checkpoint() + + # optional: load your model using checkpoint path + model = keras.models.load_model(ckpt.filepath) + + # extract history object and set starting epoch + hist = callbacks.get_callback_by_name("hist") + initial_epoch = max(hist.epoch) + 1 + + # resume training (including initial_epoch) and use callback handler's history object + _ = self.model.fit(..., callbacks=self.callbacks.get_callbacks(as_dict=False), initial_epoch=initial_epoch) + history = hist + + Important notes: Do not use the returned history object of model.fit, but use the history object from callback + handler. The fit history will only contain the new history, whereas callback handler's history contains the full + history including the resumed and new history. For a correct epoch counting, you need to add the initial epoch to + the fit method too. + + """ def __init__(self): - self.__callbacks = [] + """Initialise CallbackHandler.""" + self.__callbacks: List[clbk_type] = [] self._checkpoint = None self.editable = True @@ -168,46 +265,79 @@ class CallbackHandler: name, callback, callback_path = value self.__callbacks.append({"name": name, name: callback, "path": callback_path}) - def _update_callback(self, pos, value): + def _update_callback(self, pos: int, value: Callback) -> None: + """Update callback entry with given value.""" name = self.__callbacks[pos]["name"] self.__callbacks[pos][name] = value - def add_callback(self, callback, callback_path, name="callback"): + def add_callback(self, callback: Callback, callback_path: str, name: str = "callback") -> None: + """ + Add given callback on last position if CallbackHandler is editable. + + Save callback with given name. Will raise a PermissionError, if editable is False. + + :param callback: callback object to store + :param callback_path: path to callback + :param name: name of the callback + """ if self.editable: self._callbacks = (name, callback, callback_path) else: raise PermissionError(f"{__class__.__name__} is protected and cannot be edited.") - def get_callbacks(self, as_dict=True): + def get_callbacks(self, as_dict=True) -> Union[List[clbk_type], List[Callback]]: + """ + Get all callbacks including checkpoint on last position. + + :param as_dict: set return format, either clbk_type with dictionary structure (as_dict=True, default) or list + + :return: all callbacks either as callback dictionary structure (embedded in a list) or as raw objects in a list + """ if as_dict: return self._get_callbacks() else: return [clb["callback"] for clb in self._get_callbacks()] - def get_callback_by_name(self, obj_name): + def get_callback_by_name(self, obj_name: str) -> Union[Callback, History]: + """ + Get single callback by its name. + + :param obj_name: name of callback to look for + + :return: requested callback object + """ if obj_name != "callback": return [clbk[clbk["name"]] for clbk in self.__callbacks if clbk["name"] == obj_name][0] - def _get_callbacks(self): + def _get_callbacks(self) -> List[clbk_type]: + """Return all callbacks and append checkpoint if available on last position.""" clbks = self._callbacks if self._checkpoint is not None: clbks += [{"callback": self._checkpoint, "path": self._checkpoint.filepath}] return clbks - def get_checkpoint(self): + def get_checkpoint(self) -> ModelCheckpointAdvanced: + """Return current checkpoint if available.""" if self._checkpoint is not None: return self._checkpoint def create_model_checkpoint(self, **kwargs): + """Create a model checkpoint and enable edit.""" self._checkpoint = ModelCheckpointAdvanced(callbacks=self._callbacks, **kwargs) self.editable = False - def load_callbacks(self): + def load_callbacks(self) -> None: + """Load callbacks from path and save in callback attribute.""" for pos, callback in enumerate(self.__callbacks): path = callback["path"] clb = pickle.load(open(path, "rb")) self._update_callback(pos, clb) - def update_checkpoint(self, history_name="hist"): + def update_checkpoint(self, history_name: str = "hist") -> None: + """ + Update callbacks and history's best elements. + + :param history_name: name of history object + """ self._checkpoint.update_callbacks(self._callbacks) self._checkpoint.update_best(self.get_callback_by_name(history_name)) diff --git a/src/model_modules/linear_model.py b/src/model_modules/linear_model.py index 933a108c1b06e1786f75e7f4ebd9b220fbe812dd..e556f0358a2a5e5247f7b6cc7d416af25a8a664d 100644 --- a/src/model_modules/linear_model.py +++ b/src/model_modules/linear_model.py @@ -1,25 +1,47 @@ +"""Calculate ordinary least squared model.""" + __author__ = "Felix Kleinert, Lukas Leufen" __date__ = '2019-12-11' - import numpy as np import statsmodels.api as sm class OrdinaryLeastSquaredModel: + """ + Implementation of an ordinary least squared model (OLS). + + Inputs and outputs are retrieved from a generator. This generator needs to return in xarray format and has to be + iterable. OLS is calculated on initialisation using statsmodels package. Train your personal OLS using: + + .. code-block:: python + + # next(train_data) should be return (x, y) + my_ols_model = OrdinaryLeastSquaredModel(train_data) + + After calculation, use your OLS model with + + .. code-block:: python + + # input_data needs to be structured like train data + result_ols = my_ols_model.predict(input_data) + + :param generator: generator object returning a tuple containing inputs and outputs as xarrays + """ def __init__(self, generator): + """Set up OLS model.""" self.x = [] self.y = [] self.generator = generator - self.model = self.train_ols_model_from_generator() + self.model = self._train_ols_model_from_generator() - def train_ols_model_from_generator(self): - self.set_x_y_from_generator() + def _train_ols_model_from_generator(self): + self._set_x_y_from_generator() self.x = sm.add_constant(self.x) return self.ordinary_least_squared_model(self.x, self.y) - def set_x_y_from_generator(self): + def _set_x_y_from_generator(self): data_x = None data_y = None for item in self.generator: @@ -31,16 +53,19 @@ class OrdinaryLeastSquaredModel: self.y = data_y def predict(self, data): + """Apply OLS model on data.""" data = sm.add_constant(self.reshape_xarray_to_numpy(data), has_constant="add") return np.atleast_2d(self.model.predict(data)) @staticmethod def reshape_xarray_to_numpy(data): + """Reshape xarray data to numpy data and flatten.""" shape = data.values.shape res = data.values.reshape(shape[0], shape[1] * shape[3]) return res @staticmethod def ordinary_least_squared_model(x, y): + """Calculate ols model using statsmodels.""" ols_model = sm.OLS(y, x) return ols_model.fit() diff --git a/src/model_modules/loss.py b/src/model_modules/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..bcb85282d0fa15f18ebd65a89e4020c2a0170224 --- /dev/null +++ b/src/model_modules/loss.py @@ -0,0 +1,22 @@ +"""Collection of different customised loss functions.""" + +from keras import backend as K + +from typing import Callable + + +def l_p_loss(power: int) -> Callable: + """ + Calculate the L<p> loss for given power p. + + L1 (p=1) is equal to mean absolute error (MAE), L2 (p=2) is to mean squared error (MSE), ... + + :param power: set the power of the error calculus + + :return: loss for given power + """ + + def loss(y_true, y_pred): + return K.mean(K.pow(K.abs(y_pred - y_true), power), axis=-1) + + return loss diff --git a/src/model_modules/model_class.py b/src/model_modules/model_class.py index b46213e591798861fea4f0da13c9bab824200b4b..ced01e9ad25b0654097d6fc1b5b7d00166328c80 100644 --- a/src/model_modules/model_class.py +++ b/src/model_modules/model_class.py @@ -1,10 +1,119 @@ +""" +Module for neural models to use during experiment. + +To work properly, each customised model needs to inherit from AbstractModelClass and needs an implementation of the +set_model and set_loss method. + +In this module, you can find some exemplary model classes that have been build and were running in a experiment. + +* `MyLittleModel`: small model implementation with a single 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time). +* `MyBranchedModel`: a model with single 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), it has three + output branches from different layers of the model. +* `MyTowerModel`: a more complex model with inception blocks (called towers) +* `MyPaperModel`: A model used for the publication: <Add Publication Title / Citation> + +In addition, a short introduction how to create your own model is given hereinafter. + +How to create a customised model? +################################# + +* Create a new class: + + .. code-block:: python + + class MyCustomisedModel(AbstractModelClass): + + def __init__(self, window_history_size, window_lead_time, channels): + super.__init__() + # settings + self.window_history_size = window_history_size + self.window_lead_time = window_lead_time + self.channels = channels + self.dropout_rate = 0.1 + + # apply to model + self.set_model() + self.set_loss() + self.set_custom_objects(loss=self.loss) + +* Make sure to add the `super().__init__()` and at least `set_model()` and `set_loss()` to your custom init method. +* If you have custom objects in your model, that are not part of keras, you need to add them to custom objects. To do + this, call `set_custom_objects` with arbitrarily kwargs. In the shown example, the loss has been added, because it + wasn't a standard loss. Apart from this, we always encourage you to add the loss as custom object, to prevent + potential errors when loading an already created model instead of training a new one. +* Build your model inside `set_model()`, e.g. + + .. code-block:: python + + class MyCustomisedModel(AbstractModelClass): + + def set_model(self): + x_input = keras.layers.Input(shape=(self.window_history_size + 1, 1, self.channels)) + x_in = keras.layers.Conv2D(32, (1, 1), padding='same', name='{}_Conv_1x1'.format("major"))(x_input) + x_in = self.activation(name='{}_conv_act'.format("major"))(x_in) + x_in = keras.layers.Flatten(name='{}'.format("major"))(x_in) + x_in = keras.layers.Dropout(self.dropout_rate, name='{}_Dropout_1'.format("major"))(x_in) + x_in = keras.layers.Dense(16, name='{}_Dense_16'.format("major"))(x_in) + x_in = self.activation()(x_in) + x_in = keras.layers.Dense(self.window_lead_time, name='{}_Dense'.format("major"))(x_in) + out_main = self.activation()(x_in) + self.model = keras.Model(inputs=x_input, outputs=[out_main]) + +* Your are free, how to design your model. Just make sure to save it in the class attribute model. +* Finally, set your custom loss. + + .. code-block:: python + + class MyCustomisedModel(AbstractModelClass): + + def set_loss(self): + self.loss = keras.losses.mean_squared_error + +* If you have a branched model with multiple outputs, you need either set only a single loss for all branch outputs or + to provide the same number of loss functions considering the right order. E.g. + + .. code-block:: python + + class MyCustomisedModel(AbstractModelClass): + + def set_model(self): + ... + self.model = keras.Model(inputs=x_input, outputs=[out_minor_1, out_minor_2, out_main]) + + def set_loss(self): + self.loss = [keras.losses.mean_absolute_error] + # for out_minor_1 + [keras.losses.mean_squared_error] + # for out_minor_2 + [keras.losses.mean_squared_error] # for out_main + + +How to access my customised model? +################################## + +If the customised model is created, you can easily access the model with + +>>> MyCustomisedModel().model +<your custom model> + +The loss is accessible via + +>>> MyCustomisedModel().loss +<your custom loss> + +You can treat the instance of your model as instance but also as the model itself. If you call a method, that refers to +the model instead of the model instance, you can directly apply the command on the instance instead of adding the model +parameter call. + +>>> MyCustomisedModel().model.compile(**kwargs) == MyCustomisedModel().compile(**kwargs) +True + +""" + import src.model_modules.keras_extensions __author__ = "Lukas Leufen, Felix Kleinert" # __date__ = '2019-12-12' __date__ = '2020-05-12' - from abc import ABC from typing import Any, Callable, Dict @@ -17,20 +126,16 @@ from src.model_modules.advanced_paddings import PadUtils, Padding2D class AbstractModelClass(ABC): - """ - The AbstractModelClass provides a unified skeleton for any model provided to the machine learning workflow. The - model can always be accessed by calling ModelClass.model or directly by an model method without parsing the model - attribute name (e.g. ModelClass.model.compile -> ModelClass.compile). Beside the model, this class provides the - corresponding loss function. + The AbstractModelClass provides a unified skeleton for any model provided to the machine learning workflow. + + The model can always be accessed by calling ModelClass.model or directly by an model method without parsing the + model attribute name (e.g. ModelClass.model.compile -> ModelClass.compile). Beside the model, this class provides + the corresponding loss function. """ def __init__(self) -> None: - - """ - Predefine internal attributes for model and loss. - """ - + """Predefine internal attributes for model and loss.""" self.__model = None self.model_name = self.__class__.__name__ self.__custom_objects = {} @@ -45,27 +150,28 @@ class AbstractModelClass(ABC): self.__compile_options = self.__allowed_compile_options def __getattr__(self, name: str) -> Any: - """ - Is called if __getattribute__ is not able to find requested attribute. Normally, the model class is saved into - a variable like `model = ModelClass()`. To bypass a call like `model.model` to access the _model attribute, - this method tries to search for the named attribute in the self.model namespace and returns this attribute if - available. Therefore, following expression is true: `ModelClass().compile == ModelClass().model.compile` as long - the called attribute/method is not part if the ModelClass itself. + Is called if __getattribute__ is not able to find requested attribute. + + Normally, the model class is saved into a variable like `model = ModelClass()`. To bypass a call like + `model.model` to access the _model attribute, this method tries to search for the named attribute in the + self.model namespace and returns this attribute if available. Therefore, following expression is true: + `ModelClass().compile == ModelClass().model.compile` as long the called attribute/method is not part if the + ModelClass itself. + :param name: name of the attribute or method to call + :return: attribute or method from self.model namespace """ - return self.model.__getattribute__(name) @property def model(self) -> keras.Model: - """ The model property containing a keras.Model instance. + :return: the keras model """ - return self.__model @model.setter @@ -75,9 +181,11 @@ class AbstractModelClass(ABC): @property def custom_objects(self) -> Dict: """ - The custom objects property collects all non-keras utilities that are used in the model class. To load such a - customised and already compiled model (e.g. from local disk), this information is required. - :return: the custom objects in a dictionary + The custom objects property collects all non-keras utilities that are used in the model class. + + To load such a customised and already compiled model (e.g. from local disk), this information is required. + + :return: custom objects in a dictionary """ return self.__custom_objects @@ -179,12 +287,14 @@ class AbstractModelClass(ABC): def get_settings(self) -> Dict: """ Get all class attributes that are not protected in the AbstractModelClass as dictionary. + :return: all class attributes """ return dict((k, v) for (k, v) in self.__dict__.items() if not k.startswith("_AbstractModelClass__")) def set_model(self): - pass + """Abstract method to set model.""" + raise NotImplementedError def set_compile_options(self): """ @@ -201,14 +311,16 @@ class AbstractModelClass(ABC): :return: """ - pass + raise NotImplementedError def set_custom_objects(self, **kwargs) -> None: """ - Set custom objects that are not part of keras framework. These custom objects are needed if an already compiled - model is loaded from disk. There is a special treatment for the Padding2D class, which is a base class for - different padding types. For a correct behaviour, all supported subclasses are added as custom objects in - addition to the given ones. + Set custom objects that are not part of keras framework. + + These custom objects are needed if an already compiled model is loaded from disk. There is a special treatment + for the Padding2D class, which is a base class for different padding types. For a correct behaviour, all + supported subclasses are added as custom objects in addition to the given ones. + :param kwargs: all custom objects, that should be saved """ if "Padding2D" in kwargs.keys(): @@ -217,7 +329,6 @@ class AbstractModelClass(ABC): class MyLittleModel(AbstractModelClass): - """ A customised model with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first @@ -225,9 +336,9 @@ class MyLittleModel(AbstractModelClass): """ def __init__(self, window_history_size, window_lead_time, channels): - """ Sets model and loss depending on the given arguments. + :param activation: activation function :param window_history_size: number of historical time steps included in the input data :param channels: number of variables used in input data @@ -254,9 +365,9 @@ class MyLittleModel(AbstractModelClass): self.set_custom_objects(loss=self.compile_options['loss']) def set_model(self): - """ Build the model. + :param activation: activation function :param window_history_size: number of historical time steps included in the input data :param channels: number of variables used in input data @@ -290,20 +401,18 @@ class MyLittleModel(AbstractModelClass): class MyBranchedModel(AbstractModelClass): - """ A customised model - with a 1x1 Conv, and 4 Dense layers (64, 32, 16, window_lead_time), where the last layer is the output layer depending on the window_lead_time parameter. Dropout is used between the Convolution and the first Dense layer. """ def __init__(self, window_history_size, window_lead_time, channels): - """ Sets model and loss depending on the given arguments. + :param activation: activation function :param window_history_size: number of historical time steps included in the input data :param channels: number of variables used in input data @@ -330,9 +439,9 @@ class MyBranchedModel(AbstractModelClass): self.set_custom_objects(loss=self.compile_options["loss"]) def set_model(self): - """ Build the model. + :param activation: activation function :param window_history_size: number of historical time steps included in the input data :param channels: number of variables used in input data @@ -373,9 +482,9 @@ class MyBranchedModel(AbstractModelClass): class MyTowerModel(AbstractModelClass): def __init__(self, window_history_size, window_lead_time, channels): - """ Sets model and loss depending on the given arguments. + :param activation: activation function :param window_history_size: number of historical time steps included in the input data :param channels: number of variables used in input data @@ -393,9 +502,10 @@ class MyTowerModel(AbstractModelClass): self.dropout_rate = 1e-2 self.regularizer = keras.regularizers.l2(0.1) self.initial_lr = 1e-2 - self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) + self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, + epochs_drop=10) self.epochs = 20 - self.batch_size = int(256*4) + self.batch_size = int(256 * 4) self.activation = keras.layers.PReLU # apply to model @@ -404,9 +514,9 @@ class MyTowerModel(AbstractModelClass): self.set_custom_objects(loss=self.compile_options["loss"]) def set_model(self): - """ Build the model. + :param activation: activation function :param window_history_size: number of historical time steps included in the input data :param channels: number of variables used in input data @@ -430,7 +540,7 @@ class MyTowerModel(AbstractModelClass): 'activation': activation}, 'tower_3': {'reduction_filter': 8 * 2, 'tower_filter': 16 * 2 * 2, 'tower_kernel': (1, 1), 'activation': activation}, - } + } pool_settings_dict2 = {'pool_kernel': (3, 1), 'tower_filter': 16, 'activation': activation} conv_settings_dict3 = {'tower_1': {'reduction_filter': 16 * 4, 'tower_filter': 32 * 2, 'tower_kernel': (3, 1), @@ -447,7 +557,8 @@ class MyTowerModel(AbstractModelClass): inception_model = InceptionModelBase() X_input = keras.layers.Input( - shape=(self.window_history_size + 1, 1, self.channels)) # add 1 to window_size to include current time step t0 + shape=( + self.window_history_size + 1, 1, self.channels)) # add 1 to window_size to include current time step t0 X_in = inception_model.inception_block(X_input, conv_settings_dict1, pool_settings_dict1, regularizer=self.regularizer, @@ -455,12 +566,14 @@ class MyTowerModel(AbstractModelClass): X_in = keras.layers.Dropout(self.dropout_rate)(X_in) - X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=self.regularizer, + X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, + regularizer=self.regularizer, batch_normalisation=True) X_in = keras.layers.Dropout(self.dropout_rate)(X_in) - X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, regularizer=self.regularizer, + X_in = inception_model.inception_block(X_in, conv_settings_dict3, pool_settings_dict3, + regularizer=self.regularizer, batch_normalisation=True) ############################################# @@ -483,9 +596,9 @@ class MyTowerModel(AbstractModelClass): class MyPaperModel(AbstractModelClass): def __init__(self, window_history_size, window_lead_time, channels): - """ Sets model and loss depending on the given arguments. + :param activation: activation function :param window_history_size: number of historical time steps included in the input data :param channels: number of variables used in input data @@ -503,7 +616,8 @@ class MyPaperModel(AbstractModelClass): self.dropout_rate = .3 self.regularizer = keras.regularizers.l2(0.001) self.initial_lr = 1e-3 - self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, epochs_drop=10) + self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94, + epochs_drop=10) self.epochs = 150 self.batch_size = int(256 * 2) self.activation = keras.layers.ELU @@ -515,9 +629,9 @@ class MyPaperModel(AbstractModelClass): self.set_custom_objects(loss=self.compile_options["loss"], Padding2D=Padding2D) def set_model(self): - """ Build the model. + :param activation: activation function :param window_history_size: number of historical time steps included in the input data :param channels: number of variables used in input data @@ -526,7 +640,7 @@ class MyPaperModel(AbstractModelClass): :return: built keras model """ activation = self.activation - first_kernel = (3,1) + first_kernel = (3, 1) first_filters = 16 conv_settings_dict1 = { @@ -566,7 +680,8 @@ class MyPaperModel(AbstractModelClass): inception_model = InceptionModelBase() X_input = keras.layers.Input( - shape=(self.window_history_size + 1, 1, self.channels)) # add 1 to window_size to include current time step t0 + shape=( + self.window_history_size + 1, 1, self.channels)) # add 1 to window_size to include current time step t0 pad_size = PadUtils.get_padding_for_same(first_kernel) # X_in = adv_pad.SymmetricPadding2D(padding=pad_size)(X_input) @@ -578,7 +693,6 @@ class MyPaperModel(AbstractModelClass): name="First_conv_{}x{}".format(first_kernel[0], first_kernel[1]))(X_in) X_in = self.activation(name='FirstAct')(X_in) - X_in = inception_model.inception_block(X_in, conv_settings_dict1, pool_settings_dict1, regularizer=self.regularizer, batch_normalisation=True, @@ -593,7 +707,8 @@ class MyPaperModel(AbstractModelClass): X_in = keras.layers.Dropout(self.dropout_rate)(X_in) - X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, regularizer=self.regularizer, + X_in = inception_model.inception_block(X_in, conv_settings_dict2, pool_settings_dict2, + regularizer=self.regularizer, batch_normalisation=True, padding=self.padding) # X_in = keras.layers.Dropout(self.dropout_rate)(X_in) diff --git a/src/plotting/__init__.py b/src/plotting/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..cc92014bb42fcf43b983d576fe6d88aeb2dd797b 100644 --- a/src/plotting/__init__.py +++ b/src/plotting/__init__.py @@ -0,0 +1 @@ +"""Collection of all plots that can be used during experiment for monitoring and evaluation.""" diff --git a/src/plotting/postprocessing_plotting.py b/src/plotting/postprocessing_plotting.py index 4fcb1f49d828f47c36a5597341585896a19fcc9a..541b5976b0cf58f865a3748431f4d74c12812a68 100644 --- a/src/plotting/postprocessing_plotting.py +++ b/src/plotting/postprocessing_plotting.py @@ -1,3 +1,4 @@ +"""Collection of plots to evaluate a model, create overviews on data or forecasts.""" __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-17' @@ -9,35 +10,78 @@ from typing import Dict, List, Tuple import matplotlib +import matplotlib.patches as mpatches import matplotlib.pyplot as plt import numpy as np import pandas as pd import seaborn as sns import xarray as xr from matplotlib.backends.backend_pdf import PdfPages -import matplotlib.patches as mpatches from src import helpers -from src.helpers import TimeTracking, TimeTrackingWrapper -from src.data_handling.data_generator import DataGenerator +from src.data_handling import DataGenerator +from src.helpers import TimeTrackingWrapper logging.getLogger('matplotlib').setLevel(logging.WARNING) class AbstractPlotClass: + """ + Abstract class for all plotting routines to unify plot workflow. + + Each inheritance requires a _plot method. Create a plot class like: + + .. code-block:: python + + class MyCustomPlot(AbstractPlotClass): + + def __init__(self, plot_folder, *args, **kwargs): + super().__init__(plot_folder, "custom_plot_name") + self._data = self._prepare_data(*args, **kwargs) + self._plot(*args, **kwargs) + self._save() + + def _prepare_data(*args, **kwargs): + <your custom data preparation> + return data + + def _plot(*args, **kwargs): + <your custom plotting without saving> + + The save method is already implemented in the AbstractPlotClass. If special saving is required (e.g. if you are + using pdfpages), you need to overwrite it. Plots are saved as .pdf with a resolution of 500dpi per default (can be + set in super class initialisation). + + Methods like the shown _prepare_data() are optional. The only method required to implement is _plot. + + If you want to add a time tracking module, just add the TimeTrackingWrapper as decorator around your custom plot + class. It will log the spent time if you call your plotting without saving the returned object. + + .. code-block:: python + + @TimeTrackingWrapper + class MyCustomPlot(AbstractPlotClass): + pass + + Let's assume it takes a while to create this very special plot. + + >>> MyCustomPlot() + INFO: MyCustomPlot finished after 00:00:11 (hh:mm:ss) + + """ def __init__(self, plot_folder, plot_name, resolution=500): + """Set up plot folder and name, and plot resolution (default 500dpi).""" self.plot_folder = plot_folder self.plot_name = plot_name self.resolution = resolution def _plot(self, *args): + """Abstract plot class needs to be implemented in inheritance.""" raise NotImplementedError def _save(self, **kwargs): - """ - Standard save method to store plot locally. Name of and path to plot need to be set on initialisation - """ + """Store plot locally. Name of and path to plot need to be set on initialisation.""" plot_name = os.path.join(os.path.abspath(self.plot_folder), f"{self.plot_name}.pdf") logging.debug(f"... save plot to {plot_name}") plt.savefig(plot_name, dpi=self.resolution, **kwargs) @@ -47,21 +91,26 @@ class AbstractPlotClass: @TimeTrackingWrapper class PlotMonthlySummary(AbstractPlotClass): """ - Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. The plot is saved - in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution. + Show a monthly summary over all stations for each lead time ("ahead") as box and whiskers plot. + + The plot is saved in data_path with name monthly_summary_box_plot.pdf and 500dpi resolution. + + .. image:: ../../../../../_source/_plots/monthly_summary_box_plot.png + :width: 400 + + :param stations: all stations to plot + :param data_path: path, where the data is located + :param name: full name of the local files with a % as placeholder for the station name + :param target_var: display name of the target variable on plot's axis + :param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given + the maximum lead time from data is used. (default None -> use maximum lead time from data). + :param plot_folder: path to save the plot (default: current directory) + """ + def __init__(self, stations: List, data_path: str, name: str, target_var: str, window_lead_time: int = None, plot_folder: str = "."): - """ - Sets attributes and create plot - :param stations: all stations to plot - :param data_path: path, where the data is located - :param name: full name of the local files with a % as placeholder for the station name - :param target_var: display name of the target variable on plot's axis - :param window_lead_time: lead time to plot, if window_lead_time is higher than the available lead time or not given - the maximum lead time from data is used. (default None -> use maximum lead time from data). - :param plot_folder: path to save the plot (default: current directory) - """ + """Set attributes and create plot.""" super().__init__(plot_folder, "monthly_summary_box_plot") self._data_path = data_path self._data_name = name @@ -72,8 +121,11 @@ class PlotMonthlySummary(AbstractPlotClass): def _prepare_data(self, stations: List) -> xr.DataArray: """ - Pre-process data required to plot. For each station, load locally saved predictions, extract the CNN prediction - and the observation and group them into monthly bins (no aggregation, only sorting them). + Pre.process data required to plot. + + For each station, load locally saved predictions, extract the CNN prediction and the observation and group them + into monthly bins (no aggregation, only sorting them). + :param stations: all stations to plot :return: The entire data set, flagged with the corresponding month. """ @@ -101,9 +153,12 @@ class PlotMonthlySummary(AbstractPlotClass): def _get_window_lead_time(self, window_lead_time: int): """ - Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from - data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number - of ahead dimensions in data is lower than the given lead time, data's lead time is used. + Extract the lead time from data and arguments. + + If window_lead_time is not given, extract this information from data itself by the number of ahead dimensions. + If given, check if data supports the give length. If the number of ahead dimensions in data is lower than the + given lead time, data's lead time is used. + :param window_lead_time: lead time from arguments to validate :return: validated lead time, comes either from given argument or from data itself """ @@ -114,13 +169,14 @@ class PlotMonthlySummary(AbstractPlotClass): def _plot(self, target_var: str): """ - Main plot function that creates a monthly grouped box plot over all stations but with separate boxes for each - lead time step. + Create a monthly grouped box plot over all stations but with separate boxes for each lead time step. + :param target_var: display name of the target variable on plot's axis """ data = self._data.to_dataset(name='values').to_dask_dataframe() logging.debug("... start plotting") - color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", self._window_lead_time).as_hex() + color_palette = [matplotlib.colors.cnames["green"]] + sns.color_palette("Blues_d", + self._window_lead_time).as_hex() ax = sns.boxplot(x='index', y='values', hue='ahead', data=data.compute(), whis=1., palette=color_palette, flierprops={'marker': '.', 'markersize': 1}, showmeans=True, meanprops={'markersize': 1, 'markeredgecolor': 'k'}) @@ -131,15 +187,20 @@ class PlotMonthlySummary(AbstractPlotClass): @TimeTrackingWrapper class PlotStationMap(AbstractPlotClass): """ - Plot geographical overview of all used stations as squares. Different data sets can be colorised by its key in the - input dictionary generators. The key represents the color to plot on the map. Currently, there is only a white - background, but this can be adjusted by loading locally stored topography data (not implemented yet). The plot is - saved under plot_path with the name station_map.pdf + Plot geographical overview of all used stations as squares. + + Different data sets can be colorised by its key in the input dictionary generators. The key represents the color to + plot on the map. Currently, there is only a white background, but this can be adjusted by loading locally stored + topography data (not implemented yet). The plot is saved under plot_path with the name station_map.pdf + + .. image:: ../../../../../_source/_plots/station_map.png + :width: 400 """ def __init__(self, generators: Dict, plot_folder: str = "."): """ - Sets attributes and create plot + Set attributes and create plot. + :param generators: dictionary with the plot color of each data set as key and the generator containing all stations as value. :param plot_folder: path to save the plot (default: current directory) @@ -150,9 +211,7 @@ class PlotStationMap(AbstractPlotClass): self._save() def _draw_background(self): - """ - Draw coastline, lakes, ocean, rivers and country borders as background on the map. - """ + """Draw coastline, lakes, ocean, rivers and country borders as background on the map.""" import cartopy.feature as cfeature self._ax.add_feature(cfeature.COASTLINE.with_scale("50m"), edgecolor='black') @@ -163,8 +222,10 @@ class PlotStationMap(AbstractPlotClass): def _plot_stations(self, generators): """ - The actual plot function. Loops over all keys in generators dict and its containing stations and plots a square - and the stations's position on the map regarding the given color. + Loop over all keys in generators dict and its containing stations and plot the stations's position. + + Position is highlighted by a square on the map regarding the given color. + :param generators: dictionary with the plot color of each data set as key and the generator containing all stations as value. """ @@ -181,7 +242,10 @@ class PlotStationMap(AbstractPlotClass): def _plot(self, generators: Dict): """ - Main plot function to create the station map plot. Sets figure and calls all required sub-methods. + Create the station map plot. + + Set figure and call all required sub-methods. + :param generators: dictionary with the plot color of each data set as key and the generator containing all stations as value. """ @@ -197,10 +261,29 @@ class PlotStationMap(AbstractPlotClass): @TimeTrackingWrapper class PlotConditionalQuantiles(AbstractPlotClass): """ - This class creates cond.quantile plots as originally proposed by Murphy, Brown and Chen (1989) [But in log scale] + Create cond.quantile plots as originally proposed by Murphy, Brown and Chen (1989) [But in log scale]. Link to paper: https://journals.ametsoc.org/doi/pdf/10.1175/1520-0434%281989%29004%3C0485%3ADVOTF%3E2.0.CO%3B2 + + .. image:: ../../../../../_source/_plots/conditional_quantiles_cali-ref_plot.png + :width: 400 + + .. image:: ../../../../../_source/_plots/conditional_quantiles_like-bas_plot.png + :width: 400 + + For each time step ahead a separate plot is created. If parameter plot_per_season is true, data is split by season + and conditional quantiles are plotted for each season in addition. + + :param stations: all stations to plot + :param data_pred_path: path to dir which contains the forecasts as .nc files + :param plot_folder: path where the plots are stored + :param plot_per_seasons: if `True' create cond. quantile plots for _seasons (DJF, MAM, JJA, SON) individually + :param rolling_window: smoothing of quantiles (3 is used by Murphy et al.) + :param model_mame: name of the model prediction as stored in netCDF file (for example "CNN") + :param obs_name: name of observation as stored in netCDF file (for example "obs") + :param kwargs: Some further arguments which are listed in self._opts """ + # ignore warnings if nans appear in quantile grouping warnings.filterwarnings("ignore", message="All-NaN slice encountered") # ignore warnings if mean is calculated on nans @@ -210,46 +293,33 @@ class PlotConditionalQuantiles(AbstractPlotClass): def __init__(self, stations: List, data_pred_path: str, plot_folder: str = ".", plot_per_seasons=True, rolling_window: int = 3, model_mame: str = "CNN", obs_name: str = "obs", **kwargs): - """ - - :param stations: all stations to plot - :param data_pred_path: path to dir which contains the forecasts as .nc files - :param plot_folder: path where the plots are stored - :param plot_per_seasons: if `True' create cond. quantile plots for seasons (DJF, MAM, JJA, SON) individually - :param rolling_window: smoothing of quantiles (3 is used by Murphy et al.) - :param model_mame: name of the model prediction as stored in netCDF file (for example "CNN") - :param obs_name: name of observation as stored in netCDF file (for example "obs") - :param kwargs: Some further arguments which are listed in self._opts - """ + """Initialise.""" super().__init__(plot_folder, "conditional_quantiles") - self._data_pred_path = data_pred_path self._stations = stations self._rolling_window = rolling_window self._model_name = model_mame self._obs_name = obs_name - - self._opts = {"q": kwargs.get("q", [.1, .25, .5, .75, .9]), - "linetype": kwargs.get("linetype", [':', '-.', '--', '-.', ':']), - "legend": kwargs.get("legend", - ['.10th and .90th quantile', '.25th and .75th quantile', '.50th quantile', - 'reference 1:1']), - "data_unit": kwargs.get("data_unit", "ppb"), - } - if plot_per_seasons is True: - self.seasons = ['DJF', 'MAM', 'JJA', 'SON'] - else: - self.seasons = "" + self._opts = self._get_opts(kwargs) + self._seasons = ['DJF', 'MAM', 'JJA', 'SON'] if plot_per_seasons is True else "" self._data = self._load_data() self._bins = self._get_bins_from_rage_of_data() - self._plot() - def _load_data(self): + @staticmethod + def _get_opts(kwargs): + """Extract options from kwargs.""" + return {"q": kwargs.get("q", [.1, .25, .5, .75, .9]), + "linetype": kwargs.get("linetype", [':', '-.', '--', '-.', ':']), + "legend": kwargs.get("legend", ['.10th and .90th quantile', '.25th and .75th quantile', + '.50th quantile', 'reference 1:1']), + "data_unit": kwargs.get("data_unit", "ppb"), } + + def _load_data(self) -> xr.DataArray: """ - This method loads forcast data + Load plot data. - :return: + :return: plot data """ logging.debug("... load data") data_collector = [] @@ -260,13 +330,14 @@ class PlotConditionalQuantiles(AbstractPlotClass): res = xr.concat(data_collector, dim='station').transpose('index', 'type', 'ahead', 'station') return res - def _segment_data(self, data, x_model): + def _segment_data(self, data: xr.DataArray, x_model: str) -> xr.DataArray: """ - This method creates segmented data which is used for cond. quantile plots + Segment data into bins. - :param data: - :param x_model: - :return: + :param data: data to segment + :param x_model: name of x dimension + + :return: segmented data """ logging.debug("... segment data") # combine index and station to multi index @@ -275,17 +346,18 @@ class PlotConditionalQuantiles(AbstractPlotClass): data.coords['z'] = range(len(data.coords['z'])) # segment data of x_model into bins data.loc[x_model, ...] = data.loc[x_model, ...].to_pandas().T.apply(pd.cut, bins=self._bins, - labels=self._bins[1:]).T.values + labels=self._bins[1:]).T.values return data @staticmethod - def _labels(plot_type, data_unit="ppb"): + def _labels(plot_type: str, data_unit: str = "ppb") -> Tuple[str, str]: """ - Helper method to correctly assign (x,y) labels to plots, depending on like-base or cali-ref factorization + Assign (x,y) labels to plots correctly, depending on like-base or cali-ref factorization. - :param plot_type: - :param data_unit: - :return: + :param plot_type: type of plot, either `obs` or a model name + :param data_unit: unit of data to add to labels (default ppb) + + :return: tuple with y and x labels """ names = (f"forecast concentration (in {data_unit})", f"observed concentration (in {data_unit})") if plot_type == "obs": @@ -293,22 +365,23 @@ class PlotConditionalQuantiles(AbstractPlotClass): else: return names[::-1] - def _get_bins_from_rage_of_data(self): + def _get_bins_from_rage_of_data(self) -> np.ndarray: """ - Get array of bins to use for quantiles + Get array of bins to use for quantiles. - :return: + :return: range from 0 to data's maximum + 1 (rounded down) """ return np.arange(0, math.ceil(self._data.max().max()) + 1, 1).astype(int) - def _create_quantile_panel(self, data, x_model, y_model): + def _create_quantile_panel(self, data: xr.DataArray, x_model: str, y_model: str) -> xr.DataArray: """ - Clculate quantiles + Calculate quantiles. - :param data: - :param x_model: - :param y_model: - :return: + :param data: data to calculate quantiles + :param x_model: name of x dimension + :param y_model: name of y dimension + + :return: quantile panel with binned data """ logging.debug("... create quantile panel") # create empty xarray with dims: time steps ahead, quantiles, bin index (numbers create in previous step) @@ -320,83 +393,69 @@ class PlotConditionalQuantiles(AbstractPlotClass): # calculate for each bin of the pred_name data the quantiles of the ref_name data for bin in self._bins[1:]: mask = (data.loc[x_model, ...] == bin) - quantile_panel.loc[..., bin] = data.loc[y_model, ...].where(mask).quantile(self._opts["q"], - dim=['z']).T + quantile_panel.loc[..., bin] = data.loc[y_model, ...].where(mask).quantile(self._opts["q"], dim=['z']).T return quantile_panel @staticmethod - def add_affix(x): + def add_affix(affix: str) -> str: """ - Helper method to add additional information on plot name + Add additional information to plot name with leading underscore or add empty string if affix is empty. - :param x: - :return: + :param affix: string to add + + :return: affix with leading underscore or empty string. """ - return f"_{x}" if len(x) > 0 else "" + return f"_{affix}" if len(affix) > 0 else "" - def _prepare_plots(self, data, x_model, y_model): + def _prepare_plots(self, data: xr.DataArray, x_model: str, y_model: str) -> Tuple[xr.DataArray, xr.DataArray]: """ - Get segmented_data and quantile_panel + Get segmented data and quantile panel. - :param data: - :param x_model: - :param y_model: - :return: + :param data: plot data + :param x_model: name of x dimension + :param y_model: name of y dimension + + :return: segmented data and quantile panel """ segmented_data = self._segment_data(data, x_model) quantile_panel = self._create_quantile_panel(segmented_data, x_model, y_model) return segmented_data, quantile_panel def _plot(self): - """ - Main plot call + """Start plotting routines: overall plot and seasonal (if enabled).""" + logging.info(f"start plotting {self.__class__.__name__}, scheduled number of plots: {(len(self._seasons) + 1) * 2}") - :return: - """ - logging.info(f"start plotting {self.__class__.__name__}, scheduled number of plots: {(len(self.seasons) + 1) * 2}") - - if len(self.seasons) > 0: + if len(self._seasons) > 0: self._plot_seasons() self._plot_all() def _plot_seasons(self): - """ - Seasonal plot call - - :return: - """ - for season in self.seasons: + """Create seasonal plots.""" + for season in self._seasons: self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._model_name, y_model=self._obs_name, plot_name_affix="cali-ref", season=season) self._plot_base(data=self._data.where(self._data['index.season'] == season), x_model=self._obs_name, y_model=self._model_name, plot_name_affix="like-base", season=season) def _plot_all(self): - """ - Full plot call - - :return: - """ + """Plot overall conditional quantiles on full data.""" self._plot_base(data=self._data, x_model=self._model_name, y_model=self._obs_name, plot_name_affix="cali-ref") self._plot_base(data=self._data, x_model=self._obs_name, y_model=self._model_name, plot_name_affix="like-base") @TimeTrackingWrapper - def _plot_base(self, data, x_model, y_model, plot_name_affix, season=""): + def _plot_base(self, data: xr.DataArray, x_model: str, y_model: str, plot_name_affix: str, season: str = ""): """ - Base method to create cond. quantile plots. Is called from _plot_all and _plot_seasonal + Create conditional quantile plots. :param data: data which is used to create cond. quantile plot :param x_model: name of model on x axis (can also be obs) :param y_model: name of model on y axis (can also be obs) :param plot_name_affix: should be `cali-ref' or `like-base' - :param season: List of seasons to use - :return: + :param season: List of _seasons to use """ - segmented_data, quantile_panel = self._prepare_plots(data, x_model, y_model) ylabel, xlabel = self._labels(x_model, self._opts["data_unit"]) plot_name = f"{self.plot_name}{self.add_affix(season)}{self.add_affix(plot_name_affix)}_plot.pdf" - #f"{base_name}{add_affix(season)}{add_affix(plot_name_affix)}_plot.pdf" plot_path = os.path.join(os.path.abspath(self.plot_folder), plot_name) pdf_pages = matplotlib.backends.backend_pdf.PdfPages(plot_path) logging.debug(f"... plot path is {plot_path}") @@ -445,22 +504,30 @@ class PlotConditionalQuantiles(AbstractPlotClass): @TimeTrackingWrapper class PlotClimatologicalSkillScore(AbstractPlotClass): """ - Create plot of climatological skill score after Murphy (1988) as box plot over all stations. A forecast time step - (called "ahead") is separately shown to highlight the differences for each prediction time step. Either each single - term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True, - default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with - name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. + Create plot of climatological skill score after Murphy (1988) as box plot over all stations. + + A forecast time step (called "ahead") is separately shown to highlight the differences for each prediction time + step. Either each single term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed + (score_only=True, default). Y-axis is adjusted following the data and not hard coded. The plot is saved under + plot_folder path with name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. + + .. image:: ../../../../../_source/_plots/skill_score_clim_all_terms_CNN.png + :width: 400 + + .. image:: ../../../../../_source/_plots/skill_score_clim_CNN.png + :width: 400 + + :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. + :param plot_folder: path to save the plot (default: current directory) + :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True) + :param extra_name_tag: additional tag that can be included in the plot name (default "") + :param model_setup: architecture type to specify plot name (default "CNN") + """ + def __init__(self, data: Dict, plot_folder: str = ".", score_only: bool = True, extra_name_tag: str = "", model_setup: str = ""): - """ - Sets attributes and create plot - :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. - :param plot_folder: path to save the plot (default: current directory) - :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms (default True) - :param extra_name_tag: additional tag that can be included in the plot name (default "") - :param model_setup: architecture type to specify plot name (default "CNN") - """ + """Initialise.""" super().__init__(plot_folder, f"skill_score_clim_{extra_name_tag}{model_setup}") self._labels = None self._data = self._prepare_data(data, score_only) @@ -469,8 +536,11 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): def _prepare_data(self, data: Dict, score_only: bool) -> pd.DataFrame: """ - Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set - plot labels depending on the lead time dimensions. + Shrink given data, if only scores are relevant. + + In any case, transform data to a plot friendly format. Also set plot labels depending on the lead time + dimensions. + :param data: dictionary with station names as keys and 2D xarrays as values :param score_only: if true only scores of CASE I to IV are relevant :return: pre-processed data set @@ -483,7 +553,8 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): def _label_add(self, score_only: bool): """ - Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True). + Add the phrase "terms and " if score_only is disabled or empty string (if score_only=True). + :param score_only: if false all terms are relevant, otherwise only CASE I to IV :return: additional label """ @@ -491,7 +562,8 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): def _plot(self, score_only): """ - Main plot function to plot climatological skill score. + Plot climatological skill score. + :param score_only: if true plot only scores of CASE I to IV, otherwise plot all single terms """ fig, ax = plt.subplots() @@ -509,17 +581,24 @@ class PlotClimatologicalSkillScore(AbstractPlotClass): @TimeTrackingWrapper class PlotCompetitiveSkillScore(AbstractPlotClass): """ - Create competitive skill score for the given model setup and the reference models ordinary least squared ("ols") and - the persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name + Create competitive skill score plot. + + Create this plot for the given model setup and the reference models ordinary least squared ("ols") and the + persistence forecast ("persi") for all lead times ("ahead"). The plot is saved under plot_folder with the name skill_score_competitive_{model_setup}.pdf and resolution of 500dpi. + + .. image:: ../../../../../_source/_plots/skill_score_competitive.png + :width: 400 + + :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- + calculated comparisons for cnn, persistence and ols. + :param plot_folder: path to save the plot (default: current directory) + :param model_setup: architecture type (default "CNN") + """ + def __init__(self, data: pd.DataFrame, plot_folder=".", model_setup="CNN"): - """ - :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- - calculated comparisons for cnn, persistence and ols. - :param plot_folder: path to save the plot (default: current directory) - :param model_setup: architecture type (default "CNN") - """ + """Initialise.""" super().__init__(plot_folder, f"skill_score_competitive_{model_setup}") self._labels = None self._data = self._prepare_data(data) @@ -528,7 +607,8 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): def _prepare_data(self, data: pd.DataFrame) -> pd.DataFrame: """ - Reformat given data and create plot labels. Introduces the dimensions stations and comparison + Reformat given data and create plot labels and introduce the dimensions stations and comparison. + :param data: data frame with index=['cnn-persi', 'ols-persi', 'cnn-ols'] and columns "ahead" containing the pre- calculated comparisons for cnn, persistence and ols. :return: processed data @@ -542,9 +622,7 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): return data.stack(level=0).reset_index(level=2, drop=True).reset_index(name="data") def _plot(self): - """ - Main plot function to plot skill scores of the comparisons cnn-persi, ols-persi and cnn-ols. - """ + """Plot skill scores of the comparisons cnn-persi, ols-persi and cnn-ols.""" fig, ax = plt.subplots() sns.boxplot(x="comparison", y="data", hue="ahead", data=self._data, whis=1., ax=ax, palette="Blues_d", showmeans=True, meanprops={"markersize": 3, "markeredgecolor": "k"}, flierprops={"marker": "."}, @@ -558,8 +636,11 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): def _ylim(self) -> Tuple[float, float]: """ - Calculate y-axis limits from data. Lower is the minimum of either 0 or data's minimum (reduced by small - subtrahend) and upper limit is data's maximum (increased by a small addend). + Calculate y-axis limits from data. + + Lower limit is the minimum of 0 and data's minimum (reduced by small subtrahend) and upper limit is data's + maximum (increased by a small addend). + :return: """ lower = np.min([0, helpers.float_round(self._data.min()[2], 2) - 0.1]) @@ -570,16 +651,22 @@ class PlotCompetitiveSkillScore(AbstractPlotClass): @TimeTrackingWrapper class PlotBootstrapSkillScore(AbstractPlotClass): """ - Create plot of climatological skill score after Murphy (1988) as box plot over all stations. A forecast time step - (called "ahead") is separately shown to highlight the differences for each prediction time step. Either each single - term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed (score_only=True, - default). Y-axis is adjusted following the data and not hard coded. The plot is saved under plot_folder path with - name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. + Create plot of climatological skill score after Murphy (1988) as box plot over all stations. + + A forecast time step (called "ahead") is separately shown to highlight the differences for each prediction time + step. Either each single term is plotted (score_only=False) or only the resulting scores CASE I to IV are displayed + (score_only=True, default). Y-axis is adjusted following the data and not hard coded. The plot is saved under + plot_folder path with name skill_score_clim_{extra_name_tag}{model_setup}.pdf and resolution of 500dpi. + + .. image:: ../../../../../_source/_plots/skill_score_bootstrap.png + :width: 400 + """ def __init__(self, data: Dict, plot_folder: str = ".", model_setup: str = ""): """ - Sets attributes and create plot + Set attributes and create plot. + :param data: dictionary with station names as keys and 2D xarrays as values, consist on axis ahead and terms. :param plot_folder: path to save the plot (default: current directory) :param model_setup: architecture type to specify plot name (default "CNN") @@ -593,8 +680,11 @@ class PlotBootstrapSkillScore(AbstractPlotClass): def _prepare_data(self, data: Dict) -> pd.DataFrame: """ - Shrink given data, if only scores are relevant. In any case, transform data to a plot friendly format. Also set - plot labels depending on the lead time dimensions. + Shrink given data, if only scores are relevant. + + In any case, transform data to a plot friendly format. Also set plot labels depending on the lead time + dimensions. + :param data: dictionary with station names as keys and 2D xarrays as values :return: pre-processed data set """ @@ -604,16 +694,15 @@ class PlotBootstrapSkillScore(AbstractPlotClass): def _label_add(self, score_only: bool): """ - Adds the phrase "terms and " if score_only is disabled or empty string (if score_only=True). + Add the phrase "terms and " if score_only is disabled or empty string (if score_only=True). + :param score_only: if false all terms are relevant, otherwise only CASE I to IV :return: additional label """ return "" if score_only else "terms and " def _plot(self): - """ - Main plot function to plot climatological skill score. - """ + """Plot climatological skill score.""" fig, ax = plt.subplots() sns.boxplot(x=self._x_name, y="data", hue="ahead", data=self._data, ax=ax, whis=1., palette="Blues_d", showmeans=True, meanprops={"markersize": 1, "markeredgecolor": "k"}, flierprops={"marker": "."}) @@ -626,9 +715,15 @@ class PlotBootstrapSkillScore(AbstractPlotClass): @TimeTrackingWrapper class PlotTimeSeries: + """ + Create time series plot. + + Currently, plots are under development and not well designed for any use in public. + """ def __init__(self, stations: List, data_path: str, name: str, window_lead_time: int = None, plot_folder: str = ".", sampling="daily"): + """Initialise.""" self._data_path = data_path self._data_name = name self._stations = stations @@ -645,9 +740,12 @@ class PlotTimeSeries: def _get_window_lead_time(self, window_lead_time: int): """ - Extract the lead time from data and arguments. If window_lead_time is not given, extract this information from - data itself by the number of ahead dimensions. If given, check if data supports the give length. If the number - of ahead dimensions in data is lower than the given lead time, data's lead time is used. + Extract the lead time from data and arguments. + + If window_lead_time is not given, extract this information from data itself by the number of ahead dimensions. + If given, check if data supports the give length. If the number of ahead dimensions in data is lower than the + given lead time, data's lead time is used. + :param window_lead_time: lead time from arguments to validate :return: validated lead time, comes either from given argument or from data itself """ @@ -716,7 +814,7 @@ class PlotTimeSeries: for ahead in data.coords["ahead"].values: plot_data = data.sel(type="CNN", ahead=ahead).drop(["type", "ahead"]).squeeze().shift(index=ahead) label = f"{ahead}{self._sampling}" - ax.plot(plot_data, color=color[ahead-1], label=label) + ax.plot(plot_data, color=color[ahead - 1], label=label) def _plot_obs(self, ax, data): ahead = 1 @@ -732,9 +830,10 @@ class PlotTimeSeries: return f(data, min), f(data, max) @staticmethod - def _create_pdf_pages(plot_folder): + def _create_pdf_pages(plot_folder: str): """ - Standard save method to store plot locally. The name of this plot is static. + Store plot locally. + :param plot_folder: path to save the plot """ plot_name = os.path.join(os.path.abspath(plot_folder), 'timeseries_plot.pdf') @@ -744,25 +843,55 @@ class PlotTimeSeries: @TimeTrackingWrapper class PlotAvailability(AbstractPlotClass): + """ + Create data availablility plot similar to Gantt plot. + + Each entry of given generator, will result in a new line in the plot. Data is summarised for given temporal + resolution and checked whether data is available or not for each time step. This is afterwards highlighted as a + colored bar or a blank space. + + You can set different colors to highlight subsets for example by providing different generators for the same index + using different keys in the input dictionary. + + Note: each bar is surrounded by a small white box to highlight gabs in between. This can result in too long gabs + in display, if a gab is only very short. Also this appears on a (fluent) transition from one to another subset. + + Calling this class will create three versions fo the availability plot. + + 1) Data availability for each element + 1) Data availability as summary over all elements (is there at least a single elemnt for each time step) + 1) Combination of single and overall availability + + .. image:: ../../../../../_source/_plots/data_availability.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_summary.png + :width: 400 + + .. image:: ../../../../../_source/_plots/data_availability_combined.png + :width: 400 + + """ def __init__(self, generators: Dict[str, DataGenerator], plot_folder: str = ".", sampling="daily", summary_name="data availability"): + """Initialise.""" # create standard Gantt plot for all stations (currently in single pdf file with single page) super().__init__(plot_folder, "data_availability") self.sampling = self._get_sampling(sampling) plot_dict = self._prepare_data(generators) lgd = self._plot(plot_dict) - self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") # create summary Gantt plot (is data in at least one station available) self.plot_name += "_summary" plot_dict_summary = self._summarise_data(generators, summary_name) lgd = self._plot(plot_dict_summary) - self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") # combination of station and summary plot, last element is summary broken bar self.plot_name = "data_availability_combined" plot_dict_summary.update(plot_dict) lgd = self._plot(plot_dict_summary) - self._save(bbox_extra_artists=(lgd, ), bbox_inches="tight") + self._save(bbox_extra_artists=(lgd,), bbox_inches="tight") @staticmethod def _get_sampling(sampling): @@ -780,7 +909,8 @@ class PlotAvailability(AbstractPlotClass): labels = station_data.get_transposed_label().resample(datetime=self.sampling, skipna=True).mean() labels_bool = labels.sel(window=1).notnull() group = (labels_bool != labels_bool.shift(datetime=1)).cumsum() - plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, index=labels.datetime.values) + plot_data = pd.DataFrame({"avail": labels_bool.values, "group": group.values}, + index=labels.datetime.values) t = plot_data.groupby("group").apply(lambda x: (x["avail"].head(1)[0], x.index[0], x.shape[0])) t2 = [i[1:] for i in t if i[0]] @@ -803,7 +933,8 @@ class PlotAvailability(AbstractPlotClass): all_data = labels_bool else: tmp = all_data.combine_first(labels_bool) # expand dims to merged datetime coords - all_data = np.logical_or(tmp, labels_bool).combine_first(all_data) # apply logical on merge and fill missing with all_data + all_data = np.logical_or(tmp, labels_bool).combine_first( + all_data) # apply logical on merge and fill missing with all_data group = (all_data != all_data.shift(datetime=1)).cumsum() plot_data = pd.DataFrame({"avail": all_data.values, "group": group.values}, index=all_data.datetime.values) @@ -823,7 +954,7 @@ class PlotAvailability(AbstractPlotClass): height = 0.8 # should be <= 1 yticklabels = [] number_of_stations = len(plt_dict.keys()) - fig, ax = plt.subplots(figsize=(10, number_of_stations/3)) + fig, ax = plt.subplots(figsize=(10, number_of_stations / 3)) for station, d in sorted(plt_dict.items(), reverse=True): pos += 1 for subset, color in colors.items(): @@ -834,7 +965,7 @@ class PlotAvailability(AbstractPlotClass): yticklabels.append(station) ax.set_ylim([height, number_of_stations + 1]) - ax.set_yticks(np.arange(len(plt_dict.keys()))+1+height/2) + ax.set_yticks(np.arange(len(plt_dict.keys())) + 1 + height / 2) ax.set_yticklabels(yticklabels) handles = [mpatches.Patch(color=c, label=k) for k, c in colors.items()] lgd = plt.legend(handles=handles, bbox_to_anchor=(0, 1, 1, 0.2), loc="lower center", ncol=len(handles)) diff --git a/src/plotting/tracker_plot.py b/src/plotting/tracker_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..20db5d9d9f22df548b1d499c4e8e0faa3fbfa1ee --- /dev/null +++ b/src/plotting/tracker_plot.py @@ -0,0 +1,379 @@ +from collections import OrderedDict + +import numpy as np +import os +from typing import Union, List, Optional, Dict + +from src.helpers import to_list + +from matplotlib import pyplot as plt, lines as mlines, ticker as ticker +from matplotlib.patches import Rectangle + + +class TrackObject: + + """ + A TrackObject can be used to create simple chains of objects. + + :param name: string or list of strings with a name describing the track object + :param stage: additional meta information (can be used to highlight different blocks inside a chain) + """ + + def __init__(self, name: Union[List[str], str], stage: str): + self.name = to_list(name) + self.stage = stage + self.precursor: Optional[List[TrackObject]] = None + self.successor: Optional[List[TrackObject]] = None + self.x: Optional[float] = None + self.y: Optional[float] = None + + def __repr__(self): + return str("/".join(self.name)) + + @property + def x(self): + """Get x value.""" + return self._x + + @x.setter + def x(self, value: float): + """Set x value.""" + self._x = value + + @property + def y(self): + """Get y value.""" + return self._y + + @y.setter + def y(self, value: float): + """Set y value.""" + self._y = value + + def add_precursor(self, precursor: "TrackObject"): + """Add a precursory track object.""" + if self.precursor is None: + self.precursor = [precursor] + else: + if precursor not in self.precursor: + self.precursor.append(precursor) + else: + return + precursor.add_successor(self) + + def add_successor(self, successor: "TrackObject"): + """Add a successive track object.""" + if self.successor is None: + self.successor = [successor] + else: + if successor not in self.successor: + self.successor.append(successor) + else: + return + successor.add_precursor(self) + + +class TrackChain: + + def __init__(self, track_list): + self.track_list = track_list + self.scopes = self.get_all_scopes(self.track_list) + self.dims = self.get_all_dims(self.scopes) + + def get_all_scopes(self, track_list) -> Dict: + """Return dictionary with all distinct variables as keys and its unique scopes as values.""" + dims = {} + for track_dict in track_list: # all stages + for track in track_dict.values(): # single stage, all variables + for k, v in track.items(): # single variable + scopes = self.get_unique_scopes(v) + if dims.get(k) is None: + dims[k] = scopes + else: + dims[k] = np.unique(scopes + dims[k]).tolist() + return OrderedDict(sorted(dims.items())) + + @staticmethod + def get_all_dims(scopes): + dims = {} + for k, v in scopes.items(): + dims[k] = len(v) + return dims + + def create_track_chain(self): + control = self.control_dict(self.scopes) + track_chain_dict = OrderedDict() + for track_dict in self.track_list: + stage, stage_track = list(track_dict.items())[0] + track_chain, control = self._create_track_chain(control, OrderedDict(sorted(stage_track.items())), stage) + control = self.clean_control(control) + track_chain_dict[stage] = track_chain + return track_chain_dict + + def _create_track_chain(self, control, sorted_track_dict, stage): + track_objects = [] + for variable, all_variable_tracks in sorted_track_dict.items(): + for track_details in all_variable_tracks: + method, scope = track_details["method"], track_details["scope"] + tr = TrackObject([variable, method, scope], stage) + control_obj = control[variable][scope] + if method == "set": + track_objects = self._add_set_object(track_objects, tr, control_obj) + elif method == "get": + track_objects, skip_control_update = self._add_get_object(track_objects, tr, control_obj, + control, scope, variable) + if skip_control_update is True: + continue + self._update_control(control, variable, scope, tr) + return track_objects, control + + @staticmethod + def _update_control(control, variable, scope, tr_obj): + control[variable][scope] = tr_obj + + @staticmethod + def _add_track_object(track_objects, tr_obj, prev_obj): + if tr_obj.stage != prev_obj.stage: + track_objects.append(prev_obj) + return track_objects + + def _add_precursor(self, track_objects, tr_obj, prev_obj): + tr_obj.add_precursor(prev_obj) + return self._add_track_object(track_objects, tr_obj, prev_obj) + + def _add_set_object(self, track_objects, tr_obj, control_obj): + if control_obj is not None: + track_objects = self._add_precursor(track_objects, tr_obj, control_obj) + else: + track_objects.append(tr_obj) + return track_objects + + def _recursive_decent(self, scope, control_obj_var): + scope = scope.rsplit(".", 1) + if len(scope) > 1: + scope = scope[0] + control_obj = control_obj_var[scope] + if control_obj is not None: + pre, candidate = control_obj, control_obj + while pre.precursor is not None and pre.name[1] != "set": + # change candidate on stage border + if pre.name[2] != pre.precursor[0].name[2]: + candidate = pre + pre = pre.precursor[0] + # correct pre if candidate is from same scope + if candidate.name[2] == pre.name[2]: + pre = candidate + return pre + else: + return self._recursive_decent(scope, control_obj_var) + + def _add_get_object(self, track_objects, tr_obj, control_obj, control, scope, variable): + skip_control_update = False + if control_obj is not None: + track_objects = self._add_precursor(track_objects, tr_obj, control_obj) + else: + pre = self._recursive_decent(scope, control[variable]) + if pre is not None: + track_objects = self._add_precursor(track_objects, tr_obj, pre) + else: + skip_control_update = True + return track_objects, skip_control_update + + @staticmethod + def control_dict(scopes): + """Create empty control dictionary with variables and scopes as keys and None as default for all values.""" + control = {} + for variable, scope_names in scopes.items(): + control[variable] = {} + for s in scope_names: + update = {s: None} + if len(control[variable].keys()) == 0: + control[variable] = update + else: + control[variable].update(update) + return control + + @staticmethod + def clean_control(control): + for k, v in control.items(): # var. scopes + for kv, vv in v.items(): # scope tr_obj + try: + if vv.precursor[0].name[2] != vv.name[2]: + control[k][kv] = None + except (TypeError, AttributeError): + pass + return control + + @staticmethod + def get_unique_scopes(track_list: List[Dict]) -> List[str]: + """Get list with all unique elements from input including general scope if missing.""" + scopes = [e["scope"] for e in track_list] + ["general"] + return np.unique(scopes).tolist() + + +class TrackPlot: + + def __init__(self, tracker_list, sparse_conn_mode=True, plot_folder: str = ".", skip_run_env=True, plot_name=None): + + self.width = 0.6 + self.height = 0.5 + self.space_intern_y = 0.2 + self.space_extern_y = 1 + self.space_intern_x = 0.4 + self.space_extern_x = 0.6 + self.y_pos = None + self.anchor = None + self.x_max = None + + track_chain_obj = TrackChain(tracker_list) + track_chain_dict = track_chain_obj.create_track_chain() + self.set_ypos_anchor(track_chain_obj.scopes, track_chain_obj.dims) + self.fig, self.ax = plt.subplots(figsize=(len(tracker_list) * 2, (self.anchor.max() - self.anchor.min()) / 3)) + self._plot(track_chain_dict, sparse_conn_mode, skip_run_env, plot_folder, plot_name) + + def _plot(self, track_chain_dict, sparse_conn_mode, skip_run_env, plot_folder, plot_name=None): + stages, v_lines = self.create_track_chain_plot(track_chain_dict, sparse_conn_mode=sparse_conn_mode, + skip_run_env=skip_run_env) + self.set_lims() + self.add_variable_names() + self.add_stages(v_lines, stages) + plt.tight_layout() + plot_name = "tracking.pdf" if plot_name is None else plot_name + plot_name = os.path.join(os.path.abspath(plot_folder), plot_name) + plt.savefig(plot_name, dpi=600) + + def line(self, start_x, end_x, y, color="darkgrey"): + """Draw grey horizontal connection line from start_x to end_x on y-pos.""" + # draw white border line + l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color="white", + linewidth=2.5) + self.ax.add_line(l) + # draw grey line + l = mlines.Line2D([start_x + self.width, end_x], [y + self.height / 2, y + self.height / 2], color=color, + linewidth=1.4) + self.ax.add_line(l) + + def step(self, start_x, end_x, start_y, end_y, color="black"): + """Draw black connection step line from start_xy to end_xy. Step is taken shortly before end position.""" + # adjust start and end by width height + start_x += self.width + start_y += self.height / 2 + end_y += self.height / 2 + step_x = end_x - (self.space_intern_x) / 2 # step is taken shortly before end + pos_x = [start_x, step_x, step_x, end_x] + pos_y = [start_y, start_y, end_y, end_y] + # draw white border line + l = mlines.Line2D(pos_x, pos_y, color="white", linewidth=2.5) + self.ax.add_line(l) + # draw black line + l = mlines.Line2D(pos_x, pos_y, color=color, linewidth=1.4) + self.ax.add_line(l) + + def rect(self, x, y, method="get"): + """Draw rectangle with lower left at (x,y), size equal to width/height and label/color according to method.""" + # draw rectangle + color = {"get": "orange"}.get(method, "lightblue") + r = Rectangle((x, y), self.width, self.height, color=color) + self.ax.add_artist(r) + # add label + rx, ry = r.get_xy() + cx = rx + r.get_width() / 2.0 + cy = ry + r.get_height() / 2.0 + self.ax.annotate(method, (cx, cy), color='w', weight='bold', fontsize=6, ha='center', va='center') + + def set_ypos_anchor(self, scopes, dims): + anchor = sum(dims.values()) + pos_dict = {} + d_y = 0 + for k, v in scopes.items(): + pos_dict[k] = {} + for e in v: + update = {e: anchor + d_y} + if len(pos_dict[k].keys()) == 0: + pos_dict[k] = update + else: + pos_dict[k].update(update) + d_y -= (self.space_intern_y + self.height) + d_y -= (self.space_extern_y - self.space_intern_y) + self.y_pos = pos_dict + self.anchor = np.array((d_y, self.height + self.space_extern_y)) + anchor + + def plot_track_chain(self, chain, y_pos, x_pos=0, prev=None, stage=None, sparse_conn_mode=False): + if (chain.successor is None) or (chain.stage == stage): + var, method, scope = chain.name + x, y = x_pos, y_pos[var][scope] + self.rect(x, y, method=method) + chain.x, chain.y = x, y + if prev is not None and prev[0] is not None: + if (sparse_conn_mode is True) and (method == "set"): + pass + else: + if y == prev[1]: + self.line(prev[0], x, prev[1]) + else: + self.step(prev[0], x, prev[1], y) + else: + x, y = chain.x, chain.y + + x_max = None + if chain.successor is not None: + for e in chain.successor: + if e.stage == stage: + shift = self.width + self.space_intern_x if chain.stage == e.stage else 0 + x_tmp = self.plot_track_chain(e, y_pos, x_pos + shift, prev=(x, y), + stage=stage, sparse_conn_mode=sparse_conn_mode) + x_max = np.nanmax(np.array([x_tmp, x_max], dtype=np.float64)) + else: + x_max = np.nanmax(np.array([x, x_max, x_pos], dtype=np.float64)) + else: + x_max = x + + return x_max + + def add_variable_names(self): + labels = [] + pos = [] + labels_major = [] + pos_major = [] + for k, v in self.y_pos.items(): + for kv, vv in v.items(): + if kv == "general": + labels_major.append(k) + pos_major.append(vv + self.height / 2) + else: + labels.append(kv.split(".", 1)[1]) + pos.append(vv + self.height / 2) + self.ax.tick_params(axis="y", which="major", labelsize="large") + self.ax.yaxis.set_major_locator(ticker.FixedLocator(pos_major)) + self.ax.yaxis.set_major_formatter(ticker.FixedFormatter(labels_major)) + self.ax.yaxis.set_minor_locator(ticker.FixedLocator(pos)) + self.ax.yaxis.set_minor_formatter(ticker.FixedFormatter(labels)) + + def add_stages(self, vlines, stages): + x_max = self.x_max + self.space_intern_x + self.width + for l in vlines: + self.ax.vlines(l, *self.anchor, "black", "dashed") + vlines = [0] + vlines + [x_max] + pos = [(vlines[i] + vlines[i+1]) / 2 for i in range(len(vlines)-1)] + self.ax.xaxis.set_major_locator(ticker.FixedLocator(pos)) + self.ax.xaxis.set_major_formatter(ticker.FixedFormatter(stages)) + + def create_track_chain_plot(self, track_chain_dict, sparse_conn_mode=True, skip_run_env=True): + x, x_max = 0, 0 + v_lines, stages = [], [] + for stage, track_chain in track_chain_dict.items(): + if stage == "RunEnvironment" and skip_run_env is True: + continue + if x > 0: + v_lines.append(x - self.space_extern_x / 2) + for e in track_chain: + x_max = max(x_max, self.plot_track_chain(e, self.y_pos, x_pos=x, stage=stage, sparse_conn_mode=sparse_conn_mode)) + x = x_max + self.space_extern_x + self.width + stages.append(stage) + self.x_max = x_max + return stages, v_lines + + def set_lims(self): + x_max = self.x_max + self.space_intern_x + self.width + self.ax.set_xlim((0, x_max)) + self.ax.set_ylim(self.anchor) diff --git a/src/plotting/training_monitoring.py b/src/plotting/training_monitoring.py index 7e656895c5eecdabe1ef26869b68fb9494ed4c8c..473b966ce52ee7e2885bc14beef2e68b8835b15e 100644 --- a/src/plotting/training_monitoring.py +++ b/src/plotting/training_monitoring.py @@ -1,7 +1,8 @@ +"""Plots to monitor training.""" + __author__ = 'Felix Kleinert, Lukas Leufen' __date__ = '2019-12-11' - from typing import Union, Dict, List import keras @@ -18,15 +19,18 @@ lr_object = Union[Dict, LearningRateDecay] class PlotModelHistory: """ - Plots history of all plot_metrics (default: loss) for a training event. For default plot_metric and val_plot_metric - are plotted. If further metrics are provided (name must somehow include the word `<plot_metric>`), this additional - information is added to the plot with an separate y-axis scale on the right side (shared for all additional - metrics). The plot is saved locally. For a proper saving behaviour, the parameter filename must include the absolute - path for the plot. + Plot history of all plot_metrics (default: loss) for a training event. + + For default plot_metric and val_plot_metric are plotted. If further metrics are provided (name must somehow include + the word `<plot_metric>`), this additional information is added to the plot with an separate y-axis scale on the + right side (shared for all additional metrics). The plot is saved locally. For a proper saving behaviour, the + parameter filename must include the absolute path for the plot. """ + def __init__(self, filename: str, history: history_object, plot_metric: str = "loss", main_branch: bool = False): """ - Sets attributes and create plot + Set attributes and create plot. + :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a format ending like .pdf or .png to work. :param history: the history object (or a dict with at least 'loss' and 'val_loss' as keys) to plot loss from @@ -47,16 +51,20 @@ class PlotModelHistory: plot_metric = "mean_squared_error" elif plot_metric.lower() == "mae": plot_metric = "mean_absolute_error" - available_keys = [k for k in history.keys() if plot_metric in k and ("main" in k.lower() if main_branch else True)] + available_keys = [k for k in history.keys() if + plot_metric in k and ("main" in k.lower() if main_branch else True)] available_keys.sort(key=len) return available_keys[0] def _filter_columns(self, history: Dict) -> List[str]: """ - Select only columns named like %<plot_metric>%. The default metrics '<plot_metric>' and 'val_<plot_metric>' are - also removed. + Select only columns named like %<plot_metric>%. + + The default metrics '<plot_metric>' and 'val_<plot_metric>' are removed too. + :param history: a dict with at least '<plot_metric>' and 'val_<plot_metric>' as keys (can be derived from keras History.history) + :return: filtered columns including all plot_metric variations except <plot_metric> and val_<plot_metric>. """ cols = list(filter(lambda x: self._plot_metric in x, history.keys())) @@ -69,8 +77,11 @@ class PlotModelHistory: def _plot(self, filename: str) -> None: """ - Actual plot routine. Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided, - they will be added with an additional yaxis on the right side. The plot is saved in filename. + Create plot. + + Plots <plot_metric> and val_<plot_metric> as default. If more plot_metrics are provided, they will be added with + an additional yaxis on the right side. The plot is saved in filename. + :param filename: name (including total path) of the plot to save. """ ax = self._data[[self._plot_metric, f"val_{self._plot_metric}"]].plot(linewidth=0.7) @@ -86,12 +97,16 @@ class PlotModelHistory: class PlotModelLearningRate: """ - Plots the behaviour of the learning rate in dependence of the number of epochs. The plot is saved locally as pdf. - For a proper saving behaviour, the parameter filename must include the absolute path for the plot. + Plot the behaviour of the learning rate in dependence of the number of epochs. + + The plot is saved locally as pdf. For a proper saving behaviour, the parameter filename must include the absolute + path for the plot. """ + def __init__(self, filename: str, lr_sc: lr_object): """ - Sets attributes and create plot + Set attributes and create plot. + :param filename: saving name of the plot to create (preferably absolute path if possible), the filename needs a format ending like .pdf or .png to work. :param lr_sc: the learning rate object (or a dict with `lr` as key) to plot from @@ -103,7 +118,10 @@ class PlotModelLearningRate: def _plot(self, filename: str) -> None: """ - Actual plot routine. Plots the learning rate in dependence of epoch. + Create plot. + + Plot the learning rate in dependence of epoch. + :param filename: name (including total path) of the plot to save. """ ax = self._data.plot(linewidth=0.7) diff --git a/src/run_modules/__init__.py b/src/run_modules/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..f06d627f6ff482e11c6d1c520fa59197feb831cd 100644 --- a/src/run_modules/__init__.py +++ b/src/run_modules/__init__.py @@ -0,0 +1,6 @@ +from src.run_modules.experiment_setup import ExperimentSetup +from src.run_modules.model_setup import ModelSetup +from src.run_modules.post_processing import PostProcessing +from src.run_modules.pre_processing import PreProcessing +from src.run_modules.run_environment import RunEnvironment +from src.run_modules.training import Training diff --git a/src/run_modules/experiment_setup.py b/src/run_modules/experiment_setup.py index b04c0e2ac2a2262c92c5f7149014206d0d390e18..112c3ff4473a844a488ca462fd06400cfdfa3c5b 100644 --- a/src/run_modules/experiment_setup.py +++ b/src/run_modules/experiment_setup.py @@ -1,14 +1,14 @@ __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-11-15' - import argparse import logging import os -from typing import Union, Dict, Any +from typing import Union, Dict, Any, List import socket +from src.configuration import path_config from src import helpers from src.run_modules.run_environment import RunEnvironment @@ -31,35 +31,231 @@ DEFAULT_HPC_HOST_LIST = ["jw", "hdfmlc"] # first part of node names for Juwels class ExperimentSetup(RunEnvironment): """ - params: - trainable: Train new model if true, otherwise try to load existing model + Set up the model. + + Schedule of experiment setup: + * set up experiment path + * set up data path (according to host system) + * set up forecast, bootstrap and plot path (inside experiment path) + * set all parameters given in args (or use default values) + * check target variable + * check `variables` and `statistics_per_var` parameter for consistency + + Sets + * `data_path` [.] + * `create_new_model` [.] + * `bootstrap_path` [.] + * `trainable` [.] + * `fraction_of_training` [.] + * `extreme_values` [train] + * `extremes_on_right_tail_only` [train] + * `upsampling` [train] + * `permute_data` [train] + * `experiment_name` [.] + * `experiment_path` [.] + * `plot_path` [.] + * `forecast_path` [.] + * `stations` [.] + * `network` [.] + * `station_type` [.] + * `statistics_per_var` [.] + * `variables` [.] + * `start` [.] + * `end` [.] + * `window_history_size` [.] + * `overwrite_local_data` [preprocessing] + * `sampling` [.] + * `transformation` [., preprocessing] + * `target_var` [.] + * `target_dim` [.] + * `window_lead_time` [.] + + # interpolation + self._set_param("dimensions", dimensions, default={'new_index': ['datetime', 'Stations']}) + self._set_param("interpolate_dim", interpolate_dim, default='datetime') + self._set_param("interpolate_method", interpolate_method, default='linear') + self._set_param("limit_nan_fill", limit_nan_fill, default=1) + + # train set parameters + self._set_param("start", train_start, default="1997-01-01", scope="train") + self._set_param("end", train_end, default="2007-12-31", scope="train") + self._set_param("min_length", train_min_length, default=90, scope="train") + + # validation set parameters + self._set_param("start", val_start, default="2008-01-01", scope="val") + self._set_param("end", val_end, default="2009-12-31", scope="val") + self._set_param("min_length", val_min_length, default=90, scope="val") + + # test set parameters + self._set_param("start", test_start, default="2010-01-01", scope="test") + self._set_param("end", test_end, default="2017-12-31", scope="test") + self._set_param("min_length", test_min_length, default=90, scope="test") + + # train_val set parameters + self._set_param("start", self.data_store.get("start", "train"), scope="train_val") + self._set_param("end", self.data_store.get("end", "val"), scope="train_val") + train_val_min_length = sum([self.data_store.get("min_length", s) for s in ["train", "val"]]) + self._set_param("min_length", train_val_min_length, default=180, scope="train_val") + + # use all stations on all data sets (train, val, test) + self._set_param("use_all_stations_on_all_data_sets", use_all_stations_on_all_data_sets, default=True) + + # set post-processing instructions + self._set_param("evaluate_bootstraps", evaluate_bootstraps, scope="general.postprocessing") + create_new_bootstraps = max([self.data_store.get("trainable", "general"), create_new_bootstraps or False]) + self._set_param("create_new_bootstraps", create_new_bootstraps, scope="general.postprocessing") + self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing") + self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") + + # check variables, statistics and target variable + self._check_target_var() + self._compare_variables_and_statistics() + + + + + + + + + + Creates + * plot of model architecture in `<model_name>.pdf` + + :param parser_args: argument parser, currently only accepting ``experiment_date argument`` to be used for + experiment's name and path creation. Final experiment's name is derived from given name and the time series + sampling as `<name>_network_<sampling>/` . All interim and final results, logging, plots, ... of this run are + stored in this directory if not explicitly provided in kwargs. Only the data itself and data for bootstrap + investigations are stored outside this structure. + :param stations: list of stations or single station to use in experiment. If not provided, stations are set to + :py:const:`default stations <DEFAULT_STATIONS>`. + :param network: name of network to restrict to use only stations from this measurement network. Default is + `AIRBASE` . + :param station_type: restrict network type to one of TOAR's categories (background, traffic, industrial). Default is + `None` to use all categories. + :param variables: list of all variables to use. Valid names can be found in + `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_. If not provided, this + parameter is filled with keys from ``statistics_per_var``. + :param statistics_per_var: dictionary with statistics to use for variables (if data is daily and loaded from JOIN). + If not provided, :py:const:`default statistics <DEFAULT_VAR_ALL_DICT>` is applied. ``statistics_per_var`` is + compared with given ``variables`` and unused variables are removed. Therefore, statistics at least need to + provide all variables from ``variables``. For more details on available statistics, we refer to + `Section 3.3 List of statistics/metrics for stats service <https://join.fz-juelich.de/services/rest/surfacedata/>`_ + in the JOIN documentation. Valid parameter names can be found in + `Section 2.1 Parameters <https://join.fz-juelich.de/services/rest/surfacedata/>`_. + :param start: start date of overall data (default `"1997-01-01"`) + :param end: end date of overall data (default `"2017-12-31"`) + :param window_history_size: number of time steps to use for input data (default 13). Time steps `t_0 - w` to `t_0` + are used as input data (therefore actual data size is `w+1`). + :param target_var: target variable to predict by model, currently only a single target variable is supported. + Because this framework was originally designed to predict ozone, default is `"o3"`. + :param target_dim: dimension of target variable (default `"variables"`). + :param window_lead_time: number of time steps to predict by model (default 3). Time steps `t_0+1` to `t_0+w` are + predicted. + :param dimensions: + :param interpolate_dim: + :param interpolate_method: + :param limit_nan_fill: + :param train_start: + :param train_end: + :param val_start: + :param val_end: + :param test_start: + :param test_end: + :param use_all_stations_on_all_data_sets: + :param trainable: train a new model from scratch or resume training with existing model if `True` (default) or + freeze loaded model and do not perform any modification on it. ``trainable`` is set to `True` if + ``create_new_model`` is `True`. + :param fraction_of_train: given value is used to split between test data and train data (including validation data). + The value of ``fraction_of_train`` must be in `(0, 1)` but is recommended to be in the interval `[0.6, 0.9]`. + Default value is `0.8`. Split between train and validation is fixed to 80% - 20% and currently not changeable. + :param experiment_path: + :param plot_path: path to save all plots. If left blank, this will be included in the experiment path (recommended). + Otherwise customise the location to save all plots. + :param forecast_path: path to save all forecasts in files. It is recommended to leave this parameter blank, all + forecasts will be the directory `forecasts` inside the experiment path (default). For customisation, add your + path here. + :param overwrite_local_data: Reload input and target data from web and replace local data if `True` (default + `False`). + :param sampling: set temporal sampling rate of data. You can choose from daily (default), monthly, seasonal, + vegseason, summer and annual for aggregated values and hourly for the actual values. Note, that hourly values on + JOIN are currently not accessible from outside. To access this data, you need to add your personal token in + :py:mod:`join settings <src.configuration.join_settings>` and make sure to untrack this file! + :param create_new_model: determine whether a new model will be created (`True`, default) or not (`False`). If this + parameter is set to `False`, make sure, that a suitable model already exists in the experiment path. This model + must fit in terms of input and output dimensions as well as ``window_history_size`` and ``window_lead_time`` and + must be implemented as a :py:mod:`model class <src.model_modules.model_class>` and imported in + :py:mod:`model setup <src.run_modules.model_setup>`. If ``create_new_model`` is `True`, parameter ``trainable`` + is automatically set to `True` too. + :param bootstrap_path: + :param permute_data_on_training: shuffle train data individually for each station if `True`. This is performed each + iteration for new, so that each sample very likely differs from epoch to epoch. Train data permutation is + disabled (`False`) per default. If the case of extreme value manifolding, data permutation is enabled anyway. + :param transformation: set transformation options in dictionary style. All information about transformation options + can be found in :py:meth:`setup transformation <src.data_handling.data_generator.DataGenerator.setup_transformation>`. + If no transformation is provided, all options are set to :py:const:`default transformation <DEFAULT_TRANSFORMATION>`. + :param train_min_length: + :param val_min_length: + :param test_min_length: + :param extreme_values: augment target samples with values of lower occurrences indicated by its normalised + deviation from mean by manifolding. These extreme values need to be indicated by a list of thresholds. For + each entry in this list, all values outside an +/- interval will be added in the training (and only the + training) set for a second time to the sample. If multiple valus are given, a sample is added for each + exceedence once. E.g. a sample with `value=2.5` occurs twice in the training set for given + `extreme_values=[2, 3]`, whereas a sample with `value=5` occurs three times in the training set. For default, + upsampling of extreme values is disabled (`None`). Upsamling can be modified to manifold only values that are + actually larger than given values from ``extreme_values`` (apply only on right side of distribution) by using + ``extremes_on_right_tail_only``. This can be useful for positive skew variables. + :param extremes_on_right_tail_only: applies only if ``extreme_values`` are given. If ``extremes_on_right_tail_only`` + is `True`, only manifold values that are larger than given extremes (apply upsampling only on right side of + distribution). In default mode, this is set to `False` to manifold extremes on both sides. + :param evaluate_bootstraps: + :param plot_list: + :param number_of_bootstraps: + :param create_new_bootstraps: + :param data_path: path to find and store meteorological and environmental / air quality data. Leave this parameter + empty, if your host system is known and a suitable path was already hardcoded in the program (see + :py:func:`prepare host <src.configuration.path_config.prepare_host>`). + """ - def __init__(self, parser_args=None, stations=None, network=None, station_type=None, variables=None, - statistics_per_var=None, start=None, end=None, window_history_size=None, target_var="o3", target_dim=None, - window_lead_time=None, dimensions=None, interpolate_dim=None, interpolate_method=None, + def __init__(self, + parser_args=None, + stations: Union[str, List[str]] = None, + network: str = None, + station_type: str = None, + variables: Union[str, List[str]] = None, + statistics_per_var: Dict = None, + start: str = None, + end: str = None, + window_history_size: int = None, + target_var="o3", + target_dim=None, + window_lead_time: int = None, + dimensions=None, + interpolate_dim=None, + interpolate_method=None, limit_nan_fill=None, train_start=None, train_end=None, val_start=None, val_end=None, test_start=None, - test_end=None, use_all_stations_on_all_data_sets=True, trainable=None, fraction_of_train=None, - experiment_path=None, plot_path=None, forecast_path=None, overwrite_local_data=None, sampling="daily", - create_new_model=None, bootstrap_path=None, permute_data_on_training=False, transformation=None, - train_min_length=None, val_min_length=None, test_min_length=None, extreme_values=None, - extremes_on_right_tail_only=None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None, - create_new_bootstraps=None, data_path=None, login_nodes=None, hpc_hosts=None): + test_end=None, use_all_stations_on_all_data_sets=True, trainable: bool = None, fraction_of_train: float = None, + experiment_path=None, plot_path: str = None, forecast_path: str = None, overwrite_local_data: bool = None, sampling: str = "daily", + create_new_model: bool = None, bootstrap_path=None, permute_data_on_training: bool = None, transformation=None, + train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None, + extremes_on_right_tail_only: bool = None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None, + create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None): # create run framework super().__init__() # experiment setup - self._set_param("data_path", data_path, default=helpers.prepare_host(sampling=sampling)) + self._set_param("data_path", path_config.prepare_host(data_path=data_path, sampling=sampling)) self._set_param("hostname", helpers.get_host()) - # self._set_param("hostname", "jwc0123") self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST) self._set_param("login_nodes", login_nodes, default=DEFAULT_HPC_LOGIN_LIST) self._set_param("create_new_model", create_new_model, default=True) if self.data_store.get("create_new_model"): trainable = True data_path = self.data_store.get("data_path") - bootstrap_path = helpers.set_bootstrap_path(bootstrap_path, data_path, sampling) + bootstrap_path = path_config.set_bootstrap_path(bootstrap_path, data_path, sampling) self._set_param("bootstrap_path", bootstrap_path) self._set_param("trainable", trainable, default=True) self._set_param("fraction_of_training", fraction_of_train, default=0.8) @@ -67,25 +263,26 @@ class ExperimentSetup(RunEnvironment): self._set_param("extremes_on_right_tail_only", extremes_on_right_tail_only, default=False, scope="train") self._set_param("upsampling", extreme_values is not None, scope="train") upsampling = self.data_store.get("upsampling", "train") - self._set_param("permute_data", max([permute_data_on_training, upsampling]), scope="train") + permute_data = False if permute_data_on_training is None else permute_data_on_training + self._set_param("permute_data", permute_data or upsampling, scope="train") # set experiment name exp_date = self._get_parser_args(parser_args).get("experiment_date") - exp_name, exp_path = helpers.set_experiment_name(experiment_date=exp_date, experiment_path=experiment_path, - sampling=sampling) + exp_name, exp_path = path_config.set_experiment_name(experiment_name=exp_date, experiment_path=experiment_path, + sampling=sampling) self._set_param("experiment_name", exp_name) self._set_param("experiment_path", exp_path) - helpers.check_path_and_create(self.data_store.get("experiment_path")) + path_config.check_path_and_create(self.data_store.get("experiment_path")) # set plot path default_plot_path = os.path.join(exp_path, "plots") self._set_param("plot_path", plot_path, default=default_plot_path) - helpers.check_path_and_create(self.data_store.get("plot_path")) + path_config.check_path_and_create(self.data_store.get("plot_path")) # set results path default_forecast_path = os.path.join(exp_path, "forecasts") self._set_param("forecast_path", forecast_path, default_forecast_path) - helpers.check_path_and_create(self.data_store.get("forecast_path")) + path_config.check_path_and_create(self.data_store.get("forecast_path")) # setup for data self._set_param("stations", stations, default=DEFAULT_STATIONS) @@ -93,7 +290,6 @@ class ExperimentSetup(RunEnvironment): self._set_param("station_type", station_type, default=None) self._set_param("statistics_per_var", statistics_per_var, default=DEFAULT_VAR_ALL_DICT) self._set_param("variables", variables, default=list(self.data_store.get("statistics_per_var").keys())) - self._compare_variables_and_statistics() self._set_param("start", start, default="1997-01-01") self._set_param("end", end, default="2017-12-31") self._set_param("window_history_size", window_history_size, default=13) @@ -104,7 +300,6 @@ class ExperimentSetup(RunEnvironment): # target self._set_param("target_var", target_var, default="o3") - self._check_target_var() self._set_param("target_dim", target_dim, default='variables') self._set_param("window_lead_time", window_lead_time, default=3) @@ -145,7 +340,12 @@ class ExperimentSetup(RunEnvironment): self._set_param("number_of_bootstraps", number_of_bootstraps, default=20, scope="general.postprocessing") self._set_param("plot_list", plot_list, default=DEFAULT_PLOT_LIST, scope="general.postprocessing") + # check variables, statistics and target variable + self._check_target_var() + self._compare_variables_and_statistics() + def _set_param(self, param: str, value: Any, default: Any = None, scope: str = "general") -> None: + """Set given parameter and log in debug.""" if value is None and default is not None: value = default self.data_store.set(param, value, scope) @@ -154,8 +354,10 @@ class ExperimentSetup(RunEnvironment): @staticmethod def _get_parser_args(args: Union[Dict, argparse.Namespace]) -> Dict: """ - Transform args to dict if given as argparse.Namespace + Transform args to dict if given as argparse.Namespace. + :param args: either a dictionary or an argument parser instance + :return: dictionary with all arguments """ if isinstance(args, argparse.Namespace): @@ -166,30 +368,38 @@ class ExperimentSetup(RunEnvironment): return {} def _compare_variables_and_statistics(self): + """ + Compare variables and statistics. + + * raise error, if a variable is missing. + * remove unused variables from statistics. + """ logging.debug("check if all variables are included in statistics_per_var") stat = self.data_store.get("statistics_per_var") var = self.data_store.get("variables") + # too less entries, raise error if not set(var).issubset(stat.keys()): missing = set(var).difference(stat.keys()) raise ValueError(f"Comparison of given variables and statistics_per_var show that not all requested " f"variables are part of statistics_per_var. Please add also information on the missing " f"statistics for the variables: {missing}") + # too much entries, remove unused + target_var = helpers.to_list(self.data_store.get("target_var")) + unused_vars = set(stat.keys()).difference(set(var).union(target_var)) + if len(unused_vars) > 0: + logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}") + stat_new = helpers.remove_items(stat, list(unused_vars)) + self._set_param("statistics_per_var", stat_new) def _check_target_var(self): + """Check if target variable is in statistics_per_var dictionary.""" target_var = helpers.to_list(self.data_store.get("target_var")) stat = self.data_store.get("statistics_per_var") var = self.data_store.get("variables") if not set(target_var).issubset(stat.keys()): raise ValueError(f"Could not find target variable {target_var} in statistics_per_var.") - unused_vars = set(stat.keys()).difference(set(var).union(target_var)) - if len(unused_vars) > 0: - logging.info(f"There are unused keys in statistics_per_var. Therefore remove keys: {unused_vars}") - stat_new = helpers.dict_pop(stat, list(unused_vars)) - self._set_param("statistics_per_var", stat_new) - if __name__ == "__main__": - formatter = '%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]' logging.basicConfig(format=formatter, level=logging.DEBUG) diff --git a/src/run_modules/model_setup.py b/src/run_modules/model_setup.py index e8259b2847ea4ede1b365f49778f019c004fa7f1..804c4b9403e3b61ca61522cfa9a56588b4776be5 100644 --- a/src/run_modules/model_setup.py +++ b/src/run_modules/model_setup.py @@ -1,7 +1,8 @@ +"""Model setup module.""" + __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-02' - import logging import os @@ -17,10 +18,44 @@ from src.run_modules.run_environment import RunEnvironment class ModelSetup(RunEnvironment): + """ + Set up the model. + + Schedule of model setup: + #. set channels (from variables dimension) + #. build imported model + #. plot model architecture + #. load weights if enabled (e.g. to resume a training) + #. set callbacks and checkpoint + #. compile model + + Required objects [scope] from data store: + * `experiment_path` [.] + * `experiment_name` [.] + * `trainable` [.] + * `create_new_model` [.] + * `generator` [train] + * `window_lead_time` [.] + * `window_history_size` [.] + + Optional objects + * `lr_decay` [model] + + Sets + * `channels` [model] + * `model` [model] + * `hist` [model] + * `callbacks` [model] + * `model_name` [model] + * all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model] + + Creates + * plot of model architecture `<model_name>.pdf` + + """ def __init__(self): - - # create run framework + """Initialise and run model setup.""" super().__init__() self.model = None path = self.data_store.get("experiment_path") @@ -56,6 +91,7 @@ class ModelSetup(RunEnvironment): self.compile_model() def _set_channels(self): + """Set channels as number of variables of train generator.""" channels = self.data_store.get("generator", "train")[0][0].shape[-1] self.data_store.set("channels", channels, self.scope) @@ -70,14 +106,15 @@ class ModelSetup(RunEnvironment): def _set_callbacks(self): """ - Set all callbacks for the training phase. Add all callbacks with the .add_callback statement. Finally, the - advanced model checkpoint is added. + Set all callbacks for the training phase. + + Add all callbacks with the .add_callback statement. Finally, the advanced model checkpoint is added. """ - lr = self.data_store.get_default("lr_decay", scope="model", default=None) + lr = self.data_store.get_default("lr_decay", scope=self.scope, default=None) hist = HistoryAdvanced() self.data_store.set("hist", hist, scope="model") callbacks = CallbackHandler() - if lr: + if lr is not None: callbacks.add_callback(lr, self.callbacks_name % "lr", "lr") callbacks.add_callback(hist, self.callbacks_name % "hist", "hist") callbacks.create_model_checkpoint(filepath=self.checkpoint_name, verbose=1, monitor='val_loss', @@ -85,6 +122,7 @@ class ModelSetup(RunEnvironment): self.data_store.set("callbacks", callbacks, self.scope) def load_weights(self): + """Try to load weights from existing model or skip if not possible.""" try: self.model.load_weights(self.model_name) logging.info(f"reload weights from model {self.model_name} ...") @@ -92,18 +130,21 @@ class ModelSetup(RunEnvironment): logging.info('no weights to reload...') def build_model(self): + """Build model using window_history_size, window_lead_time and channels from data store.""" args_list = ["window_history_size", "window_lead_time", "channels"] args = self.data_store.create_args_dict(args_list, self.scope) self.model = MyModel(**args) self.get_model_settings() def get_model_settings(self): + """Load all model settings and store in data store.""" model_settings = self.model.get_settings() - self.data_store.set_args_from_dict(model_settings, self.scope) + self.data_store.set_from_dict(model_settings, self.scope, log=True) self.model_name = self.model_name % self.data_store.get_default("model_name", self.scope, "my_model") self.data_store.set("model_name", self.model_name, self.scope) def plot_model(self): # pragma: no cover + """Plot model architecture as `<model_name>.pdf`.""" with tf.device("/cpu:0"): file_name = f"{self.model_name.rsplit('.', 1)[0]}.pdf" keras.utils.plot_model(self.model, to_file=file_name, show_shapes=True, show_layer_names=True) diff --git a/src/run_modules/post_processing.py b/src/run_modules/post_processing.py index bc3cdf2653aed86a00a139d963a26d826131b5b6..143e2908b6398412211494b02bf84c3c741c18b2 100644 --- a/src/run_modules/post_processing.py +++ b/src/run_modules/post_processing.py @@ -1,35 +1,66 @@ +"""Post-processing module.""" + __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-11' - import inspect import logging import os +from typing import Dict, Tuple, Union, List import keras import numpy as np import pandas as pd import xarray as xr -from src import statistics -from src.data_handling.data_distributor import Distributor -from src.data_handling.data_generator import DataGenerator -from src.data_handling.bootstraps import BootStraps -from src.datastore import NameNotFoundInDataStore -from src.helpers import TimeTracking +from src.data_handling import BootStraps, Distributor, DataGenerator, DataPrep +from src.helpers.datastore import NameNotFoundInDataStore +from src.helpers import TimeTracking, statistics from src.model_modules.linear_model import OrdinaryLeastSquaredModel from src.model_modules.model_class import AbstractModelClass from src.plotting.postprocessing_plotting import PlotMonthlySummary, PlotStationMap, PlotClimatologicalSkillScore, \ PlotCompetitiveSkillScore, PlotTimeSeries, PlotBootstrapSkillScore, PlotAvailability, PlotConditionalQuantiles -# from src.plotting.postprocessing_plotting import plot_conditional_quantiles from src.run_modules.run_environment import RunEnvironment -from typing import Dict - class PostProcessing(RunEnvironment): + """ + Perform post-processing for performance evaluation. + + Schedule of post-processing: + #. train a ordinary least squared model (ols) for reference + #. create forecasts for nn, ols, and persistence + #. evaluate feature importance with bootstrapped predictions + #. calculate skill scores + #. create plots + + Required objects [scope] from data store: + * `best_model` [.] or locally saved model plus `model_name` [model] and `model` [model] + * `generator` [train, val, test, train_val] + * `forecast_path` [.] + * `plot_path` [postprocessing] + * `experiment_path` [.] + * `target_var` [.] + * `sampling` [.] + * `window_lead_time` [.] + * `evaluate_bootstraps` [postprocessing] and if enabled: + + * `create_new_bootstraps` [postprocessing] + * `bootstrap_path` [postprocessing] + * `number_of_bootstraps` [postprocessing] + + Optional objects + * `batch_size` [model] + + Creates + * forecasts in `forecast_path` if enabled + * bootstraps in `bootstrap_path` if enabled + * plots in `plot_path` + + """ def __init__(self): + """Initialise and run post-processing.""" super().__init__() self.model: keras.Model = self._load_model() self.ols_model = None @@ -47,19 +78,23 @@ class PostProcessing(RunEnvironment): self._run() def _run(self): + # ols model with TimeTracking(): self.train_ols_model() logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip train_ols_model() whenever it is possible to save time.") + + # forecasts with TimeTracking(): self.make_prediction() logging.info("take a look on the next reported time measure. If this increases a lot, one should think to " "skip make_prediction() whenever it is possible to save time.") + self.calculate_test_score() # bootstraps - if self.data_store.get("evaluate_bootstraps", "general.postprocessing"): + if self.data_store.get("evaluate_bootstraps", "postprocessing"): with TimeTracking(name="calculate bootstraps"): - create_new_bootstraps = self.data_store.get("create_new_bootstraps", "general.postprocessing") + create_new_bootstraps = self.data_store.get("create_new_bootstraps", "postprocessing") self.bootstrap_postprocessing(create_new_bootstraps) # skill scores @@ -71,8 +106,13 @@ class PostProcessing(RunEnvironment): def bootstrap_postprocessing(self, create_new_bootstraps: bool, _iter: int = 0) -> None: """ - Create skill scores of bootstrapped data. Also creates these bootstraps if create_new_bootstraps is true or a - failure occurred during skill score calculation. Sets class attribute bootstrap_skill_scores. + Calculate skill scores of bootstrapped data. + + Create bootstrapped data if create_new_bootstraps is true or a failure occurred during skill score calculation + (this will happen by default, if no bootstrapped data is available locally). Set class attribute + bootstrap_skill_scores. This method is implemented in a recursive fashion, but is only allowed to call itself + once. + :param create_new_bootstraps: calculate all bootstrap predictions and overwrite already available predictions :param _iter: internal counter to reduce unnecessary recursive calls (maximum number is 2, otherwise something went wrong). @@ -90,15 +130,17 @@ class PostProcessing(RunEnvironment): def create_bootstrap_forecast(self) -> None: """ - Creates the bootstrapped predictions for all stations and variables. These forecasts are saved in bootstrap_path - with the names `bootstraps_{var}_{station}.nc` and `bootstraps_labels_{station}.nc`. + Create bootstrapped predictions for all stations and variables. + + These forecasts are saved in bootstrap_path with the names `bootstraps_{var}_{station}.nc` and + `bootstraps_labels_{station}.nc`. """ # forecast with TimeTracking(name=inspect.stack()[0].function): # extract all requirements from data store bootstrap_path = self.data_store.get("bootstrap_path") forecast_path = self.data_store.get("forecast_path") - number_of_bootstraps = self.data_store.get("number_of_bootstraps", "general.postprocessing") + number_of_bootstraps = self.data_store.get("number_of_bootstraps", "postprocessing") # set bootstrap class bootstraps = BootStraps(self.test_data, bootstrap_path, number_of_bootstraps) @@ -132,12 +174,14 @@ class PostProcessing(RunEnvironment): def calculate_bootstrap_skill_scores(self) -> Dict[str, xr.DataArray]: """ + Calculate skill score of bootstrapped variables. + Use already created bootstrap predictions and the original predictions (the not-bootstrapped ones) and calculate skill scores for the bootstraps. The result is saved as a xarray DataArray in a dictionary structure separated for each station (keys of dictionary). + :return: The result dictionary with station-wise skill scores """ - with TimeTracking(name=inspect.stack()[0].function): # extract all requirements from data store bootstrap_path = self.data_store.get("bootstrap_path") @@ -157,7 +201,7 @@ class PostProcessing(RunEnvironment): shape = labels.shape # get original forecasts - orig = bootstraps.get_orig_prediction(forecast_path, f"forecasts_norm_{station}_test.nc").reshape(shape) + orig = bootstraps.get_orig_prediction(forecast_path, f"forecasts_norm_{station}_test.nc").reshape(shape) coords = (range(shape[0]), range(1, shape[1] + 1), ["orig"]) orig = xr.DataArray(orig, coords=coords, dims=["index", "ahead", "type"]) @@ -170,24 +214,47 @@ class PostProcessing(RunEnvironment): boot_scores = [] for ahead in range(1, window_lead_time + 1): data = boot_data.sel(ahead=ahead) - boot_scores.append(skill_scores.general_skill_score(data, forecast_name=boot, reference_name="orig")) + boot_scores.append( + skill_scores.general_skill_score(data, forecast_name=boot, reference_name="orig")) skill.loc[boot] = np.array(boot_scores) # collect all results in single dictionary score[station] = xr.DataArray(skill, dims=["boot_var", "ahead"]) return score - def _load_model(self): + def _load_model(self) -> keras.models: + """ + Load NN model either from data store or from local path. + + :return: the model + """ try: model = self.data_store.get("best_model") except NameNotFoundInDataStore: - logging.info("no model saved in data store. trying to load model from experiment path") + logging.info("No model was saved in data store. Try to load model from experiment path.") model_name = self.data_store.get("model_name", "model") model_class: AbstractModelClass = self.data_store.get("model", "model") model = keras.models.load_model(model_name, custom_objects=model_class.custom_objects) return model def plot(self): + """ + Create all plots. + + Plots are defined in experiment set up by `plot_list`. As default, all (following) plots are enabled: + + * :py:class:`PlotBootstrapSkillScore <src.plotting.postprocessing_plotting.PlotBootstrapSkillScore>` + * :py:class:`PlotConditionalQuantiles <src.plotting.postprocessing_plotting.PlotConditionalQuantiles>` + * :py:class:`PlotStationMap <src.plotting.postprocessing_plotting.PlotStationMap>` + * :py:class:`PlotMonthlySummary <src.plotting.postprocessing_plotting.PlotMonthlySummary>` + * :py:class:`PlotClimatologicalSkillScore <src.plotting.postprocessing_plotting.PlotClimatologicalSkillScore>` + * :py:class:`PlotCompetitiveSkillScore <src.plotting.postprocessing_plotting.PlotCompetitiveSkillScore>` + * :py:class:`PlotTimeSeries <src.plotting.postprocessing_plotting.PlotTimeSeries>` + * :py:class:`PlotAvailability <src.plotting.postprocessing_plotting.PlotAvailability>` + + .. note:: Bootstrap plots are only created if bootstraps are evaluated. + + """ logging.debug("Run plotting routines...") path = self.data_store.get("forecast_path") @@ -222,21 +289,27 @@ class PostProcessing(RunEnvironment): PlotAvailability(avail_data, plot_folder=self.plot_path) def calculate_test_score(self): + """Evaluate test score of model and save locally.""" test_score = self.model.evaluate_generator(generator=self.test_data_distributed.distribute_on_batches(), use_multiprocessing=False, verbose=0, steps=1) - logging.info(f"test score = {test_score}") - self._save_test_score(test_score) - - def _save_test_score(self, score): path = self.data_store.get("experiment_path") - with open(os.path.join(path, "test_scores.txt")) as f: - for index, item in enumerate(score): - f.write(f"{self.model.metrics[index]}, {item}\n") + with open(os.path.join(path, "test_scores.txt"), "a") as f: + for index, item in enumerate(test_score): + logging.info(f"{self.model.metrics_names[index]}, {item}") + f.write(f"{self.model.metrics_names[index]}, {item}\n") def train_ols_model(self): + """Train ordinary least squared model on train data.""" self.ols_model = OrdinaryLeastSquaredModel(self.train_data) def make_prediction(self): + """ + Create predictions for NN, OLS, and persistence and add true observation as reference. + + Predictions are filled in an array with full index range. Therefore, predictions can have missing values. All + predictions for a single station are stored locally under `<forecast/forecast_norm>_<station>_test.nc` and can + be found inside `forecast_path`. + """ logging.debug("start make_prediction") for i, _ in enumerate(self.test_data): data = self.test_data.get_data_generator(i) @@ -247,17 +320,20 @@ class PostProcessing(RunEnvironment): for normalised in [True, False]: # create empty arrays - nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays(data, count=4) + nn_prediction, persistence_prediction, ols_prediction, observation = self._create_empty_prediction_arrays( + data, count=4) # nn forecast - nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, normalised) + nn_prediction = self._create_nn_forecast(input_data, nn_prediction, mean, std, transformation_method, + normalised) # persistence persistence_prediction = self._create_persistence_forecast(data, persistence_prediction, mean, std, transformation_method, normalised) # ols - ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method, normalised) + ols_prediction = self._create_ols_forecast(input_data, ols_prediction, mean, std, transformation_method, + normalised) # observation observation = self._create_observation(data, observation, mean, std, transformation_method, normalised) @@ -276,17 +352,48 @@ class PostProcessing(RunEnvironment): file = os.path.join(path, f"{prefix}_{data.station[0]}_test.nc") all_predictions.to_netcdf(file) - def _get_frequency(self): + def _get_frequency(self) -> str: + """Get frequency abbreviation.""" getter = {"daily": "1D", "hourly": "1H"} return getter.get(self._sampling, None) - def _create_observation(self, data, _, mean, std, transformation_method, normalised): + @staticmethod + def _create_observation(data: DataPrep, _, mean: xr.DataArray, std: xr.DataArray, transformation_method: str, + normalised: bool) -> xr.DataArray: + """ + Create observation as ground truth from given data. + + Inverse transformation is applied to the ground truth to get the output in the original space. + + :param data: transposed observation from DataPrep + :param mean: mean of target value transformation + :param std: standard deviation of target value transformation + :param transformation_method: target values transformation method + :param normalised: transform ground truth in original space if false, or use normalised predictions if true + + :return: filled data array with observation + """ obs = data.label.copy() if not normalised: obs = statistics.apply_inverse_transformation(obs, mean, std, transformation_method) return obs - def _create_ols_forecast(self, input_data, ols_prediction, mean, std, transformation_method, normalised): + def _create_ols_forecast(self, input_data: xr.DataArray, ols_prediction: xr.DataArray, mean: xr.DataArray, + std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray: + """ + Create ordinary least square model forecast with given input data. + + Inverse transformation is applied to the forecast to get the output in the original space. + + :param data: transposed history from DataPrep + :param ols_prediction: empty array in right shape to fill with data + :param mean: mean of target value transformation + :param std: standard deviation of target value transformation + :param transformation_method: target values transformation method + :param normalised: transform prediction in original space if false, or use normalised predictions if true + + :return: filled data array with ols predictions + """ tmp_ols = self.ols_model.predict(input_data) if not normalised: tmp_ols = statistics.apply_inverse_transformation(tmp_ols, mean, std, transformation_method) @@ -295,7 +402,23 @@ class PostProcessing(RunEnvironment): ols_prediction.values = np.swapaxes(tmp_ols, 2, 0) if target_shape != tmp_ols.shape else tmp_ols return ols_prediction - def _create_persistence_forecast(self, data, persistence_prediction, mean, std, transformation_method, normalised): + def _create_persistence_forecast(self, data: DataPrep, persistence_prediction: xr.DataArray, mean: xr.DataArray, + std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray: + """ + Create persistence forecast with given data. + + Persistence is deviated from the value at t=0 and applied to all following time steps (t+1, ..., t+window). + Inverse transformation is applied to the forecast to get the output in the original space. + + :param data: DataPrep + :param persistence_prediction: empty array in right shape to fill with data + :param mean: mean of target value transformation + :param std: standard deviation of target value transformation + :param transformation_method: target values transformation method + :param normalised: transform prediction in original space if false, or use normalised predictions if true + + :return: filled data array with persistence predictions + """ tmp_persi = data.observation.copy().sel({'window': 0}) if not normalised: tmp_persi = statistics.apply_inverse_transformation(tmp_persi, mean, std, transformation_method) @@ -304,17 +427,23 @@ class PostProcessing(RunEnvironment): axis=1) return persistence_prediction - def _create_nn_forecast(self, input_data, nn_prediction, mean, std, transformation_method, normalised): + def _create_nn_forecast(self, input_data: xr.DataArray, nn_prediction: xr.DataArray, mean: xr.DataArray, + std: xr.DataArray, transformation_method: str, normalised: bool) -> xr.DataArray: """ - create the nn forecast for given input data. Inverse transformation is applied to the forecast to get the output - in the original space. Furthermore, only the output of the main branch is returned (not all minor branches, if - the network has multiple output branches). The main branch is defined to be the last entry of all outputs. - :param input_data: - :param nn_prediction: - :param mean: - :param std: - :param transformation_method: - :return: + Create NN forecast for given input data. + + Inverse transformation is applied to the forecast to get the output in the original space. Furthermore, only the + output of the main branch is returned (not all minor branches, if the network has multiple output branches). The + main branch is defined to be the last entry of all outputs. + + :param input_data: transposed history from DataPrep + :param nn_prediction: empty array in right shape to fill with data + :param mean: mean of target value transformation + :param std: standard deviation of target value transformation + :param transformation_method: target values transformation method + :param normalised: transform prediction in original space if false, or use normalised predictions if true + + :return: filled data array with nn predictions """ tmp_nn = self.model.predict(input_data) if not normalised: @@ -334,11 +463,15 @@ class PostProcessing(RunEnvironment): return [generator.label.copy() for _ in range(count)] @staticmethod - def create_fullindex(df, freq): - # Diese Funkton erstellt ein leeres df, mit Index der Frequenz frequ zwischen dem ersten und dem letzten Datum in df - # param: df as pandas dataframe - # param: freq as string - # return: index as pandas dataframe + def create_fullindex(df: Union[xr.DataArray, pd.DataFrame, pd.DatetimeIndex], freq: str) -> pd.DataFrame: + """ + Create full index from first and last date inside df and resample with given frequency. + + :param df: use time range of this data set + :param freq: frequency of full index + + :return: empty data frame with full index. + """ if isinstance(df, pd.DataFrame): earliest = df.index[0] latest = df.index[-1] @@ -355,13 +488,14 @@ class PostProcessing(RunEnvironment): return index @staticmethod - def create_forecast_arrays(index, ahead_names, **kwargs): + def create_forecast_arrays(index: pd.DataFrame, ahead_names: List[Union[str, int]], **kwargs): """ - This function combines different forecast types into one xarray. + Combine different forecast types into single xarray. - :param index: as index; index for forecasts (e.g. time) - :param ahead_names: as list of str/int: names of ahead values (e.g. hours or days) + :param index: index for forecasts (e.g. time) + :param ahead_names: names of ahead values (e.g. hours or days) :param kwargs: as xarrays; data of forecasts + :return: xarray of dimension 3: index, ahead_names, # predictions """ @@ -377,7 +511,15 @@ class PostProcessing(RunEnvironment): res.loc[match_index, :, k] = v.sel({'datetime': match_index}).squeeze('Stations').transpose() return res - def _get_external_data(self, station): + def _get_external_data(self, station: str) -> Union[xr.DataArray, None]: + """ + Get external data for given station. + + External data is defined as data that is not part of the observed period. From an evaluation perspective, this + refers to data, that is no test data, and therefore to train and val data. + + :param station: name of station to load external data. + """ try: data = self.train_val_data.get_data_generator(station) mean, std, transformation_method = data.get_transformation_information(variable=self.target_var) @@ -387,7 +529,16 @@ class PostProcessing(RunEnvironment): except KeyError: return None - def calculate_skill_scores(self): + def calculate_skill_scores(self) -> Tuple[Dict, Dict]: + """ + Calculate skill scores of CNN forecast. + + The competitive skill score compares the CNN prediction with persistence and ordinary least squares forecasts. + Whereas, the climatological skill scores evaluates the CNN prediction in terms of meaningfulness in comparison + to different climatological references. + + :return: competitive and climatological skill scores + """ path = self.data_store.get("forecast_path") window_lead_time = self.data_store.get("window_lead_time") skill_score_competitive = {} diff --git a/src/run_modules/pre_processing.py b/src/run_modules/pre_processing.py index 551ea599a3114b7b97f5bcb146cf6e131e324eb5..b4b36a20bf9ed7827a1ac151141c13122f517e33 100644 --- a/src/run_modules/pre_processing.py +++ b/src/run_modules/pre_processing.py @@ -1,7 +1,8 @@ +"""Pre-processing module.""" + __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-11-25' - import logging import os from typing import Tuple, Dict, List @@ -9,9 +10,10 @@ from typing import Tuple, Dict, List import numpy as np import pandas as pd -from src.data_handling.data_generator import DataGenerator -from src.helpers import TimeTracking, check_path_and_create -from src.join import EmptyQueryResult +from src.data_handling import DataGenerator +from src.helpers import TimeTracking +from src.configuration import path_config +from src.helpers.join import EmptyQueryResult from src.run_modules.run_environment import RunEnvironment DEFAULT_ARGS_LIST = ["data_path", "network", "stations", "variables", "interpolate_dim", "target_dim", "target_var"] @@ -21,20 +23,39 @@ DEFAULT_KWARGS_LIST = ["limit_nan_fill", "window_history_size", "window_lead_tim class PreProcessing(RunEnvironment): - """ - Pre-process your data by using this class. It includes time tracking and uses the experiment setup to look for data - and stores it if not already in local disk. Further, it provides this data as a generator and checks for valid - stations (in this context: valid=data available). Finally, it splits the data into valid training, validation and - testing subsets. + Pre-process your data by using this class. + + Schedule of pre-processing: + #. load and check valid stations (either download or load from disk) + #. split subsets (train, val, test, train & val) + #. create small report on data metrics + + Required objects [scope] from data store: + * all elements from `DEFAULT_ARGS_LIST` in scope preprocessing for general data loading + * all elements from `DEFAULT_ARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings + * `fraction_of_training` [.] + * `experiment_path` [.] + * `use_all_stations_on_all_data_sets` [.] + + Optional objects + * all elements from `DEFAULT_KWARGS_LIST` in scope preprocessing for general data loading + * all elements from `DEFAULT_KWARGS_LIST` in scopes [train, val, test, train_val] for custom subset settings + + Sets + * `stations` in [., train, val, test, train_val] + * `generator` in [train, val, test, train_val] + * `transformation` [.] + + Creates + * all input and output data in `data_path` + * latex reports in `experiment_path/latex_report` + """ def __init__(self): - - # create run framework + """Set up and run pre-processing.""" super().__init__() - - # self._run() def _run(self): @@ -47,6 +68,7 @@ class PreProcessing(RunEnvironment): self.report_pre_processing() def report_pre_processing(self): + """Log some metrics on data and create latex report.""" logging.debug(20 * '##') n_train = len(self.data_store.get('generator', 'train')) n_val = len(self.data_store.get('generator', 'val')) @@ -62,34 +84,42 @@ class PreProcessing(RunEnvironment): def create_latex_report(self): """ - This function creates tables with information on the station meta data and a summary on subset sample sizes. + Create tables with information on the station meta data and a summary on subset sample sizes. - * station_sample_size.md: see table below - * station_sample_size.tex: same as table below, but as latex table + * station_sample_size.md: see table below as markdown + * station_sample_size.tex: same as table below as latex table * station_sample_size_short.tex: reduced size table without any meta data besides station ID, as latex table All tables are stored inside experiment_path inside the folder latex_report. The table format (e.g. which meta data is highlighted) is currently hardcoded to have a stable table style. If further styles are needed, it is better to add an additional style than modifying the existing table styles. + +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ | stat. ID | station_name | station_lon | station_lat | station_alt | train | val | test | - |------------|-------------------------------------------|---------------|---------------|---------------|---------|-------|--------| + +============+===========================================+===============+===============+===============+=========+=======+========+ | DEBW013 | Stuttgart Bad Cannstatt | 9.2297 | 48.8088 | 235 | 1434 | 712 | 1080 | + +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ | DEBW076 | Baden-Baden | 8.2202 | 48.7731 | 148 | 3037 | 722 | 710 | + +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ | DEBW087 | Schwäbische_Alb | 9.2076 | 48.3458 | 798 | 3044 | 714 | 1087 | + +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ | DEBW107 | Tübingen | 9.0512 | 48.5077 | 325 | 1803 | 715 | 1087 | + +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ | DEBY081 | Garmisch-Partenkirchen/Kreuzeckbahnstraße | 11.0631 | 47.4764 | 735 | 2935 | 525 | 714 | + +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ | # Stations | nan | nan | nan | nan | 6 | 6 | 6 | + +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ | # Samples | nan | nan | nan | nan | 12253 | 3388 | 4678 | + +------------+-------------------------------------------+---------------+---------------+---------------+---------+-------+--------+ """ meta_data = ['station_name', 'station_lon', 'station_lat', 'station_alt'] meta_round = ["station_lon", "station_lat", "station_alt"] precision = 4 path = os.path.join(self.data_store.get("experiment_path"), "latex_report") - check_path_and_create(path) + path_config.check_path_and_create(path) set_names = ["train", "val", "test"] - df = pd.DataFrame(columns=meta_data+set_names) + df = pd.DataFrame(columns=meta_data + set_names) for set_name in set_names: data: DataGenerator = self.data_store.get("generator", set_name) for station in data.stations: @@ -102,24 +132,28 @@ class PreProcessing(RunEnvironment): df.sort_index(inplace=True) df = df.reindex(df.index.drop(["# Stations", "# Samples"]).to_list() + ["# Stations", "# Samples"], ) df.index.name = 'stat. ID' - column_format = np.repeat('c', df.shape[1]+1) + column_format = np.repeat('c', df.shape[1] + 1) column_format[0] = 'l' column_format[-1] = 'r' column_format = ''.join(column_format.tolist()) df.to_latex(os.path.join(path, "station_sample_size.tex"), na_rep='---', column_format=column_format) - df.to_markdown(open(os.path.join(path, "station_sample_size.md"), mode="w", encoding='utf-8'), tablefmt="github") + df.to_markdown(open(os.path.join(path, "station_sample_size.md"), mode="w", encoding='utf-8'), + tablefmt="github") df.drop(meta_data, axis=1).to_latex(os.path.join(path, "station_sample_size_short.tex"), na_rep='---', column_format=column_format) def split_train_val_test(self) -> None: """ - Splits all subsets. Currently: train, val, test and train_val (actually this is only the merge of train and val, - but as an separate generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train - subset needs always to be executed at first, to set a proper transformation. + Split data into subsets. + + Currently: train, val, test and train_val (actually this is only the merge of train and val, but as an separate + generator). IMPORTANT: Do not change to order of the execution of create_set_split. The train subset needs + always to be executed at first, to set a proper transformation. """ fraction_of_training = self.data_store.get("fraction_of_training") stations = self.data_store.get("stations") - train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), fraction_of_training) + train_index, val_index, test_index, train_val_index = self.split_set_indices(len(stations), + fraction_of_training) subset_names = ["train", "val", "test", "train_val"] if subset_names[0] != "train": # pragma: no cover raise AssertionError(f"Make sure, that the train subset is always at first execution position! Given subset" @@ -130,12 +164,16 @@ class PreProcessing(RunEnvironment): @staticmethod def split_set_indices(total_length: int, fraction: float) -> Tuple[slice, slice, slice, slice]: """ - create the training, validation and test subset slice indices for given total_length. The test data consists on - (1-fraction) of total_length (fraction*len:end). Train and validation data therefore are made from fraction of - total_length (0:fraction*len). Train and validation data is split by the factor 0.8 for train and 0.2 for - validation. In addition, split_set_indices returns also the combination of training and validation subset. + Create the training, validation and test subset slice indices for given total_length. + + The test data consists on (1-fraction) of total_length (fraction*len:end). Train and validation data therefore + are made from fraction of total_length (0:fraction*len). Train and validation data is split by the factor 0.8 + for train and 0.2 for validation. In addition, split_set_indices returns also the combination of training and + validation subset. + :param total_length: list with all objects to split :param fraction: ratio between test and union of train/val data + :return: slices for each subset in the order: train, val, test, train_val """ pos_test_split = int(total_length * fraction) @@ -145,12 +183,15 @@ class PreProcessing(RunEnvironment): train_val_index = slice(0, pos_test_split) return train_index, val_index, test_index, train_val_index - def create_set_split(self, index_list: slice, set_name) -> None: + def create_set_split(self, index_list: slice, set_name: str) -> None: """ + Create subsets and store in data store. + Create the subset for given split index and stores the DataGenerator with given set name in data store as - `generator`. Checks for all valid stations using the default (kw)args for given scope and creates the - DataGenerator for all valid stations. Also sets all transformation information, if subset is training set. Make + `generator`. Check for all valid stations using the default (kw)args for given scope and create the + DataGenerator for all valid stations. Also set all transformation information, if subset is training set. Make sure, that the train set is executed first, and all other subsets afterwards. + :param index_list: list of all stations to use for the set. If attribute use_all_stations_on_all_data_sets=True, this list is ignored. :param set_name: name to load/save all information from/to data store. @@ -158,30 +199,38 @@ class PreProcessing(RunEnvironment): args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name) kwargs = self.data_store.create_args_dict(DEFAULT_KWARGS_LIST, scope=set_name) stations = args["stations"] - if self.data_store.get("use_all_stations_on_all_data_sets", scope=set_name): + if self.data_store.get("use_all_stations_on_all_data_sets"): set_stations = stations else: set_stations = stations[index_list] logging.debug(f"{set_name.capitalize()} stations (len={len(set_stations)}): {set_stations}") + # validate set set_stations = self.check_valid_stations(args, kwargs, set_stations, load_tmp=False, name=set_name) self.data_store.set("stations", set_stations, scope=set_name) + # create set generator and store set_args = self.data_store.create_args_dict(DEFAULT_ARGS_LIST, scope=set_name) data_set = DataGenerator(**set_args, **kwargs) self.data_store.set("generator", data_set, scope=set_name) + # extract transformation from train set if set_name == "train": self.data_store.set("transformation", data_set.transformation) @staticmethod - def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True, name=None): + def check_valid_stations(args: Dict, kwargs: Dict, all_stations: List[str], load_tmp=True, save_tmp=True, + name=None): """ - Check if all given stations in `all_stations` are valid. Valid means, that there is data available for the given - time range (is included in `kwargs`). The shape and the loading time are logged in debug mode. + Check if all given stations in `all_stations` are valid. + + Valid means, that there is data available for the given time range (is included in `kwargs`). The shape and the + loading time are logged in debug mode. + :param args: Dictionary with required parameters for DataGenerator class (`data_path`, `network`, `stations`, `variables`, `interpolate_dim`, `target_dim`, `target_var`). :param kwargs: positional parameters for the DataGenerator class (e.g. `start`, `interpolate_method`, `window_lead_time`). :param all_stations: All stations to check. :param name: name to display in the logging info message + :return: Corrected list containing only valid station IDs. """ t_outer = TimeTracking() @@ -200,7 +249,8 @@ class PreProcessing(RunEnvironment): if data.history is None: raise AttributeError valid_stations.append(station) - logging.debug(f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') + logging.debug( + f'{station}: history_shape = {data.history.transpose("datetime", "window", "Stations", "variables").shape}') logging.debug(f"{station}: loading time = {t_inner}") except (AttributeError, EmptyQueryResult): continue diff --git a/src/run_modules/run_environment.py b/src/run_modules/run_environment.py index 7bd5027788934322d704192e1dff2995539fe245..a0e619f364a060b3ed44639c6057046db197d84b 100644 --- a/src/run_modules/run_environment.py +++ b/src/run_modules/run_environment.py @@ -1,68 +1,166 @@ +"""Implementation of run environment.""" + __author__ = "Lukas Leufen" __date__ = '2019-11-25' +import json import logging import os import shutil import time +from src.helpers.datastore import DataStoreByScope as DataStoreObject +from src.helpers.datastore import NameNotFoundInDataStore from src.helpers import Logger -from src.datastore import DataStoreByScope as DataStoreObject -from src.datastore import NameNotFoundInDataStore from src.helpers import TimeTracking +from src.plotting.tracker_plot import TrackPlot class RunEnvironment(object): """ - basic run class to measure execution time. Either call this class calling it by 'with' or delete the class instance - after finishing the measurement. The duration result is logged. + Basic run class to measure execution time. + + Either call this class by 'with' statement or delete the class instance after finishing the measurement. The + duration result is logged. + + .. code-block:: python + + >>> with RunEnvironment(): + <your code> + INFO: RunEnvironment started + ... + INFO: RunEnvironment finished after 00:00:04 (hh:mm:ss) + + If you want to embed your custom module in a RunEnvironment, you can easily call it inside the with statement. If + you want to exchange between different modules in addition, create your module as inheritance of the RunEnvironment + and call it after you initialised the RunEnvironment itself. + + .. code-block:: python + + class CustomClass(RunEnvironment): + + def __init__(self): + super().__init__() + ... + ... + + + >>> with RunEnvironment(): + CustomClass() + INFO: RunEnvironment started + INFO: CustomClass started + INFO: CustomClass finished after 00:00:04 (hh:mm:ss) + INFO: RunEnvironment finished after 00:00:04 (hh:mm:ss) + + All data that is stored in the data store will be available for all other modules that inherit from RunEnvironment + as long the RunEnvironemnt base class is running. If the base class is deleted either by hand or on exit of the with + statement, this storage is cleared. + + .. code-block:: python + + class CustomClassA(RunEnvironment): + + def __init__(self): + super().__init__() + self.data_store.set("testVar", 12) + + + class CustomClassB(RunEnvironment): + + def __init__(self): + super().__init__() + self.test_var = self.data_store.get("testVar") + logging.info(f"testVar = {self.test_var}") + + + >>> with RunEnvironment(): + CustomClassA() + CustomClassB() + INFO: RunEnvironment started + INFO: CustomClassA started + INFO: CustomClassA finished after 00:00:01 (hh:mm:ss) + INFO: CustomClassB started + INFO: testVar = 12 + INFO: CustomClassB finished after 00:00:02 (hh:mm:ss) + INFO: RunEnvironment finished after 00:00:03 (hh:mm:ss) + """ + # set data store and logger (both are mutable!) del_by_exit = False data_store = DataStoreObject() logger = Logger() + tracker_list = [] def __init__(self): - """ - Starts time tracking automatically and logs as info. - """ + """Start time tracking automatically and logs as info.""" self.time = TimeTracking() logging.info(f"{self.__class__.__name__} started") + # atexit.register(self.__del__) + self.data_store.tracker.append({}) + self.tracker_list.extend([{self.__class__.__name__: self.data_store.tracker[-1]}]) def __del__(self): """ - This is the class finalizer. The code is not executed if already called by exit method to prevent duplicated - logging (__exit__ is always executed before __del__) it this class was used in a with statement. + Finalise class. + + Only stop time tracking, if not already called by exit method to prevent duplicated logging (__exit__ is always + executed before __del__) it this class was used in a with statement. If instance is called as base class and + not as inheritance from this class, log file is copied and data store is cleared. """ if not self.del_by_exit: self.time.stop() logging.info(f"{self.__class__.__name__} finished after {self.time}") self.del_by_exit = True - if self.__class__.__name__ == "RunEnvironment": - self.__copy_log_file() - self.data_store.clear_data_store() + # copy log file and clear data store only if called as base class and not as super class + if self.__class__.__name__ == "RunEnvironment": + try: + self.__plot_tracking() + self.__save_tracking() + self.__copy_log_file() + except FileNotFoundError: + pass + self.data_store.clear_data_store() def __enter__(self): + """Enter run environment.""" return self def __exit__(self, exc_type, exc_val, exc_tb): + """Exit run environment.""" if exc_type: logging.error(exc_val, exc_info=(exc_type, exc_val, exc_tb)) self.__del__() def __copy_log_file(self): try: - counter = 0 - filename_pattern = os.path.join(self.data_store.get("experiment_path"), "logging_%03i.log") - new_file = filename_pattern % counter - while os.path.exists(new_file): - counter += 1 - new_file = filename_pattern % counter + new_file = self.__find_file_pattern("logging_%03i.log") logging.info(f"Copy log file to {new_file}") shutil.copyfile(self.logger.log_file, new_file) except (NameNotFoundInDataStore, FileNotFoundError): pass + def __save_tracking(self): + tracker = self.data_store.tracker + new_file = self.__find_file_pattern("tracking_%03i.json") + logging.info(f"Copy tracker file to {new_file}") + with open(new_file, "w") as f: + json.dump(tracker, f) + + def __plot_tracking(self): + plot_folder, plot_name = os.path.split(self.__find_file_pattern("tracking_%03i.pdf")) + TrackPlot(self.tracker_list, sparse_conn_mode=True, plot_folder=plot_folder, plot_name=plot_name) + + def __find_file_pattern(self, name): + counter = 0 + filename_pattern = os.path.join(self.data_store.get_default("experiment_path", os.path.realpath(".")), name) + new_file = filename_pattern % counter + while os.path.exists(new_file): + counter += 1 + new_file = filename_pattern % counter + return new_file + @staticmethod def do_stuff(length=2): + """Just a placeholder method for testing without any sense.""" time.sleep(length) diff --git a/src/run_modules/training.py b/src/run_modules/training.py index 2d949af8c68f244c0a0da2bad6580c616695da8d..8cb4726fdc84ad10e62106c1d2bcbf899457e31d 100644 --- a/src/run_modules/training.py +++ b/src/run_modules/training.py @@ -1,24 +1,67 @@ +"""Training module.""" + __author__ = "Lukas Leufen, Felix Kleinert" __date__ = '2019-12-05' import json import logging import os -import pickle +from typing import Union import keras +from keras.callbacks import Callback, History -from src.data_handling.data_distributor import Distributor -from src.model_modules.keras_extensions import LearningRateDecay, CallbackHandler +from src.data_handling import Distributor +from src.model_modules.keras_extensions import CallbackHandler from src.plotting.training_monitoring import PlotModelHistory, PlotModelLearningRate from src.run_modules.run_environment import RunEnvironment -from typing import Union - class Training(RunEnvironment): + """ + Train your model with this module. + + This module isn't required to run, if only a fresh post-processing is preformed. Either remove training call from + your run script or set create_new_model and trainable both to false. + + Schedule of training: + #. set_generators(): set generators for training, validation and testing and distribute according to batch size + #. make_predict_function(): create predict function before distribution on multiple nodes (detailed information + in method description) + #. train(): start or resume training of model and save callbacks + #. save_model(): save best model from training as final model + + Required objects [scope] from data store: + * `model` [model] + * `batch_size` [model] + * `epochs` [model] + * `callbacks` [model] + * `model_name` [model] + * `experiment_name` [.] + * `experiment_path` [.] + * `trainable` [.] + * `create_new_model` [.] + * `generator` [train, val, test] + * `plot_path` [.] + + Optional objects + * `permute_data` [train, val, test] + * `upsampling` [train, val, test] + + Sets + * `best_model` [.] + + Creates + * `<exp_name>_model-best.h5` + * `<exp_name>_model-best-callbacks-<name>.h5` (all callbacks from CallbackHandler) + * `history.json` + * `history_lr.json` (optional) + * `<exp_name>_history_<name>.pdf` (different monitoring plots depending on loss metrics and callbacks) + + """ def __init__(self): + """Set up and run training.""" super().__init__() self.model: keras.Model = self.data_store.get("model", "model") self.train_set: Union[Distributor, None] = None @@ -33,17 +76,7 @@ class Training(RunEnvironment): self._run() def _run(self) -> None: - """ - Perform training - 1) set_generators(): - set generators for training, validation and testing and distribute according to batch size - 2) make_predict_function(): - create predict function before distribution on multiple nodes (detailed information in method description) - 3) train(): - start or resume training of model and save callbacks - 4) save_model(): - save best model from training as final model - """ + """Run training. Details in class description.""" self.set_generators() self.make_predict_function() if self._trainable: @@ -54,39 +87,44 @@ class Training(RunEnvironment): def make_predict_function(self) -> None: """ - Creates the predict function. Must be called before distributing. This is necessary, because tf will compile - the predict function just in the moment it is used the first time. This can cause problems, if the model is - distributed on different workers. To prevent this, the function is pre-compiled. See discussion @ + Create predict function. + + Must be called before distributing. This is necessary, because tf will compile the predict function just in + the moment it is used the first time. This can cause problems, if the model is distributed on different + workers. To prevent this, the function is pre-compiled. See discussion @ https://stackoverflow.com/questions/40850089/is-keras-thread-safe/43393252#43393252 """ self.model._make_predict_function() def _set_gen(self, mode: str) -> None: """ - Set and distribute the generators for given mode regarding batch size + Set and distribute the generators for given mode regarding batch size. + :param mode: name of set, should be from ["train", "val", "test"] """ gen = self.data_store.get("generator", mode) - # permute_data = self.data_store.get_default("permute_data", mode, default=False) kwargs = self.data_store.create_args_dict(["permute_data", "upsampling"], scope=mode) setattr(self, f"{mode}_set", Distributor(gen, self.model, self.batch_size, **kwargs)) def set_generators(self) -> None: """ - Set all generators for training, validation, and testing subsets. The called sub-method will automatically - distribute the data according to the batch size. The subsets can be accessed as class variables train_set, - val_set, and test_set . + Set all generators for training, validation, and testing subsets. + + The called sub-method will automatically distribute the data according to the batch size. The subsets can be + accessed as class variables train_set, val_set, and test_set. """ for mode in ["train", "val", "test"]: self._set_gen(mode) def train(self) -> None: """ - Perform training using keras fit_generator(). Callbacks are stored locally in the experiment directory. Best - model from training is saved for class variable model. If the file path of checkpoint is not empty, this method - assumes, that this is not a new training starting from the very beginning, but a resumption from a previous - started but interrupted training (or a stopped and now continued training). Train will automatically load the - locally stored information and the corresponding model and proceed with the already started training. + Perform training using keras fit_generator(). + + Callbacks are stored locally in the experiment directory. Best model from training is saved for class + variable model. If the file path of checkpoint is not empty, this method assumes, that this is not a new + training starting from the very beginning, but a resumption from a previous started but interrupted training + (or a stopped and now continued training). Train will automatically load the locally stored information and the + corresponding model and proceed with the already started training. """ logging.info(f"Train with {len(self.train_set)} mini batches.") logging.info(f"Train with option upsampling={self.train_set.upsampling}.") @@ -106,7 +144,7 @@ class Training(RunEnvironment): self.callbacks.load_callbacks() self.callbacks.update_checkpoint() self.model = keras.models.load_model(checkpoint.filepath) - hist = self.callbacks.get_callback_by_name("hist") + hist: History = self.callbacks.get_callback_by_name("hist") initial_epoch = max(hist.epoch) + 1 _ = self.model.fit_generator(generator=self.train_set.distribute_on_batches(), steps_per_epoch=len(self.train_set), @@ -126,9 +164,7 @@ class Training(RunEnvironment): self.create_monitoring_plots(history, lr) def save_model(self) -> None: - """ - save model in local experiment directory. Model is named as <experiment_name>_<custom_model_name>.h5 . - """ + """Save model in local experiment directory. Model is named as `<experiment_name>_<custom_model_name>.h5`.""" model_name = self.data_store.get("model_name", "model") logging.debug(f"save best model to {model_name}") self.model.save(model_name) @@ -137,6 +173,7 @@ class Training(RunEnvironment): def load_best_model(self, name: str) -> None: """ Load model weights for model with name. Skip if no weights are available. + :param name: name of the model to load weights for """ logging.debug(f"load best model: {name}") @@ -146,12 +183,15 @@ class Training(RunEnvironment): except OSError: logging.info('no weights to reload...') - def save_callbacks_as_json(self, history: keras.callbacks.History, lr_sc: keras.callbacks) -> None: + def save_callbacks_as_json(self, history: Callback, lr_sc: Callback) -> None: """ Save callbacks (history, learning rate) of training. + * history.history -> history.json * lr_sc.lr -> history_lr.json + :param history: history object of training + :param lr_sc: learning rate object """ logging.debug("saving callbacks") path = self.data_store.get("experiment_path") @@ -161,12 +201,14 @@ class Training(RunEnvironment): with open(os.path.join(path, "history_lr.json"), "w") as f: json.dump(lr_sc.lr, f) - def create_monitoring_plots(self, history: keras.callbacks.History, lr_sc: LearningRateDecay) -> None: + def create_monitoring_plots(self, history: Callback, lr_sc: Callback) -> None: """ - Creates the history and learning rate plot in dependence of the number of epochs. The plots are saved in the - experiment's plot_path. History plot is named '<exp_name>_history_loss_val_loss.pdf', the learning rate with - '<exp_name>_history_learning_rate.pdf'. - :param history: keras history object with losses to plot (must include 'loss' and 'val_loss') + Create plot of history and learning rate in dependence of the number of epochs. + + The plots are saved in the experiment's plot_path. History plot is named `<exp_name>_history_loss_val_loss.pdf`, + the learning rate with `<exp_name>_history_learning_rate.pdf`. + + :param history: keras history object with losses to plot (must at least include `loss` and `val_loss`) :param lr_sc: learning rate decay object with 'lr' attribute """ path = self.data_store.get("plot_path") diff --git a/test/test_configuration/test_init.py b/test/test_configuration/test_init.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/test/test_configuration/test_path_config.py b/test/test_configuration/test_path_config.py new file mode 100644 index 0000000000000000000000000000000000000000..cb40d5835da737594943504aaca38c3adfbc2936 --- /dev/null +++ b/test/test_configuration/test_path_config.py @@ -0,0 +1,116 @@ +import logging +import os + +import mock +import pytest + +from src.configuration import prepare_host, set_experiment_name, set_bootstrap_path, check_path_and_create +from src.helpers import PyTestRegex + + +class TestPrepareHost: + + @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest", + "runner-6HmDp9Qd-project-2411-concurrent-01"]) + @mock.patch("os.getlogin", return_value="testUser") + @mock.patch("os.path.exists", return_value=True) + def test_prepare_host(self, mock_host, mock_user, mock_path): + assert prepare_host() == "/home/testUser/machinelearningtools/data/toar_daily/" + assert prepare_host() == "/home/testUser/Data/toar_daily/" + assert prepare_host() == "/home/testUser/Data/toar_daily/" + assert prepare_host() == "/p/project/cjjsc42/testUser/DATA/toar_daily/" + assert prepare_host() == "/p/home/jusers/testUser/juwels/intelliaq/DATA/toar_daily/" + assert prepare_host() == '/home/testUser/machinelearningtools/data/toar_daily/' + + @mock.patch("socket.gethostname", return_value="NotExistingHostName") + @mock.patch("os.getlogin", return_value="zombie21") + def test_error_handling_unknown_host(self, mock_user, mock_host): + with pytest.raises(OSError) as e: + prepare_host() + assert "unknown host 'NotExistingHostName'" in e.value.args[0] + + @mock.patch("os.getlogin", return_value="zombie21") + @mock.patch("src.configuration.path_config.check_path_and_create", side_effect=PermissionError) + def test_error_handling(self, mock_cpath, mock_user): + # if "runner-6HmDp9Qd-project-2411-concurrent" not in platform.node(): + # mock_host.return_value = "linux-aa9b" + with pytest.raises(NotADirectoryError) as e: + prepare_host() + assert PyTestRegex(r"path '.*' does not exist for host '.*'\.") == e.value.args[0] + with pytest.raises(NotADirectoryError) as e: + prepare_host(False) + # assert "does not exist for host 'linux-aa9b'" in e.value.args[0] + assert PyTestRegex(r"path '.*' does not exist for host '.*'\.") == e.value.args[0] + + @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest", + "runner-6HmDp9Qd-project-2411-concurrent-01"]) + @mock.patch("os.getlogin", side_effect=OSError) + @mock.patch("os.path.exists", return_value=True) + def test_os_error(self, mock_path, mock_user, mock_host): + path = prepare_host() + assert path == "/home/default/machinelearningtools/data/toar_daily/" + path = prepare_host() + assert path == "/home/default/Data/toar_daily/" + path = prepare_host() + assert path == "/home/default/Data/toar_daily/" + path = prepare_host() + assert path == "/p/project/cjjsc42/default/DATA/toar_daily/" + path = prepare_host() + assert path == "/p/home/jusers/default/juwels/intelliaq/DATA/toar_daily/" + path = prepare_host() + assert path == '/home/default/machinelearningtools/data/toar_daily/' + + @mock.patch("socket.gethostname", side_effect=["linux-aa9b"]) + @mock.patch("os.getlogin", return_value="testUser") + @mock.patch("os.path.exists", return_value=False) + @mock.patch("os.makedirs", side_effect=None) + def test_os_path_exists(self, mock_host, mock_user, mock_path, mock_check): + path = prepare_host() + assert path == "/home/testUser/machinelearningtools/data/toar_daily/" + + +class TestSetExperimentName: + + def test_set_experiment(self): + exp_name, exp_path = set_experiment_name() + assert exp_name == "TestExperiment" + assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment")) + exp_name, exp_path = set_experiment_name(experiment_name="2019-11-14", experiment_path="./test2") + assert exp_name == "2019-11-14_network" + assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "test2", exp_name)) + + def test_set_experiment_from_sys(self): + exp_name, _ = set_experiment_name(experiment_name="2019-11-14") + assert exp_name == "2019-11-14_network" + + def test_set_experiment_hourly(self): + exp_name, exp_path = set_experiment_name(sampling="hourly") + assert exp_name == "TestExperiment_hourly" + assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment_hourly")) + + +class TestSetBootstrapPath: + + @mock.patch("os.makedirs", side_effect=None) + def test_bootstrap_path_is_none(self, mock_makedir): + bootstrap_path = set_bootstrap_path(None, 'TestDataPath/', 'daily') + assert bootstrap_path == os.path.abspath('TestDataPath/../bootstrap_daily') + + @mock.patch("os.makedirs", side_effect=None) + def test_bootstap_path_is_given(self, mock_makedir): + bootstrap_path = set_bootstrap_path('Test/path/to/boots', None, None) + assert bootstrap_path == os.path.abspath('./Test/path/to/boots') + + +class TestCheckPath: + + def test_check_path_and_create(self, caplog): + caplog.set_level(logging.DEBUG) + path = 'data/test' + assert not os.path.exists('data/test') + check_path_and_create(path) + assert os.path.exists('data/test') + assert caplog.messages[0] == "Created path: data/test" + check_path_and_create(path) + assert caplog.messages[1] == "Path already exists: data/test" + os.rmdir('data/test') \ No newline at end of file diff --git a/test/test_data_handling/test_bootstraps.py b/test/test_data_handling/test_bootstraps.py index c2b814b7bf173b61b4967c83611cdd3de08ed91b..650c232314a351c148dcf906718e16e3f454a277 100644 --- a/test/test_data_handling/test_bootstraps.py +++ b/test/test_data_handling/test_bootstraps.py @@ -1,18 +1,15 @@ - -from src.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator -from src.data_handling.data_generator import DataGenerator -from src.helpers import PyTestAllEqual, xr_all_equal - import logging -import mock import os -import pytest import shutil -import typing +import mock import numpy as np +import pytest import xarray as xr +from src.data_handling.bootstraps import BootStraps, CreateShuffledData, BootStrapGenerator +from src.data_handling.data_generator import DataGenerator + @pytest.fixture def orig_generator(data_path): @@ -44,7 +41,8 @@ class TestBootStrapGenerator: assert boot_gen.variables == ["o3", "temp"] assert xr.testing.assert_equal(boot_gen.history_orig, hist) is None assert xr.testing.assert_equal(boot_gen.history, hist.sel(variables=["temp"])) is None - assert xr.testing.assert_allclose(boot_gen.shuffled - 1, hist.sel(variables="o3").expand_dims({"boots": [0]})) is None + assert xr.testing.assert_allclose(boot_gen.shuffled - 1, + hist.sel(variables="o3").expand_dims({"boots": [0]})) is None def test_len(self, boot_gen): assert len(boot_gen) == 20 @@ -290,4 +288,3 @@ class TestBootStraps: assert f(regex, test_list, 10, 10) is None assert f(regex, test_list, 9, 10) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc" assert f(regex, test_list, 9, 20) == "DEBW108_h2o_o3_temp_hist9_nboots20_shuffled.nc" - diff --git a/test/test_data_handling/test_data_distributor.py b/test/test_data_handling/test_data_distributor.py index 15344fd808a4aa9ee5774ad8ba647bf5ce06d015..9e2242fed67599cdd53ddd29e4fbd4304425b77e 100644 --- a/test/test_data_handling/test_data_distributor.py +++ b/test/test_data_handling/test_data_distributor.py @@ -49,7 +49,7 @@ class TestDistributor: values = np.zeros((2311, 19)) assert distributor._get_number_of_mini_batches(values) == math.ceil(2311 / distributor.batch_size) - def test_distribute_on_batches_single_loop(self, generator_two_stations, model): + def test_distribute_on_batches_single_loop(self, generator_two_stations, model): d = Distributor(generator_two_stations, model) for e in d.distribute_on_batches(fit_call=False): assert e[0].shape[0] <= d.batch_size @@ -60,7 +60,7 @@ class TestDistributor: for i, e in enumerate(d.distribute_on_batches()): if i < len(d): elements.append(e[0]) - elif i == 2*len(d): # check if all elements are repeated + elif i == 2 * len(d): # check if all elements are repeated assert np.testing.assert_array_equal(e[0], elements[i - len(d)]) is None else: # break when 3rd iteration starts (is called as infinite loop) break @@ -98,7 +98,7 @@ class TestDistributor: assert np.testing.assert_equal(x, x_perm) is None assert np.testing.assert_equal(y, y_perm) is None - def test_distribute_on_batches_upsampling_no_extremes_given(self, generator, model): + def test_distribute_on_batches_upsampling_no_extremes_given(self, generator, model): d = Distributor(generator, model, upsampling=True) gen_len = d.generator.get_data_generator(0, load_local_tmp_storage=False).get_transposed_label().shape[0] num_mini_batches = math.ceil(gen_len / d.batch_size) diff --git a/test/test_data_handling/test_data_generator.py b/test/test_data_handling/test_data_generator.py index 939f93cc9ee01c76a282e755aca14b39c6fc4ac9..754728ba403fbda25c021c2f576a1bc89d26f83f 100644 --- a/test/test_data_handling/test_data_generator.py +++ b/test/test_data_handling/test_data_generator.py @@ -1,15 +1,14 @@ -import os - import operator as op -import pytest +import os +import pickle -import shutil import numpy as np +import pytest import xarray as xr -import pickle + from src.data_handling.data_generator import DataGenerator from src.data_handling.data_preparation import DataPrep -from src.join import EmptyQueryResult +from src.helpers.join import EmptyQueryResult class TestDataGenerator: @@ -99,10 +98,10 @@ class TestDataGenerator: def test_repr(self, gen): path = os.path.join(os.path.dirname(__file__), 'data') - assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], "\ + assert gen.__repr__().rstrip() == f"DataGenerator(path='{path}', network='AIRBASE', stations=['DEBW107'], " \ f"variables=['o3', 'temp'], station_type=None, interpolate_dim='datetime', " \ - f"target_dim='variables', target_var='o3', **{{'start': 2010, 'end': 2014}})"\ - .rstrip() + f"target_dim='variables', target_var='o3', **{{'start': 2010, 'end': 2014}})" \ + .rstrip() def test_len(self, gen): assert len(gen) == 1 diff --git a/test/test_data_handling/test_data_preparation.py b/test/test_data_handling/test_data_preparation.py index 747b3734f565d3206696998de10f5986b7c94bf0..a8ca555c9748f7656fefc007922ee0d7df1992fa 100644 --- a/test/test_data_handling/test_data_preparation.py +++ b/test/test_data_handling/test_data_preparation.py @@ -1,7 +1,7 @@ import datetime as dt +import logging import os from operator import itemgetter, lt, gt -import logging import numpy as np import pandas as pd @@ -9,7 +9,7 @@ import pytest import xarray as xr from src.data_handling.data_preparation import DataPrep -from src.join import EmptyQueryResult +from src.helpers.join import EmptyQueryResult class TestDataPrep: @@ -190,7 +190,7 @@ class TestDataPrep: assert data._transform_method is None assert data.mean is None assert data.std is None - data_std_orig = data.data.std('datetime'). variable.values + data_std_orig = data.data.std('datetime').variable.values data.transform('datetime', 'centre') assert data._transform_method == 'centre' assert np.testing.assert_almost_equal(data.data.mean('datetime').variable.values, np.array([[0, 0]])) is None @@ -299,11 +299,11 @@ class TestDataPrep: index_array = data.create_index_array('window', range(1, 4)) assert np.testing.assert_array_equal(index_array.data, [1, 2, 3]) is None assert index_array.name == 'window' - assert index_array.coords.dims == ('window', ) + assert index_array.coords.dims == ('window',) index_array = data.create_index_array('window', range(0, 1)) assert np.testing.assert_array_equal(index_array.data, [0]) is None assert index_array.name == 'window' - assert index_array.coords.dims == ('window', ) + assert index_array.coords.dims == ('window',) @staticmethod def extract_window_data(res, orig, w): @@ -311,7 +311,7 @@ class TestDataPrep: window = res.sel(slice).data.flatten() if w <= 0: delta = w - w = abs(w)+1 + w = abs(w) + 1 else: delta = 1 slice = {'variables': ['temp'], 'Stations': 'DEBW107', @@ -421,10 +421,13 @@ class TestDataPrep: orig = data.label data.multiply_extremes([1, 1.5, 2, 3]) upsampled = data.extremes_label + def f(d, op, n): return op(d, n).any(dim="window").sum() + assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 1.5), f(orig, gt, 2) * 2, f(orig, gt, 3) * 4]) - assert f(upsampled, lt, -1) == sum([f(orig, lt, -1), f(orig, lt, -1.5), f(orig, lt, -2) * 2, f(orig, lt, -3) * 4]) + assert f(upsampled, lt, -1) == sum( + [f(orig, lt, -1), f(orig, lt, -1.5), f(orig, lt, -2) * 2, f(orig, lt, -3) * 4]) def test_multiply_extremes_wrong_extremes(self, data): data.transform("datetime") @@ -442,8 +445,10 @@ class TestDataPrep: orig = data.label data.multiply_extremes([1, 2], extremes_on_right_tail_only=True) upsampled = data.extremes_label + def f(d, op, n): return op(d, n).any(dim="window").sum() + assert f(upsampled, gt, 1) == sum([f(orig, gt, 1), f(orig, gt, 2)]) assert upsampled.shape[2] == sum([f(orig, gt, 1), f(orig, gt, 2)]) assert f(upsampled, lt, -1) == 0 @@ -454,13 +459,13 @@ class TestDataPrep: data.label = None assert data.multiply_extremes([1], extremes_on_right_tail_only=False) is None - def test_multiply_extremes_none_history(self,data ): + def test_multiply_extremes_none_history(self, data): data.transform("datetime") data.history = None data.make_labels("variables", "o3", "datetime", 2) assert data.multiply_extremes([1], extremes_on_right_tail_only=False) is None - def test_multiply_extremes_none_label_history(self,data ): + def test_multiply_extremes_none_label_history(self, data): data.history = None data.label = None assert data.multiply_extremes([1], extremes_on_right_tail_only=False) is None diff --git a/test/test_datastore.py b/test/test_datastore.py index 5b6cd17a00271a17b8fe5c30ca26665b42e56141..9aca1eef35927242df0b5f659eece716f81f6c13 100644 --- a/test/test_datastore.py +++ b/test/test_datastore.py @@ -1,11 +1,10 @@ __author__ = 'Lukas Leufen' __date__ = '2019-11-22' - import pytest -from src.datastore import AbstractDataStore, DataStoreByVariable, DataStoreByScope, CorrectScope -from src.datastore import NameNotFoundInDataStore, NameNotFoundInScope, EmptyScope +from src.helpers.datastore import AbstractDataStore, DataStoreByVariable, DataStoreByScope, CorrectScope +from src.helpers.datastore import NameNotFoundInDataStore, NameNotFoundInScope, EmptyScope class TestAbstractDataStore: @@ -80,7 +79,8 @@ class TestDataStoreByVariable: ds.set("number", 11, "general.sub") with pytest.raises(NameNotFoundInScope) as e: ds.get("number", "general.sub2") - assert "Couldn't find number in scope general.sub2 . number is only defined in ['general.sub']" in e.value.args[0] + assert "Couldn't find number in scope general.sub2 . number is only defined in ['general.sub']" in e.value.args[ + 0] def test_list_all_scopes(self, ds): ds.set("number", 22, "general2") @@ -135,9 +135,9 @@ class TestDataStoreByVariable: ds.set("number2", 3, "general.sub.sub") ds.set("number", "ABC", "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False, return_all=True) == \ - [("number", "general.sub", 11), ("number1", "general.sub", 22)] + [("number", "general.sub", 11), ("number1", "general.sub", 22)] assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ - [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)] + [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)] def test_create_args_dict_default_scope(self, ds_with_content): args = ["tester1", "tester2", "tester3", "tester4"] @@ -153,11 +153,11 @@ class TestDataStoreByVariable: assert ds_with_content.create_args_dict(args) == {"tester1": 1} def test_set_args_from_dict(self, ds): - ds.set_args_from_dict({"tester1": 1, "tester2": 10, "tester3": 21}) + ds.set_from_dict({"tester1": 1, "tester2": 10, "tester3": 21}) assert ds.get("tester1", "general") == 1 assert ds.get("tester2", "general") == 10 assert ds.get("tester3", "general") == 21 - ds.set_args_from_dict({"tester1": 111}, "general.sub") + ds.set_from_dict({"tester1": 111}, "general.sub") assert ds.get("tester1", "general.sub") == 111 assert ds.get("tester3", "general.sub") == 21 @@ -231,7 +231,8 @@ class TestDataStoreByScope: ds.set("number", 11, "general.sub") with pytest.raises(NameNotFoundInScope) as e: ds.get("number", "general.sub2") - assert "Couldn't find number in scope general.sub2 . number is only defined in ['general.sub']" in e.value.args[0] + assert "Couldn't find number in scope general.sub2 . number is only defined in ['general.sub']" in e.value.args[ + 0] def test_list_all_scopes(self, ds): ds.set("number", 22, "general2") @@ -286,9 +287,9 @@ class TestDataStoreByScope: ds.set("number2", 3, "general.sub.sub") ds.set("number", "ABC", "general.sub.sub") assert ds.search_scope("general.sub", current_scope_only=False, return_all=True) == \ - [("number", "general.sub", 11), ("number1", "general.sub", 22)] + [("number", "general.sub", 11), ("number1", "general.sub", 22)] assert ds.search_scope("general.sub.sub", current_scope_only=False, return_all=True) == \ - [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)] + [("number", "general.sub.sub", "ABC"), ("number1", "general.sub", 22), ("number2", "general.sub.sub", 3)] def test_create_args_dict_default_scope(self, ds_with_content): args = ["tester1", "tester2", "tester3", "tester4"] @@ -304,11 +305,11 @@ class TestDataStoreByScope: assert ds_with_content.create_args_dict(args) == {"tester1": 1} def test_set_args_from_dict(self, ds): - ds.set_args_from_dict({"tester1": 1, "tester2": 10, "tester3": 21}) + ds.set_from_dict({"tester1": 1, "tester2": 10, "tester3": 21}) assert ds.get("tester1", "general") == 1 assert ds.get("tester2", "general") == 10 assert ds.get("tester3", "general") == 21 - ds.set_args_from_dict({"tester1": 111}, "general.sub") + ds.set_from_dict({"tester1": 111}, "general.sub") assert ds.get("tester1", "general.sub") == 111 assert ds.get("tester3", "general.sub") == 21 diff --git a/test/test_helpers.py b/test/test_helpers.py index 0065a94b7b18d88c2e86e60df5633d47ba15f42a..7fd65f909887b8a71bcacd74b7de201ec057824d 100644 --- a/test/test_helpers.py +++ b/test/test_helpers.py @@ -12,147 +12,18 @@ import re from src.helpers import * -class TestToList: - - def test_to_list(self): - assert to_list('a') == ['a'] - assert to_list('abcd') == ['abcd'] - assert to_list([1, 2, 3]) == [1, 2, 3] - assert to_list([45]) == [45] - - -class TestCheckPath: - - def test_check_path_and_create(self, caplog): - caplog.set_level(logging.DEBUG) - path = 'data/test' - assert not os.path.exists('data/test') - check_path_and_create(path) - assert os.path.exists('data/test') - assert caplog.messages[0] == "Created path: data/test" - check_path_and_create(path) - assert caplog.messages[1] == "Path already exists: data/test" - os.rmdir('data/test') - - -class TestLoss: - - def test_l_p_loss(self): - model = keras.Sequential() - model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,))) - model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2)) - hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) - assert hist.history['loss'][0] == 1.25 - model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(3)) - hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) - assert hist.history['loss'][0] == 2.25 - - -class TestTimeTracking: - - def test_init(self): - t = TimeTracking() - assert t.start is not None - assert t.start < time.time() - assert t.end is None - t2 = TimeTracking(start=False) - assert t2.start is None - - def test__start(self): - t = TimeTracking(start=False) - t._start() - assert t.start < time.time() - - def test__end(self): - t = TimeTracking() - t._end() - assert t.end > t.start - - def test__duration(self): - t = TimeTracking() - d1 = t._duration() - assert d1 > 0 - d2 = t._duration() - assert d2 > d1 - t._end() - d3 = t._duration() - assert d3 > d2 - assert d3 == t._duration() - - def test_repr(self): - t = TimeTracking() - t._end() - duration = t._duration() - assert t.__repr__().rstrip() == f"{dt.timedelta(seconds=math.ceil(duration))} (hh:mm:ss)".rstrip() - - def test_run(self): - t = TimeTracking(start=False) - assert t.start is None - t.run() - assert t.start is not None - - def test_stop(self): - t = TimeTracking() - assert t.end is None - duration = t.stop(get_duration=True) - assert duration == t._duration() - with pytest.raises(AssertionError) as e: - t.stop() - assert "Time was already stopped" in e.value.args[0] - t.run() - assert t.end is None - assert t.stop() is None - assert t.end is not None - - def test_duration(self): - t = TimeTracking() - duration = t - assert duration is not None - duration = t.stop(get_duration=True) - assert duration == t.duration() - - def test_enter_exit(self, caplog): - caplog.set_level(logging.INFO) - with TimeTracking() as t: - assert t.start is not None - assert t.end is None - expression = PyTestRegex(r"undefined job finished after \d+:\d+:\d+ \(hh:mm:ss\)") - assert caplog.record_tuples[-1] == ('root', 20, expression) - - def test_name_enter_exit(self, caplog): - caplog.set_level(logging.INFO) - with TimeTracking(name="my job") as t: - assert t.start is not None - assert t.end is None - expression = PyTestRegex(r"my job finished after \d+:\d+:\d+ \(hh:mm:ss\)") - assert caplog.record_tuples[-1] == ('root', 20, expression) - - -class TestGetHost: - - @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest", - "runner-6HmDp9Qd-project-2411-concurrent"]) - def test_get_host(self, mock_host): - assert get_host() == "linux-aa9b" - assert get_host() == "ZAM144" - assert get_host() == "zam347" - assert get_host() == "jrtest" - assert get_host() == "jwtest" - assert get_host() == "runner-6HmDp9Qd-project-2411-concurrent" - - class TestPrepareHost: @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest", "runner-6HmDp9Qd-project-2411-concurrent-01"]) - @mock.patch("getpass.getuser", return_value="testUser") + @mock.patch("os.getlogin", return_value="testUser") @mock.patch("os.path.exists", return_value=True) def test_prepare_host(self, mock_host, mock_user, mock_path): assert prepare_host() == "/home/testUser/machinelearningtools/data/toar_daily/" assert prepare_host() == "/home/testUser/Data/toar_daily/" assert prepare_host() == "/home/testUser/Data/toar_daily/" assert prepare_host() == "/p/project/cjjsc42/testUser/DATA/toar_daily/" - assert prepare_host() == "/p/project/deepacf/intelliaq/testUser/DATA/toar_daily/" + assert prepare_host() == "/p/home/jusers/testUser/juwels/intelliaq/DATA/toar_daily/" assert prepare_host() == '/home/testUser/machinelearningtools/data/toar_daily/' @mock.patch("socket.gethostname", return_value="NotExistingHostName") @@ -162,8 +33,7 @@ class TestPrepareHost: prepare_host() assert "unknown host 'NotExistingHostName'" in e.value.args[0] - #@mock.patch("os.getlogin", return_value="zombie21") - @mock.patch("getpass.getuser", return_value="zombie21") + @mock.patch("os.getlogin", return_value="zombie21") @mock.patch("src.helpers.check_path_and_create", side_effect=PermissionError) def test_error_handling(self, mock_cpath, mock_user): # if "runner-6HmDp9Qd-project-2411-concurrent" not in platform.node(): @@ -176,215 +46,29 @@ class TestPrepareHost: # assert "does not exist for host 'linux-aa9b'" in e.value.args[0] assert PyTestRegex(r"path '.*' does not exist for host '.*'\.") == e.value.args[0] + @mock.patch("socket.gethostname", side_effect=["linux-aa9b", "ZAM144", "zam347", "jrtest", "jwtest", + "runner-6HmDp9Qd-project-2411-concurrent-01"]) + @mock.patch("os.getlogin", side_effect=OSError) + @mock.patch("os.path.exists", return_value=True) + def test_os_error(self, mock_path, mock_user, mock_host): + path = prepare_host() + assert path == "/home/default/machinelearningtools/data/toar_daily/" + path = prepare_host() + assert path == "/home/default/Data/toar_daily/" + path = prepare_host() + assert path == "/home/default/Data/toar_daily/" + path = prepare_host() + assert path == "/p/project/cjjsc42/default/DATA/toar_daily/" + path = prepare_host() + assert path == "/p/home/jusers/default/juwels/intelliaq/DATA/toar_daily/" + path = prepare_host() + assert path == '/home/default/machinelearningtools/data/toar_daily/' @mock.patch("socket.gethostname", side_effect=["linux-aa9b"]) - @mock.patch("getpass.getuser", return_value="testUser") + @mock.patch("os.getlogin", return_value="testUser") @mock.patch("os.path.exists", return_value=False) @mock.patch("os.makedirs", side_effect=None) def test_os_path_exists(self, mock_host, mock_user, mock_path, mock_check): path = prepare_host() assert path == "/home/testUser/machinelearningtools/data/toar_daily/" - -class TestSetExperimentName: - - def test_set_experiment(self): - exp_name, exp_path = set_experiment_name() - assert exp_name == "TestExperiment" - assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "TestExperiment")) - exp_name, exp_path = set_experiment_name(experiment_date="2019-11-14", experiment_path="./test2") - assert exp_name == "2019-11-14_network" - assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "test2", exp_name)) - - def test_set_experiment_from_sys(self): - exp_name, _ = set_experiment_name(experiment_date="2019-11-14") - assert exp_name == "2019-11-14_network" - - def test_set_expperiment_hourly(self): - exp_name, exp_path = set_experiment_name(sampling="hourly") - assert exp_name == "TestExperiment_hourly" - assert exp_path == os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "TestExperiment_hourly")) - - -class TestSetBootstrapPath: - - def test_bootstrap_path_is_none(self): - bootstrap_path = set_bootstrap_path(None, 'TestDataPath/', 'daily') - assert bootstrap_path == 'TestDataPath/../bootstrap_daily' - - @mock.patch("os.makedirs", side_effect=None) - def test_bootstap_path_is_given(self, mock_makedir): - bootstrap_path = set_bootstrap_path('Test/path/to/boots', None, None) - assert bootstrap_path == 'Test/path/to/boots' - - -class TestPytestRegex: - - @pytest.fixture - def regex(self): - return PyTestRegex("teststring") - - def test_pytest_regex_init(self, regex): - assert regex._regex.pattern == "teststring" - - def test_pytest_regex_eq(self, regex): - assert regex == "teststringabcd" - assert regex != "teststgabcd" - - def test_pytest_regex_repr(self, regex): - assert regex.__repr__() == "teststring" - - -class TestDictToXarray: - - def test_dict_to_xarray(self): - array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) - array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) - d = {"number1": array1, "number2": array2} - res = dict_to_xarray(d, "merge_dim") - assert type(res) == xr.DataArray - assert sorted(list(res.coords)) == ["merge_dim", "x"] - assert res.shape == (2, 2, 3) - - -class TestFloatRound: - - def test_float_round_ceil(self): - assert float_round(4.6) == 5 - assert float_round(239.3992) == 240 - - def test_float_round_decimals(self): - assert float_round(23.0091, 2) == 23.01 - assert float_round(23.1091, 3) == 23.11 - - def test_float_round_type(self): - assert float_round(34.9221, 2, math.floor) == 34.92 - assert float_round(34.9221, 0, math.floor) == 34. - assert float_round(34.9221, 2, round) == 34.92 - assert float_round(34.9221, 0, round) == 35. - - def test_float_round_negative(self): - assert float_round(-34.9221, 2, math.floor) == -34.93 - assert float_round(-34.9221, 0, math.floor) == -35. - assert float_round(-34.9221, 2) == -34.92 - assert float_round(-34.9221, 0) == -34. - - -class TestDictPop: - - @pytest.fixture - def custom_dict(self): - return {'a': 1, 'b': 2, 2: 'ab'} - - def test_dict_pop_single(self, custom_dict): - # one out as list - d_pop = dict_pop(custom_dict, [4]) - assert d_pop == custom_dict - # one out as str - d_pop = dict_pop(custom_dict, '4') - assert d_pop == custom_dict - # one in as str - d_pop = dict_pop(custom_dict, 'b') - assert d_pop == {'a': 1, 2: 'ab'} - # one in as list - d_pop = dict_pop(custom_dict, ['b']) - assert d_pop == {'a': 1, 2: 'ab'} - - def test_dict_pop_multiple(self, custom_dict): - # all out (list) - d_pop = dict_pop(custom_dict, [4, 'mykey']) - assert d_pop == custom_dict - # all in (list) - d_pop = dict_pop(custom_dict, ['a', 2]) - assert d_pop == {'b': 2} - # one in one out (list) - d_pop = dict_pop(custom_dict, [2, '10']) - assert d_pop == {'a': 1, 'b': 2} - - def test_dict_pop_missing_argument(self, custom_dict): - with pytest.raises(TypeError) as e: - dict_pop() - assert "dict_pop() missing 2 required positional arguments: 'dict_orig' and 'pop_keys'" in e.value.args[0] - with pytest.raises(TypeError) as e: - dict_pop(custom_dict) - assert "dict_pop() missing 1 required positional argument: 'pop_keys'" in e.value.args[0] - - -class TestListPop: - - @pytest.fixture - def custom_list(self): - return [1, 2, 3, 'a', 'bc'] - - def test_list_pop_single(self, custom_list): - l_pop = list_pop(custom_list, 1) - assert l_pop == [2, 3, 'a', 'bc'] - l_pop = list_pop(custom_list, 'bc') - assert l_pop == [1, 2, 3, 'a'] - l_pop = list_pop(custom_list, 5) - assert l_pop == custom_list - - def test_list_pop_multiple(self, custom_list): - # all in list - l_pop = list_pop(custom_list, [2, 'a']) - assert l_pop == [1, 3, 'bc'] - # one in one out - l_pop = list_pop(custom_list, ['bc', 10]) - assert l_pop == [1, 2, 3, 'a'] - # all out - l_pop = list_pop(custom_list, [10, 'aa']) - assert l_pop == custom_list - - def test_list_pop_missing_argument(self, custom_list): - with pytest.raises(TypeError) as e: - list_pop() - assert "list_pop() missing 2 required positional arguments: 'list_full' and 'pop_items'" in e.value.args[0] - with pytest.raises(TypeError) as e: - list_pop(custom_list) - assert "list_pop() missing 1 required positional argument: 'pop_items'" in e.value.args[0] - - -class TestLogger: - - @pytest.fixture - def logger(self): - return Logger() - - def test_init_default(self): - log = Logger() - assert log.formatter == "%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]" - assert log.log_file == Logger.setup_logging_path() - # assert PyTestRegex( - # ".*machinelearningtools/src/\.{2}/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log.log_file - - def test_setup_logging_path_none(self): - log_file = Logger.setup_logging_path(None) - assert PyTestRegex( - ".*machinelearningtools/src/\.{2}/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_file - - @mock.patch("os.makedirs", side_effect=None) - def test_setup_logging_path_given(self, mock_makedirs): - path = "my/test/path" - log_path = Logger.setup_logging_path(path) - assert PyTestRegex("my/test/path/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_path - - def test_logger_console_level0(self, logger): - consol = logger.logger_console(0) - assert isinstance(consol, logging.StreamHandler) - assert consol.level == 0 - formatter = logging.Formatter(logger.formatter) - assert isinstance(formatter, logging.Formatter) - - def test_logger_console_level1(self, logger): - consol = logger.logger_console(1) - assert isinstance(consol, logging.StreamHandler) - assert consol.level == 1 - formatter = logging.Formatter(logger.formatter) - assert isinstance(formatter, logging.Formatter) - - def test_logger_console_level_wrong_type(self, logger): - with pytest.raises(TypeError) as e: - logger.logger_console(1.5) - assert "Level not an integer or a valid string: 1.5" == e.value.args[0] - - diff --git a/test/test_helpers/test_helpers.py b/test/test_helpers/test_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..28a8bf6e421d62d58d76e7a32906f8a594f16ed7 --- /dev/null +++ b/test/test_helpers/test_helpers.py @@ -0,0 +1,265 @@ +import numpy as np +import xarray as xr + +import datetime as dt +import logging +import math +import time + +import mock +import pytest + +from src.helpers import to_list, dict_to_xarray, float_round, remove_items +from src.helpers import PyTestRegex +from src.helpers import Logger, TimeTracking + + +class TestToList: + + def test_to_list(self): + assert to_list('a') == ['a'] + assert to_list('abcd') == ['abcd'] + assert to_list([1, 2, 3]) == [1, 2, 3] + assert to_list([45]) == [45] + + +class TestTimeTracking: + + def test_init(self): + t = TimeTracking() + assert t.start is not None + assert t.start < time.time() + assert t.end is None + t2 = TimeTracking(start=False) + assert t2.start is None + + def test__start(self): + t = TimeTracking(start=False) + t._start() + assert t.start < time.time() + + def test__end(self): + t = TimeTracking() + t._end() + assert t.end > t.start + + def test__duration(self): + t = TimeTracking() + d1 = t._duration() + assert d1 > 0 + d2 = t._duration() + assert d2 > d1 + t._end() + d3 = t._duration() + assert d3 > d2 + assert d3 == t._duration() + + def test_repr(self): + t = TimeTracking() + t._end() + duration = t._duration() + assert t.__repr__().rstrip() == f"{dt.timedelta(seconds=math.ceil(duration))} (hh:mm:ss)".rstrip() + + def test_run(self): + t = TimeTracking(start=False) + assert t.start is None + t.run() + assert t.start is not None + + def test_stop(self): + t = TimeTracking() + assert t.end is None + duration = t.stop(get_duration=True) + assert duration == t._duration() + with pytest.raises(AssertionError) as e: + t.stop() + assert "Time was already stopped" in e.value.args[0] + t.run() + assert t.end is None + assert t.stop() is None + assert t.end is not None + + def test_duration(self): + t = TimeTracking() + duration = t + assert duration is not None + duration = t.stop(get_duration=True) + assert duration == t.duration() + + def test_enter_exit(self, caplog): + caplog.set_level(logging.INFO) + with TimeTracking() as t: + assert t.start is not None + assert t.end is None + expression = PyTestRegex(r"undefined job finished after \d+:\d+:\d+ \(hh:mm:ss\)") + assert caplog.record_tuples[-1] == ('root', 20, expression) + + def test_name_enter_exit(self, caplog): + caplog.set_level(logging.INFO) + with TimeTracking(name="my job") as t: + assert t.start is not None + assert t.end is None + expression = PyTestRegex(r"my job finished after \d+:\d+:\d+ \(hh:mm:ss\)") + assert caplog.record_tuples[-1] == ('root', 20, expression) + + +class TestPytestRegex: + + @pytest.fixture + def regex(self): + return PyTestRegex("teststring") + + def test_pytest_regex_init(self, regex): + assert regex._regex.pattern == "teststring" + + def test_pytest_regex_eq(self, regex): + assert regex == "teststringabcd" + assert regex != "teststgabcd" + + def test_pytest_regex_repr(self, regex): + assert regex.__repr__() == "teststring" + + +class TestDictToXarray: + + def test_dict_to_xarray(self): + array1 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) + array2 = xr.DataArray(np.random.randn(2, 3), dims=('x', 'y'), coords={'x': [10, 20]}) + d = {"number1": array1, "number2": array2} + res = dict_to_xarray(d, "merge_dim") + assert type(res) == xr.DataArray + assert sorted(list(res.coords)) == ["merge_dim", "x"] + assert res.shape == (2, 2, 3) + + +class TestFloatRound: + + def test_float_round_ceil(self): + assert float_round(4.6) == 5 + assert float_round(239.3992) == 240 + + def test_float_round_decimals(self): + assert float_round(23.0091, 2) == 23.01 + assert float_round(23.1091, 3) == 23.11 + + def test_float_round_type(self): + assert float_round(34.9221, 2, math.floor) == 34.92 + assert float_round(34.9221, 0, math.floor) == 34. + assert float_round(34.9221, 2, round) == 34.92 + assert float_round(34.9221, 0, round) == 35. + + def test_float_round_negative(self): + assert float_round(-34.9221, 2, math.floor) == -34.93 + assert float_round(-34.9221, 0, math.floor) == -35. + assert float_round(-34.9221, 2) == -34.92 + assert float_round(-34.9221, 0) == -34. + + +class TestRemoveItems: + + @pytest.fixture + def custom_list(self): + return [1, 2, 3, 'a', 'bc'] + + @pytest.fixture + def custom_dict(self): + return {'a': 1, 'b': 2, 2: 'ab'} + + def test_dict_remove_single(self, custom_dict): + # one out as list + d_pop = remove_items(custom_dict, [4]) + assert d_pop == custom_dict + # one out as str + d_pop = remove_items(custom_dict, '4') + assert d_pop == custom_dict + # one in as str + d_pop = remove_items(custom_dict, 'b') + assert d_pop == {'a': 1, 2: 'ab'} + # one in as list + d_pop = remove_items(custom_dict, ['b']) + assert d_pop == {'a': 1, 2: 'ab'} + + def test_dict_remove_multiple(self, custom_dict): + # all out (list) + d_pop = remove_items(custom_dict, [4, 'mykey']) + assert d_pop == custom_dict + # all in (list) + d_pop = remove_items(custom_dict, ['a', 2]) + assert d_pop == {'b': 2} + # one in one out (list) + d_pop = remove_items(custom_dict, [2, '10']) + assert d_pop == {'a': 1, 'b': 2} + + def test_list_remove_single(self, custom_list): + l_pop = remove_items(custom_list, 1) + assert l_pop == [2, 3, 'a', 'bc'] + l_pop = remove_items(custom_list, 'bc') + assert l_pop == [1, 2, 3, 'a'] + l_pop = remove_items(custom_list, 5) + assert l_pop == custom_list + + def test_list_remove_multiple(self, custom_list): + # all in list + l_pop = remove_items(custom_list, [2, 'a']) + assert l_pop == [1, 3, 'bc'] + # one in one out + l_pop = remove_items(custom_list, ['bc', 10]) + assert l_pop == [1, 2, 3, 'a'] + # all out + l_pop = remove_items(custom_list, [10, 'aa']) + assert l_pop == custom_list + + def test_remove_missing_argument(self, custom_dict, custom_list): + with pytest.raises(TypeError) as e: + remove_items() + assert "remove_items() missing 2 required positional arguments: 'obj' and 'items'" in e.value.args[0] + with pytest.raises(TypeError) as e: + remove_items(custom_dict) + assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0] + with pytest.raises(TypeError) as e: + remove_items(custom_list) + assert "remove_items() missing 1 required positional argument: 'items'" in e.value.args[0] + + +class TestLogger: + + @pytest.fixture + def logger(self): + return Logger() + + def test_init_default(self): + log = Logger() + assert log.formatter == "%(asctime)s - %(levelname)s: %(message)s [%(filename)s:%(funcName)s:%(lineno)s]" + assert log.log_file == Logger.setup_logging_path() + # assert PyTestRegex( + # ".*machinelearningtools/src/\.{2}/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log.log_file + + def test_setup_logging_path_none(self): + log_file = Logger.setup_logging_path(None) + assert PyTestRegex( + ".*machinelearningtools/logging/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_file + + @mock.patch("os.makedirs", side_effect=None) + def test_setup_logging_path_given(self, mock_makedirs): + path = "my/test/path" + log_path = Logger.setup_logging_path(path) + assert PyTestRegex("my/test/path/logging_\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2}\.log") == log_path + + def test_logger_console_level0(self, logger): + consol = logger.logger_console(0) + assert isinstance(consol, logging.StreamHandler) + assert consol.level == 0 + formatter = logging.Formatter(logger.formatter) + assert isinstance(formatter, logging.Formatter) + + def test_logger_console_level1(self, logger): + consol = logger.logger_console(1) + assert isinstance(consol, logging.StreamHandler) + assert consol.level == 1 + formatter = logging.Formatter(logger.formatter) + assert isinstance(formatter, logging.Formatter) + + def test_logger_console_level_wrong_type(self, logger): + with pytest.raises(TypeError) as e: + logger.logger_console(1.5) + assert "Level not an integer or a valid string: 1.5" == e.value.args[0] diff --git a/test/test_join.py b/test/test_join.py index fe3d33d6296c16bfc72675bc1737aad12ee3c8b9..5adc013cfbd446c4feaf4a2b344f07d6f170077d 100644 --- a/test/test_join.py +++ b/test/test_join.py @@ -2,9 +2,9 @@ from typing import Iterable import pytest -from src.join import * -from src.join import _save_to_pandas, _correct_stat_name, _lower_list -from src.join_settings import join_settings +from src.helpers.join import * +from src.helpers.join import _save_to_pandas, _correct_stat_name, _lower_list +from src.configuration.join_settings import join_settings class TestJoinUrlBase: @@ -53,7 +53,8 @@ class TestLoadSeriesInformation: def test_standard_query(self): expected_subset = {'o3': 23031, 'no2': 39002, 'temp--lubw': 17059, 'wspeed': 17060} - assert expected_subset.items() <= load_series_information(['DEBW107'], None, None, join_settings()[0], {}).items() + assert expected_subset.items() <= load_series_information(['DEBW107'], None, None, join_settings()[0], + {}).items() def test_empty_result(self): assert load_series_information(['DEBW107'], "traffic", None, join_settings()[0], {}) == {} @@ -137,4 +138,3 @@ class TestCreateUrl: def test_none_kwargs(self): url = create_url("www.base2.edu/", "testingservice", mood="sad", happiness=None, stress_factor=100) assert url == "www.base2.edu/testingservice/?mood=sad&stress_factor=100" - diff --git a/test/test_model_modules/test_advanced_paddings.py b/test/test_model_modules/test_advanced_paddings.py index bbeaf1c745a63b3607062b0c4052088c9af06b92..8c7cae91ad12cc2b06ec82ba64f91c792a620756 100644 --- a/test/test_model_modules/test_advanced_paddings.py +++ b/test/test_model_modules/test_advanced_paddings.py @@ -69,10 +69,10 @@ class TestPadUtils: ################################################################################## def test_check_padding_format_negative_pads(self): - with pytest.raises(ValueError) as einfo: PadUtils.check_padding_format((-2, 1)) - assert "The `1st entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str(einfo.value) + assert "The `1st entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str( + einfo.value) with pytest.raises(ValueError) as einfo: PadUtils.check_padding_format((1, -1)) @@ -198,15 +198,18 @@ class TestReflectionPadding2D: def test_init_tuple_of_negative_int(self): with pytest.raises(ValueError) as einfo: ReflectionPadding2D(padding=(-1, 1)) - assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(einfo.value) + assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str( + einfo.value) with pytest.raises(ValueError) as einfo: ReflectionPadding2D(padding=(1, -2)) - assert "The `2nd entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str(einfo.value) + assert "The `2nd entry of padding` argument must be >= 0. Received: -2 of type <class 'int'>" in str( + einfo.value) with pytest.raises(ValueError) as einfo: ReflectionPadding2D(padding=(-1, -2)) - assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str(einfo.value) + assert "The `1st entry of padding` argument must be >= 0. Received: -1 of type <class 'int'>" in str( + einfo.value) def test_init_tuple_of_invalid_format_float(self): with pytest.raises(ValueError) as einfo: @@ -434,7 +437,6 @@ class TestPadding2D: 'ZeroPad2D': ZeroPadding2D, 'ZeroPadding2D': ZeroPadding2D } - def test_check_and_get_padding_zero_padding(self): assert Padding2D('ZeroPad2D')._check_and_get_padding() == ZeroPadding2D assert Padding2D('ZeroPadding2D')._check_and_get_padding() == ZeroPadding2D @@ -450,14 +452,14 @@ class TestPadding2D: assert Padding2D('ReflectionPadding2D')._check_and_get_padding() == ReflectionPadding2D assert Padding2D(ReflectionPadding2D)._check_and_get_padding() == ReflectionPadding2D - def test_check_and_get_padding_raises(self,): + def test_check_and_get_padding_raises(self, ): with pytest.raises(NotImplementedError) as einfo: Padding2D('FalsePadding2D')._check_and_get_padding() assert "`'FalsePadding2D'' is not implemented as padding. " \ "Use one of those: i) `RefPad2D', ii) `SymPad2D', iii) `ZeroPad2D'" in str(einfo.value) with pytest.raises(TypeError) as einfo: Padding2D(keras.layers.Conv2D)._check_and_get_padding() - assert "`Conv2D' is not a valid padding layer type. Use one of those: "\ + assert "`Conv2D' is not a valid padding layer type. Use one of those: " \ "i) ReflectionPadding2D, ii) SymmetricPadding2D, iii) ZeroPadding2D" in str(einfo.value) @pytest.mark.parametrize("pad_type", ["SymPad2D", "SymmetricPadding2D", SymmetricPadding2D, @@ -469,9 +471,8 @@ class TestPadding2D: layer_name = pad_type.__name__ else: layer_name = pad_type - pd_ap = pd(padding=(1,2), name=f"{layer_name}_layer")(input_x) + pd_ap = pd(padding=(1, 2), name=f"{layer_name}_layer")(input_x) assert pd_ap._keras_history[0].input_shape == (None, 32, 32, 3) assert pd_ap._keras_history[0].output_shape == (None, 34, 36, 3) assert pd_ap._keras_history[0].padding == ((1, 1), (2, 2)) assert pd_ap._keras_history[0].name == f"{layer_name}_layer" - diff --git a/test/test_model_modules/test_inception_model.py b/test/test_model_modules/test_inception_model.py index e5e92158425a73c5af1c6d1623d970e1037bbd80..ca0126a44fa0f8ccd2ed2a7ea79c872c4731fea1 100644 --- a/test/test_model_modules/test_inception_model.py +++ b/test/test_model_modules/test_inception_model.py @@ -1,9 +1,9 @@ import keras import pytest -from src.model_modules.inception_model import InceptionModelBase -from src.model_modules.advanced_paddings import ReflectionPadding2D, SymmetricPadding2D from src.helpers import PyTestRegex +from src.model_modules.advanced_paddings import ReflectionPadding2D, SymmetricPadding2D +from src.model_modules.inception_model import InceptionModelBase class TestInceptionModelBase: diff --git a/test/test_model_modules/test_keras_extensions.py b/test/test_model_modules/test_keras_extensions.py index 35188933c476157a6ba8c244d3647f7d6f8bdc59..56c60ec43173e9fdd438214862603caba632bc65 100644 --- a/test/test_model_modules/test_keras_extensions.py +++ b/test/test_model_modules/test_keras_extensions.py @@ -1,10 +1,10 @@ +import os + import keras -import numpy as np -import pytest import mock -import os +import pytest -from src.helpers import l_p_loss +from src.model_modules.loss import l_p_loss from src.model_modules.keras_extensions import * @@ -70,12 +70,13 @@ class TestModelCheckpointAdvanced: def callbacks(self): callbacks_name = os.path.join(os.path.dirname(__file__), "callback_%s") return [{"callback": LearningRateDecay(), "path": callbacks_name % "lr"}, - {"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}] + {"callback": HistoryAdvanced(), "path": callbacks_name % "hist"}] @pytest.fixture def ckpt(self, callbacks): ckpt_name = "ckpt.test" - return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, verbose=1) + return ModelCheckpointAdvanced(filepath=ckpt_name, monitor='val_loss', save_best_only=True, callbacks=callbacks, + verbose=1) def test_init(self, ckpt, callbacks): assert ckpt.callbacks == callbacks @@ -185,7 +186,6 @@ class TestCallbackHandler: clbk_handler.add_callback("callback_new_instance", "this_path") assert 'CallbackHandler is protected and cannot be edited.' in str(einfo.value) - def test_get_callbacks_as_dict(self, clbk_handler_with_dummies): clbk = clbk_handler_with_dummies assert clbk.get_callbacks() == [{"callback": "callback_new_instance", "path": "this_path"}, diff --git a/test/test_model_modules/test_linear_model.py b/test/test_model_modules/test_linear_model.py index e4e10e9db04ba041d61d6ebcf5de3a23380c8ebe..0fab7ae30472c6ae966331f30598b73c1ec48117 100644 --- a/test/test_model_modules/test_linear_model.py +++ b/test/test_model_modules/test_linear_model.py @@ -1,7 +1,3 @@ - -from src.model_modules.linear_model import OrdinaryLeastSquaredModel - - class TestOrdinaryLeastSquareModel: def test_constant_input_variable(self): diff --git a/test/test_model_modules/test_loss.py b/test/test_model_modules/test_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..c47f3f188a4b360bda08470fb00fd1d88a9f754c --- /dev/null +++ b/test/test_model_modules/test_loss.py @@ -0,0 +1,17 @@ +import keras +import numpy as np + +from src.model_modules.loss import l_p_loss + + +class TestLoss: + + def test_l_p_loss(self): + model = keras.Sequential() + model.add(keras.layers.Lambda(lambda x: x, input_shape=(None,))) + model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(2)) + hist = model.fit(np.array([1, 0, 2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) + assert hist.history['loss'][0] == 1.25 + model.compile(optimizer=keras.optimizers.Adam(), loss=l_p_loss(3)) + hist = model.fit(np.array([1, 0, -2, 0.5]), np.array([1, 1, 0, 0.5]), epochs=1) + assert hist.history['loss'][0] == 2.25 \ No newline at end of file diff --git a/test/test_model_modules/test_model_class.py b/test/test_model_modules/test_model_class.py index a8df3fe7213eef476b2ef7dbeac29d84b698f05a..0ee2eb7e5d439c76888f1f05e238bb5507db6a7a 100644 --- a/test/test_model_modules/test_model_class.py +++ b/test/test_model_modules/test_model_class.py @@ -2,7 +2,7 @@ import keras import pytest from src.model_modules.model_class import AbstractModelClass -from src.model_modules.model_class import MyPaperModel, MyTowerModel, MyLittleModel, MyBranchedModel +from src.model_modules.model_class import MyPaperModel class Paddings: @@ -229,4 +229,3 @@ class TestMyPaperModel: def test_set_compile_options(self, mpm): assert callable(mpm.compile_options["loss"]) or (len(mpm.compile_options["loss"]) > 0) - diff --git a/test/test_modules/test_experiment_setup.py b/test/test_modules/test_experiment_setup.py index a3a83acf84e286d1f5da9b5caffa256fc0ca3327..e06ba6c0ce5b9abb169e20016342b2a0dfb47d0f 100644 --- a/test/test_modules/test_experiment_setup.py +++ b/test/test_modules/test_experiment_setup.py @@ -4,7 +4,8 @@ import os import pytest -from src.helpers import TimeTracking, prepare_host +from src.helpers import TimeTracking +from src.configuration.path_config import prepare_host from src.run_modules.experiment_setup import ExperimentSetup @@ -51,8 +52,8 @@ class TestExperimentSetup: assert data_store.get("create_new_model", "general") is True assert data_store.get("fraction_of_training", "general") == 0.8 # set experiment name - assert data_store.get("experiment_name", "general") == "TestExperiment" - path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment")) + assert data_store.get("experiment_name", "general") == "TestExperiment_daily" + path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "TestExperiment_daily")) assert data_store.get("experiment_path", "general") == path default_statistics_per_var = {'o3': 'dma8eu', 'relhum': 'average_values', 'temp': 'maximum', 'u': 'average_values', 'v': 'average_values', 'no': 'dma8eu', 'no2': 'dma8eu', @@ -120,9 +121,9 @@ class TestExperimentSetup: assert data_store.get("create_new_model", "general") is True assert data_store.get("fraction_of_training", "general") == 0.5 # set experiment name - assert data_store.get("experiment_name", "general") == "TODAY_network" + assert data_store.get("experiment_name", "general") == "TODAY_network_daily" path = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "data", "testExperimentFolder", - "TODAY_network")) + "TODAY_network_daily")) assert data_store.get("experiment_path", "general") == path # setup for data assert data_store.get("stations", "general") == ['DEBY053', 'DEBW059', 'DEBW027'] diff --git a/test/test_modules/test_model_setup.py b/test/test_modules/test_model_setup.py index 9ff7494ff0540c9c96c1343b4f44fece08bfe4ce..24e89772cd2e9ddc4418617e2a622fe325d94b2f 100644 --- a/test/test_modules/test_model_setup.py +++ b/test/test_modules/test_model_setup.py @@ -3,7 +3,7 @@ import os import pytest from src.data_handling.data_generator import DataGenerator -from src.datastore import EmptyScope +from src.helpers.datastore import EmptyScope from src.model_modules.keras_extensions import CallbackHandler from src.model_modules.model_class import AbstractModelClass from src.run_modules.model_setup import ModelSetup @@ -16,7 +16,7 @@ class TestModelSetup: def setup(self): obj = object.__new__(ModelSetup) super(ModelSetup, obj).__init__() - obj.scope = "general.modeltest" + obj.scope = "general.model" obj.model = None obj.callbacks_name = "placeholder_%s_str.pickle" obj.data_store.set("lr_decay", "dummy_str", "general.model") @@ -58,24 +58,25 @@ class TestModelSetup: return set(model_cls.data_store.search_scope(model_cls.scope, current_scope_only=True)) def test_set_callbacks(self, setup): - assert "general.modeltest" not in setup.data_store.search_name("callbacks") + assert "general.model" not in setup.data_store.search_name("callbacks") setup.checkpoint_name = "TestName" setup._set_callbacks() - assert "general.modeltest" in setup.data_store.search_name("callbacks") - callbacks = setup.data_store.get("callbacks", "general.modeltest") + assert "general.model" in setup.data_store.search_name("callbacks") + callbacks = setup.data_store.get("callbacks", "general.model") assert len(callbacks.get_callbacks()) == 3 def test_set_callbacks_no_lr_decay(self, setup): setup.data_store.set("lr_decay", None, "general.model") - assert "general.modeltest" not in setup.data_store.search_name("callbacks") + assert "general.model" not in setup.data_store.search_name("callbacks") setup.checkpoint_name = "TestName" setup._set_callbacks() - callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.modeltest") + callbacks: CallbackHandler = setup.data_store.get("callbacks", "general.model") assert len(callbacks.get_callbacks()) == 2 with pytest.raises(IndexError): callbacks.get_callback_by_name("lr_decay") def test_get_model_settings(self, setup_with_model): + setup_with_model.scope = "model_test" with pytest.raises(EmptyScope): self.current_scope_as_set(setup_with_model) # will fail because scope is not created setup_with_model.get_model_settings() # this saves now the parameters epochs and batch_size into scope @@ -105,4 +106,3 @@ class TestModelSetup: def test_init(self): pass - diff --git a/test/test_modules/test_pre_processing.py b/test/test_modules/test_pre_processing.py index b29ed1e21480a869e4c118332c18b6edd8ac23a5..6abc722273613a1f4d6727396b114939b4d6a552 100644 --- a/test/test_modules/test_pre_processing.py +++ b/test/test_modules/test_pre_processing.py @@ -3,7 +3,7 @@ import logging import pytest from src.data_handling.data_generator import DataGenerator -from src.datastore import NameNotFoundInScope +from src.helpers.datastore import NameNotFoundInScope from src.helpers import PyTestRegex from src.run_modules.experiment_setup import ExperimentSetup from src.run_modules.pre_processing import PreProcessing, DEFAULT_ARGS_LIST, DEFAULT_KWARGS_LIST @@ -63,9 +63,9 @@ class TestPreProcessing: def test_create_set_split_not_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) - obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general.awesome") + obj_with_exp_setup.data_store.set("use_all_stations_on_all_data_sets", False, "general") obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") - assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") + assert ('root', 10, "Awesome stations (len=2): ['DEBW107', 'DEBY081']") in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) with pytest.raises(NameNotFoundInScope): @@ -75,8 +75,8 @@ class TestPreProcessing: def test_create_set_split_all_stations(self, caplog, obj_with_exp_setup): caplog.set_level(logging.DEBUG) obj_with_exp_setup.create_set_split(slice(0, 2), "awesome") - assert caplog.record_tuples[0] == ('root', 10, "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', " - "'DEBW076', 'DEBW087', 'DEBW001']") + message = "Awesome stations (len=6): ['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW001']" + assert ('root', 10, message) in caplog.record_tuples data_store = obj_with_exp_setup.data_store assert isinstance(data_store.get("generator", "general.awesome"), DataGenerator) with pytest.raises(NameNotFoundInScope): diff --git a/test/test_modules/test_run_environment.py b/test/test_modules/test_run_environment.py index d82675b57ea6feb4f83c99dab6f648c2846e4137..59bb8535c4dab44e646bd6bc4aa83a8553be4d26 100644 --- a/test/test_modules/test_run_environment.py +++ b/test/test_modules/test_run_environment.py @@ -17,7 +17,7 @@ class TestRunEnvironment: with RunEnvironment() as r: r.do_stuff(0.1) expression = PyTestRegex(r"RunEnvironment finished after \d+:\d+:\d+ \(hh:mm:ss\)") - assert caplog.record_tuples[-1] == ('root', 20, expression) + assert ('root', 20, expression) in caplog.record_tuples[-3:] def test_init(self, caplog): caplog.set_level(logging.INFO) @@ -30,4 +30,4 @@ class TestRunEnvironment: r.do_stuff(0.2) del r expression = PyTestRegex(r"RunEnvironment finished after \d+:\d+:\d+ \(hh:mm:ss\)") - assert caplog.record_tuples[-1] == ('root', 20, expression) + assert ('root', 20, expression) in caplog.record_tuples[-3:] diff --git a/test/test_modules/test_training.py b/test/test_modules/test_training.py index d3127de1afe0c1691b72dca0408e428fb5944bf4..33f9ddf62bd91c870643727de4d146ce332fbe07 100644 --- a/test/test_modules/test_training.py +++ b/test/test_modules/test_training.py @@ -7,7 +7,7 @@ import shutil import keras import mock import pytest -from keras.callbacks import ModelCheckpoint, History +from keras.callbacks import History from src.data_handling.data_distributor import Distributor from src.data_handling.data_generator import DataGenerator @@ -186,7 +186,8 @@ class TestTraining: assert all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) init_without_run.set_generators() assert not all([getattr(init_without_run, f"{obj}_set") is None for obj in sets]) - assert all([getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets]) + assert all( + [getattr(init_without_run, f"{obj}_set").generator.return_value == f"mock_{obj}_gen" for obj in sets]) def test_train(self, ready_to_train, path): assert not hasattr(ready_to_train.model, "history") @@ -201,7 +202,8 @@ class TestTraining: model_name = "test_model.h5" assert model_name not in os.listdir(path) init_without_run.save_model() - assert caplog.record_tuples[0] == ("root", 10, PyTestRegex(f"save best model to {os.path.join(path, model_name)}")) + message = PyTestRegex(f"save best model to {os.path.join(path, model_name)}") + assert caplog.record_tuples[1] == ("root", 10, message) assert model_name in os.listdir(path) def test_load_best_model_no_weights(self, init_without_run, caplog): diff --git a/test/test_plotting/test_tracker_plot.py b/test/test_plotting/test_tracker_plot.py new file mode 100644 index 0000000000000000000000000000000000000000..9a92360a819c130c213d06b89a48a896e082adad --- /dev/null +++ b/test/test_plotting/test_tracker_plot.py @@ -0,0 +1,447 @@ +import pytest + +from collections import OrderedDict +import os +import shutil + +from matplotlib import pyplot as plt +import numpy as np + +from src.plotting.tracker_plot import TrackObject, TrackChain, TrackPlot +from src.helpers import PyTestAllEqual + + +class TestTrackObject: + + @pytest.fixture + def track_obj(self): + return TrackObject("custom_name", "your_stage") + + def test_init(self, track_obj): + assert track_obj.name == ["custom_name"] + assert track_obj.stage == "your_stage" + assert all(track_obj.__getattribute__(obj) is None for obj in ["precursor", "successor", "x", "y"]) + + def test_repr(self, track_obj): + track_obj.name = ["custom", "name"] + assert repr(track_obj) == "custom/name" + + def test_x_property(self, track_obj): + assert track_obj.x is None + track_obj.x = 23 + assert track_obj.x == 23 + + def test_y_property(self, track_obj): + assert track_obj.y is None + track_obj.y = 21 + assert track_obj.y == 21 + + def test_add_precursor(self, track_obj): + assert track_obj.precursor is None + another_track_obj = TrackObject(["another", "track"], "your_stage") + track_obj.add_precursor(another_track_obj) + assert isinstance(track_obj.precursor, list) + assert track_obj.precursor[-1] == another_track_obj + assert len(track_obj.precursor) == 1 + assert another_track_obj.successor is not None + track_obj.add_precursor(another_track_obj) + assert len(track_obj.precursor) == 1 + track_obj.add_precursor(TrackObject(["third", "track"], "your_stage")) + assert len(track_obj.precursor) == 2 + + def test_add_successor(self, track_obj): + assert track_obj.successor is None + another_track_obj = TrackObject(["another", "track"], "your_stage") + track_obj.add_successor(another_track_obj) + assert isinstance(track_obj.successor, list) + assert track_obj.successor[-1] == another_track_obj + assert len(track_obj.successor) == 1 + assert another_track_obj.precursor is not None + track_obj.add_successor(another_track_obj) + assert len(track_obj.successor) == 1 + track_obj.add_successor(TrackObject(["third", "track"], "your_stage")) + assert len(track_obj.successor) == 2 + + +class TestTrackChain: + + @pytest.fixture + def track_list(self): + return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'},], + 'another': [{'method': 'set', 'scope': 'general'}]}}, + {'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}], + 'another': [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}]}}, + {'Stage3': {'another': [{'method': 'get', 'scope': 'general.daytime'}], + 'test': [{'method': 'get', 'scope': 'general'}], + 'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}] + + @pytest.fixture + def track_chain(self, track_list): + return TrackChain(track_list) + + @pytest.fixture + def track_chain_object(self): + return object.__new__(TrackChain) + + def test_init(self, track_list): + chain = TrackChain(track_list) + assert chain.track_list == track_list + + def test_get_all_scopes(self, track_chain, track_list): + scopes = track_chain.get_all_scopes(track_list) + expected_scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + assert scopes == expected_scopes + + def test_get_unique_scopes(self, track_chain_object): + variable_calls = [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, ] + + unique_scopes = track_chain_object.get_unique_scopes(variable_calls) + assert sorted(unique_scopes) == sorted(["general", "general.daytime", "general.daytime.noon", + "general.nighttime"]) + + def test_get_unique_scopes_no_general(self, track_chain_object): + variable_calls = [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.nighttime'}, ] + unique_scopes = track_chain_object.get_unique_scopes(variable_calls) + assert sorted(unique_scopes) == sorted(["general", "general.daytime", "general.nighttime"]) + + def test_get_all_dims(self, track_chain_object): + scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + dims = track_chain_object.get_all_dims(scopes) + expected_dims = {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2} + assert dims == expected_dims + + def test_create_track_chain(self, track_chain): + train_chain_dict = track_chain.create_track_chain() + assert list(train_chain_dict.keys()) == ["Stage1", "Stage2", "Stage3"] + assert len(train_chain_dict["Stage1"]) == 3 + assert len(train_chain_dict["Stage2"]) == 3 + assert len(train_chain_dict["Stage3"]) == 3 + + def test_control_dict(self, track_chain_object): + scopes = {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + control = track_chain_object.control_dict(scopes) + expected_control = {"another": {"general": None, "general.daytime": None, "general.daytime.noon": None, + "general.nighttime": None}, + "moonlight": {"general": None, "general.daytime": None}, + "sunlight": {"general": None}, + "test": {"general": None, "general.daytime": None}} + assert control == expected_control + + def test__create_track_chain(self, track_chain_object): + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': None, 'general.sub': None}, + 'skip': {'general': None, 'general.sub': None}} + sorted_track_dict = OrderedDict([("another", [{"method": "set", "scope": "general"}, + {"method": "get", "scope": "general"}, + {"method": "get", "scope": "general.sub"}]), + ("first", [{"method": "set", "scope": "general.sub"}, + {"method": "get", "scope": "general.sub"}]), + ("skip", [{"method": "get", "scope": "general.sub"}]),]) + stage = "Stage1" + track_objects, control = track_chain_object._create_track_chain(control, sorted_track_dict, stage) + assert len(track_objects) == 2 + assert control["another"]["general"] is not None + assert control["first"]["general"] is None + assert control["skip"]["general.sub"] is None + + def test_add_precursor(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_precursor(track_objects, tr_obj, prev_obj)) == 0 + assert tr_obj.precursor[0] == prev_obj + + def test_add_track_object_same_stage(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 0 + + def test_add_track_object_different_stage(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage2") + prev_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 1 + tr_obj = TrackObject(["first", "get", "general.sub"], "Stage2") + assert len(track_chain_object._add_track_object(track_objects, tr_obj, prev_obj)) == 2 + + def test_update_control(self, track_chain_object): + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': None, 'general.sub': None}, } + variable, scope, tr_obj = "first", "general", 23 + track_chain_object._update_control(control, variable, scope, tr_obj) + assert control[variable][scope] == tr_obj + + def test_add_set_object(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "set", "general"], "Stage1") + control_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, control_obj)) == 0 + assert len(tr_obj.precursor) == 1 + control_obj = TrackObject(["first", "set", "general"], "Stage0") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, control_obj)) == 1 + assert len(tr_obj.precursor) == 2 + + def test_add_set_object_no_control_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "set", "general"], "Stage1") + assert len(track_chain_object._add_set_object(track_objects, tr_obj, None)) == 1 + assert tr_obj.precursor is None + + def test_add_get_object_no_new_track_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + pre = TrackObject(["first", "set", "general"], "Stage1") + control = {"testVar": {"general": pre, "general.sub": None}} + scope, variable = "general", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, pre, control, scope, variable) + assert res == ([], False) + assert pre.successor[0] == tr_obj + + def test_add_get_object_no_control_obj(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + pre = TrackObject(["first", "set", "general"], "Stage1") + control = {"testVar": {"general": pre, "general.sub": None}} + scope, variable = "general.sub", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, None, control, scope, variable) + assert res == ([], False) + assert pre.successor[0] == tr_obj + + def test_add_get_object_skip_update(self, track_chain_object): + track_objects = [] + tr_obj = TrackObject(["first", "get", "general"], "Stage1") + control = {"testVar": {"general": None, "general.sub": None}} + scope, variable = "general.sub", "testVar" + res = track_chain_object._add_get_object(track_objects, tr_obj, None, control, scope, variable) + assert res == ([], True) + + def test_recursive_decent_avail_in_1_up(self, track_chain_object): + scope = "general.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + control_obj_var = {"general": expected_pre} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_in_2_up(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "set", "general"], "Stage1") + expected_pre.add_successor(TrackObject(["first", "get", "general.sub"], "Stage1")) + control_obj_var = {"general": expected_pre, "general.sub": expected_pre.successor[0]} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain_get(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "get", "general"], "Stage1") + expected_pre.add_precursor(TrackObject(["first", "set", "general"], "Stage1")) + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_recursive_decent_avail_from_chain_multiple_get(self, track_chain_object): + scope = "general.sub.sub" + expected_pre = TrackObject(["first", "get", "general"], "Stage1") + start_obj = TrackObject(["first", "set", "general"], "Stage1") + mid_obj = TrackObject(["first", "get", "general"], "Stage1") + expected_pre.add_precursor(mid_obj) + mid_obj.add_precursor(start_obj) + control_obj_var = {"general": expected_pre, "general.sub": None} + pre = track_chain_object._recursive_decent(scope, control_obj_var) + assert pre == expected_pre + + def test_clean_control(self, track_chain_object): + tr1 = TrackObject(["first", "get", "general"], "Stage1") + tr2 = TrackObject(["first", "set", "general"], "Stage1") + tr2.add_precursor(tr1) + tr3 = TrackObject(["first", "get", "general/sub"], "Stage1") + tr3.add_precursor(tr1) + control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': tr2, 'general.sub': tr3}, } + control = track_chain_object.clean_control(control) + expected_control = {'another': {'general': None, 'general.sub': None}, + 'first': {'general': tr2, 'general.sub': None}, } + assert control == expected_control + + +class TestTrackPlot: + + @pytest.fixture + def track_plot_obj(self): + return object.__new__(TrackPlot) + + @pytest.fixture + def track_list(self): + return [{'Stage1': {'test': [{'method': 'set', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'}, + {'method': 'get', 'scope': 'general'},], + 'another': [{'method': 'set', 'scope': 'general'}]}}, + {'Stage2': {'sunlight': [{'method': 'set', 'scope': 'general'}], + 'another': [{'method': 'get', 'scope': 'general.daytime'}, + {'method': 'set', 'scope': 'general'}, + {'method': 'set', 'scope': 'general.daytime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}, + {'method': 'get', 'scope': 'general.nighttime'}, + {'method': 'get', 'scope': 'general.daytime.noon'}]}}, + {'RunEnvironment': {'another': [{'method': 'get', 'scope': 'general.daytime'}], + 'test': [{'method': 'get', 'scope': 'general'}], + 'moonlight': [{'method': 'set', 'scope': 'general.daytime'}]}}] + + @pytest.fixture + def scopes(self): + return {"another": ["general", "general.daytime", "general.daytime.noon", "general.nighttime"], + "moonlight": ["general", "general.daytime"], + "sunlight": ["general"], + "test": ["general", "general.daytime"]} + + @pytest.fixture + def dims(self): + return {"another": 4, "moonlight": 2, "sunlight": 1, "test": 2} + + @pytest.fixture + def track_chain_dict(self, track_list): + return TrackChain(track_list).create_track_chain() + + @pytest.fixture + def path(self): + p = os.path.join(os.path.dirname(__file__), "TestExperiment") + if not os.path.exists(p): + os.makedirs(p) + yield p + shutil.rmtree(p, ignore_errors=True) + + def test_init(self, path, track_list): + assert "tracking.pdf" not in os.listdir(path) + TrackPlot(track_list, plot_folder=path) + assert "tracking.pdf" in os.listdir(path) + + def test_plot(self): + pass + + def test_line(self, track_plot_obj): + h, w = 0.6, 0.65 + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.lines) == 0 + track_plot_obj.line(start_x=5, end_x=6, y=2) + assert len(track_plot_obj.ax.lines) == 2 + pos_x, pos_y = np.array([5 + w, 6]), np.ones((2, )) * (2 + h / 2) + assert track_plot_obj.ax.lines[0]._color == "white" + assert track_plot_obj.ax.lines[0]._linewidth == 2.5 + assert track_plot_obj.ax.lines[1]._color == "darkgrey" + assert track_plot_obj.ax.lines[1]._linewidth == 1.4 + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + + def test_step(self, track_plot_obj): + x_int, h, w = 0.5, 0.6, 0.65 + track_plot_obj.space_intern_x = x_int + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.lines) == 0 + track_plot_obj.step(start_x=5, end_x=6, start_y=2, end_y=3) + assert len(track_plot_obj.ax.lines) == 2 + pos_x = np.array([5 + w, 6 - x_int / 2, 6 - x_int / 2, 6]) + pos_y = np.array([2 + h / 2, 2 + h / 2, 3 + h / 2, 3 + h / 2]) + assert track_plot_obj.ax.lines[0]._color == "white" + assert track_plot_obj.ax.lines[0]._linewidth == 2.5 + assert track_plot_obj.ax.lines[1]._color == "black" + assert track_plot_obj.ax.lines[1]._linewidth == 1.4 + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._x, track_plot_obj.ax.lines[1]._x, pos_x]).is_true() + assert PyTestAllEqual([track_plot_obj.ax.lines[0]._y, track_plot_obj.ax.lines[1]._y, pos_y]).is_true() + + def test_rect(self, track_plot_obj): + h, w = 0.5, 0.6 + track_plot_obj.height = h + track_plot_obj.width = w + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert len(track_plot_obj.ax.artists) == 0 + assert len(track_plot_obj.ax.texts) == 0 + track_plot_obj.rect(x=4, y=2) + assert len(track_plot_obj.ax.artists) == 1 + assert len(track_plot_obj.ax.texts) == 1 + track_plot_obj.ax.artists[0].xy == (4, 2) + track_plot_obj.ax.artists[0]._height == h + track_plot_obj.ax.artists[0]._width == w + track_plot_obj.ax.artists[0]._original_facecolor == "orange" + track_plot_obj.ax.texts[0].xy == (4 + w / 2, 2 + h / 2) + track_plot_obj.ax.texts[0]._color == "w" + track_plot_obj.ax.texts[0]._text == "get" + track_plot_obj.rect(x=4, y=2, method="set") + assert len(track_plot_obj.ax.artists) == 2 + assert len(track_plot_obj.ax.texts) == 2 + track_plot_obj.ax.artists[0]._original_facecolor == "lightblue" + track_plot_obj.ax.texts[0]._text == "set" + + + + def test_set_ypos_anchor(self, track_plot_obj, scopes, dims): + assert not hasattr(track_plot_obj, "y_pos") + assert not hasattr(track_plot_obj, "anchor") + y_int, y_ext, h = 0.5, 0.7, 0.6 + track_plot_obj.space_intern_y = y_int + track_plot_obj.height = h + track_plot_obj.space_extern_y = y_ext + track_plot_obj.set_ypos_anchor(scopes, dims) + d_y = 0 - sum([factor * (y_int + h) + y_ext - y_int for factor in dims.values()]) + expected_anchor = (d_y + sum(dims.values()), h + y_ext + sum(dims.values())) + assert np.testing.assert_array_almost_equal(track_plot_obj.anchor, expected_anchor) is None + assert track_plot_obj.y_pos["another"]["general"] == sum(dims.values()) + assert track_plot_obj.y_pos["another"]["general.daytime"] == sum(dims.values()) - (h + y_int) + assert track_plot_obj.y_pos["another"]["general.daytime.noon"] == sum(dims.values()) - 2 * (h + y_int) + + def test_plot_track_chain(self): + pass + + def test_add_variable_names(self): + pass + + def test_add_stages(self): + pass + + def test_create_track_chain_plot_run_env(self): + pass + + def test_set_lims(self, track_plot_obj): + track_plot_obj.x_max = 10 + track_plot_obj.space_intern_x = 0.5 + track_plot_obj.width = 0.4 + track_plot_obj.anchor = np.array((0.1, 12.5)) + track_plot_obj.fig, track_plot_obj.ax = plt.subplots() + assert track_plot_obj.ax.get_ylim() == (0, 1) # matplotlib default + assert track_plot_obj.ax.get_xlim() == (0, 1) # matplotlib default + track_plot_obj.set_lims() + assert track_plot_obj.ax.get_ylim() == (0.1, 12.5) + assert track_plot_obj.ax.get_xlim() == (0, 10+0.5+0.4) \ No newline at end of file diff --git a/test/test_plotting/test_training_monitoring.py b/test/test_plotting/test_training_monitoring.py index 7e4e21c1a28b35bef4aa6e613756378fe41611b5..6e5e0abbc5da0978e200f19019700c4dedd14ad0 100644 --- a/test/test_plotting/test_training_monitoring.py +++ b/test/test_plotting/test_training_monitoring.py @@ -94,7 +94,6 @@ class TestPlotModelHistory: assert "hist_additional.pdf" in os.listdir(path) - class TestPlotModelLearningRate: @pytest.fixture diff --git a/test/test_statistics.py b/test/test_statistics.py index cad915564aac675cadda0f625dca1a073b2c8959..3da7a47871f6d92472de268d165d788c343ce394 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -3,7 +3,7 @@ import pandas as pd import pytest import xarray as xr -from src.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, centre_apply,\ +from src.helpers.statistics import standardise, standardise_inverse, standardise_apply, centre, centre_inverse, centre_apply, \ apply_inverse_transformation lazy = pytest.lazy_fixture