1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/mlir/lite/utils/lstm_utils.h"
17
18 #include <algorithm>
19
20 #include "llvm/ADT/ArrayRef.h"
21 #include "llvm/ADT/None.h"
22 #include "llvm/ADT/SmallVector.h"
23 #include "llvm/ADT/StringRef.h"
24 #include "llvm/Support/Casting.h"
25 #include "llvm/Support/raw_ostream.h"
26 #include "mlir/Dialect/Arithmetic/IR/Arithmetic.h" // from @llvm-project
27 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
28 #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project
29 #include "mlir/IR/Attributes.h" // from @llvm-project
30 #include "mlir/IR/Builders.h" // from @llvm-project
31 #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
32 #include "mlir/IR/BuiltinOps.h" // from @llvm-project
33 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project
34 #include "mlir/IR/Location.h" // from @llvm-project
35 #include "mlir/IR/MLIRContext.h" // from @llvm-project
36 #include "mlir/IR/OpDefinition.h" // from @llvm-project
37 #include "mlir/IR/Operation.h" // from @llvm-project
38 #include "mlir/IR/Types.h" // from @llvm-project
39 #include "mlir/IR/Value.h" // from @llvm-project
40 #include "mlir/Support/LLVM.h" // from @llvm-project
41 #include "mlir/Support/LogicalResult.h" // from @llvm-project
42 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
43 #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
44
45 namespace mlir {
46 namespace TFL {
47
48 namespace {
49
CreateI32SplatConst(OpBuilder * builder,ArrayRef<int64_t> shape,int32_t val,mlir::Location location)50 Value CreateI32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
51 int32_t val, mlir::Location location) {
52 auto type = RankedTensorType::get(shape, builder->getIntegerType(32));
53 auto attr = DenseElementsAttr::get(type, val);
54 return builder->create<arith::ConstantOp>(location, type, attr);
55 }
56
CreateF32SplatConst(OpBuilder * builder,ArrayRef<int64_t> shape,float val,mlir::Location location)57 Value CreateF32SplatConst(OpBuilder* builder, ArrayRef<int64_t> shape,
58 float val, mlir::Location location) {
59 auto type = RankedTensorType::get(shape, builder->getF32Type());
60 auto attr = DenseElementsAttr::get(type, val);
61 return builder->create<arith::ConstantOp>(location, type, attr);
62 }
63
CreatTfF32ConstOp(OpBuilder * builder,ArrayRef<int64_t> shape,float val,mlir::Location location)64 Value CreatTfF32ConstOp(OpBuilder* builder, ArrayRef<int64_t> shape, float val,
65 mlir::Location location) {
66 auto type = RankedTensorType::get(shape, builder->getF32Type());
67 auto ele_type = RankedTensorType::get({1}, builder->getF32Type());
68 auto attr = DenseElementsAttr::get(ele_type, val);
69 return builder->create<TF::ConstOp>(location, type, attr);
70 }
71
CreateI64DenseConst(OpBuilder * builder,ArrayRef<int64_t> shape,ArrayRef<int64_t> values,mlir::Location location)72 Value CreateI64DenseConst(OpBuilder* builder, ArrayRef<int64_t> shape,
73 ArrayRef<int64_t> values, mlir::Location location) {
74 auto type = RankedTensorType::get(static_cast<int>(shape.size()),
75 builder->getIntegerType(64));
76 auto attr = DenseElementsAttr::get(type, values);
77 return builder->create<arith::ConstantOp>(location, type, attr);
78 }
79
CreateI32DenseConst(OpBuilder * builder,ArrayRef<int32_t> values,mlir::Location location)80 Value CreateI32DenseConst(OpBuilder* builder, ArrayRef<int32_t> values,
81 mlir::Location location) {
82 auto type = RankedTensorType::get(static_cast<int>(values.size()),
83 builder->getIntegerType(32));
84 auto attr = DenseElementsAttr::get(type, values);
85 return builder->create<arith::ConstantOp>(location, type, attr);
86 }
87
CreateNoneValue(OpBuilder * builder,mlir::Location location)88 Value CreateNoneValue(OpBuilder* builder, mlir::Location location) {
89 return builder->create<TFL::NoValueOp>(location, builder->getNoneType(),
90 builder->getUnitAttr());
91 }
92
Transpose(OpBuilder * builder,Value value_to_transpose,SmallVector<int32_t,4> perm,RankedTensorType original_type,mlir::Location location)93 Value Transpose(OpBuilder* builder, Value value_to_transpose,
94 SmallVector<int32_t, 4> perm, RankedTensorType original_type,
95 mlir::Location location) {
96 // Create a constant op for transpose permutation.
97 auto perm_op = CreateI32DenseConst(builder, perm, location);
98
99 // Create tensor type for the transpose result.
100 auto transpose_type = original_type;
101 auto transpose_shape =
102 llvm::to_vector<8>(llvm::map_range(perm, [transpose_type](int32_t dim) {
103 return transpose_type.getDimSize(dim);
104 }));
105 auto elem_type = transpose_type.getElementType();
106 auto result_type = RankedTensorType::get(transpose_shape, elem_type);
107
108 return builder->create<TF::TransposeOp>(location, result_type,
109 value_to_transpose, perm_op);
110 }
111
Transpose2D(OpBuilder * builder,Value value_to_transpose,RankedTensorType type,mlir::Location location)112 Value Transpose2D(OpBuilder* builder, Value value_to_transpose,
113 RankedTensorType type, mlir::Location location) {
114 // Create a constant op for transpose permutation.
115 SmallVector<int32_t, 4> perm = {1, 0};
116 return Transpose(builder, value_to_transpose, perm, type, location);
117 }
118
Reverse(OpBuilder * builder,Value value_to_reverse,int axis,RankedTensorType type,mlir::Location location)119 Value Reverse(OpBuilder* builder, Value value_to_reverse, int axis,
120 RankedTensorType type, mlir::Location location) {
121 auto axis_op = CreateI32SplatConst(builder, {1}, axis, location);
122 // The result type will be the same as the input.
123 return builder->create<TF::ReverseV2Op>(location, type, value_to_reverse,
124 axis_op);
125 }
126
GetRankedTensorShape(Value value)127 ArrayRef<int64_t> GetRankedTensorShape(Value value) {
128 return value.getType().cast<RankedTensorType>().getShape();
129 }
130
SliceRankedTensor(OpBuilder * builder,Value input,ArrayRef<int64_t> begin_shape,ArrayRef<int64_t> begin_values,ArrayRef<int64_t> size_shape,ArrayRef<int64_t> size_values,mlir::Location location)131 Value SliceRankedTensor(OpBuilder* builder, Value input,
132 ArrayRef<int64_t> begin_shape,
133 ArrayRef<int64_t> begin_values,
134 ArrayRef<int64_t> size_shape,
135 ArrayRef<int64_t> size_values,
136 mlir::Location location) {
137 // If the size of the tensor to be sliced from the input overflows
138 // the input tensor's dimensions, return 0-valued tensor of the requested
139 // shape.
140 ArrayRef<int64_t> input_shape = GetRankedTensorShape(input);
141 for (int i = 0, end = input_shape.size(); i < end; i++) {
142 if (begin_values[i] < 0 ||
143 (begin_values[i] + size_values[i] > input_shape[i])) {
144 return CreateF32SplatConst(builder, size_shape, 0, location);
145 }
146 }
147
148 // Create a dense constant op for slice's begin
149 auto slice_i2c_begin =
150 CreateI64DenseConst(builder, begin_shape, begin_values, location);
151
152 // Create a dense constant op for slice's size
153 auto slice_i2c_size =
154 CreateI64DenseConst(builder, size_shape, size_values, location);
155
156 return builder->create<TF::SliceOp>(
157 location,
158 RankedTensorType::get(
159 size_values,
160 input.getType().cast<RankedTensorType>().getElementType()),
161 input, slice_i2c_begin, slice_i2c_size);
162 }
163
CreateStridedSliceOp(mlir::Location loc,ArrayRef<int64_t> output_shape,Value input,ArrayRef<int32_t> begin,ArrayRef<int32_t> end,ArrayRef<int32_t> strides,int64_t begin_mask,int64_t end_mask,int64_t ellipsis_mask,int64_t new_axis_mask,int64_t shrink_axis_mask,OpBuilder * builder)164 Value CreateStridedSliceOp(mlir::Location loc, ArrayRef<int64_t> output_shape,
165 Value input, ArrayRef<int32_t> begin,
166 ArrayRef<int32_t> end, ArrayRef<int32_t> strides,
167 int64_t begin_mask, int64_t end_mask,
168 int64_t ellipsis_mask, int64_t new_axis_mask,
169 int64_t shrink_axis_mask, OpBuilder* builder) {
170 auto output_type = RankedTensorType::get(
171 output_shape, input.getType().cast<RankedTensorType>().getElementType());
172 auto begin_tensor = CreateI32DenseConst(builder, begin, loc);
173 auto end_tensor = CreateI32DenseConst(builder, end, loc);
174 auto strides_tensor = CreateI32DenseConst(builder, strides, loc);
175
176 return builder->create<TF::StridedSliceOp>(
177 loc, output_type, input, begin_tensor, end_tensor, strides_tensor,
178 builder->getI64IntegerAttr(begin_mask),
179 builder->getI64IntegerAttr(end_mask),
180 builder->getI64IntegerAttr(ellipsis_mask),
181 builder->getI64IntegerAttr(new_axis_mask),
182 builder->getI64IntegerAttr(shrink_axis_mask));
183 }
184
185 } // namespace
186
SetWeightForInputToCellGate()187 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToCellGate() {
188 SmallVector<int64_t, 2> begin_i2c_values = {0, 0};
189 input2cell_ = SliceRankedTensor(
190 &builder_, weight_transposed_, weight_slice_shape_, begin_i2c_values,
191 weight_slice_shape_, weight_slice_size_input_values_,
192 fused_func_op_.getLoc());
193 }
194
SetWeightForInputToInputGate()195 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToInputGate() {
196 SmallVector<int64_t, 2> begin_i2i_values = {n_cell_, 0};
197 input2input_ = couple_input_forget_gates_
198 ? none_
199 : SliceRankedTensor(&builder_, weight_transposed_,
200 weight_slice_shape_, begin_i2i_values,
201 weight_slice_shape_,
202 weight_slice_size_input_values_,
203 fused_func_op_.getLoc());
204 }
205
SetWeightForInputToForgetGate()206 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToForgetGate() {
207 int input_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
208 SmallVector<int64_t, 2> begin_i2f_values = {input_forget_start, 0};
209 input2forget_ = SliceRankedTensor(
210 &builder_, weight_transposed_, weight_slice_shape_, begin_i2f_values,
211 weight_slice_shape_, weight_slice_size_input_values_,
212 fused_func_op_.getLoc());
213 }
214
SetWeightForInputToOutputGate()215 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForInputToOutputGate() {
216 int input_output_start =
217 couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
218 SmallVector<int64_t, 2> begin_i2o_values = {input_output_start, 0};
219 input2output_ = SliceRankedTensor(
220 &builder_, weight_transposed_, weight_slice_shape_, begin_i2o_values,
221 weight_slice_shape_, weight_slice_size_input_values_,
222 fused_func_op_.getLoc());
223 }
224
SetWeightForRecurrentToCellGate()225 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToCellGate() {
226 SmallVector<int64_t, 2> begin_rec2c_values = {0, n_input_};
227 rec2cell_ = SliceRankedTensor(
228 &builder_, weight_transposed_, weight_slice_shape_, begin_rec2c_values,
229 weight_slice_shape_, weight_slice_size_recurrent_values_,
230 fused_func_op_.getLoc());
231 }
232
SetWeightForRecurrentToInputGate()233 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToInputGate() {
234 SmallVector<int64_t, 2> begin_rec2i_values = {n_cell_, n_input_};
235 rec2input_ = couple_input_forget_gates_
236 ? none_
237 : SliceRankedTensor(&builder_, weight_transposed_,
238 weight_slice_shape_, begin_rec2i_values,
239 weight_slice_shape_,
240 weight_slice_size_recurrent_values_,
241 fused_func_op_.getLoc());
242 }
243
SetWeightForRecurrentToForgetGate()244 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToForgetGate() {
245 int rec_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
246 SmallVector<int64_t, 2> begin_rec2f_values = {rec_forget_start, n_input_};
247 rec2forget_ = SliceRankedTensor(
248 &builder_, weight_transposed_, weight_slice_shape_, begin_rec2f_values,
249 weight_slice_shape_, weight_slice_size_recurrent_values_,
250 fused_func_op_.getLoc());
251 }
252
SetWeightForRecurrentToOutputGate()253 void ConvertLSTMCellSimpleToFusedLSTM::SetWeightForRecurrentToOutputGate() {
254 int rec_output_start = couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
255 SmallVector<int64_t, 2> begin_rec2o_values = {rec_output_start, n_input_};
256 rec2output_ = SliceRankedTensor(
257 &builder_, weight_transposed_, weight_slice_shape_, begin_rec2o_values,
258 weight_slice_shape_, weight_slice_size_recurrent_values_,
259 fused_func_op_.getLoc());
260 }
261
SetBiasToCellGate()262 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToCellGate() {
263 SmallVector<int64_t, 1> begin_bias2c_values = {0};
264 bias2cell_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
265 begin_bias2c_values, bias_slice_shape_,
266 bias_size_values_, fused_func_op_.getLoc());
267 }
268
SetBiasToInputGate()269 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToInputGate() {
270 SmallVector<int64_t, 1> begin_bias2i_values = {n_cell_};
271 bias2input_ =
272 couple_input_forget_gates_
273 ? none_
274 : SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
275 begin_bias2i_values, bias_slice_shape_,
276 bias_size_values_, fused_func_op_.getLoc());
277 }
278
SetBiasToForgetGate()279 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToForgetGate() {
280 int bias_forget_start = couple_input_forget_gates_ ? n_cell_ : 2 * n_cell_;
281 SmallVector<int64_t, 1> begin_bias2f_values = {bias_forget_start};
282 bias2forget_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
283 begin_bias2f_values, bias_slice_shape_,
284 bias_size_values_, fused_func_op_.getLoc());
285 }
286
SetBiasToOutputGate()287 void ConvertLSTMCellSimpleToFusedLSTM::SetBiasToOutputGate() {
288 int bias_output_start =
289 couple_input_forget_gates_ ? 2 * n_cell_ : 3 * n_cell_;
290 SmallVector<int64_t, 1> begin_bias2o_values = {bias_output_start};
291 bias2output_ = SliceRankedTensor(&builder_, bias_, bias_slice_shape_,
292 begin_bias2o_values, bias_slice_shape_,
293 bias_size_values_, fused_func_op_.getLoc());
294 }
295
SetProjection()296 void ConvertLSTMCellSimpleToFusedLSTM::SetProjection() {
297 SmallVector<int64_t, 2> projection_slice_shape = {
298 1, num_cols_projection_transposed_};
299 SmallVector<int64_t, 2> projection_slice_size_values = {n_output_, n_cell_};
300 SmallVector<int64_t, 2> projection_slice_begin_values = {0, 0};
301 proj_weight_ =
302 !projection_
303 ? none_
304 : SliceRankedTensor(
305 &builder_, projection_transposed_, projection_slice_shape,
306 projection_slice_begin_values, projection_slice_shape,
307 projection_slice_size_values, fused_func_op_.getLoc());
308 }
309
SetProjectionBias()310 void ConvertLSTMCellSimpleToFusedLSTM::SetProjectionBias() {
311 proj_bias_ = !projection_type_
312 ? none_
313 : CreateF32SplatConst(&builder_, {n_output_}, 0,
314 fused_func_op_.getLoc());
315 }
316
SetInputActivationState()317 void ConvertLSTMCellSimpleToFusedLSTM::SetInputActivationState() {
318 input_activation_state_ = CreateF32SplatConst(&builder_, {1, n_output_}, 0,
319 fused_func_op_.getLoc());
320 }
321
SetInputCellState()322 void ConvertLSTMCellSimpleToFusedLSTM::SetInputCellState() {
323 input_cell_state_ =
324 CreateF32SplatConst(&builder_, {1, n_cell_}, 0, fused_func_op_.getLoc());
325 }
326
SetCellLayerNormCoefficients()327 void ConvertLSTMCellSimpleToFusedLSTM::SetCellLayerNormCoefficients() {
328 cell_layer_norm_coefficients_ = none_;
329 }
330
SetInputLayerNormCoefficients()331 void ConvertLSTMCellSimpleToFusedLSTM::SetInputLayerNormCoefficients() {
332 input_layer_norm_coefficients_ = none_;
333 }
334
SetForgetLayerNormCoefficients()335 void ConvertLSTMCellSimpleToFusedLSTM::SetForgetLayerNormCoefficients() {
336 forget_layer_norm_coefficients_ = none_;
337 }
SetOutputLayerNormCoefficients()338 void ConvertLSTMCellSimpleToFusedLSTM::SetOutputLayerNormCoefficients() {
339 output_layer_norm_coefficients_ = none_;
340 }
341
GenerateFusedOpOperands()342 void ConvertLSTMCellSimpleToFusedLSTM::GenerateFusedOpOperands() {
343 // Transpose both weight and projection.
344 weight_transposed_ =
345 Transpose2D(&builder_, weight_, weight_type_, fused_func_op_.getLoc());
346 projection_transposed_ = Transpose2D(&builder_, projection_, projection_type_,
347 fused_func_op_.getLoc());
348
349 none_ = CreateNoneValue(&builder_, fused_func_op_.getLoc());
350 // Extract input to cifg gates via slicing the weight tensor
351 SetWeightForInputToCellGate();
352 SetWeightForInputToInputGate();
353 SetWeightForInputToForgetGate();
354 SetWeightForInputToOutputGate();
355
356 // Extract recurrent to cifg gates via slicing the weight tensor
357 SetWeightForRecurrentToCellGate();
358 SetWeightForRecurrentToInputGate();
359 SetWeightForRecurrentToForgetGate();
360 SetWeightForRecurrentToOutputGate();
361
362 // Extract bias to cifg gates via slicing the bias tensor
363 SetBiasToCellGate();
364 SetBiasToInputGate();
365 SetBiasToForgetGate();
366 SetBiasToOutputGate();
367
368 // Extract projection and set an empty projection bias
369 SetProjection();
370 SetProjectionBias();
371
372 // Set the variable tensors
373 SetInputActivationState();
374 SetInputCellState();
375
376 // Extract the layer norm coefficients
377 SetCellLayerNormCoefficients();
378 SetInputLayerNormCoefficients();
379 SetForgetLayerNormCoefficients();
380 SetOutputLayerNormCoefficients();
381 }
382
UpdateFuncSignature()383 void ConvertLSTMCellSimpleToFusedLSTM::UpdateFuncSignature() {
384 // https://github.com/tensorflow/community/pull/113
385 SmallVector<int64_t, 2> output_shape{1, -1};
386 auto input_types = fused_func_op_.getFunctionType().getInputs();
387 auto output_type = mlir::RankedTensorType::get(
388 output_shape, input_.getType().cast<RankedTensorType>().getElementType());
389 fused_func_op_.setType(mlir::FunctionType::get(fused_func_op_.getContext(),
390 input_types, output_type));
391 }
392
RewriteFunc()393 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::RewriteFunc() {
394 LogicalResult result = Initialize();
395 if (failed(result)) {
396 return result;
397 }
398
399 // Update the func signature, based on output shape.
400 // The func will ultimately return the output of the fused
401 // LSTM op.
402 UpdateFuncSignature();
403
404 // Transform the weights, projection, bias and layer norm coefficients
405 // to generate operands for the TFL fused LSTM op.
406 GenerateFusedOpOperands();
407
408 // Create the fused LSTM op.
409 SmallVector<int64_t, 2> output_shape = {1, n_output_};
410 auto result_type = mlir::RankedTensorType::get(
411 output_shape, input_.getType().cast<RankedTensorType>().getElementType());
412 lstm_ = builder_.create<mlir::TFL::LSTMOp>(
413 fused_func_op_.getLoc(), result_type, input_, input2input_, input2forget_,
414 input2cell_, input2output_, rec2input_, rec2forget_, rec2cell_,
415 rec2output_, /*cell_to_input_weights*/ none_,
416 /*cell_to_forget_weights*/ none_,
417 /*cell_to_output_weights*/ none_, bias2input_, bias2forget_, bias2cell_,
418 bias2output_, proj_weight_, proj_bias_, input_activation_state_,
419 input_cell_state_, input_layer_norm_coefficients_,
420 forget_layer_norm_coefficients_, cell_layer_norm_coefficients_,
421 output_layer_norm_coefficients_, builder_.getStringAttr("TANH"),
422 builder_.getF32FloatAttr(10.0), builder_.getF32FloatAttr(0.0),
423 mlir::TFL::LSTMKernelTypeAttr::get(builder_.getContext(),
424 mlir::TFL::LSTMKernelType::FULL),
425 /*asymmetric_quantize_inputs=*/mlir::BoolAttr(),
426 /*input_to_input_intermediate=*/mlir::TypeAttr(),
427 /*input_to_forget_intermediate=*/mlir::TypeAttr(),
428 /*input_to_cell_intermediate=*/mlir::TypeAttr(),
429 /*input_to_output_intermediate=*/mlir::TypeAttr(),
430 /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
431
432 // Cast the static shaped lstm result to FuncOp's signature -
433 // Ranked but unknown 2nd dimension to support stacking these.
434 SmallVector<int64_t, 2> func_output_shape = {1, -1};
435 auto func_result_type = mlir::RankedTensorType::get(
436 func_output_shape,
437 input_.getType().cast<RankedTensorType>().getElementType());
438
439 auto tensor_cast = builder_.create<mlir::tensor::CastOp>(
440 fused_func_op_.getLoc(), func_result_type, lstm_.getResult());
441 builder_.create<mlir::func::ReturnOp>(fused_func_op_.getLoc(),
442 tensor_cast.getResult());
443 return success();
444 }
445
InitializeFromFuncAttributes()446 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::InitializeFromFuncAttributes() {
447 auto attr = fused_func_op_->getAttrOfType<StringAttr>(kTFImplements);
448 if (!attr) {
449 return fused_func_op_.emitError()
450 << "Invalid function attribute, expected " << kTFImplements
451 << " attribute "
452 "not found";
453 }
454
455 // TODO(ashwinm, b/144775479): Make these NamedAttribute on TF import
456 // once tf.function can support this.
457 llvm::SmallVector<llvm::StringRef, 4> attr_tokens;
458 attr.getValue().split(attr_tokens, ",");
459 if (attr_tokens.empty()) {
460 return fused_func_op_.emitError()
461 << kTFImplements << " attribute should be set";
462 }
463
464 // Check if the interface matches.
465 if (GetCompositeOpName().str() != attr_tokens[0]) {
466 return fused_func_op_.emitError()
467 << "Unexpected interface for the composite op. Expected: "
468 << GetCompositeOpName() << " Actual: " << attr_tokens[0];
469 }
470
471 // Extract other interface attributes, for now cifg.
472 couple_input_forget_gates_ =
473 std::find(attr_tokens.begin() + 1, attr_tokens.end(),
474 kCoupleInputForgetGates) != attr_tokens.end();
475
476 return success();
477 }
478
Initialize()479 LogicalResult ConvertLSTMCellSimpleToFusedLSTM::Initialize() {
480 if (failed(InitializeFromFuncAttributes())) {
481 return fused_func_op_.emitError()
482 << "Expected function attributes were not set on the function "
483 "encapsulating the composite op";
484 }
485
486 num_gates_ = couple_input_forget_gates_ ? 3 : 4;
487
488 input_ = fused_func_op_.getArgument(0);
489 bias_ = fused_func_op_.getArgument(2);
490
491 weight_ = fused_func_op_.getArgument(1);
492 weight_type_ = weight_.getType().cast<RankedTensorType>();
493
494 if (weight_type_.getRank() != 2) {
495 return fused_func_op_.emitError() << "The weight tensor was not of rank 2";
496 }
497
498 if (weight_type_.getDimSize(1) % num_gates_ != 0) {
499 return fused_func_op_.emitError()
500 << "Invalid dimension 1 of weight tensor, "
501 "should be divisible by the number of gates";
502 }
503 n_cell_ = weight_type_.getDimSize(1) / num_gates_;
504
505 projection_ = fused_func_op_.getArgument(3);
506 projection_type_ = projection_.getType().cast<RankedTensorType>();
507 if (projection_type_.getRank() != 2) {
508 n_output_ = n_cell_;
509 } else {
510 n_output_ = projection_type_.getDimSize(1);
511 }
512 n_input_ = weight_type_.getDimSize(0) - n_output_;
513 num_cols_weight_transposed_ = weight_type_.getDimSize(0);
514 num_cols_projection_transposed_ = projection_type_.getDimSize(0);
515
516 bias_slice_shape_ = {n_cell_};
517 bias_size_values_ = {n_cell_};
518 weight_slice_shape_ = {1, num_cols_weight_transposed_};
519 weight_slice_size_input_values_ = {n_cell_, n_input_};
520 weight_slice_size_recurrent_values_ = {n_cell_, n_output_};
521
522 return success();
523 }
524
Initialize()525 LogicalResult ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::Initialize() {
526 if (failed(ConvertLSTMCellSimpleToFusedLSTM::Initialize())) {
527 return fused_func_op_.emitError()
528 << "Specified LayerNormalizedLSTMCellSimple was not of the expected "
529 "interface and cannot not be converted to the fused LSTM op";
530 }
531
532 layer_norm_scale_ = fused_func_op_.getArgument(4);
533 layer_norm_scale_type_ = layer_norm_scale_.getType().cast<RankedTensorType>();
534 if (layer_norm_scale_type_.getRank() != 1) {
535 return fused_func_op_.emitError()
536 << "The layer_norm_scale tensor was not of rank 1";
537 }
538 layer_norm_slice_shape_ = {n_cell_};
539 layer_norm_size_values_ = {n_cell_};
540
541 return success();
542 }
543
544 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetCellLayerNormCoefficients()545 SetCellLayerNormCoefficients() {
546 SmallVector<int64_t, 1> begin_cell_layer_norm_values = {0};
547 cell_layer_norm_coefficients_ =
548 SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
549 begin_cell_layer_norm_values, layer_norm_slice_shape_,
550 layer_norm_size_values_, fused_func_op_.getLoc());
551 }
552
553 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetInputLayerNormCoefficients()554 SetInputLayerNormCoefficients() {
555 SmallVector<int64_t, 1> begin_input_layer_norm_values = {n_cell_};
556 input_layer_norm_coefficients_ =
557 couple_input_forget_gates_
558 ? none_
559 : SliceRankedTensor(
560 &builder_, layer_norm_scale_, layer_norm_slice_shape_,
561 begin_input_layer_norm_values, layer_norm_slice_shape_,
562 layer_norm_size_values_, fused_func_op_.getLoc());
563 }
564
565 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetForgetLayerNormCoefficients()566 SetForgetLayerNormCoefficients() {
567 SmallVector<int64_t, 1> begin_forget_layer_norm_values = {2 * n_cell_};
568 forget_layer_norm_coefficients_ =
569 SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
570 begin_forget_layer_norm_values, layer_norm_slice_shape_,
571 layer_norm_size_values_, fused_func_op_.getLoc());
572 }
573
574 void ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM::
SetOutputLayerNormCoefficients()575 SetOutputLayerNormCoefficients() {
576 SmallVector<int64_t, 1> begin_output_layer_norm_values = {3 * n_cell_};
577 output_layer_norm_coefficients_ =
578 SliceRankedTensor(&builder_, layer_norm_scale_, layer_norm_slice_shape_,
579 begin_output_layer_norm_values, layer_norm_slice_shape_,
580 layer_norm_size_values_, fused_func_op_.getLoc());
581 }
582
Create1DConstantOp(const std::vector<int> & value,Location loc,OpBuilder * builder)583 TF::ConstOp Create1DConstantOp(const std::vector<int>& value, Location loc,
584 OpBuilder* builder) {
585 auto type =
586 mlir::RankedTensorType::get(value.size(), builder->getIntegerType(32));
587 auto dense_values = mlir::DenseIntElementsAttr::get(type, value);
588 return builder->create<TF::ConstOp>(loc, dense_values);
589 }
590
CreateScalarConstantOp(int value,Location loc,OpBuilder * builder)591 TF::ConstOp CreateScalarConstantOp(int value, Location loc,
592 OpBuilder* builder) {
593 return builder->create<TF::ConstOp>(loc, builder->getI32IntegerAttr(value));
594 }
595
CreateEqualSizeSplitVOp(Value input,int axis,int splits,Location loc,OpBuilder * builder,Operation ** result)596 LogicalResult CreateEqualSizeSplitVOp(Value input, int axis, int splits,
597 Location loc, OpBuilder* builder,
598 Operation** result) {
599 auto input_type = input.getType().cast<RankedTensorType>();
600 SmallVector<int64_t, 4> output_shape;
601 int size_of_splits;
602 if (input_type.getRank() < axis || axis < 0) return failure();
603 for (int i = 0; i < input_type.getRank(); ++i) {
604 int dim = input_type.getDimSize(i);
605 if (i == axis) {
606 if (dim % splits != 0) {
607 return failure();
608 }
609 size_of_splits = dim / splits;
610 output_shape.push_back(size_of_splits);
611 } else {
612 output_shape.push_back(dim);
613 }
614 }
615
616 SmallVector<mlir::Type, 4> output_types;
617 for (int i = 0; i < splits; ++i) {
618 output_types.push_back(
619 mlir::RankedTensorType::get(output_shape, input_type.getElementType()));
620 }
621 auto size_of_splits_op = Create1DConstantOp(
622 {size_of_splits, size_of_splits, size_of_splits, size_of_splits}, loc,
623 builder);
624
625 auto axis_op = CreateScalarConstantOp(axis, loc, builder);
626 *result = builder->create<TF::SplitVOp>(loc, output_types, input,
627 size_of_splits_op.getResult(),
628 axis_op.getResult());
629 return success();
630 }
631
632 // TODO(b/147436982): Consider refactor this to be more general.
ConvertKerasLSTMLayer(mlir::func::FuncOp func_op,OpBuilder * builder)633 LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op,
634 OpBuilder* builder) {
635 // For argument order, please check out standard_lstm under
636 // tensorflow/python/keras/layers/recurrent_v2.py
637 Value input = func_op.getArgument(0);
638 Value output_init_state = func_op.getArgument(1);
639 Value hidden_init_state = func_op.getArgument(2);
640 Value weight_kernel = func_op.getArgument(3);
641 Value recurrent_kernel = func_op.getArgument(4);
642 Value bias = func_op.getArgument(5);
643
644 // The func op should have 5 outputs.
645 if (func_op.getNumResults() != 5) return failure();
646
647 // TFL lstm only supports time-majored inputs, so if it's not time-majored,
648 // we will transpose the inputs and outputs.
649 auto time_major_attr = func_op->getAttrOfType<BoolAttr>("tf.time_major");
650 if (time_major_attr == nullptr) return failure();
651
652 bool time_majored = time_major_attr.getValue();
653 auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
654 if (!input_type) {
655 func_op.emitError() << "Input type is not a ranked tensor type";
656 return failure();
657 }
658
659 auto final_inputs = input;
660 auto final_input_type = input_type;
661
662 // Handle go_backwards:
663 // LSTM in Keras semantic will reverse the input sequence if it's go_backwards
664 auto go_backwards_attr = func_op->getAttrOfType<BoolAttr>("tf.go_backwards");
665
666 if (go_backwards_attr != nullptr && go_backwards_attr.getValue()) {
667 int time_dim = time_majored ? 0 : 1;
668 final_inputs = Reverse(builder, final_inputs, time_dim, final_input_type,
669 func_op.getLoc());
670 }
671
672 int batch = time_majored ? final_input_type.getDimSize(1)
673 : final_input_type.getDimSize(0);
674 int time = time_majored ? final_input_type.getDimSize(0)
675 : final_input_type.getDimSize(1);
676
677 // Setup correct weights.
678 RankedTensorType weight_type =
679 weight_kernel.getType().cast<RankedTensorType>();
680 if (weight_type.getRank() != 2)
681 return func_op.emitError() << "The weight should be rank of 2";
682
683 Value transposed_weight_kernel =
684 Transpose2D(builder, weight_kernel, weight_type, func_op.getLoc());
685
686 RankedTensorType recurrent_kernel_type =
687 recurrent_kernel.getType().cast<RankedTensorType>();
688 const int n_output = recurrent_kernel_type.getDimSize(0);
689
690 Value transpose_recurrent_kernel = Transpose2D(
691 builder, recurrent_kernel, recurrent_kernel_type, func_op.getLoc());
692
693 // Splits the weights into 4: i, f, c, o.
694 const int splits = 4;
695
696 Operation* weights_array;
697 if (failed(CreateEqualSizeSplitVOp(transposed_weight_kernel, 0, splits,
698 func_op.getLoc(), builder,
699 &weights_array)))
700 return failure();
701
702 // Splits the recurrent_weights into 4:
703 Operation* recurrent_weights_array;
704 if (failed(CreateEqualSizeSplitVOp(transpose_recurrent_kernel, 0, splits,
705 func_op.getLoc(), builder,
706 &recurrent_weights_array)))
707 return failure();
708
709 // Splits the bias into 4:
710 Operation* bias_array;
711 if (failed(CreateEqualSizeSplitVOp(bias, 0, splits, func_op.getLoc(), builder,
712 &bias_array)))
713 return failure();
714
715 // Build the lstm op.
716 SmallVector<int64_t, 3> output_shape;
717 if (time_majored) {
718 output_shape = {time, batch, n_output};
719 } else {
720 output_shape = {batch, time, n_output};
721 }
722 auto result_type = mlir::RankedTensorType::get(
723 output_shape,
724 final_inputs.getType().cast<RankedTensorType>().getElementType());
725
726 Value none = CreateNoneValue(builder, func_op.getLoc());
727 auto lstm = builder->create<mlir::TFL::UnidirectionalSequenceLSTMOp>(
728 func_op.getLoc(), result_type, /*input=*/final_inputs,
729 /*input_to_input_weights=*/weights_array->getResult(0),
730 /*input_to_forget_weights=*/weights_array->getResult(1),
731 /*input_to_cell_weights=*/weights_array->getResult(2),
732 /*input_to_output_weights=*/weights_array->getResult(3),
733 /*recurrent_to_input_weights=*/recurrent_weights_array->getResult(0),
734 /*recurrent_to_forget_weights=*/recurrent_weights_array->getResult(1),
735 /*recurrent_to_cell_weights=*/recurrent_weights_array->getResult(2),
736 /*recurrent_to_output_weights=*/recurrent_weights_array->getResult(3),
737 /*cell_to_input_weights=*/none,
738 /*cell_to_forget_weights=*/none,
739 /*cell_to_output_weights=*/none,
740 /*input_gate_bias=*/bias_array->getResult(0),
741 /*forget_gate_bias=*/bias_array->getResult(1),
742 /*cell_bias=*/bias_array->getResult(2),
743 /*output_gate_bias=*/bias_array->getResult(3),
744 /*projection_weights=*/none,
745 /*projection_bias=*/none,
746 /*input_activation_state=*/output_init_state,
747 /*input_cell_state=*/hidden_init_state,
748 /*input_layer_norm_coefficients=*/none,
749 /*forget_layer_norm_coefficients=*/none,
750 /*cell_layer_norm_coefficients=*/none,
751 /*output_layer_norm_coefficients=*/none,
752 /*fused_activation_function*/ builder->getStringAttr("TANH"),
753 /*cell_clip*/ builder->getF32FloatAttr(10.0),
754 /*proj_clip*/ builder->getF32FloatAttr(0.0),
755 /*time_major*/ builder->getBoolAttr(time_majored),
756 /*asymmetric_quantize_inputs=*/mlir::BoolAttr(),
757 /*input_to_input_intermediate=*/mlir::TypeAttr(),
758 /*input_to_forget_intermediate=*/mlir::TypeAttr(),
759 /*input_to_cell_intermediate=*/mlir::TypeAttr(),
760 /*input_to_output_intermediate=*/mlir::TypeAttr(),
761 /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());
762
763 auto final_output_full_sequences = lstm.getResult();
764
765 // Populate the last output: last output is sliced from the full sequences.
766 // If time_major: last_output = outputs[-1, :, :]
767 // else: last_output = outputs[:, -1, :]
768 //
769 // As we are creating the strided_slice op, we need to populate the following
770 // fields:
771 // end: should always be (0, 0, 0)
772 // strides: should always be (1, 1, 1)
773 // begin: should be (0, -1, 0) or (-1, 0, 0) if it's time-majored.
774 // new_axis_mask: should always be 0.
775 // ellipsis_mask: should always be 0.
776 // begin_mask & end_mask: should be 0b101 = 5 or 0b110 = 4 if it's
777 // time-majored. shrink_axis_mask: should be 0b010 = 2 or 0b001 = 1 if it's
778 // time-majored.
779 SmallVector<int64_t, 2> last_output_shape({batch, n_output});
780
781 SmallVector<int32_t, 3> end({0, 0, 0});
782 SmallVector<int32_t, 3> strides({1, 1, 1});
783 SmallVector<int32_t, 3> begin;
784
785 int64_t new_axis_mask = 0;
786 int64_t ellipsis_mask = 0;
787 int64_t begin_mask;
788 int64_t end_mask;
789 int64_t shrink_axis_mask;
790 if (time_majored) {
791 begin_mask = 6;
792 end_mask = 6;
793 shrink_axis_mask = 1;
794 begin = {-1, 0, 0};
795 } else {
796 begin_mask = 5;
797 end_mask = 5;
798 shrink_axis_mask = 2;
799 begin = {0, -1, 0};
800 }
801
802 auto last_output = CreateStridedSliceOp(
803 func_op.getLoc(), last_output_shape, final_output_full_sequences, begin,
804 end, strides, begin_mask, end_mask, ellipsis_mask, new_axis_mask,
805 shrink_axis_mask, builder);
806
807 SmallVector<Value, 5> outputs;
808 SmallVector<Type, 5> output_types;
809
810 // Due to the existence of the while loop, the timestamp may be unknown
811 // for the signature, for us, since we know the inputs, we can infer the time
812 // steps.
813
814 // Last output.
815 outputs.push_back(last_output);
816 output_types.push_back(last_output.getType());
817
818 // Full sequences.
819 outputs.push_back(final_output_full_sequences);
820 output_types.push_back(final_output_full_sequences.getType());
821
822 // All the rest: states, device.
823 for (int i = 2; i < 5; ++i) {
824 auto result_type =
825 func_op.getCallableResults()[i].dyn_cast<RankedTensorType>();
826 outputs.push_back(CreatTfF32ConstOp(builder, result_type.getShape(), 0.0f,
827 func_op.getLoc()));
828 output_types.push_back(result_type);
829 }
830
831 // Update function signatures.
832 func_op.setType(mlir::FunctionType::get(func_op.getContext(),
833 func_op.getFunctionType().getInputs(),
834 output_types));
835
836 builder->create<mlir::func::ReturnOp>(func_op.getLoc(), outputs);
837 return success();
838 }
839
840 } // namespace TFL
841 } // namespace mlir
842