1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/lite/delegates/gpu/gl/kernels/lstm.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21
22 #include "absl/memory/memory.h"
23 #include "tensorflow/lite/delegates/gpu/common/status.h"
24 #include "tensorflow/lite/delegates/gpu/common/types.h"
25 #include "tensorflow/lite/delegates/gpu/gl/node_shader.h"
26
27 namespace tflite {
28 namespace gpu {
29 namespace gl {
30 namespace {
31
32 // Basic LSTMCell gates.
33 //
34 // inputs: 0 1
35 // activ_temp prev_state
36 // \ /
37 // [[LSTM gates]]
38 // / \
39 // new_state activation
40 // outputs: 0 1
41 //
42 // The size of activ_temp should be 4x size of new_state.
43 // The size of prev_state == new_state == activation.
44 //
45 class LstmNodeShader : public NodeShader {
46 public:
GenerateCode(const GenerationContext & ctx,GeneratedCode * generated_code) const47 absl::Status GenerateCode(const GenerationContext& ctx,
48 GeneratedCode* generated_code) const final {
49 std::string code = R"(
50 vec4 prev_state = $input_data_1[gid.x, gid.y, gid.z]$;
51
52 int c0 = 0 * $workload_z$;
53 int c1 = 1 * $workload_z$;
54 int c2 = 2 * $workload_z$;
55 int c3 = 3 * $workload_z$;
56
57 // input, new, forget, output
58 vec4 gate_0 = $input_data_0[gid.x, gid.y, gid.z + c0]$;
59 vec4 gate_1 = $input_data_0[gid.x, gid.y, gid.z + c1]$;
60 vec4 gate_2 = $input_data_0[gid.x, gid.y, gid.z + c2]$;
61 vec4 gate_3 = $input_data_0[gid.x, gid.y, gid.z + c3]$;
62
63 vec4 input_gate = 1.0f / (1.0f + exp(-1.0 * gate_0)); // sig(x)
64 vec4 new_input = tanh(gate_1); // tanh(x)
65 vec4 forget_gate = 1.0f / (1.0f + exp(-1.0 * gate_2)); // sig(x)
66 vec4 output_gate = 1.0f / (1.0f + exp(-1.0 * gate_3)); // sig(x)
67
68 vec4 new_state = input_gate * new_input + forget_gate * prev_state;
69 vec4 activation = output_gate * tanh(new_state);
70
71 value_0 = new_state;
72 value_1 = activation;
73 )";
74 *generated_code = {
75 /*parameters=*/{},
76 /*objects=*/{},
77 /*shared_variables=*/{},
78 /*workload=*/uint3(),
79 /*workgroup=*/uint3(),
80 /*source_code=*/std::move(code),
81 /*input=*/IOStructure::ONLY_DEFINITIONS,
82 /*output=*/IOStructure::AUTO,
83 };
84 return absl::OkStatus();
85 }
86 };
87
88 } // namespace
89
NewLstmNodeShader()90 std::unique_ptr<NodeShader> NewLstmNodeShader() {
91 return std::make_unique<LstmNodeShader>();
92 }
93
94 } // namespace gl
95 } // namespace gpu
96 } // namespace tflite
97