Commit 449f2564 authored by mova's avatar mova
Browse files

fix error propagation for torch_geometric batches

parent daa3a064
......@@ -46,15 +46,9 @@ class PoolStep(StepBase):
self.n_pool_workers,
)
def propagete_error(self, element, error=Exception):
workermsg = (
f"""{self.workername} failed on element of type {type(element)}."""
)
self.error_queue.put((workermsg, element, error))
def _worker(self):
self.set_workername()
logger.info(
logger.debug(
f"{self.workername} pool initalizing with {self.n_pool_workers} subprocesses"
)
self.pool = mp.Pool(self.n_pool_workers)
......@@ -108,7 +102,7 @@ class PoolStep(StepBase):
except Exception as error:
logger.warn(f"""{self.workername} got error""")
self.propagete_error(wkin, error)
self.handle_error(error, wkin)
break
# if wkin_iter.count > 200:
......
......@@ -77,9 +77,7 @@ class ProcessStep(StepBase):
# Catch Errors in the worker function
except Exception as error:
workermsg = f"""
{self.workername} failed on element of type of type {type(wkin)}.\n\n{wkin}"""
self.error_queue.put((workermsg, wkin, error))
self.handle_error(error, wkin)
break
logger.debug(
......
......@@ -144,8 +144,10 @@ class Sequence:
def stop(self):
logger.info("Before Sequence Stop\n" + str(self.flowstatus()))
logger.warn("Setting shutdown event!")
self.shutdown_event.set()
# # Drain the queues:
for queue in self.queues:
while True:
try:
......@@ -157,8 +159,8 @@ class Sequence:
logger.debug(f"Stopping sequence step {istep}")
step.stop()
self.queues[0].close()
self.queues[0].join_thread()
# self.queues[0].close()
# self.queues[0].join_thread()
self.queues[-1].close()
self.queues[-1].join_thread()
......@@ -175,7 +177,6 @@ class Sequence:
Process {ip} (of {len(step.processes)}) of step {istep} is still alive!"""
)
logger.debug(f"Stopping sequence step {istep}")
step.stop()
def read_error_queue(self):
threading.current_thread().setName("readErrorQueue")
......
......@@ -6,7 +6,7 @@ import torch
import torch_geometric
from torch import multiprocessing as mp
from ...utils.batch_utils import clone_batch
from ...utils.batch_utils import batch_to_numpy_dict, clone_batch
from ...utils.logger import logger
......@@ -94,5 +94,14 @@ Had to kill process of name {self.name}."""
def process_status(self):
return (sum([p.is_alive() for p in self.processes]), self.nworkers)
def handle_error(self, error, obj):
if isinstance(obj, torch.Tensor):
obj = obj.numpy()
if isinstance(obj, torch_geometric.data.Data):
obj = batch_to_numpy_dict(obj)
workermsg = f"""
{self.workername} failed on element of type of type {type(obj)}."""
self.error_queue.put((workermsg, obj, error))
def _worker(self):
raise NotImplementedError
......@@ -228,6 +228,3 @@ Skipping {n_skip_events} events => {n_skip_chunks} chunks and {n_skip_batches} b
def __iter__(self):
return iter(self.qfseq)
def __del__(self):
self.qfseq.stop()
......@@ -66,10 +66,11 @@ def training_procedure() -> None:
batch = next(train_state.loader.qfseq)
except StopIteration:
# If there is no next batch go to the next epoch
train_state.experiment.log_epoch_end(
train_state.state["epoch"],
step=train_state.state["grad_step"],
)
if not conf.debug:
train_state.experiment.log_epoch_end(
train_state.state["epoch"],
step=train_state.state["grad_step"],
)
logger.warning("New epoch!")
train_state.state.epoch += 1
train_state.state.ibatch = 0
......
......@@ -75,3 +75,24 @@ def check_batch_device(batch):
else:
if v.device != device:
print(f"Key {k} on wrong device {v.device}")
def batch_to_numpy_dict(batch):
def tonumpy(element):
if torch.is_tensor(element):
return element.numpy()
elif isinstance(element, list):
return [tonumpy(ee) for ee in element]
elif isinstance(element, dict):
return {k: tonumpy(ee) for k, ee in element.items()}
elif element is None:
return None
elif isinstance(element, (int, str, float)):
return element
else:
raise ValueError
batch_new = torch_geometric.data.Batch().from_dict(
{k: tonumpy(v) for k, v in batch.to_dict().items()}
)
return batch_new
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment