xref: /aosp_15_r20/external/pytorch/torch/csrc/tensor/python_tensor.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/core/Device.h>
4 #include <c10/core/DispatchKey.h>
5 #include <c10/core/ScalarType.h>
6 #include <torch/csrc/python_headers.h>
7 
8 namespace at {
9 class Tensor;
10 } // namespace at
11 
12 namespace torch {
13 namespace tensors {
14 
15 // Initializes the Python tensor type objects: torch.FloatTensor,
16 // torch.DoubleTensor, etc. and binds them in their containing modules.
17 void initialize_python_bindings();
18 
19 // Same as set_default_tensor_type() but takes a PyObject*
20 void py_set_default_tensor_type(PyObject* type_obj);
21 
22 // Same as py_set_default_tensor_type, but only changes the dtype (ScalarType).
23 void py_set_default_dtype(PyObject* dtype_obj);
24 
25 // Gets the DispatchKey for the default tensor type.
26 //
27 // TODO: This is nuts!  There is no reason to let the default tensor type id
28 // change.  Probably only store ScalarType, as that's the only flex point
29 // we support.
30 TORCH_API c10::DispatchKey get_default_dispatch_key();
31 at::Device get_default_device();
32 
33 // Gets the ScalarType for the default tensor type.
34 at::ScalarType get_default_scalar_type();
35 } // namespace tensors
36 } // namespace torch
37