xref: /aosp_15_r20/external/tensorflow/tensorflow/python/data/ops/options.py (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1# Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""API for specifying `tf.data` options."""
16
17import enum
18
19from absl import logging
20
21from tensorflow.core.framework import dataset_options_pb2
22from tensorflow.core.framework import model_pb2
23from tensorflow.python.data.util import options as options_lib
24from tensorflow.python.util import deprecation
25from tensorflow.python.util.tf_export import tf_export
26
27
28@tf_export("data.experimental.AutotuneAlgorithm")
29class AutotuneAlgorithm(enum.Enum):
30  """Represents the type of autotuning algorithm to use.
31
32  DEFAULT: The default behavior is implementation specific and may change over
33  time.
34
35  HILL_CLIMB: In each optimization step, this algorithm chooses the optimial
36  parameter and increases its value by 1.
37
38  GRADIENT_DESCENT: In each optimization step, this algorithm updates the
39  parameter values in the optimal direction.
40
41  MAX_PARALLELISM: Similar to HILL_CLIMB but uses a relaxed stopping condition,
42  allowing the optimization to oversubscribe the CPU.
43
44  STAGE_BASED: In each optimization step, this algorithm chooses the worst
45  bottleneck parameter and increases its value by 1.
46  """
47  DEFAULT = 0
48  HILL_CLIMB = 1
49  GRADIENT_DESCENT = 2
50  MAX_PARALLELISM = 3
51  STAGE_BASED = 4
52
53  @classmethod
54  def _to_proto(cls, obj):
55    if obj == cls.DEFAULT:
56      return model_pb2.AutotuneAlgorithm.DEFAULT
57    if obj == cls.HILL_CLIMB:
58      return model_pb2.AutotuneAlgorithm.HILL_CLIMB
59    if obj == cls.GRADIENT_DESCENT:
60      return model_pb2.AutotuneAlgorithm.GRADIENT_DESCENT
61    if obj == cls.MAX_PARALLELISM:
62      return model_pb2.AutotuneAlgorithm.MAX_PARALLELISM
63    if obj == cls.STAGE_BASED:
64      return model_pb2.AutotuneAlgorithm.STAGE_BASED
65    raise ValueError(
66        f"Invalid `obj.` Supported values include `DEFAULT`, `HILL_CLIMB` "
67        f"`GRADIENT_DESCENT`, and `STAGE_BASED`. Got {obj.name}.")
68
69  @classmethod
70  def _from_proto(cls, pb):
71    if pb == model_pb2.AutotuneAlgorithm.DEFAULT:
72      return cls.DEFAULT
73    if pb == model_pb2.AutotuneAlgorithm.HILL_CLIMB:
74      return cls.HILL_CLIMB
75    if pb == model_pb2.AutotuneAlgorithm.GRADIENT_DESCENT:
76      return cls.GRADIENT_DESCENT
77    if pb == model_pb2.AutotuneAlgorithm.MAX_PARALLELISM:
78      return cls.MAX_PARALLELISM
79    if pb == model_pb2.AutotuneAlgorithm.STAGE_BASED:
80      return cls.STAGE_BASED
81    raise ValueError(
82        f"Invalid `pb.` Supported values include `DEFAULT`, `HILL_CLIMB`, "
83        f"`GRADIENT_DESCENT` and `STAGE_BASED`. Got {pb}.")
84
85
86@tf_export("data.experimental.AutoShardPolicy")
87class AutoShardPolicy(enum.IntEnum):
88  """Represents the type of auto-sharding to use.
89
90  OFF: No sharding will be performed.
91
92  AUTO: Attempts FILE-based sharding, falling back to DATA-based sharding.
93
94  FILE: Shards by input files (i.e. each worker will get a set of files to
95  process). When this option is selected, make sure that there is at least as
96  many files as workers. If there are fewer input files than workers, a runtime
97  error will be raised.
98
99  DATA: Shards by elements produced by the dataset. Each worker will process the
100  whole dataset and discard the portion that is not for itself. Note that for
101  this mode to correctly partitions the dataset elements, the dataset needs to
102  produce elements in a deterministic order.
103
104  HINT: Looks for the presence of `shard(SHARD_HINT, ...)` which is treated as a
105  placeholder to replace with `shard(num_workers, worker_index)`.
106  """
107
108  # LINT.IfChange
109  OFF = -1
110  AUTO = 0
111  FILE = 1
112  DATA = 2
113  HINT = 3
114  # LINT.ThenChange(//tensorflow/python/data/experimental/ops/data_service_ops.py:tf_data_service_sharding_policy)
115
116  @classmethod
117  def _to_proto(cls, obj):
118    """Convert enum to proto."""
119    if obj == cls.OFF:
120      return dataset_options_pb2.AutoShardPolicy.OFF
121    if obj == cls.FILE:
122      return dataset_options_pb2.AutoShardPolicy.FILE
123    if obj == cls.DATA:
124      return dataset_options_pb2.AutoShardPolicy.DATA
125    if obj == cls.AUTO:
126      return dataset_options_pb2.AutoShardPolicy.AUTO
127    if obj == cls.HINT:
128      return dataset_options_pb2.AutoShardPolicy.HINT
129    raise ValueError(
130        f"Invalid `obj.` Supported values include `OFF`, `FILE`, `DATA`,"
131        f"`AUTO`, and `HINT`. Got {obj.name}."
132    )
133
134  @classmethod
135  def _from_proto(cls, pb):
136    """Convert proto to enum."""
137    if pb == dataset_options_pb2.AutoShardPolicy.OFF:
138      return cls.OFF
139    if pb == dataset_options_pb2.AutoShardPolicy.FILE:
140      return cls.FILE
141    if pb == dataset_options_pb2.AutoShardPolicy.DATA:
142      return cls.DATA
143    if pb == dataset_options_pb2.AutoShardPolicy.AUTO:
144      return cls.AUTO
145    if pb == dataset_options_pb2.AutoShardPolicy.HINT:
146      return cls.HINT
147    raise ValueError(
148        f"Invalid `pb.` Supported values include `OFF`, `FILE`, `DATA`,"
149        f"`AUTO`, and `HINT`. Got {pb}."
150    )
151
152
153@tf_export("data.experimental.ExternalStatePolicy")
154class ExternalStatePolicy(enum.Enum):
155  """Represents how to handle external state during serialization.
156
157  See the `tf.data.Options.experimental_external_state_policy` documentation
158  for more information.
159  """
160  WARN = 0
161  IGNORE = 1
162  FAIL = 2
163
164  @classmethod
165  def _to_proto(cls, obj):
166    """Convert enum to proto."""
167    if obj == cls.IGNORE:
168      return dataset_options_pb2.ExternalStatePolicy.POLICY_IGNORE
169    if obj == cls.FAIL:
170      return dataset_options_pb2.ExternalStatePolicy.POLICY_FAIL
171    if obj == cls.WARN:
172      return dataset_options_pb2.ExternalStatePolicy.POLICY_WARN
173    raise ValueError(
174        f"Invalid `obj.` Supported values include `POLICY_IGNORE`,"
175        f"`POLICY_FAIL`, `POLICY_WARN`. Got {obj.name}.")
176
177  @classmethod
178  def _from_proto(cls, pb):
179    """Convert proto to enum."""
180    if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_IGNORE:
181      return cls.IGNORE
182    if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_FAIL:
183      return cls.FAIL
184    if pb == dataset_options_pb2.ExternalStatePolicy.POLICY_WARN:
185      return cls.WARN
186    raise ValueError(
187        f"Invalid `pb.` Supported values include `POLICY_IGNORE`,"
188        f"`POLICY_FAIL`, `POLICY_WARN`. Got {pb}.")
189
190
191@tf_export("data.experimental.AutotuneOptions")
192class AutotuneOptions(options_lib.OptionsBase):
193  """Represents options for autotuning dataset performance.
194
195  ```python
196  options = tf.data.Options()
197  options.autotune.enabled = False
198  dataset = dataset.with_options(options)
199  ```
200  """
201
202  enabled = options_lib.create_option(
203      name="enabled",
204      ty=bool,
205      docstring="Whether to automatically tune performance knobs. If None, "
206      "defaults to True.")
207
208  cpu_budget = options_lib.create_option(
209      name="cpu_budget",
210      ty=int,
211      docstring="When autotuning is enabled (through `autotune`), determines "
212      "the CPU budget to use. Values greater than the number of schedulable "
213      "CPU cores are allowed but may result in CPU contention. If None, "
214      "defaults to the number of schedulable CPU cores.")
215
216  ram_budget = options_lib.create_option(
217      name="ram_budget",
218      ty=int,
219      docstring="When autotuning is enabled (through `autotune`), determines "
220      "the RAM budget to use. Values greater than the available RAM in bytes "
221      "may result in OOM. If None, defaults to half of the available RAM in "
222      "bytes.")
223
224  autotune_algorithm = options_lib.create_option(
225      name="autotune_algorithm",
226      ty=AutotuneAlgorithm,
227      docstring="When autotuning is enabled (through `autotune`), determines "
228      "the algorithm to use.")
229
230  def _to_proto(self):
231    pb = dataset_options_pb2.AutotuneOptions()
232    if self.enabled is not None:
233      pb.enabled = self.enabled
234    if self.cpu_budget is not None:
235      pb.cpu_budget = self.cpu_budget
236    if self.ram_budget is not None:
237      pb.ram_budget = self.ram_budget
238    if self.autotune_algorithm is not None:
239      pb.autotune_algorithm = AutotuneAlgorithm._to_proto(  # pylint: disable=protected-access
240          self.autotune_algorithm)
241    return pb
242
243  def _from_proto(self, pb):
244    if pb.WhichOneof("optional_enabled") is not None:
245      self.enabled = pb.enabled
246    if pb.WhichOneof("optional_cpu_budget") is not None:
247      self.cpu_budget = pb.cpu_budget
248    if pb.WhichOneof("optional_ram_budget") is not None:
249      self.ram_budget = pb.ram_budget
250    if pb.WhichOneof("optional_autotune_algorithm") is not None:
251      self.autotune_algorithm = AutotuneAlgorithm._from_proto(  # pylint: disable=protected-access
252          pb.autotune_algorithm)
253
254  def _set_mutable(self, mutable):
255    """Change the mutability value to `mutable` on this options and children."""
256    # pylint: disable=protected-access
257    object.__setattr__(self, "_mutable", mutable)
258
259
260@tf_export("data.experimental.DistributeOptions")
261class DistributeOptions(options_lib.OptionsBase):
262  """Represents options for distributed data processing.
263
264  You can set the distribution options of a dataset through the
265  `experimental_distribute` property of `tf.data.Options`; the property is
266  an instance of `tf.data.experimental.DistributeOptions`.
267
268  ```python
269  options = tf.data.Options()
270  options.experimental_distribute.auto_shard_policy = tf.data.experimental.AutoShardPolicy.OFF
271  dataset = dataset.with_options(options)
272  ```
273  """
274
275  auto_shard_policy = options_lib.create_option(
276      name="auto_shard_policy",
277      ty=AutoShardPolicy,
278      docstring="The type of sharding to use. See "
279      "`tf.data.experimental.AutoShardPolicy` for additional information.",
280      default_factory=lambda: AutoShardPolicy.AUTO)
281
282  num_devices = options_lib.create_option(
283      name="num_devices",
284      ty=int,
285      docstring=
286      "The number of devices attached to this input pipeline. This will be "
287      "automatically set by `MultiDeviceIterator`.")
288
289  def _to_proto(self):
290    pb = dataset_options_pb2.DistributeOptions()
291    pb.auto_shard_policy = AutoShardPolicy._to_proto(self.auto_shard_policy)  # pylint: disable=protected-access
292    if self.num_devices is not None:
293      pb.num_devices = self.num_devices
294    return pb
295
296  def _from_proto(self, pb):
297    self.auto_shard_policy = AutoShardPolicy._from_proto(pb.auto_shard_policy)  # pylint: disable=protected-access
298    if pb.WhichOneof("optional_num_devices") is not None:
299      self.num_devices = pb.num_devices
300
301
302@tf_export("data.experimental.OptimizationOptions")
303class OptimizationOptions(options_lib.OptionsBase):
304  """Represents options for dataset optimizations.
305
306  You can set the optimization options of a dataset through the
307  `experimental_optimization` property of `tf.data.Options`; the property is
308  an instance of `tf.data.experimental.OptimizationOptions`.
309
310  ```python
311  options = tf.data.Options()
312  options.experimental_optimization.noop_elimination = True
313  options.experimental_optimization.apply_default_optimizations = False
314  dataset = dataset.with_options(options)
315  ```
316  """
317  apply_default_optimizations = options_lib.create_option(
318      name="apply_default_optimizations",
319      ty=bool,
320      docstring=
321      "Whether to apply default graph optimizations. If False, only graph "
322      "optimizations that have been explicitly enabled will be applied.")
323
324  filter_fusion = options_lib.create_option(
325      name="filter_fusion",
326      ty=bool,
327      docstring=
328      "Whether to fuse filter transformations. If None, defaults to False.")
329
330  filter_parallelization = options_lib.create_option(
331      name="filter_parallelization",
332      ty=bool,
333      docstring=
334      "Whether to parallelize stateless filter transformations. If None, "
335      "defaults to False.")
336
337  inject_prefetch = options_lib.create_option(
338      name="inject_prefetch",
339      ty=bool,
340      docstring=
341      "Whether to inject prefetch transformation as the last transformation "
342      "when the last transformation is a synchronous transformation. If None, "
343      "defaults to False.")
344
345  map_and_batch_fusion = options_lib.create_option(
346      name="map_and_batch_fusion",
347      ty=bool,
348      docstring=
349      "Whether to fuse map and batch transformations. If None, defaults to "
350      "True.")
351
352  map_and_filter_fusion = options_lib.create_option(
353      name="map_and_filter_fusion",
354      ty=bool,
355      docstring=
356      "Whether to fuse map and filter transformations. If None, defaults to "
357      "False.")
358
359  map_fusion = options_lib.create_option(
360      name="map_fusion",
361      ty=bool,
362      docstring="Whether to fuse map transformations. If None, defaults to "
363      "False.")
364
365  map_parallelization = options_lib.create_option(
366      name="map_parallelization",
367      ty=bool,
368      docstring=
369      "Whether to parallelize stateless map transformations. If None, defaults "
370      "to True.")
371
372  noop_elimination = options_lib.create_option(
373      name="noop_elimination",
374      ty=bool,
375      docstring=
376      "Whether to eliminate no-op transformations. If None, defaults to True.")
377
378  parallel_batch = options_lib.create_option(
379      name="parallel_batch",
380      ty=bool,
381      docstring="Whether to parallelize copying of batch elements. If None, "
382      "defaults to True.")
383
384  shuffle_and_repeat_fusion = options_lib.create_option(
385      name="shuffle_and_repeat_fusion",
386      ty=bool,
387      docstring="Whether to fuse shuffle and repeat transformations. If None, "
388      "defaults to True.")
389
390  def _to_proto(self):
391    pb = dataset_options_pb2.OptimizationOptions()
392    if self.apply_default_optimizations is not None:
393      pb.apply_default_optimizations = self.apply_default_optimizations
394    if self.filter_fusion is not None:
395      pb.filter_fusion = self.filter_fusion
396    if self.filter_parallelization is not None:
397      pb.filter_parallelization = self.filter_parallelization
398    if self.inject_prefetch is not None:
399      pb.inject_prefetch = self.inject_prefetch
400    if self.map_and_batch_fusion is not None:
401      pb.map_and_batch_fusion = self.map_and_batch_fusion
402    if self.map_and_filter_fusion is not None:
403      pb.map_and_filter_fusion = self.map_and_filter_fusion
404    if self.map_fusion is not None:
405      pb.map_fusion = self.map_fusion
406    if self.map_parallelization is not None:
407      pb.map_parallelization = self.map_parallelization
408    if self.noop_elimination is not None:
409      pb.noop_elimination = self.noop_elimination
410    if self.parallel_batch is not None:
411      pb.parallel_batch = self.parallel_batch
412    if self.shuffle_and_repeat_fusion is not None:
413      pb.shuffle_and_repeat_fusion = self.shuffle_and_repeat_fusion
414    return pb
415
416  def _from_proto(self, pb):
417    if pb.WhichOneof("optional_apply_default_optimizations") is not None:
418      self.apply_default_optimizations = pb.apply_default_optimizations
419    if pb.WhichOneof("optional_filter_fusion") is not None:
420      self.filter_fusion = pb.filter_fusion
421    if pb.WhichOneof("optional_filter_parallelization") is not None:
422      self.filter_parallelization = pb.filter_parallelization
423    if pb.WhichOneof("optional_inject_prefetch") is not None:
424      self.inject_prefetch = pb.inject_prefetch
425    if pb.WhichOneof("optional_map_and_batch_fusion") is not None:
426      self.map_and_batch_fusion = pb.map_and_batch_fusion
427    if pb.WhichOneof("optional_map_and_filter_fusion") is not None:
428      self.map_and_filter_fusion = pb.map_and_filter_fusion
429    if pb.WhichOneof("optional_map_fusion") is not None:
430      self.map_fusion = pb.map_fusion
431    if pb.WhichOneof("optional_map_parallelization") is not None:
432      self.map_parallelization = pb.map_parallelization
433    if pb.WhichOneof("optional_noop_elimination") is not None:
434      self.noop_elimination = pb.noop_elimination
435    if pb.WhichOneof("optional_parallel_batch") is not None:
436      self.parallel_batch = pb.parallel_batch
437    if pb.WhichOneof("optional_shuffle_and_repeat_fusion") is not None:
438      self.shuffle_and_repeat_fusion = pb.shuffle_and_repeat_fusion
439
440  def _set_mutable(self, mutable):
441    """Change the mutability value to `mutable` on this options and children."""
442    # pylint: disable=protected-access
443    object.__setattr__(self, "_mutable", mutable)
444
445
446@deprecation.deprecated_endpoints("data.experimental.ThreadingOptions")
447@tf_export("data.experimental.ThreadingOptions", "data.ThreadingOptions")
448class ThreadingOptions(options_lib.OptionsBase):
449  """Represents options for dataset threading.
450
451  You can set the threading options of a dataset through the
452  `experimental_threading` property of `tf.data.Options`; the property is
453  an instance of `tf.data.ThreadingOptions`.
454
455  ```python
456  options = tf.data.Options()
457  options.threading.private_threadpool_size = 10
458  dataset = dataset.with_options(options)
459  ```
460  """
461
462  max_intra_op_parallelism = options_lib.create_option(
463      name="max_intra_op_parallelism",
464      ty=int,
465      docstring=
466      "If set, it overrides the maximum degree of intra-op parallelism.")
467
468  private_threadpool_size = options_lib.create_option(
469      name="private_threadpool_size",
470      ty=int,
471      docstring=
472      "If set, the dataset will use a private threadpool of the given size. "
473      "The value 0 can be used to indicate that the threadpool size should be "
474      "determined at runtime based on the number of available CPU cores.")
475
476  def _to_proto(self):
477    pb = dataset_options_pb2.ThreadingOptions()
478    if self.max_intra_op_parallelism is not None:
479      pb.max_intra_op_parallelism = self.max_intra_op_parallelism
480    if self.private_threadpool_size is not None:
481      pb.private_threadpool_size = self.private_threadpool_size
482    return pb
483
484  def _from_proto(self, pb):
485    if pb.WhichOneof("optional_max_intra_op_parallelism") is not None:
486      self.max_intra_op_parallelism = pb.max_intra_op_parallelism
487    if pb.WhichOneof("optional_private_threadpool_size") is not None:
488      self.private_threadpool_size = pb.private_threadpool_size
489
490
491@tf_export("data.Options")
492class Options(options_lib.OptionsBase):
493  """Represents options for `tf.data.Dataset`.
494
495  A `tf.data.Options` object can be, for instance, used to control which static
496  optimizations to apply to the input pipeline graph or whether to use
497  performance modeling to dynamically tune the parallelism of operations such as
498  `tf.data.Dataset.map` or `tf.data.Dataset.interleave`.
499
500  The options are set for the entire dataset and are carried over to datasets
501  created through tf.data transformations.
502
503  The options can be set by constructing an `Options` object and using the
504  `tf.data.Dataset.with_options(options)` transformation, which returns a
505  dataset with the options set.
506
507  >>> dataset = tf.data.Dataset.range(42)
508  >>> options = tf.data.Options()
509  >>> options.deterministic = False
510  >>> dataset = dataset.with_options(options)
511  >>> print(dataset.options().deterministic)
512  False
513
514  Note: A known limitation of the `tf.data.Options` implementation is that the
515  options are not preserved across tf.function boundaries. In particular, to
516  set options for a dataset that is iterated within a tf.function, the options
517  need to be set within the same tf.function.
518  """
519
520  autotune = options_lib.create_option(
521      name="autotune",
522      ty=AutotuneOptions,
523      docstring="The autotuning options associated with the dataset. See "
524      "`tf.data.experimental.AutotuneOptions` for more details.",
525      default_factory=AutotuneOptions)
526
527  deterministic = options_lib.create_option(
528      name="deterministic",
529      ty=bool,
530      docstring=
531      "Whether the outputs need to be produced in deterministic order. If None,"
532      " defaults to True.")
533
534  experimental_deterministic = options_lib.create_option(
535      name="experimental_deterministic",
536      ty=bool,
537      docstring="DEPRECATED. Use `deterministic` instead.")
538
539  experimental_distribute = options_lib.create_option(
540      name="experimental_distribute",
541      ty=DistributeOptions,
542      docstring=
543      "The distribution strategy options associated with the dataset. See "
544      "`tf.data.experimental.DistributeOptions` for more details.",
545      default_factory=DistributeOptions)
546
547  experimental_external_state_policy = options_lib.create_option(
548      name="experimental_external_state_policy",
549      ty=ExternalStatePolicy,
550      docstring="This option can be used to override the default policy for "
551      "how to handle external state when serializing a dataset or "
552      "checkpointing its iterator. There are three settings available - "
553      "IGNORE: External state is ignored without a warning; WARN: External "
554      "state is ignored and a warning is logged; FAIL: External state results "
555      "in an error.")
556
557  experimental_optimization = options_lib.create_option(
558      name="experimental_optimization",
559      ty=OptimizationOptions,
560      docstring=
561      "The optimization options associated with the dataset. See "
562      "`tf.data.experimental.OptimizationOptions` for more details.",
563      default_factory=OptimizationOptions)
564
565  experimental_slack = options_lib.create_option(
566      name="experimental_slack",
567      ty=bool,
568      docstring="Whether to introduce 'slack' in the last `prefetch` of the "
569      "input pipeline, if it exists. This may reduce CPU contention with "
570      "accelerator host-side activity at the start of a step. The slack "
571      "frequency is determined by the number of devices attached to this "
572      "input pipeline. If None, defaults to False.")
573
574  experimental_threading = options_lib.create_option(
575      name="experimental_threading",
576      ty=ThreadingOptions,
577      docstring="DEPRECATED. Use `threading` instead.")
578
579  threading = options_lib.create_option(
580      name="threading",
581      ty=ThreadingOptions,
582      docstring="The threading options associated with the dataset. See "
583      "`tf.data.ThreadingOptions` for more details.",
584      default_factory=ThreadingOptions)
585
586  def __getattribute__(self, name):
587    if name == "experimental_threading":
588      logging.warning("options.experimental_threading is deprecated. "
589                      "Use options.threading instead.")
590      return getattr(self, "threading")
591    if name == "experimental_deterministic":
592      # TODO(aaudibert): Uncomment after internal uses have been updated.
593      # logging.warning("options.experimental_deterministic is deprecated. "
594      #                 "Use options.deterministic instead.")
595      return getattr(self, "deterministic")
596    return super(Options, self).__getattribute__(name)
597
598  def __setattr__(self, name, value):
599    if name == "experimental_threading":
600      logging.warning("options.experimental_threading is deprecated. "
601                      "Use options.threading instead.")
602      super(Options, self).__setattr__("threading", value)
603      return
604    if name == "experimental_deterministic":
605      # TODO(aaudibert): Uncomment after internal uses have been updated.
606      # logging.warning("options.experimental_deterministic is deprecated. "
607      #                 "Use options.deterministic instead.")
608      super(Options, self).__setattr__("deterministic", value)
609      return
610    super(Options, self).__setattr__(name, value)
611
612  def _to_proto(self):
613    pb = dataset_options_pb2.Options()
614    if self.deterministic is not None:
615      pb.deterministic = self.deterministic
616    pb.autotune_options.CopyFrom(self.autotune._to_proto())  # pylint: disable=protected-access
617    pb.distribute_options.CopyFrom(self.experimental_distribute._to_proto())  # pylint: disable=protected-access
618    if self.experimental_external_state_policy is not None:
619      pb.external_state_policy = (
620          ExternalStatePolicy._to_proto(  # pylint: disable=protected-access
621              self.experimental_external_state_policy))
622    pb.optimization_options.CopyFrom(self.experimental_optimization._to_proto())  # pylint: disable=protected-access
623    if self.experimental_slack is not None:
624      pb.slack = self.experimental_slack
625    pb.threading_options.CopyFrom(self.threading._to_proto())  # pylint: disable=protected-access
626    return pb
627
628  def _from_proto(self, pb):
629    if pb.WhichOneof("optional_deterministic") is not None:
630      self.deterministic = pb.deterministic
631    self.autotune._from_proto(pb.autotune_options)  # pylint: disable=protected-access
632    self.experimental_distribute._from_proto(pb.distribute_options)  # pylint: disable=protected-access
633    if pb.WhichOneof("optional_external_state_policy") is not None:
634      self.experimental_external_state_policy = (
635          ExternalStatePolicy._from_proto(  # pylint: disable=protected-access
636              pb.external_state_policy))
637    self.experimental_optimization._from_proto(pb.optimization_options)  # pylint: disable=protected-access
638    if pb.WhichOneof("optional_slack") is not None:
639      self.experimental_slack = pb.slack
640    self.threading._from_proto(pb.threading_options)  # pylint: disable=protected-access
641
642  def _set_mutable(self, mutable):
643    """Change the mutability value to `mutable` on this options and children."""
644    # pylint: disable=protected-access
645    object.__setattr__(self, "_mutable", mutable)
646    self.autotune._set_mutable(mutable)
647    self.experimental_distribute._set_mutable(mutable)
648    self.experimental_optimization._set_mutable(mutable)
649    self.threading._set_mutable(mutable)
650
651  def merge(self, options):
652    """Merges itself with the given `tf.data.Options`.
653
654    If this object and the `options` to merge set an option differently, a
655    warning is generated and this object's value is updated with the `options`
656    object's value.
657
658    Args:
659      options: The `tf.data.Options` to merge with.
660
661    Returns:
662      New `tf.data.Options` object which is the result of merging self with
663      the input `tf.data.Options`.
664    """
665    return options_lib.merge_options(self, options)
666