xref: /aosp_15_r20/external/federated-compute/fcp/aggregation/testing/testing.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2022 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_AGGREGATION_TESTING_TESTING_H_
18 #define FCP_AGGREGATION_TESTING_TESTING_H_
19 
20 #include <initializer_list>
21 #include <ostream>
22 #include <string>
23 #include <vector>
24 
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "fcp/aggregation/core/datatype.h"
28 #include "fcp/aggregation/core/tensor.h"
29 #include "fcp/aggregation/core/tensor_shape.h"
30 #include "tensorflow/cc/framework/ops.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/types.pb.h"
33 
34 namespace fcp::aggregation {
35 
36 namespace tf = ::tensorflow;
37 
38 template <typename T>
CreateTfTensor(tf::DataType data_type,std::initializer_list<int64_t> dim_sizes,std::initializer_list<T> values)39 tf::Tensor CreateTfTensor(tf::DataType data_type,
40                           std::initializer_list<int64_t> dim_sizes,
41                           std::initializer_list<T> values) {
42   tf::TensorShape shape;
43   EXPECT_TRUE(tf::TensorShape::BuildTensorShape(dim_sizes, &shape).ok());
44   tf::Tensor tensor(data_type, shape);
45   T* tensor_data_ptr = reinterpret_cast<T*>(tensor.data());
46   for (auto value : values) {
47     *tensor_data_ptr++ = value;
48   }
49   return tensor;
50 }
51 
52 tf::Tensor CreateStringTfTensor(std::initializer_list<int64_t> dim_sizes,
53                                 std::initializer_list<string_view> values);
54 
55 // Wrapper around tf::ops::Save that sets up and runs the op.
56 tf::Status CreateTfCheckpoint(tf::Input filename, tf::Input tensor_names,
57                               tf::InputList tensors);
58 
59 // Returns a summary of the checkpoint as a map of tensor names and values.
60 absl::StatusOr<absl::flat_hash_map<std::string, std::string>>
61 SummarizeCheckpoint(const absl::Cord& checkpoint);
62 
63 // Converts a potentially sparse tensor to a flat vector of tensor values.
64 template <typename T>
TensorValuesToVector(const Tensor & arg)65 std::vector<T> TensorValuesToVector(const Tensor& arg) {
66   std::vector<T> vec(arg.shape().NumElements());
67   AggVector<T> agg_vector = arg.AsAggVector<T>();
68   for (auto [i, v] : agg_vector) {
69     vec[i] = v;
70   }
71   return vec;
72 }
73 
74 // Writes description of a tensor to the ostream.
75 template <typename T>
DescribeTensor(::std::ostream * os,DataType dtype,TensorShape shape,std::vector<T> values)76 void DescribeTensor(::std::ostream* os, DataType dtype, TensorShape shape,
77                     std::vector<T> values) {
78   // Max number of tensor values to be printed.
79   constexpr int kMaxValues = 100;
80   // TODO(team): Print dtype name istead of number.
81   *os << "{dtype: " << dtype;
82   *os << ", shape: {";
83   bool insert_comma = false;
84   for (auto dim_size : shape.dim_sizes()) {
85     if (insert_comma) {
86       *os << ", ";
87     }
88     *os << dim_size;
89     insert_comma = true;
90   }
91   *os << "}, values: {";
92   int num_values = 0;
93   insert_comma = false;
94   for (auto v : values) {
95     if (++num_values > kMaxValues) {
96       *os << "...";
97       break;
98     }
99     if (insert_comma) {
100       *os << ", ";
101     }
102     *os << v;
103     insert_comma = true;
104   }
105   *os << "}}";
106 }
107 
108 // Writes description of a tensor to the ostream.
109 std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
110 
111 // TensorMatcher implementation.
112 template <typename T>
113 class TensorMatcherImpl : public ::testing::MatcherInterface<const Tensor&> {
114  public:
TensorMatcherImpl(DataType expected_dtype,TensorShape expected_shape,std::vector<T> expected_values)115   TensorMatcherImpl(DataType expected_dtype, TensorShape expected_shape,
116                     std::vector<T> expected_values)
117       : expected_dtype_(expected_dtype),
118         expected_shape_(expected_shape),
119         expected_values_(expected_values) {}
120 
DescribeTo(std::ostream * os)121   void DescribeTo(std::ostream* os) const override {
122     DescribeTensor<T>(os, expected_dtype_, expected_shape_, expected_values_);
123   }
124 
MatchAndExplain(const Tensor & arg,::testing::MatchResultListener * listener)125   bool MatchAndExplain(
126       const Tensor& arg,
127       ::testing::MatchResultListener* listener) const override {
128     return arg.dtype() == expected_dtype_ && arg.shape() == expected_shape_ &&
129            TensorValuesToVector<T>(arg) == expected_values_;
130   }
131 
132  private:
133   DataType expected_dtype_;
134   TensorShape expected_shape_;
135   std::vector<T> expected_values_;
136 };
137 
138 // TensorMatcher can be used to compare a tensor against an expected
139 // value type, shape, and the list of values.
140 template <typename T>
141 class TensorMatcher {
142  public:
TensorMatcher(DataType expected_dtype,TensorShape expected_shape,std::initializer_list<T> expected_values)143   explicit TensorMatcher(DataType expected_dtype, TensorShape expected_shape,
144                          std::initializer_list<T> expected_values)
145       : expected_dtype_(expected_dtype),
146         expected_shape_(expected_shape),
147         expected_values_(expected_values.begin(), expected_values.end()) {}
148   // Intentionally allowed to be implicit.
149   operator ::testing::Matcher<const Tensor&>() const {  // NOLINT
150     return ::testing::MakeMatcher(new TensorMatcherImpl<T>(
151         expected_dtype_, expected_shape_, expected_values_));
152   }
153 
154  private:
155   DataType expected_dtype_;
156   TensorShape expected_shape_;
157   std::vector<T> expected_values_;
158 };
159 
160 template <typename T>
IsTensor(TensorShape expected_shape,std::initializer_list<T> expected_values)161 TensorMatcher<T> IsTensor(TensorShape expected_shape,
162                           std::initializer_list<T> expected_values) {
163   return TensorMatcher<T>(internal::TypeTraits<T>::kDataType, expected_shape,
164                           expected_values);
165 }
166 
167 }  // namespace fcp::aggregation
168 
169 #endif  // FCP_AGGREGATION_TESTING_TESTING_H_
170