Skip to content
Snippets Groups Projects
Commit 39d9ffbc authored by Bing Gong's avatar Bing Gong
Browse files

correct the train_iterator function

parent 87ae8bc7
No related branches found
No related tags found
No related merge requests found
Pipeline #46314 failed
...@@ -144,10 +144,10 @@ def make_dataset_iterator(train_dataset, val_dataset, batch_size ): ...@@ -144,10 +144,10 @@ def make_dataset_iterator(train_dataset, val_dataset, batch_size ):
val_tf_dataset = val_dataset.make_dataset_v2(batch_size) val_tf_dataset = val_dataset.make_dataset_v2(batch_size)
val_iterator = val_tf_dataset.make_one_shot_iterator() val_iterator = val_tf_dataset.make_one_shot_iterator()
val_handle = val_iterator.string_handle() val_handle = val_iterator.string_handle()
#iterator = tf.data.Iterator.from_string_handle( iterator = tf.data.Iterator.from_string_handle(
# train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes) train_handle, train_tf_dataset.output_types, train_tf_dataset.output_shapes)
inputs = train_iterator.get_next() inputs = train_iterator.get_next()
val = val_iterator.get_next()
return inputs,train_handle, val_handle return inputs,train_handle, val_handle
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment