xref: /aosp_15_r20/external/tensorflow/tensorflow/compiler/xla/service/compilation_environments.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_COMPILER_SERVICE_XLA_COMPILATION_ENVIRONMENTS_H_
17 #define TENSORFLOW_COMPILER_SERVICE_XLA_COMPILATION_ENVIRONMENTS_H_
18 
19 #include <cstdint>
20 #include <memory>
21 #include <string_view>
22 #include <typeindex>
23 #include <utility>
24 
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/platform/casts.h"
27 #include "tensorflow/core/platform/protobuf.h"
28 
29 namespace xla {
30 
31 // A class for holding CompilationEnvironments, i.e., protos holding the values
32 // of command line flags and environment variables that affect compilation.
33 //
34 // CompilationEnvironments uses lazy initialization, (see GetEnv() for more
35 // details). Lazy initialization is used so we can avoid:
36 // A) Requiring every code path to explitily construct all needed compilation
37 //    environments, particularly when the default constructed environment is
38 //    all we need AND
39 // B) Requiring CompilationEnvironments to implicitly construct all needed
40 //    environments, thereby requiring it to statically know the types of all
41 //    such environments
42 //
43 // CompilationEnvironments is not thread-safe.
44 class CompilationEnvironments {
45  public:
46   CompilationEnvironments() = default;
CompilationEnvironments(const CompilationEnvironments & rhs)47   CompilationEnvironments(const CompilationEnvironments& rhs) { *this = rhs; }
48   CompilationEnvironments& operator=(const CompilationEnvironments& rhs);
49   ~CompilationEnvironments() = default;
50 
51   // Users of CompilationEnvironments must specialize this method for each type
52   // of CompilationEnvironment they wish to use in code.
53   //
54   // Users are requested to call
55   // DefaultEnvCreated(T::descriptor()->full_name()); from their
56   // implementations, to track the number of calls to the default creator.
57   //
58   // REQUIRES:
59   // - T must be a type of proto message.
60   template <typename T>
61   static std::unique_ptr<T> CreateDefaultEnv() = delete;
62 
63   // Adds env to the list of CompilationEnvironments. If an environment with
64   // std::type_index equal to env.GetTypeid() has already been added, env
65   // will replace it.
66   void AddEnv(std::unique_ptr<tensorflow::protobuf::Message> env);
67 
68   // Returns the CompilationEnvironment corresponding to T. If such an
69   // environment has not been added, CreateDefaultEnv<T>() will be called to
70   // create one that is then added.
71   //
72   // GetEnv() is not const because it can perform lazy initialization, thereby
73   // modifying the CompilationEnvironments's data members.
74   //
75   // GetEnv<T> will not compile for type T, unless CreateDefaultEnv<T> is
76   // defined.
77   template <typename T>
78   const T& GetEnv();
79 
80   // Removes all added environments.
Clear()81   void Clear() { environments_.clear(); }
82 
83  private:
84   // Called by implementations of CreateDefaultEnv(), to globally track stats
85   // about default environment creation.
86   static void DefaultEnvCreated(std::string_view env_type);
87 
88   // Called by GetEnv() when it calls CreateDefaultEnv(), to globally track
89   // stats about how many of the created default environments are created by
90   // CompilationEnvironments.
91   static void DefaultEnvCreatedByCompilationEnvironments(
92       std::string_view env_type);
93 
94   // Called by AddEnv(), to globally track stats about how many environments
95   // are added to CompilationEnvironments.
96   static void EnvAdded(std::string_view env_type);
97 
98   absl::flat_hash_map<const tensorflow::protobuf::Descriptor*,
99                       std::unique_ptr<tensorflow::protobuf::Message>>
100       environments_;
101 };
102 
103 // ----- Template implementation below -----
104 
105 // Make sure no one tries to specialize CreateDefaultEnv() for raw
106 // tensorflow::protobuf::Message. Specialization should always be for a specific
107 // type of proto message.
108 template <>
109 std::unique_ptr<tensorflow::protobuf::Message>
110 CompilationEnvironments::CreateDefaultEnv() = delete;
111 
112 template <typename T>
GetEnv()113 const T& CompilationEnvironments::GetEnv() {
114   auto descriptor = T::descriptor();
115   auto it = environments_.find(descriptor);
116   if (it == environments_.end()) {
117     AddEnv(CreateDefaultEnv<T>());
118     DefaultEnvCreatedByCompilationEnvironments(descriptor->full_name());
119     it = environments_.find(descriptor);
120   }
121   return tensorflow::down_cast<const T&>(*it->second);
122 }
123 
124 }  // namespace xla
125 
126 #endif  // TENSORFLOW_COMPILER_SERVICE_XLA_COMPILATION_ENVIRONMENTS_H_
127