xref: /aosp_15_r20/external/executorch/kernels/portable/cpu/op_allclose.cpp (revision 523fa7a60841cd1ecfb9cc4201f1ca8b03ed023a)
1 /*
2  * Copyright (c) Meta Platforms, Inc. and affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 #include <executorch/runtime/platform/compiler.h>
11 #include <math.h>
12 #include <string.h>
13 
14 namespace torch {
15 namespace executor {
16 namespace native {
17 using Tensor = exec_aten::Tensor;
18 using ScalarType = exec_aten::ScalarType;
19 using Scalar = exec_aten::Scalar;
20 namespace {
21 
22 /**
23  * Returns true if the two arrays are close according to the description on
24  * `tensors_are_close()`.
25  *
26  * T must be a floating point type. Non-floating point data should be compared
27  * directly.
28  */
29 template <typename T>
data_is_close(const T * a,const T * b,size_t numel,double rtol,double atol)30 bool data_is_close(
31     const T* a,
32     const T* b,
33     size_t numel,
34     double rtol,
35     double atol) {
36   for (size_t i = 0; i < numel; i++) {
37     if (rtol == 0 && atol == 0) {
38       // Exact comparison; avoid unnecessary math.
39       if (a[i] != b[i]) {
40         return false;
41       }
42     } else {
43       auto allowed_error = atol + fabs(rtol * b[i]);
44       auto actual_error = fabs(a[i] - b[i]);
45       if (!isfinite(actual_error) || actual_error > allowed_error) {
46         return false;
47       }
48     }
49   }
50   return true;
51 }
52 
53 /**
54  * Returns true if the tensors are of the same shape and dtype, and if all
55  * elements are close to each other.
56  *
57  * A number A is close to B when either:
58  *
59  * (1) A is equal to B.
60  * (2) The error abs(A - B) is finite and less than the max error
61  *     (atol + abs(rtol * B)).
62  *
63  * NOTE: rtol/atol are ignored for non-floating-point dtypes.
64  */
tensors_are_close(const Tensor & a,const Tensor & b,double rtol,double atol)65 bool tensors_are_close(
66     const Tensor& a,
67     const Tensor& b,
68     double rtol,
69     double atol) {
70   // TODO(dbort): Listen to strides instead of assuming that the data is
71   // contiguous.
72 
73   if (a.scalar_type() == ScalarType::Float) {
74     return data_is_close<float>(
75         a.const_data_ptr<float>(),
76         b.const_data_ptr<float>(),
77         a.numel(),
78         rtol,
79         atol);
80   } else if (a.scalar_type() == ScalarType::Double) {
81     return data_is_close<double>(
82         a.const_data_ptr<double>(),
83         b.const_data_ptr<double>(),
84         a.numel(),
85         rtol,
86         atol);
87   } else {
88     // Non-floating-point types can be compared bitwise.
89     return memcmp(a.mutable_data_ptr(), b.mutable_data_ptr(), a.nbytes()) == 0;
90   }
91 }
92 } // namespace
93 
allclose_out(const Tensor & self,const Tensor & other,double rtol,double atol,ET_UNUSED bool equal_nan,ET_UNUSED bool dummy_param,Tensor & out)94 Tensor& allclose_out(
95     const Tensor& self,
96     const Tensor& other,
97     double rtol,
98     double atol,
99     ET_UNUSED bool equal_nan,
100     ET_UNUSED bool dummy_param,
101     Tensor& out) {
102   ET_CHECK_SAME_SHAPE_AND_DTYPE2(self, other);
103   ET_CHECK_MSG(
104       out.scalar_type() == ScalarType::Bool,
105       "Out tensor must be type Bool; saw type %" PRId8,
106       static_cast<int8_t>(out.scalar_type()));
107   ET_CHECK_MSG(
108       tensors_have_same_dim_order(self, other, out),
109       "self, other and out tensors should have same dim order");
110   ET_CHECK_MSG(
111       out.numel() == 1,
112       "Out tensor must be a single element; saw %zu elements",
113       (size_t)out.numel());
114   auto out_data = out.mutable_data_ptr<bool>();
115   out_data[0] = tensors_are_close(self, other, rtol, atol);
116   return out;
117 }
118 
119 /**
120  * Note: This custom operator contains two variants: allclose.Tensor (a
121  * functional variant, no inplace mutating on the arguments) and allclose.out
122  * (an out variant, mutating out). We need to register both into the PyTorch
123  * runtime so that they can be visible from ExecuTorch compiler side. Eventually
124  * only allclose.out will be seen from ExecuTorch runtime. With this setup, the
125  * portable kernel for allclose.Tensor can be implemented as a wrapper of
126  * allclose.out. We can easily instantiate an at::Tensor for the out argument,
127  * then pass it into allclose.out. This logic will only need to work out in
128  * "ATen mode" for ExecuTorch compiler, since we won't expose allclose.Tensor in
129  * ExecuTorch runtime.
130  */
allclose_tensor(ET_UNUSED const Tensor & self,ET_UNUSED const Tensor & other,ET_UNUSED double rtol,ET_UNUSED double atol,ET_UNUSED bool equal_nan,ET_UNUSED bool dummy_param)131 Tensor allclose_tensor(
132     ET_UNUSED const Tensor& self,
133     ET_UNUSED const Tensor& other,
134     ET_UNUSED double rtol,
135     ET_UNUSED double atol,
136     ET_UNUSED bool equal_nan,
137     ET_UNUSED bool dummy_param) {
138 #ifdef USE_ATEN_LIB
139   Tensor out =
140       torch::tensor({false}, c10::TensorOptions(c10::ScalarType::Bool));
141   allclose_out(self, other, rtol, atol, equal_nan, dummy_param, out);
142   return out;
143 #else
144   ET_ASSERT_UNREACHABLE();
145 #endif
146 }
147 
allclose_out(KernelRuntimeContext & ctx,const Tensor & self,const Tensor & other,double rtol,double atol,ET_UNUSED bool equal_nan,ET_UNUSED bool dummy_param,Tensor & out)148 Tensor& allclose_out(
149     KernelRuntimeContext& ctx,
150     const Tensor& self,
151     const Tensor& other,
152     double rtol,
153     double atol,
154     ET_UNUSED bool equal_nan,
155     ET_UNUSED bool dummy_param,
156     Tensor& out) {
157   (void)ctx;
158   // TODO(larryliu): Add a context arg to the real op function and remove this
159   // wrapper
160   return allclose_out(self, other, rtol, atol, equal_nan, dummy_param, out);
161 }
162 
allclose_tensor(ET_UNUSED KernelRuntimeContext & ctx,ET_UNUSED const Tensor & self,ET_UNUSED const Tensor & other,ET_UNUSED double rtol,ET_UNUSED double atol,ET_UNUSED bool equal_nan,ET_UNUSED bool dummy_param)163 Tensor allclose_tensor(
164     ET_UNUSED KernelRuntimeContext& ctx,
165     ET_UNUSED const Tensor& self,
166     ET_UNUSED const Tensor& other,
167     ET_UNUSED double rtol,
168     ET_UNUSED double atol,
169     ET_UNUSED bool equal_nan,
170     ET_UNUSED bool dummy_param) {
171   // TODO(larryliu): Add a context arg to the real op function and remove this
172   // wrapper
173   ET_ASSERT_UNREACHABLE();
174 }
175 } // namespace native
176 } // namespace executor
177 } // namespace torch
178