xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/rnn/gru_ops.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_KERNELS_RNN_GRU_OPS_H_
17 #define TENSORFLOW_CORE_KERNELS_RNN_GRU_OPS_H_
18 
19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
20 #include "tensorflow/core/framework/tensor_types.h"
21 #include "tensorflow/core/kernels/rnn/blas_gemm.h"
22 #include "tensorflow/core/platform/types.h"
23 
24 namespace tensorflow {
25 
26 class OpKernelContext;
27 
28 namespace functor {
29 
30 struct GRUCell {
GRUCellGRUCell31   GRUCell(const int batch_size, const int input_size, const int cell_size)
32       : batch_size_(batch_size),
33         input_size_(input_size),
34         cell_size_(cell_size) {}
35 
x_offsetsGRUCell36   inline Eigen::array<Eigen::DenseIndex, 2> x_offsets() const { return {0, 0}; }
37 
x_extendsGRUCell38   inline Eigen::array<Eigen::DenseIndex, 2> x_extends() const {
39     return {batch_size_, input_size_};
40   }
41 
h_offsetsGRUCell42   inline Eigen::array<Eigen::DenseIndex, 2> h_offsets() const {
43     return {0, input_size_};
44   }
45 
h_extendsGRUCell46   inline Eigen::array<Eigen::DenseIndex, 2> h_extends() const {
47     return {batch_size_, cell_size_};
48   }
49 
ru_r_offsetGRUCell50   inline Eigen::array<Eigen::DenseIndex, 2> ru_r_offset() const {
51     return {0, 0};
52   }
53 
ru_u_offsetGRUCell54   inline Eigen::array<Eigen::DenseIndex, 2> ru_u_offset() const {
55     return {0, cell_size_};
56   }
57 
cell_extentsGRUCell58   inline Eigen::array<Eigen::DenseIndex, 2> cell_extents() const {
59     return {batch_size_, cell_size_};
60   }
61 
62  protected:
63   const int batch_size_;
64   const int input_size_;
65   const int cell_size_;
66 };
67 
68 template <typename Device, typename T, bool USE_CUBLAS>
69 struct GRUBlockCellFprop : public GRUCell {
GRUBlockCellFpropGRUBlockCellFprop70   GRUBlockCellFprop(const int batch_size, const int input_size,
71                     const int cell_size)
72       : GRUCell(batch_size, input_size, cell_size) {}
73 
operatorGRUBlockCellFprop74   void operator()(
75       OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
76       typename TTypes<T>::ConstMatrix h_prev,
77       typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
78       typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
79       typename TTypes<T>::Matrix r_u_bar, typename TTypes<T>::Matrix r,
80       typename TTypes<T>::Matrix u, typename TTypes<T>::Matrix c,
81       typename TTypes<T>::Matrix h, typename TTypes<T>::Matrix x_h_prev,
82       typename TTypes<T>::Matrix x_h_prevr) {
83     // Concat x_h_prev = [x, h_prev].
84     x_h_prev.slice(x_offsets(), x_extends()).device(d) = x;
85     x_h_prev.slice(h_offsets(), h_extends()).device(d) = h_prev;
86 
87     // r_u_bar = x_h_prev * w_ru + b_ru
88     typename TTypes<T>::ConstMatrix const_x_h_prev(x_h_prev.data(),
89                                                    x_h_prev.dimensions());
90     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
91         ctx, d, false, false, typename gemm_compute_type<T>::type(1.f),
92         const_x_h_prev, w_ru, typename gemm_compute_type<T>::type(0.f),
93         r_u_bar);
94 
95     // Creating a bias matrix for adding by broadcasting 'b_ru'
96     Eigen::array<Eigen::DenseIndex, 2> broadcast_shape({batch_size_, 1});
97     Eigen::array<Eigen::DenseIndex, 2> b_ru_shape({1, b_ru.dimensions()[0]});
98     r_u_bar.device(d) += b_ru.reshape(b_ru_shape).broadcast(broadcast_shape);
99 
100     // Slice r_u_bar into r, u and apply the sigmoid.
101     r.device(d) = (r_u_bar.slice(ru_r_offset(), cell_extents())).sigmoid();
102     u.device(d) = (r_u_bar.slice(ru_u_offset(), cell_extents())).sigmoid();
103 
104     // Concat x_h_prevr = [x,h_prev*r]
105     x_h_prevr.slice(x_offsets(), x_extends()).device(d) = x;
106     x_h_prevr.slice(h_offsets(), h_extends()).device(d) = h_prev * r;
107 
108     // c = tanh(x_h_prevr*w_c+b_c), Note b_c is broadcasted before adding.
109     typename TTypes<T>::ConstMatrix const_x_h_prevr(x_h_prevr.data(),
110                                                     x_h_prevr.dimensions());
111     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
112         ctx, d, false, false, typename gemm_compute_type<T>::type(1.f),
113         const_x_h_prevr, w_c, typename gemm_compute_type<T>::type(0.f), c);
114 
115     Eigen::array<Eigen::DenseIndex, 2> b_c_shape({1, b_c.dimensions()[0]});
116     c.device(d) += (b_c.reshape(b_c_shape).broadcast(broadcast_shape));
117     c.device(d) = c.tanh();
118 
119     // h= u*h_prev + (1-u)*c
120     h.device(d) = u * (h_prev - c) + c;
121   }
122 };
123 
124 template <typename Device, typename T, bool USE_CUBLAS>
125 struct GRUBlockCellBprop : public GRUCell {
GRUBlockCellBpropGRUBlockCellBprop126   GRUBlockCellBprop(const int batch_size, const int input_size,
127                     const int cell_size)
128       : GRUCell(batch_size, input_size, cell_size) {}
129 
operatorGRUBlockCellBprop130   void operator()(
131       OpKernelContext* ctx, const Device& d, typename TTypes<T>::ConstMatrix x,
132       typename TTypes<T>::ConstMatrix h_prev,
133       typename TTypes<T>::ConstMatrix w_ru, typename TTypes<T>::ConstMatrix w_c,
134       typename TTypes<T>::ConstVec b_ru, typename TTypes<T>::ConstVec b_c,
135       typename TTypes<T>::ConstMatrix r, typename TTypes<T>::ConstMatrix u,
136       typename TTypes<T>::ConstMatrix c, typename TTypes<T>::ConstMatrix d_h,
137       typename TTypes<T>::Matrix d_x, typename TTypes<T>::Matrix d_h_prev,
138       typename TTypes<T>::Matrix d_c_bar,
139       typename TTypes<T>::Matrix d_r_bar_u_bar,
140       typename TTypes<T>::Matrix d_r_bar, typename TTypes<T>::Matrix d_u_bar,
141       typename TTypes<T>::Matrix d_hr,
142       typename TTypes<T>::Matrix d_x_comp1_and_h_prev_comp1,
143       typename TTypes<T>::Matrix d_x_comp2_and_h_prevr) {
144     // d_c_bar = d_h*(1-u)*(1-(c*c))
145     d_c_bar.device(d) =
146         ((d_h * (u.constant(T(1)) - u)) * (c.constant(T(1)) - c * c));
147 
148     // d_u_bar = d_h*(h-c)*(u*(1-u))
149     d_u_bar.device(d) = d_h * (h_prev - c) * u * (u.constant(T(1)) - u);
150 
151     // [2nd_component_of_d_x d_h_prevr] = d_c_bar X w_c^T
152     typename TTypes<T>::ConstMatrix const_d_c_bar(d_c_bar.data(),
153                                                   d_c_bar.dimensions());
154     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
155         ctx, d, false, true, typename gemm_compute_type<T>::type(1.f),
156         const_d_c_bar, w_c, typename gemm_compute_type<T>::type(0.f),
157         d_x_comp2_and_h_prevr);
158 
159     d_hr.device(d) = d_x_comp2_and_h_prevr.slice(h_offsets(), h_extends());
160     d_r_bar.device(d) = (d_hr * h_prev * r) * (r.constant(T(1)) - r);
161 
162     // d_r_bar_u_bar = concatenate(d_r_bar, d_u_bar) along axis = 1.
163     d_r_bar_u_bar.slice(ru_r_offset(), cell_extents()).device(d) = d_r_bar;
164     d_r_bar_u_bar.slice(ru_u_offset(), cell_extents()).device(d) = d_u_bar;
165 
166     // [1st_component_of_d_x 1st_component_of_d_h_prev] = [d_r_bar d_u_bar] X
167     // w_ru^T
168     typename TTypes<T>::ConstMatrix const_d_r_bar_u_bar(
169         d_r_bar_u_bar.data(), d_r_bar_u_bar.dimensions());
170     TensorBlasGemm<Device, T, USE_CUBLAS>::compute(
171         ctx, d, false, true, typename gemm_compute_type<T>::type(1.f),
172         const_d_r_bar_u_bar, w_ru, typename gemm_compute_type<T>::type(0.f),
173         d_x_comp1_and_h_prev_comp1);
174 
175     // d_x = d_x_comp1 + d_x_comp2
176     d_x.device(d) = (d_x_comp1_and_h_prev_comp1 + d_x_comp2_and_h_prevr)
177                         .slice(x_offsets(), x_extends());
178 
179     // d_h_prev = d_h_comp1 + d_hr*r + d_h*u
180     d_h_prev.device(d) =
181         d_x_comp1_and_h_prev_comp1.slice(h_offsets(), h_extends()) +
182         (d_hr * r) + (d_h * u);
183   }
184 };
185 
186 }  // namespace functor
187 }  // namespace tensorflow
188 
189 #endif  // TENSORFLOW_CORE_KERNELS_RNN_GRU_OPS_H_
190