1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "llvm/ADT/SmallVector.h"
17 #include "llvm/ADT/StringMap.h"
18 #include "llvm/Support/Casting.h"
19 #include "mlir/IR/Attributes.h"  // from @llvm-project
20 #include "mlir/IR/Block.h"  // from @llvm-project
21 #include "mlir/IR/Builders.h"  // from @llvm-project
22 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
23 #include "mlir/IR/Operation.h"  // from @llvm-project
24 #include "mlir/IR/Value.h"  // from @llvm-project
25 #include "mlir/IR/Visitors.h"  // from @llvm-project
26 #include "mlir/Pass/Pass.h"  // from @llvm-project
27 #include "mlir/Pass/PassRegistry.h"  // from @llvm-project
28 #include "mlir/Support/LLVM.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
31 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
32 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
33 #include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
34 
35 namespace mlir {
36 namespace TFTPU {
37 namespace {
38 
39 constexpr char kXlaOutsideCompilationAttr[] = "_xla_outside_compilation";
40 constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer";
41 
42 struct TPUUpdateEmbeddingEnqueueOpInputsPass
43     : public TF::TPUUpdateEmbeddingEnqueueOpInputsPassBase<
44           TPUUpdateEmbeddingEnqueueOpInputsPass> {
45   void runOnOperation() override;
46 };
47 
48 // Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and
49 // clear the attribute from the operation. This ensures that future optimization
50 // passes does not trigger additional logic due to presence of this attribute.
ExtractEmbeddingAttribute(Operation * op,llvm::StringMap<Operation * > * embedding_op_map)51 LogicalResult ExtractEmbeddingAttribute(
52     Operation* op, llvm::StringMap<Operation*>* embedding_op_map) {
53   auto embedding_attr = op->getAttrOfType<StringAttr>(kTPUEmbeddingAttr);
54   if (!embedding_attr) return mlir::success();
55 
56   if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second)
57     return op->emitOpError(
58         "found duplicate TPU embedding ops potentially from multiple "
59         "TPUEmbedding layers");
60 
61   op->removeAttr(kTPUEmbeddingAttr);
62   return success();
63 }
64 
FindTPUEmbeddingOps(func::FuncOp func_op,llvm::StringMap<Operation * > * enqueue_op_map,llvm::StringMap<Operation * > * recv_activation_op_map,llvm::StringMap<Operation * > * send_gradient_op_map)65 LogicalResult FindTPUEmbeddingOps(
66     func::FuncOp func_op, llvm::StringMap<Operation*>* enqueue_op_map,
67     llvm::StringMap<Operation*>* recv_activation_op_map,
68     llvm::StringMap<Operation*>* send_gradient_op_map) {
69   auto walk_result = func_op.walk([&](Operation* op) {
70     if (llvm::isa<TF::RecvTPUEmbeddingActivationsOp>(op))
71       if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map)))
72         return WalkResult::interrupt();
73 
74     if (llvm::isa<TF::SendTPUEmbeddingGradientsOp>(op))
75       if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map)))
76         return WalkResult::interrupt();
77 
78     if (llvm::isa<TF::EnqueueTPUEmbeddingSparseTensorBatchOp,
79                   TF::EnqueueTPUEmbeddingRaggedTensorBatchOp,
80                   TF::EnqueueTPUEmbeddingArbitraryTensorBatchOp>(op))
81       if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map)))
82         return WalkResult::interrupt();
83 
84     return WalkResult::advance();
85   });
86   return failure(walk_result.wasInterrupted());
87 }
88 
89 // Updates the operand of TPU embedding enqueue ops depending on whether
90 // the graph is in training mode or in non-training mode.
91 // If SendTPUEmbeddingGradients op is present, this means that graph is in
92 // training mode. As so, correctly feed in `then` branch value of SelectV2
93 // operand as inputs to the TPU embedding enqueue ops.
UpdateEmbeddingEnqueueOpInput(const llvm::StringMap<Operation * > & enqueue_op_map,const llvm::StringMap<Operation * > & recv_activation_op_map,const llvm::StringMap<Operation * > & send_gradient_op_map,OpBuilder * builder)94 LogicalResult UpdateEmbeddingEnqueueOpInput(
95     const llvm::StringMap<Operation*>& enqueue_op_map,
96     const llvm::StringMap<Operation*>& recv_activation_op_map,
97     const llvm::StringMap<Operation*>& send_gradient_op_map,
98     OpBuilder* builder) {
99   for (const auto& it : enqueue_op_map) {
100     const auto& embedding_attr = it.getKey();
101     Operation* embedding_op = it.second;
102     if (!recv_activation_op_map.count(embedding_attr))
103       return embedding_op->emitOpError()
104              << "must have a corresponding '"
105              << TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op";
106 
107     // TPU Embedding enqueue ops take different inputs depending on whether
108     // graph is in training mode or in eval/prediction mode. During training,
109     // the mode parameter for TPUEmbeddingEnqueue op must be `train` and for
110     // evaluation or prediction, mode must be set to `inference`.
111     // If SendTPUEmbeddingGradients op exists in the graph, then graph is
112     // in training mode, so create a const op with value `train` use the
113     // output value of the constant as an operand to the TPU embedding
114     // enqueue op.
115     bool is_training = send_gradient_op_map.count(embedding_attr);
116 
117     // The last operand of TPUEmbeddingEnqueue ops is the mode which
118     // represents whether graph is in training mode or in evaluation mode.
119     auto& mode_enqueue_operand =
120         embedding_op->getOpOperand(embedding_op->getNumOperands() - 1);
121 
122     llvm::SmallVector<StringRef, 1> mode_string_value;
123     mode_string_value.emplace_back(is_training ? "train" : "inference");
124     builder->setInsertionPoint(embedding_op);
125     auto enqueue_mode = builder->create<TF::ConstOp>(
126         embedding_op->getLoc(),
127         DenseStringElementsAttr::get(
128             RankedTensorType::get({}, builder->getType<TF::StringType>()),
129             mode_string_value));
130 
131     auto outside_compilation_attr =
132         embedding_op->getAttrOfType<StringAttr>(kXlaOutsideCompilationAttr);
133     if (outside_compilation_attr)
134       enqueue_mode->setAttr(kXlaOutsideCompilationAttr,
135                             outside_compilation_attr);
136 
137     mode_enqueue_operand.set(enqueue_mode);
138   }
139 
140   return success();
141 }
142 
runOnOperation()143 void TPUUpdateEmbeddingEnqueueOpInputsPass::runOnOperation() {
144   OpBuilder builder(&getContext());
145   auto func_op = getOperation();
146 
147   // All TPU embedding layer related ops are annotated with
148   // `_tpu_embedding_layer` attribute along with corresponding string attribute.
149   // Store all tpu embedding layer related ops with value of
150   // `_tpu_embedding_layer` attribute as map key.
151   llvm::StringMap<Operation*> enqueue_op_map;
152   llvm::StringMap<Operation*> recv_activation_op_map;
153   llvm::StringMap<Operation*> send_gradient_op_map;
154   if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map,
155                                  &recv_activation_op_map,
156                                  &send_gradient_op_map)))
157     return signalPassFailure();
158 
159   if (enqueue_op_map.size() != recv_activation_op_map.size()) {
160     func_op.emitError() << "expects the number of embedding enqueue ops to "
161                            "match the number of '"
162                         << TF::RecvTPUEmbeddingActivationsOp::getOperationName()
163                         << "' ops";
164     return signalPassFailure();
165   }
166 
167   if (failed(UpdateEmbeddingEnqueueOpInput(enqueue_op_map,
168                                            recv_activation_op_map,
169                                            send_gradient_op_map, &builder)))
170     return signalPassFailure();
171 }
172 
173 }  // anonymous namespace
174 
175 std::unique_ptr<OperationPass<func::FuncOp>>
CreateTPUUpdateEmbeddingEnqueueOpInputsPass()176 CreateTPUUpdateEmbeddingEnqueueOpInputsPass() {
177   return std::make_unique<TPUUpdateEmbeddingEnqueueOpInputsPass>();
178 }
179 
180 }  // namespace TFTPU
181 }  // namespace mlir
182