diff --git a/video_prediction_tools/data_extraction/data_info.py b/video_prediction_tools/data_extraction/data_info.py index b62018861abed4668f725c8517ba2745807ed7ac..1d2c0876d391e92773337b8eaee407907c8f1ca8 100644 --- a/video_prediction_tools/data_extraction/data_info.py +++ b/video_prediction_tools/data_extraction/data_info.py @@ -1,10 +1,10 @@ -from typing import Dict, Any, List, Tuple, Union, Literal +from typing import List, Tuple, Union, Literal import json import logging from pydantic import BaseModel, validator, root_validator, conint, PositiveInt, conlist -from utils.dataset_utils import DATASETS, get_dataset_info, get_vars +from utils.dataset_utils import DATASETS, get_dataset_info, get_vars, INTERPOLATION_UNITS logging.basicConfig(level=logging.DEBUG) @@ -13,7 +13,7 @@ class VariableInfo(BaseModel): dataset: str name: str lvl: List[int] - interpolation: Literal["z","p"] + interpolation: Union[Literal[*INTERPOLATION_UNITS], None] # TODO align untis with units defined in InterpolationInfo @validator("name") def check_variable_name(cls, name, values): @@ -22,18 +22,36 @@ class VariableInfo(BaseModel): raise ValueError(f"no variable '{name}' available for dataset {values['dataset']}") return name + @validator("interpolation") + def check_interpolation_availability(cls, interpolation, values): + info = get_dataset_info(values["dataset"]) + if not interpolation in info.levels.interpolation_units: + raise ValueError(f"no information on how to interpolate dataset {values['dataset']} in unit {interpolation}") + + @root_validator(skip_on_failure=True) def check_lvl_availability(cls, values): - variables = get_dataset_info(values["dataset"])["variables"] - variables = list(filter( - lambda v: v["name"] == values["name"] and set(values["lvl"]).issubset(v["lvl"]), - variables - )) - if not len(variables) > 0: - raise ValueError(f"variable {variables[0]['name']} at lvl {values['lvl']} is not available for dataset {values['dataset']}.") + dataset, name, lvls, interpolation = values.values() + info = get_dataset_info(dataset) + variable = info.get_var(name) + + if variable.level_type == "sfc": # TODO mark convention for surface variables + len(lvls) == 0 + elif variable.level_type == "ml": + diff = set(lvls).issubset(info.levels.ml) + + # interpolate difference from model lvls + interp_info = info.levels.get_interp(interpolation) + out_of_range = list(filter(lambda lvl: interp_info.start <= lvl <= interp_info.end, diff)) + if len(out_of_range) > 0: + raise ValueError(f"Cannot interpolate {name} for lvls: {out_of_range}: Out of Range") + elif variable.level_type == "pl": + if not set(lvls).issubset(info.levels.pl): + diff = set(lvls).difference(info.levels.pl) + raise ValueError(f"variable {name} not available for lvls: {diff}") + return values - def __str__(self): return "_".join(f"{self.name}-{l}{self.interpolation}" for l in self.lvl) @@ -56,7 +74,7 @@ class DomainInfo(BaseModel): @validator("years", pre=True) def all_years(cls, years, values): if years == "all": - return get_dataset_info(values["name"])["years"] + return get_dataset_info(values["name"]).years else: return years @@ -126,4 +144,24 @@ class DomainInfo(BaseModel): @property def coords_ne(self) -> Tuple[float]: - return DomainInfo._nxy_to_coords(self.coords_sw, self.nyx, self.resolution) \ No newline at end of file + return DomainInfo._nxy_to_coords(self.coords_sw, self.nyx, self.resolution) + + @property + def lat_range(self) -> Tuple[float]: + return (self.coords_sw[0], self.coords_ne[0]) + + @property + def lon_range(self) -> Tuple[float]: + return (self.coords_sw[1], self.coords_ne[1]) + + @property + def years_count(self) -> int: + return len(self.years) + + @property + def months_count(self) -> int: + return self.years_count*len(self.months) + + @property + def variable_names(self) -> List[str]: + return [var.name for var in self.variables] \ No newline at end of file diff --git a/video_prediction_tools/main_scripts/main_data_extraction.py b/video_prediction_tools/main_scripts/main_data_extraction.py index a09af10bca7a70b081ee402f8676012dc9f4e447..b2618eaa4f764b551d6a3af9d2f08e89aba1d95d 100644 --- a/video_prediction_tools/main_scripts/main_data_extraction.py +++ b/video_prediction_tools/main_scripts/main_data_extraction.py @@ -14,7 +14,6 @@ from pydantic import ValidationError from data_extraction.weatherbench import ExtractWeatherbench from data_extraction.era5 import ExtractERA5 from data_extraction.data_info import VariableInfo, DomainInfo -from utils.dataset_utils import DATASETS, get_dataset_info logging.basicConfig(level=logging.DEBUG) diff --git a/video_prediction_tools/utils/dataset_utils.py b/video_prediction_tools/utils/dataset_utils.py index 89a838e3c9963606bb33efb6d0d4f3057f21e322..79dac353e756245dbe783cd2960d0a160dd6f46b 100644 --- a/video_prediction_tools/utils/dataset_utils.py +++ b/video_prediction_tools/utils/dataset_utils.py @@ -7,6 +7,7 @@ functions providing info about available options Provides: * DATASET_META_LOCATION * DATASETS * get_dataset_info + * DatasetInfo """ # import sys @@ -14,46 +15,114 @@ Provides: * DATASET_META_LOCATION import json from pathlib import Path -from typing import Dict, Any, List, Tuple -from dataclasses import dataclass +from functools import cache +from enum import StrEnum, auto +from typing import Dict, Any, List, Tuple, Literal + +from pydantic import BaseModel, Literal, PositiveInt, conint, PositiveFloat, root_validator, ValidationError DATASET_META_LOCATION = Path(__file__).parent.parent / "config" / "datasets" / "info" DATASETS = [path.stem for path in DATASET_META_LOCATION.iterdir() if path.name.endswith(".json")] - + DATE_TEMPLATE = "{year}-{month:02d}" +INTERP_UNITS = ["hpa", "m", "p"] + + +class VariableInfo(BaseModel): + name: str + level_type: Literal["sfc","pl","ml"] + + def __eq__(self, other): + return str(self) == str(other) + + def __hash__(self): + return hash(self.name) + + def __str__(self): + return self.name + + +class InterpInfo(BaseModel): + unit: Literal[*INTERP_UNITS] + start: PositiveInt + end: PositiveInt + + @root_validator + def start_lt_end(cls, values): + if not values["start"] < values["end"]: + raise ValueError( + f"Interpolation: unit {values['unit']} Start value should be bigger then end value." + ) + + return values + + def __eq__(self, other): + return str(self) == str(other) + + def __hash__(self): + return hash(self.unit) + + def __str__(self): + return self.unit + +class LevelInfo(BaseModel): + ml: List[PositiveInt] + pl: List[PositiveInt] + interpolation: List[InterpInfo] + + @cache + @property + def interpolation_units(self) -> List[str]: + return [i_info.unit for i_info in self.interpolation] + + def get_interp(self, unit: Literal[*INTERP_UNITS]) -> InterpInfo: + return self.interpolation[self.interpolation.index(unit)] + + +class Resolution(BaseModel): + deg: PositiveFloat + nx: PositiveInt + ny: PositiveInt + +class GridInfo(BaseModel): + grid_type: Literal["lonlat"] + xname: str + xunits: Literal["degree"] # maybe unnessecariy ? + yname: str + yunits: Literal["degree"] # maybe unnessecariy ? + grid_spacing: List[Resolution] + +class DatasetInfo(BaseModel): + variables: List[VariableInfo] + levels: LevelInfo + grid: GridInfo + years: List[conint(ge=1979)] + + @cache + @property + def var_names(self) -> List[str]: + return [var.name for var in self.variables] + + def get_var(self, name: str) -> VariableInfo: + return self.variables[self.var_names.index(name)] + def get_filename_template(name: str) -> str: return f"{name}_{DATE_TEMPLATE}.nc" -def get_dataset_info(name: str) -> Dict[str,Any]: +@cache +def get_dataset_info(name: str) -> DatasetInfo: """Extract metainformation about dataset from corresponding JSON file.""" file = DATASET_META_LOCATION / f"{name}.json" try: with open(file, "r") as f: - return json.load(f) # TODO: input validation => specify schema / pydantic ? + return DatasetInfo(**json.load(f)) except FileNotFoundError as e: - raise ValueError(f"Information on dataset '{dataset}' doesnt exist.") + raise ValueError(f"Cannot access {name} information: {f} not available") + except ValidationError as e: + raise ValueError(f"Cannot access {name} information: Invalid Format of {f}\n{str(e)}") def get_vars(name: str) -> List[str]: """Extract names of available variables.""" - info = get_dataset_info(name) - return [variable["name"] for variable in info["variables"]] - -var_schema = { - "type": "object", - "properties": { - "name": {"type": "string"}, - "lvl": { - "type": "array", - "items": { - "type": "integer" - }, - "minItems": 1, - "uniqueItems": True, - }, - "interpolation": {"type": "string", "enum": ["p", "z"]} - }, - "required": ["name"], - "additionalProperties": False -} \ No newline at end of file + return get_dataset_info(name).var_names \ No newline at end of file