xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/cpu/CrossKernel.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/native/Cross.h>
3 
4 #include <numeric>
5 #include <iterator>
6 #include <algorithm>
7 #include <vector>
8 
9 #include <ATen/core/Tensor.h>
10 #include <ATen/Dispatch.h>
11 #include <ATen/Parallel.h>
12 #include <ATen/TensorIterator.h>
13 #include <c10/util/irange.h>
14 namespace at::native {
15 namespace {
16 
17 template<typename scalar_t>
apply_cross(const Tensor & result,const Tensor & a,const Tensor & b,const int64_t dim)18 static void apply_cross(const Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) {
19   int64_t total = a.numel() / 3;
20   int64_t a_stride = a.stride(dim);
21   int64_t b_stride = b.stride(dim);
22   int64_t r_stride = result.stride(dim);
23 
24   const scalar_t *a_ptr = a.const_data_ptr<scalar_t>();
25   const scalar_t *b_ptr = b.const_data_ptr<scalar_t>();
26   scalar_t *r_ptr = result.data_ptr<scalar_t>();
27 
28   parallel_for(0, total, internal::GRAIN_SIZE, [&](int64_t s, int64_t e) {
29     const int64_t a_dim = a.dim();
30     std::vector<int64_t> position_in_dims(a_dim);
31     int64_t index_in_curr_dim = s;
32     int64_t a_start = 0;
33     int64_t b_start = 0;
34     int64_t r_start = 0;
35     for (const auto i : c10::irange(a.dim())) {
36       if (i == dim) continue;
37       position_in_dims[i] = index_in_curr_dim % a.size(i);
38       a_start += (index_in_curr_dim % a.size(i)) * a.stride(i);
39       b_start += (index_in_curr_dim % b.size(i)) * b.stride(i);
40       r_start += (index_in_curr_dim % result.size(i)) * result.stride(i);
41       index_in_curr_dim = index_in_curr_dim / a.size(i);
42     }
43 
44     while (s < e) {
45       r_ptr[r_start+0*r_stride] = a_ptr[a_start+1*a_stride]*b_ptr[b_start+2*b_stride] - a_ptr[a_start+2*a_stride]*b_ptr[b_start+1*b_stride];
46       r_ptr[r_start+1*r_stride] = a_ptr[a_start+2*a_stride]*b_ptr[b_start+0*b_stride] - a_ptr[a_start+0*a_stride]*b_ptr[b_start+2*b_stride];
47       r_ptr[r_start+2*r_stride] = a_ptr[a_start+0*a_stride]*b_ptr[b_start+1*b_stride] - a_ptr[a_start+1*a_stride]*b_ptr[b_start+0*b_stride];
48       s++;
49 
50       for (const auto i : c10::irange(a.dim())) {
51         if (i == dim) {
52           continue;
53         }
54         position_in_dims[i]++;
55         a_start += a.stride(i);
56         b_start += b.stride(i);
57         r_start += result.stride(i);
58         if (position_in_dims[i] == a.size(i) && i != a.dim()-1) {
59             a_start -= position_in_dims[i] * a.stride(i);
60             b_start -= position_in_dims[i] * b.stride(i);
61             r_start -= position_in_dims[i] * result.stride(i);
62             position_in_dims[i] = 0;
63         } else {
64           break;
65         }
66       }
67     }
68   });
69 }
70 
cross_kernel_impl(const Tensor & result,const Tensor & a,const Tensor & b,const int64_t dim)71 static void cross_kernel_impl(const Tensor& result, const Tensor& a, const Tensor& b, const int64_t dim) {
72   AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND2(kBFloat16, kHalf, result.scalar_type(), "cross", [&]() {
73     apply_cross<scalar_t>(result, a, b, dim);
74   });
75 }
76 
77 } // anonymous namespace
78 
79 REGISTER_DISPATCH(cross_stub, &cross_kernel_impl);
80 
81 } // namespace at::native
82