diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py index 3ab3a423a001e903ec4ca9fe1bd7ec78e18dc731..972f95576a90563241f6eede2c4d973115092d06 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py @@ -2,21 +2,8 @@ __email__ = "b.gong@fz-juelich.de" __author__ = "Bing Gong,Yanji" __date__ = "2021-04-13" -import collections -import functools -import itertools -from collections import OrderedDict -import numpy as np +from model_helpers import set_and_check_pred_frames import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops -from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.models import networks -from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils -from datetime import datetime -from pathlib import Path from model_modules.video_prediction.layers import layer_def as ld from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from tensorflow.contrib.training import HParams @@ -53,7 +40,7 @@ class ConvLstmGANVideoPredictionModel(object): self.total_loss = None self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length - self.context_frames + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.max_epochs = self.hparams.max_epochs self.loss_fun = self.hparams.loss_fun self.batch_size = self.hparams.batch_size diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py index e0b0d61edcc2464492fbd00e733ff4ce0130c04a..747e19334e37d001439529f7a449a69b1a3a56f8 100644 --- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py +++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py @@ -10,23 +10,11 @@ This code implement take the following as references: 2) cousera GAN courses 3) https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py """ -import collections -import functools -import itertools -from collections import OrderedDict -import numpy as np + import tensorflow as tf -from tensorflow.python.util import nest -from model_modules.video_prediction import ops, flow_ops -from model_modules.video_prediction.models import BaseVideoPredictionModel -from model_modules.video_prediction.models import networks -from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat -from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell -from model_modules.video_prediction.utils import tf_utils -from datetime import datetime -from pathlib import Path + +from model_helpers import set_and_check_pred_frames from model_modules.video_prediction.layers import layer_def as ld -from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell from tensorflow.contrib.training import HParams class VanillaGANVideoPredictionModel(object): @@ -44,11 +32,11 @@ class VanillaGANVideoPredictionModel(object): self.total_loss = None self.context_frames = self.hparams.context_frames self.sequence_length = self.hparams.sequence_length - self.predict_frames = self.sequence_length - self.context_frames + self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames) self.max_epochs = self.hparams.max_epochs self.loss_fun = self.hparams.loss_fun self.batch_size = self.hparams.batch_size - self.z_dim = self.hparams.z_dim #dim of noise-vector + self.z_dim = self.hparams.z_dim # dim of noise-vector def get_default_hparams(self): return HParams(**self.get_default_hparams_dict())