xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cuda/Activation.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/cuda/Activation.h>
3 
4 #include <ATen/core/DimVector.h>
5 #include <ATen/core/Tensor.h>
6 #include <ATen/TensorIterator.h>
7 #include <ATen/WrapDimUtils.h>
8 #include <ATen/native/Resize.h>
9 #include <c10/util/irange.h>
10 
11 #ifndef AT_PER_OPERATOR_HEADERS
12 #include <ATen/Functions.h>
13 #include <ATen/NativeFunctions.h>
14 #else
15 #include <ATen/ops/empty.h>
16 #include <ATen/ops/empty_like.h>
17 #include <ATen/ops/gelu_backward_native.h>
18 #include <ATen/ops/gelu_native.h>
19 #include <ATen/ops/glu_backward_native.h>
20 #include <ATen/ops/log_sigmoid_forward_native.h>
21 #endif
22 
23 namespace at::native {
24 
25 // -----------------------------------
26 // glu backward
27 // -----------------------------------
28 
glu_backward_cuda_out(const Tensor & grad_output,const Tensor & input,int64_t dim,Tensor & grad_input)29 Tensor& glu_backward_cuda_out(const Tensor& grad_output, const Tensor& input,
30                               int64_t dim, Tensor& grad_input) {
31   TORCH_CHECK(input.dim() > 0, "glu does not support 0-dimensional tensors");
32   auto wrap_dim = maybe_wrap_dim(dim, input.dim());
33   auto input_sizes = input.sizes();
34   const int64_t nIn = input_sizes[wrap_dim];
35   TORCH_CHECK(nIn % 2 == 0, "Halving dimension must be even, but dimension ",
36               wrap_dim, " is size ", nIn);
37 
38   resize_output(grad_input, input_sizes);
39 
40   DimVector iter_shape(input_sizes);
41   const auto dim_size = nIn / 2;
42   iter_shape[wrap_dim] = dim_size;
43   TORCH_CHECK(grad_output.sizes() == IntArrayRef{iter_shape});
44 
45   const auto iter = at::TensorIteratorConfig()
46     .add_output(grad_input)
47     .add_const_input(input)
48     .add_const_input(grad_output)
49     .resize_outputs(false)
50     .declare_static_shape(iter_shape)
51     .build();
52 
53   if (iter.numel() == 0) {
54     return grad_input;
55   }
56 
57   const auto I_stride = input.strides()[wrap_dim] * dim_size;
58   const auto gI_stride = grad_input.strides()[wrap_dim] * dim_size;
59 
60   if (iter.can_use_32bit_indexing()) {
61     launch_glu_backward_kernel(iter, gI_stride, I_stride);
62   } else {
63     for (const auto& sub_iter: iter.with_32bit_indexing()) {
64       launch_glu_backward_kernel(sub_iter, gI_stride, I_stride);
65     }
66   }
67   return grad_input;
68 }
69 
glu_backward_cuda(const Tensor & grad_output,const Tensor & input,int64_t dim)70 Tensor glu_backward_cuda(const Tensor& grad_output, const Tensor& input, int64_t dim) {
71   auto grad_input = at::empty({0}, input.options());
72   return glu_backward_cuda_out(grad_output, input, dim, grad_input);
73 }
74 
75 // -----------------------------------
76 // log_sigmoid forward
77 // -----------------------------------
78 
log_sigmoid_forward_out_cuda(const Tensor & input,Tensor & result,Tensor & buffer)79 std::tuple<Tensor&, Tensor&> log_sigmoid_forward_out_cuda(const Tensor& input, Tensor& result, Tensor& buffer) {
80   // NOTE: buffer is only used by CPU dispatch, we just ignore it here
81   auto iter = TensorIteratorConfig()
82     .add_output(result)
83     .add_const_input(input)
84     .build();
85   launch_log_sigmoid_forward_kernel(iter);
86   return std::forward_as_tuple(result, buffer);
87 }
88 
log_sigmoid_forward_cuda(const Tensor & input)89 std::tuple<Tensor, Tensor> log_sigmoid_forward_cuda(const Tensor& input) {
90   auto result = at::empty_like(input);
91   auto buffer = at::empty({0}, input.options());
92   log_sigmoid_forward_out_cuda(input, result, buffer);
93   return std::forward_as_tuple(result, buffer);
94 }
95 
TORCH_IMPL_FUNC(gelu_out_cuda)96 TORCH_IMPL_FUNC(gelu_out_cuda) (
97   const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*result*/
98 ) {
99   GeluCUDAKernelImpl(*this, get_gelutype_enum(approximate));
100 }
101 
TORCH_IMPL_FUNC(gelu_backward_out_cuda)102 TORCH_IMPL_FUNC(gelu_backward_out_cuda) (
103   const Tensor& /*grad*/, const Tensor& /*self*/, c10::string_view approximate, const Tensor& /*grad_input*/
104 ) {
105   GeluBackwardCUDAKernelImpl(*this, get_gelutype_enum(approximate));
106 }
107 
108 }  // namespace at::native
109