1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/ops.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "tensorflow/c/c_api.h"
20 #include "tensorflow/core/framework/attr_value.pb.h"
21 #include "tensorflow/core/framework/fake_input.h"
22 #include "tensorflow/core/framework/op_def.pb.h"
23 #include "tensorflow/core/framework/op_def_builder.h"
24 #include "tensorflow/core/framework/shape_inference_testutil.h"
25 #include "tensorflow/core/framework/tensor_testutil.h"
26 #include "tensorflow/core/framework/types.pb.h"
27 #include "tensorflow/core/lib/core/status_test_util.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace tensorflow {
31 namespace {
32
TEST(OpsTest,TestBasicOpRegistration)33 TEST(OpsTest, TestBasicOpRegistration) {
34 TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeOp");
35 TF_OpDefinitionBuilderAddAttr(builder, "attr1: string");
36 TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
37 TF_OpDefinitionBuilderAddInput(builder, "input2: uint16");
38 TF_OpDefinitionBuilderAddOutput(builder, "output1: uint32");
39 TF_Status* status = TF_NewStatus();
40 TF_RegisterOpDefinition(builder, status);
41 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
42 TF_Buffer* op_list_buffer = TF_GetAllOpList();
43 ::tensorflow::OpList op_list;
44 op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
45 bool found = false;
46 for (const auto& op : op_list.op()) {
47 if (op.name() == "SomeOp") {
48 ASSERT_EQ(2, op.input_arg_size());
49 ASSERT_EQ("input1", op.input_arg(0).name());
50 ASSERT_EQ(::tensorflow::DT_UINT8, op.input_arg(0).type());
51 ASSERT_EQ(1, op.attr_size());
52 ASSERT_EQ("string", op.attr(0).type());
53 found = true;
54 }
55 }
56 EXPECT_TRUE(found);
57 TF_DeleteStatus(status);
58 TF_DeleteBuffer(op_list_buffer);
59 }
60
identity_shape_fn(TF_ShapeInferenceContext * ctx,TF_Status * status)61 void identity_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
62 TF_ShapeHandle* handle = TF_NewShapeHandle();
63 TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
64 ASSERT_EQ(TF_OK, TF_GetCode(status));
65 TF_ShapeInferenceContextSetOutput(ctx, 0, handle, status);
66 TF_DeleteShapeHandle(handle);
67 }
68
TEST(OpsTest,TestShapeInference_IdentityFunction)69 TEST(OpsTest, TestShapeInference_IdentityFunction) {
70 ShapeInferenceTestOp op("SomeTestOp");
71
72 TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("SomeTestOp");
73 TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
74 TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
75 TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &identity_shape_fn);
76 TF_Status* status = TF_NewStatus();
77 TF_RegisterOpDefinition(builder, status);
78 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
79
80 TF_ASSERT_OK(
81 shape_inference::ShapeInferenceTestutil::InferShapes(op, "[1,2]", "in0"));
82 TF_DeleteStatus(status);
83 }
84
TEST(OpsTest,TestShapeInference_UnknownShape)85 TEST(OpsTest, TestShapeInference_UnknownShape) {
86 ShapeInferenceTestOp op("UnknownShapeOp");
87
88 TF_OpDefinitionBuilder* builder = TF_NewOpDefinitionBuilder("UnknownShapeOp");
89 TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
90 TF_OpDefinitionBuilderAddInput(builder, "input2: uint32");
91 TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
92 TF_OpDefinitionBuilderAddOutput(builder, "output2: uint8");
93 TF_OpDefinitionBuilderSetShapeInferenceFunction(
94 builder, &TF_ShapeInferenceContextSetUnknownShape);
95 TF_Status* status = TF_NewStatus();
96 TF_RegisterOpDefinition(builder, status);
97 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
98
99 TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
100 op, "[1,2];[3,4]", "?;?"));
101 TF_DeleteStatus(status);
102 }
103
104 // Creates an output whose shape is a vector of length
105 // TF_ShapeInferenceContextRank.
vectorize_shape_fn(TF_ShapeInferenceContext * ctx,TF_Status * status)106 void vectorize_shape_fn(TF_ShapeInferenceContext* ctx, TF_Status* status) {
107 TF_ShapeHandle* handle = TF_NewShapeHandle();
108 TF_ShapeInferenceContextGetInput(ctx, 0, handle, status);
109 ASSERT_EQ(TF_OK, TF_GetCode(status));
110 TF_ShapeHandle* new_shape = TF_ShapeInferenceContextVectorFromSize(
111 ctx, TF_ShapeInferenceContextRank(ctx, handle));
112 TF_ShapeInferenceContextSetOutput(ctx, 0, new_shape, status);
113 TF_DeleteShapeHandle(handle);
114 TF_DeleteShapeHandle(new_shape);
115 }
116
TEST(OpsTest,TestShapeInference_VectorizeFunction)117 TEST(OpsTest, TestShapeInference_VectorizeFunction) {
118 ShapeInferenceTestOp op("VectorizeTestOp");
119
120 TF_OpDefinitionBuilder* builder =
121 TF_NewOpDefinitionBuilder("VectorizeTestOp");
122 TF_OpDefinitionBuilderAddInput(builder, "input1: uint8");
123 TF_OpDefinitionBuilderAddOutput(builder, "output1: uint8");
124 TF_OpDefinitionBuilderSetShapeInferenceFunction(builder, &vectorize_shape_fn);
125 TF_Status* status = TF_NewStatus();
126 TF_RegisterOpDefinition(builder, status);
127 ASSERT_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
128
129 TF_ASSERT_OK(shape_inference::ShapeInferenceTestutil::InferShapes(
130 op, "[4,5,9]", "[3]"));
131 TF_DeleteStatus(status);
132 }
133
TEST(OpsTest,AttributeAccessors)134 TEST(OpsTest, AttributeAccessors) {
135 TF_OpDefinitionBuilder* builder =
136 TF_NewOpDefinitionBuilder("AttributeAccessorsOp");
137 TF_OpDefinitionBuilderAddAttr(builder, "foo1: int >= 2");
138 TF_OpDefinitionBuilderAddAttr(builder, "foo2: string=\"my string\"");
139 TF_OpDefinitionBuilderSetIsCommutative(builder, true);
140 TF_OpDefinitionBuilderSetIsAggregate(builder, true);
141 TF_OpDefinitionBuilderSetAllowsUninitializedInput(builder, true);
142 std::string deprecation_msg = "use something else instead";
143 TF_OpDefinitionBuilderDeprecated(builder, 4, deprecation_msg.c_str());
144
145 TF_Status* status = TF_NewStatus();
146 TF_RegisterOpDefinition(builder, status);
147 ASSERT_EQ(TF_OK, TF_GetCode(status));
148
149 TF_Buffer* op_list_buffer = TF_GetAllOpList();
150 ::tensorflow::OpList op_list;
151 op_list.ParseFromArray(op_list_buffer->data, op_list_buffer->length);
152 bool found = false;
153 for (const auto& op : op_list.op()) {
154 if (op.name() == "AttributeAccessorsOp") {
155 ASSERT_TRUE(op.is_commutative());
156 ASSERT_TRUE(op.is_aggregate());
157 ASSERT_TRUE(op.allows_uninitialized_input());
158 ASSERT_EQ(4, op.deprecation().version());
159 ASSERT_EQ(deprecation_msg, op.deprecation().explanation());
160 ASSERT_EQ(2, op.attr_size());
161 ASSERT_EQ("int", op.attr(0).type());
162 ASSERT_EQ(2, op.attr(0).minimum());
163 ASSERT_EQ("string", op.attr(1).type());
164 ASSERT_EQ("my string", op.attr(1).default_value().s());
165 found = true;
166 }
167 }
168 ASSERT_TRUE(found);
169 TF_DeleteStatus(status);
170 TF_DeleteBuffer(op_list_buffer);
171 }
172
173 #define C_CTX(x) reinterpret_cast<TF_ShapeInferenceContext*>(x)
174 #define C_SHP(x) reinterpret_cast<TF_ShapeHandle*>(x)
175
MakeOpDef(int num_inputs,int num_outputs)176 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
177 OpRegistrationData op_reg_data;
178 OpDefBuilder b("dummy");
179 for (int i = 0; i < num_inputs; ++i) {
180 b.Input(strings::StrCat("i", i, ": float"));
181 }
182 for (int i = 0; i < num_outputs; ++i) {
183 b.Output(strings::StrCat("o", i, ": float"));
184 }
185 CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
186 return op_reg_data.op_def;
187 }
188
189 // Tests for shape inference
190
S(std::initializer_list<int64_t> dims)191 PartialTensorShape S(std::initializer_list<int64_t> dims) {
192 return PartialTensorShape(dims);
193 }
194
Unknown()195 PartialTensorShape Unknown() { return PartialTensorShape(); }
196
TEST(OpsTest,ShapeInferenceWithRank)197 TEST(OpsTest, ShapeInferenceWithRank) {
198 NodeDef def;
199 shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
200 {S({10, 20, 30})}, {}, {}, {});
201
202 shape_inference::ShapeHandle in0 = c.input(0);
203 shape_inference::ShapeHandle s1;
204
205 TF_Status* status = TF_NewStatus();
206 TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
207 status);
208 EXPECT_EQ("[10,20,30]", c.DebugString(s1));
209 EXPECT_EQ(TF_OK, TF_GetCode(status));
210
211 TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
212 status);
213 EXPECT_EQ("[10,20,30]", c.DebugString(s1));
214 EXPECT_EQ(TF_OK, TF_GetCode(status));
215
216 TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 6, C_SHP(&s1),
217 status);
218 ASSERT_NE(TF_OK, TF_GetCode(status));
219
220 TF_SetStatus(status, TF_OK, "");
221 TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
222 status);
223 ASSERT_NE(TF_OK, TF_GetCode(status));
224
225 TF_SetStatus(status, TF_OK, "");
226 TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 3, C_SHP(&s1),
227 status);
228 ASSERT_EQ(TF_OK, TF_GetCode(status));
229
230 TF_ShapeInferenceContextWithRank(C_CTX(&c), C_SHP(&in0), 4, C_SHP(&s1),
231 status);
232 ASSERT_NE(TF_OK, TF_GetCode(status));
233
234 TF_DeleteStatus(status);
235 }
236
TEST(OpsTest,ShapeInferenceWithRank_UnknownRank)237 TEST(OpsTest, ShapeInferenceWithRank_UnknownRank) {
238 NodeDef def;
239 shape_inference::InferenceContext c(0, def, MakeOpDef(2, 2),
240 {Unknown(), S({1, -1, 3})}, {}, {}, {});
241
242 shape_inference::ShapeHandle in0 = c.input(0);
243 shape_inference::ShapeHandle s1;
244
245 // WithRankAtMost and WithRankAtLeast on a shape with unknown dimensionality
246 // always succeed.
247 TF_Status* status = TF_NewStatus();
248 TF_ShapeInferenceContextWithRankAtMost(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
249 status);
250 EXPECT_EQ("?", c.DebugString(s1));
251 EXPECT_EQ(TF_OK, TF_GetCode(status));
252
253 TF_ShapeInferenceContextWithRankAtLeast(C_CTX(&c), C_SHP(&in0), 1, C_SHP(&s1),
254 status);
255 EXPECT_EQ("?", c.DebugString(s1));
256 EXPECT_EQ(TF_OK, TF_GetCode(status));
257
258 TF_DeleteStatus(status);
259 }
260
TEST(OpsTest,ShapeInferenceConcatenateShapes)261 TEST(OpsTest, ShapeInferenceConcatenateShapes) {
262 NodeDef def;
263 shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
264 {S({1, 2}), S({3, 4})}, {}, {}, {});
265 ASSERT_EQ(2, TF_ShapeInferenceContextNumInputs(C_CTX(&c)));
266 shape_inference::ShapeHandle a = c.input(0);
267 shape_inference::ShapeHandle b = c.input(1);
268 TF_ShapeHandle* result = TF_NewShapeHandle();
269 TF_Status* status = TF_NewStatus();
270 TF_ShapeInferenceContextConcatenateShapes(C_CTX(&c), C_SHP(&a), C_SHP(&b),
271 result, status);
272 EXPECT_EQ(
273 "[1,2,3,4]",
274 c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(result)));
275 EXPECT_EQ(TF_OK, TF_GetCode(status));
276 TF_DeleteShapeHandle(result);
277 TF_DeleteStatus(status);
278 }
279
TEST(OpsTest,DimensionHandleValueKnown)280 TEST(OpsTest, DimensionHandleValueKnown) {
281 NodeDef def;
282 shape_inference::InferenceContext c(0, def, MakeOpDef(2, 0),
283 {S({1, 2}), S({3, 4})}, {}, {}, {});
284 TF_ShapeHandle* handle =
285 TF_ShapeInferenceContextVectorFromSize(C_CTX(&c), 43);
286 ASSERT_EQ(
287 "[43]",
288 c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
289 ASSERT_EQ(1, TF_ShapeInferenceContextRankKnown(C_CTX(&c), handle));
290 ASSERT_EQ(1, TF_ShapeInferenceContextRank(C_CTX(&c), handle));
291
292 TF_DimensionHandle* dim_handle = TF_NewDimensionHandle();
293 TF_ShapeInferenceContextDim(C_CTX(&c), handle, 0, dim_handle);
294 ASSERT_EQ(1, TF_DimensionHandleValueKnown(dim_handle));
295 ASSERT_EQ(43, TF_DimensionHandleValue(dim_handle));
296 TF_DeleteShapeHandle(handle);
297 TF_DeleteDimensionHandle(dim_handle);
298 }
299
TEST(OpsTest,ShapeInferenceSubshape)300 TEST(OpsTest, ShapeInferenceSubshape) {
301 NodeDef def;
302 shape_inference::InferenceContext c(0, def, MakeOpDef(1, 0),
303 {S({10, 20, 30, 40, 50})}, {}, {}, {});
304 ASSERT_EQ("[10,20,30,40,50]", c.DebugString(c.input(0)));
305
306 TF_ShapeHandle* handle = TF_NewShapeHandle();
307 TF_Status* status = TF_NewStatus();
308 TF_ShapeInferenceContextGetInput(C_CTX(&c), 0, handle, status);
309 ASSERT_EQ(TF_OK, TF_GetCode(status));
310 TF_ShapeInferenceContextSubshape(C_CTX(&c), handle, 1, -1, handle, status);
311 ASSERT_EQ(TF_OK, TF_GetCode(status));
312 ASSERT_EQ(
313 "[20,30,40]",
314 c.DebugString(*reinterpret_cast<shape_inference::ShapeHandle*>(handle)));
315 TF_DeleteStatus(status);
316 TF_DeleteShapeHandle(handle);
317 }
318
TEST(OpsTest,ShapeInferenceScalarShape)319 TEST(OpsTest, ShapeInferenceScalarShape) {
320 NodeDef def;
321 shape_inference::InferenceContext c(0, def, MakeOpDef(0, 0), {S({})}, {}, {},
322 {});
323 TF_ShapeHandle* TF_scalar_shape = TF_ShapeInferenceContextScalar(C_CTX(&c));
324 shape_inference::ShapeHandle* scalar_shape =
325 reinterpret_cast<shape_inference::ShapeHandle*>(TF_scalar_shape);
326 ASSERT_EQ("[]", c.DebugString(*scalar_shape));
327 TF_DeleteShapeHandle(TF_scalar_shape);
328 }
329
330 } // namespace
331 } // namespace tensorflow
332