xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/nested/NestedTensorUnaryOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #include <ATen/native/nested/NestedTensorMath.h>
2 
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #include <ATen/NestedTensorImpl.h>
8 #include <ATen/ScalarOps.h>
9 #include <ATen/TensorIndexing.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/core/Tensor.h>
13 #include <ATen/native/layer_norm.h>
14 #include <ATen/native/nested/NestedTensorUtils.h>
15 
16 namespace at::native {
17 
NestedTensor_abs(const Tensor & self)18 Tensor NestedTensor_abs(const Tensor& self) {
19   return map_nt(self, at::abs);
20 }
21 
NestedTensor_abs_(Tensor & self)22 Tensor& NestedTensor_abs_(Tensor& self) {
23   auto self_ptr = get_nested_tensor_impl(self);
24   check_numel_equals_buffer_size(self_ptr);
25   auto buffer = self_ptr->get_buffer();
26   at::abs_(buffer);
27   return self;
28 }
29 
NestedTensor_where(const Tensor & condition,const Tensor & self,const Tensor & other)30 Tensor NestedTensor_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
31   TORCH_CHECK(condition.is_nested(), "condition must be nested");
32   TORCH_CHECK(other.is_nested(), "other must be nested");
33   TORCH_CHECK(!self.is_nested(), "self must not be nested");
34 
35   auto condition_ptr = get_nested_tensor_impl(condition);
36   auto other_ptr = get_nested_tensor_impl(other);
37 
38   int64_t ntensors = condition_ptr->size(0);
39   TORCH_CHECK(other_ptr->size(0) == ntensors, "condition and other must have the same number of tensors");
40 
41   // Get the buffer and sizes of the 'other' tensor to use for the output
42   const Tensor& other_buffer = other_ptr->get_unsafe_storage_as_tensor();
43   const Tensor& other_sizes = other_ptr->get_nested_sizes();
44 
45   // Create output buffer with the same size as other_buffer
46   Tensor output_buffer = other_buffer.new_empty(other_buffer.sizes());
47 
48   // Create the output nested tensor
49   Tensor output = wrap_buffer(output_buffer, other_sizes.clone());
50 
51   // Unbind condition, other, and output into lists of tensors
52   std::vector<Tensor> condition_unbind = condition.unbind();
53   std::vector<Tensor> other_unbind = other.unbind();
54   std::vector<Tensor> output_unbind = output.unbind();
55 
56   // Apply at::where operation on each triplet of condition, self, and other tensors
57   for (int64_t i = 0; i < ntensors; i++) {
58     at::where_out(
59       output_unbind[i],
60       condition_unbind[i],
61       self,  // Note: self is not nested, so we use it directly
62       other_unbind[i]);
63   }
64 
65   return output;
66 }
67 
NestedTensor_sgn(const Tensor & self)68 Tensor NestedTensor_sgn(const Tensor& self) {
69   return map_nt(self, at::sgn);
70 }
71 
NestedTensor_sgn_(Tensor & self)72 Tensor& NestedTensor_sgn_(Tensor& self) {
73   auto self_ptr = get_nested_tensor_impl(self);
74   check_numel_equals_buffer_size(self_ptr);
75   auto buffer = self_ptr->get_buffer();
76   buffer.sgn_();
77   return self;
78 }
79 
NestedTensor_logical_not_(Tensor & self)80 Tensor& NestedTensor_logical_not_(Tensor& self){
81   auto self_ptr = get_nested_tensor_impl(self);
82   check_numel_equals_buffer_size(self_ptr);
83   auto buffer = self_ptr->get_buffer();
84   buffer.logical_not_();
85   return self;
86 }
87 
NestedTensor_logical_not(const Tensor & self)88 Tensor NestedTensor_logical_not(const Tensor& self) {
89   return map_nt(self, at::logical_not);
90 }
91 
NestedTensor_relu_(Tensor & self)92 Tensor& NestedTensor_relu_(Tensor& self) {
93   auto self_ptr = get_nested_tensor_impl(self);
94   check_numel_equals_buffer_size(self_ptr);
95   auto buffer = self_ptr->get_buffer();
96   at::relu_(buffer);
97   return self;
98 }
99 
NestedTensor_relu(const Tensor & self)100 Tensor NestedTensor_relu(const Tensor& self) {
101   return map_nt(self, at::relu);
102 }
103 
NestedTensor_gelu_(Tensor & self,c10::string_view approximate)104 Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) {
105   auto self_ptr = get_nested_tensor_impl(self);
106   check_numel_equals_buffer_size(self_ptr);
107   auto buffer = self_ptr->get_buffer();
108   at::gelu_(buffer, approximate);
109   return self;
110 }
111 
NestedTensor_gelu(const Tensor & self,c10::string_view approximate)112 Tensor NestedTensor_gelu(const Tensor& self, c10::string_view approximate) {
113   return map_nt(
114       self,
115       [approximate](const Tensor& buffer) {
116         return at::gelu(buffer, approximate);
117       });
118 }
119 
NestedTensor_tanh_(Tensor & self)120 Tensor& NestedTensor_tanh_(Tensor& self) {
121   auto self_ptr = get_nested_tensor_impl(self);
122   check_numel_equals_buffer_size(self_ptr);
123   auto buffer = self_ptr->get_buffer();
124   at::tanh_(buffer);
125   return self;
126 }
127 
NestedTensor_tanh(const Tensor & self)128 Tensor NestedTensor_tanh(const Tensor& self) {
129   return map_nt(self, at::tanh);
130 }
131 
NestedTensor_neg_(Tensor & self)132 Tensor& NestedTensor_neg_(Tensor& self) {
133   auto self_ptr = get_nested_tensor_impl(self);
134   check_numel_equals_buffer_size(self_ptr);
135   auto buffer = self_ptr->get_buffer();
136   at::neg_(buffer);
137   return self;
138 }
139 
NestedTensor_neg(const Tensor & self)140 Tensor NestedTensor_neg(const Tensor& self) {
141   return map_nt(self, at::neg);
142 }
143 
zero_nested_(Tensor & self)144 Tensor& zero_nested_(Tensor& self) {
145   const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
146   self_buf.fill_(0);
147   return self;
148 }
149 
NestedTensor_silu(const Tensor & self)150 Tensor NestedTensor_silu(const Tensor& self){
151   return map_nt(self, at::silu);
152 }
153 
NestedTensor_silu_(Tensor & self)154 Tensor& NestedTensor_silu_(Tensor& self){
155   auto self_ptr = get_nested_tensor_impl(self);
156   check_numel_equals_buffer_size(self_ptr);
157   auto buffer = self_ptr->get_buffer();
158   at::silu_(buffer);
159   return self;
160 }
161 
sin_nested(const Tensor & self)162 Tensor sin_nested(const Tensor& self) {
163   return map_nt(self, at::sin);
164 }
165 
cos_nested(const Tensor & self)166 Tensor cos_nested(const Tensor& self) {
167   return map_nt(self, at::cos);
168 }
169 
_pin_memory_nested(const Tensor & self,std::optional<Device> device)170 Tensor _pin_memory_nested(const Tensor& self, std::optional<Device> device) {
171   auto* nt_input = get_nested_tensor_impl(self);
172   const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();
173   return wrap_buffer(
174       at::_pin_memory(input_buffer, device),
175       nt_input->get_nested_sizes(),
176       nt_input->get_nested_strides(),
177       nt_input->get_storage_offsets());
178 }
179 
180 } // namespace at::native
181