xref: /aosp_15_r20/external/pytorch/tools/test/gen_oplist_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Copyright 2004-present Facebook. All Rights Reserved.
3
4import unittest
5from unittest.mock import MagicMock
6
7from tools.code_analyzer.gen_oplist import throw_if_any_op_includes_overloads
8
9
10class GenOplistTest(unittest.TestCase):
11    def setUp(self) -> None:
12        pass
13
14    def test_throw_if_any_op_includes_overloads(self) -> None:
15        selective_builder = MagicMock()
16        selective_builder.operators = MagicMock()
17        selective_builder.operators.items.return_value = [
18            ("op1", MagicMock(include_all_overloads=True)),
19            ("op2", MagicMock(include_all_overloads=False)),
20            ("op3", MagicMock(include_all_overloads=True)),
21        ]
22
23        self.assertRaises(
24            Exception, throw_if_any_op_includes_overloads, selective_builder
25        )
26
27        selective_builder.operators.items.return_value = [
28            ("op1", MagicMock(include_all_overloads=False)),
29            ("op2", MagicMock(include_all_overloads=False)),
30            ("op3", MagicMock(include_all_overloads=False)),
31        ]
32
33        # Here we do not expect it to throw an exception since none of the ops
34        # include all overloads.
35        throw_if_any_op_includes_overloads(selective_builder)
36