xref: /aosp_15_r20/external/pytorch/aten/src/ATen/native/TestOps.cpp (revision da0073e96a02ea20f0ac840b70461e3646d07c45)
1 // Copyright 2004-present Facebook. All Rights Reserved.
2 #define TORCH_ASSERT_ONLY_METHOD_OPERATORS
3 
4 #include <ATen/core/Tensor.h>
5 #include <ATen/FunctionalInverses.h>
6 #include <ATen/ScalarOps.h>
7 #include <ATen/Parallel.h>
8 
9 #ifndef AT_PER_OPERATOR_HEADERS
10 #include <ATen/Functions.h>
11 #include <ATen/NativeFunctions.h>
12 #else
13 #include <ATen/ops/_test_ambiguous_defaults_native.h>
14 #include <ATen/ops/_test_autograd_multiple_dispatch_native.h>
15 #include <ATen/ops/_test_autograd_multiple_dispatch_view_native.h>
16 #include <ATen/ops/_test_check_tensor_native.h>
17 #include <ATen/ops/_test_parallel_materialize_native.h>
18 #include <ATen/ops/_test_optional_filled_intlist_native.h>
19 #include <ATen/ops/_test_optional_floatlist_native.h>
20 #include <ATen/ops/_test_optional_intlist_native.h>
21 #include <ATen/ops/_test_string_default_native.h>
22 #include <ATen/ops/_test_warn_in_autograd_native.h>
23 #include <ATen/ops/empty_like.h>
24 #endif
25 
26 #include <c10/util/irange.h>
27 
28 namespace at::native {
29 
30 /// If addends is nullopt, return values.
31 /// Else, return a new tensor containing the elementwise sums.
_test_optional_intlist(const Tensor & values,at::OptionalIntArrayRef addends)32 Tensor _test_optional_intlist(
33     const Tensor& values,
34     at::OptionalIntArrayRef addends) {
35   if (!addends) {
36     return values;
37   }
38   TORCH_CHECK(values.dim() == 1);
39   Tensor output = at::empty_like(values);
40   auto inp = values.accessor<int,1>();
41   auto out = output.accessor<int,1>();
42   for (const auto i : c10::irange(values.size(0))) {
43     out[i] = inp[i] + addends->at(i);
44   }
45   return output;
46 }
47 
48 /// If addends is nullopt, return values.
49 /// Else, return a new tensor containing the elementwise sums.
_test_optional_floatlist(const Tensor & values,std::optional<ArrayRef<double>> addends)50 Tensor _test_optional_floatlist(
51     const Tensor& values,
52     std::optional<ArrayRef<double>> addends) {
53   if (!addends) {
54     return values;
55   }
56   TORCH_CHECK(values.dim() == 1);
57   Tensor output = at::empty_like(values);
58   auto inp = values.accessor<float,1>();
59   auto out = output.accessor<float,1>();
60   for (const auto i : c10::irange(values.size(0))) {
61     out[i] = inp[i] + addends->at(i);
62   }
63   return output;
64 }
65 
66 // Test default strings can handle escape sequences properly (although commas are broken)
_test_string_default(const Tensor & dummy,c10::string_view a,c10::string_view b)67 Tensor _test_string_default(const Tensor& dummy, c10::string_view a, c10::string_view b) {
68   const c10::string_view expect = "\"'\\";
69   TORCH_CHECK(a == expect, "Default A failed");
70   TORCH_CHECK(b == expect, "Default B failed");
71   return dummy;
72 }
73 
74 // Test that overloads with ambiguity created by defaulted parameters work.
75 // The operator declared first should have priority always
76 
77 // Overload a
_test_ambiguous_defaults(const Tensor & dummy,int64_t a,int64_t b)78 Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, int64_t b) {
79   TORCH_CHECK(a == 1);
80   TORCH_CHECK(b == 1);
81   return c10::scalar_to_tensor(1);
82 }
83 
84 // Overload b
_test_ambiguous_defaults(const Tensor & dummy,int64_t a,c10::string_view b)85 Tensor _test_ambiguous_defaults(const Tensor& dummy, int64_t a, c10::string_view b) {
86   TORCH_CHECK(a == 2);
87   TORCH_CHECK(b == "2");
88   return c10::scalar_to_tensor(2);
89 }
90 
_test_warn_in_autograd(const Tensor & self)91 Tensor _test_warn_in_autograd(const Tensor &self) {
92   return self.clone();
93 }
94 
95 // Test registration of per-dispatch-key derivatives in derivatives.yaml.
96 // See derivatives.yaml for dummy registrations.
97 
_test_autograd_multiple_dispatch_fullcoverage(const Tensor & self)98 Tensor _test_autograd_multiple_dispatch_fullcoverage(const Tensor &self) {
99   return self.clone();
100 }
101 
_test_autograd_multiple_dispatch_ntonly(const Tensor & self,bool b)102 Tensor _test_autograd_multiple_dispatch_ntonly(const Tensor &self, bool b) {
103   return self.clone();
104 }
105 
106 // Test derivative dispatch registration for view_copy ops
_test_autograd_multiple_dispatch_view(const Tensor & self)107 Tensor _test_autograd_multiple_dispatch_view(const Tensor &self) {
108   return self.view(-1);
109 }
110 
_test_check_tensor(const Tensor & self)111 Tensor _test_check_tensor(const Tensor& self) {
112   TORCH_CHECK_TENSOR_ALL(self, "Test message for TORCH_CHECK_TENSOR_ALL");
113   return self.clone();
114 }
115 
_test_parallel_materialize(const Tensor & self,int64_t num_parallel,bool skip_first)116 Tensor _test_parallel_materialize(const Tensor& self, int64_t num_parallel, bool skip_first) {
117   at::parallel_for(0, num_parallel, 1, [&](int64_t begin, int64_t end){
118     // NOTE: skip_first is meant to avoid triggering the materialization from
119     // the first thread, to ensure that the subthreads throw the error
120     // correctly. On some platforms, the first thread is the main thread and it
121     // begins executing the loop function much earlier than the subthreads.
122     if (skip_first && begin == 0 && end == 1) {
123       return;
124     } else {
125       self.mutable_data_ptr();
126     }
127   });
128   return self;
129 }
130 
131 } // namespace at::native
132 
133 namespace at::functionalization {
134 
135 // view ops must have a functional inverse registered
_test_autograd_multiple_dispatch_view_inverse(const at::Tensor & base,const at::Tensor & mutated_view,InverseReturnMode inverse_return_mode)136 Tensor FunctionalInverses::_test_autograd_multiple_dispatch_view_inverse(const at::Tensor& base, const at::Tensor& mutated_view, InverseReturnMode inverse_return_mode) {
137     TORCH_INTERNAL_ASSERT(false,
138     "Attempted to call _test_autograd_multiple_dispatch_view_inverse() during the functionalization pass. ",
139     "This function is for testing only and should never be called.");
140     return Tensor();
141 }
142 
143 } // namespace at::functionalization
144