xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/type_factory.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/type_factory.h>
2 
3 #include <ATen/core/jit_type.h>
4 
5 namespace c10 {
6 
7 // Dtype constraints are not constrained in compilation. Therefore, we map
8 // all tensor subclasses with different dtypes to a same underlying
9 // Tensor. But, we give warning about possible dtype change whenever user
10 // uses any of the tensor subclasses such as LongTensor.
11 //
12 // Technically "number" is not a python type but we need it when
13 // parsing serialized methods that use implicit conversions to Scalar
14 #define FORALL_BASE_PYTHON_TYPES(_) \
15   _(Tensor, TensorType)             \
16   _(LongTensor, TensorType)         \
17   _(DoubleTensor, TensorType)       \
18   _(FloatTensor, TensorType)        \
19   _(IntTensor, TensorType)          \
20   _(ShortTensor, TensorType)        \
21   _(HalfTensor, TensorType)         \
22   _(CharTensor, TensorType)         \
23   _(ByteTensor, TensorType)         \
24   _(BoolTensor, TensorType)         \
25   _(int, IntType)                   \
26   _(float, FloatType)               \
27   _(bool, BoolType)                 \
28   _(complex, ComplexType)           \
29   _(str, StringType)                \
30   _(Device, DeviceObjType)          \
31   _(Generator, GeneratorType)       \
32   _(Stream, StreamObjType)          \
33   _(number, NumberType)             \
34   _(None, NoneType)                 \
35   _(NoneType, NoneType)             \
36   _(Any, AnyType)                   \
37   _(Capsule, CapsuleType)           \
38   _(list, AnyListType)              \
39   _(tuple, AnyTupleType)
40 
41 const std::unordered_map<std::string, c10::TypePtr>& DynamicTypeFactory::
basePythonTypes()42     basePythonTypes() {
43   static const std::unordered_map<std::string, c10::TypePtr> map = {
44 #define MAP_ITEM(NAME, TYPE) \
45   {#NAME, c10::DynamicTypeTrait<c10::TYPE>::getBaseType()},
46       FORALL_BASE_PYTHON_TYPES(MAP_ITEM)
47 #undef MAP_ITEM
48   };
49   return map;
50 }
51 
52 const std::unordered_map<std::string, c10::TypePtr>& DefaultTypeFactory::
basePythonTypes()53     basePythonTypes() {
54   static const std::unordered_map<std::string, c10::TypePtr> map = {
55 #define MAP_ITEM(NAME, TYPE) {#NAME, c10::TYPE::get()},
56       FORALL_BASE_PYTHON_TYPES(MAP_ITEM)
57 #undef MAP_ITEM
58   };
59   return map;
60 }
61 
createNamedTuple(const std::string & name,const std::vector<c10::string_view> & fields,const std::vector<c10::TypePtr> & types)62 c10::TypePtr DefaultTypeFactory::createNamedTuple(
63     const std::string& name,
64     const std::vector<c10::string_view>& fields,
65     const std::vector<c10::TypePtr>& types) {
66   return c10::TupleType::createNamed(name, fields, types);
67 }
68 
69 } // namespace c10
70