1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/util/tensor_slice_set.h"
17
18 #include <vector>
19 #include "tensorflow/core/lib/core/status.h"
20 #include "tensorflow/core/platform/logging.h"
21 #include "tensorflow/core/platform/test.h"
22 #include "tensorflow/core/platform/test_benchmark.h"
23
24 namespace tensorflow {
25
26 namespace checkpoint {
27
28 namespace {
29
30 // A simple test: we have a 2-d tensor of shape 4 X 5 that looks like this:
31 //
32 // 0 1 2 3 4
33 // 5 6 7 8 9
34 // 10 11 12 13 14
35 // 15 16 17 18 19
36 //
37 // We assume this is a row-major matrix.
38 //
39 // Testing the meta version of the tensor slice set.
TEST(TensorSliceSetTest,QueryMetaTwoD)40 TEST(TensorSliceSetTest, QueryMetaTwoD) {
41 TensorShape shape({4, 5});
42
43 TensorSliceSet tss(shape, DT_INT32);
44 // We store a few slices.
45
46 // Slice #1 is the top two rows:
47 // 0 1 2 3 4
48 // 5 6 7 8 9
49 // . . . . .
50 // . . . . .
51 TensorSlice slice_1 = TensorSlice::ParseOrDie("0,2:-");
52 TF_CHECK_OK(tss.Register(slice_1, "slice_1"));
53
54 // Slice #2 is the bottom left corner
55 // . . . . .
56 // . . . . .
57 // 10 11 12 . .
58 // 15 16 17 . .
59 TensorSlice slice_2 = TensorSlice::ParseOrDie("2,2:0,3");
60 TF_CHECK_OK(tss.Register(slice_2, "slice_2"));
61
62 // Slice #3 is the bottom right corner
63 // . . . . .
64 // . . . . .
65 // . . . . .
66 // . . . 18 19
67 TensorSlice slice_3 = TensorSlice::ParseOrDie("3,1:3,2");
68 TF_CHECK_OK(tss.Register(slice_3, "slice_3"));
69
70 // Notice that we leave a hole in the tensor
71 // . . . . .
72 // . . . . .
73 // . . . (13) (14)
74 // . . . . .
75
76 // Now we query some of the slices
77
78 // Slice #1 is an exact match
79 // 0 1 2 3 4
80 // 5 6 7 8 9
81 // . . . . .
82 // . . . . .
83 // We just need slice_1 for this
84 {
85 TensorSlice s = TensorSlice::ParseOrDie("0,2:-");
86 std::vector<std::pair<TensorSlice, string>> results;
87 EXPECT_TRUE(tss.QueryMeta(s, &results));
88 EXPECT_EQ(1, results.size());
89 EXPECT_EQ("0,2:-", results[0].first.DebugString());
90 EXPECT_EQ("slice_1", results[0].second);
91 }
92
93 // Slice #2 is a subset match
94 // . . . . .
95 // 5 6 7 8 9
96 // . . . . .
97 // . . . . .
98 // We just need slice_1 for this
99 {
100 TensorSlice s = TensorSlice::ParseOrDie("1,1:-");
101 std::vector<std::pair<TensorSlice, string>> results;
102 EXPECT_TRUE(tss.QueryMeta(s, &results));
103 EXPECT_EQ(1, results.size());
104 EXPECT_EQ("0,2:-", results[0].first.DebugString());
105 EXPECT_EQ("slice_1", results[0].second);
106 }
107
108 // Slice #3 is a more complicated match: it needs the combination of a couple
109 // of slices
110 // . . . . .
111 // 5 6 7 . .
112 // 10 11 12 . .
113 // . . . . .
114 // We need both slice_1 and slice_2 for this.
115 {
116 TensorSlice s = TensorSlice::ParseOrDie("1,2:0,3");
117 std::vector<std::pair<TensorSlice, string>> results;
118 EXPECT_TRUE(tss.QueryMeta(s, &results));
119 EXPECT_EQ(2, results.size());
120 // Allow results to be returned in either order
121 if (results[0].second == "slice_2") {
122 EXPECT_EQ("2,2:0,3", results[0].first.DebugString());
123 EXPECT_EQ("slice_2", results[0].second);
124 EXPECT_EQ("0,2:-", results[1].first.DebugString());
125 EXPECT_EQ("slice_1", results[1].second);
126 } else {
127 EXPECT_EQ("0,2:-", results[0].first.DebugString());
128 EXPECT_EQ("slice_1", results[0].second);
129 EXPECT_EQ("2,2:0,3", results[1].first.DebugString());
130 EXPECT_EQ("slice_2", results[1].second);
131 }
132 }
133
134 // Slice #4 includes the hole and so there is no match
135 // . . . . .
136 // . . 7 8 9
137 // . . 12 13 14
138 // . . . . .
139 {
140 TensorSlice s = TensorSlice::ParseOrDie("1,2:2,3");
141 std::vector<std::pair<TensorSlice, string>> results;
142 EXPECT_FALSE(tss.QueryMeta(s, &results));
143 EXPECT_EQ(0, results.size());
144 }
145 }
146
BM_RegisterOneByOne(::testing::benchmark::State & state)147 static void BM_RegisterOneByOne(::testing::benchmark::State& state) {
148 TensorShape shape({static_cast<int>(state.max_iterations), 41});
149 TensorSliceSet slice_set(shape, DT_INT32);
150 int i = 0;
151 for (auto s : state) {
152 TensorSlice part({{i, 1}, {0, -1}});
153 TF_CHECK_OK(slice_set.Register(part, part.DebugString()));
154 ++i;
155 }
156 }
157
158 BENCHMARK(BM_RegisterOneByOne);
159
160 } // namespace
161
162 } // namespace checkpoint
163
164 } // namespace tensorflow
165