Skip to content
Snippets Groups Projects
Commit ff32ac91 authored by lukas leufen's avatar lukas leufen
Browse files

batch size is now part of the exp setup and not related to the model class anymore

parent eac8cf47
No related branches found
No related tags found
3 merge requests!125Release v0.10.0,!124Update Master to new version v0.10.0,!104Resolve "batch size definition in exp setup"
Pipeline #39352 failed
......@@ -119,7 +119,6 @@ from typing import Any, Callable, Dict
import keras
import tensorflow as tf
import logging
from src.model_modules.inception_model import InceptionModelBase
from src.model_modules.flatten import flatten_tail
from src.model_modules.advanced_paddings import PadUtils, Padding2D
......@@ -356,7 +355,6 @@ class MyLittleModel(AbstractModelClass):
self.dropout_rate = 0.1
self.regularizer = keras.regularizers.l2(0.1)
self.epochs = 20
self.batch_size = int(256)
self.activation = keras.layers.PReLU
# apply to model
......@@ -430,7 +428,6 @@ class MyBranchedModel(AbstractModelClass):
self.dropout_rate = 0.1
self.regularizer = keras.regularizers.l2(0.1)
self.epochs = 20
self.batch_size = int(256)
self.activation = keras.layers.PReLU
# apply to model
......@@ -505,7 +502,6 @@ class MyTowerModel(AbstractModelClass):
self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
epochs_drop=10)
self.epochs = 20
self.batch_size = int(256 * 4)
self.activation = keras.layers.PReLU
# apply to model
......@@ -619,7 +615,6 @@ class MyPaperModel(AbstractModelClass):
self.lr_decay = src.model_modules.keras_extensions.LearningRateDecay(base_lr=self.initial_lr, drop=.94,
epochs_drop=10)
self.epochs = 150
self.batch_size = int(256 * 2)
self.activation = keras.layers.ELU
self.padding = "SymPad2D"
......
......@@ -25,7 +25,8 @@ def run(stations=['DEBW107', 'DEBY081', 'DEBW013', 'DEBW076', 'DEBW087', 'DEBW00
train_min_length=None, val_min_length=None, test_min_length=None,
evaluate_bootstraps=True, number_of_bootstraps=None, create_new_bootstraps=False,
plot_list=None,
model=None):
model=None,
batch_size=None):
params = inspect.getfullargspec(ExperimentSetup).args
kwargs = {k: v for k, v in locals().items() if k in params}
......
......@@ -233,12 +233,12 @@ class ExperimentSetup(RunEnvironment):
create_new_model: bool = None, bootstrap_path=None, permute_data_on_training: bool = None, transformation=None,
train_min_length=None, val_min_length=None, test_min_length=None, extreme_values: list = None,
extremes_on_right_tail_only: bool = None, evaluate_bootstraps=True, plot_list=None, number_of_bootstraps=None,
create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None):
create_new_bootstraps=None, data_path: str = None, login_nodes=None, hpc_hosts=None, model=None, batch_size=None):
# create run framework
super().__init__()
# experiment setup
# experiment setup, hyperparameters
self._set_param("data_path", path_config.prepare_host(data_path=data_path, sampling=sampling))
self._set_param("hostname", path_config.get_host())
self._set_param("hpc_hosts", hpc_hosts, default=DEFAULT_HPC_HOST_LIST + DEFAULT_HPC_LOGIN_LIST)
......@@ -257,6 +257,7 @@ class ExperimentSetup(RunEnvironment):
upsampling = self.data_store.get("upsampling", "train")
permute_data = False if permute_data_on_training is None else permute_data_on_training
self._set_param("permute_data", permute_data or upsampling, scope="train")
self._set_param("batch_size", batch_size, default=int(256 * 2))
# set experiment name
exp_date = self._get_parser_args(parser_args).get("experiment_date")
......
......@@ -33,6 +33,7 @@ class ModelSetup(RunEnvironment):
* `generator` [train]
* `window_lead_time` [.]
* `window_history_size` [.]
* `model_class` [.]
Optional objects
* `lr_decay` [model]
......@@ -43,7 +44,7 @@ class ModelSetup(RunEnvironment):
* `hist` [model]
* `callbacks` [model]
* `model_name` [model]
* all settings from model class like `dropout_rate`, `initial_lr`, `batch_size`, and `optimizer` [model]
* all settings from model class like `dropout_rate`, `initial_lr`, and `optimizer` [model]
Creates
* plot of model architecture `<model_name>.pdf`
......
......@@ -33,7 +33,7 @@ class Training(RunEnvironment):
Required objects [scope] from data store:
* `model` [model]
* `batch_size` [model]
* `batch_size` [.]
* `epochs` [model]
* `callbacks` [model]
* `model_name` [model]
......@@ -67,7 +67,7 @@ class Training(RunEnvironment):
self.train_set: Union[Distributor, None] = None
self.val_set: Union[Distributor, None] = None
self.test_set: Union[Distributor, None] = None
self.batch_size = self.data_store.get("batch_size", "model")
self.batch_size = self.data_store.get("batch_size")
self.epochs = self.data_store.get("epochs", "model")
self.callbacks: CallbackHandler = self.data_store.get("callbacks", "model")
self.experiment_name = self.data_store.get("experiment_name")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment