1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <limits>
17
18 #include "tensorflow/lite/c/common.h"
19 #include "tensorflow/lite/kernels/cpu_backend_context.h"
20 #include "tensorflow/lite/kernels/gru_cell.h"
21 #include "tensorflow/lite/kernels/internal/tensor.h"
22 #include "tensorflow/lite/kernels/kernel_util.h"
23
24 // Unidirectional_sequence_gru is the fused version of GRU:
25 // https://www.tensorflow.org/api_docs/python/tf/keras/layers/GRU.
26 namespace tflite {
27 namespace ops {
28 namespace custom {
29 namespace unidirectional_sequence_gru {
30 namespace {
31
GruImpl(const TfLiteTensor * input,const TfLiteTensor * input_state,const TfLiteTensor * gate_weight,const TfLiteTensor * gate_bias,const TfLiteTensor * candidate_weight,const TfLiteTensor * candidate_bias,TfLiteTensor * output,TfLiteTensor * output_state,TfLiteTensor * activation,TfLiteTensor * concat,tflite::CpuBackendContext * cpu_backend_context)32 void GruImpl(const TfLiteTensor* input, const TfLiteTensor* input_state,
33 const TfLiteTensor* gate_weight, const TfLiteTensor* gate_bias,
34 const TfLiteTensor* candidate_weight,
35 const TfLiteTensor* candidate_bias, TfLiteTensor* output,
36 TfLiteTensor* output_state, TfLiteTensor* activation,
37 TfLiteTensor* concat,
38 tflite::CpuBackendContext* cpu_backend_context) {
39 const int n_time = input->dims->data[0];
40 const int n_batch = input->dims->data[1];
41 const int n_input = input->dims->data[2];
42 const int n_output = output->dims->data[2];
43 const int n_batch_input = n_batch * n_input;
44 const int n_batch_output = n_batch * n_output;
45 const RuntimeShape input_shape({n_batch, n_input});
46 const float* input_data = GetTensorData<float>(input);
47 const RuntimeShape state_shape = GetTensorShape(input_state);
48 const float* input_state_data = GetTensorData<float>(input_state);
49 const RuntimeShape gate_weight_shape = GetTensorShape(gate_weight);
50 const float* gate_weight_data = GetTensorData<float>(gate_weight);
51 const RuntimeShape gate_bias_shape = GetTensorShape(gate_bias);
52 const float* gate_bias_data = GetTensorData<float>(gate_bias);
53 const RuntimeShape candidate_weight_shape = GetTensorShape(candidate_weight);
54 const float* candidate_weight_data = GetTensorData<float>(candidate_weight);
55 const RuntimeShape candidate_bias_shape = GetTensorShape(candidate_bias);
56 const float* candidate_bias_data = GetTensorData<float>(candidate_bias);
57 const RuntimeShape activation_shape = GetTensorShape(activation);
58 const RuntimeShape output_shape = RuntimeShape({n_batch, n_output});
59 float* output_data = GetTensorData<float>(output);
60 float* output_state_data = GetTensorData<float>(output_state);
61 float* activation_data = GetTensorData<float>(activation);
62 const RuntimeShape concat_shape = GetTensorShape(concat);
63 float* concat_data = GetTensorData<float>(concat);
64 tflite::FullyConnectedParams fc_params;
65 fc_params.float_activation_min = std::numeric_limits<float>::lowest();
66 fc_params.float_activation_max = std::numeric_limits<float>::max();
67
68 // The lhs is cacheable only when both gate weight & candidate weight are both
69 // constants.
70 fc_params.lhs_cacheable =
71 IsConstantTensor(gate_weight) && IsConstantTensor(candidate_weight);
72 fc_params.rhs_cacheable = false;
73 for (int i = 0; i < n_time; ++i) {
74 gru_cell::GruCell(
75 input_shape, input_data, state_shape, input_state_data,
76 gate_weight_shape, gate_weight_data, gate_bias_shape, gate_bias_data,
77 candidate_weight_shape, candidate_weight_data, candidate_bias_shape,
78 candidate_bias_data, output_shape, output_data, output_state_data,
79 activation_shape, activation_data, concat_shape, concat_data, fc_params,
80 cpu_backend_context);
81 input_data += n_batch_input;
82 output_data += n_batch_output;
83 input_state_data = output_state_data;
84 }
85 }
86
87 } // namespace
88
89 enum InputTensor {
90 // Input tensor of size [n_time, n_batch, n_input]
91 kInput = 0,
92 // Input state tensor of size [n_batch, n_output]
93 kInputState = 1,
94 // Gate weight tensor of size [2*n_output, n_input+n_output]
95 kGateWeight = 2,
96 // Gate bias tensor of size [2*n_output]
97 kGateBias = 3,
98 // Candidate weight tensor of size [n_output, n_input+n_output]
99 kCandidateWeight = 4,
100 // Candidate bias tensor of size [n_output]
101 kCandidateBias = 5,
102 kInputNum = 6
103 };
104
105 enum OutputTensor {
106 // Input tensor of size [n_time, n_batch, n_output]
107 kOutput = 0,
108 // Output state tensor of size [n_batch, n_output]
109 kOutputState = 1,
110 kOutputNum = 2
111 };
112
113 enum TemporaryTensor {
114 // Scratch buffer for activation of size [n_batch, 2*n_output]
115 kActivation = 0,
116 // Scratch buffer for activation of size [n_batch, n_input+n_output]
117 kConcat = 1,
118 kTemporaryNum = 2
119 };
120
Init(TfLiteContext * context,const char * buffer,size_t length)121 void* Init(TfLiteContext* context, const char* buffer, size_t length) {
122 auto* scratch_tensor_index = new int;
123 context->AddTensors(context, kTemporaryNum, scratch_tensor_index);
124 return scratch_tensor_index;
125 }
126
Free(TfLiteContext * context,void * buffer)127 void Free(TfLiteContext* context, void* buffer) {
128 delete reinterpret_cast<int*>(buffer);
129 }
130
Prepare(TfLiteContext * context,TfLiteNode * node)131 TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
132 int* scratch_tensor_index = reinterpret_cast<int*>(node->user_data);
133
134 TF_LITE_ENSURE_EQ(context, node->inputs->size, kInputNum);
135 TF_LITE_ENSURE_EQ(context, node->outputs->size, kOutputNum);
136
137 // input's dim = [n_time, n_batch, n_input]
138 const TfLiteTensor* input;
139 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
140 TF_LITE_ENSURE_EQ(context, input->dims->size, 3);
141 const int n_time = input->dims->data[0];
142 const int n_batch = input->dims->data[1];
143 const int n_input = input->dims->data[2];
144
145 // input_state's dim = [n_batch, n_output]
146 const TfLiteTensor* input_state;
147 TF_LITE_ENSURE_OK(context,
148 GetInputSafe(context, node, kInputState, &input_state));
149 TF_LITE_ENSURE_EQ(context, input_state->dims->size, 2);
150 TF_LITE_ENSURE_EQ(context, input_state->dims->data[0], n_batch);
151 const int n_output = input_state->dims->data[1];
152
153 // gate_weight' dim = [2 * n_output, n_input + n_output]
154 const TfLiteTensor* gate_weight;
155 TF_LITE_ENSURE_OK(context,
156 GetInputSafe(context, node, kGateWeight, &gate_weight));
157 TF_LITE_ENSURE_EQ(context, gate_weight->dims->size, 2);
158 TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[0], 2 * n_output);
159 TF_LITE_ENSURE_EQ(context, gate_weight->dims->data[1], n_input + n_output);
160
161 // gate_bias' dim = [2 * n_output]
162 const TfLiteTensor* gate_bias;
163 TF_LITE_ENSURE_OK(context,
164 GetInputSafe(context, node, kGateBias, &gate_bias));
165 TF_LITE_ENSURE_EQ(context, gate_bias->dims->size, 1);
166 TF_LITE_ENSURE_EQ(context, gate_bias->dims->data[0], 2 * n_output);
167
168 // candidate_weight' dim = [n_output, n_input + n_output]
169 const TfLiteTensor* candidate_weight;
170 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
171 &candidate_weight));
172 TF_LITE_ENSURE_EQ(context, candidate_weight->dims->size, 2);
173 TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[0], n_output);
174 TF_LITE_ENSURE_EQ(context, candidate_weight->dims->data[1],
175 n_input + n_output);
176
177 // candidate_bias' dim = [n_output]
178 const TfLiteTensor* candidate_bias;
179 TF_LITE_ENSURE_OK(
180 context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
181 TF_LITE_ENSURE_EQ(context, candidate_bias->dims->size, 1);
182 TF_LITE_ENSURE_EQ(context, candidate_bias->dims->data[0], n_output);
183
184 // output's dim = [n_time, n_batch, n_output]
185 TfLiteTensor* output;
186 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
187 TfLiteIntArray* output_size = TfLiteIntArrayCreate(3);
188 output_size->data[0] = n_time;
189 output_size->data[1] = n_batch;
190 output_size->data[2] = n_output;
191 TF_LITE_ENSURE_OK(context,
192 context->ResizeTensor(context, output, output_size));
193
194 // output_state's dim = [n_batch, n_output]
195 TfLiteTensor* output_state;
196 TF_LITE_ENSURE_OK(context,
197 GetOutputSafe(context, node, kOutputState, &output_state));
198 TF_LITE_ENSURE_OK(
199 context, context->ResizeTensor(context, output_state,
200 TfLiteIntArrayCopy(input_state->dims)));
201
202 TfLiteIntArrayFree(node->temporaries);
203 node->temporaries = TfLiteIntArrayCreate(kTemporaryNum);
204
205 // activation's dim = [n_batch, 2 * n_output]
206 node->temporaries->data[kActivation] = *scratch_tensor_index;
207 TfLiteTensor* activation;
208 TF_LITE_ENSURE_OK(context,
209 GetTemporarySafe(context, node, kActivation, &activation));
210 activation->type = input->type;
211 activation->allocation_type = kTfLiteArenaRw;
212 TfLiteIntArray* activation_size = TfLiteIntArrayCreate(2);
213 activation_size->data[0] = n_batch;
214 activation_size->data[1] = 2 * n_output;
215 TF_LITE_ENSURE_OK(
216 context, context->ResizeTensor(context, activation, activation_size));
217
218 // concat's dim = [n_batch, n_input + n_output]
219 node->temporaries->data[kConcat] = (*scratch_tensor_index) + kConcat;
220 TfLiteTensor* concat;
221 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
222 concat->type = input->type;
223 concat->allocation_type = kTfLiteArenaRw;
224 TfLiteIntArray* concat_size = TfLiteIntArrayCreate(2);
225 concat_size->data[0] = n_batch;
226 concat_size->data[1] = n_input + n_output;
227 TF_LITE_ENSURE_OK(context,
228 context->ResizeTensor(context, concat, concat_size));
229
230 return kTfLiteOk;
231 }
232
Eval(TfLiteContext * context,TfLiteNode * node)233 TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
234 const TfLiteTensor* input;
235 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kInput, &input));
236 const TfLiteTensor* input_state;
237 TF_LITE_ENSURE_OK(context,
238 GetInputSafe(context, node, kInputState, &input_state));
239 const TfLiteTensor* gate_weight;
240 TF_LITE_ENSURE_OK(context,
241 GetInputSafe(context, node, kGateWeight, &gate_weight));
242 const TfLiteTensor* gate_bias;
243 TF_LITE_ENSURE_OK(context,
244 GetInputSafe(context, node, kGateBias, &gate_bias));
245 const TfLiteTensor* candidate_weight;
246 TF_LITE_ENSURE_OK(context, GetInputSafe(context, node, kCandidateWeight,
247 &candidate_weight));
248 const TfLiteTensor* candidate_bias;
249 TF_LITE_ENSURE_OK(
250 context, GetInputSafe(context, node, kCandidateBias, &candidate_bias));
251 TfLiteTensor* output;
252 TF_LITE_ENSURE_OK(context, GetOutputSafe(context, node, kOutput, &output));
253 TfLiteTensor* output_state;
254 TF_LITE_ENSURE_OK(context,
255 GetOutputSafe(context, node, kOutputState, &output_state));
256 TfLiteTensor* activation;
257 TF_LITE_ENSURE_OK(context,
258 GetTemporarySafe(context, node, kActivation, &activation));
259 TfLiteTensor* concat;
260 TF_LITE_ENSURE_OK(context, GetTemporarySafe(context, node, kConcat, &concat));
261 auto cpu_backend_context = CpuBackendContext::GetFromContext(context);
262
263 if (gate_weight->type == kTfLiteFloat32) {
264 GruImpl(input, input_state, gate_weight, gate_bias, candidate_weight,
265 candidate_bias, output, output_state, activation, concat,
266 cpu_backend_context);
267 } else {
268 TF_LITE_KERNEL_LOG(context,
269 "Unsupported combination of data types for GruCell");
270 return kTfLiteError;
271 }
272
273 return kTfLiteOk;
274 }
275
276 } // namespace unidirectional_sequence_gru
277
Register_UNIDIRECTIONAL_SEQUENCE_GRU()278 TfLiteRegistration* Register_UNIDIRECTIONAL_SEQUENCE_GRU() {
279 static TfLiteRegistration r = {
280 unidirectional_sequence_gru::Init, unidirectional_sequence_gru::Free,
281 unidirectional_sequence_gru::Prepare, unidirectional_sequence_gru::Eval};
282 return &r;
283 }
284
285 } // namespace custom
286 } // namespace ops
287 } // namespace tflite
288