xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_api_dispatcher.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 // Support for API dispatch at the Python level.
16 //
17 // The dispatcher is implemented in c++ for efficiency.
18 //
19 // * PythonAPIDispatcher: Class that handles dispatch for a single Python API.
20 //   Contains a mapping from PySignatureCheckers to dispatch targets (python
21 //   functions).
22 //
23 // * PySignatureChecker: Class to efficiently check whether dispatch should be
24 //   invoked for a given set of parameters.  Contains a collection of
25 //   PyTypeCheckers.
26 //
27 // * PyTypeChecker: Class to efficiently check whether a Python value matches
28 //   a type annotation.  Three subclasses (PyInstanceChecker, PyListChecker,
29 //   and PyUnionChecker) handle the different kinds of type annotation.
30 
31 #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
32 #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
33 
34 #include <Python.h>
35 
36 #include <string>
37 #include <vector>
38 
39 #include "absl/container/flat_hash_map.h"
40 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
41 #include "tensorflow/python/util/function_parameter_canonicalizer.h"
42 
43 namespace tensorflow {
44 
45 namespace py_dispatch {
46 
47 class PyTypeChecker;
48 class PySignatureChecker;
49 
50 // Dispatcher for a single TensorFlow Python API (e.g. `tf.add` or `tf.concat`).
51 //
52 // A separate `PythonAPIDispatcher` object is created for each API, and handles
53 // dispatch for that API.  The `Register` method can be used to add new
54 // "dispatch targets", which override the default behavior of the API when it
55 // is called with parameters matching a given signature.  The `Dispatch` method
56 // checks if any registered target matches parameters, and if so, then calls
57 // that target.
58 //
59 // This class is *not* thread-safe.  It is assumed that the Python Global
60 // Interpreter Lock (GIL) will be held when any method is called.
61 class PythonAPIDispatcher {
62   // TODO(b/196369143) Add benchmarking for this class.
63  public:
64   // Creates a new PythonAPIDispatcher for the named API.
65   //
66   // Args:
67   //   api_name: The name of the API (used for error messages).
68   //   arg_names: The argument names (used for parameter canonicalization).
69   //   defaults: The argument defaults, as returned by `inspect.getargspec`
70   //       (used for parameter canonicalization).
71   PythonAPIDispatcher(const std::string& api_name,
72                       absl::Span<const char*> arg_names,
73                       absl::Span<PyObject*> defaults);
74 
75   // Registers a new dispatch target for this dispatcher.  If the API is
76   // called with parameters that match `signature_checker`, then
77   // `dispatch_target` will be called instead of the default API implementation.
78   void Register(PySignatureChecker signature_checker,
79                 PyObject* dispatch_target);
80 
81   // Performs dispatch with the given set of parameters.
82   //
83   // * If a single target matches the parameters, then that target is called.
84   // * If multiple targets match the parameters, then an exception is raised.
85   // * If no targets match the parameters, then returns `Py_NotImplemented`.
86   //
87   // On error, returns nullptr and sets a Python exception.
88   Safe_PyObjectPtr Dispatch(PyObject* args, PyObject* kwargs);
89 
90   // Remove a dispatch target from this dispatcher.  If the target was
91   // registered with multiple signatures, then all entries will be removed.
92   // (This method is primarily intended for regression tests.)
93   void Unregister(PyObject* func);
94 
95   std::string DebugString() const;
96 
97  private:
98   // Name of the API.
99   std::string api_name_;
100 
101   // Mapping from signature checkers to dispatch targets.
102   std::vector<std::pair<PySignatureChecker, Safe_PyObjectPtr>> targets_;
103 
104   // Parameter canonicalizer.
105   FunctionParameterCanonicalizer canonicalizer_;
106 
107   // Target storage for canonicalization.  (Note: for efficiency, `Dispatch`
108   // writes to this pre-allocated storage, rather than allocating new storage
109   // each time it is called.)
110   std::vector<PyObject*> canonicalized_args_storage_;
111   absl::Span<PyObject*> canonicalized_args_span_;
112 };
113 
114 // Registers a type for use with dispatch.  Dispatch will only occur if at least
115 // one parameter value matches an annotation corresponding to a registered
116 // dispatchable type.
117 //
118 // Returns true on success; or sets a Python exception and returns false
119 // on error.
120 //
121 // Must be called before any PyInstanceChecker object is created from py_class.
122 //
123 // (Note: the CompositeTensor class is automatically registered for dispatch,
124 // so you do not need to use this method for any class that is a subclass of
125 // CompositeTensor or ExtensionType.)
126 bool RegisterDispatchableType(PyObject* py_class);
127 
128 // Class used by dispatch to check if parameters' values match a signature.
129 //
130 // Currently only supports checking parameters with kind POSITIONAL_ONLY or
131 // POSITIONAL_OR_KEYWORD.  (Does not support checking parameters with kind
132 // VAR_POSITIONAL, VAR_KEYWORD, or KEYWORD_ONLY.)
133 class PySignatureChecker {
134  public:
135   // A parameter index and a TypeChecker for the parameter at that index.
136   using ParamChecker = std::pair<int, std::shared_ptr<PyTypeChecker>>;
137 
138   // Constructs a signature checker that will check the specified positional
139   // parameters.
140   explicit PySignatureChecker(std::vector<ParamChecker> parameter_checkers);
141 
142   // Returns true if the given canonicalized arguments match this signature
143   // checker.
144   bool CheckCanonicalizedArgs(absl::Span<PyObject*> canon_args) const;
145 
146   std::string DebugString() const;
147 
148  private:
149   // Type checkers for individual parameters.  Only annotated parameters will
150   // be checked.  This list is sorted to perform less expensive checks first.
151   // E.g., we check simple values before list values.
152   std::vector<ParamChecker> positional_parameter_checkers_;
153 };
154 
155 // Abstract base class that checks if a Python value matches a type annotation.
156 //
157 // Subclasses of PyTypeChecker are defined for different annotations (List,
158 // Union, etc). Currently, we support the minimum set of type checkers that are
159 // required for CompositeTensor dispatch -- namely, `List`, `Union`, and simple
160 // types (`IsInstance`).  Support for additional annotations may be added in the
161 // future.
162 class PyTypeChecker {
163  public:
164   using PyTypeChecker_ptr = std::shared_ptr<PyTypeChecker>;
165   PyTypeChecker() = default;
166   PyTypeChecker(const PyTypeChecker&) = delete;
167   PyTypeChecker(PyTypeChecker&&) = delete;
~PyTypeChecker()168   virtual ~PyTypeChecker() {}
169 
170   // Enumeration used to indicate whether a Python value matches a type
171   // annotation. MATCH and NO_MATCH simply indicate whether a value matches the
172   // annotation.
173   //
174   // MATCH_DISPATCHABLE indicates that a value matches the annotation, and
175   // additionally that the value (or one of its nested values) matched a type
176   // that has been registered for dispatch. This is important information
177   // because we only want to perform dispatch if at least one such value
178   // matches. Otherwise, we would end up using dispatch in undesirable cases.
179   // Examples:
180   //
181   // @tf.dispatch_for(tf.concat)(x=List[MyType])
182   //
183   //    We should not dispatch to `my_concat` when the user calls
184   //    `tf.concat([])` (even though it's technically true that the empty
185   //    list satisfies the type annotation `List[MyType]`).
186   //
187   // @tf.dispatch_for(tf.add)(x=Union[MyType, Tensor], y=Union[MyType, Tensor])
188   //
189   //   We should not dispatch to `my_add` when the user calls
190   //   `tf.add(tf.constant(1), tf.constant(2))` (even though this technically
191   //   matches the annotated types).
192   enum class MatchType { NO_MATCH, MATCH, MATCH_DISPATCHABLE };
193 
194   // Returns a value indicating how this type checker matched with the given
195   // value.
196   virtual MatchType Check(PyObject* value) = 0;
197 
198   // Approximate cost of calling this type checker, so we can perform less
199   // expensive checks first.  (E.g., checking if every element in a list has a
200   // given type is more costly than checking a single value.)
201   virtual int cost() const = 0;
202 
203   virtual std::string DebugString() const = 0;
204 };
205 
206 // PyTypeChecker that checks if a value is an instance of a given Python type.
207 class PyInstanceChecker : public PyTypeChecker {
208  public:
209   explicit PyInstanceChecker(const std::vector<PyObject*>& py_classes);
210   ~PyInstanceChecker() override;
211   MatchType Check(PyObject* value) override;
212   int cost() const override;
213   std::string DebugString() const override;
214 
215   // Size of the cache (for regression testing).
cache_size()216   size_t cache_size() const { return py_class_cache_.size(); }
217 
218  private:
219   // Python class to check values against.
220   std::vector<Safe_PyObjectPtr> py_classes_;
221 
222   // Cache to avoid having to call PyObject_IsInstance.  Note: we rely on the
223   // Python GIL (global interpreter lock) to avoid concurrent writes to this
224   // cache, since `Check()` is always called from Python (via pybind11).
225   absl::flat_hash_map<PyTypeObject*, MatchType> py_class_cache_;
226 
227   // Maximum cache size.  In typical user programs, the cache will never become
228   // full, but we use a maximum size in case the user creates types dynamically,
229   // to avoid having an unbounded number of items in the cache.
230   // TODO(b/194903203) Consider switching to an LRU cache.
231   static constexpr int kMaxItemsInCache = 1024;
232 };
233 
234 // PyTypeChecker that checks if a value is a list whose elements all match a
235 // nested PyTypeChecker.
236 class PyListChecker : public PyTypeChecker {
237  public:
PyListChecker(PyTypeChecker_ptr element_type)238   explicit PyListChecker(PyTypeChecker_ptr element_type)
239       : element_type_(element_type) {}
240   MatchType Check(PyObject* value) override;
241   int cost() const override;
242   std::string DebugString() const override;
243 
244  private:
245   PyTypeChecker_ptr element_type_;
246 };
247 
248 // PyTypeChecker that checks if a value matches at least one nested
249 // PyTypeChecker.
250 class PyUnionChecker : public PyTypeChecker {
251  public:
PyUnionChecker(std::vector<PyTypeChecker_ptr> options)252   explicit PyUnionChecker(std::vector<PyTypeChecker_ptr> options)
253       : options_(options) {}
254   MatchType Check(PyObject* value) override;
255   int cost() const override;
256   std::string DebugString() const override;
257 
258  private:
259   std::vector<PyTypeChecker_ptr> options_;
260 };
261 
262 }  // namespace py_dispatch
263 }  // namespace tensorflow
264 
265 #endif  // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
266