1 #include <torch/csrc/jit/ir/constants.h>
2
3 #include <ATen/core/functional.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/runtime/custom_operator.h>
7 #include <torch/csrc/jit/runtime/operator.h>
8 #include <torch/csrc/jit/runtime/register_ops_utils.h>
9
10 namespace torch::jit {
11
insertableTensor(const at::Tensor & ten)12 static bool insertableTensor(const at::Tensor& ten) {
13 // bail if tensor has no storage i.e. opaque tensor used in MKLdnn.
14 // or gradients because we have no way of serializing them & are mutable
15 return !ten.requires_grad() && ten.has_storage() && !ten.is_nested();
16 }
17
insertableIValue(const IValue & ivalue)18 static bool insertableIValue(const IValue& ivalue) {
19 if (ivalue.isInt() || ivalue.isNone() || ivalue.isBool() ||
20 ivalue.isDouble() || ivalue.isComplexDouble() || ivalue.isString() ||
21 ivalue.isDevice() || ivalue.isEnum()) {
22 return true;
23 }
24 if (ivalue.isTensor()) {
25 return insertableTensor(ivalue.toTensor());
26 }
27 if (ivalue.isList() || ivalue.isTuple()) {
28 c10::ArrayRef<IValue> elems;
29 if (ivalue.isTuple()) {
30 elems = ivalue.toTupleRef().elements();
31 } else {
32 elems = ivalue.toListRef();
33 }
34 return std::all_of(elems.begin(), elems.end(), [](const IValue& tup_elem) {
35 return insertableIValue(tup_elem);
36 });
37 }
38 if (ivalue.isGenericDict()) {
39 const auto& dict = ivalue.toGenericDict();
40 return std::all_of(dict.begin(), dict.end(), [](const auto& entry) {
41 return insertableIValue(entry.key()) && insertableIValue(entry.value());
42 });
43 }
44
45 return false;
46 }
47
insertConstant(Graph & g,const IValue & val,std::optional<SourceRange> loc,std::optional<ScopePtr> scope)48 Value* insertConstant(
49 Graph& g,
50 const IValue& val,
51 std::optional<SourceRange> loc,
52 std::optional<ScopePtr> scope) {
53 auto value = tryInsertConstant(g, val, std::move(loc), std::move(scope));
54 if (value) {
55 return *value;
56 }
57 throw constant_not_supported_error(
58 "Unsupported value kind: " + val.tagKind());
59 }
60
61 // IValue -> Constant node
tryInsertConstant(Graph & g,const IValue & val,std::optional<SourceRange> loc,std::optional<ScopePtr> scope)62 std::optional<Value*> tryInsertConstant(
63 Graph& g,
64 const IValue& val,
65 std::optional<SourceRange> loc,
66 std::optional<ScopePtr> scope) {
67 Node* n = g.create(prim::Constant);
68 if (val.isTensor()) {
69 at::Tensor ref = val.toTensor();
70 if (!insertableTensor(val.toTensor())) {
71 n->destroy();
72 return std::nullopt;
73 }
74 if (!ref.defined()) {
75 n->destroy();
76 return g.insertNode(g.createNone())->output();
77 }
78 TORCH_INTERNAL_ASSERT(!ref.requires_grad());
79 n->output()->inferTypeFrom(
80 ref); // note: before t_ because of std::move(ref)
81 n->t_(attr::value, std::move(ref));
82 } else if (val.isInt()) {
83 n->i_(attr::value, val.toInt());
84 n->output()->setType(IntType::get());
85 } else if (val.isDouble()) {
86 n->f_(attr::value, val.toDouble());
87 n->output()->setType(FloatType::get());
88 } else if (val.isComplexDouble()) {
89 n->c_(attr::value, val.toComplexDouble());
90 n->output()->setType(ComplexType::get());
91 } else if (val.isBool()) {
92 n->i_(attr::value, val.toBool());
93 n->output()->setType(BoolType::get());
94 } else if (val.isList()) {
95 bool fast_path_list =
96 val.isBoolList() || val.isIntList() || val.isDoubleList();
97 if (fast_path_list || insertableIValue(val)) {
98 n->ival_(attr::value, val);
99 n->output()->setType(val.type());
100 } else {
101 n->destroy();
102 return std::nullopt;
103 }
104 } else if (val.isString()) {
105 n->s_(attr::value, val.toStringRef());
106 n->output()->setType(StringType::get());
107 } else if (val.isDevice()) {
108 std::stringstream ss;
109 ss << val.toDevice();
110 n->s_(attr::value, ss.str());
111 n->output()->setType(DeviceObjType::get());
112 } else if (val.isGenerator()) {
113 auto generator = val.toGenerator();
114 n->ival_(attr::value, generator);
115 n->output()->setType(GeneratorType::get());
116 } else if (val.isStream()) {
117 // packing into int64_t removed
118 n->ival_(attr::value, val);
119 n->output()->setType(StreamObjType::get());
120 } else if (val.isNone()) {
121 n->output()->setType(NoneType::get());
122 } else if (val.isTuple()) {
123 if (insertableIValue(val)) {
124 n->ival_(attr::value, val);
125 n->output()->setType(val.type());
126 } else {
127 n->destroy();
128 return std::nullopt;
129 };
130 } else if (val.isObject()) {
131 const auto& ref = val.toObjectRef();
132 // see: [Constant Object Weak CompilationUnit Reference]
133 if (!ref.type()->is_module() &&
134 (ref.is_weak_compilation_ref() ||
135 ref.is_empty_strong_compilation_ref())) {
136 n->ival_(attr::value, val);
137 n->output()->setType(val.type());
138 } else {
139 n->destroy();
140 return std::nullopt;
141 }
142 } else if ((val.isGenericDict() && insertableIValue(val)) || (val.isEnum())) {
143 n->ival_(attr::value, val);
144 n->output()->setType(val.type());
145 } else {
146 n->destroy();
147 return std::nullopt;
148 }
149 if (loc)
150 n->setSourceRange(*loc);
151 if (scope)
152 n->setScope(*scope);
153 return g.insertNode(n)->output();
154 }
155
toIValue(const Value * v)156 std::optional<IValue> toIValue(const Value* v) {
157 if (v->node()->kind() != prim::Constant || v->type()->cast<FunctionType>()) {
158 return std::nullopt;
159 }
160 const Node* node = v->node();
161 const TypePtr& type = v->type();
162 if (type->isSubtypeOf(*TensorType::get())) {
163 return node->t(attr::value);
164 } else if (type->isSubtypeOf(*BoolType::get())) {
165 return (bool)node->i(attr::value);
166 } else if (
167 type->isSubtypeOf(*NumberType::get()) &&
168 node->kindOf(attr::value) == AttributeKind::i) {
169 return node->i(attr::value);
170 } else if (
171 type->isSubtypeOf(*NumberType::get()) &&
172 node->kindOf(attr::value) == AttributeKind::f) {
173 return node->f(attr::value);
174 } else if (
175 type->isSubtypeOf(*NumberType::get()) &&
176 node->kindOf(attr::value) == AttributeKind::c) {
177 return node->c(attr::value);
178 } else if (
179 type->cast<ListType>() &&
180 node->kindOf(attr::value) == AttributeKind::ival) {
181 const auto& list = node->ival(attr::value);
182 TORCH_INTERNAL_ASSERT(list.isList());
183 return list;
184 } else if (
185 type->cast<DictType>() &&
186 node->kindOf(attr::value) == AttributeKind::ival) {
187 const auto& dict = node->ival(attr::value);
188 TORCH_INTERNAL_ASSERT(dict.isGenericDict());
189 return dict;
190 } else if (
191 type->cast<TupleType>() &&
192 node->kindOf(attr::value) == AttributeKind::ival) {
193 const auto& tup = node->ival(attr::value);
194 TORCH_INTERNAL_ASSERT(tup.isTuple());
195 return tup;
196 } else if (type == StringType::get()) {
197 const auto& s = node->s(attr::value);
198 return s;
199 } else if (type == DeviceObjType::get()) {
200 auto d = c10::Device(node->s(attr::value));
201 return d;
202 } else if (type == GeneratorType::get()) {
203 auto generator = node->ival(attr::value).toGenerator();
204 return generator;
205 } else if (type == StreamObjType::get()) {
206 // int64_t packing removed
207 auto s = node->ival(attr::value).toStream();
208 return s;
209 } else if (node->mustBeNone()) {
210 return IValue();
211 } else if (type->cast<EnumType>()) {
212 const auto& enum_val = node->ival(attr::value);
213 return enum_val;
214 } else if (type->cast<ClassType>() && !type->is_module()) {
215 const auto& class_val = node->ival(attr::value);
216 return class_val;
217 } else {
218 std::stringstream ss;
219 ss << "constant literal not supported for: " << type->str();
220 throw std::runtime_error(ss.str());
221 }
222 }
223
224 } // namespace torch::jit
225