Skip to content
Snippets Groups Projects
Commit c02129ae authored by Michael Langguth's avatar Michael Langguth
Browse files

Adopt the models GAN and convLSTM_GAN to use automatically retreived...

Adopt the models GAN and convLSTM_GAN to use automatically retreived sequence_length and add a consistency check for predict_frames-attribute.
parent 25f3b4ba
No related branches found
No related tags found
No related merge requests found
Pipeline #68253 passed
......@@ -2,21 +2,8 @@ __email__ = "b.gong@fz-juelich.de"
__author__ = "Bing Gong,Yanji"
__date__ = "2021-04-13"
import collections
import functools
import itertools
from collections import OrderedDict
import numpy as np
from model_helpers import set_and_check_pred_frames
import tensorflow as tf
from tensorflow.python.util import nest
from model_modules.video_prediction import ops, flow_ops
from model_modules.video_prediction.models import BaseVideoPredictionModel
from model_modules.video_prediction.models import networks
from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
from model_modules.video_prediction.utils import tf_utils
from datetime import datetime
from pathlib import Path
from model_modules.video_prediction.layers import layer_def as ld
from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
from tensorflow.contrib.training import HParams
......@@ -53,7 +40,7 @@ class ConvLstmGANVideoPredictionModel(object):
self.total_loss = None
self.context_frames = self.hparams.context_frames
self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length - self.context_frames
self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun
self.batch_size = self.hparams.batch_size
......
......@@ -10,23 +10,11 @@ This code implement take the following as references:
2) cousera GAN courses
3) https://github.com/hwalsuklee/tensorflow-generative-model-collections/blob/master/GAN.py
"""
import collections
import functools
import itertools
from collections import OrderedDict
import numpy as np
import tensorflow as tf
from tensorflow.python.util import nest
from model_modules.video_prediction import ops, flow_ops
from model_modules.video_prediction.models import BaseVideoPredictionModel
from model_modules.video_prediction.models import networks
from model_modules.video_prediction.ops import dense, pad2d, conv2d, flatten, tile_concat
from model_modules.video_prediction.rnn_ops import BasicConv2DLSTMCell, Conv2DGRUCell
from model_modules.video_prediction.utils import tf_utils
from datetime import datetime
from pathlib import Path
from model_helpers import set_and_check_pred_frames
from model_modules.video_prediction.layers import layer_def as ld
from model_modules.video_prediction.layers.BasicConvLSTMCell import BasicConvLSTMCell
from tensorflow.contrib.training import HParams
class VanillaGANVideoPredictionModel(object):
......@@ -44,7 +32,7 @@ class VanillaGANVideoPredictionModel(object):
self.total_loss = None
self.context_frames = self.hparams.context_frames
self.sequence_length = self.hparams.sequence_length
self.predict_frames = self.sequence_length - self.context_frames
self.predict_frames = set_and_check_pred_frames(self.sequence_length, self.context_frames)
self.max_epochs = self.hparams.max_epochs
self.loss_fun = self.hparams.loss_fun
self.batch_size = self.hparams.batch_size
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment