1 #pragma once
2
3 #include <algorithm>
4 #include <cstddef>
5 #include <cstring>
6 #include <functional>
7 #include <iterator>
8 #include <limits>
9 #include <ostream>
10 #include <stdexcept>
11 #include <string>
12 #include <string_view>
13
14 #include <c10/macros/Macros.h>
15
16 namespace c10 {
17
18 /**
19 * Port of std::string_view with methods from C++20.
20 * Implemented following the interface definition in
21 * https://en.cppreference.com/w/cpp/string/basic_string_view
22 * See there for the API documentation.
23 *
24 * Difference: We don't have a Traits template parameter because
25 * std::char_traits isn't constexpr and we'd have to reimplement
26 * std::char_traits if we wanted to use it with our constexpr basic_string_view.
27 */
28 template <class CharT>
29 class basic_string_view final {
30 public:
31 using value_type = CharT;
32 using pointer = CharT*;
33 using const_pointer = const CharT*;
34 using reference = CharT&;
35 using const_reference = const CharT&;
36 using const_iterator = const CharT*;
37 using iterator = const_iterator;
38 using const_reverse_iterator = std::reverse_iterator<const_iterator>;
39 using reverse_iterator = const_reverse_iterator;
40 using size_type = std::size_t;
41 using difference_type = std::ptrdiff_t;
42
43 static constexpr size_type npos = size_type(-1);
44
basic_string_view()45 constexpr basic_string_view() noexcept : begin_(nullptr) {}
46
basic_string_view(const_pointer str,size_type count)47 explicit constexpr basic_string_view(const_pointer str, size_type count)
48 : begin_(str), size_(count) {}
49
basic_string_view(const_pointer str)50 /* implicit */ constexpr basic_string_view(const_pointer str)
51 : basic_string_view(str, strlen_(str)) {}
52
basic_string_view(const::std::basic_string<CharT> & str)53 /* implicit */ basic_string_view(const ::std::basic_string<CharT>& str)
54 : basic_string_view(str.data(), str.size()) {}
55
56 constexpr basic_string_view(const basic_string_view&) noexcept = default;
57
58 constexpr basic_string_view& operator=(
59 const basic_string_view& rhs) noexcept {
60 begin_ = rhs.begin_;
61 size_ = rhs.size_;
62 return *this;
63 }
64
65 explicit operator ::std::basic_string<CharT>() const {
66 return ::std::basic_string<CharT>(data(), size());
67 }
68
begin()69 constexpr const_iterator begin() const noexcept {
70 return cbegin();
71 }
72
cbegin()73 constexpr const_iterator cbegin() const noexcept {
74 return begin_;
75 }
76
end()77 constexpr const_iterator end() const noexcept {
78 return cend();
79 }
80
cend()81 constexpr const_iterator cend() const noexcept {
82 return begin_ + size_;
83 }
84
rbegin()85 constexpr const_reverse_iterator rbegin() const noexcept {
86 return crbegin();
87 }
88
crbegin()89 constexpr const_reverse_iterator crbegin() const noexcept {
90 return const_reverse_iterator(this->end());
91 }
92
rend()93 constexpr const_reverse_iterator rend() const noexcept {
94 return crend();
95 }
96
crend()97 constexpr const_reverse_iterator crend() const noexcept {
98 return const_reverse_iterator(this->begin());
99 }
100
begin(basic_string_view sv)101 friend constexpr const_iterator begin(basic_string_view sv) noexcept {
102 return sv.begin();
103 }
104
end(basic_string_view sv)105 friend constexpr const_iterator end(basic_string_view sv) noexcept {
106 return sv.end();
107 }
108
109 constexpr const_reference operator[](size_type pos) const {
110 // TODO: split out
111 return at_(pos);
112 }
113
at(size_type pos)114 constexpr const_reference at(size_type pos) const {
115 #if !defined( \
116 __CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code
117 return C10_UNLIKELY(pos >= size_)
118 ? (throw std::out_of_range(
119 "string_view::operator[] or string_view::at() out of range. Index: " +
120 std::to_string(pos) + ", size: " + std::to_string(size())),
121 at_(0))
122 : at_(pos);
123 #else
124 return at_(pos);
125 #endif
126 }
127
front()128 constexpr const_reference front() const {
129 return *begin_;
130 }
131
back()132 constexpr const_reference back() const {
133 return *(begin_ + size_ - 1);
134 }
135
data()136 constexpr const_pointer data() const noexcept {
137 return begin_;
138 }
139
size()140 constexpr size_type size() const noexcept {
141 return size_;
142 }
143
length()144 constexpr size_type length() const noexcept {
145 return size();
146 }
147
max_size()148 constexpr size_type max_size() const noexcept {
149 return std::numeric_limits<difference_type>::max();
150 }
151
empty()152 C10_NODISCARD constexpr bool empty() const noexcept {
153 return size() == 0;
154 }
155
remove_prefix(size_type n)156 constexpr void remove_prefix(size_type n) {
157 if (n > size()) {
158 throw std::out_of_range(
159 "basic_string_view::remove_prefix: out of range. PrefixLength: " +
160 std::to_string(n) + ", size: " + std::to_string(size()));
161 }
162 begin_ += n;
163 size_ -= n;
164 }
165
remove_suffix(size_type n)166 constexpr void remove_suffix(size_type n) {
167 if (n > size()) {
168 throw std::out_of_range(
169 "basic_string_view::remove_suffix: out of range. SuffixLength: " +
170 std::to_string(n) + ", size: " + std::to_string(size()));
171 }
172 size_ -= n;
173 }
174
swap(basic_string_view & sv)175 constexpr void swap(basic_string_view& sv) noexcept {
176 auto tmp = *this;
177 *this = sv;
178 sv = tmp;
179 }
180
181 size_type copy(pointer dest, size_type count, size_type pos = 0) const {
182 if (pos > size_) {
183 throw std::out_of_range(
184 "basic_string_view::copy: out of range. Index: " +
185 std::to_string(pos) + ", size: " + std::to_string(size()));
186 }
187 size_type copy_length = std::min(count, size_ - pos);
188 for (auto iter = begin() + pos, end = iter + copy_length; iter != end;) {
189 *(dest++) = *(iter++);
190 }
191 return copy_length;
192 }
193
194 constexpr basic_string_view substr(size_type pos = 0, size_type count = npos)
195 const {
196 #if !defined( \
197 __CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code
198 return (pos > size_)
199 ? (throw std::out_of_range(
200 "basic_string_view::substr parameter out of bounds. Index: " +
201 std::to_string(pos) + ", size: " + std::to_string(size())),
202 substr_())
203 : substr_(pos, count);
204 #else
205 return substr_(pos, count);
206 #endif
207 }
208
compare(basic_string_view rhs)209 constexpr int compare(basic_string_view rhs) const noexcept {
210 // Write it iteratively. This is faster.
211 for (size_t i = 0, end = std::min(size(), rhs.size()); i < end; ++i) {
212 if (at_(i) < rhs.at_(i)) {
213 return -1;
214 } else if (at_(i) > rhs.at_(i)) {
215 return 1;
216 }
217 }
218 if (size() < rhs.size()) {
219 return -1;
220 } else if (size() > rhs.size()) {
221 return 1;
222 }
223 return 0;
224 }
225
compare(size_type pos1,size_type count1,basic_string_view v)226 constexpr int compare(size_type pos1, size_type count1, basic_string_view v)
227 const {
228 return substr(pos1, count1).compare(v);
229 }
230
compare(size_type pos1,size_type count1,basic_string_view v,size_type pos2,size_type count2)231 constexpr int compare(
232 size_type pos1,
233 size_type count1,
234 basic_string_view v,
235 size_type pos2,
236 size_type count2) const {
237 return substr(pos1, count1).compare(v.substr(pos2, count2));
238 }
239
compare(const_pointer s)240 constexpr int compare(const_pointer s) const {
241 return compare(basic_string_view(s));
242 }
243
compare(size_type pos1,size_type count1,const_pointer s)244 constexpr int compare(size_type pos1, size_type count1, const_pointer s)
245 const {
246 return substr(pos1, count1).compare(basic_string_view(s));
247 }
248
compare(size_type pos1,size_type count1,const_pointer s,size_type count2)249 constexpr int compare(
250 size_type pos1,
251 size_type count1,
252 const_pointer s,
253 size_type count2) const {
254 return substr(pos1, count1).compare(basic_string_view(s, count2));
255 }
256
257 friend constexpr bool operator==(
258 basic_string_view lhs,
259 basic_string_view rhs) noexcept {
260 return lhs.equals_(rhs);
261 }
262
263 friend constexpr bool operator!=(
264 basic_string_view lhs,
265 basic_string_view rhs) noexcept {
266 return !(lhs == rhs);
267 }
268
269 friend constexpr bool operator<(
270 basic_string_view lhs,
271 basic_string_view rhs) noexcept {
272 return lhs.compare(rhs) < 0;
273 }
274
275 friend constexpr bool operator>=(
276 basic_string_view lhs,
277 basic_string_view rhs) noexcept {
278 return !(lhs < rhs);
279 }
280
281 friend constexpr bool operator>(
282 basic_string_view lhs,
283 basic_string_view rhs) noexcept {
284 return rhs < lhs;
285 }
286
287 friend constexpr bool operator<=(
288 basic_string_view lhs,
289 basic_string_view rhs) noexcept {
290 return !(lhs > rhs);
291 }
292
starts_with(basic_string_view prefix)293 constexpr bool starts_with(basic_string_view prefix) const noexcept {
294 return (prefix.size() > size()) ? false
295 : prefix.equals_(substr_(0, prefix.size()));
296 }
297
starts_with(CharT prefix)298 constexpr bool starts_with(CharT prefix) const noexcept {
299 return !empty() && prefix == front();
300 }
301
starts_with(const_pointer prefix)302 constexpr bool starts_with(const_pointer prefix) const {
303 return starts_with(basic_string_view(prefix));
304 }
305
ends_with(basic_string_view suffix)306 constexpr bool ends_with(basic_string_view suffix) const noexcept {
307 return (suffix.size() > size())
308 ? false
309 : suffix.equals_(substr_(size() - suffix.size(), suffix.size()));
310 }
311
ends_with(CharT suffix)312 constexpr bool ends_with(CharT suffix) const noexcept {
313 return !empty() && suffix == back();
314 }
315
ends_with(const_pointer suffix)316 constexpr bool ends_with(const_pointer suffix) const {
317 return ends_with(basic_string_view(suffix));
318 }
319
320 constexpr size_type find(basic_string_view v, size_type pos = 0)
321 const noexcept {
322 if (v.size() == 0) {
323 return pos <= size() ? pos : npos;
324 }
325
326 if (pos + v.size() <= size()) {
327 for (size_type cur = pos, end = size() - v.size(); cur <= end; ++cur) {
328 if (v.at_(0) == at_(cur) &&
329 v.substr_(1).equals_(substr_(cur + 1, v.size() - 1))) {
330 return cur;
331 }
332 }
333 }
334 return npos;
335 }
336
337 constexpr size_type find(CharT ch, size_type pos = 0) const noexcept {
338 return find_first_if_(pos, charIsEqual_{ch});
339 }
340
find(const_pointer s,size_type pos,size_type count)341 constexpr size_type find(const_pointer s, size_type pos, size_type count)
342 const {
343 return find(basic_string_view(s, count), pos);
344 }
345
346 constexpr size_type find(const_pointer s, size_type pos = 0) const {
347 return find(basic_string_view(s), pos);
348 }
349
350 constexpr size_type rfind(basic_string_view v, size_type pos = npos)
351 const noexcept {
352 // Write it iteratively. This is faster.
353 if (v.size() == 0) {
354 return pos <= size() ? pos : size();
355 }
356
357 if (v.size() <= size()) {
358 pos = std::min(size() - v.size(), pos);
359 do {
360 if (v.at_(0) == at_(pos) &&
361 v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) {
362 return pos;
363 }
364 } while (pos-- > 0);
365 }
366 return npos;
367 }
368
369 constexpr size_type rfind(CharT ch, size_type pos = npos) const noexcept {
370 return find_last_if_(pos, charIsEqual_{ch});
371 }
372
rfind(const_pointer s,size_type pos,size_type count)373 constexpr size_type rfind(const_pointer s, size_type pos, size_type count)
374 const {
375 return rfind(basic_string_view(s, count), pos);
376 }
377
378 constexpr size_type rfind(const_pointer s, size_type pos = npos) const {
379 return rfind(basic_string_view(s), pos);
380 }
381
382 constexpr size_type find_first_of(basic_string_view v, size_type pos = 0)
383 const noexcept {
384 return find_first_if_(pos, stringViewContainsChar_{v});
385 }
386
387 constexpr size_type find_first_of(CharT ch, size_type pos = 0)
388 const noexcept {
389 return find_first_if_(pos, charIsEqual_{ch});
390 }
391
find_first_of(const_pointer s,size_type pos,size_type count)392 constexpr size_type find_first_of(
393 const_pointer s,
394 size_type pos,
395 size_type count) const {
396 return find_first_of(basic_string_view(s, count), pos);
397 }
398
399 constexpr size_type find_first_of(const_pointer s, size_type pos = 0) const {
400 return find_first_of(basic_string_view(s), pos);
401 }
402
403 constexpr size_type find_last_of(basic_string_view v, size_type pos = npos)
404 const noexcept {
405 return find_last_if_(pos, stringViewContainsChar_{v});
406 }
407
408 constexpr size_type find_last_of(CharT ch, size_type pos = npos)
409 const noexcept {
410 return find_last_if_(pos, charIsEqual_{ch});
411 }
412
find_last_of(const_pointer s,size_type pos,size_type count)413 constexpr size_type find_last_of(
414 const_pointer s,
415 size_type pos,
416 size_type count) const {
417 return find_last_of(basic_string_view(s, count), pos);
418 }
419
420 constexpr size_type find_last_of(const_pointer s, size_type pos = npos)
421 const {
422 return find_last_of(basic_string_view(s), pos);
423 }
424
425 constexpr size_type find_first_not_of(basic_string_view v, size_type pos = 0)
426 const noexcept {
427 return find_first_if_(pos, stringViewDoesNotContainChar_{v});
428 }
429
430 constexpr size_type find_first_not_of(CharT ch, size_type pos = 0)
431 const noexcept {
432 return find_first_if_(pos, charIsNotEqual_{ch});
433 }
434
find_first_not_of(const_pointer s,size_type pos,size_type count)435 constexpr size_type find_first_not_of(
436 const_pointer s,
437 size_type pos,
438 size_type count) const {
439 return find_first_not_of(basic_string_view(s, count), pos);
440 }
441
442 constexpr size_type find_first_not_of(const_pointer s, size_type pos = 0)
443 const {
444 return find_first_not_of(basic_string_view(s), pos);
445 }
446
447 constexpr size_type find_last_not_of(
448 basic_string_view v,
449 size_type pos = npos) const noexcept {
450 return find_last_if_(pos, stringViewDoesNotContainChar_{v});
451 }
452
453 constexpr size_type find_last_not_of(CharT ch, size_type pos = npos)
454 const noexcept {
455 return find_last_if_(pos, charIsNotEqual_{ch});
456 }
457
find_last_not_of(const_pointer s,size_type pos,size_type count)458 constexpr size_type find_last_not_of(
459 const_pointer s,
460 size_type pos,
461 size_type count) const {
462 return find_last_not_of(basic_string_view(s, count), pos);
463 }
464
465 constexpr size_type find_last_not_of(const_pointer s, size_type pos = npos)
466 const {
467 return find_last_not_of(basic_string_view(s), pos);
468 }
469
470 private:
strlen_(const_pointer str)471 static constexpr size_type strlen_(const_pointer str) noexcept {
472 const_pointer current = str;
473 while (*current != '\0') {
474 ++current;
475 }
476 return current - str;
477 }
478
at_(size_type pos)479 constexpr const_reference at_(size_type pos) const noexcept {
480 return *(begin_ + pos);
481 }
482
483 constexpr basic_string_view substr_(size_type pos = 0, size_type count = npos)
484 const {
485 return basic_string_view{begin_ + pos, std::min(count, size() - pos)};
486 }
487
488 template <class Condition>
489 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
find_first_if_(size_type pos,Condition && condition)490 constexpr size_type find_first_if_(size_type pos, Condition&& condition)
491 const noexcept {
492 if (pos + 1 <= size()) {
493 for (size_type cur = pos; cur < size(); ++cur) {
494 if (condition(at_(cur))) {
495 return cur;
496 }
497 }
498 }
499 return npos;
500 }
501
502 template <class Condition>
503 // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
find_last_if_(size_type pos,Condition && condition)504 constexpr size_type find_last_if_(size_type pos, Condition&& condition)
505 const noexcept {
506 // Write it iteratively. This is faster.
507 if (size() > 0) {
508 pos = std::min(size() - 1, pos);
509 do {
510 if (condition(at_(pos))) {
511 return pos;
512 }
513 } while (pos-- > 0);
514 }
515 return npos;
516 }
517
equals_(basic_string_view rhs)518 constexpr bool equals_(basic_string_view rhs) const {
519 // We don't use string_view::compare() here but implement it manually
520 // because only looking at equality allows for more optimized code.
521 #if defined(__GNUC__) && !defined(__CUDACC__)
522 return size() == rhs.size() &&
523 0 == __builtin_memcmp(data(), rhs.data(), size());
524 #else
525 if (size() != rhs.size()) {
526 return false;
527 }
528 // Yes, memcmp would be laster than this loop, but memcmp isn't constexpr
529 // and I didn't feel like implementing a constexpr memcmp variant.
530 // TODO At some point this should probably be done, including tricks
531 // like comparing one machine word instead of a byte per iteration.
532 for (typename basic_string_view<CharT>::size_type pos = 0; pos < size();
533 ++pos) {
534 if (at_(pos) != rhs.at_(pos)) {
535 return false;
536 }
537 }
538 return true;
539 #endif
540 }
541
542 struct charIsEqual_ final {
543 CharT expected;
operatorfinal544 constexpr bool operator()(CharT actual) const noexcept {
545 return expected == actual;
546 }
547 };
548
549 struct charIsNotEqual_ final {
550 CharT expected;
operatorfinal551 constexpr bool operator()(CharT actual) const noexcept {
552 return expected != actual;
553 }
554 };
555
556 struct stringViewContainsChar_ final {
557 basic_string_view expected;
operatorfinal558 constexpr bool operator()(CharT ch) const noexcept {
559 return npos != expected.find(ch);
560 }
561 };
562
563 struct stringViewDoesNotContainChar_ final {
564 basic_string_view expected;
operatorfinal565 constexpr bool operator()(CharT ch) const noexcept {
566 return npos == expected.find(ch);
567 }
568 };
569
570 const_pointer begin_;
571 size_type size_{};
572 };
573
574 template <class CharT>
575 inline std::basic_ostream<CharT>& operator<<(
576 std::basic_ostream<CharT>& stream,
577 basic_string_view<CharT> sv) {
578 // The rules for operator<< are quite complex, so lets defer to the
579 // STL implementation.
580 using std_string_type = ::std::basic_string_view<CharT>;
581 return stream << std_string_type(sv.data(), sv.size());
582 }
583
584 template <class CharT>
swap(basic_string_view<CharT> & lhs,basic_string_view<CharT> & rhs)585 constexpr inline void swap(
586 basic_string_view<CharT>& lhs,
587 basic_string_view<CharT>& rhs) noexcept {
588 lhs.swap(rhs);
589 }
590
591 using string_view = basic_string_view<char>;
592
593 } // namespace c10
594
595 namespace std {
596 template <class CharT>
597 struct hash<::c10::basic_string_view<CharT>> {
598 size_t operator()(::c10::basic_string_view<CharT> x) const {
599 // The standard says that std::string_view hashing must do the same as
600 // std::string hashing but leaves the details of std::string hashing
601 // up to the implementer. So, to be conformant, we need to re-use and
602 // existing STL type's hash function. The std::string fallback is probably
603 // slow but the only way to be conformant.
604
605 using std_string_type = ::std::basic_string_view<CharT>;
606 return ::std::hash<std_string_type>{}(std_string_type(x.data(), x.size()));
607 }
608 };
609 } // namespace std
610