Skip to content
Snippets Groups Projects
Commit 038beaf2 authored by lukas leufen's avatar lukas leufen
Browse files

Merge branch 'lukas_issue090_feat_extract-station-sample-sizes' into 'develop'

Resolve "Extract number of examples per data-set per station"

See merge request toar/machinelearningtools!86
parents 26d51f3d 3b33060e
No related branches found
No related tags found
3 merge requests!90WIP: new release update,!89Resolve "release branch / CI on gpu",!86Resolve "Extract number of examples per data-set per station"
Pipeline #33998 passed
...@@ -53,6 +53,7 @@ seaborn==0.10.0 ...@@ -53,6 +53,7 @@ seaborn==0.10.0
--no-binary shapely Shapely==1.7.0 --no-binary shapely Shapely==1.7.0
six==1.11.0 six==1.11.0
statsmodels==0.11.1 statsmodels==0.11.1
tabulate
tensorboard==1.13.1 tensorboard==1.13.1
tensorflow-estimator==1.13.0 tensorflow-estimator==1.13.0
tensorflow==1.13.1 tensorflow==1.13.1
......
...@@ -53,6 +53,7 @@ seaborn==0.10.0 ...@@ -53,6 +53,7 @@ seaborn==0.10.0
--no-binary shapely Shapely==1.7.0 --no-binary shapely Shapely==1.7.0
six==1.11.0 six==1.11.0
statsmodels==0.11.1 statsmodels==0.11.1
tabulate
tensorboard==1.13.1 tensorboard==1.13.1
tensorflow-estimator==1.13.0 tensorflow-estimator==1.13.0
tensorflow-gpu==1.13.1 tensorflow-gpu==1.13.1
......
...@@ -3,10 +3,14 @@ __date__ = '2019-11-25' ...@@ -3,10 +3,14 @@ __date__ = '2019-11-25'
import logging import logging
import os
from typing import Tuple, Dict, List from typing import Tuple, Dict, List
import numpy as np
import pandas as pd
from src.data_handling.data_generator import DataGenerator from src.data_handling.data_generator import DataGenerator
from src.helpers import TimeTracking from src.helpers import TimeTracking, check_path_and_create
from src.join import EmptyQueryResult from src.join import EmptyQueryResult
from src.run_modules.run_environment import RunEnvironment from src.run_modules.run_environment import RunEnvironment
...@@ -54,6 +58,58 @@ class PreProcessing(RunEnvironment): ...@@ -54,6 +58,58 @@ class PreProcessing(RunEnvironment):
logging.debug(f"Number of test stations: {n_test}") logging.debug(f"Number of test stations: {n_test}")
logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'test')[0][0].shape}" logging.debug(f"TEST SHAPE OF GENERATOR CALL: {self.data_store.get('generator', 'test')[0][0].shape}"
f"{self.data_store.get('generator', 'test')[0][1].shape}") f"{self.data_store.get('generator', 'test')[0][1].shape}")
self.create_latex_report()
def create_latex_report(self):
"""
This function creates tables with information on the station meta data and a summary on subset sample sizes.
* station_sample_size.md: see table below
* station_sample_size.tex: same as table below, but as latex table
* station_sample_size_short.tex: reduced size table without any meta data besides station ID, as latex table
All tables are stored inside experiment_path inside the folder latex_report. The table format (e.g. which meta
data is highlighted) is currently hardcoded to have a stable table style. If further styles are needed, it is
better to add an additional style than modifying the existing table styles.
| stat. ID | station_name | station_lon | station_lat | station_alt | train | val | test |
|------------|-------------------------------------------|---------------|---------------|---------------|---------|-------|--------|
| DEBW013 | Stuttgart Bad Cannstatt | 9.2297 | 48.8088 | 235 | 1434 | 712 | 1080 |
| DEBW076 | Baden-Baden | 8.2202 | 48.7731 | 148 | 3037 | 722 | 710 |
| DEBW087 | Schwäbische_Alb | 9.2076 | 48.3458 | 798 | 3044 | 714 | 1087 |
| DEBW107 | Tübingen | 9.0512 | 48.5077 | 325 | 1803 | 715 | 1087 |
| DEBY081 | Garmisch-Partenkirchen/Kreuzeckbahnstraße | 11.0631 | 47.4764 | 735 | 2935 | 525 | 714 |
| # Stations | nan | nan | nan | nan | 6 | 6 | 6 |
| # Samples | nan | nan | nan | nan | 12253 | 3388 | 4678 |
"""
meta_data = ['station_name', 'station_lon', 'station_lat', 'station_alt']
meta_round = ["station_lon", "station_lat", "station_alt"]
precision = 4
path = os.path.join(self.data_store.get("experiment_path"), "latex_report")
check_path_and_create(path)
set_names = ["train", "val", "test"]
df = pd.DataFrame(columns=meta_data+set_names)
for set_name in set_names:
data: DataGenerator = self.data_store.get("generator", set_name)
for station in data.stations:
df.loc[station, set_name] = data.get_data_generator(station).get_transposed_label().shape[0]
if df.loc[station, meta_data].isnull().any():
df.loc[station, meta_data] = data.get_data_generator(station).meta.loc[meta_data].values.flatten()
df.loc["# Samples", set_name] = df.loc[:, set_name].sum()
df.loc["# Stations", set_name] = df.loc[:, set_name].count()
df[meta_round] = df[meta_round].astype(float).round(precision)
df.sort_index(inplace=True)
df = df.reindex(df.index.drop(["# Stations", "# Samples"]).to_list() + ["# Stations", "# Samples"], )
df.index.name = 'stat. ID'
column_format = np.repeat('c', df.shape[1]+1)
column_format[0] = 'l'
column_format[-1] = 'r'
column_format = ''.join(column_format.tolist())
df.to_latex(os.path.join(path, "station_sample_size.tex"), na_rep='---', column_format=column_format)
df.to_markdown(open(os.path.join(path, "station_sample_size.md"), mode="w", encoding='utf-8'), tablefmt="github")
df.drop(meta_data, axis=1).to_latex(os.path.join(path, "station_sample_size_short.tex"), na_rep='---',
column_format=column_format)
def split_train_val_test(self) -> None: def split_train_val_test(self) -> None:
""" """
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment