1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/python/framework/python_api_dispatcher.h"
17
18 #include <set>
19
20 #include "absl/strings/str_join.h"
21 #include "tensorflow/core/platform/logging.h"
22 #include "tensorflow/core/platform/macros.h"
23 #include "tensorflow/python/lib/core/py_util.h"
24 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
25 #include "tensorflow/python/util/util.h"
26
27 namespace tensorflow {
28 namespace py_dispatch {
29
30 namespace {
31
GetRegisteredDispatchableTypes()32 std::vector<Safe_PyObjectPtr>& GetRegisteredDispatchableTypes() {
33 static std::vector<Safe_PyObjectPtr>* registered_dispatchable_types =
34 new std::vector<Safe_PyObjectPtr>();
35 if (registered_dispatchable_types->empty()) {
36 static PyObject* composite_tensor =
37 swig::GetRegisteredPyObject("CompositeTensor");
38 Py_INCREF(composite_tensor);
39 registered_dispatchable_types->push_back(
40 Safe_PyObjectPtr(composite_tensor));
41 }
42 return *registered_dispatchable_types;
43 }
44
45 // Returns true if `py_class` is a registered dispatchable type.
IsRegisteredDispatchableType(PyObject * py_class)46 bool IsRegisteredDispatchableType(PyObject* py_class) {
47 DCheckPyGilState();
48 for (const auto& registered_type : GetRegisteredDispatchableTypes()) {
49 int result = PyObject_IsSubclass(py_class, registered_type.get());
50 if (result > 0) return true;
51 if (result < 0) PyErr_Clear();
52 }
53 return false;
54 }
55
56 // Raises an exception indicating that multiple dispatch targets matched.
RaiseDispatchConflictError(const std::string & api_name,PyObject * selected,PyObject * target)57 Safe_PyObjectPtr RaiseDispatchConflictError(const std::string& api_name,
58 PyObject* selected,
59 PyObject* target) {
60 Safe_PyObjectPtr s1(PyObject_Str(selected));
61 Safe_PyObjectPtr s2(PyObject_Str(target));
62 PyErr_SetString(PyExc_ValueError,
63 absl::StrCat("Multiple dispatch targets that were "
64 "registered with tf.dispatch_for (",
65 s1 ? PyUnicode_AsUTF8(s1.get()) : "?", " and ",
66 s2 ? PyUnicode_AsUTF8(s2.get()) : "?",
67 ") match the arguments to ", api_name)
68 .c_str());
69 return nullptr;
70 }
71
72 } // namespace
73
RegisterDispatchableType(PyObject * py_class)74 bool RegisterDispatchableType(PyObject* py_class) {
75 DCheckPyGilState();
76 if (!PyType_Check(py_class)) {
77 PyErr_SetString(
78 PyExc_ValueError,
79 absl::StrCat("Expected a type object; got object with type ",
80 py_class->ob_type->tp_name)
81 .c_str());
82 return false;
83 }
84 if (IsRegisteredDispatchableType(py_class)) {
85 Safe_PyObjectPtr s(PyObject_Str(py_class));
86 PyErr_SetString(PyExc_ValueError,
87 absl::StrCat("Type ", s ? PyUnicode_AsUTF8(s.get()) : "?",
88 " (or one of its bases clases) has "
89 "already been registered")
90 .c_str());
91 return false;
92 }
93 Py_INCREF(py_class);
94 GetRegisteredDispatchableTypes().push_back(Safe_PyObjectPtr(py_class));
95 return true;
96 }
97
PythonAPIDispatcher(const std::string & api_name,absl::Span<const char * > arg_names,absl::Span<PyObject * > defaults)98 PythonAPIDispatcher::PythonAPIDispatcher(const std::string& api_name,
99 absl::Span<const char*> arg_names,
100 absl::Span<PyObject*> defaults)
101 : api_name_(api_name),
102 canonicalizer_(arg_names, defaults),
103 canonicalized_args_storage_(canonicalizer_.GetArgSize()),
104 canonicalized_args_span_(canonicalized_args_storage_) {}
105
Register(PySignatureChecker signature_checker,PyObject * dispatch_target)106 void PythonAPIDispatcher::Register(PySignatureChecker signature_checker,
107 PyObject* dispatch_target) {
108 DCheckPyGilState();
109 Py_INCREF(dispatch_target);
110 targets_.emplace_back(std::move(signature_checker),
111 Safe_PyObjectPtr(dispatch_target));
112 }
113
Dispatch(PyObject * args,PyObject * kwargs)114 Safe_PyObjectPtr PythonAPIDispatcher::Dispatch(PyObject* args,
115 PyObject* kwargs) {
116 DCheckPyGilState();
117 if (kwargs == Py_None) {
118 kwargs = nullptr;
119 }
120 // Canonicalize args (so we don't need to deal with kwargs).
121 if (!canonicalizer_.Canonicalize(args, kwargs, canonicalized_args_span_)) {
122 return nullptr;
123 }
124
125 PyObject* selected = nullptr;
126 for (auto& target : targets_) {
127 if (target.first.CheckCanonicalizedArgs(canonicalized_args_span_)) {
128 if (selected && selected != target.second.get()) {
129 return RaiseDispatchConflictError(api_name_, selected,
130 target.second.get());
131 }
132 selected = target.second.get();
133 }
134 }
135 if (selected) {
136 return Safe_PyObjectPtr(PyObject_Call(selected, args, kwargs));
137 } else {
138 Py_INCREF(Py_NotImplemented);
139 return Safe_PyObjectPtr(Py_NotImplemented);
140 }
141 }
142
143 // TODO(b/194903203) Raise an error if `func` is not registered.
Unregister(PyObject * func)144 void PythonAPIDispatcher::Unregister(PyObject* func) {
145 DCheckPyGilState();
146 using DispatchTargetPair = std::pair<PySignatureChecker, Safe_PyObjectPtr>;
147 targets_.erase(std::remove_if(targets_.begin(), targets_.end(),
148 [func](const DispatchTargetPair& t) {
149 return t.second.get() == func;
150 }),
151 targets_.end());
152 }
153
DebugString() const154 std::string PythonAPIDispatcher::DebugString() const {
155 DCheckPyGilState();
156 std::string out = absl::StrCat("<Dispatch(", api_name_, "): ");
157
158 const char* sep = "";
159 for (const auto& target : targets_) {
160 Safe_PyObjectPtr target_str(PyObject_Str(target.second.get()));
161 absl::StrAppend(&out, sep, target.first.DebugString(), " -> ",
162 target_str ? PyUnicode_AsUTF8(target_str.get()) : "?");
163 sep = ", ";
164 }
165 return out;
166 }
167
PySignatureChecker(std::vector<ParamChecker> parameter_checkers)168 PySignatureChecker::PySignatureChecker(
169 std::vector<ParamChecker> parameter_checkers)
170 : positional_parameter_checkers_(std::move(parameter_checkers)) {
171 // Check less expensive parameters first.
172 std::sort(positional_parameter_checkers_.begin(),
173 positional_parameter_checkers_.end(),
174 [](ParamChecker a, ParamChecker b) {
175 return a.second->cost() < b.second->cost();
176 });
177 }
178
CheckCanonicalizedArgs(absl::Span<PyObject * > canon_args) const179 bool PySignatureChecker::CheckCanonicalizedArgs(
180 absl::Span<PyObject*> canon_args) const {
181 bool matched_dispatchable_type = false;
182 for (auto& c : positional_parameter_checkers_) {
183 int index = c.first;
184 auto& param_checker = c.second;
185 if (index >= canon_args.size()) {
186 return false;
187 }
188 switch (param_checker->Check(canon_args[index])) {
189 case PyTypeChecker::MatchType::NO_MATCH:
190 return false;
191 case PyTypeChecker::MatchType::MATCH_DISPATCHABLE:
192 matched_dispatchable_type = true;
193 break;
194 case PyTypeChecker::MatchType::MATCH:
195 break;
196 }
197 }
198 return matched_dispatchable_type;
199 }
200
DebugString() const201 std::string PySignatureChecker::DebugString() const {
202 return absl::StrJoin(positional_parameter_checkers_, ", ",
203 [](std::string* out, ParamChecker p) {
204 absl::StrAppend(out, "args[", p.first,
205 "]:", p.second->DebugString());
206 });
207 }
208
PyInstanceChecker(const std::vector<PyObject * > & py_classes)209 PyInstanceChecker::PyInstanceChecker(const std::vector<PyObject*>& py_classes) {
210 DCheckPyGilState();
211 py_classes_.reserve(py_classes.size());
212 for (PyObject* py_class : py_classes) {
213 py_classes_.emplace_back(py_class);
214 Py_INCREF(py_class);
215 }
216 }
217
~PyInstanceChecker()218 PyInstanceChecker::~PyInstanceChecker() {
219 DCheckPyGilState();
220 for (const auto& pair : py_class_cache_) {
221 Py_DECREF(pair.first);
222 }
223 }
224
Check(PyObject * value)225 PyTypeChecker::MatchType PyInstanceChecker::Check(PyObject* value) {
226 DCheckPyGilState();
227 auto* type = Py_TYPE(value);
228 auto it = py_class_cache_.find(type);
229 if (it != py_class_cache_.end()) {
230 return it->second;
231 }
232
233 MatchType result = MatchType::NO_MATCH;
234 for (const auto& py_class : py_classes_) {
235 int is_instance = PyObject_IsInstance(value, py_class.get());
236 if (is_instance == 1) {
237 if (IsRegisteredDispatchableType(py_class.get())) {
238 result = MatchType::MATCH_DISPATCHABLE;
239 break;
240 } else {
241 result = MatchType::MATCH;
242 }
243 } else if (is_instance < 0) {
244 PyErr_Clear();
245 return MatchType::NO_MATCH;
246 }
247 }
248
249 if (py_class_cache_.size() < kMaxItemsInCache) {
250 Py_INCREF(type);
251 auto insert_result = py_class_cache_.insert({type, result});
252 if (!insert_result.second) {
253 Py_DECREF(type); // Result was added by a different thread.
254 }
255 }
256 return result;
257 }
258
cost() const259 int PyInstanceChecker::cost() const { return py_classes_.size(); }
260
DebugString() const261 std::string PyInstanceChecker::DebugString() const {
262 DCheckPyGilState();
263 std::vector<const char*> type_names;
264 for (const auto& py_class : py_classes_) {
265 type_names.push_back(
266 reinterpret_cast<PyTypeObject*>(py_class.get())->tp_name);
267 }
268 return absl::StrJoin(
269 py_classes_, ", ", [](std::string* out, const Safe_PyObjectPtr& v) {
270 out->append(reinterpret_cast<PyTypeObject*>(v.get())->tp_name);
271 });
272 }
273
Check(PyObject * value)274 PyTypeChecker::MatchType PyListChecker::Check(PyObject* value) {
275 DCheckPyGilState();
276 if (!(PyList_Check(value) || PyTuple_Check(value))) {
277 return MatchType::NO_MATCH;
278 }
279
280 Safe_PyObjectPtr seq(PySequence_Fast(value, ""));
281 if (!seq) {
282 PyErr_Clear();
283 return MatchType::NO_MATCH; // value is not a sequence.
284 }
285
286 MatchType result = MatchType::MATCH;
287 for (int i = 0; i < PySequence_Fast_GET_SIZE(seq.get()); ++i) {
288 switch (element_type_->Check(PySequence_Fast_GET_ITEM(seq.get(), i))) {
289 case MatchType::NO_MATCH:
290 return MatchType::NO_MATCH;
291 case MatchType::MATCH_DISPATCHABLE:
292 result = MatchType::MATCH_DISPATCHABLE;
293 break;
294 case MatchType::MATCH:
295 break;
296 }
297 }
298 return result;
299 }
300
cost() const301 int PyListChecker::cost() const { return 10 * element_type_->cost(); }
302
DebugString() const303 std::string PyListChecker::DebugString() const {
304 return absl::StrCat("List[", element_type_->DebugString(), "]");
305 }
306
Check(PyObject * value)307 PyTypeChecker::MatchType PyUnionChecker::Check(PyObject* value) {
308 MatchType result = MatchType::NO_MATCH;
309 for (auto& type_option : options_) {
310 switch (type_option->Check(value)) {
311 case MatchType::MATCH:
312 result = MatchType::MATCH;
313 break;
314 case MatchType::MATCH_DISPATCHABLE:
315 return MatchType::MATCH_DISPATCHABLE;
316 case MatchType::NO_MATCH:
317 break;
318 }
319 }
320 return result;
321 }
322
cost() const323 int PyUnionChecker::cost() const {
324 int cost = 1;
325 for (auto& type_option : options_) {
326 cost += type_option->cost();
327 }
328 return cost;
329 }
330
DebugString() const331 std::string PyUnionChecker::DebugString() const {
332 return absl::StrCat("Union[",
333 absl::StrJoin(options_, ", ",
334 [](std::string* out, PyTypeChecker_ptr v) {
335 out->append(v->DebugString());
336 }),
337 "]");
338 }
339
340 } // namespace py_dispatch
341 } // namespace tensorflow
342