1 #include <array>
2
3 #include <ATen/Functions.h>
4 #include <ATen/Utils.h>
5 #include <c10/core/Allocator.h>
6
7 namespace at {
8
make_tensor()9 Tensor TensorMaker::make_tensor() {
10 AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
11 tracer::impl::NoTracerDispatchMode tracer_guard{};
12
13 check_size_nonnegative(sizes_);
14
15 TORCH_CHECK_VALUE(
16 !deleter_ || !ctx_,
17 "The deleter and context arguments are mutually exclusive.");
18
19 if (device_ == std::nullopt) {
20 device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
21 }
22
23 if (opts_.device().has_index()) {
24 // clang-format off
25 TORCH_CHECK_VALUE(
26 opts_.device() == *device_,
27 "Specified device ", opts_.device(), " does not match device of data ", *device_);
28 // clang-format on
29 }
30
31 std::size_t size_bytes = computeStorageSize();
32
33 DataPtr data_ptr{};
34 if (deleter_) {
35 data_ptr = makeDataPtrFromDeleter();
36 } else {
37 data_ptr = makeDataPtrFromContext();
38 }
39
40 TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
41 Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizable=*/resizeable_};
42
43 Tensor tensor = detail::make_tensor<TensorImpl>(
44 std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
45
46 TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
47 if (strides_) {
48 tensor_impl->set_sizes_and_strides(sizes_, *strides_);
49 } else {
50 tensor_impl->set_sizes_contiguous(sizes_);
51 }
52 if (storage_offset_) {
53 tensor_impl->set_storage_offset(*storage_offset_);
54 }
55
56 return tensor;
57 }
58
computeStorageSize() const59 std::size_t TensorMaker::computeStorageSize() const noexcept {
60 std::size_t itemsize = opts_.dtype().itemsize();
61
62 if (strides_) {
63 auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
64 if (storage_offset_) {
65 storage_size += storage_offset_.value();
66 }
67 return storage_size;
68 }
69
70 std::size_t size = 1;
71 for (std::int64_t s : sizes_) {
72 size *= static_cast<std::size_t>(s);
73 }
74 auto storage_size = size * itemsize;
75 if (storage_offset_) {
76 storage_size += storage_offset_.value();
77 }
78 return storage_size;
79 }
80
makeDataPtrFromDeleter()81 inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept {
82 return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_);
83 }
84
makeDataPtrFromContext()85 inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
86 return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
87 }
88
makeTempSizes() const89 IntArrayRef TensorMaker::makeTempSizes() const noexcept {
90 static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
91 if (opts_.has_memory_format()) {
92 MemoryFormat format = *opts_.memory_format_opt();
93 if (format == MemoryFormat::ChannelsLast) {
94 return IntArrayRef(zeros, 4);
95 }
96 if (format == MemoryFormat::ChannelsLast3d) {
97 return IntArrayRef(zeros, 5);
98 }
99 }
100 return IntArrayRef(zeros, 1);
101 }
102
103 } // namespace at
104