xref: /aosp_15_r20/external/tensorflow/tensorflow/c/ops_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/ops.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/fake_input.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/framework/shape_inference_testutil.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace tensorflow {
31 namespace {
32 
TEST(OpsTest,TestBasicOpRegistration)33 TEST(OpsTest, TestBasicOpRegistration) {
34   TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp");
35   TF_OpDefinitionBuilderAddAttr(builder, "attr1: string");
36   TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
37   TF_OpDefinitionBuilderAddInput(builder, "input2: uint16");
38   TF_OpDefinitionBuilderAddOutput(builder, "output1: uint32");
39   TF_Status* status = TF_NewStatus();
40   TF_RegisterOpDefinition(builder, status);
41   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
42   TF_Buffer* op_list_buffer = TF_GetAllOpList();
43   ::tensorflow::OpList op_list;
44   op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
45   bool found = false;
46   for (const auto& op : op_list.op()) {
47     if (op.name() == "SomeOp") {
48       ASSERT_EQ(2, op.input_arg_size());
49       ASSERT_EQ("input1", op.input_arg(0).name());
50       ASSERT_EQ(::tensorflow::DT_UINT8, op.input_arg(0).type());
51       ASSERT_EQ(1, op.attr_size());
52       ASSERT_EQ("string", op.attr(0).type());
53       found = true;
54     }
55   }
56   EXPECT_TRUE(found);
57   TF_DeleteStatus(status);
58   TF_DeleteBuffer(op_list_buffer);
59 }
60 
identity_shape_fn(TF_ShapeInferenceContext * ctx,TF_Status * status)61 void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
62   TF_ShapeHandle* handle = TF_NewShapeHandle();
63   TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
64   ASSERT_EQ(TF_OK, TF_GetCode(status));
65   TF_ShapeInferenceContextSetOutput(ctx, 0, handle, status);
66   TF_DeleteShapeHandle(handle);
67 }
68 
TEST(OpsTest,TestShapeInference_IdentityFunction)69 TEST(OpsTest, TestShapeInference_IdentityFunction) {
70   ShapeInferenceTestOp op("SomeTestOp");
71 
72   TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp");
73   TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
74   TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
75   TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
76   TF_Status* status = TF_NewStatus();
77   TF_RegisterOpDefinition(builder, status);
78   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
79 
80   TF_ASSERT_OK(
81       shape_inference::ShapeInferenceTestutil::InferShapes(op, "[1,2]", "in0"));
82   TF_DeleteStatus(status);
83 }
84 
TEST(OpsTest,TestShapeInference_UnknownShape)85 TEST(OpsTest, TestShapeInference_UnknownShape) {
86   ShapeInferenceTestOp op("UnknownShapeOp");
87 
88   TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("UnknownShapeOp");
89   TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
90   TF_OpDefinitionBuilderAddInput(builder, "input2: uint32");
91   TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
92   TF_OpDefinitionBuilderAddOutput(builder, "output2: uint8");
93   TF_OpDefinitionBuilderSetShapeInferenceFunction(
94       builder, &TF_ShapeInferenceContextSetUnknownShape);
95   TF_Status* status = TF_NewStatus();
96   TF_RegisterOpDefinition(builder, status);
97   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
98 
99   TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
100       op, "[1,2];[3,4]", "?;?"));
101   TF_DeleteStatus(status);
102 }
103 
104 // Creates an output whose shape is a vector of length
105 // TF_ShapeInferenceContextRank.
vectorize_shape_fn(TF_ShapeInferenceContext * ctx,TF_Status * status)106 void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
107   TF_ShapeHandle* handle = TF_NewShapeHandle();
108   TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
109   ASSERT_EQ(TF_OK, TF_GetCode(status));
110   TF_ShapeHandle* new_shape = TF_ShapeInferenceContextVectorFromSize(
111       ctx, TF_ShapeInferenceContextRank(ctx, handle));
112   TF_ShapeInferenceContextSetOutput(ctx, 0, new_shape, status);
113   TF_DeleteShapeHandle(handle);
114   TF_DeleteShapeHandle(new_shape);
115 }
116 
TEST(OpsTest,TestShapeInference_VectorizeFunction)117 TEST(OpsTest, TestShapeInference_VectorizeFunction) {
118   ShapeInferenceTestOp op("VectorizeTestOp");
119 
120   TF_OpDefinitionBuilder* builder =
121       TF_NewOpDefinitionBuilder("VectorizeTestOp");
122   TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
123   TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
124   TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn);
125   TF_Status* status = TF_NewStatus();
126   TF_RegisterOpDefinition(builder, status);
127   ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
128 
129   TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
130       op, "[4,5,9]", "[3]"));
131   TF_DeleteStatus(status);
132 }
133 
TEST(OpsTest,AttributeAccessors)134 TEST(OpsTest, AttributeAccessors) {
135   TF_OpDefinitionBuilder* builder =
136       TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
137   TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
138   TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
139   TF_OpDefinitionBuilderSetIsCommutative(builder, true);
140   TF_OpDefinitionBuilderSetIsAggregate(builder, true);
141   TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true);
142   std::string deprecation_msg = "use something else instead";
143   TF_OpDefinitionBuilderDeprecated(builder, 4, deprecation_msg.c_str());
144 
145   TF_Status* status = TF_NewStatus();
146   TF_RegisterOpDefinition(builder, status);
147   ASSERT_EQ(TF_OK, TF_GetCode(status));
148 
149   TF_Buffer* op_list_buffer = TF_GetAllOpList();
150   ::tensorflow::OpList op_list;
151   op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
152   bool found = false;
153   for (const auto& op : op_list.op()) {
154     if (op.name() == "AttributeAccessorsOp") {
155       ASSERT_TRUE(op.is_commutative());
156       ASSERT_TRUE(op.is_aggregate());
157       ASSERT_TRUE(op.allows_uninitialized_input());
158       ASSERT_EQ(4, op.deprecation().version());
159       ASSERT_EQ(deprecation_msg, op.deprecation().explanation());
160       ASSERT_EQ(2, op.attr_size());
161       ASSERT_EQ("int", op.attr(0).type());
162       ASSERT_EQ(2, op.attr(0).minimum());
163       ASSERT_EQ("string", op.attr(1).type());
164       ASSERT_EQ("my string", op.attr(1).default_value().s());
165       found = true;
166     }
167   }
168   ASSERT_TRUE(found);
169   TF_DeleteStatus(status);
170   TF_DeleteBuffer(op_list_buffer);
171 }
172 
173 #define C_CTX(x) reinterpret_cast<TF_ShapeInferenceContext*>(x)
174 #define C_SHP(x) reinterpret_cast<TF_ShapeHandle*>(x)
175 
MakeOpDef(int num_inputs,int num_outputs)176 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
177   OpRegistrationData op_reg_data;
178   OpDefBuilder b("dummy");
179   for (int i = 0; i < num_inputs; ++i) {
180     b.Input(strings::StrCat("i", i, ": float"));
181   }
182   for (int i = 0; i < num_outputs; ++i) {
183     b.Output(strings::StrCat("o", i, ": float"));
184   }
185   CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
186   return op_reg_data.op_def;
187 }
188 
189 // Tests for shape inference
190 
S(std::initializer_list<int64_t> dims)191 PartialTensorShape S(std::initializer_list<int64_t> dims) {
192   return PartialTensorShape(dims);
193 }
194 
Unknown()195 PartialTensorShape Unknown() { return PartialTensorShape(); }
196 
TEST(OpsTest,ShapeInferenceWithRank)197 TEST(OpsTest, ShapeInferenceWithRank) {
198   NodeDef def;
199   shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
200                                       {S({10, 20, 30})}, {}, {}, {});
201 
202   shape_inference::ShapeHandle in0 = c.input(0);
203   shape_inference::ShapeHandle s1;
204 
205   TF_Status* status = TF_NewStatus();
206   TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
207                                          status);
208   EXPECT_EQ("[10,20,30]", c.DebugString(s1));
209   EXPECT_EQ(TF_OK, TF_GetCode(status));
210 
211   TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
212                                           status);
213   EXPECT_EQ("[10,20,30]", c.DebugString(s1));
214   EXPECT_EQ(TF_OK, TF_GetCode(status));
215 
216   TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 6, C_SHP(&s1),
217                                           status);
218   ASSERT_NE(TF_OK, TF_GetCode(status));
219 
220   TF_SetStatus(status, TF_OK, "");
221   TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
222                                          status);
223   ASSERT_NE(TF_OK, TF_GetCode(status));
224 
225   TF_SetStatus(status, TF_OK, "");
226   TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
227                                    status);
228   ASSERT_EQ(TF_OK, TF_GetCode(status));
229 
230   TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 4, C_SHP(&s1),
231                                    status);
232   ASSERT_NE(TF_OK, TF_GetCode(status));
233 
234   TF_DeleteStatus(status);
235 }
236 
TEST(OpsTest,ShapeInferenceWithRank_UnknownRank)237 TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
238   NodeDef def;
239   shape_inference::InferenceContext c(0, def, MakeOpDef(2, 2),
240                                       {Unknown(), S({1, -1, 3})}, {}, {}, {});
241 
242   shape_inference::ShapeHandle in0 = c.input(0);
243   shape_inference::ShapeHandle s1;
244 
245   // WithRankAtMost and WithRankAtLeast on a shape with unknown dimensionality
246   // always succeed.
247   TF_Status* status = TF_NewStatus();
248   TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
249                                          status);
250   EXPECT_EQ("?", c.DebugString(s1));
251   EXPECT_EQ(TF_OK, TF_GetCode(status));
252 
253   TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
254                                           status);
255   EXPECT_EQ("?", c.DebugString(s1));
256   EXPECT_EQ(TF_OK, TF_GetCode(status));
257 
258   TF_DeleteStatus(status);
259 }
260 
TEST(OpsTest,ShapeInferenceConcatenateShapes)261 TEST(OpsTest, ShapeInferenceConcatenateShapes) {
262   NodeDef def;
263   shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
264                                       {S({1, 2}), S({3, 4})}, {}, {}, {});
265   ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c)));
266   shape_inference::ShapeHandle a = c.input(0);
267   shape_inference::ShapeHandle b = c.input(1);
268   TF_ShapeHandle* result = TF_NewShapeHandle();
269   TF_Status* status = TF_NewStatus();
270   TF_ShapeInferenceContextConcatenateShapes(C_CTX(&c), C_SHP(&a), C_SHP(&b),
271                                             result, status);
272   EXPECT_EQ(
273       "[1,2,3,4]",
274       c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(result)));
275   EXPECT_EQ(TF_OK, TF_GetCode(status));
276   TF_DeleteShapeHandle(result);
277   TF_DeleteStatus(status);
278 }
279 
TEST(OpsTest,DimensionHandleValueKnown)280 TEST(OpsTest, DimensionHandleValueKnown) {
281   NodeDef def;
282   shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
283                                       {S({1, 2}), S({3, 4})}, {}, {}, {});
284   TF_ShapeHandle* handle =
285       TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43);
286   ASSERT_EQ(
287       "[43]",
288       c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
289   ASSERT_EQ(1, TF_ShapeInferenceContextRankKnown(C_CTX(&c), handle));
290   ASSERT_EQ(1, TF_ShapeInferenceContextRank(C_CTX(&c), handle));
291 
292   TF_DimensionHandle* dim_handle = TF_NewDimensionHandle();
293   TF_ShapeInferenceContextDim(C_CTX(&c), handle, 0, dim_handle);
294   ASSERT_EQ(1, TF_DimensionHandleValueKnown(dim_handle));
295   ASSERT_EQ(43, TF_DimensionHandleValue(dim_handle));
296   TF_DeleteShapeHandle(handle);
297   TF_DeleteDimensionHandle(dim_handle);
298 }
299 
TEST(OpsTest,ShapeInferenceSubshape)300 TEST(OpsTest, ShapeInferenceSubshape) {
301   NodeDef def;
302   shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
303                                       {S({10, 20, 30, 40, 50})}, {}, {}, {});
304   ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0)));
305 
306   TF_ShapeHandle* handle = TF_NewShapeHandle();
307   TF_Status* status = TF_NewStatus();
308   TF_ShapeInferenceContextGetInput(C_CTX(&c), 0, handle, status);
309   ASSERT_EQ(TF_OK, TF_GetCode(status));
310   TF_ShapeInferenceContextSubshape(C_CTX(&c), handle, 1, -1, handle, status);
311   ASSERT_EQ(TF_OK, TF_GetCode(status));
312   ASSERT_EQ(
313       "[20,30,40]",
314       c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
315   TF_DeleteStatus(status);
316   TF_DeleteShapeHandle(handle);
317 }
318 
TEST(OpsTest,ShapeInferenceScalarShape)319 TEST(OpsTest, ShapeInferenceScalarShape) {
320   NodeDef def;
321   shape_inference::InferenceContext c(0, def, MakeOpDef(0, 0), {S({})}, {}, {},
322                                       {});
323   TF_ShapeHandle* TF_scalar_shape = TF_ShapeInferenceContextScalar(C_CTX(&c));
324   shape_inference::ShapeHandle* scalar_shape =
325       reinterpret_cast<shape_inference::ShapeHandle*>(TF_scalar_shape);
326   ASSERT_EQ("[]", c.DebugString(*scalar_shape));
327   TF_DeleteShapeHandle(TF_scalar_shape);
328 }
329 
330 }  // namespace
331 }  // namespace tensorflow
332