xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/multinomial_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/random_ops.cc.
17 
18 #define EIGEN_USE_THREADS
19 
20 #include "tensorflow/core/kernels/multinomial_op.h"
21 
22 #include <algorithm>
23 #include <cmath>
24 #include <memory>
25 
26 #include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
27 #include "tensorflow/core/framework/op_kernel.h"
28 #include "tensorflow/core/framework/register_types.h"
29 #include "tensorflow/core/framework/tensor.h"
30 #include "tensorflow/core/framework/tensor_shape.h"
31 #include "tensorflow/core/kernels/stateless_random_ops.h"
32 #include "tensorflow/core/lib/random/random_distributions.h"
33 #include "tensorflow/core/lib/random/simple_philox.h"
34 #include "tensorflow/core/util/guarded_philox_random.h"
35 #include "tensorflow/core/util/work_sharder.h"
36 
37 namespace tensorflow {
38 
39 typedef Eigen::ThreadPoolDevice CPUDevice;
40 typedef Eigen::GpuDevice GPUDevice;
41 
42 namespace functor {
43 
44 template <typename Device, typename T, typename OutputType>
45 struct MultinomialFunctor {
46   void operator()(OpKernelContext* ctx, const Device& d,
47                   typename TTypes<T>::ConstMatrix logits,
48                   typename TTypes<float>::Flat noises,
49                   typename TTypes<float>::Flat scores,
50                   typename TTypes<float>::Flat scratch, int batch_size,
51                   int num_classes, int num_samples,
52                   const random::PhiloxRandom& gen,
53                   typename TTypes<OutputType>::Matrix output);
54 };
55 
56 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
57 extern template struct MultinomialFunctor<GPUDevice, Eigen::half, int32>;
58 extern template struct MultinomialFunctor<GPUDevice, float, int32>;
59 extern template struct MultinomialFunctor<GPUDevice, double, int32>;
60 extern template struct MultinomialFunctor<GPUDevice, int32, int32>;
61 extern template struct MultinomialFunctor<GPUDevice, int64_t, int32>;
62 
63 extern template struct MultinomialFunctor<GPUDevice, Eigen::half, int64_t>;
64 extern template struct MultinomialFunctor<GPUDevice, float, int64_t>;
65 extern template struct MultinomialFunctor<GPUDevice, double, int64_t>;
66 extern template struct MultinomialFunctor<GPUDevice, int32, int64_t>;
67 extern template struct MultinomialFunctor<GPUDevice, int64_t, int64_t>;
68 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
69 
70 template <typename T, typename OutputType>
71 struct MultinomialFunctor<CPUDevice, T, OutputType> {
operator ()tensorflow::functor::MultinomialFunctor72   void operator()(OpKernelContext* ctx, const CPUDevice& d,
73                   typename TTypes<T>::ConstMatrix logits,
74                   typename TTypes<float>::Flat /* noises */,
75                   typename TTypes<float>::Flat /* scores */,
76                   typename TTypes<float>::Flat /* scratch */, int batch_size,
77                   int num_classes, int num_samples,
78                   const random::PhiloxRandom& gen,
79                   typename TTypes<OutputType>::Matrix output) {
80     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
81 
82     // The implementation only parallelizes by batch.
83     //
84     // This takes O(BatchSize * NumSamples * log(NumClasses) + NumClasses) CPU
85     // time.
86     auto DoWork = [ctx, num_samples, num_classes, &gen, &output, &logits](
87                       int64_t start_row, int64_t limit_row) {
88       // Capturing "gen" by-value would only make a copy for the _shared_
89       // lambda.  Since we want to let each worker have its own copy, we pass
90       // "gen" by reference and explicitly do a copy assignment here.
91       random::PhiloxRandom gen_copy = gen;
92       // Skip takes units of 128 bits.  +3 is so rounding doesn't lead to
93       // us using the same state in different batches.
94       gen_copy.Skip(start_row * (num_samples + 3) / 4);
95       random::SimplePhilox simple_philox(&gen_copy);
96 
97       Tensor cdf_tensor;
98       OP_REQUIRES_OK(ctx,
99                      ctx->allocate_temp(DT_DOUBLE, TensorShape({num_classes}),
100                                         &cdf_tensor));
101       auto cdf = cdf_tensor.flat<double>();
102       for (int64_t b = start_row; b < limit_row; ++b) {
103         const auto* logits_row = &logits(b, 0);
104 
105         // Takes an along-class maximum (for numerical stability).
106         T max = std::numeric_limits<T>::lowest();
107         for (int64_t j = 0; j < num_classes; ++j) {
108           if (Eigen::numext::isfinite(logits_row[j])) {
109             max = std::max(max, logits_row[j]);
110           }
111         }
112         const double max_logit = static_cast<double>(max);
113 
114         // Precompute cumulative probability distribution across classes.
115         // Note: This isn't normalized.
116         cdf = (logits.template chip<0>(b).template cast<double>() - max_logit)
117                   .exp();
118         double running_total = 0;
119         for (int64_t j = 0; j < num_classes; ++j) {
120           if (Eigen::numext::isfinite(logits_row[j])) {
121             running_total += cdf(j);
122           }
123           cdf(j) = running_total;
124         }
125         // Generate each sample.
126         const double* cdf_begin = cdf.data();
127         const double* cdf_end = cdf.data() + num_classes;
128         for (int64_t j = 0; j < num_samples; ++j) {
129           const double to_find = simple_philox.RandDouble() * running_total;
130           auto found_iter = std::upper_bound(cdf_begin, cdf_end, to_find);
131           output(b, j) = std::distance(cdf_begin, found_iter);
132         }
133       }
134     };
135     // Incredibly rough estimate of clock cycles for DoWork();
136     const int64_t cost =
137         50 * (num_samples * std::log(num_classes) / std::log(2) + num_classes);
138     Shard(worker_threads.num_threads, worker_threads.workers, batch_size, cost,
139           DoWork);
140   }
141 };
142 
143 }  // namespace functor
144 
145 namespace {
146 
147 // Samples from a multinomial distribution.
148 template <typename Device, typename T, typename OutputType>
149 class MultinomialOp : public OpKernel {
150  public:
MultinomialOp(OpKernelConstruction * context)151   explicit MultinomialOp(OpKernelConstruction* context) : OpKernel(context) {}
152 
DoCompute(OpKernelContext * ctx,const Tensor & logits_t,const Tensor & num_samples_t,GuardedPhiloxRandom * generator)153   void DoCompute(OpKernelContext* ctx, const Tensor& logits_t,
154                  const Tensor& num_samples_t, GuardedPhiloxRandom* generator) {
155     OP_REQUIRES(ctx, TensorShapeUtils::IsMatrix(logits_t.shape()),
156                 errors::InvalidArgument("logits should be a matrix, got shape ",
157                                         logits_t.shape().DebugString()));
158     OP_REQUIRES(
159         ctx, TensorShapeUtils::IsScalar(num_samples_t.shape()),
160         errors::InvalidArgument("num_samples should be a scalar, got shape ",
161                                 num_samples_t.shape().DebugString()));
162 
163     const int num_samples = num_samples_t.scalar<int>()();
164     OP_REQUIRES(ctx, num_samples >= 0,
165                 errors::InvalidArgument(
166                     "num_samples should be nonnegative, got ", num_samples));
167 
168     for (int i = 0; i < 2; i++) {
169       const int64_t dim = logits_t.dim_size(i);
170       OP_REQUIRES(ctx, static_cast<int>(dim) == dim,
171                   errors::InvalidArgument(
172                       "logits.shape = ", logits_t.shape().DebugString(),
173                       " too large for int"));
174     }
175     const int batch_size = static_cast<int>(logits_t.dim_size(0));
176     const int num_classes = static_cast<int>(logits_t.dim_size(1));
177     OP_REQUIRES(ctx, num_classes > 0,
178                 errors::InvalidArgument("num_classes should be positive, got ",
179                                         num_classes));
180 
181     Tensor* samples_t;
182     OP_REQUIRES_OK(
183         ctx, ctx->allocate_output(0, TensorShape({batch_size, num_samples}),
184                                   &samples_t));
185 
186     // Execute kernel only for nonempty output; otherwise Eigen crashes on GPU.
187     if (samples_t->NumElements() > 0) {
188       Tensor noises, scores, scratch;  // Scratch space only used for GPU.
189       if (std::is_same<Device, GPUDevice>::value) {
190         OP_REQUIRES_OK(
191             ctx,
192             ctx->allocate_temp(
193                 DT_FLOAT, TensorShape({batch_size, num_samples, num_classes}),
194                 &noises));
195         OP_REQUIRES_OK(
196             ctx,
197             ctx->allocate_temp(
198                 DT_FLOAT, TensorShape({batch_size, num_samples, num_classes}),
199                 &scores));
200         OP_REQUIRES_OK(
201             ctx,
202             ctx->allocate_temp(DT_FLOAT, TensorShape({batch_size, num_samples}),
203                                &scratch));
204       }
205 
206       int num_samples_ceil_4 = (num_samples + 3) / 4 * 4;
207       // CPU generates doubles = 2 samples per number.
208       if (std::is_same<Device, CPUDevice>::value) num_samples_ceil_4 *= 2;
209       auto rng =
210           generator->ReserveRandomOutputs(batch_size * num_samples_ceil_4, 256);
211       functor::MultinomialFunctor<Device, T, OutputType>()(
212           ctx, ctx->eigen_device<Device>(), logits_t.matrix<T>(),
213           noises.flat<float>(), scores.flat<float>(), scratch.flat<float>(),
214           batch_size, num_classes, num_samples, rng,
215           samples_t->matrix<OutputType>());
216     }
217   }
218 };
219 
220 template <typename Device, typename T, typename OutputType>
221 class StatefulMultinomialOp : public MultinomialOp<Device, T, OutputType> {
222  public:
StatefulMultinomialOp(OpKernelConstruction * ctx)223   explicit StatefulMultinomialOp(OpKernelConstruction* ctx)
224       : MultinomialOp<Device, T, OutputType>(ctx) {
225     OP_REQUIRES_OK(ctx, generator_.Init(ctx));
226   }
227 
Compute(OpKernelContext * ctx)228   void Compute(OpKernelContext* ctx) override {
229     const Tensor& logits_t = ctx->input(0);
230     const Tensor& num_samples_t = ctx->input(1);
231     this->DoCompute(ctx, logits_t, num_samples_t, &generator_);
232   }
233 
234  private:
235   GuardedPhiloxRandom generator_;
236 };
237 
238 // TODO(b/77906027): Add a TPU implementation.
239 #define REGISTER(TYPE)                                                    \
240   REGISTER_KERNEL_BUILDER(Name("Multinomial")                             \
241                               .Device(DEVICE_CPU)                         \
242                               .TypeConstraint<TYPE>("T")                  \
243                               .TypeConstraint("output_dtype", DT_INT32),  \
244                           StatefulMultinomialOp<CPUDevice, TYPE, int32>); \
245   REGISTER_KERNEL_BUILDER(Name("Multinomial")                             \
246                               .Device(DEVICE_CPU)                         \
247                               .TypeConstraint<TYPE>("T")                  \
248                               .TypeConstraint("output_dtype", DT_INT64),  \
249                           StatefulMultinomialOp<CPUDevice, TYPE, int64>);
250 
251 TF_CALL_half(REGISTER);
252 TF_CALL_float(REGISTER);
253 TF_CALL_double(REGISTER);
254 #undef REGISTER
255 
256 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
257 #define REGISTER(TYPE)                                                   \
258   REGISTER_KERNEL_BUILDER(Name("Multinomial")                            \
259                               .Device(DEVICE_GPU)                        \
260                               .HostMemory("num_samples")                 \
261                               .TypeConstraint<TYPE>("T")                 \
262                               .TypeConstraint("output_dtype", DT_INT32), \
263                           StatefulMultinomialOp<GPUDevice, TYPE, int32>) \
264   REGISTER_KERNEL_BUILDER(Name("Multinomial")                            \
265                               .Device(DEVICE_GPU)                        \
266                               .HostMemory("num_samples")                 \
267                               .TypeConstraint<TYPE>("T")                 \
268                               .TypeConstraint("output_dtype", DT_INT64), \
269                           StatefulMultinomialOp<GPUDevice, TYPE, int64>)
270 
271 TF_CALL_half(REGISTER);
272 TF_CALL_float(REGISTER);
273 TF_CALL_double(REGISTER);
274 #undef REGISTER
275 
276 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
277 
278 template <typename Device, typename T, typename OutputType>
279 class StatelessMultinomialOp : public MultinomialOp<Device, T, OutputType> {
280  public:
StatelessMultinomialOp(OpKernelConstruction * ctx)281   explicit StatelessMultinomialOp(OpKernelConstruction* ctx)
282       : MultinomialOp<Device, T, OutputType>(ctx) {}
283 
Compute(OpKernelContext * ctx)284   void Compute(OpKernelContext* ctx) override {
285     const Tensor& logits_t = ctx->input(0);
286     const Tensor& num_samples_t = ctx->input(1);
287 
288     const Tensor& seed_t = ctx->input(2);
289     OP_REQUIRES(ctx, seed_t.dims() == 1 && seed_t.dim_size(0) == 2,
290                 errors::InvalidArgument("seed must have shape [2], not ",
291                                         seed_t.shape().DebugString()));
292 
293     random::PhiloxRandom::Key key;
294     random::PhiloxRandom::ResultType counter;
295     OP_REQUIRES_OK(ctx, GenerateKey(seed_t, &key, &counter));
296 
297     GuardedPhiloxRandom generator;
298     generator.Init(counter, key);
299 
300     this->DoCompute(ctx, logits_t, num_samples_t, &generator);
301   }
302 
303  private:
304   GuardedPhiloxRandom generator_;
305 };
306 
307 #define REGISTER(TYPE)                                                     \
308   REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial")                     \
309                               .Device(DEVICE_CPU)                          \
310                               .TypeConstraint<TYPE>("T")                   \
311                               .TypeConstraint("output_dtype", DT_INT32),   \
312                           StatelessMultinomialOp<CPUDevice, TYPE, int32>); \
313   REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial")                     \
314                               .Device(DEVICE_CPU)                          \
315                               .TypeConstraint<TYPE>("T")                   \
316                               .TypeConstraint("output_dtype", DT_INT64),   \
317                           StatelessMultinomialOp<CPUDevice, TYPE, int64>);
318 
319 TF_CALL_half(REGISTER);
320 TF_CALL_float(REGISTER);
321 TF_CALL_double(REGISTER);
322 #undef REGISTER
323 
324 #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
325 #define REGISTER(TYPE)                                                    \
326   REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial")                    \
327                               .Device(DEVICE_GPU)                         \
328                               .HostMemory("num_samples")                  \
329                               .HostMemory("seed")                         \
330                               .TypeConstraint<TYPE>("T")                  \
331                               .TypeConstraint("output_dtype", DT_INT32),  \
332                           StatelessMultinomialOp<GPUDevice, TYPE, int32>) \
333   REGISTER_KERNEL_BUILDER(Name("StatelessMultinomial")                    \
334                               .Device(DEVICE_GPU)                         \
335                               .HostMemory("num_samples")                  \
336                               .HostMemory("seed")                         \
337                               .TypeConstraint<TYPE>("T")                  \
338                               .TypeConstraint("output_dtype", DT_INT64),  \
339                           StatelessMultinomialOp<GPUDevice, TYPE, int64>)
340 
341 TF_CALL_half(REGISTER);
342 TF_CALL_float(REGISTER);
343 TF_CALL_double(REGISTER);
344 #undef REGISTER
345 
346 #endif  // GOOGLE_CUDA || TENSORFLOW_USE_ROCM
347 
348 }  // end namespace
349 
350 }  // end namespace tensorflow
351