xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/ivalue.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Dict.h>
2 #include <ATen/core/Formatting.h>
3 #include <ATen/core/class_type.h>
4 #include <ATen/core/enum_type.h>
5 #include <ATen/core/function.h>
6 #include <ATen/core/ivalue.h>
7 #include <ATen/core/jit_type.h>
8 #include <ATen/core/stack.h>
9 #include <ATen/core/type_factory.h>
10 #include <c10/util/StringUtil.h>
11 #include <c10/util/hash.h>
12 #include <c10/util/irange.h>
13 #include <cmath>
14 #include <iostream>
15 #include <utility>
16 
17 namespace c10 {
_fastEqualsForContainer(const IValue & lhs,const IValue & rhs)18 bool _fastEqualsForContainer(const IValue& lhs, const IValue& rhs) {
19   if (lhs.is(rhs)) {
20     // Like Python, for containers we consider identity equality to be
21     // sufficient but not necessary for value equality
22     return true;
23   }
24   return lhs == rhs;
25 }
26 
27 namespace ivalue {
28 
29 // This is in ivalue.cpp because we need to access Type::annotation_str, which
30 // is declared in jit_type.h
checkCustomClassType(const ClassType * expected_type,const Type * actual_type)31 void checkCustomClassType(const ClassType* expected_type, const Type* actual_type) {
32   // NB: doing pointer comparison here
33   // If in the future there ever arises a need to call operator== on custom class
34   // Type's, this needs to be changed!
35   TORCH_CHECK(actual_type == static_cast<const Type*>(expected_type),
36               "Tried to convert an IValue of type ",
37               actual_type ? actual_type->repr_str() : std::string("*NULL*"),
38               " to custom class type ",
39               expected_type ? expected_type->repr_str() : std::string("*NULL*"));
40 }
41 
create(std::string str_)42 TORCH_API c10::intrusive_ptr<ConstantString> ConstantString::create(
43     std::string str_) {
44   return c10::make_intrusive<ConstantString>(std::move(str_));
45 }
46 
create(c10::string_view str_)47 TORCH_API c10::intrusive_ptr<ConstantString> ConstantString::create(
48     c10::string_view str_) {
49   return c10::make_intrusive<ConstantString>(std::string(str_));
50 }
51 
create(const char * str_)52 TORCH_API c10::intrusive_ptr<ConstantString> ConstantString::create(
53     const char* str_) {
54   return c10::make_intrusive<ConstantString>(std::string(str_));
55 }
56 
operator ==(const ivalue::Tuple & lhs,const ivalue::Tuple & rhs)57 bool operator==(const ivalue::Tuple& lhs, const ivalue::Tuple& rhs) {
58   return lhs.size() == rhs.size() &&
59       // see [container equality]
60       std::equal(
61              lhs.elements().cbegin(),
62              lhs.elements().cend(),
63              rhs.elements().cbegin(),
64              _fastEqualsForContainer);
65 }
66 
operator <<(std::ostream & out,const ivalue::EnumHolder & v)67 std::ostream& operator<<(std::ostream& out, const ivalue::EnumHolder& v) {
68   out << v.qualifiedClassName() << "." << v.name();
69   return out;
70 }
71 
operator ==(const ivalue::EnumHolder & lhs,const ivalue::EnumHolder & rhs)72 bool operator==(const ivalue::EnumHolder& lhs, const ivalue::EnumHolder& rhs) {
73   return lhs.name() == rhs.name() && *rhs.type() == *lhs.type();
74 }
75 
qualifiedClassName() const76 const std::string& ivalue::EnumHolder::qualifiedClassName() const {
77   return type_->qualifiedClassName().qualifiedName();
78 }
79 
unqualifiedClassName() const80 const std::string& ivalue::EnumHolder::unqualifiedClassName() const {
81   return type_->qualifiedClassName().name();
82 }
83 
84 } // namespace ivalue
85 
get(const IValue & v)86 c10::TypePtr IValue::TagType<c10::Type>::get(const IValue& v) {
87   switch (v.tag) {
88       case Tag::None:
89         return NoneType::get();
90       case Tag::Tensor:
91         return TensorType::create(v.toTensor());
92       case Tag::Storage:
93         return StorageType::get();
94       case Tag::Double:
95         return FloatType::get();
96       case Tag::ComplexDouble:
97         return ComplexType::get();
98       case Tag::Int:
99         return IntType::get();
100       case Tag::SymInt:
101         return c10::SymIntType::get();
102       case Tag::SymFloat:
103         return c10::SymFloatType::get();
104       case Tag::SymBool:
105         return c10::SymBoolType::get();
106       case Tag::Bool:
107         return BoolType::get();
108       case Tag::String:
109         return StringType::get();
110       case Tag::Blob:
111         return AnyType::get();
112       case Tag::GenericDict: {
113         auto d = v.toGenericDict();
114         return DictType::create(d.keyType(), d.valueType());
115       }
116       case Tag::GenericList:
117         return ListType::create(v.toList().elementType());
118       case Tag::Await:
119         return AwaitType::create(v.toAwait()->elementType());
120       case Tag::Future:
121         return FutureType::create(v.toFuture()->elementType());
122       case Tag::RRef:
123         return RRefType::create(v.toRRef()->type());
124       case Tag::Device:
125         return DeviceObjType::get();
126       case Tag::Stream:
127         return StreamObjType::get();
128       case Tag::Object:
129         return v.toObjectRef().type();
130       case Tag::PyObject:
131         return PyObjectType::get();
132       case Tag::Uninitialized:
133         return AnyType::get();
134       case Tag::Capsule:
135         return CapsuleType::get();
136       case Tag::Tuple:
137         return v.toTupleRef().type();
138       case Tag::Generator:
139         return GeneratorType::get();
140       case Tag::Quantizer:
141         return QuantizerType::get();
142       case Tag::Enum:
143         return v.toEnumHolder()->type();
144   }
145   // switch above is complete but this silences compiler warnings
146   TORCH_INTERNAL_ASSERT(false, "unhandled case in IValue::type()");
147 
148   // This static_assert has to go into some IValue member function; I
149   // chose this one. It's not in the class body because that's in
150   // ivalue.h, which is a very high-fanout header file and we want to
151   // minimize build time.
152   static_assert(
153       kNumTags <= 32,
154       "IValue::isIntrusivePtr needs to be updated because it assumes there are at most 32 tags");
155 }
156 
visit(const std::function<bool (const IValue &)> & visitor) const157 void IValue::visit(const std::function<bool (const IValue &)>& visitor) const {
158   if (visitor(*this)) {
159     // Shortcut
160     return;
161   }
162   switch (this->tag) {
163     case Tag::Tuple:
164     case Tag::GenericList: {
165       c10::ArrayRef<IValue> elems;
166       if (isTuple()) {
167         elems = this->toTupleRef().elements();
168       } else {
169         elems = this->toListRef();
170       }
171       for (auto& elem : elems) {
172         elem.visit(visitor);
173       }
174       break;
175     }
176     case Tag::GenericDict:
177       for (const auto& pair : this->toGenericDict()) {
178         pair.value().visit(visitor);
179         pair.key().visit(visitor);
180       }
181       break;
182     case Tag::Object: {
183       auto obj_type = type()->expect<ClassType>();
184       auto obj_value = toObject();
185       auto attributes = obj_type->getAttributes();
186       for (const auto& attr: attributes) {
187         auto attribute = obj_value->getAttr(attr.getName());
188         attribute.visit(visitor);
189       }
190       break;
191     }
192     case Tag::PyObject: {
193       c10::intrusive_ptr<at::ivalue::PyObjectHolder> py_obj = toPyObjectHolder();
194       auto match = py_obj->tryToInferType();
195       if (match.success()) {
196         auto contained_value = py_obj->toIValue(match.type());
197         contained_value.visit(visitor);
198       }
199       break;
200     }
201     default:
202       break;
203  }
204 }
205 
getSubValues(HashAliasedIValues & subValues) const206 void IValue::getSubValues(HashAliasedIValues& subValues) const {
207   switch (this->tag) {
208     case Tag::Tensor:
209       subValues.insert(*this);
210       return;
211     case Tag::Tuple:
212     case Tag::GenericList: {
213       subValues.insert(*this);
214       c10::ArrayRef<IValue> elems;
215       if (isTuple()) {
216         elems = this->toTupleRef().elements();
217       } else {
218         elems = this->toListRef();
219       }
220       for (auto& elem : elems) {
221         elem.getSubValues(subValues);
222       }
223       break;
224     }
225     case Tag::GenericDict:
226       subValues.insert(*this);
227       for (const auto& pair : this->toGenericDict()) {
228         pair.value().getSubValues(subValues);
229         pair.key().getSubValues(subValues);
230       }
231       break;
232     case Tag::Object: {
233       // Record Object IValue and its attributes.
234       subValues.insert(*this);
235       auto obj_type = type()->expect<ClassType>();
236       auto obj_value = toObject();
237       auto attributes = obj_type->getAttributes();
238       for (const auto& attr: attributes) {
239         auto attribute = obj_value->getAttr(attr.getName());
240         attribute.getSubValues(subValues);
241       }
242       break;
243     }
244     case Tag::PyObject: {
245       subValues.insert(*this);
246       c10::intrusive_ptr<at::ivalue::PyObjectHolder> py_obj = toPyObjectHolder();
247       auto match = py_obj->tryToInferType();
248       TORCH_CHECK_TYPE(match.success(),
249             "Cannot infer type of ", py_obj->toStr(), ": ", match.reason());
250       auto contained_value = py_obj->toIValue(match.type());
251       contained_value.getSubValues(subValues);
252       break;
253     }
254     case Tag::Future:
255     case Tag::Await:
256     case Tag::Device:
257     case Tag::Uninitialized:
258     case Tag::Capsule:
259       TORCH_CHECK_TYPE(
260           false, "Cannot inspect value of type ", this->tagKind());
261     default:
262       // don't record scalars.
263       break;
264   }
265 }
266 
overlaps(const IValue & rhs) const267 bool IValue::overlaps(const IValue& rhs) const {
268   HashAliasedIValues rhsSubValues, thisSubValues;
269   rhs.getSubValues(rhsSubValues);
270   getSubValues(thisSubValues);
271   for (auto& sub : thisSubValues) {
272     if (rhsSubValues.count(sub)) {
273       return true;
274     }
275   }
276   return false;
277 }
278 
operator !=(const IValue & lhs,const IValue & rhs)279 bool operator!=(const IValue& lhs, const IValue& rhs) {
280   return !(lhs == rhs);
281 }
282 
operator ==(const IValue & lhs,const IValue & rhs)283 bool operator==(const IValue& lhs, const IValue& rhs) {
284   IValue eq = lhs.equals(rhs);
285   if (eq.isBool()) {
286     return eq.toBool();
287   }
288   // The only case we don't return bool is for tensor comparison. In Python,
289   // `bool()` is called on the return value of `__eq__` if the return value is
290   // not a boolean. Mimic that behavior here.
291   TORCH_INTERNAL_ASSERT(eq.isTensor());
292   return eq.toTensor().is_nonzero();
293 }
294 
ptrEqual(const IValue & lhs,const IValue & rhs)295 bool IValue::ptrEqual(const IValue& lhs, const IValue& rhs) {
296   TORCH_INTERNAL_ASSERT(lhs.isIntrusivePtr());
297   TORCH_INTERNAL_ASSERT(rhs.isIntrusivePtr());
298   return lhs.tag == rhs.tag &&
299       lhs.payload.u.as_intrusive_ptr == rhs.payload.u.as_intrusive_ptr;
300 }
301 
equals(const IValue & rhs) const302 IValue IValue::equals(const IValue& rhs) const {
303   const IValue& lhs = *this;
304   switch (lhs.tag) {
305     case Tag::None:
306       // In Python you're not supposed to do this comparison apparently. Not
307       // sure if we should warn here or what
308       return rhs.isNone();
309     case Tag::Tensor: {
310       if (!rhs.isTensor()) {
311         return false;
312       }
313       return lhs.toTensor().eq(rhs.toTensor());
314     }
315     case Tag::Storage:
316       return rhs.isStorage() && lhs.toStorage().unsafeGetStorageImpl() == rhs.toStorage().unsafeGetStorageImpl();
317     case Tag::Double:
318       return rhs.isDouble() && lhs.toDouble() == rhs.toDouble();
319     case Tag::ComplexDouble:
320       return rhs.isComplexDouble() && lhs.toComplexDouble() == rhs.toComplexDouble();
321     case Tag::Int:
322       return rhs.isInt() && lhs.toInt() == rhs.toInt();
323     case Tag::SymInt:
324       return rhs.isSymInt() && lhs.toSymInt() == rhs.toSymInt();
325     case Tag::SymFloat:
326       return rhs.isSymFloat() && lhs.toSymFloat() == rhs.toSymFloat();
327     case Tag::SymBool:
328       return rhs.isSymBool() && lhs.toSymBool() == rhs.toSymBool();
329     case Tag::Bool:
330       return rhs.isBool() && lhs.toBool() == rhs.toBool();
331     case Tag::String:
332       return rhs.isString() && lhs.toStringRef() == rhs.toStringRef();
333     case Tag::GenericDict:
334       return rhs.isGenericDict() && lhs.toGenericDict() == rhs.toGenericDict();
335     case Tag::Tuple:
336       return rhs.isTuple() && *lhs.toTuple() == *rhs.toTuple();
337     case Tag::Stream:
338       return rhs.isStream() && lhs.toStream() == rhs.toStream();
339     case Tag::Device:
340       return rhs.isDevice() && lhs.toDevice() == rhs.toDevice();
341     case Tag::GenericList:
342       return rhs.isList() && lhs.toList() == rhs.toList();
343     case Tag::Blob:
344     case Tag::Future:
345     case Tag::Await:
346     case Tag::RRef:
347     case Tag::Object:
348     case Tag::PyObject:
349     case Tag::Capsule:
350     case Tag::Generator:
351     case Tag::Quantizer:
352       return ptrEqual(lhs, rhs);
353     case Tag::Enum:
354       return lhs.toEnumHolder()->is(*rhs.toEnumHolder());
355     case Tag::Uninitialized:
356       // Unitialized ivalues show up in no-ops when the compiler can prove a
357       // value will never be used. Just return false on any equality comparison.
358       return false;
359   }
360   // the above switch should be exhaustive
361   TORCH_INTERNAL_ASSERT(false, "we should never reach here")
362 }
363 
hash(const IValue & v)364 size_t IValue::hash(const IValue& v) {
365   switch (v.tag) {
366     case Tag::None:
367       return 0;
368     case Tag::Bool:
369       return c10::get_hash(v.payload.u.as_bool);
370     case Tag::Double:
371       return c10::get_hash(v.payload.u.as_double);
372     case Tag::Tensor:
373       // Tensor __hash__ is equivalent to `id()`, so take the pointer value of
374       // the tensor to emulate it
375       return c10::get_hash(v.payload.as_tensor.unsafeGetTensorImpl());
376     // NOLINTNEXTLINE(bugprone-branch-clone)
377     case Tag::Storage:
378       return c10::get_hash(v.payload.u.as_int);
379     case Tag::Int:
380       return c10::get_hash(v.payload.u.as_int);
381     // NB: these are technically strict aliasing violations
382     case Tag::SymInt:
383       return c10::get_hash(v.payload.u.as_int);
384     case Tag::SymFloat:
385       return c10::get_hash(v.payload.u.as_int);
386     case Tag::SymBool:
387       return c10::get_hash(v.payload.u.as_int);
388     case Tag::String:
389       return c10::get_hash(v.toStringRef());
390     case Tag::Tuple:
391       return c10::get_hash(*v.toTuple());
392     case Tag::Device:
393       return c10::get_hash(v.toDevice());
394     case Tag::GenericDict:
395     case Tag::GenericList:
396     case Tag::Blob:
397     case Tag::Future:
398     case Tag::Await:
399     case Tag::RRef:
400     case Tag::Object:
401     case Tag::PyObject:
402     case Tag::Capsule:
403     case Tag::Generator:
404     case Tag::Quantizer:
405     case Tag::ComplexDouble:
406     case Tag::Enum:
407     case Tag::Stream:
408     case Tag::Uninitialized:
409       throw std::runtime_error(
410           "unhashable type: '" + v.type()->repr_str() + "'");
411   }
412   // the above switch should be exhaustive
413   TORCH_INTERNAL_ASSERT(false, "we should never reach here")
414 }
415 
isUndefinedTensor(const IValue & iv)416 static bool isUndefinedTensor(const IValue& iv) {
417   return iv.isTensor() && !iv.toTensor().defined();
418 }
419 
is(const IValue & rhs) const420 bool IValue::is(const IValue& rhs) const {
421   const IValue& lhs = *this;
422   // Special handling for undefined tensors:
423   // 1. Undefined_tensor is None and vice versa.
424   if ((isUndefinedTensor(lhs) && rhs.isNone()) ||
425       (lhs.isNone() && isUndefinedTensor(rhs))) {
426     return true;
427   }
428   // 2. Undefined_tensor is Undefined_tensor.
429   if (isUndefinedTensor(lhs) && isUndefinedTensor(rhs)) {
430     return true;
431   }
432 
433   if (lhs.isTensor()) {
434     // Use the standard way of comparing two tensors for identity
435     return rhs.isTensor() && lhs.toTensor().is_same(rhs.toTensor());
436   }
437 
438   if (lhs.isIntrusivePtr()) {
439     return rhs.isIntrusivePtr() && ptrEqual(lhs, rhs);
440   }
441   return lhs == rhs;
442 }
443 
444 template <typename T>
isListOf() const445 inline bool IValue::isListOf() const {
446   // note: avoids calling type() to avoid extra referencing counting for the returned type.
447   if (!isList()) {
448     return false;
449   }
450   const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
451   if (ty->kind() == T::Kind) {
452     return true;
453   }
454   return *ty == *TypeFactory::get<T>();
455 }
456 
isDoubleList() const457 bool IValue::isDoubleList() const {
458   return isListOf<c10::FloatType>();
459 }
460 
isComplexDoubleList() const461 bool IValue::isComplexDoubleList() const {
462   return isListOf<c10::ComplexType>();
463 }
464 
isTensorList() const465 bool IValue::isTensorList() const {
466   return isListOf<c10::TensorType>();
467 }
468 
isOptionalTensorList() const469 bool IValue::isOptionalTensorList() const {
470   if (!isList()) {
471     return false;
472   }
473   const auto& ty = static_cast<detail::ListImpl*>(payload.u.as_intrusive_ptr)->elementType;
474   const auto& expected_ty = c10::getTypePtr<std::optional<at::Tensor>>();
475   return expected_ty == ty;
476 }
477 
isIntList() const478 bool IValue::isIntList() const {
479   return isListOf<c10::IntType>();
480 }
481 
isSymIntList() const482 bool IValue::isSymIntList() const {
483   return isListOf<c10::SymIntType>();
484 }
485 
isBoolList() const486 bool IValue::isBoolList() const {
487   return isListOf<c10::BoolType>();
488 }
489 
490 namespace {
491 
492 using IValueFormatter = std::function<void(std::ostream&, const IValue&)>;
493 
494 template <class T>
printList(std::ostream & out,const T & list,const std::string & start,const std::string & finish,const IValueFormatter & formatter)495 std::ostream& printList(
496     std::ostream& out,
497     const T& list,
498     const std::string& start,
499     const std::string& finish,
500     const IValueFormatter& formatter) {
501   out << start;
502   for (const auto i : c10::irange(list.size())) {
503     if (i > 0) {
504       out << ", ";
505     }
506     formatter(out, IValue(list[i]));
507   }
508   out << finish;
509   return out;
510 }
511 
512 // Properly disambiguate the type of an empty list
printMaybeAnnotatedList(std::ostream & out,const IValue & the_list,const IValueFormatter & formatter)513 std::ostream& printMaybeAnnotatedList(
514     std::ostream& out,
515     const IValue& the_list,
516     const IValueFormatter& formatter) {
517   auto list_elem_type = the_list.type()->containedType(0);
518   if (the_list.toListRef().empty() ||
519       !elementTypeCanBeInferredFromMembers(list_elem_type)) {
520     out << "annotate(" << the_list.type<c10::Type>()->annotation_str() << ", ";
521     printList(out, the_list.toListRef(), "[", "]", formatter);
522     out << ")";
523     return out;
524   } else {
525     return printList(out, the_list.toListRef(), "[", "]", formatter);
526   }
527 }
528 
529 template <typename Dict>
printDict(std::ostream & out,const Dict & v,const IValueFormatter & formatter)530 std::ostream& printDict(
531     std::ostream& out,
532     const Dict& v,
533     const IValueFormatter& formatter) {
534   out << "{";
535 
536   bool first = true;
537   for (const auto& pair : v) {
538     if (!first) {
539       out << ", ";
540     }
541 
542     formatter(out, pair.key());
543     out << ": ";
544     formatter(out, pair.value());
545     first = false;
546   }
547 
548   out << "}";
549   return out;
550 }
551 }
552 
553 // Properly disambiguate the type of an empty dict
printMaybeAnnotatedDict(std::ostream & out,const IValue & the_dict,const IValueFormatter & formatter)554 static std::ostream& printMaybeAnnotatedDict(
555     std::ostream& out,
556     const IValue& the_dict,
557     const IValueFormatter& formatter) {
558   auto value_type = the_dict.type()->castRaw<DictType>()->getValueType();
559   if (the_dict.toGenericDict().empty() ||
560       !elementTypeCanBeInferredFromMembers(value_type)) {
561     out << "annotate(" << the_dict.type<c10::Type>()->annotation_str() << ",";
562     printDict(out, the_dict.toGenericDict(), formatter) << ")";
563   } else {
564     return printDict(out, the_dict.toGenericDict(), formatter);
565   }
566   return out;
567 }
568 
printComplex(std::ostream & out,const IValue & v)569 static std::ostream& printComplex(std::ostream & out, const IValue & v) {
570   c10::complex<double> d = v.toComplexDouble();
571   IValue real(d.real()), imag(std::abs(d.imag()));
572   auto sign = "";
573   if (d.imag() >= 0) {
574     sign = "+";
575   } else {
576     sign = "-";
577   }
578   return out << real << sign << imag << "j";
579 }
580 
repr(std::ostream & out,std::function<bool (std::ostream &,const IValue & v)> customFormatter) const581 std::ostream& IValue::repr(
582     std::ostream& out,
583     std::function<bool(std::ostream&, const IValue& v)>
584         customFormatter) const {
585   // First check if the caller has provided a custom formatter. Use that if possible.
586   if (customFormatter(out, *this)) {
587     return out;
588   }
589 
590   const IValue& v = *this;
591   // continue to use custom formatter in recursion
592   auto formatter = [&](std::ostream& out, const IValue& input) {
593     input.repr(out, customFormatter);
594   };
595   switch (v.tag) {
596     case IValue::Tag::None:
597       return out << v.toNone();
598     case IValue::Tag::Double: {
599       double d = v.toDouble();
600       int c = std::fpclassify(d);
601       if ((c == FP_NORMAL || c == FP_ZERO ) && std::abs(d) < 1e10) {
602         int64_t i = int64_t(d);
603         if (double(i) == d) {
604           // -0.0 (signed zero) needs to be parsed as -0.
605           if (i == 0 && std::signbit(d)) {
606             return out << "-" << i << ".";
607           }
608           return out << i << ".";
609         }
610       }
611       auto orig_prec = out.precision();
612       return out << std::setprecision(std::numeric_limits<double>::max_digits10)
613                  << d << std::setprecision(static_cast<int>(orig_prec));
614     }
615     case IValue::Tag::ComplexDouble: {
616       return printComplex(out, v);
617     }
618     case IValue::Tag::Int:
619       return out << v.toInt();
620     case IValue::Tag::SymInt:
621       return out << v.toSymInt();
622     case IValue::Tag::SymFloat:
623       return out << v.toSymFloat();
624     case IValue::Tag::SymBool:
625       return out << v.toSymBool();
626     case IValue::Tag::Bool:
627       return out << (v.toBool() ? "True" : "False");
628     case IValue::Tag::Tuple: {
629       const auto& elements = v.toTupleRef().elements();
630       const auto& finish = elements.size() == 1 ? ",)" : ")";
631       return printList(out, elements, "(", finish, formatter);
632     }
633     case IValue::Tag::String:
634       c10::printQuotedString(out, v.toStringRef());
635       return out;
636     case IValue::Tag::GenericList: {
637       return printMaybeAnnotatedList(out, *this, formatter);
638     }
639     case IValue::Tag::Device: {
640       std::stringstream device_stream;
641       device_stream << v.toDevice();
642       out << "torch.device(";
643       c10::printQuotedString(out, device_stream.str());
644       return out << ")";
645     }
646     case IValue::Tag::Generator: {
647       auto generator = v.toGenerator();
648       out << "torch.Generator(device=";
649       c10::printQuotedString(out, generator.device().str());
650       out << ", seed=" << generator.current_seed() << ")";
651       return out;
652     }
653     case IValue::Tag::GenericDict:
654       return printMaybeAnnotatedDict(out, v, formatter);
655     case IValue::Tag::Enum: {
656       auto enum_holder = v.toEnumHolder();
657       return out << enum_holder->qualifiedClassName() << "." <<
658           enum_holder->name();
659     }
660     case IValue::Tag::Object: {
661       TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind(), ". Perhaps you've frozen a module with custom classes?");
662     }
663     default:
664       TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind());
665   }
666 }
667 
simpleClassTypeArg(const Argument & arg,const ClassTypePtr & type)668 static bool simpleClassTypeArg(const Argument& arg, const ClassTypePtr& type) {
669   return arg.type() == type && !arg.kwarg_only() && !arg.default_value();
670 }
671 
checkObjectSortSchema(const c10::ClassTypePtr & t,std::stringstream & why_not)672 torch::jit::Function* checkObjectSortSchema(const c10::ClassTypePtr& t, std::stringstream& why_not) {
673   if (auto method = t->findMethod("__lt__")) {
674       const auto& lt_schema = method->getSchema();
675       const auto& schema_args = lt_schema.arguments();
676       bool error =
677           (schema_args.size() != 2 ||
678            !simpleClassTypeArg(schema_args[0], t) ||
679            !simpleClassTypeArg(schema_args[1], t) ||
680            lt_schema.returns().size() != 1 ||
681            lt_schema.returns()[0].type() != BoolType::get());
682       if (!error) {
683         return method;
684       }
685     }
686 
687     why_not << "To sort a list of " << t->repr_str()
688             << " it must define a "
689             << "__lt__ method with two inputs of type "
690             << t->repr_str() << " that "
691             << "returns a bool";
692     return nullptr;
693 }
694 
getLessThanComparator(const IValue & v)695 IValueComparator getLessThanComparator(const IValue& v) {
696   if (v.isTensor()) {
697       return [](const IValue& a, const IValue& b) {
698         return a.toTensor().lt(b.toTensor()).is_nonzero();
699       };
700   }
701 
702   if (v.isDouble()) {
703       return [](const IValue& a, const IValue& b) {
704         return a.toDouble() < b.toDouble();
705       };
706   }
707 
708   if (v.isInt()) {
709       return [](const IValue& a, const IValue& b) {
710         return a.toInt() < b.toInt();
711       };
712   }
713 
714   if (v.isBool()) {
715       return [](const IValue& a, const IValue& b) {
716         return a.toBool() == false && b.toBool() == true;
717       };
718   }
719 
720   if (v.isString()) {
721       return [](const IValue& a, const IValue& b) {
722        return a.toStringRef() < b.toStringRef();
723       };
724   }
725 
726   if (v.isTuple()) {
727       const auto& elements = v.toTupleRef().elements();
728       size_t n = elements.size();
729 
730       std::vector<IValueComparator> elements_lts;
731       elements_lts.reserve(n);
732       for (const auto i : c10::irange(n)) {
733         elements_lts.push_back(getLessThanComparator(elements[i]));
734       }
735 
736       return [elements_lts=std::move(elements_lts), n](const IValue& a, const IValue& b) {
737         const auto& a_elements = a.toTupleRef().elements();
738         const auto& b_elements = b.toTupleRef().elements();
739 
740         for (const auto i : c10::irange(n)) {
741           if (elements_lts[i](a_elements[i], b_elements[i])) {
742             return true;
743           }
744           if (a_elements[i] == b_elements[i]) {
745             continue;
746           }
747           return false;
748         }
749         // Reaching here means two tuples are equal.
750         return false;
751       };
752   }
753 
754   if (v.isObject()) {
755     std::stringstream why_not;
756     torch::jit::Function* lt_func =
757         checkObjectSortSchema(v.type()->expect<ClassType>(), why_not);
758     if (!lt_func) {
759       AT_ERROR(why_not.str());
760     }
761 
762     return [lt_func](const IValue& a, const IValue& b) {
763       // Quick pass to satisfy "strict weak ordering" requirement
764       if (a.is(b)) {
765         return false;
766       }
767       torch::jit::Stack sort_stack;
768       sort_stack.push_back(a);
769       sort_stack.push_back(b);
770       lt_func->run(sort_stack);
771       return torch::jit::pop(sort_stack).toBool();
772     };
773   }
774 
775   AT_ERROR("IValues of type: ", v.tagKind(), " are not comparable");
776 }
777 
getGreaterThanComparator(const IValue & v)778 IValueComparator getGreaterThanComparator(const IValue& v) {
779   auto lt = getLessThanComparator(v);
780   return [lt = std::move(lt)](const IValue& a, const IValue& b) {
781     return lt(b, a);  // gt(a, b) === lt(b, a)
782   };
783 }
784 
operator <<(std::ostream & out,const IValue & v)785 std::ostream& operator<<(std::ostream & out, const IValue & v) {
786   auto formatter = [&](std::ostream& out, const IValue& v) {
787     out << v;
788   };
789   switch(v.tag) {
790     case IValue::Tag::None:
791       return out << v.toNone();
792     case IValue::Tag::Tensor:
793       return out << v.toTensor();
794     case IValue::Tag::Storage:
795       return out << v.toStorage().unsafeGetStorageImpl();
796     case IValue::Tag::Double: {
797       double d = v.toDouble();
798       int c = std::fpclassify(d);
799       if (c == FP_NORMAL || c == FP_ZERO) {
800         int64_t i = int64_t(d);
801         if (double(i) == d) {
802           return out << i << ".";
803         }
804       }
805       auto orig_prec = out.precision();
806       return out
807         << std::setprecision(std::numeric_limits<double>::max_digits10)
808         << v.toDouble()
809         << std::setprecision(static_cast<int>(orig_prec));
810     } case IValue::Tag::ComplexDouble: {
811       return printComplex(out, v);
812     } case IValue::Tag::Int:
813       return out << v.toInt();
814     case IValue::Tag::SymInt:
815       return out << v.toSymInt();
816     case IValue::Tag::SymFloat:
817       return out << v.toSymFloat();
818     case IValue::Tag::SymBool:
819       return out << v.toSymBool();
820     case IValue::Tag::Bool:
821       return out << (v.toBool() ? "True" : "False");
822     case IValue::Tag::Tuple: {
823       const auto& elements = v.toTupleRef().elements();
824       const auto& finish = elements.size() == 1 ? ",)" : ")";
825       return printList(out, elements, "(", finish, formatter);
826     }
827     case IValue::Tag::String:
828       return out << v.toStringRef();
829     case IValue::Tag::Blob:
830       return out << *v.toBlob();
831     case IValue::Tag::Capsule:
832       return out << "Capsule";
833     case IValue::Tag::GenericList:
834       return printList(out, v.toList(), "[", "]", formatter);
835     case IValue::Tag::RRef:
836       return out << "RRef";
837     case IValue::Tag::Future:
838       return out << "Future";
839     case IValue::Tag::Await:
840       return out << "Await";
841     case IValue::Tag::Uninitialized:
842       return out << "Uninitialized";
843     case IValue::Tag::Device:
844       return out << v.toDevice();
845     case IValue::Tag::Stream:
846       return out << v.toStream();
847     case IValue::Tag::GenericDict:
848       return printDict(out, v.toGenericDict(), formatter);
849     case IValue::Tag::PyObject: {
850       auto py_obj = v.toPyObject();
851       return out << "<PyObject at" << py_obj << ">";
852     }
853     case IValue::Tag::Generator:
854       return out << "Generator";
855     case IValue::Tag::Quantizer:
856       return out << "Quantizer";
857     case IValue::Tag::Object: {
858       // TODO we should attempt to call __str__ if the object defines it.
859       auto obj = v.toObject();
860       // print this out the way python would do it
861       return out << "<" << obj->name() << " object at " << obj.get() << ">";
862     }
863     case IValue::Tag::Enum: {
864       auto enum_holder = v.toEnumHolder();
865       return out << "Enum<" << enum_holder->unqualifiedClassName() << "." <<
866           enum_holder->name() << ">";
867     }
868 
869   }
870   return out << "<Invalid IValue tag=" << std::to_string(static_cast<uint32_t>(v.tag)) << ">";
871 }
872 
873 #undef TORCH_FORALL_TAGS
874 
dump() const875 void IValue::dump() const {
876   std::cout << *this << "\n";
877 }
878 
type() const879 std::shared_ptr<ClassType> ivalue::Object::type() const {
880   return type_.type_->expect<ClassType>();
881 }
882 
create(ClassTypePtr classType,size_t numSlots)883 c10::intrusive_ptr<ivalue::Object> ivalue::Object::create(
884     ClassTypePtr classType, size_t numSlots) {
885   return ivalue::Object::create(
886       StrongTypePtr(nullptr, std::move(classType)), numSlots);
887 }
888 
deepcopy(std::optional<at::Device> device) const889 IValue IValue::deepcopy(std::optional<at::Device> device) const {
890   IValue::HashIdentityIValueMap memo;
891   return deepcopy(memo, device);
892 }
893 
deepcopy(IValue::HashIdentityIValueMap & memo,std::optional<at::Device> device) const894 IValue IValue::deepcopy(
895     IValue::HashIdentityIValueMap& memo,
896     std::optional<at::Device> device) const {
897   if (memo.count(*this)) {
898     return memo.at(*this);
899   }
900   IValue copy;
901   switch(tag) {
902     case IValue::Tag::Tensor: {
903       const at::Tensor& src_tensor = toTensor();
904       copy = device.has_value() && !src_tensor.device().is_meta()
905           ? IValue(src_tensor.to(*device))
906           : IValue(src_tensor.clone());
907     } break;
908     case IValue::Tag::Tuple: {
909       std::vector<IValue> copied_tuple;
910       for (const auto& e : toTupleRef().elements()) {
911         copied_tuple.emplace_back(e.deepcopy(memo, device));
912       }
913       copy = IValue(ivalue::Tuple::create(std::move(copied_tuple)));
914     }
915       break;
916     case IValue::Tag::GenericList: {
917       auto list = toList();
918       auto copied_list = c10::impl::GenericList(list.elementType());
919       for (IValue v : list) {
920         copied_list.push_back(v.deepcopy(memo, device));
921       }
922       copy = IValue(copied_list);
923     }
924       break;
925     case IValue::Tag::GenericDict: {
926       auto dict = toGenericDict();
927       auto copied_dict = c10::impl::GenericDict(dict.keyType(), dict.valueType());
928       for (const auto& entry : dict) {
929         copied_dict.insert(
930             entry.key().deepcopy(memo, device),
931             entry.value().deepcopy(memo, device));
932       }
933       copy = IValue(copied_dict);
934     }
935       break;
936     case IValue::Tag::Object: {
937       auto class_type = type()->expect<ClassType>();
938       if (class_type->hasMethod("__getstate__") &&
939           class_type->hasMethod("__setstate__")) {
940         copy = ivalue::Object::create(
941             c10::StrongTypePtr(class_type->compilation_unit(), type()),
942             class_type->numAttributes());
943         auto state = class_type->getMethod("__getstate__")({*this});
944         class_type->getMethod("__setstate__")({copy, std::move(state)});
945       } else {
946         copy = IValue(toObject()->deepcopy(memo, device));
947       }
948     } break;
949     case IValue::Tag::Enum: {
950       auto enum_holder = toEnumHolder();
951       copy = IValue(c10::make_intrusive<ivalue::EnumHolder>(
952           enum_holder->type(),
953           enum_holder->name(),
954           enum_holder->value().deepcopy(memo, device)));
955     } break;
956     case IValue::Tag::String:
957     case IValue::Tag::None:
958     case IValue::Tag::Double:
959     case IValue::Tag::Int:
960     case IValue::Tag::SymInt:
961     case IValue::Tag::SymFloat:
962     case IValue::Tag::SymBool:
963     case IValue::Tag::Bool:
964     case IValue::Tag::Device:
965     case IValue::Tag::Generator:
966     case IValue::Tag::Uninitialized: {
967       copy = *this;
968     } break;
969     default: {
970       AT_ERROR("Can't deepcopy IValue with tag: ", tagKind());
971     }
972   }
973   // NB: this doesn't work if an object contains itself, and it may
974   // come up in the future when we expand the object system, we will
975   // have a follow up PR to fix this when it becomes an issue.
976   if (!isAliasOf(copy)) {
977     memo[*this] = copy;
978   }
979   return copy;
980 }
981 
reportToTensorTypeError() const982 void IValue::reportToTensorTypeError() const {
983   TORCH_CHECK(false, "Expected Tensor but got ", tagKind());
984 }
985 
name() const986 std::string ivalue::Object::name() const {
987   // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
988   return type()->name()->qualifiedName();
989 }
990 
getAttr(const std::string & name) const991 IValue ivalue::Object::getAttr(const std::string& name) const {
992   const size_t slot = type()->getAttributeSlot(name);
993   return getSlot(slot);
994 }
995 
setAttr(const std::string & name,IValue v)996 void ivalue::Object::setAttr(const std::string& name, IValue v) {
997   const size_t slot = type()->getAttributeSlot(name);
998   setSlot(slot, std::move(v));
999 }
1000 
unsafeRemoveAttr(const std::string & name)1001 void ivalue::Object::unsafeRemoveAttr(const std::string& name) {
1002   const size_t slot = type()->getAttributeSlot(name);
1003   unsafeRemoveSlot(slot);
1004 }
1005 
resizeObject(size_t slot)1006 void ivalue::Object::resizeObject(size_t slot) {
1007   AT_ASSERT(slot < type()->numAttributes());
1008   slots_.resize(type()->numAttributes());
1009 }
1010 
1011 
copy() const1012 c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy() const {
1013   auto object = ivalue::Object::create(type_, type()->numAttributes());
1014   for (const auto i : c10::irange(slots_.size())) {
1015     object->setSlot(i, slots_[i]);
1016   }
1017   return object;
1018 }
1019 
copy_to_weak_compilation_ref() const1020 c10::intrusive_ptr<ivalue::Object> ivalue::Object::copy_to_weak_compilation_ref() const {
1021   auto object = ivalue::Object::create(
1022       WeakOrStrongTypePtr(type_.asWeakTypePtr()), type()->numAttributes());
1023   for (const auto i : c10::irange(slots_.size())) {
1024     object->setSlot(i, slots_[i]);
1025   }
1026   return object;
1027 }
1028 
deepcopy(std::optional<at::Device> device) const1029 c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
1030     std::optional<at::Device> device) const {
1031   IValue::HashIdentityIValueMap memo;
1032   return deepcopy(memo, device);
1033 }
1034 
deepcopy(IValue::HashIdentityIValueMap & memo,std::optional<at::Device> device) const1035 c10::intrusive_ptr<ivalue::Object> ivalue::Object::deepcopy(
1036     IValue::HashIdentityIValueMap& memo,
1037     std::optional<at::Device> device) const {
1038   auto cu = type_.cu_;
1039   auto object = ivalue::Object::create(WeakOrStrongTypePtr(type_.cu_, type_.type_), type()->numAttributes());
1040   for (const auto i : c10::irange(slots_.size())) {
1041     if (*slots_[i].type() == *c10::TypeFactory::get<CapsuleType>()) {
1042       // If we've gotten here, it means that we have *not* copied this
1043       // class via __getstate__ and __setstate__. That fact and the
1044       // fact that we have a Capsule attribute mean that this is a
1045       // custom C++ class without serialization methods defined.
1046       std::stringstream err;
1047       err << "Cannot serialize custom bound C++ class";
1048       if (auto qualname = type()->name()) {
1049         err << " " << qualname->qualifiedName();
1050       }
1051       err << ". Please define serialization methods via def_pickle() for "
1052             "this class.";
1053       AT_ERROR(err.str());
1054     }
1055     object->setSlot(i, slots_[i].deepcopy(memo, device));
1056   }
1057   return object;
1058 }
1059 
StrongTypePtr(std::shared_ptr<torch::jit::CompilationUnit> cu,TypePtr type)1060 StrongTypePtr::StrongTypePtr(
1061     std::shared_ptr<torch::jit::CompilationUnit> cu,
1062     TypePtr type) : cu_(std::move(cu)), type_(std::move(type)) {
1063   TORCH_INTERNAL_ASSERT(type_);
1064 }
1065 
WeakTypePtr(std::weak_ptr<torch::jit::CompilationUnit> cu,TypePtr type)1066 WeakTypePtr::WeakTypePtr(
1067     std::weak_ptr<torch::jit::CompilationUnit> cu,
1068     TypePtr type) : cu_(std::move(cu)), type_(std::move(type)) {}
1069 
asWeakTypePtr() const1070 WeakTypePtr WeakOrStrongTypePtr::asWeakTypePtr() const {
1071   if (!holds_strong_ref()) {
1072     return WeakTypePtr(cu_.getWeakRefOrThrow(), type_);
1073   } else {
1074     std::weak_ptr<torch::jit::CompilationUnit> weak_cu =
1075         cu_.getStrongRefOrThrow();
1076     return WeakTypePtr(std::move(weak_cu), type_);
1077   }
1078 }
1079 
1080 // Needs to be in this .cpp file to access the full definition of PyObjectHolder
extractStorages(const at::IValue & value)1081 std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> ivalue::Future::extractStorages(
1082     const at::IValue& value) {
1083   std::vector<c10::weak_intrusive_ptr<c10::StorageImpl>> weakStorageImpls;
1084   // getSubValues works poorly on Python objects: it only works if they can be
1085   // converted to a "regular" IValue type hence, for example, it doesn't support
1086   // custom subclasses. Thus, instead, we extract the tensors through pickling.
1087   if (value.isPyObject()) {
1088     std::vector<at::Tensor> tensors =
1089         value.toPyObjectHolder()->extractTensors();
1090     size_t num_storages = 0;
1091     for (const at::Tensor& tensor : tensors) {
1092       if (tensor.is_sparse()) {
1093         // Sparse tensor is indices and values. Both are tensors
1094         // and contain storage. Therefore num_storages needs to be
1095         // incremented by 2.
1096         num_storages += 2;
1097       } else {
1098         // A dense/strided tensor contains 1 storage.
1099         num_storages += 1;
1100       }
1101     }
1102     weakStorageImpls.reserve(num_storages);
1103     for (const at::Tensor& tensor : tensors) {
1104       if (tensor.is_sparse()) {
1105         // Sparse tensor is indices and values. Both are tensors
1106         // and contain storage.
1107         // TODO (rohan-varma): for tensors created with at::sparse_coo_tensor held
1108         // in a python object, this might need a coalesce().
1109         weakStorageImpls.emplace_back(tensor.indices().storage().getWeakStorageImpl());
1110         weakStorageImpls.emplace_back(tensor.values().storage().getWeakStorageImpl());
1111       } else {
1112         // A dense/strided tensor contains 1 storage
1113         weakStorageImpls.emplace_back(tensor.storage().getWeakStorageImpl());
1114       }
1115     }
1116   } else {
1117     at::IValue::HashAliasedIValues sub_values;
1118     // Prefer getSubValues() over visit() as the latter is a silent no-op for
1119     // some unsupported types, whereas the former at least fails loudly.
1120     value.getSubValues(sub_values);
1121     for (const at::IValue& sub_value : sub_values) {
1122       if (sub_value.isTensor()) {
1123         auto const & tens = sub_value.toTensor();
1124         if (tens.is_sparse()) {
1125           // sparse tensors have 2 storages! One for indices one for values
1126           auto coalesced = tens.coalesce();
1127           weakStorageImpls.emplace_back(coalesced.indices().storage().getWeakStorageImpl());
1128           weakStorageImpls.emplace_back(coalesced.values().storage().getWeakStorageImpl());
1129         } else {
1130           weakStorageImpls.emplace_back(tens.storage().getWeakStorageImpl());
1131         }
1132       }
1133     }
1134   }
1135   return weakStorageImpls;
1136 }
1137 
collectAll(const List<intrusive_ptr<ivalue::Future>> & srcs)1138 TORCH_API intrusive_ptr<ivalue::Future> collectAll(
1139     const List<intrusive_ptr<ivalue::Future>>& srcs) {
1140   struct Ctx {
1141     explicit Ctx(const List<intrusive_ptr<ivalue::Future>>& srcs)
1142         : remaining(srcs.size()),
1143           srcFutures(srcs),
1144           asIvalue(srcFutures),
1145           // No need to pass devices, because dstFuture won't directly contain
1146           // the value, it will contain the srcFutures (which have no DataPtrs).
1147           dstFuture(make_intrusive<ivalue::Future>(asIvalue.type())) {}
1148     std::atomic<size_t> remaining{0};
1149     List<intrusive_ptr<ivalue::Future>> srcFutures;
1150     IValue asIvalue;
1151     intrusive_ptr<ivalue::Future> dstFuture;
1152   };
1153 
1154   auto ctx = std::make_shared<Ctx>(srcs);
1155   if (ctx->srcFutures.empty()) {
1156     ctx->dstFuture->markCompleted(ctx->asIvalue);
1157   } else {
1158     for (const auto i : c10::irange(ctx->srcFutures.size())) {
1159 
1160       std::function<void(ivalue::Future&)> func = [ctx](ivalue::Future& fut) {
1161         // Set error and exit early if encountered.
1162         if (fut.hasError()) {
1163           ctx->dstFuture->setErrorIfNeeded(fut.exception_ptr());
1164           return;
1165         }
1166 
1167         if (--ctx->remaining == 0 && !ctx->dstFuture->completed()) {
1168           // No need to pass DataPtrs, because dstFuture won't directly contain
1169           // the value, it will contain the srcFutures (which have no DataPtrs).
1170           ctx->dstFuture->markCompleted(ctx->asIvalue);
1171         }
1172       };
1173       ctx->srcFutures.get(i)->addCallback(func);
1174     }
1175   }
1176   return ctx->dstFuture;
1177 }
1178 
1179 namespace {
1180 
1181 #ifndef STRIP_ERROR_MESSAGES
formatSetOfDevices(const std::vector<c10::Device> & devices)1182 std::string formatSetOfDevices(const std::vector<c10::Device>& devices) {
1183   std::ostringstream oss;
1184   std::copy(
1185       devices.begin(),
1186       devices.end(),
1187       std::ostream_iterator<c10::Device>(oss, ", "));
1188   return oss.str();
1189 }
1190 #endif
1191 
1192 }
1193 
collectAny(const List<intrusive_ptr<ivalue::Future>> & srcs)1194 TORCH_API intrusive_ptr<ivalue::Future> collectAny(
1195     const List<intrusive_ptr<ivalue::Future>>& srcs) {
1196   if (srcs.empty()) {
1197     auto res = make_intrusive<ivalue::Future>(NoneType::get());
1198     res->markCompleted();
1199     return res;
1200   }
1201   const TypePtr& typePtr = srcs.get(0)->elementType();
1202   const std::vector<c10::Device>& devices = srcs.get(0)->devices();
1203   for (const auto i : c10::irange(srcs.size())) {
1204     if (srcs.get(i)->completed()) {
1205       return srcs.get(i);
1206     }
1207     TORCH_CHECK_TYPE(
1208         i == 0 || (*typePtr == *srcs.get(i)->elementType()),
1209         "Expected all futures to have the same type, but found ", *typePtr,
1210         " in position 0 and ", *srcs.get(i)->elementType(), " in position ", i);
1211     TORCH_CHECK_VALUE(
1212         i == 0 || (devices == srcs.get(i)->devices()),
1213         "Expected all futures to have the same devices, but found ",
1214         formatSetOfDevices(devices), " in position 0 and ",
1215         formatSetOfDevices(srcs.get(i)->devices()), " in position ", i);
1216   }
1217   struct Ctx {
1218     explicit Ctx(
1219         const List<intrusive_ptr<ivalue::Future>>& srcs,
1220         TypePtr typePtr,
1221         std::vector<c10::Device> devices)
1222         : srcFutures(srcs),
1223           dstFuture(make_intrusive<ivalue::Future>(std::move(typePtr), std::move(devices))) {}
1224     std::atomic<bool> done{false};
1225     List<intrusive_ptr<ivalue::Future>> srcFutures;
1226     intrusive_ptr<ivalue::Future> dstFuture;
1227   };
1228   auto ctx = std::make_shared<Ctx>(srcs, typePtr, devices);
1229   std::function<void(ivalue::Future&)> func = [ctx](ivalue::Future& src) {
1230     if (!ctx->done.exchange(true)) {
1231       intrusive_ptr<ivalue::Future> dst = ctx->dstFuture;
1232       ctx->dstFuture.reset(); // Once future is satisfied, remove refs.
1233       ctx->srcFutures =
1234           List<intrusive_ptr<ivalue::Future>>(ctx->srcFutures.elementType());
1235       if (src.hasError()) {
1236         dst->setError(src.exception_ptr());
1237       } else {
1238         dst->markCompleted(src.constValue(), src.storages());
1239       }
1240     }
1241   };
1242   for (const auto i : c10::irange(ctx->srcFutures.size())) {
1243     ctx->srcFutures.get(i)->addCallback(func);
1244   }
1245   return ctx->dstFuture;
1246 }
1247 
1248 } // namespace c10
1249