xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/shape_inference_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/framework/shape_inference.h"
16 
17 #include "absl/strings/str_cat.h"
18 #include "tensorflow/core/framework/fake_input.h"
19 #include "tensorflow/core/framework/node_def_builder.h"
20 #include "tensorflow/core/framework/op_def_builder.h"
21 #include "tensorflow/core/framework/tensor_shape.pb.h"
22 #include "tensorflow/core/framework/tensor_testutil.h"
23 #include "tensorflow/core/framework/types.pb.h"
24 #include "tensorflow/core/lib/core/status_test_util.h"
25 #include "tensorflow/core/platform/status_matchers.h"
26 #include "tensorflow/core/platform/test.h"
27 #include "tensorflow/core/protobuf/error_codes.pb.h"
28 
29 namespace tensorflow {
30 namespace shape_inference {
31 namespace {
32 
33 using ::tensorflow::testing::StatusIs;
34 using ::testing::_;
35 using ::testing::AllOf;
36 using ::testing::HasSubstr;
37 
MakeOpDefWithLists()38 OpDef MakeOpDefWithLists() {
39   OpRegistrationData op_reg_data;
40   TF_EXPECT_OK(OpDefBuilder("dummy")
41                    .Input("input: N * float")
42                    .Output("output: N * float")
43                    .Attr("N:int >= 1")
44                    .Finalize(&op_reg_data));
45   return op_reg_data.op_def;
46 }
47 
S(std::initializer_list<int64_t> dims)48 PartialTensorShape S(std::initializer_list<int64_t> dims) {
49   return PartialTensorShape(dims);
50 }
51 
Unknown()52 PartialTensorShape Unknown() { return PartialTensorShape(); }
53 
54 }  // namespace
55 
56 class ShapeInferenceTest : public ::testing::Test {
57  protected:
58   // These give access to private functions of DimensionHandle and ShapeHandle.
SameHandle(DimensionHandle a,DimensionHandle b)59   bool SameHandle(DimensionHandle a, DimensionHandle b) {
60     return a.SameHandle(b);
61   }
SameHandle(ShapeHandle a,ShapeHandle b)62   bool SameHandle(ShapeHandle a, ShapeHandle b) { return a.SameHandle(b); }
IsSet(DimensionHandle d)63   bool IsSet(DimensionHandle d) { return d.IsSet(); }
IsSet(ShapeHandle s)64   bool IsSet(ShapeHandle s) { return s.IsSet(); }
Relax(InferenceContext * c,DimensionHandle d0,DimensionHandle d1,DimensionHandle * out)65   void Relax(InferenceContext* c, DimensionHandle d0, DimensionHandle d1,
66              DimensionHandle* out) {
67     c->Relax(d0, d1, out);
68   }
Relax(InferenceContext * c,ShapeHandle s0,ShapeHandle s1,ShapeHandle * out)69   void Relax(InferenceContext* c, ShapeHandle s0, ShapeHandle s1,
70              ShapeHandle* out) {
71     c->Relax(s0, s1, out);
72   }
73   void TestMergeHandles(bool input_not_output);
74   void TestRelaxHandles(bool input_not_output);
75 
76   static constexpr int kVersion = 0;  // used for graph-def version.
77 };
78 
79 namespace {
80 
TEST_F(ShapeInferenceTest,InputOutputByName)81 TEST_F(ShapeInferenceTest, InputOutputByName) {
82   // Setup test to contain an input tensor list of size 3.
83   OpDef op_def = MakeOpDefWithLists();
84   NodeDef def;
85   auto s = NodeDefBuilder("dummy", &op_def)
86                .Attr("N", 3)
87                .Input(FakeInput(DT_FLOAT))
88                .Finalize(&def);
89   InferenceContext c(kVersion, def, op_def, {S({1, 5}), S({2, 5}), S({1, 3})},
90                      {}, {}, {});
91 
92   EXPECT_EQ("5", c.DebugString(c.NumElements(c.input(0))));
93   EXPECT_EQ("10", c.DebugString(c.NumElements(c.input(1))));
94   EXPECT_EQ("3", c.DebugString(c.NumElements(c.input(2))));
95   // Test getters.
96   std::vector<ShapeHandle> shapes;
97   EXPECT_THAT(
98       c.input("nonexistent", &shapes),
99       StatusIs(error::INVALID_ARGUMENT, HasSubstr("Unknown input name")));
100   TF_EXPECT_OK(c.input("input", &shapes));
101   EXPECT_EQ("[1,5]", c.DebugString(shapes[0]));
102   EXPECT_EQ("[2,5]", c.DebugString(shapes[1]));
103   EXPECT_EQ("[1,3]", c.DebugString(shapes[2]));
104 
105   // Test setters.
106   EXPECT_THAT(
107       c.set_output("nonexistent", shapes),
108       StatusIs(error::INVALID_ARGUMENT, HasSubstr("Unknown output name")));
109   TF_EXPECT_OK(c.set_output("output", shapes));
110   EXPECT_EQ("5", c.DebugString(c.NumElements(c.output(0))));
111   EXPECT_EQ("10", c.DebugString(c.NumElements(c.output(1))));
112   EXPECT_EQ("3", c.DebugString(c.NumElements(c.output(2))));
113 }
114 
MakeOpDef(int num_inputs,int num_outputs)115 static OpDef MakeOpDef(int num_inputs, int num_outputs) {
116   OpRegistrationData op_reg_data;
117   OpDefBuilder b("dummy");
118   for (int i = 0; i < num_inputs; ++i) {
119     b.Input(absl::StrCat("i", i, ": float"));
120   }
121   for (int i = 0; i < num_outputs; ++i) {
122     b.Output(absl::StrCat("o", i, ": float"));
123   }
124   CHECK(b.Attr("foo:string").Finalize(&op_reg_data).ok());
125   return op_reg_data.op_def;
126 }
127 
TEST_F(ShapeInferenceTest,DimensionOrConstant)128 TEST_F(ShapeInferenceTest, DimensionOrConstant) {
129   NodeDef def;
130   InferenceContext c(kVersion, def, MakeOpDef(1, 1), {Unknown()}, {}, {}, {});
131   EXPECT_EQ(InferenceContext::kUnknownDim,
132             c.Value(InferenceContext::kUnknownDim));
133   EXPECT_EQ(1, c.Value(1));
134 
135 #ifndef NDEBUG
136   // Only run death test if DCHECKS are enabled.
137   EXPECT_DEATH(c.Value(-7), "Dimension must be non\\-negative or equal to");
138 #endif
139 }
140 
TEST_F(ShapeInferenceTest,Run)141 TEST_F(ShapeInferenceTest, Run) {
142   NodeDef def;
143   def.set_name("foo");
144   def.set_op("foo_op");
145   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1})}, {}, {}, {});
146   TF_ASSERT_OK(c.construction_status());
147 
148   {
149     auto fn = [](InferenceContext* c) {
150       ShapeHandle h;
151       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 6, &h));
152       c->set_output(0, c->input(0));
153       c->set_output(1, c->input(0));
154       return OkStatus();
155     };
156     TF_ASSERT_OK(c.Run(fn));
157   }
158 
159   {
160     auto fn = [](InferenceContext* c) {
161       ShapeHandle h;
162       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
163       c->set_output(0, c->input(0));
164       c->set_output(1, c->input(0));
165       return OkStatus();
166     };
167     // Extra error message is attached when Run fails.
168     EXPECT_THAT(
169         c.Run(fn),
170         StatusIs(error::INVALID_ARGUMENT,
171                  AllOf(HasSubstr("Shape must be at most rank 0 but is rank 1"),
172                        HasSubstr("node foo"), HasSubstr("foo_op"))));
173   }
174 }
175 
176 // Tests different context data added when Run returns error.
TEST_F(ShapeInferenceTest,AttachContext)177 TEST_F(ShapeInferenceTest, AttachContext) {
178   NodeDef def;
179   def.set_name("foo");
180   def.set_op("foo_op");
181   // Error when no constant tensors were requested.
182   {
183     InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3})}, {}, {},
184                        {});
185     TF_ASSERT_OK(c.construction_status());
186     auto fn = [](InferenceContext* c) {
187       ShapeHandle h;
188       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
189       c->set_output(0, c->input(0));
190       return OkStatus();
191     };
192     EXPECT_THAT(
193         c.Run(fn),
194         StatusIs(error::INVALID_ARGUMENT,
195                  AllOf(HasSubstr("Shape must be at most rank 0 but is rank 3"),
196                        HasSubstr("node foo"), HasSubstr("foo_op"),
197                        HasSubstr("input shapes: [1,2,3]"))));
198   }
199 
200   // Error when a constant tensor value was requested.
201   {
202     Tensor input_t =
203         ::tensorflow::test::AsTensor<float>({1.1, 2.2, 3.3, 4.4, 5.5});
204     InferenceContext c(kVersion, def, MakeOpDef(2, 2),
205                        {S({1, 2, 3}), S({4, 5})}, {nullptr, &input_t}, {}, {});
206     TF_ASSERT_OK(c.construction_status());
207     auto fn = [](InferenceContext* c) {
208       c->input_tensor(0);  // It's null - will not appear in the error message.
209       c->input_tensor(1);  // This will appear in the error message.
210       ShapeHandle h;
211       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
212       c->set_output(0, c->input(0));
213       return OkStatus();
214     };
215     EXPECT_THAT(
216         c.Run(fn),
217         StatusIs(error::INVALID_ARGUMENT,
218                  AllOf(HasSubstr("Shape must be at most rank 0 but is rank 3"),
219                        HasSubstr("node foo"), HasSubstr("foo_op"),
220                        HasSubstr("input shapes: [1,2,3], [4,5] and with "
221                                  "computed input tensors: "
222                                  "input[1] = <1.1 2.2 3.3 4.4 5.5>."))));
223   }
224 
225   // Error when a constant tensor value as shape was requested, but no partial
226   // shapes provided.
227   {
228     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
229     InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
230                        {nullptr, &input_t}, {}, {});
231     TF_ASSERT_OK(c.construction_status());
232     auto fn = [](InferenceContext* c) {
233       ShapeHandle s;
234       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
235       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
236       ShapeHandle h;
237       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
238       c->set_output(0, c->input(0));
239       return OkStatus();
240     };
241     EXPECT_THAT(
242         c.Run(fn),
243         StatusIs(error::INVALID_ARGUMENT,
244                  AllOf(HasSubstr("Shape must be at most rank 0 but is rank 1"),
245                        HasSubstr("node foo"), HasSubstr("foo_op"),
246                        HasSubstr("with input shapes: [3], [4] and with "
247                                  "computed input tensors: input[1] "
248                                  "= <1 2 3 4 5>."))));
249   }
250 
251   // Error when a constant tensor value as shape was requested, and a partial
252   // shape was provided.
253   {
254     Tensor input_t = ::tensorflow::test::AsTensor<int32>({1, 2, 3, 4, 5});
255     InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({3}), S({4})},
256                        {nullptr, &input_t}, {S({10, -1, 5}), Unknown()}, {});
257     TF_ASSERT_OK(c.construction_status());
258     auto fn = [](InferenceContext* c) {
259       ShapeHandle s;
260       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(0, &s));
261       TF_RETURN_IF_ERROR(c->MakeShapeFromShapeTensor(1, &s));
262       ShapeHandle h;
263       TF_RETURN_IF_ERROR(c->WithRankAtMost(c->input(0), 0, &h));
264       c->set_output(0, c->input(0));
265       return OkStatus();
266     };
267     EXPECT_THAT(
268         c.Run(fn),
269         StatusIs(
270             error::INVALID_ARGUMENT,
271             AllOf(HasSubstr("Shape must be at most rank 0 but is rank 1"),
272                   HasSubstr("node foo"), HasSubstr("foo_op"),
273                   HasSubstr("with input shapes: [3], [4] and with computed "
274                             "input tensors: input[1] = <1 2 3 4 5> and with "
275                             "input tensors computed "
276                             "as partial shapes: input[0] = [10,?,5]."))));
277   }
278 }
279 
TEST_F(ShapeInferenceTest,RankAndDimInspection)280 TEST_F(ShapeInferenceTest, RankAndDimInspection) {
281   NodeDef def;
282   InferenceContext c(kVersion, def, MakeOpDef(3, 2),
283                      {Unknown(), S({1, -1, 3}), S({})}, {}, {}, {});
284   EXPECT_EQ(3, c.num_inputs());
285   EXPECT_EQ(2, c.num_outputs());
286 
287   auto in0 = c.input(0);
288   EXPECT_EQ("?", c.DebugString(in0));
289   EXPECT_FALSE(c.RankKnown(in0));
290   EXPECT_EQ(InferenceContext::kUnknownRank, c.Rank(in0));
291   EXPECT_EQ("?", c.DebugString(c.Dim(in0, 0)));
292   EXPECT_EQ("?", c.DebugString(c.Dim(in0, -1)));
293   EXPECT_EQ("?", c.DebugString(c.Dim(in0, 1000)));
294 
295   auto in1 = c.input(1);
296   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
297   EXPECT_TRUE(c.RankKnown(in1));
298   EXPECT_EQ(3, c.Rank(in1));
299   auto d = c.Dim(in1, 0);
300   EXPECT_EQ(1, c.Value(d));
301   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -3)));
302   EXPECT_TRUE(c.ValueKnown(d));
303   EXPECT_EQ("1", c.DebugString(d));
304   d = c.Dim(in1, 1);
305   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(d));
306   EXPECT_FALSE(c.ValueKnown(d));
307   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -2)));
308   EXPECT_EQ("?", c.DebugString(d));
309   d = c.Dim(in1, 2);
310   EXPECT_EQ(3, c.Value(d));
311   EXPECT_TRUE(SameHandle(d, c.Dim(in1, -1)));
312   EXPECT_TRUE(c.ValueKnown(d));
313   EXPECT_EQ("3", c.DebugString(d));
314 
315   auto in2 = c.input(2);
316   EXPECT_EQ("[]", c.DebugString(in2));
317   EXPECT_TRUE(c.RankKnown(in2));
318   EXPECT_EQ(0, c.Rank(in2));
319 }
320 
TEST_F(ShapeInferenceTest,NumElements)321 TEST_F(ShapeInferenceTest, NumElements) {
322   NodeDef def;
323   InferenceContext c(kVersion, def, MakeOpDef(3, 2),
324                      {Unknown(), S({1, -1, 3}), S({5, 4, 3, 2})}, {}, {}, {});
325 
326   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(0))));
327   EXPECT_EQ("?", c.DebugString(c.NumElements(c.input(1))));
328 
329   // Different handles (not the same unknown value).
330   EXPECT_FALSE(SameHandle(c.Dim(c.input(1), 1), c.NumElements(c.input(1))));
331 
332   EXPECT_EQ("120", c.DebugString(c.NumElements(c.input(2))));
333 }
334 
TEST_F(ShapeInferenceTest,WithRank)335 TEST_F(ShapeInferenceTest, WithRank) {
336   NodeDef def;
337   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
338                      {}, {}, {});
339 
340   auto in0 = c.input(0);
341   auto in1 = c.input(1);
342   ShapeHandle s1;
343   ShapeHandle s2;
344 
345   // WithRank on a shape with unknown dimensionality always succeeds.
346   TF_EXPECT_OK(c.WithRank(in0, 1, &s1));
347   EXPECT_EQ("[?]", c.DebugString(s1));
348 
349   TF_EXPECT_OK(c.WithRank(in0, 2, &s2));
350   EXPECT_EQ("[?,?]", c.DebugString(s2));
351   EXPECT_FALSE(SameHandle(s1, s2));
352   EXPECT_FALSE(SameHandle(c.Dim(s2, 0), c.Dim(s2, 1)));
353 
354   TF_EXPECT_OK(c.WithRank(in0, 1, &s2));
355   EXPECT_EQ("[?]", c.DebugString(s2));
356   EXPECT_FALSE(SameHandle(s1, s2));
357 
358   TF_EXPECT_OK(c.WithRank(in0, 0, &s1));
359   EXPECT_EQ("[]", c.DebugString(s1));
360 
361   // WithRank on shape with known dimensionality.
362   s1 = in1;
363   EXPECT_THAT(c.WithRank(in1, 2, &s1),
364               StatusIs(error::INVALID_ARGUMENT,
365                        HasSubstr("Shape must be rank 2 but is rank 3")));
366 
367   EXPECT_FALSE(IsSet(s1));
368   TF_EXPECT_OK(c.WithRank(in1, 3, &s1));
369   EXPECT_TRUE(SameHandle(s1, in1));
370 
371   // Inputs are unchanged.
372   EXPECT_EQ("?", c.DebugString(in0));
373   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
374 }
375 
TEST_F(ShapeInferenceTest,WithRankAtMost)376 TEST_F(ShapeInferenceTest, WithRankAtMost) {
377   NodeDef def;
378   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
379                      {}, {}, {});
380 
381   auto in0 = c.input(0);
382   auto in1 = c.input(1);
383   ShapeHandle s1;
384   ShapeHandle s2;
385 
386   // WithRankAtMost on a shape with unknown dimensionality always succeeds.
387   TF_EXPECT_OK(c.WithRankAtMost(in0, 1, &s1));
388   EXPECT_EQ("?", c.DebugString(s1));
389   EXPECT_TRUE(SameHandle(in0, s1));
390 
391   TF_EXPECT_OK(c.WithRankAtMost(in0, 2, &s2));
392   EXPECT_EQ("?", c.DebugString(s2));
393   EXPECT_TRUE(SameHandle(s1, s2));
394 
395   // WithRankAtMost on shape with known dimensionality.
396   s1 = in1;
397   EXPECT_THAT(
398       c.WithRankAtMost(in1, 2, &s1),
399       StatusIs(error::INVALID_ARGUMENT,
400                HasSubstr("Shape must be at most rank 2 but is rank 3")));
401 
402   EXPECT_FALSE(IsSet(s1));
403   TF_EXPECT_OK(c.WithRankAtMost(in1, 3, &s1));
404   EXPECT_TRUE(SameHandle(s1, in1));
405   TF_EXPECT_OK(c.WithRankAtMost(in1, 4, &s1));
406   EXPECT_TRUE(SameHandle(s1, in1));
407   TF_EXPECT_OK(c.WithRankAtMost(in1, 5, &s1));
408   EXPECT_TRUE(SameHandle(s1, in1));
409 
410   // Inputs are unchanged.
411   EXPECT_EQ("?", c.DebugString(in0));
412   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
413 }
414 
TEST_F(ShapeInferenceTest,WithRankAtLeast)415 TEST_F(ShapeInferenceTest, WithRankAtLeast) {
416   NodeDef def;
417   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {Unknown(), S({1, -1, 3})},
418                      {}, {}, {});
419 
420   auto in0 = c.input(0);
421   auto in1 = c.input(1);
422   ShapeHandle s1;
423   ShapeHandle s2;
424 
425   // WithRankAtLeast on a shape with unknown dimensionality always succeeds.
426   TF_EXPECT_OK(c.WithRankAtLeast(in0, 1, &s1));
427   EXPECT_EQ("?", c.DebugString(s1));
428   EXPECT_TRUE(SameHandle(in0, s1));
429 
430   TF_EXPECT_OK(c.WithRankAtLeast(in0, 2, &s2));
431   EXPECT_EQ("?", c.DebugString(s2));
432   EXPECT_TRUE(SameHandle(s1, s2));
433 
434   // WithRankAtLeast on shape with known dimensionality.
435   s1 = in1;
436   EXPECT_THAT(
437       c.WithRankAtLeast(in1, 4, &s1),
438       StatusIs(error::INVALID_ARGUMENT,
439                HasSubstr("Shape must be at least rank 4 but is rank 3")));
440 
441   EXPECT_FALSE(IsSet(s1));
442   TF_EXPECT_OK(c.WithRankAtLeast(in1, 3, &s1));
443   EXPECT_TRUE(SameHandle(s1, in1));
444   TF_EXPECT_OK(c.WithRankAtLeast(in1, 2, &s1));
445   EXPECT_TRUE(SameHandle(s1, in1));
446   TF_EXPECT_OK(c.WithRankAtLeast(in1, 0, &s1));
447   EXPECT_TRUE(SameHandle(s1, in1));
448 
449   // Inputs are unchanged.
450   EXPECT_EQ("?", c.DebugString(in0));
451   EXPECT_EQ("[1,?,3]", c.DebugString(in1));
452 }
453 
TEST_F(ShapeInferenceTest,WithValue)454 TEST_F(ShapeInferenceTest, WithValue) {
455   NodeDef def;
456   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, -1})}, {}, {}, {});
457 
458   auto d0 = c.Dim(c.input(0), 0);
459   auto d1 = c.Dim(c.input(0), 1);
460   DimensionHandle out1;
461   DimensionHandle out2;
462 
463   // WithValue on a dimension with unknown value always succeeds.
464   TF_EXPECT_OK(c.WithValue(d1, 1, &out1));
465   EXPECT_EQ(1, c.Value(out1));
466 
467   TF_EXPECT_OK(c.WithValue(d1, 2, &out2));
468   EXPECT_EQ(2, c.Value(out2));
469   EXPECT_FALSE(SameHandle(out1, out2));
470   EXPECT_FALSE(SameHandle(out1, d1));
471 
472   TF_EXPECT_OK(c.WithValue(d1, 1, &out2));
473   EXPECT_EQ(1, c.Value(out2));
474   EXPECT_FALSE(SameHandle(out1, out2));
475 
476   // WithValue on dimension with known size.
477   out1 = d0;
478 
479   EXPECT_THAT(c.WithValue(d0, 0, &out1),
480               StatusIs(error::INVALID_ARGUMENT,
481                        HasSubstr("Dimension must be 0 but is 1")));
482   EXPECT_FALSE(IsSet(out1));
483   out1 = d0;
484   EXPECT_THAT(c.WithValue(d0, 2, &out1),
485               StatusIs(error::INVALID_ARGUMENT,
486                        HasSubstr("Dimension must be 2 but is 1")));
487 
488   EXPECT_FALSE(IsSet(out1));
489   TF_EXPECT_OK(c.WithValue(d0, 1, &out1));
490   EXPECT_TRUE(SameHandle(d0, out1));
491 
492   // Inputs are unchanged.
493   EXPECT_EQ("1", c.DebugString(d0));
494   EXPECT_EQ("?", c.DebugString(d1));
495 }
496 
TEST_F(ShapeInferenceTest,MergeDim)497 TEST_F(ShapeInferenceTest, MergeDim) {
498   NodeDef def;
499   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({2, -1, 2, 1, -1})}, {},
500                      {}, {});
501 
502   auto d2 = c.Dim(c.input(0), 0);
503   auto d_unknown = c.Dim(c.input(0), 1);
504   auto d2_b = c.Dim(c.input(0), 2);
505   auto d1 = c.Dim(c.input(0), 3);
506   auto d_unknown_b = c.Dim(c.input(0), 4);
507   DimensionHandle out;
508 
509   // Merging anything with unknown returns the same pointer.
510   TF_EXPECT_OK(c.Merge(d2, d_unknown, &out));
511   EXPECT_TRUE(SameHandle(d2, out));
512   TF_EXPECT_OK(c.Merge(d_unknown, d2, &out));
513   EXPECT_TRUE(SameHandle(d2, out));
514   TF_EXPECT_OK(c.Merge(d_unknown, d_unknown_b, &out));
515   EXPECT_TRUE(SameHandle(d_unknown, out));
516 
517   auto merged_dims = c.MergedDims();
518   ASSERT_EQ(3, merged_dims.size());
519   EXPECT_TRUE(merged_dims[0].first.SameHandle(d2));
520   EXPECT_TRUE(merged_dims[0].second.SameHandle(d_unknown));
521   EXPECT_TRUE(merged_dims[1].first.SameHandle(d_unknown));
522   EXPECT_TRUE(merged_dims[1].second.SameHandle(d2));
523   EXPECT_TRUE(merged_dims[2].first.SameHandle(d_unknown));
524   EXPECT_TRUE(merged_dims[2].second.SameHandle(d_unknown_b));
525 
526   // Merging with self is a no-op and returns self.
527   TF_EXPECT_OK(c.Merge(d2, d2, &out));
528   EXPECT_TRUE(SameHandle(d2, out));
529   TF_EXPECT_OK(c.Merge(d_unknown, d_unknown, &out));
530   EXPECT_TRUE(SameHandle(d_unknown, out));
531 
532   merged_dims = c.MergedDims();
533   EXPECT_EQ(3, merged_dims.size());
534 
535   // Merging equal values is a no op and returns first one.
536   TF_EXPECT_OK(c.Merge(d2, d2_b, &out));
537   EXPECT_TRUE(SameHandle(d2, out));
538   TF_EXPECT_OK(c.Merge(d2_b, d2, &out));
539   EXPECT_TRUE(SameHandle(d2_b, out));
540 
541   merged_dims = c.MergedDims();
542   EXPECT_EQ(3, merged_dims.size());
543 
544   // Merging unequal values is an error.
545   EXPECT_THAT(c.Merge(d2, d1, &out),
546               StatusIs(error::INVALID_ARGUMENT,
547                        HasSubstr("Dimensions must be equal, but are 2 and 1")));
548 
549   EXPECT_FALSE(IsSet(out));
550   EXPECT_THAT(c.Merge(d1, d2, &out),
551               StatusIs(error::INVALID_ARGUMENT,
552                        HasSubstr("Dimensions must be equal, but are 1 and 2")));
553 
554   EXPECT_FALSE(IsSet(out));
555 
556   merged_dims = c.MergedDims();
557   EXPECT_EQ(3, merged_dims.size());
558 }
559 
TEST_F(ShapeInferenceTest,RelaxDim)560 TEST_F(ShapeInferenceTest, RelaxDim) {
561   NodeDef def;
562   InferenceContext c(kVersion, def, MakeOpDef(1, 2),
563                      {S({2, InferenceContext::kUnknownDim, 2, 1,
564                          InferenceContext::kUnknownDim})},
565                      {}, {}, {});
566 
567   auto d2 = c.Dim(c.input(0), 0);
568   auto d_unknown = c.Dim(c.input(0), 1);
569   auto d2_b = c.Dim(c.input(0), 2);
570   auto d1 = c.Dim(c.input(0), 3);
571   auto d_unknown_b = c.Dim(c.input(0), 4);
572   DimensionHandle out;
573 
574   // Relaxing anything with unknown returns a new unknown or the existing
575   // unknown.
576   Relax(&c, d2, d_unknown, &out);
577   EXPECT_TRUE(SameHandle(d_unknown, out));
578   EXPECT_FALSE(SameHandle(d_unknown_b, out));
579   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
580   Relax(&c, d_unknown, d2, &out);
581   EXPECT_FALSE(SameHandle(d_unknown, out));
582   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
583   Relax(&c, d_unknown, d_unknown_b, &out);
584   EXPECT_FALSE(SameHandle(d_unknown, out));
585   EXPECT_TRUE(SameHandle(d_unknown_b, out));
586   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
587 
588   // Relaxing with self returns self.
589   Relax(&c, d2, d2, &out);
590   EXPECT_TRUE(SameHandle(d2, out));
591   Relax(&c, d_unknown, d_unknown, &out);
592   EXPECT_TRUE(SameHandle(d_unknown, out));
593 
594   // Relaxing equal values returns first one.
595   Relax(&c, d2, d2_b, &out);
596   EXPECT_TRUE(SameHandle(d2, out));
597   Relax(&c, d2_b, d2, &out);
598   EXPECT_TRUE(SameHandle(d2_b, out));
599 
600   // Relaxing unequal values returns a new unknown.
601   Relax(&c, d2, d1, &out);
602   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
603   Relax(&c, d1, d2, &out);
604   EXPECT_EQ(InferenceContext::kUnknownDim, c.Value(out));
605 }
606 
TEST_F(ShapeInferenceTest,RelaxShape)607 TEST_F(ShapeInferenceTest, RelaxShape) {
608   NodeDef def;
609   InferenceContext c(
610       kVersion, def, MakeOpDef(7, 2),
611       {Unknown(), S({1, 2}), S({InferenceContext::kUnknownDim, 2}),
612        S({1, InferenceContext::kUnknownDim}), S({1, 3}), Unknown(), S({1})},
613       {}, {}, {});
614 
615   auto s_unknown = c.input(0);
616   auto s_1_2 = c.input(1);
617   auto s_u_2 = c.input(2);
618   auto s_1_u = c.input(3);
619   auto s_1_3 = c.input(4);
620   auto s_unknown_b = c.input(5);
621   auto s_1 = c.input(6);
622   ShapeHandle out;
623 
624   // Relaxing any shape with unknown returns a new unknown.
625   Relax(&c, s_unknown, s_1_2, &out);
626   EXPECT_FALSE(SameHandle(s_u_2, s_unknown));
627   EXPECT_EQ("?", c.DebugString(out));
628   Relax(&c, s_u_2, s_unknown, &out);
629   EXPECT_FALSE(SameHandle(s_u_2, out));
630   EXPECT_EQ("?", c.DebugString(out));
631   Relax(&c, s_unknown, s_unknown_b, &out);
632   EXPECT_FALSE(SameHandle(s_unknown, out));
633   EXPECT_TRUE(SameHandle(s_unknown_b, out));
634   EXPECT_EQ("?", c.DebugString(out));
635 
636   // Relaxing with self returns self.
637   Relax(&c, s_1_2, s_1_2, &out);
638   EXPECT_TRUE(SameHandle(out, s_1_2));
639 
640   // Relaxing where one of the inputs has less information.
641   out = ShapeHandle();
642   Relax(&c, s_1_2, s_u_2, &out);
643   EXPECT_FALSE(SameHandle(s_u_2, out));
644   EXPECT_EQ("[?,2]", c.DebugString(out));
645   out = ShapeHandle();
646   Relax(&c, s_u_2, s_1_2, &out);
647   EXPECT_FALSE(SameHandle(s_u_2, out));
648   EXPECT_EQ("[?,2]", c.DebugString(out));
649 
650   // Relaxing where each input has one distinct unknown dimension.
651   Relax(&c, s_u_2, s_1_u, &out);
652   EXPECT_EQ("[?,?]", c.DebugString(out));
653   EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
654   EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 1), c.Dim(out, 1)));
655   auto s_u1 = c.UnknownShapeOfRank(1);
656   auto s_u2 = c.UnknownShapeOfRank(1);
657   Relax(&c, s_u1, s_u2, &out);
658   EXPECT_FALSE(SameHandle(s_u1, out));
659 
660   // Relaxing with mismatched values in a dimension returns a shape with that
661   // dimension unknown.
662   out = s_unknown;
663   Relax(&c, s_u_2, s_1_3, &out);
664   EXPECT_FALSE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
665   EXPECT_EQ("[?,?]", c.DebugString(out));
666   out = s_unknown;
667   Relax(&c, s_1_3, s_u_2, &out);
668   EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 0), c.Dim(out, 0)));
669   EXPECT_EQ("[?,?]", c.DebugString(out));
670   out = s_unknown;
671 
672   // Relaxing with mismatched ranks returns a new unknown.
673   Relax(&c, s_1, s_1_2, &out);
674   EXPECT_EQ("?", c.DebugString(out));
675 }
676 
TEST_F(ShapeInferenceTest,MergeShape)677 TEST_F(ShapeInferenceTest, MergeShape) {
678   NodeDef def;
679   InferenceContext c(kVersion, def, MakeOpDef(7, 2),
680                      {Unknown(), S({1, 2}), S({-1, 2}), S({1, -1}), S({1, 3}),
681                       Unknown(), S({1})},
682                      {}, {}, {});
683 
684   auto s_unknown = c.input(0);
685   auto s_1_2 = c.input(1);
686   auto s_u_2 = c.input(2);
687   auto s_1_u = c.input(3);
688   auto s_1_3 = c.input(4);
689   auto s_unknown_b = c.input(5);
690   auto s_1 = c.input(6);
691   ShapeHandle out;
692 
693   // Merging any shape with unknown returns the shape.
694   TF_EXPECT_OK(c.Merge(s_unknown, s_1_2, &out));
695   EXPECT_TRUE(SameHandle(s_1_2, out));
696   TF_EXPECT_OK(c.Merge(s_u_2, s_unknown, &out));
697   EXPECT_TRUE(SameHandle(s_u_2, out));
698   TF_EXPECT_OK(c.Merge(s_unknown, s_unknown_b, &out));
699   EXPECT_TRUE(SameHandle(s_unknown, out));
700 
701   auto merged_shapes = c.MergedShapes();
702   ASSERT_EQ(3, merged_shapes.size());
703   EXPECT_TRUE(merged_shapes[0].first.SameHandle(s_unknown));
704   EXPECT_TRUE(merged_shapes[0].second.SameHandle(s_1_2));
705   EXPECT_TRUE(merged_shapes[1].first.SameHandle(s_u_2));
706   EXPECT_TRUE(merged_shapes[1].second.SameHandle(s_unknown));
707   EXPECT_TRUE(merged_shapes[2].first.SameHandle(s_unknown));
708   EXPECT_TRUE(merged_shapes[2].second.SameHandle(s_unknown_b));
709 
710   // Merging with self returns self.
711   TF_EXPECT_OK(c.Merge(s_1_2, s_1_2, &out));
712   EXPECT_TRUE(SameHandle(out, s_1_2));
713 
714   merged_shapes = c.MergedShapes();
715   EXPECT_EQ(3, merged_shapes.size());
716 
717   // Merging where one of the inputs is the right answer - return that input.
718   out = ShapeHandle();
719   TF_EXPECT_OK(c.Merge(s_1_2, s_u_2, &out));
720   EXPECT_TRUE(SameHandle(s_1_2, out));
721   out = ShapeHandle();
722   TF_EXPECT_OK(c.Merge(s_u_2, s_1_2, &out));
723   EXPECT_TRUE(SameHandle(s_1_2, out));
724 
725   merged_shapes = c.MergedShapes();
726   ASSERT_EQ(5, merged_shapes.size());
727   EXPECT_TRUE(merged_shapes[3].first.SameHandle(s_1_2));
728   EXPECT_TRUE(merged_shapes[3].second.SameHandle(s_u_2));
729   EXPECT_TRUE(merged_shapes[4].first.SameHandle(s_u_2));
730   EXPECT_TRUE(merged_shapes[4].second.SameHandle(s_1_2));
731 
732   // Merging where neither input is the right answer.
733   TF_EXPECT_OK(c.Merge(s_u_2, s_1_u, &out));
734   EXPECT_FALSE(SameHandle(out, s_u_2));
735   EXPECT_FALSE(SameHandle(out, s_1_u));
736   EXPECT_EQ("[1,2]", c.DebugString(out));
737   EXPECT_TRUE(SameHandle(c.Dim(s_1_u, 0), c.Dim(out, 0)));
738   EXPECT_TRUE(SameHandle(c.Dim(s_u_2, 1), c.Dim(out, 1)));
739 
740   merged_shapes = c.MergedShapes();
741   ASSERT_EQ(7, merged_shapes.size());
742   EXPECT_TRUE(merged_shapes[5].first.SameHandle(s_u_2));
743   EXPECT_TRUE(merged_shapes[5].second.SameHandle(s_1_u));
744   EXPECT_TRUE(merged_shapes[6].first.SameHandle(s_u_2));
745   EXPECT_TRUE(merged_shapes[6].second.SameHandle(out));
746 
747   auto s_u1 = c.UnknownShapeOfRank(1);
748   auto s_u2 = c.UnknownShapeOfRank(1);
749   TF_EXPECT_OK(c.Merge(s_u1, s_u2, &out));
750   EXPECT_TRUE(SameHandle(s_u1, out));
751 
752   merged_shapes = c.MergedShapes();
753   ASSERT_EQ(8, merged_shapes.size());
754   EXPECT_TRUE(merged_shapes[7].first.SameHandle(s_u1));
755   EXPECT_TRUE(merged_shapes[7].second.SameHandle(s_u2));
756 
757   // Incompatible merges give errors and set out to nullptr.
758   out = s_unknown;
759   EXPECT_THAT(
760       c.Merge(s_u_2, s_1_3, &out),
761       StatusIs(
762           error::INVALID_ARGUMENT,
763           HasSubstr(
764               "Dimension 1 in both shapes must be equal, but are 2 and 3")));
765 
766   EXPECT_FALSE(IsSet(out));
767   out = s_unknown;
768   EXPECT_THAT(
769       c.Merge(s_1_3, s_u_2, &out),
770       StatusIs(
771           error::INVALID_ARGUMENT,
772           HasSubstr(
773               "Dimension 1 in both shapes must be equal, but are 3 and 2")));
774 
775   EXPECT_FALSE(IsSet(out));
776   out = s_unknown;
777   EXPECT_THAT(
778       c.Merge(s_1, s_1_2, &out),
779       StatusIs(error::INVALID_ARGUMENT,
780                HasSubstr("Shapes must be equal rank, but are 1 and 2")));
781 
782   EXPECT_FALSE(IsSet(out));
783 
784   merged_shapes = c.MergedShapes();
785   EXPECT_EQ(8, merged_shapes.size());
786 }
787 
TEST_F(ShapeInferenceTest,MergePrefix)788 TEST_F(ShapeInferenceTest, MergePrefix) {
789   NodeDef def;
790   InferenceContext c(kVersion, def, MakeOpDef(4, 2),
791                      {
792                          Unknown(),
793                          S({-1, 2}),
794                          S({1, -1, 3}),
795                          S({2, 4}),
796                      },
797                      {}, {}, {});
798 
799   auto s_unknown = c.input(0);
800   auto s_u_2 = c.input(1);
801   auto s_1_u_3 = c.input(2);
802   auto s_2_4 = c.input(3);
803 
804   ShapeHandle s_out;
805   ShapeHandle s_prefix_out;
806 
807   // Merging with unknown returns the inputs.
808   TF_EXPECT_OK(c.MergePrefix(s_unknown, s_u_2, &s_out, &s_prefix_out));
809   EXPECT_TRUE(SameHandle(s_out, s_unknown));
810   EXPECT_TRUE(SameHandle(s_prefix_out, s_u_2));
811   TF_EXPECT_OK(c.MergePrefix(s_1_u_3, s_unknown, &s_out, &s_prefix_out));
812   EXPECT_TRUE(SameHandle(s_out, s_1_u_3));
813   EXPECT_TRUE(SameHandle(s_prefix_out, s_unknown));
814 
815   TF_EXPECT_OK(c.MergePrefix(s_1_u_3, s_u_2, &s_out, &s_prefix_out));
816   EXPECT_FALSE(SameHandle(s_out, s_1_u_3));
817   EXPECT_EQ("[1,2]", c.DebugString(s_prefix_out));
818   EXPECT_EQ("[1,2,3]", c.DebugString(s_out));
819   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 0), c.Dim(s_out, 0)));
820   EXPECT_TRUE(SameHandle(c.Dim(s_out, 0), c.Dim(s_1_u_3, 0)));
821   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_out, 1)));
822   EXPECT_TRUE(SameHandle(c.Dim(s_prefix_out, 1), c.Dim(s_u_2, 1)));
823 
824   // Incompatible merges give errors and set outs to nullptr.
825   s_out = s_unknown;
826   s_prefix_out = s_unknown;
827   EXPECT_THAT(c.MergePrefix(s_1_u_3, s_2_4, &s_out, &s_prefix_out),
828               StatusIs(error::INVALID_ARGUMENT,
829                        HasSubstr("Dimensions must be equal, but are 1 and 2")));
830 
831   EXPECT_FALSE(IsSet(s_out));
832   EXPECT_FALSE(IsSet(s_prefix_out));
833 
834   s_out = s_unknown;
835   s_prefix_out = s_unknown;
836   EXPECT_THAT(
837       c.MergePrefix(s_2_4, s_1_u_3, &s_out, &s_prefix_out),
838       StatusIs(error::INVALID_ARGUMENT,
839                HasSubstr("Shape must be at least rank 3 but is rank 2")));
840   EXPECT_FALSE(IsSet(s_out));
841   EXPECT_FALSE(IsSet(s_prefix_out));
842 }
843 
TEST_F(ShapeInferenceTest,Subshape)844 TEST_F(ShapeInferenceTest, Subshape) {
845   NodeDef def;
846   InferenceContext c(kVersion, def, MakeOpDef(2, 2),
847                      {S({1, 2, 3, -1, 5}), Unknown()}, {}, {}, {});
848 
849   ShapeHandle unknown = c.input(1);
850   ShapeHandle out;
851   TF_EXPECT_OK(c.Subshape(unknown, 0, &out));
852   EXPECT_EQ("?", c.DebugString(out));
853   EXPECT_TRUE(SameHandle(out, unknown));
854   TF_EXPECT_OK(c.Subshape(unknown, 1, &out));
855   EXPECT_EQ("?", c.DebugString(out));
856   EXPECT_FALSE(SameHandle(out, unknown));
857   TF_EXPECT_OK(c.Subshape(unknown, 200, &out));
858   EXPECT_EQ("?", c.DebugString(out));
859   EXPECT_FALSE(SameHandle(out, unknown));
860 
861   const int kFullRank = 5;
862   ShapeHandle out_arr[4];
863   auto in0 = c.input(0);
864   TF_EXPECT_OK(c.Subshape(in0, 0, &out));
865   EXPECT_EQ("[1,2,3,?,5]", c.DebugString(out));
866   EXPECT_TRUE(SameHandle(out, in0));
867   EXPECT_EQ(kFullRank, c.Rank(out));
868   for (int start = 0; start <= kFullRank + 1; ++start) {
869     for (int end = start; end <= kFullRank + 1; ++end) {
870       // Get subshapes using different start and end values that give the same
871       // range.
872       const int neg_start =
873           start >= kFullRank ? kFullRank : (start - kFullRank);
874       const int neg_end = end >= kFullRank ? kFullRank : (end - kFullRank);
875       TF_ASSERT_OK(c.Subshape(in0, start, end, &out_arr[0]));
876       TF_ASSERT_OK(c.Subshape(in0, neg_start, end, &out_arr[1]));
877       TF_ASSERT_OK(c.Subshape(in0, start, neg_end, &out_arr[2]));
878       TF_ASSERT_OK(c.Subshape(in0, neg_start, neg_end, &out_arr[3]));
879 
880       // Verify all computed subshapes.
881       for (int arr_idx = 0; arr_idx < 4; ++arr_idx) {
882         out = out_arr[arr_idx];
883         ASSERT_EQ(std::min(kFullRank, end) - std::min(kFullRank, start),
884                   c.Rank(out))
885             << "start: " << start << " end: " << end << " arr_idx: " << arr_idx
886             << " in0: " << c.DebugString(in0) << " out: " << c.DebugString(out);
887         for (int d = 0; d < c.Rank(out); ++d) {
888           EXPECT_TRUE(SameHandle(c.Dim(in0, start + d), c.Dim(out, d)))
889               << "arr_idx: " << arr_idx;
890         }
891       }
892     }
893   }
894 
895   // Errors.
896   out = unknown;
897   EXPECT_THAT(
898       c.Subshape(in0, 6, -3, &out),
899       StatusIs(error::INVALID_ARGUMENT,
900                HasSubstr("Subshape must have computed start <= end, but is 5 "
901                          "and 2 (computed from start 6 and end -3 over shape "
902                          "with rank 5)")));
903   EXPECT_FALSE(IsSet(out));
904   out = unknown;
905   EXPECT_THAT(
906       c.Subshape(in0, -50, 100, &out),
907       StatusIs(
908           error::INVALID_ARGUMENT,
909           HasSubstr(
910               "Subshape start out of bounds: -50, for shape with rank 5")));
911 
912   EXPECT_FALSE(IsSet(out));
913   out = unknown;
914   EXPECT_THAT(
915       c.Subshape(in0, 0, -50, &out),
916       StatusIs(
917           error::INVALID_ARGUMENT,
918           HasSubstr("Subshape end out of bounds: -50, for shape with rank 5")));
919 
920   EXPECT_FALSE(IsSet(out));
921 }
922 
TEST_F(ShapeInferenceTest,Concatenate)923 TEST_F(ShapeInferenceTest, Concatenate) {
924   NodeDef def;
925   InferenceContext c(kVersion, def, MakeOpDef(3, 2),
926                      {S({1, -1, 3}), S({4, 5}), Unknown()}, {}, {}, {});
927 
928   auto in0 = c.input(0);
929   auto in1 = c.input(1);
930   ShapeHandle unknown = c.input(2);
931   ShapeHandle out;
932   TF_EXPECT_OK(c.Concatenate(unknown, unknown, &out));
933   EXPECT_EQ("?", c.DebugString(out));
934   EXPECT_FALSE(SameHandle(out, unknown));
935   TF_EXPECT_OK(c.Concatenate(unknown, in0, &out));
936   EXPECT_EQ("?", c.DebugString(out));
937   EXPECT_FALSE(SameHandle(out, unknown));
938 
939   TF_EXPECT_OK(c.Concatenate(in0, in1, &out));
940   EXPECT_EQ("[1,?,3,4,5]", c.DebugString(out));
941   int out_i = 0;
942   for (int i = 0; i < c.Rank(in0); ++i, ++out_i) {
943     EXPECT_TRUE(SameHandle(c.Dim(in0, i), c.Dim(out, out_i)));
944   }
945   for (int i = 0; i < c.Rank(in1); ++i, ++out_i) {
946     EXPECT_TRUE(SameHandle(c.Dim(in1, i), c.Dim(out, out_i)));
947   }
948 }
949 
TEST_F(ShapeInferenceTest,ReplaceDim)950 TEST_F(ShapeInferenceTest, ReplaceDim) {
951   NodeDef def;
952   InferenceContext c(kVersion, def, MakeOpDef(2, 0), {S({1, 2, 3}), Unknown()},
953                      {}, {}, {});
954 
955   auto in = c.input(0);
956   auto unknown = c.input(1);
957 
958   ShapeHandle replaced;
959   TF_EXPECT_OK(c.ReplaceDim(in, 0, c.Dim(in, 1), &replaced));
960   EXPECT_EQ("[2,2,3]", c.DebugString(replaced));
961   TF_EXPECT_OK(c.ReplaceDim(in, 2, c.Dim(in, 1), &replaced));
962   EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
963   TF_EXPECT_OK(c.ReplaceDim(in, 1, c.Dim(in, 2), &replaced));
964   EXPECT_EQ("[1,3,3]", c.DebugString(replaced));
965   TF_EXPECT_OK(c.ReplaceDim(unknown, 0, c.Dim(in, 1), &replaced));
966   EXPECT_EQ("?", c.DebugString(replaced));
967 
968   // Negative indexing.
969   TF_EXPECT_OK(c.ReplaceDim(in, -1, c.Dim(in, 1), &replaced));
970   EXPECT_EQ("[1,2,2]", c.DebugString(replaced));
971   TF_EXPECT_OK(c.ReplaceDim(unknown, -1, c.Dim(in, 1), &replaced));
972   EXPECT_EQ("?", c.DebugString(replaced));
973 
974   // out of range indexing.
975   EXPECT_THAT(
976       c.ReplaceDim(in, 3, c.Dim(in, 1), &replaced),
977       StatusIs(error::INVALID_ARGUMENT, HasSubstr("Out of range dim_index")));
978   EXPECT_FALSE(IsSet(replaced));
979   replaced = in;
980   EXPECT_THAT(
981       c.ReplaceDim(in, -4, c.Dim(in, 1), &replaced),
982       StatusIs(error::INVALID_ARGUMENT, HasSubstr("Out of range dim_index")));
983   EXPECT_FALSE(IsSet(replaced));
984 }
985 
TEST_F(ShapeInferenceTest,MakeShape)986 TEST_F(ShapeInferenceTest, MakeShape) {
987   NodeDef def;
988   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, 3, -1, 5})}, {},
989                      {}, {});
990 
991   std::vector<DimensionHandle> dims;
992   auto in0 = c.input(0);
993   const int rank = c.Rank(in0);
994   dims.reserve(rank);
995   for (int i = 0; i < rank; ++i) {
996     dims.push_back(c.Dim(in0, rank - i - 1));
997   }
998 
999   auto s = c.MakeShape(dims);
1000   EXPECT_EQ("[5,?,3,2,1]", c.DebugString(s));
1001   EXPECT_TRUE(SameHandle(c.Dim(s, 0), c.Dim(in0, rank - 1)));
1002 
1003   auto s2 = c.MakeShape(dims);
1004   EXPECT_FALSE(SameHandle(s, s2));
1005   EXPECT_TRUE(SameHandle(c.Dim(s2, 0), c.Dim(in0, rank - 1)));
1006 
1007   auto s3 = c.MakeShape({1, 2, dims[2]});
1008   EXPECT_FALSE(SameHandle(s, s3));
1009   EXPECT_EQ("[1,2,3]", c.DebugString(s3));
1010 }
1011 
TEST_F(ShapeInferenceTest,UnknownShape)1012 TEST_F(ShapeInferenceTest, UnknownShape) {
1013   NodeDef def;
1014   std::vector<ShapeHandle> empty;
1015   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1016 
1017   auto u0 = c.UnknownShape();
1018   auto u1 = c.UnknownShape();
1019   EXPECT_EQ("?", c.DebugString(u0));
1020   EXPECT_EQ("?", c.DebugString(u1));
1021   EXPECT_FALSE(SameHandle(u0, u1));
1022 }
1023 
TEST_F(ShapeInferenceTest,KnownShapeToProto)1024 TEST_F(ShapeInferenceTest, KnownShapeToProto) {
1025   NodeDef def;
1026   std::vector<ShapeHandle> empty;
1027   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1028 
1029   auto s = c.MakeShape({1, 2, 3});
1030   TensorShapeProto proto;
1031   c.ShapeHandleToProto(s, &proto);
1032 
1033   EXPECT_FALSE(proto.unknown_rank());
1034   EXPECT_EQ(3, proto.dim_size());
1035   EXPECT_EQ(1, proto.dim(0).size());
1036 }
1037 
TEST_F(ShapeInferenceTest,UnknownShapeToProto)1038 TEST_F(ShapeInferenceTest, UnknownShapeToProto) {
1039   NodeDef def;
1040   std::vector<ShapeHandle> empty;
1041   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1042 
1043   auto u0 = c.UnknownShape();
1044   TensorShapeProto proto;
1045   c.ShapeHandleToProto(u0, &proto);
1046 
1047   EXPECT_TRUE(proto.unknown_rank());
1048   EXPECT_EQ(0, proto.dim_size());
1049 }
1050 
TEST_F(ShapeInferenceTest,Scalar)1051 TEST_F(ShapeInferenceTest, Scalar) {
1052   NodeDef def;
1053   std::vector<ShapeHandle> empty;
1054   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1055 
1056   auto s0 = c.Scalar();
1057   EXPECT_EQ("[]", c.DebugString(s0));
1058   auto s1 = c.Scalar();
1059   EXPECT_EQ("[]", c.DebugString(s1));
1060   EXPECT_FALSE(SameHandle(s0, s1));
1061 }
1062 
TEST_F(ShapeInferenceTest,Vector)1063 TEST_F(ShapeInferenceTest, Vector) {
1064   NodeDef def;
1065   std::vector<ShapeHandle> empty;
1066   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1067 
1068   auto s0 = c.Vector(1);
1069   EXPECT_EQ("[1]", c.DebugString(s0));
1070   auto s1 = c.Vector(InferenceContext::kUnknownDim);
1071   EXPECT_EQ("[?]", c.DebugString(s1));
1072 
1073   auto d1 = c.UnknownDim();
1074   auto s2 = c.Vector(d1);
1075   EXPECT_EQ("[?]", c.DebugString(s2));
1076   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1077 }
1078 
TEST_F(ShapeInferenceTest,Matrix)1079 TEST_F(ShapeInferenceTest, Matrix) {
1080   NodeDef def;
1081   std::vector<ShapeHandle> empty;
1082   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1083 
1084   auto s0 = c.Matrix(1, 2);
1085   EXPECT_EQ("[1,2]", c.DebugString(s0));
1086   auto s1 = c.Matrix(0, InferenceContext::kUnknownDim);
1087   EXPECT_EQ("[0,?]", c.DebugString(s1));
1088 
1089   auto d1 = c.UnknownDim();
1090   auto d2 = c.UnknownDim();
1091   auto s2 = c.Matrix(d1, d2);
1092   EXPECT_EQ("[?,?]", c.DebugString(s2));
1093   EXPECT_TRUE(SameHandle(d1, c.Dim(s2, 0)));
1094   EXPECT_TRUE(SameHandle(d2, c.Dim(s2, 1)));
1095 
1096   auto s3 = c.Matrix(d1, 100);
1097   EXPECT_EQ("[?,100]", c.DebugString(s3));
1098   EXPECT_TRUE(SameHandle(d1, c.Dim(s3, 0)));
1099 }
1100 
TEST_F(ShapeInferenceTest,MakeShapeFromShapeTensor)1101 TEST_F(ShapeInferenceTest, MakeShapeFromShapeTensor) {
1102   auto create = [&](Tensor* t) {
1103     NodeDef def;
1104     InferenceContext c(kVersion, def, MakeOpDef(1, 0), {Unknown()}, {t}, {},
1105                        {});
1106     ShapeHandle out;
1107     Status s = c.MakeShapeFromShapeTensor(0, &out);
1108     if (s.ok()) {
1109       return c.DebugString(out);
1110     } else {
1111       EXPECT_FALSE(IsSet(out));
1112       return s.error_message();
1113     }
1114   };
1115 
1116   Tensor t;
1117   EXPECT_EQ("?", create(nullptr));
1118 
1119   t = ::tensorflow::test::AsTensor<int32>({1, 2, 3});
1120   EXPECT_EQ("[1,2,3]", create(&t));
1121 
1122   t = ::tensorflow::test::AsTensor<int64_t>({3, 2, 1});
1123   EXPECT_EQ("[3,2,1]", create(&t));
1124 
1125   t = ::tensorflow::test::AsTensor<int64_t>({3, -1, 1});
1126   EXPECT_EQ("[3,?,1]", create(&t));
1127 
1128   t = ::tensorflow::test::AsTensor<int64_t>({});
1129   EXPECT_EQ("[]", create(&t));
1130 
1131   // Test negative scalar
1132   t = ::tensorflow::test::AsScalar<int32>(-1);
1133   EXPECT_EQ("?", create(&t));
1134 
1135   t = ::tensorflow::test::AsTensor<float>({1, 2, 3});
1136   EXPECT_THAT(create(&t),
1137               HasSubstr("Input tensor must be int32 or int64, but was float"));
1138 
1139   t = ::tensorflow::test::AsScalar<int32>(1);
1140   auto s_scalar = create(&t);
1141   EXPECT_THAT(s_scalar, HasSubstr("Input tensor must be rank 1, or if its rank "
1142                                   "0 it must have value -1"));
1143 
1144   t = ::tensorflow::test::AsTensor<int32>({1, 2}, TensorShape{2, 1});
1145   auto s_matrix = create(&t);
1146   EXPECT_THAT(s_matrix,
1147               HasSubstr("Input tensor must be rank 1, but was rank 2"));
1148 
1149   // Test negative values for the dims.
1150   t = ::tensorflow::test::AsTensor<int64_t>({3, -2, 1});
1151   EXPECT_THAT(create(&t),
1152               HasSubstr("Invalid value in tensor used for shape: -2"));
1153 
1154   // Test negative values for the dims.
1155   t = ::tensorflow::test::AsTensor<int32>({3, -2, 1});
1156   EXPECT_THAT(create(&t),
1157               HasSubstr("Invalid value in tensor used for shape: -2"));
1158 
1159   // Test when the input shape is wrong.
1160   {
1161     NodeDef def;
1162     InferenceContext c(kVersion, def, MakeOpDef(1, 0), {S({1, -1})}, {nullptr},
1163                        {}, {});
1164     ShapeHandle out;
1165     EXPECT_EQ("Shape must be rank 1 but is rank 2",
1166               c.MakeShapeFromShapeTensor(0, &out).error_message());
1167   }
1168 }
1169 
TEST_F(ShapeInferenceTest,MakeShapeFromPartialTensorShape)1170 TEST_F(ShapeInferenceTest, MakeShapeFromPartialTensorShape) {
1171   NodeDef def;
1172   std::vector<ShapeHandle> empty;
1173   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1174 
1175   // With an unknown rank.
1176   ShapeHandle out;
1177   TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(PartialTensorShape(), &out));
1178   EXPECT_EQ("?", c.DebugString(out));
1179 
1180   // With a known rank.
1181   TF_ASSERT_OK(
1182       c.MakeShapeFromPartialTensorShape(PartialTensorShape({0}), &out));
1183   EXPECT_EQ("[0]", c.DebugString(out));
1184   TF_ASSERT_OK(c.MakeShapeFromPartialTensorShape(
1185       PartialTensorShape({0, -1, 1000}), &out));
1186   EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1187 }
1188 
TEST_F(ShapeInferenceTest,MakeShapeFromTensorShape)1189 TEST_F(ShapeInferenceTest, MakeShapeFromTensorShape) {
1190   NodeDef def;
1191   std::vector<ShapeHandle> empty;
1192   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1193 
1194   ShapeHandle out;
1195   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape(), &out));
1196   EXPECT_EQ("[]", c.DebugString(out));
1197   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0}), &out));
1198   EXPECT_EQ("[0]", c.DebugString(out));
1199   TF_ASSERT_OK(c.MakeShapeFromTensorShape(TensorShape({0, 7, 1000}), &out));
1200   EXPECT_EQ("[0,7,1000]", c.DebugString(out));
1201 }
1202 
TEST_F(ShapeInferenceTest,MakeShapeFromShapeProto)1203 TEST_F(ShapeInferenceTest, MakeShapeFromShapeProto) {
1204   NodeDef def;
1205   std::vector<ShapeHandle> empty;
1206   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1207   TensorShapeProto proto;
1208 
1209   // With a set unknown rank.
1210   ShapeHandle out;
1211   proto.set_unknown_rank(true);
1212   TF_EXPECT_OK(c.MakeShapeFromShapeProto(proto, &out));
1213   EXPECT_EQ("?", c.DebugString(out));
1214   proto.add_dim()->set_size(0);
1215   EXPECT_THAT(
1216       c.MakeShapeFromShapeProto(proto, &out),
1217       StatusIs(error::INVALID_ARGUMENT,
1218                HasSubstr("An unknown shape must not have any dimensions set")));
1219   EXPECT_FALSE(IsSet(out));
1220 
1221   // With known rank.
1222   proto.set_unknown_rank(false);
1223   TF_EXPECT_OK(c.MakeShapeFromShapeProto(proto, &out));
1224   EXPECT_EQ("[0]", c.DebugString(out));
1225   proto.add_dim()->set_size(-1);
1226   proto.add_dim()->set_size(1000);
1227   TF_EXPECT_OK(c.MakeShapeFromShapeProto(proto, &out));
1228   EXPECT_EQ("[0,?,1000]", c.DebugString(out));
1229 
1230   // With invalid dimension value.
1231   proto.add_dim()->set_size(-2);
1232   EXPECT_THAT(
1233       c.MakeShapeFromShapeProto(proto, &out),
1234       StatusIs(
1235           error::INVALID_ARGUMENT,
1236           HasSubstr("Shape [0,?,1000,-2] has dimensions with values below -1 "
1237                     "(where -1 means unknown)")));
1238 
1239   EXPECT_FALSE(IsSet(out));
1240 }
1241 
TEST_F(ShapeInferenceTest,MakeDim)1242 TEST_F(ShapeInferenceTest, MakeDim) {
1243   NodeDef def;
1244   std::vector<ShapeHandle> empty;
1245   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1246 
1247   auto d0 = c.MakeDim(1);
1248   auto d1 = c.MakeDim(1);
1249   auto d2 = c.MakeDim(2);
1250   EXPECT_EQ("1", c.DebugString(d0));
1251   EXPECT_EQ("1", c.DebugString(d1));
1252   EXPECT_FALSE(SameHandle(d0, d1));
1253   EXPECT_EQ("2", c.DebugString(d2));
1254 }
1255 
TEST_F(ShapeInferenceTest,UnknownDim)1256 TEST_F(ShapeInferenceTest, UnknownDim) {
1257   NodeDef def;
1258   std::vector<ShapeHandle> empty;
1259   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1260 
1261   auto d0 = c.UnknownDim();
1262   auto d1 = c.UnknownDim();
1263   EXPECT_EQ("?", c.DebugString(d0));
1264   EXPECT_EQ("?", c.DebugString(d1));
1265   EXPECT_FALSE(SameHandle(d0, d1));
1266 }
1267 
TEST_F(ShapeInferenceTest,UnknownShapeOfRank)1268 TEST_F(ShapeInferenceTest, UnknownShapeOfRank) {
1269   NodeDef def;
1270   std::vector<ShapeHandle> empty;
1271   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1272 
1273   auto unknown_shape_of_rank_3 = c.UnknownShapeOfRank(3);
1274   EXPECT_EQ("[?,?,?]", c.DebugString(unknown_shape_of_rank_3));
1275 
1276   auto unknown_shape_of_rank_0 = c.UnknownShapeOfRank(0);
1277   EXPECT_EQ("[]", c.DebugString(unknown_shape_of_rank_0));
1278 }
1279 
TEST_F(ShapeInferenceTest,InputTensors)1280 TEST_F(ShapeInferenceTest, InputTensors) {
1281   const Tensor t1 = tensorflow::test::AsTensor<float>({10});
1282   const Tensor t2 = tensorflow::test::AsTensor<float>({20, 30});
1283   NodeDef def;
1284   InferenceContext c(kVersion, def, MakeOpDef(3, 2), {S({1}), S({2}), S({3})},
1285                      {&t1, &t2}, {}, {});
1286 
1287   EXPECT_TRUE(c.input_tensor(0) == &t1);
1288   EXPECT_TRUE(c.input_tensor(1) == &t2);
1289   EXPECT_TRUE(c.input_tensor(2) == nullptr);
1290 }
1291 
TEST_F(ShapeInferenceTest,MakeDimForScalarInput)1292 TEST_F(ShapeInferenceTest, MakeDimForScalarInput) {
1293   Tensor t1 = tensorflow::test::AsScalar<int32>(20);
1294   Tensor t2 = tensorflow::test::AsScalar<int32>(-1);
1295   NodeDef def;
1296   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
1297                      {}, {});
1298 
1299   DimensionHandle d;
1300   TF_EXPECT_OK(c.MakeDimForScalarInput(0, &d));
1301   EXPECT_EQ("20", c.DebugString(d));
1302 
1303   EXPECT_THAT(
1304       c.MakeDimForScalarInput(1, &d),
1305       StatusIs(error::INVALID_ARGUMENT,
1306                HasSubstr("Dimension size, given by scalar input 1, must be "
1307                          "non-negative but is -1")));
1308 
1309   // Same tests, with int64 values.
1310   t1 = tensorflow::test::AsScalar<int64_t>(20);
1311   t2 = tensorflow::test::AsScalar<int64_t>(-1);
1312   TF_EXPECT_OK(c.MakeDimForScalarInput(0, &d));
1313   EXPECT_EQ("20", c.DebugString(d));
1314 
1315   EXPECT_THAT(
1316       c.MakeDimForScalarInput(1, &d),
1317       StatusIs(error::INVALID_ARGUMENT,
1318                HasSubstr("Dimension size, given by scalar input 1, must be "
1319                          "non-negative but is -1")));
1320 }
1321 
TEST_F(ShapeInferenceTest,MakeDimForScalarInputWithNegativeIndexing)1322 TEST_F(ShapeInferenceTest, MakeDimForScalarInputWithNegativeIndexing) {
1323   Tensor t1 = tensorflow::test::AsScalar<int32>(-2);
1324   Tensor t2 = tensorflow::test::AsScalar<int32>(3);
1325   NodeDef def;
1326   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {&t1, &t2},
1327                      {}, {});
1328 
1329   DimensionHandle d;
1330 
1331   // Negative input rank and a negative value in the input tensor results in an
1332   // unknown dimension.
1333   TF_EXPECT_OK(c.MakeDimForScalarInputWithNegativeIndexing(0, -1, &d));
1334   EXPECT_EQ("?", c.DebugString(d));
1335 
1336   TF_EXPECT_OK(c.MakeDimForScalarInputWithNegativeIndexing(0, 4, &d));
1337   EXPECT_EQ("2", c.DebugString(d));
1338 
1339   EXPECT_THAT(c.MakeDimForScalarInputWithNegativeIndexing(0, 1, &d),
1340               StatusIs(error::INVALID_ARGUMENT,
1341                        HasSubstr("Dimension size, given by scalar input -2 "
1342                                  "must be in range [-1, 1)")));
1343 
1344   TF_EXPECT_OK(c.MakeDimForScalarInputWithNegativeIndexing(1, 4, &d));
1345   EXPECT_EQ("3", c.DebugString(d));
1346 
1347   EXPECT_THAT(c.MakeDimForScalarInputWithNegativeIndexing(1, 2, &d),
1348               StatusIs(error::INVALID_ARGUMENT,
1349                        HasSubstr("Dimension size, given by scalar input 3 "
1350                                  "must be in range [-2, 2)")));
1351 }
1352 
TEST_F(ShapeInferenceTest,GetAttr)1353 TEST_F(ShapeInferenceTest, GetAttr) {
1354   OpRegistrationData op_reg_data;
1355   op_reg_data.op_def = MakeOpDef(0, 2);
1356   NodeDef def;
1357   CHECK(NodeDefBuilder("dummy", &op_reg_data.op_def)
1358             .Attr("foo", "bar")
1359             .Finalize(&def)
1360             .ok());
1361 
1362   std::vector<ShapeHandle> empty;
1363   InferenceContext c(kVersion, def, op_reg_data.op_def, empty, {}, {}, {});
1364   string value;
1365   TF_EXPECT_OK(c.GetAttr("foo", &value));
1366   EXPECT_EQ("bar", value);
1367 }
1368 
TEST_F(ShapeInferenceTest,Divide)1369 TEST_F(ShapeInferenceTest, Divide) {
1370   NodeDef def;
1371   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 1, 2, 0})}, {},
1372                      {}, {});
1373 
1374   auto s = c.input(0);
1375   auto d_6 = c.Dim(s, 0);
1376   auto d_unknown = c.Dim(s, 1);
1377   auto d_1 = c.Dim(s, 2);
1378   auto d_2 = c.Dim(s, 3);
1379   auto d_0 = c.Dim(s, 4);
1380   bool evenly_divisible = true;
1381 
1382   // Dividing unknown by non-1 gives new unknown.
1383   DimensionHandle out;
1384   TF_EXPECT_OK(c.Divide(d_unknown, 2, evenly_divisible, &out));
1385   EXPECT_EQ("?", c.DebugString(out));
1386   EXPECT_FALSE(SameHandle(out, d_unknown));
1387 
1388   // Dividing anything by 1 returns the input.
1389   TF_EXPECT_OK(c.Divide(d_unknown, 1, evenly_divisible, &out));
1390   EXPECT_TRUE(SameHandle(out, d_unknown));
1391   TF_EXPECT_OK(c.Divide(d_6, 1, evenly_divisible, &out));
1392   EXPECT_TRUE(SameHandle(out, d_6));
1393   TF_EXPECT_OK(c.Divide(d_unknown, d_1, evenly_divisible, &out));
1394   EXPECT_TRUE(SameHandle(out, d_unknown));
1395   TF_EXPECT_OK(c.Divide(d_6, d_1, evenly_divisible, &out));
1396   EXPECT_TRUE(SameHandle(out, d_6));
1397 
1398   TF_EXPECT_OK(c.Divide(d_6, 2, evenly_divisible, &out));
1399   EXPECT_EQ("3", c.DebugString(out));
1400   TF_EXPECT_OK(c.Divide(d_6, d_2, evenly_divisible, &out));
1401   EXPECT_EQ("3", c.DebugString(out));
1402 
1403   EXPECT_THAT(
1404       c.Divide(d_6, 5, evenly_divisible, &out),
1405       StatusIs(
1406           error::INVALID_ARGUMENT,
1407           HasSubstr("Dimension size must be evenly divisible by 5 but is 6")));
1408 
1409   EXPECT_THAT(c.Divide(d_6, 0, evenly_divisible, &out),
1410               StatusIs(error::INVALID_ARGUMENT,
1411                        HasSubstr("Divisor must be positive but is 0")));
1412   EXPECT_THAT(c.Divide(d_6, d_0, evenly_divisible, &out),
1413               StatusIs(error::INVALID_ARGUMENT,
1414                        HasSubstr("Divisor must be positive but is 0")));
1415 
1416   EXPECT_THAT(c.Divide(d_6, -1, evenly_divisible, &out),
1417               StatusIs(error::INVALID_ARGUMENT,
1418                        HasSubstr("Divisor must be positive but is -1")));
1419 
1420   // Repeat error cases above with evenly_divisible=false.
1421   evenly_divisible = false;
1422   TF_EXPECT_OK(c.Divide(d_6, 5, evenly_divisible, &out));
1423   EXPECT_EQ("1", c.DebugString(out));
1424 
1425   EXPECT_THAT(c.Divide(d_6, 0, evenly_divisible, &out),
1426               StatusIs(error::INVALID_ARGUMENT,
1427                        HasSubstr("Divisor must be positive but is 0")));
1428 
1429   EXPECT_THAT(c.Divide(d_6, -1, evenly_divisible, &out),
1430               StatusIs(error::INVALID_ARGUMENT,
1431                        HasSubstr("Divisor must be positive but is -1")));
1432 }
1433 
TEST_F(ShapeInferenceTest,Add)1434 TEST_F(ShapeInferenceTest, Add) {
1435   NodeDef def;
1436   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0})}, {}, {},
1437                      {});
1438 
1439   auto s = c.input(0);
1440   auto d_6 = c.Dim(s, 0);
1441   auto d_unknown = c.Dim(s, 1);
1442   auto d_0 = c.Dim(s, 2);
1443 
1444   // Adding non-zero to unknown gives new unknown.
1445   DimensionHandle out;
1446   TF_EXPECT_OK(c.Add(d_unknown, 1, &out));
1447   EXPECT_EQ("?", c.DebugString(out));
1448   EXPECT_FALSE(SameHandle(out, d_unknown));
1449 
1450   // Adding 0 to anything gives input.
1451   TF_EXPECT_OK(c.Add(d_unknown, 0, &out));
1452   EXPECT_TRUE(SameHandle(out, d_unknown));
1453   TF_EXPECT_OK(c.Add(d_6, 0, &out));
1454   EXPECT_TRUE(SameHandle(out, d_6));
1455 
1456   // Adding dimension with value 0 to anything gives input.
1457   TF_EXPECT_OK(c.Add(d_unknown, c.MakeDim(0ll), &out));
1458   EXPECT_TRUE(SameHandle(out, d_unknown));
1459   TF_EXPECT_OK(c.Add(d_6, c.MakeDim(0ll), &out));
1460   EXPECT_TRUE(SameHandle(out, d_6));
1461 
1462   // Test addition.
1463   TF_EXPECT_OK(c.Add(d_6, 2, &out));
1464   EXPECT_EQ("8", c.DebugString(out));
1465   TF_EXPECT_OK(c.Add(d_6, std::numeric_limits<int64_t>::max() - 6, &out));
1466   EXPECT_EQ(std::numeric_limits<int64_t>::max(), c.Value(out));
1467 
1468   // Test addition using dimension as second value.
1469   TF_EXPECT_OK(c.Add(d_6, c.MakeDim(2), &out));
1470   EXPECT_EQ("8", c.DebugString(out));
1471   EXPECT_TRUE(
1472       c.Add(d_6, c.MakeDim(std::numeric_limits<int64_t>::max() - 6), &out)
1473           .ok());
1474   EXPECT_EQ(std::numeric_limits<int64_t>::max(), c.Value(out));
1475   TF_EXPECT_OK(c.Add(d_6, c.UnknownDim(), &out));
1476   EXPECT_EQ("?", c.DebugString(out));
1477   TF_EXPECT_OK(c.Add(d_0, d_6, &out));
1478   EXPECT_TRUE(SameHandle(out, d_6));
1479 
1480   EXPECT_THAT(c.Add(d_6, std::numeric_limits<int64_t>::max() - 5, &out),
1481               StatusIs(error::INVALID_ARGUMENT,
1482                        HasSubstr("Dimension size overflow from adding 6 and "
1483                                  "9223372036854775802")));
1484 }
1485 
TEST_F(ShapeInferenceTest,Subtract)1486 TEST_F(ShapeInferenceTest, Subtract) {
1487   NodeDef def;
1488   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 5})}, {}, {},
1489                      {});
1490 
1491   auto s = c.input(0);
1492   auto d_6 = c.Dim(s, 0);
1493   auto d_unknown = c.Dim(s, 1);
1494   auto d_0 = c.Dim(s, 2);
1495   auto d_5 = c.Dim(s, 3);
1496 
1497   // Subtracting non-zero from unknown gives new unknown.
1498   DimensionHandle out;
1499   TF_EXPECT_OK(c.Subtract(d_unknown, 1, &out));
1500   EXPECT_EQ("?", c.DebugString(out));
1501   EXPECT_FALSE(SameHandle(out, d_unknown));
1502 
1503   // Subtracting 0 from anything gives input.
1504   TF_EXPECT_OK(c.Subtract(d_unknown, 0ll, &out));
1505   EXPECT_TRUE(SameHandle(out, d_unknown));
1506   TF_EXPECT_OK(c.Subtract(d_6, 0ll, &out));
1507   EXPECT_TRUE(SameHandle(out, d_6));
1508 
1509   // Subtracting dimension with value 0 from anything gives input.
1510   TF_EXPECT_OK(c.Subtract(d_unknown, c.MakeDim(0ll), &out));
1511   EXPECT_TRUE(SameHandle(out, d_unknown));
1512   TF_EXPECT_OK(c.Subtract(d_6, c.MakeDim(0ll), &out));
1513   EXPECT_TRUE(SameHandle(out, d_6));
1514 
1515   // Test subtraction.
1516   TF_EXPECT_OK(c.Subtract(d_6, 2, &out));
1517   EXPECT_EQ("4", c.DebugString(out));
1518   TF_EXPECT_OK(c.Subtract(d_6, 6, &out));
1519   EXPECT_EQ("0", c.DebugString(out));
1520 
1521   // Test subtraction using dimension as second value.
1522   TF_EXPECT_OK(c.Subtract(d_6, c.MakeDim(2), &out));
1523   EXPECT_EQ("4", c.DebugString(out));
1524   TF_EXPECT_OK(c.Subtract(d_6, d_5, &out));
1525   EXPECT_EQ("1", c.DebugString(out));
1526   TF_EXPECT_OK(c.Subtract(d_6, c.UnknownDim(), &out));
1527   EXPECT_EQ("?", c.DebugString(out));
1528   TF_EXPECT_OK(c.Subtract(d_6, d_0, &out));
1529   EXPECT_TRUE(SameHandle(out, d_6));
1530 
1531   EXPECT_THAT(
1532       c.Subtract(d_5, d_6, &out),
1533       StatusIs(
1534           error::INVALID_ARGUMENT,
1535           HasSubstr("Negative dimension size caused by subtracting 6 from 5")));
1536 }
1537 
TEST_F(ShapeInferenceTest,Multiply)1538 TEST_F(ShapeInferenceTest, Multiply) {
1539   NodeDef def;
1540   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({6, -1, 0, 1})}, {}, {},
1541                      {});
1542 
1543   auto s = c.input(0);
1544   auto d_6 = c.Dim(s, 0);
1545   auto d_unknown = c.Dim(s, 1);
1546   auto d_0 = c.Dim(s, 2);
1547   auto d_1 = c.Dim(s, 3);
1548 
1549   // Multiplying non-zero to unknown gives new unknown.
1550   DimensionHandle out;
1551   TF_EXPECT_OK(c.Multiply(d_unknown, 2, &out));
1552   EXPECT_EQ("?", c.DebugString(out));
1553 
1554   // Multiplying 0 to anything gives 0.
1555   TF_EXPECT_OK(c.Multiply(d_unknown, 0, &out));
1556   EXPECT_EQ("0", c.DebugString(out));
1557   TF_EXPECT_OK(c.Multiply(d_unknown, d_0, &out));
1558   EXPECT_EQ("0", c.DebugString(out));
1559   TF_EXPECT_OK(c.Multiply(d_0, d_unknown, &out));
1560   EXPECT_EQ("0", c.DebugString(out));
1561 
1562   // Multiplying 1 to anything gives the original.
1563   // (unknown -> unknown)
1564   TF_EXPECT_OK(c.Multiply(d_unknown, 1, &out));
1565   EXPECT_TRUE(SameHandle(d_unknown, out));
1566   TF_EXPECT_OK(c.Multiply(d_unknown, d_1, &out));
1567   EXPECT_TRUE(SameHandle(d_unknown, out));
1568   TF_EXPECT_OK(c.Multiply(d_1, d_unknown, &out));
1569   EXPECT_TRUE(SameHandle(d_unknown, out));
1570   // (known -> known)
1571   TF_EXPECT_OK(c.Multiply(d_6, 1, &out));
1572   EXPECT_TRUE(SameHandle(d_6, out));
1573   TF_EXPECT_OK(c.Multiply(d_6, d_1, &out));
1574   EXPECT_TRUE(SameHandle(d_6, out));
1575   TF_EXPECT_OK(c.Multiply(d_1, d_6, &out));
1576   EXPECT_TRUE(SameHandle(d_6, out));
1577 
1578   // Test multiplication.
1579   TF_EXPECT_OK(c.Multiply(d_6, 2, &out));
1580   EXPECT_EQ("12", c.DebugString(out));
1581   TF_EXPECT_OK(c.Multiply(d_6, 6, &out));
1582   EXPECT_EQ("36", c.DebugString(out));
1583 
1584   // Test multiplication using dimension as second value.
1585   TF_EXPECT_OK(c.Multiply(d_6, c.MakeDim(2), &out));
1586   EXPECT_EQ("12", c.DebugString(out));
1587   TF_EXPECT_OK(c.Multiply(d_6, c.UnknownDim(), &out));
1588   EXPECT_EQ("?", c.DebugString(out));
1589 
1590   EXPECT_THAT(
1591       c.Multiply(d_6, std::numeric_limits<int64_t>::max() / 2, &out),
1592       StatusIs(error::INVALID_ARGUMENT,
1593                HasSubstr("Negative dimension size caused by overflow")));
1594 }
1595 
TEST_F(ShapeInferenceTest,FullyDefined)1596 TEST_F(ShapeInferenceTest, FullyDefined) {
1597   NodeDef def;
1598   std::vector<ShapeHandle> empty;
1599   InferenceContext c(kVersion, def, MakeOpDef(0, 2), empty, {}, {}, {});
1600 
1601   // No rank or missing dimension information should return false.
1602   EXPECT_FALSE(c.FullyDefined(c.UnknownShape()));
1603   EXPECT_FALSE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.UnknownDim())));
1604 
1605   // Return true if all information exists.
1606   EXPECT_TRUE(c.FullyDefined(c.Matrix(c.MakeDim(1), c.MakeDim(2))));
1607   EXPECT_TRUE(c.FullyDefined(c.Scalar()));
1608 }
1609 
TEST_F(ShapeInferenceTest,Min)1610 TEST_F(ShapeInferenceTest, Min) {
1611   NodeDef def;
1612   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1, 0})}, {}, {},
1613                      {});
1614 
1615   auto s = c.input(0);
1616   auto d_1 = c.Dim(s, 0);
1617   auto d_2 = c.Dim(s, 1);
1618   auto d_unknown = c.Dim(s, 2);
1619   auto d_0 = c.Dim(s, 3);
1620 
1621   // Minimum involving zero and unknown returns zero.
1622   DimensionHandle out;
1623   TF_EXPECT_OK(c.Min(d_0, d_unknown, &out));
1624   EXPECT_TRUE(SameHandle(d_0, out));
1625   TF_EXPECT_OK(c.Min(d_unknown, d_0, &out));
1626   EXPECT_TRUE(SameHandle(d_0, out));
1627   TF_EXPECT_OK(c.Min(c.MakeDim(0ll), d_unknown, &out));
1628   EXPECT_EQ("0", c.DebugString(out));
1629   TF_EXPECT_OK(c.Min(d_unknown, 0ll, &out));
1630   EXPECT_EQ("0", c.DebugString(out));
1631 
1632   // Minimum involving unknowns and non-zeros gives new unknown.
1633   TF_EXPECT_OK(c.Min(d_unknown, d_unknown, &out));
1634   EXPECT_EQ("?", c.DebugString(out));
1635   TF_EXPECT_OK(c.Min(d_unknown, 1, &out));
1636   EXPECT_EQ("?", c.DebugString(out));
1637   TF_EXPECT_OK(c.Min(d_1, d_unknown, &out));
1638   EXPECT_EQ("?", c.DebugString(out));
1639 
1640   // Minimum with constant second arg.
1641   TF_EXPECT_OK(c.Min(d_1, 1, &out));
1642   EXPECT_TRUE(SameHandle(d_1, out));
1643   TF_EXPECT_OK(c.Min(d_1, 3, &out));
1644   EXPECT_TRUE(SameHandle(d_1, out));
1645   TF_EXPECT_OK(c.Min(d_2, 1, &out));
1646   EXPECT_EQ("1", c.DebugString(out));
1647 
1648   // Minimum with two dimensions.
1649   TF_EXPECT_OK(c.Min(d_1, d_1, &out));
1650   EXPECT_TRUE(SameHandle(d_1, out));
1651   TF_EXPECT_OK(c.Min(d_1, d_2, &out));
1652   EXPECT_TRUE(SameHandle(d_1, out));
1653   TF_EXPECT_OK(c.Min(d_2, d_1, &out));
1654   EXPECT_TRUE(SameHandle(d_1, out));
1655   TF_EXPECT_OK(c.Min(d_2, d_2, &out));
1656   EXPECT_TRUE(SameHandle(d_2, out));
1657 }
1658 
TEST_F(ShapeInferenceTest,Max)1659 TEST_F(ShapeInferenceTest, Max) {
1660   NodeDef def;
1661   InferenceContext c(kVersion, def, MakeOpDef(1, 2), {S({1, 2, -1})}, {}, {},
1662                      {});
1663 
1664   auto s = c.input(0);
1665   auto d_1 = c.Dim(s, 0);
1666   auto d_2 = c.Dim(s, 1);
1667   auto d_unknown = c.Dim(s, 2);
1668 
1669   // Maximum involving unknowns gives new unknown.
1670   DimensionHandle out;
1671   TF_EXPECT_OK(c.Max(d_unknown, d_unknown, &out));
1672   EXPECT_EQ("?", c.DebugString(out));
1673   TF_EXPECT_OK(c.Max(d_unknown, 1, &out));
1674   EXPECT_EQ("?", c.DebugString(out));
1675   TF_EXPECT_OK(c.Max(d_1, d_unknown, &out));
1676   EXPECT_EQ("?", c.DebugString(out));
1677 
1678   // Maximum with constant second arg.
1679   TF_EXPECT_OK(c.Max(d_1, 1, &out));
1680   EXPECT_TRUE(SameHandle(d_1, out));
1681   TF_EXPECT_OK(c.Max(d_2, 1, &out));
1682   EXPECT_TRUE(SameHandle(d_2, out));
1683   TF_EXPECT_OK(c.Max(d_2, 3, &out));
1684   EXPECT_EQ("3", c.DebugString(out));
1685 
1686   // Maximum with two dimensions.
1687   TF_EXPECT_OK(c.Max(d_1, d_1, &out));
1688   EXPECT_TRUE(SameHandle(d_1, out));
1689   TF_EXPECT_OK(c.Max(d_1, d_2, &out));
1690   EXPECT_TRUE(SameHandle(d_2, out));
1691   TF_EXPECT_OK(c.Max(d_2, d_1, &out));
1692   EXPECT_TRUE(SameHandle(d_2, out));
1693   TF_EXPECT_OK(c.Max(d_2, d_2, &out));
1694   EXPECT_TRUE(SameHandle(d_2, out));
1695 }
1696 
1697 class ShapeInferenceHandlesTest : public ShapeInferenceTest,
1698                                   public ::testing::WithParamInterface<bool> {};
1699 
1700 INSTANTIATE_TEST_SUITE_P(All, ShapeInferenceHandlesTest,
1701                          ::testing::Bool() /* input_not_output */,
1702                          ::testing::PrintToStringParamName());
1703 
TEST_P(ShapeInferenceHandlesTest,MergeHandles)1704 TEST_P(ShapeInferenceHandlesTest, MergeHandles) {
1705   bool input_not_output = GetParam();
1706   NodeDef def;
1707   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1708                      {});
1709   auto make_shape = [&c](std::initializer_list<int64_t> dim_sizes) {
1710     ShapeHandle s;
1711     TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1712     return s;
1713   };
1714   auto get_shapes_and_types_from_context = [&](int idx) {
1715     if (input_not_output) {
1716       return c.input_handle_shapes_and_types(idx);
1717     } else {
1718       return c.output_handle_shapes_and_types(idx);
1719     }
1720   };
1721   auto merge_shapes_and_types_to_context =
1722       [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1723         if (input_not_output) {
1724           return c.MergeInputHandleShapesAndTypes(idx, shapes_and_types);
1725         } else {
1726           return c.MergeOutputHandleShapesAndTypes(idx, shapes_and_types);
1727         }
1728       };
1729 
1730   EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1731   EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1732 
1733   // First merge will take the input completely.
1734   std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1735                               {c.UnknownShape(), DT_INVALID},
1736                               {make_shape({4, 3, 2, 1}), DT_INT32}};
1737   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1738   ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1739   std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1740   ASSERT_EQ(3, v.size());
1741   for (int i = 0; i < v.size(); ++i) {
1742     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1743     EXPECT_EQ(t[i].dtype, v[i].dtype);
1744   }
1745 
1746   // Merge that fails because wrong number of values passed.
1747   // Fails, and no changes made.
1748   ASSERT_FALSE(merge_shapes_and_types_to_context(
1749       0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1750   v = *get_shapes_and_types_from_context(0);
1751   ASSERT_EQ(3, v.size());
1752   for (int i = 0; i < v.size(); ++i) {
1753     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1754     EXPECT_EQ(t[i].dtype, v[i].dtype);
1755   }
1756 
1757   // Only difference is in a mismatched shape. That is ignored,
1758   // and there are no other changes, so nothing is done.
1759   //
1760   // TODO(cwhipkey): in mismatch cases, change Merge*HandleShapesAndTypes to
1761   // return an error (separate error from 'refined' output)?
1762   auto t2 = t;
1763   t2[2].shape = make_shape({4, 3, 4, 1});
1764   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1765   v = *get_shapes_and_types_from_context(0);
1766   ASSERT_EQ(3, v.size());
1767   for (int i = 0; i < v.size(); ++i) {
1768     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1769     EXPECT_EQ(t[i].dtype, v[i].dtype);
1770   }
1771 
1772   // Only difference is in a mismatched dtype, but that cannot be
1773   // updated unless original dtype is DT_INVALID.
1774   t2 = t;
1775   t2[2].dtype = DT_FLOAT;
1776   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t2));
1777   v = *get_shapes_and_types_from_context(0);
1778   ASSERT_EQ(3, v.size());
1779   for (int i = 0; i < v.size(); ++i) {
1780     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1781     EXPECT_EQ(t[i].dtype, v[i].dtype);
1782   }
1783 
1784   // Difference is mergeable (new shape).
1785   t[1].shape = make_shape({1, 10});
1786   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1787   v = *get_shapes_and_types_from_context(0);
1788   ASSERT_EQ(3, v.size());
1789   for (int i = 0; i < v.size(); ++i) {
1790     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1791     EXPECT_EQ(t[i].dtype, v[i].dtype);
1792   }
1793 
1794   // Difference is mergeable (new type).
1795   t[1].dtype = DT_DOUBLE;
1796   ASSERT_TRUE(merge_shapes_and_types_to_context(0, t));
1797   v = *get_shapes_and_types_from_context(0);
1798   ASSERT_EQ(3, v.size());
1799   for (int i = 0; i < v.size(); ++i) {
1800     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1801     EXPECT_EQ(t[i].dtype, v[i].dtype);
1802   }
1803 
1804   // No difference.
1805   ASSERT_FALSE(merge_shapes_and_types_to_context(0, t));
1806 }
1807 
TEST_P(ShapeInferenceHandlesTest,RelaxHandles)1808 TEST_P(ShapeInferenceHandlesTest, RelaxHandles) {
1809   bool input_not_output = GetParam();
1810   NodeDef def;
1811   InferenceContext c(kVersion, def, MakeOpDef(2, 2), {S({}), S({})}, {}, {},
1812                      {});
1813   auto make_shape = [&c](std::initializer_list<int64_t> dim_sizes) {
1814     ShapeHandle s;
1815     TF_CHECK_OK(c.MakeShapeFromPartialTensorShape(S(dim_sizes), &s));
1816     return s;
1817   };
1818   auto get_shapes_and_types_from_context = [&](int idx) {
1819     if (input_not_output) {
1820       return c.input_handle_shapes_and_types(idx);
1821     } else {
1822       return c.output_handle_shapes_and_types(idx);
1823     }
1824   };
1825   auto relax_shapes_and_types_to_context =
1826       [&](int idx, const std::vector<ShapeAndType>& shapes_and_types) {
1827         if (input_not_output) {
1828           return c.RelaxInputHandleShapesAndMergeTypes(idx, shapes_and_types);
1829         } else {
1830           return c.RelaxOutputHandleShapesAndMergeTypes(idx, shapes_and_types);
1831         }
1832       };
1833 
1834   EXPECT_TRUE(get_shapes_and_types_from_context(0) == nullptr);
1835   EXPECT_TRUE(get_shapes_and_types_from_context(1) == nullptr);
1836 
1837   // First relax will take the input completely.
1838   std::vector<ShapeAndType> t{{make_shape({1, 2, 3}), DT_FLOAT},
1839                               {c.UnknownShape(), DT_INVALID},
1840                               {make_shape({4, 3, 2, 1}), DT_INT32}};
1841   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1842   ASSERT_TRUE(get_shapes_and_types_from_context(0) != nullptr);
1843   std::vector<ShapeAndType> v = *get_shapes_and_types_from_context(0);
1844   ASSERT_EQ(3, v.size());
1845   for (int i = 0; i < v.size(); ++i) {
1846     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1847     EXPECT_EQ(t[i].dtype, v[i].dtype);
1848   }
1849 
1850   // Relax that fails because wrong number of values passed.
1851   // Fails, and no changes made.
1852   ASSERT_FALSE(relax_shapes_and_types_to_context(
1853       0, std::vector<ShapeAndType>{{make_shape({1, 2, 3}), DT_FLOAT}}));
1854   v = *get_shapes_and_types_from_context(0);
1855   ASSERT_EQ(3, v.size());
1856   for (int i = 0; i < v.size(); ++i) {
1857     EXPECT_TRUE(SameHandle(t[i].shape, v[i].shape)) << i;
1858     EXPECT_EQ(t[i].dtype, v[i].dtype);
1859   }
1860 
1861   // Only difference is in a mismatched shape. This should replace
1862   // the mismatched dimension with an UnknownDim.
1863   auto t2 = t;
1864   t2[2].shape = make_shape({4, 3, 4, 1});
1865   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t2));
1866   v = *get_shapes_and_types_from_context(0);
1867   EXPECT_EQ("[4,3,?,1]", c.DebugString(v[2].shape));
1868   for (int i = 0; i < v.size(); ++i) {
1869     EXPECT_EQ(t[i].dtype, v[i].dtype);
1870   }
1871 
1872   // Only difference is in a mismatched dtype, but that cannot be
1873   // updated unless original dtype is DT_INVALID.
1874   t2 = t;
1875   t2[2].dtype = DT_FLOAT;
1876   ASSERT_FALSE(relax_shapes_and_types_to_context(0, t2));
1877   v = *get_shapes_and_types_from_context(0);
1878   ASSERT_EQ(3, v.size());
1879   for (int i = 0; i < v.size(); ++i) {
1880     EXPECT_EQ(t[i].dtype, v[i].dtype);
1881   }
1882 
1883   // Difference is a new shape, which will result in a new UnknownShape.
1884   t[1].shape = make_shape({1, 10});
1885   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1886   v = *get_shapes_and_types_from_context(0);
1887   ASSERT_EQ(3, v.size());
1888   EXPECT_FALSE(SameHandle(t[1].shape, v[1].shape));
1889   EXPECT_EQ("?", c.DebugString(v[1].shape));
1890   for (int i = 0; i < v.size(); ++i) {
1891     EXPECT_EQ(t[i].dtype, v[i].dtype);
1892   }
1893 
1894   // Difference is relaxable (new type).
1895   t[1].dtype = DT_DOUBLE;
1896   ASSERT_TRUE(relax_shapes_and_types_to_context(0, t));
1897   v = *get_shapes_and_types_from_context(0);
1898   EXPECT_EQ(t[1].dtype, v[1].dtype);
1899 }
1900 
1901 }  // namespace
1902 }  // namespace shape_inference
1903 }  // namespace tensorflow
1904