Skip to content
Snippets Groups Projects
Commit d2e233cd authored by Simon Grasse's avatar Simon Grasse
Browse files

implement extractor base class + weatherbench

parent 7d2b4c5b
No related branches found
No related tags found
No related merge requests found
Pipeline #129018 failed
from pathlib import Path
import logging import logging
from xml.dom.minidom import DOMImplementation import itertools as it
from pathlib import Path
from zipfile import ZipFile
from multiprocessing import Pool
from abc import ABC, abstractmethod
import pandas as pd
import xarray as xr
from utils.dataset_utils import get_filename_template from utils.dataset_utils import get_filename_template
from data_extraction.data_info import VariableInfo, DomainInfo from data_extraction.data_info import VariableInfo, DomainInfo
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.INFO)
PROCESSES = 20
class ExtractWeatherbench: class Extraction(ABC):
"""Extractor base class.
"""
def __init__( def __init__(
self, self,
dirin: Path, dirin: Path,
...@@ -15,39 +24,151 @@ class ExtractWeatherbench: ...@@ -15,39 +24,151 @@ class ExtractWeatherbench:
domain: DomainInfo domain: DomainInfo
): ):
""" """
This script performs several sanity checks and sets the class attributes accordingly. Initialize Extractor.
:param dirin: directory to the weatherbench data :param dirin: directory to the weatherbench data
:param dirout: directory where the output data will be saved :param dirout: directory where the output data will be saved
:param domain: information on spatial and temporal (sub)domain to extract :param domain: information on spatial and temporal (sub)domain to extract
""" """
self.dirin = dirin self.dirin: Path = dirin
self.dirout = dirout self.dirout: Path = dirout
self.domain: DomainInfo = domain
def __call__(self):
"""
Run extraction.
:return: -
"""
logging.info(f"domain to be extracted: {self.domain}")
self.domain = DOMImplementation # TODO: ensure data is moved efficiently
data = self.load_data()
data = self.selection(data)
data = self.interpolation(data)
self.write_output(data)
self.years = domain.years @abstractmethod
self.months = domain.months def load_data(self) -> xr.Dataset:
"""Load data for relevant years from files as xarray Dataset.
Needs to be overwritten by implementing class !
:return: Data for entire Dataset.
:rtype: xr.Dataset
"""
pass
def selection(self, ds: xr.Dataset) -> xr.Dataset:
"""Do spatial and temporal selection.
:param ds: Dataset containing spatiotemporal data.
:type data: xr.Dataset
:return: Reduced dataset only containing relevant region/timeframe.
:rtype: xr.Dataset
"""
# spatial selection/interpolation
ds = ds.sel(
lat=slice(*self.domain.lat_range),
lon=slice(*self.domain.lon_range)
)
# temporal selection
ds = ds.isel(time=ds.time.dt.month.isin(self.domain.months))
return ds
def interpolation(self, data: xr.Dataset) -> xr.Dataset:
"""Interpolate/select correct variable levels.
:param data: Dataset containing variable data.
:type data: xr.Dataset
:return: Reduced dataset containing only relevant levels.
:rtype: xr.Dataset
"""
return data
def write_output(self, data: xr.Dataset):
"""Write extracted domain to monthly NetCDF efficiently.
:param data: Data from extracted domain.
:type data: xr.Dataset
"""
# TODO: use zarr for better parallel IO
templ = get_filename_template(self.domain.dataset)
names = [
templ.format(year=year, month=month) for year, month in
zip(data.time.dt.year.values, data.time.dt.values)
]
data.coords["year_month"] = ("time", names)
names, datasets = zip(*data.groupby("year_month"))
# use dask for parallel writing
files = [self.dirout / name for name in names]
data.save_mfdataset(datasets, files)
class ExtractWeatherbench(Extraction):
def __init__(self, *args, **kwargs):
"""
Initialize Weatherbench extraction.
"""
super.__init__(*args, **kwargs)
# TODO handle special variables for resolution 5.625 (temperature_850, geopotential_500) # TODO handle special variables for resolution 5.625 (temperature_850, geopotential_500)
if domain.resolution == 5.625: if self.domain.resolution == 5.625:
for var in domain.variables: for var in self.domain.variables:
combined_name = f"{var.name}_{var.lvl[0]}" combined_name = f"{var.name}_{var.lvl[0]}"
if combined_name in {"temperature_850", "geopotential_500"}: if combined_name in {"temperature_850", "geopotential_500"}:
var.name = combined_name var.name = combined_name
self.variables = domain.variables
self.lat_range = (domain.coords_sw[0], domain.coords_ne[0]) def load_data(self) -> xr.Dataset:
self.lon_range = (domain.coords_sw[1], domain.coords_ne[1]) """
Implement weatherbench specific loading of relevant data.
:return: xarray Dataset instance (may use dask array within)
"""
zipfiles = self.get_data_files()
args = it.chain.from_iterable(
zip(it.repeat(self.dirout), it.repeat(zipfile), self.get_names(zipfile))
for zipfile in zipfiles
)
self.resolution = domain.resolution with Pool(PROCESSES) as pool:
files = pool.starmap(
ExtractWeatherbench.extract_file,
args,
chunksize=self.domain.months_count//PROCESSES
)
return xr.openmfdataset(files)
def __call__(self): def get_data_files(self):
res_str = f"{self.domain.resolution}deg"
zip_files = [
self.dirin / res_str / var.name / f"{var.name}_{res_str}.zip"
for var in self.domain.variables
]
def get_names(self, zipfile):
with ZipFile(zipfile, "r") as myzip:
return filter(
lambda name: any(str(year) in name for year in self.domain.years),
myzip.namelist()
)
@staticmethod
def extract_file(dirout, zipfile, name) -> Path:
""" """
Run extraction stub. return
:return: -
""" """
logging.info(f"domain to be extracted: {self.domain}") try:
logging.info("dummy exection, nothing todo") with ZipFile(zipfile, "r") as myzip:
\ No newline at end of file myzip.extract(path=dirout, member=name)
except Exception:
print(f"runtime exception for file {zipfile}")
return dirout / name
...@@ -7,7 +7,7 @@ from .stats import MinMax, ZScore ...@@ -7,7 +7,7 @@ from .stats import MinMax, ZScore
from .dataset import Dataset from .dataset import Dataset
import dask import dask
from dask.base import tokenize from dask.base import tokenize
from utils.dataset_utils import DATASETS, get_dataset_info, get_filename_template from utils.dataset_utils import get_dataset_info, get_filename_template
normalise = {"MinMax": MinMax, normalise = {"MinMax": MinMax,
"ZScore": ZScore} "ZScore": ZScore}
...@@ -19,6 +19,6 @@ def get_dataset(name: str, *args, **kwargs): ...@@ -19,6 +19,6 @@ def get_dataset(name: str, *args, **kwargs):
raise ValueError(f"unknown dataset: {name}") raise ValueError(f"unknown dataset: {name}")
return Dataset(*args, **kwargs, return Dataset(*args, **kwargs,
normalize=normalise[ds_info["normalize"]], normalize=normalise[ds_info["normalize"]], # FIXME: normalize not present in dataset info
filename_template=get_filename_template(name) filename_template=get_filename_template(name)
) )
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment