xref: /aosp_15_r20/external/tensorflow/tensorflow/python/framework/python_api_dispatcher.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/python/framework/python_api_dispatcher.h"
17 
18 #include <set>
19 
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/platform/macros.h"
23 #include "tensorflow/python/lib/core/py_util.h"
24 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
25 #include "tensorflow/python/util/util.h"
26 
27 namespace tensorflow {
28 namespace py_dispatch {
29 
30 namespace {
31 
GetRegisteredDispatchableTypes()32 std::vector<Safe_PyObjectPtr>& GetRegisteredDispatchableTypes() {
33   static std::vector<Safe_PyObjectPtr>* registered_dispatchable_types =
34       new std::vector<Safe_PyObjectPtr>();
35   if (registered_dispatchable_types->empty()) {
36     static PyObject* composite_tensor =
37         swig::GetRegisteredPyObject("CompositeTensor");
38     Py_INCREF(composite_tensor);
39     registered_dispatchable_types->push_back(
40         Safe_PyObjectPtr(composite_tensor));
41   }
42   return *registered_dispatchable_types;
43 }
44 
45 // Returns true if `py_class` is a registered dispatchable type.
IsRegisteredDispatchableType(PyObject * py_class)46 bool IsRegisteredDispatchableType(PyObject* py_class) {
47   DCheckPyGilState();
48   for (const auto& registered_type : GetRegisteredDispatchableTypes()) {
49     int result = PyObject_IsSubclass(py_class, registered_type.get());
50     if (result > 0) return true;
51     if (result < 0) PyErr_Clear();
52   }
53   return false;
54 }
55 
56 // Raises an exception indicating that multiple dispatch targets matched.
RaiseDispatchConflictError(const std::string & api_name,PyObject * selected,PyObject * target)57 Safe_PyObjectPtr RaiseDispatchConflictError(const std::string& api_name,
58                                             PyObject* selected,
59                                             PyObject* target) {
60   Safe_PyObjectPtr s1(PyObject_Str(selected));
61   Safe_PyObjectPtr s2(PyObject_Str(target));
62   PyErr_SetString(PyExc_ValueError,
63                   absl::StrCat("Multiple dispatch targets that were "
64                                "registered with tf.dispatch_for (",
65                                s1 ? PyUnicode_AsUTF8(s1.get()) : "?", " and ",
66                                s2 ? PyUnicode_AsUTF8(s2.get()) : "?",
67                                ") match the arguments to ", api_name)
68                       .c_str());
69   return nullptr;
70 }
71 
72 }  // namespace
73 
RegisterDispatchableType(PyObject * py_class)74 bool RegisterDispatchableType(PyObject* py_class) {
75   DCheckPyGilState();
76   if (!PyType_Check(py_class)) {
77     PyErr_SetString(
78         PyExc_ValueError,
79         absl::StrCat("Expected a type object; got object with type ",
80                      py_class->ob_type->tp_name)
81             .c_str());
82     return false;
83   }
84   if (IsRegisteredDispatchableType(py_class)) {
85     Safe_PyObjectPtr s(PyObject_Str(py_class));
86     PyErr_SetString(PyExc_ValueError,
87                     absl::StrCat("Type ", s ? PyUnicode_AsUTF8(s.get()) : "?",
88                                  " (or one of its bases clases) has "
89                                  "already been registered")
90                         .c_str());
91     return false;
92   }
93   Py_INCREF(py_class);
94   GetRegisteredDispatchableTypes().push_back(Safe_PyObjectPtr(py_class));
95   return true;
96 }
97 
PythonAPIDispatcher(const std::string & api_name,absl::Span<const char * > arg_names,absl::Span<PyObject * > defaults)98 PythonAPIDispatcher::PythonAPIDispatcher(const std::string& api_name,
99                                          absl::Span<const char*> arg_names,
100                                          absl::Span<PyObject*> defaults)
101     : api_name_(api_name),
102       canonicalizer_(arg_names, defaults),
103       canonicalized_args_storage_(canonicalizer_.GetArgSize()),
104       canonicalized_args_span_(canonicalized_args_storage_) {}
105 
Register(PySignatureChecker signature_checker,PyObject * dispatch_target)106 void PythonAPIDispatcher::Register(PySignatureChecker signature_checker,
107                                    PyObject* dispatch_target) {
108   DCheckPyGilState();
109   Py_INCREF(dispatch_target);
110   targets_.emplace_back(std::move(signature_checker),
111                         Safe_PyObjectPtr(dispatch_target));
112 }
113 
Dispatch(PyObject * args,PyObject * kwargs)114 Safe_PyObjectPtr PythonAPIDispatcher::Dispatch(PyObject* args,
115                                                PyObject* kwargs) {
116   DCheckPyGilState();
117   if (kwargs == Py_None) {
118     kwargs = nullptr;
119   }
120   // Canonicalize args (so we don't need to deal with kwargs).
121   if (!canonicalizer_.Canonicalize(args, kwargs, canonicalized_args_span_)) {
122     return nullptr;
123   }
124 
125   PyObject* selected = nullptr;
126   for (auto& target : targets_) {
127     if (target.first.CheckCanonicalizedArgs(canonicalized_args_span_)) {
128       if (selected && selected != target.second.get()) {
129         return RaiseDispatchConflictError(api_name_, selected,
130                                           target.second.get());
131       }
132       selected = target.second.get();
133     }
134   }
135   if (selected) {
136     return Safe_PyObjectPtr(PyObject_Call(selected, args, kwargs));
137   } else {
138     Py_INCREF(Py_NotImplemented);
139     return Safe_PyObjectPtr(Py_NotImplemented);
140   }
141 }
142 
143 // TODO(b/194903203) Raise an error if `func` is not registered.
Unregister(PyObject * func)144 void PythonAPIDispatcher::Unregister(PyObject* func) {
145   DCheckPyGilState();
146   using DispatchTargetPair = std::pair<PySignatureChecker, Safe_PyObjectPtr>;
147   targets_.erase(std::remove_if(targets_.begin(), targets_.end(),
148                                 [func](const DispatchTargetPair& t) {
149                                   return t.second.get() == func;
150                                 }),
151                  targets_.end());
152 }
153 
DebugString() const154 std::string PythonAPIDispatcher::DebugString() const {
155   DCheckPyGilState();
156   std::string out = absl::StrCat("<Dispatch(", api_name_, "): ");
157 
158   const char* sep = "";
159   for (const auto& target : targets_) {
160     Safe_PyObjectPtr target_str(PyObject_Str(target.second.get()));
161     absl::StrAppend(&out, sep, target.first.DebugString(), " -> ",
162                     target_str ? PyUnicode_AsUTF8(target_str.get()) : "?");
163     sep = ", ";
164   }
165   return out;
166 }
167 
PySignatureChecker(std::vector<ParamChecker> parameter_checkers)168 PySignatureChecker::PySignatureChecker(
169     std::vector<ParamChecker> parameter_checkers)
170     : positional_parameter_checkers_(std::move(parameter_checkers)) {
171   // Check less expensive parameters first.
172   std::sort(positional_parameter_checkers_.begin(),
173             positional_parameter_checkers_.end(),
174             [](ParamChecker a, ParamChecker b) {
175               return a.second->cost() < b.second->cost();
176             });
177 }
178 
CheckCanonicalizedArgs(absl::Span<PyObject * > canon_args) const179 bool PySignatureChecker::CheckCanonicalizedArgs(
180     absl::Span<PyObject*> canon_args) const {
181   bool matched_dispatchable_type = false;
182   for (auto& c : positional_parameter_checkers_) {
183     int index = c.first;
184     auto& param_checker = c.second;
185     if (index >= canon_args.size()) {
186       return false;
187     }
188     switch (param_checker->Check(canon_args[index])) {
189       case PyTypeChecker::MatchType::NO_MATCH:
190         return false;
191       case PyTypeChecker::MatchType::MATCH_DISPATCHABLE:
192         matched_dispatchable_type = true;
193         break;
194       case PyTypeChecker::MatchType::MATCH:
195         break;
196     }
197   }
198   return matched_dispatchable_type;
199 }
200 
DebugString() const201 std::string PySignatureChecker::DebugString() const {
202   return absl::StrJoin(positional_parameter_checkers_, ", ",
203                        [](std::string* out, ParamChecker p) {
204                          absl::StrAppend(out, "args[", p.first,
205                                          "]:", p.second->DebugString());
206                        });
207 }
208 
PyInstanceChecker(const std::vector<PyObject * > & py_classes)209 PyInstanceChecker::PyInstanceChecker(const std::vector<PyObject*>& py_classes) {
210   DCheckPyGilState();
211   py_classes_.reserve(py_classes.size());
212   for (PyObject* py_class : py_classes) {
213     py_classes_.emplace_back(py_class);
214     Py_INCREF(py_class);
215   }
216 }
217 
~PyInstanceChecker()218 PyInstanceChecker::~PyInstanceChecker() {
219   DCheckPyGilState();
220   for (const auto& pair : py_class_cache_) {
221     Py_DECREF(pair.first);
222   }
223 }
224 
Check(PyObject * value)225 PyTypeChecker::MatchType PyInstanceChecker::Check(PyObject* value) {
226   DCheckPyGilState();
227   auto* type = Py_TYPE(value);
228   auto it = py_class_cache_.find(type);
229   if (it != py_class_cache_.end()) {
230     return it->second;
231   }
232 
233   MatchType result = MatchType::NO_MATCH;
234   for (const auto& py_class : py_classes_) {
235     int is_instance = PyObject_IsInstance(value, py_class.get());
236     if (is_instance == 1) {
237       if (IsRegisteredDispatchableType(py_class.get())) {
238         result = MatchType::MATCH_DISPATCHABLE;
239         break;
240       } else {
241         result = MatchType::MATCH;
242       }
243     } else if (is_instance < 0) {
244       PyErr_Clear();
245       return MatchType::NO_MATCH;
246     }
247   }
248 
249   if (py_class_cache_.size() < kMaxItemsInCache) {
250     Py_INCREF(type);
251     auto insert_result = py_class_cache_.insert({type, result});
252     if (!insert_result.second) {
253       Py_DECREF(type);  // Result was added by a different thread.
254     }
255   }
256   return result;
257 }
258 
cost() const259 int PyInstanceChecker::cost() const { return py_classes_.size(); }
260 
DebugString() const261 std::string PyInstanceChecker::DebugString() const {
262   DCheckPyGilState();
263   std::vector<const char*> type_names;
264   for (const auto& py_class : py_classes_) {
265     type_names.push_back(
266         reinterpret_cast<PyTypeObject*>(py_class.get())->tp_name);
267   }
268   return absl::StrJoin(
269       py_classes_, ", ", [](std::string* out, const Safe_PyObjectPtr& v) {
270         out->append(reinterpret_cast<PyTypeObject*>(v.get())->tp_name);
271       });
272 }
273 
Check(PyObject * value)274 PyTypeChecker::MatchType PyListChecker::Check(PyObject* value) {
275   DCheckPyGilState();
276   if (!(PyList_Check(value) || PyTuple_Check(value))) {
277     return MatchType::NO_MATCH;
278   }
279 
280   Safe_PyObjectPtr seq(PySequence_Fast(value, ""));
281   if (!seq) {
282     PyErr_Clear();
283     return MatchType::NO_MATCH;  // value is not a sequence.
284   }
285 
286   MatchType result = MatchType::MATCH;
287   for (int i = 0; i < PySequence_Fast_GET_SIZE(seq.get()); ++i) {
288     switch (element_type_->Check(PySequence_Fast_GET_ITEM(seq.get(), i))) {
289       case MatchType::NO_MATCH:
290         return MatchType::NO_MATCH;
291       case MatchType::MATCH_DISPATCHABLE:
292         result = MatchType::MATCH_DISPATCHABLE;
293         break;
294       case MatchType::MATCH:
295         break;
296     }
297   }
298   return result;
299 }
300 
cost() const301 int PyListChecker::cost() const { return 10 * element_type_->cost(); }
302 
DebugString() const303 std::string PyListChecker::DebugString() const {
304   return absl::StrCat("List[", element_type_->DebugString(), "]");
305 }
306 
Check(PyObject * value)307 PyTypeChecker::MatchType PyUnionChecker::Check(PyObject* value) {
308   MatchType result = MatchType::NO_MATCH;
309   for (auto& type_option : options_) {
310     switch (type_option->Check(value)) {
311       case MatchType::MATCH:
312         result = MatchType::MATCH;
313         break;
314       case MatchType::MATCH_DISPATCHABLE:
315         return MatchType::MATCH_DISPATCHABLE;
316       case MatchType::NO_MATCH:
317         break;
318     }
319   }
320   return result;
321 }
322 
cost() const323 int PyUnionChecker::cost() const {
324   int cost = 1;
325   for (auto& type_option : options_) {
326     cost += type_option->cost();
327   }
328   return cost;
329 }
330 
DebugString() const331 std::string PyUnionChecker::DebugString() const {
332   return absl::StrCat("Union[",
333                       absl::StrJoin(options_, ", ",
334                                     [](std::string* out, PyTypeChecker_ptr v) {
335                                       out->append(v->DebugString());
336                                     }),
337                       "]");
338 }
339 
340 }  // namespace py_dispatch
341 }  // namespace tensorflow
342