Searched defs:grad_scale_ptr (Results 1 – 12 of 12) sorted by relevance
24 const float* grad_scale_ptr) { in sgd_math()161 float* grad_scale_ptr = in _fused_sgd_with_momentum_kernel_cuda_() local236 float* grad_scale_ptr = in _fused_sgd_with_momentum_kernel_cuda_() local302 float* grad_scale_ptr = in _fused_sgd_kernel_cuda_() local395 float* grad_scale_ptr = in _fused_sgd_kernel_cuda_() local
28 const float* grad_scale_ptr = in _fused_adam_cuda_impl_() local74 const float* grad_scale_ptr = in _fused_adam_cuda_impl_() local
29 const float* grad_scale_ptr = in _fused_adamw_cuda_impl_() local75 const float* grad_scale_ptr = in _fused_adamw_cuda_impl_() local
33 const float* grad_scale_ptr = in _fused_adam_amsgrad_cuda_impl_() local84 const float* grad_scale_ptr = in _fused_adam_amsgrad_cuda_impl_() local
34 const float* grad_scale_ptr = in _fused_adamw_amsgrad_cuda_impl_() local85 const float* grad_scale_ptr = in _fused_adamw_amsgrad_cuda_impl_() local
35 const float* grad_scale_ptr, in adam_math()
26 const float* grad_scale_ptr, in adagrad_math()95 const float* grad_scale_ptr, in adagrad_math()150 const float* grad_scale_ptr) { in adagrad_fused_step_impl()197 const float* grad_scale_ptr in fused_adagrad_kernel()
29 const float* grad_scale_ptr, in sgd_math()121 const float* grad_scale_ptr, in sgd_math()196 const float* grad_scale_ptr) { in sgd_fused_step_impl()246 const float* grad_scale_ptr in fused_sgd_kernel()
35 const float* grad_scale_ptr, in adam_math()178 const float* grad_scale_ptr, in adam_math()280 const float* grad_scale_ptr) { in adam_fused_step_impl()351 const float* grad_scale_ptr, in fused_adam_kernel()
35 const float* grad_scale_ptr = in _fused_adam_kernel_cpu_() local111 const float* grad_scale_ptr = in _fused_adamw_kernel_cpu_() local
29 const float* grad_scale_ptr = in _fused_adagrad_kernel_cpu_() local
31 const float* grad_scale_ptr = in _fused_sgd_kernel_cpu_() local