1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 // This header file defines common utils used by TFLite transformation 17 // passes to work with op attributes. 18 19 #ifndef TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ 20 #define TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ 21 22 #include "llvm/ADT/StringRef.h" 23 #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project 24 #include "mlir/IR/Builders.h" // from @llvm-project 25 #include "mlir/IR/BuiltinOps.h" // from @llvm-project 26 #include "mlir/IR/BuiltinTypes.h" // from @llvm-project 27 #include "mlir/IR/Location.h" // from @llvm-project 28 #include "mlir/IR/Value.h" // from @llvm-project 29 #include "mlir/Support/LogicalResult.h" // from @llvm-project 30 #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" 31 32 namespace mlir { 33 namespace TFL { 34 35 constexpr char kTFImplements[] = "tf._implements"; 36 constexpr char kLstmCellSimple[] = "LSTMCellSimple"; 37 constexpr char kLayerNormalizedLstmCellSimple[] = 38 "LayerNormalizedLstmCellSimple"; 39 constexpr char kCoupleInputForgetGates[] = "CoupleInputForgetGates"; 40 41 // A utility class that enables the conversion of the LSTMCellSimple composite 42 // op into a fused TFL LSTM op. The fused op is contained within a FuncOp 43 // that also contains other supporting ops needed to construct the operands for 44 // the fused op. The caller provides the containing FuncOp as input with 45 // arguments specifying the input, weight, projection and bias. 46 // The weight, projection, bias and layer norm scale all need to be 47 // RankedTensorType. 48 // This class sets the layer norm coefficients to NoneType. 49 class ConvertLSTMCellSimpleToFusedLSTM { 50 public: ConvertLSTMCellSimpleToFusedLSTM(mlir::func::FuncOp fused_func_op)51 explicit ConvertLSTMCellSimpleToFusedLSTM(mlir::func::FuncOp fused_func_op) 52 : fused_func_op_(fused_func_op), 53 couple_input_forget_gates_(false), 54 builder_(fused_func_op.getBody()) {} 55 56 // not copyable. 57 ConvertLSTMCellSimpleToFusedLSTM(const ConvertLSTMCellSimpleToFusedLSTM&) = 58 delete; 59 ConvertLSTMCellSimpleToFusedLSTM& operator=( 60 const ConvertLSTMCellSimpleToFusedLSTM&) = delete; ~ConvertLSTMCellSimpleToFusedLSTM()61 virtual ~ConvertLSTMCellSimpleToFusedLSTM() {} 62 GetCompositeOpName()63 virtual llvm::StringRef GetCompositeOpName() { return kLstmCellSimple; } 64 65 // Rewrite the func body with constructed fused lstm. 66 LogicalResult RewriteFunc(); 67 GetNumInputs()68 int GetNumInputs() { return n_input_; } 69 70 protected: 71 // verify input func op arguments/attributes and initialize internal state. 72 virtual LogicalResult InitializeFromFuncAttributes(); 73 virtual LogicalResult Initialize(); 74 75 void UpdateFuncSignature(); 76 void GenerateFusedOpOperands(); 77 78 void SetWeightForInputToCellGate(); 79 void SetWeightForInputToInputGate(); 80 void SetWeightForInputToForgetGate(); 81 void SetWeightForInputToOutputGate(); 82 83 void SetWeightForRecurrentToCellGate(); 84 void SetWeightForRecurrentToInputGate(); 85 void SetWeightForRecurrentToForgetGate(); 86 void SetWeightForRecurrentToOutputGate(); 87 88 void SetBiasToCellGate(); 89 void SetBiasToInputGate(); 90 void SetBiasToForgetGate(); 91 void SetBiasToOutputGate(); 92 93 void SetProjection(); 94 void SetProjectionBias(); 95 96 void SetInputActivationState(); 97 void SetInputCellState(); 98 99 virtual void SetCellLayerNormCoefficients(); 100 virtual void SetInputLayerNormCoefficients(); 101 virtual void SetForgetLayerNormCoefficients(); 102 virtual void SetOutputLayerNormCoefficients(); 103 104 // specified state 105 func::FuncOp fused_func_op_; 106 Value input_; 107 Value weight_; 108 Value bias_; 109 Value projection_; 110 bool couple_input_forget_gates_; 111 112 // internal state 113 Value weight_transposed_; 114 Value projection_transposed_; 115 RankedTensorType weight_type_; 116 RankedTensorType projection_type_; 117 int num_gates_; 118 int n_cell_; 119 int n_output_; 120 int n_input_; 121 int num_cols_weight_transposed_; 122 int num_cols_projection_transposed_; 123 124 // input -> cifg 125 Value input2input_; 126 Value input2forget_; 127 Value input2cell_; 128 Value input2output_; 129 130 // recurrent -> cifg 131 Value rec2input_; 132 Value rec2forget_; 133 Value rec2cell_; 134 Value rec2output_; 135 136 // bias -> cifg 137 Value bias2input_; 138 Value bias2forget_; 139 Value bias2cell_; 140 Value bias2output_; 141 142 // projection 143 Value proj_weight_; 144 Value proj_bias_; 145 146 // state 147 Value input_activation_state_; 148 Value input_cell_state_; 149 150 // layer norm coefficients 151 Value input_layer_norm_coefficients_; 152 Value forget_layer_norm_coefficients_; 153 Value cell_layer_norm_coefficients_; 154 Value output_layer_norm_coefficients_; 155 156 mlir::TFL::LSTMOp lstm_; 157 158 Value none_; 159 SmallVector<int64_t, 1> bias_slice_shape_; 160 SmallVector<int64_t, 1> bias_size_values_; 161 SmallVector<int64_t, 2> weight_slice_shape_; 162 SmallVector<int64_t, 2> weight_slice_size_input_values_; 163 SmallVector<int64_t, 2> weight_slice_size_recurrent_values_; 164 OpBuilder builder_; 165 }; 166 167 // A utility class that enables the conversion of the 168 // LayerNormalizedLSTMCellSimple composite op into a fused TFL LSTM op. The 169 // fused op is contained within a FuncOp that also contains other supporting ops 170 // needed to construct the operands for the fused op. The caller provides the 171 // containing FuncOp as input with arguments specifying the input, weight, 172 // projection, bias and layer norm scale. The weight, projection, bias and 173 // layer norm scale all need to be RankedTensorType. 174 // This class overrides the layer norm coefficient setters from the base class. 175 class ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM 176 : public ConvertLSTMCellSimpleToFusedLSTM { 177 public: ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM(mlir::func::FuncOp fused_func_op)178 explicit ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM( 179 mlir::func::FuncOp fused_func_op) 180 : ConvertLSTMCellSimpleToFusedLSTM(fused_func_op) {} 181 182 // not copyable. 183 ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM( 184 const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; 185 ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM& operator=( 186 const ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM&) = delete; ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM()187 ~ConvertLayerNormalizedLSTMCellSimpleToFusedLSTM() override {} 188 GetCompositeOpName()189 llvm::StringRef GetCompositeOpName() override { 190 return kLayerNormalizedLstmCellSimple; 191 } 192 193 protected: 194 LogicalResult Initialize() override; 195 196 void SetCellLayerNormCoefficients() override; 197 void SetInputLayerNormCoefficients() override; 198 void SetForgetLayerNormCoefficients() override; 199 void SetOutputLayerNormCoefficients() override; 200 201 private: 202 // specified state 203 Value layer_norm_scale_; 204 205 // internal state 206 RankedTensorType layer_norm_scale_type_; 207 SmallVector<int64_t, 1> layer_norm_slice_shape_; 208 SmallVector<int64_t, 1> layer_norm_size_values_; 209 }; 210 211 LogicalResult ConvertKerasLSTMLayer(mlir::func::FuncOp func_op, 212 OpBuilder* builder); 213 214 } // end namespace TFL 215 } // end namespace mlir 216 217 #endif // TENSORFLOW_COMPILER_MLIR_LITE_UTILS_LSTM_UTILS_H_ 218