1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/python/util/util.h"
16
17 #include <functional>
18 #include <memory>
19 #include <unordered_map>
20 #include <vector>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/core/lib/gtl/map_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 #include "tensorflow/core/platform/mutex.h"
27 #include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
28
29 namespace tensorflow {
30 namespace swig {
31
32 namespace {
33 string PyObjectToString(PyObject* o);
34 } // namespace
35
RegisteredPyObjectMap()36 std::unordered_map<string, PyObject*>* RegisteredPyObjectMap() {
37 static auto* m = new std::unordered_map<string, PyObject*>();
38 return m;
39 }
40
GetRegisteredPyObject(const string & name)41 PyObject* GetRegisteredPyObject(const string& name) {
42 const auto* m = RegisteredPyObjectMap();
43 auto it = m->find(name);
44 if (it == m->end()) {
45 PyErr_SetString(PyExc_TypeError,
46 tensorflow::strings::StrCat("No object with name ", name,
47 " has been registered.")
48 .c_str());
49 return nullptr;
50 }
51 return it->second;
52 }
53
RegisterType(PyObject * type_name,PyObject * type)54 PyObject* RegisterType(PyObject* type_name, PyObject* type) {
55 if (!PyType_Check(type)) {
56 PyErr_SetString(PyExc_TypeError,
57 tensorflow::strings::StrCat("Expecting a type, got ",
58 Py_TYPE(type)->tp_name)
59 .c_str());
60 return nullptr;
61 }
62 return RegisterPyObject(type_name, type);
63 }
64
RegisterPyObject(PyObject * name,PyObject * value)65 PyObject* RegisterPyObject(PyObject* name, PyObject* value) {
66 string key;
67 if (PyBytes_Check(name)) {
68 key = PyBytes_AsString(name);
69 #if PY_MAJOR_VERSION >= 3
70 } else if (PyUnicode_Check(name)) {
71 key = PyUnicode_AsUTF8(name);
72 #endif
73 } else {
74 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
75 "Expected name to be a str, got",
76 PyObjectToString(name))
77 .c_str());
78 return nullptr;
79 }
80
81 auto* m = RegisteredPyObjectMap();
82 if (m->find(key) != m->end()) {
83 PyErr_SetString(PyExc_TypeError, tensorflow::strings::StrCat(
84 "Value already registered for ", key)
85 .c_str());
86 return nullptr;
87 }
88
89 Py_INCREF(value);
90 m->emplace(key, value);
91
92 Py_RETURN_NONE;
93 }
94
95 namespace {
96 const int kMaxItemsInCache = 1024;
97
IsString(PyObject * o)98 bool IsString(PyObject* o) {
99 return PyBytes_Check(o) ||
100 #if PY_MAJOR_VERSION < 3
101 PyString_Check(o) ||
102 #endif
103 PyUnicode_Check(o);
104 }
105
106 // Equivalent to Python's 'o.__class__.__name__'
107 // Note that '__class__' attribute is set only in new-style classes.
108 // A lot of tensorflow code uses __class__ without checks, so it seems like
109 // we only support new-style classes.
GetClassName(PyObject * o)110 StringPiece GetClassName(PyObject* o) {
111 // __class__ is equivalent to type() for new style classes.
112 // type() is equivalent to PyObject_Type()
113 // (https://docs.python.org/3.5/c-api/object.html#c.PyObject_Type)
114 // PyObject_Type() is equivalent to o->ob_type except for Py_INCREF, which
115 // we don't need here.
116 PyTypeObject* type = o->ob_type;
117
118 // __name__ is the value of `tp_name` after the last '.'
119 // (https://docs.python.org/2/c-api/typeobj.html#c.PyTypeObject.tp_name)
120 StringPiece name(type->tp_name);
121 size_t pos = name.rfind('.');
122 if (pos != StringPiece::npos) {
123 name.remove_prefix(pos + 1);
124 }
125 return name;
126 }
127
PyObjectToString(PyObject * o)128 string PyObjectToString(PyObject* o) {
129 if (o == nullptr) {
130 return "<null object>";
131 }
132 PyObject* str = PyObject_Str(o);
133 if (str) {
134 #if PY_MAJOR_VERSION < 3
135 string s(PyString_AS_STRING(str));
136 #else
137 string s(PyUnicode_AsUTF8(str));
138 #endif
139 Py_DECREF(str);
140 return tensorflow::strings::StrCat("type=", GetClassName(o), " str=", s);
141 } else {
142 return "<failed to execute str() on object>";
143 }
144 }
145
146 class CachedTypeCheck {
147 public:
CachedTypeCheck(std::function<int (PyObject *)> ternary_predicate)148 explicit CachedTypeCheck(std::function<int(PyObject*)> ternary_predicate)
149 : ternary_predicate_(std::move(ternary_predicate)) {}
150
~CachedTypeCheck()151 ~CachedTypeCheck() {
152 mutex_lock l(type_to_sequence_map_mu_);
153 for (const auto& pair : type_to_sequence_map_) {
154 Py_DECREF(pair.first);
155 }
156 }
157
158 // Caches successful executions of the one-argument (PyObject*) callable
159 // "ternary_predicate" based on the type of "o". -1 from the callable
160 // indicates an unsuccessful check (not cached), 0 indicates that "o"'s type
161 // does not match the predicate, and 1 indicates that it does. Used to avoid
162 // calling back into Python for expensive isinstance checks.
CachedLookup(PyObject * o)163 int CachedLookup(PyObject* o) {
164 // Try not to return to Python - see if the type has already been seen
165 // before.
166
167 auto* type = Py_TYPE(o);
168
169 {
170 tf_shared_lock l(type_to_sequence_map_mu_);
171 auto it = type_to_sequence_map_.find(type);
172 if (it != type_to_sequence_map_.end()) {
173 return it->second;
174 }
175 }
176
177 int check_result = ternary_predicate_(o);
178
179 if (check_result == -1) {
180 return -1; // Type check error, not cached.
181 }
182
183 // NOTE: This is never decref'd as long as the object lives, which is likely
184 // forever, but we don't want the type to get deleted as long as it is in
185 // the map. This should not be too much of a leak, as there should only be a
186 // relatively small number of types in the map, and an even smaller number
187 // that are eligible for decref. As a precaution, we limit the size of the
188 // map to 1024.
189 {
190 mutex_lock l(type_to_sequence_map_mu_);
191 if (type_to_sequence_map_.size() < kMaxItemsInCache) {
192 Py_INCREF(type);
193 auto insert_result = type_to_sequence_map_.insert({type, check_result});
194 if (!insert_result.second) {
195 // The type was added to the cache by a concurrent thread after we
196 // looked it up above.
197 Py_DECREF(type);
198 }
199 }
200 }
201
202 return check_result;
203 }
204
205 private:
206 std::function<int(PyObject*)> ternary_predicate_;
207 mutex type_to_sequence_map_mu_;
208 std::unordered_map<PyTypeObject*, bool> type_to_sequence_map_
209 TF_GUARDED_BY(type_to_sequence_map_mu_);
210 };
211
212 // Returns 1 if 'obj' is an instance of 'type_name'
213 // Returns 0 otherwise.
214 // Returns -1 if an error occurred (e.g., if 'type_name' is not registered.)
IsInstanceOfRegisteredType(PyObject * obj,const char * type_name)215 int IsInstanceOfRegisteredType(PyObject* obj, const char* type_name) {
216 PyObject* type_obj = GetRegisteredPyObject(type_name);
217 if (TF_PREDICT_FALSE(type_obj == nullptr)) {
218 PyErr_SetString(PyExc_RuntimeError,
219 tensorflow::strings::StrCat(
220 type_name,
221 " type has not been set. "
222 "Please register the type with the identifier \"",
223 type_name, "\" using RegisterType.")
224 .c_str());
225 return -1;
226 }
227 return PyObject_IsInstance(obj, type_obj);
228 }
229
230 // Returns 1 if `o` is considered a mapping for the purposes of Flatten().
231 // Returns 0 otherwise.
232 // Returns -1 if an error occurred.
IsMappingHelper(PyObject * o)233 int IsMappingHelper(PyObject* o) {
234 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
235 return IsInstanceOfRegisteredType(to_check, "Mapping");
236 });
237 if (PyDict_Check(o)) return true;
238 return check_cache->CachedLookup(o);
239 }
240
241 // Returns 1 if `o` is considered a mutable mapping for the purposes of
242 // Flatten(). Returns 0 otherwise. Returns -1 if an error occurred.
IsMutableMappingHelper(PyObject * o)243 int IsMutableMappingHelper(PyObject* o) {
244 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
245 return IsInstanceOfRegisteredType(to_check, "MutableMapping");
246 });
247 if (PyDict_Check(o)) return true;
248 return check_cache->CachedLookup(o);
249 }
250
251 // Returns 1 if `o` is considered a mapping view for the purposes of Flatten().
252 // Returns 0 otherwise.
253 // Returns -1 if an error occurred.
IsMappingViewHelper(PyObject * o)254 int IsMappingViewHelper(PyObject* o) {
255 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
256 return IsInstanceOfRegisteredType(to_check, "MappingView");
257 });
258 return check_cache->CachedLookup(o);
259 }
260
261 // Returns 1 if `o` is considered an object proxy
262 // Returns 0 otherwise.
263 // Returns -1 if an error occurred.
IsObjectProxy(PyObject * o)264 int IsObjectProxy(PyObject* o) {
265 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
266 return IsInstanceOfRegisteredType(to_check, "ObjectProxy");
267 });
268 return check_cache->CachedLookup(o);
269 }
270
271 // Returns 1 if `o` is an instance of attrs-decorated class.
272 // Returns 0 otherwise.
IsAttrsHelper(PyObject * o)273 int IsAttrsHelper(PyObject* o) {
274 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
275 Safe_PyObjectPtr cls(PyObject_GetAttrString(to_check, "__class__"));
276 if (cls) {
277 return PyObject_HasAttrString(cls.get(), "__attrs_attrs__");
278 }
279
280 // PyObject_GetAttrString returns null on error
281 PyErr_Clear();
282 return 0;
283 });
284 return check_cache->CachedLookup(o);
285 }
286
287 // Returns 1 if `o` is an object of type IndexedSlices.
288 // Returns 0 otherwise.
289 // Returns -1 if an error occurred.
IsIndexedSlicesHelper(PyObject * o)290 int IsIndexedSlicesHelper(PyObject* o) {
291 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
292 return IsInstanceOfRegisteredType(to_check, "IndexedSlices");
293 });
294 return check_cache->CachedLookup(o);
295 }
296
297 // Returns 1 if `o` is a Tensor.
298 // Returns 0 otherwise.
299 // Returns -1 if an error occurred.
IsTensorHelper(PyObject * o)300 int IsTensorHelper(PyObject* o) {
301 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
302 return IsInstanceOfRegisteredType(to_check, "Tensor");
303 });
304 return check_cache->CachedLookup(o);
305 }
306
307 // Returns 1 if `o` is a TensorSpec.
308 // Returns 0 otherwise.
309 // Returns -1 if an error occurred.
IsTensorSpecHelper(PyObject * o)310 int IsTensorSpecHelper(PyObject* o) {
311 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
312 return IsInstanceOfRegisteredType(to_check, "TensorSpec");
313 });
314 return check_cache->CachedLookup(o);
315 }
316
317 // Returns 1 if `o` is an EagerTensor.
318 // Returns 0 otherwise.
319 // Returns -1 if an error occurred.
IsEagerTensorHelper(PyObject * o)320 int IsEagerTensorHelper(PyObject* o) {
321 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
322 return IsInstanceOfRegisteredType(to_check, "EagerTensor");
323 });
324 return check_cache->CachedLookup(o);
325 }
326
327 // Returns 1 if `o` is a ResourceVariable.
328 // Returns 0 otherwise.
329 // Returns -1 if an error occurred.
IsResourceVariableHelper(PyObject * o)330 int IsResourceVariableHelper(PyObject* o) {
331 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
332 return IsInstanceOfRegisteredType(to_check, "ResourceVariable");
333 });
334 return check_cache->CachedLookup(o);
335 }
336
337 // Returns 1 if `o` is a OwnedIterator.
338 // Returns 0 otherwise.
339 // Returns -1 if an error occurred.
IsOwnedIteratorHelper(PyObject * o)340 int IsOwnedIteratorHelper(PyObject* o) {
341 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
342 return IsInstanceOfRegisteredType(to_check, "OwnedIterator");
343 });
344 return check_cache->CachedLookup(o);
345 }
346
347 // Returns 1 if `o` is a ResourceVariable.
348 // Returns 0 otherwise.
349 // Returns -1 if an error occurred.
IsVariableHelper(PyObject * o)350 int IsVariableHelper(PyObject* o) {
351 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
352 return IsInstanceOfRegisteredType(to_check, "Variable");
353 });
354 return check_cache->CachedLookup(o);
355 }
356
357 // Returns 1 if `o` is considered a sequence for the purposes of Flatten().
358 // Returns 0 otherwise.
359 // Returns -1 if an error occurred.
IsNestedHelper(PyObject * o)360 int IsNestedHelper(PyObject* o) {
361 // We treat dicts and other mappings as special cases of sequences.
362 if (IsMappingHelper(o)) return true;
363 if (IsMappingViewHelper(o)) return true;
364 if (IsAttrsHelper(o)) return true;
365
366 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
367 int is_instance = IsInstanceOfRegisteredType(to_check, "Sequence");
368
369 // Don't cache a failed is_instance check.
370 if (is_instance == -1) return -1;
371
372 return static_cast<int>(is_instance != 0 && !IsString(to_check));
373 });
374 return check_cache->CachedLookup(o);
375 }
376
377 // Returns 1 if `o`'s class has a `__tf_dispatch__` attribute.
378 // Returns 0 otherwise.
IsDispatchableHelper(PyObject * o)379 int IsDispatchableHelper(PyObject* o) {
380 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
381 return PyObject_HasAttrString(
382 reinterpret_cast<PyObject*>(to_check->ob_type), "__tf_dispatch__");
383 });
384 return check_cache->CachedLookup(o);
385 }
386
387 // ValueIterator interface
388 class ValueIterator {
389 public:
~ValueIterator()390 virtual ~ValueIterator() {}
391 virtual Safe_PyObjectPtr next() = 0;
392
valid() const393 bool valid() const { return is_valid_; }
394
395 protected:
invalidate()396 void invalidate() { is_valid_ = false; }
397
398 private:
399 bool is_valid_ = true;
400 };
401
402 using ValueIteratorPtr = std::unique_ptr<ValueIterator>;
403
404 // Iterate through dictionaries in a deterministic order by sorting the
405 // keys. Notice this means that we ignore the original order of
406 // `OrderedDict` instances. This is intentional, to avoid potential
407 // bugs caused by mixing ordered and plain dicts (e.g., flattening
408 // a dict but using a corresponding `OrderedDict` to pack it back).
409 class DictValueIterator : public ValueIterator {
410 public:
DictValueIterator(PyObject * dict)411 explicit DictValueIterator(PyObject* dict)
412 : dict_(dict), keys_(PyDict_Keys(dict)) {
413 if (PyList_Sort(keys_.get()) == -1) {
414 invalidate();
415 } else {
416 iter_.reset(PyObject_GetIter(keys_.get()));
417 }
418 }
419
next()420 Safe_PyObjectPtr next() override {
421 Safe_PyObjectPtr result;
422 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
423 if (key) {
424 // PyDict_GetItem returns a borrowed reference.
425 PyObject* elem = PyDict_GetItem(dict_, key.get());
426 if (elem) {
427 Py_INCREF(elem);
428 result.reset(elem);
429 } else {
430 PyErr_SetString(PyExc_RuntimeError,
431 "Dictionary was modified during iteration over it");
432 }
433 }
434 return result;
435 }
436
437 private:
438 PyObject* dict_;
439 Safe_PyObjectPtr keys_;
440 Safe_PyObjectPtr iter_;
441 };
442
443 // Iterate over mapping objects by sorting the keys first
444 class MappingValueIterator : public ValueIterator {
445 public:
MappingValueIterator(PyObject * mapping)446 explicit MappingValueIterator(PyObject* mapping)
447 : mapping_(mapping), keys_(MappingKeys(mapping)) {
448 if (!keys_ || PyList_Sort(keys_.get()) == -1) {
449 invalidate();
450 } else {
451 iter_.reset(PyObject_GetIter(keys_.get()));
452 }
453 }
454
next()455 Safe_PyObjectPtr next() override {
456 Safe_PyObjectPtr result;
457 Safe_PyObjectPtr key(PyIter_Next(iter_.get()));
458 if (key) {
459 // Unlike PyDict_GetItem, PyObject_GetItem returns a new reference.
460 PyObject* elem = PyObject_GetItem(mapping_, key.get());
461 if (elem) {
462 result.reset(elem);
463 } else {
464 PyErr_SetString(PyExc_RuntimeError,
465 "Mapping was modified during iteration over it");
466 }
467 }
468 return result;
469 }
470
471 private:
472 PyObject* mapping_;
473 Safe_PyObjectPtr keys_;
474 Safe_PyObjectPtr iter_;
475 };
476
477 // Iterate over a sequence, by index.
478 class SequenceValueIterator : public ValueIterator {
479 public:
SequenceValueIterator(PyObject * iterable)480 explicit SequenceValueIterator(PyObject* iterable)
481 : seq_(PySequence_Fast(iterable, "")),
482 size_(seq_.get() ? PySequence_Fast_GET_SIZE(seq_.get()) : 0),
483 index_(0) {}
484
next()485 Safe_PyObjectPtr next() override {
486 Safe_PyObjectPtr result;
487 if (index_ < size_) {
488 // PySequence_Fast_GET_ITEM returns a borrowed reference.
489 PyObject* elem = PySequence_Fast_GET_ITEM(seq_.get(), index_);
490 ++index_;
491 if (elem) {
492 Py_INCREF(elem);
493 result.reset(elem);
494 }
495 }
496
497 return result;
498 }
499
500 private:
501 Safe_PyObjectPtr seq_;
502 const Py_ssize_t size_;
503 Py_ssize_t index_;
504 };
505
506 // Iterator that just returns a single python object.
507 class SingleValueIterator : public ValueIterator {
508 public:
SingleValueIterator(PyObject * x)509 explicit SingleValueIterator(PyObject* x) : x_(x) { Py_INCREF(x); }
510
next()511 Safe_PyObjectPtr next() override { return std::move(x_); }
512
513 private:
514 Safe_PyObjectPtr x_;
515 };
516
517 // Returns nullptr (to raise an exception) when next() is called. Caller
518 // should have already called PyErr_SetString.
519 class ErrorValueIterator : public ValueIterator {
520 public:
ErrorValueIterator()521 ErrorValueIterator() {}
next()522 Safe_PyObjectPtr next() override { return nullptr; }
523 };
524
525 class AttrsValueIterator : public ValueIterator {
526 public:
AttrsValueIterator(PyObject * nested)527 explicit AttrsValueIterator(PyObject* nested) : nested_(nested) {
528 Py_INCREF(nested);
529 cls_.reset(PyObject_GetAttrString(nested_.get(), "__class__"));
530 if (cls_) {
531 attrs_.reset(PyObject_GetAttrString(cls_.get(), "__attrs_attrs__"));
532 if (attrs_) {
533 iter_.reset(PyObject_GetIter(attrs_.get()));
534 }
535 }
536 if (!iter_ || PyErr_Occurred()) invalidate();
537 }
538
next()539 Safe_PyObjectPtr next() override {
540 Safe_PyObjectPtr result;
541 Safe_PyObjectPtr item(PyIter_Next(iter_.get()));
542 if (item) {
543 Safe_PyObjectPtr name(PyObject_GetAttrString(item.get(), "name"));
544 result.reset(PyObject_GetAttr(nested_.get(), name.get()));
545 }
546
547 return result;
548 }
549
550 private:
551 Safe_PyObjectPtr nested_;
552 Safe_PyObjectPtr cls_;
553 Safe_PyObjectPtr attrs_;
554 Safe_PyObjectPtr iter_;
555 };
556
IsSparseTensorValueType(PyObject * o)557 bool IsSparseTensorValueType(PyObject* o) {
558 PyObject* sparse_tensor_value_type =
559 GetRegisteredPyObject("SparseTensorValue");
560 if (TF_PREDICT_FALSE(sparse_tensor_value_type == nullptr)) {
561 return false;
562 }
563
564 return PyObject_TypeCheck(
565 o, reinterpret_cast<PyTypeObject*>(sparse_tensor_value_type)) == 1;
566 }
567
568 // Returns 1 if `o` is an instance of CompositeTensor.
569 // Returns 0 otherwise.
570 // Returns -1 if an error occurred.
IsCompositeTensorHelper(PyObject * o)571 bool IsCompositeTensorHelper(PyObject* o) {
572 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
573 return IsInstanceOfRegisteredType(to_check, "CompositeTensor");
574 });
575 return check_cache->CachedLookup(o);
576 }
577
578 // Returns 1 if `o` is an instance of TypeSpec, but is not TensorSpec or
579 // VariableSpec.
580 // Returns 0 otherwise.
581 // Returns -1 if an error occurred.
IsTypeSpecHelper(PyObject * o)582 bool IsTypeSpecHelper(PyObject* o) {
583 static auto* const check_cache = new CachedTypeCheck([](PyObject* to_check) {
584 int is_type_spec = IsInstanceOfRegisteredType(to_check, "TypeSpec");
585 int is_dense_spec = (IsInstanceOfRegisteredType(to_check, "TensorSpec") ||
586 IsInstanceOfRegisteredType(to_check, "VariableSpec"));
587 if ((is_type_spec == -1) || (is_dense_spec == -1)) return -1;
588 return static_cast<int>(is_type_spec && !is_dense_spec);
589 });
590 return check_cache->CachedLookup(o);
591 }
592
593 // Returns 1 if `o` is a (non-string) sequence or CompositeTensor or
594 // (non-TensorSpec and non-VariableSpec) TypeSpec.
595 // Returns 0 otherwise.
596 // Returns -1 if an error occurred.
IsNestedOrCompositeHelper(PyObject * o)597 int IsNestedOrCompositeHelper(PyObject* o) {
598 int is_nested = IsNestedHelper(o);
599 int is_composite = IsCompositeTensorHelper(o);
600 int is_type_spec = IsTypeSpecHelper(o);
601 if ((is_nested == -1) || (is_composite == -1) || (is_type_spec == -1)) {
602 return -1;
603 }
604 return is_nested || is_composite || is_type_spec;
605 }
606
IsNestedForDataHelper(PyObject * o)607 int IsNestedForDataHelper(PyObject* o) {
608 return IsNestedHelper(o) == 1 && !PyList_Check(o) &&
609 !IsSparseTensorValueType(o);
610 }
611
GetValueIterator(PyObject * nested)612 ValueIteratorPtr GetValueIterator(PyObject* nested) {
613 if (PyDict_Check(nested)) {
614 return absl::make_unique<DictValueIterator>(nested);
615 } else if (IsMappingHelper(nested)) {
616 return absl::make_unique<MappingValueIterator>(nested);
617 } else if (IsAttrsHelper(nested)) {
618 return absl::make_unique<AttrsValueIterator>(nested);
619 } else {
620 return absl::make_unique<SequenceValueIterator>(nested);
621 }
622 }
623
624 // Similar to above, just specialized for the functions in the data package.
GetValueIteratorForData(PyObject * nested)625 ValueIteratorPtr GetValueIteratorForData(PyObject* nested) {
626 if (PyDict_Check(nested)) {
627 return absl::make_unique<DictValueIterator>(nested);
628 } else if (IsMappingHelper(nested)) {
629 return absl::make_unique<MappingValueIterator>(nested);
630 } else if (IsAttrsHelper(nested)) {
631 return absl::make_unique<AttrsValueIterator>(nested);
632 } else if (IsSparseTensorValueType(nested)) {
633 return absl::make_unique<SingleValueIterator>(nested);
634 } else {
635 return absl::make_unique<SequenceValueIterator>(nested);
636 }
637 }
638
639 // Similar to GetValueIterator above, but expands CompositeTensor and TypeSpec.
GetValueIteratorForComposite(PyObject * nested)640 ValueIteratorPtr GetValueIteratorForComposite(PyObject* nested) {
641 if (IsCompositeTensor(nested)) {
642 Safe_PyObjectPtr spec(PyObject_GetAttrString(nested, "_type_spec"));
643 if (PyErr_Occurred() || !spec) {
644 return absl::make_unique<ErrorValueIterator>();
645 }
646
647 static char to_components[] = "_to_components";
648 static char argspec[] = "(O)";
649 Safe_PyObjectPtr components(
650 PyObject_CallMethod(spec.get(), to_components, argspec, nested));
651 if (PyErr_Occurred() || components == nullptr) {
652 return absl::make_unique<ErrorValueIterator>();
653 }
654 return absl::make_unique<SingleValueIterator>(components.get());
655 }
656
657 if (IsTypeSpec(nested)) {
658 Safe_PyObjectPtr specs(PyObject_GetAttrString(nested, "_component_specs"));
659 if (PyErr_Occurred() || specs == nullptr) {
660 return absl::make_unique<ErrorValueIterator>();
661 }
662 return absl::make_unique<SingleValueIterator>(specs.get());
663 }
664
665 return GetValueIterator(nested);
666 }
667
FlattenHelper(PyObject * nested,PyObject * list,const std::function<int (PyObject *)> & is_nested_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter)668 bool FlattenHelper(
669 PyObject* nested, PyObject* list,
670 const std::function<int(PyObject*)>& is_nested_helper,
671 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter) {
672 // if nested is not a sequence, append itself and exit
673 int is_nested = is_nested_helper(nested);
674 if (is_nested == -1) return false;
675 if (!is_nested) {
676 return PyList_Append(list, nested) != -1;
677 }
678
679 ValueIteratorPtr iter = value_iterator_getter(nested);
680 if (!iter->valid()) return false;
681
682 for (Safe_PyObjectPtr item = iter->next(); item; item = iter->next()) {
683 if (Py_EnterRecursiveCall(" in flatten")) {
684 return false;
685 }
686 const bool success = FlattenHelper(item.get(), list, is_nested_helper,
687 value_iterator_getter);
688 Py_LeaveRecursiveCall();
689 if (!success) {
690 return false;
691 }
692 }
693 return true;
694 }
695
696 // Sets error using keys of 'dict1' and 'dict2'.
697 // 'dict1' and 'dict2' are assumed to be Python dictionaries.
SetDifferentKeysError(PyObject * dict1,PyObject * dict2,string * error_msg,bool * is_type_error)698 void SetDifferentKeysError(PyObject* dict1, PyObject* dict2, string* error_msg,
699 bool* is_type_error) {
700 Safe_PyObjectPtr k1(MappingKeys(dict1));
701 if (PyErr_Occurred() || k1.get() == nullptr) {
702 *error_msg =
703 ("The two dictionaries don't have the same set of keys. Failed to "
704 "fetch keys.");
705 return;
706 }
707 Safe_PyObjectPtr k2(MappingKeys(dict2));
708 if (PyErr_Occurred() || k2.get() == nullptr) {
709 *error_msg =
710 ("The two dictionaries don't have the same set of keys. Failed to "
711 "fetch keys.");
712 return;
713 }
714 *is_type_error = false;
715 *error_msg = tensorflow::strings::StrCat(
716 "The two dictionaries don't have the same set of keys. "
717 "First structure has keys ",
718 PyObjectToString(k1.get()), ", while second structure has keys ",
719 PyObjectToString(k2.get()));
720 }
721
722 // Returns true iff there were no "internal" errors. In other words,
723 // errors that has nothing to do with structure checking.
724 // If an "internal" error occurred, the appropriate Python error will be
725 // set and the caller can propage it directly to the user.
726 //
727 // Both `error_msg` and `is_type_error` must be non-null. `error_msg` must
728 // be empty.
729 // Leaves `error_msg` empty if structures matched. Else, fills `error_msg`
730 // with appropriate error and sets `is_type_error` to true iff
731 // the error to be raised should be TypeError.
AssertSameStructureHelper(PyObject * o1,PyObject * o2,bool check_types,string * error_msg,bool * is_type_error,const std::function<int (PyObject *)> & is_nested_helper,const std::function<ValueIteratorPtr (PyObject *)> & value_iterator_getter,bool check_composite_tensor_type_spec)732 bool AssertSameStructureHelper(
733 PyObject* o1, PyObject* o2, bool check_types, string* error_msg,
734 bool* is_type_error, const std::function<int(PyObject*)>& is_nested_helper,
735 const std::function<ValueIteratorPtr(PyObject*)>& value_iterator_getter,
736 bool check_composite_tensor_type_spec) {
737 DCHECK(error_msg);
738 DCHECK(is_type_error);
739 const bool is_nested1 = is_nested_helper(o1);
740 const bool is_nested2 = is_nested_helper(o2);
741 if (PyErr_Occurred()) return false;
742 if (is_nested1 != is_nested2) {
743 string seq_str = is_nested1 ? PyObjectToString(o1) : PyObjectToString(o2);
744 string non_seq_str =
745 is_nested1 ? PyObjectToString(o2) : PyObjectToString(o1);
746 *is_type_error = false;
747 *error_msg = tensorflow::strings::StrCat(
748 "Substructure \"", seq_str, "\" is a sequence, while substructure \"",
749 non_seq_str, "\" is not");
750 return true;
751 }
752
753 // Got to objects that are considered non-sequences. Note that in tf.data
754 // use case lists and sparse_tensors are not considered sequences. So finished
755 // checking, structures are the same.
756 if (!is_nested1) return true;
757
758 if (check_types) {
759 // Treat wrapped tuples as tuples.
760 tensorflow::Safe_PyObjectPtr o1_wrapped;
761 if (IsObjectProxy(o1)) {
762 o1_wrapped.reset(PyObject_GetAttrString(o1, "__wrapped__"));
763 o1 = o1_wrapped.get();
764 }
765 tensorflow::Safe_PyObjectPtr o2_wrapped;
766 if (IsObjectProxy(o2)) {
767 o2_wrapped.reset(PyObject_GetAttrString(o2, "__wrapped__"));
768 o2 = o2_wrapped.get();
769 }
770
771 const PyTypeObject* type1 = o1->ob_type;
772 const PyTypeObject* type2 = o2->ob_type;
773
774 // We treat two different namedtuples with identical name and fields
775 // as having the same type.
776 const PyObject* o1_tuple = IsNamedtuple(o1, false);
777 if (o1_tuple == nullptr) return false;
778 const PyObject* o2_tuple = IsNamedtuple(o2, false);
779 if (o2_tuple == nullptr) {
780 Py_DECREF(o1_tuple);
781 return false;
782 }
783 bool both_tuples = o1_tuple == Py_True && o2_tuple == Py_True;
784 Py_DECREF(o1_tuple);
785 Py_DECREF(o2_tuple);
786
787 if (both_tuples) {
788 const PyObject* same_tuples = SameNamedtuples(o1, o2);
789 if (same_tuples == nullptr) return false;
790 bool not_same_tuples = same_tuples != Py_True;
791 Py_DECREF(same_tuples);
792 if (not_same_tuples) {
793 *is_type_error = true;
794 *error_msg = tensorflow::strings::StrCat(
795 "The two namedtuples don't have the same sequence type. "
796 "First structure ",
797 PyObjectToString(o1), " has type ", type1->tp_name,
798 ", while second structure ", PyObjectToString(o2), " has type ",
799 type2->tp_name);
800 return true;
801 }
802 } else if (type1 != type2
803 /* If both sequences are list types, don't complain. This allows
804 one to be a list subclass (e.g. _ListWrapper used for
805 automatic dependency tracking.) */
806 && !(PyList_Check(o1) && PyList_Check(o2))
807 /* Two mapping types will also compare equal, making _DictWrapper
808 and dict compare equal. */
809 && !(IsMappingHelper(o1) && IsMappingHelper(o2))
810 /* For CompositeTensor & TypeSpec, we check below. */
811 && !(check_composite_tensor_type_spec &&
812 (IsCompositeTensor(o1) || IsCompositeTensor(o2)) &&
813 (IsTypeSpec(o1) || IsTypeSpec(o2)))) {
814 *is_type_error = true;
815 *error_msg = tensorflow::strings::StrCat(
816 "The two namedtuples don't have the same sequence type. "
817 "First structure ",
818 PyObjectToString(o1), " has type ", type1->tp_name,
819 ", while second structure ", PyObjectToString(o2), " has type ",
820 type2->tp_name);
821 return true;
822 }
823
824 if (PyDict_Check(o1) && PyDict_Check(o2)) {
825 if (PyDict_Size(o1) != PyDict_Size(o2)) {
826 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
827 return true;
828 }
829
830 PyObject* key;
831 Py_ssize_t pos = 0;
832 while (PyDict_Next(o1, &pos, &key, nullptr)) {
833 if (PyDict_GetItem(o2, key) == nullptr) {
834 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
835 return true;
836 }
837 }
838 } else if (IsMappingHelper(o1)) {
839 // Fallback for custom mapping types. Instead of using PyDict methods
840 // which stay in C, we call iter(o1).
841 if (PyMapping_Size(o1) != PyMapping_Size(o2)) {
842 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
843 return true;
844 }
845
846 Safe_PyObjectPtr iter(PyObject_GetIter(o1));
847 PyObject* key;
848 while ((key = PyIter_Next(iter.get())) != nullptr) {
849 if (!PyMapping_HasKey(o2, key)) {
850 SetDifferentKeysError(o1, o2, error_msg, is_type_error);
851 Py_DECREF(key);
852 return true;
853 }
854 Py_DECREF(key);
855 }
856 }
857 }
858
859 if (check_composite_tensor_type_spec &&
860 (IsCompositeTensor(o1) || IsCompositeTensor(o2))) {
861 Safe_PyObjectPtr owned_type_spec_1;
862 PyObject* type_spec_1 = o1;
863 if (IsCompositeTensor(o1)) {
864 owned_type_spec_1.reset(PyObject_GetAttrString(o1, "_type_spec"));
865 type_spec_1 = owned_type_spec_1.get();
866 }
867
868 Safe_PyObjectPtr owned_type_spec_2;
869 PyObject* type_spec_2 = o2;
870 if (IsCompositeTensor(o2)) {
871 owned_type_spec_2.reset(PyObject_GetAttrString(o2, "_type_spec"));
872 type_spec_2 = owned_type_spec_2.get();
873 }
874
875 // Two composite tensors are considered to have the same structure if
876 // they share a type spec that is a supertype of both of them. We do *not*
877 // use is_subtype_of, since that would prevent us from e.g. using a
878 // cond statement where the two sides have different shapes.
879
880 // TODO(b/206014848): We have to explicitly remove the names.
881 Safe_PyObjectPtr owned_nameless_type_spec_1(
882 PyObject_CallMethod(type_spec_1, "_without_tensor_names", nullptr));
883 Safe_PyObjectPtr owned_nameless_type_spec_2(
884 PyObject_CallMethod(type_spec_2, "_without_tensor_names", nullptr));
885 // TODO(b/222123181): Reconsider most_specific_common_supertype usage.
886 static char compatible_type[] = "most_specific_common_supertype";
887 static char argspec[] = "([O])";
888 Safe_PyObjectPtr struct_compatible(
889 PyObject_CallMethod(owned_nameless_type_spec_1.get(), compatible_type,
890 argspec, owned_nameless_type_spec_2.get()));
891 if (PyErr_Occurred()) {
892 return false;
893 }
894 if (struct_compatible.get() == Py_None) {
895 *is_type_error = false;
896 *error_msg = tensorflow::strings::StrCat(
897 "Incompatible CompositeTensor TypeSpecs: ",
898 PyObjectToString(type_spec_1), " vs. ",
899 PyObjectToString(type_spec_2));
900 return true;
901 }
902 }
903
904 ValueIteratorPtr iter1 = value_iterator_getter(o1);
905 ValueIteratorPtr iter2 = value_iterator_getter(o2);
906
907 if (!iter1->valid() || !iter2->valid()) return false;
908
909 while (true) {
910 Safe_PyObjectPtr v1 = iter1->next();
911 Safe_PyObjectPtr v2 = iter2->next();
912 if (v1 && v2) {
913 if (Py_EnterRecursiveCall(" in assert_same_structure")) {
914 return false;
915 }
916 bool no_internal_errors = AssertSameStructureHelper(
917 v1.get(), v2.get(), check_types, error_msg, is_type_error,
918 is_nested_helper, value_iterator_getter,
919 check_composite_tensor_type_spec);
920 Py_LeaveRecursiveCall();
921 if (!no_internal_errors) return false;
922 if (!error_msg->empty()) return true;
923 } else if (!v1 && !v2) {
924 // Done with all recursive calls. Structure matched.
925 return true;
926 } else {
927 *is_type_error = false;
928 *error_msg = tensorflow::strings::StrCat(
929 "The two structures don't have the same number of elements. ",
930 "First structure: ", PyObjectToString(o1),
931 ". Second structure: ", PyObjectToString(o2));
932 return true;
933 }
934 }
935 }
936
937 } // namespace
938
IsNested(PyObject * o)939 bool IsNested(PyObject* o) { return IsNestedHelper(o) == 1; }
IsMapping(PyObject * o)940 bool IsMapping(PyObject* o) { return IsMappingHelper(o) == 1; }
IsMutableMapping(PyObject * o)941 bool IsMutableMapping(PyObject* o) { return IsMutableMappingHelper(o) == 1; }
IsMappingView(PyObject * o)942 bool IsMappingView(PyObject* o) { return IsMappingViewHelper(o) == 1; }
IsAttrs(PyObject * o)943 bool IsAttrs(PyObject* o) { return IsAttrsHelper(o) == 1; }
IsTensor(PyObject * o)944 bool IsTensor(PyObject* o) { return IsTensorHelper(o) == 1; }
IsTensorSpec(PyObject * o)945 bool IsTensorSpec(PyObject* o) { return IsTensorSpecHelper(o) == 1; }
IsEagerTensorSlow(PyObject * o)946 bool IsEagerTensorSlow(PyObject* o) { return IsEagerTensorHelper(o) == 1; }
IsResourceVariable(PyObject * o)947 bool IsResourceVariable(PyObject* o) {
948 return IsResourceVariableHelper(o) == 1;
949 }
IsOwnedIterator(PyObject * o)950 bool IsOwnedIterator(PyObject* o) { return IsOwnedIteratorHelper(o) == 1; }
IsVariable(PyObject * o)951 bool IsVariable(PyObject* o) { return IsVariableHelper(o) == 1; }
IsIndexedSlices(PyObject * o)952 bool IsIndexedSlices(PyObject* o) { return IsIndexedSlicesHelper(o) == 1; }
IsDispatchable(PyObject * o)953 bool IsDispatchable(PyObject* o) { return IsDispatchableHelper(o) == 1; }
954
IsTuple(PyObject * o)955 bool IsTuple(PyObject* o) {
956 tensorflow::Safe_PyObjectPtr wrapped;
957 if (IsObjectProxy(o)) {
958 wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
959 o = wrapped.get();
960 }
961 return PyTuple_Check(o);
962 }
963
964 // Work around a writable-strings warning with Python 2's PyMapping_Keys macro,
965 // and while we're at it give them consistent behavior by making sure the
966 // returned value is a list.
967 //
968 // As with PyMapping_Keys, returns a new reference.
969 //
970 // On failure, returns nullptr.
MappingKeys(PyObject * o)971 PyObject* MappingKeys(PyObject* o) {
972 #if PY_MAJOR_VERSION >= 3
973 return PyMapping_Keys(o);
974 #else
975 static char key_method_name[] = "keys";
976 Safe_PyObjectPtr raw_result(PyObject_CallMethod(o, key_method_name, nullptr));
977 if (PyErr_Occurred() || raw_result.get() == nullptr) {
978 return nullptr;
979 }
980 return PySequence_Fast(
981 raw_result.get(),
982 "The '.keys()' method of a custom mapping returned a non-sequence.");
983 #endif
984 }
985
Flatten(PyObject * nested,bool expand_composites)986 PyObject* Flatten(PyObject* nested, bool expand_composites) {
987 PyObject* list = PyList_New(0);
988 const std::function<int(PyObject*)>& is_nested_helper =
989 expand_composites ? IsNestedOrCompositeHelper : IsNestedHelper;
990 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
991 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
992 if (FlattenHelper(nested, list, is_nested_helper, get_value_iterator)) {
993 return list;
994 } else {
995 Py_DECREF(list);
996 return nullptr;
997 }
998 }
999
IsNestedOrComposite(PyObject * o)1000 bool IsNestedOrComposite(PyObject* o) {
1001 return IsNestedOrCompositeHelper(o) == 1;
1002 }
1003
IsCompositeTensor(PyObject * o)1004 bool IsCompositeTensor(PyObject* o) { return IsCompositeTensorHelper(o) == 1; }
1005
IsTypeSpec(PyObject * o)1006 bool IsTypeSpec(PyObject* o) { return IsTypeSpecHelper(o) == 1; }
1007
IsNestedForData(PyObject * o)1008 bool IsNestedForData(PyObject* o) { return IsNestedForDataHelper(o) == 1; }
1009
FlattenForData(PyObject * nested)1010 PyObject* FlattenForData(PyObject* nested) {
1011 PyObject* list = PyList_New(0);
1012 if (FlattenHelper(nested, list, IsNestedForDataHelper,
1013 GetValueIteratorForData)) {
1014 return list;
1015 } else {
1016 Py_DECREF(list);
1017 return nullptr;
1018 }
1019 }
1020
IsNamedtuple(PyObject * o,bool strict)1021 PyObject* IsNamedtuple(PyObject* o, bool strict) {
1022 // Some low-level CPython calls do not work with wrapt.ObjectProxy, so they
1023 // require some unwrapping if we want to treat them like the objects they're
1024 // wrapping.
1025 tensorflow::Safe_PyObjectPtr o_wrapped;
1026 if (IsObjectProxy(o)) {
1027 o_wrapped.reset(PyObject_GetAttrString(o, "__wrapped__"));
1028 o = o_wrapped.get();
1029 }
1030
1031 // Must be subclass of tuple
1032 if (!PyTuple_Check(o)) {
1033 Py_RETURN_FALSE;
1034 }
1035
1036 // If strict, o.__class__.__base__ must be tuple
1037 if (strict) {
1038 PyObject* klass = PyObject_GetAttrString(o, "__class__");
1039 if (klass == nullptr) return nullptr;
1040 PyObject* base = PyObject_GetAttrString(klass, "__base__");
1041 Py_DECREF(klass);
1042 if (base == nullptr) return nullptr;
1043
1044 const PyTypeObject* base_type = reinterpret_cast<PyTypeObject*>(base);
1045 // built-in object types are singletons
1046 bool tuple_base = base_type == &PyTuple_Type;
1047 Py_DECREF(base);
1048 if (!tuple_base) {
1049 Py_RETURN_FALSE;
1050 }
1051 }
1052
1053 // o must have attribute '_fields' and every element in
1054 // '_fields' must be a string.
1055 int has_fields = PyObject_HasAttrString(o, "_fields");
1056 if (!has_fields) {
1057 Py_RETURN_FALSE;
1058 }
1059
1060 Safe_PyObjectPtr fields = make_safe(PyObject_GetAttrString(o, "_fields"));
1061 int is_instance = IsInstanceOfRegisteredType(fields.get(), "Sequence");
1062 if (is_instance == 0) {
1063 Py_RETURN_FALSE;
1064 } else if (is_instance == -1) {
1065 return nullptr;
1066 }
1067
1068 Safe_PyObjectPtr seq = make_safe(PySequence_Fast(fields.get(), ""));
1069 const Py_ssize_t s = PySequence_Fast_GET_SIZE(seq.get());
1070 for (Py_ssize_t i = 0; i < s; ++i) {
1071 // PySequence_Fast_GET_ITEM returns borrowed ref
1072 PyObject* elem = PySequence_Fast_GET_ITEM(seq.get(), i);
1073 if (!IsString(elem)) {
1074 Py_RETURN_FALSE;
1075 }
1076 }
1077
1078 Py_RETURN_TRUE;
1079 }
1080
SameNamedtuples(PyObject * o1,PyObject * o2)1081 PyObject* SameNamedtuples(PyObject* o1, PyObject* o2) {
1082 Safe_PyObjectPtr f1 = make_safe(PyObject_GetAttrString(o1, "_fields"));
1083 Safe_PyObjectPtr f2 = make_safe(PyObject_GetAttrString(o2, "_fields"));
1084 if (f1 == nullptr || f2 == nullptr) {
1085 PyErr_SetString(
1086 PyExc_RuntimeError,
1087 "Expected namedtuple-like objects (that have _fields attr)");
1088 return nullptr;
1089 }
1090
1091 if (PyObject_RichCompareBool(f1.get(), f2.get(), Py_NE)) {
1092 Py_RETURN_FALSE;
1093 }
1094
1095 if (GetClassName(o1).compare(GetClassName(o2)) == 0) {
1096 Py_RETURN_TRUE;
1097 } else {
1098 Py_RETURN_FALSE;
1099 }
1100 }
1101
AssertSameStructure(PyObject * o1,PyObject * o2,bool check_types,bool expand_composites)1102 PyObject* AssertSameStructure(PyObject* o1, PyObject* o2, bool check_types,
1103 bool expand_composites) {
1104 const std::function<int(PyObject*)>& is_nested_helper =
1105 expand_composites ? IsNestedOrCompositeHelper : IsNestedHelper;
1106 const std::function<ValueIteratorPtr(PyObject*)>& get_value_iterator =
1107 expand_composites ? GetValueIteratorForComposite : GetValueIterator;
1108 const bool check_composite_tensor_type_spec = expand_composites;
1109 string error_msg;
1110 bool is_type_error = false;
1111 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1112 is_nested_helper, get_value_iterator,
1113 check_composite_tensor_type_spec);
1114 if (PyErr_Occurred()) {
1115 // Don't hide Python exceptions while checking (e.g. errors fetching keys
1116 // from custom mappings).
1117 return nullptr;
1118 }
1119 if (!error_msg.empty()) {
1120 PyErr_SetString(
1121 is_type_error ? PyExc_TypeError : PyExc_ValueError,
1122 tensorflow::strings::StrCat(
1123 "The two structures don't have the same nested structure.\n\n",
1124 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1125 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1126 .c_str());
1127 return nullptr;
1128 }
1129 Py_RETURN_NONE;
1130 }
1131
AssertSameStructureForData(PyObject * o1,PyObject * o2,bool check_types)1132 PyObject* AssertSameStructureForData(PyObject* o1, PyObject* o2,
1133 bool check_types) {
1134 string error_msg;
1135 bool is_type_error = false;
1136 AssertSameStructureHelper(o1, o2, check_types, &error_msg, &is_type_error,
1137 IsNestedForDataHelper, GetValueIterator, false);
1138 if (PyErr_Occurred()) {
1139 // Don't hide Python exceptions while checking (e.g. errors fetching keys
1140 // from custom mappings).
1141 return nullptr;
1142 }
1143 if (!error_msg.empty()) {
1144 PyErr_SetString(
1145 is_type_error ? PyExc_TypeError : PyExc_ValueError,
1146 tensorflow::strings::StrCat(
1147 "The two structures don't have the same nested structure.\n\n",
1148 "First structure: ", PyObjectToString(o1), "\n\nSecond structure: ",
1149 PyObjectToString(o2), "\n\nMore specifically: ", error_msg)
1150 .c_str());
1151 return nullptr;
1152 }
1153 Py_RETURN_NONE;
1154 }
1155
1156 } // namespace swig
1157 } // namespace tensorflow
1158