Skip to content
Snippets Groups Projects
Commit 30b0542a authored by Clara Betancourt's avatar Clara Betancourt
Browse files

Merge branch 'clara_issue030_train_on_reduced_dataset' into devel

parents 450f9272 8f6ff11f
Branches
Tags
No related merge requests found
"""
Our experiments have shown that there exist important and
unimportant stations in the AQ-Bench dataset.
Here we see what happens if we discard the unimportant ones.
"""
# general
import os
import pdb
# data science
import numpy as np
import pandas as pd
# plotting
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
# own package
import settings
from preprocessing.data_for_clustering import DataForClustering
from experiments.nearest_neighbors_activations import NearestNeighbor
from models.random_forest import RandomForest
from models.neural_network import NeuralNetwork
__author__ = 'Clara Betancourt'
class RedAQbench:
"""
A class for training on a reduced dataset
"""
def __init__(self, model):
"""
Initialize the class.
"""
print(f'Initialize RedAQB with {model}')
self.model = model
self.aqb = pd.read_csv(settings.resources_dir +
settings.AQbench_dataset_file,
index_col='id')
def find_unimportant_stations(self):
"""
Read in the nene dataset, then check if there are unimportant
entries
"""
print('Find unimportant stations...')
# a df with all contribs
dfc = DataForClustering()
dfc.get_data(how='read')
df_template = dfc.data['rf_activations']
if self.model == 'rf':
df = df_template
elif self.model == 'nn':
df = pd.DataFrame(index=df_template.index,
columns=df_template.columns,
data=df_template.values*0.)
# We need the nearest neighbors.
nene = NearestNeighbor('nn', 'egu_pruned_nn_scale_targets_true/')
nene.read_data_from_pickle()
# Todo: make sure that the contribs and the ids are in
# the right order.
for index, row in nene.df.iterrows():
test_id = index
train_ids = row.contrib_ids
contribs = row.contribs
for idx, train_id in enumerate(train_ids):
df.at[test_id, train_id] = contribs[idx]
# sum all contributions of all predictions
df.loc['total'] = df.sum(axis=0)
total_contribs = df.loc['total', :]
# a histogram of the contributions
if True:
plt.hist(total_contribs, bins=100)
plt.xlabel('Sum of contributions over test set')
plt.ylabel('Number of stations with that contribution sum in training set')
plt.savefig(settings.output_dir+f'red_aqb_contrib_hist_{self.model}.png')
plt.close()
# 10th percentile for least important stations
threshold = np.percentile(total_contribs, 10)
unimportant_ids = total_contribs[total_contribs<threshold].index
self.unimportant_ids = list(unimportant_ids)
def plot_unimportant_ids_on_map(self):
"""
A map plot of unimportant stations
"""
lons = self.aqb['lon'].loc[self.unimportant_ids]
lats = self.aqb['lat'].loc[self.unimportant_ids]
# plot the map
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_extent([-140, 154, -57, 84],
crs=ccrs.PlateCarree())
ax.set_facecolor('black')
ax.add_feature(cartopy.feature.LAND.with_scale('50m'),
color='dimgray', alpha=0.7, linewidths=0.,
zorder=1)
plt.scatter(lons, y=lats, c='r', zorder=3,
marker='X', lw=0.03, edgecolors='white',
s=3, transform=ccrs.PlateCarree())
plt.savefig(settings.output_dir +
f'red_aqb_map_{self.model}.png',
bbox_inches='tight', pad_inches=0, dpi=500)
plt.close()
def train_on_important_data(self):
"""
Let's see how the performance drops.
"""
print('Train reference and reduced...')
ref_model_dir = f'reference_{self.model}/' #Scarlet ??
red_model_dir = f'reduced_{self.model}/' #Scarlet ??
if self.model == 'rf':
Model = RandomForest
elif self.model == 'nn':
Model = NeuralNetwork
# Train reference
ref_model = Model(output_folder=ref_model_dir)
ref_model.run_training(inputs='egu_pruned',
targets=['o3_average_values'],
plot=False)
# Save the original aqb
os.rename(settings.resources_dir+settings.AQbench_dataset_file,
settings.resources_dir+settings.AQbench_dataset_file+'_orig')
# Drop unimportant stations from aqb, then save it
red_aqb = self.aqb.drop(self.unimportant_ids)
red_aqb.to_csv(settings.resources_dir +
settings.AQbench_dataset_file)
# Train on red aqb
red_model = Model(output_folder=red_model_dir)
red_model.run_training(inputs='egu_pruned',
targets=['o3_average_values'],
plot=False)
# Put the aqb back where it belongs
os.rename(settings.resources_dir+settings.AQbench_dataset_file+
'_orig',
settings.resources_dir+settings.AQbench_dataset_file)
# set reduced aqb as class var
self.red_aqb = red_aqb
def training_experiment():
"""
Training on the reduced datasets
"""
print('Conduct training experiment on reduced AQBench')
raqbs = {}
for model in ['rf', 'nn']:
raqb = RedAQbench(model)
raqb.find_unimportant_stations()
raqb.plot_unimportant_ids_on_map()
raqb.train_on_important_data()
raqbs[model] = raqb
# RF
# Reference
# mse(test): 19.8935
# mae(test): 3.3032
# r2(test): 53.0299
# Reduced
# mse(test): 20.1962
# mae(test): 3.3438
# r2(test): 52.3152
# NN
# Reference
# mse(test): 21.028900146484375
# mae(test): 3.4154999256134033
# r2(test): 50.3493
# Reduced
# mse(test): 22.257200241088867
# mae(test): 3.5109000205993652
# r2(test): 47.4491
def compare_unimportant_stations():
"""
Comparing dataset reductions
"""
print('Compare unimportant stations for RF and NN')
# find the ids
raqbs = {}
for model in ['rf', 'nn']:
raqb = RedAQbench(model)
raqb.find_unimportant_stations()
raqbs[model] = raqb
nn_u_ids = raqbs['nn'].unimportant_ids
rf_u_ids = raqbs['rf'].unimportant_ids
intersect_ids = [id_ for id_ in nn_u_ids if id_ in rf_u_ids]
only_nn_ids = [id_ for id_ in nn_u_ids if not id_ in rf_u_ids]
only_rf_ids = [id_ for id_ in rf_u_ids if not id_ in nn_u_ids]
# plot preparation
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_extent([-140, 154, -57, 84],
crs=ccrs.PlateCarree())
ax.add_feature(cartopy.feature.LAND.with_scale('50m'),
color='dimgray', alpha=0.7, linewidths=0.,
zorder=1)
ax.set_facecolor('black')
specs = [
('only nn', 'lightcoral', only_nn_ids, 2),
('only rf', 'cornflowerblue', only_rf_ids, 2),
('both', 'mediumorchid', intersect_ids, 3)
]
aqb = pd.read_csv(settings.resources_dir +
settings.AQbench_dataset_file,
index_col='id')
for label, color, ids, zorder in specs:
print(f'found {len(ids)} of {label}:')
print(ids)
# plot the map
lons = aqb['lon'].loc[ids]
lats = aqb['lat'].loc[ids]
plt.scatter(lons, y=lats, c=color, zorder=zorder,
marker='X', lw=0.03, edgecolors='white',
s=4.5, transform=ccrs.PlateCarree(), label=label)
# Finalize plot
plt.legend(fontsize=4)
plt.savefig(settings.output_dir +
f'red_aqb_map_compare.png',
bbox_inches='tight', pad_inches=0, dpi=1000)
plt.close()
# found 154 of only nn
# found 153 of only rf
# found 181 of both
pdb.set_trace()
if __name__ == '__main__':
"""
Conduct the experiment.
"""
training_experiment()
compare_unimportant_stations()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment