xref: /aosp_15_r20/external/pytorch/torch/csrc/api/include/torch/expanding_array.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <c10/util/ArrayRef.h>
4 #include <c10/util/Exception.h>
5 #include <c10/util/irange.h>
6 #include <optional>
7 
8 #include <algorithm>
9 #include <array>
10 #include <cstdint>
11 #include <initializer_list>
12 #include <string>
13 #include <vector>
14 
15 namespace torch {
16 
17 /// A utility class that accepts either a container of `D`-many values, or a
18 /// single value, which is internally repeated `D` times. This is useful to
19 /// represent parameters that are multidimensional, but often equally sized in
20 /// all dimensions. For example, the kernel size of a 2D convolution has an `x`
21 /// and `y` length, but `x` and `y` are often equal. In such a case you could
22 /// just pass `3` to an `ExpandingArray<2>` and it would "expand" to `{3, 3}`.
23 template <size_t D, typename T = int64_t>
24 class ExpandingArray {
25  public:
26   /// Constructs an `ExpandingArray` from an `initializer_list`. The extent of
27   /// the length is checked against the `ExpandingArray`'s extent parameter `D`
28   /// at runtime.
ExpandingArray(std::initializer_list<T> list)29   /*implicit*/ ExpandingArray(std::initializer_list<T> list)
30       : ExpandingArray(at::ArrayRef<T>(list)) {}
31 
32   /// Constructs an `ExpandingArray` from an `std::vector`. The extent of
33   /// the length is checked against the `ExpandingArray`'s extent parameter `D`
34   /// at runtime.
ExpandingArray(std::vector<T> vec)35   /*implicit*/ ExpandingArray(std::vector<T> vec)
36       : ExpandingArray(at::ArrayRef<T>(vec)) {}
37 
38   /// Constructs an `ExpandingArray` from an `at::ArrayRef`. The extent of
39   /// the length is checked against the `ExpandingArray`'s extent parameter `D`
40   /// at runtime.
ExpandingArray(at::ArrayRef<T> values)41   /*implicit*/ ExpandingArray(at::ArrayRef<T> values) {
42     // clang-format off
43     TORCH_CHECK(
44         values.size() == D,
45         "Expected ", D, " values, but instead got ", values.size());
46     // clang-format on
47     std::copy(values.begin(), values.end(), values_.begin());
48   }
49 
50   /// Constructs an `ExpandingArray` from a single value, which is repeated `D`
51   /// times (where `D` is the extent parameter of the `ExpandingArray`).
ExpandingArray(T single_size)52   /*implicit*/ ExpandingArray(T single_size) {
53     values_.fill(single_size);
54   }
55 
56   /// Constructs an `ExpandingArray` from a correctly sized `std::array`.
ExpandingArray(const std::array<T,D> & values)57   /*implicit*/ ExpandingArray(const std::array<T, D>& values)
58       : values_(values) {}
59 
60   /// Accesses the underlying `std::array`.
61   std::array<T, D>& operator*() {
62     return values_;
63   }
64 
65   /// Accesses the underlying `std::array`.
66   const std::array<T, D>& operator*() const {
67     return values_;
68   }
69 
70   /// Accesses the underlying `std::array`.
71   std::array<T, D>* operator->() {
72     return &values_;
73   }
74 
75   /// Accesses the underlying `std::array`.
76   const std::array<T, D>* operator->() const {
77     return &values_;
78   }
79 
80   /// Returns an `ArrayRef` to the underlying `std::array`.
81   operator at::ArrayRef<T>() const {
82     return values_;
83   }
84 
85   /// Returns the extent of the `ExpandingArray`.
size()86   size_t size() const noexcept {
87     return D;
88   }
89 
90  protected:
91   /// The backing array.
92   // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes)
93   std::array<T, D> values_;
94 };
95 
96 template <size_t D, typename T>
97 std::ostream& operator<<(
98     std::ostream& stream,
99     const ExpandingArray<D, T>& expanding_array) {
100   if (expanding_array.size() == 1) {
101     return stream << expanding_array->at(0);
102   }
103   return stream << static_cast<at::ArrayRef<T>>(expanding_array);
104 }
105 
106 /// A utility class that accepts either a container of `D`-many
107 /// `std::optional<T>` values, or a single `std::optional<T>` value, which is
108 /// internally repeated `D` times. It has the additional ability to accept
109 /// containers of the underlying type `T` and convert them to a container of
110 /// `std::optional<T>`.
111 template <size_t D, typename T = int64_t>
112 class ExpandingArrayWithOptionalElem
113     : public ExpandingArray<D, std::optional<T>> {
114  public:
115   using ExpandingArray<D, std::optional<T>>::ExpandingArray;
116 
117   /// Constructs an `ExpandingArrayWithOptionalElem` from an `initializer_list`
118   /// of the underlying type `T`. The extent of the length is checked against
119   /// the `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime.
ExpandingArrayWithOptionalElem(std::initializer_list<T> list)120   /*implicit*/ ExpandingArrayWithOptionalElem(std::initializer_list<T> list)
121       : ExpandingArrayWithOptionalElem(at::ArrayRef<T>(list)) {}
122 
123   /// Constructs an `ExpandingArrayWithOptionalElem` from an `std::vector` of
124   /// the underlying type `T`. The extent of the length is checked against the
125   /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime.
ExpandingArrayWithOptionalElem(std::vector<T> vec)126   /*implicit*/ ExpandingArrayWithOptionalElem(std::vector<T> vec)
127       : ExpandingArrayWithOptionalElem(at::ArrayRef<T>(vec)) {}
128 
129   /// Constructs an `ExpandingArrayWithOptionalElem` from an `at::ArrayRef` of
130   /// the underlying type `T`. The extent of the length is checked against the
131   /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime.
ExpandingArrayWithOptionalElem(at::ArrayRef<T> values)132   /*implicit*/ ExpandingArrayWithOptionalElem(at::ArrayRef<T> values)
133       : ExpandingArray<D, std::optional<T>>(0) {
134     // clang-format off
135     TORCH_CHECK(
136         values.size() == D,
137         "Expected ", D, " values, but instead got ", values.size());
138     // clang-format on
139     for (const auto i : c10::irange(this->values_.size())) {
140       this->values_[i] = values[i];
141     }
142   }
143 
144   /// Constructs an `ExpandingArrayWithOptionalElem` from a single value of the
145   /// underlying type `T`, which is repeated `D` times (where `D` is the extent
146   /// parameter of the `ExpandingArrayWithOptionalElem`).
ExpandingArrayWithOptionalElem(T single_size)147   /*implicit*/ ExpandingArrayWithOptionalElem(T single_size)
148       : ExpandingArray<D, std::optional<T>>(0) {
149     for (const auto i : c10::irange(this->values_.size())) {
150       this->values_[i] = single_size;
151     }
152   }
153 
154   /// Constructs an `ExpandingArrayWithOptionalElem` from a correctly sized
155   /// `std::array` of the underlying type `T`.
ExpandingArrayWithOptionalElem(const std::array<T,D> & values)156   /*implicit*/ ExpandingArrayWithOptionalElem(const std::array<T, D>& values)
157       : ExpandingArray<D, std::optional<T>>(0) {
158     for (const auto i : c10::irange(this->values_.size())) {
159       this->values_[i] = values[i];
160     }
161   }
162 };
163 
164 template <size_t D, typename T>
165 std::ostream& operator<<(
166     std::ostream& stream,
167     const ExpandingArrayWithOptionalElem<D, T>& expanding_array_with_opt_elem) {
168   if (expanding_array_with_opt_elem.size() == 1) {
169     const auto& elem = expanding_array_with_opt_elem->at(0);
170     stream << (elem.has_value() ? c10::str(elem.value()) : "None");
171   } else {
172     std::vector<std::string> str_array;
173     for (const auto& elem : *expanding_array_with_opt_elem) {
174       str_array.emplace_back(
175           elem.has_value() ? c10::str(elem.value()) : "None");
176     }
177     stream << at::ArrayRef<std::string>(str_array);
178   }
179   return stream;
180 }
181 
182 } // namespace torch
183