1 #include <ATen/core/Dict.h>
2 #include <ATen/core/Tensor.h>
3 #include <ATen/core/function.h>
4 #include <ATen/core/function_schema.h>
5 #include <ATen/core/grad_mode.h>
6 #include <ATen/core/jit_type.h>
7 #include <ATen/core/type_factory.h>
8 #include <c10/macros/Macros.h>
9 #include <c10/util/irange.h>
10 #include <ostream>
11 #include <sstream>
12 #include <utility>
13
14 namespace c10 {
15
create(const TypePtr & contained)16 OptionalTypePtr OptionalType::create(const TypePtr& contained) {
17 return OptionalTypePtr(new OptionalType(contained));
18 }
19
ofTensor()20 TypePtr OptionalType::ofTensor() {
21 static auto value = OptionalType::create(TensorType::get());
22 return value;
23 }
24
ofOptionalTensors()25 ListTypePtr ListType::ofOptionalTensors() {
26 static auto value = ListType::create(OptionalType::ofTensor());
27 return value;
28 }
29
30 namespace {
31
subtractTypeSetFrom(std::vector<TypePtr> & to_subtract,ArrayRef<TypePtr> from)32 std::optional<TypePtr> subtractTypeSetFrom(std::vector<TypePtr>& to_subtract, ArrayRef<TypePtr> from) {
33 std::vector<TypePtr> types;
34
35 // Given a TypePtr `lhs`, this function says whether or not `lhs` (or
36 // one of its parent types) is in the `to_subtract` vector
37 auto should_subtract = [&](const TypePtr& lhs) -> bool {
38 return std::any_of(to_subtract.begin(), to_subtract.end(),
39 [&](const TypePtr& rhs) {
40 return lhs->isSubtypeOf(*rhs);
41 });
42 };
43
44 // Copy all the elements that should NOT be subtracted to the `types`
45 // vector
46 std::copy_if(from.begin(), from.end(),
47 std::back_inserter(types),
48 [&](const TypePtr& t) {
49 return !should_subtract(t);
50 });
51
52 if (types.empty()) {
53 return std::nullopt;
54 } else if (types.size() == 1) {
55 return types[0];
56 } else {
57 return UnionType::create(std::move(types));
58 }
59 }
60
61 // Remove nested Optionals/Unions during the instantiation of a Union or
62 // an Optional. This populates `types` with all the types found during
63 // flattening. At the end of `flattenUnion`, `types` may have
64 // duplicates, but it will not have nested Optionals/Unions
flattenUnion(const TypePtr & type,std::vector<TypePtr> * to_fill)65 void flattenUnion(const TypePtr& type, std::vector<TypePtr>* to_fill) {
66 if (auto* union_type = type->castRaw<UnionType>()) {
67 for (const auto& inner : union_type->containedTypes()) {
68 flattenUnion(inner, to_fill);
69 }
70 } else if (auto* opt_type = type->castRaw<OptionalType>()) {
71 const auto& inner = opt_type->getElementType();
72 flattenUnion(inner, to_fill);
73 to_fill->emplace_back(NoneType::get());
74 } else if (type->kind() == NumberType::Kind) {
75 to_fill->emplace_back(IntType::get());
76 to_fill->emplace_back(FloatType::get());
77 to_fill->emplace_back(ComplexType::get());
78 } else {
79 to_fill->emplace_back(type);
80 }
81 }
82
83 // Helper function for `standardizeUnion`
84 //
85 // NB: If we have types `T1`, `T2`, `T3`, and `PARENT_T` such that `T1`,
86 // `T2`, and `T2` are children of `PARENT_T`, then `unifyTypes(T1, T2)`
87 // will return `PARENT_T`. This could be a problem if we didn't want our
88 // Union to also be able to take `T3 `. In our current type hierarchy,
89 // this isn't an issue--most types SHOULD be unified even if the parent
90 // type wasn't in the original vector. However, later additions to the
91 // type system might necessitate reworking `get_supertype`
filterDuplicateSubtypes(std::vector<TypePtr> * types)92 void filterDuplicateSubtypes(std::vector<TypePtr>* types) {
93 if (types->empty()) {
94 return;
95 }
96 auto get_supertype = [](const TypePtr& t1, const TypePtr& t2) -> std::optional<TypePtr> {
97 // We don't want nested Optionals. Also, prematurely unifying to
98 // `Optional` could prevent us from coalescing other types
99 if ((t1->isSubtypeOf(*NoneType::get()) && !t2->isSubtypeOf(*NoneType::get()))
100 || (!t1->isSubtypeOf(*NoneType::get()) && t2->isSubtypeOf(*NoneType::get()))) {
101 return std::nullopt;
102 } else {
103 return unifyTypes(t1, t2, /*default_to_union=*/false);
104 }
105 };
106
107 // Coalesce types and delete all duplicates. Moving from right to left
108 // through the vector, we try to unify the current element (`i`) with
109 // each element (`j`) before the "new" end of the vector (`end`).
110 // If we're able to unify the types at `types[i]` and `types[j]`, we
111 // decrement `end`, swap `types[j]` with the unified type, and
112 // break. Otherwise, we keep `end` where it is to signify that the
113 // new end of the vector hasn't shifted
114 size_t end_idx = types->size()-1;
115 for (size_t i = types->size()-1; i > 0; --i) {
116 for (size_t j = std::min(i-1, end_idx); ; --j) {
117 std::optional<TypePtr> unified;
118 unified = get_supertype((*types)[i], (*types)[j]);
119 if (unified) {
120 (*types)[j] = *unified;
121 (*types)[i] = (*types)[end_idx];
122 --end_idx;
123 break;
124 }
125 // Break condition here so we don't get `j = 0; j = j-1` and end
126 // up with MAX_INT
127 if (j == 0) {
128 break;
129 }
130 }
131 }
132 // Cut off the vector's tail so that `end` is the real last element
133 types->erase(types->begin() + static_cast<std::ptrdiff_t>(end_idx) + 1, types->end());
134
135 }
136
137 }
138
sortUnion(std::vector<TypePtr> * types)139 static void sortUnion(std::vector<TypePtr>* types) {
140 // We want the elements to be sorted so we can easily compare two
141 // UnionType objects for equality in the future. Note that this order
142 // is guaranteed to be stable since we've already coalesced any
143 // possible types
144 std::sort(types->begin(), types->end(),
145 [](const TypePtr& a, const TypePtr& b) -> bool {
146 if (a->kind() != b->kind()) {
147 return a->kind() < b->kind();
148 }
149 return a->str() < b->str();
150 });
151 }
152
standardizeVectorForUnion(std::vector<TypePtr> & reference,std::vector<TypePtr> * to_fill)153 void standardizeVectorForUnion(std::vector<TypePtr>& reference, std::vector<TypePtr>* to_fill) {
154 for (const auto& type : reference) {
155 flattenUnion(type, to_fill);
156 }
157 filterDuplicateSubtypes(to_fill);
158 sortUnion(to_fill);
159 }
160
standardizeVectorForUnion(std::vector<TypePtr> * to_flatten)161 void standardizeVectorForUnion(std::vector<TypePtr>* to_flatten) {
162 TORCH_INTERNAL_ASSERT(to_flatten, "`standardizeVectorForUnion` was ",
163 "passed a `nullptr`");
164 std::vector<TypePtr> to_fill;
165 standardizeVectorForUnion(*to_flatten, &to_fill);
166 *to_flatten = std::move(to_fill);
167 }
168
OptionalType(const TypePtr & contained)169 OptionalType::OptionalType(const TypePtr& contained)
170 : UnionType({contained, NoneType::get()}, TypeKind::OptionalType) {
171 bool is_numbertype = false;
172 if (auto as_union = contained->cast<UnionType>()) {
173 is_numbertype = as_union->containedTypes().size() == 3 &&
174 as_union->canHoldType(*NumberType::get());
175 }
176 if (UnionType::containedTypes().size() == 2) {
177 contained_ = UnionType::containedTypes()[0]->kind()!= NoneType::Kind
178 ? UnionType::containedTypes()[0]
179 : UnionType::containedTypes()[1];
180 } else if (contained == NumberType::get() || is_numbertype) {
181 contained_ = NumberType::get();
182 types_.clear();
183 types_.emplace_back(NumberType::get());
184 types_.emplace_back(NoneType::get());
185 } else {
186 std::vector<TypePtr> to_subtract{NoneType::get()};
187 auto without_none = subtractTypeSetFrom(to_subtract, types_);
188 // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
189 contained_ = UnionType::create({*without_none});
190 }
191 has_free_variables_ = contained_->hasFreeVariables();
192 }
193
UnionType(std::vector<TypePtr> reference,TypeKind kind)194 UnionType::UnionType(std::vector<TypePtr> reference, TypeKind kind) : SharedType(kind) {
195 TORCH_INTERNAL_ASSERT(!reference.empty(), "Cannot create an empty Union");
196
197 standardizeVectorForUnion(reference, &types_);
198
199 // Gate the assert in a regular conditional so that we don't create
200 // this long error message unnecessarily
201 if (types_.size() == 1) {
202 std::stringstream msg;
203 msg << "After type unification was performed, the Union with the "
204 << "original types {";
205 for (const auto i : c10::irange(reference.size())) {
206 msg << reference[i]->repr_str();
207 if (i > 0) {
208 msg << ",";
209 }
210 msg << " ";
211 }
212 msg << "} has the single type " << types_[0]->repr_str()
213 << ". Use the common supertype instead of creating a Union"
214 << "type";
215 TORCH_INTERNAL_ASSERT(false, msg.str());
216 }
217
218 can_hold_none_ = false;
219 has_free_variables_ = false;
220
221 for (const TypePtr& type : types_) {
222 if (type->kind() == NoneType::Kind) {
223 can_hold_none_ = true;
224 }
225 if (type->hasFreeVariables()) {
226 has_free_variables_ = true;
227 }
228 }
229
230 }
231
create(std::vector<TypePtr> reference)232 UnionTypePtr UnionType::create(std::vector<TypePtr> reference) {
233 UnionTypePtr union_type(new UnionType(std::move(reference)));
234
235 // Some very special-cased logic for `Optional`. This will be deleted
236 // in a later PR
237 bool int_found = false;
238 bool float_found = false;
239 bool complex_found = false;
240 bool nonetype_found = false;
241
242 auto update_is_opt_flags = [&](const TypePtr& t) {
243 if (t == IntType::get()) {
244 int_found = true;
245 } else if (t == FloatType::get()) {
246 float_found = true;
247 } else if (t == ComplexType::get()) {
248 complex_found = true;
249 } else if (t == NoneType::get()) {
250 nonetype_found = true;
251 }
252 };
253
254 for (const auto& t : union_type->containedTypes()) {
255 update_is_opt_flags(t);
256 }
257
258 bool numbertype_found = int_found && float_found && complex_found;
259
260 if (nonetype_found) {
261 if (union_type->containedTypes().size() == 4 && numbertype_found) {
262 return OptionalType::create(NumberType::get());
263 }
264 if (union_type->containedTypes().size() == 2) {
265 auto not_none = union_type->containedTypes()[0] != NoneType::get()
266 ? union_type->containedTypes()[0]
267 : union_type->containedTypes()[1];
268 return OptionalType::create(not_none);
269 }
270 }
271
272 return union_type;
273 }
274
subtractTypeSet(std::vector<TypePtr> & to_subtract) const275 std::optional<TypePtr> UnionType::subtractTypeSet(std::vector<TypePtr>& to_subtract) const {
276 return subtractTypeSetFrom(to_subtract, containedTypes());
277 }
278
toOptional() const279 std::optional<TypePtr> UnionType::toOptional() const {
280 if (!canHoldType(*NoneType::get())) {
281 return std::nullopt;
282 }
283
284 std::vector<TypePtr> copied_types = this->containedTypes().vec();
285
286 auto maybe_opt = UnionType::create(std::move(copied_types));
287
288 if (maybe_opt->kind() == UnionType::Kind) {
289 return std::nullopt;
290 } else {
291 return maybe_opt;
292 }
293 }
294
equals(const Type & rhs) const295 bool UnionType::equals(const Type& rhs) const {
296 if (auto union_rhs = rhs.cast<UnionType>()) {
297 // We can't compare the type vectors for equality using `operator=`,
298 // because the vectors hold `TypePtr`s and we want to compare `Type`
299 // equality
300 if (union_rhs->containedTypes().size() != this->containedTypes().size()) {
301 return false;
302 }
303 // Check that all the types in `this->types_` are also in
304 // `union_rhs->types_`
305 return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
306 [&](TypePtr lhs_type) {
307 return std::any_of(union_rhs->containedTypes().begin(),
308 union_rhs->containedTypes().end(),
309 [&](const TypePtr& rhs_type) {
310 return *lhs_type == *rhs_type;
311 });
312 });
313 } else if (auto optional_rhs = rhs.cast<OptionalType>()) {
314 if (optional_rhs->getElementType() == NumberType::get()) {
315 return this->containedTypes().size() == 4
316 && this->can_hold_none_
317 && this->canHoldType(*NumberType::get());
318 }
319 auto optional_lhs = this->toOptional();
320 return optional_lhs && *optional_rhs == *((optional_lhs.value())->expect<OptionalType>());
321 } else if (rhs.kind() == NumberType::Kind) {
322 return this->containedTypes().size() == 3 && canHoldType(*NumberType::get());
323 } else {
324 return false;
325 }
326 }
327
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const328 bool UnionType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
329 std::vector<const Type*> rhs_types;
330 if (const auto union_rhs = rhs.cast<UnionType>()) {
331 // Fast path
332 if (this->containedTypes() == rhs.containedTypes()) {
333 return true;
334 }
335 for (const auto& typePtr: rhs.containedTypes()) {
336 rhs_types.push_back(typePtr.get());
337 }
338 } else if (const auto optional_rhs = rhs.cast<OptionalType>()) {
339 rhs_types.push_back(NoneType::get().get());
340 if (optional_rhs->getElementType() == NumberType::get()) {
341 std::array<const Type*, 3> number_types{IntType::get().get(), FloatType::get().get(), ComplexType::get().get()};
342 rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
343 } else {
344 rhs_types.push_back(optional_rhs->getElementType().get());
345 }
346 } else if (const auto number_rhs = rhs.cast<NumberType>()) {
347 std::array<const Type*, 3> number_types{IntType::get().get(), FloatType::get().get(), ComplexType::get().get()};
348 rhs_types.insert(rhs_types.end(), number_types.begin(), number_types.end());
349 } else {
350 rhs_types.push_back(&rhs);
351 }
352 return std::all_of(this->containedTypes().begin(), this->containedTypes().end(),
353 [&](const TypePtr& lhs_type) -> bool {
354 return std::any_of(rhs_types.begin(),
355 rhs_types.end(),
356 [&](const Type* rhs_type) -> bool {
357 return lhs_type->isSubtypeOfExt(*rhs_type, why_not);
358 });
359 });
360 }
361
unionStr(const TypePrinter & printer,bool is_annotation_str) const362 std::string UnionType::unionStr(const TypePrinter& printer, bool is_annotation_str)
363 const {
364 std::stringstream ss;
365
366 bool can_hold_numbertype = this->canHoldType(*NumberType::get());
367
368 std::vector<TypePtr> number_types{IntType::get(), FloatType::get(), ComplexType::get()};
369
370 auto is_numbertype = [&](const TypePtr& lhs) {
371 for (const auto& rhs : number_types) {
372 if (*lhs == *rhs) {
373 return true;
374 }
375 }
376 return false;
377 };
378
379 std::string open_delimeter = is_annotation_str ? "[" : "(";
380 std::string close_delimeter = is_annotation_str ? "]" : ")";
381
382 ss << "Union" + open_delimeter;
383 bool printed = false;
384 for (size_t i = 0; i < types_.size(); ++i) {
385 if (!can_hold_numbertype || !is_numbertype(types_[i])) {
386 if (i > 0) {
387 ss << ", ";
388 printed = true;
389 }
390 if (is_annotation_str) {
391 ss << this->containedTypes()[i]->annotation_str(printer);
392 } else {
393 ss << this->containedTypes()[i]->str();
394 }
395 }
396 }
397 if (can_hold_numbertype) {
398 if (printed) {
399 ss << ", ";
400 }
401 if (is_annotation_str) {
402 ss << NumberType::get()->annotation_str(printer);
403 } else {
404 ss << NumberType::get()->str();
405 }
406 }
407 ss << close_delimeter;
408 return ss.str();
409 }
410
str() const411 std::string UnionType::str() const {
412 return this->unionStr(nullptr, /*is_annotation_str=*/false);
413 }
414
annotation_str_impl(const TypePrinter & printer) const415 std::string UnionType::annotation_str_impl(const TypePrinter& printer) const {
416 return this->unionStr(printer, /*is_annotation_str=*/true);
417 }
418
canHoldType(const Type & type) const419 bool UnionType::canHoldType(const Type& type) const {
420 if (&type == NumberType::get().get()) {
421 return canHoldType(*IntType::get())
422 && canHoldType(*FloatType::get())
423 && canHoldType(*ComplexType::get());
424 } else {
425 return std::any_of(this->containedTypes().begin(), this->containedTypes().end(),
426 [&](const TypePtr& inner) {
427 return type.isSubtypeOf(*inner);
428 });
429 }
430 }
431
equals(const Type & rhs) const432 bool OptionalType::equals(const Type& rhs) const {
433 if (auto union_rhs = rhs.cast<UnionType>()) {
434 auto optional_rhs = union_rhs->toOptional();
435 // `**optional_rhs` = `*` to get value of `std::optional<TypePtr>`,
436 // then `*` to dereference the pointer
437 return optional_rhs && *this == **optional_rhs;
438 } else if (auto optional_rhs = rhs.cast<OptionalType>()) {
439 return *this->getElementType() == *optional_rhs->getElementType();
440 } else {
441 return false;
442 }
443 }
444
isSubtypeOfExt(const Type & rhs,std::ostream * why_not) const445 bool OptionalType::isSubtypeOfExt(const Type& rhs, std::ostream* why_not) const {
446 if (auto optional_rhs = rhs.castRaw<OptionalType>()) {
447 return getElementType()->isSubtypeOfExt(*optional_rhs->getElementType(), why_not);
448 } else if (auto union_rhs = rhs.castRaw<UnionType>()) {
449 if (!union_rhs->canHoldType(*NoneType::get())) {
450 if (why_not) {
451 *why_not << rhs.repr_str() << " cannot hold None";
452 }
453 return false;
454 } else if (!union_rhs->canHoldType(*this->getElementType())) {
455 if (why_not) {
456 *why_not << rhs.repr_str() << " cannot hold " << this->getElementType();
457 }
458 return false;
459 } else {
460 return true;
461 }
462 } else {
463 // NOLINTNEXTLINE(bugprone-parent-virtual-call)
464 return Type::isSubtypeOfExt(rhs, why_not);
465 }
466 }
467
468 } // namespace 10
469