1# Owner(s): ["oncall: jit"] 2 3import os 4import sys 5import unittest 6from pathlib import Path 7 8import torch 9import torch._C 10from torch.testing._internal.common_utils import IS_FBCODE, skipIfTorchDynamo 11 12 13# hacky way to skip these tests in fbcode: 14# during test execution in fbcode, test_nnapi is available during test discovery, 15# but not during test execution. So we can't try-catch here, otherwise it'll think 16# it sees tests but then fails when it tries to actuall run them. 17if not IS_FBCODE: 18 from test_nnapi import TestNNAPI 19 20 HAS_TEST_NNAPI = True 21else: 22 from torch.testing._internal.common_utils import TestCase as TestNNAPI 23 24 HAS_TEST_NNAPI = False 25 26 27# Make the helper files in test/ importable 28pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) 29sys.path.append(pytorch_test_dir) 30 31if __name__ == "__main__": 32 raise RuntimeError( 33 "This test file is not meant to be run directly, use:\n\n" 34 "\tpython test/test_jit.py TESTNAME\n\n" 35 "instead." 36 ) 37 38""" 39Unit Tests for Nnapi backend with delegate 40Inherits most tests from TestNNAPI, which loads Android NNAPI models 41without the delegate API. 42""" 43# First skip is needed for IS_WINDOWS or IS_MACOS to skip the tests. 44torch_root = Path(__file__).resolve().parent.parent.parent 45lib_path = torch_root / "build" / "lib" / "libnnapi_backend.so" 46 47 48@skipIfTorchDynamo("weird py38 failures") 49@unittest.skipIf( 50 not os.path.exists(lib_path), 51 "Skipping the test as libnnapi_backend.so was not found", 52) 53@unittest.skipIf(IS_FBCODE, "test_nnapi.py not found") 54class TestNnapiBackend(TestNNAPI): 55 def setUp(self): 56 super().setUp() 57 58 # Save default dtype 59 module = torch.nn.PReLU() 60 self.default_dtype = module.weight.dtype 61 # Change dtype to float32 (since a different unit test changed dtype to float64, 62 # which is not supported by the Android NNAPI delegate) 63 # Float32 should typically be the default in other files. 64 torch.set_default_dtype(torch.float32) 65 66 # Load nnapi delegate library 67 torch.ops.load_library(str(lib_path)) 68 69 # Override 70 def call_lowering_to_nnapi(self, traced_module, args): 71 compile_spec = {"forward": {"inputs": args}} 72 return torch._C._jit_to_backend("nnapi", traced_module, compile_spec) 73 74 def test_tensor_input(self): 75 # Lower a simple module 76 args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1) 77 module = torch.nn.PReLU() 78 traced = torch.jit.trace(module, args) 79 80 # Argument input is a single Tensor 81 self.call_lowering_to_nnapi(traced, args) 82 # Argument input is a Tensor in a list 83 self.call_lowering_to_nnapi(traced, [args]) 84 85 # Test exceptions for incorrect compile specs 86 def test_compile_spec_santiy(self): 87 args = torch.tensor([[1.0, -1.0, 2.0, -2.0]]).unsqueeze(-1).unsqueeze(-1) 88 module = torch.nn.PReLU() 89 traced = torch.jit.trace(module, args) 90 91 errorMsgTail = r""" 92method_compile_spec should contain a Tensor or Tensor List which bundles input parameters: shape, dtype, quantization, and dimorder. 93For input shapes, use 0 for run/load time flexible input. 94method_compile_spec must use the following format: 95{"forward": {"inputs": at::Tensor}} OR {"forward": {"inputs": c10::List<at::Tensor>}}""" 96 97 # No forward key 98 compile_spec = {"backward": {"inputs": args}} 99 with self.assertRaisesRegex( 100 RuntimeError, 101 'method_compile_spec does not contain the "forward" key.' + errorMsgTail, 102 ): 103 torch._C._jit_to_backend("nnapi", traced, compile_spec) 104 105 # No dictionary under the forward key 106 compile_spec = {"forward": 1} 107 with self.assertRaisesRegex( 108 RuntimeError, 109 'method_compile_spec does not contain a dictionary with an "inputs" key, ' 110 'under it\'s "forward" key.' + errorMsgTail, 111 ): 112 torch._C._jit_to_backend("nnapi", traced, compile_spec) 113 114 # No inputs key (in the dictionary under the forward key) 115 compile_spec = {"forward": {"not inputs": args}} 116 with self.assertRaisesRegex( 117 RuntimeError, 118 'method_compile_spec does not contain a dictionary with an "inputs" key, ' 119 'under it\'s "forward" key.' + errorMsgTail, 120 ): 121 torch._C._jit_to_backend("nnapi", traced, compile_spec) 122 123 # No Tensor or TensorList under the inputs key 124 compile_spec = {"forward": {"inputs": 1}} 125 with self.assertRaisesRegex( 126 RuntimeError, 127 'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.' 128 + errorMsgTail, 129 ): 130 torch._C._jit_to_backend("nnapi", traced, compile_spec) 131 compile_spec = {"forward": {"inputs": [1]}} 132 with self.assertRaisesRegex( 133 RuntimeError, 134 'method_compile_spec does not contain either a Tensor or TensorList, under it\'s "inputs" key.' 135 + errorMsgTail, 136 ): 137 torch._C._jit_to_backend("nnapi", traced, compile_spec) 138 139 def tearDown(self): 140 # Change dtype back to default (Otherwise, other unit tests will complain) 141 torch.set_default_dtype(self.default_dtype) 142