xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/sharding_util_ops_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include <functional>
17 #include <memory>
18 #include <numeric>
19 #include <string>
20 #include <vector>
21 
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/graph/testlib.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/protobuf/config.pb.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/public/session.h"
41 #include "tensorflow/core/public/session_options.h"
42 
43 namespace tensorflow {
44 namespace {
45 
46 MATCHER_P2(IsStatus, error_code, error_message, "") {
47   return arg.code() == error_code &&
48          absl::StrContains(arg.error_message(), error_message);
49 }
50 
RunGraph(const Graph & graph,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_tensor_names,std::vector<Tensor> * output_tensors)51 Status RunGraph(const Graph& graph,
52                 const std::vector<std::string>& output_tensor_names,
53                 const std::vector<std::string>& target_tensor_names,
54                 std::vector<Tensor>* output_tensors) {
55   GraphDef graph_def;
56   graph.ToGraphDef(&graph_def);
57   SessionOptions session_options;
58   std::unique_ptr<Session> session(NewSession(session_options));
59   TF_RETURN_IF_ERROR(session->Create(graph_def));
60   RunOptions run_options;
61   return session->Run(run_options, /*inputs=*/{}, output_tensor_names,
62                       target_tensor_names, output_tensors,
63                       /*run_metadata=*/nullptr);
64 }
65 
TEST(ReadVariableXlaSplitNDOpTest,VariableMissing)66 TEST(ReadVariableXlaSplitNDOpTest, VariableMissing) {
67   Graph graph(OpRegistry::Global());
68 
69   Node* var_handle = nullptr;
70   DataType data_type = DataTypeToEnum<int32>::value;
71   const TensorShape input_shape({4, 4});
72   TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
73                    .Attr("dtype", data_type)
74                    .Attr("shape", input_shape)
75                    .Finalize(&graph, &var_handle));
76 
77   Node* xla_op = nullptr;
78   const std::vector<int32> num_splits = {2, 2};
79   const int num_outputs = 4;
80   TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "ReadVariableXlaSplitND")
81                    .Input(var_handle)
82                    .Attr("num_splits", num_splits)
83                    .Attr("T", data_type)
84                    .Attr("N", num_outputs)
85                    .Finalize(&graph, &xla_op));
86 
87   std::vector<Tensor> output_tensors;
88   EXPECT_THAT(RunGraph(graph, /*output_tensor_names=*/{xla_op->name()},
89                        /*target_tensor_names=*/{}, &output_tensors),
90               IsStatus(error::INVALID_ARGUMENT, "cannot be found"));
91 }
92 
TEST(ReadVariableXlaSplitNDOpTest,DTypeInvalid)93 TEST(ReadVariableXlaSplitNDOpTest, DTypeInvalid) {
94   Graph graph(OpRegistry::Global());
95 
96   Node* var_handle = nullptr;
97   DataType data_type = DataTypeToEnum<int32>::value;
98   const TensorShape input_shape({4, 4});
99   TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
100                    .Attr("dtype", data_type)
101                    .Attr("shape", input_shape)
102                    .Finalize(&graph, &var_handle));
103 
104   Tensor input_tensor(data_type, input_shape);
105   test::FillIota<int32>(&input_tensor, /*val=*/0);
106   Node* input = test::graph::Constant(&graph, input_tensor);
107 
108   Node* assign_var = nullptr;
109   TF_ASSERT_OK(NodeBuilder(graph.NewName("assign_var"), "AssignVariableOp")
110                    .Input(var_handle)
111                    .Input(input)
112                    .Attr("dtype", data_type)
113                    .Finalize(&graph, &assign_var));
114 
115   Node* xla_op = nullptr;
116   const std::vector<int32> num_splits = {2, 2};
117   const int num_outputs = 4;
118   TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "ReadVariableXlaSplitND")
119                    .Input(var_handle)
120                    .ControlInput(assign_var)
121                    .Attr("num_splits", num_splits)
122                    .Attr("T", DataTypeToEnum<float>::value)
123                    .Attr("N", num_outputs)
124                    .Finalize(&graph, &xla_op));
125 
126   std::vector<Tensor> output_tensors;
127   EXPECT_THAT(RunGraph(graph, /*output_tensor_names=*/{xla_op->name()},
128                        /*target_tensor_names=*/{}, &output_tensors),
129               IsStatus(error::INVALID_ARGUMENT, "'T' must match 'resource'"));
130 }
131 
CreateSplitTensorGraph(const TensorShape & input_shape,absl::Span<const int32> num_splits,absl::Span<const int32> paddings,const int num_outputs,Graph * graph,std::vector<std::string> * output_tensor_names)132 Status CreateSplitTensorGraph(const TensorShape& input_shape,
133                               absl::Span<const int32> num_splits,
134                               absl::Span<const int32> paddings,
135                               const int num_outputs, Graph* graph,
136                               std::vector<std::string>* output_tensor_names) {
137   DataType data_type = DataTypeToEnum<int32>::value;
138   Tensor input_tensor(data_type, input_shape);
139   test::FillIota<int32>(&input_tensor, /*val=*/0);
140   Node* input = test::graph::Constant(graph, input_tensor);
141 
142   Node* xla_op = nullptr;
143   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("xla_op"), "XlaSplitND")
144                          .Input(input)
145                          .Attr("num_splits", num_splits)
146                          .Attr("paddings", paddings)
147                          .Attr("T", data_type)
148                          .Attr("N", num_outputs)
149                          .Finalize(graph, &xla_op));
150 
151   output_tensor_names->reserve(num_outputs);
152   for (int i = 0; i < num_outputs; ++i) {
153     output_tensor_names->push_back(absl::StrCat(xla_op->name(), ":", i));
154   }
155 
156   return OkStatus();
157 }
158 
CreateSplitResourceGraph(const TensorShape & input_shape,absl::Span<const int32> num_splits,absl::Span<const int32> paddings,const int num_outputs,Graph * graph,std::vector<std::string> * output_tensor_names)159 Status CreateSplitResourceGraph(const TensorShape& input_shape,
160                                 absl::Span<const int32> num_splits,
161                                 absl::Span<const int32> paddings,
162                                 const int num_outputs, Graph* graph,
163                                 std::vector<std::string>* output_tensor_names) {
164   Node* var_handle = nullptr;
165   DataType data_type = DataTypeToEnum<int32>::value;
166   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("var_handle"), "VarHandleOp")
167                          .Attr("dtype", data_type)
168                          .Attr("shape", input_shape)
169                          .Finalize(graph, &var_handle));
170 
171   Tensor input_tensor(data_type, input_shape);
172   test::FillIota<int32>(&input_tensor, /*val=*/0);
173   Node* input = test::graph::Constant(graph, input_tensor);
174 
175   Node* assign_var = nullptr;
176   TF_RETURN_IF_ERROR(
177       NodeBuilder(graph->NewName("assign_var"), "AssignVariableOp")
178           .Input(var_handle)
179           .Input(input)
180           .Attr("dtype", data_type)
181           .Finalize(graph, &assign_var));
182 
183   Node* xla_op = nullptr;
184   TF_RETURN_IF_ERROR(
185       NodeBuilder(graph->NewName("xla_op"), "ReadVariableXlaSplitND")
186           .Input(var_handle)
187           .ControlInput(assign_var)
188           .Attr("num_splits", num_splits)
189           .Attr("paddings", paddings)
190           .Attr("T", data_type)
191           .Attr("N", num_outputs)
192           .Finalize(graph, &xla_op));
193 
194   output_tensor_names->reserve(num_outputs);
195   for (int i = 0; i < num_outputs; ++i) {
196     output_tensor_names->push_back(absl::StrCat(xla_op->name(), ":", i));
197   }
198 
199   return OkStatus();
200 }
201 
202 struct XlaSplitNDTestParam {
203   std::string name;
204   std::function<Status(const TensorShape&, absl::Span<const int32>,
205                        absl::Span<const int32>, const int num_outputs, Graph*,
206                        std::vector<std::string>*)>
207       graph_creator;
208 };
209 
210 using XlaSplitNDOpTest = ::testing::TestWithParam<XlaSplitNDTestParam>;
211 
TEST_P(XlaSplitNDOpTest,SplitDimensionZero)212 TEST_P(XlaSplitNDOpTest, SplitDimensionZero) {
213   Graph graph(OpRegistry::Global());
214   const TensorShape input_shape({1, 1, 1});
215   const std::vector<int32> num_splits = {1, 1, 0};
216   const std::vector<int32> paddings;
217   const int num_outputs = 1;
218   std::vector<std::string> output_tensor_names;
219   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
220                                         num_outputs, &graph,
221                                         &output_tensor_names));
222 
223   std::vector<Tensor> output_tensors;
224   EXPECT_THAT(
225       RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
226                &output_tensors),
227       IsStatus(error::INVALID_ARGUMENT, "index 2 must be positive, but got 0"));
228 }
229 
TEST_P(XlaSplitNDOpTest,SplitDimensionNegative)230 TEST_P(XlaSplitNDOpTest, SplitDimensionNegative) {
231   Graph graph(OpRegistry::Global());
232   const TensorShape input_shape({1, 1, 1});
233   const std::vector<int32> num_splits = {1, -1, 1};
234   const std::vector<int32> paddings;
235   const int num_outputs = 1;
236   std::vector<std::string> output_tensor_names;
237   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
238                                         num_outputs, &graph,
239                                         &output_tensor_names));
240 
241   std::vector<Tensor> output_tensors;
242   EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
243                        &output_tensors),
244               IsStatus(error::INVALID_ARGUMENT,
245                        "index 1 must be positive, but got -1"));
246 }
247 
TEST_P(XlaSplitNDOpTest,NumOutputsMismatch)248 TEST_P(XlaSplitNDOpTest, NumOutputsMismatch) {
249   Graph graph(OpRegistry::Global());
250   const TensorShape input_shape({2});
251   const std::vector<int32> num_splits = {2};
252   const std::vector<int> paddings;
253   const int num_outputs = 1;
254   std::vector<std::string> output_tensor_names;
255   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
256                                         num_outputs, &graph,
257                                         &output_tensor_names));
258 
259   std::vector<Tensor> output_tensors;
260   EXPECT_THAT(
261       RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
262                &output_tensors),
263       IsStatus(error::INVALID_ARGUMENT, "'N' must match number of slices 2"));
264 }
265 
TEST_P(XlaSplitNDOpTest,PaddingsLengthMismatch)266 TEST_P(XlaSplitNDOpTest, PaddingsLengthMismatch) {
267   Graph graph(OpRegistry::Global());
268   const TensorShape input_shape({2, 2});
269   const std::vector<int32> num_splits = {2, 2};
270   const std::vector<int32> paddings = {0};
271   const int num_outputs = 4;
272   std::vector<std::string> output_tensor_names;
273   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
274                                         num_outputs, &graph,
275                                         &output_tensor_names));
276 
277   std::vector<Tensor> output_tensors;
278   EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
279                        &output_tensors),
280               IsStatus(error::INVALID_ARGUMENT, "length 2, but got 1"));
281 }
282 
TEST_P(XlaSplitNDOpTest,PaddingsNegative)283 TEST_P(XlaSplitNDOpTest, PaddingsNegative) {
284   Graph graph(OpRegistry::Global());
285   const TensorShape input_shape({2, 2});
286   const std::vector<int32> num_splits = {2, 2};
287   const std::vector<int32> paddings = {0, -1};
288   const int num_outputs = 4;
289   std::vector<std::string> output_tensor_names;
290   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
291                                         num_outputs, &graph,
292                                         &output_tensor_names));
293 
294   std::vector<Tensor> output_tensors;
295   EXPECT_THAT(
296       RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
297                &output_tensors),
298       IsStatus(error::INVALID_ARGUMENT, "non-negative, but got -1 at index 1"));
299 }
300 
TEST_P(XlaSplitNDOpTest,InputRank0)301 TEST_P(XlaSplitNDOpTest, InputRank0) {
302   Graph graph(OpRegistry::Global());
303   const TensorShape input_shape({});
304   const std::vector<int32> num_splits = {2};
305   const std::vector<int32> paddings;
306   const int num_outputs = 2;
307   std::vector<std::string> output_tensor_names;
308   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
309                                         num_outputs, &graph,
310                                         &output_tensor_names));
311 
312   std::vector<Tensor> output_tensors;
313   EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
314                        &output_tensors),
315               IsStatus(error::INVALID_ARGUMENT, "range (0, 8], but got 0"));
316 }
317 
TEST_P(XlaSplitNDOpTest,InputRank9)318 TEST_P(XlaSplitNDOpTest, InputRank9) {
319   Graph graph(OpRegistry::Global());
320   const TensorShape input_shape({2, 2, 2, 2, 2, 2, 2, 2, 2});
321   const std::vector<int32> num_splits(9, 2);
322   const std::vector<int32> paddings;
323   const int num_outputs = 512;
324   std::vector<std::string> output_tensor_names;
325   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
326                                         num_outputs, &graph,
327                                         &output_tensor_names));
328 
329   std::vector<Tensor> output_tensors;
330   EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
331                        &output_tensors),
332               IsStatus(error::INVALID_ARGUMENT, "range (0, 8], but got 9"));
333 }
334 
TEST_P(XlaSplitNDOpTest,InputRankSplitMismatch)335 TEST_P(XlaSplitNDOpTest, InputRankSplitMismatch) {
336   Graph graph(OpRegistry::Global());
337   const TensorShape input_shape({2, 2});
338   const std::vector<int32> num_splits = {2, 2, 2};
339   const std::vector<int32> paddings;
340   const int num_outputs = 8;
341   std::vector<std::string> output_tensor_names;
342   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
343                                         num_outputs, &graph,
344                                         &output_tensor_names));
345 
346   std::vector<Tensor> output_tensors;
347   EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
348                        &output_tensors),
349               IsStatus(error::INVALID_ARGUMENT,
350                        "'num_splits' length 3, but got rank 2"));
351 }
352 
TEST_P(XlaSplitNDOpTest,DimNotEvenlySplit)353 TEST_P(XlaSplitNDOpTest, DimNotEvenlySplit) {
354   Graph graph(OpRegistry::Global());
355   const TensorShape input_shape({4, 2});
356   const std::vector<int32> num_splits = {3, 2};
357   const std::vector<int32> paddings;
358   const int num_outputs = 6;
359   std::vector<std::string> output_tensor_names;
360   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
361                                         num_outputs, &graph,
362                                         &output_tensor_names));
363 
364   std::vector<Tensor> output_tensors;
365   EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
366                        &output_tensors),
367               IsStatus(error::INVALID_ARGUMENT, "divisible by 'num_splits' 3"));
368 }
369 
TEST_P(XlaSplitNDOpTest,DimWithPaddingNotEvenlySplit)370 TEST_P(XlaSplitNDOpTest, DimWithPaddingNotEvenlySplit) {
371   Graph graph(OpRegistry::Global());
372   const TensorShape input_shape({4, 2});
373   const std::vector<int32> num_splits = {2, 2};
374   const std::vector<int32> paddings = {0, 1};
375   const int num_outputs = 4;
376   std::vector<std::string> output_tensor_names;
377   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
378                                         num_outputs, &graph,
379                                         &output_tensor_names));
380 
381   std::vector<Tensor> output_tensors;
382   EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
383                        &output_tensors),
384               IsStatus(error::INVALID_ARGUMENT, "divisible by 'num_splits' 2"));
385 }
386 
TEST_P(XlaSplitNDOpTest,NoSplits)387 TEST_P(XlaSplitNDOpTest, NoSplits) {
388   Graph graph(OpRegistry::Global());
389   const TensorShape input_shape({2, 2, 2});
390   const std::vector<int32> num_splits = {1, 1, 1};
391   const std::vector<int> paddings;
392   const int num_outputs = 1;
393   std::vector<std::string> output_tensor_names;
394   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
395                                         num_outputs, &graph,
396                                         &output_tensor_names));
397 
398   std::vector<Tensor> output_tensors;
399   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
400                         &output_tensors));
401   ASSERT_EQ(output_tensors.size(), 1);
402   test::ExpectTensorEqual<int32>(
403       output_tensors[0],
404       test::AsTensor<int32>({0, 1, 2, 3, 4, 5, 6, 7}, TensorShape({2, 2, 2})));
405 }
406 
TEST_P(XlaSplitNDOpTest,NoSplitsWithPadding)407 TEST_P(XlaSplitNDOpTest, NoSplitsWithPadding) {
408   Graph graph(OpRegistry::Global());
409   const TensorShape input_shape({2, 1, 1});
410   const std::vector<int32> num_splits = {1, 1, 1};
411   const std::vector<int> paddings = {0, 1, 1};
412   const int num_outputs = 1;
413   std::vector<std::string> output_tensor_names;
414   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
415                                         num_outputs, &graph,
416                                         &output_tensor_names));
417 
418   std::vector<Tensor> output_tensors;
419   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
420                         &output_tensors));
421   ASSERT_EQ(output_tensors.size(), 1);
422   std::vector<int32> expected_values(3 * 3 * 3);
423   test::ExpectTensorEqual<int32>(
424       output_tensors[0],
425       test::AsTensor<int32>({0, 0, 0, 0, 1, 0, 0, 0}, TensorShape({2, 2, 2})));
426 }
427 
TEST_P(XlaSplitNDOpTest,SplitNoPadding)428 TEST_P(XlaSplitNDOpTest, SplitNoPadding) {
429   Graph graph(OpRegistry::Global());
430   const TensorShape input_shape({4, 4});
431   const std::vector<int32> num_splits = {2, 2};
432   const std::vector<int32> paddings;
433   const int num_outputs = 4;
434   std::vector<std::string> output_tensor_names;
435   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
436                                         num_outputs, &graph,
437                                         &output_tensor_names));
438 
439   std::vector<Tensor> output_tensors;
440   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
441                         &output_tensors));
442   ASSERT_EQ(output_tensors.size(), num_outputs);
443   test::ExpectTensorEqual<int32>(
444       output_tensors[0],
445       test::AsTensor<int32>({0, 1, 4, 5}, TensorShape({2, 2})));
446   test::ExpectTensorEqual<int32>(
447       output_tensors[1],
448       test::AsTensor<int32>({2, 3, 6, 7}, TensorShape({2, 2})));
449   test::ExpectTensorEqual<int32>(
450       output_tensors[2],
451       test::AsTensor<int32>({8, 9, 12, 13}, TensorShape({2, 2})));
452   test::ExpectTensorEqual<int32>(
453       output_tensors[3],
454       test::AsTensor<int32>({10, 11, 14, 15}, TensorShape({2, 2})));
455 }
456 
TEST_P(XlaSplitNDOpTest,SplitPartialPadding)457 TEST_P(XlaSplitNDOpTest, SplitPartialPadding) {
458   Graph graph(OpRegistry::Global());
459   const TensorShape input_shape({3, 3});
460   const std::vector<int32> num_splits = {2, 2};
461   const std::vector<int32> paddings = {1, 1};
462   const int num_outputs = 4;
463   std::vector<std::string> output_tensor_names;
464   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
465                                         num_outputs, &graph,
466                                         &output_tensor_names));
467 
468   std::vector<Tensor> output_tensors;
469   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
470                         &output_tensors));
471   ASSERT_EQ(output_tensors.size(), num_outputs);
472   test::ExpectTensorEqual<int32>(
473       output_tensors[0],
474       test::AsTensor<int32>({0, 1, 3, 4}, TensorShape({2, 2})));
475   test::ExpectTensorEqual<int32>(
476       output_tensors[1],
477       test::AsTensor<int32>({2, 0, 5, 0}, TensorShape({2, 2})));
478   test::ExpectTensorEqual<int32>(
479       output_tensors[2],
480       test::AsTensor<int32>({6, 7, 0, 0}, TensorShape({2, 2})));
481   test::ExpectTensorEqual<int32>(
482       output_tensors[3],
483       test::AsTensor<int32>({8, 0, 0, 0}, TensorShape({2, 2})));
484 }
485 
TEST_P(XlaSplitNDOpTest,SplitCompletePadding)486 TEST_P(XlaSplitNDOpTest, SplitCompletePadding) {
487   Graph graph(OpRegistry::Global());
488   const TensorShape input_shape({2, 1});
489   const std::vector<int32> num_splits = {2, 2};
490   const std::vector<int32> paddings = {2, 3};
491   const int num_outputs = 4;
492   std::vector<std::string> output_tensor_names;
493   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
494                                         num_outputs, &graph,
495                                         &output_tensor_names));
496 
497   std::vector<Tensor> output_tensors;
498   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
499                         &output_tensors));
500   ASSERT_EQ(output_tensors.size(), num_outputs);
501   test::ExpectTensorEqual<int32>(
502       output_tensors[0],
503       test::AsTensor<int32>({0, 0, 1, 0}, TensorShape({2, 2})));
504   test::ExpectTensorEqual<int32>(
505       output_tensors[1],
506       test::AsTensor<int32>({0, 0, 0, 0}, TensorShape({2, 2})));
507   test::ExpectTensorEqual<int32>(
508       output_tensors[2],
509       test::AsTensor<int32>({0, 0, 0, 0}, TensorShape({2, 2})));
510   test::ExpectTensorEqual<int32>(
511       output_tensors[3],
512       test::AsTensor<int32>({0, 0, 0, 0}, TensorShape({2, 2})));
513 }
514 
515 INSTANTIATE_TEST_SUITE_P(
516     XlaSplitNDOpTest, XlaSplitNDOpTest,
517     ::testing::ValuesIn<XlaSplitNDTestParam>(
518         {{"Tensor", CreateSplitTensorGraph},
519          {"Resource", CreateSplitResourceGraph}}),
__anon3d7075df0202(const ::testing::TestParamInfo<XlaSplitNDOpTest::ParamType>& info) 520     [](const ::testing::TestParamInfo<XlaSplitNDOpTest::ParamType>& info) {
521       return info.param.name;
522     });
523 
524 struct RankedXlaSplitNDTestParam {
525   std::string name;
526   int rank = 0;
527   std::function<Status(const TensorShape&, absl::Span<const int32>,
528                        absl::Span<const int32>, const int num_outputs, Graph*,
529                        std::vector<std::string>*)>
530       graph_creator;
531 };
532 
533 class RankedXlaSplitNDOpTest
534     : public ::testing::TestWithParam<RankedXlaSplitNDTestParam> {};
535 
TEST_P(RankedXlaSplitNDOpTest,TestSubscriptRank)536 TEST_P(RankedXlaSplitNDOpTest, TestSubscriptRank) {
537   const int rank = GetParam().rank;
538   const std::vector<int32> num_splits(rank, 2);
539 
540   Graph graph(OpRegistry::Global());
541   const TensorShape input_shape(std::vector<int64_t>(rank, 2));
542   const std::vector<int32> paddings;
543   const int num_outputs = 2 << (rank - 1);
544   std::vector<std::string> output_tensor_names;
545   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
546                                         num_outputs, &graph,
547                                         &output_tensor_names));
548 
549   std::vector<Tensor> output_tensors;
550   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
551                         &output_tensors));
552   ASSERT_EQ(output_tensors.size(), num_outputs);
553   TensorShape output_shape(std::vector<int64_t>(rank, 1));
554   for (int i = 0; i < num_outputs; ++i) {
555     test::ExpectTensorEqual<int32>(output_tensors[i],
556                                    test::AsTensor<int32>({i}, output_shape));
557   }
558 }
559 
560 INSTANTIATE_TEST_SUITE_P(
561     RankedXlaSplitNDOpTest, RankedXlaSplitNDOpTest,
562     ::testing::ValuesIn<RankedXlaSplitNDTestParam>(
563         {{"TensorRanked1", 1, CreateSplitTensorGraph},
564          {"TensorRanked2", 2, CreateSplitTensorGraph},
565          {"TensorRanked3", 3, CreateSplitTensorGraph},
566          {"TensorRanked4", 4, CreateSplitTensorGraph},
567          {"TensorRanked5", 5, CreateSplitTensorGraph},
568          {"TensorRanked6", 6, CreateSplitTensorGraph},
569          {"TensorRanked7", 7, CreateSplitTensorGraph},
570          {"TensorRanked8", 8, CreateSplitTensorGraph},
571          {"ResourceRanked1", 1, CreateSplitResourceGraph},
572          {"ResourceRanked2", 2, CreateSplitResourceGraph},
573          {"ResourceRanked3", 3, CreateSplitResourceGraph},
574          {"ResourceRanked4", 4, CreateSplitResourceGraph},
575          {"ResourceRanked5", 5, CreateSplitResourceGraph},
576          {"ResourceRanked6", 6, CreateSplitResourceGraph},
577          {"ResourceRanked7", 7, CreateSplitResourceGraph},
578          {"ResourceRanked8", 8, CreateSplitResourceGraph}}),
579     [](const ::testing::TestParamInfo<RankedXlaSplitNDOpTest::ParamType>&
__anon3d7075df0302(const ::testing::TestParamInfo<RankedXlaSplitNDOpTest::ParamType>& info) 580            info) { return info.param.name; });
581 
TEST(AssignVariableXlaConcatNDOpTest,HandleDTypeInvalid)582 TEST(AssignVariableXlaConcatNDOpTest, HandleDTypeInvalid) {
583   Graph graph(OpRegistry::Global());
584   Node* var_handle = nullptr;
585   DataType handle_dtype = DataTypeToEnum<int32>::value;
586   PartialTensorShape handle_shape;
587   TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
588                    .Attr("dtype", handle_dtype)
589                    .Attr("shape", handle_shape)
590                    .Finalize(&graph, &var_handle));
591   DataType update_data_type = DataTypeToEnum<float>::value;
592   const TensorShape update_input_shape({4, 4});
593   Tensor update_input_tensor(update_data_type, update_input_shape);
594   test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
595   Node* update_input = test::graph::Constant(&graph, update_input_tensor);
596   Node* xla_op = nullptr;
597   const std::vector<int32> num_concats = {1, 1};
598   const int num_inputs = 1;
599   TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
600                    .Input(var_handle)
601                    .Input(std::vector<NodeBuilder::NodeOut>{update_input})
602                    .Attr("num_concats", num_concats)
603                    .Attr("T", update_data_type)
604                    .Attr("N", num_inputs)
605                    .Finalize(&graph, &xla_op));
606 
607   std::vector<Tensor> output_tensors;
608   EXPECT_THAT(
609       RunGraph(graph, /*output_tensor_names=*/{},
610                /*target_tensor_names=*/{xla_op->name()}, &output_tensors),
611       IsStatus(error::INVALID_ARGUMENT, "dtype int32, but got float"));
612 }
613 
TEST(AssignVariableXlaConcatNDOpTest,TensorDTypeInvalid)614 TEST(AssignVariableXlaConcatNDOpTest, TensorDTypeInvalid) {
615   Graph graph(OpRegistry::Global());
616 
617   Node* var_handle = nullptr;
618   DataType handle_dtype = DataTypeToEnum<float>::value;
619   PartialTensorShape handle_shape;
620   TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
621                    .Attr("dtype", handle_dtype)
622                    .Attr("shape", handle_shape)
623                    .Finalize(&graph, &var_handle));
624 
625   DataType init_data_type = DataTypeToEnum<int32>::value;
626   const TensorShape init_input_shape({4, 4});
627   Tensor init_input_tensor(init_data_type, init_input_shape);
628   test::FillIota<int32>(&init_input_tensor, /*val=*/0);
629   Node* input = test::graph::Constant(&graph, init_input_tensor);
630 
631   Node* assign_var = nullptr;
632   TF_ASSERT_OK(NodeBuilder(graph.NewName("assign_var"), "AssignVariableOp")
633                    .Input(var_handle)
634                    .Input(input)
635                    .Attr("dtype", init_data_type)
636                    .Finalize(&graph, &assign_var));
637 
638   DataType update_data_type = DataTypeToEnum<float>::value;
639   const TensorShape update_input_shape({4, 4});
640   Tensor update_input_tensor(update_data_type, update_input_shape);
641   test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
642   Node* update_input = test::graph::Constant(&graph, update_input_tensor);
643 
644   Node* xla_op = nullptr;
645   const std::vector<int32> num_concats = {1, 1};
646   const int num_inputs = 1;
647   TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
648                    .Input(var_handle)
649                    .Input(std::vector<NodeBuilder::NodeOut>{update_input})
650                    .ControlInput(assign_var)
651                    .Attr("num_concats", num_concats)
652                    .Attr("T", update_data_type)
653                    .Attr("N", num_inputs)
654                    .Finalize(&graph, &xla_op));
655 
656   std::vector<Tensor> output_tensors;
657   EXPECT_THAT(
658       RunGraph(graph, /*output_tensor_names=*/{},
659                /*target_tensor_names=*/{xla_op->name()}, &output_tensors),
660       IsStatus(error::INVALID_ARGUMENT, "dtype int32, but got float"));
661 }
662 
TEST(AssignVariableXlaConcatNDOpTest,HandleShapeIncompatible)663 TEST(AssignVariableXlaConcatNDOpTest, HandleShapeIncompatible) {
664   Graph graph(OpRegistry::Global());
665 
666   Node* var_handle = nullptr;
667   DataType handle_dtype = DataTypeToEnum<float>::value;
668   PartialTensorShape handle_shape({});
669   TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
670                    .Attr("dtype", handle_dtype)
671                    .Attr("shape", handle_shape)
672                    .Finalize(&graph, &var_handle));
673 
674   DataType update_data_type = DataTypeToEnum<float>::value;
675   const TensorShape update_input_shape({4, 4});
676   Tensor update_input_tensor(update_data_type, update_input_shape);
677   test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
678   Node* update_input = test::graph::Constant(&graph, update_input_tensor);
679 
680   Node* xla_op = nullptr;
681   const std::vector<int32> num_concats = {1, 1};
682   const int num_inputs = 1;
683   TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
684                    .Input(var_handle)
685                    .Input(std::vector<NodeBuilder::NodeOut>{update_input})
686                    .Attr("num_concats", num_concats)
687                    .Attr("T", update_data_type)
688                    .Attr("N", num_inputs)
689                    .Finalize(&graph, &xla_op));
690 
691   std::vector<Tensor> output_tensors;
692   EXPECT_THAT(
693       RunGraph(graph, /*output_tensor_names=*/{},
694                /*target_tensor_names=*/{xla_op->name()}, &output_tensors),
695       IsStatus(error::INVALID_ARGUMENT, "expected shape [4,4], but got []"));
696 }
697 
TEST(AssignVariableXlaConcatNDOpTest,HandleShapeWithPaddingIncompatible)698 TEST(AssignVariableXlaConcatNDOpTest, HandleShapeWithPaddingIncompatible) {
699   Graph graph(OpRegistry::Global());
700 
701   Node* var_handle = nullptr;
702   DataType handle_dtype = DataTypeToEnum<float>::value;
703   PartialTensorShape handle_shape({4, 4});
704   TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
705                    .Attr("dtype", handle_dtype)
706                    .Attr("shape", handle_shape)
707                    .Finalize(&graph, &var_handle));
708 
709   DataType update_data_type = DataTypeToEnum<float>::value;
710   const TensorShape update_input_shape({4, 4});
711   Tensor update_input_tensor(update_data_type, update_input_shape);
712   test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
713   Node* update_input = test::graph::Constant(&graph, update_input_tensor);
714 
715   Node* xla_op = nullptr;
716   const std::vector<int32> num_concats = {1, 1};
717   const std::vector<int32> paddings = {1, 1};
718   const int num_inputs = 1;
719   TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
720                    .Input(var_handle)
721                    .Input(std::vector<NodeBuilder::NodeOut>{update_input})
722                    .Attr("num_concats", num_concats)
723                    .Attr("paddings", paddings)
724                    .Attr("T", update_data_type)
725                    .Attr("N", num_inputs)
726                    .Finalize(&graph, &xla_op));
727 
728   std::vector<Tensor> output_tensors;
729   EXPECT_THAT(
730       RunGraph(graph, /*output_tensor_names=*/{},
731                /*target_tensor_names=*/{xla_op->name()}, &output_tensors),
732       IsStatus(error::INVALID_ARGUMENT, "expected shape [3,3], but got [4,4]"));
733 }
734 
TEST(AssignVariableXlaConcatNDOpTest,AssignDifferentShape)735 TEST(AssignVariableXlaConcatNDOpTest, AssignDifferentShape) {
736   Graph graph(OpRegistry::Global());
737 
738   Node* var_handle = nullptr;
739   DataType data_type = DataTypeToEnum<float>::value;
740   TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
741                    .Attr("dtype", data_type)
742                    .Attr("shape", PartialTensorShape({4, -1}))
743                    .Finalize(&graph, &var_handle));
744 
745   const TensorShape init_input_shape({4, 2});
746   Tensor init_input_tensor(data_type, init_input_shape);
747   test::FillFn<float>(&init_input_tensor, [](int unused) { return -1.f; });
748   Node* init_input = test::graph::Constant(&graph, init_input_tensor);
749 
750   Node* assign_var = nullptr;
751   TF_ASSERT_OK(NodeBuilder(graph.NewName("assign_var"), "AssignVariableOp")
752                    .Input(var_handle)
753                    .Input(init_input)
754                    .Attr("dtype", data_type)
755                    .Finalize(&graph, &assign_var));
756 
757   const TensorShape update_input_shape({4, 4});
758   Tensor update_input_tensor(data_type, update_input_shape);
759   test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
760   Node* update_input = test::graph::Constant(&graph, update_input_tensor);
761 
762   Node* xla_op = nullptr;
763   const std::vector<int32> num_concats = {1, 1};
764   const int num_inputs = 1;
765   TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
766                    .Input(var_handle)
767                    .Input(std::vector<NodeBuilder::NodeOut>{update_input})
768                    .ControlInput(assign_var)
769                    .Attr("num_concats", num_concats)
770                    .Attr("T", data_type)
771                    .Attr("N", num_inputs)
772                    .Finalize(&graph, &xla_op));
773 
774   Node* read_var = nullptr;
775   TF_ASSERT_OK(NodeBuilder(graph.NewName("read_var"), "ReadVariableOp")
776                    .Input(var_handle)
777                    .ControlInput(xla_op)
778                    .Attr("dtype", data_type)
779                    .Finalize(&graph, &read_var));
780 
781   std::vector<Tensor> output_tensors;
782   TF_ASSERT_OK(RunGraph(
783       graph, /*output_tensor_names=*/{absl::StrCat(read_var->name(), ":", 0)},
784       /*target_tensor_names=*/{}, &output_tensors));
785   ASSERT_EQ(output_tensors.size(), 1);
786   test::ExpectTensorNear<float>(output_tensors[0], update_input_tensor,
787                                 /*atol=*/1e-6);
788 }
789 
CreateConcatTensorGraph(absl::Span<const TensorShape> input_shapes,absl::Span<const int32> num_concats,absl::Span<const int32> paddings,Graph * graph,std::vector<std::string> * output_tensor_names)790 Status CreateConcatTensorGraph(absl::Span<const TensorShape> input_shapes,
791                                absl::Span<const int32> num_concats,
792                                absl::Span<const int32> paddings, Graph* graph,
793                                std::vector<std::string>* output_tensor_names) {
794   int32_t val = 0;
795   DataType data_type = DataTypeToEnum<int32>::value;
796   std::vector<NodeBuilder::NodeOut> inputs;
797   inputs.reserve(input_shapes.size());
798   for (const TensorShape& input_shape : input_shapes) {
799     Tensor input_tensor(data_type, input_shape);
800     test::FillIota<int32>(&input_tensor, val);
801     val += input_tensor.NumElements();
802     inputs.push_back(test::graph::Constant(graph, input_tensor));
803   }
804 
805   Node* xla_op = nullptr;
806   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("xla_op"), "XlaConcatND")
807                          .Input(inputs)
808                          .Attr("num_concats", num_concats)
809                          .Attr("paddings", paddings)
810                          .Attr("T", data_type)
811                          .Attr("N", static_cast<int64_t>(input_shapes.size()))
812                          .Finalize(graph, &xla_op));
813 
814   output_tensor_names->push_back(absl::StrCat(xla_op->name(), ":", 0));
815 
816   return OkStatus();
817 }
818 
819 template <bool Init>
CreateConcatResourceGraph(absl::Span<const TensorShape> input_shapes,absl::Span<const int32> num_concats,absl::Span<const int32> paddings,Graph * graph,std::vector<std::string> * output_tensor_names)820 Status CreateConcatResourceGraph(
821     absl::Span<const TensorShape> input_shapes,
822     absl::Span<const int32> num_concats, absl::Span<const int32> paddings,
823     Graph* graph, std::vector<std::string>* output_tensor_names) {
824   Node* var_handle = nullptr;
825   DataType data_type = DataTypeToEnum<int32>::value;
826   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("var_handle"), "VarHandleOp")
827                          .Attr("dtype", data_type)
828                          .Attr("shape", PartialTensorShape())
829                          .Finalize(graph, &var_handle));
830 
831   Node* assign_var = nullptr;
832   if (Init) {
833     Tensor init_input_tensor(data_type, input_shapes.front());
834     test::FillFn<int32>(&init_input_tensor, [](int unused) { return -1; });
835     Node* init_input = test::graph::Constant(graph, init_input_tensor);
836 
837     TF_RETURN_IF_ERROR(
838         NodeBuilder(graph->NewName("assign_var"), "AssignVariableOp")
839             .Input(var_handle)
840             .Input(init_input)
841             .Attr("dtype", data_type)
842             .Finalize(graph, &assign_var));
843   }
844 
845   int32_t val = 0;
846   std::vector<NodeBuilder::NodeOut> inputs;
847   inputs.reserve(input_shapes.size());
848   for (const TensorShape& input_shape : input_shapes) {
849     Tensor input_tensor(data_type, input_shape);
850     test::FillIota<int32>(&input_tensor, val);
851     val += input_tensor.NumElements();
852     inputs.push_back(test::graph::Constant(graph, input_tensor));
853   }
854 
855   Node* xla_op = nullptr;
856   NodeBuilder builder(graph->NewName("xla_op"), "AssignVariableXlaConcatND");
857   builder.Input(var_handle);
858   builder.Input(inputs);
859   if (assign_var != nullptr) {
860     builder.ControlInput(assign_var);
861   }
862   TF_RETURN_IF_ERROR(builder.Attr("num_concats", num_concats)
863                          .Attr("paddings", paddings)
864                          .Attr("T", data_type)
865                          .Attr("N", static_cast<int64_t>(input_shapes.size()))
866                          .Finalize(graph, &xla_op));
867 
868   Node* read_var = nullptr;
869   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("read_var"), "ReadVariableOp")
870                          .Input(var_handle)
871                          .ControlInput(xla_op)
872                          .Attr("dtype", data_type)
873                          .Finalize(graph, &read_var));
874 
875   output_tensor_names->push_back(absl::StrCat(read_var->name(), ":", 0));
876 
877   return OkStatus();
878 }
879 
880 struct XlaConcatNDTestParam {
881   std::string name;
882   std::function<Status(absl::Span<const TensorShape>, absl::Span<const int32>,
883                        absl::Span<const int32>, Graph*,
884                        std::vector<std::string>*)>
885       graph_creator;
886 };
887 
888 using XlaConcatNDOpTest = ::testing::TestWithParam<XlaConcatNDTestParam>;
889 
TEST_P(XlaConcatNDOpTest,ConcatDimensionZero)890 TEST_P(XlaConcatNDOpTest, ConcatDimensionZero) {
891   Graph graph(OpRegistry::Global());
892   const TensorShape input_shape({1, 1, 1});
893   const std::vector<int32> num_concats = {1, 1, 0};
894   const std::vector<int32> paddings;
895   std::vector<std::string> output_tensor_names;
896   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
897                                         &graph, &output_tensor_names));
898 
899   std::vector<Tensor> output_tensors;
900   EXPECT_THAT(
901       RunGraph(graph, output_tensor_names,
902                /*target_tensor_names=*/{}, &output_tensors),
903       IsStatus(error::INVALID_ARGUMENT, "index 2 must be positive, but got 0"));
904 }
905 
TEST_P(XlaConcatNDOpTest,ConcatDimensionNegative)906 TEST_P(XlaConcatNDOpTest, ConcatDimensionNegative) {
907   Graph graph(OpRegistry::Global());
908   const TensorShape input_shape({1, 1, 1});
909   const std::vector<int32> num_splits = {1, -1, 1};
910   const std::vector<int32> paddings;
911   std::vector<std::string> output_tensor_names;
912   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_splits, paddings,
913                                         &graph, &output_tensor_names));
914 
915   std::vector<Tensor> output_tensors;
916   EXPECT_THAT(RunGraph(graph, output_tensor_names,
917                        /*target_tensor_names=*/{}, &output_tensors),
918               IsStatus(error::INVALID_ARGUMENT,
919                        "index 1 must be positive, but got -1"));
920 }
921 
TEST_P(XlaConcatNDOpTest,NumInputsMismatch)922 TEST_P(XlaConcatNDOpTest, NumInputsMismatch) {
923   Graph graph(OpRegistry::Global());
924   const TensorShape input_shape({2});
925   const std::vector<int32> num_concats = {2};
926   const std::vector<int> paddings;
927   std::vector<std::string> output_tensor_names;
928   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
929                                         &graph, &output_tensor_names));
930 
931   std::vector<Tensor> output_tensors;
932   EXPECT_THAT(
933       RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
934                &output_tensors),
935       IsStatus(error::INVALID_ARGUMENT, "'N' must match number of slices 2"));
936 }
937 
TEST_P(XlaConcatNDOpTest,PaddingsLengthMismatch)938 TEST_P(XlaConcatNDOpTest, PaddingsLengthMismatch) {
939   Graph graph(OpRegistry::Global());
940   const TensorShape input_shape({2, 2});
941   const std::vector<int32> num_concats = {1, 1};
942   const std::vector<int32> paddings = {0};
943   std::vector<std::string> output_tensor_names;
944   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
945                                         &graph, &output_tensor_names));
946 
947   std::vector<Tensor> output_tensors;
948   EXPECT_THAT(RunGraph(graph, output_tensor_names,
949                        /*target_tensor_names=*/{}, &output_tensors),
950               IsStatus(error::INVALID_ARGUMENT, "length 2, but got 1"));
951 }
952 
TEST_P(XlaConcatNDOpTest,PaddingsNegative)953 TEST_P(XlaConcatNDOpTest, PaddingsNegative) {
954   Graph graph(OpRegistry::Global());
955   const TensorShape input_shape({2, 2});
956   const std::vector<int32> num_concats = {1, 1};
957   const std::vector<int32> paddings = {0, -1};
958   std::vector<std::string> output_tensor_names;
959   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
960                                         &graph, &output_tensor_names));
961 
962   std::vector<Tensor> output_tensors;
963   EXPECT_THAT(
964       RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
965                &output_tensors),
966       IsStatus(error::INVALID_ARGUMENT, "non-negative, but got -1 at index 1"));
967 }
968 
TEST_P(XlaConcatNDOpTest,InputRank0)969 TEST_P(XlaConcatNDOpTest, InputRank0) {
970   Graph graph(OpRegistry::Global());
971   const TensorShape input_shape({});
972   const std::vector<int32> num_concats;
973   const std::vector<int32> paddings;
974   std::vector<std::string> output_tensor_names;
975   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
976                                         &graph, &output_tensor_names));
977 
978   std::vector<Tensor> output_tensors;
979   EXPECT_THAT(RunGraph(graph, output_tensor_names,
980                        /*target_tensor_names=*/{}, &output_tensors),
981               IsStatus(error::INVALID_ARGUMENT, "range (0, 8], but got 0"));
982 }
983 
TEST_P(XlaConcatNDOpTest,InputRank9)984 TEST_P(XlaConcatNDOpTest, InputRank9) {
985   Graph graph(OpRegistry::Global());
986   const TensorShape input_shape({1, 1, 1, 1, 1, 1, 1, 1, 1});
987   const std::vector<int32> num_concats(9, 1);
988   const std::vector<int32> paddings;
989   std::vector<std::string> output_tensor_names;
990   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
991                                         &graph, &output_tensor_names));
992 
993   std::vector<Tensor> output_tensors;
994   EXPECT_THAT(RunGraph(graph, output_tensor_names,
995                        /*target_tensor_names=*/{}, &output_tensors),
996               IsStatus(error::INVALID_ARGUMENT, "range (0, 8], but got 9"));
997 }
998 
TEST_P(XlaConcatNDOpTest,InputRankConcatMismatch)999 TEST_P(XlaConcatNDOpTest, InputRankConcatMismatch) {
1000   Graph graph(OpRegistry::Global());
1001   const TensorShape input_shape({1});
1002   const std::vector<int32> num_concats = {1, 1};
1003   const std::vector<int32> paddings;
1004   std::vector<std::string> output_tensor_names;
1005   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
1006                                         &graph, &output_tensor_names));
1007 
1008   std::vector<Tensor> output_tensors;
1009   EXPECT_THAT(RunGraph(graph, output_tensor_names,
1010                        /*target_tensor_names=*/{}, &output_tensors),
1011               IsStatus(error::INVALID_ARGUMENT,
1012                        "'num_concats' length 2, but got rank 1"));
1013 }
1014 
TEST_P(XlaConcatNDOpTest,DifferentShapedInputs)1015 TEST_P(XlaConcatNDOpTest, DifferentShapedInputs) {
1016   Graph graph(OpRegistry::Global());
1017   const std::vector<TensorShape> input_shapes{{1}, {2}};
1018   const std::vector<int32> num_concats = {2};
1019   const std::vector<int32> paddings;
1020   std::vector<std::string> output_tensor_names;
1021   TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1022                                         &graph, &output_tensor_names));
1023 
1024   std::vector<Tensor> output_tensors;
1025   EXPECT_THAT(RunGraph(graph, output_tensor_names,
1026                        /*target_tensor_names=*/{}, &output_tensors),
1027               IsStatus(error::INVALID_ARGUMENT,
1028                        "same expected shape [1], but got [2] at index 1"));
1029 }
1030 
TEST_P(XlaConcatNDOpTest,PaddingExceedsOutputDimSize)1031 TEST_P(XlaConcatNDOpTest, PaddingExceedsOutputDimSize) {
1032   Graph graph(OpRegistry::Global());
1033   const TensorShape input_shape({1});
1034   const std::vector<int32> num_concats = {1};
1035   const std::vector<int32> paddings = {2};
1036   std::vector<std::string> output_tensor_names;
1037   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
1038                                         &graph, &output_tensor_names));
1039 
1040   std::vector<Tensor> output_tensors;
1041   EXPECT_THAT(
1042       RunGraph(graph, output_tensor_names,
1043                /*target_tensor_names=*/{}, &output_tensors),
1044       IsStatus(
1045           error::INVALID_ARGUMENT,
1046           "exceed expected output shape dimension 1 at index 0, but got 2"));
1047 }
1048 
TEST_P(XlaConcatNDOpTest,NoConcats)1049 TEST_P(XlaConcatNDOpTest, NoConcats) {
1050   Graph graph(OpRegistry::Global());
1051   const TensorShape input_shape({2, 2, 2});
1052   const std::vector<int32> num_concats = {1, 1, 1};
1053   const std::vector<int> paddings;
1054   std::vector<std::string> output_tensor_names;
1055   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
1056                                         &graph, &output_tensor_names));
1057 
1058   std::vector<Tensor> output_tensors;
1059   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1060                         &output_tensors));
1061   ASSERT_EQ(output_tensors.size(), 1);
1062   test::ExpectTensorEqual<int32>(
1063       output_tensors[0],
1064       test::AsTensor<int32>({0, 1, 2, 3, 4, 5, 6, 7}, TensorShape({2, 2, 2})));
1065 }
1066 
TEST_P(XlaConcatNDOpTest,NoConcatsWithPadding)1067 TEST_P(XlaConcatNDOpTest, NoConcatsWithPadding) {
1068   Graph graph(OpRegistry::Global());
1069   const TensorShape input_shape({2, 2, 2});
1070   const std::vector<int32> num_concats = {1, 1, 1};
1071   const std::vector<int> paddings = {1, 1, 1};
1072   std::vector<std::string> output_tensor_names;
1073   TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
1074                                         &graph, &output_tensor_names));
1075 
1076   std::vector<Tensor> output_tensors;
1077   TF_ASSERT_OK(RunGraph(graph, output_tensor_names,
1078                         /*target_tensor_names=*/{}, &output_tensors));
1079   ASSERT_EQ(output_tensors.size(), 1);
1080   test::ExpectTensorEqual<int32>(
1081       output_tensors[0], test::AsTensor<int32>({0}, TensorShape({1, 1, 1})));
1082 }
1083 
TEST_P(XlaConcatNDOpTest,ConcatNoPadding)1084 TEST_P(XlaConcatNDOpTest, ConcatNoPadding) {
1085   Graph graph(OpRegistry::Global());
1086   const std::vector<TensorShape> input_shapes{{2, 2}, {2, 2}, {2, 2}, {2, 2}};
1087   const std::vector<int32> num_concats = {2, 2};
1088   const std::vector<int32> paddings;
1089   std::vector<std::string> output_tensor_names;
1090   TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1091                                         &graph, &output_tensor_names));
1092 
1093   std::vector<Tensor> output_tensors;
1094   TF_ASSERT_OK(RunGraph(graph, output_tensor_names,
1095                         /*target_tensor_names=*/{}, &output_tensors));
1096   ASSERT_EQ(output_tensors.size(), 1);
1097   test::ExpectTensorEqual<int32>(
1098       output_tensors[0], test::AsTensor<int32>({0, 1, 4, 5, 2, 3, 6, 7, 8, 9,
1099                                                 12, 13, 10, 11, 14, 15},
1100                                                TensorShape({4, 4})));
1101 }
1102 
TEST_P(XlaConcatNDOpTest,ConcatPartialPadding)1103 TEST_P(XlaConcatNDOpTest, ConcatPartialPadding) {
1104   Graph graph(OpRegistry::Global());
1105   const std::vector<TensorShape> input_shapes{{2, 2}, {2, 2}, {2, 2}, {2, 2}};
1106   const std::vector<int32> num_concats = {2, 2};
1107   const std::vector<int32> paddings = {1, 1};
1108   std::vector<std::string> output_tensor_names;
1109   TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1110                                         &graph, &output_tensor_names));
1111 
1112   std::vector<Tensor> output_tensors;
1113   TF_ASSERT_OK(RunGraph(graph, output_tensor_names,
1114                         /*target_tensor_names=*/{}, &output_tensors));
1115   ASSERT_EQ(output_tensors.size(), 1);
1116   test::ExpectTensorEqual<int32>(
1117       output_tensors[0],
1118       test::AsTensor<int32>({0, 1, 4, 2, 3, 6, 8, 9, 12}, TensorShape({3, 3})));
1119 }
1120 
TEST_P(XlaConcatNDOpTest,ConcatCompletePadding)1121 TEST_P(XlaConcatNDOpTest, ConcatCompletePadding) {
1122   Graph graph(OpRegistry::Global());
1123   const std::vector<TensorShape> input_shapes{{2, 2}, {2, 2}, {2, 2}, {2, 2}};
1124   const std::vector<int32> num_concats = {2, 2};
1125   const std::vector<int32> paddings = {2, 2};
1126   std::vector<std::string> output_tensor_names;
1127   TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1128                                         &graph, &output_tensor_names));
1129 
1130   std::vector<Tensor> output_tensors;
1131   TF_ASSERT_OK(RunGraph(graph, output_tensor_names,
1132                         /*target_tensor_names=*/{}, &output_tensors));
1133   ASSERT_EQ(output_tensors.size(), 1);
1134   test::ExpectTensorEqual<int32>(
1135       output_tensors[0],
1136       test::AsTensor<int32>({0, 1, 2, 3}, TensorShape({2, 2})));
1137 }
1138 
1139 INSTANTIATE_TEST_SUITE_P(
1140     XlaConcatNDOpTest, XlaConcatNDOpTest,
1141     ::testing::ValuesIn<XlaConcatNDTestParam>(
1142         {{"Tensor", CreateConcatTensorGraph},
1143          {"InitializedResource", CreateConcatResourceGraph<true>},
1144          {"UninitializedResource", CreateConcatResourceGraph<false>}}),
__anon3d7075df0602(const ::testing::TestParamInfo<XlaConcatNDOpTest::ParamType>& info) 1145     [](const ::testing::TestParamInfo<XlaConcatNDOpTest::ParamType>& info) {
1146       return info.param.name;
1147     });
1148 
1149 struct RankedXlaConcatNDTestParam {
1150   std::string name;
1151   int rank = 0;
1152   std::function<Status(absl::Span<const TensorShape>, absl::Span<const int32>,
1153                        absl::Span<const int32>, Graph*,
1154                        std::vector<std::string>*)>
1155       graph_creator;
1156 };
1157 
1158 class RankedXlaConcatNDOpTest
1159     : public ::testing::TestWithParam<RankedXlaConcatNDTestParam> {};
1160 
TEST_P(RankedXlaConcatNDOpTest,TestSubscriptRank)1161 TEST_P(RankedXlaConcatNDOpTest, TestSubscriptRank) {
1162   const int rank = GetParam().rank;
1163   const std::vector<int32> num_concats(rank, 2);
1164 
1165   Graph graph(OpRegistry::Global());
1166   const int num_inputs = 2 << (rank - 1);
1167   const TensorShape base_input_shape(std::vector<int64_t>(rank, 1));
1168   const std::vector<TensorShape> input_shapes(num_inputs, base_input_shape);
1169   const std::vector<int32> paddings;
1170   std::vector<std::string> output_tensor_names;
1171   TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1172                                         &graph, &output_tensor_names));
1173 
1174   std::vector<Tensor> output_tensors;
1175   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1176                         &output_tensors));
1177   ASSERT_EQ(output_tensors.size(), 1);
1178   std::vector<int32> expected_values(num_inputs);
1179   std::iota(expected_values.begin(), expected_values.end(), 0);
1180   test::ExpectTensorEqual<int32>(
1181       output_tensors[0],
1182       test::AsTensor<int32>(expected_values,
1183                             TensorShape(std::vector<int64_t>(rank, 2))));
1184 }
1185 
1186 INSTANTIATE_TEST_SUITE_P(
1187     RankedXlaConcatNDOpTest, RankedXlaConcatNDOpTest,
1188     ::testing::ValuesIn<RankedXlaConcatNDTestParam>(
1189         {{"TensorRanked1", 1, CreateConcatTensorGraph},
1190          {"TensorRanked2", 2, CreateConcatTensorGraph},
1191          {"TensorRanked3", 3, CreateConcatTensorGraph},
1192          {"TensorRanked4", 4, CreateConcatTensorGraph},
1193          {"TensorRanked5", 5, CreateConcatTensorGraph},
1194          {"TensorRanked6", 6, CreateConcatTensorGraph},
1195          {"TensorRanked7", 7, CreateConcatTensorGraph},
1196          {"TensorRanked8", 8, CreateConcatTensorGraph},
1197          {"InitializedResourceRanked1", 1, CreateConcatResourceGraph<true>},
1198          {"InitializedResourceRanked2", 2, CreateConcatResourceGraph<true>},
1199          {"InitializedResourceRanked3", 3, CreateConcatResourceGraph<true>},
1200          {"InitializedResourceRanked4", 4, CreateConcatResourceGraph<true>},
1201          {"InitializedResourceRanked5", 5, CreateConcatResourceGraph<true>},
1202          {"InitializedResourceRanked6", 6, CreateConcatResourceGraph<true>},
1203          {"InitializedResourceRanked7", 7, CreateConcatResourceGraph<true>},
1204          {"InitializedResourceRanked8", 8, CreateConcatResourceGraph<true>},
1205          {"UninitializedResourceRanked1", 1, CreateConcatResourceGraph<false>},
1206          {"UninitializedResourceRanked2", 2, CreateConcatResourceGraph<false>},
1207          {"UninitializedResourceRanked3", 3, CreateConcatResourceGraph<false>},
1208          {"UninitializedResourceRanked4", 4, CreateConcatResourceGraph<false>},
1209          {"UninitializedResourceRanked5", 5, CreateConcatResourceGraph<false>},
1210          {"UninitializedResourceRanked6", 6, CreateConcatResourceGraph<false>},
1211          {"UninitializedResourceRanked7", 7, CreateConcatResourceGraph<false>},
1212          {"UninitializedResourceRanked8", 8,
1213           CreateConcatResourceGraph<false>}}),
1214     [](const ::testing::TestParamInfo<RankedXlaConcatNDOpTest::ParamType>&
__anon3d7075df0702(const ::testing::TestParamInfo<RankedXlaConcatNDOpTest::ParamType>& info) 1215            info) { return info.param.name; });
1216 
CreateRoundtripTensorGraph(const TensorShape & input_shape,absl::Span<const int32> num_partitions,absl::Span<const int32> paddings,Graph * graph,std::vector<std::string> * output_tensor_names)1217 Status CreateRoundtripTensorGraph(
1218     const TensorShape& input_shape, absl::Span<const int32> num_partitions,
1219     absl::Span<const int32> paddings, Graph* graph,
1220     std::vector<std::string>* output_tensor_names) {
1221   const int32_t num_partitions_size =
1222       std::accumulate(num_partitions.begin(), num_partitions.end(), 1,
1223                       std::multiplies<int32>());
1224 
1225   DataType data_type = DataTypeToEnum<int32>::value;
1226   Tensor input_tensor(data_type, input_shape);
1227   test::FillIota<int32>(&input_tensor, /*val=*/0);
1228   Node* input = test::graph::Constant(graph, input_tensor);
1229 
1230   Node* xla_split_op = nullptr;
1231   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("xla_split_op"), "XlaSplitND")
1232                          .Input(input)
1233                          .Attr("num_splits", num_partitions)
1234                          .Attr("paddings", paddings)
1235                          .Attr("T", data_type)
1236                          .Attr("N", num_partitions_size)
1237                          .Finalize(graph, &xla_split_op));
1238 
1239   std::vector<NodeBuilder::NodeOut> concat_inputs;
1240   concat_inputs.reserve(num_partitions_size);
1241   for (int i = 0; i < num_partitions_size; ++i) {
1242     concat_inputs.push_back({xla_split_op, i});
1243   }
1244 
1245   Node* xla_concat_op = nullptr;
1246   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("xla_concat_op"), "XlaConcatND")
1247                          .Input(concat_inputs)
1248                          .Attr("num_concats", num_partitions)
1249                          .Attr("paddings", paddings)
1250                          .Attr("T", data_type)
1251                          .Attr("N", num_partitions_size)
1252                          .Finalize(graph, &xla_concat_op));
1253 
1254   Node* equal = nullptr;
1255   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("equal"), "Equal")
1256                          .Input(input)
1257                          .Input(xla_concat_op)
1258                          .Attr("T", data_type)
1259                          .Finalize(graph, &equal));
1260 
1261   output_tensor_names->push_back(absl::StrCat(equal->name(), ":", 0));
1262 
1263   return OkStatus();
1264 }
1265 
CreateRoundtripResourceGraph(const TensorShape & input_shape,absl::Span<const int32> num_partitions,absl::Span<const int32> paddings,Graph * graph,std::vector<std::string> * output_tensor_names)1266 Status CreateRoundtripResourceGraph(
1267     const TensorShape& input_shape, absl::Span<const int32> num_partitions,
1268     absl::Span<const int32> paddings, Graph* graph,
1269     std::vector<std::string>* output_tensor_names) {
1270   const int32_t num_partitions_size =
1271       std::accumulate(num_partitions.begin(), num_partitions.end(), 1,
1272                       std::multiplies<int32>());
1273 
1274   Node* var_handle = nullptr;
1275   DataType data_type = DataTypeToEnum<int32>::value;
1276   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("var_handle"), "VarHandleOp")
1277                          .Attr("dtype", data_type)
1278                          .Attr("shape", PartialTensorShape())
1279                          .Finalize(graph, &var_handle));
1280 
1281   Tensor input_tensor(data_type, input_shape);
1282   test::FillIota<int32>(&input_tensor, 0);
1283   Node* input = test::graph::Constant(graph, input_tensor);
1284 
1285   Node* assign_var = nullptr;
1286   TF_RETURN_IF_ERROR(
1287       NodeBuilder(graph->NewName("assign_var"), "AssignVariableOp")
1288           .Input(var_handle)
1289           .Input(input)
1290           .Attr("dtype", data_type)
1291           .Finalize(graph, &assign_var));
1292 
1293   Node* xla_split_op = nullptr;
1294   TF_RETURN_IF_ERROR(
1295       NodeBuilder(graph->NewName("xla_split_op"), "ReadVariableXlaSplitND")
1296           .Input(var_handle)
1297           .ControlInput(assign_var)
1298           .Attr("num_splits", num_partitions)
1299           .Attr("paddings", paddings)
1300           .Attr("T", data_type)
1301           .Attr("N", num_partitions_size)
1302           .Finalize(graph, &xla_split_op));
1303 
1304   std::vector<NodeBuilder::NodeOut> concat_inputs;
1305   concat_inputs.reserve(num_partitions_size);
1306   for (int i = 0; i < num_partitions_size; ++i) {
1307     concat_inputs.push_back({xla_split_op, i});
1308   }
1309 
1310   Node* xla_concat_op = nullptr;
1311   TF_RETURN_IF_ERROR(
1312       NodeBuilder(graph->NewName("xla_op"), "AssignVariableXlaConcatND")
1313           .Input(var_handle)
1314           .Input(concat_inputs)
1315           .Attr("num_concats", num_partitions)
1316           .Attr("paddings", paddings)
1317           .Attr("T", data_type)
1318           .Attr("N", num_partitions_size)
1319           .Finalize(graph, &xla_concat_op));
1320 
1321   Node* read_var = nullptr;
1322   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("read_var"), "ReadVariableOp")
1323                          .Input(var_handle)
1324                          .ControlInput(xla_concat_op)
1325                          .Attr("dtype", data_type)
1326                          .Finalize(graph, &read_var));
1327 
1328   Node* equal = nullptr;
1329   TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("equal"), "Equal")
1330                          .Input(input)
1331                          .Input(read_var)
1332                          .Attr("T", data_type)
1333                          .Finalize(graph, &equal));
1334 
1335   output_tensor_names->push_back(absl::StrCat(equal->name(), ":", 0));
1336 
1337   return OkStatus();
1338 }
1339 
1340 struct RoundtripXlaSplitConcatNDTestParam {
1341   std::string name;
1342   int rank = 0;
1343   std::function<Status(const TensorShape&, absl::Span<const int32>,
1344                        absl::Span<const int32>, Graph*,
1345                        std::vector<std::string>*)>
1346       graph_creator;
1347 };
1348 
1349 class RoundtripXlaSplitConcatNDTest
1350     : public ::testing::TestWithParam<RoundtripXlaSplitConcatNDTestParam> {};
1351 
1352 template <typename T>
Constant(T v,TensorShape shape)1353 Tensor Constant(T v, TensorShape shape) {
1354   Tensor ret(DataTypeToEnum<T>::value, shape);
1355   ret.flat<T>().setConstant(v);
1356   return ret;
1357 }
1358 
TEST_P(RoundtripXlaSplitConcatNDTest,NoPadding)1359 TEST_P(RoundtripXlaSplitConcatNDTest, NoPadding) {
1360   const int rank = GetParam().rank;
1361   const std::vector<int32> num_partitions(rank, 2);
1362 
1363   Graph graph(OpRegistry::Global());
1364   const TensorShape input_shape(std::vector<int64_t>(rank, 4));
1365   const std::vector<int32> paddings;
1366   std::vector<std::string> output_tensor_names;
1367   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_partitions, paddings,
1368                                         &graph, &output_tensor_names));
1369 
1370   std::vector<Tensor> output_tensors;
1371   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1372                         &output_tensors));
1373   ASSERT_EQ(output_tensors.size(), 1);
1374 
1375   test::ExpectTensorEqual<bool>(
1376       output_tensors[0],
1377       Constant<bool>(true, TensorShape(std::vector<int64_t>(rank, 4))));
1378 }
1379 
TEST_P(RoundtripXlaSplitConcatNDTest,PartialPadding)1380 TEST_P(RoundtripXlaSplitConcatNDTest, PartialPadding) {
1381   const int rank = GetParam().rank;
1382   const std::vector<int32> num_partitions(rank, 2);
1383 
1384   Graph graph(OpRegistry::Global());
1385   const TensorShape input_shape(std::vector<int64_t>(rank, 4));
1386   const std::vector<int32> paddings(rank, 2);
1387   std::vector<std::string> output_tensor_names;
1388   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_partitions, paddings,
1389                                         &graph, &output_tensor_names));
1390 
1391   std::vector<Tensor> output_tensors;
1392   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1393                         &output_tensors));
1394   ASSERT_EQ(output_tensors.size(), 1);
1395 
1396   test::ExpectTensorEqual<bool>(
1397       output_tensors[0],
1398       Constant<bool>(true, TensorShape(std::vector<int64_t>(rank, 4))));
1399 }
1400 
TEST_P(RoundtripXlaSplitConcatNDTest,CompletePadding)1401 TEST_P(RoundtripXlaSplitConcatNDTest, CompletePadding) {
1402   const int rank = GetParam().rank;
1403   const std::vector<int32> num_partitions(rank, 2);
1404 
1405   Graph graph(OpRegistry::Global());
1406   const TensorShape input_shape(std::vector<int64_t>(rank, 4));
1407   const std::vector<int32> paddings(rank, 4);
1408   std::vector<std::string> output_tensor_names;
1409   TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_partitions, paddings,
1410                                         &graph, &output_tensor_names));
1411 
1412   std::vector<Tensor> output_tensors;
1413   TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1414                         &output_tensors));
1415   ASSERT_EQ(output_tensors.size(), 1);
1416 
1417   test::ExpectTensorEqual<bool>(
1418       output_tensors[0],
1419       Constant<bool>(true, TensorShape(std::vector<int64_t>(rank, 4))));
1420 }
1421 
1422 INSTANTIATE_TEST_SUITE_P(
1423     RoundtripXlaSplitConcatNDTest, RoundtripXlaSplitConcatNDTest,
1424     ::testing::ValuesIn<RoundtripXlaSplitConcatNDTestParam>(
1425         {{"TensorRanked1", 1, CreateRoundtripTensorGraph},
1426          {"TensorRanked2", 2, CreateRoundtripTensorGraph},
1427          {"TensorRanked3", 3, CreateRoundtripTensorGraph},
1428          {"TensorRanked4", 4, CreateRoundtripTensorGraph},
1429          {"TensorRanked5", 5, CreateRoundtripTensorGraph},
1430          {"TensorRanked6", 6, CreateRoundtripTensorGraph},
1431          {"TensorRanked7", 7, CreateRoundtripTensorGraph},
1432          {"TensorRanked8", 8, CreateRoundtripTensorGraph},
1433          {"ResourceRanked1", 1, CreateRoundtripResourceGraph},
1434          {"ResourceRanked2", 2, CreateRoundtripResourceGraph},
1435          {"ResourceRanked3", 3, CreateRoundtripResourceGraph},
1436          {"ResourceRanked4", 4, CreateRoundtripResourceGraph},
1437          {"ResourceRanked5", 5, CreateRoundtripResourceGraph},
1438          {"ResourceRanked6", 6, CreateRoundtripResourceGraph},
1439          {"ResourceRanked7", 7, CreateRoundtripResourceGraph},
1440          {"ResourceRanked8", 8, CreateRoundtripResourceGraph}}),
1441     [](const ::testing::TestParamInfo<RoundtripXlaSplitConcatNDTest::ParamType>&
__anon3d7075df0802(const ::testing::TestParamInfo<RoundtripXlaSplitConcatNDTest::ParamType>& info) 1442            info) { return info.param.name; });
1443 
1444 }  // namespace
1445 }  // namespace tensorflow
1446