1syntax = "proto3"; 2 3package tensorflow.tpu; 4 5import "tensorflow/compiler/xla/xla.proto"; 6import "tensorflow/compiler/xla/xla_data.proto"; 7import "tensorflow/core/framework/tensor_shape.proto"; 8import "tensorflow/core/framework/types.proto"; 9import "tensorflow/core/protobuf/tpu/dynamic_padding.proto"; 10 11option cc_enable_arenas = true; 12 13// This is an experimental proto used in the TF/XLA bridge to store metadata to 14// a compile op (e.g. _TPUCompileMlir). 15// TODO(lyandy): Deprecate proto once generic metadata proto is created. 16message TPUCompileMetadataProto { 17 // Description of the types and shapes of the arguments to a computation. 18 message Arg { 19 enum Kind { 20 INVALID = 0; 21 PARAMETER = 1; 22 VARIABLE = 2; 23 // These are args which have been guaranteed to be constants during the 24 // session lifetime by the use of the GuaranteeConstOp (or ConstantOp). 25 GUARANTEED_CONSTANT = 3; 26 } 27 DataType dtype = 1; 28 TensorShapeProto shape = 2; 29 Kind kind = 3; 30 31 // The cross-core sharding of this input within each replica, e.g., 32 // assigning to one core, or replicate across all cores. 33 xla.OpSharding sharding = 4; 34 35 // Whether this argument will receive the same data across all replicas. 36 bool is_same_data_across_replicas = 5; 37 38 enum EnableXlaSharding { 39 DISALLOWED = 0; 40 // Sharding is allowed if host training loop exists. 41 TENTATIVE = 1; 42 ALLOWED = 2; 43 } 44 // Whether to allow XLA to produce separate programs to shard/unshard this 45 // argument. Requires this arg to be an on-device Kind::VARIABLE, or a 46 // Kind::PARAMETER. For Kind::PARAMETER, it represents the initial value of 47 // a variable, and retval_index_for_sharding must be specified for the 48 // corresponding updated value. 49 EnableXlaSharding enable_xla_sharding = 6; 50 51 // If XLA sharding is allowed on a Kind::PARAMETER, this field is used to 52 // specify the corresponding updated value in the return values. Use -1 for 53 // variables that are not updated. 54 int32 retval_index_for_sharding = 8; 55 56 // Whether this argument is placed on fast memory or not. 57 bool fast_mem = 7; 58 59 // Whether to let XLA to decide the layout during compilation, as opposed to 60 // using a fixed layout determined by the shape. 61 bool unrestricted_layout = 9; 62 63 // Name of the node that the arg comes from. 64 string name = 10; 65 66 // Whether to use XLA collectives to broadcast this parameter to all 67 // replicas, instead of using TensorFlow Send/Recv among the tasks. 68 bool requires_xla_broadcast = 11; 69 } 70 repeated Arg args = 1; 71 72 // Description of the return values from a computation. 73 message Retval { 74 // The cross-core sharding of this return value within each replica, e.g., 75 // assigning to one core, or replicate across all cores. 76 xla.OpSharding sharding = 1; 77 } 78 repeated Retval retvals = 2; 79 80 // Number of replicas of the computation and number of cores in each replica. 81 // TODO(b/140721404): it may not be necessary to state the number of cores per 82 // replica here. Reconsider when replicated model-parallelism is implemented 83 // in XLA. 84 int32 num_replicas = 3; 85 int32 num_cores_per_replica = 4; 86 87 reserved 5; // was device_names 88 reserved 7; // was replica_device_assignment 89 90 xla.DeviceAssignmentProto device_assignment = 8; 91 92 // A fingerprint of the function library. Ensures that any functions called 93 // by the computation have matching definitions. 94 uint64 function_library_fingerprint = 6; 95 96 // Unique session identifier. Can be empty. 97 string session_handle = 9; 98 99 // Fingerprint of guaranteed_const value. The fingerprint computation inside 100 // tpu_compile_op may be slow. The computation can be avoided by setting the 101 // fingerprint value here. 102 string guaranteed_const_fingerprint = 10; 103 104 repeated tpu.PaddingMap padding_maps = 11; 105 106 // The location of step markers that XLA compile will instrument. 107 xla.DebugOptions.StepMarkerLocation step_marker_location = 12; 108 109 // Minimum number of batches run through the XLA graph before XLA fusion 110 // autotuner is enabled. Default value of zero disables the autotuner. 111 // The XLA fusion autotuner can improve performance by executing a heuristic 112 // search on the compiler parameters. 113 int64 xla_fusion_autotuner_thresh = 13; 114 115 // Enables TPU compiler to add partitioning policies for inputs/outputs to 116 // the XLA computation for model parallelism. 117 bool enable_automatic_model_parallelism = 14; 118 119 // Whether to use XLA's SPMD or MPMD partitioner when compiler partitioning is 120 // requested. 121 bool use_spmd_for_xla_partitioning = 15; 122 123 // Whether to automatically generate XLA shardings for SPMD partitioner. 124 bool use_auto_spmd_for_xla_partitioning = 18; 125 126 // Device mesh shape used to create the sharding search space when 127 // use_auto_spmd_partitioning=true. 128 repeated int64 auto_spmd_mesh_shape = 19; 129 130 // Device mesh ids compatible with the above mesh_shape used when 131 // use_auto_spmd_partitioning=true. 132 repeated int64 auto_spmd_mesh_ids = 20; 133 134 reserved 16; // Was broadcast_replicated_parameters_via_collectives 135 136 // A fingerprint generated by hashing the MLIR module content. 137 uint64 mlir_fingerprint = 17; 138 139 TPUCompileOptions compile_options = 21; 140} 141 142// Stable protobuf for TPU compilation options, suitable for persistent storage. 143// This proto needs to be backward compatible under maintenance. 144// TODO(timshen): investigate and migrate other options from 145// TPUCompileMetadataProto. 146message TPUCompileOptions { 147 enum Precision { 148 DEFAULT = 0; 149 BFLOAT16 = 1; 150 FLOAT32 = 2; 151 TENSOR_FLOAT32 = 3; 152 } 153 Precision matrix_unit_operand_precision = 1; 154} 155