diff --git a/video_prediction_tools/main_scripts/main_predict.py b/video_prediction_tools/main_scripts/main_predict.py index 5c148ed4ed0113b32d9f63c9ab72899d7c9d878e..8dac161844b1c4b19b6395624cc6248cd2240852 100644 --- a/video_prediction_tools/main_scripts/main_predict.py +++ b/video_prediction_tools/main_scripts/main_predict.py @@ -407,7 +407,10 @@ class Postprocess(TrainModel): test_handle = test_iterator.string_handle() dataset_iterator = tf.data.Iterator.from_string_handle(test_handle, test_tf_dataset.output_types, test_tf_dataset.output_shapes) - self.input_iter, self.ts_iter = dataset_iterator.get_next() + + self.inputs = dataset_iterator.get_next() + self.input_iter, self.ts_iter = self.inputs[0], self.inputs[1] + #self.input_iter, self.ts_iter = dataset_iterator.get_next() # ts_iter = input_iter["T_start"] # return input_iter, ts_iter @@ -558,6 +561,7 @@ class Postprocess(TrainModel): # init sample index for looping sample_ind = 0 nsamples = self.num_samples_per_epoch + print('test samples: {}'.format(nsamples)) # # initialize xarray datasets # eval_metric_ds = Postprocess.init_metric_ds(self.fcst_products, self.eval_metrics, self.vars_in[self.channel], # nsamples, self.future_length) @@ -569,9 +573,11 @@ class Postprocess(TrainModel): # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel] print("%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(method, self.batch_size, sample_ind)) - - input_results = self.sess.run(self.input_iter) - t_starts = self.sess.run(self.ts_iter) + + input_data = self.sess.run(self.inputs) + input_results, t_starts = input_data[0], input_data[1] + #input_results = self.sess.run(self.input_iter) + #t_starts = self.sess.run(self.ts_iter) # print('self.input_iter: {}'.format(self.input_iter.items())) # t_starts = input_results["T_start"]