1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7import copy 8import logging 9 10import torch 11from executorch.exir.dialects.edge._ops import EdgeOpOverload 12from executorch.exir.dim_order_utils import get_dim_order, get_memory_format 13from executorch.exir.pass_base import ExportPass, ProxyValue 14from executorch.exir.passes.dim_order_ops_registry import ( 15 DimOrderOpsMap, 16 MemoryFormatOpsMap, 17) 18 19logger = logging.getLogger(__file__) 20logger.setLevel(logging.INFO) 21 22# TODO - these passes are too specialized on a single to_copy op. 23# We should be able to replace (or revert) any of the dim_order ops in the future. 24 25 26class MemoryFormatOpsPass(ExportPass): 27 """ 28 This pass replaces ops which takes torch.memory_format as an argument with 29 'equivalent' op which takes dim_order. This is towards the larger ExecuTorch 30 goal to move away from torch.memory_format. There is a 1:1 mapping between 31 the aten op and the new edge dialect dim_order op. 32 """ 33 34 def call_operator(self, op, args, kwargs, meta): 35 if not (isinstance(op, EdgeOpOverload) and op.__name__ in DimOrderOpsMap): 36 return super().call_operator( 37 op, 38 args, 39 kwargs, 40 meta, 41 ) 42 # new kwargs with dim_order, and no memory_format for the new op 43 nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable 44 45 # get the "to" memory format for the EdgeOp 46 mem_format = nkwargs.pop("memory_format", torch.contiguous_format) 47 48 # can always get the shape, assuming rank is specialized 49 if isinstance(args[0], ProxyValue) and args[0].is_tensor(): 50 ndim = args[0].to_tensor().dim() 51 elif isinstance(args[0], torch.Tensor): 52 ndim = args[0].dim() 53 else: 54 assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}" 55 56 nkwargs["dim_order"] = get_dim_order(mem_format, ndim) 57 logger.debug( 58 f"_to_copy = rank: {ndim}, memory_format: {mem_format}." 59 f" _to_dim_order_copy = dim_order: {nkwargs['dim_order']}" 60 ) 61 62 t = DimOrderOpsMap.get(op.__name__, None) 63 assert t is not None, f"{op.__name__} not found in DimOrderOpsMap" 64 65 return super().call_operator( 66 t, 67 args, 68 nkwargs, 69 meta, 70 ) 71 72 73class DimOrderOpsRevertPass(ExportPass): 74 """ 75 This pass is to revert the dim_order ops back to the memory format ops. 76 """ 77 78 def call_operator(self, op, args, kwargs, meta): 79 if not (isinstance(op, EdgeOpOverload) and op.__name__ in MemoryFormatOpsMap): 80 return super().call_operator( 81 op, 82 args, 83 kwargs, 84 meta, 85 ) 86 87 # new kwargs with dim_order, and no memory_format for the new op 88 nkwargs = dict(copy.deepcopy(kwargs)) # orig kwargs are immutable 89 90 # can always get the shape, assuming rank is specialized 91 if isinstance(args[0], ProxyValue) and args[0].is_tensor(): 92 ndim = args[0].to_tensor().dim() 93 elif isinstance(args[0], torch.Tensor): 94 ndim = args[0].dim() 95 else: 96 assert 0, f"Expecting a Tensor or a ProxyValue buy got {type(args[0])}" 97 98 # get the "to" memory format for the EdgeOp 99 default_dim_order = list(range(ndim)) 100 dim_order = nkwargs.pop("dim_order", default_dim_order) 101 102 nkwargs["memory_format"] = get_memory_format(dim_order) 103 104 logger.debug( 105 f" _to_dim_order_copy = dim_order: {dim_order}." 106 f"_to_copy = rank: {ndim}, memory_format: {nkwargs['memory_format']}." 107 ) 108 109 t = MemoryFormatOpsMap.get(op.__name__, None) 110 assert t is not None, f"{op.__name__} not found in MemoryFormatOpsMap" 111 112 return super().call_operator( 113 t, 114 args, 115 nkwargs, 116 meta, 117 ) 118