1 #include <ATen/core/TensorBase.h> 2 #include <algorithm> 3 #include <vector> 4 5 namespace at::native { 6 ensure_nonempty_dim(int64_t dim)7inline int64_t ensure_nonempty_dim(int64_t dim) { 8 return std::max<int64_t>(dim, 1); 9 } 10 ensure_nonempty_size(const TensorBase & t,int64_t dim)11inline int64_t ensure_nonempty_size(const TensorBase &t, int64_t dim) { 12 return t.dim() == 0 ? 1 : t.size(dim); 13 } 14 ensure_nonempty_stride(const TensorBase & t,int64_t dim)15inline int64_t ensure_nonempty_stride(const TensorBase &t, int64_t dim) { 16 return t.dim() == 0 ? 1 : t.stride(dim); 17 } 18 19 using IdxVec = std::vector<int64_t>; ensure_nonempty_vec(IdxVec vec)20inline IdxVec ensure_nonempty_vec(IdxVec vec) { 21 if (vec.empty()) { 22 vec.push_back(1); 23 } 24 return vec; 25 } 26 27 } // namespace at::native 28