1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 
13 #include <cstddef>
14 
15 #include "llvm/ADT/STLExtras.h"
16 #include "llvm/ADT/StringRef.h"
17 #include "mlir/Pass/Pass.h"  // from @llvm-project
18 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
19 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
20 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
21 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
22 
23 namespace mlir {
24 namespace TFTPU {
25 namespace {
26 
27 struct TPUReorderReplicateAndPartitionedInputsPass
28     : public TF::TPUReorderReplicateAndPartitionedInputsPassBase<
29           TPUReorderReplicateAndPartitionedInputsPass> {
30   void runOnOperation() override;
31 };
32 
ReorderReplicateAndPartitionedInputs(TF::TPUReplicatedInputOp replicated_input)33 LogicalResult ReorderReplicateAndPartitionedInputs(
34     TF::TPUReplicatedInputOp replicated_input) {
35   if (!llvm::all_of(replicated_input.inputs(), [](Value input) {
36         return llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
37             input.getDefiningOp());
38       }))
39     return replicated_input.emitOpError()
40            << "expects all inputs from 'tf.TPUPartitionedInput' ops";
41 
42   auto first_partitioned_input = llvm::cast<TF::TPUPartitionedInputOp>(
43       replicated_input.getOperand(0).getDefiningOp());
44   llvm::Optional<::llvm::StringRef> xla_sharding =
45       first_partitioned_input._XlaSharding();
46   int64_t partition_dim = first_partitioned_input.partition_dim();
47   size_t num_cores_per_replica = first_partitioned_input.getNumOperands();
48 
49   for (auto operand : replicated_input.inputs().drop_front()) {
50     auto partitioned_input =
51         llvm::cast<TF::TPUPartitionedInputOp>(operand.getDefiningOp());
52     llvm::Optional<::llvm::StringRef> op_xla_sharding =
53         partitioned_input._XlaSharding();
54     int64_t op_partition_dim = partitioned_input.partition_dim();
55     // Abort if TPUPartitionedInput(s) do not have the same attributes.
56     if (partition_dim != op_partition_dim)
57       return partitioned_input->emitOpError()
58              << "expects partition_dim = " << partition_dim << " but found "
59              << op_partition_dim;
60     if (partitioned_input.getNumOperands() != num_cores_per_replica)
61       return partitioned_input->emitOpError()
62              << "expects " << num_cores_per_replica << " operands but found "
63              << partitioned_input.getNumOperands();
64     if (xla_sharding != op_xla_sharding)
65       return replicated_input.emitOpError()
66              << "expects all inputs from 'tf.TPUPartitionedInput' ops to have "
67                 "identical XLA sharding";
68   }
69 
70   // 2D Matrix to store per core per replica operands. The matrix dimensions are
71   // num_cores_per_replica x num_replicas. i-th row holds the operands for i-th
72   // core. j-th column holds the operands for j-th replica.
73   llvm::SmallVector<llvm::SmallVector<Value, 4>, 4>
74       operands_per_replica_per_core;
75   operands_per_replica_per_core.resize(num_cores_per_replica);
76 
77   // Collect all operands in the 2D matrix.
78   for (auto operand : replicated_input.inputs()) {
79     auto pi = llvm::cast<TF::TPUPartitionedInputOp>(operand.getDefiningOp());
80     for (auto& pi_operand : pi->getOpOperands()) {
81       unsigned core_id = pi_operand.getOperandNumber();
82       operands_per_replica_per_core[core_id].push_back(pi_operand.get());
83     }
84   }
85 
86   // Create new `tf.TPUReplicatedInput` ops feeding into one
87   // `tf.TPUPartitionedInput` op.
88   OpBuilder builder(replicated_input);
89   llvm::SmallVector<Value, 4> operands_per_core;
90   for (const auto& operands_per_replica : operands_per_replica_per_core) {
91     auto replicate_op = builder.create<TF::TPUReplicatedInputOp>(
92         replicated_input.getLoc(), replicated_input.getType(),
93         operands_per_replica, replicated_input->getAttrs());
94     operands_per_core.push_back(replicate_op);
95   }
96 
97   auto pi = builder.create<TF::TPUPartitionedInputOp>(
98       first_partitioned_input.getLoc(), replicated_input.getType(),
99       operands_per_core, first_partitioned_input->getAttrs());
100   replicated_input.replaceAllUsesWith(pi.output());
101   return success();
102 }
103 
runOnOperation()104 void TPUReorderReplicateAndPartitionedInputsPass::runOnOperation() {
105   auto result =
106       getOperation()->walk([](TF::TPUReplicatedInputOp replicated_input) {
107         if (llvm::none_of(replicated_input.inputs(), [](Value input) {
108               return llvm::isa_and_nonnull<TF::TPUPartitionedInputOp>(
109                   input.getDefiningOp());
110             }))
111           return WalkResult::advance();
112         if (failed(ReorderReplicateAndPartitionedInputs(replicated_input)))
113           return WalkResult::interrupt();
114 
115         assert(replicated_input->use_empty());
116         replicated_input->erase();
117         return WalkResult::advance();
118       });
119 
120   if (result.wasInterrupted()) {
121     signalPassFailure();
122     return;
123   }
124 
125   getOperation()->walk([](TF::TPUPartitionedInputOp partitioned_input) {
126     if (partitioned_input->use_empty()) partitioned_input->erase();
127   });
128 }
129 
130 }  // namespace
131 
132 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUReorderReplicateAndPartitionedInputsPass()133 CreateTPUReorderReplicateAndPartitionedInputsPass() {
134   return std::make_unique<TPUReorderReplicateAndPartitionedInputsPass>();
135 }
136 
137 }  // namespace TFTPU
138 }  // namespace mlir
139