pool.py 4.47 KB
Newer Older
mova's avatar
mova committed
1
from collections.abc import Iterable
mova's avatar
mova committed
2
from multiprocessing.queues import Empty
mova's avatar
mova committed
3

mova's avatar
mova committed
4
from torch import multiprocessing as mp
mova's avatar
mova committed
5

6
7
8
from fgsim.utils.count_iterations import CountIterations
from fgsim.utils.logger import logger

mova's avatar
mova committed
9
from .step_base import StepBase
mova's avatar
mova committed
10
11
12
from .terminate_queue import TerminateQueue


mova's avatar
mova committed
13
class PoolStep(StepBase):
mova's avatar
mova committed
14
15
16
17
18
19
20
    """Class for simple processing steps pooled over multiple workes.
    Each incoming object is processed by a multiple subprocesses
    per worker into a single outgoing element."""

    def __init__(
        self,
        *args,
mova's avatar
mova committed
21
        nworkers: int,
mova's avatar
mova committed
22
23
24
25
26
27
28
        **kwargs,
    ):
        # Spawn only one process with deamonize false that can spawn the Pool
        kwargs["deamonize"] = False

        # Make sure the contructor of the base class only initializes
        # one process that manages the pool
mova's avatar
mova committed
29
        self.n_pool_workers = nworkers
mova's avatar
mova committed
30
31
        kwargs["nworkers"] = 1
        super().__init__(*args, **kwargs)
mova's avatar
mova committed
32
33
34
35
36

    def start(self):
        for p in self.processes:
            p.daemon = self.deamonize
            p.start()
mova's avatar
mova committed
37

mova's avatar
mova committed
38
39
    def stop(self):
        for p in self.processes:
40
41
            if p.is_alive():
                p.join(5)
mova's avatar
mova committed
42
                p.kill()
mova's avatar
mova committed
43

mova's avatar
mova committed
44
45
    def process_status(self):
        return (
mova's avatar
mova committed
46
47
            sum([p.is_alive() for p in self.processes]) * self.n_pool_workers,
            self.n_pool_workers,
mova's avatar
mova committed
48
49
50
        )

    def _worker(self):
mova's avatar
mova committed
51
        self.set_workername()
52
        logger.debug(
mova's avatar
mova committed
53
            f"{self.workername} pool  initalizing with {self.n_pool_workers} subprocesses"
mova's avatar
mova committed
54
        )
mova's avatar
mova committed
55
56
        self.pool = mp.Pool(self.n_pool_workers)

57
        while not self.shutdown_event.is_set():
mova's avatar
mova committed
58
59
60
61
            try:
                wkin = self.inq.get(block=True, timeout=0.005)
            except Empty:
                continue
mova's avatar
mova committed
62
            logger.debug(
63
                f"""\
64
{self.workername} working on element of type {type(wkin)} from queue {id(self.inq)}."""
mova's avatar
mova committed
65
            )
66

mova's avatar
mova committed
67
68
69
70
            # If the process gets a TerminateQueue object,
            # it terminates the pool and and puts the terminal element in
            # in the outgoing queue.
            if isinstance(wkin, TerminateQueue):
mova's avatar
mova committed
71
                logger.info(f"{self.workername} terminating")
72
                self.safe_put(self.outq, TerminateQueue())
73
                logger.warning(
74
75
76
77
78
                    f"""\
{self.workername} finished with iterable (in {self.count_in}/out {self.count_out})"""
                )
                self.count_in, self.count_out = 0, 0
                continue
79
            self.count_in += 1
80
81
            wkin = self._clone_tensors(wkin)

82
83
84
85
86
87
            assert isinstance(wkin, Iterable)
            logger.debug(
                f"{self.workername} got element"
                + f" {id(wkin)} of element type {type(wkin)}."
            )
            wkin_iter = CountIterations(wkin)
mova's avatar
mova committed
88

89
90
91
92
93
94
95
96
97
98
99
100
101
            try:
                wkout_async_res = self.pool.map_async(
                    self.workerfn,
                    wkin_iter,
                )
                while True:
                    if wkout_async_res.ready():
                        wkout = wkout_async_res.get()
                        break
                    elif self.shutdown_event.is_set():
                        break
                    wkout_async_res.wait(1)
                if self.shutdown_event.is_set():
mova's avatar
mova committed
102
                    break
103
104

            except Exception as error:
105
                logger.warning(f"""{self.workername} got error""")
106
                self.handle_error(error, wkin)
107
                break
108

mova's avatar
mova committed
109
            #             if wkin_iter.count > 200:
110
            #                 logger.warning(
mova's avatar
mova committed
111
112
113
114
115
116
117
            #                     f"""\
            # Giving large iterables ({wkin_iter.count})\
            # to a worker can lead to crashes.
            # Lower the number here if you see an error like \
            # 'RuntimeError: unable to mmap x bytes from file </torch_x>:
            # Cannot allocate memory'"""
            #                 )
118
119
            logger.debug(
                f"""\
mova's avatar
mova committed
120
121
{self.workername} push pool output list {id(wkout)} with \
element type {type(wkin)} into output queue {id(self.outq)}."""
122
123
124
            )
            # Put while there is no shutdown event
            self.safe_put(self.outq, wkout)
125
            self.count_out += 1
126
127
            del wkin_iter
            del wkin
mova's avatar
mova committed
128
        self.pool.close()
mova's avatar
mova committed
129
        self.pool.terminate()
mova's avatar
mova committed
130
        logger.debug(f"""{self.workername} pool closed""")
131
        self.outq.cancel_join_thread()
mova's avatar
mova committed
132
        self._close_queues()
mova's avatar
mova committed
133
        logger.debug(f"""{self.workername} queues closed""")