xref: /aosp_15_r20/external/pytorch/aten/src/ATen/core/operator_name.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/macros/Macros.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/string_view.h>
6 #include <optional>
7 #include <ostream>
8 #include <string>
9 #include <utility>
10 
11 namespace c10 {
12 
13 // TODO: consider storing namespace separately too
14 struct OperatorName final {
15   std::string name;
16   std::string overload_name;
OperatorNamefinal17   OperatorName(std::string name, std::string overload_name)
18       : name(std::move(name)), overload_name(std::move(overload_name)) {}
19 
20   // TODO: These two functions below are slow!  Fix internal data structures so
21   // I don't have to manually reconstruct the namespaces!
22 
23   // Return the namespace of this OperatorName, if it exists.  The
24   // returned string_view is only live as long as the OperatorName
25   // exists and name is not mutated
getNamespacefinal26   std::optional<c10::string_view> getNamespace() const {
27     auto pos = name.find("::");
28     if (pos == std::string::npos) {
29       return std::nullopt;
30     } else {
31       return std::make_optional(c10::string_view(name.data(), pos));
32     }
33   }
34 
35   // Returns true if we successfully set the namespace
setNamespaceIfNotSetfinal36   bool setNamespaceIfNotSet(const char* ns) {
37     if (!getNamespace().has_value()) {
38       const auto ns_len = strlen(ns);
39       const auto old_name_size = name.size();
40       name.resize(ns_len + 2 + old_name_size);
41       // Shift current value of name to the end of the new space.
42       name.replace(
43           name.size() - old_name_size, old_name_size, name, 0, old_name_size);
44       name.replace(0, ns_len, ns, ns_len);
45       name[ns_len] = ':';
46       name[ns_len + 1] = ':';
47       return true;
48     } else {
49       return false;
50     }
51   }
52 };
53 
54 // Non-owning view of an OperatorName.  Unlike OperatorName, most of
55 // its functions are constexpr, so it can be used for compile time
56 // computations
57 struct OperatorNameView final {
58   c10::string_view name;
59   c10::string_view overload_name;
OperatorNameViewfinal60   constexpr OperatorNameView(
61       c10::string_view name,
62       c10::string_view overload_name)
63       : name(name), overload_name(overload_name) {}
64   // Parses strings like "foo.overload" and also "foo"
parsefinal65   constexpr static OperatorNameView parse(c10::string_view full_name) {
66     auto i = full_name.find('.');
67     if (i == c10::string_view::npos) {
68       return OperatorNameView(full_name, c10::string_view());
69     } else {
70       return OperatorNameView(full_name.substr(0, i), full_name.substr(i + 1));
71     }
72   }
73 };
74 
75 inline bool operator==(const OperatorName& lhs, const OperatorName& rhs) {
76   return lhs.name == rhs.name && lhs.overload_name == rhs.overload_name;
77 }
78 
79 inline bool operator!=(const OperatorName& lhs, const OperatorName& rhs) {
80   return !operator==(lhs, rhs);
81 }
82 
83 TORCH_API std::string toString(const OperatorName& opName);
84 TORCH_API std::ostream& operator<<(std::ostream&, const OperatorName&);
85 
86 } // namespace c10
87 
88 namespace std {
89 template <>
90 struct hash<::c10::OperatorName> {
91   size_t operator()(const ::c10::OperatorName& x) const {
92     return std::hash<std::string>()(x.name) ^
93         (~std::hash<std::string>()(x.overload_name));
94   }
95 };
96 } // namespace std
97