1# Copyright 2024 Arm Limited and/or its affiliates. 2# All rights reserved. 3# 4# This source code is licensed under the BSD-style license found in the 5# LICENSE file in the root directory of this source tree. 6 7# pyre-unsafe 8 9import torch 10from executorch.exir.dialects._ops import ops as exir_ops 11from executorch.exir.pass_base import ExportPass 12 13edge_div_ops = (exir_ops.edge.aten.div.Tensor,) 14aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor) 15 16 17def get_div_decomposition(op) -> tuple: 18 """ 19 Returns the the (reciprocal_op, mul_op), where the ops depends on if 20 the div op is in exir_ops torch.ops.aten. 21 """ 22 if op in edge_div_ops: 23 return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor) 24 if op in aten_div_ops: 25 return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor) 26 raise RuntimeError(f"Can't get div decomposition for op {op}") 27 28 29class DecomposeDivPass(ExportPass): 30 """ 31 This pass decomposes div into a mul and a reciprocal node. 32 33 Example: 34 y = div(a,b) 35 Becomes: 36 x = reciprocal(b) 37 y = mul(a,x) 38 """ 39 40 def call_operator(self, op, args, kwargs, meta): 41 if op not in (edge_div_ops + aten_div_ops): 42 return super().call_operator(op, args, kwargs, meta) 43 44 reciprocal_op, mul_op = get_div_decomposition(op) 45 46 numerator = args[0] 47 denominator = args[1] 48 reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta) 49 50 return super().call_operator(mul_op, (numerator, reciprocal), {}, meta) 51