1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12
13 #include <cstddef>
14
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "mlir/Pass/Pass.h" // from @llvm-project
18 #include "mlir/Support/LogicalResult.h" // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
22
23 namespace mlir {
24 namespace TFTPU {
25 namespace {
26
27 struct TPUReorderReplicateAndPartitionedInputsPass
28 : public TF::TPUReorderReplicateAndPartitionedInputsPassBase<
29 TPUReorderReplicateAndPartitionedInputsPass> {
30 void runOnOperation() override;
31 };
32
ReorderReplicateAndPartitionedInputs(TF::TPUReplicatedInputOp replicated_input)33 LogicalResult ReorderReplicateAndPartitionedInputs(
34 TF::TPUReplicatedInputOp replicated_input) {
35 if (!llvm::all_of(replicated_input.inputs(), [](Value input) {
36 return llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
37 input.getDefiningOp());
38 }))
39 return replicated_input.emitOpError()
40 << "expects all inputs from 'tf.TPUPartitionedInput' ops";
41
42 auto first_partitioned_input = llvm::cast<TF::TPUPartitionedInputOp>(
43 replicated_input.getOperand(0).getDefiningOp());
44 llvm::Optional<::llvm::StringRef> xla_sharding =
45 first_partitioned_input._XlaSharding();
46 int64_t partition_dim = first_partitioned_input.partition_dim();
47 size_t num_cores_per_replica = first_partitioned_input.getNumOperands();
48
49 for (auto operand : replicated_input.inputs().drop_front()) {
50 auto partitioned_input =
51 llvm::cast<TF::TPUPartitionedInputOp>(operand.getDefiningOp());
52 llvm::Optional<::llvm::StringRef> op_xla_sharding =
53 partitioned_input._XlaSharding();
54 int64_t op_partition_dim = partitioned_input.partition_dim();
55 // Abort if TPUPartitionedInput(s) do not have the same attributes.
56 if (partition_dim != op_partition_dim)
57 return partitioned_input->emitOpError()
58 << "expects partition_dim = " << partition_dim << " but found "
59 << op_partition_dim;
60 if (partitioned_input.getNumOperands() != num_cores_per_replica)
61 return partitioned_input->emitOpError()
62 << "expects " << num_cores_per_replica << " operands but found "
63 << partitioned_input.getNumOperands();
64 if (xla_sharding != op_xla_sharding)
65 return replicated_input.emitOpError()
66 << "expects all inputs from 'tf.TPUPartitionedInput' ops to have "
67 "identical XLA sharding";
68 }
69
70 // 2D Matrix to store per core per replica operands. The matrix dimensions are
71 // num_cores_per_replica x num_replicas. i-th row holds the operands for i-th
72 // core. j-th column holds the operands for j-th replica.
73 llvm::SmallVector<llvm::SmallVector<Value, 4>, 4>
74 operands_per_replica_per_core;
75 operands_per_replica_per_core.resize(num_cores_per_replica);
76
77 // Collect all operands in the 2D matrix.
78 for (auto operand : replicated_input.inputs()) {
79 auto pi = llvm::cast<TF::TPUPartitionedInputOp>(operand.getDefiningOp());
80 for (auto& pi_operand : pi->getOpOperands()) {
81 unsigned core_id = pi_operand.getOperandNumber();
82 operands_per_replica_per_core[core_id].push_back(pi_operand.get());
83 }
84 }
85
86 // Create new `tf.TPUReplicatedInput` ops feeding into one
87 // `tf.TPUPartitionedInput` op.
88 OpBuilder builder(replicated_input);
89 llvm::SmallVector<Value, 4> operands_per_core;
90 for (const auto& operands_per_replica : operands_per_replica_per_core) {
91 auto replicate_op = builder.create<TF::TPUReplicatedInputOp>(
92 replicated_input.getLoc(), replicated_input.getType(),
93 operands_per_replica, replicated_input->getAttrs());
94 operands_per_core.push_back(replicate_op);
95 }
96
97 auto pi = builder.create<TF::TPUPartitionedInputOp>(
98 first_partitioned_input.getLoc(), replicated_input.getType(),
99 operands_per_core, first_partitioned_input->getAttrs());
100 replicated_input.replaceAllUsesWith(pi.output());
101 return success();
102 }
103
runOnOperation()104 void TPUReorderReplicateAndPartitionedInputsPass::runOnOperation() {
105 auto result =
106 getOperation()->walk([](TF::TPUReplicatedInputOp replicated_input) {
107 if (llvm::none_of(replicated_input.inputs(), [](Value input) {
108 return llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
109 input.getDefiningOp());
110 }))
111 return WalkResult::advance();
112 if (failed(ReorderReplicateAndPartitionedInputs(replicated_input)))
113 return WalkResult::interrupt();
114
115 assert(replicated_input->use_empty());
116 replicated_input->erase();
117 return WalkResult::advance();
118 });
119
120 if (result.wasInterrupted()) {
121 signalPassFailure();
122 return;
123 }
124
125 getOperation()->walk([](TF::TPUPartitionedInputOp partitioned_input) {
126 if (partitioned_input->use_empty()) partitioned_input->erase();
127 });
128 }
129
130 } // namespace
131
132 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUReorderReplicateAndPartitionedInputsPass()133 CreateTPUReorderReplicateAndPartitionedInputsPass() {
134 return std::make_unique<TPUReorderReplicateAndPartitionedInputsPass>();
135 }
136
137 } // namespace TFTPU
138 } // namespace mlir
139