xref: /aosp_15_r20/external/tensorflow/tensorflow/lite/delegates/gpu/metal/kernels/elementwise_test.mm (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7    http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#import <XCTest/XCTest.h>
17
18#include "tensorflow/lite/delegates/gpu/common/status.h"
19#include "tensorflow/lite/delegates/gpu/common/tasks/elementwise_test_util.h"
20#include "tensorflow/lite/delegates/gpu/metal/kernels/test_util.h"
21
22@interface ElementwiseTest : XCTestCase
23@end
24
25@implementation ElementwiseTest {
26  tflite::gpu::metal::MetalExecutionEnvironment exec_env_;
27}
28
29- (void)testAbsUnit {
30  auto status = AbsTest(&exec_env_);
31  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
32}
33
34- (void)testCosUnit {
35  auto status = CosTest(&exec_env_);
36  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
37}
38
39- (void)testCopyUnit {
40  auto status = CopyTest(&exec_env_);
41  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
42}
43
44- (void)testEluUnit {
45  auto status = EluTest(&exec_env_);
46  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
47}
48
49- (void)testExpUnit {
50  auto status = ExpTest(&exec_env_);
51  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
52}
53
54- (void)testFloorUnit {
55  auto status = FloorTest(&exec_env_);
56  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
57}
58
59- (void)testFloorDivUnit {
60  auto status = FloorDivTest(&exec_env_);
61  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
62}
63
64- (void)testFloorModUnit {
65  auto status = FloorModTest(&exec_env_);
66  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
67}
68
69- (void)testHardSwishUnit {
70  auto status = HardSwishTest(&exec_env_);
71  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
72}
73
74- (void)testLogUnit {
75  auto status = LogTest(&exec_env_);
76  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
77}
78
79- (void)testNegUnit {
80  auto status = NegTest(&exec_env_);
81  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
82}
83
84- (void)testRsqrtUnit {
85  auto status = RsqrtTest(&exec_env_);
86  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
87}
88
89- (void)testSigmoidUnit {
90  auto status = SigmoidTest(&exec_env_);
91  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
92}
93
94- (void)testSinUnit {
95  auto status = SinTest(&exec_env_);
96  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
97}
98
99- (void)testSqrtUnit {
100  auto status = SqrtTest(&exec_env_);
101  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
102}
103
104- (void)testSquareUnit {
105  auto status = SquareTest(&exec_env_);
106  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
107}
108
109- (void)testTanhUnit {
110  auto status = TanhTest(&exec_env_);
111  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
112}
113
114- (void)testSubUnit {
115  auto status = SubTest(&exec_env_);
116  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
117}
118
119- (void)testSquaredDiffUnit {
120  auto status = SquaredDiffTest(&exec_env_);
121  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
122}
123
124- (void)testDivUnit {
125  auto status = DivTest(&exec_env_);
126  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
127}
128
129- (void)testPowUnit {
130  auto status = PowTest(&exec_env_);
131  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
132}
133
134- (void)testAddUnit {
135  auto status = AddTest(&exec_env_);
136  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
137}
138
139- (void)testMaximumUnit {
140  auto status = MaximumTest(&exec_env_);
141  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
142}
143
144- (void)testMaximumWithScalarUnit {
145  auto status = MaximumWithScalarTest(&exec_env_);
146  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
147}
148
149- (void)testMaximumWithConstantLinearTensorUnit {
150  auto status = MaximumWithConstantLinearTensorTest(&exec_env_);
151  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
152}
153
154- (void)testMaximumWithConstantHWCTensorUnit {
155  auto status = MaximumWithConstantHWCTensorTest(&exec_env_);
156  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
157}
158
159- (void)testMaximumWithConstantHWCTensorBroadcastChannelsUnit {
160  auto status = MaximumWithConstantHWCTensorBroadcastChannelsTest(&exec_env_);
161  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
162}
163
164- (void)testMinimumUnit {
165  auto status = MinimumTest(&exec_env_);
166  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
167}
168
169- (void)testMinimumWithScalarUnit {
170  auto status = MinimumWithScalarTest(&exec_env_);
171  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
172}
173
174- (void)testMulUnit {
175  auto status = MulTest(&exec_env_);
176  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
177}
178
179- (void)testMulBroadcastHWUnit {
180  auto status = MulBroadcastHWTest(&exec_env_);
181  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
182}
183
184- (void)testMulBroadcastChannelsUnit {
185  auto status = MulBroadcastChannelsTest(&exec_env_);
186  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
187}
188
189- (void)testSubWithScalarAtFirstPositionUnit {
190  auto status = SubWithScalarAtFirstPositionTest(&exec_env_);
191  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
192}
193
194- (void)testLessUnit {
195  auto status = LessTest(&exec_env_);
196  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
197}
198
199- (void)testLessEqualUnit {
200  auto status = LessEqualTest(&exec_env_);
201  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
202}
203
204- (void)testGreaterUnit {
205  auto status = GreaterTest(&exec_env_);
206  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
207}
208
209- (void)testGreaterEqualUnit {
210  auto status = GreaterEqualTest(&exec_env_);
211  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
212}
213
214- (void)testEqualUnit {
215  auto status = EqualTest(&exec_env_);
216  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
217}
218
219- (void)testNotEqualUnit {
220  auto status = NotEqualTest(&exec_env_);
221  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
222}
223
224- (void)testCosBroadcast {
225  auto status = CosBroadcastTest(&exec_env_);
226  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
227}
228
229- (void)testMaximumScalarBroadcastInput {
230  auto status = MaximumScalarBroadcastInputTest(&exec_env_);
231  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
232}
233
234- (void)testMulLinearBroadcastInput {
235  auto status = MulLinearBroadcastInputTest(&exec_env_);
236  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
237}
238
239- (void)testMulBroadcastBothInputs {
240  auto status = MulBroadcastBothInputsTest(&exec_env_);
241  XCTAssertTrue(status.ok(), @"%s", std::string(status.message()).c_str());
242}
243
244@end
245