xref: /aosp_15_r20/external/pytorch/c10/util/numa.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <c10/util/Exception.h>
2 #include <c10/util/numa.h>
3 
4 C10_DEFINE_bool(caffe2_cpu_numa_enabled, false, "Use NUMA whenever possible.");
5 
6 #if defined(__linux__) && defined(C10_USE_NUMA) && !defined(C10_MOBILE)
7 #include <numa.h>
8 #include <numaif.h>
9 #include <unistd.h>
10 #define C10_ENABLE_NUMA
11 #endif
12 
13 // This code used to have a lot of VLOGs. However, because allocation might be
14 // triggered during static initialization, it's unsafe to invoke VLOG here
15 
16 namespace c10 {
17 
18 #ifdef C10_ENABLE_NUMA
IsNUMAEnabled()19 bool IsNUMAEnabled() {
20   return FLAGS_caffe2_cpu_numa_enabled && numa_available() >= 0;
21 }
22 
NUMABind(int numa_node_id)23 void NUMABind(int numa_node_id) {
24   if (numa_node_id < 0) {
25     return;
26   }
27   if (!IsNUMAEnabled()) {
28     return;
29   }
30 
31   TORCH_CHECK(
32       numa_node_id <= numa_max_node(),
33       "NUMA node id ",
34       numa_node_id,
35       " is unavailable");
36 
37   auto bm = numa_allocate_nodemask();
38   numa_bitmask_setbit(bm, numa_node_id);
39   numa_bind(bm);
40   numa_bitmask_free(bm);
41 }
42 
GetNUMANode(const void * ptr)43 int GetNUMANode(const void* ptr) {
44   if (!IsNUMAEnabled()) {
45     return -1;
46   }
47   AT_ASSERT(ptr);
48 
49   int numa_node = -1;
50   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
51   TORCH_CHECK(
52       get_mempolicy(
53           &numa_node,
54           nullptr,
55           0,
56           const_cast<void*>(ptr),
57           MPOL_F_NODE | MPOL_F_ADDR) == 0,
58       "Unable to get memory policy, errno:",
59       errno);
60   return numa_node;
61 }
62 
GetNumNUMANodes()63 int GetNumNUMANodes() {
64   if (!IsNUMAEnabled()) {
65     return -1;
66   }
67 
68   return numa_num_configured_nodes();
69 }
70 
NUMAMove(void * ptr,size_t size,int numa_node_id)71 void NUMAMove(void* ptr, size_t size, int numa_node_id) {
72   if (numa_node_id < 0) {
73     return;
74   }
75   if (!IsNUMAEnabled()) {
76     return;
77   }
78   AT_ASSERT(ptr);
79 
80   uintptr_t page_start_ptr =
81       ((reinterpret_cast<uintptr_t>(ptr)) & ~(getpagesize() - 1));
82   // NOLINTNEXTLINE(*-conversions)
83   ptrdiff_t offset = reinterpret_cast<uintptr_t>(ptr) - page_start_ptr;
84   // Avoid extra dynamic allocation and NUMA api calls
85   AT_ASSERT(
86       numa_node_id >= 0 &&
87       static_cast<unsigned>(numa_node_id) < sizeof(unsigned long) * 8);
88   unsigned long mask = 1UL << numa_node_id;
89   // NOLINTNEXTLINE(performance-no-int-to-ptr)
90   TORCH_CHECK(
91       mbind(
92           reinterpret_cast<void*>(page_start_ptr),
93           size + offset,
94           MPOL_BIND,
95           &mask,
96           sizeof(mask) * 8,
97           MPOL_MF_MOVE | MPOL_MF_STRICT) == 0,
98       "Could not move memory to a NUMA node");
99 }
100 
GetCurrentNUMANode()101 int GetCurrentNUMANode() {
102   if (!IsNUMAEnabled()) {
103     return -1;
104   }
105 
106   auto n = numa_node_of_cpu(sched_getcpu());
107   return n;
108 }
109 
110 #else // C10_ENABLE_NUMA
111 
112 bool IsNUMAEnabled() {
113   return false;
114 }
115 
116 void NUMABind(int numa_node_id) {}
117 
118 int GetNUMANode(const void* ptr) {
119   return -1;
120 }
121 
122 int GetNumNUMANodes() {
123   return -1;
124 }
125 
126 void NUMAMove(void* ptr, size_t size, int numa_node_id) {}
127 
128 int GetCurrentNUMANode() {
129   return -1;
130 }
131 
132 #endif // C10_NUMA_ENABLED
133 
134 } // namespace c10
135