diff --git a/video_prediction/datasets/base_dataset.py b/video_prediction/datasets/base_dataset.py index e66cf5206c77076927ac518e4d74127b363fee23..cc69d987bc1842a61eebbfa3e8a054517758c69d 100644 --- a/video_prediction/datasets/base_dataset.py +++ b/video_prediction/datasets/base_dataset.py @@ -245,7 +245,7 @@ class VideoDataset(BaseVideoDataset): list(action_like_names_and_shapes.items())): name, shape = name_and_shape feature = self._dict_message['features']['feature'] - names = [name_ for name_ in feature.keys() if re.search(name.replace('%d', '(\d+)'), name_) is not None] + names = [name_ for name_ in feature.keys() if re.search(name.replace('%d', '\d+'), name_) is not None] if not names: raise ValueError('Could not found any feature with name pattern %s.' % name) if example_name in self.state_like_names_and_shapes: diff --git a/video_prediction/datasets/softmotion_dataset.py b/video_prediction/datasets/softmotion_dataset.py index f12e330119fd583a8a26433183cb67e46b0dd6a0..286acd42629f2a25167f7439fadf0c0b068254d2 100644 --- a/video_prediction/datasets/softmotion_dataset.py +++ b/video_prediction/datasets/softmotion_dataset.py @@ -14,14 +14,32 @@ class SoftmotionVideoDataset(VideoDataset): """ def __init__(self, *args, **kwargs): super(SoftmotionVideoDataset, self).__init__(*args, **kwargs) - if 'softmotion30_44k' in self.input_dir.split('/'): - self.state_like_names_and_shapes['images'] = '%d/image_aux1/encoded', None - else: - self.state_like_names_and_shapes['images'] = '%d/image_view0/encoded', None + # infer name of image feature and check if object_pos feature is present + from google.protobuf.json_format import MessageToDict + example = next(tf.python_io.tf_record_iterator(self.filenames[0])) + dict_message = MessageToDict(tf.train.Example.FromString(example)) + feature = dict_message['features']['feature'] + image_names = set() + for name in feature.keys(): + m = re.search('\d+/(\w+)/encoded', name) + if m: + image_names.add(m.group(1)) + # look for image_aux1 and image_view0 in that order of priority + image_name = None + for name in ['image_aux1', 'image_view0']: + if name in image_names: + image_name = name + break + if not image_name: + if len(image_names) == 1: + image_name = image_names.pop() + else: + raise ValueError('The examples have images under more than one name.') + self.state_like_names_and_shapes['images'] = '%%d/%s/encoded' % image_name, None if self.hparams.use_state: self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (3,) self.action_like_names_and_shapes['actions'] = '%d/action', (4,) - if os.path.basename(self.input_dir).endswith('annotations'): + if any([re.search('\d+/object_pos', name) for name in feature.keys()]): self.state_like_names_and_shapes['object_pos'] = '%d/object_pos', None # shape is (2 * num_designated_pixels) self._check_or_infer_shapes()