1 #include <c10/core/StorageImpl.h>
2 #include <c10/util/flat_hash_map.h>
3
4 namespace c10 {
5
6 // The array to save function pointer for custom storageImpl create.
7 C10_API std::array<StorageImplCreateHelper, at::COMPILE_TIME_MAX_DEVICE_TYPES>
8 StorageImplCreate;
9
10 // A allowlist of device type, currently available is PrivateUse1
GetBackendMetaAllowlist()11 inline ska::flat_hash_set<c10::DeviceType>& GetBackendMetaAllowlist() {
12 static ska::flat_hash_set<c10::DeviceType> DeviceTypeAllowList{
13 DeviceType::PrivateUse1};
14 return DeviceTypeAllowList;
15 }
16
throwNullDataPtrError()17 void throwNullDataPtrError() {
18 TORCH_CHECK(
19 false,
20 "Cannot access data pointer of Tensor (e.g. FakeTensor, FunctionalTensor). "
21 "If you're using torch.compile/export/fx, it is likely that we are erroneously "
22 "tracing into a custom kernel. To fix this, please wrap the custom kernel into "
23 "an opaque custom op. Please see the following for details: "
24 "https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html");
25 }
26
27 // NOTE: [FakeTensor.data_ptr deprecation]
28 // Today:
29 // - FakeTensor.data_ptr errors out in torch.compile.
30 // - FakeTensor.data_ptr raises the following deprecation warning otherwise.
31 // - the following deprecation warning is only for FakeTensor (for now).
32 // In the future we can consider extending to more wrapper Tensor subclasses.
warnDeprecatedDataPtr()33 void warnDeprecatedDataPtr() {
34 TORCH_WARN_ONCE(
35 "Accessing the data pointer of FakeTensor is deprecated and will error in "
36 "PyTorch 2.5. This is almost definitely a bug in your code and will "
37 "cause undefined behavior with subsystems like torch.compile. "
38 "Please wrap calls to tensor.data_ptr() in an opaque custom op; "
39 "If all else fails, you can guard accesses to tensor.data_ptr() on "
40 "isinstance(tensor, FakeTensor).")
41 }
42
SetStorageImplCreate(DeviceType t,StorageImplCreateHelper fptr)43 void SetStorageImplCreate(DeviceType t, StorageImplCreateHelper fptr) {
44 // Allowlist verification.
45 // Only if the devicetype is in the allowlist,
46 // we allow the extension to be registered for storageImpl create.
47 const auto& DeviceTypeAllowlist = GetBackendMetaAllowlist();
48 TORCH_CHECK(
49 DeviceTypeAllowlist.find(t) != DeviceTypeAllowlist.end(),
50 "It is only allowed to register the storageImpl create method ",
51 "for PrivateUse1. ",
52 "If you have related storageImpl requirements, ",
53 "please expand the allowlist");
54 // Register function pointer.
55 int device_type = static_cast<int>(t);
56 TORCH_CHECK(
57 StorageImplCreate[device_type] == nullptr,
58 "The StorageImplCreate function pointer for ",
59 t,
60 " has been registered.");
61 StorageImplCreate[device_type] = fptr;
62 }
63
GetStorageImplCreate(DeviceType t)64 StorageImplCreateHelper GetStorageImplCreate(DeviceType t) {
65 int device_type = static_cast<int>(t);
66 return StorageImplCreate[device_type];
67 }
68
make_storage_impl(c10::StorageImpl::use_byte_size_t use_byte_size,c10::SymInt size_bytes,c10::DataPtr data_ptr,c10::Allocator * allocator,bool resizable,std::optional<at::Device> device_opt)69 c10::intrusive_ptr<c10::StorageImpl> make_storage_impl(
70 c10::StorageImpl::use_byte_size_t use_byte_size,
71 c10::SymInt size_bytes,
72 c10::DataPtr data_ptr,
73 c10::Allocator* allocator,
74 bool resizable,
75 std::optional<at::Device> device_opt) {
76 // This will be non-nullptr only when there is a custom StorageImpl
77 // constructor for the given device
78 c10::StorageImplCreateHelper fptr = nullptr;
79 if (device_opt.has_value()) {
80 // We only need to check this here as this is the only case where we can
81 // have a device that is not CPU (and thus for which the StorageImpl
82 // constructor can be overwritten).
83 fptr = c10::GetStorageImplCreate(device_opt.value().type());
84 }
85
86 if (fptr != nullptr) {
87 return fptr(
88 use_byte_size,
89 std::move(size_bytes),
90 std::move(data_ptr),
91 allocator,
92 resizable);
93 }
94
95 // Create a c10::StorageImpl object.
96 if (data_ptr != nullptr) {
97 return c10::make_intrusive<c10::StorageImpl>(
98 use_byte_size,
99 std::move(size_bytes),
100 std::move(data_ptr),
101 allocator,
102 resizable);
103 }
104 return c10::make_intrusive<c10::StorageImpl>(
105 use_byte_size, std::move(size_bytes), allocator, resizable);
106 }
107
108 } // namespace c10
109