xref: /aosp_15_r20/external/tensorflow/tensorflow/dtensor/cc/constants.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_DTENSOR_CC_CONSTANTS_H_
17 #define TENSORFLOW_DTENSOR_CC_CONSTANTS_H_
18 
19 namespace tensorflow {
20 namespace dtensor {
21 // Constants used within dtensor scope.
22 
23 // Qualified attribute without `_` prefix.
24 // Used in Ops attribute registration.
25 static constexpr char kQualifiedLayoutAttr[] = "layout";
26 
27 // Internal attribute to DTensor MLIR passes and Graph nodes.
28 // Prefixed with `_` so that it doesn't require op attribute registration.
29 static constexpr char kLayoutAttr[] = "_layout";
30 
31 // Indicates a non-binding layout hint provided by the user.
32 // `tf` prefix attached in MLIR importer for dialect requirements.
33 static constexpr char kCustomDefaultLayoutAttr[] = "tf._default_layout";
34 
35 // Indicates a non-binding layout hint provided by the user.
36 static constexpr char kDefaultLayoutAttr[] = "_default_layout";
37 
38 // Attribute carries layout information from Custom Device Arguments.
39 // `tf` prefix attached in MLIR importer for dialect requirements.
40 static constexpr char kCustomDeviceAttr[] = "tf._layout";
41 
42 // Attribute attached on _Arg node for the mesh config.
43 static constexpr char kMeshAttr[] = "_mesh";
44 
45 // Attribute carries mesh information from Custom Device Arguments.
46 // `tf` prefix attached in MLIR importer for dialect requirements.
47 static constexpr char kCustomDeviceMeshAttr[] = "tf._mesh";
48 
49 // Attribute carries argument indices for newly inferred layout of resource
50 // handle.
51 static constexpr char kNewResourceLayoutIndices[] =
52     "_inferred_resource_indices";
53 
54 // Attribute carries layout for newly inferred layout of resource handle.
55 static constexpr char kNewResourceArgLayouts[] = "_inferred_resource_layouts";
56 
57 // Attribute carries input layout information for shape op.
58 static constexpr char kShapeOpInputLayout[] = "_shape_input_layout";
59 
60 // Attribute carries input layout index for shape op. This forms a 1 -> 1
61 // mapping for kShapeOpInputLayout above.
62 static constexpr char kShapeOpInputLayoutIndices[] = "_shape_input_indices";
63 
64 // Attribute that carries global shape of operation. Used to preserve global
65 // shape to be used during SPMD expansion.
66 static constexpr char kGlobalShape[] = "_global_shape";
67 
68 // Global shape attribute with `tf.` dialect to be used for annotating func op
69 // arguments/return values.
70 static constexpr char kGlobalShapeDialectAttr[] = "tf._global_shape";
71 
72 // Attribute attached to resource-type function arguments containing the local
73 // shape of the tensor that is being assigned to it.
74 static constexpr char kAssignedResourceLocalShape[] =
75     "tf._assigned_resource_local_shape";
76 
77 // Tensor handles smaller than this is considered as small tensor. We perform
78 // some optimizations around it. For example, will be transformed into constant
79 // values during graph building, instead of being passed as inputs. In addition,
80 // we allow automatical broadcasting small non-DTensor to DTensor device, which
81 // is very useful for shape/axis info tensor in eager mode (eliminating the need
82 // forcing users to do explicit copy-to-mesh).
83 static constexpr int kSmallTensorThreshold = 20;
84 
85 // Contains a serialized mesh. Will be attached to a FloorMod op to denote which
86 // mesh the output of the FloorMod op is giving coordinates for.
87 static constexpr char kMeshCoordinatesAttr[] = "_mesh_coordinates";
88 
89 // Attribute used to determine if a module pass should log long form information
90 // such as IR dumps etc.
91 static constexpr char kDoNotLog[] = "dtensor.do_not_log";
92 
93 // The number of TPU cores in a donut.
94 static constexpr int kTpuDonutSize = 8;
95 
96 // An attribute used to cache the computation of device seeds, so that we don't
97 // constantly recompute device seeds in a cluster for a given layout.
98 static constexpr char kDeviceSeedForMeshDims[] =
99     "dtensor.device_seed_for_mesh_dims";
100 
101 // Attribute that determines whether to skip XlA compilation. There are some ops
102 // that run on a TPU mesh but are not expected to be compiled by XLA, e.g.
103 // VarHandleOp, DestroyResourceOp, etc. For such an case, set this attribute
104 // to true on the StatefulPartitionedCallOp generated by MLIR lowering.
105 static constexpr char kSkipXlaCompilation[] = "_skip_xla_compilation";
106 
107 // Prefix of pipelining mesh name (kPipelineMeshNamePrefix + composite device
108 // name).
109 static constexpr char kPipelineMeshNamePrefix[] = "pipe_cluster:";
110 
111 // An attribute which stores the cache_key for the graph in the module. Used
112 // to uniquely name functions.
113 static constexpr char kCacheKey[] = "dtensor.cache_key";
114 
115 // An attribute that determines whether a tensor is a sparse tensor. If this
116 // attribute exists in a tensor, then this tensor is a sparse tensor.
117 static constexpr char kSparseValue[] = "tf._sparse";
118 
119 // TPUEmbedding configuration attribute with `tf.` dialect to be used for
120 // annotating func op that contains tpu embedding configuration ops.
121 static constexpr char kTPUEmbeddingConfiguration[] =
122     "tf._tpu_embedding_configuration";
123 
124 // Attribute mapping table_id to func op arguments using as TPUEmbedding tables
125 // `tf` prefix attached in MLIR importer for dialect requirements.
126 static constexpr char kTPUEmbeddingTableID[] = "tf._tpu_embedding_table_id";
127 
128 // Attribute mapping slot_id to func op arguments using as TPUEmbedding slot
129 // variables.`tf` prefix attached in MLIR importer for dialect requirements.
130 static constexpr char kTPUEmbeddingSlotID[] = "tf._tpu_embedding_slot_id";
131 
132 // Name of dtensor load embedding function.
133 static constexpr char kLoadEmbeddingFn[] = "load_embedding_fn";
134 
135 // Name of dtensor retrieve embedding function.
136 static constexpr char kRetrieveEmbeddingFn[] = "retrieve_embedding_fn";
137 }  // namespace dtensor
138 }  // namespace tensorflow
139 
140 #endif  // TENSORFLOW_DTENSOR_CC_CONSTANTS_H_
141