diff --git a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py index a946bd555a603fd9be14306929e0a8e722a24673..61fdb9121c3f127a0fb873df546c629fb679ff79 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/mcnet_model.py @@ -3,22 +3,13 @@ __author__ = "Bing Gong" __date__ = "2020-08-22" -import collections -import functools import itertools -from collections import OrderedDict import numpy as np import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops + +from model_helpers import set_and_check_pred_frames from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.models import networks from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils -from datetime import datetime -from pathlib import Path -from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from model_modules.video_prediction.layers.mcnet_ops import * from model_modules.video_prediction.utils.mcnet_utils import * @@ -32,7 +23,7 @@ class McNetVideoPredictionModel(BaseVideoPredictionModel): self.lr = self.hparams.lr self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length - self.context_frames + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.df_dim = self.hparams.df_dim self.gf_dim = self.hparams.gf_dim self.alpha = self.hparams.alpha diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py index b9c76c5a7addef60725008f881168f99e9b0f8a4..58172bca0401cdc2b2a4353ac2aeee092d59774a 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_convLSTM_model.py @@ -2,21 +2,8 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong, Scarlet Stadtler,Michael Langguth" __date__ = "2020-11-05" -import collections -import functools -import itertools -from collections import OrderedDict -import numpy as np +from model_helpers import set_and_check_pred_frames import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops -from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.models import networks -from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils -from datetime import datetime -from pathlib import Path from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from tensorflow.contrib.training import HParams @@ -36,8 +23,7 @@ class VanillaConvLstmVideoPredictionModel(object): self.total_loss = None self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = VanillaConvLstmVideoPredictionModel.set_and_check_pred_frames(self.sequence_length, - self.context_frames) + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.max_epochs = self.hparams.max_epochs self.loss_fun = self.hparams.loss_fun diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py index 986e3626fe0746c2c714ee9fa9a76ad873044415..bc4516d9e953689b58b5277e629b66d275881cfb 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_vae_model.py @@ -3,21 +3,8 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong" __date__ = "2020-09-01" -import collections -import functools -import itertools -from collections import OrderedDict -import numpy as np +from model_helpers import set_and_check_pred_frames import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops -from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.models import networks -from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils -from datetime import datetime -from pathlib import Path from model_modules.video_prediction.layers import layer_def as ld from tensorflow.contrib.training import HParams @@ -37,7 +24,7 @@ class VanillaVAEVideoPredictionModel(object): self.total_loss = None self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length - self.context_frames + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.max_epochs = self.hparams.max_epochs self.nz = self.hparams.nz self.loss_fun = self.hparams.loss_fun