Extract number of examples per data-set per station
We should add a method to extract the number of samples which fall into the train, val, and test set for each station individually. In the following, an example function which I used in MA.
def get_train_val_test_split_summary(table_name, gen_dict, values_per_table=None, **kwargs):
'''
This function creates a latex table containing the Station IDs as index, and number of vallid data points per
station per generator as well as used_meta_data:
could look like this
\begin{tabular}{llrrrlll}
\toprule
{} & station\_name & station\_lon & station\_lat & station\_alt & train & val & test \\
\midrule
DENW094 & Aachen-Burtscheid & 6.0939 & 50.7547 & 205.0 & 1875 & 584 & 1032 \\
DEBW029 & Aalen & 10.0963 & 48.8479 & 424.0 & 2958 & 715 & 1080 \\
DENI052 & Allertal & 9.6230 & 52.8294 & 38.0 & 2790 & 497 & 1080 \\
:param values_per_table:
:param table_name:
:param gen_dict:
:param kwargs:
:return:
'''
used_meta_data = kwargs.pop('used_meta_data', ['station_name', 'station_lon', 'station_lat', 'station_alt'])
meta_to_round = kwargs.pop('meta_to_round', ['station_lon', 'station_lat', 'station_alt'])
file_name_ = table_name.split('.')
file_name = file_name_[0]
if len(file_name_) > 2:
raise SyntaxError('only one "." is allowed, got at least two')
elif len(file_name_) == 1:
file_ending = 'tex'
if table_name.split('.')[-1] == 'tex':
file_ending = file_name_[-1]
elif table_name.split('.')[-1] == file_name:
pass
else:
raise NotImplementedError('only .tex files are supported')
df = pd.DataFrame(columns=used_meta_data + list(gen_dict.keys()))
for k, v in gen_dict.items():
for count, value in enumerate(v):
station = v[count][0].coords['Stations'].values.tolist()[0]
df.loc[station, k] = value[0].shape[0]
if True in [df.loc[station, curr_meta] is np.nan for _, curr_meta in enumerate(used_meta_data)]:
df.loc[station, used_meta_data] = v.get_data_generator(
station).meta.loc[used_meta_data,:].values.squeeze()
df[meta_to_round] = df[meta_to_round].astype(float).round(4)
df.sort_index(inplace=True)
df.index.name = 'stat. ID'
column_format = np.repeat('c', df.shape[1]+1)
column_format[0] = 'l'
column_format[-1] = 'r'
column_format = ''.join(column_format.tolist())
print('ok')
print('values_per_table: ', values_per_table)
if values_per_table is None:
df.to_latex('.'.join([file_name, file_ending]), na_rep='---', column_format=column_format).replace(
'\\toprule', '\\hline').replace('\\midrule', '\\hline').replace('\\bottomrule', '\\hline')
else:
num = 0
print('ok')
for i in range(0, df.shape[0]+values_per_table, values_per_table):
hdf = df.iloc[i:i+values_per_table]
if hdf.shape[0] >0:
print('df[{}:{}].shape: '.format(i, i + values_per_table), df.iloc[i:i + values_per_table].shape)
#hdf.to_latex('.'.join([file_name+'_{:02d}'.format(num), file_ending]), na_rep='---', column_format=column_format)
# hdf.to_latex('.'.join([file_name + '_{}'.format(num), file_ending]), na_rep='---',
# column_format=column_format).replace('\\toprule', '\\hline').replace(
# '\\midrule', '\\hline').replace('\\bottomrule', '\\hline')
hdf_str = hdf.to_latex(na_rep='---',
column_format=column_format).replace('\\toprule', '\\hline').replace(
'\\midrule', '\\hline').replace('\\bottomrule', '\\hline')
with open('.'.join([file_name + '_{}'.format(num), file_ending]), 'w') as f:
f.write(hdf_str)
num += 1
return df
It might also be a good idea to add a column having the MSE on the test set of each station