xref: /aosp_15_r20/external/tensorflow/tensorflow/c/c_api_function_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/c/c_api.h"
17 #include "tensorflow/c/c_api_internal.h"
18 #include "tensorflow/c/c_test_util.h"
19 #include "tensorflow/core/framework/common_shape_fns.h"
20 #include "tensorflow/core/framework/function.pb.h"
21 #include "tensorflow/core/framework/op_def.pb.h"
22 #include "tensorflow/core/lib/hash/hash.h"
23 #include "tensorflow/core/lib/strings/proto_serialization.h"
24 #include "tensorflow/core/platform/logging.h"
25 #include "tensorflow/core/platform/status.h"
26 #include "tensorflow/core/platform/str_util.h"
27 #include "tensorflow/core/platform/strcat.h"
28 #include "tensorflow/core/platform/test.h"
29 
30 namespace tensorflow {
31 namespace {
32 
33 // Specification for expected input/output and its type.
34 // DataType value of DT_INVALID signifies that we don't want to
35 // check the data type.
36 typedef std::pair<string, DataType> IOSpec;
37 
M(const std::initializer_list<string> & names)38 std::vector<IOSpec> M(const std::initializer_list<string>& names) {
39   std::vector<IOSpec> v;
40   for (const string& name : names) {
41     v.push_back(IOSpec(name, DT_INVALID));
42   }
43   return v;
44 }
45 
46 // Specification for an expected edge.
47 // src is either:
48 // - input name (as it appears in FunctionDef)
49 // - name of output tensor (in nested "add:z:0" format)
50 // dst is either:
51 // - output name (as it appears in FunctionDef)
52 // - <name_of_node>:<index_of_this_input_into_node> (this looks the same as
53 //      output tensor naming, but it the index is actually an input index)
54 struct EdgeSpec : public std::pair<string, string> {
55   typedef std::pair<string, string> Base;
56 
57   // Inherit the set of constructors
58   using Base::pair;
59 
ToStringtensorflow::__anon10e0d3b10111::EdgeSpec60   string ToString() const { return strings::StrCat(first, "->", second); }
61 };
62 
63 class CApiFunctionTest : public ::testing::Test {
64  protected:
CApiFunctionTest()65   CApiFunctionTest()
66       : s_(TF_NewStatus()),
67         func_graph_(TF_NewGraph()),
68         host_graph_(TF_NewGraph()),
69         func_(nullptr) {}
70 
SetUp()71   void SetUp() override {}
72 
~CApiFunctionTest()73   ~CApiFunctionTest() override {
74     TF_DeleteFunction(func_);
75     TF_DeleteGraph(host_graph_);
76     TF_DeleteGraph(func_graph_);
77     TF_DeleteStatus(s_);
78   }
79 
Run(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,TF_Operation * output,int32_t expected_result)80   void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
81            TF_Operation* output, int32_t expected_result) {
82     Run(inputs, {{output, 0}}, {expected_result});
83   }
84 
85   // Run the host graph, which now contains a function and check that
86   // outputs are as expected.
87   // 'T' stands for 'tensor' since the outputs are tensors, not scalars.
RunT(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,std::initializer_list<TF_Output> outputs,const std::vector<std::vector<int32_t>> & expected_results)88   void RunT(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
89             std::initializer_list<TF_Output> outputs,
90             const std::vector<std::vector<int32_t>>& expected_results) {
91     // Create a session for this graph
92     CSession csession(host_graph_, s_);
93     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
94 
95     // Run
96     csession.SetInputs(inputs);
97     csession.SetOutputs(outputs);
98     csession.Run(s_);
99     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
100 
101     // Check results
102     for (int i = 0; i < expected_results.size(); ++i) {
103       TF_Tensor* out = csession.output_tensor(i);
104       ASSERT_TRUE(out != nullptr);
105       EXPECT_EQ(TF_INT32, TF_TensorType(out));
106       EXPECT_EQ(1, TF_NumDims(out));
107       CompareInt32Tensor(expected_results[i], out);
108     }
109   }
110 
111   // Run the host graph, which now contains a function and check that
112   // outputs are as expected.
Run(const std::vector<std::pair<TF_Operation *,TF_Tensor * >> & inputs,std::initializer_list<TF_Output> outputs,const std::vector<int32_t> & expected_results)113   void Run(const std::vector<std::pair<TF_Operation*, TF_Tensor*>>& inputs,
114            std::initializer_list<TF_Output> outputs,
115            const std::vector<int32_t>& expected_results) {
116     // Create a session for this graph.
117     CSession csession(host_graph_, s_);
118     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
119 
120     csession.SetInputs(inputs);
121     csession.SetOutputs(outputs);
122     csession.Run(s_);
123     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
124 
125     for (int i = 0; i < expected_results.size(); ++i) {
126       TF_Tensor* out = csession.output_tensor(i);
127       ASSERT_TRUE(out != nullptr);
128       EXPECT_EQ(TF_INT32, TF_TensorType(out));
129       EXPECT_EQ(0, TF_NumDims(out));  // scalar
130       ASSERT_EQ(sizeof(int32_t), TF_TensorByteSize(out));
131       int32_t* output_contents = static_cast<int32_t*>(TF_TensorData(out));
132       EXPECT_EQ(expected_results[i], *output_contents);
133     }
134   }
135 
CompareInt32Tensor(const std::vector<int32_t> & expected,TF_Tensor * t)136   void CompareInt32Tensor(const std::vector<int32_t>& expected, TF_Tensor* t) {
137     int32_t* data = static_cast<int32_t*>(TF_TensorData(t));
138     size_t size = TF_TensorByteSize(t);
139     ASSERT_EQ(expected.size() * sizeof(int32_t), size);
140     for (int i = 0; i < expected.size(); ++i) {
141       ASSERT_EQ(expected[i], data[i]) << "Different data at index " << i;
142     }
143   }
144 
ToOutput(const std::vector<TF_Operation * > ops)145   std::vector<TF_Output> ToOutput(const std::vector<TF_Operation*> ops) {
146     std::vector<TF_Output> out;
147     for (auto op : ops) {
148       out.push_back({op, 0});
149     }
150     return out;
151   }
152 
Define(int num_opers,const std::vector<TF_Operation * > & opers,const std::vector<TF_Operation * > & inputs,const std::vector<TF_Operation * > & outputs,const std::vector<string> & output_names,bool expect_failure=false)153   void Define(int num_opers, const std::vector<TF_Operation*>& opers,
154               const std::vector<TF_Operation*>& inputs,
155               const std::vector<TF_Operation*>& outputs,
156               const std::vector<string>& output_names,
157               bool expect_failure = false) {
158     DefineT(num_opers, opers, ToOutput(inputs), ToOutput(outputs), output_names,
159             expect_failure);
160   }
161 
162   // Caller must delete[] the returned value
ToArray(const std::vector<string> & strs)163   static const char** ToArray(const std::vector<string>& strs) {
164     const char** ptr = nullptr;
165     if (!strs.empty()) {
166       ptr = new const char*[strs.size()];
167       for (size_t i = 0; i < strs.size(); ++i) {
168         ptr[i] = strs[i].c_str();
169       }
170     }
171     return ptr;
172   }
173 
174   // An explicit `num_opers` is needed so that we can distinguish between the
175   // case of no operations specified (-1) and the case of an empty set of
176   // operations specified (0).
DefineT(int num_opers,const std::vector<TF_Operation * > & opers,const std::vector<TF_Output> & inputs,const std::vector<TF_Output> & outputs,const std::vector<string> & output_names,bool expect_failure=false)177   void DefineT(int num_opers, const std::vector<TF_Operation*>& opers,
178                const std::vector<TF_Output>& inputs,
179                const std::vector<TF_Output>& outputs,
180                const std::vector<string>& output_names,
181                bool expect_failure = false) {
182     ASSERT_EQ(func_, nullptr);
183     const char** output_names_ptr = ToArray(output_names);
184     func_ = TF_GraphToFunction(func_graph_, func_name_, false, num_opers,
185                                num_opers == -1 ? nullptr : opers.data(),
186                                inputs.size(), inputs.data(), outputs.size(),
187                                outputs.data(), output_names_ptr,
188                                /*opts=*/nullptr, /*description=*/nullptr, s_);
189     delete[] output_names_ptr;
190     if (expect_failure) {
191       ASSERT_EQ(func_, nullptr);
192       return;
193     }
194 
195     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
196     ASSERT_NE(func_, nullptr);
197     ASSERT_EQ(std::string(func_name_), std::string(TF_FunctionName(func_)));
198     TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
199     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
200   }
201 
Use(const std::vector<TF_Operation * > & inputs)202   TF_Operation* Use(const std::vector<TF_Operation*>& inputs) {
203     return UseT(ToOutput(inputs));
204   }
205 
UseT(const std::vector<TF_Output> & inputs)206   TF_Operation* UseT(const std::vector<TF_Output>& inputs) {
207     TF_Operation* op;
208     UseHelper(inputs, &op);
209     return op;
210   }
211 
212   // All the *Helper methods are used as a workaround for the restrictions that
213   // one cannot call ASSERT_* methods in non-void-returning functions (when
214   // exceptions are disabled during compilation)
UseHelper(const std::vector<TF_Output> & inputs,TF_Operation ** op)215   void UseHelper(const std::vector<TF_Output>& inputs, TF_Operation** op) {
216     TF_OperationDescription* desc =
217         TF_NewOperation(host_graph_, func_name_, func_node_name_);
218     for (auto input : inputs) {
219       TF_AddInput(desc, input);
220     }
221     // Set device to CPU because some ops inside the function might not be
222     // available on GPU.
223     TF_SetDevice(desc, "/cpu:0");
224     *op = TF_FinishOperation(desc, s_);
225     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
226     ASSERT_NE(*op, nullptr);
227   }
228 
fdef()229   FunctionDef fdef() {
230     tensorflow::FunctionDef fdef;
231     EXPECT_TRUE(GetFunctionDef(func_, &fdef));
232     return fdef;
233   }
234 
235   // logging utility
236   template <class Container>
ToString(const Container & v)237   string ToString(const Container& v) {
238     std::stringstream ss;
239     ss << "{";
240     size_t i = 0;
241     for (const auto& e : v) {
242       if (i != 0) {
243         ss << ", ";
244       }
245       ss << e.ToString();
246       ++i;
247     }
248     ss << "}";
249     return ss.str();
250   }
251 
VerifyFDefNodes(const tensorflow::FunctionDef & fdef,const std::unordered_set<string> & nodes)252   void VerifyFDefNodes(const tensorflow::FunctionDef& fdef,
253                        const std::unordered_set<string>& nodes) {
254     ASSERT_EQ(nodes.size(), fdef.node_def_size())
255         << "Got unexpected number of nodes. Expected: ["
256         << absl::StrJoin(nodes, ", ")
257         << "] Actual nodes in fdef: " << fdef.DebugString();
258     for (const NodeDef& node_def : fdef.node_def()) {
259       ASSERT_TRUE(nodes.find(node_def.name()) != nodes.end())
260           << "Got unexpected node: " << node_def.name()
261           << " in fdef: " << fdef.DebugString();
262     }
263   }
264 
VerifyFDefInputs(const tensorflow::FunctionDef & fdef,const std::vector<IOSpec> & inputs)265   void VerifyFDefInputs(const tensorflow::FunctionDef& fdef,
266                         const std::vector<IOSpec>& inputs) {
267     const OpDef& signature = fdef.signature();
268     ASSERT_EQ(inputs.size(), signature.input_arg_size());
269     for (int i = 0; i < inputs.size(); ++i) {
270       const OpDef::ArgDef& arg = signature.input_arg(i);
271       const IOSpec& in = inputs[i];
272       if (in.second != DT_INVALID) {
273         ASSERT_EQ(arg.type(), in.second)
274             << "Got unexpected type for input " << i
275             << ". fdef: " << fdef.DebugString();
276       }
277       ASSERT_EQ(arg.name(), in.first) << "Got unexpected name for input " << i
278                                       << ". fdef: " << fdef.DebugString();
279     }
280   }
281 
VerifyFDefOutputs(const tensorflow::FunctionDef & fdef,const std::vector<IOSpec> & outputs)282   void VerifyFDefOutputs(const tensorflow::FunctionDef& fdef,
283                          const std::vector<IOSpec>& outputs) {
284     const OpDef& signature = fdef.signature();
285     ASSERT_EQ(outputs.size(), signature.output_arg_size());
286     for (int i = 0; i < outputs.size(); ++i) {
287       const OpDef::ArgDef& arg = signature.output_arg(i);
288       const IOSpec& out = outputs[i];
289       if (out.second != DT_INVALID) {
290         ASSERT_EQ(arg.type(), out.second)
291             << "Got unexpected type for output " << i
292             << ". fdef: " << fdef.DebugString();
293       }
294       ASSERT_EQ(arg.name(), out.first) << "Got unexpected name for output " << i
295                                        << ". fdef: " << fdef.DebugString();
296     }
297   }
298 
VerifyFDefEdges(const tensorflow::FunctionDef & fdef,const std::vector<EdgeSpec> & e_edges,const std::vector<EdgeSpec> & c_edges,bool is_exact_edges=true)299   void VerifyFDefEdges(
300       const tensorflow::FunctionDef& fdef,
301       const std::vector<EdgeSpec>& e_edges,  // expected edges
302       const std::vector<EdgeSpec>& c_edges,  // expected ctrl edges
303       bool is_exact_edges = true) {
304     // Build a set of edges from fdef
305     std::set<EdgeSpec> a_edges;  // actual edges
306     // Get edges from inputs to body nodes and between body nodes
307     for (const NodeDef& node_def : fdef.node_def()) {
308       for (int i = 0; i < node_def.input_size(); ++i) {
309         const string& in = node_def.input(i);
310         const auto& v =
311             a_edges.insert({in, strings::StrCat(node_def.name(), ":", i)});
312         ASSERT_TRUE(v.second) << "Duplicate edge " << in << " -> "
313                               << strings::StrCat(node_def.name(), ":", i)
314                               << ". fdef: " << fdef.DebugString();
315       }
316     }
317     // Get edges from body nodes to outputs and from inputs to outputs
318     for (const OpDef::ArgDef& arg : fdef.signature().output_arg()) {
319       const auto& iter = fdef.ret().find(arg.name());
320       if (iter != fdef.ret().end()) {
321         const auto& v = a_edges.insert({iter->second, arg.name()});
322         ASSERT_TRUE(v.second) << "Duplicate edge " << iter->second << " -> "
323                               << arg.name() << ". fdef: " << fdef.DebugString();
324       } else {
325         const auto& v = a_edges.insert({arg.name(), arg.name()});
326         ASSERT_TRUE(v.second) << "Duplicate edge " << arg.name() << " -> "
327                               << arg.name() << ". fdef: " << fdef.DebugString();
328       }
329     }
330 
331     // Verify edges
332     for (const EdgeSpec& e : e_edges) {
333       ASSERT_TRUE(a_edges.find(e) != a_edges.end())
334           << "Failed to find expected edge " << e.ToString()
335           << " in fdef: " << fdef.DebugString();
336     }
337     for (const EdgeSpec& e : c_edges) {
338       ASSERT_TRUE(a_edges.find(e) != a_edges.end())
339           << "Failed to find expected control edge " << e.ToString()
340           << " in fdef: " << fdef.DebugString();
341     }
342 
343     // If caller specified all edges, check that we have seen all
344     if (is_exact_edges) {
345       ASSERT_EQ(e_edges.size() + c_edges.size(), a_edges.size())
346           << "Expected edges: " << ToString(e_edges)
347           << " Expected Control edges: " << ToString(c_edges)
348           << " Actual edges: " << ToString(a_edges)
349           << " in fdef: " << fdef.DebugString();
350     }
351   }
352 
VerifyFDef(const std::unordered_set<string> & nodes,const std::vector<IOSpec> & inputs,const std::vector<IOSpec> & outputs,const std::vector<EdgeSpec> & e_edges,const std::vector<EdgeSpec> & c_edges,bool is_exact_edges=true)353   void VerifyFDef(const std::unordered_set<string>& nodes,
354                   const std::vector<IOSpec>& inputs,
355                   const std::vector<IOSpec>& outputs,
356                   const std::vector<EdgeSpec>& e_edges,  // expected edges
357                   const std::vector<EdgeSpec>& c_edges,  // expected ctrl edges
358                   bool is_exact_edges = true) {
359     tensorflow::FunctionDef fdef;
360     ASSERT_TRUE(GetFunctionDef(func_, &fdef));
361     VerifyFDefNodes(fdef, nodes);
362     VerifyFDefInputs(fdef, inputs);
363     VerifyFDefOutputs(fdef, outputs);
364     VerifyFDefEdges(fdef, e_edges, c_edges, is_exact_edges);
365   }
366 
367   // Serialize func_ to fdef and import it back
Reincarnate()368   void Reincarnate() {
369     // func_ -> fdef
370     tensorflow::FunctionDef fdef;
371     ASSERT_TRUE(GetFunctionDef(func_, &fdef));
372     TF_DeleteFunction(func_);
373 
374     // fdef -> func_
375     string buf;
376     ASSERT_TRUE(fdef.SerializeToString(&buf));
377     func_ = TF_FunctionImportFunctionDef(buf.data(), buf.size(), s_);
378     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
379   }
380 
GetAttr(const char * attr_name,AttrValue * out_attr)381   void GetAttr(const char* attr_name, AttrValue* out_attr) {
382     TF_Buffer* attr_buf = TF_NewBuffer();
383     TF_FunctionGetAttrValueProto(func_, attr_name, attr_buf, s_);
384     ASSERT_TRUE(out_attr->ParseFromArray(attr_buf->data, attr_buf->length));
385     TF_DeleteBuffer(attr_buf);
386   }
387 
388   const char* func_name_ = "MyFunc";
389   const char* func_node_name_ = "MyFunc_0";
390   TF_Status* s_;
391   TF_Graph* func_graph_;
392   TF_Graph* host_graph_;
393   TF_Function* func_;
394 
395   // Workaround for not being able to initialize empty map using {}
396   std::unordered_set<string> empty_;
397 };
398 
TEST_F(CApiFunctionTest,OneOp_ZeroInputs_OneOutput)399 TEST_F(CApiFunctionTest, OneOp_ZeroInputs_OneOutput) {
400   /*
401    *                constant
402    *                   |
403    *                   v
404    */
405   // Define
406   TF_Operation* c = ScalarConst(10, func_graph_, s_, "scalar10");
407   Define(-1, {}, {}, {c}, {});
408 
409   // Use, run, and verify
410   TF_Operation* func_op = Use({});
411   Run({}, func_op, 10);
412   VerifyFDef({"scalar10_0"}, {}, {{"scalar10", DT_INT32}},
413              {{"scalar10_0:output:0", "scalar10"}}, {});
414 }
415 
TEST_F(CApiFunctionTest,OneOp_OneInput_OneOutput)416 TEST_F(CApiFunctionTest, OneOp_OneInput_OneOutput) {
417   /*
418    *                   |
419    *                   v
420    *                 negate
421    *                   |
422    *                   v
423    */
424   // Define
425   TF_Operation* feed = Placeholder(func_graph_, s_);
426   TF_Operation* neg = Neg(feed, func_graph_, s_);
427   Define(-1, {}, {feed}, {neg}, {});
428 
429   // Use, run, and verify
430   TF_Operation* func_feed = Placeholder(host_graph_, s_);
431   TF_Operation* func_op = Use({func_feed});
432   Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
433   VerifyFDef({"neg_0"}, {{"feed", DT_INT32}}, {{"neg", DT_INT32}},
434              {{"feed", "neg_0:0"}, {"neg_0:y:0", "neg"}}, {});
435 }
436 
TEST_F(CApiFunctionTest,OneOutput_OutputNames)437 TEST_F(CApiFunctionTest, OneOutput_OutputNames) {
438   /*
439    *                   |
440    *                   v
441    *                 negate
442    *                   |
443    *                   v
444    */
445   // Define
446   TF_Operation* feed = Placeholder(func_graph_, s_);
447   TF_Operation* neg = Neg(feed, func_graph_, s_);
448   Define(-1, {}, {feed}, {neg}, {"negated_num"});
449 
450   // Use, run, and verify
451   TF_Operation* func_feed = Placeholder(host_graph_, s_);
452   TF_Operation* func_op = Use({func_feed});
453   Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
454   VerifyFDef({"neg"}, {{"feed", DT_INT32}}, {{"negated_num", DT_INT32}},
455              {{"feed", "neg:0"}, {"neg:y:0", "negated_num"}}, {});
456 }
457 
TEST_F(CApiFunctionTest,OutputNames_SameNameAsInput)458 TEST_F(CApiFunctionTest, OutputNames_SameNameAsInput) {
459   /*
460    *                   |
461    *                   v
462    *                 negate
463    *                   |
464    *                   v
465    */
466   // Define
467   TF_Operation* feed = Placeholder(func_graph_, s_, "negation");
468   TF_Operation* neg = Neg(feed, func_graph_, s_, "neg");
469   Define(-1, {}, {feed}, {neg}, {"negation"});
470 
471   // Use, run, and verify
472   TF_Operation* func_feed = Placeholder(host_graph_, s_);
473   TF_Operation* func_op = Use({func_feed});
474   Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
475   VerifyFDef({"neg"}, {{"negation_0", DT_INT32}}, {{"negation", DT_INT32}},
476              {{"negation_0", "neg:0"}, {"neg:y:0", "negation"}}, {});
477 }
478 
TEST_F(CApiFunctionTest,ZeroOps_Identity)479 TEST_F(CApiFunctionTest, ZeroOps_Identity) {
480   /*
481    *                   |
482    *                   |
483    *                   |
484    *                   v
485    */
486   // Define
487   TF_Operation* feed = Placeholder(func_graph_, s_);
488   Define(-1, {}, {feed}, {feed}, {});
489 
490   // Use, run, and verify
491   TF_Operation* func_feed = Placeholder(host_graph_, s_);
492   TF_Operation* func_op = Use({func_feed});
493   Run({{func_feed, Int32Tensor(3)}}, func_op, 3);
494   VerifyFDef(empty_, {{"feed_0", DT_INT32}}, {{"feed", DT_INT32}},
495              {{"feed_0", "feed"}}, {});
496 }
497 
TEST_F(CApiFunctionTest,ZeroOps_Permutation)498 TEST_F(CApiFunctionTest, ZeroOps_Permutation) {
499   /*
500    *                   |   |
501    *                   \  /
502    *                    \/
503    *                    x
504    *                   /\
505    *                  /  \
506    *                 |   |
507    *                 v   v
508    */
509   // Define
510   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
511   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
512   Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {});
513 
514   // Use, run, and verify
515   TF_Operation* two = ScalarConst(2, host_graph_, s_);
516   TF_Operation* func_feed = Placeholder(host_graph_, s_);
517   TF_Operation* func_op = Use({two, func_feed});
518   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
519   VerifyFDef(empty_, M({{"feed1_0"}, {"feed2_0"}}), M({{"feed2"}, {"feed1"}}),
520              {{"feed1_0", "feed1"}, {"feed2_0", "feed2"}}, {});
521 }
522 
TEST_F(CApiFunctionTest,ZeroOps_Permutation_OutputNames)523 TEST_F(CApiFunctionTest, ZeroOps_Permutation_OutputNames) {
524   /*
525    *                   |   |
526    *                   \  /
527    *                    \/
528    *                    x
529    *                   /\
530    *                  /  \
531    *                 |   |
532    *                 v   v
533    */
534   // Define
535   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
536   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
537   Define(-1, {}, {feed1, feed2}, {feed2, feed1}, {"first", "second"});
538 
539   // Use, run, and verify
540   TF_Operation* two = ScalarConst(2, host_graph_, s_);
541   TF_Operation* func_feed = Placeholder(host_graph_, s_);
542   TF_Operation* func_op = Use({two, func_feed});
543   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {3, 2});
544   VerifyFDef(empty_, M({{"feed1"}, {"feed2"}}), M({{"first"}, {"second"}}),
545              {{"feed1", "second"}, {"feed2", "first"}}, {});
546 }
547 
TEST_F(CApiFunctionTest,OneOp_TwoInputs_OneOutput)548 TEST_F(CApiFunctionTest, OneOp_TwoInputs_OneOutput) {
549   /*
550    *                  |  |
551    *                  v  v
552    *                  add
553    *                   |
554    *                   v
555    */
556   // Define
557   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
558   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
559   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
560   Define(-1, {}, {feed1, feed2}, {add}, {});
561 
562   // Use, run, and verify
563   TF_Operation* two = ScalarConst(2, host_graph_, s_);
564   TF_Operation* func_feed = Placeholder(host_graph_, s_);
565   TF_Operation* func_op = Use({two, func_feed});
566   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
567   VerifyFDef(
568       {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
569       {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}}, {});
570 }
571 
TEST_F(CApiFunctionTest,OneOp_TwoInputs_ZeroOutputs)572 TEST_F(CApiFunctionTest, OneOp_TwoInputs_ZeroOutputs) {
573   /*
574    *                  |  |
575    *                  v  v
576    *                  add
577    *
578    *            (output ignored)
579    */
580   // Define
581   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
582   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
583   Add(feed1, feed2, func_graph_, s_);
584   Define(-1, {}, {feed1, feed2}, {}, {});
585 
586   // Use, run, and verify
587   TF_Operation* two = ScalarConst(2, host_graph_, s_);
588   TF_Operation* func_feed = Placeholder(host_graph_, s_);
589   Use({two, func_feed});
590   VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), {},
591              {{"feed1", "add:0"}, {"feed2", "add:1"}}, {});
592 }
593 
TEST_F(CApiFunctionTest,TwoOps_ThreeInputs_OneOutput)594 TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_OneOutput) {
595   /*
596    *                  |  |   |
597    *                  v  v   /
598    *                  add1  /
599    *                   |   |
600    *                   v   v
601    *                   add2
602    *                    |
603    *                    v
604    */
605   // Define
606   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
607   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
608   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
609   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
610   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
611   Define(-1, {}, {feed1, feed2, feed3}, {add2}, {});
612 
613   // Use, run, and verify
614   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
615   TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
616   TF_Operation* func_feed = Placeholder(host_graph_, s_);
617   TF_Operation* func_op = Use({two, ten, func_feed});
618   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 10 + 3);
619   VerifyFDef({"add1", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
620              M({{"add2"}}),
621              {{"feed1", "add1:0"},
622               {"feed2", "add1:1"},
623               {"add1:sum:0", "add2_0:0"},
624               {"feed3", "add2_0:1"},
625               {"add2_0:sum:0", "add2"}},
626              {});
627 }
628 
TEST_F(CApiFunctionTest,OneOp_TwoInputs_TwoDuplicateOutputs)629 TEST_F(CApiFunctionTest, OneOp_TwoInputs_TwoDuplicateOutputs) {
630   /*
631    *                  |  |
632    *                  v  v
633    *                  add
634    *                   |
635    *                 +-+-+
636    *                 |   |
637    *                 v   v
638    */
639   // Define
640   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
641   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
642   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
643   Define(-1, {}, {feed1, feed2}, {add, add}, {});
644 
645   // Use, run, and verify
646   TF_Operation* two = ScalarConst(2, host_graph_, s_);
647   TF_Operation* func_feed = Placeholder(host_graph_, s_);
648   TF_Operation* func_op = Use({two, func_feed});
649   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
650   VerifyFDef({"add_1"}, M({{"feed1"}, {"feed2"}}), M({{"add"}, {"add_0"}}),
651              {{"feed1", "add_1:0"},
652               {"feed2", "add_1:1"},
653               {"add_1:sum:0", "add"},
654               {"add_1:sum:0", "add_0"}},
655              {});
656 }
657 
TEST_F(CApiFunctionTest,TwoDuplicateOutputs_OutputNames)658 TEST_F(CApiFunctionTest, TwoDuplicateOutputs_OutputNames) {
659   /*
660    *                  |  |
661    *                  v  v
662    *                  add
663    *                   |
664    *                 +-+-+
665    *                 |   |
666    *                 v   v
667    */
668   // Define
669   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
670   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
671   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
672   Define(-1, {}, {feed1, feed2}, {add, add}, {"out1", "out2"});
673 
674   // Use, run, and verify
675   TF_Operation* two = ScalarConst(2, host_graph_, s_);
676   TF_Operation* func_feed = Placeholder(host_graph_, s_);
677   TF_Operation* func_op = Use({two, func_feed});
678   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {5, 5});
679   VerifyFDef({"add"}, M({{"feed1"}, {"feed2"}}), M({{"out1"}, {"out2"}}),
680              {{"feed1", "add:0"},
681               {"feed2", "add:1"},
682               {"add:sum:0", "out1"},
683               {"add:sum:0", "out2"}},
684              {});
685 }
686 
TEST_F(CApiFunctionTest,TwoOps_ThreeInputs_TwoOutputs)687 TEST_F(CApiFunctionTest, TwoOps_ThreeInputs_TwoOutputs) {
688   /*
689    *                  |  |  |
690    *                  v  v  /
691    *                  add  /
692    *                   |  |
693    *                 +-+  |
694    *                 | |  |
695    *                 | v  v
696    *                 | add
697    *                 |  |
698    *                 v  v
699    */
700   // Define
701   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
702   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
703   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
704   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
705   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
706   Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {});
707 
708   // Use, run, and verify
709   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
710   TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
711   TF_Operation* func_feed = Placeholder(host_graph_, s_);
712   TF_Operation* func_op = Use({two, ten, func_feed});
713   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
714   VerifyFDef({"add1_0", "add2_0"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
715              M({{"add1"}, {"add2"}}),
716              {{"feed1", "add1_0:0"},
717               {"feed2", "add1_0:1"},
718               {"add1_0:sum:0", "add2_0:0"},
719               {"feed3", "add2_0:1"},
720               {"add1_0:sum:0", "add1"},
721               {"add2_0:sum:0", "add2"}},
722              {});
723 }
724 
TEST_F(CApiFunctionTest,FromSubsetOfOps)725 TEST_F(CApiFunctionTest, FromSubsetOfOps) {
726   /*
727    *                  |  |  |
728    *                  v  v  /
729    *                  add  /
730    *                   |  |
731    *               +---+--+---+
732    *  Ops used     |   |  |   |
733    *  for func     |   v  v   |
734    *     |         |   add    |
735    *     +-------> |    |     |
736    *               |    v     |
737    *               |          |
738    *               +----------+
739    */
740   // Define
741   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
742   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
743   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
744   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
745   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
746   Define(1, {add2}, {add1, feed3}, {add2}, {});
747 
748   // Use, run, and verify
749   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
750   TF_Operation* func_feed = Placeholder(host_graph_, s_);
751   TF_Operation* func_op = Use({two, func_feed});
752   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
753   VerifyFDef(
754       {"add2_0"}, M({{"add1"}, {"feed3"}}), M({{"add2"}}),
755       {{"add1", "add2_0:0"}, {"feed3", "add2_0:1"}, {"add2_0:sum:0", "add2"}},
756       {});
757 }
758 
TEST_F(CApiFunctionTest,UsingOneOutputOfSplit)759 TEST_F(CApiFunctionTest, UsingOneOutputOfSplit) {
760   /*
761    *                      feed
762    *                       |
763    *             +---------+---+
764    *             | const0  |   |
765    *             |    |    |   |
766    *             |    v    /   |
767    *             |    split    |
768    *             |   |  |  |   |
769    *             |   v  |  v   |
770    *             |      |      |
771    *             +------+------+
772    *                    |
773    *                    v
774    *
775    *  Only the second output from split is used as function output
776    */
777   // Define
778   TF_Operation* feed = Placeholder(func_graph_, s_);
779   TF_Operation* split = Split3(feed, func_graph_, s_);
780   DefineT(-1, {}, {{feed, 0}}, {{split, 1}}, {});
781 
782   // Use, run, and verify
783   TF_Operation* func_feed = Placeholder(host_graph_, s_);
784   TF_Operation* func_op = Use({func_feed});
785   RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}}, {{func_op, 0}},
786        {{3, 4}});
787   VerifyFDef({"split3_const0", "split3_0"}, M({{"feed"}}), M({{"split3"}}),
788              {{"split3_const0:output:0", "split3_0:0"},
789               {"feed", "split3_0:1"},
790               {"split3_0:output:1", "split3"}},
791              {});
792 }
793 
TEST_F(CApiFunctionTest,UsingTwoOutputsOfSplit)794 TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplit) {
795   /*
796    *                      feed
797    *                       |
798    *             +---------+---+
799    *             | const0  |   |
800    *             |    |    |   |
801    *             |    v    /   |
802    *             |    split    |
803    *             |   |  |  |   |
804    *             |   |  v  |   |
805    *             |   |     |   |
806    *             +---+-----+---+
807    *                 |     |
808    *                 v     v
809    *
810    *  Second output from split is not used as function output
811    */
812   // Define
813   TF_Operation* feed = Placeholder(func_graph_, s_);
814   TF_Operation* split = Split3(feed, func_graph_, s_);
815   DefineT(-1, {}, {{feed, 0}}, {{split, 0}, {split, 2}}, {});
816 
817   // Use, run, and verify
818   TF_Operation* func_feed = Placeholder(host_graph_, s_);
819   TF_Operation* func_op = Use({func_feed});
820   RunT({{func_feed, Int32Tensor({1, 2, 3, 4, 5, 6})}},
821        {{func_op, 0}, {func_op, 1}}, {{1, 2}, {5, 6}});
822   VerifyFDef({"split3_const0", "split3_1"}, M({{"feed"}}),
823              M({{"split3"}, {"split3_0"}}),
824              {{"split3_const0:output:0", "split3_1:0"},
825               {"feed", "split3_1:1"},
826               {"split3_1:output:0", "split3"},
827               {"split3_1:output:2", "split3_0"}},
828              {});
829 }
830 
TEST_F(CApiFunctionTest,UsingTwoOutputsOfSplitAsInputs)831 TEST_F(CApiFunctionTest, UsingTwoOutputsOfSplitAsInputs) {
832   /*
833    *                    |
834    *                    v
835    *                  split
836    *                 |  |  |
837    *                 |  v  |
838    *                 |     |
839    *             +---+-----+---+
840    *             |   |     |   |
841    *             |   v     v   |
842    *             |     add     |
843    *             |      |      |
844    *             |      |      |
845    *             +------+------+
846    *                    |
847    *                    v
848    */
849   // Define
850   TF_Operation* feed = Placeholder(func_graph_, s_);
851   TF_Operation* split = Split3(feed, func_graph_, s_);
852   TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
853   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
854   DefineT(1, {add}, {{split, 0}, {split, 2}}, {{add, 0}}, {});
855 
856   // Use, run, and verify
857   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
858   TF_Operation* func_feed = Placeholder(host_graph_, s_);
859   TF_Operation* func_op = Use({two, func_feed});
860   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
861   VerifyFDef(
862       {"add_0"}, M({{"split3"}, {"split3_0"}}), M({{"add"}}),
863       {{"split3", "add_0:0"}, {"split3_0", "add_0:1"}, {"add_0:sum:0", "add"}},
864       {});
865 }
866 
TEST_F(CApiFunctionTest,NodesUsedInInputsMustHaveSingleOutput)867 TEST_F(CApiFunctionTest, NodesUsedInInputsMustHaveSingleOutput) {
868   /*
869    *                    |
870    *                    v
871    *                  split
872    *                 |  |  |
873    *                 |  v  |
874    *                 |     |
875    *       input --->|     |<--- input
876    *                 |     |
877    *                 v     v
878    *                   add
879    *                    |
880    *                    |
881    *                    v
882    */
883   // Define
884   TF_Tensor* tensor_123 = Int32Tensor({1, 2, 3});
885   TF_Operation* c = Const(tensor_123, func_graph_, s_, "const_array");
886   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
887   TF_Operation* split = Split3(c, func_graph_, s_);
888   TF_Operation* add = Add({split, 0}, {split, 2}, func_graph_, s_);
889   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
890   DefineT(-1, {}, {{split, 0}, {split, 2}}, {{add, 0}}, {}, true);
891   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
892   EXPECT_EQ(string("When `num_opers` is set to -1, nodes referenced in "
893                    "`inputs` must have a single output. Node split3 has "
894                    "3 outputs. Encountered while creating function 'MyFunc'"),
895             string(TF_Message(s_)));
896 
897   TF_DeleteTensor(tensor_123);
898 }
899 
TEST_F(CApiFunctionTest,FunctionWithWhileLoop)900 TEST_F(CApiFunctionTest, FunctionWithWhileLoop) {
901   // Inputs to the while loop and the function as a whole
902   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
903   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
904 
905   // Outputs of the while loop corresponding to the two inputs above
906   // The first one will the function's output
907   std::vector<TF_Output> outputs;
908 
909   // Add while loop to func_graph_
910   {
911     // The inputs to the while loop
912     std::vector<TF_Output> inputs = {{feed1, 0}, {feed2, 0}};
913     std::unique_ptr<TF_WhileParams> params(new TF_WhileParams(
914         TF_NewWhile(func_graph_, &inputs[0], inputs.size(), s_)));
915     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
916     params->name = "test_loop";
917 
918     // Initialize outputs so we can easily detect errors/bugs
919     outputs.resize(2, {nullptr, -1});
920 
921     // Create loop: while (input1 < input2) input1 += input2 + 1
922     TF_Operation* less_than = LessThan(
923         params->cond_inputs[0], params->cond_inputs[1], params->cond_graph, s_);
924     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
925     params->cond_output = {less_than, 0};
926 
927     TF_Operation* add1 = Add(params->body_inputs[0], params->body_inputs[1],
928                              params->body_graph, s_, "add1");
929     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
930     TF_Operation* one = ScalarConst(1, params->body_graph, s_);
931     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
932     TF_Operation* add2 = Add(add1, one, params->body_graph, s_, "add2");
933     ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
934     params->body_outputs[0] = {add2, 0};
935     params->body_outputs[1] = params->body_inputs[1];
936 
937     // Finalize while loop
938     TF_FinishWhile(params.get(), s_, &outputs[0]);
939     EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
940   }
941 
942   // Define function, use it in graph, and run
943   DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {outputs[0]}, {});
944   TF_Operation* five = ScalarConst(5, host_graph_, s_, "five");
945   TF_Operation* func_feed = Placeholder(host_graph_, s_);
946   TF_Operation* func_op = Use({func_feed, five});
947   Run({{func_feed, Int32Tensor(2)}}, func_op, 2 /*+=*/ + 5 + 1);
948 
949   // Verify input, output, and subset of edges in fdef.
950   // The subset of edges we verify is a chain between feed1 and output to
951   // make sure that the correct output is picked.
952   tensorflow::FunctionDef fdef;
953   ASSERT_TRUE(GetFunctionDef(func_, &fdef));
954   VerifyFDefInputs(fdef, M({{"feed1"}, {"feed2"}}));
955   VerifyFDefOutputs(fdef, M({{"test_loop_exit"}}));
956   VerifyFDefEdges(fdef,
957                   {{"feed1", "test_loop/Enter:0"},
958                    {"test_loop/Enter:output:0", "test_loop/Merge:0"},
959                    {"test_loop/Merge:output:0", "test_loop/Switch:0"},
960                    {"test_loop/Switch:output_false:0", "test_loop/Exit:0"},
961                    {"test_loop/Exit:output:0", "test_loop_exit"}},
962                   {}, false);
963 }
964 
TEST_F(CApiFunctionTest,ControlDependency)965 TEST_F(CApiFunctionTest, ControlDependency) {
966   /*
967    *                  |  |    scalar
968    *                  |  |    .
969    *                  v  v   . <---- control dependency
970    *                  add < -
971    *                   |
972    *                   v
973    */
974   // Define
975   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
976   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
977   TF_Operation* five = ScalarConst(5, func_graph_, s_);
978   TF_Operation* add =
979       AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
980   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
981   Define(-1, {}, {feed1, feed2}, {add}, {});
982 
983   // Use, run, and verify
984   TF_Operation* two = ScalarConst(2, host_graph_, s_);
985   TF_Operation* func_feed = Placeholder(host_graph_, s_);
986   TF_Operation* func_op = Use({two, func_feed});
987   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
988   VerifyFDef(
989       {"add_0", "scalar"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
990       {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
991       {{"^scalar", "add_0:2"}});
992 }
993 
TEST_F(CApiFunctionTest,ControlDependencyOutsideOfBody)994 TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody) {
995   /*
996    *                  |  |    scalar
997    *                  |  |    .
998    *                  v  v   . <---- control dependency
999    *                  add < -
1000    *                   |
1001    *                   v
1002    */
1003   // Define
1004   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1005   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1006   TF_Operation* five = ScalarConst(5, func_graph_, s_);
1007   TF_Operation* add =
1008       AddWithCtrlDependency(feed1, feed2, func_graph_, five, s_);
1009   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1010   Define(1, {add}, {feed1, feed2}, {add}, {}, true);
1011   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1012   EXPECT_EQ(string("The source of control edge [id=3 scalar:-1 -> add:-1] "
1013                    "is not in the body. Encountered while creating "
1014                    "function 'MyFunc'"),
1015             string(TF_Message(s_)));
1016 }
1017 
TEST_F(CApiFunctionTest,ControlDependencyOutsideOfBody_FromInputNode)1018 TEST_F(CApiFunctionTest, ControlDependencyOutsideOfBody_FromInputNode) {
1019   /*
1020    *                  |  |.
1021    *                  |  |  .
1022    *                  |  |   .
1023    *                  v  v   . <---- control dependency
1024    *                  add < -
1025    *                   |
1026    *                   v
1027    */
1028   // Define
1029   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1030   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1031   TF_Operation* add =
1032       AddWithCtrlDependency(feed1, feed2, func_graph_, feed1, s_);
1033   EXPECT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1034   Define(-1, {}, {feed1, feed2}, {add}, {});
1035 
1036   // Use, run, and verify
1037   TF_Operation* two = ScalarConst(2, host_graph_, s_);
1038   TF_Operation* func_feed = Placeholder(host_graph_, s_);
1039   TF_Operation* func_op = Use({two, func_feed});
1040   Run({{func_feed, Int32Tensor(3)}}, func_op, 2 + 3);
1041   VerifyFDef(
1042       {"add_0"}, M({{"feed1"}, {"feed2"}}), M({{"add"}}),
1043       {{"feed1", "add_0:0"}, {"feed2", "add_0:1"}, {"add_0:sum:0", "add"}},
1044       {{"^feed1", "add_0:2"}});
1045 }
1046 
TEST_F(CApiFunctionTest,DuplicateInputsAreNotAllowed)1047 TEST_F(CApiFunctionTest, DuplicateInputsAreNotAllowed) {
1048   /*
1049    *                  feed
1050    *                   |
1051    *                  +++
1052    *                  | |
1053    *              +---+-+---+
1054    *              |   | |   |
1055    *              |   v v   |
1056    *              |   add   |
1057    *              |    |    |
1058    *              |    |    |
1059    *              +----+----+
1060    *                   |
1061    *                   v
1062    */
1063   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1064   TF_Operation* add = Add(feed1, feed1, func_graph_, s_);
1065   Define(-1, {}, {feed1, feed1}, {add}, {}, true);
1066   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1067   EXPECT_EQ(
1068       string("TF_Output feed1:0 appears more than once in the input list"),
1069       string(TF_Message(s_)));
1070 }
1071 
TEST_F(CApiFunctionTest,DuplicateOutputNamesAreNotAllowed)1072 TEST_F(CApiFunctionTest, DuplicateOutputNamesAreNotAllowed) {
1073   /*
1074    *                  |  |  |
1075    *                  v  v  /
1076    *                  add  /
1077    *                   |  |
1078    *                 +-+  |
1079    *                 | |  |
1080    *                 | v  v
1081    *                 | add
1082    *                 |  |
1083    *                 v  v
1084    */
1085   // Define
1086   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1087   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1088   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
1089   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
1090   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
1091   Define(-1, {}, {feed1, feed2, feed3}, {add1, add2}, {"my_out", "my_out"},
1092          true);
1093   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1094   EXPECT_EQ(string("Cannot have duplicate output names. Name 'my_out' "
1095                    "appears more than once in 'output_names' array."),
1096             string(TF_Message(s_)));
1097 }
1098 
TEST_F(CApiFunctionTest,InvalidInputTensor_HighIndex)1099 TEST_F(CApiFunctionTest, InvalidInputTensor_HighIndex) {
1100   /*
1101    *                  |  |
1102    *                  v  v
1103    *                  add
1104    *                   |
1105    *                   v
1106    */
1107   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1108   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1109   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1110   DefineT(-1, {}, {{feed1, 0}, {feed2, 2}}, {{add, 0}}, {}, true);
1111   EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
1112   EXPECT_EQ(string("Node 'feed2' (type: 'Placeholder', num of outputs: 1) does "
1113                    "not have output 2\n\tEncountered while processing "
1114                    "input 1 into function 'MyFunc'"),
1115             string(TF_Message(s_)));
1116 }
1117 
TEST_F(CApiFunctionTest,InvalidInputTensor_BadNodePtr)1118 TEST_F(CApiFunctionTest, InvalidInputTensor_BadNodePtr) {
1119   /*
1120    *                  |  |
1121    *                  v  v
1122    *                  add
1123    *                   |
1124    *                   v
1125    */
1126   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1127   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1128   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1129   DefineT(-1, {}, {{feed1, 0}, {nullptr, 0}}, {{add, 0}}, {}, true);
1130   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1131   EXPECT_EQ(string("Node is null\n\tEncountered while processing input 1 "
1132                    "into function 'MyFunc'"),
1133             string(TF_Message(s_)));
1134 }
1135 
TEST_F(CApiFunctionTest,InvalidOutputTensor_HighIndex)1136 TEST_F(CApiFunctionTest, InvalidOutputTensor_HighIndex) {
1137   /*
1138    *                  |  |
1139    *                  v  v
1140    *                  add
1141    *                   |
1142    *                   v
1143    */
1144   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1145   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1146   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1147   DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{add, 3}}, {}, true);
1148   EXPECT_EQ(TF_OUT_OF_RANGE, TF_GetCode(s_));
1149   EXPECT_EQ(string("Node 'add' (type: 'AddN', num of outputs: 1) does "
1150                    "not have output 3\n\tEncountered while processing "
1151                    "output 0 from function 'MyFunc'"),
1152             string(TF_Message(s_)));
1153 }
1154 
TEST_F(CApiFunctionTest,InvalidOutputTensor_BadNodePtr)1155 TEST_F(CApiFunctionTest, InvalidOutputTensor_BadNodePtr) {
1156   /*
1157    *                  |  |
1158    *                  v  v
1159    *                  add
1160    *                   |
1161    *                   v
1162    */
1163   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1164   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1165   Add(feed1, feed2, func_graph_, s_);
1166   DefineT(-1, {}, {{feed1, 0}, {feed2, 0}}, {{nullptr, 3}}, {}, true);
1167   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1168   EXPECT_EQ(string("Node is null\n\tEncountered while processing output 0 "
1169                    "from function 'MyFunc'"),
1170             string(TF_Message(s_)));
1171 }
1172 
TEST_F(CApiFunctionTest,NodeMissingInput)1173 TEST_F(CApiFunctionTest, NodeMissingInput) {
1174   /*
1175    *        input---> |  | <----missing input
1176    *                  v  v
1177    *        body----> add
1178    *                   |
1179    *                   v
1180    */
1181   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1182   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1183   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1184   DefineT(1, {add}, {{feed1, 0}}, {{add, 0}}, {}, true);
1185   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1186   EXPECT_EQ(string("Input 1, 'feed2:0', of node 'add' in function 'MyFunc' "
1187                    "is not available. You might need to include it in inputs "
1188                    "or include its source node in the body"),
1189             string(TF_Message(s_)));
1190 }
1191 
TEST_F(CApiFunctionTest,OutputOpNotInBody)1192 TEST_F(CApiFunctionTest, OutputOpNotInBody) {
1193   /*
1194    *                  |  |
1195    *                  v  v
1196    *                  add    scalar    (scalar not included in body)
1197    *                   |       |
1198    *                   v       v       (function has two outputs)
1199    */
1200   // Define
1201   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1202   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1203   TF_Operation* scalar = ScalarConst(2, func_graph_, s_);
1204   TF_Operation* add = Add(feed1, feed2, func_graph_, s_);
1205   Define(1, {add}, {feed1, feed2}, {add, scalar}, {}, true);
1206   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1207   EXPECT_EQ(string("TF_Output scalar:0 is neither in the function body nor "
1208                    "among function inputs. Encountered while creating "
1209                    "function 'MyFunc'"),
1210             string(TF_Message(s_)));
1211 }
1212 
DefineFunction(const char * name,TF_Function ** func,const char * description=nullptr,bool append_hash=false)1213 void DefineFunction(const char* name, TF_Function** func,
1214                     const char* description = nullptr,
1215                     bool append_hash = false) {
1216   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1217       TF_NewGraph(), TF_DeleteGraph);
1218   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1219                                                            TF_DeleteStatus);
1220 
1221   TF_Operation* feed = Placeholder(func_graph.get(), s.get());
1222   TF_Operation* neg = Neg(feed, func_graph.get(), s.get());
1223 
1224   TF_Output inputs[] = {{feed, 0}};
1225   TF_Output outputs[] = {{neg, 0}};
1226   *func = TF_GraphToFunction(func_graph.get(), name, append_hash, -1,
1227                              /*opers=*/nullptr, 1, inputs, 1, outputs,
1228                              /*output_names=*/nullptr,
1229                              /*opts=*/nullptr, description, s.get());
1230   ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1231   ASSERT_NE(*func, nullptr);
1232 }
1233 
1234 REGISTER_OP("CustomOp")
1235     .Output("output: float32")
1236     .Attr("index: int")
1237     .SetShapeFn(tensorflow::shape_inference::UnknownShape);
1238 
NodeWithPlaceholderAttrHelper(TF_Graph * graph,TF_Status * s,const char * name,const char * placeholder,TF_Operation ** op)1239 void NodeWithPlaceholderAttrHelper(TF_Graph* graph, TF_Status* s,
1240                                    const char* name, const char* placeholder,
1241                                    TF_Operation** op) {
1242   TF_OperationDescription* desc = TF_NewOperation(graph, "CustomOp", name);
1243   TF_SetAttrPlaceholder(desc, "index", placeholder);
1244   *op = TF_FinishOperation(desc, s);
1245   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1246   ASSERT_NE(*op, nullptr);
1247 }
1248 
TEST_F(CApiFunctionTest,GraphToFunctionDefWithPlaceholderAttr)1249 TEST_F(CApiFunctionTest, GraphToFunctionDefWithPlaceholderAttr) {
1250   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1251       TF_NewGraph(), TF_DeleteGraph);
1252   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1253                                                            TF_DeleteStatus);
1254 
1255   TF_Operation *node1, *node2, *node3;
1256   NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node1", "v1",
1257                                 &node1);
1258   NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node2", "v1",
1259                                 &node2);
1260   NodeWithPlaceholderAttrHelper(func_graph.get(), s.get(), "node3", "v2",
1261                                 &node3);
1262 
1263   TF_Output outputs[] = {{node1, 0}, {node2, 0}, {node3, 0}};
1264   func_ = TF_GraphToFunction(
1265       func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
1266       /*opers=*/nullptr, 0, nullptr, 3, outputs,
1267       /*output_names=*/nullptr,
1268       /*opts=*/nullptr, /*description=*/nullptr, s.get());
1269   ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1270   ASSERT_NE(func_, nullptr);
1271 
1272   // Verify that FunctionDef has 2 attributes, "v1" and "v2".
1273   ASSERT_EQ(func_->fdef.signature().attr().size(), 2);
1274   EXPECT_EQ(func_->fdef.signature().attr(0).name(), "v1");
1275   EXPECT_EQ(func_->fdef.signature().attr(0).type(), "int");
1276   EXPECT_EQ(func_->fdef.signature().attr(1).name(), "v2");
1277   EXPECT_EQ(func_->fdef.signature().attr(1).type(), "int");
1278 }
1279 
NodeWithAttrHelper(TF_Graph * graph,TF_Status * s,const char * name,const char * attr_name,const char * attr_value,TF_Operation ** op)1280 void NodeWithAttrHelper(TF_Graph* graph, TF_Status* s, const char* name,
1281                         const char* attr_name, const char* attr_value,
1282                         TF_Operation** op) {
1283   TF_OperationDescription* desc = TF_NewOperation(graph, "Placeholder", name);
1284   TF_SetAttrType(desc, "dtype", TF_INT32);
1285   TF_SetAttrString(desc, attr_name, attr_value, strlen(attr_value));
1286   *op = TF_FinishOperation(desc, s);
1287   ASSERT_EQ(TF_OK, TF_GetCode(s)) << TF_Message(s);
1288   ASSERT_NE(*op, nullptr);
1289 }
1290 
TEST_F(CApiFunctionTest,GraphToFunctionDefWithArgAttr)1291 TEST_F(CApiFunctionTest, GraphToFunctionDefWithArgAttr) {
1292   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1293       TF_NewGraph(), TF_DeleteGraph);
1294   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1295                                                            TF_DeleteStatus);
1296 
1297   TF_Operation* node;
1298   NodeWithAttrHelper(func_graph.get(), s.get(), "node", "_test_attr", "value",
1299                      &node);
1300 
1301   TF_Output inputs[] = {{node, 0}};
1302   func_ = TF_GraphToFunction(
1303       func_graph.get(), "func", /*append_hash_to_fn_name=*/false, -1,
1304       /*opers=*/nullptr, 1, inputs, 0, nullptr,
1305       /*output_names=*/nullptr,
1306       /*opts=*/nullptr, /*description=*/nullptr, s.get());
1307   ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1308   ASSERT_NE(func_, nullptr);
1309 
1310   // Verify that FunctionDef ArgDef has attributes.
1311   ASSERT_EQ(func_->fdef.arg_attr_size(), 1);
1312   auto arg_attrs = func_->fdef.arg_attr().find(0);
1313   ASSERT_NE(arg_attrs, func_->fdef.arg_attr().end());
1314   auto iter = arg_attrs->second.attr().find("_test_attr");
1315   ASSERT_NE(iter, arg_attrs->second.attr().end());
1316   EXPECT_EQ(iter->second.s(), "value");
1317 }
1318 
TEST_F(CApiFunctionTest,SetGradientAndRun)1319 TEST_F(CApiFunctionTest, SetGradientAndRun) {
1320   // Define the function and its grad
1321   DefineFunction(func_name_, &func_);
1322   TF_Function* grad_func;
1323   DefineFunction("MyGrad", &grad_func);
1324 
1325   // Add func and its gradient to host graph
1326   TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
1327   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1328 
1329   // Verify that function and its grad are in host graph's GraphDef
1330   GraphDef gdef;
1331   GetGraphDef(host_graph_, &gdef);
1332   std::vector<string> func_names = GetFuncNames(gdef);
1333   ASSERT_EQ(2, func_names.size());
1334   ASSERT_EQ(func_name_, func_names[0]);
1335   ASSERT_EQ("MyGrad", func_names[1]);
1336   std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1337   ASSERT_EQ(1, grads.size());
1338   ASSERT_EQ(func_name_, grads[0].first);
1339   ASSERT_EQ("MyGrad", grads[0].second);
1340 
1341   // These calls must be noops
1342   TF_GraphCopyFunction(host_graph_, func_, grad_func, s_);
1343   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1344   TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1345   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1346 
1347   // Delete the gradient func.
1348   // It is safe to delete after adding a copy to host graph.
1349   TF_DeleteFunction(grad_func);
1350 
1351   // Check that GraphDef did not change
1352   GraphDef gdef2;
1353   GetGraphDef(host_graph_, &gdef2);
1354   ASSERT_EQ(gdef.DebugString(), gdef2.DebugString());
1355 
1356   // Use and run func
1357   TF_Operation* func_feed = Placeholder(host_graph_, s_);
1358   TF_Operation* func_op = Use({func_feed});
1359   Run({{func_feed, Int32Tensor(3)}}, func_op, -3);
1360 }
1361 
TEST_F(CApiFunctionTest,SameGradForTwoFunctions)1362 TEST_F(CApiFunctionTest, SameGradForTwoFunctions) {
1363   // Define the functions
1364   TF_Function* func1;
1365   TF_Function* func2;
1366   TF_Function* grad_func;
1367   DefineFunction("FooFunc1", &func1);
1368   DefineFunction("FooFunc2", &func2);
1369   DefineFunction("MyGrad", &grad_func);
1370 
1371   // Make grad_func be a gradient of func1 and func2
1372   TF_GraphCopyFunction(host_graph_, func1, grad_func, s_);
1373   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1374   TF_GraphCopyFunction(host_graph_, func2, grad_func, s_);
1375   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1376 
1377   // Verify that functions and their gradients are in host graph's GraphDef
1378   GraphDef gdef;
1379   GetGraphDef(host_graph_, &gdef);
1380   std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1381   ASSERT_EQ(2, grads.size());
1382   ASSERT_EQ("FooFunc1", grads[0].first);
1383   ASSERT_EQ("MyGrad", grads[0].second);
1384   ASSERT_EQ("FooFunc2", grads[1].first);
1385   ASSERT_EQ("MyGrad", grads[1].second);
1386 
1387   TF_DeleteFunction(func1);
1388   TF_DeleteFunction(func2);
1389   TF_DeleteFunction(grad_func);
1390 }
1391 
TEST_F(CApiFunctionTest,AddFunctionsThenMakeOneGradientOfAnother)1392 TEST_F(CApiFunctionTest, AddFunctionsThenMakeOneGradientOfAnother) {
1393   // Define the functions
1394   TF_Function* func;
1395   TF_Function* grad_func;
1396   DefineFunction("FooFunc", &func);
1397   DefineFunction("MyGrad", &grad_func);
1398 
1399   // Add functions individually
1400   TF_GraphCopyFunction(host_graph_, func, nullptr, s_);
1401   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1402   TF_GraphCopyFunction(host_graph_, grad_func, nullptr, s_);
1403   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1404 
1405   // Check that functions are added but not linked
1406   GraphDef gdef;
1407   GetGraphDef(host_graph_, &gdef);
1408   std::vector<string> func_names = GetFuncNames(gdef);
1409   ASSERT_EQ(2, func_names.size());
1410   ASSERT_EQ("FooFunc", func_names[0]);
1411   ASSERT_EQ("MyGrad", func_names[1]);
1412   ASSERT_EQ(0, GetGradDefs(gdef).size());
1413 
1414   // Make grad_func a gradient of func
1415   TF_GraphCopyFunction(host_graph_, func, grad_func, s_);
1416   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1417 
1418   // Verify that function and its grad are linked
1419   gdef.Clear();
1420   GetGraphDef(host_graph_, &gdef);
1421   std::vector<std::pair<string, string>> grads = GetGradDefs(gdef);
1422   ASSERT_EQ(1, grads.size());
1423   ASSERT_EQ("FooFunc", grads[0].first);
1424   ASSERT_EQ("MyGrad", grads[0].second);
1425 
1426   TF_DeleteFunction(func);
1427   TF_DeleteFunction(grad_func);
1428 }
1429 
TEST_F(CApiFunctionTest,GradientErrorCases)1430 TEST_F(CApiFunctionTest, GradientErrorCases) {
1431   // Define the function
1432   DefineFunction(func_name_, &func_);
1433   TF_Function* grad_func1;
1434   TF_Function* grad_func2;
1435   DefineFunction("MyGrad1", &grad_func1);
1436   DefineFunction("MyGrad2", &grad_func2);
1437 
1438   // func cannot be null
1439   TF_GraphCopyFunction(host_graph_, nullptr, func_, s_);
1440   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1441   EXPECT_EQ(string("'func' argument to TF_GraphCopyFunction cannot be null"),
1442             string(TF_Message(s_)));
1443 
1444   // Cannot change gradient
1445   TF_GraphCopyFunction(host_graph_, func_, grad_func1, s_);
1446   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1447   TF_GraphCopyFunction(host_graph_, func_, grad_func2, s_);
1448   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1449   EXPECT_EQ(string("Cannot assign gradient function 'MyGrad2' to 'MyFunc' "
1450                    "because it already has gradient function 'MyGrad1'"),
1451             string(TF_Message(s_)));
1452 
1453   TF_DeleteFunction(grad_func1);
1454   TF_DeleteFunction(grad_func2);
1455 }
1456 
TEST_F(CApiFunctionTest,ImportFunctionDef)1457 TEST_F(CApiFunctionTest, ImportFunctionDef) {
1458   /*
1459    * Using a fairly complex function with output names
1460    *
1461    *                  |  |  |
1462    *                  v  v  /
1463    *                  add  /
1464    *                   |  |
1465    *            +------+  |
1466    *            |      |  |
1467    *            |      v  v
1468    *            |      add
1469    *            |       |
1470    *            v       v
1471    *    internal_out  final_out
1472    */
1473   // Define
1474   TF_Operation* feed1 = Placeholder(func_graph_, s_, "feed1");
1475   TF_Operation* feed2 = Placeholder(func_graph_, s_, "feed2");
1476   TF_Operation* feed3 = Placeholder(func_graph_, s_, "feed3");
1477   TF_Operation* add1 = Add(feed1, feed2, func_graph_, s_, "add1");
1478   TF_Operation* add2 = Add(add1, feed3, func_graph_, s_, "add2");
1479   Define(-1, {}, {feed1, feed2, feed3}, {add1, add2},
1480          {"internal_out", "final_out"});
1481 
1482   // Save func_ to FunctionDef and import it back
1483   Reincarnate();
1484 
1485   // Use, run, and verify
1486   TF_Operation* two = ScalarConst(2, host_graph_, s_, "two");
1487   TF_Operation* ten = ScalarConst(10, host_graph_, s_, "ten");
1488   TF_Operation* func_feed = Placeholder(host_graph_, s_);
1489   TF_Operation* func_op = Use({two, ten, func_feed});
1490   Run({{func_feed, Int32Tensor(3)}}, {{func_op, 0}, {func_op, 1}}, {12, 15});
1491   VerifyFDef({"add1", "add2"}, M({{"feed1"}, {"feed2"}, {"feed3"}}),
1492              M({{"internal_out"}, {"final_out"}}),
1493              {{"feed1", "add1:0"},
1494               {"feed2", "add1:1"},
1495               {"add1:sum:0", "add2:0"},
1496               {"feed3", "add2:1"},
1497               {"add1:sum:0", "internal_out"},
1498               {"add2:sum:0", "final_out"}},
1499              {});
1500 }
1501 
TEST_F(CApiFunctionTest,ImportFunctionDef_InvalidProto)1502 TEST_F(CApiFunctionTest, ImportFunctionDef_InvalidProto) {
1503   // Invalid protobuf data (protos cannot start with 4 bytes of zeros)
1504   char proto[] = {0x0, 0x0, 0x0, 0x0};
1505   func_ = TF_FunctionImportFunctionDef(proto, 4, s_);
1506   EXPECT_TRUE(func_ == nullptr);
1507   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1508   EXPECT_EQ(string("Invalid FunctionDef given to TF_FunctionImportFunctionDef"),
1509             string(TF_Message(s_)));
1510 }
1511 
TEST_F(CApiFunctionTest,Attribute)1512 TEST_F(CApiFunctionTest, Attribute) {
1513   DefineFunction(func_name_, &func_);
1514 
1515   // Get non existent attribute
1516   TF_Buffer* attr_buf = TF_NewBuffer();
1517   TF_FunctionGetAttrValueProto(func_, "foo_attr", attr_buf, s_);
1518   EXPECT_EQ(TF_INVALID_ARGUMENT, TF_GetCode(s_));
1519   EXPECT_EQ(string("Function 'MyFunc' has no attr named 'foo_attr'."),
1520             string(TF_Message(s_)));
1521   TF_DeleteBuffer(attr_buf);
1522 
1523   // Set attr
1524   tensorflow::AttrValue attr;
1525   attr.set_s("test_attr_value");
1526   string bytes;
1527   attr.SerializeToString(&bytes);
1528   TF_FunctionSetAttrValueProto(func_, "test_attr_name", bytes.data(),
1529                                bytes.size(), s_);
1530   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1531 
1532   // Get attr
1533   AttrValue read_attr;
1534   GetAttr("test_attr_name", &read_attr);
1535   ASSERT_EQ(attr.DebugString(), read_attr.DebugString());
1536 
1537   // Retrieve the same attr after save/restore
1538   Reincarnate();
1539   AttrValue read_attr2;
1540   GetAttr("test_attr_name", &read_attr2);
1541   ASSERT_EQ(attr.DebugString(), read_attr2.DebugString());
1542 }
1543 
TEST_F(CApiFunctionTest,Description)1544 TEST_F(CApiFunctionTest, Description) {
1545   DefineFunction(func_name_, &func_, "Return something");
1546   tensorflow::FunctionDef fdef;
1547   ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1548   ASSERT_EQ(string("Return something"), fdef.signature().description());
1549 }
1550 
TEST_F(CApiFunctionTest,Name)1551 TEST_F(CApiFunctionTest, Name) {
1552   DefineFunction("long_func_name", &func_, "Return something",
1553                  /*append_hash=*/false);
1554   tensorflow::FunctionDef fdef;
1555   ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1556   ASSERT_EQ(string("long_func_name"), fdef.signature().name());
1557 }
1558 
TEST_F(CApiFunctionTest,AppendHash)1559 TEST_F(CApiFunctionTest, AppendHash) {
1560   DefineFunction("func_name_base", &func_, "Return something",
1561                  /*append_hash=*/true);
1562   tensorflow::FunctionDef fdef;
1563   ASSERT_TRUE(GetFunctionDef(func_, &fdef));
1564 #if (__BYTE_ORDER__ == __ORDER_BIG_ENDIAN__)
1565   ASSERT_EQ(string("func_name_base_ZpgUD4x8oqk"), fdef.signature().name());
1566 #else
1567   ASSERT_EQ(string("func_name_base_qaJ8jA8UmGY"), fdef.signature().name());
1568 #endif
1569 }
1570 
TEST_F(CApiFunctionTest,GetOpDef)1571 TEST_F(CApiFunctionTest, GetOpDef) {
1572   DefineFunction(func_name_, &func_);
1573   TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1574   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1575 
1576   // Test we can retrieve function OpDef from graph
1577   TF_Buffer* buffer = TF_NewBuffer();
1578   TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
1579   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1580 
1581   // Sanity check returned OpDef
1582   string data(static_cast<const char*>(buffer->data), buffer->length);
1583   OpDef op_def;
1584   op_def.ParseFromString(data);
1585   EXPECT_EQ(op_def.name(), func_name_);
1586   EXPECT_EQ(op_def.input_arg_size(), 1);
1587   EXPECT_EQ(op_def.output_arg_size(), 1);
1588   EXPECT_FALSE(op_def.is_stateful());
1589 
1590   TF_DeleteBuffer(buffer);
1591 }
1592 
DefineStatefulFunction(const char * name,TF_Function ** func)1593 void DefineStatefulFunction(const char* name, TF_Function** func) {
1594   std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)> func_graph(
1595       TF_NewGraph(), TF_DeleteGraph);
1596   std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> s(TF_NewStatus(),
1597                                                            TF_DeleteStatus);
1598 
1599   TF_Tensor* tensor_shape = Int32Tensor({37, 1});
1600   TF_Operation* shape = Const(tensor_shape, func_graph.get(), s.get(), "shape");
1601   TF_Operation* random =
1602       RandomUniform(shape, TF_FLOAT, func_graph.get(), s.get());
1603 
1604   TF_Output outputs[] = {{random, 0}};
1605   *func = TF_GraphToFunction(func_graph.get(), name,
1606                              /*append_hash_to_fn_name=*/false, -1,
1607                              /*opers=*/nullptr, 0, nullptr, 1, outputs,
1608                              /*output_names=*/nullptr,
1609                              /*opts=*/nullptr, "", s.get());
1610   ASSERT_EQ(TF_OK, TF_GetCode(s.get())) << TF_Message(s.get());
1611   ASSERT_NE(*func, nullptr);
1612   TF_DeleteTensor(tensor_shape);
1613 }
1614 
TEST_F(CApiFunctionTest,StatefulOpDef)1615 TEST_F(CApiFunctionTest, StatefulOpDef) {
1616   DefineStatefulFunction(func_name_, &func_);
1617   TF_GraphCopyFunction(host_graph_, func_, nullptr, s_);
1618   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1619 
1620   // Test we can retrieve function OpDef from graph
1621   TF_Buffer* buffer = TF_NewBuffer();
1622   TF_GraphGetOpDef(host_graph_, func_name_, buffer, s_);
1623   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1624 
1625   // Sanity check returned OpDef
1626   string data(static_cast<const char*>(buffer->data), buffer->length);
1627   OpDef op_def;
1628   op_def.ParseFromString(data);
1629   EXPECT_EQ(op_def.name(), func_name_);
1630   EXPECT_EQ(op_def.input_arg_size(), 0);
1631   EXPECT_EQ(op_def.output_arg_size(), 1);
1632   EXPECT_TRUE(op_def.is_stateful());
1633 
1634   TF_DeleteBuffer(buffer);
1635 }
1636 
AssertEqual(TF_Function * f1,TF_Function * f2)1637 void AssertEqual(TF_Function* f1, TF_Function* f2) {
1638   string s1, s2;
1639   tensorflow::FunctionDef fdef1, fdef2;
1640   ASSERT_TRUE(GetFunctionDef(f1, &fdef1));
1641   ASSERT_TRUE(GetFunctionDef(f2, &fdef2));
1642   SerializeToStringDeterministic(fdef1, &s1);
1643   SerializeToStringDeterministic(fdef2, &s2);
1644   ASSERT_EQ(s1, s2);
1645 }
1646 
GetName(TF_Function * func)1647 string GetName(TF_Function* func) {
1648   tensorflow::FunctionDef fdef;
1649   GetFunctionDef(func, &fdef);
1650   return fdef.signature().name();
1651 }
1652 
TEST_F(CApiFunctionTest,GetFunctionsFromGraph)1653 TEST_F(CApiFunctionTest, GetFunctionsFromGraph) {
1654   TF_Function* funcs[2];
1655 
1656   // Get functions from empty graph
1657   EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 0);
1658   TF_GraphGetFunctions(host_graph_, nullptr, 0, s_);
1659   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1660 
1661   // Define a function and add it to host_graph_
1662   TF_Function* func0;
1663   DefineFunction("FooFunc0", &func0);
1664   TF_GraphCopyFunction(host_graph_, func0, nullptr, s_);
1665   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1666 
1667   // Get this function from host_graph_
1668   EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 1);
1669   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
1670   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1671   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 1, s_), 1);
1672   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1673   AssertEqual(func0, funcs[0]);
1674   TF_DeleteFunction(funcs[0]);
1675   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 1);
1676   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1677   AssertEqual(func0, funcs[0]);
1678   TF_DeleteFunction(funcs[0]);
1679 
1680   // Define a second function
1681   TF_Function* func1;
1682   DefineFunction("FooFunc1", &func1);
1683   TF_GraphCopyFunction(host_graph_, func1, nullptr, s_);
1684   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1685 
1686   // Get both function from host_graph_
1687   EXPECT_EQ(TF_GraphNumFunctions(host_graph_), 2);
1688   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 0, s_), 0);
1689   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1690   EXPECT_EQ(TF_GraphGetFunctions(host_graph_, funcs, 2, s_), 2);
1691   ASSERT_EQ(TF_OK, TF_GetCode(s_)) << TF_Message(s_);
1692   if (GetName(funcs[0]) == GetName(func0)) {
1693     AssertEqual(func0, funcs[0]);
1694     AssertEqual(func1, funcs[1]);
1695   } else {
1696     AssertEqual(func0, funcs[1]);
1697     AssertEqual(func1, funcs[0]);
1698   }
1699 
1700   TF_DeleteFunction(funcs[0]);
1701   TF_DeleteFunction(funcs[1]);
1702 
1703   TF_DeleteFunction(func0);
1704   TF_DeleteFunction(func1);
1705 }
1706 
1707 }  // namespace
1708 }  // namespace tensorflow
1709