diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py index 3d2bd6d7462320baaa6bc8f5414506491d6d0e1f..adfdf06539174a78be375fa1f2416dade078a1c5 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/moving_mnist.py @@ -29,6 +29,7 @@ class MovingMnist(object): self.input_dir = input_dir self.mode = mode self.seed = seed + self.sequence_length = None # will be set in get_example_info if self.mode not in ('train', 'val', 'test'): raise ValueError('Invalid mode %s' % self.mode) if not os.path.exists(self.input_dir): @@ -41,8 +42,6 @@ class MovingMnist(object): self.get_tfrecords_filename_base_datasplit() self.get_example_info() - - def get_datasplit(self): """ Get the datasplit json file @@ -51,8 +50,6 @@ class MovingMnist(object): self.d = json.load(f) return self.d - - def get_model_hparams_dict(self): """ Get model_hparams_dict from json file @@ -142,21 +139,23 @@ class MovingMnist(object): dict_message = MessageToDict(tf.train.Example.FromString(example)) feature = dict_message['features']['feature'] print("features in dataset:",feature.keys()) - self.video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', 'width', 'channels']) - self.image_shape = self.video_shape[1:] - + video_shape = tuple(int(feature[key]['int64List']['value'][0]) for key in ['sequence_length','height', + 'width', 'channels']) + self.sequence_length = video_shape[0] + self.image_shape = video_shape[1:] def num_examples_per_epoch(self): """ Calculate how many tfrecords samples in the train/val/test """ - #count how many tfrecords files for train/val/testing + # count how many tfrecords files for train/val/testing len_fnames = len(self.filenames) - seq_len_file = os.path.join(self.input_dir, 'number_sequences.txt') - with open(seq_len_file, 'r') as sequence_lengths_file: - sequence_lengths = sequence_lengths_file.readlines() - sequence_lengths = [int(sequence_length.strip()) for sequence_length in sequence_lengths] - self.num_examples_per_epoch = len_fnames * sequence_lengths[0] + num_seq_file = os.path.join(self.input_dir, 'number_sequences.txt') + with open(num_seq_file, 'r') as dfile: + num_seqs = dfile.readlines() + num_sequences = [int(num_seq.strip()) for num_seq in num_seqs] + self.num_examples_per_epoch = len_fnames * num_sequences[0] + return self.num_examples_per_epoch @@ -181,8 +180,10 @@ class MovingMnist(object): } parsed_features = tf.parse_single_example(serialized_example, keys_to_features) seq = tf.sparse_tensor_to_dense(parsed_features["images/encoded"]) - print("Image shape {}, {},{},{}".format(self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2])) - images = tf.reshape(seq, [self.video_shape[0],self.image_shape[0],self.image_shape[1], self.image_shape[2]], name = "reshape_new") + print("Image shape {}, {},{},{}".format(self.sequence_length,self.image_shape[0],self.image_shape[1], + self.image_shape[2])) + images = tf.reshape(seq, [self.sequence_length,self.image_shape[0],self.image_shape[1], + self.image_shape[2]], name = "reshape_new") seqs["images"] = images return seqs filenames = self.filenames