1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/NamedTensorUtils.h>
3 #if defined(C10_MOBILE) && defined(USE_XNNPACK)
4 #include <ATen/native/xnnpack/Engine.h>
5 #endif
6 #include <c10/util/Exception.h>
7
8 #include <ATen/native/cpu/ChannelShuffleKernel.h>
9
10 #ifndef AT_PER_OPERATOR_HEADERS
11 #include <ATen/Functions.h>
12 #include <ATen/NativeFunctions.h>
13 #else
14 #include <ATen/ops/channel_shuffle_native.h>
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/native_channel_shuffle.h>
17 #include <ATen/ops/native_channel_shuffle_native.h>
18 #endif
19
20 namespace at::native {
21
channel_shuffle_cpu(const Tensor & self,int64_t groups)22 Tensor channel_shuffle_cpu(const Tensor& self, int64_t groups) {
23 Tensor output;
24 if (self.numel() == 0) {
25 output = self.alias();
26 } else {
27 auto memory_format = self.suggest_memory_format();
28 output = at::empty({0}, self.options());
29 output.resize_(self.sizes(), memory_format);
30 auto input = self.contiguous(memory_format);
31 channel_shuffle_kernel(kCPU, output, input, groups);
32 }
33 return namedinference::propagate_names_if_nonempty(
34 output,
35 self.has_names() ? self.names() : at::ArrayRef<Dimname>{});
36 }
37
channel_shuffle(const Tensor & self,int64_t groups)38 Tensor channel_shuffle(const Tensor& self, int64_t groups) {
39 TORCH_CHECK(self.dim() > 2,
40 "channel_shuffle expects input with > 2 dims, but got input with sizes ",
41 self.sizes());
42 int64_t c = self.size(1);
43 TORCH_CHECK(groups > 0,
44 "Number of groups to divide channels in must be positive.",
45 " Value of groups:", groups);
46 TORCH_CHECK((c % groups) == 0,
47 "Number of channels must be divisible by groups. Got ",
48 c, " channels and ", groups, " groups.");
49
50 #if defined(C10_MOBILE) && defined(USE_XNNPACK)
51 if (self.is_contiguous(MemoryFormat::ChannelsLast) &&
52 xnnpack::use_channel_shuffle(self, groups)) {
53 auto output = self.numel() == 0 ? self.alias() : xnnpack::channel_shuffle(self, groups);
54 return output;
55 }
56 #endif
57
58 auto output = self.numel() == 0 ? self.alias() : at::native_channel_shuffle(self, groups);
59 return namedinference::propagate_names_if_nonempty(
60 output,
61 self.has_names() ? self.names() : at::ArrayRef<Dimname>{});
62 }
63
math_channel_shuffle(const Tensor & self,int64_t groups)64 Tensor math_channel_shuffle(const Tensor& self, int64_t groups) {
65 int64_t b = self.size(0);
66 int64_t c = self.size(1);
67 int64_t oc = c / groups;
68
69 auto input_reshaped = self.view({b, groups, oc, -1});
70 // TODO: contiguous can be made to preserve the memory format
71 // of the input. However since the above reshape clobbers h and w
72 // it may not be safe to do that, since channels_last contiguous
73 // may think oc and and the last dim correspond to h,w?
74 // It is not clear, however from initial looking around it feels that
75 // this may not be correct.
76 // In this case channels last will likely require custom implementation
77 // if we want to preserve the memory order.
78 // XNNPACK has channel shuffle op for NHWC. For mobile usecase this is good.
79 // For server we will have to do a custom implementation.
80 // For ChannelsFirst, a.k.a Contiguous, memory format we will also need
81 // a fast custom implementation perhaps.
82 Tensor output_tensor =
83 input_reshaped.permute({0 /* b */, 2 /* oc */, 1 /* groups */, 3})
84 .contiguous()
85 .reshape(self.sizes());
86 return namedinference::propagate_names_if_nonempty(
87 output_tensor,
88 self.has_names() ? self.names() : at::ArrayRef<Dimname>{});
89 }
90
91 DEFINE_DISPATCH(channel_shuffle_kernel);
92
93 } // namespace at::native
94