xref: /aosp_15_r20/external/pytorch/c10/util/string_view.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <algorithm>
4 #include <cstddef>
5 #include <cstring>
6 #include <functional>
7 #include <iterator>
8 #include <limits>
9 #include <ostream>
10 #include <stdexcept>
11 #include <string>
12 #include <string_view>
13 
14 #include <c10/macros/Macros.h>
15 
16 namespace c10 {
17 
18 /**
19  * Port of std::string_view with methods from C++20.
20  * Implemented following the interface definition in
21  * https://en.cppreference.com/w/cpp/string/basic_string_view
22  * See there for the API documentation.
23  *
24  * Difference: We don't have a Traits template parameter because
25  * std::char_traits isn't constexpr and we'd have to reimplement
26  * std::char_traits if we wanted to use it with our constexpr basic_string_view.
27  */
28 template <class CharT>
29 class basic_string_view final {
30  public:
31   using value_type = CharT;
32   using pointer = CharT*;
33   using const_pointer = const CharT*;
34   using reference = CharT&;
35   using const_reference = const CharT&;
36   using const_iterator = const CharT*;
37   using iterator = const_iterator;
38   using const_reverse_iterator = std::reverse_iterator<const_iterator>;
39   using reverse_iterator = const_reverse_iterator;
40   using size_type = std::size_t;
41   using difference_type = std::ptrdiff_t;
42 
43   static constexpr size_type npos = size_type(-1);
44 
basic_string_view()45   constexpr basic_string_view() noexcept : begin_(nullptr) {}
46 
basic_string_view(const_pointer str,size_type count)47   explicit constexpr basic_string_view(const_pointer str, size_type count)
48       : begin_(str), size_(count) {}
49 
basic_string_view(const_pointer str)50   /* implicit */ constexpr basic_string_view(const_pointer str)
51       : basic_string_view(str, strlen_(str)) {}
52 
basic_string_view(const::std::basic_string<CharT> & str)53   /* implicit */ basic_string_view(const ::std::basic_string<CharT>& str)
54       : basic_string_view(str.data(), str.size()) {}
55 
56   constexpr basic_string_view(const basic_string_view&) noexcept = default;
57 
58   constexpr basic_string_view& operator=(
59       const basic_string_view& rhs) noexcept {
60     begin_ = rhs.begin_;
61     size_ = rhs.size_;
62     return *this;
63   }
64 
65   explicit operator ::std::basic_string<CharT>() const {
66     return ::std::basic_string<CharT>(data(), size());
67   }
68 
begin()69   constexpr const_iterator begin() const noexcept {
70     return cbegin();
71   }
72 
cbegin()73   constexpr const_iterator cbegin() const noexcept {
74     return begin_;
75   }
76 
end()77   constexpr const_iterator end() const noexcept {
78     return cend();
79   }
80 
cend()81   constexpr const_iterator cend() const noexcept {
82     return begin_ + size_;
83   }
84 
rbegin()85   constexpr const_reverse_iterator rbegin() const noexcept {
86     return crbegin();
87   }
88 
crbegin()89   constexpr const_reverse_iterator crbegin() const noexcept {
90     return const_reverse_iterator(this->end());
91   }
92 
rend()93   constexpr const_reverse_iterator rend() const noexcept {
94     return crend();
95   }
96 
crend()97   constexpr const_reverse_iterator crend() const noexcept {
98     return const_reverse_iterator(this->begin());
99   }
100 
begin(basic_string_view sv)101   friend constexpr const_iterator begin(basic_string_view sv) noexcept {
102     return sv.begin();
103   }
104 
end(basic_string_view sv)105   friend constexpr const_iterator end(basic_string_view sv) noexcept {
106     return sv.end();
107   }
108 
109   constexpr const_reference operator[](size_type pos) const {
110     // TODO: split out
111     return at_(pos);
112   }
113 
at(size_type pos)114   constexpr const_reference at(size_type pos) const {
115 #if !defined( \
116     __CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code
117     return C10_UNLIKELY(pos >= size_)
118         ? (throw std::out_of_range(
119                "string_view::operator[] or string_view::at() out of range. Index: " +
120                std::to_string(pos) + ", size: " + std::to_string(size())),
121            at_(0))
122         : at_(pos);
123 #else
124     return at_(pos);
125 #endif
126   }
127 
front()128   constexpr const_reference front() const {
129     return *begin_;
130   }
131 
back()132   constexpr const_reference back() const {
133     return *(begin_ + size_ - 1);
134   }
135 
data()136   constexpr const_pointer data() const noexcept {
137     return begin_;
138   }
139 
size()140   constexpr size_type size() const noexcept {
141     return size_;
142   }
143 
length()144   constexpr size_type length() const noexcept {
145     return size();
146   }
147 
max_size()148   constexpr size_type max_size() const noexcept {
149     return std::numeric_limits<difference_type>::max();
150   }
151 
empty()152   C10_NODISCARD constexpr bool empty() const noexcept {
153     return size() == 0;
154   }
155 
remove_prefix(size_type n)156   constexpr void remove_prefix(size_type n) {
157     if (n > size()) {
158       throw std::out_of_range(
159           "basic_string_view::remove_prefix: out of range. PrefixLength: " +
160           std::to_string(n) + ", size: " + std::to_string(size()));
161     }
162     begin_ += n;
163     size_ -= n;
164   }
165 
remove_suffix(size_type n)166   constexpr void remove_suffix(size_type n) {
167     if (n > size()) {
168       throw std::out_of_range(
169           "basic_string_view::remove_suffix: out of range. SuffixLength: " +
170           std::to_string(n) + ", size: " + std::to_string(size()));
171     }
172     size_ -= n;
173   }
174 
swap(basic_string_view & sv)175   constexpr void swap(basic_string_view& sv) noexcept {
176     auto tmp = *this;
177     *this = sv;
178     sv = tmp;
179   }
180 
181   size_type copy(pointer dest, size_type count, size_type pos = 0) const {
182     if (pos > size_) {
183       throw std::out_of_range(
184           "basic_string_view::copy: out of range. Index: " +
185           std::to_string(pos) + ", size: " + std::to_string(size()));
186     }
187     size_type copy_length = std::min(count, size_ - pos);
188     for (auto iter = begin() + pos, end = iter + copy_length; iter != end;) {
189       *(dest++) = *(iter++);
190     }
191     return copy_length;
192   }
193 
194   constexpr basic_string_view substr(size_type pos = 0, size_type count = npos)
195       const {
196 #if !defined( \
197     __CUDA_ARCH__) // CUDA doesn't like std::out_of_range in device code
198     return (pos > size_)
199         ? (throw std::out_of_range(
200                "basic_string_view::substr parameter out of bounds. Index: " +
201                std::to_string(pos) + ", size: " + std::to_string(size())),
202            substr_())
203         : substr_(pos, count);
204 #else
205     return substr_(pos, count);
206 #endif
207   }
208 
compare(basic_string_view rhs)209   constexpr int compare(basic_string_view rhs) const noexcept {
210     // Write it iteratively. This is faster.
211     for (size_t i = 0, end = std::min(size(), rhs.size()); i < end; ++i) {
212       if (at_(i) < rhs.at_(i)) {
213         return -1;
214       } else if (at_(i) > rhs.at_(i)) {
215         return 1;
216       }
217     }
218     if (size() < rhs.size()) {
219       return -1;
220     } else if (size() > rhs.size()) {
221       return 1;
222     }
223     return 0;
224   }
225 
compare(size_type pos1,size_type count1,basic_string_view v)226   constexpr int compare(size_type pos1, size_type count1, basic_string_view v)
227       const {
228     return substr(pos1, count1).compare(v);
229   }
230 
compare(size_type pos1,size_type count1,basic_string_view v,size_type pos2,size_type count2)231   constexpr int compare(
232       size_type pos1,
233       size_type count1,
234       basic_string_view v,
235       size_type pos2,
236       size_type count2) const {
237     return substr(pos1, count1).compare(v.substr(pos2, count2));
238   }
239 
compare(const_pointer s)240   constexpr int compare(const_pointer s) const {
241     return compare(basic_string_view(s));
242   }
243 
compare(size_type pos1,size_type count1,const_pointer s)244   constexpr int compare(size_type pos1, size_type count1, const_pointer s)
245       const {
246     return substr(pos1, count1).compare(basic_string_view(s));
247   }
248 
compare(size_type pos1,size_type count1,const_pointer s,size_type count2)249   constexpr int compare(
250       size_type pos1,
251       size_type count1,
252       const_pointer s,
253       size_type count2) const {
254     return substr(pos1, count1).compare(basic_string_view(s, count2));
255   }
256 
257   friend constexpr bool operator==(
258       basic_string_view lhs,
259       basic_string_view rhs) noexcept {
260     return lhs.equals_(rhs);
261   }
262 
263   friend constexpr bool operator!=(
264       basic_string_view lhs,
265       basic_string_view rhs) noexcept {
266     return !(lhs == rhs);
267   }
268 
269   friend constexpr bool operator<(
270       basic_string_view lhs,
271       basic_string_view rhs) noexcept {
272     return lhs.compare(rhs) < 0;
273   }
274 
275   friend constexpr bool operator>=(
276       basic_string_view lhs,
277       basic_string_view rhs) noexcept {
278     return !(lhs < rhs);
279   }
280 
281   friend constexpr bool operator>(
282       basic_string_view lhs,
283       basic_string_view rhs) noexcept {
284     return rhs < lhs;
285   }
286 
287   friend constexpr bool operator<=(
288       basic_string_view lhs,
289       basic_string_view rhs) noexcept {
290     return !(lhs > rhs);
291   }
292 
starts_with(basic_string_view prefix)293   constexpr bool starts_with(basic_string_view prefix) const noexcept {
294     return (prefix.size() > size()) ? false
295                                     : prefix.equals_(substr_(0, prefix.size()));
296   }
297 
starts_with(CharT prefix)298   constexpr bool starts_with(CharT prefix) const noexcept {
299     return !empty() && prefix == front();
300   }
301 
starts_with(const_pointer prefix)302   constexpr bool starts_with(const_pointer prefix) const {
303     return starts_with(basic_string_view(prefix));
304   }
305 
ends_with(basic_string_view suffix)306   constexpr bool ends_with(basic_string_view suffix) const noexcept {
307     return (suffix.size() > size())
308         ? false
309         : suffix.equals_(substr_(size() - suffix.size(), suffix.size()));
310   }
311 
ends_with(CharT suffix)312   constexpr bool ends_with(CharT suffix) const noexcept {
313     return !empty() && suffix == back();
314   }
315 
ends_with(const_pointer suffix)316   constexpr bool ends_with(const_pointer suffix) const {
317     return ends_with(basic_string_view(suffix));
318   }
319 
320   constexpr size_type find(basic_string_view v, size_type pos = 0)
321       const noexcept {
322     if (v.size() == 0) {
323       return pos <= size() ? pos : npos;
324     }
325 
326     if (pos + v.size() <= size()) {
327       for (size_type cur = pos, end = size() - v.size(); cur <= end; ++cur) {
328         if (v.at_(0) == at_(cur) &&
329             v.substr_(1).equals_(substr_(cur + 1, v.size() - 1))) {
330           return cur;
331         }
332       }
333     }
334     return npos;
335   }
336 
337   constexpr size_type find(CharT ch, size_type pos = 0) const noexcept {
338     return find_first_if_(pos, charIsEqual_{ch});
339   }
340 
find(const_pointer s,size_type pos,size_type count)341   constexpr size_type find(const_pointer s, size_type pos, size_type count)
342       const {
343     return find(basic_string_view(s, count), pos);
344   }
345 
346   constexpr size_type find(const_pointer s, size_type pos = 0) const {
347     return find(basic_string_view(s), pos);
348   }
349 
350   constexpr size_type rfind(basic_string_view v, size_type pos = npos)
351       const noexcept {
352     // Write it iteratively. This is faster.
353     if (v.size() == 0) {
354       return pos <= size() ? pos : size();
355     }
356 
357     if (v.size() <= size()) {
358       pos = std::min(size() - v.size(), pos);
359       do {
360         if (v.at_(0) == at_(pos) &&
361             v.substr_(1).equals_(substr_(pos + 1, v.size() - 1))) {
362           return pos;
363         }
364       } while (pos-- > 0);
365     }
366     return npos;
367   }
368 
369   constexpr size_type rfind(CharT ch, size_type pos = npos) const noexcept {
370     return find_last_if_(pos, charIsEqual_{ch});
371   }
372 
rfind(const_pointer s,size_type pos,size_type count)373   constexpr size_type rfind(const_pointer s, size_type pos, size_type count)
374       const {
375     return rfind(basic_string_view(s, count), pos);
376   }
377 
378   constexpr size_type rfind(const_pointer s, size_type pos = npos) const {
379     return rfind(basic_string_view(s), pos);
380   }
381 
382   constexpr size_type find_first_of(basic_string_view v, size_type pos = 0)
383       const noexcept {
384     return find_first_if_(pos, stringViewContainsChar_{v});
385   }
386 
387   constexpr size_type find_first_of(CharT ch, size_type pos = 0)
388       const noexcept {
389     return find_first_if_(pos, charIsEqual_{ch});
390   }
391 
find_first_of(const_pointer s,size_type pos,size_type count)392   constexpr size_type find_first_of(
393       const_pointer s,
394       size_type pos,
395       size_type count) const {
396     return find_first_of(basic_string_view(s, count), pos);
397   }
398 
399   constexpr size_type find_first_of(const_pointer s, size_type pos = 0) const {
400     return find_first_of(basic_string_view(s), pos);
401   }
402 
403   constexpr size_type find_last_of(basic_string_view v, size_type pos = npos)
404       const noexcept {
405     return find_last_if_(pos, stringViewContainsChar_{v});
406   }
407 
408   constexpr size_type find_last_of(CharT ch, size_type pos = npos)
409       const noexcept {
410     return find_last_if_(pos, charIsEqual_{ch});
411   }
412 
find_last_of(const_pointer s,size_type pos,size_type count)413   constexpr size_type find_last_of(
414       const_pointer s,
415       size_type pos,
416       size_type count) const {
417     return find_last_of(basic_string_view(s, count), pos);
418   }
419 
420   constexpr size_type find_last_of(const_pointer s, size_type pos = npos)
421       const {
422     return find_last_of(basic_string_view(s), pos);
423   }
424 
425   constexpr size_type find_first_not_of(basic_string_view v, size_type pos = 0)
426       const noexcept {
427     return find_first_if_(pos, stringViewDoesNotContainChar_{v});
428   }
429 
430   constexpr size_type find_first_not_of(CharT ch, size_type pos = 0)
431       const noexcept {
432     return find_first_if_(pos, charIsNotEqual_{ch});
433   }
434 
find_first_not_of(const_pointer s,size_type pos,size_type count)435   constexpr size_type find_first_not_of(
436       const_pointer s,
437       size_type pos,
438       size_type count) const {
439     return find_first_not_of(basic_string_view(s, count), pos);
440   }
441 
442   constexpr size_type find_first_not_of(const_pointer s, size_type pos = 0)
443       const {
444     return find_first_not_of(basic_string_view(s), pos);
445   }
446 
447   constexpr size_type find_last_not_of(
448       basic_string_view v,
449       size_type pos = npos) const noexcept {
450     return find_last_if_(pos, stringViewDoesNotContainChar_{v});
451   }
452 
453   constexpr size_type find_last_not_of(CharT ch, size_type pos = npos)
454       const noexcept {
455     return find_last_if_(pos, charIsNotEqual_{ch});
456   }
457 
find_last_not_of(const_pointer s,size_type pos,size_type count)458   constexpr size_type find_last_not_of(
459       const_pointer s,
460       size_type pos,
461       size_type count) const {
462     return find_last_not_of(basic_string_view(s, count), pos);
463   }
464 
465   constexpr size_type find_last_not_of(const_pointer s, size_type pos = npos)
466       const {
467     return find_last_not_of(basic_string_view(s), pos);
468   }
469 
470  private:
strlen_(const_pointer str)471   static constexpr size_type strlen_(const_pointer str) noexcept {
472     const_pointer current = str;
473     while (*current != '\0') {
474       ++current;
475     }
476     return current - str;
477   }
478 
at_(size_type pos)479   constexpr const_reference at_(size_type pos) const noexcept {
480     return *(begin_ + pos);
481   }
482 
483   constexpr basic_string_view substr_(size_type pos = 0, size_type count = npos)
484       const {
485     return basic_string_view{begin_ + pos, std::min(count, size() - pos)};
486   }
487 
488   template <class Condition>
489   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
find_first_if_(size_type pos,Condition && condition)490   constexpr size_type find_first_if_(size_type pos, Condition&& condition)
491       const noexcept {
492     if (pos + 1 <= size()) {
493       for (size_type cur = pos; cur < size(); ++cur) {
494         if (condition(at_(cur))) {
495           return cur;
496         }
497       }
498     }
499     return npos;
500   }
501 
502   template <class Condition>
503   // NOLINTNEXTLINE(cppcoreguidelines-missing-std-forward)
find_last_if_(size_type pos,Condition && condition)504   constexpr size_type find_last_if_(size_type pos, Condition&& condition)
505       const noexcept {
506     // Write it iteratively. This is faster.
507     if (size() > 0) {
508       pos = std::min(size() - 1, pos);
509       do {
510         if (condition(at_(pos))) {
511           return pos;
512         }
513       } while (pos-- > 0);
514     }
515     return npos;
516   }
517 
equals_(basic_string_view rhs)518   constexpr bool equals_(basic_string_view rhs) const {
519     // We don't use string_view::compare() here but implement it manually
520     // because only looking at equality allows for more optimized code.
521 #if defined(__GNUC__) && !defined(__CUDACC__)
522     return size() == rhs.size() &&
523         0 == __builtin_memcmp(data(), rhs.data(), size());
524 #else
525     if (size() != rhs.size()) {
526       return false;
527     }
528     // Yes, memcmp would be laster than this loop, but memcmp isn't constexpr
529     // and I didn't feel like implementing a constexpr memcmp variant.
530     // TODO At some point this should probably be done, including tricks
531     // like comparing one machine word instead of a byte per iteration.
532     for (typename basic_string_view<CharT>::size_type pos = 0; pos < size();
533          ++pos) {
534       if (at_(pos) != rhs.at_(pos)) {
535         return false;
536       }
537     }
538     return true;
539 #endif
540   }
541 
542   struct charIsEqual_ final {
543     CharT expected;
operatorfinal544     constexpr bool operator()(CharT actual) const noexcept {
545       return expected == actual;
546     }
547   };
548 
549   struct charIsNotEqual_ final {
550     CharT expected;
operatorfinal551     constexpr bool operator()(CharT actual) const noexcept {
552       return expected != actual;
553     }
554   };
555 
556   struct stringViewContainsChar_ final {
557     basic_string_view expected;
operatorfinal558     constexpr bool operator()(CharT ch) const noexcept {
559       return npos != expected.find(ch);
560     }
561   };
562 
563   struct stringViewDoesNotContainChar_ final {
564     basic_string_view expected;
operatorfinal565     constexpr bool operator()(CharT ch) const noexcept {
566       return npos == expected.find(ch);
567     }
568   };
569 
570   const_pointer begin_;
571   size_type size_{};
572 };
573 
574 template <class CharT>
575 inline std::basic_ostream<CharT>& operator<<(
576     std::basic_ostream<CharT>& stream,
577     basic_string_view<CharT> sv) {
578   // The rules for operator<< are quite complex, so lets defer to the
579   // STL implementation.
580   using std_string_type = ::std::basic_string_view<CharT>;
581   return stream << std_string_type(sv.data(), sv.size());
582 }
583 
584 template <class CharT>
swap(basic_string_view<CharT> & lhs,basic_string_view<CharT> & rhs)585 constexpr inline void swap(
586     basic_string_view<CharT>& lhs,
587     basic_string_view<CharT>& rhs) noexcept {
588   lhs.swap(rhs);
589 }
590 
591 using string_view = basic_string_view<char>;
592 
593 } // namespace c10
594 
595 namespace std {
596 template <class CharT>
597 struct hash<::c10::basic_string_view<CharT>> {
598   size_t operator()(::c10::basic_string_view<CharT> x) const {
599     // The standard says that std::string_view hashing must do the same as
600     // std::string hashing but leaves the details of std::string hashing
601     // up to the implementer. So, to be conformant, we need to re-use and
602     // existing STL type's hash function. The std::string fallback is probably
603     // slow but the only way to be conformant.
604 
605     using std_string_type = ::std::basic_string_view<CharT>;
606     return ::std::hash<std_string_type>{}(std_string_type(x.data(), x.size()));
607   }
608 };
609 } // namespace std
610