1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Builders.h" // from @llvm-project
21 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
22 #include "mlir/IR/Diagnostics.h" // from @llvm-project
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/IR/Types.h" // from @llvm-project
25 #include "mlir/IR/Value.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassManager.h" // from @llvm-project
28 #include "mlir/Support/LogicalResult.h" // from @llvm-project
29 #include "mlir/Transforms/Passes.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
34 #include "tensorflow/dtensor/mlir/layout_parsing.h"
35
36 namespace tensorflow {
37 namespace dtensor {
38 namespace {
39
40 constexpr char kFuncDeviceAttr[] = "tf.device";
41
42 // Returns whether `val` is of resource type.
IsResourceType(mlir::Value val)43 bool IsResourceType(mlir::Value val) {
44 return val.isa<mlir::BlockArgument>() && val.getType()
45 .cast<mlir::TensorType>()
46 .getElementType()
47 .isa<mlir::TF::ResourceType>();
48 }
49
50 // Adds device attribute to `arg` with the device placement of `execute_op`
AddPlaceholderDeviceAttributeToResource(mlir::BlockArgument arg,mlir::TF::TPUExecuteOp execute_op)51 void AddPlaceholderDeviceAttributeToResource(
52 mlir::BlockArgument arg, mlir::TF::TPUExecuteOp execute_op) {
53 // TPUExecute op is wrapped inside tf_device.Launch op for device assignment.
54 auto tpu_execute_device_launch =
55 execute_op->getParentOfType<mlir::tf_device::LaunchOp>();
56 mlir::StringRef tpu_device_attr = tpu_execute_device_launch.device();
57
58 auto function = execute_op->getParentOfType<mlir::func::FuncOp>();
59 mlir::OpBuilder builder(execute_op);
60 function.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
61 builder.getStringAttr(tpu_device_attr));
62 }
63
64 // Returns AssignVariableOp that consumes output of `val`. `val` is a output
65 // from TPUExecute op which is wrapped inside a single tf_device.Launch
66 // operation. As so, output of parent launch op is queried to identify connected
67 // AssignVariable op.
IdentifyConnectedAssignVariableOp(mlir::Value val)68 mlir::Operation* IdentifyConnectedAssignVariableOp(mlir::Value val) {
69 for (mlir::OpOperand& use : val.getUses()) {
70 auto return_op = llvm::dyn_cast<mlir::tf_device::ReturnOp>(use.getOwner());
71 if (!return_op) continue;
72
73 auto parent_launch =
74 val.getDefiningOp()->getParentOfType<mlir::tf_device::LaunchOp>();
75 mlir::Value launch_output = parent_launch.getResult(use.getOperandNumber());
76 for (mlir::Operation* user : launch_output.getUsers()) {
77 auto assign_variable = llvm::dyn_cast<mlir::TF::AssignVariableOp>(user);
78 if (!assign_variable) continue;
79
80 return assign_variable;
81 }
82 }
83 return nullptr;
84 }
85
86 struct DTensorTpuAddResourceDeviceAttribute
87 : public DTensorTpuAddResourceDeviceAttributeBase<
88 DTensorTpuAddResourceDeviceAttribute> {
runOnOperationtensorflow::dtensor::__anon390ddbda0111::DTensorTpuAddResourceDeviceAttribute89 void runOnOperation() override {
90 mlir::MLIRContext& context = getContext();
91 mlir::OpBuilder op_builder(&context);
92 mlir::ModuleOp module = getOperation();
93 // For each resource value that is input or that is consumed by TPUExecute
94 // op, add placeholder device attribute to the resource argument.
95 mlir::WalkResult walk_result =
96 module.walk([](mlir::TF::TPUExecuteOp tpu_execute) {
97 for (mlir::Value tpu_input : tpu_execute.getOperands()) {
98 if (IsResourceType(tpu_input))
99 AddPlaceholderDeviceAttributeToResource(
100 tpu_input.cast<mlir::BlockArgument>(), tpu_execute);
101
102 mlir::Operation* input_op = tpu_input.getDefiningOp();
103 auto read_variable_op =
104 llvm::dyn_cast_or_null<mlir::TF::ReadVariableOp>(input_op);
105 if (!read_variable_op) continue;
106
107 AddPlaceholderDeviceAttributeToResource(
108 read_variable_op.resource().cast<mlir::BlockArgument>(),
109 tpu_execute);
110 }
111
112 for (mlir::Value result : tpu_execute.getResults()) {
113 mlir::Operation* assign_variable =
114 IdentifyConnectedAssignVariableOp(result);
115 if (assign_variable == nullptr) continue;
116
117 AddPlaceholderDeviceAttributeToResource(
118 llvm::cast<mlir::TF::AssignVariableOp>(assign_variable)
119 .resource()
120 .cast<mlir::BlockArgument>(),
121 tpu_execute);
122 }
123
124 return mlir::WalkResult::advance();
125 });
126
127 if (walk_result.wasInterrupted()) return signalPassFailure();
128 };
129 };
130
131 } // namespace
132
133 // Adds placeholder device attributes to resource arguments of TPU functions.
134 // Device attribute added is consistent with device placement of TPUExecute op.
135 // This is required for enabling CreateTPUMergeVariablesWithExecutePass as the
136 // pass checks that all resources must have consistent device placement with
137 // TPUExecute op in order to enable buffer aliasing.
138 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorTpuAddResourceDeviceAttribute()139 CreateDTensorTpuAddResourceDeviceAttribute() {
140 return std::make_unique<DTensorTpuAddResourceDeviceAttribute>();
141 }
142
143 } // namespace dtensor
144 } // namespace tensorflow
145