1 #include <unordered_map>
2 #include <c10/core/impl/alloc_cpu.h>
3 #include <c10/core/Allocator.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/ArrayRef.h>
6
7 #include <torch/csrc/Device.h>
8 #include <torch/csrc/jit/serialization/pickler.h>
9 #include <c10/core/impl/DeviceGuardImplInterface.h>
10 #include <c10/macros/Macros.h>
11 #include <torch/extension.h>
12
13 #include <ATen/native/cpu/Loops.h>
14 #include <ATen/native/quantized/AffineQuantizer.h>
15 #include <ATen/native/DispatchStub.h>
16 #include <ATen/native/Resize.h>
17 #include <ATen/native/UnaryOps.h>
18 #include <ATen/native/CPUFallback.h>
19 #include <ATen/ops/abs_native.h>
20 #include <ATen/EmptyTensor.h>
21 #include <ATen/core/GeneratorForPrivateuseone.h>
22 #include <ATen/detail/PrivateUse1HooksInterface.h>
23 #include <ATen/ops/view.h>
24 #include <ATen/native/transformers/sdp_utils_cpp.h>
25 #include <ATen/native/transformers/attention.h>
26
27 static uint64_t add_counter = 0;
28 static uint64_t last_saved_value = 0;
29 static c10::DeviceIndex custom_device_index = 0;
30
31 static uint64_t abs_counter = 0;
32 static uint64_t last_abs_saved_value = 0;
33
34 static uint64_t storageImpl_counter = 0;
35 static uint64_t last_storageImpl_saved_value = 0;
36 // register guard
37 namespace at {
38 namespace detail {
39
40 C10_REGISTER_GUARD_IMPL(
41 PrivateUse1,
42 c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
43
44 }} // namespace at::detail
45
46 namespace {
47
48 // Using the simplest way to obtain continuous Tensor data and process it.
49 // This is a demo for using operand API, and you can add more complex logic
50 // for input and output tensor based on your custom device kernel.
abs_kernel(at::TensorIteratorBase & iter)51 void abs_kernel(at::TensorIteratorBase& iter) {
52 // Abs only have a input tensor and a output tensor.
53 auto& output_operand = iter.operand(0);
54 auto& input_operand = iter.operand(1);
55 auto& output_tensor_base = output_operand.tensor_base();
56 auto& input_tensor_base = input_operand.tensor_base();
57 TORCH_CHECK(!input_operand.original_tensor_base().defined(),
58 "input original tensor is defined.");
59 TORCH_CHECK(!output_operand.original_tensor_base().defined(),
60 "output original tensor is defined.");
61 // For easy test, only accept contiguous input tensor for calculate.
62 auto memory_format = input_tensor_base.suggest_memory_format();
63 TORCH_CHECK(input_tensor_base.is_contiguous(memory_format),
64 "Input tensor need be contiguous.");
65 // Add necessary restrictions to ensure the security of the demo.
66 TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(),
67 "Intput and output tensor size are not equal.");
68 // Common dtype is calculate in TensorIteratorBase.
69 TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float,
70 "Only support float type.")
71 // Using for loop for abs calculate.
72 auto abs_function = [](float* output_ptr, const float* input_ptr,
73 const int64_t NUM) {
74 for (int64_t i = 0; i < NUM; ++i) {
75 *(output_ptr + i) = std::abs(*(input_ptr + i));
76 }
77 };
78 // To simplify the logic of the test demo code,
79 // we only use contiguous tensor to calculate on device side.
80 // And using input tensor memory format.
81 if (iter.is_contiguous()) {
82 // Add for will_resize flag check. You can convert to differernt
83 // tensor memory format when will_resize is True.
84 // If TensorIteratorConfig resize_outputs_ flag is true, and there are two
85 // situations:
86 // 1) Out tensor is undefined, and TensorIterator set will_resize to true;
87 // 2) Out tensor is defined and tensor size is not equal to input tensor size;
88 // TensorIterator set will_resize to true, and call set_output_raw_strided
89 // to resize output tensor.
90 // When output operand will_resize flag is ture, dummy
91 // device can convert tensor to dummy device preferred memory format.
92 // Here we don't convert tensor memory format, because it will become complex
93 // when dummy device want keep same memory format for training network.
94 TORCH_CHECK(output_operand.will_resize,
95 "output operand will_resize flag need be True.");
96 abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel());
97 } else {
98 // Stride copy is not support for foo device, using cpu device instead.
99 // For abs op, the last situation is: output tensor is not contiguous with
100 // operand will_resize is False.
101 TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True.");
102 // Get a contiguous tensor with input memory format.
103 at::Tensor output = at::empty(output_tensor_base.sizes(),
104 input_tensor_base.options()
105 .memory_format(memory_format));
106 // For structured op which inheried from TensorIteratorBase, maybe you need to
107 // call set_output_raw_strided function to update output stored in op sturctured.
108 // abs op is no need to do this.
109 output_operand.exchange_tensor(c10::MaybeOwned<at::TensorBase>::owned(std::in_place, output));
110 abs_function((float*)output_operand.tensor_base().mutable_data_ptr(),
111 (float*)iter.data_ptr(1), iter.numel());
112 // Copy tensor base to original tensor base, and keep same scalar type and
113 // stride with cpu and gpu.
114 if (output_operand.original_tensor_base().defined() &&
115 !output_operand.original_tensor_base().is_same(output_operand.tensor_base())) {
116 output_operand.original_tensor().copy_(output_operand.tensor());
117 output_operand.restore_original_tensor();
118 }
119 }
120 }
121
quantize_tensor_per_tensor_affine_privateuse1(const at::Tensor & rtensor,at::Tensor & qtensor,double scale,int64_t zero_point)122 void quantize_tensor_per_tensor_affine_privateuse1(
123 const at::Tensor& rtensor,
124 at::Tensor& qtensor,
125 double scale,
126 int64_t zero_point) {
127 // do nothing
128 }
129
_fused_sdp_choice_privateuse1(const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const std::optional<at::Tensor> & attn_mask,double dropout_p,bool is_causal,std::optional<double> scale,bool enable_gqa)130 int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
131 const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
132 auto backend = sdp::SDPBackend::overrideable;
133 return static_cast<int64_t>(backend);
134 }
135 } // namespace
136
137 namespace at::native {
138
139 REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);
140 REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
141 REGISTER_PRIVATEUSE1_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_privateuse1);
142
143 } // namespace at::native
144 struct CustomBackendMetadata : public c10::BackendMeta {
145 // for testing this field will mutate when clone() is called by shallow_copy_from.
146 int backend_version_format_{-1};
147 int format_number_{-1};
148 mutable bool cloned_{false};
149 // define the constructor
CustomBackendMetadataCustomBackendMetadata150 CustomBackendMetadata(int backend_version_format, int format_number) :
151 backend_version_format_(backend_version_format), format_number_(format_number) {}
cloneCustomBackendMetadata152 c10::intrusive_ptr<c10::BackendMeta> clone(
153 const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
154 cloned_ = true;
155 return c10::BackendMeta::clone(ptr);
156 }
157 };
158
159 // we need to register two functions for serialization
for_serialization(const at::Tensor & t,std::unordered_map<std::string,bool> & m)160 void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
161 if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr() == nullptr) {
162 return;
163 }
164 auto tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
165 if (tmeta->backend_version_format_ == 1) {
166 m["backend_version_format"] = true;
167 }
168 if (tmeta->format_number_ == 29) {
169 m["format_number"] = true;
170 }
171 }
172
for_deserialization(const at::Tensor & t,std::unordered_map<std::string,bool> & m)173 void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
174 int backend_version_format{-1};
175 int format_number{-1};
176 if (m.find("backend_version_format") != m.end()) {
177 backend_version_format = 1;
178 }
179 if (m.find("format_number") != m.end()) {
180 format_number = 29;
181 }
182 c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
183 new CustomBackendMetadata(backend_version_format, format_number))};
184 t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
185 }
186
custom_serialization_registry()187 void custom_serialization_registry() {
188 torch::jit::TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1,
189 &for_serialization,
190 &for_deserialization);
191 }
192
193 //check if BackendMeta serialization correctly
check_backend_meta(const at::Tensor & t)194 bool check_backend_meta(const at::Tensor& t) {
195 if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr()) {
196 CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(
197 t.unsafeGetTensorImpl()->get_backend_meta());
198 if (tmeta->backend_version_format_==1 && tmeta->format_number_==29) {
199 return true;
200 }
201 }
202 return false;
203 }
204
205 // a fake set function is exposed to the Python side
custom_set_backend_meta(const at::Tensor & t)206 void custom_set_backend_meta(const at::Tensor& t) {
207 int backend_version_format{1};
208 int format_number{29};
209 c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
210 new CustomBackendMetadata(backend_version_format, format_number))};
211 t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
212 }
213
214 // A dummy storageImpl for our custom device, that secretly uses the CPU
make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,c10::SymInt size_bytes,c10::DataPtr data_ptr,c10::Allocator * allocator,bool resizable)215 c10::intrusive_ptr<c10::StorageImpl> make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,
216 c10::SymInt size_bytes,
217 c10::DataPtr data_ptr,
218 c10::Allocator* allocator,
219 bool resizable) {
220 c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl;
221 if (data_ptr == nullptr){
222 custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
223 c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable);
224 } else {
225 custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
226 c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable);
227 }
228 storageImpl_counter += 1;
229 return custom_storage_impl;
230 }
231
232 // Register our dummy storageImpl create method.
custom_storage_registry()233 void custom_storage_registry() {
234 c10::SetStorageImplCreate(c10::DeviceType::PrivateUse1, &make_custom_storage_impl);
235 }
236
custom_storageImpl_called()237 bool custom_storageImpl_called() {
238 if (storageImpl_counter > last_storageImpl_saved_value) {
239 last_storageImpl_saved_value = storageImpl_counter;
240 return true;
241 }
242 return false;
243 }
244
245 // basic dummy add function
custom_add_Tensor(const at::Tensor & self,const at::Tensor & other,const at::Scalar & alpha)246 at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {
247 add_counter += 1;
248 // Since this custom device is just for testing, not bothering to implement kernels.
249 return at::empty(self.sizes(), self.options());
250 }
251
252 // basic abs function
custom_abs_out(const at::Tensor & self,at::Tensor & out)253 at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
254 return at::native::abs_out(self, out);
255 }
256
257 // A dummy allocator for our custom device, that secretly uses the CPU
258 struct DummyCustomAllocator final : at::Allocator {
259 DummyCustomAllocator() = default;
allocateDummyCustomAllocator260 at::DataPtr allocate(size_t nbytes) override {
261 void* data = c10::alloc_cpu(nbytes);
262 return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, custom_device_index)};
263 }
264
ReportAndDeleteDummyCustomAllocator265 static void ReportAndDelete(void* ptr) {
266 if (!ptr) {
267 return;
268 }
269 c10::free_cpu(ptr);
270 }
271
raw_deleterDummyCustomAllocator272 at::DeleterFnPtr raw_deleter() const override {
273 return &ReportAndDelete;
274 }
275
copy_dataDummyCustomAllocator276 void copy_data(void* dest, const void* src, std::size_t count) const final {
277 default_copy_data(dest, src, count);
278 }
279 };
280
281 // Register our dummy allocator
282 static DummyCustomAllocator global_custom_alloc;
283 REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
284
285 // basic dummy empty function, so we can directly construct tensors on the custom device
286 // This dummy test device will just use the CPU allocator, and ignores pinned memory.
custom_empty_memory_format(at::IntArrayRef size,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,std::optional<at::MemoryFormat> memory_format)287 at::Tensor custom_empty_memory_format(at::IntArrayRef size,
288 std::optional<at::ScalarType> dtype,
289 std::optional<at::Layout> layout,
290 std::optional<at::Device> device,
291 std::optional<bool> pin_memory,
292 std::optional<at::MemoryFormat> memory_format) {
293 constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
294 return at::detail::empty_generic(size,
295 &global_custom_alloc,
296 private_use_ks,
297 c10::dtype_or_default(dtype),
298 memory_format);
299 }
custom_empty_symint(c10::IntArrayRef size,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,std::optional<at::MemoryFormat> memory_format)300 at::Tensor custom_empty_symint(c10::IntArrayRef size,
301 std::optional<at::ScalarType> dtype,
302 std::optional<at::Layout> layout,
303 std::optional<at::Device> device,
304 std::optional<bool> pin_memory,
305 std::optional<at::MemoryFormat> memory_format) {
306 constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
307 return at::detail::empty_generic(size,
308 &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
309 }
310
custom_fill__scalar(at::Tensor & self,const at::Scalar & value)311 at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
312 // Not bothering to implement.
313 return self;
314 }
315
316 // Unsafe using dummy device data_ptr to creat a cpu tensor, and shared data_ptr.
unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor & src)317 at::Tensor unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor& src) {
318 TORCH_CHECK(src.device().type() == c10::DeviceType::PrivateUse1,
319 "Only support dummy device.");
320 const auto& sizes_ = src.sizes();
321 const auto& strides_ = src.strides();
322 auto storage_offset_ = src.storage_offset();
323 at::detail::check_size_nonnegative(sizes_);
324
325 size_t size_bytes = at::detail::computeStorageNbytes(sizes_, strides_,
326 src.element_size(),
327 storage_offset_);
328
329 at::DataPtr data_ptr =
330 c10::InefficientStdFunctionContext::makeDataPtr(src.storage().mutable_data_ptr().get(),
331 [](void*){}, at::kCPU);
332
333 c10::Storage storage{c10::Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr),
334 /*allocator=*/&global_custom_alloc, /*resizeable=*/false};
335
336 constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU);
337 at::Tensor tensor = at::detail::make_tensor<c10::TensorImpl>(
338 std::move(storage), cpu_ks, src.dtype());
339
340 c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
341 tensor_impl->set_sizes_and_strides(sizes_, strides_);
342 tensor_impl->set_storage_offset(storage_offset_);
343 return tensor;
344 }
345
346 // basic dummy copy_() function, so we can copy from the custom device to/from CPU
custom__copy_from(const at::Tensor & self,const at::Tensor & dst,bool non_blocking)347 at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
348 TORCH_CHECK(
349 self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1,
350 "Dummy test only allows copy from cpu -> dummy device.");
351 TORCH_CHECK(
352 dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1,
353 "Dummy test only allows copy from cpu -> dummy device.");
354
355 // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
356 TORCH_CHECK(self.sizes() == dst.sizes());
357 TORCH_CHECK(self.scalar_type() == dst.scalar_type());
358
359 if (self.is_contiguous() && dst.is_contiguous()) {
360 std::memcpy(dst.storage().data_ptr().get(),
361 self.storage().data_ptr().get(),
362 self.storage().nbytes());
363 } else {
364 // Using cpu tensor to accomplishment stride copy.
365 auto convert_to_cpu_tensor = [](const at::Tensor& src) -> at::Tensor {
366 if (src.device().type() == c10::DeviceType::PrivateUse1) {
367 return unsafe_create_cpu_tensor_from_dummy_tensor(src);
368 } else {
369 return src;
370 }
371 };
372 at::Tensor cpu_self = convert_to_cpu_tensor(self);
373 at::Tensor cpu_dst = convert_to_cpu_tensor(dst);
374 cpu_dst.copy_(cpu_self);
375 }
376
377 return dst;
378 }
379
custom__copy_from_and_resize(const at::Tensor & self,const at::Tensor & dst)380 at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) {
381 return custom__copy_from(self, dst, false);
382 }
383
custom_empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<at::ScalarType> dtype_opt,std::optional<at::Layout> layout_opt,std::optional<at::Device> device_opt,std::optional<bool> pin_memory_opt)384 at::Tensor custom_empty_strided(c10::IntArrayRef size,
385 c10::IntArrayRef stride,
386 std::optional<at::ScalarType> dtype_opt,
387 std::optional<at::Layout> layout_opt,
388 std::optional<at::Device> device_opt,
389 std::optional<bool> pin_memory_opt) {
390 constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
391 auto dtype = c10::dtype_or_default(dtype_opt);
392 return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
393 }
394
395 // Some set operations for the basic use case
custom_set_source_Storage(at::Tensor & result,c10::Storage src)396 at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
397 int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
398 c10::IntArrayRef stride = {};
399 result.unsafeGetTensorImpl()->set_storage_offset(0);
400 at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
401 at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
402 new_size, stride_opt,
403 /*resize_storage=*/!result.is_meta());
404 return result;
405 }
406
407 // Some set operations for the basic use case
custom_set_source_Storage_storage_offset(at::Tensor & result,c10::Storage storage,int64_t storage_offset,c10::IntArrayRef size,c10::IntArrayRef stride)408 at::Tensor& custom_set_source_Storage_storage_offset(at::Tensor& result,
409 c10::Storage storage,
410 int64_t storage_offset,
411 c10::IntArrayRef size,
412 c10::IntArrayRef stride) {
413 result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
414 at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
415 at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
416 size, stride_opt,
417 /*resize_storage=*/!result.is_meta());
418 return result;
419 }
420
custom_resize_(const at::Tensor & self,at::IntArrayRef size,std::optional<at::MemoryFormat> optional_memory_format)421 const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
422 std::optional<at::MemoryFormat> optional_memory_format) {
423 at::TensorImpl* tensor_impl = self.unsafeGetTensorImpl();
424 tensor_impl->set_sizes_contiguous(size);
425 const auto itemsize = tensor_impl->dtype().itemsize();
426 const auto offset = tensor_impl->storage_offset();
427 const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset);
428 // Dummy device is using cpu allocator, so here just call cpu
429 // function maybe_resize_storage_cpu in aten/src/ATen/native/Resize.h
430 // to get a sufficient memory space.
431 at::native::maybe_resize_storage_cpu(tensor_impl, storage_size);
432 if (optional_memory_format.has_value()) {
433 auto memory_format =
434 optional_memory_format.value();
435 TORCH_CHECK(
436 memory_format != at::MemoryFormat::Preserve,
437 "Unsupported memory format",
438 memory_format);
439 tensor_impl->empty_tensor_restride(memory_format);
440 }
441 return self;
442 }
443
444 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable(const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const std::optional<at::Tensor> & attn_bias,double dropout_p,bool is_causal,bool return_debug_mask,std::optional<double> scale)445 custom_scaled_dot_product_fused_attention_overrideable(
446 const at::Tensor & query,
447 const at::Tensor & key,
448 const at::Tensor & value,
449 const std::optional<at::Tensor> & attn_bias,
450 double dropout_p,
451 bool is_causal,
452 bool return_debug_mask,
453 std::optional<double> scale) {
454 const int64_t batch_size = query.size(0);
455 const int64_t num_heads = query.size(1);
456 const int64_t head_dim_qk = query.size(3);
457 const int64_t head_dim_v = value.size(3);
458 const int64_t max_seqlen_q = query.size(2);
459 const int64_t max_seqlen_kv = key.size(2);
460
461 auto opts = query.options();
462 auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
463 auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
464 auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
465 opts.dtype(at::kFloat));
466 auto philox_seed = at::empty({}, at::dtype(at::kLong));
467 auto philox_offset = at::empty({}, at::dtype(at::kLong));
468
469 return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask);
470 }
471 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable_backward(const at::Tensor & grad_out,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & attn_bias,std::array<bool,4> grad_input_mask,const at::Tensor & out,const at::Tensor & logsumexp,const at::Tensor & cum_seq_q,const at::Tensor & cum_seq_k,int64_t max_q,int64_t max_k,double dropout_p,bool is_causal,const at::Tensor & philox_seed,const at::Tensor & philox_offset,std::optional<double> scale)472 custom_scaled_dot_product_fused_attention_overrideable_backward(
473 const at::Tensor & grad_out,
474 const at::Tensor & query,
475 const at::Tensor & key,
476 const at::Tensor & value,
477 const at::Tensor & attn_bias,
478 std::array<bool,4> grad_input_mask,
479 const at::Tensor & out,
480 const at::Tensor & logsumexp,
481 const at::Tensor & cum_seq_q,
482 const at::Tensor & cum_seq_k,
483 int64_t max_q,
484 int64_t max_k,
485 double dropout_p,
486 bool is_causal,
487 const at::Tensor & philox_seed,
488 const at::Tensor & philox_offset,
489 std::optional<double> scale) {
490 return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
491 at::empty_like(query),
492 at::empty_like(key),
493 at::empty_like(value),
494 at::empty_like(attn_bias));
495 }
496
497 // This macro does the heavy lifting.
498 // With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
499 // For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
500 // Later in this file, we map a custom device to the PrivateUse1 device type,
501 // which allows user code that puts a tensor on your custom_device to eventually get plumbed
502 // into the kernels registered here.
503 //
504 // This macro registers your kernels to the PyTorch Dispatcher.
505 // More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten,PrivateUse1,m)506 TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
507 m.impl("abs.out", &custom_abs_out);
508 m.impl("add.Tensor", &custom_add_Tensor);
509 m.impl("empty.memory_format", &custom_empty_symint);
510 m.impl("fill_.Scalar", &custom_fill__scalar);
511 m.impl("_copy_from", &custom__copy_from);
512 m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
513 m.impl("empty_strided", &custom_empty_strided);
514 m.impl("set_.source_Storage", &custom_set_source_Storage);
515 m.impl("set_.source_Storage_storage_offset",&custom_set_source_Storage_storage_offset);
516 m.impl("resize_", &custom_resize_);
517 m.impl("as_strided", at::native::as_strided_tensorimpl);
518 m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
519 m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
520 m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
521 m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
522 }
523
custom_cpu_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)524 void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
525 at::native::cpu_fallback(op, stack);
526 }
527
TORCH_LIBRARY_IMPL(aten,PrivateUse1,m)528 TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
529 m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
530 m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
531 m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
532 m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
533 m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
534 }
535
536 // This basic implementation doesn't bother dealing with different device indices
537 // (e.g. custom_device:0 vs. custom_device:1).
538 // We could do that by letting the user pass in a device index in our exposed device function.
539 // Note that if you do that, you'll also need to register a device guard to core.
540 // See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
get_custom_device()541 c10::Device get_custom_device() {
542 return c10::Device(c10::DeviceType::PrivateUse1, 0);
543 }
544
custom_add_called()545 bool custom_add_called() {
546 bool called = false;
547 if (add_counter > last_saved_value) {
548 called = true;
549 last_saved_value = add_counter;
550 }
551 return called;
552 }
553
554 class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
555 public:
556 // Constructors
PrivateGeneratorImpl(c10::DeviceIndex device_index)557 PrivateGeneratorImpl(c10::DeviceIndex device_index) {
558 device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
559 key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
560 }
561 ~PrivateGeneratorImpl() override = default;
562 };
563
564 // this is used to register generator
make_generator_privateuse1(c10::DeviceIndex device_index)565 at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
566 return at::make_generator<PrivateGeneratorImpl>(device_index);
567 }
568
register_generator_first()569 void register_generator_first() {
570 REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
571 }
572
register_generator_second()573 void register_generator_second() {
574 REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
575 }
576
set_custom_device_index(c10::DeviceIndex device_index)577 void set_custom_device_index(c10::DeviceIndex device_index) {
578 custom_device_index = device_index;
579 }
580
581 // a global flag used for dummy pin_memory of custom device
582 bool custom_pinned_flag = false;
583
584 struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
585
586 struct FooHooksInterface : public at::PrivateUse1HooksInterface {
FooHooksInterfaceFooHooksInterface587 FooHooksInterface(FooHooksArgs) {}
588 ~FooHooksInterface() override = default;
getDefaultGeneratorFooHooksInterface589 const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override {
590 static auto device_gen = make_generator_privateuse1(device_index);
591 return device_gen;
592 }
593 // this is a simple implementation, custom_pinned_flag will be set as true
594 // once tensor.pin_memory() is called. And then tensor.is_pinned()
595 // always return true no matter what tensor it's called on.
isPinnedPtrFooHooksInterface596 bool isPinnedPtr(const void* data) const override {
597 return custom_pinned_flag;
598 }
getPinnedMemoryAllocatorFooHooksInterface599 c10::Allocator* getPinnedMemoryAllocator() const override {
600 custom_pinned_flag = true;
601 return c10::GetCPUAllocator();
602 }
603 };
604
605 TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
606 C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs)
607 // Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
608 C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "FooHooks", FooHooksInterface)
609
610 static at::PrivateUse1HooksInterface* privateuse1_hooks_local = nullptr;
get_private_hooks()611 static at::PrivateUse1HooksInterface* get_private_hooks() {
612 static c10::once_flag once;
613 c10::call_once(once, [] {
614 privateuse1_hooks_local = PrivateUse1HooksRegistry()->Create("FooHooks", {}).release();
615 if (!privateuse1_hooks_local) {
616 privateuse1_hooks_local = new FooHooksInterface(FooHooksArgs{});
617 }
618 });
619 return privateuse1_hooks_local;
620 }
621
register_hook()622 void register_hook() {
623 at::RegisterPrivateUse1HooksInterface(get_private_hooks());
624 }
625
is_register_hook()626 bool is_register_hook() {
627 return privateuse1_hooks_local != nullptr;
628 }
629
default_generator(c10::DeviceIndex device_index)630 const at::Generator& default_generator(c10::DeviceIndex device_index) {
631 return at::globalContext().defaultGenerator(at::Device(c10::DeviceType::PrivateUse1, device_index));;
632 }
633
fallback_with_undefined_tensor()634 void fallback_with_undefined_tensor() {
635 at::Tensor first = at::empty((2,3)).to(at::DeviceType::PrivateUse1);
636 at::Tensor second = at::Tensor();
637 at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1);
638 at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1);
639 at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1);
640 at::TensorList tensors = {first, first};
641 at::TensorList undefined_tensors = {first, second};
642 at::TensorList steps = {step, step};
643 return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors,
644 steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false,
645 grad_scale, found_inf);
646 }
647
648 struct CustomAutogradFnReturnsSelf : public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
649
forwardCustomAutogradFnReturnsSelf650 static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
651 return self;
652 }
653
backwardCustomAutogradFnReturnsSelf654 static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
655 return {grad_output[0] * 0.5};
656 }
657 };
658
659 struct CustomAutogradFnAliasing : public torch::autograd::Function<CustomAutogradFnAliasing> {
660
forwardCustomAutogradFnAliasing661 static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
662 return self.view_symint(self.sym_sizes());
663 }
664
backwardCustomAutogradFnAliasing665 static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
666 return {grad_output[0] * 0.5};
667 }
668 };
669
custom_autograd_fn_returns_self(at::Tensor x)670 at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
671 return CustomAutogradFnReturnsSelf::apply(x);
672 }
673
custom_autograd_fn_aliasing(at::Tensor x)674 at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
675 return CustomAutogradFnAliasing::apply(x);
676 }
677
678 // Here, we're exposing a custom device object that corresponds to our custom backend.
679 // We do this using pybind: exposing an "extension_name.custom_device()" function in python,
680 // that's implemented in C++.
681 // The implementation in this file maps directly to the `PrivateUse1` device type.
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)682 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
683 m.def("custom_device", &get_custom_device, "get custom device object");
684 m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
685 m.def("register_generator_first", ®ister_generator_first, "register generator for custom device firstly");
686 m.def("register_generator_second", ®ister_generator_second, "register generator for custom device secondly");
687 m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
688 m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
689 m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
690 m.def("custom_set_backend_meta", &custom_set_backend_meta, "a fake set tensor BackendMeta function");
691 m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly");
692 m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function");
693 m.def("register_hook", ®ister_hook, "register_hook for privateuse1");
694 m.def("is_register_hook", &is_register_hook, "is_register_hook for privateuse1");
695 m.def("default_generator", &default_generator, "default_generator for privateuse1");
696 m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
697
698 // Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++
699 m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self);
700 }
701
TORCH_LIBRARY(_test_funcs,m)702 TORCH_LIBRARY(_test_funcs, m) {
703 m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
704 }
TORCH_LIBRARY_IMPL(_test_funcs,AutogradCPU,m)705 TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) {
706 m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing);
707 }
708