1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker *
4*993b0882SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker *
8*993b0882SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker *
10*993b0882SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker */
16*993b0882SAndroid Build Coastguard Worker
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/dist_diversification.h"
18*993b0882SAndroid Build Coastguard Worker
19*993b0882SAndroid Build Coastguard Worker #include <algorithm>
20*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/context.h"
21*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/kernel_util.h"
22*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
23*993b0882SAndroid Build Coastguard Worker
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
25*993b0882SAndroid Build Coastguard Worker namespace {
26*993b0882SAndroid Build Coastguard Worker
27*993b0882SAndroid Build Coastguard Worker // Returns a vector of row indices in a distance matrix.
28*993b0882SAndroid Build Coastguard Worker // Indices are increasing and the distance of every selected index to others
29*993b0882SAndroid Build Coastguard Worker // is larger than `min_distance`.
30*993b0882SAndroid Build Coastguard Worker template <typename DistanceMatrixType>
DiversifyByDistance(const DistanceMatrixType & distance_matrix,const int matrix_size,const float min_distance,const int max_num_results)31*993b0882SAndroid Build Coastguard Worker std::vector<int> DiversifyByDistance(const DistanceMatrixType& distance_matrix,
32*993b0882SAndroid Build Coastguard Worker const int matrix_size,
33*993b0882SAndroid Build Coastguard Worker const float min_distance,
34*993b0882SAndroid Build Coastguard Worker const int max_num_results) {
35*993b0882SAndroid Build Coastguard Worker std::vector<int> result{0};
36*993b0882SAndroid Build Coastguard Worker result.reserve(max_num_results);
37*993b0882SAndroid Build Coastguard Worker int index = 1;
38*993b0882SAndroid Build Coastguard Worker while (result.size() < max_num_results && index < matrix_size) {
39*993b0882SAndroid Build Coastguard Worker for (; index < matrix_size; ++index) {
40*993b0882SAndroid Build Coastguard Worker bool too_close = false;
41*993b0882SAndroid Build Coastguard Worker for (const int selected_index : result) {
42*993b0882SAndroid Build Coastguard Worker if (distance_matrix(index, selected_index) < min_distance) {
43*993b0882SAndroid Build Coastguard Worker too_close = true;
44*993b0882SAndroid Build Coastguard Worker break;
45*993b0882SAndroid Build Coastguard Worker }
46*993b0882SAndroid Build Coastguard Worker }
47*993b0882SAndroid Build Coastguard Worker if (!too_close) {
48*993b0882SAndroid Build Coastguard Worker result.push_back(index);
49*993b0882SAndroid Build Coastguard Worker ++index;
50*993b0882SAndroid Build Coastguard Worker break;
51*993b0882SAndroid Build Coastguard Worker }
52*993b0882SAndroid Build Coastguard Worker }
53*993b0882SAndroid Build Coastguard Worker }
54*993b0882SAndroid Build Coastguard Worker return result;
55*993b0882SAndroid Build Coastguard Worker }
56*993b0882SAndroid Build Coastguard Worker
57*993b0882SAndroid Build Coastguard Worker // Input parameters for the op.
58*993b0882SAndroid Build Coastguard Worker enum DistDiversificationInputs {
59*993b0882SAndroid Build Coastguard Worker DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX = 0,
60*993b0882SAndroid Build Coastguard Worker DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE = 1,
61*993b0882SAndroid Build Coastguard Worker DIST_DIVERSIFICATION_INPUT_NUM_RESULTS = 2
62*993b0882SAndroid Build Coastguard Worker };
63*993b0882SAndroid Build Coastguard Worker
64*993b0882SAndroid Build Coastguard Worker // Output parameters for the op.
65*993b0882SAndroid Build Coastguard Worker enum DistDiversificationOutputs {
66*993b0882SAndroid Build Coastguard Worker DIST_DIVERSIFICATION_OUTPUT_INDICES = 0,
67*993b0882SAndroid Build Coastguard Worker DIST_DIVERSIFICATION_OUTPUT_LENGTH = 1,
68*993b0882SAndroid Build Coastguard Worker };
69*993b0882SAndroid Build Coastguard Worker
CreateSizeArray(const std::initializer_list<int> & sizes)70*993b0882SAndroid Build Coastguard Worker TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
71*993b0882SAndroid Build Coastguard Worker TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
72*993b0882SAndroid Build Coastguard Worker int index = 0;
73*993b0882SAndroid Build Coastguard Worker for (const int size : sizes) {
74*993b0882SAndroid Build Coastguard Worker array_size->data[index++] = size;
75*993b0882SAndroid Build Coastguard Worker }
76*993b0882SAndroid Build Coastguard Worker return array_size;
77*993b0882SAndroid Build Coastguard Worker }
78*993b0882SAndroid Build Coastguard Worker
AllocateOutputIndexes(TfLiteContext * context,TfLiteNode * node)79*993b0882SAndroid Build Coastguard Worker TfLiteStatus AllocateOutputIndexes(TfLiteContext* context, TfLiteNode* node) {
80*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& num_results =
81*993b0882SAndroid Build Coastguard Worker context
82*993b0882SAndroid Build Coastguard Worker ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
83*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_indices =
84*993b0882SAndroid Build Coastguard Worker context
85*993b0882SAndroid Build Coastguard Worker ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
86*993b0882SAndroid Build Coastguard Worker return context->ResizeTensor(context, &output_indices,
87*993b0882SAndroid Build Coastguard Worker CreateSizeArray({num_results.data.i32[0]}));
88*993b0882SAndroid Build Coastguard Worker }
89*993b0882SAndroid Build Coastguard Worker
Prepare(TfLiteContext * context,TfLiteNode * node)90*993b0882SAndroid Build Coastguard Worker TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& num_results =
92*993b0882SAndroid Build Coastguard Worker context
93*993b0882SAndroid Build Coastguard Worker ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
94*993b0882SAndroid Build Coastguard Worker if (tflite::IsConstantTensor(&num_results)) {
95*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
96*993b0882SAndroid Build Coastguard Worker } else {
97*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_indices =
98*993b0882SAndroid Build Coastguard Worker context
99*993b0882SAndroid Build Coastguard Worker ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
100*993b0882SAndroid Build Coastguard Worker tflite::SetTensorToDynamic(&output_indices);
101*993b0882SAndroid Build Coastguard Worker }
102*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_length =
103*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
104*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_length,
105*993b0882SAndroid Build Coastguard Worker CreateSizeArray({1})));
106*993b0882SAndroid Build Coastguard Worker return kTfLiteOk;
107*993b0882SAndroid Build Coastguard Worker }
108*993b0882SAndroid Build Coastguard Worker
Eval(TfLiteContext * context,TfLiteNode * node)109*993b0882SAndroid Build Coastguard Worker TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
110*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_indices =
111*993b0882SAndroid Build Coastguard Worker context
112*993b0882SAndroid Build Coastguard Worker ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
113*993b0882SAndroid Build Coastguard Worker if (tflite::IsDynamicTensor(&output_indices)) {
114*993b0882SAndroid Build Coastguard Worker TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
115*993b0882SAndroid Build Coastguard Worker }
116*993b0882SAndroid Build Coastguard Worker const TfLiteTensor& distance_matrix =
117*993b0882SAndroid Build Coastguard Worker context->tensors[node->inputs
118*993b0882SAndroid Build Coastguard Worker ->data[DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX]];
119*993b0882SAndroid Build Coastguard Worker const int distance_matrix_dim = distance_matrix.dims->data[0];
120*993b0882SAndroid Build Coastguard Worker const float min_distance =
121*993b0882SAndroid Build Coastguard Worker context
122*993b0882SAndroid Build Coastguard Worker ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE]]
123*993b0882SAndroid Build Coastguard Worker .data.f[0];
124*993b0882SAndroid Build Coastguard Worker const int num_results =
125*993b0882SAndroid Build Coastguard Worker context
126*993b0882SAndroid Build Coastguard Worker ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]]
127*993b0882SAndroid Build Coastguard Worker .data.i32[0];
128*993b0882SAndroid Build Coastguard Worker const auto indices = DiversifyByDistance(
129*993b0882SAndroid Build Coastguard Worker [&](int row, int col) {
130*993b0882SAndroid Build Coastguard Worker return distance_matrix.data.f[row * distance_matrix_dim + col];
131*993b0882SAndroid Build Coastguard Worker },
132*993b0882SAndroid Build Coastguard Worker distance_matrix_dim, min_distance, num_results);
133*993b0882SAndroid Build Coastguard Worker std::copy(indices.begin(), indices.end(), output_indices.data.i32);
134*993b0882SAndroid Build Coastguard Worker std::fill_n(output_indices.data.i32 + indices.size(),
135*993b0882SAndroid Build Coastguard Worker num_results - indices.size(), -1);
136*993b0882SAndroid Build Coastguard Worker TfLiteTensor& output_length =
137*993b0882SAndroid Build Coastguard Worker context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
138*993b0882SAndroid Build Coastguard Worker *output_length.data.i32 = indices.size();
139*993b0882SAndroid Build Coastguard Worker return kTfLiteOk;
140*993b0882SAndroid Build Coastguard Worker }
141*993b0882SAndroid Build Coastguard Worker
142*993b0882SAndroid Build Coastguard Worker } // namespace
143*993b0882SAndroid Build Coastguard Worker } // namespace libtextclassifier3
144*993b0882SAndroid Build Coastguard Worker
145*993b0882SAndroid Build Coastguard Worker namespace tflite {
146*993b0882SAndroid Build Coastguard Worker namespace ops {
147*993b0882SAndroid Build Coastguard Worker namespace custom {
Register_DISTANCE_DIVERSIFICATION()148*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION() {
149*993b0882SAndroid Build Coastguard Worker static TfLiteRegistration r = {nullptr, nullptr, libtextclassifier3::Prepare,
150*993b0882SAndroid Build Coastguard Worker libtextclassifier3::Eval};
151*993b0882SAndroid Build Coastguard Worker return &r;
152*993b0882SAndroid Build Coastguard Worker }
153*993b0882SAndroid Build Coastguard Worker } // namespace custom
154*993b0882SAndroid Build Coastguard Worker } // namespace ops
155*993b0882SAndroid Build Coastguard Worker } // namespace tflite
156