xref: /aosp_15_r20/external/executorch/exir/passes/dim_order_ops_registry.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import torch
8from executorch.exir.dialects._ops import ops as exir_ops
9from executorch.exir.dim_order_utils import get_memory_format
10
11from torch.library import impl, Library
12
13lib = Library("dim_order_ops", "DEF")
14lib.define(
15    "_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor"
16)
17
18# Out variant drops TensorOptions
19lib.define(
20    "_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)"
21)
22
23
24def _op_impl(target, *args, **kwargs):
25    kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None))
26    _ = kwargs.pop("dim_order", None)
27    res = target(*args, **kwargs)
28    # assert list(res.dim_order()) == dim_order
29    return res
30
31
32@impl(lib, "_to_dim_order_copy", "CompositeImplicitAutograd")
33def _to_dim_order_copy_impl(*args, **kwargs):
34    return _op_impl(torch.ops.aten._to_copy, *args, **kwargs)
35
36
37@impl(lib, "_to_dim_order_copy.out", "CompositeImplicitAutograd")
38def _to_dim_order_copy_out_impl(*args, **kwargs):
39    return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs)
40
41
42"""
43Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup
44"""
45DimOrderOpsMap = {
46    "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default,
47}
48
49"""
50Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup
51"""
52MemoryFormatOpsMap = {
53    "dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default,
54}
55
56# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts.
57assert len(DimOrderOpsMap) == len(MemoryFormatOpsMap)
58
59# TODO stricter check for 1:1 mapping
60