xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/pjrt/pjrt_executable.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_XLA_PJRT_PJRT_EXECUTABLE_H_
17 #define TENSORFLOW_COMPILER_XLA_PJRT_PJRT_EXECUTABLE_H_
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string>
22 #include <vector>
23 
24 #include "absl/strings/string_view.h"
25 #include "tensorflow/compiler/xla/service/hlo_module.h"
26 #include "tensorflow/compiler/xla/statusor.h"
27 #include "tensorflow/compiler/xla/util.h"
28 
29 namespace xla {
30 
31 // Static device memory usage for a compiled program.
32 // The on-device memory needed to run an executable is at least
33 //   generated_code_size_in_bytes
34 //   + argument_size_in_bytes + output_size_in_bytes - alias_size_in_bytes
35 //   + temp_size_in_bytes.
36 struct CompiledMemoryStats {
37   int64_t generated_code_size_in_bytes = 0;
38   int64_t argument_size_in_bytes = 0;
39   int64_t output_size_in_bytes = 0;
40   // How much argument is reused for output.
41   int64_t alias_size_in_bytes = 0;
42   int64_t temp_size_in_bytes = 0;
43 
44   std::string DebugString() const;
45 };
46 
47 class PjRtExecutable {
48  public:
49   virtual ~PjRtExecutable() = default;
50 
51   virtual int num_replicas() const = 0;
52 
53   virtual int num_partitions() const = 0;
54 
55   virtual int64_t SizeOfGeneratedCodeInBytes() const = 0;
56 
57   // Unique name for this executable, e.g., HloModule name.
58   virtual absl::string_view name() const = 0;
59 
60   // Return an HloModule (optimized) per partition.
61   virtual StatusOr<std::vector<std::shared_ptr<HloModule>>> GetHloModules()
62       const = 0;
63 
64   // Return memory stats that allow callers to estimate device memory usage
65   // when running this executable.
GetCompiledMemoryStats()66   virtual StatusOr<CompiledMemoryStats> GetCompiledMemoryStats() const {
67     return Unimplemented("Retrieving CompiledMemoryStats is not supported.");
68   }
69 
70   // Serialize this executable into a string and return the value.
SerializeExecutable()71   virtual StatusOr<std::string> SerializeExecutable() const {
72     return Unimplemented("Serializing executable is not supported.");
73   }
74 
75   // Return a fingerprint of this executable.
FingerprintExecutable()76   virtual StatusOr<std::string> FingerprintExecutable() const {
77     return Unimplemented("Fingerprinting executable is not supported.");
78   }
79 };
80 
81 }  // namespace xla
82 
83 #endif  // TENSORFLOW_COMPILER_XLA_PJRT_PJRT_EXECUTABLE_H_
84