xref: /aosp_15_r20/external/tensorflow/tensorflow/core/util/tensor_slice_set_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/util/tensor_slice_set.h"
17 
18 #include <vector>
19 #include "tensorflow/core/lib/core/status.h"
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/core/platform/test.h"
22 #include "tensorflow/core/platform/test_benchmark.h"
23 
24 namespace tensorflow {
25 
26 namespace checkpoint {
27 
28 namespace {
29 
30 // A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this:
31 //
32 //   0   1   2   3   4
33 //   5   6   7   8   9
34 //  10  11  12  13  14
35 //  15  16  17  18  19
36 //
37 // We assume this is a row-major matrix.
38 //
39 // Testing the meta version of the tensor slice set.
TEST(TensorSliceSetTest,QueryMetaTwoD)40 TEST(TensorSliceSetTest, QueryMetaTwoD) {
41   TensorShape shape({4, 5});
42 
43   TensorSliceSet tss(shape, DT_INT32);
44   // We store a few slices.
45 
46   // Slice #1 is the top two rows:
47   //   0   1   2   3   4
48   //   5   6   7   8   9
49   //   .   .   .   .   .
50   //   .   .   .   .   .
51   TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
52   TF_CHECK_OK(tss.Register(slice_1, "slice_1"));
53 
54   // Slice #2 is the bottom left corner
55   //   .   .   .   .   .
56   //   .   .   .   .   .
57   //  10  11  12   .   .
58   //  15  16  17   .   .
59   TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
60   TF_CHECK_OK(tss.Register(slice_2, "slice_2"));
61 
62   // Slice #3 is the bottom right corner
63   //   .   .   .   .   .
64   //   .   .   .   .   .
65   //   .   .   .   .   .
66   //   .   .   .  18  19
67   TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
68   TF_CHECK_OK(tss.Register(slice_3, "slice_3"));
69 
70   // Notice that we leave a hole in the tensor
71   //   .   .   .   .   .
72   //   .   .   .   .   .
73   //   .   .   . (13) (14)
74   //   .   .   .   .   .
75 
76   // Now we query some of the slices
77 
78   // Slice #1 is an exact match
79   //   0   1   2   3   4
80   //   5   6   7   8   9
81   //   .   .   .   .   .
82   //   .   .   .   .   .
83   // We just need slice_1 for this
84   {
85     TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
86     std::vector<std::pair<TensorSlice, string>> results;
87     EXPECT_TRUE(tss.QueryMeta(s, &results));
88     EXPECT_EQ(1, results.size());
89     EXPECT_EQ("0,2:-", results[0].first.DebugString());
90     EXPECT_EQ("slice_1", results[0].second);
91   }
92 
93   // Slice #2 is a subset match
94   //   .   .   .   .   .
95   //   5   6   7   8   9
96   //   .   .   .   .   .
97   //   .   .   .   .   .
98   // We just need slice_1 for this
99   {
100     TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
101     std::vector<std::pair<TensorSlice, string>> results;
102     EXPECT_TRUE(tss.QueryMeta(s, &results));
103     EXPECT_EQ(1, results.size());
104     EXPECT_EQ("0,2:-", results[0].first.DebugString());
105     EXPECT_EQ("slice_1", results[0].second);
106   }
107 
108   // Slice #3 is a more complicated match: it needs the combination of a couple
109   // of slices
110   //   .   .   .   .   .
111   //   5   6   7   .   .
112   //  10  11  12   .   .
113   //   .   .   .   .   .
114   // We need both slice_1 and slice_2 for this.
115   {
116     TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
117     std::vector<std::pair<TensorSlice, string>> results;
118     EXPECT_TRUE(tss.QueryMeta(s, &results));
119     EXPECT_EQ(2, results.size());
120     // Allow results to be returned in either order
121     if (results[0].second == "slice_2") {
122       EXPECT_EQ("2,2:0,3", results[0].first.DebugString());
123       EXPECT_EQ("slice_2", results[0].second);
124       EXPECT_EQ("0,2:-", results[1].first.DebugString());
125       EXPECT_EQ("slice_1", results[1].second);
126     } else {
127       EXPECT_EQ("0,2:-", results[0].first.DebugString());
128       EXPECT_EQ("slice_1", results[0].second);
129       EXPECT_EQ("2,2:0,3", results[1].first.DebugString());
130       EXPECT_EQ("slice_2", results[1].second);
131     }
132   }
133 
134   // Slice #4 includes the hole and so there is no match
135   //   .   .   .   .   .
136   //   .   .   7   8   9
137   //   .   .  12  13  14
138   //   .   .   .   .   .
139   {
140     TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
141     std::vector<std::pair<TensorSlice, string>> results;
142     EXPECT_FALSE(tss.QueryMeta(s, &results));
143     EXPECT_EQ(0, results.size());
144   }
145 }
146 
BM_RegisterOneByOne(::testing::benchmark::State & state)147 static void BM_RegisterOneByOne(::testing::benchmark::State& state) {
148   TensorShape shape({static_cast<int>(state.max_iterations), 41});
149   TensorSliceSet slice_set(shape, DT_INT32);
150   int i = 0;
151   for (auto s : state) {
152     TensorSlice part({{i, 1}, {0, -1}});
153     TF_CHECK_OK(slice_set.Register(part, part.DebugString()));
154     ++i;
155   }
156 }
157 
158 BENCHMARK(BM_RegisterOneByOne);
159 
160 }  // namespace
161 
162 }  // namespace checkpoint
163 
164 }  // namespace tensorflow
165