xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/mlir/lite/utils/lstm_utils.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // This header file defines common utils used by TFLite transformation
17 // passes to work with op attributes.
18 
19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
20 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
21 
22 #include "llvm/ADT/StringRef.h"
23 #include "mlir/Dialect/Func/IR/FuncOps.h"  // from @llvm-project
24 #include "mlir/IR/Builders.h"  // from @llvm-project
25 #include "mlir/IR/BuiltinOps.h"  // from @llvm-project
26 #include "mlir/IR/BuiltinTypes.h"  // from @llvm-project
27 #include "mlir/IR/Location.h"  // from @llvm-project
28 #include "mlir/IR/Value.h"  // from @llvm-project
29 #include "mlir/Support/LogicalResult.h"  // from @llvm-project
30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
31 
32 namespace mlir {
33 namespace TFL {
34 
35 constexpr char kTFImplements[] = "tf._implements";
36 constexpr char kLstmCellSimple[] = "LSTMCellSimple";
37 constexpr char kLayerNormalizedLstmCellSimple[] =
38     "LayerNormalizedLstmCellSimple";
39 constexpr char kCoupleInputForgetGates[] = "CoupleInputForgetGates";
40 
41 // A utility class that enables the conversion of the LSTMCellSimple composite
42 // op into a fused TFL LSTM op. The fused op is contained within a FuncOp
43 // that also contains other supporting ops needed to construct the operands for
44 // the fused op. The caller provides the containing FuncOp as input with
45 // arguments specifying the input, weight, projection and bias.
46 // The weight, projection, bias and layer norm scale all need to be
47 // RankedTensorType.
48 // This class sets the layer norm coefficients to NoneType.
49 class ConvertLSTMCellSimpleToFusedLSTM {
50  public:
ConvertLSTMCellSimpleToFusedLSTM(mlir::func::FuncOp fused_func_op)51   explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::func::FuncOp fused_func_op)
52       : fused_func_op_(fused_func_op),
53         couple_input_forget_gates_(false),
54         builder_(fused_func_op.getBody()) {}
55 
56   // not copyable.
57   ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) =
58       delete;
59   ConvertLSTMCellSimpleToFusedLSTM& operator=(
60       const ConvertLSTMCellSimpleToFusedLSTM&) = delete;
~ConvertLSTMCellSimpleToFusedLSTM()61   virtual ~ConvertLSTMCellSimpleToFusedLSTM() {}
62 
GetCompositeOpName()63   virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; }
64 
65   // Rewrite the func body with constructed fused lstm.
66   LogicalResult RewriteFunc();
67 
GetNumInputs()68   int GetNumInputs() { return n_input_; }
69 
70  protected:
71   // verify input func op arguments/attributes and initialize internal state.
72   virtual LogicalResult InitializeFromFuncAttributes();
73   virtual LogicalResult Initialize();
74 
75   void UpdateFuncSignature();
76   void GenerateFusedOpOperands();
77 
78   void SetWeightForInputToCellGate();
79   void SetWeightForInputToInputGate();
80   void SetWeightForInputToForgetGate();
81   void SetWeightForInputToOutputGate();
82 
83   void SetWeightForRecurrentToCellGate();
84   void SetWeightForRecurrentToInputGate();
85   void SetWeightForRecurrentToForgetGate();
86   void SetWeightForRecurrentToOutputGate();
87 
88   void SetBiasToCellGate();
89   void SetBiasToInputGate();
90   void SetBiasToForgetGate();
91   void SetBiasToOutputGate();
92 
93   void SetProjection();
94   void SetProjectionBias();
95 
96   void SetInputActivationState();
97   void SetInputCellState();
98 
99   virtual void SetCellLayerNormCoefficients();
100   virtual void SetInputLayerNormCoefficients();
101   virtual void SetForgetLayerNormCoefficients();
102   virtual void SetOutputLayerNormCoefficients();
103 
104   // specified state
105   func::FuncOp fused_func_op_;
106   Value input_;
107   Value weight_;
108   Value bias_;
109   Value projection_;
110   bool couple_input_forget_gates_;
111 
112   // internal state
113   Value weight_transposed_;
114   Value projection_transposed_;
115   RankedTensorType weight_type_;
116   RankedTensorType projection_type_;
117   int num_gates_;
118   int n_cell_;
119   int n_output_;
120   int n_input_;
121   int num_cols_weight_transposed_;
122   int num_cols_projection_transposed_;
123 
124   // input -> cifg
125   Value input2input_;
126   Value input2forget_;
127   Value input2cell_;
128   Value input2output_;
129 
130   // recurrent -> cifg
131   Value rec2input_;
132   Value rec2forget_;
133   Value rec2cell_;
134   Value rec2output_;
135 
136   // bias -> cifg
137   Value bias2input_;
138   Value bias2forget_;
139   Value bias2cell_;
140   Value bias2output_;
141 
142   // projection
143   Value proj_weight_;
144   Value proj_bias_;
145 
146   // state
147   Value input_activation_state_;
148   Value input_cell_state_;
149 
150   // layer norm coefficients
151   Value input_layer_norm_coefficients_;
152   Value forget_layer_norm_coefficients_;
153   Value cell_layer_norm_coefficients_;
154   Value output_layer_norm_coefficients_;
155 
156   mlir::TFL::LSTMOp lstm_;
157 
158   Value none_;
159   SmallVector<int64_t, 1> bias_slice_shape_;
160   SmallVector<int64_t, 1> bias_size_values_;
161   SmallVector<int64_t, 2> weight_slice_shape_;
162   SmallVector<int64_t, 2> weight_slice_size_input_values_;
163   SmallVector<int64_t, 2> weight_slice_size_recurrent_values_;
164   OpBuilder builder_;
165 };
166 
167 // A utility class that enables the conversion of the
168 // LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The
169 // fused op is contained within a FuncOp that also contains other supporting ops
170 // needed to construct the operands for the fused op. The caller provides the
171 // containing FuncOp as input with arguments specifying the input, weight,
172 // projection, bias and layer norm scale. The weight, projection, bias and
173 // layer norm scale all need to be RankedTensorType.
174 // This class overrides the layer norm coefficient setters from the base class.
175 class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM
176     : public ConvertLSTMCellSimpleToFusedLSTM {
177  public:
ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(mlir::func::FuncOp fused_func_op)178   explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
179       mlir::func::FuncOp fused_func_op)
180       : ConvertLSTMCellSimpleToFusedLSTM(fused_func_op) {}
181 
182   // not copyable.
183   ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(
184       const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
185   ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=(
186       const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete;
~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM()187   ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {}
188 
GetCompositeOpName()189   llvm::StringRef GetCompositeOpName() override {
190     return kLayerNormalizedLstmCellSimple;
191   }
192 
193  protected:
194   LogicalResult Initialize() override;
195 
196   void SetCellLayerNormCoefficients() override;
197   void SetInputLayerNormCoefficients() override;
198   void SetForgetLayerNormCoefficients() override;
199   void SetOutputLayerNormCoefficients() override;
200 
201  private:
202   // specified state
203   Value layer_norm_scale_;
204 
205   // internal state
206   RankedTensorType layer_norm_scale_type_;
207   SmallVector<int64_t, 1> layer_norm_slice_shape_;
208   SmallVector<int64_t, 1> layer_norm_size_values_;
209 };
210 
211 LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op,
212                                     OpBuilder* builder);
213 
214 }  // end namespace TFL
215 }  // end namespace mlir
216 
217 #endif  // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_
218