diff --git a/video_prediction_tools/data_extraction/data_info.py b/video_prediction_tools/data_extraction/data_info.py new file mode 100644 index 0000000000000000000000000000000000000000..b62018861abed4668f725c8517ba2745807ed7ac --- /dev/null +++ b/video_prediction_tools/data_extraction/data_info.py @@ -0,0 +1,129 @@ +from typing import Dict, Any, List, Tuple, Union, Literal +import json +import logging + +from pydantic import BaseModel, validator, root_validator, conint, PositiveInt, conlist + +from utils.dataset_utils import DATASETS, get_dataset_info, get_vars + + +logging.basicConfig(level=logging.DEBUG) + +class VariableInfo(BaseModel): + dataset: str + name: str + lvl: List[int] + interpolation: Literal["z","p"] + + @validator("name") + def check_variable_name(cls, name, values): + logging.debug(values) + if name not in get_vars(values["dataset"]): + raise ValueError(f"no variable '{name}' available for dataset {values['dataset']}") + return name + + @root_validator(skip_on_failure=True) + def check_lvl_availability(cls, values): + variables = get_dataset_info(values["dataset"])["variables"] + variables = list(filter( + lambda v: v["name"] == values["name"] and set(values["lvl"]).issubset(v["lvl"]), + variables + )) + if not len(variables) > 0: + raise ValueError(f"variable {variables[0]['name']} at lvl {values['lvl']} is not available for dataset {values['dataset']}.") + return values + + + def __str__(self): + return "_".join(f"{self.name}-{l}{self.interpolation}" for l in self.lvl) + + +class DomainInfo(BaseModel): + dataset: str + years: Union[conint(ge=-1,le=-1), List[conint(gt=1970)]] + months: List[conint(ge=1,le=12)] + variables: List[VariableInfo] + coords_sw: conlist(float, min_items=2, max_items=2) + nyx: conlist(PositiveInt, min_items=2, max_items=2) + resolution: float + + @validator("dataset") + def check_dataset_name(cls, name): + if name not in DATASETS: + raise ValueError(f"'dataset' must be one of {DATASETS}.") + return name + + @validator("years", pre=True) + def all_years(cls, years, values): + if years == "all": + return get_dataset_info(values["name"])["years"] + else: + return years + + @validator("months", pre=True) + def months_strings(cls, months): + if months == "DJF": + month_list = [1, 2, 12] + elif months == "MAM": + month_list = [3, 4, 5] + elif months == "JJA": + month_list = [6, 7, 8] + elif months == "SON": + month_list = [9, 10, 11] + elif months == "all": + month_list = list(range(1, 13)) + else: + month_list = [] + return month_list + + @validator("variables", pre=True) + def str_to_list(cls, variables): + return json.loads(variables) + + @validator("variables", pre=True, each_item=True) + def vars(cls, variable, values): + logging.debug(values) + variable["dataset"] = values["dataset"] + logging.debug(variable) + return variable + + @validator("coords_sw") + def coords(cls, coords): + if not -90 <= coords[0] <= 90: + raise ValueError( + f"latitude is {coords[0]} but should be >= -90, <= 90" + ) + if not 0 <= coords[1] <= 360: + raise ValueError( + f"latitude is {coords[0]} but should be >= 0, <= 360" + ) + return coords + + @root_validator(skip_on_failure=True) + def valid_ne(cls, values): + ne = DomainInfo._nxy_to_coords(values["coords_sw"], values["nyx"], values["resolution"]) + try: + DomainInfo.coords(ne) + except ValueError as e: + raise e + return values + + def __str__(self) -> str: + years_s = sorted(self.years) + months_s = sorted(self.months) + desc = ( + f"{self.name}_Y{years_s[0]}-{years_s[-1]}_M{months_s[0]}-{months_s[-1]}_{self.nyx[0]}x{self.nyx[1]}_" + + "_".join(str(var) for var in self.variables) + ) + + return desc + + @staticmethod + def _nxy_to_coords(coords_sw, nyx, resolution): + return tuple([ + coord + n * resolution for coord, n in zip(coords_sw, nyx) + ]) + + @property + def coords_ne(self) -> Tuple[float]: + return DomainInfo._nxy_to_coords(self.coords_sw, self.nyx, self.resolution) \ No newline at end of file diff --git a/video_prediction_tools/env_setup/install_venv.sh b/video_prediction_tools/env_setup/install_venv.sh index 3ff7e7b83046ab88a7a7624c1f15bdac324b1492..55e99bb513b82d4078eb6c309ab57d4400de3a60 100755 --- a/video_prediction_tools/env_setup/install_venv.sh +++ b/video_prediction_tools/env_setup/install_venv.sh @@ -71,6 +71,7 @@ export PYTHONPATH="" export PYTHONPATH=${WORKING_DIR}/virtual_envs/${VENV_NAME}/lib/python${PYTHON_VERSION}/site-packages:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/utils:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}/data_extraction:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/model_modules:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/postprocess:$PYTHONPATH # ... also ensure that PYTHONPATH is appended when activating the virtual environment... @@ -78,6 +79,7 @@ echo 'export PYTHONPATH='"" >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/virtual_envs/'${VENV_NAME}'/lib/python'${PYTHON_VERSION}'/site-packages:$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}':$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/utils:$PYTHONPATH' >> ${ACT_VENV} +echo 'export PYTHONPATH='${WORKING_DIR}'/data_extraction:$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/model_modules:$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/postprocess:$PYTHONPATH' >> ${ACT_VENV} # ... install requirements diff --git a/video_prediction_tools/env_setup/install_venv_container.sh b/video_prediction_tools/env_setup/install_venv_container.sh index 45065c48a9f88b7c96a965255ca0165f79800129..08a917dcc2bec8d0aa4d4b01d720c6c671f4f9a2 100755 --- a/video_prediction_tools/env_setup/install_venv_container.sh +++ b/video_prediction_tools/env_setup/install_venv_container.sh @@ -75,6 +75,7 @@ export PYTHONPATH=/usr/local/lib/python${PYTHON_VERSION}/dist-packages/:$PYTHONP export PYTHONPATH=${WORKING_DIR}/virtual_envs/${VENV_NAME}/lib/python${PYTHON_VERSION}/site-packages:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/utils:$PYTHONPATH +export PYTHONPATH=${WORKING_DIR}/data_extraction:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/model_modules:$PYTHONPATH export PYTHONPATH=${WORKING_DIR}/postprocess:$PYTHONPATH # ... also ensure that PYTHONPATH is appended when activating the virtual environment... @@ -82,6 +83,7 @@ echo 'export PYTHONPATH=/usr/local/lib/python3.8/dist-packages/:$PYTHONPATH' >> echo 'export PYTHONPATH='${WORKING_DIR}'/virtual_envs/'${VENV_NAME}'/lib/python3.8/site-packages:$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}':$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/utils:$PYTHONPATH' >> ${ACT_VENV} +echo 'export PYTHONPATH='${WORKING_DIR}'/data_extraction:$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/model_modules:$PYTHONPATH' >> ${ACT_VENV} echo 'export PYTHONPATH='${WORKING_DIR}'/postprocess:$PYTHONPATH' >> ${ACT_VENV} # ... install requirements diff --git a/video_prediction_tools/env_setup/requirements.txt b/video_prediction_tools/env_setup/requirements.txt index dd5a43273f5077834916343019bf80e3f476e43a..4ca1d07528647e6d8af0af30317b2a8e71a73e69 100755 --- a/video_prediction_tools/env_setup/requirements.txt +++ b/video_prediction_tools/env_setup/requirements.txt @@ -12,3 +12,4 @@ normalization==0.4 utils==1.0.1 pytest==7.1.1 dask==2021.7.2 +pydantic \ No newline at end of file diff --git a/video_prediction_tools/env_setup/requirements_nocontainer.txt b/video_prediction_tools/env_setup/requirements_nocontainer.txt index dc8475e048298372f4ddcfe137a39a1fb16766b9..1ce67dd94221fe014d409b47865401ae3de763e2 100755 --- a/video_prediction_tools/env_setup/requirements_nocontainer.txt +++ b/video_prediction_tools/env_setup/requirements_nocontainer.txt @@ -11,4 +11,5 @@ netcdf4==1.5.8 #metadata==0.2 normalization==0.4 utils==1.0.1 +pydantic diff --git a/video_prediction_tools/main_scripts/main_data_extraction.py b/video_prediction_tools/main_scripts/main_data_extraction.py index 7f38a1a88b41053107d758e38893f9860f0ebe16..a09af10bca7a70b081ee402f8676012dc9f4e447 100644 --- a/video_prediction_tools/main_scripts/main_data_extraction.py +++ b/video_prediction_tools/main_scripts/main_data_extraction.py @@ -1,171 +1,44 @@ -import json as js -import os import argparse import itertools as it from pathlib import Path -from typing import Union, get_args +from typing import Union, List, Tuple import zipfile as zf import multiprocessing as mp import sys import json +import logging +from dataclasses import dataclass -from data_preprocess.extract_weatherbench import ExtractWeatherbench -from utils.dataset_utils import DATASETS, get_dataset_info - +from pydantic import ValidationError -# IDEA: type conversion (generic) => params_obj => bounds_checking (ds-specific)/ semantic checking +from data_extraction.weatherbench import ExtractWeatherbench +from data_extraction.era5 import ExtractERA5 +from data_extraction.data_info import VariableInfo, DomainInfo +from utils.dataset_utils import DATASETS, get_dataset_info -def dataset(name): - if name not in DATASETS: - raise ValueError(f"'dataset' must be one of {DATASETS}.") - return name +logging.basicConfig(level=logging.DEBUG) def source_dir(directory: str) -> Path: - dir = Path(directory) - if not dir.exists(): - raise ValueError(f"Input directory {dir.absolute()} does not exist") - return dir + source = Path(directory) + if not source.exists(): + raise ValueError(f"Input directory {source.absolute()} does not exist") + return source def destination_dir(directory: str) -> Path: - dir = Path(directory) - if not dir.exists(): - raise ValueError(f"Output directory: {dir.absolute()} does not exist.") - return dir - - -def years(years: str) -> Union[list[int], int]: - try: - year_list = [int(x) for x in years] - except ValueError as e: - if not years == "all": - raise ValueError( - f"years must be either a list of years or 'all', not {months}." - ) - year_list = -1 - - return year_list - - -def months(months: str) -> list[int]: - try: - month_list = [int(x) for x in months] - except ValueError as e: - if months == "DJF": - month_list = [1, 2, 12] - elif months == "MAM": - month_list = [3, 4, 5] - elif months == "JJA": - month_list = [6, 7, 8] - elif months == "SON": - month_list = [9, 10, 11] - elif months == "all": - month_list = list(range(1, 13)) - else: - raise ValueError( - f"months-string '{months}' cannot be converted to list of months" - ) - - if not all(1 <= m <= 12 for m in month_list): - errors = filter(lambda m: not 1 <= m <= 12, month_list) - raise ValueError( - f"all month integers must be within 1, ..., 12 not {list(errors)}" - ) - - return month_list - - -def variables(variables: str) -> list[dict]: - var_list = json.loads(variables) - - attributes = {"name", "lvl", "interpolation"} - interpolations = {"p", "z"} - - for var in var_list: - if not var.keys() == attributes: - raise ValueError(f"each variable should have the attributes {attributes}") - if not type(var["name"]) == str: - raise ValueError(f"'name' should be of type string not {type(var['name'])}") - if not type(var["lvl"]) == list: - raise ValueError(f"'lvl' should be of type list not {type(var['lvl'])}") - if not var["interpolation"] in interpolations: - raise ValueError(f"value of 'interpolation' should be one of {interpolations} not {var['interpolation']}") - if len(var["lvl"]) == 0: - raise ValueError(f"'lvl' should have at least one entry") - if not all(type(lvl) == int for lvl in var["lvl"]): - raise ValueError(f"all entries of 'lvl' should be of type int") - - return var_list - - -def get_data_files(variables: list, years, resolution, dirin: Path): - """ - Get path to zip files and names of the yearly files within. - :param variables: list of variables - :param years: list of years - :param months: list of months - :param resolution: - :param dirin: input directory - :return lists paths to zips of variables - """ - data_files = [] - zip_files = [] - res_str = f"{resolution}deg" - for var in variables: - var_dir = dirin / res_str / var - if not var_dir.exists(): - raise ValueError( - f"variable {var} is not available for resolution {res_str}" - ) - - zip_file = var_dir / f"{var}_{res_str}.zip" - with zf.ZipFile(zip_file, "r") as myzip: - names = myzip.namelist() - if not all(any(str(year) in name for name in names) for year in years): - raise ValueError( - f"variable {var} is not available for all years: {years}" - ) - names = filter(lambda name: any(str(year) in name for year in years), names) - - data_files.append(list(names)) - zip_files.append(zip_file) - - return zip_files, data_files - - -def nyx(nyx): - try: - nyx = [int(n) for n in nyx] - except ValueError as e: - raise ValueError(f"number of grid points should be integers not {nyx}") - if not all(n > 0 for n in nyx): - raise ValueError(f"number of grid points should be > 0") - - return nyx - - -def coords(coords_sw): - try: - coords = [float(c) for c in coords_sw] - except ValueError as e: - raise ValueError(f"coordinates should be floats not {coords}") - if not -90 <= coords[0] <= 90: - raise ValueError( - f"latitude of sw-corner is {coords[0]} but should be >= -90, <= 90" - ) - if not 0 <= coords[1] <= 360: - raise ValueError( - f"latitude of sw-corner is {coords[0]} but should be >= 0, <= 360" - ) - return coords - + destination = Path(directory) + if not destination.exists(): + destination.mkdir() # TODO: monitor if good + return destination def main(): # TODO consult Bing for defaults + for arg in sys.argv: + logging.debug(arg) + parser = argparse.ArgumentParser() parser.add_argument( "dataset", - type=dataset, help="Name of the dataset" ) parser.add_argument( @@ -179,17 +52,16 @@ def main(): help="Destination directory where the netCDF-files will be stored", ) parser.add_argument( - "years", nargs="+", type=int, help="Years of data to be processed." + "years", nargs="+", help="Years of data to be processed." ) parser.add_argument( "variables", help="list of variables to extract", - type=variables, ) parser.add_argument( "--resolution", "-r", - choices=[1.40625, 2.8125, 5.625], + type=float, default=5.625, ) parser.add_argument( @@ -198,16 +70,14 @@ def main(): nargs="+", dest="months", default="all", - type=months, help="Months of data. Can also be 'all' or season-strings, e.g. 'DJF'.", ) parser.add_argument( "--sw_corner", "-swc", - dest="sw_corner", + dest="coords_sw", nargs="+", - type=coords, default=(0.0, 0.0), help="Defines south-west corner of target domain (lat, lon)=(-90..90, 0..360)", ) @@ -216,62 +86,40 @@ def main(): "-nyx", dest="nyx", nargs="+", - type=nyx, default=(10, 20), help="Number of grid points in zonal and meridional direction.", ) - - args = parser.parse_args() - - # check if north-east corner is valid - ne_corner = [ - coord + n * args.resolution for coord, n in zip(args.sw_corner, args.nyx) - ] - if not (-90 <= ne_corner[0] <= 90 and 0 <= ne_corner[1] <= 360): - raise ValueError( - f"number of grid points {args.nyx} will result in a invalid north-east corner: {ne_corner}" - ) - - # check if arguments can be provided by dataset - dataset_info = get_dataset_info(args.dataset) - years_not_avail = [year not in dataset_info["years"] for year in args.years] - vars_avail_map = {var["name"]: var for var in dataset_info["variables"]} - - for variable in args.variables: - try: - var = vars_avail_map[variable["name"]] - except KeyError as e: - raise ValueError(f"variable {variable['name']} is not available for dataset {args.dataset}.") - - lvl_not_avail = list(filter(lambda l: not l in var["lvl"], variable["lvl"])) - if len(lvl_not_avail) > 0: - raise ValueError(f"variable {variable['name']} at lvl {lvl_not_avail} is not available for dataset {args.dataset}.") - - + args = parser.parse_args() + # pydantic input validation / parsing + logging.debug(args) + try: + logging.debug(vars(args)) + #domain = DomainInfo(args.dataset, args.years, args.months, args.variables, args.nyx, args.resolution) + domain = DomainInfo(**vars(args)) + logging.debug(domain) + except ValidationError as e: + logging.exception(str(e)) + sys.exit( ) + # get extraction instance - if args.dataset == "weatherbench": + if domain.dataset == "weatherbench": extraction = ExtractWeatherbench( args.source_dir, args.destination_dir, - args.variables, - args.years, - args.months, - (args.sw_corner[0], ne_corner[0]), - (args.sw_corner[1], ne_corner[1]), - args.resolution, + domain ) - elif args.dataset == "era5": - extraction = NewEra5Extraction() + elif domain.dataset == "era5": + extraction = ExtractERA5() else: raise ValueError("no other extractor.") - print("initialized extraction") + logging.info("initialized extraction") extraction() if __name__ == "__main__": - print("start script") + logging.info("start script") mp.set_start_method("spawn") # fix cuda initalization issue main()