1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <iterator>
17 #include <memory>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21
22 #include "absl/strings/str_cat.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
30 #include "mlir/IR/Attributes.h" // from @llvm-project
31 #include "mlir/IR/Builders.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/Location.h" // from @llvm-project
34 #include "mlir/IR/MLIRContext.h" // from @llvm-project
35 #include "mlir/IR/Operation.h" // from @llvm-project
36 #include "mlir/IR/TypeUtilities.h" // from @llvm-project
37 #include "mlir/IR/Types.h" // from @llvm-project
38 #include "mlir/IR/Value.h" // from @llvm-project
39 #include "mlir/Pass/Pass.h" // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
46 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
50 #include "tensorflow/core/framework/tensor.h"
51 #include "tensorflow/core/framework/tensor_shape.pb.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/platform/random.h"
54 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
55
56 namespace mlir {
57 namespace TFTPU {
58
59 namespace {
60
61 constexpr char kDeviceAttr[] = "device";
62 constexpr char kFuncDeviceAttr[] = "tf.device";
63 constexpr char kDefaultShardingValue[] = "";
64 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
65
GetRandomStateVariableName()66 std::string GetRandomStateVariableName() {
67 return absl::StrCat("VariablesFormatState_", tensorflow::random::New64());
68 }
69
70 struct TPUVariableRuntimeReformattingPass
71 : public TF::TPUVariableRuntimeReformattingPassBase<
72 TPUVariableRuntimeReformattingPass> {
73 void runOnOperation() final;
74 };
75
76 // Returns the earlier value of which `v` is an identity. If `skipped` is
77 // provided, it will be used to store the identity nodes skipped.
SkipIdentity(Value v,bool allow_other_use,llvm::SmallPtrSet<Operation *,4> * skipped=nullptr)78 Value SkipIdentity(Value v, bool allow_other_use,
79 llvm::SmallPtrSet<Operation*, 4>* skipped = nullptr) {
80 while (auto result = v.dyn_cast<OpResult>()) {
81 if (!(allow_other_use || v.hasOneUse())) break;
82 auto op = result.getDefiningOp();
83 if (!llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
84 break;
85 }
86 v = op->getOperand(result.getResultNumber());
87 if (skipped) skipped->insert(op);
88 }
89 return v;
90 }
91
92 // Finds the formattable arguments of `execute` and annotates the metadata of
93 // `compile` to record these arguments. In addition, it returns a mapping from
94 // the formattable arguments of `execute` to the corresponding operand of
95 // `replicate`. The
96 // entries in the mapping are sorted in the order of operands of `execute`.
97 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate,TF::TPUExecuteAndUpdateVariablesOp execute,tf_device::LaunchOp compile_launch)98 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
99 TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate,
100 TF::TPUExecuteAndUpdateVariablesOp execute,
101 tf_device::LaunchOp compile_launch) {
102 Region& body = while_op.body();
103 Region& cond = while_op.cond();
104
105 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
106 auto mirrored_variable_indices_attr =
107 replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
108 if (!mirrored_variable_indices_attr) return mapping;
109
110 // Finds the mapping from a replicate argument to an execute operand.
111 llvm::SmallDenseMap<int64_t, int64_t, 8> replicate_arg_to_execute_arg;
112 for (auto index_and_arg : llvm::enumerate(execute.args())) {
113 auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false);
114 if (!arg.hasOneUse() ||
115 !getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
116 continue;
117 }
118 auto block_arg = arg.dyn_cast<BlockArgument>();
119 if (!block_arg || block_arg.getOwner() != &replicate.GetBody()) continue;
120 assert(replicate_arg_to_execute_arg.count(block_arg.getArgNumber()) == 0 &&
121 "Found duplicate use of a resource in the execute op.");
122 replicate_arg_to_execute_arg[block_arg.getArgNumber()] =
123 index_and_arg.index();
124 }
125 if (replicate_arg_to_execute_arg.empty()) return mapping;
126
127 // Parse the original compile metadata.
128 Operation& compile = compile_launch.GetBody().front();
129 auto metadata_str = compile.getAttrOfType<StringAttr>("metadata");
130 assert(metadata_str && "Missing compilation metadata");
131 tensorflow::tpu::TPUCompileMetadataProto metadata;
132 metadata.ParseFromString(std::string(metadata_str.getValue()));
133 int64_t num_replicas = replicate.n();
134 // Find the formattable operands of `execute`, which must be mirrored
135 // variables (arguments of `replicate`), and must be pass-throughs from while
136 // operands.
137 for (const auto& mirrored_index : mirrored_variable_indices_attr) {
138 int64_t replicate_arg = mirrored_index.cast<IntegerAttr>().getInt();
139 // Check if the mirrored variable is an input to `execute`.
140 auto it = replicate_arg_to_execute_arg.find(replicate_arg);
141 if (it == replicate_arg_to_execute_arg.end()) continue;
142 // Get the data type of the resource.
143 auto subtypes = getElementTypeOrSelf(execute.getOperand(it->second))
144 .cast<TF::ResourceType>()
145 .getSubtypes();
146 if (subtypes.size() != 1) continue;
147 auto data_type = getElementTypeOrSelf(subtypes[0]);
148 // The XLA backend does not yet support formatting 64-bit data types.
149 if (data_type.getIntOrFloatBitWidth() == 64) continue;
150
151 const auto& block_arg = replicate.GetBody().getArgument(replicate_arg);
152
153 int64_t num_inputs = 0;
154 if (replicate.IsReplicatedBlockArgument(block_arg)) {
155 num_inputs = num_replicas;
156 } else {
157 num_inputs = 1;
158 }
159
160 // We have found a mirrored variable which is an input to the replicated
161 // `execute`. Now find if this mirrored variable is a pass-through of while
162 // arguments.
163 llvm::SmallVector<Value, 4> replicate_args;
164 for (int64_t i = 0; i < num_inputs; ++i) {
165 llvm::SmallPtrSet<Operation*, 4> skipped_identities;
166
167 auto replicate_operand = SkipIdentity(
168 replicate.GetReplicaOperandForBlockArgument(block_arg, i),
169 /*allow_other_use=*/false, &skipped_identities);
170 // For region based control flow, the resource operand for the replicate
171 // should be a region capture. If this has any use other than the
172 // replicate op (within the body of the while) or the skipped identities,
173 // then do not apply the transformation to this variable.
174 bool is_region_capture =
175 replicate_operand.getParentRegion()->isProperAncestor(&body);
176 bool has_other_use_in_body =
177 llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) {
178 // Ignore uses that are not in the while body or condition.
179 if (!body.isAncestor(user->getParentRegion()) &&
180 !cond.isAncestor(user->getParentRegion()))
181 return false;
182 // Within the body or cond, only uses in replicate and the skipped
183 // identities is allowed.
184 return user != replicate && skipped_identities.count(user) == 0;
185 });
186
187 if (!is_region_capture || has_other_use_in_body) {
188 replicate_args.clear();
189 break;
190 }
191 replicate_args.push_back(replicate_operand);
192 }
193 if (replicate_args.empty()) continue;
194 // Now set the enable_xla_sharding field in the metadata to inform the
195 // compile op.
196 auto metadata_arg = metadata.mutable_args(it->second);
197 metadata_arg->set_enable_xla_sharding(
198 ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED);
199 mapping.emplace_back(it->second, std::move(replicate_args));
200 }
201 // Sort the mapping according to execute operand order.
202 llvm::sort(mapping, llvm::less_first());
203 // Populate the `retval_index_for_sharding` field of the argument metadate.
204 for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) {
205 int64_t arg_index = entry.value().cast<IntegerAttr>().getInt();
206 auto arg_metadata = metadata.mutable_args(arg_index);
207 if (arg_metadata->enable_xla_sharding() ==
208 ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) {
209 int64_t ret_index = execute.device_var_updates_indices()
210 .getValue()[entry.index()]
211 .cast<IntegerAttr>()
212 .getInt();
213 arg_metadata->set_retval_index_for_sharding(ret_index);
214 }
215 }
216 // Update the metadata of the compile op.
217 compile.setAttr("metadata", StringAttr::get(compile.getContext(),
218 metadata.SerializeAsString()));
219 return mapping;
220 }
221
222 // Adds a new replicated input to the replicate op.
AddInputsToReplicateOp(tf_device::ReplicateOp replicate,MutableArrayRef<TF::VarHandleOp> new_inputs,const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices)223 tf_device::ReplicateOp AddInputsToReplicateOp(
224 tf_device::ReplicateOp replicate,
225 MutableArrayRef<TF::VarHandleOp> new_inputs,
226 const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
227 devices) {
228 int64_t num_replicas = replicate.n();
229 assert(new_inputs.size() == num_replicas);
230
231 // As model parallelism is not yet supported, we assume that all ops are
232 // placed in logical core 0.
233 // TODO(b/148913020): Remove this constraint once model parallelism is
234 // supported.
235 assert(devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))
236 ->getSecond()
237 .size() == num_replicas);
238
239 llvm::SmallVector<std::pair<ValueRange, Type>, 8> new_replicated_inputs;
240 llvm::SmallVector<Value, 8> new_packed_inputs;
241 llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
242 replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments());
243 new_packed_inputs.reserve(replicate.GetNumPackedBlockArguments());
244 for (const auto& arg : replicate.GetReplicatedBlockArguments()) {
245 replicated_inputs.emplace_back();
246 for (int64_t i = 0; i < num_replicas; ++i) {
247 replicated_inputs.back().push_back(
248 replicate.GetReplicaOperandForBlockArgument(arg, i));
249 }
250 new_replicated_inputs.emplace_back(replicated_inputs.back(), arg.getType());
251 }
252 for (const auto& arg : replicate.GetPackedBlockArguments()) {
253 new_packed_inputs.emplace_back(
254 replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0));
255 }
256 SmallVector<Value, 4> new_input_values;
257 new_input_values.reserve(new_inputs.size());
258 for (auto var : new_inputs) new_input_values.push_back(var.resource());
259 new_replicated_inputs.emplace_back(new_input_values,
260 new_input_values.front().getType());
261 OpBuilder builder(replicate);
262 auto new_replicate = builder.create<tf_device::ReplicateOp>(
263 replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
264 new_packed_inputs,
265 replicate.GetBody().getTerminator()->getOperandTypes());
266 for (auto arg : replicate.GetBody().getArguments()) {
267 if (replicate.IsReplicatedBlockArgument(arg)) {
268 arg.replaceAllUsesWith(
269 new_replicate.GetBody().getArgument(arg.getArgNumber()));
270 } else {
271 // There is a new added replicated state variable between replicated args
272 // and packed args.
273 arg.replaceAllUsesWith(
274 new_replicate.GetBody().getArgument(arg.getArgNumber() + 1));
275 }
276 }
277 for (auto& op : llvm::make_early_inc_range(replicate.GetBody())) {
278 op.moveBefore(&new_replicate.GetBody(), new_replicate.GetBody().end());
279 }
280 replicate.replaceAllUsesWith(new_replicate);
281 replicate.erase();
282 return new_replicate;
283 }
284
285 // Creates the per-device variables that represent the formatting state of each
286 // device.
CreateStateVars(const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices,Location loc,RankedTensorType key_type,OpBuilder * builder)287 llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
288 const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
289 devices,
290 Location loc, RankedTensorType key_type, OpBuilder* builder) {
291 llvm::SmallVector<TF::VarHandleOp, 4> state_vars;
292
293 // TODO(b/148913020): Remove this constraint once model parallelism is
294 // supported.
295 const auto& device_list =
296 devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))->getSecond();
297
298 // Create the state variable for each device.
299 for (llvm::StringRef device : device_list) {
300 state_vars.push_back(builder->create<TF::VarHandleOp>(
301 loc,
302 llvm::ArrayRef<Type>{RankedTensorType::get(
303 {}, TF::ResourceType::get(llvm::ArrayRef<TensorType>{key_type},
304 builder->getContext()))},
305 llvm::ArrayRef<Value>{},
306 llvm::ArrayRef<NamedAttribute>{
307 builder->getNamedAttr(kDeviceAttr, builder->getStringAttr(device)),
308 builder->getNamedAttr("container", builder->getStringAttr("")),
309 builder->getNamedAttr(
310 "shared_name",
311 builder->getStringAttr(GetRandomStateVariableName()))}));
312 }
313 return state_vars;
314 }
315
316 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)317 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
318 llvm::StringRef device) {
319 OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
320
321 auto launch = builder->create<tf_device::LaunchOp>(
322 loc, builder->getStringAttr(device), op->getResultTypes());
323 launch.body().push_back(new Block);
324
325 builder->setInsertionPointToEnd(&launch.GetBody());
326 builder->create<tf_device::ReturnOp>(loc, op->getResults());
327
328 // Move op inside launch.
329 op->moveBefore(launch.GetBody().getTerminator());
330
331 builder->restoreInsertionPoint(insert_point);
332 }
333
334 // Performs the transformation for a replicate op inside a while loop.
HandleReplicateOp(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate)335 void HandleReplicateOp(TF::WhileRegionOp while_op,
336 tf_device::ReplicateOp replicate) {
337 int64_t num_replicas = replicate.n();
338 if (num_replicas == 1) return;
339 tf_device::LaunchOp execute_launch;
340 for (auto execute_launch_op :
341 replicate.GetBody().getOps<tf_device::LaunchOp>()) {
342 if (!execute_launch_op.WrapsSingleOp() ||
343 !llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
344 execute_launch_op.GetBody().front()))
345 continue;
346
347 if (execute_launch == nullptr) {
348 execute_launch = execute_launch_op;
349 } else {
350 // We only support one execute op inside replicate.
351 execute_launch = nullptr;
352 break;
353 }
354 }
355 if (!execute_launch) return;
356 auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
357 execute_launch.GetBody().front());
358 auto compile =
359 SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp();
360 if (!compile) return;
361 auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
362 if (!compile_launch || !compile_launch.WrapsSingleOp() ||
363 !llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
364 return;
365
366 // Analyze the formattable inputs.
367 auto execute_arg_to_outer_args =
368 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
369 while_op, replicate, execute, compile_launch);
370 if (execute_arg_to_outer_args.empty()) return;
371
372 // Extract the replicated devices.
373 auto devices_attr = replicate.devices();
374 if (!devices_attr) return;
375
376 auto device_map = devices_attr.getValue();
377 llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>> devices;
378 devices.reserve(device_map.size());
379
380 for (auto it : device_map) {
381 auto device_alias = it.getName().strref();
382 auto device_list = it.getValue().cast<ArrayAttr>();
383 llvm::SmallVector<StringRef, 4> device_list_for_alias;
384 device_list_for_alias.reserve(device_list.size());
385
386 for (auto device : device_list)
387 device_list_for_alias.emplace_back(device.cast<StringAttr>().getValue());
388
389 devices.insert({device_alias, device_list_for_alias});
390 }
391
392 OpBuilder builder(replicate);
393 builder.setInsertionPoint(while_op);
394 // Create per-device variables for formatting state, and add them to the while
395 // loop.
396 auto key_type =
397 RankedTensorType::get({2}, TF::StringType::get(builder.getContext()));
398 auto state_vars =
399 CreateStateVars(devices, while_op.getLoc(), key_type, &builder);
400 replicate = AddInputsToReplicateOp(replicate, state_vars, devices);
401 // Build the reformat according to the compilation. Build it inside
402 // `replicate`.
403 llvm::SmallVector<Value, 8> reformat_operands;
404 for (const auto& entry : execute_arg_to_outer_args) {
405 reformat_operands.push_back(execute.args()[entry.first]);
406 }
407 reformat_operands.push_back(compile_launch.getResult(1));
408 reformat_operands.push_back(replicate.GetBody().getArgument(
409 replicate.GetNumReplicatedBlockArguments() - 1));
410 builder.setInsertionPoint(execute_launch);
411 auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
412 execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands);
413 WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
414 execute_launch.device());
415
416 // Build the replicated unformat op after the loop. First prepare building the
417 // replicate op.
418 llvm::SmallVector<std::pair<ValueRange, Type>, 8> unformat_replicate_operands;
419 llvm::SmallVector<Value, 8> unformat_packed_operands;
420 for (const auto& entry : execute_arg_to_outer_args) {
421 if (entry.second.size() > 1) {
422 unformat_replicate_operands.emplace_back(entry.second,
423 entry.second.front().getType());
424 } else {
425 unformat_packed_operands.emplace_back(entry.second.front());
426 }
427 }
428 llvm::SmallVector<Value, 4> state_var_vals(state_vars.size());
429 for (const auto& entry : llvm::enumerate(state_vars)) {
430 state_var_vals[entry.index()] = entry.value().resource();
431 }
432 // Add the replicated state var to the end of the replicate operands.
433 unformat_replicate_operands.emplace_back(state_var_vals,
434 state_var_vals.front().getType());
435 // Build a constant default key to specify that the unformatting should
436 // transform the variables to the original format.
437 builder.setInsertionPointAfter(while_op);
438 tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
439 default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
440 default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
441 default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
442 auto default_state_key = builder.create<TF::ConstOp>(
443 while_op.getLoc(),
444 tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
445 // With all replicated inputs, now build the replicate op.
446 auto unformat_replicate = builder.create<tf_device::ReplicateOp>(
447 while_op.getLoc(), num_replicas, devices, unformat_replicate_operands,
448 unformat_packed_operands, TypeRange{});
449 // Then build the unformat op in the replicate op.
450 builder.setInsertionPointToEnd(&unformat_replicate.GetBody());
451 llvm::SmallVector<Value, 8> unformat_operands;
452 // Add the replicated state var (the last replicated operand of the
453 // ReplicateOp) as the last operand of TPUReshardVariablesOp.
454 BlockArgument state = unformat_replicate.GetReplicatedBlockArguments().back();
455 auto replicated_block_args =
456 unformat_replicate.GetReplicatedBlockArguments().drop_back(1);
457 auto packed_block_args = unformat_replicate.GetPackedBlockArguments();
458 unformat_operands.append(replicated_block_args.begin(),
459 replicated_block_args.end());
460 unformat_operands.append(packed_block_args.begin(), packed_block_args.end());
461 unformat_operands.push_back(state);
462
463 // Insert the default key as the second last operand.
464 unformat_operands.insert(
465 unformat_operands.begin() + unformat_operands.size() - 1,
466 default_state_key.getResult());
467 // Unformat op.
468 auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
469 while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands);
470 WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
471 execute_launch.device());
472 builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
473 }
474
runOnOperation()475 void TPUVariableRuntimeReformattingPass::runOnOperation() {
476 auto module = getOperation();
477 module.walk([&](TF::WhileRegionOp while_op) {
478 tf_device::ReplicateOp replicate;
479 while_op.body().walk([&](tf_device::ReplicateOp replicate_op) {
480 if (replicate == nullptr) {
481 replicate = replicate_op;
482 return WalkResult::advance();
483 }
484 // We do not handle loops with multiple replicate ops.
485 replicate = nullptr;
486 return WalkResult::interrupt();
487 });
488 // Model parallelism is not supported, and can be detected when a
489 // `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
490 if (replicate &&
491 replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
492 HandleReplicateOp(while_op, replicate);
493 });
494 }
495
496 } // namespace
497
498 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUVariableRuntimeReformattingPass()499 CreateTPUVariableRuntimeReformattingPass() {
500 return std::make_unique<TPUVariableRuntimeReformattingPass>();
501 }
502
503 } // namespace TFTPU
504 } // namespace mlir
505