From 54a21996a1c15d4f6c163ff07e91c55d0090dd54 Mon Sep 17 00:00:00 2001 From: Yan Ji <y.ji@fz-juelich.de> Date: Thu, 21 Jul 2022 17:48:19 +0200 Subject: [PATCH] update main_predict.py --- .../main_scripts/main_predict.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/video_prediction_tools/main_scripts/main_predict.py b/video_prediction_tools/main_scripts/main_predict.py index 5c148ed4..8dac1618 100644 --- a/video_prediction_tools/main_scripts/main_predict.py +++ b/video_prediction_tools/main_scripts/main_predict.py @@ -407,7 +407,10 @@ class Postprocess(TrainModel): test_handle = test_iterator.string_handle() dataset_iterator = tf.data.Iterator.from_string_handle(test_handle, test_tf_dataset.output_types, test_tf_dataset.output_shapes) - self.input_iter, self.ts_iter = dataset_iterator.get_next() + + self.inputs = dataset_iterator.get_next() + self.input_iter, self.ts_iter = self.inputs[0], self.inputs[1] + #self.input_iter, self.ts_iter = dataset_iterator.get_next() # ts_iter = input_iter["T_start"] # return input_iter, ts_iter @@ -558,6 +561,7 @@ class Postprocess(TrainModel): # init sample index for looping sample_ind = 0 nsamples = self.num_samples_per_epoch + print('test samples: {}'.format(nsamples)) # # initialize xarray datasets # eval_metric_ds = Postprocess.init_metric_ds(self.fcst_products, self.eval_metrics, self.vars_in[self.channel], # nsamples, self.future_length) @@ -569,9 +573,11 @@ class Postprocess(TrainModel): # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel] print("%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(method, self.batch_size, sample_ind)) - - input_results = self.sess.run(self.input_iter) - t_starts = self.sess.run(self.ts_iter) + + input_data = self.sess.run(self.inputs) + input_results, t_starts = input_data[0], input_data[1] + #input_results = self.sess.run(self.input_iter) + #t_starts = self.sess.run(self.ts_iter) # print('self.input_iter: {}'.format(self.input_iter.items())) # t_starts = input_results["T_start"] -- GitLab