1# mypy: allow-untyped-defs 2from typing import Optional, Tuple 3 4import torch 5import torch.utils._pytree as pytree 6from torch import _prims 7from torch._C import DispatchKey 8from torch._higher_order_ops.utils import autograd_not_implemented 9from torch._ops import HigherOrderOperator 10from torch._prims_common import CUDARngStateHelper, make_contiguous_strides_for 11from torch._subclasses.fake_tensor import FakeTensorMode 12from torch.fx.experimental.proxy_tensor import ( 13 disable_proxy_modes_tracing, 14 ProxyTorchDispatchMode, 15 track_tensor_tree, 16) 17from torch.types import _device, _dtype 18 19 20def throw_on_non_cuda(device): 21 raise RuntimeError( 22 f"You are trying to functionalize a {device.type} RNG operator but {device.type} does not " 23 f"use Philox/counter-based RNG. Therefore, functionalizing a {device.type} RNG operator is " 24 "not supported. We are discussing the possibility of a Philox-based RNG implementation for CPU." 25 ) 26 27 28def register_rng_prim(name, schema, impl_aten, impl_meta, doc, tags=None): 29 rngprim_def = torch.library.custom_op( 30 "rngprims::" + name, impl_aten, mutates_args=(), schema=schema 31 ) 32 rngprim_def.register_fake(impl_meta) 33 34 prim_packet = getattr(torch._ops.ops.rngprims, name) 35 prim = prim_packet.default 36 if tags: 37 prim._tags = tags 38 39 for p in (prim_packet, prim): 40 p.__doc__ = doc 41 p.return_type = torch._prims_common.RETURN_TYPE.NEW # type: ignore[attr-defined] 42 43 p.schema = name + schema 44 p.impl_aten = impl_aten 45 p.prim_meta_impl = impl_meta 46 47 48# Philox rand offsets could be shared in future with other philox ops, so 49# keeping these functions in global scope. 50def philox_rand_offset_meta( 51 shape: torch.Size, 52): 53 return _prims.TensorLike(torch.tensor(0, dtype=torch.int64)) 54 55 56def philox_rand_offset( 57 shape: torch.Size, 58): 59 # For impl, look at the function calc_execution_policy in the file 60 # aten/src/ATen/native/cuda/DistributionTemplates.h. The impl was copied at 61 # commit hash 72aa0667bd16707d50eb8fa337092a1f5d11dfb6 62 numel_scalar = 1 63 for dim_size in shape: 64 numel_scalar *= dim_size 65 numel = torch.scalar_tensor(numel_scalar, dtype=torch.int64) 66 67 block_size = 256 68 unroll = 4 69 curand4_engine_calls = 4 70 device_property = torch.cuda.get_device_properties(torch.cuda.current_device()) 71 blocks_per_sm = device_property.max_threads_per_multi_processor // block_size 72 grid_size = (numel + block_size - 1) // block_size 73 grid_size = min(grid_size, device_property.multi_processor_count * blocks_per_sm) 74 offset = ( 75 (numel - 1) // (block_size * grid_size * unroll) + 1 76 ) * curand4_engine_calls 77 return offset 78 79 80def register_philox_rand(): 81 name = "philox_rand" 82 schema = "(SymInt[] size, Tensor seed, Tensor offset, int[]? stride, Device? device=None, ScalarType? dtype=None) -> (Tensor, Tensor)" # noqa: B950 83 84 def _philox_rand_meta( 85 shape: torch.Size, 86 seed: torch.Tensor, 87 offset: torch.Tensor, 88 stride: Optional[Tuple[int, ...]], 89 device: _device, 90 dtype: _dtype, 91 ): 92 # stride arg will be useful for distributed usecase. Currently, its unused. 93 assert stride is None 94 stride = make_contiguous_strides_for(shape) 95 random_values = _prims.TensorMeta( 96 shape=shape, strides=stride, dtype=dtype, device=device 97 ) 98 offset = philox_rand_offset_meta(shape) 99 return (random_values, offset) 100 101 def _philox_rand( 102 shape: torch.Size, 103 seed: torch.Tensor, 104 offset: torch.Tensor, 105 stride: Optional[Tuple[int, ...]], 106 device: _device, 107 dtype: _dtype, 108 ): 109 # stride arg will be useful for distributed usecase. Currently, its unused. 110 assert stride is None 111 if device.type == "cpu": 112 devices = [] 113 else: 114 devices = [device] 115 116 if device.type != "cuda": 117 raise throw_on_non_cuda(device) 118 119 with torch.random.fork_rng(devices): 120 CUDARngStateHelper.set_torch_state_tensor(seed, offset) 121 random_values = torch.rand(shape, device=device, dtype=dtype) 122 123 return random_values, philox_rand_offset(shape) 124 125 register_rng_prim( 126 name=name, 127 schema=schema, 128 impl_aten=_philox_rand, 129 impl_meta=_philox_rand_meta, 130 doc="Philox based stateless rand operator", 131 tags=(torch.Tag.nondeterministic_seeded,), 132 ) 133 134 135def get_device(args, kwargs): 136 if kwargs.get("device"): 137 device = kwargs.get("device") 138 if isinstance(device, str): 139 device = torch.device(device) 140 return device.type 141 142 devices = {arg.device.type for arg in args if isinstance(arg, torch.Tensor)} 143 if any(dev == "cuda" for dev in devices): 144 return "cuda" 145 elif any(dev == "xpu" for dev in devices): 146 return "xpu" 147 elif any(dev == "hpu" for dev in devices): 148 return "hpu" 149 elif any(dev == "cpu" for dev in devices): 150 return "cpu" 151 return None 152 153 154def register_run_and_save_rng_state_op(): 155 class RunAndSaveRngState(HigherOrderOperator): 156 def __init__(self): 157 super().__init__("run_and_save_rng_state") 158 159 def __call__(self, op, *args, **kwargs): 160 return super().__call__(op, *args, **kwargs) 161 162 run_and_save_rng_state = RunAndSaveRngState() 163 164 run_and_save_rng_state.py_impl(DispatchKey.Autograd)( 165 autograd_not_implemented(run_and_save_rng_state, deferred_error=True) 166 ) 167 168 @run_and_save_rng_state.py_impl(DispatchKey.CUDA) 169 def impl_cuda(op, *args, **kwargs): 170 return torch.cuda.get_rng_state(), op(*args, **kwargs) 171 172 @run_and_save_rng_state.py_impl(DispatchKey.CPU) 173 def impl_cpu(op, *args, **kwargs): 174 return torch.get_rng_state(), op(*args, **kwargs) 175 176 @run_and_save_rng_state.py_impl(DispatchKey.HPU) 177 def impl_hpu(op, *args, **kwargs): 178 if hasattr(torch, "hpu"): 179 return torch.hpu.get_rng_state(), op(*args, **kwargs) 180 raise RuntimeError("functionalize a hpu RNG operator is not supported.") 181 182 @run_and_save_rng_state.py_impl(DispatchKey.XPU) 183 def impl_xpu(op, *args, **kwargs): 184 return torch.xpu.get_rng_state(), op(*args, **kwargs) 185 186 @run_and_save_rng_state.py_impl(DispatchKey.BackendSelect) 187 def impl_backend_select(op, *args, **kwargs): 188 impl_map = { 189 "cuda": impl_cuda, 190 "cpu": impl_cpu, 191 "hpu": impl_hpu, 192 "xpu": impl_xpu, 193 } 194 device = get_device(args, kwargs) 195 assert device in impl_map, f"Backend not supported for {device}" 196 impl = impl_map[device] 197 return impl(op, *args, **kwargs) 198 199 @run_and_save_rng_state.py_impl(FakeTensorMode) 200 def impl_fake_tensor_mode(mode, op, *args, **kwargs): 201 # Check device to call the right impl 202 with mode: 203 return impl_backend_select(op, *args, **kwargs) 204 205 @run_and_save_rng_state.py_impl(ProxyTorchDispatchMode) 206 def impl_proxy_dispatch_mode(mode, op, *args, **kwargs): 207 out = impl_backend_select(op, *args, **kwargs) 208 proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args)) 209 proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) 210 out_proxy = mode.tracer.create_proxy( 211 "call_function", run_and_save_rng_state, proxy_args, proxy_kwargs 212 ) 213 return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) 214 215 return run_and_save_rng_state 216 217 218def register_run_with_rng_state_op(): 219 class RunWithRngState(HigherOrderOperator): 220 def __init__(self): 221 super().__init__("run_with_rng_state") 222 223 def __call__(self, rng_state, op, *args, **kwargs): 224 return super().__call__(rng_state, op, *args, **kwargs) 225 226 run_with_rng_state = RunWithRngState() 227 228 run_with_rng_state.py_impl(DispatchKey.Autograd)( 229 autograd_not_implemented(run_with_rng_state, deferred_error=True) 230 ) 231 232 @run_with_rng_state.py_impl(DispatchKey.CUDA) 233 def impl_cuda(rng_state, op, *args, **kwargs): 234 current_state = torch.cuda.get_rng_state() 235 torch.cuda.set_rng_state(rng_state.cpu()) 236 out = op(*args, **kwargs) 237 torch.cuda.set_rng_state(current_state) 238 return out 239 240 @run_with_rng_state.py_impl(DispatchKey.CPU) 241 def impl_cpu(rng_state, op, *args, **kwargs): 242 current_state = torch.get_rng_state() 243 torch.set_rng_state(rng_state) 244 out = op(*args, **kwargs) 245 torch.set_rng_state(current_state) 246 return out 247 248 @run_with_rng_state.py_impl(DispatchKey.HPU) 249 def impl_hpu(rng_state, op, *args, **kwargs): 250 if hasattr(torch, "hpu"): 251 current_state = torch.hpu.get_rng_state() 252 torch.hpu.set_rng_state(rng_state) 253 out = op(*args, **kwargs) 254 torch.hpu.set_rng_state(current_state) 255 return out 256 raise RuntimeError("functionalize a hpu RNG operator is not supported.") 257 258 @run_with_rng_state.py_impl(DispatchKey.XPU) 259 def impl_xpu(rng_state, op, *args, **kwargs): 260 current_state = torch.xpu.get_rng_state() 261 torch.xpu.set_rng_state(rng_state) 262 out = op(*args, **kwargs) 263 torch.xpu.set_rng_state(current_state) 264 return out 265 266 @run_with_rng_state.py_impl(ProxyTorchDispatchMode) 267 def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs): 268 # TODO: you don't need to do this, the dispatch here already disabled 269 # it 270 with disable_proxy_modes_tracing(): 271 out = run_with_rng_state(rng_state, op, *args, **kwargs) 272 proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (rng_state, op, *args)) 273 proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs) 274 out_proxy = mode.tracer.create_proxy( 275 "call_function", run_with_rng_state, proxy_args, proxy_kwargs 276 ) 277 return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer) 278 279 @run_with_rng_state.py_impl(DispatchKey.BackendSelect) 280 def impl_backend_select(rng_state, op, *args, **kwargs): 281 impl_map = { 282 "cuda": impl_cuda, 283 "cpu": impl_cpu, 284 "hpu": impl_hpu, 285 "xpu": impl_xpu, 286 } 287 device = get_device(args, kwargs) 288 assert device in impl_map, f"Backend not supported for {device}" 289 impl = impl_map[device] 290 return impl(rng_state, op, *args, **kwargs) 291 292 @run_with_rng_state.py_impl(FakeTensorMode) 293 def impl_fake_tensor_mode(mode, rng_state, op, *args, **kwargs): 294 # Skip setting the set_rng_state as it does not work well with fake tensors. 295 # And it does not matter for the fake tensor mode. 296 with mode: 297 return op(*args, **kwargs) 298 299 @run_with_rng_state.py_functionalize_impl 300 def impl_functional(ctx, rng_state, op, *args, **kwargs): 301 unwrapped_rng_state = ctx.unwrap_tensors(rng_state) 302 unwrapped_args = ctx.unwrap_tensors(args) 303 unwrapped_kwargs = ctx.unwrap_tensors(kwargs) 304 305 with ctx.redispatch_to_next(): 306 out = run_with_rng_state( 307 unwrapped_rng_state, op, *unwrapped_args, **unwrapped_kwargs 308 ) 309 return ctx.wrap_tensors(out) 310 311 return run_with_rng_state 312 313 314run_and_save_rng_state = register_run_and_save_rng_state_op() 315run_with_rng_state = register_run_with_rng_state_op() 316 317 318def register_rng_prims(): 319 register_philox_rand() 320