1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/Attributes.h" // from @llvm-project
20 #include "mlir/IR/Block.h" // from @llvm-project
21 #include "mlir/IR/Builders.h" // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
23 #include "mlir/IR/Operation.h" // from @llvm-project
24 #include "mlir/IR/Value.h" // from @llvm-project
25 #include "mlir/IR/Visitors.h" // from @llvm-project
26 #include "mlir/Pass/Pass.h" // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h" // from @llvm-project
28 #include "mlir/Support/LLVM.h" // from @llvm-project
29 #include "mlir/Support/LogicalResult.h" // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34
35 namespace mlir {
36 namespace TFTPU {
37 namespace {
38
39 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
40 constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
41
42 struct TPUUpdateEmbeddingEnqueueOpInputsPass
43 : public TF::TPUUpdateEmbeddingEnqueueOpInputsPassBase<
44 TPUUpdateEmbeddingEnqueueOpInputsPass> {
45 void runOnOperation() override;
46 };
47
48 // Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and
49 // clear the attribute from the operation. This ensures that future optimization
50 // passes does not trigger additional logic due to presence of this attribute.
ExtractEmbeddingAttribute(Operation * op,llvm::StringMap<Operation * > * embedding_op_map)51 LogicalResult ExtractEmbeddingAttribute(
52 Operation* op, llvm::StringMap<Operation*>* embedding_op_map) {
53 auto embedding_attr = op->getAttrOfType<StringAttr>(kTPUEmbeddingAttr);
54 if (!embedding_attr) return mlir::success();
55
56 if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second)
57 return op->emitOpError(
58 "found duplicate TPU embedding ops potentially from multiple "
59 "TPUEmbedding layers");
60
61 op->removeAttr(kTPUEmbeddingAttr);
62 return success();
63 }
64
FindTPUEmbeddingOps(func::FuncOp func_op,llvm::StringMap<Operation * > * enqueue_op_map,llvm::StringMap<Operation * > * recv_activation_op_map,llvm::StringMap<Operation * > * send_gradient_op_map)65 LogicalResult FindTPUEmbeddingOps(
66 func::FuncOp func_op, llvm::StringMap<Operation*>* enqueue_op_map,
67 llvm::StringMap<Operation*>* recv_activation_op_map,
68 llvm::StringMap<Operation*>* send_gradient_op_map) {
69 auto walk_result = func_op.walk([&](Operation* op) {
70 if (llvm::isa<TF::RecvTPUEmbeddingActivationsOp>(op))
71 if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map)))
72 return WalkResult::interrupt();
73
74 if (llvm::isa<TF::SendTPUEmbeddingGradientsOp>(op))
75 if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map)))
76 return WalkResult::interrupt();
77
78 if (llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
79 TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
80 TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp>(op))
81 if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map)))
82 return WalkResult::interrupt();
83
84 return WalkResult::advance();
85 });
86 return failure(walk_result.wasInterrupted());
87 }
88
89 // Updates the operand of TPU embedding enqueue ops depending on whether
90 // the graph is in training mode or in non-training mode.
91 // If SendTPUEmbeddingGradients op is present, this means that graph is in
92 // training mode. As so, correctly feed in `then` branch value of SelectV2
93 // operand as inputs to the TPU embedding enqueue ops.
UpdateEmbeddingEnqueueOpInput(const llvm::StringMap<Operation * > & enqueue_op_map,const llvm::StringMap<Operation * > & recv_activation_op_map,const llvm::StringMap<Operation * > & send_gradient_op_map,OpBuilder * builder)94 LogicalResult UpdateEmbeddingEnqueueOpInput(
95 const llvm::StringMap<Operation*>& enqueue_op_map,
96 const llvm::StringMap<Operation*>& recv_activation_op_map,
97 const llvm::StringMap<Operation*>& send_gradient_op_map,
98 OpBuilder* builder) {
99 for (const auto& it : enqueue_op_map) {
100 const auto& embedding_attr = it.getKey();
101 Operation* embedding_op = it.second;
102 if (!recv_activation_op_map.count(embedding_attr))
103 return embedding_op->emitOpError()
104 << "must have a corresponding '"
105 << TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
106
107 // TPU Embedding enqueue ops take different inputs depending on whether
108 // graph is in training mode or in eval/prediction mode. During training,
109 // the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
110 // evaluation or prediction, mode must be set to `inference`.
111 // If SendTPUEmbeddingGradients op exists in the graph, then graph is
112 // in training mode, so create a const op with value `train` use the
113 // output value of the constant as an operand to the TPU embedding
114 // enqueue op.
115 bool is_training = send_gradient_op_map.count(embedding_attr);
116
117 // The last operand of TPUEmbeddingEnqueue ops is the mode which
118 // represents whether graph is in training mode or in evaluation mode.
119 auto& mode_enqueue_operand =
120 embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
121
122 llvm::SmallVector<StringRef, 1> mode_string_value;
123 mode_string_value.emplace_back(is_training ? "train" : "inference");
124 builder->setInsertionPoint(embedding_op);
125 auto enqueue_mode = builder->create<TF::ConstOp>(
126 embedding_op->getLoc(),
127 DenseStringElementsAttr::get(
128 RankedTensorType::get({}, builder->getType<TF::StringType>()),
129 mode_string_value));
130
131 auto outside_compilation_attr =
132 embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
133 if (outside_compilation_attr)
134 enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
135 outside_compilation_attr);
136
137 mode_enqueue_operand.set(enqueue_mode);
138 }
139
140 return success();
141 }
142
runOnOperation()143 void TPUUpdateEmbeddingEnqueueOpInputsPass::runOnOperation() {
144 OpBuilder builder(&getContext());
145 auto func_op = getOperation();
146
147 // All TPU embedding layer related ops are annotated with
148 // `_tpu_embedding_layer` attribute along with corresponding string attribute.
149 // Store all tpu embedding layer related ops with value of
150 // `_tpu_embedding_layer` attribute as map key.
151 llvm::StringMap<Operation*> enqueue_op_map;
152 llvm::StringMap<Operation*> recv_activation_op_map;
153 llvm::StringMap<Operation*> send_gradient_op_map;
154 if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map,
155 &recv_activation_op_map,
156 &send_gradient_op_map)))
157 return signalPassFailure();
158
159 if (enqueue_op_map.size() != recv_activation_op_map.size()) {
160 func_op.emitError() << "expects the number of embedding enqueue ops to "
161 "match the number of '"
162 << TF::RecvTPUEmbeddingActivationsOp::getOperationName()
163 << "' ops";
164 return signalPassFailure();
165 }
166
167 if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
168 recv_activation_op_map,
169 send_gradient_op_map, &builder)))
170 return signalPassFailure();
171 }
172
173 } // anonymous namespace
174
175 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUUpdateEmbeddingEnqueueOpInputsPass()176 CreateTPUUpdateEmbeddingEnqueueOpInputsPass() {
177 return std::make_unique<TPUUpdateEmbeddingEnqueueOpInputsPass>();
178 }
179
180 } // namespace TFTPU
181 } // namespace mlir
182