xref: /aosp_15_r20/external/pytorch/c10/core/ConstantSymNodeImpl.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/core/ConstantSymNodeImpl.h>
2 
3 namespace c10 {
4 
5 // This is used to support the case where the lhs is a constant symnode
6 // and the rhs is a nested int symnode. This situation occurs today when we
7 // perform a binary op between nested int and plain int and the
8 // int is promoted into a constant symnode. If we'd like to
9 // support more combinations in the future, we may need to implement some
10 // kind of multiple dispatch.
11 #define DEFINE_BINARY_OP(OP, ROP)                                        \
12   template <typename T>                                                  \
13   c10::SymNode ConstantSymNodeImpl<T>::OP(const c10::SymNode& other) {   \
14     TORCH_INTERNAL_ASSERT(other->is_nested_int());                       \
15     return other->ROP(                                                   \
16         c10::intrusive_ptr<ConstantSymNodeImpl<T>>::reclaim_copy(this)); \
17   }
18 
19 DEFINE_BINARY_OP(eq, eq)
20 DEFINE_BINARY_OP(ne, ne)
21 DEFINE_BINARY_OP(ge, le)
22 DEFINE_BINARY_OP(le, ge)
23 DEFINE_BINARY_OP(lt, gt)
24 DEFINE_BINARY_OP(gt, lt)
25 DEFINE_BINARY_OP(mul, mul)
26 
27 #undef DEFINE_BINARY_OP
28 
29 template class ConstantSymNodeImpl<bool>;
30 template class ConstantSymNodeImpl<int64_t>;
31 
32 } // namespace c10
33