xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/alias_analysis.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/alias_analysis.h>
2 
3 #include <ATen/core/interned_strings.h>
4 #include <c10/util/flat_hash_map.h>
5 #include <c10/util/irange.h>
6 #include <torch/csrc/jit/api/function_impl.h>
7 #include <torch/csrc/jit/jit_log.h>
8 #include <torch/csrc/jit/passes/inliner.h>
9 #include <torch/csrc/jit/passes/utils/subgraph_utils.h>
10 #include <torch/csrc/jit/runtime/operator.h>
11 #include <fstream>
12 #include <iostream>
13 
14 namespace torch::jit {
15 
16 namespace {
17 
toSingleType(const AliasTypeSet & mut_types)18 c10::MaybeOwned<TypePtr> toSingleType(const AliasTypeSet& mut_types) {
19   return mut_types.size() == 1
20       ? c10::MaybeOwned<TypePtr>::borrowed(mut_types[0])
21       : c10::MaybeOwned<TypePtr>::owned(c10::UnionType::create(mut_types));
22 }
23 
24 // This class determines whether a type is mutable, and, if so, it maps
25 // the type to its "mutable equivalent" (see definition in
26 // `mapTypeToAliasTypeSet`). It uses a cache of TypePtrs to speed up these
27 // type lookups
28 class MutableTypePtrHelper {
29  public:
MutableTypePtrHelper(ska::flat_hash_map<TypePtr,AliasTypeSet> * mutable_type_cache)30   explicit MutableTypePtrHelper(
31       ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache)
32       : mutable_type_cache_(mutable_type_cache) {}
33 
34   // Map any mutable type to a type such that all other types which the
35   // mutable type can alias will be mapped to the same type. For
36   // example, calling this method on `Optional[List[int]]` should be
37   // the same as calling this method on `List[int]`.
38   //
39   // Rules:
40   //   - If the type is not mutable, return `nullopt`
41   //   - If the type is a `Tuple`, that means that it's an immutable
42   //     object that can itself contain mutable objects. We want to make
43   //     sure that the mutable objects are correctly aliased, so we
44   //     remove the immutable objects. (For example,
45   //     `Tuple[int, Tensor]` would become `Tuple[Tensor]`, while
46   //     `Tuple[int, str]` would be returned as `nullopt`.) This is a
47   //     convenience that makes it easy to check if the `Tuple`
48   //     contains only immutable objects, though it's not technically
49   //     necessary
50   //   - For any Tensor type (including Tensor types that are part of
51   //     a larger container, e.g. `List[Tensor]`), return the
52   //     "unshaped" version of that Tensor. An "unshaped" Tensor is a
53   //     Tensor with shape information removed. For example, a Tensor
54   //     of dimension 4 would map to the same type as a Tensor of
55   //     dimension 1. This allows us to treat all subclasses of Tensor
56   //     as a single, homogenous "Tensor" type.
mapTypeToAliasTypeSet(const TypePtr & type)57   std::optional<AliasTypeSet> mapTypeToAliasTypeSet(const TypePtr& type) {
58     if (mutable_type_cache_) {
59       const AliasTypeSet* result = mapTypeToBorrowedAliasTypeSet(type);
60       if (result) {
61         return *result;
62       }
63     }
64     return mapTypeToAliasTypeSetImpl(type);
65   }
66 
mapTypeToBorrowedAliasTypeSet(const TypePtr & type)67   const AliasTypeSet* mapTypeToBorrowedAliasTypeSet(const TypePtr& type) {
68     TORCH_INTERNAL_ASSERT_DEBUG_ONLY(mutable_type_cache_ != nullptr);
69     auto maybe_type_mapping = mutable_type_cache_->find(type);
70     if (maybe_type_mapping != mutable_type_cache_->end()) {
71       return &maybe_type_mapping->second;
72     }
73 
74     auto mutable_types = mapTypeToAliasTypeSetImpl(type);
75     if (mutable_types) {
76       auto it =
77           mutable_type_cache_->emplace(type, std::move(*mutable_types)).first;
78       return &it->second;
79     } else {
80       return nullptr;
81     }
82   }
83 
84  private:
mapTypeToAliasTypeSetImpl(const TypePtr & type)85   std::optional<AliasTypeSet> mapTypeToAliasTypeSetImpl(const TypePtr& type) {
86     switch (type->kind()) {
87       case TypeKind::ListType:
88       case TypeKind::DictType:
89       case TypeKind::ClassType:
90       case TypeKind::TensorType:
91         // TODO: Look up cached contained types. this is kind of tricky
92         // because a `List[Optional[T]]` should still be
93         // `List[Optional[Unshaped(T)]]`, but
94         // `mapTypeToAliasTypeSet(Optional[T])` should be `T`
95         return AliasTypeSet{unshapedType(type)};
96       case TypeKind::UnionType: {
97         AliasTypeSet mutable_types;
98         for (const TypePtr& inner :
99              type->expectRef<UnionType>().containedTypes()) {
100           if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
101             mutable_types.insert(
102                 mutable_types.end(),
103                 (*maybe_inner_types).begin(),
104                 (*maybe_inner_types).end());
105           }
106         }
107         if (mutable_types.empty()) {
108           return std::nullopt;
109         }
110         return mutable_types;
111       }
112       case TypeKind::OptionalType: {
113         auto inner = type->castRaw<OptionalType>()->getElementType();
114         return mapTypeToAliasTypeSet(inner);
115       }
116       case TypeKind::AnyType:
117         return {AliasTypeSet{type}};
118       case TypeKind::FutureType: {
119         if (auto maybe_mut_types = mapTypeToAliasTypeSet(
120                 type->castRaw<FutureType>()->getElementType())) {
121           return {AliasTypeSet{
122               FutureType::create(*toSingleType(*maybe_mut_types))}};
123         }
124         return std::nullopt;
125       }
126       case TypeKind::AwaitType: {
127         if (auto maybe_mut_types = mapTypeToAliasTypeSet(
128                 type->castRaw<AwaitType>()->getElementType())) {
129           return {
130               AliasTypeSet{AwaitType::create(*toSingleType(*maybe_mut_types))}};
131         }
132         return std::nullopt;
133       }
134       case TypeKind::TupleType: {
135         std::vector<TypePtr> mutable_types;
136         for (const TypePtr& inner : type->expectRef<TupleType>().elements()) {
137           if (auto maybe_inner_types = mapTypeToAliasTypeSet(inner)) {
138             mutable_types.insert(
139                 mutable_types.end(),
140                 (*maybe_inner_types).begin(),
141                 (*maybe_inner_types).end());
142           }
143         }
144         if (mutable_types.empty()) {
145           return std::nullopt;
146         }
147         return {AliasTypeSet{TupleType::create(mutable_types)}};
148       }
149       default:
150         return std::nullopt;
151     }
152   }
153   ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache_;
154 };
155 
isMutableTypeImpl(const TypePtr & type,ska::flat_hash_map<TypePtr,AliasTypeSet> * mutable_type_cache)156 bool isMutableTypeImpl(
157     const TypePtr& type,
158     ska::flat_hash_map<TypePtr, AliasTypeSet>* mutable_type_cache) {
159   // Check common cases to avoid recursively constructing type in
160   // `mapTypeToAliasTypeSetPtrImpl`
161   auto kind = type->kind();
162   if (kind == TypeKind::TensorType || kind == TypeKind::ListType ||
163       kind == TypeKind::ClassType || kind == TypeKind::DictType) {
164     return true;
165   }
166   MutableTypePtrHelper helper(mutable_type_cache);
167   if (mutable_type_cache) {
168     return helper.mapTypeToBorrowedAliasTypeSet(type) != nullptr;
169   } else {
170     return helper.mapTypeToAliasTypeSet(type).has_value();
171   }
172 }
173 
174 } // namespace
175 
176 // Static `isMutableType` does not use cache of type -> mutable type equivalent
isMutableType(const TypePtr & type)177 bool AliasDb::isMutableType(const TypePtr& type) {
178   return isMutableTypeImpl(type, nullptr);
179 }
180 
isMutableType(const Value * v)181 bool AliasDb::isMutableType(const Value* v) {
182   return isMutableType(v->type());
183 }
184 
185 // Make use of type -> mutable cache
isMutableTypeInternal(const TypePtr & type) const186 bool AliasDb::isMutableTypeInternal(const TypePtr& type) const {
187   return isMutableTypeImpl(type, &mapped_mutable_types_);
188 }
189 
isMutableTypeInternal(const Value * v) const190 bool AliasDb::isMutableTypeInternal(const Value* v) const {
191   return isMutableTypeInternal(v->type());
192 }
193 
mapTypeToAliasTypeSetPtr(const TypePtr & type) const194 const AliasTypeSet* AliasDb::mapTypeToAliasTypeSetPtr(
195     const TypePtr& type) const {
196   MutableTypePtrHelper helper(&mapped_mutable_types_);
197   return helper.mapTypeToBorrowedAliasTypeSet(type);
198 }
199 
200 AliasDb::~AliasDb() = default;
201 
202 // Structure used during analysis to keep track of all writes at a high
203 // level. When the analysis is completed, this will be used to construct
204 // a more efficient WriteIndex
205 struct AliasDb::WriteRegistry {
registerWritetorch::jit::AliasDb::WriteRegistry206   void registerWrite(const Value* v, Node* n) {
207     writes_[n].emplace_back(v);
208   }
registerWriteToAllContainedtorch::jit::AliasDb::WriteRegistry209   void registerWriteToAllContained(const Value* v, Node* n) {
210     containedWrites_[n].emplace_back(v);
211   }
registerWriteToAllWildcardstorch::jit::AliasDb::WriteRegistry212   void registerWriteToAllWildcards(Node* n) {
213     writesToAllWildcards_.insert(n);
214   }
215   std::unordered_map<Node*, std::vector<const Value*>> writes_;
216   std::unordered_map<Node*, std::vector<const Value*>> containedWrites_;
217   std::unordered_set<Node*> writesToAllWildcards_;
218 };
219 
AliasDb(std::shared_ptr<Graph> graph,bool isFrozen,bool descendFunctionCalls)220 AliasDb::AliasDb(
221     std::shared_ptr<Graph> graph,
222     bool isFrozen,
223     bool descendFunctionCalls)
224     : graph_(std::move(graph)),
225       isFrozen_(isFrozen),
226       descend_function_calls_(descendFunctionCalls),
227       memoryDAGBuilder_(std::make_unique<MemoryDAGBuilder>()),
228       writeRegistry_(std::make_unique<AliasDb::WriteRegistry>()) {
229   analyze(graph_);
230 
231   memoryDAG_ = std::move(*memoryDAGBuilder_).createMemoryDAG();
232   // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
233   memoryDAGBuilder_ = nullptr; // to make further access a hard error
234 
235   memoryDAG_->setWildcards(
236       wildcards_, elementMap_, [&](const Value* v) -> Element* {
237         return getWildcard(v->type());
238       });
239 
240   // Now we build up the various write indices based on information in the write
241   // registry that we populated during analysis
242 
243   // Initialize the write index
244   writeIndex_ = TWriteIndex();
245   auto& writeIndex = *writeIndex_; // to make operator[] less ugly
246 
247   // Build the write index
248   for (const auto& write : writeRegistry_->writes_) {
249     Node* node = write.first;
250     const std::vector<const Value*> writtenValues = write.second;
251     for (const Value* writtenValue : writtenValues) {
252       auto it = elementMap_.find(writtenValue);
253       TORCH_INTERNAL_ASSERT(
254           it != elementMap_.end(), "Tried to write to value not in MemoryDAG");
255       const auto& writtenMemoryLocations =
256           memoryDAG_->getMemoryLocations(it->second);
257       writeIndex[node] |= writtenMemoryLocations;
258     }
259   }
260 
261   for (const auto& write : writeRegistry_->containedWrites_) {
262     Node* node = write.first;
263     const std::vector<const Value*>& writtenValues = write.second;
264     for (const Value* writtenValue : writtenValues) {
265       auto elem = elementMap_.at(writtenValue);
266       MemoryLocations writtenMemoryLocations;
267       memoryDAG_->collectAllContainedMemoryLocations(
268           elem, writtenMemoryLocations);
269       writeIndex[node] |= writtenMemoryLocations;
270     }
271   }
272 
273   for (const auto& write : writeRegistry_->writesToAllWildcards_) {
274     for (const auto& pr : wildcardIndex_) {
275       writeIndex[write].set(pr.second->index);
276     }
277   }
278 
279   // Now that we've built the write index, we can null out the WriteRegistry to
280   // make future access an error. In this way we prevent the index from getting
281   // out of sync (since we have no way of registering new writes)
282   // NOLINTNEXTLINE(cppcoreguidelines-prefer-member-initializer)
283   writeRegistry_ = nullptr;
284 
285   // Initialize the write cache
286   buildWrittenToLocationsIndex();
287   GRAPH_DEBUG(toString());
288 }
289 
isMutable(Node * n) const290 bool AliasDb::isMutable(Node* n) const {
291   ValueSet vs;
292   for (const auto input : n->inputs()) {
293     vs.insert(input);
294   }
295   return writesToAlias(n, vs);
296 }
297 
hasInputWriters(const Node * n) const298 bool AliasDb::hasInputWriters(const Node* n) const {
299   for (const auto input : n->inputs()) {
300     if (hasWriters(input)) {
301       return true;
302     }
303   }
304   return false;
305 }
306 
hasOutputWriters(const Node * n) const307 bool AliasDb::hasOutputWriters(const Node* n) const {
308   for (const auto output : n->outputs()) {
309     if (hasWriters(output)) {
310       return true;
311     }
312   }
313   return false;
314 }
315 
hasWriters(const Node * n) const316 bool AliasDb::hasWriters(const Node* n) const {
317   return hasInputWriters(n) || hasOutputWriters(n);
318 }
319 
hasWriters(const Value * v) const320 bool AliasDb::hasWriters(const Value* v) const {
321   if (v->mustBeNone()) {
322     return false;
323   }
324 
325   auto it = elementMap_.find(v);
326   if (it == elementMap_.end()) {
327     return false;
328   }
329 
330   const auto& el = it->second;
331   return writtenToLocationsIndex_->intersects(
332       memoryDAG_->getMemoryLocations(el));
333 }
334 
getWritesImpl(Node * n,MemoryLocations & ret) const335 void AliasDb::getWritesImpl(Node* n, MemoryLocations& ret) const {
336   if (writeIndex_->count(n)) {
337     const auto& writes = writeIndex_->at(n);
338     ret |= writes;
339   }
340 
341   for (auto block : n->blocks()) {
342     for (auto node : block->nodes()) {
343       getWritesImpl(node, ret);
344     }
345   }
346 }
347 
348 // Does `n` write to an alias of one of the values in `vs`?
writesToAlias(Node * n,const ValueSet & vs) const349 bool AliasDb::writesToAlias(Node* n, const ValueSet& vs) const {
350   const auto writtenTo = getWrites(n);
351   if (writtenTo.empty()) {
352     return false;
353   }
354 
355   MemoryLocations locs;
356   for (const auto v : vs) {
357     auto it = elementMap_.find(v);
358     if (it != elementMap_.end()) {
359       const auto& vlocs = memoryDAG_->getMemoryLocations(it->second);
360       if (writtenTo.intersects(vlocs)) {
361         return true;
362       }
363     }
364   }
365 
366   return false;
367 }
368 
getWrites(Node * n) const369 MemoryLocations AliasDb::getWrites(Node* n) const {
370   MemoryLocations writes;
371   getWritesImpl(n, writes);
372   return writes;
373 }
374 
getReadsImpl(Node * n,MemoryLocations & ret) const375 void AliasDb::getReadsImpl(Node* n, MemoryLocations& ret) const {
376   for (const auto input : n->inputs()) {
377     auto it = elementMap_.find(input);
378     if (it != elementMap_.end()) {
379       auto el = it->second;
380 
381       // Add all memory locations this element may alias and their contained
382       // elements
383       memoryDAG_->collectAllContainedMemoryLocations(el, ret);
384     }
385   }
386 
387   for (auto block : n->blocks()) {
388     for (auto node : block->nodes()) {
389       getReadsImpl(node, ret);
390     }
391   }
392 }
393 
getReads(Node * n) const394 MemoryLocations AliasDb::getReads(Node* n) const {
395   MemoryLocations reads;
396   getReadsImpl(n, reads);
397   return reads;
398 }
399 
getElementName(const Element * e) const400 std::string AliasDb::getElementName(const Element* e) const {
401   if (e->values.empty()) {
402     // Not the most efficient way, but given the fact there are
403     // not too many types and even fewer of them will end up in
404     // `wildcardIndex_`, we should be fine with a linear search
405     // each time we hit a Wildcard leaf
406     for (const auto& ent : wildcardIndex_) {
407       if (ent.second == e) {
408         return std::string("WILDCARD for type ") + ent.first->str();
409       }
410     }
411     return "WILDCARD";
412   } else {
413     std::ostringstream ss;
414     if (e->values.size() == 1) {
415       ss << "%" << (*e->values.begin())->debugName();
416       return ss.str();
417     }
418     ss << "(";
419     for (const Value* v : e->values) {
420       ss << "%" << v->debugName() << ", ";
421     }
422     ss << ")";
423     return ss.str();
424   }
425 }
426 
dump() const427 void AliasDb::dump() const {
428   std::cout << toString();
429 }
430 
toString() const431 std::string AliasDb::toString() const {
432   std::stringstream ss{};
433 
434   ss << "\n===1. GRAPH===\n";
435   ss << graph_->toString();
436 
437   ss << "\n===2. ALIAS DB===\n";
438   for (const auto& ptrPair : elementMap_) {
439     const auto element = ptrPair.second;
440     int ct = 0;
441     if (!element->pointsTo.empty()) {
442       ss << getElementName(element) << " points to: ";
443       for (const auto pointedTo : element->pointsTo) {
444         if (ct > 0) {
445           ss << ", ";
446         }
447         ++ct;
448         ss << getElementName(memoryDAG_->fromIndex(pointedTo));
449       }
450       ss << "\n";
451     }
452     ct = 0;
453     if (!element->containedElements.empty()) {
454       ss << getElementName(element) << " contains: ";
455       for (const auto contained : element->containedElements) {
456         ss << getElementName(memoryDAG_->fromIndex(contained));
457         if (ct > 0) {
458           ss << ", ";
459         }
460         ++ct;
461       }
462       ss << "\n";
463     }
464   }
465 
466   ss << "\n===3. Writes===\n";
467   for (const auto& pr : *writeIndex_) {
468     const auto node = pr.first;
469     const auto& values = pr.second;
470     ss << *node;
471     ss << "  ";
472     for (const auto value : values) {
473       ss << getElementName(memoryDAG_->fromIndex(value)) << ", ";
474     }
475     ss << "\n";
476   }
477   ss << "\n";
478   return ss.str();
479 }
480 
dumpToGraphvizFile(const char * filename) const481 bool AliasDb::dumpToGraphvizFile(const char* filename) const {
482   std::ofstream dot_file(filename);
483   if (!dot_file.good()) {
484     std::cout << "Failed to create Graphviz file: '" << filename << "'\n";
485     return false;
486   }
487   dot_file << toGraphviz();
488   return true;
489 }
490 
toGraphviz() const491 std::string AliasDb::toGraphviz() const {
492   std::stringstream dot;
493 
494   // Local helper to generate a graphviz-friendly name encoding
495   // See also AliasDb::getElementName()
496   const auto name = [this](const Element* e) -> std::string {
497     if (e->values.empty()) {
498       for (const auto& ent : wildcardIndex_) {
499         if (ent.second == e) {
500           return std::string("\"WILDCARD for ") + ent.first->str() + "\"";
501         }
502       }
503       return "\"WILDCARD\"";
504     } else {
505       std::ostringstream ss;
506       if (e->values.size() == 1) {
507         ss << "\"\\%" << (*e->values.begin())->debugName() << "\"";
508         return ss.str();
509       }
510       ss << "\"(";
511       for (const Value* v : e->values) {
512         ss << "\\%" << v->debugName() << ", ";
513       }
514       ss << ")\"";
515       return ss.str();
516     }
517   };
518 
519   // Include the textual representation for reference
520   dot << "/*\n";
521   dot << toString();
522   dot << "*/\n";
523 
524   dot << "digraph alias_db {\n"
525       << "  rankdir=LR\n"
526       << "  node [shape=rect, color=gray];\n"
527       << "  edge [color=black];\n";
528 
529   for (const auto& ptrPair : elementMap_) {
530     const auto element = ptrPair.second;
531     if (!element->pointsTo.empty()) {
532       for (const auto pointedTo : element->pointsTo) {
533         dot << "  " << name(element) << " -> "
534             << name(memoryDAG_->fromIndex(pointedTo)) << "\n";
535       }
536     }
537     if (!element->containedElements.empty()) {
538       for (const auto contained : element->containedElements) {
539         dot << "  " << name(element) << " -> "
540             << name(memoryDAG_->fromIndex(contained))
541             << " [style=dashed, color=blue]\n";
542       }
543     }
544   }
545 
546   dot << "}\n";
547   return dot.str();
548 }
549 
analyze(const std::shared_ptr<Graph> & graph)550 void AliasDb::analyze(const std::shared_ptr<Graph>& graph) {
551   for (auto input : graph->inputs()) {
552     setWildcard(input);
553   }
554   analyze(graph->block());
555 }
556 
analyze(Block * block)557 void AliasDb::analyze(Block* block) {
558   for (auto node : block->nodes()) {
559     analyze(node);
560   }
561 }
562 
analyze(Node * node)563 void AliasDb::analyze(Node* node) {
564   analyzeImpl(node);
565 }
566 
567 // Returns true if analysis was run using
568 // the registered analyzer.
tryRegisteredAnalysis(Node * node)569 bool AliasDb::tryRegisteredAnalysis(Node* node) {
570   const Operator& op = node->getOperator();
571   auto analysis = op.aliasAnalysisKind();
572   if (AliasAnalysisKind::PURE_FUNCTION == analysis) {
573     analyzeCreator(node);
574     return true;
575   }
576   return false;
577 }
578 
579 // The basic strategy is:
580 //   1. Retrieve alias information for every input.
581 //   2. Use the node's schema's alias annotations to propgagate alias/write
582 //      information to the outputs. For unschematized nodes, a special analyzer
583 //      will have to be handwritten.
analyzeImpl(Node * node)584 void AliasDb::analyzeImpl(Node* node) {
585   auto op = node->maybeOperator();
586   const bool hasSpecialCase = aliasAnalysisHasSpecialCaseFor(node->kind());
587   if (op) {
588     const auto analysis = op->aliasAnalysisKind();
589 
590     const bool registeredAsSpecialCase =
591         analysis == AliasAnalysisKind::INTERNAL_SPECIAL_CASE;
592     if (C10_UNLIKELY(registeredAsSpecialCase && !hasSpecialCase)) {
593       TORCH_INTERNAL_ASSERT(
594           false,
595           "Op ",
596           node->kind().toDisplayString(),
597           " is registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but doesn't have a special case.");
598     } else if (C10_UNLIKELY(!registeredAsSpecialCase && hasSpecialCase)) {
599       TORCH_INTERNAL_ASSERT(
600           false,
601           "Op ",
602           node->kind().toDisplayString(),
603           " has a special case and should be registered with AliasAnalysisKind::INTERNAL_SPECIAL_CASE but is registered with ",
604           c10::toString(analysis));
605     }
606   } else {
607     if (!hasSpecialCase) {
608       std::ostringstream oss;
609       for (const auto input : node->inputs()) {
610         oss << input->type()->str() << ", ";
611       }
612       oss << "\n\nCandidates:";
613       const auto& candidates = getAllOperatorsFor(node->kind());
614       for (const auto& candidate : candidates) {
615         oss << "\n\t" << candidate->schema();
616       }
617       TORCH_INTERNAL_ASSERT(
618           0,
619           "We don't have an op for ",
620           node->kind().toDisplayString(),
621           " but it isn't a special case.  ",
622           "Argument types: ",
623           oss.str());
624     }
625   }
626 
627   // These nodes are not schematized, so we need to handle them specially
628   switch (node->kind()) {
629     case prim::If:
630       return analyzeIf(node);
631     case prim::Loop:
632       return analyzeLoop(node);
633     case prim::FusionGroup:
634     case prim::CudaFusionGroup:
635     case prim::oneDNNFusionGroup:
636     case prim::FunctionalGraph:
637     case prim::DifferentiableGraph:
638     case prim::FallbackGraph:
639       return analyzeSubgraph(node);
640     case prim::fork:
641       return analyzeFork(node);
642     case aten::wait:
643       return analyzeWait(node);
644     case prim::awaitable:
645     case prim::awaitable_nowait:
646       return analyzeAwaitable(node);
647     case prim::awaitable_wait:
648       return analyzeAwaitableWait(node);
649     case prim::rpc_async:
650     case prim::rpc_sync:
651     case prim::rpc_remote:
652       return analyzeRpcAsync(node);
653     case aten::batch_norm:
654       return analyzeBatchNorm(node);
655     case aten::instance_norm:
656       return analyzeInstanceNorm(node);
657     case prim::GradOf:
658       return analyzeGradOf(node);
659     case prim::BroadcastMKLDNNTensors: {
660       makePointerTo(node->outputs().at(0), node->inputs().at(0));
661       makePointerTo(node->outputs().at(1), node->inputs().at(1));
662       return;
663     }
664     // TODO: think more about TensorExpr alias correctness
665     case prim::TensorExprGroup:
666     case prim::TensorExprDynamicGroup:
667     case prim::MKLDNNGroup:
668     case prim::ConstantMKLDNNTensor:
669     case prim::StaticSubgraph:
670     case prim::Constant:
671     case prim::AutogradZero:
672     case prim::AutogradAdd:
673     case prim::FusedConcat:
674     case prim::MMTreeReduce:
675     case prim::MMBatchSide:
676     case prim::BroadcastSizes:
677     case prim::ChunkSizes:
678     // this should never be seen outside of initial compilation
679     // but because of some dependencies with closure invoking alias
680     // db needs to be handled here
681     case prim::EmptyListLiteral:
682     case prim::Closure:
683     case prim::CreateObject:
684     case prim::tolist:
685     case prim::Uninitialized:
686       return analyzeCreator(node);
687     case prim::TupleConstruct:
688     case prim::DictConstruct:
689     case prim::ListConstruct:
690       return analyzeContainerConstruct(node);
691     case prim::TupleUnpack:
692     case prim::TupleIndex:
693     case prim::TupleSlice:
694     case prim::ListUnpack:
695     case prim::PythonOp:
696     case prim::GetAttr:
697       if (isFrozen_ && node->kind() == prim::GetAttr) {
698         auto& ty = node->input()->type();
699         if (ty->expectRef<ClassType>().is_module()) {
700           return analyzeCreator(node);
701         }
702       }
703       return analyzeExtractor(node);
704     case prim::unchecked_cast:
705       return makePointerTo(node->output(), node->input());
706     case prim::ConstantChunk:
707       return analyzeChunk(node);
708     case prim::BroadcastingChunk:
709       return analyzeBroadcastingChunk(node);
710     case prim::SetAttr:
711       return analyzeSetAttr(node);
712     case prim::profile_ivalue:
713     case prim::profile:
714       makePointerTo(node->output(), node->inputs().at(0));
715       return;
716     case prim::TypeCheck:
717     case prim::RequiresGradCheck: {
718       auto num_inputs = node->inputs().size();
719       for (const auto i : c10::irange(num_inputs)) {
720         makePointerTo(node->outputs().at(i), node->inputs().at(i));
721       }
722       return;
723     }
724     case prim::BailOut:
725       TORCH_INTERNAL_ASSERT(
726           node->inputs().at(0)->node()->kind() == prim::BailoutTemplate);
727       makePointerTo(node->output(), node->inputs().at(1));
728       return;
729     case prim::Guard:
730       makePointerTo(node->output(), node->inputs().at(0));
731       return;
732     case prim::CallFunction:
733     case prim::CallMethod: {
734       // TODO: this can be improved with summarizes of what the function does
735       // for now we assume the worst
736       if (!descend_function_calls_) {
737         return analyzeConservative(node);
738       }
739       auto g = tryToGraphFunction(node);
740       if (!g) {
741         return analyzeConservative(node);
742       }
743       // this is an unoptimized path - we copy the subgraph for each function
744       // call past the first - so we do not generally enable the recursive
745       // analysis. use cases for fine-grained alias analysis without inlining
746       // are very uncommon
747       auto graph = g->optimized_graph();
748       // alias analysis will use Value* as mappings for information,
749       // so for each analysis of a particular function call we need a new graph
750       // for all copies made, store them for duration of analysis so we do not
751       // run into lifetime issues with the graph
752       std::vector<std::shared_ptr<Graph>>& graphs =
753           function_call_copies_[graph.get()];
754       if (graphs.empty()) {
755         graphs.push_back(graph);
756         analyzeSubgraph(node, graph);
757       } else {
758         auto copied_graph = graph->copy();
759         graphs.push_back(copied_graph);
760         analyzeSubgraph(node, copied_graph);
761       }
762       return;
763     }
764     case prim::Enter:
765     case prim::Exit:
766       // TODO: this can be improved with summarizes of what the function does
767       // for now we assume the worst
768       // NB: update safeToChangeAliasingRelationship if changed
769       return analyzeConservative(node);
770     case prim::Print:
771     case prim::isinstance:
772       // These ops do nothing
773       return;
774     default:
775       if (tryRegisteredAnalysis(node)) {
776         return;
777       }
778   }
779 
780   TORCH_INTERNAL_ASSERT(op, "We should have an op schema if we get to here");
781   const AliasAnalysisKind analysis = op->aliasAnalysisKind();
782   TORCH_INTERNAL_ASSERT(
783       analysis != AliasAnalysisKind::INTERNAL_SPECIAL_CASE &&
784           !aliasAnalysisHasSpecialCaseFor(node->kind()),
785       "Special cases should be handled already if we're here.");
786 
787   if (node->kind().is_aten() || node->kind().is_prim() ||
788       node->kind().is_cuda()) {
789     // TODO There is nothing in the system that relies on aten:: and prim::
790     // ops using AliasAnalysisKind::FROM_SCHEMA or
791     // AliasAnalysisKind::INTERNAL_SPECIAL_CASE, but this is the intended
792     // behavior for all current ops and a good error check. We can consider
793     // lifting this constraint later if we have a use case for it.
794     TORCH_INTERNAL_ASSERT(
795         analysis == AliasAnalysisKind::FROM_SCHEMA ||
796             analysis == AliasAnalysisKind::CONSERVATIVE,
797         "aten:: and prim:: operators should use AliasAnalysisKind::FROM_SCHEMA or "
798         "AliasAnalysisKind::CONSERVATIVE(if really necessary), but ",
799         node->kind().toDisplayString(),
800         " doesn't. Note: Ideally, prim:: operators actually shouldn't have a schema ",
801         "and then use AliasAnalysisKind::INTERNAL_SPECIAL_CASE instead.");
802   }
803 
804   if (analysis == AliasAnalysisKind::CONSERVATIVE) {
805     // TODO A previous implementation of alias analysis always accessed
806     // node->schema , which cause the schema caches in the Node class to be
807     // filled for the full graph. Unfortunately, our JIT passes started relying
808     // on that, so we need to keep doing this. Details: in
809     // caffe2/torch/onnx/utils.py, _jit_pass_onnx is called on an invalid JIT
810     // graph because we called _jit_pass_erase_number_types right before and
811     // ints are now Tensors instead. So if _jit_pass_onnx tries to look up
812     // operator schemas, it will crash. However, _jit_pass_constant_propagation,
813     // which is called before it, runs alias analysis and prefills the schema
814     // cache in the all Node instances so that _jit_pass_onnx doesn't look up
815     // operators to get the schemas anymore. We should fix this.
816     node->schema(); // fill the schema cache in the Node class
817 
818     return analyzeConservative(node);
819   }
820 
821   TORCH_INTERNAL_ASSERT(
822       analysis == AliasAnalysisKind::FROM_SCHEMA,
823       "AliasAnalysisKind::CONSERVATIVE/PURE_FUNCTION/INTERNAL_SPECIAL_CASE should already have been handled above");
824   const auto& schema = node->schema();
825 
826   // Bind the schema's "formal" alias annotation to the actual values those
827   // schema arguments represent
828   std::unordered_map<Symbol, Value*> formalToActual;
829   for (const auto i : c10::irange(schema.arguments().size())) {
830     const at::AliasInfo* formal = schema.arguments()[i].alias_info();
831     const auto& actualValue = node->inputs().at(i);
832 
833     // Skip if there's no alias annotation
834     if (!formal) {
835       continue;
836     }
837 
838     // If this type cannot alias, continue. Can occur with a VarType schema
839     if (!isMutableTypeInternal(actualValue)) {
840       continue;
841     }
842 
843     // Do sanity checks on the alias annotation
844     TORCH_INTERNAL_ASSERT(
845         formal->containedTypes().size() <= 1,
846         "Composite types for alias analysis not yet supported");
847     TORCH_INTERNAL_ASSERT(
848         !formal->isWildcardBefore(),
849         "Doesn't make sense for a input value to begin as a wildcard");
850     // This is a special case where we have alias info before [] but not after,
851     // such as `Tensor(a!)[]`
852     if (formal->containedTypes().size() == 1 && formal->beforeSets().empty()) {
853       // Use the first containedType in alias info.
854       formal = &(formal->containedTypes()[0]);
855     }
856 
857     const auto& formalAlias = formal->beforeSet();
858 
859     // skip if we've already bound this alias
860     if (formalToActual.count(formalAlias) != 0) {
861       continue;
862     }
863 
864     // Bind the formal to the actual
865     formalToActual[formalAlias] = actualValue;
866 
867     // Record writes
868     if (formal->isWrite()) {
869       registerWrite(actualValue, node);
870     }
871 
872     // Now deal with sets after the '->'
873     if (formal->isWildcardAfter()) {
874       TORCH_INTERNAL_ASSERT(
875           formal->afterSets().size() == 1,
876           "If the after set contains a wildcard, "
877           "there should be no other alias sets specified.");
878       setWildcard(actualValue);
879     } else {
880       // We don't understand anything else in the after yet, so assert there's
881       // been no change.
882       TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
883     }
884   }
885 
886   // Use the formal-actual mapping to give aliases to the outputs
887   for (const auto i : c10::irange(schema.returns().size())) {
888     const auto actual = node->outputs().at(i);
889     const at::AliasInfo* formal = schema.returns()[i].alias_info();
890     if (!formal) {
891       // This is a fresh tensor
892       giveFreshAlias(actual);
893       continue;
894     }
895 
896     // If this type cannot alias, continue. Can occur with a VarType schema
897     if (!isMutableType(actual)) {
898       continue;
899     }
900 
901     TORCH_INTERNAL_ASSERT(
902         formal->containedTypes().size() <= 1,
903         "Composite types for alias analysis not yet supported");
904     TORCH_INTERNAL_ASSERT(formal->beforeSets() == formal->afterSets());
905     if (formal->containedTypes().size() == 1 && formal->beforeSets().empty()) {
906       // Use the first containedType in alias info.
907       formal = &(formal->containedTypes()[0]);
908     }
909     if (formal->isWildcardBefore()) {
910       TORCH_INTERNAL_ASSERT(
911           formal->beforeSets().size() == 1,
912           "If an output is a wildcard, "
913           "there should be no other alias sets specified.");
914       setWildcard(actual);
915       continue;
916     }
917 
918     bool inputs_has_alias = false;
919     for (const auto& formalAlias : formal->beforeSets()) {
920       if (formalToActual.count(formalAlias)) {
921         inputs_has_alias = true;
922         auto toAlias = formalToActual.at(formalAlias);
923         makePointerTo(actual, toAlias);
924       }
925     }
926     // If all the alias annotation that we encounter weren't in the inputs:
927     //   e.g. foo(Tensor(a) self) -> Tensor(b)
928     //   or foo(Tensor(a) self) -> Tensor(b|c)
929     // Otherwise it is the form of a|fresh, which we can ignore, taking the
930     // conservative assumption that the output must alias `a`, e.g
931     //   aten::cuda(Tensor(a) self) -> Tensor(a|fresh)
932     if (!inputs_has_alias && !formal->beforeSets().empty()) {
933       giveFreshAlias(actual);
934     }
935 
936     // Record writes
937     if (formal->isWrite()) {
938       registerWrite(actual, node);
939     }
940   }
941 }
942 
943 // Register the fact that `n` writes to `v`.
registerWrite(const Value * v,Node * n,bool writeToContained)944 void AliasDb::registerWrite(const Value* v, Node* n, bool writeToContained) {
945   if (!isMutableTypeInternal(v)) {
946     // don't need to register a write if the value isn't mutable
947     return;
948   }
949   if (writeToContained) {
950     writeRegistry_->registerWriteToAllContained(v, n);
951   } else {
952     writeRegistry_->registerWrite(v, n);
953   }
954 }
955 
analyzeIf(Node * node)956 void AliasDb::analyzeIf(Node* node) {
957   // For if statements, the alias set of an output is the union of the
958   // alias sets generated by the if and else block
959   const auto trueBlock = node->blocks().at(0);
960   const auto falseBlock = node->blocks().at(1);
961   analyze(trueBlock);
962   analyze(falseBlock);
963 
964   for (const auto i : c10::irange(node->outputs().size())) {
965     const auto nodeOutput = node->outputs()[i];
966 
967     const auto trueOutput = trueBlock->outputs().at(i);
968     const auto falseOutput = falseBlock->outputs().at(i);
969 
970     makePointerTo(nodeOutput, trueOutput);
971     makePointerTo(nodeOutput, falseOutput);
972   }
973 }
974 
analyzeLoop(Node * node)975 void AliasDb::analyzeLoop(Node* node) {
976   const auto bodyBlock = node->blocks().at(0);
977   const auto loopCarriedInputs = node->inputs().slice(2); // skip max, cond
978   const auto blockInputs = bodyBlock->inputs().slice(1); // skip trip
979   const auto blockOutputs = bodyBlock->outputs().slice(1); // skip trip
980   TORCH_INTERNAL_ASSERT(loopCarriedInputs.size() == blockInputs.size());
981   TORCH_INTERNAL_ASSERT(blockOutputs.size() == node->outputs().size());
982 
983   // Run alias analysis on the loop body, iterating until the block output
984   // alias info converges. Copy node input aliases to block input
985   mapAliases(blockInputs, loopCarriedInputs);
986 
987   // Populate block output alias info by analyzing the body
988   analyze(bodyBlock);
989 
990   // Copy the alias info from the block output to the node output
991   mapAliases(node->outputs(), blockOutputs);
992 }
993 
analyzeGradOf(Node * node)994 void AliasDb::analyzeGradOf(Node* node) {
995   const auto grad_of_block = node->blocks().at(0);
996   analyze(grad_of_block);
997   mapAliases(node->outputs(), grad_of_block->outputs());
998 }
999 
analyzeSubgraph(Node * node,const std::shared_ptr<Graph> & subgraph)1000 void AliasDb::analyzeSubgraph(
1001     Node* node,
1002     const std::shared_ptr<Graph>& subgraph) {
1003   const auto subgraphBlock = subgraph->block();
1004   // CallFunction nodes have an extra first parameter
1005   if (node->kind() == prim::CallFunction) {
1006     mapAliases(subgraphBlock->inputs(), node->inputs().slice(1));
1007   } else {
1008     mapAliases(subgraphBlock->inputs(), node->inputs());
1009   }
1010 
1011   analyze(subgraphBlock);
1012 
1013   // Note: the subgraph outputs and node outputs are NOT NECESSARILY the
1014   // same length. Autodifferentiation maybe capture additional outputs in the
1015   // subgraph block.
1016   TORCH_INTERNAL_ASSERT(
1017       subgraphBlock->outputs().size() >= node->outputs().size());
1018   for (size_t i = 0; i < node->outputs().size(); i++) {
1019     makePointerTo(node->outputs()[i], subgraphBlock->outputs()[i]);
1020   }
1021 }
1022 
analyzeSubgraph(Node * node)1023 void AliasDb::analyzeSubgraph(Node* node) {
1024   const auto subgraph = node->g(attr::Subgraph);
1025   return analyzeSubgraph(node, subgraph);
1026 }
1027 // For nodes that generate a fresh value from nothing
analyzeCreator(Node * node)1028 void AliasDb::analyzeCreator(Node* node) {
1029   for (Value* output : node->outputs()) {
1030     giveFreshAlias(output);
1031   }
1032 }
1033 
1034 // For nodes that extract values from a composite type. Right now, this just
1035 // gives up and creates wildcards for everything.
analyzeExtractor(Node * node)1036 void AliasDb::analyzeExtractor(Node* node) {
1037   for (const auto output : node->outputs()) {
1038     setWildcard(output);
1039   }
1040 }
1041 
1042 // For torch.chunk(), all returned tensors may alias the input tensor
analyzeChunk(Node * node)1043 void AliasDb::analyzeChunk(Node* node) {
1044   for (auto output : node->outputs()) {
1045     makePointerTo(output, node->input());
1046   }
1047 }
1048 
analyzeFork(Node * node)1049 void AliasDb::analyzeFork(Node* node) {
1050   for (const auto input : node->inputs()) {
1051     setWildcard(input);
1052   }
1053 
1054   // Give the future that the fork emits a fresh value
1055   for (const auto output : node->outputs()) {
1056     giveFreshAlias(output);
1057   }
1058 }
1059 
analyzeWait(Node * node)1060 void AliasDb::analyzeWait(Node* node) {
1061   TORCH_INTERNAL_ASSERT(node->kind() == aten::wait);
1062   for (const auto output : node->outputs()) {
1063     setWildcard(output);
1064   }
1065   // the forked subgraph that `wait` is waiting on may write to any of its
1066   // inputs. We don't have a reliable way of recovering the fork inputs, so
1067   // for safety we just register a write to every wildcard.
1068   writeRegistry_->registerWriteToAllWildcards(node);
1069 }
1070 
analyzeAwaitable(Node * node)1071 void AliasDb::analyzeAwaitable(Node* node) {
1072   for (const auto input : node->inputs()) {
1073     setWildcard(input);
1074   }
1075 
1076   for (const auto output : node->outputs()) {
1077     giveFreshAlias(output);
1078   }
1079 }
1080 
analyzeAwaitableWait(Node * node)1081 void AliasDb::analyzeAwaitableWait(Node* node) {
1082   TORCH_INTERNAL_ASSERT(node->kind() == prim::awaitable_wait);
1083   for (const auto output : node->outputs()) {
1084     setWildcard(output);
1085   }
1086   // the awaitable subgraph that `wait` is waiting on may write to any of its
1087   // inputs. We don't have a reliable way of recovering the awaitable inputs, so
1088   // for safety we just register a write to every wildcard.
1089   writeRegistry_->registerWriteToAllWildcards(node);
1090 }
1091 
analyzeRpcAsync(Node * node)1092 void AliasDb::analyzeRpcAsync(Node* node) {
1093   for (const auto input : node->inputs()) {
1094     setWildcard(input);
1095   }
1096 
1097   // Give the future that the rpc_async emits a fresh value
1098   for (const auto output : node->outputs()) {
1099     giveFreshAlias(output);
1100   }
1101 }
1102 
1103 namespace {
getConstantBooleanInput(Node * node,const std::string & inputName)1104 std::optional<bool> getConstantBooleanInput(
1105     Node* node,
1106     const std::string& inputName) {
1107   TORCH_INTERNAL_ASSERT(
1108       node->hasNamedInput(inputName), inputName + " input is expected");
1109   auto value = node->namedInput(inputName);
1110   TORCH_INTERNAL_ASSERT(
1111       value->type() == BoolType::get(),
1112       inputName + "training input is expected to be a bool");
1113   return constant_as<bool>(value);
1114 }
1115 } // namespace
1116 
1117 // custom behavior for batch_norm because (a!)? annotations currently
1118 // aren't supported, and because behavior differs depending on the value of
1119 // training
analyzeBatchNorm(Node * node)1120 void AliasDb::analyzeBatchNorm(Node* node) {
1121   // we invoking freezing for inference, so we assume training will be folded to
1122   // a constant false to avoid needing to invoke freezing multiple times in
1123   // order to make batch norm weights constant
1124   for (Value* output : node->outputs()) {
1125     giveFreshAlias(output);
1126   }
1127 
1128   if (isFrozen_) {
1129     return;
1130   }
1131 
1132   auto isTraining = getConstantBooleanInput(node, "training");
1133 
1134   if (!isTraining.has_value() || *isTraining) {
1135     TORCH_INTERNAL_ASSERT(
1136         node->hasNamedInput("running_mean"), "running_mean input is expected");
1137     auto runningMean = node->namedInput("running_mean");
1138     TORCH_INTERNAL_ASSERT(
1139         node->hasNamedInput("running_var"), "running_var input is expected");
1140     auto runningVar = node->namedInput("running_var");
1141 
1142     registerWrite(runningMean, node);
1143     registerWrite(runningVar, node);
1144   }
1145 }
1146 
1147 // custom behavior for instance_norm, because (a!)? annotations currently
1148 // aren't supported, and because behavior differs depending on the value of
1149 // use_input_stats
analyzeInstanceNorm(Node * node)1150 void AliasDb::analyzeInstanceNorm(Node* node) {
1151   for (Value* output : node->outputs()) {
1152     giveFreshAlias(output);
1153   }
1154 
1155   auto useInputStats = getConstantBooleanInput(node, "use_input_stats");
1156 
1157   if (!useInputStats.has_value() || *useInputStats) {
1158     TORCH_INTERNAL_ASSERT(
1159         node->hasNamedInput("running_mean"), "running_mean input is expected");
1160     auto runningMean = node->namedInput("running_mean");
1161     TORCH_INTERNAL_ASSERT(
1162         node->hasNamedInput("running_var"), "running_var input is expected");
1163     auto runningVar = node->namedInput("running_var");
1164 
1165     registerWrite(runningMean, node);
1166     registerWrite(runningVar, node);
1167   }
1168 }
1169 
1170 // SetAttr: writes to the `self` field
analyzeSetAttr(Node * node)1171 void AliasDb::analyzeSetAttr(Node* node) {
1172   const auto self = node->inputs().at(0);
1173   TORCH_INTERNAL_ASSERT(self->type()->kind() == TypeKind::ClassType);
1174   registerWrite(self, node);
1175   // Also the value being set must become a wildcard.
1176   const auto newValue = node->inputs().at(1);
1177   setWildcard(newValue);
1178 }
1179 
1180 // Used for anything where we do not have accurate alias summaries
1181 // may write to any input and produce wildcards
analyzeConservative(Node * node)1182 void AliasDb::analyzeConservative(Node* node) {
1183   for (const auto input : node->inputs()) {
1184     if (!isMutableTypeInternal(input)) {
1185       continue;
1186     }
1187     registerWrite(input, node, /*writeToContained=*/true);
1188     setWildcard(input);
1189   }
1190 
1191   for (const auto output : node->outputs()) {
1192     setWildcard(output);
1193   }
1194 }
1195 
functionalNonEscapingListUse(const Use & use) const1196 bool AliasDb::functionalNonEscapingListUse(const Use& use) const {
1197   Node* n = use.user;
1198   size_t offset = use.offset;
1199   Value* container = n->inputs().at(offset);
1200 
1201   // only consider aten op uses of lists
1202   if (!container->type()->cast<ListType>()) {
1203     return false;
1204   }
1205 
1206   /*
1207   in the general case, we consider any Value that enters another container as
1208   entering the heap, and thus aliasing all other heap values of the same type.
1209   the advantage of this approach are:
1210   - there are many composite list/container ops that would be tricky to
1211   schematize if we did something more complicated
1212   - limits the size of the AliasDb, because a container of size 10 only contains
1213   1 memory dag element instead of 10
1214   - we do not need to worry about adding contained elements to the wildcard set
1215   when a container escapes the graph.
1216   The downside of this approach is we are unable to handle the common case of a
1217   list constructed and passed into an aten op. Here, optimize for a set of
1218   common ops where the output does not alias the list or the list elements
1219   */
1220 
1221   // only used in output of graph - no further uses,
1222   // so there will be no use of it where the contained element leaks
1223   if (use.user->kind() == prim::Return) {
1224     return use.user->owningBlock() == graph_->block();
1225   }
1226 
1227   switch (use.user->kind()) {
1228     case aten::cat:
1229     case aten::broadcast_tensors:
1230     case aten::stack:
1231     case aten::vstack:
1232     case aten::hstack:
1233     case aten::dstack:
1234       return true;
1235   }
1236   auto op = use.user->maybeOperator();
1237   if (op && op->aliasAnalysisKind() == AliasAnalysisKind::PURE_FUNCTION) {
1238     return true;
1239   }
1240   return false;
1241 }
1242 
functionalNonEscapingTupleUse(const Use & use) const1243 bool AliasDb::functionalNonEscapingTupleUse(const Use& use) const {
1244   Node* n = use.user;
1245   size_t offset = use.offset;
1246   Value* container = n->inputs().at(offset);
1247   if (!container->type()->cast<TupleType>()) {
1248     return false;
1249   }
1250   // TODO(T97387453): Cover more ops that do not let escape tuples' elements.
1251   bool in_return_outputs = use.user->kind() == prim::Return;
1252   bool not_in_nested_subgraph = use.user->owningBlock() == graph_->block();
1253   return in_return_outputs && not_in_nested_subgraph;
1254 }
1255 
1256 // List or dict or tuple construct: create an aliasing element for the actual
1257 // container, then mark all inputs as wildcards, since they've gone inside the
1258 // container. Then, add the wildcard sets of appropriate type to the contained
1259 // elements of the container.
analyzeContainerConstruct(Node * node)1260 void AliasDb::analyzeContainerConstruct(Node* node) {
1261   TORCH_INTERNAL_ASSERT(
1262       node->kind() == prim::ListConstruct ||
1263       node->kind() == prim::DictConstruct ||
1264       node->kind() == prim::TupleConstruct);
1265 
1266   // tuples which contain immutable types are immutable
1267   if (!isMutableTypeInternal(node->output())) {
1268     return;
1269   }
1270 
1271   TORCH_INTERNAL_ASSERT(node->outputs().size() == 1);
1272   auto container = node->output();
1273 
1274   // optimization:
1275   // if a list is only used once in an aten op, and the op output
1276   // doesn't alias the input, then we can add all inputs to the list's
1277   // contained elements instead of the wildcard set.
1278   if (container->uses().size() == 1 &&
1279       (functionalNonEscapingListUse(container->uses().at(0)) ||
1280        functionalNonEscapingTupleUse(container->uses().at(0)))) {
1281     giveFreshAlias(container, false);
1282     for (Value* v : node->inputs()) {
1283       addToContainedElements(v, container);
1284     }
1285     return;
1286   }
1287 
1288   giveFreshAlias(container);
1289   auto container_elem = elementMap_.at(container);
1290   for (auto input : node->inputs()) {
1291     auto maybe_wildcard_elem = setWildcard(input);
1292     if (maybe_wildcard_elem) {
1293       memoryDAGBuilder_->addToContainedElements(
1294           *maybe_wildcard_elem, container_elem);
1295     }
1296   }
1297 }
1298 
1299 // BroadcastingChunk: all inputs are broadcasted, and then individually chunked.
1300 // This is an intermediate node used only in the graph fuser.
analyzeBroadcastingChunk(Node * node)1301 void AliasDb::analyzeBroadcastingChunk(Node* node) {
1302   auto inputs = node->inputs();
1303   auto outputs = node->outputs();
1304   auto nchunks = node->i(attr::chunks);
1305   for (const auto index : c10::irange(inputs.size())) {
1306     // Each inputs[i] is aliased by exactly `nchunks` distinct output tensors:
1307     // inputs[i] produces chunks outputs[i * nchunks + k] for k in [0..nchunks)
1308     auto output_begin = outputs.begin() + index * nchunks;
1309     for (auto it = output_begin; it != output_begin + nchunks; ++it) {
1310       makePointerTo(*it, inputs.at(index));
1311     }
1312   }
1313 }
1314 
nonAliasingValue(const Value * elem) const1315 bool AliasDb::nonAliasingValue(const Value* elem) const {
1316   // these are values which can point to aliasing types in the graph,
1317   // as with a None value pointing to an optional if node output,
1318   // but will never alias themselves
1319   return elem->mustBeNone() || elem->node()->kind() == prim::Uninitialized;
1320 }
1321 
1322 // Register the fact that `from` is a pointer to `to`
makePointerTo(const Value * from,const Value * to)1323 void AliasDb::makePointerTo(const Value* from, const Value* to) {
1324   if (nonAliasingValue(from) || nonAliasingValue(to)) {
1325     // if either value is guaranteed to be non-aliasing, we do not need to
1326     // connect the two elements. however, it is invariant that aliasing types
1327     // that are not wildcards have a memory dag element, so we create one if
1328     // needed
1329     giveFreshAlias(from);
1330     giveFreshAlias(to);
1331     return;
1332   }
1333 
1334   // The contained types of immutable type containers (`Optional`,
1335   // `Tuple`, `Future`, and `Union`) are unified, so these types can be
1336   // mutable or immutable and point to a type which is mutable or
1337   // immutable. `Any` is mutable but can point to an immutable type
1338   // through refinement
1339   if (isMutableTypeInternal(from) != isMutableTypeInternal(to)) {
1340     return;
1341   }
1342   // both immutable
1343   if (!isMutableTypeInternal(from)) {
1344     return;
1345   }
1346   if (from == to) {
1347     return;
1348   }
1349 
1350   // At this point, we are dealing with two mutable types
1351   auto from_el = getOrCreateElement(from);
1352   auto to_el = getOrCreateElement(to);
1353 
1354   memoryDAGBuilder_->makePointerTo(from_el, to_el);
1355 }
1356 
addToContainedElements(const Value * inner,const Value * container)1357 void AliasDb::addToContainedElements(
1358     const Value* inner,
1359     const Value* container) {
1360   if (!isMutableTypeInternal(inner)) {
1361     return;
1362   }
1363 
1364   auto inner_el = getOrCreateElement(inner);
1365   auto cont_el = getOrCreateElement(container);
1366 
1367   memoryDAGBuilder_->addToContainedElements(inner_el, cont_el);
1368 }
1369 
mayAlias(const Value * a,const Value * b) const1370 bool AliasDb::mayAlias(const Value* a, const Value* b) const {
1371   if (!isMutableTypeInternal(a) || !isMutableTypeInternal(b)) {
1372     return false;
1373   }
1374 
1375   return memoryDAG_->mayAlias(elementMap_.at(a), elementMap_.at(b));
1376 }
1377 
mayAlias(const ValueSet & a,const ValueSet & b) const1378 bool AliasDb::mayAlias(const ValueSet& a, const ValueSet& b) const {
1379   if (a.empty() || b.empty()) {
1380     return false;
1381   }
1382 
1383   // Record all memory locations from group `a`
1384   MemoryLocations aMemLocs;
1385   for (const auto value : a) {
1386     auto it = elementMap_.find(value);
1387     if (it != elementMap_.end()) {
1388       aMemLocs |= memoryDAG_->getMemoryLocations(it->second);
1389     }
1390   }
1391 
1392   // If any of group `b`s memory locations overlap, return true.
1393   for (const auto value : b) {
1394     auto it = elementMap_.find(value);
1395     if (it != elementMap_.end()) {
1396       if (aMemLocs.intersects(memoryDAG_->getMemoryLocations(it->second))) {
1397         return true;
1398       }
1399     }
1400   }
1401   // No overlap, so group `a` and `b` do not share a memory location
1402   return false;
1403 }
1404 
mayContainAlias(Value * a,Value * b) const1405 bool AliasDb::mayContainAlias(Value* a, Value* b) const {
1406   if (!isMutableTypeInternal(a) || !isMutableTypeInternal(b)) {
1407     return false;
1408   }
1409   return memoryDAG_->mayContainAlias(elementMap_.at(a), elementMap_.at(b));
1410 }
1411 
getElements(at::ArrayRef<Value * > vs) const1412 std::vector<Element*> AliasDb::getElements(at::ArrayRef<Value*> vs) const {
1413   std::vector<Element*> elements;
1414   for (const auto& val : vs) {
1415     if (isMutableTypeInternal(val)) {
1416       elements.push_back(elementMap_.at(val));
1417     }
1418   }
1419   return elements;
1420 }
1421 
mayContainAlias(const at::ArrayRef<Value * > a,const at::ArrayRef<Value * > b) const1422 bool AliasDb::mayContainAlias(
1423     const at::ArrayRef<Value*> a,
1424     const at::ArrayRef<Value*> b) const {
1425   auto a_elems = getElements(a);
1426   return a_elems.empty() ? false
1427                          : memoryDAG_->mayContainAlias(a_elems, getElements(b));
1428 }
1429 
mayContainAlias(Value * a,const at::ArrayRef<Value * > b) const1430 bool AliasDb::mayContainAlias(Value* a, const at::ArrayRef<Value*> b) const {
1431   if (!isMutableTypeInternal(a)) {
1432     return false;
1433   }
1434   auto b_elems = getElements(b);
1435   return b_elems.empty()
1436       ? false
1437       : memoryDAG_->mayContainAlias(elementMap_.at(a), b_elems);
1438 }
1439 
1440 // Make each value in the `from` list point to its partner in the `to` list
mapAliases(at::ArrayRef<Value * > from,at::ArrayRef<Value * > to)1441 void AliasDb::mapAliases(at::ArrayRef<Value*> from, at::ArrayRef<Value*> to) {
1442   TORCH_INTERNAL_ASSERT(to.size() == from.size());
1443   for (const auto i : c10::irange(to.size())) {
1444     makePointerTo(from[i], to[i]);
1445   }
1446 }
1447 
1448 // Should only be called from create_functional_graphs.
1449 // The asserts are to guard against unintentional use.
1450 // FIXME refactor aliasdb construction to be more robust to mutation so this
1451 // hack isn't necessary.
createValue(const Value * value)1452 void AliasDb::createValue(const Value* value) {
1453   TORCH_INTERNAL_ASSERT(isMutableTypeInternal(value->type()));
1454   auto new_elem = memoryDAG_->unsafeMakeFreshValue(value);
1455   elementMap_[value] = new_elem;
1456 }
1457 
giveFreshAlias(const Value * value,bool add_wildcard_to_contained_elems)1458 void AliasDb::giveFreshAlias(
1459     const Value* value,
1460     bool add_wildcard_to_contained_elems) {
1461   auto maybe_mut_types = mapTypeToAliasTypeSetPtr(value->type());
1462   if (!maybe_mut_types) {
1463     return;
1464   }
1465 
1466   if (elementMap_.count(value)) {
1467     // Inside a loop, we may have given a fresh alias to this value already, so
1468     // skip
1469     return;
1470   }
1471 
1472   auto new_elem = memoryDAGBuilder_->makeFreshValue(value);
1473   elementMap_[value] = new_elem;
1474   if (add_wildcard_to_contained_elems) {
1475     if (maybe_mut_types->size() > 1) {
1476       pointUnionTypeElementToAllContainedTypes(new_elem, *maybe_mut_types);
1477     } else {
1478       addContainedTypesToFreshElement(new_elem, *maybe_mut_types);
1479     }
1480   }
1481 }
1482 
getOrCreateElement(const Value * value)1483 Element* AliasDb::getOrCreateElement(const Value* value) {
1484   if (!elementMap_.count(value)) {
1485     giveFreshAlias(value);
1486   }
1487   return elementMap_.at(value);
1488 }
1489 
replaceWithNewValue(Value * existing,Value * new_value)1490 void AliasDb::replaceWithNewValue(Value* existing, Value* new_value) {
1491   TORCH_INTERNAL_ASSERT(
1492       *unshapedType(existing->type()) == *unshapedType(new_value->type()),
1493       "Types must be strictly equal if you are replacing aliasing information. ",
1494       "Got existing: '",
1495       existing->type()->repr_str(),
1496       "', new_value: '",
1497       new_value->type()->repr_str(),
1498       "'");
1499   if (!isMutableTypeInternal(existing)) {
1500     return;
1501   }
1502   auto existing_elem = elementMap_.at(existing);
1503   elementMap_[new_value] = existing_elem;
1504   elementMap_.erase(existing);
1505   existing_elem->values = {new_value};
1506 }
1507 
copyValue(Value * from,Value * to)1508 void AliasDb::copyValue(Value* from, Value* to) {
1509   TORCH_INTERNAL_ASSERT(
1510       *unshapedType(from->type()) == *unshapedType(to->type()),
1511       "Types must be strictly equal if you are copying aliasing information. ",
1512       "Got from: '",
1513       from->type()->repr_str(),
1514       "', to: '",
1515       to->type()->repr_str(),
1516       "'");
1517   if (!isMutableTypeInternal(to)) {
1518     return;
1519   }
1520   auto origElem = elementMap_.at(from);
1521   elementMap_[to] = origElem;
1522   origElem->values.insert(to);
1523 }
1524 
moveAfterTopologicallyValid(Node * n,Node * movePoint)1525 bool AliasDb::moveAfterTopologicallyValid(Node* n, Node* movePoint) {
1526   return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/false);
1527 }
1528 
couldMoveAfterTopologically(Node * n,Node * movePoint)1529 bool AliasDb::couldMoveAfterTopologically(Node* n, Node* movePoint) {
1530   return tryMove(n, movePoint, MoveSide::AFTER, /*dryRun=*/true);
1531 }
1532 
moveBeforeTopologicallyValid(Node * n,Node * movePoint)1533 bool AliasDb::moveBeforeTopologicallyValid(Node* n, Node* movePoint) {
1534   // We have to distinguish the move side (instead of just moving after
1535   // n->prev()). Consider the following example:
1536   // If the dependency graph looks like
1537   //   n -> movePoint -> o
1538   // then moveBefore(o) will end up with
1539   //   n, o, movePoint
1540   // but moveAfter(n) will return false.
1541   return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/false);
1542 }
1543 
couldMoveBeforeTopologically(Node * n,Node * movePoint)1544 bool AliasDb::couldMoveBeforeTopologically(Node* n, Node* movePoint) {
1545   return tryMove(n, movePoint, MoveSide::BEFORE, /*dryRun=*/true);
1546 }
1547 
hasWriters(const at::ArrayRef<Value * > & values) const1548 bool AliasDb::hasWriters(const at::ArrayRef<Value*>& values) const {
1549   return std::any_of(values.begin(), values.end(), [&](Value* value) {
1550     return hasWriters(value);
1551   });
1552 }
1553 
escapesScope(const at::ArrayRef<Value * > & vs) const1554 bool AliasDb::escapesScope(const at::ArrayRef<Value*>& vs) const {
1555   return mayContainAlias(graph_->inputs(), vs) ||
1556       mayContainAlias(graph_->outputs(), vs) || mayAliasWildcard(vs);
1557 }
1558 
1559 // Correctness conditions:
1560 // no values in either set can have writers, and values in both sets
1561 // cannot escape the current graph scope. Values can escape the current scope
1562 // by aliasing a graph output or input, or by aliasing the wildcard set.
safeToChangeAliasingRelationship(const at::ArrayRef<Value * > & a,const at::ArrayRef<Value * > & b) const1563 bool AliasDb::safeToChangeAliasingRelationship(
1564     const at::ArrayRef<Value*>& a,
1565     const at::ArrayRef<Value*>& b) const {
1566   if (hasWriters(a) || hasWriters(b)) {
1567     return false;
1568   }
1569 
1570   return !(escapesScope(a) && escapesScope(b));
1571 }
1572 
1573 // Helper for topologically-safe node moves. See `tryMove()` for details.
1574 class AliasDb::WorkingSet {
1575  public:
WorkingSet(Node * mover,const AliasDb & aliasDb)1576   explicit WorkingSet(Node* mover, const AliasDb& aliasDb)
1577       : aliasDb_(aliasDb), mover_(mover) {
1578     for (const auto user : getUsersSameBlock(mover_)) {
1579       moverUsers_.insert(user);
1580     }
1581     moverWrites_ |= aliasDb_.getWrites(mover_);
1582     moverReads_ |= aliasDb_.getReads(mover_);
1583   }
1584 
1585   // Add `n` to the working set
add(Node * n)1586   void add(Node* n) {
1587     nodes_.push_back(n);
1588     node_to_index_[n] = static_cast<int64_t>(nodes_.size()) - 1;
1589     for (const auto user : getUsersSameBlock(n)) {
1590       users_.insert(user);
1591     }
1592 
1593     writes_ |= aliasDb_.getWrites(n);
1594     reads_ |= aliasDb_.getReads(n);
1595   }
1596 
eraseMover()1597   void eraseMover() {
1598     mover_ = nullptr;
1599     moverWrites_.clear();
1600     moverReads_.clear();
1601     moverUsers_.clear();
1602   }
1603 
dependentNodes()1604   const std::vector<Node*>& dependentNodes() {
1605     return nodes_;
1606   }
1607 
1608   // Does the working set depend on `n`?
dependsOn(Node * n) const1609   bool dependsOn(Node* n) const {
1610     if (!mover_ && nodes_.empty()) {
1611       return false;
1612     }
1613 
1614     return hasDataDependency(n) || hasMutabilityDependency(n);
1615   }
1616 
1617  private:
hasDataDependency(Node * n) const1618   bool hasDataDependency(Node* n) const {
1619     if (!mover_ && nodes_.empty()) {
1620       return false;
1621     }
1622     const Node* pivot = mover_ ? mover_ : nodes_.front();
1623     if (n->isAfter(pivot)) {
1624       return producesFor(n);
1625     } else {
1626       return consumesFrom(n);
1627     }
1628   }
1629 
hasMutabilityDependency(Node * n) const1630   bool hasMutabilityDependency(Node* n) const {
1631     // Check that `n` does not write to anything used by the working set
1632     const auto& nWrites = aliasDb_.getWrites(n);
1633     if (reads_.intersects(nWrites)) {
1634       return true;
1635     }
1636     if (mover_ && moverReads_.intersects(nWrites)) {
1637       return true;
1638     }
1639 
1640     // Check that the working set doesn't write to anything that `n` uses.
1641     const auto& nReads = aliasDb_.getReads(n);
1642     if (writes_.intersects(nReads)) {
1643       return true;
1644     }
1645     if (mover_ && moverWrites_.intersects(nReads)) {
1646       return true;
1647     }
1648     return false;
1649   }
1650 
1651   // Does the working set produce any values consumed by `n`?
producesFor(Node * n) const1652   bool producesFor(Node* n) const {
1653     // This equivalent to asking: does the total use-set of all the nodes in the
1654     // working set include `n`?
1655     if (mover_ && moverUsers_.count(n)) {
1656       return true;
1657     }
1658     return users_.count(n) != 0;
1659   }
1660 
1661   // Does the working set consume any values produced by `n`?
consumesFrom(Node * n) const1662   bool consumesFrom(Node* n) const {
1663     const auto users = getUsersSameBlock(n);
1664 
1665     if (mover_ && users.count(mover_)) {
1666       return true;
1667     }
1668     return std::any_of(users.begin(), users.end(), [&](Node* user) {
1669       return node_to_index_.find(user) != node_to_index_.end();
1670     });
1671   }
1672 
1673   // Get all users of outputs of `n`, in the same block as `n`.
1674   // This means if there is an `if` node that uses an output of `n` in some
1675   // inner sub-block, we will consider the whole `if` node a user of `n`.
getUsersSameBlock(Node * n) const1676   std::unordered_set<Node*> getUsersSameBlock(Node* n) const {
1677     std::unordered_set<Node*> users;
1678     for (const auto output : n->outputs()) {
1679       for (const auto& use : output->uses()) {
1680         if (auto sameBlock = findSameBlock(use.user, n)) {
1681           users.insert(sameBlock);
1682         }
1683       }
1684     }
1685     return users;
1686   }
1687 
1688   // Traverse `target`'s blockchain upward until we find a node that shares a
1689   // block with `n`.
1690   //
1691   // If one can't be found (say, because `n` is an inner block and target is
1692   // outside), then return nullptr. Since we can only reorder nodes within a
1693   // block, `target` would be irrelevant.
findSameBlock(Node * target,Node * n)1694   static Node* findSameBlock(Node* target, Node* n) {
1695     TORCH_INTERNAL_ASSERT(target->owningGraph() == n->owningGraph());
1696     if (target->owningBlock() == n->owningBlock()) {
1697       return target;
1698     } else {
1699       // This user is in a sub-block. Traverse the blockchain upward until
1700       // we arrive at a node that shares a block with `this`
1701       auto curNode = target;
1702       while (curNode->owningBlock() != n->owningBlock()) {
1703         curNode = curNode->owningBlock()->owningNode();
1704         if (curNode == nullptr) {
1705           return curNode;
1706         }
1707       }
1708       return curNode;
1709     }
1710   }
1711 
1712   // NOLINTNEXTLINE(cppcoreguidelines-avoid-const-or-ref-data-members)
1713   const AliasDb& aliasDb_;
1714   std::vector<Node*> nodes_;
1715   // Extra data structure for nodes for faster look up
1716   // Since the tryMove method is used a lot, we want to
1717   // make it as fast as possible.
1718   std::unordered_map<Node*, int64_t> node_to_index_;
1719 
1720   // Mover dependencies. We track these separately since we may erase the mover
1721   // from the working set.
1722   Node* mover_;
1723   MemoryLocations moverWrites_;
1724   MemoryLocations moverReads_;
1725   std::unordered_set<Node*> moverUsers_;
1726 
1727   // users => # of working set nodes it uses
1728   std::unordered_set<Node*> users_;
1729   // Values written to by the working set => number of nodes writing to value
1730   MemoryLocations writes_;
1731   MemoryLocations reads_;
1732 };
1733 
1734 // Try to move `toMove` before/after `movePoint` while preserving value
1735 // dependencies. Returns false iff such a move could not be made.
1736 //
1737 // If `dryRun` is set, don't actually execute the move, just check if the move
1738 // is possible
1739 //
1740 // The basic approach is: have a "working set" that we are moving forward, one
1741 // node at a time. When we can't move past a node (because it depends on the
1742 // working set), then add it to the working set and keep moving until we hit
1743 // `moveAfter`.
tryMove(Node * toMove,Node * movePoint,MoveSide moveSide,bool dryRun)1744 bool AliasDb::tryMove(
1745     Node* toMove,
1746     Node* movePoint,
1747     MoveSide moveSide,
1748     bool dryRun) {
1749   if (toMove->owningBlock() != movePoint->owningBlock()) {
1750     return false;
1751   }
1752   if (toMove == movePoint) {
1753     return true;
1754   }
1755 
1756   // 1. Move from `this` toward movePoint, building up the working set of
1757   // dependencies
1758   WorkingSet workingSet(toMove, *this);
1759 
1760   auto direction = kNextDirection;
1761   if (toMove->isAfter(movePoint)) {
1762     direction = kPrevDirection;
1763   }
1764 
1765   auto curNode = toMove->next_in_graph[direction];
1766 
1767   bool toMoveIsOnMoveSide =
1768       (moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
1769       (moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
1770 
1771   if (toMoveIsOnMoveSide && curNode == movePoint) {
1772     return true;
1773   }
1774 
1775   // it is never valid to move reorder a node with side effects
1776   if (toMove->hasSideEffects() ||
1777       (!toMoveIsOnMoveSide && movePoint->hasSideEffects())) {
1778     return false;
1779   }
1780 
1781   // Move forward one node at a time
1782   while (curNode != movePoint) {
1783     // never valid to reorder around a node with side effects
1784     if (curNode->hasSideEffects()) {
1785       return false;
1786     }
1787 
1788     if (workingSet.dependsOn(curNode)) {
1789       // If we can't move past this node, add it to the working set
1790       workingSet.add(curNode);
1791     }
1792     curNode = curNode->next_in_graph[direction];
1793   }
1794 
1795   // 2. Decide whether we can move it all to `movePoint`.
1796 
1797   // Say we are moving directly before movePoint and `toMove` starts before
1798   // movePoint in the graph. The move looks like
1799   //
1800   //  `toMove`            `toMove`         |
1801   //  <dependencies>  ->  `movePoint`      | `toMove` and deps are split
1802   //  `movePoint`         <dependencies>   |
1803   //
1804   // Contrast with the case where `toMove` starts AFTER movePoint:
1805   //
1806   //  `movePoint`           <dependencies>   |
1807   //  <dependencies>  ->    `toMove`         | `toMove` and deps are together
1808   //  `toMove`              `movePoint`      |
1809   //
1810   // In the first case, we need to split `this` off from its dependencies, so we
1811   // can move the dependencies below `movePoint` and keep `toMove` above.
1812   const bool splitToMoveAndDeps =
1813       (moveSide == MoveSide::BEFORE && toMove->isBefore(movePoint)) ||
1814       (moveSide == MoveSide::AFTER && toMove->isAfter(movePoint));
1815 
1816   if (splitToMoveAndDeps) {
1817     // remove `this` from dependencies to be moved past `movePoint`
1818     workingSet.eraseMover();
1819   }
1820 
1821   // Check if we can move the working set past the move point
1822   if (workingSet.dependsOn(movePoint)) {
1823     // if we can't, then there are intermediate dependencies between the
1824     // `this` and `movePoint`, so we can't do the move
1825     return false;
1826   }
1827 
1828   if (dryRun) {
1829     return true;
1830   }
1831 
1832   // 3. Execute the move
1833   TORCH_INTERNAL_ASSERT(curNode == movePoint);
1834   if (splitToMoveAndDeps) {
1835     // Move `toMove`
1836     move(toMove, movePoint, moveSide);
1837 
1838     // Then move all of its dependencies on the other side of `movePoint`
1839     const auto reversed =
1840         moveSide == MoveSide::BEFORE ? MoveSide::AFTER : MoveSide::BEFORE;
1841     for (auto n : workingSet.dependentNodes()) {
1842       move(n, curNode, reversed);
1843       curNode = n;
1844     }
1845   } else {
1846     // Just append/prepend everything to `movePoint`
1847     move(toMove, curNode, moveSide);
1848     curNode = toMove;
1849     for (auto n : workingSet.dependentNodes()) {
1850       move(n, curNode, moveSide);
1851       curNode = n;
1852     }
1853   }
1854   return true;
1855 }
1856 
1857 // Helper function so we can generalize `tryMove`
move(Node * toMove,Node * movePoint,MoveSide moveSide)1858 void AliasDb::move(Node* toMove, Node* movePoint, MoveSide moveSide) {
1859   switch (moveSide) {
1860     case MoveSide::BEFORE:
1861       toMove->moveBefore(movePoint);
1862       break;
1863     case MoveSide::AFTER:
1864       toMove->moveAfter(movePoint);
1865       break;
1866   }
1867 }
1868 
writesToWildcard(Node * n) const1869 bool AliasDb::writesToWildcard(Node* n) const {
1870   if (!writeIndex_->count(n)) {
1871     return false;
1872   }
1873   const auto& writes = writeIndex_->at(n);
1874 
1875   // Are any of these memoryLocs a wildcard element?
1876   for (const auto& pr : wildcardIndex_) {
1877     const auto wildcardElement = pr.second;
1878     if (writes.test(wildcardElement->index)) {
1879       return true;
1880     }
1881   }
1882   return false;
1883 }
1884 
mayAliasWildcard(const Value * v) const1885 bool AliasDb::mayAliasWildcard(const Value* v) const {
1886   if (auto e = getWildcard(v->type())) {
1887     return memoryDAG_->mayAlias(elementMap_.at(v), e);
1888   }
1889   // There were no wildcards of this type, so return false.
1890   return false;
1891 }
1892 
mayAliasWildcard(const at::ArrayRef<Value * > vs) const1893 bool AliasDb::mayAliasWildcard(const at::ArrayRef<Value*> vs) const {
1894   return std::any_of(
1895       vs.begin(), vs.end(), [&](Value* v) { return mayAliasWildcard(v); });
1896 }
1897 
tryGetOrCreateWildcard(const TypePtr & type)1898 std::optional<Element*> AliasDb::tryGetOrCreateWildcard(const TypePtr& type) {
1899   auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type);
1900   if (!maybe_mut_types) {
1901     return std::nullopt;
1902   }
1903   auto mut_type = toSingleType(*maybe_mut_types);
1904   auto existing_wildcard = wildcardIndex_.find(*mut_type);
1905   if (existing_wildcard != wildcardIndex_.end()) {
1906     return existing_wildcard->second;
1907   }
1908 
1909   auto wildcard_elem = memoryDAGBuilder_->makeFreshValue(nullptr);
1910   wildcardIndex_.emplace(*std::move(mut_type), wildcard_elem);
1911   if (maybe_mut_types->size() > 1) {
1912     pointUnionTypeElementToAllContainedTypes(wildcard_elem, *maybe_mut_types);
1913   } else {
1914     addContainedTypesToFreshElement(wildcard_elem, *maybe_mut_types);
1915   }
1916   return wildcard_elem;
1917 }
1918 
pointUnionTypeElementToAllContainedTypes(Element * container_elem,const AliasTypeSet & mut_types)1919 void AliasDb::pointUnionTypeElementToAllContainedTypes(
1920     Element* container_elem,
1921     const AliasTypeSet& mut_types) {
1922   for (const auto& mut_type : mut_types) {
1923     auto maybe_elem = tryGetOrCreateWildcard(mut_type);
1924     if (maybe_elem) {
1925       TORCH_INTERNAL_ASSERT(*maybe_elem != container_elem);
1926       memoryDAGBuilder_->makePointerTo(container_elem, *maybe_elem);
1927     }
1928   }
1929 }
1930 
addContainedTypesToFreshElement(Element * container_elem,const AliasTypeSet & mut_types)1931 void AliasDb::addContainedTypesToFreshElement(
1932     Element* container_elem,
1933     const AliasTypeSet& mut_types) {
1934   for (const auto& mut_type : mut_types) {
1935     for (const auto& contained : mut_type->containedTypes()) {
1936       auto maybe_elem = tryGetOrCreateWildcard(contained);
1937       if (maybe_elem) {
1938         memoryDAGBuilder_->addToContainedElements(*maybe_elem, container_elem);
1939       }
1940     }
1941   }
1942 }
1943 
1944 // Search the wildcard index for an element that corresponds to the given type.
1945 // Const version returns nullptr
getWildcard(const TypePtr & type) const1946 Element* AliasDb::getWildcard(const TypePtr& type) const {
1947   auto maybe_mut_types = mapTypeToAliasTypeSetPtr(type);
1948   if (!maybe_mut_types) {
1949     return {};
1950   }
1951   if (maybe_mut_types->size() > 1) {
1952     auto union_type = UnionType::create(*maybe_mut_types);
1953     // Get a <TypePtr, Element*> pair where the TypePtr is this Union
1954     // type and the Element is the corresponding Wildcard
1955     auto maybe_union_pair = wildcardIndex_.find(union_type);
1956     if (maybe_union_pair != wildcardIndex_.end()) {
1957       return (*maybe_union_pair).second;
1958     }
1959   } else {
1960     // Get a <TypePtr, Element*> pair where the TypePtr is the given
1961     // type and the Element is the corresponding Wildcard
1962     auto type_pair = wildcardIndex_.find((*maybe_mut_types)[0]);
1963     if (type_pair != wildcardIndex_.end()) {
1964       return type_pair->second;
1965     }
1966   }
1967   return {};
1968 }
1969 
1970 // Register `v` as a wildcard value.
setWildcard(const Value * v)1971 std::optional<Element*> AliasDb::setWildcard(const Value* v) {
1972   std::optional<Element*> maybe_wildcardElement =
1973       tryGetOrCreateWildcard(v->type());
1974   if (!maybe_wildcardElement) {
1975     return std::nullopt;
1976   }
1977   // Ensure that we create a corresponding Element for `v` still, as it is an
1978   // invariant that all mutable values have an Element
1979   getOrCreateElement(v);
1980   wildcards_.insert(v);
1981   return maybe_wildcardElement;
1982 }
1983 
buildWrittenToLocationsIndex()1984 void AliasDb::buildWrittenToLocationsIndex() {
1985   MemoryLocations ret;
1986   for (const auto& pr : *writeIndex_) {
1987     const auto& writtenLocs = pr.second;
1988     ret |= writtenLocs;
1989   }
1990   writtenToLocationsIndex_ = ret;
1991 }
1992 
Lint(const AliasDb * db)1993 void Lint(const AliasDb* db) {
1994   bool failed = false;
1995 
1996   std::stringstream ss;
1997   // Every mutable value in the system has a corresponding element.
1998   for (const auto& v : db->graph_->all_values) {
1999     if (!db->isMutableTypeInternal(v)) {
2000       continue;
2001     }
2002     auto it = db->elementMap_.find(v);
2003     if (it == db->elementMap_.end()) {
2004       failed = true;
2005       ss << "Value %" << v->debugName() << " of type " << v->type()->repr_str()
2006          << " wasn't found in the element map.\n"
2007          << "It was defined in " << *v->node();
2008     }
2009   }
2010   TORCH_INTERNAL_ASSERT(!failed, ss.str());
2011 
2012   // Two checks that we want to add but can't until the mutation API is more
2013   // fully developed.
2014   // - Every mutable value in the aliasdb belongs to the graph
2015   // - All container values have contained elements
2016 }
2017 
2018 } // namespace torch::jit
2019