xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/union_type.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/Dict.h>
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/function.h>
4 #include <ATen/core/function_schema.h>
5 #include <ATen/core/grad_mode.h>
6 #include <ATen/core/jit_type.h>
7 #include <ATen/core/type_factory.h>
8 #include <c10/macros/Macros.h>
9 #include <c10/util/irange.h>
10 #include <ostream>
11 #include <sstream>
12 #include <utility>
13 
14 namespace c10 {
15 
create(const TypePtr & contained)16 OptionalTypePtr OptionalType::create(const TypePtr& contained) {
17   return OptionalTypePtr(new OptionalType(contained));
18 }
19 
ofTensor()20 TypePtr OptionalType::ofTensor() {
21   static auto value = OptionalType::create(TensorType::get());
22   return value;
23 }
24 
ofOptionalTensors()25 ListTypePtr ListType::ofOptionalTensors() {
26   static auto value = ListType::create(OptionalType::ofTensor());
27   return value;
28 }
29 
30 namespace {
31 
subtractTypeSetFrom(std::vector<TypePtr> & to_subtract,ArrayRef<TypePtr> from)32 std::optional<TypePtr> subtractTypeSetFrom(std::vector<TypePtr>& to_subtract, ArrayRef<TypePtr> from) {
33   std::vector<TypePtr> types;
34 
35   // Given a TypePtr `lhs`, this function says whether or not `lhs` (or
36   // one of its parent types) is in the `to_subtract` vector
37   auto should_subtract = [&](const TypePtr& lhs) -> bool {
38     return std::any_of(to_subtract.begin(), to_subtract.end(),
39                         [&](const TypePtr& rhs) {
40                           return lhs->isSubtypeOf(*rhs);
41                         });
42   };
43 
44   // Copy all the elements that should NOT be subtracted to the `types`
45   // vector
46   std::copy_if(from.begin(), from.end(),
47               std::back_inserter(types),
48               [&](const TypePtr& t) {
49                 return !should_subtract(t);
50               });
51 
52   if (types.empty()) {
53     return std::nullopt;
54   } else if (types.size() == 1) {
55     return types[0];
56   } else {
57     return UnionType::create(std::move(types));
58   }
59 }
60 
61 // Remove nested Optionals/Unions during the instantiation of a Union or
62 // an Optional. This populates `types` with all the types found during
63 // flattening. At the end of `flattenUnion`, `types` may have
64 // duplicates, but it will not have nested Optionals/Unions
flattenUnion(const TypePtr & type,std::vector<TypePtr> * to_fill)65 void flattenUnion(const TypePtr& type, std::vector<TypePtr>* to_fill) {
66   if (auto* union_type = type->castRaw<UnionType>()) {
67     for (const auto& inner : union_type->containedTypes()) {
68       flattenUnion(inner, to_fill);
69     }
70   } else if (auto* opt_type = type->castRaw<OptionalType>()) {
71     const auto& inner = opt_type->getElementType();
72     flattenUnion(inner, to_fill);
73     to_fill->emplace_back(NoneType::get());
74   } else if (type->kind() == NumberType::Kind) {
75     to_fill->emplace_back(IntType::get());
76     to_fill->emplace_back(FloatType::get());
77     to_fill->emplace_back(ComplexType::get());
78   } else {
79     to_fill->emplace_back(type);
80   }
81 }
82 
83 // Helper function for `standardizeUnion`
84 //
85 // NB: If we have types `T1`, `T2`, `T3`, and `PARENT_T` such that `T1`,
86 // `T2`, and `T2` are children of `PARENT_T`, then `unifyTypes(T1, T2)`
87 // will return `PARENT_T`. This could be a problem if we didn't want our
88 // Union to also be able to take `T3 `. In our current type hierarchy,
89 // this isn't an issue--most types SHOULD be unified even if the parent
90 // type wasn't in the original vector. However, later additions to the
91 // type system might necessitate reworking `get_supertype`
filterDuplicateSubtypes(std::vector<TypePtr> * types)92 void filterDuplicateSubtypes(std::vector<TypePtr>* types) {
93   if (types->empty()) {
94     return;
95   }
96   auto get_supertype = [](const TypePtr& t1, const TypePtr& t2) -> std::optional<TypePtr> {
97     // We don't want nested Optionals. Also, prematurely unifying to
98     // `Optional` could prevent us from coalescing other types
99     if ((t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get()))
100         || (!t1->isSubtypeOf(*NoneType::get()) && t2->isSubtypeOf(*NoneType::get()))) {
101           return std::nullopt;
102     } else {
103       return unifyTypes(t1, t2, /*default_to_union=*/false);
104     }
105   };
106 
107   // Coalesce types and delete all duplicates. Moving from right to left
108   // through the vector, we try to unify the current element (`i`) with
109   // each element (`j`) before the "new" end of the vector (`end`).
110   // If we're able to unify the types at `types[i]` and `types[j]`, we
111   // decrement `end`, swap `types[j]` with the unified type, and
112   // break. Otherwise, we keep `end` where it is to signify that the
113   // new end of the vector hasn't shifted
114   size_t end_idx = types->size()-1;
115   for (size_t i = types->size()-1; i > 0; --i) {
116     for (size_t j = std::min(i-1, end_idx); ; --j) {
117       std::optional<TypePtr> unified;
118       unified = get_supertype((*types)[i], (*types)[j]);
119       if (unified) {
120         (*types)[j] = *unified;
121         (*types)[i] = (*types)[end_idx];
122         --end_idx;
123         break;
124       }
125       // Break condition here so we don't get `j = 0; j = j-1` and end
126       // up with MAX_INT
127       if (j == 0) {
128         break;
129       }
130     }
131   }
132   // Cut off the vector's tail so that `end` is the real last element
133   types->erase(types->begin() + static_cast<std::ptrdiff_t>(end_idx) + 1, types->end());
134 
135 }
136 
137 }
138 
sortUnion(std::vector<TypePtr> * types)139 static void sortUnion(std::vector<TypePtr>* types) {
140   // We want the elements to be sorted so we can easily compare two
141   // UnionType objects for equality in the future. Note that this order
142   // is guaranteed to be stable since we've already coalesced any
143   // possible types
144   std::sort(types->begin(), types->end(),
145           [](const TypePtr& a, const TypePtr& b) -> bool {
146             if (a->kind() != b->kind()) {
147               return a->kind() < b->kind();
148             }
149             return a->str() < b->str();
150           });
151 }
152 
standardizeVectorForUnion(std::vector<TypePtr> & reference,std::vector<TypePtr> * to_fill)153 void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill) {
154   for (const auto& type : reference) {
155     flattenUnion(type, to_fill);
156   }
157   filterDuplicateSubtypes(to_fill);
158   sortUnion(to_fill);
159 }
160 
standardizeVectorForUnion(std::vector<TypePtr> * to_flatten)161 void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten) {
162   TORCH_INTERNAL_ASSERT(to_flatten, "`standardizeVectorForUnion` was ",
163                         "passed a `nullptr`");
164   std::vector<TypePtr> to_fill;
165   standardizeVectorForUnion(*to_flatten, &to_fill);
166   *to_flatten = std::move(to_fill);
167 }
168 
OptionalType(const TypePtr & contained)169 OptionalType::OptionalType(const TypePtr& contained)
170                            : UnionType({contained, NoneType::get()}, TypeKind::OptionalType) {
171   bool is_numbertype = false;
172   if (auto as_union = contained->cast<UnionType>()) {
173     is_numbertype = as_union->containedTypes().size() == 3 &&
174                     as_union->canHoldType(*NumberType::get());
175   }
176   if (UnionType::containedTypes().size() == 2) {
177     contained_ = UnionType::containedTypes()[0]->kind()!= NoneType::Kind
178                  ? UnionType::containedTypes()[0]
179                  : UnionType::containedTypes()[1];
180   } else if (contained == NumberType::get() || is_numbertype) {
181     contained_ = NumberType::get();
182     types_.clear();
183     types_.emplace_back(NumberType::get());
184     types_.emplace_back(NoneType::get());
185   } else {
186     std::vector<TypePtr> to_subtract{NoneType::get()};
187     auto without_none = subtractTypeSetFrom(to_subtract, types_);
188     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
189     contained_ = UnionType::create({*without_none});
190   }
191   has_free_variables_ = contained_->hasFreeVariables();
192 }
193 
UnionType(std::vector<TypePtr> reference,TypeKind kind)194 UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : SharedType(kind) {
195   TORCH_INTERNAL_ASSERT(!reference.empty(), "Cannot create an empty Union");
196 
197   standardizeVectorForUnion(reference, &types_);
198 
199   // Gate the assert in a regular conditional so that we don't create
200   // this long error message unnecessarily
201   if (types_.size() == 1) {
202     std::stringstream msg;
203     msg << "After type unification was performed, the Union with the "
204         << "original types {";
205     for (const auto i : c10::irange(reference.size())) {
206       msg << reference[i]->repr_str();
207       if (i > 0) {
208         msg << ",";
209       }
210       msg << " ";
211     }
212     msg << "} has the single type " << types_[0]->repr_str()
213          << ". Use the common supertype instead of creating a Union"
214          << "type";
215     TORCH_INTERNAL_ASSERT(false, msg.str());
216   }
217 
218   can_hold_none_ = false;
219   has_free_variables_ = false;
220 
221   for (const TypePtr& type : types_) {
222     if (type->kind() == NoneType::Kind) {
223       can_hold_none_ = true;
224     }
225     if (type->hasFreeVariables()) {
226       has_free_variables_ = true;
227     }
228   }
229 
230 }
231 
create(std::vector<TypePtr> reference)232 UnionTypePtr UnionType::create(std::vector<TypePtr> reference) {
233   UnionTypePtr union_type(new UnionType(std::move(reference)));
234 
235   // Some very special-cased logic for `Optional`. This will be deleted
236   // in a later PR
237   bool int_found = false;
238   bool float_found = false;
239   bool complex_found = false;
240   bool nonetype_found = false;
241 
242   auto update_is_opt_flags = [&](const TypePtr& t) {
243     if (t == IntType::get()) {
244       int_found = true;
245     } else if (t == FloatType::get()) {
246       float_found  = true;
247     } else if (t == ComplexType::get()) {
248       complex_found = true;
249     } else if (t == NoneType::get()) {
250       nonetype_found = true;
251     }
252   };
253 
254   for (const auto& t : union_type->containedTypes()) {
255     update_is_opt_flags(t);
256   }
257 
258   bool numbertype_found = int_found && float_found && complex_found;
259 
260   if (nonetype_found) {
261     if (union_type->containedTypes().size() == 4 && numbertype_found) {
262       return OptionalType::create(NumberType::get());
263     }
264     if (union_type->containedTypes().size() == 2) {
265       auto not_none = union_type->containedTypes()[0] != NoneType::get()
266                       ? union_type->containedTypes()[0]
267                       : union_type->containedTypes()[1];
268       return OptionalType::create(not_none);
269     }
270   }
271 
272   return union_type;
273 }
274 
subtractTypeSet(std::vector<TypePtr> & to_subtract) const275 std::optional<TypePtr> UnionType::subtractTypeSet(std::vector<TypePtr>& to_subtract) const {
276   return subtractTypeSetFrom(to_subtract, containedTypes());
277 }
278 
toOptional() const279 std::optional<TypePtr> UnionType::toOptional() const {
280   if (!canHoldType(*NoneType::get())) {
281       return std::nullopt;
282   }
283 
284   std::vector<TypePtr> copied_types = this->containedTypes().vec();
285 
286   auto maybe_opt = UnionType::create(std::move(copied_types));
287 
288   if (maybe_opt->kind() == UnionType::Kind) {
289     return std::nullopt;
290   } else {
291     return maybe_opt;
292   }
293 }
294 
equals(const Type & rhs) const295 bool UnionType::equals(const Type& rhs) const {
296   if (auto union_rhs = rhs.cast<UnionType>()) {
297     // We can't compare the type vectors for equality using `operator=`,
298     // because the vectors hold `TypePtr`s and we want to compare `Type`
299     // equality
300     if (union_rhs->containedTypes().size() != this->containedTypes().size()) {
301       return false;
302     }
303     // Check that all the types in `this->types_` are also in
304     // `union_rhs->types_`
305     return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
306                        [&](TypePtr lhs_type) {
307                          return std::any_of(union_rhs->containedTypes().begin(),
308                                             union_rhs->containedTypes().end(),
309                                             [&](const TypePtr& rhs_type) {
310                                               return *lhs_type == *rhs_type;
311                                             });
312                        });
313   } else if (auto optional_rhs = rhs.cast<OptionalType>()) {
314     if (optional_rhs->getElementType() == NumberType::get()) {
315       return this->containedTypes().size() == 4
316              && this->can_hold_none_
317              && this->canHoldType(*NumberType::get());
318     }
319     auto optional_lhs = this->toOptional();
320     return optional_lhs && *optional_rhs == *((optional_lhs.value())->expect<OptionalType>());
321   } else if (rhs.kind() == NumberType::Kind) {
322     return this->containedTypes().size() == 3 && canHoldType(*NumberType::get());
323   } else {
324     return false;
325   }
326 }
327 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const328 bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
329   std::vector<const Type*> rhs_types;
330   if (const auto union_rhs = rhs.cast<UnionType>()) {
331     // Fast path
332     if (this->containedTypes() == rhs.containedTypes()) {
333       return true;
334     }
335     for (const auto& typePtr: rhs.containedTypes()) {
336       rhs_types.push_back(typePtr.get());
337     }
338   } else if (const auto optional_rhs = rhs.cast<OptionalType>()) {
339     rhs_types.push_back(NoneType::get().get());
340     if (optional_rhs->getElementType() == NumberType::get()) {
341       std::array<const Type*, 3> number_types{IntType::get().get(), FloatType::get().get(), ComplexType::get().get()};
342       rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
343     } else {
344       rhs_types.push_back(optional_rhs->getElementType().get());
345     }
346   } else if (const auto number_rhs = rhs.cast<NumberType>()) {
347     std::array<const Type*, 3> number_types{IntType::get().get(), FloatType::get().get(), ComplexType::get().get()};
348     rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
349   } else {
350     rhs_types.push_back(&rhs);
351   }
352   return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
353                      [&](const TypePtr& lhs_type) -> bool {
354                       return std::any_of(rhs_types.begin(),
355                                          rhs_types.end(),
356                                          [&](const Type* rhs_type) -> bool {
357                                            return lhs_type->isSubtypeOfExt(*rhs_type, why_not);
358                                          });
359   });
360 }
361 
unionStr(const TypePrinter & printer,bool is_annotation_str) const362 std::string UnionType::unionStr(const TypePrinter& printer, bool is_annotation_str)
363     const {
364   std::stringstream ss;
365 
366   bool can_hold_numbertype = this->canHoldType(*NumberType::get());
367 
368   std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
369 
370   auto is_numbertype = [&](const TypePtr& lhs) {
371     for (const auto& rhs : number_types) {
372       if (*lhs == *rhs) {
373         return true;
374       }
375     }
376     return false;
377   };
378 
379   std::string open_delimeter = is_annotation_str ? "[" : "(";
380   std::string close_delimeter = is_annotation_str ? "]" : ")";
381 
382   ss << "Union" + open_delimeter;
383   bool printed = false;
384   for (size_t i = 0; i < types_.size(); ++i) {
385     if (!can_hold_numbertype || !is_numbertype(types_[i])) {
386       if (i > 0) {
387         ss << ", ";
388         printed = true;
389       }
390       if (is_annotation_str) {
391         ss << this->containedTypes()[i]->annotation_str(printer);
392       } else {
393         ss << this->containedTypes()[i]->str();
394       }
395     }
396   }
397   if (can_hold_numbertype) {
398     if (printed) {
399       ss << ", ";
400     }
401     if (is_annotation_str) {
402       ss << NumberType::get()->annotation_str(printer);
403     } else {
404       ss << NumberType::get()->str();
405     }
406   }
407   ss << close_delimeter;
408   return ss.str();
409 }
410 
str() const411 std::string UnionType::str() const {
412   return this->unionStr(nullptr, /*is_annotation_str=*/false);
413 }
414 
annotation_str_impl(const TypePrinter & printer) const415 std::string UnionType::annotation_str_impl(const TypePrinter& printer) const {
416   return this->unionStr(printer, /*is_annotation_str=*/true);
417 }
418 
canHoldType(const Type & type) const419 bool UnionType::canHoldType(const Type& type) const {
420   if (&type == NumberType::get().get()) {
421     return canHoldType(*IntType::get())
422            && canHoldType(*FloatType::get())
423            && canHoldType(*ComplexType::get());
424   } else {
425     return std::any_of(this->containedTypes().begin(), this->containedTypes().end(),
426                     [&](const TypePtr& inner) {
427                       return type.isSubtypeOf(*inner);
428                     });
429   }
430 }
431 
equals(const Type & rhs) const432 bool OptionalType::equals(const Type& rhs) const {
433   if (auto union_rhs = rhs.cast<UnionType>()) {
434     auto optional_rhs = union_rhs->toOptional();
435     // `**optional_rhs` = `*` to get value of `std::optional<TypePtr>`,
436     // then `*` to dereference the pointer
437     return optional_rhs && *this == **optional_rhs;
438   } else if (auto optional_rhs = rhs.cast<OptionalType>()) {
439     return *this->getElementType() == *optional_rhs->getElementType();
440   } else {
441     return false;
442   }
443 }
444 
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const445 bool OptionalType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
446   if (auto optional_rhs = rhs.castRaw<OptionalType>()) {
447     return getElementType()->isSubtypeOfExt(*optional_rhs->getElementType(), why_not);
448   } else if (auto union_rhs = rhs.castRaw<UnionType>()) {
449     if (!union_rhs->canHoldType(*NoneType::get())) {
450       if (why_not) {
451         *why_not << rhs.repr_str() << " cannot hold None";
452       }
453       return false;
454     } else if (!union_rhs->canHoldType(*this->getElementType())) {
455       if (why_not) {
456         *why_not << rhs.repr_str() << " cannot hold " << this->getElementType();
457       }
458       return false;
459     } else {
460       return true;
461     }
462   } else {
463     // NOLINTNEXTLINE(bugprone-parent-virtual-call)
464     return Type::isSubtypeOfExt(rhs, why_not);
465   }
466 }
467 
468 } // namespace 10
469