diff --git a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
index 3ab3a423a001e903ec4ca9fe1bd7ec78e18dc731..972f95576a90563241f6eede2c4d973115092d06 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/convLSTM_GAN_model.py
@@ -2,21 +2,8 @@ __email__ = "b.gong@fz-juelich.de"
 __author__ = "Bing Gong,Yanji"
 __date__ = "2021-04-13"
 
-import collections
-import functools
-import itertools
-from collections import OrderedDict
-import numpy as np
+from model_helpers import set_and_check_pred_frames
 import tensorflow as tf
-from tensorflow.python.util import nest
-from model_modules.video_prediction import ops, flow_ops
-from model_modules.video_prediction.models import BaseVideoPredictionModel
-from model_modules.video_prediction.models import networks
-from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
-from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
-from model_modules.video_prediction.utils import tf_utils
-from datetime import datetime
-from pathlib import Path
 from model_modules.video_prediction.layers import layer_def as ld
 from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
 from tensorflow.contrib.training import HParams
@@ -53,7 +40,7 @@ class ConvLstmGANVideoPredictionModel(object):
         self.total_loss = None
         self.context_frames = self.hparams.context_frames
         self.sequence_length = self.hparams.sequence_length
-        self.predict_frames = self.sequence_length - self.context_frames
+        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
         self.max_epochs = self.hparams.max_epochs
         self.loss_fun = self.hparams.loss_fun
         self.batch_size = self.hparams.batch_size
diff --git a/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py
index e0b0d61edcc2464492fbd00e733ff4ce0130c04a..747e19334e37d001439529f7a449a69b1a3a56f8 100644
--- a/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py
+++ b/video_prediction_tools/model_modules/video_prediction/models/vanilla_GAN_model.py
@@ -10,23 +10,11 @@ This code implement take the following as references:
 2) cousera GAN courses
 3) https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py
 """
-import collections
-import functools
-import itertools
-from collections import OrderedDict
-import numpy as np
+
 import tensorflow as tf
-from tensorflow.python.util import nest
-from model_modules.video_prediction import ops, flow_ops
-from model_modules.video_prediction.models import BaseVideoPredictionModel
-from model_modules.video_prediction.models import networks
-from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
-from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
-from model_modules.video_prediction.utils import tf_utils
-from datetime import datetime
-from pathlib import Path
+
+from model_helpers import set_and_check_pred_frames
 from model_modules.video_prediction.layers import layer_def as ld
-from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
 from tensorflow.contrib.training import HParams
 
 class VanillaGANVideoPredictionModel(object):
@@ -44,11 +32,11 @@ class VanillaGANVideoPredictionModel(object):
         self.total_loss = None
         self.context_frames = self.hparams.context_frames
         self.sequence_length = self.hparams.sequence_length
-        self.predict_frames = self.sequence_length - self.context_frames
+        self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
         self.max_epochs = self.hparams.max_epochs
         self.loss_fun = self.hparams.loss_fun
         self.batch_size = self.hparams.batch_size
-        self.z_dim = self.hparams.z_dim #dim of noise-vector
+        self.z_dim = self.hparams.z_dim  # dim of noise-vector
 
     def get_default_hparams(self):
         return HParams(**self.get_default_hparams_dict())