1 #include <ATen/core/Dict.h>
2 #include <ATen/core/Formatting.h>
3 #include <ATen/core/class_type.h>
4 #include <ATen/core/enum_type.h>
5 #include <ATen/core/function.h>
6 #include <ATen/core/ivalue.h>
7 #include <ATen/core/jit_type.h>
8 #include <ATen/core/stack.h>
9 #include <ATen/core/type_factory.h>
10 #include <c10/util/StringUtil.h>
11 #include <c10/util/hash.h>
12 #include <c10/util/irange.h>
13 #include <cmath>
14 #include <iostream>
15 #include <utility>
16
17 namespace c10 {
_fastEqualsForContainer(const IValue & lhs,const IValue & rhs)18 bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs) {
19 if (lhs.is(rhs)) {
20 // Like Python, for containers we consider identity equality to be
21 // sufficient but not necessary for value equality
22 return true;
23 }
24 return lhs == rhs;
25 }
26
27 namespace ivalue {
28
29 // This is in ivalue.cpp because we need to access Type::annotation_str, which
30 // is declared in jit_type.h
checkCustomClassType(const ClassType * expected_type,const Type * actual_type)31 void checkCustomClassType(const ClassType* expected_type, const Type* actual_type) {
32 // NB: doing pointer comparison here
33 // If in the future there ever arises a need to call operator== on custom class
34 // Type's, this needs to be changed!
35 TORCH_CHECK(actual_type == static_cast<const Type*>(expected_type),
36 "Tried to convert an IValue of type ",
37 actual_type ? actual_type->repr_str() : std::string("*NULL*"),
38 " to custom class type ",
39 expected_type ? expected_type->repr_str() : std::string("*NULL*"));
40 }
41
create(std::string str_)42 TORCH_API c10::intrusive_ptr<ConstantString> ConstantString::create(
43 std::string str_) {
44 return c10::make_intrusive<ConstantString>(std::move(str_));
45 }
46
create(c10::string_view str_)47 TORCH_API c10::intrusive_ptr<ConstantString> ConstantString::create(
48 c10::string_view str_) {
49 return c10::make_intrusive<ConstantString>(std::string(str_));
50 }
51
create(const char * str_)52 TORCH_API c10::intrusive_ptr<ConstantString> ConstantString::create(
53 const char* str_) {
54 return c10::make_intrusive<ConstantString>(std::string(str_));
55 }
56
operator ==(const ivalue::Tuple & lhs,const ivalue::Tuple & rhs)57 bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
58 return lhs.size() == rhs.size() &&
59 // see [container equality]
60 std::equal(
61 lhs.elements().cbegin(),
62 lhs.elements().cend(),
63 rhs.elements().cbegin(),
64 _fastEqualsForContainer);
65 }
66
operator <<(std::ostream & out,const ivalue::EnumHolder & v)67 std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
68 out << v.qualifiedClassName() << "." << v.name();
69 return out;
70 }
71
operator ==(const ivalue::EnumHolder & lhs,const ivalue::EnumHolder & rhs)72 bool operator==(const ivalue::EnumHolder& lhs, const ivalue::EnumHolder& rhs) {
73 return lhs.name() == rhs.name() && *rhs.type() == *lhs.type();
74 }
75
qualifiedClassName() const76 const std::string& ivalue::EnumHolder::qualifiedClassName() const {
77 return type_->qualifiedClassName().qualifiedName();
78 }
79
unqualifiedClassName() const80 const std::string& ivalue::EnumHolder::unqualifiedClassName() const {
81 return type_->qualifiedClassName().name();
82 }
83
84 } // namespace ivalue
85
get(const IValue & v)86 c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
87 switch (v.tag) {
88 case Tag::None:
89 return NoneType::get();
90 case Tag::Tensor:
91 return TensorType::create(v.toTensor());
92 case Tag::Storage:
93 return StorageType::get();
94 case Tag::Double:
95 return FloatType::get();
96 case Tag::ComplexDouble:
97 return ComplexType::get();
98 case Tag::Int:
99 return IntType::get();
100 case Tag::SymInt:
101 return c10::SymIntType::get();
102 case Tag::SymFloat:
103 return c10::SymFloatType::get();
104 case Tag::SymBool:
105 return c10::SymBoolType::get();
106 case Tag::Bool:
107 return BoolType::get();
108 case Tag::String:
109 return StringType::get();
110 case Tag::Blob:
111 return AnyType::get();
112 case Tag::GenericDict: {
113 auto d = v.toGenericDict();
114 return DictType::create(d.keyType(), d.valueType());
115 }
116 case Tag::GenericList:
117 return ListType::create(v.toList().elementType());
118 case Tag::Await:
119 return AwaitType::create(v.toAwait()->elementType());
120 case Tag::Future:
121 return FutureType::create(v.toFuture()->elementType());
122 case Tag::RRef:
123 return RRefType::create(v.toRRef()->type());
124 case Tag::Device:
125 return DeviceObjType::get();
126 case Tag::Stream:
127 return StreamObjType::get();
128 case Tag::Object:
129 return v.toObjectRef().type();
130 case Tag::PyObject:
131 return PyObjectType::get();
132 case Tag::Uninitialized:
133 return AnyType::get();
134 case Tag::Capsule:
135 return CapsuleType::get();
136 case Tag::Tuple:
137 return v.toTupleRef().type();
138 case Tag::Generator:
139 return GeneratorType::get();
140 case Tag::Quantizer:
141 return QuantizerType::get();
142 case Tag::Enum:
143 return v.toEnumHolder()->type();
144 }
145 // switch above is complete but this silences compiler warnings
146 TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()");
147
148 // This static_assert has to go into some IValue member function; I
149 // chose this one. It's not in the class body because that's in
150 // ivalue.h, which is a very high-fanout header file and we want to
151 // minimize build time.
152 static_assert(
153 kNumTags <= 32,
154 "IValue::isIntrusivePtr needs to be updated because it assumes there are at most 32 tags");
155 }
156
visit(const std::function<bool (const IValue &)> & visitor) const157 void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
158 if (visitor(*this)) {
159 // Shortcut
160 return;
161 }
162 switch (this->tag) {
163 case Tag::Tuple:
164 case Tag::GenericList: {
165 c10::ArrayRef<IValue> elems;
166 if (isTuple()) {
167 elems = this->toTupleRef().elements();
168 } else {
169 elems = this->toListRef();
170 }
171 for (auto& elem : elems) {
172 elem.visit(visitor);
173 }
174 break;
175 }
176 case Tag::GenericDict:
177 for (const auto& pair : this->toGenericDict()) {
178 pair.value().visit(visitor);
179 pair.key().visit(visitor);
180 }
181 break;
182 case Tag::Object: {
183 auto obj_type = type()->expect<ClassType>();
184 auto obj_value = toObject();
185 auto attributes = obj_type->getAttributes();
186 for (const auto& attr: attributes) {
187 auto attribute = obj_value->getAttr(attr.getName());
188 attribute.visit(visitor);
189 }
190 break;
191 }
192 case Tag::PyObject: {
193 c10::intrusive_ptr<at::ivalue::PyObjectHolder> py_obj = toPyObjectHolder();
194 auto match = py_obj->tryToInferType();
195 if (match.success()) {
196 auto contained_value = py_obj->toIValue(match.type());
197 contained_value.visit(visitor);
198 }
199 break;
200 }
201 default:
202 break;
203 }
204 }
205
getSubValues(HashAliasedIValues & subValues) const206 void IValue::getSubValues(HashAliasedIValues& subValues) const {
207 switch (this->tag) {
208 case Tag::Tensor:
209 subValues.insert(*this);
210 return;
211 case Tag::Tuple:
212 case Tag::GenericList: {
213 subValues.insert(*this);
214 c10::ArrayRef<IValue> elems;
215 if (isTuple()) {
216 elems = this->toTupleRef().elements();
217 } else {
218 elems = this->toListRef();
219 }
220 for (auto& elem : elems) {
221 elem.getSubValues(subValues);
222 }
223 break;
224 }
225 case Tag::GenericDict:
226 subValues.insert(*this);
227 for (const auto& pair : this->toGenericDict()) {
228 pair.value().getSubValues(subValues);
229 pair.key().getSubValues(subValues);
230 }
231 break;
232 case Tag::Object: {
233 // Record Object IValue and its attributes.
234 subValues.insert(*this);
235 auto obj_type = type()->expect<ClassType>();
236 auto obj_value = toObject();
237 auto attributes = obj_type->getAttributes();
238 for (const auto& attr: attributes) {
239 auto attribute = obj_value->getAttr(attr.getName());
240 attribute.getSubValues(subValues);
241 }
242 break;
243 }
244 case Tag::PyObject: {
245 subValues.insert(*this);
246 c10::intrusive_ptr<at::ivalue::PyObjectHolder> py_obj = toPyObjectHolder();
247 auto match = py_obj->tryToInferType();
248 TORCH_CHECK_TYPE(match.success(),
249 "Cannot infer type of ", py_obj->toStr(), ": ", match.reason());
250 auto contained_value = py_obj->toIValue(match.type());
251 contained_value.getSubValues(subValues);
252 break;
253 }
254 case Tag::Future:
255 case Tag::Await:
256 case Tag::Device:
257 case Tag::Uninitialized:
258 case Tag::Capsule:
259 TORCH_CHECK_TYPE(
260 false, "Cannot inspect value of type ", this->tagKind());
261 default:
262 // don't record scalars.
263 break;
264 }
265 }
266
overlaps(const IValue & rhs) const267 bool IValue::overlaps(const IValue& rhs) const {
268 HashAliasedIValues rhsSubValues, thisSubValues;
269 rhs.getSubValues(rhsSubValues);
270 getSubValues(thisSubValues);
271 for (auto& sub : thisSubValues) {
272 if (rhsSubValues.count(sub)) {
273 return true;
274 }
275 }
276 return false;
277 }
278
operator !=(const IValue & lhs,const IValue & rhs)279 bool operator!=(const IValue& lhs, const IValue& rhs) {
280 return !(lhs == rhs);
281 }
282
operator ==(const IValue & lhs,const IValue & rhs)283 bool operator==(const IValue& lhs, const IValue& rhs) {
284 IValue eq = lhs.equals(rhs);
285 if (eq.isBool()) {
286 return eq.toBool();
287 }
288 // The only case we don't return bool is for tensor comparison. In Python,
289 // `bool()` is called on the return value of `__eq__` if the return value is
290 // not a boolean. Mimic that behavior here.
291 TORCH_INTERNAL_ASSERT(eq.isTensor());
292 return eq.toTensor().is_nonzero();
293 }
294
ptrEqual(const IValue & lhs,const IValue & rhs)295 bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) {
296 TORCH_INTERNAL_ASSERT(lhs.isIntrusivePtr());
297 TORCH_INTERNAL_ASSERT(rhs.isIntrusivePtr());
298 return lhs.tag == rhs.tag &&
299 lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
300 }
301
equals(const IValue & rhs) const302 IValue IValue::equals(const IValue& rhs) const {
303 const IValue& lhs = *this;
304 switch (lhs.tag) {
305 case Tag::None:
306 // In Python you're not supposed to do this comparison apparently. Not
307 // sure if we should warn here or what
308 return rhs.isNone();
309 case Tag::Tensor: {
310 if (!rhs.isTensor()) {
311 return false;
312 }
313 return lhs.toTensor().eq(rhs.toTensor());
314 }
315 case Tag::Storage:
316 return rhs.isStorage() && lhs.toStorage().unsafeGetStorageImpl() == rhs.toStorage().unsafeGetStorageImpl();
317 case Tag::Double:
318 return rhs.isDouble() && lhs.toDouble() == rhs.toDouble();
319 case Tag::ComplexDouble:
320 return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
321 case Tag::Int:
322 return rhs.isInt() && lhs.toInt() == rhs.toInt();
323 case Tag::SymInt:
324 return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt();
325 case Tag::SymFloat:
326 return rhs.isSymFloat() && lhs.toSymFloat() == rhs.toSymFloat();
327 case Tag::SymBool:
328 return rhs.isSymBool() && lhs.toSymBool() == rhs.toSymBool();
329 case Tag::Bool:
330 return rhs.isBool() && lhs.toBool() == rhs.toBool();
331 case Tag::String:
332 return rhs.isString() && lhs.toStringRef() == rhs.toStringRef();
333 case Tag::GenericDict:
334 return rhs.isGenericDict() && lhs.toGenericDict() == rhs.toGenericDict();
335 case Tag::Tuple:
336 return rhs.isTuple() && *lhs.toTuple() == *rhs.toTuple();
337 case Tag::Stream:
338 return rhs.isStream() && lhs.toStream() == rhs.toStream();
339 case Tag::Device:
340 return rhs.isDevice() && lhs.toDevice() == rhs.toDevice();
341 case Tag::GenericList:
342 return rhs.isList() && lhs.toList() == rhs.toList();
343 case Tag::Blob:
344 case Tag::Future:
345 case Tag::Await:
346 case Tag::RRef:
347 case Tag::Object:
348 case Tag::PyObject:
349 case Tag::Capsule:
350 case Tag::Generator:
351 case Tag::Quantizer:
352 return ptrEqual(lhs, rhs);
353 case Tag::Enum:
354 return lhs.toEnumHolder()->is(*rhs.toEnumHolder());
355 case Tag::Uninitialized:
356 // Unitialized ivalues show up in no-ops when the compiler can prove a
357 // value will never be used. Just return false on any equality comparison.
358 return false;
359 }
360 // the above switch should be exhaustive
361 TORCH_INTERNAL_ASSERT(false, "we should never reach here")
362 }
363
hash(const IValue & v)364 size_t IValue::hash(const IValue& v) {
365 switch (v.tag) {
366 case Tag::None:
367 return 0;
368 case Tag::Bool:
369 return c10::get_hash(v.payload.u.as_bool);
370 case Tag::Double:
371 return c10::get_hash(v.payload.u.as_double);
372 case Tag::Tensor:
373 // Tensor __hash__ is equivalent to `id()`, so take the pointer value of
374 // the tensor to emulate it
375 return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl());
376 // NOLINTNEXTLINE(bugprone-branch-clone)
377 case Tag::Storage:
378 return c10::get_hash(v.payload.u.as_int);
379 case Tag::Int:
380 return c10::get_hash(v.payload.u.as_int);
381 // NB: these are technically strict aliasing violations
382 case Tag::SymInt:
383 return c10::get_hash(v.payload.u.as_int);
384 case Tag::SymFloat:
385 return c10::get_hash(v.payload.u.as_int);
386 case Tag::SymBool:
387 return c10::get_hash(v.payload.u.as_int);
388 case Tag::String:
389 return c10::get_hash(v.toStringRef());
390 case Tag::Tuple:
391 return c10::get_hash(*v.toTuple());
392 case Tag::Device:
393 return c10::get_hash(v.toDevice());
394 case Tag::GenericDict:
395 case Tag::GenericList:
396 case Tag::Blob:
397 case Tag::Future:
398 case Tag::Await:
399 case Tag::RRef:
400 case Tag::Object:
401 case Tag::PyObject:
402 case Tag::Capsule:
403 case Tag::Generator:
404 case Tag::Quantizer:
405 case Tag::ComplexDouble:
406 case Tag::Enum:
407 case Tag::Stream:
408 case Tag::Uninitialized:
409 throw std::runtime_error(
410 "unhashable type: '" + v.type()->repr_str() + "'");
411 }
412 // the above switch should be exhaustive
413 TORCH_INTERNAL_ASSERT(false, "we should never reach here")
414 }
415
isUndefinedTensor(const IValue & iv)416 static bool isUndefinedTensor(const IValue& iv) {
417 return iv.isTensor() && !iv.toTensor().defined();
418 }
419
is(const IValue & rhs) const420 bool IValue::is(const IValue& rhs) const {
421 const IValue& lhs = *this;
422 // Special handling for undefined tensors:
423 // 1. Undefined_tensor is None and vice versa.
424 if ((isUndefinedTensor(lhs) && rhs.isNone()) ||
425 (lhs.isNone() && isUndefinedTensor(rhs))) {
426 return true;
427 }
428 // 2. Undefined_tensor is Undefined_tensor.
429 if (isUndefinedTensor(lhs) && isUndefinedTensor(rhs)) {
430 return true;
431 }
432
433 if (lhs.isTensor()) {
434 // Use the standard way of comparing two tensors for identity
435 return rhs.isTensor() && lhs.toTensor().is_same(rhs.toTensor());
436 }
437
438 if (lhs.isIntrusivePtr()) {
439 return rhs.isIntrusivePtr() && ptrEqual(lhs, rhs);
440 }
441 return lhs == rhs;
442 }
443
444 template <typename T>
isListOf() const445 inline bool IValue::isListOf() const {
446 // note: avoids calling type() to avoid extra referencing counting for the returned type.
447 if (!isList()) {
448 return false;
449 }
450 const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
451 if (ty->kind() == T::Kind) {
452 return true;
453 }
454 return *ty == *TypeFactory::get<T>();
455 }
456
isDoubleList() const457 bool IValue::isDoubleList() const {
458 return isListOf<c10::FloatType>();
459 }
460
isComplexDoubleList() const461 bool IValue::isComplexDoubleList() const {
462 return isListOf<c10::ComplexType>();
463 }
464
isTensorList() const465 bool IValue::isTensorList() const {
466 return isListOf<c10::TensorType>();
467 }
468
isOptionalTensorList() const469 bool IValue::isOptionalTensorList() const {
470 if (!isList()) {
471 return false;
472 }
473 const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
474 const auto& expected_ty = c10::getTypePtr<std::optional<at::Tensor>>();
475 return expected_ty == ty;
476 }
477
isIntList() const478 bool IValue::isIntList() const {
479 return isListOf<c10::IntType>();
480 }
481
isSymIntList() const482 bool IValue::isSymIntList() const {
483 return isListOf<c10::SymIntType>();
484 }
485
isBoolList() const486 bool IValue::isBoolList() const {
487 return isListOf<c10::BoolType>();
488 }
489
490 namespace {
491
492 using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
493
494 template <class T>
printList(std::ostream & out,const T & list,const std::string & start,const std::string & finish,const IValueFormatter & formatter)495 std::ostream& printList(
496 std::ostream& out,
497 const T& list,
498 const std::string& start,
499 const std::string& finish,
500 const IValueFormatter& formatter) {
501 out << start;
502 for (const auto i : c10::irange(list.size())) {
503 if (i > 0) {
504 out << ", ";
505 }
506 formatter(out, IValue(list[i]));
507 }
508 out << finish;
509 return out;
510 }
511
512 // Properly disambiguate the type of an empty list
printMaybeAnnotatedList(std::ostream & out,const IValue & the_list,const IValueFormatter & formatter)513 std::ostream& printMaybeAnnotatedList(
514 std::ostream& out,
515 const IValue& the_list,
516 const IValueFormatter& formatter) {
517 auto list_elem_type = the_list.type()->containedType(0);
518 if (the_list.toListRef().empty() ||
519 !elementTypeCanBeInferredFromMembers(list_elem_type)) {
520 out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
521 printList(out, the_list.toListRef(), "[", "]", formatter);
522 out << ")";
523 return out;
524 } else {
525 return printList(out, the_list.toListRef(), "[", "]", formatter);
526 }
527 }
528
529 template <typename Dict>
printDict(std::ostream & out,const Dict & v,const IValueFormatter & formatter)530 std::ostream& printDict(
531 std::ostream& out,
532 const Dict& v,
533 const IValueFormatter& formatter) {
534 out << "{";
535
536 bool first = true;
537 for (const auto& pair : v) {
538 if (!first) {
539 out << ", ";
540 }
541
542 formatter(out, pair.key());
543 out << ": ";
544 formatter(out, pair.value());
545 first = false;
546 }
547
548 out << "}";
549 return out;
550 }
551 }
552
553 // Properly disambiguate the type of an empty dict
printMaybeAnnotatedDict(std::ostream & out,const IValue & the_dict,const IValueFormatter & formatter)554 static std::ostream& printMaybeAnnotatedDict(
555 std::ostream& out,
556 const IValue& the_dict,
557 const IValueFormatter& formatter) {
558 auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
559 if (the_dict.toGenericDict().empty() ||
560 !elementTypeCanBeInferredFromMembers(value_type)) {
561 out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
562 printDict(out, the_dict.toGenericDict(), formatter) << ")";
563 } else {
564 return printDict(out, the_dict.toGenericDict(), formatter);
565 }
566 return out;
567 }
568
printComplex(std::ostream & out,const IValue & v)569 static std::ostream& printComplex(std::ostream & out, const IValue & v) {
570 c10::complex<double> d = v.toComplexDouble();
571 IValue real(d.real()), imag(std::abs(d.imag()));
572 auto sign = "";
573 if (d.imag() >= 0) {
574 sign = "+";
575 } else {
576 sign = "-";
577 }
578 return out << real << sign << imag << "j";
579 }
580
repr(std::ostream & out,std::function<bool (std::ostream &,const IValue & v)> customFormatter) const581 std::ostream& IValue::repr(
582 std::ostream& out,
583 std::function<bool(std::ostream&, const IValue& v)>
584 customFormatter) const {
585 // First check if the caller has provided a custom formatter. Use that if possible.
586 if (customFormatter(out, *this)) {
587 return out;
588 }
589
590 const IValue& v = *this;
591 // continue to use custom formatter in recursion
592 auto formatter = [&](std::ostream& out, const IValue& input) {
593 input.repr(out, customFormatter);
594 };
595 switch (v.tag) {
596 case IValue::Tag::None:
597 return out << v.toNone();
598 case IValue::Tag::Double: {
599 double d = v.toDouble();
600 int c = std::fpclassify(d);
601 if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
602 int64_t i = int64_t(d);
603 if (double(i) == d) {
604 // -0.0 (signed zero) needs to be parsed as -0.
605 if (i == 0 && std::signbit(d)) {
606 return out << "-" << i << ".";
607 }
608 return out << i << ".";
609 }
610 }
611 auto orig_prec = out.precision();
612 return out << std::setprecision(std::numeric_limits<double>::max_digits10)
613 << d << std::setprecision(static_cast<int>(orig_prec));
614 }
615 case IValue::Tag::ComplexDouble: {
616 return printComplex(out, v);
617 }
618 case IValue::Tag::Int:
619 return out << v.toInt();
620 case IValue::Tag::SymInt:
621 return out << v.toSymInt();
622 case IValue::Tag::SymFloat:
623 return out << v.toSymFloat();
624 case IValue::Tag::SymBool:
625 return out << v.toSymBool();
626 case IValue::Tag::Bool:
627 return out << (v.toBool() ? "True" : "False");
628 case IValue::Tag::Tuple: {
629 const auto& elements = v.toTupleRef().elements();
630 const auto& finish = elements.size() == 1 ? ",)" : ")";
631 return printList(out, elements, "(", finish, formatter);
632 }
633 case IValue::Tag::String:
634 c10::printQuotedString(out, v.toStringRef());
635 return out;
636 case IValue::Tag::GenericList: {
637 return printMaybeAnnotatedList(out, *this, formatter);
638 }
639 case IValue::Tag::Device: {
640 std::stringstream device_stream;
641 device_stream << v.toDevice();
642 out << "torch.device(";
643 c10::printQuotedString(out, device_stream.str());
644 return out << ")";
645 }
646 case IValue::Tag::Generator: {
647 auto generator = v.toGenerator();
648 out << "torch.Generator(device=";
649 c10::printQuotedString(out, generator.device().str());
650 out << ", seed=" << generator.current_seed() << ")";
651 return out;
652 }
653 case IValue::Tag::GenericDict:
654 return printMaybeAnnotatedDict(out, v, formatter);
655 case IValue::Tag::Enum: {
656 auto enum_holder = v.toEnumHolder();
657 return out << enum_holder->qualifiedClassName() << "." <<
658 enum_holder->name();
659 }
660 case IValue::Tag::Object: {
661 TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind(), ". Perhaps you've frozen a module with custom classes?");
662 }
663 default:
664 TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind());
665 }
666 }
667
simpleClassTypeArg(const Argument & arg,const ClassTypePtr & type)668 static bool simpleClassTypeArg(const Argument& arg, const ClassTypePtr& type) {
669 return arg.type() == type && !arg.kwarg_only() && !arg.default_value();
670 }
671
checkObjectSortSchema(const c10::ClassTypePtr & t,std::stringstream & why_not)672 torch::jit::Function* checkObjectSortSchema(const c10::ClassTypePtr& t, std::stringstream& why_not) {
673 if (auto method = t->findMethod("__lt__")) {
674 const auto& lt_schema = method->getSchema();
675 const auto& schema_args = lt_schema.arguments();
676 bool error =
677 (schema_args.size() != 2 ||
678 !simpleClassTypeArg(schema_args[0], t) ||
679 !simpleClassTypeArg(schema_args[1], t) ||
680 lt_schema.returns().size() != 1 ||
681 lt_schema.returns()[0].type() != BoolType::get());
682 if (!error) {
683 return method;
684 }
685 }
686
687 why_not << "To sort a list of " << t->repr_str()
688 << " it must define a "
689 << "__lt__ method with two inputs of type "
690 << t->repr_str() << " that "
691 << "returns a bool";
692 return nullptr;
693 }
694
getLessThanComparator(const IValue & v)695 IValueComparator getLessThanComparator(const IValue& v) {
696 if (v.isTensor()) {
697 return [](const IValue& a, const IValue& b) {
698 return a.toTensor().lt(b.toTensor()).is_nonzero();
699 };
700 }
701
702 if (v.isDouble()) {
703 return [](const IValue& a, const IValue& b) {
704 return a.toDouble() < b.toDouble();
705 };
706 }
707
708 if (v.isInt()) {
709 return [](const IValue& a, const IValue& b) {
710 return a.toInt() < b.toInt();
711 };
712 }
713
714 if (v.isBool()) {
715 return [](const IValue& a, const IValue& b) {
716 return a.toBool() == false && b.toBool() == true;
717 };
718 }
719
720 if (v.isString()) {
721 return [](const IValue& a, const IValue& b) {
722 return a.toStringRef() < b.toStringRef();
723 };
724 }
725
726 if (v.isTuple()) {
727 const auto& elements = v.toTupleRef().elements();
728 size_t n = elements.size();
729
730 std::vector<IValueComparator> elements_lts;
731 elements_lts.reserve(n);
732 for (const auto i : c10::irange(n)) {
733 elements_lts.push_back(getLessThanComparator(elements[i]));
734 }
735
736 return [elements_lts=std::move(elements_lts), n](const IValue& a, const IValue& b) {
737 const auto& a_elements = a.toTupleRef().elements();
738 const auto& b_elements = b.toTupleRef().elements();
739
740 for (const auto i : c10::irange(n)) {
741 if (elements_lts[i](a_elements[i], b_elements[i])) {
742 return true;
743 }
744 if (a_elements[i] == b_elements[i]) {
745 continue;
746 }
747 return false;
748 }
749 // Reaching here means two tuples are equal.
750 return false;
751 };
752 }
753
754 if (v.isObject()) {
755 std::stringstream why_not;
756 torch::jit::Function* lt_func =
757 checkObjectSortSchema(v.type()->expect<ClassType>(), why_not);
758 if (!lt_func) {
759 AT_ERROR(why_not.str());
760 }
761
762 return [lt_func](const IValue& a, const IValue& b) {
763 // Quick pass to satisfy "strict weak ordering" requirement
764 if (a.is(b)) {
765 return false;
766 }
767 torch::jit::Stack sort_stack;
768 sort_stack.push_back(a);
769 sort_stack.push_back(b);
770 lt_func->run(sort_stack);
771 return torch::jit::pop(sort_stack).toBool();
772 };
773 }
774
775 AT_ERROR("IValues of type: ", v.tagKind(), " are not comparable");
776 }
777
getGreaterThanComparator(const IValue & v)778 IValueComparator getGreaterThanComparator(const IValue& v) {
779 auto lt = getLessThanComparator(v);
780 return [lt = std::move(lt)](const IValue& a, const IValue& b) {
781 return lt(b, a); // gt(a, b) === lt(b, a)
782 };
783 }
784
operator <<(std::ostream & out,const IValue & v)785 std::ostream& operator<<(std::ostream & out, const IValue & v) {
786 auto formatter = [&](std::ostream& out, const IValue& v) {
787 out << v;
788 };
789 switch(v.tag) {
790 case IValue::Tag::None:
791 return out << v.toNone();
792 case IValue::Tag::Tensor:
793 return out << v.toTensor();
794 case IValue::Tag::Storage:
795 return out << v.toStorage().unsafeGetStorageImpl();
796 case IValue::Tag::Double: {
797 double d = v.toDouble();
798 int c = std::fpclassify(d);
799 if (c == FP_NORMAL || c == FP_ZERO) {
800 int64_t i = int64_t(d);
801 if (double(i) == d) {
802 return out << i << ".";
803 }
804 }
805 auto orig_prec = out.precision();
806 return out
807 << std::setprecision(std::numeric_limits<double>::max_digits10)
808 << v.toDouble()
809 << std::setprecision(static_cast<int>(orig_prec));
810 } case IValue::Tag::ComplexDouble: {
811 return printComplex(out, v);
812 } case IValue::Tag::Int:
813 return out << v.toInt();
814 case IValue::Tag::SymInt:
815 return out << v.toSymInt();
816 case IValue::Tag::SymFloat:
817 return out << v.toSymFloat();
818 case IValue::Tag::SymBool:
819 return out << v.toSymBool();
820 case IValue::Tag::Bool:
821 return out << (v.toBool() ? "True" : "False");
822 case IValue::Tag::Tuple: {
823 const auto& elements = v.toTupleRef().elements();
824 const auto& finish = elements.size() == 1 ? ",)" : ")";
825 return printList(out, elements, "(", finish, formatter);
826 }
827 case IValue::Tag::String:
828 return out << v.toStringRef();
829 case IValue::Tag::Blob:
830 return out << *v.toBlob();
831 case IValue::Tag::Capsule:
832 return out << "Capsule";
833 case IValue::Tag::GenericList:
834 return printList(out, v.toList(), "[", "]", formatter);
835 case IValue::Tag::RRef:
836 return out << "RRef";
837 case IValue::Tag::Future:
838 return out << "Future";
839 case IValue::Tag::Await:
840 return out << "Await";
841 case IValue::Tag::Uninitialized:
842 return out << "Uninitialized";
843 case IValue::Tag::Device:
844 return out << v.toDevice();
845 case IValue::Tag::Stream:
846 return out << v.toStream();
847 case IValue::Tag::GenericDict:
848 return printDict(out, v.toGenericDict(), formatter);
849 case IValue::Tag::PyObject: {
850 auto py_obj = v.toPyObject();
851 return out << "<PyObject at" << py_obj << ">";
852 }
853 case IValue::Tag::Generator:
854 return out << "Generator";
855 case IValue::Tag::Quantizer:
856 return out << "Quantizer";
857 case IValue::Tag::Object: {
858 // TODO we should attempt to call __str__ if the object defines it.
859 auto obj = v.toObject();
860 // print this out the way python would do it
861 return out << "<" << obj->name() << " object at " << obj.get() << ">";
862 }
863 case IValue::Tag::Enum: {
864 auto enum_holder = v.toEnumHolder();
865 return out << "Enum<" << enum_holder->unqualifiedClassName() << "." <<
866 enum_holder->name() << ">";
867 }
868
869 }
870 return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << ">";
871 }
872
873 #undef TORCH_FORALL_TAGS
874
dump() const875 void IValue::dump() const {
876 std::cout << *this << "\n";
877 }
878
type() const879 std::shared_ptr<ClassType> ivalue::Object::type() const {
880 return type_.type_->expect<ClassType>();
881 }
882
create(ClassTypePtr classType,size_t numSlots)883 c10::intrusive_ptr<ivalue::Object> ivalue::Object::create(
884 ClassTypePtr classType, size_t numSlots) {
885 return ivalue::Object::create(
886 StrongTypePtr(nullptr, std::move(classType)), numSlots);
887 }
888
deepcopy(std::optional<at::Device> device) const889 IValue IValue::deepcopy(std::optional<at::Device> device) const {
890 IValue::HashIdentityIValueMap memo;
891 return deepcopy(memo, device);
892 }
893
deepcopy(IValue::HashIdentityIValueMap & memo,std::optional<at::Device> device) const894 IValue IValue::deepcopy(
895 IValue::HashIdentityIValueMap& memo,
896 std::optional<at::Device> device) const {
897 if (memo.count(*this)) {
898 return memo.at(*this);
899 }
900 IValue copy;
901 switch(tag) {
902 case IValue::Tag::Tensor: {
903 const at::Tensor& src_tensor = toTensor();
904 copy = device.has_value() && !src_tensor.device().is_meta()
905 ? IValue(src_tensor.to(*device))
906 : IValue(src_tensor.clone());
907 } break;
908 case IValue::Tag::Tuple: {
909 std::vector<IValue> copied_tuple;
910 for (const auto& e : toTupleRef().elements()) {
911 copied_tuple.emplace_back(e.deepcopy(memo, device));
912 }
913 copy = IValue(ivalue::Tuple::create(std::move(copied_tuple)));
914 }
915 break;
916 case IValue::Tag::GenericList: {
917 auto list = toList();
918 auto copied_list = c10::impl::GenericList(list.elementType());
919 for (IValue v : list) {
920 copied_list.push_back(v.deepcopy(memo, device));
921 }
922 copy = IValue(copied_list);
923 }
924 break;
925 case IValue::Tag::GenericDict: {
926 auto dict = toGenericDict();
927 auto copied_dict = c10::impl::GenericDict(dict.keyType(), dict.valueType());
928 for (const auto& entry : dict) {
929 copied_dict.insert(
930 entry.key().deepcopy(memo, device),
931 entry.value().deepcopy(memo, device));
932 }
933 copy = IValue(copied_dict);
934 }
935 break;
936 case IValue::Tag::Object: {
937 auto class_type = type()->expect<ClassType>();
938 if (class_type->hasMethod("__getstate__") &&
939 class_type->hasMethod("__setstate__")) {
940 copy = ivalue::Object::create(
941 c10::StrongTypePtr(class_type->compilation_unit(), type()),
942 class_type->numAttributes());
943 auto state = class_type->getMethod("__getstate__")({*this});
944 class_type->getMethod("__setstate__")({copy, std::move(state)});
945 } else {
946 copy = IValue(toObject()->deepcopy(memo, device));
947 }
948 } break;
949 case IValue::Tag::Enum: {
950 auto enum_holder = toEnumHolder();
951 copy = IValue(c10::make_intrusive<ivalue::EnumHolder>(
952 enum_holder->type(),
953 enum_holder->name(),
954 enum_holder->value().deepcopy(memo, device)));
955 } break;
956 case IValue::Tag::String:
957 case IValue::Tag::None:
958 case IValue::Tag::Double:
959 case IValue::Tag::Int:
960 case IValue::Tag::SymInt:
961 case IValue::Tag::SymFloat:
962 case IValue::Tag::SymBool:
963 case IValue::Tag::Bool:
964 case IValue::Tag::Device:
965 case IValue::Tag::Generator:
966 case IValue::Tag::Uninitialized: {
967 copy = *this;
968 } break;
969 default: {
970 AT_ERROR("Can't deepcopy IValue with tag: ", tagKind());
971 }
972 }
973 // NB: this doesn't work if an object contains itself, and it may
974 // come up in the future when we expand the object system, we will
975 // have a follow up PR to fix this when it becomes an issue.
976 if (!isAliasOf(copy)) {
977 memo[*this] = copy;
978 }
979 return copy;
980 }
981
reportToTensorTypeError() const982 void IValue::reportToTensorTypeError() const {
983 TORCH_CHECK(false, "Expected Tensor but got ", tagKind());
984 }
985
name() const986 std::string ivalue::Object::name() const {
987 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
988 return type()->name()->qualifiedName();
989 }
990
getAttr(const std::string & name) const991 IValue ivalue::Object::getAttr(const std::string& name) const {
992 const size_t slot = type()->getAttributeSlot(name);
993 return getSlot(slot);
994 }
995
setAttr(const std::string & name,IValue v)996 void ivalue::Object::setAttr(const std::string& name, IValue v) {
997 const size_t slot = type()->getAttributeSlot(name);
998 setSlot(slot, std::move(v));
999 }
1000
unsafeRemoveAttr(const std::string & name)1001 void ivalue::Object::unsafeRemoveAttr(const std::string& name) {
1002 const size_t slot = type()->getAttributeSlot(name);
1003 unsafeRemoveSlot(slot);
1004 }
1005
resizeObject(size_t slot)1006 void ivalue::Object::resizeObject(size_t slot) {
1007 AT_ASSERT(slot < type()->numAttributes());
1008 slots_.resize(type()->numAttributes());
1009 }
1010
1011
copy() const1012 c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy() const {
1013 auto object = ivalue::Object::create(type_, type()->numAttributes());
1014 for (const auto i : c10::irange(slots_.size())) {
1015 object->setSlot(i, slots_[i]);
1016 }
1017 return object;
1018 }
1019
copy_to_weak_compilation_ref() const1020 c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy_to_weak_compilation_ref() const {
1021 auto object = ivalue::Object::create(
1022 WeakOrStrongTypePtr(type_.asWeakTypePtr()), type()->numAttributes());
1023 for (const auto i : c10::irange(slots_.size())) {
1024 object->setSlot(i, slots_[i]);
1025 }
1026 return object;
1027 }
1028
deepcopy(std::optional<at::Device> device) const1029 c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
1030 std::optional<at::Device> device) const {
1031 IValue::HashIdentityIValueMap memo;
1032 return deepcopy(memo, device);
1033 }
1034
deepcopy(IValue::HashIdentityIValueMap & memo,std::optional<at::Device> device) const1035 c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
1036 IValue::HashIdentityIValueMap& memo,
1037 std::optional<at::Device> device) const {
1038 auto cu = type_.cu_;
1039 auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());
1040 for (const auto i : c10::irange(slots_.size())) {
1041 if (*slots_[i].type() == *c10::TypeFactory::get<CapsuleType>()) {
1042 // If we've gotten here, it means that we have *not* copied this
1043 // class via __getstate__ and __setstate__. That fact and the
1044 // fact that we have a Capsule attribute mean that this is a
1045 // custom C++ class without serialization methods defined.
1046 std::stringstream err;
1047 err << "Cannot serialize custom bound C++ class";
1048 if (auto qualname = type()->name()) {
1049 err << " " << qualname->qualifiedName();
1050 }
1051 err << ". Please define serialization methods via def_pickle() for "
1052 "this class.";
1053 AT_ERROR(err.str());
1054 }
1055 object->setSlot(i, slots_[i].deepcopy(memo, device));
1056 }
1057 return object;
1058 }
1059
StrongTypePtr(std::shared_ptr<torch::jit::CompilationUnit> cu,TypePtr type)1060 StrongTypePtr::StrongTypePtr(
1061 std::shared_ptr<torch::jit::CompilationUnit> cu,
1062 TypePtr type) : cu_(std::move(cu)), type_(std::move(type)) {
1063 TORCH_INTERNAL_ASSERT(type_);
1064 }
1065
WeakTypePtr(std::weak_ptr<torch::jit::CompilationUnit> cu,TypePtr type)1066 WeakTypePtr::WeakTypePtr(
1067 std::weak_ptr<torch::jit::CompilationUnit> cu,
1068 TypePtr type) : cu_(std::move(cu)), type_(std::move(type)) {}
1069
asWeakTypePtr() const1070 WeakTypePtr WeakOrStrongTypePtr::asWeakTypePtr() const {
1071 if (!holds_strong_ref()) {
1072 return WeakTypePtr(cu_.getWeakRefOrThrow(), type_);
1073 } else {
1074 std::weak_ptr<torch::jit::CompilationUnit> weak_cu =
1075 cu_.getStrongRefOrThrow();
1076 return WeakTypePtr(std::move(weak_cu), type_);
1077 }
1078 }
1079
1080 // Needs to be in this .cpp file to access the full definition of PyObjectHolder
extractStorages(const at::IValue & value)1081 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> ivalue::Future::extractStorages(
1082 const at::IValue& value) {
1083 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> weakStorageImpls;
1084 // getSubValues works poorly on Python objects: it only works if they can be
1085 // converted to a "regular" IValue type hence, for example, it doesn't support
1086 // custom subclasses. Thus, instead, we extract the tensors through pickling.
1087 if (value.isPyObject()) {
1088 std::vector<at::Tensor> tensors =
1089 value.toPyObjectHolder()->extractTensors();
1090 size_t num_storages = 0;
1091 for (const at::Tensor& tensor : tensors) {
1092 if (tensor.is_sparse()) {
1093 // Sparse tensor is indices and values. Both are tensors
1094 // and contain storage. Therefore num_storages needs to be
1095 // incremented by 2.
1096 num_storages += 2;
1097 } else {
1098 // A dense/strided tensor contains 1 storage.
1099 num_storages += 1;
1100 }
1101 }
1102 weakStorageImpls.reserve(num_storages);
1103 for (const at::Tensor& tensor : tensors) {
1104 if (tensor.is_sparse()) {
1105 // Sparse tensor is indices and values. Both are tensors
1106 // and contain storage.
1107 // TODO (rohan-varma): for tensors created with at::sparse_coo_tensor held
1108 // in a python object, this might need a coalesce().
1109 weakStorageImpls.emplace_back(tensor.indices().storage().getWeakStorageImpl());
1110 weakStorageImpls.emplace_back(tensor.values().storage().getWeakStorageImpl());
1111 } else {
1112 // A dense/strided tensor contains 1 storage
1113 weakStorageImpls.emplace_back(tensor.storage().getWeakStorageImpl());
1114 }
1115 }
1116 } else {
1117 at::IValue::HashAliasedIValues sub_values;
1118 // Prefer getSubValues() over visit() as the latter is a silent no-op for
1119 // some unsupported types, whereas the former at least fails loudly.
1120 value.getSubValues(sub_values);
1121 for (const at::IValue& sub_value : sub_values) {
1122 if (sub_value.isTensor()) {
1123 auto const & tens = sub_value.toTensor();
1124 if (tens.is_sparse()) {
1125 // sparse tensors have 2 storages! One for indices one for values
1126 auto coalesced = tens.coalesce();
1127 weakStorageImpls.emplace_back(coalesced.indices().storage().getWeakStorageImpl());
1128 weakStorageImpls.emplace_back(coalesced.values().storage().getWeakStorageImpl());
1129 } else {
1130 weakStorageImpls.emplace_back(tens.storage().getWeakStorageImpl());
1131 }
1132 }
1133 }
1134 }
1135 return weakStorageImpls;
1136 }
1137
collectAll(const List<intrusive_ptr<ivalue::Future>> & srcs)1138 TORCH_API intrusive_ptr<ivalue::Future> collectAll(
1139 const List<intrusive_ptr<ivalue::Future>>& srcs) {
1140 struct Ctx {
1141 explicit Ctx(const List<intrusive_ptr<ivalue::Future>>& srcs)
1142 : remaining(srcs.size()),
1143 srcFutures(srcs),
1144 asIvalue(srcFutures),
1145 // No need to pass devices, because dstFuture won't directly contain
1146 // the value, it will contain the srcFutures (which have no DataPtrs).
1147 dstFuture(make_intrusive<ivalue::Future>(asIvalue.type())) {}
1148 std::atomic<size_t> remaining{0};
1149 List<intrusive_ptr<ivalue::Future>> srcFutures;
1150 IValue asIvalue;
1151 intrusive_ptr<ivalue::Future> dstFuture;
1152 };
1153
1154 auto ctx = std::make_shared<Ctx>(srcs);
1155 if (ctx->srcFutures.empty()) {
1156 ctx->dstFuture->markCompleted(ctx->asIvalue);
1157 } else {
1158 for (const auto i : c10::irange(ctx->srcFutures.size())) {
1159
1160 std::function<void(ivalue::Future&)> func = [ctx](ivalue::Future& fut) {
1161 // Set error and exit early if encountered.
1162 if (fut.hasError()) {
1163 ctx->dstFuture->setErrorIfNeeded(fut.exception_ptr());
1164 return;
1165 }
1166
1167 if (--ctx->remaining == 0 && !ctx->dstFuture->completed()) {
1168 // No need to pass DataPtrs, because dstFuture won't directly contain
1169 // the value, it will contain the srcFutures (which have no DataPtrs).
1170 ctx->dstFuture->markCompleted(ctx->asIvalue);
1171 }
1172 };
1173 ctx->srcFutures.get(i)->addCallback(func);
1174 }
1175 }
1176 return ctx->dstFuture;
1177 }
1178
1179 namespace {
1180
1181 #ifndef STRIP_ERROR_MESSAGES
formatSetOfDevices(const std::vector<c10::Device> & devices)1182 std::string formatSetOfDevices(const std::vector<c10::Device>& devices) {
1183 std::ostringstream oss;
1184 std::copy(
1185 devices.begin(),
1186 devices.end(),
1187 std::ostream_iterator<c10::Device>(oss, ", "));
1188 return oss.str();
1189 }
1190 #endif
1191
1192 }
1193
collectAny(const List<intrusive_ptr<ivalue::Future>> & srcs)1194 TORCH_API intrusive_ptr<ivalue::Future> collectAny(
1195 const List<intrusive_ptr<ivalue::Future>>& srcs) {
1196 if (srcs.empty()) {
1197 auto res = make_intrusive<ivalue::Future>(NoneType::get());
1198 res->markCompleted();
1199 return res;
1200 }
1201 const TypePtr& typePtr = srcs.get(0)->elementType();
1202 const std::vector<c10::Device>& devices = srcs.get(0)->devices();
1203 for (const auto i : c10::irange(srcs.size())) {
1204 if (srcs.get(i)->completed()) {
1205 return srcs.get(i);
1206 }
1207 TORCH_CHECK_TYPE(
1208 i == 0 || (*typePtr == *srcs.get(i)->elementType()),
1209 "Expected all futures to have the same type, but found ", *typePtr,
1210 " in position 0 and ", *srcs.get(i)->elementType(), " in position ", i);
1211 TORCH_CHECK_VALUE(
1212 i == 0 || (devices == srcs.get(i)->devices()),
1213 "Expected all futures to have the same devices, but found ",
1214 formatSetOfDevices(devices), " in position 0 and ",
1215 formatSetOfDevices(srcs.get(i)->devices()), " in position ", i);
1216 }
1217 struct Ctx {
1218 explicit Ctx(
1219 const List<intrusive_ptr<ivalue::Future>>& srcs,
1220 TypePtr typePtr,
1221 std::vector<c10::Device> devices)
1222 : srcFutures(srcs),
1223 dstFuture(make_intrusive<ivalue::Future>(std::move(typePtr), std::move(devices))) {}
1224 std::atomic<bool> done{false};
1225 List<intrusive_ptr<ivalue::Future>> srcFutures;
1226 intrusive_ptr<ivalue::Future> dstFuture;
1227 };
1228 auto ctx = std::make_shared<Ctx>(srcs, typePtr, devices);
1229 std::function<void(ivalue::Future&)> func = [ctx](ivalue::Future& src) {
1230 if (!ctx->done.exchange(true)) {
1231 intrusive_ptr<ivalue::Future> dst = ctx->dstFuture;
1232 ctx->dstFuture.reset(); // Once future is satisfied, remove refs.
1233 ctx->srcFutures =
1234 List<intrusive_ptr<ivalue::Future>>(ctx->srcFutures.elementType());
1235 if (src.hasError()) {
1236 dst->setError(src.exception_ptr());
1237 } else {
1238 dst->markCompleted(src.constValue(), src.storages());
1239 }
1240 }
1241 };
1242 for (const auto i : c10::irange(ctx->srcFutures.size())) {
1243 ctx->srcFutures.get(i)->addCallback(func);
1244 }
1245 return ctx->dstFuture;
1246 }
1247
1248 } // namespace c10
1249