xref: /aosp_15_r20/external/tensorflow/tensorflow/python/ops/candidate_sampling_ops.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15
16"""Wrappers for candidate sampling operations."""
17
18from tensorflow.python.framework import random_seed
19from tensorflow.python.ops import array_ops  # pylint: disable=unused-import
20from tensorflow.python.ops import gen_candidate_sampling_ops
21from tensorflow.python.ops import math_ops  # pylint: disable=unused-import
22from tensorflow.python.util import deprecation
23from tensorflow.python.util import dispatch
24from tensorflow.python.util.tf_export import tf_export
25
26
27@tf_export(
28    'random.uniform_candidate_sampler',
29    v1=['random.uniform_candidate_sampler', 'nn.uniform_candidate_sampler'])
30@dispatch.add_dispatch_support
31@deprecation.deprecated_endpoints('nn.uniform_candidate_sampler')
32def uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
33                              range_max, seed=None, name=None):
34  """Samples a set of classes using a uniform base distribution.
35
36  This operation randomly samples a tensor of sampled classes
37  (`sampled_candidates`) from the range of integers `[0, range_max)`.
38
39  The elements of `sampled_candidates` are drawn without replacement
40  (if `unique=True`) or with replacement (if `unique=False`) from
41  the base distribution.
42
43  The base distribution for this operation is the uniform distribution
44  over the range of integers `[0, range_max)`.
45
46  In addition, this operation returns tensors `true_expected_count`
47  and `sampled_expected_count` representing the number of times each
48  of the target classes (`true_classes`) and the sampled
49  classes (`sampled_candidates`) is expected to occur in an average
50  tensor of sampled classes.  These values correspond to `Q(y|x)`
51  defined in [this
52  document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
53  If `unique=True`, then these are post-rejection probabilities and we
54  compute them approximately.
55
56  Args:
57    true_classes: A `Tensor` of type `int64` and shape `[batch_size,
58      num_true]`. The target classes.
59    num_true: An `int`.  The number of target classes per training example.
60    num_sampled: An `int`.  The number of classes to randomly sample. The
61      `sampled_candidates` return value will have shape `[num_sampled]`. If
62      `unique=True`, `num_sampled` must be less than or equal to `range_max`.
63    unique: A `bool`. Determines whether all sampled classes in a batch are
64      unique.
65    range_max: An `int`. The number of possible classes.
66    seed: An `int`. An operation-specific seed. Default is 0.
67    name: A name for the operation (optional).
68
69  Returns:
70    sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.  The
71      sampled classes, either with possible duplicates (`unique=False`) or all
72      unique (`unique=True`). In either case, `sampled_candidates` is
73      independent of the true classes.
74    true_expected_count: A tensor of type `float`.  Same shape as
75      `true_classes`. The expected counts under the sampling distribution
76      of each of `true_classes`.
77    sampled_expected_count: A tensor of type `float`. Same shape as
78      `sampled_candidates`. The expected counts under the sampling distribution
79      of each of `sampled_candidates`.
80  """
81  seed1, seed2 = random_seed.get_seed(seed)
82  return gen_candidate_sampling_ops.uniform_candidate_sampler(
83      true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
84      seed2=seed2, name=name)
85
86
87@tf_export(
88    'random.log_uniform_candidate_sampler',
89    v1=[
90        'random.log_uniform_candidate_sampler',
91        'nn.log_uniform_candidate_sampler'
92    ])
93@dispatch.add_dispatch_support
94@deprecation.deprecated_endpoints('nn.log_uniform_candidate_sampler')
95def log_uniform_candidate_sampler(true_classes, num_true, num_sampled, unique,
96                                  range_max, seed=None, name=None):
97  """Samples a set of classes using a log-uniform (Zipfian) base distribution.
98
99  This operation randomly samples a tensor of sampled classes
100  (`sampled_candidates`) from the range of integers `[0, range_max)`.
101
102  The elements of `sampled_candidates` are drawn without replacement
103  (if `unique=True`) or with replacement (if `unique=False`) from
104  the base distribution.
105
106  The base distribution for this operation is an approximately log-uniform
107  or Zipfian distribution:
108
109  `P(class) = (log(class + 2) - log(class + 1)) / log(range_max + 1)`
110
111  This sampler is useful when the target classes approximately follow such
112  a distribution - for example, if the classes represent words in a lexicon
113  sorted in decreasing order of frequency. If your classes are not ordered by
114  decreasing frequency, do not use this op.
115
116  In addition, this operation returns tensors `true_expected_count`
117  and `sampled_expected_count` representing the number of times each
118  of the target classes (`true_classes`) and the sampled
119  classes (`sampled_candidates`) is expected to occur in an average
120  tensor of sampled classes.  These values correspond to `Q(y|x)`
121  defined in [this
122  document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
123  If `unique=True`, then these are post-rejection probabilities and we
124  compute them approximately.
125
126  Args:
127    true_classes: A `Tensor` of type `int64` and shape `[batch_size,
128      num_true]`. The target classes.
129    num_true: An `int`.  The number of target classes per training example.
130    num_sampled: An `int`.  The number of classes to randomly sample.
131    unique: A `bool`. Determines whether all sampled classes in a batch are
132      unique.
133    range_max: An `int`. The number of possible classes.
134    seed: An `int`. An operation-specific seed. Default is 0.
135    name: A name for the operation (optional).
136
137  Returns:
138    sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
139      The sampled classes.
140    true_expected_count: A tensor of type `float`.  Same shape as
141      `true_classes`. The expected counts under the sampling distribution
142      of each of `true_classes`.
143    sampled_expected_count: A tensor of type `float`. Same shape as
144      `sampled_candidates`. The expected counts under the sampling distribution
145      of each of `sampled_candidates`.
146  """
147  seed1, seed2 = random_seed.get_seed(seed)
148  return gen_candidate_sampling_ops.log_uniform_candidate_sampler(
149      true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
150      seed2=seed2, name=name)
151
152
153@tf_export(
154    'random.learned_unigram_candidate_sampler',
155    'nn.learned_unigram_candidate_sampler')
156@dispatch.add_dispatch_support
157@deprecation.deprecated_endpoints(['nn.learned_unigram_candidate_sampler'])
158def learned_unigram_candidate_sampler(true_classes, num_true, num_sampled,
159                                      unique, range_max, seed=None, name=None):
160  """Samples a set of classes from a distribution learned during training.
161
162  This operation randomly samples a tensor of sampled classes
163  (`sampled_candidates`) from the range of integers `[0, range_max)`.
164
165  The elements of `sampled_candidates` are drawn without replacement
166  (if `unique=True`) or with replacement (if `unique=False`) from
167  the base distribution.
168
169  The base distribution for this operation is constructed on the fly
170  during training.  It is a unigram distribution over the target
171  classes seen so far during training.  Every integer in `[0, range_max)`
172  begins with a weight of 1, and is incremented by 1 each time it is
173  seen as a target class.  The base distribution is not saved to checkpoints,
174  so it is reset when the model is reloaded.
175
176  In addition, this operation returns tensors `true_expected_count`
177  and `sampled_expected_count` representing the number of times each
178  of the target classes (`true_classes`) and the sampled
179  classes (`sampled_candidates`) is expected to occur in an average
180  tensor of sampled classes.  These values correspond to `Q(y|x)`
181  defined in [this
182  document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
183  If `unique=True`, then these are post-rejection probabilities and we
184  compute them approximately.
185
186  Args:
187    true_classes: A `Tensor` of type `int64` and shape `[batch_size,
188      num_true]`. The target classes.
189    num_true: An `int`.  The number of target classes per training example.
190    num_sampled: An `int`.  The number of classes to randomly sample.
191    unique: A `bool`. Determines whether all sampled classes in a batch are
192      unique.
193    range_max: An `int`. The number of possible classes.
194    seed: An `int`. An operation-specific seed. Default is 0.
195    name: A name for the operation (optional).
196
197  Returns:
198    sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
199      The sampled classes.
200    true_expected_count: A tensor of type `float`.  Same shape as
201      `true_classes`. The expected counts under the sampling distribution
202      of each of `true_classes`.
203    sampled_expected_count: A tensor of type `float`. Same shape as
204      `sampled_candidates`. The expected counts under the sampling distribution
205      of each of `sampled_candidates`.
206
207  """
208  seed1, seed2 = random_seed.get_seed(seed)
209  # Limiting to Max int32 value
210  if range_max > 2147483647:
211    raise ValueError(f'Value of range_max:{range_max} is too large to handle')
212  return gen_candidate_sampling_ops.learned_unigram_candidate_sampler(
213      true_classes, num_true, num_sampled, unique, range_max, seed=seed1,
214      seed2=seed2, name=name)
215
216
217@tf_export('random.fixed_unigram_candidate_sampler',
218           'nn.fixed_unigram_candidate_sampler')
219@dispatch.add_dispatch_support
220def fixed_unigram_candidate_sampler(true_classes,
221                                    num_true,
222                                    num_sampled,
223                                    unique,
224                                    range_max,
225                                    vocab_file='',
226                                    distortion=1.0,
227                                    num_reserved_ids=0,
228                                    num_shards=1,
229                                    shard=0,
230                                    unigrams=(),
231                                    seed=None,
232                                    name=None):
233  """Samples a set of classes using the provided (fixed) base distribution.
234
235  This operation randomly samples a tensor of sampled classes
236  (`sampled_candidates`) from the range of integers `[0, range_max)`.
237
238  The elements of `sampled_candidates` are drawn without replacement
239  (if `unique=True`) or with replacement (if `unique=False`) from
240  the base distribution.
241
242  The base distribution is read from a file or passed in as an
243  in-memory array. There is also an option to skew the distribution by
244  applying a distortion power to the weights.
245
246  In addition, this operation returns tensors `true_expected_count`
247  and `sampled_expected_count` representing the number of times each
248  of the target classes (`true_classes`) and the sampled
249  classes (`sampled_candidates`) is expected to occur in an average
250  tensor of sampled classes.  These values correspond to `Q(y|x)`
251  defined in [this
252  document](http://www.tensorflow.org/extras/candidate_sampling.pdf).
253  If `unique=True`, then these are post-rejection probabilities and we
254  compute them approximately.
255
256  Args:
257    true_classes: A `Tensor` of type `int64` and shape `[batch_size,
258      num_true]`. The target classes.
259    num_true: An `int`.  The number of target classes per training example.
260    num_sampled: An `int`.  The number of classes to randomly sample.
261    unique: A `bool`. Determines whether all sampled classes in a batch are
262      unique.
263    range_max: An `int`. The number of possible classes.
264    vocab_file: Each valid line in this file (which should have a CSV-like
265      format) corresponds to a valid word ID. IDs are in sequential order,
266      starting from num_reserved_ids. The last entry in each line is expected
267      to be a value corresponding to the count or relative probability. Exactly
268      one of `vocab_file` and `unigrams` needs to be passed to this operation.
269    distortion: The distortion is used to skew the unigram probability
270      distribution.  Each weight is first raised to the distortion's power
271      before adding to the internal unigram distribution. As a result,
272      `distortion = 1.0` gives regular unigram sampling (as defined by the vocab
273      file), and `distortion = 0.0` gives a uniform distribution.
274    num_reserved_ids: Optionally some reserved IDs can be added in the range
275      `[0, num_reserved_ids)` by the users. One use case is that a special
276      unknown word token is used as ID 0. These IDs will have a sampling
277      probability of 0.
278    num_shards: A sampler can be used to sample from a subset of the original
279      range in order to speed up the whole computation through parallelism. This
280      parameter (together with `shard`) indicates the number of partitions that
281      are being used in the overall computation.
282    shard: A sampler can be used to sample from a subset of the original range
283      in order to speed up the whole computation through parallelism. This
284      parameter (together with `num_shards`) indicates the particular partition
285      number of the operation, when partitioning is being used.
286    unigrams: A list of unigram counts or probabilities, one per ID in
287      sequential order. Exactly one of `vocab_file` and `unigrams` should be
288      passed to this operation.
289    seed: An `int`. An operation-specific seed. Default is 0.
290    name: A name for the operation (optional).
291
292  Returns:
293    sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
294      The sampled classes.
295    true_expected_count: A tensor of type `float`.  Same shape as
296      `true_classes`. The expected counts under the sampling distribution
297      of each of `true_classes`.
298    sampled_expected_count: A tensor of type `float`. Same shape as
299      `sampled_candidates`. The expected counts under the sampling distribution
300      of each of `sampled_candidates`.
301
302  """
303  seed1, seed2 = random_seed.get_seed(seed)
304  return gen_candidate_sampling_ops.fixed_unigram_candidate_sampler(
305      true_classes, num_true, num_sampled, unique, range_max,
306      vocab_file=vocab_file, distortion=distortion,
307      num_reserved_ids=num_reserved_ids, num_shards=num_shards, shard=shard,
308      unigrams=unigrams, seed=seed1, seed2=seed2, name=name)
309
310
311@tf_export('random.all_candidate_sampler', 'nn.all_candidate_sampler')
312def all_candidate_sampler(true_classes, num_true, num_sampled, unique,
313                          seed=None, name=None):
314  """Generate the set of all classes.
315
316  Deterministically generates and returns the set of all possible classes.
317  For testing purposes.  There is no need to use this, since you might as
318  well use full softmax or full logistic regression.
319
320  Args:
321    true_classes: A `Tensor` of type `int64` and shape `[batch_size,
322      num_true]`. The target classes.
323    num_true: An `int`.  The number of target classes per training example.
324    num_sampled: An `int`.  The number of possible classes.
325    unique: A `bool`. Ignored.
326      unique.
327    seed: An `int`. An operation-specific seed. Default is 0.
328    name: A name for the operation (optional).
329
330  Returns:
331    sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
332      This operation deterministically returns the entire range
333      `[0, num_sampled]`.
334    true_expected_count: A tensor of type `float`.  Same shape as
335      `true_classes`. The expected counts under the sampling distribution
336      of each of `true_classes`. All returned values are 1.0.
337    sampled_expected_count: A tensor of type `float`. Same shape as
338      `sampled_candidates`. The expected counts under the sampling distribution
339      of each of `sampled_candidates`. All returned values are 1.0.
340  """
341  seed1, seed2 = random_seed.get_seed(seed)
342  return gen_candidate_sampling_ops.all_candidate_sampler(
343      true_classes, num_true, num_sampled, unique, seed=seed1, seed2=seed2,
344      name=name)
345
346
347@tf_export('nn.compute_accidental_hits')
348@dispatch.add_dispatch_support
349def compute_accidental_hits(true_classes, sampled_candidates, num_true,
350                            seed=None, name=None):
351  """Compute the position ids in `sampled_candidates` matching `true_classes`.
352
353  In Candidate Sampling, this operation facilitates virtually removing
354  sampled classes which happen to match target classes.  This is done
355  in Sampled Softmax and Sampled Logistic.
356
357  See our [Candidate Sampling Algorithms
358  Reference](http://www.tensorflow.org/extras/candidate_sampling.pdf).
359
360  We presuppose that the `sampled_candidates` are unique.
361
362  We call it an 'accidental hit' when one of the target classes
363  matches one of the sampled classes.  This operation reports
364  accidental hits as triples `(index, id, weight)`, where `index`
365  represents the row number in `true_classes`, `id` represents the
366  position in `sampled_candidates`, and weight is `-FLOAT_MAX`.
367
368  The result of this op should be passed through a `sparse_to_dense`
369  operation, then added to the logits of the sampled classes. This
370  removes the contradictory effect of accidentally sampling the true
371  target classes as noise classes for the same example.
372
373  Args:
374    true_classes: A `Tensor` of type `int64` and shape `[batch_size,
375      num_true]`. The target classes.
376    sampled_candidates: A tensor of type `int64` and shape `[num_sampled]`.
377      The sampled_candidates output of CandidateSampler.
378    num_true: An `int`.  The number of target classes per training example.
379    seed: An `int`. An operation-specific seed. Default is 0.
380    name: A name for the operation (optional).
381
382  Returns:
383    indices: A `Tensor` of type `int32` and shape `[num_accidental_hits]`.
384      Values indicate rows in `true_classes`.
385    ids: A `Tensor` of type `int64` and shape `[num_accidental_hits]`.
386      Values indicate positions in `sampled_candidates`.
387    weights: A `Tensor` of type `float` and shape `[num_accidental_hits]`.
388      Each value is `-FLOAT_MAX`.
389
390  """
391  seed1, seed2 = random_seed.get_seed(seed)
392  return gen_candidate_sampling_ops.compute_accidental_hits(
393      true_classes, sampled_candidates, num_true, seed=seed1, seed2=seed2,
394      name=name)
395