xref: /aosp_15_r20/external/tensorflow/tensorflow/core/framework/tensor_shape.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
17 #define TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
18 
19 #include <string>
20 
21 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
22 #include "tensorflow/core/framework/types.pb.h"
23 #include "tensorflow/core/lib/gtl/array_slice.h"
24 #include "tensorflow/core/lib/gtl/inlined_vector.h"
25 #include "tensorflow/core/lib/strings/str_util.h"
26 #include "tensorflow/core/platform/errors.h"
27 #include "tensorflow/core/platform/logging.h"
28 #include "tensorflow/core/platform/macros.h"
29 #include "tensorflow/core/platform/statusor.h"
30 
31 namespace tensorflow {
32 
33 // START_SKIP_DOXYGEN
34 template <class Shape>
35 class TensorShapeIter;
36 class TensorShape;
37 class TensorShapeProto;
38 class PartialTensorShape;
39 // END_SKIP_DOXYGEN
40 
41 /// Internal representation for both TensorShape and PartialTensorShape.
42 class TensorShapeRep {
43  public:
44   ~TensorShapeRep();
45 
46   /// Copy the specified shape
47   TensorShapeRep(const TensorShapeRep& b);
48   void operator=(const TensorShapeRep& b);
49 
50   /// Move the specified shape.  After moving, `b` is safe for destruction and
51   // can be reassigned into, but its dimensions and number of elements can be
52   // nonsensical (e.g., negative dimension sizes, or number of elements not
53   // properly recomputed).
54   TensorShapeRep(TensorShapeRep&& b);
55   void operator=(TensorShapeRep&& b);
56 
57   /// Clear a tensor shape, producing the scalar shape.
58   void Clear();
59 
60   // Maximum number of dimensions in a tensor.
61   // It's 254 because 255 = kUnknownRank is used to represent unknown rank.
MaxDimensions()62   static constexpr int MaxDimensions() { return 254; }
63 
64   /// \brief Returns the number of elements in the tensor.
65   ///
66   /// We use `int64` and not `size_t` to be compatible with `Eigen::Tensor`
67   /// which uses `ptrdiff_t`.  For PartialTensorShape, -1 means not fully
68   /// defined.
num_elements()69   int64_t num_elements() const { return num_elements_; }
70 
71   /// For error messages.
72   std::string DebugString() const;
73   static std::string DebugString(const TensorShapeProto& proto);
74 
75  protected:
76   // Constructable only via TensorShapeBase
77   TensorShapeRep() = default;
78 
79   void ClearAllButDataType();
80 
81   // We use 16 bytes to represent a TensorShape.  Because we need to
82   // be able to support full 64-bit dimension sizes and an arbitrary
83   // number of dimensions for a Tensor, but most tensor dimensions are
84   // significantly smaller than 64 bits and most tensors are 1, 2, or 3
85   // dimensions, we have several representations.
86   // Rep16: Supports up to 6 dimensions where each dimension is < 2^16 - 1
87   // Rep32: Supports up to 3 dimensions where each dimension is < 2^32 - 1
88   // Rep64: Supports arbitrary dimensionality, 64-bit dimensions using
89   //        an out of line vector.
90   // For PartialTensorShape, a dimension of static_cast<uint??>(-1) is unknown.
91   // This value is not allowed in TensorShape either for format compatibility.
92   struct Rep16 {
93     uint16 dims_[6];
94   };
95   struct Rep32 {
96     uint32 dims_[3];
97   };
98   struct Rep64 {
99     gtl::InlinedVector<int64_t, 4>* dims_;
100   };
101 
102   // We use the max value of uint16 or uint32 to represent unknown shapes, so
103   // the maximum representable valid shape in these representations is one less.
104   static constexpr int64_t kMaxRep16 = std::numeric_limits<uint16>::max() - 1;
105   static constexpr int64_t kMaxRep32 = std::numeric_limits<uint32>::max() - 1;
106   static constexpr uint16 kUnknownRep16 = std::numeric_limits<uint16>::max();
107   static constexpr uint32 kUnknownRep32 = std::numeric_limits<uint32>::max();
108 
as16()109   Rep16* as16() { return reinterpret_cast<Rep16*>(buf()); }
as32()110   Rep32* as32() { return reinterpret_cast<Rep32*>(buf()); }
as64()111   Rep64* as64() { return reinterpret_cast<Rep64*>(buf()); }
112 
as16()113   const Rep16* as16() const { return reinterpret_cast<const Rep16*>(buf()); }
as32()114   const Rep32* as32() const { return reinterpret_cast<const Rep32*>(buf()); }
as64()115   const Rep64* as64() const { return reinterpret_cast<const Rep64*>(buf()); }
116 
117   enum RepTag { REP16 = 0, REP32 = 1, REP_OUT_OF_LINE = 2 };
118 
119   // Since we have a convenient extra byte available, we allow the
120   // Tensor class to store an 8-bit value in this extra storage.  This
121   // allows it to store the Tensor's datatype enum value here and avoid
122   // an extra word of storage.
123   friend class Tensor;
124   friend class TensorShapeTestHelper;
data_type()125   DataType data_type() const { return static_cast<DataType>(buf()[13]); }
set_data_type(DataType dt)126   void set_data_type(DataType dt) {
127     // We only have 8 bits available to store DataType, so make sure it fits
128     DCHECK_LT(static_cast<uint32>(dt), 256u);
129     buf()[13] = static_cast<uint8>(dt);
130   }
131 
132   // We store the number of dimensions in byte 14, and the RepTag in byte 15.
133   // Bytes [0..13] vary depending on the representation.
134   // A value of 255 indicates unknown rank in the PartialTensorShape case.
135   static constexpr uint8 kUnknownRank = 255;
ndims_byte()136   uint8 ndims_byte() const { return buf()[14]; }
set_ndims_byte(uint8 nd)137   void set_ndims_byte(uint8 nd) { buf()[14] = nd; }
138 
tag()139   RepTag tag() const { return static_cast<RepTag>(buf()[15]); }
set_tag(RepTag tag)140   void set_tag(RepTag tag) { buf()[15] = static_cast<uint8>(tag); }
141 
set_num_elements(int64_t n)142   void set_num_elements(int64_t n) { num_elements_ = n; }
143 
144  private:
145   void DestructorOutOfLine();
146   void SlowCopyFrom(const TensorShapeRep& b);
147 
buf()148   uint8* buf() { return &u_.buf[0]; }
buf()149   const uint8* buf() const { return &u_.buf[0]; }
150 
151   union {
152     uint8 buf[16];
153     // Force data to be aligned enough for a pointer.
154     Rep64* unused_aligner;
155   } u_;
156   int64_t num_elements_;
157 };
158 
159 /// Base class for TensorShape and PartialTensorShape.
160 /// The class is templatized by either TensorShape or PartialTensorShape to
161 /// allow skipping known/unknown checks in the TensorShape case, but the
162 /// representation is shared exactly for fast conversion.
163 template <class Shape>
164 class TensorShapeBase : public TensorShapeRep {
165  public:
166   /// \brief Construct a `TensorShapeBase` from the provided sizes.
167   /// REQUIRES: `dim_sizes[i] >= 0` (or >= -1 for PartialTensorShape)
168   explicit TensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes);
TensorShapeBase(std::initializer_list<int64_t> dim_sizes)169   TensorShapeBase(std::initializer_list<int64_t> dim_sizes)
170       : TensorShapeBase(gtl::ArraySlice<int64_t>(dim_sizes)) {}
171 
172   /// Construct an empty TensorShape, or an unknown rank PartialTensorShape
173   TensorShapeBase();
174 
175   // Cannot be made explicit because we rely on conversion between proto and
176   // `TensorShapeBase` throughtout the codebase (needs bigger cleanup)
177   TensorShapeBase(const TensorShapeProto& proto);
178 
179   // These factory methods should be used instead of the constructors that take
180   // an array of sizes if calling code cannot validate that the sizes specify a
181   // valid `TensorShape`.
182   // The value in `*out` is valid iff the returned value is `Status::OK`.
183   static Status BuildTensorShapeBase(gtl::ArraySlice<int64_t> dim_sizes,
184                                      TensorShapeBase* out);
BuildTensorShapeBase(std::initializer_list<int64_t> dim_sizes,TensorShapeBase * out)185   static Status BuildTensorShapeBase(std::initializer_list<int64_t> dim_sizes,
186                                      TensorShapeBase* out) {
187     return BuildTensorShapeBase(gtl::ArraySlice<int64_t>(dim_sizes), out);
188   }
189   static Status BuildTensorShapeBase(const TensorShapeProto& proto,
190                                      TensorShapeBase* out);
191 
192   /// Returns `true` iff `proto` is a valid tensor shape.
193   // For TensorShape, the proto shape must be fully defined.
194   static bool IsValid(const TensorShapeProto& proto);
195 
196   /// Returns `OK` iff `proto` is a valid tensor shape, and a descriptive error
197   /// status otherwise.
198   static Status IsValidShape(const TensorShapeProto& proto);
199 
200   /// Returns `true` iff this is a valid tensor shape.
201   bool IsValid();
202 
203   /// \brief Add a dimension to the end ("inner-most").
204   /// REQUIRES: `size >= 0`
205   void AddDim(int64_t size);
206 
207   /// Same as `AddDim` but returns a `Status`.
208   /// Use if unsure is `size >= 0`, to prevent `CHECK`-crashes.
209   Status AddDimWithStatus(int64_t size);
210 
211   /// Appends all the dimensions from `shape`.
212   void AppendShape(const TensorShapeBase& shape);
213 
214   /// Same as `RemoveDim` but returns a `Status`.
215   /// Use if you cannot validate all invariants, to prevent `CHECK`-fail.
216   Status AppendShapeWithStatus(const TensorShapeBase& shape);
217 
218   /// \brief Insert a dimension somewhere in the `TensorShape`.
219   /// REQUIRES: `0 <= d <= dims()`
220   /// REQUIRES: `size >= 0`
221   void InsertDim(int d, int64_t size);
222 
223   /// Same as `InsertDim` but returns a `Status`.
224   /// Use if unsure if requirements in `InsertDim` are satistified, to prevent
225   /// `CHECK`-fail crashes.
226   Status InsertDimWithStatus(int d, int64_t size);
227 
228   /// \brief Modifies the size of the dimension `d` to be `size`
229   /// REQUIRES: `0 <= d < dims()`
230   /// REQUIRES: `size >= 0`
231   void set_dim(int d, int64_t size);
232 
233   /// Same as `set_dim` but returns a `Status`.
234   /// Use if unsure if requirements in `set_dim` are satistified, to prevent
235   /// `CHECK`-fail crashes.
236   Status SetDimWithStatus(int d, int64_t size);
237 
238   /// \brief Removes dimension `d` from the `TensorShape`.
239   /// REQUIRES: `0 <= d < dims()`
RemoveDim(int d)240   void RemoveDim(int d) {
241     CHECK_GE(d, 0);
242     RemoveDimRange(d, d + 1);
243   }
244 
245   /// Same as `RemoveDim` but returns a `Status`.
246   /// Use if unsure is `0 <= d < dims()`, to prevent `CHECK`-crashes.
RemoveDimWithStatus(int64_t d)247   Status RemoveDimWithStatus(int64_t d) {
248     if (TF_PREDICT_FALSE(d < 0)) {
249       return errors::Internal(
250           "Expected dimension index to be non-negative, got ", d);
251     }
252     return RemoveDimRangeWithStatus(d, d + 1);
253   }
254 
255   /// \brief Removes last `n` dimensions from the `TensorShape`.
256   /// REQUIRES: `0 <= n <= dims()`
RemoveLastDims(int n)257   void RemoveLastDims(int n) {
258     CHECK_LE(n, dims());
259     RemoveDimRange(dims() - n, dims());
260   }
261 
262   /// Same as `RemoveLastDims` but returns a `Status`.
263   /// Use if unsure is `0 <= n <= dims()`, to prevent `CHECK`-crashes.
RemoveLastDimsWithStatus(int64_t n)264   Status RemoveLastDimsWithStatus(int64_t n) {
265     if (TF_PREDICT_FALSE(n < dims())) {
266       return errors::Internal("Expected dimension index to be at most ", dims(),
267                               " got ", n);
268     }
269     return RemoveDimRangeWithStatus(dims() - n, dims());
270   }
271 
272   /// \brief Removes the dimensions in range `[begin:end)` from `TensorShape`.
273   /// Negative values of `end` are interpreted as `dims() + end + 1` (as in
274   /// Python). The same is true for negative values of `begin`.
275   /// REQUIRES: `-(dims()+1) <= begin <= dims()`
276   /// REQUIRES: `-(dims()+1) <= end <= dims()`
277   void RemoveDimRange(int begin, int end);
278 
279   /// Same as `RemoveDimRange` but returns a `Status`.
280   /// Use if unsure if requirements in `RemoveDimRange` are satistified, to
281   /// prevent `CHECK`-fail crashes.
282   Status RemoveDimRangeWithStatus(int begin, int end);
283 
284   /// Return whether the rank is unknown
unknown_rank()285   bool unknown_rank() const {
286     return kIsPartial && ndims_byte() == kUnknownRank;
287   }
288 
289   /// Return the number of dimensions in the tensor.
290   /// Can be -1 meaning unknown rank for PartialTensorShape.
dims()291   int dims() const {
292     uint8 dims = ndims_byte();
293     return kIsPartial && dims == kUnknownRank ? -1 : dims;
294   }
295 
296   /// \brief Returns the number of elements in dimension `d`.
297   /// REQUIRES: `0 <= d < dims()`
298   // TODO(touts): Rename to `dimension()` to match
299   // `Eigen::Tensor::dimension()`?
300   int64_t dim_size(int d) const;
301 
302   /// Returns sizes of all dimensions.
303   // Returns an empty list for unknown rank PartialTensorShape.
304   gtl::InlinedVector<int64_t, 4> dim_sizes() const;
305 
306   /// Return true iff the rank and all of the dimensions are well defined
307   // TODO(irving): Rename to is_fully_defined now that it's fast.
IsFullyDefined()308   bool IsFullyDefined() const { return !kIsPartial || num_elements() != -1; }
309 
310   /// Fill `*proto` from `*this`.
311   void AsProto(TensorShapeProto* proto) const;
312   TensorShapeProto AsProto() const;
313 
314   /// For iterating through the dimensions.
315   TensorShapeIter<Shape> begin() const;
316   TensorShapeIter<Shape> end() const;
317 
318  protected:
319   // Optimized constructor for a shape representing an empty vector.
320   //
321   // This constructor is provided to optimize the default constructor for
322   // `Tensor`.
323   explicit TensorShapeBase(DataType dt);
324 
325  private:
326   Status RecomputeNumElements();
327   Status InitDims(gtl::ArraySlice<int64_t> dim_sizes);
328 
329   // True for PartialTensorShape, false for TensorShape
330   static constexpr bool kIsPartial =
331       std::is_same<Shape, PartialTensorShape>::value;
332   static_assert(kIsPartial || std::is_same<Shape, TensorShape>::value,
333                 "Shape is neither TensorShape nor PartialTensorShape");
334 
335   // Used by AddDim and MakeShapeHelper.  Does no error checking.
336   void UnsafeAddDim(int64_t size, int64_t new_num_elements);
337 
338   // For use by TensorShapeUtils::MakeShape
339   template <class T, class S>
340   friend Status MakeShapeHelper(const T*, int64_t, S*);
341 };
342 
343 /// Outputs `TensorShapeBase` to `std::ostream`.
344 template <typename Shape>
345 std::ostream& operator<<(std::ostream& os, const TensorShapeBase<Shape>& tsb) {
346   return os << tsb.DebugString();
347 }
348 
349 /// Represents the shape of a Tensor.
350 ///
351 /// A tensor's shape is denoted by its number of dimensions and a size for each
352 /// dimension.  For example, a Tensor represented by a 3 x 4 matrix would have
353 /// a shape of 2-D, [3,4].
354 ///
355 /// If you know the exact shape of your Tensor when you create the TensorShape
356 /// object, you can specify it then, or you can create a TensorShape with
357 /// zero dimensions and one element, and call AddDim() to add dimensions later.
358 class TensorShape : public TensorShapeBase<TensorShape> {
359  public:
360   using TensorShapeBase<TensorShape>::TensorShapeBase;
361 
362   // These factory methods should be used instead of the constructors that take
363   // an array of sizes if calling code cannot validate that the sizes specify a
364   // valid `TensorShape`.
365   // The value in `*out` is valid iff the returned value is `Status::OK`.
BuildTensorShape(gtl::ArraySlice<int64_t> dim_sizes,TensorShape * out)366   static Status BuildTensorShape(gtl::ArraySlice<int64_t> dim_sizes,
367                                  TensorShape* out) {
368     return BuildTensorShapeBase(dim_sizes, out);
369   }
BuildTensorShape(std::initializer_list<int64_t> dim_sizes,TensorShape * out)370   static Status BuildTensorShape(std::initializer_list<int64_t> dim_sizes,
371                                  TensorShape* out) {
372     return BuildTensorShape(gtl::ArraySlice<int64_t>(dim_sizes), out);
373   }
BuildTensorShape(const TensorShapeProto & proto,TensorShape * out)374   static Status BuildTensorShape(const TensorShapeProto& proto,
375                                  TensorShape* out) {
376     return BuildTensorShapeBase(proto, out);
377   }
378 
BuildTensorShape(const TensorShapeProto & proto)379   static StatusOr<TensorShape> BuildTensorShape(const TensorShapeProto& proto) {
380     TensorShape out;
381     TF_RETURN_IF_ERROR(BuildTensorShape(proto, &out));
382     return out;
383   }
384 
385   /// Allow a TensorShape to be used as a PartialTensorShape without copying
386   operator const PartialTensorShape&() const;  // NOLINT(runtime/explicit)
387 
388   /// Returns true if `*this` and `b` have the same sizes. Ignores
389   /// dimension names.
390   bool IsSameSize(const TensorShape& b) const;
391   bool operator==(const TensorShape& b) const { return IsSameSize(b); }
392   bool operator!=(const TensorShape& b) const { return !IsSameSize(b); }
393 
394   /// Fill `*dsizes` from `*this`.
395   /// Notice: Using IndexType=int32 in combination with To32Bit() can
396   /// significantly improve performance on GPU.
397   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
398   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizes() const;
399 
400   // Same as `AsEigenDSizes()` but returns a `Status` instead.
401   // Use this method to surface error to user instead of crashing if `NDMIS` is
402   // not equal to `dims()`.
403   // Caller must take ownership of `out`.
404   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
405   Status AsEigenDSizesWithStatus(Eigen::DSizes<IndexType, NDIMS>* out) const;
406 
407   /// Same as `AsEigenDSizes()` but allows for `NDIMS > dims()` -- in
408   /// which case we pad the rest of the sizes with 1.
409   /// Notice: Using IndexType=int32 in combination with To32Bit() can
410   /// significantly improve performance on GPU.
411   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
412   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesWithPadding() const;
413 
414   // Same as `AsEigenDSizesWithPadding()` but returns a `Status` instead.
415   // Use this method to surface error to user instead of crashing if `NDMIS` is
416   // not equal to `dims()`.
417   // Caller must take ownership of `out`.
418   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
419   Status AsEigenDSizesWithPaddingWithStatus(
420       Eigen::DSizes<IndexType, NDIMS>* out) const;
421 
422  private:
423   // These CHECK fail to ease debugging.
424   // REQUIRES: dims() == NDIMS
425   void CheckDimsEqual(int NDIMS) const;
426   // REQUIRES: dims() <= NDIMS
427   void CheckDimsAtMost(int NDIMS) const;
428 
429   // Fill output from `*this`.
430   // Helper method for common code between `AsEigenDSize()` and
431   // `AsEigenDSizeWithStatus()`.
432   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
433   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopy() const;
434 
435   // Fill output from `*this`.
436   // Helper method for common code between `AsEigenDSizesWithPadding()` and
437   // `AsEigenDSizeWithPaddingWithStatus()`.
438   template <int NDIMS, typename IndexType = Eigen::DenseIndex>
439   Eigen::DSizes<IndexType, NDIMS> AsEigenDSizesCopyAndPad() const;
440 
441   // For access to TensorShapeBase(DataType).
442   friend class Tensor;
443 };
444 
445 /// Outputs `TensorShapeBase` to `std::ostream`.
446 inline std::ostream& operator<<(std::ostream& os, const TensorShape& ts) {
447   return os << ts.DebugString();
448 }
449 
450 /// Represents the value of one dimension in a TensorShape.
451 struct TensorShapeDim {
TensorShapeDimTensorShapeDim452   explicit TensorShapeDim(int64_t s) : size(s) {}
453   int64_t size;
454 };
455 
456 // START_SKIP_DOXYGEN
457 template <class Shape>
458 class TensorShapeIter {
459  public:
TensorShapeIter(const Shape * shape,int d)460   TensorShapeIter(const Shape* shape, int d) : shape_(shape), d_(d) {}
461   bool operator==(const TensorShapeIter& rhs) {
462     DCHECK(shape_ == rhs.shape_);
463     return d_ == rhs.d_;
464   }
465   bool operator!=(const TensorShapeIter& rhs) {
466     DCHECK(shape_ == rhs.shape_);
467     return d_ != rhs.d_;
468   }
469   void operator++() { ++d_; }
470   TensorShapeDim operator*() { return TensorShapeDim(shape_->dim_size(d_)); }
471 
472  private:
473   const Shape* shape_;
474   int d_;
475 };
476 // END_SKIP_DOXYGEN
477 
478 /// \brief Static helper routines for `TensorShape`. Includes a few common
479 /// predicates on a tensor shape.
480 class TensorShapeUtils {
481  public:
IsScalar(const TensorShape & shape)482   static bool IsScalar(const TensorShape& shape) { return shape.dims() == 0; }
483 
IsVector(const TensorShape & shape)484   static bool IsVector(const TensorShape& shape) { return shape.dims() == 1; }
485 
IsVectorOrHigher(const TensorShape & shape)486   static bool IsVectorOrHigher(const TensorShape& shape) {
487     return shape.dims() >= 1;
488   }
489 
IsMatrix(const TensorShape & shape)490   static bool IsMatrix(const TensorShape& shape) { return shape.dims() == 2; }
491 
IsSquareMatrix(const TensorShape & shape)492   static bool IsSquareMatrix(const TensorShape& shape) {
493     return shape.dims() == 2 && shape.dim_size(0) == shape.dim_size(1);
494   }
495 
IsMatrixOrHigher(const TensorShape & shape)496   static bool IsMatrixOrHigher(const TensorShape& shape) {
497     return shape.dims() >= 2;
498   }
499 
500   /// \brief Returns a `TensorShape` whose dimensions are
501   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.
502   static Status MakeShape(const int32* dims, int64_t n, TensorShape* out);
503   static Status MakeShape(const int64_t* dims, int64_t n, TensorShape* out);
504   static Status MakeShape(gtl::ArraySlice<int32> shape, TensorShape* out);
505   static Status MakeShape(gtl::ArraySlice<int64_t> shape, TensorShape* out);
506   static Status MakeShape(const int32* dims, int64_t n,
507                           PartialTensorShape* out);
508   static Status MakeShape(const int64_t* dims, int64_t n,
509                           PartialTensorShape* out);
510   static Status MakeShape(gtl::ArraySlice<int32> shape,
511                           PartialTensorShape* out);
512   static Status MakeShape(gtl::ArraySlice<int64_t> shape,
513                           PartialTensorShape* out);
514 
515   static std::string ShapeListString(
516       const gtl::ArraySlice<TensorShape>& shapes);
517 
518   /// \brief Returns true iff `shape` starts with `prefix`.
519   static bool StartsWith(const TensorShape& shape, const TensorShape& prefix);
520 
521   /// \brief Returns true iff `shape` ends with `suffix`.
522   static bool EndsWith(const TensorShape& shape, const TensorShape& suffix);
523 
524   /// \brief Returns the product of values in an int64 array,
525   /// or a failing Status if the array represents a value larger than
526   /// a `TensorShape` can hold.
527   static Status NumElements(gtl::ArraySlice<int64_t> shape,
528                             int64_t* num_elements);
529 };
530 
531 /// Manages the partially known dimensions of a Tensor and their sizes.
532 class PartialTensorShape : public TensorShapeBase<PartialTensorShape> {
533  public:
PartialTensorShape()534   PartialTensorShape() {}
535   using TensorShapeBase<PartialTensorShape>::TensorShapeBase;
536 
537   // These factory methods should be used instead of the constructors that take
538   // an array of sizes if calling code cannot validate that the sizes specify a
539   // valid `PartialTensorShape`.
540   // The value in `*out` is valid iff the returned value is `Status::OK`.
BuildPartialTensorShape(gtl::ArraySlice<int64_t> dim_sizes,PartialTensorShape * out)541   static Status BuildPartialTensorShape(gtl::ArraySlice<int64_t> dim_sizes,
542                                         PartialTensorShape* out) {
543     return BuildTensorShapeBase(dim_sizes, out);
544   }
BuildPartialTensorShape(std::initializer_list<int64_t> dim_sizes,PartialTensorShape * out)545   static Status BuildPartialTensorShape(
546       std::initializer_list<int64_t> dim_sizes, PartialTensorShape* out) {
547     return BuildPartialTensorShape(gtl::ArraySlice<int64_t>(dim_sizes), out);
548   }
BuildPartialTensorShape(const TensorShapeProto & proto,PartialTensorShape * out)549   static Status BuildPartialTensorShape(const TensorShapeProto& proto,
550                                         PartialTensorShape* out) {
551     return BuildTensorShapeBase(proto, out);
552   }
553 
BuildPartialTensorShape(const TensorShapeProto & proto)554   static StatusOr<PartialTensorShape> BuildPartialTensorShape(
555       const TensorShapeProto& proto) {
556     PartialTensorShape out;
557     TF_RETURN_IF_ERROR(BuildTensorShapeBase(proto, &out));
558     return out;
559   }
560 
561   /// Add a dimension to the end ("inner-most"), returns a new
562   /// PartialTensorShape.
563   /// REQUIRES: `size >= -1`, where -1 means unknown
564   PartialTensorShape Concatenate(int64_t size) const;
565 
566   /// Similar to `Concatenate` but returning `Status`.
567   /// Use if calling code cannot validate all requirements and if `CHECK`-fails
568   /// are to be avoided.
569   Status ConcatenateWithStatus(int64_t size, PartialTensorShape* out) const;
570 
571   /// Appends all the dimensions from `shape`.  Returns a new
572   /// PartialTensorShape.
573   PartialTensorShape Concatenate(const PartialTensorShape& shape) const;
574 
575   /// Similar to `Concatenate` but returning `Status`.
576   /// Use if calling code cannot validate all requirements and if `CHECK`-fails
577   /// are to be avoided.
578   Status ConcatenateWithStatus(const PartialTensorShape& shape,
579                                PartialTensorShape* out) const;
580 
581   /// Merges all the dimensions from `shape`.  Returns
582   /// `InvalidArgument` error if either `shape` has a different rank
583   /// or if any of the dimensions are incompatible.
584   Status MergeWith(const PartialTensorShape& shape,
585                    PartialTensorShape* result) const;
586 
587   /// Exact equality test. Returns true iff the ranks match (i.e., both are
588   /// unknown, or both are known and equal), and all dimensions are equal (i.e.,
589   /// both dimensions are known, or both are known and equal). This is a
590   /// stronger condition that IsCompatibleWith.
591   bool IsIdenticalTo(const PartialTensorShape& shape) const;
592 
593   /// Return true iff the ranks match, and if the
594   /// dimensions all either match or one is unknown.
595   bool IsCompatibleWith(const PartialTensorShape& shape) const;
596 
597   // Fill `*shape` from `*this`.
598   // If `*this` is not fully defined, returns false and
599   // `*shape` is left in an intermediate state.  Otherwise
600   // returns true.
601   bool AsTensorShape(TensorShape* shape) const;
602 
603   /// \brief Returns a `PartialTensorShape` whose dimensions are
604   /// `dims[0]`, `dims[1]`, ..., `dims[n-1]`.  Values of -1 are
605   /// considered "unknown".
606   template <class T>
MakePartialShape(const T * dims,int n,PartialTensorShape * out)607   static Status MakePartialShape(const T* dims, int n,
608                                  PartialTensorShape* out) {
609     return TensorShapeUtils::MakeShape(dims, n, out);
610   }
611 };
612 
613 /// \brief Static helper routines for `PartialTensorShape`. Includes a few
614 /// common predicates on a partially known tensor shape.
615 class PartialTensorShapeUtils {
616  public:
617   static std::string PartialShapeListString(
618       const gtl::ArraySlice<PartialTensorShape>& shapes);
619 
620   static bool AreIdentical(const gtl::ArraySlice<PartialTensorShape>& shapes0,
621                            const gtl::ArraySlice<PartialTensorShape>& shapes1);
622 
623   static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
624                             const gtl::ArraySlice<PartialTensorShape>& shapes1);
625 };
626 
627 // ----------------------------------------------------------------------------
628 // Template method implementation details below
629 // ----------------------------------------------------------------------------
630 
631 template <int NDIMS, typename IndexType>
AsEigenDSizesCopy()632 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopy() const {
633   Eigen::DSizes<IndexType, NDIMS> dsizes;
634   for (int d = 0; d < NDIMS; d++) {
635     dsizes[d] = static_cast<IndexType>(dim_size(d));
636   }
637   return dsizes;
638 }
639 
640 template <int NDIMS, typename IndexType>
AsEigenDSizesCopyAndPad()641 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesCopyAndPad() const {
642   static_assert(NDIMS <= TensorShape::MaxDimensions(), "Too many dimensions");
643   Eigen::DSizes<IndexType, NDIMS> dsizes;
644   for (int d = 0; d < dims(); d++) {
645     dsizes[d] = static_cast<IndexType>(dim_size(d));
646   }
647   for (int d = dims(); d < NDIMS; d++) {
648     dsizes[d] = 1;
649   }
650   return dsizes;
651 }
652 
653 template <int NDIMS, typename IndexType>
AsEigenDSizes()654 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizes() const {
655   CheckDimsEqual(NDIMS);
656   return AsEigenDSizesCopy<NDIMS, IndexType>();
657 }
658 
659 template <int NDIMS, typename IndexType>
AsEigenDSizesWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)660 Status TensorShape::AsEigenDSizesWithStatus(
661     Eigen::DSizes<IndexType, NDIMS>* out) const {
662   if (TF_PREDICT_FALSE(NDIMS != dims())) {
663     return errors::Internal("Asking for tensor of ", NDIMS,
664                             " dimensions from a tensor of ", dims(),
665                             " dimensions");
666   }
667   *out = AsEigenDSizesCopy<NDIMS, IndexType>();
668 }
669 
670 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPadding()671 Eigen::DSizes<IndexType, NDIMS> TensorShape::AsEigenDSizesWithPadding() const {
672   CheckDimsAtMost(NDIMS);
673   return AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
674 }
675 
676 template <int NDIMS, typename IndexType>
AsEigenDSizesWithPaddingWithStatus(Eigen::DSizes<IndexType,NDIMS> * out)677 Status TensorShape::AsEigenDSizesWithPaddingWithStatus(
678     Eigen::DSizes<IndexType, NDIMS>* out) const {
679   if (TF_PREDICT_FALSE(NDIMS < dims())) {
680     return errors::Internal("Asking for tensor of at least ", NDIMS,
681                             " dimensions from a tensor of ", dims(),
682                             " dimensions");
683   }
684   *out = AsEigenDSizesCopyAndPad<NDIMS, IndexType>();
685 }
686 
687 // ----------------------------------------------------------------------------
688 // Inlining of some performance critical routines
689 // ----------------------------------------------------------------------------
690 
TensorShapeRep(const TensorShapeRep & b)691 inline TensorShapeRep::TensorShapeRep(const TensorShapeRep& b) {
692   num_elements_ = b.num_elements_;
693   if (b.tag() != REP_OUT_OF_LINE) {
694     memcpy(buf(), b.buf(), sizeof(u_.buf));
695     // memcpy above Implicitly does:
696     //   set_ndims_byte(b.ndims_byte());
697     //   set_tag(b.tag());
698   } else {
699     set_tag(REP16);  // So that SlowCopyFrom does not try to deallocate
700     SlowCopyFrom(b);
701   }
702 }
703 
TensorShapeRep(TensorShapeRep && b)704 inline TensorShapeRep::TensorShapeRep(TensorShapeRep&& b) {
705   num_elements_ = b.num_elements_;
706   memcpy(buf(), b.buf(), sizeof(u_.buf));
707   // memcpy above Implicitly does:
708   //   set_ndims_byte(b.ndims_byte());
709   //   set_tag(b.tag());
710   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
711 }
712 
~TensorShapeRep()713 inline TensorShapeRep::~TensorShapeRep() {
714   if (tag() == REP_OUT_OF_LINE) {
715     DestructorOutOfLine();
716   }
717 }
718 
719 inline void TensorShapeRep::operator=(const TensorShapeRep& b) {
720   num_elements_ = b.num_elements_;
721   if (tag() != REP_OUT_OF_LINE && b.tag() != REP_OUT_OF_LINE) {
722     memcpy(buf(), b.buf(), sizeof(u_.buf));
723     // memcpy above implicitly also does:
724     //   set_tag(b.tag());
725     //   set_ndims_byte(b.ndims_byte());
726   } else {
727     SlowCopyFrom(b);
728   }
729 }
730 
731 inline void TensorShapeRep::operator=(TensorShapeRep&& b) {
732   if (tag() == REP_OUT_OF_LINE) {
733     DestructorOutOfLine();
734   }
735   num_elements_ = b.num_elements_;
736   memcpy(buf(), b.buf(), sizeof(u_.buf));
737   // memcpy above Implicitly does:
738   //   set_ndims_byte(b.ndims_byte());
739   //   set_tag(b.tag());
740   b.set_tag(REP16);  // other shape no longer owns out-of-line data, if any.
741 }
742 
743 inline TensorShape::operator const PartialTensorShape&() const {
744   // Downcast to the shared representation and upcast to PartialTensorShape
745   const TensorShapeRep* rep = this;
746   return *static_cast<const PartialTensorShape*>(rep);
747 }
748 
749 template <class Shape>
TensorShapeBase(DataType dt)750 inline TensorShapeBase<Shape>::TensorShapeBase(DataType dt) {
751   set_tag(REP16);
752   set_data_type(dt);
753 
754   // Optimized implementation of InitDims() where the shape is statically known
755   // to be {0}.
756   set_ndims_byte(1);
757   uint16* dst = as16()->dims_;
758   *dst = 0;
759   set_num_elements(0);
760 }
761 
762 // Declare explicit instantiations in .cc file
763 extern template class TensorShapeBase<TensorShape>;
764 extern template class TensorShapeBase<PartialTensorShape>;
765 
766 // A convenient struct to represent a (DataType, PartialTensorShape) pair. It's
767 // often used in shape inference.
768 struct DtypeAndPartialTensorShape {
769   DataType dtype;
770   PartialTensorShape shape;
771 };
772 
773 }  // namespace tensorflow
774 
775 #endif  // TENSORFLOW_CORE_FRAMEWORK_TENSOR_SHAPE_H_
776