1 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
2 #include <ATen/core/Tensor.h>
3 #include <ATen/TensorOperators.h>
4
5 #ifndef AT_PER_OPERATOR_HEADERS
6 #include <ATen/Functions.h>
7 #include <ATen/NativeFunctions.h>
8 #else
9 #include <ATen/ops/arange.h>
10 #include <ATen/ops/cartesian_prod_native.h>
11 #include <ATen/ops/combinations_native.h>
12 #include <ATen/ops/empty.h>
13 #include <ATen/ops/full.h>
14 #include <ATen/ops/meshgrid.h>
15 #include <ATen/ops/stack.h>
16 #endif
17
18 #include <vector>
19
20 namespace {
21
22 using namespace at;
23
_triu_mask(int64_t n,int64_t dims,bool diagonal,TensorOptions opt)24 Tensor _triu_mask(int64_t n, int64_t dims, bool diagonal, TensorOptions opt) {
25 // get a mask that has value 1 whose indices satisfies i < j < k < ...
26 // or i <= j <= k <= ... (depending on diagonal)
27 Tensor range = at::arange(n, opt.dtype(kLong));
28 std::vector<Tensor> index_grids = at::meshgrid(std::vector<Tensor>(dims, range), "ij");
29 Tensor mask = at::full(index_grids[0].sizes(), true, opt.dtype(kBool));
30 if(diagonal) {
31 for(int64_t i = 0; i < dims - 1; i++) {
32 mask *= index_grids[i] <= index_grids[i+1];
33 }
34 } else {
35 for(int64_t i = 0; i < dims - 1; i++) {
36 mask *= index_grids[i] < index_grids[i+1];
37 }
38 }
39 return mask;
40 }
41
42 } // namespace
43
44 namespace at::native {
45
cartesian_prod(TensorList tensors)46 Tensor cartesian_prod(TensorList tensors) {
47 for(const Tensor &t : tensors) {
48 TORCH_CHECK(t.dim() == 1, "Expect a 1D vector, but got shape ", t.sizes());
49 }
50 if (tensors.size() == 1) {
51 return tensors[0];
52 }
53 std::vector<Tensor> grids = at::meshgrid(tensors, "ij");
54 for(Tensor &t : grids) {
55 t = t.flatten();
56 }
57 return at::stack(grids, 1);
58 }
59
combinations(const Tensor & self,int64_t r,bool with_replacement)60 Tensor combinations(const Tensor& self, int64_t r, bool with_replacement) {
61 TORCH_CHECK(self.dim() == 1, "Expect a 1D vector, but got shape ", self.sizes());
62 TORCH_CHECK(r >= 0, "Expect a non-negative number, but got ", r);
63 if (r == 0) {
64 return at::empty({0}, self.options());
65 }
66 int64_t num_elements = self.numel();
67 std::vector<Tensor> grids = at::meshgrid(std::vector<Tensor>(r, self), "ij");
68 Tensor mask = _triu_mask(num_elements, r, with_replacement, self.options());
69 for(Tensor &t : grids) {
70 t = t.masked_select(mask);
71 }
72 return at::stack(grids, 1);
73 }
74
75 } // namespace at::native
76