1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <iterator>
17 #include <memory>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21 
22 #include "absl/strings/str_cat.h"
23 #include "llvm/ADT/ArrayRef.h"
24 #include "llvm/ADT/DenseMap.h"
25 #include "llvm/ADT/STLExtras.h"
26 #include "llvm/ADT/SmallVector.h"
27 #include "llvm/ADT/StringRef.h"
28 #include "llvm/Support/Casting.h"
29 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
30 #include "mlir/IR/Attributes.h"  // from @llvm-project
31 #include "mlir/IR/Builders.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/Location.h"  // from @llvm-project
34 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
35 #include "mlir/IR/Operation.h"  // from @llvm-project
36 #include "mlir/IR/TypeUtilities.h"  // from @llvm-project
37 #include "mlir/IR/Types.h"  // from @llvm-project
38 #include "mlir/IR/Value.h"  // from @llvm-project
39 #include "mlir/Pass/Pass.h"  // from @llvm-project
40 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
41 #include "mlir/Transforms/RegionUtils.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
44 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
45 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
46 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
47 #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h"
48 #include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h"
49 #include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
50 #include "tensorflow/core/framework/tensor.h"
51 #include "tensorflow/core/framework/tensor_shape.pb.h"
52 #include "tensorflow/core/framework/types.pb.h"
53 #include "tensorflow/core/platform/random.h"
54 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
55 
56 namespace mlir {
57 namespace TFTPU {
58 
59 namespace {
60 
61 constexpr char kDeviceAttr[] = "device";
62 constexpr char kFuncDeviceAttr[] = "tf.device";
63 constexpr char kDefaultShardingValue[] = "";
64 constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
65 
GetRandomStateVariableName()66 std::string GetRandomStateVariableName() {
67   return absl::StrCat("VariablesFormatState_", tensorflow::random::New64());
68 }
69 
70 struct TPUVariableRuntimeReformattingPass
71     : public TF::TPUVariableRuntimeReformattingPassBase<
72           TPUVariableRuntimeReformattingPass> {
73   void runOnOperation() final;
74 };
75 
76 // Returns the earlier value of which `v` is an identity. If `skipped` is
77 // provided, it will be used to store the identity nodes skipped.
SkipIdentity(Value v,bool allow_other_use,llvm::SmallPtrSet<Operation *,4> * skipped=nullptr)78 Value SkipIdentity(Value v, bool allow_other_use,
79                    llvm::SmallPtrSet<Operation*, 4>* skipped = nullptr) {
80   while (auto result = v.dyn_cast<OpResult>()) {
81     if (!(allow_other_use || v.hasOneUse())) break;
82     auto op = result.getDefiningOp();
83     if (!llvm::isa<TF::IdentityOp, TF::IdentityNOp>(op)) {
84       break;
85     }
86     v = op->getOperand(result.getResultNumber());
87     if (skipped) skipped->insert(op);
88   }
89   return v;
90 }
91 
92 // Finds the formattable arguments of `execute` and annotates the metadata of
93 // `compile` to record these arguments. In addition, it returns a mapping from
94 // the formattable arguments of `execute` to the corresponding operand of
95 // `replicate`. The
96 // entries in the mapping are sorted in the order of operands of `execute`.
97 llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4>
AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate,TF::TPUExecuteAndUpdateVariablesOp execute,tf_device::LaunchOp compile_launch)98 AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
99     TF::WhileRegionOp while_op, tf_device::ReplicateOp replicate,
100     TF::TPUExecuteAndUpdateVariablesOp execute,
101     tf_device::LaunchOp compile_launch) {
102   Region& body = while_op.body();
103   Region& cond = while_op.cond();
104 
105   llvm::SmallVector<std::pair<int64_t, llvm::SmallVector<Value, 4>>, 4> mapping;
106   auto mirrored_variable_indices_attr =
107       replicate->getAttrOfType<ArrayAttr>(kMirroredVariableIndicesAttr);
108   if (!mirrored_variable_indices_attr) return mapping;
109 
110   // Finds the mapping from a replicate argument to an execute operand.
111   llvm::SmallDenseMap<int64_t, int64_t, 8> replicate_arg_to_execute_arg;
112   for (auto index_and_arg : llvm::enumerate(execute.args())) {
113     auto arg = SkipIdentity(index_and_arg.value(), /*allow_other_use=*/false);
114     if (!arg.hasOneUse() ||
115         !getElementTypeOrSelf(arg.getType()).isa<TF::ResourceType>()) {
116       continue;
117     }
118     auto block_arg = arg.dyn_cast<BlockArgument>();
119     if (!block_arg || block_arg.getOwner() != &replicate.GetBody()) continue;
120     assert(replicate_arg_to_execute_arg.count(block_arg.getArgNumber()) == 0 &&
121            "Found duplicate use of a resource in the execute op.");
122     replicate_arg_to_execute_arg[block_arg.getArgNumber()] =
123         index_and_arg.index();
124   }
125   if (replicate_arg_to_execute_arg.empty()) return mapping;
126 
127   // Parse the original compile metadata.
128   Operation& compile = compile_launch.GetBody().front();
129   auto metadata_str = compile.getAttrOfType<StringAttr>("metadata");
130   assert(metadata_str && "Missing compilation metadata");
131   tensorflow::tpu::TPUCompileMetadataProto metadata;
132   metadata.ParseFromString(std::string(metadata_str.getValue()));
133   int64_t num_replicas = replicate.n();
134   // Find the formattable operands of `execute`, which must be mirrored
135   // variables (arguments of `replicate`), and must be pass-throughs from while
136   // operands.
137   for (const auto& mirrored_index : mirrored_variable_indices_attr) {
138     int64_t replicate_arg = mirrored_index.cast<IntegerAttr>().getInt();
139     // Check if the mirrored variable is an input to `execute`.
140     auto it = replicate_arg_to_execute_arg.find(replicate_arg);
141     if (it == replicate_arg_to_execute_arg.end()) continue;
142     // Get the data type of the resource.
143     auto subtypes = getElementTypeOrSelf(execute.getOperand(it->second))
144                         .cast<TF::ResourceType>()
145                         .getSubtypes();
146     if (subtypes.size() != 1) continue;
147     auto data_type = getElementTypeOrSelf(subtypes[0]);
148     // The XLA backend does not yet support formatting 64-bit data types.
149     if (data_type.getIntOrFloatBitWidth() == 64) continue;
150 
151     const auto& block_arg = replicate.GetBody().getArgument(replicate_arg);
152 
153     int64_t num_inputs = 0;
154     if (replicate.IsReplicatedBlockArgument(block_arg)) {
155       num_inputs = num_replicas;
156     } else {
157       num_inputs = 1;
158     }
159 
160     // We have found a mirrored variable which is an input to the replicated
161     // `execute`. Now find if this mirrored variable is a pass-through of while
162     // arguments.
163     llvm::SmallVector<Value, 4> replicate_args;
164     for (int64_t i = 0; i < num_inputs; ++i) {
165       llvm::SmallPtrSet<Operation*, 4> skipped_identities;
166 
167       auto replicate_operand = SkipIdentity(
168           replicate.GetReplicaOperandForBlockArgument(block_arg, i),
169           /*allow_other_use=*/false, &skipped_identities);
170       // For region based control flow, the resource operand for the replicate
171       // should be a region capture. If this has any use other than the
172       // replicate op (within the body of the while) or the skipped identities,
173       // then do not apply the transformation to this variable.
174       bool is_region_capture =
175           replicate_operand.getParentRegion()->isProperAncestor(&body);
176       bool has_other_use_in_body =
177           llvm::any_of(replicate_operand.getUsers(), [&](Operation* user) {
178             // Ignore uses that are not in the while body or condition.
179             if (!body.isAncestor(user->getParentRegion()) &&
180                 !cond.isAncestor(user->getParentRegion()))
181               return false;
182             // Within the body or cond, only uses in replicate and the skipped
183             // identities is allowed.
184             return user != replicate && skipped_identities.count(user) == 0;
185           });
186 
187       if (!is_region_capture || has_other_use_in_body) {
188         replicate_args.clear();
189         break;
190       }
191       replicate_args.push_back(replicate_operand);
192     }
193     if (replicate_args.empty()) continue;
194     // Now set the enable_xla_sharding field in the metadata to inform the
195     // compile op.
196     auto metadata_arg = metadata.mutable_args(it->second);
197     metadata_arg->set_enable_xla_sharding(
198         ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED);
199     mapping.emplace_back(it->second, std::move(replicate_args));
200   }
201   // Sort the mapping according to execute operand order.
202   llvm::sort(mapping, llvm::less_first());
203   // Populate the `retval_index_for_sharding` field of the argument metadate.
204   for (auto entry : llvm::enumerate(execute.device_var_reads_indices())) {
205     int64_t arg_index = entry.value().cast<IntegerAttr>().getInt();
206     auto arg_metadata = metadata.mutable_args(arg_index);
207     if (arg_metadata->enable_xla_sharding() ==
208         ::tensorflow::tpu::TPUCompileMetadataProto_Arg::ALLOWED) {
209       int64_t ret_index = execute.device_var_updates_indices()
210                               .getValue()[entry.index()]
211                               .cast<IntegerAttr>()
212                               .getInt();
213       arg_metadata->set_retval_index_for_sharding(ret_index);
214     }
215   }
216   // Update the metadata of the compile op.
217   compile.setAttr("metadata", StringAttr::get(compile.getContext(),
218                                               metadata.SerializeAsString()));
219   return mapping;
220 }
221 
222 // Adds a new replicated input to the replicate op.
AddInputsToReplicateOp(tf_device::ReplicateOp replicate,MutableArrayRef<TF::VarHandleOp> new_inputs,const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices)223 tf_device::ReplicateOp AddInputsToReplicateOp(
224     tf_device::ReplicateOp replicate,
225     MutableArrayRef<TF::VarHandleOp> new_inputs,
226     const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
227         devices) {
228   int64_t num_replicas = replicate.n();
229   assert(new_inputs.size() == num_replicas);
230 
231   // As model parallelism is not yet supported, we assume that all ops are
232   // placed in logical core 0.
233   // TODO(b/148913020): Remove this constraint once model parallelism is
234   // supported.
235   assert(devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))
236              ->getSecond()
237              .size() == num_replicas);
238 
239   llvm::SmallVector<std::pair<ValueRange, Type>, 8> new_replicated_inputs;
240   llvm::SmallVector<Value, 8> new_packed_inputs;
241   llvm::SmallVector<llvm::SmallVector<Value, 8>, 8> replicated_inputs;
242   replicated_inputs.reserve(replicate.GetNumReplicatedBlockArguments());
243   new_packed_inputs.reserve(replicate.GetNumPackedBlockArguments());
244   for (const auto& arg : replicate.GetReplicatedBlockArguments()) {
245     replicated_inputs.emplace_back();
246     for (int64_t i = 0; i < num_replicas; ++i) {
247       replicated_inputs.back().push_back(
248           replicate.GetReplicaOperandForBlockArgument(arg, i));
249     }
250     new_replicated_inputs.emplace_back(replicated_inputs.back(), arg.getType());
251   }
252   for (const auto& arg : replicate.GetPackedBlockArguments()) {
253     new_packed_inputs.emplace_back(
254         replicate.GetReplicaOperandForBlockArgument(arg, /*replica=*/0));
255   }
256   SmallVector<Value, 4> new_input_values;
257   new_input_values.reserve(new_inputs.size());
258   for (auto var : new_inputs) new_input_values.push_back(var.resource());
259   new_replicated_inputs.emplace_back(new_input_values,
260                                      new_input_values.front().getType());
261   OpBuilder builder(replicate);
262   auto new_replicate = builder.create<tf_device::ReplicateOp>(
263       replicate.getLoc(), num_replicas, devices, new_replicated_inputs,
264       new_packed_inputs,
265       replicate.GetBody().getTerminator()->getOperandTypes());
266   for (auto arg : replicate.GetBody().getArguments()) {
267     if (replicate.IsReplicatedBlockArgument(arg)) {
268       arg.replaceAllUsesWith(
269           new_replicate.GetBody().getArgument(arg.getArgNumber()));
270     } else {
271       // There is a new added replicated state variable between replicated args
272       // and packed args.
273       arg.replaceAllUsesWith(
274           new_replicate.GetBody().getArgument(arg.getArgNumber() + 1));
275     }
276   }
277   for (auto& op : llvm::make_early_inc_range(replicate.GetBody())) {
278     op.moveBefore(&new_replicate.GetBody(), new_replicate.GetBody().end());
279   }
280   replicate.replaceAllUsesWith(new_replicate);
281   replicate.erase();
282   return new_replicate;
283 }
284 
285 // Creates the per-device variables that represent the formatting state of each
286 // device.
CreateStateVars(const llvm::SmallDenseMap<llvm::StringRef,llvm::SmallVector<StringRef,4>> & devices,Location loc,RankedTensorType key_type,OpBuilder * builder)287 llvm::SmallVector<TF::VarHandleOp, 4> CreateStateVars(
288     const llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>&
289         devices,
290     Location loc, RankedTensorType key_type, OpBuilder* builder) {
291   llvm::SmallVector<TF::VarHandleOp, 4> state_vars;
292 
293   // TODO(b/148913020): Remove this constraint once model parallelism is
294   // supported.
295   const auto& device_list =
296       devices.find(tensorflow::GetDeviceAliasForLogicalCore(0))->getSecond();
297 
298   // Create the state variable for each device.
299   for (llvm::StringRef device : device_list) {
300     state_vars.push_back(builder->create<TF::VarHandleOp>(
301         loc,
302         llvm::ArrayRef<Type>{RankedTensorType::get(
303             {}, TF::ResourceType::get(llvm::ArrayRef<TensorType>{key_type},
304                                       builder->getContext()))},
305         llvm::ArrayRef<Value>{},
306         llvm::ArrayRef<NamedAttribute>{
307             builder->getNamedAttr(kDeviceAttr, builder->getStringAttr(device)),
308             builder->getNamedAttr("container", builder->getStringAttr("")),
309             builder->getNamedAttr(
310                 "shared_name",
311                 builder->getStringAttr(GetRandomStateVariableName()))}));
312   }
313   return state_vars;
314 }
315 
316 // Wraps single op in `tf_device.launch` for explicit device assignment.
WrapOpInLaunch(OpBuilder * builder,Location loc,Operation * op,llvm::StringRef device)317 void WrapOpInLaunch(OpBuilder* builder, Location loc, Operation* op,
318                     llvm::StringRef device) {
319   OpBuilder::InsertPoint insert_point = builder->saveInsertionPoint();
320 
321   auto launch = builder->create<tf_device::LaunchOp>(
322       loc, builder->getStringAttr(device), op->getResultTypes());
323   launch.body().push_back(new Block);
324 
325   builder->setInsertionPointToEnd(&launch.GetBody());
326   builder->create<tf_device::ReturnOp>(loc, op->getResults());
327 
328   // Move op inside launch.
329   op->moveBefore(launch.GetBody().getTerminator());
330 
331   builder->restoreInsertionPoint(insert_point);
332 }
333 
334 // Performs the transformation for a replicate op inside a while loop.
HandleReplicateOp(TF::WhileRegionOp while_op,tf_device::ReplicateOp replicate)335 void HandleReplicateOp(TF::WhileRegionOp while_op,
336                        tf_device::ReplicateOp replicate) {
337   int64_t num_replicas = replicate.n();
338   if (num_replicas == 1) return;
339   tf_device::LaunchOp execute_launch;
340   for (auto execute_launch_op :
341        replicate.GetBody().getOps<tf_device::LaunchOp>()) {
342     if (!execute_launch_op.WrapsSingleOp() ||
343         !llvm::isa<TF::TPUExecuteAndUpdateVariablesOp>(
344             execute_launch_op.GetBody().front()))
345       continue;
346 
347     if (execute_launch == nullptr) {
348       execute_launch = execute_launch_op;
349     } else {
350       // We only support one execute op inside replicate.
351       execute_launch = nullptr;
352       break;
353     }
354   }
355   if (!execute_launch) return;
356   auto execute = llvm::cast<TF::TPUExecuteAndUpdateVariablesOp>(
357       execute_launch.GetBody().front());
358   auto compile =
359       SkipIdentity(execute.key(), /*allow_other_use=*/true).getDefiningOp();
360   if (!compile) return;
361   auto compile_launch = llvm::dyn_cast<tf_device::LaunchOp>(compile);
362   if (!compile_launch || !compile_launch.WrapsSingleOp() ||
363       !llvm::isa<TF::_TPUCompileMlirOp>(compile_launch.GetBody().front()))
364     return;
365 
366   // Analyze the formattable inputs.
367   auto execute_arg_to_outer_args =
368       AnnotateCompileOpAndGetExecuteArgToWhileArgsMapping(
369           while_op, replicate, execute, compile_launch);
370   if (execute_arg_to_outer_args.empty()) return;
371 
372   // Extract the replicated devices.
373   auto devices_attr = replicate.devices();
374   if (!devices_attr) return;
375 
376   auto device_map = devices_attr.getValue();
377   llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>> devices;
378   devices.reserve(device_map.size());
379 
380   for (auto it : device_map) {
381     auto device_alias = it.getName().strref();
382     auto device_list = it.getValue().cast<ArrayAttr>();
383     llvm::SmallVector<StringRef, 4> device_list_for_alias;
384     device_list_for_alias.reserve(device_list.size());
385 
386     for (auto device : device_list)
387       device_list_for_alias.emplace_back(device.cast<StringAttr>().getValue());
388 
389     devices.insert({device_alias, device_list_for_alias});
390   }
391 
392   OpBuilder builder(replicate);
393   builder.setInsertionPoint(while_op);
394   // Create per-device variables for formatting state, and add them to the while
395   // loop.
396   auto key_type =
397       RankedTensorType::get({2}, TF::StringType::get(builder.getContext()));
398   auto state_vars =
399       CreateStateVars(devices, while_op.getLoc(), key_type, &builder);
400   replicate = AddInputsToReplicateOp(replicate, state_vars, devices);
401   // Build the reformat according to the compilation. Build it inside
402   // `replicate`.
403   llvm::SmallVector<Value, 8> reformat_operands;
404   for (const auto& entry : execute_arg_to_outer_args) {
405     reformat_operands.push_back(execute.args()[entry.first]);
406   }
407   reformat_operands.push_back(compile_launch.getResult(1));
408   reformat_operands.push_back(replicate.GetBody().getArgument(
409       replicate.GetNumReplicatedBlockArguments() - 1));
410   builder.setInsertionPoint(execute_launch);
411   auto reformat_op = builder.create<TF::TPUReshardVariablesOp>(
412       execute_launch.getLoc(), llvm::ArrayRef<Type>{}, reformat_operands);
413   WrapOpInLaunch(&builder, execute_launch.getLoc(), reformat_op,
414                  execute_launch.device());
415 
416   // Build the replicated unformat op after the loop. First prepare building the
417   // replicate op.
418   llvm::SmallVector<std::pair<ValueRange, Type>, 8> unformat_replicate_operands;
419   llvm::SmallVector<Value, 8> unformat_packed_operands;
420   for (const auto& entry : execute_arg_to_outer_args) {
421     if (entry.second.size() > 1) {
422       unformat_replicate_operands.emplace_back(entry.second,
423                                                entry.second.front().getType());
424     } else {
425       unformat_packed_operands.emplace_back(entry.second.front());
426     }
427   }
428   llvm::SmallVector<Value, 4> state_var_vals(state_vars.size());
429   for (const auto& entry : llvm::enumerate(state_vars)) {
430     state_var_vals[entry.index()] = entry.value().resource();
431   }
432   // Add the replicated state var to the end of the replicate operands.
433   unformat_replicate_operands.emplace_back(state_var_vals,
434                                            state_var_vals.front().getType());
435   // Build a constant default key to specify that the unformatting should
436   // transform the variables to the original format.
437   builder.setInsertionPointAfter(while_op);
438   tensorflow::Tensor default_key_tensor(tensorflow::DT_STRING, {3});
439   default_key_tensor.vec<tensorflow::tstring>()(0) = kDefaultShardingValue;
440   default_key_tensor.vec<tensorflow::tstring>()(1) = kDefaultShardingValue;
441   default_key_tensor.vec<tensorflow::tstring>()(2) = kDefaultShardingValue;
442   auto default_state_key = builder.create<TF::ConstOp>(
443       while_op.getLoc(),
444       tensorflow::ConvertTensor(default_key_tensor, &builder).ValueOrDie());
445   // With all replicated inputs, now build the replicate op.
446   auto unformat_replicate = builder.create<tf_device::ReplicateOp>(
447       while_op.getLoc(), num_replicas, devices, unformat_replicate_operands,
448       unformat_packed_operands, TypeRange{});
449   // Then build the unformat op in the replicate op.
450   builder.setInsertionPointToEnd(&unformat_replicate.GetBody());
451   llvm::SmallVector<Value, 8> unformat_operands;
452   // Add the replicated state var (the last replicated operand of the
453   // ReplicateOp) as the last operand of TPUReshardVariablesOp.
454   BlockArgument state = unformat_replicate.GetReplicatedBlockArguments().back();
455   auto replicated_block_args =
456       unformat_replicate.GetReplicatedBlockArguments().drop_back(1);
457   auto packed_block_args = unformat_replicate.GetPackedBlockArguments();
458   unformat_operands.append(replicated_block_args.begin(),
459                            replicated_block_args.end());
460   unformat_operands.append(packed_block_args.begin(), packed_block_args.end());
461   unformat_operands.push_back(state);
462 
463   // Insert the default key as the second last operand.
464   unformat_operands.insert(
465       unformat_operands.begin() + unformat_operands.size() - 1,
466       default_state_key.getResult());
467   // Unformat op.
468   auto unformat_op = builder.create<TF::TPUReshardVariablesOp>(
469       while_op.getLoc(), llvm::ArrayRef<Type>{}, unformat_operands);
470   WrapOpInLaunch(&builder, execute_launch.getLoc(), unformat_op,
471                  execute_launch.device());
472   builder.create<tf_device::ReturnOp>(while_op.getLoc(), ArrayRef<Value>{});
473 }
474 
runOnOperation()475 void TPUVariableRuntimeReformattingPass::runOnOperation() {
476   auto module = getOperation();
477   module.walk([&](TF::WhileRegionOp while_op) {
478     tf_device::ReplicateOp replicate;
479     while_op.body().walk([&](tf_device::ReplicateOp replicate_op) {
480       if (replicate == nullptr) {
481         replicate = replicate_op;
482         return WalkResult::advance();
483       }
484       // We do not handle loops with multiple replicate ops.
485       replicate = nullptr;
486       return WalkResult::interrupt();
487     });
488     // Model parallelism is not supported, and can be detected when a
489     // `tf_device.parallel_execute` op in the `tf_device.replicate` is present.
490     if (replicate &&
491         replicate.GetBody().getOps<tf_device::ParallelExecuteOp>().empty())
492       HandleReplicateOp(while_op, replicate);
493   });
494 }
495 
496 }  // namespace
497 
498 std::unique_ptr<OperationPass<ModuleOp>>
CreateTPUVariableRuntimeReformattingPass()499 CreateTPUVariableRuntimeReformattingPass() {
500   return std::make_unique<TPUVariableRuntimeReformattingPass>();
501 }
502 
503 }  // namespace TFTPU
504 }  // namespace mlir
505