xref: /aosp_15_r20/external/pytorch/torch/csrc/jit/codegen/fuser/cpu/temp_file.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/ATen.h>
4 #include <ATen/Utils.h>
5 #include <c10/util/Exception.h>
6 #include <torch/csrc/Export.h>
7 
8 #ifdef _WIN32
9 #include <WinError.h>
10 #include <c10/util/Unicode.h>
11 #include <c10/util/win32-headers.h>
12 #include <fcntl.h>
13 #include <io.h>
14 #include <process.h>
15 #include <stdio.h>
16 #include <sys/stat.h>
17 #include <random>
18 #else
19 #include <unistd.h>
20 #endif
21 
22 #include <string>
23 #include <vector>
24 
25 namespace torch {
26 namespace jit {
27 namespace fuser {
28 namespace cpu {
29 
30 #ifdef _MSC_VER
wmkstemps(wchar_t * tmpl,int suffix_len)31 int wmkstemps(wchar_t* tmpl, int suffix_len) {
32   int len;
33   wchar_t* name;
34   int fd = -1;
35   int save_errno = errno;
36 
37   len = wcslen(tmpl);
38   if (len < 6 + suffix_len ||
39       wcsncmp(&tmpl[len - 6 - suffix_len], L"XXXXXX", 6)) {
40     return -1;
41   }
42 
43   name = &tmpl[len - 6 - suffix_len];
44 
45   std::random_device rd;
46   do {
47     for (unsigned i = 0; i < 6; ++i) {
48       name[i] = "abcdefghijklmnopqrstuvwxyz0123456789"[rd() % 36];
49     }
50 
51     fd = _wopen(tmpl, _O_RDWR | _O_CREAT | _O_EXCL, _S_IWRITE | _S_IREAD);
52   } while (errno == EEXIST);
53 
54   if (fd >= 0) {
55     errno = save_errno;
56     return fd;
57   } else {
58     return -1;
59   }
60 }
61 #endif
62 
63 struct TempFile {
64   AT_DISALLOW_COPY_AND_ASSIGN(TempFile);
65 
TempFileTempFile66   TempFile(const std::string& t, int suffix) {
67 #ifdef _MSC_VER
68     auto wt = c10::u8u16(t);
69     std::vector<wchar_t> tt(wt.c_str(), wt.c_str() + wt.size() + 1);
70     int fd = wmkstemps(tt.data(), suffix);
71     AT_ASSERT(fd != -1);
72     file_ = _wfdopen(fd, L"r+");
73     auto wname = std::wstring(tt.begin(), tt.end() - 1);
74     name_ = c10::u16u8(wname);
75 #else
76     // mkstemps edits its first argument in places
77     // so we make a copy of the string here, including null terminator
78     std::vector<char> tt(t.c_str(), t.c_str() + t.size() + 1);
79     int fd = mkstemps(tt.data(), suffix);
80     AT_ASSERT(fd != -1);
81     file_ = fdopen(fd, "r+");
82     // - 1 because tt.size() includes the null terminator,
83     // but std::string does not expect one
84     name_ = std::string(tt.begin(), tt.end() - 1);
85 #endif
86   }
87 
nameTempFile88   const std::string& name() const {
89     return name_;
90   }
91 
syncTempFile92   void sync() {
93     fflush(file_);
94   }
95 
writeTempFile96   void write(const std::string& str) {
97     size_t result = fwrite(str.c_str(), 1, str.size(), file_);
98     AT_ASSERT(str.size() == result);
99   }
100 
101 #ifdef _MSC_VER
closeTempFile102   void close() {
103     if (file_ != nullptr) {
104       fclose(file_);
105     }
106     file_ = nullptr;
107   }
108 #endif
109 
fileTempFile110   FILE* file() {
111     return file_;
112   }
113 
~TempFileTempFile114   ~TempFile() {
115 #ifdef _MSC_VER
116     if (file_ != nullptr) {
117       fclose(file_);
118     }
119     auto wname = c10::u8u16(name_);
120     if (!wname.empty() && _waccess(wname.c_str(), 0) != -1) {
121       _wunlink(wname.c_str());
122     }
123 #else
124     if (file_ != nullptr) {
125       // unlink first to ensure another mkstemps doesn't
126       // race between close and unlink
127       unlink(name_.c_str());
128       fclose(file_);
129     }
130 #endif
131   }
132 
133  private:
134   FILE* file_ = nullptr;
135   std::string name_;
136 };
137 
138 } // namespace cpu
139 } // namespace fuser
140 } // namespace jit
141 } // namespace torch
142