Skip to content
Snippets Groups Projects
Commit 980b3d1f authored by Michael Langguth's avatar Michael Langguth
Browse files

Reduction of default hyperparameters to the relevant ones for the dataset and...

Reduction of default hyperparameters to the relevant ones for the dataset and corrected usage of Hparams-handling in era5_dataset.py.

Conflicts:
	video_prediction_tools/model_modules/video_prediction/datasets/era5_dataset.py
parent 6bfdc3ad
Branches
Tags
No related merge requests found
...@@ -10,7 +10,7 @@ import tensorflow as tf ...@@ -10,7 +10,7 @@ import tensorflow as tf
from collections import OrderedDict from collections import OrderedDict
from tensorflow.contrib.training import HParams from tensorflow.contrib.training import HParams
from google.protobuf.json_format import MessageToDict from google.protobuf.json_format import MessageToDict
from general_utils import reduce_dict
class ERA5Dataset(object): class ERA5Dataset(object):
...@@ -58,20 +58,18 @@ class ERA5Dataset(object): ...@@ -58,20 +58,18 @@ class ERA5Dataset(object):
def get_default_hparams_dict(self): def get_default_hparams_dict(self):
""" """
The function that contains default hparams Provide dictionary containing default hyperparameters for the dataset
Returns: Returns:
A dict with the following hyperparameters. A dict with the following hyperparameters relevant for the dataset.
context_frames : the number of ground-truth frames to pass in at start. context_frames : the number of ground-truth frames to pass in at start
batch_size : number of training examples per mini-batch
max_epochs : the number of epochs to train model max_epochs : the number of epochs to train model
lr : learning rate
loss_fun : the loss function loss_fun : the loss function
""" """
hparams = dict( hparams = dict(
context_frames=10, context_frames=10,
max_epochs = 20, max_epochs = 20,
batch_size = 40, batch_size = 40,
lr = 0.001,
loss_fun = "rmse",
shuffle_on_val= True, shuffle_on_val= True,
) )
return hparams return hparams
...@@ -89,7 +87,9 @@ class ERA5Dataset(object): ...@@ -89,7 +87,9 @@ class ERA5Dataset(object):
""" """
Parse the hparams setting to ovoerride the default ones Parse the hparams setting to ovoerride the default ones
""" """
self.hparams_dict = reduce_dict(self.hparams_dict, self.get_default_hparams().values())
parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {}) parsed_hparams = self.get_default_hparams().override_from_dict(self.hparams_dict or {})
return parsed_hparams return parsed_hparams
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment