xref: /aosp_15_r20/external/tensorflow/tensorflow/core/kernels/random_binomial_op.cc (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 // See docs in ../ops/random_ops.cc.
17 // NOTE: If the algorithm is changed, please run the test
18 // .../python/kernel_tests/random:random_binomial_test
19 // commenting out the "tf.set_random_seed(seed)" lines, and using the
20 // "--runs-per-test=1000" flag. This tests the statistical correctness of the
21 // op results.
22 
23 #define EIGEN_USE_THREADS
24 
25 #include "tensorflow/core/kernels/random_binomial_op.h"
26 
27 #include <algorithm>
28 #include <cmath>
29 #include <memory>
30 
31 #include "tensorflow/core/framework/op_kernel.h"
32 #include "tensorflow/core/framework/register_types.h"
33 #include "tensorflow/core/framework/rng_alg.h"
34 #include "tensorflow/core/framework/tensor.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/kernels/random_ops_util.h"
37 #include "tensorflow/core/kernels/stateful_random_ops_cpu_gpu.h"
38 #include "tensorflow/core/kernels/stateless_random_ops.h"
39 #include "tensorflow/core/kernels/training_op_helpers.h"
40 #include "tensorflow/core/lib/core/refcount.h"
41 #include "tensorflow/core/lib/random/random_distributions.h"
42 #include "tensorflow/core/platform/logging.h"
43 #include "tensorflow/core/util/bcast.h"
44 #include "tensorflow/core/util/guarded_philox_random.h"
45 #include "tensorflow/core/util/work_sharder.h"
46 
47 #define UNIFORM(X)                                    \
48   if (uniform_remaining == 0) {                       \
49     uniform_remaining = Uniform::kResultElementCount; \
50     uniform_result = uniform(gen);                    \
51   }                                                   \
52   uniform_remaining--;                                \
53   double X = uniform_result[uniform_remaining]
54 
55 namespace tensorflow {
56 
57 typedef Eigen::ThreadPoolDevice CPUDevice;
58 typedef Eigen::GpuDevice GPUDevice;
59 
60 namespace {
61 
62 typedef random::UniformDistribution<random::PhiloxRandom, double> Uniform;
63 
64 // Binomial inversion. Given prob, sum geometric random variables until they
65 // exceed count. The number of random variables used is binomially distributed.
66 // This is also known as binomial inversion, as this is equivalent to inverting
67 // the Binomial CDF.
binomial_inversion(double count,double prob,random::PhiloxRandom * gen)68 double binomial_inversion(double count, double prob,
69                           random::PhiloxRandom* gen) {
70   using Eigen::numext::ceil;
71   using Eigen::numext::log;
72   using Eigen::numext::log1p;
73 
74   double geom_sum = 0;
75   int num_geom = 0;
76 
77   Uniform uniform;
78   typename Uniform::ResultType uniform_result;
79   int16_t uniform_remaining = 0;
80 
81   while (true) {
82     UNIFORM(u);
83     double geom = ceil(log(u) / log1p(-prob));
84     geom_sum += geom;
85     if (geom_sum > count) {
86       break;
87     }
88     ++num_geom;
89   }
90   return num_geom;
91 }
92 
stirling_approx_tail(double k)93 inline double stirling_approx_tail(double k) {
94   static double kTailValues[] = {0.0810614667953272,  0.0413406959554092,
95                                  0.0276779256849983,  0.02079067210376509,
96                                  0.0166446911898211,  0.0138761288230707,
97                                  0.0118967099458917,  0.0104112652619720,
98                                  0.00925546218271273, 0.00833056343336287};
99   if (k <= 9) {
100     return kTailValues[static_cast<int>(k)];
101   }
102   double kp1sq = (k + 1) * (k + 1);
103   return (1.0 / 12 - (1.0 / 360 - 1.0 / 1260 / kp1sq) / kp1sq) / (k + 1);
104 }
105 
106 // We use a transformation-rejection algorithm from
107 // pairs of uniform random variables due to Hormann.
108 // https://www.tandfonline.com/doi/abs/10.1080/00949659308811496
btrs(double count,double prob,random::PhiloxRandom * gen)109 inline double btrs(double count, double prob, random::PhiloxRandom* gen) {
110   using Eigen::numext::abs;
111   using Eigen::numext::floor;
112   using Eigen::numext::log;
113   using Eigen::numext::log1p;
114   using Eigen::numext::sqrt;
115 
116   // This is spq in the paper.
117   const double stddev = sqrt(count * prob * (1 - prob));
118 
119   // Other coefficients for Transformed Rejection sampling.
120   const double b = 1.15 + 2.53 * stddev;
121   const double a = -0.0873 + 0.0248 * b + 0.01 * prob;
122   const double c = count * prob + 0.5;
123   const double v_r = 0.92 - 4.2 / b;
124   const double r = prob / (1 - prob);
125 
126   const double alpha = (2.83 + 5.1 / b) * stddev;
127   const double m = floor((count + 1) * prob);
128 
129   Uniform uniform;
130   typename Uniform::ResultType uniform_result;
131   int16_t uniform_remaining = 0;
132 
133   while (true) {
134     UNIFORM(u);
135     UNIFORM(v);
136     u = u - 0.5;
137     double us = 0.5 - abs(u);
138     double k = floor((2 * a / us + b) * u + c);
139 
140     // Region for which the box is tight, and we
141     // can return our calculated value This should happen
142     // 0.86 * v_r times. In the limit as n * p is large,
143     // the acceptance rate converges to ~79% (and in the lower
144     // regime it is ~24%).
145     if (us >= 0.07 && v <= v_r) {
146       return k;
147     }
148     // Reject non-sensical answers.
149     if (k < 0 || k > count) {
150       continue;
151     }
152 
153     // This deviates from Hormann's BRTS algorithm, as there is a log missing.
154     // For all (u, v) pairs outside of the bounding box, this calculates the
155     // transformed-reject ratio.
156     v = log(v * alpha / (a / (us * us) + b));
157     double upperbound =
158         ((m + 0.5) * log((m + 1) / (r * (count - m + 1))) +
159          (count + 1) * log((count - m + 1) / (count - k + 1)) +
160          (k + 0.5) * log(r * (count - k + 1) / (k + 1)) +
161          stirling_approx_tail(m) + stirling_approx_tail(count - m) -
162          stirling_approx_tail(k) - stirling_approx_tail(count - k));
163     if (v <= upperbound) {
164       return k;
165     }
166   }
167 }
168 
169 }  // namespace
170 
171 namespace functor {
172 
173 template <typename T, typename U>
174 struct RandomBinomialFunctor<CPUDevice, T, U> {
operator ()tensorflow::functor::RandomBinomialFunctor175   void operator()(OpKernelContext* ctx, const CPUDevice& d, int64_t num_batches,
176                   int64_t samples_per_batch, int64_t num_elements,
177                   const BCast& bcast, typename TTypes<T>::ConstFlat counts,
178                   typename TTypes<T>::ConstFlat probs,
179                   const random::PhiloxRandom& gen,
180                   typename TTypes<U>::Flat output) {
181     auto worker_threads = *(ctx->device()->tensorflow_cpu_worker_threads());
182 
183     // The output layout is [B1, ... Bk, H1, ... Hm]. We have [B1, ... Bk] for
184     // the sample shape and [H1, ... Hm] for the batch shape of the samples.
185     // We have B1 * ... * Bk samples per batch member we need.
186     auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs,
187                    &gen, &output](int64_t start_output, int64_t limit_output) {
188       // Vectorized intermediate calculations for uniform rejection sampling.
189       // We always generate at most 4 samples.
190       Eigen::array<T, 4> z;
191       Eigen::array<T, 4> g;
192       const bool should_bcast = bcast.IsBroadcastingRequired();
193       const auto& counts_batch_indices = bcast.x_batch_indices();
194       const auto& probs_batch_indices = bcast.y_batch_indices();
195       auto output_flat = output.data();
196 
197       // We partition work across batches (count, prob) and then across samples
198       // per batch member, to avoid extra work.
199       for (int64_t output_idx = start_output; output_idx < limit_output;
200            // output_idx is incremented with the inner loops below.
201       ) {
202         int64_t batch_idx = output_idx / samples_per_batch;
203         U* const output_batch_offset = output_flat + batch_idx;
204         // Generate batch counts from BCast, as it has the right indices to loop
205         // over.
206         T count, prob;
207         if (should_bcast) {
208           count = counts(counts_batch_indices[batch_idx]);
209           prob = probs(probs_batch_indices[batch_idx]);
210         } else {
211           count = counts(batch_idx);
212           prob = probs(batch_idx);
213         }
214 
215         // Calculate normalized samples, then convert them.
216         // Determine the method to use.
217         double dcount = static_cast<double>(count);
218         if (dcount <= 0.0 || prob <= T(0.0)) {
219           for (int64_t sample_idx = output_idx % samples_per_batch;
220                sample_idx < samples_per_batch && output_idx < limit_output;
221                ++sample_idx, ++output_idx) {
222             output_batch_offset[sample_idx * num_batches] = static_cast<U>(0.0);
223           }
224         } else if (prob >= T(1.0)) {
225           for (int64_t sample_idx = output_idx % samples_per_batch;
226                sample_idx < samples_per_batch && output_idx < limit_output;
227                ++sample_idx, ++output_idx) {
228             output_batch_offset[sample_idx * num_batches] =
229                 static_cast<U>(dcount);
230           }
231         } else if (prob <= T(0.5)) {
232           double dp = static_cast<double>(prob);
233           if (count * prob >= T(10)) {
234             for (int64_t sample_idx = output_idx % samples_per_batch;
235                  sample_idx < samples_per_batch && output_idx < limit_output;
236                  ++sample_idx, ++output_idx) {
237               random::PhiloxRandom gen_copy = gen;
238               gen_copy.Skip(256 * output_idx);
239               output_batch_offset[sample_idx * num_batches] =
240                   static_cast<U>(btrs(dcount, dp, &gen_copy));
241             }
242           } else {
243             for (int64_t sample_idx = output_idx % samples_per_batch;
244                  sample_idx < samples_per_batch && output_idx < limit_output;
245                  ++sample_idx, ++output_idx) {
246               random::PhiloxRandom gen_copy = gen;
247               // For binomial inversion, we have mean <= 10, variance <= 10.
248               // This means on average we need at most 10 number of samples,
249               // and for 10 standard deviations, we need 42 samples. We reserve
250               // that much.
251               gen_copy.Skip(42 * output_idx);
252               output_batch_offset[sample_idx * num_batches] =
253                   static_cast<U>(binomial_inversion(dcount, dp, &gen_copy));
254             }
255           }
256         } else if (prob > T(0.5)) {
257           T q = T(1) - prob;
258           double dq = static_cast<double>(q);
259           if (count * q >= T(10)) {
260             for (int64_t sample_idx = output_idx % samples_per_batch;
261                  sample_idx < samples_per_batch && output_idx < limit_output;
262                  ++sample_idx, ++output_idx) {
263               random::PhiloxRandom gen_copy = gen;
264               gen_copy.Skip(256 * output_idx);
265               output_batch_offset[sample_idx * num_batches] =
266                   static_cast<U>(dcount - btrs(dcount, dq, &gen_copy));
267             }
268           } else {
269             for (int64_t sample_idx = output_idx % samples_per_batch;
270                  sample_idx < samples_per_batch && output_idx < limit_output;
271                  ++sample_idx, ++output_idx) {
272               random::PhiloxRandom gen_copy = gen;
273               // For binomial inversion, we have mean <= 10, variance <= 10.
274               // This means on average we need at most 10 number of samples,
275               // and for 10 standard deviations, we need 42 samples. We reserve
276               // that much.
277               gen_copy.Skip(42 * output_idx);
278               output_batch_offset[sample_idx * num_batches] = static_cast<U>(
279                   dcount - binomial_inversion(dcount, dq, &gen_copy));
280             }
281           }
282         } else {  // prob is NaN
283           // TODO(srvasude): What should happen if prob is NaN but the output
284           // type is an integer (which doesn't have a sentinel for NaN)?  Fail
285           // the whole batch sample?  Return a specialized sentinel like -1?
286           for (int64_t sample_idx = output_idx % samples_per_batch;
287                sample_idx < samples_per_batch && output_idx < limit_output;
288                ++sample_idx, ++output_idx) {
289             output_batch_offset[sample_idx * num_batches] = static_cast<U>(NAN);
290           }
291         }
292       }
293     };
294 
295     // This will depend on count * p (or count * q).
296     // For n * p < 10, on average, O(n * p) calls to uniform are
297     // needed, with that
298     // many multiplies. ~10 uniform calls on average with ~200 cost op calls.
299     //
300     // Very roughly, for rate >= 10, the four calls to log
301     // occur for ~72 percent of samples.
302     // 4 x 100 (64-bit cycles per log) * 0.72 = ~288
303     // Additionally, there are ~10 other ops (+, *, /, ...) at 3-6 cycles each:
304     // 40 * .72  = ~25.
305     //
306     // Finally, there are several other ops that are done every loop along with
307     // 2 uniform generations along with 5 other ops at 3-6 cycles each.
308     // ~15 / .89 = ~16
309     //
310     // In total this (rate >= 10) should be ~329 + 2 * Uniform::kElementCost.
311     // We assume that half the tensor has rate < 10, so on average 6
312     // uniform's
313     // will be needed. We will upper bound the other op cost by the one for
314     // rate > 10.
315     static const int kElementCost = 329 + 6 * Uniform::kElementCost +
316                                     6 * random::PhiloxRandom::kElementCost;
317     Shard(worker_threads.num_threads, worker_threads.workers, num_elements,
318           kElementCost, DoWork);
319   }
320 };
321 
322 }  // namespace functor
323 
324 namespace {
325 
326 // Samples from a binomial distribution, using the given parameters.
327 template <typename Device, typename T, typename U>
328 class RandomBinomialOp : public OpKernel {
329   // Reshape batches so each batch is this size if possible.
330   static constexpr int32_t kDesiredBatchSize = 100;
331 
332  public:
RandomBinomialOp(OpKernelConstruction * context)333   explicit RandomBinomialOp(OpKernelConstruction* context)
334       : OpKernel(context) {}
335 
Compute(OpKernelContext * ctx)336   void Compute(OpKernelContext* ctx) override {
337     const Tensor& alg_tensor = ctx->input(1);
338     const Tensor& shape_tensor = ctx->input(2);
339     const Tensor& counts_tensor = ctx->input(3);
340     const Tensor& probs_tensor = ctx->input(4);
341 
342     tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(),
343                             probs_tensor.shape().dim_sizes(),
344                             /*fewer_dims_optimization=*/false,
345                             /*return_flattened_batch_indices=*/true);
346     OP_REQUIRES(ctx, bcast.IsValid(),
347                 errors::InvalidArgument(
348                     "counts and probs must have compatible batch dimensions: ",
349                     counts_tensor.shape().DebugString(), " vs. ",
350                     probs_tensor.shape().DebugString()));
351     OP_REQUIRES(
352         ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
353         errors::InvalidArgument("Input shape should be a vector, got shape: ",
354                                 shape_tensor.shape().DebugString()));
355     OP_REQUIRES(ctx,
356                 (shape_tensor.dtype() == DataType::DT_INT32 ||
357                  shape_tensor.dtype() == DataType::DT_INT64),
358                 errors::InvalidArgument(
359                     "Input shape should have dtype {int32, int64}."));
360 
361     // Let's check that the shape tensor dominates the broadcasted tensor.
362     TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
363     TensorShape output_shape;
364     if (shape_tensor.dtype() == DataType::DT_INT32) {
365       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
366                                                       &output_shape));
367     } else {
368       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
369                               shape_tensor.vec<int64_t>(), &output_shape));
370     }
371     OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
372                 errors::InvalidArgument(
373                     "Shape passed in must end with broadcasted shape."));
374     // Now that we have a guarantee, we can get the additional dimensions added
375     // by sampling.
376     OP_REQUIRES(ctx, alg_tensor.dims() == 0,
377                 errors::InvalidArgument("algorithm must be of shape [], not ",
378                                         alg_tensor.shape().DebugString()));
379     Algorithm alg = Algorithm(alg_tensor.flat<int64_t>()(0));
380 
381     int64_t samples_per_batch = 1;
382     const int64_t num_sample_dims =
383         (shape_tensor.dim_size(0) - bcast.output_shape().size());
384     for (int64_t i = 0; i < num_sample_dims; ++i) {
385       samples_per_batch *= shape_tensor.flat<int32>()(i);
386     }
387     int64_t num_batches = 1;
388     for (int64_t i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
389       num_batches *= shape_tensor.flat<int32>()(i);
390     }
391     const int64_t num_elements = num_batches * samples_per_batch;
392 
393     Tensor* samples_tensor;
394     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
395 
396     core::RefCountPtr<Var> var;
397     OP_REQUIRES_OK(ctx, LookupResource(ctx, HandleFromInput(ctx, 0), &var));
398 
399     Tensor* var_tensor = var->tensor();
400     OP_REQUIRES(
401         ctx, var_tensor->dtype() == STATE_ELEMENT_DTYPE,
402         errors::InvalidArgument("dtype of RNG state variable must be ",
403                                 DataTypeString(STATE_ELEMENT_DTYPE), ", not ",
404                                 DataTypeString(var_tensor->dtype())));
405     OP_REQUIRES(ctx, var_tensor->dims() == 1,
406                 errors::InvalidArgument(
407                     "RNG state must have one and only one dimension, not ",
408                     var_tensor->dims()));
409     auto var_tensor_flat = var_tensor->flat<StateElementType>();
410     OP_REQUIRES(ctx, alg == RNG_ALG_PHILOX,
411                 errors::InvalidArgument("Unsupported algorithm id: ", alg));
412     static_assert(std::is_same<StateElementType, int64_t>::value,
413                   "StateElementType must be int64");
414     static_assert(std::is_same<PhiloxRandom::ResultElementType, uint32>::value,
415                   "PhiloxRandom::ResultElementType must be uint32");
416     OP_REQUIRES(ctx, var_tensor_flat.size() >= PHILOX_MIN_STATE_SIZE,
417                 errors::InvalidArgument(
418                     "For Philox algorithm, the size of state must be at least ",
419                     PHILOX_MIN_STATE_SIZE, "; got ", var_tensor_flat.size()));
420 
421     OP_REQUIRES_OK(ctx, PrepareToUpdateVariable<Device, StateElementType>(
422                             ctx, var_tensor, var->copy_on_read_mode.load()));
423     auto var_data = var_tensor_flat.data();
424     auto philox = GetPhiloxRandomFromMem(var_data);
425     UpdateMemWithPhiloxRandom(
426         philox, num_batches * 2 * 100 * (samples_per_batch + 3) / 4, var_data);
427 
428     auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
429     binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
430                      samples_per_batch, num_elements, bcast,
431                      counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox,
432                      samples_tensor->flat<U>());
433   }
434 
435  private:
436   TF_DISALLOW_COPY_AND_ASSIGN(RandomBinomialOp);
437 };
438 
439 // Samples from a binomial distribution, using the given parameters.
440 template <typename Device, typename T, typename U>
441 class StatelessRandomBinomialOp : public OpKernel {
442   // Reshape batches so each batch is this size if possible.
443   static constexpr int32_t kDesiredBatchSize = 100;
444 
445  public:
StatelessRandomBinomialOp(OpKernelConstruction * context)446   explicit StatelessRandomBinomialOp(OpKernelConstruction* context)
447       : OpKernel(context) {}
448 
Compute(OpKernelContext * ctx)449   void Compute(OpKernelContext* ctx) override {
450     const Tensor& shape_tensor = ctx->input(0);
451     const Tensor& seed_tensor = ctx->input(1);
452     const Tensor& counts_tensor = ctx->input(2);
453     const Tensor& probs_tensor = ctx->input(3);
454 
455     OP_REQUIRES(ctx, seed_tensor.dims() == 1 && seed_tensor.dim_size(0) == 2,
456                 errors::InvalidArgument("seed must have shape [2], not ",
457                                         seed_tensor.shape().DebugString()));
458 
459     tensorflow::BCast bcast(counts_tensor.shape().dim_sizes(),
460                             probs_tensor.shape().dim_sizes(),
461                             /*fewer_dims_optimization=*/false,
462                             /*return_flattened_batch_indices=*/true);
463     OP_REQUIRES(ctx, bcast.IsValid(),
464                 errors::InvalidArgument(
465                     "counts and probs must have compatible batch dimensions: ",
466                     counts_tensor.shape().DebugString(), " vs. ",
467                     probs_tensor.shape().DebugString()));
468     OP_REQUIRES(
469         ctx, TensorShapeUtils::IsVector(shape_tensor.shape()),
470         errors::InvalidArgument("Input shape should be a vector, got shape: ",
471                                 shape_tensor.shape().DebugString()));
472     OP_REQUIRES(ctx,
473                 (shape_tensor.dtype() == DataType::DT_INT32 ||
474                  shape_tensor.dtype() == DataType::DT_INT64),
475                 errors::InvalidArgument(
476                     "Input shape should have dtype {int32, int64}."));
477 
478     // Let's check that the shape tensor dominates the broadcasted tensor.
479     TensorShape bcast_shape = BCast::ToShape(bcast.output_shape());
480     TensorShape output_shape;
481     if (shape_tensor.dtype() == DataType::DT_INT32) {
482       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(shape_tensor.vec<int32>(),
483                                                       &output_shape));
484     } else {
485       OP_REQUIRES_OK(ctx, TensorShapeUtils::MakeShape(
486                               shape_tensor.vec<int64_t>(), &output_shape));
487     }
488     OP_REQUIRES(ctx, TensorShapeUtils::EndsWith(output_shape, bcast_shape),
489                 errors::InvalidArgument(
490                     "Shape passed in must end with broadcasted shape."));
491     // Now that we have a guarantee, we can get the additional dimensions added
492     // by sampling.
493     int64_t samples_per_batch = 1;
494     const int64_t num_sample_dims =
495         (shape_tensor.dim_size(0) - bcast.output_shape().size());
496     for (int64_t i = 0; i < num_sample_dims; ++i) {
497       samples_per_batch *= shape_tensor.flat<int32>()(i);
498     }
499     int64_t num_batches = 1;
500     for (int64_t i = num_sample_dims; i < shape_tensor.dim_size(0); ++i) {
501       num_batches *= shape_tensor.flat<int32>()(i);
502     }
503     const int64_t num_elements = num_batches * samples_per_batch;
504 
505     Tensor* samples_tensor;
506     OP_REQUIRES_OK(ctx, ctx->allocate_output(0, output_shape, &samples_tensor));
507     if (output_shape.num_elements() == 0) return;
508 
509     random::PhiloxRandom::Key key;
510     random::PhiloxRandom::ResultType counter;
511     OP_REQUIRES_OK(ctx, GenerateKey(seed_tensor, &key, &counter));
512 
513     auto philox = random::PhiloxRandom(counter, key);
514     auto binomial_functor = functor::RandomBinomialFunctor<Device, T, U>();
515     binomial_functor(ctx, ctx->eigen_device<Device>(), num_batches,
516                      samples_per_batch, num_elements, bcast,
517                      counts_tensor.flat<T>(), probs_tensor.flat<T>(), philox,
518                      samples_tensor->flat<U>());
519   }
520 
521  private:
522   TF_DISALLOW_COPY_AND_ASSIGN(StatelessRandomBinomialOp);
523 };
524 
525 }  // namespace
526 
527 #define REGISTER(RTYPE, TYPE)                                        \
528   REGISTER_KERNEL_BUILDER(Name("StatefulRandomBinomial")             \
529                               .Device(DEVICE_CPU)                    \
530                               .HostMemory("resource")                \
531                               .HostMemory("algorithm")               \
532                               .HostMemory("shape")                   \
533                               .HostMemory("counts")                  \
534                               .HostMemory("probs")                   \
535                               .TypeConstraint<RTYPE>("dtype")        \
536                               .TypeConstraint<TYPE>("T"),            \
537                           RandomBinomialOp<CPUDevice, TYPE, RTYPE>); \
538   REGISTER_KERNEL_BUILDER(Name("StatelessRandomBinomial")            \
539                               .Device(DEVICE_CPU)                    \
540                               .HostMemory("shape")                   \
541                               .HostMemory("seed")                    \
542                               .HostMemory("counts")                  \
543                               .HostMemory("probs")                   \
544                               .TypeConstraint<RTYPE>("dtype")        \
545                               .TypeConstraint<TYPE>("T"),            \
546                           StatelessRandomBinomialOp<CPUDevice, TYPE, RTYPE>)
547 
548 #define REGISTER_ALL(RTYPE)     \
549   REGISTER(RTYPE, Eigen::half); \
550   REGISTER(RTYPE, float);       \
551   REGISTER(RTYPE, double);
552 
553 REGISTER_ALL(Eigen::half);
554 REGISTER_ALL(float);
555 REGISTER_ALL(double);
556 REGISTER_ALL(int32);
557 REGISTER_ALL(int64_t);
558 
559 #undef REGISTER
560 #undef REGISTER_ALL
561 
562 }  // end namespace tensorflow
563