1 /*
2 * Copyright 2022 Google LLC
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #ifndef FCP_AGGREGATION_TESTING_TESTING_H_
18 #define FCP_AGGREGATION_TESTING_TESTING_H_
19
20 #include <initializer_list>
21 #include <ostream>
22 #include <string>
23 #include <vector>
24
25 #include "gmock/gmock.h"
26 #include "gtest/gtest.h"
27 #include "fcp/aggregation/core/datatype.h"
28 #include "fcp/aggregation/core/tensor.h"
29 #include "fcp/aggregation/core/tensor_shape.h"
30 #include "tensorflow/cc/framework/ops.h"
31 #include "tensorflow/core/framework/tensor_shape.h"
32 #include "tensorflow/core/framework/types.pb.h"
33
34 namespace fcp::aggregation {
35
36 namespace tf = ::tensorflow;
37
38 template <typename T>
CreateTfTensor(tf::DataType data_type,std::initializer_list<int64_t> dim_sizes,std::initializer_list<T> values)39 tf::Tensor CreateTfTensor(tf::DataType data_type,
40 std::initializer_list<int64_t> dim_sizes,
41 std::initializer_list<T> values) {
42 tf::TensorShape shape;
43 EXPECT_TRUE(tf::TensorShape::BuildTensorShape(dim_sizes, &shape).ok());
44 tf::Tensor tensor(data_type, shape);
45 T* tensor_data_ptr = reinterpret_cast<T*>(tensor.data());
46 for (auto value : values) {
47 *tensor_data_ptr++ = value;
48 }
49 return tensor;
50 }
51
52 tf::Tensor CreateStringTfTensor(std::initializer_list<int64_t> dim_sizes,
53 std::initializer_list<string_view> values);
54
55 // Wrapper around tf::ops::Save that sets up and runs the op.
56 tf::Status CreateTfCheckpoint(tf::Input filename, tf::Input tensor_names,
57 tf::InputList tensors);
58
59 // Returns a summary of the checkpoint as a map of tensor names and values.
60 absl::StatusOr<absl::flat_hash_map<std::string, std::string>>
61 SummarizeCheckpoint(const absl::Cord& checkpoint);
62
63 // Converts a potentially sparse tensor to a flat vector of tensor values.
64 template <typename T>
TensorValuesToVector(const Tensor & arg)65 std::vector<T> TensorValuesToVector(const Tensor& arg) {
66 std::vector<T> vec(arg.shape().NumElements());
67 AggVector<T> agg_vector = arg.AsAggVector<T>();
68 for (auto [i, v] : agg_vector) {
69 vec[i] = v;
70 }
71 return vec;
72 }
73
74 // Writes description of a tensor to the ostream.
75 template <typename T>
DescribeTensor(::std::ostream * os,DataType dtype,TensorShape shape,std::vector<T> values)76 void DescribeTensor(::std::ostream* os, DataType dtype, TensorShape shape,
77 std::vector<T> values) {
78 // Max number of tensor values to be printed.
79 constexpr int kMaxValues = 100;
80 // TODO(team): Print dtype name istead of number.
81 *os << "{dtype: " << dtype;
82 *os << ", shape: {";
83 bool insert_comma = false;
84 for (auto dim_size : shape.dim_sizes()) {
85 if (insert_comma) {
86 *os << ", ";
87 }
88 *os << dim_size;
89 insert_comma = true;
90 }
91 *os << "}, values: {";
92 int num_values = 0;
93 insert_comma = false;
94 for (auto v : values) {
95 if (++num_values > kMaxValues) {
96 *os << "...";
97 break;
98 }
99 if (insert_comma) {
100 *os << ", ";
101 }
102 *os << v;
103 insert_comma = true;
104 }
105 *os << "}}";
106 }
107
108 // Writes description of a tensor to the ostream.
109 std::ostream& operator<<(std::ostream& os, const Tensor& tensor);
110
111 // TensorMatcher implementation.
112 template <typename T>
113 class TensorMatcherImpl : public ::testing::MatcherInterface<const Tensor&> {
114 public:
TensorMatcherImpl(DataType expected_dtype,TensorShape expected_shape,std::vector<T> expected_values)115 TensorMatcherImpl(DataType expected_dtype, TensorShape expected_shape,
116 std::vector<T> expected_values)
117 : expected_dtype_(expected_dtype),
118 expected_shape_(expected_shape),
119 expected_values_(expected_values) {}
120
DescribeTo(std::ostream * os)121 void DescribeTo(std::ostream* os) const override {
122 DescribeTensor<T>(os, expected_dtype_, expected_shape_, expected_values_);
123 }
124
MatchAndExplain(const Tensor & arg,::testing::MatchResultListener * listener)125 bool MatchAndExplain(
126 const Tensor& arg,
127 ::testing::MatchResultListener* listener) const override {
128 return arg.dtype() == expected_dtype_ && arg.shape() == expected_shape_ &&
129 TensorValuesToVector<T>(arg) == expected_values_;
130 }
131
132 private:
133 DataType expected_dtype_;
134 TensorShape expected_shape_;
135 std::vector<T> expected_values_;
136 };
137
138 // TensorMatcher can be used to compare a tensor against an expected
139 // value type, shape, and the list of values.
140 template <typename T>
141 class TensorMatcher {
142 public:
TensorMatcher(DataType expected_dtype,TensorShape expected_shape,std::initializer_list<T> expected_values)143 explicit TensorMatcher(DataType expected_dtype, TensorShape expected_shape,
144 std::initializer_list<T> expected_values)
145 : expected_dtype_(expected_dtype),
146 expected_shape_(expected_shape),
147 expected_values_(expected_values.begin(), expected_values.end()) {}
148 // Intentionally allowed to be implicit.
149 operator ::testing::Matcher<const Tensor&>() const { // NOLINT
150 return ::testing::MakeMatcher(new TensorMatcherImpl<T>(
151 expected_dtype_, expected_shape_, expected_values_));
152 }
153
154 private:
155 DataType expected_dtype_;
156 TensorShape expected_shape_;
157 std::vector<T> expected_values_;
158 };
159
160 template <typename T>
IsTensor(TensorShape expected_shape,std::initializer_list<T> expected_values)161 TensorMatcher<T> IsTensor(TensorShape expected_shape,
162 std::initializer_list<T> expected_values) {
163 return TensorMatcher<T>(internal::TypeTraits<T>::kDataType, expected_shape,
164 expected_values);
165 }
166
167 } // namespace fcp::aggregation
168
169 #endif // FCP_AGGREGATION_TESTING_TESTING_H_
170