xref: /aosp_15_r20/external/swiftshader/third_party/SPIRV-Tools/source/opt/optimizer.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 // Copyright (c) 2016 Google Inc.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "spirv-tools/optimizer.hpp"
16 
17 #include <cassert>
18 #include <charconv>
19 #include <memory>
20 #include <string>
21 #include <unordered_map>
22 #include <utility>
23 #include <vector>
24 
25 #include "source/opt/build_module.h"
26 #include "source/opt/graphics_robust_access_pass.h"
27 #include "source/opt/log.h"
28 #include "source/opt/pass_manager.h"
29 #include "source/opt/passes.h"
30 #include "source/spirv_optimizer_options.h"
31 #include "source/util/make_unique.h"
32 #include "source/util/string_utils.h"
33 
34 namespace spvtools {
35 
GetVectorOfStrings(const char ** strings,const size_t string_count)36 std::vector<std::string> GetVectorOfStrings(const char** strings,
37                                             const size_t string_count) {
38   std::vector<std::string> result;
39   for (uint32_t i = 0; i < string_count; i++) {
40     result.emplace_back(strings[i]);
41   }
42   return result;
43 }
44 
45 struct Optimizer::PassToken::Impl {
Implspvtools::Optimizer::PassToken::Impl46   Impl(std::unique_ptr<opt::Pass> p) : pass(std::move(p)) {}
47 
48   std::unique_ptr<opt::Pass> pass;  // Internal implementation pass.
49 };
50 
PassToken(std::unique_ptr<Optimizer::PassToken::Impl> impl)51 Optimizer::PassToken::PassToken(
52     std::unique_ptr<Optimizer::PassToken::Impl> impl)
53     : impl_(std::move(impl)) {}
54 
PassToken(std::unique_ptr<opt::Pass> && pass)55 Optimizer::PassToken::PassToken(std::unique_ptr<opt::Pass>&& pass)
56     : impl_(MakeUnique<Optimizer::PassToken::Impl>(std::move(pass))) {}
57 
PassToken(PassToken && that)58 Optimizer::PassToken::PassToken(PassToken&& that)
59     : impl_(std::move(that.impl_)) {}
60 
operator =(PassToken && that)61 Optimizer::PassToken& Optimizer::PassToken::operator=(PassToken&& that) {
62   impl_ = std::move(that.impl_);
63   return *this;
64 }
65 
~PassToken()66 Optimizer::PassToken::~PassToken() {}
67 
68 struct Optimizer::Impl {
Implspvtools::Optimizer::Impl69   explicit Impl(spv_target_env env) : target_env(env), pass_manager() {}
70 
71   spv_target_env target_env;      // Target environment.
72   opt::PassManager pass_manager;  // Internal implementation pass manager.
73   std::unordered_set<uint32_t> live_locs;  // Arg to debug dead output passes
74 };
75 
Optimizer(spv_target_env env)76 Optimizer::Optimizer(spv_target_env env) : impl_(new Impl(env)) {
77   assert(env != SPV_ENV_WEBGPU_0);
78 }
79 
~Optimizer()80 Optimizer::~Optimizer() {}
81 
SetMessageConsumer(MessageConsumer c)82 void Optimizer::SetMessageConsumer(MessageConsumer c) {
83   // All passes' message consumer needs to be updated.
84   for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); ++i) {
85     impl_->pass_manager.GetPass(i)->SetMessageConsumer(c);
86   }
87   impl_->pass_manager.SetMessageConsumer(std::move(c));
88 }
89 
consumer() const90 const MessageConsumer& Optimizer::consumer() const {
91   return impl_->pass_manager.consumer();
92 }
93 
RegisterPass(PassToken && p)94 Optimizer& Optimizer::RegisterPass(PassToken&& p) {
95   // Change to use the pass manager's consumer.
96   p.impl_->pass->SetMessageConsumer(consumer());
97   impl_->pass_manager.AddPass(std::move(p.impl_->pass));
98   return *this;
99 }
100 
101 // The legalization passes take a spir-v shader generated by an HLSL front-end
102 // and turn it into a valid vulkan spir-v shader.  There are two ways in which
103 // the code will be invalid at the start:
104 //
105 // 1) There will be opaque objects, like images, which will be passed around
106 //    in intermediate objects.  Valid spir-v will have to replace the use of
107 //    the opaque object with an intermediate object that is the result of the
108 //    load of the global opaque object.
109 //
110 // 2) There will be variables that contain pointers to structured or uniform
111 //    buffers.  It be legal, the variables must be eliminated, and the
112 //    references to the structured buffers must use the result of OpVariable
113 //    in the Uniform storage class.
114 //
115 // Optimization in this list must accept shaders with these relaxation of the
116 // rules.  There is not guarantee that this list of optimizations is able to
117 // legalize all inputs, but it is on a best effort basis.
118 //
119 // The legalization problem is essentially a very general copy propagation
120 // problem.  The optimization we use are all used to either do copy propagation
121 // or enable more copy propagation.
RegisterLegalizationPasses(bool preserve_interface)122 Optimizer& Optimizer::RegisterLegalizationPasses(bool preserve_interface) {
123   return
124       // Wrap OpKill instructions so all other code can be inlined.
125       RegisterPass(CreateWrapOpKillPass())
126           // Remove unreachable block so that merge return works.
127           .RegisterPass(CreateDeadBranchElimPass())
128           // Merge the returns so we can inline.
129           .RegisterPass(CreateMergeReturnPass())
130           // Make sure uses and definitions are in the same function.
131           .RegisterPass(CreateInlineExhaustivePass())
132           // Make private variable function scope
133           .RegisterPass(CreateEliminateDeadFunctionsPass())
134           .RegisterPass(CreatePrivateToLocalPass())
135           // Fix up the storage classes that DXC may have purposely generated
136           // incorrectly.  All functions are inlined, and a lot of dead code has
137           // been removed.
138           .RegisterPass(CreateFixStorageClassPass())
139           // Propagate the value stored to the loads in very simple cases.
140           .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
141           .RegisterPass(CreateLocalSingleStoreElimPass())
142           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
143           // Split up aggregates so they are easier to deal with.
144           .RegisterPass(CreateScalarReplacementPass(0))
145           // Remove loads and stores so everything is in intermediate values.
146           // Takes care of copy propagation of non-members.
147           .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
148           .RegisterPass(CreateLocalSingleStoreElimPass())
149           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
150           .RegisterPass(CreateLocalMultiStoreElimPass())
151           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
152           // Propagate constants to get as many constant conditions on branches
153           // as possible.
154           .RegisterPass(CreateCCPPass())
155           .RegisterPass(CreateLoopUnrollPass(true))
156           .RegisterPass(CreateDeadBranchElimPass())
157           // Copy propagate members.  Cleans up code sequences generated by
158           // scalar replacement.  Also important for removing OpPhi nodes.
159           .RegisterPass(CreateSimplificationPass())
160           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
161           .RegisterPass(CreateCopyPropagateArraysPass())
162           // May need loop unrolling here see
163           // https://github.com/Microsoft/DirectXShaderCompiler/pull/930
164           // Get rid of unused code that contain traces of illegal code
165           // or unused references to unbound external objects
166           .RegisterPass(CreateVectorDCEPass())
167           .RegisterPass(CreateDeadInsertElimPass())
168           .RegisterPass(CreateReduceLoadSizePass())
169           .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
170           .RegisterPass(CreateRemoveUnusedInterfaceVariablesPass())
171           .RegisterPass(CreateInterpolateFixupPass())
172           .RegisterPass(CreateInvocationInterlockPlacementPass());
173 }
174 
RegisterLegalizationPasses()175 Optimizer& Optimizer::RegisterLegalizationPasses() {
176   return RegisterLegalizationPasses(false);
177 }
178 
RegisterPerformancePasses(bool preserve_interface)179 Optimizer& Optimizer::RegisterPerformancePasses(bool preserve_interface) {
180   return RegisterPass(CreateWrapOpKillPass())
181       .RegisterPass(CreateDeadBranchElimPass())
182       .RegisterPass(CreateMergeReturnPass())
183       .RegisterPass(CreateInlineExhaustivePass())
184       .RegisterPass(CreateEliminateDeadFunctionsPass())
185       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
186       .RegisterPass(CreatePrivateToLocalPass())
187       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
188       .RegisterPass(CreateLocalSingleStoreElimPass())
189       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
190       .RegisterPass(CreateScalarReplacementPass())
191       .RegisterPass(CreateLocalAccessChainConvertPass())
192       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
193       .RegisterPass(CreateLocalSingleStoreElimPass())
194       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
195       .RegisterPass(CreateLocalMultiStoreElimPass())
196       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
197       .RegisterPass(CreateCCPPass())
198       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
199       .RegisterPass(CreateLoopUnrollPass(true))
200       .RegisterPass(CreateDeadBranchElimPass())
201       .RegisterPass(CreateRedundancyEliminationPass())
202       .RegisterPass(CreateCombineAccessChainsPass())
203       .RegisterPass(CreateSimplificationPass())
204       .RegisterPass(CreateScalarReplacementPass())
205       .RegisterPass(CreateLocalAccessChainConvertPass())
206       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
207       .RegisterPass(CreateLocalSingleStoreElimPass())
208       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
209       .RegisterPass(CreateSSARewritePass())
210       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
211       .RegisterPass(CreateVectorDCEPass())
212       .RegisterPass(CreateDeadInsertElimPass())
213       .RegisterPass(CreateDeadBranchElimPass())
214       .RegisterPass(CreateSimplificationPass())
215       .RegisterPass(CreateIfConversionPass())
216       .RegisterPass(CreateCopyPropagateArraysPass())
217       .RegisterPass(CreateReduceLoadSizePass())
218       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
219       .RegisterPass(CreateBlockMergePass())
220       .RegisterPass(CreateRedundancyEliminationPass())
221       .RegisterPass(CreateDeadBranchElimPass())
222       .RegisterPass(CreateBlockMergePass())
223       .RegisterPass(CreateSimplificationPass());
224 }
225 
RegisterPerformancePasses()226 Optimizer& Optimizer::RegisterPerformancePasses() {
227   return RegisterPerformancePasses(false);
228 }
229 
RegisterSizePasses(bool preserve_interface)230 Optimizer& Optimizer::RegisterSizePasses(bool preserve_interface) {
231   return RegisterPass(CreateWrapOpKillPass())
232       .RegisterPass(CreateDeadBranchElimPass())
233       .RegisterPass(CreateMergeReturnPass())
234       .RegisterPass(CreateInlineExhaustivePass())
235       .RegisterPass(CreateEliminateDeadFunctionsPass())
236       .RegisterPass(CreatePrivateToLocalPass())
237       .RegisterPass(CreateScalarReplacementPass(0))
238       .RegisterPass(CreateLocalMultiStoreElimPass())
239       .RegisterPass(CreateCCPPass())
240       .RegisterPass(CreateLoopUnrollPass(true))
241       .RegisterPass(CreateDeadBranchElimPass())
242       .RegisterPass(CreateSimplificationPass())
243       .RegisterPass(CreateScalarReplacementPass(0))
244       .RegisterPass(CreateLocalSingleStoreElimPass())
245       .RegisterPass(CreateIfConversionPass())
246       .RegisterPass(CreateSimplificationPass())
247       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
248       .RegisterPass(CreateDeadBranchElimPass())
249       .RegisterPass(CreateBlockMergePass())
250       .RegisterPass(CreateLocalAccessChainConvertPass())
251       .RegisterPass(CreateLocalSingleBlockLoadStoreElimPass())
252       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
253       .RegisterPass(CreateCopyPropagateArraysPass())
254       .RegisterPass(CreateVectorDCEPass())
255       .RegisterPass(CreateDeadInsertElimPass())
256       .RegisterPass(CreateEliminateDeadMembersPass())
257       .RegisterPass(CreateLocalSingleStoreElimPass())
258       .RegisterPass(CreateBlockMergePass())
259       .RegisterPass(CreateLocalMultiStoreElimPass())
260       .RegisterPass(CreateRedundancyEliminationPass())
261       .RegisterPass(CreateSimplificationPass())
262       .RegisterPass(CreateAggressiveDCEPass(preserve_interface))
263       .RegisterPass(CreateCFGCleanupPass());
264 }
265 
RegisterSizePasses()266 Optimizer& Optimizer::RegisterSizePasses() { return RegisterSizePasses(false); }
267 
RegisterPassesFromFlags(const std::vector<std::string> & flags)268 bool Optimizer::RegisterPassesFromFlags(const std::vector<std::string>& flags) {
269   return RegisterPassesFromFlags(flags, false);
270 }
271 
RegisterPassesFromFlags(const std::vector<std::string> & flags,bool preserve_interface)272 bool Optimizer::RegisterPassesFromFlags(const std::vector<std::string>& flags,
273                                         bool preserve_interface) {
274   for (const auto& flag : flags) {
275     if (!RegisterPassFromFlag(flag, preserve_interface)) {
276       return false;
277     }
278   }
279 
280   return true;
281 }
282 
FlagHasValidForm(const std::string & flag) const283 bool Optimizer::FlagHasValidForm(const std::string& flag) const {
284   if (flag == "-O" || flag == "-Os") {
285     return true;
286   } else if (flag.size() > 2 && flag.substr(0, 2) == "--") {
287     return true;
288   }
289 
290   Errorf(consumer(), nullptr, {},
291          "%s is not a valid flag.  Flag passes should have the form "
292          "'--pass_name[=pass_args]'. Special flag names also accepted: -O "
293          "and -Os.",
294          flag.c_str());
295   return false;
296 }
297 
RegisterPassFromFlag(const std::string & flag)298 bool Optimizer::RegisterPassFromFlag(const std::string& flag) {
299   return RegisterPassFromFlag(flag, false);
300 }
301 
RegisterPassFromFlag(const std::string & flag,bool preserve_interface)302 bool Optimizer::RegisterPassFromFlag(const std::string& flag,
303                                      bool preserve_interface) {
304   if (!FlagHasValidForm(flag)) {
305     return false;
306   }
307 
308   // Split flags of the form --pass_name=pass_args.
309   auto p = utils::SplitFlagArgs(flag);
310   std::string pass_name = p.first;
311   std::string pass_args = p.second;
312 
313   // FIXME(dnovillo): This should be re-factored so that pass names can be
314   // automatically checked against Pass::name() and PassToken instances created
315   // via a template function.  Additionally, class Pass should have a desc()
316   // method that describes the pass (so it can be used in --help).
317   //
318   // Both Pass::name() and Pass::desc() should be static class members so they
319   // can be invoked without creating a pass instance.
320   if (pass_name == "strip-debug") {
321     RegisterPass(CreateStripDebugInfoPass());
322   } else if (pass_name == "strip-reflect") {
323     RegisterPass(CreateStripReflectInfoPass());
324   } else if (pass_name == "strip-nonsemantic") {
325     RegisterPass(CreateStripNonSemanticInfoPass());
326   } else if (pass_name == "set-spec-const-default-value") {
327     if (pass_args.size() > 0) {
328       auto spec_ids_vals =
329           opt::SetSpecConstantDefaultValuePass::ParseDefaultValuesString(
330               pass_args.c_str());
331       if (!spec_ids_vals) {
332         Errorf(consumer(), nullptr, {},
333                "Invalid argument for --set-spec-const-default-value: %s",
334                pass_args.c_str());
335         return false;
336       }
337       RegisterPass(
338           CreateSetSpecConstantDefaultValuePass(std::move(*spec_ids_vals)));
339     } else {
340       Errorf(consumer(), nullptr, {},
341              "Invalid spec constant value string '%s'. Expected a string of "
342              "<spec id>:<default value> pairs.",
343              pass_args.c_str());
344       return false;
345     }
346   } else if (pass_name == "if-conversion") {
347     RegisterPass(CreateIfConversionPass());
348   } else if (pass_name == "freeze-spec-const") {
349     RegisterPass(CreateFreezeSpecConstantValuePass());
350   } else if (pass_name == "inline-entry-points-exhaustive") {
351     RegisterPass(CreateInlineExhaustivePass());
352   } else if (pass_name == "inline-entry-points-opaque") {
353     RegisterPass(CreateInlineOpaquePass());
354   } else if (pass_name == "combine-access-chains") {
355     RegisterPass(CreateCombineAccessChainsPass());
356   } else if (pass_name == "convert-local-access-chains") {
357     RegisterPass(CreateLocalAccessChainConvertPass());
358   } else if (pass_name == "replace-desc-array-access-using-var-index") {
359     RegisterPass(CreateReplaceDescArrayAccessUsingVarIndexPass());
360   } else if (pass_name == "spread-volatile-semantics") {
361     RegisterPass(CreateSpreadVolatileSemanticsPass());
362   } else if (pass_name == "descriptor-scalar-replacement") {
363     RegisterPass(CreateDescriptorScalarReplacementPass());
364   } else if (pass_name == "eliminate-dead-code-aggressive") {
365     RegisterPass(CreateAggressiveDCEPass(preserve_interface));
366   } else if (pass_name == "eliminate-insert-extract") {
367     RegisterPass(CreateInsertExtractElimPass());
368   } else if (pass_name == "eliminate-local-single-block") {
369     RegisterPass(CreateLocalSingleBlockLoadStoreElimPass());
370   } else if (pass_name == "eliminate-local-single-store") {
371     RegisterPass(CreateLocalSingleStoreElimPass());
372   } else if (pass_name == "merge-blocks") {
373     RegisterPass(CreateBlockMergePass());
374   } else if (pass_name == "merge-return") {
375     RegisterPass(CreateMergeReturnPass());
376   } else if (pass_name == "eliminate-dead-branches") {
377     RegisterPass(CreateDeadBranchElimPass());
378   } else if (pass_name == "eliminate-dead-functions") {
379     RegisterPass(CreateEliminateDeadFunctionsPass());
380   } else if (pass_name == "eliminate-local-multi-store") {
381     RegisterPass(CreateLocalMultiStoreElimPass());
382   } else if (pass_name == "eliminate-dead-const") {
383     RegisterPass(CreateEliminateDeadConstantPass());
384   } else if (pass_name == "eliminate-dead-inserts") {
385     RegisterPass(CreateDeadInsertElimPass());
386   } else if (pass_name == "eliminate-dead-variables") {
387     RegisterPass(CreateDeadVariableEliminationPass());
388   } else if (pass_name == "eliminate-dead-members") {
389     RegisterPass(CreateEliminateDeadMembersPass());
390   } else if (pass_name == "fold-spec-const-op-composite") {
391     RegisterPass(CreateFoldSpecConstantOpAndCompositePass());
392   } else if (pass_name == "loop-unswitch") {
393     RegisterPass(CreateLoopUnswitchPass());
394   } else if (pass_name == "scalar-replacement") {
395     if (pass_args.size() == 0) {
396       RegisterPass(CreateScalarReplacementPass());
397     } else {
398       int limit = -1;
399       if (pass_args.find_first_not_of("0123456789") == std::string::npos) {
400         limit = atoi(pass_args.c_str());
401       }
402 
403       if (limit >= 0) {
404         RegisterPass(CreateScalarReplacementPass(limit));
405       } else {
406         Error(consumer(), nullptr, {},
407               "--scalar-replacement must have no arguments or a non-negative "
408               "integer argument");
409         return false;
410       }
411     }
412   } else if (pass_name == "strength-reduction") {
413     RegisterPass(CreateStrengthReductionPass());
414   } else if (pass_name == "unify-const") {
415     RegisterPass(CreateUnifyConstantPass());
416   } else if (pass_name == "flatten-decorations") {
417     RegisterPass(CreateFlattenDecorationPass());
418   } else if (pass_name == "compact-ids") {
419     RegisterPass(CreateCompactIdsPass());
420   } else if (pass_name == "cfg-cleanup") {
421     RegisterPass(CreateCFGCleanupPass());
422   } else if (pass_name == "local-redundancy-elimination") {
423     RegisterPass(CreateLocalRedundancyEliminationPass());
424   } else if (pass_name == "loop-invariant-code-motion") {
425     RegisterPass(CreateLoopInvariantCodeMotionPass());
426   } else if (pass_name == "reduce-load-size") {
427     if (pass_args.size() == 0) {
428       RegisterPass(CreateReduceLoadSizePass());
429     } else {
430       double load_replacement_threshold = 0.9;
431       if (pass_args.find_first_not_of(".0123456789") == std::string::npos) {
432         load_replacement_threshold = atof(pass_args.c_str());
433       }
434 
435       if (load_replacement_threshold >= 0) {
436         RegisterPass(CreateReduceLoadSizePass(load_replacement_threshold));
437       } else {
438         Error(consumer(), nullptr, {},
439               "--reduce-load-size must have no arguments or a non-negative "
440               "double argument");
441         return false;
442       }
443     }
444   } else if (pass_name == "redundancy-elimination") {
445     RegisterPass(CreateRedundancyEliminationPass());
446   } else if (pass_name == "private-to-local") {
447     RegisterPass(CreatePrivateToLocalPass());
448   } else if (pass_name == "remove-duplicates") {
449     RegisterPass(CreateRemoveDuplicatesPass());
450   } else if (pass_name == "workaround-1209") {
451     RegisterPass(CreateWorkaround1209Pass());
452   } else if (pass_name == "replace-invalid-opcode") {
453     RegisterPass(CreateReplaceInvalidOpcodePass());
454   } else if (pass_name == "inst-bindless-check" ||
455              pass_name == "inst-desc-idx-check" ||
456              pass_name == "inst-buff-oob-check") {
457     // preserve legacy names
458     RegisterPass(CreateInstBindlessCheckPass(23));
459     RegisterPass(CreateSimplificationPass());
460     RegisterPass(CreateDeadBranchElimPass());
461     RegisterPass(CreateBlockMergePass());
462   } else if (pass_name == "inst-buff-addr-check") {
463     RegisterPass(CreateInstBuffAddrCheckPass(23));
464   } else if (pass_name == "convert-relaxed-to-half") {
465     RegisterPass(CreateConvertRelaxedToHalfPass());
466   } else if (pass_name == "relax-float-ops") {
467     RegisterPass(CreateRelaxFloatOpsPass());
468   } else if (pass_name == "inst-debug-printf") {
469     // This private option is not for user consumption.
470     // It is here to assist in debugging and fixing the debug printf
471     // instrumentation pass.
472     // For users who wish to utilize debug printf, see the white paper at
473     // https://www.lunarg.com/wp-content/uploads/2021/08/Using-Debug-Printf-02August2021.pdf
474     RegisterPass(CreateInstDebugPrintfPass(7, 23));
475   } else if (pass_name == "simplify-instructions") {
476     RegisterPass(CreateSimplificationPass());
477   } else if (pass_name == "ssa-rewrite") {
478     RegisterPass(CreateSSARewritePass());
479   } else if (pass_name == "copy-propagate-arrays") {
480     RegisterPass(CreateCopyPropagateArraysPass());
481   } else if (pass_name == "loop-fission") {
482     int register_threshold_to_split =
483         (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1;
484     if (register_threshold_to_split > 0) {
485       RegisterPass(CreateLoopFissionPass(
486           static_cast<size_t>(register_threshold_to_split)));
487     } else {
488       Error(consumer(), nullptr, {},
489             "--loop-fission must have a positive integer argument");
490       return false;
491     }
492   } else if (pass_name == "loop-fusion") {
493     int max_registers_per_loop =
494         (pass_args.size() > 0) ? atoi(pass_args.c_str()) : -1;
495     if (max_registers_per_loop > 0) {
496       RegisterPass(
497           CreateLoopFusionPass(static_cast<size_t>(max_registers_per_loop)));
498     } else {
499       Error(consumer(), nullptr, {},
500             "--loop-fusion must have a positive integer argument");
501       return false;
502     }
503   } else if (pass_name == "loop-unroll") {
504     RegisterPass(CreateLoopUnrollPass(true));
505   } else if (pass_name == "upgrade-memory-model") {
506     RegisterPass(CreateUpgradeMemoryModelPass());
507   } else if (pass_name == "vector-dce") {
508     RegisterPass(CreateVectorDCEPass());
509   } else if (pass_name == "loop-unroll-partial") {
510     int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0;
511     if (factor > 0) {
512       RegisterPass(CreateLoopUnrollPass(false, factor));
513     } else {
514       Error(consumer(), nullptr, {},
515             "--loop-unroll-partial must have a positive integer argument");
516       return false;
517     }
518   } else if (pass_name == "loop-peeling") {
519     RegisterPass(CreateLoopPeelingPass());
520   } else if (pass_name == "loop-peeling-threshold") {
521     int factor = (pass_args.size() > 0) ? atoi(pass_args.c_str()) : 0;
522     if (factor > 0) {
523       opt::LoopPeelingPass::SetLoopPeelingThreshold(factor);
524     } else {
525       Error(consumer(), nullptr, {},
526             "--loop-peeling-threshold must have a positive integer argument");
527       return false;
528     }
529   } else if (pass_name == "ccp") {
530     RegisterPass(CreateCCPPass());
531   } else if (pass_name == "code-sink") {
532     RegisterPass(CreateCodeSinkingPass());
533   } else if (pass_name == "fix-storage-class") {
534     RegisterPass(CreateFixStorageClassPass());
535   } else if (pass_name == "O") {
536     RegisterPerformancePasses(preserve_interface);
537   } else if (pass_name == "Os") {
538     RegisterSizePasses(preserve_interface);
539   } else if (pass_name == "legalize-hlsl") {
540     RegisterLegalizationPasses(preserve_interface);
541   } else if (pass_name == "remove-unused-interface-variables") {
542     RegisterPass(CreateRemoveUnusedInterfaceVariablesPass());
543   } else if (pass_name == "graphics-robust-access") {
544     RegisterPass(CreateGraphicsRobustAccessPass());
545   } else if (pass_name == "wrap-opkill") {
546     RegisterPass(CreateWrapOpKillPass());
547   } else if (pass_name == "amd-ext-to-khr") {
548     RegisterPass(CreateAmdExtToKhrPass());
549   } else if (pass_name == "interpolate-fixup") {
550     RegisterPass(CreateInterpolateFixupPass());
551   } else if (pass_name == "remove-dont-inline") {
552     RegisterPass(CreateRemoveDontInlinePass());
553   } else if (pass_name == "eliminate-dead-input-components") {
554     RegisterPass(CreateEliminateDeadInputComponentsSafePass());
555   } else if (pass_name == "fix-func-call-param") {
556     RegisterPass(CreateFixFuncCallArgumentsPass());
557   } else if (pass_name == "convert-to-sampled-image") {
558     if (pass_args.size() > 0) {
559       auto descriptor_set_binding_pairs =
560           opt::ConvertToSampledImagePass::ParseDescriptorSetBindingPairsString(
561               pass_args.c_str());
562       if (!descriptor_set_binding_pairs) {
563         Errorf(consumer(), nullptr, {},
564                "Invalid argument for --convert-to-sampled-image: %s",
565                pass_args.c_str());
566         return false;
567       }
568       RegisterPass(CreateConvertToSampledImagePass(
569           std::move(*descriptor_set_binding_pairs)));
570     } else {
571       Errorf(consumer(), nullptr, {},
572              "Invalid pairs of descriptor set and binding '%s'. Expected a "
573              "string of <descriptor set>:<binding> pairs.",
574              pass_args.c_str());
575       return false;
576     }
577   } else if (pass_name == "switch-descriptorset") {
578     if (pass_args.size() == 0) {
579       Error(consumer(), nullptr, {},
580             "--switch-descriptorset requires a from:to argument.");
581       return false;
582     }
583     uint32_t from_set = 0, to_set = 0;
584     const char* start = pass_args.data();
585     const char* end = pass_args.data() + pass_args.size();
586 
587     auto result = std::from_chars(start, end, from_set);
588     if (result.ec != std::errc()) {
589       Errorf(consumer(), nullptr, {},
590              "Invalid argument for --switch-descriptorset: %s",
591              pass_args.c_str());
592       return false;
593     }
594     start = result.ptr;
595     if (start[0] != ':') {
596       Errorf(consumer(), nullptr, {},
597              "Invalid argument for --switch-descriptorset: %s",
598              pass_args.c_str());
599       return false;
600     }
601     start++;
602     result = std::from_chars(start, end, to_set);
603     if (result.ec != std::errc() || result.ptr != end) {
604       Errorf(consumer(), nullptr, {},
605              "Invalid argument for --switch-descriptorset: %s",
606              pass_args.c_str());
607       return false;
608     }
609     RegisterPass(CreateSwitchDescriptorSetPass(from_set, to_set));
610   } else if (pass_name == "modify-maximal-reconvergence") {
611     if (pass_args.size() == 0) {
612       Error(consumer(), nullptr, {},
613             "--modify-maximal-reconvergence requires an argument");
614       return false;
615     }
616     if (pass_args == "add") {
617       RegisterPass(CreateModifyMaximalReconvergencePass(true));
618     } else if (pass_args == "remove") {
619       RegisterPass(CreateModifyMaximalReconvergencePass(false));
620     } else {
621       Errorf(consumer(), nullptr, {},
622              "Invalid argument for --modify-maximal-reconvergence: %s (must be "
623              "'add' or 'remove')",
624              pass_args.c_str());
625       return false;
626     }
627   } else if (pass_name == "trim-capabilities") {
628     RegisterPass(CreateTrimCapabilitiesPass());
629   } else {
630     Errorf(consumer(), nullptr, {},
631            "Unknown flag '--%s'. Use --help for a list of valid flags",
632            pass_name.c_str());
633     return false;
634   }
635 
636   return true;
637 }
638 
SetTargetEnv(const spv_target_env env)639 void Optimizer::SetTargetEnv(const spv_target_env env) {
640   impl_->target_env = env;
641 }
642 
Run(const uint32_t * original_binary,const size_t original_binary_size,std::vector<uint32_t> * optimized_binary) const643 bool Optimizer::Run(const uint32_t* original_binary,
644                     const size_t original_binary_size,
645                     std::vector<uint32_t>* optimized_binary) const {
646   return Run(original_binary, original_binary_size, optimized_binary,
647              OptimizerOptions());
648 }
649 
Run(const uint32_t * original_binary,const size_t original_binary_size,std::vector<uint32_t> * optimized_binary,const ValidatorOptions & validator_options,bool skip_validation) const650 bool Optimizer::Run(const uint32_t* original_binary,
651                     const size_t original_binary_size,
652                     std::vector<uint32_t>* optimized_binary,
653                     const ValidatorOptions& validator_options,
654                     bool skip_validation) const {
655   OptimizerOptions opt_options;
656   opt_options.set_run_validator(!skip_validation);
657   opt_options.set_validator_options(validator_options);
658   return Run(original_binary, original_binary_size, optimized_binary,
659              opt_options);
660 }
661 
Run(const uint32_t * original_binary,const size_t original_binary_size,std::vector<uint32_t> * optimized_binary,const spv_optimizer_options opt_options) const662 bool Optimizer::Run(const uint32_t* original_binary,
663                     const size_t original_binary_size,
664                     std::vector<uint32_t>* optimized_binary,
665                     const spv_optimizer_options opt_options) const {
666   spvtools::SpirvTools tools(impl_->target_env);
667   tools.SetMessageConsumer(impl_->pass_manager.consumer());
668   if (opt_options->run_validator_ &&
669       !tools.Validate(original_binary, original_binary_size,
670                       &opt_options->val_options_)) {
671     return false;
672   }
673 
674   std::unique_ptr<opt::IRContext> context = BuildModule(
675       impl_->target_env, consumer(), original_binary, original_binary_size);
676   if (context == nullptr) return false;
677 
678   context->set_max_id_bound(opt_options->max_id_bound_);
679   context->set_preserve_bindings(opt_options->preserve_bindings_);
680   context->set_preserve_spec_constants(opt_options->preserve_spec_constants_);
681 
682   impl_->pass_manager.SetValidatorOptions(&opt_options->val_options_);
683   impl_->pass_manager.SetTargetEnv(impl_->target_env);
684   auto status = impl_->pass_manager.Run(context.get());
685 
686   if (status == opt::Pass::Status::Failure) {
687     return false;
688   }
689 
690 #ifndef NDEBUG
691   // We do not keep the result id of DebugScope in struct DebugScope.
692   // Instead, we assign random ids for them, which results in integrity
693   // check failures. In addition, propagating the OpLine/OpNoLine to preserve
694   // the debug information through transformations results in integrity
695   // check failures. We want to skip the integrity check when the module
696   // contains DebugScope or OpLine/OpNoLine instructions.
697   if (status == opt::Pass::Status::SuccessWithoutChange &&
698       !context->module()->ContainsDebugInfo()) {
699     std::vector<uint32_t> optimized_binary_with_nop;
700     context->module()->ToBinary(&optimized_binary_with_nop,
701                                 /* skip_nop = */ false);
702     assert(optimized_binary_with_nop.size() == original_binary_size &&
703            "Binary size unexpectedly changed despite the optimizer saying "
704            "there was no change");
705 
706     // Compare the magic number to make sure the binaries were encoded in the
707     // endianness.  If not, the contents of the binaries will be different, so
708     // do not check the contents.
709     if (optimized_binary_with_nop[0] == original_binary[0]) {
710       assert(memcmp(optimized_binary_with_nop.data(), original_binary,
711                     original_binary_size) == 0 &&
712              "Binary content unexpectedly changed despite the optimizer saying "
713              "there was no change");
714     }
715   }
716 #endif  // !NDEBUG
717 
718   // Note that |original_binary| and |optimized_binary| may share the same
719   // buffer and the below will invalidate |original_binary|.
720   optimized_binary->clear();
721   context->module()->ToBinary(optimized_binary, /* skip_nop = */ true);
722 
723   return true;
724 }
725 
SetPrintAll(std::ostream * out)726 Optimizer& Optimizer::SetPrintAll(std::ostream* out) {
727   impl_->pass_manager.SetPrintAll(out);
728   return *this;
729 }
730 
SetTimeReport(std::ostream * out)731 Optimizer& Optimizer::SetTimeReport(std::ostream* out) {
732   impl_->pass_manager.SetTimeReport(out);
733   return *this;
734 }
735 
SetValidateAfterAll(bool validate)736 Optimizer& Optimizer::SetValidateAfterAll(bool validate) {
737   impl_->pass_manager.SetValidateAfterAll(validate);
738   return *this;
739 }
740 
CreateNullPass()741 Optimizer::PassToken CreateNullPass() {
742   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::NullPass>());
743 }
744 
CreateStripDebugInfoPass()745 Optimizer::PassToken CreateStripDebugInfoPass() {
746   return MakeUnique<Optimizer::PassToken::Impl>(
747       MakeUnique<opt::StripDebugInfoPass>());
748 }
749 
CreateStripReflectInfoPass()750 Optimizer::PassToken CreateStripReflectInfoPass() {
751   return CreateStripNonSemanticInfoPass();
752 }
753 
CreateStripNonSemanticInfoPass()754 Optimizer::PassToken CreateStripNonSemanticInfoPass() {
755   return MakeUnique<Optimizer::PassToken::Impl>(
756       MakeUnique<opt::StripNonSemanticInfoPass>());
757 }
758 
CreateEliminateDeadFunctionsPass()759 Optimizer::PassToken CreateEliminateDeadFunctionsPass() {
760   return MakeUnique<Optimizer::PassToken::Impl>(
761       MakeUnique<opt::EliminateDeadFunctionsPass>());
762 }
763 
CreateEliminateDeadMembersPass()764 Optimizer::PassToken CreateEliminateDeadMembersPass() {
765   return MakeUnique<Optimizer::PassToken::Impl>(
766       MakeUnique<opt::EliminateDeadMembersPass>());
767 }
768 
CreateSetSpecConstantDefaultValuePass(const std::unordered_map<uint32_t,std::string> & id_value_map)769 Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
770     const std::unordered_map<uint32_t, std::string>& id_value_map) {
771   return MakeUnique<Optimizer::PassToken::Impl>(
772       MakeUnique<opt::SetSpecConstantDefaultValuePass>(id_value_map));
773 }
774 
CreateSetSpecConstantDefaultValuePass(const std::unordered_map<uint32_t,std::vector<uint32_t>> & id_value_map)775 Optimizer::PassToken CreateSetSpecConstantDefaultValuePass(
776     const std::unordered_map<uint32_t, std::vector<uint32_t>>& id_value_map) {
777   return MakeUnique<Optimizer::PassToken::Impl>(
778       MakeUnique<opt::SetSpecConstantDefaultValuePass>(id_value_map));
779 }
780 
CreateFlattenDecorationPass()781 Optimizer::PassToken CreateFlattenDecorationPass() {
782   return MakeUnique<Optimizer::PassToken::Impl>(
783       MakeUnique<opt::FlattenDecorationPass>());
784 }
785 
CreateFreezeSpecConstantValuePass()786 Optimizer::PassToken CreateFreezeSpecConstantValuePass() {
787   return MakeUnique<Optimizer::PassToken::Impl>(
788       MakeUnique<opt::FreezeSpecConstantValuePass>());
789 }
790 
CreateFoldSpecConstantOpAndCompositePass()791 Optimizer::PassToken CreateFoldSpecConstantOpAndCompositePass() {
792   return MakeUnique<Optimizer::PassToken::Impl>(
793       MakeUnique<opt::FoldSpecConstantOpAndCompositePass>());
794 }
795 
CreateUnifyConstantPass()796 Optimizer::PassToken CreateUnifyConstantPass() {
797   return MakeUnique<Optimizer::PassToken::Impl>(
798       MakeUnique<opt::UnifyConstantPass>());
799 }
800 
CreateEliminateDeadConstantPass()801 Optimizer::PassToken CreateEliminateDeadConstantPass() {
802   return MakeUnique<Optimizer::PassToken::Impl>(
803       MakeUnique<opt::EliminateDeadConstantPass>());
804 }
805 
CreateDeadVariableEliminationPass()806 Optimizer::PassToken CreateDeadVariableEliminationPass() {
807   return MakeUnique<Optimizer::PassToken::Impl>(
808       MakeUnique<opt::DeadVariableElimination>());
809 }
810 
CreateStrengthReductionPass()811 Optimizer::PassToken CreateStrengthReductionPass() {
812   return MakeUnique<Optimizer::PassToken::Impl>(
813       MakeUnique<opt::StrengthReductionPass>());
814 }
815 
CreateBlockMergePass()816 Optimizer::PassToken CreateBlockMergePass() {
817   return MakeUnique<Optimizer::PassToken::Impl>(
818       MakeUnique<opt::BlockMergePass>());
819 }
820 
CreateInlineExhaustivePass()821 Optimizer::PassToken CreateInlineExhaustivePass() {
822   return MakeUnique<Optimizer::PassToken::Impl>(
823       MakeUnique<opt::InlineExhaustivePass>());
824 }
825 
CreateInlineOpaquePass()826 Optimizer::PassToken CreateInlineOpaquePass() {
827   return MakeUnique<Optimizer::PassToken::Impl>(
828       MakeUnique<opt::InlineOpaquePass>());
829 }
830 
CreateLocalAccessChainConvertPass()831 Optimizer::PassToken CreateLocalAccessChainConvertPass() {
832   return MakeUnique<Optimizer::PassToken::Impl>(
833       MakeUnique<opt::LocalAccessChainConvertPass>());
834 }
835 
CreateLocalSingleBlockLoadStoreElimPass()836 Optimizer::PassToken CreateLocalSingleBlockLoadStoreElimPass() {
837   return MakeUnique<Optimizer::PassToken::Impl>(
838       MakeUnique<opt::LocalSingleBlockLoadStoreElimPass>());
839 }
840 
CreateLocalSingleStoreElimPass()841 Optimizer::PassToken CreateLocalSingleStoreElimPass() {
842   return MakeUnique<Optimizer::PassToken::Impl>(
843       MakeUnique<opt::LocalSingleStoreElimPass>());
844 }
845 
CreateInsertExtractElimPass()846 Optimizer::PassToken CreateInsertExtractElimPass() {
847   return MakeUnique<Optimizer::PassToken::Impl>(
848       MakeUnique<opt::SimplificationPass>());
849 }
850 
CreateDeadInsertElimPass()851 Optimizer::PassToken CreateDeadInsertElimPass() {
852   return MakeUnique<Optimizer::PassToken::Impl>(
853       MakeUnique<opt::DeadInsertElimPass>());
854 }
855 
CreateDeadBranchElimPass()856 Optimizer::PassToken CreateDeadBranchElimPass() {
857   return MakeUnique<Optimizer::PassToken::Impl>(
858       MakeUnique<opt::DeadBranchElimPass>());
859 }
860 
CreateLocalMultiStoreElimPass()861 Optimizer::PassToken CreateLocalMultiStoreElimPass() {
862   return MakeUnique<Optimizer::PassToken::Impl>(
863       MakeUnique<opt::SSARewritePass>());
864 }
865 
CreateAggressiveDCEPass()866 Optimizer::PassToken CreateAggressiveDCEPass() {
867   return MakeUnique<Optimizer::PassToken::Impl>(
868       MakeUnique<opt::AggressiveDCEPass>(false, false));
869 }
870 
CreateAggressiveDCEPass(bool preserve_interface)871 Optimizer::PassToken CreateAggressiveDCEPass(bool preserve_interface) {
872   return MakeUnique<Optimizer::PassToken::Impl>(
873       MakeUnique<opt::AggressiveDCEPass>(preserve_interface, false));
874 }
875 
CreateAggressiveDCEPass(bool preserve_interface,bool remove_outputs)876 Optimizer::PassToken CreateAggressiveDCEPass(bool preserve_interface,
877                                              bool remove_outputs) {
878   return MakeUnique<Optimizer::PassToken::Impl>(
879       MakeUnique<opt::AggressiveDCEPass>(preserve_interface, remove_outputs));
880 }
881 
CreateRemoveUnusedInterfaceVariablesPass()882 Optimizer::PassToken CreateRemoveUnusedInterfaceVariablesPass() {
883   return MakeUnique<Optimizer::PassToken::Impl>(
884       MakeUnique<opt::RemoveUnusedInterfaceVariablesPass>());
885 }
886 
CreatePropagateLineInfoPass()887 Optimizer::PassToken CreatePropagateLineInfoPass() {
888   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::EmptyPass>());
889 }
890 
CreateRedundantLineInfoElimPass()891 Optimizer::PassToken CreateRedundantLineInfoElimPass() {
892   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::EmptyPass>());
893 }
894 
CreateCompactIdsPass()895 Optimizer::PassToken CreateCompactIdsPass() {
896   return MakeUnique<Optimizer::PassToken::Impl>(
897       MakeUnique<opt::CompactIdsPass>());
898 }
899 
CreateMergeReturnPass()900 Optimizer::PassToken CreateMergeReturnPass() {
901   return MakeUnique<Optimizer::PassToken::Impl>(
902       MakeUnique<opt::MergeReturnPass>());
903 }
904 
GetPassNames() const905 std::vector<const char*> Optimizer::GetPassNames() const {
906   std::vector<const char*> v;
907   for (uint32_t i = 0; i < impl_->pass_manager.NumPasses(); i++) {
908     v.push_back(impl_->pass_manager.GetPass(i)->name());
909   }
910   return v;
911 }
912 
CreateCFGCleanupPass()913 Optimizer::PassToken CreateCFGCleanupPass() {
914   return MakeUnique<Optimizer::PassToken::Impl>(
915       MakeUnique<opt::CFGCleanupPass>());
916 }
917 
CreateLocalRedundancyEliminationPass()918 Optimizer::PassToken CreateLocalRedundancyEliminationPass() {
919   return MakeUnique<Optimizer::PassToken::Impl>(
920       MakeUnique<opt::LocalRedundancyEliminationPass>());
921 }
922 
CreateLoopFissionPass(size_t threshold)923 Optimizer::PassToken CreateLoopFissionPass(size_t threshold) {
924   return MakeUnique<Optimizer::PassToken::Impl>(
925       MakeUnique<opt::LoopFissionPass>(threshold));
926 }
927 
CreateLoopFusionPass(size_t max_registers_per_loop)928 Optimizer::PassToken CreateLoopFusionPass(size_t max_registers_per_loop) {
929   return MakeUnique<Optimizer::PassToken::Impl>(
930       MakeUnique<opt::LoopFusionPass>(max_registers_per_loop));
931 }
932 
CreateLoopInvariantCodeMotionPass()933 Optimizer::PassToken CreateLoopInvariantCodeMotionPass() {
934   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::LICMPass>());
935 }
936 
CreateLoopPeelingPass()937 Optimizer::PassToken CreateLoopPeelingPass() {
938   return MakeUnique<Optimizer::PassToken::Impl>(
939       MakeUnique<opt::LoopPeelingPass>());
940 }
941 
CreateLoopUnswitchPass()942 Optimizer::PassToken CreateLoopUnswitchPass() {
943   return MakeUnique<Optimizer::PassToken::Impl>(
944       MakeUnique<opt::LoopUnswitchPass>());
945 }
946 
CreateRedundancyEliminationPass()947 Optimizer::PassToken CreateRedundancyEliminationPass() {
948   return MakeUnique<Optimizer::PassToken::Impl>(
949       MakeUnique<opt::RedundancyEliminationPass>());
950 }
951 
CreateRemoveDuplicatesPass()952 Optimizer::PassToken CreateRemoveDuplicatesPass() {
953   return MakeUnique<Optimizer::PassToken::Impl>(
954       MakeUnique<opt::RemoveDuplicatesPass>());
955 }
956 
CreateScalarReplacementPass(uint32_t size_limit)957 Optimizer::PassToken CreateScalarReplacementPass(uint32_t size_limit) {
958   return MakeUnique<Optimizer::PassToken::Impl>(
959       MakeUnique<opt::ScalarReplacementPass>(size_limit));
960 }
961 
CreatePrivateToLocalPass()962 Optimizer::PassToken CreatePrivateToLocalPass() {
963   return MakeUnique<Optimizer::PassToken::Impl>(
964       MakeUnique<opt::PrivateToLocalPass>());
965 }
966 
CreateCCPPass()967 Optimizer::PassToken CreateCCPPass() {
968   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::CCPPass>());
969 }
970 
CreateWorkaround1209Pass()971 Optimizer::PassToken CreateWorkaround1209Pass() {
972   return MakeUnique<Optimizer::PassToken::Impl>(
973       MakeUnique<opt::Workaround1209>());
974 }
975 
CreateIfConversionPass()976 Optimizer::PassToken CreateIfConversionPass() {
977   return MakeUnique<Optimizer::PassToken::Impl>(
978       MakeUnique<opt::IfConversion>());
979 }
980 
CreateReplaceInvalidOpcodePass()981 Optimizer::PassToken CreateReplaceInvalidOpcodePass() {
982   return MakeUnique<Optimizer::PassToken::Impl>(
983       MakeUnique<opt::ReplaceInvalidOpcodePass>());
984 }
985 
CreateSimplificationPass()986 Optimizer::PassToken CreateSimplificationPass() {
987   return MakeUnique<Optimizer::PassToken::Impl>(
988       MakeUnique<opt::SimplificationPass>());
989 }
990 
CreateLoopUnrollPass(bool fully_unroll,int factor)991 Optimizer::PassToken CreateLoopUnrollPass(bool fully_unroll, int factor) {
992   return MakeUnique<Optimizer::PassToken::Impl>(
993       MakeUnique<opt::LoopUnroller>(fully_unroll, factor));
994 }
995 
CreateSSARewritePass()996 Optimizer::PassToken CreateSSARewritePass() {
997   return MakeUnique<Optimizer::PassToken::Impl>(
998       MakeUnique<opt::SSARewritePass>());
999 }
1000 
CreateCopyPropagateArraysPass()1001 Optimizer::PassToken CreateCopyPropagateArraysPass() {
1002   return MakeUnique<Optimizer::PassToken::Impl>(
1003       MakeUnique<opt::CopyPropagateArrays>());
1004 }
1005 
CreateVectorDCEPass()1006 Optimizer::PassToken CreateVectorDCEPass() {
1007   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::VectorDCE>());
1008 }
1009 
CreateReduceLoadSizePass(double load_replacement_threshold)1010 Optimizer::PassToken CreateReduceLoadSizePass(
1011     double load_replacement_threshold) {
1012   return MakeUnique<Optimizer::PassToken::Impl>(
1013       MakeUnique<opt::ReduceLoadSize>(load_replacement_threshold));
1014 }
1015 
CreateCombineAccessChainsPass()1016 Optimizer::PassToken CreateCombineAccessChainsPass() {
1017   return MakeUnique<Optimizer::PassToken::Impl>(
1018       MakeUnique<opt::CombineAccessChains>());
1019 }
1020 
CreateUpgradeMemoryModelPass()1021 Optimizer::PassToken CreateUpgradeMemoryModelPass() {
1022   return MakeUnique<Optimizer::PassToken::Impl>(
1023       MakeUnique<opt::UpgradeMemoryModel>());
1024 }
1025 
CreateInstBindlessCheckPass(uint32_t shader_id)1026 Optimizer::PassToken CreateInstBindlessCheckPass(uint32_t shader_id) {
1027   return MakeUnique<Optimizer::PassToken::Impl>(
1028       MakeUnique<opt::InstBindlessCheckPass>(shader_id));
1029 }
1030 
CreateInstDebugPrintfPass(uint32_t desc_set,uint32_t shader_id)1031 Optimizer::PassToken CreateInstDebugPrintfPass(uint32_t desc_set,
1032                                                uint32_t shader_id) {
1033   return MakeUnique<Optimizer::PassToken::Impl>(
1034       MakeUnique<opt::InstDebugPrintfPass>(desc_set, shader_id));
1035 }
1036 
CreateInstBuffAddrCheckPass(uint32_t shader_id)1037 Optimizer::PassToken CreateInstBuffAddrCheckPass(uint32_t shader_id) {
1038   return MakeUnique<Optimizer::PassToken::Impl>(
1039       MakeUnique<opt::InstBuffAddrCheckPass>(shader_id));
1040 }
1041 
CreateConvertRelaxedToHalfPass()1042 Optimizer::PassToken CreateConvertRelaxedToHalfPass() {
1043   return MakeUnique<Optimizer::PassToken::Impl>(
1044       MakeUnique<opt::ConvertToHalfPass>());
1045 }
1046 
CreateRelaxFloatOpsPass()1047 Optimizer::PassToken CreateRelaxFloatOpsPass() {
1048   return MakeUnique<Optimizer::PassToken::Impl>(
1049       MakeUnique<opt::RelaxFloatOpsPass>());
1050 }
1051 
CreateCodeSinkingPass()1052 Optimizer::PassToken CreateCodeSinkingPass() {
1053   return MakeUnique<Optimizer::PassToken::Impl>(
1054       MakeUnique<opt::CodeSinkingPass>());
1055 }
1056 
CreateFixStorageClassPass()1057 Optimizer::PassToken CreateFixStorageClassPass() {
1058   return MakeUnique<Optimizer::PassToken::Impl>(
1059       MakeUnique<opt::FixStorageClass>());
1060 }
1061 
CreateGraphicsRobustAccessPass()1062 Optimizer::PassToken CreateGraphicsRobustAccessPass() {
1063   return MakeUnique<Optimizer::PassToken::Impl>(
1064       MakeUnique<opt::GraphicsRobustAccessPass>());
1065 }
1066 
CreateReplaceDescArrayAccessUsingVarIndexPass()1067 Optimizer::PassToken CreateReplaceDescArrayAccessUsingVarIndexPass() {
1068   return MakeUnique<Optimizer::PassToken::Impl>(
1069       MakeUnique<opt::ReplaceDescArrayAccessUsingVarIndex>());
1070 }
1071 
CreateSpreadVolatileSemanticsPass()1072 Optimizer::PassToken CreateSpreadVolatileSemanticsPass() {
1073   return MakeUnique<Optimizer::PassToken::Impl>(
1074       MakeUnique<opt::SpreadVolatileSemantics>());
1075 }
1076 
CreateDescriptorScalarReplacementPass()1077 Optimizer::PassToken CreateDescriptorScalarReplacementPass() {
1078   return MakeUnique<Optimizer::PassToken::Impl>(
1079       MakeUnique<opt::DescriptorScalarReplacement>());
1080 }
1081 
CreateWrapOpKillPass()1082 Optimizer::PassToken CreateWrapOpKillPass() {
1083   return MakeUnique<Optimizer::PassToken::Impl>(MakeUnique<opt::WrapOpKill>());
1084 }
1085 
CreateAmdExtToKhrPass()1086 Optimizer::PassToken CreateAmdExtToKhrPass() {
1087   return MakeUnique<Optimizer::PassToken::Impl>(
1088       MakeUnique<opt::AmdExtensionToKhrPass>());
1089 }
1090 
CreateInterpolateFixupPass()1091 Optimizer::PassToken CreateInterpolateFixupPass() {
1092   return MakeUnique<Optimizer::PassToken::Impl>(
1093       MakeUnique<opt::InterpFixupPass>());
1094 }
1095 
CreateEliminateDeadInputComponentsPass()1096 Optimizer::PassToken CreateEliminateDeadInputComponentsPass() {
1097   return MakeUnique<Optimizer::PassToken::Impl>(
1098       MakeUnique<opt::EliminateDeadIOComponentsPass>(spv::StorageClass::Input,
1099                                                      /* safe_mode */ false));
1100 }
1101 
CreateEliminateDeadOutputComponentsPass()1102 Optimizer::PassToken CreateEliminateDeadOutputComponentsPass() {
1103   return MakeUnique<Optimizer::PassToken::Impl>(
1104       MakeUnique<opt::EliminateDeadIOComponentsPass>(spv::StorageClass::Output,
1105                                                      /* safe_mode */ false));
1106 }
1107 
CreateEliminateDeadInputComponentsSafePass()1108 Optimizer::PassToken CreateEliminateDeadInputComponentsSafePass() {
1109   return MakeUnique<Optimizer::PassToken::Impl>(
1110       MakeUnique<opt::EliminateDeadIOComponentsPass>(spv::StorageClass::Input,
1111                                                      /* safe_mode */ true));
1112 }
1113 
CreateAnalyzeLiveInputPass(std::unordered_set<uint32_t> * live_locs,std::unordered_set<uint32_t> * live_builtins)1114 Optimizer::PassToken CreateAnalyzeLiveInputPass(
1115     std::unordered_set<uint32_t>* live_locs,
1116     std::unordered_set<uint32_t>* live_builtins) {
1117   return MakeUnique<Optimizer::PassToken::Impl>(
1118       MakeUnique<opt::AnalyzeLiveInputPass>(live_locs, live_builtins));
1119 }
1120 
CreateEliminateDeadOutputStoresPass(std::unordered_set<uint32_t> * live_locs,std::unordered_set<uint32_t> * live_builtins)1121 Optimizer::PassToken CreateEliminateDeadOutputStoresPass(
1122     std::unordered_set<uint32_t>* live_locs,
1123     std::unordered_set<uint32_t>* live_builtins) {
1124   return MakeUnique<Optimizer::PassToken::Impl>(
1125       MakeUnique<opt::EliminateDeadOutputStoresPass>(live_locs, live_builtins));
1126 }
1127 
CreateConvertToSampledImagePass(const std::vector<opt::DescriptorSetAndBinding> & descriptor_set_binding_pairs)1128 Optimizer::PassToken CreateConvertToSampledImagePass(
1129     const std::vector<opt::DescriptorSetAndBinding>&
1130         descriptor_set_binding_pairs) {
1131   return MakeUnique<Optimizer::PassToken::Impl>(
1132       MakeUnique<opt::ConvertToSampledImagePass>(descriptor_set_binding_pairs));
1133 }
1134 
CreateInterfaceVariableScalarReplacementPass()1135 Optimizer::PassToken CreateInterfaceVariableScalarReplacementPass() {
1136   return MakeUnique<Optimizer::PassToken::Impl>(
1137       MakeUnique<opt::InterfaceVariableScalarReplacement>());
1138 }
1139 
CreateRemoveDontInlinePass()1140 Optimizer::PassToken CreateRemoveDontInlinePass() {
1141   return MakeUnique<Optimizer::PassToken::Impl>(
1142       MakeUnique<opt::RemoveDontInline>());
1143 }
1144 
CreateFixFuncCallArgumentsPass()1145 Optimizer::PassToken CreateFixFuncCallArgumentsPass() {
1146   return MakeUnique<Optimizer::PassToken::Impl>(
1147       MakeUnique<opt::FixFuncCallArgumentsPass>());
1148 }
1149 
CreateTrimCapabilitiesPass()1150 Optimizer::PassToken CreateTrimCapabilitiesPass() {
1151   return MakeUnique<Optimizer::PassToken::Impl>(
1152       MakeUnique<opt::TrimCapabilitiesPass>());
1153 }
1154 
CreateSwitchDescriptorSetPass(uint32_t from,uint32_t to)1155 Optimizer::PassToken CreateSwitchDescriptorSetPass(uint32_t from, uint32_t to) {
1156   return MakeUnique<Optimizer::PassToken::Impl>(
1157       MakeUnique<opt::SwitchDescriptorSetPass>(from, to));
1158 }
1159 
CreateInvocationInterlockPlacementPass()1160 Optimizer::PassToken CreateInvocationInterlockPlacementPass() {
1161   return MakeUnique<Optimizer::PassToken::Impl>(
1162       MakeUnique<opt::InvocationInterlockPlacementPass>());
1163 }
1164 
CreateModifyMaximalReconvergencePass(bool add)1165 Optimizer::PassToken CreateModifyMaximalReconvergencePass(bool add) {
1166   return MakeUnique<Optimizer::PassToken::Impl>(
1167       MakeUnique<opt::ModifyMaximalReconvergence>(add));
1168 }
1169 }  // namespace spvtools
1170 
1171 extern "C" {
1172 
spvOptimizerCreate(spv_target_env env)1173 SPIRV_TOOLS_EXPORT spv_optimizer_t* spvOptimizerCreate(spv_target_env env) {
1174   return reinterpret_cast<spv_optimizer_t*>(new spvtools::Optimizer(env));
1175 }
1176 
spvOptimizerDestroy(spv_optimizer_t * optimizer)1177 SPIRV_TOOLS_EXPORT void spvOptimizerDestroy(spv_optimizer_t* optimizer) {
1178   delete reinterpret_cast<spvtools::Optimizer*>(optimizer);
1179 }
1180 
spvOptimizerSetMessageConsumer(spv_optimizer_t * optimizer,spv_message_consumer consumer)1181 SPIRV_TOOLS_EXPORT void spvOptimizerSetMessageConsumer(
1182     spv_optimizer_t* optimizer, spv_message_consumer consumer) {
1183   reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1184       SetMessageConsumer(
1185           [consumer](spv_message_level_t level, const char* source,
1186                      const spv_position_t& position, const char* message) {
1187             return consumer(level, source, &position, message);
1188           });
1189 }
1190 
spvOptimizerRegisterLegalizationPasses(spv_optimizer_t * optimizer)1191 SPIRV_TOOLS_EXPORT void spvOptimizerRegisterLegalizationPasses(
1192     spv_optimizer_t* optimizer) {
1193   reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1194       RegisterLegalizationPasses();
1195 }
1196 
spvOptimizerRegisterPerformancePasses(spv_optimizer_t * optimizer)1197 SPIRV_TOOLS_EXPORT void spvOptimizerRegisterPerformancePasses(
1198     spv_optimizer_t* optimizer) {
1199   reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1200       RegisterPerformancePasses();
1201 }
1202 
spvOptimizerRegisterSizePasses(spv_optimizer_t * optimizer)1203 SPIRV_TOOLS_EXPORT void spvOptimizerRegisterSizePasses(
1204     spv_optimizer_t* optimizer) {
1205   reinterpret_cast<spvtools::Optimizer*>(optimizer)->RegisterSizePasses();
1206 }
1207 
spvOptimizerRegisterPassFromFlag(spv_optimizer_t * optimizer,const char * flag)1208 SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassFromFlag(
1209     spv_optimizer_t* optimizer, const char* flag)
1210 {
1211   return reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1212       RegisterPassFromFlag(flag);
1213 }
1214 
spvOptimizerRegisterPassesFromFlags(spv_optimizer_t * optimizer,const char ** flags,const size_t flag_count)1215 SPIRV_TOOLS_EXPORT bool spvOptimizerRegisterPassesFromFlags(
1216     spv_optimizer_t* optimizer, const char** flags, const size_t flag_count) {
1217   std::vector<std::string> opt_flags =
1218       spvtools::GetVectorOfStrings(flags, flag_count);
1219   return reinterpret_cast<spvtools::Optimizer*>(optimizer)
1220       ->RegisterPassesFromFlags(opt_flags, false);
1221 }
1222 
1223 SPIRV_TOOLS_EXPORT bool
spvOptimizerRegisterPassesFromFlagsWhilePreservingTheInterface(spv_optimizer_t * optimizer,const char ** flags,const size_t flag_count)1224 spvOptimizerRegisterPassesFromFlagsWhilePreservingTheInterface(
1225     spv_optimizer_t* optimizer, const char** flags, const size_t flag_count) {
1226   std::vector<std::string> opt_flags =
1227       spvtools::GetVectorOfStrings(flags, flag_count);
1228   return reinterpret_cast<spvtools::Optimizer*>(optimizer)
1229       ->RegisterPassesFromFlags(opt_flags, true);
1230 }
1231 
1232 SPIRV_TOOLS_EXPORT
spvOptimizerRun(spv_optimizer_t * optimizer,const uint32_t * binary,const size_t word_count,spv_binary * optimized_binary,const spv_optimizer_options options)1233 spv_result_t spvOptimizerRun(spv_optimizer_t* optimizer,
1234                              const uint32_t* binary,
1235                              const size_t word_count,
1236                              spv_binary* optimized_binary,
1237                              const spv_optimizer_options options) {
1238   std::vector<uint32_t> optimized;
1239 
1240   if (!reinterpret_cast<spvtools::Optimizer*>(optimizer)->
1241       Run(binary, word_count, &optimized, options)) {
1242     return SPV_ERROR_INTERNAL;
1243   }
1244 
1245   auto result_binary = new spv_binary_t();
1246   if (!result_binary) {
1247       *optimized_binary = nullptr;
1248       return SPV_ERROR_OUT_OF_MEMORY;
1249   }
1250 
1251   result_binary->code = new uint32_t[optimized.size()];
1252   if (!result_binary->code) {
1253       delete result_binary;
1254       *optimized_binary = nullptr;
1255       return SPV_ERROR_OUT_OF_MEMORY;
1256   }
1257   result_binary->wordCount = optimized.size();
1258 
1259   memcpy(result_binary->code, optimized.data(),
1260          optimized.size() * sizeof(uint32_t));
1261 
1262   *optimized_binary = result_binary;
1263 
1264   return SPV_SUCCESS;
1265 }
1266 
1267 }  // extern "C"
1268