diff --git a/source/experiments/train_on_reduced_dataset.py b/source/experiments/train_on_reduced_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..2ed0c9bc11d529507584f70ba15bad74c5a737a1 --- /dev/null +++ b/source/experiments/train_on_reduced_dataset.py @@ -0,0 +1,256 @@ +""" +Our experiments have shown that there exist important and +unimportant stations in the AQ-Bench dataset. +Here we see what happens if we discard the unimportant ones. +""" + +# general +import os +import pdb + +# data science +import numpy as np +import pandas as pd + +# plotting +import matplotlib.pyplot as plt +import cartopy +import cartopy.crs as ccrs +import cartopy.feature as cfeature + +# own package +import settings +from preprocessing.data_for_clustering import DataForClustering +from experiments.nearest_neighbors_activations import NearestNeighbor +from models.random_forest import RandomForest +from models.neural_network import NeuralNetwork + + +__author__ = 'Clara Betancourt' + + +class RedAQbench: + """ + A class for training on a reduced dataset + """ + def __init__(self, model): + """ + Initialize the class. + """ + print(f'Initialize RedAQB with {model}') + self.model = model + self.aqb = pd.read_csv(settings.resources_dir + + settings.AQbench_dataset_file, + index_col='id') + + def find_unimportant_stations(self): + """ + Read in the nene dataset, then check if there are unimportant + entries + """ + print('Find unimportant stations...') + # a df with all contribs + dfc = DataForClustering() + dfc.get_data(how='read') + df_template = dfc.data['rf_activations'] + + if self.model == 'rf': + df = df_template + elif self.model == 'nn': + df = pd.DataFrame(index=df_template.index, + columns=df_template.columns, + data=df_template.values*0.) + # We need the nearest neighbors. + nene = NearestNeighbor('nn', 'egu_pruned_nn_scale_targets_true/') + nene.read_data_from_pickle() + # Todo: make sure that the contribs and the ids are in + # the right order. + for index, row in nene.df.iterrows(): + test_id = index + train_ids = row.contrib_ids + contribs = row.contribs + for idx, train_id in enumerate(train_ids): + df.at[test_id, train_id] = contribs[idx] + + # sum all contributions of all predictions + df.loc['total'] = df.sum(axis=0) + total_contribs = df.loc['total', :] + + # a histogram of the contributions + if True: + plt.hist(total_contribs, bins=100) + plt.xlabel('Sum of contributions over test set') + plt.ylabel('Number of stations with that contribution sum in training set') + plt.savefig(settings.output_dir+f'red_aqb_contrib_hist_{self.model}.png') + plt.close() + + # 10th percentile for least important stations + threshold = np.percentile(total_contribs, 10) + unimportant_ids = total_contribs[total_contribs<threshold].index + self.unimportant_ids = list(unimportant_ids) + + def plot_unimportant_ids_on_map(self): + """ + A map plot of unimportant stations + """ + lons = self.aqb['lon'].loc[self.unimportant_ids] + lats = self.aqb['lat'].loc[self.unimportant_ids] + + # plot the map + ax = plt.axes(projection=ccrs.PlateCarree()) + ax.set_extent([-140, 154, -57, 84], + crs=ccrs.PlateCarree()) + ax.set_facecolor('black') + ax.add_feature(cartopy.feature.LAND.with_scale('50m'), + color='dimgray', alpha=0.7, linewidths=0., + zorder=1) + plt.scatter(lons, y=lats, c='r', zorder=3, + marker='X', lw=0.03, edgecolors='white', + s=3, transform=ccrs.PlateCarree()) + plt.savefig(settings.output_dir + + f'red_aqb_map_{self.model}.png', + bbox_inches='tight', pad_inches=0, dpi=500) + plt.close() + + def train_on_important_data(self): + """ + Let's see how the performance drops. + """ + print('Train reference and reduced...') + ref_model_dir = f'reference_{self.model}/' #Scarlet ?? + red_model_dir = f'reduced_{self.model}/' #Scarlet ?? + if self.model == 'rf': + Model = RandomForest + elif self.model == 'nn': + Model = NeuralNetwork + + # Train reference + ref_model = Model(output_folder=ref_model_dir) + ref_model.run_training(inputs='egu_pruned', + targets=['o3_average_values'], + plot=False) + + # Save the original aqb + os.rename(settings.resources_dir+settings.AQbench_dataset_file, + settings.resources_dir+settings.AQbench_dataset_file+'_orig') + + # Drop unimportant stations from aqb, then save it + red_aqb = self.aqb.drop(self.unimportant_ids) + red_aqb.to_csv(settings.resources_dir + + settings.AQbench_dataset_file) + + # Train on red aqb + red_model = Model(output_folder=red_model_dir) + red_model.run_training(inputs='egu_pruned', + targets=['o3_average_values'], + plot=False) + + # Put the aqb back where it belongs + os.rename(settings.resources_dir+settings.AQbench_dataset_file+ + '_orig', + settings.resources_dir+settings.AQbench_dataset_file) + + # set reduced aqb as class var + self.red_aqb = red_aqb + + +def training_experiment(): + """ + Training on the reduced datasets + """ + print('Conduct training experiment on reduced AQBench') + raqbs = {} + for model in ['rf', 'nn']: + raqb = RedAQbench(model) + raqb.find_unimportant_stations() + raqb.plot_unimportant_ids_on_map() + raqb.train_on_important_data() + raqbs[model] = raqb + + # RF + # Reference + # mse(test): 19.8935 + # mae(test): 3.3032 + # r2(test): 53.0299 + + # Reduced + # mse(test): 20.1962 + # mae(test): 3.3438 + # r2(test): 52.3152 + + # NN + # Reference + # mse(test): 21.028900146484375 + # mae(test): 3.4154999256134033 + # r2(test): 50.3493 + + # Reduced + # mse(test): 22.257200241088867 + # mae(test): 3.5109000205993652 + # r2(test): 47.4491 + + +def compare_unimportant_stations(): + """ + Comparing dataset reductions + """ + print('Compare unimportant stations for RF and NN') + # find the ids + raqbs = {} + for model in ['rf', 'nn']: + raqb = RedAQbench(model) + raqb.find_unimportant_stations() + raqbs[model] = raqb + nn_u_ids = raqbs['nn'].unimportant_ids + rf_u_ids = raqbs['rf'].unimportant_ids + intersect_ids = [id_ for id_ in nn_u_ids if id_ in rf_u_ids] + only_nn_ids = [id_ for id_ in nn_u_ids if not id_ in rf_u_ids] + only_rf_ids = [id_ for id_ in rf_u_ids if not id_ in nn_u_ids] + + # plot preparation + ax = plt.axes(projection=ccrs.PlateCarree()) + ax.set_extent([-140, 154, -57, 84], + crs=ccrs.PlateCarree()) + ax.add_feature(cartopy.feature.LAND.with_scale('50m'), + color='dimgray', alpha=0.7, linewidths=0., + zorder=1) + ax.set_facecolor('black') + specs = [ + ('only nn', 'lightcoral', only_nn_ids, 2), + ('only rf', 'cornflowerblue', only_rf_ids, 2), + ('both', 'mediumorchid', intersect_ids, 3) + ] + aqb = pd.read_csv(settings.resources_dir + + settings.AQbench_dataset_file, + index_col='id') + + for label, color, ids, zorder in specs: + print(f'found {len(ids)} of {label}:') + print(ids) + + # plot the map + lons = aqb['lon'].loc[ids] + lats = aqb['lat'].loc[ids] + + plt.scatter(lons, y=lats, c=color, zorder=zorder, + marker='X', lw=0.03, edgecolors='white', + s=4.5, transform=ccrs.PlateCarree(), label=label) + + # Finalize plot + plt.legend(fontsize=4) + plt.savefig(settings.output_dir + + f'red_aqb_map_compare.png', + bbox_inches='tight', pad_inches=0, dpi=1000) + plt.close() + + # found 154 of only nn + # found 153 of only rf + # found 181 of both + pdb.set_trace() + +if __name__ == '__main__': + """ + Conduct the experiment. + """ + training_experiment() + compare_unimportant_stations()