diff --git a/video_prediction_tools/data_extraction/weatherbench.py b/video_prediction_tools/data_extraction/weatherbench.py index 49d6f09db3dff17656291edbaf82137d3500e278..d9f2c0103165e0f1541422c18afd2aacc2000ad0 100644 --- a/video_prediction_tools/data_extraction/weatherbench.py +++ b/video_prediction_tools/data_extraction/weatherbench.py @@ -1,13 +1,22 @@ -from pathlib import Path import logging -from xml.dom.minidom import DOMImplementation +import itertools as it +from pathlib import Path +from zipfile import ZipFile +from multiprocessing import Pool +from abc import ABC, abstractmethod + +import pandas as pd +import xarray as xr from utils.dataset_utils import get_filename_template from data_extraction.data_info import VariableInfo, DomainInfo logging.basicConfig(level=logging.INFO) +PROCESSES = 20 -class ExtractWeatherbench: +class Extraction(ABC): + """Extractor base class. + """ def __init__( self, dirin: Path, @@ -15,39 +24,151 @@ class ExtractWeatherbench: domain: DomainInfo ): """ - This script performs several sanity checks and sets the class attributes accordingly. + Initialize Extractor. :param dirin: directory to the weatherbench data :param dirout: directory where the output data will be saved :param domain: information on spatial and temporal (sub)domain to extract """ - self.dirin = dirin - self.dirout = dirout + self.dirin: Path = dirin + self.dirout: Path = dirout + + self.domain: DomainInfo = domain + + def __call__(self): + """ + Run extraction. + :return: - + """ + logging.info(f"domain to be extracted: {self.domain}") + + # TODO: ensure data is moved efficiently + data = self.load_data() + data = self.selection(data) + data = self.interpolation(data) + self.write_output(data) + + @abstractmethod + def load_data(self) -> xr.Dataset: + """Load data for relevant years from files as xarray Dataset. + + Needs to be overwritten by implementing class ! + + :return: Data for entire Dataset. + :rtype: xr.Dataset + """ + pass + + def selection(self, ds: xr.Dataset) -> xr.Dataset: + """Do spatial and temporal selection. + + :param ds: Dataset containing spatiotemporal data. + :type data: xr.Dataset + :return: Reduced dataset only containing relevant region/timeframe. + :rtype: xr.Dataset + """ + # spatial selection/interpolation + ds = ds.sel( + lat=slice(*self.domain.lat_range), + lon=slice(*self.domain.lon_range) + ) + + # temporal selection + ds = ds.isel(time=ds.time.dt.month.isin(self.domain.months)) + + return ds + + def interpolation(self, data: xr.Dataset) -> xr.Dataset: + """Interpolate/select correct variable levels. + + :param data: Dataset containing variable data. + :type data: xr.Dataset + :return: Reduced dataset containing only relevant levels. + :rtype: xr.Dataset + """ + return data + + def write_output(self, data: xr.Dataset): + """Write extracted domain to monthly NetCDF efficiently. - self.domain = DOMImplementation + :param data: Data from extracted domain. + :type data: xr.Dataset + """ + # TODO: use zarr for better parallel IO + templ = get_filename_template(self.domain.dataset) + names = [ + templ.format(year=year, month=month) for year, month in + zip(data.time.dt.year.values, data.time.dt.values) + ] + + data.coords["year_month"] = ("time", names) + names, datasets = zip(*data.groupby("year_month")) + + # use dask for parallel writing + files = [self.dirout / name for name in names] + data.save_mfdataset(datasets, files) + - self.years = domain.years - self.months = domain.months +class ExtractWeatherbench(Extraction): + def __init__(self, *args, **kwargs): + """ + Initialize Weatherbench extraction. + """ + super.__init__(*args, **kwargs) # TODO handle special variables for resolution 5.625 (temperature_850, geopotential_500) - if domain.resolution == 5.625: - for var in domain.variables: + if self.domain.resolution == 5.625: + for var in self.domain.variables: combined_name = f"{var.name}_{var.lvl[0]}" if combined_name in {"temperature_850", "geopotential_500"}: var.name = combined_name - - self.variables = domain.variables - - self.lat_range = (domain.coords_sw[0], domain.coords_ne[0]) - self.lon_range = (domain.coords_sw[1], domain.coords_ne[1]) - - self.resolution = domain.resolution - + + + def load_data(self) -> xr.Dataset: + """ + Implement weatherbench specific loading of relevant data. + :return: xarray Dataset instance (may use dask array within) + """ + zipfiles = self.get_data_files() + args = it.chain.from_iterable( + zip(it.repeat(self.dirout), it.repeat(zipfile), self.get_names(zipfile)) + for zipfile in zipfiles + ) + + with Pool(PROCESSES) as pool: + files = pool.starmap( + ExtractWeatherbench.extract_file, + args, + chunksize=self.domain.months_count//PROCESSES + ) + + return xr.openmfdataset(files) + + def get_data_files(self): + res_str = f"{self.domain.resolution}deg" + zip_files = [ + self.dirin / res_str / var.name / f"{var.name}_{res_str}.zip" + for var in self.domain.variables + ] - def __call__(self): + + def get_names(self, zipfile): + with ZipFile(zipfile, "r") as myzip: + return filter( + lambda name: any(str(year) in name for year in self.domain.years), + myzip.namelist() + ) + + @staticmethod + def extract_file(dirout, zipfile, name) -> Path: """ - Run extraction stub. - :return: - + return """ - logging.info(f"domain to be extracted: {self.domain}") - logging.info("dummy exection, nothing todo") \ No newline at end of file + try: + with ZipFile(zipfile, "r") as myzip: + myzip.extract(path=dirout, member=name) + except Exception: + print(f"runtime exception for file {zipfile}") + + return dirout / name + diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py index 3b7afcc929e4affcf6f7aa14da37808d1e1faf78..dfaf48208e63364064866c03ed516a2fc01f17c7 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/__init__.py @@ -7,7 +7,7 @@ from .stats import MinMax, ZScore from .dataset import Dataset import dask from dask.base import tokenize -from utils.dataset_utils import DATASETS, get_dataset_info, get_filename_template +from utils.dataset_utils import get_dataset_info, get_filename_template normalise = {"MinMax": MinMax, "ZScore": ZScore} @@ -19,6 +19,6 @@ def get_dataset(name: str, *args, **kwargs): raise ValueError(f"unknown dataset: {name}") return Dataset(*args, **kwargs, - normalize=normalise[ds_info["normalize"]], + normalize=normalise[ds_info["normalize"]], # FIXME: normalize not present in dataset info filename_template=get_filename_template(name) )