xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TensorShape.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 #include <ATen/core/Tensor.h>
3 #include <c10/util/irange.h>
4 #include <ATen/core/IListRef.h>
5 
6 namespace at::native {
7 
8 TORCH_API at::Tensor clone_preserve_strides(const at::Tensor& self);
9 
cat_should_skip_tensor(const Tensor & t)10 inline bool cat_should_skip_tensor(const Tensor& t) {
11   return t.sym_numel() == 0 && t.dim() == 1;
12 }
13 
14  // Check to see if the shape of tensors is compatible
15  // for being concatenated along a given dimension.
check_cat_shape_except_dim(const Tensor & first,const Tensor & second,int64_t dimension,int64_t index)16 inline void check_cat_shape_except_dim(const Tensor & first, const Tensor & second, int64_t dimension, int64_t index) {
17    int64_t first_dims = first.dim();
18    int64_t second_dims = second.dim();
19    TORCH_CHECK(first_dims == second_dims, "Tensors must have same number of dimensions: got ",
20                first_dims, " and ", second_dims);
21    for (const auto dim : c10::irange(first_dims)) {
22      if (dim == dimension) {
23        continue;
24      }
25      int64_t first_dim_size = first.sizes()[dim];
26      int64_t second_dim_size = second.sizes()[dim];
27      TORCH_CHECK(first_dim_size == second_dim_size, "Sizes of tensors must match except in dimension ",
28                  dimension, ". Expected size ", static_cast<long long>(first_dim_size), " but got size ", static_cast<long long>(second_dim_size), " for tensor number ", index, " in the list.");
29    }
30  }
31 
check_cat_no_zero_dim(const MaterializedITensorListRef & tensors)32 inline void check_cat_no_zero_dim(const MaterializedITensorListRef& tensors) {
33   int64_t i = 0;
34   for(const Tensor& t : tensors) {
35     TORCH_CHECK(t.dim() > 0,
36              "zero-dimensional tensor (at position ", i, ") cannot be concatenated");
37     i++;
38   }
39 }
40 
get_num_splits(const Tensor & self,int64_t split_size,int64_t dim)41 inline int64_t get_num_splits(const Tensor& self, int64_t split_size, int64_t dim) {
42   TORCH_CHECK(self.dim() != 0, "split expects at least a 1-dimensional tensor");
43   TORCH_CHECK(split_size >= 0,  "split expects split_size be non-negative, but got split_size=", split_size);
44   int64_t dim_size = self.size(dim);
45   TORCH_CHECK(split_size > 0 || dim_size == 0,
46            "split_size can only be 0 if dimension size is 0, "
47            "but got dimension size of ", dim_size);
48   // if split_size is 0 and dimension size is 0, there is 1 split.
49   int64_t num_splits = 1;
50   if (split_size != 0) {
51     // ensuring num_splits is at least 1 makes consistent the case where split_size > dim_size
52     // (returns a single split).  We might want to error here, but keep it for BC.
53     num_splits = std::max<int64_t>((dim_size + split_size - 1) / split_size, 1);
54   }
55   return num_splits;
56 }
57 
have_same_ndims(TensorList tensors)58 inline bool have_same_ndims(TensorList tensors) {
59   auto ndim = tensors[0].dim();
60   for (const auto tensor_idx : c10::irange(tensors.size())) {
61     if(tensors[tensor_idx].dim() != ndim) {
62       return false;
63     }
64   }
65   return true;
66 }
67 
leading_dimension_matches(TensorList tensors,int64_t dim)68 inline void leading_dimension_matches(TensorList tensors, int64_t dim) {
69   auto tensor_zero_size = tensors[0].sizes();
70   std::vector<c10::SymInt> leading_dim_sizes(tensor_zero_size.begin(), tensor_zero_size.begin() + dim);
71   for (const auto i : c10::irange(tensors.size())) {
72     at::Tensor tensor = tensors[i];
73     for(const auto j : c10::irange(dim)) {
74       TORCH_CHECK(
75         tensor.size(j) == leading_dim_sizes[j],
76         "_chunk_cat expects same sizes of 0,...,dim-1 dimensions for all tensors"
77       );
78     }
79   }
80 }
81 
preprocess_chunk_cat_inputs(TensorList tensors,int64_t dim,int64_t num_chunks)82 inline int64_t preprocess_chunk_cat_inputs(TensorList tensors, int64_t dim, int64_t num_chunks) {
83   TORCH_CHECK(num_chunks >= 1, "_chunk_cat expects positive num_chunks");
84   TORCH_CHECK(!tensors.empty(),
85            "_chunk_cat expects a non-empty input tensor list");
86   auto expected_dtype = tensors[0].dtype();
87   auto expected_device = tensors[0].device();
88   for(const auto i : c10::irange(tensors.size())) {
89     TORCH_CHECK(tensors[i].numel() > 0, "_chunk_cat expects non-empty tensor");
90     TORCH_CHECK(tensors[i].dtype() == expected_dtype, "_chunk_cat expects all input tensors with the same dtype");
91     TORCH_CHECK(tensors[i].device() == expected_device, "_chunk_cat expects all inputs tensors on the same device");
92   }
93   if (have_same_ndims(tensors)) {
94     dim = maybe_wrap_dim(dim, tensors[0].dim());
95   } else {
96     TORCH_CHECK(dim >= 0, "_chunk_cat expects non-negative dim when input tensors have different ndims")
97     for(const auto i : c10::irange(tensors.size())) {
98       TORCH_CHECK(dim < tensors[i].ndimension(), "_chunk_cat expects dim < ndim for all input tensors");
99     }
100   }
101   leading_dimension_matches(tensors, dim);
102   return dim;
103 }
104 
105 } // namespace at::native
106