Commit f37acc99 authored by Bing Gong's avatar Bing Gong
Browse files

update main_train_modles.py

parent c759d326
Pipeline #115021 failed with stages
in 1 minute and 8 seconds
......@@ -159,12 +159,12 @@ class TrainModel(object):
# create dataset instance
VideoDataset = datasets.get_dataset_class(self.dataset)
self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict,
self.train_dataset = VideoDataset(input_dir=self.input_dir, output_dir = self.output_dir, mode='train', datasplit_config=self.datasplit_dict,
hparams_dict_config=self.model_hparams_dict)
self.calculate_samples_and_epochs()
self.num_examples = self.calculate_samples_and_epochs()
self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
# set-up validation dataset and calculate number of batches for calculating validation loss
self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict,
self.val_dataset = VideoDataset(input_dir=self.input_dir, output_dir = self.output_dir, mode='val', datasplit_config=self.datasplit_dict,
hparams_dict_config=self.model_hparams_dict, nsamples_ref=self.num_examples)
# Retrieve sequence length from dataset
self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
......@@ -175,7 +175,7 @@ class TrainModel(object):
:param mode: "train" used the model graph in train process; "test" for postprocessing step
"""
VideoPredictionModel = models.get_model_class(self.model)
self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict, mode=mode)
self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict_load, mode=mode)
def setup_graph(self):
"""
......@@ -201,8 +201,8 @@ class TrainModel(object):
self.inputs = self.iterator.get_next()
# since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP
# Otherwise an error will be risen by SAVP
if self.dataset == "era5" and self.model == "savp":
del self.inputs["T_start"]
if self.dataset == "era5":
self.inputs = self.inputs
def save_dataset_model_params_to_checkpoint_dir(self, dataset, video_model):
"""
......
......@@ -14,10 +14,8 @@ import tensorflow as tf
class BaseModels(ABC):
def __init__(self, hparams_dict_config=None):
self.hparams_dict_config = hparams_dict_config
self.hparams_dict = self.get_model_hparams_dict()
self.hparams = self.parse_hparams()
def __init__(self, hparams_dict_config:dict = None):
self.hparams = self.parse_hparams(hparams_dict_config)
# Attributes set during runtime
self.total_loss = None
self.loss_summary = None
......@@ -33,24 +31,12 @@ class BaseModels(ABC):
self.x_hat_predict_frames = None
def get_model_hparams_dict(self):
"""
Get model_hparams_dict from json file
"""
if self.hparams_dict_config:
with open(self.hparams_dict_config, 'r') as f:
hparams_dict = json.loads(f.read())
else:
raise FileNotFoundError("hparam directory doesn't exist! please check {}!".format(self.hparams_dict_config))
return hparams_dict
def parse_hparams(self):
def parse_hparams(self,hparams_dict_config):
"""
Obtain the parameters from directory
"""
hparams = dotdict(self.hparams_dict)
hparams = dotdict(hparams_dict_config)
return hparams
@abstractmethod
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment