1 /* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include <functional>
17 #include <memory>
18 #include <numeric>
19 #include <string>
20 #include <vector>
21
22 #include "absl/strings/match.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/span.h"
25 #include "tensorflow/core/framework/graph.pb.h"
26 #include "tensorflow/core/framework/tensor.h"
27 #include "tensorflow/core/framework/tensor_shape.h"
28 #include "tensorflow/core/framework/tensor_testutil.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/framework/types.pb.h"
31 #include "tensorflow/core/graph/graph.h"
32 #include "tensorflow/core/graph/node_builder.h"
33 #include "tensorflow/core/graph/testlib.h"
34 #include "tensorflow/core/lib/core/status_test_util.h"
35 #include "tensorflow/core/platform/errors.h"
36 #include "tensorflow/core/platform/status.h"
37 #include "tensorflow/core/platform/test.h"
38 #include "tensorflow/core/protobuf/config.pb.h"
39 #include "tensorflow/core/protobuf/error_codes.pb.h"
40 #include "tensorflow/core/public/session.h"
41 #include "tensorflow/core/public/session_options.h"
42
43 namespace tensorflow {
44 namespace {
45
46 MATCHER_P2(IsStatus, error_code, error_message, "") {
47 return arg.code() == error_code &&
48 absl::StrContains(arg.error_message(), error_message);
49 }
50
RunGraph(const Graph & graph,const std::vector<std::string> & output_tensor_names,const std::vector<std::string> & target_tensor_names,std::vector<Tensor> * output_tensors)51 Status RunGraph(const Graph& graph,
52 const std::vector<std::string>& output_tensor_names,
53 const std::vector<std::string>& target_tensor_names,
54 std::vector<Tensor>* output_tensors) {
55 GraphDef graph_def;
56 graph.ToGraphDef(&graph_def);
57 SessionOptions session_options;
58 std::unique_ptr<Session> session(NewSession(session_options));
59 TF_RETURN_IF_ERROR(session->Create(graph_def));
60 RunOptions run_options;
61 return session->Run(run_options, /*inputs=*/{}, output_tensor_names,
62 target_tensor_names, output_tensors,
63 /*run_metadata=*/nullptr);
64 }
65
TEST(ReadVariableXlaSplitNDOpTest,VariableMissing)66 TEST(ReadVariableXlaSplitNDOpTest, VariableMissing) {
67 Graph graph(OpRegistry::Global());
68
69 Node* var_handle = nullptr;
70 DataType data_type = DataTypeToEnum<int32>::value;
71 const TensorShape input_shape({4, 4});
72 TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
73 .Attr("dtype", data_type)
74 .Attr("shape", input_shape)
75 .Finalize(&graph, &var_handle));
76
77 Node* xla_op = nullptr;
78 const std::vector<int32> num_splits = {2, 2};
79 const int num_outputs = 4;
80 TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "ReadVariableXlaSplitND")
81 .Input(var_handle)
82 .Attr("num_splits", num_splits)
83 .Attr("T", data_type)
84 .Attr("N", num_outputs)
85 .Finalize(&graph, &xla_op));
86
87 std::vector<Tensor> output_tensors;
88 EXPECT_THAT(RunGraph(graph, /*output_tensor_names=*/{xla_op->name()},
89 /*target_tensor_names=*/{}, &output_tensors),
90 IsStatus(error::INVALID_ARGUMENT, "cannot be found"));
91 }
92
TEST(ReadVariableXlaSplitNDOpTest,DTypeInvalid)93 TEST(ReadVariableXlaSplitNDOpTest, DTypeInvalid) {
94 Graph graph(OpRegistry::Global());
95
96 Node* var_handle = nullptr;
97 DataType data_type = DataTypeToEnum<int32>::value;
98 const TensorShape input_shape({4, 4});
99 TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
100 .Attr("dtype", data_type)
101 .Attr("shape", input_shape)
102 .Finalize(&graph, &var_handle));
103
104 Tensor input_tensor(data_type, input_shape);
105 test::FillIota<int32>(&input_tensor, /*val=*/0);
106 Node* input = test::graph::Constant(&graph, input_tensor);
107
108 Node* assign_var = nullptr;
109 TF_ASSERT_OK(NodeBuilder(graph.NewName("assign_var"), "AssignVariableOp")
110 .Input(var_handle)
111 .Input(input)
112 .Attr("dtype", data_type)
113 .Finalize(&graph, &assign_var));
114
115 Node* xla_op = nullptr;
116 const std::vector<int32> num_splits = {2, 2};
117 const int num_outputs = 4;
118 TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "ReadVariableXlaSplitND")
119 .Input(var_handle)
120 .ControlInput(assign_var)
121 .Attr("num_splits", num_splits)
122 .Attr("T", DataTypeToEnum<float>::value)
123 .Attr("N", num_outputs)
124 .Finalize(&graph, &xla_op));
125
126 std::vector<Tensor> output_tensors;
127 EXPECT_THAT(RunGraph(graph, /*output_tensor_names=*/{xla_op->name()},
128 /*target_tensor_names=*/{}, &output_tensors),
129 IsStatus(error::INVALID_ARGUMENT, "'T' must match 'resource'"));
130 }
131
CreateSplitTensorGraph(const TensorShape & input_shape,absl::Span<const int32> num_splits,absl::Span<const int32> paddings,const int num_outputs,Graph * graph,std::vector<std::string> * output_tensor_names)132 Status CreateSplitTensorGraph(const TensorShape& input_shape,
133 absl::Span<const int32> num_splits,
134 absl::Span<const int32> paddings,
135 const int num_outputs, Graph* graph,
136 std::vector<std::string>* output_tensor_names) {
137 DataType data_type = DataTypeToEnum<int32>::value;
138 Tensor input_tensor(data_type, input_shape);
139 test::FillIota<int32>(&input_tensor, /*val=*/0);
140 Node* input = test::graph::Constant(graph, input_tensor);
141
142 Node* xla_op = nullptr;
143 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("xla_op"), "XlaSplitND")
144 .Input(input)
145 .Attr("num_splits", num_splits)
146 .Attr("paddings", paddings)
147 .Attr("T", data_type)
148 .Attr("N", num_outputs)
149 .Finalize(graph, &xla_op));
150
151 output_tensor_names->reserve(num_outputs);
152 for (int i = 0; i < num_outputs; ++i) {
153 output_tensor_names->push_back(absl::StrCat(xla_op->name(), ":", i));
154 }
155
156 return OkStatus();
157 }
158
CreateSplitResourceGraph(const TensorShape & input_shape,absl::Span<const int32> num_splits,absl::Span<const int32> paddings,const int num_outputs,Graph * graph,std::vector<std::string> * output_tensor_names)159 Status CreateSplitResourceGraph(const TensorShape& input_shape,
160 absl::Span<const int32> num_splits,
161 absl::Span<const int32> paddings,
162 const int num_outputs, Graph* graph,
163 std::vector<std::string>* output_tensor_names) {
164 Node* var_handle = nullptr;
165 DataType data_type = DataTypeToEnum<int32>::value;
166 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("var_handle"), "VarHandleOp")
167 .Attr("dtype", data_type)
168 .Attr("shape", input_shape)
169 .Finalize(graph, &var_handle));
170
171 Tensor input_tensor(data_type, input_shape);
172 test::FillIota<int32>(&input_tensor, /*val=*/0);
173 Node* input = test::graph::Constant(graph, input_tensor);
174
175 Node* assign_var = nullptr;
176 TF_RETURN_IF_ERROR(
177 NodeBuilder(graph->NewName("assign_var"), "AssignVariableOp")
178 .Input(var_handle)
179 .Input(input)
180 .Attr("dtype", data_type)
181 .Finalize(graph, &assign_var));
182
183 Node* xla_op = nullptr;
184 TF_RETURN_IF_ERROR(
185 NodeBuilder(graph->NewName("xla_op"), "ReadVariableXlaSplitND")
186 .Input(var_handle)
187 .ControlInput(assign_var)
188 .Attr("num_splits", num_splits)
189 .Attr("paddings", paddings)
190 .Attr("T", data_type)
191 .Attr("N", num_outputs)
192 .Finalize(graph, &xla_op));
193
194 output_tensor_names->reserve(num_outputs);
195 for (int i = 0; i < num_outputs; ++i) {
196 output_tensor_names->push_back(absl::StrCat(xla_op->name(), ":", i));
197 }
198
199 return OkStatus();
200 }
201
202 struct XlaSplitNDTestParam {
203 std::string name;
204 std::function<Status(const TensorShape&, absl::Span<const int32>,
205 absl::Span<const int32>, const int num_outputs, Graph*,
206 std::vector<std::string>*)>
207 graph_creator;
208 };
209
210 using XlaSplitNDOpTest = ::testing::TestWithParam<XlaSplitNDTestParam>;
211
TEST_P(XlaSplitNDOpTest,SplitDimensionZero)212 TEST_P(XlaSplitNDOpTest, SplitDimensionZero) {
213 Graph graph(OpRegistry::Global());
214 const TensorShape input_shape({1, 1, 1});
215 const std::vector<int32> num_splits = {1, 1, 0};
216 const std::vector<int32> paddings;
217 const int num_outputs = 1;
218 std::vector<std::string> output_tensor_names;
219 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
220 num_outputs, &graph,
221 &output_tensor_names));
222
223 std::vector<Tensor> output_tensors;
224 EXPECT_THAT(
225 RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
226 &output_tensors),
227 IsStatus(error::INVALID_ARGUMENT, "index 2 must be positive, but got 0"));
228 }
229
TEST_P(XlaSplitNDOpTest,SplitDimensionNegative)230 TEST_P(XlaSplitNDOpTest, SplitDimensionNegative) {
231 Graph graph(OpRegistry::Global());
232 const TensorShape input_shape({1, 1, 1});
233 const std::vector<int32> num_splits = {1, -1, 1};
234 const std::vector<int32> paddings;
235 const int num_outputs = 1;
236 std::vector<std::string> output_tensor_names;
237 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
238 num_outputs, &graph,
239 &output_tensor_names));
240
241 std::vector<Tensor> output_tensors;
242 EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
243 &output_tensors),
244 IsStatus(error::INVALID_ARGUMENT,
245 "index 1 must be positive, but got -1"));
246 }
247
TEST_P(XlaSplitNDOpTest,NumOutputsMismatch)248 TEST_P(XlaSplitNDOpTest, NumOutputsMismatch) {
249 Graph graph(OpRegistry::Global());
250 const TensorShape input_shape({2});
251 const std::vector<int32> num_splits = {2};
252 const std::vector<int> paddings;
253 const int num_outputs = 1;
254 std::vector<std::string> output_tensor_names;
255 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
256 num_outputs, &graph,
257 &output_tensor_names));
258
259 std::vector<Tensor> output_tensors;
260 EXPECT_THAT(
261 RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
262 &output_tensors),
263 IsStatus(error::INVALID_ARGUMENT, "'N' must match number of slices 2"));
264 }
265
TEST_P(XlaSplitNDOpTest,PaddingsLengthMismatch)266 TEST_P(XlaSplitNDOpTest, PaddingsLengthMismatch) {
267 Graph graph(OpRegistry::Global());
268 const TensorShape input_shape({2, 2});
269 const std::vector<int32> num_splits = {2, 2};
270 const std::vector<int32> paddings = {0};
271 const int num_outputs = 4;
272 std::vector<std::string> output_tensor_names;
273 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
274 num_outputs, &graph,
275 &output_tensor_names));
276
277 std::vector<Tensor> output_tensors;
278 EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
279 &output_tensors),
280 IsStatus(error::INVALID_ARGUMENT, "length 2, but got 1"));
281 }
282
TEST_P(XlaSplitNDOpTest,PaddingsNegative)283 TEST_P(XlaSplitNDOpTest, PaddingsNegative) {
284 Graph graph(OpRegistry::Global());
285 const TensorShape input_shape({2, 2});
286 const std::vector<int32> num_splits = {2, 2};
287 const std::vector<int32> paddings = {0, -1};
288 const int num_outputs = 4;
289 std::vector<std::string> output_tensor_names;
290 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
291 num_outputs, &graph,
292 &output_tensor_names));
293
294 std::vector<Tensor> output_tensors;
295 EXPECT_THAT(
296 RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
297 &output_tensors),
298 IsStatus(error::INVALID_ARGUMENT, "non-negative, but got -1 at index 1"));
299 }
300
TEST_P(XlaSplitNDOpTest,InputRank0)301 TEST_P(XlaSplitNDOpTest, InputRank0) {
302 Graph graph(OpRegistry::Global());
303 const TensorShape input_shape({});
304 const std::vector<int32> num_splits = {2};
305 const std::vector<int32> paddings;
306 const int num_outputs = 2;
307 std::vector<std::string> output_tensor_names;
308 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
309 num_outputs, &graph,
310 &output_tensor_names));
311
312 std::vector<Tensor> output_tensors;
313 EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
314 &output_tensors),
315 IsStatus(error::INVALID_ARGUMENT, "range (0, 8], but got 0"));
316 }
317
TEST_P(XlaSplitNDOpTest,InputRank9)318 TEST_P(XlaSplitNDOpTest, InputRank9) {
319 Graph graph(OpRegistry::Global());
320 const TensorShape input_shape({2, 2, 2, 2, 2, 2, 2, 2, 2});
321 const std::vector<int32> num_splits(9, 2);
322 const std::vector<int32> paddings;
323 const int num_outputs = 512;
324 std::vector<std::string> output_tensor_names;
325 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
326 num_outputs, &graph,
327 &output_tensor_names));
328
329 std::vector<Tensor> output_tensors;
330 EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
331 &output_tensors),
332 IsStatus(error::INVALID_ARGUMENT, "range (0, 8], but got 9"));
333 }
334
TEST_P(XlaSplitNDOpTest,InputRankSplitMismatch)335 TEST_P(XlaSplitNDOpTest, InputRankSplitMismatch) {
336 Graph graph(OpRegistry::Global());
337 const TensorShape input_shape({2, 2});
338 const std::vector<int32> num_splits = {2, 2, 2};
339 const std::vector<int32> paddings;
340 const int num_outputs = 8;
341 std::vector<std::string> output_tensor_names;
342 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
343 num_outputs, &graph,
344 &output_tensor_names));
345
346 std::vector<Tensor> output_tensors;
347 EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
348 &output_tensors),
349 IsStatus(error::INVALID_ARGUMENT,
350 "'num_splits' length 3, but got rank 2"));
351 }
352
TEST_P(XlaSplitNDOpTest,DimNotEvenlySplit)353 TEST_P(XlaSplitNDOpTest, DimNotEvenlySplit) {
354 Graph graph(OpRegistry::Global());
355 const TensorShape input_shape({4, 2});
356 const std::vector<int32> num_splits = {3, 2};
357 const std::vector<int32> paddings;
358 const int num_outputs = 6;
359 std::vector<std::string> output_tensor_names;
360 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
361 num_outputs, &graph,
362 &output_tensor_names));
363
364 std::vector<Tensor> output_tensors;
365 EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
366 &output_tensors),
367 IsStatus(error::INVALID_ARGUMENT, "divisible by 'num_splits' 3"));
368 }
369
TEST_P(XlaSplitNDOpTest,DimWithPaddingNotEvenlySplit)370 TEST_P(XlaSplitNDOpTest, DimWithPaddingNotEvenlySplit) {
371 Graph graph(OpRegistry::Global());
372 const TensorShape input_shape({4, 2});
373 const std::vector<int32> num_splits = {2, 2};
374 const std::vector<int32> paddings = {0, 1};
375 const int num_outputs = 4;
376 std::vector<std::string> output_tensor_names;
377 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
378 num_outputs, &graph,
379 &output_tensor_names));
380
381 std::vector<Tensor> output_tensors;
382 EXPECT_THAT(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
383 &output_tensors),
384 IsStatus(error::INVALID_ARGUMENT, "divisible by 'num_splits' 2"));
385 }
386
TEST_P(XlaSplitNDOpTest,NoSplits)387 TEST_P(XlaSplitNDOpTest, NoSplits) {
388 Graph graph(OpRegistry::Global());
389 const TensorShape input_shape({2, 2, 2});
390 const std::vector<int32> num_splits = {1, 1, 1};
391 const std::vector<int> paddings;
392 const int num_outputs = 1;
393 std::vector<std::string> output_tensor_names;
394 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
395 num_outputs, &graph,
396 &output_tensor_names));
397
398 std::vector<Tensor> output_tensors;
399 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
400 &output_tensors));
401 ASSERT_EQ(output_tensors.size(), 1);
402 test::ExpectTensorEqual<int32>(
403 output_tensors[0],
404 test::AsTensor<int32>({0, 1, 2, 3, 4, 5, 6, 7}, TensorShape({2, 2, 2})));
405 }
406
TEST_P(XlaSplitNDOpTest,NoSplitsWithPadding)407 TEST_P(XlaSplitNDOpTest, NoSplitsWithPadding) {
408 Graph graph(OpRegistry::Global());
409 const TensorShape input_shape({2, 1, 1});
410 const std::vector<int32> num_splits = {1, 1, 1};
411 const std::vector<int> paddings = {0, 1, 1};
412 const int num_outputs = 1;
413 std::vector<std::string> output_tensor_names;
414 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
415 num_outputs, &graph,
416 &output_tensor_names));
417
418 std::vector<Tensor> output_tensors;
419 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
420 &output_tensors));
421 ASSERT_EQ(output_tensors.size(), 1);
422 std::vector<int32> expected_values(3 * 3 * 3);
423 test::ExpectTensorEqual<int32>(
424 output_tensors[0],
425 test::AsTensor<int32>({0, 0, 0, 0, 1, 0, 0, 0}, TensorShape({2, 2, 2})));
426 }
427
TEST_P(XlaSplitNDOpTest,SplitNoPadding)428 TEST_P(XlaSplitNDOpTest, SplitNoPadding) {
429 Graph graph(OpRegistry::Global());
430 const TensorShape input_shape({4, 4});
431 const std::vector<int32> num_splits = {2, 2};
432 const std::vector<int32> paddings;
433 const int num_outputs = 4;
434 std::vector<std::string> output_tensor_names;
435 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
436 num_outputs, &graph,
437 &output_tensor_names));
438
439 std::vector<Tensor> output_tensors;
440 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
441 &output_tensors));
442 ASSERT_EQ(output_tensors.size(), num_outputs);
443 test::ExpectTensorEqual<int32>(
444 output_tensors[0],
445 test::AsTensor<int32>({0, 1, 4, 5}, TensorShape({2, 2})));
446 test::ExpectTensorEqual<int32>(
447 output_tensors[1],
448 test::AsTensor<int32>({2, 3, 6, 7}, TensorShape({2, 2})));
449 test::ExpectTensorEqual<int32>(
450 output_tensors[2],
451 test::AsTensor<int32>({8, 9, 12, 13}, TensorShape({2, 2})));
452 test::ExpectTensorEqual<int32>(
453 output_tensors[3],
454 test::AsTensor<int32>({10, 11, 14, 15}, TensorShape({2, 2})));
455 }
456
TEST_P(XlaSplitNDOpTest,SplitPartialPadding)457 TEST_P(XlaSplitNDOpTest, SplitPartialPadding) {
458 Graph graph(OpRegistry::Global());
459 const TensorShape input_shape({3, 3});
460 const std::vector<int32> num_splits = {2, 2};
461 const std::vector<int32> paddings = {1, 1};
462 const int num_outputs = 4;
463 std::vector<std::string> output_tensor_names;
464 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
465 num_outputs, &graph,
466 &output_tensor_names));
467
468 std::vector<Tensor> output_tensors;
469 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
470 &output_tensors));
471 ASSERT_EQ(output_tensors.size(), num_outputs);
472 test::ExpectTensorEqual<int32>(
473 output_tensors[0],
474 test::AsTensor<int32>({0, 1, 3, 4}, TensorShape({2, 2})));
475 test::ExpectTensorEqual<int32>(
476 output_tensors[1],
477 test::AsTensor<int32>({2, 0, 5, 0}, TensorShape({2, 2})));
478 test::ExpectTensorEqual<int32>(
479 output_tensors[2],
480 test::AsTensor<int32>({6, 7, 0, 0}, TensorShape({2, 2})));
481 test::ExpectTensorEqual<int32>(
482 output_tensors[3],
483 test::AsTensor<int32>({8, 0, 0, 0}, TensorShape({2, 2})));
484 }
485
TEST_P(XlaSplitNDOpTest,SplitCompletePadding)486 TEST_P(XlaSplitNDOpTest, SplitCompletePadding) {
487 Graph graph(OpRegistry::Global());
488 const TensorShape input_shape({2, 1});
489 const std::vector<int32> num_splits = {2, 2};
490 const std::vector<int32> paddings = {2, 3};
491 const int num_outputs = 4;
492 std::vector<std::string> output_tensor_names;
493 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
494 num_outputs, &graph,
495 &output_tensor_names));
496
497 std::vector<Tensor> output_tensors;
498 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
499 &output_tensors));
500 ASSERT_EQ(output_tensors.size(), num_outputs);
501 test::ExpectTensorEqual<int32>(
502 output_tensors[0],
503 test::AsTensor<int32>({0, 0, 1, 0}, TensorShape({2, 2})));
504 test::ExpectTensorEqual<int32>(
505 output_tensors[1],
506 test::AsTensor<int32>({0, 0, 0, 0}, TensorShape({2, 2})));
507 test::ExpectTensorEqual<int32>(
508 output_tensors[2],
509 test::AsTensor<int32>({0, 0, 0, 0}, TensorShape({2, 2})));
510 test::ExpectTensorEqual<int32>(
511 output_tensors[3],
512 test::AsTensor<int32>({0, 0, 0, 0}, TensorShape({2, 2})));
513 }
514
515 INSTANTIATE_TEST_SUITE_P(
516 XlaSplitNDOpTest, XlaSplitNDOpTest,
517 ::testing::ValuesIn<XlaSplitNDTestParam>(
518 {{"Tensor", CreateSplitTensorGraph},
519 {"Resource", CreateSplitResourceGraph}}),
__anon3d7075df0202(const ::testing::TestParamInfo<XlaSplitNDOpTest::ParamType>& info) 520 [](const ::testing::TestParamInfo<XlaSplitNDOpTest::ParamType>& info) {
521 return info.param.name;
522 });
523
524 struct RankedXlaSplitNDTestParam {
525 std::string name;
526 int rank = 0;
527 std::function<Status(const TensorShape&, absl::Span<const int32>,
528 absl::Span<const int32>, const int num_outputs, Graph*,
529 std::vector<std::string>*)>
530 graph_creator;
531 };
532
533 class RankedXlaSplitNDOpTest
534 : public ::testing::TestWithParam<RankedXlaSplitNDTestParam> {};
535
TEST_P(RankedXlaSplitNDOpTest,TestSubscriptRank)536 TEST_P(RankedXlaSplitNDOpTest, TestSubscriptRank) {
537 const int rank = GetParam().rank;
538 const std::vector<int32> num_splits(rank, 2);
539
540 Graph graph(OpRegistry::Global());
541 const TensorShape input_shape(std::vector<int64_t>(rank, 2));
542 const std::vector<int32> paddings;
543 const int num_outputs = 2 << (rank - 1);
544 std::vector<std::string> output_tensor_names;
545 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_splits, paddings,
546 num_outputs, &graph,
547 &output_tensor_names));
548
549 std::vector<Tensor> output_tensors;
550 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
551 &output_tensors));
552 ASSERT_EQ(output_tensors.size(), num_outputs);
553 TensorShape output_shape(std::vector<int64_t>(rank, 1));
554 for (int i = 0; i < num_outputs; ++i) {
555 test::ExpectTensorEqual<int32>(output_tensors[i],
556 test::AsTensor<int32>({i}, output_shape));
557 }
558 }
559
560 INSTANTIATE_TEST_SUITE_P(
561 RankedXlaSplitNDOpTest, RankedXlaSplitNDOpTest,
562 ::testing::ValuesIn<RankedXlaSplitNDTestParam>(
563 {{"TensorRanked1", 1, CreateSplitTensorGraph},
564 {"TensorRanked2", 2, CreateSplitTensorGraph},
565 {"TensorRanked3", 3, CreateSplitTensorGraph},
566 {"TensorRanked4", 4, CreateSplitTensorGraph},
567 {"TensorRanked5", 5, CreateSplitTensorGraph},
568 {"TensorRanked6", 6, CreateSplitTensorGraph},
569 {"TensorRanked7", 7, CreateSplitTensorGraph},
570 {"TensorRanked8", 8, CreateSplitTensorGraph},
571 {"ResourceRanked1", 1, CreateSplitResourceGraph},
572 {"ResourceRanked2", 2, CreateSplitResourceGraph},
573 {"ResourceRanked3", 3, CreateSplitResourceGraph},
574 {"ResourceRanked4", 4, CreateSplitResourceGraph},
575 {"ResourceRanked5", 5, CreateSplitResourceGraph},
576 {"ResourceRanked6", 6, CreateSplitResourceGraph},
577 {"ResourceRanked7", 7, CreateSplitResourceGraph},
578 {"ResourceRanked8", 8, CreateSplitResourceGraph}}),
579 [](const ::testing::TestParamInfo<RankedXlaSplitNDOpTest::ParamType>&
__anon3d7075df0302(const ::testing::TestParamInfo<RankedXlaSplitNDOpTest::ParamType>& info) 580 info) { return info.param.name; });
581
TEST(AssignVariableXlaConcatNDOpTest,HandleDTypeInvalid)582 TEST(AssignVariableXlaConcatNDOpTest, HandleDTypeInvalid) {
583 Graph graph(OpRegistry::Global());
584 Node* var_handle = nullptr;
585 DataType handle_dtype = DataTypeToEnum<int32>::value;
586 PartialTensorShape handle_shape;
587 TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
588 .Attr("dtype", handle_dtype)
589 .Attr("shape", handle_shape)
590 .Finalize(&graph, &var_handle));
591 DataType update_data_type = DataTypeToEnum<float>::value;
592 const TensorShape update_input_shape({4, 4});
593 Tensor update_input_tensor(update_data_type, update_input_shape);
594 test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
595 Node* update_input = test::graph::Constant(&graph, update_input_tensor);
596 Node* xla_op = nullptr;
597 const std::vector<int32> num_concats = {1, 1};
598 const int num_inputs = 1;
599 TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
600 .Input(var_handle)
601 .Input(std::vector<NodeBuilder::NodeOut>{update_input})
602 .Attr("num_concats", num_concats)
603 .Attr("T", update_data_type)
604 .Attr("N", num_inputs)
605 .Finalize(&graph, &xla_op));
606
607 std::vector<Tensor> output_tensors;
608 EXPECT_THAT(
609 RunGraph(graph, /*output_tensor_names=*/{},
610 /*target_tensor_names=*/{xla_op->name()}, &output_tensors),
611 IsStatus(error::INVALID_ARGUMENT, "dtype int32, but got float"));
612 }
613
TEST(AssignVariableXlaConcatNDOpTest,TensorDTypeInvalid)614 TEST(AssignVariableXlaConcatNDOpTest, TensorDTypeInvalid) {
615 Graph graph(OpRegistry::Global());
616
617 Node* var_handle = nullptr;
618 DataType handle_dtype = DataTypeToEnum<float>::value;
619 PartialTensorShape handle_shape;
620 TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
621 .Attr("dtype", handle_dtype)
622 .Attr("shape", handle_shape)
623 .Finalize(&graph, &var_handle));
624
625 DataType init_data_type = DataTypeToEnum<int32>::value;
626 const TensorShape init_input_shape({4, 4});
627 Tensor init_input_tensor(init_data_type, init_input_shape);
628 test::FillIota<int32>(&init_input_tensor, /*val=*/0);
629 Node* input = test::graph::Constant(&graph, init_input_tensor);
630
631 Node* assign_var = nullptr;
632 TF_ASSERT_OK(NodeBuilder(graph.NewName("assign_var"), "AssignVariableOp")
633 .Input(var_handle)
634 .Input(input)
635 .Attr("dtype", init_data_type)
636 .Finalize(&graph, &assign_var));
637
638 DataType update_data_type = DataTypeToEnum<float>::value;
639 const TensorShape update_input_shape({4, 4});
640 Tensor update_input_tensor(update_data_type, update_input_shape);
641 test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
642 Node* update_input = test::graph::Constant(&graph, update_input_tensor);
643
644 Node* xla_op = nullptr;
645 const std::vector<int32> num_concats = {1, 1};
646 const int num_inputs = 1;
647 TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
648 .Input(var_handle)
649 .Input(std::vector<NodeBuilder::NodeOut>{update_input})
650 .ControlInput(assign_var)
651 .Attr("num_concats", num_concats)
652 .Attr("T", update_data_type)
653 .Attr("N", num_inputs)
654 .Finalize(&graph, &xla_op));
655
656 std::vector<Tensor> output_tensors;
657 EXPECT_THAT(
658 RunGraph(graph, /*output_tensor_names=*/{},
659 /*target_tensor_names=*/{xla_op->name()}, &output_tensors),
660 IsStatus(error::INVALID_ARGUMENT, "dtype int32, but got float"));
661 }
662
TEST(AssignVariableXlaConcatNDOpTest,HandleShapeIncompatible)663 TEST(AssignVariableXlaConcatNDOpTest, HandleShapeIncompatible) {
664 Graph graph(OpRegistry::Global());
665
666 Node* var_handle = nullptr;
667 DataType handle_dtype = DataTypeToEnum<float>::value;
668 PartialTensorShape handle_shape({});
669 TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
670 .Attr("dtype", handle_dtype)
671 .Attr("shape", handle_shape)
672 .Finalize(&graph, &var_handle));
673
674 DataType update_data_type = DataTypeToEnum<float>::value;
675 const TensorShape update_input_shape({4, 4});
676 Tensor update_input_tensor(update_data_type, update_input_shape);
677 test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
678 Node* update_input = test::graph::Constant(&graph, update_input_tensor);
679
680 Node* xla_op = nullptr;
681 const std::vector<int32> num_concats = {1, 1};
682 const int num_inputs = 1;
683 TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
684 .Input(var_handle)
685 .Input(std::vector<NodeBuilder::NodeOut>{update_input})
686 .Attr("num_concats", num_concats)
687 .Attr("T", update_data_type)
688 .Attr("N", num_inputs)
689 .Finalize(&graph, &xla_op));
690
691 std::vector<Tensor> output_tensors;
692 EXPECT_THAT(
693 RunGraph(graph, /*output_tensor_names=*/{},
694 /*target_tensor_names=*/{xla_op->name()}, &output_tensors),
695 IsStatus(error::INVALID_ARGUMENT, "expected shape [4,4], but got []"));
696 }
697
TEST(AssignVariableXlaConcatNDOpTest,HandleShapeWithPaddingIncompatible)698 TEST(AssignVariableXlaConcatNDOpTest, HandleShapeWithPaddingIncompatible) {
699 Graph graph(OpRegistry::Global());
700
701 Node* var_handle = nullptr;
702 DataType handle_dtype = DataTypeToEnum<float>::value;
703 PartialTensorShape handle_shape({4, 4});
704 TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
705 .Attr("dtype", handle_dtype)
706 .Attr("shape", handle_shape)
707 .Finalize(&graph, &var_handle));
708
709 DataType update_data_type = DataTypeToEnum<float>::value;
710 const TensorShape update_input_shape({4, 4});
711 Tensor update_input_tensor(update_data_type, update_input_shape);
712 test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
713 Node* update_input = test::graph::Constant(&graph, update_input_tensor);
714
715 Node* xla_op = nullptr;
716 const std::vector<int32> num_concats = {1, 1};
717 const std::vector<int32> paddings = {1, 1};
718 const int num_inputs = 1;
719 TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
720 .Input(var_handle)
721 .Input(std::vector<NodeBuilder::NodeOut>{update_input})
722 .Attr("num_concats", num_concats)
723 .Attr("paddings", paddings)
724 .Attr("T", update_data_type)
725 .Attr("N", num_inputs)
726 .Finalize(&graph, &xla_op));
727
728 std::vector<Tensor> output_tensors;
729 EXPECT_THAT(
730 RunGraph(graph, /*output_tensor_names=*/{},
731 /*target_tensor_names=*/{xla_op->name()}, &output_tensors),
732 IsStatus(error::INVALID_ARGUMENT, "expected shape [3,3], but got [4,4]"));
733 }
734
TEST(AssignVariableXlaConcatNDOpTest,AssignDifferentShape)735 TEST(AssignVariableXlaConcatNDOpTest, AssignDifferentShape) {
736 Graph graph(OpRegistry::Global());
737
738 Node* var_handle = nullptr;
739 DataType data_type = DataTypeToEnum<float>::value;
740 TF_ASSERT_OK(NodeBuilder(graph.NewName("var_handle"), "VarHandleOp")
741 .Attr("dtype", data_type)
742 .Attr("shape", PartialTensorShape({4, -1}))
743 .Finalize(&graph, &var_handle));
744
745 const TensorShape init_input_shape({4, 2});
746 Tensor init_input_tensor(data_type, init_input_shape);
747 test::FillFn<float>(&init_input_tensor, [](int unused) { return -1.f; });
748 Node* init_input = test::graph::Constant(&graph, init_input_tensor);
749
750 Node* assign_var = nullptr;
751 TF_ASSERT_OK(NodeBuilder(graph.NewName("assign_var"), "AssignVariableOp")
752 .Input(var_handle)
753 .Input(init_input)
754 .Attr("dtype", data_type)
755 .Finalize(&graph, &assign_var));
756
757 const TensorShape update_input_shape({4, 4});
758 Tensor update_input_tensor(data_type, update_input_shape);
759 test::FillIota<float>(&update_input_tensor, /*val=*/0.f);
760 Node* update_input = test::graph::Constant(&graph, update_input_tensor);
761
762 Node* xla_op = nullptr;
763 const std::vector<int32> num_concats = {1, 1};
764 const int num_inputs = 1;
765 TF_ASSERT_OK(NodeBuilder(graph.NewName("xla_op"), "AssignVariableXlaConcatND")
766 .Input(var_handle)
767 .Input(std::vector<NodeBuilder::NodeOut>{update_input})
768 .ControlInput(assign_var)
769 .Attr("num_concats", num_concats)
770 .Attr("T", data_type)
771 .Attr("N", num_inputs)
772 .Finalize(&graph, &xla_op));
773
774 Node* read_var = nullptr;
775 TF_ASSERT_OK(NodeBuilder(graph.NewName("read_var"), "ReadVariableOp")
776 .Input(var_handle)
777 .ControlInput(xla_op)
778 .Attr("dtype", data_type)
779 .Finalize(&graph, &read_var));
780
781 std::vector<Tensor> output_tensors;
782 TF_ASSERT_OK(RunGraph(
783 graph, /*output_tensor_names=*/{absl::StrCat(read_var->name(), ":", 0)},
784 /*target_tensor_names=*/{}, &output_tensors));
785 ASSERT_EQ(output_tensors.size(), 1);
786 test::ExpectTensorNear<float>(output_tensors[0], update_input_tensor,
787 /*atol=*/1e-6);
788 }
789
CreateConcatTensorGraph(absl::Span<const TensorShape> input_shapes,absl::Span<const int32> num_concats,absl::Span<const int32> paddings,Graph * graph,std::vector<std::string> * output_tensor_names)790 Status CreateConcatTensorGraph(absl::Span<const TensorShape> input_shapes,
791 absl::Span<const int32> num_concats,
792 absl::Span<const int32> paddings, Graph* graph,
793 std::vector<std::string>* output_tensor_names) {
794 int32_t val = 0;
795 DataType data_type = DataTypeToEnum<int32>::value;
796 std::vector<NodeBuilder::NodeOut> inputs;
797 inputs.reserve(input_shapes.size());
798 for (const TensorShape& input_shape : input_shapes) {
799 Tensor input_tensor(data_type, input_shape);
800 test::FillIota<int32>(&input_tensor, val);
801 val += input_tensor.NumElements();
802 inputs.push_back(test::graph::Constant(graph, input_tensor));
803 }
804
805 Node* xla_op = nullptr;
806 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("xla_op"), "XlaConcatND")
807 .Input(inputs)
808 .Attr("num_concats", num_concats)
809 .Attr("paddings", paddings)
810 .Attr("T", data_type)
811 .Attr("N", static_cast<int64_t>(input_shapes.size()))
812 .Finalize(graph, &xla_op));
813
814 output_tensor_names->push_back(absl::StrCat(xla_op->name(), ":", 0));
815
816 return OkStatus();
817 }
818
819 template <bool Init>
CreateConcatResourceGraph(absl::Span<const TensorShape> input_shapes,absl::Span<const int32> num_concats,absl::Span<const int32> paddings,Graph * graph,std::vector<std::string> * output_tensor_names)820 Status CreateConcatResourceGraph(
821 absl::Span<const TensorShape> input_shapes,
822 absl::Span<const int32> num_concats, absl::Span<const int32> paddings,
823 Graph* graph, std::vector<std::string>* output_tensor_names) {
824 Node* var_handle = nullptr;
825 DataType data_type = DataTypeToEnum<int32>::value;
826 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("var_handle"), "VarHandleOp")
827 .Attr("dtype", data_type)
828 .Attr("shape", PartialTensorShape())
829 .Finalize(graph, &var_handle));
830
831 Node* assign_var = nullptr;
832 if (Init) {
833 Tensor init_input_tensor(data_type, input_shapes.front());
834 test::FillFn<int32>(&init_input_tensor, [](int unused) { return -1; });
835 Node* init_input = test::graph::Constant(graph, init_input_tensor);
836
837 TF_RETURN_IF_ERROR(
838 NodeBuilder(graph->NewName("assign_var"), "AssignVariableOp")
839 .Input(var_handle)
840 .Input(init_input)
841 .Attr("dtype", data_type)
842 .Finalize(graph, &assign_var));
843 }
844
845 int32_t val = 0;
846 std::vector<NodeBuilder::NodeOut> inputs;
847 inputs.reserve(input_shapes.size());
848 for (const TensorShape& input_shape : input_shapes) {
849 Tensor input_tensor(data_type, input_shape);
850 test::FillIota<int32>(&input_tensor, val);
851 val += input_tensor.NumElements();
852 inputs.push_back(test::graph::Constant(graph, input_tensor));
853 }
854
855 Node* xla_op = nullptr;
856 NodeBuilder builder(graph->NewName("xla_op"), "AssignVariableXlaConcatND");
857 builder.Input(var_handle);
858 builder.Input(inputs);
859 if (assign_var != nullptr) {
860 builder.ControlInput(assign_var);
861 }
862 TF_RETURN_IF_ERROR(builder.Attr("num_concats", num_concats)
863 .Attr("paddings", paddings)
864 .Attr("T", data_type)
865 .Attr("N", static_cast<int64_t>(input_shapes.size()))
866 .Finalize(graph, &xla_op));
867
868 Node* read_var = nullptr;
869 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("read_var"), "ReadVariableOp")
870 .Input(var_handle)
871 .ControlInput(xla_op)
872 .Attr("dtype", data_type)
873 .Finalize(graph, &read_var));
874
875 output_tensor_names->push_back(absl::StrCat(read_var->name(), ":", 0));
876
877 return OkStatus();
878 }
879
880 struct XlaConcatNDTestParam {
881 std::string name;
882 std::function<Status(absl::Span<const TensorShape>, absl::Span<const int32>,
883 absl::Span<const int32>, Graph*,
884 std::vector<std::string>*)>
885 graph_creator;
886 };
887
888 using XlaConcatNDOpTest = ::testing::TestWithParam<XlaConcatNDTestParam>;
889
TEST_P(XlaConcatNDOpTest,ConcatDimensionZero)890 TEST_P(XlaConcatNDOpTest, ConcatDimensionZero) {
891 Graph graph(OpRegistry::Global());
892 const TensorShape input_shape({1, 1, 1});
893 const std::vector<int32> num_concats = {1, 1, 0};
894 const std::vector<int32> paddings;
895 std::vector<std::string> output_tensor_names;
896 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
897 &graph, &output_tensor_names));
898
899 std::vector<Tensor> output_tensors;
900 EXPECT_THAT(
901 RunGraph(graph, output_tensor_names,
902 /*target_tensor_names=*/{}, &output_tensors),
903 IsStatus(error::INVALID_ARGUMENT, "index 2 must be positive, but got 0"));
904 }
905
TEST_P(XlaConcatNDOpTest,ConcatDimensionNegative)906 TEST_P(XlaConcatNDOpTest, ConcatDimensionNegative) {
907 Graph graph(OpRegistry::Global());
908 const TensorShape input_shape({1, 1, 1});
909 const std::vector<int32> num_splits = {1, -1, 1};
910 const std::vector<int32> paddings;
911 std::vector<std::string> output_tensor_names;
912 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_splits, paddings,
913 &graph, &output_tensor_names));
914
915 std::vector<Tensor> output_tensors;
916 EXPECT_THAT(RunGraph(graph, output_tensor_names,
917 /*target_tensor_names=*/{}, &output_tensors),
918 IsStatus(error::INVALID_ARGUMENT,
919 "index 1 must be positive, but got -1"));
920 }
921
TEST_P(XlaConcatNDOpTest,NumInputsMismatch)922 TEST_P(XlaConcatNDOpTest, NumInputsMismatch) {
923 Graph graph(OpRegistry::Global());
924 const TensorShape input_shape({2});
925 const std::vector<int32> num_concats = {2};
926 const std::vector<int> paddings;
927 std::vector<std::string> output_tensor_names;
928 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
929 &graph, &output_tensor_names));
930
931 std::vector<Tensor> output_tensors;
932 EXPECT_THAT(
933 RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
934 &output_tensors),
935 IsStatus(error::INVALID_ARGUMENT, "'N' must match number of slices 2"));
936 }
937
TEST_P(XlaConcatNDOpTest,PaddingsLengthMismatch)938 TEST_P(XlaConcatNDOpTest, PaddingsLengthMismatch) {
939 Graph graph(OpRegistry::Global());
940 const TensorShape input_shape({2, 2});
941 const std::vector<int32> num_concats = {1, 1};
942 const std::vector<int32> paddings = {0};
943 std::vector<std::string> output_tensor_names;
944 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
945 &graph, &output_tensor_names));
946
947 std::vector<Tensor> output_tensors;
948 EXPECT_THAT(RunGraph(graph, output_tensor_names,
949 /*target_tensor_names=*/{}, &output_tensors),
950 IsStatus(error::INVALID_ARGUMENT, "length 2, but got 1"));
951 }
952
TEST_P(XlaConcatNDOpTest,PaddingsNegative)953 TEST_P(XlaConcatNDOpTest, PaddingsNegative) {
954 Graph graph(OpRegistry::Global());
955 const TensorShape input_shape({2, 2});
956 const std::vector<int32> num_concats = {1, 1};
957 const std::vector<int32> paddings = {0, -1};
958 std::vector<std::string> output_tensor_names;
959 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
960 &graph, &output_tensor_names));
961
962 std::vector<Tensor> output_tensors;
963 EXPECT_THAT(
964 RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
965 &output_tensors),
966 IsStatus(error::INVALID_ARGUMENT, "non-negative, but got -1 at index 1"));
967 }
968
TEST_P(XlaConcatNDOpTest,InputRank0)969 TEST_P(XlaConcatNDOpTest, InputRank0) {
970 Graph graph(OpRegistry::Global());
971 const TensorShape input_shape({});
972 const std::vector<int32> num_concats;
973 const std::vector<int32> paddings;
974 std::vector<std::string> output_tensor_names;
975 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
976 &graph, &output_tensor_names));
977
978 std::vector<Tensor> output_tensors;
979 EXPECT_THAT(RunGraph(graph, output_tensor_names,
980 /*target_tensor_names=*/{}, &output_tensors),
981 IsStatus(error::INVALID_ARGUMENT, "range (0, 8], but got 0"));
982 }
983
TEST_P(XlaConcatNDOpTest,InputRank9)984 TEST_P(XlaConcatNDOpTest, InputRank9) {
985 Graph graph(OpRegistry::Global());
986 const TensorShape input_shape({1, 1, 1, 1, 1, 1, 1, 1, 1});
987 const std::vector<int32> num_concats(9, 1);
988 const std::vector<int32> paddings;
989 std::vector<std::string> output_tensor_names;
990 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
991 &graph, &output_tensor_names));
992
993 std::vector<Tensor> output_tensors;
994 EXPECT_THAT(RunGraph(graph, output_tensor_names,
995 /*target_tensor_names=*/{}, &output_tensors),
996 IsStatus(error::INVALID_ARGUMENT, "range (0, 8], but got 9"));
997 }
998
TEST_P(XlaConcatNDOpTest,InputRankConcatMismatch)999 TEST_P(XlaConcatNDOpTest, InputRankConcatMismatch) {
1000 Graph graph(OpRegistry::Global());
1001 const TensorShape input_shape({1});
1002 const std::vector<int32> num_concats = {1, 1};
1003 const std::vector<int32> paddings;
1004 std::vector<std::string> output_tensor_names;
1005 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
1006 &graph, &output_tensor_names));
1007
1008 std::vector<Tensor> output_tensors;
1009 EXPECT_THAT(RunGraph(graph, output_tensor_names,
1010 /*target_tensor_names=*/{}, &output_tensors),
1011 IsStatus(error::INVALID_ARGUMENT,
1012 "'num_concats' length 2, but got rank 1"));
1013 }
1014
TEST_P(XlaConcatNDOpTest,DifferentShapedInputs)1015 TEST_P(XlaConcatNDOpTest, DifferentShapedInputs) {
1016 Graph graph(OpRegistry::Global());
1017 const std::vector<TensorShape> input_shapes{{1}, {2}};
1018 const std::vector<int32> num_concats = {2};
1019 const std::vector<int32> paddings;
1020 std::vector<std::string> output_tensor_names;
1021 TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1022 &graph, &output_tensor_names));
1023
1024 std::vector<Tensor> output_tensors;
1025 EXPECT_THAT(RunGraph(graph, output_tensor_names,
1026 /*target_tensor_names=*/{}, &output_tensors),
1027 IsStatus(error::INVALID_ARGUMENT,
1028 "same expected shape [1], but got [2] at index 1"));
1029 }
1030
TEST_P(XlaConcatNDOpTest,PaddingExceedsOutputDimSize)1031 TEST_P(XlaConcatNDOpTest, PaddingExceedsOutputDimSize) {
1032 Graph graph(OpRegistry::Global());
1033 const TensorShape input_shape({1});
1034 const std::vector<int32> num_concats = {1};
1035 const std::vector<int32> paddings = {2};
1036 std::vector<std::string> output_tensor_names;
1037 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
1038 &graph, &output_tensor_names));
1039
1040 std::vector<Tensor> output_tensors;
1041 EXPECT_THAT(
1042 RunGraph(graph, output_tensor_names,
1043 /*target_tensor_names=*/{}, &output_tensors),
1044 IsStatus(
1045 error::INVALID_ARGUMENT,
1046 "exceed expected output shape dimension 1 at index 0, but got 2"));
1047 }
1048
TEST_P(XlaConcatNDOpTest,NoConcats)1049 TEST_P(XlaConcatNDOpTest, NoConcats) {
1050 Graph graph(OpRegistry::Global());
1051 const TensorShape input_shape({2, 2, 2});
1052 const std::vector<int32> num_concats = {1, 1, 1};
1053 const std::vector<int> paddings;
1054 std::vector<std::string> output_tensor_names;
1055 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
1056 &graph, &output_tensor_names));
1057
1058 std::vector<Tensor> output_tensors;
1059 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1060 &output_tensors));
1061 ASSERT_EQ(output_tensors.size(), 1);
1062 test::ExpectTensorEqual<int32>(
1063 output_tensors[0],
1064 test::AsTensor<int32>({0, 1, 2, 3, 4, 5, 6, 7}, TensorShape({2, 2, 2})));
1065 }
1066
TEST_P(XlaConcatNDOpTest,NoConcatsWithPadding)1067 TEST_P(XlaConcatNDOpTest, NoConcatsWithPadding) {
1068 Graph graph(OpRegistry::Global());
1069 const TensorShape input_shape({2, 2, 2});
1070 const std::vector<int32> num_concats = {1, 1, 1};
1071 const std::vector<int> paddings = {1, 1, 1};
1072 std::vector<std::string> output_tensor_names;
1073 TF_ASSERT_OK(GetParam().graph_creator({input_shape}, num_concats, paddings,
1074 &graph, &output_tensor_names));
1075
1076 std::vector<Tensor> output_tensors;
1077 TF_ASSERT_OK(RunGraph(graph, output_tensor_names,
1078 /*target_tensor_names=*/{}, &output_tensors));
1079 ASSERT_EQ(output_tensors.size(), 1);
1080 test::ExpectTensorEqual<int32>(
1081 output_tensors[0], test::AsTensor<int32>({0}, TensorShape({1, 1, 1})));
1082 }
1083
TEST_P(XlaConcatNDOpTest,ConcatNoPadding)1084 TEST_P(XlaConcatNDOpTest, ConcatNoPadding) {
1085 Graph graph(OpRegistry::Global());
1086 const std::vector<TensorShape> input_shapes{{2, 2}, {2, 2}, {2, 2}, {2, 2}};
1087 const std::vector<int32> num_concats = {2, 2};
1088 const std::vector<int32> paddings;
1089 std::vector<std::string> output_tensor_names;
1090 TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1091 &graph, &output_tensor_names));
1092
1093 std::vector<Tensor> output_tensors;
1094 TF_ASSERT_OK(RunGraph(graph, output_tensor_names,
1095 /*target_tensor_names=*/{}, &output_tensors));
1096 ASSERT_EQ(output_tensors.size(), 1);
1097 test::ExpectTensorEqual<int32>(
1098 output_tensors[0], test::AsTensor<int32>({0, 1, 4, 5, 2, 3, 6, 7, 8, 9,
1099 12, 13, 10, 11, 14, 15},
1100 TensorShape({4, 4})));
1101 }
1102
TEST_P(XlaConcatNDOpTest,ConcatPartialPadding)1103 TEST_P(XlaConcatNDOpTest, ConcatPartialPadding) {
1104 Graph graph(OpRegistry::Global());
1105 const std::vector<TensorShape> input_shapes{{2, 2}, {2, 2}, {2, 2}, {2, 2}};
1106 const std::vector<int32> num_concats = {2, 2};
1107 const std::vector<int32> paddings = {1, 1};
1108 std::vector<std::string> output_tensor_names;
1109 TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1110 &graph, &output_tensor_names));
1111
1112 std::vector<Tensor> output_tensors;
1113 TF_ASSERT_OK(RunGraph(graph, output_tensor_names,
1114 /*target_tensor_names=*/{}, &output_tensors));
1115 ASSERT_EQ(output_tensors.size(), 1);
1116 test::ExpectTensorEqual<int32>(
1117 output_tensors[0],
1118 test::AsTensor<int32>({0, 1, 4, 2, 3, 6, 8, 9, 12}, TensorShape({3, 3})));
1119 }
1120
TEST_P(XlaConcatNDOpTest,ConcatCompletePadding)1121 TEST_P(XlaConcatNDOpTest, ConcatCompletePadding) {
1122 Graph graph(OpRegistry::Global());
1123 const std::vector<TensorShape> input_shapes{{2, 2}, {2, 2}, {2, 2}, {2, 2}};
1124 const std::vector<int32> num_concats = {2, 2};
1125 const std::vector<int32> paddings = {2, 2};
1126 std::vector<std::string> output_tensor_names;
1127 TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1128 &graph, &output_tensor_names));
1129
1130 std::vector<Tensor> output_tensors;
1131 TF_ASSERT_OK(RunGraph(graph, output_tensor_names,
1132 /*target_tensor_names=*/{}, &output_tensors));
1133 ASSERT_EQ(output_tensors.size(), 1);
1134 test::ExpectTensorEqual<int32>(
1135 output_tensors[0],
1136 test::AsTensor<int32>({0, 1, 2, 3}, TensorShape({2, 2})));
1137 }
1138
1139 INSTANTIATE_TEST_SUITE_P(
1140 XlaConcatNDOpTest, XlaConcatNDOpTest,
1141 ::testing::ValuesIn<XlaConcatNDTestParam>(
1142 {{"Tensor", CreateConcatTensorGraph},
1143 {"InitializedResource", CreateConcatResourceGraph<true>},
1144 {"UninitializedResource", CreateConcatResourceGraph<false>}}),
__anon3d7075df0602(const ::testing::TestParamInfo<XlaConcatNDOpTest::ParamType>& info) 1145 [](const ::testing::TestParamInfo<XlaConcatNDOpTest::ParamType>& info) {
1146 return info.param.name;
1147 });
1148
1149 struct RankedXlaConcatNDTestParam {
1150 std::string name;
1151 int rank = 0;
1152 std::function<Status(absl::Span<const TensorShape>, absl::Span<const int32>,
1153 absl::Span<const int32>, Graph*,
1154 std::vector<std::string>*)>
1155 graph_creator;
1156 };
1157
1158 class RankedXlaConcatNDOpTest
1159 : public ::testing::TestWithParam<RankedXlaConcatNDTestParam> {};
1160
TEST_P(RankedXlaConcatNDOpTest,TestSubscriptRank)1161 TEST_P(RankedXlaConcatNDOpTest, TestSubscriptRank) {
1162 const int rank = GetParam().rank;
1163 const std::vector<int32> num_concats(rank, 2);
1164
1165 Graph graph(OpRegistry::Global());
1166 const int num_inputs = 2 << (rank - 1);
1167 const TensorShape base_input_shape(std::vector<int64_t>(rank, 1));
1168 const std::vector<TensorShape> input_shapes(num_inputs, base_input_shape);
1169 const std::vector<int32> paddings;
1170 std::vector<std::string> output_tensor_names;
1171 TF_ASSERT_OK(GetParam().graph_creator(input_shapes, num_concats, paddings,
1172 &graph, &output_tensor_names));
1173
1174 std::vector<Tensor> output_tensors;
1175 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1176 &output_tensors));
1177 ASSERT_EQ(output_tensors.size(), 1);
1178 std::vector<int32> expected_values(num_inputs);
1179 std::iota(expected_values.begin(), expected_values.end(), 0);
1180 test::ExpectTensorEqual<int32>(
1181 output_tensors[0],
1182 test::AsTensor<int32>(expected_values,
1183 TensorShape(std::vector<int64_t>(rank, 2))));
1184 }
1185
1186 INSTANTIATE_TEST_SUITE_P(
1187 RankedXlaConcatNDOpTest, RankedXlaConcatNDOpTest,
1188 ::testing::ValuesIn<RankedXlaConcatNDTestParam>(
1189 {{"TensorRanked1", 1, CreateConcatTensorGraph},
1190 {"TensorRanked2", 2, CreateConcatTensorGraph},
1191 {"TensorRanked3", 3, CreateConcatTensorGraph},
1192 {"TensorRanked4", 4, CreateConcatTensorGraph},
1193 {"TensorRanked5", 5, CreateConcatTensorGraph},
1194 {"TensorRanked6", 6, CreateConcatTensorGraph},
1195 {"TensorRanked7", 7, CreateConcatTensorGraph},
1196 {"TensorRanked8", 8, CreateConcatTensorGraph},
1197 {"InitializedResourceRanked1", 1, CreateConcatResourceGraph<true>},
1198 {"InitializedResourceRanked2", 2, CreateConcatResourceGraph<true>},
1199 {"InitializedResourceRanked3", 3, CreateConcatResourceGraph<true>},
1200 {"InitializedResourceRanked4", 4, CreateConcatResourceGraph<true>},
1201 {"InitializedResourceRanked5", 5, CreateConcatResourceGraph<true>},
1202 {"InitializedResourceRanked6", 6, CreateConcatResourceGraph<true>},
1203 {"InitializedResourceRanked7", 7, CreateConcatResourceGraph<true>},
1204 {"InitializedResourceRanked8", 8, CreateConcatResourceGraph<true>},
1205 {"UninitializedResourceRanked1", 1, CreateConcatResourceGraph<false>},
1206 {"UninitializedResourceRanked2", 2, CreateConcatResourceGraph<false>},
1207 {"UninitializedResourceRanked3", 3, CreateConcatResourceGraph<false>},
1208 {"UninitializedResourceRanked4", 4, CreateConcatResourceGraph<false>},
1209 {"UninitializedResourceRanked5", 5, CreateConcatResourceGraph<false>},
1210 {"UninitializedResourceRanked6", 6, CreateConcatResourceGraph<false>},
1211 {"UninitializedResourceRanked7", 7, CreateConcatResourceGraph<false>},
1212 {"UninitializedResourceRanked8", 8,
1213 CreateConcatResourceGraph<false>}}),
1214 [](const ::testing::TestParamInfo<RankedXlaConcatNDOpTest::ParamType>&
__anon3d7075df0702(const ::testing::TestParamInfo<RankedXlaConcatNDOpTest::ParamType>& info) 1215 info) { return info.param.name; });
1216
CreateRoundtripTensorGraph(const TensorShape & input_shape,absl::Span<const int32> num_partitions,absl::Span<const int32> paddings,Graph * graph,std::vector<std::string> * output_tensor_names)1217 Status CreateRoundtripTensorGraph(
1218 const TensorShape& input_shape, absl::Span<const int32> num_partitions,
1219 absl::Span<const int32> paddings, Graph* graph,
1220 std::vector<std::string>* output_tensor_names) {
1221 const int32_t num_partitions_size =
1222 std::accumulate(num_partitions.begin(), num_partitions.end(), 1,
1223 std::multiplies<int32>());
1224
1225 DataType data_type = DataTypeToEnum<int32>::value;
1226 Tensor input_tensor(data_type, input_shape);
1227 test::FillIota<int32>(&input_tensor, /*val=*/0);
1228 Node* input = test::graph::Constant(graph, input_tensor);
1229
1230 Node* xla_split_op = nullptr;
1231 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("xla_split_op"), "XlaSplitND")
1232 .Input(input)
1233 .Attr("num_splits", num_partitions)
1234 .Attr("paddings", paddings)
1235 .Attr("T", data_type)
1236 .Attr("N", num_partitions_size)
1237 .Finalize(graph, &xla_split_op));
1238
1239 std::vector<NodeBuilder::NodeOut> concat_inputs;
1240 concat_inputs.reserve(num_partitions_size);
1241 for (int i = 0; i < num_partitions_size; ++i) {
1242 concat_inputs.push_back({xla_split_op, i});
1243 }
1244
1245 Node* xla_concat_op = nullptr;
1246 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("xla_concat_op"), "XlaConcatND")
1247 .Input(concat_inputs)
1248 .Attr("num_concats", num_partitions)
1249 .Attr("paddings", paddings)
1250 .Attr("T", data_type)
1251 .Attr("N", num_partitions_size)
1252 .Finalize(graph, &xla_concat_op));
1253
1254 Node* equal = nullptr;
1255 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("equal"), "Equal")
1256 .Input(input)
1257 .Input(xla_concat_op)
1258 .Attr("T", data_type)
1259 .Finalize(graph, &equal));
1260
1261 output_tensor_names->push_back(absl::StrCat(equal->name(), ":", 0));
1262
1263 return OkStatus();
1264 }
1265
CreateRoundtripResourceGraph(const TensorShape & input_shape,absl::Span<const int32> num_partitions,absl::Span<const int32> paddings,Graph * graph,std::vector<std::string> * output_tensor_names)1266 Status CreateRoundtripResourceGraph(
1267 const TensorShape& input_shape, absl::Span<const int32> num_partitions,
1268 absl::Span<const int32> paddings, Graph* graph,
1269 std::vector<std::string>* output_tensor_names) {
1270 const int32_t num_partitions_size =
1271 std::accumulate(num_partitions.begin(), num_partitions.end(), 1,
1272 std::multiplies<int32>());
1273
1274 Node* var_handle = nullptr;
1275 DataType data_type = DataTypeToEnum<int32>::value;
1276 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("var_handle"), "VarHandleOp")
1277 .Attr("dtype", data_type)
1278 .Attr("shape", PartialTensorShape())
1279 .Finalize(graph, &var_handle));
1280
1281 Tensor input_tensor(data_type, input_shape);
1282 test::FillIota<int32>(&input_tensor, 0);
1283 Node* input = test::graph::Constant(graph, input_tensor);
1284
1285 Node* assign_var = nullptr;
1286 TF_RETURN_IF_ERROR(
1287 NodeBuilder(graph->NewName("assign_var"), "AssignVariableOp")
1288 .Input(var_handle)
1289 .Input(input)
1290 .Attr("dtype", data_type)
1291 .Finalize(graph, &assign_var));
1292
1293 Node* xla_split_op = nullptr;
1294 TF_RETURN_IF_ERROR(
1295 NodeBuilder(graph->NewName("xla_split_op"), "ReadVariableXlaSplitND")
1296 .Input(var_handle)
1297 .ControlInput(assign_var)
1298 .Attr("num_splits", num_partitions)
1299 .Attr("paddings", paddings)
1300 .Attr("T", data_type)
1301 .Attr("N", num_partitions_size)
1302 .Finalize(graph, &xla_split_op));
1303
1304 std::vector<NodeBuilder::NodeOut> concat_inputs;
1305 concat_inputs.reserve(num_partitions_size);
1306 for (int i = 0; i < num_partitions_size; ++i) {
1307 concat_inputs.push_back({xla_split_op, i});
1308 }
1309
1310 Node* xla_concat_op = nullptr;
1311 TF_RETURN_IF_ERROR(
1312 NodeBuilder(graph->NewName("xla_op"), "AssignVariableXlaConcatND")
1313 .Input(var_handle)
1314 .Input(concat_inputs)
1315 .Attr("num_concats", num_partitions)
1316 .Attr("paddings", paddings)
1317 .Attr("T", data_type)
1318 .Attr("N", num_partitions_size)
1319 .Finalize(graph, &xla_concat_op));
1320
1321 Node* read_var = nullptr;
1322 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("read_var"), "ReadVariableOp")
1323 .Input(var_handle)
1324 .ControlInput(xla_concat_op)
1325 .Attr("dtype", data_type)
1326 .Finalize(graph, &read_var));
1327
1328 Node* equal = nullptr;
1329 TF_RETURN_IF_ERROR(NodeBuilder(graph->NewName("equal"), "Equal")
1330 .Input(input)
1331 .Input(read_var)
1332 .Attr("T", data_type)
1333 .Finalize(graph, &equal));
1334
1335 output_tensor_names->push_back(absl::StrCat(equal->name(), ":", 0));
1336
1337 return OkStatus();
1338 }
1339
1340 struct RoundtripXlaSplitConcatNDTestParam {
1341 std::string name;
1342 int rank = 0;
1343 std::function<Status(const TensorShape&, absl::Span<const int32>,
1344 absl::Span<const int32>, Graph*,
1345 std::vector<std::string>*)>
1346 graph_creator;
1347 };
1348
1349 class RoundtripXlaSplitConcatNDTest
1350 : public ::testing::TestWithParam<RoundtripXlaSplitConcatNDTestParam> {};
1351
1352 template <typename T>
Constant(T v,TensorShape shape)1353 Tensor Constant(T v, TensorShape shape) {
1354 Tensor ret(DataTypeToEnum<T>::value, shape);
1355 ret.flat<T>().setConstant(v);
1356 return ret;
1357 }
1358
TEST_P(RoundtripXlaSplitConcatNDTest,NoPadding)1359 TEST_P(RoundtripXlaSplitConcatNDTest, NoPadding) {
1360 const int rank = GetParam().rank;
1361 const std::vector<int32> num_partitions(rank, 2);
1362
1363 Graph graph(OpRegistry::Global());
1364 const TensorShape input_shape(std::vector<int64_t>(rank, 4));
1365 const std::vector<int32> paddings;
1366 std::vector<std::string> output_tensor_names;
1367 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_partitions, paddings,
1368 &graph, &output_tensor_names));
1369
1370 std::vector<Tensor> output_tensors;
1371 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1372 &output_tensors));
1373 ASSERT_EQ(output_tensors.size(), 1);
1374
1375 test::ExpectTensorEqual<bool>(
1376 output_tensors[0],
1377 Constant<bool>(true, TensorShape(std::vector<int64_t>(rank, 4))));
1378 }
1379
TEST_P(RoundtripXlaSplitConcatNDTest,PartialPadding)1380 TEST_P(RoundtripXlaSplitConcatNDTest, PartialPadding) {
1381 const int rank = GetParam().rank;
1382 const std::vector<int32> num_partitions(rank, 2);
1383
1384 Graph graph(OpRegistry::Global());
1385 const TensorShape input_shape(std::vector<int64_t>(rank, 4));
1386 const std::vector<int32> paddings(rank, 2);
1387 std::vector<std::string> output_tensor_names;
1388 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_partitions, paddings,
1389 &graph, &output_tensor_names));
1390
1391 std::vector<Tensor> output_tensors;
1392 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1393 &output_tensors));
1394 ASSERT_EQ(output_tensors.size(), 1);
1395
1396 test::ExpectTensorEqual<bool>(
1397 output_tensors[0],
1398 Constant<bool>(true, TensorShape(std::vector<int64_t>(rank, 4))));
1399 }
1400
TEST_P(RoundtripXlaSplitConcatNDTest,CompletePadding)1401 TEST_P(RoundtripXlaSplitConcatNDTest, CompletePadding) {
1402 const int rank = GetParam().rank;
1403 const std::vector<int32> num_partitions(rank, 2);
1404
1405 Graph graph(OpRegistry::Global());
1406 const TensorShape input_shape(std::vector<int64_t>(rank, 4));
1407 const std::vector<int32> paddings(rank, 4);
1408 std::vector<std::string> output_tensor_names;
1409 TF_ASSERT_OK(GetParam().graph_creator(input_shape, num_partitions, paddings,
1410 &graph, &output_tensor_names));
1411
1412 std::vector<Tensor> output_tensors;
1413 TF_ASSERT_OK(RunGraph(graph, output_tensor_names, /*target_tensor_names=*/{},
1414 &output_tensors));
1415 ASSERT_EQ(output_tensors.size(), 1);
1416
1417 test::ExpectTensorEqual<bool>(
1418 output_tensors[0],
1419 Constant<bool>(true, TensorShape(std::vector<int64_t>(rank, 4))));
1420 }
1421
1422 INSTANTIATE_TEST_SUITE_P(
1423 RoundtripXlaSplitConcatNDTest, RoundtripXlaSplitConcatNDTest,
1424 ::testing::ValuesIn<RoundtripXlaSplitConcatNDTestParam>(
1425 {{"TensorRanked1", 1, CreateRoundtripTensorGraph},
1426 {"TensorRanked2", 2, CreateRoundtripTensorGraph},
1427 {"TensorRanked3", 3, CreateRoundtripTensorGraph},
1428 {"TensorRanked4", 4, CreateRoundtripTensorGraph},
1429 {"TensorRanked5", 5, CreateRoundtripTensorGraph},
1430 {"TensorRanked6", 6, CreateRoundtripTensorGraph},
1431 {"TensorRanked7", 7, CreateRoundtripTensorGraph},
1432 {"TensorRanked8", 8, CreateRoundtripTensorGraph},
1433 {"ResourceRanked1", 1, CreateRoundtripResourceGraph},
1434 {"ResourceRanked2", 2, CreateRoundtripResourceGraph},
1435 {"ResourceRanked3", 3, CreateRoundtripResourceGraph},
1436 {"ResourceRanked4", 4, CreateRoundtripResourceGraph},
1437 {"ResourceRanked5", 5, CreateRoundtripResourceGraph},
1438 {"ResourceRanked6", 6, CreateRoundtripResourceGraph},
1439 {"ResourceRanked7", 7, CreateRoundtripResourceGraph},
1440 {"ResourceRanked8", 8, CreateRoundtripResourceGraph}}),
1441 [](const ::testing::TestParamInfo<RoundtripXlaSplitConcatNDTest::ParamType>&
__anon3d7075df0802(const ::testing::TestParamInfo<RoundtripXlaSplitConcatNDTest::ParamType>& info) 1442 info) { return info.param.name; });
1443
1444 } // namespace
1445 } // namespace tensorflow
1446