1 #pragma once
2
3 #include <ATen/ATen.h>
4 #include <ATen/Utils.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/Export.h>
7
8 #ifdef _WIN32
9 #include <WinError.h>
10 #include <c10/util/Unicode.h>
11 #include <c10/util/win32-headers.h>
12 #include <fcntl.h>
13 #include <io.h>
14 #include <process.h>
15 #include <stdio.h>
16 #include <sys/stat.h>
17 #include <random>
18 #else
19 #include <unistd.h>
20 #endif
21
22 #include <string>
23 #include <vector>
24
25 namespace torch {
26 namespace jit {
27 namespace fuser {
28 namespace cpu {
29
30 #ifdef _MSC_VER
wmkstemps(wchar_t * tmpl,int suffix_len)31 int wmkstemps(wchar_t* tmpl, int suffix_len) {
32 int len;
33 wchar_t* name;
34 int fd = -1;
35 int save_errno = errno;
36
37 len = wcslen(tmpl);
38 if (len < 6 + suffix_len ||
39 wcsncmp(&tmpl[len - 6 - suffix_len], L"XXXXXX", 6)) {
40 return -1;
41 }
42
43 name = &tmpl[len - 6 - suffix_len];
44
45 std::random_device rd;
46 do {
47 for (unsigned i = 0; i < 6; ++i) {
48 name[i] = "abcdefghijklmnopqrstuvwxyz0123456789"[rd() % 36];
49 }
50
51 fd = _wopen(tmpl, _O_RDWR | _O_CREAT | _O_EXCL, _S_IWRITE | _S_IREAD);
52 } while (errno == EEXIST);
53
54 if (fd >= 0) {
55 errno = save_errno;
56 return fd;
57 } else {
58 return -1;
59 }
60 }
61 #endif
62
63 struct TempFile {
64 AT_DISALLOW_COPY_AND_ASSIGN(TempFile);
65
TempFileTempFile66 TempFile(const std::string& t, int suffix) {
67 #ifdef _MSC_VER
68 auto wt = c10::u8u16(t);
69 std::vector<wchar_t> tt(wt.c_str(), wt.c_str() + wt.size() + 1);
70 int fd = wmkstemps(tt.data(), suffix);
71 AT_ASSERT(fd != -1);
72 file_ = _wfdopen(fd, L"r+");
73 auto wname = std::wstring(tt.begin(), tt.end() - 1);
74 name_ = c10::u16u8(wname);
75 #else
76 // mkstemps edits its first argument in places
77 // so we make a copy of the string here, including null terminator
78 std::vector<char> tt(t.c_str(), t.c_str() + t.size() + 1);
79 int fd = mkstemps(tt.data(), suffix);
80 AT_ASSERT(fd != -1);
81 file_ = fdopen(fd, "r+");
82 // - 1 because tt.size() includes the null terminator,
83 // but std::string does not expect one
84 name_ = std::string(tt.begin(), tt.end() - 1);
85 #endif
86 }
87
nameTempFile88 const std::string& name() const {
89 return name_;
90 }
91
syncTempFile92 void sync() {
93 fflush(file_);
94 }
95
writeTempFile96 void write(const std::string& str) {
97 size_t result = fwrite(str.c_str(), 1, str.size(), file_);
98 AT_ASSERT(str.size() == result);
99 }
100
101 #ifdef _MSC_VER
closeTempFile102 void close() {
103 if (file_ != nullptr) {
104 fclose(file_);
105 }
106 file_ = nullptr;
107 }
108 #endif
109
fileTempFile110 FILE* file() {
111 return file_;
112 }
113
~TempFileTempFile114 ~TempFile() {
115 #ifdef _MSC_VER
116 if (file_ != nullptr) {
117 fclose(file_);
118 }
119 auto wname = c10::u8u16(name_);
120 if (!wname.empty() && _waccess(wname.c_str(), 0) != -1) {
121 _wunlink(wname.c_str());
122 }
123 #else
124 if (file_ != nullptr) {
125 // unlink first to ensure another mkstemps doesn't
126 // race between close and unlink
127 unlink(name_.c_str());
128 fclose(file_);
129 }
130 #endif
131 }
132
133 private:
134 FILE* file_ = nullptr;
135 std::string name_;
136 };
137
138 } // namespace cpu
139 } // namespace fuser
140 } // namespace jit
141 } // namespace torch
142