diff --git a/video_prediction_tools/main_scripts/main_predict.py b/video_prediction_tools/main_scripts/main_predict.py
index 5c148ed4ed0113b32d9f63c9ab72899d7c9d878e..8dac161844b1c4b19b6395624cc6248cd2240852 100644
--- a/video_prediction_tools/main_scripts/main_predict.py
+++ b/video_prediction_tools/main_scripts/main_predict.py
@@ -407,7 +407,10 @@ class Postprocess(TrainModel):
         test_handle = test_iterator.string_handle()
         dataset_iterator = tf.data.Iterator.from_string_handle(test_handle, test_tf_dataset.output_types,
                                                                test_tf_dataset.output_shapes)
-        self.input_iter, self.ts_iter = dataset_iterator.get_next()
+
+        self.inputs = dataset_iterator.get_next()
+        self.input_iter, self.ts_iter = self.inputs[0], self.inputs[1]
+        #self.input_iter, self.ts_iter = dataset_iterator.get_next()
         # ts_iter = input_iter["T_start"]
 
         # return input_iter, ts_iter
@@ -558,6 +561,7 @@ class Postprocess(TrainModel):
         # init sample index for looping
         sample_ind = 0
         nsamples = self.num_samples_per_epoch
+        print('test samples: {}'.format(nsamples))
         # # initialize xarray datasets
         # eval_metric_ds = Postprocess.init_metric_ds(self.fcst_products, self.eval_metrics, self.vars_in[self.channel],
         #                                             nsamples, self.future_length)
@@ -569,9 +573,11 @@ class Postprocess(TrainModel):
             # feed and run the trained model; returned array has the shape [batchsize, seq_len, lat, lon, channel]
             print("%{0}: Start generating {1:d} predictions at current sample index {2:d}".format(method, self.batch_size,
                                                                                                   sample_ind))
-                              
-            input_results = self.sess.run(self.input_iter)    
-            t_starts = self.sess.run(self.ts_iter)
+             
+            input_data = self.sess.run(self.inputs)
+            input_results, t_starts = input_data[0], input_data[1]
+            #input_results = self.sess.run(self.input_iter)    
+            #t_starts = self.sess.run(self.ts_iter)
             # print('self.input_iter: {}'.format(self.input_iter.items()))
             # t_starts = input_results["T_start"]