Skip to content
Snippets Groups Projects
Commit 2974626b authored by Alex Lee's avatar Alex Lee
Browse files

Remove assumptions of directory name for softmotion dataset.

parent a2f5b1b7
No related branches found
No related tags found
No related merge requests found
...@@ -245,7 +245,7 @@ class VideoDataset(BaseVideoDataset): ...@@ -245,7 +245,7 @@ class VideoDataset(BaseVideoDataset):
list(action_like_names_and_shapes.items())): list(action_like_names_and_shapes.items())):
name, shape = name_and_shape name, shape = name_and_shape
feature = self._dict_message['features']['feature'] feature = self._dict_message['features']['feature']
names = [name_ for name_ in feature.keys() if re.search(name.replace('%d', '(\d+)'), name_) is not None] names = [name_ for name_ in feature.keys() if re.search(name.replace('%d', '\d+'), name_) is not None]
if not names: if not names:
raise ValueError('Could not found any feature with name pattern %s.' % name) raise ValueError('Could not found any feature with name pattern %s.' % name)
if example_name in self.state_like_names_and_shapes: if example_name in self.state_like_names_and_shapes:
......
...@@ -14,14 +14,32 @@ class SoftmotionVideoDataset(VideoDataset): ...@@ -14,14 +14,32 @@ class SoftmotionVideoDataset(VideoDataset):
""" """
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super(SoftmotionVideoDataset, self).__init__(*args, **kwargs) super(SoftmotionVideoDataset, self).__init__(*args, **kwargs)
if 'softmotion30_44k' in self.input_dir.split('/'): # infer name of image feature and check if object_pos feature is present
self.state_like_names_and_shapes['images'] = '%d/image_aux1/encoded', None from google.protobuf.json_format import MessageToDict
example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
dict_message = MessageToDict(tf.train.Example.FromString(example))
feature = dict_message['features']['feature']
image_names = set()
for name in feature.keys():
m = re.search('\d+/(\w+)/encoded', name)
if m:
image_names.add(m.group(1))
# look for image_aux1 and image_view0 in that order of priority
image_name = None
for name in ['image_aux1', 'image_view0']:
if name in image_names:
image_name = name
break
if not image_name:
if len(image_names) == 1:
image_name = image_names.pop()
else: else:
self.state_like_names_and_shapes['images'] = '%d/image_view0/encoded', None raise ValueError('The examples have images under more than one name.')
self.state_like_names_and_shapes['images'] = '%%d/%s/encoded' % image_name, None
if self.hparams.use_state: if self.hparams.use_state:
self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (3,) self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (3,)
self.action_like_names_and_shapes['actions'] = '%d/action', (4,) self.action_like_names_and_shapes['actions'] = '%d/action', (4,)
if os.path.basename(self.input_dir).endswith('annotations'): if any([re.search('\d+/object_pos', name) for name in feature.keys()]):
self.state_like_names_and_shapes['object_pos'] = '%d/object_pos', None # shape is (2 * num_designated_pixels) self.state_like_names_and_shapes['object_pos'] = '%d/object_pos', None # shape is (2 * num_designated_pixels)
self._check_or_infer_shapes() self._check_or_infer_shapes()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment