diff --git a/source/experiments/cs_uba_graph.py b/source/experiments/cs_uba_graph.py
index facfc45a6888d61e893e73436d4512f2291b13a8..69ba9262c6d9a3e3dee2bb471c5fb04c45dd481b 100644
--- a/source/experiments/cs_uba_graph.py
+++ b/source/experiments/cs_uba_graph.py
@@ -1,5 +1,9 @@
 """
-Try the correct and smooth algorithm for missing data imputation
+Correct and smooth algorithm for missing data imputation of the
+UBA dataset. These are the experiments we are publishing.
+
+Run apply_correct_and_smooth(save=False) to print the evaluations as
+published.
 """
 
 # general
@@ -175,15 +179,16 @@ def tune_correct_and_smooth():
         val_r2, val_rmse, val_d = csw.evaluate()
         tuning_df.at[count, ['r2', 'rmse', 'd']] = round(val_r2, 3), \
                                     round(val_rmse, 3), round(val_d, 3)
-        if count % 10 == 0: tuning_df.to_csv(settings.output_dir+'tuning_stm_scales.csv')
+        if count % 10 == 0: tuning_df.to_csv(settings.output_dir+'tuning.csv')
 
-    tuning_df.to_csv(settings.output_dir+'tuning_stm_scales.csv')
+    tuning_df.to_csv(settings.output_dir+'tuning.csv')
     print(f'tuning results written to {settings.output_dir}tuning.csv')
 
 
-def apply_correct_and_smooth():
+def apply_correct_and_smooth(save=False):
     """
     Finally, apply correct using training and validation set!
+    If save is set to true, the saved predictions are overwritten.
     """
     print(f"\n{'*'*30}\n** APPLY CORRECT AND SMOOTH **\n{'*'*30}\n")
     specs = [
@@ -204,10 +209,11 @@ def apply_correct_and_smooth():
                                num_smoothing_layers=num_smoothing_layers,
                                scale=scale)
         csw.evaluate()
-        dir_ = settings.resources_dir + 'models/'
-        file_ = f'{simple_model()}_cs_predictions_useval.pyt'
-        torch.save(csw.y_hat_smooth, dir_+file_)
-        print(f'written to {dir_+file_}\n\n')
+        if save:
+            dir_ = settings.resources_dir + 'models/'
+            file_ = f'{simple_model()}_cs_predictions_useval.pyt'
+            torch.save(csw.y_hat_smooth, dir_+file_)
+            print(f'written to {dir_+file_}\n\n')
 
 
 if __name__ == '__main__':
diff --git a/source/experiments/gat_aqbench_graph.py b/source/experiments/gat_aqbench_graph.py
index a8b420b2f35ad7af3514ee6ab4dc3e64523e0039..7fd9a4b845ff46687353668e14e583fa4ae0b79f 100644
--- a/source/experiments/gat_aqbench_graph.py
+++ b/source/experiments/gat_aqbench_graph.py
@@ -1,5 +1,5 @@
 """
-Trying some Graph ML architectures on AQ-Bench
+Training a pytorch geometric graph attention network on AQ-Bench
 """
 
 # general
diff --git a/source/experiments/nn_aqbench.py b/source/experiments/nn_aqbench.py
index 4d596fd04847e97daaa509ac25fb462b54fd9622..a0e046e6e39c82c76190f6ced29416964a1b9bbe 100644
--- a/source/experiments/nn_aqbench.py
+++ b/source/experiments/nn_aqbench.py
@@ -1,5 +1,5 @@
 """
-Trying some Graph ML architectures on AQ-Bench
+Training a basic pytorch model on AQ-Bench
 """
 
 # general
diff --git a/source/models/correct_and_smooth.py b/source/models/correct_and_smooth.py
index d4b5a6c4997976829462667552bb7293db465d6e..34c7c3e2d66a0210a73cae2ecb67c2cd0f8d47cf 100644
--- a/source/models/correct_and_smooth.py
+++ b/source/models/correct_and_smooth.py
@@ -1,3 +1,9 @@
+"""
+pytorch geometric correct and smooth.
+We adapted this script to work for regression problems, as described
+in our paper.
+"""
+
 import torch
 import torch.nn.functional as F
 from torch import Tensor
@@ -101,7 +107,7 @@ class CorrectAndSmooth(torch.nn.Module):
         error = torch.zeros_like(y_soft)
         error[mask] = y_true - y_soft[mask]
 
-        if self.autoscale:  # This should be False
+        if self.autoscale:
             smoothed_error = self.prop1(error, edge_index,
                                         edge_weight=edge_weight,
                                         post_step=lambda x: x)  # .clamp_(-1., 1.))  # CB
diff --git a/source/models/graph_attention_network.py b/source/models/graph_attention_network.py
index 359a7aad8ce3177b33851946507ca612fa9a4570..4915021623f5a0b87e1532edbd627438e4f89086 100644
--- a/source/models/graph_attention_network.py
+++ b/source/models/graph_attention_network.py
@@ -1,5 +1,5 @@
 """
-Trying some Graph ML architectures on AQ-Bench
+A pytorch geometric graph attention network.
 """
 
 # pytorch
diff --git a/source/models/label_propagation.py b/source/models/label_propagation.py
index 2ae4ff412061886aca01d081cb2fbda8555679d0..27bf61a72d9ef1ab578f37ec7bf829be462594bd 100644
--- a/source/models/label_propagation.py
+++ b/source/models/label_propagation.py
@@ -1,3 +1,8 @@
+"""
+Label propagation by pytorch geometric. Did not change the script,
+it is only meant for debugging.
+"""
+
 from typing import Callable, Optional
 
 import torch
diff --git a/source/models/message_passing.py b/source/models/message_passing.py
index 4d6342ff743deb9f41ad814bedc84c05b8bd85dc..09ca20feac62a4a5499c3cab75d3d91ea9c02b34 100644
--- a/source/models/message_passing.py
+++ b/source/models/message_passing.py
@@ -1,3 +1,8 @@
+"""
+message passing by pytorch geometric.
+We did not change this class, it is only used for debugging.
+"""
+
 import inspect
 import os
 import os.path as osp
diff --git a/source/models/modelwrapper.py b/source/models/modelwrapper.py
index 230fa1a3d98a021f935e2e61078a96c291ff9f5f..51ab6a685293a21dfebaa21baab2b9f8482613fe 100644
--- a/source/models/modelwrapper.py
+++ b/source/models/modelwrapper.py
@@ -1,5 +1,6 @@
 """
-A wrapper for training pytorch models
+A wrapper for training pytorch neural networks (not for correct
+and smooth!)
 """
 
 # pytorch
diff --git a/source/models/simple.py b/source/models/simple.py
index 2f002ead633d80460a6d75cb5c29f612b83bf093..e81827f5e0bab3ca24f79071086ddabf8b7b0b5c 100644
--- a/source/models/simple.py
+++ b/source/models/simple.py
@@ -1,5 +1,6 @@
 """
-Simple models as baselines and to combine with C&S
+This script contains routines to fit and evaluate simple models as
+baselines and to combine with correct and smooth.
 """
 
 # general
@@ -548,7 +549,7 @@ def evaluate_models():
 
 if __name__ == '__main__':
     """
-    Tune, the simple models, then save their predictions.
+    Tune the simple models, then save their predictions.
     Also save the predictions where the validation set is used
     for training.
     """
diff --git a/source/postprocessing/create_final_imputation.py b/source/postprocessing/create_final_imputation.py
index 11a160fd2f0475ff7da182262015d9e1c693e1f1..213598f14f306fe216c45bbc7c46540e911724fd 100644
--- a/source/postprocessing/create_final_imputation.py
+++ b/source/postprocessing/create_final_imputation.py
@@ -1,5 +1,8 @@
 """
-Create the final dataset.
+Create the final dataset. It contains ozone measurements wherever
+they are available. According to the results of this study,
+gaps of up to 5 h lenght are linearly interpolated, while gaps of 6 h
+or more are imputed with random forest + correct and smooth.
 """
 
 # general
@@ -32,6 +35,7 @@ def create_final_imputation():
     missing_o3_mask = mask_df.missing_o3_mask.to_numpy().reshape(-1)
     y_true_df = pd.read_csv(ug.y_path, index_col=0)
     y_true = y_true_df.y.to_numpy().reshape(-1)
+    reg_df = pd.read_csv(ug.reg_path, index_col=0)
 
     # load imputations of the models
     model_path = settings.resources_dir + 'models/'
@@ -40,7 +44,7 @@ def create_final_imputation():
     lin_imp = torch.load(model_path+lin_file).numpy().reshape(-1)
     rf_cs_imp = torch.load(model_path+rf_cs_file).numpy().reshape(-1)
 
-    # prepare data
+    # write imputations to gaps
     imp = y_true.copy()
     imp[missing_o3_mask] = rf_cs_imp[missing_o3_mask]
     gap_df_filtered = gap_df[(gap_df.type=='missing_o3_mask') &
@@ -54,18 +58,30 @@ def create_final_imputation():
                           data=imp)
     imp_df.index.name = 'node_index'
 
-    # save data
-    save_path = settings.resources_dir + 'imputed_dataset/y_imputed.csv'
-    imp_df.to_csv(save_path)
-    print(f'written to {save_path}')
-    pdb.set_trace()
-
-
-
-
-
-
+    # prepare data for saving
+    station_list = np.unique(reg_df.station_id)
+    n_stations = len(station_list)
+    datetime_list = np.unique(reg_df.datetime)
+    n_datetime = len(datetime_list)
+    imp_o3_df = pd.DataFrame(index=datetime_list,
+                             columns=station_list,
+                             data=imp.reshape(n_stations, n_datetime).T)
+    imp_o3_df.index.name = 'datetime'
+
+    # prepare an info dataframe
+    imp_info_df = pd.DataFrame(index=datetime_list,
+                               columns=station_list,
+                               data=missing_o3_mask.reshape(n_stations,
+                                                            n_datetime).T)
+    imp_info_df.index.name = 'datetime'
 
+    # save data
+    o3_save_path = settings.resources_dir + 'imputed_dataset/imputed_o3.csv'
+    info_save_path = settings.resources_dir + 'imputed_dataset/imputed_info.csv'
+    imp_o3_df.to_csv(o3_save_path)
+    print(f'written to {o3_save_path}')
+    imp_info_df.to_csv(info_save_path)
+    print(f'written to {info_save_path}')
 
 
 if __name__ == '__main__':
diff --git a/source/postprocessing/evaluation_tools.py b/source/postprocessing/evaluation_tools.py
index 40e22605ad0cb044bb68a62cb81fb2021a336d1d..19c329a958e61ea14af79ca2abca986fce820ad3 100644
--- a/source/postprocessing/evaluation_tools.py
+++ b/source/postprocessing/evaluation_tools.py
@@ -264,9 +264,10 @@ def count_exceedances(print_=False):
     ug = UBAGraph()
     y_true_df = pd.read_csv(ug.y_path, index_col=0)
     y_true = y_true_df.y.to_numpy().reshape(-1)
-    imputed_path = settings.resources_dir + 'imputed_dataset/y_imputed.csv'
+    imputed_path = settings.resources_dir + \
+                   'imputed_dataset/imputed_o3.csv'
     y_imp_df = pd.read_csv(imputed_path, index_col=0)
-    y_imp = y_imp_df.y_imputed.to_numpy().reshape(-1)
+    y_imp = y_imp_df.values.reshape(-1)
 
     # set up data frame
     columns = ['threshold', 'n_measured', 'n_imputed',
@@ -286,13 +287,13 @@ def count_exceedances(print_=False):
         if not print_:
             continue
         print(f'\nthreshold: {threshold}')
-        print(f'measured exc: {tru_exc}, imputed exc: {imp_exc}')
-        print('after imputation:', (y_imp>threshold).sum())
+        print(f'measured exc: {tru_exc}, imputed exc: {diff}')
+        print('after imputation:', tru_exc+diff)
 
     return exc_df
 
 
-def number_of_neighbors(station_id):
+def number_of_neighbors(station_id, return_ids=False):
     """
     How many stations are in a radius of 50 km?
     """
@@ -316,8 +317,11 @@ def number_of_neighbors(station_id):
                                                 radius=max_dist)
     n_neighbors = len(idx_lists[0]) - 1
     distances = sorted(dist_lists[0])[1:]
+    ids = sd.df.index[idx_lists[0]].to_list()[1:]
 
     # return
+    if return_ids:
+        return(n_neighbors, distances, ids)
     return(n_neighbors, distances)
 
 
@@ -411,9 +415,9 @@ if __name__ == '__main__':
     gap_length_evaluation_ = False
     print_best_hyperparameters_ = False
     test_bootstrap_ = False
-    count_exceedances_ = False
+    count_exceedances_ = True
     number_of_neighbors_ = False
-    get_characteristics_df_ = True
+    get_characteristics_df_ = False
 
     if index_of_agreement_:
         # test index of agreement should be.959
diff --git a/source/preprocessing/aqbench.py b/source/preprocessing/aqbench.py
index 36405504040dedd8618259629bfe9feaab668521..5d444966eb4dbcbd0adec57fa32f59c23c5c758b 100644
--- a/source/preprocessing/aqbench.py
+++ b/source/preprocessing/aqbench.py
@@ -1,5 +1,6 @@
 """
-Prepare datasets for Pytorch Geometric
+This script prepares a pytorch geometric dataset from the AQ-Bench
+dataset. It can be used for first graph machine learning tryouts.
 """
 
 # general
diff --git a/source/preprocessing/get_cams_data_v2.py b/source/preprocessing/get_cams_data_v2.py
index 1d59625d29bfe94fb741b26b9430c8317418358d..135082c5692f2c239135b75e39773b33c4637395 100644
--- a/source/preprocessing/get_cams_data_v2.py
+++ b/source/preprocessing/get_cams_data_v2.py
@@ -1,3 +1,8 @@
+"""
+This script preprocesses raw cams data to .csv files. These files
+are later used in uba.py to create the graph dataset.
+"""
+
 from datetime import datetime, timedelta
 import pandas as pd
 import numpy as np
@@ -32,54 +37,54 @@ class GetInterpolatedCAMSData():
         self.csv_out_path = self.in_dir + 'cams_data_csv/'
         self.plot_out_path = settings.output_dir
         print(f'Resource path: {self.in_dir}')
-        
-    
+
+
     def read_eac4_ncfile(self):
-        
+
         # Convert unit from mmr to vmr (ppb)
         no_unit = 28.9644/30.0061*1.e9
-        no2_unit = 28.9644/46.0055*1.e9 
+        no2_unit = 28.9644/46.0055*1.e9
         o3_unit = 28.9644/47.9982*1.e9
-        
+
         ## READ nc file
         eac4_path = self.in_dir + 'cams_eac4_2011_3hourly.nc'
         nc_data = xr.open_dataset(eac4_path)
-        self.no = nc_data['no'][:,:,:] * no_unit 
-        self.no2 = nc_data['no2'][:,:,:] * no2_unit 
-        self.o3 = nc_data['go3'][:,:,:] * o3_unit 
-        self.lati  = nc_data['latitude'][:]                       
+        self.no = nc_data['no'][:,:,:] * no_unit
+        self.no2 = nc_data['no2'][:,:,:] * no2_unit
+        self.o3 = nc_data['go3'][:,:,:] * o3_unit
+        self.lati  = nc_data['latitude'][:]
         self.loni  = nc_data['longitude'][:]
         self.ti = nc_data['time'][:]
         self.nt = len(self.ti)
-        
+
     def read_emiss_ncfile(self):
         ## READ nc file
         emiss_path = self.in_dir + 'CAMS-GLOB-ANT_Glb_0.1x0.1_anthro_nox_v5.3_monthly_2011_DE.nc'
         nc_data = xr.open_dataset(emiss_path)
         esum_nox = nc_data['sum'][:,:,:]
-        self.late  = nc_data['lat'][:]                       
+        self.late  = nc_data['lat'][:]
         self.lone  = nc_data['lon'][:]
         self.te = nc_data['time'][:]
-        self.nte = len(self.te)        
+        self.nte = len(self.te)
         self.emiss_nox = esum_nox * 12.0 * 1.e12 / 100.0 / 1.e6 ## Convert unit - from Tg per month to g m−2 yr−1
         ## Print arrays to check
-                
+
     def read_stn_info(self):
         stn_data = pd.read_csv(self.stn_path)
         self.ids = stn_data['id']
         self.lats = stn_data['lat']
         self.lons = stn_data['lon']
         self.ns = len(self.ids)
-        print(f' No. of stations: {self.ns}')     
+        print(f' No. of stations: {self.ns}')
         print(f' Station IDs:{self.ids}')
-                
+
     def plot_ncfile(self):
         def ymean_plot(varname, lons, lats, x, y, data):
             vstr=varname
             lon = x
             lat = y
             minlon = np.min(lon)
-            maxlon = np.max(lon) 
+            maxlon = np.max(lon)
             minlat = np.min(lat)
             maxlat = np.max(lat)
             # make color plots with station locations
@@ -87,7 +92,7 @@ class GetInterpolatedCAMSData():
             fig = plt.figure(figsize=(8, 8))
             ymean = data.mean(dim='time')
             # Plot map lines
-            m =Basemap(projection='merc',llcrnrlon=minlon,llcrnrlat=minlat,urcrnrlon=maxlon,urcrnrlat=maxlat,resolution='i') 
+            m =Basemap(projection='merc',llcrnrlon=minlon,llcrnrlat=minlat,urcrnrlon=maxlon,urcrnrlat=maxlat,resolution='i')
             m.drawcountries()
             m.drawcoastlines()
             # Plot color mesh
@@ -98,24 +103,24 @@ class GetInterpolatedCAMSData():
             m.plot(xs,ys,'ro',markersize=2)
             plt.savefig(self.plot_out_path + 'cmap_'+vstr+'_2011avg.png')
             plt.clf()
-                    
+
         ymean_plot('no',self.lons,self.lats,self.loni, self.lati, self.no)
         print(f'Plotted NO color map...')
         ymean_plot('no2',self.lons,self.lats,self.loni, self.lati,self.no2)
         print(f'Plotted NO2 color map...')
         ymean_plot('o3',self.lons,self.lats,self.loni, self.lati,self.o3)
         print(f'Plotted O3 color map...')
-        
+
     def get_data_latlon(self):
         def inp_latlon(var,data,time_index,lats,lons,ids):
             var_str = var
             nt,nlat,nlon = np.shape(data)
             ns = self.ns  # station dimension
-        
-            # Create empty station interpolated arrays 
-            data_stn = np.zeros(nt*ns) 
+
+            # Create empty station interpolated arrays
+            data_stn = np.zeros(nt*ns)
             data_stn = data_stn.reshape(nt,ns)
-            
+
             # Select data by interpolating the neareast lat & lon grids
             for i in range(ns):
             #for i in range (0,5):
@@ -124,14 +129,14 @@ class GetInterpolatedCAMSData():
                 lon = lons[i]
                 print (f'Interpolating Station {id} ({i+1}/{ns})...')
                 print (f'lat & lon : ({lat},{lon})')
-            
+
                 for t in range(nt):
                     try:
                         data_stn[t,i]=data[t,:,:].interp(latitude=lat, longitude=lon)
                     except:
                         data_stn[t,i]=data[t,:,:].interp(lat=lat, lon=lon)
             print(f'Interpolated data: {data_stn}')
-            
+
             # Convert to panda dataframe
             stn_df = pd.DataFrame(data_stn)
             id_str = ids.to_numpy(dtype=str)
@@ -141,25 +146,25 @@ class GetInterpolatedCAMSData():
             stn_df['time'] = time_index
             stn_df = stn_df.set_index('time')
             print(f'station data : {stn_df.head()}')
-                
+
             # Resample and interpolate data to one-hour interval
             stn_df_2011 = stn_df['2011']
             last_hour = pd.to_datetime('2011-12-31 23:00:00')
             stn_df_2011 = stn_df_2011.append(pd.DataFrame(index=[last_hour]))
             stn_df_1h = stn_df_2011.resample('60T')
-            stn_df_1h = stn_df_1h.interpolate(method='linear',axis=0) 
+            stn_df_1h = stn_df_1h.interpolate(method='linear',axis=0)
             #stn_df_1h = stn_df_1h.interpolate(method='spline',order=2,axis=0)
             print(f'station 1h data : {stn_df_1h.head()}')
-            
+
             # Save as csv file
-            out_csv_file = self.out_dir + 'hourly_cams_' + var_str + '.csv' 
+            out_csv_file = self.out_dir + 'hourly_cams_' + var_str + '.csv'
             stn_df_1h.to_csv(out_csv_file)
-    
+
         inp_latlon("no",self.no,self.ti,self.lats,self.lons,self.ids)
         inp_latlon("no2",self.no2,self.ti,self.lats,self.lons,self.ids)
         inp_latlon("o3",self.o3,self.ti,self.lats,self.lons,self.ids)
         inp_latlon("Enox",self.emiss_nox,self.te,self.lats,self.lons,self.ids)
-    
+
     def stn_data_stats(self):
         def plot_conc_hist(varname, data, binwidth):
             vstr = varname
@@ -173,18 +178,18 @@ class GetInterpolatedCAMSData():
             plt.ylabel('count')
             plt.savefig(settings.output_dir + 'cams_'+vstr+'_2011_hist.png')
             plt.clf()
-        # Present conc arrays    
-        m = 365*24 + 1 # read one year of data in 1-hour interval    
+        # Present conc arrays
+        m = 365*24 + 1 # read one year of data in 1-hour interval
         n = self.ns  # station dimension
-        
-        no_stn = np.zeros(n*m) 
+
+        no_stn = np.zeros(n*m)
         no_stn = no_stn.reshape(n,m)
-        no2_stn = np.zeros(n*m) 
+        no2_stn = np.zeros(n*m)
         no2_stn = no2_stn.reshape(n,m)
-        o3_stn = np.zeros(n*m) 
-        o3_stn = o3_stn.reshape(n,m)        
+        o3_stn = np.zeros(n*m)
+        o3_stn = o3_stn.reshape(n,m)
         # READ station interpolated data from precompiled csv files
-        for i in range(n):            
+        for i in range(n):
             id = self.ids[i]
             id_str = str(id).zfill(4)
             data_stn = pd.read_csv(self.csv_out_path + 'ID_' + id_str + '_1h.csv',index_col=0)
@@ -199,13 +204,13 @@ class GetInterpolatedCAMSData():
         print(no_df.describe())
         print(no2_df.describe())
         print(o3_df.describe())
-        # Plot distribution histogram 
+        # Plot distribution histogram
         plot_conc_hist('no',no_df,10)
         plot_conc_hist('no2',no2_df,5)
         plot_conc_hist('o3',o3_df,5)
-            
-            
-            
+
+
+
 if __name__ == '__main__':
     """
     Create the dataset for testing purposes.
@@ -216,11 +221,10 @@ if __name__ == '__main__':
     gicd.read_emiss_ncfile()
     ## READ station file
     gicd.read_stn_info()
-    ## PLOT colour maps 
+    ## PLOT colour maps
     #gied.plot_ncfile()
     ## GET station-interpolated data
-    gicd.get_data_latlon()  
+    gicd.get_data_latlon()
     ## OUTPUT basic interpolated data statistics
     #gicd.stn_data_stats()
-        
-        
\ No newline at end of file
+
diff --git a/source/preprocessing/uba.py b/source/preprocessing/uba.py
index 430a1876923993b94e9a020a533416b2519123da..d98e3d4373317b8cce9356c80f9570a9cba7870f 100644
--- a/source/preprocessing/uba.py
+++ b/source/preprocessing/uba.py
@@ -1,6 +1,6 @@
 """
-Creates a Pytorch Dataset which contains time
-resolved data.
+Creates the pytorch geometric dataset that is used in our publication.
+To reuse the dataset, use the UBAGraph.get_dataset() routine.
 """
 
 # general
@@ -759,14 +759,14 @@ def dataset_workflow():
     Full workflow to create the dataset
     """
     ug = UBAGraph()
-    # ug.get_reg()
-    # ug.get_pos()
-    # ug.get_x()
-    # ug.get_y()
-    # ug.get_gaps()
-    # ug.get_mask()
-    # ug.get_edges()
-    # ug.get_edge_weights()
+    ug.get_reg()
+    ug.get_pos()
+    ug.get_x()
+    ug.get_y()
+    ug.get_gaps()
+    ug.get_mask()
+    ug.get_edges()
+    ug.get_edge_weights()
     ug.convert_to_tensor()
 
 
@@ -802,5 +802,5 @@ if __name__ == '__main__':
     """
     Start routines
     """
-    dataset_workflow()
-    # print_graph_statistics()
+    # dataset_workflow()
+    print_graph_statistics()
diff --git a/source/retrieval/toar_db.py b/source/retrieval/toar_db.py
index 100187647497e47d73abc0af186b5a8f2ee036fd..f5ec4f318a0543635fd6433d2bcf4cb9f30a759b 100644
--- a/source/retrieval/toar_db.py
+++ b/source/retrieval/toar_db.py
@@ -1,5 +1,6 @@
 """
-Retrieving TOAR data
+This file contains all classes necessary to retrieve data from
+the TOAR database.
 """
 
 # general
diff --git a/source/settings.py b/source/settings.py
index 37041fd4356617a215c7fb1637b412868bf53ed9..60849b84e16304179e11b6e44a94c996c4aa0ff3 100644
--- a/source/settings.py
+++ b/source/settings.py
@@ -23,7 +23,7 @@ ROOTDIR = str(SOURCEDIR_pos.parent)
 output_dir = ROOTDIR + '/output/'
 
 # data resources directory
-resources_dir = '/p/project/deepacf/intelliaq/spatial-patterns-data/'
+resources_dir = '/p/project/deepacf/intelliaq/ozone-imputation-data/'
 
 # random seed
 random_seed = 1
diff --git a/source/visualizations/plots_for_paper.py b/source/visualizations/plots_for_paper.py
index 360beb781b2f19113eb2fafc60d38c2c5f9c9dd0..fe891176f87f71bbb05f0d5c15f3f77baa9ded10 100644
--- a/source/visualizations/plots_for_paper.py
+++ b/source/visualizations/plots_for_paper.py
@@ -174,7 +174,8 @@ def station_loc_on_map():
 
 def station_ids_on_map():
     """
-    To analyze what the station ids mean
+    To analyze what the station ids mean.
+    This plot was not published.
     """
     print('station ids on map...')
 
@@ -287,7 +288,8 @@ def o3_value_matrix():
 
 def mask_matrix():
     """
-    the missing, train, test, val masks as a matrix
+    the missing, train, test, val masks as a matrix.
+    This plot was not published.
     """
     print('mask matrix...')
 
@@ -516,17 +518,16 @@ def exceedances():
     # get data
     exc_df = count_exceedances()
 
+    # plot style
+    plt.style.use('seaborn-darkgrid')
+    sns.set(rc={'axes.facecolor':'whitesmoke'})
+
     # plot
     fig, ax = plt.subplots()
     for idx, row in exc_df[:-4].iterrows():
-        # ax.bar(row.threshold,
-        #        row.n_measured,
-        #        color='gainsboro',
-        #        alpha=.95)
         ax.bar(row.threshold,
                row.difference,
                4,
-               # bottom=row.n_measured,
                color='steelblue',
                alpha=.8)
 
@@ -544,138 +545,114 @@ def exceedances():
 
 def true_vs_imputed():
     """
-    A simple scatter plot
+    A heatmap, only isolated gaps, only correlated gaps, summary.
     """
     print('true versus imputed...')
-
-    # prepare scatter true vs. imputed
-    df = get_characteristics_df()
-    y_tru_short = df[(df.test_mask) &
-                     (df.gap_len<=5)].y_true
-    y_imp_short = df[(df.test_mask) &
-                     (df.gap_len<=5)].y_imputed
-    y_tru_long = df[(df.test_mask) &
-                     (df.gap_len>5)].y_true
-    y_imp_long = df[(df.test_mask) &
-                     (df.gap_len>5)].y_imputed
-    y_tru = df[df.test_mask].y_true
-    y_imp = df[df.test_mask].y_imputed
-
-    # plot style
-    plt.style.use('seaborn-darkgrid')
-    sns.set(rc={'axes.facecolor':'whitesmoke'})
-
-    # scatter true vs. imputed
-    fig, ax = plt.subplots(1, 2)
-    min_val = -7.
-    max_val = 115
-
-    ax[0].scatter(y_tru_short,
-                  y_imp_short,
-                  s=2.,  # 1.5,
-                  alpha=.1,
-                  zorder=2,
-                  color='steelblue',
-                  lw=0.
-                  )
-    ax[1].scatter(y_tru_long,
-                  y_imp_long,
-                  s=2.,  # 1.5,
-                  alpha=.1,
-                  zorder=2,
-                  color='steelblue',
-                  lw=0.
-                  )
-    for ax_ in [0, 1]:
-        ax[ax_].plot([min_val, max_val],
-                     [min_val, max_val],
-                     linestyle='--',
-                     lw=1.2,
-                     color='gray',
-                     alpha=.55,
-                     zorder=1,
-                     )
-        ax[ax_].set_aspect('equal', adjustable='box')
-        ax[ax_].set_xlim(min_val, max_val)
-        ax[ax_].set_ylim(min_val, max_val)
-        ax[ax_].set_xlabel('true O3 [ppb]')
-        ax[ax_].set_ylabel('imputed O3 [ppb]')
-    fig.set_size_inches(10, 5)
-       
-
-    # print statistics
-    r2 = r2_score(y_tru, y_imp)
-    rmse = (mean_squared_error(y_tru, y_imp))**.5
-    d = index_of_agreement(y_tru, y_imp)
-    print(f'r2: {r2:.2f}')
-    print(f'rmse: {rmse:.2f}')
-    print(f'd: {d:.2f}')
-
-    # save
-    true_vs_imp_path = settings.output_dir + 'true_vs_imp.png'
-    plt.savefig(true_vs_imp_path, dpi=500)
-    print(f'written to {true_vs_imp_path}')
-    plt.close()
-    
-     
-    """
-    with heatmap
-    """
-    # plot style
-    plt.style.use('seaborn-darkgrid')
-    sns.set(rc={'axes.facecolor':'whitesmoke'})
-    
-    # scatter true vs. imputed
-    fig, ax = plt.subplots(1, 2)
-    min_val = -7.
-    max_val = 115
-
-    # Construct 2D histogram from data using the 'plasma' colormap
-    norm = LogNorm(vmin=1, vmax=600) # Adjust here for the scale of the heatmap
-    
-    ax[0].hist2d(y_tru_short, 
-                 y_imp_short, 
-                 bins=(np.arange(min_val, max_val, 1.0), np.arange(min_val, max_val, 1.0)),
-                 norm=norm,
-                 cmap="Blues"
-                )
-    
-    ax[1].hist2d(y_tru_long, 
-                 y_imp_long, 
-                 bins=(np.arange(min_val, max_val, 1.0), np.arange(min_val, max_val, 1.0)),
-                 norm=norm,
-                 cmap="Blues"
-                )
-    
-    for ax_ in [0, 1]:
-        ax[ax_].plot([min_val, max_val],
-                     [min_val, max_val],
-                     linestyle='--',
-                     lw=1.2,
-                     color='gray',
-                     alpha=.55,
-                     zorder=1,
-                     )
-        ax[ax_].set_aspect('equal', adjustable='box')
-        ax[ax_].set_xlim(min_val, max_val)
-        ax[ax_].set_ylim(min_val, max_val)
-        ax[ax_].set_xlabel('true O3 [ppb]')
-        #ax[ax_].set_ylabel('imputed O3 [ppb]')
-        ax[ax_].grid(True)
-    fig.set_size_inches(10, 5)
-    ax[0].set_ylabel('imputed O3 [ppb]')
-    
-    # Plot a colorbar with label.
-    m = cm.ScalarMappable(cmap="Blues",norm=norm)
-    m.set_array([])
-    cb = plt.colorbar(m, ax=ax,shrink=0.4)
-    cb.set_label('Number of entries')
-    
-    # save
-    true_vs_imp_heat_path = settings.output_dir + 'true_vs_imp_heatmap.png'
-    plt.savefig(true_vs_imp_heat_path, dpi=500)
-    print(f'written to {true_vs_imp_heat_path}')
-    plt.close()
+    dicts = [
+             {'identifier': '_single',
+              'corr_list': [False]},
+             {'identifier': '_correlated',
+              'corr_list': [True]},
+             {'identifier': '',
+              'corr_list': [True, False]}
+              ]
+    for dict_ in dicts:
+        print('\n', dict_['identifier'], '\n')
+
+        # prepare scatter true vs. imputed
+        corr_list = dict_['corr_list']
+        df = get_characteristics_df()
+        y_tru_short = df[(df.test_mask) &
+                         (df.gap_len<=5) &
+                         (df.correlated.isin(corr_list))].y_true
+        y_imp_short = df[(df.test_mask) &
+                         (df.gap_len<=5) &
+                         (df.correlated.isin(corr_list))].y_imputed
+        y_tru_long = df[(df.test_mask) &
+                         (df.gap_len>5) &
+                         (df.correlated.isin(corr_list))].y_true
+        y_imp_long = df[(df.test_mask) &
+                         (df.gap_len>5) &
+                         (df.correlated.isin(corr_list))].y_imputed
+        y_tru = df[df.test_mask].y_true
+        y_imp = df[df.test_mask].y_imputed
+
+        # print statistics, only for summary
+        for tru, imp, what in [(y_tru_short, y_imp_short, 'short'),
+                               (y_tru_long, y_imp_long, 'long'),
+                               (y_tru, y_imp, 'summary') ]:
+            if (dict_['identifier'] != '') & (what== 'summary'):
+                continue
+            r2 = r2_score(tru, imp)
+            rmse = (mean_squared_error(tru, imp))**.5
+            d = index_of_agreement(tru, imp)
+            print(what)
+            print(f'r2: {r2:.2f}')
+            print(f'rmse: {rmse:.2f}')
+            print(f'd: {d:.2f}\n')
+
+        # plot style
+        plt.style.use('seaborn-darkgrid')
+        sns.set(rc={'axes.facecolor':'whitesmoke'})
+
+        # scatter true vs. imputed
+        fig, ax = plt.subplots(1, 2)
+        min_val = -7.
+        max_val = 115
+
+        # Construct 2D histogram from data using the 'plasma' colormap
+        norm = LogNorm(vmin=1, vmax=600) # Adjust scale of the heatmap
+        colormap = 'Blues_r'  # winter, Blues_r, PuBu_r
+
+        ax[0].hist2d(y_tru_short,
+                     y_imp_short,
+                     bins=(np.arange(min_val, max_val, 1.0),
+                           np.arange(min_val, max_val, 1.0)),
+                     norm=norm,
+                     # cmap="Blues_r"
+                     cmap=colormap
+                    )
+
+        ax[1].hist2d(y_tru_long,
+                     y_imp_long,
+                     bins=(np.arange(min_val, max_val, 1.0),
+                           np.arange(min_val, max_val, 1.0)),
+                     norm=norm,
+                     # cmap="Blues_r"
+                     cmap=colormap
+                    )
+
+        for ax_ in [0, 1]:
+            ax[ax_].plot([min_val, max_val],
+                         [min_val, max_val],
+                         linestyle='--',
+                         lw=1.2,
+                         color='gray',
+                         alpha=.55,
+                         zorder=1,
+                         )
+            ax[ax_].set_aspect('equal', adjustable='box')
+            ax[ax_].set_xlim(min_val, max_val)
+            ax[ax_].set_ylim(min_val, max_val)
+            ax[ax_].set_xlabel('true O3 [ppb]')
+            #ax[ax_].set_ylabel('imputed O3 [ppb]')
+            ax[ax_].grid(True)
+        fig.set_size_inches(10, 5)
+        ax[0].set_ylabel('imputed O3 [ppb]')
+
+        # Plot a colorbar with label.
+        m = cm.ScalarMappable(cmap=colormap,norm=norm)
+        m.set_array([])
+        cb = plt.colorbar(m, ax=ax,shrink=0.4)
+        cb.set_label('Number of entries')
+
+        # save
+        id_ = dict_['identifier']
+        true_vs_imp_heat_path = settings.output_dir + \
+                                f'true_vs_imp_heatmap{id_}.png'
+        plt.savefig(true_vs_imp_heat_path, dpi=500)
+        print(f'written to {true_vs_imp_heat_path}')
+        plt.close()
 
 
 def n_neighbors_vs_r2():
@@ -787,14 +764,15 @@ def gap_len_vs_r2():
 
 def imputed_dataset_matrix():
     """
-    Creates a matrix plot of the imputed datset.
+    Creates a matrix plot of the imputed datset. This plot is used
+    for the graphical abstract / TOC figure.
     """
     print('imputed matrix...')
 
     # read data
     directory = settings.resources_dir + 'imputed_dataset/'
-    filename = 'y_imputed.csv'
-    y_imp_flat = pd.read_csv(directory+filename, index_col=0)
+    filename = 'imputed_o3.csv'
+    y_imp = pd.read_csv(directory+filename, index_col=0)
 
     # for reshape and vmax
     hd = HourlyData('o3')
@@ -805,11 +783,9 @@ def imputed_dataset_matrix():
     vmin = 0.
     vmax = np.nanpercentile(o3_df.values, 99)
 
-    # prepare data
-    y_imp = y_imp_flat.values.reshape(n_stations, n_timesteps)
-
     # plot
-    im = plt.imshow(y_imp, interpolation='none', vmin=vmin, vmax=vmax)
+    im = plt.imshow(y_imp.values.T, interpolation='none',
+                    vmin=vmin, vmax=vmax)
     ax = plt.gca()
     ax.axis('off')
 
@@ -820,6 +796,79 @@ def imputed_dataset_matrix():
     plt.close()
 
 
+def modeled_time_series():
+    """
+    A completely modeled time series, as a proof of concept.
+    """
+    print('modeled time series...')
+
+    # load imputed dataset
+    dir_ = settings.resources_dir + 'imputed_dataset/'
+    o3_file = 'imputed_o3.csv'
+    info_file = 'imputed_info.csv'
+    o3_df = pd.read_csv(dir_+o3_file, index_col=0)
+    info_df = pd.read_csv(dir_+info_file, index_col=0)
+
+    # load station data
+    sd = StationData()
+    sd.read_from_file()
+    station_df = sd.df
+
+    # find stations without data
+    find_stations = False
+    if find_stations:
+        filter_ = info_df.all(axis=0)
+        null_station_list = info_df.columns[filter_].to_list()
+        for station in null_station_list:
+            print(station)
+            print(number_of_neighbors(int(station), return_ids=True))
+    chosen_station_ids = [3625, 3637, 4258]
+
+    # plot info on chosen stations
+    for station_id in chosen_station_ids:
+        n_ngh, _ = number_of_neighbors(station_id)
+        print(station_id)
+        print(n_ngh, 'neighbors')
+        print(sd.df.loc[station_id], '\n')
+
+    # choose time steps to plot
+    # datetimes[5086] = 2011-08-01 00:00:00
+    # datetimes[5830] = 2011-09-01 00:00:00
+    datetimes = o3_df.index.to_list()
+    start_idx = 5086
+    end_idx = 5830
+
+    # plot
+    fig, ax = plt.subplots(3, 1, figsize=(15,7))
+    for plt_idx, station_id in enumerate(chosen_station_ids):
+        data = o3_df[str(station_id)].to_list()
+        _, ngh_dists, ngh_ids = number_of_neighbors(station_id,
+                                                    return_ids=True)
+        ax[plt_idx].set_facecolor('whitesmoke')
+        ax[plt_idx].set_alpha(0.35)
+        ax[plt_idx].plot(data[start_idx:end_idx+1],
+                         color='steelblue',
+                         zorder=2)
+        for ngh_id in ngh_ids:
+            data = o3_df[str(ngh_id)].to_list()
+            ax[plt_idx].plot(data[start_idx:end_idx+1],
+                             color='dimgray',
+                             alpha=0.5,
+                             lw=.4,
+                             zorder=1)
+        ax[plt_idx].set_xticks([])
+        ax[plt_idx].set_xlim(-12, end_idx-start_idx+12)
+        ax[plt_idx].set_ylim(-4, 104)
+    ax[plt_idx].set_xticks(range(0, end_idx-start_idx+1, 24))
+    ax[plt_idx].set_xticklabels([])
+
+    # save figure
+    save_pth = settings.output_dir + 'modeled_time_series.png'
+    plt.savefig(save_pth, bbox_inches='tight', pad_inches=0.2, dpi=800)
+    print(f'saved to {save_pth}')
+    plt.close()
+
+
 if __name__ == '__main__':
     station_loc_on_map_ = False
     station_ids_on_map_ = False
@@ -832,6 +881,7 @@ if __name__ == '__main__':
     n_neighbors_vs_r2_ = False
     gap_len_vs_r2_ = False
     imputed_dataset_matrix_ = False
+    modeled_time_series_ = False
 
     if station_loc_on_map_:
         station_loc_on_map()
@@ -857,4 +907,6 @@ if __name__ == '__main__':
         gap_len_vs_r2()
     if imputed_dataset_matrix_:
         imputed_dataset_matrix()
+    if modeled_time_series_:
+        modeled_time_series()
 
diff --git a/source/visualizations/preanalysis_cams_plots.py b/source/visualizations/preanalysis_cams_plots.py
index 7a634405bcaf60d65423e4b4dde7621fa22522e1..7b70d33d29baf80e169877989b7715a99974c09a 100644
--- a/source/visualizations/preanalysis_cams_plots.py
+++ b/source/visualizations/preanalysis_cams_plots.py
@@ -1,6 +1,6 @@
 """
-This file contains routines that were used in visualise and pre-analyse the CAMS data.
-
+This file contains routines that were used in visualise and
+pre-analyse the CAMS data.
 """
 
 from datetime import datetime, timedelta
@@ -38,44 +38,44 @@ class PlotCAMSData():
         self.hourly_no2_path = self.cams_dir + 'hourly_cams_no2.csv'
         self.hourly_o3_path = self.cams_dir + 'hourly_cams_o3.csv'
         self.hourly_Enox_path = self.cams_dir + 'hourly_cams_Enox.csv'
-        
+
     def read_eac4_ncfile(self):
         # Convert unit from mmr to vmr (ppb)
         no_unit = 28.9644/30.0061*1.e9
-        no2_unit = 28.9644/46.0055*1.e9 
+        no2_unit = 28.9644/46.0055*1.e9
         o3_unit = 28.9644/47.9982*1.e9
-        
+
         ## READ nc file
         eac4_path = self.cams_dir + 'cams_eac4_2011_3hourly.nc'
         nc_data = xr.open_dataset(eac4_path)
-        self.no = nc_data['no'][:,:,:] * no_unit 
-        self.no2 = nc_data['no2'][:,:,:] * no2_unit 
-        self.o3 = nc_data['go3'][:,:,:] * o3_unit 
-        self.lati  = nc_data['latitude'][:]                       
+        self.no = nc_data['no'][:,:,:] * no_unit
+        self.no2 = nc_data['no2'][:,:,:] * no2_unit
+        self.o3 = nc_data['go3'][:,:,:] * o3_unit
+        self.lati  = nc_data['latitude'][:]
         self.loni  = nc_data['longitude'][:]
         self.ti = nc_data['time'][:]
         self.nt = len(self.ti)
-        
+
     def read_emiss_ncfile(self):
         ## READ nc file
         emiss_path = self.cams_dir + 'CAMS-GLOB-ANT_Glb_0.1x0.1_anthro_nox_v5.3_monthly_2011_DE.nc'
         nc_data = xr.open_dataset(emiss_path)
         esum_nox = nc_data['sum'][:,:,:]
-        self.late  = nc_data['lat'][:]                       
+        self.late  = nc_data['lat'][:]
         self.lone  = nc_data['lon'][:]
         self.te = nc_data['time'][:]
-        self.nte = len(self.te)        
+        self.nte = len(self.te)
         self.emiss_nox = esum_nox * 12.0 * 1.e12 / 100.0 / 1.e6 ## Convert unit - from Tg per month to g m−2 yr−1
-    
+
     def read_stn_info(self):
         stn_data = pd.read_csv(self.stn_path)
         self.ids = stn_data['id']
         self.lats = stn_data['lat']
         self.lons = stn_data['lon']
         self.ns = len(self.ids)
-        print(f' No. of stations: {self.ns}')     
+        print(f' No. of stations: {self.ns}')
         print(f' Station IDs:{self.ids}')
-        
+
     def ymean_plot(self,varname, lons, lats, x, y, data):
         vstr=varname
         lon = x
@@ -101,13 +101,13 @@ class PlotCAMSData():
             plt.colorbar(label= vstr+'(g m−2 yr−1)')
         else:
             plt.colorbar(label= vstr+' VMR (ppb)')
-        
+
         # add stations
         xs,ys = m(lons,lats)
         m.plot(xs,ys,'ro',markersize=2)
         plt.savefig(self.plot_out_dir + 'cmap_'+vstr+'_2011avg.png')
         plt.clf()
-    
+
     def plot_hist(self,varname, data, binwidth):
         vstr = varname
         maxvmr = data.max(axis=0)
@@ -140,17 +140,17 @@ class PlotCAMSData():
         self.hourly_no2 = pd.read_csv(self.hourly_no2_path,index_col=0)
         self.hourly_o3 = pd.read_csv(self.hourly_o3_path,index_col=0)
         self.hourly_Enox = pd.read_csv(self.hourly_Enox_path,index_col=0)
-        
+
         no_df = pd.DataFrame(self.hourly_no.to_numpy().flatten())
         no2_df = pd.DataFrame(self.hourly_no2.to_numpy().flatten())
         o3_df = pd.DataFrame(self.hourly_o3.to_numpy().flatten())
         Enox_df = pd.DataFrame(self.hourly_Enox.to_numpy().flatten())
-        
+
         self.plot_hist('no',no_df,10)
         self.plot_hist('no2',no2_df,5)
         self.plot_hist('o3',o3_df,5)
         self.plot_hist('Enox',Enox_df,10)
-       
+
         print(no_df.describe())
         print(no2_df.describe())
         print(o3_df.describe())
@@ -162,7 +162,7 @@ if __name__ == '__main__':
     """
     pcd = PlotCAMSData()
     ## READ ncfile
-    pcd.read_eac4_ncfile()    
+    pcd.read_eac4_ncfile()
     pcd.read_emiss_ncfile()
     ## READ station info
     pcd.read_stn_info()
diff --git a/source/visualizations/preanalysis_plots.py b/source/visualizations/preanalysis_plots.py
index c646b841c8bf28f00b60991681717433304e309b..03f02b602870989fa552131b4afbe36fb9b26616 100644
--- a/source/visualizations/preanalysis_plots.py
+++ b/source/visualizations/preanalysis_plots.py
@@ -1,5 +1,6 @@
 """
-A script for the visualization and analysis of time series available
+This is an old script that was used to visualize data directly from
+the TOAR database. Not meant for reuse.
 """
 
 # general
diff --git a/source/visualizations/stack_overflow.py b/source/visualizations/projection_tryouts.py
similarity index 97%
rename from source/visualizations/stack_overflow.py
rename to source/visualizations/projection_tryouts.py
index 4d49b585ddb9753a2cb3f205d4d90c16fab1bc78..29fd31feb44426c737368bbe57fc473eee45cfff 100644
--- a/source/visualizations/stack_overflow.py
+++ b/source/visualizations/projection_tryouts.py
@@ -1,3 +1,8 @@
+"""
+Projection tryouts for German map.
+"""
+
+
 import geopandas
 import matplotlib.pyplot as plt
 import cartopy.crs as ccrs
diff --git a/source/visualizations/station_gaps_analysis.py b/source/visualizations/station_gaps_analysis.py
index be9d7625046d27f42bfeb48dc9f141097e592f49..c05ff2964794c3340032aae3bb48ed3ee69b2fa6 100644
--- a/source/visualizations/station_gaps_analysis.py
+++ b/source/visualizations/station_gaps_analysis.py
@@ -1,6 +1,5 @@
 """
-Creates a pytorch Dataset which contains time
-resolved data.
+An old script for analyzing gap lenghts. Not meant for reuse.
 """
 
 # general
@@ -37,41 +36,41 @@ class StnGapAnalysis():
     def __init__(self):
         ### LOAD RAW DATASETS ###
         self.stn_data = pd.read_csv(settings.resources_dir + 'time_resolved_raw/' +
-                          'stations.csv', index_col=0)        
+                          'stations.csv', index_col=0)
         self.gap_data = pd.read_csv(settings.resources_dir + 'time_resolved_preproc/' +
                                'gap_2011.csv', index_col=0)  ## Read only 2011 data
         self.reg_data = pd.read_csv(settings.resources_dir + 'time_resolved_preproc/' +
                                'reg.csv', index_col=0)
         print(f'Loaded station data and gap data......')
-    
+
     def select_dates(self):
-        print(f' No. of nodes in total: {len(self.reg_data)}')    
+        print(f' No. of nodes in total: {len(self.reg_data)}')
         data_2011 = self.reg_data[self.reg_data['datetime'].str.contains('2011')] # Select the year 2011
         print(f'data in 2011: {data_2011.head()}')
         print(f'data in 2011: {data_2011.tail()}')
         self.id_2011 = data_2011.index.astype('float')
         print(f' Node index for 2011: {self.id_2011[:5]}')
         print(f' Node index for 2011: {self.id_2011[-1]}')
-        print(f' No. of nodes in 2011: {len(self.id_2011)}')    
-    
+        print(f' No. of nodes in 2011: {len(self.id_2011)}')
+
     def define_large_gap(self):
       ### DEFINE LARGE GAPS AND GAP BINS ###
         real_gap = self.gap_data[(self.gap_data['type']=='missing_o3')]
         print(f'No. of real gaps in total: {len(real_gap)}')
-        
+
         self.large_gap = real_gap[real_gap['len'] > 12.0] # Define large gaps as gaps longer than 12.0 hours
         #self.large_gap = real_gap_2011[real_gap_2011['len'] > 12.0] # Select gaps only in 2011
         print(f'No. of large gaps: {len(self.large_gap)}')
-        # Define gap bin = [12h, 1d, 2d, 3d, 1w, 2w, 1m, 2m, 3m, 6m, 1y 2y] 
+        # Define gap bin = [12h, 1d, 2d, 3d, 1w, 2w, 1m, 2m, 3m, 6m, 1y 2y]
         self.gbin = [12, 24, 48, 72, 168, 336, 720, 1440, 2160, 4320, 8760, 17520, 43824] # define bin for gap length
-        
+
     def define_gap_bins(self):
         real_gap = self.gap_data[(self.gap_data['type']=='missing_o3')]
         self.large_gap = real_gap ## this includes all gaps, not only "large" gap, but just stick to the terminology
-        
-        # Define gap bin = [1h, 2h, 3h, 4h, 6h, 12h, 1d, 2d, 3d, 1w, 2w, 1m, 2m, 3m, 6m, 1y] 
+
+        # Define gap bin = [1h, 2h, 3h, 4h, 6h, 12h, 1d, 2d, 3d, 1w, 2w, 1m, 2m, 3m, 6m, 1y]
         self.gbin = [1, 2, 3, 4, 6, 12, 24, 48, 72, 168, 336, 720, 1440, 2160, 4320, 8760, 17520] # define bin for gap length
-        
+
     def large_gap_hist(self):
         ### PLOT GAP BIN HISTOGRAM ###
         glen = self.large_gap['len']
@@ -84,7 +83,7 @@ class StnGapAnalysis():
         plt.savefig(settings.output_dir + 'gap_len_hist_bins.png')
         plt.clf()
         print(f'Number of gaps per bin: { glen.value_counts(bins=self.gbin) }')
-    
+
     def stn_typs_hist(self):
         ### Analysis station data by station types ###
         # Print counts of stations per type per type of area
@@ -105,7 +104,7 @@ class StnGapAnalysis():
         plt.savefig(settings.output_dir + 'station_types.png')
         plt.clf()
 
-    
+
     def stn_typs_gap_count(self):
         ### Count gaps per bin in each station ###
         # Classify gaps to gap bin (group)
@@ -113,7 +112,7 @@ class StnGapAnalysis():
         self.large_gap['group'] = pd.cut(self.large_gap.len, self.gbin, right=False, labels=labels)
         print(self.large_gap.head())
         #self.large_gap.to_csv(settings.output_dir + 'gap_tobins.csv')
-        # Count no. of gaps per bin in each station 
+        # Count no. of gaps per bin in each station
         stn_ggroup = pd.crosstab(self.large_gap['station_id'],self.large_gap['group'])
         print(stn_ggroup.head())
 
@@ -126,7 +125,7 @@ class StnGapAnalysis():
         stn_ggroup_new = stn_ggroup_data.rename(columns={'station_id': 'id'})
         stn_data_ggroup = pd.merge(self.stn_data, stn_ggroup_new, how='left', on='id')
         print(stn_data_ggroup.head())
-    
+
         #Save to csv output
         stn_data_ggroup.to_csv(settings.output_dir + 'station_fullinfo_gap_length.csv')
         print(f'written to csv file station_fullinfo_gap_length.csv......')
@@ -139,16 +138,16 @@ class StnGapAnalysis():
             # Count no. of gaps per bin per station type
             stn_typ_gap = stn_data_ggroup.groupby(['type','type_of_area'])[gap_name].sum()
             print(stn_typ_gap)
-        
+
             # Plot histogram
             stn_typ_gap.unstack().plot.bar()
             plt.title(f"Gap {i}")
             plt.xlabel("type")
             plt.xticks(rotation=0)
-            plt.ylabel("Number of gaps")    
+            plt.ylabel("Number of gaps")
             gn_str = str(i).zfill(2)
             plt.savefig(settings.output_dir + f"station_types_gap"+gn_str+".png")
-            
+
 
 if __name__ == '__main__':
     """
@@ -161,6 +160,6 @@ if __name__ == '__main__':
     sga.large_gap_hist()
     sga.stn_typs_hist()
     sga.stn_typs_gap_count()
-    
-        
+
+