1 // Copyright (c) 2018 Google LLC.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "source/opt/scalar_analysis.h"
16
17 #include <functional>
18 #include <string>
19 #include <utility>
20
21 #include "source/opt/ir_context.h"
22
23 // Transforms a given scalar operation instruction into a DAG representation.
24 //
25 // 1. Take an instruction and traverse its operands until we reach a
26 // constant node or an instruction which we do not know how to compute the
27 // value, such as a load.
28 //
29 // 2. Create a new node for each instruction traversed and build the nodes for
30 // the in operands of that instruction as well.
31 //
32 // 3. Add the operand nodes as children of the first and hash the node. Use the
33 // hash to see if the node is already in the cache. We ensure the children are
34 // always in sorted order so that two nodes with the same children but inserted
35 // in a different order have the same hash and so that the overloaded operator==
36 // will return true. If the node is already in the cache return the cached
37 // version instead.
38 //
39 // 4. The created DAG can then be simplified by
40 // ScalarAnalysis::SimplifyExpression, implemented in
41 // scalar_analysis_simplification.cpp. See that file for further information on
42 // the simplification process.
43 //
44
45 namespace spvtools {
46 namespace opt {
47
48 uint32_t SENode::NumberOfNodes = 0;
49
ScalarEvolutionAnalysis(IRContext * context)50 ScalarEvolutionAnalysis::ScalarEvolutionAnalysis(IRContext* context)
51 : context_(context), pretend_equal_{} {
52 // Create and cached the CantComputeNode.
53 cached_cant_compute_ =
54 GetCachedOrAdd(std::unique_ptr<SECantCompute>(new SECantCompute(this)));
55 }
56
CreateNegation(SENode * operand)57 SENode* ScalarEvolutionAnalysis::CreateNegation(SENode* operand) {
58 // If operand is can't compute then the whole graph is can't compute.
59 if (operand->IsCantCompute()) return CreateCantComputeNode();
60
61 if (operand->GetType() == SENode::Constant) {
62 return CreateConstant(-operand->AsSEConstantNode()->FoldToSingleValue());
63 }
64 std::unique_ptr<SENode> negation_node{new SENegative(this)};
65 negation_node->AddChild(operand);
66 return GetCachedOrAdd(std::move(negation_node));
67 }
68
CreateConstant(int64_t integer)69 SENode* ScalarEvolutionAnalysis::CreateConstant(int64_t integer) {
70 return GetCachedOrAdd(
71 std::unique_ptr<SENode>(new SEConstantNode(this, integer)));
72 }
73
CreateRecurrentExpression(const Loop * loop,SENode * offset,SENode * coefficient)74 SENode* ScalarEvolutionAnalysis::CreateRecurrentExpression(
75 const Loop* loop, SENode* offset, SENode* coefficient) {
76 assert(loop && "Recurrent add expressions must have a valid loop.");
77
78 // If operands are can't compute then the whole graph is can't compute.
79 if (offset->IsCantCompute() || coefficient->IsCantCompute())
80 return CreateCantComputeNode();
81
82 const Loop* loop_to_use = nullptr;
83 if (pretend_equal_[loop]) {
84 loop_to_use = pretend_equal_[loop];
85 } else {
86 loop_to_use = loop;
87 }
88
89 std::unique_ptr<SERecurrentNode> phi_node{
90 new SERecurrentNode(this, loop_to_use)};
91 phi_node->AddOffset(offset);
92 phi_node->AddCoefficient(coefficient);
93
94 return GetCachedOrAdd(std::move(phi_node));
95 }
96
AnalyzeMultiplyOp(const Instruction * multiply)97 SENode* ScalarEvolutionAnalysis::AnalyzeMultiplyOp(
98 const Instruction* multiply) {
99 assert(multiply->opcode() == spv::Op::OpIMul &&
100 "Multiply node did not come from a multiply instruction");
101 analysis::DefUseManager* def_use = context_->get_def_use_mgr();
102
103 SENode* op1 =
104 AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(0)));
105 SENode* op2 =
106 AnalyzeInstruction(def_use->GetDef(multiply->GetSingleWordInOperand(1)));
107
108 return CreateMultiplyNode(op1, op2);
109 }
110
CreateMultiplyNode(SENode * operand_1,SENode * operand_2)111 SENode* ScalarEvolutionAnalysis::CreateMultiplyNode(SENode* operand_1,
112 SENode* operand_2) {
113 // If operands are can't compute then the whole graph is can't compute.
114 if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
115 return CreateCantComputeNode();
116
117 if (operand_1->GetType() == SENode::Constant &&
118 operand_2->GetType() == SENode::Constant) {
119 return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() *
120 operand_2->AsSEConstantNode()->FoldToSingleValue());
121 }
122
123 std::unique_ptr<SENode> multiply_node{new SEMultiplyNode(this)};
124
125 multiply_node->AddChild(operand_1);
126 multiply_node->AddChild(operand_2);
127
128 return GetCachedOrAdd(std::move(multiply_node));
129 }
130
CreateSubtraction(SENode * operand_1,SENode * operand_2)131 SENode* ScalarEvolutionAnalysis::CreateSubtraction(SENode* operand_1,
132 SENode* operand_2) {
133 // Fold if both operands are constant.
134 if (operand_1->GetType() == SENode::Constant &&
135 operand_2->GetType() == SENode::Constant) {
136 return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() -
137 operand_2->AsSEConstantNode()->FoldToSingleValue());
138 }
139
140 return CreateAddNode(operand_1, CreateNegation(operand_2));
141 }
142
CreateAddNode(SENode * operand_1,SENode * operand_2)143 SENode* ScalarEvolutionAnalysis::CreateAddNode(SENode* operand_1,
144 SENode* operand_2) {
145 // Fold if both operands are constant and the |simplify| flag is true.
146 if (operand_1->GetType() == SENode::Constant &&
147 operand_2->GetType() == SENode::Constant) {
148 return CreateConstant(operand_1->AsSEConstantNode()->FoldToSingleValue() +
149 operand_2->AsSEConstantNode()->FoldToSingleValue());
150 }
151
152 // If operands are can't compute then the whole graph is can't compute.
153 if (operand_1->IsCantCompute() || operand_2->IsCantCompute())
154 return CreateCantComputeNode();
155
156 std::unique_ptr<SENode> add_node{new SEAddNode(this)};
157
158 add_node->AddChild(operand_1);
159 add_node->AddChild(operand_2);
160
161 return GetCachedOrAdd(std::move(add_node));
162 }
163
AnalyzeInstruction(const Instruction * inst)164 SENode* ScalarEvolutionAnalysis::AnalyzeInstruction(const Instruction* inst) {
165 auto itr = recurrent_node_map_.find(inst);
166 if (itr != recurrent_node_map_.end()) return itr->second;
167
168 SENode* output = nullptr;
169 switch (inst->opcode()) {
170 case spv::Op::OpPhi: {
171 output = AnalyzePhiInstruction(inst);
172 break;
173 }
174 case spv::Op::OpConstant:
175 case spv::Op::OpConstantNull: {
176 output = AnalyzeConstant(inst);
177 break;
178 }
179 case spv::Op::OpISub:
180 case spv::Op::OpIAdd: {
181 output = AnalyzeAddOp(inst);
182 break;
183 }
184 case spv::Op::OpIMul: {
185 output = AnalyzeMultiplyOp(inst);
186 break;
187 }
188 default: {
189 output = CreateValueUnknownNode(inst);
190 break;
191 }
192 }
193
194 return output;
195 }
196
AnalyzeConstant(const Instruction * inst)197 SENode* ScalarEvolutionAnalysis::AnalyzeConstant(const Instruction* inst) {
198 if (inst->opcode() == spv::Op::OpConstantNull) return CreateConstant(0);
199
200 assert(inst->opcode() == spv::Op::OpConstant);
201 assert(inst->NumInOperands() == 1);
202 int64_t value = 0;
203
204 // Look up the instruction in the constant manager.
205 const analysis::Constant* constant =
206 context_->get_constant_mgr()->FindDeclaredConstant(inst->result_id());
207
208 if (!constant) return CreateCantComputeNode();
209
210 const analysis::IntConstant* int_constant = constant->AsIntConstant();
211
212 // Exit out if it is a 64 bit integer.
213 if (!int_constant || int_constant->words().size() != 1)
214 return CreateCantComputeNode();
215
216 if (int_constant->type()->AsInteger()->IsSigned()) {
217 value = int_constant->GetS32BitValue();
218 } else {
219 value = int_constant->GetU32BitValue();
220 }
221
222 return CreateConstant(value);
223 }
224
225 // Handles both addition and subtraction. If the |sub| flag is set then the
226 // addition will be op1+(-op2) otherwise op1+op2.
AnalyzeAddOp(const Instruction * inst)227 SENode* ScalarEvolutionAnalysis::AnalyzeAddOp(const Instruction* inst) {
228 assert((inst->opcode() == spv::Op::OpIAdd ||
229 inst->opcode() == spv::Op::OpISub) &&
230 "Add node must be created from a OpIAdd or OpISub instruction");
231
232 analysis::DefUseManager* def_use = context_->get_def_use_mgr();
233
234 SENode* op1 =
235 AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(0)));
236
237 SENode* op2 =
238 AnalyzeInstruction(def_use->GetDef(inst->GetSingleWordInOperand(1)));
239
240 // To handle subtraction we wrap the second operand in a unary negation node.
241 if (inst->opcode() == spv::Op::OpISub) {
242 op2 = CreateNegation(op2);
243 }
244
245 return CreateAddNode(op1, op2);
246 }
247
AnalyzePhiInstruction(const Instruction * phi)248 SENode* ScalarEvolutionAnalysis::AnalyzePhiInstruction(const Instruction* phi) {
249 // The phi should only have two incoming value pairs.
250 if (phi->NumInOperands() != 4) {
251 return CreateCantComputeNode();
252 }
253
254 analysis::DefUseManager* def_use = context_->get_def_use_mgr();
255
256 // Get the basic block this instruction belongs to.
257 BasicBlock* basic_block =
258 context_->get_instr_block(const_cast<Instruction*>(phi));
259
260 // And then the function that the basic blocks belongs to.
261 Function* function = basic_block->GetParent();
262
263 // Use the function to get the loop descriptor.
264 LoopDescriptor* loop_descriptor = context_->GetLoopDescriptor(function);
265
266 // We only handle phis in loops at the moment.
267 if (!loop_descriptor) return CreateCantComputeNode();
268
269 // Get the innermost loop which this block belongs to.
270 Loop* loop = (*loop_descriptor)[basic_block->id()];
271
272 // If the loop doesn't exist or doesn't have a preheader or latch block, exit
273 // out.
274 if (!loop || !loop->GetLatchBlock() || !loop->GetPreHeaderBlock() ||
275 loop->GetHeaderBlock() != basic_block)
276 return recurrent_node_map_[phi] = CreateCantComputeNode();
277
278 const Loop* loop_to_use = nullptr;
279 if (pretend_equal_[loop]) {
280 loop_to_use = pretend_equal_[loop];
281 } else {
282 loop_to_use = loop;
283 }
284 std::unique_ptr<SERecurrentNode> phi_node{
285 new SERecurrentNode(this, loop_to_use)};
286
287 // We add the node to this map to allow it to be returned before the node is
288 // fully built. This is needed as the subsequent call to AnalyzeInstruction
289 // could lead back to this |phi| instruction so we return the pointer
290 // immediately in AnalyzeInstruction to break the recursion.
291 recurrent_node_map_[phi] = phi_node.get();
292
293 // Traverse the operands of the instruction an create new nodes for each one.
294 for (uint32_t i = 0; i < phi->NumInOperands(); i += 2) {
295 uint32_t value_id = phi->GetSingleWordInOperand(i);
296 uint32_t incoming_label_id = phi->GetSingleWordInOperand(i + 1);
297
298 Instruction* value_inst = def_use->GetDef(value_id);
299 SENode* value_node = AnalyzeInstruction(value_inst);
300
301 // If any operand is CantCompute then the whole graph is CantCompute.
302 if (value_node->IsCantCompute())
303 return recurrent_node_map_[phi] = CreateCantComputeNode();
304
305 // If the value is coming from the preheader block then the value is the
306 // initial value of the phi.
307 if (incoming_label_id == loop->GetPreHeaderBlock()->id()) {
308 phi_node->AddOffset(value_node);
309 } else if (incoming_label_id == loop->GetLatchBlock()->id()) {
310 // Assumed to be in the form of step + phi.
311 if (value_node->GetType() != SENode::Add)
312 return recurrent_node_map_[phi] = CreateCantComputeNode();
313
314 SENode* step_node = nullptr;
315 SENode* phi_operand = nullptr;
316 SENode* operand_1 = value_node->GetChild(0);
317 SENode* operand_2 = value_node->GetChild(1);
318
319 // Find which node is the step term.
320 if (!operand_1->AsSERecurrentNode())
321 step_node = operand_1;
322 else if (!operand_2->AsSERecurrentNode())
323 step_node = operand_2;
324
325 // Find which node is the recurrent expression.
326 if (operand_1->AsSERecurrentNode())
327 phi_operand = operand_1;
328 else if (operand_2->AsSERecurrentNode())
329 phi_operand = operand_2;
330
331 // If it is not in the form step + phi exit out.
332 if (!(step_node && phi_operand))
333 return recurrent_node_map_[phi] = CreateCantComputeNode();
334
335 // If the phi operand is not the same phi node exit out.
336 if (phi_operand != phi_node.get())
337 return recurrent_node_map_[phi] = CreateCantComputeNode();
338
339 if (!IsLoopInvariant(loop, step_node))
340 return recurrent_node_map_[phi] = CreateCantComputeNode();
341
342 phi_node->AddCoefficient(step_node);
343 }
344 }
345
346 // Once the node is fully built we update the map with the version from the
347 // cache (if it has already been added to the cache).
348 return recurrent_node_map_[phi] = GetCachedOrAdd(std::move(phi_node));
349 }
350
CreateValueUnknownNode(const Instruction * inst)351 SENode* ScalarEvolutionAnalysis::CreateValueUnknownNode(
352 const Instruction* inst) {
353 std::unique_ptr<SEValueUnknown> load_node{
354 new SEValueUnknown(this, inst->result_id())};
355 return GetCachedOrAdd(std::move(load_node));
356 }
357
CreateCantComputeNode()358 SENode* ScalarEvolutionAnalysis::CreateCantComputeNode() {
359 return cached_cant_compute_;
360 }
361
362 // Add the created node into the cache of nodes. If it already exists return it.
GetCachedOrAdd(std::unique_ptr<SENode> prospective_node)363 SENode* ScalarEvolutionAnalysis::GetCachedOrAdd(
364 std::unique_ptr<SENode> prospective_node) {
365 auto itr = node_cache_.find(prospective_node);
366 if (itr != node_cache_.end()) {
367 return (*itr).get();
368 }
369
370 SENode* raw_ptr_to_node = prospective_node.get();
371 node_cache_.insert(std::move(prospective_node));
372 return raw_ptr_to_node;
373 }
374
IsLoopInvariant(const Loop * loop,const SENode * node) const375 bool ScalarEvolutionAnalysis::IsLoopInvariant(const Loop* loop,
376 const SENode* node) const {
377 for (auto itr = node->graph_cbegin(); itr != node->graph_cend(); ++itr) {
378 if (const SERecurrentNode* rec = itr->AsSERecurrentNode()) {
379 const BasicBlock* header = rec->GetLoop()->GetHeaderBlock();
380
381 // If the loop which the recurrent expression belongs to is either |loop
382 // or a nested loop inside |loop| then we assume it is variant.
383 if (loop->IsInsideLoop(header)) {
384 return false;
385 }
386 } else if (const SEValueUnknown* unknown = itr->AsSEValueUnknown()) {
387 // If the instruction is inside the loop we conservatively assume it is
388 // loop variant.
389 if (loop->IsInsideLoop(unknown->ResultId())) return false;
390 }
391 }
392
393 return true;
394 }
395
GetCoefficientFromRecurrentTerm(SENode * node,const Loop * loop)396 SENode* ScalarEvolutionAnalysis::GetCoefficientFromRecurrentTerm(
397 SENode* node, const Loop* loop) {
398 // Traverse the DAG to find the recurrent expression belonging to |loop|.
399 for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
400 SERecurrentNode* rec = itr->AsSERecurrentNode();
401 if (rec && rec->GetLoop() == loop) {
402 return rec->GetCoefficient();
403 }
404 }
405 return CreateConstant(0);
406 }
407
UpdateChildNode(SENode * parent,SENode * old_child,SENode * new_child)408 SENode* ScalarEvolutionAnalysis::UpdateChildNode(SENode* parent,
409 SENode* old_child,
410 SENode* new_child) {
411 // Only handles add.
412 if (parent->GetType() != SENode::Add) return parent;
413
414 std::vector<SENode*> new_children;
415 for (SENode* child : *parent) {
416 if (child == old_child) {
417 new_children.push_back(new_child);
418 } else {
419 new_children.push_back(child);
420 }
421 }
422
423 std::unique_ptr<SENode> add_node{new SEAddNode(this)};
424 for (SENode* child : new_children) {
425 add_node->AddChild(child);
426 }
427
428 return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
429 }
430
431 // Rebuild the |node| eliminating, if it exists, the recurrent term which
432 // belongs to the |loop|.
BuildGraphWithoutRecurrentTerm(SENode * node,const Loop * loop)433 SENode* ScalarEvolutionAnalysis::BuildGraphWithoutRecurrentTerm(
434 SENode* node, const Loop* loop) {
435 // If the node is already a recurrent expression belonging to loop then just
436 // return the offset.
437 SERecurrentNode* recurrent = node->AsSERecurrentNode();
438 if (recurrent) {
439 if (recurrent->GetLoop() == loop) {
440 return recurrent->GetOffset();
441 } else {
442 return node;
443 }
444 }
445
446 std::vector<SENode*> new_children;
447 // Otherwise find the recurrent node in the children of this node.
448 for (auto itr : *node) {
449 recurrent = itr->AsSERecurrentNode();
450 if (recurrent && recurrent->GetLoop() == loop) {
451 new_children.push_back(recurrent->GetOffset());
452 } else {
453 new_children.push_back(itr);
454 }
455 }
456
457 std::unique_ptr<SENode> add_node{new SEAddNode(this)};
458 for (SENode* child : new_children) {
459 add_node->AddChild(child);
460 }
461
462 return SimplifyExpression(GetCachedOrAdd(std::move(add_node)));
463 }
464
465 // Return the recurrent term belonging to |loop| if it appears in the graph
466 // starting at |node| or null if it doesn't.
GetRecurrentTerm(SENode * node,const Loop * loop)467 SERecurrentNode* ScalarEvolutionAnalysis::GetRecurrentTerm(SENode* node,
468 const Loop* loop) {
469 for (auto itr = node->graph_begin(); itr != node->graph_end(); ++itr) {
470 SERecurrentNode* rec = itr->AsSERecurrentNode();
471 if (rec && rec->GetLoop() == loop) {
472 return rec;
473 }
474 }
475 return nullptr;
476 }
AsString() const477 std::string SENode::AsString() const {
478 switch (GetType()) {
479 case Constant:
480 return "Constant";
481 case RecurrentAddExpr:
482 return "RecurrentAddExpr";
483 case Add:
484 return "Add";
485 case Negative:
486 return "Negative";
487 case Multiply:
488 return "Multiply";
489 case ValueUnknown:
490 return "Value Unknown";
491 case CanNotCompute:
492 return "Can not compute";
493 }
494 return "NULL";
495 }
496
operator ==(const SENode & other) const497 bool SENode::operator==(const SENode& other) const {
498 if (GetType() != other.GetType()) return false;
499
500 if (other.GetChildren().size() != children_.size()) return false;
501
502 const SERecurrentNode* this_as_recurrent = AsSERecurrentNode();
503
504 // Check the children are the same, for SERecurrentNodes we need to check the
505 // offset and coefficient manually as the child vector is sorted by ids so the
506 // offset/coefficient information is lost.
507 if (!this_as_recurrent) {
508 for (size_t index = 0; index < children_.size(); ++index) {
509 if (other.GetChildren()[index] != children_[index]) return false;
510 }
511 } else {
512 const SERecurrentNode* other_as_recurrent = other.AsSERecurrentNode();
513
514 // We've already checked the types are the same, this should not fail if
515 // this->AsSERecurrentNode() succeeded.
516 assert(other_as_recurrent);
517
518 if (this_as_recurrent->GetCoefficient() !=
519 other_as_recurrent->GetCoefficient())
520 return false;
521
522 if (this_as_recurrent->GetOffset() != other_as_recurrent->GetOffset())
523 return false;
524
525 if (this_as_recurrent->GetLoop() != other_as_recurrent->GetLoop())
526 return false;
527 }
528
529 // If we're dealing with a value unknown node check both nodes were created by
530 // the same instruction.
531 if (GetType() == SENode::ValueUnknown) {
532 if (AsSEValueUnknown()->ResultId() !=
533 other.AsSEValueUnknown()->ResultId()) {
534 return false;
535 }
536 }
537
538 if (AsSEConstantNode()) {
539 if (AsSEConstantNode()->FoldToSingleValue() !=
540 other.AsSEConstantNode()->FoldToSingleValue())
541 return false;
542 }
543
544 return true;
545 }
546
operator !=(const SENode & other) const547 bool SENode::operator!=(const SENode& other) const { return !(*this == other); }
548
549 namespace {
550 // Helper functions to insert 32/64 bit values into the 32 bit hash string. This
551 // allows us to add pointers to the string by reinterpreting the pointers as
552 // uintptr_t. PushToString will deduce the type, call sizeof on it and use
553 // that size to call into the correct PushToStringImpl functor depending on
554 // whether it is 32 or 64 bit.
555
556 template <typename T, size_t size_of_t>
557 struct PushToStringImpl;
558
559 template <typename T>
560 struct PushToStringImpl<T, 8> {
operator ()spvtools::opt::__anon2670a6200111::PushToStringImpl561 void operator()(T id, std::u32string* str) {
562 str->push_back(static_cast<uint32_t>(id >> 32));
563 str->push_back(static_cast<uint32_t>(id));
564 }
565 };
566
567 template <typename T>
568 struct PushToStringImpl<T, 4> {
operator ()spvtools::opt::__anon2670a6200111::PushToStringImpl569 void operator()(T id, std::u32string* str) {
570 str->push_back(static_cast<uint32_t>(id));
571 }
572 };
573
574 template <typename T>
PushToString(T id,std::u32string * str)575 void PushToString(T id, std::u32string* str) {
576 PushToStringImpl<T, sizeof(T)>{}(id, str);
577 }
578
579 } // namespace
580
581 // Implements the hashing of SENodes.
operator ()(const SENode * node) const582 size_t SENodeHash::operator()(const SENode* node) const {
583 // Concatenate the terms into a string which we can hash.
584 std::u32string hash_string{};
585
586 // Hashing the type as a string is safer than hashing the enum as the enum is
587 // very likely to collide with constants.
588 for (char ch : node->AsString()) {
589 hash_string.push_back(static_cast<char32_t>(ch));
590 }
591
592 // We just ignore the literal value unless it is a constant.
593 if (node->GetType() == SENode::Constant)
594 PushToString(node->AsSEConstantNode()->FoldToSingleValue(), &hash_string);
595
596 const SERecurrentNode* recurrent = node->AsSERecurrentNode();
597
598 // If we're dealing with a recurrent expression hash the loop as well so that
599 // nested inductions like i=0,i++ and j=0,j++ correspond to different nodes.
600 if (recurrent) {
601 PushToString(reinterpret_cast<uintptr_t>(recurrent->GetLoop()),
602 &hash_string);
603
604 // Recurrent expressions can't be hashed using the normal method as the
605 // order of coefficient and offset matters to the hash.
606 PushToString(reinterpret_cast<uintptr_t>(recurrent->GetCoefficient()),
607 &hash_string);
608 PushToString(reinterpret_cast<uintptr_t>(recurrent->GetOffset()),
609 &hash_string);
610
611 return std::hash<std::u32string>{}(hash_string);
612 }
613
614 // Hash the result id of the original instruction which created this node if
615 // it is a value unknown node.
616 if (node->GetType() == SENode::ValueUnknown) {
617 PushToString(node->AsSEValueUnknown()->ResultId(), &hash_string);
618 }
619
620 // Hash the pointers of the child nodes, each SENode has a unique pointer
621 // associated with it.
622 const std::vector<SENode*>& children = node->GetChildren();
623 for (const SENode* child : children) {
624 PushToString(reinterpret_cast<uintptr_t>(child), &hash_string);
625 }
626
627 return std::hash<std::u32string>{}(hash_string);
628 }
629
630 // This overload is the actual overload used by the node_cache_ set.
operator ()(const std::unique_ptr<SENode> & node) const631 size_t SENodeHash::operator()(const std::unique_ptr<SENode>& node) const {
632 return this->operator()(node.get());
633 }
634
DumpDot(std::ostream & out,bool recurse) const635 void SENode::DumpDot(std::ostream& out, bool recurse) const {
636 size_t unique_id = std::hash<const SENode*>{}(this);
637 out << unique_id << " [label=\"" << AsString() << " ";
638 if (GetType() == SENode::Constant) {
639 out << "\nwith value: " << this->AsSEConstantNode()->FoldToSingleValue();
640 }
641 out << "\"]\n";
642 for (const SENode* child : children_) {
643 size_t child_unique_id = std::hash<const SENode*>{}(child);
644 out << unique_id << " -> " << child_unique_id << " \n";
645 if (recurse) child->DumpDot(out, true);
646 }
647 }
648
649 namespace {
650 class IsGreaterThanZero {
651 public:
IsGreaterThanZero(IRContext * context)652 explicit IsGreaterThanZero(IRContext* context) : context_(context) {}
653
654 // Determine if the value of |node| is always strictly greater than zero if
655 // |or_equal_zero| is false or greater or equal to zero if |or_equal_zero| is
656 // true. It returns true is the evaluation was able to conclude something, in
657 // which case the result is stored in |result|.
658 // The algorithm work by going through all the nodes and determine the
659 // sign of each of them.
Eval(const SENode * node,bool or_equal_zero,bool * result)660 bool Eval(const SENode* node, bool or_equal_zero, bool* result) {
661 *result = false;
662 switch (Visit(node)) {
663 case Signedness::kPositiveOrNegative: {
664 return false;
665 }
666 case Signedness::kStrictlyNegative: {
667 *result = false;
668 break;
669 }
670 case Signedness::kNegative: {
671 if (!or_equal_zero) {
672 return false;
673 }
674 *result = false;
675 break;
676 }
677 case Signedness::kStrictlyPositive: {
678 *result = true;
679 break;
680 }
681 case Signedness::kPositive: {
682 if (!or_equal_zero) {
683 return false;
684 }
685 *result = true;
686 break;
687 }
688 }
689 return true;
690 }
691
692 private:
693 enum class Signedness {
694 kPositiveOrNegative, // Yield a value positive or negative.
695 kStrictlyNegative, // Yield a value strictly less than 0.
696 kNegative, // Yield a value less or equal to 0.
697 kStrictlyPositive, // Yield a value strictly greater than 0.
698 kPositive // Yield a value greater or equal to 0.
699 };
700
701 // Combine the signedness according to arithmetic rules of a given operator.
702 using Combiner = std::function<Signedness(Signedness, Signedness)>;
703
704 // Returns a functor to interpret the signedness of 2 expressions as if they
705 // were added.
GetAddCombiner() const706 Combiner GetAddCombiner() const {
707 return [](Signedness lhs, Signedness rhs) {
708 switch (lhs) {
709 case Signedness::kPositiveOrNegative:
710 break;
711 case Signedness::kStrictlyNegative:
712 if (rhs == Signedness::kStrictlyNegative ||
713 rhs == Signedness::kNegative)
714 return lhs;
715 break;
716 case Signedness::kNegative: {
717 if (rhs == Signedness::kStrictlyNegative)
718 return Signedness::kStrictlyNegative;
719 if (rhs == Signedness::kNegative) return Signedness::kNegative;
720 break;
721 }
722 case Signedness::kStrictlyPositive: {
723 if (rhs == Signedness::kStrictlyPositive ||
724 rhs == Signedness::kPositive) {
725 return Signedness::kStrictlyPositive;
726 }
727 break;
728 }
729 case Signedness::kPositive: {
730 if (rhs == Signedness::kStrictlyPositive)
731 return Signedness::kStrictlyPositive;
732 if (rhs == Signedness::kPositive) return Signedness::kPositive;
733 break;
734 }
735 }
736 return Signedness::kPositiveOrNegative;
737 };
738 }
739
740 // Returns a functor to interpret the signedness of 2 expressions as if they
741 // were multiplied.
GetMulCombiner() const742 Combiner GetMulCombiner() const {
743 return [](Signedness lhs, Signedness rhs) {
744 switch (lhs) {
745 case Signedness::kPositiveOrNegative:
746 break;
747 case Signedness::kStrictlyNegative: {
748 switch (rhs) {
749 case Signedness::kPositiveOrNegative: {
750 break;
751 }
752 case Signedness::kStrictlyNegative: {
753 return Signedness::kStrictlyPositive;
754 }
755 case Signedness::kNegative: {
756 return Signedness::kPositive;
757 }
758 case Signedness::kStrictlyPositive: {
759 return Signedness::kStrictlyNegative;
760 }
761 case Signedness::kPositive: {
762 return Signedness::kNegative;
763 }
764 }
765 break;
766 }
767 case Signedness::kNegative: {
768 switch (rhs) {
769 case Signedness::kPositiveOrNegative: {
770 break;
771 }
772 case Signedness::kStrictlyNegative:
773 case Signedness::kNegative: {
774 return Signedness::kPositive;
775 }
776 case Signedness::kStrictlyPositive:
777 case Signedness::kPositive: {
778 return Signedness::kNegative;
779 }
780 }
781 break;
782 }
783 case Signedness::kStrictlyPositive: {
784 return rhs;
785 }
786 case Signedness::kPositive: {
787 switch (rhs) {
788 case Signedness::kPositiveOrNegative: {
789 break;
790 }
791 case Signedness::kStrictlyNegative:
792 case Signedness::kNegative: {
793 return Signedness::kNegative;
794 }
795 case Signedness::kStrictlyPositive:
796 case Signedness::kPositive: {
797 return Signedness::kPositive;
798 }
799 }
800 break;
801 }
802 }
803 return Signedness::kPositiveOrNegative;
804 };
805 }
806
Visit(const SENode * node)807 Signedness Visit(const SENode* node) {
808 switch (node->GetType()) {
809 case SENode::Constant:
810 return Visit(node->AsSEConstantNode());
811 break;
812 case SENode::RecurrentAddExpr:
813 return Visit(node->AsSERecurrentNode());
814 break;
815 case SENode::Negative:
816 return Visit(node->AsSENegative());
817 break;
818 case SENode::CanNotCompute:
819 return Visit(node->AsSECantCompute());
820 break;
821 case SENode::ValueUnknown:
822 return Visit(node->AsSEValueUnknown());
823 break;
824 case SENode::Add:
825 return VisitExpr(node, GetAddCombiner());
826 break;
827 case SENode::Multiply:
828 return VisitExpr(node, GetMulCombiner());
829 break;
830 }
831 return Signedness::kPositiveOrNegative;
832 }
833
834 // Returns the signedness of a constant |node|.
Visit(const SEConstantNode * node)835 Signedness Visit(const SEConstantNode* node) {
836 if (0 == node->FoldToSingleValue()) return Signedness::kPositive;
837 if (0 < node->FoldToSingleValue()) return Signedness::kStrictlyPositive;
838 if (0 > node->FoldToSingleValue()) return Signedness::kStrictlyNegative;
839 return Signedness::kPositiveOrNegative;
840 }
841
842 // Returns the signedness of an unknown |node| based on its type.
Visit(const SEValueUnknown * node)843 Signedness Visit(const SEValueUnknown* node) {
844 Instruction* insn = context_->get_def_use_mgr()->GetDef(node->ResultId());
845 analysis::Type* type = context_->get_type_mgr()->GetType(insn->type_id());
846 assert(type && "Can't retrieve a type for the instruction");
847 analysis::Integer* int_type = type->AsInteger();
848 assert(type && "Can't retrieve an integer type for the instruction");
849 return int_type->IsSigned() ? Signedness::kPositiveOrNegative
850 : Signedness::kPositive;
851 }
852
853 // Returns the signedness of a recurring expression.
Visit(const SERecurrentNode * node)854 Signedness Visit(const SERecurrentNode* node) {
855 Signedness coeff_sign = Visit(node->GetCoefficient());
856 // SERecurrentNode represent an affine expression in the range [0,
857 // loop_bound], so the result cannot be strictly positive or negative.
858 switch (coeff_sign) {
859 default:
860 break;
861 case Signedness::kStrictlyNegative:
862 coeff_sign = Signedness::kNegative;
863 break;
864 case Signedness::kStrictlyPositive:
865 coeff_sign = Signedness::kPositive;
866 break;
867 }
868 return GetAddCombiner()(coeff_sign, Visit(node->GetOffset()));
869 }
870
871 // Returns the signedness of a negation |node|.
Visit(const SENegative * node)872 Signedness Visit(const SENegative* node) {
873 switch (Visit(*node->begin())) {
874 case Signedness::kPositiveOrNegative: {
875 return Signedness::kPositiveOrNegative;
876 }
877 case Signedness::kStrictlyNegative: {
878 return Signedness::kStrictlyPositive;
879 }
880 case Signedness::kNegative: {
881 return Signedness::kPositive;
882 }
883 case Signedness::kStrictlyPositive: {
884 return Signedness::kStrictlyNegative;
885 }
886 case Signedness::kPositive: {
887 return Signedness::kNegative;
888 }
889 }
890 return Signedness::kPositiveOrNegative;
891 }
892
Visit(const SECantCompute *)893 Signedness Visit(const SECantCompute*) {
894 return Signedness::kPositiveOrNegative;
895 }
896
897 // Returns the signedness of a binary expression by using the combiner
898 // |reduce|.
VisitExpr(const SENode * node,std::function<Signedness (Signedness,Signedness)> reduce)899 Signedness VisitExpr(
900 const SENode* node,
901 std::function<Signedness(Signedness, Signedness)> reduce) {
902 Signedness result = Visit(*node->begin());
903 for (const SENode* operand : make_range(++node->begin(), node->end())) {
904 if (result == Signedness::kPositiveOrNegative) {
905 return Signedness::kPositiveOrNegative;
906 }
907 result = reduce(result, Visit(operand));
908 }
909 return result;
910 }
911
912 IRContext* context_;
913 };
914 } // namespace
915
IsAlwaysGreaterThanZero(SENode * node,bool * is_gt_zero) const916 bool ScalarEvolutionAnalysis::IsAlwaysGreaterThanZero(SENode* node,
917 bool* is_gt_zero) const {
918 return IsGreaterThanZero(context_).Eval(node, false, is_gt_zero);
919 }
920
IsAlwaysGreaterOrEqualToZero(SENode * node,bool * is_ge_zero) const921 bool ScalarEvolutionAnalysis::IsAlwaysGreaterOrEqualToZero(
922 SENode* node, bool* is_ge_zero) const {
923 return IsGreaterThanZero(context_).Eval(node, true, is_ge_zero);
924 }
925
926 namespace {
927
928 // Remove |node| from the |mul| chain (of the form A * ... * |node| * ... * Z),
929 // if |node| is not in the chain, returns the original chain.
RemoveOneNodeFromMultiplyChain(SEMultiplyNode * mul,const SENode * node)930 SENode* RemoveOneNodeFromMultiplyChain(SEMultiplyNode* mul,
931 const SENode* node) {
932 SENode* lhs = mul->GetChildren()[0];
933 SENode* rhs = mul->GetChildren()[1];
934 if (lhs == node) {
935 return rhs;
936 }
937 if (rhs == node) {
938 return lhs;
939 }
940 if (lhs->AsSEMultiplyNode()) {
941 SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), node);
942 if (res != lhs)
943 return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
944 }
945 if (rhs->AsSEMultiplyNode()) {
946 SENode* res = RemoveOneNodeFromMultiplyChain(rhs->AsSEMultiplyNode(), node);
947 if (res != rhs)
948 return mul->GetParentAnalysis()->CreateMultiplyNode(res, rhs);
949 }
950
951 return mul;
952 }
953 } // namespace
954
operator /(SExpression rhs_wrapper) const955 std::pair<SExpression, int64_t> SExpression::operator/(
956 SExpression rhs_wrapper) const {
957 SENode* lhs = node_;
958 SENode* rhs = rhs_wrapper.node_;
959 // Check for division by 0.
960 if (rhs->AsSEConstantNode() &&
961 !rhs->AsSEConstantNode()->FoldToSingleValue()) {
962 return {scev_->CreateCantComputeNode(), 0};
963 }
964
965 // Trivial case.
966 if (lhs->AsSEConstantNode() && rhs->AsSEConstantNode()) {
967 int64_t lhs_value = lhs->AsSEConstantNode()->FoldToSingleValue();
968 int64_t rhs_value = rhs->AsSEConstantNode()->FoldToSingleValue();
969 return {scev_->CreateConstant(lhs_value / rhs_value),
970 lhs_value % rhs_value};
971 }
972
973 // look for a "c U / U" pattern.
974 if (lhs->AsSEMultiplyNode()) {
975 assert(lhs->GetChildren().size() == 2 &&
976 "More than 2 operand for a multiply node.");
977 SENode* res = RemoveOneNodeFromMultiplyChain(lhs->AsSEMultiplyNode(), rhs);
978 if (res != lhs) {
979 return {res, 0};
980 }
981 }
982
983 return {scev_->CreateCantComputeNode(), 0};
984 }
985
986 } // namespace opt
987 } // namespace spvtools
988