xref: /aosp_15_r20/external/tensorflow/tensorflow/core/grappler/costs/cost_estimator.h (revision b6fb3261f9314811a0f4371741dbb8839866f948)
1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
18 
19 #include <cmath>
20 #include <string>
21 #include <unordered_map>
22 
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27 
28 namespace tensorflow {
29 class GraphDef;
30 class CostGraphDef;
31 
32 namespace grappler {
33 struct GrapplerItem;
34 
35 constexpr int64_t kMemoryUnknown = -1ll;
36 constexpr int64_t kZeroMemory = 0ll;
37 
38 struct DeviceInfo {
39   // Billions of operations executed per second.
40   double gigaops;
41 
42   // Bandwidth to main memory in GB per second.
43   double gb_per_sec;
44 
45   // Read bandwidth to intermediate memory in GB per second.
46   double intermediate_read_gb_per_sec;
47 
48   // Write bandwidth to intermediate memory in GB per second.
49   double intermediate_write_gb_per_sec;
50 
DeviceInfoDeviceInfo51   DeviceInfo()
52       : gigaops(INFINITY),
53         gb_per_sec(INFINITY),
54         intermediate_read_gb_per_sec(INFINITY),
55         intermediate_write_gb_per_sec(INFINITY) {}
56 
DeviceInfoDeviceInfo57   DeviceInfo(const DeviceInfo& input)
58       : gigaops(input.gigaops),
59         gb_per_sec(input.gb_per_sec),
60         intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec),
61         intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {}
62 
63   DeviceInfo(double gigaops, double gb_per_sec,
64              double intermediate_read_gb_per_sec = INFINITY,
65              double intermediate_write_gb_per_sec = INFINITY)
gigaopsDeviceInfo66       : gigaops(gigaops),
67         gb_per_sec(gb_per_sec),
68         intermediate_read_gb_per_sec(intermediate_read_gb_per_sec),
69         intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {}
70 };
71 
72 // Holds the set of things we might want to estimate or measure in Grappler.
73 // Always produce execution time. Other fields are optional depending on the
74 // estimator being used.
75 struct Costs {
76   // Returns a Costs structure with default values for all of the fields.
77   inline Costs();
78 
79   // Builds a Costs structure with all zero values, rather than unknowns.
80   static inline Costs ZeroCosts(bool inaccurate = false);
81 
82   struct MilliSeconds : std::chrono::milliseconds {
MilliSecondsCosts::MilliSeconds83     MilliSeconds() : std::chrono::milliseconds(0) {}
MilliSecondsCosts::MilliSeconds84     MilliSeconds(double d)
85         : std::chrono::milliseconds(static_cast<int64_t>(d)) {}
MilliSecondsCosts::MilliSeconds86     MilliSeconds(const std::chrono::milliseconds& d)
87         : std::chrono::milliseconds(d) {}
88     MilliSeconds& operator=(const std::chrono::milliseconds& d) {
89       std::chrono::milliseconds::operator=(d);
90       return *this;
91     }
92   };
93   struct MicroSeconds : std::chrono::microseconds {
MicroSecondsCosts::MicroSeconds94     MicroSeconds() : std::chrono::microseconds(0) {}
MicroSecondsCosts::MicroSeconds95     MicroSeconds(double d)
96         : std::chrono::microseconds(static_cast<int64_t>(d)) {}
MicroSecondsCosts::MicroSeconds97     MicroSeconds(const std::chrono::microseconds& d)
98         : std::chrono::microseconds(d) {}
99     MicroSeconds& operator=(const std::chrono::microseconds& d) {
100       std::chrono::microseconds::operator=(d);
101       return *this;
102     }
asMilliSecondsCosts::MicroSeconds103     MilliSeconds asMilliSeconds() const {
104       return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
105     }
106   };
107   struct NanoSeconds : std::chrono::nanoseconds {
NanoSecondsCosts::NanoSeconds108     NanoSeconds() : std::chrono::nanoseconds(0) {}
NanoSecondsCosts::NanoSeconds109     NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64_t>(d)) {}
NanoSecondsCosts::NanoSeconds110     NanoSeconds(const std::chrono::nanoseconds& d)
111         : std::chrono::nanoseconds(d) {}
112     NanoSeconds& operator=(const std::chrono::nanoseconds& d) {
113       std::chrono::nanoseconds::operator=(d);
114       return *this;
115     }
asMicroSecondsCosts::NanoSeconds116     MicroSeconds asMicroSeconds() const {
117       return std::chrono::duration_cast<std::chrono::microseconds>(*this);
118     }
asMilliSecondsCosts::NanoSeconds119     MilliSeconds asMilliSeconds() const {
120       return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
121     }
infinityCosts::NanoSeconds122     static NanoSeconds infinity() {
123       return NanoSeconds(std::chrono::nanoseconds::max());
124     }
125   };
126 
127   // We store all our times in nanoseconds. If needs be, we can always switch to
128   // picoseconds in the future by updating this typedef.
129   typedef NanoSeconds Duration;
130 
131   // Overall cost of running the graph; latency.
132   Duration execution_time;
133 
134   // Computation cost of running the graph.
135   Duration compute_time;
136 
137   // Memory access cost of running the graph.
138   Duration memory_time;
139 
140   // Intermediate memory access cost of running the graph
141   Duration intermediate_memory_time;
142   Duration intermediate_memory_read_time;   // Intermediate memory read cost.
143   Duration intermediate_memory_write_time;  // Intermediate memory write cost.
144 
145   // This field can be a very pessimistic estimate of the main memory
146   // requirements of a graph. For example, it might assume that all activations
147   // are live for all of a graph's execution.
148   int64_t max_memory;  // Maximum main memory requirement in bytes over all ops.
149   int64_t persistent_memory;
150   int64_t temporary_memory;
151 
152   // Output memory usage per port.
153   absl::flat_hash_map<int32_t, int64_t> output_tensor_size_bytes;
154 
155   // Track persistent versus temporary memory.
156   absl::flat_hash_set<int32_t> persistent_output_ports;
157 
158   // These fields are used for TPU-related estimations. They are per-op
159   // maximums, so each op is evaluated independently, but we want the maximum of
160   // the value over all ops.
161   int64_t max_per_op_buffers;    // Sum of all buffers used by the ops.
162   int64_t max_per_op_streaming;  // Ignore largest input buffer, assuming it
163                                  // streams from main memory.
164 
165   // Number of ops included in this Costs in total.
166   // Default initialized to be one.
167   int64_t num_ops_total = 1;
168   // If the time estimation is inaccurate.
169   bool inaccurate = false;
170   // Number of ops that are estimated with unknown shapes.
171   int64_t num_ops_with_unknown_shapes = 0;
172   // TODO(pcma): include a counter for total inaccurate ops and counters for
173   // other reasons causing the inaccuracy
174 
175   // Max possible memory usage per device.
176   std::unordered_map<string, uint64> estimated_max_memory_per_device;
177 };
178 
179 inline std::ostream& operator<<(std::ostream& os, const Costs::MilliSeconds d) {
180   os << d.count() << "ms";
181   return os;
182 }
183 inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) {
184   os << d.count() << "us";
185   return os;
186 }
187 inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) {
188   os << d.count() << "ns";
189   return os;
190 }
191 
Costs()192 Costs::Costs() {
193   execution_time = Duration::zero();
194   compute_time = Duration::zero();
195   memory_time = Duration::zero();
196   intermediate_memory_time = Duration::zero();
197   max_memory = kMemoryUnknown;
198   persistent_memory = kMemoryUnknown;
199   temporary_memory = kMemoryUnknown;
200   max_per_op_buffers = kMemoryUnknown;
201   max_per_op_streaming = kMemoryUnknown;
202 }
203 
ZeroCosts(bool inaccurate)204 Costs Costs::ZeroCosts(bool inaccurate) {
205   Costs costs;
206   costs.execution_time = Duration::zero();
207   costs.compute_time = Duration::zero();
208   costs.memory_time = Duration::zero();
209   costs.intermediate_memory_time = Duration::zero();
210   costs.max_memory = kZeroMemory;
211   costs.persistent_memory = kZeroMemory;
212   costs.temporary_memory = kZeroMemory;
213   costs.max_per_op_buffers = kZeroMemory;
214   costs.max_per_op_streaming = kZeroMemory;
215   costs.inaccurate = inaccurate;
216   return costs;
217 }
218 
219 Costs CombineCosts(const Costs& left, const Costs& right);
220 
221 // Multiplies Costs by a scalar.
222 // Equivalent to applying CombineCosts "multiplier" times.
223 Costs MultiplyCosts(const Costs& costs, int multiplier);
224 
225 // Given a GrapperItem and an optimized implementation of the corresponding
226 // TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
227 // running the graph.
228 class CostEstimator {
229  public:
~CostEstimator()230   virtual ~CostEstimator() {}
231 
232   // Initializes the estimator for the specified grappler item.
233   // The estimator shouldn't be used if this function returns any status other
234   // that OK.
235   virtual Status Initialize(const GrapplerItem& item) = 0;
236 
237   // Predicts the cost of running the given optimized version of the grappler
238   // item.
239   // If a RunMetadata is passed, it will be populated with detailed information
240   // about the cost of running each operation of the optimized graph.
241   // if a double value is passed, it will be set to a value that reflects the
242   // overall cost of running the graph (e.g. the latency of the computation).
243   // Returns a status that indicate is the performance could be estimated or
244   // not.
245   virtual Status PredictCosts(const GraphDef& optimized_graph,
246                               RunMetadata* run_metadata, Costs* cost) const = 0;
247 };
248 
249 }  // end namespace grappler
250 }  // end namespace tensorflow
251 
252 #endif  // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
253