xref: /aosp_15_r20/external/pytorch/test/cpp_extensions/open_registration_extension.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <unordered_map>
2 #include <c10/core/impl/alloc_cpu.h>
3 #include <c10/core/Allocator.h>
4 #include <c10/core/ScalarType.h>
5 #include <c10/util/ArrayRef.h>
6 
7 #include <torch/csrc/Device.h>
8 #include <torch/csrc/jit/serialization/pickler.h>
9 #include <c10/core/impl/DeviceGuardImplInterface.h>
10 #include <c10/macros/Macros.h>
11 #include <torch/extension.h>
12 
13 #include <ATen/native/cpu/Loops.h>
14 #include <ATen/native/quantized/AffineQuantizer.h>
15 #include <ATen/native/DispatchStub.h>
16 #include <ATen/native/Resize.h>
17 #include <ATen/native/UnaryOps.h>
18 #include <ATen/native/CPUFallback.h>
19 #include <ATen/ops/abs_native.h>
20 #include <ATen/EmptyTensor.h>
21 #include <ATen/core/GeneratorForPrivateuseone.h>
22 #include <ATen/detail/PrivateUse1HooksInterface.h>
23 #include <ATen/ops/view.h>
24 #include <ATen/native/transformers/sdp_utils_cpp.h>
25 #include <ATen/native/transformers/attention.h>
26 
27 static uint64_t add_counter = 0;
28 static uint64_t last_saved_value = 0;
29 static c10::DeviceIndex custom_device_index = 0;
30 
31 static uint64_t abs_counter = 0;
32 static uint64_t last_abs_saved_value = 0;
33 
34 static uint64_t storageImpl_counter = 0;
35 static uint64_t last_storageImpl_saved_value = 0;
36 // register guard
37 namespace at {
38 namespace detail {
39 
40 C10_REGISTER_GUARD_IMPL(
41     PrivateUse1,
42     c10::impl::NoOpDeviceGuardImpl<DeviceType::PrivateUse1>);
43 
44 }} // namespace at::detail
45 
46 namespace {
47 
48 // Using the simplest way to obtain continuous Tensor data and process it.
49 // This is a demo for using operand API, and you can add more complex logic
50 // for input and output tensor based on your custom device kernel.
abs_kernel(at::TensorIteratorBase & iter)51 void abs_kernel(at::TensorIteratorBase& iter) {
52   // Abs only have a input tensor and a output tensor.
53   auto& output_operand = iter.operand(0);
54   auto& input_operand = iter.operand(1);
55   auto& output_tensor_base = output_operand.tensor_base();
56   auto& input_tensor_base = input_operand.tensor_base();
57   TORCH_CHECK(!input_operand.original_tensor_base().defined(),
58     "input original tensor is defined.");
59   TORCH_CHECK(!output_operand.original_tensor_base().defined(),
60     "output original tensor is defined.");
61   // For easy test, only accept contiguous input tensor for calculate.
62   auto memory_format = input_tensor_base.suggest_memory_format();
63   TORCH_CHECK(input_tensor_base.is_contiguous(memory_format),
64     "Input tensor need be contiguous.");
65   // Add necessary restrictions to ensure the security of the demo.
66   TORCH_CHECK(input_tensor_base.sizes() == output_tensor_base.sizes(),
67     "Intput and output tensor size are not equal.");
68   // Common dtype is calculate in TensorIteratorBase.
69   TORCH_CHECK(iter.common_dtype() == at::ScalarType::Float,
70     "Only support float type.")
71   // Using for loop for abs calculate.
72   auto abs_function = [](float* output_ptr, const float* input_ptr,
73                          const int64_t NUM) {
74     for (int64_t i = 0; i < NUM; ++i) {
75       *(output_ptr + i) = std::abs(*(input_ptr + i));
76     }
77   };
78   // To simplify the logic of the test demo code,
79   // we only use contiguous tensor to calculate on device side.
80   // And using input tensor memory format.
81   if (iter.is_contiguous()) {
82     // Add for will_resize flag check. You can convert to differernt
83     // tensor memory format when will_resize is True.
84     // If TensorIteratorConfig resize_outputs_ flag is true, and there are two
85     // situations:
86     // 1) Out tensor is undefined, and TensorIterator set will_resize to true;
87     // 2) Out tensor is defined and tensor size is not equal to input tensor size;
88     //    TensorIterator set will_resize to true, and call set_output_raw_strided
89     //    to resize output tensor.
90     // When output operand will_resize flag is ture, dummy
91     // device can convert tensor to dummy device preferred memory format.
92     // Here we don't convert tensor memory format, because it will become complex
93     // when dummy device want keep same memory format for training network.
94     TORCH_CHECK(output_operand.will_resize,
95       "output operand will_resize flag need be True.");
96     abs_function((float*)iter.data_ptr(0), (float*)iter.data_ptr(1), iter.numel());
97   } else {
98     // Stride copy is not support for foo device, using cpu device instead.
99     // For abs op, the last situation is: output tensor is not contiguous with
100     // operand will_resize is False.
101     TORCH_CHECK(!output_operand.will_resize, "output operand will_resize is True.");
102     // Get a contiguous tensor with input memory format.
103     at::Tensor output = at::empty(output_tensor_base.sizes(),
104                                   input_tensor_base.options()
105                                                    .memory_format(memory_format));
106     // For structured op which inheried from TensorIteratorBase, maybe you need to
107     // call set_output_raw_strided function to update output stored in op sturctured.
108     // abs op is no need to do this.
109     output_operand.exchange_tensor(c10::MaybeOwned<at::TensorBase>::owned(std::in_place, output));
110     abs_function((float*)output_operand.tensor_base().mutable_data_ptr(),
111                  (float*)iter.data_ptr(1), iter.numel());
112     // Copy tensor base to original tensor base, and keep same scalar type and
113     // stride with cpu and gpu.
114     if (output_operand.original_tensor_base().defined() &&
115         !output_operand.original_tensor_base().is_same(output_operand.tensor_base())) {
116       output_operand.original_tensor().copy_(output_operand.tensor());
117       output_operand.restore_original_tensor();
118     }
119   }
120 }
121 
quantize_tensor_per_tensor_affine_privateuse1(const at::Tensor & rtensor,at::Tensor & qtensor,double scale,int64_t zero_point)122 void quantize_tensor_per_tensor_affine_privateuse1(
123     const at::Tensor& rtensor,
124     at::Tensor& qtensor,
125     double scale,
126     int64_t zero_point) {
127     // do nothing
128 }
129 
_fused_sdp_choice_privateuse1(const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const std::optional<at::Tensor> & attn_mask,double dropout_p,bool is_causal,std::optional<double> scale,bool enable_gqa)130 int64_t _fused_sdp_choice_privateuse1(const at::Tensor & query, const at::Tensor & key, const at::Tensor & value,
131     const std::optional<at::Tensor> & attn_mask, double dropout_p, bool is_causal, std::optional<double> scale, bool enable_gqa){
132   auto backend = sdp::SDPBackend::overrideable;
133   return static_cast<int64_t>(backend);
134 }
135 } // namespace
136 
137 namespace at::native {
138 
139 REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &abs_kernel);
140 REGISTER_PRIVATEUSE1_DISPATCH(quantize_tensor_per_tensor_affine_stub, &quantize_tensor_per_tensor_affine_privateuse1);
141 REGISTER_PRIVATEUSE1_DISPATCH(_fused_sdp_choice_stub, &_fused_sdp_choice_privateuse1);
142 
143 } // namespace at::native
144 struct CustomBackendMetadata : public c10::BackendMeta {
145   // for testing this field will mutate when clone() is called by shallow_copy_from.
146   int backend_version_format_{-1};
147   int format_number_{-1};
148   mutable bool cloned_{false};
149   // define the constructor
CustomBackendMetadataCustomBackendMetadata150   CustomBackendMetadata(int backend_version_format, int format_number) :
151       backend_version_format_(backend_version_format), format_number_(format_number) {}
cloneCustomBackendMetadata152   c10::intrusive_ptr<c10::BackendMeta> clone(
153       const c10::intrusive_ptr<c10::BackendMeta>& ptr) const override {
154     cloned_ = true;
155     return c10::BackendMeta::clone(ptr);
156   }
157 };
158 
159 // we need to register two functions for serialization
for_serialization(const at::Tensor & t,std::unordered_map<std::string,bool> & m)160 void for_serialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
161   if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr() == nullptr) {
162     return;
163   }
164   auto tmeta = dynamic_cast<CustomBackendMetadata*>(t.unsafeGetTensorImpl()->get_backend_meta());
165   if (tmeta->backend_version_format_ == 1) {
166     m["backend_version_format"] = true;
167   }
168   if (tmeta->format_number_ == 29) {
169     m["format_number"] = true;
170   }
171 }
172 
for_deserialization(const at::Tensor & t,std::unordered_map<std::string,bool> & m)173 void for_deserialization(const at::Tensor& t, std::unordered_map<std::string, bool>& m) {
174   int backend_version_format{-1};
175   int format_number{-1};
176   if (m.find("backend_version_format") != m.end()) {
177     backend_version_format = 1;
178   }
179   if (m.find("format_number") != m.end()) {
180     format_number = 29;
181   }
182   c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
183       new CustomBackendMetadata(backend_version_format, format_number))};
184   t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
185 }
186 
custom_serialization_registry()187 void custom_serialization_registry() {
188   torch::jit::TensorBackendMetaRegistry(c10::DeviceType::PrivateUse1,
189                                         &for_serialization,
190                                         &for_deserialization);
191 }
192 
193 //check if BackendMeta serialization correctly
check_backend_meta(const at::Tensor & t)194 bool check_backend_meta(const at::Tensor& t) {
195   if (t.unsafeGetTensorImpl()->get_backend_meta_intrusive_ptr()) {
196     CustomBackendMetadata* tmeta = dynamic_cast<CustomBackendMetadata*>(
197         t.unsafeGetTensorImpl()->get_backend_meta());
198     if (tmeta->backend_version_format_==1 && tmeta->format_number_==29) {
199       return true;
200     }
201   }
202   return false;
203 }
204 
205 // a fake set function is exposed to the Python side
custom_set_backend_meta(const at::Tensor & t)206 void custom_set_backend_meta(const at::Tensor& t) {
207   int backend_version_format{1};
208   int format_number{29};
209   c10::intrusive_ptr<c10::BackendMeta> new_tmeta{std::unique_ptr<c10::BackendMeta>(
210       new CustomBackendMetadata(backend_version_format, format_number))};
211   t.unsafeGetTensorImpl()->set_backend_meta(new_tmeta);
212 }
213 
214 // A dummy storageImpl for our custom device, that secretly uses the CPU
make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,c10::SymInt size_bytes,c10::DataPtr data_ptr,c10::Allocator * allocator,bool resizable)215 c10::intrusive_ptr<c10::StorageImpl> make_custom_storage_impl(c10::StorageImpl::use_byte_size_t,
216                                                               c10::SymInt size_bytes,
217                                                               c10::DataPtr data_ptr,
218                                                               c10::Allocator* allocator,
219                                                               bool resizable) {
220   c10::intrusive_ptr<c10::StorageImpl> custom_storage_impl;
221   if (data_ptr == nullptr){
222     custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
223       c10::StorageImpl::use_byte_size_t(), size_bytes, allocator, resizable);
224   } else {
225     custom_storage_impl = c10::make_intrusive<c10::StorageImpl>(
226       c10::StorageImpl::use_byte_size_t(), size_bytes, std::move(data_ptr), allocator, resizable);
227   }
228   storageImpl_counter += 1;
229   return custom_storage_impl;
230 }
231 
232 // Register our dummy storageImpl create method.
custom_storage_registry()233 void custom_storage_registry() {
234   c10::SetStorageImplCreate(c10::DeviceType::PrivateUse1, &make_custom_storage_impl);
235 }
236 
custom_storageImpl_called()237 bool custom_storageImpl_called() {
238   if (storageImpl_counter > last_storageImpl_saved_value) {
239     last_storageImpl_saved_value = storageImpl_counter;
240     return true;
241   }
242   return false;
243 }
244 
245 // basic dummy add function
custom_add_Tensor(const at::Tensor & self,const at::Tensor & other,const at::Scalar & alpha)246 at::Tensor custom_add_Tensor(const at::Tensor& self, const at::Tensor& other, const at::Scalar& alpha) {
247   add_counter += 1;
248   // Since this custom device is just for testing, not bothering to implement kernels.
249   return at::empty(self.sizes(), self.options());
250 }
251 
252 // basic abs function
custom_abs_out(const at::Tensor & self,at::Tensor & out)253 at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) {
254   return at::native::abs_out(self, out);
255 }
256 
257 // A dummy allocator for our custom device, that secretly uses the CPU
258 struct DummyCustomAllocator final : at::Allocator {
259   DummyCustomAllocator() = default;
allocateDummyCustomAllocator260   at::DataPtr allocate(size_t nbytes) override {
261     void* data = c10::alloc_cpu(nbytes);
262     return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, custom_device_index)};
263   }
264 
ReportAndDeleteDummyCustomAllocator265   static void ReportAndDelete(void* ptr) {
266     if (!ptr) {
267       return;
268     }
269     c10::free_cpu(ptr);
270   }
271 
raw_deleterDummyCustomAllocator272   at::DeleterFnPtr raw_deleter() const override {
273     return &ReportAndDelete;
274   }
275 
copy_dataDummyCustomAllocator276   void copy_data(void* dest, const void* src, std::size_t count) const final {
277     default_copy_data(dest, src, count);
278   }
279 };
280 
281 // Register our dummy allocator
282 static DummyCustomAllocator global_custom_alloc;
283 REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc);
284 
285 // basic dummy empty function, so we can directly construct tensors on the custom device
286 // This dummy test device will just use the CPU allocator, and ignores pinned memory.
custom_empty_memory_format(at::IntArrayRef size,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,std::optional<at::MemoryFormat> memory_format)287 at::Tensor custom_empty_memory_format(at::IntArrayRef size,
288                                       std::optional<at::ScalarType> dtype,
289                                       std::optional<at::Layout> layout,
290                                       std::optional<at::Device> device,
291                                       std::optional<bool> pin_memory,
292                                       std::optional<at::MemoryFormat> memory_format) {
293   constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
294   return at::detail::empty_generic(size,
295                                    &global_custom_alloc,
296                                    private_use_ks,
297                                    c10::dtype_or_default(dtype),
298                                    memory_format);
299 }
custom_empty_symint(c10::IntArrayRef size,std::optional<at::ScalarType> dtype,std::optional<at::Layout> layout,std::optional<at::Device> device,std::optional<bool> pin_memory,std::optional<at::MemoryFormat> memory_format)300 at::Tensor custom_empty_symint(c10::IntArrayRef size,
301                                std::optional<at::ScalarType> dtype,
302                                std::optional<at::Layout> layout,
303                                std::optional<at::Device> device,
304                                std::optional<bool> pin_memory,
305                                std::optional<at::MemoryFormat> memory_format) {
306   constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
307   return at::detail::empty_generic(size,
308     &global_custom_alloc, private_use_ks, c10::dtype_or_default(dtype), memory_format);
309 }
310 
custom_fill__scalar(at::Tensor & self,const at::Scalar & value)311 at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) {
312   // Not bothering to implement.
313   return self;
314 }
315 
316 // Unsafe using dummy device data_ptr to creat a cpu tensor, and shared data_ptr.
unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor & src)317 at::Tensor unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor& src) {
318   TORCH_CHECK(src.device().type() == c10::DeviceType::PrivateUse1,
319               "Only support dummy device.");
320   const auto& sizes_ = src.sizes();
321   const auto& strides_ = src.strides();
322   auto storage_offset_ = src.storage_offset();
323   at::detail::check_size_nonnegative(sizes_);
324 
325   size_t size_bytes = at::detail::computeStorageNbytes(sizes_, strides_,
326                                                        src.element_size(),
327                                                        storage_offset_);
328 
329   at::DataPtr data_ptr =
330     c10::InefficientStdFunctionContext::makeDataPtr(src.storage().mutable_data_ptr().get(),
331                                                     [](void*){}, at::kCPU);
332 
333   c10::Storage storage{c10::Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr),
334     /*allocator=*/&global_custom_alloc, /*resizeable=*/false};
335 
336   constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU);
337   at::Tensor tensor = at::detail::make_tensor<c10::TensorImpl>(
338        std::move(storage), cpu_ks, src.dtype());
339 
340   c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
341   tensor_impl->set_sizes_and_strides(sizes_, strides_);
342   tensor_impl->set_storage_offset(storage_offset_);
343   return tensor;
344 }
345 
346 // basic dummy copy_() function, so we can copy from the custom device to/from CPU
custom__copy_from(const at::Tensor & self,const at::Tensor & dst,bool non_blocking)347 at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) {
348   TORCH_CHECK(
349       self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1,
350       "Dummy test only allows copy from cpu -> dummy device.");
351   TORCH_CHECK(
352       dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1,
353       "Dummy test only allows copy from cpu -> dummy device.");
354 
355   // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous.
356   TORCH_CHECK(self.sizes() == dst.sizes());
357   TORCH_CHECK(self.scalar_type() == dst.scalar_type());
358 
359   if (self.is_contiguous() && dst.is_contiguous()) {
360     std::memcpy(dst.storage().data_ptr().get(),
361                 self.storage().data_ptr().get(),
362                 self.storage().nbytes());
363   } else {
364     // Using cpu tensor to accomplishment stride copy.
365     auto convert_to_cpu_tensor = [](const at::Tensor& src) -> at::Tensor {
366       if (src.device().type() == c10::DeviceType::PrivateUse1) {
367         return unsafe_create_cpu_tensor_from_dummy_tensor(src);
368       } else {
369         return src;
370       }
371     };
372     at::Tensor cpu_self = convert_to_cpu_tensor(self);
373     at::Tensor cpu_dst = convert_to_cpu_tensor(dst);
374     cpu_dst.copy_(cpu_self);
375   }
376 
377   return dst;
378 }
379 
custom__copy_from_and_resize(const at::Tensor & self,const at::Tensor & dst)380 at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) {
381   return custom__copy_from(self, dst, false);
382 }
383 
custom_empty_strided(c10::IntArrayRef size,c10::IntArrayRef stride,std::optional<at::ScalarType> dtype_opt,std::optional<at::Layout> layout_opt,std::optional<at::Device> device_opt,std::optional<bool> pin_memory_opt)384 at::Tensor custom_empty_strided(c10::IntArrayRef size,
385                                 c10::IntArrayRef stride,
386                                 std::optional<at::ScalarType> dtype_opt,
387                                 std::optional<at::Layout> layout_opt,
388                                 std::optional<at::Device> device_opt,
389                                 std::optional<bool> pin_memory_opt) {
390   constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1);
391   auto dtype = c10::dtype_or_default(dtype_opt);
392   return  at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype);
393 }
394 
395 // Some set operations for the basic use case
custom_set_source_Storage(at::Tensor & result,c10::Storage src)396 at::Tensor& custom_set_source_Storage(at::Tensor& result, c10::Storage src) {
397   int64_t new_size = static_cast<int64_t>(src.nbytes() / result.dtype().itemsize());
398   c10::IntArrayRef stride = {};
399   result.unsafeGetTensorImpl()->set_storage_offset(0);
400   at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
401   at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
402                                new_size, stride_opt,
403                                /*resize_storage=*/!result.is_meta());
404   return result;
405 }
406 
407 // Some set operations for the basic use case
custom_set_source_Storage_storage_offset(at::Tensor & result,c10::Storage storage,int64_t storage_offset,c10::IntArrayRef size,c10::IntArrayRef stride)408 at::Tensor& custom_set_source_Storage_storage_offset(at::Tensor& result,
409                                                      c10::Storage storage,
410                                                      int64_t storage_offset,
411                                                      c10::IntArrayRef size,
412                                                      c10::IntArrayRef stride) {
413   result.unsafeGetTensorImpl()->set_storage_offset(storage_offset);
414   at::OptionalIntArrayRef stride_opt = stride.data() != nullptr ? at::OptionalIntArrayRef(stride) : std::nullopt;
415   at::native::resize_impl_cpu_(result.unsafeGetTensorImpl(),
416                                size, stride_opt,
417                                /*resize_storage=*/!result.is_meta());
418   return result;
419 }
420 
custom_resize_(const at::Tensor & self,at::IntArrayRef size,std::optional<at::MemoryFormat> optional_memory_format)421 const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size,
422                           std::optional<at::MemoryFormat> optional_memory_format) {
423   at::TensorImpl* tensor_impl = self.unsafeGetTensorImpl();
424   tensor_impl->set_sizes_contiguous(size);
425   const auto itemsize = tensor_impl->dtype().itemsize();
426   const auto offset = tensor_impl->storage_offset();
427   const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset);
428   // Dummy device is using cpu allocator, so here just call cpu
429   // function maybe_resize_storage_cpu in aten/src/ATen/native/Resize.h
430   // to get a sufficient memory space.
431   at::native::maybe_resize_storage_cpu(tensor_impl, storage_size);
432   if (optional_memory_format.has_value()) {
433     auto memory_format =
434         optional_memory_format.value();
435     TORCH_CHECK(
436         memory_format != at::MemoryFormat::Preserve,
437         "Unsupported memory format",
438         memory_format);
439     tensor_impl->empty_tensor_restride(memory_format);
440   }
441   return self;
442 }
443 
444 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor, c10::SymInt, c10::SymInt, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable(const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const std::optional<at::Tensor> & attn_bias,double dropout_p,bool is_causal,bool return_debug_mask,std::optional<double> scale)445 custom_scaled_dot_product_fused_attention_overrideable(
446     const at::Tensor & query,
447     const at::Tensor & key,
448     const at::Tensor & value,
449     const std::optional<at::Tensor> & attn_bias,
450     double dropout_p,
451     bool is_causal,
452     bool return_debug_mask,
453     std::optional<double> scale) {
454   const int64_t batch_size = query.size(0);
455   const int64_t num_heads = query.size(1);
456   const int64_t head_dim_qk = query.size(3);
457   const int64_t head_dim_v = value.size(3);
458   const int64_t max_seqlen_q = query.size(2);
459   const int64_t max_seqlen_kv = key.size(2);
460 
461   auto opts = query.options();
462   auto output = at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts);
463   auto logsumexp = at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat));
464   auto debug_attn_mask = at::empty({batch_size, num_heads, max_seqlen_q, max_seqlen_kv},
465                                    opts.dtype(at::kFloat));
466   auto philox_seed = at::empty({}, at::dtype(at::kLong));
467   auto philox_offset = at::empty({}, at::dtype(at::kLong));
468 
469   return std::make_tuple(output, logsumexp, at::Tensor(), at::Tensor(), max_seqlen_q, max_seqlen_kv, philox_seed, philox_offset, debug_attn_mask);
470 }
471 std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>
custom_scaled_dot_product_fused_attention_overrideable_backward(const at::Tensor & grad_out,const at::Tensor & query,const at::Tensor & key,const at::Tensor & value,const at::Tensor & attn_bias,std::array<bool,4> grad_input_mask,const at::Tensor & out,const at::Tensor & logsumexp,const at::Tensor & cum_seq_q,const at::Tensor & cum_seq_k,int64_t max_q,int64_t max_k,double dropout_p,bool is_causal,const at::Tensor & philox_seed,const at::Tensor & philox_offset,std::optional<double> scale)472 custom_scaled_dot_product_fused_attention_overrideable_backward(
473     const at::Tensor & grad_out,
474     const at::Tensor & query,
475     const at::Tensor & key,
476     const at::Tensor & value,
477     const at::Tensor & attn_bias,
478     std::array<bool,4> grad_input_mask,
479     const at::Tensor & out,
480     const at::Tensor & logsumexp,
481     const at::Tensor & cum_seq_q,
482     const at::Tensor & cum_seq_k,
483     int64_t max_q,
484     int64_t max_k,
485     double dropout_p,
486     bool is_causal,
487     const at::Tensor & philox_seed,
488     const at::Tensor & philox_offset,
489     std::optional<double> scale) {
490   return std::tuple<at::Tensor, at::Tensor, at::Tensor, at::Tensor>(
491           at::empty_like(query),
492           at::empty_like(key),
493           at::empty_like(value),
494           at::empty_like(attn_bias));
495 }
496 
497 // This macro does the heavy lifting.
498 // With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend.
499 // For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key.
500 // Later in this file, we map a custom device to the PrivateUse1 device type,
501 // which allows user code that puts a tensor on your custom_device to eventually get plumbed
502 // into the kernels registered here.
503 //
504 // This macro registers your kernels to the PyTorch Dispatcher.
505 // More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/.
TORCH_LIBRARY_IMPL(aten,PrivateUse1,m)506 TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
507   m.impl("abs.out", &custom_abs_out);
508   m.impl("add.Tensor", &custom_add_Tensor);
509   m.impl("empty.memory_format", &custom_empty_symint);
510   m.impl("fill_.Scalar", &custom_fill__scalar);
511   m.impl("_copy_from", &custom__copy_from);
512   m.impl("_copy_from_and_resize", &custom__copy_from_and_resize);
513   m.impl("empty_strided", &custom_empty_strided);
514   m.impl("set_.source_Storage", &custom_set_source_Storage);
515   m.impl("set_.source_Storage_storage_offset",&custom_set_source_Storage_storage_offset);
516   m.impl("resize_", &custom_resize_);
517   m.impl("as_strided", at::native::as_strided_tensorimpl);
518   m.impl("quantize_per_tensor", at::native::quantize_per_tensor);
519   m.impl("_fused_sdp_choice", &_fused_sdp_choice_privateuse1);
520   m.impl("_scaled_dot_product_fused_attention_overrideable", &custom_scaled_dot_product_fused_attention_overrideable);
521   m.impl("_scaled_dot_product_fused_attention_overrideable_backward", &custom_scaled_dot_product_fused_attention_overrideable_backward);
522 }
523 
custom_cpu_fallback(const c10::OperatorHandle & op,torch::jit::Stack * stack)524 void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) {
525   at::native::cpu_fallback(op, stack);
526 }
527 
TORCH_LIBRARY_IMPL(aten,PrivateUse1,m)528 TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) {
529   m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
530   m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
531   m.impl("_fused_adamw_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
532   m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
533   m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>());
534 }
535 
536 // This basic implementation doesn't bother dealing with different device indices
537 // (e.g. custom_device:0 vs. custom_device:1).
538 // We could do that by letting the user pass in a device index in our exposed device function.
539 // Note that if you do that, you'll also need to register a device guard to core.
540 // See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`.
get_custom_device()541 c10::Device get_custom_device() {
542   return c10::Device(c10::DeviceType::PrivateUse1, 0);
543 }
544 
custom_add_called()545 bool custom_add_called() {
546   bool called = false;
547   if (add_counter > last_saved_value) {
548     called = true;
549     last_saved_value = add_counter;
550   }
551   return called;
552 }
553 
554 class PrivateGeneratorImpl : public at::CPUGeneratorImpl {
555 public:
556   // Constructors
PrivateGeneratorImpl(c10::DeviceIndex device_index)557   PrivateGeneratorImpl(c10::DeviceIndex device_index) {
558     device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index);
559     key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1);
560   }
561   ~PrivateGeneratorImpl() override = default;
562 };
563 
564 // this is used to register generator
make_generator_privateuse1(c10::DeviceIndex device_index)565 at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) {
566   return at::make_generator<PrivateGeneratorImpl>(device_index);
567 }
568 
register_generator_first()569 void register_generator_first() {
570   REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
571 }
572 
register_generator_second()573 void register_generator_second() {
574   REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1)
575 }
576 
set_custom_device_index(c10::DeviceIndex device_index)577 void set_custom_device_index(c10::DeviceIndex device_index) {
578   custom_device_index = device_index;
579 }
580 
581 // a global flag used for dummy pin_memory of custom device
582 bool custom_pinned_flag = false;
583 
584 struct FooHooksArgs : public at::PrivateUse1HooksArgs {};
585 
586 struct FooHooksInterface : public at::PrivateUse1HooksInterface {
FooHooksInterfaceFooHooksInterface587     FooHooksInterface(FooHooksArgs) {}
588     ~FooHooksInterface() override = default;
getDefaultGeneratorFooHooksInterface589     const at::Generator& getDefaultGenerator(c10::DeviceIndex device_index) const override {
590       static auto device_gen = make_generator_privateuse1(device_index);
591       return device_gen;
592     }
593     // this is a simple implementation, custom_pinned_flag will be set as true
594     // once tensor.pin_memory() is called. And then tensor.is_pinned()
595     // always return true no matter what tensor it's called on.
isPinnedPtrFooHooksInterface596     bool isPinnedPtr(const void* data) const override {
597       return custom_pinned_flag;
598     }
getPinnedMemoryAllocatorFooHooksInterface599     c10::Allocator* getPinnedMemoryAllocator() const override {
600       custom_pinned_flag = true;
601       return c10::GetCPUAllocator();
602     }
603 };
604 
605 TORCH_DECLARE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs);
606 C10_DEFINE_REGISTRY(PrivateUse1HooksRegistry, FooHooksInterface, FooHooksArgs)
607 // Using Create function to get PrivateUse1HooksInterface point from PrivateUse1HooksRegistry class.
608 C10_REGISTER_TYPED_CLASS(PrivateUse1HooksRegistry, "FooHooks", FooHooksInterface)
609 
610 static at::PrivateUse1HooksInterface* privateuse1_hooks_local = nullptr;
get_private_hooks()611 static at::PrivateUse1HooksInterface* get_private_hooks() {
612   static c10::once_flag once;
613   c10::call_once(once, [] {
614     privateuse1_hooks_local = PrivateUse1HooksRegistry()->Create("FooHooks", {}).release();
615     if (!privateuse1_hooks_local) {
616       privateuse1_hooks_local = new FooHooksInterface(FooHooksArgs{});
617     }
618   });
619   return privateuse1_hooks_local;
620 }
621 
register_hook()622 void register_hook() {
623   at::RegisterPrivateUse1HooksInterface(get_private_hooks());
624 }
625 
is_register_hook()626 bool is_register_hook() {
627   return privateuse1_hooks_local != nullptr;
628 }
629 
default_generator(c10::DeviceIndex device_index)630 const at::Generator& default_generator(c10::DeviceIndex device_index) {
631   return at::globalContext().defaultGenerator(at::Device(c10::DeviceType::PrivateUse1, device_index));;
632 }
633 
fallback_with_undefined_tensor()634 void fallback_with_undefined_tensor() {
635   at::Tensor first = at::empty((2,3)).to(at::DeviceType::PrivateUse1);
636   at::Tensor second = at::Tensor();
637   at::Tensor step = at::empty({}).fill_(2).to(at::DeviceType::PrivateUse1);
638   at::Tensor grad_scale = at::empty({}).fill_(0.00001).to(at::DeviceType::PrivateUse1);
639   at::Tensor found_inf = at::empty({}).fill_(1).to(at::DeviceType::PrivateUse1);
640   at::TensorList tensors = {first, first};
641   at::TensorList undefined_tensors = {first, second};
642   at::TensorList steps = {step, step};
643   return at::_fused_adamw_(tensors, tensors, tensors, tensors, undefined_tensors,
644                            steps, 0.001, 0.9, 0.999, 1e-2, 1e-8, false, false,
645                            grad_scale, found_inf);
646 }
647 
648 struct CustomAutogradFnReturnsSelf : public torch::autograd::Function<CustomAutogradFnReturnsSelf> {
649 
forwardCustomAutogradFnReturnsSelf650   static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
651     return self;
652   }
653 
backwardCustomAutogradFnReturnsSelf654   static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
655     return {grad_output[0] * 0.5};
656   }
657 };
658 
659 struct CustomAutogradFnAliasing : public torch::autograd::Function<CustomAutogradFnAliasing> {
660 
forwardCustomAutogradFnAliasing661   static at::Tensor forward(torch::autograd::AutogradContext* ctx, at::Tensor self) {
662     return self.view_symint(self.sym_sizes());
663   }
664 
backwardCustomAutogradFnAliasing665   static torch::autograd::variable_list backward(torch::autograd::AutogradContext* ctx, torch::autograd::variable_list grad_output) {
666     return {grad_output[0] * 0.5};
667   }
668 };
669 
custom_autograd_fn_returns_self(at::Tensor x)670 at::Tensor custom_autograd_fn_returns_self(at::Tensor x) {
671   return CustomAutogradFnReturnsSelf::apply(x);
672 }
673 
custom_autograd_fn_aliasing(at::Tensor x)674 at::Tensor custom_autograd_fn_aliasing(at::Tensor x) {
675   return CustomAutogradFnAliasing::apply(x);
676 }
677 
678 // Here, we're exposing a custom device object that corresponds to our custom backend.
679 // We do this using pybind: exposing an "extension_name.custom_device()" function in python,
680 // that's implemented in C++.
681 // The implementation in this file maps directly to the `PrivateUse1` device type.
PYBIND11_MODULE(TORCH_EXTENSION_NAME,m)682 PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
683     m.def("custom_device", &get_custom_device, "get custom device object");
684     m.def("custom_add_called", &custom_add_called, "check if our custom add function was called");
685     m.def("register_generator_first", &register_generator_first, "register generator for custom device firstly");
686     m.def("register_generator_second", &register_generator_second, "register generator for custom device secondly");
687     m.def("set_custom_device_index", &set_custom_device_index, "set custom device index");
688     m.def("custom_storage_registry", &custom_storage_registry, "set custom storageImpl creat method");
689     m.def("custom_storageImpl_called", &custom_storageImpl_called, "check if our custom abs function was called");
690     m.def("custom_set_backend_meta", &custom_set_backend_meta, "a fake set tensor BackendMeta function");
691     m.def("check_backend_meta", &check_backend_meta, "check if BackendMeta serialization correctly");
692     m.def("custom_serialization_registry", &custom_serialization_registry, "register custom serialization function");
693     m.def("register_hook", &register_hook, "register_hook for privateuse1");
694     m.def("is_register_hook", &is_register_hook, "is_register_hook for privateuse1");
695     m.def("default_generator", &default_generator, "default_generator for privateuse1");
696     m.def("fallback_with_undefined_tensor", &fallback_with_undefined_tensor, "fallback_with_undefined_tensor for privateuse1");
697 
698     // Co-opting this file to more easily test torch.compile'ing of custom autograd functions in C++
699     m.def("custom_autograd_fn_returns_self", &custom_autograd_fn_returns_self);
700 }
701 
TORCH_LIBRARY(_test_funcs,m)702 TORCH_LIBRARY(_test_funcs, m) {
703   m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)");
704 }
TORCH_LIBRARY_IMPL(_test_funcs,AutogradCPU,m)705 TORCH_LIBRARY_IMPL(_test_funcs, AutogradCPU, m) {
706   m.impl("custom_autograd_fn_aliasing", &custom_autograd_fn_aliasing);
707 }
708