xref: /aosp_15_r20/external/swiftshader/third_party/SPIRV-Tools/source/opt/scalar_analysis.cpp (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "source/opt/scalar_analysis.h"
16 
17 #include <functional>
18 #include <string>
19 #include <utility>
20 
21 #include "source/opt/ir_context.h"
22 
23 // Transforms a given scalar operation instruction into a DAG representation.
24 //
25 // 1. Take an instruction and traverse its operands until we reach a
26 // constant node or an instruction which we do not know how to compute the
27 // value, such as a load.
28 //
29 // 2. Create a new node for each instruction traversed and build the nodes for
30 // the in operands of that instruction as well.
31 //
32 // 3. Add the operand nodes as children of the first and hash the node. Use the
33 // hash to see if the node is already in the cache. We ensure the children are
34 // always in sorted order so that two nodes with the same children but inserted
35 // in a different order have the same hash and so that the overloaded operator==
36 // will return true. If the node is already in the cache return the cached
37 // version instead.
38 //
39 // 4. The created DAG can then be simplified by
40 // ScalarAnalysis::SimplifyExpression, implemented in
41 // scalar_analysis_simplification.cpp. See that file for further information on
42 // the simplification process.
43 //
44 
45 namespace spvtools {
46 namespace opt {
47 
48 uint32_t SENode::NumberOfNodes = 0;
49 
ScalarEvolutionAnalysis(IRContext * context)50 ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context)
51     : context_(context), pretend_equal_{} {
52   // Create and cached the CantComputeNode.
53   cached_cant_compute_ =
54       GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
55 }
56 
CreateNegation(SENode * operand)57 SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
58   // If operand is can't compute then the whole graph is can't compute.
59   if (operand->IsCantCompute()) return CreateCantComputeNode();
60 
61   if (operand->GetType() == SENode::Constant) {
62     return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
63   }
64   std::unique_ptr<SENode> negation_node{new SENegative(this)};
65   negation_node->AddChild(operand);
66   return GetCachedOrAdd(std::move(negation_node));
67 }
68 
CreateConstant(int64_t integer)69 SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
70   return GetCachedOrAdd(
71       std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
72 }
73 
CreateRecurrentExpression(const Loop * loop,SENode * offset,SENode * coefficient)74 SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
75     const Loop* loop, SENode* offset, SENode* coefficient) {
76   assert(loop && "Recurrent add expressions must have a valid loop.");
77 
78   // If operands are can't compute then the whole graph is can't compute.
79   if (offset->IsCantCompute() || coefficient->IsCantCompute())
80     return CreateCantComputeNode();
81 
82   const Loop* loop_to_use = nullptr;
83   if (pretend_equal_[loop]) {
84     loop_to_use = pretend_equal_[loop];
85   } else {
86     loop_to_use = loop;
87   }
88 
89   std::unique_ptr<SERecurrentNode> phi_node{
90       new SERecurrentNode(this, loop_to_use)};
91   phi_node->AddOffset(offset);
92   phi_node->AddCoefficient(coefficient);
93 
94   return GetCachedOrAdd(std::move(phi_node));
95 }
96 
AnalyzeMultiplyOp(const Instruction * multiply)97 SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
98     const Instruction* multiply) {
99   assert(multiply->opcode() == spv::Op::OpIMul &&
100          "Multiply node did not come from a multiply instruction");
101   analysis::DefUseManager* def_use = context_->get_def_use_mgr();
102 
103   SENode* op1 =
104       AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
105   SENode* op2 =
106       AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));
107 
108   return CreateMultiplyNode(op1, op2);
109 }
110 
CreateMultiplyNode(SENode * operand_1,SENode * operand_2)111 SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
112                                                     SENode* operand_2) {
113   // If operands are can't compute then the whole graph is can't compute.
114   if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
115     return CreateCantComputeNode();
116 
117   if (operand_1->GetType() == SENode::Constant &&
118       operand_2->GetType() == SENode::Constant) {
119     return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
120                           operand_2->AsSEConstantNode()->FoldToSingleValue());
121   }
122 
123   std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};
124 
125   multiply_node->AddChild(operand_1);
126   multiply_node->AddChild(operand_2);
127 
128   return GetCachedOrAdd(std::move(multiply_node));
129 }
130 
CreateSubtraction(SENode * operand_1,SENode * operand_2)131 SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
132                                                    SENode* operand_2) {
133   // Fold if both operands are constant.
134   if (operand_1->GetType() == SENode::Constant &&
135       operand_2->GetType() == SENode::Constant) {
136     return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
137                           operand_2->AsSEConstantNode()->FoldToSingleValue());
138   }
139 
140   return CreateAddNode(operand_1, CreateNegation(operand_2));
141 }
142 
CreateAddNode(SENode * operand_1,SENode * operand_2)143 SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
144                                                SENode* operand_2) {
145   // Fold if both operands are constant and the |simplify| flag is true.
146   if (operand_1->GetType() == SENode::Constant &&
147       operand_2->GetType() == SENode::Constant) {
148     return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
149                           operand_2->AsSEConstantNode()->FoldToSingleValue());
150   }
151 
152   // If operands are can't compute then the whole graph is can't compute.
153   if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
154     return CreateCantComputeNode();
155 
156   std::unique_ptr<SENode> add_node{new SEAddNode(this)};
157 
158   add_node->AddChild(operand_1);
159   add_node->AddChild(operand_2);
160 
161   return GetCachedOrAdd(std::move(add_node));
162 }
163 
AnalyzeInstruction(const Instruction * inst)164 SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) {
165   auto itr = recurrent_node_map_.find(inst);
166   if (itr != recurrent_node_map_.end()) return itr->second;
167 
168   SENode* output = nullptr;
169   switch (inst->opcode()) {
170     case spv::Op::OpPhi: {
171       output = AnalyzePhiInstruction(inst);
172       break;
173     }
174     case spv::Op::OpConstant:
175     case spv::Op::OpConstantNull: {
176       output = AnalyzeConstant(inst);
177       break;
178     }
179     case spv::Op::OpISub:
180     case spv::Op::OpIAdd: {
181       output = AnalyzeAddOp(inst);
182       break;
183     }
184     case spv::Op::OpIMul: {
185       output = AnalyzeMultiplyOp(inst);
186       break;
187     }
188     default: {
189       output = CreateValueUnknownNode(inst);
190       break;
191     }
192   }
193 
194   return output;
195 }
196 
AnalyzeConstant(const Instruction * inst)197 SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) {
198   if (inst->opcode() == spv::Op::OpConstantNull) return CreateConstant(0);
199 
200   assert(inst->opcode() == spv::Op::OpConstant);
201   assert(inst->NumInOperands() == 1);
202   int64_t value = 0;
203 
204   // Look up the instruction in the constant manager.
205   const analysis::Constant* constant =
206       context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());
207 
208   if (!constant) return CreateCantComputeNode();
209 
210   const analysis::IntConstant* int_constant = constant->AsIntConstant();
211 
212   // Exit out if it is a 64 bit integer.
213   if (!int_constant || int_constant->words().size() != 1)
214     return CreateCantComputeNode();
215 
216   if (int_constant->type()->AsInteger()->IsSigned()) {
217     value = int_constant->GetS32BitValue();
218   } else {
219     value = int_constant->GetU32BitValue();
220   }
221 
222   return CreateConstant(value);
223 }
224 
225 // Handles both addition and subtraction. If the |sub| flag is set then the
226 // addition will be op1+(-op2) otherwise op1+op2.
AnalyzeAddOp(const Instruction * inst)227 SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) {
228   assert((inst->opcode() == spv::Op::OpIAdd ||
229           inst->opcode() == spv::Op::OpISub) &&
230          "Add node must be created from a OpIAdd or OpISub instruction");
231 
232   analysis::DefUseManager* def_use = context_->get_def_use_mgr();
233 
234   SENode* op1 =
235       AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));
236 
237   SENode* op2 =
238       AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));
239 
240   // To handle subtraction we wrap the second operand in a unary negation node.
241   if (inst->opcode() == spv::Op::OpISub) {
242     op2 = CreateNegation(op2);
243   }
244 
245   return CreateAddNode(op1, op2);
246 }
247 
AnalyzePhiInstruction(const Instruction * phi)248 SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) {
249   // The phi should only have two incoming value pairs.
250   if (phi->NumInOperands() != 4) {
251     return CreateCantComputeNode();
252   }
253 
254   analysis::DefUseManager* def_use = context_->get_def_use_mgr();
255 
256   // Get the basic block this instruction belongs to.
257   BasicBlock* basic_block =
258       context_->get_instr_block(const_cast<Instruction*>(phi));
259 
260   // And then the function that the basic blocks belongs to.
261   Function* function = basic_block->GetParent();
262 
263   // Use the function to get the loop descriptor.
264   LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);
265 
266   // We only handle phis in loops at the moment.
267   if (!loop_descriptor) return CreateCantComputeNode();
268 
269   // Get the innermost loop which this block belongs to.
270   Loop* loop = (*loop_descriptor)[basic_block->id()];
271 
272   // If the loop doesn't exist or doesn't have a preheader or latch block, exit
273   // out.
274   if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
275       loop->GetHeaderBlock() != basic_block)
276     return recurrent_node_map_[phi] = CreateCantComputeNode();
277 
278   const Loop* loop_to_use = nullptr;
279   if (pretend_equal_[loop]) {
280     loop_to_use = pretend_equal_[loop];
281   } else {
282     loop_to_use = loop;
283   }
284   std::unique_ptr<SERecurrentNode> phi_node{
285       new SERecurrentNode(this, loop_to_use)};
286 
287   // We add the node to this map to allow it to be returned before the node is
288   // fully built. This is needed as the subsequent call to AnalyzeInstruction
289   // could lead back to this |phi| instruction so we return the pointer
290   // immediately in AnalyzeInstruction to break the recursion.
291   recurrent_node_map_[phi] = phi_node.get();
292 
293   // Traverse the operands of the instruction an create new nodes for each one.
294   for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
295     uint32_t value_id = phi->GetSingleWordInOperand(i);
296     uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);
297 
298     Instruction* value_inst = def_use->GetDef(value_id);
299     SENode* value_node = AnalyzeInstruction(value_inst);
300 
301     // If any operand is CantCompute then the whole graph is CantCompute.
302     if (value_node->IsCantCompute())
303       return recurrent_node_map_[phi] = CreateCantComputeNode();
304 
305     // If the value is coming from the preheader block then the value is the
306     // initial value of the phi.
307     if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
308       phi_node->AddOffset(value_node);
309     } else if (incoming_label_id == loop->GetLatchBlock()->id()) {
310       // Assumed to be in the form of step + phi.
311       if (value_node->GetType() != SENode::Add)
312         return recurrent_node_map_[phi] = CreateCantComputeNode();
313 
314       SENode* step_node = nullptr;
315       SENode* phi_operand = nullptr;
316       SENode* operand_1 = value_node->GetChild(0);
317       SENode* operand_2 = value_node->GetChild(1);
318 
319       // Find which node is the step term.
320       if (!operand_1->AsSERecurrentNode())
321         step_node = operand_1;
322       else if (!operand_2->AsSERecurrentNode())
323         step_node = operand_2;
324 
325       // Find which node is the recurrent expression.
326       if (operand_1->AsSERecurrentNode())
327         phi_operand = operand_1;
328       else if (operand_2->AsSERecurrentNode())
329         phi_operand = operand_2;
330 
331       // If it is not in the form step + phi exit out.
332       if (!(step_node && phi_operand))
333         return recurrent_node_map_[phi] = CreateCantComputeNode();
334 
335       // If the phi operand is not the same phi node exit out.
336       if (phi_operand != phi_node.get())
337         return recurrent_node_map_[phi] = CreateCantComputeNode();
338 
339       if (!IsLoopInvariant(loop, step_node))
340         return recurrent_node_map_[phi] = CreateCantComputeNode();
341 
342       phi_node->AddCoefficient(step_node);
343     }
344   }
345 
346   // Once the node is fully built we update the map with the version from the
347   // cache (if it has already been added to the cache).
348   return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
349 }
350 
CreateValueUnknownNode(const Instruction * inst)351 SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
352     const Instruction* inst) {
353   std::unique_ptr<SEValueUnknown> load_node{
354       new SEValueUnknown(this, inst->result_id())};
355   return GetCachedOrAdd(std::move(load_node));
356 }
357 
CreateCantComputeNode()358 SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
359   return cached_cant_compute_;
360 }
361 
362 // Add the created node into the cache of nodes. If it already exists return it.
GetCachedOrAdd(std::unique_ptr<SENode> prospective_node)363 SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
364     std::unique_ptr<SENode> prospective_node) {
365   auto itr = node_cache_.find(prospective_node);
366   if (itr != node_cache_.end()) {
367     return (*itr).get();
368   }
369 
370   SENode* raw_ptr_to_node = prospective_node.get();
371   node_cache_.insert(std::move(prospective_node));
372   return raw_ptr_to_node;
373 }
374 
IsLoopInvariant(const Loop * loop,const SENode * node) const375 bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop,
376                                               const SENode* node) const {
377   for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
378     if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
379       const BasicBlock* header = rec->GetLoop()->GetHeaderBlock();
380 
381       // If the loop which the recurrent expression belongs to is either |loop
382       // or a nested loop inside |loop| then we assume it is variant.
383       if (loop->IsInsideLoop(header)) {
384         return false;
385       }
386     } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
387       // If the instruction is inside the loop we conservatively assume it is
388       // loop variant.
389       if (loop->IsInsideLoop(unknown->ResultId())) return false;
390     }
391   }
392 
393   return true;
394 }
395 
GetCoefficientFromRecurrentTerm(SENode * node,const Loop * loop)396 SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
397     SENode* node, const Loop* loop) {
398   // Traverse the DAG to find the recurrent expression belonging to |loop|.
399   for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
400     SERecurrentNode* rec = itr->AsSERecurrentNode();
401     if (rec && rec->GetLoop() == loop) {
402       return rec->GetCoefficient();
403     }
404   }
405   return CreateConstant(0);
406 }
407 
UpdateChildNode(SENode * parent,SENode * old_child,SENode * new_child)408 SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
409                                                  SENode* old_child,
410                                                  SENode* new_child) {
411   // Only handles add.
412   if (parent->GetType() != SENode::Add) return parent;
413 
414   std::vector<SENode*> new_children;
415   for (SENode* child : *parent) {
416     if (child == old_child) {
417       new_children.push_back(new_child);
418     } else {
419       new_children.push_back(child);
420     }
421   }
422 
423   std::unique_ptr<SENode> add_node{new SEAddNode(this)};
424   for (SENode* child : new_children) {
425     add_node->AddChild(child);
426   }
427 
428   return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
429 }
430 
431 // Rebuild the |node| eliminating, if it exists, the recurrent term which
432 // belongs to the |loop|.
BuildGraphWithoutRecurrentTerm(SENode * node,const Loop * loop)433 SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
434     SENode* node, const Loop* loop) {
435   // If the node is already a recurrent expression belonging to loop then just
436   // return the offset.
437   SERecurrentNode* recurrent = node->AsSERecurrentNode();
438   if (recurrent) {
439     if (recurrent->GetLoop() == loop) {
440       return recurrent->GetOffset();
441     } else {
442       return node;
443     }
444   }
445 
446   std::vector<SENode*> new_children;
447   // Otherwise find the recurrent node in the children of this node.
448   for (auto itr : *node) {
449     recurrent = itr->AsSERecurrentNode();
450     if (recurrent && recurrent->GetLoop() == loop) {
451       new_children.push_back(recurrent->GetOffset());
452     } else {
453       new_children.push_back(itr);
454     }
455   }
456 
457   std::unique_ptr<SENode> add_node{new SEAddNode(this)};
458   for (SENode* child : new_children) {
459     add_node->AddChild(child);
460   }
461 
462   return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
463 }
464 
465 // Return the recurrent term belonging to |loop| if it appears in the graph
466 // starting at |node| or null if it doesn't.
GetRecurrentTerm(SENode * node,const Loop * loop)467 SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node,
468                                                            const Loop* loop) {
469   for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
470     SERecurrentNode* rec = itr->AsSERecurrentNode();
471     if (rec && rec->GetLoop() == loop) {
472       return rec;
473     }
474   }
475   return nullptr;
476 }
AsString() const477 std::string SENode::AsString() const {
478   switch (GetType()) {
479     case Constant:
480       return "Constant";
481     case RecurrentAddExpr:
482       return "RecurrentAddExpr";
483     case Add:
484       return "Add";
485     case Negative:
486       return "Negative";
487     case Multiply:
488       return "Multiply";
489     case ValueUnknown:
490       return "Value Unknown";
491     case CanNotCompute:
492       return "Can not compute";
493   }
494   return "NULL";
495 }
496 
operator ==(const SENode & other) const497 bool SENode::operator==(const SENode& other) const {
498   if (GetType() != other.GetType()) return false;
499 
500   if (other.GetChildren().size() != children_.size()) return false;
501 
502   const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();
503 
504   // Check the children are the same, for SERecurrentNodes we need to check the
505   // offset and coefficient manually as the child vector is sorted by ids so the
506   // offset/coefficient information is lost.
507   if (!this_as_recurrent) {
508     for (size_t index = 0; index < children_.size(); ++index) {
509       if (other.GetChildren()[index] != children_[index]) return false;
510     }
511   } else {
512     const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();
513 
514     // We've already checked the types are the same, this should not fail if
515     // this->AsSERecurrentNode() succeeded.
516     assert(other_as_recurrent);
517 
518     if (this_as_recurrent->GetCoefficient() !=
519         other_as_recurrent->GetCoefficient())
520       return false;
521 
522     if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
523       return false;
524 
525     if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
526       return false;
527   }
528 
529   // If we're dealing with a value unknown node check both nodes were created by
530   // the same instruction.
531   if (GetType() == SENode::ValueUnknown) {
532     if (AsSEValueUnknown()->ResultId() !=
533         other.AsSEValueUnknown()->ResultId()) {
534       return false;
535     }
536   }
537 
538   if (AsSEConstantNode()) {
539     if (AsSEConstantNode()->FoldToSingleValue() !=
540         other.AsSEConstantNode()->FoldToSingleValue())
541       return false;
542   }
543 
544   return true;
545 }
546 
operator !=(const SENode & other) const547 bool SENode::operator!=(const SENode& other) const { return !(*this == other); }
548 
549 namespace {
550 // Helper functions to insert 32/64 bit values into the 32 bit hash string. This
551 // allows us to add pointers to the string by reinterpreting the pointers as
552 // uintptr_t. PushToString will deduce the type, call sizeof on it and use
553 // that size to call into the correct PushToStringImpl functor depending on
554 // whether it is 32 or 64 bit.
555 
556 template <typename T, size_t size_of_t>
557 struct PushToStringImpl;
558 
559 template <typename T>
560 struct PushToStringImpl<T, 8> {
operator ()spvtools::opt::__anon2670a6200111::PushToStringImpl561   void operator()(T id, std::u32string* str) {
562     str->push_back(static_cast<uint32_t>(id >> 32));
563     str->push_back(static_cast<uint32_t>(id));
564   }
565 };
566 
567 template <typename T>
568 struct PushToStringImpl<T, 4> {
operator ()spvtools::opt::__anon2670a6200111::PushToStringImpl569   void operator()(T id, std::u32string* str) {
570     str->push_back(static_cast<uint32_t>(id));
571   }
572 };
573 
574 template <typename T>
PushToString(T id,std::u32string * str)575 void PushToString(T id, std::u32string* str) {
576   PushToStringImpl<T, sizeof(T)>{}(id, str);
577 }
578 
579 }  // namespace
580 
581 // Implements the hashing of SENodes.
operator ()(const SENode * node) const582 size_t SENodeHash::operator()(const SENode* node) const {
583   // Concatenate the terms into a string which we can hash.
584   std::u32string hash_string{};
585 
586   // Hashing the type as a string is safer than hashing the enum as the enum is
587   // very likely to collide with constants.
588   for (char ch : node->AsString()) {
589     hash_string.push_back(static_cast<char32_t>(ch));
590   }
591 
592   // We just ignore the literal value unless it is a constant.
593   if (node->GetType() == SENode::Constant)
594     PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);
595 
596   const SERecurrentNode* recurrent = node->AsSERecurrentNode();
597 
598   // If we're dealing with a recurrent expression hash the loop as well so that
599   // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
600   if (recurrent) {
601     PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
602                  &hash_string);
603 
604     // Recurrent expressions can't be hashed using the normal method as the
605     // order of coefficient and offset matters to the hash.
606     PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
607                  &hash_string);
608     PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
609                  &hash_string);
610 
611     return std::hash<std::u32string>{}(hash_string);
612   }
613 
614   // Hash the result id of the original instruction which created this node if
615   // it is a value unknown node.
616   if (node->GetType() == SENode::ValueUnknown) {
617     PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
618   }
619 
620   // Hash the pointers of the child nodes, each SENode has a unique pointer
621   // associated with it.
622   const std::vector<SENode*>& children = node->GetChildren();
623   for (const SENode* child : children) {
624     PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
625   }
626 
627   return std::hash<std::u32string>{}(hash_string);
628 }
629 
630 // This overload is the actual overload used by the node_cache_ set.
operator ()(const std::unique_ptr<SENode> & node) const631 size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
632   return this->operator()(node.get());
633 }
634 
DumpDot(std::ostream & out,bool recurse) const635 void SENode::DumpDot(std::ostream& out, bool recurse) const {
636   size_t unique_id = std::hash<const SENode*>{}(this);
637   out << unique_id << " [label=\"" << AsString() << " ";
638   if (GetType() == SENode::Constant) {
639     out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
640   }
641   out << "\"]\n";
642   for (const SENode* child : children_) {
643     size_t child_unique_id = std::hash<const SENode*>{}(child);
644     out << unique_id << " -> " << child_unique_id << " \n";
645     if (recurse) child->DumpDot(out, true);
646   }
647 }
648 
649 namespace {
650 class IsGreaterThanZero {
651  public:
IsGreaterThanZero(IRContext * context)652   explicit IsGreaterThanZero(IRContext* context) : context_(context) {}
653 
654   // Determine if the value of |node| is always strictly greater than zero if
655   // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
656   // true. It returns true is the evaluation was able to conclude something, in
657   // which case the result is stored in |result|.
658   // The algorithm work by going through all the nodes and determine the
659   // sign of each of them.
Eval(const SENode * node,bool or_equal_zero,bool * result)660   bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
661     *result = false;
662     switch (Visit(node)) {
663       case Signedness::kPositiveOrNegative: {
664         return false;
665       }
666       case Signedness::kStrictlyNegative: {
667         *result = false;
668         break;
669       }
670       case Signedness::kNegative: {
671         if (!or_equal_zero) {
672           return false;
673         }
674         *result = false;
675         break;
676       }
677       case Signedness::kStrictlyPositive: {
678         *result = true;
679         break;
680       }
681       case Signedness::kPositive: {
682         if (!or_equal_zero) {
683           return false;
684         }
685         *result = true;
686         break;
687       }
688     }
689     return true;
690   }
691 
692  private:
693   enum class Signedness {
694     kPositiveOrNegative,  // Yield a value positive or negative.
695     kStrictlyNegative,    // Yield a value strictly less than 0.
696     kNegative,            // Yield a value less or equal to 0.
697     kStrictlyPositive,    // Yield a value strictly greater than 0.
698     kPositive             // Yield a value greater or equal to 0.
699   };
700 
701   // Combine the signedness according to arithmetic rules of a given operator.
702   using Combiner = std::function<Signedness(Signedness, Signedness)>;
703 
704   // Returns a functor to interpret the signedness of 2 expressions as if they
705   // were added.
GetAddCombiner() const706   Combiner GetAddCombiner() const {
707     return [](Signedness lhs, Signedness rhs) {
708       switch (lhs) {
709         case Signedness::kPositiveOrNegative:
710           break;
711         case Signedness::kStrictlyNegative:
712           if (rhs == Signedness::kStrictlyNegative ||
713               rhs == Signedness::kNegative)
714             return lhs;
715           break;
716         case Signedness::kNegative: {
717           if (rhs == Signedness::kStrictlyNegative)
718             return Signedness::kStrictlyNegative;
719           if (rhs == Signedness::kNegative) return Signedness::kNegative;
720           break;
721         }
722         case Signedness::kStrictlyPositive: {
723           if (rhs == Signedness::kStrictlyPositive ||
724               rhs == Signedness::kPositive) {
725             return Signedness::kStrictlyPositive;
726           }
727           break;
728         }
729         case Signedness::kPositive: {
730           if (rhs == Signedness::kStrictlyPositive)
731             return Signedness::kStrictlyPositive;
732           if (rhs == Signedness::kPositive) return Signedness::kPositive;
733           break;
734         }
735       }
736       return Signedness::kPositiveOrNegative;
737     };
738   }
739 
740   // Returns a functor to interpret the signedness of 2 expressions as if they
741   // were multiplied.
GetMulCombiner() const742   Combiner GetMulCombiner() const {
743     return [](Signedness lhs, Signedness rhs) {
744       switch (lhs) {
745         case Signedness::kPositiveOrNegative:
746           break;
747         case Signedness::kStrictlyNegative: {
748           switch (rhs) {
749             case Signedness::kPositiveOrNegative: {
750               break;
751             }
752             case Signedness::kStrictlyNegative: {
753               return Signedness::kStrictlyPositive;
754             }
755             case Signedness::kNegative: {
756               return Signedness::kPositive;
757             }
758             case Signedness::kStrictlyPositive: {
759               return Signedness::kStrictlyNegative;
760             }
761             case Signedness::kPositive: {
762               return Signedness::kNegative;
763             }
764           }
765           break;
766         }
767         case Signedness::kNegative: {
768           switch (rhs) {
769             case Signedness::kPositiveOrNegative: {
770               break;
771             }
772             case Signedness::kStrictlyNegative:
773             case Signedness::kNegative: {
774               return Signedness::kPositive;
775             }
776             case Signedness::kStrictlyPositive:
777             case Signedness::kPositive: {
778               return Signedness::kNegative;
779             }
780           }
781           break;
782         }
783         case Signedness::kStrictlyPositive: {
784           return rhs;
785         }
786         case Signedness::kPositive: {
787           switch (rhs) {
788             case Signedness::kPositiveOrNegative: {
789               break;
790             }
791             case Signedness::kStrictlyNegative:
792             case Signedness::kNegative: {
793               return Signedness::kNegative;
794             }
795             case Signedness::kStrictlyPositive:
796             case Signedness::kPositive: {
797               return Signedness::kPositive;
798             }
799           }
800           break;
801         }
802       }
803       return Signedness::kPositiveOrNegative;
804     };
805   }
806 
Visit(const SENode * node)807   Signedness Visit(const SENode* node) {
808     switch (node->GetType()) {
809       case SENode::Constant:
810         return Visit(node->AsSEConstantNode());
811         break;
812       case SENode::RecurrentAddExpr:
813         return Visit(node->AsSERecurrentNode());
814         break;
815       case SENode::Negative:
816         return Visit(node->AsSENegative());
817         break;
818       case SENode::CanNotCompute:
819         return Visit(node->AsSECantCompute());
820         break;
821       case SENode::ValueUnknown:
822         return Visit(node->AsSEValueUnknown());
823         break;
824       case SENode::Add:
825         return VisitExpr(node, GetAddCombiner());
826         break;
827       case SENode::Multiply:
828         return VisitExpr(node, GetMulCombiner());
829         break;
830     }
831     return Signedness::kPositiveOrNegative;
832   }
833 
834   // Returns the signedness of a constant |node|.
Visit(const SEConstantNode * node)835   Signedness Visit(const SEConstantNode* node) {
836     if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
837     if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
838     if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
839     return Signedness::kPositiveOrNegative;
840   }
841 
842   // Returns the signedness of an unknown |node| based on its type.
Visit(const SEValueUnknown * node)843   Signedness Visit(const SEValueUnknown* node) {
844     Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId());
845     analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
846     assert(type && "Can't retrieve a type for the instruction");
847     analysis::Integer* int_type = type->AsInteger();
848     assert(type && "Can't retrieve an integer type for the instruction");
849     return int_type->IsSigned() ? Signedness::kPositiveOrNegative
850                                 : Signedness::kPositive;
851   }
852 
853   // Returns the signedness of a recurring expression.
Visit(const SERecurrentNode * node)854   Signedness Visit(const SERecurrentNode* node) {
855     Signedness coeff_sign = Visit(node->GetCoefficient());
856     // SERecurrentNode represent an affine expression in the range [0,
857     // loop_bound], so the result cannot be strictly positive or negative.
858     switch (coeff_sign) {
859       default:
860         break;
861       case Signedness::kStrictlyNegative:
862         coeff_sign = Signedness::kNegative;
863         break;
864       case Signedness::kStrictlyPositive:
865         coeff_sign = Signedness::kPositive;
866         break;
867     }
868     return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
869   }
870 
871   // Returns the signedness of a negation |node|.
Visit(const SENegative * node)872   Signedness Visit(const SENegative* node) {
873     switch (Visit(*node->begin())) {
874       case Signedness::kPositiveOrNegative: {
875         return Signedness::kPositiveOrNegative;
876       }
877       case Signedness::kStrictlyNegative: {
878         return Signedness::kStrictlyPositive;
879       }
880       case Signedness::kNegative: {
881         return Signedness::kPositive;
882       }
883       case Signedness::kStrictlyPositive: {
884         return Signedness::kStrictlyNegative;
885       }
886       case Signedness::kPositive: {
887         return Signedness::kNegative;
888       }
889     }
890     return Signedness::kPositiveOrNegative;
891   }
892 
Visit(const SECantCompute *)893   Signedness Visit(const SECantCompute*) {
894     return Signedness::kPositiveOrNegative;
895   }
896 
897   // Returns the signedness of a binary expression by using the combiner
898   // |reduce|.
VisitExpr(const SENode * node,std::function<Signedness (Signedness,Signedness)> reduce)899   Signedness VisitExpr(
900       const SENode* node,
901       std::function<Signedness(Signedness, Signedness)> reduce) {
902     Signedness result = Visit(*node->begin());
903     for (const SENode* operand : make_range(++node->begin(), node->end())) {
904       if (result == Signedness::kPositiveOrNegative) {
905         return Signedness::kPositiveOrNegative;
906       }
907       result = reduce(result, Visit(operand));
908     }
909     return result;
910   }
911 
912   IRContext* context_;
913 };
914 }  // namespace
915 
IsAlwaysGreaterThanZero(SENode * node,bool * is_gt_zero) const916 bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
917                                                       bool* is_gt_zero) const {
918   return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
919 }
920 
IsAlwaysGreaterOrEqualToZero(SENode * node,bool * is_ge_zero) const921 bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
922     SENode* node, bool* is_ge_zero) const {
923   return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
924 }
925 
926 namespace {
927 
928 // Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
929 // if |node| is not in the chain, returns the original chain.
RemoveOneNodeFromMultiplyChain(SEMultiplyNode * mul,const SENode * node)930 SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
931                                        const SENode* node) {
932   SENode* lhs = mul->GetChildren()[0];
933   SENode* rhs = mul->GetChildren()[1];
934   if (lhs == node) {
935     return rhs;
936   }
937   if (rhs == node) {
938     return lhs;
939   }
940   if (lhs->AsSEMultiplyNode()) {
941     SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
942     if (res != lhs)
943       return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
944   }
945   if (rhs->AsSEMultiplyNode()) {
946     SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
947     if (res != rhs)
948       return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
949   }
950 
951   return mul;
952 }
953 }  // namespace
954 
operator /(SExpression rhs_wrapper) const955 std::pair<SExpression, int64_t> SExpression::operator/(
956     SExpression rhs_wrapper) const {
957   SENode* lhs = node_;
958   SENode* rhs = rhs_wrapper.node_;
959   // Check for division by 0.
960   if (rhs->AsSEConstantNode() &&
961       !rhs->AsSEConstantNode()->FoldToSingleValue()) {
962     return {scev_->CreateCantComputeNode(), 0};
963   }
964 
965   // Trivial case.
966   if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
967     int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
968     int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
969     return {scev_->CreateConstant(lhs_value / rhs_value),
970             lhs_value % rhs_value};
971   }
972 
973   // look for a "c U / U" pattern.
974   if (lhs->AsSEMultiplyNode()) {
975     assert(lhs->GetChildren().size() == 2 &&
976            "More than 2 operand for a multiply node.");
977     SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
978     if (res != lhs) {
979       return {res, 0};
980     }
981   }
982 
983   return {scev_->CreateCantComputeNode(), 0};
984 }
985 
986 }  // namespace opt
987 }  // namespace spvtools
988