xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/ComplexHelper.h (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #pragma once
2 
3 #include <ATen/core/Tensor.h>
4 #include <c10/util/irange.h>
5 
6 #ifndef AT_PER_OPERATOR_HEADERS
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/view_as_real_native.h>
10 #include <ATen/ops/view_as_complex_native.h>
11 
12 #include <utility>
13 #endif
14 
15 // WARNING: this header contains non-inline functions and should be only
16 // included from ONE cpp file
17 
18 namespace at::native {
19 
20 // View tensor with new dtype, storage offset, sizes and strides
view_tensor(const Tensor & tensor,ScalarType dtype,c10::SymInt offset,SymIntArrayRef sizes,SymIntArrayRef strides)21 inline Tensor view_tensor(
22     const Tensor &tensor, ScalarType dtype,
23     c10::SymInt offset, SymIntArrayRef sizes, SymIntArrayRef strides) {
24   Storage storage = tensor.storage();
25   auto key_set = tensor.key_set().remove(DispatchKey::Conjugate);
26   auto new_tensor = detail::make_tensor<TensorImpl>(
27       c10::TensorImpl::VIEW, std::move(storage), key_set, scalarTypeToTypeMeta(dtype));
28   auto * impl = new_tensor.unsafeGetTensorImpl();
29   impl->set_sizes_and_strides(sizes, strides, offset);
30   return new_tensor;
31 }
32 
computeStrideForViewAsReal(SymIntArrayRef oldstride)33 inline SymDimVector computeStrideForViewAsReal(SymIntArrayRef oldstride) {
34   SymDimVector res(oldstride.size() + 1);
35   for (const auto i : c10::irange(oldstride.size())) {
36     res[i] = oldstride[i] * 2;
37   }
38   res.back() = 1;
39   return res;
40 }
41 
_view_as_real_physical(const Tensor & self)42 inline Tensor _view_as_real_physical(const Tensor& self) {
43   TORCH_CHECK(self.is_complex(), "view_as_real is only supported for complex tensors");
44   auto old_sizes = self.sym_sizes();
45   SymDimVector new_sizes(old_sizes.size() + 1);
46   std::copy(old_sizes.begin(), old_sizes.end(), new_sizes.begin());
47   // last dimension will always have two elements containing the real and imag vals
48   new_sizes.back() = 2;
49   auto new_strides = computeStrideForViewAsReal(self.sym_strides());
50   auto new_storage_offset = self.sym_storage_offset() * 2;
51   const auto float_type = c10::toRealValueType(self.scalar_type());
52   auto real_tensor = view_tensor(self, float_type, std::move(new_storage_offset), new_sizes, new_strides);
53   return real_tensor;
54 }
55 
56 // expects as input a complex tensor and returns back a tensor
57 // with corresponding real dtype containing the complex values
58 // in the last two dimensions
view_as_real(const Tensor & self)59 Tensor view_as_real(const Tensor& self) {
60   TORCH_CHECK(!self.is_conj(), "view_as_real doesn't work on unresolved conjugated tensors.  To resolve the conjugate tensor so you can view it as real, use self.resolve_conj(); however, be warned that the resulting tensor will NOT alias the original.");
61   return _view_as_real_physical(self);
62 }
63 
computeStrideForViewAsComplex(SymIntArrayRef oldstride)64 inline SymDimVector computeStrideForViewAsComplex(SymIntArrayRef oldstride) {
65   const auto dim = oldstride.size();
66   TORCH_CHECK(dim > 0 && oldstride[dim - 1] == 1, "Tensor must have a last dimension with stride 1");
67 
68   SymDimVector res(dim - 1);
69   for (const auto i : c10::irange(res.size())) {
70     TORCH_CHECK(oldstride[i] % 2 == 0, "Tensor must have a stride divisible by 2 for all but last dimension");
71     res[i] = oldstride[i] / 2;
72   }
73   return res;
74 }
75 
76 // expects as input a float or double tensor with last dimension of size 2
77 // and returns back a tensor with corresponding complex dtype
view_as_complex(const Tensor & self)78 Tensor view_as_complex(const Tensor& self) {
79   TORCH_CHECK(
80     self.scalar_type() == kFloat || self.scalar_type() == kDouble || self.scalar_type() == kHalf,
81     "view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ", self.scalar_type());
82 
83   auto old_sizes = self.sym_sizes();
84   TORCH_CHECK(!old_sizes.empty(), "Input tensor must have one or more dimensions");
85   TORCH_CHECK(old_sizes[old_sizes.size()-1] == 2, "Tensor must have a last dimension of size 2");
86   SymDimVector new_sizes(old_sizes.begin(), old_sizes.end() - 1);
87 
88   const auto new_strides = computeStrideForViewAsComplex(self.sym_strides());
89   const auto complex_type = c10::toComplexType(self.scalar_type());
90 
91   TORCH_CHECK(self.sym_storage_offset() % 2 == 0, "Tensor must have a storage_offset divisible by 2");
92   const auto new_storage_offset = self.sym_storage_offset() / 2;
93 
94   return view_tensor(self, complex_type, new_storage_offset, new_sizes, new_strides);
95 }
96 
97 } // namespace at::native
98