Skip to content
Snippets Groups Projects
Commit 4416932e authored by gong1's avatar gong1
Browse files

Add back get_dataset_class and update the sequence_length from dataset object

parent 3eb3c9db
No related branches found
No related tags found
No related merge requests found
Pipeline #70724 passed
......@@ -139,12 +139,13 @@ class TrainModel(object):
Setup train and val dataset instance with the corresponding data split configuration.
Simultaneously, sequence_length is attached to the hyperparameter dictionary.
"""
VideoDataset = datasets.get_dataset_class(self.dataset)
self.train_dataset = VideoDataset(input_dir=self.input_dir, mode='train', datasplit_config=self.datasplit_dict,
hparams_dict_config=self.model_hparams_dict)
self.val_dataset = VideoDataset(input_dir=self.input_dir, mode='val', datasplit_config=self.datasplit_dict,
hparams_dict_config=self.model_hparams_dict)
# ML/BG 2021-06-15: Is the following needed?
# self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
self.model_hparams_dict_load.update({"sequence_length": self.train_dataset.sequence_length})
def setup_model(self):
"""
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment