xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/vulkan/ops/Expand.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 #include <torch/library.h>
4 
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #else
8 #include <ATen/ops/repeat.h>
9 #endif
10 
11 #include <ATen/native/vulkan/ops/Utils.h>
12 
13 namespace at {
14 namespace native {
15 namespace vulkan {
16 namespace ops {
17 namespace {
18 
19 using namespace api::utils;
20 
expand(const at::Tensor & self,const IntArrayRef output_size,bool implicit=false)21 Tensor expand(
22     const at::Tensor& self,
23     const IntArrayRef output_size,
24     bool implicit = false) {
25   TORCH_CHECK(
26       self.dim() > 0 && self.dim() <= 4,
27       "Vulkan expand supports up to 4d tensors");
28   TORCH_CHECK(
29       static_cast<size_t>(self.dim()) <= output_size.size(),
30       "Vulkan expand: the number of sizes provided (",
31       output_size.size(),
32       ") must be greater or equal to the number of dimensions in the tensor (",
33       self.dim(),
34       ").");
35 
36   std::vector<int64_t> repeat_size = std::vector<int64_t>(output_size.size());
37   std::vector<int64_t> input_size = self.sizes().vec();
38 
39   int in_idx = input_size.size() - 1;
40   for (int i = output_size.size() - 1; i >= 0; --i) {
41     if (in_idx >= 0) {
42       TORCH_CHECK(
43           input_size[in_idx] == output_size[i] || input_size[in_idx] == 1 ||
44               output_size[i] == -1,
45           "Vulkan expand: the expanded size of the tensor (",
46           output_size[i],
47           ") must match the existing size (",
48           input_size[in_idx],
49           ") at non-singleton dimension ",
50           i);
51 
52       if (input_size[in_idx] == output_size[i] || output_size[i] == -1) {
53         repeat_size[i] = 1;
54       } else if (input_size[in_idx] == 1) {
55         repeat_size[i] = output_size[i];
56       }
57       --in_idx;
58     } else {
59       TORCH_CHECK(
60           output_size[i] != -1,
61           "Vulkan expand: the expanded size of the tensor (-1) is not allowed in a leading, non-existing dimension 0.");
62 
63       repeat_size[i] = output_size[i];
64     }
65   }
66 
67   return self.repeat(repeat_size);
68 }
69 
expand_as(const at::Tensor & self,const at::Tensor & other)70 Tensor expand_as(const at::Tensor& self, const at::Tensor& other) {
71   return expand(self, other.sizes());
72 }
73 
74 #ifdef USE_VULKAN_API
75 
TORCH_LIBRARY_IMPL(aten,Vulkan,m)76 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
77   m.impl(TORCH_SELECTIVE_NAME("aten::expand"), TORCH_FN(expand));
78   m.impl(TORCH_SELECTIVE_NAME("aten::expand_as"), TORCH_FN(expand_as));
79 }
80 
81 #endif /* USE_VULKAN_API */
82 
83 } // namespace
84 } // namespace ops
85 } // namespace vulkan
86 } // namespace native
87 } // namespace at
88