xref: /aosp_15_r20/external/tensorflow/tensorflow/core/protobuf/tpu/compile_metadata.proto (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1syntax = "proto3";
2
3package tensorflow.tpu;
4
5import "tensorflow/compiler/xla/xla.proto";
6import "tensorflow/compiler/xla/xla_data.proto";
7import "tensorflow/core/framework/tensor_shape.proto";
8import "tensorflow/core/framework/types.proto";
9import "tensorflow/core/protobuf/tpu/dynamic_padding.proto";
10
11option cc_enable_arenas = true;
12
13// This is an experimental proto used in the TF/XLA bridge to store metadata to
14// a compile op (e.g. _TPUCompileMlir).
15// TODO(lyandy): Deprecate proto once generic metadata proto is created.
16message TPUCompileMetadataProto {
17  // Description of the types and shapes of the arguments to a computation.
18  message Arg {
19    enum Kind {
20      INVALID = 0;
21      PARAMETER = 1;
22      VARIABLE = 2;
23      // These are args which have been guaranteed to be constants during the
24      // session lifetime by the use of the GuaranteeConstOp (or ConstantOp).
25      GUARANTEED_CONSTANT = 3;
26    }
27    DataType dtype = 1;
28    TensorShapeProto shape = 2;
29    Kind kind = 3;
30
31    // The cross-core sharding of this input within each replica, e.g.,
32    // assigning to one core, or replicate across all cores.
33    xla.OpSharding sharding = 4;
34
35    // Whether this argument will receive the same data across all replicas.
36    bool is_same_data_across_replicas = 5;
37
38    enum EnableXlaSharding {
39      DISALLOWED = 0;
40      // Sharding is allowed if host training loop exists.
41      TENTATIVE = 1;
42      ALLOWED = 2;
43    }
44    // Whether to allow XLA to produce separate programs to shard/unshard this
45    // argument. Requires this arg to be an on-device Kind::VARIABLE, or a
46    // Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of
47    // a variable, and retval_index_for_sharding must be specified for the
48    // corresponding updated value.
49    EnableXlaSharding enable_xla_sharding = 6;
50
51    // If XLA sharding is allowed on a Kind::PARAMETER, this field is used to
52    // specify the corresponding updated value in the return values. Use -1 for
53    // variables that are not updated.
54    int32 retval_index_for_sharding = 8;
55
56    // Whether this argument is placed on fast memory or not.
57    bool fast_mem = 7;
58
59    // Whether to let XLA to decide the layout during compilation, as opposed to
60    // using a fixed layout determined by the shape.
61    bool unrestricted_layout = 9;
62
63    // Name of the node that the arg comes from.
64    string name = 10;
65
66    // Whether to use XLA collectives to broadcast this parameter to all
67    // replicas, instead of using TensorFlow Send/Recv among the tasks.
68    bool requires_xla_broadcast = 11;
69  }
70  repeated Arg args = 1;
71
72  // Description of the return values from a computation.
73  message Retval {
74    // The cross-core sharding of this return value within each replica, e.g.,
75    // assigning to one core, or replicate across all cores.
76    xla.OpSharding sharding = 1;
77  }
78  repeated Retval retvals = 2;
79
80  // Number of replicas of the computation and number of cores in each replica.
81  // TODO(b/140721404): it may not be necessary to state the number of cores per
82  // replica here. Reconsider when replicated model-parallelism is implemented
83  // in XLA.
84  int32 num_replicas = 3;
85  int32 num_cores_per_replica = 4;
86
87  reserved 5;  // was device_names
88  reserved 7;  // was replica_device_assignment
89
90  xla.DeviceAssignmentProto device_assignment = 8;
91
92  // A fingerprint of the function library. Ensures that any functions called
93  // by the computation have matching definitions.
94  uint64 function_library_fingerprint = 6;
95
96  // Unique session identifier. Can be empty.
97  string session_handle = 9;
98
99  // Fingerprint of guaranteed_const value. The fingerprint computation inside
100  // tpu_compile_op may be slow. The computation can be avoided by setting the
101  // fingerprint value here.
102  string guaranteed_const_fingerprint = 10;
103
104  repeated tpu.PaddingMap padding_maps = 11;
105
106  // The location of step markers that XLA compile will instrument.
107  xla.DebugOptions.StepMarkerLocation step_marker_location = 12;
108
109  // Minimum number of batches run through the XLA graph before XLA fusion
110  // autotuner is enabled. Default value of zero disables the autotuner.
111  // The XLA fusion autotuner can improve performance by executing a heuristic
112  // search on the compiler parameters.
113  int64 xla_fusion_autotuner_thresh = 13;
114
115  // Enables TPU compiler to add partitioning policies for inputs/outputs to
116  // the XLA computation for model parallelism.
117  bool enable_automatic_model_parallelism = 14;
118
119  // Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is
120  // requested.
121  bool use_spmd_for_xla_partitioning = 15;
122
123  // Whether to automatically generate XLA shardings for SPMD partitioner.
124  bool use_auto_spmd_for_xla_partitioning = 18;
125
126  // Device mesh shape used to create the sharding search space when
127  // use_auto_spmd_partitioning=true.
128  repeated int64 auto_spmd_mesh_shape = 19;
129
130  // Device mesh ids compatible with the above mesh_shape used when
131  // use_auto_spmd_partitioning=true.
132  repeated int64 auto_spmd_mesh_ids = 20;
133
134  reserved 16;  // Was broadcast_replicated_parameters_via_collectives
135
136  // A fingerprint generated by hashing the MLIR module content.
137  uint64 mlir_fingerprint = 17;
138
139  TPUCompileOptions compile_options = 21;
140}
141
142// Stable protobuf for TPU compilation options, suitable for persistent storage.
143// This proto needs to be backward compatible under maintenance.
144// TODO(timshen): investigate and migrate other options from
145// TPUCompileMetadataProto.
146message TPUCompileOptions {
147  enum Precision {
148    DEFAULT = 0;
149    BFLOAT16 = 1;
150    FLOAT32 = 2;
151    TENSOR_FLOAT32 = 3;
152  }
153  Precision matrix_unit_operand_precision = 1;
154}
155