1 #include <ATen/native/vulkan/ops/Common.h>
2 #include <ATen/native/vulkan/ops/Utils.h>
3 #include <torch/library.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #else
8 #include <ATen/ops/repeat.h>
9 #endif
10
11 #include <ATen/native/vulkan/ops/Utils.h>
12
13 namespace at {
14 namespace native {
15 namespace vulkan {
16 namespace ops {
17 namespace {
18
19 using namespace api::utils;
20
expand(const at::Tensor & self,const IntArrayRef output_size,bool implicit=false)21 Tensor expand(
22 const at::Tensor& self,
23 const IntArrayRef output_size,
24 bool implicit = false) {
25 TORCH_CHECK(
26 self.dim() > 0 && self.dim() <= 4,
27 "Vulkan expand supports up to 4d tensors");
28 TORCH_CHECK(
29 static_cast<size_t>(self.dim()) <= output_size.size(),
30 "Vulkan expand: the number of sizes provided (",
31 output_size.size(),
32 ") must be greater or equal to the number of dimensions in the tensor (",
33 self.dim(),
34 ").");
35
36 std::vector<int64_t> repeat_size = std::vector<int64_t>(output_size.size());
37 std::vector<int64_t> input_size = self.sizes().vec();
38
39 int in_idx = input_size.size() - 1;
40 for (int i = output_size.size() - 1; i >= 0; --i) {
41 if (in_idx >= 0) {
42 TORCH_CHECK(
43 input_size[in_idx] == output_size[i] || input_size[in_idx] == 1 ||
44 output_size[i] == -1,
45 "Vulkan expand: the expanded size of the tensor (",
46 output_size[i],
47 ") must match the existing size (",
48 input_size[in_idx],
49 ") at non-singleton dimension ",
50 i);
51
52 if (input_size[in_idx] == output_size[i] || output_size[i] == -1) {
53 repeat_size[i] = 1;
54 } else if (input_size[in_idx] == 1) {
55 repeat_size[i] = output_size[i];
56 }
57 --in_idx;
58 } else {
59 TORCH_CHECK(
60 output_size[i] != -1,
61 "Vulkan expand: the expanded size of the tensor (-1) is not allowed in a leading, non-existing dimension 0.");
62
63 repeat_size[i] = output_size[i];
64 }
65 }
66
67 return self.repeat(repeat_size);
68 }
69
expand_as(const at::Tensor & self,const at::Tensor & other)70 Tensor expand_as(const at::Tensor& self, const at::Tensor& other) {
71 return expand(self, other.sizes());
72 }
73
74 #ifdef USE_VULKAN_API
75
TORCH_LIBRARY_IMPL(aten,Vulkan,m)76 TORCH_LIBRARY_IMPL(aten, Vulkan, m) {
77 m.impl(TORCH_SELECTIVE_NAME("aten::expand"), TORCH_FN(expand));
78 m.impl(TORCH_SELECTIVE_NAME("aten::expand_as"), TORCH_FN(expand_as));
79 }
80
81 #endif /* USE_VULKAN_API */
82
83 } // namespace
84 } // namespace ops
85 } // namespace vulkan
86 } // namespace native
87 } // namespace at
88