Skip to content
Snippets Groups Projects
Select Git revision
  • 01d94a19de7d3686d6667c4ee4460038ae534b5c
  • master default
  • bing_issues#190_tf2
  • bing_tf2_convert
  • bing_issue#189_train_modular
  • simon_#172_integrate_weatherbench
  • develop
  • bing_issue#188_restructure_ambs
  • yan_issue#100_extract_prcp_data
  • bing_issue#170_data_preprocess_training_tf1
  • Gong2022_temperature_forecasts
  • bing_issue#186_clean_GMD1_tag
  • yan_issue#179_integrate_GZAWS_data_onfly
  • bing_issue#178_runscript_bug_postprocess
  • michael_issue#187_bugfix_setup_runscript_template
  • bing_issue#180_bugs_postprpocess_meta_postprocess
  • yan_issue#177_repo_for_CLGAN_gmd
  • bing_issue#176_integrate_weather_bench
  • michael_issue#181_eval_era5_forecasts
  • michael_issue#182_eval_subdomain
  • michael_issue#119_warmup_Horovod
  • bing_issue#160_test_zam347
  • ambs_v1
  • ambs_gmd_nowcasting_v1.0
  • GMD1
  • modular_booster_20210203
  • new_structure_20201004_v1.0
  • old_structure_20200930
28 results

plot_ambs_forecast.py

Blame
  • softmotion_dataset.py 3.55 KiB
    import itertools
    import os
    import re
    
    import tensorflow as tf
    
    from video_prediction.utils import tf_utils
    from .base_dataset import VideoDataset
    
    
    class SoftmotionVideoDataset(VideoDataset):
        """
        https://sites.google.com/view/sna-visual-mpc
        """
        def __init__(self, *args, **kwargs):
            super(SoftmotionVideoDataset, self).__init__(*args, **kwargs)
            # infer name of image feature and check if object_pos feature is present
            from google.protobuf.json_format import MessageToDict
            example = next(tf.python_io.tf_record_iterator(self.filenames[0]))
            dict_message = MessageToDict(tf.train.Example.FromString(example))
            feature = dict_message['features']['feature']
            image_names = set()
            for name in feature.keys():
                m = re.search('\d+/(\w+)/encoded', name)
                if m:
                    image_names.add(m.group(1))
            # look for image_aux1 and image_view0 in that order of priority
            image_name = None
            for name in ['image_aux1', 'image_view0']:
                if name in image_names:
                    image_name = name
                    break
            if not image_name:
                if len(image_names) == 1:
                    image_name = image_names.pop()
                else:
                    raise ValueError('The examples have images under more than one name.')
            self.state_like_names_and_shapes['images'] = '%%d/%s/encoded' % image_name, None
            if self.hparams.use_state:
                self.state_like_names_and_shapes['states'] = '%d/endeffector_pos', (3,)
                self.action_like_names_and_shapes['actions'] = '%d/action', (4,)
            if any([re.search('\d+/object_pos', name) for name in feature.keys()]):
                self.state_like_names_and_shapes['object_pos'] = '%d/object_pos', None  # shape is (2 * num_designated_pixels)
            self._check_or_infer_shapes()
    
        def get_default_hparams_dict(self):
            default_hparams = super(SoftmotionVideoDataset, self).get_default_hparams_dict()
            hparams = dict(
                context_frames=2,
                sequence_length=12,
                time_shift=2,
            )
            return dict(itertools.chain(default_hparams.items(), hparams.items()))
    
        @property
        def jpeg_encoding(self):
            return False
    
        def parser(self, serialized_example):
            state_like_seqs, action_like_seqs = super(SoftmotionVideoDataset, self).parser(serialized_example)
            if 'object_pos' in state_like_seqs:
                object_pos = state_like_seqs['object_pos']
                height, width, _ = self.state_like_names_and_shapes['images'][1]
                object_pos = tf.reshape(object_pos, [object_pos.shape[0].value, -1, 2])
                pix_distribs = tf.stack([tf_utils.pixel_distribution(object_pos_, height, width)
                                         for object_pos_ in tf.unstack(object_pos, axis=1)], axis=-1)
                state_like_seqs['pix_distribs'] = pix_distribs
            return state_like_seqs, action_like_seqs
    
        def num_examples_per_epoch(self):
            # extract information from filename to count the number of trajectories in the dataset
            count = 0
            for filename in self.filenames:
                match = re.search('traj_(\d+)_to_(\d+).tfrecords', os.path.basename(filename))
                start_traj_iter = int(match.group(1))
                end_traj_iter = int(match.group(2))
                count += end_traj_iter - start_traj_iter + 1
    
            # alternatively, the dataset size can be determined like this, but it's very slow
            # count = sum(sum(1 for _ in tf.python_io.tf_record_iterator(filename)) for filename in filenames)
            return count