1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 Licensed under the Apache License, Version 2.0 (the "License");
3 you may not use this file except in compliance with the License.
4 You may obtain a copy of the License at
5 http://www.apache.org/licenses/LICENSE-2.0
6 Unless required by applicable law or agreed to in writing, software
7 distributed under the License is distributed on an "AS IS" BASIS,
8 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9 See the License for the specific language governing permissions and
10 limitations under the License.
11 ==============================================================================*/
12 #include "tensorflow/core/kernels/data/experimental/sampling_dataset_op.h"
13
14 #include "tensorflow/core/data/dataset_test_base.h"
15
16 namespace tensorflow {
17 namespace data {
18 namespace experimental {
19 namespace {
20
21 constexpr char kNodeName[] = "sampling_dataset";
22 constexpr int64_t kRandomSeed = 42;
23 constexpr int64_t kRandomSeed2 = 7;
24
25 class SamplingDatasetParams : public DatasetParams {
26 public:
27 template <typename T>
SamplingDatasetParams(T input_dataset_params,float rate,DataTypeVector output_dtypes,std::vector<PartialTensorShape> output_shapes,string node_name)28 SamplingDatasetParams(T input_dataset_params, float rate,
29 DataTypeVector output_dtypes,
30 std::vector<PartialTensorShape> output_shapes,
31 string node_name)
32 : DatasetParams(std::move(output_dtypes), std::move(output_shapes),
33 std::move(node_name)),
34 rate_(rate) {
35 input_dataset_params_.push_back(std::make_unique<T>(input_dataset_params));
36 iterator_prefix_ =
37 name_utils::IteratorPrefix(input_dataset_params.dataset_type(),
38 input_dataset_params.iterator_prefix());
39 }
40
GetInputTensors() const41 std::vector<Tensor> GetInputTensors() const override {
42 Tensor rate = CreateTensor<float>(TensorShape({}), {rate_});
43 Tensor seed_tensor = CreateTensor<int64_t>(TensorShape({}), {seed_tensor_});
44 Tensor seed2_tensor =
45 CreateTensor<int64_t>(TensorShape({}), {seed2_tensor_});
46 return {rate, seed_tensor, seed2_tensor};
47 }
48
GetInputNames(std::vector<string> * input_names) const49 Status GetInputNames(std::vector<string>* input_names) const override {
50 *input_names = {SamplingDatasetOp::kInputDataset, SamplingDatasetOp::kRate,
51 SamplingDatasetOp::kSeed, SamplingDatasetOp::kSeed2};
52
53 return OkStatus();
54 }
55
GetAttributes(AttributeVector * attr_vector) const56 Status GetAttributes(AttributeVector* attr_vector) const override {
57 *attr_vector = {{SamplingDatasetOp::kOutputTypes, output_dtypes_},
58 {SamplingDatasetOp::kOutputShapes, output_shapes_}};
59 return OkStatus();
60 }
61
dataset_type() const62 string dataset_type() const override {
63 return SamplingDatasetOp::kDatasetType;
64 }
65
66 private:
67 // Target sample rate, range (0,1], wrapped in a scalar Tensor
68 float rate_;
69 // Boxed versions of kRandomSeed and kRandomSeed2.
70 int64_t seed_tensor_ = kRandomSeed;
71 int64_t seed2_tensor_ = kRandomSeed2;
72 };
73
74 class SamplingDatasetOpTest : public DatasetOpsTestBase {};
75
OneHundredPercentSampleParams()76 SamplingDatasetParams OneHundredPercentSampleParams() {
77 return SamplingDatasetParams(RangeDatasetParams(0, 3, 1),
78 /*rate=*/1.0,
79 /*output_dtypes=*/{DT_INT64},
80 /*output_shapes=*/{PartialTensorShape({})},
81 /*node_name=*/kNodeName);
82 }
83
TenPercentSampleParams()84 SamplingDatasetParams TenPercentSampleParams() {
85 return SamplingDatasetParams(RangeDatasetParams(0, 20, 1),
86 /*rate=*/0.1,
87 /*output_dtypes=*/{DT_INT64},
88 /*output_shapes=*/{PartialTensorShape({})},
89 /*node_name=*/kNodeName);
90 }
91
ZeroPercentSampleParams()92 SamplingDatasetParams ZeroPercentSampleParams() {
93 return SamplingDatasetParams(RangeDatasetParams(0, 20, 1),
94 /*rate=*/0.0,
95 /*output_dtypes=*/{DT_INT64},
96 /*output_shapes=*/{PartialTensorShape({})},
97 /*node_name=*/kNodeName);
98 }
99
GetNextTestCases()100 std::vector<GetNextTestCase<SamplingDatasetParams>> GetNextTestCases() {
101 return {
102 // Test case 1: 100% sample should return all inputs
103 {/*dataset_params=*/OneHundredPercentSampleParams(),
104 /*expected_outputs=*/CreateTensors<int64_t>(TensorShape({}),
105 {{0}, {1}, {2}})},
106
107 // Test case 2: 10% sample should return about 10% of inputs, and the
108 // specific inputs returned shouldn't change across build environments.
109 {/*dataset_params=*/TenPercentSampleParams(),
110 /*expected_outputs=*/CreateTensors<int64_t>(TensorShape({}),
111 {{9}, {11}, {19}})},
112
113 // Test case 3: 0% sample should return nothing and should not crash.
114 {/*dataset_params=*/ZeroPercentSampleParams(), /*expected_outputs=*/{}}};
115 }
116
ITERATOR_GET_NEXT_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,GetNextTestCases ())117 ITERATOR_GET_NEXT_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
118 GetNextTestCases())
119
120 std::vector<DatasetNodeNameTestCase<SamplingDatasetParams>>
121 DatasetNodeNameTestCases() {
122 return {{/*dataset_params=*/TenPercentSampleParams(),
123 /*expected_node_name=*/kNodeName}};
124 }
125
DATASET_NODE_NAME_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,DatasetNodeNameTestCases ())126 DATASET_NODE_NAME_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
127 DatasetNodeNameTestCases())
128
129 std::vector<DatasetTypeStringTestCase<SamplingDatasetParams>>
130 DatasetTypeStringTestCases() {
131 return {{/*dataset_params=*/TenPercentSampleParams(),
132 /*expected_dataset_type_string=*/name_utils::OpName(
133 SamplingDatasetOp::kDatasetType)}};
134 }
135
DATASET_TYPE_STRING_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,DatasetTypeStringTestCases ())136 DATASET_TYPE_STRING_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
137 DatasetTypeStringTestCases())
138
139 std::vector<DatasetOutputDtypesTestCase<SamplingDatasetParams>>
140 DatasetOutputDtypesTestCases() {
141 return {{/*dataset_params=*/TenPercentSampleParams(),
142 /*expected_output_dtypes=*/{DT_INT64}}};
143 }
144
DATASET_OUTPUT_DTYPES_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,DatasetOutputDtypesTestCases ())145 DATASET_OUTPUT_DTYPES_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
146 DatasetOutputDtypesTestCases())
147
148 std::vector<DatasetOutputShapesTestCase<SamplingDatasetParams>>
149 DatasetOutputShapesTestCases() {
150 return {{/*dataset_params=*/TenPercentSampleParams(),
151 /*expected_output_shapes=*/{PartialTensorShape({})}}};
152 }
153
DATASET_OUTPUT_SHAPES_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,DatasetOutputShapesTestCases ())154 DATASET_OUTPUT_SHAPES_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
155 DatasetOutputShapesTestCases())
156
157 std::vector<CardinalityTestCase<SamplingDatasetParams>> CardinalityTestCases() {
158 return {{/*dataset_params=*/OneHundredPercentSampleParams(),
159 /*expected_cardinality=*/kUnknownCardinality},
160 {/*dataset_params=*/TenPercentSampleParams(),
161 /*expected,cardinality=*/kUnknownCardinality},
162 {/*dataset_params=*/ZeroPercentSampleParams(),
163 /*expected_cardinality=*/kUnknownCardinality}};
164 }
165
DATASET_CARDINALITY_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,CardinalityTestCases ())166 DATASET_CARDINALITY_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
167 CardinalityTestCases())
168
169 std::vector<IteratorOutputDtypesTestCase<SamplingDatasetParams>>
170 IteratorOutputDtypesTestCases() {
171 return {{/*dataset_params=*/TenPercentSampleParams(),
172 /*expected_output_dtypes=*/{DT_INT64}}};
173 }
174
ITERATOR_OUTPUT_DTYPES_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,IteratorOutputDtypesTestCases ())175 ITERATOR_OUTPUT_DTYPES_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
176 IteratorOutputDtypesTestCases())
177
178 std::vector<IteratorOutputShapesTestCase<SamplingDatasetParams>>
179 IteratorOutputShapesTestCases() {
180 return {{/*dataset_params=*/TenPercentSampleParams(),
181 /*expected_output_shapes=*/{PartialTensorShape({})}}};
182 }
183
ITERATOR_OUTPUT_SHAPES_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,IteratorOutputShapesTestCases ())184 ITERATOR_OUTPUT_SHAPES_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
185 IteratorOutputShapesTestCases())
186
187 std::vector<IteratorPrefixTestCase<SamplingDatasetParams>>
188 IteratorOutputPrefixTestCases() {
189 return {{/*dataset_params=*/TenPercentSampleParams(),
190 /*expected_iterator_prefix=*/name_utils::IteratorPrefix(
191 SamplingDatasetOp::kDatasetType,
192 TenPercentSampleParams().iterator_prefix())}};
193 }
194
ITERATOR_PREFIX_TEST_P(SamplingDatasetOpTest,SamplingDatasetParams,IteratorOutputPrefixTestCases ())195 ITERATOR_PREFIX_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
196 IteratorOutputPrefixTestCases())
197
198 std::vector<IteratorSaveAndRestoreTestCase<SamplingDatasetParams>>
199 IteratorSaveAndRestoreTestCases() {
200 return {{/*dataset_params=*/OneHundredPercentSampleParams(),
201 /*breakpoints=*/{0, 2, 5},
202 /*expected_outputs=*/
203 CreateTensors<int64_t>(TensorShape({}), {{0}, {1}, {2}})},
204 {/*dataset_params=*/TenPercentSampleParams(),
205 /*breakpoints=*/{0, 2, 5},
206 /*expected_outputs=*/
207 CreateTensors<int64_t>(TensorShape({}), {{9}, {11}, {19}})},
208 {/*dataset_params=*/ZeroPercentSampleParams(),
209 /*breakpoints=*/{0, 2, 5},
210 /*expected_outputs=*/{}}};
211 }
212
213 ITERATOR_SAVE_AND_RESTORE_TEST_P(SamplingDatasetOpTest, SamplingDatasetParams,
214 IteratorSaveAndRestoreTestCases())
215
216 } // namespace
217 } // namespace experimental
218 } // namespace data
219 } // namespace tensorflow
220