1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/ArrayRef.h"
17 #include "llvm/ADT/DenseMap.h"
18 #include "llvm/ADT/STLExtras.h"
19 #include "llvm/ADT/SmallVector.h"
20 #include "llvm/Support/Casting.h"
21 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22 #include "mlir/IR/Attributes.h" // from @llvm-project
23 #include "mlir/IR/Builders.h" // from @llvm-project
24 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
25 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
26 #include "mlir/IR/Location.h" // from @llvm-project
27 #include "mlir/IR/MLIRContext.h" // from @llvm-project
28 #include "mlir/IR/Operation.h" // from @llvm-project
29 #include "mlir/IR/OperationSupport.h" // from @llvm-project
30 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
31 #include "mlir/IR/Types.h" // from @llvm-project
32 #include "mlir/IR/UseDefLists.h" // from @llvm-project
33 #include "mlir/IR/Value.h" // from @llvm-project
34 #include "mlir/Pass/Pass.h" // from @llvm-project
35 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
36 #include "mlir/Support/LLVM.h" // from @llvm-project
37 #include "tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.h"
38 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
39 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
40 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
41 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
42 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
43 #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h"
44 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
45 #include "tensorflow/core/util/device_name_utils.h"
46
47 namespace mlir {
48 namespace TFTPU {
49
50 namespace {
51
52 constexpr char kDeviceAttr[] = "device";
53 constexpr char kDeviceCPU[] = "CPU";
54 constexpr char kFuncDeviceAttr[] = "tf.device";
55
56 struct TPUDynamicLayoutPass
57 : public TF::PerFunctionAggregateAnalysisConsumerPass<
58 TPUDynamicLayoutPass, TF::ResourceAliasAnalysis> {
59 MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TPUDynamicLayoutPass)
60
61 void runOnFunction(
62 func::FuncOp func,
63 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis);
64
getArgumentmlir::TFTPU::__anon346adf1c0111::TPUDynamicLayoutPass65 StringRef getArgument() const final { return "tf-tpu-dynamic-layout-pass"; }
66
getDescriptionmlir::TFTPU::__anon346adf1c0111::TPUDynamicLayoutPass67 StringRef getDescription() const final {
68 return "Inserts TPU layout ops to determine layout at run time.";
69 }
70 };
71
72 // Checks if the input producer op is supported in this transform. Right now, we
73 // only check if it is a tf.IteratorGetNext where resource input is coming from
74 // a VarHandle on CPU or a function argument assigned to CPU.
IsSupportedInputOp(Operation * op,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)75 bool IsSupportedInputOp(
76 Operation* op,
77 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
78 TF::IteratorGetNextOp iterator_op = llvm::dyn_cast<TF::IteratorGetNextOp>(op);
79 if (!iterator_op) return false;
80
81 Value resource_iterator = iterator_op.iterator();
82
83 if (resource_alias_analysis.IsUnknownResource(resource_iterator))
84 return false;
85 llvm::SmallSetVector<Value, 8> aliases =
86 resource_alias_analysis.GetResourceAliases(resource_iterator);
87
88 auto is_generator = [](Value val) {
89 if (val.isa<BlockArgument>()) return true;
90 Operation* definition = val.getDefiningOp();
91 return definition->getNumOperands() == 0 &&
92 definition->getNumResults() == 1;
93 };
94
95 // Check all generator aliases (ops or function argument) are on CPU.
96 func::FuncOp func = iterator_op->getParentOfType<func::FuncOp>();
97 return llvm::all_of(aliases, [&](Value alias) {
98 // Ignore non-generator aliases.
99 if (!is_generator(alias)) return true;
100
101 StringAttr device;
102 if (auto arg = alias.dyn_cast<BlockArgument>()) {
103 device = func.getArgAttrOfType<mlir::StringAttr>(arg.getArgNumber(),
104 kFuncDeviceAttr);
105 } else {
106 device = alias.getDefiningOp()->getAttrOfType<StringAttr>(kDeviceAttr);
107 }
108
109 if (!device) return false;
110 tensorflow::DeviceNameUtils::ParsedName parsed_device;
111 if (!tensorflow::DeviceNameUtils::ParseFullName(device.getValue().str(),
112 &parsed_device)) {
113 return false;
114 }
115 return parsed_device.has_type && parsed_device.type == kDeviceCPU;
116 });
117 }
118
CreateBuilderAfterOp(Operation * op)119 OpBuilder CreateBuilderAfterOp(Operation* op) {
120 return OpBuilder(op->getBlock(), ++Block::iterator(op));
121 }
122
123 // Builds a TPUGetLayoutOp with the given compile op and input index.
BuildGetLayout(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp compile_launch,OpBuilder * builder)124 TF::TPUGetLayoutOp BuildGetLayout(const int64_t execute_arg_index,
125 Value compilation_key,
126 tf_device::LaunchOp compile_launch,
127 OpBuilder* builder) {
128 return builder->create<TF::TPUGetLayoutOp>(
129 compile_launch.getLoc(),
130 llvm::ArrayRef<Type>{RankedTensorType::get({ShapedType::kDynamicSize},
131 builder->getIntegerType(64))},
132 llvm::ArrayRef<Value>{compilation_key},
133 llvm::ArrayRef<NamedAttribute>{
134 builder->getNamedAttr("index",
135 builder->getI64IntegerAttr(execute_arg_index)),
136 builder->getNamedAttr("is_output", builder->getBoolAttr(false))});
137 }
138
139 // Builds a TPUCopyWithLayoutOp with the given get_layout op and input.
BuildCopyWithLayout(tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,TF::TPUGetLayoutOp get_layout,Value input,OpBuilder * builder)140 TF::TPUCopyWithLayoutOp BuildCopyWithLayout(tf_device::LaunchOp execute_launch,
141 tf_device::LaunchOp compile_launch,
142 TF::TPUGetLayoutOp get_layout,
143 Value input, OpBuilder* builder) {
144 return builder->create<TF::TPUCopyWithLayoutOp>(
145 execute_launch.getLoc(), llvm::ArrayRef<Type>{input.getType()},
146 llvm::ArrayRef<Value>{input, get_layout.layout()});
147 }
148
149 // Performs transformation for a non-replicated input.
HandleInput(Value input,const int64_t execute_arg_index,TF::TPUExecuteOp execute,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch)150 void HandleInput(Value input, const int64_t execute_arg_index,
151 TF::TPUExecuteOp execute, tf_device::LaunchOp execute_launch,
152 tf_device::LaunchOp compile_launch) {
153 OpBuilder builder = CreateBuilderAfterOp(compile_launch);
154 auto get_layout = BuildGetLayout(execute_arg_index, execute.key(),
155 compile_launch, &builder);
156 builder.setInsertionPoint(execute_launch);
157 auto copy_with_layout = BuildCopyWithLayout(execute_launch, compile_launch,
158 get_layout, input, &builder);
159 copy_with_layout->setAttr(kDeviceAttr, execute_launch.deviceAttr());
160 execute.setOperand(execute_arg_index, copy_with_layout);
161 }
162
163 // Performs transformation for replicated inputs. Returns true if this is a
164 // supported case (thus transform happened).
HandleReplicatedInputs(const int64_t execute_arg_index,Value compilation_key,tf_device::LaunchOp execute_launch,tf_device::LaunchOp compile_launch,mlir::BlockArgument replicate_arg,tf_device::ReplicateOp replicate,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)165 bool HandleReplicatedInputs(
166 const int64_t execute_arg_index, Value compilation_key,
167 tf_device::LaunchOp execute_launch, tf_device::LaunchOp compile_launch,
168 mlir::BlockArgument replicate_arg, tf_device::ReplicateOp replicate,
169 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
170 // We need to know the devices to copy to.
171 if (!replicate.devices()) return false;
172
173 MutableArrayRef<OpOperand> inputs =
174 replicate.GetOperandsForBlockArgument(replicate_arg);
175 for (auto entry : llvm::enumerate(inputs)) {
176 auto input_op = entry.value().get().getDefiningOp();
177 if (!input_op || !IsSupportedInputOp(input_op, resource_alias_analysis))
178 return false;
179 }
180 OpBuilder builder = CreateBuilderAfterOp(compile_launch);
181 auto get_layout = BuildGetLayout(execute_arg_index, compilation_key,
182 compile_launch, &builder);
183 builder.setInsertionPoint(replicate);
184 for (auto entry : llvm::enumerate(inputs)) {
185 auto copy_with_layout =
186 BuildCopyWithLayout(execute_launch, compile_launch, get_layout,
187 entry.value().get(), &builder);
188
189 auto device_list = replicate.devices()
190 .getValue()
191 .get(execute_launch.getDevice())
192 .cast<ArrayAttr>();
193 copy_with_layout->setAttr(kDeviceAttr,
194 device_list.getValue()[entry.index()]);
195
196 entry.value().set(copy_with_layout);
197 }
198 return true;
199 }
200
201 // Performs transformation on a compile and associated execute(s) ops. The
202 // compile should not have other uses.
HandleCompileAndExecutes(tf_device::LaunchOp compile_launch,llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)203 void HandleCompileAndExecutes(
204 tf_device::LaunchOp compile_launch,
205 llvm::MutableArrayRef<tf_device::LaunchOp> execute_launches,
206 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
207 auto compile =
208 llvm::cast<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front());
209 tensorflow::tpu::TPUCompileMetadataProto metadata;
210 metadata.ParseFromString(compile.metadata().str());
211 llvm::SmallVector<llvm::SmallVector<int64_t, 4>, 4> input_mappings =
212 tensorflow::GetMetadataArgumentMapping(metadata);
213
214 bool metadata_updated = false;
215 auto maybe_replicate =
216 execute_launches.front()->getParentOfType<tf_device::ReplicateOp>();
217
218 for (auto execute_and_input_mapping :
219 llvm::zip(execute_launches, input_mappings)) {
220 auto& execute_launch = std::get<0>(execute_and_input_mapping);
221 auto execute =
222 llvm::cast<TF::TPUExecuteOp>(execute_launch.GetBody().front());
223 const auto& input_mapping = std::get<1>(execute_and_input_mapping);
224
225 for (auto& input_and_idx : llvm::enumerate(execute.args())) {
226 Value input = input_and_idx.value();
227 const int64_t execute_arg_index = input_and_idx.index();
228 if (auto block_arg = input.dyn_cast<BlockArgument>()) {
229 // For a block argument, consider transforms only when it is a
230 // replicated input (defining ops will be outside the replicate node).
231 if (maybe_replicate != block_arg.getParentRegion()->getParentOp() ||
232 !HandleReplicatedInputs(execute_arg_index, execute.key(),
233 execute_launch, compile_launch, block_arg,
234 maybe_replicate, resource_alias_analysis)) {
235 continue;
236 }
237 } else {
238 // For an op output, consider transforms only when 1) there is no
239 // replication or 2) it is outside the replicate node that encloses the
240 // execute node. (Because if the op is inside replicate, it is probably
241 // not on the host.)
242 auto* input_op = input.getDefiningOp();
243 if (maybe_replicate &&
244 maybe_replicate.body().isAncestor(input_op->getParentRegion())) {
245 continue;
246 }
247 if (!IsSupportedInputOp(input_op, resource_alias_analysis)) continue;
248 HandleInput(input, execute_arg_index, execute, execute_launch,
249 compile_launch);
250 }
251
252 metadata.mutable_args(input_mapping[execute_arg_index])
253 ->set_unrestricted_layout(true);
254 metadata_updated = true;
255 }
256 }
257
258 if (metadata_updated)
259 compile->setAttr("metadata", StringAttr::get(compile.getContext(),
260 metadata.SerializeAsString()));
261 }
262
runOnFunction(func::FuncOp func,const TF::ResourceAliasAnalysis::Info & resource_alias_analysis)263 void TPUDynamicLayoutPass::runOnFunction(
264 func::FuncOp func,
265 const TF::ResourceAliasAnalysis::Info& resource_alias_analysis) {
266 func.walk([&](TF::_TPUCompileMlirOp compile) {
267 // Detect tf._TPUCompileMlir -> tf.TPUExecute(s).
268 auto compile_launch =
269 llvm::dyn_cast<tf_device::LaunchOp>(compile->getParentOp());
270 if (!compile_launch || !compile_launch.WrapsSingleOp()) return;
271
272 llvm::SmallVector<tf_device::LaunchOp, 4> execute_launches;
273 execute_launches.reserve(compile_launch.getNumResults() - 1);
274 for (Value program_result : llvm::drop_begin(compile_launch.results(), 1)) {
275 if (!program_result.hasOneUse()) return;
276 Operation* user = *program_result.user_begin();
277 auto execute = llvm::dyn_cast<TF::TPUExecuteOp>(user);
278 if (!execute) return;
279 auto execute_launch =
280 llvm::dyn_cast<tf_device::LaunchOp>(execute->getParentOp());
281 if (!execute_launch || !execute_launch.WrapsSingleOp()) return;
282 execute_launches.push_back(execute_launch);
283 }
284
285 HandleCompileAndExecutes(compile_launch, execute_launches,
286 resource_alias_analysis);
287 });
288 }
289
290 } // namespace
291
CreateTPUDynamicLayoutPass()292 std::unique_ptr<OperationPass<ModuleOp>> CreateTPUDynamicLayoutPass() {
293 return std::make_unique<TPUDynamicLayoutPass>();
294 }
295
296 } // namespace TFTPU
297 } // namespace mlir
298