xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/ir/node_hashing.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <torch/csrc/jit/ir/ir.h>
2 
3 #include <algorithm>
4 #include <unordered_map>
5 
6 #include <ATen/core/functional.h>
7 #include <ATen/core/symbol.h>
8 #include <c10/util/Exception.h>
9 #include <c10/util/hash.h>
10 #include <c10/util/irange.h>
11 #include <torch/csrc/jit/ir/node_hashing.h>
12 #include <torch/csrc/jit/passes/common_subexpression_elimination.h>
13 
14 namespace torch::jit {
15 
16 namespace {
17 
tensorEqual(const at::Tensor & lhs,const at::Tensor & rhs)18 bool tensorEqual(const at::Tensor& lhs, const at::Tensor& rhs) {
19   // type_equal doesnt distinguish between mkldnn/pytorch cpu tensors,
20   // and we dont want to coalesce mkldnn tensors bc they do layout
21   // transformations based on usage
22   if (lhs.is_mkldnn() || rhs.is_mkldnn()) {
23     return false;
24   }
25   if (lhs.is_nested() || rhs.is_nested()) {
26     return false;
27   }
28   // If device is not equal, lhs.equal(rhs) would throw an error.
29   if (lhs.device() != rhs.device()) {
30     return false;
31   }
32   return lhs.options().type_equal(rhs.options()) && lhs.equal(rhs);
33 }
34 
typeListEqual(const std::vector<TypePtr> & lhs,const std::vector<TypePtr> & rhs)35 bool typeListEqual(
36     const std::vector<TypePtr>& lhs,
37     const std::vector<TypePtr>& rhs) {
38   if (lhs.size() != rhs.size())
39     return false;
40   for (const auto i : c10::irange(lhs.size())) {
41     if (*lhs[i] != *rhs[i]) {
42       return false;
43     }
44   }
45   return true;
46 }
47 
48 template <typename attribute_type> // int64_t, bool, double
attributesEqual(attribute_type a1,attribute_type a2)49 bool attributesEqual(attribute_type a1, attribute_type a2) {
50   return a1 == a2;
51 }
52 
attributesEqual(const at::Tensor & a1,const at::Tensor & a2)53 bool attributesEqual(const at::Tensor& a1, const at::Tensor& a2) {
54   return tensorEqual(a1, a2);
55 }
56 
57 bool ivaluesEqual(const IValue& a1, const IValue& a2);
58 
attributesEqual(const std::vector<at::Tensor> & lhs,const std::vector<at::Tensor> & rhs)59 bool attributesEqual(
60     const std::vector<at::Tensor>& lhs,
61     const std::vector<at::Tensor>& rhs) {
62   if (lhs.size() != rhs.size())
63     return false;
64   return std::equal(lhs.begin(), lhs.end(), rhs.begin(), tensorEqual);
65 }
66 
attributesEqual(at::ArrayRef<IValue> a1,at::ArrayRef<IValue> a2)67 bool attributesEqual(at::ArrayRef<IValue> a1, at::ArrayRef<IValue> a2) {
68   if (a1.size() != a2.size()) {
69     return false;
70   }
71   for (const auto i : c10::irange(a1.size())) {
72     if (!ivaluesEqual(a1[i], a2[i])) {
73       return false;
74     }
75   }
76   return true;
77 }
78 
attributesEqual(const IValue & a1,const IValue & a2)79 bool attributesEqual(const IValue& a1, const IValue& a2) {
80   return ivaluesEqual(a1, a2);
81 }
82 
83 // this is not a general-purpose comparison of IValues, it only covers the
84 // ivalues that are allowed as attributes, and it does not check type
85 // equivalence of containers.
ivaluesEqual(const IValue & a1,const IValue & a2)86 bool ivaluesEqual(const IValue& a1, const IValue& a2) {
87   if (a1.tagKind() != a2.tagKind()) {
88     return false;
89   }
90   if (a1.isInt()) {
91     return a1.toInt() == a2.toInt();
92   }
93   if (a1.isBool()) {
94     return a1.toBool() == a2.toBool();
95   }
96   if (a1.isDouble()) {
97     return a1.toDouble() == a2.toDouble();
98   }
99   if (a1.isTensor()) {
100     return attributesEqual(a1.toTensor(), a2.toTensor());
101   }
102   if (a1.isNone()) {
103     return true;
104   }
105   if (a1.isString()) {
106     return a1.toStringRef() == a2.toStringRef();
107   }
108   if (a1.isList()) {
109     return attributesEqual(a1.toListRef(), a2.toListRef());
110   }
111   if (a1.isTuple()) {
112     at::ArrayRef<IValue> a1_elem = a1.toTupleRef().elements();
113     at::ArrayRef<IValue> a2_elem = a2.toTupleRef().elements();
114     return attributesEqual(a1_elem, a2_elem);
115   }
116   if (a1.isGenericDict()) {
117     auto a1_dict = a1.toGenericDict();
118     auto a2_dict = a2.toGenericDict();
119     if (a1_dict.size() != a2_dict.size()) {
120       return false;
121     }
122 
123     auto it_a1 = a1_dict.begin();
124     auto it_a2 = a2_dict.begin();
125 
126     while (it_a1 != a1_dict.end()) {
127       const auto& e_a1 = *it_a1;
128       const auto& e_a2 = *it_a2;
129 
130       if (!ivaluesEqual(e_a1.key(), e_a2.key()) ||
131           !ivaluesEqual(e_a1.value(), e_a2.value())) {
132         return false;
133       }
134       it_a1++;
135       it_a2++;
136     }
137     return true;
138   }
139   if (a1.isEnum()) {
140     return a1.toEnumHolder() == a2.toEnumHolder();
141   }
142   if (a1.isObject()) {
143     return &a1.toObjectRef() == &a2.toObjectRef();
144   }
145   if (a1.isGenerator()) {
146     return a1.toGenerator() == a2.toGenerator();
147   }
148   TORCH_INTERNAL_ASSERT(false);
149 }
150 
151 // Check whether two nodes have the same attributes in CSE.
152 // This function may be too conservative for general use.
153 // Do NOT support g/gs attributes.
attributesEqualCSE(const Node * lhs,const Node * rhs)154 bool attributesEqualCSE(const Node* lhs, const Node* rhs) {
155   AT_ASSERT(lhs != nullptr);
156   AT_ASSERT(rhs != nullptr);
157   // One has attributes, the other does not.
158   if (lhs->hasAttributes() != rhs->hasAttributes())
159     return false;
160   // Neither has attributes.
161   if (!lhs->hasAttributes() && !rhs->hasAttributes())
162     return true;
163 
164   auto lnames = lhs->attributeNames();
165   auto rnames = rhs->attributeNames();
166   std::sort(lnames.begin(), lnames.end());
167   std::sort(rnames.begin(), rnames.end());
168   if (lnames != rnames)
169     return false;
170 
171   for (auto name : lnames) {
172     if (lhs->kindOf(name) != rhs->kindOf(name))
173       return false;
174 
175 #define COMPARE_ATTRIBUTEVALUE(selector)                            \
176   case AttributeKind::selector: {                                   \
177     if (!attributesEqual(lhs->selector(name), rhs->selector(name))) \
178       return false;                                                 \
179   } break;
180 
181     switch (lhs->kindOf(name)) {
182       COMPARE_ATTRIBUTEVALUE(f)
183       COMPARE_ATTRIBUTEVALUE(c)
184       COMPARE_ATTRIBUTEVALUE(fs)
185       COMPARE_ATTRIBUTEVALUE(cs)
186       COMPARE_ATTRIBUTEVALUE(i)
187       COMPARE_ATTRIBUTEVALUE(is)
188       COMPARE_ATTRIBUTEVALUE(s)
189       COMPARE_ATTRIBUTEVALUE(ss)
190       COMPARE_ATTRIBUTEVALUE(t)
191       COMPARE_ATTRIBUTEVALUE(ts)
192       COMPARE_ATTRIBUTEVALUE(ival)
193       case AttributeKind::ty:
194         if (*lhs->ty(name) != *rhs->ty(name)) {
195           return false;
196         }
197         break;
198       case AttributeKind::tys:
199         if (!typeListEqual(lhs->tys(name), rhs->tys(name))) {
200           return false;
201         }
202         break;
203       case AttributeKind::g:
204       case AttributeKind::gs:
205         return false;
206     }
207 
208 #undef COMPARE_ATTRIBUTEVALUE
209   }
210 
211   return true;
212 }
213 
214 } // anonymous namespace
215 
216 // Makes a hash that hashes the input Value, the output type
217 // as well as the node attributes
operator ()(const Node * k) const218 size_t HashNode::operator()(const Node* k) const {
219   AT_ASSERT(k != nullptr);
220   size_t constant_hash = 0;
221   if (k->kind() == prim::Constant) {
222     TypePtr type = k->output()->type();
223     if (type->isSubtypeOf(*NumberType::get()) &&
224         k->kindOf(attr::value) == AttributeKind::i) {
225       constant_hash = std::hash<int64_t>{}(k->i(attr::value));
226     } else if (
227         type->isSubtypeOf(*NumberType::get()) &&
228         k->kindOf(attr::value) == AttributeKind::f) {
229       constant_hash = std::hash<double>{}(k->f(attr::value));
230     } else if (
231         type->isSubtypeOf(*NumberType::get()) &&
232         k->kindOf(attr::value) == AttributeKind::c) {
233       constant_hash = c10::hash<c10::complex<double>>{}(k->c(attr::value));
234     } else if (type->isSubtypeOf(*BoolType::get())) {
235       constant_hash = std::hash<bool>{}(k->i(attr::value));
236     }
237   }
238   return get_hash(
239       k->kind(),
240       fmap(k->outputs(), [](const Value* v) { return v->type()->kind(); }),
241       fmap(k->inputs(), [](const Value* v) { return v->unique(); }),
242       constant_hash);
243 }
244 
245 // Checks that two nodes have the same inputs, output types
246 // and node attributes.
operator ()(const Node * lhs,const Node * rhs) const247 bool EqualNode::operator()(const Node* lhs, const Node* rhs) const {
248   if (lhs == nullptr && rhs == nullptr)
249     return true;
250   if (lhs == nullptr || rhs == nullptr)
251     return false;
252 
253   if (lhs->kind() != rhs->kind())
254     return false;
255 
256   // Check whether the output types are the same.
257   auto lhs_outputs = lhs->outputs();
258   auto rhs_outputs = rhs->outputs();
259   if (lhs_outputs.size() != rhs_outputs.size())
260     return false;
261   for (const auto i : c10::irange(lhs_outputs.size())) {
262     const auto& lt = lhs_outputs[i]->type();
263     const auto& rt = rhs_outputs[i]->type();
264     if (!(lt == rt || *lt == *rt))
265       return false;
266   }
267 
268   // Check whether the inputs are the same.
269   auto lhs_inputs = lhs->inputs();
270   auto rhs_inputs = rhs->inputs();
271   if (lhs_inputs.size() != rhs_inputs.size())
272     return false;
273   if (!std::equal(lhs_inputs.begin(), lhs_inputs.end(), rhs_inputs.begin()))
274     return false;
275 
276   if (!attributesEqualCSE(lhs, rhs))
277     return false;
278 
279   // Check if the blocks contained in a op are the same
280   if (lhs->blocks().size() != rhs->blocks().size()) {
281     return false;
282   }
283   for (size_t i = 0; i < lhs->blocks().size(); ++i) {
284     if (lhs->blocks()[i] != rhs->blocks()[i]) {
285       return false;
286     }
287   }
288 
289   return true;
290 }
291 
292 } // namespace torch::jit
293