1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 // CUDA userspace driver library wrapper functionality.
17
18 #ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
19 #define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
20
21 #include "absl/container/node_hash_map.h"
22 #include "absl/memory/memory.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/synchronization/mutex.h"
25 #include "tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h"
26
27 namespace stream_executor {
28 namespace gpu {
29 // Formats CUresult to output prettified values into a log stream.
ToString(CUresult result)30 static std::string ToString(CUresult result) {
31 const char* error_name;
32 if (cuGetErrorName(result, &error_name)) {
33 return absl::StrCat("UNKNOWN ERROR (", static_cast<int>(result), ")");
34 }
35 const char* error_string;
36 if (cuGetErrorString(result, &error_string)) {
37 return error_name;
38 }
39 return absl::StrCat(error_name, ": ", error_string);
40 }
41
42 // CUDAContext wraps a cuda CUcontext handle, and includes a unique id. The
43 // unique id is positive, and ids are not repeated within the process.
44 class GpuContext {
45 public:
GpuContext(CUcontext context,int64_t id)46 GpuContext(CUcontext context, int64_t id) : context_(context), id_(id) {}
47
context()48 CUcontext context() const { return context_; }
id()49 int64_t id() const { return id_; }
50
51 // Disallow copying and moving.
52 GpuContext(GpuContext&&) = delete;
53 GpuContext(const GpuContext&) = delete;
54 GpuContext& operator=(GpuContext&&) = delete;
55 GpuContext& operator=(const GpuContext&) = delete;
56
57 private:
58 CUcontext const context_;
59 const int64_t id_;
60 };
61
62 // Manages the singleton map of contexts that we've created, mapping
63 // from the CUcontext to the GpuContext* that we pass around internally.
64 // This also manages assignment of unique ids to GpuContexts, to allow
65 // for fast comparison of a context against the current context.
66 //
67 // CUDA-runtime-created contexts are avoided, if triple angle
68 // brace launches are required, by using the scoped activations in
69 // gpu/gpu_activation.h.
70 class CreatedContexts {
71 public:
72 // Returns whether context is a member of the live set.
Has(CUcontext context)73 static bool Has(CUcontext context) {
74 absl::ReaderMutexLock lock(&mu_);
75 return Live()->find(context) != Live()->end();
76 }
77
78 // Adds context to the live set, or returns it if it's already present.
Add(CUcontext context,int device_ordinal)79 static GpuContext* Add(CUcontext context, int device_ordinal) {
80 CHECK(context != nullptr);
81 absl::MutexLock lock(&mu_);
82
83 auto insert_result = Live()->insert(std::make_pair(context, nullptr));
84 auto it = insert_result.first;
85 if (insert_result.second) {
86 // context was not present in the map. Add it.
87 it->second = std::make_unique<GpuContext>(context, next_id_++);
88 (*LiveOrdinal())[device_ordinal].push_back(context);
89 }
90 return it->second.get();
91 }
92
93 // Removes context from the live set.
Remove(CUcontext context)94 static void Remove(CUcontext context) {
95 CHECK(context != nullptr);
96 absl::MutexLock lock(&mu_);
97 auto it = Live()->find(context);
98 CHECK(it != Live()->end()) << context;
99 Live()->erase(it);
100 for (auto p : (*LiveOrdinal())) {
101 auto it2 = std::find(p.second.begin(), p.second.end(), context);
102 if (it2 != p.second.end()) {
103 p.second.erase(it2, it2++);
104 if (p.second.empty()) {
105 LiveOrdinal()->erase(p.first);
106 }
107 break;
108 }
109 }
110 }
111
112 // Return the context associated to that ptr.
GetAnyContext(void * ptr)113 static CUcontext GetAnyContext(void* ptr) {
114 absl::ReaderMutexLock lock(&mu_);
115 int device_ordinal;
116 CUresult result = cuPointerGetAttribute(static_cast<void*>(&device_ordinal),
117 CU_POINTER_ATTRIBUTE_DEVICE_ORDINAL,
118 reinterpret_cast<CUdeviceptr>(ptr));
119 if (result != CUDA_SUCCESS) {
120 LOG(FATAL) << "Not able to get the device_ordinal for ptr: " << ptr
121 << ". Error: " << ToString(result);
122 }
123 CHECK_EQ(LiveOrdinal()->count(device_ordinal), 1);
124 CHECK(!LiveOrdinal()->at(device_ordinal).empty())
125 << "Need at least one context.";
126 return LiveOrdinal()->at(device_ordinal)[0];
127 }
128
129 private:
130 // Returns the live map singleton.
Live()131 static absl::node_hash_map<CUcontext, std::unique_ptr<GpuContext>>* Live() {
132 static auto singleton =
133 new absl::node_hash_map<CUcontext, std::unique_ptr<GpuContext>>;
134 return singleton;
135 }
LiveOrdinal()136 static absl::node_hash_map<int, std::vector<CUcontext>>* LiveOrdinal() {
137 static auto singleton =
138 new absl::node_hash_map<int, std::vector<CUcontext>>;
139 return singleton;
140 }
141
142 // Lock that guards access-to/mutation-of the live set.
143 static absl::Mutex mu_;
144 static int64_t next_id_;
145 };
146 } // namespace gpu
147
148 namespace cuda {
149
150 using MemorySpace = gpu::MemorySpace;
151
152 using CUDADriver = gpu::GpuDriver;
153
154 using ScopedActivateContext = gpu::ScopedActivateContext;
155
156 using CudaContext = gpu::GpuContext;
157
158 // Returns the current context set in CUDA. This is done by calling the cuda
159 // driver (e.g., this value is not our cached view of the current context).
160 CUcontext CurrentContextOrDie();
161
162 } // namespace cuda
163 } // namespace stream_executor
164
165 #endif // TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_CUDA_CUDA_DRIVER_H_
166