From 980b3d1fb6b4b343e1607aff2877684ae5a53ae4 Mon Sep 17 00:00:00 2001 From: Michael <m.langguth@fz-juelich.de> Date: Mon, 21 Jun 2021 11:37:56 +0200 Subject: [PATCH] Reduction of default hyperparameters to the relevant ones for the dataset and corrected usage of Hparams-handling in era5_dataset.py. Conflicts: video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py --- .../video_prediction/datasets/era5_dataset.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py index ce62965a..664a817a 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py @@ -10,7 +10,7 @@ import tensorflow as tf from collections import OrderedDict from tensorflow.contrib.training import HParams from google.protobuf.json_format import MessageToDict - +from general_utils import reduce_dict class ERA5Dataset(object): @@ -58,20 +58,18 @@ class ERA5Dataset(object): def get_default_hparams_dict(self): """ - The function that contains default hparams + Provide dictionary containing default hyperparameters for the dataset Returns: - A dict with the following hyperparameters. - context_frames : the number of ground-truth frames to pass in at start. + A dict with the following hyperparameters relevant for the dataset. + context_frames : the number of ground-truth frames to pass in at start + batch_size : number of training examples per mini-batch max_epochs : the number of epochs to train model - lr : learning rate loss_fun : the loss function """ hparams = dict( context_frames=10, max_epochs = 20, batch_size = 40, - lr = 0.001, - loss_fun = "rmse", shuffle_on_val= True, ) return hparams @@ -89,7 +87,9 @@ class ERA5Dataset(object): """ Parse the hparams setting to ovoerride the default ones """ + self.hparams_dict = reduce_dict(self.hparams_dict, self.get_default_hparams().values()) parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) + return parsed_hparams -- GitLab