diff --git a/video_prediction_tools/main_scripts/main_train_models.py b/video_prediction_tools/main_scripts/main_train_models.py index ff9486cd29cf928f49cd8d161af28eca56380dca..d54393473313e0604028db278d652b18fbd00d4b 100644 --- a/video_prediction_tools/main_scripts/main_train_models.py +++ b/video_prediction_tools/main_scripts/main_train_models.py @@ -159,12 +159,12 @@ class TrainModel(object): # create dataset instance VideoDataset = datasets.get_dataset_class(self.dataset) - self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict, + self.train_dataset = VideoDataset(input_dir=self.input_dir, output_dir = self.output_dir, mode='train', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict) - self.calculate_samples_and_epochs() + self.num_examples = self.calculate_samples_and_epochs() self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) # set-up validation dataset and calculate number of batches for calculating validation loss - self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict, + self.val_dataset = VideoDataset(input_dir=self.input_dir, output_dir = self.output_dir, mode='val', datasplit_config=self.datasplit_dict, hparams_dict_config=self.model_hparams_dict, nsamples_ref=self.num_examples) # Retrieve sequence length from dataset self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length}) @@ -175,7 +175,7 @@ class TrainModel(object): :param mode: "train" used the model graph in train process; "test" for postprocessing step """ VideoPredictionModel = models.get_model_class(self.model) - self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict, mode=mode) + self.video_model = VideoPredictionModel(hparams_dict=self.model_hparams_dict_load, mode=mode) def setup_graph(self): """ @@ -201,8 +201,8 @@ class TrainModel(object): self.inputs = self.iterator.get_next() # since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP # Otherwise an error will be risen by SAVP - if self.dataset == "era5" and self.model == "savp": - del self.inputs["T_start"] + if self.dataset == "era5": + self.inputs = self.inputs def save_dataset_model_params_to_checkpoint_dir(self, dataset, video_model): """ diff --git a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py index 7393911145db16798d3a8acdf3ff30870a32e0ad..0c527b2b0f403714bfc4020e2778d569abb6ef8d 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/our_base_model.py @@ -14,10 +14,8 @@ import tensorflow as tf class BaseModels(ABC): - def __init__(self, hparams_dict_config=None): - self.hparams_dict_config = hparams_dict_config - self.hparams_dict = self.get_model_hparams_dict() - self.hparams = self.parse_hparams() + def __init__(self, hparams_dict_config:dict = None): + self.hparams = self.parse_hparams(hparams_dict_config) # Attributes set during runtime self.total_loss = None self.loss_summary = None @@ -33,24 +31,12 @@ class BaseModels(ABC): self.x_hat_predict_frames = None - def get_model_hparams_dict(self): - """ - Get model_hparams_dict from json file - """ - if self.hparams_dict_config: - with open(self.hparams_dict_config, 'r') as f: - hparams_dict = json.loads(f.read()) - else: - raise FileNotFoundError("hparam directory doesn't exist! please check {}!".format(self.hparams_dict_config)) - return hparams_dict - - def parse_hparams(self): + def parse_hparams(self,hparams_dict_config): """ Obtain the parameters from directory """ - - hparams = dotdict(self.hparams_dict) + hparams = dotdict(hparams_dict_config) return hparams @abstractmethod