From d3011f622b4c2334ef995080f6d3e4a4f6591747 Mon Sep 17 00:00:00 2001
From: Felix Kleinert <f.kleinert@fz-juelich.de>
Date: Thu, 24 Sep 2020 10:11:45 +0200
Subject: [PATCH] update examples

---
 Examples_from_manuscript.ipynb | 12 ++++++++----
 1 file changed, 8 insertions(+), 4 deletions(-)

diff --git a/Examples_from_manuscript.ipynb b/Examples_from_manuscript.ipynb
index 5c89cd9a..dd258554 100644
--- a/Examples_from_manuscript.ipynb
+++ b/Examples_from_manuscript.ipynb
@@ -105,12 +105,12 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# Figure 5\n",
     "import keras\n",
     "from keras.losses import mean_squared_error as mse\n",
     "from keras.layers import PReLU, Input, Conv2D, Flatten, Dropout, Dense\n",
     "\n",
     "from mlair.model_modules import AbstractModelClass\n",
+    "from mlair.workflows import DefaultWorkflow\n",
     "\n",
     "class MyCustomisedModel(AbstractModelClass):\n",
     "\n",
@@ -129,14 +129,14 @@
     "        self.set_custom_objects(loss=self.compile_options['loss'])\n",
     "\n",
     "    def set_model(self):\n",
-    "        x_input = Input(shape=self.shape_inputs)\n",
+    "        x_input = Input(shape=self._input_shape)\n",
     "        x_in = Conv2D(4, (1, 1))(x_input)\n",
     "        x_in = PReLU()(x_in)\n",
     "        x_in = Flatten()(x_in)\n",
     "        x_in = Dropout(0.1)(x_in)\n",
     "        x_in = Dense(16)(x_in)\n",
     "        x_in = PReLU()(x_in)\n",
-    "        x_in = Dense(self.shape_outputs)(x_in)\n",
+    "        x_in = Dense(self._output_shape)(x_in)\n",
     "        out = PReLU()(x_in)\n",
     "        self.model = keras.Model(inputs=x_input, outputs=[out])\n",
     "\n",
@@ -144,7 +144,11 @@
     "        self.initial_lr = 1e-2\n",
     "        self.optimizer = keras.optimizers.SGD(lr=self.initial_lr, momentum=0.9)\n",
     "        self.loss = mse\n",
-    "        self.compile_options = {\"metrics\": [\"mse\", \"mae\"]}\n"
+    "        self.compile_options = {\"metrics\": [\"mse\", \"mae\"]}\n",
+    "\n",
+    "# Make use of MyCustomisedModel within the DefaultWorkflow\n",
+    "workflow = DefaultWorkflow(model=MyCustomisedModel, epochs=2)\n",
+    "workflow.run()\n"
    ]
   },
   {
-- 
GitLab