1 #include <ATen/native/nested/NestedTensorMath.h>
2
3 #include <ATen/AccumulateType.h>
4 #include <ATen/Dispatch.h>
5 #include <ATen/Functions.h>
6 #include <ATen/NativeFunctions.h>
7 #include <ATen/NestedTensorImpl.h>
8 #include <ATen/ScalarOps.h>
9 #include <ATen/TensorIndexing.h>
10 #include <ATen/TensorOperators.h>
11 #include <ATen/TensorUtils.h>
12 #include <ATen/core/Tensor.h>
13 #include <ATen/native/layer_norm.h>
14 #include <ATen/native/nested/NestedTensorUtils.h>
15
16 namespace at::native {
17
NestedTensor_abs(const Tensor & self)18 Tensor NestedTensor_abs(const Tensor& self) {
19 return map_nt(self, at::abs);
20 }
21
NestedTensor_abs_(Tensor & self)22 Tensor& NestedTensor_abs_(Tensor& self) {
23 auto self_ptr = get_nested_tensor_impl(self);
24 check_numel_equals_buffer_size(self_ptr);
25 auto buffer = self_ptr->get_buffer();
26 at::abs_(buffer);
27 return self;
28 }
29
NestedTensor_where(const Tensor & condition,const Tensor & self,const Tensor & other)30 Tensor NestedTensor_where(const Tensor& condition, const Tensor& self, const Tensor& other) {
31 TORCH_CHECK(condition.is_nested(), "condition must be nested");
32 TORCH_CHECK(other.is_nested(), "other must be nested");
33 TORCH_CHECK(!self.is_nested(), "self must not be nested");
34
35 auto condition_ptr = get_nested_tensor_impl(condition);
36 auto other_ptr = get_nested_tensor_impl(other);
37
38 int64_t ntensors = condition_ptr->size(0);
39 TORCH_CHECK(other_ptr->size(0) == ntensors, "condition and other must have the same number of tensors");
40
41 // Get the buffer and sizes of the 'other' tensor to use for the output
42 const Tensor& other_buffer = other_ptr->get_unsafe_storage_as_tensor();
43 const Tensor& other_sizes = other_ptr->get_nested_sizes();
44
45 // Create output buffer with the same size as other_buffer
46 Tensor output_buffer = other_buffer.new_empty(other_buffer.sizes());
47
48 // Create the output nested tensor
49 Tensor output = wrap_buffer(output_buffer, other_sizes.clone());
50
51 // Unbind condition, other, and output into lists of tensors
52 std::vector<Tensor> condition_unbind = condition.unbind();
53 std::vector<Tensor> other_unbind = other.unbind();
54 std::vector<Tensor> output_unbind = output.unbind();
55
56 // Apply at::where operation on each triplet of condition, self, and other tensors
57 for (int64_t i = 0; i < ntensors; i++) {
58 at::where_out(
59 output_unbind[i],
60 condition_unbind[i],
61 self, // Note: self is not nested, so we use it directly
62 other_unbind[i]);
63 }
64
65 return output;
66 }
67
NestedTensor_sgn(const Tensor & self)68 Tensor NestedTensor_sgn(const Tensor& self) {
69 return map_nt(self, at::sgn);
70 }
71
NestedTensor_sgn_(Tensor & self)72 Tensor& NestedTensor_sgn_(Tensor& self) {
73 auto self_ptr = get_nested_tensor_impl(self);
74 check_numel_equals_buffer_size(self_ptr);
75 auto buffer = self_ptr->get_buffer();
76 buffer.sgn_();
77 return self;
78 }
79
NestedTensor_logical_not_(Tensor & self)80 Tensor& NestedTensor_logical_not_(Tensor& self){
81 auto self_ptr = get_nested_tensor_impl(self);
82 check_numel_equals_buffer_size(self_ptr);
83 auto buffer = self_ptr->get_buffer();
84 buffer.logical_not_();
85 return self;
86 }
87
NestedTensor_logical_not(const Tensor & self)88 Tensor NestedTensor_logical_not(const Tensor& self) {
89 return map_nt(self, at::logical_not);
90 }
91
NestedTensor_relu_(Tensor & self)92 Tensor& NestedTensor_relu_(Tensor& self) {
93 auto self_ptr = get_nested_tensor_impl(self);
94 check_numel_equals_buffer_size(self_ptr);
95 auto buffer = self_ptr->get_buffer();
96 at::relu_(buffer);
97 return self;
98 }
99
NestedTensor_relu(const Tensor & self)100 Tensor NestedTensor_relu(const Tensor& self) {
101 return map_nt(self, at::relu);
102 }
103
NestedTensor_gelu_(Tensor & self,c10::string_view approximate)104 Tensor& NestedTensor_gelu_(Tensor& self, c10::string_view approximate) {
105 auto self_ptr = get_nested_tensor_impl(self);
106 check_numel_equals_buffer_size(self_ptr);
107 auto buffer = self_ptr->get_buffer();
108 at::gelu_(buffer, approximate);
109 return self;
110 }
111
NestedTensor_gelu(const Tensor & self,c10::string_view approximate)112 Tensor NestedTensor_gelu(const Tensor& self, c10::string_view approximate) {
113 return map_nt(
114 self,
115 [approximate](const Tensor& buffer) {
116 return at::gelu(buffer, approximate);
117 });
118 }
119
NestedTensor_tanh_(Tensor & self)120 Tensor& NestedTensor_tanh_(Tensor& self) {
121 auto self_ptr = get_nested_tensor_impl(self);
122 check_numel_equals_buffer_size(self_ptr);
123 auto buffer = self_ptr->get_buffer();
124 at::tanh_(buffer);
125 return self;
126 }
127
NestedTensor_tanh(const Tensor & self)128 Tensor NestedTensor_tanh(const Tensor& self) {
129 return map_nt(self, at::tanh);
130 }
131
NestedTensor_neg_(Tensor & self)132 Tensor& NestedTensor_neg_(Tensor& self) {
133 auto self_ptr = get_nested_tensor_impl(self);
134 check_numel_equals_buffer_size(self_ptr);
135 auto buffer = self_ptr->get_buffer();
136 at::neg_(buffer);
137 return self;
138 }
139
NestedTensor_neg(const Tensor & self)140 Tensor NestedTensor_neg(const Tensor& self) {
141 return map_nt(self, at::neg);
142 }
143
zero_nested_(Tensor & self)144 Tensor& zero_nested_(Tensor& self) {
145 const auto& self_buf = get_nested_tensor_impl(self)->get_buffer();
146 self_buf.fill_(0);
147 return self;
148 }
149
NestedTensor_silu(const Tensor & self)150 Tensor NestedTensor_silu(const Tensor& self){
151 return map_nt(self, at::silu);
152 }
153
NestedTensor_silu_(Tensor & self)154 Tensor& NestedTensor_silu_(Tensor& self){
155 auto self_ptr = get_nested_tensor_impl(self);
156 check_numel_equals_buffer_size(self_ptr);
157 auto buffer = self_ptr->get_buffer();
158 at::silu_(buffer);
159 return self;
160 }
161
sin_nested(const Tensor & self)162 Tensor sin_nested(const Tensor& self) {
163 return map_nt(self, at::sin);
164 }
165
cos_nested(const Tensor & self)166 Tensor cos_nested(const Tensor& self) {
167 return map_nt(self, at::cos);
168 }
169
_pin_memory_nested(const Tensor & self,std::optional<Device> device)170 Tensor _pin_memory_nested(const Tensor& self, std::optional<Device> device) {
171 auto* nt_input = get_nested_tensor_impl(self);
172 const auto& input_buffer = nt_input->get_unsafe_storage_as_tensor();
173 return wrap_buffer(
174 at::_pin_memory(input_buffer, device),
175 nt_input->get_nested_sizes(),
176 nt_input->get_nested_strides(),
177 nt_input->get_storage_offsets());
178 }
179
180 } // namespace at::native
181