1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Support/Debug.h"
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
21
22 #define DEBUG_TYPE "tf-device-mark-input-output-aliases"
23
24 namespace mlir {
25 namespace TFDevice {
26
27 namespace {
28 struct MarkInputOutputAliasesPass
29 : public TF::MarkInputOutputAliasesPassBase<MarkInputOutputAliasesPass> {
30 void runOnOperation() override;
31 };
32
33 constexpr char kAliasingAttr[] = "tf.aliasing_output";
34 constexpr int kUnassigned = -1;
35
36 struct AliasInfo {
AliasInfomlir::TFDevice::__anoncdd3ccd70111::AliasInfo37 AliasInfo() : input_index(kUnassigned), output_index(kUnassigned) {}
38 int input_index;
39 int output_index;
40 };
41
42 // Idenitfy tf_device.cluster_func input-output alias pairs.
43 // This is currently conservative, and handles the following simple case:
44 // ```
45 // %value = tf.ReadVariableOp(%resource_var)
46 // %output:N = tf_device.cluster_func(..., /*input index = a*/ %value, ...)
47 // tf.AssignVariableOp(%resource_var, %output#b) // write output #b to resource
48 // ```
49 // where `%value` and `%output#b` have only one use. (a, b) would be added as
50 // input-output alias pair for `%resource_var`.
51 //
52 // TODO(b/184420848): Explore relaxing these constraints.
BuildAliasingInfo(tf_device::ClusterFuncOp cluster_func,llvm::DenseMap<Value,AliasInfo> & resource_alias_info_map)53 LogicalResult BuildAliasingInfo(
54 tf_device::ClusterFuncOp cluster_func,
55 llvm::DenseMap<Value, AliasInfo>& resource_alias_info_map) {
56 for (auto result : cluster_func.getResults()) {
57 if (!result.hasOneUse()) continue;
58 auto assign_op = llvm::dyn_cast_or_null<TF::AssignVariableOp>(
59 result.use_begin()->getOwner());
60 if (!assign_op) continue;
61 AliasInfo& alias_info = resource_alias_info_map[assign_op.resource()];
62 // TODO(b/184420848): We may not need to skip aliasing for entire function
63 // in case of multiple assigns.
64 if (alias_info.output_index != kUnassigned) {
65 LLVM_DEBUG(
66 llvm::dbgs()
67 << "Skip adding aliasing information because of multiple assigns to "
68 "the same resource from tf_device.cluster_func outputs. This can "
69 "lead to poor memory management on device.\n");
70
71 return failure();
72 }
73 alias_info.output_index = result.getResultNumber();
74 }
75
76 for (auto& operand : cluster_func->getOpOperands()) {
77 auto read_op = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
78 operand.get().getDefiningOp());
79 if (!read_op) continue;
80 if (!read_op->hasOneUse()) continue;
81 auto it = resource_alias_info_map.find(read_op.resource());
82 if (it == resource_alias_info_map.end()) continue;
83 AliasInfo& alias_info = it->getSecond();
84 // TODO(b/184420848): We may not need to skip aliasing for entire function
85 // in case of multiple reads from same resource variable.
86 if (alias_info.input_index != kUnassigned) {
87 LLVM_DEBUG(
88 llvm::dbgs()
89 << "Skip adding aliasing information because of multiple reads of "
90 "the same resource in tf_device.cluster_func inputs. This can "
91 "lead to poor memory management on device.\n");
92 return failure();
93 }
94
95 alias_info.input_index = operand.getOperandNumber();
96 }
97 return success();
98 }
99
AddAliasingAttributeToDeviceFunc(func::FuncOp device_func,llvm::DenseMap<Value,AliasInfo> & resource_alias_info_map)100 void AddAliasingAttributeToDeviceFunc(
101 func::FuncOp device_func,
102 llvm::DenseMap<Value, AliasInfo>& resource_alias_info_map) {
103 OpBuilder builder(device_func.getContext());
104 for (const auto& resource_alias_entry : resource_alias_info_map) {
105 const AliasInfo& alias_info = resource_alias_entry.second;
106 if (alias_info.input_index == kUnassigned ||
107 alias_info.output_index == kUnassigned)
108 continue;
109 auto aliasing_attr = device_func.getArgAttrOfType<mlir::IntegerAttr>(
110 alias_info.input_index, kAliasingAttr);
111
112 // Set only if aliasing attribute does not exist.
113 if (!aliasing_attr) {
114 device_func.setArgAttr(
115 alias_info.input_index, kAliasingAttr,
116 builder.getI64IntegerAttr(alias_info.output_index));
117 continue;
118 }
119 // If aliasing attribute already exists, it must match the new value.
120 assert(aliasing_attr.getInt() == alias_info.output_index);
121 }
122 }
123
runOnOperation()124 void MarkInputOutputAliasesPass::runOnOperation() {
125 SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
126 ModuleOp module = getOperation();
127 module.walk([&](tf_device::ClusterFuncOp cluster_func) {
128 // Map resource values to pair of input-output indices.
129 llvm::DenseMap<Value, AliasInfo> resource_alias_info_map;
130 if (failed(BuildAliasingInfo(cluster_func, resource_alias_info_map)) ||
131 resource_alias_info_map.empty()) {
132 return;
133 }
134
135 FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
136 func::FuncOp device_func =
137 module.lookupSymbol<func::FuncOp>(func_attr.getValue());
138 AddAliasingAttributeToDeviceFunc(device_func, resource_alias_info_map);
139 });
140 }
141
142 } // namespace
143
CreateMarkInputOutputAliasesPass()144 std::unique_ptr<OperationPass<ModuleOp>> CreateMarkInputOutputAliasesPass() {
145 return std::make_unique<MarkInputOutputAliasesPass>();
146 }
147
148 } // namespace TFDevice
149 } // namespace mlir
150