xref: /aosp_15_r20/external/pytorch/torch/_subclasses/fake_impls.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# mypy: ignore-errors
2
3import functools
4import itertools
5import math
6import sys
7from typing import Callable, Union
8
9import torch
10import torch._custom_op
11import torch._logging
12from torch._dispatch.python import no_python_dispatcher
13from torch._ops import OpOverload
14from torch._prims_common import (
15    elementwise_dtypes,
16    ELEMENTWISE_TYPE_PROMOTION_KIND,
17    is_boolean_dtype,
18    is_float_dtype,
19    is_integer_dtype,
20)
21from torch._subclasses.fake_tensor import (
22    DataDependentOutputException,
23    DynamicOutputShapeException,
24    FakeTensor,
25    in_kernel_invocation_manager,
26    run_fallback_kernel,
27    UnsupportedOperatorException,
28)
29from torch.fx.operator_schemas import normalize_function
30from torch.utils._stats import count_label
31
32
33pytree = torch.utils._pytree
34
35__all__ = [
36    "op_implementations_checks",
37    "get_fast_op_impls",
38    "stride_incorrect_op",
39    "has_meta",
40]
41
42op_implementations_dict = {}
43op_implementations_checks = []
44
45
46aten = torch._ops.ops.aten
47
48
49def ordered_set(*items):
50    return dict.fromkeys(items, True)
51
52
53# This function indicates if the backend device
54# supports non-contiguous tensors
55def is_noncontiguous_supported(device):
56    return device.type != "hpu"
57
58
59_like_tensor_constructors = ordered_set(
60    aten.empty_like.default,
61    aten.empty_like.out,
62    aten.full_like.default,
63    aten.full_like.out,
64    aten.ones_like.default,
65    aten.ones_like.out,
66    aten.rand_like.default,
67    aten.rand_like.out,
68    aten.randn_like.default,
69    aten.randn_like.out,
70    aten.randint_like.default,
71    aten.randint_like.out,
72    aten.randint_like.low_dtype,
73    aten.randint_like.low_dtype_out,
74    aten.zeros_like.default,
75    aten.zeros_like.out,
76    aten.new_empty.default,
77    aten.new_empty.out,
78    aten.new_empty_strided.default,
79    aten.new_empty_strided.out,
80    aten.new_full.default,
81    aten.new_full.out,
82    aten.new_zeros.default,
83    aten.new_zeros.out,
84    aten.new_ones.default,
85    aten.new_ones.out,
86)
87
88
89_device_not_kwarg_ops = ordered_set(
90    aten._resize_output_.default,
91    aten._nested_tensor_from_tensor_list.default,
92    aten._nested_tensor_from_tensor_list.out,
93    aten.pin_memory.default,
94    aten.to.device,
95    aten.to.prim_Device,
96    aten.is_pinned.default,
97    aten._pin_memory.default,
98    aten._pin_memory.out,
99    aten._resize_output.default,
100    aten._resize_output.out,
101)
102
103# this op is never actually used
104_non_kwarg_device_constructors = (aten._list_to_tensor,)
105
106
107def contains_tensor_types(type):
108    tensor_type = torch._C.TensorType.get()
109    return type.isSubtypeOf(tensor_type) or any(
110        contains_tensor_types(e) for e in type.containedTypes()
111    )
112
113
114@functools.lru_cache(None)
115def _is_tensor_constructor(func: OpOverload):
116    assert isinstance(func, OpOverload)
117    schema = func._schema
118    if any(contains_tensor_types(arg.type) for arg in schema.arguments):
119        return False
120    # TODO: no real reason to restrict multiple outputs
121    return (
122        len(schema.returns) == 1 and schema.returns[0].type is torch._C.TensorType.get()
123    )
124
125
126def register_op_impl(run_impl_check: Union[Callable[[OpOverload], bool], OpOverload]):
127    def impl_decorator(op_impl):
128        if isinstance(run_impl_check, OpOverload):
129            assert (
130                run_impl_check not in op_implementations_dict
131            ), f"duplicate registration: {run_impl_check}"
132            op_implementations_dict[run_impl_check] = op_impl
133        elif isinstance(run_impl_check, (list, tuple)):
134            for op in run_impl_check:
135                register_op_impl(op)(op_impl)
136        else:
137            assert callable(run_impl_check)
138            op_implementations_checks.append((run_impl_check, op_impl))
139
140        return op_impl
141
142    return impl_decorator
143
144
145@register_op_impl(op_implementations_dict.__contains__)
146def dispatch_to_op_implementations_dict(fake_mode, func, *args, **kwargs):
147    return op_implementations_dict[func](fake_mode, func, *args, **kwargs)
148
149
150@register_op_impl(_is_tensor_constructor)
151@register_op_impl([*_like_tensor_constructors])
152def constructors(fake_mode, func, *args, **kwargs):
153    assert func not in _non_kwarg_device_constructors
154    _, new_kwargs = normalize_function(
155        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
156    )
157    if "names" in kwargs:
158        raise UnsupportedOperatorException(
159            "torch.compile doesn't support named tensors"
160        )
161
162    if func in _like_tensor_constructors:
163        default_device = new_kwargs["input"].device
164        # TODO: file issue
165        args = (new_kwargs.pop("input"),)
166    else:
167        # cpu is default device if none is specified
168        default_device = torch.device("cpu")
169        args = ()
170    out_device = new_kwargs.pop("device", None)
171    out_device = out_device if out_device is not None else default_device
172    new_kwargs["device"] = torch.device("meta")
173    # _like constructors have fake tensor inputs (maybe this causes the non-like
174    # to fail? hmmm)
175    with in_kernel_invocation_manager(fake_mode):
176        r = func(*args, **new_kwargs)
177    return FakeTensor(fake_mode, r, out_device)
178
179
180@register_op_impl(aten.is_pinned.default)
181def non_kwarg_is_pinned(fake_mode, func, *args, **kwargs):
182    _, new_kwargs = normalize_function(
183        func, args, kwargs, normalize_to_only_use_kwargs=True
184    )
185    inp = new_kwargs.pop("input")
186    # we'll ignore device argument because it is deprecated and not
187    # actually used by is_pinned.
188    with in_kernel_invocation_manager(fake_mode):
189        r = func(inp)
190    return r
191
192
193@register_op_impl(aten.to.prim_Device)
194@register_op_impl(aten.to.device)
195def non_kwarg_to(fake_mode, func, *args, **kwargs):
196    _, new_kwargs = normalize_function(
197        func, args, kwargs, normalize_to_only_use_kwargs=True
198    )
199    input_device = new_kwargs["device"]
200    out_device = input_device if input_device else new_kwargs["input"].device
201    new_kwargs["device"] = torch.device("meta")
202    inp = new_kwargs.pop("input")
203    with in_kernel_invocation_manager(fake_mode):
204        r = func(inp, **new_kwargs)
205    # TODO: I think this does the wrong thing if r is inp
206    return fake_mode.fake_tensor_converter.from_meta_and_device(
207        fake_mode, r, out_device
208    )
209
210
211def stride_incorrect_op(op):
212    if op.namespace not in ("aten", "prims"):
213        return False
214    if op is aten._fft_c2c.default:
215        return False
216
217    op_name = op.name()
218    if "fft" in op_name:
219        return True
220    return False
221
222
223# These operators have meta implementations with incorrect strides
224@register_op_impl(stride_incorrect_op)
225def wordaround_stride_incorrect_op(fake_mode, func, *args, **kwargs):
226    # This is a workaround for meta implmentations with incorrect strides
227
228    def is_symbolic(x):
229        if isinstance(x, FakeTensor):
230            return x._has_symbolic_sizes_strides
231        if isinstance(x, (torch.SymInt, torch.SymFloat, torch.SymBool)):
232            return True
233        return False
234
235    # For static shapes, we can fall back to eager for the real strides
236    if fake_mode.allow_fallback_kernels:
237        require_dynamic = any(
238            is_symbolic(x) for x in itertools.chain(args, kwargs.values())
239        )
240        if not require_dynamic:
241            flat_args, args_spec = pytree.tree_flatten((args, kwargs))
242            return run_fallback_kernel(fake_mode, func, flat_args, args_spec, None)
243
244    raise UnsupportedOperatorException(func)
245
246
247# Dont default to default device handling,
248# since the device of `the_template` is ignored
249@register_op_impl(aten.resize_as_.default)
250def resize_as_(fake_mode, func, *args, **kwargs):
251    with in_kernel_invocation_manager(fake_mode):
252        return func(*args, **kwargs)
253
254
255@register_op_impl(aten._sparse_coo_tensor_with_dims_and_tensors.default)
256def _sparse_coo_tensor_with_dims_and_tensors(fake_mode, func, *args, **kwargs):
257    # TODO: remove me
258    return constructors(fake_mode, func, *args, **kwargs)
259
260
261# index.Tensor data-dependent in only some conditions
262@register_op_impl(
263    lambda func: torch.Tag.dynamic_output_shape in func.tags
264    and func
265    not in [aten.index.Tensor, aten.nonzero.default, aten.repeat_interleave.Tensor]
266)
267def dyn_shape(fake_mode, func, *args, **kwargs):
268    raise DynamicOutputShapeException(func)
269
270
271def _unique(
272    fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
273):
274    if (
275        fake_mode.shape_env is None
276        or not fake_mode.shape_env.allow_dynamic_output_shape_ops
277    ):
278        # Without symints/symfloats, cannot handle this
279        raise DynamicOutputShapeException(func)
280
281    # Do not use a memo for unique_dim
282    if dim is not None or (nnz := arg.unique_memo) is None:
283        # Avoid importing sympy at a module level
284        from torch.fx.experimental.symbolic_shapes import (
285            _constrain_range_for_size,
286            has_free_symbols,
287        )
288
289        if not has_free_symbols(arg.numel()) and arg.numel() == 0:
290            # If numel is zero, then the output size must be zero.
291            # In this case, we must not allocate an unbacked SymInt,
292            # because if we do, it will immediately get refined to
293            # zero, but this will be inconsistent with size oblivious
294            # tests (which will continue to claim that the unbacked
295            # symint cannot equal zero).  We could also unconditionally
296            # allocate an unbacked SymInt and not refine its range,
297            # but this seems more precise.
298            nnz = 0
299        else:
300            nnz = fake_mode.shape_env.create_unbacked_symint()
301
302            maxval = sys.maxsize - 1
303
304            numel = arg.numel() if dim is None else arg.size(dim)
305            if not has_free_symbols(numel):
306                maxval = int(numel)
307
308            _constrain_range_for_size(nnz, max=maxval)
309
310        if dim is None:
311            arg.unique_memo = nnz
312
313    if dim is None:
314        ret = [arg.new_empty((nnz,))]
315    else:
316        ret = [arg.new_empty(*arg.shape[:dim], nnz, *arg.shape[dim + 1 :])]
317
318    return_if_dim_and_cpu = dim is not None and arg.fake_device == torch.device("cpu")
319    if return_inverse or return_if_dim_and_cpu:
320        inverse = arg.new_empty(arg.shape if dim is None else (arg.shape[dim],))
321    else:
322        inverse = arg.new_empty(0)
323    ret.append(inverse)
324
325    if return_counts or return_if_dim_and_cpu:
326        counts = arg.new_empty(ret[0].shape if dim is None else (ret[0].shape[dim],))
327    else:
328        counts = arg.new_empty(0)
329    ret.append(counts)
330
331    return tuple(ret)
332
333
334@register_op_impl(aten._unique2.default)
335def unique2(
336    fake_mode, func, arg, sorted=True, return_inverse=False, return_counts=False
337):
338    return _unique(fake_mode, func, arg, None, sorted, return_inverse, return_counts)
339
340
341@register_op_impl(aten.unique_dim.default)
342def unique_dim(
343    fake_mode, func, arg, dim, sorted=True, return_inverse=False, return_counts=False
344):
345    return _unique(
346        fake_mode,
347        func,
348        arg,
349        # normalize dim to be non-negative
350        dim if dim >= 0 else dim % max(arg.ndim, 1),
351        sorted,
352        return_inverse,
353        return_counts,
354    )
355
356
357@register_op_impl(aten.repeat_interleave.Tensor)
358def repeat_interleave_tensor(fake_mode, func, repeats, output_size=None):
359    if output_size is None:
360        if (
361            fake_mode.shape_env is None
362            or not fake_mode.shape_env.allow_dynamic_output_shape_ops
363        ):
364            raise DynamicOutputShapeException(func)
365
366        output_size = fake_mode.shape_env.create_unbacked_symint()
367
368        # Avoid importing sympy at a module level
369        from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
370
371        _constrain_range_for_size(output_size)
372        # TODO: consider a memo
373    return repeats.new_empty(output_size)
374
375
376@register_op_impl(torch.ops.aten.item.default)
377@register_op_impl(torch.ops.aten._local_scalar_dense.default)
378def local_scalar_dense(fake_mode, func, arg):
379    if (r := arg.item_memo) is not None:
380        return r
381    if fake_mode.shape_env is None or (
382        not fake_mode.shape_env.allow_scalar_outputs
383        and not fake_mode.allow_scalar_outputs
384    ):
385        # Without symints/symfloats, cannot handle this
386        raise DataDependentOutputException(func)
387    if is_float_dtype(arg.dtype):
388        r = fake_mode.shape_env.create_unbacked_symfloat()
389    elif is_integer_dtype(arg.dtype):
390        r = fake_mode.shape_env.create_unbacked_symint()
391    elif is_boolean_dtype(arg.dtype):
392        r = fake_mode.shape_env.create_unbacked_symbool()
393    else:
394        raise NotImplementedError(f"local_scalar_dense/item NYI for {arg.dtype}")
395    arg.item_memo = r
396    return r
397
398
399@register_op_impl(torch.ops.aten.nonzero.default)
400def nonzero(fake_mode, func, arg):
401    if (
402        fake_mode.shape_env is None
403        or not fake_mode.shape_env.allow_dynamic_output_shape_ops
404    ):
405        # Without symints/symfloats, cannot handle this
406        raise DynamicOutputShapeException(func)
407
408    if (nnz := arg.nonzero_memo) is None:
409        # Avoid importing sympy at a module level
410        from torch.fx.experimental.symbolic_shapes import (
411            _constrain_range_for_size,
412            has_free_symbols,
413        )
414
415        if not has_free_symbols(arg.numel()) and arg.numel() == 0:
416            # If numel is zero, then the output size must be zero.
417            # In this case, we must not allocate an unbacked SymInt,
418            # because if we do, it will immediately get refined to
419            # zero, but this will be inconsistent with size oblivious
420            # tests (which will continue to claim that the unbacked
421            # symint cannot equal zero).  We could also unconditionally
422            # allocate an unbacked SymInt and not refine its range,
423            # but this seems more precise.
424            nnz = 0
425        else:
426            nnz = fake_mode.shape_env.create_unbacked_symint()
427
428            maxval = sys.maxsize - 1
429
430            if not has_free_symbols(arg.numel()):
431                maxval = int(arg.numel())
432
433            _constrain_range_for_size(nnz, max=maxval)
434
435        arg.nonzero_memo = nnz
436
437    return arg.new_empty((nnz, arg.dim()), dtype=torch.int64)
438
439
440@register_op_impl(torch.ops.aten.masked_select.default)
441def masked_select(fake_mode, func, self, mask):
442    if (
443        fake_mode.shape_env is None
444        or not fake_mode.shape_env.allow_dynamic_output_shape_ops
445    ):
446        # Without symints/symfloats, cannot handle this
447        raise DynamicOutputShapeException(func)
448
449    nnz = fake_mode.shape_env.create_unbacked_symint()
450
451    # see nonzero for commentary
452    maxval = sys.maxsize - 1
453
454    # Avoid importing sympy at a module level
455    from torch.fx.experimental.symbolic_shapes import (
456        _constrain_range_for_size,
457        has_free_symbols,
458    )
459    from torch.utils._sympy.numbers import IntInfinity
460    from torch.utils._sympy.value_ranges import bound_sympy
461
462    # If num elements is expressed symbolically, calculate
463    # the concrete value based on upper bounds. Otherwise,
464    # we can set max val directly.
465    if not has_free_symbols(self.numel()):
466        num_elements = int(self.numel())
467    else:
468        prod_node = math.prod(self.shape).node
469        prod_range = bound_sympy(prod_node.expr, prod_node.shape_env.var_to_range)
470        if isinstance(prod_range.upper, IntInfinity):
471            num_elements = sys.maxsize - 1
472        else:
473            num_elements = prod_range.upper
474    if num_elements > 2:
475        maxval = num_elements
476
477    _constrain_range_for_size(nnz, max=maxval)
478
479    return self.new_empty((nnz,))
480
481
482# NB: this must be ordered after local_scalar_dense
483@register_op_impl(lambda func: torch.Tag.data_dependent_output in func.tags)
484def data_dep(fake_mode, func, *args, **kwargs):
485    raise DataDependentOutputException(func)
486
487
488# Bool Indices get Expanded as Masks
489# See: IndexingUtils.h:expandTensors
490def check_no_bool_index_tensors(func, self, indices):
491    for index in indices:
492        if index is not None and index.dtype in (torch.bool, torch.uint8):
493            raise DynamicOutputShapeException(func)
494
495
496def run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs):
497    _, new_kwargs = normalize_function(
498        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
499    )
500
501    out_device = new_kwargs["input"].device
502    with in_kernel_invocation_manager(fake_mode):
503        out = func(*args, **kwargs)
504        if not is_noncontiguous_supported(out_device):
505            out = out.new_empty(out.shape)
506
507    if out is new_kwargs["input"]:
508        return out  # copy_
509    return FakeTensor(fake_mode, out, out_device)
510
511
512_is_builtin_namespaces = ordered_set("aten", "prims", "prim")
513
514
515def is_builtin(op):
516    return op.namespace in _is_builtin_namespaces
517
518
519def has_meta(func):
520    return torch._C._dispatch_has_computed_kernel_for_dispatch_key(func.name(), "Meta")
521
522
523@register_op_impl(
524    lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
525)
526def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
527    tensor_lists = []
528    for arg in itertools.chain(args, kwargs.values()):
529        if (
530            isinstance(arg, (list, tuple))
531            and len(arg)
532            and isinstance(arg[0], torch.Tensor)
533        ):
534            tensor_lists.append(arg)
535
536    try:
537        with in_kernel_invocation_manager(fake_mode):
538            out_meta = func(*args, **kwargs)
539    except NotImplementedError as not_implemented_error:
540        return NotImplemented
541
542    if not out_meta:
543        return out_meta
544
545    assert tensor_lists
546    out_fake = []
547
548    for i, meta_t in enumerate(out_meta):
549        device, _ = FakeTensor._find_common_device(func, [tl[i] for tl in tensor_lists])
550        out_fake.append(
551            fake_mode.fake_tensor_converter.from_meta_and_device(
552                fake_mode, meta_t, device
553            )
554        )
555
556    return out_fake
557
558
559# Dont default to default device handling,
560# Since op can take in non-zero sized cpu
561# index tensors with cuda self
562@register_op_impl(aten.index.Tensor)
563def index_tensor(fake_mode, func, *args, **kwargs):
564    from torch._meta_registrations import meta_index_Tensor
565
566    _, new_kwargs = normalize_function(
567        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
568    )
569
570    out_device = new_kwargs["input"].device
571    # ensure nonzero call goes to fake tensor
572    with fake_mode:
573        out = meta_index_Tensor(*args, **kwargs)
574        return out.to(out_device)
575
576
577# Can take mixed meta/non-meta arguments; the meta registration
578# will roughly do the right thing even when given real devices
579@register_op_impl(aten._embedding_bag.default)
580def embedding_bag(fake_mode, func, *args, **kwargs):
581    from torch._meta_registrations import meta_embedding_bag
582
583    with fake_mode:
584        return meta_embedding_bag(*args, **kwargs)
585
586
587# takes in multiple-devices, dont default to default device handling
588@register_op_impl(aten._unsafe_index_put.default)
589@register_op_impl(aten.copy.default)
590@register_op_impl(aten.copy_.default)
591@register_op_impl(aten.slice_scatter.default)
592def multi_device_op_default(fake_mode, func, *args, **kwargs):
593    return run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
594
595
596# same with multi_device_op_default, but return the input
597@register_op_impl(aten.copy.out)
598@register_op_impl(aten.slice_scatter.out)
599def multi_device_op_out(fake_mode, func, *args, **kwargs):
600    with in_kernel_invocation_manager(fake_mode):
601        out = func(*args, **kwargs)
602
603    _, new_kwargs = normalize_function(
604        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
605    )
606
607    return new_kwargs["input"]
608
609
610@register_op_impl(aten.index_put.default)
611@register_op_impl(aten.index_put_.default)
612def index_put_impl(fake_mode, func, *args, **kwargs):
613    _, new_kwargs = normalize_function(
614        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
615    )
616
617    values = new_kwargs["values"]
618    self_device = new_kwargs["input"].fake_device
619    torch._check(
620        self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
621        lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
622    )
623
624    out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
625    if func is aten.index_put_.default:
626        return new_kwargs["input"]
627    else:
628        return out
629
630
631@register_op_impl(aten._nested_tensor_from_tensor_list.default)
632@register_op_impl(aten._nested_tensor_from_tensor_list.out)
633@register_op_impl(aten._nested_view_from_buffer.default)
634@register_op_impl(aten._nested_view_from_buffer_copy.default)
635def nested_tensors_unsupported(fake_mode, func, *args, **kwargs):
636    raise UnsupportedOperatorException(
637        "torch.compile does not support strided NestedTensor"
638    )
639
640
641@register_op_impl(
642    [
643        x
644        for x in _device_not_kwarg_ops
645        if x
646        not in (
647            # these are already registered elsewhere
648            aten.is_pinned.default,
649            aten.to.device,
650            aten.to.prim_Device,
651            aten._nested_tensor_from_tensor_list.default,
652            aten._nested_tensor_from_tensor_list.out,
653        )
654    ]
655)
656def nyi(fake_mode, func, *args, **kwargs):
657    assert func not in _device_not_kwarg_ops, f"NYI: {func}"
658
659
660@register_op_impl([aten.convolution.default, aten.convolution_backward.default])
661def conv(fake_mode, func, *args, **kwargs):
662    _, kwargs = normalize_function(
663        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
664    )
665    device = kwargs["input"].fake_device
666    # need to re-enable mode so the tensors report fake device
667    with fake_mode:
668        # if the input is unsqueezed is done in Convolution.cpp we get segfault
669        k = kwargs["weight"].ndim
670        batch = kwargs["input"].shape[0]
671
672        # Avoid importing sympy at a module level
673        from torch.fx.experimental.symbolic_shapes import has_hint
674
675        if not has_hint(batch):
676            # TODO: We can make this a little more faithful with best effort
677            # channels last detection (but only if it's statically obvious!)
678            mem_fmt = None
679        elif k == 3 and not kwargs["input"].is_mkldnn and not kwargs["input"].is_xpu:
680            mem_fmt = None
681        else:
682            if func is aten.convolution.default:
683                conv_backend = torch._C._select_conv_backend(**kwargs)
684            else:
685                conv_backend = torch._C._select_conv_backend(
686                    kwargs["input"],
687                    kwargs["weight"],
688                    bias=None,
689                    stride=kwargs["stride"],
690                    padding=kwargs["padding"],
691                    dilation=kwargs["dilation"],
692                    transposed=kwargs["transposed"],
693                    output_padding=kwargs["output_padding"],
694                    groups=kwargs["groups"],
695                    bias_sizes=kwargs["bias_sizes"],
696                )
697            mem_fmt = torch._C._conv_determine_backend_memory_format(
698                kwargs["input"], kwargs["weight"], conv_backend
699            )
700
701    def convert(t, mem_fmt):
702        if t is None:
703            return t
704        if mem_fmt is not None:
705            t = t.to(memory_format=mem_fmt)
706        return FakeTensor(fake_mode, t, device)
707
708    with in_kernel_invocation_manager(fake_mode):
709        out = func(**kwargs)
710
711        if func is aten.convolution.default:
712            return convert(out, mem_fmt)
713        else:
714            return (
715                convert(out[0], mem_fmt),
716                convert(out[1], mem_fmt),
717                convert(out[2], None),
718            )
719
720
721@register_op_impl(torch.ops.aten._pack_padded_sequence.default)
722def _pack_padded_sequence(fake_mode, func, inputs, lengths, batch_first):
723    if (
724        fake_mode.shape_env is None
725        or not fake_mode.shape_env.allow_dynamic_output_shape_ops
726    ):
727        # Without symints/symfloats, cannot handle this
728        raise DynamicOutputShapeException(func)
729
730    new_batch_size = fake_mode.shape_env.create_unbacked_symint()
731
732    from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
733
734    _constrain_range_for_size(new_batch_size)
735
736    if not batch_first:
737        # Inputs should have shape (batch_size, seq_len, *)
738        inputs = inputs.transpose(0, 1)
739
740    res_size = inputs.shape[1:]
741    packed_data = inputs.new_empty(res_size)
742    batch_size = inputs.new_empty((new_batch_size,))
743    return (packed_data, batch_size)
744
745
746FAST_OP_IMPLEMENTATIONS = {}
747
748
749# Unlike register_op_impl, these don't do the slow iteration for
750# run_impl_check, and these run BEFORE decompositions
751def register_fast_op_impl(func: OpOverload):
752    def impl_decorator(op_impl):
753        FAST_OP_IMPLEMENTATIONS[func] = op_impl
754        return op_impl
755
756    return impl_decorator
757
758
759# infer_size_impl in ExpandUtils
760def infer_size(a, b):
761    from torch.fx.experimental.symbolic_shapes import guard_size_oblivious
762
763    dimsA = len(a)
764    dimsB = len(b)
765    ndim = max(dimsA, dimsB)
766    expandedSizes = [0] * ndim
767    for i in range(ndim - 1, -1, -1):
768        offset = ndim - 1 - i
769        dimA = dimsA - 1 - offset
770        dimB = dimsB - 1 - offset
771        sizeA = a[dimA] if dimA >= 0 else 1
772        sizeB = b[dimB] if dimB >= 0 else 1
773
774        # NB: It is very important to test for broadcasting, before testing
775        # sizeA == sizeB.  This is because the broadcasting tests are likely
776        # to be statically known (in particular, if sizeA/sizeB is unbacked
777        # but size-like, we will unsoundly assume they never equal 1), but
778        # the sizeA == sizeB test may not be statically known.  However, once
779        # we have established that no broadcasting is happening, the
780        # sizeA == sizeB is now expect_true and we can defer it as a runtime
781        # assert (this works because Python will return the terminal
782        # expression of an or statement as-is, without bool()'ing it; if this
783        # were not the case, we'd need to write this using torch.sym_or() or
784        # something like that).
785        torch._check(
786            guard_size_oblivious(sizeA == 1)
787            or guard_size_oblivious(sizeB == 1)
788            or sizeA == sizeB,
789            lambda: f"The size of tensor a ({sizeA}) "
790            f"must match the size of tensor b ({sizeB}) "
791            f"at non-singleton dimension {i})",
792        )
793        expandedSizes[i] = sizeB if guard_size_oblivious(sizeA == 1) else sizeA
794    return tuple(expandedSizes)
795
796
797def make_fast_binary_impl(slow_ref):
798    def fast_binary_impl(mode, *args, **kwargs):
799        def slow(msg):
800            count_label(f"slow {msg}")
801            with mode:
802                return slow_ref(*args, **kwargs)
803
804        count_label("attempt fast")
805
806        # Fast path (based off of TensorIterator fast path).
807        # Unfortunately, there is no way to easily deduplicate
808        # this with either the TensorIterator C++ implementation
809        # (which we don't want to SymIntify, and also the algorithm
810        # here is slightly different from TensorIterator to allow
811        # for broadcasting), nor the PrimTorch implementation
812        # (which does not actually implement a fast path.)
813
814        operands = args
815
816        # compute_shape
817        has_scalars = False
818        has_tensors = False
819        final_shape = None
820        for op in operands:
821            shape = op.shape if isinstance(op, torch.Tensor) else ()
822            if len(shape) == 0:
823                has_scalars = True
824            else:
825                has_tensors = True
826            if final_shape is None:
827                final_shape = shape
828            # TODO: Minor optimization: track if the shapes
829            # were equal so you can skip the equality check
830            # below if unnecessary
831            final_shape = infer_size(final_shape, shape)
832        assert final_shape is not None
833
834        from torch.fx.experimental.symbolic_shapes import guard_size_oblivious, sym_eq
835
836        # Do some extra safety checks to see if the output
837        # stride is obvious
838        for op in operands:
839            if (
840                isinstance(op, torch.Tensor)
841                and len(op.shape) == len(final_shape)
842                and guard_size_oblivious(sym_eq(op.shape, final_shape))
843            ):
844                break
845        else:
846            return slow("both tensors nontrivially broadcast")
847
848        # compute_types
849        cpu = torch.device("cpu")
850        common_device = cpu
851        common_dtype = None
852        output_dtype = None
853        has_different_input_dtypes = False
854        for op in operands:
855            if not isinstance(op, torch.Tensor):
856                # Use elementwise_dtypes for the tricky case
857                has_different_input_dtypes = True
858                continue
859            if common_device == cpu and not op.device.type == "cpu":
860                common_device = op.device
861            # Slightly simplified here as target_dtype cannot vary
862            if common_dtype is None:
863                common_dtype = op.dtype
864            elif common_dtype != op.dtype:
865                has_different_input_dtypes = True
866
867        if has_different_input_dtypes:
868            # compute promotion
869            # TODO: we don't need the compute type
870            _, common_dtype = elementwise_dtypes(
871                *operands, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
872            )
873
874        # check all tensors on same device
875        # cpu scalars are assumed allow
876        current_cpu_scalars_on_non_cpu = 0
877        max_cpu_scalars_on_non_cpu = 1  # hard coded atm
878        for op in operands:
879            if not isinstance(op, torch.Tensor):
880                continue
881            if common_device != cpu and op.dim() == 0 and op.device == cpu:
882                if current_cpu_scalars_on_non_cpu >= max_cpu_scalars_on_non_cpu:
883                    return slow("error")
884                current_cpu_scalars_on_non_cpu += 1
885            elif op.device != common_device:
886                return slow("error")
887
888        # compute_fast_setup_type
889        is_contiguous = True
890        is_channels_last = True
891        # TODO: is_non-overlapping_and_dense (not bound from Python
892        # no inplace, no out, everything defined
893
894        if is_noncontiguous_supported(common_device):
895            for op in operands:
896                if not isinstance(op, torch.Tensor):
897                    continue
898                is_contiguous = is_contiguous and op.is_contiguous(
899                    memory_format=torch.contiguous_format
900                )
901                is_channels_last = is_channels_last and op.is_contiguous(
902                    memory_format=torch.channels_last
903                )
904        if is_contiguous:
905            # do contiguous
906            count_label("fast is_contiguous")
907            return FakeTensor(
908                mode,
909                torch.empty(
910                    final_shape,
911                    dtype=common_dtype,
912                    device="meta",
913                    memory_format=torch.contiguous_format,
914                ),
915                device=common_device,
916            )
917        if is_channels_last:
918            count_label("fast channels_last")
919            # do channels last
920            return FakeTensor(
921                mode,
922                torch.empty(
923                    final_shape,
924                    dtype=common_dtype,
925                    device="meta",
926                    memory_format=torch.channels_last,
927                ),
928                device=common_device,
929            )
930
931        return slow("no contiguity match")
932
933    return fast_binary_impl
934
935
936# disable the python dispatcher to avoid decomposing detach() further
937# (proxy_mode should still decompose detach() though)
938def fast_detach(fake_mode, x):
939    with no_python_dispatcher(), in_kernel_invocation_manager(fake_mode):
940        out = torch.ops.aten.detach.default(x)
941    return FakeTensor(fake_mode, out, x.device)
942
943
944@functools.lru_cache(None)
945def get_fast_op_impls():
946    import torch._refs
947
948    register_fast_op_impl(torch.ops.aten.add.Tensor)(
949        make_fast_binary_impl(torch._refs.add)
950    )
951    register_fast_op_impl(torch.ops.aten.sub.Tensor)(
952        make_fast_binary_impl(torch._refs.sub)
953    )
954    register_fast_op_impl(torch.ops.aten.mul.Tensor)(make_fast_binary_impl(torch._refs.mul))  # type: ignore[has-type]
955    register_fast_op_impl(torch.ops.aten.div.Tensor)(
956        make_fast_binary_impl(torch._refs.div)
957    )
958    register_fast_op_impl(torch.ops.aten.detach.default)(fast_detach)
959    return FAST_OP_IMPLEMENTATIONS
960