1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16
17 #include "absl/strings/str_cat.h"
18 #include "tensorflow/core/framework/fake_input.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/framework/op_def_builder.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/status_matchers.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28
29 namespace tensorflow {
30 namespace shape_inference {
31 namespace {
32
33 using ::tensorflow::testing::StatusIs;
34 using ::testing::_;
35 using ::testing::AllOf;
36 using ::testing::HasSubstr;
37
MakeOpDefWithLists()38 OpDef MakeOpDefWithLists() {
39 OpRegistrationData op_reg_data;
40 TF_EXPECT_OK(OpDefBuilder("dummy")
41 .Input("input: N * float")
42 .Output("output: N * float")
43 .Attr("N:int >= 1")
44 .Finalize(&op_reg_data));
45 return op_reg_data.op_def;
46 }
47
S(std::initializer_list<int64_t> dims)48 PartialTensorShape S(std::initializer_list<int64_t> dims) {
49 return PartialTensorShape(dims);
50 }
51
Unknown()52 PartialTensorShape Unknown() { return PartialTensorShape(); }
53
54 } // namespace
55
56 class ShapeInferenceTest : public ::testing::Test {
57 protected:
58 // These give access to private functions of DimensionHandle and ShapeHandle.
SameHandle(DimensionHandle a,DimensionHandle b)59 bool SameHandle(DimensionHandle a, DimensionHandle b) {
60 return a.SameHandle(b);
61 }
SameHandle(ShapeHandle a,ShapeHandle b)62 bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
IsSet(DimensionHandle d)63 bool IsSet(DimensionHandle d) { return d.IsSet(); }
IsSet(ShapeHandle s)64 bool IsSet(ShapeHandle s) { return s.IsSet(); }
Relax(InferenceContext * c,DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)65 void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1,
66 DimensionHandle* out) {
67 c->Relax(d0, d1, out);
68 }
Relax(InferenceContext * c,ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)69 void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1,
70 ShapeHandle* out) {
71 c->Relax(s0, s1, out);
72 }
73 void TestMergeHandles(bool input_not_output);
74 void TestRelaxHandles(bool input_not_output);
75
76 static constexpr int kVersion = 0; // used for graph-def version.
77 };
78
79 namespace {
80
TEST_F(ShapeInferenceTest,InputOutputByName)81 TEST_F(ShapeInferenceTest, InputOutputByName) {
82 // Setup test to contain an input tensor list of size 3.
83 OpDef op_def = MakeOpDefWithLists();
84 NodeDef def;
85 auto s = NodeDefBuilder("dummy", &op_def)
86 .Attr("N", 3)
87 .Input(FakeInput(DT_FLOAT))
88 .Finalize(&def);
89 InferenceContext c(kVersion, def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
90 {}, {}, {});
91
92 EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
93 EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
94 EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2))));
95 // Test getters.
96 std::vector<ShapeHandle> shapes;
97 EXPECT_THAT(
98 c.input("nonexistent", &shapes),
99 StatusIs(error::INVALID_ARGUMENT, HasSubstr("Unknown input name")));
100 TF_EXPECT_OK(c.input("input", &shapes));
101 EXPECT_EQ("[1,5]", c.DebugString(shapes[0]));
102 EXPECT_EQ("[2,5]", c.DebugString(shapes[1]));
103 EXPECT_EQ("[1,3]", c.DebugString(shapes[2]));
104
105 // Test setters.
106 EXPECT_THAT(
107 c.set_output("nonexistent", shapes),
108 StatusIs(error::INVALID_ARGUMENT, HasSubstr("Unknown output name")));
109 TF_EXPECT_OK(c.set_output("output", shapes));
110 EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0))));
111 EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1))));
112 EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2))));
113 }
114
MakeOpDef(int num_inputs,int num_outputs)115 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
116 OpRegistrationData op_reg_data;
117 OpDefBuilder b("dummy");
118 for (int i = 0; i < num_inputs; ++i) {
119 b.Input(absl::StrCat("i", i, ": float"));
120 }
121 for (int i = 0; i < num_outputs; ++i) {
122 b.Output(absl::StrCat("o", i, ": float"));
123 }
124 CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
125 return op_reg_data.op_def;
126 }
127
TEST_F(ShapeInferenceTest,DimensionOrConstant)128 TEST_F(ShapeInferenceTest, DimensionOrConstant) {
129 NodeDef def;
130 InferenceContext c(kVersion, def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
131 EXPECT_EQ(InferenceContext::kUnknownDim,
132 c.Value(InferenceContext::kUnknownDim));
133 EXPECT_EQ(1, c.Value(1));
134
135 #ifndef NDEBUG
136 // Only run death test if DCHECKS are enabled.
137 EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
138 #endif
139 }
140
TEST_F(ShapeInferenceTest,Run)141 TEST_F(ShapeInferenceTest, Run) {
142 NodeDef def;
143 def.set_name("foo");
144 def.set_op("foo_op");
145 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
146 TF_ASSERT_OK(c.construction_status());
147
148 {
149 auto fn = [](InferenceContext* c) {
150 ShapeHandle h;
151 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h));
152 c->set_output(0, c->input(0));
153 c->set_output(1, c->input(0));
154 return OkStatus();
155 };
156 TF_ASSERT_OK(c.Run(fn));
157 }
158
159 {
160 auto fn = [](InferenceContext* c) {
161 ShapeHandle h;
162 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
163 c->set_output(0, c->input(0));
164 c->set_output(1, c->input(0));
165 return OkStatus();
166 };
167 // Extra error message is attached when Run fails.
168 EXPECT_THAT(
169 c.Run(fn),
170 StatusIs(error::INVALID_ARGUMENT,
171 AllOf(HasSubstr("Shape must be at most rank 0 but is rank 1"),
172 HasSubstr("node foo"), HasSubstr("foo_op"))));
173 }
174 }
175
176 // Tests different context data added when Run returns error.
TEST_F(ShapeInferenceTest,AttachContext)177 TEST_F(ShapeInferenceTest, AttachContext) {
178 NodeDef def;
179 def.set_name("foo");
180 def.set_op("foo_op");
181 // Error when no constant tensors were requested.
182 {
183 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
184 {});
185 TF_ASSERT_OK(c.construction_status());
186 auto fn = [](InferenceContext* c) {
187 ShapeHandle h;
188 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
189 c->set_output(0, c->input(0));
190 return OkStatus();
191 };
192 EXPECT_THAT(
193 c.Run(fn),
194 StatusIs(error::INVALID_ARGUMENT,
195 AllOf(HasSubstr("Shape must be at most rank 0 but is rank 3"),
196 HasSubstr("node foo"), HasSubstr("foo_op"),
197 HasSubstr("input shapes: [1,2,3]"))));
198 }
199
200 // Error when a constant tensor value was requested.
201 {
202 Tensor input_t =
203 ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
204 InferenceContext c(kVersion, def, MakeOpDef(2, 2),
205 {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
206 TF_ASSERT_OK(c.construction_status());
207 auto fn = [](InferenceContext* c) {
208 c->input_tensor(0); // It's null - will not appear in the error message.
209 c->input_tensor(1); // This will appear in the error message.
210 ShapeHandle h;
211 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
212 c->set_output(0, c->input(0));
213 return OkStatus();
214 };
215 EXPECT_THAT(
216 c.Run(fn),
217 StatusIs(error::INVALID_ARGUMENT,
218 AllOf(HasSubstr("Shape must be at most rank 0 but is rank 3"),
219 HasSubstr("node foo"), HasSubstr("foo_op"),
220 HasSubstr("input shapes: [1,2,3], [4,5] and with "
221 "computed input tensors: "
222 "input[1] = <1.1 2.2 3.3 4.4 5.5>."))));
223 }
224
225 // Error when a constant tensor value as shape was requested, but no partial
226 // shapes provided.
227 {
228 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
229 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
230 {nullptr, &input_t}, {}, {});
231 TF_ASSERT_OK(c.construction_status());
232 auto fn = [](InferenceContext* c) {
233 ShapeHandle s;
234 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
235 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
236 ShapeHandle h;
237 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
238 c->set_output(0, c->input(0));
239 return OkStatus();
240 };
241 EXPECT_THAT(
242 c.Run(fn),
243 StatusIs(error::INVALID_ARGUMENT,
244 AllOf(HasSubstr("Shape must be at most rank 0 but is rank 1"),
245 HasSubstr("node foo"), HasSubstr("foo_op"),
246 HasSubstr("with input shapes: [3], [4] and with "
247 "computed input tensors: input[1] "
248 "= <1 2 3 4 5>."))));
249 }
250
251 // Error when a constant tensor value as shape was requested, and a partial
252 // shape was provided.
253 {
254 Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
255 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
256 {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
257 TF_ASSERT_OK(c.construction_status());
258 auto fn = [](InferenceContext* c) {
259 ShapeHandle s;
260 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
261 TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
262 ShapeHandle h;
263 TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
264 c->set_output(0, c->input(0));
265 return OkStatus();
266 };
267 EXPECT_THAT(
268 c.Run(fn),
269 StatusIs(
270 error::INVALID_ARGUMENT,
271 AllOf(HasSubstr("Shape must be at most rank 0 but is rank 1"),
272 HasSubstr("node foo"), HasSubstr("foo_op"),
273 HasSubstr("with input shapes: [3], [4] and with computed "
274 "input tensors: input[1] = <1 2 3 4 5> and with "
275 "input tensors computed "
276 "as partial shapes: input[0] = [10,?,5]."))));
277 }
278 }
279
TEST_F(ShapeInferenceTest,RankAndDimInspection)280 TEST_F(ShapeInferenceTest, RankAndDimInspection) {
281 NodeDef def;
282 InferenceContext c(kVersion, def, MakeOpDef(3, 2),
283 {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
284 EXPECT_EQ(3, c.num_inputs());
285 EXPECT_EQ(2, c.num_outputs());
286
287 auto in0 = c.input(0);
288 EXPECT_EQ("?", c.DebugString(in0));
289 EXPECT_FALSE(c.RankKnown(in0));
290 EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0));
291 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0)));
292 EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1)));
293 EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000)));
294
295 auto in1 = c.input(1);
296 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
297 EXPECT_TRUE(c.RankKnown(in1));
298 EXPECT_EQ(3, c.Rank(in1));
299 auto d = c.Dim(in1, 0);
300 EXPECT_EQ(1, c.Value(d));
301 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3)));
302 EXPECT_TRUE(c.ValueKnown(d));
303 EXPECT_EQ("1", c.DebugString(d));
304 d = c.Dim(in1, 1);
305 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d));
306 EXPECT_FALSE(c.ValueKnown(d));
307 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2)));
308 EXPECT_EQ("?", c.DebugString(d));
309 d = c.Dim(in1, 2);
310 EXPECT_EQ(3, c.Value(d));
311 EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1)));
312 EXPECT_TRUE(c.ValueKnown(d));
313 EXPECT_EQ("3", c.DebugString(d));
314
315 auto in2 = c.input(2);
316 EXPECT_EQ("[]", c.DebugString(in2));
317 EXPECT_TRUE(c.RankKnown(in2));
318 EXPECT_EQ(0, c.Rank(in2));
319 }
320
TEST_F(ShapeInferenceTest,NumElements)321 TEST_F(ShapeInferenceTest, NumElements) {
322 NodeDef def;
323 InferenceContext c(kVersion, def, MakeOpDef(3, 2),
324 {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
325
326 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
327 EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
328
329 // Different handles (not the same unknown value).
330 EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1))));
331
332 EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2))));
333 }
334
TEST_F(ShapeInferenceTest,WithRank)335 TEST_F(ShapeInferenceTest, WithRank) {
336 NodeDef def;
337 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
338 {}, {}, {});
339
340 auto in0 = c.input(0);
341 auto in1 = c.input(1);
342 ShapeHandle s1;
343 ShapeHandle s2;
344
345 // WithRank on a shape with unknown dimensionality always succeeds.
346 TF_EXPECT_OK(c.WithRank(in0, 1, &s1));
347 EXPECT_EQ("[?]", c.DebugString(s1));
348
349 TF_EXPECT_OK(c.WithRank(in0, 2, &s2));
350 EXPECT_EQ("[?,?]", c.DebugString(s2));
351 EXPECT_FALSE(SameHandle(s1, s2));
352 EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1)));
353
354 TF_EXPECT_OK(c.WithRank(in0, 1, &s2));
355 EXPECT_EQ("[?]", c.DebugString(s2));
356 EXPECT_FALSE(SameHandle(s1, s2));
357
358 TF_EXPECT_OK(c.WithRank(in0, 0, &s1));
359 EXPECT_EQ("[]", c.DebugString(s1));
360
361 // WithRank on shape with known dimensionality.
362 s1 = in1;
363 EXPECT_THAT(c.WithRank(in1, 2, &s1),
364 StatusIs(error::INVALID_ARGUMENT,
365 HasSubstr("Shape must be rank 2 but is rank 3")));
366
367 EXPECT_FALSE(IsSet(s1));
368 TF_EXPECT_OK(c.WithRank(in1, 3, &s1));
369 EXPECT_TRUE(SameHandle(s1, in1));
370
371 // Inputs are unchanged.
372 EXPECT_EQ("?", c.DebugString(in0));
373 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
374 }
375
TEST_F(ShapeInferenceTest,WithRankAtMost)376 TEST_F(ShapeInferenceTest, WithRankAtMost) {
377 NodeDef def;
378 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
379 {}, {}, {});
380
381 auto in0 = c.input(0);
382 auto in1 = c.input(1);
383 ShapeHandle s1;
384 ShapeHandle s2;
385
386 // WithRankAtMost on a shape with unknown dimensionality always succeeds.
387 TF_EXPECT_OK(c.WithRankAtMost(in0, 1, &s1));
388 EXPECT_EQ("?", c.DebugString(s1));
389 EXPECT_TRUE(SameHandle(in0, s1));
390
391 TF_EXPECT_OK(c.WithRankAtMost(in0, 2, &s2));
392 EXPECT_EQ("?", c.DebugString(s2));
393 EXPECT_TRUE(SameHandle(s1, s2));
394
395 // WithRankAtMost on shape with known dimensionality.
396 s1 = in1;
397 EXPECT_THAT(
398 c.WithRankAtMost(in1, 2, &s1),
399 StatusIs(error::INVALID_ARGUMENT,
400 HasSubstr("Shape must be at most rank 2 but is rank 3")));
401
402 EXPECT_FALSE(IsSet(s1));
403 TF_EXPECT_OK(c.WithRankAtMost(in1, 3, &s1));
404 EXPECT_TRUE(SameHandle(s1, in1));
405 TF_EXPECT_OK(c.WithRankAtMost(in1, 4, &s1));
406 EXPECT_TRUE(SameHandle(s1, in1));
407 TF_EXPECT_OK(c.WithRankAtMost(in1, 5, &s1));
408 EXPECT_TRUE(SameHandle(s1, in1));
409
410 // Inputs are unchanged.
411 EXPECT_EQ("?", c.DebugString(in0));
412 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
413 }
414
TEST_F(ShapeInferenceTest,WithRankAtLeast)415 TEST_F(ShapeInferenceTest, WithRankAtLeast) {
416 NodeDef def;
417 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
418 {}, {}, {});
419
420 auto in0 = c.input(0);
421 auto in1 = c.input(1);
422 ShapeHandle s1;
423 ShapeHandle s2;
424
425 // WithRankAtLeast on a shape with unknown dimensionality always succeeds.
426 TF_EXPECT_OK(c.WithRankAtLeast(in0, 1, &s1));
427 EXPECT_EQ("?", c.DebugString(s1));
428 EXPECT_TRUE(SameHandle(in0, s1));
429
430 TF_EXPECT_OK(c.WithRankAtLeast(in0, 2, &s2));
431 EXPECT_EQ("?", c.DebugString(s2));
432 EXPECT_TRUE(SameHandle(s1, s2));
433
434 // WithRankAtLeast on shape with known dimensionality.
435 s1 = in1;
436 EXPECT_THAT(
437 c.WithRankAtLeast(in1, 4, &s1),
438 StatusIs(error::INVALID_ARGUMENT,
439 HasSubstr("Shape must be at least rank 4 but is rank 3")));
440
441 EXPECT_FALSE(IsSet(s1));
442 TF_EXPECT_OK(c.WithRankAtLeast(in1, 3, &s1));
443 EXPECT_TRUE(SameHandle(s1, in1));
444 TF_EXPECT_OK(c.WithRankAtLeast(in1, 2, &s1));
445 EXPECT_TRUE(SameHandle(s1, in1));
446 TF_EXPECT_OK(c.WithRankAtLeast(in1, 0, &s1));
447 EXPECT_TRUE(SameHandle(s1, in1));
448
449 // Inputs are unchanged.
450 EXPECT_EQ("?", c.DebugString(in0));
451 EXPECT_EQ("[1,?,3]", c.DebugString(in1));
452 }
453
TEST_F(ShapeInferenceTest,WithValue)454 TEST_F(ShapeInferenceTest, WithValue) {
455 NodeDef def;
456 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
457
458 auto d0 = c.Dim(c.input(0), 0);
459 auto d1 = c.Dim(c.input(0), 1);
460 DimensionHandle out1;
461 DimensionHandle out2;
462
463 // WithValue on a dimension with unknown value always succeeds.
464 TF_EXPECT_OK(c.WithValue(d1, 1, &out1));
465 EXPECT_EQ(1, c.Value(out1));
466
467 TF_EXPECT_OK(c.WithValue(d1, 2, &out2));
468 EXPECT_EQ(2, c.Value(out2));
469 EXPECT_FALSE(SameHandle(out1, out2));
470 EXPECT_FALSE(SameHandle(out1, d1));
471
472 TF_EXPECT_OK(c.WithValue(d1, 1, &out2));
473 EXPECT_EQ(1, c.Value(out2));
474 EXPECT_FALSE(SameHandle(out1, out2));
475
476 // WithValue on dimension with known size.
477 out1 = d0;
478
479 EXPECT_THAT(c.WithValue(d0, 0, &out1),
480 StatusIs(error::INVALID_ARGUMENT,
481 HasSubstr("Dimension must be 0 but is 1")));
482 EXPECT_FALSE(IsSet(out1));
483 out1 = d0;
484 EXPECT_THAT(c.WithValue(d0, 2, &out1),
485 StatusIs(error::INVALID_ARGUMENT,
486 HasSubstr("Dimension must be 2 but is 1")));
487
488 EXPECT_FALSE(IsSet(out1));
489 TF_EXPECT_OK(c.WithValue(d0, 1, &out1));
490 EXPECT_TRUE(SameHandle(d0, out1));
491
492 // Inputs are unchanged.
493 EXPECT_EQ("1", c.DebugString(d0));
494 EXPECT_EQ("?", c.DebugString(d1));
495 }
496
TEST_F(ShapeInferenceTest,MergeDim)497 TEST_F(ShapeInferenceTest, MergeDim) {
498 NodeDef def;
499 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {},
500 {}, {});
501
502 auto d2 = c.Dim(c.input(0), 0);
503 auto d_unknown = c.Dim(c.input(0), 1);
504 auto d2_b = c.Dim(c.input(0), 2);
505 auto d1 = c.Dim(c.input(0), 3);
506 auto d_unknown_b = c.Dim(c.input(0), 4);
507 DimensionHandle out;
508
509 // Merging anything with unknown returns the same pointer.
510 TF_EXPECT_OK(c.Merge(d2, d_unknown, &out));
511 EXPECT_TRUE(SameHandle(d2, out));
512 TF_EXPECT_OK(c.Merge(d_unknown, d2, &out));
513 EXPECT_TRUE(SameHandle(d2, out));
514 TF_EXPECT_OK(c.Merge(d_unknown, d_unknown_b, &out));
515 EXPECT_TRUE(SameHandle(d_unknown, out));
516
517 auto merged_dims = c.MergedDims();
518 ASSERT_EQ(3, merged_dims.size());
519 EXPECT_TRUE(merged_dims[0].first.SameHandle(d2));
520 EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown));
521 EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown));
522 EXPECT_TRUE(merged_dims[1].second.SameHandle(d2));
523 EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown));
524 EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b));
525
526 // Merging with self is a no-op and returns self.
527 TF_EXPECT_OK(c.Merge(d2, d2, &out));
528 EXPECT_TRUE(SameHandle(d2, out));
529 TF_EXPECT_OK(c.Merge(d_unknown, d_unknown, &out));
530 EXPECT_TRUE(SameHandle(d_unknown, out));
531
532 merged_dims = c.MergedDims();
533 EXPECT_EQ(3, merged_dims.size());
534
535 // Merging equal values is a no op and returns first one.
536 TF_EXPECT_OK(c.Merge(d2, d2_b, &out));
537 EXPECT_TRUE(SameHandle(d2, out));
538 TF_EXPECT_OK(c.Merge(d2_b, d2, &out));
539 EXPECT_TRUE(SameHandle(d2_b, out));
540
541 merged_dims = c.MergedDims();
542 EXPECT_EQ(3, merged_dims.size());
543
544 // Merging unequal values is an error.
545 EXPECT_THAT(c.Merge(d2, d1, &out),
546 StatusIs(error::INVALID_ARGUMENT,
547 HasSubstr("Dimensions must be equal, but are 2 and 1")));
548
549 EXPECT_FALSE(IsSet(out));
550 EXPECT_THAT(c.Merge(d1, d2, &out),
551 StatusIs(error::INVALID_ARGUMENT,
552 HasSubstr("Dimensions must be equal, but are 1 and 2")));
553
554 EXPECT_FALSE(IsSet(out));
555
556 merged_dims = c.MergedDims();
557 EXPECT_EQ(3, merged_dims.size());
558 }
559
TEST_F(ShapeInferenceTest,RelaxDim)560 TEST_F(ShapeInferenceTest, RelaxDim) {
561 NodeDef def;
562 InferenceContext c(kVersion, def, MakeOpDef(1, 2),
563 {S({2, InferenceContext::kUnknownDim, 2, 1,
564 InferenceContext::kUnknownDim})},
565 {}, {}, {});
566
567 auto d2 = c.Dim(c.input(0), 0);
568 auto d_unknown = c.Dim(c.input(0), 1);
569 auto d2_b = c.Dim(c.input(0), 2);
570 auto d1 = c.Dim(c.input(0), 3);
571 auto d_unknown_b = c.Dim(c.input(0), 4);
572 DimensionHandle out;
573
574 // Relaxing anything with unknown returns a new unknown or the existing
575 // unknown.
576 Relax(&c, d2, d_unknown, &out);
577 EXPECT_TRUE(SameHandle(d_unknown, out));
578 EXPECT_FALSE(SameHandle(d_unknown_b, out));
579 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
580 Relax(&c, d_unknown, d2, &out);
581 EXPECT_FALSE(SameHandle(d_unknown, out));
582 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
583 Relax(&c, d_unknown, d_unknown_b, &out);
584 EXPECT_FALSE(SameHandle(d_unknown, out));
585 EXPECT_TRUE(SameHandle(d_unknown_b, out));
586 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
587
588 // Relaxing with self returns self.
589 Relax(&c, d2, d2, &out);
590 EXPECT_TRUE(SameHandle(d2, out));
591 Relax(&c, d_unknown, d_unknown, &out);
592 EXPECT_TRUE(SameHandle(d_unknown, out));
593
594 // Relaxing equal values returns first one.
595 Relax(&c, d2, d2_b, &out);
596 EXPECT_TRUE(SameHandle(d2, out));
597 Relax(&c, d2_b, d2, &out);
598 EXPECT_TRUE(SameHandle(d2_b, out));
599
600 // Relaxing unequal values returns a new unknown.
601 Relax(&c, d2, d1, &out);
602 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
603 Relax(&c, d1, d2, &out);
604 EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
605 }
606
TEST_F(ShapeInferenceTest,RelaxShape)607 TEST_F(ShapeInferenceTest, RelaxShape) {
608 NodeDef def;
609 InferenceContext c(
610 kVersion, def, MakeOpDef(7, 2),
611 {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
612 S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
613 {}, {}, {});
614
615 auto s_unknown = c.input(0);
616 auto s_1_2 = c.input(1);
617 auto s_u_2 = c.input(2);
618 auto s_1_u = c.input(3);
619 auto s_1_3 = c.input(4);
620 auto s_unknown_b = c.input(5);
621 auto s_1 = c.input(6);
622 ShapeHandle out;
623
624 // Relaxing any shape with unknown returns a new unknown.
625 Relax(&c, s_unknown, s_1_2, &out);
626 EXPECT_FALSE(SameHandle(s_u_2, s_unknown));
627 EXPECT_EQ("?", c.DebugString(out));
628 Relax(&c, s_u_2, s_unknown, &out);
629 EXPECT_FALSE(SameHandle(s_u_2, out));
630 EXPECT_EQ("?", c.DebugString(out));
631 Relax(&c, s_unknown, s_unknown_b, &out);
632 EXPECT_FALSE(SameHandle(s_unknown, out));
633 EXPECT_TRUE(SameHandle(s_unknown_b, out));
634 EXPECT_EQ("?", c.DebugString(out));
635
636 // Relaxing with self returns self.
637 Relax(&c, s_1_2, s_1_2, &out);
638 EXPECT_TRUE(SameHandle(out, s_1_2));
639
640 // Relaxing where one of the inputs has less information.
641 out = ShapeHandle();
642 Relax(&c, s_1_2, s_u_2, &out);
643 EXPECT_FALSE(SameHandle(s_u_2, out));
644 EXPECT_EQ("[?,2]", c.DebugString(out));
645 out = ShapeHandle();
646 Relax(&c, s_u_2, s_1_2, &out);
647 EXPECT_FALSE(SameHandle(s_u_2, out));
648 EXPECT_EQ("[?,2]", c.DebugString(out));
649
650 // Relaxing where each input has one distinct unknown dimension.
651 Relax(&c, s_u_2, s_1_u, &out);
652 EXPECT_EQ("[?,?]", c.DebugString(out));
653 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
654 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1)));
655 auto s_u1 = c.UnknownShapeOfRank(1);
656 auto s_u2 = c.UnknownShapeOfRank(1);
657 Relax(&c, s_u1, s_u2, &out);
658 EXPECT_FALSE(SameHandle(s_u1, out));
659
660 // Relaxing with mismatched values in a dimension returns a shape with that
661 // dimension unknown.
662 out = s_unknown;
663 Relax(&c, s_u_2, s_1_3, &out);
664 EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
665 EXPECT_EQ("[?,?]", c.DebugString(out));
666 out = s_unknown;
667 Relax(&c, s_1_3, s_u_2, &out);
668 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
669 EXPECT_EQ("[?,?]", c.DebugString(out));
670 out = s_unknown;
671
672 // Relaxing with mismatched ranks returns a new unknown.
673 Relax(&c, s_1, s_1_2, &out);
674 EXPECT_EQ("?", c.DebugString(out));
675 }
676
TEST_F(ShapeInferenceTest,MergeShape)677 TEST_F(ShapeInferenceTest, MergeShape) {
678 NodeDef def;
679 InferenceContext c(kVersion, def, MakeOpDef(7, 2),
680 {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
681 Unknown(), S({1})},
682 {}, {}, {});
683
684 auto s_unknown = c.input(0);
685 auto s_1_2 = c.input(1);
686 auto s_u_2 = c.input(2);
687 auto s_1_u = c.input(3);
688 auto s_1_3 = c.input(4);
689 auto s_unknown_b = c.input(5);
690 auto s_1 = c.input(6);
691 ShapeHandle out;
692
693 // Merging any shape with unknown returns the shape.
694 TF_EXPECT_OK(c.Merge(s_unknown, s_1_2, &out));
695 EXPECT_TRUE(SameHandle(s_1_2, out));
696 TF_EXPECT_OK(c.Merge(s_u_2, s_unknown, &out));
697 EXPECT_TRUE(SameHandle(s_u_2, out));
698 TF_EXPECT_OK(c.Merge(s_unknown, s_unknown_b, &out));
699 EXPECT_TRUE(SameHandle(s_unknown, out));
700
701 auto merged_shapes = c.MergedShapes();
702 ASSERT_EQ(3, merged_shapes.size());
703 EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown));
704 EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2));
705 EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2));
706 EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown));
707 EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown));
708 EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b));
709
710 // Merging with self returns self.
711 TF_EXPECT_OK(c.Merge(s_1_2, s_1_2, &out));
712 EXPECT_TRUE(SameHandle(out, s_1_2));
713
714 merged_shapes = c.MergedShapes();
715 EXPECT_EQ(3, merged_shapes.size());
716
717 // Merging where one of the inputs is the right answer - return that input.
718 out = ShapeHandle();
719 TF_EXPECT_OK(c.Merge(s_1_2, s_u_2, &out));
720 EXPECT_TRUE(SameHandle(s_1_2, out));
721 out = ShapeHandle();
722 TF_EXPECT_OK(c.Merge(s_u_2, s_1_2, &out));
723 EXPECT_TRUE(SameHandle(s_1_2, out));
724
725 merged_shapes = c.MergedShapes();
726 ASSERT_EQ(5, merged_shapes.size());
727 EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2));
728 EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2));
729 EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2));
730 EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2));
731
732 // Merging where neither input is the right answer.
733 TF_EXPECT_OK(c.Merge(s_u_2, s_1_u, &out));
734 EXPECT_FALSE(SameHandle(out, s_u_2));
735 EXPECT_FALSE(SameHandle(out, s_1_u));
736 EXPECT_EQ("[1,2]", c.DebugString(out));
737 EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0)));
738 EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1)));
739
740 merged_shapes = c.MergedShapes();
741 ASSERT_EQ(7, merged_shapes.size());
742 EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2));
743 EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u));
744 EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2));
745 EXPECT_TRUE(merged_shapes[6].second.SameHandle(out));
746
747 auto s_u1 = c.UnknownShapeOfRank(1);
748 auto s_u2 = c.UnknownShapeOfRank(1);
749 TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out));
750 EXPECT_TRUE(SameHandle(s_u1, out));
751
752 merged_shapes = c.MergedShapes();
753 ASSERT_EQ(8, merged_shapes.size());
754 EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1));
755 EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2));
756
757 // Incompatible merges give errors and set out to nullptr.
758 out = s_unknown;
759 EXPECT_THAT(
760 c.Merge(s_u_2, s_1_3, &out),
761 StatusIs(
762 error::INVALID_ARGUMENT,
763 HasSubstr(
764 "Dimension 1 in both shapes must be equal, but are 2 and 3")));
765
766 EXPECT_FALSE(IsSet(out));
767 out = s_unknown;
768 EXPECT_THAT(
769 c.Merge(s_1_3, s_u_2, &out),
770 StatusIs(
771 error::INVALID_ARGUMENT,
772 HasSubstr(
773 "Dimension 1 in both shapes must be equal, but are 3 and 2")));
774
775 EXPECT_FALSE(IsSet(out));
776 out = s_unknown;
777 EXPECT_THAT(
778 c.Merge(s_1, s_1_2, &out),
779 StatusIs(error::INVALID_ARGUMENT,
780 HasSubstr("Shapes must be equal rank, but are 1 and 2")));
781
782 EXPECT_FALSE(IsSet(out));
783
784 merged_shapes = c.MergedShapes();
785 EXPECT_EQ(8, merged_shapes.size());
786 }
787
TEST_F(ShapeInferenceTest,MergePrefix)788 TEST_F(ShapeInferenceTest, MergePrefix) {
789 NodeDef def;
790 InferenceContext c(kVersion, def, MakeOpDef(4, 2),
791 {
792 Unknown(),
793 S({-1, 2}),
794 S({1, -1, 3}),
795 S({2, 4}),
796 },
797 {}, {}, {});
798
799 auto s_unknown = c.input(0);
800 auto s_u_2 = c.input(1);
801 auto s_1_u_3 = c.input(2);
802 auto s_2_4 = c.input(3);
803
804 ShapeHandle s_out;
805 ShapeHandle s_prefix_out;
806
807 // Merging with unknown returns the inputs.
808 TF_EXPECT_OK(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out));
809 EXPECT_TRUE(SameHandle(s_out, s_unknown));
810 EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2));
811 TF_EXPECT_OK(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out));
812 EXPECT_TRUE(SameHandle(s_out, s_1_u_3));
813 EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown));
814
815 TF_EXPECT_OK(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out));
816 EXPECT_FALSE(SameHandle(s_out, s_1_u_3));
817 EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out));
818 EXPECT_EQ("[1,2,3]", c.DebugString(s_out));
819 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0)));
820 EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0)));
821 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1)));
822 EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1)));
823
824 // Incompatible merges give errors and set outs to nullptr.
825 s_out = s_unknown;
826 s_prefix_out = s_unknown;
827 EXPECT_THAT(c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out),
828 StatusIs(error::INVALID_ARGUMENT,
829 HasSubstr("Dimensions must be equal, but are 1 and 2")));
830
831 EXPECT_FALSE(IsSet(s_out));
832 EXPECT_FALSE(IsSet(s_prefix_out));
833
834 s_out = s_unknown;
835 s_prefix_out = s_unknown;
836 EXPECT_THAT(
837 c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out),
838 StatusIs(error::INVALID_ARGUMENT,
839 HasSubstr("Shape must be at least rank 3 but is rank 2")));
840 EXPECT_FALSE(IsSet(s_out));
841 EXPECT_FALSE(IsSet(s_prefix_out));
842 }
843
TEST_F(ShapeInferenceTest,Subshape)844 TEST_F(ShapeInferenceTest, Subshape) {
845 NodeDef def;
846 InferenceContext c(kVersion, def, MakeOpDef(2, 2),
847 {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
848
849 ShapeHandle unknown = c.input(1);
850 ShapeHandle out;
851 TF_EXPECT_OK(c.Subshape(unknown, 0, &out));
852 EXPECT_EQ("?", c.DebugString(out));
853 EXPECT_TRUE(SameHandle(out, unknown));
854 TF_EXPECT_OK(c.Subshape(unknown, 1, &out));
855 EXPECT_EQ("?", c.DebugString(out));
856 EXPECT_FALSE(SameHandle(out, unknown));
857 TF_EXPECT_OK(c.Subshape(unknown, 200, &out));
858 EXPECT_EQ("?", c.DebugString(out));
859 EXPECT_FALSE(SameHandle(out, unknown));
860
861 const int kFullRank = 5;
862 ShapeHandle out_arr[4];
863 auto in0 = c.input(0);
864 TF_EXPECT_OK(c.Subshape(in0, 0, &out));
865 EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out));
866 EXPECT_TRUE(SameHandle(out, in0));
867 EXPECT_EQ(kFullRank, c.Rank(out));
868 for (int start = 0; start <= kFullRank + 1; ++start) {
869 for (int end = start; end <= kFullRank + 1; ++end) {
870 // Get subshapes using different start and end values that give the same
871 // range.
872 const int neg_start =
873 start >= kFullRank ? kFullRank : (start - kFullRank);
874 const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank);
875 TF_ASSERT_OK(c.Subshape(in0, start, end, &out_arr[0]));
876 TF_ASSERT_OK(c.Subshape(in0, neg_start, end, &out_arr[1]));
877 TF_ASSERT_OK(c.Subshape(in0, start, neg_end, &out_arr[2]));
878 TF_ASSERT_OK(c.Subshape(in0, neg_start, neg_end, &out_arr[3]));
879
880 // Verify all computed subshapes.
881 for (int arr_idx = 0; arr_idx < 4; ++arr_idx) {
882 out = out_arr[arr_idx];
883 ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start),
884 c.Rank(out))
885 << "start: " << start << " end: " << end << " arr_idx: " << arr_idx
886 << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out);
887 for (int d = 0; d < c.Rank(out); ++d) {
888 EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d)))
889 << "arr_idx: " << arr_idx;
890 }
891 }
892 }
893 }
894
895 // Errors.
896 out = unknown;
897 EXPECT_THAT(
898 c.Subshape(in0, 6, -3, &out),
899 StatusIs(error::INVALID_ARGUMENT,
900 HasSubstr("Subshape must have computed start <= end, but is 5 "
901 "and 2 (computed from start 6 and end -3 over shape "
902 "with rank 5)")));
903 EXPECT_FALSE(IsSet(out));
904 out = unknown;
905 EXPECT_THAT(
906 c.Subshape(in0, -50, 100, &out),
907 StatusIs(
908 error::INVALID_ARGUMENT,
909 HasSubstr(
910 "Subshape start out of bounds: -50, for shape with rank 5")));
911
912 EXPECT_FALSE(IsSet(out));
913 out = unknown;
914 EXPECT_THAT(
915 c.Subshape(in0, 0, -50, &out),
916 StatusIs(
917 error::INVALID_ARGUMENT,
918 HasSubstr("Subshape end out of bounds: -50, for shape with rank 5")));
919
920 EXPECT_FALSE(IsSet(out));
921 }
922
TEST_F(ShapeInferenceTest,Concatenate)923 TEST_F(ShapeInferenceTest, Concatenate) {
924 NodeDef def;
925 InferenceContext c(kVersion, def, MakeOpDef(3, 2),
926 {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
927
928 auto in0 = c.input(0);
929 auto in1 = c.input(1);
930 ShapeHandle unknown = c.input(2);
931 ShapeHandle out;
932 TF_EXPECT_OK(c.Concatenate(unknown, unknown, &out));
933 EXPECT_EQ("?", c.DebugString(out));
934 EXPECT_FALSE(SameHandle(out, unknown));
935 TF_EXPECT_OK(c.Concatenate(unknown, in0, &out));
936 EXPECT_EQ("?", c.DebugString(out));
937 EXPECT_FALSE(SameHandle(out, unknown));
938
939 TF_EXPECT_OK(c.Concatenate(in0, in1, &out));
940 EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out));
941 int out_i = 0;
942 for (int i = 0; i < c.Rank(in0); ++i, ++out_i) {
943 EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i)));
944 }
945 for (int i = 0; i < c.Rank(in1); ++i, ++out_i) {
946 EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i)));
947 }
948 }
949
TEST_F(ShapeInferenceTest,ReplaceDim)950 TEST_F(ShapeInferenceTest, ReplaceDim) {
951 NodeDef def;
952 InferenceContext c(kVersion, def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
953 {}, {}, {});
954
955 auto in = c.input(0);
956 auto unknown = c.input(1);
957
958 ShapeHandle replaced;
959 TF_EXPECT_OK(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced));
960 EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
961 TF_EXPECT_OK(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced));
962 EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
963 TF_EXPECT_OK(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced));
964 EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
965 TF_EXPECT_OK(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced));
966 EXPECT_EQ("?", c.DebugString(replaced));
967
968 // Negative indexing.
969 TF_EXPECT_OK(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced));
970 EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
971 TF_EXPECT_OK(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced));
972 EXPECT_EQ("?", c.DebugString(replaced));
973
974 // out of range indexing.
975 EXPECT_THAT(
976 c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced),
977 StatusIs(error::INVALID_ARGUMENT, HasSubstr("Out of range dim_index")));
978 EXPECT_FALSE(IsSet(replaced));
979 replaced = in;
980 EXPECT_THAT(
981 c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced),
982 StatusIs(error::INVALID_ARGUMENT, HasSubstr("Out of range dim_index")));
983 EXPECT_FALSE(IsSet(replaced));
984 }
985
TEST_F(ShapeInferenceTest,MakeShape)986 TEST_F(ShapeInferenceTest, MakeShape) {
987 NodeDef def;
988 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
989 {}, {});
990
991 std::vector<DimensionHandle> dims;
992 auto in0 = c.input(0);
993 const int rank = c.Rank(in0);
994 dims.reserve(rank);
995 for (int i = 0; i < rank; ++i) {
996 dims.push_back(c.Dim(in0, rank - i - 1));
997 }
998
999 auto s = c.MakeShape(dims);
1000 EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s));
1001 EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1)));
1002
1003 auto s2 = c.MakeShape(dims);
1004 EXPECT_FALSE(SameHandle(s, s2));
1005 EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1)));
1006
1007 auto s3 = c.MakeShape({1, 2, dims[2]});
1008 EXPECT_FALSE(SameHandle(s, s3));
1009 EXPECT_EQ("[1,2,3]", c.DebugString(s3));
1010 }
1011
TEST_F(ShapeInferenceTest,UnknownShape)1012 TEST_F(ShapeInferenceTest, UnknownShape) {
1013 NodeDef def;
1014 std::vector<ShapeHandle> empty;
1015 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1016
1017 auto u0 = c.UnknownShape();
1018 auto u1 = c.UnknownShape();
1019 EXPECT_EQ("?", c.DebugString(u0));
1020 EXPECT_EQ("?", c.DebugString(u1));
1021 EXPECT_FALSE(SameHandle(u0, u1));
1022 }
1023
TEST_F(ShapeInferenceTest,KnownShapeToProto)1024 TEST_F(ShapeInferenceTest, KnownShapeToProto) {
1025 NodeDef def;
1026 std::vector<ShapeHandle> empty;
1027 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1028
1029 auto s = c.MakeShape({1, 2, 3});
1030 TensorShapeProto proto;
1031 c.ShapeHandleToProto(s, &proto);
1032
1033 EXPECT_FALSE(proto.unknown_rank());
1034 EXPECT_EQ(3, proto.dim_size());
1035 EXPECT_EQ(1, proto.dim(0).size());
1036 }
1037
TEST_F(ShapeInferenceTest,UnknownShapeToProto)1038 TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
1039 NodeDef def;
1040 std::vector<ShapeHandle> empty;
1041 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1042
1043 auto u0 = c.UnknownShape();
1044 TensorShapeProto proto;
1045 c.ShapeHandleToProto(u0, &proto);
1046
1047 EXPECT_TRUE(proto.unknown_rank());
1048 EXPECT_EQ(0, proto.dim_size());
1049 }
1050
TEST_F(ShapeInferenceTest,Scalar)1051 TEST_F(ShapeInferenceTest, Scalar) {
1052 NodeDef def;
1053 std::vector<ShapeHandle> empty;
1054 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1055
1056 auto s0 = c.Scalar();
1057 EXPECT_EQ("[]", c.DebugString(s0));
1058 auto s1 = c.Scalar();
1059 EXPECT_EQ("[]", c.DebugString(s1));
1060 EXPECT_FALSE(SameHandle(s0, s1));
1061 }
1062
TEST_F(ShapeInferenceTest,Vector)1063 TEST_F(ShapeInferenceTest, Vector) {
1064 NodeDef def;
1065 std::vector<ShapeHandle> empty;
1066 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1067
1068 auto s0 = c.Vector(1);
1069 EXPECT_EQ("[1]", c.DebugString(s0));
1070 auto s1 = c.Vector(InferenceContext::kUnknownDim);
1071 EXPECT_EQ("[?]", c.DebugString(s1));
1072
1073 auto d1 = c.UnknownDim();
1074 auto s2 = c.Vector(d1);
1075 EXPECT_EQ("[?]", c.DebugString(s2));
1076 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1077 }
1078
TEST_F(ShapeInferenceTest,Matrix)1079 TEST_F(ShapeInferenceTest, Matrix) {
1080 NodeDef def;
1081 std::vector<ShapeHandle> empty;
1082 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1083
1084 auto s0 = c.Matrix(1, 2);
1085 EXPECT_EQ("[1,2]", c.DebugString(s0));
1086 auto s1 = c.Matrix(0, InferenceContext::kUnknownDim);
1087 EXPECT_EQ("[0,?]", c.DebugString(s1));
1088
1089 auto d1 = c.UnknownDim();
1090 auto d2 = c.UnknownDim();
1091 auto s2 = c.Matrix(d1, d2);
1092 EXPECT_EQ("[?,?]", c.DebugString(s2));
1093 EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1094 EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1)));
1095
1096 auto s3 = c.Matrix(d1, 100);
1097 EXPECT_EQ("[?,100]", c.DebugString(s3));
1098 EXPECT_TRUE(SameHandle(d1, c.Dim(s3, 0)));
1099 }
1100
TEST_F(ShapeInferenceTest,MakeShapeFromShapeTensor)1101 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
1102 auto create = [&](Tensor* t) {
1103 NodeDef def;
1104 InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
1105 {});
1106 ShapeHandle out;
1107 Status s = c.MakeShapeFromShapeTensor(0, &out);
1108 if (s.ok()) {
1109 return c.DebugString(out);
1110 } else {
1111 EXPECT_FALSE(IsSet(out));
1112 return s.error_message();
1113 }
1114 };
1115
1116 Tensor t;
1117 EXPECT_EQ("?", create(nullptr));
1118
1119 t = ::tensorflow::test::AsTensor<int32>({1, 2, 3});
1120 EXPECT_EQ("[1,2,3]", create(&t));
1121
1122 t = ::tensorflow::test::AsTensor<int64_t>({3, 2, 1});
1123 EXPECT_EQ("[3,2,1]", create(&t));
1124
1125 t = ::tensorflow::test::AsTensor<int64_t>({3, -1, 1});
1126 EXPECT_EQ("[3,?,1]", create(&t));
1127
1128 t = ::tensorflow::test::AsTensor<int64_t>({});
1129 EXPECT_EQ("[]", create(&t));
1130
1131 // Test negative scalar
1132 t = ::tensorflow::test::AsScalar<int32>(-1);
1133 EXPECT_EQ("?", create(&t));
1134
1135 t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
1136 EXPECT_THAT(create(&t),
1137 HasSubstr("Input tensor must be int32 or int64, but was float"));
1138
1139 t = ::tensorflow::test::AsScalar<int32>(1);
1140 auto s_scalar = create(&t);
1141 EXPECT_THAT(s_scalar, HasSubstr("Input tensor must be rank 1, or if its rank "
1142 "0 it must have value -1"));
1143
1144 t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
1145 auto s_matrix = create(&t);
1146 EXPECT_THAT(s_matrix,
1147 HasSubstr("Input tensor must be rank 1, but was rank 2"));
1148
1149 // Test negative values for the dims.
1150 t = ::tensorflow::test::AsTensor<int64_t>({3, -2, 1});
1151 EXPECT_THAT(create(&t),
1152 HasSubstr("Invalid value in tensor used for shape: -2"));
1153
1154 // Test negative values for the dims.
1155 t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
1156 EXPECT_THAT(create(&t),
1157 HasSubstr("Invalid value in tensor used for shape: -2"));
1158
1159 // Test when the input shape is wrong.
1160 {
1161 NodeDef def;
1162 InferenceContext c(kVersion, def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
1163 {}, {});
1164 ShapeHandle out;
1165 EXPECT_EQ("Shape must be rank 1 but is rank 2",
1166 c.MakeShapeFromShapeTensor(0, &out).error_message());
1167 }
1168 }
1169
TEST_F(ShapeInferenceTest,MakeShapeFromPartialTensorShape)1170 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
1171 NodeDef def;
1172 std::vector<ShapeHandle> empty;
1173 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1174
1175 // With an unknown rank.
1176 ShapeHandle out;
1177 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out));
1178 EXPECT_EQ("?", c.DebugString(out));
1179
1180 // With a known rank.
1181 TF_ASSERT_OK(
1182 c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out));
1183 EXPECT_EQ("[0]", c.DebugString(out));
1184 TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(
1185 PartialTensorShape({0, -1, 1000}), &out));
1186 EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1187 }
1188
TEST_F(ShapeInferenceTest,MakeShapeFromTensorShape)1189 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
1190 NodeDef def;
1191 std::vector<ShapeHandle> empty;
1192 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1193
1194 ShapeHandle out;
1195 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
1196 EXPECT_EQ("[]", c.DebugString(out));
1197 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out));
1198 EXPECT_EQ("[0]", c.DebugString(out));
1199 TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out));
1200 EXPECT_EQ("[0,7,1000]", c.DebugString(out));
1201 }
1202
TEST_F(ShapeInferenceTest,MakeShapeFromShapeProto)1203 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
1204 NodeDef def;
1205 std::vector<ShapeHandle> empty;
1206 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1207 TensorShapeProto proto;
1208
1209 // With a set unknown rank.
1210 ShapeHandle out;
1211 proto.set_unknown_rank(true);
1212 TF_EXPECT_OK(c.MakeShapeFromShapeProto(proto, &out));
1213 EXPECT_EQ("?", c.DebugString(out));
1214 proto.add_dim()->set_size(0);
1215 EXPECT_THAT(
1216 c.MakeShapeFromShapeProto(proto, &out),
1217 StatusIs(error::INVALID_ARGUMENT,
1218 HasSubstr("An unknown shape must not have any dimensions set")));
1219 EXPECT_FALSE(IsSet(out));
1220
1221 // With known rank.
1222 proto.set_unknown_rank(false);
1223 TF_EXPECT_OK(c.MakeShapeFromShapeProto(proto, &out));
1224 EXPECT_EQ("[0]", c.DebugString(out));
1225 proto.add_dim()->set_size(-1);
1226 proto.add_dim()->set_size(1000);
1227 TF_EXPECT_OK(c.MakeShapeFromShapeProto(proto, &out));
1228 EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1229
1230 // With invalid dimension value.
1231 proto.add_dim()->set_size(-2);
1232 EXPECT_THAT(
1233 c.MakeShapeFromShapeProto(proto, &out),
1234 StatusIs(
1235 error::INVALID_ARGUMENT,
1236 HasSubstr("Shape [0,?,1000,-2] has dimensions with values below -1 "
1237 "(where -1 means unknown)")));
1238
1239 EXPECT_FALSE(IsSet(out));
1240 }
1241
TEST_F(ShapeInferenceTest,MakeDim)1242 TEST_F(ShapeInferenceTest, MakeDim) {
1243 NodeDef def;
1244 std::vector<ShapeHandle> empty;
1245 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1246
1247 auto d0 = c.MakeDim(1);
1248 auto d1 = c.MakeDim(1);
1249 auto d2 = c.MakeDim(2);
1250 EXPECT_EQ("1", c.DebugString(d0));
1251 EXPECT_EQ("1", c.DebugString(d1));
1252 EXPECT_FALSE(SameHandle(d0, d1));
1253 EXPECT_EQ("2", c.DebugString(d2));
1254 }
1255
TEST_F(ShapeInferenceTest,UnknownDim)1256 TEST_F(ShapeInferenceTest, UnknownDim) {
1257 NodeDef def;
1258 std::vector<ShapeHandle> empty;
1259 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1260
1261 auto d0 = c.UnknownDim();
1262 auto d1 = c.UnknownDim();
1263 EXPECT_EQ("?", c.DebugString(d0));
1264 EXPECT_EQ("?", c.DebugString(d1));
1265 EXPECT_FALSE(SameHandle(d0, d1));
1266 }
1267
TEST_F(ShapeInferenceTest,UnknownShapeOfRank)1268 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
1269 NodeDef def;
1270 std::vector<ShapeHandle> empty;
1271 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1272
1273 auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
1274 EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
1275
1276 auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0);
1277 EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0));
1278 }
1279
TEST_F(ShapeInferenceTest,InputTensors)1280 TEST_F(ShapeInferenceTest, InputTensors) {
1281 const Tensor t1 = tensorflow::test::AsTensor<float>({10});
1282 const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
1283 NodeDef def;
1284 InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
1285 {&t1, &t2}, {}, {});
1286
1287 EXPECT_TRUE(c.input_tensor(0) == &t1);
1288 EXPECT_TRUE(c.input_tensor(1) == &t2);
1289 EXPECT_TRUE(c.input_tensor(2) == nullptr);
1290 }
1291
TEST_F(ShapeInferenceTest,MakeDimForScalarInput)1292 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
1293 Tensor t1 = tensorflow::test::AsScalar<int32>(20);
1294 Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
1295 NodeDef def;
1296 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
1297 {}, {});
1298
1299 DimensionHandle d;
1300 TF_EXPECT_OK(c.MakeDimForScalarInput(0, &d));
1301 EXPECT_EQ("20", c.DebugString(d));
1302
1303 EXPECT_THAT(
1304 c.MakeDimForScalarInput(1, &d),
1305 StatusIs(error::INVALID_ARGUMENT,
1306 HasSubstr("Dimension size, given by scalar input 1, must be "
1307 "non-negative but is -1")));
1308
1309 // Same tests, with int64 values.
1310 t1 = tensorflow::test::AsScalar<int64_t>(20);
1311 t2 = tensorflow::test::AsScalar<int64_t>(-1);
1312 TF_EXPECT_OK(c.MakeDimForScalarInput(0, &d));
1313 EXPECT_EQ("20", c.DebugString(d));
1314
1315 EXPECT_THAT(
1316 c.MakeDimForScalarInput(1, &d),
1317 StatusIs(error::INVALID_ARGUMENT,
1318 HasSubstr("Dimension size, given by scalar input 1, must be "
1319 "non-negative but is -1")));
1320 }
1321
TEST_F(ShapeInferenceTest,MakeDimForScalarInputWithNegativeIndexing)1322 TEST_F(ShapeInferenceTest, MakeDimForScalarInputWithNegativeIndexing) {
1323 Tensor t1 = tensorflow::test::AsScalar<int32>(-2);
1324 Tensor t2 = tensorflow::test::AsScalar<int32>(3);
1325 NodeDef def;
1326 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
1327 {}, {});
1328
1329 DimensionHandle d;
1330
1331 // Negative input rank and a negative value in the input tensor results in an
1332 // unknown dimension.
1333 TF_EXPECT_OK(c.MakeDimForScalarInputWithNegativeIndexing(0, -1, &d));
1334 EXPECT_EQ("?", c.DebugString(d));
1335
1336 TF_EXPECT_OK(c.MakeDimForScalarInputWithNegativeIndexing(0, 4, &d));
1337 EXPECT_EQ("2", c.DebugString(d));
1338
1339 EXPECT_THAT(c.MakeDimForScalarInputWithNegativeIndexing(0, 1, &d),
1340 StatusIs(error::INVALID_ARGUMENT,
1341 HasSubstr("Dimension size, given by scalar input -2 "
1342 "must be in range [-1, 1)")));
1343
1344 TF_EXPECT_OK(c.MakeDimForScalarInputWithNegativeIndexing(1, 4, &d));
1345 EXPECT_EQ("3", c.DebugString(d));
1346
1347 EXPECT_THAT(c.MakeDimForScalarInputWithNegativeIndexing(1, 2, &d),
1348 StatusIs(error::INVALID_ARGUMENT,
1349 HasSubstr("Dimension size, given by scalar input 3 "
1350 "must be in range [-2, 2)")));
1351 }
1352
TEST_F(ShapeInferenceTest,GetAttr)1353 TEST_F(ShapeInferenceTest, GetAttr) {
1354 OpRegistrationData op_reg_data;
1355 op_reg_data.op_def = MakeOpDef(0, 2);
1356 NodeDef def;
1357 CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
1358 .Attr("foo", "bar")
1359 .Finalize(&def)
1360 .ok());
1361
1362 std::vector<ShapeHandle> empty;
1363 InferenceContext c(kVersion, def, op_reg_data.op_def, empty, {}, {}, {});
1364 string value;
1365 TF_EXPECT_OK(c.GetAttr("foo", &value));
1366 EXPECT_EQ("bar", value);
1367 }
1368
TEST_F(ShapeInferenceTest,Divide)1369 TEST_F(ShapeInferenceTest, Divide) {
1370 NodeDef def;
1371 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
1372 {}, {});
1373
1374 auto s = c.input(0);
1375 auto d_6 = c.Dim(s, 0);
1376 auto d_unknown = c.Dim(s, 1);
1377 auto d_1 = c.Dim(s, 2);
1378 auto d_2 = c.Dim(s, 3);
1379 auto d_0 = c.Dim(s, 4);
1380 bool evenly_divisible = true;
1381
1382 // Dividing unknown by non-1 gives new unknown.
1383 DimensionHandle out;
1384 TF_EXPECT_OK(c.Divide(d_unknown, 2, evenly_divisible, &out));
1385 EXPECT_EQ("?", c.DebugString(out));
1386 EXPECT_FALSE(SameHandle(out, d_unknown));
1387
1388 // Dividing anything by 1 returns the input.
1389 TF_EXPECT_OK(c.Divide(d_unknown, 1, evenly_divisible, &out));
1390 EXPECT_TRUE(SameHandle(out, d_unknown));
1391 TF_EXPECT_OK(c.Divide(d_6, 1, evenly_divisible, &out));
1392 EXPECT_TRUE(SameHandle(out, d_6));
1393 TF_EXPECT_OK(c.Divide(d_unknown, d_1, evenly_divisible, &out));
1394 EXPECT_TRUE(SameHandle(out, d_unknown));
1395 TF_EXPECT_OK(c.Divide(d_6, d_1, evenly_divisible, &out));
1396 EXPECT_TRUE(SameHandle(out, d_6));
1397
1398 TF_EXPECT_OK(c.Divide(d_6, 2, evenly_divisible, &out));
1399 EXPECT_EQ("3", c.DebugString(out));
1400 TF_EXPECT_OK(c.Divide(d_6, d_2, evenly_divisible, &out));
1401 EXPECT_EQ("3", c.DebugString(out));
1402
1403 EXPECT_THAT(
1404 c.Divide(d_6, 5, evenly_divisible, &out),
1405 StatusIs(
1406 error::INVALID_ARGUMENT,
1407 HasSubstr("Dimension size must be evenly divisible by 5 but is 6")));
1408
1409 EXPECT_THAT(c.Divide(d_6, 0, evenly_divisible, &out),
1410 StatusIs(error::INVALID_ARGUMENT,
1411 HasSubstr("Divisor must be positive but is 0")));
1412 EXPECT_THAT(c.Divide(d_6, d_0, evenly_divisible, &out),
1413 StatusIs(error::INVALID_ARGUMENT,
1414 HasSubstr("Divisor must be positive but is 0")));
1415
1416 EXPECT_THAT(c.Divide(d_6, -1, evenly_divisible, &out),
1417 StatusIs(error::INVALID_ARGUMENT,
1418 HasSubstr("Divisor must be positive but is -1")));
1419
1420 // Repeat error cases above with evenly_divisible=false.
1421 evenly_divisible = false;
1422 TF_EXPECT_OK(c.Divide(d_6, 5, evenly_divisible, &out));
1423 EXPECT_EQ("1", c.DebugString(out));
1424
1425 EXPECT_THAT(c.Divide(d_6, 0, evenly_divisible, &out),
1426 StatusIs(error::INVALID_ARGUMENT,
1427 HasSubstr("Divisor must be positive but is 0")));
1428
1429 EXPECT_THAT(c.Divide(d_6, -1, evenly_divisible, &out),
1430 StatusIs(error::INVALID_ARGUMENT,
1431 HasSubstr("Divisor must be positive but is -1")));
1432 }
1433
TEST_F(ShapeInferenceTest,Add)1434 TEST_F(ShapeInferenceTest, Add) {
1435 NodeDef def;
1436 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
1437 {});
1438
1439 auto s = c.input(0);
1440 auto d_6 = c.Dim(s, 0);
1441 auto d_unknown = c.Dim(s, 1);
1442 auto d_0 = c.Dim(s, 2);
1443
1444 // Adding non-zero to unknown gives new unknown.
1445 DimensionHandle out;
1446 TF_EXPECT_OK(c.Add(d_unknown, 1, &out));
1447 EXPECT_EQ("?", c.DebugString(out));
1448 EXPECT_FALSE(SameHandle(out, d_unknown));
1449
1450 // Adding 0 to anything gives input.
1451 TF_EXPECT_OK(c.Add(d_unknown, 0, &out));
1452 EXPECT_TRUE(SameHandle(out, d_unknown));
1453 TF_EXPECT_OK(c.Add(d_6, 0, &out));
1454 EXPECT_TRUE(SameHandle(out, d_6));
1455
1456 // Adding dimension with value 0 to anything gives input.
1457 TF_EXPECT_OK(c.Add(d_unknown, c.MakeDim(0ll), &out));
1458 EXPECT_TRUE(SameHandle(out, d_unknown));
1459 TF_EXPECT_OK(c.Add(d_6, c.MakeDim(0ll), &out));
1460 EXPECT_TRUE(SameHandle(out, d_6));
1461
1462 // Test addition.
1463 TF_EXPECT_OK(c.Add(d_6, 2, &out));
1464 EXPECT_EQ("8", c.DebugString(out));
1465 TF_EXPECT_OK(c.Add(d_6, std::numeric_limits<int64_t>::max() - 6, &out));
1466 EXPECT_EQ(std::numeric_limits<int64_t>::max(), c.Value(out));
1467
1468 // Test addition using dimension as second value.
1469 TF_EXPECT_OK(c.Add(d_6, c.MakeDim(2), &out));
1470 EXPECT_EQ("8", c.DebugString(out));
1471 EXPECT_TRUE(
1472 c.Add(d_6, c.MakeDim(std::numeric_limits<int64_t>::max() - 6), &out)
1473 .ok());
1474 EXPECT_EQ(std::numeric_limits<int64_t>::max(), c.Value(out));
1475 TF_EXPECT_OK(c.Add(d_6, c.UnknownDim(), &out));
1476 EXPECT_EQ("?", c.DebugString(out));
1477 TF_EXPECT_OK(c.Add(d_0, d_6, &out));
1478 EXPECT_TRUE(SameHandle(out, d_6));
1479
1480 EXPECT_THAT(c.Add(d_6, std::numeric_limits<int64_t>::max() - 5, &out),
1481 StatusIs(error::INVALID_ARGUMENT,
1482 HasSubstr("Dimension size overflow from adding 6 and "
1483 "9223372036854775802")));
1484 }
1485
TEST_F(ShapeInferenceTest,Subtract)1486 TEST_F(ShapeInferenceTest, Subtract) {
1487 NodeDef def;
1488 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {},
1489 {});
1490
1491 auto s = c.input(0);
1492 auto d_6 = c.Dim(s, 0);
1493 auto d_unknown = c.Dim(s, 1);
1494 auto d_0 = c.Dim(s, 2);
1495 auto d_5 = c.Dim(s, 3);
1496
1497 // Subtracting non-zero from unknown gives new unknown.
1498 DimensionHandle out;
1499 TF_EXPECT_OK(c.Subtract(d_unknown, 1, &out));
1500 EXPECT_EQ("?", c.DebugString(out));
1501 EXPECT_FALSE(SameHandle(out, d_unknown));
1502
1503 // Subtracting 0 from anything gives input.
1504 TF_EXPECT_OK(c.Subtract(d_unknown, 0ll, &out));
1505 EXPECT_TRUE(SameHandle(out, d_unknown));
1506 TF_EXPECT_OK(c.Subtract(d_6, 0ll, &out));
1507 EXPECT_TRUE(SameHandle(out, d_6));
1508
1509 // Subtracting dimension with value 0 from anything gives input.
1510 TF_EXPECT_OK(c.Subtract(d_unknown, c.MakeDim(0ll), &out));
1511 EXPECT_TRUE(SameHandle(out, d_unknown));
1512 TF_EXPECT_OK(c.Subtract(d_6, c.MakeDim(0ll), &out));
1513 EXPECT_TRUE(SameHandle(out, d_6));
1514
1515 // Test subtraction.
1516 TF_EXPECT_OK(c.Subtract(d_6, 2, &out));
1517 EXPECT_EQ("4", c.DebugString(out));
1518 TF_EXPECT_OK(c.Subtract(d_6, 6, &out));
1519 EXPECT_EQ("0", c.DebugString(out));
1520
1521 // Test subtraction using dimension as second value.
1522 TF_EXPECT_OK(c.Subtract(d_6, c.MakeDim(2), &out));
1523 EXPECT_EQ("4", c.DebugString(out));
1524 TF_EXPECT_OK(c.Subtract(d_6, d_5, &out));
1525 EXPECT_EQ("1", c.DebugString(out));
1526 TF_EXPECT_OK(c.Subtract(d_6, c.UnknownDim(), &out));
1527 EXPECT_EQ("?", c.DebugString(out));
1528 TF_EXPECT_OK(c.Subtract(d_6, d_0, &out));
1529 EXPECT_TRUE(SameHandle(out, d_6));
1530
1531 EXPECT_THAT(
1532 c.Subtract(d_5, d_6, &out),
1533 StatusIs(
1534 error::INVALID_ARGUMENT,
1535 HasSubstr("Negative dimension size caused by subtracting 6 from 5")));
1536 }
1537
TEST_F(ShapeInferenceTest,Multiply)1538 TEST_F(ShapeInferenceTest, Multiply) {
1539 NodeDef def;
1540 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {},
1541 {});
1542
1543 auto s = c.input(0);
1544 auto d_6 = c.Dim(s, 0);
1545 auto d_unknown = c.Dim(s, 1);
1546 auto d_0 = c.Dim(s, 2);
1547 auto d_1 = c.Dim(s, 3);
1548
1549 // Multiplying non-zero to unknown gives new unknown.
1550 DimensionHandle out;
1551 TF_EXPECT_OK(c.Multiply(d_unknown, 2, &out));
1552 EXPECT_EQ("?", c.DebugString(out));
1553
1554 // Multiplying 0 to anything gives 0.
1555 TF_EXPECT_OK(c.Multiply(d_unknown, 0, &out));
1556 EXPECT_EQ("0", c.DebugString(out));
1557 TF_EXPECT_OK(c.Multiply(d_unknown, d_0, &out));
1558 EXPECT_EQ("0", c.DebugString(out));
1559 TF_EXPECT_OK(c.Multiply(d_0, d_unknown, &out));
1560 EXPECT_EQ("0", c.DebugString(out));
1561
1562 // Multiplying 1 to anything gives the original.
1563 // (unknown -> unknown)
1564 TF_EXPECT_OK(c.Multiply(d_unknown, 1, &out));
1565 EXPECT_TRUE(SameHandle(d_unknown, out));
1566 TF_EXPECT_OK(c.Multiply(d_unknown, d_1, &out));
1567 EXPECT_TRUE(SameHandle(d_unknown, out));
1568 TF_EXPECT_OK(c.Multiply(d_1, d_unknown, &out));
1569 EXPECT_TRUE(SameHandle(d_unknown, out));
1570 // (known -> known)
1571 TF_EXPECT_OK(c.Multiply(d_6, 1, &out));
1572 EXPECT_TRUE(SameHandle(d_6, out));
1573 TF_EXPECT_OK(c.Multiply(d_6, d_1, &out));
1574 EXPECT_TRUE(SameHandle(d_6, out));
1575 TF_EXPECT_OK(c.Multiply(d_1, d_6, &out));
1576 EXPECT_TRUE(SameHandle(d_6, out));
1577
1578 // Test multiplication.
1579 TF_EXPECT_OK(c.Multiply(d_6, 2, &out));
1580 EXPECT_EQ("12", c.DebugString(out));
1581 TF_EXPECT_OK(c.Multiply(d_6, 6, &out));
1582 EXPECT_EQ("36", c.DebugString(out));
1583
1584 // Test multiplication using dimension as second value.
1585 TF_EXPECT_OK(c.Multiply(d_6, c.MakeDim(2), &out));
1586 EXPECT_EQ("12", c.DebugString(out));
1587 TF_EXPECT_OK(c.Multiply(d_6, c.UnknownDim(), &out));
1588 EXPECT_EQ("?", c.DebugString(out));
1589
1590 EXPECT_THAT(
1591 c.Multiply(d_6, std::numeric_limits<int64_t>::max() / 2, &out),
1592 StatusIs(error::INVALID_ARGUMENT,
1593 HasSubstr("Negative dimension size caused by overflow")));
1594 }
1595
TEST_F(ShapeInferenceTest,FullyDefined)1596 TEST_F(ShapeInferenceTest, FullyDefined) {
1597 NodeDef def;
1598 std::vector<ShapeHandle> empty;
1599 InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1600
1601 // No rank or missing dimension information should return false.
1602 EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
1603 EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim())));
1604
1605 // Return true if all information exists.
1606 EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2))));
1607 EXPECT_TRUE(c.FullyDefined(c.Scalar()));
1608 }
1609
TEST_F(ShapeInferenceTest,Min)1610 TEST_F(ShapeInferenceTest, Min) {
1611 NodeDef def;
1612 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {},
1613 {});
1614
1615 auto s = c.input(0);
1616 auto d_1 = c.Dim(s, 0);
1617 auto d_2 = c.Dim(s, 1);
1618 auto d_unknown = c.Dim(s, 2);
1619 auto d_0 = c.Dim(s, 3);
1620
1621 // Minimum involving zero and unknown returns zero.
1622 DimensionHandle out;
1623 TF_EXPECT_OK(c.Min(d_0, d_unknown, &out));
1624 EXPECT_TRUE(SameHandle(d_0, out));
1625 TF_EXPECT_OK(c.Min(d_unknown, d_0, &out));
1626 EXPECT_TRUE(SameHandle(d_0, out));
1627 TF_EXPECT_OK(c.Min(c.MakeDim(0ll), d_unknown, &out));
1628 EXPECT_EQ("0", c.DebugString(out));
1629 TF_EXPECT_OK(c.Min(d_unknown, 0ll, &out));
1630 EXPECT_EQ("0", c.DebugString(out));
1631
1632 // Minimum involving unknowns and non-zeros gives new unknown.
1633 TF_EXPECT_OK(c.Min(d_unknown, d_unknown, &out));
1634 EXPECT_EQ("?", c.DebugString(out));
1635 TF_EXPECT_OK(c.Min(d_unknown, 1, &out));
1636 EXPECT_EQ("?", c.DebugString(out));
1637 TF_EXPECT_OK(c.Min(d_1, d_unknown, &out));
1638 EXPECT_EQ("?", c.DebugString(out));
1639
1640 // Minimum with constant second arg.
1641 TF_EXPECT_OK(c.Min(d_1, 1, &out));
1642 EXPECT_TRUE(SameHandle(d_1, out));
1643 TF_EXPECT_OK(c.Min(d_1, 3, &out));
1644 EXPECT_TRUE(SameHandle(d_1, out));
1645 TF_EXPECT_OK(c.Min(d_2, 1, &out));
1646 EXPECT_EQ("1", c.DebugString(out));
1647
1648 // Minimum with two dimensions.
1649 TF_EXPECT_OK(c.Min(d_1, d_1, &out));
1650 EXPECT_TRUE(SameHandle(d_1, out));
1651 TF_EXPECT_OK(c.Min(d_1, d_2, &out));
1652 EXPECT_TRUE(SameHandle(d_1, out));
1653 TF_EXPECT_OK(c.Min(d_2, d_1, &out));
1654 EXPECT_TRUE(SameHandle(d_1, out));
1655 TF_EXPECT_OK(c.Min(d_2, d_2, &out));
1656 EXPECT_TRUE(SameHandle(d_2, out));
1657 }
1658
TEST_F(ShapeInferenceTest,Max)1659 TEST_F(ShapeInferenceTest, Max) {
1660 NodeDef def;
1661 InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
1662 {});
1663
1664 auto s = c.input(0);
1665 auto d_1 = c.Dim(s, 0);
1666 auto d_2 = c.Dim(s, 1);
1667 auto d_unknown = c.Dim(s, 2);
1668
1669 // Maximum involving unknowns gives new unknown.
1670 DimensionHandle out;
1671 TF_EXPECT_OK(c.Max(d_unknown, d_unknown, &out));
1672 EXPECT_EQ("?", c.DebugString(out));
1673 TF_EXPECT_OK(c.Max(d_unknown, 1, &out));
1674 EXPECT_EQ("?", c.DebugString(out));
1675 TF_EXPECT_OK(c.Max(d_1, d_unknown, &out));
1676 EXPECT_EQ("?", c.DebugString(out));
1677
1678 // Maximum with constant second arg.
1679 TF_EXPECT_OK(c.Max(d_1, 1, &out));
1680 EXPECT_TRUE(SameHandle(d_1, out));
1681 TF_EXPECT_OK(c.Max(d_2, 1, &out));
1682 EXPECT_TRUE(SameHandle(d_2, out));
1683 TF_EXPECT_OK(c.Max(d_2, 3, &out));
1684 EXPECT_EQ("3", c.DebugString(out));
1685
1686 // Maximum with two dimensions.
1687 TF_EXPECT_OK(c.Max(d_1, d_1, &out));
1688 EXPECT_TRUE(SameHandle(d_1, out));
1689 TF_EXPECT_OK(c.Max(d_1, d_2, &out));
1690 EXPECT_TRUE(SameHandle(d_2, out));
1691 TF_EXPECT_OK(c.Max(d_2, d_1, &out));
1692 EXPECT_TRUE(SameHandle(d_2, out));
1693 TF_EXPECT_OK(c.Max(d_2, d_2, &out));
1694 EXPECT_TRUE(SameHandle(d_2, out));
1695 }
1696
1697 class ShapeInferenceHandlesTest : public ShapeInferenceTest,
1698 public ::testing::WithParamInterface<bool> {};
1699
1700 INSTANTIATE_TEST_SUITE_P(All, ShapeInferenceHandlesTest,
1701 ::testing::Bool() /* input_not_output */,
1702 ::testing::PrintToStringParamName());
1703
TEST_P(ShapeInferenceHandlesTest,MergeHandles)1704 TEST_P(ShapeInferenceHandlesTest, MergeHandles) {
1705 bool input_not_output = GetParam();
1706 NodeDef def;
1707 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1708 {});
1709 auto make_shape = [&c](std::initializer_list<int64_t> dim_sizes) {
1710 ShapeHandle s;
1711 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1712 return s;
1713 };
1714 auto get_shapes_and_types_from_context = [&](int idx) {
1715 if (input_not_output) {
1716 return c.input_handle_shapes_and_types(idx);
1717 } else {
1718 return c.output_handle_shapes_and_types(idx);
1719 }
1720 };
1721 auto merge_shapes_and_types_to_context =
1722 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1723 if (input_not_output) {
1724 return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
1725 } else {
1726 return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
1727 }
1728 };
1729
1730 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1731 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1732
1733 // First merge will take the input completely.
1734 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1735 {c.UnknownShape(), DT_INVALID},
1736 {make_shape({4, 3, 2, 1}), DT_INT32}};
1737 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1738 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1739 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1740 ASSERT_EQ(3, v.size());
1741 for (int i = 0; i < v.size(); ++i) {
1742 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1743 EXPECT_EQ(t[i].dtype, v[i].dtype);
1744 }
1745
1746 // Merge that fails because wrong number of values passed.
1747 // Fails, and no changes made.
1748 ASSERT_FALSE(merge_shapes_and_types_to_context(
1749 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1750 v = *get_shapes_and_types_from_context(0);
1751 ASSERT_EQ(3, v.size());
1752 for (int i = 0; i < v.size(); ++i) {
1753 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1754 EXPECT_EQ(t[i].dtype, v[i].dtype);
1755 }
1756
1757 // Only difference is in a mismatched shape. That is ignored,
1758 // and there are no other changes, so nothing is done.
1759 //
1760 // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
1761 // return an error (separate error from 'refined' output)?
1762 auto t2 = t;
1763 t2[2].shape = make_shape({4, 3, 4, 1});
1764 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1765 v = *get_shapes_and_types_from_context(0);
1766 ASSERT_EQ(3, v.size());
1767 for (int i = 0; i < v.size(); ++i) {
1768 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1769 EXPECT_EQ(t[i].dtype, v[i].dtype);
1770 }
1771
1772 // Only difference is in a mismatched dtype, but that cannot be
1773 // updated unless original dtype is DT_INVALID.
1774 t2 = t;
1775 t2[2].dtype = DT_FLOAT;
1776 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1777 v = *get_shapes_and_types_from_context(0);
1778 ASSERT_EQ(3, v.size());
1779 for (int i = 0; i < v.size(); ++i) {
1780 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1781 EXPECT_EQ(t[i].dtype, v[i].dtype);
1782 }
1783
1784 // Difference is mergeable (new shape).
1785 t[1].shape = make_shape({1, 10});
1786 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1787 v = *get_shapes_and_types_from_context(0);
1788 ASSERT_EQ(3, v.size());
1789 for (int i = 0; i < v.size(); ++i) {
1790 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1791 EXPECT_EQ(t[i].dtype, v[i].dtype);
1792 }
1793
1794 // Difference is mergeable (new type).
1795 t[1].dtype = DT_DOUBLE;
1796 ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1797 v = *get_shapes_and_types_from_context(0);
1798 ASSERT_EQ(3, v.size());
1799 for (int i = 0; i < v.size(); ++i) {
1800 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1801 EXPECT_EQ(t[i].dtype, v[i].dtype);
1802 }
1803
1804 // No difference.
1805 ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
1806 }
1807
TEST_P(ShapeInferenceHandlesTest,RelaxHandles)1808 TEST_P(ShapeInferenceHandlesTest, RelaxHandles) {
1809 bool input_not_output = GetParam();
1810 NodeDef def;
1811 InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1812 {});
1813 auto make_shape = [&c](std::initializer_list<int64_t> dim_sizes) {
1814 ShapeHandle s;
1815 TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1816 return s;
1817 };
1818 auto get_shapes_and_types_from_context = [&](int idx) {
1819 if (input_not_output) {
1820 return c.input_handle_shapes_and_types(idx);
1821 } else {
1822 return c.output_handle_shapes_and_types(idx);
1823 }
1824 };
1825 auto relax_shapes_and_types_to_context =
1826 [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1827 if (input_not_output) {
1828 return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types);
1829 } else {
1830 return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types);
1831 }
1832 };
1833
1834 EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1835 EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1836
1837 // First relax will take the input completely.
1838 std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1839 {c.UnknownShape(), DT_INVALID},
1840 {make_shape({4, 3, 2, 1}), DT_INT32}};
1841 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1842 ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1843 std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1844 ASSERT_EQ(3, v.size());
1845 for (int i = 0; i < v.size(); ++i) {
1846 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1847 EXPECT_EQ(t[i].dtype, v[i].dtype);
1848 }
1849
1850 // Relax that fails because wrong number of values passed.
1851 // Fails, and no changes made.
1852 ASSERT_FALSE(relax_shapes_and_types_to_context(
1853 0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1854 v = *get_shapes_and_types_from_context(0);
1855 ASSERT_EQ(3, v.size());
1856 for (int i = 0; i < v.size(); ++i) {
1857 EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1858 EXPECT_EQ(t[i].dtype, v[i].dtype);
1859 }
1860
1861 // Only difference is in a mismatched shape. This should replace
1862 // the mismatched dimension with an UnknownDim.
1863 auto t2 = t;
1864 t2[2].shape = make_shape({4, 3, 4, 1});
1865 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2));
1866 v = *get_shapes_and_types_from_context(0);
1867 EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape));
1868 for (int i = 0; i < v.size(); ++i) {
1869 EXPECT_EQ(t[i].dtype, v[i].dtype);
1870 }
1871
1872 // Only difference is in a mismatched dtype, but that cannot be
1873 // updated unless original dtype is DT_INVALID.
1874 t2 = t;
1875 t2[2].dtype = DT_FLOAT;
1876 ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2));
1877 v = *get_shapes_and_types_from_context(0);
1878 ASSERT_EQ(3, v.size());
1879 for (int i = 0; i < v.size(); ++i) {
1880 EXPECT_EQ(t[i].dtype, v[i].dtype);
1881 }
1882
1883 // Difference is a new shape, which will result in a new UnknownShape.
1884 t[1].shape = make_shape({1, 10});
1885 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1886 v = *get_shapes_and_types_from_context(0);
1887 ASSERT_EQ(3, v.size());
1888 EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape));
1889 EXPECT_EQ("?", c.DebugString(v[1].shape));
1890 for (int i = 0; i < v.size(); ++i) {
1891 EXPECT_EQ(t[i].dtype, v[i].dtype);
1892 }
1893
1894 // Difference is relaxable (new type).
1895 t[1].dtype = DT_DOUBLE;
1896 ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1897 v = *get_shapes_and_types_from_context(0);
1898 EXPECT_EQ(t[1].dtype, v[1].dtype);
1899 }
1900
1901 } // namespace
1902 } // namespace shape_inference
1903 } // namespace tensorflow
1904