xref: /aosp_15_r20/external/pytorch/tools/test/gen_operators_yaml_test.py (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1#!/usr/bin/env python3
2# Copyright 2004-present Facebook. All Rights Reserved.
3
4import argparse
5import json
6import unittest
7from collections import defaultdict
8from unittest.mock import Mock, patch
9
10from gen_operators_yaml import (
11    fill_output,
12    get_parser_options,
13    make_filter_from_options,
14    verify_all_specified_present,
15)
16
17
18def _mock_options():
19    options = argparse.Namespace()
20    options.root_ops = "aten::add,aten::cat"
21    options.training_root_ops = []
22    options.output_path = "/tmp"
23    options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml"
24    options.model_name = "test_model"
25    options.model_versions = None
26    options.model_assets = None
27    options.model_backends = None
28    options.models_yaml_path = None
29    options.include_all_operators = False
30    options.rule_name = "test_rule"
31    options.not_include_all_overloads_static_root_ops = True
32    options.not_include_all_overloads_closure_ops = True
33
34    return options
35
36
37def _mock_load_op_dep_graph():
38    result = defaultdict(set)
39    result["aten::add"] = {"aten::add", "aten::as_strided_"}
40    result["aten::cat"] = {"aten::cat", "aten::as_strided_"}
41    return dict(result)
42
43
44class GenOperatorsYAMLTest(unittest.TestCase):
45    def setUp(self) -> None:
46        pass
47
48    def test_filter_creation(self) -> None:
49        filter_func = make_filter_from_options(
50            model_name="abc",
51            model_versions=["100", "101"],
52            model_assets=None,
53            model_backends=None,
54        )
55        config = [
56            {
57                "model": {
58                    "name": "abc",
59                    "version": 100,
60                    "asset": "asset-1",
61                    "backend": "CPU",
62                },
63                "root_operators": [],
64                "traced_operators": [],
65            },
66            {
67                "model": {
68                    "name": "abc",
69                    "version": 102,
70                    "asset": "asset-1",
71                    "backend": "CPU",
72                },
73                "root_operators": [],
74            },
75            {
76                "model": {
77                    "name": "abcd",
78                    "version": 100,
79                    "asset": "asset-1",
80                    "backend": "CPU",
81                },
82                "root_operators": [],
83                "traced_operators": [],
84            },
85            {
86                "model": {
87                    "name": "abc",
88                    "version": 101,
89                    "asset": "asset-2",
90                    "backend": "CPU",
91                },
92                "root_operators": [],
93            },
94        ]
95
96        filtered_configs = list(filter(filter_func, config))
97        assert (
98            len(filtered_configs) == 2
99        ), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}"
100
101    def test_verification_success(self) -> None:
102        filter_func = make_filter_from_options(
103            model_name="abc",
104            model_versions=["100", "101"],
105            model_assets=["asset-1", "asset-2"],
106            model_backends=None,
107        )
108        config = [
109            {
110                "model": {
111                    "name": "abc",
112                    "version": 100,
113                    "asset": "asset-1",
114                    "backend": "CPU",
115                },
116                "root_operators": [],
117                "traced_operators": [],
118            },
119            {
120                "model": {
121                    "name": "abc",
122                    "version": 101,
123                    "asset": "asset-2",
124                    "backend": "CPU",
125                },
126                "root_operators": [],
127            },
128        ]
129        filtered_configs = list(filter(filter_func, config))
130        try:
131            verify_all_specified_present(
132                model_assets=["asset-1", "asset-2"],
133                model_versions=["100", "101"],
134                selected_models_yaml=filtered_configs,
135                rule_name="test",
136                model_name="abc",
137                new_style_rule=True,
138            )
139        except Exception:
140            self.fail(
141                "expected verify_all_specified_present to succeed instead it raised an exception"
142            )
143
144    def test_verification_fail(self) -> None:
145        config = [
146            {
147                "model": {
148                    "name": "abc",
149                    "version": 100,
150                    "asset": "asset-1",
151                    "backend": "CPU",
152                },
153                "root_operators": [],
154                "traced_operators": [],
155            },
156            {
157                "model": {
158                    "name": "abc",
159                    "version": 101,
160                    "asset": "asset-2",
161                    "backend": "CPU",
162                },
163                "root_operators": [],
164            },
165        ]
166
167        good_assets = ["asset-1", "asset-2"]
168        good_versions = ["100", "101"]
169        good_name = "abc"
170
171        # Test bad asset
172        filter_func_bad_asset = make_filter_from_options(
173            model_name=good_name,
174            model_versions=good_versions,
175            model_assets=["asset-1", "asset-2", "asset-3"],
176            model_backends=None,
177        )
178        filtered_configs_asset = list(filter(filter_func_bad_asset, config))
179        with self.assertRaises(RuntimeError):
180            verify_all_specified_present(
181                model_assets=["asset-1", "asset-2", "asset-3"],
182                model_versions=good_versions,
183                selected_models_yaml=filtered_configs_asset,
184                rule_name="test",
185                model_name=good_name,
186                new_style_rule=True,
187            )
188
189        # Test bad version
190        filter_func_bad_version = make_filter_from_options(
191            model_name=good_name,
192            model_versions=["100", "101", "102"],
193            model_assets=good_assets,
194            model_backends=None,
195        )
196        filtered_configs_version = list(filter(filter_func_bad_version, config))
197        with self.assertRaises(RuntimeError):
198            verify_all_specified_present(
199                model_assets=good_assets,
200                model_versions=["100", "101", "102"],
201                selected_models_yaml=filtered_configs_version,
202                rule_name="test",
203                model_name=good_name,
204                new_style_rule=True,
205            )
206
207        # Test bad name
208        filter_func_bad_name = make_filter_from_options(
209            model_name="abcd",
210            model_versions=good_versions,
211            model_assets=good_assets,
212            model_backends=None,
213        )
214        filtered_configs_name = list(filter(filter_func_bad_name, config))
215        with self.assertRaises(RuntimeError):
216            verify_all_specified_present(
217                model_assets=good_assets,
218                model_versions=good_versions,
219                selected_models_yaml=filtered_configs_name,
220                rule_name="test",
221                model_name="abcd",
222                new_style_rule=True,
223            )
224
225    @patch("gen_operators_yaml.parse_options", return_value=_mock_options())
226    @patch(
227        "gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph()
228    )
229    def test_fill_output_with_arguments_not_include_all_overloads(
230        self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock
231    ) -> None:
232        parser = argparse.ArgumentParser(description="Generate used operators YAML")
233        options = get_parser_options(parser)
234
235        model_dict = {
236            "model_name": options.model_name,
237            "asset_info": {},
238            "is_new_style_rule": False,
239        }
240        output = {"debug_info": [json.dumps(model_dict)]}
241
242        fill_output(output, options)
243
244        for op_val in output["operators"].values():
245            self.assertFalse(op_val["include_all_overloads"])
246