1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
17 #define TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
18
19 #include <cmath>
20 #include <string>
21 #include <unordered_map>
22
23 #include "absl/container/flat_hash_map.h"
24 #include "absl/container/flat_hash_set.h"
25 #include "tensorflow/core/lib/core/status.h"
26 #include "tensorflow/core/protobuf/config.pb.h"
27
28 namespace tensorflow {
29 class GraphDef;
30 class CostGraphDef;
31
32 namespace grappler {
33 struct GrapplerItem;
34
35 constexpr int64_t kMemoryUnknown = -1ll;
36 constexpr int64_t kZeroMemory = 0ll;
37
38 struct DeviceInfo {
39 // Billions of operations executed per second.
40 double gigaops;
41
42 // Bandwidth to main memory in GB per second.
43 double gb_per_sec;
44
45 // Read bandwidth to intermediate memory in GB per second.
46 double intermediate_read_gb_per_sec;
47
48 // Write bandwidth to intermediate memory in GB per second.
49 double intermediate_write_gb_per_sec;
50
DeviceInfoDeviceInfo51 DeviceInfo()
52 : gigaops(INFINITY),
53 gb_per_sec(INFINITY),
54 intermediate_read_gb_per_sec(INFINITY),
55 intermediate_write_gb_per_sec(INFINITY) {}
56
DeviceInfoDeviceInfo57 DeviceInfo(const DeviceInfo& input)
58 : gigaops(input.gigaops),
59 gb_per_sec(input.gb_per_sec),
60 intermediate_read_gb_per_sec(input.intermediate_read_gb_per_sec),
61 intermediate_write_gb_per_sec(input.intermediate_write_gb_per_sec) {}
62
63 DeviceInfo(double gigaops, double gb_per_sec,
64 double intermediate_read_gb_per_sec = INFINITY,
65 double intermediate_write_gb_per_sec = INFINITY)
gigaopsDeviceInfo66 : gigaops(gigaops),
67 gb_per_sec(gb_per_sec),
68 intermediate_read_gb_per_sec(intermediate_read_gb_per_sec),
69 intermediate_write_gb_per_sec(intermediate_write_gb_per_sec) {}
70 };
71
72 // Holds the set of things we might want to estimate or measure in Grappler.
73 // Always produce execution time. Other fields are optional depending on the
74 // estimator being used.
75 struct Costs {
76 // Returns a Costs structure with default values for all of the fields.
77 inline Costs();
78
79 // Builds a Costs structure with all zero values, rather than unknowns.
80 static inline Costs ZeroCosts(bool inaccurate = false);
81
82 struct MilliSeconds : std::chrono::milliseconds {
MilliSecondsCosts::MilliSeconds83 MilliSeconds() : std::chrono::milliseconds(0) {}
MilliSecondsCosts::MilliSeconds84 MilliSeconds(double d)
85 : std::chrono::milliseconds(static_cast<int64_t>(d)) {}
MilliSecondsCosts::MilliSeconds86 MilliSeconds(const std::chrono::milliseconds& d)
87 : std::chrono::milliseconds(d) {}
88 MilliSeconds& operator=(const std::chrono::milliseconds& d) {
89 std::chrono::milliseconds::operator=(d);
90 return *this;
91 }
92 };
93 struct MicroSeconds : std::chrono::microseconds {
MicroSecondsCosts::MicroSeconds94 MicroSeconds() : std::chrono::microseconds(0) {}
MicroSecondsCosts::MicroSeconds95 MicroSeconds(double d)
96 : std::chrono::microseconds(static_cast<int64_t>(d)) {}
MicroSecondsCosts::MicroSeconds97 MicroSeconds(const std::chrono::microseconds& d)
98 : std::chrono::microseconds(d) {}
99 MicroSeconds& operator=(const std::chrono::microseconds& d) {
100 std::chrono::microseconds::operator=(d);
101 return *this;
102 }
asMilliSecondsCosts::MicroSeconds103 MilliSeconds asMilliSeconds() const {
104 return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
105 }
106 };
107 struct NanoSeconds : std::chrono::nanoseconds {
NanoSecondsCosts::NanoSeconds108 NanoSeconds() : std::chrono::nanoseconds(0) {}
NanoSecondsCosts::NanoSeconds109 NanoSeconds(double d) : std::chrono::nanoseconds(static_cast<int64_t>(d)) {}
NanoSecondsCosts::NanoSeconds110 NanoSeconds(const std::chrono::nanoseconds& d)
111 : std::chrono::nanoseconds(d) {}
112 NanoSeconds& operator=(const std::chrono::nanoseconds& d) {
113 std::chrono::nanoseconds::operator=(d);
114 return *this;
115 }
asMicroSecondsCosts::NanoSeconds116 MicroSeconds asMicroSeconds() const {
117 return std::chrono::duration_cast<std::chrono::microseconds>(*this);
118 }
asMilliSecondsCosts::NanoSeconds119 MilliSeconds asMilliSeconds() const {
120 return std::chrono::duration_cast<std::chrono::milliseconds>(*this);
121 }
infinityCosts::NanoSeconds122 static NanoSeconds infinity() {
123 return NanoSeconds(std::chrono::nanoseconds::max());
124 }
125 };
126
127 // We store all our times in nanoseconds. If needs be, we can always switch to
128 // picoseconds in the future by updating this typedef.
129 typedef NanoSeconds Duration;
130
131 // Overall cost of running the graph; latency.
132 Duration execution_time;
133
134 // Computation cost of running the graph.
135 Duration compute_time;
136
137 // Memory access cost of running the graph.
138 Duration memory_time;
139
140 // Intermediate memory access cost of running the graph
141 Duration intermediate_memory_time;
142 Duration intermediate_memory_read_time; // Intermediate memory read cost.
143 Duration intermediate_memory_write_time; // Intermediate memory write cost.
144
145 // This field can be a very pessimistic estimate of the main memory
146 // requirements of a graph. For example, it might assume that all activations
147 // are live for all of a graph's execution.
148 int64_t max_memory; // Maximum main memory requirement in bytes over all ops.
149 int64_t persistent_memory;
150 int64_t temporary_memory;
151
152 // Output memory usage per port.
153 absl::flat_hash_map<int32_t, int64_t> output_tensor_size_bytes;
154
155 // Track persistent versus temporary memory.
156 absl::flat_hash_set<int32_t> persistent_output_ports;
157
158 // These fields are used for TPU-related estimations. They are per-op
159 // maximums, so each op is evaluated independently, but we want the maximum of
160 // the value over all ops.
161 int64_t max_per_op_buffers; // Sum of all buffers used by the ops.
162 int64_t max_per_op_streaming; // Ignore largest input buffer, assuming it
163 // streams from main memory.
164
165 // Number of ops included in this Costs in total.
166 // Default initialized to be one.
167 int64_t num_ops_total = 1;
168 // If the time estimation is inaccurate.
169 bool inaccurate = false;
170 // Number of ops that are estimated with unknown shapes.
171 int64_t num_ops_with_unknown_shapes = 0;
172 // TODO(pcma): include a counter for total inaccurate ops and counters for
173 // other reasons causing the inaccuracy
174
175 // Max possible memory usage per device.
176 std::unordered_map<string, uint64> estimated_max_memory_per_device;
177 };
178
179 inline std::ostream& operator<<(std::ostream& os, const Costs::MilliSeconds d) {
180 os << d.count() << "ms";
181 return os;
182 }
183 inline std::ostream& operator<<(std::ostream& os, const Costs::MicroSeconds d) {
184 os << d.count() << "us";
185 return os;
186 }
187 inline std::ostream& operator<<(std::ostream& os, const Costs::NanoSeconds d) {
188 os << d.count() << "ns";
189 return os;
190 }
191
Costs()192 Costs::Costs() {
193 execution_time = Duration::zero();
194 compute_time = Duration::zero();
195 memory_time = Duration::zero();
196 intermediate_memory_time = Duration::zero();
197 max_memory = kMemoryUnknown;
198 persistent_memory = kMemoryUnknown;
199 temporary_memory = kMemoryUnknown;
200 max_per_op_buffers = kMemoryUnknown;
201 max_per_op_streaming = kMemoryUnknown;
202 }
203
ZeroCosts(bool inaccurate)204 Costs Costs::ZeroCosts(bool inaccurate) {
205 Costs costs;
206 costs.execution_time = Duration::zero();
207 costs.compute_time = Duration::zero();
208 costs.memory_time = Duration::zero();
209 costs.intermediate_memory_time = Duration::zero();
210 costs.max_memory = kZeroMemory;
211 costs.persistent_memory = kZeroMemory;
212 costs.temporary_memory = kZeroMemory;
213 costs.max_per_op_buffers = kZeroMemory;
214 costs.max_per_op_streaming = kZeroMemory;
215 costs.inaccurate = inaccurate;
216 return costs;
217 }
218
219 Costs CombineCosts(const Costs& left, const Costs& right);
220
221 // Multiplies Costs by a scalar.
222 // Equivalent to applying CombineCosts "multiplier" times.
223 Costs MultiplyCosts(const Costs& costs, int multiplier);
224
225 // Given a GrapperItem and an optimized implementation of the corresponding
226 // TensorFlow graph, the CostEstimator attempts to predicts the actual cost of
227 // running the graph.
228 class CostEstimator {
229 public:
~CostEstimator()230 virtual ~CostEstimator() {}
231
232 // Initializes the estimator for the specified grappler item.
233 // The estimator shouldn't be used if this function returns any status other
234 // that OK.
235 virtual Status Initialize(const GrapplerItem& item) = 0;
236
237 // Predicts the cost of running the given optimized version of the grappler
238 // item.
239 // If a RunMetadata is passed, it will be populated with detailed information
240 // about the cost of running each operation of the optimized graph.
241 // if a double value is passed, it will be set to a value that reflects the
242 // overall cost of running the graph (e.g. the latency of the computation).
243 // Returns a status that indicate is the performance could be estimated or
244 // not.
245 virtual Status PredictCosts(const GraphDef& optimized_graph,
246 RunMetadata* run_metadata, Costs* cost) const = 0;
247 };
248
249 } // end namespace grappler
250 } // end namespace tensorflow
251
252 #endif // TENSORFLOW_CORE_GRAPPLER_COSTS_COST_ESTIMATOR_H_
253