1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 // Support for API dispatch at the Python level. 16 // 17 // The dispatcher is implemented in c++ for efficiency. 18 // 19 // * PythonAPIDispatcher: Class that handles dispatch for a single Python API. 20 // Contains a mapping from PySignatureCheckers to dispatch targets (python 21 // functions). 22 // 23 // * PySignatureChecker: Class to efficiently check whether dispatch should be 24 // invoked for a given set of parameters. Contains a collection of 25 // PyTypeCheckers. 26 // 27 // * PyTypeChecker: Class to efficiently check whether a Python value matches 28 // a type annotation. Three subclasses (PyInstanceChecker, PyListChecker, 29 // and PyUnionChecker) handle the different kinds of type annotation. 30 31 #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ 32 #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ 33 34 #include <Python.h> 35 36 #include <string> 37 #include <vector> 38 39 #include "absl/container/flat_hash_map.h" 40 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" 41 #include "tensorflow/python/util/function_parameter_canonicalizer.h" 42 43 namespace tensorflow { 44 45 namespace py_dispatch { 46 47 class PyTypeChecker; 48 class PySignatureChecker; 49 50 // Dispatcher for a single TensorFlow Python API (e.g. `tf.add` or `tf.concat`). 51 // 52 // A separate `PythonAPIDispatcher` object is created for each API, and handles 53 // dispatch for that API. The `Register` method can be used to add new 54 // "dispatch targets", which override the default behavior of the API when it 55 // is called with parameters matching a given signature. The `Dispatch` method 56 // checks if any registered target matches parameters, and if so, then calls 57 // that target. 58 // 59 // This class is *not* thread-safe. It is assumed that the Python Global 60 // Interpreter Lock (GIL) will be held when any method is called. 61 class PythonAPIDispatcher { 62 // TODO(b/196369143) Add benchmarking for this class. 63 public: 64 // Creates a new PythonAPIDispatcher for the named API. 65 // 66 // Args: 67 // api_name: The name of the API (used for error messages). 68 // arg_names: The argument names (used for parameter canonicalization). 69 // defaults: The argument defaults, as returned by `inspect.getargspec` 70 // (used for parameter canonicalization). 71 PythonAPIDispatcher(const std::string& api_name, 72 absl::Span<const char*> arg_names, 73 absl::Span<PyObject*> defaults); 74 75 // Registers a new dispatch target for this dispatcher. If the API is 76 // called with parameters that match `signature_checker`, then 77 // `dispatch_target` will be called instead of the default API implementation. 78 void Register(PySignatureChecker signature_checker, 79 PyObject* dispatch_target); 80 81 // Performs dispatch with the given set of parameters. 82 // 83 // * If a single target matches the parameters, then that target is called. 84 // * If multiple targets match the parameters, then an exception is raised. 85 // * If no targets match the parameters, then returns `Py_NotImplemented`. 86 // 87 // On error, returns nullptr and sets a Python exception. 88 Safe_PyObjectPtr Dispatch(PyObject* args, PyObject* kwargs); 89 90 // Remove a dispatch target from this dispatcher. If the target was 91 // registered with multiple signatures, then all entries will be removed. 92 // (This method is primarily intended for regression tests.) 93 void Unregister(PyObject* func); 94 95 std::string DebugString() const; 96 97 private: 98 // Name of the API. 99 std::string api_name_; 100 101 // Mapping from signature checkers to dispatch targets. 102 std::vector<std::pair<PySignatureChecker, Safe_PyObjectPtr>> targets_; 103 104 // Parameter canonicalizer. 105 FunctionParameterCanonicalizer canonicalizer_; 106 107 // Target storage for canonicalization. (Note: for efficiency, `Dispatch` 108 // writes to this pre-allocated storage, rather than allocating new storage 109 // each time it is called.) 110 std::vector<PyObject*> canonicalized_args_storage_; 111 absl::Span<PyObject*> canonicalized_args_span_; 112 }; 113 114 // Registers a type for use with dispatch. Dispatch will only occur if at least 115 // one parameter value matches an annotation corresponding to a registered 116 // dispatchable type. 117 // 118 // Returns true on success; or sets a Python exception and returns false 119 // on error. 120 // 121 // Must be called before any PyInstanceChecker object is created from py_class. 122 // 123 // (Note: the CompositeTensor class is automatically registered for dispatch, 124 // so you do not need to use this method for any class that is a subclass of 125 // CompositeTensor or ExtensionType.) 126 bool RegisterDispatchableType(PyObject* py_class); 127 128 // Class used by dispatch to check if parameters' values match a signature. 129 // 130 // Currently only supports checking parameters with kind POSITIONAL_ONLY or 131 // POSITIONAL_OR_KEYWORD. (Does not support checking parameters with kind 132 // VAR_POSITIONAL, VAR_KEYWORD, or KEYWORD_ONLY.) 133 class PySignatureChecker { 134 public: 135 // A parameter index and a TypeChecker for the parameter at that index. 136 using ParamChecker = std::pair<int, std::shared_ptr<PyTypeChecker>>; 137 138 // Constructs a signature checker that will check the specified positional 139 // parameters. 140 explicit PySignatureChecker(std::vector<ParamChecker> parameter_checkers); 141 142 // Returns true if the given canonicalized arguments match this signature 143 // checker. 144 bool CheckCanonicalizedArgs(absl::Span<PyObject*> canon_args) const; 145 146 std::string DebugString() const; 147 148 private: 149 // Type checkers for individual parameters. Only annotated parameters will 150 // be checked. This list is sorted to perform less expensive checks first. 151 // E.g., we check simple values before list values. 152 std::vector<ParamChecker> positional_parameter_checkers_; 153 }; 154 155 // Abstract base class that checks if a Python value matches a type annotation. 156 // 157 // Subclasses of PyTypeChecker are defined for different annotations (List, 158 // Union, etc). Currently, we support the minimum set of type checkers that are 159 // required for CompositeTensor dispatch -- namely, `List`, `Union`, and simple 160 // types (`IsInstance`). Support for additional annotations may be added in the 161 // future. 162 class PyTypeChecker { 163 public: 164 using PyTypeChecker_ptr = std::shared_ptr<PyTypeChecker>; 165 PyTypeChecker() = default; 166 PyTypeChecker(const PyTypeChecker&) = delete; 167 PyTypeChecker(PyTypeChecker&&) = delete; ~PyTypeChecker()168 virtual ~PyTypeChecker() {} 169 170 // Enumeration used to indicate whether a Python value matches a type 171 // annotation. MATCH and NO_MATCH simply indicate whether a value matches the 172 // annotation. 173 // 174 // MATCH_DISPATCHABLE indicates that a value matches the annotation, and 175 // additionally that the value (or one of its nested values) matched a type 176 // that has been registered for dispatch. This is important information 177 // because we only want to perform dispatch if at least one such value 178 // matches. Otherwise, we would end up using dispatch in undesirable cases. 179 // Examples: 180 // 181 // @tf.dispatch_for(tf.concat)(x=List[MyType]) 182 // 183 // We should not dispatch to `my_concat` when the user calls 184 // `tf.concat([])` (even though it's technically true that the empty 185 // list satisfies the type annotation `List[MyType]`). 186 // 187 // @tf.dispatch_for(tf.add)(x=Union[MyType, Tensor], y=Union[MyType, Tensor]) 188 // 189 // We should not dispatch to `my_add` when the user calls 190 // `tf.add(tf.constant(1), tf.constant(2))` (even though this technically 191 // matches the annotated types). 192 enum class MatchType { NO_MATCH, MATCH, MATCH_DISPATCHABLE }; 193 194 // Returns a value indicating how this type checker matched with the given 195 // value. 196 virtual MatchType Check(PyObject* value) = 0; 197 198 // Approximate cost of calling this type checker, so we can perform less 199 // expensive checks first. (E.g., checking if every element in a list has a 200 // given type is more costly than checking a single value.) 201 virtual int cost() const = 0; 202 203 virtual std::string DebugString() const = 0; 204 }; 205 206 // PyTypeChecker that checks if a value is an instance of a given Python type. 207 class PyInstanceChecker : public PyTypeChecker { 208 public: 209 explicit PyInstanceChecker(const std::vector<PyObject*>& py_classes); 210 ~PyInstanceChecker() override; 211 MatchType Check(PyObject* value) override; 212 int cost() const override; 213 std::string DebugString() const override; 214 215 // Size of the cache (for regression testing). cache_size()216 size_t cache_size() const { return py_class_cache_.size(); } 217 218 private: 219 // Python class to check values against. 220 std::vector<Safe_PyObjectPtr> py_classes_; 221 222 // Cache to avoid having to call PyObject_IsInstance. Note: we rely on the 223 // Python GIL (global interpreter lock) to avoid concurrent writes to this 224 // cache, since `Check()` is always called from Python (via pybind11). 225 absl::flat_hash_map<PyTypeObject*, MatchType> py_class_cache_; 226 227 // Maximum cache size. In typical user programs, the cache will never become 228 // full, but we use a maximum size in case the user creates types dynamically, 229 // to avoid having an unbounded number of items in the cache. 230 // TODO(b/194903203) Consider switching to an LRU cache. 231 static constexpr int kMaxItemsInCache = 1024; 232 }; 233 234 // PyTypeChecker that checks if a value is a list whose elements all match a 235 // nested PyTypeChecker. 236 class PyListChecker : public PyTypeChecker { 237 public: PyListChecker(PyTypeChecker_ptr element_type)238 explicit PyListChecker(PyTypeChecker_ptr element_type) 239 : element_type_(element_type) {} 240 MatchType Check(PyObject* value) override; 241 int cost() const override; 242 std::string DebugString() const override; 243 244 private: 245 PyTypeChecker_ptr element_type_; 246 }; 247 248 // PyTypeChecker that checks if a value matches at least one nested 249 // PyTypeChecker. 250 class PyUnionChecker : public PyTypeChecker { 251 public: PyUnionChecker(std::vector<PyTypeChecker_ptr> options)252 explicit PyUnionChecker(std::vector<PyTypeChecker_ptr> options) 253 : options_(options) {} 254 MatchType Check(PyObject* value) override; 255 int cost() const override; 256 std::string DebugString() const override; 257 258 private: 259 std::vector<PyTypeChecker_ptr> options_; 260 }; 261 262 } // namespace py_dispatch 263 } // namespace tensorflow 264 265 #endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ 266