xref: /aosp_15_r20/external/executorch/exir/tests/test_arg_validator.py (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1#!/usr/bin/env fbpython
2# Copyright (c) Meta Platforms, Inc. and affiliates.
3# All rights reserved.
4#
5# This source code is licensed under the BSD-style license found in the
6# LICENSE file in the root directory of this source tree.
7
8import unittest
9
10import torch
11from executorch.exir import EdgeCompileConfig, to_edge
12from executorch.exir.dialects._ops import ops
13from executorch.exir.dialects.edge._ops import EdgeOpOverload
14from executorch.exir.verification.arg_validator import EdgeOpArgValidator
15from torch.export import export
16
17
18class TestArgValidator(unittest.TestCase):
19    """Test for EdgeOpArgValidator"""
20
21    def setUp(self) -> None:
22        super().setUp()
23
24    def test_edge_dialect_passes(self) -> None:
25        class TestModel(torch.nn.Module):
26            def __init__(self):
27                super().__init__()
28
29            def forward(self, x):
30                return x + x
31
32        m = TestModel()
33        inputs = (torch.randn(1, 3, 100, 100).to(dtype=torch.int),)
34        egm = to_edge(export(m, inputs)).exported_program().graph_module
35        validator = EdgeOpArgValidator(egm)
36        validator.run(*inputs)
37        self.assertEqual(len(validator.violating_ops), 0)
38
39    def test_edge_dialect_fails(self) -> None:
40        # torch.bfloat16 is not supported by edge::aten::_log_softmax
41        class M(torch.nn.Module):
42            def __init__(self):
43                super().__init__()
44                self.m = torch.nn.LogSoftmax(dim=1)
45
46            def forward(self, x):
47                return self.m(x)
48
49        inputs = (torch.randn(1, 3, 100, 100).to(dtype=torch.bfloat16),)
50        egm = (
51            to_edge(
52                export(M(), inputs),
53                compile_config=EdgeCompileConfig(_check_ir_validity=False),
54            )
55            .exported_program()
56            .graph_module
57        )
58        validator = EdgeOpArgValidator(egm)
59        validator.run(*inputs)
60        self.assertEqual(len(validator.violating_ops), 1)
61        key: EdgeOpOverload = next(iter(validator.violating_ops))
62        self.assertEqual(
63            key.name(),
64            ops.edge.aten._log_softmax.default.name(),
65        )
66        self.assertDictEqual(
67            validator.violating_ops[key],
68            {
69                "self": torch.bfloat16,
70                "__ret_0": torch.bfloat16,
71            },
72        )
73