xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/tensor_slice.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/framework/tensor_slice.h"
17 
18 #include <limits>
19 #include <vector>
20 
21 #include "tensorflow/core/lib/core/errors.h"
22 #include "tensorflow/core/lib/strings/numbers.h"
23 #include "tensorflow/core/lib/strings/str_util.h"
24 #include "tensorflow/core/lib/strings/strcat.h"
25 #include "tensorflow/core/platform/logging.h"
26 
27 namespace tensorflow {
28 
TensorSlice(int dim)29 TensorSlice::TensorSlice(int dim) { SetFullSlice(dim); }
30 
TensorSlice(const TensorSliceProto & proto)31 TensorSlice::TensorSlice(const TensorSliceProto& proto) {
32   starts_.reserve(proto.extent_size());
33   lengths_.reserve(proto.extent_size());
34   for (const auto& e : proto.extent()) {
35     starts_.push_back(e.start());
36     lengths_.push_back(GetExtentLength(e));
37   }
38 }
39 
TensorSlice(std::initializer_list<std::pair<int64_t,int64_t>> extents)40 TensorSlice::TensorSlice(
41     std::initializer_list<std::pair<int64_t, int64_t>> extents) {
42   starts_.reserve(extents.size());
43   lengths_.reserve(extents.size());
44   for (const auto& e : extents) {
45     starts_.push_back(e.first);
46     lengths_.push_back(e.second);
47   }
48 }
49 
BuildTensorSlice(const TensorSliceProto & proto,TensorSlice * output)50 Status TensorSlice::BuildTensorSlice(const TensorSliceProto& proto,
51                                      TensorSlice* output) {
52   output->Clear();
53   output->starts_.reserve(proto.extent_size());
54   output->lengths_.reserve(proto.extent_size());
55   for (const auto& e : proto.extent()) {
56     int64_t l = GetExtentLength(e);
57     if (e.start() != 0 || l != kFullExtent) {
58       if (e.start() < 0 || l <= 0) {
59         return errors::InvalidArgument(
60             "Expected non-negative start and positive length but got start = ",
61             e.start(), ", length = ", l, ": extent = ", e.ShortDebugString());
62       }
63       // Calculating the extent end must not cause signed integer overflow.
64       if (static_cast<uint64_t>(e.start()) + static_cast<uint64_t>(e.length()) >
65           std::numeric_limits<int64_t>::max()) {
66         return errors::InvalidArgument(
67             "Extent end exceeds the maximum possible size: extent = ",
68             e.ShortDebugString());
69       }
70     }
71     output->starts_.push_back(e.start());
72     output->lengths_.push_back(l);
73   }
74 
75   return OkStatus();
76 }
77 
Parse(const string & str,TensorSlice * slice)78 Status TensorSlice::Parse(const string& str, TensorSlice* slice) {
79   std::vector<string> items = str_util::Split(str, ':', str_util::SkipEmpty());
80   slice->starts_.reserve(items.size());
81   slice->lengths_.reserve(items.size());
82   for (const string& x : items) {
83     int64_t s, l;
84     if (x == "-") {
85       // "everything"
86       s = 0;
87       l = kFullExtent;
88     } else {
89       std::vector<string> sl = str_util::Split(x, ',', str_util::SkipEmpty());
90       if (sl.size() != 2 || !strings::safe_strto64(sl[0], &s) ||
91           !strings::safe_strto64(sl[1], &l)) {
92         return errors::InvalidArgument(
93             "Expected a pair of numbers or '-' "
94             "but got '",
95             x, "': string = ", str);
96       }
97       if (s < 0 || l <= 0) {
98         return errors::InvalidArgument(
99             "Expected non-negative start and "
100             "positive length but got start = ",
101             s, ", length = ", l, ": string = ", str);
102       }
103     }
104     slice->starts_.push_back(s);
105     slice->lengths_.push_back(l);
106   }
107 
108   return OkStatus();
109 }
110 
Clear()111 void TensorSlice::Clear() {
112   starts_.clear();
113   lengths_.clear();
114 }
115 
IsFull() const116 bool TensorSlice::IsFull() const {
117   for (int d = 0; d < dims(); ++d) {
118     if (!IsFullAt(d)) return false;
119   }
120   return true;
121 }
122 
SetFullSlice(int dim)123 void TensorSlice::SetFullSlice(int dim) {
124   Clear();
125   starts_.reserve(dim);
126   lengths_.reserve(dim);
127   for (int d = 0; d < dim; ++d) {
128     starts_.push_back(0);
129     lengths_.push_back(kFullExtent);
130   }
131 }
132 
Extend(int dim)133 void TensorSlice::Extend(int dim) {
134   int old_dim = dims();
135   DCHECK_LE(old_dim, dim);
136   starts_.resize(dim);
137   lengths_.resize(dim);
138   for (int d = old_dim; d < dim; ++d) {
139     starts_[d] = 0;
140     lengths_[d] = kFullExtent;
141   }
142 }
143 
AsProto(TensorSliceProto * proto) const144 void TensorSlice::AsProto(TensorSliceProto* proto) const {
145   for (int d = 0; d < dims(); ++d) {
146     TensorSliceProto::Extent* e = proto->add_extent();
147     // We only need to record the explicit slice for non-full slices
148     if (!IsFullAt(d)) {
149       e->set_start(starts_[d]);
150       e->set_length(lengths_[d]);
151     }
152   }
153 }
154 
DebugString() const155 string TensorSlice::DebugString() const {
156   string buffer;
157   bool first = true;
158   for (int d = 0; d < dims(); ++d) {
159     if (!first) {
160       buffer.append(":");
161     }
162     if (IsFullAt(d)) {
163       buffer.append("-");
164     } else {
165       strings::StrAppend(&buffer, starts_[d], ",", lengths_[d]);
166     }
167     first = false;
168   }
169   return buffer;
170 }
171 
Intersect(const TensorSlice & other,TensorSlice * result) const172 bool TensorSlice::Intersect(const TensorSlice& other,
173                             TensorSlice* result) const {
174   // First, if two slices have different ranks, they obviously don't overlap
175   // -- in fact they are not compatible.
176   if (dims() != other.dims()) {
177     return false;
178   }
179 
180   // Setting the result to the right dimension
181   if (result) {
182     result->SetFullSlice(dims());
183   }
184   // The two slices overlap if they overlap in all dimensions.
185   for (int d = 0; d < dims(); ++d) {
186     if (IsFullAt(d)) {
187       if (result) {
188         result->set_start(d, other.start(d));
189         result->set_length(d, other.length(d));
190       }
191     } else if (other.IsFullAt(d)) {
192       if (result) {
193         result->set_start(d, start(d));
194         result->set_length(d, length(d));
195       }
196     } else {
197       // If we have an intersection here, it should have a start that is the
198       // max of the two starts and an end that is the min of the two ends.
199       int64_t s = std::max(start(d), other.start(d));
200       int64_t l = std::min(end(d), other.end(d)) - s;
201       if (l > 0) {
202         // We have a real intersection
203         if (result) {
204           result->set_start(d, s);
205           result->set_length(d, l);
206         }
207       } else {
208         // We don't have an intersection for this dimension -- thus we don't
209         // have any intersection at all.
210         if (result) {
211           result->Clear();
212         }
213         return false;
214       }
215     }
216   }
217   // If we are here, we know there is overlap in every dimension.
218   return true;
219 }
220 
operator ==(const TensorSlice & other) const221 bool TensorSlice::operator==(const TensorSlice& other) const {
222   return dims() == other.dims() && starts_ == other.starts_ &&
223          lengths_ == other.lengths_;
224 }
225 
ComputeRelative(const TensorSlice & sub,TensorSlice * relative) const226 void TensorSlice::ComputeRelative(const TensorSlice& sub,
227                                   TensorSlice* relative) const {
228   DCHECK_EQ(dims(), sub.dims());
229   relative->SetFullSlice(dims());
230   for (int d = 0; d < dims(); ++d) {
231     if (IsFullAt(d)) {
232       relative->set_start(d, sub.start(d));
233       relative->set_length(d, sub.length(d));
234     } else {
235       // Otherwise the relative start is the difference between the start of
236       // sub and the start of base
237       relative->set_start(d, sub.start(d) - start(d));
238       relative->set_length(d, sub.length(d));
239     }
240   }
241 }
242 
UpdateToCover(const TensorSlice & other)243 void TensorSlice::UpdateToCover(const TensorSlice& other) {
244   DCHECK_EQ(dims(), other.dims());
245   for (int d = 0; d < dims(); ++d) {
246     if (!IsFullAt(d)) {
247       if (other.IsFullAt(d)) {
248         starts_[d] = 0;
249         lengths_[d] = kFullExtent;
250       } else {
251         const auto new_end = std::max(end(d), other.end(d));
252         set_start(d, std::min(start(d), other.start(d)));
253         set_length(d, new_end - start(d));
254       }
255     }
256   }
257 }
258 
259 // static
HasExtentLength(const TensorSliceProto::Extent & extent)260 bool TensorSlice::HasExtentLength(const TensorSliceProto::Extent& extent) {
261   return extent.has_length_case() == TensorSliceProto::Extent::kLength;
262 }
263 
264 // static
GetExtentLength(const TensorSliceProto::Extent & extent)265 int64_t TensorSlice::GetExtentLength(const TensorSliceProto::Extent& extent) {
266   if (!HasExtentLength(extent)) return -1;
267   return extent.length();
268 }
269 
SliceTensorShape(const TensorShape & shape,TensorShape * result_shape) const270 Status TensorSlice::SliceTensorShape(const TensorShape& shape,
271                                      TensorShape* result_shape) const {
272   result_shape->Clear();
273   // Mismatching ranks: we can't apply the slice at all.
274   if (shape.dims() != dims()) {
275     return errors::Internal("Mismatching ranks: shape = ", shape.DebugString(),
276                             ", slice = ", DebugString());
277   }
278   for (int d = 0; d < dims(); ++d) {
279     if (IsFullAt(d)) {
280       result_shape->AddDim(shape.dim_size(d));
281     } else {
282       // Check if the extent applies to the dimension
283       if (end(d) <= shape.dim_size(d)) {
284         // Yes: the end is within the range of the dim -- we adjust the result
285         // shape so that its size along this dimension is the length of the
286         // slice.
287         result_shape->AddDim(length(d));
288       } else {
289         // The extent doesn't apply to the dimension
290         result_shape->Clear();
291         return errors::Internal("Extent in dimension ", d,
292                                 " out of bounds: shape = ", shape.DebugString(),
293                                 ", slice = ", DebugString());
294       }
295     }
296   }
297   // If we are here, we have successfully applied the shape.
298   return OkStatus();
299 }
300 
301 const int64_t TensorSlice::kFullExtent = -1;
302 
303 }  // namespace tensorflow
304