xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/transformers/hip/aotriton_adapter.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #ifdef USE_ROCM
4 
5 #include <aotriton/dtypes.h>
6 #include <aotriton/util.h>
7 
8 ////////////////////////////////////////////////////////////////////////////////
9 // Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h
10 ////////////////////////////////////////////////////////////////////////////////
11 
12 #define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR)                            \
13   TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor");     \
14   TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
15   TORCH_CHECK(TENSOR.is_contiguous());
16 
17 #define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR)                        \
18   TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor");     \
19   TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
20   TORCH_CHECK(                                                         \
21       TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
22 
23 #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
24   TORCH_CHECK(                         \
25       uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
26 
27 #define ASSIGN_CHECK_OVERFLOW(A, B)                                    \
28   {                                                                    \
29     A = B;                                                             \
30     TORCH_CHECK(                                                    \
31         B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
32   }
33 
34 namespace sdp {
35 
36 namespace aotriton_adapter {
37 
cast_dtype(caffe2::TypeMeta t_dtype)38 inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype)
39 {
40 #define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname
41   CAST_TYPE(kByte, kUInt8);
42   CAST_TYPE(kUInt16, kUInt16);
43   CAST_TYPE(kUInt32, kUInt32);
44   CAST_TYPE(kUInt64, kUInt64);
45   CAST_TYPE(kChar, kInt8);
46   CAST_TYPE(kShort, kInt16);
47   CAST_TYPE(kInt, kInt32);
48   CAST_TYPE(kLong, kInt64);
49   CAST_TYPE(kHalf, kFloat16);
50   CAST_TYPE(kFloat, kFloat32);
51   CAST_TYPE(kBFloat16, kBFloat16);
52   return aotriton::DType::kUnknown;
53 #undef CAST_TYPE
54 }
55 
56 template<typename TargetType, int Rank>
57 struct IntArrayRefCaster {
58   // std::array<TargetType, Rank> cast(IntArrayRef);
59 };
60 
61 template<typename TargetType>
62 struct IntArrayRefCaster<TargetType, 1> {
63   static auto cast(at::IntArrayRef ref) {
64     return std::array<TargetType, 1>{{ static_cast<TargetType>(ref.at(0)) }};
65   }
66 };
67 
68 template<typename TargetType>
69 struct IntArrayRefCaster<TargetType, 2> {
70   static auto cast(at::IntArrayRef ref) {
71     return std::array<TargetType, 2>{{
72       static_cast<TargetType>(ref.at(0)),
73       static_cast<TargetType>(ref.at(1))
74     }};
75   }
76 };
77 
78 template<typename TargetType>
79 struct IntArrayRefCaster<TargetType, 3> {
80   static auto cast(at::IntArrayRef ref) {
81     return std::array<TargetType, 3>{{
82       static_cast<TargetType>(ref.at(0)),
83       static_cast<TargetType>(ref.at(1)),
84       static_cast<TargetType>(ref.at(2))
85     }};
86   }
87 };
88 
89 template<typename TargetType>
90 struct IntArrayRefCaster<TargetType, 4> {
91   static auto cast(at::IntArrayRef ref) {
92     return std::array<TargetType, 4>{{
93       static_cast<TargetType>(ref.at(0)),
94       static_cast<TargetType>(ref.at(1)),
95       static_cast<TargetType>(ref.at(2)),
96       static_cast<TargetType>(ref.at(3))
97     }};
98   }
99 };
100 
101 
102 template<int Rank = 4>
103 aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
104 {
105   const auto strides = q.strides();
106   int real_rank = strides.size();
107   if (real_rank != Rank) {  // Lazy convertion of tensor_name
108     TORCH_CHECK(false,
109                 std::string(tensor_name) + "'s rank should be " + std::to_string(Rank)
110                 + " but is " + std::to_string(real_rank));
111   }
112   return aotriton::TensorView<Rank>(reinterpret_cast<intptr_t>(q.data_ptr()),
113                                     IntArrayRefCaster<uint64_t, Rank>::cast(q.sizes()),
114                                     IntArrayRefCaster<uint64_t, Rank>::cast(strides),
115                                     cast_dtype(q.dtype()));
116 }
117 
118 inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q)
119 {
120   return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(q.data_ptr()),
121                                  cast_dtype(q.dtype()));
122 }
123 
124 inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr)
125 {
126   return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(ptr),
127                                  aotriton::DType::kUInt64);  // AOTriton excepts unsigned int64
128 }
129 
130 } // namespace aotriton_adapter
131 
132 } // namespace sdp
133 
134 namespace at::native {
135 
136 inline int64_t ceil_div(int64_t numerator, int64_t denominator) {
137   return (numerator + (denominator - 1)) / denominator;
138 }
139 
140 }
141 
142 #endif // USE_ROCM
143