xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/sdca_internal.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #define EIGEN_USE_THREADS
17 
18 #include "tensorflow/core/kernels/sdca_internal.h"
19 
20 #include <limits>
21 #include <numeric>
22 #include <random>
23 
24 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
25 #include "tensorflow/core/lib/gtl/flatset.h"
26 #include "tensorflow/core/lib/math/math_util.h"
27 #include "tensorflow/core/lib/random/simple_philox.h"
28 
29 #if defined(TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL)
30 #include "tensorflow/core/kernels/eigen_contraction_kernel.h"
31 #endif
32 
33 namespace tensorflow {
34 namespace sdca {
35 
36 using UnalignedFloatVector = TTypes<const float>::UnalignedConstVec;
37 using UnalignedInt64Vector = TTypes<const int64_t>::UnalignedConstVec;
38 
UpdateDenseDeltaWeights(const Eigen::ThreadPoolDevice & device,const Example::DenseVector & dense_vector,const std::vector<double> & normalized_bounded_dual_delta)39 void FeatureWeightsDenseStorage::UpdateDenseDeltaWeights(
40     const Eigen::ThreadPoolDevice& device,
41     const Example::DenseVector& dense_vector,
42     const std::vector<double>& normalized_bounded_dual_delta) {
43   const size_t num_weight_vectors = normalized_bounded_dual_delta.size();
44   if (num_weight_vectors == 1) {
45     deltas_.device(device) =
46         deltas_ + dense_vector.RowAsMatrix() *
47                       deltas_.constant(normalized_bounded_dual_delta[0]);
48   } else {
49     // Transform the dual vector into a column matrix.
50     const Eigen::TensorMap<Eigen::Tensor<const double, 2, Eigen::RowMajor>>
51         dual_matrix(normalized_bounded_dual_delta.data(), num_weight_vectors,
52                     1);
53     const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
54         Eigen::IndexPair<int>(1, 0)};
55     // This computes delta_w += delta_vector / \lamdba * N.
56     deltas_.device(device) =
57         (deltas_.cast<double>() +
58          dual_matrix.contract(dense_vector.RowAsMatrix().cast<double>(),
59                               product_dims))
60             .cast<float>();
61   }
62 }
63 
UpdateSparseDeltaWeights(const Eigen::ThreadPoolDevice & device,const Example::SparseFeatures & sparse_features,const std::vector<double> & normalized_bounded_dual_delta)64 void FeatureWeightsSparseStorage::UpdateSparseDeltaWeights(
65     const Eigen::ThreadPoolDevice& device,
66     const Example::SparseFeatures& sparse_features,
67     const std::vector<double>& normalized_bounded_dual_delta) {
68   for (int64_t k = 0; k < sparse_features.indices->size(); ++k) {
69     const double feature_value =
70         sparse_features.values == nullptr ? 1.0 : (*sparse_features.values)(k);
71     auto it = indices_to_id_.find((*sparse_features.indices)(k));
72     for (size_t l = 0; l < normalized_bounded_dual_delta.size(); ++l) {
73       deltas_(l, it->second) +=
74           feature_value * normalized_bounded_dual_delta[l];
75     }
76   }
77 }
78 
UpdateDeltaWeights(const Eigen::ThreadPoolDevice & device,const Example & example,const std::vector<double> & normalized_bounded_dual_delta)79 void ModelWeights::UpdateDeltaWeights(
80     const Eigen::ThreadPoolDevice& device, const Example& example,
81     const std::vector<double>& normalized_bounded_dual_delta) {
82   // Sparse weights.
83   for (size_t j = 0; j < sparse_weights_.size(); ++j) {
84     sparse_weights_[j].UpdateSparseDeltaWeights(
85         device, example.sparse_features_[j], normalized_bounded_dual_delta);
86   }
87 
88   // Dense weights.
89   for (size_t j = 0; j < dense_weights_.size(); ++j) {
90     dense_weights_[j].UpdateDenseDeltaWeights(
91         device, *example.dense_vectors_[j], normalized_bounded_dual_delta);
92   }
93 }
94 
Initialize(OpKernelContext * const context)95 Status ModelWeights::Initialize(OpKernelContext* const context) {
96   OpInputList sparse_indices_inputs;
97   TF_RETURN_IF_ERROR(
98       context->input_list("sparse_indices", &sparse_indices_inputs));
99   OpInputList sparse_weights_inputs;
100   TF_RETURN_IF_ERROR(
101       context->input_list("sparse_weights", &sparse_weights_inputs));
102   if (sparse_indices_inputs.size() != sparse_weights_inputs.size())
103     return errors::InvalidArgument(
104         "sparse_indices and sparse_weights must have the same length, got ",
105         sparse_indices_inputs.size(), " and ", sparse_weights_inputs.size());
106   OpInputList dense_weights_inputs;
107   TF_RETURN_IF_ERROR(
108       context->input_list("dense_weights", &dense_weights_inputs));
109 
110   OpOutputList sparse_weights_outputs;
111   TF_RETURN_IF_ERROR(context->output_list("out_delta_sparse_weights",
112                                           &sparse_weights_outputs));
113   if (sparse_weights_outputs.size() != sparse_weights_inputs.size())
114     return errors::InvalidArgument(
115         "out_delta_sparse_weights and sparse_weights must have the same "
116         "length, got ",
117         sparse_weights_outputs.size(), " and ", sparse_weights_inputs.size());
118 
119   OpOutputList dense_weights_outputs;
120   TF_RETURN_IF_ERROR(
121       context->output_list("out_delta_dense_weights", &dense_weights_outputs));
122   if (dense_weights_outputs.size() != dense_weights_inputs.size())
123     return errors::InvalidArgument(
124         "out_delta_dense_weights and dense_weights must have the same length, "
125         "got ",
126         dense_weights_outputs.size(), " and ", dense_weights_inputs.size());
127 
128   for (int i = 0; i < sparse_weights_inputs.size(); ++i) {
129     Tensor* delta_t;
130     TF_RETURN_IF_ERROR(sparse_weights_outputs.allocate(
131         i, sparse_weights_inputs[i].shape(), &delta_t));
132     // Convert the input vector to a row matrix in internal representation.
133     auto deltas = delta_t->shaped<float, 2>({1, delta_t->NumElements()});
134     deltas.setZero();
135     sparse_weights_.emplace_back(FeatureWeightsSparseStorage{
136         sparse_indices_inputs[i].flat<int64_t>(),
137         sparse_weights_inputs[i].shaped<float, 2>(
138             {1, sparse_weights_inputs[i].NumElements()}),
139         deltas});
140   }
141 
142   // Reads in the weights, and allocates and initializes the delta weights.
143   const auto initialize_weights =
144       [&](const OpInputList& weight_inputs, OpOutputList* const weight_outputs,
145           std::vector<FeatureWeightsDenseStorage>* const feature_weights) {
146         for (int i = 0; i < weight_inputs.size(); ++i) {
147           Tensor* delta_t;
148           TF_RETURN_IF_ERROR(
149               weight_outputs->allocate(i, weight_inputs[i].shape(), &delta_t));
150           // Convert the input vector to a row matrix in internal
151           // representation.
152           auto deltas = delta_t->shaped<float, 2>({1, delta_t->NumElements()});
153           deltas.setZero();
154           feature_weights->emplace_back(FeatureWeightsDenseStorage{
155               weight_inputs[i].shaped<float, 2>(
156                   {1, weight_inputs[i].NumElements()}),
157               deltas});
158         }
159         return OkStatus();
160       };
161 
162   return initialize_weights(dense_weights_inputs, &dense_weights_outputs,
163                             &dense_weights_);
164 }
165 
166 // Computes the example statistics for given example, and model. Defined here
167 // as we need definition of ModelWeights and Regularizations.
ComputeWxAndWeightedExampleNorm(const int num_loss_partitions,const ModelWeights & model_weights,const Regularizations & regularization,const int num_weight_vectors) const168 const ExampleStatistics Example::ComputeWxAndWeightedExampleNorm(
169     const int num_loss_partitions, const ModelWeights& model_weights,
170     const Regularizations& regularization, const int num_weight_vectors) const {
171   ExampleStatistics result(num_weight_vectors);
172 
173   result.normalized_squared_norm =
174       squared_norm_ / regularization.symmetric_l2();
175 
176   // Compute w \dot x and prev_w \dot x.
177   // This is for sparse features contribution to the logit.
178   for (size_t j = 0; j < sparse_features_.size(); ++j) {
179     const Example::SparseFeatures& sparse_features = sparse_features_[j];
180     const FeatureWeightsSparseStorage& sparse_weights =
181         model_weights.sparse_weights()[j];
182 
183     for (int64_t k = 0; k < sparse_features.indices->size(); ++k) {
184       const int64_t feature_index = (*sparse_features.indices)(k);
185       const double feature_value = sparse_features.values == nullptr
186                                        ? 1.0
187                                        : (*sparse_features.values)(k);
188       for (int l = 0; l < num_weight_vectors; ++l) {
189         const float sparse_weight = sparse_weights.nominals(l, feature_index);
190         const double feature_weight =
191             sparse_weight +
192             sparse_weights.deltas(l, feature_index) * num_loss_partitions;
193         result.prev_wx[l] +=
194             feature_value * regularization.Shrink(sparse_weight);
195         result.wx[l] += feature_value * regularization.Shrink(feature_weight);
196       }
197     }
198   }
199 
200   // Compute w \dot x and prev_w \dot x.
201   // This is for dense features contribution to the logit.
202   for (size_t j = 0; j < dense_vectors_.size(); ++j) {
203     const Example::DenseVector& dense_vector = *dense_vectors_[j];
204     const FeatureWeightsDenseStorage& dense_weights =
205         model_weights.dense_weights()[j];
206 
207     const Eigen::Tensor<float, 2, Eigen::RowMajor> feature_weights =
208         dense_weights.nominals() +
209         dense_weights.deltas() *
210             dense_weights.deltas().constant(num_loss_partitions);
211     if (num_weight_vectors == 1) {
212       const Eigen::Tensor<float, 0, Eigen::RowMajor> prev_prediction =
213           (dense_vector.Row() *
214            regularization.EigenShrinkVector(
215                Eigen::TensorMap<Eigen::Tensor<const float, 1, Eigen::RowMajor>>(
216                    dense_weights.nominals().data(),
217                    dense_weights.nominals().dimension(1))))
218               .sum();
219       const Eigen::Tensor<float, 0, Eigen::RowMajor> prediction =
220           (dense_vector.Row() *
221            regularization.EigenShrinkVector(
222                Eigen::TensorMap<Eigen::Tensor<const float, 1, Eigen::RowMajor>>(
223                    feature_weights.data(), feature_weights.dimension(1))))
224               .sum();
225       result.prev_wx[0] += prev_prediction();
226       result.wx[0] += prediction();
227     } else {
228       const Eigen::array<Eigen::IndexPair<int>, 1> product_dims = {
229           Eigen::IndexPair<int>(1, 1)};
230       const Eigen::Tensor<float, 2, Eigen::RowMajor> prev_prediction =
231           regularization.EigenShrinkMatrix(dense_weights.nominals())
232               .contract(dense_vector.RowAsMatrix(), product_dims);
233       const Eigen::Tensor<float, 2, Eigen::RowMajor> prediction =
234           regularization.EigenShrinkMatrix(feature_weights)
235               .contract(dense_vector.RowAsMatrix(), product_dims);
236       // The result of "tensor contraction" (multiplication)  in the code
237       // above is of dimension num_weight_vectors * 1.
238       for (int l = 0; l < num_weight_vectors; ++l) {
239         result.prev_wx[l] += prev_prediction(l, 0);
240         result.wx[l] += prediction(l, 0);
241       }
242     }
243   }
244 
245   return result;
246 }
247 
248 // Examples contains all the training examples that SDCA uses for a mini-batch.
SampleAdaptiveProbabilities(const int num_loss_partitions,const Regularizations & regularization,const ModelWeights & model_weights,const TTypes<float>::Matrix example_state_data,const std::unique_ptr<DualLossUpdater> & loss_updater,const int num_weight_vectors)249 Status Examples::SampleAdaptiveProbabilities(
250     const int num_loss_partitions, const Regularizations& regularization,
251     const ModelWeights& model_weights,
252     const TTypes<float>::Matrix example_state_data,
253     const std::unique_ptr<DualLossUpdater>& loss_updater,
254     const int num_weight_vectors) {
255   if (num_weight_vectors != 1) {
256     return errors::InvalidArgument(
257         "Adaptive SDCA only works with binary SDCA, "
258         "where num_weight_vectors should be 1.");
259   }
260   // Compute the probabilities
261   for (int example_id = 0; example_id < num_examples(); ++example_id) {
262     const Example& example = examples_[example_id];
263     const double example_weight = example.example_weight();
264     float label = example.example_label();
265     const Status conversion_status = loss_updater->ConvertLabel(&label);
266     const ExampleStatistics example_statistics =
267         example.ComputeWxAndWeightedExampleNorm(num_loss_partitions,
268                                                 model_weights, regularization,
269                                                 num_weight_vectors);
270     const double kappa = example_state_data(example_id, 0) +
271                          loss_updater->PrimalLossDerivative(
272                              example_statistics.wx[0], label, 1.0);
273     probabilities_[example_id] = example_weight *
274                                  sqrt(examples_[example_id].squared_norm_ +
275                                       regularization.symmetric_l2() *
276                                           loss_updater->SmoothnessConstant()) *
277                                  std::abs(kappa);
278   }
279 
280   // Sample the index
281   random::DistributionSampler sampler(probabilities_);
282   GuardedPhiloxRandom generator;
283   generator.Init(0, 0);
284   auto local_gen = generator.ReserveSamples32(num_examples());
285   random::SimplePhilox random(&local_gen);
286   std::random_device rd;
287   std::mt19937 gen(rd());
288   std::uniform_real_distribution<> dis(0, 1);
289 
290   // We use a decay of 10: the probability of an example is divided by 10
291   // once that example is picked. A good approximation of that is to only
292   // keep a picked example with probability (1 / 10) ^ k where k is the
293   // number of times we already picked that example. We add a num_retries
294   // to avoid taking too long to sample. We then fill the sampled_index with
295   // unseen examples sorted by probabilities.
296   int id = 0;
297   int num_retries = 0;
298   while (id < num_examples() && num_retries < num_examples()) {
299     int picked_id = sampler.Sample(&random);
300     if (dis(gen) > MathUtil::IPow(0.1, sampled_count_[picked_id])) {
301       num_retries++;
302       continue;
303     }
304     sampled_count_[picked_id]++;
305     sampled_index_[id++] = picked_id;
306   }
307 
308   std::vector<std::pair<int, float>> examples_not_seen;
309   examples_not_seen.reserve(num_examples());
310   for (int i = 0; i < num_examples(); ++i) {
311     if (sampled_count_[i] == 0)
312       examples_not_seen.emplace_back(sampled_index_[i], probabilities_[i]);
313   }
314   std::sort(
315       examples_not_seen.begin(), examples_not_seen.end(),
316       [](const std::pair<int, float>& lhs, const std::pair<int, float>& rhs) {
317         return lhs.second > rhs.second;
318       });
319   for (int i = id; i < num_examples(); ++i) {
320     sampled_count_[i] = examples_not_seen[i - id].first;
321   }
322   return OkStatus();
323 }
324 
RandomShuffle()325 void Examples::RandomShuffle() {
326   std::iota(sampled_index_.begin(), sampled_index_.end(), 0);
327 
328   std::random_device rd;
329   std::mt19937 rng(rd());
330   std::shuffle(sampled_index_.begin(), sampled_index_.end(), rng);
331 }
332 
333 // TODO(sibyl-Aix6ihai): Refactor/shorten this function.
Initialize(OpKernelContext * const context,const ModelWeights & weights,const int num_sparse_features,const int num_sparse_features_with_values,const int num_dense_features)334 Status Examples::Initialize(OpKernelContext* const context,
335                             const ModelWeights& weights,
336                             const int num_sparse_features,
337                             const int num_sparse_features_with_values,
338                             const int num_dense_features) {
339   num_features_ = num_sparse_features + num_dense_features;
340 
341   OpInputList sparse_example_indices_inputs;
342   TF_RETURN_IF_ERROR(context->input_list("sparse_example_indices",
343                                          &sparse_example_indices_inputs));
344   if (sparse_example_indices_inputs.size() != num_sparse_features)
345     return errors::InvalidArgument(
346         "Expected ", num_sparse_features,
347         " tensors in sparse_example_indices but got ",
348         sparse_example_indices_inputs.size());
349   OpInputList sparse_feature_indices_inputs;
350   TF_RETURN_IF_ERROR(context->input_list("sparse_feature_indices",
351                                          &sparse_feature_indices_inputs));
352   if (sparse_feature_indices_inputs.size() != num_sparse_features)
353     return errors::InvalidArgument(
354         "Expected ", num_sparse_features,
355         " tensors in sparse_feature_indices but got ",
356         sparse_feature_indices_inputs.size());
357   OpInputList sparse_feature_values_inputs;
358   if (num_sparse_features_with_values > 0) {
359     TF_RETURN_IF_ERROR(context->input_list("sparse_feature_values",
360                                            &sparse_feature_values_inputs));
361     if (sparse_feature_values_inputs.size() != num_sparse_features_with_values)
362       return errors::InvalidArgument(
363           "Expected ", num_sparse_features_with_values,
364           " tensors in sparse_feature_values but got ",
365           sparse_feature_values_inputs.size());
366   }
367 
368   const Tensor* example_weights_t;
369   TF_RETURN_IF_ERROR(context->input("example_weights", &example_weights_t));
370   auto example_weights = example_weights_t->flat<float>();
371 
372   if (example_weights.size() >= std::numeric_limits<int>::max()) {
373     return errors::InvalidArgument(strings::Printf(
374         "Too many examples in a mini-batch: %zu > %d", example_weights.size(),
375         std::numeric_limits<int>::max()));
376   }
377 
378   // The static_cast here is safe since num_examples can be at max an int.
379   const int num_examples = static_cast<int>(example_weights.size());
380   const Tensor* example_labels_t;
381   TF_RETURN_IF_ERROR(context->input("example_labels", &example_labels_t));
382   auto example_labels = example_labels_t->flat<float>();
383   if (example_labels.size() != num_examples) {
384     return errors::InvalidArgument("Expected ", num_examples,
385                                    " example labels but got ",
386                                    example_labels.size());
387   }
388 
389   OpInputList dense_features_inputs;
390   TF_RETURN_IF_ERROR(
391       context->input_list("dense_features", &dense_features_inputs));
392 
393   examples_.clear();
394   examples_.resize(num_examples);
395   probabilities_.resize(num_examples);
396   sampled_index_.resize(num_examples);
397   sampled_count_.resize(num_examples);
398   for (int example_id = 0; example_id < num_examples; ++example_id) {
399     Example* const example = &examples_[example_id];
400     example->sparse_features_.resize(num_sparse_features);
401     example->dense_vectors_.resize(num_dense_features);
402     example->example_weight_ = example_weights(example_id);
403     example->example_label_ = example_labels(example_id);
404   }
405   const DeviceBase::CpuWorkerThreads& worker_threads =
406       *context->device()->tensorflow_cpu_worker_threads();
407   TF_RETURN_IF_ERROR(CreateSparseFeatureRepresentation(
408       worker_threads, num_examples, num_sparse_features, weights,
409       sparse_example_indices_inputs, sparse_feature_indices_inputs,
410       sparse_feature_values_inputs, &examples_));
411   TF_RETURN_IF_ERROR(CreateDenseFeatureRepresentation(
412       worker_threads, num_examples, num_dense_features, weights,
413       dense_features_inputs, &examples_));
414   TF_RETURN_IF_ERROR(ComputeSquaredNormPerExample(
415       worker_threads, num_examples, num_sparse_features, num_dense_features,
416       &examples_));
417   return OkStatus();
418 }
419 
CreateSparseFeatureRepresentation(const DeviceBase::CpuWorkerThreads & worker_threads,const int num_examples,const int num_sparse_features,const ModelWeights & weights,const OpInputList & sparse_example_indices_inputs,const OpInputList & sparse_feature_indices_inputs,const OpInputList & sparse_feature_values_inputs,std::vector<Example> * const examples)420 Status Examples::CreateSparseFeatureRepresentation(
421     const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples,
422     const int num_sparse_features, const ModelWeights& weights,
423     const OpInputList& sparse_example_indices_inputs,
424     const OpInputList& sparse_feature_indices_inputs,
425     const OpInputList& sparse_feature_values_inputs,
426     std::vector<Example>* const examples) {
427   mutex mu;
428   Status result;  // Guarded by mu
429   auto parse_partition = [&](const int64_t begin, const int64_t end) {
430     // The static_cast here is safe since begin and end can be at most
431     // num_examples which is an int.
432     for (int i = static_cast<int>(begin); i < end; ++i) {
433       auto example_indices =
434           sparse_example_indices_inputs[i].template flat<int64_t>();
435       auto feature_indices =
436           sparse_feature_indices_inputs[i].template flat<int64_t>();
437       if (example_indices.size() != feature_indices.size()) {
438         mutex_lock l(mu);
439         result = errors::InvalidArgument(
440             "Found mismatched example_indices and feature_indices [",
441             example_indices, "] vs [", feature_indices, "]");
442         return;
443       }
444 
445       // Parse features for each example. Features for a particular example
446       // are at the offsets (start_id, end_id]
447       int start_id = -1;
448       int end_id = 0;
449       for (int example_id = 0; example_id < num_examples; ++example_id) {
450         start_id = end_id;
451         while (end_id < example_indices.size() &&
452                example_indices(end_id) == example_id) {
453           ++end_id;
454         }
455         Example::SparseFeatures* const sparse_features =
456             &(*examples)[example_id].sparse_features_[i];
457         if (start_id < example_indices.size() &&
458             example_indices(start_id) == example_id) {
459           sparse_features->indices.reset(new UnalignedInt64Vector(
460               &(feature_indices(start_id)), end_id - start_id));
461           if (sparse_feature_values_inputs.size() > i) {
462             auto feature_weights =
463                 sparse_feature_values_inputs[i].flat<float>();
464             sparse_features->values.reset(new UnalignedFloatVector(
465                 &(feature_weights(start_id)), end_id - start_id));
466           }
467           // If features are non empty.
468           if (end_id - start_id > 0) {
469             // TODO(sibyl-Aix6ihai): Write this efficiently using vectorized
470             // operations from eigen.
471             for (int64_t k = 0; k < sparse_features->indices->size(); ++k) {
472               const int64_t feature_index = (*sparse_features->indices)(k);
473               if (!weights.SparseIndexValid(i, feature_index)) {
474                 mutex_lock l(mu);
475                 result = errors::InvalidArgument(
476                     "Found sparse feature indices out of valid range: ",
477                     (*sparse_features->indices)(k));
478                 return;
479               }
480             }
481           }
482         } else {
483           // Add a Tensor that has size 0.
484           sparse_features->indices.reset(
485               new UnalignedInt64Vector(&(feature_indices(0)), 0));
486           // If values exist for this feature group.
487           if (sparse_feature_values_inputs.size() > i) {
488             auto feature_weights =
489                 sparse_feature_values_inputs[i].flat<float>();
490             sparse_features->values.reset(
491                 new UnalignedFloatVector(&(feature_weights(0)), 0));
492           }
493         }
494       }
495     }
496   };
497   // For each column, the cost of parsing it is O(num_examples). We use
498   // num_examples here, as empirically Shard() creates the right amount of
499   // threads based on the problem size.
500   // TODO(sibyl-Aix6ihai): Tune this as a function of dataset size.
501   const int64_t kCostPerUnit = num_examples;
502   Shard(worker_threads.num_threads, worker_threads.workers, num_sparse_features,
503         kCostPerUnit, parse_partition);
504   return result;
505 }
506 
CreateDenseFeatureRepresentation(const DeviceBase::CpuWorkerThreads & worker_threads,const int num_examples,const int num_dense_features,const ModelWeights & weights,const OpInputList & dense_features_inputs,std::vector<Example> * const examples)507 Status Examples::CreateDenseFeatureRepresentation(
508     const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples,
509     const int num_dense_features, const ModelWeights& weights,
510     const OpInputList& dense_features_inputs,
511     std::vector<Example>* const examples) {
512   mutex mu;
513   Status result;  // Guarded by mu
514   auto parse_partition = [&](const int64_t begin, const int64_t end) {
515     // The static_cast here is safe since begin and end can be at most
516     // num_examples which is an int.
517     for (int i = static_cast<int>(begin); i < end; ++i) {
518       auto dense_features = dense_features_inputs[i].template matrix<float>();
519       for (int example_id = 0; example_id < num_examples; ++example_id) {
520         (*examples)[example_id].dense_vectors_[i].reset(
521             new Example::DenseVector{dense_features, example_id});
522       }
523       if (!weights.DenseIndexValid(i, dense_features.dimension(1) - 1)) {
524         mutex_lock l(mu);
525         result = errors::InvalidArgument(
526             "More dense features than we have parameters for: ",
527             dense_features.dimension(1));
528         return;
529       }
530     }
531   };
532   // TODO(sibyl-Aix6ihai): Tune this as a function of dataset size.
533   const int64_t kCostPerUnit = num_examples;
534   Shard(worker_threads.num_threads, worker_threads.workers, num_dense_features,
535         kCostPerUnit, parse_partition);
536   return result;
537 }
538 
ComputeSquaredNormPerExample(const DeviceBase::CpuWorkerThreads & worker_threads,const int num_examples,const int num_sparse_features,const int num_dense_features,std::vector<Example> * const examples)539 Status Examples::ComputeSquaredNormPerExample(
540     const DeviceBase::CpuWorkerThreads& worker_threads, const int num_examples,
541     const int num_sparse_features, const int num_dense_features,
542     std::vector<Example>* const examples) {
543   mutex mu;
544   Status result;  // Guarded by mu
545   // Compute norm of examples.
546   auto compute_example_norm = [&](const int64_t begin, const int64_t end) {
547     // The static_cast here is safe since begin and end can be at most
548     // num_examples which is an int.
549     gtl::FlatSet<int64_t> previous_indices;
550     for (int example_id = static_cast<int>(begin); example_id < end;
551          ++example_id) {
552       double squared_norm = 0;
553       Example* const example = &(*examples)[example_id];
554       for (int j = 0; j < num_sparse_features; ++j) {
555         const Example::SparseFeatures& sparse_features =
556             example->sparse_features_[j];
557         previous_indices.clear();
558         for (int64_t k = 0; k < sparse_features.indices->size(); ++k) {
559           const int64_t feature_index = (*sparse_features.indices)(k);
560           if (previous_indices.insert(feature_index).second == false) {
561             mutex_lock l(mu);
562             result =
563                 errors::InvalidArgument("Duplicate index in sparse vector.");
564             return;
565           }
566           const double feature_value = sparse_features.values == nullptr
567                                            ? 1.0
568                                            : (*sparse_features.values)(k);
569           squared_norm += feature_value * feature_value;
570         }
571       }
572       for (int j = 0; j < num_dense_features; ++j) {
573         const Eigen::Tensor<float, 0, Eigen::RowMajor> sn =
574             example->dense_vectors_[j]->Row().square().sum();
575         squared_norm += sn();
576       }
577       example->squared_norm_ = squared_norm;
578     }
579   };
580   // TODO(sibyl-Aix6ihai): Compute the cost optimally.
581   const int64_t kCostPerUnit = num_dense_features + num_sparse_features;
582   Shard(worker_threads.num_threads, worker_threads.workers, num_examples,
583         kCostPerUnit, compute_example_norm);
584   return result;
585 }
586 
587 }  // namespace sdca
588 }  // namespace tensorflow
589