xref: /aosp_15_r20/external/pytorch/test/jit/test_backend_nnapi.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1# Owner(s): ["oncall: jit"]
2
3import os
4import sys
5import unittest
6from pathlib import Path
7
8import torch
9import torch._C
10from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo
11
12
13# hacky way to skip these tests in fbcode:
14# during test execution in fbcode, test_nnapi is available during test discovery,
15# but not during test execution. So we can't try-catch here, otherwise it'll think
16# it sees tests but then fails when it tries to actuall run them.
17if not IS_FBCODE:
18    from test_nnapi import TestNNAPI
19
20    HAS_TEST_NNAPI = True
21else:
22    from torch.testing._internal.common_utils import TestCase as TestNNAPI
23
24    HAS_TEST_NNAPI = False
25
26
27# Make the helper files in test/ importable
28pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
29sys.path.append(pytorch_test_dir)
30
31if __name__ == "__main__":
32    raise RuntimeError(
33        "This test file is not meant to be run directly, use:\n\n"
34        "\tpython test/test_jit.py TESTNAME\n\n"
35        "instead."
36    )
37
38"""
39Unit Tests for Nnapi backend with delegate
40Inherits most tests from TestNNAPI, which loads Android NNAPI models
41without the delegate API.
42"""
43# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests.
44torch_root = Path(__file__).resolve().parent.parent.parent
45lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so"
46
47
48@skipIfTorchDynamo("weird py38 failures")
49@unittest.skipIf(
50    not os.path.exists(lib_path),
51    "Skipping the test as libnnapi_backend.so was not found",
52)
53@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found")
54class TestNnapiBackend(TestNNAPI):
55    def setUp(self):
56        super().setUp()
57
58        # Save default dtype
59        module = torch.nn.PReLU()
60        self.default_dtype = module.weight.dtype
61        # Change dtype to float32 (since a different unit test changed dtype to float64,
62        # which is not supported by the Android NNAPI delegate)
63        # Float32 should typically be the default in other files.
64        torch.set_default_dtype(torch.float32)
65
66        # Load nnapi delegate library
67        torch.ops.load_library(str(lib_path))
68
69    # Override
70    def call_lowering_to_nnapi(self, traced_module, args):
71        compile_spec = {"forward": {"inputs": args}}
72        return torch._C._jit_to_backend("nnapi", traced_module, compile_spec)
73
74    def test_tensor_input(self):
75        # Lower a simple module
76        args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
77        module = torch.nn.PReLU()
78        traced = torch.jit.trace(module, args)
79
80        # Argument input is a single Tensor
81        self.call_lowering_to_nnapi(traced, args)
82        # Argument input is a Tensor in a list
83        self.call_lowering_to_nnapi(traced, [args])
84
85    # Test exceptions for incorrect compile specs
86    def test_compile_spec_santiy(self):
87        args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1)
88        module = torch.nn.PReLU()
89        traced = torch.jit.trace(module, args)
90
91        errorMsgTail = r"""
92method_compile_spec should contain a Tensor or Tensor List which bundles input parameters: shape, dtype, quantization, and dimorder.
93For input shapes, use 0 for run/load time flexible input.
94method_compile_spec must use the following format:
95{"forward": {"inputs": at::Tensor}} OR {"forward": {"inputs": c10::List<at::Tensor>}}"""
96
97        # No forward key
98        compile_spec = {"backward": {"inputs": args}}
99        with self.assertRaisesRegex(
100            RuntimeError,
101            'method_compile_spec does not contain the "forward" key.' + errorMsgTail,
102        ):
103            torch._C._jit_to_backend("nnapi", traced, compile_spec)
104
105        # No dictionary under the forward key
106        compile_spec = {"forward": 1}
107        with self.assertRaisesRegex(
108            RuntimeError,
109            'method_compile_spec does not contain a dictionary with an "inputs" key, '
110            'under it\'s "forward" key.' + errorMsgTail,
111        ):
112            torch._C._jit_to_backend("nnapi", traced, compile_spec)
113
114        # No inputs key (in the dictionary under the forward key)
115        compile_spec = {"forward": {"not inputs": args}}
116        with self.assertRaisesRegex(
117            RuntimeError,
118            'method_compile_spec does not contain a dictionary with an "inputs" key, '
119            'under it\'s "forward" key.' + errorMsgTail,
120        ):
121            torch._C._jit_to_backend("nnapi", traced, compile_spec)
122
123        # No Tensor or TensorList under the inputs key
124        compile_spec = {"forward": {"inputs": 1}}
125        with self.assertRaisesRegex(
126            RuntimeError,
127            'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
128            + errorMsgTail,
129        ):
130            torch._C._jit_to_backend("nnapi", traced, compile_spec)
131        compile_spec = {"forward": {"inputs": [1]}}
132        with self.assertRaisesRegex(
133            RuntimeError,
134            'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.'
135            + errorMsgTail,
136        ):
137            torch._C._jit_to_backend("nnapi", traced, compile_spec)
138
139    def tearDown(self):
140        # Change dtype back to default (Otherwise, other unit tests will complain)
141        torch.set_default_dtype(self.default_dtype)
142