Skip to content
Snippets Groups Projects
Commit 30b0542a authored by Clara Betancourt's avatar Clara Betancourt
Browse files

Merge branch 'clara_issue030_train_on_reduced_dataset' into devel

parents 450f9272 8f6ff11f
No related branches found
No related tags found
No related merge requests found
"""
Our experiments have shown that there exist important and
unimportant stations in the AQ-Bench dataset.
Here we see what happens if we discard the unimportant ones.
"""
# general
import os
import pdb
# data science
import numpy as np
import pandas as pd
# plotting
import matplotlib.pyplot as plt
import cartopy
import cartopy.crs as ccrs
import cartopy.feature as cfeature
# own package
import settings
from preprocessing.data_for_clustering import DataForClustering
from experiments.nearest_neighbors_activations import NearestNeighbor
from models.random_forest import RandomForest
from models.neural_network import NeuralNetwork
__author__ = 'Clara Betancourt'
class RedAQbench:
"""
A class for training on a reduced dataset
"""
def __init__(self, model):
"""
Initialize the class.
"""
print(f'Initialize RedAQB with {model}')
self.model = model
self.aqb = pd.read_csv(settings.resources_dir +
settings.AQbench_dataset_file,
index_col='id')
def find_unimportant_stations(self):
"""
Read in the nene dataset, then check if there are unimportant
entries
"""
print('Find unimportant stations...')
# a df with all contribs
dfc = DataForClustering()
dfc.get_data(how='read')
df_template = dfc.data['rf_activations']
if self.model == 'rf':
df = df_template
elif self.model == 'nn':
df = pd.DataFrame(index=df_template.index,
columns=df_template.columns,
data=df_template.values*0.)
# We need the nearest neighbors.
nene = NearestNeighbor('nn', 'egu_pruned_nn_scale_targets_true/')
nene.read_data_from_pickle()
# Todo: make sure that the contribs and the ids are in
# the right order.
for index, row in nene.df.iterrows():
test_id = index
train_ids = row.contrib_ids
contribs = row.contribs
for idx, train_id in enumerate(train_ids):
df.at[test_id, train_id] = contribs[idx]
# sum all contributions of all predictions
df.loc['total'] = df.sum(axis=0)
total_contribs = df.loc['total', :]
# a histogram of the contributions
if True:
plt.hist(total_contribs, bins=100)
plt.xlabel('Sum of contributions over test set')
plt.ylabel('Number of stations with that contribution sum in training set')
plt.savefig(settings.output_dir+f'red_aqb_contrib_hist_{self.model}.png')
plt.close()
# 10th percentile for least important stations
threshold = np.percentile(total_contribs, 10)
unimportant_ids = total_contribs[total_contribs<threshold].index
self.unimportant_ids = list(unimportant_ids)
def plot_unimportant_ids_on_map(self):
"""
A map plot of unimportant stations
"""
lons = self.aqb['lon'].loc[self.unimportant_ids]
lats = self.aqb['lat'].loc[self.unimportant_ids]
# plot the map
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_extent([-140, 154, -57, 84],
crs=ccrs.PlateCarree())
ax.set_facecolor('black')
ax.add_feature(cartopy.feature.LAND.with_scale('50m'),
color='dimgray', alpha=0.7, linewidths=0.,
zorder=1)
plt.scatter(lons, y=lats, c='r', zorder=3,
marker='X', lw=0.03, edgecolors='white',
s=3, transform=ccrs.PlateCarree())
plt.savefig(settings.output_dir +
f'red_aqb_map_{self.model}.png',
bbox_inches='tight', pad_inches=0, dpi=500)
plt.close()
def train_on_important_data(self):
"""
Let's see how the performance drops.
"""
print('Train reference and reduced...')
ref_model_dir = f'reference_{self.model}/' #Scarlet ??
red_model_dir = f'reduced_{self.model}/' #Scarlet ??
if self.model == 'rf':
Model = RandomForest
elif self.model == 'nn':
Model = NeuralNetwork
# Train reference
ref_model = Model(output_folder=ref_model_dir)
ref_model.run_training(inputs='egu_pruned',
targets=['o3_average_values'],
plot=False)
# Save the original aqb
os.rename(settings.resources_dir+settings.AQbench_dataset_file,
settings.resources_dir+settings.AQbench_dataset_file+'_orig')
# Drop unimportant stations from aqb, then save it
red_aqb = self.aqb.drop(self.unimportant_ids)
red_aqb.to_csv(settings.resources_dir +
settings.AQbench_dataset_file)
# Train on red aqb
red_model = Model(output_folder=red_model_dir)
red_model.run_training(inputs='egu_pruned',
targets=['o3_average_values'],
plot=False)
# Put the aqb back where it belongs
os.rename(settings.resources_dir+settings.AQbench_dataset_file+
'_orig',
settings.resources_dir+settings.AQbench_dataset_file)
# set reduced aqb as class var
self.red_aqb = red_aqb
def training_experiment():
"""
Training on the reduced datasets
"""
print('Conduct training experiment on reduced AQBench')
raqbs = {}
for model in ['rf', 'nn']:
raqb = RedAQbench(model)
raqb.find_unimportant_stations()
raqb.plot_unimportant_ids_on_map()
raqb.train_on_important_data()
raqbs[model] = raqb
# RF
# Reference
# mse(test): 19.8935
# mae(test): 3.3032
# r2(test): 53.0299
# Reduced
# mse(test): 20.1962
# mae(test): 3.3438
# r2(test): 52.3152
# NN
# Reference
# mse(test): 21.028900146484375
# mae(test): 3.4154999256134033
# r2(test): 50.3493
# Reduced
# mse(test): 22.257200241088867
# mae(test): 3.5109000205993652
# r2(test): 47.4491
def compare_unimportant_stations():
"""
Comparing dataset reductions
"""
print('Compare unimportant stations for RF and NN')
# find the ids
raqbs = {}
for model in ['rf', 'nn']:
raqb = RedAQbench(model)
raqb.find_unimportant_stations()
raqbs[model] = raqb
nn_u_ids = raqbs['nn'].unimportant_ids
rf_u_ids = raqbs['rf'].unimportant_ids
intersect_ids = [id_ for id_ in nn_u_ids if id_ in rf_u_ids]
only_nn_ids = [id_ for id_ in nn_u_ids if not id_ in rf_u_ids]
only_rf_ids = [id_ for id_ in rf_u_ids if not id_ in nn_u_ids]
# plot preparation
ax = plt.axes(projection=ccrs.PlateCarree())
ax.set_extent([-140, 154, -57, 84],
crs=ccrs.PlateCarree())
ax.add_feature(cartopy.feature.LAND.with_scale('50m'),
color='dimgray', alpha=0.7, linewidths=0.,
zorder=1)
ax.set_facecolor('black')
specs = [
('only nn', 'lightcoral', only_nn_ids, 2),
('only rf', 'cornflowerblue', only_rf_ids, 2),
('both', 'mediumorchid', intersect_ids, 3)
]
aqb = pd.read_csv(settings.resources_dir +
settings.AQbench_dataset_file,
index_col='id')
for label, color, ids, zorder in specs:
print(f'found {len(ids)} of {label}:')
print(ids)
# plot the map
lons = aqb['lon'].loc[ids]
lats = aqb['lat'].loc[ids]
plt.scatter(lons, y=lats, c=color, zorder=zorder,
marker='X', lw=0.03, edgecolors='white',
s=4.5, transform=ccrs.PlateCarree(), label=label)
# Finalize plot
plt.legend(fontsize=4)
plt.savefig(settings.output_dir +
f'red_aqb_map_compare.png',
bbox_inches='tight', pad_inches=0, dpi=1000)
plt.close()
# found 154 of only nn
# found 153 of only rf
# found 181 of both
pdb.set_trace()
if __name__ == '__main__':
"""
Conduct the experiment.
"""
training_experiment()
compare_unimportant_stations()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment