diff --git a/video_prediction_tools/main_scripts/main_predict.py b/video_prediction_tools/main_scripts/main_predict.py index 76dbc7c5347211bb21b5e70ddc44e3e22ba8742d..5c148ed4ed0113b32d9f63c9ab72899d7c9d878e 100644 --- a/video_prediction_tools/main_scripts/main_predict.py +++ b/video_prediction_tools/main_scripts/main_predict.py @@ -80,7 +80,7 @@ class Postprocess(TrainModel): self.run_mode = run_mode self.data_mode = data_mode #self.channel = channel - #self.lquick = lquick + self.lquick = lquick #self.frac_data = frac_data # Attributes set during runtime #self.norm_cls = None @@ -116,6 +116,9 @@ class Postprocess(TrainModel): #self.cond_quantile_vars = self.init_cond_quantile_vars() # setup test dataset and model self.test_dataset, self.num_samples_per_epoch = self.setup_dataset() + self.lats = self.test_dataset.lats + self.lons = self.test_dataset.lons + self.vars_in = self.test_dataset.variables # if lquick and self.test_dataset.shuffled: # self.num_samples_per_epoch = Postprocess.reduce_samples(self.num_samples_per_epoch, frac_data) # self.num_samples_per_epoch = 100 # reduced number of epoch samples -> useful for testing @@ -569,12 +572,11 @@ class Postprocess(TrainModel): input_results = self.sess.run(self.input_iter) t_starts = self.sess.run(self.ts_iter) - print('self.input_iter: {}'.format(self.input_iter)) - print('input_results: {}'.format(input_results)) + # print('self.input_iter: {}'.format(self.input_iter.items())) # t_starts = input_results["T_start"] # feed_dict = {input_ph: input_results[name] for name, input_ph in self.input_iter.items()} - feed_dict = {"x": input_results} + feed_dict = {"IteratorGetNext:0": input_results} gen_images = self.sess.run(self.video_model.outputs['gen_images'], feed_dict=feed_dict) # sanity check on length of forecast sequence @@ -679,20 +681,20 @@ class Postprocess(TrainModel): """ method = Postprocess.get_init_time.__name__ - t_starts = np.squeeze(np.asarray(t_starts)) - if not np.ndim(t_starts) == 1: - raise ValueError("%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H" - .format(method)) + #t_starts = np.squeeze(np.asarray(t_starts)) + #print('t_starts: {}'.format(t_starts)) + #if not np.ndim(t_starts) == 1: + # raise ValueError("%{0}: Inputted t_starts must be a 1D list/array of date-strings with format %Y%m%d%H" + # .format(method)) for i, t_start in enumerate(t_starts): try: - #seq_ts = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"), periods=self.context_frames, - # freq="10min") - print('t_start: ',t_start) - t0 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"), periods=4, - freq="-10min") - t1 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y%m%d%H%M"),periods=self.context_frames-3, - freq="10min") - seq_ts = t0.append(t1)[1:] + seq_ts = pd.date_range(dt.datetime.strptime(str(t_start[0])[2:-1], "%Y-%m-%dT%H:%M:00"), periods=self.context_frames, + freq="10min") + #t0 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y-%m-%dT%H:%M:00"), periods=4, + # freq="-10min") + #t1 = pd.date_range(dt.datetime.strptime(str(t_start), "%Y-%m-%dT%H:%M:00"),periods=self.context_frames-3, + # freq="10min") + #seq_ts = t0.append(t1)[1:] print('seq_ts: ',seq_ts) except Exception as err: print("%{0}: Could not convert {1} to datetime object. Ensure that the date-string format is 'Y%m%d%H'". diff --git a/video_prediction_tools/model_modules/video_prediction/datasets/gzaws_dataset.py b/video_prediction_tools/model_modules/video_prediction/datasets/gzaws_dataset.py index e4106d3900fb460c0ee6919e2f96cf9260c53bd0..ba37513e3b6d92aec1738e0ff26d7f21ddef5793 100644 --- a/video_prediction_tools/model_modules/video_prediction/datasets/gzaws_dataset.py +++ b/video_prediction_tools/model_modules/video_prediction/datasets/gzaws_dataset.py @@ -85,6 +85,9 @@ class GZawsDataset(BaseDataset): self.nlon = len(ds["lon"].values) # .values[1:-3] self.n_samples = data_arr.shape[0] self.n_vars = len(self.variables) + + self.lons = ds["lon"].values + self.lats = ds["lat"].values return data_arr, init_times