Home
last modified time | relevance | path

Searched defs:MatmulGradK (Results 1 – 1 of 1) sorted by relevance

/aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/cuda/mem_eff_attention/
H A Dkernel_backward.h546 struct MatmulGradK { struct
548 using ThreadblockShape =
550 using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>;
551 using InstructionShape = typename GemmType::InstructionShape;
553 using DefaultGemm = cutlass::gemm::kernel::DefaultGemm<
574 using WarpIteratorA = typename cutlass::gemm::threadblock::
580 using DefaultMmaFromSmemN =
586 using DefaultMmaFromSmemT =
593 using DefaultMmaFromSmem = typename cutlass::platform::conditional<
597 using Mma = typename DefaultMmaFromSmem::Mma;
[all …]