1 #pragma once 2 3 #include <c10/util/ArrayRef.h> 4 #include <c10/util/Exception.h> 5 #include <c10/util/irange.h> 6 #include <optional> 7 8 #include <algorithm> 9 #include <array> 10 #include <cstdint> 11 #include <initializer_list> 12 #include <string> 13 #include <vector> 14 15 namespace torch { 16 17 /// A utility class that accepts either a container of `D`-many values, or a 18 /// single value, which is internally repeated `D` times. This is useful to 19 /// represent parameters that are multidimensional, but often equally sized in 20 /// all dimensions. For example, the kernel size of a 2D convolution has an `x` 21 /// and `y` length, but `x` and `y` are often equal. In such a case you could 22 /// just pass `3` to an `ExpandingArray<2>` and it would "expand" to `{3, 3}`. 23 template <size_t D, typename T = int64_t> 24 class ExpandingArray { 25 public: 26 /// Constructs an `ExpandingArray` from an `initializer_list`. The extent of 27 /// the length is checked against the `ExpandingArray`'s extent parameter `D` 28 /// at runtime. ExpandingArray(std::initializer_list<T> list)29 /*implicit*/ ExpandingArray(std::initializer_list<T> list) 30 : ExpandingArray(at::ArrayRef<T>(list)) {} 31 32 /// Constructs an `ExpandingArray` from an `std::vector`. The extent of 33 /// the length is checked against the `ExpandingArray`'s extent parameter `D` 34 /// at runtime. ExpandingArray(std::vector<T> vec)35 /*implicit*/ ExpandingArray(std::vector<T> vec) 36 : ExpandingArray(at::ArrayRef<T>(vec)) {} 37 38 /// Constructs an `ExpandingArray` from an `at::ArrayRef`. The extent of 39 /// the length is checked against the `ExpandingArray`'s extent parameter `D` 40 /// at runtime. ExpandingArray(at::ArrayRef<T> values)41 /*implicit*/ ExpandingArray(at::ArrayRef<T> values) { 42 // clang-format off 43 TORCH_CHECK( 44 values.size() == D, 45 "Expected ", D, " values, but instead got ", values.size()); 46 // clang-format on 47 std::copy(values.begin(), values.end(), values_.begin()); 48 } 49 50 /// Constructs an `ExpandingArray` from a single value, which is repeated `D` 51 /// times (where `D` is the extent parameter of the `ExpandingArray`). ExpandingArray(T single_size)52 /*implicit*/ ExpandingArray(T single_size) { 53 values_.fill(single_size); 54 } 55 56 /// Constructs an `ExpandingArray` from a correctly sized `std::array`. ExpandingArray(const std::array<T,D> & values)57 /*implicit*/ ExpandingArray(const std::array<T, D>& values) 58 : values_(values) {} 59 60 /// Accesses the underlying `std::array`. 61 std::array<T, D>& operator*() { 62 return values_; 63 } 64 65 /// Accesses the underlying `std::array`. 66 const std::array<T, D>& operator*() const { 67 return values_; 68 } 69 70 /// Accesses the underlying `std::array`. 71 std::array<T, D>* operator->() { 72 return &values_; 73 } 74 75 /// Accesses the underlying `std::array`. 76 const std::array<T, D>* operator->() const { 77 return &values_; 78 } 79 80 /// Returns an `ArrayRef` to the underlying `std::array`. 81 operator at::ArrayRef<T>() const { 82 return values_; 83 } 84 85 /// Returns the extent of the `ExpandingArray`. size()86 size_t size() const noexcept { 87 return D; 88 } 89 90 protected: 91 /// The backing array. 92 // NOLINTNEXTLINE(cppcoreguidelines-non-private-member-variables-in-classes) 93 std::array<T, D> values_; 94 }; 95 96 template <size_t D, typename T> 97 std::ostream& operator<<( 98 std::ostream& stream, 99 const ExpandingArray<D, T>& expanding_array) { 100 if (expanding_array.size() == 1) { 101 return stream << expanding_array->at(0); 102 } 103 return stream << static_cast<at::ArrayRef<T>>(expanding_array); 104 } 105 106 /// A utility class that accepts either a container of `D`-many 107 /// `std::optional<T>` values, or a single `std::optional<T>` value, which is 108 /// internally repeated `D` times. It has the additional ability to accept 109 /// containers of the underlying type `T` and convert them to a container of 110 /// `std::optional<T>`. 111 template <size_t D, typename T = int64_t> 112 class ExpandingArrayWithOptionalElem 113 : public ExpandingArray<D, std::optional<T>> { 114 public: 115 using ExpandingArray<D, std::optional<T>>::ExpandingArray; 116 117 /// Constructs an `ExpandingArrayWithOptionalElem` from an `initializer_list` 118 /// of the underlying type `T`. The extent of the length is checked against 119 /// the `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. ExpandingArrayWithOptionalElem(std::initializer_list<T> list)120 /*implicit*/ ExpandingArrayWithOptionalElem(std::initializer_list<T> list) 121 : ExpandingArrayWithOptionalElem(at::ArrayRef<T>(list)) {} 122 123 /// Constructs an `ExpandingArrayWithOptionalElem` from an `std::vector` of 124 /// the underlying type `T`. The extent of the length is checked against the 125 /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. ExpandingArrayWithOptionalElem(std::vector<T> vec)126 /*implicit*/ ExpandingArrayWithOptionalElem(std::vector<T> vec) 127 : ExpandingArrayWithOptionalElem(at::ArrayRef<T>(vec)) {} 128 129 /// Constructs an `ExpandingArrayWithOptionalElem` from an `at::ArrayRef` of 130 /// the underlying type `T`. The extent of the length is checked against the 131 /// `ExpandingArrayWithOptionalElem`'s extent parameter `D` at runtime. ExpandingArrayWithOptionalElem(at::ArrayRef<T> values)132 /*implicit*/ ExpandingArrayWithOptionalElem(at::ArrayRef<T> values) 133 : ExpandingArray<D, std::optional<T>>(0) { 134 // clang-format off 135 TORCH_CHECK( 136 values.size() == D, 137 "Expected ", D, " values, but instead got ", values.size()); 138 // clang-format on 139 for (const auto i : c10::irange(this->values_.size())) { 140 this->values_[i] = values[i]; 141 } 142 } 143 144 /// Constructs an `ExpandingArrayWithOptionalElem` from a single value of the 145 /// underlying type `T`, which is repeated `D` times (where `D` is the extent 146 /// parameter of the `ExpandingArrayWithOptionalElem`). ExpandingArrayWithOptionalElem(T single_size)147 /*implicit*/ ExpandingArrayWithOptionalElem(T single_size) 148 : ExpandingArray<D, std::optional<T>>(0) { 149 for (const auto i : c10::irange(this->values_.size())) { 150 this->values_[i] = single_size; 151 } 152 } 153 154 /// Constructs an `ExpandingArrayWithOptionalElem` from a correctly sized 155 /// `std::array` of the underlying type `T`. ExpandingArrayWithOptionalElem(const std::array<T,D> & values)156 /*implicit*/ ExpandingArrayWithOptionalElem(const std::array<T, D>& values) 157 : ExpandingArray<D, std::optional<T>>(0) { 158 for (const auto i : c10::irange(this->values_.size())) { 159 this->values_[i] = values[i]; 160 } 161 } 162 }; 163 164 template <size_t D, typename T> 165 std::ostream& operator<<( 166 std::ostream& stream, 167 const ExpandingArrayWithOptionalElem<D, T>& expanding_array_with_opt_elem) { 168 if (expanding_array_with_opt_elem.size() == 1) { 169 const auto& elem = expanding_array_with_opt_elem->at(0); 170 stream << (elem.has_value() ? c10::str(elem.value()) : "None"); 171 } else { 172 std::vector<std::string> str_array; 173 for (const auto& elem : *expanding_array_with_opt_elem) { 174 str_array.emplace_back( 175 elem.has_value() ? c10::str(elem.value()) : "None"); 176 } 177 stream << at::ArrayRef<std::string>(str_array); 178 } 179 return stream; 180 } 181 182 } // namespace torch 183