1 #include <torch/csrc/jit/ir/ir.h>
2
3 #include <algorithm>
4 #include <unordered_map>
5
6 #include <ATen/core/functional.h>
7 #include <ATen/core/symbol.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/hash.h>
10 #include <c10/util/irange.h>
11 #include <torch/csrc/jit/ir/node_hashing.h>
12 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
13
14 namespace torch::jit {
15
16 namespace {
17
tensorEqual(const at::Tensor & lhs,const at::Tensor & rhs)18 bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
19 // type_equal doesnt distinguish between mkldnn/pytorch cpu tensors,
20 // and we dont want to coalesce mkldnn tensors bc they do layout
21 // transformations based on usage
22 if (lhs.is_mkldnn() || rhs.is_mkldnn()) {
23 return false;
24 }
25 if (lhs.is_nested() || rhs.is_nested()) {
26 return false;
27 }
28 // If device is not equal, lhs.equal(rhs) would throw an error.
29 if (lhs.device() != rhs.device()) {
30 return false;
31 }
32 return lhs.options().type_equal(rhs.options()) && lhs.equal(rhs);
33 }
34
typeListEqual(const std::vector<TypePtr> & lhs,const std::vector<TypePtr> & rhs)35 bool typeListEqual(
36 const std::vector<TypePtr>& lhs,
37 const std::vector<TypePtr>& rhs) {
38 if (lhs.size() != rhs.size())
39 return false;
40 for (const auto i : c10::irange(lhs.size())) {
41 if (*lhs[i] != *rhs[i]) {
42 return false;
43 }
44 }
45 return true;
46 }
47
48 template <typename attribute_type> // int64_t, bool, double
attributesEqual(attribute_type a1,attribute_type a2)49 bool attributesEqual(attribute_type a1, attribute_type a2) {
50 return a1 == a2;
51 }
52
attributesEqual(const at::Tensor & a1,const at::Tensor & a2)53 bool attributesEqual(const at::Tensor& a1, const at::Tensor& a2) {
54 return tensorEqual(a1, a2);
55 }
56
57 bool ivaluesEqual(const IValue& a1, const IValue& a2);
58
attributesEqual(const std::vector<at::Tensor> & lhs,const std::vector<at::Tensor> & rhs)59 bool attributesEqual(
60 const std::vector<at::Tensor>& lhs,
61 const std::vector<at::Tensor>& rhs) {
62 if (lhs.size() != rhs.size())
63 return false;
64 return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
65 }
66
attributesEqual(at::ArrayRef<IValue> a1,at::ArrayRef<IValue> a2)67 bool attributesEqual(at::ArrayRef<IValue> a1, at::ArrayRef<IValue> a2) {
68 if (a1.size() != a2.size()) {
69 return false;
70 }
71 for (const auto i : c10::irange(a1.size())) {
72 if (!ivaluesEqual(a1[i], a2[i])) {
73 return false;
74 }
75 }
76 return true;
77 }
78
attributesEqual(const IValue & a1,const IValue & a2)79 bool attributesEqual(const IValue& a1, const IValue& a2) {
80 return ivaluesEqual(a1, a2);
81 }
82
83 // this is not a general-purpose comparison of IValues, it only covers the
84 // ivalues that are allowed as attributes, and it does not check type
85 // equivalence of containers.
ivaluesEqual(const IValue & a1,const IValue & a2)86 bool ivaluesEqual(const IValue& a1, const IValue& a2) {
87 if (a1.tagKind() != a2.tagKind()) {
88 return false;
89 }
90 if (a1.isInt()) {
91 return a1.toInt() == a2.toInt();
92 }
93 if (a1.isBool()) {
94 return a1.toBool() == a2.toBool();
95 }
96 if (a1.isDouble()) {
97 return a1.toDouble() == a2.toDouble();
98 }
99 if (a1.isTensor()) {
100 return attributesEqual(a1.toTensor(), a2.toTensor());
101 }
102 if (a1.isNone()) {
103 return true;
104 }
105 if (a1.isString()) {
106 return a1.toStringRef() == a2.toStringRef();
107 }
108 if (a1.isList()) {
109 return attributesEqual(a1.toListRef(), a2.toListRef());
110 }
111 if (a1.isTuple()) {
112 at::ArrayRef<IValue> a1_elem = a1.toTupleRef().elements();
113 at::ArrayRef<IValue> a2_elem = a2.toTupleRef().elements();
114 return attributesEqual(a1_elem, a2_elem);
115 }
116 if (a1.isGenericDict()) {
117 auto a1_dict = a1.toGenericDict();
118 auto a2_dict = a2.toGenericDict();
119 if (a1_dict.size() != a2_dict.size()) {
120 return false;
121 }
122
123 auto it_a1 = a1_dict.begin();
124 auto it_a2 = a2_dict.begin();
125
126 while (it_a1 != a1_dict.end()) {
127 const auto& e_a1 = *it_a1;
128 const auto& e_a2 = *it_a2;
129
130 if (!ivaluesEqual(e_a1.key(), e_a2.key()) ||
131 !ivaluesEqual(e_a1.value(), e_a2.value())) {
132 return false;
133 }
134 it_a1++;
135 it_a2++;
136 }
137 return true;
138 }
139 if (a1.isEnum()) {
140 return a1.toEnumHolder() == a2.toEnumHolder();
141 }
142 if (a1.isObject()) {
143 return &a1.toObjectRef() == &a2.toObjectRef();
144 }
145 if (a1.isGenerator()) {
146 return a1.toGenerator() == a2.toGenerator();
147 }
148 TORCH_INTERNAL_ASSERT(false);
149 }
150
151 // Check whether two nodes have the same attributes in CSE.
152 // This function may be too conservative for general use.
153 // Do NOT support g/gs attributes.
attributesEqualCSE(const Node * lhs,const Node * rhs)154 bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
155 AT_ASSERT(lhs != nullptr);
156 AT_ASSERT(rhs != nullptr);
157 // One has attributes, the other does not.
158 if (lhs->hasAttributes() != rhs->hasAttributes())
159 return false;
160 // Neither has attributes.
161 if (!lhs->hasAttributes() && !rhs->hasAttributes())
162 return true;
163
164 auto lnames = lhs->attributeNames();
165 auto rnames = rhs->attributeNames();
166 std::sort(lnames.begin(), lnames.end());
167 std::sort(rnames.begin(), rnames.end());
168 if (lnames != rnames)
169 return false;
170
171 for (auto name : lnames) {
172 if (lhs->kindOf(name) != rhs->kindOf(name))
173 return false;
174
175 #define COMPARE_ATTRIBUTEVALUE(selector) \
176 case AttributeKind::selector: { \
177 if (!attributesEqual(lhs->selector(name), rhs->selector(name))) \
178 return false; \
179 } break;
180
181 switch (lhs->kindOf(name)) {
182 COMPARE_ATTRIBUTEVALUE(f)
183 COMPARE_ATTRIBUTEVALUE(c)
184 COMPARE_ATTRIBUTEVALUE(fs)
185 COMPARE_ATTRIBUTEVALUE(cs)
186 COMPARE_ATTRIBUTEVALUE(i)
187 COMPARE_ATTRIBUTEVALUE(is)
188 COMPARE_ATTRIBUTEVALUE(s)
189 COMPARE_ATTRIBUTEVALUE(ss)
190 COMPARE_ATTRIBUTEVALUE(t)
191 COMPARE_ATTRIBUTEVALUE(ts)
192 COMPARE_ATTRIBUTEVALUE(ival)
193 case AttributeKind::ty:
194 if (*lhs->ty(name) != *rhs->ty(name)) {
195 return false;
196 }
197 break;
198 case AttributeKind::tys:
199 if (!typeListEqual(lhs->tys(name), rhs->tys(name))) {
200 return false;
201 }
202 break;
203 case AttributeKind::g:
204 case AttributeKind::gs:
205 return false;
206 }
207
208 #undef COMPARE_ATTRIBUTEVALUE
209 }
210
211 return true;
212 }
213
214 } // anonymous namespace
215
216 // Makes a hash that hashes the input Value, the output type
217 // as well as the node attributes
operator ()(const Node * k) const218 size_t HashNode::operator()(const Node* k) const {
219 AT_ASSERT(k != nullptr);
220 size_t constant_hash = 0;
221 if (k->kind() == prim::Constant) {
222 TypePtr type = k->output()->type();
223 if (type->isSubtypeOf(*NumberType::get()) &&
224 k->kindOf(attr::value) == AttributeKind::i) {
225 constant_hash = std::hash<int64_t>{}(k->i(attr::value));
226 } else if (
227 type->isSubtypeOf(*NumberType::get()) &&
228 k->kindOf(attr::value) == AttributeKind::f) {
229 constant_hash = std::hash<double>{}(k->f(attr::value));
230 } else if (
231 type->isSubtypeOf(*NumberType::get()) &&
232 k->kindOf(attr::value) == AttributeKind::c) {
233 constant_hash = c10::hash<c10::complex<double>>{}(k->c(attr::value));
234 } else if (type->isSubtypeOf(*BoolType::get())) {
235 constant_hash = std::hash<bool>{}(k->i(attr::value));
236 }
237 }
238 return get_hash(
239 k->kind(),
240 fmap(k->outputs(), [](const Value* v) { return v->type()->kind(); }),
241 fmap(k->inputs(), [](const Value* v) { return v->unique(); }),
242 constant_hash);
243 }
244
245 // Checks that two nodes have the same inputs, output types
246 // and node attributes.
operator ()(const Node * lhs,const Node * rhs) const247 bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
248 if (lhs == nullptr && rhs == nullptr)
249 return true;
250 if (lhs == nullptr || rhs == nullptr)
251 return false;
252
253 if (lhs->kind() != rhs->kind())
254 return false;
255
256 // Check whether the output types are the same.
257 auto lhs_outputs = lhs->outputs();
258 auto rhs_outputs = rhs->outputs();
259 if (lhs_outputs.size() != rhs_outputs.size())
260 return false;
261 for (const auto i : c10::irange(lhs_outputs.size())) {
262 const auto& lt = lhs_outputs[i]->type();
263 const auto& rt = rhs_outputs[i]->type();
264 if (!(lt == rt || *lt == *rt))
265 return false;
266 }
267
268 // Check whether the inputs are the same.
269 auto lhs_inputs = lhs->inputs();
270 auto rhs_inputs = rhs->inputs();
271 if (lhs_inputs.size() != rhs_inputs.size())
272 return false;
273 if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin()))
274 return false;
275
276 if (!attributesEqualCSE(lhs, rhs))
277 return false;
278
279 // Check if the blocks contained in a op are the same
280 if (lhs->blocks().size() != rhs->blocks().size()) {
281 return false;
282 }
283 for (size_t i = 0; i < lhs->blocks().size(); ++i) {
284 if (lhs->blocks()[i] != rhs->blocks()[i]) {
285 return false;
286 }
287 }
288
289 return true;
290 }
291
292 } // namespace torch::jit
293