xref: /aosp_15_r20/external/pytorch/torch/csrc/utils/tensor_new.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/python_headers.h>
2 #include <torch/csrc/utils/tensor_new.h>
3 
4 #include <pybind11/pybind11.h>
5 #include <torch/csrc/DynamicTypes.h>
6 #include <torch/csrc/Exceptions.h>
7 #include <torch/csrc/Size.h>
8 #include <torch/csrc/autograd/generated/variable_factories.h>
9 #include <torch/csrc/autograd/variable.h>
10 #include <torch/csrc/utils/device_lazy_init.h>
11 #include <torch/csrc/utils/numpy_stub.h>
12 #include <torch/csrc/utils/pybind.h>
13 #include <torch/csrc/utils/python_arg_parser.h>
14 #include <torch/csrc/utils/python_numbers.h>
15 #include <torch/csrc/utils/python_scalars.h>
16 #include <torch/csrc/utils/python_strings.h>
17 #include <torch/csrc/utils/tensor_numpy.h>
18 
19 #include <ATen/ATen.h>
20 #include <ATen/DLConvertor.h>
21 #include <ATen/InitialTensorOptions.h>
22 #include <ATen/NamedTensorUtils.h>
23 #include <ATen/NativeFunctions.h>
24 #include <ATen/SparseCsrTensorUtils.h>
25 #include <ATen/TracerMode.h>
26 #include <ATen/dlpack.h>
27 #include <c10/core/Backend.h>
28 #include <c10/core/DispatchKeySet.h>
29 #include <c10/core/Layout.h>
30 #include <c10/util/Exception.h>
31 #include <c10/util/irange.h>
32 #include <optional>
33 
34 #include <stdexcept>
35 #include <vector>
36 
37 using at::Device;
38 using at::IntArrayRef;
39 using at::kInt;
40 using at::kLong;
41 using at::ScalarType;
42 using at::Storage;
43 using at::Tensor;
44 using at::TensorOptions;
45 using std::optional;
46 
47 namespace torch::utils {
48 namespace {
49 const int MAX_DIMS = 128;
50 
51 thread_local bool kOnlyLiftCPUTensors = false;
52 
build_options(c10::TensorOptions options,at::ScalarType scalar_type,const std::optional<Device> & device=std::nullopt)53 TensorOptions build_options(
54     c10::TensorOptions options,
55     at::ScalarType scalar_type,
56     const std::optional<Device>& device = std::nullopt) {
57   options = options.dtype(scalar_type);
58   if (device.has_value()) {
59     return options.device(device);
60   }
61   return options;
62 }
63 
64 // NB: It appears there is some consistency invariant between options and
65 // device, where if device is non-empty, its type must be consistent with the
66 // device type in options.
67 // TODO: Refactor this so we just pass everything in via options
68 
new_with_sizes(c10::TensorOptions options,at::ScalarType scalar_type,const std::optional<Device> & device,c10::SymIntArrayRef sizes)69 Tensor new_with_sizes(
70     c10::TensorOptions options,
71     at::ScalarType scalar_type,
72     const std::optional<Device>& device,
73     c10::SymIntArrayRef sizes) {
74   maybe_initialize_device(options.device());
75   pybind11::gil_scoped_release no_gil;
76   return at::empty_symint(sizes, build_options(options, scalar_type, device));
77 }
78 
new_with_storage(c10::TensorOptions options,at::ScalarType scalar_type,Storage storage)79 Tensor new_with_storage(
80     c10::TensorOptions options,
81     at::ScalarType scalar_type,
82     Storage storage) {
83   auto tensor = at::empty({}, build_options(options, scalar_type));
84   tensor.set_(std::move(storage));
85   return tensor;
86 }
87 
compute_sizes(PyObject * seq,ScalarType scalar_type)88 std::vector<int64_t> compute_sizes(PyObject* seq, ScalarType scalar_type) {
89   bool is_storage = isStorage(seq);
90   std::vector<int64_t> sizes;
91   // Note that after the first iteration, obj is the only thing that keeps
92   // the seq raw pointer alive.
93   THPObjectPtr obj;
94   while (PySequence_Check(seq)) {
95     auto length = PySequence_Length(seq);
96     if (length < 0)
97       throw python_error();
98     if (is_storage) {
99       length /= static_cast<int64_t>(elementSize(scalar_type));
100     }
101     sizes.push_back(length);
102     TORCH_CHECK_VALUE(
103         sizes.size() <= MAX_DIMS,
104         "too many dimensions '",
105         Py_TYPE(seq)->tp_name,
106         "'");
107     if (length == 0)
108       break;
109     PyObject* new_obj = PySequence_GetItem(seq, 0);
110     // This line uses seq so we must NOT override obj before this line
111     TORCH_CHECK_VALUE(
112         new_obj,
113         "could not determine the shape of object type '",
114         Py_TYPE(seq)->tp_name,
115         "'");
116     obj = THPObjectPtr(new_obj);
117     seq = obj.get();
118   }
119 
120   return sizes;
121 }
122 
infer_scalar_type(PyObject * obj)123 ScalarType infer_scalar_type(PyObject* obj) {
124   if (torch::is_symint(obj)) {
125     return ScalarType::Long;
126   }
127   if (torch::is_symfloat(obj)) {
128     return torch::tensors::get_default_scalar_type();
129   }
130 #ifdef USE_NUMPY
131   if (is_numpy_available()) {
132     if (PyArray_Check(obj)) {
133       return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)obj));
134     }
135     if (PyArray_CheckScalar(obj)) {
136       THPObjectPtr arr(PyArray_FromScalar(obj, nullptr));
137       return numpy_dtype_to_aten(PyArray_TYPE((PyArrayObject*)arr.get()));
138     }
139   }
140 #endif
141   if (PyFloat_Check(obj)) {
142     // this is always guaranteed to be a floating-point type, and makes it more
143     // convenient to write e.g. torch.tensor(0.) than torch.tensor(0.,
144     // dtype=torch.Tensor.dtype).
145     return torch::tensors::get_default_scalar_type();
146   }
147   if (THPUtils_checkLong(obj)) {
148     return ScalarType::Long;
149   }
150   if (PyBool_Check(obj)) {
151     return ScalarType::Bool;
152   }
153   if (PyComplex_Check(obj)) {
154     switch (torch::tensors::get_default_scalar_type()) {
155       case ScalarType::Float:
156         return ScalarType::ComplexFloat;
157       case ScalarType::Double:
158         return ScalarType::ComplexDouble;
159       case ScalarType::Half:
160         return ScalarType::ComplexHalf;
161       default:
162         TORCH_CHECK(false, "invalid default scalar type for complex");
163     }
164   }
165   if (THPVariable_Check(obj)) {
166     const auto& var = THPVariable_Unpack(obj);
167     return var.scalar_type();
168   }
169   TORCH_CHECK_TYPE(
170       !THPUtils_checkString(obj),
171       "new(): invalid data type '",
172       Py_TYPE(obj)->tp_name,
173       "'");
174   if (PySequence_Check(obj)) {
175     std::optional<ScalarType> scalarType;
176     auto length = PySequence_Length(obj);
177     if (length < 0)
178       throw python_error();
179     // match NumPy semantics, except use default tensor type instead of double.
180     if (length == 0)
181       return torch::tensors::get_default_scalar_type();
182     for (const auto i : c10::irange(length)) {
183       THPObjectPtr handle(PySequence_GetItem(obj, i));
184       if (!handle)
185         throw python_error();
186       auto cur_item = handle.get();
187       TORCH_CHECK_TYPE(
188           cur_item != obj, "new(): self-referential lists are incompatible");
189       ScalarType item_scalarType = infer_scalar_type(cur_item);
190       scalarType = (scalarType) ? at::promoteTypes(*scalarType, item_scalarType)
191                                 : item_scalarType;
192       if (scalarType == ScalarType::ComplexDouble) {
193         // this won't change (unless we hit undefined, but that will fail
194         // later).
195         return *scalarType;
196       }
197     }
198     // NOLINTNEXTLINE(bugprone-unchecked-optional-access)
199     return *scalarType;
200   }
201   AT_ERROR("Could not infer dtype of ", Py_TYPE(obj)->tp_name);
202 }
203 
recursive_store(char * data,IntArrayRef sizes,IntArrayRef strides,int64_t dim,ScalarType scalarType,size_t elementSize,PyObject * obj)204 void recursive_store(
205     char* data,
206     IntArrayRef sizes,
207     IntArrayRef strides,
208     int64_t dim,
209     ScalarType scalarType,
210     size_t elementSize,
211     PyObject* obj) {
212   TORCH_INTERNAL_ASSERT_DEBUG_ONLY(data != nullptr);
213 
214   int64_t ndim = static_cast<int64_t>(sizes.size());
215   bool is_symfloat = torch::is_symfloat(obj);
216   bool is_symint = torch::is_symint(obj);
217   if (dim == ndim) {
218     if (is_symfloat) {
219       auto new_obj = py::reinterpret_borrow<py::object>(obj);
220       auto val = new_obj.cast<c10::SymFloat>();
221       const double double_val = val.guard_float(__FILE__, __LINE__);
222       switch (elementSize) {
223         case 8:
224           *reinterpret_cast<double*>(data) = double_val;
225           break;
226         case 4:
227           *reinterpret_cast<float*>(data) = static_cast<float>(double_val);
228           break;
229       }
230       return;
231     }
232     if (is_symint) {
233       auto new_obj = py::reinterpret_borrow<py::object>(obj);
234       auto val = new_obj.cast<c10::SymInt>();
235       const auto int_val = val.guard_int(__FILE__, __LINE__);
236       switch (elementSize) {
237         case 8:
238           *reinterpret_cast<int64_t*>(data) = int_val;
239           break;
240         case 4:
241           *reinterpret_cast<int32_t*>(data) = static_cast<int32_t>(int_val);
242           break;
243         case 2:
244           *reinterpret_cast<int16_t*>(data) = static_cast<int16_t>(int_val);
245           break;
246         case 1:
247           *reinterpret_cast<int8_t*>(data) = static_cast<int8_t>(int_val);
248           break;
249         default:
250           TORCH_CHECK(false, "Unexpected elementSize ", elementSize);
251       }
252       return;
253     }
254     torch::utils::store_scalar(data, scalarType, obj);
255     return;
256   }
257 
258   auto n = sizes[dim];
259   auto seq = THPObjectPtr(PySequence_Fast(obj, "not a sequence"));
260   if (!seq)
261     throw python_error();
262   // NOLINTNEXTLINE(bugprone-branch-clone)
263   auto seq_size = PySequence_Fast_GET_SIZE(seq.get());
264   TORCH_CHECK_VALUE(
265       seq_size == n,
266       "expected sequence of length ",
267       n,
268       " at dim ",
269       dim,
270       " (got ",
271       seq_size,
272       ")");
273 
274   PyObject** items = PySequence_Fast_ITEMS(seq.get());
275   for (const auto i : c10::irange(n)) {
276 #ifdef USE_NUMPY
277     if (is_numpy_available() && PyArray_Check(items[i])) {
278       TORCH_WARN_ONCE(
279           "Creating a tensor from a list of numpy.ndarrays is extremely slow. "
280           "Please consider converting the list to a single numpy.ndarray with "
281           "numpy.array() before converting to a tensor.");
282     }
283 #endif
284     recursive_store(
285         data, sizes, strides, dim + 1, scalarType, elementSize, items[i]);
286     data += strides[dim] * elementSize;
287   }
288 }
289 
internal_new_from_data(c10::TensorOptions options,at::ScalarType scalar_type,std::optional<Device> device_opt,PyObject * data,bool copy_variables,bool copy_numpy,bool type_inference,bool pin_memory=false)290 Tensor internal_new_from_data(
291     c10::TensorOptions options,
292     at::ScalarType scalar_type,
293     std::optional<Device> device_opt,
294     PyObject* data,
295     bool copy_variables,
296     bool copy_numpy,
297     bool type_inference,
298     bool pin_memory = false) {
299   TORCH_CHECK_TYPE(
300       !THPUtils_checkString(data),
301       "new(): invalid data type '",
302       Py_TYPE(data)->tp_name,
303       "'");
304 
305   if (THPVariable_Check(data)) {
306     TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from a variable");
307     // TODO: use MaybeOwned
308     auto var = THPVariable_Unpack(data);
309     if (copy_variables) {
310       var = var.detach();
311     }
312     // infer the scalar type and device type; it's not expected to infer the
313     // layout since these constructors are defined per-layout-type (e.g. tensor
314     // vs sparse_coo_tensor).
315     const auto& inferred_scalar_type =
316         type_inference ? var.scalar_type() : scalar_type;
317     auto device = device_opt.has_value() ? *device_opt : var.device();
318     pybind11::gil_scoped_release no_gil;
319     maybe_initialize_device(device);
320     return var.to(
321         device,
322         inferred_scalar_type,
323         /*non_blocking=*/false,
324         /*copy=*/copy_variables);
325   }
326 
327 #ifdef USE_NUMPY
328   if (PyObject_HasAttrString(data, "__cuda_array_interface__")) {
329     TORCH_CHECK(
330         !pin_memory,
331         "Can't pin tensor constructed from __cuda_array_interface__");
332     auto tensor = tensor_from_cuda_array_interface(data);
333     const auto& inferred_scalar_type =
334         type_inference ? tensor.scalar_type() : scalar_type;
335 
336     // Device preference is:
337     //  - explicitly user specified device in `device_opt`
338     //      - either by setting device='...'
339     //      - or setting torch.set_default_device(...)
340     //  - device of already constructed tensor
341     // This prevents an unnecessary device -> host copy when the tensor is
342     // already on the device, while respecting a default device and allows the
343     // user to overwrite the behavior explicitly.
344     at::Device device = device_opt.has_value() ? *device_opt : tensor.device();
345 
346     pybind11::gil_scoped_release no_gil;
347     maybe_initialize_device(device);
348     return tensor.to(
349         device,
350         inferred_scalar_type,
351         /*non_blocking=*/false,
352         /*copy=*/copy_numpy);
353   }
354 
355   if (is_numpy_available() && PyArray_Check(data)) {
356     TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from numpy");
357     auto tensor =
358         tensor_from_numpy(data, /*warn_if_not_writeable=*/!copy_numpy);
359     const auto& inferred_scalar_type =
360         type_inference ? tensor.scalar_type() : scalar_type;
361     auto device = device_opt.has_value() ? *device_opt : options.device();
362     pybind11::gil_scoped_release no_gil;
363     maybe_initialize_device(device);
364     return tensor.to(
365         device,
366         inferred_scalar_type,
367         /*non_blocking=*/false,
368         /*copy=*/copy_numpy);
369   }
370 #endif
371 
372   auto device = device_opt.has_value() ? *device_opt : options.device();
373 
374   auto sizes = compute_sizes(data, scalar_type);
375 
376   ScalarType inferred_scalar_type =
377       type_inference ? infer_scalar_type(data) : scalar_type;
378   // This exists to prevent us from tracing the call to empty().  The actual
379   // autograd code doesn't really matter, because requires_grad is always false
380   // here.
381   // What are the semantics of tensor_new()?
382   // We manually construct a tensor and place on it on the correct device with
383   // empty() and to(). We then have to "lift" the newly constructed tensor in
384   // some cases, like when we're performing a functorch transform or running
385   // functionalization. The exclude guards are all to ensure that extra logic
386   // doesn't run when we're constructing the raw tensor.
387   Tensor tensor;
388   {
389     at::AutoDispatchBelowADInplaceOrView guard;
390     c10::impl::ExcludeDispatchKeyGuard torchdispatchmode_guard(
391         c10::DispatchKey::Python);
392     c10::impl::ExcludeDispatchKeyGuard torchdispatchmode_snapshot_guard(
393         c10::DispatchKey::PythonTLSSnapshot);
394     // functorch uses FuncTorchDynamicLayerBackMode as a mode key to wrap all
395     // tensors returned from operators in special TensorWrapper tensor extension
396     c10::impl::ExcludeDispatchKeyGuard functorch_front_guard(
397         c10::DispatchKey::FuncTorchDynamicLayerFrontMode);
398     c10::impl::ExcludeDispatchKeyGuard functorch_back_guard(
399         c10::DispatchKey::FuncTorchDynamicLayerBackMode);
400     // We disable Fake and DeferredInit handlers for similar reasons as
401     // functorch.
402     c10::impl::ExcludeDispatchKeyGuard fake_and_deferred_init_guard(
403         c10::DispatchKeySet{
404             c10::DispatchKey::Fake, c10::DispatchKey::DeferredInit});
405     // Note [Functionalization <> torch.Tensor constructor]
406     // Functionalization "lifts" the newly constructed tensor into a wrapper
407     // using aten::lift().
408     c10::impl::ExcludeDispatchKeyGuard functionalize_guard(
409         c10::DispatchKey::Functionalize);
410     {
411       // Tracing should probably also use the "lift" operator to add the tensor
412       // to a trace, but it's technically BC-breaking to do that, since we
413       // currently trace .to() calls.
414       at::tracer::impl::NoTracerDispatchMode tracer_guard;
415 
416       if (isStorage(data)) {
417         auto [storage, storage_scalar_type, is_typed_storage] =
418             createStorageGetType(data);
419 
420         TORCH_CHECK(
421             !is_typed_storage || storage_scalar_type == scalar_type,
422             "Expected a Storage of type ",
423             scalar_type,
424             " or an UntypedStorage, but got ",
425             storage_scalar_type);
426         tensor = at::empty(
427             sizes,
428             at::initialTensorOptions()
429                 .dtype(
430                     is_typed_storage ? storage_scalar_type
431                                      : inferred_scalar_type)
432                 .pinned_memory(pin_memory)
433                 .device(storage.device()));
434         tensor.set_(storage);
435 
436       } else {
437         TensorOptions opts =
438             at::initialTensorOptions().dtype(inferred_scalar_type);
439 
440         // If the device is Meta, take the shortcut. We don't want to allocate
441         // an empty CPU tensor which would break our contract for meta tensors.
442         if (device == at::kMeta) {
443           return at::empty(sizes, opts.device(device));
444         }
445         tensor = at::empty(sizes, opts.pinned_memory(pin_memory));
446         if (c10::multiply_integers(tensor.sizes()) != 0) {
447           recursive_store(
448               (char*)tensor.data_ptr(),
449               tensor.sizes(),
450               tensor.strides(),
451               0,
452               inferred_scalar_type,
453               tensor.dtype().itemsize(),
454               data);
455         }
456       }
457     }
458     pybind11::gil_scoped_release no_gil;
459     maybe_initialize_device(device);
460     // However, it is VERY important that we trace the to() call here (even
461     // though the reason this is important is a hack).  Without *some* factory
462     // function call that is traced at construction time, we will consider
463     // a tensor constant as originating from "outside" the trace, and if you
464     // try to return it directly we will fail with the error saying no
465     // "no observable data dependence".  In an ideal world, we wouldn't trace
466     // a to() call but I need to think harder about what exactly we should trace
467     // in this case.
468     if (only_lift_cpu_tensors()) {
469       tensor = tensor.to(
470           inferred_scalar_type, /*non_blocking=*/false, /*copy=*/false);
471 
472     } else {
473       tensor = tensor.to(
474           device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/false);
475     }
476   }
477 
478   // torch.jit.trace will continue to trace out `.to()` instead of `.lift()`,
479   // since changing it is BC-breaking.
480   at::tracer::impl::NoTracerDispatchMode tracer_guard;
481   {
482     // lift has no autograd implementation, so we need to make sure we don't try
483     // to dispatch to it.
484     // TODO: arguably it should have an autograd implementation that noops
485     at::AutoDispatchBelowADInplaceOrView guard;
486     tensor = at::lift_fresh(tensor);
487   }
488   if (only_lift_cpu_tensors() && device.type() != DeviceType::CPU) {
489     if (!device.has_index() &&
490         !torch::utils::is_device_initialized(device.type())) {
491       // Infer device 0 to avoid device init
492       device = c10::Device(device.type(), 0);
493     }
494     tensor = tensor.to(device, /*non_blocking=*/false, /*copy=*/false);
495   }
496   return tensor;
497 }
498 
new_from_data_copy(c10::TensorOptions options,at::ScalarType scalar_type,std::optional<Device> device,PyObject * data)499 Tensor new_from_data_copy(
500     c10::TensorOptions options,
501     at::ScalarType scalar_type,
502     std::optional<Device> device,
503     PyObject* data) {
504   return internal_new_from_data(
505       options,
506       scalar_type,
507       device,
508       data,
509       /*copy_variables=*/true,
510       /*copy_numpy=*/true,
511       /*type_inference=*/false);
512 }
513 
legacy_new_from_sequence(c10::TensorOptions options,at::ScalarType scalar_type,std::optional<Device> device,PyObject * data)514 Tensor legacy_new_from_sequence(
515     c10::TensorOptions options,
516     at::ScalarType scalar_type,
517     std::optional<Device> device,
518     PyObject* data) {
519   TORCH_CHECK_TYPE(
520       PySequence_Check(data),
521       "new(): data must be a sequence (got ",
522       Py_TYPE(data)->tp_name,
523       ")");
524   return internal_new_from_data(
525       options,
526       scalar_type,
527       device,
528       data,
529       /*copy_variables=*/false,
530       /*copy_numpy=*/false,
531       /*type_inference=*/false);
532 }
533 
534 // "base" here refers to the Tensor type on which the function was invoked,
535 // e.g.: in x.new(y), 'x' is the base.
536 // TODO: Rewrite this using dispatchKeyToTensorOptions
check_base_legacy_new(c10::DispatchKey dispatch_key,at::Layout expected_layout)537 void check_base_legacy_new(
538     c10::DispatchKey dispatch_key,
539     at::Layout expected_layout) {
540   if (expected_layout == c10::kStrided) {
541     constexpr c10::DispatchKeySet expected_key_set({
542         c10::DispatchKey::CPU,
543         c10::DispatchKey::CUDA,
544         c10::DispatchKey::HIP,
545         c10::DispatchKey::XLA,
546         c10::DispatchKey::Lazy,
547         c10::DispatchKey::IPU,
548         c10::DispatchKey::XPU,
549         c10::DispatchKey::HPU,
550         c10::DispatchKey::MPS,
551         c10::DispatchKey::Meta,
552         c10::DispatchKey::PrivateUse1,
553     });
554     TORCH_CHECK(
555         expected_key_set.has(dispatch_key),
556         "new(): expected key in ",
557         expected_key_set,
558         " but got: ",
559         dispatch_key);
560   } else if (expected_layout == c10::kSparse) {
561     // NOTE: no sparse XLA or Lazy
562     constexpr c10::DispatchKeySet expected_key_set({
563         c10::DispatchKey::SparseCPU,
564         c10::DispatchKey::SparseCUDA,
565         c10::DispatchKey::SparseHIP,
566         c10::DispatchKey::SparseXPU,
567         c10::DispatchKey::SparsePrivateUse1,
568     });
569     TORCH_CHECK(
570         expected_key_set.has(dispatch_key),
571         "new(): expected key in ",
572         expected_key_set,
573         " but got: ",
574         dispatch_key);
575   } else {
576     TORCH_INTERNAL_ASSERT(false, "unexpected layout");
577   }
578 }
579 
580 // TODO: Make this accept options instead of dispatch key
check_legacy_ctor_device(c10::DispatchKey dispatch_key,std::optional<Device> device)581 void check_legacy_ctor_device(
582     c10::DispatchKey dispatch_key,
583     std::optional<Device> device) {
584   if (device.has_value()) {
585     TORCH_CHECK(
586         dispatchKeyToDeviceType(dispatch_key) == device.value().type(),
587         "legacy constructor expects device type: ",
588         dispatchKeyToDeviceType(dispatch_key),
589         " but device type: ",
590         device.value().type(),
591         " was passed");
592   }
593 }
594 
595 enum class CtorOrNew {
596   BASE_CTOR,
597   CTOR,
598   NEW,
599 };
600 
legacy_sparse_tensor_generic_ctor_new(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs,CtorOrNew ctor_or_new)601 Tensor legacy_sparse_tensor_generic_ctor_new(
602     c10::DispatchKey dispatch_key,
603     at::ScalarType scalar_type,
604     PyObject* args,
605     PyObject* kwargs,
606     CtorOrNew ctor_or_new) {
607   auto options = dispatchKeyToTensorOptions(dispatch_key);
608   static PythonArgParser parser({
609       "new(*, Device? device=None)",
610       "new(*, int64_t cdata)|hidden",
611       "new(Tensor indices, Tensor values, *, Device? device=None)",
612       "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)",
613       "new(SymIntArrayRef size, *, Device? device=None)",
614   });
615   if (ctor_or_new == CtorOrNew::NEW)
616     check_base_legacy_new(dispatch_key, c10::kSparse);
617   ParsedArgs<4> parsed_args;
618   auto r = parser.parse(args, kwargs, parsed_args);
619   if (r.idx == 0) {
620     if (ctor_or_new == CtorOrNew::CTOR) {
621       TORCH_WARN_ONCE(
622           "torch.sparse.SparseTensor() is deprecated."
623           "  Please use torch.sparse_coo_tensor((0,), dtype=).");
624     }
625     auto deviceOptional = r.deviceOptional(0);
626     check_legacy_ctor_device(dispatch_key, deviceOptional);
627     return at::empty({0}, build_options(options, scalar_type, deviceOptional));
628   } else if (r.idx == 1) {
629     if (ctor_or_new == CtorOrNew::CTOR) {
630       TORCH_WARN_ONCE(
631           "torch.sparse.SparseTensor(cdata=x._cdata) is deprecated."
632           "  Please use torch.sparse_coo_tensor(x._indices(), x._values(), x.shape).");
633     }
634     // NOLINTNEXTLINE(performance-no-int-to-ptr)
635     auto cdata = reinterpret_cast<void*>(r.toInt64(0));
636     return at::unsafeTensorFromTH(cdata, true);
637   } else if (r.idx == 2) {
638     if (ctor_or_new == CtorOrNew::CTOR) {
639       TORCH_WARN_ONCE(
640           "torch.sparse.SparseTensor(indices, values, *, device=) is deprecated."
641           "  Please use torch.sparse_coo_tensor(indices, values, dtype=, device=).");
642     }
643     // Note: this signature doesn't have a dtype, even though it has a device;
644     // it probably shouldn't have a device (we should infer it).
645     auto deviceOptional = r.deviceOptional(2);
646     check_legacy_ctor_device(dispatch_key, deviceOptional);
647     at::OptionalDeviceGuard device_guard(deviceOptional);
648     return at::sparse_coo_tensor(r.tensor(0), r.tensor(1));
649   } else if (r.idx == 3) {
650     if (ctor_or_new == CtorOrNew::CTOR) {
651       TORCH_WARN_ONCE(
652           "torch.sparse.SparseTensor(indices, values, shape, *, device=) is deprecated."
653           "  Please use torch.sparse_coo_tensor(indices, values, shape, dtype=, device=).");
654     }
655     // Note: this signature doesn't have a dtype, even though it has a device;
656     // it probably shouldn't have a device (we should infer it).
657     auto deviceOptional = r.deviceOptional(3);
658     check_legacy_ctor_device(dispatch_key, deviceOptional);
659     at::OptionalDeviceGuard device_guard(deviceOptional);
660     return at::sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2));
661   } else if (r.idx == 4) {
662     PyObject* arg = r.pyobject(0);
663     auto deviceOptional = r.deviceOptional(1);
664     check_legacy_ctor_device(dispatch_key, deviceOptional);
665     if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 &&
666         arg == PyTuple_GET_ITEM(args, 0)) {
667       // new(sequence) binds to this signature but should be treated differently
668       // unless the sequences is a torch.Size
669       if (ctor_or_new == CtorOrNew::CTOR) {
670         throw TypeError(
671             "torch.sparse.SparseTensor(sequence) only accepts sizes.  Please use torch.sparse_coo_tensor() "
672             "or construct a strided tensor and convert it to sparse via to_sparse.");
673       } else {
674         throw TypeError(
675             "SparseTensor.new(sequence) only accepts sizes.  Please use torch.sparse_coo_tensor() "
676             "or construct a strided tensor and convert it to sparse via to_sparse.");
677       }
678     }
679     if (ctor_or_new == CtorOrNew::CTOR) {
680       TORCH_WARN_ONCE(
681           "torch.sparse.SparseTensor(shape, *, device=) is deprecated."
682           "  Please use torch.sparse_coo_tensor(shape, dtype=, device=).");
683     }
684     return new_with_sizes(
685         options, scalar_type, r.deviceOptional(1), r.symintlist(0));
686   }
687   throw std::runtime_error("new(): invalid arguments");
688 }
689 
690 // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs
typeIdWithDefault(PythonArgs & r,int64_t device_idx,c10::DispatchKey dispatch_key)691 c10::TensorOptions typeIdWithDefault(
692     PythonArgs& r,
693     int64_t device_idx,
694     c10::DispatchKey dispatch_key) {
695   auto options = dispatchKeyToTensorOptions(dispatch_key);
696   if (!r.isNone(static_cast<int>(device_idx))) {
697     // TODO: This line doesn't seem to be exercised at all in tests
698     options = options.device(r.device(static_cast<int>(device_idx)).type());
699   }
700   return options;
701 }
702 
703 } // namespace
704 
legacy_tensor_generic_ctor_new(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs,CtorOrNew ctor_or_new)705 Tensor legacy_tensor_generic_ctor_new(
706     c10::DispatchKey dispatch_key,
707     at::ScalarType scalar_type,
708     PyObject* args,
709     PyObject* kwargs,
710     CtorOrNew ctor_or_new) {
711   auto options = dispatchKeyToTensorOptions(dispatch_key);
712   static PythonArgParser parser({
713       "new(*, Device? device=None)",
714       "new(Storage storage)",
715       "new(*, int64_t cdata)|hidden",
716       // This constructor is no longer legacy, it will also be usable for
717       // subclass initialization
718       "new(Tensor other)",
719       "new(Tensor other, *, Device? device=None)|hidden", // prevent Tensor
720                                                           // matching with
721                                                           // IntArrayRef,
722                                                           // PyObject*
723       "new(SymIntArrayRef size, *, Device? device=None)",
724       "new(PyObject* data, *, Device? device=None)",
725   });
726 
727   if (isSparse(dispatchKeyToBackend(dispatch_key))) {
728     return legacy_sparse_tensor_generic_ctor_new(
729         dispatch_key, scalar_type, args, kwargs, ctor_or_new);
730   }
731 
732   if (ctor_or_new == CtorOrNew::NEW)
733     check_base_legacy_new(dispatch_key, c10::kStrided);
734 
735   ParsedArgs<2> parsed_args;
736   auto r = parser.parse(args, kwargs, parsed_args);
737   if (r.idx == 0) {
738     auto deviceOptional = r.deviceOptional(0);
739     check_legacy_ctor_device(dispatch_key, deviceOptional);
740     at::OptionalDeviceGuard device_guard(deviceOptional);
741     return at::empty({0}, build_options(options, scalar_type));
742   } else if (r.idx == 1) {
743     at::ScalarType storage_scalar_type{at::ScalarType::Undefined};
744     bool is_typed_storage = false;
745     at::Storage storage = r.storage(0, storage_scalar_type, is_typed_storage);
746     if (storage_scalar_type != at::ScalarType::Undefined && is_typed_storage) {
747       TORCH_CHECK(
748           storage_scalar_type == scalar_type,
749           "Expected a Storage of type ",
750           scalar_type,
751           " or an UntypedStorage, but got type ",
752           storage_scalar_type,
753           " for argument 1 'storage'");
754     }
755     return new_with_storage(options, scalar_type, storage);
756   } else if (r.idx == 2) {
757     // NOLINTNEXTLINE(performance-no-int-to-ptr)
758     auto cdata = reinterpret_cast<void*>(r.toInt64(0));
759     return at::unsafeTensorFromTH(cdata, true);
760   } else if (r.idx == 3) {
761     const auto& other = r.tensor(0);
762     // BASE_CTOR (aka torch.Tensor) is now relaxed to accept any
763     // dtype; previously it was "float" biased
764     if (ctor_or_new != CtorOrNew::BASE_CTOR) {
765       options = options.dtype(scalar_type);
766       TORCH_CHECK_TYPE(
767           other.options().type_equal(options),
768           "expected ",
769           options,
770           " (got ",
771           other.options(),
772           ")");
773     }
774     return other.alias();
775   } else if (r.idx == 4) {
776     if (ctor_or_new == CtorOrNew::CTOR || ctor_or_new == CtorOrNew::BASE_CTOR) {
777       TORCH_CHECK(
778           false,
779           "Legacy tensor constructor of the form torch.Tensor(tensor, device=device) "
780           "is not supported.  Use torch.tensor(...) or torch.as_tensor(...) instead.");
781     } else {
782       TORCH_CHECK(
783           false,
784           "Legacy tensor new of the form tensor.new(tensor, device=device) "
785           "is not supported.  Use torch.as_tensor(...) instead.");
786     }
787   } else if (r.idx == 5) {
788     PyObject* arg = r.pyobject(0);
789     auto deviceOptional = r.deviceOptional(1);
790     check_legacy_ctor_device(dispatch_key, deviceOptional);
791     if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 &&
792         arg == PyTuple_GET_ITEM(args, 0)) {
793       // new(sequence) binds to this signature but should be treated differently
794       // unless the sequences is a torch.Size
795       return legacy_new_from_sequence(
796           options, scalar_type, deviceOptional, r.pyobject(0));
797     }
798     return new_with_sizes(
799         options, scalar_type, r.deviceOptional(1), r.symintlist(0));
800   } else if (r.idx == 6) {
801     auto deviceOptional = r.deviceOptional(1);
802     check_legacy_ctor_device(dispatch_key, deviceOptional);
803     return legacy_new_from_sequence(
804         options, scalar_type, deviceOptional, r.pyobject(0));
805   }
806   throw std::runtime_error("new(): invalid arguments");
807 }
808 
809 // Handles ONLY torch.Tensor
810 // Unlike the legacy dtype/device specialized constructors, this one is
811 // relaxed to accept any device/dtype input tensor (even if it doesn't
812 // match the default)
base_tensor_ctor(PyObject * args,PyObject * kwargs)813 Tensor base_tensor_ctor(PyObject* args, PyObject* kwargs) {
814   return legacy_tensor_generic_ctor_new(
815       torch::tensors::get_default_dispatch_key(),
816       torch::tensors::get_default_scalar_type(),
817       args,
818       kwargs,
819       CtorOrNew::BASE_CTOR);
820 }
821 
822 // Handles calls like torch.DoubleTensor, torch.cuda.FloatTensor,
823 // torch.sparse.FloatTensor, etc.
legacy_tensor_ctor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)824 Tensor legacy_tensor_ctor(
825     c10::DispatchKey dispatch_key,
826     at::ScalarType scalar_type,
827     PyObject* args,
828     PyObject* kwargs) {
829   return legacy_tensor_generic_ctor_new(
830       dispatch_key, scalar_type, args, kwargs, CtorOrNew::CTOR);
831 }
832 
833 // Handles tensor.new(...)
legacy_tensor_new(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)834 Tensor legacy_tensor_new(
835     c10::DispatchKey dispatch_key,
836     at::ScalarType scalar_type,
837     PyObject* args,
838     PyObject* kwargs) {
839   return legacy_tensor_generic_ctor_new(
840       dispatch_key, scalar_type, args, kwargs, CtorOrNew::NEW);
841 }
842 
indexing_tensor_from_data(c10::TensorOptions options,at::ScalarType scalar_type,std::optional<Device> device,PyObject * data)843 Tensor indexing_tensor_from_data(
844     c10::TensorOptions options,
845     at::ScalarType scalar_type,
846     std::optional<Device> device,
847     PyObject* data) {
848   // Specific to tensor indexing, converts an indexing list to an
849   // indexing tensor (type Byte or Long)
850   ScalarType inferred_scalar_type = infer_scalar_type(data);
851   if (inferred_scalar_type == ScalarType::Byte ||
852       inferred_scalar_type == ScalarType::Bool) {
853     return internal_new_from_data(
854         options,
855         inferred_scalar_type,
856         device,
857         data,
858         /*copy_variables=*/false,
859         /*copy_numpy=*/false,
860         /*type_inference=*/false);
861   } else {
862     return internal_new_from_data(
863         options,
864         scalar_type,
865         device,
866         data,
867         /*copy_variables=*/false,
868         /*copy_numpy=*/false,
869         /*type_inference=*/false);
870   }
871 }
872 
873 class CheckSparseTensorInvariantsContext {
874  public:
CheckSparseTensorInvariantsContext()875   CheckSparseTensorInvariantsContext()
876       : state{at::globalContext().checkSparseTensorInvariants()} {}
~CheckSparseTensorInvariantsContext()877   ~CheckSparseTensorInvariantsContext() {
878     at::globalContext().setCheckSparseTensorInvariants(state);
879   }
880 
881  private:
882   bool state;
883 };
884 
sparse_compressed_tensor_ctor_worker(const std::string & name,c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r,std::optional<c10::Layout> required_layout)885 static Tensor sparse_compressed_tensor_ctor_worker(
886     const std::string& name,
887     c10::DispatchKey dispatch_key,
888     at::ScalarType scalar_type,
889     PythonArgs& r,
890     std::optional<c10::Layout> required_layout) {
891   TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
892   TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
893   enum {
894     ARG_COMPRESSED_INDICES = 0,
895     ARG_PLAIN_INDICES,
896     ARG_VALUES,
897     ARG_SIZE,
898     ARG_TYPE,
899     ARG_LAYOUT,
900     ARG_DEVICE,
901     ARG_PIN_MEMORY,
902     ARG_REQUIRES_GRAD,
903     ARG_CHECK_INVARIANTS,
904     ARGS_COUNT
905   };
906   enum {
907     ARG_VALUES1 = ARG_VALUES,
908     ARG_TYPE1,
909     ARG_LAYOUT1,
910     ARG_DEVICE1,
911     ARG_PIN_MEMORY1,
912     ARG_REQUIRES_GRAD1,
913     ARG_CHECK_INVARIANTS1,
914     ARGS_COUNT1
915   };
916 
917   auto safe_get_attr_string = [](PyObject* o,
918                                  const char* attr_name) -> PyObject* {
919     // Clear error indicator if attribute does not exists.
920     // Otherwise subsequent Python C API calls might return bogus values.
921     // See https://github.com/pytorch/pytorch/issues/58520 for more details
922     auto rc = PyObject_GetAttrString(o, attr_name);
923     if (!rc) {
924       if (!PyErr_ExceptionMatches(PyExc_AttributeError)) {
925         throw python_error();
926       }
927       // Warning: a wrong attribute error may be suppressed here
928       PyErr_Clear();
929     }
930     return rc;
931   };
932   THPObjectPtr compressed_indices_dtype_attr(
933       safe_get_attr_string(r.pyobject(ARG_COMPRESSED_INDICES), "dtype"));
934   THPObjectPtr plain_indices_dtype_attr(
935       safe_get_attr_string(r.pyobject(ARG_PLAIN_INDICES), "dtype"));
936   at::ScalarType compressed_indices_scalar_type = compressed_indices_dtype_attr
937       ? reinterpret_cast<THPDtype*>(compressed_indices_dtype_attr.get())
938             ->scalar_type
939       : kInt;
940   at::ScalarType plain_indices_scalar_type = plain_indices_dtype_attr
941       ? reinterpret_cast<THPDtype*>(plain_indices_dtype_attr.get())->scalar_type
942       : kInt;
943   CheckSparseTensorInvariantsContext
944       restores_check_sparse_tensor_invariants_global_state{};
945   bool default_check_invariants =
946       at::globalContext().checkSparseTensorInvariants();
947 
948   if (r.idx == 0) {
949     const bool pin_memory = r.toBool(ARG_PIN_MEMORY);
950     bool type_inference = r.isNone(ARG_TYPE);
951     const auto inferred_options =
952         typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
953     const auto inferred_scalar_type =
954         r.scalartypeWithDefault(ARG_TYPE, scalar_type);
955     at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
956     // the global state of invariants check flag will be restored via
957     // CheckSparseTensorInvariantsContext destructor
958     at::globalContext().setCheckSparseTensorInvariants(
959         r.toBoolWithDefault(ARG_CHECK_INVARIANTS, default_check_invariants));
960     Tensor values = internal_new_from_data(
961         inferred_options,
962         inferred_scalar_type,
963         r.deviceOptional(ARG_DEVICE),
964         r.pyobject(ARG_VALUES),
965         /*copy_variables=*/false,
966         /*copy_numpy=*/true,
967         /*type_inference=*/type_inference);
968     Tensor compressed_indices = internal_new_from_data(
969         values.options(),
970         compressed_indices_scalar_type,
971         r.deviceOptional(ARG_DEVICE),
972         r.pyobject(ARG_COMPRESSED_INDICES),
973         /*copy_variables=*/false,
974         /*copy_numpy=*/true,
975         /*type_inference=*/true);
976     Tensor plain_indices = internal_new_from_data(
977         values.options(),
978         plain_indices_scalar_type,
979         r.deviceOptional(ARG_DEVICE),
980         r.pyobject(ARG_PLAIN_INDICES),
981         /*copy_variables=*/false,
982         /*copy_numpy=*/true,
983         /*type_inference=*/true);
984     std::optional<c10::Layout> layout =
985         (required_layout
986              ? r.layoutWithDefault(ARG_LAYOUT, required_layout.value())
987              : r.layoutOptional(ARG_LAYOUT));
988     if (required_layout) {
989       TORCH_CHECK(
990           layout.value() == required_layout.value(),
991           name,
992           ": layout must be ",
993           required_layout.value(),
994           " but got ",
995           layout.value());
996     }
997     return at::sparse_compressed_tensor(
998                compressed_indices,
999                plain_indices,
1000                values,
1001                r.intlist(ARG_SIZE),
1002                values.options().layout(layout).pinned_memory(pin_memory))
1003         .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
1004   } else if (r.idx == 1) {
1005     bool type_inference = r.isNone(ARG_TYPE1);
1006     const auto inferred_options =
1007         typeIdWithDefault(r, ARG_DEVICE1, dispatch_key);
1008     const auto inferred_scalar_type =
1009         r.scalartypeWithDefault(ARG_TYPE1, scalar_type);
1010     at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1));
1011     const bool pin_memory = r.toBool(ARG_PIN_MEMORY1);
1012     // the global state of invariants check flag will be restored via
1013     // CheckSparseTensorInvariantsContext destructor
1014     at::globalContext().setCheckSparseTensorInvariants(
1015         r.toBoolWithDefault(ARG_CHECK_INVARIANTS1, default_check_invariants));
1016     Tensor values = internal_new_from_data(
1017         inferred_options,
1018         inferred_scalar_type,
1019         r.deviceOptional(ARG_DEVICE1),
1020         r.pyobject(ARG_VALUES),
1021         /*copy_variables=*/false,
1022         /*copy_numpy=*/true,
1023         /*type_inference=*/type_inference);
1024     Tensor compressed_indices = internal_new_from_data(
1025         values.options(),
1026         compressed_indices_scalar_type,
1027         r.deviceOptional(ARG_DEVICE1),
1028         r.pyobject(ARG_COMPRESSED_INDICES),
1029         /*copy_variables=*/false,
1030         /*copy_numpy=*/true,
1031         /*type_inference=*/true);
1032     Tensor plain_indices = internal_new_from_data(
1033         values.options(),
1034         plain_indices_scalar_type,
1035         r.deviceOptional(ARG_DEVICE1),
1036         r.pyobject(ARG_PLAIN_INDICES),
1037         /*copy_variables=*/false,
1038         /*copy_numpy=*/true,
1039         /*type_inference=*/true);
1040     std::optional<c10::Layout> layout =
1041         (required_layout
1042              ? r.layoutWithDefault(ARG_LAYOUT1, required_layout.value())
1043              : r.layoutOptional(ARG_LAYOUT1));
1044     if (required_layout) {
1045       TORCH_CHECK(
1046           layout.value() == required_layout.value(),
1047           name,
1048           ": layout must be ",
1049           required_layout.value(),
1050           " but got ",
1051           layout.value());
1052     }
1053     return at::sparse_compressed_tensor(
1054                compressed_indices,
1055                plain_indices,
1056                values,
1057                values.options().layout(layout).pinned_memory(pin_memory))
1058         .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1));
1059   }
1060   throw std::runtime_error(name + ": invalid arguments");
1061 }
1062 
sparse_compressed_tensor_ctor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r)1063 Tensor sparse_compressed_tensor_ctor(
1064     c10::DispatchKey dispatch_key,
1065     at::ScalarType scalar_type,
1066     PythonArgs& r) {
1067   std::optional<c10::Layout> required_layout{};
1068   return sparse_compressed_tensor_ctor_worker(
1069       "sparse_compressed_tensor",
1070       dispatch_key,
1071       scalar_type,
1072       r,
1073       required_layout);
1074 }
1075 
sparse_csr_tensor_ctor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r)1076 Tensor sparse_csr_tensor_ctor(
1077     c10::DispatchKey dispatch_key,
1078     at::ScalarType scalar_type,
1079     PythonArgs& r) {
1080   std::optional<c10::Layout> required_layout(c10::Layout::SparseCsr);
1081   return sparse_compressed_tensor_ctor_worker(
1082       "sparse_csr_tensor", dispatch_key, scalar_type, r, required_layout);
1083 }
1084 
sparse_csc_tensor_ctor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r)1085 Tensor sparse_csc_tensor_ctor(
1086     c10::DispatchKey dispatch_key,
1087     at::ScalarType scalar_type,
1088     PythonArgs& r) {
1089   std::optional<c10::Layout> required_layout(c10::Layout::SparseCsc);
1090   return sparse_compressed_tensor_ctor_worker(
1091       "sparse_csc_tensor", dispatch_key, scalar_type, r, required_layout);
1092 }
1093 
sparse_bsr_tensor_ctor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r)1094 Tensor sparse_bsr_tensor_ctor(
1095     c10::DispatchKey dispatch_key,
1096     at::ScalarType scalar_type,
1097     PythonArgs& r) {
1098   std::optional<c10::Layout> required_layout(c10::Layout::SparseBsr);
1099   return sparse_compressed_tensor_ctor_worker(
1100       "sparse_bsr_tensor", dispatch_key, scalar_type, r, required_layout);
1101 }
1102 
sparse_bsc_tensor_ctor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r)1103 Tensor sparse_bsc_tensor_ctor(
1104     c10::DispatchKey dispatch_key,
1105     at::ScalarType scalar_type,
1106     PythonArgs& r) {
1107   std::optional<c10::Layout> required_layout(c10::Layout::SparseBsc);
1108   return sparse_compressed_tensor_ctor_worker(
1109       "sparse_bsc_tensor", dispatch_key, scalar_type, r, required_layout);
1110 }
1111 
1112 // Note [Ensuring sparse values and indices match devices]
1113 // ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
1114 // In all places where we construct indices, we read out options from values
1115 // (rather than use inferred_options).  Why?  This handles the case when
1116 // values is a CUDA tensor, but indices is a non-Tensor value (and the device
1117 // argument is not set).  Example:
1118 //
1119 //  torch.sparse_coo_tensor(([0, 1],), self.empty(2, 0).cuda(), (4, 0))
1120 //
1121 // Sparse tensors require both indices and values to live on the same device.
1122 // If values lives on CUDA, we can infer where the indices should live, and
1123 // should accept even ordinary index sequences (and just make sure we write them
1124 // into the correct device).  values is the ONLY way we know that the index
1125 // tensor should go to CUDA, so we have to get the information in somehow.
1126 //
1127 // This code is kind of jank.  For one, the dtype in options is silently ignored
1128 // by internal_new_from_data.  Also, in classic janky code style, it used to
1129 // not work quite right: if values lives on "cuda:1", before all we said was
1130 // "this needs to be CUDA" and indices would be allocated on the wrong tensor.
1131 // Options is more right and gets this correct.
1132 
sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r)1133 Tensor sparse_coo_tensor_ctor(
1134     c10::DispatchKey dispatch_key,
1135     at::ScalarType scalar_type,
1136     PythonArgs& r) {
1137   TORCH_INTERNAL_ASSERT(!isSparse(dispatchKeyToBackend(dispatch_key)));
1138   TORCH_INTERNAL_ASSERT(!isSparseCsr(dispatchKeyToBackend(dispatch_key)));
1139   enum {
1140     ARG_INDICES = 0,
1141     ARG_VALUES,
1142     ARG_TYPE,
1143     ARG_DEVICE,
1144     ARG_PIN_MEMORY,
1145     ARG_REQUIRES_GRAD,
1146     ARG_CHECK_INVARIANTS,
1147     ARGS_COUNT
1148   };
1149   enum {
1150     ARG_INDICES1 = 0,
1151     ARG_VALUES1,
1152     ARG_SIZE1,
1153     ARG_TYPE1,
1154     ARG_DEVICE1,
1155     ARG_PIN_MEMORY1,
1156     ARG_REQUIRES_GRAD1,
1157     ARG_CHECK_INVARIANTS1,
1158     ARG_IS_COALESCED1,
1159     ARGS_COUNT1
1160   };
1161   enum {
1162     ARG_SIZE2 = 0,
1163     ARG_TYPE2,
1164     ARG_DEVICE2,
1165     ARG_REQUIRES_GRAD2,
1166     ARG_CHECK_INVARIANTS2,
1167     ARGS_COUNT2
1168   };
1169 
1170   CheckSparseTensorInvariantsContext
1171       restores_check_sparse_tensor_invariants_global_state{};
1172   bool default_check_invariants =
1173       at::globalContext().checkSparseTensorInvariants();
1174 
1175   if (r.idx == 0) {
1176     bool pin_memory = r.toBool(ARG_PIN_MEMORY);
1177     bool type_inference = r.isNone(ARG_TYPE);
1178     const auto inferred_options =
1179         typeIdWithDefault(r, ARG_DEVICE, dispatch_key);
1180     const auto inferred_scalar_type =
1181         r.scalartypeWithDefault(ARG_TYPE, scalar_type);
1182     at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE));
1183     at::globalContext().setCheckSparseTensorInvariants(
1184         r.toBoolWithDefault(ARG_CHECK_INVARIANTS, default_check_invariants));
1185 
1186     // if no dtype provided, infer type based on value type.
1187     Tensor values = internal_new_from_data(
1188         inferred_options,
1189         inferred_scalar_type,
1190         r.deviceOptional(ARG_DEVICE),
1191         r.pyobject(ARG_VALUES),
1192         /*copy_variables=*/false,
1193         /*copy_numpy=*/true,
1194         /*type_inference=*/type_inference);
1195     // See Note [Ensuring sparse values and indices match devices]
1196     Tensor indices = internal_new_from_data(
1197         values.options(),
1198         kLong,
1199         r.deviceOptional(ARG_DEVICE),
1200         r.pyobject(ARG_INDICES),
1201         /*copy_variables=*/false,
1202         /*copy_numpy=*/true,
1203         /*type_inference=*/false);
1204     return at::sparse_coo_tensor(
1205                indices,
1206                values,
1207                values.options().layout(at::kSparse).pinned_memory(pin_memory))
1208         .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD));
1209   } else if (r.idx == 1) {
1210     bool pin_memory = r.toBool(ARG_PIN_MEMORY1);
1211     bool type_inference = r.isNone(ARG_TYPE1);
1212     const auto inferred_options =
1213         typeIdWithDefault(r, ARG_DEVICE1, dispatch_key);
1214     const auto inferred_scalar_type =
1215         r.scalartypeWithDefault(ARG_TYPE1, scalar_type);
1216     at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE1));
1217     at::globalContext().setCheckSparseTensorInvariants(
1218         r.toBoolWithDefault(ARG_CHECK_INVARIANTS1, default_check_invariants));
1219 
1220     Tensor values = internal_new_from_data(
1221         inferred_options,
1222         inferred_scalar_type,
1223         r.deviceOptional(ARG_DEVICE1),
1224         r.pyobject(ARG_VALUES1),
1225         /*copy_variables=*/false,
1226         /*copy_numpy=*/true,
1227         /*type_inference=*/type_inference);
1228     // See Note [Ensuring sparse values and indices match devices]
1229     Tensor indices = internal_new_from_data(
1230         values.options(),
1231         kLong,
1232         r.deviceOptional(ARG_DEVICE1),
1233         r.pyobject(ARG_INDICES1),
1234         /*copy_variables=*/false,
1235         /*copy_numpy=*/true,
1236         /*type_inference=*/false);
1237     return at::sparse_coo_tensor(
1238                indices,
1239                values,
1240                r.intlist(ARG_SIZE1),
1241                values.options().layout(at::kSparse).pinned_memory(pin_memory),
1242                r.toBoolOptional(ARG_IS_COALESCED1))
1243         .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD1));
1244   } else if (r.idx == 2) {
1245     const auto inferred_options =
1246         typeIdWithDefault(r, ARG_DEVICE2, dispatch_key);
1247     const auto inferred_scalar_type =
1248         r.scalartypeWithDefault(ARG_TYPE2, scalar_type);
1249     at::OptionalDeviceGuard device_guard(r.deviceOptional(ARG_DEVICE2));
1250     at::globalContext().setCheckSparseTensorInvariants(
1251         r.toBoolWithDefault(ARG_CHECK_INVARIANTS2, default_check_invariants));
1252 
1253     return at::sparse_coo_tensor(
1254                r.intlist(ARG_SIZE2),
1255                inferred_options.dtype(inferred_scalar_type).layout(at::kSparse))
1256         .set_requires_grad(r.toBool(ARG_REQUIRES_GRAD2));
1257   }
1258   throw std::runtime_error("sparse_coo_tensor(): invalid arguments");
1259 }
1260 
_validate_sparse_coo_tensor_args(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)1261 void _validate_sparse_coo_tensor_args(
1262     c10::DispatchKey dispatch_key,
1263     at::ScalarType scalar_type,
1264     PyObject* args,
1265     PyObject* kwargs) {
1266   auto options = dispatchKeyToTensorOptions(dispatch_key);
1267   static PythonArgParser parser({
1268       "_validate_sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size)",
1269   });
1270 
1271   ParsedArgs<3> parsed_args;
1272   auto r = parser.parse(args, kwargs, parsed_args);
1273   Tensor values = internal_new_from_data(
1274       options,
1275       scalar_type,
1276       std::nullopt,
1277       r.pyobject(1),
1278       /*copy_variables=*/false,
1279       /*copy_numpy=*/true,
1280       /*type_inference=*/true);
1281   // See Note [Ensuring sparse values and indices match devices]
1282   Tensor indices = internal_new_from_data(
1283       values.options(),
1284       kLong,
1285       std::nullopt,
1286       r.pyobject(0),
1287       /*copy_variables=*/false,
1288       /*copy_numpy=*/true,
1289       /*type_inference=*/false);
1290   at::native::_validate_sparse_coo_tensor_args(indices, values, r.intlist(2));
1291 }
1292 
_validate_sparse_compressed_tensor_args(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)1293 void _validate_sparse_compressed_tensor_args(
1294     c10::DispatchKey dispatch_key,
1295     at::ScalarType scalar_type,
1296     PyObject* args,
1297     PyObject* kwargs) {
1298   auto options = dispatchKeyToTensorOptions(dispatch_key);
1299   enum {
1300     ARG_COMPRESSED_INDICES = 0,
1301     ARG_PLAIN_INDICES,
1302     ARG_VALUES,
1303     ARG_SIZE,
1304     ARG_LAYOUT,
1305     ARGS_COUNT
1306   };
1307 
1308   const std::string signature =
1309       "_validate_sparse_compressed_tensor(PyObject* compressed_indices, PyObject* plain_indices, PyObject* values, IntArrayRef size, Layout layout)";
1310   static PythonArgParser parser({signature});
1311 
1312   ParsedArgs<ARGS_COUNT> parsed_args;
1313   auto r = parser.parse(args, kwargs, parsed_args);
1314   Tensor values = internal_new_from_data(
1315       options,
1316       scalar_type,
1317       std::nullopt,
1318       r.pyobject(ARG_VALUES),
1319       /*copy_variables=*/false,
1320       /*copy_numpy=*/true,
1321       /*type_inference=*/true);
1322   // See Note [Ensuring sparse values and indices match devices]
1323   Tensor compressed_indices = internal_new_from_data(
1324       values.options(),
1325       kInt,
1326       std::nullopt,
1327       r.pyobject(ARG_COMPRESSED_INDICES),
1328       /*copy_variables=*/false,
1329       /*copy_numpy=*/true,
1330       /*type_inference=*/true);
1331   Tensor plain_indices = internal_new_from_data(
1332       values.options(),
1333       kInt,
1334       std::nullopt,
1335       r.pyobject(ARG_PLAIN_INDICES),
1336       /*copy_variables=*/false,
1337       /*copy_numpy=*/true,
1338       /*type_inference=*/true);
1339   at::native::_validate_sparse_compressed_tensor_args(
1340       compressed_indices,
1341       plain_indices,
1342       values,
1343       r.intlist(ARG_SIZE),
1344       r.layout(ARG_LAYOUT));
1345 }
1346 
1347 template <c10::Layout required_layout>
_validate_sparse_compressed_tensor_args_template(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)1348 void _validate_sparse_compressed_tensor_args_template(
1349     c10::DispatchKey dispatch_key,
1350     at::ScalarType scalar_type,
1351     PyObject* args,
1352     PyObject* kwargs) {
1353   auto options = dispatchKeyToTensorOptions(dispatch_key);
1354   enum {
1355     ARG_COMPRESSED_INDICES = 0,
1356     ARG_PLAIN_INDICES,
1357     ARG_VALUES,
1358     ARG_SIZE,
1359     ARGS_COUNT
1360   };
1361   static std::string sig;
1362   switch (required_layout) {
1363     case c10::Layout::SparseCsr:
1364       sig =
1365           "_validate_sparse_csr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)";
1366       break;
1367     case c10::Layout::SparseCsc:
1368       sig =
1369           "_validate_sparse_csc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)";
1370       break;
1371     case c10::Layout::SparseBsr:
1372       sig =
1373           "_validate_sparse_bsr_tensor(PyObject* crow_indices, PyObject* col_indices, PyObject* values, IntArrayRef size)";
1374       break;
1375     case c10::Layout::SparseBsc:
1376       sig =
1377           "_validate_sparse_bsc_tensor(PyObject* ccol_indices, PyObject* row_indices, PyObject* values, IntArrayRef size)";
1378       break;
1379     default:;
1380   }
1381   static PythonArgParser parser({sig});
1382 
1383   ParsedArgs<ARGS_COUNT> parsed_args;
1384   auto r = parser.parse(args, kwargs, parsed_args);
1385   Tensor values = internal_new_from_data(
1386       options,
1387       scalar_type,
1388       std::nullopt,
1389       r.pyobject(ARG_VALUES),
1390       /*copy_variables=*/false,
1391       /*copy_numpy=*/true,
1392       /*type_inference=*/true);
1393   // See Note [Ensuring sparse values and indices match devices]
1394   Tensor compressed_indices = internal_new_from_data(
1395       values.options(),
1396       kInt,
1397       std::nullopt,
1398       r.pyobject(ARG_COMPRESSED_INDICES),
1399       /*copy_variables=*/false,
1400       /*copy_numpy=*/true,
1401       /*type_inference=*/true);
1402   Tensor plain_indices = internal_new_from_data(
1403       values.options(),
1404       kInt,
1405       std::nullopt,
1406       r.pyobject(ARG_PLAIN_INDICES),
1407       /*copy_variables=*/false,
1408       /*copy_numpy=*/true,
1409       /*type_inference=*/true);
1410 
1411   at::native::_validate_sparse_compressed_tensor_args(
1412       compressed_indices, plain_indices, values, r.intlist(3), required_layout);
1413 }
1414 
_validate_sparse_csr_tensor_args(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)1415 void _validate_sparse_csr_tensor_args(
1416     c10::DispatchKey dispatch_key,
1417     at::ScalarType scalar_type,
1418     PyObject* args,
1419     PyObject* kwargs) {
1420   _validate_sparse_compressed_tensor_args_template<c10::Layout::SparseCsr>(
1421       dispatch_key, scalar_type, args, kwargs);
1422 }
1423 
_validate_sparse_csc_tensor_args(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)1424 void _validate_sparse_csc_tensor_args(
1425     c10::DispatchKey dispatch_key,
1426     at::ScalarType scalar_type,
1427     PyObject* args,
1428     PyObject* kwargs) {
1429   _validate_sparse_compressed_tensor_args_template<c10::Layout::SparseCsc>(
1430       dispatch_key, scalar_type, args, kwargs);
1431 }
1432 
_validate_sparse_bsr_tensor_args(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)1433 void _validate_sparse_bsr_tensor_args(
1434     c10::DispatchKey dispatch_key,
1435     at::ScalarType scalar_type,
1436     PyObject* args,
1437     PyObject* kwargs) {
1438   _validate_sparse_compressed_tensor_args_template<c10::Layout::SparseBsr>(
1439       dispatch_key, scalar_type, args, kwargs);
1440 }
1441 
_validate_sparse_bsc_tensor_args(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)1442 void _validate_sparse_bsc_tensor_args(
1443     c10::DispatchKey dispatch_key,
1444     at::ScalarType scalar_type,
1445     PyObject* args,
1446     PyObject* kwargs) {
1447   _validate_sparse_compressed_tensor_args_template<c10::Layout::SparseBsc>(
1448       dispatch_key, scalar_type, args, kwargs);
1449 }
1450 
tensor_ctor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r)1451 Tensor tensor_ctor(
1452     c10::DispatchKey dispatch_key,
1453     at::ScalarType scalar_type,
1454     PythonArgs& r) {
1455   if (r.idx == 0) {
1456     PyObject* data = r.pyobject(0);
1457     if (THPVariable_Check(data)) {
1458       auto ret = PyErr_WarnEx(
1459           PyExc_UserWarning,
1460           "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
1461           "or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).",
1462           1);
1463       if (ret != 0)
1464         throw python_error();
1465     }
1466 
1467     bool type_inference = r.isNone(1);
1468     bool pin_memory = r.toBool(3);
1469     bool args_requires_grad = r.toBool(4);
1470     auto new_tensor = internal_new_from_data(
1471         typeIdWithDefault(r, 2, dispatch_key),
1472         r.scalartypeWithDefault(1, scalar_type),
1473         r.deviceOptional(2),
1474         data,
1475         /*copy_variables=*/true,
1476         /*copy_numpy=*/true,
1477         /*type_inference=*/type_inference,
1478         pin_memory);
1479     auto names = r.toDimnameListOptional(5);
1480     if (names) {
1481       at::namedinference::propagate_names(
1482           new_tensor, *names, /*validate_names=*/true);
1483     }
1484     new_tensor.detach_(); // ensure new_tensor a leaf node
1485     new_tensor.set_requires_grad(args_requires_grad);
1486     return new_tensor;
1487   }
1488   throw std::runtime_error("tensor(): invalid arguments");
1489 }
1490 
as_tensor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PythonArgs & r)1491 Tensor as_tensor(
1492     c10::DispatchKey dispatch_key,
1493     at::ScalarType scalar_type,
1494     PythonArgs& r) {
1495   // TODO: add requires_grad once we decide on semantics for sharing data.
1496   if (r.idx == 0) {
1497     bool type_inference = r.isNone(1);
1498     return internal_new_from_data(
1499         typeIdWithDefault(r, 2, dispatch_key),
1500         r.scalartypeWithDefault(1, scalar_type),
1501         r.deviceOptional(2),
1502         r.pyobject(0),
1503         /*copy_variables=*/false,
1504         /*copy_numpy=*/false,
1505         /*type_inference=*/type_inference);
1506   }
1507   throw std::runtime_error("tensor(): invalid arguments");
1508 }
1509 
new_tensor(c10::DispatchKey dispatch_key,at::ScalarType scalar_type,PyObject * args,PyObject * kwargs)1510 Tensor new_tensor(
1511     c10::DispatchKey dispatch_key,
1512     at::ScalarType scalar_type,
1513     PyObject* args,
1514     PyObject* kwargs) {
1515   static PythonArgParser parser({
1516       "new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)",
1517   });
1518 
1519   ParsedArgs<4> parsed_args;
1520   auto r = parser.parse(args, kwargs, parsed_args);
1521   if (r.idx == 0) {
1522     PyObject* data = r.pyobject(0);
1523     if (THPVariable_Check(data)) {
1524       auto ret = PyErr_WarnEx(
1525           PyExc_UserWarning,
1526           "To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() "
1527           "or sourceTensor.clone().detach().requires_grad_(True), rather than tensor.new_tensor(sourceTensor).",
1528           1);
1529       if (ret != 0)
1530         throw python_error();
1531     }
1532 
1533     bool args_requires_grad = r.toBool(3);
1534     auto new_tensor = new_from_data_copy(
1535         typeIdWithDefault(r, 2, dispatch_key),
1536         r.scalartypeWithDefault(1, scalar_type),
1537         r.deviceOptional(2),
1538         data);
1539     new_tensor.detach_(); // ensure new_tensor a leaf node
1540     new_tensor.set_requires_grad(args_requires_grad);
1541     return new_tensor;
1542   }
1543   throw std::runtime_error("new_tensor(): invalid arguments");
1544 }
1545 
tensor_frombuffer(PyObject * buffer,ScalarType dtype,int64_t count,int64_t offset,bool requires_grad)1546 Tensor tensor_frombuffer(
1547     PyObject* buffer,
1548     ScalarType dtype,
1549     int64_t count,
1550     int64_t offset,
1551     bool requires_grad) {
1552   auto elsize = at::elementSize(dtype);
1553   size_t actual_count = 0;
1554 
1555   Py_buffer view;
1556   if (PyObject_GetBuffer(buffer, &view, PyBUF_WRITABLE) < 0) {
1557     TORCH_CHECK(
1558         PyObject_GetBuffer(buffer, &view, PyBUF_SIMPLE) >= 0,
1559         "could not retrieve buffer from object");
1560     TORCH_WARN_ONCE(
1561         "The given buffer is not writable, and PyTorch does "
1562         "not support non-writable tensors. This means you can write to the "
1563         "underlying (supposedly non-writable) buffer using the tensor. "
1564         "You may want to copy the buffer to protect its data or make it writable "
1565         "before converting it to a tensor. This type of warning will be "
1566         "suppressed for the rest of this program.");
1567     PyErr_Clear();
1568   }
1569 
1570   Py_INCREF(view.obj);
1571   THPObjectPtr obj(view.obj);
1572 
1573   auto len = view.len;
1574   auto buf = view.buf;
1575   PyBuffer_Release(&view);
1576 
1577   TORCH_CHECK_VALUE(
1578       len > 0 && count != 0,
1579       "both buffer length (",
1580       len,
1581       ") and count (",
1582       count,
1583       ") must not be 0");
1584   TORCH_CHECK_VALUE(
1585       offset >= 0 && offset < len,
1586       "offset (",
1587       offset,
1588       " bytes) must be non-negative and no greater than "
1589       "buffer length (",
1590       len,
1591       " bytes) minus 1");
1592   TORCH_CHECK_VALUE(
1593       count > 0 || (len - offset) % elsize == 0,
1594       "buffer length (",
1595       len - offset,
1596       " bytes) after offset (",
1597       offset,
1598       " bytes) "
1599       "must be a multiple of element size (",
1600       elsize,
1601       ")");
1602 
1603   if (count < 0) {
1604     actual_count = (len - offset) / elsize;
1605   } else {
1606     actual_count = static_cast<size_t>(count);
1607   }
1608 
1609   TORCH_CHECK_VALUE(
1610       static_cast<size_t>(offset) + actual_count * elsize <=
1611           static_cast<size_t>(len),
1612       "requested buffer length (",
1613       actual_count,
1614       " * ",
1615       elsize,
1616       " bytes) "
1617       "after offset (",
1618       offset,
1619       " bytes) must not be greater than actual "
1620       "buffer length (",
1621       len,
1622       " bytes)");
1623 
1624   auto offset_buf = static_cast<char*>(buf) + offset;
1625   auto options = TensorOptions().dtype(dtype).device(c10::kCPU);
1626 
1627   auto tensor = at::for_blob(offset_buf, static_cast<int64_t>(actual_count))
1628                     .options(options)
1629                     .deleter([obj = obj.release()](void*) {
1630                       pybind11::gil_scoped_acquire gil;
1631                       Py_DECREF(obj);
1632                     })
1633                     .make_tensor();
1634   tensor.set_requires_grad(requires_grad);
1635   return tensor;
1636 }
1637 
tensor_fromDLPack(PyObject * data)1638 Tensor tensor_fromDLPack(PyObject* data) {
1639   DLManagedTensor* dlMTensor =
1640       (DLManagedTensor*)PyCapsule_GetPointer(data, "dltensor");
1641   TORCH_CHECK(
1642       dlMTensor,
1643       "from_dlpack received an invalid capsule. "
1644       "Note that DLTensor capsules can be consumed only once, "
1645       "so you might have already constructed a tensor from it once.");
1646 
1647   auto deleter_with_gil = [dlMTensor](void*) {
1648     if (dlMTensor->deleter) {
1649       pybind11::gil_scoped_acquire gil;
1650       dlMTensor->deleter(dlMTensor);
1651     }
1652   };
1653 
1654   // atensor steals the ownership of the underlying storage. It also passes a
1655   // destructor function that will be called when the underlying storage goes
1656   // out of scope. When the destructor is called, the dlMTensor is destructed
1657   // too.
1658   // HACK: Ensure that we hold the GIL here just in case the
1659   // managed tensor originating from a buggy NumPy build.
1660   auto atensor = torch::utils::is_numpy_dlpack_deleter_bugged()
1661       ? at::fromDLPack(dlMTensor, std::move(deleter_with_gil))
1662       : at::fromDLPack(dlMTensor);
1663 
1664   // Make sure this capsule will never be used again.
1665   PyCapsule_SetName(data, "used_dltensor");
1666 
1667   // It is possible that the call to at::fromDLPack is the very first
1668   // call to create a Tensor in PyTorch. If so, then _lazy_init has
1669   // not been called, and the attempt to call createPyObject will fail
1670   // because cuda ATen types have not been registered in Python yet.
1671   // so if we have a cuda tensor, then we need to make sure
1672   // we have called _lazy_init here
1673   maybe_initialize_device(atensor.device());
1674   return atensor;
1675 }
1676 
asarray(PyObject * obj,std::optional<ScalarType> dtype,std::optional<Device> device,std::optional<bool> copy,bool requires_grad)1677 Tensor asarray(
1678     PyObject* obj,
1679     std::optional<ScalarType> dtype,
1680     std::optional<Device> device,
1681     std::optional<bool> copy,
1682     bool requires_grad) {
1683   Tensor tensor;
1684 
1685   bool force_copy = copy.value_or(false);
1686   bool force_alias = !copy.value_or(true);
1687   bool should_warn_numpy_not_writable = false;
1688 
1689   // Used when:
1690   // 1. 'obj' implements the buffer protocol and no type is given.
1691   // 2. creating a new tensor from a Python sequence.
1692   auto dtype_unwrapped =
1693       dtype.value_or(torch::tensors::get_default_scalar_type());
1694 
1695   // Check whether 'obj' is a 'Tensor'
1696   if (THPVariable_Check(obj)) {
1697     tensor = THPVariable_Unpack(obj);
1698   }
1699 
1700 #ifdef USE_NUMPY
1701   if (is_numpy_available()) {
1702     // Check whether 'obj' is a NumPy Array or Scalar.
1703     bool is_numpy_array = PyArray_Check(obj);
1704     bool is_numpy_scalar = PyArray_CheckScalar(obj);
1705 
1706     if (is_numpy_array || is_numpy_scalar) {
1707       THPObjectPtr ptr;
1708       auto arr = obj;
1709 
1710       // PyArray_CheckScalar is true for both scalars and 0-dim arrays, per
1711       // https://numpy.org/devdocs/reference/c-api/array.html#c.PyArray_CheckScalar
1712       // But for 0-dim arrays no `PyArray_FromScalar` call is needed
1713       if (is_numpy_scalar && !is_numpy_array) {
1714         TORCH_CHECK(
1715             !force_alias,
1716             "can't alias NumPy scalars. ",
1717             "Either remove copy=False or transform it in a ndarray. ")
1718 
1719         ptr = PyArray_FromScalar(obj, nullptr);
1720         arr = ptr.get();
1721       }
1722 
1723       tensor = tensor_from_numpy(arr, /*warn_if_not_writeable=*/false);
1724       should_warn_numpy_not_writable =
1725           !PyArray_ISWRITEABLE((PyArrayObject*)arr);
1726 
1727       if (is_numpy_scalar) {
1728         // Uses a newly cloned storage, instead of the shared one.
1729         // The THPObjectPtr will delete the previous storage in the
1730         // end of the previous scope.
1731         tensor = tensor.clone();
1732 
1733         // No need to clone again, later.
1734         force_copy = false;
1735       }
1736     }
1737   }
1738 #endif
1739 
1740   // Check whether 'obj' is a 'DLPack' capsule
1741   if (!tensor.defined() && PyCapsule_IsValid(obj, "dltensor") != 0) {
1742     tensor = tensor_fromDLPack(obj);
1743   }
1744 
1745   // Check whether 'obj' implements the buffer protocol
1746   if (!tensor.defined() && PyObject_CheckBuffer(obj) != 0) {
1747     tensor = tensor_frombuffer(obj, dtype_unwrapped, -1, 0, requires_grad);
1748   }
1749 
1750   if (tensor.defined()) {
1751     // Given an aliasable tensor, should we copy it?
1752     bool wrong_device = device.has_value() && device.value() != tensor.device();
1753     bool wrong_dtype =
1754         dtype.has_value() && dtype.value() != tensor.scalar_type();
1755     bool needs_copying = !copy.has_value() && (wrong_device || wrong_dtype);
1756 
1757     // Given a defined tensor, we copy it if either we have to (copy=True) or
1758     // if we need to (copy=None) because of mismatched device or dtype.
1759     if (force_copy || needs_copying) {
1760       if (wrong_device || wrong_dtype) {
1761         tensor = tensor.to(
1762             device.value_or(tensor.device()),
1763             dtype.value_or(tensor.scalar_type()),
1764             /*non_blocking=*/false,
1765             /*copy=*/force_copy);
1766       } else {
1767         tensor = tensor.clone();
1768       }
1769     } else {
1770       // If we are not copying, we have to check whther we have the tensor
1771       // in the right device, with the right dtype.
1772       TORCH_CHECK_VALUE(
1773           !wrong_device,
1774           "can't alias tensor from device '",
1775           tensor.device(),
1776           "' to '",
1777           device.value(),
1778           "'.");
1779       TORCH_CHECK_VALUE(
1780           !wrong_dtype,
1781           "can't alias tensor with dtype '",
1782           tensor.scalar_type(),
1783           "' into dtype '",
1784           dtype.value(),
1785           "'.");
1786       // If tensor is a NumPy Array view, we warn the user about non-writeable
1787       // arrays if this is the case.
1788       if (should_warn_numpy_not_writable) {
1789         warn_numpy_not_writeable();
1790       }
1791     }
1792 
1793     // Setting 'requires_grad' when the tensor is not a leaf does not work.
1794     // Whenever that happens, we have to use 'detach'.
1795     if (!tensor.is_leaf() && !requires_grad) {
1796       tensor = tensor.detach();
1797     } else {
1798       tensor.set_requires_grad(requires_grad);
1799     }
1800   } else {
1801     // Undefined tensor means it does not implement neither DLPack nor
1802     // the buffer protocol. Last case is a sequence, in which case we must
1803     // copy (copy can't be false).
1804     TORCH_CHECK_VALUE(
1805         !force_alias, "can't alias arbitrary sequence into a tensor.");
1806 
1807     // Make tensor from sequence, inferring its type, and then convert
1808     // it to the desired type.
1809     // Type inference is activated only if the dtype has not been specified.
1810     // Otherwise, we force the unwrapped dtype.
1811     tensor = internal_new_from_data(
1812         TensorOptions(),
1813         dtype_unwrapped,
1814         device,
1815         obj,
1816         /* copy_variables = */ false,
1817         /* copy_numpy = */ false,
1818         /* type_inference = */ !dtype.has_value());
1819     tensor.set_requires_grad(requires_grad);
1820   }
1821 
1822   return tensor;
1823 }
1824 
only_lift_cpu_tensors()1825 bool only_lift_cpu_tensors() {
1826   return kOnlyLiftCPUTensors;
1827 }
1828 
set_only_lift_cpu_tensors(bool value)1829 void set_only_lift_cpu_tensors(bool value) {
1830   kOnlyLiftCPUTensors = value;
1831 }
1832 
1833 } // namespace torch::utils
1834