Home
last modified time | relevance | path

Searched defs:grad_scale_ptr (Results 1 – 12 of 12) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/
H A DFusedSgdKernel.cu24 const float* grad_scale_ptr) { in sgd_math()
161 float* grad_scale_ptr = in _fused_sgd_with_momentum_kernel_cuda_() local
236 float* grad_scale_ptr = in _fused_sgd_with_momentum_kernel_cuda_() local
302 float* grad_scale_ptr = in _fused_sgd_kernel_cuda_() local
395 float* grad_scale_ptr = in _fused_sgd_kernel_cuda_() local
H A Dfused_adam_impl.cu28 const float* grad_scale_ptr = in _fused_adam_cuda_impl_() local
74 const float* grad_scale_ptr = in _fused_adam_cuda_impl_() local
H A Dfused_adamw_impl.cu29 const float* grad_scale_ptr = in _fused_adamw_cuda_impl_() local
75 const float* grad_scale_ptr = in _fused_adamw_cuda_impl_() local
H A Dfused_adam_amsgrad_impl.cu33 const float* grad_scale_ptr = in _fused_adam_amsgrad_cuda_impl_() local
84 const float* grad_scale_ptr = in _fused_adam_amsgrad_cuda_impl_() local
H A Dfused_adamw_amsgrad_impl.cu34 const float* grad_scale_ptr = in _fused_adamw_amsgrad_cuda_impl_() local
85 const float* grad_scale_ptr = in _fused_adamw_amsgrad_cuda_impl_() local
H A Dfused_adam_utils.cuh35 const float* grad_scale_ptr, in adam_math()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/
H A DFusedAdagradKernel.cpp26 const float* grad_scale_ptr, in adagrad_math()
95 const float* grad_scale_ptr, in adagrad_math()
150 const float* grad_scale_ptr) { in adagrad_fused_step_impl()
197 const float* grad_scale_ptr in fused_adagrad_kernel()
H A DFusedSGDKernel.cpp29 const float* grad_scale_ptr, in sgd_math()
121 const float* grad_scale_ptr, in sgd_math()
196 const float* grad_scale_ptr) { in sgd_fused_step_impl()
246 const float* grad_scale_ptr in fused_sgd_kernel()
H A DFusedAdamKernel.cpp35 const float* grad_scale_ptr, in adam_math()
178 const float* grad_scale_ptr, in adam_math()
280 const float* grad_scale_ptr) { in adam_fused_step_impl()
351 const float* grad_scale_ptr, in fused_adam_kernel()
/aosp_15_r20/external/pytorch/aten/src/ATen/native/
H A DFusedAdam.cpp35 const float* grad_scale_ptr = in _fused_adam_kernel_cpu_() local
111 const float* grad_scale_ptr = in _fused_adamw_kernel_cpu_() local
H A DFusedAdagrad.cpp29 const float* grad_scale_ptr = in _fused_adagrad_kernel_cpu_() local
H A DFusedSGD.cpp31 const float* grad_scale_ptr = in _fused_sgd_kernel_cpu_() local