1#!/usr/bin/env fbpython 2# Copyright (c) Meta Platforms, Inc. and affiliates. 3# All rights reserved. 4# 5# This source code is licensed under the BSD-style license found in the 6# LICENSE file in the root directory of this source tree. 7 8import unittest 9 10import torch 11from executorch.exir import EdgeCompileConfig, to_edge 12from executorch.exir.dialects._ops import ops 13from executorch.exir.dialects.edge._ops import EdgeOpOverload 14from executorch.exir.verification.arg_validator import EdgeOpArgValidator 15from torch.export import export 16 17 18class TestArgValidator(unittest.TestCase): 19 """Test for EdgeOpArgValidator""" 20 21 def setUp(self) -> None: 22 super().setUp() 23 24 def test_edge_dialect_passes(self) -> None: 25 class TestModel(torch.nn.Module): 26 def __init__(self): 27 super().__init__() 28 29 def forward(self, x): 30 return x + x 31 32 m = TestModel() 33 inputs = (torch.randn(1, 3, 100, 100).to(dtype=torch.int),) 34 egm = to_edge(export(m, inputs)).exported_program().graph_module 35 validator = EdgeOpArgValidator(egm) 36 validator.run(*inputs) 37 self.assertEqual(len(validator.violating_ops), 0) 38 39 def test_edge_dialect_fails(self) -> None: 40 # torch.bfloat16 is not supported by edge::aten::_log_softmax 41 class M(torch.nn.Module): 42 def __init__(self): 43 super().__init__() 44 self.m = torch.nn.LogSoftmax(dim=1) 45 46 def forward(self, x): 47 return self.m(x) 48 49 inputs = (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),) 50 egm = ( 51 to_edge( 52 export(M(), inputs), 53 compile_config=EdgeCompileConfig(_check_ir_validity=False), 54 ) 55 .exported_program() 56 .graph_module 57 ) 58 validator = EdgeOpArgValidator(egm) 59 validator.run(*inputs) 60 self.assertEqual(len(validator.violating_ops), 1) 61 key: EdgeOpOverload = next(iter(validator.violating_ops)) 62 self.assertEqual( 63 key.name(), 64 ops.edge.aten._log_softmax.default.name(), 65 ) 66 self.assertDictEqual( 67 validator.violating_ops[key], 68 { 69 "self": torch.bfloat16, 70 "__ret_0": torch.bfloat16, 71 }, 72 ) 73