xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/kernels/unidirectional_sequence_gru.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <limits>
17 
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/cpu_backend_context.h"
20 #include "tensorflow/lite/kernels/gru_cell.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23 
24 // Unidirectional_sequence_gru is the fused version of GRU:
25 // https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU.
26 namespace tflite {
27 namespace ops {
28 namespace custom {
29 namespace unidirectional_sequence_gru {
30 namespace {
31 
GruImpl(const TfLiteTensor * input,const TfLiteTensor * input_state,const TfLiteTensor * gate_weight,const TfLiteTensor * gate_bias,const TfLiteTensor * candidate_weight,const TfLiteTensor * candidate_bias,TfLiteTensor * output,TfLiteTensor * output_state,TfLiteTensor * activation,TfLiteTensor * concat,tflite::CpuBackendContext * cpu_backend_context)32 void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
33              const TfLiteTensor* gate_weight, const TfLiteTensor* gate_bias,
34              const TfLiteTensor* candidate_weight,
35              const TfLiteTensor* candidate_bias, TfLiteTensor* output,
36              TfLiteTensor* output_state, TfLiteTensor* activation,
37              TfLiteTensor* concat,
38              tflite::CpuBackendContext* cpu_backend_context) {
39   const int n_time = input->dims->data[0];
40   const int n_batch = input->dims->data[1];
41   const int n_input = input->dims->data[2];
42   const int n_output = output->dims->data[2];
43   const int n_batch_input = n_batch * n_input;
44   const int n_batch_output = n_batch * n_output;
45   const RuntimeShape input_shape({n_batch, n_input});
46   const float* input_data = GetTensorData<float>(input);
47   const RuntimeShape state_shape = GetTensorShape(input_state);
48   const float* input_state_data = GetTensorData<float>(input_state);
49   const RuntimeShape gate_weight_shape = GetTensorShape(gate_weight);
50   const float* gate_weight_data = GetTensorData<float>(gate_weight);
51   const RuntimeShape gate_bias_shape = GetTensorShape(gate_bias);
52   const float* gate_bias_data = GetTensorData<float>(gate_bias);
53   const RuntimeShape candidate_weight_shape = GetTensorShape(candidate_weight);
54   const float* candidate_weight_data = GetTensorData<float>(candidate_weight);
55   const RuntimeShape candidate_bias_shape = GetTensorShape(candidate_bias);
56   const float* candidate_bias_data = GetTensorData<float>(candidate_bias);
57   const RuntimeShape activation_shape = GetTensorShape(activation);
58   const RuntimeShape output_shape = RuntimeShape({n_batch, n_output});
59   float* output_data = GetTensorData<float>(output);
60   float* output_state_data = GetTensorData<float>(output_state);
61   float* activation_data = GetTensorData<float>(activation);
62   const RuntimeShape concat_shape = GetTensorShape(concat);
63   float* concat_data = GetTensorData<float>(concat);
64   tflite::FullyConnectedParams fc_params;
65   fc_params.float_activation_min = std::numeric_limits<float>::lowest();
66   fc_params.float_activation_max = std::numeric_limits<float>::max();
67 
68   // The lhs is cacheable only when both gate weight & candidate weight are both
69   // constants.
70   fc_params.lhs_cacheable =
71       IsConstantTensor(gate_weight) && IsConstantTensor(candidate_weight);
72   fc_params.rhs_cacheable = false;
73   for (int i = 0; i < n_time; ++i) {
74     gru_cell::GruCell(
75         input_shape, input_data, state_shape, input_state_data,
76         gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
77         candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
78         candidate_bias_data, output_shape, output_data, output_state_data,
79         activation_shape, activation_data, concat_shape, concat_data, fc_params,
80         cpu_backend_context);
81     input_data += n_batch_input;
82     output_data += n_batch_output;
83     input_state_data = output_state_data;
84   }
85 }
86 
87 }  // namespace
88 
89 enum InputTensor {
90   // Input tensor of size [n_time, n_batch, n_input]
91   kInput = 0,
92   // Input state tensor of size [n_batch, n_output]
93   kInputState = 1,
94   // Gate weight tensor of size [2*n_output, n_input+n_output]
95   kGateWeight = 2,
96   // Gate bias tensor of size [2*n_output]
97   kGateBias = 3,
98   // Candidate weight tensor of size [n_output, n_input+n_output]
99   kCandidateWeight = 4,
100   // Candidate bias tensor of size [n_output]
101   kCandidateBias = 5,
102   kInputNum = 6
103 };
104 
105 enum OutputTensor {
106   // Input tensor of size [n_time, n_batch, n_output]
107   kOutput = 0,
108   // Output state tensor of size [n_batch, n_output]
109   kOutputState = 1,
110   kOutputNum = 2
111 };
112 
113 enum TemporaryTensor {
114   // Scratch buffer for activation of size [n_batch, 2*n_output]
115   kActivation = 0,
116   // Scratch buffer for activation of size [n_batch, n_input+n_output]
117   kConcat = 1,
118   kTemporaryNum = 2
119 };
120 
Init(TfLiteContext * context,const char * buffer,size_t length)121 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
122   auto* scratch_tensor_index = new int;
123   context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
124   return scratch_tensor_index;
125 }
126 
Free(TfLiteContext * context,void * buffer)127 void Free(TfLiteContext* context, void* buffer) {
128   delete reinterpret_cast<int*>(buffer);
129 }
130 
Prepare(TfLiteContext * context,TfLiteNode * node)131 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
132   int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
133 
134   TF_LITE_ENSURE_EQ(context, node->inputs->size, kInputNum);
135   TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
136 
137   // input's dim = [n_time, n_batch, n_input]
138   const TfLiteTensor* input;
139   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
140   TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
141   const int n_time = input->dims->data[0];
142   const int n_batch = input->dims->data[1];
143   const int n_input = input->dims->data[2];
144 
145   // input_state's dim = [n_batch, n_output]
146   const TfLiteTensor* input_state;
147   TF_LITE_ENSURE_OK(context,
148                     GetInputSafe(context, node, kInputState, &input_state));
149   TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
150   TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
151   const int n_output = input_state->dims->data[1];
152 
153   // gate_weight' dim = [2 * n_output, n_input + n_output]
154   const TfLiteTensor* gate_weight;
155   TF_LITE_ENSURE_OK(context,
156                     GetInputSafe(context, node, kGateWeight, &gate_weight));
157   TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
158   TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
159   TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
160 
161   // gate_bias' dim = [2 * n_output]
162   const TfLiteTensor* gate_bias;
163   TF_LITE_ENSURE_OK(context,
164                     GetInputSafe(context, node, kGateBias, &gate_bias));
165   TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
166   TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
167 
168   // candidate_weight' dim = [n_output, n_input + n_output]
169   const TfLiteTensor* candidate_weight;
170   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
171                                           &candidate_weight));
172   TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
173   TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
174   TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
175                     n_input + n_output);
176 
177   // candidate_bias' dim = [n_output]
178   const TfLiteTensor* candidate_bias;
179   TF_LITE_ENSURE_OK(
180       context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
181   TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
182   TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
183 
184   // output's dim = [n_time, n_batch, n_output]
185   TfLiteTensor* output;
186   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
187   TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
188   output_size->data[0] = n_time;
189   output_size->data[1] = n_batch;
190   output_size->data[2] = n_output;
191   TF_LITE_ENSURE_OK(context,
192                     context->ResizeTensor(context, output, output_size));
193 
194   // output_state's dim = [n_batch, n_output]
195   TfLiteTensor* output_state;
196   TF_LITE_ENSURE_OK(context,
197                     GetOutputSafe(context, node, kOutputState, &output_state));
198   TF_LITE_ENSURE_OK(
199       context, context->ResizeTensor(context, output_state,
200                                      TfLiteIntArrayCopy(input_state->dims)));
201 
202   TfLiteIntArrayFree(node->temporaries);
203   node->temporaries = TfLiteIntArrayCreate(kTemporaryNum);
204 
205   // activation's dim = [n_batch, 2 * n_output]
206   node->temporaries->data[kActivation] = *scratch_tensor_index;
207   TfLiteTensor* activation;
208   TF_LITE_ENSURE_OK(context,
209                     GetTemporarySafe(context, node, kActivation, &activation));
210   activation->type = input->type;
211   activation->allocation_type = kTfLiteArenaRw;
212   TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
213   activation_size->data[0] = n_batch;
214   activation_size->data[1] = 2 * n_output;
215   TF_LITE_ENSURE_OK(
216       context, context->ResizeTensor(context, activation, activation_size));
217 
218   // concat's dim  = [n_batch, n_input + n_output]
219   node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
220   TfLiteTensor* concat;
221   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
222   concat->type = input->type;
223   concat->allocation_type = kTfLiteArenaRw;
224   TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
225   concat_size->data[0] = n_batch;
226   concat_size->data[1] = n_input + n_output;
227   TF_LITE_ENSURE_OK(context,
228                     context->ResizeTensor(context, concat, concat_size));
229 
230   return kTfLiteOk;
231 }
232 
Eval(TfLiteContext * context,TfLiteNode * node)233 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
234   const TfLiteTensor* input;
235   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
236   const TfLiteTensor* input_state;
237   TF_LITE_ENSURE_OK(context,
238                     GetInputSafe(context, node, kInputState, &input_state));
239   const TfLiteTensor* gate_weight;
240   TF_LITE_ENSURE_OK(context,
241                     GetInputSafe(context, node, kGateWeight, &gate_weight));
242   const TfLiteTensor* gate_bias;
243   TF_LITE_ENSURE_OK(context,
244                     GetInputSafe(context, node, kGateBias, &gate_bias));
245   const TfLiteTensor* candidate_weight;
246   TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
247                                           &candidate_weight));
248   const TfLiteTensor* candidate_bias;
249   TF_LITE_ENSURE_OK(
250       context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
251   TfLiteTensor* output;
252   TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
253   TfLiteTensor* output_state;
254   TF_LITE_ENSURE_OK(context,
255                     GetOutputSafe(context, node, kOutputState, &output_state));
256   TfLiteTensor* activation;
257   TF_LITE_ENSURE_OK(context,
258                     GetTemporarySafe(context, node, kActivation, &activation));
259   TfLiteTensor* concat;
260   TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
261   auto cpu_backend_context = CpuBackendContext::GetFromContext(context);
262 
263   if (gate_weight->type == kTfLiteFloat32) {
264     GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
265             candidate_bias, output, output_state, activation, concat,
266             cpu_backend_context);
267   } else {
268     TF_LITE_KERNEL_LOG(context,
269                        "Unsupported combination of data types for GruCell");
270     return kTfLiteError;
271   }
272 
273   return kTfLiteOk;
274 }
275 
276 }  // namespace unidirectional_sequence_gru
277 
Register_UNIDIRECTIONAL_SEQUENCE_GRU()278 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU() {
279   static TfLiteRegistration r = {
280       unidirectional_sequence_gru::Init, unidirectional_sequence_gru::Free,
281       unidirectional_sequence_gru::Prepare, unidirectional_sequence_gru::Eval};
282   return &r;
283 }
284 
285 }  // namespace custom
286 }  // namespace ops
287 }  // namespace tflite
288