diff --git a/projects/asos/21_train.py b/projects/asos/21_train.py
index 4ee2756bea7f7f82b12688d65bb4c4159ef3fe4e..f1a12df6279982ab5f5587b939859d5ef8bf9407 100644
--- a/projects/asos/21_train.py
+++ b/projects/asos/21_train.py
@@ -26,6 +26,7 @@ if config.dataset in ['anthroprotect', 'places']:
         'unet_mode': 'bilinear',  # standard UNet has None, we use 'bilinear'
         'unet_activation': nn.Tanh(),
 
+        'dropout': None,  # standard UNet has None
         'final_activation': nn.Sigmoid(),  # nn.Sigmoid() or nn.Softmax(dim=1)
     }
 
diff --git a/projects/asos/22_load_model_and_data.ipynb b/projects/asos/22_load_model_and_data.ipynb
new file mode 100644
index 0000000000000000000000000000000000000000..56cf1811d8083331a079d9a894243f71da56c40c
--- /dev/null
+++ b/projects/asos/22_load_model_and_data.ipynb
@@ -0,0 +1,100 @@
+{
+ "cells": [
+  {
+   "cell_type": "markdown",
+   "metadata": {},
+   "source": [
+    "# Load Model and Data\n",
+    "\n",
+    "If you want to run your own experiments with the model, you can simply load it as follows."
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "from projects.asos import utils\n",
+    "\n",
+    "%load_ext autoreload\n",
+    "%autoreload 2"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# load model and data\n",
+    "\n",
+    "trainer = utils.load_trainer()\n",
+    "\n",
+    "model = trainer.model\n",
+    "datamodule = trainer.datamodule\n",
+    "\n",
+    "test_dataset = datamodule.test_dataset\n",
+    "test_dataloader = datamodule.get_dataloader('test')"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# run single sample\n",
+    "\n",
+    "index = 0\n",
+    "sample = test_dataset[index]\n",
+    "print('file:', sample['file'])\n",
+    "\n",
+    "pred = model(sample['x'].unsqueeze(0))\n",
+    "print('prediction:', pred)\n",
+    "print('label:', sample['y'])"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# run single batch\n",
+    "\n",
+    "batch = next(iter(test_dataloader))\n",
+    "pred = model(batch['x'])\n",
+    "pred"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": "Python 3 (ipykernel)",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.9.12"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/projects/asos/34_analyze_samples.ipynb b/projects/asos/34_analyze_samples.ipynb
index e3678382a32941ac484089c6b1a9635e867892a0..051bc9187714b430d6437a003d43fc04de1a7d92 100644
--- a/projects/asos/34_analyze_samples.ipynb
+++ b/projects/asos/34_analyze_samples.ipynb
@@ -65,7 +65,7 @@
     "# get files from given dataset\n",
     "fi = utils.load_file_infos()\n",
     "df = fi.df\n",
-    "df = df[df['datasplit'].isin(datasets)]\n",
+    "df = df[df['dataset'].isin(datasets)]\n",
     "print(f'# files: {len(df)}')\n",
     "files = df.index\n",
     "\n",
diff --git a/projects/asos/modules.py b/projects/asos/modules.py
index 6630f9675b51151b4058cf5aca24dee9b58e3e50..7bb167fb4f464186be53554a01f19d439020b70e 100644
--- a/projects/asos/modules.py
+++ b/projects/asos/modules.py
@@ -24,7 +24,7 @@ class Model(ttorch.model.Module):
     :param final_activation: activation function of the classifier; can be nn.Sigmoid() or nn.Softmax(dim=1)
     """
 
-    def __init__(
+    def tinit(
             self,
 
             in_channels: int,
@@ -37,13 +37,10 @@ class Model(ttorch.model.Module):
             unet_mode: str,  # 'bilinear', 'nearest' or None
             unet_activation,  # e.g. nn.Tanh()
 
+            dropout: float,  # standard UNet has None
             final_activation,  # nn.Sigmoid() or nn.Softmax(dim=1)
     ):
 
-        super().__init__()
-
-        self.final_activation = final_activation
-
         # unet
         
         unet_kwargs = {
@@ -52,6 +49,7 @@ class Model(ttorch.model.Module):
             'base_channels': unet_base_channels,
             'batch_norm': batch_norm,
             'double_conv': double_conv,
+            'dropout': dropout,
             'final_activation': unet_activation,
         }
 
@@ -65,21 +63,28 @@ class Model(ttorch.model.Module):
         self.random_occlusion = ttorch.modules.operations.RandomPixelOcclusion(probability=0.2)
 
         # classifier
-        
+
+        dropout = dropout if dropout is not None else 0
+
         self.classifier = nn.Sequential(
+            #  nn.Dropout(p=dropout), do not run because of random occlusions
             nn.Conv2d(n_unet_maps, 2 * n_unet_maps, kernel_size=5, stride=3),
             nn.ReLU(inplace=True),
 
+            nn.Dropout(p=dropout),
             nn.Conv2d(2 * n_unet_maps, 4 * n_unet_maps, kernel_size=5, stride=3),
             nn.ReLU(inplace=True),
 
+            nn.Dropout(p=dropout),
             nn.Conv2d(4 * n_unet_maps, 8 * n_unet_maps, kernel_size=5, stride=3),
             nn.ReLU(inplace=True),
 
             nn.Flatten(),
+            nn.Dropout(p=dropout),
             nn.Linear(512 * n_unet_maps, 128),
             nn.ReLU(inplace=True),
             
+            nn.Dropout(p=dropout),
             nn.Linear(128, n_classes),
         )
 
@@ -113,7 +118,7 @@ class Model(ttorch.model.Module):
 
         pred = self.classifier(unet_map)
 
-        if self.final_activation is not None:
-            pred = self.final_activation(pred)
+        if self.hparams['final_activation'] is not None:
+            pred = self.hparams['final_activation'](pred)
 
         return pred
diff --git a/setup.py b/setup.py
index ea44994fcf48805b58df897a4978b6073c32f8b8..f958ce9c3ad964c15c051755fbc022e15dde73c2 100644
--- a/setup.py
+++ b/setup.py
@@ -4,7 +4,7 @@ from setuptools import setup
 setup(
     name='tlib',
     version='1.0',
-    description='This code belongs to the research article "Exploring Wilderness Using Explainable Machine Learning in Satellite Imagery" (2022) by Timo T. Stomberg, Taylor Stone, Johannes Leonhardt, Immanuel Weber, and Ribana Roscher (https://doi.org/10.48550/arXiv.2203.00379).',
+    description='This code belongs to the research article "Exploring Wilderness Characteristics Using Explainable Machine Learning in Satellite Imagery" (2022) by Timo T. Stomberg, Taylor Stone, Johannes Leonhardt, Immanuel Weber, and Ribana Roscher (https://doi.org/10.48550/arXiv.2203.00379).',
     license='MIT',
     author='Timo Tjaden Stomberg',
     author_email='timo.stomberg@uni-bonn.de',
@@ -27,8 +27,6 @@ setup(
     ],
 
     install_requires=[
-        'gdal',
-
         'numpy',
         'pandas',
         'scikit-learn',
diff --git a/tlib/tgeo/tgdal.py b/tlib/tgeo/tgdal.py
index 21fa574d30335715f5c2d817a86f38682743ca4b..d47b841940dc590c4a0a05b985cd42d6d2b128d9 100644
--- a/tlib/tgeo/tgdal.py
+++ b/tlib/tgeo/tgdal.py
@@ -1,10 +1,15 @@
 import os
 import warnings
 
-from osgeo import gdal
 import numpy as np
 from tqdm.auto import tqdm
 
+try:
+    from osgeo import gdal
+except:
+    warnings.warn('WARNING: Python package osgeo.gdal is not available. Please run \'conda install gdal\' in your terminal.')
+
+
 
 def tile_file(file, output_folder, positions, tile_dims, disable_tqdm: bool = False):
     # open file with gdal
diff --git a/tlib/ttorch/data/images.py b/tlib/ttorch/data/images.py
index 7159c51351973e7da003b6db703ee79148fdc2f2..eaef8b05e4945144295f5a286e70d511b885f701 100644
--- a/tlib/ttorch/data/images.py
+++ b/tlib/ttorch/data/images.py
@@ -11,7 +11,7 @@ import torch.nn.functional as F
 import torchvision
 from tqdm.auto import tqdm
 
-from tlib.ttorch import data, utils  # circular import is caused if not imported like this
+from tlib.ttorch import data, utils
 
 
 def rotate_randomly_90(t):
@@ -116,7 +116,10 @@ class Dataset(data.basic.Dataset):
     :param x_normalization: tuple with mean and std (mean, std) to normalize input images
     :param clip_range: tubpe with min and max value (min, max) to clip input images
     :param rotate: bool if rotation data augmentation shall be performed (random rotation of 0°, 90°, 180° or 270°)
+    :param resize: resizes an image; tuple; shape (h, w) of the target size
+    :param crop: randomly crops and image; tuple; shape (h, w) of the target size
     :param cutmix: bool if cutmix shall be performed
+    
     :param y_normalization: tuple with mean and std (mean, std) to normalize target data
     :param n_classes: number of classes (labels); needed to create one-hot vectors
 
@@ -137,8 +140,12 @@ class Dataset(data.basic.Dataset):
         channels: list = None,
         x_normalization: tuple = None,
         clip_range: tuple = None,
+        resize: tuple = None,  # (h, w)
+        crop: tuple = None,  # (h, w)
+
         rotate: bool = False,
         cutmix: float = None,
+
         y_normalization: tuple = None,
         n_classes: int = None,
 
@@ -154,7 +161,10 @@ class Dataset(data.basic.Dataset):
         self.x_normalization = x_normalization
         self.clip_range = clip_range
         self.rotate = rotate
+        self.resize = torchvision.transforms.Resize(size=resize, antialias=True) if resize is not None else None
+        self.crop = torchvision.transforms.RandomCrop(size=crop) if crop is not None else None
         self.cutmix = cutmix
+
         self.y_normalization = y_normalization
         self.n_classes = n_classes
 
@@ -230,6 +240,17 @@ class Dataset(data.basic.Dataset):
         if self.clip_range is not None:
             x =  torch.clip(x, min=self.clip_range[0], max=self.clip_range[1])
 
+        # resize
+        if self.resize is not None:
+            x = self.resize(x)
+
+        # crop
+        if self.crop is not None:
+            x = self.crop(x)
+            if self.cutmix is not None:
+                print(self.cutmix)
+                warnings.warn('WARNING: You have defined CutMix as well as cropping. Note that CutMix is performed first.')
+        
         # x rotation
         if self.rotate and self.training:
             x = rotate_randomly_90(x)
@@ -312,6 +333,7 @@ class DataModule(data.basic.DataModule):
         'label': label
         'datasplit': 'train', 'val' or 'test'
     :param folder: path to main folder in which files are located (or from which file paths start)
+    :param crop_test_data: bool if test data shall be cropped; test data is only cropped if crop is not None
 
     :param batch_size: batch size for dataloader
     :param num_workers: number of workers to load data
@@ -327,8 +349,13 @@ class DataModule(data.basic.DataModule):
         channels: list = None,
         x_normalization: tuple = None,
         clip_range: tuple = None,
+        resize: tuple = None,  # (h, w)
+        crop: tuple = None,  # (h, w)
+        crop_test_data: bool = True,
+
         rotate: bool = False,
         cutmix: float = None,
+
         y_normalization: tuple = None,
         n_classes: int = None,
 
@@ -345,6 +372,10 @@ class DataModule(data.basic.DataModule):
         self.batch_size = batch_size  # used in inherited methods
         self.num_workers = num_workers  # used in inherited methods
 
+        self.resize = resize
+        self.crop = crop
+        self.crop_test_data = crop_test_data
+
         self.rgb_channels = rgb_channels
         self.val_range = val_range
 
@@ -352,6 +383,9 @@ class DataModule(data.basic.DataModule):
             'channels': channels,
             'x_normalization': x_normalization,
             'clip_range': clip_range,
+
+            'resize': self.resize,
+
             'y_normalization': y_normalization,
             'n_classes': n_classes,
 
@@ -363,6 +397,7 @@ class DataModule(data.basic.DataModule):
         self.train_dataset = self.DatasetClass(
             files=self.get_dirs(self.file_infos[self.file_infos['datasplit'] == 'train']),
             labels=self.file_infos[self.file_infos['datasplit'] == 'train']['label'].to_list(),
+            crop=self.crop,
             rotate=rotate,
             cutmix=cutmix,
             **self.dataset_kwargs,
@@ -371,16 +406,18 @@ class DataModule(data.basic.DataModule):
         self.val_dataset = self.DatasetClass(
             files=self.get_dirs(self.file_infos[self.file_infos['datasplit'] == 'val']),
             labels=self.file_infos[self.file_infos['datasplit'] == 'val']['label'].to_list(),
+            crop=self.crop,
             rotate=False,
-            cutmix=False,
+            cutmix=None,
             **self.dataset_kwargs,
         )
 
         self.test_dataset = self.DatasetClass(
             files=self.get_dirs(self.file_infos[self.file_infos['datasplit'] == 'test']),
             labels=self.file_infos[self.file_infos['datasplit'] == 'test']['label'].to_list(),
+            crop=self.crop if self.crop_test_data else None,
             rotate=False,
-            cutmix=False,
+            cutmix=None,
             **self.dataset_kwargs,
         )
     
@@ -401,8 +438,15 @@ class DataModule(data.basic.DataModule):
         if prepend_folder:
             files = [os.path.join(self.folder, file) for file in files]
         labels = [0] * len(files) if labels is None else labels  # labels cannot be None
-
-        return self.DatasetClass(files=files, labels=labels, rotate=False, cutmix=False, **self.dataset_kwargs)
+        
+        return self.DatasetClass(
+            files=files,
+            labels=labels,
+            rotate=False,
+            crop=self.crop if self.crop_test_data else None,
+            cutmix=None,
+            **self.dataset_kwargs
+        )
         
     def get_dirs(self, df=None):
             """
diff --git a/tlib/ttorch/model.py b/tlib/ttorch/model.py
index 63b72028cc87b16b6a7bae17c2ccf5057a64b5ee..14f966d348d97dfe69f258e263c6c54cf31a3350 100644
--- a/tlib/ttorch/model.py
+++ b/tlib/ttorch/model.py
@@ -1,4 +1,6 @@
 import warnings
+import inspect
+import importlib
 
 import torch
 import torch.nn as nn
@@ -7,19 +9,58 @@ import torch.nn as nn
 class Module(nn.Module):
     """
     Contains the very basic of a torch Module.
-    You need to overwrite the __init__() function and the tforward function. Use the tforward function like the pytorch
-    forward function, except of: Do not run the final activation here, but just define it in the __init__ function
-    as follows: self.final_activation = ...
+    You need to overwrite the tinit() function and the tforward function.
+    
+    Use the tinit function like the pytorch __init__() function EXCEPT of calling super().__init__().
+    You can but don't have to set final_activation as a parameter. If you don't set it, it must be given anyway
+    or a warning will occur. E.g.:
+
+    def tinit(self, a, b=3, final_activation=nn.Tanh()):
+        self.layer = ...
+
+    Use the tforward function like the pytorch forward function, EXCEPT of: Do not run the final activation here,
+    but set it as a variable in the init function.
 
     :param final_activation: torch activation function or object
     """
 
-    def __init__(self, final_activation=None):
+    def __init__(self_, **hparams):
 
         super().__init__()
 
-        self.final_activation = final_activation
+        # save hyperparams
+        self_.hparams = hparams
+
+        # remove keys: self, __class__
+        if 'self' in self_.hparams.keys():
+            del self_.hparams['self']
+        if '__class__' in self_.hparams.keys():
+            del self_.hparams['__class__']
 
+        # add keys that are standard values in self_.tinit() and not defined when creating object
+        params = inspect.signature(self_.tinit).parameters
+        params = {v.name: v.default for (v, v) in params.items()}  # to dict
+        params = {v: k for (v, k) in params.items() if k is not inspect.Parameter.empty and v not in self_.hparams.keys()}
+        self_.hparams = {**self_.hparams, **params}
+        
+        # if final_activation not given, print warning and set to None
+        if 'final_activation' not in self_.hparams.keys():
+            warnings.warn('WARNING: Your __init__() takes no argument \'final_activation\'. Therefore final_activation is set to None.')
+            self_.hparams['final_activation'] = None
+
+        # sort by name
+        self_.hparams = dict(sorted(self_.hparams.items()))
+        
+        # run self.tinit()
+        if 'final_activation' not in inspect.signature(self_.tinit).parameters.keys():  # remove it for kwargs
+            kwargs_for_tinit = {k: v for (k, v) in self_.hparams.items() if k != 'final_activation'}
+        else:
+            kwargs_for_tinit = self_.hparams
+        self_.tinit(**kwargs_for_tinit)
+
+    def tinit(self):
+        pass
+    
     @property
     def device(self):
         
@@ -70,47 +111,92 @@ class Module(nn.Module):
 
         x = self.forward_no_activation(x)
 
-        if self.final_activation is not None:
-            x = self.final_activation(x)
+        if self.hparams['final_activation'] is not None:
+            x = self.hparams['final_activation'](x)
 
         return x
 
-    def save(self, path):
+    def get_save_dict(self, **info_kwargs):
         """
-        Saves the pickled model and the model_state_dict to given path. Only one of them is needed to reload the model
-        (but currently both are saved).
+        Returns dictionary to save with torch.save() including model_type, model_hparams, model_state_dict and info_kwargs.
 
-        :param path: path where model is stored
-        :return:
+        :param info_kwargs: dictionary with infos, e.g. {'epoch': 5, 'iteration': 3012}
+        :return: save_dict including model_type, model_hparams, model_state_dict and info_kwargs
         """
 
         state_changer = ModelStateChanger()
-        state_changer(model=self.model, state='eval')
+        state_changer(model=self, state='eval')
+
+        save_dict = {
+            'model_type': str(type(self)).split(' ')[-1][1:-2],
+            'model_hparams': self.hparams,
+            'model_state_dict': self.state_dict(),
+            **info_kwargs
+        }
 
-        torch.save({
-            'model': self,
-            'model_state_dict': self.model.state_dict(),
-        }, path)
+        state_changer.reverse(model=self)  # return to train mode if model was in that mode before training
 
-        state_changer.reverse(model=self.model)  # return to train mode if model was in that mode before training
+        return save_dict
 
-    def load(self, path, device='cuda'):
+    def save(self, path, **info_kwargs):
         """
-        Load model from given path using model_state_dict within dictionary into this object.
-        Infos for loading models with pytorch: https://pytorch.org/tutorials/beginner/saving_loading_models.html
+        Saves the pickled model and the model_state_dict to given path. Only one of them is needed to reload the model
+        (but currently both are saved).
 
-        :param path: path to model checkpoint, which must be a dictionary with entry 'model_state_dict'
-        :param device: available and to be used device ('cuda', 'cuda:<cuda_id>' or 'cpu'; default: 'cuda')
-        :return: self (model); but also overwrites own __dict__
+        :param path: path where model is stored
+        :param info_kwargs: dictionary with infos, e.g. {'epoch': 5, 'iteration': 3012}
+        :return:
         """
 
-        checkpoint = torch.load(path, map_location=torch.device(device))
-        model_dict = checkpoint['model_state_dict']
+        save_dict = self.get_save_dict(**info_kwargs)
+        torch.save(save_dict, path)
+
+
+def load_model(checkpoint, ModelClass=None, model_hparams=None, device='cuda'):
+    """
+    Loads a model from checkpoint_path and uses model_hparams to load model without the need of initialization
+    parameters.
+
+    :param checkpoint: path to model checkpoint, e.g. model.pt, OR dict from checkpoint
+    :param ModelClass: class of the model that shall be loaded
+    :param device: 'cuda' or 'cpu'
+    :param model_hparams: if checkpoint has no key 'model_hparams' this can be used to define the model_hparams
+    """
+
+    if isinstance(checkpoint, dict): 
+        checkpoint = checkpoint
+    else:  # assume it is apath
+        checkpoint = torch.load(checkpoint, map_location=torch.device(device))
+
+    # ModelClass
+    if ModelClass is not None:
+        ModelClass = ModelClass
+    elif 'model_type' in checkpoint.keys():
+        model_type = checkpoint['model_type']
+        model_module = '.'.join(model_type.split('.')[:-1])
+        model_class = model_type.split('.')[-1]
+        ModelClass = getattr(importlib.import_module(model_module), model_class)
+    else:
+        warnings.warn('model_type is not stored in checkpoint and you have not defined a ModelClass in load function!')
+        return
+
+    # model_hparams
+    if model_hparams is not None:
+        model_hparams = model_hparams
+    elif 'model_hparams' in checkpoint.keys():
+        model_hparams = checkpoint['model_hparams']
+    else:
+        warnings.warn('model_hparams are not stored in checkpoint and you have not defined model_haparms in load function!')
+        return
         
-        self.__dict__.clear()
-        self.__dict__.update(model_dict)
+    # state dict
+    model_state_dict = checkpoint['model_state_dict']
+
+    # load model
+    model = ModelClass(**model_hparams)
+    model.load_state_dict(model_state_dict)
 
-        return self
+    return model
 
 
 class ModelStateChanger:
diff --git a/tlib/ttorch/modules/unet.py b/tlib/ttorch/modules/unet.py
index 4a6b5d71ac6a27ac5da0195fa570de6280d504a2..9fec38903cd6455ae612b72340e030b2760b1197 100644
--- a/tlib/ttorch/modules/unet.py
+++ b/tlib/ttorch/modules/unet.py
@@ -23,6 +23,7 @@ class StandardUNet(nn.Module):
             base_channels: int = 64,
             double_conv: bool = True,
             batch_norm: bool = False,
+            dropout: float = None,
             final_activation=None
     ):
 
@@ -30,21 +31,21 @@ class StandardUNet(nn.Module):
 
         if double_conv:
             self.conv_in = nn.Sequential(
-                ConvStep(in_channels, base_channels, batch_norm=batch_norm),
-                ConvStep(base_channels, base_channels, batch_norm=batch_norm)
+                ConvStep(in_channels, base_channels, batch_norm=batch_norm, dropout=0),  # no dropout in the beginning
+                ConvStep(base_channels, base_channels, batch_norm=batch_norm, dropout=0)  # no dropout in the beginning
             )
         else:
-            self.conv_in = ConvStep(in_channels, base_channels, batch_norm=batch_norm)
+            self.conv_in = ConvStep(in_channels, base_channels, batch_norm=batch_norm, dropout=0)  # no dropout in the beginning
 
-        self.enc1 = EncodingStep(base_channels, double_conv=double_conv, batch_norm=batch_norm)  # 64 in standard configuration
-        self.enc2 = EncodingStep(base_channels * 2, double_conv=double_conv, batch_norm=batch_norm)  # 128 in standard configuration
-        self.enc3 = EncodingStep(base_channels * 4, double_conv=double_conv, batch_norm=batch_norm)  # 256 in standard configuration
-        self.enc4 = EncodingStep(base_channels * 8, double_conv=double_conv, batch_norm=batch_norm)  # 512 in standard configuration
+        self.enc1 = EncodingStep(base_channels, double_conv=double_conv, batch_norm=batch_norm, dropout=0)  # 64 in standard configuration (no dropout in the beginning)
+        self.enc2 = EncodingStep(base_channels * 2, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout)  # 128 in standard configuration
+        self.enc3 = EncodingStep(base_channels * 4, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout)  # 256 in standard configuration
+        self.enc4 = EncodingStep(base_channels * 8, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout)  # 512 in standard configuration
 
-        self.dec1 = DecodingStep(base_channels * 16, double_conv=double_conv, batch_norm=batch_norm)  # 1024 in standard configuration
-        self.dec2 = DecodingStep(base_channels * 8, double_conv=double_conv, batch_norm=batch_norm)  # 512 in standard configuration
-        self.dec3 = DecodingStep(base_channels * 4, double_conv=double_conv, batch_norm=batch_norm)  # 256 in standard configuration
-        self.dec4 = DecodingStep(base_channels * 2, double_conv=double_conv, batch_norm=batch_norm)  # 128 in standard configuration
+        self.dec1 = DecodingStep(base_channels * 16, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout)  # 1024 in standard configuration
+        self.dec2 = DecodingStep(base_channels * 8, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout)  # 512 in standard configuration
+        self.dec3 = DecodingStep(base_channels * 4, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout)  # 256 in standard configuration
+        self.dec4 = DecodingStep(base_channels * 2, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout)  # 128 in standard configuration
 
         self.conv_out = nn.Conv2d(base_channels, out_channels, kernel_size=1)
 
@@ -96,6 +97,7 @@ class UpsamplingUNet(StandardUNet):
             double_conv: bool = True,
             batch_norm: bool = False,
             final_activation=None,
+            dropout: float = None,
             mode: str = 'bilinear'
     ):
 
@@ -105,16 +107,17 @@ class UpsamplingUNet(StandardUNet):
             base_channels=base_channels,
             double_conv=double_conv,
             batch_norm=batch_norm,
+            dropout=dropout,
             final_activation=final_activation
         )
 
-        self.enc4 = EncodingStep(base_channels * 8, base_channels * 8, double_conv=double_conv, batch_norm=batch_norm)
+        self.enc4 = EncodingStep(base_channels * 8, base_channels * 8, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout)
         # 512, 512 in standard configuration
 
-        self.dec1 = DecodingStep(base_channels * 16, base_channels * 4, double_conv=double_conv, batch_norm=batch_norm, mode=mode)
-        self.dec2 = DecodingStep(base_channels * 8, base_channels * 2, double_conv=double_conv, batch_norm=batch_norm, mode=mode)
-        self.dec3 = DecodingStep(base_channels * 4, base_channels, double_conv=double_conv, batch_norm=batch_norm, mode=mode)
-        self.dec4 = DecodingStep(base_channels * 2, double_conv=double_conv, batch_norm=batch_norm, mode=mode)
+        self.dec1 = DecodingStep(base_channels * 16, base_channels * 4, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout, mode=mode)
+        self.dec2 = DecodingStep(base_channels * 8, base_channels * 2, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout, mode=mode)
+        self.dec3 = DecodingStep(base_channels * 4, base_channels, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout, mode=mode)
+        self.dec4 = DecodingStep(base_channels * 2, double_conv=double_conv, batch_norm=batch_norm, dropout=dropout, mode=mode)
 
 
 class ConvStep(nn.Module):
@@ -126,19 +129,28 @@ class ConvStep(nn.Module):
     :param batch_norm: bool if batch normalization shall be performed after convolution
     """
 
-    def __init__(self, in_channels: int, out_channels: int, batch_norm: bool = False):
+    def __init__(self, in_channels: int, out_channels: int, batch_norm: bool = False, dropout: float = None):
 
         super().__init__()
 
         self.perform_batch_norm = batch_norm
+        self.perform_dropout = True if dropout is not None else False
+
+        if self.perform_dropout:
+            self.drop = nn.Dropout(p=dropout)
 
         self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
+
         if self.perform_batch_norm:
             self.norm = nn.BatchNorm2d(out_channels)
+
         self.relu = nn.ReLU(inplace=True)
 
     def forward(self, x):
 
+        if 'perform_dropout' in self.__dict__.keys():  # TODO This is because old pickled modules didn't have this property
+            if self.perform_dropout:
+                x = self.drop(x)
         x = self.conv(x)
         if self.perform_batch_norm:
             x = self.norm(x)
@@ -157,7 +169,9 @@ class EncodingStep(nn.Module):
     :param batch_norm: bool if batch normalization shall be performed after each convolution
     """
 
-    def __init__(self, in_channels: int, out_channels: int = None, double_conv: bool = True, batch_norm: bool = False):
+    def __init__(
+        self, in_channels: int, out_channels: int = None, double_conv: bool = True, batch_norm: bool = False,
+        dropout: float = None):
 
         super().__init__()
 
@@ -168,11 +182,11 @@ class EncodingStep(nn.Module):
 
         if double_conv:
             self.conv = nn.Sequential(
-                ConvStep(in_channels, out_channels, batch_norm=batch_norm),
-                ConvStep(out_channels, out_channels, batch_norm=batch_norm)
+                ConvStep(in_channels, out_channels, batch_norm=batch_norm, dropout=dropout),
+                ConvStep(out_channels, out_channels, batch_norm=batch_norm, dropout=dropout)
             )
         else:
-            self.conv = ConvStep(in_channels, out_channels, batch_norm=batch_norm)
+            self.conv = ConvStep(in_channels, out_channels, batch_norm=batch_norm, dropout=dropout)
 
     def forward(self, x):
 
@@ -200,7 +214,8 @@ class DecodingStep(nn.Module):
         out_channels: int = None,
         mode: str = None,
         double_conv: bool = True,
-        batch_norm: bool = False
+        batch_norm: bool = False,
+        dropout: float = None
     ):
 
         super().__init__()
@@ -217,11 +232,11 @@ class DecodingStep(nn.Module):
             # convolution(s)
             if double_conv:
                 self.conv = nn.Sequential(
-                    ConvStep(in_channels, out_channels, batch_norm=batch_norm),
-                    ConvStep(out_channels, out_channels, batch_norm=batch_norm)
+                    ConvStep(in_channels, out_channels, batch_norm=batch_norm, dropout=dropout),
+                    ConvStep(out_channels, out_channels, batch_norm=batch_norm, dropout=dropout)
                 )
             else:
-                self.conv = ConvStep(in_channels, out_channels, batch_norm=batch_norm)
+                self.conv = ConvStep(in_channels, out_channels, batch_norm=batch_norm, dropout=dropout)
 
         else:
             self.dec = nn.Upsample(scale_factor=2, mode=mode, align_corners=True)
@@ -229,11 +244,11 @@ class DecodingStep(nn.Module):
             # convolution(s)
             if double_conv:
                 self.conv = nn.Sequential(
-                    ConvStep(in_channels, int(in_channels / 2), batch_norm=batch_norm),
-                    ConvStep(int(in_channels / 2), out_channels, batch_norm=batch_norm)
+                    ConvStep(in_channels, int(in_channels / 2), batch_norm=batch_norm, dropout=dropout),
+                    ConvStep(int(in_channels / 2), out_channels, batch_norm=batch_norm, dropout=dropout)
                 )
             else:
-                self.conv = ConvStep(in_channels, out_channels, batch_norm=batch_norm)
+                self.conv = ConvStep(in_channels, out_channels, batch_norm=batch_norm, dropout=dropout)
 
     def forward(self, current_x, recurrent_x):
 
diff --git a/tlib/ttorch/train.py b/tlib/ttorch/train.py
index 4ecbd98bb8eb59a00453b86f239970a8c10e7bbd..154565cb8058625df7c608204834cf44977b0340 100644
--- a/tlib/ttorch/train.py
+++ b/tlib/ttorch/train.py
@@ -14,6 +14,7 @@ from tqdm import tqdm as tqdm_dataloader  # tqdm(dataloader) causes an error if
 import psutil
 
 from tlib import tutils, ttorch, tlearn, tutils
+import tlib  # to save scripts in training
 
 
 def define_params_for_weight_decay(model, weight_decay: float, exclude: list = ('bias', 'bn')):
@@ -213,7 +214,7 @@ class R2(Metric):
         
         _, y = ttorch.utils.get_batch(batch)
 
-        pred = pred.detact().cpu()
+        pred = pred.detach().cpu()
         y = y.detach().cpu()
 
         value = tlearn.metrics.r2(labels=y, preds=pred)
@@ -262,8 +263,8 @@ class Trainer:
     :param lr: learning rate
     :param weight_decay: weight decay
     :param exclude_bias_from_weight_decay: bool if bias shall be excluded from weight decay
-    :param one_cycle_lr_epochs: bool if learning rate shall be adapted according to one cycle learning
-
+    :param one_cycle_lr_epochs: number of epochs as information for one cycle learning; if None, no one cycle learning
+        is performed
     :param metrics: list of metrics (metric objects of classes defined above) that shall be logged
     :param reduce_logging: reduces logging, e.g. does not store model checkpoint etc. every epoch but only at the end;
         should be set to True if an epoch runs very short
@@ -358,7 +359,7 @@ class Trainer:
             #self.writer = SummaryWriter(buffer)
             self.log_dir = None
             self.writer = None
-            if self.datamodule is not None: warnings.warn('WARNING: Trainer: You did not define a log_dir and no logging is performed')
+            if self.datamodule is not None: warnings.warn('You did not define a log_dir for your trainer and no logging is performed')
 
     def init_model(self, model=None, cuda=True):
         """
@@ -371,16 +372,23 @@ class Trainer:
 
         self.model = model
 
-        if cuda and torch.cuda.is_available() and self.model is not None:
-            self.model = self.model.cuda()
-        if model is not None:
-            print(f'Model is on device: {next(self.model.parameters()).device}')
+        if self.model is not None:
+
+            if not isinstance(self.model, ttorch.model.Module):
+                warnings.warn(f'WARNING: Your model is not instance of ttorch.model.Module but probably just an instance of nn.Module! Its type is: {type(model)}. This will probably cause issues using this trainer.')
 
-        # get sample and add model graph to tensorboard
-        if self.writer is not None and self.datamodule is not None and self.datamodule.train_dataset is not None:
-            sample = next(iter(self.datamodule.get_dataloader('train')))
-            x_sample = sample['x'] if isinstance(sample, dict) else sample[0]
-            self.writer.add_graph(self.model, x_sample)
+            if cuda and torch.cuda.is_available():
+                self.model = self.model.cuda()
+            #print(f'Model is on device: {next(self.model.parameters()).device}')
+
+            # get sample and add model graph to tensorboard
+            if self.writer is not None and self._current_iter == 0 and self.datamodule is not None and self.datamodule.train_dataset is not None:
+                sample = next(iter(self.datamodule.get_dataloader('train')))
+                x_sample = sample['x'] if isinstance(sample, dict) else sample[0]
+                try:
+                    self.writer.add_graph(self.model, torch.Tensor(x_sample))
+                except:
+                    warnings.warn('writer.add_graph does not work with unflatten layer. Therefore graph is not logged and NYI messages shows up.')
     
     def init_criterion(self, criterion=None):
         """
@@ -393,7 +401,7 @@ class Trainer:
         if criterion is not None or self.model is None:
             self.criterion = criterion
         else:
-            self.criterion = get_recommended_criterion(self.model.final_activation)
+            self.criterion = get_recommended_criterion(self.model.hparams['final_activation'])
             print(f'Criterion has been set to: {self.criterion}')
 
     def init_optimizer(
@@ -460,35 +468,47 @@ class Trainer:
         # epochs loop
         prog_bar = tqdm(range(self._current_epoch + 1, self._current_epoch + epochs + 1), desc='epochs')
         self._prog_bar_epochs_dict = {}  # to show infos in tqdm progressbar
+        self.log_time(mode='start', desc='whole/epoch')
         for epoch in prog_bar:
 
             self._current_epoch = epoch
 
             # run train epoch
+            self.log_time(mode='start', desc='train/epoch')
             self.train_epoch(train_dataloader)
+            self.log_time(mode='stop', desc='train/epoch')
 
             # run val epoch
             if val_dataloader is not None:
+                self.log_time(mode='start', desc='val/epoch')
                 self.val_epoch(val_dataloader)
+                self.log_time(mode='stop', desc='val/epoch')
 
             # logging
+            self.log_time(mode='start', desc='whole/epoch_logging')
             if self.writer is not None:
                 self.writer.add_scalar('epoch', self._current_epoch, self._current_iter)  # epoch
 
                 if not self.reduce_logging:
-                    fig = self.plot_loss()  # loss plot
-                    self.log_fig('loss', fig)
-
                     self.log_weights()
                     self.save()
             
             # show postfix in tqdm
             prog_bar.set_postfix(self._prog_bar_epochs_dict)
+            self.log_time(mode='stop', desc='whole/epoch_logging')
+
+            self.log_time(mode='stop', desc='whole/epoch')
+            self.log_time(mode='start', desc='whole/epoch')
 
         # run test epoch
         if test_dataloader is not None:
             self.test_epoch(test_dataloader)
 
+        # log loss plot
+        if not self.reduce_logging:
+            fig = self.plot_loss()  # plotting image can take very long time; therefore it is not in epoch loop
+            self.log_fig('loss', fig)
+
         # setup eval mode
         self.eval()
 
@@ -503,8 +523,6 @@ class Trainer:
         :return:
         """
 
-        self.log_time(mode='start', desc='train/epoch')
-
         # setup training_mode
         # (usually 'train' but sometimes, e.g. for activation maximization where the model is not trained, it should be set to 'eval')
         if self.training_mode == 'train':
@@ -572,8 +590,6 @@ class Trainer:
 
         self.log_epoch(dataset_name='train')
 
-        self.log_time(mode='stop', desc='train/epoch')
-
     def val_epoch(self, dataloader=None):
         """
         Runs validation epoch. If this class in inherited, this must probably NOT be overwritten but val_step() method.
@@ -689,13 +705,13 @@ class Trainer:
         y = ttorch.utils.automove_data(model=self.model, t=y)
 
         if hasattr(self.model, 'forward_no_activation') and check_if_criterion_includes_activation(
-            criterion=self.criterion, activation=self.model.final_activation):
+            criterion=self.criterion, activation=self.model.hparams['final_activation']):
 
             pred = self.model.forward_no_activation(x)
             loss = self.criterion(pred, y)
 
-            if self.model.final_activation is not None:
-                pred = self.model.final_activation(pred)
+            if self.model.hparams['final_activation'] is not None:
+                pred = self.model.hparams['final_activation'](pred)
 
         else:
             pred = self.model(x)
@@ -876,7 +892,7 @@ class Trainer:
         if self.writer is not None:
 
             if mode == 'start':
-                if not hasattr(self, 'start_times'):
+                if not hasattr(self, 'start_times') or self.start_times is None:
                     self.start_times = {}
                 self.start_times[desc] = time.time()
 
@@ -894,79 +910,87 @@ class Trainer:
         if self.writer is not None:
             self.writer.close()
 
-        trainer_dict = self.__dict__.copy()
+        trainer_state_dict = self.__dict__.copy()
 
-        #del trainer_dict['writer']
-        #del trainer_dict['datamodule']
+        del trainer_state_dict['model']
 
-        return trainer_dict
+        return trainer_state_dict
     
     def save(self):
         """
-        Save checkpoint dictionary with epoch, iteration, trainer_dict, model_state_dict and optimizer_state_dict.
+        Save checkpoint dictionary with epoch, iteration, trainer_state_dict, model_state_dict and optimizer_state_dict.
         :return:
         """
 
-        state_changer = ttorch.model.ModelStateChanger()
-        state_changer(model=self.model, state='eval')
+        # save model
+        self.model.save(
+            path=os.path.join(tutils.files.join_paths(self.log_dir, 'model_checkpoints'), f'model_{self._current_epoch}.pt'),
+            epoch=self._current_epoch,
+            iteration=self._current_iter,
+        )
+
+        # save trainer
+        model_save_dict = self.model.get_save_dict()
 
         torch.save({
             'epoch': self._current_epoch,
             'iteration': self._current_iter,
-            'trainer_dict': self.state_dict(),
-            'model_state_dict': self.model.state_dict() if hasattr(self.model, 'state_dict') else self.model,
-            'optimizer_state_dict': self.optimizer.state_dict() if hasattr(self.optimizer, 'state_dict') else self.optimizer,
-        }, os.path.join(tutils.files.join_paths(self.log_dir, 'checkpoints'), f'checkpoint_{self._current_epoch}.pt'))
-
-        state_changer.reverse(model=self.model)  # return to train mode if model was in that mode before training
+            'trainer_state_dict': self.state_dict(),
+            **model_save_dict,
+        }, os.path.join(tutils.files.join_paths(self.log_dir, 'trainer_checkpoints'), f'trainer_{self._current_epoch}.pt'))
 
-    def load(self, log_dir, dummy_model=None, dummy_optimizer=None, device='cuda'):
+    def load(self, log_dir, device='cuda'):
         """
         Load trainer from given path.
         Infos how to save and load pytorch modules: https://pytorch.org/tutorials/beginner/saving_loading_models.html
 
-        :param log_dir: logging directory
-        :param dummy_model: model object to load state_dict of model with
+        :param log_dir: logging directory (or checkpoint path)
         :param dummy_optimizer: optimizer object to load state_dict of optimizer with
         :param device: available and to be used device ('cuda', 'cuda:<cuda_id>' or 'cpu'; default: 'cuda')
         :return: self
         """
 
-        checkpoint_path = ttorch.utils.get_last_checkpoint(log_dir=log_dir)
+        log_dir_is_checkpoint_path = True if log_dir.split('.')[-1] == 'pt' else False
+
+        if log_dir_is_checkpoint_path:
+            checkpoint_path = log_dir
+        else:
+            checkpoint_path = ttorch.utils.get_last_checkpoint(log_dir=log_dir)            
 
         # load checkpoint
         checkpoint = torch.load(checkpoint_path, map_location=torch.device(device))
 
         # overwrite self
         self.__dict__.clear()
-        self.__dict__.update(checkpoint['trainer_dict'])
-
+        self.__dict__.update(checkpoint['trainer_state_dict'])
 
-        # overwrite model with dummy model if given and load state dict
-        if dummy_model is not None:
-            self.model = dummy_model
-            self.model.load_state_dict(checkpoint['model_state_dict'])
-
-        # overwrite optimizer with dummy optimizer if given and load state dict
-        if dummy_optimizer is not None:
-            self.optimizer = dummy_optimizer
-            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
+        # load self.model
+        model = ttorch.model.load_model(checkpoint)
+        self.init_model(model=model, cuda=True if device == 'cuda' else False)
 
         # set to eval mode
         self.eval()
 
         # create new log_dir and summary writer with new log_dir
-        self.log_dir = log_dir
-        self.writer = SummaryWriter(os.path.join(log_dir, 'tf'))
+        if not log_dir_is_checkpoint_path:
+            self.log_dir = log_dir
+            self.writer = SummaryWriter(os.path.join(log_dir, 'tf'))
 
         return self
 
     def save_scripts(self):
         """
-        Save scripts of tlib and projects to log_dir. TODO implement!
+        Save scripts of tlib and projects to log_dir.
         :return:
         """
-        pass
+
+        # save important script folders to checkpoint path
+        destination_path = tutils.files.join_paths(self.log_dir, 'scripts')
+
+        source_path = os.path.dirname(os.path.abspath(tlib.__file__))  # path to tlib
+        source_path = '/'.join(str(source_path).split('/')[:-1])  # one folder back
+
+        shutil.copytree(source_path, os.path.join(destination_path, source_path.split('/')[-1]))
 
     def __str__(self):
         """
diff --git a/tlib/ttorch/utils.py b/tlib/ttorch/utils.py
index c014ac978fdb58af5d3d05f701648fdca2c0dc61..47499ebf2d05ec021a87d9fd1c37a0abb2080545 100644
--- a/tlib/ttorch/utils.py
+++ b/tlib/ttorch/utils.py
@@ -72,15 +72,15 @@ def get_last_checkpoint(log_dir):
     :return: checkpoint_path string of last version
     """
 
-    checkpoint_path = os.path.join(log_dir, 'checkpoints')
+    checkpoint_path = os.path.join(log_dir, 'trainer_checkpoints')
 
     # get last checkpoint
     checkpoints = os.listdir(checkpoint_path)
-    checkpoints = [int(c.replace('checkpoint_', '').replace('.pt', '')) for c in checkpoints if c.split('_')[0] == 'checkpoint']
+    checkpoints = [int(c.replace('trainer_', '').replace('.pt', '')) for c in checkpoints if c[:8] == 'trainer_']
     last_checkpoint = max(checkpoints)
 
     # get checkpoint path of last checkpoint
-    checkpoint_path = os.path.join(checkpoint_path, f'checkpoint_{last_checkpoint}.pt')
+    checkpoint_path = os.path.join(checkpoint_path, f'trainer_{last_checkpoint}.pt')
 
     return checkpoint_path