1#!/usr/bin/env python3 2# Copyright 2004-present Facebook. All Rights Reserved. 3 4import argparse 5import json 6import unittest 7from collections import defaultdict 8from unittest.mock import Mock, patch 9 10from gen_operators_yaml import ( 11 fill_output, 12 get_parser_options, 13 make_filter_from_options, 14 verify_all_specified_present, 15) 16 17 18def _mock_options(): 19 options = argparse.Namespace() 20 options.root_ops = "aten::add,aten::cat" 21 options.training_root_ops = [] 22 options.output_path = "/tmp" 23 options.dep_graph_yaml_path = "dummy_pytorch_op_deps.yaml" 24 options.model_name = "test_model" 25 options.model_versions = None 26 options.model_assets = None 27 options.model_backends = None 28 options.models_yaml_path = None 29 options.include_all_operators = False 30 options.rule_name = "test_rule" 31 options.not_include_all_overloads_static_root_ops = True 32 options.not_include_all_overloads_closure_ops = True 33 34 return options 35 36 37def _mock_load_op_dep_graph(): 38 result = defaultdict(set) 39 result["aten::add"] = {"aten::add", "aten::as_strided_"} 40 result["aten::cat"] = {"aten::cat", "aten::as_strided_"} 41 return dict(result) 42 43 44class GenOperatorsYAMLTest(unittest.TestCase): 45 def setUp(self) -> None: 46 pass 47 48 def test_filter_creation(self) -> None: 49 filter_func = make_filter_from_options( 50 model_name="abc", 51 model_versions=["100", "101"], 52 model_assets=None, 53 model_backends=None, 54 ) 55 config = [ 56 { 57 "model": { 58 "name": "abc", 59 "version": 100, 60 "asset": "asset-1", 61 "backend": "CPU", 62 }, 63 "root_operators": [], 64 "traced_operators": [], 65 }, 66 { 67 "model": { 68 "name": "abc", 69 "version": 102, 70 "asset": "asset-1", 71 "backend": "CPU", 72 }, 73 "root_operators": [], 74 }, 75 { 76 "model": { 77 "name": "abcd", 78 "version": 100, 79 "asset": "asset-1", 80 "backend": "CPU", 81 }, 82 "root_operators": [], 83 "traced_operators": [], 84 }, 85 { 86 "model": { 87 "name": "abc", 88 "version": 101, 89 "asset": "asset-2", 90 "backend": "CPU", 91 }, 92 "root_operators": [], 93 }, 94 ] 95 96 filtered_configs = list(filter(filter_func, config)) 97 assert ( 98 len(filtered_configs) == 2 99 ), f"Expected 2 elements in filtered_configs, but got {len(filtered_configs)}" 100 101 def test_verification_success(self) -> None: 102 filter_func = make_filter_from_options( 103 model_name="abc", 104 model_versions=["100", "101"], 105 model_assets=["asset-1", "asset-2"], 106 model_backends=None, 107 ) 108 config = [ 109 { 110 "model": { 111 "name": "abc", 112 "version": 100, 113 "asset": "asset-1", 114 "backend": "CPU", 115 }, 116 "root_operators": [], 117 "traced_operators": [], 118 }, 119 { 120 "model": { 121 "name": "abc", 122 "version": 101, 123 "asset": "asset-2", 124 "backend": "CPU", 125 }, 126 "root_operators": [], 127 }, 128 ] 129 filtered_configs = list(filter(filter_func, config)) 130 try: 131 verify_all_specified_present( 132 model_assets=["asset-1", "asset-2"], 133 model_versions=["100", "101"], 134 selected_models_yaml=filtered_configs, 135 rule_name="test", 136 model_name="abc", 137 new_style_rule=True, 138 ) 139 except Exception: 140 self.fail( 141 "expected verify_all_specified_present to succeed instead it raised an exception" 142 ) 143 144 def test_verification_fail(self) -> None: 145 config = [ 146 { 147 "model": { 148 "name": "abc", 149 "version": 100, 150 "asset": "asset-1", 151 "backend": "CPU", 152 }, 153 "root_operators": [], 154 "traced_operators": [], 155 }, 156 { 157 "model": { 158 "name": "abc", 159 "version": 101, 160 "asset": "asset-2", 161 "backend": "CPU", 162 }, 163 "root_operators": [], 164 }, 165 ] 166 167 good_assets = ["asset-1", "asset-2"] 168 good_versions = ["100", "101"] 169 good_name = "abc" 170 171 # Test bad asset 172 filter_func_bad_asset = make_filter_from_options( 173 model_name=good_name, 174 model_versions=good_versions, 175 model_assets=["asset-1", "asset-2", "asset-3"], 176 model_backends=None, 177 ) 178 filtered_configs_asset = list(filter(filter_func_bad_asset, config)) 179 with self.assertRaises(RuntimeError): 180 verify_all_specified_present( 181 model_assets=["asset-1", "asset-2", "asset-3"], 182 model_versions=good_versions, 183 selected_models_yaml=filtered_configs_asset, 184 rule_name="test", 185 model_name=good_name, 186 new_style_rule=True, 187 ) 188 189 # Test bad version 190 filter_func_bad_version = make_filter_from_options( 191 model_name=good_name, 192 model_versions=["100", "101", "102"], 193 model_assets=good_assets, 194 model_backends=None, 195 ) 196 filtered_configs_version = list(filter(filter_func_bad_version, config)) 197 with self.assertRaises(RuntimeError): 198 verify_all_specified_present( 199 model_assets=good_assets, 200 model_versions=["100", "101", "102"], 201 selected_models_yaml=filtered_configs_version, 202 rule_name="test", 203 model_name=good_name, 204 new_style_rule=True, 205 ) 206 207 # Test bad name 208 filter_func_bad_name = make_filter_from_options( 209 model_name="abcd", 210 model_versions=good_versions, 211 model_assets=good_assets, 212 model_backends=None, 213 ) 214 filtered_configs_name = list(filter(filter_func_bad_name, config)) 215 with self.assertRaises(RuntimeError): 216 verify_all_specified_present( 217 model_assets=good_assets, 218 model_versions=good_versions, 219 selected_models_yaml=filtered_configs_name, 220 rule_name="test", 221 model_name="abcd", 222 new_style_rule=True, 223 ) 224 225 @patch("gen_operators_yaml.parse_options", return_value=_mock_options()) 226 @patch( 227 "gen_operators_yaml.load_op_dep_graph", return_value=_mock_load_op_dep_graph() 228 ) 229 def test_fill_output_with_arguments_not_include_all_overloads( 230 self, mock_parse_options: Mock, mock_load_op_dep_graph: Mock 231 ) -> None: 232 parser = argparse.ArgumentParser(description="Generate used operators YAML") 233 options = get_parser_options(parser) 234 235 model_dict = { 236 "model_name": options.model_name, 237 "asset_info": {}, 238 "is_new_style_rule": False, 239 } 240 output = {"debug_info": [json.dumps(model_dict)]} 241 242 fill_output(output, options) 243 244 for op_val in output["operators"].values(): 245 self.assertFalse(op_val["include_all_overloads"]) 246