xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // CUDA userspace driver library wrapper functionality.
17 
18 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
19 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
20 
21 #include "absl/container/node_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h"
26 
27 namespace stream_executor {
28 namespace gpu {
29 // Formats CUresult to output prettified values into a log stream.
ToString(CUresult result)30 static std::string ToString(CUresult result) {
31   const char* error_name;
32   if (cuGetErrorName(result, &error_name)) {
33     return absl::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")");
34   }
35   const char* error_string;
36   if (cuGetErrorString(result, &error_string)) {
37     return error_name;
38   }
39   return absl::StrCat(error_name, ": ", error_string);
40 }
41 
42 // CUDAContext wraps a cuda CUcontext handle, and includes a unique id. The
43 // unique id is positive, and ids are not repeated within the process.
44 class GpuContext {
45  public:
GpuContext(CUcontext context,int64_t id)46   GpuContext(CUcontext context, int64_t id) : context_(context), id_(id) {}
47 
context()48   CUcontext context() const { return context_; }
id()49   int64_t id() const { return id_; }
50 
51   // Disallow copying and moving.
52   GpuContext(GpuContext&&) = delete;
53   GpuContext(const GpuContext&) = delete;
54   GpuContext& operator=(GpuContext&&) = delete;
55   GpuContext& operator=(const GpuContext&) = delete;
56 
57  private:
58   CUcontext const context_;
59   const int64_t id_;
60 };
61 
62 // Manages the singleton map of contexts that we've created, mapping
63 // from the CUcontext to the GpuContext* that we pass around internally.
64 // This also manages assignment of unique ids to GpuContexts, to allow
65 // for fast comparison of a context against the current context.
66 //
67 // CUDA-runtime-created contexts are avoided, if triple angle
68 // brace launches are required, by using the scoped activations in
69 // gpu/gpu_activation.h.
70 class CreatedContexts {
71  public:
72   // Returns whether context is a member of the live set.
Has(CUcontext context)73   static bool Has(CUcontext context) {
74     absl::ReaderMutexLock lock(&mu_);
75     return Live()->find(context) != Live()->end();
76   }
77 
78   // Adds context to the live set, or returns it if it's already present.
Add(CUcontext context,int device_ordinal)79   static GpuContext* Add(CUcontext context, int device_ordinal) {
80     CHECK(context != nullptr);
81     absl::MutexLock lock(&mu_);
82 
83     auto insert_result = Live()->insert(std::make_pair(context, nullptr));
84     auto it = insert_result.first;
85     if (insert_result.second) {
86       // context was not present in the map.  Add it.
87       it->second = std::make_unique<GpuContext>(context, next_id_++);
88       (*LiveOrdinal())[device_ordinal].push_back(context);
89     }
90     return it->second.get();
91   }
92 
93   // Removes context from the live set.
Remove(CUcontext context)94   static void Remove(CUcontext context) {
95     CHECK(context != nullptr);
96     absl::MutexLock lock(&mu_);
97     auto it = Live()->find(context);
98     CHECK(it != Live()->end()) << context;
99     Live()->erase(it);
100     for (auto p : (*LiveOrdinal())) {
101       auto it2 = std::find(p.second.begin(), p.second.end(), context);
102       if (it2 != p.second.end()) {
103         p.second.erase(it2, it2++);
104         if (p.second.empty()) {
105           LiveOrdinal()->erase(p.first);
106         }
107         break;
108       }
109     }
110   }
111 
112   // Return the context associated to that ptr.
GetAnyContext(void * ptr)113   static CUcontext GetAnyContext(void* ptr) {
114     absl::ReaderMutexLock lock(&mu_);
115     int device_ordinal;
116     CUresult result = cuPointerGetAttribute(static_cast<void*>(&device_ordinal),
117                                             CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
118                                             reinterpret_cast<CUdeviceptr>(ptr));
119     if (result != CUDA_SUCCESS) {
120       LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr
121                  << ". Error: " << ToString(result);
122     }
123     CHECK_EQ(LiveOrdinal()->count(device_ordinal), 1);
124     CHECK(!LiveOrdinal()->at(device_ordinal).empty())
125         << "Need at least one context.";
126     return LiveOrdinal()->at(device_ordinal)[0];
127   }
128 
129  private:
130   // Returns the live map singleton.
Live()131   static absl::node_hash_map<CUcontext, std::unique_ptr<GpuContext>>* Live() {
132     static auto singleton =
133         new absl::node_hash_map<CUcontext, std::unique_ptr<GpuContext>>;
134     return singleton;
135   }
LiveOrdinal()136   static absl::node_hash_map<int, std::vector<CUcontext>>* LiveOrdinal() {
137     static auto singleton =
138         new absl::node_hash_map<int, std::vector<CUcontext>>;
139     return singleton;
140   }
141 
142   // Lock that guards access-to/mutation-of the live set.
143   static absl::Mutex mu_;
144   static int64_t next_id_;
145 };
146 }  // namespace gpu
147 
148 namespace cuda {
149 
150 using MemorySpace = gpu::MemorySpace;
151 
152 using CUDADriver = gpu::GpuDriver;
153 
154 using ScopedActivateContext = gpu::ScopedActivateContext;
155 
156 using CudaContext = gpu::GpuContext;
157 
158 // Returns the current context set in CUDA. This is done by calling the cuda
159 // driver (e.g., this value is not our cached view of the current context).
160 CUcontext CurrentContextOrDie();
161 
162 }  // namespace cuda
163 }  // namespace stream_executor
164 
165 #endif  // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
166