1 /* 2 * Copyright 2022 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_AGGREGATION_CORE_TENSOR_H_ 18 #define FCP_AGGREGATION_CORE_TENSOR_H_ 19 20 #include <memory> 21 #include <utility> 22 23 #include "fcp/aggregation/core/agg_vector.h" 24 #include "fcp/aggregation/core/datatype.h" 25 #include "fcp/aggregation/core/tensor_data.h" 26 #include "fcp/aggregation/core/tensor_shape.h" 27 #include "fcp/base/monitoring.h" 28 29 #ifndef FCP_NANOLIBC 30 #include "fcp/aggregation/core/tensor.pb.h" 31 #endif 32 33 namespace fcp { 34 namespace aggregation { 35 36 // Tensor class is a container that packages the tensor data with the tensor 37 // metadata such as the value type and the shape. 38 // 39 // For the most part, the aggregation code won't be consuming tensors directly. 40 // Instead the aggregation code will be working with AggVector instances that 41 // represent the tensor data in a flattened way. 42 class Tensor final { 43 public: 44 // Tensor class isn't copyable. 45 Tensor(const Tensor&) = delete; 46 47 // Move constructor. Tensor(Tensor && other)48 Tensor(Tensor&& other) 49 : dtype_(other.dtype_), 50 shape_(std::move(other.shape_)), 51 data_(std::move(other.data_)) { 52 other.dtype_ = DT_INVALID; 53 } 54 55 // Move assignment. 56 Tensor& operator=(Tensor&& other) { 57 dtype_ = other.dtype_; 58 shape_ = std::move(other.shape_); 59 data_ = std::move(other.data_); 60 other.dtype_ = DT_INVALID; 61 return *this; 62 } 63 64 // Define a default constructor to allow for initalization of array 65 // to enable creation of a vector of Tensors. 66 // A tensor created with the default constructor is not valid and thus should 67 // not actually be used. Tensor()68 Tensor() : dtype_(DT_INVALID), shape_{}, data_(nullptr) {} 69 70 // Validates parameters and creates a Tensor instance. 71 static StatusOr<Tensor> Create(DataType dtype, TensorShape shape, 72 std::unique_ptr<TensorData> data); 73 74 #ifndef FCP_NANOLIBC 75 // Creates a Tensor instance from a TensorProto. 76 static StatusOr<Tensor> FromProto(const TensorProto& tensor_proto); 77 78 // Creates a Tensor instance from a TensorProto, consuming the proto. 79 static StatusOr<Tensor> FromProto(TensorProto&& tensor_proto); 80 81 // Converts Tensor to TensorProto 82 TensorProto ToProto() const; 83 #endif // FCP_NANOLIBC 84 85 // Validates the tensor. 86 Status CheckValid() const; 87 88 // Gets the tensor value type. dtype()89 DataType dtype() const { return dtype_; } 90 91 // Gets the tensor shape. shape()92 const TensorShape& shape() const { return shape_; } 93 94 // Readonly access to the tensor data. data()95 const TensorData& data() const { return *data_; } 96 97 // Returns true is the current tensor data is dense. 98 // TODO(team): Implement sparse tensors. is_dense()99 bool is_dense() const { return true; } 100 101 // Provides access to the tensor data via a strongly typed AggVector. 102 template <typename T> AsAggVector()103 AggVector<T> AsAggVector() const { 104 FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype_) 105 << "Incompatible tensor dtype()"; 106 return AggVector<T>(data_.get()); 107 } 108 109 // TODO(team): Add serialization functions. 110 111 private: Tensor(DataType dtype,TensorShape shape,std::unique_ptr<TensorData> data)112 Tensor(DataType dtype, TensorShape shape, std::unique_ptr<TensorData> data) 113 : dtype_(dtype), shape_(std::move(shape)), data_(std::move(data)) {} 114 115 // Tensor data type. 116 DataType dtype_; 117 // Tensor shape. 118 TensorShape shape_; 119 // The underlying tensor data. 120 std::unique_ptr<TensorData> data_; 121 }; 122 123 } // namespace aggregation 124 } // namespace fcp 125 126 #endif // FCP_AGGREGATION_CORE_TENSOR_H_ 127