1 #pragma once
2
3 #ifdef USE_ROCM
4
5 #include <aotriton/dtypes.h>
6 #include <aotriton/util.h>
7
8 ////////////////////////////////////////////////////////////////////////////////
9 // Common macros copied from cuda/mem_eff_attention/gemm_kernel_utils.h
10 ////////////////////////////////////////////////////////////////////////////////
11
12 #define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \
13 TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
14 TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
15 TORCH_CHECK(TENSOR.is_contiguous());
16
17 #define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \
18 TORCH_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \
19 TORCH_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \
20 TORCH_CHECK( \
21 TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous");
22
23 #define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \
24 TORCH_CHECK( \
25 uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned")
26
27 #define ASSIGN_CHECK_OVERFLOW(A, B) \
28 { \
29 A = B; \
30 TORCH_CHECK( \
31 B < std::numeric_limits<decltype(A)>::max(), #B " overflows"); \
32 }
33
34 namespace sdp {
35
36 namespace aotriton_adapter {
37
cast_dtype(caffe2::TypeMeta t_dtype)38 inline aotriton::DType cast_dtype(caffe2::TypeMeta t_dtype)
39 {
40 #define CAST_TYPE(aname, dtname) if (t_dtype == at::aname) return aotriton::DType::dtname
41 CAST_TYPE(kByte, kUInt8);
42 CAST_TYPE(kUInt16, kUInt16);
43 CAST_TYPE(kUInt32, kUInt32);
44 CAST_TYPE(kUInt64, kUInt64);
45 CAST_TYPE(kChar, kInt8);
46 CAST_TYPE(kShort, kInt16);
47 CAST_TYPE(kInt, kInt32);
48 CAST_TYPE(kLong, kInt64);
49 CAST_TYPE(kHalf, kFloat16);
50 CAST_TYPE(kFloat, kFloat32);
51 CAST_TYPE(kBFloat16, kBFloat16);
52 return aotriton::DType::kUnknown;
53 #undef CAST_TYPE
54 }
55
56 template<typename TargetType, int Rank>
57 struct IntArrayRefCaster {
58 // std::array<TargetType, Rank> cast(IntArrayRef);
59 };
60
61 template<typename TargetType>
62 struct IntArrayRefCaster<TargetType, 1> {
63 static auto cast(at::IntArrayRef ref) {
64 return std::array<TargetType, 1>{{ static_cast<TargetType>(ref.at(0)) }};
65 }
66 };
67
68 template<typename TargetType>
69 struct IntArrayRefCaster<TargetType, 2> {
70 static auto cast(at::IntArrayRef ref) {
71 return std::array<TargetType, 2>{{
72 static_cast<TargetType>(ref.at(0)),
73 static_cast<TargetType>(ref.at(1))
74 }};
75 }
76 };
77
78 template<typename TargetType>
79 struct IntArrayRefCaster<TargetType, 3> {
80 static auto cast(at::IntArrayRef ref) {
81 return std::array<TargetType, 3>{{
82 static_cast<TargetType>(ref.at(0)),
83 static_cast<TargetType>(ref.at(1)),
84 static_cast<TargetType>(ref.at(2))
85 }};
86 }
87 };
88
89 template<typename TargetType>
90 struct IntArrayRefCaster<TargetType, 4> {
91 static auto cast(at::IntArrayRef ref) {
92 return std::array<TargetType, 4>{{
93 static_cast<TargetType>(ref.at(0)),
94 static_cast<TargetType>(ref.at(1)),
95 static_cast<TargetType>(ref.at(2)),
96 static_cast<TargetType>(ref.at(3))
97 }};
98 }
99 };
100
101
102 template<int Rank = 4>
103 aotriton::TensorView<Rank> mk_aotensor(const at::Tensor& q, c10::string_view tensor_name)
104 {
105 const auto strides = q.strides();
106 int real_rank = strides.size();
107 if (real_rank != Rank) { // Lazy convertion of tensor_name
108 TORCH_CHECK(false,
109 std::string(tensor_name) + "'s rank should be " + std::to_string(Rank)
110 + " but is " + std::to_string(real_rank));
111 }
112 return aotriton::TensorView<Rank>(reinterpret_cast<intptr_t>(q.data_ptr()),
113 IntArrayRefCaster<uint64_t, Rank>::cast(q.sizes()),
114 IntArrayRefCaster<uint64_t, Rank>::cast(strides),
115 cast_dtype(q.dtype()));
116 }
117
118 inline aotriton::TensorView<0> mk_aoscalartensor(const at::Tensor& q)
119 {
120 return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(q.data_ptr()),
121 cast_dtype(q.dtype()));
122 }
123
124 inline aotriton::TensorView<0> mk_philoxtensor(const int64_t* ptr)
125 {
126 return aotriton::TensorView<0>(reinterpret_cast<intptr_t>(ptr),
127 aotriton::DType::kUInt64); // AOTriton excepts unsigned int64
128 }
129
130 } // namespace aotriton_adapter
131
132 } // namespace sdp
133
134 namespace at::native {
135
136 inline int64_t ceil_div(int64_t numerator, int64_t denominator) {
137 return (numerator + (denominator - 1)) / denominator;
138 }
139
140 }
141
142 #endif // USE_ROCM
143