Skip to content
Snippets Groups Projects
Select Git revision
  • ad302db1fb194e46731af79dca057e881c73e965
  • master default protected
  • enxhi_issue460_remove_TOAR-I_access
  • michael_issue459_preprocess_german_stations
  • sh_pollutants
  • develop protected
  • release_v2.4.0
  • michael_issue450_feat_load-ifs-data
  • lukas_issue457_feat_set-config-paths-as-parameter
  • lukas_issue454_feat_use-toar-statistics-api-v2
  • lukas_issue453_refac_advanced-retry-strategy
  • lukas_issue452_bug_update-proj-version
  • lukas_issue449_refac_load-era5-data-from-toar-db
  • lukas_issue451_feat_robust-apriori-estimate-for-short-timeseries
  • lukas_issue448_feat_load-model-from-path
  • lukas_issue447_feat_store-and-load-local-clim-apriori-data
  • lukas_issue445_feat_data-insight-plot-monthly-distribution
  • lukas_issue442_feat_bias-free-evaluation
  • lukas_issue444_feat_choose-interp-method-cams
  • 414-include-crps-analysis-and-other-ens-verif-methods-or-plots
  • lukas_issue384_feat_aqw-data-handler
  • v2.4.0 protected
  • v2.3.0 protected
  • v2.2.0 protected
  • v2.1.0 protected
  • Kleinert_etal_2022_initial_submission
  • v2.0.0 protected
  • v1.5.0 protected
  • v1.4.0 protected
  • v1.3.0 protected
  • v1.2.1 protected
  • v1.2.0 protected
  • v1.1.0 protected
  • IntelliO3-ts-v1.0_R1-submit
  • v1.0.0 protected
  • v0.12.2 protected
  • v0.12.1 protected
  • v0.12.0 protected
  • v0.11.0 protected
  • v0.10.0 protected
  • IntelliO3-ts-v1.0_initial-submit
41 results

inception_model.py

Blame
  • kth_dataset.py 4.19 KiB
    import argparse
    import glob
    import itertools
    import os
    import random
    
    import cv2
    import numpy as np
    import tensorflow as tf
    
    from video_prediction.datasets.base_dataset import VarLenFeatureVideoDataset
    
    
    class KTHVideoDataset(VarLenFeatureVideoDataset):
        def __init__(self, *args, **kwargs):
            super(KTHVideoDataset, self).__init__(*args, **kwargs)
            self.state_like_names_and_shapes['images'] = 'images/encoded', (64, 64, 3)
    
        def get_default_hparams_dict(self):
            default_hparams = super(KTHVideoDataset, self).get_default_hparams_dict()
            hparams = dict(
                context_frames=10,
                sequence_length=20,
            )
            return dict(itertools.chain(default_hparams.items(), hparams.items()))
    
        @property
        def jpeg_encoding(self):
            return False
    
        def num_examples_per_epoch(self):
            return len(self.filenames)
    
    
    def _bytes_feature(value):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
    
    
    def _bytes_list_feature(values):
        return tf.train.Feature(bytes_list=tf.train.BytesList(value=values))
    
    
    def _int64_feature(value):
        return tf.train.Feature(int64_list=tf.train.Int64List(value=[value]))
    
    
    def partition_data(input_dir):
        # List files and corresponding person IDs
        files = glob.glob(os.path.join(input_dir, '*/*.avi'))
        persons = np.array([int(f.split('/person')[1].split('_')[0]) for f in files])
        train_mask = persons <= 16
    
        train_fnames = [files[i] for i in np.where(train_mask)[0]]
        test_fnames = [files[i] for i in np.where(~train_mask)[0]]
    
        random.shuffle(train_fnames)
    
        pivot = int(0.95 * len(train_fnames))
        train_fnames, val_fnames = train_fnames[:pivot], train_fnames[pivot:]
        return train_fnames, val_fnames, test_fnames
    
    
    def read_video(fname):
        if not os.path.isfile(fname):
            raise FileNotFoundError
        vidcap = cv2.VideoCapture(fname)
        frames, (success, image) = [], vidcap.read()
        while success:
            frames.append(image)
            success, image = vidcap.read()
        return frames
    
    
    def save_tf_record(output_fname, sequences, preprocess_image):
        print('saving sequences to %s' % output_fname)
        with tf.python_io.TFRecordWriter(output_fname) as writer:
            for sequence in sequences:
                num_frames = len(sequence)
                height, width, channels = sequence[0].shape
                encoded_sequence = [preprocess_image(image) for image in sequence]
                features = tf.train.Features(feature={
                    'sequence_length': _int64_feature(num_frames),
                    'height': _int64_feature(height),
                    'width': _int64_feature(width),
                    'channels': _int64_feature(channels),
                    'images/encoded': _bytes_list_feature(encoded_sequence),
                })
                example = tf.train.Example(features=features)
                writer.write(example.SerializeToString())
    
    
    def read_videos_and_save_tf_records(output_dir, fnames):
        def preprocess_image(image):
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = cv2.resize(image[:, 20:-20], (64, 64), interpolation=cv2.INTER_LINEAR)
            return image.tostring()
    
        for i, fname in enumerate(fnames):
            output_fname = os.path.join(output_dir, os.path.splitext(os.path.basename(fname))[0] + '.tfrecords')
            sequence = read_video(fname)
            save_tf_record(output_fname, [sequence], preprocess_image)
    
    
    def main():
        parser = argparse.ArgumentParser()
        parser.add_argument("input_dir", type=str, help="directory containing the directories "
                                                        "boxing, handclapping, handwaving, "
                                                        "jogging, running, walking")
        parser.add_argument("output_dir", type=str)
        args = parser.parse_args()
    
        partition_names = ['train', 'val', 'test']
        partition_fnames = partition_data(args.input_dir)
    
        for partition_name, partition_fnames in zip(partition_names, partition_fnames):
            partition_dir = os.path.join(args.output_dir, partition_name)
            if not os.path.exists(partition_dir):
                os.makedirs(partition_dir, exist_ok=True)
    
            read_videos_and_save_tf_records(partition_dir, partition_fnames)
    
    
    if __name__ == '__main__':
        main()