1 #include <torch/csrc/jit/ir/alias_analysis.h>
2
3 #include <ATen/core/interned_strings.h>
4 #include <c10/util/flat_hash_map.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/api/function_impl.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/inliner.h>
9 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
10 #include <torch/csrc/jit/runtime/operator.h>
11 #include <fstream>
12 #include <iostream>
13
14 namespace torch::jit {
15
16 namespace {
17
toSingleType(const AliasTypeSet & mut_types)18 c10::MaybeOwned<TypePtr> toSingleType(const AliasTypeSet& mut_types) {
19 return mut_types.size() == 1
20 ? c10::MaybeOwned<TypePtr>::borrowed(mut_types[0])
21 : c10::MaybeOwned<TypePtr>::owned(c10::UnionType::create(mut_types));
22 }
23
24 // This class determines whether a type is mutable, and, if so, it maps
25 // the type to its "mutable equivalent" (see definition in
26 // `mapTypeToAliasTypeSet`). It uses a cache of TypePtrs to speed up these
27 // type lookups
28 class MutableTypePtrHelper {
29 public:
MutableTypePtrHelper(ska::flat_hash_map<TypePtr,AliasTypeSet> * mutable_type_cache)30 explicit MutableTypePtrHelper(
31 ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache)
32 : mutable_type_cache_(mutable_type_cache) {}
33
34 // Map any mutable type to a type such that all other types which the
35 // mutable type can alias will be mapped to the same type. For
36 // example, calling this method on `Optional[List[int]]` should be
37 // the same as calling this method on `List[int]`.
38 //
39 // Rules:
40 // - If the type is not mutable, return `nullopt`
41 // - If the type is a `Tuple`, that means that it's an immutable
42 // object that can itself contain mutable objects. We want to make
43 // sure that the mutable objects are correctly aliased, so we
44 // remove the immutable objects. (For example,
45 // `Tuple[int, Tensor]` would become `Tuple[Tensor]`, while
46 // `Tuple[int, str]` would be returned as `nullopt`.) This is a
47 // convenience that makes it easy to check if the `Tuple`
48 // contains only immutable objects, though it's not technically
49 // necessary
50 // - For any Tensor type (including Tensor types that are part of
51 // a larger container, e.g. `List[Tensor]`), return the
52 // "unshaped" version of that Tensor. An "unshaped" Tensor is a
53 // Tensor with shape information removed. For example, a Tensor
54 // of dimension 4 would map to the same type as a Tensor of
55 // dimension 1. This allows us to treat all subclasses of Tensor
56 // as a single, homogenous "Tensor" type.
mapTypeToAliasTypeSet(const TypePtr & type)57 std::optional<AliasTypeSet> mapTypeToAliasTypeSet(const TypePtr& type) {
58 if (mutable_type_cache_) {
59 const AliasTypeSet* result = mapTypeToBorrowedAliasTypeSet(type);
60 if (result) {
61 return *result;
62 }
63 }
64 return mapTypeToAliasTypeSetImpl(type);
65 }
66
mapTypeToBorrowedAliasTypeSet(const TypePtr & type)67 const AliasTypeSet* mapTypeToBorrowedAliasTypeSet(const TypePtr& type) {
68 TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mutable_type_cache_ != nullptr);
69 auto maybe_type_mapping = mutable_type_cache_->find(type);
70 if (maybe_type_mapping != mutable_type_cache_->end()) {
71 return &maybe_type_mapping->second;
72 }
73
74 auto mutable_types = mapTypeToAliasTypeSetImpl(type);
75 if (mutable_types) {
76 auto it =
77 mutable_type_cache_->emplace(type, std::move(*mutable_types)).first;
78 return &it->second;
79 } else {
80 return nullptr;
81 }
82 }
83
84 private:
mapTypeToAliasTypeSetImpl(const TypePtr & type)85 std::optional<AliasTypeSet> mapTypeToAliasTypeSetImpl(const TypePtr& type) {
86 switch (type->kind()) {
87 case TypeKind::ListType:
88 case TypeKind::DictType:
89 case TypeKind::ClassType:
90 case TypeKind::TensorType:
91 // TODO: Look up cached contained types. this is kind of tricky
92 // because a `List[Optional[T]]` should still be
93 // `List[Optional[Unshaped(T)]]`, but
94 // `mapTypeToAliasTypeSet(Optional[T])` should be `T`
95 return AliasTypeSet{unshapedType(type)};
96 case TypeKind::UnionType: {
97 AliasTypeSet mutable_types;
98 for (const TypePtr& inner :
99 type->expectRef<UnionType>().containedTypes()) {
100 if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
101 mutable_types.insert(
102 mutable_types.end(),
103 (*maybe_inner_types).begin(),
104 (*maybe_inner_types).end());
105 }
106 }
107 if (mutable_types.empty()) {
108 return std::nullopt;
109 }
110 return mutable_types;
111 }
112 case TypeKind::OptionalType: {
113 auto inner = type->castRaw<OptionalType>()->getElementType();
114 return mapTypeToAliasTypeSet(inner);
115 }
116 case TypeKind::AnyType:
117 return {AliasTypeSet{type}};
118 case TypeKind::FutureType: {
119 if (auto maybe_mut_types = mapTypeToAliasTypeSet(
120 type->castRaw<FutureType>()->getElementType())) {
121 return {AliasTypeSet{
122 FutureType::create(*toSingleType(*maybe_mut_types))}};
123 }
124 return std::nullopt;
125 }
126 case TypeKind::AwaitType: {
127 if (auto maybe_mut_types = mapTypeToAliasTypeSet(
128 type->castRaw<AwaitType>()->getElementType())) {
129 return {
130 AliasTypeSet{AwaitType::create(*toSingleType(*maybe_mut_types))}};
131 }
132 return std::nullopt;
133 }
134 case TypeKind::TupleType: {
135 std::vector<TypePtr> mutable_types;
136 for (const TypePtr& inner : type->expectRef<TupleType>().elements()) {
137 if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
138 mutable_types.insert(
139 mutable_types.end(),
140 (*maybe_inner_types).begin(),
141 (*maybe_inner_types).end());
142 }
143 }
144 if (mutable_types.empty()) {
145 return std::nullopt;
146 }
147 return {AliasTypeSet{TupleType::create(mutable_types)}};
148 }
149 default:
150 return std::nullopt;
151 }
152 }
153 ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache_;
154 };
155
isMutableTypeImpl(const TypePtr & type,ska::flat_hash_map<TypePtr,AliasTypeSet> * mutable_type_cache)156 bool isMutableTypeImpl(
157 const TypePtr& type,
158 ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache) {
159 // Check common cases to avoid recursively constructing type in
160 // `mapTypeToAliasTypeSetPtrImpl`
161 auto kind = type->kind();
162 if (kind == TypeKind::TensorType || kind == TypeKind::ListType ||
163 kind == TypeKind::ClassType || kind == TypeKind::DictType) {
164 return true;
165 }
166 MutableTypePtrHelper helper(mutable_type_cache);
167 if (mutable_type_cache) {
168 return helper.mapTypeToBorrowedAliasTypeSet(type) != nullptr;
169 } else {
170 return helper.mapTypeToAliasTypeSet(type).has_value();
171 }
172 }
173
174 } // namespace
175
176 // Static `isMutableType` does not use cache of type -> mutable type equivalent
isMutableType(const TypePtr & type)177 bool AliasDb::isMutableType(const TypePtr& type) {
178 return isMutableTypeImpl(type, nullptr);
179 }
180
isMutableType(const Value * v)181 bool AliasDb::isMutableType(const Value* v) {
182 return isMutableType(v->type());
183 }
184
185 // Make use of type -> mutable cache
isMutableTypeInternal(const TypePtr & type) const186 bool AliasDb::isMutableTypeInternal(const TypePtr& type) const {
187 return isMutableTypeImpl(type, &mapped_mutable_types_);
188 }
189
isMutableTypeInternal(const Value * v) const190 bool AliasDb::isMutableTypeInternal(const Value* v) const {
191 return isMutableTypeInternal(v->type());
192 }
193
mapTypeToAliasTypeSetPtr(const TypePtr & type) const194 const AliasTypeSet* AliasDb::mapTypeToAliasTypeSetPtr(
195 const TypePtr& type) const {
196 MutableTypePtrHelper helper(&mapped_mutable_types_);
197 return helper.mapTypeToBorrowedAliasTypeSet(type);
198 }
199
200 AliasDb::~AliasDb() = default;
201
202 // Structure used during analysis to keep track of all writes at a high
203 // level. When the analysis is completed, this will be used to construct
204 // a more efficient WriteIndex
205 struct AliasDb::WriteRegistry {
registerWritetorch::jit::AliasDb::WriteRegistry206 void registerWrite(const Value* v, Node* n) {
207 writes_[n].emplace_back(v);
208 }
registerWriteToAllContainedtorch::jit::AliasDb::WriteRegistry209 void registerWriteToAllContained(const Value* v, Node* n) {
210 containedWrites_[n].emplace_back(v);
211 }
registerWriteToAllWildcardstorch::jit::AliasDb::WriteRegistry212 void registerWriteToAllWildcards(Node* n) {
213 writesToAllWildcards_.insert(n);
214 }
215 std::unordered_map<Node*, std::vector<const Value*>> writes_;
216 std::unordered_map<Node*, std::vector<const Value*>> containedWrites_;
217 std::unordered_set<Node*> writesToAllWildcards_;
218 };
219
AliasDb(std::shared_ptr<Graph> graph,bool isFrozen,bool descendFunctionCalls)220 AliasDb::AliasDb(
221 std::shared_ptr<Graph> graph,
222 bool isFrozen,
223 bool descendFunctionCalls)
224 : graph_(std::move(graph)),
225 isFrozen_(isFrozen),
226 descend_function_calls_(descendFunctionCalls),
227 memoryDAGBuilder_(std::make_unique<MemoryDAGBuilder>()),
228 writeRegistry_(std::make_unique<AliasDb::WriteRegistry>()) {
229 analyze(graph_);
230
231 memoryDAG_ = std::move(*memoryDAGBuilder_).createMemoryDAG();
232 // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
233 memoryDAGBuilder_ = nullptr; // to make further access a hard error
234
235 memoryDAG_->setWildcards(
236 wildcards_, elementMap_, [&](const Value* v) -> Element* {
237 return getWildcard(v->type());
238 });
239
240 // Now we build up the various write indices based on information in the write
241 // registry that we populated during analysis
242
243 // Initialize the write index
244 writeIndex_ = TWriteIndex();
245 auto& writeIndex = *writeIndex_; // to make operator[] less ugly
246
247 // Build the write index
248 for (const auto& write : writeRegistry_->writes_) {
249 Node* node = write.first;
250 const std::vector<const Value*> writtenValues = write.second;
251 for (const Value* writtenValue : writtenValues) {
252 auto it = elementMap_.find(writtenValue);
253 TORCH_INTERNAL_ASSERT(
254 it != elementMap_.end(), "Tried to write to value not in MemoryDAG");
255 const auto& writtenMemoryLocations =
256 memoryDAG_->getMemoryLocations(it->second);
257 writeIndex[node] |= writtenMemoryLocations;
258 }
259 }
260
261 for (const auto& write : writeRegistry_->containedWrites_) {
262 Node* node = write.first;
263 const std::vector<const Value*>& writtenValues = write.second;
264 for (const Value* writtenValue : writtenValues) {
265 auto elem = elementMap_.at(writtenValue);
266 MemoryLocations writtenMemoryLocations;
267 memoryDAG_->collectAllContainedMemoryLocations(
268 elem, writtenMemoryLocations);
269 writeIndex[node] |= writtenMemoryLocations;
270 }
271 }
272
273 for (const auto& write : writeRegistry_->writesToAllWildcards_) {
274 for (const auto& pr : wildcardIndex_) {
275 writeIndex[write].set(pr.second->index);
276 }
277 }
278
279 // Now that we've built the write index, we can null out the WriteRegistry to
280 // make future access an error. In this way we prevent the index from getting
281 // out of sync (since we have no way of registering new writes)
282 // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
283 writeRegistry_ = nullptr;
284
285 // Initialize the write cache
286 buildWrittenToLocationsIndex();
287 GRAPH_DEBUG(toString());
288 }
289
isMutable(Node * n) const290 bool AliasDb::isMutable(Node* n) const {
291 ValueSet vs;
292 for (const auto input : n->inputs()) {
293 vs.insert(input);
294 }
295 return writesToAlias(n, vs);
296 }
297
hasInputWriters(const Node * n) const298 bool AliasDb::hasInputWriters(const Node* n) const {
299 for (const auto input : n->inputs()) {
300 if (hasWriters(input)) {
301 return true;
302 }
303 }
304 return false;
305 }
306
hasOutputWriters(const Node * n) const307 bool AliasDb::hasOutputWriters(const Node* n) const {
308 for (const auto output : n->outputs()) {
309 if (hasWriters(output)) {
310 return true;
311 }
312 }
313 return false;
314 }
315
hasWriters(const Node * n) const316 bool AliasDb::hasWriters(const Node* n) const {
317 return hasInputWriters(n) || hasOutputWriters(n);
318 }
319
hasWriters(const Value * v) const320 bool AliasDb::hasWriters(const Value* v) const {
321 if (v->mustBeNone()) {
322 return false;
323 }
324
325 auto it = elementMap_.find(v);
326 if (it == elementMap_.end()) {
327 return false;
328 }
329
330 const auto& el = it->second;
331 return writtenToLocationsIndex_->intersects(
332 memoryDAG_->getMemoryLocations(el));
333 }
334
getWritesImpl(Node * n,MemoryLocations & ret) const335 void AliasDb::getWritesImpl(Node* n, MemoryLocations& ret) const {
336 if (writeIndex_->count(n)) {
337 const auto& writes = writeIndex_->at(n);
338 ret |= writes;
339 }
340
341 for (auto block : n->blocks()) {
342 for (auto node : block->nodes()) {
343 getWritesImpl(node, ret);
344 }
345 }
346 }
347
348 // Does `n` write to an alias of one of the values in `vs`?
writesToAlias(Node * n,const ValueSet & vs) const349 bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
350 const auto writtenTo = getWrites(n);
351 if (writtenTo.empty()) {
352 return false;
353 }
354
355 MemoryLocations locs;
356 for (const auto v : vs) {
357 auto it = elementMap_.find(v);
358 if (it != elementMap_.end()) {
359 const auto& vlocs = memoryDAG_->getMemoryLocations(it->second);
360 if (writtenTo.intersects(vlocs)) {
361 return true;
362 }
363 }
364 }
365
366 return false;
367 }
368
getWrites(Node * n) const369 MemoryLocations AliasDb::getWrites(Node* n) const {
370 MemoryLocations writes;
371 getWritesImpl(n, writes);
372 return writes;
373 }
374
getReadsImpl(Node * n,MemoryLocations & ret) const375 void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const {
376 for (const auto input : n->inputs()) {
377 auto it = elementMap_.find(input);
378 if (it != elementMap_.end()) {
379 auto el = it->second;
380
381 // Add all memory locations this element may alias and their contained
382 // elements
383 memoryDAG_->collectAllContainedMemoryLocations(el, ret);
384 }
385 }
386
387 for (auto block : n->blocks()) {
388 for (auto node : block->nodes()) {
389 getReadsImpl(node, ret);
390 }
391 }
392 }
393
getReads(Node * n) const394 MemoryLocations AliasDb::getReads(Node* n) const {
395 MemoryLocations reads;
396 getReadsImpl(n, reads);
397 return reads;
398 }
399
getElementName(const Element * e) const400 std::string AliasDb::getElementName(const Element* e) const {
401 if (e->values.empty()) {
402 // Not the most efficient way, but given the fact there are
403 // not too many types and even fewer of them will end up in
404 // `wildcardIndex_`, we should be fine with a linear search
405 // each time we hit a Wildcard leaf
406 for (const auto& ent : wildcardIndex_) {
407 if (ent.second == e) {
408 return std::string("WILDCARD for type ") + ent.first->str();
409 }
410 }
411 return "WILDCARD";
412 } else {
413 std::ostringstream ss;
414 if (e->values.size() == 1) {
415 ss << "%" << (*e->values.begin())->debugName();
416 return ss.str();
417 }
418 ss << "(";
419 for (const Value* v : e->values) {
420 ss << "%" << v->debugName() << ", ";
421 }
422 ss << ")";
423 return ss.str();
424 }
425 }
426
dump() const427 void AliasDb::dump() const {
428 std::cout << toString();
429 }
430
toString() const431 std::string AliasDb::toString() const {
432 std::stringstream ss{};
433
434 ss << "\n===1. GRAPH===\n";
435 ss << graph_->toString();
436
437 ss << "\n===2. ALIAS DB===\n";
438 for (const auto& ptrPair : elementMap_) {
439 const auto element = ptrPair.second;
440 int ct = 0;
441 if (!element->pointsTo.empty()) {
442 ss << getElementName(element) << " points to: ";
443 for (const auto pointedTo : element->pointsTo) {
444 if (ct > 0) {
445 ss << ", ";
446 }
447 ++ct;
448 ss << getElementName(memoryDAG_->fromIndex(pointedTo));
449 }
450 ss << "\n";
451 }
452 ct = 0;
453 if (!element->containedElements.empty()) {
454 ss << getElementName(element) << " contains: ";
455 for (const auto contained : element->containedElements) {
456 ss << getElementName(memoryDAG_->fromIndex(contained));
457 if (ct > 0) {
458 ss << ", ";
459 }
460 ++ct;
461 }
462 ss << "\n";
463 }
464 }
465
466 ss << "\n===3. Writes===\n";
467 for (const auto& pr : *writeIndex_) {
468 const auto node = pr.first;
469 const auto& values = pr.second;
470 ss << *node;
471 ss << " ";
472 for (const auto value : values) {
473 ss << getElementName(memoryDAG_->fromIndex(value)) << ", ";
474 }
475 ss << "\n";
476 }
477 ss << "\n";
478 return ss.str();
479 }
480
dumpToGraphvizFile(const char * filename) const481 bool AliasDb::dumpToGraphvizFile(const char* filename) const {
482 std::ofstream dot_file(filename);
483 if (!dot_file.good()) {
484 std::cout << "Failed to create Graphviz file: '" << filename << "'\n";
485 return false;
486 }
487 dot_file << toGraphviz();
488 return true;
489 }
490
toGraphviz() const491 std::string AliasDb::toGraphviz() const {
492 std::stringstream dot;
493
494 // Local helper to generate a graphviz-friendly name encoding
495 // See also AliasDb::getElementName()
496 const auto name = [this](const Element* e) -> std::string {
497 if (e->values.empty()) {
498 for (const auto& ent : wildcardIndex_) {
499 if (ent.second == e) {
500 return std::string("\"WILDCARD for ") + ent.first->str() + "\"";
501 }
502 }
503 return "\"WILDCARD\"";
504 } else {
505 std::ostringstream ss;
506 if (e->values.size() == 1) {
507 ss << "\"\\%" << (*e->values.begin())->debugName() << "\"";
508 return ss.str();
509 }
510 ss << "\"(";
511 for (const Value* v : e->values) {
512 ss << "\\%" << v->debugName() << ", ";
513 }
514 ss << ")\"";
515 return ss.str();
516 }
517 };
518
519 // Include the textual representation for reference
520 dot << "/*\n";
521 dot << toString();
522 dot << "*/\n";
523
524 dot << "digraph alias_db {\n"
525 << " rankdir=LR\n"
526 << " node [shape=rect, color=gray];\n"
527 << " edge [color=black];\n";
528
529 for (const auto& ptrPair : elementMap_) {
530 const auto element = ptrPair.second;
531 if (!element->pointsTo.empty()) {
532 for (const auto pointedTo : element->pointsTo) {
533 dot << " " << name(element) << " -> "
534 << name(memoryDAG_->fromIndex(pointedTo)) << "\n";
535 }
536 }
537 if (!element->containedElements.empty()) {
538 for (const auto contained : element->containedElements) {
539 dot << " " << name(element) << " -> "
540 << name(memoryDAG_->fromIndex(contained))
541 << " [style=dashed, color=blue]\n";
542 }
543 }
544 }
545
546 dot << "}\n";
547 return dot.str();
548 }
549
analyze(const std::shared_ptr<Graph> & graph)550 void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
551 for (auto input : graph->inputs()) {
552 setWildcard(input);
553 }
554 analyze(graph->block());
555 }
556
analyze(Block * block)557 void AliasDb::analyze(Block* block) {
558 for (auto node : block->nodes()) {
559 analyze(node);
560 }
561 }
562
analyze(Node * node)563 void AliasDb::analyze(Node* node) {
564 analyzeImpl(node);
565 }
566
567 // Returns true if analysis was run using
568 // the registered analyzer.
tryRegisteredAnalysis(Node * node)569 bool AliasDb::tryRegisteredAnalysis(Node* node) {
570 const Operator& op = node->getOperator();
571 auto analysis = op.aliasAnalysisKind();
572 if (AliasAnalysisKind::PURE_FUNCTION == analysis) {
573 analyzeCreator(node);
574 return true;
575 }
576 return false;
577 }
578
579 // The basic strategy is:
580 // 1. Retrieve alias information for every input.
581 // 2. Use the node's schema's alias annotations to propgagate alias/write
582 // information to the outputs. For unschematized nodes, a special analyzer
583 // will have to be handwritten.
analyzeImpl(Node * node)584 void AliasDb::analyzeImpl(Node* node) {
585 auto op = node->maybeOperator();
586 const bool hasSpecialCase = aliasAnalysisHasSpecialCaseFor(node->kind());
587 if (op) {
588 const auto analysis = op->aliasAnalysisKind();
589
590 const bool registeredAsSpecialCase =
591 analysis == AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
592 if (C10_UNLIKELY(registeredAsSpecialCase && !hasSpecialCase)) {
593 TORCH_INTERNAL_ASSERT(
594 false,
595 "Op ",
596 node->kind().toDisplayString(),
597 " is registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but doesn't have a special case.");
598 } else if (C10_UNLIKELY(!registeredAsSpecialCase && hasSpecialCase)) {
599 TORCH_INTERNAL_ASSERT(
600 false,
601 "Op ",
602 node->kind().toDisplayString(),
603 " has a special case and should be registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but is registered with ",
604 c10::toString(analysis));
605 }
606 } else {
607 if (!hasSpecialCase) {
608 std::ostringstream oss;
609 for (const auto input : node->inputs()) {
610 oss << input->type()->str() << ", ";
611 }
612 oss << "\n\nCandidates:";
613 const auto& candidates = getAllOperatorsFor(node->kind());
614 for (const auto& candidate : candidates) {
615 oss << "\n\t" << candidate->schema();
616 }
617 TORCH_INTERNAL_ASSERT(
618 0,
619 "We don't have an op for ",
620 node->kind().toDisplayString(),
621 " but it isn't a special case. ",
622 "Argument types: ",
623 oss.str());
624 }
625 }
626
627 // These nodes are not schematized, so we need to handle them specially
628 switch (node->kind()) {
629 case prim::If:
630 return analyzeIf(node);
631 case prim::Loop:
632 return analyzeLoop(node);
633 case prim::FusionGroup:
634 case prim::CudaFusionGroup:
635 case prim::oneDNNFusionGroup:
636 case prim::FunctionalGraph:
637 case prim::DifferentiableGraph:
638 case prim::FallbackGraph:
639 return analyzeSubgraph(node);
640 case prim::fork:
641 return analyzeFork(node);
642 case aten::wait:
643 return analyzeWait(node);
644 case prim::awaitable:
645 case prim::awaitable_nowait:
646 return analyzeAwaitable(node);
647 case prim::awaitable_wait:
648 return analyzeAwaitableWait(node);
649 case prim::rpc_async:
650 case prim::rpc_sync:
651 case prim::rpc_remote:
652 return analyzeRpcAsync(node);
653 case aten::batch_norm:
654 return analyzeBatchNorm(node);
655 case aten::instance_norm:
656 return analyzeInstanceNorm(node);
657 case prim::GradOf:
658 return analyzeGradOf(node);
659 case prim::BroadcastMKLDNNTensors: {
660 makePointerTo(node->outputs().at(0), node->inputs().at(0));
661 makePointerTo(node->outputs().at(1), node->inputs().at(1));
662 return;
663 }
664 // TODO: think more about TensorExpr alias correctness
665 case prim::TensorExprGroup:
666 case prim::TensorExprDynamicGroup:
667 case prim::MKLDNNGroup:
668 case prim::ConstantMKLDNNTensor:
669 case prim::StaticSubgraph:
670 case prim::Constant:
671 case prim::AutogradZero:
672 case prim::AutogradAdd:
673 case prim::FusedConcat:
674 case prim::MMTreeReduce:
675 case prim::MMBatchSide:
676 case prim::BroadcastSizes:
677 case prim::ChunkSizes:
678 // this should never be seen outside of initial compilation
679 // but because of some dependencies with closure invoking alias
680 // db needs to be handled here
681 case prim::EmptyListLiteral:
682 case prim::Closure:
683 case prim::CreateObject:
684 case prim::tolist:
685 case prim::Uninitialized:
686 return analyzeCreator(node);
687 case prim::TupleConstruct:
688 case prim::DictConstruct:
689 case prim::ListConstruct:
690 return analyzeContainerConstruct(node);
691 case prim::TupleUnpack:
692 case prim::TupleIndex:
693 case prim::TupleSlice:
694 case prim::ListUnpack:
695 case prim::PythonOp:
696 case prim::GetAttr:
697 if (isFrozen_ && node->kind() == prim::GetAttr) {
698 auto& ty = node->input()->type();
699 if (ty->expectRef<ClassType>().is_module()) {
700 return analyzeCreator(node);
701 }
702 }
703 return analyzeExtractor(node);
704 case prim::unchecked_cast:
705 return makePointerTo(node->output(), node->input());
706 case prim::ConstantChunk:
707 return analyzeChunk(node);
708 case prim::BroadcastingChunk:
709 return analyzeBroadcastingChunk(node);
710 case prim::SetAttr:
711 return analyzeSetAttr(node);
712 case prim::profile_ivalue:
713 case prim::profile:
714 makePointerTo(node->output(), node->inputs().at(0));
715 return;
716 case prim::TypeCheck:
717 case prim::RequiresGradCheck: {
718 auto num_inputs = node->inputs().size();
719 for (const auto i : c10::irange(num_inputs)) {
720 makePointerTo(node->outputs().at(i), node->inputs().at(i));
721 }
722 return;
723 }
724 case prim::BailOut:
725 TORCH_INTERNAL_ASSERT(
726 node->inputs().at(0)->node()->kind() == prim::BailoutTemplate);
727 makePointerTo(node->output(), node->inputs().at(1));
728 return;
729 case prim::Guard:
730 makePointerTo(node->output(), node->inputs().at(0));
731 return;
732 case prim::CallFunction:
733 case prim::CallMethod: {
734 // TODO: this can be improved with summarizes of what the function does
735 // for now we assume the worst
736 if (!descend_function_calls_) {
737 return analyzeConservative(node);
738 }
739 auto g = tryToGraphFunction(node);
740 if (!g) {
741 return analyzeConservative(node);
742 }
743 // this is an unoptimized path - we copy the subgraph for each function
744 // call past the first - so we do not generally enable the recursive
745 // analysis. use cases for fine-grained alias analysis without inlining
746 // are very uncommon
747 auto graph = g->optimized_graph();
748 // alias analysis will use Value* as mappings for information,
749 // so for each analysis of a particular function call we need a new graph
750 // for all copies made, store them for duration of analysis so we do not
751 // run into lifetime issues with the graph
752 std::vector<std::shared_ptr<Graph>>& graphs =
753 function_call_copies_[graph.get()];
754 if (graphs.empty()) {
755 graphs.push_back(graph);
756 analyzeSubgraph(node, graph);
757 } else {
758 auto copied_graph = graph->copy();
759 graphs.push_back(copied_graph);
760 analyzeSubgraph(node, copied_graph);
761 }
762 return;
763 }
764 case prim::Enter:
765 case prim::Exit:
766 // TODO: this can be improved with summarizes of what the function does
767 // for now we assume the worst
768 // NB: update safeToChangeAliasingRelationship if changed
769 return analyzeConservative(node);
770 case prim::Print:
771 case prim::isinstance:
772 // These ops do nothing
773 return;
774 default:
775 if (tryRegisteredAnalysis(node)) {
776 return;
777 }
778 }
779
780 TORCH_INTERNAL_ASSERT(op, "We should have an op schema if we get to here");
781 const AliasAnalysisKind analysis = op->aliasAnalysisKind();
782 TORCH_INTERNAL_ASSERT(
783 analysis != AliasAnalysisKind::INTERNAL_SPECIAL_CASE &&
784 !aliasAnalysisHasSpecialCaseFor(node->kind()),
785 "Special cases should be handled already if we're here.");
786
787 if (node->kind().is_aten() || node->kind().is_prim() ||
788 node->kind().is_cuda()) {
789 // TODO There is nothing in the system that relies on aten:: and prim::
790 // ops using AliasAnalysisKind::FROM_SCHEMA or
791 // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended
792 // behavior for all current ops and a good error check. We can consider
793 // lifting this constraint later if we have a use case for it.
794 TORCH_INTERNAL_ASSERT(
795 analysis == AliasAnalysisKind::FROM_SCHEMA ||
796 analysis == AliasAnalysisKind::CONSERVATIVE,
797 "aten:: and prim:: operators should use AliasAnalysisKind::FROM_SCHEMA or "
798 "AliasAnalysisKind::CONSERVATIVE(if really necessary), but ",
799 node->kind().toDisplayString(),
800 " doesn't. Note: Ideally, prim:: operators actually shouldn't have a schema ",
801 "and then use AliasAnalysisKind::INTERNAL_SPECIAL_CASE instead.");
802 }
803
804 if (analysis == AliasAnalysisKind::CONSERVATIVE) {
805 // TODO A previous implementation of alias analysis always accessed
806 // node->schema , which cause the schema caches in the Node class to be
807 // filled for the full graph. Unfortunately, our JIT passes started relying
808 // on that, so we need to keep doing this. Details: in
809 // caffe2/torch/onnx/utils.py, _jit_pass_onnx is called on an invalid JIT
810 // graph because we called _jit_pass_erase_number_types right before and
811 // ints are now Tensors instead. So if _jit_pass_onnx tries to look up
812 // operator schemas, it will crash. However, _jit_pass_constant_propagation,
813 // which is called before it, runs alias analysis and prefills the schema
814 // cache in the all Node instances so that _jit_pass_onnx doesn't look up
815 // operators to get the schemas anymore. We should fix this.
816 node->schema(); // fill the schema cache in the Node class
817
818 return analyzeConservative(node);
819 }
820
821 TORCH_INTERNAL_ASSERT(
822 analysis == AliasAnalysisKind::FROM_SCHEMA,
823 "AliasAnalysisKind::CONSERVATIVE/PURE_FUNCTION/INTERNAL_SPECIAL_CASE should already have been handled above");
824 const auto& schema = node->schema();
825
826 // Bind the schema's "formal" alias annotation to the actual values those
827 // schema arguments represent
828 std::unordered_map<Symbol, Value*> formalToActual;
829 for (const auto i : c10::irange(schema.arguments().size())) {
830 const at::AliasInfo* formal = schema.arguments()[i].alias_info();
831 const auto& actualValue = node->inputs().at(i);
832
833 // Skip if there's no alias annotation
834 if (!formal) {
835 continue;
836 }
837
838 // If this type cannot alias, continue. Can occur with a VarType schema
839 if (!isMutableTypeInternal(actualValue)) {
840 continue;
841 }
842
843 // Do sanity checks on the alias annotation
844 TORCH_INTERNAL_ASSERT(
845 formal->containedTypes().size() <= 1,
846 "Composite types for alias analysis not yet supported");
847 TORCH_INTERNAL_ASSERT(
848 !formal->isWildcardBefore(),
849 "Doesn't make sense for a input value to begin as a wildcard");
850 // This is a special case where we have alias info before [] but not after,
851 // such as `Tensor(a!)[]`
852 if (formal->containedTypes().size() == 1 && formal->beforeSets().empty()) {
853 // Use the first containedType in alias info.
854 formal = &(formal->containedTypes()[0]);
855 }
856
857 const auto& formalAlias = formal->beforeSet();
858
859 // skip if we've already bound this alias
860 if (formalToActual.count(formalAlias) != 0) {
861 continue;
862 }
863
864 // Bind the formal to the actual
865 formalToActual[formalAlias] = actualValue;
866
867 // Record writes
868 if (formal->isWrite()) {
869 registerWrite(actualValue, node);
870 }
871
872 // Now deal with sets after the '->'
873 if (formal->isWildcardAfter()) {
874 TORCH_INTERNAL_ASSERT(
875 formal->afterSets().size() == 1,
876 "If the after set contains a wildcard, "
877 "there should be no other alias sets specified.");
878 setWildcard(actualValue);
879 } else {
880 // We don't understand anything else in the after yet, so assert there's
881 // been no change.
882 TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
883 }
884 }
885
886 // Use the formal-actual mapping to give aliases to the outputs
887 for (const auto i : c10::irange(schema.returns().size())) {
888 const auto actual = node->outputs().at(i);
889 const at::AliasInfo* formal = schema.returns()[i].alias_info();
890 if (!formal) {
891 // This is a fresh tensor
892 giveFreshAlias(actual);
893 continue;
894 }
895
896 // If this type cannot alias, continue. Can occur with a VarType schema
897 if (!isMutableType(actual)) {
898 continue;
899 }
900
901 TORCH_INTERNAL_ASSERT(
902 formal->containedTypes().size() <= 1,
903 "Composite types for alias analysis not yet supported");
904 TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
905 if (formal->containedTypes().size() == 1 && formal->beforeSets().empty()) {
906 // Use the first containedType in alias info.
907 formal = &(formal->containedTypes()[0]);
908 }
909 if (formal->isWildcardBefore()) {
910 TORCH_INTERNAL_ASSERT(
911 formal->beforeSets().size() == 1,
912 "If an output is a wildcard, "
913 "there should be no other alias sets specified.");
914 setWildcard(actual);
915 continue;
916 }
917
918 bool inputs_has_alias = false;
919 for (const auto& formalAlias : formal->beforeSets()) {
920 if (formalToActual.count(formalAlias)) {
921 inputs_has_alias = true;
922 auto toAlias = formalToActual.at(formalAlias);
923 makePointerTo(actual, toAlias);
924 }
925 }
926 // If all the alias annotation that we encounter weren't in the inputs:
927 // e.g. foo(Tensor(a) self) -> Tensor(b)
928 // or foo(Tensor(a) self) -> Tensor(b|c)
929 // Otherwise it is the form of a|fresh, which we can ignore, taking the
930 // conservative assumption that the output must alias `a`, e.g
931 // aten::cuda(Tensor(a) self) -> Tensor(a|fresh)
932 if (!inputs_has_alias && !formal->beforeSets().empty()) {
933 giveFreshAlias(actual);
934 }
935
936 // Record writes
937 if (formal->isWrite()) {
938 registerWrite(actual, node);
939 }
940 }
941 }
942
943 // Register the fact that `n` writes to `v`.
registerWrite(const Value * v,Node * n,bool writeToContained)944 void AliasDb::registerWrite(const Value* v, Node* n, bool writeToContained) {
945 if (!isMutableTypeInternal(v)) {
946 // don't need to register a write if the value isn't mutable
947 return;
948 }
949 if (writeToContained) {
950 writeRegistry_->registerWriteToAllContained(v, n);
951 } else {
952 writeRegistry_->registerWrite(v, n);
953 }
954 }
955
analyzeIf(Node * node)956 void AliasDb::analyzeIf(Node* node) {
957 // For if statements, the alias set of an output is the union of the
958 // alias sets generated by the if and else block
959 const auto trueBlock = node->blocks().at(0);
960 const auto falseBlock = node->blocks().at(1);
961 analyze(trueBlock);
962 analyze(falseBlock);
963
964 for (const auto i : c10::irange(node->outputs().size())) {
965 const auto nodeOutput = node->outputs()[i];
966
967 const auto trueOutput = trueBlock->outputs().at(i);
968 const auto falseOutput = falseBlock->outputs().at(i);
969
970 makePointerTo(nodeOutput, trueOutput);
971 makePointerTo(nodeOutput, falseOutput);
972 }
973 }
974
analyzeLoop(Node * node)975 void AliasDb::analyzeLoop(Node* node) {
976 const auto bodyBlock = node->blocks().at(0);
977 const auto loopCarriedInputs = node->inputs().slice(2); // skip max, cond
978 const auto blockInputs = bodyBlock->inputs().slice(1); // skip trip
979 const auto blockOutputs = bodyBlock->outputs().slice(1); // skip trip
980 TORCH_INTERNAL_ASSERT(loopCarriedInputs.size() == blockInputs.size());
981 TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size());
982
983 // Run alias analysis on the loop body, iterating until the block output
984 // alias info converges. Copy node input aliases to block input
985 mapAliases(blockInputs, loopCarriedInputs);
986
987 // Populate block output alias info by analyzing the body
988 analyze(bodyBlock);
989
990 // Copy the alias info from the block output to the node output
991 mapAliases(node->outputs(), blockOutputs);
992 }
993
analyzeGradOf(Node * node)994 void AliasDb::analyzeGradOf(Node* node) {
995 const auto grad_of_block = node->blocks().at(0);
996 analyze(grad_of_block);
997 mapAliases(node->outputs(), grad_of_block->outputs());
998 }
999
analyzeSubgraph(Node * node,const std::shared_ptr<Graph> & subgraph)1000 void AliasDb::analyzeSubgraph(
1001 Node* node,
1002 const std::shared_ptr<Graph>& subgraph) {
1003 const auto subgraphBlock = subgraph->block();
1004 // CallFunction nodes have an extra first parameter
1005 if (node->kind() == prim::CallFunction) {
1006 mapAliases(subgraphBlock->inputs(), node->inputs().slice(1));
1007 } else {
1008 mapAliases(subgraphBlock->inputs(), node->inputs());
1009 }
1010
1011 analyze(subgraphBlock);
1012
1013 // Note: the subgraph outputs and node outputs are NOT NECESSARILY the
1014 // same length. Autodifferentiation maybe capture additional outputs in the
1015 // subgraph block.
1016 TORCH_INTERNAL_ASSERT(
1017 subgraphBlock->outputs().size() >= node->outputs().size());
1018 for (size_t i = 0; i < node->outputs().size(); i++) {
1019 makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
1020 }
1021 }
1022
analyzeSubgraph(Node * node)1023 void AliasDb::analyzeSubgraph(Node* node) {
1024 const auto subgraph = node->g(attr::Subgraph);
1025 return analyzeSubgraph(node, subgraph);
1026 }
1027 // For nodes that generate a fresh value from nothing
analyzeCreator(Node * node)1028 void AliasDb::analyzeCreator(Node* node) {
1029 for (Value* output : node->outputs()) {
1030 giveFreshAlias(output);
1031 }
1032 }
1033
1034 // For nodes that extract values from a composite type. Right now, this just
1035 // gives up and creates wildcards for everything.
analyzeExtractor(Node * node)1036 void AliasDb::analyzeExtractor(Node* node) {
1037 for (const auto output : node->outputs()) {
1038 setWildcard(output);
1039 }
1040 }
1041
1042 // For torch.chunk(), all returned tensors may alias the input tensor
analyzeChunk(Node * node)1043 void AliasDb::analyzeChunk(Node* node) {
1044 for (auto output : node->outputs()) {
1045 makePointerTo(output, node->input());
1046 }
1047 }
1048
analyzeFork(Node * node)1049 void AliasDb::analyzeFork(Node* node) {
1050 for (const auto input : node->inputs()) {
1051 setWildcard(input);
1052 }
1053
1054 // Give the future that the fork emits a fresh value
1055 for (const auto output : node->outputs()) {
1056 giveFreshAlias(output);
1057 }
1058 }
1059
analyzeWait(Node * node)1060 void AliasDb::analyzeWait(Node* node) {
1061 TORCH_INTERNAL_ASSERT(node->kind() == aten::wait);
1062 for (const auto output : node->outputs()) {
1063 setWildcard(output);
1064 }
1065 // the forked subgraph that `wait` is waiting on may write to any of its
1066 // inputs. We don't have a reliable way of recovering the fork inputs, so
1067 // for safety we just register a write to every wildcard.
1068 writeRegistry_->registerWriteToAllWildcards(node);
1069 }
1070
analyzeAwaitable(Node * node)1071 void AliasDb::analyzeAwaitable(Node* node) {
1072 for (const auto input : node->inputs()) {
1073 setWildcard(input);
1074 }
1075
1076 for (const auto output : node->outputs()) {
1077 giveFreshAlias(output);
1078 }
1079 }
1080
analyzeAwaitableWait(Node * node)1081 void AliasDb::analyzeAwaitableWait(Node* node) {
1082 TORCH_INTERNAL_ASSERT(node->kind() == prim::awaitable_wait);
1083 for (const auto output : node->outputs()) {
1084 setWildcard(output);
1085 }
1086 // the awaitable subgraph that `wait` is waiting on may write to any of its
1087 // inputs. We don't have a reliable way of recovering the awaitable inputs, so
1088 // for safety we just register a write to every wildcard.
1089 writeRegistry_->registerWriteToAllWildcards(node);
1090 }
1091
analyzeRpcAsync(Node * node)1092 void AliasDb::analyzeRpcAsync(Node* node) {
1093 for (const auto input : node->inputs()) {
1094 setWildcard(input);
1095 }
1096
1097 // Give the future that the rpc_async emits a fresh value
1098 for (const auto output : node->outputs()) {
1099 giveFreshAlias(output);
1100 }
1101 }
1102
1103 namespace {
getConstantBooleanInput(Node * node,const std::string & inputName)1104 std::optional<bool> getConstantBooleanInput(
1105 Node* node,
1106 const std::string& inputName) {
1107 TORCH_INTERNAL_ASSERT(
1108 node->hasNamedInput(inputName), inputName + " input is expected");
1109 auto value = node->namedInput(inputName);
1110 TORCH_INTERNAL_ASSERT(
1111 value->type() == BoolType::get(),
1112 inputName + "training input is expected to be a bool");
1113 return constant_as<bool>(value);
1114 }
1115 } // namespace
1116
1117 // custom behavior for batch_norm because (a!)? annotations currently
1118 // aren't supported, and because behavior differs depending on the value of
1119 // training
analyzeBatchNorm(Node * node)1120 void AliasDb::analyzeBatchNorm(Node* node) {
1121 // we invoking freezing for inference, so we assume training will be folded to
1122 // a constant false to avoid needing to invoke freezing multiple times in
1123 // order to make batch norm weights constant
1124 for (Value* output : node->outputs()) {
1125 giveFreshAlias(output);
1126 }
1127
1128 if (isFrozen_) {
1129 return;
1130 }
1131
1132 auto isTraining = getConstantBooleanInput(node, "training");
1133
1134 if (!isTraining.has_value() || *isTraining) {
1135 TORCH_INTERNAL_ASSERT(
1136 node->hasNamedInput("running_mean"), "running_mean input is expected");
1137 auto runningMean = node->namedInput("running_mean");
1138 TORCH_INTERNAL_ASSERT(
1139 node->hasNamedInput("running_var"), "running_var input is expected");
1140 auto runningVar = node->namedInput("running_var");
1141
1142 registerWrite(runningMean, node);
1143 registerWrite(runningVar, node);
1144 }
1145 }
1146
1147 // custom behavior for instance_norm, because (a!)? annotations currently
1148 // aren't supported, and because behavior differs depending on the value of
1149 // use_input_stats
analyzeInstanceNorm(Node * node)1150 void AliasDb::analyzeInstanceNorm(Node* node) {
1151 for (Value* output : node->outputs()) {
1152 giveFreshAlias(output);
1153 }
1154
1155 auto useInputStats = getConstantBooleanInput(node, "use_input_stats");
1156
1157 if (!useInputStats.has_value() || *useInputStats) {
1158 TORCH_INTERNAL_ASSERT(
1159 node->hasNamedInput("running_mean"), "running_mean input is expected");
1160 auto runningMean = node->namedInput("running_mean");
1161 TORCH_INTERNAL_ASSERT(
1162 node->hasNamedInput("running_var"), "running_var input is expected");
1163 auto runningVar = node->namedInput("running_var");
1164
1165 registerWrite(runningMean, node);
1166 registerWrite(runningVar, node);
1167 }
1168 }
1169
1170 // SetAttr: writes to the `self` field
analyzeSetAttr(Node * node)1171 void AliasDb::analyzeSetAttr(Node* node) {
1172 const auto self = node->inputs().at(0);
1173 TORCH_INTERNAL_ASSERT(self->type()->kind() == TypeKind::ClassType);
1174 registerWrite(self, node);
1175 // Also the value being set must become a wildcard.
1176 const auto newValue = node->inputs().at(1);
1177 setWildcard(newValue);
1178 }
1179
1180 // Used for anything where we do not have accurate alias summaries
1181 // may write to any input and produce wildcards
analyzeConservative(Node * node)1182 void AliasDb::analyzeConservative(Node* node) {
1183 for (const auto input : node->inputs()) {
1184 if (!isMutableTypeInternal(input)) {
1185 continue;
1186 }
1187 registerWrite(input, node, /*writeToContained=*/true);
1188 setWildcard(input);
1189 }
1190
1191 for (const auto output : node->outputs()) {
1192 setWildcard(output);
1193 }
1194 }
1195
functionalNonEscapingListUse(const Use & use) const1196 bool AliasDb::functionalNonEscapingListUse(const Use& use) const {
1197 Node* n = use.user;
1198 size_t offset = use.offset;
1199 Value* container = n->inputs().at(offset);
1200
1201 // only consider aten op uses of lists
1202 if (!container->type()->cast<ListType>()) {
1203 return false;
1204 }
1205
1206 /*
1207 in the general case, we consider any Value that enters another container as
1208 entering the heap, and thus aliasing all other heap values of the same type.
1209 the advantage of this approach are:
1210 - there are many composite list/container ops that would be tricky to
1211 schematize if we did something more complicated
1212 - limits the size of the AliasDb, because a container of size 10 only contains
1213 1 memory dag element instead of 10
1214 - we do not need to worry about adding contained elements to the wildcard set
1215 when a container escapes the graph.
1216 The downside of this approach is we are unable to handle the common case of a
1217 list constructed and passed into an aten op. Here, optimize for a set of
1218 common ops where the output does not alias the list or the list elements
1219 */
1220
1221 // only used in output of graph - no further uses,
1222 // so there will be no use of it where the contained element leaks
1223 if (use.user->kind() == prim::Return) {
1224 return use.user->owningBlock() == graph_->block();
1225 }
1226
1227 switch (use.user->kind()) {
1228 case aten::cat:
1229 case aten::broadcast_tensors:
1230 case aten::stack:
1231 case aten::vstack:
1232 case aten::hstack:
1233 case aten::dstack:
1234 return true;
1235 }
1236 auto op = use.user->maybeOperator();
1237 if (op && op->aliasAnalysisKind() == AliasAnalysisKind::PURE_FUNCTION) {
1238 return true;
1239 }
1240 return false;
1241 }
1242
functionalNonEscapingTupleUse(const Use & use) const1243 bool AliasDb::functionalNonEscapingTupleUse(const Use& use) const {
1244 Node* n = use.user;
1245 size_t offset = use.offset;
1246 Value* container = n->inputs().at(offset);
1247 if (!container->type()->cast<TupleType>()) {
1248 return false;
1249 }
1250 // TODO(T97387453): Cover more ops that do not let escape tuples' elements.
1251 bool in_return_outputs = use.user->kind() == prim::Return;
1252 bool not_in_nested_subgraph = use.user->owningBlock() == graph_->block();
1253 return in_return_outputs && not_in_nested_subgraph;
1254 }
1255
1256 // List or dict or tuple construct: create an aliasing element for the actual
1257 // container, then mark all inputs as wildcards, since they've gone inside the
1258 // container. Then, add the wildcard sets of appropriate type to the contained
1259 // elements of the container.
analyzeContainerConstruct(Node * node)1260 void AliasDb::analyzeContainerConstruct(Node* node) {
1261 TORCH_INTERNAL_ASSERT(
1262 node->kind() == prim::ListConstruct ||
1263 node->kind() == prim::DictConstruct ||
1264 node->kind() == prim::TupleConstruct);
1265
1266 // tuples which contain immutable types are immutable
1267 if (!isMutableTypeInternal(node->output())) {
1268 return;
1269 }
1270
1271 TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
1272 auto container = node->output();
1273
1274 // optimization:
1275 // if a list is only used once in an aten op, and the op output
1276 // doesn't alias the input, then we can add all inputs to the list's
1277 // contained elements instead of the wildcard set.
1278 if (container->uses().size() == 1 &&
1279 (functionalNonEscapingListUse(container->uses().at(0)) ||
1280 functionalNonEscapingTupleUse(container->uses().at(0)))) {
1281 giveFreshAlias(container, false);
1282 for (Value* v : node->inputs()) {
1283 addToContainedElements(v, container);
1284 }
1285 return;
1286 }
1287
1288 giveFreshAlias(container);
1289 auto container_elem = elementMap_.at(container);
1290 for (auto input : node->inputs()) {
1291 auto maybe_wildcard_elem = setWildcard(input);
1292 if (maybe_wildcard_elem) {
1293 memoryDAGBuilder_->addToContainedElements(
1294 *maybe_wildcard_elem, container_elem);
1295 }
1296 }
1297 }
1298
1299 // BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
1300 // This is an intermediate node used only in the graph fuser.
analyzeBroadcastingChunk(Node * node)1301 void AliasDb::analyzeBroadcastingChunk(Node* node) {
1302 auto inputs = node->inputs();
1303 auto outputs = node->outputs();
1304 auto nchunks = node->i(attr::chunks);
1305 for (const auto index : c10::irange(inputs.size())) {
1306 // Each inputs[i] is aliased by exactly `nchunks` distinct output tensors:
1307 // inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks)
1308 auto output_begin = outputs.begin() + index * nchunks;
1309 for (auto it = output_begin; it != output_begin + nchunks; ++it) {
1310 makePointerTo(*it, inputs.at(index));
1311 }
1312 }
1313 }
1314
nonAliasingValue(const Value * elem) const1315 bool AliasDb::nonAliasingValue(const Value* elem) const {
1316 // these are values which can point to aliasing types in the graph,
1317 // as with a None value pointing to an optional if node output,
1318 // but will never alias themselves
1319 return elem->mustBeNone() || elem->node()->kind() == prim::Uninitialized;
1320 }
1321
1322 // Register the fact that `from` is a pointer to `to`
makePointerTo(const Value * from,const Value * to)1323 void AliasDb::makePointerTo(const Value* from, const Value* to) {
1324 if (nonAliasingValue(from) || nonAliasingValue(to)) {
1325 // if either value is guaranteed to be non-aliasing, we do not need to
1326 // connect the two elements. however, it is invariant that aliasing types
1327 // that are not wildcards have a memory dag element, so we create one if
1328 // needed
1329 giveFreshAlias(from);
1330 giveFreshAlias(to);
1331 return;
1332 }
1333
1334 // The contained types of immutable type containers (`Optional`,
1335 // `Tuple`, `Future`, and `Union`) are unified, so these types can be
1336 // mutable or immutable and point to a type which is mutable or
1337 // immutable. `Any` is mutable but can point to an immutable type
1338 // through refinement
1339 if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) {
1340 return;
1341 }
1342 // both immutable
1343 if (!isMutableTypeInternal(from)) {
1344 return;
1345 }
1346 if (from == to) {
1347 return;
1348 }
1349
1350 // At this point, we are dealing with two mutable types
1351 auto from_el = getOrCreateElement(from);
1352 auto to_el = getOrCreateElement(to);
1353
1354 memoryDAGBuilder_->makePointerTo(from_el, to_el);
1355 }
1356
addToContainedElements(const Value * inner,const Value * container)1357 void AliasDb::addToContainedElements(
1358 const Value* inner,
1359 const Value* container) {
1360 if (!isMutableTypeInternal(inner)) {
1361 return;
1362 }
1363
1364 auto inner_el = getOrCreateElement(inner);
1365 auto cont_el = getOrCreateElement(container);
1366
1367 memoryDAGBuilder_->addToContainedElements(inner_el, cont_el);
1368 }
1369
mayAlias(const Value * a,const Value * b) const1370 bool AliasDb::mayAlias(const Value* a, const Value* b) const {
1371 if (!isMutableTypeInternal(a) || !isMutableTypeInternal(b)) {
1372 return false;
1373 }
1374
1375 return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
1376 }
1377
mayAlias(const ValueSet & a,const ValueSet & b) const1378 bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const {
1379 if (a.empty() || b.empty()) {
1380 return false;
1381 }
1382
1383 // Record all memory locations from group `a`
1384 MemoryLocations aMemLocs;
1385 for (const auto value : a) {
1386 auto it = elementMap_.find(value);
1387 if (it != elementMap_.end()) {
1388 aMemLocs |= memoryDAG_->getMemoryLocations(it->second);
1389 }
1390 }
1391
1392 // If any of group `b`s memory locations overlap, return true.
1393 for (const auto value : b) {
1394 auto it = elementMap_.find(value);
1395 if (it != elementMap_.end()) {
1396 if (aMemLocs.intersects(memoryDAG_->getMemoryLocations(it->second))) {
1397 return true;
1398 }
1399 }
1400 }
1401 // No overlap, so group `a` and `b` do not share a memory location
1402 return false;
1403 }
1404
mayContainAlias(Value * a,Value * b) const1405 bool AliasDb::mayContainAlias(Value* a, Value* b) const {
1406 if (!isMutableTypeInternal(a) || !isMutableTypeInternal(b)) {
1407 return false;
1408 }
1409 return memoryDAG_->mayContainAlias(elementMap_.at(a), elementMap_.at(b));
1410 }
1411
getElements(at::ArrayRef<Value * > vs) const1412 std::vector<Element*> AliasDb::getElements(at::ArrayRef<Value*> vs) const {
1413 std::vector<Element*> elements;
1414 for (const auto& val : vs) {
1415 if (isMutableTypeInternal(val)) {
1416 elements.push_back(elementMap_.at(val));
1417 }
1418 }
1419 return elements;
1420 }
1421
mayContainAlias(const at::ArrayRef<Value * > a,const at::ArrayRef<Value * > b) const1422 bool AliasDb::mayContainAlias(
1423 const at::ArrayRef<Value*> a,
1424 const at::ArrayRef<Value*> b) const {
1425 auto a_elems = getElements(a);
1426 return a_elems.empty() ? false
1427 : memoryDAG_->mayContainAlias(a_elems, getElements(b));
1428 }
1429
mayContainAlias(Value * a,const at::ArrayRef<Value * > b) const1430 bool AliasDb::mayContainAlias(Value* a, const at::ArrayRef<Value*> b) const {
1431 if (!isMutableTypeInternal(a)) {
1432 return false;
1433 }
1434 auto b_elems = getElements(b);
1435 return b_elems.empty()
1436 ? false
1437 : memoryDAG_->mayContainAlias(elementMap_.at(a), b_elems);
1438 }
1439
1440 // Make each value in the `from` list point to its partner in the `to` list
mapAliases(at::ArrayRef<Value * > from,at::ArrayRef<Value * > to)1441 void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
1442 TORCH_INTERNAL_ASSERT(to.size() == from.size());
1443 for (const auto i : c10::irange(to.size())) {
1444 makePointerTo(from[i], to[i]);
1445 }
1446 }
1447
1448 // Should only be called from create_functional_graphs.
1449 // The asserts are to guard against unintentional use.
1450 // FIXME refactor aliasdb construction to be more robust to mutation so this
1451 // hack isn't necessary.
createValue(const Value * value)1452 void AliasDb::createValue(const Value* value) {
1453 TORCH_INTERNAL_ASSERT(isMutableTypeInternal(value->type()));
1454 auto new_elem = memoryDAG_->unsafeMakeFreshValue(value);
1455 elementMap_[value] = new_elem;
1456 }
1457
giveFreshAlias(const Value * value,bool add_wildcard_to_contained_elems)1458 void AliasDb::giveFreshAlias(
1459 const Value* value,
1460 bool add_wildcard_to_contained_elems) {
1461 auto maybe_mut_types = mapTypeToAliasTypeSetPtr(value->type());
1462 if (!maybe_mut_types) {
1463 return;
1464 }
1465
1466 if (elementMap_.count(value)) {
1467 // Inside a loop, we may have given a fresh alias to this value already, so
1468 // skip
1469 return;
1470 }
1471
1472 auto new_elem = memoryDAGBuilder_->makeFreshValue(value);
1473 elementMap_[value] = new_elem;
1474 if (add_wildcard_to_contained_elems) {
1475 if (maybe_mut_types->size() > 1) {
1476 pointUnionTypeElementToAllContainedTypes(new_elem, *maybe_mut_types);
1477 } else {
1478 addContainedTypesToFreshElement(new_elem, *maybe_mut_types);
1479 }
1480 }
1481 }
1482
getOrCreateElement(const Value * value)1483 Element* AliasDb::getOrCreateElement(const Value* value) {
1484 if (!elementMap_.count(value)) {
1485 giveFreshAlias(value);
1486 }
1487 return elementMap_.at(value);
1488 }
1489
replaceWithNewValue(Value * existing,Value * new_value)1490 void AliasDb::replaceWithNewValue(Value* existing, Value* new_value) {
1491 TORCH_INTERNAL_ASSERT(
1492 *unshapedType(existing->type()) == *unshapedType(new_value->type()),
1493 "Types must be strictly equal if you are replacing aliasing information. ",
1494 "Got existing: '",
1495 existing->type()->repr_str(),
1496 "', new_value: '",
1497 new_value->type()->repr_str(),
1498 "'");
1499 if (!isMutableTypeInternal(existing)) {
1500 return;
1501 }
1502 auto existing_elem = elementMap_.at(existing);
1503 elementMap_[new_value] = existing_elem;
1504 elementMap_.erase(existing);
1505 existing_elem->values = {new_value};
1506 }
1507
copyValue(Value * from,Value * to)1508 void AliasDb::copyValue(Value* from, Value* to) {
1509 TORCH_INTERNAL_ASSERT(
1510 *unshapedType(from->type()) == *unshapedType(to->type()),
1511 "Types must be strictly equal if you are copying aliasing information. ",
1512 "Got from: '",
1513 from->type()->repr_str(),
1514 "', to: '",
1515 to->type()->repr_str(),
1516 "'");
1517 if (!isMutableTypeInternal(to)) {
1518 return;
1519 }
1520 auto origElem = elementMap_.at(from);
1521 elementMap_[to] = origElem;
1522 origElem->values.insert(to);
1523 }
1524
moveAfterTopologicallyValid(Node * n,Node * movePoint)1525 bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
1526 return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
1527 }
1528
couldMoveAfterTopologically(Node * n,Node * movePoint)1529 bool AliasDb::couldMoveAfterTopologically(Node* n, Node* movePoint) {
1530 return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/true);
1531 }
1532
moveBeforeTopologicallyValid(Node * n,Node * movePoint)1533 bool AliasDb::moveBeforeTopologicallyValid(Node* n, Node* movePoint) {
1534 // We have to distinguish the move side (instead of just moving after
1535 // n->prev()). Consider the following example:
1536 // If the dependency graph looks like
1537 // n -> movePoint -> o
1538 // then moveBefore(o) will end up with
1539 // n, o, movePoint
1540 // but moveAfter(n) will return false.
1541 return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/false);
1542 }
1543
couldMoveBeforeTopologically(Node * n,Node * movePoint)1544 bool AliasDb::couldMoveBeforeTopologically(Node* n, Node* movePoint) {
1545 return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/true);
1546 }
1547
hasWriters(const at::ArrayRef<Value * > & values) const1548 bool AliasDb::hasWriters(const at::ArrayRef<Value*>& values) const {
1549 return std::any_of(values.begin(), values.end(), [&](Value* value) {
1550 return hasWriters(value);
1551 });
1552 }
1553
escapesScope(const at::ArrayRef<Value * > & vs) const1554 bool AliasDb::escapesScope(const at::ArrayRef<Value*>& vs) const {
1555 return mayContainAlias(graph_->inputs(), vs) ||
1556 mayContainAlias(graph_->outputs(), vs) || mayAliasWildcard(vs);
1557 }
1558
1559 // Correctness conditions:
1560 // no values in either set can have writers, and values in both sets
1561 // cannot escape the current graph scope. Values can escape the current scope
1562 // by aliasing a graph output or input, or by aliasing the wildcard set.
safeToChangeAliasingRelationship(const at::ArrayRef<Value * > & a,const at::ArrayRef<Value * > & b) const1563 bool AliasDb::safeToChangeAliasingRelationship(
1564 const at::ArrayRef<Value*>& a,
1565 const at::ArrayRef<Value*>& b) const {
1566 if (hasWriters(a) || hasWriters(b)) {
1567 return false;
1568 }
1569
1570 return !(escapesScope(a) && escapesScope(b));
1571 }
1572
1573 // Helper for topologically-safe node moves. See `tryMove()` for details.
1574 class AliasDb::WorkingSet {
1575 public:
WorkingSet(Node * mover,const AliasDb & aliasDb)1576 explicit WorkingSet(Node* mover, const AliasDb& aliasDb)
1577 : aliasDb_(aliasDb), mover_(mover) {
1578 for (const auto user : getUsersSameBlock(mover_)) {
1579 moverUsers_.insert(user);
1580 }
1581 moverWrites_ |= aliasDb_.getWrites(mover_);
1582 moverReads_ |= aliasDb_.getReads(mover_);
1583 }
1584
1585 // Add `n` to the working set
add(Node * n)1586 void add(Node* n) {
1587 nodes_.push_back(n);
1588 node_to_index_[n] = static_cast<int64_t>(nodes_.size()) - 1;
1589 for (const auto user : getUsersSameBlock(n)) {
1590 users_.insert(user);
1591 }
1592
1593 writes_ |= aliasDb_.getWrites(n);
1594 reads_ |= aliasDb_.getReads(n);
1595 }
1596
eraseMover()1597 void eraseMover() {
1598 mover_ = nullptr;
1599 moverWrites_.clear();
1600 moverReads_.clear();
1601 moverUsers_.clear();
1602 }
1603
dependentNodes()1604 const std::vector<Node*>& dependentNodes() {
1605 return nodes_;
1606 }
1607
1608 // Does the working set depend on `n`?
dependsOn(Node * n) const1609 bool dependsOn(Node* n) const {
1610 if (!mover_ && nodes_.empty()) {
1611 return false;
1612 }
1613
1614 return hasDataDependency(n) || hasMutabilityDependency(n);
1615 }
1616
1617 private:
hasDataDependency(Node * n) const1618 bool hasDataDependency(Node* n) const {
1619 if (!mover_ && nodes_.empty()) {
1620 return false;
1621 }
1622 const Node* pivot = mover_ ? mover_ : nodes_.front();
1623 if (n->isAfter(pivot)) {
1624 return producesFor(n);
1625 } else {
1626 return consumesFrom(n);
1627 }
1628 }
1629
hasMutabilityDependency(Node * n) const1630 bool hasMutabilityDependency(Node* n) const {
1631 // Check that `n` does not write to anything used by the working set
1632 const auto& nWrites = aliasDb_.getWrites(n);
1633 if (reads_.intersects(nWrites)) {
1634 return true;
1635 }
1636 if (mover_ && moverReads_.intersects(nWrites)) {
1637 return true;
1638 }
1639
1640 // Check that the working set doesn't write to anything that `n` uses.
1641 const auto& nReads = aliasDb_.getReads(n);
1642 if (writes_.intersects(nReads)) {
1643 return true;
1644 }
1645 if (mover_ && moverWrites_.intersects(nReads)) {
1646 return true;
1647 }
1648 return false;
1649 }
1650
1651 // Does the working set produce any values consumed by `n`?
producesFor(Node * n) const1652 bool producesFor(Node* n) const {
1653 // This equivalent to asking: does the total use-set of all the nodes in the
1654 // working set include `n`?
1655 if (mover_ && moverUsers_.count(n)) {
1656 return true;
1657 }
1658 return users_.count(n) != 0;
1659 }
1660
1661 // Does the working set consume any values produced by `n`?
consumesFrom(Node * n) const1662 bool consumesFrom(Node* n) const {
1663 const auto users = getUsersSameBlock(n);
1664
1665 if (mover_ && users.count(mover_)) {
1666 return true;
1667 }
1668 return std::any_of(users.begin(), users.end(), [&](Node* user) {
1669 return node_to_index_.find(user) != node_to_index_.end();
1670 });
1671 }
1672
1673 // Get all users of outputs of `n`, in the same block as `n`.
1674 // This means if there is an `if` node that uses an output of `n` in some
1675 // inner sub-block, we will consider the whole `if` node a user of `n`.
getUsersSameBlock(Node * n) const1676 std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
1677 std::unordered_set<Node*> users;
1678 for (const auto output : n->outputs()) {
1679 for (const auto& use : output->uses()) {
1680 if (auto sameBlock = findSameBlock(use.user, n)) {
1681 users.insert(sameBlock);
1682 }
1683 }
1684 }
1685 return users;
1686 }
1687
1688 // Traverse `target`'s blockchain upward until we find a node that shares a
1689 // block with `n`.
1690 //
1691 // If one can't be found (say, because `n` is an inner block and target is
1692 // outside), then return nullptr. Since we can only reorder nodes within a
1693 // block, `target` would be irrelevant.
findSameBlock(Node * target,Node * n)1694 static Node* findSameBlock(Node* target, Node* n) {
1695 TORCH_INTERNAL_ASSERT(target->owningGraph() == n->owningGraph());
1696 if (target->owningBlock() == n->owningBlock()) {
1697 return target;
1698 } else {
1699 // This user is in a sub-block. Traverse the blockchain upward until
1700 // we arrive at a node that shares a block with `this`
1701 auto curNode = target;
1702 while (curNode->owningBlock() != n->owningBlock()) {
1703 curNode = curNode->owningBlock()->owningNode();
1704 if (curNode == nullptr) {
1705 return curNode;
1706 }
1707 }
1708 return curNode;
1709 }
1710 }
1711
1712 // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1713 const AliasDb& aliasDb_;
1714 std::vector<Node*> nodes_;
1715 // Extra data structure for nodes for faster look up
1716 // Since the tryMove method is used a lot, we want to
1717 // make it as fast as possible.
1718 std::unordered_map<Node*, int64_t> node_to_index_;
1719
1720 // Mover dependencies. We track these separately since we may erase the mover
1721 // from the working set.
1722 Node* mover_;
1723 MemoryLocations moverWrites_;
1724 MemoryLocations moverReads_;
1725 std::unordered_set<Node*> moverUsers_;
1726
1727 // users => # of working set nodes it uses
1728 std::unordered_set<Node*> users_;
1729 // Values written to by the working set => number of nodes writing to value
1730 MemoryLocations writes_;
1731 MemoryLocations reads_;
1732 };
1733
1734 // Try to move `toMove` before/after `movePoint` while preserving value
1735 // dependencies. Returns false iff such a move could not be made.
1736 //
1737 // If `dryRun` is set, don't actually execute the move, just check if the move
1738 // is possible
1739 //
1740 // The basic approach is: have a "working set" that we are moving forward, one
1741 // node at a time. When we can't move past a node (because it depends on the
1742 // working set), then add it to the working set and keep moving until we hit
1743 // `moveAfter`.
tryMove(Node * toMove,Node * movePoint,MoveSide moveSide,bool dryRun)1744 bool AliasDb::tryMove(
1745 Node* toMove,
1746 Node* movePoint,
1747 MoveSide moveSide,
1748 bool dryRun) {
1749 if (toMove->owningBlock() != movePoint->owningBlock()) {
1750 return false;
1751 }
1752 if (toMove == movePoint) {
1753 return true;
1754 }
1755
1756 // 1. Move from `this` toward movePoint, building up the working set of
1757 // dependencies
1758 WorkingSet workingSet(toMove, *this);
1759
1760 auto direction = kNextDirection;
1761 if (toMove->isAfter(movePoint)) {
1762 direction = kPrevDirection;
1763 }
1764
1765 auto curNode = toMove->next_in_graph[direction];
1766
1767 bool toMoveIsOnMoveSide =
1768 (moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
1769 (moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
1770
1771 if (toMoveIsOnMoveSide && curNode == movePoint) {
1772 return true;
1773 }
1774
1775 // it is never valid to move reorder a node with side effects
1776 if (toMove->hasSideEffects() ||
1777 (!toMoveIsOnMoveSide && movePoint->hasSideEffects())) {
1778 return false;
1779 }
1780
1781 // Move forward one node at a time
1782 while (curNode != movePoint) {
1783 // never valid to reorder around a node with side effects
1784 if (curNode->hasSideEffects()) {
1785 return false;
1786 }
1787
1788 if (workingSet.dependsOn(curNode)) {
1789 // If we can't move past this node, add it to the working set
1790 workingSet.add(curNode);
1791 }
1792 curNode = curNode->next_in_graph[direction];
1793 }
1794
1795 // 2. Decide whether we can move it all to `movePoint`.
1796
1797 // Say we are moving directly before movePoint and `toMove` starts before
1798 // movePoint in the graph. The move looks like
1799 //
1800 // `toMove` `toMove` |
1801 // <dependencies> -> `movePoint` | `toMove` and deps are split
1802 // `movePoint` <dependencies> |
1803 //
1804 // Contrast with the case where `toMove` starts AFTER movePoint:
1805 //
1806 // `movePoint` <dependencies> |
1807 // <dependencies> -> `toMove` | `toMove` and deps are together
1808 // `toMove` `movePoint` |
1809 //
1810 // In the first case, we need to split `this` off from its dependencies, so we
1811 // can move the dependencies below `movePoint` and keep `toMove` above.
1812 const bool splitToMoveAndDeps =
1813 (moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
1814 (moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
1815
1816 if (splitToMoveAndDeps) {
1817 // remove `this` from dependencies to be moved past `movePoint`
1818 workingSet.eraseMover();
1819 }
1820
1821 // Check if we can move the working set past the move point
1822 if (workingSet.dependsOn(movePoint)) {
1823 // if we can't, then there are intermediate dependencies between the
1824 // `this` and `movePoint`, so we can't do the move
1825 return false;
1826 }
1827
1828 if (dryRun) {
1829 return true;
1830 }
1831
1832 // 3. Execute the move
1833 TORCH_INTERNAL_ASSERT(curNode == movePoint);
1834 if (splitToMoveAndDeps) {
1835 // Move `toMove`
1836 move(toMove, movePoint, moveSide);
1837
1838 // Then move all of its dependencies on the other side of `movePoint`
1839 const auto reversed =
1840 moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
1841 for (auto n : workingSet.dependentNodes()) {
1842 move(n, curNode, reversed);
1843 curNode = n;
1844 }
1845 } else {
1846 // Just append/prepend everything to `movePoint`
1847 move(toMove, curNode, moveSide);
1848 curNode = toMove;
1849 for (auto n : workingSet.dependentNodes()) {
1850 move(n, curNode, moveSide);
1851 curNode = n;
1852 }
1853 }
1854 return true;
1855 }
1856
1857 // Helper function so we can generalize `tryMove`
move(Node * toMove,Node * movePoint,MoveSide moveSide)1858 void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
1859 switch (moveSide) {
1860 case MoveSide::BEFORE:
1861 toMove->moveBefore(movePoint);
1862 break;
1863 case MoveSide::AFTER:
1864 toMove->moveAfter(movePoint);
1865 break;
1866 }
1867 }
1868
writesToWildcard(Node * n) const1869 bool AliasDb::writesToWildcard(Node* n) const {
1870 if (!writeIndex_->count(n)) {
1871 return false;
1872 }
1873 const auto& writes = writeIndex_->at(n);
1874
1875 // Are any of these memoryLocs a wildcard element?
1876 for (const auto& pr : wildcardIndex_) {
1877 const auto wildcardElement = pr.second;
1878 if (writes.test(wildcardElement->index)) {
1879 return true;
1880 }
1881 }
1882 return false;
1883 }
1884
mayAliasWildcard(const Value * v) const1885 bool AliasDb::mayAliasWildcard(const Value* v) const {
1886 if (auto e = getWildcard(v->type())) {
1887 return memoryDAG_->mayAlias(elementMap_.at(v), e);
1888 }
1889 // There were no wildcards of this type, so return false.
1890 return false;
1891 }
1892
mayAliasWildcard(const at::ArrayRef<Value * > vs) const1893 bool AliasDb::mayAliasWildcard(const at::ArrayRef<Value*> vs) const {
1894 return std::any_of(
1895 vs.begin(), vs.end(), [&](Value* v) { return mayAliasWildcard(v); });
1896 }
1897
tryGetOrCreateWildcard(const TypePtr & type)1898 std::optional<Element*> AliasDb::tryGetOrCreateWildcard(const TypePtr& type) {
1899 auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type);
1900 if (!maybe_mut_types) {
1901 return std::nullopt;
1902 }
1903 auto mut_type = toSingleType(*maybe_mut_types);
1904 auto existing_wildcard = wildcardIndex_.find(*mut_type);
1905 if (existing_wildcard != wildcardIndex_.end()) {
1906 return existing_wildcard->second;
1907 }
1908
1909 auto wildcard_elem = memoryDAGBuilder_->makeFreshValue(nullptr);
1910 wildcardIndex_.emplace(*std::move(mut_type), wildcard_elem);
1911 if (maybe_mut_types->size() > 1) {
1912 pointUnionTypeElementToAllContainedTypes(wildcard_elem, *maybe_mut_types);
1913 } else {
1914 addContainedTypesToFreshElement(wildcard_elem, *maybe_mut_types);
1915 }
1916 return wildcard_elem;
1917 }
1918
pointUnionTypeElementToAllContainedTypes(Element * container_elem,const AliasTypeSet & mut_types)1919 void AliasDb::pointUnionTypeElementToAllContainedTypes(
1920 Element* container_elem,
1921 const AliasTypeSet& mut_types) {
1922 for (const auto& mut_type : mut_types) {
1923 auto maybe_elem = tryGetOrCreateWildcard(mut_type);
1924 if (maybe_elem) {
1925 TORCH_INTERNAL_ASSERT(*maybe_elem != container_elem);
1926 memoryDAGBuilder_->makePointerTo(container_elem, *maybe_elem);
1927 }
1928 }
1929 }
1930
addContainedTypesToFreshElement(Element * container_elem,const AliasTypeSet & mut_types)1931 void AliasDb::addContainedTypesToFreshElement(
1932 Element* container_elem,
1933 const AliasTypeSet& mut_types) {
1934 for (const auto& mut_type : mut_types) {
1935 for (const auto& contained : mut_type->containedTypes()) {
1936 auto maybe_elem = tryGetOrCreateWildcard(contained);
1937 if (maybe_elem) {
1938 memoryDAGBuilder_->addToContainedElements(*maybe_elem, container_elem);
1939 }
1940 }
1941 }
1942 }
1943
1944 // Search the wildcard index for an element that corresponds to the given type.
1945 // Const version returns nullptr
getWildcard(const TypePtr & type) const1946 Element* AliasDb::getWildcard(const TypePtr& type) const {
1947 auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type);
1948 if (!maybe_mut_types) {
1949 return {};
1950 }
1951 if (maybe_mut_types->size() > 1) {
1952 auto union_type = UnionType::create(*maybe_mut_types);
1953 // Get a <TypePtr, Element*> pair where the TypePtr is this Union
1954 // type and the Element is the corresponding Wildcard
1955 auto maybe_union_pair = wildcardIndex_.find(union_type);
1956 if (maybe_union_pair != wildcardIndex_.end()) {
1957 return (*maybe_union_pair).second;
1958 }
1959 } else {
1960 // Get a <TypePtr, Element*> pair where the TypePtr is the given
1961 // type and the Element is the corresponding Wildcard
1962 auto type_pair = wildcardIndex_.find((*maybe_mut_types)[0]);
1963 if (type_pair != wildcardIndex_.end()) {
1964 return type_pair->second;
1965 }
1966 }
1967 return {};
1968 }
1969
1970 // Register `v` as a wildcard value.
setWildcard(const Value * v)1971 std::optional<Element*> AliasDb::setWildcard(const Value* v) {
1972 std::optional<Element*> maybe_wildcardElement =
1973 tryGetOrCreateWildcard(v->type());
1974 if (!maybe_wildcardElement) {
1975 return std::nullopt;
1976 }
1977 // Ensure that we create a corresponding Element for `v` still, as it is an
1978 // invariant that all mutable values have an Element
1979 getOrCreateElement(v);
1980 wildcards_.insert(v);
1981 return maybe_wildcardElement;
1982 }
1983
buildWrittenToLocationsIndex()1984 void AliasDb::buildWrittenToLocationsIndex() {
1985 MemoryLocations ret;
1986 for (const auto& pr : *writeIndex_) {
1987 const auto& writtenLocs = pr.second;
1988 ret |= writtenLocs;
1989 }
1990 writtenToLocationsIndex_ = ret;
1991 }
1992
Lint(const AliasDb * db)1993 void Lint(const AliasDb* db) {
1994 bool failed = false;
1995
1996 std::stringstream ss;
1997 // Every mutable value in the system has a corresponding element.
1998 for (const auto& v : db->graph_->all_values) {
1999 if (!db->isMutableTypeInternal(v)) {
2000 continue;
2001 }
2002 auto it = db->elementMap_.find(v);
2003 if (it == db->elementMap_.end()) {
2004 failed = true;
2005 ss << "Value %" << v->debugName() << " of type " << v->type()->repr_str()
2006 << " wasn't found in the element map.\n"
2007 << "It was defined in " << *v->node();
2008 }
2009 }
2010 TORCH_INTERNAL_ASSERT(!failed, ss.str());
2011
2012 // Two checks that we want to add but can't until the mutation API is more
2013 // fully developed.
2014 // - Every mutable value in the aliasdb belongs to the graph
2015 // - All container values have contained elements
2016 }
2017
2018 } // namespace torch::jit
2019