1 /*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9 #include <executorch/runtime/kernel/kernel_includes.h>
10 #include <executorch/runtime/platform/compiler.h>
11 #include <math.h>
12 #include <string.h>
13
14 namespace torch {
15 namespace executor {
16 namespace native {
17 using Tensor = exec_aten::Tensor;
18 using ScalarType = exec_aten::ScalarType;
19 using Scalar = exec_aten::Scalar;
20 namespace {
21
22 /**
23 * Returns true if the two arrays are close according to the description on
24 * `tensors_are_close()`.
25 *
26 * T must be a floating point type. Non-floating point data should be compared
27 * directly.
28 */
29 template <typename T>
data_is_close(const T * a,const T * b,size_t numel,double rtol,double atol)30 bool data_is_close(
31 const T* a,
32 const T* b,
33 size_t numel,
34 double rtol,
35 double atol) {
36 for (size_t i = 0; i < numel; i++) {
37 if (rtol == 0 && atol == 0) {
38 // Exact comparison; avoid unnecessary math.
39 if (a[i] != b[i]) {
40 return false;
41 }
42 } else {
43 auto allowed_error = atol + fabs(rtol * b[i]);
44 auto actual_error = fabs(a[i] - b[i]);
45 if (!isfinite(actual_error) || actual_error > allowed_error) {
46 return false;
47 }
48 }
49 }
50 return true;
51 }
52
53 /**
54 * Returns true if the tensors are of the same shape and dtype, and if all
55 * elements are close to each other.
56 *
57 * A number A is close to B when either:
58 *
59 * (1) A is equal to B.
60 * (2) The error abs(A - B) is finite and less than the max error
61 * (atol + abs(rtol * B)).
62 *
63 * NOTE: rtol/atol are ignored for non-floating-point dtypes.
64 */
tensors_are_close(const Tensor & a,const Tensor & b,double rtol,double atol)65 bool tensors_are_close(
66 const Tensor& a,
67 const Tensor& b,
68 double rtol,
69 double atol) {
70 // TODO(dbort): Listen to strides instead of assuming that the data is
71 // contiguous.
72
73 if (a.scalar_type() == ScalarType::Float) {
74 return data_is_close<float>(
75 a.const_data_ptr<float>(),
76 b.const_data_ptr<float>(),
77 a.numel(),
78 rtol,
79 atol);
80 } else if (a.scalar_type() == ScalarType::Double) {
81 return data_is_close<double>(
82 a.const_data_ptr<double>(),
83 b.const_data_ptr<double>(),
84 a.numel(),
85 rtol,
86 atol);
87 } else {
88 // Non-floating-point types can be compared bitwise.
89 return memcmp(a.mutable_data_ptr(), b.mutable_data_ptr(), a.nbytes()) == 0;
90 }
91 }
92 } // namespace
93
allclose_out(const Tensor & self,const Tensor & other,double rtol,double atol,ET_UNUSED bool equal_nan,ET_UNUSED bool dummy_param,Tensor & out)94 Tensor& allclose_out(
95 const Tensor& self,
96 const Tensor& other,
97 double rtol,
98 double atol,
99 ET_UNUSED bool equal_nan,
100 ET_UNUSED bool dummy_param,
101 Tensor& out) {
102 ET_CHECK_SAME_SHAPE_AND_DTYPE2(self, other);
103 ET_CHECK_MSG(
104 out.scalar_type() == ScalarType::Bool,
105 "Out tensor must be type Bool; saw type %" PRId8,
106 static_cast<int8_t>(out.scalar_type()));
107 ET_CHECK_MSG(
108 tensors_have_same_dim_order(self, other, out),
109 "self, other and out tensors should have same dim order");
110 ET_CHECK_MSG(
111 out.numel() == 1,
112 "Out tensor must be a single element; saw %zu elements",
113 (size_t)out.numel());
114 auto out_data = out.mutable_data_ptr<bool>();
115 out_data[0] = tensors_are_close(self, other, rtol, atol);
116 return out;
117 }
118
119 /**
120 * Note: This custom operator contains two variants: allclose.Tensor (a
121 * functional variant, no inplace mutating on the arguments) and allclose.out
122 * (an out variant, mutating out). We need to register both into the PyTorch
123 * runtime so that they can be visible from ExecuTorch compiler side. Eventually
124 * only allclose.out will be seen from ExecuTorch runtime. With this setup, the
125 * portable kernel for allclose.Tensor can be implemented as a wrapper of
126 * allclose.out. We can easily instantiate an at::Tensor for the out argument,
127 * then pass it into allclose.out. This logic will only need to work out in
128 * "ATen mode" for ExecuTorch compiler, since we won't expose allclose.Tensor in
129 * ExecuTorch runtime.
130 */
allclose_tensor(ET_UNUSED const Tensor & self,ET_UNUSED const Tensor & other,ET_UNUSED double rtol,ET_UNUSED double atol,ET_UNUSED bool equal_nan,ET_UNUSED bool dummy_param)131 Tensor allclose_tensor(
132 ET_UNUSED const Tensor& self,
133 ET_UNUSED const Tensor& other,
134 ET_UNUSED double rtol,
135 ET_UNUSED double atol,
136 ET_UNUSED bool equal_nan,
137 ET_UNUSED bool dummy_param) {
138 #ifdef USE_ATEN_LIB
139 Tensor out =
140 torch::tensor({false}, c10::TensorOptions(c10::ScalarType::Bool));
141 allclose_out(self, other, rtol, atol, equal_nan, dummy_param, out);
142 return out;
143 #else
144 ET_ASSERT_UNREACHABLE();
145 #endif
146 }
147
allclose_out(KernelRuntimeContext & ctx,const Tensor & self,const Tensor & other,double rtol,double atol,ET_UNUSED bool equal_nan,ET_UNUSED bool dummy_param,Tensor & out)148 Tensor& allclose_out(
149 KernelRuntimeContext& ctx,
150 const Tensor& self,
151 const Tensor& other,
152 double rtol,
153 double atol,
154 ET_UNUSED bool equal_nan,
155 ET_UNUSED bool dummy_param,
156 Tensor& out) {
157 (void)ctx;
158 // TODO(larryliu): Add a context arg to the real op function and remove this
159 // wrapper
160 return allclose_out(self, other, rtol, atol, equal_nan, dummy_param, out);
161 }
162
allclose_tensor(ET_UNUSED KernelRuntimeContext & ctx,ET_UNUSED const Tensor & self,ET_UNUSED const Tensor & other,ET_UNUSED double rtol,ET_UNUSED double atol,ET_UNUSED bool equal_nan,ET_UNUSED bool dummy_param)163 Tensor allclose_tensor(
164 ET_UNUSED KernelRuntimeContext& ctx,
165 ET_UNUSED const Tensor& self,
166 ET_UNUSED const Tensor& other,
167 ET_UNUSED double rtol,
168 ET_UNUSED double atol,
169 ET_UNUSED bool equal_nan,
170 ET_UNUSED bool dummy_param) {
171 // TODO(larryliu): Add a context arg to the real op function and remove this
172 // wrapper
173 ET_ASSERT_UNREACHABLE();
174 }
175 } // namespace native
176 } // namespace executor
177 } // namespace torch
178