xref: /aosp_15_r20/external/pytorch/torch/multiprocessing/spawn.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1*da0073e9SAndroid Build Coastguard Worker# mypy: allow-untyped-defs
2*da0073e9SAndroid Build Coastguard Workerimport logging
3*da0073e9SAndroid Build Coastguard Workerimport multiprocessing
4*da0073e9SAndroid Build Coastguard Workerimport multiprocessing.connection
5*da0073e9SAndroid Build Coastguard Workerimport os
6*da0073e9SAndroid Build Coastguard Workerimport pickle
7*da0073e9SAndroid Build Coastguard Workerimport signal
8*da0073e9SAndroid Build Coastguard Workerimport sys
9*da0073e9SAndroid Build Coastguard Workerimport tempfile
10*da0073e9SAndroid Build Coastguard Workerimport time
11*da0073e9SAndroid Build Coastguard Workerimport warnings
12*da0073e9SAndroid Build Coastguard Workerfrom concurrent.futures import as_completed, ThreadPoolExecutor
13*da0073e9SAndroid Build Coastguard Workerfrom typing import Optional
14*da0073e9SAndroid Build Coastguard Worker
15*da0073e9SAndroid Build Coastguard Workerfrom . import _prctl_pr_set_pdeathsig  # type: ignore[attr-defined]
16*da0073e9SAndroid Build Coastguard Worker
17*da0073e9SAndroid Build Coastguard Worker
18*da0073e9SAndroid Build Coastguard WorkerENV_VAR_PARALLEL_START = "TORCH_MP_PARALLEL_START"
19*da0073e9SAndroid Build Coastguard Worker
20*da0073e9SAndroid Build Coastguard Workerlog = logging.getLogger(__name__)
21*da0073e9SAndroid Build Coastguard Worker
22*da0073e9SAndroid Build Coastguard Worker__all__ = [
23*da0073e9SAndroid Build Coastguard Worker    "ProcessContext",
24*da0073e9SAndroid Build Coastguard Worker    "ProcessException",
25*da0073e9SAndroid Build Coastguard Worker    "ProcessExitedException",
26*da0073e9SAndroid Build Coastguard Worker    "ProcessRaisedException",
27*da0073e9SAndroid Build Coastguard Worker    "spawn",
28*da0073e9SAndroid Build Coastguard Worker    "SpawnContext",
29*da0073e9SAndroid Build Coastguard Worker    "start_processes",
30*da0073e9SAndroid Build Coastguard Worker]
31*da0073e9SAndroid Build Coastguard Worker
32*da0073e9SAndroid Build Coastguard Worker
33*da0073e9SAndroid Build Coastguard Workerclass ProcessException(Exception):
34*da0073e9SAndroid Build Coastguard Worker    __slots__ = ["error_index", "error_pid"]
35*da0073e9SAndroid Build Coastguard Worker
36*da0073e9SAndroid Build Coastguard Worker    def __init__(self, msg: str, error_index: int, pid: int):
37*da0073e9SAndroid Build Coastguard Worker        super().__init__(msg)
38*da0073e9SAndroid Build Coastguard Worker        self.msg = msg
39*da0073e9SAndroid Build Coastguard Worker        self.error_index = error_index
40*da0073e9SAndroid Build Coastguard Worker        self.pid = pid
41*da0073e9SAndroid Build Coastguard Worker
42*da0073e9SAndroid Build Coastguard Worker    def __reduce__(self):
43*da0073e9SAndroid Build Coastguard Worker        return type(self), (self.msg, self.error_index, self.pid)
44*da0073e9SAndroid Build Coastguard Worker
45*da0073e9SAndroid Build Coastguard Worker
46*da0073e9SAndroid Build Coastguard Workerclass ProcessRaisedException(ProcessException):
47*da0073e9SAndroid Build Coastguard Worker    """Exception raised when a process failed due to an exception raised by the code."""
48*da0073e9SAndroid Build Coastguard Worker
49*da0073e9SAndroid Build Coastguard Worker    def __init__(
50*da0073e9SAndroid Build Coastguard Worker        self,
51*da0073e9SAndroid Build Coastguard Worker        msg: str,
52*da0073e9SAndroid Build Coastguard Worker        error_index: int,
53*da0073e9SAndroid Build Coastguard Worker        error_pid: int,
54*da0073e9SAndroid Build Coastguard Worker    ):
55*da0073e9SAndroid Build Coastguard Worker        super().__init__(msg, error_index, error_pid)
56*da0073e9SAndroid Build Coastguard Worker
57*da0073e9SAndroid Build Coastguard Worker
58*da0073e9SAndroid Build Coastguard Workerclass ProcessExitedException(ProcessException):
59*da0073e9SAndroid Build Coastguard Worker    """Exception raised when a process failed due to signal or exited with a specific code."""
60*da0073e9SAndroid Build Coastguard Worker
61*da0073e9SAndroid Build Coastguard Worker    __slots__ = ["exit_code"]
62*da0073e9SAndroid Build Coastguard Worker
63*da0073e9SAndroid Build Coastguard Worker    def __init__(
64*da0073e9SAndroid Build Coastguard Worker        self,
65*da0073e9SAndroid Build Coastguard Worker        msg: str,
66*da0073e9SAndroid Build Coastguard Worker        error_index: int,
67*da0073e9SAndroid Build Coastguard Worker        error_pid: int,
68*da0073e9SAndroid Build Coastguard Worker        exit_code: int,
69*da0073e9SAndroid Build Coastguard Worker        signal_name: Optional[str] = None,
70*da0073e9SAndroid Build Coastguard Worker    ):
71*da0073e9SAndroid Build Coastguard Worker        super().__init__(msg, error_index, error_pid)
72*da0073e9SAndroid Build Coastguard Worker        self.exit_code = exit_code
73*da0073e9SAndroid Build Coastguard Worker        self.signal_name = signal_name
74*da0073e9SAndroid Build Coastguard Worker
75*da0073e9SAndroid Build Coastguard Worker    def __reduce__(self):
76*da0073e9SAndroid Build Coastguard Worker        return (
77*da0073e9SAndroid Build Coastguard Worker            type(self),
78*da0073e9SAndroid Build Coastguard Worker            (self.msg, self.error_index, self.pid, self.exit_code, self.signal_name),
79*da0073e9SAndroid Build Coastguard Worker        )
80*da0073e9SAndroid Build Coastguard Worker
81*da0073e9SAndroid Build Coastguard Worker
82*da0073e9SAndroid Build Coastguard Workerdef _wrap(fn, i, args, error_file):
83*da0073e9SAndroid Build Coastguard Worker    # prctl(2) is a Linux specific system call.
84*da0073e9SAndroid Build Coastguard Worker    # On other systems the following function call has no effect.
85*da0073e9SAndroid Build Coastguard Worker    # This is set to ensure that non-daemonic child processes can
86*da0073e9SAndroid Build Coastguard Worker    # terminate if their parent terminates before they do.
87*da0073e9SAndroid Build Coastguard Worker    _prctl_pr_set_pdeathsig(signal.SIGINT)
88*da0073e9SAndroid Build Coastguard Worker
89*da0073e9SAndroid Build Coastguard Worker    try:
90*da0073e9SAndroid Build Coastguard Worker        fn(i, *args)
91*da0073e9SAndroid Build Coastguard Worker    except KeyboardInterrupt:
92*da0073e9SAndroid Build Coastguard Worker        pass  # SIGINT; Killed by parent, do nothing
93*da0073e9SAndroid Build Coastguard Worker    except Exception:
94*da0073e9SAndroid Build Coastguard Worker        # Propagate exception to parent process, keeping original traceback
95*da0073e9SAndroid Build Coastguard Worker        import traceback
96*da0073e9SAndroid Build Coastguard Worker
97*da0073e9SAndroid Build Coastguard Worker        with open(error_file, "wb") as fh:
98*da0073e9SAndroid Build Coastguard Worker            pickle.dump(traceback.format_exc(), fh)
99*da0073e9SAndroid Build Coastguard Worker        sys.exit(1)
100*da0073e9SAndroid Build Coastguard Worker
101*da0073e9SAndroid Build Coastguard Worker
102*da0073e9SAndroid Build Coastguard Workerclass ProcessContext:
103*da0073e9SAndroid Build Coastguard Worker    def __init__(self, processes, error_files):
104*da0073e9SAndroid Build Coastguard Worker        self.error_files = error_files
105*da0073e9SAndroid Build Coastguard Worker        self.processes = processes
106*da0073e9SAndroid Build Coastguard Worker        self.sentinels = {
107*da0073e9SAndroid Build Coastguard Worker            process.sentinel: index for index, process in enumerate(processes)
108*da0073e9SAndroid Build Coastguard Worker        }
109*da0073e9SAndroid Build Coastguard Worker
110*da0073e9SAndroid Build Coastguard Worker    def pids(self):
111*da0073e9SAndroid Build Coastguard Worker        return [int(process.pid) for process in self.processes]
112*da0073e9SAndroid Build Coastguard Worker
113*da0073e9SAndroid Build Coastguard Worker    def join(self, timeout=None):
114*da0073e9SAndroid Build Coastguard Worker        r"""Join one or more processes within spawn context.
115*da0073e9SAndroid Build Coastguard Worker
116*da0073e9SAndroid Build Coastguard Worker        Attempt to join one or more processes in this spawn context.
117*da0073e9SAndroid Build Coastguard Worker        If one of them exited with a non-zero exit status, this function
118*da0073e9SAndroid Build Coastguard Worker        kills the remaining processes and raises an exception with the cause
119*da0073e9SAndroid Build Coastguard Worker        of the first process exiting.
120*da0073e9SAndroid Build Coastguard Worker
121*da0073e9SAndroid Build Coastguard Worker        Returns ``True`` if all processes have been joined successfully,
122*da0073e9SAndroid Build Coastguard Worker        ``False`` if there are more processes that need to be joined.
123*da0073e9SAndroid Build Coastguard Worker
124*da0073e9SAndroid Build Coastguard Worker        Args:
125*da0073e9SAndroid Build Coastguard Worker            timeout (float): Wait this long before giving up on waiting.
126*da0073e9SAndroid Build Coastguard Worker        """
127*da0073e9SAndroid Build Coastguard Worker        # Ensure this function can be called even when we're done.
128*da0073e9SAndroid Build Coastguard Worker        if len(self.sentinels) == 0:
129*da0073e9SAndroid Build Coastguard Worker            return True
130*da0073e9SAndroid Build Coastguard Worker
131*da0073e9SAndroid Build Coastguard Worker        # Wait for any process to fail or all of them to succeed.
132*da0073e9SAndroid Build Coastguard Worker        ready = multiprocessing.connection.wait(
133*da0073e9SAndroid Build Coastguard Worker            self.sentinels.keys(),
134*da0073e9SAndroid Build Coastguard Worker            timeout=timeout,
135*da0073e9SAndroid Build Coastguard Worker        )
136*da0073e9SAndroid Build Coastguard Worker
137*da0073e9SAndroid Build Coastguard Worker        error_index = None
138*da0073e9SAndroid Build Coastguard Worker        for sentinel in ready:
139*da0073e9SAndroid Build Coastguard Worker            index = self.sentinels.pop(sentinel)
140*da0073e9SAndroid Build Coastguard Worker            process = self.processes[index]
141*da0073e9SAndroid Build Coastguard Worker            process.join()
142*da0073e9SAndroid Build Coastguard Worker            if process.exitcode != 0:
143*da0073e9SAndroid Build Coastguard Worker                error_index = index
144*da0073e9SAndroid Build Coastguard Worker                break
145*da0073e9SAndroid Build Coastguard Worker
146*da0073e9SAndroid Build Coastguard Worker        # Return if there was no error.
147*da0073e9SAndroid Build Coastguard Worker        if error_index is None:
148*da0073e9SAndroid Build Coastguard Worker            # Return whether or not all processes have been joined.
149*da0073e9SAndroid Build Coastguard Worker            return len(self.sentinels) == 0
150*da0073e9SAndroid Build Coastguard Worker
151*da0073e9SAndroid Build Coastguard Worker        # Assume failure. Terminate processes that are still alive.
152*da0073e9SAndroid Build Coastguard Worker        # Try SIGTERM then SIGKILL if the process isn't going down.
153*da0073e9SAndroid Build Coastguard Worker        # The reason is related to python signal handling is limited
154*da0073e9SAndroid Build Coastguard Worker        # to main thread and if that is in c/c++ land and stuck it won't
155*da0073e9SAndroid Build Coastguard Worker        # to handle it. We have seen processes getting stuck not handling
156*da0073e9SAndroid Build Coastguard Worker        # SIGTERM for the above reason.
157*da0073e9SAndroid Build Coastguard Worker        timeout: int = 30
158*da0073e9SAndroid Build Coastguard Worker        for process in self.processes:
159*da0073e9SAndroid Build Coastguard Worker            if process.is_alive():
160*da0073e9SAndroid Build Coastguard Worker                log.warning("Terminating process %s via signal SIGTERM", process.pid)
161*da0073e9SAndroid Build Coastguard Worker                process.terminate()
162*da0073e9SAndroid Build Coastguard Worker        end = time.monotonic() + timeout
163*da0073e9SAndroid Build Coastguard Worker        for process in self.processes:
164*da0073e9SAndroid Build Coastguard Worker            time_to_wait = max(0, end - time.monotonic())
165*da0073e9SAndroid Build Coastguard Worker            process.join(time_to_wait)
166*da0073e9SAndroid Build Coastguard Worker        for process in self.processes:
167*da0073e9SAndroid Build Coastguard Worker            if process.is_alive():
168*da0073e9SAndroid Build Coastguard Worker                log.warning(
169*da0073e9SAndroid Build Coastguard Worker                    "Unable to shutdown process %s via SIGTERM , forcefully exiting via SIGKILL",
170*da0073e9SAndroid Build Coastguard Worker                    process.pid,
171*da0073e9SAndroid Build Coastguard Worker                )
172*da0073e9SAndroid Build Coastguard Worker                process.kill()
173*da0073e9SAndroid Build Coastguard Worker            process.join()
174*da0073e9SAndroid Build Coastguard Worker
175*da0073e9SAndroid Build Coastguard Worker        # The file will only be created if the process crashed.
176*da0073e9SAndroid Build Coastguard Worker        failed_process = self.processes[error_index]
177*da0073e9SAndroid Build Coastguard Worker        if not os.access(self.error_files[error_index], os.R_OK):
178*da0073e9SAndroid Build Coastguard Worker            exitcode = self.processes[error_index].exitcode
179*da0073e9SAndroid Build Coastguard Worker            if exitcode < 0:
180*da0073e9SAndroid Build Coastguard Worker                try:
181*da0073e9SAndroid Build Coastguard Worker                    name = signal.Signals(-exitcode).name
182*da0073e9SAndroid Build Coastguard Worker                except ValueError:
183*da0073e9SAndroid Build Coastguard Worker                    name = f"<Unknown signal {-exitcode}>"
184*da0073e9SAndroid Build Coastguard Worker                raise ProcessExitedException(
185*da0073e9SAndroid Build Coastguard Worker                    "process %d terminated with signal %s" % (error_index, name),
186*da0073e9SAndroid Build Coastguard Worker                    error_index=error_index,
187*da0073e9SAndroid Build Coastguard Worker                    error_pid=failed_process.pid,
188*da0073e9SAndroid Build Coastguard Worker                    exit_code=exitcode,
189*da0073e9SAndroid Build Coastguard Worker                    signal_name=name,
190*da0073e9SAndroid Build Coastguard Worker                )
191*da0073e9SAndroid Build Coastguard Worker            else:
192*da0073e9SAndroid Build Coastguard Worker                raise ProcessExitedException(
193*da0073e9SAndroid Build Coastguard Worker                    "process %d terminated with exit code %d" % (error_index, exitcode),
194*da0073e9SAndroid Build Coastguard Worker                    error_index=error_index,
195*da0073e9SAndroid Build Coastguard Worker                    error_pid=failed_process.pid,
196*da0073e9SAndroid Build Coastguard Worker                    exit_code=exitcode,
197*da0073e9SAndroid Build Coastguard Worker                )
198*da0073e9SAndroid Build Coastguard Worker
199*da0073e9SAndroid Build Coastguard Worker        with open(self.error_files[error_index], "rb") as fh:
200*da0073e9SAndroid Build Coastguard Worker            original_trace = pickle.load(fh)
201*da0073e9SAndroid Build Coastguard Worker        msg = "\n\n-- Process %d terminated with the following error:\n" % error_index
202*da0073e9SAndroid Build Coastguard Worker        msg += original_trace
203*da0073e9SAndroid Build Coastguard Worker        raise ProcessRaisedException(msg, error_index, failed_process.pid)
204*da0073e9SAndroid Build Coastguard Worker
205*da0073e9SAndroid Build Coastguard Worker
206*da0073e9SAndroid Build Coastguard Workerclass SpawnContext(ProcessContext):
207*da0073e9SAndroid Build Coastguard Worker    def __init__(self, processes, error_files):
208*da0073e9SAndroid Build Coastguard Worker        warnings.warn("SpawnContext is renamed to ProcessContext since 1.4 release.")
209*da0073e9SAndroid Build Coastguard Worker        super().__init__(processes, error_files)
210*da0073e9SAndroid Build Coastguard Worker
211*da0073e9SAndroid Build Coastguard Worker
212*da0073e9SAndroid Build Coastguard Worker# Note: [start_processes]
213*da0073e9SAndroid Build Coastguard Worker# mp.start_processes handles both start_method='spawn' and 'fork'. It's supposed to be a
214*da0073e9SAndroid Build Coastguard Worker# more generalized API than mp.spawn. Currently we only document mp.spawn as it's the
215*da0073e9SAndroid Build Coastguard Worker# CUDA compatible start_method. However, in environments like Ipython notebooks, 'fork'
216*da0073e9SAndroid Build Coastguard Worker# works better than 'spawn'. Every helper function we created for mp.spawn is indeed
217*da0073e9SAndroid Build Coastguard Worker# general enough, and backends like XLA can reuse them in Colab notebooks as well.
218*da0073e9SAndroid Build Coastguard Worker# Currently we only add this API first, we can consider adding it to documentation as
219*da0073e9SAndroid Build Coastguard Worker# needed in the future.
220*da0073e9SAndroid Build Coastguard Workerdef start_processes(
221*da0073e9SAndroid Build Coastguard Worker    fn,
222*da0073e9SAndroid Build Coastguard Worker    args=(),
223*da0073e9SAndroid Build Coastguard Worker    nprocs=1,
224*da0073e9SAndroid Build Coastguard Worker    join=True,
225*da0073e9SAndroid Build Coastguard Worker    daemon=False,
226*da0073e9SAndroid Build Coastguard Worker    start_method="spawn",
227*da0073e9SAndroid Build Coastguard Worker):
228*da0073e9SAndroid Build Coastguard Worker    # To speed up performance in certain cases (see https://github.com/pytorch/pytorch/issues/133010),
229*da0073e9SAndroid Build Coastguard Worker    # this func will start processes in parallel if start_method is 'forkserver'.
230*da0073e9SAndroid Build Coastguard Worker    # Please opt in to this perf optimization by setting env var (TORCH_MP_PARALLEL_START) to 1.
231*da0073e9SAndroid Build Coastguard Worker    # todo: investigate why spawn does not work with threadpool and raises SIGINT
232*da0073e9SAndroid Build Coastguard Worker    if (
233*da0073e9SAndroid Build Coastguard Worker        start_method == "forkserver"
234*da0073e9SAndroid Build Coastguard Worker        and os.environ.get(ENV_VAR_PARALLEL_START, "0") == "1"
235*da0073e9SAndroid Build Coastguard Worker    ):
236*da0073e9SAndroid Build Coastguard Worker        log.info("Starting processes in parallel.")
237*da0073e9SAndroid Build Coastguard Worker        start_parallel = True
238*da0073e9SAndroid Build Coastguard Worker    else:
239*da0073e9SAndroid Build Coastguard Worker        # Set env var TORCH_MP_PARALLEL_START to 0 to disable parallel start
240*da0073e9SAndroid Build Coastguard Worker        start_parallel = False
241*da0073e9SAndroid Build Coastguard Worker
242*da0073e9SAndroid Build Coastguard Worker    mp = multiprocessing.get_context(start_method)
243*da0073e9SAndroid Build Coastguard Worker    error_files = [None] * nprocs
244*da0073e9SAndroid Build Coastguard Worker    processes = [None] * nprocs
245*da0073e9SAndroid Build Coastguard Worker
246*da0073e9SAndroid Build Coastguard Worker    def start_process(i):
247*da0073e9SAndroid Build Coastguard Worker        # Each process is assigned a file to write tracebacks to.  We
248*da0073e9SAndroid Build Coastguard Worker        # use the file being non-empty to indicate an exception
249*da0073e9SAndroid Build Coastguard Worker        # occurred (vs an expected shutdown).  Note: this previously
250*da0073e9SAndroid Build Coastguard Worker        # used a multiprocessing.Queue but that can be prone to
251*da0073e9SAndroid Build Coastguard Worker        # deadlocks, so we went with a simpler solution for a one-shot
252*da0073e9SAndroid Build Coastguard Worker        # message between processes.
253*da0073e9SAndroid Build Coastguard Worker        tf = tempfile.NamedTemporaryFile(
254*da0073e9SAndroid Build Coastguard Worker            prefix="pytorch-errorfile-", suffix=".pickle", delete=False
255*da0073e9SAndroid Build Coastguard Worker        )
256*da0073e9SAndroid Build Coastguard Worker        tf.close()
257*da0073e9SAndroid Build Coastguard Worker        os.unlink(tf.name)
258*da0073e9SAndroid Build Coastguard Worker        process = mp.Process(
259*da0073e9SAndroid Build Coastguard Worker            target=_wrap,
260*da0073e9SAndroid Build Coastguard Worker            args=(fn, i, args, tf.name),
261*da0073e9SAndroid Build Coastguard Worker            daemon=daemon,
262*da0073e9SAndroid Build Coastguard Worker        )
263*da0073e9SAndroid Build Coastguard Worker        process.start()
264*da0073e9SAndroid Build Coastguard Worker        return i, process, tf.name
265*da0073e9SAndroid Build Coastguard Worker
266*da0073e9SAndroid Build Coastguard Worker    if not start_parallel:
267*da0073e9SAndroid Build Coastguard Worker        for i in range(nprocs):
268*da0073e9SAndroid Build Coastguard Worker            idx, process, tf_name = start_process(i)
269*da0073e9SAndroid Build Coastguard Worker            error_files[idx] = tf_name
270*da0073e9SAndroid Build Coastguard Worker            processes[idx] = process
271*da0073e9SAndroid Build Coastguard Worker    else:
272*da0073e9SAndroid Build Coastguard Worker        with ThreadPoolExecutor(max_workers=nprocs) as executor:
273*da0073e9SAndroid Build Coastguard Worker            futures = [executor.submit(start_process, i) for i in range(nprocs)]
274*da0073e9SAndroid Build Coastguard Worker            for fut in as_completed(futures):
275*da0073e9SAndroid Build Coastguard Worker                idx, process, tf_name = fut.result()
276*da0073e9SAndroid Build Coastguard Worker                # idx and process rank needs to be the same.
277*da0073e9SAndroid Build Coastguard Worker                error_files[idx] = tf_name
278*da0073e9SAndroid Build Coastguard Worker                processes[idx] = process
279*da0073e9SAndroid Build Coastguard Worker    context = ProcessContext(processes, error_files)
280*da0073e9SAndroid Build Coastguard Worker    if not join:
281*da0073e9SAndroid Build Coastguard Worker        return context
282*da0073e9SAndroid Build Coastguard Worker
283*da0073e9SAndroid Build Coastguard Worker    # Loop on join until it returns True or raises an exception.
284*da0073e9SAndroid Build Coastguard Worker    while not context.join():
285*da0073e9SAndroid Build Coastguard Worker        pass
286*da0073e9SAndroid Build Coastguard Worker
287*da0073e9SAndroid Build Coastguard Worker
288*da0073e9SAndroid Build Coastguard Workerdef spawn(fn, args=(), nprocs=1, join=True, daemon=False, start_method="spawn"):
289*da0073e9SAndroid Build Coastguard Worker    r"""Spawns ``nprocs`` processes that run ``fn`` with ``args``.
290*da0073e9SAndroid Build Coastguard Worker
291*da0073e9SAndroid Build Coastguard Worker    If one of the processes exits with a non-zero exit status, the
292*da0073e9SAndroid Build Coastguard Worker    remaining processes are killed and an exception is raised with the
293*da0073e9SAndroid Build Coastguard Worker    cause of termination. In the case an exception was caught in the
294*da0073e9SAndroid Build Coastguard Worker    child process, it is forwarded and its traceback is included in
295*da0073e9SAndroid Build Coastguard Worker    the exception raised in the parent process.
296*da0073e9SAndroid Build Coastguard Worker
297*da0073e9SAndroid Build Coastguard Worker    Args:
298*da0073e9SAndroid Build Coastguard Worker        fn (function): Function is called as the entrypoint of the
299*da0073e9SAndroid Build Coastguard Worker            spawned process. This function must be defined at the top
300*da0073e9SAndroid Build Coastguard Worker            level of a module so it can be pickled and spawned. This
301*da0073e9SAndroid Build Coastguard Worker            is a requirement imposed by multiprocessing.
302*da0073e9SAndroid Build Coastguard Worker
303*da0073e9SAndroid Build Coastguard Worker            The function is called as ``fn(i, *args)``, where ``i`` is
304*da0073e9SAndroid Build Coastguard Worker            the process index and ``args`` is the passed through tuple
305*da0073e9SAndroid Build Coastguard Worker            of arguments.
306*da0073e9SAndroid Build Coastguard Worker
307*da0073e9SAndroid Build Coastguard Worker        args (tuple): Arguments passed to ``fn``.
308*da0073e9SAndroid Build Coastguard Worker        nprocs (int): Number of processes to spawn.
309*da0073e9SAndroid Build Coastguard Worker        join (bool): Perform a blocking join on all processes.
310*da0073e9SAndroid Build Coastguard Worker        daemon (bool): The spawned processes' daemon flag. If set to True,
311*da0073e9SAndroid Build Coastguard Worker                       daemonic processes will be created.
312*da0073e9SAndroid Build Coastguard Worker        start_method (str): (deprecated) this method will always use ``spawn``
313*da0073e9SAndroid Build Coastguard Worker                               as the start method. To use a different start method
314*da0073e9SAndroid Build Coastguard Worker                               use ``start_processes()``.
315*da0073e9SAndroid Build Coastguard Worker
316*da0073e9SAndroid Build Coastguard Worker    Returns:
317*da0073e9SAndroid Build Coastguard Worker        None if ``join`` is ``True``,
318*da0073e9SAndroid Build Coastguard Worker        :class:`~ProcessContext` if ``join`` is ``False``
319*da0073e9SAndroid Build Coastguard Worker
320*da0073e9SAndroid Build Coastguard Worker    """
321*da0073e9SAndroid Build Coastguard Worker    if start_method != "spawn":
322*da0073e9SAndroid Build Coastguard Worker        msg = (
323*da0073e9SAndroid Build Coastguard Worker            f"This method only supports start_method=spawn (got: {start_method}).\n"
324*da0073e9SAndroid Build Coastguard Worker            "To use a different start_method use:\n\t\t"
325*da0073e9SAndroid Build Coastguard Worker            " torch.multiprocessing.start_processes(...)"
326*da0073e9SAndroid Build Coastguard Worker        )
327*da0073e9SAndroid Build Coastguard Worker        warnings.warn(msg, FutureWarning, stacklevel=2)
328*da0073e9SAndroid Build Coastguard Worker    return start_processes(fn, args, nprocs, join, daemon, start_method="spawn")
329