xref: /aosp_15_r20/external/executorch/backends/arm/_passes/decompose_div_pass.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright 2024 Arm Limited and/or its affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7# pyre-unsafe
8
9import torch
10from executorch.exir.dialects._ops import ops as exir_ops
11from executorch.exir.pass_base import ExportPass
12
13edge_div_ops = (exir_ops.edge.aten.div.Tensor,)
14aten_div_ops = (torch.ops.aten.div.Tensor, torch.ops.aten.div_.Tensor)
15
16
17def get_div_decomposition(op) -> tuple:
18    """
19    Returns the the (reciprocal_op, mul_op), where the ops depends on if
20    the div op is in exir_ops torch.ops.aten.
21    """
22    if op in edge_div_ops:
23        return (exir_ops.edge.aten.reciprocal.default, exir_ops.edge.aten.mul.Tensor)
24    if op in aten_div_ops:
25        return (torch.ops.aten.reciprocal.default, torch.ops.aten.mul.Tensor)
26    raise RuntimeError(f"Can't get div decomposition for op {op}")
27
28
29class DecomposeDivPass(ExportPass):
30    """
31    This pass decomposes div into a mul and a reciprocal node.
32
33    Example:
34        y = div(a,b)
35    Becomes:
36        x = reciprocal(b)
37        y = mul(a,x)
38    """
39
40    def call_operator(self, op, args, kwargs, meta):
41        if op not in (edge_div_ops + aten_div_ops):
42            return super().call_operator(op, args, kwargs, meta)
43
44        reciprocal_op, mul_op = get_div_decomposition(op)
45
46        numerator = args[0]
47        denominator = args[1]
48        reciprocal = super().call_operator(reciprocal_op, (denominator,), {}, meta)
49
50        return super().call_operator(mul_op, (numerator, reciprocal), {}, meta)
51