xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/constants.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/constants.h>
2 
3 #include <ATen/core/functional.h>
4 #include <torch/csrc/autograd/variable.h>
5 #include <torch/csrc/jit/ir/ir.h>
6 #include <torch/csrc/jit/runtime/custom_operator.h>
7 #include <torch/csrc/jit/runtime/operator.h>
8 #include <torch/csrc/jit/runtime/register_ops_utils.h>
9 
10 namespace torch::jit {
11 
insertableTensor(const at::Tensor & ten)12 static bool insertableTensor(const at::Tensor& ten) {
13   // bail if tensor has no storage i.e. opaque tensor used in MKLdnn.
14   // or gradients because we have no way of serializing them & are mutable
15   return !ten.requires_grad() && ten.has_storage() && !ten.is_nested();
16 }
17 
insertableIValue(const IValue & ivalue)18 static bool insertableIValue(const IValue& ivalue) {
19   if (ivalue.isInt() || ivalue.isNone() || ivalue.isBool() ||
20       ivalue.isDouble() || ivalue.isComplexDouble() || ivalue.isString() ||
21       ivalue.isDevice() || ivalue.isEnum()) {
22     return true;
23   }
24   if (ivalue.isTensor()) {
25     return insertableTensor(ivalue.toTensor());
26   }
27   if (ivalue.isList() || ivalue.isTuple()) {
28     c10::ArrayRef<IValue> elems;
29     if (ivalue.isTuple()) {
30       elems = ivalue.toTupleRef().elements();
31     } else {
32       elems = ivalue.toListRef();
33     }
34     return std::all_of(elems.begin(), elems.end(), [](const IValue& tup_elem) {
35       return insertableIValue(tup_elem);
36     });
37   }
38   if (ivalue.isGenericDict()) {
39     const auto& dict = ivalue.toGenericDict();
40     return std::all_of(dict.begin(), dict.end(), [](const auto& entry) {
41       return insertableIValue(entry.key()) && insertableIValue(entry.value());
42     });
43   }
44 
45   return false;
46 }
47 
insertConstant(Graph & g,const IValue & val,std::optional<SourceRange> loc,std::optional<ScopePtr> scope)48 Value* insertConstant(
49     Graph& g,
50     const IValue& val,
51     std::optional<SourceRange> loc,
52     std::optional<ScopePtr> scope) {
53   auto value = tryInsertConstant(g, val, std::move(loc), std::move(scope));
54   if (value) {
55     return *value;
56   }
57   throw constant_not_supported_error(
58       "Unsupported value kind: " + val.tagKind());
59 }
60 
61 // IValue -> Constant node
tryInsertConstant(Graph & g,const IValue & val,std::optional<SourceRange> loc,std::optional<ScopePtr> scope)62 std::optional<Value*> tryInsertConstant(
63     Graph& g,
64     const IValue& val,
65     std::optional<SourceRange> loc,
66     std::optional<ScopePtr> scope) {
67   Node* n = g.create(prim::Constant);
68   if (val.isTensor()) {
69     at::Tensor ref = val.toTensor();
70     if (!insertableTensor(val.toTensor())) {
71       n->destroy();
72       return std::nullopt;
73     }
74     if (!ref.defined()) {
75       n->destroy();
76       return g.insertNode(g.createNone())->output();
77     }
78     TORCH_INTERNAL_ASSERT(!ref.requires_grad());
79     n->output()->inferTypeFrom(
80         ref); // note: before t_ because of std::move(ref)
81     n->t_(attr::value, std::move(ref));
82   } else if (val.isInt()) {
83     n->i_(attr::value, val.toInt());
84     n->output()->setType(IntType::get());
85   } else if (val.isDouble()) {
86     n->f_(attr::value, val.toDouble());
87     n->output()->setType(FloatType::get());
88   } else if (val.isComplexDouble()) {
89     n->c_(attr::value, val.toComplexDouble());
90     n->output()->setType(ComplexType::get());
91   } else if (val.isBool()) {
92     n->i_(attr::value, val.toBool());
93     n->output()->setType(BoolType::get());
94   } else if (val.isList()) {
95     bool fast_path_list =
96         val.isBoolList() || val.isIntList() || val.isDoubleList();
97     if (fast_path_list || insertableIValue(val)) {
98       n->ival_(attr::value, val);
99       n->output()->setType(val.type());
100     } else {
101       n->destroy();
102       return std::nullopt;
103     }
104   } else if (val.isString()) {
105     n->s_(attr::value, val.toStringRef());
106     n->output()->setType(StringType::get());
107   } else if (val.isDevice()) {
108     std::stringstream ss;
109     ss << val.toDevice();
110     n->s_(attr::value, ss.str());
111     n->output()->setType(DeviceObjType::get());
112   } else if (val.isGenerator()) {
113     auto generator = val.toGenerator();
114     n->ival_(attr::value, generator);
115     n->output()->setType(GeneratorType::get());
116   } else if (val.isStream()) {
117     // packing into int64_t removed
118     n->ival_(attr::value, val);
119     n->output()->setType(StreamObjType::get());
120   } else if (val.isNone()) {
121     n->output()->setType(NoneType::get());
122   } else if (val.isTuple()) {
123     if (insertableIValue(val)) {
124       n->ival_(attr::value, val);
125       n->output()->setType(val.type());
126     } else {
127       n->destroy();
128       return std::nullopt;
129     };
130   } else if (val.isObject()) {
131     const auto& ref = val.toObjectRef();
132     // see: [Constant Object Weak CompilationUnit Reference]
133     if (!ref.type()->is_module() &&
134         (ref.is_weak_compilation_ref() ||
135          ref.is_empty_strong_compilation_ref())) {
136       n->ival_(attr::value, val);
137       n->output()->setType(val.type());
138     } else {
139       n->destroy();
140       return std::nullopt;
141     }
142   } else if ((val.isGenericDict() && insertableIValue(val)) || (val.isEnum())) {
143     n->ival_(attr::value, val);
144     n->output()->setType(val.type());
145   } else {
146     n->destroy();
147     return std::nullopt;
148   }
149   if (loc)
150     n->setSourceRange(*loc);
151   if (scope)
152     n->setScope(*scope);
153   return g.insertNode(n)->output();
154 }
155 
toIValue(const Value * v)156 std::optional<IValue> toIValue(const Value* v) {
157   if (v->node()->kind() != prim::Constant || v->type()->cast<FunctionType>()) {
158     return std::nullopt;
159   }
160   const Node* node = v->node();
161   const TypePtr& type = v->type();
162   if (type->isSubtypeOf(*TensorType::get())) {
163     return node->t(attr::value);
164   } else if (type->isSubtypeOf(*BoolType::get())) {
165     return (bool)node->i(attr::value);
166   } else if (
167       type->isSubtypeOf(*NumberType::get()) &&
168       node->kindOf(attr::value) == AttributeKind::i) {
169     return node->i(attr::value);
170   } else if (
171       type->isSubtypeOf(*NumberType::get()) &&
172       node->kindOf(attr::value) == AttributeKind::f) {
173     return node->f(attr::value);
174   } else if (
175       type->isSubtypeOf(*NumberType::get()) &&
176       node->kindOf(attr::value) == AttributeKind::c) {
177     return node->c(attr::value);
178   } else if (
179       type->cast<ListType>() &&
180       node->kindOf(attr::value) == AttributeKind::ival) {
181     const auto& list = node->ival(attr::value);
182     TORCH_INTERNAL_ASSERT(list.isList());
183     return list;
184   } else if (
185       type->cast<DictType>() &&
186       node->kindOf(attr::value) == AttributeKind::ival) {
187     const auto& dict = node->ival(attr::value);
188     TORCH_INTERNAL_ASSERT(dict.isGenericDict());
189     return dict;
190   } else if (
191       type->cast<TupleType>() &&
192       node->kindOf(attr::value) == AttributeKind::ival) {
193     const auto& tup = node->ival(attr::value);
194     TORCH_INTERNAL_ASSERT(tup.isTuple());
195     return tup;
196   } else if (type == StringType::get()) {
197     const auto& s = node->s(attr::value);
198     return s;
199   } else if (type == DeviceObjType::get()) {
200     auto d = c10::Device(node->s(attr::value));
201     return d;
202   } else if (type == GeneratorType::get()) {
203     auto generator = node->ival(attr::value).toGenerator();
204     return generator;
205   } else if (type == StreamObjType::get()) {
206     // int64_t packing removed
207     auto s = node->ival(attr::value).toStream();
208     return s;
209   } else if (node->mustBeNone()) {
210     return IValue();
211   } else if (type->cast<EnumType>()) {
212     const auto& enum_val = node->ival(attr::value);
213     return enum_val;
214   } else if (type->cast<ClassType>() && !type->is_module()) {
215     const auto& class_val = node->ival(attr::value);
216     return class_val;
217   } else {
218     std::stringstream ss;
219     ss << "constant literal not supported for: " << type->str();
220     throw std::runtime_error(ss.str());
221   }
222 }
223 
224 } // namespace torch::jit
225