diff --git a/video_prediction_tools/data_preprocess/preprocess_data_step2.py b/video_prediction_tools/data_preprocess/preprocess_data_step2.py index c7df46397291288ff7f6c502158abd0b59889cfc..4170b20e58e29036d766010a7ffbe341e311d755 100644 --- a/video_prediction_tools/data_preprocess/preprocess_data_step2.py +++ b/video_prediction_tools/data_preprocess/preprocess_data_step2.py @@ -116,8 +116,10 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): """ sequences = np.array(sequences) # sanity checks + print(t_start_points[0]) + print(type(t_start_points[0])) assert sequences.shape[0] == len(t_start_points) - assert type(t_start_points[0]) == datetime.datetime + assert type(t_start_points) == datetime.datetime, "What's that: {0} (type {1})".format(t_start_points[0], type(t_start_points[0])) with tf.python_io.TFRecordWriter(output_fname) as writer: for i in range(len(sequences)): @@ -199,7 +201,7 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): X_end = X_start + self.sequence_length seq = X_train[X_start:X_end, ...] # recording the start point of the timestamps (already datetime-objects) - t_start = T_train[X_start] + t_start = ERA5Pkl2Tfrecords.ensure_datetime(T_train[X_start][0]) print("t_start,", t_start) print("type of t_starty", type(t_start)) seq = list(np.array(seq).reshape((self.sequence_length, self.height, self.width, self.nvars))) @@ -244,6 +246,24 @@ class ERA5Pkl2Tfrecords(ERA5Dataset): with open(os.path.join(self.output_dir, 'number_sequences.txt'), 'w') as seq_file: seq_file.write("%d\n" % self.sequences_per_file) + + @staticmethod + def ensure_datetime(date): + """ + Wrapper to return a datetime-object + """ + fmt = "%Y%m%d %H:%M" + if isinstance(date, datetime.datetime): + date_new = date + else: + try: + date_new=pd.to_datetime(date) + date_new=datetime.datetime(date_new.strptime(fmt), fmt) + except Exception as err: + print("%{0}: Could not handle input data {1} which is of type {2}.".format(method, date, type(date))) + raise err + + return date_new # def num_examples_per_epoch(self): # with open(os.path.join(self.input_dir, 'sequence_lengths.txt'), 'r') as sequence_lengths_file: # sequence_lengths = sequence_lengths_file.readlines()