xref: /aosp_15_r20/external/pytorch/tools/autograd/templates/variable_factories.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 // ${generated_comment}
4 
5 #include <ATen/core/Tensor.h>
6 #include <ATen/TracerMode.h>
7 #include <ATen/core/grad_mode.h>
8 #include <c10/util/ArrayRef.h>
9 #include <c10/core/MemoryFormat.h>
10 #include <torch/csrc/api/include/torch/detail/TensorDataContainer.h>
11 #include <torch/csrc/autograd/variable.h>
12 
13 #ifndef AT_PER_OPERATOR_HEADERS
14 #include <ATen/Functions.h>
15 #else
16 #include <ATen/ops/from_blob.h>
17 $ops_headers
18 #endif
19 
20 #include <functional>
21 #include <initializer_list>
22 #include <utility>
23 
24 namespace torch {
25 
26 /// NOTE: Currently `torch::tensor(...)` doesn't support mixed data types
27 /// (i.e. `torch::tensor({{bool, 2.0}})` doesn't work). We might be able to
28 /// support it in the future by iterating over all sub-lists to find
29 /// the largest data type that can represent all of the elements, or by using
30 /// variadic templates.
31 ///
32 /// NOTE: C++ `torch::tensor` with a floating-point type or an `at::ArrayRef` / `std::vector` /
33 /// (nested) braced-init-list of floating-point types always produces a tensor of dtype
34 /// `torch::get_default_dtype()`, matching Python `torch.tensor` behavior.
35 ///
36 /// NOTE: C++ `torch::tensor` with an integer type or an `at::ArrayRef` / `std::vector` /
37 /// (nested) braced-init-list of integer types always produces a tensor of dtype `at::kLong`
38 /// (aka. int64_t), matching Python `torch.tensor` behavior.
39 ///
40 /// NOTE: The following dtypes are not supported by `torch::tensor` currently:
41 /// - `unsigned int`
42 /// - `unsigned long int`
43 /// - `unsigned long long int`
44 /// - `long long int`
45 inline at::Tensor tensor(detail::TensorDataContainer tensor_data_container, const at::TensorOptions& options = {}) {
46   return autograd::make_variable(
47     // note: we remove the requires_grad setting from the TensorOptions because
48     // it is ignored anyways (and we actually have an assertion that it isn't set
49     // which would fail otherwise). We handle requires_grad explicitly here
50     // instead of passing it through to the kernel.
51     tensor_data_container.convert_to_tensor(options.requires_grad(::std::nullopt)),
52     options.requires_grad());
53 }
54 
55 /// A generic deleter function.
56 using Deleter = std::function<void(void*)>;
57 using at::MemoryFormat;
58 
59 /// Exposes the given `data` as a `Tensor` without taking ownership of the
60 /// original data. `sizes` should specify the shape of the tensor, `strides` the
61 /// stride in each dimension. The `deleter` function (a
62 /// `std::function<void(void*)>`) will be called on the `data` when the Tensor
63 /// data would normally be deallocated. The `TensorOptions` specify additional
64 /// configuration options for the returned tensor, such as what type to
65 /// interpret the `data` as.
66 inline at::Tensor from_blob(
67     void* data,
68     at::IntArrayRef sizes,
69     at::IntArrayRef strides,
70     const Deleter& deleter,
71     const at::TensorOptions& options = at::TensorOptions()) {
72   at::Tensor tensor = ([&]() {
73     at::AutoDispatchBelowAutograd guard;  // TODO: remove
74     at::tracer::impl::NoTracerDispatchMode tracer_guard;
75     return at::from_blob(data, sizes, strides, deleter, options.requires_grad(::std::nullopt));
76   })();
77   return autograd::make_variable(tensor, options.requires_grad());
78 }
79 
80 /// Exposes the given `data` as a `Tensor` without taking ownership of the
81 /// original data. `sizes` should specify the shape of the tensor, `strides` the
82 /// stride in each dimension. The `TensorOptions`
83 /// specify additional configuration options for the returned tensor, such as
84 /// what type to interpret the `data` as.
85 inline at::Tensor from_blob(
86     void* data,
87     at::IntArrayRef sizes,
88     at::IntArrayRef strides,
89     const at::TensorOptions& options = at::TensorOptions()) {
90   at::Tensor tensor = ([&]() {
91     at::AutoDispatchBelowAutograd guard;  // TODO: remove
92     at::tracer::impl::NoTracerDispatchMode tracer_guard;
93     return at::from_blob(data, sizes, strides, options.requires_grad(::std::nullopt));
94   })();
95   return autograd::make_variable(tensor, options.requires_grad());
96 }
97 
98 /// Exposes the given `data` as a `Tensor` without taking ownership of the
99 /// original data. `sizes` should specify the shape of the tensor. The `deleter`
100 /// (a `std::function<void(void*)>`) function will be called on the `data` when
101 /// the Tensor data would normally be deallocated. The `TensorOptions` specify
102 /// additional configuration options for the returned tensor, such as what type
103 /// to interpret the `data` as.
104 inline at::Tensor from_blob(
105     void* data,
106     at::IntArrayRef sizes,
107     const Deleter& deleter,
108     const at::TensorOptions& options = at::TensorOptions()) {
109   at::Tensor tensor = ([&]() {
110     at::AutoDispatchBelowAutograd guard;  // TODO: remove
111     at::tracer::impl::NoTracerDispatchMode tracer_guard;
112     return at::from_blob(data, sizes, deleter, options.requires_grad(::std::nullopt));
113   })();
114   return autograd::make_variable(tensor, options.requires_grad());
115 }
116 
117 /// Exposes the given `data` as a `Tensor` without taking ownership of the
118 /// original data. `sizes` should specify the shape of the tensor. The
119 /// `TensorOptions` specify additional configuration options for the returned
120 /// tensor, such as what type to interpret the `data` as.
121 inline at::Tensor from_blob(
122     void* data,
123     at::IntArrayRef sizes,
124     const at::TensorOptions& options = at::TensorOptions()) {
125   at::Tensor tensor = ([&]() {
126     at::AutoDispatchBelowAutograd guard;  // TODO: remove
127     at::tracer::impl::NoTracerDispatchMode tracer_guard;
128     return at::from_blob(data, sizes, options.requires_grad(::std::nullopt));
129   })();
130   return autograd::make_variable(tensor, options.requires_grad());
131 }
132 
133 ${function_definitions}
134 
135 } // namespace torch
136