1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 #ifndef TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_ 16 #define TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_ 17 18 #include <Python.h> 19 20 #include <map> 21 #include <string> 22 #include <vector> 23 24 #include "absl/types/span.h" 25 #include "tensorflow/core/framework/op_def.pb.h" 26 #include "tensorflow/core/framework/types.pb.h" 27 #include "tensorflow/core/platform/status.h" 28 #include "tensorflow/python/framework/op_def_util.h" 29 #include "tensorflow/python/framework/python_tensor_converter.h" 30 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" 31 32 namespace tensorflow { 33 34 // Precomputed information about a TensorFlow Python API. 35 // 36 // PythonAPIInfo records information about a single TensorFlow Python API, 37 // in order to allow calls to the API to be executed more efficiently. This 38 // information includes: 39 // 40 // * The name of the API. (E.g. "tf.math.add") 41 // 42 // * The name of the registered op that implements the API, if applicable 43 // (e.g. "AddV2"). 44 // 45 // * Information about the API's parameters. Parameters are divided into two 46 // "kinds": inputs and attributes. An *input* is a parameter that 47 // expects a Tensor or list of Tensors, and it is described by an `ArgDef`. 48 // An *attribute* is a parameter that expects any other value type, and it is 49 // described by an `AttrDef`. 50 // 51 // * Default values for the API's attribute parameters. 52 // 53 // * Information about "inferred attributes" -- attributes whose values are 54 // inferred from `input` parameters. There are two kinds of inferred 55 // attributes: Tensor dtypes, which are inferred from tensor and list(tensor) 56 // parameters; and list lengths, which are inferred from list(tensor) 57 // parameters. 58 class PythonAPIInfo { 59 public: 60 // The index of a parameter in the canonicalized parameter list. The 61 // canonicalized parameter list includes inputs and attributes (but does 62 // not include inferred attributes). `-1` is used for inferred attributes. 63 using ParamIndex = int; 64 65 // Information about a parameter that expects a non-Tensor value. 66 struct Attribute { 67 ParamIndex index; // -1 if this is an inferred attribute 68 AttributeType type; 69 const char* name; // Interned python string 70 int inferred_index; // index to store attribute in InferredAttributes 71 }; 72 73 // Information about a parameter that expects a Tensor or list(Tensor). 74 // Additional information about tensor parameters is stored in types 75 // defined below, in order to simplify dtype/length inference: 76 // * FixedDTypeInput: inputs with fixed dtypes. 77 // * InputsWithTypeAttr: groups inputs that use a type_attr for dtype. 78 // * InputsWithTypeListAttr: groups inputs that use a type_list_attr. 79 // * InputsWithNumberAttr: groups inputs by a number_attr for length. 80 struct Input { 81 ParamIndex index; 82 bool is_list; 83 }; 84 85 // Information about a Tensor parameter w/ fixed dtype. 86 struct InputWithFixedDType { 87 DataType dtype; 88 ParamIndex index; 89 bool is_list; 90 }; 91 92 // Information about Tensor parameters whose DType is specified by a single 93 // `type_attr` attribute. 94 struct InputsWithTypeAttr { 95 Attribute* type_attr; // not owned. 96 DataType default_dtype; // DT_INVALID if no default. 97 std::vector<ParamIndex> tensor_params; // single-tensor inputs. 98 std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs. 99 std::vector<DataType> ok_dtypes; 100 }; 101 102 // Information about Tensor parameters whose DType is specified by a single 103 // `type_list_attr` attribute. 104 struct InputsWithTypeListAttr { 105 Attribute* type_list_attr; // not owned. 106 std::vector<DataType> default_dtypes; // empty if no default. 107 std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs. 108 std::vector<DataType> ok_dtypes; 109 }; 110 111 // Information about Tensor-list parameters whose length is specified by a 112 // single `int` attribute. 113 struct InputsWithNumberAttr { 114 Attribute* number_attr; // not owned. 115 int64_t default_length; // -1 for no default. 116 std::vector<ParamIndex> tensor_list_params; // list(tensor) inputs. 117 }; 118 119 // Structure used to return inferred attribute values. 120 // * types[i] is the inferred value for inferred_type_attrs()[i] 121 // * type_lists[i] is the inferred value for inferred_type_list_attrs()[i] 122 // * lengths[i] is the inferred value for inferred_length_attrs()[i] 123 struct InferredAttributes { 124 std::vector<DataType> types; 125 std::vector<std::vector<DataType>> type_lists; 126 std::vector<int64_t> lengths; 127 }; 128 129 // Constructs a new PythonAPIInfo. 130 // 131 // Note: One of the `Initialize()` functions must be called before the 132 // `PythonAPIInfo` is used. 133 // 134 // Args: 135 // api_name: The fully-qualified name of the python API (e.g., tf.math.sum). 136 explicit PythonAPIInfo(const std::string& api_name); 137 138 // Initializes this PythonAPIInfo. 139 // 140 // Args: 141 // op_def: Contains information about the parameters. 142 // param_names: The argument names for the python API, in canonical order. 143 // defaults_tuple: Tuple containing default values for the parameters, 144 // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default 145 // for `param_names[-i]`. 146 Status Initialize(const OpDef& op_def, const std::vector<string> param_names, 147 PyObject* defaults_tuple); 148 149 // Initialize this PythonAPIInfo based on the registered OpDef for the given 150 // operation. 151 // 152 // Args: 153 // op_name: The registered name of the operation (e.g. "AddV2"). 154 Status InitializeFromRegisteredOp(const std::string& op_name); 155 156 // Initializes this PythonAPIInfo based on a set of parameter specifications. 157 // 158 // Args: 159 // input_specs: Mapping from parameter name to specification string for 160 // each input (parameter that expects a tensor value). 161 // attr_specs: Mapping from parameter name to specification string for 162 // each attribute (parameter that expects a non-tensor value). 163 // param_names: The argument names for the python API, in canonical order. 164 // defaults_tuple: Tuple containing default values for the parameters, 165 // right-aligned with `param_names` -- i.e., `defaults[-i]` is the default 166 // for `param_names[-i]`. 167 // 168 // Note: the `name` parameter should not be included in `input_specs` or 169 // `attr_specs`. 170 Status InitializeFromParamSpecs( 171 const std::map<std::string, std::string>& input_specs, 172 const std::map<std::string, std::string>& attr_specs, 173 const std::vector<string> param_names, PyObject* defaults_tuple); 174 175 // The name of the API that is described by this PythonAPIInfo. api_name()176 const char* api_name() const { return api_name_; } 177 178 // The ordered names of the canononical parameters that this API expects. param_names()179 const std::vector<const char*>& param_names() const { return param_names_; } 180 181 // A Python tuple containing the default values for parameters. This is 182 // right-aligned with `param_name` -- i.e., `defaults[-i]` is the default 183 // for `param_names[-i]`. defaults_tuple()184 const PyObject* defaults_tuple() const { return defaults_tuple_.get(); } 185 186 // Information about the attribute (non-tensor) parameters for this API. attributes()187 const std::vector<Attribute>& attributes() const { return attributes_; } 188 189 // Information about the input (tensor) parameters for this API. inputs()190 const std::vector<Input>& inputs() const { return inputs_; } inputs_with_fixed_dtype()191 const std::vector<InputWithFixedDType>& inputs_with_fixed_dtype() const { 192 return inputs_with_fixed_dtype_; 193 } inputs_with_type_attrs()194 const std::vector<InputsWithTypeAttr>& inputs_with_type_attrs() const { 195 return inputs_with_type_attrs_; 196 } inputs_with_type_list_attrs()197 const std::vector<InputsWithTypeListAttr>& inputs_with_type_list_attrs() 198 const { 199 return inputs_with_type_list_attrs_; 200 } inputs_with_number_attrs()201 const std::vector<InputsWithNumberAttr>& inputs_with_number_attrs() const { 202 return inputs_with_number_attrs_; 203 } 204 205 // Names of inferred attributes. inferred_type_attrs()206 const std::vector<const char*>& inferred_type_attrs() const { 207 return inferred_type_attrs_; 208 } inferred_type_list_attrs()209 const std::vector<const char*>& inferred_type_list_attrs() const { 210 return inferred_type_list_attrs_; 211 } inferred_length_attrs()212 const std::vector<const char*>& inferred_length_attrs() const { 213 return inferred_length_attrs_; 214 } 215 216 // Returns a string summarizing the internal state of this type converter. 217 string DebugInfo() const; 218 219 private: 220 // Adds an entry to the attributes_ vector based on the given `AttrDef`. 221 // 222 // If `attr_def` describes a type attribute, then adds a value to 223 // inputs_with_type_attrs_ or inputs_with_type_list_attrs_ (to record any 224 // tensor inputs that use this dtype). 225 // 226 // If `attr_def` describes an int attribute, then adds a value to 227 // inputs_with_number_attrs_ (to record any tensor inputs that use this 228 // value as a list length). 229 Status InitializeAttribute( 230 const OpDef::AttrDef& attr_def, 231 const std::map<std::string, ParamIndex>& param_name_to_index); 232 233 // Adds an entry to the inputs_ vector based on the given `ArgDef`. 234 // 235 // If `arg_def` has a fixed dtype, then adds a value to `fixed_dtype_inputs`. 236 // 237 // If `arg_def`'s dtype is described by a `type` attr, then updates the 238 // appropriate value in `inputs_with_type_attrs_` with information about the 239 // `arg_def`. 240 // 241 // If `arg_def`'s dtype is described by a `list(type)` attr, then updates the 242 // appropriate value in `inputs_with_type_list_attrs_` with information about 243 // the `arg_def`. 244 Status InitializeInput(const OpDef::ArgDef& arg_def, 245 const std::map<std::string, int>& param_name_to_index); 246 247 // Checks that the OpDef used to initialize this PythonAPIInfo 248 // had an AttrDef or ArgDef specification for each parameter. 249 Status CheckParamNames() const; 250 251 // Searches inputs_with_type_attrs_ for an input with the given name. 252 InputsWithTypeAttr* FindInputsWithTypeAttr(const string& name); 253 254 // Searches inputs_with_type_list_attrs_ for an input with the given name. 255 InputsWithTypeListAttr* FindInputsWithTypeListAttr(const string& name); 256 257 // Searches inputs_with_type_list_attrs_ for an input with the given name. 258 InputsWithNumberAttr* FindInputsWithNumberAttr(const string& name); 259 260 ABSL_MUST_USE_RESULT 261 bool InferLengthAttributes(const absl::Span<PyObject*> params, 262 std::vector<int64_t>& inferred_length_attrs) const; 263 264 // ========================================================================== 265 // Member Variables 266 // ========================================================================== 267 268 // The name of the API that is described by this PythonAPIInfo. 269 // (Interned python string). 270 const char* api_name_; 271 272 // The names of the parameters that this API expects. 273 // (Interned python strings.) 274 std::vector<const char*> param_names_; 275 276 // Tuple containing default values for the parameters, right-aligned with 277 // `param_names` -- i.e., `defaults[-i]` is the default for `param_names[-i]`. 278 Safe_PyObjectPtr defaults_tuple_; 279 280 // Information about the non-tensor-valued parameters that this API expects. 281 std::vector<Attribute> attributes_; 282 283 // Information about the tensor-valued parameters that this API expects. 284 std::vector<Input> inputs_; 285 std::vector<InputWithFixedDType> inputs_with_fixed_dtype_; 286 std::vector<InputsWithTypeAttr> inputs_with_type_attrs_; 287 std::vector<InputsWithTypeListAttr> inputs_with_type_list_attrs_; 288 std::vector<InputsWithNumberAttr> inputs_with_number_attrs_; 289 290 // Names of inferred attributes. (Interned python strings.) 291 std::vector<const char*> inferred_type_attrs_; 292 std::vector<const char*> inferred_type_list_attrs_; 293 std::vector<const char*> inferred_length_attrs_; 294 }; 295 296 } // namespace tensorflow 297 298 #endif // TENSORFLOW_PYTHON_UTIL_PYTHON_API_INFO_H_ 299