Skip to content
GitLab
Projects
Groups
Snippets
/
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
esde
machine-learning
AMBS
Commits
f37acc99
Commit
f37acc99
authored
Oct 18, 2022
by
Bing Gong
Browse files
update main_train_modles.py
parent
c759d326
Pipeline
#115021
failed with stages
in 1 minute and 8 seconds
Changes
2
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
video_prediction_tools/main_scripts/main_train_models.py
View file @
f37acc99
...
...
@@ -159,12 +159,12 @@ class TrainModel(object):
# create dataset instance
VideoDataset
=
datasets
.
get_dataset_class
(
self
.
dataset
)
self
.
train_dataset
=
VideoDataset
(
input_dir
=
self
.
input_dir
,
mode
=
'train'
,
datasplit_config
=
self
.
datasplit_dict
,
self
.
train_dataset
=
VideoDataset
(
input_dir
=
self
.
input_dir
,
output_dir
=
self
.
output_dir
,
mode
=
'train'
,
datasplit_config
=
self
.
datasplit_dict
,
hparams_dict_config
=
self
.
model_hparams_dict
)
self
.
calculate_samples_and_epochs
()
self
.
num_examples
=
self
.
calculate_samples_and_epochs
()
self
.
model_hparams_dict_load
.
update
({
"sequence_length"
:
self
.
train_dataset
.
sequence_length
})
# set-up validation dataset and calculate number of batches for calculating validation loss
self
.
val_dataset
=
VideoDataset
(
input_dir
=
self
.
input_dir
,
mode
=
'val'
,
datasplit_config
=
self
.
datasplit_dict
,
self
.
val_dataset
=
VideoDataset
(
input_dir
=
self
.
input_dir
,
output_dir
=
self
.
output_dir
,
mode
=
'val'
,
datasplit_config
=
self
.
datasplit_dict
,
hparams_dict_config
=
self
.
model_hparams_dict
,
nsamples_ref
=
self
.
num_examples
)
# Retrieve sequence length from dataset
self
.
model_hparams_dict_load
.
update
({
"sequence_length"
:
self
.
train_dataset
.
sequence_length
})
...
...
@@ -175,7 +175,7 @@ class TrainModel(object):
:param mode: "train" used the model graph in train process; "test" for postprocessing step
"""
VideoPredictionModel
=
models
.
get_model_class
(
self
.
model
)
self
.
video_model
=
VideoPredictionModel
(
hparams_dict
=
self
.
model_hparams_dict
,
mode
=
mode
)
self
.
video_model
=
VideoPredictionModel
(
hparams_dict
=
self
.
model_hparams_dict
_load
,
mode
=
mode
)
def
setup_graph
(
self
):
"""
...
...
@@ -201,8 +201,8 @@ class TrainModel(object):
self
.
inputs
=
self
.
iterator
.
get_next
()
# since era5 tfrecords include T_start, we need to remove it from the tfrecord when we train SAVP
# Otherwise an error will be risen by SAVP
if
self
.
dataset
==
"era5"
and
self
.
model
==
"savp"
:
d
el
self
.
inputs
[
"T_start"
]
if
self
.
dataset
==
"era5"
:
s
el
f
.
inputs
=
self
.
inputs
def
save_dataset_model_params_to_checkpoint_dir
(
self
,
dataset
,
video_model
):
"""
...
...
video_prediction_tools/model_modules/video_prediction/models/our_base_model.py
View file @
f37acc99
...
...
@@ -14,10 +14,8 @@ import tensorflow as tf
class
BaseModels
(
ABC
):
def
__init__
(
self
,
hparams_dict_config
=
None
):
self
.
hparams_dict_config
=
hparams_dict_config
self
.
hparams_dict
=
self
.
get_model_hparams_dict
()
self
.
hparams
=
self
.
parse_hparams
()
def
__init__
(
self
,
hparams_dict_config
:
dict
=
None
):
self
.
hparams
=
self
.
parse_hparams
(
hparams_dict_config
)
# Attributes set during runtime
self
.
total_loss
=
None
self
.
loss_summary
=
None
...
...
@@ -33,24 +31,12 @@ class BaseModels(ABC):
self
.
x_hat_predict_frames
=
None
def
get_model_hparams_dict
(
self
):
"""
Get model_hparams_dict from json file
"""
if
self
.
hparams_dict_config
:
with
open
(
self
.
hparams_dict_config
,
'r'
)
as
f
:
hparams_dict
=
json
.
loads
(
f
.
read
())
else
:
raise
FileNotFoundError
(
"hparam directory doesn't exist! please check {}!"
.
format
(
self
.
hparams_dict_config
))
return
hparams_dict
def
parse_hparams
(
self
):
def
parse_hparams
(
self
,
hparams_dict_config
):
"""
Obtain the parameters from directory
"""
hparams
=
dotdict
(
self
.
hparams_dict
)
hparams
=
dotdict
(
hparams_dict_config
)
return
hparams
@
abstractmethod
...
...
Write
Preview
Supports
Markdown
0%
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment