xref: /aosp_15_r20/external/pytorch/aten/src/ATen/templates/Functions.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <array>
2 
3 #include <ATen/Functions.h>
4 #include <ATen/Utils.h>
5 #include <c10/core/Allocator.h>
6 
7 namespace at {
8 
make_tensor()9 Tensor TensorMaker::make_tensor() {
10    AutoDispatchBelowADInplaceOrView guard{}; // TODO: Remove.
11    tracer::impl::NoTracerDispatchMode tracer_guard{};
12 
13    check_size_nonnegative(sizes_);
14 
15    TORCH_CHECK_VALUE(
16        !deleter_ || !ctx_,
17        "The deleter and context arguments are mutually exclusive.");
18 
19    if (device_ == std::nullopt) {
20      device_ = globalContext().getDeviceFromPtr(data_, opts_.device().type());
21    }
22 
23    if (opts_.device().has_index()) {
24      // clang-format off
25      TORCH_CHECK_VALUE(
26          opts_.device() == *device_,
27          "Specified device ", opts_.device(), " does not match device of data ", *device_);
28      // clang-format on
29    }
30 
31    std::size_t size_bytes = computeStorageSize();
32 
33    DataPtr data_ptr{};
34    if (deleter_) {
35      data_ptr = makeDataPtrFromDeleter();
36    } else {
37      data_ptr = makeDataPtrFromContext();
38    }
39 
40    TORCH_CHECK(!resizeable_ || allocator_ != nullptr, "Must specify an allocator with allocator() if you want to use resizeable_storage()");
41    Storage storage{Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), /*allocator=*/allocator_, /*resizable=*/resizeable_};
42 
43    Tensor tensor = detail::make_tensor<TensorImpl>(
44        std::move(storage), opts_.computeDispatchKey(), opts_.dtype());
45 
46   TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl();
47   if (strides_) {
48     tensor_impl->set_sizes_and_strides(sizes_, *strides_);
49   } else {
50     tensor_impl->set_sizes_contiguous(sizes_);
51   }
52   if (storage_offset_) {
53     tensor_impl->set_storage_offset(*storage_offset_);
54   }
55 
56    return tensor;
57  }
58 
computeStorageSize() const59  std::size_t TensorMaker::computeStorageSize() const noexcept {
60    std::size_t itemsize = opts_.dtype().itemsize();
61 
62    if (strides_) {
63      auto storage_size = detail::computeStorageNbytes(sizes_, *strides_, itemsize);
64      if (storage_offset_) {
65        storage_size += storage_offset_.value();
66      }
67      return storage_size;
68    }
69 
70    std::size_t size = 1;
71    for (std::int64_t s : sizes_) {
72      size *= static_cast<std::size_t>(s);
73    }
74    auto storage_size = size * itemsize;
75    if (storage_offset_) {
76      storage_size += storage_offset_.value();
77    }
78    return storage_size;
79  }
80 
makeDataPtrFromDeleter()81  inline DataPtr TensorMaker::makeDataPtrFromDeleter() noexcept {
82    return InefficientStdFunctionContext::makeDataPtr(data_, std::move(deleter_), *device_);
83  }
84 
makeDataPtrFromContext()85  inline DataPtr TensorMaker::makeDataPtrFromContext() noexcept {
86    return DataPtr{data_, ctx_.release(), ctx_.get_deleter(), *device_};
87  }
88 
makeTempSizes() const89  IntArrayRef TensorMaker::makeTempSizes() const noexcept {
90    static std::int64_t zeros[5] = {0, 0, 0, 0, 0};
91    if (opts_.has_memory_format()) {
92      MemoryFormat format = *opts_.memory_format_opt();
93      if (format == MemoryFormat::ChannelsLast) {
94        return IntArrayRef(zeros, 4);
95      }
96      if (format == MemoryFormat::ChannelsLast3d) {
97        return IntArrayRef(zeros, 5);
98      }
99    }
100    return IntArrayRef(zeros, 1);
101  }
102 
103 } // namespace at
104