diff --git a/source/experiments/train_on_reduced_dataset.py b/source/experiments/train_on_reduced_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ed0c9bc11d529507584f70ba15bad74c5a737a1
--- /dev/null
+++ b/source/experiments/train_on_reduced_dataset.py
@@ -0,0 +1,256 @@
+"""
+Our experiments have shown that there exist important and
+unimportant stations in the AQ-Bench dataset.
+Here we see what happens if we discard the unimportant ones.
+"""
+
+# general
+import os
+import pdb
+
+# data science
+import numpy as np
+import pandas as pd
+
+# plotting
+import matplotlib.pyplot as plt
+import cartopy
+import cartopy.crs as ccrs
+import cartopy.feature as cfeature
+
+# own package
+import settings
+from preprocessing.data_for_clustering import DataForClustering
+from experiments.nearest_neighbors_activations import NearestNeighbor
+from models.random_forest import RandomForest
+from models.neural_network import NeuralNetwork
+
+
+__author__ = 'Clara Betancourt'
+
+
+class RedAQbench:
+    """
+    A class for training on a reduced dataset
+    """
+    def __init__(self, model):
+        """
+        Initialize the class.
+        """
+        print(f'Initialize RedAQB with {model}')
+        self.model = model
+        self.aqb = pd.read_csv(settings.resources_dir +
+                               settings.AQbench_dataset_file,
+                               index_col='id')
+
+    def find_unimportant_stations(self):
+        """
+        Read in the nene dataset, then check if there are unimportant
+        entries
+        """
+        print('Find unimportant stations...')
+        # a df with all contribs
+        dfc = DataForClustering()
+        dfc.get_data(how='read')
+        df_template = dfc.data['rf_activations']
+
+        if self.model == 'rf':
+            df = df_template
+        elif self.model == 'nn':
+            df = pd.DataFrame(index=df_template.index,
+                              columns=df_template.columns,
+                              data=df_template.values*0.)
+            # We need the nearest neighbors.
+            nene = NearestNeighbor('nn', 'egu_pruned_nn_scale_targets_true/')
+            nene.read_data_from_pickle()
+            # Todo: make sure that the contribs and the ids are in
+            # the right order.
+            for index, row in nene.df.iterrows():
+                test_id = index
+                train_ids = row.contrib_ids
+                contribs = row.contribs
+                for idx, train_id in enumerate(train_ids):
+                    df.at[test_id, train_id] = contribs[idx]
+
+        # sum all contributions of all predictions
+        df.loc['total'] = df.sum(axis=0)
+        total_contribs = df.loc['total', :]
+
+        # a histogram of the contributions
+        if True:
+            plt.hist(total_contribs, bins=100)
+            plt.xlabel('Sum of contributions over test set')
+            plt.ylabel('Number of stations with that contribution sum in training set')
+            plt.savefig(settings.output_dir+f'red_aqb_contrib_hist_{self.model}.png')
+            plt.close()
+
+        # 10th percentile for least important stations
+        threshold = np.percentile(total_contribs, 10)
+        unimportant_ids = total_contribs[total_contribs<threshold].index
+        self.unimportant_ids = list(unimportant_ids)
+
+    def plot_unimportant_ids_on_map(self):
+        """
+        A map plot of unimportant stations
+        """
+        lons = self.aqb['lon'].loc[self.unimportant_ids]
+        lats = self.aqb['lat'].loc[self.unimportant_ids]
+
+        # plot the map
+        ax = plt.axes(projection=ccrs.PlateCarree())
+        ax.set_extent([-140, 154, -57, 84],
+                      crs=ccrs.PlateCarree())
+        ax.set_facecolor('black')
+        ax.add_feature(cartopy.feature.LAND.with_scale('50m'),
+                       color='dimgray', alpha=0.7, linewidths=0.,
+                       zorder=1)
+        plt.scatter(lons, y=lats, c='r', zorder=3,
+                    marker='X', lw=0.03, edgecolors='white',
+                    s=3, transform=ccrs.PlateCarree())
+        plt.savefig(settings.output_dir +
+                    f'red_aqb_map_{self.model}.png',
+                    bbox_inches='tight', pad_inches=0, dpi=500)
+        plt.close()
+
+    def train_on_important_data(self):
+        """
+        Let's see how the performance drops.
+        """
+        print('Train reference and reduced...')
+        ref_model_dir = f'reference_{self.model}/' #Scarlet ??
+        red_model_dir = f'reduced_{self.model}/' #Scarlet ??
+        if self.model == 'rf':
+            Model = RandomForest
+        elif self.model == 'nn':
+            Model = NeuralNetwork
+
+        # Train reference
+        ref_model = Model(output_folder=ref_model_dir)
+        ref_model.run_training(inputs='egu_pruned',
+                               targets=['o3_average_values'],
+                               plot=False)
+
+        # Save the original aqb
+        os.rename(settings.resources_dir+settings.AQbench_dataset_file,
+                  settings.resources_dir+settings.AQbench_dataset_file+'_orig')
+
+        # Drop unimportant stations from aqb, then save it
+        red_aqb = self.aqb.drop(self.unimportant_ids)
+        red_aqb.to_csv(settings.resources_dir +
+                       settings.AQbench_dataset_file)
+
+        # Train on red aqb
+        red_model = Model(output_folder=red_model_dir)
+        red_model.run_training(inputs='egu_pruned',
+                               targets=['o3_average_values'],
+                               plot=False)
+
+        # Put the aqb back where it belongs
+        os.rename(settings.resources_dir+settings.AQbench_dataset_file+
+                  '_orig',
+                  settings.resources_dir+settings.AQbench_dataset_file)
+
+        # set reduced aqb as class var
+        self.red_aqb = red_aqb
+
+
+def training_experiment():
+    """
+    Training on the reduced datasets
+    """
+    print('Conduct training experiment on reduced AQBench')
+    raqbs = {}
+    for model in ['rf', 'nn']:
+        raqb = RedAQbench(model)
+        raqb.find_unimportant_stations()
+        raqb.plot_unimportant_ids_on_map()
+        raqb.train_on_important_data()
+        raqbs[model] = raqb
+
+    # RF
+    # Reference
+    # mse(test): 19.8935
+    # mae(test): 3.3032
+    # r2(test): 53.0299
+
+    # Reduced
+    # mse(test): 20.1962
+    # mae(test): 3.3438
+    # r2(test): 52.3152
+
+    # NN
+    # Reference
+    # mse(test): 21.028900146484375
+    # mae(test): 3.4154999256134033
+    # r2(test): 50.3493
+
+    # Reduced
+    # mse(test): 22.257200241088867
+    # mae(test): 3.5109000205993652
+    # r2(test): 47.4491
+
+
+def compare_unimportant_stations():
+    """
+    Comparing dataset reductions
+    """
+    print('Compare unimportant stations for RF and NN')
+    # find the ids
+    raqbs = {}
+    for model in ['rf', 'nn']:
+        raqb = RedAQbench(model)
+        raqb.find_unimportant_stations()
+        raqbs[model] = raqb
+    nn_u_ids = raqbs['nn'].unimportant_ids
+    rf_u_ids = raqbs['rf'].unimportant_ids
+    intersect_ids = [id_ for id_ in nn_u_ids if id_ in rf_u_ids]
+    only_nn_ids = [id_ for id_ in nn_u_ids if not id_ in rf_u_ids]
+    only_rf_ids = [id_ for id_ in rf_u_ids if not id_ in nn_u_ids]
+
+    # plot preparation
+    ax = plt.axes(projection=ccrs.PlateCarree())
+    ax.set_extent([-140, 154, -57, 84],
+                  crs=ccrs.PlateCarree())
+    ax.add_feature(cartopy.feature.LAND.with_scale('50m'),
+                   color='dimgray', alpha=0.7, linewidths=0.,
+                   zorder=1)
+    ax.set_facecolor('black')
+    specs = [
+          ('only nn', 'lightcoral', only_nn_ids, 2),
+          ('only rf', 'cornflowerblue', only_rf_ids, 2),
+          ('both', 'mediumorchid', intersect_ids, 3)
+          ]
+    aqb = pd.read_csv(settings.resources_dir +
+                      settings.AQbench_dataset_file,
+                      index_col='id')
+
+    for label, color, ids, zorder in specs:
+        print(f'found {len(ids)} of {label}:')
+        print(ids)
+
+        # plot the map
+        lons = aqb['lon'].loc[ids]
+        lats = aqb['lat'].loc[ids]
+
+        plt.scatter(lons, y=lats, c=color, zorder=zorder,
+                    marker='X', lw=0.03, edgecolors='white',
+                    s=4.5, transform=ccrs.PlateCarree(), label=label)
+
+    # Finalize plot
+    plt.legend(fontsize=4)
+    plt.savefig(settings.output_dir +
+                f'red_aqb_map_compare.png',
+                bbox_inches='tight', pad_inches=0, dpi=1000)
+    plt.close()
+
+    # found 154 of only nn
+    # found 153 of only rf
+    # found 181 of both
+    pdb.set_trace()
+
+if __name__ == '__main__':
+    """
+    Conduct the experiment.
+    """
+    training_experiment()
+    compare_unimportant_stations()