xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
17 
18 #include <algorithm>
19 
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/None.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h"  // from @llvm-project
27 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
28 #include "mlir/Dialect/Tensor/IR/Tensor.h"  // from @llvm-project
29 #include "mlir/IR/Attributes.h"  // from @llvm-project
30 #include "mlir/IR/Builders.h"  // from @llvm-project
31 #include "mlir/IR/BuiltinAttributes.h"  // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
34 #include "mlir/IR/Location.h"  // from @llvm-project
35 #include "mlir/IR/MLIRContext.h"  // from @llvm-project
36 #include "mlir/IR/OpDefinition.h"  // from @llvm-project
37 #include "mlir/IR/Operation.h"  // from @llvm-project
38 #include "mlir/IR/Types.h"  // from @llvm-project
39 #include "mlir/IR/Value.h"  // from @llvm-project
40 #include "mlir/Support/LLVM.h"  // from @llvm-project
41 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44 
45 namespace mlir {
46 namespace TFL {
47 
48 namespace {
49 
CreateI32SplatConst(OpBuilder * builder,ArrayRef<int64_t> shape,int32_t val,mlir::Location location)50 Value CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
51                           int32_t val, mlir::Location location) {
52   auto type = RankedTensorType::get(shape, builder->getIntegerType(32));
53   auto attr = DenseElementsAttr::get(type, val);
54   return builder->create<arith::ConstantOp>(location, type, attr);
55 }
56 
CreateF32SplatConst(OpBuilder * builder,ArrayRef<int64_t> shape,float val,mlir::Location location)57 Value CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
58                           float val, mlir::Location location) {
59   auto type = RankedTensorType::get(shape, builder->getF32Type());
60   auto attr = DenseElementsAttr::get(type, val);
61   return builder->create<arith::ConstantOp>(location, type, attr);
62 }
63 
CreatTfF32ConstOp(OpBuilder * builder,ArrayRef<int64_t> shape,float val,mlir::Location location)64 Value CreatTfF32ConstOp(OpBuilder* builder, ArrayRef<int64_t> shape, float val,
65                         mlir::Location location) {
66   auto type = RankedTensorType::get(shape, builder->getF32Type());
67   auto ele_type = RankedTensorType::get({1}, builder->getF32Type());
68   auto attr = DenseElementsAttr::get(ele_type, val);
69   return builder->create<TF::ConstOp>(location, type, attr);
70 }
71 
CreateI64DenseConst(OpBuilder * builder,ArrayRef<int64_t> shape,ArrayRef<int64_t> values,mlir::Location location)72 Value CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
73                           ArrayRef<int64_t> values, mlir::Location location) {
74   auto type = RankedTensorType::get(static_cast<int>(shape.size()),
75                                     builder->getIntegerType(64));
76   auto attr = DenseElementsAttr::get(type, values);
77   return builder->create<arith::ConstantOp>(location, type, attr);
78 }
79 
CreateI32DenseConst(OpBuilder * builder,ArrayRef<int32_t> values,mlir::Location location)80 Value CreateI32DenseConst(OpBuilder* builder, ArrayRef<int32_t> values,
81                           mlir::Location location) {
82   auto type = RankedTensorType::get(static_cast<int>(values.size()),
83                                     builder->getIntegerType(32));
84   auto attr = DenseElementsAttr::get(type, values);
85   return builder->create<arith::ConstantOp>(location, type, attr);
86 }
87 
CreateNoneValue(OpBuilder * builder,mlir::Location location)88 Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
89   return builder->create<TFL::NoValueOp>(location, builder->getNoneType(),
90                                          builder->getUnitAttr());
91 }
92 
Transpose(OpBuilder * builder,Value value_to_transpose,SmallVector<int32_t,4> perm,RankedTensorType original_type,mlir::Location location)93 Value Transpose(OpBuilder* builder, Value value_to_transpose,
94                 SmallVector<int32_t, 4> perm, RankedTensorType original_type,
95                 mlir::Location location) {
96   // Create a constant op for transpose permutation.
97   auto perm_op = CreateI32DenseConst(builder, perm, location);
98 
99   // Create tensor type for the transpose result.
100   auto transpose_type = original_type;
101   auto transpose_shape =
102       llvm::to_vector<8>(llvm::map_range(perm, [transpose_type](int32_t dim) {
103         return transpose_type.getDimSize(dim);
104       }));
105   auto elem_type = transpose_type.getElementType();
106   auto result_type = RankedTensorType::get(transpose_shape, elem_type);
107 
108   return builder->create<TF::TransposeOp>(location, result_type,
109                                           value_to_transpose, perm_op);
110 }
111 
Transpose2D(OpBuilder * builder,Value value_to_transpose,RankedTensorType type,mlir::Location location)112 Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
113                   RankedTensorType type, mlir::Location location) {
114   // Create a constant op for transpose permutation.
115   SmallVector<int32_t, 4> perm = {1, 0};
116   return Transpose(builder, value_to_transpose, perm, type, location);
117 }
118 
Reverse(OpBuilder * builder,Value value_to_reverse,int axis,RankedTensorType type,mlir::Location location)119 Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis,
120               RankedTensorType type, mlir::Location location) {
121   auto axis_op = CreateI32SplatConst(builder, {1}, axis, location);
122   // The result type will be the same as the input.
123   return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse,
124                                           axis_op);
125 }
126 
GetRankedTensorShape(Value value)127 ArrayRef<int64_t> GetRankedTensorShape(Value value) {
128   return value.getType().cast<RankedTensorType>().getShape();
129 }
130 
SliceRankedTensor(OpBuilder * builder,Value input,ArrayRef<int64_t> begin_shape,ArrayRef<int64_t> begin_values,ArrayRef<int64_t> size_shape,ArrayRef<int64_t> size_values,mlir::Location location)131 Value SliceRankedTensor(OpBuilder* builder, Value input,
132                         ArrayRef<int64_t> begin_shape,
133                         ArrayRef<int64_t> begin_values,
134                         ArrayRef<int64_t> size_shape,
135                         ArrayRef<int64_t> size_values,
136                         mlir::Location location) {
137   // If the size of the tensor to be sliced from the input overflows
138   // the input tensor's dimensions, return 0-valued tensor of the requested
139   // shape.
140   ArrayRef<int64_t> input_shape = GetRankedTensorShape(input);
141   for (int i = 0, end = input_shape.size(); i < end; i++) {
142     if (begin_values[i] < 0 ||
143         (begin_values[i] + size_values[i] > input_shape[i])) {
144       return CreateF32SplatConst(builder, size_shape, 0, location);
145     }
146   }
147 
148   // Create a dense constant op for slice's begin
149   auto slice_i2c_begin =
150       CreateI64DenseConst(builder, begin_shape, begin_values, location);
151 
152   // Create a dense constant op for slice's size
153   auto slice_i2c_size =
154       CreateI64DenseConst(builder, size_shape, size_values, location);
155 
156   return builder->create<TF::SliceOp>(
157       location,
158       RankedTensorType::get(
159           size_values,
160           input.getType().cast<RankedTensorType>().getElementType()),
161       input, slice_i2c_begin, slice_i2c_size);
162 }
163 
CreateStridedSliceOp(mlir::Location loc,ArrayRef<int64_t> output_shape,Value input,ArrayRef<int32_t> begin,ArrayRef<int32_t> end,ArrayRef<int32_t> strides,int64_t begin_mask,int64_t end_mask,int64_t ellipsis_mask,int64_t new_axis_mask,int64_t shrink_axis_mask,OpBuilder * builder)164 Value CreateStridedSliceOp(mlir::Location loc, ArrayRef<int64_t> output_shape,
165                            Value input, ArrayRef<int32_t> begin,
166                            ArrayRef<int32_t> end, ArrayRef<int32_t> strides,
167                            int64_t begin_mask, int64_t end_mask,
168                            int64_t ellipsis_mask, int64_t new_axis_mask,
169                            int64_t shrink_axis_mask, OpBuilder* builder) {
170   auto output_type = RankedTensorType::get(
171       output_shape, input.getType().cast<RankedTensorType>().getElementType());
172   auto begin_tensor = CreateI32DenseConst(builder, begin, loc);
173   auto end_tensor = CreateI32DenseConst(builder, end, loc);
174   auto strides_tensor = CreateI32DenseConst(builder, strides, loc);
175 
176   return builder->create<TF::StridedSliceOp>(
177       loc, output_type, input, begin_tensor, end_tensor, strides_tensor,
178       builder->getI64IntegerAttr(begin_mask),
179       builder->getI64IntegerAttr(end_mask),
180       builder->getI64IntegerAttr(ellipsis_mask),
181       builder->getI64IntegerAttr(new_axis_mask),
182       builder->getI64IntegerAttr(shrink_axis_mask));
183 }
184 
185 }  // namespace
186 
SetWeightForInputToCellGate()187 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
188   SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
189   input2cell_ = SliceRankedTensor(
190       &builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
191       weight_slice_shape_, weight_slice_size_input_values_,
192       fused_func_op_.getLoc());
193 }
194 
SetWeightForInputToInputGate()195 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
196   SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
197   input2input_ = couple_input_forget_gates_
198                      ? none_
199                      : SliceRankedTensor(&builder_, weight_transposed_,
200                                          weight_slice_shape_, begin_i2i_values,
201                                          weight_slice_shape_,
202                                          weight_slice_size_input_values_,
203                                          fused_func_op_.getLoc());
204 }
205 
SetWeightForInputToForgetGate()206 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
207   int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
208   SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
209   input2forget_ = SliceRankedTensor(
210       &builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
211       weight_slice_shape_, weight_slice_size_input_values_,
212       fused_func_op_.getLoc());
213 }
214 
SetWeightForInputToOutputGate()215 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
216   int input_output_start =
217       couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
218   SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
219   input2output_ = SliceRankedTensor(
220       &builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
221       weight_slice_shape_, weight_slice_size_input_values_,
222       fused_func_op_.getLoc());
223 }
224 
SetWeightForRecurrentToCellGate()225 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
226   SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
227   rec2cell_ = SliceRankedTensor(
228       &builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
229       weight_slice_shape_, weight_slice_size_recurrent_values_,
230       fused_func_op_.getLoc());
231 }
232 
SetWeightForRecurrentToInputGate()233 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
234   SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
235   rec2input_ = couple_input_forget_gates_
236                    ? none_
237                    : SliceRankedTensor(&builder_, weight_transposed_,
238                                        weight_slice_shape_, begin_rec2i_values,
239                                        weight_slice_shape_,
240                                        weight_slice_size_recurrent_values_,
241                                        fused_func_op_.getLoc());
242 }
243 
SetWeightForRecurrentToForgetGate()244 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
245   int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
246   SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
247   rec2forget_ = SliceRankedTensor(
248       &builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
249       weight_slice_shape_, weight_slice_size_recurrent_values_,
250       fused_func_op_.getLoc());
251 }
252 
SetWeightForRecurrentToOutputGate()253 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
254   int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
255   SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
256   rec2output_ = SliceRankedTensor(
257       &builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
258       weight_slice_shape_, weight_slice_size_recurrent_values_,
259       fused_func_op_.getLoc());
260 }
261 
SetBiasToCellGate()262 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
263   SmallVector<int64_t, 1> begin_bias2c_values = {0};
264   bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
265                                  begin_bias2c_values, bias_slice_shape_,
266                                  bias_size_values_, fused_func_op_.getLoc());
267 }
268 
SetBiasToInputGate()269 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
270   SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
271   bias2input_ =
272       couple_input_forget_gates_
273           ? none_
274           : SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
275                               begin_bias2i_values, bias_slice_shape_,
276                               bias_size_values_, fused_func_op_.getLoc());
277 }
278 
SetBiasToForgetGate()279 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
280   int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
281   SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
282   bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
283                                    begin_bias2f_values, bias_slice_shape_,
284                                    bias_size_values_, fused_func_op_.getLoc());
285 }
286 
SetBiasToOutputGate()287 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
288   int bias_output_start =
289       couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
290   SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
291   bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
292                                    begin_bias2o_values, bias_slice_shape_,
293                                    bias_size_values_, fused_func_op_.getLoc());
294 }
295 
SetProjection()296 void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
297   SmallVector<int64_t, 2> projection_slice_shape = {
298       1, num_cols_projection_transposed_};
299   SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
300   SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
301   proj_weight_ =
302       !projection_
303           ? none_
304           : SliceRankedTensor(
305                 &builder_, projection_transposed_, projection_slice_shape,
306                 projection_slice_begin_values, projection_slice_shape,
307                 projection_slice_size_values, fused_func_op_.getLoc());
308 }
309 
SetProjectionBias()310 void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
311   proj_bias_ = !projection_type_
312                    ? none_
313                    : CreateF32SplatConst(&builder_, {n_output_}, 0,
314                                          fused_func_op_.getLoc());
315 }
316 
SetInputActivationState()317 void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
318   input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
319                                                 fused_func_op_.getLoc());
320 }
321 
SetInputCellState()322 void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
323   input_cell_state_ =
324       CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
325 }
326 
SetCellLayerNormCoefficients()327 void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
328   cell_layer_norm_coefficients_ = none_;
329 }
330 
SetInputLayerNormCoefficients()331 void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
332   input_layer_norm_coefficients_ = none_;
333 }
334 
SetForgetLayerNormCoefficients()335 void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
336   forget_layer_norm_coefficients_ = none_;
337 }
SetOutputLayerNormCoefficients()338 void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
339   output_layer_norm_coefficients_ = none_;
340 }
341 
GenerateFusedOpOperands()342 void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
343   // Transpose both weight and projection.
344   weight_transposed_ =
345       Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
346   projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
347                                        fused_func_op_.getLoc());
348 
349   none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
350   // Extract input to cifg gates via slicing the weight tensor
351   SetWeightForInputToCellGate();
352   SetWeightForInputToInputGate();
353   SetWeightForInputToForgetGate();
354   SetWeightForInputToOutputGate();
355 
356   // Extract recurrent to cifg gates via slicing the weight tensor
357   SetWeightForRecurrentToCellGate();
358   SetWeightForRecurrentToInputGate();
359   SetWeightForRecurrentToForgetGate();
360   SetWeightForRecurrentToOutputGate();
361 
362   // Extract bias to cifg gates via slicing the bias tensor
363   SetBiasToCellGate();
364   SetBiasToInputGate();
365   SetBiasToForgetGate();
366   SetBiasToOutputGate();
367 
368   // Extract projection and set an empty projection bias
369   SetProjection();
370   SetProjectionBias();
371 
372   // Set the variable tensors
373   SetInputActivationState();
374   SetInputCellState();
375 
376   // Extract the layer norm coefficients
377   SetCellLayerNormCoefficients();
378   SetInputLayerNormCoefficients();
379   SetForgetLayerNormCoefficients();
380   SetOutputLayerNormCoefficients();
381 }
382 
UpdateFuncSignature()383 void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
384   // https://github.com/tensorflow/community/pull/113
385   SmallVector<int64_t, 2> output_shape{1, -1};
386   auto input_types = fused_func_op_.getFunctionType().getInputs();
387   auto output_type = mlir::RankedTensorType::get(
388       output_shape, input_.getType().cast<RankedTensorType>().getElementType());
389   fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
390                                                  input_types, output_type));
391 }
392 
RewriteFunc()393 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
394   LogicalResult result = Initialize();
395   if (failed(result)) {
396     return result;
397   }
398 
399   // Update the func signature, based on output shape.
400   // The func will ultimately return the output of the fused
401   // LSTM op.
402   UpdateFuncSignature();
403 
404   // Transform the weights, projection, bias and layer norm coefficients
405   // to generate operands for the TFL fused LSTM op.
406   GenerateFusedOpOperands();
407 
408   // Create the fused LSTM op.
409   SmallVector<int64_t, 2> output_shape = {1, n_output_};
410   auto result_type = mlir::RankedTensorType::get(
411       output_shape, input_.getType().cast<RankedTensorType>().getElementType());
412   lstm_ = builder_.create<mlir::TFL::LSTMOp>(
413       fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
414       input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
415       rec2output_, /*cell_to_input_weights*/ none_,
416       /*cell_to_forget_weights*/ none_,
417       /*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
418       bias2output_, proj_weight_, proj_bias_, input_activation_state_,
419       input_cell_state_, input_layer_norm_coefficients_,
420       forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
421       output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
422       builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
423       mlir::TFL::LSTMKernelTypeAttr::get(builder_.getContext(),
424                                          mlir::TFL::LSTMKernelType::FULL),
425       /*asymmetric_quantize_inputs=*/mlir::BoolAttr(),
426       /*input_to_input_intermediate=*/mlir::TypeAttr(),
427       /*input_to_forget_intermediate=*/mlir::TypeAttr(),
428       /*input_to_cell_intermediate=*/mlir::TypeAttr(),
429       /*input_to_output_intermediate=*/mlir::TypeAttr(),
430       /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
431 
432   // Cast the static shaped lstm result to FuncOp's signature -
433   // Ranked but unknown 2nd dimension to support stacking these.
434   SmallVector<int64_t, 2> func_output_shape = {1, -1};
435   auto func_result_type = mlir::RankedTensorType::get(
436       func_output_shape,
437       input_.getType().cast<RankedTensorType>().getElementType());
438 
439   auto tensor_cast = builder_.create<mlir::tensor::CastOp>(
440       fused_func_op_.getLoc(), func_result_type, lstm_.getResult());
441   builder_.create<mlir::func::ReturnOp>(fused_func_op_.getLoc(),
442                                         tensor_cast.getResult());
443   return success();
444 }
445 
InitializeFromFuncAttributes()446 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::InitializeFromFuncAttributes() {
447   auto attr = fused_func_op_->getAttrOfType<StringAttr>(kTFImplements);
448   if (!attr) {
449     return fused_func_op_.emitError()
450            << "Invalid function attribute, expected " << kTFImplements
451            << " attribute "
452               "not found";
453   }
454 
455   // TODO(ashwinm, b/144775479): Make these NamedAttribute on TF import
456   // once tf.function can support this.
457   llvm::SmallVector<llvm::StringRef, 4> attr_tokens;
458   attr.getValue().split(attr_tokens, ",");
459   if (attr_tokens.empty()) {
460     return fused_func_op_.emitError()
461            << kTFImplements << " attribute should be set";
462   }
463 
464   // Check if the interface matches.
465   if (GetCompositeOpName().str() != attr_tokens[0]) {
466     return fused_func_op_.emitError()
467            << "Unexpected interface for the composite op. Expected: "
468            << GetCompositeOpName() << " Actual: " << attr_tokens[0];
469   }
470 
471   // Extract other interface attributes, for now cifg.
472   couple_input_forget_gates_ =
473       std::find(attr_tokens.begin() + 1, attr_tokens.end(),
474                 kCoupleInputForgetGates) != attr_tokens.end();
475 
476   return success();
477 }
478 
Initialize()479 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
480   if (failed(InitializeFromFuncAttributes())) {
481     return fused_func_op_.emitError()
482            << "Expected function attributes were not set on the function "
483               "encapsulating the composite op";
484   }
485 
486   num_gates_ = couple_input_forget_gates_ ? 3 : 4;
487 
488   input_ = fused_func_op_.getArgument(0);
489   bias_ = fused_func_op_.getArgument(2);
490 
491   weight_ = fused_func_op_.getArgument(1);
492   weight_type_ = weight_.getType().cast<RankedTensorType>();
493 
494   if (weight_type_.getRank() != 2) {
495     return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
496   }
497 
498   if (weight_type_.getDimSize(1) % num_gates_ != 0) {
499     return fused_func_op_.emitError()
500            << "Invalid dimension 1 of weight tensor, "
501               "should be divisible by the number of gates";
502   }
503   n_cell_ = weight_type_.getDimSize(1) / num_gates_;
504 
505   projection_ = fused_func_op_.getArgument(3);
506   projection_type_ = projection_.getType().cast<RankedTensorType>();
507   if (projection_type_.getRank() != 2) {
508     n_output_ = n_cell_;
509   } else {
510     n_output_ = projection_type_.getDimSize(1);
511   }
512   n_input_ = weight_type_.getDimSize(0) - n_output_;
513   num_cols_weight_transposed_ = weight_type_.getDimSize(0);
514   num_cols_projection_transposed_ = projection_type_.getDimSize(0);
515 
516   bias_slice_shape_ = {n_cell_};
517   bias_size_values_ = {n_cell_};
518   weight_slice_shape_ = {1, num_cols_weight_transposed_};
519   weight_slice_size_input_values_ = {n_cell_, n_input_};
520   weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
521 
522   return success();
523 }
524 
Initialize()525 LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
526   if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
527     return fused_func_op_.emitError()
528            << "Specified LayerNormalizedLSTMCellSimple was not of the expected "
529               "interface and cannot not be converted to the fused LSTM op";
530   }
531 
532   layer_norm_scale_ = fused_func_op_.getArgument(4);
533   layer_norm_scale_type_ = layer_norm_scale_.getType().cast<RankedTensorType>();
534   if (layer_norm_scale_type_.getRank() != 1) {
535     return fused_func_op_.emitError()
536            << "The layer_norm_scale tensor was not of rank 1";
537   }
538   layer_norm_slice_shape_ = {n_cell_};
539   layer_norm_size_values_ = {n_cell_};
540 
541   return success();
542 }
543 
544 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetCellLayerNormCoefficients()545     SetCellLayerNormCoefficients() {
546   SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
547   cell_layer_norm_coefficients_ =
548       SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
549                         begin_cell_layer_norm_values, layer_norm_slice_shape_,
550                         layer_norm_size_values_, fused_func_op_.getLoc());
551 }
552 
553 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetInputLayerNormCoefficients()554     SetInputLayerNormCoefficients() {
555   SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
556   input_layer_norm_coefficients_ =
557       couple_input_forget_gates_
558           ? none_
559           : SliceRankedTensor(
560                 &builder_, layer_norm_scale_, layer_norm_slice_shape_,
561                 begin_input_layer_norm_values, layer_norm_slice_shape_,
562                 layer_norm_size_values_, fused_func_op_.getLoc());
563 }
564 
565 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetForgetLayerNormCoefficients()566     SetForgetLayerNormCoefficients() {
567   SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
568   forget_layer_norm_coefficients_ =
569       SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
570                         begin_forget_layer_norm_values, layer_norm_slice_shape_,
571                         layer_norm_size_values_, fused_func_op_.getLoc());
572 }
573 
574 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetOutputLayerNormCoefficients()575     SetOutputLayerNormCoefficients() {
576   SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
577   output_layer_norm_coefficients_ =
578       SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
579                         begin_output_layer_norm_values, layer_norm_slice_shape_,
580                         layer_norm_size_values_, fused_func_op_.getLoc());
581 }
582 
Create1DConstantOp(const std::vector<int> & value,Location loc,OpBuilder * builder)583 TF::ConstOp Create1DConstantOp(const std::vector<int>& value, Location loc,
584                                OpBuilder* builder) {
585   auto type =
586       mlir::RankedTensorType::get(value.size(), builder->getIntegerType(32));
587   auto dense_values = mlir::DenseIntElementsAttr::get(type, value);
588   return builder->create<TF::ConstOp>(loc, dense_values);
589 }
590 
CreateScalarConstantOp(int value,Location loc,OpBuilder * builder)591 TF::ConstOp CreateScalarConstantOp(int value, Location loc,
592                                    OpBuilder* builder) {
593   return builder->create<TF::ConstOp>(loc, builder->getI32IntegerAttr(value));
594 }
595 
CreateEqualSizeSplitVOp(Value input,int axis,int splits,Location loc,OpBuilder * builder,Operation ** result)596 LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits,
597                                       Location loc, OpBuilder* builder,
598                                       Operation** result) {
599   auto input_type = input.getType().cast<RankedTensorType>();
600   SmallVector<int64_t, 4> output_shape;
601   int size_of_splits;
602   if (input_type.getRank() < axis || axis < 0) return failure();
603   for (int i = 0; i < input_type.getRank(); ++i) {
604     int dim = input_type.getDimSize(i);
605     if (i == axis) {
606       if (dim % splits != 0) {
607         return failure();
608       }
609       size_of_splits = dim / splits;
610       output_shape.push_back(size_of_splits);
611     } else {
612       output_shape.push_back(dim);
613     }
614   }
615 
616   SmallVector<mlir::Type, 4> output_types;
617   for (int i = 0; i < splits; ++i) {
618     output_types.push_back(
619         mlir::RankedTensorType::get(output_shape, input_type.getElementType()));
620   }
621   auto size_of_splits_op = Create1DConstantOp(
622       {size_of_splits, size_of_splits, size_of_splits, size_of_splits}, loc,
623       builder);
624 
625   auto axis_op = CreateScalarConstantOp(axis, loc, builder);
626   *result = builder->create<TF::SplitVOp>(loc, output_types, input,
627                                           size_of_splits_op.getResult(),
628                                           axis_op.getResult());
629   return success();
630 }
631 
632 // TODO(b/147436982): Consider refactor this to be more general.
ConvertKerasLSTMLayer(mlir::func::FuncOp func_op,OpBuilder * builder)633 LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op,
634                                     OpBuilder* builder) {
635   // For argument order, please check out standard_lstm under
636   // tensorflow/python/keras/layers/recurrent_v2.py
637   Value input = func_op.getArgument(0);
638   Value output_init_state = func_op.getArgument(1);
639   Value hidden_init_state = func_op.getArgument(2);
640   Value weight_kernel = func_op.getArgument(3);
641   Value recurrent_kernel = func_op.getArgument(4);
642   Value bias = func_op.getArgument(5);
643 
644   // The func op should have 5 outputs.
645   if (func_op.getNumResults() != 5) return failure();
646 
647   // TFL lstm only supports time-majored inputs, so if it's not time-majored,
648   // we will transpose the inputs and outputs.
649   auto time_major_attr = func_op->getAttrOfType<BoolAttr>("tf.time_major");
650   if (time_major_attr == nullptr) return failure();
651 
652   bool time_majored = time_major_attr.getValue();
653   auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
654   if (!input_type) {
655     func_op.emitError() << "Input type is not a ranked tensor type";
656     return failure();
657   }
658 
659   auto final_inputs = input;
660   auto final_input_type = input_type;
661 
662   // Handle go_backwards:
663   // LSTM in Keras semantic will reverse the input sequence if it's go_backwards
664   auto go_backwards_attr = func_op->getAttrOfType<BoolAttr>("tf.go_backwards");
665 
666   if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
667     int time_dim = time_majored ? 0 : 1;
668     final_inputs = Reverse(builder, final_inputs, time_dim, final_input_type,
669                            func_op.getLoc());
670   }
671 
672   int batch = time_majored ? final_input_type.getDimSize(1)
673                            : final_input_type.getDimSize(0);
674   int time = time_majored ? final_input_type.getDimSize(0)
675                           : final_input_type.getDimSize(1);
676 
677   // Setup correct weights.
678   RankedTensorType weight_type =
679       weight_kernel.getType().cast<RankedTensorType>();
680   if (weight_type.getRank() != 2)
681     return func_op.emitError() << "The weight should be rank of 2";
682 
683   Value transposed_weight_kernel =
684       Transpose2D(builder, weight_kernel, weight_type, func_op.getLoc());
685 
686   RankedTensorType recurrent_kernel_type =
687       recurrent_kernel.getType().cast<RankedTensorType>();
688   const int n_output = recurrent_kernel_type.getDimSize(0);
689 
690   Value transpose_recurrent_kernel = Transpose2D(
691       builder, recurrent_kernel, recurrent_kernel_type, func_op.getLoc());
692 
693   // Splits the weights into 4: i, f, c, o.
694   const int splits = 4;
695 
696   Operation* weights_array;
697   if (failed(CreateEqualSizeSplitVOp(transposed_weight_kernel, 0, splits,
698                                      func_op.getLoc(), builder,
699                                      &weights_array)))
700     return failure();
701 
702   // Splits the recurrent_weights into 4:
703   Operation* recurrent_weights_array;
704   if (failed(CreateEqualSizeSplitVOp(transpose_recurrent_kernel, 0, splits,
705                                      func_op.getLoc(), builder,
706                                      &recurrent_weights_array)))
707     return failure();
708 
709   // Splits the bias into 4:
710   Operation* bias_array;
711   if (failed(CreateEqualSizeSplitVOp(bias, 0, splits, func_op.getLoc(), builder,
712                                      &bias_array)))
713     return failure();
714 
715   // Build the lstm op.
716   SmallVector<int64_t, 3> output_shape;
717   if (time_majored) {
718     output_shape = {time, batch, n_output};
719   } else {
720     output_shape = {batch, time, n_output};
721   }
722   auto result_type = mlir::RankedTensorType::get(
723       output_shape,
724       final_inputs.getType().cast<RankedTensorType>().getElementType());
725 
726   Value none = CreateNoneValue(builder, func_op.getLoc());
727   auto lstm = builder->create<mlir::TFL::UnidirectionalSequenceLSTMOp>(
728       func_op.getLoc(), result_type, /*input=*/final_inputs,
729       /*input_to_input_weights=*/weights_array->getResult(0),
730       /*input_to_forget_weights=*/weights_array->getResult(1),
731       /*input_to_cell_weights=*/weights_array->getResult(2),
732       /*input_to_output_weights=*/weights_array->getResult(3),
733       /*recurrent_to_input_weights=*/recurrent_weights_array->getResult(0),
734       /*recurrent_to_forget_weights=*/recurrent_weights_array->getResult(1),
735       /*recurrent_to_cell_weights=*/recurrent_weights_array->getResult(2),
736       /*recurrent_to_output_weights=*/recurrent_weights_array->getResult(3),
737       /*cell_to_input_weights=*/none,
738       /*cell_to_forget_weights=*/none,
739       /*cell_to_output_weights=*/none,
740       /*input_gate_bias=*/bias_array->getResult(0),
741       /*forget_gate_bias=*/bias_array->getResult(1),
742       /*cell_bias=*/bias_array->getResult(2),
743       /*output_gate_bias=*/bias_array->getResult(3),
744       /*projection_weights=*/none,
745       /*projection_bias=*/none,
746       /*input_activation_state=*/output_init_state,
747       /*input_cell_state=*/hidden_init_state,
748       /*input_layer_norm_coefficients=*/none,
749       /*forget_layer_norm_coefficients=*/none,
750       /*cell_layer_norm_coefficients=*/none,
751       /*output_layer_norm_coefficients=*/none,
752       /*fused_activation_function*/ builder->getStringAttr("TANH"),
753       /*cell_clip*/ builder->getF32FloatAttr(10.0),
754       /*proj_clip*/ builder->getF32FloatAttr(0.0),
755       /*time_major*/ builder->getBoolAttr(time_majored),
756       /*asymmetric_quantize_inputs=*/mlir::BoolAttr(),
757       /*input_to_input_intermediate=*/mlir::TypeAttr(),
758       /*input_to_forget_intermediate=*/mlir::TypeAttr(),
759       /*input_to_cell_intermediate=*/mlir::TypeAttr(),
760       /*input_to_output_intermediate=*/mlir::TypeAttr(),
761       /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
762 
763   auto final_output_full_sequences = lstm.getResult();
764 
765   // Populate the last output: last output is sliced from the full sequences.
766   // If time_major: last_output = outputs[-1, :, :]
767   // else: last_output = outputs[:, -1, :]
768   //
769   // As we are creating the strided_slice op, we need to populate the following
770   // fields:
771   // end: should always be (0, 0, 0)
772   // strides: should always be (1, 1, 1)
773   // begin: should be (0, -1, 0) or (-1, 0, 0) if it's time-majored.
774   // new_axis_mask: should always be 0.
775   // ellipsis_mask: should always be 0.
776   // begin_mask & end_mask: should be 0b101 = 5 or 0b110 = 4 if it's
777   // time-majored. shrink_axis_mask: should be 0b010 = 2 or 0b001 = 1 if it's
778   // time-majored.
779   SmallVector<int64_t, 2> last_output_shape({batch, n_output});
780 
781   SmallVector<int32_t, 3> end({0, 0, 0});
782   SmallVector<int32_t, 3> strides({1, 1, 1});
783   SmallVector<int32_t, 3> begin;
784 
785   int64_t new_axis_mask = 0;
786   int64_t ellipsis_mask = 0;
787   int64_t begin_mask;
788   int64_t end_mask;
789   int64_t shrink_axis_mask;
790   if (time_majored) {
791     begin_mask = 6;
792     end_mask = 6;
793     shrink_axis_mask = 1;
794     begin = {-1, 0, 0};
795   } else {
796     begin_mask = 5;
797     end_mask = 5;
798     shrink_axis_mask = 2;
799     begin = {0, -1, 0};
800   }
801 
802   auto last_output = CreateStridedSliceOp(
803       func_op.getLoc(), last_output_shape, final_output_full_sequences, begin,
804       end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
805       shrink_axis_mask, builder);
806 
807   SmallVector<Value, 5> outputs;
808   SmallVector<Type, 5> output_types;
809 
810   // Due to the existence of the while loop, the timestamp may be unknown
811   // for the signature, for us, since we know the inputs, we can infer the time
812   // steps.
813 
814   // Last output.
815   outputs.push_back(last_output);
816   output_types.push_back(last_output.getType());
817 
818   // Full sequences.
819   outputs.push_back(final_output_full_sequences);
820   output_types.push_back(final_output_full_sequences.getType());
821 
822   // All the rest: states, device.
823   for (int i = 2; i < 5; ++i) {
824     auto result_type =
825         func_op.getCallableResults()[i].dyn_cast<RankedTensorType>();
826     outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f,
827                                         func_op.getLoc()));
828     output_types.push_back(result_type);
829   }
830 
831   // Update function signatures.
832   func_op.setType(mlir::FunctionType::get(func_op.getContext(),
833                                           func_op.getFunctionType().getInputs(),
834                                           output_types));
835 
836   builder->create<mlir::func::ReturnOp>(func_op.getLoc(), outputs);
837   return success();
838 }
839 
840 }  // namespace TFL
841 }  // namespace mlir
842