xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_new.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <torch/csrc/python_headers.h>
4 #include <torch/csrc/utils/python_arg_parser.h>
5 
6 #include <ATen/core/Tensor.h>
7 
8 namespace torch::utils {
9 
10 // NOTE: [torch.tensor, lift_fresh, and device movement]
11 //
12 // The `only_lift_cpu_tensors` flag controls what happens on torch.tensor([1, 2,
13 // 3], device="cuda") (or any non-CPU devices).
14 //
15 // If false (default):
16 // - the data gets moved into a CPU Tensor
17 // - then, it gets moved to cuda (via .to)
18 // - finally, we call lift_fresh() on it.
19 // Steps 1 and 2 happen with all modes disabled.
20 //
21 // If true:
22 // - the data gets moved into a CPU Tensor (with correct dtype)
23 // - we call lift_fresh() on it
24 // - finally, we move it to cuda (via .to)
25 // Step 1 happens with all modes disabled.
26 //
27 // `only_lift_cpu_tensors=true` is useful to prevent CUDA initialization under
28 // FakeTensorMode because it avoids moving concrete data to CUDA.
29 TORCH_API bool only_lift_cpu_tensors();
30 TORCH_API void set_only_lift_cpu_tensors(bool value);
31 
32 at::Tensor base_tensor_ctor(PyObject* args, PyObject* kwargs);
33 at::Tensor legacy_tensor_ctor(
34     c10::DispatchKey dispatch_key,
35     at::ScalarType scalar_type,
36     PyObject* args,
37     PyObject* kwargs);
38 at::Tensor legacy_tensor_new(
39     c10::DispatchKey dispatch_key,
40     at::ScalarType scalar_type,
41     PyObject* args,
42     PyObject* kwargs);
43 at::Tensor indexing_tensor_from_data(
44     c10::TensorOptions options,
45     at::ScalarType scalar_type,
46     std::optional<at::Device> device,
47     PyObject* data);
48 at::Tensor sparse_coo_tensor_ctor(
49     c10::DispatchKey dispatch_key,
50     at::ScalarType scalar_type,
51     PythonArgs& r);
52 void _validate_sparse_coo_tensor_args(
53     c10::DispatchKey dispatch_key,
54     at::ScalarType scalar_type,
55     PyObject* args,
56     PyObject* kwargs);
57 
58 at::Tensor sparse_compressed_tensor_ctor(
59     c10::DispatchKey dispatch_key,
60     at::ScalarType scalar_type,
61     PythonArgs& r);
62 at::Tensor sparse_csr_tensor_ctor(
63     c10::DispatchKey dispatch_key,
64     at::ScalarType scalar_type,
65     PythonArgs& r);
66 at::Tensor sparse_csc_tensor_ctor(
67     c10::DispatchKey dispatch_key,
68     at::ScalarType scalar_type,
69     PythonArgs& r);
70 at::Tensor sparse_bsr_tensor_ctor(
71     c10::DispatchKey dispatch_key,
72     at::ScalarType scalar_type,
73     PythonArgs& r);
74 at::Tensor sparse_bsc_tensor_ctor(
75     c10::DispatchKey dispatch_key,
76     at::ScalarType scalar_type,
77     PythonArgs& r);
78 
79 void _validate_sparse_compressed_tensor_args(
80     c10::DispatchKey dispatch_key,
81     at::ScalarType scalar_type,
82     PyObject* args,
83     PyObject* kwargs);
84 void _validate_sparse_csr_tensor_args(
85     c10::DispatchKey dispatch_key,
86     at::ScalarType scalar_type,
87     PyObject* args,
88     PyObject* kwargs);
89 void _validate_sparse_csc_tensor_args(
90     c10::DispatchKey dispatch_key,
91     at::ScalarType scalar_type,
92     PyObject* args,
93     PyObject* kwargs);
94 void _validate_sparse_bsr_tensor_args(
95     c10::DispatchKey dispatch_key,
96     at::ScalarType scalar_type,
97     PyObject* args,
98     PyObject* kwargs);
99 void _validate_sparse_bsc_tensor_args(
100     c10::DispatchKey dispatch_key,
101     at::ScalarType scalar_type,
102     PyObject* args,
103     PyObject* kwargs);
104 
105 at::Tensor tensor_ctor(
106     c10::DispatchKey dispatch_key,
107     at::ScalarType scalar_type,
108     PythonArgs& r);
109 at::Tensor as_tensor(
110     c10::DispatchKey dispatch_key,
111     at::ScalarType scalar_type,
112     PythonArgs& r);
113 at::Tensor new_tensor(
114     c10::DispatchKey dispatch_key,
115     at::ScalarType scalar_type,
116     PyObject* args,
117     PyObject* kwargs);
118 at::Tensor new_ones(
119     c10::DispatchKey dispatch_key,
120     at::ScalarType scalar_type,
121     PyObject* args,
122     PyObject* kwargs);
123 at::Tensor tensor_frombuffer(
124     PyObject* buffer,
125     at::ScalarType dtype,
126     int64_t count,
127     int64_t offset,
128     bool requires_grad);
129 at::Tensor tensor_fromDLPack(PyObject* data);
130 at::Tensor asarray(
131     PyObject* obj,
132     std::optional<c10::ScalarType> dtype,
133     std::optional<c10::Device> device,
134     std::optional<bool> copy,
135     bool requires_grad);
136 } // namespace torch::utils
137