xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/core/tensor_data.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_AGGREGATION_CORE_TENSOR_DATA_H_
18 #define FCP_AGGREGATION_CORE_TENSOR_DATA_H_
19 
20 #include <cstddef>
21 
22 #include "fcp/base/monitoring.h"
23 
24 namespace fcp {
25 namespace aggregation {
26 
27 // Abstract representation of tensor data storage.
28 //
29 // Tensor data is flattened one-dimensional array of tensor of tensor values
30 // where each values takes sizeof(T) bytes.
31 //
32 // All tensor values are stored in a single blob regardless of whether the
33 // tensor is dense or sparse.
34 //
35 // If the tensor is dense, then the values are flattened into
36 // one-dimensional array the following way:
37 // - First iterating over the last dimension
38 // - Then incrementing the second from the last dimension and then iterating
39 //   over the last dimension
40 // - Then gradually moving towards the first dimension.
41 // For example, if we had a 3-dimensional {3 x 2 x 4} Tensor, the values
42 // in TensorData would be ordered in the following way, showing 3-dimensional
43 // indices of the tensor values:
44 //   (0,0,0), (0,0,1), (0,0,2), (0,0,3)
45 //   (0,1,0), (0,1,1), (0,1,2), (0,1,3)
46 //   (1,0,0), (1,0,1), (1,0,2), (1,0,3)
47 //   (1,1,0), (1,1,1), (1,1,2), (1,1,3)
48 //   (2,0,0), (2,0,1), (2,0,2), (2,0,3)
49 //   (2,1,0), (2,1,1), (2,1,2), (2,1,3)
50 //
51 // If the tensor is sparse, then the order of values in the array is arbitrary
52 // and can be described by the tensor SparsityParameters which describes the
53 // mapping from the value indices in tensor data to indices in the dense tensor
54 // flattened the way described above.
55 //
56 // The tensor data can be backed by different implementations depending on
57 // where the data comes from.
58 class TensorData {
59  public:
60   virtual ~TensorData() = default;
61 
62   // Tensor data pointer.
63   virtual const void* data() const = 0;
64 
65   // The overall size of the tensor data in bytes.
66   virtual size_t byte_size() const = 0;
67 
68   // Validates TensorData constraints given the specified value_size.
69   // The value_size is the size of the native data type (e.g. 4 bytes for int32
70   // or float, 8 bytes for int64). This is used to verify data alignment - that
71   // all offsets and sizes are multiples of value_size that pointers are memory
72   // aligned to the value_size.
73   // TODO(team): Consider separate sizes for the pointer alignment and
74   // the slices offsets/sizes. The latter may need to be more coarse.
75   Status CheckValid(size_t value_size) const;
76 };
77 
78 }  // namespace aggregation
79 }  // namespace fcp
80 
81 #endif  // FCP_AGGREGATION_CORE_TENSOR_DATA_H_
82