1 #include <torch/csrc/distributed/c10d/FileStore.hpp>
2
3 #include <fcntl.h>
4 #include <sys/stat.h>
5 #include <cassert>
6 #include <cstdint>
7
8 #ifdef _WIN32
9 #include <c10/util/win32-headers.h>
10 #include <fileapi.h>
11 #include <io.h>
12 #include <filesystem>
13 #else
14 #include <sys/file.h>
15 #include <unistd.h>
16 #endif
17
18 #include <chrono>
19 #include <cstdio>
20 #include <thread>
21 #include <utility>
22
23 #include <c10/util/Exception.h>
24
25 #define SYSASSERT(rv, ...) \
26 if ((rv) < 0) { \
27 C10_THROW_ERROR(DistStoreError, std::strerror(errno)); \
28 }
29
30 #ifdef _WIN32
31 #define LOCK_EX 0x00000001
32 #define LOCK_SH 0x00000010
33 #define LOCK_UN 0x00000100
34
flock_(int fd,int op)35 int flock_(int fd, int op) {
36 HANDLE hdl = (HANDLE)_get_osfhandle(fd);
37 DWORD low = 1, high = 0;
38 OVERLAPPED offset = {0, 0, 0, 0, NULL};
39
40 if ((intptr_t)hdl < 0)
41 return -1;
42
43 switch (op) {
44 case LOCK_EX:
45 if (LockFileEx(hdl, LOCKFILE_EXCLUSIVE_LOCK, 0, low, high, &offset))
46 return 0;
47 break;
48 case LOCK_SH:
49 if (LockFileEx(hdl, 0, 0, low, high, &offset))
50 return 0;
51 break;
52 case LOCK_UN:
53 if (UnlockFileEx(hdl, 0, low, high, &offset) != 0)
54 return 0;
55 break;
56 default:
57 break;
58 }
59 errno = EINVAL;
60 return -1;
61 }
62 #endif
63
64 namespace c10d {
65
66 namespace {
67
68 template <typename F>
syscall(F fn)69 auto syscall(F fn) {
70 while (true) {
71 auto rv = fn();
72 if (rv == -1) {
73 if (errno == EINTR) {
74 continue;
75 }
76 }
77 return rv;
78 }
79 return typename std::invoke_result_t<F>{-1};
80 }
81
82 // For a comprehensive overview of file locking methods,
83 // see: https://gavv.github.io/blog/file-locks/.
84 // We stick to flock(2) here because we don't care about
85 // locking byte ranges and don't want locks to be process-wide.
86
87 // RAII wrapper around flock(2)
88 class Lock {
89 public:
Lock(int fd,int operation)90 explicit Lock(int fd, int operation) : fd_(fd) {
91 flock(operation);
92 }
93
94 // NOLINTNEXTLINE(bugprone-exception-escape)
~Lock()95 ~Lock() {
96 unlock();
97 }
98
99 Lock(const Lock& that) = delete;
100
operator =(Lock && other)101 Lock& operator=(Lock&& other) noexcept {
102 if (this != &other) {
103 fd_ = other.fd_;
104 other.fd_ = -1;
105 }
106 return *this;
107 }
108
Lock(Lock && other)109 Lock(Lock&& other) noexcept {
110 *this = std::move(other);
111 }
112
unlock()113 void unlock() {
114 if (fd_ >= 0) {
115 flock(LOCK_UN);
116 fd_ = -1;
117 }
118 }
119
120 protected:
121 int fd_{-1};
122
flock(int operation)123 void flock(int operation) {
124 #ifdef _WIN32
125 auto rv = syscall(std::bind(::flock_, fd_, operation));
126 #else
127 auto rv = syscall([this, operation] { return ::flock(fd_, operation); });
128 #endif
129 SYSASSERT(rv, "flock");
130 }
131 };
132
133 class File {
134 public:
File(const std::string & path,int flags,std::chrono::milliseconds timeout)135 explicit File(
136 const std::string& path,
137 int flags,
138 std::chrono::milliseconds timeout) {
139 const auto start = std::chrono::steady_clock::now();
140 while (true) {
141 #ifdef _WIN32
142 fd_ = syscall(std::bind(
143 ::open, path.c_str(), flags | _O_BINARY, _S_IREAD | _S_IWRITE));
144 #else
145 fd_ = syscall([capture0 = path.c_str(), flags] {
146 return ::open(capture0, flags, 0644);
147 });
148 #endif
149 // Only retry when the file doesn't exist, since we are waiting for the
150 // file to be created in this case to address the following issue:
151 // https://github.com/pytorch/pytorch/issues/13750
152 if (fd_ >= 0 || errno != ENOENT) {
153 break;
154 }
155 #ifdef _WIN32
156 // if the parent folder doesn't exist it will never be able to create the
157 // file so we can skip the retry
158 if (!std::filesystem::exists(std::filesystem::path(path).parent_path())) {
159 break;
160 }
161 #endif
162 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
163 std::chrono::steady_clock::now() - start);
164 if (timeout != c10d::Store::kNoTimeout && elapsed > timeout) {
165 break;
166 }
167 std::this_thread::sleep_for(std::chrono::milliseconds(10));
168 }
169 SYSASSERT(fd_, "open(" + path + ")");
170 }
171
~File()172 ~File() {
173 ::close(fd_);
174 }
175
lockShared()176 Lock lockShared() {
177 return Lock(fd_, LOCK_SH);
178 }
179
lockExclusive()180 Lock lockExclusive() {
181 return Lock(fd_, LOCK_EX);
182 }
183
seek(off_t offset,int whence)184 off_t seek(off_t offset, int whence) {
185 auto rv =
186 syscall([this, offset, whence] { return lseek(fd_, offset, whence); });
187 SYSASSERT(rv, "lseek");
188 return rv;
189 }
190
tell()191 off_t tell() {
192 auto rv = syscall([this] { return lseek(fd_, 0, SEEK_CUR); });
193 SYSASSERT(rv, "lseek");
194 return rv;
195 }
196
size()197 off_t size() {
198 auto pos = tell();
199 auto size = seek(0, SEEK_END);
200 seek(pos, SEEK_SET);
201 return size;
202 }
203
write(const void * buf,size_t count)204 void write(const void* buf, size_t count) {
205 while (count > 0) {
206 auto rv =
207 syscall([this, buf, count] { return ::write(fd_, buf, count); });
208 SYSASSERT(rv, "write");
209 buf = (uint8_t*)buf + rv;
210 count -= rv;
211 }
212 }
213
read(void * buf,size_t count)214 void read(void* buf, size_t count) {
215 while (count > 0) {
216 auto rv = syscall([this, buf, count] { return ::read(fd_, buf, count); });
217 SYSASSERT(rv, "read");
218 buf = (uint8_t*)buf + rv;
219 count -= rv;
220 }
221 }
222
write(const std::string & str)223 void write(const std::string& str) {
224 uint32_t len = str.size();
225 assert(str.size() <= std::numeric_limits<decltype(len)>::max());
226 write(&len, sizeof(len));
227 write(str.c_str(), len);
228 }
229
write(const std::vector<uint8_t> & data)230 void write(const std::vector<uint8_t>& data) {
231 uint32_t len = data.size();
232 assert(data.size() <= std::numeric_limits<decltype(len)>::max());
233 write(&len, sizeof(len));
234 write(data.data(), len);
235 }
236
read(std::string & str)237 void read(std::string& str) {
238 uint32_t len = 0;
239 read(&len, sizeof(len));
240 std::vector<uint8_t> buf(len);
241 read(buf.data(), len);
242 str.assign(buf.begin(), buf.end());
243 }
244
read(std::vector<uint8_t> & data)245 void read(std::vector<uint8_t>& data) {
246 uint32_t len = 0;
247 read(&len, sizeof(len));
248 data.resize(len);
249 read(data.data(), len);
250 }
251
252 protected:
253 int fd_;
254 };
255
refresh(File & file,off_t pos,std::unordered_map<std::string,std::vector<uint8_t>> & cache,const std::string & deletePrefix)256 off_t refresh(
257 File& file,
258 off_t pos,
259 std::unordered_map<std::string, std::vector<uint8_t>>& cache,
260 const std::string& deletePrefix) {
261 auto size = file.size();
262 if (size != pos) {
263 std::string tmpKey;
264 std::vector<uint8_t> tmpValue;
265 file.seek(pos, SEEK_SET);
266 while (size > pos) {
267 file.read(tmpKey);
268 file.read(tmpValue);
269 if (tmpKey.compare(0, deletePrefix.size(), deletePrefix) == 0) {
270 cache.erase(tmpKey.substr(deletePrefix.size()));
271 } else {
272 cache[tmpKey] = std::move(tmpValue);
273 }
274 pos = file.tell();
275 }
276 }
277 file.seek(0, SEEK_SET);
278 return pos;
279 }
280
281 } // namespace
282
FileStore(std::string path,int numWorkers)283 FileStore::FileStore(std::string path, int numWorkers)
284 : Store(),
285 path_(std::move(path)),
286
287 numWorkers_(numWorkers),
288 cleanupKey_("cleanup/"),
289 refCountKey_("refcount/"),
290 regularPrefix_("/"),
291 deletePrefix_("-") {
292 addHelper(refCountKey_, 1);
293 }
294
295 // NOLINTNEXTLINE(bugprone-exception-escape)
~FileStore()296 FileStore::~FileStore() {
297 // If the file does not exist - exit.
298 // This can happen when FileStore is invoked from python language which has
299 // GC. If python code has directory cleanup procedure, the race condition may
300 // occur between that code and this deconstructor. As a result, we check for
301 // file existense before cleanup
302 #ifdef _WIN32
303 int res = syscall(std::bind(::_access, path_.c_str(), 0));
304 #else
305 int res =
306 syscall([filepath = path_.c_str()] { return ::access(filepath, F_OK); });
307 #endif
308 if (res == -1) {
309 return;
310 }
311
312 // cleanup key will be different from all rest keys since all rest keys will
313 // have a regular prefix.
314 auto numFinishedWorker = addHelper(cleanupKey_, 1);
315 auto refCount = addHelper(refCountKey_, -1);
316 // The last worker cleans up the file. If numWorkers was not initialized to
317 // a specific postive value (i.e. meaning that there was not a fixed number
318 // of workers), we don't attempt to clean.
319 // Clean up the file if number of references is 0.
320 if (refCount == 0 && numWorkers_ >= 0 && numFinishedWorker >= numWorkers_) {
321 // Best effort removal without checking the return
322 ::remove(path_.c_str());
323 }
324 }
325
set(const std::string & key,const std::vector<uint8_t> & value)326 void FileStore::set(const std::string& key, const std::vector<uint8_t>& value) {
327 std::string regKey = regularPrefix_ + key;
328 std::unique_lock<std::mutex> l(activeFileOpLock_);
329 File file(path_, O_RDWR | O_CREAT, timeout_);
330 auto lock = file.lockExclusive();
331 file.seek(0, SEEK_END);
332 file.write(regKey);
333 file.write(value);
334 }
335
compareSet(const std::string & key,const std::vector<uint8_t> & expectedValue,const std::vector<uint8_t> & desiredValue)336 std::vector<uint8_t> FileStore::compareSet(
337 const std::string& key,
338 const std::vector<uint8_t>& expectedValue,
339 const std::vector<uint8_t>& desiredValue) {
340 std::string regKey = regularPrefix_ + key;
341 std::unique_lock<std::mutex> l(activeFileOpLock_);
342 File file(path_, O_RDWR | O_CREAT, timeout_);
343 auto lock = file.lockExclusive();
344 // Always refresh since even though the key exists in the cache,
345 // it might be outdated
346 pos_ = refresh(file, pos_, cache_, deletePrefix_);
347 if ((cache_.count(regKey) == 0 && expectedValue.empty()) ||
348 (cache_.count(regKey) != 0 && cache_[regKey] == expectedValue)) {
349 // if the key does not exist and currentValue arg is empty or
350 // the key does exist and current value is what is expected, then set it
351 file.seek(0, SEEK_END);
352 file.write(regKey);
353 file.write(desiredValue);
354 return desiredValue;
355 } else if (cache_.count(regKey) == 0) {
356 // if the key does not exist
357 return expectedValue;
358 }
359 // key exists but current value is not expected
360 return cache_[regKey];
361 }
362
get(const std::string & key)363 std::vector<uint8_t> FileStore::get(const std::string& key) {
364 std::string regKey = regularPrefix_ + key;
365 const auto start = std::chrono::steady_clock::now();
366 while (true) {
367 std::unique_lock<std::mutex> l(activeFileOpLock_);
368 File file(path_, O_RDONLY, timeout_);
369 auto lock = file.lockShared();
370 auto size = file.size();
371 if (cache_.count(regKey) == 0 && size == pos_) {
372 // No new entries; release the shared lock and sleep for a bit
373 lock.unlock();
374 l.unlock();
375 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
376 std::chrono::steady_clock::now() - start);
377 if (timeout_ != kNoTimeout && elapsed > timeout_) {
378 auto err = c10::str(
379 "Timeout waiting for key: ",
380 key,
381 " after ",
382 timeout_.count(),
383 " ms");
384 TORCH_CHECK(false, err);
385 }
386 std::this_thread::sleep_for(std::chrono::milliseconds(10));
387 continue;
388 }
389 // Always refresh since even though the key exists in the cache,
390 // it might be outdated
391 pos_ = refresh(file, pos_, cache_, deletePrefix_);
392 if (cache_.count(regKey) != 0) {
393 return cache_[regKey];
394 }
395 }
396 }
397
addHelper(const std::string & key,int64_t i)398 int64_t FileStore::addHelper(const std::string& key, int64_t i) {
399 std::unique_lock<std::mutex> l(activeFileOpLock_);
400 File file(path_, O_RDWR | O_CREAT, timeout_);
401 auto lock = file.lockExclusive();
402 pos_ = refresh(file, pos_, cache_, deletePrefix_);
403
404 const auto& value = cache_[key];
405 int64_t ti = i;
406 if (!value.empty()) {
407 auto buf = reinterpret_cast<const char*>(value.data());
408 auto len = value.size();
409 ti += std::stoll(std::string(buf, len));
410 }
411 // Always seek to the end to write
412 file.seek(0, SEEK_END);
413 // File cursor is at the end of the file now, and we have an
414 // exclusive lock, so we can write the new value.
415 file.write(key);
416 file.write(std::to_string(ti));
417 return ti;
418 }
419
add(const std::string & key,int64_t value)420 int64_t FileStore::add(const std::string& key, int64_t value) {
421 std::string regKey = regularPrefix_ + key;
422 return addHelper(regKey, value);
423 }
424
getNumKeys()425 int64_t FileStore::getNumKeys() {
426 std::unique_lock<std::mutex> l(activeFileOpLock_);
427 File file(path_, O_RDONLY, timeout_);
428 auto lock = file.lockShared();
429 pos_ = refresh(file, pos_, cache_, deletePrefix_);
430 return static_cast<int64_t>(cache_.size());
431 }
432
deleteKey(const std::string & key)433 bool FileStore::deleteKey(const std::string& key) {
434 std::string deleteKey = deletePrefix_ + regularPrefix_ + key;
435 std::unique_lock<std::mutex> l(activeFileOpLock_);
436 File file(path_, O_RDWR, timeout_);
437 auto lock = file.lockExclusive();
438 file.seek(0, SEEK_END);
439 file.write(deleteKey);
440 file.write(std::vector<uint8_t>{});
441 return true;
442 }
443
check(const std::vector<std::string> & keys)444 bool FileStore::check(const std::vector<std::string>& keys) {
445 std::unique_lock<std::mutex> l(activeFileOpLock_);
446 File file(path_, O_RDONLY, timeout_);
447 auto lock = file.lockShared();
448 pos_ = refresh(file, pos_, cache_, deletePrefix_);
449
450 for (const auto& key : keys) {
451 std::string regKey = regularPrefix_ + key;
452 if (cache_.count(regKey) == 0) {
453 return false;
454 }
455 }
456
457 return true;
458 }
459
wait(const std::vector<std::string> & keys)460 void FileStore::wait(const std::vector<std::string>& keys) {
461 wait(keys, timeout_);
462 }
463
wait(const std::vector<std::string> & keys,const std::chrono::milliseconds & timeout)464 void FileStore::wait(
465 const std::vector<std::string>& keys,
466 const std::chrono::milliseconds& timeout) {
467 // Not using inotify because it doesn't work on many
468 // shared filesystems (such as NFS).
469 const auto start = std::chrono::steady_clock::now();
470 while (!check(keys)) {
471 const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
472 std::chrono::steady_clock::now() - start);
473 if (timeout != kNoTimeout && elapsed > timeout) {
474 TORCH_CHECK(false, "Wait timeout");
475 }
476
477 /* sleep override */
478 std::this_thread::sleep_for(std::chrono::milliseconds(10));
479 }
480 }
481
482 } // namespace c10d
483