1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 3 Licensed under the Apache License, Version 2.0 (the "License"); 4 you may not use this file except in compliance with the License. 5 You may obtain a copy of the License at 6 7 http://www.apache.org/licenses/LICENSE-2.0 8 9 Unless required by applicable law or agreed to in writing, software 10 distributed under the License is distributed on an "AS IS" BASIS, 11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 See the License for the specific language governing permissions and 13 limitations under the License. 14 ==============================================================================*/ 15 16 #ifndef TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ 17 #define TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ 18 19 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" 20 #include "tensorflow/core/framework/tensor_types.h" 21 #include "tensorflow/core/lib/core/status.h" 22 #include "tensorflow/core/platform/types.h" 23 24 namespace tensorflow { 25 namespace functor { 26 27 // Each training algorithm has a ApplyXYZ functor struct declared in 28 // this header file. They are specialized for different devices 29 // (CPUDevice in training_ops.cc or GPUDevice in training_ops_gpu.cc). 30 31 template <typename Device, typename T> 32 struct ApplyGradientDescent { 33 void operator()(const Device& d, typename TTypes<T>::Flat var, 34 typename TTypes<T>::ConstScalar alpha, 35 typename TTypes<T>::ConstFlat delta); 36 }; 37 38 template <typename Device, typename T> 39 struct ApplyAdadelta { 40 void operator()(const Device& d, typename TTypes<T>::Flat var, 41 typename TTypes<T>::Flat accum, 42 typename TTypes<T>::Flat accum_update, 43 typename TTypes<T>::ConstScalar lr, 44 typename TTypes<T>::ConstScalar rho, 45 typename TTypes<T>::ConstScalar epsilon, 46 typename TTypes<T>::ConstFlat grad); 47 }; 48 49 template <typename Device, typename T, typename Tindex> 50 struct SparseApplyAdadelta { 51 void operator()(const Device& d, typename TTypes<T>::Matrix var, 52 typename TTypes<T>::Matrix accum, 53 typename TTypes<T>::Matrix accum_update, 54 typename TTypes<T>::ConstScalar lr, 55 typename TTypes<T>::ConstScalar rho, 56 typename TTypes<T>::ConstScalar epsilon, 57 typename TTypes<T>::ConstMatrix grad, 58 typename TTypes<Tindex>::ConstFlat indices); 59 }; 60 61 template <typename Device, typename T> 62 struct FobosElasticNet { 63 void operator()(const Device& d, typename TTypes<T>::Flat var, 64 typename TTypes<T>::ConstScalar lr, 65 typename TTypes<T>::ConstScalar l1, 66 typename TTypes<T>::ConstScalar l2, 67 typename TTypes<T>::ConstFlat grad); 68 }; 69 70 template <typename Device, typename T> 71 struct ApplyProximalGradientDescent { 72 void operator()(const Device& d, typename TTypes<T>::Flat var, 73 typename TTypes<T>::ConstScalar lr, 74 typename TTypes<T>::ConstScalar l1, 75 typename TTypes<T>::ConstScalar l2, 76 typename TTypes<T>::ConstFlat grad); 77 }; 78 79 template <typename Device, typename T> 80 struct ApplyAdagrad { 81 void operator()(const Device& d, typename TTypes<T>::Flat var, 82 typename TTypes<T>::Flat accum, 83 typename TTypes<T>::ConstScalar lr, 84 typename TTypes<T>::ConstFlat grad, bool update_slots); 85 }; 86 87 template <typename Device, typename T> 88 struct ApplyAdagradV2 { 89 void operator()(const Device& d, typename TTypes<T>::Flat var, 90 typename TTypes<T>::Flat accum, 91 typename TTypes<T>::ConstScalar lr, 92 typename TTypes<T>::ConstScalar epsilon, 93 typename TTypes<T>::ConstFlat grad, bool update_slots); 94 }; 95 96 template <typename Device, typename T> 97 struct ApplyAdagradDA { 98 void operator()(const Device& d, typename TTypes<T>::Flat var, 99 typename TTypes<T>::Flat gradient_accum, 100 typename TTypes<T>::Flat gradient_squared_accum, 101 typename TTypes<T>::ConstScalar lr, int64_t global_step, 102 typename TTypes<T>::ConstScalar l1, 103 typename TTypes<T>::ConstScalar l2, 104 typename TTypes<T>::ConstFlat grad); 105 }; 106 107 template <typename Device, typename T, typename Tindex, bool has_epsilon> 108 struct SparseApplyAdagrad { 109 // Note that epsilon is ignored if has_epsilon is false. 110 Status operator()(const Device& d, typename TTypes<T>::Matrix var, 111 typename TTypes<T>::Matrix accum, 112 typename TTypes<T>::ConstScalar lr, 113 typename TTypes<T>::ConstScalar epsilon, 114 typename TTypes<T>::ConstMatrix grad, 115 typename TTypes<Tindex>::ConstVec indices, 116 int64_t inner_dim, bool update_slots); 117 }; 118 119 template <typename Device, typename T> 120 struct ApplyProximalAdagrad { 121 void operator()(const Device& d, typename TTypes<T>::Flat var, 122 typename TTypes<T>::Flat accum, 123 typename TTypes<T>::ConstScalar lr, 124 typename TTypes<T>::ConstScalar l1, 125 typename TTypes<T>::ConstScalar l2, 126 typename TTypes<T>::ConstFlat grad); 127 }; 128 129 template <typename Device, typename T, typename Tindex> 130 struct SparseApplyProximalAdagrad { 131 Status operator()(const Device& d, typename TTypes<T>::Matrix var, 132 typename TTypes<T>::Matrix accum, 133 typename TTypes<T>::ConstScalar lr, 134 typename TTypes<T>::ConstScalar l1, 135 typename TTypes<T>::ConstScalar l2, 136 typename TTypes<T>::ConstMatrix grad, 137 typename TTypes<Tindex>::ConstVec indices, 138 int64_t inner_dim); 139 }; 140 141 template <typename Device, typename T> 142 struct ApplyFtrl { 143 void operator()(const Device& d, typename TTypes<T>::Flat var, 144 typename TTypes<T>::Flat accum, 145 typename TTypes<T>::Flat linear, 146 typename TTypes<T>::ConstFlat grad, 147 typename TTypes<T>::ConstScalar lr, 148 typename TTypes<T>::ConstScalar l1, 149 typename TTypes<T>::ConstScalar l2, 150 typename TTypes<T>::ConstScalar lr_power); 151 }; 152 153 template <typename Device, typename T> 154 struct ApplyFtrlMultiplyLinearByLr { 155 void operator()(const Device& d, typename TTypes<T>::Flat var, 156 typename TTypes<T>::Flat accum, 157 typename TTypes<T>::Flat linear, 158 typename TTypes<T>::ConstFlat grad, 159 typename TTypes<T>::ConstScalar lr, 160 typename TTypes<T>::ConstScalar l1, 161 typename TTypes<T>::ConstScalar l2, 162 typename TTypes<T>::ConstScalar lr_power); 163 }; 164 165 template <typename Device, typename T> 166 struct ApplyFtrlV2 { 167 void operator()(const Device& d, typename TTypes<T>::Flat var, 168 typename TTypes<T>::Flat accum, 169 typename TTypes<T>::Flat linear, 170 typename TTypes<T>::ConstFlat grad, 171 typename TTypes<T>::ConstScalar lr, 172 typename TTypes<T>::ConstScalar l1, 173 typename TTypes<T>::ConstScalar l2, 174 typename TTypes<T>::ConstScalar l2_shrinkage, 175 typename TTypes<T>::ConstScalar lr_power); 176 }; 177 178 template <typename Device, typename T> 179 struct ApplyFtrlV2MultiplyLinearByLr { 180 void operator()(const Device& d, typename TTypes<T>::Flat var, 181 typename TTypes<T>::Flat accum, 182 typename TTypes<T>::Flat linear, 183 typename TTypes<T>::ConstFlat grad, 184 typename TTypes<T>::ConstScalar lr, 185 typename TTypes<T>::ConstScalar l1, 186 typename TTypes<T>::ConstScalar l2, 187 typename TTypes<T>::ConstScalar l2_shrinkage, 188 typename TTypes<T>::ConstScalar lr_power); 189 }; 190 191 template <typename Device, typename T, typename Tindex, bool has_l2_shrinkage> 192 struct SparseApplyFtrl { 193 Status operator()(const Device& d, typename TTypes<T>::Matrix var_flat, 194 typename TTypes<T>::Matrix accum_flat, 195 typename TTypes<T>::Matrix linear_flat, 196 typename TTypes<T>::ConstScalar lr, 197 typename TTypes<T>::ConstScalar l1, 198 typename TTypes<T>::ConstScalar l2, 199 typename TTypes<T>::ConstScalar l2_shrinkage, 200 typename TTypes<T>::ConstScalar lr_power, 201 typename TTypes<T>::ConstMatrix grad_flat, 202 typename TTypes<Tindex>::ConstVec indices_vec, 203 int64_t inner_dim, bool multiply_linear_by_lr); 204 }; 205 206 template <typename Device, typename T> 207 struct ApplyMomentum { 208 void operator()(const Device& d, typename TTypes<T>::Flat var, 209 typename TTypes<T>::Flat accum, 210 typename TTypes<T>::ConstScalar lr, 211 typename TTypes<T>::ConstFlat grad, 212 typename TTypes<T>::ConstScalar momentum, bool use_nesterov); 213 }; 214 215 template <typename Device, typename T> 216 struct ApplyKerasMomentum { 217 void operator()(const Device& d, typename TTypes<T>::Flat var, 218 typename TTypes<T>::Flat accum, 219 typename TTypes<T>::ConstScalar lr, 220 typename TTypes<T>::ConstFlat grad, 221 typename TTypes<T>::ConstScalar momentum, bool use_nesterov); 222 }; 223 224 template <typename Device, typename T, typename Tindex> 225 struct SparseApplyKerasMomentum { 226 Tindex operator()(const Device& d, typename TTypes<T>::Matrix var, 227 typename TTypes<T>::Matrix accum, 228 typename TTypes<T>::ConstScalar lr, 229 typename TTypes<T>::ConstMatrix grad, 230 typename TTypes<Tindex>::ConstFlat indices, 231 typename TTypes<T>::ConstScalar momentum, 232 bool use_nesterov); 233 }; 234 235 template <typename Device, typename T> 236 struct ApplyAdam { 237 void operator()(const Device& d, typename TTypes<T>::Flat var, 238 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, 239 typename TTypes<T>::ConstScalar beta1_power, 240 typename TTypes<T>::ConstScalar beta2_power, 241 typename TTypes<T>::ConstScalar lr, 242 typename TTypes<T>::ConstScalar beta1, 243 typename TTypes<T>::ConstScalar beta2, 244 typename TTypes<T>::ConstScalar epsilon, 245 typename TTypes<T>::ConstFlat grad, bool use_nesterov); 246 }; 247 248 template <typename Device, typename T> 249 struct ApplyAdamWithAmsgrad { 250 void operator()(const Device& d, typename TTypes<T>::Flat var, 251 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, 252 typename TTypes<T>::Flat vhat, 253 typename TTypes<T>::ConstScalar beta1_power, 254 typename TTypes<T>::ConstScalar beta2_power, 255 typename TTypes<T>::ConstScalar lr, 256 typename TTypes<T>::ConstScalar beta1, 257 typename TTypes<T>::ConstScalar beta2, 258 typename TTypes<T>::ConstScalar epsilon, 259 typename TTypes<T>::ConstFlat grad); 260 }; 261 262 template <typename Device, typename T> 263 struct ApplyAdaMax { 264 void operator()(const Device& d, typename TTypes<T>::Flat var, 265 typename TTypes<T>::Flat m, typename TTypes<T>::Flat v, 266 typename TTypes<T>::ConstScalar beta1_power, 267 typename TTypes<T>::ConstScalar lr, 268 typename TTypes<T>::ConstScalar beta1, 269 typename TTypes<T>::ConstScalar beta2, 270 typename TTypes<T>::ConstScalar epsilon, 271 typename TTypes<T>::ConstFlat grad); 272 }; 273 274 template <typename Device, typename T> 275 struct ApplyRMSProp { 276 void operator()(const Device& d, typename TTypes<T>::Flat var, 277 typename TTypes<T>::Flat ms, typename TTypes<T>::Flat mom, 278 typename TTypes<T>::ConstScalar lr, 279 typename TTypes<T>::ConstScalar rho, 280 typename TTypes<T>::ConstScalar momentum, 281 typename TTypes<T>::ConstScalar epsilon, 282 typename TTypes<T>::ConstFlat grad); 283 }; 284 285 template <typename Device, typename T> 286 struct ApplyCenteredRMSProp { 287 void operator()(const Device& d, typename TTypes<T>::Flat var, 288 typename TTypes<T>::Flat mg, typename TTypes<T>::Flat ms, 289 typename TTypes<T>::Flat mom, 290 typename TTypes<T>::ConstScalar lr, 291 typename TTypes<T>::ConstScalar rho, 292 typename TTypes<T>::ConstScalar momentum, 293 typename TTypes<T>::ConstScalar epsilon, 294 typename TTypes<T>::ConstFlat grad); 295 }; 296 297 template <typename Device, typename T> 298 struct ApplyAddSign { 299 void operator()(const Device& d, typename TTypes<T>::Flat var, 300 typename TTypes<T>::Flat m, 301 typename TTypes<T>::ConstScalar lr, 302 typename TTypes<T>::ConstScalar alpha, 303 typename TTypes<T>::ConstScalar sign_decay, 304 typename TTypes<T>::ConstScalar beta, 305 typename TTypes<T>::ConstFlat grad); 306 }; 307 308 template <typename Device, typename T> 309 struct ApplyPowerSign { 310 void operator()(const Device& d, typename TTypes<T>::Flat var, 311 typename TTypes<T>::Flat m, 312 typename TTypes<T>::ConstScalar lr, 313 typename TTypes<T>::ConstScalar logbase, 314 typename TTypes<T>::ConstScalar sign_decay, 315 typename TTypes<T>::ConstScalar beta, 316 typename TTypes<T>::ConstFlat grad); 317 }; 318 319 } // end namespace functor 320 } // end namespace tensorflow 321 322 #endif // TENSORFLOW_CORE_KERNELS_TRAINING_OPS_H_ 323