1 #pragma once 2 3 #include <c10/core/Device.h> 4 #include <c10/core/DispatchKey.h> 5 #include <c10/core/ScalarType.h> 6 #include <torch/csrc/python_headers.h> 7 8 namespace at { 9 class Tensor; 10 } // namespace at 11 12 namespace torch { 13 namespace tensors { 14 15 // Initializes the Python tensor type objects: torch.FloatTensor, 16 // torch.DoubleTensor, etc. and binds them in their containing modules. 17 void initialize_python_bindings(); 18 19 // Same as set_default_tensor_type() but takes a PyObject* 20 void py_set_default_tensor_type(PyObject* type_obj); 21 22 // Same as py_set_default_tensor_type, but only changes the dtype (ScalarType). 23 void py_set_default_dtype(PyObject* dtype_obj); 24 25 // Gets the DispatchKey for the default tensor type. 26 // 27 // TODO: This is nuts! There is no reason to let the default tensor type id 28 // change. Probably only store ScalarType, as that's the only flex point 29 // we support. 30 TORCH_API c10::DispatchKey get_default_dispatch_key(); 31 at::Device get_default_device(); 32 33 // Gets the ScalarType for the default tensor type. 34 at::ScalarType get_default_scalar_type(); 35 } // namespace tensors 36 } // namespace torch 37