1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/c/c_api.h"
17 #include "tensorflow/c/c_api_internal.h"
18 #include "tensorflow/c/c_test_util.h"
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/function.pb.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/lib/hash/hash.h"
23 #include "tensorflow/core/lib/strings/proto_serialization.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/platform/str_util.h"
27 #include "tensorflow/core/platform/strcat.h"
28 #include "tensorflow/core/platform/test.h"
29
30 namespace tensorflow {
31 namespace {
32
33 // Specification for expected input/output and its type.
34 // DataType value of DT_INVALID signifies that we don't want to
35 // check the data type.
36 typedef std::pair<string, DataType> IOSpec;
37
M(const std::initializer_list<string> & names)38 std::vector<IOSpec> M(const std::initializer_list<string>& names) {
39 std::vector<IOSpec> v;
40 for (const string& name : names) {
41 v.push_back(IOSpec(name, DT_INVALID));
42 }
43 return v;
44 }
45
46 // Specification for an expected edge.
47 // src is either:
48 // - input name (as it appears in FunctionDef)
49 // - name of output tensor (in nested "add:z:0" format)
50 // dst is either:
51 // - output name (as it appears in FunctionDef)
52 // - <name_of_node>:<index_of_this_input_into_node> (this looks the same as
53 // output tensor naming, but it the index is actually an input index)
54 struct EdgeSpec : public std::pair<string, string> {
55 typedef std::pair<string, string> Base;
56
57 // Inherit the set of constructors
58 using Base::pair;
59
ToStringtensorflow::__anon10e0d3b10111::EdgeSpec60 string ToString() const { return strings::StrCat(first, "->", second); }
61 };
62
63 class CApiFunctionTest : public ::testing::Test {
64 protected:
CApiFunctionTest()65 CApiFunctionTest()
66 : s_(TF_NewStatus()),
67 func_graph_(TF_NewGraph()),
68 host_graph_(TF_NewGraph()),
69 func_(nullptr) {}
70
SetUp()71 void SetUp() override {}
72
~CApiFunctionTest()73 ~CApiFunctionTest() override {
74 TF_DeleteFunction(func_);
75 TF_DeleteGraph(host_graph_);
76 TF_DeleteGraph(func_graph_);
77 TF_DeleteStatus(s_);
78 }
79
Run(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,TF_Operation * output,int32_t expected_result)80 void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
81 TF_Operation* output, int32_t expected_result) {
82 Run(inputs, {{output, 0}}, {expected_result});
83 }
84
85 // Run the host graph, which now contains a function and check that
86 // outputs are as expected.
87 // 'T' stands for 'tensor' since the outputs are tensors, not scalars.
RunT(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,std::initializer_list<TF_Output> outputs,const std::vector<std::vector<int32_t>> & expected_results)88 void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
89 std::initializer_list<TF_Output> outputs,
90 const std::vector<std::vector<int32_t>>& expected_results) {
91 // Create a session for this graph
92 CSession csession(host_graph_, s_);
93 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
94
95 // Run
96 csession.SetInputs(inputs);
97 csession.SetOutputs(outputs);
98 csession.Run(s_);
99 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
100
101 // Check results
102 for (int i = 0; i < expected_results.size(); ++i) {
103 TF_Tensor* out = csession.output_tensor(i);
104 ASSERT_TRUE(out != nullptr);
105 EXPECT_EQ(TF_INT32, TF_TensorType(out));
106 EXPECT_EQ(1, TF_NumDims(out));
107 CompareInt32Tensor(expected_results[i], out);
108 }
109 }
110
111 // Run the host graph, which now contains a function and check that
112 // outputs are as expected.
Run(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,std::initializer_list<TF_Output> outputs,const std::vector<int32_t> & expected_results)113 void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
114 std::initializer_list<TF_Output> outputs,
115 const std::vector<int32_t>& expected_results) {
116 // Create a session for this graph.
117 CSession csession(host_graph_, s_);
118 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
119
120 csession.SetInputs(inputs);
121 csession.SetOutputs(outputs);
122 csession.Run(s_);
123 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
124
125 for (int i = 0; i < expected_results.size(); ++i) {
126 TF_Tensor* out = csession.output_tensor(i);
127 ASSERT_TRUE(out != nullptr);
128 EXPECT_EQ(TF_INT32, TF_TensorType(out));
129 EXPECT_EQ(0, TF_NumDims(out)); // scalar
130 ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
131 int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out));
132 EXPECT_EQ(expected_results[i], *output_contents);
133 }
134 }
135
CompareInt32Tensor(const std::vector<int32_t> & expected,TF_Tensor * t)136 void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) {
137 int32_t* data = static_cast<int32_t*>(TF_TensorData(t));
138 size_t size = TF_TensorByteSize(t);
139 ASSERT_EQ(expected.size() * sizeof(int32_t), size);
140 for (int i = 0; i < expected.size(); ++i) {
141 ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i;
142 }
143 }
144
ToOutput(const std::vector<TF_Operation * > ops)145 std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) {
146 std::vector<TF_Output> out;
147 for (auto op : ops) {
148 out.push_back({op, 0});
149 }
150 return out;
151 }
152
Define(int num_opers,const std::vector<TF_Operation * > & opers,const std::vector<TF_Operation * > & inputs,const std::vector<TF_Operation * > & outputs,const std::vector<string> & output_names,bool expect_failure=false)153 void Define(int num_opers, const std::vector<TF_Operation*>& opers,
154 const std::vector<TF_Operation*>& inputs,
155 const std::vector<TF_Operation*>& outputs,
156 const std::vector<string>& output_names,
157 bool expect_failure = false) {
158 DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names,
159 expect_failure);
160 }
161
162 // Caller must delete[] the returned value
ToArray(const std::vector<string> & strs)163 static const char** ToArray(const std::vector<string>& strs) {
164 const char** ptr = nullptr;
165 if (!strs.empty()) {
166 ptr = new const char*[strs.size()];
167 for (size_t i = 0; i < strs.size(); ++i) {
168 ptr[i] = strs[i].c_str();
169 }
170 }
171 return ptr;
172 }
173
174 // An explicit `num_opers` is needed so that we can distinguish between the
175 // case of no operations specified (-1) and the case of an empty set of
176 // operations specified (0).
DefineT(int num_opers,const std::vector<TF_Operation * > & opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<string> & output_names,bool expect_failure=false)177 void DefineT(int num_opers, const std::vector<TF_Operation*>& opers,
178 const std::vector<TF_Output>& inputs,
179 const std::vector<TF_Output>& outputs,
180 const std::vector<string>& output_names,
181 bool expect_failure = false) {
182 ASSERT_EQ(func_, nullptr);
183 const char** output_names_ptr = ToArray(output_names);
184 func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers,
185 num_opers == -1 ? nullptr : opers.data(),
186 inputs.size(), inputs.data(), outputs.size(),
187 outputs.data(), output_names_ptr,
188 /*opts=*/nullptr, /*description=*/nullptr, s_);
189 delete[] output_names_ptr;
190 if (expect_failure) {
191 ASSERT_EQ(func_, nullptr);
192 return;
193 }
194
195 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
196 ASSERT_NE(func_, nullptr);
197 ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_)));
198 TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
199 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
200 }
201
Use(const std::vector<TF_Operation * > & inputs)202 TF_Operation* Use(const std::vector<TF_Operation*>& inputs) {
203 return UseT(ToOutput(inputs));
204 }
205
UseT(const std::vector<TF_Output> & inputs)206 TF_Operation* UseT(const std::vector<TF_Output>& inputs) {
207 TF_Operation* op;
208 UseHelper(inputs, &op);
209 return op;
210 }
211
212 // All the *Helper methods are used as a workaround for the restrictions that
213 // one cannot call ASSERT_* methods in non-void-returning functions (when
214 // exceptions are disabled during compilation)
UseHelper(const std::vector<TF_Output> & inputs,TF_Operation ** op)215 void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) {
216 TF_OperationDescription* desc =
217 TF_NewOperation(host_graph_, func_name_, func_node_name_);
218 for (auto input : inputs) {
219 TF_AddInput(desc, input);
220 }
221 // Set device to CPU because some ops inside the function might not be
222 // available on GPU.
223 TF_SetDevice(desc, "/cpu:0");
224 *op = TF_FinishOperation(desc, s_);
225 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
226 ASSERT_NE(*op, nullptr);
227 }
228
fdef()229 FunctionDef fdef() {
230 tensorflow::FunctionDef fdef;
231 EXPECT_TRUE(GetFunctionDef(func_, &fdef));
232 return fdef;
233 }
234
235 // logging utility
236 template <class Container>
ToString(const Container & v)237 string ToString(const Container& v) {
238 std::stringstream ss;
239 ss << "{";
240 size_t i = 0;
241 for (const auto& e : v) {
242 if (i != 0) {
243 ss << ", ";
244 }
245 ss << e.ToString();
246 ++i;
247 }
248 ss << "}";
249 return ss.str();
250 }
251
VerifyFDefNodes(const tensorflow::FunctionDef & fdef,const std::unordered_set<string> & nodes)252 void VerifyFDefNodes(const tensorflow::FunctionDef& fdef,
253 const std::unordered_set<string>& nodes) {
254 ASSERT_EQ(nodes.size(), fdef.node_def_size())
255 << "Got unexpected number of nodes. Expected: ["
256 << absl::StrJoin(nodes, ", ")
257 << "] Actual nodes in fdef: " << fdef.DebugString();
258 for (const NodeDef& node_def : fdef.node_def()) {
259 ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
260 << "Got unexpected node: " << node_def.name()
261 << " in fdef: " << fdef.DebugString();
262 }
263 }
264
VerifyFDefInputs(const tensorflow::FunctionDef & fdef,const std::vector<IOSpec> & inputs)265 void VerifyFDefInputs(const tensorflow::FunctionDef& fdef,
266 const std::vector<IOSpec>& inputs) {
267 const OpDef& signature = fdef.signature();
268 ASSERT_EQ(inputs.size(), signature.input_arg_size());
269 for (int i = 0; i < inputs.size(); ++i) {
270 const OpDef::ArgDef& arg = signature.input_arg(i);
271 const IOSpec& in = inputs[i];
272 if (in.second != DT_INVALID) {
273 ASSERT_EQ(arg.type(), in.second)
274 << "Got unexpected type for input " << i
275 << ". fdef: " << fdef.DebugString();
276 }
277 ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i
278 << ". fdef: " << fdef.DebugString();
279 }
280 }
281
VerifyFDefOutputs(const tensorflow::FunctionDef & fdef,const std::vector<IOSpec> & outputs)282 void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef,
283 const std::vector<IOSpec>& outputs) {
284 const OpDef& signature = fdef.signature();
285 ASSERT_EQ(outputs.size(), signature.output_arg_size());
286 for (int i = 0; i < outputs.size(); ++i) {
287 const OpDef::ArgDef& arg = signature.output_arg(i);
288 const IOSpec& out = outputs[i];
289 if (out.second != DT_INVALID) {
290 ASSERT_EQ(arg.type(), out.second)
291 << "Got unexpected type for output " << i
292 << ". fdef: " << fdef.DebugString();
293 }
294 ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i
295 << ". fdef: " << fdef.DebugString();
296 }
297 }
298
VerifyFDefEdges(const tensorflow::FunctionDef & fdef,const std::vector<EdgeSpec> & e_edges,const std::vector<EdgeSpec> & c_edges,bool is_exact_edges=true)299 void VerifyFDefEdges(
300 const tensorflow::FunctionDef& fdef,
301 const std::vector<EdgeSpec>& e_edges, // expected edges
302 const std::vector<EdgeSpec>& c_edges, // expected ctrl edges
303 bool is_exact_edges = true) {
304 // Build a set of edges from fdef
305 std::set<EdgeSpec> a_edges; // actual edges
306 // Get edges from inputs to body nodes and between body nodes
307 for (const NodeDef& node_def : fdef.node_def()) {
308 for (int i = 0; i < node_def.input_size(); ++i) {
309 const string& in = node_def.input(i);
310 const auto& v =
311 a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)});
312 ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> "
313 << strings::StrCat(node_def.name(), ":", i)
314 << ". fdef: " << fdef.DebugString();
315 }
316 }
317 // Get edges from body nodes to outputs and from inputs to outputs
318 for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) {
319 const auto& iter = fdef.ret().find(arg.name());
320 if (iter != fdef.ret().end()) {
321 const auto& v = a_edges.insert({iter->second, arg.name()});
322 ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> "
323 << arg.name() << ". fdef: " << fdef.DebugString();
324 } else {
325 const auto& v = a_edges.insert({arg.name(), arg.name()});
326 ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> "
327 << arg.name() << ". fdef: " << fdef.DebugString();
328 }
329 }
330
331 // Verify edges
332 for (const EdgeSpec& e : e_edges) {
333 ASSERT_TRUE(a_edges.find(e) != a_edges.end())
334 << "Failed to find expected edge " << e.ToString()
335 << " in fdef: " << fdef.DebugString();
336 }
337 for (const EdgeSpec& e : c_edges) {
338 ASSERT_TRUE(a_edges.find(e) != a_edges.end())
339 << "Failed to find expected control edge " << e.ToString()
340 << " in fdef: " << fdef.DebugString();
341 }
342
343 // If caller specified all edges, check that we have seen all
344 if (is_exact_edges) {
345 ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size())
346 << "Expected edges: " << ToString(e_edges)
347 << " Expected Control edges: " << ToString(c_edges)
348 << " Actual edges: " << ToString(a_edges)
349 << " in fdef: " << fdef.DebugString();
350 }
351 }
352
VerifyFDef(const std::unordered_set<string> & nodes,const std::vector<IOSpec> & inputs,const std::vector<IOSpec> & outputs,const std::vector<EdgeSpec> & e_edges,const std::vector<EdgeSpec> & c_edges,bool is_exact_edges=true)353 void VerifyFDef(const std::unordered_set<string>& nodes,
354 const std::vector<IOSpec>& inputs,
355 const std::vector<IOSpec>& outputs,
356 const std::vector<EdgeSpec>& e_edges, // expected edges
357 const std::vector<EdgeSpec>& c_edges, // expected ctrl edges
358 bool is_exact_edges = true) {
359 tensorflow::FunctionDef fdef;
360 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
361 VerifyFDefNodes(fdef, nodes);
362 VerifyFDefInputs(fdef, inputs);
363 VerifyFDefOutputs(fdef, outputs);
364 VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
365 }
366
367 // Serialize func_ to fdef and import it back
Reincarnate()368 void Reincarnate() {
369 // func_ -> fdef
370 tensorflow::FunctionDef fdef;
371 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
372 TF_DeleteFunction(func_);
373
374 // fdef -> func_
375 string buf;
376 ASSERT_TRUE(fdef.SerializeToString(&buf));
377 func_ = TF_FunctionImportFunctionDef(buf.data(), buf.size(), s_);
378 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
379 }
380
GetAttr(const char * attr_name,AttrValue * out_attr)381 void GetAttr(const char* attr_name, AttrValue* out_attr) {
382 TF_Buffer* attr_buf = TF_NewBuffer();
383 TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
384 ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
385 TF_DeleteBuffer(attr_buf);
386 }
387
388 const char* func_name_ = "MyFunc";
389 const char* func_node_name_ = "MyFunc_0";
390 TF_Status* s_;
391 TF_Graph* func_graph_;
392 TF_Graph* host_graph_;
393 TF_Function* func_;
394
395 // Workaround for not being able to initialize empty map using {}
396 std::unordered_set<string> empty_;
397 };
398
TEST_F(CApiFunctionTest,OneOp_ZeroInputs_OneOutput)399 TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) {
400 /*
401 * constant
402 * |
403 * v
404 */
405 // Define
406 TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10");
407 Define(-1, {}, {}, {c}, {});
408
409 // Use, run, and verify
410 TF_Operation* func_op = Use({});
411 Run({}, func_op, 10);
412 VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}},
413 {{"scalar10_0:output:0", "scalar10"}}, {});
414 }
415
TEST_F(CApiFunctionTest,OneOp_OneInput_OneOutput)416 TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) {
417 /*
418 * |
419 * v
420 * negate
421 * |
422 * v
423 */
424 // Define
425 TF_Operation* feed = Placeholder(func_graph_, s_);
426 TF_Operation* neg = Neg(feed, func_graph_, s_);
427 Define(-1, {}, {feed}, {neg}, {});
428
429 // Use, run, and verify
430 TF_Operation* func_feed = Placeholder(host_graph_, s_);
431 TF_Operation* func_op = Use({func_feed});
432 Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
433 VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}},
434 {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {});
435 }
436
TEST_F(CApiFunctionTest,OneOutput_OutputNames)437 TEST_F(CApiFunctionTest, OneOutput_OutputNames) {
438 /*
439 * |
440 * v
441 * negate
442 * |
443 * v
444 */
445 // Define
446 TF_Operation* feed = Placeholder(func_graph_, s_);
447 TF_Operation* neg = Neg(feed, func_graph_, s_);
448 Define(-1, {}, {feed}, {neg}, {"negated_num"});
449
450 // Use, run, and verify
451 TF_Operation* func_feed = Placeholder(host_graph_, s_);
452 TF_Operation* func_op = Use({func_feed});
453 Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
454 VerifyFDef({"neg"}, {{"feed", DT_INT32}}, {{"negated_num", DT_INT32}},
455 {{"feed", "neg:0"}, {"neg:y:0", "negated_num"}}, {});
456 }
457
TEST_F(CApiFunctionTest,OutputNames_SameNameAsInput)458 TEST_F(CApiFunctionTest, OutputNames_SameNameAsInput) {
459 /*
460 * |
461 * v
462 * negate
463 * |
464 * v
465 */
466 // Define
467 TF_Operation* feed = Placeholder(func_graph_, s_, "negation");
468 TF_Operation* neg = Neg(feed, func_graph_, s_, "neg");
469 Define(-1, {}, {feed}, {neg}, {"negation"});
470
471 // Use, run, and verify
472 TF_Operation* func_feed = Placeholder(host_graph_, s_);
473 TF_Operation* func_op = Use({func_feed});
474 Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
475 VerifyFDef({"neg"}, {{"negation_0", DT_INT32}}, {{"negation", DT_INT32}},
476 {{"negation_0", "neg:0"}, {"neg:y:0", "negation"}}, {});
477 }
478
TEST_F(CApiFunctionTest,ZeroOps_Identity)479 TEST_F(CApiFunctionTest, ZeroOps_Identity) {
480 /*
481 * |
482 * |
483 * |
484 * v
485 */
486 // Define
487 TF_Operation* feed = Placeholder(func_graph_, s_);
488 Define(-1, {}, {feed}, {feed}, {});
489
490 // Use, run, and verify
491 TF_Operation* func_feed = Placeholder(host_graph_, s_);
492 TF_Operation* func_op = Use({func_feed});
493 Run({{func_feed, Int32Tensor(3)}}, func_op, 3);
494 VerifyFDef(empty_, {{"feed_0", DT_INT32}}, {{"feed", DT_INT32}},
495 {{"feed_0", "feed"}}, {});
496 }
497
TEST_F(CApiFunctionTest,ZeroOps_Permutation)498 TEST_F(CApiFunctionTest, ZeroOps_Permutation) {
499 /*
500 * | |
501 * \ /
502 * \/
503 * x
504 * /\
505 * / \
506 * | |
507 * v v
508 */
509 // Define
510 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
511 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
512 Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {});
513
514 // Use, run, and verify
515 TF_Operation* two = ScalarConst(2, host_graph_, s_);
516 TF_Operation* func_feed = Placeholder(host_graph_, s_);
517 TF_Operation* func_op = Use({two, func_feed});
518 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
519 VerifyFDef(empty_, M({{"feed1_0"}, {"feed2_0"}}), M({{"feed2"}, {"feed1"}}),
520 {{"feed1_0", "feed1"}, {"feed2_0", "feed2"}}, {});
521 }
522
TEST_F(CApiFunctionTest,ZeroOps_Permutation_OutputNames)523 TEST_F(CApiFunctionTest, ZeroOps_Permutation_OutputNames) {
524 /*
525 * | |
526 * \ /
527 * \/
528 * x
529 * /\
530 * / \
531 * | |
532 * v v
533 */
534 // Define
535 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
536 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
537 Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {"first", "second"});
538
539 // Use, run, and verify
540 TF_Operation* two = ScalarConst(2, host_graph_, s_);
541 TF_Operation* func_feed = Placeholder(host_graph_, s_);
542 TF_Operation* func_op = Use({two, func_feed});
543 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
544 VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"first"}, {"second"}}),
545 {{"feed1", "second"}, {"feed2", "first"}}, {});
546 }
547
TEST_F(CApiFunctionTest,OneOp_TwoInputs_OneOutput)548 TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) {
549 /*
550 * | |
551 * v v
552 * add
553 * |
554 * v
555 */
556 // Define
557 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
558 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
559 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
560 Define(-1, {}, {feed1, feed2}, {add}, {});
561
562 // Use, run, and verify
563 TF_Operation* two = ScalarConst(2, host_graph_, s_);
564 TF_Operation* func_feed = Placeholder(host_graph_, s_);
565 TF_Operation* func_op = Use({two, func_feed});
566 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
567 VerifyFDef(
568 {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
569 {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {});
570 }
571
TEST_F(CApiFunctionTest,OneOp_TwoInputs_ZeroOutputs)572 TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) {
573 /*
574 * | |
575 * v v
576 * add
577 *
578 * (output ignored)
579 */
580 // Define
581 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
582 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
583 Add(feed1, feed2, func_graph_, s_);
584 Define(-1, {}, {feed1, feed2}, {}, {});
585
586 // Use, run, and verify
587 TF_Operation* two = ScalarConst(2, host_graph_, s_);
588 TF_Operation* func_feed = Placeholder(host_graph_, s_);
589 Use({two, func_feed});
590 VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {},
591 {{"feed1", "add:0"}, {"feed2", "add:1"}}, {});
592 }
593
TEST_F(CApiFunctionTest,TwoOps_ThreeInputs_OneOutput)594 TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) {
595 /*
596 * | | |
597 * v v /
598 * add1 /
599 * | |
600 * v v
601 * add2
602 * |
603 * v
604 */
605 // Define
606 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
607 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
608 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
609 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
610 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
611 Define(-1, {}, {feed1, feed2, feed3}, {add2}, {});
612
613 // Use, run, and verify
614 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
615 TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
616 TF_Operation* func_feed = Placeholder(host_graph_, s_);
617 TF_Operation* func_op = Use({two, ten, func_feed});
618 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3);
619 VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
620 M({{"add2"}}),
621 {{"feed1", "add1:0"},
622 {"feed2", "add1:1"},
623 {"add1:sum:0", "add2_0:0"},
624 {"feed3", "add2_0:1"},
625 {"add2_0:sum:0", "add2"}},
626 {});
627 }
628
TEST_F(CApiFunctionTest,OneOp_TwoInputs_TwoDuplicateOutputs)629 TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) {
630 /*
631 * | |
632 * v v
633 * add
634 * |
635 * +-+-+
636 * | |
637 * v v
638 */
639 // Define
640 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
641 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
642 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
643 Define(-1, {}, {feed1, feed2}, {add, add}, {});
644
645 // Use, run, and verify
646 TF_Operation* two = ScalarConst(2, host_graph_, s_);
647 TF_Operation* func_feed = Placeholder(host_graph_, s_);
648 TF_Operation* func_op = Use({two, func_feed});
649 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
650 VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}),
651 {{"feed1", "add_1:0"},
652 {"feed2", "add_1:1"},
653 {"add_1:sum:0", "add"},
654 {"add_1:sum:0", "add_0"}},
655 {});
656 }
657
TEST_F(CApiFunctionTest,TwoDuplicateOutputs_OutputNames)658 TEST_F(CApiFunctionTest, TwoDuplicateOutputs_OutputNames) {
659 /*
660 * | |
661 * v v
662 * add
663 * |
664 * +-+-+
665 * | |
666 * v v
667 */
668 // Define
669 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
670 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
671 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
672 Define(-1, {}, {feed1, feed2}, {add, add}, {"out1", "out2"});
673
674 // Use, run, and verify
675 TF_Operation* two = ScalarConst(2, host_graph_, s_);
676 TF_Operation* func_feed = Placeholder(host_graph_, s_);
677 TF_Operation* func_op = Use({two, func_feed});
678 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
679 VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), M({{"out1"}, {"out2"}}),
680 {{"feed1", "add:0"},
681 {"feed2", "add:1"},
682 {"add:sum:0", "out1"},
683 {"add:sum:0", "out2"}},
684 {});
685 }
686
TEST_F(CApiFunctionTest,TwoOps_ThreeInputs_TwoOutputs)687 TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) {
688 /*
689 * | | |
690 * v v /
691 * add /
692 * | |
693 * +-+ |
694 * | | |
695 * | v v
696 * | add
697 * | |
698 * v v
699 */
700 // Define
701 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
702 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
703 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
704 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
705 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
706 Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {});
707
708 // Use, run, and verify
709 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
710 TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
711 TF_Operation* func_feed = Placeholder(host_graph_, s_);
712 TF_Operation* func_op = Use({two, ten, func_feed});
713 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
714 VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
715 M({{"add1"}, {"add2"}}),
716 {{"feed1", "add1_0:0"},
717 {"feed2", "add1_0:1"},
718 {"add1_0:sum:0", "add2_0:0"},
719 {"feed3", "add2_0:1"},
720 {"add1_0:sum:0", "add1"},
721 {"add2_0:sum:0", "add2"}},
722 {});
723 }
724
TEST_F(CApiFunctionTest,FromSubsetOfOps)725 TEST_F(CApiFunctionTest, FromSubsetOfOps) {
726 /*
727 * | | |
728 * v v /
729 * add /
730 * | |
731 * +---+--+---+
732 * Ops used | | | |
733 * for func | v v |
734 * | | add |
735 * +-------> | | |
736 * | v |
737 * | |
738 * +----------+
739 */
740 // Define
741 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
742 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
743 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
744 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
745 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
746 Define(1, {add2}, {add1, feed3}, {add2}, {});
747
748 // Use, run, and verify
749 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
750 TF_Operation* func_feed = Placeholder(host_graph_, s_);
751 TF_Operation* func_op = Use({two, func_feed});
752 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
753 VerifyFDef(
754 {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}),
755 {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}},
756 {});
757 }
758
TEST_F(CApiFunctionTest,UsingOneOutputOfSplit)759 TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) {
760 /*
761 * feed
762 * |
763 * +---------+---+
764 * | const0 | |
765 * | | | |
766 * | v / |
767 * | split |
768 * | | | | |
769 * | v | v |
770 * | | |
771 * +------+------+
772 * |
773 * v
774 *
775 * Only the second output from split is used as function output
776 */
777 // Define
778 TF_Operation* feed = Placeholder(func_graph_, s_);
779 TF_Operation* split = Split3(feed, func_graph_, s_);
780 DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, {});
781
782 // Use, run, and verify
783 TF_Operation* func_feed = Placeholder(host_graph_, s_);
784 TF_Operation* func_op = Use({func_feed});
785 RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}},
786 {{3, 4}});
787 VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}),
788 {{"split3_const0:output:0", "split3_0:0"},
789 {"feed", "split3_0:1"},
790 {"split3_0:output:1", "split3"}},
791 {});
792 }
793
TEST_F(CApiFunctionTest,UsingTwoOutputsOfSplit)794 TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) {
795 /*
796 * feed
797 * |
798 * +---------+---+
799 * | const0 | |
800 * | | | |
801 * | v / |
802 * | split |
803 * | | | | |
804 * | | v | |
805 * | | | |
806 * +---+-----+---+
807 * | |
808 * v v
809 *
810 * Second output from split is not used as function output
811 */
812 // Define
813 TF_Operation* feed = Placeholder(func_graph_, s_);
814 TF_Operation* split = Split3(feed, func_graph_, s_);
815 DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, {});
816
817 // Use, run, and verify
818 TF_Operation* func_feed = Placeholder(host_graph_, s_);
819 TF_Operation* func_op = Use({func_feed});
820 RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}},
821 {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}});
822 VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}),
823 M({{"split3"}, {"split3_0"}}),
824 {{"split3_const0:output:0", "split3_1:0"},
825 {"feed", "split3_1:1"},
826 {"split3_1:output:0", "split3"},
827 {"split3_1:output:2", "split3_0"}},
828 {});
829 }
830
TEST_F(CApiFunctionTest,UsingTwoOutputsOfSplitAsInputs)831 TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) {
832 /*
833 * |
834 * v
835 * split
836 * | | |
837 * | v |
838 * | |
839 * +---+-----+---+
840 * | | | |
841 * | v v |
842 * | add |
843 * | | |
844 * | | |
845 * +------+------+
846 * |
847 * v
848 */
849 // Define
850 TF_Operation* feed = Placeholder(func_graph_, s_);
851 TF_Operation* split = Split3(feed, func_graph_, s_);
852 TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
853 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
854 DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, {});
855
856 // Use, run, and verify
857 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
858 TF_Operation* func_feed = Placeholder(host_graph_, s_);
859 TF_Operation* func_op = Use({two, func_feed});
860 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
861 VerifyFDef(
862 {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}),
863 {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}},
864 {});
865 }
866
TEST_F(CApiFunctionTest,NodesUsedInInputsMustHaveSingleOutput)867 TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) {
868 /*
869 * |
870 * v
871 * split
872 * | | |
873 * | v |
874 * | |
875 * input --->| |<--- input
876 * | |
877 * v v
878 * add
879 * |
880 * |
881 * v
882 */
883 // Define
884 TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3});
885 TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array");
886 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
887 TF_Operation* split = Split3(c, func_graph_, s_);
888 TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
889 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
890 DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, {}, true);
891 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
892 EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in "
893 "`inputs` must have a single output. Node split3 has "
894 "3 outputs. Encountered while creating function 'MyFunc'"),
895 string(TF_Message(s_)));
896
897 TF_DeleteTensor(tensor_123);
898 }
899
TEST_F(CApiFunctionTest,FunctionWithWhileLoop)900 TEST_F(CApiFunctionTest, FunctionWithWhileLoop) {
901 // Inputs to the while loop and the function as a whole
902 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
903 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
904
905 // Outputs of the while loop corresponding to the two inputs above
906 // The first one will the function's output
907 std::vector<TF_Output> outputs;
908
909 // Add while loop to func_graph_
910 {
911 // The inputs to the while loop
912 std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}};
913 std::unique_ptr<TF_WhileParams> params(new TF_WhileParams(
914 TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_)));
915 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
916 params->name = "test_loop";
917
918 // Initialize outputs so we can easily detect errors/bugs
919 outputs.resize(2, {nullptr, -1});
920
921 // Create loop: while (input1 < input2) input1 += input2 + 1
922 TF_Operation* less_than = LessThan(
923 params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_);
924 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
925 params->cond_output = {less_than, 0};
926
927 TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1],
928 params->body_graph, s_, "add1");
929 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
930 TF_Operation* one = ScalarConst(1, params->body_graph, s_);
931 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
932 TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2");
933 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
934 params->body_outputs[0] = {add2, 0};
935 params->body_outputs[1] = params->body_inputs[1];
936
937 // Finalize while loop
938 TF_FinishWhile(params.get(), s_, &outputs[0]);
939 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
940 }
941
942 // Define function, use it in graph, and run
943 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, {});
944 TF_Operation* five = ScalarConst(5, host_graph_, s_, "five");
945 TF_Operation* func_feed = Placeholder(host_graph_, s_);
946 TF_Operation* func_op = Use({func_feed, five});
947 Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1);
948
949 // Verify input, output, and subset of edges in fdef.
950 // The subset of edges we verify is a chain between feed1 and output to
951 // make sure that the correct output is picked.
952 tensorflow::FunctionDef fdef;
953 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
954 VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}}));
955 VerifyFDefOutputs(fdef, M({{"test_loop_exit"}}));
956 VerifyFDefEdges(fdef,
957 {{"feed1", "test_loop/Enter:0"},
958 {"test_loop/Enter:output:0", "test_loop/Merge:0"},
959 {"test_loop/Merge:output:0", "test_loop/Switch:0"},
960 {"test_loop/Switch:output_false:0", "test_loop/Exit:0"},
961 {"test_loop/Exit:output:0", "test_loop_exit"}},
962 {}, false);
963 }
964
TEST_F(CApiFunctionTest,ControlDependency)965 TEST_F(CApiFunctionTest, ControlDependency) {
966 /*
967 * | | scalar
968 * | | .
969 * v v . <---- control dependency
970 * add < -
971 * |
972 * v
973 */
974 // Define
975 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
976 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
977 TF_Operation* five = ScalarConst(5, func_graph_, s_);
978 TF_Operation* add =
979 AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
980 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
981 Define(-1, {}, {feed1, feed2}, {add}, {});
982
983 // Use, run, and verify
984 TF_Operation* two = ScalarConst(2, host_graph_, s_);
985 TF_Operation* func_feed = Placeholder(host_graph_, s_);
986 TF_Operation* func_op = Use({two, func_feed});
987 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
988 VerifyFDef(
989 {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
990 {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
991 {{"^scalar", "add_0:2"}});
992 }
993
TEST_F(CApiFunctionTest,ControlDependencyOutsideOfBody)994 TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) {
995 /*
996 * | | scalar
997 * | | .
998 * v v . <---- control dependency
999 * add < -
1000 * |
1001 * v
1002 */
1003 // Define
1004 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1005 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1006 TF_Operation* five = ScalarConst(5, func_graph_, s_);
1007 TF_Operation* add =
1008 AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
1009 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1010 Define(1, {add}, {feed1, feed2}, {add}, {}, true);
1011 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1012 EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] "
1013 "is not in the body. Encountered while creating "
1014 "function 'MyFunc'"),
1015 string(TF_Message(s_)));
1016 }
1017
TEST_F(CApiFunctionTest,ControlDependencyOutsideOfBody_FromInputNode)1018 TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) {
1019 /*
1020 * | |.
1021 * | | .
1022 * | | .
1023 * v v . <---- control dependency
1024 * add < -
1025 * |
1026 * v
1027 */
1028 // Define
1029 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1030 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1031 TF_Operation* add =
1032 AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_);
1033 EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1034 Define(-1, {}, {feed1, feed2}, {add}, {});
1035
1036 // Use, run, and verify
1037 TF_Operation* two = ScalarConst(2, host_graph_, s_);
1038 TF_Operation* func_feed = Placeholder(host_graph_, s_);
1039 TF_Operation* func_op = Use({two, func_feed});
1040 Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
1041 VerifyFDef(
1042 {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
1043 {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
1044 {{"^feed1", "add_0:2"}});
1045 }
1046
TEST_F(CApiFunctionTest,DuplicateInputsAreNotAllowed)1047 TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) {
1048 /*
1049 * feed
1050 * |
1051 * +++
1052 * | |
1053 * +---+-+---+
1054 * | | | |
1055 * | v v |
1056 * | add |
1057 * | | |
1058 * | | |
1059 * +----+----+
1060 * |
1061 * v
1062 */
1063 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1064 TF_Operation* add = Add(feed1, feed1, func_graph_, s_);
1065 Define(-1, {}, {feed1, feed1}, {add}, {}, true);
1066 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1067 EXPECT_EQ(
1068 string("TF_Output feed1:0 appears more than once in the input list"),
1069 string(TF_Message(s_)));
1070 }
1071
TEST_F(CApiFunctionTest,DuplicateOutputNamesAreNotAllowed)1072 TEST_F(CApiFunctionTest, DuplicateOutputNamesAreNotAllowed) {
1073 /*
1074 * | | |
1075 * v v /
1076 * add /
1077 * | |
1078 * +-+ |
1079 * | | |
1080 * | v v
1081 * | add
1082 * | |
1083 * v v
1084 */
1085 // Define
1086 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1087 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1088 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
1089 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
1090 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
1091 Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {"my_out", "my_out"},
1092 true);
1093 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1094 EXPECT_EQ(string("Cannot have duplicate output names. Name 'my_out' "
1095 "appears more than once in 'output_names' array."),
1096 string(TF_Message(s_)));
1097 }
1098
TEST_F(CApiFunctionTest,InvalidInputTensor_HighIndex)1099 TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
1100 /*
1101 * | |
1102 * v v
1103 * add
1104 * |
1105 * v
1106 */
1107 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1108 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1109 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1110 DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true);
1111 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
1112 EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
1113 "not have output 2\n\tEncountered while processing "
1114 "input 1 into function 'MyFunc'"),
1115 string(TF_Message(s_)));
1116 }
1117
TEST_F(CApiFunctionTest,InvalidInputTensor_BadNodePtr)1118 TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) {
1119 /*
1120 * | |
1121 * v v
1122 * add
1123 * |
1124 * v
1125 */
1126 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1127 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1128 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1129 DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, {}, true);
1130 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1131 EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 "
1132 "into function 'MyFunc'"),
1133 string(TF_Message(s_)));
1134 }
1135
TEST_F(CApiFunctionTest,InvalidOutputTensor_HighIndex)1136 TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
1137 /*
1138 * | |
1139 * v v
1140 * add
1141 * |
1142 * v
1143 */
1144 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1145 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1146 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1147 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true);
1148 EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
1149 EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
1150 "not have output 3\n\tEncountered while processing "
1151 "output 0 from function 'MyFunc'"),
1152 string(TF_Message(s_)));
1153 }
1154
TEST_F(CApiFunctionTest,InvalidOutputTensor_BadNodePtr)1155 TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) {
1156 /*
1157 * | |
1158 * v v
1159 * add
1160 * |
1161 * v
1162 */
1163 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1164 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1165 Add(feed1, feed2, func_graph_, s_);
1166 DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, {}, true);
1167 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1168 EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 "
1169 "from function 'MyFunc'"),
1170 string(TF_Message(s_)));
1171 }
1172
TEST_F(CApiFunctionTest,NodeMissingInput)1173 TEST_F(CApiFunctionTest, NodeMissingInput) {
1174 /*
1175 * input---> | | <----missing input
1176 * v v
1177 * body----> add
1178 * |
1179 * v
1180 */
1181 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1182 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1183 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1184 DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, {}, true);
1185 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1186 EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' "
1187 "is not available. You might need to include it in inputs "
1188 "or include its source node in the body"),
1189 string(TF_Message(s_)));
1190 }
1191
TEST_F(CApiFunctionTest,OutputOpNotInBody)1192 TEST_F(CApiFunctionTest, OutputOpNotInBody) {
1193 /*
1194 * | |
1195 * v v
1196 * add scalar (scalar not included in body)
1197 * | |
1198 * v v (function has two outputs)
1199 */
1200 // Define
1201 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1202 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1203 TF_Operation* scalar = ScalarConst(2, func_graph_, s_);
1204 TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1205 Define(1, {add}, {feed1, feed2}, {add, scalar}, {}, true);
1206 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1207 EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor "
1208 "among function inputs. Encountered while creating "
1209 "function 'MyFunc'"),
1210 string(TF_Message(s_)));
1211 }
1212
DefineFunction(const char * name,TF_Function ** func,const char * description=nullptr,bool append_hash=false)1213 void DefineFunction(const char* name, TF_Function** func,
1214 const char* description = nullptr,
1215 bool append_hash = false) {
1216 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1217 TF_NewGraph(), TF_DeleteGraph);
1218 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1219 TF_DeleteStatus);
1220
1221 TF_Operation* feed = Placeholder(func_graph.get(), s.get());
1222 TF_Operation* neg = Neg(feed, func_graph.get(), s.get());
1223
1224 TF_Output inputs[] = {{feed, 0}};
1225 TF_Output outputs[] = {{neg, 0}};
1226 *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1,
1227 /*opers=*/nullptr, 1, inputs, 1, outputs,
1228 /*output_names=*/nullptr,
1229 /*opts=*/nullptr, description, s.get());
1230 ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1231 ASSERT_NE(*func, nullptr);
1232 }
1233
1234 REGISTER_OP("CustomOp")
1235 .Output("output: float32")
1236 .Attr("index: int")
1237 .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1238
NodeWithPlaceholderAttrHelper(TF_Graph * graph,TF_Status * s,const char * name,const char * placeholder,TF_Operation ** op)1239 void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s,
1240 const char* name, const char* placeholder,
1241 TF_Operation** op) {
1242 TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name);
1243 TF_SetAttrPlaceholder(desc, "index", placeholder);
1244 *op = TF_FinishOperation(desc, s);
1245 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1246 ASSERT_NE(*op, nullptr);
1247 }
1248
TEST_F(CApiFunctionTest,GraphToFunctionDefWithPlaceholderAttr)1249 TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
1250 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1251 TF_NewGraph(), TF_DeleteGraph);
1252 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1253 TF_DeleteStatus);
1254
1255 TF_Operation *node1, *node2, *node3;
1256 NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1",
1257 &node1);
1258 NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1",
1259 &node2);
1260 NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
1261 &node3);
1262
1263 TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
1264 func_ = TF_GraphToFunction(
1265 func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
1266 /*opers=*/nullptr, 0, nullptr, 3, outputs,
1267 /*output_names=*/nullptr,
1268 /*opts=*/nullptr, /*description=*/nullptr, s.get());
1269 ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1270 ASSERT_NE(func_, nullptr);
1271
1272 // Verify that FunctionDef has 2 attributes, "v1" and "v2".
1273 ASSERT_EQ(func_->fdef.signature().attr().size(), 2);
1274 EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1");
1275 EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int");
1276 EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2");
1277 EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
1278 }
1279
NodeWithAttrHelper(TF_Graph * graph,TF_Status * s,const char * name,const char * attr_name,const char * attr_value,TF_Operation ** op)1280 void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
1281 const char* attr_name, const char* attr_value,
1282 TF_Operation** op) {
1283 TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
1284 TF_SetAttrType(desc, "dtype", TF_INT32);
1285 TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
1286 *op = TF_FinishOperation(desc, s);
1287 ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1288 ASSERT_NE(*op, nullptr);
1289 }
1290
TEST_F(CApiFunctionTest,GraphToFunctionDefWithArgAttr)1291 TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
1292 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1293 TF_NewGraph(), TF_DeleteGraph);
1294 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1295 TF_DeleteStatus);
1296
1297 TF_Operation* node;
1298 NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
1299 &node);
1300
1301 TF_Output inputs[] = {{node, 0}};
1302 func_ = TF_GraphToFunction(
1303 func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
1304 /*opers=*/nullptr, 1, inputs, 0, nullptr,
1305 /*output_names=*/nullptr,
1306 /*opts=*/nullptr, /*description=*/nullptr, s.get());
1307 ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1308 ASSERT_NE(func_, nullptr);
1309
1310 // Verify that FunctionDef ArgDef has attributes.
1311 ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
1312 auto arg_attrs = func_->fdef.arg_attr().find(0);
1313 ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
1314 auto iter = arg_attrs->second.attr().find("_test_attr");
1315 ASSERT_NE(iter, arg_attrs->second.attr().end());
1316 EXPECT_EQ(iter->second.s(), "value");
1317 }
1318
TEST_F(CApiFunctionTest,SetGradientAndRun)1319 TEST_F(CApiFunctionTest, SetGradientAndRun) {
1320 // Define the function and its grad
1321 DefineFunction(func_name_, &func_);
1322 TF_Function* grad_func;
1323 DefineFunction("MyGrad", &grad_func);
1324
1325 // Add func and its gradient to host graph
1326 TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
1327 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1328
1329 // Verify that function and its grad are in host graph's GraphDef
1330 GraphDef gdef;
1331 GetGraphDef(host_graph_, &gdef);
1332 std::vector<string> func_names = GetFuncNames(gdef);
1333 ASSERT_EQ(2, func_names.size());
1334 ASSERT_EQ(func_name_, func_names[0]);
1335 ASSERT_EQ("MyGrad", func_names[1]);
1336 std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1337 ASSERT_EQ(1, grads.size());
1338 ASSERT_EQ(func_name_, grads[0].first);
1339 ASSERT_EQ("MyGrad", grads[0].second);
1340
1341 // These calls must be noops
1342 TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
1343 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1344 TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1345 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1346
1347 // Delete the gradient func.
1348 // It is safe to delete after adding a copy to host graph.
1349 TF_DeleteFunction(grad_func);
1350
1351 // Check that GraphDef did not change
1352 GraphDef gdef2;
1353 GetGraphDef(host_graph_, &gdef2);
1354 ASSERT_EQ(gdef.DebugString(), gdef2.DebugString());
1355
1356 // Use and run func
1357 TF_Operation* func_feed = Placeholder(host_graph_, s_);
1358 TF_Operation* func_op = Use({func_feed});
1359 Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
1360 }
1361
TEST_F(CApiFunctionTest,SameGradForTwoFunctions)1362 TEST_F(CApiFunctionTest, SameGradForTwoFunctions) {
1363 // Define the functions
1364 TF_Function* func1;
1365 TF_Function* func2;
1366 TF_Function* grad_func;
1367 DefineFunction("FooFunc1", &func1);
1368 DefineFunction("FooFunc2", &func2);
1369 DefineFunction("MyGrad", &grad_func);
1370
1371 // Make grad_func be a gradient of func1 and func2
1372 TF_GraphCopyFunction(host_graph_, func1, grad_func, s_);
1373 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1374 TF_GraphCopyFunction(host_graph_, func2, grad_func, s_);
1375 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1376
1377 // Verify that functions and their gradients are in host graph's GraphDef
1378 GraphDef gdef;
1379 GetGraphDef(host_graph_, &gdef);
1380 std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1381 ASSERT_EQ(2, grads.size());
1382 ASSERT_EQ("FooFunc1", grads[0].first);
1383 ASSERT_EQ("MyGrad", grads[0].second);
1384 ASSERT_EQ("FooFunc2", grads[1].first);
1385 ASSERT_EQ("MyGrad", grads[1].second);
1386
1387 TF_DeleteFunction(func1);
1388 TF_DeleteFunction(func2);
1389 TF_DeleteFunction(grad_func);
1390 }
1391
TEST_F(CApiFunctionTest,AddFunctionsThenMakeOneGradientOfAnother)1392 TEST_F(CApiFunctionTest, AddFunctionsThenMakeOneGradientOfAnother) {
1393 // Define the functions
1394 TF_Function* func;
1395 TF_Function* grad_func;
1396 DefineFunction("FooFunc", &func);
1397 DefineFunction("MyGrad", &grad_func);
1398
1399 // Add functions individually
1400 TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
1401 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1402 TF_GraphCopyFunction(host_graph_, grad_func, nullptr, s_);
1403 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1404
1405 // Check that functions are added but not linked
1406 GraphDef gdef;
1407 GetGraphDef(host_graph_, &gdef);
1408 std::vector<string> func_names = GetFuncNames(gdef);
1409 ASSERT_EQ(2, func_names.size());
1410 ASSERT_EQ("FooFunc", func_names[0]);
1411 ASSERT_EQ("MyGrad", func_names[1]);
1412 ASSERT_EQ(0, GetGradDefs(gdef).size());
1413
1414 // Make grad_func a gradient of func
1415 TF_GraphCopyFunction(host_graph_, func, grad_func, s_);
1416 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1417
1418 // Verify that function and its grad are linked
1419 gdef.Clear();
1420 GetGraphDef(host_graph_, &gdef);
1421 std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1422 ASSERT_EQ(1, grads.size());
1423 ASSERT_EQ("FooFunc", grads[0].first);
1424 ASSERT_EQ("MyGrad", grads[0].second);
1425
1426 TF_DeleteFunction(func);
1427 TF_DeleteFunction(grad_func);
1428 }
1429
TEST_F(CApiFunctionTest,GradientErrorCases)1430 TEST_F(CApiFunctionTest, GradientErrorCases) {
1431 // Define the function
1432 DefineFunction(func_name_, &func_);
1433 TF_Function* grad_func1;
1434 TF_Function* grad_func2;
1435 DefineFunction("MyGrad1", &grad_func1);
1436 DefineFunction("MyGrad2", &grad_func2);
1437
1438 // func cannot be null
1439 TF_GraphCopyFunction(host_graph_, nullptr, func_, s_);
1440 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1441 EXPECT_EQ(string("'func' argument to TF_GraphCopyFunction cannot be null"),
1442 string(TF_Message(s_)));
1443
1444 // Cannot change gradient
1445 TF_GraphCopyFunction(host_graph_, func_, grad_func1, s_);
1446 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1447 TF_GraphCopyFunction(host_graph_, func_, grad_func2, s_);
1448 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1449 EXPECT_EQ(string("Cannot assign gradient function 'MyGrad2' to 'MyFunc' "
1450 "because it already has gradient function 'MyGrad1'"),
1451 string(TF_Message(s_)));
1452
1453 TF_DeleteFunction(grad_func1);
1454 TF_DeleteFunction(grad_func2);
1455 }
1456
TEST_F(CApiFunctionTest,ImportFunctionDef)1457 TEST_F(CApiFunctionTest, ImportFunctionDef) {
1458 /*
1459 * Using a fairly complex function with output names
1460 *
1461 * | | |
1462 * v v /
1463 * add /
1464 * | |
1465 * +------+ |
1466 * | | |
1467 * | v v
1468 * | add
1469 * | |
1470 * v v
1471 * internal_out final_out
1472 */
1473 // Define
1474 TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1475 TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1476 TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
1477 TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
1478 TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
1479 Define(-1, {}, {feed1, feed2, feed3}, {add1, add2},
1480 {"internal_out", "final_out"});
1481
1482 // Save func_ to FunctionDef and import it back
1483 Reincarnate();
1484
1485 // Use, run, and verify
1486 TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
1487 TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
1488 TF_Operation* func_feed = Placeholder(host_graph_, s_);
1489 TF_Operation* func_op = Use({two, ten, func_feed});
1490 Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
1491 VerifyFDef({"add1", "add2"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
1492 M({{"internal_out"}, {"final_out"}}),
1493 {{"feed1", "add1:0"},
1494 {"feed2", "add1:1"},
1495 {"add1:sum:0", "add2:0"},
1496 {"feed3", "add2:1"},
1497 {"add1:sum:0", "internal_out"},
1498 {"add2:sum:0", "final_out"}},
1499 {});
1500 }
1501
TEST_F(CApiFunctionTest,ImportFunctionDef_InvalidProto)1502 TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
1503 // Invalid protobuf data (protos cannot start with 4 bytes of zeros)
1504 char proto[] = {0x0, 0x0, 0x0, 0x0};
1505 func_ = TF_FunctionImportFunctionDef(proto, 4, s_);
1506 EXPECT_TRUE(func_ == nullptr);
1507 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1508 EXPECT_EQ(string("Invalid FunctionDef given to TF_FunctionImportFunctionDef"),
1509 string(TF_Message(s_)));
1510 }
1511
TEST_F(CApiFunctionTest,Attribute)1512 TEST_F(CApiFunctionTest, Attribute) {
1513 DefineFunction(func_name_, &func_);
1514
1515 // Get non existent attribute
1516 TF_Buffer* attr_buf = TF_NewBuffer();
1517 TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
1518 EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1519 EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
1520 string(TF_Message(s_)));
1521 TF_DeleteBuffer(attr_buf);
1522
1523 // Set attr
1524 tensorflow::AttrValue attr;
1525 attr.set_s("test_attr_value");
1526 string bytes;
1527 attr.SerializeToString(&bytes);
1528 TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
1529 bytes.size(), s_);
1530 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1531
1532 // Get attr
1533 AttrValue read_attr;
1534 GetAttr("test_attr_name", &read_attr);
1535 ASSERT_EQ(attr.DebugString(), read_attr.DebugString());
1536
1537 // Retrieve the same attr after save/restore
1538 Reincarnate();
1539 AttrValue read_attr2;
1540 GetAttr("test_attr_name", &read_attr2);
1541 ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
1542 }
1543
TEST_F(CApiFunctionTest,Description)1544 TEST_F(CApiFunctionTest, Description) {
1545 DefineFunction(func_name_, &func_, "Return something");
1546 tensorflow::FunctionDef fdef;
1547 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1548 ASSERT_EQ(string("Return something"), fdef.signature().description());
1549 }
1550
TEST_F(CApiFunctionTest,Name)1551 TEST_F(CApiFunctionTest, Name) {
1552 DefineFunction("long_func_name", &func_, "Return something",
1553 /*append_hash=*/false);
1554 tensorflow::FunctionDef fdef;
1555 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1556 ASSERT_EQ(string("long_func_name"), fdef.signature().name());
1557 }
1558
TEST_F(CApiFunctionTest,AppendHash)1559 TEST_F(CApiFunctionTest, AppendHash) {
1560 DefineFunction("func_name_base", &func_, "Return something",
1561 /*append_hash=*/true);
1562 tensorflow::FunctionDef fdef;
1563 ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1564 #if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
1565 ASSERT_EQ(string("func_name_base_ZpgUD4x8oqk"), fdef.signature().name());
1566 #else
1567 ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
1568 #endif
1569 }
1570
TEST_F(CApiFunctionTest,GetOpDef)1571 TEST_F(CApiFunctionTest, GetOpDef) {
1572 DefineFunction(func_name_, &func_);
1573 TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1574 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1575
1576 // Test we can retrieve function OpDef from graph
1577 TF_Buffer* buffer = TF_NewBuffer();
1578 TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
1579 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1580
1581 // Sanity check returned OpDef
1582 string data(static_cast<const char*>(buffer->data), buffer->length);
1583 OpDef op_def;
1584 op_def.ParseFromString(data);
1585 EXPECT_EQ(op_def.name(), func_name_);
1586 EXPECT_EQ(op_def.input_arg_size(), 1);
1587 EXPECT_EQ(op_def.output_arg_size(), 1);
1588 EXPECT_FALSE(op_def.is_stateful());
1589
1590 TF_DeleteBuffer(buffer);
1591 }
1592
DefineStatefulFunction(const char * name,TF_Function ** func)1593 void DefineStatefulFunction(const char* name, TF_Function** func) {
1594 std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1595 TF_NewGraph(), TF_DeleteGraph);
1596 std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1597 TF_DeleteStatus);
1598
1599 TF_Tensor* tensor_shape = Int32Tensor({37, 1});
1600 TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape");
1601 TF_Operation* random =
1602 RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
1603
1604 TF_Output outputs[] = {{random, 0}};
1605 *func = TF_GraphToFunction(func_graph.get(), name,
1606 /*append_hash_to_fn_name=*/false, -1,
1607 /*opers=*/nullptr, 0, nullptr, 1, outputs,
1608 /*output_names=*/nullptr,
1609 /*opts=*/nullptr, "", s.get());
1610 ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1611 ASSERT_NE(*func, nullptr);
1612 TF_DeleteTensor(tensor_shape);
1613 }
1614
TEST_F(CApiFunctionTest,StatefulOpDef)1615 TEST_F(CApiFunctionTest, StatefulOpDef) {
1616 DefineStatefulFunction(func_name_, &func_);
1617 TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1618 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1619
1620 // Test we can retrieve function OpDef from graph
1621 TF_Buffer* buffer = TF_NewBuffer();
1622 TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
1623 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1624
1625 // Sanity check returned OpDef
1626 string data(static_cast<const char*>(buffer->data), buffer->length);
1627 OpDef op_def;
1628 op_def.ParseFromString(data);
1629 EXPECT_EQ(op_def.name(), func_name_);
1630 EXPECT_EQ(op_def.input_arg_size(), 0);
1631 EXPECT_EQ(op_def.output_arg_size(), 1);
1632 EXPECT_TRUE(op_def.is_stateful());
1633
1634 TF_DeleteBuffer(buffer);
1635 }
1636
AssertEqual(TF_Function * f1,TF_Function * f2)1637 void AssertEqual(TF_Function* f1, TF_Function* f2) {
1638 string s1, s2;
1639 tensorflow::FunctionDef fdef1, fdef2;
1640 ASSERT_TRUE(GetFunctionDef(f1, &fdef1));
1641 ASSERT_TRUE(GetFunctionDef(f2, &fdef2));
1642 SerializeToStringDeterministic(fdef1, &s1);
1643 SerializeToStringDeterministic(fdef2, &s2);
1644 ASSERT_EQ(s1, s2);
1645 }
1646
GetName(TF_Function * func)1647 string GetName(TF_Function* func) {
1648 tensorflow::FunctionDef fdef;
1649 GetFunctionDef(func, &fdef);
1650 return fdef.signature().name();
1651 }
1652
TEST_F(CApiFunctionTest,GetFunctionsFromGraph)1653 TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
1654 TF_Function* funcs[2];
1655
1656 // Get functions from empty graph
1657 EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 0);
1658 TF_GraphGetFunctions(host_graph_, nullptr, 0, s_);
1659 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1660
1661 // Define a function and add it to host_graph_
1662 TF_Function* func0;
1663 DefineFunction("FooFunc0", &func0);
1664 TF_GraphCopyFunction(host_graph_, func0, nullptr, s_);
1665 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1666
1667 // Get this function from host_graph_
1668 EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 1);
1669 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
1670 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1671 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 1, s_), 1);
1672 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1673 AssertEqual(func0, funcs[0]);
1674 TF_DeleteFunction(funcs[0]);
1675 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 1);
1676 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1677 AssertEqual(func0, funcs[0]);
1678 TF_DeleteFunction(funcs[0]);
1679
1680 // Define a second function
1681 TF_Function* func1;
1682 DefineFunction("FooFunc1", &func1);
1683 TF_GraphCopyFunction(host_graph_, func1, nullptr, s_);
1684 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1685
1686 // Get both function from host_graph_
1687 EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 2);
1688 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
1689 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1690 EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 2);
1691 ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1692 if (GetName(funcs[0]) == GetName(func0)) {
1693 AssertEqual(func0, funcs[0]);
1694 AssertEqual(func1, funcs[1]);
1695 } else {
1696 AssertEqual(func0, funcs[1]);
1697 AssertEqual(func1, funcs[0]);
1698 }
1699
1700 TF_DeleteFunction(funcs[0]);
1701 TF_DeleteFunction(funcs[1]);
1702
1703 TF_DeleteFunction(func0);
1704 TF_DeleteFunction(func1);
1705 }
1706
1707 } // namespace
1708 } // namespace tensorflow
1709