1# Copyright (c) Meta Platforms, Inc. and affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-strict 8 9from typing import List, Tuple 10 11import torch 12from executorch.exir.dialects._ops import ops as exir_ops 13from executorch.exir.pass_base import ExportPass, PassResult 14from torch.fx import GraphModule 15 16_DEQUANT_OPS: Tuple[torch._ops.OpOverload] = ( 17 torch.ops.quantized_decomposed.dequantize_per_tensor.default, 18 torch.ops.quantized_decomposed.dequantize_per_channel.default, 19 exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, 20 exir_ops.edge.quantized_decomposed.dequantize_per_channel.default, 21) 22_QUANT_OPS: Tuple[torch._ops.OpOverload] = ( 23 torch.ops.quantized_decomposed.quantize_per_tensor.default, 24 torch.ops.quantized_decomposed.quantize_per_channel.default, 25 exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, 26 exir_ops.edge.quantized_decomposed.quantize_per_channel.default, 27) 28 29 30def eliminate_dq_q( 31 graph_module: GraphModule, 32 dequant_nodes: List[torch.fx.Node], 33) -> None: 34 for node in dequant_nodes: 35 assert node.target in _DEQUANT_OPS 36 for user in list(node.users): 37 if user.target in _QUANT_OPS: 38 # Drop the input arg and check that the qparams are the same. 39 qparams_dq = list(node.args)[1:] 40 qparams_q = list(user.args)[1:] 41 if qparams_dq != qparams_q: 42 continue 43 user.replace_all_uses_with(node.args[0]) # pyre-fixme[6] 44 45 46class RemoveNoopPass(ExportPass): 47 """ 48 Removes noops that pass through arguments. 49 """ 50 51 def call(self, graph_module: GraphModule) -> PassResult: 52 53 # In this list we'll collect all the dequant nodes that are inputs to ops that 54 # are removed in this pass and later check for redundant dq->q patterns and 55 # remove them. 56 dequant_nodes = [] 57 58 for node in graph_module.graph.nodes: 59 if node.op != "call_function": 60 continue 61 62 if node.target not in ( 63 torch.ops.aten.to.dtype, 64 torch.ops.aten.dropout.default, 65 torch.ops.aten.slice_copy.Tensor, 66 ): 67 continue 68 69 orig_tensor = node.args[0].meta["val"] 70 71 if orig_tensor is node.meta["val"]: 72 # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q. 73 # Otherwise, removing only the op will suffice. 74 if node.args[0].target in _DEQUANT_OPS: 75 dequant_nodes += [node.args[0]] 76 node.replace_all_uses_with(node.args[0]) 77 continue 78 79 if node.target == torch.ops.aten.slice_copy.Tensor: 80 # Only do this check if all the dims are static. 81 if all(isinstance(dim, int) for dim in orig_tensor.size()): 82 if orig_tensor.shape == node.meta["val"].shape: 83 # If the graph is quantized, we must remove the entire pattern consisting of dq->op->q. 84 # Otherwise, removing only the op will suffice. 85 if node.args[0].target in _DEQUANT_OPS: 86 dequant_nodes += [node.args[0]] 87 node.replace_all_uses_with(node.args[0]) 88 89 graph_module.graph.eliminate_dead_code() 90 eliminate_dq_q(graph_module, dequant_nodes) 91 graph_module.graph.lint() 92 graph_module.graph.eliminate_dead_code() 93 94 return PassResult(graph_module, True) 95 96 97class RemoveToCopyPass(ExportPass): 98 """ 99 Removes _to_copy that pass through arguments. 100 """ 101 102 def call(self, graph_module: GraphModule) -> PassResult: 103 for node in graph_module.graph.nodes: 104 if node.op != "call_function": 105 continue 106 107 if node.target not in (torch.ops.aten._to_copy.default,): 108 continue 109 110 orig_tensor = node.args[0].meta["val"] 111 112 if ( 113 orig_tensor.dtype == node.meta["val"].dtype 114 and orig_tensor.device == node.meta["val"].device 115 and orig_tensor.shape == node.meta["val"].shape 116 and orig_tensor.stride() == node.meta["val"].stride() 117 ): 118 node.replace_all_uses_with(node.args[0]) 119 120 graph_module.graph.eliminate_dead_code() 121 graph_module.graph.lint() 122 123 return PassResult(graph_module, True) 124