xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/tensor.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_AGGREGATION_CORE_TENSOR_H_
18 #define FCP_AGGREGATION_CORE_TENSOR_H_
19 
20 #include <memory>
21 #include <utility>
22 
23 #include "fcp/aggregation/core/agg_vector.h"
24 #include "fcp/aggregation/core/datatype.h"
25 #include "fcp/aggregation/core/tensor_data.h"
26 #include "fcp/aggregation/core/tensor_shape.h"
27 #include "fcp/base/monitoring.h"
28 
29 #ifndef FCP_NANOLIBC
30 #include "fcp/aggregation/core/tensor.pb.h"
31 #endif
32 
33 namespace fcp {
34 namespace aggregation {
35 
36 // Tensor class is a container that packages the tensor data with the tensor
37 // metadata such as the value type and the shape.
38 //
39 // For the most part, the aggregation code won't be consuming tensors directly.
40 // Instead the aggregation code will be working with AggVector instances that
41 // represent the tensor data in a flattened way.
42 class Tensor final {
43  public:
44   // Tensor class isn't copyable.
45   Tensor(const Tensor&) = delete;
46 
47   // Move constructor.
Tensor(Tensor && other)48   Tensor(Tensor&& other)
49       : dtype_(other.dtype_),
50         shape_(std::move(other.shape_)),
51         data_(std::move(other.data_)) {
52     other.dtype_ = DT_INVALID;
53   }
54 
55   // Move assignment.
56   Tensor& operator=(Tensor&& other) {
57     dtype_ = other.dtype_;
58     shape_ = std::move(other.shape_);
59     data_ = std::move(other.data_);
60     other.dtype_ = DT_INVALID;
61     return *this;
62   }
63 
64   // Define a default constructor to allow for initalization of array
65   // to enable creation of a vector of Tensors.
66   // A tensor created with the default constructor is not valid and thus should
67   // not actually be used.
Tensor()68   Tensor() : dtype_(DT_INVALID), shape_{}, data_(nullptr) {}
69 
70   // Validates parameters and creates a Tensor instance.
71   static StatusOr<Tensor> Create(DataType dtype, TensorShape shape,
72                                  std::unique_ptr<TensorData> data);
73 
74 #ifndef FCP_NANOLIBC
75   // Creates a Tensor instance from a TensorProto.
76   static StatusOr<Tensor> FromProto(const TensorProto& tensor_proto);
77 
78   // Creates a Tensor instance from a TensorProto, consuming the proto.
79   static StatusOr<Tensor> FromProto(TensorProto&& tensor_proto);
80 
81   // Converts Tensor to TensorProto
82   TensorProto ToProto() const;
83 #endif  // FCP_NANOLIBC
84 
85   // Validates the tensor.
86   Status CheckValid() const;
87 
88   // Gets the tensor value type.
dtype()89   DataType dtype() const { return dtype_; }
90 
91   // Gets the tensor shape.
shape()92   const TensorShape& shape() const { return shape_; }
93 
94   // Readonly access to the tensor data.
data()95   const TensorData& data() const { return *data_; }
96 
97   // Returns true is the current tensor data is dense.
98   // TODO(team): Implement sparse tensors.
is_dense()99   bool is_dense() const { return true; }
100 
101   // Provides access to the tensor data via a strongly typed AggVector.
102   template <typename T>
AsAggVector()103   AggVector<T> AsAggVector() const {
104     FCP_CHECK(internal::TypeTraits<T>::kDataType == dtype_)
105         << "Incompatible tensor dtype()";
106     return AggVector<T>(data_.get());
107   }
108 
109   // TODO(team): Add serialization functions.
110 
111  private:
Tensor(DataType dtype,TensorShape shape,std::unique_ptr<TensorData> data)112   Tensor(DataType dtype, TensorShape shape, std::unique_ptr<TensorData> data)
113       : dtype_(dtype), shape_(std::move(shape)), data_(std::move(data)) {}
114 
115   // Tensor data type.
116   DataType dtype_;
117   // Tensor shape.
118   TensorShape shape_;
119   // The underlying tensor data.
120   std::unique_ptr<TensorData> data_;
121 };
122 
123 }  // namespace aggregation
124 }  // namespace fcp
125 
126 #endif  // FCP_AGGREGATION_CORE_TENSOR_H_
127