xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/data/experimental/sampling_dataset_op_test.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5     http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/experimental/sampling_dataset_op.h"
13 
14 #include "tensorflow/core/data/dataset_test_base.h"
15 
16 namespace tensorflow {
17 namespace data {
18 namespace experimental {
19 namespace {
20 
21 constexpr char kNodeName[] = "sampling_dataset";
22 constexpr int64_t kRandomSeed = 42;
23 constexpr int64_t kRandomSeed2 = 7;
24 
25 class SamplingDatasetParams : public DatasetParams {
26  public:
27   template <typename T>
SamplingDatasetParams(T input_dataset_params,float rate,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)28   SamplingDatasetParams(T input_dataset_params, float rate,
29                         DataTypeVector output_dtypes,
30                         std::vector<PartialTensorShape> output_shapes,
31                         string node_name)
32       : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
33                       std::move(node_name)),
34         rate_(rate) {
35     input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
36     iterator_prefix_ =
37         name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
38                                    input_dataset_params.iterator_prefix());
39   }
40 
GetInputTensors() const41   std::vector<Tensor> GetInputTensors() const override {
42     Tensor rate = CreateTensor<float>(TensorShape({}), {rate_});
43     Tensor seed_tensor = CreateTensor<int64_t>(TensorShape({}), {seed_tensor_});
44     Tensor seed2_tensor =
45         CreateTensor<int64_t>(TensorShape({}), {seed2_tensor_});
46     return {rate, seed_tensor, seed2_tensor};
47   }
48 
GetInputNames(std::vector<string> * input_names) const49   Status GetInputNames(std::vector<string>* input_names) const override {
50     *input_names = {SamplingDatasetOp::kInputDataset, SamplingDatasetOp::kRate,
51                     SamplingDatasetOp::kSeed, SamplingDatasetOp::kSeed2};
52 
53     return OkStatus();
54   }
55 
GetAttributes(AttributeVector * attr_vector) const56   Status GetAttributes(AttributeVector* attr_vector) const override {
57     *attr_vector = {{SamplingDatasetOp::kOutputTypes, output_dtypes_},
58                     {SamplingDatasetOp::kOutputShapes, output_shapes_}};
59     return OkStatus();
60   }
61 
dataset_type() const62   string dataset_type() const override {
63     return SamplingDatasetOp::kDatasetType;
64   }
65 
66  private:
67   // Target sample rate, range (0,1], wrapped in a scalar Tensor
68   float rate_;
69   // Boxed versions of kRandomSeed and kRandomSeed2.
70   int64_t seed_tensor_ = kRandomSeed;
71   int64_t seed2_tensor_ = kRandomSeed2;
72 };
73 
74 class SamplingDatasetOpTest : public DatasetOpsTestBase {};
75 
OneHundredPercentSampleParams()76 SamplingDatasetParams OneHundredPercentSampleParams() {
77   return SamplingDatasetParams(RangeDatasetParams(0, 3, 1),
78                                /*rate=*/1.0,
79                                /*output_dtypes=*/{DT_INT64},
80                                /*output_shapes=*/{PartialTensorShape({})},
81                                /*node_name=*/kNodeName);
82 }
83 
TenPercentSampleParams()84 SamplingDatasetParams TenPercentSampleParams() {
85   return SamplingDatasetParams(RangeDatasetParams(0, 20, 1),
86                                /*rate=*/0.1,
87                                /*output_dtypes=*/{DT_INT64},
88                                /*output_shapes=*/{PartialTensorShape({})},
89                                /*node_name=*/kNodeName);
90 }
91 
ZeroPercentSampleParams()92 SamplingDatasetParams ZeroPercentSampleParams() {
93   return SamplingDatasetParams(RangeDatasetParams(0, 20, 1),
94                                /*rate=*/0.0,
95                                /*output_dtypes=*/{DT_INT64},
96                                /*output_shapes=*/{PartialTensorShape({})},
97                                /*node_name=*/kNodeName);
98 }
99 
GetNextTestCases()100 std::vector<GetNextTestCase<SamplingDatasetParams>> GetNextTestCases() {
101   return {
102       // Test case 1: 100% sample should return all inputs
103       {/*dataset_params=*/OneHundredPercentSampleParams(),
104        /*expected_outputs=*/CreateTensors<int64_t>(TensorShape({}),
105                                                    {{0}, {1}, {2}})},
106 
107       // Test case 2: 10% sample should return about 10% of inputs, and the
108       // specific inputs returned shouldn't change across build environments.
109       {/*dataset_params=*/TenPercentSampleParams(),
110        /*expected_outputs=*/CreateTensors<int64_t>(TensorShape({}),
111                                                    {{9}, {11}, {19}})},
112 
113       // Test case 3: 0% sample should return nothing and should not crash.
114       {/*dataset_params=*/ZeroPercentSampleParams(), /*expected_outputs=*/{}}};
115 }
116 
ITERATOR_GET_NEXT_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,GetNextTestCases ())117 ITERATOR_GET_NEXT_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
118                          GetNextTestCases())
119 
120 std::vector<DatasetNodeNameTestCase<SamplingDatasetParams>>
121 DatasetNodeNameTestCases() {
122   return {{/*dataset_params=*/TenPercentSampleParams(),
123            /*expected_node_name=*/kNodeName}};
124 }
125 
DATASET_NODE_NAME_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,DatasetNodeNameTestCases ())126 DATASET_NODE_NAME_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
127                          DatasetNodeNameTestCases())
128 
129 std::vector<DatasetTypeStringTestCase<SamplingDatasetParams>>
130 DatasetTypeStringTestCases() {
131   return {{/*dataset_params=*/TenPercentSampleParams(),
132            /*expected_dataset_type_string=*/name_utils::OpName(
133                SamplingDatasetOp::kDatasetType)}};
134 }
135 
DATASET_TYPE_STRING_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,DatasetTypeStringTestCases ())136 DATASET_TYPE_STRING_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
137                            DatasetTypeStringTestCases())
138 
139 std::vector<DatasetOutputDtypesTestCase<SamplingDatasetParams>>
140 DatasetOutputDtypesTestCases() {
141   return {{/*dataset_params=*/TenPercentSampleParams(),
142            /*expected_output_dtypes=*/{DT_INT64}}};
143 }
144 
DATASET_OUTPUT_DTYPES_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,DatasetOutputDtypesTestCases ())145 DATASET_OUTPUT_DTYPES_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
146                              DatasetOutputDtypesTestCases())
147 
148 std::vector<DatasetOutputShapesTestCase<SamplingDatasetParams>>
149 DatasetOutputShapesTestCases() {
150   return {{/*dataset_params=*/TenPercentSampleParams(),
151            /*expected_output_shapes=*/{PartialTensorShape({})}}};
152 }
153 
DATASET_OUTPUT_SHAPES_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,DatasetOutputShapesTestCases ())154 DATASET_OUTPUT_SHAPES_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
155                              DatasetOutputShapesTestCases())
156 
157 std::vector<CardinalityTestCase<SamplingDatasetParams>> CardinalityTestCases() {
158   return {{/*dataset_params=*/OneHundredPercentSampleParams(),
159            /*expected_cardinality=*/kUnknownCardinality},
160           {/*dataset_params=*/TenPercentSampleParams(),
161            /*expected,cardinality=*/kUnknownCardinality},
162           {/*dataset_params=*/ZeroPercentSampleParams(),
163            /*expected_cardinality=*/kUnknownCardinality}};
164 }
165 
DATASET_CARDINALITY_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,CardinalityTestCases ())166 DATASET_CARDINALITY_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
167                            CardinalityTestCases())
168 
169 std::vector<IteratorOutputDtypesTestCase<SamplingDatasetParams>>
170 IteratorOutputDtypesTestCases() {
171   return {{/*dataset_params=*/TenPercentSampleParams(),
172            /*expected_output_dtypes=*/{DT_INT64}}};
173 }
174 
ITERATOR_OUTPUT_DTYPES_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,IteratorOutputDtypesTestCases ())175 ITERATOR_OUTPUT_DTYPES_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
176                               IteratorOutputDtypesTestCases())
177 
178 std::vector<IteratorOutputShapesTestCase<SamplingDatasetParams>>
179 IteratorOutputShapesTestCases() {
180   return {{/*dataset_params=*/TenPercentSampleParams(),
181            /*expected_output_shapes=*/{PartialTensorShape({})}}};
182 }
183 
ITERATOR_OUTPUT_SHAPES_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,IteratorOutputShapesTestCases ())184 ITERATOR_OUTPUT_SHAPES_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
185                               IteratorOutputShapesTestCases())
186 
187 std::vector<IteratorPrefixTestCase<SamplingDatasetParams>>
188 IteratorOutputPrefixTestCases() {
189   return {{/*dataset_params=*/TenPercentSampleParams(),
190            /*expected_iterator_prefix=*/name_utils::IteratorPrefix(
191                SamplingDatasetOp::kDatasetType,
192                TenPercentSampleParams().iterator_prefix())}};
193 }
194 
ITERATOR_PREFIX_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,IteratorOutputPrefixTestCases ())195 ITERATOR_PREFIX_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
196                        IteratorOutputPrefixTestCases())
197 
198 std::vector<IteratorSaveAndRestoreTestCase<SamplingDatasetParams>>
199 IteratorSaveAndRestoreTestCases() {
200   return {{/*dataset_params=*/OneHundredPercentSampleParams(),
201            /*breakpoints=*/{0, 2, 5},
202            /*expected_outputs=*/
203            CreateTensors<int64_t>(TensorShape({}), {{0}, {1}, {2}})},
204           {/*dataset_params=*/TenPercentSampleParams(),
205            /*breakpoints=*/{0, 2, 5},
206            /*expected_outputs=*/
207            CreateTensors<int64_t>(TensorShape({}), {{9}, {11}, {19}})},
208           {/*dataset_params=*/ZeroPercentSampleParams(),
209            /*breakpoints=*/{0, 2, 5},
210            /*expected_outputs=*/{}}};
211 }
212 
213 ITERATOR_SAVE_AND_RESTORE_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
214                                  IteratorSaveAndRestoreTestCases())
215 
216 }  // namespace
217 }  // namespace experimental
218 }  // namespace data
219 }  // namespace tensorflow
220