Skip to content
Snippets Groups Projects
Commit 54a21996 authored by Yan Ji's avatar Yan Ji
Browse files

update main_predict.py

parent 6fb659a9
No related branches found
No related tags found
No related merge requests found
Pipeline #106902 failed
......@@ -407,7 +407,10 @@ class Postprocess(TrainModel):
test_handle = test_iterator.string_handle()
dataset_iterator = tf.data.Iterator.from_string_handle(test_handle, test_tf_dataset.output_types,
test_tf_dataset.output_shapes)
self.input_iter, self.ts_iter = dataset_iterator.get_next()
self.inputs = dataset_iterator.get_next()
self.input_iter, self.ts_iter = self.inputs[0], self.inputs[1]
#self.input_iter, self.ts_iter = dataset_iterator.get_next()
# ts_iter = input_iter["T_start"]
# return input_iter, ts_iter
......@@ -558,6 +561,7 @@ class Postprocess(TrainModel):
# init sample index for looping
sample_ind = 0
nsamples = self.num_samples_per_epoch
print('test samples: {}'.format(nsamples))
# # initialize xarray datasets
# eval_metric_ds = Postprocess.init_metric_ds(self.fcst_products, self.eval_metrics, self.vars_in[self.channel],
# nsamples, self.future_length)
......@@ -570,8 +574,10 @@ class Postprocess(TrainModel):
print("%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(method, self.batch_size,
sample_ind))
input_results = self.sess.run(self.input_iter)
t_starts = self.sess.run(self.ts_iter)
input_data = self.sess.run(self.inputs)
input_results, t_starts = input_data[0], input_data[1]
#input_results = self.sess.run(self.input_iter)
#t_starts = self.sess.run(self.ts_iter)
# print('self.input_iter: {}'.format(self.input_iter.items()))
# t_starts = input_results["T_start"]
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment