1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
17
18 #include <cstdint>
19 #include <initializer_list>
20 #include <utility>
21
22 #include "llvm/ADT/ArrayRef.h"
23 #include "llvm/ADT/DenseMap.h"
24 #include "llvm/ADT/Optional.h"
25 #include "llvm/ADT/SCCIterator.h"
26 #include "llvm/ADT/STLExtras.h"
27 #include "llvm/ADT/SmallVector.h"
28 #include "llvm/ADT/iterator_range.h"
29 #include "llvm/Support/Casting.h"
30 #include "mlir/Analysis/CallGraph.h" // from @llvm-project
31 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
32 #include "mlir/IR/Attributes.h" // from @llvm-project
33 #include "mlir/IR/Block.h" // from @llvm-project
34 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
35 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
36 #include "mlir/IR/Operation.h" // from @llvm-project
37 #include "mlir/IR/Value.h" // from @llvm-project
38 #include "mlir/IR/Visitors.h" // from @llvm-project
39 #include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project
40 #include "mlir/Support/LLVM.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
46
47 namespace mlir {
48 namespace TF {
49 namespace detail {
50
51 //===----------------------------------------------------------------------===//
52 // BacktrackAnalysisInfo
53 //===----------------------------------------------------------------------===//
54 // Class to hold backtrack analysis for a results of a region. Backtrack
55 // analysis will trace back the definition of return values of regions through
56 // pass-through operations, so that the return value of the region will have the
57 // same value as the backtracked value.
58 class BacktrackAnalysisInfo {
59 public:
60 // Initializes the backtrack analysis for the given region.
61 explicit BacktrackAnalysisInfo(Region& region,
62 detail::BacktrackAnalysis& backtrack_analysis);
63
64 BacktrackAnalysisInfo(BacktrackAnalysisInfo&&) = default;
65
66 // Returns the value to which the given result number of the region can be
67 // backtracked to.
GetValue(int result_index) const68 Value GetValue(int result_index) const {
69 return backtracked_values_[result_index];
70 }
71
72 // Returns the argument index of the region to which the given result number
73 // can backtracked to. Such results will be called "function passthrough". If
74 // the result cannot be backtracked to a region argument, returns llvm::None.
GetArg(int result_index) const75 llvm::Optional<int> GetArg(int result_index) const {
76 if (auto arg = GetValue(result_index).dyn_cast<BlockArgument>())
77 if (arg.getParentBlock() == ®ion_->front()) return arg.getArgNumber();
78 return llvm::None;
79 }
80
81 private:
82 friend class detail::BacktrackAnalysis;
83
84 // Region for which this object holds the analysis info.
85 Region* region_;
86
87 // Backtracked values indexed by the result number.
88 llvm::SmallVector<Value, 4> backtracked_values_;
89 };
90
91 //===----------------------------------------------------------------------===//
92 // BacktrackAnalysis
93 //===----------------------------------------------------------------------===//
94 // Holds backtrack analysis for all functions and regions within a module.
95 class BacktrackAnalysis {
96 public:
97 using InfoT = BacktrackAnalysisInfo;
98
99 // Constructs the analysis by analyzing the given module.
100 BacktrackAnalysis(ModuleOp module,
101 SymbolTableCollection& symbol_table_collection);
102
103 // Returns backtracking analysis for the given region.
GetAnalysisForRegion(Region & region) const104 const InfoT& GetAnalysisForRegion(Region& region) const {
105 auto it = info_map_.find(®ion);
106 assert(it != info_map_.end());
107 return it->second;
108 }
109
110 // Returns backtracking analysis for the given function.
GetAnalysisForFunc(func::FuncOp func) const111 const InfoT& GetAnalysisForFunc(func::FuncOp func) const {
112 return GetAnalysisForRegion(func.getBody());
113 }
114
115 // Backtracks the given value.
116 Value BacktrackValue(Value value);
117
118 private:
119 // Returns the analysis for the given region (analyzing the region if it has
120 // not yet been analyzed).
GetOrCreateAnalysis(Region & region)121 const InfoT& GetOrCreateAnalysis(Region& region) {
122 auto it = info_map_.find(®ion);
123 if (it == info_map_.end()) {
124 // Note: Keep object construction and insertion separate. If we use
125 // emplace() to construct and insert in a single shot, when analyzing
126 // this region, calls to BacktrackValue() may end up inserting additional
127 // entries in the map, causing the underlying storage to be moved. This
128 // would also include this pertially constructed object that we have just
129 // inserted into the map and are constructing it. To avoid this issue,
130 // construct the analysis object separately and then insert it into the
131 // map.
132 InfoT info(region, *this);
133 info_map_.insert({®ion, std::move(info)});
134 }
135
136 return GetAnalysisForRegion(region);
137 }
138
139 // Returns the backtrack analysis for the given region if it exists.
140 // If the region has not yet been analyzed, returns llvm::None.
GetAnalysisIfExists(Region & region) const141 Optional<const InfoT*> GetAnalysisIfExists(Region& region) const {
142 auto it = info_map_.find(®ion);
143 if (it == info_map_.end()) return llvm::None;
144 return &it->second;
145 }
146
GetAnalysisIfExists(func::FuncOp func) const147 Optional<const InfoT*> GetAnalysisIfExists(func::FuncOp func) const {
148 return GetAnalysisIfExists(func.getBody());
149 }
150
151 private:
152 llvm::SmallDenseMap<Region*, InfoT> info_map_;
153 SymbolTableCollection& symbol_table_collection_;
154 };
155
156 // Analyzes all regions attached to all operations in the module.
BacktrackAnalysis(ModuleOp module,SymbolTableCollection & symbol_table_collection)157 BacktrackAnalysis::BacktrackAnalysis(
158 ModuleOp module, SymbolTableCollection& symbol_table_collection)
159 : symbol_table_collection_(symbol_table_collection) {
160 const CallGraph call_graph(module);
161
162 // Visit functions bottom up when doing the analysis. Note that SCC iterator
163 // has the property that if there is an edge from SCC1->SCC2, SCC1 is visited
164 // after SCC2, i.e., the graph is traversed bottom up just the way we want.
165 auto scc_begin = llvm::scc_begin(&call_graph);
166 auto scc_end = llvm::scc_end(&call_graph);
167 for (auto& scc : make_range(scc_begin, scc_end)) {
168 // Each SCC node is a collection of callgraph nodes that form a cycle. We
169 // will visit these nodes in an arbitrary order. If a node being visited
170 // calls a function that has not yet been analyzed, we will not be able to
171 // backtrack through that function call (our analysis will be correct but
172 // pessimistic).
173 for (CallGraphNode* node : scc) {
174 if (node->isExternal()) continue;
175 Region* region = node->getCallableRegion();
176 GetOrCreateAnalysis(*region);
177 }
178 }
179
180 // This above call graph analysis will cover all regions attached to functions
181 // but we also need to analyze regions attached to other ops.
182 module->walk([this](Operation* op) {
183 if (op->hasTrait<OpTrait::NoTerminator>()) return;
184 for (Region& region : op->getRegions()) GetOrCreateAnalysis(region);
185 });
186 }
187
188 // Backtracks the definition of `value` looking through passthrough ops.
189 // Returns a non-null value and can return `value` if backtracking is not
190 // possible.
BacktrackValue(Value value)191 Value BacktrackAnalysis::BacktrackValue(Value value) {
192 while (Operation* op = value.getDefiningOp()) {
193 int res_index = value.cast<OpResult>().getResultNumber();
194 if (auto graph = dyn_cast<tf_executor::GraphOp>(op)) {
195 value = graph.GetFetch().getOperand(res_index);
196 } else if (auto island = dyn_cast<tf_executor::IslandOp>(op)) {
197 // Control output is generated by the IslandOp, not the yield in
198 // in the Island body.
199 if (value == island.control()) break;
200 value = island.GetYield().getOperand(res_index);
201 } else if (isa<IdentityNOp, IdentityOp>(op)) {
202 value = op->getOperand(res_index);
203 } else if (auto call = dyn_cast<CallOpInterface>(op)) {
204 func::FuncOp func = dyn_cast<func::FuncOp>(
205 call.resolveCallable(&symbol_table_collection_));
206 if (!func) break;
207 // Check if the function being called has been analyzed. if not,
208 // we cannot backtrack the value further.
209 Optional<const InfoT*> callee_info = GetAnalysisIfExists(func);
210 if (!callee_info) break;
211 Optional<int> passthrough_arg = callee_info.getValue()->GetArg(res_index);
212 if (!passthrough_arg) break;
213 value = call.getArgOperands()[passthrough_arg.getValue()];
214 } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp>(op)) {
215 value = op->getRegion(0).front().getTerminator()->getOperand(res_index);
216 } else {
217 break;
218 }
219 }
220 return value;
221 }
222
223 // Analyze the region.
BacktrackAnalysisInfo(Region & region,detail::BacktrackAnalysis & backtrack_analysis)224 BacktrackAnalysisInfo::BacktrackAnalysisInfo(
225 Region& region, detail::BacktrackAnalysis& backtrack_analysis)
226 : region_(®ion) {
227 if (region.empty()) return;
228
229 assert(llvm::hasSingleElement(region.getBlocks()));
230
231 auto results = region.front().getTerminator()->getOperands();
232 if (results.empty()) return;
233
234 backtracked_values_.reserve(results.size());
235 for (auto result : results)
236 backtracked_values_.push_back(backtrack_analysis.BacktrackValue(result));
237 }
238
239 //===----------------------------------------------------------------------===//
240 // ResourceAliasAnalysisInfo
241 //===----------------------------------------------------------------------===//
242
243 namespace {
244
245 constexpr char kResourceArgUniqueIdAttr[] = "tf._resource_arg_unique_id";
246
IsResourceAllocatingOp(Operation * op)247 bool IsResourceAllocatingOp(Operation* op) {
248 auto mem_interface = dyn_cast<MemoryEffectOpInterface>(op);
249 if (!mem_interface) return false;
250
251 for (Value value : filter_resources(op->getResults())) {
252 llvm::SmallVector<MemoryEffects::EffectInstance, 4> effects;
253 mem_interface.getEffectsOnValue(value, effects);
254 for (auto& effect_instance : effects) {
255 if (isa<MemoryEffects::Allocate>(effect_instance.getEffect())) {
256 return true;
257 }
258 }
259 }
260 return false;
261 }
262
263 } // namespace
264
265 constexpr int64_t ResourceAliasAnalysisInfo::kUnknownResourceId;
266
IncrementResourceTypeId(int64_t & resource_type_id)267 void IncrementResourceTypeId(int64_t& resource_type_id) {
268 if (resource_type_id == ResourceAliasAnalysisInfo::kMaxResourceTypeId) {
269 // We don't expect this to happen, currently there are 10 resource types in
270 // TF dialect. Still, it should be visible if this ever happens.
271 LOG(WARNING) << "reached limit for supported number of resource types ("
272 << ResourceAliasAnalysisInfo::kMaxResourceTypeId
273 << "); this could lead to overly conservative execution order";
274 // Note: By not incrementing `resource_type_id` we still maintain
275 // correctness, we might only handle different resource types as the same
276 // type (for ID `kMaxResourceTypeId`) which is overly conservative.
277 } else {
278 ++resource_type_id;
279 }
280 }
281
282 // Constructs the analysis info by analyzing the given function.
ResourceAliasAnalysisInfo(func::FuncOp func_op,const BacktrackAnalysis & backtrack_analysis,SymbolTableCollection & symbol_table_collection)283 ResourceAliasAnalysisInfo::ResourceAliasAnalysisInfo(
284 func::FuncOp func_op, const BacktrackAnalysis& backtrack_analysis,
285 SymbolTableCollection& symbol_table_collection) {
286 // This function populates resource_value_to_ids_ and id_to_resource_values_.
287
288 // See `ResourceAliasAnalysisInfo` class for ID semantics.
289 int64_t next_unique_type_id = 0;
290 int64_t next_unique_instance_id = kMaxResourceTypeId + 1;
291
292 // Helper to assign new unique id for all resources in the given list of
293 // values.
294 auto assign_unique_id_to_all = [&](ValueRange values) {
295 for (Value value : filter_resources(values)) {
296 AddValueUniqueIDMapping(value, next_unique_instance_id++);
297 }
298 };
299
300 // Helper to assign new unknown id for all resources in the given list of
301 // values.
302 auto assign_unknown_id_to_all = [&](ValueRange values) {
303 for (Value value : filter_resources(values)) {
304 AddValueUniqueIDMapping(value, kUnknownResourceId);
305 }
306 };
307
308 // If `tf.resource_arg_unique_id` argument attributes are present for
309 // resource-type arguments, use those to decide which arguments correspond to
310 // the same resource (and thus need the same ID). Otherwise, they must not
311 // alias.
312 const bool has_arg_unique_id_attrs =
313 llvm::any_of(func_op.getArguments(), [&](const BlockArgument& arg) {
314 return func_op.getArgAttr(arg.getArgNumber(), kResourceArgUniqueIdAttr);
315 });
316 if (has_arg_unique_id_attrs) {
317 // Resource arguments have IDs attached (via `kResourceArgUniqueIdAttr`)
318 // that represent different resources. Map those IDs to the internal
319 // instance IDs used by this pass.
320 llvm::SmallDenseMap<int64_t, int64_t> attr_id_to_internal_id;
321 for (auto arg : filter_resources(func_op.getArguments())) {
322 auto id_attr = func_op.getArgAttrOfType<IntegerAttr>(
323 arg.getArgNumber(), kResourceArgUniqueIdAttr);
324 assert(id_attr &&
325 "tf.resource_arg_unique_id attribute should exist on either "
326 "none or all arguments.");
327 auto emplace_res = attr_id_to_internal_id.try_emplace(
328 id_attr.getInt(), next_unique_instance_id);
329 AddValueUniqueIDMapping(arg, emplace_res.first->getSecond());
330 // Only increment ID if it has been used.
331 if (emplace_res.second) ++next_unique_instance_id;
332 }
333 } else {
334 // No `kResourceArgUniqueIdAttr` attribute is present, so all resource
335 // arguments must correspond to different resources and we can assign unique
336 // IDs.
337 assign_unique_id_to_all(func_op.getArguments());
338 }
339
340 // Since this analysis is neither inter-procedural nor inter-regional,
341 // each region attached to Op's within a function is analyzed independently.
342 // Seed this analysis for each such region by mapping all resource arguments
343 // for such regions to a new unique-id. This is required because walk() walks
344 // the attached regions first before visiting the op, so there is no
345 // opportunity during the walk to seed region arguments. Also note that walk
346 // eventually also visits the Op on which the walk() is called, so make sure
347 // we do not overwrite the function argument mapping here.
348 func_op.walk([&](Operation* op) {
349 if (op == func_op) return;
350 for (Region& region : op->getRegions()) {
351 assign_unique_id_to_all(region.getArguments());
352 }
353 });
354
355 llvm::SmallDenseMap<ResourceHandle, int64_t> resource_handle_id_map;
356 func_op.walk([&](Operation* op) {
357 if (auto resource_alloc = dyn_cast<ResourceHandleAllocatorInterface>(op)) {
358 llvm::SmallVector<ResourceHandleValueAndId, 4> resources =
359 resource_alloc.GetResourceHandleValueAndIdList(
360 resource_handle_id_map, next_unique_instance_id);
361 for (auto& resource_handle : resources) {
362 AddValueUniqueIDMapping(resource_handle.value, resource_handle.id);
363 // Keep track of IDs of resources that are allocated by ops with
364 // `UniqueResourceAllocation` trait, this can be utilized for while-loop
365 // parallelization (every iteration creates a new unique resource).
366 if (op->hasTrait<OpTrait::TF::UniqueResourceAllocation>()) {
367 unique_resource_allocation_ids_.insert(resource_handle.id);
368 }
369 }
370 } else if (llvm::isa<TPUReplicatedInputOp>(op)) {
371 // TPUReplicateInput only has a single result but we get all results
372 // to use filter_resources and for consistency.
373 for (auto result : filter_resources(op->getResults())) {
374 for (auto operand : op->getOperands()) {
375 PropagateInputToOutput(operand, result);
376 }
377 }
378 } else if (llvm::isa<IdentityNOp, IdentityOp>(op)) {
379 for (auto result : filter_resources(op->getResults()))
380 PropagateInputToOutput(op->getOperand(result.getResultNumber()),
381 result);
382 } else if (auto while_op = dyn_cast<WhileOp>(op)) {
383 AnalyzeWhileLoop(while_op, backtrack_analysis.GetAnalysisForFunc(
384 while_op.body_function()));
385 } else if (auto while_region = dyn_cast<WhileRegionOp>(op)) {
386 AnalyzeWhileLoop(while_region, backtrack_analysis.GetAnalysisForRegion(
387 while_region.body()));
388 } else if (auto case_op = dyn_cast<CaseOp>(op)) {
389 llvm::SmallVector<func::FuncOp, 4> functions;
390 case_op.get_branch_functions(functions);
391 AnalyzeFunctionalCaseOrIfOp(case_op, functions, backtrack_analysis);
392 } else if (auto if_op = dyn_cast<IfOp>(op)) {
393 AnalyzeFunctionalCaseOrIfOp(
394 if_op, {if_op.then_function(), if_op.else_function()},
395 backtrack_analysis);
396 } else if (llvm::isa<CaseRegionOp, IfRegionOp>(op)) {
397 AnalyzeRegionCaseOrIfOp(op, backtrack_analysis);
398 } else if (auto call = dyn_cast<CallOpInterface>(op)) {
399 func::FuncOp func = dyn_cast_or_null<func::FuncOp>(
400 call.resolveCallable(&symbol_table_collection));
401 if (!func) {
402 assign_unknown_id_to_all(op->getResults());
403 return WalkResult::advance();
404 }
405 const auto& func_info = backtrack_analysis.GetAnalysisForFunc(func);
406 for (auto result : filter_resources(op->getResults())) {
407 auto passthrough_arg = func_info.GetArg(result.getResultNumber());
408 if (passthrough_arg) {
409 PropagateInputToOutput(
410 call.getArgOperands()[passthrough_arg.getValue()], result);
411 } else {
412 AddValueUniqueIDMapping(result, kUnknownResourceId);
413 }
414 }
415 } else if (isa<tf_device::LaunchOp, tf_device::ClusterOp,
416 tf_executor::IslandOp, tf_executor::GraphOp>(op) &&
417 op->getNumRegions() == 1) {
418 Region& region = op->getRegion(0);
419 const auto& body_info = backtrack_analysis.GetAnalysisForRegion(region);
420 for (auto result : filter_resources(op->getResults())) {
421 Value body_result = body_info.GetValue(result.getResultNumber());
422 PropagateInputToOutput(body_result, result);
423 }
424 } else {
425 auto mem_interface = dyn_cast<MemoryEffectOpInterface>(op);
426 for (Value value : filter_resources(op->getResults())) {
427 // Set unknown ID first, reset later if applicable.
428 int64_t resource_id = kUnknownResourceId;
429
430 if (mem_interface) {
431 auto alloc_effect =
432 mem_interface.getEffectOnValue<MemoryEffects::Allocate>(value);
433 if (alloc_effect) {
434 TypeID mlir_type_id =
435 alloc_effect.getValue().getResource()->getResourceID();
436 // Update or lookup internal type ID.
437 auto emplace_result = type_id_to_internal_type_id_.try_emplace(
438 mlir_type_id, next_unique_type_id);
439 // Change unknown ID to type-based ID.
440 resource_id = emplace_result.first->getSecond();
441 // Only increment ID if we have encountered a new resource type.
442 if (emplace_result.second)
443 IncrementResourceTypeId(next_unique_type_id);
444 }
445 }
446 AddValueUniqueIDMapping(value, resource_id);
447 }
448 }
449 return WalkResult::advance();
450 });
451 }
452
453 // Propagates the resource IDs from an input operand to a result. Returns true
454 // if the mapping changed.
PropagateInputToOutput(const Value & operand,const OpResult & result)455 bool ResourceAliasAnalysisInfo::PropagateInputToOutput(const Value& operand,
456 const OpResult& result) {
457 auto operand_it = resource_value_to_ids_.find(operand);
458 assert(operand_it != resource_value_to_ids_.end() &&
459 "A resource-type output does not have the corresponding "
460 "resource-type input.");
461 bool change = false;
462 for (int64_t id : operand_it->second)
463 change = AddValueUniqueIDMapping(result, id) || change;
464 return change;
465 }
466
467 // Analyzes while loops to compute resourceIDs for the loop results.
468 //
469 // (1) The base case for the analysis is that if the loop body does not execute
470 // at all, the resource IDs for each result is the same as the resource IDs
471 // of the corresponding input.
472 // (2) If the loop does execute one or more times, then we need to account for
473 // data flow through the body of the while loop. If result #r is the same
474 // as arg #a of the loop body (pass through argument), then we can reason
475 // further, else if the result is not a passthrough, we mark it as unknown.
476 // (3) For passthrough results, if result #r is the same as arg #a of the loop
477 // body, after one iteration, result #r = arg #a, so we need to also
478 // propagate arg #a to result #r. After another iteration, arg #a of the
479 // loop body will be result #a of the previous iteration. So then we need
480 // propagate from result #a to result #r. Generalizing, the resource ID
481 // propagation (for results which are passthrough) looks like:
482 //
483 // for r in (0, num_results) : result[r] = arg[r];
484 // repeat till no change {
485 // a = passthrough arg for result #r;
486 // result[r] += result[a];
487 // }
488 //
AnalyzeWhileLoop(Operation * while_op,const BacktrackAnalysisInfo & body_info)489 void ResourceAliasAnalysisInfo::AnalyzeWhileLoop(
490 Operation* while_op, const BacktrackAnalysisInfo& body_info) {
491 // Seed the resource IDs for the results using either the resource ID of the
492 // passthrough arg, or unknown. We need to perform further analysis if we
493 // find a passthrough arg which is not the same as corresponding the result #.
494 llvm::SmallVector<Optional<int>, 4> passthrough_args(
495 while_op->getNumResults());
496 bool need_analysis = false;
497 for (auto result : filter_resources(while_op->getResults())) {
498 int result_index = result.getResultNumber();
499 passthrough_args[result_index] = body_info.GetArg(result_index);
500 if (passthrough_args[result_index]) {
501 int passthru_index = passthrough_args[result_index].getValue();
502 PropagateInputToOutput(while_op->getOperand(passthru_index), result);
503 need_analysis |=
504 !IsUnknownResource(result) && passthru_index != result_index;
505 } else {
506 AddValueUniqueIDMapping(result, kUnknownResourceId);
507 }
508 }
509
510 if (!need_analysis) return;
511
512 // We found a result that is not unknown and whose passthrough operand index
513 // is not the same as the result index, which means there is "crosstalk"
514 // between 2 or more operands. In that case, we do an iterative propagation
515 // of resource IDs till the results converge.
516 bool change = true;
517 while (change) {
518 change = false;
519 for (auto result : filter_resources(while_op->getResults())) {
520 if (IsUnknownResource(result)) continue;
521 // If this result has a valid passthrough arg, propagate resource IDs
522 // from the result of the passthrough arg
523 int result_index = result.getResultNumber();
524 int passthru_index = passthrough_args[result_index].getValue();
525 change =
526 PropagateInputToOutput(while_op->getResult(passthru_index), result) ||
527 change;
528 }
529 }
530 }
531
532 template <class CaseOrIfOp>
AnalyzeFunctionalCaseOrIfOp(CaseOrIfOp case_or_if_op,llvm::ArrayRef<func::FuncOp> functions,const BacktrackAnalysis & backtrack_analysis)533 void ResourceAliasAnalysisInfo::AnalyzeFunctionalCaseOrIfOp(
534 CaseOrIfOp case_or_if_op, llvm::ArrayRef<func::FuncOp> functions,
535 const BacktrackAnalysis& backtrack_analysis) {
536 llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
537 infos.reserve(functions.size());
538 for (func::FuncOp func : functions)
539 infos.push_back(&backtrack_analysis.GetAnalysisForFunc(func));
540
541 // If a result is a passthrough of all branches' inputs, merge the resource
542 // IDs of corresponding operands for all the inputs.
543 for (auto result : filter_resources(case_or_if_op.getResults())) {
544 llvm::SmallVector<llvm::Optional<int>, 2> passthrough_args;
545 passthrough_args.reserve(functions.size());
546 for (const auto* info : infos)
547 passthrough_args.emplace_back(info->GetArg(result.getResultNumber()));
548
549 const bool all_passthrough_args_known = llvm::all_of(
550 passthrough_args, [](const llvm::Optional<int>& passthrough_arg) {
551 return passthrough_arg.has_value();
552 });
553 if (all_passthrough_args_known) {
554 for (const auto& passthrough_arg : passthrough_args) {
555 Value operand = case_or_if_op.input()[passthrough_arg.getValue()];
556 PropagateInputToOutput(operand, result);
557 }
558 } else {
559 AddValueUniqueIDMapping(result, kUnknownResourceId);
560 }
561 }
562 }
563
AnalyzeRegionCaseOrIfOp(Operation * case_or_if_op,const BacktrackAnalysis & backtrack_analysis)564 void ResourceAliasAnalysisInfo::AnalyzeRegionCaseOrIfOp(
565 Operation* case_or_if_op, const BacktrackAnalysis& backtrack_analysis) {
566 llvm::SmallVector<const BacktrackAnalysisInfo*, 2> infos;
567 infos.reserve(case_or_if_op->getNumRegions());
568 for (Region& region : case_or_if_op->getRegions())
569 infos.push_back(&backtrack_analysis.GetAnalysisForRegion(region));
570
571 // For region Case/If, the walk would have visited all branch regions before
572 // visiting the Case/If op. Backtracking of each region results will either
573 // give a value computed within these regions, or a region capture. If it is a
574 // region capture computed before this Case/If, it will have been visited
575 // earlier and a mapping would exist for that value. If it is computed within
576 // the region, then again a mapping would exist.
577 for (auto result : filter_resources(case_or_if_op->getResults())) {
578 for (const auto* info : infos) {
579 Value region_result = info->GetValue(result.getResultNumber());
580 PropagateInputToOutput(region_result, result);
581 }
582 }
583 }
584
IsUnknownResource(Value resource) const585 bool ResourceAliasAnalysisInfo::IsUnknownResource(Value resource) const {
586 auto it = resource_value_to_ids_.find(resource);
587 assert(it != resource_value_to_ids_.end() && !it->getSecond().empty());
588 // The set is sorted so we only need to check the first element since
589 // kUnknownResourceId < 0.
590 static_assert(kUnknownResourceId < 0,
591 "kUnknownResourceId should be negative");
592 return *it->getSecond().begin() == kUnknownResourceId;
593 }
594
595 const llvm::SmallSet<int64_t, 8>&
GetResourceUniqueIds(Value resource) const596 ResourceAliasAnalysisInfo::GetResourceUniqueIds(Value resource) const {
597 assert(!IsUnknownResource(resource));
598 auto it = resource_value_to_ids_.find(resource);
599 assert(it != resource_value_to_ids_.end() && "Unseen resource was queried");
600 return it->getSecond();
601 }
602
603 const llvm::SmallSetVector<Value, 8>&
GetUniqueIdResources(const int64_t id) const604 ResourceAliasAnalysisInfo::GetUniqueIdResources(const int64_t id) const {
605 auto it = id_to_resource_values_.find(id);
606 assert(it != id_to_resource_values_.end() && "Unseen id was queried");
607 return it->getSecond();
608 }
609
GetResourceAliases(Value resource) const610 llvm::SmallSetVector<Value, 8> ResourceAliasAnalysisInfo::GetResourceAliases(
611 Value resource) const {
612 assert(!IsUnknownResource(resource) && "Unknown resource was queried");
613 llvm::SmallSetVector<Value, 8> aliases;
614 for (int64_t id : GetResourceUniqueIds(resource)) {
615 const llvm::SmallSetVector<Value, 8>& resources_aliasing_id =
616 GetUniqueIdResources(id);
617 aliases.insert(resources_aliasing_id.begin(), resources_aliasing_id.end());
618 }
619 // If there are resources that were marked as unknown, they alias with all
620 // other resources.
621 auto it = id_to_resource_values_.find(kUnknownResourceId);
622 if (it != id_to_resource_values_.end())
623 aliases.insert(it->getSecond().begin(), it->getSecond().end());
624 return aliases;
625 }
626
627 } // namespace detail
628
629 //===----------------------------------------------------------------------===//
630 // ResourceAliasAnalysis
631 //===----------------------------------------------------------------------===//
632
ResourceAliasAnalysis(ModuleOp module)633 ResourceAliasAnalysis::ResourceAliasAnalysis(ModuleOp module) {
634 // Create symbol table for module.
635 SymbolTableCollection symbol_table_collection;
636 symbol_table_collection.getSymbolTable(module);
637 // Analyze all regions for backtracking info.
638 detail::BacktrackAnalysis backtrack_analysis(module, symbol_table_collection);
639
640 // Analyze each function.
641 for (auto func : module.getOps<func::FuncOp>())
642 this->info_map_.try_emplace(func, func, backtrack_analysis,
643 symbol_table_collection);
644 }
645
646 } // namespace TF
647 } // namespace mlir
648