1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
17 
18 #include <cstdint>
19 #include <initializer_list>
20 #include <utility>
21 
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/iterator_range.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Analysis/CallGraph.h"  // from @llvm-project
31 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
32 #include "mlir/IR/Attributes.h"  // from @llvm-project
33 #include "mlir/IR/Block.h"  // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
36 #include "mlir/IR/Operation.h"  // from @llvm-project
37 #include "mlir/IR/Value.h"  // from @llvm-project
38 #include "mlir/IR/Visitors.h"  // from @llvm-project
39 #include "mlir/Interfaces/CallInterfaces.h"  // from @llvm-project
40 #include "mlir/Support/LLVM.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
46 
47 namespace mlir {
48 namespace TF {
49 namespace detail {
50 
51 //===----------------------------------------------------------------------===//
52 // BacktrackAnalysisInfo
53 //===----------------------------------------------------------------------===//
54 // Class to hold backtrack analysis for a results of a region. Backtrack
55 // analysis will trace back the definition of return values of regions through
56 // pass-through operations, so that the return value of the region will have the
57 // same value as the backtracked value.
58 class BacktrackAnalysisInfo {
59  public:
60   // Initializes the backtrack analysis for the given region.
61   explicit BacktrackAnalysisInfo(Region& region,
62                                  detail::BacktrackAnalysis& backtrack_analysis);
63 
64   BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
65 
66   // Returns the value to which the given result number of the region can be
67   // backtracked to.
GetValue(int result_index) const68   Value GetValue(int result_index) const {
69     return backtracked_values_[result_index];
70   }
71 
72   // Returns the argument index of the region to which the given result number
73   // can backtracked to. Such results will be called "function passthrough". If
74   // the result cannot be backtracked to a region argument, returns llvm::None.
GetArg(int result_index) const75   llvm::Optional<int> GetArg(int result_index) const {
76     if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
77       if (arg.getParentBlock() == &region_->front()) return arg.getArgNumber();
78     return llvm::None;
79   }
80 
81  private:
82   friend class detail::BacktrackAnalysis;
83 
84   // Region for which this object holds the analysis info.
85   Region* region_;
86 
87   // Backtracked values indexed by the result number.
88   llvm::SmallVector<Value, 4> backtracked_values_;
89 };
90 
91 //===----------------------------------------------------------------------===//
92 // BacktrackAnalysis
93 //===----------------------------------------------------------------------===//
94 // Holds backtrack analysis for all functions and regions within a module.
95 class BacktrackAnalysis {
96  public:
97   using InfoT = BacktrackAnalysisInfo;
98 
99   // Constructs the analysis by analyzing the given module.
100   BacktrackAnalysis(ModuleOp module,
101                     SymbolTableCollection& symbol_table_collection);
102 
103   // Returns backtracking analysis for the given region.
GetAnalysisForRegion(Region & region) const104   const InfoT& GetAnalysisForRegion(Region& region) const {
105     auto it = info_map_.find(&region);
106     assert(it != info_map_.end());
107     return it->second;
108   }
109 
110   // Returns backtracking analysis for the given function.
GetAnalysisForFunc(func::FuncOp func) const111   const InfoT& GetAnalysisForFunc(func::FuncOp func) const {
112     return GetAnalysisForRegion(func.getBody());
113   }
114 
115   // Backtracks the given value.
116   Value BacktrackValue(Value value);
117 
118  private:
119   // Returns the analysis for the given region (analyzing the region if it has
120   // not yet been analyzed).
GetOrCreateAnalysis(Region & region)121   const InfoT& GetOrCreateAnalysis(Region& region) {
122     auto it = info_map_.find(&region);
123     if (it == info_map_.end()) {
124       // Note: Keep object construction and insertion separate. If we use
125       // emplace() to construct and insert in a single shot, when analyzing
126       // this region, calls to BacktrackValue() may end up inserting additional
127       // entries in the map, causing the underlying storage to be moved. This
128       // would also include this pertially constructed object that we have just
129       // inserted into the map and are constructing it. To avoid this issue,
130       // construct the analysis object separately and then insert it into the
131       // map.
132       InfoT info(region, *this);
133       info_map_.insert({&region, std::move(info)});
134     }
135 
136     return GetAnalysisForRegion(region);
137   }
138 
139   // Returns the backtrack analysis for the given region if it exists.
140   // If the region has not yet been analyzed, returns llvm::None.
GetAnalysisIfExists(Region & region) const141   Optional<const InfoT*> GetAnalysisIfExists(Region& region) const {
142     auto it = info_map_.find(&region);
143     if (it == info_map_.end()) return llvm::None;
144     return &it->second;
145   }
146 
GetAnalysisIfExists(func::FuncOp func) const147   Optional<const InfoT*> GetAnalysisIfExists(func::FuncOp func) const {
148     return GetAnalysisIfExists(func.getBody());
149   }
150 
151  private:
152   llvm::SmallDenseMap<Region*, InfoT> info_map_;
153   SymbolTableCollection& symbol_table_collection_;
154 };
155 
156 // Analyzes all regions attached to all operations in the module.
BacktrackAnalysis(ModuleOp module,SymbolTableCollection & symbol_table_collection)157 BacktrackAnalysis::BacktrackAnalysis(
158     ModuleOp module, SymbolTableCollection& symbol_table_collection)
159     : symbol_table_collection_(symbol_table_collection) {
160   const CallGraph call_graph(module);
161 
162   // Visit functions bottom up when doing the analysis. Note that SCC iterator
163   // has the property that if there is an edge from SCC1->SCC2, SCC1 is visited
164   // after SCC2, i.e., the graph is traversed bottom up just the way we want.
165   auto scc_begin = llvm::scc_begin(&call_graph);
166   auto scc_end = llvm::scc_end(&call_graph);
167   for (auto& scc : make_range(scc_begin, scc_end)) {
168     // Each SCC node is a collection of callgraph nodes that form a cycle. We
169     // will visit these nodes in an arbitrary order. If a node being visited
170     // calls a function that has not yet been analyzed, we will not be able to
171     // backtrack through that function call (our analysis will be correct but
172     // pessimistic).
173     for (CallGraphNode* node : scc) {
174       if (node->isExternal()) continue;
175       Region* region = node->getCallableRegion();
176       GetOrCreateAnalysis(*region);
177     }
178   }
179 
180   // This above call graph analysis will cover all regions attached to functions
181   // but we also need to analyze regions attached to other ops.
182   module->walk([this](Operation* op) {
183     if (op->hasTrait<OpTrait::NoTerminator>()) return;
184     for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
185   });
186 }
187 
188 // Backtracks the definition of `value` looking through passthrough ops.
189 // Returns a non-null value and can return `value` if backtracking is not
190 // possible.
BacktrackValue(Value value)191 Value BacktrackAnalysis::BacktrackValue(Value value) {
192   while (Operation* op = value.getDefiningOp()) {
193     int res_index = value.cast<OpResult>().getResultNumber();
194     if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
195       value = graph.GetFetch().getOperand(res_index);
196     } else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
197       // Control output is generated by the IslandOp, not the yield in
198       // in the Island body.
199       if (value == island.control()) break;
200       value = island.GetYield().getOperand(res_index);
201     } else if (isa<IdentityNOp, IdentityOp>(op)) {
202       value = op->getOperand(res_index);
203     } else if (auto call = dyn_cast<CallOpInterface>(op)) {
204       func::FuncOp func = dyn_cast<func::FuncOp>(
205           call.resolveCallable(&symbol_table_collection_));
206       if (!func) break;
207       // Check if the function being called has been analyzed. if not,
208       // we cannot backtrack the value further.
209       Optional<const InfoT*> callee_info = GetAnalysisIfExists(func);
210       if (!callee_info) break;
211       Optional<int> passthrough_arg = callee_info.getValue()->GetArg(res_index);
212       if (!passthrough_arg) break;
213       value = call.getArgOperands()[passthrough_arg.getValue()];
214     } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
215       value = op->getRegion(0).front().getTerminator()->getOperand(res_index);
216     } else {
217       break;
218     }
219   }
220   return value;
221 }
222 
223 // Analyze the region.
BacktrackAnalysisInfo(Region & region,detail::BacktrackAnalysis & backtrack_analysis)224 BacktrackAnalysisInfo::BacktrackAnalysisInfo(
225     Region& region, detail::BacktrackAnalysis& backtrack_analysis)
226     : region_(&region) {
227   if (region.empty()) return;
228 
229   assert(llvm::hasSingleElement(region.getBlocks()));
230 
231   auto results = region.front().getTerminator()->getOperands();
232   if (results.empty()) return;
233 
234   backtracked_values_.reserve(results.size());
235   for (auto result : results)
236     backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
237 }
238 
239 //===----------------------------------------------------------------------===//
240 // ResourceAliasAnalysisInfo
241 //===----------------------------------------------------------------------===//
242 
243 namespace {
244 
245 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
246 
IsResourceAllocatingOp(Operation * op)247 bool IsResourceAllocatingOp(Operation* op) {
248   auto mem_interface = dyn_cast<MemoryEffectOpInterface>(op);
249   if (!mem_interface) return false;
250 
251   for (Value value : filter_resources(op->getResults())) {
252     llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
253     mem_interface.getEffectsOnValue(value, effects);
254     for (auto& effect_instance : effects) {
255       if (isa<MemoryEffects::Allocate>(effect_instance.getEffect())) {
256         return true;
257       }
258     }
259   }
260   return false;
261 }
262 
263 }  // namespace
264 
265 constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
266 
IncrementResourceTypeId(int64_t & resource_type_id)267 void IncrementResourceTypeId(int64_t& resource_type_id) {
268   if (resource_type_id == ResourceAliasAnalysisInfo::kMaxResourceTypeId) {
269     // We don't expect this to happen, currently there are 10 resource types in
270     // TF dialect. Still, it should be visible if this ever happens.
271     LOG(WARNING) << "reached limit for supported number of resource types ("
272                  << ResourceAliasAnalysisInfo::kMaxResourceTypeId
273                  << "); this could lead to overly conservative execution order";
274     // Note: By not incrementing `resource_type_id` we still maintain
275     // correctness, we might only handle different resource types as the same
276     // type (for ID `kMaxResourceTypeId`) which is overly conservative.
277   } else {
278     ++resource_type_id;
279   }
280 }
281 
282 // Constructs the analysis info by analyzing the given function.
ResourceAliasAnalysisInfo(func::FuncOp func_op,const BacktrackAnalysis & backtrack_analysis,SymbolTableCollection & symbol_table_collection)283 ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
284     func::FuncOp func_op, const BacktrackAnalysis& backtrack_analysis,
285     SymbolTableCollection& symbol_table_collection) {
286   // This function populates resource_value_to_ids_ and id_to_resource_values_.
287 
288   // See `ResourceAliasAnalysisInfo` class for ID semantics.
289   int64_t next_unique_type_id = 0;
290   int64_t next_unique_instance_id = kMaxResourceTypeId + 1;
291 
292   // Helper to assign new unique id for all resources in the given list of
293   // values.
294   auto assign_unique_id_to_all = [&](ValueRange values) {
295     for (Value value : filter_resources(values)) {
296       AddValueUniqueIDMapping(value, next_unique_instance_id++);
297     }
298   };
299 
300   // Helper to assign new unknown id for all resources in the given list of
301   // values.
302   auto assign_unknown_id_to_all = [&](ValueRange values) {
303     for (Value value : filter_resources(values)) {
304       AddValueUniqueIDMapping(value, kUnknownResourceId);
305     }
306   };
307 
308   // If `tf.resource_arg_unique_id` argument attributes are present for
309   // resource-type arguments, use those to decide which arguments correspond to
310   // the same resource (and thus need the same ID). Otherwise, they must not
311   // alias.
312   const bool has_arg_unique_id_attrs =
313       llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) {
314         return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr);
315       });
316   if (has_arg_unique_id_attrs) {
317     // Resource arguments have IDs attached (via `kResourceArgUniqueIdAttr`)
318     // that represent different resources. Map those IDs to the internal
319     // instance IDs used by this pass.
320     llvm::SmallDenseMap<int64_t, int64_t> attr_id_to_internal_id;
321     for (auto arg : filter_resources(func_op.getArguments())) {
322       auto id_attr = func_op.getArgAttrOfType<IntegerAttr>(
323           arg.getArgNumber(), kResourceArgUniqueIdAttr);
324       assert(id_attr &&
325              "tf.resource_arg_unique_id attribute should exist on either "
326              "none or all arguments.");
327       auto emplace_res = attr_id_to_internal_id.try_emplace(
328           id_attr.getInt(), next_unique_instance_id);
329       AddValueUniqueIDMapping(arg, emplace_res.first->getSecond());
330       // Only increment ID if it has been used.
331       if (emplace_res.second) ++next_unique_instance_id;
332     }
333   } else {
334     // No `kResourceArgUniqueIdAttr` attribute is present, so all resource
335     // arguments must correspond to different resources and we can assign unique
336     // IDs.
337     assign_unique_id_to_all(func_op.getArguments());
338   }
339 
340   // Since this analysis is neither inter-procedural nor inter-regional,
341   // each region attached to Op's within a function is analyzed independently.
342   // Seed this analysis for each such region by mapping all resource arguments
343   // for such regions to a new unique-id. This is required because walk() walks
344   // the attached regions first before visiting the op, so there is no
345   // opportunity during the walk to seed region arguments. Also note that walk
346   // eventually also visits the Op on which the walk() is called, so make sure
347   // we do not overwrite the function argument mapping here.
348   func_op.walk([&](Operation* op) {
349     if (op == func_op) return;
350     for (Region& region : op->getRegions()) {
351       assign_unique_id_to_all(region.getArguments());
352     }
353   });
354 
355   llvm::SmallDenseMap<ResourceHandle, int64_t> resource_handle_id_map;
356   func_op.walk([&](Operation* op) {
357     if (auto resource_alloc = dyn_cast<ResourceHandleAllocatorInterface>(op)) {
358       llvm::SmallVector<ResourceHandleValueAndId, 4> resources =
359           resource_alloc.GetResourceHandleValueAndIdList(
360               resource_handle_id_map, next_unique_instance_id);
361       for (auto& resource_handle : resources) {
362         AddValueUniqueIDMapping(resource_handle.value, resource_handle.id);
363         // Keep track of IDs of resources that are allocated by ops with
364         // `UniqueResourceAllocation` trait, this can be utilized for while-loop
365         // parallelization (every iteration creates a new unique resource).
366         if (op->hasTrait<OpTrait::TF::UniqueResourceAllocation>()) {
367           unique_resource_allocation_ids_.insert(resource_handle.id);
368         }
369       }
370     } else if (llvm::isa<TPUReplicatedInputOp>(op)) {
371       // TPUReplicateInput only has a single result but we get all results
372       // to use filter_resources and for consistency.
373       for (auto result : filter_resources(op->getResults())) {
374         for (auto operand : op->getOperands()) {
375           PropagateInputToOutput(operand, result);
376         }
377       }
378     } else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
379       for (auto result : filter_resources(op->getResults()))
380         PropagateInputToOutput(op->getOperand(result.getResultNumber()),
381                                result);
382     } else if (auto while_op = dyn_cast<WhileOp>(op)) {
383       AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc(
384                                      while_op.body_function()));
385     } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
386       AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion(
387                                          while_region.body()));
388     } else if (auto case_op = dyn_cast<CaseOp>(op)) {
389       llvm::SmallVector<func::FuncOp, 4> functions;
390       case_op.get_branch_functions(functions);
391       AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
392     } else if (auto if_op = dyn_cast<IfOp>(op)) {
393       AnalyzeFunctionalCaseOrIfOp(
394           if_op, {if_op.then_function(), if_op.else_function()},
395           backtrack_analysis);
396     } else if (llvm::isa<CaseRegionOp, IfRegionOp>(op)) {
397       AnalyzeRegionCaseOrIfOp(op, backtrack_analysis);
398     } else if (auto call = dyn_cast<CallOpInterface>(op)) {
399       func::FuncOp func = dyn_cast_or_null<func::FuncOp>(
400           call.resolveCallable(&symbol_table_collection));
401       if (!func) {
402         assign_unknown_id_to_all(op->getResults());
403         return WalkResult::advance();
404       }
405       const auto& func_info = backtrack_analysis.GetAnalysisForFunc(func);
406       for (auto result : filter_resources(op->getResults())) {
407         auto passthrough_arg = func_info.GetArg(result.getResultNumber());
408         if (passthrough_arg) {
409           PropagateInputToOutput(
410               call.getArgOperands()[passthrough_arg.getValue()], result);
411         } else {
412           AddValueUniqueIDMapping(result, kUnknownResourceId);
413         }
414       }
415     } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp,
416                    tf_executor::IslandOp, tf_executor::GraphOp>(op) &&
417                op->getNumRegions() == 1) {
418       Region& region = op->getRegion(0);
419       const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region);
420       for (auto result : filter_resources(op->getResults())) {
421         Value body_result = body_info.GetValue(result.getResultNumber());
422         PropagateInputToOutput(body_result, result);
423       }
424     } else {
425       auto mem_interface = dyn_cast<MemoryEffectOpInterface>(op);
426       for (Value value : filter_resources(op->getResults())) {
427         // Set unknown ID first, reset later if applicable.
428         int64_t resource_id = kUnknownResourceId;
429 
430         if (mem_interface) {
431           auto alloc_effect =
432               mem_interface.getEffectOnValue<MemoryEffects::Allocate>(value);
433           if (alloc_effect) {
434             TypeID mlir_type_id =
435                 alloc_effect.getValue().getResource()->getResourceID();
436             // Update or lookup internal type ID.
437             auto emplace_result = type_id_to_internal_type_id_.try_emplace(
438                 mlir_type_id, next_unique_type_id);
439             // Change unknown ID to type-based ID.
440             resource_id = emplace_result.first->getSecond();
441             // Only increment ID if we have encountered a new resource type.
442             if (emplace_result.second)
443               IncrementResourceTypeId(next_unique_type_id);
444           }
445         }
446         AddValueUniqueIDMapping(value, resource_id);
447       }
448     }
449     return WalkResult::advance();
450   });
451 }
452 
453 // Propagates the resource IDs from an input operand to a result. Returns true
454 // if the mapping changed.
PropagateInputToOutput(const Value & operand,const OpResult & result)455 bool ResourceAliasAnalysisInfo::PropagateInputToOutput(const Value& operand,
456                                                        const OpResult& result) {
457   auto operand_it = resource_value_to_ids_.find(operand);
458   assert(operand_it != resource_value_to_ids_.end() &&
459          "A resource-type output does not have the corresponding "
460          "resource-type input.");
461   bool change = false;
462   for (int64_t id : operand_it->second)
463     change = AddValueUniqueIDMapping(result, id) || change;
464   return change;
465 }
466 
467 // Analyzes while loops to compute resourceIDs for the loop results.
468 //
469 // (1) The base case for the analysis is that if the loop body does not execute
470 //     at all, the resource IDs for each result is the same as the resource IDs
471 //     of the corresponding input.
472 // (2) If the loop does execute one or more times, then we need to account for
473 //     data flow through the body of the while loop. If result #r is the same
474 //     as arg #a of the loop body (pass through argument), then we can reason
475 //     further, else if the result is not a passthrough, we mark it as unknown.
476 // (3) For passthrough results, if result #r is the same as arg #a of the loop
477 //     body, after one iteration, result #r = arg #a, so we need to also
478 //     propagate arg #a to result #r. After another iteration, arg #a of the
479 //     loop body will be result #a of the previous iteration. So then we need
480 //     propagate from result #a to result #r. Generalizing, the resource ID
481 //     propagation (for results which are passthrough) looks like:
482 //
483 //     for r in (0, num_results) : result[r] = arg[r];
484 //     repeat till no change {
485 //       a = passthrough arg for result #r;
486 //       result[r] += result[a];
487 //     }
488 //
AnalyzeWhileLoop(Operation * while_op,const BacktrackAnalysisInfo & body_info)489 void ResourceAliasAnalysisInfo::AnalyzeWhileLoop(
490     Operation* while_op, const BacktrackAnalysisInfo& body_info) {
491   // Seed the resource IDs for the results using either the resource ID of the
492   // passthrough arg, or unknown. We need to perform further analysis if we
493   // find a passthrough arg which is not the same as corresponding the result #.
494   llvm::SmallVector<Optional<int>, 4> passthrough_args(
495       while_op->getNumResults());
496   bool need_analysis = false;
497   for (auto result : filter_resources(while_op->getResults())) {
498     int result_index = result.getResultNumber();
499     passthrough_args[result_index] = body_info.GetArg(result_index);
500     if (passthrough_args[result_index]) {
501       int passthru_index = passthrough_args[result_index].getValue();
502       PropagateInputToOutput(while_op->getOperand(passthru_index), result);
503       need_analysis |=
504           !IsUnknownResource(result) && passthru_index != result_index;
505     } else {
506       AddValueUniqueIDMapping(result, kUnknownResourceId);
507     }
508   }
509 
510   if (!need_analysis) return;
511 
512   // We found a result that is not unknown and whose passthrough operand index
513   // is not the same as the result index, which means there is "crosstalk"
514   // between 2 or more operands. In that case, we do an iterative propagation
515   // of resource IDs till the results converge.
516   bool change = true;
517   while (change) {
518     change = false;
519     for (auto result : filter_resources(while_op->getResults())) {
520       if (IsUnknownResource(result)) continue;
521       // If this result has a valid passthrough arg, propagate resource IDs
522       // from the result of the passthrough arg
523       int result_index = result.getResultNumber();
524       int passthru_index = passthrough_args[result_index].getValue();
525       change =
526           PropagateInputToOutput(while_op->getResult(passthru_index), result) ||
527           change;
528     }
529   }
530 }
531 
532 template <class CaseOrIfOp>
AnalyzeFunctionalCaseOrIfOp(CaseOrIfOp case_or_if_op,llvm::ArrayRef<func::FuncOp> functions,const BacktrackAnalysis & backtrack_analysis)533 void ResourceAliasAnalysisInfo::AnalyzeFunctionalCaseOrIfOp(
534     CaseOrIfOp case_or_if_op, llvm::ArrayRef<func::FuncOp> functions,
535     const BacktrackAnalysis& backtrack_analysis) {
536   llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
537   infos.reserve(functions.size());
538   for (func::FuncOp func : functions)
539     infos.push_back(&backtrack_analysis.GetAnalysisForFunc(func));
540 
541   // If a result is a passthrough of all branches' inputs, merge the resource
542   // IDs of corresponding operands for all the inputs.
543   for (auto result : filter_resources(case_or_if_op.getResults())) {
544     llvm::SmallVector<llvm::Optional<int>, 2> passthrough_args;
545     passthrough_args.reserve(functions.size());
546     for (const auto* info : infos)
547       passthrough_args.emplace_back(info->GetArg(result.getResultNumber()));
548 
549     const bool all_passthrough_args_known = llvm::all_of(
550         passthrough_args, [](const llvm::Optional<int>& passthrough_arg) {
551           return passthrough_arg.has_value();
552         });
553     if (all_passthrough_args_known) {
554       for (const auto& passthrough_arg : passthrough_args) {
555         Value operand = case_or_if_op.input()[passthrough_arg.getValue()];
556         PropagateInputToOutput(operand, result);
557       }
558     } else {
559       AddValueUniqueIDMapping(result, kUnknownResourceId);
560     }
561   }
562 }
563 
AnalyzeRegionCaseOrIfOp(Operation * case_or_if_op,const BacktrackAnalysis & backtrack_analysis)564 void ResourceAliasAnalysisInfo::AnalyzeRegionCaseOrIfOp(
565     Operation* case_or_if_op, const BacktrackAnalysis& backtrack_analysis) {
566   llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
567   infos.reserve(case_or_if_op->getNumRegions());
568   for (Region& region : case_or_if_op->getRegions())
569     infos.push_back(&backtrack_analysis.GetAnalysisForRegion(region));
570 
571   // For region Case/If, the walk would have visited all branch regions before
572   // visiting the Case/If op. Backtracking of each region results will either
573   // give a value computed within these regions, or a region capture. If it is a
574   // region capture computed before this Case/If, it will have been visited
575   // earlier and a mapping would exist for that value. If it is computed within
576   // the region, then again a mapping would exist.
577   for (auto result : filter_resources(case_or_if_op->getResults())) {
578     for (const auto* info : infos) {
579       Value region_result = info->GetValue(result.getResultNumber());
580       PropagateInputToOutput(region_result, result);
581     }
582   }
583 }
584 
IsUnknownResource(Value resource) const585 bool ResourceAliasAnalysisInfo::IsUnknownResource(Value resource) const {
586   auto it = resource_value_to_ids_.find(resource);
587   assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
588   // The set is sorted so we only need to check the first element since
589   // kUnknownResourceId < 0.
590   static_assert(kUnknownResourceId < 0,
591                 "kUnknownResourceId should be negative");
592   return *it->getSecond().begin() == kUnknownResourceId;
593 }
594 
595 const llvm::SmallSet<int64_t, 8>&
GetResourceUniqueIds(Value resource) const596 ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
597   assert(!IsUnknownResource(resource));
598   auto it = resource_value_to_ids_.find(resource);
599   assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
600   return it->getSecond();
601 }
602 
603 const llvm::SmallSetVector<Value, 8>&
GetUniqueIdResources(const int64_t id) const604 ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
605   auto it = id_to_resource_values_.find(id);
606   assert(it != id_to_resource_values_.end() && "Unseen id was queried");
607   return it->getSecond();
608 }
609 
GetResourceAliases(Value resource) const610 llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
611     Value resource) const {
612   assert(!IsUnknownResource(resource) && "Unknown resource was queried");
613   llvm::SmallSetVector<Value, 8> aliases;
614   for (int64_t id : GetResourceUniqueIds(resource)) {
615     const llvm::SmallSetVector<Value, 8>& resources_aliasing_id =
616         GetUniqueIdResources(id);
617     aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end());
618   }
619   // If there are resources that were marked as unknown, they alias with all
620   // other resources.
621   auto it = id_to_resource_values_.find(kUnknownResourceId);
622   if (it != id_to_resource_values_.end())
623     aliases.insert(it->getSecond().begin(), it->getSecond().end());
624   return aliases;
625 }
626 
627 }  // namespace detail
628 
629 //===----------------------------------------------------------------------===//
630 // ResourceAliasAnalysis
631 //===----------------------------------------------------------------------===//
632 
ResourceAliasAnalysis(ModuleOp module)633 ResourceAliasAnalysis::ResourceAliasAnalysis(ModuleOp module) {
634   // Create symbol table for module.
635   SymbolTableCollection symbol_table_collection;
636   symbol_table_collection.getSymbolTable(module);
637   // Analyze all regions for backtracking info.
638   detail::BacktrackAnalysis backtrack_analysis(module, symbol_table_collection);
639 
640   // Analyze each function.
641   for (auto func : module.getOps<func::FuncOp>())
642     this->info_map_.try_emplace(func, func, backtrack_analysis,
643                                 symbol_table_collection);
644 }
645 
646 }  // namespace TF
647 }  // namespace mlir
648