1 /* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_COMPILER_SERVICE_XLA_COMPILATION_ENVIRONMENTS_H_
17 #define TENSORFLOW_COMPILER_SERVICE_XLA_COMPILATION_ENVIRONMENTS_H_
18
19 #include <cstdint>
20 #include <memory>
21 #include <string_view>
22 #include <typeindex>
23 #include <utility>
24
25 #include "absl/container/flat_hash_map.h"
26 #include "tensorflow/core/platform/casts.h"
27 #include "tensorflow/core/platform/protobuf.h"
28
29 namespace xla {
30
31 // A class for holding CompilationEnvironments, i.e., protos holding the values
32 // of command line flags and environment variables that affect compilation.
33 //
34 // CompilationEnvironments uses lazy initialization, (see GetEnv() for more
35 // details). Lazy initialization is used so we can avoid:
36 // A) Requiring every code path to explitily construct all needed compilation
37 // environments, particularly when the default constructed environment is
38 // all we need AND
39 // B) Requiring CompilationEnvironments to implicitly construct all needed
40 // environments, thereby requiring it to statically know the types of all
41 // such environments
42 //
43 // CompilationEnvironments is not thread-safe.
44 class CompilationEnvironments {
45 public:
46 CompilationEnvironments() = default;
CompilationEnvironments(const CompilationEnvironments & rhs)47 CompilationEnvironments(const CompilationEnvironments& rhs) { *this = rhs; }
48 CompilationEnvironments& operator=(const CompilationEnvironments& rhs);
49 ~CompilationEnvironments() = default;
50
51 // Users of CompilationEnvironments must specialize this method for each type
52 // of CompilationEnvironment they wish to use in code.
53 //
54 // Users are requested to call
55 // DefaultEnvCreated(T::descriptor()->full_name()); from their
56 // implementations, to track the number of calls to the default creator.
57 //
58 // REQUIRES:
59 // - T must be a type of proto message.
60 template <typename T>
61 static std::unique_ptr<T> CreateDefaultEnv() = delete;
62
63 // Adds env to the list of CompilationEnvironments. If an environment with
64 // std::type_index equal to env.GetTypeid() has already been added, env
65 // will replace it.
66 void AddEnv(std::unique_ptr<tensorflow::protobuf::Message> env);
67
68 // Returns the CompilationEnvironment corresponding to T. If such an
69 // environment has not been added, CreateDefaultEnv<T>() will be called to
70 // create one that is then added.
71 //
72 // GetEnv() is not const because it can perform lazy initialization, thereby
73 // modifying the CompilationEnvironments's data members.
74 //
75 // GetEnv<T> will not compile for type T, unless CreateDefaultEnv<T> is
76 // defined.
77 template <typename T>
78 const T& GetEnv();
79
80 // Removes all added environments.
Clear()81 void Clear() { environments_.clear(); }
82
83 private:
84 // Called by implementations of CreateDefaultEnv(), to globally track stats
85 // about default environment creation.
86 static void DefaultEnvCreated(std::string_view env_type);
87
88 // Called by GetEnv() when it calls CreateDefaultEnv(), to globally track
89 // stats about how many of the created default environments are created by
90 // CompilationEnvironments.
91 static void DefaultEnvCreatedByCompilationEnvironments(
92 std::string_view env_type);
93
94 // Called by AddEnv(), to globally track stats about how many environments
95 // are added to CompilationEnvironments.
96 static void EnvAdded(std::string_view env_type);
97
98 absl::flat_hash_map<const tensorflow::protobuf::Descriptor*,
99 std::unique_ptr<tensorflow::protobuf::Message>>
100 environments_;
101 };
102
103 // ----- Template implementation below -----
104
105 // Make sure no one tries to specialize CreateDefaultEnv() for raw
106 // tensorflow::protobuf::Message. Specialization should always be for a specific
107 // type of proto message.
108 template <>
109 std::unique_ptr<tensorflow::protobuf::Message>
110 CompilationEnvironments::CreateDefaultEnv() = delete;
111
112 template <typename T>
GetEnv()113 const T& CompilationEnvironments::GetEnv() {
114 auto descriptor = T::descriptor();
115 auto it = environments_.find(descriptor);
116 if (it == environments_.end()) {
117 AddEnv(CreateDefaultEnv<T>());
118 DefaultEnvCreatedByCompilationEnvironments(descriptor->full_name());
119 it = environments_.find(descriptor);
120 }
121 return tensorflow::down_cast<const T&>(*it->second);
122 }
123
124 } // namespace xla
125
126 #endif // TENSORFLOW_COMPILER_SERVICE_XLA_COMPILATION_ENVIRONMENTS_H_
127