1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <cstddef>
17 #include <cstdint>
18 #include <string>
19 #include <utility>
20
21 #include "llvm/ADT/DenseMap.h"
22 #include "llvm/ADT/STLExtras.h"
23 #include "llvm/ADT/SmallVector.h"
24 #include "llvm/Support/Debug.h"
25 #include "mlir/Pass/Pass.h" // from @llvm-project
26 #include "mlir/Pass/PassManager.h" // from @llvm-project
27 #include "mlir/Support/LLVM.h" // from @llvm-project
28 #include "mlir/Transforms/Passes.h" // from @llvm-project
29 #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h"
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31
32 namespace mlir {
33 namespace tf_test {
34 namespace {
35
36 // A pass that annotates each operation with a resource type result with the
37 // aliasing values for each such result. Each value is assigned a unique ID, and
38 // that ID is used to annotate the operations.
39 struct TestResourceAliasAnalysis
40 : public TF::PerFunctionAggregateAnalysisConsumerPass<
41 TestResourceAliasAnalysis, TF::ResourceAliasAnalysis> {
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_IDmlir::tf_test::__anon24e9faff0111::TestResourceAliasAnalysis42 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestResourceAliasAnalysis)
43
44 StringRef getArgument() const final {
45 return "tf-test-resource-alias-analysis";
46 }
47
getDescriptionmlir::tf_test::__anon24e9faff0111::TestResourceAliasAnalysis48 StringRef getDescription() const final {
49 return "Add remarks based on resource alias analysis result, for testing "
50 "purpose.";
51 }
52
runOnFunctionmlir::tf_test::__anon24e9faff0111::TestResourceAliasAnalysis53 void runOnFunction(func::FuncOp func,
54 const TF::ResourceAliasAnalysis::Info& analysis) {
55 int64_t next_id = 0;
56 llvm::SmallDenseMap<Value, int64_t, 8> ids;
57
58 auto assign_id = [&](Value value) {
59 if (ids.find(value) == ids.end()) ids.insert({value, next_id++});
60 };
61
62 auto get_id = [&](Value value) -> int64_t {
63 auto it = ids.find(value);
64 assert(it != ids.end());
65 return it->second;
66 };
67
68 auto print_aliases = [&](InFlightDiagnostic& diag, Value value) {
69 diag << ", ID " << get_id(value) << " : ";
70 if (analysis.IsUnknownResource(value)) {
71 diag << "Unknown";
72 } else {
73 auto aliases = llvm::to_vector<4>(analysis.GetResourceAliases(value));
74 llvm::sort(aliases,
75 [&](Value v1, Value v2) { return get_id(v1) < get_id(v2); });
76 llvm::interleaveComma(aliases, diag,
77 [&](Value v) { diag << get_id(v); });
78 }
79 };
80
81 // Assign a unique ID to each value seen in this function.
82 func.walk([&](Operation* op) {
83 // For all attached regions, assign ID to the region arguments.
84 for (Region& region : op->getRegions()) {
85 for (auto region_arg : TF::filter_resources(region.getArguments()))
86 assign_id(region_arg);
87 }
88
89 // Assign ID for all results.
90 for (auto result : TF::filter_resources(op->getResults()))
91 assign_id(result);
92 });
93
94 // Now walk each operation, and annotate it wil remarks for aliases for
95 // each resource type result
96 func.walk([&](Operation* op) {
97 // For all attached regions, assign ID to the region arguments.
98 for (Region& region : op->getRegions()) {
99 for (auto region_arg : TF::filter_resources(region.getArguments())) {
100 InFlightDiagnostic diag = op->emitRemark("Region #")
101 << region.getRegionNumber() << ", Arg #"
102 << region_arg.getArgNumber();
103 print_aliases(diag, region_arg);
104 }
105 }
106
107 for (auto result : TF::filter_resources(op->getResults())) {
108 InFlightDiagnostic diag = op->emitRemark("Result #")
109 << result.getResultNumber();
110 print_aliases(diag, result);
111 }
112 });
113 }
114 };
115
116 } // anonymous namespace
117
CreateTestResourceAliasAnalysisPass()118 std::unique_ptr<OperationPass<ModuleOp>> CreateTestResourceAliasAnalysisPass() {
119 return std::make_unique<TestResourceAliasAnalysis>();
120 }
121
122 } // namespace tf_test
123 } // namespace mlir
124