xref: /aosp_15_r20/external/pytorch/torch/_prims/rng_prims.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: allow-untyped-defs
2from typing import Optional, Tuple
3
4import torch
5import torch.utils._pytree as pytree
6from torch import _prims
7from torch._C import DispatchKey
8from torch._higher_order_ops.utils import autograd_not_implemented
9from torch._ops import HigherOrderOperator
10from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for
11from torch._subclasses.fake_tensor import FakeTensorMode
12from torch.fx.experimental.proxy_tensor import (
13    disable_proxy_modes_tracing,
14    ProxyTorchDispatchMode,
15    track_tensor_tree,
16)
17from torch.types import _device, _dtype
18
19
20def throw_on_non_cuda(device):
21    raise RuntimeError(
22        f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not "
23        f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is "
24        "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU."
25    )
26
27
28def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None):
29    rngprim_def = torch.library.custom_op(
30        "rngprims::" + name, impl_aten, mutates_args=(), schema=schema
31    )
32    rngprim_def.register_fake(impl_meta)
33
34    prim_packet = getattr(torch._ops.ops.rngprims, name)
35    prim = prim_packet.default
36    if tags:
37        prim._tags = tags
38
39    for p in (prim_packet, prim):
40        p.__doc__ = doc
41        p.return_type = torch._prims_common.RETURN_TYPE.NEW  # type: ignore[attr-defined]
42
43        p.schema = name + schema
44        p.impl_aten = impl_aten
45        p.prim_meta_impl = impl_meta
46
47
48# Philox rand offsets could be shared in future with other philox ops, so
49# keeping these functions in global scope.
50def philox_rand_offset_meta(
51    shape: torch.Size,
52):
53    return _prims.TensorLike(torch.tensor(0, dtype=torch.int64))
54
55
56def philox_rand_offset(
57    shape: torch.Size,
58):
59    # For impl, look at the function calc_execution_policy in the file
60    # aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at
61    # commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6
62    numel_scalar = 1
63    for dim_size in shape:
64        numel_scalar *= dim_size
65    numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64)
66
67    block_size = 256
68    unroll = 4
69    curand4_engine_calls = 4
70    device_property = torch.cuda.get_device_properties(torch.cuda.current_device())
71    blocks_per_sm = device_property.max_threads_per_multi_processor // block_size
72    grid_size = (numel + block_size - 1) // block_size
73    grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm)
74    offset = (
75        (numel - 1) // (block_size * grid_size * unroll) + 1
76    ) * curand4_engine_calls
77    return offset
78
79
80def register_philox_rand():
81    name = "philox_rand"
82    schema = "(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)"  # noqa: B950
83
84    def _philox_rand_meta(
85        shape: torch.Size,
86        seed: torch.Tensor,
87        offset: torch.Tensor,
88        stride: Optional[Tuple[int, ...]],
89        device: _device,
90        dtype: _dtype,
91    ):
92        # stride arg will be useful for distributed usecase. Currently, its unused.
93        assert stride is None
94        stride = make_contiguous_strides_for(shape)
95        random_values = _prims.TensorMeta(
96            shape=shape, strides=stride, dtype=dtype, device=device
97        )
98        offset = philox_rand_offset_meta(shape)
99        return (random_values, offset)
100
101    def _philox_rand(
102        shape: torch.Size,
103        seed: torch.Tensor,
104        offset: torch.Tensor,
105        stride: Optional[Tuple[int, ...]],
106        device: _device,
107        dtype: _dtype,
108    ):
109        # stride arg will be useful for distributed usecase. Currently, its unused.
110        assert stride is None
111        if device.type == "cpu":
112            devices = []
113        else:
114            devices = [device]
115
116        if device.type != "cuda":
117            raise throw_on_non_cuda(device)
118
119        with torch.random.fork_rng(devices):
120            CUDARngStateHelper.set_torch_state_tensor(seed, offset)
121            random_values = torch.rand(shape, device=device, dtype=dtype)
122
123        return random_values, philox_rand_offset(shape)
124
125    register_rng_prim(
126        name=name,
127        schema=schema,
128        impl_aten=_philox_rand,
129        impl_meta=_philox_rand_meta,
130        doc="Philox based stateless rand operator",
131        tags=(torch.Tag.nondeterministic_seeded,),
132    )
133
134
135def get_device(args, kwargs):
136    if kwargs.get("device"):
137        device = kwargs.get("device")
138        if isinstance(device, str):
139            device = torch.device(device)
140        return device.type
141
142    devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)}
143    if any(dev == "cuda" for dev in devices):
144        return "cuda"
145    elif any(dev == "xpu" for dev in devices):
146        return "xpu"
147    elif any(dev == "hpu" for dev in devices):
148        return "hpu"
149    elif any(dev == "cpu" for dev in devices):
150        return "cpu"
151    return None
152
153
154def register_run_and_save_rng_state_op():
155    class RunAndSaveRngState(HigherOrderOperator):
156        def __init__(self):
157            super().__init__("run_and_save_rng_state")
158
159        def __call__(self, op, *args, **kwargs):
160            return super().__call__(op, *args, **kwargs)
161
162    run_and_save_rng_state = RunAndSaveRngState()
163
164    run_and_save_rng_state.py_impl(DispatchKey.Autograd)(
165        autograd_not_implemented(run_and_save_rng_state, deferred_error=True)
166    )
167
168    @run_and_save_rng_state.py_impl(DispatchKey.CUDA)
169    def impl_cuda(op, *args, **kwargs):
170        return torch.cuda.get_rng_state(), op(*args, **kwargs)
171
172    @run_and_save_rng_state.py_impl(DispatchKey.CPU)
173    def impl_cpu(op, *args, **kwargs):
174        return torch.get_rng_state(), op(*args, **kwargs)
175
176    @run_and_save_rng_state.py_impl(DispatchKey.HPU)
177    def impl_hpu(op, *args, **kwargs):
178        if hasattr(torch, "hpu"):
179            return torch.hpu.get_rng_state(), op(*args, **kwargs)
180        raise RuntimeError("functionalize a hpu RNG operator is not supported.")
181
182    @run_and_save_rng_state.py_impl(DispatchKey.XPU)
183    def impl_xpu(op, *args, **kwargs):
184        return torch.xpu.get_rng_state(), op(*args, **kwargs)
185
186    @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect)
187    def impl_backend_select(op, *args, **kwargs):
188        impl_map = {
189            "cuda": impl_cuda,
190            "cpu": impl_cpu,
191            "hpu": impl_hpu,
192            "xpu": impl_xpu,
193        }
194        device = get_device(args, kwargs)
195        assert device in impl_map, f"Backend not supported for {device}"
196        impl = impl_map[device]
197        return impl(op, *args, **kwargs)
198
199    @run_and_save_rng_state.py_impl(FakeTensorMode)
200    def impl_fake_tensor_mode(mode, op, *args, **kwargs):
201        # Check device to call the right impl
202        with mode:
203            return impl_backend_select(op, *args, **kwargs)
204
205    @run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
206    def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
207        out = impl_backend_select(op, *args, **kwargs)
208        proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
209        proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
210        out_proxy = mode.tracer.create_proxy(
211            "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
212        )
213        return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
214
215    return run_and_save_rng_state
216
217
218def register_run_with_rng_state_op():
219    class RunWithRngState(HigherOrderOperator):
220        def __init__(self):
221            super().__init__("run_with_rng_state")
222
223        def __call__(self, rng_state, op, *args, **kwargs):
224            return super().__call__(rng_state, op, *args, **kwargs)
225
226    run_with_rng_state = RunWithRngState()
227
228    run_with_rng_state.py_impl(DispatchKey.Autograd)(
229        autograd_not_implemented(run_with_rng_state, deferred_error=True)
230    )
231
232    @run_with_rng_state.py_impl(DispatchKey.CUDA)
233    def impl_cuda(rng_state, op, *args, **kwargs):
234        current_state = torch.cuda.get_rng_state()
235        torch.cuda.set_rng_state(rng_state.cpu())
236        out = op(*args, **kwargs)
237        torch.cuda.set_rng_state(current_state)
238        return out
239
240    @run_with_rng_state.py_impl(DispatchKey.CPU)
241    def impl_cpu(rng_state, op, *args, **kwargs):
242        current_state = torch.get_rng_state()
243        torch.set_rng_state(rng_state)
244        out = op(*args, **kwargs)
245        torch.set_rng_state(current_state)
246        return out
247
248    @run_with_rng_state.py_impl(DispatchKey.HPU)
249    def impl_hpu(rng_state, op, *args, **kwargs):
250        if hasattr(torch, "hpu"):
251            current_state = torch.hpu.get_rng_state()
252            torch.hpu.set_rng_state(rng_state)
253            out = op(*args, **kwargs)
254            torch.hpu.set_rng_state(current_state)
255            return out
256        raise RuntimeError("functionalize a hpu RNG operator is not supported.")
257
258    @run_with_rng_state.py_impl(DispatchKey.XPU)
259    def impl_xpu(rng_state, op, *args, **kwargs):
260        current_state = torch.xpu.get_rng_state()
261        torch.xpu.set_rng_state(rng_state)
262        out = op(*args, **kwargs)
263        torch.xpu.set_rng_state(current_state)
264        return out
265
266    @run_with_rng_state.py_impl(ProxyTorchDispatchMode)
267    def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
268        # TODO: you don't need to do this, the dispatch here already disabled
269        # it
270        with disable_proxy_modes_tracing():
271            out = run_with_rng_state(rng_state, op, *args, **kwargs)
272        proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (rng_state, op, *args))
273        proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
274        out_proxy = mode.tracer.create_proxy(
275            "call_function", run_with_rng_state, proxy_args, proxy_kwargs
276        )
277        return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
278
279    @run_with_rng_state.py_impl(DispatchKey.BackendSelect)
280    def impl_backend_select(rng_state, op, *args, **kwargs):
281        impl_map = {
282            "cuda": impl_cuda,
283            "cpu": impl_cpu,
284            "hpu": impl_hpu,
285            "xpu": impl_xpu,
286        }
287        device = get_device(args, kwargs)
288        assert device in impl_map, f"Backend not supported for {device}"
289        impl = impl_map[device]
290        return impl(rng_state, op, *args, **kwargs)
291
292    @run_with_rng_state.py_impl(FakeTensorMode)
293    def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs):
294        # Skip setting the set_rng_state as it does not work well with fake tensors.
295        # And it does not matter for the fake tensor mode.
296        with mode:
297            return op(*args, **kwargs)
298
299    @run_with_rng_state.py_functionalize_impl
300    def impl_functional(ctx, rng_state, op, *args, **kwargs):
301        unwrapped_rng_state = ctx.unwrap_tensors(rng_state)
302        unwrapped_args = ctx.unwrap_tensors(args)
303        unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
304
305        with ctx.redispatch_to_next():
306            out = run_with_rng_state(
307                unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs
308            )
309            return ctx.wrap_tensors(out)
310
311    return run_with_rng_state
312
313
314run_and_save_rng_state = register_run_and_save_rng_state_op()
315run_with_rng_state = register_run_with_rng_state_op()
316
317
318def register_rng_prims():
319    register_philox_rand()
320