diff --git a/README.md b/README.md index 5fa4b7c2bc02d5325ce39955af61a0b817884876..5743ea41d3f33fb4d374284c429b25911f0549e7 100644 --- a/README.md +++ b/README.md @@ -1,68 +1,31 @@ -# Air quality mapping with the AQ-Bench dataset - -The goal of this project is to map metadata at station locations to air quality statistics. - -These instructions will get you a copy of the project up and running on your PC. - -## Structure of the project - -This project consists of two parts: -* Obtaining the training dataset from TOAR-DB and JOIN. We call it AQ-Bench for now. -* The mapping part - -## Hyperparameter tuning "Hackathon" - -Get yourself up and ready: -* Download the project from Git -* Run ```source prepare.sh``` for python environment -* Start the Jupyter notebook ```cd source```, ```jupyter notebook``` - -Rules for the Game: -* We provide you with training data (train/dev split as you like) -* Try out hyperparameters -* Submit your best hyper-parameters to be tested with our secret test set -* Best hyper-parameters win the price! - -## Downloading the AQ-Bench dataset - -* We provide the dataset in the data folder of this project. Nevertheless, you can also download it by yourself. -* If would like to download the AQ-Bench dataset, turn on FZJ VPN for TOAR access. -* Create a file ```dataset_dbaccess.py``` in the source directory which contains your credentials for TOAR-DB (if you do not have access to TOAR-db, then just leave '***' for username and password): - -``` -db_user = '****' -db_password = '****' -db_host = 'zam10131.zam.kfa-juelich.de' -db_port = '5432' -db_name = 'surface_observations_toar' -``` -## Resources to describe AQ-Bench - -The resources folder contains .csv files with necessary info to handle the dataset. - -* ```AQbench_variables.csv```: Info on all variables in the dataset -* ```*_cols.csv```: Info for dataset retrieval -* ```climatic_zone.csv```, ```htap_region.csv```, ```climatic_landcover.csv```: Info on decoded variables - -## Run Scripts - -Run ```source run.sh``` to start the interactive script starter. You may choose from various options: - -* ```prepare ``` - * Creates folders for logs (where your log files are stored), data (where the dataset is stored) and plots (where your plots will be stored) - * Creates and activates the mapping environment -* ```test ``` - * Starts all tests in the test folder -* ```retrieval ``` - * Starts the dataset retrieval from TOAR-DB and JOIN -* ```sanitycheck ``` - * Carries out a sanitycheck for your dataset -* ```preanalysis ``` - * Preliminary analysis of dataset statistics - * Visualisation of missing values -* ```mapping ``` - * Mapping of the dataset (multi layer perceptron) - * Mapping of the dataset (random forest) +<img src="/doc/graphical_abstract.png" alt="graphical_abstract" + title="Graphical abstract" width="600" height="146" /> + +# Machine learning on the AQ-Bench dataset + +This repository enables a machine learning quickstart on the AQ-Bench dataset. + +The AQ-Bench Benchmark dataset is described in Betancourt et al. (manuscript): "AQ-Bench: A Benchmark Dataset for Machine Learning on Global Air Quality Metrics" (link follows) + +## Quickstart + +Run it on binder! Click on the badge below to start machine learning on AQ-Bench in your browser (might take a couple of minutes to launch). + +[](https://mybinder.org/v2/git/https%3A%2F%2Fgitlab.version.fz-juelich.de%2Ftoar%2Fozone-mapping/master?filepath=source%2Fintroduction_jupyter.ipynb) + + + +## Get the project running on your PC + +* Prerequisite: Conda or MiniConda with Python 3.6 +* Use ```environment.yml``` to create an environment, then activate it by prompting ```source activate aqbench``` +* Navigate to ```source``` and start the ```introduction_jupyter.ipynp``` by prompting ```jupyter notebook``` + +## Structure of the repository + +* ```resources``` contains the data +* ```source``` contains the scripts +* ```doc``` contains documentation files ## Authors diff --git a/apt.txt b/apt.txt deleted file mode 100644 index b4d31fa80bd4a0fcb5a9d64f27914951ae932746..0000000000000000000000000000000000000000 --- a/apt.txt +++ /dev/null @@ -1 +0,0 @@ -libgeos-dev diff --git a/doc/graphical_abstract.png b/doc/graphical_abstract.png new file mode 100644 index 0000000000000000000000000000000000000000..de2b94001563a26c4ad822d40254f1968a629cd6 Binary files /dev/null and b/doc/graphical_abstract.png differ diff --git a/doc/launch-binder.svg b/doc/launch-binder.svg new file mode 100644 index 0000000000000000000000000000000000000000..e3b14fbcc34d6d883f77369fe503b8a1a735258a --- /dev/null +++ b/doc/launch-binder.svg @@ -0,0 +1 @@ +<svg xmlns="http://www.w3.org/2000/svg" xmlns:xlink="http://www.w3.org/1999/xlink" width="109" height="20" role="img" aria-label="launch: binder"><title>launch: binder</title><linearGradient id="s" x2="0" y2="100%"><stop offset="0" stop-color="#bbb" stop-opacity=".1"/><stop offset="1" stop-opacity=".1"/></linearGradient><clipPath id="r"><rect width="109" height="20" rx="3" fill="#fff"/></clipPath><g clip-path="url(#r)"><rect width="64" height="20" fill="#555"/><rect x="64" width="45" height="20" fill="#579aca"/><rect width="109" height="20" fill="url(#s)"/></g><g fill="#fff" text-anchor="middle" font-family="Verdana,Geneva,DejaVu Sans,sans-serif" text-rendering="geometricPrecision" font-size="110"><image x="5" y="3" width="14" height="14" xlink:href="data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAFkAAABZCAMAAABi1XidAAAB8lBMVEX///9XmsrmZYH1olJXmsr1olJXmsrmZYH1olJXmsr1olJXmsrmZYH1olL1olJXmsr1olJXmsrmZYH1olL1olJXmsrmZYH1olJXmsr1olL1olJXmsrmZYH1olL1olJXmsrmZYH1olL1olL0nFf1olJXmsrmZYH1olJXmsq8dZb1olJXmsrmZYH1olJXmspXmspXmsr1olL1olJXmsrmZYH1olJXmsr1olL1olJXmsrmZYH1olL1olLeaIVXmsrmZYH1olL1olL1olJXmsrmZYH1olLna31Xmsr1olJXmsr1olJXmsrmZYH1olLqoVr1olJXmsr1olJXmsrmZYH1olL1olKkfaPobXvviGabgadXmsqThKuofKHmZ4Dobnr1olJXmsr1olJXmspXmsr1olJXmsrfZ4TuhWn1olL1olJXmsqBi7X1olJXmspZmslbmMhbmsdemsVfl8ZgmsNim8Jpk8F0m7R4m7F5nLB6jbh7jbiDirOEibOGnKaMhq+PnaCVg6qWg6qegKaff6WhnpKofKGtnomxeZy3noG6dZi+n3vCcpPDcpPGn3bLb4/Mb47UbIrVa4rYoGjdaIbeaIXhoWHmZYHobXvpcHjqdHXreHLroVrsfG/uhGnuh2bwj2Hxk17yl1vzmljzm1j0nlX1olL3AJXWAAAAbXRSTlMAEBAQHx8gICAuLjAwMDw9PUBAQEpQUFBXV1hgYGBkcHBwcXl8gICAgoiIkJCQlJicnJ2goKCmqK+wsLC4usDAwMjP0NDQ1NbW3Nzg4ODi5+3v8PDw8/T09PX29vb39/f5+fr7+/z8/Pz9/v7+zczCxgAABC5JREFUeAHN1ul3k0UUBvCb1CTVpmpaitAGSLSpSuKCLWpbTKNJFGlcSMAFF63iUmRccNG6gLbuxkXU66JAUef/9LSpmXnyLr3T5AO/rzl5zj137p136BISy44fKJXuGN/d19PUfYeO67Znqtf2KH33Id1psXoFdW30sPZ1sMvs2D060AHqws4FHeJojLZqnw53cmfvg+XR8mC0OEjuxrXEkX5ydeVJLVIlV0e10PXk5k7dYeHu7Cj1j+49uKg7uLU61tGLw1lq27ugQYlclHC4bgv7VQ+TAyj5Zc/UjsPvs1sd5cWryWObtvWT2EPa4rtnWW3JkpjggEpbOsPr7F7EyNewtpBIslA7p43HCsnwooXTEc3UmPmCNn5lrqTJxy6nRmcavGZVt/3Da2pD5NHvsOHJCrdc1G2r3DITpU7yic7w/7Rxnjc0kt5GC4djiv2Sz3Fb2iEZg41/ddsFDoyuYrIkmFehz0HR2thPgQqMyQYb2OtB0WxsZ3BeG3+wpRb1vzl2UYBog8FfGhttFKjtAclnZYrRo9ryG9uG/FZQU4AEg8ZE9LjGMzTmqKXPLnlWVnIlQQTvxJf8ip7VgjZjyVPrjw1te5otM7RmP7xm+sK2Gv9I8Gi++BRbEkR9EBw8zRUcKxwp73xkaLiqQb+kGduJTNHG72zcW9LoJgqQxpP3/Tj//c3yB0tqzaml05/+orHLksVO+95kX7/7qgJvnjlrfr2Ggsyx0eoy9uPzN5SPd86aXggOsEKW2Prz7du3VID3/tzs/sSRs2w7ovVHKtjrX2pd7ZMlTxAYfBAL9jiDwfLkq55Tm7ifhMlTGPyCAs7RFRhn47JnlcB9RM5T97ASuZXIcVNuUDIndpDbdsfrqsOppeXl5Y+XVKdjFCTh+zGaVuj0d9zy05PPK3QzBamxdwtTCrzyg/2Rvf2EstUjordGwa/kx9mSJLr8mLLtCW8HHGJc2R5hS219IiF6PnTusOqcMl57gm0Z8kanKMAQg0qSyuZfn7zItsbGyO9QlnxY0eCuD1XL2ys/MsrQhltE7Ug0uFOzufJFE2PxBo/YAx8XPPdDwWN0MrDRYIZF0mSMKCNHgaIVFoBbNoLJ7tEQDKxGF0kcLQimojCZopv0OkNOyWCCg9XMVAi7ARJzQdM2QUh0gmBozjc3Skg6dSBRqDGYSUOu66Zg+I2fNZs/M3/f/Grl/XnyF1Gw3VKCez0PN5IUfFLqvgUN4C0qNqYs5YhPL+aVZYDE4IpUk57oSFnJm4FyCqqOE0jhY2SMyLFoo56zyo6becOS5UVDdj7Vih0zp+tcMhwRpBeLyqtIjlJKAIZSbI8SGSF3k0pA3mR5tHuwPFoa7N7reoq2bqCsAk1HqCu5uvI1n6JuRXI+S1Mco54YmYTwcn6Aeic+kssXi8XpXC4V3t7/ADuTNKaQJdScAAAAAElFTkSuQmCC"/><text aria-hidden="true" x="415" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="370">launch</text><text x="415" y="140" transform="scale(.1)" fill="#fff" textLength="370">launch</text><text aria-hidden="true" x="855" y="150" fill="#010101" fill-opacity=".3" transform="scale(.1)" textLength="350">binder</text><text x="855" y="140" transform="scale(.1)" fill="#fff" textLength="350">binder</text></g></svg> \ No newline at end of file diff --git a/prepare.sh b/prepare.sh deleted file mode 100755 index 27db8c99aecee491c136f8de0cb867b158ef02d5..0000000000000000000000000000000000000000 --- a/prepare.sh +++ /dev/null @@ -1,61 +0,0 @@ -#!/usr/bin/env bash - -# 2020-11-05 this preparation script should work for the current version of -# AQ-bench -# venv is created in this directory. - -# check if we are really in the ozone-mapping directory -S="ozone-mapping" -if [[ $(pwd) == *$S ]] -then - echo "Prepare..." -else - echo "You are not in the ozone-mapping directory. Abort" - return -fi - -# check if file was sourced -if [[ ${BASH_SOURCE[0]} == ${0} ]]; then - echo ERROR: 'prepare.sh' must be sourced. - echo Execute by prompting "'source prepare.sh'" - exit 1 -fi - -# create logs dir in project directory -mkdir -p log - -# create plots dir in project directory -mkdir -p output - - -# create virtual python environment -ENV_NAME=venv -ENV_DIR=./${ENV_NAME} -unset PYTHONPATH -python3 -m venv ./$ENV_NAME -source ${ENV_DIR}/bin/activate - -# python packages -pip3 install --upgrade pip -pip3 install pandas -pip3 install psycopg2-binary # same functionality as psycopg2 but works better - # with UDFs -pip3 install -U scikit-learn -pip3 install pytest -pip3 install missingno -pip3 install tensorflow-cpu -pip3 install torch torchvision -pip3 install jupyterlab -pip3 install notebook -pip3 install ipywidgets - - -# basemap -sudo apt-get install libgeos-3.5.0 -sudo apt-get install libgeos-dev -pip install https://github.com/matplotlib/basemap/archive/master.zip - -# add current directory to pythonpath -CWD=$(pwd) -export PYTHONPATH=$CWD:$PYTHONPATH -export PYTHONPATH=$CWD"/source/":$PYTHONPATH diff --git a/run.sh b/run.sh deleted file mode 100755 index 8115225241c1232b6313cd5787962a1f02353939..0000000000000000000000000000000000000000 --- a/run.sh +++ /dev/null @@ -1,2 +0,0 @@ -cd source -jupyter notebook diff --git a/source/dataset_mapplot.py b/source/dataset_mapplot.py index 617e0595d3a41914a05ddc3f7c5f7f469bca0708..144aa2c0b02997fab8811f1ae9398ed228df1382 100644 --- a/source/dataset_mapplot.py +++ b/source/dataset_mapplot.py @@ -1,4 +1,3 @@ -import numpy as np import pandas as pd import matplotlib.pyplot as plt import cartopy @@ -19,12 +18,12 @@ class MapPlot: # name of the file where data is stored / read from self.data = pd.read_csv(resources_dir + AQbench_dataset_file) self.available_colors =\ - ['red', 'blue', 'green', 'yellow', 'cyan', 'magenta', - 'orangered', 'sienna', 'chocolate', 'orange', 'goldenrod', + ['orangered', 'turquoise', 'green', 'yellow', 'cyan', 'magenta', + 'hotpink', 'sienna', 'chocolate', 'orange', 'goldenrod', 'gold', 'olive', 'yellowgreen', 'lawngreen', 'limegreen', - 'springgreen', 'turquoise', 'teal', 'deepskyblue', 'dodgerblue', + 'springgreen', 'blue', 'teal', 'deepskyblue', 'dodgerblue', 'navy', 'mediumslateblue', 'blueviolet', 'violet', - 'mediumvioletred', 'deeppink', 'hotpink'] + 'mediumvioletred', 'deeppink', 'red'] self.available_markers =\ ['o', 's', 'D', '*', 'P', 'p', 'd', 'X', 'v', '^', '<', '>'] @@ -37,7 +36,8 @@ class MapPlot: def plot(self, annotate=False): ax = plt.axes(projection=cartopy.crs.PlateCarree()) ax.add_feature(cartopy.feature.COASTLINE) - ax.add_feature(cartopy.feature.BORDERS, linestyle=':') + # No borders due to political reasons + # ax.add_feature(cartopy.feature.BORDERS, linestyle=':') ax.add_feature(cartopy.feature.LAND) ax.add_feature(cartopy.feature.OCEAN) @@ -48,10 +48,10 @@ class MapPlot: x, y = lons, lats ax.scatter(x, y, - s=50, + s=12, color=self.colors[i], marker=self.markers[i], - alpha=0.5, + alpha=0.7, edgecolor='black', label=self.labels[i], zorder=2) @@ -125,7 +125,7 @@ class MapPlot: self.data_plot = [train_pd, val_pd, test_pd] self.labels = ['training set', 'validation set', 'test set'] self.colors = self.available_colors[:len(self.labels)] - self.markers = self.available_markers[:len(self.labels)] + self.markers = self.available_markers[0] * 3 # [:len(self.labels)] self.plot() @@ -177,8 +177,7 @@ class MapPlot: label=self.labels[i], zorder=2) # plt.legend() - plt.savefig(output_dir + 'all_toar.png', dpi = 500) - + plt.savefig(output_dir + 'all_toar.png', dpi=500) if __name__ == '__main__': diff --git a/source/dataset_preanalysis.py b/source/dataset_preanalysis.py index 0af621e83ab8d77b5e1aca0ae541e7a17bf443be..6ce5f212d5389f1ab0432371f2ec679428ac3772 100644 --- a/source/dataset_preanalysis.py +++ b/source/dataset_preanalysis.py @@ -18,7 +18,7 @@ import missingno as msno from utils import read_csv_to_df # settings -from settings import resources_dir, AQbench_variables_file, log_dir, \ +from settings import resources_dir, AQbench_variables_file, \ AQbench_dataset_file, output_dir # info on this file @@ -234,7 +234,7 @@ class PreMis: """ plot = msno.matrix(self.data, fontsize=8, labels=True) fig = plot.get_figure() - fig.savefig(self.plot_dir+'missingno_matrix.png', dpi=500) + fig.savefig(self.plot_dir+'missingno_matrix.png', dpi=100) def missingno_heatmap(self): """ @@ -242,7 +242,7 @@ class PreMis: """ plot = msno.heatmap(self.data, fontsize=8, cmap='coolwarm') fig = plot.get_figure() - fig.savefig(self.plot_dir+'missingno_heatmap.png', dpi=500) + fig.savefig(self.plot_dir+'missingno_heatmap.png', dpi=100) class PreMap: @@ -287,7 +287,6 @@ def main_previs(): previs.vis(col) plt.close() except Exception as exc: - # logging.warning(col + ' no plot produced') print(exc) @@ -302,29 +301,9 @@ def main_premis(): premis.missingno_heatmap() -def main_premap(): - """ - Two plots: - 1) small map with dots of color of metric (relief or satellite) for - graphical abstract - 2) world map (non relief) with one equally colored dot for every TOAR - station to clarify all stations - """ - pass - - if __name__ == '__main__': """ Start routine """ - - # logging - # log_file = __file__.replace('py', 'log') - - # logging.basicConfig( - # level=logging.INFO, - # format="%(asctime)s [%(levelname)s] %(message)s", - # handlers=[logging.FileHandler(log_dir+log_file), - # logging.StreamHandler()]) main_previs() main_premis() diff --git a/source/introduction_jupyter.ipynb b/source/introduction_jupyter.ipynb index d0f0c9ff75cc2e6b82e254c7e69568144fbdc081..1f1d810f3acc1617681d95c17e7e23eb0fe0199c 100644 --- a/source/introduction_jupyter.ipynb +++ b/source/introduction_jupyter.ipynb @@ -6,7 +6,7 @@ "source": [ "# Ozone Mapping Introduction\n", "\n", - "The following Jupyter notebook will introduce you to the AQ-Bench data set, its' data preprocessing and three different baseline experiments: linear regression, neural network and random forest. Let's first import some modules and do some settings." + "The following Jupyter notebook will introduce you to the AQ-Bench data set, its data preprocessing and three different baseline experiments: linear regression, neural network and random forest. Let's first import some modules and setup." ] }, { @@ -18,6 +18,7 @@ "import numpy as np\n", "from matplotlib import pyplot as plt\n", "import pandas as pd\n", + "from importlib import reload\n", "\n", "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, display_html\n", @@ -44,7 +45,7 @@ "source": [ "## AQ-Bench data set\n", "\n", - "Let's have a look to the AQ-Bench data set. You see several meta data and ozone metrics - each of them has its' own collumn. Each row shows data from a different station. It is linked to a unique ID." + "Let's have a look to the AQ-Bench data set. You see several meta data and ozone metrics - each of them has its own column. Each row shows data from a different station. It is linked to a unique ID." ] }, { @@ -82,7 +83,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "If you like to get information about the variables' distributions, feel free to play around with the following widget. Just choose a variable and a plot, which describes the variable, will emerge automatically:" + "If you like to get information about the variables' distributions, feel free to play around with the following widget. Just choose a variable and a logarithmic histogram, which describes the variable, will emerge automatically:" ] }, { @@ -108,7 +109,7 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "The data set has some missing values. Common machine learning algorithms cannot handle this so that rows with missing values will be dropped later. The following two plots (matrix and heatmap) will give you an overview of missing data:" + "The data set has some missing values. Common machine learning algorithms cannot handle this so that rows with missing values will be dropped later. The following plot will give you an overview of missing data:" ] }, { @@ -122,7 +123,6 @@ "premis = PreMis(resources_dir + AQbench_dataset_file, resources_dir + AQbench_variables_file, output_dir)\n", "premis.fill_nan()\n", "premis.missingno_matrix()\n", - "premis.missingno_heatmap()\n", "plt.show()" ] }, @@ -159,7 +159,9 @@ "source": [ "## Datasplit\n", "\n", - "For the baseline experiments, you are going to perform later, the following data split will be used:" + "For the baseline experiments, you are going to perform later, the following data split will be used:\n", + "\n", + "(The plot is interactive, zoom in unsing the buttons below)" ] }, { @@ -168,6 +170,7 @@ "metadata": {}, "outputs": [], "source": [ + "reload(plt)\n", "%matplotlib notebook\n", "datasplit = DataSplit()\n", "datasplit.read_datasplit()\n", @@ -234,13 +237,6 @@ "source": [ "%run mapping_jupyter.ipynb" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { @@ -259,7 +255,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.9" + "version": "3.7.8" } }, "nbformat": 4, diff --git a/source/mapping_linear_regression.py b/source/mapping_linear_regression.py index fada8f31dbf5e85be701b28274037454b105ce6c..220ce1a011971ea83346ab1ca31506cc6e12077b 100644 --- a/source/mapping_linear_regression.py +++ b/source/mapping_linear_regression.py @@ -1,7 +1,4 @@ # data science -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt import sklearn.linear_model # settings diff --git a/source/mapping_neural_network.py b/source/mapping_neural_network.py index e90313bcc3bfc9ef0a21980220ce404545f8e6b6..1ab65cd209cb2da26b1ffc8464a62b440cc0fef7 100644 --- a/source/mapping_neural_network.py +++ b/source/mapping_neural_network.py @@ -17,7 +17,8 @@ __author__ = 'Timo Stomberg' class NeuralNetwork(Model): """ - In this class the main properties and functions are defined to define and test + In this class the main properties and + functions are defined to setup and test neural network models. Input: data object of class Data. """ @@ -186,7 +187,8 @@ class NeuralNetwork(Model): f'L2 \u03BB = {self.l2}, ' \ f'batch size = {self.batch_size}\n' - plot_name = str(self.__class__.__name__) + '_' + naming + '_' + datetime_.strftime('%Y-%m-%d_%H-%M-%S_') + str(np.random.randint(10000, 99999)) + plot_name = str(self.__class__.__name__) + '_' + naming + '_' + \ + datetime_.strftime('%Y-%m-%d_%H-%M-%S_') + str(np.random.randint(10000, 99999)) # Define plot. diff --git a/source/mapping_random_forest.py b/source/mapping_random_forest.py index 8d6d9209d976a5d3b34e7fc7529d24db030ed125..9f7bdad8abbe5ad1c4f677160e3a86b3e0ec4072 100644 --- a/source/mapping_random_forest.py +++ b/source/mapping_random_forest.py @@ -1,12 +1,7 @@ -# data science -import numpy as np -import pandas as pd -import matplotlib.pyplot as plt - -import sklearn.linear_model +# machine learning from sklearn.ensemble import RandomForestRegressor -# settings +# model class from mapping_model import Model __author__ = 'Timo Stomberg' @@ -21,7 +16,8 @@ class RandomForest(Model): Model.__init__(self) def define(self): - self.model = RandomForestRegressor(n_estimators=100, random_state=0, verbose=0, n_jobs=-1) + self.model = RandomForestRegressor(n_estimators=100, random_state=0, + verbose=0, n_jobs=-1) def predict(self): self.pred_train = self.model.predict(self.data.x_train) diff --git a/source/mapping_tensorflow.py b/source/mapping_tensorflow.py index 2b8c7da7f9c7136a7ecf87564be52cb85c72de86..6966fcccdc211c30b2f364e4ee4f85028f7a034d 100644 --- a/source/mapping_tensorflow.py +++ b/source/mapping_tensorflow.py @@ -33,17 +33,21 @@ class NNTensorflow(NeuralNetwork): # other layers for units in self.hidden_layers[1:]: self.model.add(keras.layers.Dense( - units=units, activation=self.activation, kernel_initializer=initializer, + units=units, activation=self.activation, + kernel_initializer=initializer, kernel_regularizer=regularizer)) # last layer self.model.add(keras.layers.Dense( - units=1, activation=None, kernel_initializer=initializer, kernel_regularizer=regularizer)) + units=1, activation=None, kernel_initializer=initializer, + kernel_regularizer=regularizer)) # backpropagation with Adam - optimizer = keras.optimizers.Adam(learning_rate=self.learning_rate, epsilon=1e-08) + optimizer = keras.optimizers.Adam(learning_rate=self.learning_rate, + epsilon=1e-08) - self.model.compile(optimizer=optimizer, loss=self.loss, metrics=['mse', 'mae']) + self.model.compile(optimizer=optimizer, loss=self.loss, + metrics=['mse', 'mae']) def print_model(self): """ @@ -65,11 +69,14 @@ class NNTensorflow(NeuralNetwork): def evaluate(self): - self.eval_train['loss'] = self.model.evaluate(x=self.data.x_train, y=self.data.y_train)[0] - self.eval_val['loss'] = self.model.evaluate(x=self.data.x_val, y=self.data.y_val)[0] + self.eval_train['loss'] = self.model.evaluate(x=self.data.x_train, + y=self.data.y_train)[0] + self.eval_val['loss'] = self.model.evaluate(x=self.data.x_val, + y=self.data.y_val)[0] if self.analyze_test: - self.eval_test['loss'] = self.model.evaluate(x=self.data.x_test, y=self.data.y_test)[0] + self.eval_test['loss'] = self.model.evaluate(x=self.data.x_test, + y=self.data.y_test)[0] NeuralNetwork.evaluate(self) @@ -83,8 +90,8 @@ class NNTensorflow(NeuralNetwork): print("\n--- Training ---\n") fit = self.model.fit( self.data.x_train, self.data.y_train, batch_size=self.batch_size, - epochs=self.epochs, validation_data=(self.data.x_val, self.data.y_val), - verbose=2) + epochs=self.epochs, validation_data=(self.data.x_val, + self.data.y_val), verbose=2) # Save all losses, mse and mae during fitting. diff --git a/source/settings.py b/source/settings.py index 165ee32a8f21ec5a7aa9cc50312b649a80a3142d..217e9666bd44fbd8d280f4b311486cc0817c2847 100644 --- a/source/settings.py +++ b/source/settings.py @@ -4,7 +4,6 @@ and data files that we use in our project. """ # imports -import sys import os import pathlib @@ -14,9 +13,6 @@ CWD = os.getcwd() CWD_pos = pathlib.Path(CWD) ROOTDIR = str(CWD_pos.parent) -# log file directory -log_dir = ROOTDIR + '/log/' - # for plots, evaluation etc... output_dir = ROOTDIR + '/output/' @@ -31,4 +27,3 @@ AQbench_dataset_file = 'AQbench_dataset.csv' AQbench_variables_file = 'AQbench_variables.csv' datasplit_file = 'datasplit.csv' hyperparameters_file = 'hyperparameters.csv' - diff --git a/source/utils.py b/source/utils.py index 9433fcbc58d03e2c40c1979920baa1c3825e97f7..68cec4ffef5dab0513067d26fcbcc5bf2b0ed8f1 100644 --- a/source/utils.py +++ b/source/utils.py @@ -1,19 +1,9 @@ """ Helper functions for our project - -... Requires logging already set up! """ -# general -# import logging -import pdb - # data science import pandas as pd -import json - -# settings -from settings import * def read_csv_to_df(filename, converters=None): @@ -28,7 +18,6 @@ def read_csv_to_df(filename, converters=None): df = pd.read_csv(filename, converters=converters) return df except Exception as exc: - # logging.error(f'Error reading csv file {filename}') print(exc) exit() @@ -41,8 +30,8 @@ def read_pkl_to_df(filename): df = pd.read_pickle(filename) return df except Exception as exc: - # logging.error(f'Error reading pickle {filename}') print(exc) + exit() def save_data_to_file(df, filename): @@ -52,13 +41,10 @@ def save_data_to_file(df, filename): try: if filename.endswith('pkl'): df.to_pickle(filename) - # logging.info(f'Data saved to pickle {filename}') elif filename.endswith('csv'): df.to_csv(filename, index=False) - # logging.info(f'Data saved to csv {filename}') except Exception as exc: - # logging.warning(f'Could not write file {filename}') print(exc) @@ -67,4 +53,3 @@ if __name__ == '__main__': Tryouts """ pass -