xref: /aosp_15_r20/external/libtextclassifier/native/utils/tflite/dist_diversification.cc (revision 993b0882672172b81d12fad7a7ac0c3e5c824a12)
1*993b0882SAndroid Build Coastguard Worker /*
2*993b0882SAndroid Build Coastguard Worker  * Copyright (C) 2018 The Android Open Source Project
3*993b0882SAndroid Build Coastguard Worker  *
4*993b0882SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*993b0882SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*993b0882SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*993b0882SAndroid Build Coastguard Worker  *
8*993b0882SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*993b0882SAndroid Build Coastguard Worker  *
10*993b0882SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*993b0882SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*993b0882SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*993b0882SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*993b0882SAndroid Build Coastguard Worker  * limitations under the License.
15*993b0882SAndroid Build Coastguard Worker  */
16*993b0882SAndroid Build Coastguard Worker 
17*993b0882SAndroid Build Coastguard Worker #include "utils/tflite/dist_diversification.h"
18*993b0882SAndroid Build Coastguard Worker 
19*993b0882SAndroid Build Coastguard Worker #include <algorithm>
20*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/context.h"
21*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/kernels/kernel_util.h"
22*993b0882SAndroid Build Coastguard Worker #include "tensorflow/lite/model.h"
23*993b0882SAndroid Build Coastguard Worker 
24*993b0882SAndroid Build Coastguard Worker namespace libtextclassifier3 {
25*993b0882SAndroid Build Coastguard Worker namespace {
26*993b0882SAndroid Build Coastguard Worker 
27*993b0882SAndroid Build Coastguard Worker // Returns a vector of row indices in a distance matrix.
28*993b0882SAndroid Build Coastguard Worker // Indices are increasing and the distance of every selected index to others
29*993b0882SAndroid Build Coastguard Worker // is larger than `min_distance`.
30*993b0882SAndroid Build Coastguard Worker template <typename DistanceMatrixType>
DiversifyByDistance(const DistanceMatrixType & distance_matrix,const int matrix_size,const float min_distance,const int max_num_results)31*993b0882SAndroid Build Coastguard Worker std::vector<int> DiversifyByDistance(const DistanceMatrixType& distance_matrix,
32*993b0882SAndroid Build Coastguard Worker                                      const int matrix_size,
33*993b0882SAndroid Build Coastguard Worker                                      const float min_distance,
34*993b0882SAndroid Build Coastguard Worker                                      const int max_num_results) {
35*993b0882SAndroid Build Coastguard Worker   std::vector<int> result{0};
36*993b0882SAndroid Build Coastguard Worker   result.reserve(max_num_results);
37*993b0882SAndroid Build Coastguard Worker   int index = 1;
38*993b0882SAndroid Build Coastguard Worker   while (result.size() < max_num_results && index < matrix_size) {
39*993b0882SAndroid Build Coastguard Worker     for (; index < matrix_size; ++index) {
40*993b0882SAndroid Build Coastguard Worker       bool too_close = false;
41*993b0882SAndroid Build Coastguard Worker       for (const int selected_index : result) {
42*993b0882SAndroid Build Coastguard Worker         if (distance_matrix(index, selected_index) < min_distance) {
43*993b0882SAndroid Build Coastguard Worker           too_close = true;
44*993b0882SAndroid Build Coastguard Worker           break;
45*993b0882SAndroid Build Coastguard Worker         }
46*993b0882SAndroid Build Coastguard Worker       }
47*993b0882SAndroid Build Coastguard Worker       if (!too_close) {
48*993b0882SAndroid Build Coastguard Worker         result.push_back(index);
49*993b0882SAndroid Build Coastguard Worker         ++index;
50*993b0882SAndroid Build Coastguard Worker         break;
51*993b0882SAndroid Build Coastguard Worker       }
52*993b0882SAndroid Build Coastguard Worker     }
53*993b0882SAndroid Build Coastguard Worker   }
54*993b0882SAndroid Build Coastguard Worker   return result;
55*993b0882SAndroid Build Coastguard Worker }
56*993b0882SAndroid Build Coastguard Worker 
57*993b0882SAndroid Build Coastguard Worker // Input parameters for the op.
58*993b0882SAndroid Build Coastguard Worker enum DistDiversificationInputs {
59*993b0882SAndroid Build Coastguard Worker   DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX = 0,
60*993b0882SAndroid Build Coastguard Worker   DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE = 1,
61*993b0882SAndroid Build Coastguard Worker   DIST_DIVERSIFICATION_INPUT_NUM_RESULTS = 2
62*993b0882SAndroid Build Coastguard Worker };
63*993b0882SAndroid Build Coastguard Worker 
64*993b0882SAndroid Build Coastguard Worker // Output parameters for the op.
65*993b0882SAndroid Build Coastguard Worker enum DistDiversificationOutputs {
66*993b0882SAndroid Build Coastguard Worker   DIST_DIVERSIFICATION_OUTPUT_INDICES = 0,
67*993b0882SAndroid Build Coastguard Worker   DIST_DIVERSIFICATION_OUTPUT_LENGTH = 1,
68*993b0882SAndroid Build Coastguard Worker };
69*993b0882SAndroid Build Coastguard Worker 
CreateSizeArray(const std::initializer_list<int> & sizes)70*993b0882SAndroid Build Coastguard Worker TfLiteIntArray* CreateSizeArray(const std::initializer_list<int>& sizes) {
71*993b0882SAndroid Build Coastguard Worker   TfLiteIntArray* array_size = TfLiteIntArrayCreate(sizes.size());
72*993b0882SAndroid Build Coastguard Worker   int index = 0;
73*993b0882SAndroid Build Coastguard Worker   for (const int size : sizes) {
74*993b0882SAndroid Build Coastguard Worker     array_size->data[index++] = size;
75*993b0882SAndroid Build Coastguard Worker   }
76*993b0882SAndroid Build Coastguard Worker   return array_size;
77*993b0882SAndroid Build Coastguard Worker }
78*993b0882SAndroid Build Coastguard Worker 
AllocateOutputIndexes(TfLiteContext * context,TfLiteNode * node)79*993b0882SAndroid Build Coastguard Worker TfLiteStatus AllocateOutputIndexes(TfLiteContext* context, TfLiteNode* node) {
80*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& num_results =
81*993b0882SAndroid Build Coastguard Worker       context
82*993b0882SAndroid Build Coastguard Worker           ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
83*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_indices =
84*993b0882SAndroid Build Coastguard Worker       context
85*993b0882SAndroid Build Coastguard Worker           ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
86*993b0882SAndroid Build Coastguard Worker   return context->ResizeTensor(context, &output_indices,
87*993b0882SAndroid Build Coastguard Worker                                CreateSizeArray({num_results.data.i32[0]}));
88*993b0882SAndroid Build Coastguard Worker }
89*993b0882SAndroid Build Coastguard Worker 
Prepare(TfLiteContext * context,TfLiteNode * node)90*993b0882SAndroid Build Coastguard Worker TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) {
91*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& num_results =
92*993b0882SAndroid Build Coastguard Worker       context
93*993b0882SAndroid Build Coastguard Worker           ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]];
94*993b0882SAndroid Build Coastguard Worker   if (tflite::IsConstantTensor(&num_results)) {
95*993b0882SAndroid Build Coastguard Worker     TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
96*993b0882SAndroid Build Coastguard Worker   } else {
97*993b0882SAndroid Build Coastguard Worker     TfLiteTensor& output_indices =
98*993b0882SAndroid Build Coastguard Worker         context
99*993b0882SAndroid Build Coastguard Worker             ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
100*993b0882SAndroid Build Coastguard Worker     tflite::SetTensorToDynamic(&output_indices);
101*993b0882SAndroid Build Coastguard Worker   }
102*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_length =
103*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
104*993b0882SAndroid Build Coastguard Worker   TF_LITE_ENSURE_OK(context, context->ResizeTensor(context, &output_length,
105*993b0882SAndroid Build Coastguard Worker                                                    CreateSizeArray({1})));
106*993b0882SAndroid Build Coastguard Worker   return kTfLiteOk;
107*993b0882SAndroid Build Coastguard Worker }
108*993b0882SAndroid Build Coastguard Worker 
Eval(TfLiteContext * context,TfLiteNode * node)109*993b0882SAndroid Build Coastguard Worker TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
110*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_indices =
111*993b0882SAndroid Build Coastguard Worker       context
112*993b0882SAndroid Build Coastguard Worker           ->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_INDICES]];
113*993b0882SAndroid Build Coastguard Worker   if (tflite::IsDynamicTensor(&output_indices)) {
114*993b0882SAndroid Build Coastguard Worker     TF_LITE_ENSURE_OK(context, AllocateOutputIndexes(context, node));
115*993b0882SAndroid Build Coastguard Worker   }
116*993b0882SAndroid Build Coastguard Worker   const TfLiteTensor& distance_matrix =
117*993b0882SAndroid Build Coastguard Worker       context->tensors[node->inputs
118*993b0882SAndroid Build Coastguard Worker                            ->data[DIST_DIVERSIFICATION_INPUT_DISTANCE_MATRIX]];
119*993b0882SAndroid Build Coastguard Worker   const int distance_matrix_dim = distance_matrix.dims->data[0];
120*993b0882SAndroid Build Coastguard Worker   const float min_distance =
121*993b0882SAndroid Build Coastguard Worker       context
122*993b0882SAndroid Build Coastguard Worker           ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_MIN_DISTANCE]]
123*993b0882SAndroid Build Coastguard Worker           .data.f[0];
124*993b0882SAndroid Build Coastguard Worker   const int num_results =
125*993b0882SAndroid Build Coastguard Worker       context
126*993b0882SAndroid Build Coastguard Worker           ->tensors[node->inputs->data[DIST_DIVERSIFICATION_INPUT_NUM_RESULTS]]
127*993b0882SAndroid Build Coastguard Worker           .data.i32[0];
128*993b0882SAndroid Build Coastguard Worker   const auto indices = DiversifyByDistance(
129*993b0882SAndroid Build Coastguard Worker       [&](int row, int col) {
130*993b0882SAndroid Build Coastguard Worker         return distance_matrix.data.f[row * distance_matrix_dim + col];
131*993b0882SAndroid Build Coastguard Worker       },
132*993b0882SAndroid Build Coastguard Worker       distance_matrix_dim, min_distance, num_results);
133*993b0882SAndroid Build Coastguard Worker   std::copy(indices.begin(), indices.end(), output_indices.data.i32);
134*993b0882SAndroid Build Coastguard Worker   std::fill_n(output_indices.data.i32 + indices.size(),
135*993b0882SAndroid Build Coastguard Worker               num_results - indices.size(), -1);
136*993b0882SAndroid Build Coastguard Worker   TfLiteTensor& output_length =
137*993b0882SAndroid Build Coastguard Worker       context->tensors[node->outputs->data[DIST_DIVERSIFICATION_OUTPUT_LENGTH]];
138*993b0882SAndroid Build Coastguard Worker   *output_length.data.i32 = indices.size();
139*993b0882SAndroid Build Coastguard Worker   return kTfLiteOk;
140*993b0882SAndroid Build Coastguard Worker }
141*993b0882SAndroid Build Coastguard Worker 
142*993b0882SAndroid Build Coastguard Worker }  // namespace
143*993b0882SAndroid Build Coastguard Worker }  // namespace libtextclassifier3
144*993b0882SAndroid Build Coastguard Worker 
145*993b0882SAndroid Build Coastguard Worker namespace tflite {
146*993b0882SAndroid Build Coastguard Worker namespace ops {
147*993b0882SAndroid Build Coastguard Worker namespace custom {
Register_DISTANCE_DIVERSIFICATION()148*993b0882SAndroid Build Coastguard Worker TfLiteRegistration* Register_DISTANCE_DIVERSIFICATION() {
149*993b0882SAndroid Build Coastguard Worker   static TfLiteRegistration r = {nullptr, nullptr, libtextclassifier3::Prepare,
150*993b0882SAndroid Build Coastguard Worker                                  libtextclassifier3::Eval};
151*993b0882SAndroid Build Coastguard Worker   return &r;
152*993b0882SAndroid Build Coastguard Worker }
153*993b0882SAndroid Build Coastguard Worker }  // namespace custom
154*993b0882SAndroid Build Coastguard Worker }  // namespace ops
155*993b0882SAndroid Build Coastguard Worker }  // namespace tflite
156