xref: /aosp_15_r20/external/tensorflow/tensorflow/core/tpu/kernels/tpu_compile_op_support.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
16 #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
17 
18 #include <string>
19 #include <vector>
20 
21 #include "absl/strings/string_view.h"
22 #include "absl/types/optional.h"
23 #include "absl/types/span.h"
24 #include "absl/types/variant.h"
25 #include "tensorflow/cc/framework/ops.h"
26 #include "tensorflow/compiler/tf2xla/xla_compiler.h"
27 #include "tensorflow/compiler/xla/service/hlo_module_config.h"
28 #include "tensorflow/compiler/xla/service/hlo_module_group.h"
29 #include "tensorflow/compiler/xla/service/hlo_sharding.h"
30 #include "tensorflow/compiler/xla/shape.h"
31 #include "tensorflow/compiler/xla/shape_tree.h"
32 #include "tensorflow/compiler/xla/statusor.h"
33 #include "tensorflow/compiler/xla/xla_data.pb.h"
34 #include "tensorflow/core/framework/attr_value.pb.h"
35 #include "tensorflow/core/framework/function.h"
36 #include "tensorflow/core/framework/op_kernel.h"
37 #include "tensorflow/core/framework/tensor.pb.h"
38 #include "tensorflow/core/framework/tensor_shape.h"
39 #include "tensorflow/core/framework/types.pb.h"
40 #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
41 #include "tensorflow/core/tpu/kernels/tpu_compile.pb.h"
42 
43 namespace tensorflow {
44 namespace tpu {
45 
46 namespace se = ::stream_executor;
47 
48 // List of parameters for lowering Mlir to HLO IR.
49 struct MlirToHloArgs {
50   absl::string_view mlir_module;
51   ConfigProto::Experimental::MlirBridgeRollout rollout_state =
52       ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED;
53 };
54 
55 // Variant of guaranteed constant tensors types.
56 using GuaranteedConsts = absl::variant<absl::Span<const TensorProto* const>,
57                                        const OpInputList* const>;
58 
59 // List of parameters for lowering function library definition to HLO IR.
60 struct FunctionToHloArgs {
61   const NameAttrList* const function;
62   const FunctionLibraryDefinition* const flib_def;
63   int graph_def_version;
64   GuaranteedConsts guaranteed_constants;
65 };
66 
67 // Persistent cache for compiled TPU program and the related compiler metadata
68 // intended for TPU inference.
69 // TODO(henrytan): there is an opportunity to consolidate the interface with the
70 // `TpuCompilationCacheInterface` once `TpuPersistentCompilationCache` is
71 // converted into a ref count based class.
72 class TpuPersistentCompilationCacheInterface {
73  public:
74   virtual ~TpuPersistentCompilationCacheInterface() = default;
75 
76   // Returns the location where cache entries are stored.
77   virtual std::string cache_location() const = 0;
78 };
79 
80 // Describes the position of an argument or return value after the computation
81 // has been partitioned into cores.
82 struct ShardingAndIndex {
83   // Sharding across cores.
84   ::xla::OpSharding sharding;
85   // Argument/return value number. If sharding is single-core, `indices` has a
86   // single element; otherwise, it has num_cores elements.
87   std::vector<int> indices;
88 };
89 
90 // TODO(b/158279168): Dedup with internal version.
91 // Return the per-device shape for a `shape` with a given `sharding`.
92 xla::Shape GetPerDeviceShape(const xla::Shape& shape,
93                              const xla::HloSharding& sharding, int64_t device);
94 
95 stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
96 CreateModuleConfig(
97     const xla::ProgramShape& program_shape,
98     absl::Span<const xla::Shape> argument_shapes,
99     absl::optional<const xla::Shape> result_layout,
100     absl::optional<const xla::DeviceAssignment> device_assignment,
101     int replica_count, int num_partitions,
102     const xla::DebugOptions* debug_options, const int* seed,
103     const int* launch_id, const bool* alias_passthrough_params,
104     const xla::FusionConfigCollection* fusion_config_collection,
105     const std::vector<std::vector<bool>>* fusion_config);
106 
107 stream_executor::port::StatusOr<std::unique_ptr<xla::HloModuleConfig>>
108 CreateModuleConfig(
109     const xla::ProgramShape& program_shape,
110     absl::Span<const xla::Shape> argument_shapes,
111     absl::optional<const xla::Shape> result_layout,
112     absl::optional<const xla::DeviceAssignment> device_assignment,
113     int replica_count,
114     int num_partitions, const xla::DebugOptions* debug_options);
115 
116 xla::ShapeTree<xla::HloSharding> GetSubtree(
117     const xla::ShapeTree<xla::HloSharding>& tuple_shape_tree,
118     int element_index);
119 
120 xla::Shape GetPerDeviceShape(const xla::Shape& shape,
121                              const xla::HloSharding& sharding, int64_t device);
122 
123 Status AddVariableUpdatesToCores(
124     const TPUCompileMetadataProto& metadata,
125     const XlaCompiler::CompilationResult& compilation_result,
126     const std::vector<ShardingAndIndex>& arg_core_mapping,
127     std::vector<bool>* may_modify_variables,
128     std::vector<std::vector<xla::Shape>>* per_core_output_shapes,
129     std::vector<std::vector<std::pair<int, bool>>>* per_core_variable_indices);
130 
131 se::port::Status ComputeOutputShapesForEachCore(
132     const tpu::TPUCompileMetadataProto& metadata,
133     const XlaCompiler::CompilationResult& compilation_result,
134     std::vector<std::vector<xla::Shape>>* per_core_output_shapes);
135 
136 se::port::Status CreateHloModules(
137     const TPUCompileMetadataProto& metadata,
138     const XlaCompiler::CompilationResult& compilation_result,
139     const absl::optional<xla::DeviceAssignment>& device_assignment,
140     std::vector<std::unique_ptr<xla::HloModule>>* hlo_modules);
141 
142 se::port::StatusOr<TpuCompilationRequestProto> CreateTpuCompilationRequest(
143     const absl::variant<MlirToHloArgs, FunctionToHloArgs>& computation,
144     const TPUCompileMetadataProto& metadata,
145     const std::vector<TensorShape>& arg_shapes);
146 
147 se::port::Status CompileOpMetadataFromContext(OpKernelConstruction* ctx,
148                                               TPUCompileMetadataProto* metadata,
149                                               NameAttrList* function_name,
150                                               std::string* mlir_module);
151 
152 // Computes shapes for each argument. Uses both the static shape from the
153 // metadata, and the dynamic shapes where the static shape is not
154 // defined. There must be one dynamic_shape for each argument with a
155 // partially defined shape, in index order.
156 Status ComputeArgumentShapes(const TPUCompileMetadataProto& metadata,
157                              const std::vector<TensorShape>& dynamic_shapes,
158                              std::vector<TensorShape>* arg_shapes);
159 }  // namespace tpu
160 }  // namespace tensorflow
161 
162 #endif  // TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_OP_SUPPORT_H_
163