1from __future__ import annotations 2 3from dataclasses import dataclass 4 5 6# This class holds information about a single operator used to determine 7# the outcome of a selective/custom PyTorch build that doesn't include 8# registration code for all the supported operators. This is done to 9# reduce the size of the generated binary so that it can be deployed in 10# situations where binary size comes at a premium. 11# 12@dataclass(frozen=True) 13class SelectiveBuildOperator: 14 # The name of the operator. This includes the aten::, etc... prefix 15 # The operator name may or may not have the overload name. If this 16 # operator name does not specify an overload name, the way to determine 17 # if this entry refers to the family of operators with this base name 18 # or just the operator with this name is to look at the value of the 19 # 'include_all_overloads' flag in this class. 20 name: str 21 22 # True if this is a root operator (i.e. called directly from a 23 # TorchScript model, etc...). An operator is considered to be a 24 # root operator if it is called directly from any one of the models 25 # that this instance of the pytorch library was built for. Hence, it 26 # may not be a root operator in all of the models that are used in 27 # this instance of the pytorch library. 28 is_root_operator: bool 29 30 # Is this operator used for on-device training? If True, then we need to 31 # use the information to generate code in VariableType_N.cpp for registration 32 # of training related operators. Again, this is True if this operator 33 # is used for training in one or more models used by this instance of the 34 # pytorch library. 35 is_used_for_training: bool 36 37 # If True, it indicates that this operator instance (object) refers to an 38 # operator without the overload name and should apply to all overloads 39 # which have this operator name as the base name. This flag is applicable 40 # only for objects that have operator names without a DOT (period) character 41 # in them. 42 # 43 # Note: This flag is a temporary workaround to grandfather in the current 44 # static selective (custom) build mechanism, which largely ignores overload 45 # names when determining whether to select operators for registration 46 # purposes. 47 include_all_overloads: bool 48 49 # Debug Information at the operator level 50 _debug_info: tuple[str, ...] | None 51 52 @staticmethod 53 def from_yaml_dict( 54 op_name: str, op_info: dict[str, object] 55 ) -> SelectiveBuildOperator: 56 allowed_keys = { 57 "name", 58 "is_root_operator", 59 "is_used_for_training", 60 "include_all_overloads", 61 "debug_info", 62 } 63 64 if len(set(op_info.keys()) - allowed_keys) > 0: 65 raise Exception( # noqa: TRY002 66 "Got unexpected top level keys: {}".format( 67 ",".join(set(op_info.keys()) - allowed_keys), 68 ) 69 ) 70 71 if "name" in op_info: 72 assert op_name == op_info["name"] 73 74 is_root_operator = op_info.get("is_root_operator", True) 75 assert isinstance(is_root_operator, bool) 76 77 is_used_for_training = op_info.get("is_used_for_training", True) 78 assert isinstance(is_used_for_training, bool) 79 80 include_all_overloads = op_info.get("include_all_overloads", True) 81 assert isinstance(include_all_overloads, bool) 82 83 debug_info: tuple[str, ...] | None = None 84 if "debug_info" in op_info: 85 di_list = op_info["debug_info"] 86 assert isinstance(di_list, list) 87 debug_info = tuple(str(x) for x in di_list) 88 89 return SelectiveBuildOperator( 90 name=op_name, 91 is_root_operator=is_root_operator, 92 is_used_for_training=is_used_for_training, 93 include_all_overloads=include_all_overloads, 94 _debug_info=debug_info, 95 ) 96 97 @staticmethod 98 def from_legacy_operator_name_without_overload( 99 name: str, 100 ) -> SelectiveBuildOperator: 101 return SelectiveBuildOperator( 102 name=name, 103 is_root_operator=True, 104 is_used_for_training=True, 105 include_all_overloads=True, 106 _debug_info=None, 107 ) 108 109 def to_dict(self) -> dict[str, object]: 110 ret: dict[str, object] = { 111 "is_root_operator": self.is_root_operator, 112 "is_used_for_training": self.is_used_for_training, 113 "include_all_overloads": self.include_all_overloads, 114 } 115 if self._debug_info is not None: 116 ret["debug_info"] = self._debug_info 117 118 return ret 119 120 121def merge_debug_info( 122 lhs: tuple[str, ...] | None, 123 rhs: tuple[str, ...] | None, 124) -> tuple[str, ...] | None: 125 # Ensure that when merging, each entry shows up just once. 126 if lhs is None and rhs is None: 127 return None 128 129 return tuple(set((lhs or ()) + (rhs or ()))) 130 131 132def combine_operators( 133 lhs: SelectiveBuildOperator, rhs: SelectiveBuildOperator 134) -> SelectiveBuildOperator: 135 if str(lhs.name) != str(rhs.name): 136 raise Exception( # noqa: TRY002 137 f"Expected both arguments to have the same name, but got '{str(lhs.name)}' and '{str(rhs.name)}' instead" 138 ) 139 140 return SelectiveBuildOperator( 141 name=lhs.name, 142 # Consider this operator to be a root operator if it is a 143 # root operator in any of the models used in this instance of 144 # the pytorch library. 145 is_root_operator=lhs.is_root_operator or rhs.is_root_operator, 146 # Consider this operator to be a training operator if it is 147 # an operator used for training in any of the models used 148 # in this instance of the pytorch library. 149 is_used_for_training=lhs.is_used_for_training or rhs.is_used_for_training, 150 include_all_overloads=lhs.include_all_overloads or rhs.include_all_overloads, 151 _debug_info=merge_debug_info(lhs._debug_info, rhs._debug_info), 152 ) 153 154 155def merge_operator_dicts( 156 lhs: dict[str, SelectiveBuildOperator], 157 rhs: dict[str, SelectiveBuildOperator], 158) -> dict[str, SelectiveBuildOperator]: 159 operators: dict[str, SelectiveBuildOperator] = {} 160 for op_name, op in list(lhs.items()) + list(rhs.items()): 161 new_op = op 162 if op_name in operators: 163 new_op = combine_operators(operators[op_name], op) 164 165 operators[op_name] = new_op 166 167 return operators 168 169 170def strip_operator_overload_name(op_name: str) -> str: 171 return op_name.split(".")[0] 172