xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/mlir/tpu_add_resource_device_attribute.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/StringRef.h"
17 #include "llvm/Support/FormatVariadic.h"
18 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Builders.h"  // from @llvm-project
21 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
22 #include "mlir/IR/Diagnostics.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/Types.h"  // from @llvm-project
25 #include "mlir/IR/Value.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassManager.h"  // from @llvm-project
28 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
29 #include "mlir/Transforms/Passes.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
32 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33 #include "tensorflow/dtensor/mlir/dtensor_mlir_passes_classes.h"
34 #include "tensorflow/dtensor/mlir/layout_parsing.h"
35 
36 namespace tensorflow {
37 namespace dtensor {
38 namespace {
39 
40 constexpr char kFuncDeviceAttr[] = "tf.device";
41 
42 // Returns whether `val` is of resource type.
IsResourceType(mlir::Value val)43 bool IsResourceType(mlir::Value val) {
44   return val.isa<mlir::BlockArgument>() && val.getType()
45                                                .cast<mlir::TensorType>()
46                                                .getElementType()
47                                                .isa<mlir::TF::ResourceType>();
48 }
49 
50 // Adds device attribute to `arg` with the device placement of `execute_op`
AddPlaceholderDeviceAttributeToResource(mlir::BlockArgument arg,mlir::TF::TPUExecuteOp execute_op)51 void AddPlaceholderDeviceAttributeToResource(
52     mlir::BlockArgument arg, mlir::TF::TPUExecuteOp execute_op) {
53   // TPUExecute op is wrapped inside tf_device.Launch op for device assignment.
54   auto tpu_execute_device_launch =
55       execute_op->getParentOfType<mlir::tf_device::LaunchOp>();
56   mlir::StringRef tpu_device_attr = tpu_execute_device_launch.device();
57 
58   auto function = execute_op->getParentOfType<mlir::func::FuncOp>();
59   mlir::OpBuilder builder(execute_op);
60   function.setArgAttr(arg.getArgNumber(), kFuncDeviceAttr,
61                       builder.getStringAttr(tpu_device_attr));
62 }
63 
64 // Returns AssignVariableOp that consumes output of `val`. `val` is a output
65 // from TPUExecute op which is wrapped inside a single tf_device.Launch
66 // operation. As so, output of parent launch op is queried to identify connected
67 // AssignVariable op.
IdentifyConnectedAssignVariableOp(mlir::Value val)68 mlir::Operation* IdentifyConnectedAssignVariableOp(mlir::Value val) {
69   for (mlir::OpOperand& use : val.getUses()) {
70     auto return_op = llvm::dyn_cast<mlir::tf_device::ReturnOp>(use.getOwner());
71     if (!return_op) continue;
72 
73     auto parent_launch =
74         val.getDefiningOp()->getParentOfType<mlir::tf_device::LaunchOp>();
75     mlir::Value launch_output = parent_launch.getResult(use.getOperandNumber());
76     for (mlir::Operation* user : launch_output.getUsers()) {
77       auto assign_variable = llvm::dyn_cast<mlir::TF::AssignVariableOp>(user);
78       if (!assign_variable) continue;
79 
80       return assign_variable;
81     }
82   }
83   return nullptr;
84 }
85 
86 struct DTensorTpuAddResourceDeviceAttribute
87     : public DTensorTpuAddResourceDeviceAttributeBase<
88           DTensorTpuAddResourceDeviceAttribute> {
runOnOperationtensorflow::dtensor::__anon390ddbda0111::DTensorTpuAddResourceDeviceAttribute89   void runOnOperation() override {
90     mlir::MLIRContext& context = getContext();
91     mlir::OpBuilder op_builder(&context);
92     mlir::ModuleOp module = getOperation();
93     // For each resource value that is input or that is consumed by TPUExecute
94     // op, add placeholder device attribute to the resource argument.
95     mlir::WalkResult walk_result =
96         module.walk([](mlir::TF::TPUExecuteOp tpu_execute) {
97           for (mlir::Value tpu_input : tpu_execute.getOperands()) {
98             if (IsResourceType(tpu_input))
99               AddPlaceholderDeviceAttributeToResource(
100                   tpu_input.cast<mlir::BlockArgument>(), tpu_execute);
101 
102             mlir::Operation* input_op = tpu_input.getDefiningOp();
103             auto read_variable_op =
104                 llvm::dyn_cast_or_null<mlir::TF::ReadVariableOp>(input_op);
105             if (!read_variable_op) continue;
106 
107             AddPlaceholderDeviceAttributeToResource(
108                 read_variable_op.resource().cast<mlir::BlockArgument>(),
109                 tpu_execute);
110           }
111 
112           for (mlir::Value result : tpu_execute.getResults()) {
113             mlir::Operation* assign_variable =
114                 IdentifyConnectedAssignVariableOp(result);
115             if (assign_variable == nullptr) continue;
116 
117             AddPlaceholderDeviceAttributeToResource(
118                 llvm::cast<mlir::TF::AssignVariableOp>(assign_variable)
119                     .resource()
120                     .cast<mlir::BlockArgument>(),
121                 tpu_execute);
122           }
123 
124           return mlir::WalkResult::advance();
125         });
126 
127     if (walk_result.wasInterrupted()) return signalPassFailure();
128   };
129 };
130 
131 }  // namespace
132 
133 // Adds placeholder device attributes to resource arguments of TPU functions.
134 // Device attribute added is consistent with device placement of TPUExecute op.
135 // This is required for enabling CreateTPUMergeVariablesWithExecutePass as the
136 // pass checks that all resources must have consistent device placement with
137 // TPUExecute op in order to enable buffer aliasing.
138 std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
CreateDTensorTpuAddResourceDeviceAttribute()139 CreateDTensorTpuAddResourceDeviceAttribute() {
140   return std::make_unique<DTensorTpuAddResourceDeviceAttribute>();
141 }
142 
143 }  // namespace dtensor
144 }  // namespace tensorflow
145