Skip to content

Extract number of examples per data-set per station

We should add a method to extract the number of samples which fall into the train, val, and test set for each station individually. In the following, an example function which I used in MA.

def get_train_val_test_split_summary(table_name, gen_dict, values_per_table=None, **kwargs):
    '''
    This function creates a latex table containing the Station IDs as index, and number of vallid data points per
    station per generator as well as used_meta_data:
    could look like this
    \begin{tabular}{llrrrlll}
    \toprule
    {} &                 station\_name &  station\_lon &  station\_lat &  station\_alt & train &  val &  test \\
    \midrule
    DENW094 &            Aachen-Burtscheid &       6.0939 &      50.7547 &        205.0 &  1875 &  584 &  1032 \\
    DEBW029 &                        Aalen &      10.0963 &      48.8479 &        424.0 &  2958 &  715 &  1080 \\
    DENI052 &                     Allertal &       9.6230 &      52.8294 &         38.0 &  2790 &  497 &  1080 \\

    :param values_per_table:
    :param table_name:
    :param gen_dict:
    :param kwargs:
    :return:
    '''
    used_meta_data = kwargs.pop('used_meta_data', ['station_name', 'station_lon', 'station_lat', 'station_alt'])
    meta_to_round = kwargs.pop('meta_to_round', ['station_lon', 'station_lat', 'station_alt'])

    file_name_ = table_name.split('.')
    file_name = file_name_[0]
    if len(file_name_) > 2:
        raise SyntaxError('only one "." is allowed, got at least two')
    elif len(file_name_) == 1:
        file_ending = 'tex'
    if table_name.split('.')[-1] == 'tex':
        file_ending = file_name_[-1]
    elif table_name.split('.')[-1] == file_name:
        pass
    else:
        raise NotImplementedError('only .tex files are supported')

    df = pd.DataFrame(columns=used_meta_data + list(gen_dict.keys()))
    for k, v in gen_dict.items():
        for count, value in enumerate(v):
            station = v[count][0].coords['Stations'].values.tolist()[0]
            df.loc[station, k] = value[0].shape[0]
            if True in [df.loc[station, curr_meta] is np.nan for _, curr_meta in enumerate(used_meta_data)]:
                df.loc[station, used_meta_data] = v.get_data_generator(
                    station).meta.loc[used_meta_data,:].values.squeeze()
    df[meta_to_round] = df[meta_to_round].astype(float).round(4)
    df.sort_index(inplace=True)
    df.index.name = 'stat. ID'
    column_format = np.repeat('c', df.shape[1]+1)
    column_format[0] = 'l'
    column_format[-1] = 'r'
    column_format = ''.join(column_format.tolist())
    print('ok')
    print('values_per_table: ', values_per_table)
    if values_per_table is None:
        df.to_latex('.'.join([file_name, file_ending]), na_rep='---', column_format=column_format).replace(
            '\\toprule', '\\hline').replace('\\midrule', '\\hline').replace('\\bottomrule', '\\hline')
    else:
        num = 0
        print('ok')
        for i in range(0, df.shape[0]+values_per_table, values_per_table):
            hdf = df.iloc[i:i+values_per_table]
            if hdf.shape[0] >0:
                print('df[{}:{}].shape: '.format(i, i + values_per_table), df.iloc[i:i + values_per_table].shape)
                #hdf.to_latex('.'.join([file_name+'_{:02d}'.format(num),  file_ending]), na_rep='---', column_format=column_format)
                # hdf.to_latex('.'.join([file_name + '_{}'.format(num), file_ending]), na_rep='---',
                #              column_format=column_format).replace('\\toprule', '\\hline').replace(
                #     '\\midrule', '\\hline').replace('\\bottomrule', '\\hline')
                hdf_str = hdf.to_latex(na_rep='---',
                              column_format=column_format).replace('\\toprule', '\\hline').replace(
                    '\\midrule', '\\hline').replace('\\bottomrule', '\\hline')
                with open('.'.join([file_name + '_{}'.format(num), file_ending]), 'w') as f:
                    f.write(hdf_str)
            num += 1
    return df

It might also be a good idea to add a column having the MSE on the test set of each station