Commit 13c64e0d authored by mova's avatar mova
Browse files

qf: increase timeout, unify terminal element handling in pack, only clone when...

qf: increase timeout, unify terminal element handling in pack, only clone when getting the element from the incoming queue
parent 79ff3d4b
......@@ -86,4 +86,3 @@ __all__ = ["pack", "process_step", "sequence", "pool"]
# print_with_lock(res.flowstatus())
# print_with_lock("Done Iterating")
# %%
......@@ -10,7 +10,7 @@ from .terminate_queue import TerminateQueue
class InOutStep:
def __init__(self):
pass
self.shutdown_event = NotImplemented
def safe_put(self, queue, element):
while not self.shutdown_event.is_set():
......@@ -56,7 +56,7 @@ class OutputStep(InOutStep):
def __next__(self):
while not self.shutdown_event.is_set():
try:
out = self.inq.get(block=True, timeout=0.005)
out = self.inq.get(block=True, timeout=0.05)
if isinstance(out, TerminateQueue):
logger.debug("OutputStep got terminal element.")
break
......
......@@ -14,13 +14,12 @@ class UnpackStep(StepBase):
def __init__(self):
super().__init__(name="Unpack")
def _terminate(self):
def __handle_terminal(self):
logger.debug(
f"{self.workername} push terminal element into output queue {id(self.outq)}."
f"""\
{self.workername} push terminal element into output queue {id(self.outq)}."""
)
self.safe_put(TerminateQueue())
self._close_queues()
logger.info(f"{self.workername} terminating")
self.safe_put(self.outq, TerminateQueue())
def _worker(self):
self.set_workername()
......@@ -29,7 +28,8 @@ class UnpackStep(StepBase):
if self.shutdown_event.is_set():
break
try:
wkin = self.inq.get(block=True, timeout=0.005)
wkin = self.inq.get(block=True, timeout=0.05)
wkin = self._clone_tensors(wkin)
except Empty:
continue
logger.debug(
......@@ -37,28 +37,26 @@ class UnpackStep(StepBase):
{self.workername} working type {type(wkin)} from queue {id(self.inq)}."""
)
if isinstance(wkin, TerminateQueue):
self.__handle_terminal()
continue
if not isinstance(wkin, Iterable):
errormsg = f"""\
{self.workername} cannot iterate over element type {type(wkin)}."""
self.error_queue.put((errormsg, wkin, ValueError))
break
else:
if not isinstance(wkin, Iterable):
errormsg = (
f"{self.workername} cannot iterate over "
f"element type {type(wkin)}."
)
self.error_queue.put((errormsg, wkin, ValueError))
break
logger.debug(
f"{self.workername} got element of element type {type(wkin)}."
)
for element in wkin:
logger.debug(
f"{self.workername} got element of element type {type(wkin)}."
f"""\
{self.workername} push element of type {type(wkin)} into output queue."""
)
for element in wkin:
logger.debug(
f"{self.workername} push element of type "
+ f"{type(wkin)} into output queue."
)
if hasattr(element, "clone"):
element = self._clone_tensors(element)
self.safe_put(self.outq, element)
del wkin
self._terminate()
self.safe_put(self.outq, element)
del wkin
self._close_queues()
logger.info(f"{self.workername} terminating")
class PackStep(StepBase):
......@@ -73,17 +71,18 @@ class PackStep(StepBase):
self.nelements = nelements
self.collected_elements = []
def _terminate(self):
def __handle_terminal(self):
if len(self.collected_elements) > 0:
logger.debug(
f"""\
{self.workername} terminal element of type \
{type(self.collected_elements[0])} into output queue {id(self.outq)}."""
{self.workername} put remainder of size {len(self.collected_elements)} into output queue."""
)
self.safe_put(self.outq, self.collected_elements)
logger.info(f"{self.workername} terminating")
logger.debug(
f"""\
{self.workername} terminal element into output queue {id(self.outq)}."""
)
self.safe_put(self.outq, TerminateQueue())
self._close_queues()
def _worker(self):
self.set_workername()
......@@ -92,18 +91,23 @@ class PackStep(StepBase):
if self.shutdown_event.is_set():
break
try:
wkin = self.inq.get(block=True, timeout=0.005)
wkin = self.inq.get(block=True, timeout=0.05)
wkin = self._clone_tensors(wkin)
except Empty:
continue
logger.debug(
f"""\
{self.workername} working on type {type(wkin)} from queue {id(self.inq)}."""
)
if isinstance(wkin, TerminateQueue):
break
wkin = self._clone_tensors(wkin)
self.__handle_terminal()
continue
logger.debug(f"{self.workername} storing element of type {type(wkin)}.")
logger.debug(
f"""\
{self.workername} storing element of type {type(wkin)}."""
)
self.collected_elements.append(wkin)
if len(self.collected_elements) == self.nelements:
......@@ -115,7 +119,7 @@ class PackStep(StepBase):
self.safe_put(self.outq, self.collected_elements)
self.collected_elements = []
del wkin
self._terminate()
self._close_queues()
class RepackStep(StepBase):
......@@ -152,7 +156,7 @@ class RepackStep(StepBase):
if self.shutdown_event.is_set():
break
try:
wkin = self.inq.get(block=True, timeout=0.005)
wkin = self.inq.get(block=True, timeout=0.05)
except Empty:
continue
logger.debug(
......@@ -174,8 +178,7 @@ class RepackStep(StepBase):
(len {len(wkin) if hasattr(wkin,'__len__') else '?'})."""
)
for element in wkin:
e_cloned = self._clone_tensors(element)
self.collected_elements.append(e_cloned)
self.collected_elements.append(element)
if len(self.collected_elements) == self.nelements:
logger.debug(
f"""\
......
......@@ -50,13 +50,14 @@ class PoolStep(StepBase):
def _worker(self):
self.set_workername()
logger.debug(
f"{self.workername} pool initalizing with {self.n_pool_workers} subprocesses"
f"{self.workername} pool initalizing with"
f" {self.n_pool_workers} subprocesses"
)
self.pool = mp.Pool(self.n_pool_workers)
while not self.shutdown_event.is_set():
try:
wkin = self.inq.get(block=True, timeout=0.005)
wkin = self.inq.get(block=True, timeout=0.05)
except Empty:
continue
logger.debug(
......
......@@ -54,7 +54,7 @@ class ProcessStep(StepBase):
)
while not self.shutdown_event.is_set():
try:
wkin = self.inq.get(block=True, timeout=0.005)
wkin = self.inq.get(block=True, timeout=0.05)
except Empty:
continue
logger.debug(
......
......@@ -53,13 +53,13 @@ class StepBase:
def _clone_tensors(self, wkin):
if isinstance(wkin, list):
wkin = [self._clone_tensors(e) for e in wkin]
return [self._clone_tensors(e) for e in wkin]
elif isinstance(wkin, GeneratorType):
return (self._clone_tensors(e) for e in wkin)
elif isinstance(wkin, torch_geometric.data.batch.Data):
return clone_batch(wkin)
elif isinstance(wkin, torch.Tensor):
wkin = wkin.clone()
return wkin.clone()
return wkin
def set_workername(self):
......@@ -84,6 +84,7 @@ Had to kill process of name {self.name}."""
p.join(0)
def safe_put(self, queue, element):
# element = self._clone_tensors(element)
while not self.shutdown_event.is_set():
try:
queue.put(element, True, 1)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment