# WeatherGenerator: Datareader implementation
All data-dependent adjustments of the training pipeline for an experiment happen in the datareader: In it we open, select and provide the data, while we leave the rest of the repository mainly untouched. In this tutorial, we use the stable version of the WeatherGenerator on the main branch and work on a fork of this repository. The tutorial assumes that we already have a data store for our experiment in a shared scratch.
This is on Juwels e.g. in ```/p/scratch/weatherai/shared/weather_generator_data```, or on Leonardo on in ```/leonardo_work/AIFAC_5C0_154/weathergen/data/```. Otherwise, we first do the [dataset onboarding](https://gitlab.jsc.fz-juelich.de/esde/WeatherGenerator-private/-/wikis/home/Operating-procedures/dataset-onboarding).
## Setting up fork and environment
We fork the WeatherGenerator repository to our personal GitHub account, clone this fork from Juwels (see [Introduction](https://gitlab.jsc.fz-juelich.de/hedgedoc/N5iZmj_ZT5eKer_xNOjCKA?view#)) and set the original repo as upstream with
```
git remote add upstream https://github.com/ecmwf/WeatherGenerator.git
```
The environment can be installed with the optional specification for pytorch gpu development with
```
uv sync --extra gpu
```
or with
```
./scripts/actions.sh sync
```
#### Updating a fork
The repository is active and to stay in line with the developments, the local version of the repository should be synced with the origin. To do this, sync our remote main branch on Github with origin/main and pull changes. On the local feature branch the latest changes should be staged or committed. Then do:
```bash
git checkout main
git pull origin main
git checkout my_datareader
git merge main
```
After the merge, run again ```uv sync```.
## Interactive debugging
During implementation of the datareader, we work on an interactive GPU node. To request this, we run from the WeatherGenerator directory (see also [Getting started on Jureca](https://gitlab.jsc.fz-juelich.de/hedgedoc/N5iZmj_ZT5eKer_xNOjCKA?view#)):
#### On Juwels
```
srun --gres=gpu:4 --partition=develbooster --account=weatherai -t 02:00:00 --pty bash -i
```
develbooster is a small queue and we can't request more than 2 hours, so rerun another GPU request after that time period.
#### On Leonardo
```
srun --gres=gpu:4 --partition=boost_usr_prod --account=aifac_5c0_154 -t 02:00:00 --pty bash -i
```
#### Launch script with uv
In the following, we will do prognostic **finetuning** of the [Hackathon I](https://gitlab.jsc.fz-juelich.de/hedgedoc/yU0GQXLmSByFnZlYrUzsJA?view#Prepare-data-configs) model (run_id= lfr79awq). The finetuning can be launched with:
```
uv run --offline train_continue --from_run_id lfr79awq --options istep=0 num_epochs=32 forecast_steps=4 lr_max=0.00005 streams_directory=config/streams/cerra_era5/ 'freeze_modules=".*ERA5.*"'
```
This call enters ```train_continue``` and allows us to catch errors directly.
We can use the package *code* for more convenience. Adding a line with ```code.interact(local=locals())``` in code wherever debugging is needed will interrupt the execution and open an interactive python shell in the terminal that allows inspection of local objects.
## File dependencies
We will adjust the experiment above finetune on our data store by implementing and integrating a custom datareader class. This is called MyDataReader hereafter. The impementation is based on [the datareader template for synop station data](https://github.com/ecmwf/WeatherGenerator/pull/860/files). The new datareader script should be placed ```src/weathergen/datasets/```. This script has a couple of file dependencies that we need to add or adjust.
**Stream config** The entry point ```train_continue``` that we call takes as argument the directory to a stream config for our data store:
```
uv run --offline train_continue --from_run_id=lfr79awq --options streams_directory=/p/home/jusers/wesselkamp1/jureca/WeatherGenerator/config/streams/my_stream sampling_rate_target=0.01 istep=0 num_epochs=2 'freeze_modules=".*ERA5.*"'
```
Hence we create this folder for our new data stream, ```mkdir config/streams/my_stream```, and place a config there: ```config/streams/my_stream/stream.yaml```. Check out other stream configs as templates.
This config specifies the input (source) and target variables that will be used in the experiment, as well as the architecture of the prediction head. Say our data comes as a .zarr store that contains geoinformation and a variable, then for a prognostic prediction of land surface temperature (LST), the head of a stream config could look like this:
```yaml
MY_STREAM:
type: my_type # will be used to select MyDataReader, see MultiStreamDataSampler
filename: ["my_datastore.zarr"] # name of data store in shared data scratch, see Intro
data_start_time : "2017-01-01 00:00" # start time for loaded data
data_end_time : "2017-09-30 00:00" # end time for loaded data
target: ["LST"] # our target variable
source: ["LST"] # same as target, required for prognostic finetuning
geoinfo: ["DEM"] # if needed, included geoinfo. Otherwise adjust datareader accordingly.
val_source_channels: ["LST"]
val_target_channels: ["LST"] # same as target channels for evaluation.
...
```
For finetuning of ```lfr79awq```, the other streams used during pre-training need to in the streams folder. That is here the era5 stream config:
```
scp config/streams/cerra_era5/era5.yaml config/streams/my_stream/
```
**MultiStreamDataSampler** MyDataReader will be imported by the [MultiStreamDataSampler](https://github.com/ecmwf/WeatherGenerator/blob/main/src/weathergen/datasets/multi_stream_data_sampler.py) in ```src/weathergen/datasets/multi_stream_data_sampler.py```. The MultiStreamDataSampler collects a list from all the stream configs in ```config/streams/my_stream```. This list contains the names of the data stores, hence also *my_datastore.zarr*. ```stream_info``` refers to the stream specific config.
We import MyDataReader (add ```from weathergen.datasets.my_data_reader import MyDataReader```) that can then be selected with ```stream_info[„type“]``` by adding the datareader to the cases as:
```python=128
...
for _, stream_info in enumerate(cf.streams):
self.streams_datasets.append([])
for fname in stream_info["filenames"]:
...
match stream_info["type"]:
case "obs":
dataset = DataReaderObs
datapath = cf.data_path_obs
...
case "my_type":
dataset = MyDataReader
datapath = cf.data_path_obs
```
where ```"my_type"```is specified in the stream config (see above). For the datapath, we use the path to the shared scratch (see introduction) that is saved as a default in ```cf.data_path_obs``` (for path specs, see WeatherGenerator-private repo, hpc).
**Training config** The higher level folder ```config/streams``` contains config files with model parameters and training specifications for pre-training experiments. We will not use this file for fine-tuning. Still for consistency create a file for the experiment, ```config/streams/our_experiment.yaml``` that can be copied from the default_config. Change the top line to consider our new directory ```config/streams/our_stream/```.
```yaml
streams_directory: "./config/streams/our_stream/"
...
```
**Eval config**
This is the config in ```config/evaluate``` that be used for calling the evaluation after inference with the finetuned model. Strictly speaking not a file dependency of the datareader, but if we want to go through the full pipeline, we may want to give that one a look.
## Datareader implementation
If all of the above is set up, we can start implementing and debugging the datareader class MyDataReader in ```src/weathergen/datasets/my_data_reader.py```.
```python-repl=
class MyDataReader(DataReaderTimestep):
"""Data reader for satellite observations."""
def __init__(
self,
tw_handler: TimeWindowHandler,
filename: Path,
stream_info: dict,
) -> None:
"""Initializes a data reader."""
np32 = np.float32
...
def _statistics(self):
"""Loads or computes channel statistics"""
...
@override
def init_empty(self):
"""initilises empty dataset"""
super().init_empty()
self.ds = None
self.len = 0
@override
def length(self) -> int:
"""returns length of the data set"""
return self.len
@override
def _get(self, idx: TIndex, channels_idx: np.array):
"""
Get data for window (for either source or target, through public interface)
"""
...
def select_channels(self, ds, ch_type: str):
"""Select channels based on stream info for either source or target."""
```
This is inherits from [MyDataReaderTimestep](https://github.com/ecmwf/WeatherGenerator/blob/main/src/weathergen/datasets/data_reader_base.py). We will implement or modify the following class functions:
- **\_\_init__**: Initialises the datareader; opens the data store and gathers the information needed for training.
- **_statistics**: Defines how to load or compute the statistics.
- **length**: Returns the length of the dataset.
- **_get**: Serves the data. From a time window index, it randomly selects one time step. channels_idx allows to reuse the functions for source and target.
- **select_channels**: Preimplemented; modify for special conveniences.
**\_\_init__** *Arguments: tw_handler: class TimeWindowHandler; filename: Path
; stream_info: dict*
During initilisation, the data store is opened as a dataset, the length of the data set is defined for sampling, the source and target channels and their indices are selected from the dataset, as well as their statistics. Here, a couple of class attributes have to be specified that are needed by MyDataReaderTimestep or the MultiStreamSampler.
Assuming the data store is provided in a zarr format, the first step is to open it, e.g. here with xarray and zarr as:
```python
ds = xr.open_dataset(filename, group= "seviri", engine="zarr“)
```
The data store must contain time, latitude and longitude as coordinates that are readable and accessible, i.e. with ```ds.time```. A TimeWindowHandler is given to MyDataReader as argument, that stores the given time window of the sample. A quick check is performed that ensures the data store contains data within that window and otherwise initialises an empty dataset.
```
if tw_handler.t_start >= time_ds.time.max() or tw_handler.t_end <= time_ds.time.min():
name = stream_info["name"]
_logger.warning(f"{name} is not supported over data loader window. Stream is skipped.")
super().__init__(tw_handler, stream_info)
self.init_empty()
return
```
The time information and period of the dataset, the TimeWindowHandler and the stream_info (i.e. stream config) will be set in MyDataReaderTimestep:
```python
data_start_time = ds.time[idx_start].values
data_end_time = ds.time[idx_end].values
period = (data_end_time - data_start_time) # compute time period
# sets the time arguments and stream info in the base class
super().__init__(
tw_handler,
stream_info,
data_start_time,
data_end_time,
period,
)
```
*self.len*: Above, we specify ```idx_start``` and ```idx_end``` as the minimum and maximum time period in our data that we want to cover during the experiment. Then, another check is performed by the TimeWindowHandler on the selection of the time period. Important here is that the length of the data set ```self.len``` is defined, which will later be returned by the class function```length``` for sampling iterations.
```python
if tw_handler.t_start >= data_end_time or tw_handler.t_end <= data_start_time:
self.init_empty() # return an empty data reader
return
else:
self.ds = ds
self.len = idx_end - idx_start # or len(ds), if the whole dataset is used
```
*self.channels*: We need to set a list of all the channels that are present in our data set that will be used lateron for channel selection.
```
self.channels_file = [k for k in self.ds.keys()]
```
*self.latitudes, self.longitudes*: Then, ```self.latitudes``` and ```self.longitudes``` need to be set from the coordinates in the data set. If they are named differently, this can be specified directly in the stream config yaml. If only a spatial subset of the data is goal for the experiment, it can be selected here.
```python-repl
lat_name = stream_info.get("latitude_name", "latitude")
self.latitudes = _clip_lat(np.array(ds[lat_name], dtype=np32))
lon_name = stream_info.get("longitude_name", "longitude")
self.longitudes = _clip_lon(np.array(ds[lon_name], dtype=np32))
```
*self.source_idx, self.source_channels*: The MultiStreamSampler will require indices and channel information for all, the geoinformiation, the source and the target features. So we set them by using the class function ```select_channels``` as:
```python
self.source_idx, self.source_channels = self.select_channels(ds, "source")
```
*self.target_idx, self.target_channels*: See above.
*self.geoinfo_idx, self.geoinfo_channels*: See above.
*self.mean, self.stdev*: The indices will be used to access the statistics here and in the MultiStreamSampler. We compute or load it by:
```python
self.mean, self.stdev = self._statistics()
self.mean_geoinfo, self.stdev_geoinfo = self.mean[self.geoinfo_idx], self.stdev[self.geoinfo_idx]
```
*self.mean_geoinfo, self.stdev_geoinfo*: See above.
**Notes:**
- ```self.mean``` and ```self.stdev``` need to be provided for all variables that are specified in ```self.channels_file```!
- Because memory is a limited ressource, we may want to open our dataset lazily. This can be done directly in \__init__ or in a dedicated function, by using `zarr` to open and index the dataset.
**_statistics**
the statistics (self.mean, self.stdev) should have the length of your dataset (variables are specified in self.channels_file). These will be indexed in data_reader_base.py with normalize_target_channels().
Whatever this functions needs to do to retrieve the statistics.
**_get** *Arguments: idx: TIndex, channels_idx: list[int]*
This function is called when \__get_item__ returns a time window sample from the dataset. The arguments are idx, which extracts that specific window, and channels_idx, which is the list of channels that will be selected.
_get assembles *coordinates*, *geoinformation*, *data* and *time* into a [ReaderData](https://github.com/ecmwf/WeatherGenerator/blob/main/src/weathergen/datasets/data_reader_base.py) object that is returned to the Dataloader. All these four are arrays that are handed over to ReaderData in a time-major, flat format: The shapes will be `(time_steps*spatial_points, N)`.
This function should fail earlier and return an empty ReaderData, if data or channels are not available in this time window, hence we add
```python
if len(t_idxs) == 0 or len(channels_idx) == 0:
return ReaderData.empty(
num_data_fields=len(channels_idx), num_geo_fields=len(self.geoinfo_idx)
)
```
Store the first and last indices of the time window that will be used for indexing the datastore.
```python
didx_start = t_idxs[0]
didx_end = t_idxs[-1]
```
Then we follow the four steps:
1. Select the data at time, space and channels. Reshape to `(time_steps*spatial_points, len(channels))`.
2. Select geoinformation, tile along time window (because is static), reshape to `(time_steps*spatial_points, len(geoinfos))`
3. Construct coordinates as a full meshgrid such that each data point has an own lat lon pair. Concatenate and tile along time window, such that is has shape `(time_steps*spatial_points, 2)`
4. Create a one dimensional time vector by repeating timesteps for spatial points such that is has shape `(time_steps*spatial_points, )`