1 #include <ATen/core/type_factory.h>
2
3 #include <ATen/core/jit_type.h>
4
5 namespace c10 {
6
7 // Dtype constraints are not constrained in compilation. Therefore, we map
8 // all tensor subclasses with different dtypes to a same underlying
9 // Tensor. But, we give warning about possible dtype change whenever user
10 // uses any of the tensor subclasses such as LongTensor.
11 //
12 // Technically "number" is not a python type but we need it when
13 // parsing serialized methods that use implicit conversions to Scalar
14 #define FORALL_BASE_PYTHON_TYPES(_) \
15 _(Tensor, TensorType) \
16 _(LongTensor, TensorType) \
17 _(DoubleTensor, TensorType) \
18 _(FloatTensor, TensorType) \
19 _(IntTensor, TensorType) \
20 _(ShortTensor, TensorType) \
21 _(HalfTensor, TensorType) \
22 _(CharTensor, TensorType) \
23 _(ByteTensor, TensorType) \
24 _(BoolTensor, TensorType) \
25 _(int, IntType) \
26 _(float, FloatType) \
27 _(bool, BoolType) \
28 _(complex, ComplexType) \
29 _(str, StringType) \
30 _(Device, DeviceObjType) \
31 _(Generator, GeneratorType) \
32 _(Stream, StreamObjType) \
33 _(number, NumberType) \
34 _(None, NoneType) \
35 _(NoneType, NoneType) \
36 _(Any, AnyType) \
37 _(Capsule, CapsuleType) \
38 _(list, AnyListType) \
39 _(tuple, AnyTupleType)
40
41 const std::unordered_map<std::string, c10::TypePtr>& DynamicTypeFactory::
basePythonTypes()42 basePythonTypes() {
43 static const std::unordered_map<std::string, c10::TypePtr> map = {
44 #define MAP_ITEM(NAME, TYPE) \
45 {#NAME, c10::DynamicTypeTrait<c10::TYPE>::getBaseType()},
46 FORALL_BASE_PYTHON_TYPES(MAP_ITEM)
47 #undef MAP_ITEM
48 };
49 return map;
50 }
51
52 const std::unordered_map<std::string, c10::TypePtr>& DefaultTypeFactory::
basePythonTypes()53 basePythonTypes() {
54 static const std::unordered_map<std::string, c10::TypePtr> map = {
55 #define MAP_ITEM(NAME, TYPE) {#NAME, c10::TYPE::get()},
56 FORALL_BASE_PYTHON_TYPES(MAP_ITEM)
57 #undef MAP_ITEM
58 };
59 return map;
60 }
61
createNamedTuple(const std::string & name,const std::vector<c10::string_view> & fields,const std::vector<c10::TypePtr> & types)62 c10::TypePtr DefaultTypeFactory::createNamedTuple(
63 const std::string& name,
64 const std::vector<c10::string_view>& fields,
65 const std::vector<c10::TypePtr>& types) {
66 return c10::TupleType::createNamed(name, fields, types);
67 }
68
69 } // namespace c10
70