1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/Support/Debug.h"
18 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
21 
22 #define DEBUG_TYPE "tf-device-mark-input-output-aliases"
23 
24 namespace mlir {
25 namespace TFDevice {
26 
27 namespace {
28 struct MarkInputOutputAliasesPass
29     : public TF::MarkInputOutputAliasesPassBase<MarkInputOutputAliasesPass> {
30   void runOnOperation() override;
31 };
32 
33 constexpr char kAliasingAttr[] = "tf.aliasing_output";
34 constexpr int kUnassigned = -1;
35 
36 struct AliasInfo {
AliasInfomlir::TFDevice::__anoncdd3ccd70111::AliasInfo37   AliasInfo() : input_index(kUnassigned), output_index(kUnassigned) {}
38   int input_index;
39   int output_index;
40 };
41 
42 // Idenitfy tf_device.cluster_func input-output alias pairs.
43 // This is currently conservative, and handles the following simple case:
44 // ```
45 // %value = tf.ReadVariableOp(%resource_var)
46 // %output:N = tf_device.cluster_func(..., /*input index = a*/ %value, ...)
47 // tf.AssignVariableOp(%resource_var, %output#b) // write output #b to resource
48 // ```
49 // where `%value` and `%output#b` have only one use. (a, b) would be added as
50 // input-output alias pair for `%resource_var`.
51 //
52 // TODO(b/184420848): Explore relaxing these constraints.
BuildAliasingInfo(tf_device::ClusterFuncOp cluster_func,llvm::DenseMap<Value,AliasInfo> & resource_alias_info_map)53 LogicalResult BuildAliasingInfo(
54     tf_device::ClusterFuncOp cluster_func,
55     llvm::DenseMap<Value, AliasInfo>& resource_alias_info_map) {
56   for (auto result : cluster_func.getResults()) {
57     if (!result.hasOneUse()) continue;
58     auto assign_op = llvm::dyn_cast_or_null<TF::AssignVariableOp>(
59         result.use_begin()->getOwner());
60     if (!assign_op) continue;
61     AliasInfo& alias_info = resource_alias_info_map[assign_op.resource()];
62     // TODO(b/184420848): We may not need to skip aliasing for entire function
63     // in case of multiple assigns.
64     if (alias_info.output_index != kUnassigned) {
65       LLVM_DEBUG(
66           llvm::dbgs()
67           << "Skip adding aliasing information because of multiple assigns to "
68              "the same resource from tf_device.cluster_func outputs. This can "
69              "lead to poor memory management on device.\n");
70 
71       return failure();
72     }
73     alias_info.output_index = result.getResultNumber();
74   }
75 
76   for (auto& operand : cluster_func->getOpOperands()) {
77     auto read_op = llvm::dyn_cast_or_null<TF::ReadVariableOp>(
78         operand.get().getDefiningOp());
79     if (!read_op) continue;
80     if (!read_op->hasOneUse()) continue;
81     auto it = resource_alias_info_map.find(read_op.resource());
82     if (it == resource_alias_info_map.end()) continue;
83     AliasInfo& alias_info = it->getSecond();
84     // TODO(b/184420848): We may not need to skip aliasing for entire function
85     // in case of multiple reads from same resource variable.
86     if (alias_info.input_index != kUnassigned) {
87       LLVM_DEBUG(
88           llvm::dbgs()
89           << "Skip adding aliasing information because of multiple reads of "
90              "the same resource in tf_device.cluster_func inputs. This can "
91              "lead to poor memory management on device.\n");
92       return failure();
93     }
94 
95     alias_info.input_index = operand.getOperandNumber();
96   }
97   return success();
98 }
99 
AddAliasingAttributeToDeviceFunc(func::FuncOp device_func,llvm::DenseMap<Value,AliasInfo> & resource_alias_info_map)100 void AddAliasingAttributeToDeviceFunc(
101     func::FuncOp device_func,
102     llvm::DenseMap<Value, AliasInfo>& resource_alias_info_map) {
103   OpBuilder builder(device_func.getContext());
104   for (const auto& resource_alias_entry : resource_alias_info_map) {
105     const AliasInfo& alias_info = resource_alias_entry.second;
106     if (alias_info.input_index == kUnassigned ||
107         alias_info.output_index == kUnassigned)
108       continue;
109     auto aliasing_attr = device_func.getArgAttrOfType<mlir::IntegerAttr>(
110         alias_info.input_index, kAliasingAttr);
111 
112     // Set only if aliasing attribute does not exist.
113     if (!aliasing_attr) {
114       device_func.setArgAttr(
115           alias_info.input_index, kAliasingAttr,
116           builder.getI64IntegerAttr(alias_info.output_index));
117       continue;
118     }
119     // If aliasing attribute already exists, it must match the new value.
120     assert(aliasing_attr.getInt() == alias_info.output_index);
121   }
122 }
123 
runOnOperation()124 void MarkInputOutputAliasesPass::runOnOperation() {
125   SmallVector<tf_device::ClusterFuncOp, 4> cluster_funcs;
126   ModuleOp module = getOperation();
127   module.walk([&](tf_device::ClusterFuncOp cluster_func) {
128     // Map resource values to pair of input-output indices.
129     llvm::DenseMap<Value, AliasInfo> resource_alias_info_map;
130     if (failed(BuildAliasingInfo(cluster_func, resource_alias_info_map)) ||
131         resource_alias_info_map.empty()) {
132       return;
133     }
134 
135     FlatSymbolRefAttr func_attr = cluster_func.funcAttr();
136     func::FuncOp device_func =
137         module.lookupSymbol<func::FuncOp>(func_attr.getValue());
138     AddAliasingAttributeToDeviceFunc(device_func, resource_alias_info_map);
139   });
140 }
141 
142 }  // namespace
143 
CreateMarkInputOutputAliasesPass()144 std::unique_ptr<OperationPass<ModuleOp>> CreateMarkInputOutputAliasesPass() {
145   return std::make_unique<MarkInputOutputAliasesPass>();
146 }
147 
148 }  // namespace TFDevice
149 }  // namespace mlir
150