xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/NativeMultiheadAttnKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/Dispatch.h>
4 #include <ATen/OpMathType.h>
5 #include <ATen/Parallel.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/cpu/vec/vec.h>
8 #include <ATen/cpu/vec/functional.h>
9 #include <ATen/native/cpu/utils.h>
10 #include <ATen/native/transformers/attention.h>
11 #include <c10/util/irange.h>
12 
13 namespace at::native {
14 
15 namespace {
16 
17 template <typename scalar_t>
cpu_transform_bias_rescale_qkv(scalar_t * q_k_v_data,const scalar_t * qkv_data,const scalar_t * qkv_bias_data,int64_t B,int64_t T,int64_t D,int64_t num_head)18 void cpu_transform_bias_rescale_qkv(
19     scalar_t* q_k_v_data,
20     const scalar_t* qkv_data,
21     const scalar_t* qkv_bias_data,
22     int64_t B,
23     int64_t T,
24     int64_t D,
25     int64_t num_head) {
26 
27   int64_t dim_per_head = D / num_head;
28 
29   // shapes and strides:
30   //   qkv      : {B, T, 3, num_head, dim_per_head}
31   //   qkv_bias : {3, num_head, dim_per_head}
32   //   q_k_v    : {3, B, num_head, T, dim_per_head}
33   //
34   int64_t i_strideB = T * 3 * D;
35   int64_t i_strideT = 3 * D;
36   int64_t o_stride = B * num_head * T * dim_per_head;
37 
38   // inv_sqrt_dim_per_head in accumulate type
39   using acc_t = at::opmath_type<scalar_t>;
40   using Vec =  vec::Vectorized<acc_t>;
41   const acc_t s = 1.0 / std::sqrt(static_cast<acc_t>(dim_per_head));
42 
43   // parallel on {B, num_head, T}
44   int64_t grain_size = std::max(at::internal::GRAIN_SIZE / (3 * dim_per_head), (int64_t)1);
45   at::parallel_for(0, B * num_head * T, grain_size, [&](int64_t begin, int64_t end) {
46     int64_t b{0}, nh{0}, t{0};
47     data_index_init(begin, b, B, nh, num_head, t, T);
48 
49     for (const auto i : c10::irange(begin, end)) {
50       const scalar_t* q_in_ptr = qkv_data + b * i_strideB + t * i_strideT + 0 * D + nh * dim_per_head;
51       const scalar_t* k_in_ptr = qkv_data + b * i_strideB + t * i_strideT + 1 * D + nh * dim_per_head;
52       const scalar_t* v_in_ptr = qkv_data + b * i_strideB + t * i_strideT + 2 * D + nh * dim_per_head;
53 
54       const scalar_t* q_bias_ptr = qkv_bias_data + 0 * D + nh * dim_per_head;
55       const scalar_t* k_bias_ptr = qkv_bias_data + 1 * D + nh * dim_per_head;
56       const scalar_t* v_bias_ptr = qkv_bias_data + 2 * D + nh * dim_per_head;
57 
58       // we can use global index i here for output
59       scalar_t* q_out_ptr = q_k_v_data + 0 * o_stride + i * dim_per_head;
60       scalar_t* k_out_ptr = q_k_v_data + 1 * o_stride + i * dim_per_head;
61       scalar_t* v_out_ptr = q_k_v_data + 2 * o_stride + i * dim_per_head;
62 
63       // q = (q + bias) * inv_sqrt_dim_per_head
64       vec::map2<scalar_t>(
65           [s](Vec q, Vec q_bias) { return (q + q_bias) * Vec(s); },
66           q_out_ptr, q_in_ptr, q_bias_ptr, dim_per_head);
67 
68       // k = k + bias
69       vec::map2<scalar_t>([](Vec k, Vec k_bias) { return k + k_bias; },
70           k_out_ptr, k_in_ptr, k_bias_ptr, dim_per_head);
71 
72       // v = v + bias
73       vec::map2<scalar_t>([](Vec v, Vec v_bias) { return v + v_bias; },
74           v_out_ptr, v_in_ptr, v_bias_ptr, dim_per_head);
75 
76       // move to the next index
77       data_index_step(b, B, nh, num_head, t, T);
78     }
79   });
80 }
81 
transform_bias_rescale_qkv_kernel_impl(at::ScalarType type,void * _q_k_v,const void * _qkv,const void * _qkv_bias,int64_t B,int64_t T,int64_t D,int64_t num_head)82 void transform_bias_rescale_qkv_kernel_impl(
83     at::ScalarType type,
84     void* _q_k_v,
85     const void* _qkv,
86     const void* _qkv_bias,
87     int64_t B,
88     int64_t T,
89     int64_t D,
90     int64_t num_head) {
91 
92   AT_DISPATCH_FLOATING_TYPES_AND2(kHalf, kBFloat16, type, "transform_bias_rescale_qkv", [&] {
93     scalar_t* q_k_v = static_cast<scalar_t*>(_q_k_v);
94     const scalar_t* qkv = static_cast<const scalar_t*>(_qkv);
95     const scalar_t* qkv_bias = static_cast<const scalar_t*>(_qkv_bias);
96     cpu_transform_bias_rescale_qkv<scalar_t>(
97         q_k_v,
98         qkv,
99         qkv_bias,
100         B,
101         T,
102         D,
103         num_head);
104   });
105 }
106 
107 } // anonymous namespace
108 
109 REGISTER_DISPATCH(transform_bias_rescale_qkv_stub, &transform_bias_rescale_qkv_kernel_impl);
110 
111 } // at::native
112