1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import torch 8from executorch.exir.dialects._ops import ops as exir_ops 9from executorch.exir.dim_order_utils import get_memory_format 10 11from torch.library import impl, Library 12 13lib = Library("dim_order_ops", "DEF") 14lib.define( 15 "_to_dim_order_copy(Tensor self, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, bool non_blocking=False, int[]? dim_order=None) -> Tensor" 16) 17 18# Out variant drops TensorOptions 19lib.define( 20 "_to_dim_order_copy.out(Tensor self, *, bool non_blocking=False, int[]? dim_order=None, Tensor(a!) out) -> Tensor(a!)" 21) 22 23 24def _op_impl(target, *args, **kwargs): 25 kwargs["memory_format"] = get_memory_format(kwargs.get("dim_order", None)) 26 _ = kwargs.pop("dim_order", None) 27 res = target(*args, **kwargs) 28 # assert list(res.dim_order()) == dim_order 29 return res 30 31 32@impl(lib, "_to_dim_order_copy", "CompositeImplicitAutograd") 33def _to_dim_order_copy_impl(*args, **kwargs): 34 return _op_impl(torch.ops.aten._to_copy, *args, **kwargs) 35 36 37@impl(lib, "_to_dim_order_copy.out", "CompositeImplicitAutograd") 38def _to_dim_order_copy_out_impl(*args, **kwargs): 39 return _op_impl(torch.ops.aten._to_copy.out, *args, **kwargs) 40 41 42""" 43Defines a map of aten or edge ops to the corresponding dim_order ops for quick lookup 44""" 45DimOrderOpsMap = { 46 "aten._to_copy.default": exir_ops.edge.dim_order_ops._to_dim_order_copy.default, 47} 48 49""" 50Defines a map of aten or edge ops to the corresponding memory format ops for quick lookup 51""" 52MemoryFormatOpsMap = { 53 "dim_order_ops._to_dim_order_copy.default": exir_ops.edge.aten._to_copy.default, 54} 55 56# If we are replacing an aten op with a dim_order op, we must have a 1:1 mapping through these dicts. 57assert len(DimOrderOpsMap) == len(MemoryFormatOpsMap) 58 59# TODO stricter check for 1:1 mapping 60