xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ChanelShuffle.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/NamedTensorUtils.h>
3 #if defined(C10_MOBILE) && defined(USE_XNNPACK)
4 #include <ATen/native/xnnpack/Engine.h>
5 #endif
6 #include <c10/util/Exception.h>
7 
8 #include <ATen/native/cpu/ChannelShuffleKernel.h>
9 
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/channel_shuffle_native.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/native_channel_shuffle.h>
17 #include <ATen/ops/native_channel_shuffle_native.h>
18 #endif
19 
20 namespace at::native {
21 
channel_shuffle_cpu(const Tensor & self,int64_t groups)22 Tensor channel_shuffle_cpu(const Tensor& self, int64_t groups) {
23   Tensor output;
24   if (self.numel() == 0) {
25     output = self.alias();
26   } else {
27     auto memory_format = self.suggest_memory_format();
28     output = at::empty({0}, self.options());
29     output.resize_(self.sizes(), memory_format);
30     auto input = self.contiguous(memory_format);
31     channel_shuffle_kernel(kCPU, output, input, groups);
32   }
33   return namedinference::propagate_names_if_nonempty(
34       output,
35       self.has_names() ? self.names() : at::ArrayRef<Dimname>{});
36 }
37 
channel_shuffle(const Tensor & self,int64_t groups)38 Tensor channel_shuffle(const Tensor& self, int64_t groups) {
39   TORCH_CHECK(self.dim() > 2,
40               "channel_shuffle expects input with > 2 dims, but got input with sizes ",
41               self.sizes());
42   int64_t c = self.size(1);
43   TORCH_CHECK(groups > 0,
44               "Number of groups to divide channels in must be positive.",
45               " Value of groups:", groups);
46   TORCH_CHECK((c % groups) == 0,
47               "Number of channels must be divisible by groups. Got ",
48               c, " channels and ", groups, " groups.");
49 
50 #if defined(C10_MOBILE) && defined(USE_XNNPACK)
51   if (self.is_contiguous(MemoryFormat::ChannelsLast) &&
52       xnnpack::use_channel_shuffle(self, groups)) {
53     auto output = self.numel() == 0 ? self.alias() : xnnpack::channel_shuffle(self, groups);
54     return output;
55   }
56 #endif
57 
58   auto output = self.numel() == 0 ? self.alias() : at::native_channel_shuffle(self, groups);
59   return namedinference::propagate_names_if_nonempty(
60       output,
61       self.has_names() ? self.names() : at::ArrayRef<Dimname>{});
62 }
63 
math_channel_shuffle(const Tensor & self,int64_t groups)64 Tensor math_channel_shuffle(const Tensor& self, int64_t groups) {
65   int64_t b = self.size(0);
66   int64_t c = self.size(1);
67   int64_t oc = c / groups;
68 
69   auto input_reshaped = self.view({b, groups, oc, -1});
70   // TODO: contiguous can be made to preserve the memory format
71   // of the input. However since the above reshape clobbers h and w
72   // it may not be safe to do that, since channels_last contiguous
73   // may think oc and and the last dim correspond to h,w?
74   // It is not clear, however from initial looking around it feels that
75   // this may not be correct.
76   // In this case channels last will likely require custom implementation
77   // if we want to preserve the memory order.
78   // XNNPACK has channel shuffle op for NHWC. For mobile usecase this is good.
79   // For server we will have to do a custom implementation.
80   // For ChannelsFirst, a.k.a Contiguous, memory format we will also need
81   // a fast custom implementation perhaps.
82   Tensor output_tensor =
83       input_reshaped.permute({0 /* b */, 2 /* oc */, 1 /* groups */, 3})
84       .contiguous()
85       .reshape(self.sizes());
86   return namedinference::propagate_names_if_nonempty(
87       output_tensor,
88       self.has_names() ? self.names() : at::ArrayRef<Dimname>{});
89 }
90 
91 DEFINE_DISPATCH(channel_shuffle_kernel);
92 
93 } // namespace at::native
94