xref: /aosp_15_r20/external/executorch/backends/xnnpack/test/ops/div.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1# Copyright (c) Meta Platforms, Inc. and affiliates.
2# All rights reserved.
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import unittest
8
9import torch
10from executorch.backends.xnnpack.test.tester import Tester
11
12
13class TestDiv(unittest.TestCase):
14    class Div(torch.nn.Module):
15        def __init__(self):
16            super().__init__()
17
18        def forward(self, x, y):
19            z = x / y
20            return z
21
22    class DivSingleInput(torch.nn.Module):
23        def __init__(self):
24            super().__init__()
25
26        def forward(self, x):
27            z = x / x
28            return z
29
30    def _test_div(self, inputs):
31        (
32            Tester(self.Div(), inputs)
33            .export()
34            .check_count({"torch.ops.aten.div.Tensor": 1})
35            .to_edge_transform_and_lower()
36            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
37            .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"])
38            .to_executorch()
39            .serialize()
40            .run_method_and_compare_outputs()
41        )
42
43    def test_fp16_div(self):
44        # Adding 4 to move distribution away from 0, 4 Std Dev should be far enough
45        inputs = (
46            (torch.randn(1) + 4).to(torch.float16),
47            (torch.randn(1) + 4).to(torch.float16),
48        )
49        self._test_div(inputs)
50
51    def test_fp32_div(self):
52        # Adding 4 to move distribution away from 0, 4 Std Dev should be far enough
53        inputs = (torch.randn(1) + 4, torch.randn(1) + 4)
54        self._test_div(inputs)
55
56    def test_fp32_div_single_input(self):
57        # Adding 4 to move distribution away from 0, 4 Std Dev should be far enough
58        inputs = (torch.randn(1) + 4,)
59        (
60            Tester(self.DivSingleInput(), inputs)
61            .export()
62            .check_count({"torch.ops.aten.div.Tensor": 1})
63            .to_edge_transform_and_lower()
64            .check_count({"torch.ops.higher_order.executorch_call_delegate": 1})
65            .check_not(["executorch_exir_dialects_edge__ops_aten_div_Tensor"])
66            .to_executorch()
67            .serialize()
68            .run_method_and_compare_outputs()
69        )
70