xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/NonEmptyUtils.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/core/TensorBase.h>
2 #include <algorithm>
3 #include <vector>
4 
5 namespace at::native {
6 
ensure_nonempty_dim(int64_t dim)7 inline int64_t ensure_nonempty_dim(int64_t dim) {
8   return std::max<int64_t>(dim, 1);
9 }
10 
ensure_nonempty_size(const TensorBase & t,int64_t dim)11 inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) {
12   return t.dim() == 0 ? 1 : t.size(dim);
13 }
14 
ensure_nonempty_stride(const TensorBase & t,int64_t dim)15 inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) {
16   return t.dim() == 0 ? 1 : t.stride(dim);
17 }
18 
19 using IdxVec = std::vector<int64_t>;
ensure_nonempty_vec(IdxVec vec)20 inline IdxVec ensure_nonempty_vec(IdxVec vec) {
21   if (vec.empty()) {
22     vec.push_back(1);
23   }
24   return vec;
25 }
26 
27 }  // namespace at::native
28