1// Copyright 2023 The Go Authors. All rights reserved. 2// Use of this source code is governed by a BSD-style 3// license that can be found in the LICENSE file. 4 5package ssa 6 7import ( 8 "fmt" 9) 10 11// ---------------------------------------------------------------------------- 12// Sparse Conditional Constant Propagation 13// 14// Described in 15// Mark N. Wegman, F. Kenneth Zadeck: Constant Propagation with Conditional Branches. 16// TOPLAS 1991. 17// 18// This algorithm uses three level lattice for SSA value 19// 20// Top undefined 21// / | \ 22// .. 1 2 3 .. constant 23// \ | / 24// Bottom not constant 25// 26// It starts with optimistically assuming that all SSA values are initially Top 27// and then propagates constant facts only along reachable control flow paths. 28// Since some basic blocks are not visited yet, corresponding inputs of phi become 29// Top, we use the meet(phi) to compute its lattice. 30// 31// Top ∩ any = any 32// Bottom ∩ any = Bottom 33// ConstantA ∩ ConstantA = ConstantA 34// ConstantA ∩ ConstantB = Bottom 35// 36// Each lattice value is lowered most twice(Top to Constant, Constant to Bottom) 37// due to lattice depth, resulting in a fast convergence speed of the algorithm. 38// In this way, sccp can discover optimization opportunities that cannot be found 39// by just combining constant folding and constant propagation and dead code 40// elimination separately. 41 42// Three level lattice holds compile time knowledge about SSA value 43const ( 44 top int8 = iota // undefined 45 constant // constant 46 bottom // not a constant 47) 48 49type lattice struct { 50 tag int8 // lattice type 51 val *Value // constant value 52} 53 54type worklist struct { 55 f *Func // the target function to be optimized out 56 edges []Edge // propagate constant facts through edges 57 uses []*Value // re-visiting set 58 visited map[Edge]bool // visited edges 59 latticeCells map[*Value]lattice // constant lattices 60 defUse map[*Value][]*Value // def-use chains for some values 61 defBlock map[*Value][]*Block // use blocks of def 62 visitedBlock []bool // visited block 63} 64 65// sccp stands for sparse conditional constant propagation, it propagates constants 66// through CFG conditionally and applies constant folding, constant replacement and 67// dead code elimination all together. 68func sccp(f *Func) { 69 var t worklist 70 t.f = f 71 t.edges = make([]Edge, 0) 72 t.visited = make(map[Edge]bool) 73 t.edges = append(t.edges, Edge{f.Entry, 0}) 74 t.defUse = make(map[*Value][]*Value) 75 t.defBlock = make(map[*Value][]*Block) 76 t.latticeCells = make(map[*Value]lattice) 77 t.visitedBlock = f.Cache.allocBoolSlice(f.NumBlocks()) 78 defer f.Cache.freeBoolSlice(t.visitedBlock) 79 80 // build it early since we rely heavily on the def-use chain later 81 t.buildDefUses() 82 83 // pick up either an edge or SSA value from worklist, process it 84 for { 85 if len(t.edges) > 0 { 86 edge := t.edges[0] 87 t.edges = t.edges[1:] 88 if _, exist := t.visited[edge]; !exist { 89 dest := edge.b 90 destVisited := t.visitedBlock[dest.ID] 91 92 // mark edge as visited 93 t.visited[edge] = true 94 t.visitedBlock[dest.ID] = true 95 for _, val := range dest.Values { 96 if val.Op == OpPhi || !destVisited { 97 t.visitValue(val) 98 } 99 } 100 // propagates constants facts through CFG, taking condition test 101 // into account 102 if !destVisited { 103 t.propagate(dest) 104 } 105 } 106 continue 107 } 108 if len(t.uses) > 0 { 109 use := t.uses[0] 110 t.uses = t.uses[1:] 111 t.visitValue(use) 112 continue 113 } 114 break 115 } 116 117 // apply optimizations based on discovered constants 118 constCnt, rewireCnt := t.replaceConst() 119 if f.pass.debug > 0 { 120 if constCnt > 0 || rewireCnt > 0 { 121 fmt.Printf("Phase SCCP for %v : %v constants, %v dce\n", f.Name, constCnt, rewireCnt) 122 } 123 } 124} 125 126func equals(a, b lattice) bool { 127 if a == b { 128 // fast path 129 return true 130 } 131 if a.tag != b.tag { 132 return false 133 } 134 if a.tag == constant { 135 // The same content of const value may be different, we should 136 // compare with auxInt instead 137 v1 := a.val 138 v2 := b.val 139 if v1.Op == v2.Op && v1.AuxInt == v2.AuxInt { 140 return true 141 } else { 142 return false 143 } 144 } 145 return true 146} 147 148// possibleConst checks if Value can be folded to const. For those Values that can 149// never become constants(e.g. StaticCall), we don't make futile efforts. 150func possibleConst(val *Value) bool { 151 if isConst(val) { 152 return true 153 } 154 switch val.Op { 155 case OpCopy: 156 return true 157 case OpPhi: 158 return true 159 case 160 // negate 161 OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F, 162 OpCom8, OpCom16, OpCom32, OpCom64, 163 // math 164 OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt, 165 // conversion 166 OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8, 167 OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F, 168 OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64, 169 OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F, 170 OpCvtBoolToUint8, 171 OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, 172 OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32, 173 OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64, 174 // bit 175 OpCtz8, OpCtz16, OpCtz32, OpCtz64, 176 // mask 177 OpSlicemask, 178 // safety check 179 OpIsNonNil, 180 // not 181 OpNot: 182 return true 183 case 184 // add 185 OpAdd64, OpAdd32, OpAdd16, OpAdd8, 186 OpAdd32F, OpAdd64F, 187 // sub 188 OpSub64, OpSub32, OpSub16, OpSub8, 189 OpSub32F, OpSub64F, 190 // mul 191 OpMul64, OpMul32, OpMul16, OpMul8, 192 OpMul32F, OpMul64F, 193 // div 194 OpDiv32F, OpDiv64F, 195 OpDiv8, OpDiv16, OpDiv32, OpDiv64, 196 OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, 197 OpMod8, OpMod16, OpMod32, OpMod64, 198 OpMod8u, OpMod16u, OpMod32u, OpMod64u, 199 // compare 200 OpEq64, OpEq32, OpEq16, OpEq8, 201 OpEq32F, OpEq64F, 202 OpLess64, OpLess32, OpLess16, OpLess8, 203 OpLess64U, OpLess32U, OpLess16U, OpLess8U, 204 OpLess32F, OpLess64F, 205 OpLeq64, OpLeq32, OpLeq16, OpLeq8, 206 OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U, 207 OpLeq32F, OpLeq64F, 208 OpEqB, OpNeqB, 209 // shift 210 OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64, 211 OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64, 212 OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64, 213 // safety check 214 OpIsInBounds, OpIsSliceInBounds, 215 // bit 216 OpAnd8, OpAnd16, OpAnd32, OpAnd64, 217 OpOr8, OpOr16, OpOr32, OpOr64, 218 OpXor8, OpXor16, OpXor32, OpXor64: 219 return true 220 default: 221 return false 222 } 223} 224 225func (t *worklist) getLatticeCell(val *Value) lattice { 226 if !possibleConst(val) { 227 // they are always worst 228 return lattice{bottom, nil} 229 } 230 lt, exist := t.latticeCells[val] 231 if !exist { 232 return lattice{top, nil} // optimistically for un-visited value 233 } 234 return lt 235} 236 237func isConst(val *Value) bool { 238 switch val.Op { 239 case OpConst64, OpConst32, OpConst16, OpConst8, 240 OpConstBool, OpConst32F, OpConst64F: 241 return true 242 default: 243 return false 244 } 245} 246 247// buildDefUses builds def-use chain for some values early, because once the 248// lattice of a value is changed, we need to update lattices of use. But we don't 249// need all uses of it, only uses that can become constants would be added into 250// re-visit worklist since no matter how many times they are revisited, uses which 251// can't become constants lattice remains unchanged, i.e. Bottom. 252func (t *worklist) buildDefUses() { 253 for _, block := range t.f.Blocks { 254 for _, val := range block.Values { 255 for _, arg := range val.Args { 256 // find its uses, only uses that can become constants take into account 257 if possibleConst(arg) && possibleConst(val) { 258 if _, exist := t.defUse[arg]; !exist { 259 t.defUse[arg] = make([]*Value, 0, arg.Uses) 260 } 261 t.defUse[arg] = append(t.defUse[arg], val) 262 } 263 } 264 } 265 for _, ctl := range block.ControlValues() { 266 // for control values that can become constants, find their use blocks 267 if possibleConst(ctl) { 268 t.defBlock[ctl] = append(t.defBlock[ctl], block) 269 } 270 } 271 } 272} 273 274// addUses finds all uses of value and appends them into work list for further process 275func (t *worklist) addUses(val *Value) { 276 for _, use := range t.defUse[val] { 277 if val == use { 278 // Phi may refer to itself as uses, ignore them to avoid re-visiting phi 279 // for performance reason 280 continue 281 } 282 t.uses = append(t.uses, use) 283 } 284 for _, block := range t.defBlock[val] { 285 if t.visitedBlock[block.ID] { 286 t.propagate(block) 287 } 288 } 289} 290 291// meet meets all of phi arguments and computes result lattice 292func (t *worklist) meet(val *Value) lattice { 293 optimisticLt := lattice{top, nil} 294 for i := 0; i < len(val.Args); i++ { 295 edge := Edge{val.Block, i} 296 // If incoming edge for phi is not visited, assume top optimistically. 297 // According to rules of meet: 298 // Top ∩ any = any 299 // Top participates in meet() but does not affect the result, so here 300 // we will ignore Top and only take other lattices into consideration. 301 if _, exist := t.visited[edge]; exist { 302 lt := t.getLatticeCell(val.Args[i]) 303 if lt.tag == constant { 304 if optimisticLt.tag == top { 305 optimisticLt = lt 306 } else { 307 if !equals(optimisticLt, lt) { 308 // ConstantA ∩ ConstantB = Bottom 309 return lattice{bottom, nil} 310 } 311 } 312 } else if lt.tag == bottom { 313 // Bottom ∩ any = Bottom 314 return lattice{bottom, nil} 315 } else { 316 // Top ∩ any = any 317 } 318 } else { 319 // Top ∩ any = any 320 } 321 } 322 323 // ConstantA ∩ ConstantA = ConstantA or Top ∩ any = any 324 return optimisticLt 325} 326 327func computeLattice(f *Func, val *Value, args ...*Value) lattice { 328 // In general, we need to perform constant evaluation based on constant args: 329 // 330 // res := lattice{constant, nil} 331 // switch op { 332 // case OpAdd16: 333 // res.val = newConst(argLt1.val.AuxInt16() + argLt2.val.AuxInt16()) 334 // case OpAdd32: 335 // res.val = newConst(argLt1.val.AuxInt32() + argLt2.val.AuxInt32()) 336 // case OpDiv8: 337 // if !isDivideByZero(argLt2.val.AuxInt8()) { 338 // res.val = newConst(argLt1.val.AuxInt8() / argLt2.val.AuxInt8()) 339 // } 340 // ... 341 // } 342 // 343 // However, this would create a huge switch for all opcodes that can be 344 // evaluated during compile time. Moreover, some operations can be evaluated 345 // only if its arguments satisfy additional conditions(e.g. divide by zero). 346 // It's fragile and error-prone. We did a trick by reusing the existing rules 347 // in generic rules for compile-time evaluation. But generic rules rewrite 348 // original value, this behavior is undesired, because the lattice of values 349 // may change multiple times, once it was rewritten, we lose the opportunity 350 // to change it permanently, which can lead to errors. For example, We cannot 351 // change its value immediately after visiting Phi, because some of its input 352 // edges may still not be visited at this moment. 353 constValue := f.newValue(val.Op, val.Type, f.Entry, val.Pos) 354 constValue.AddArgs(args...) 355 matched := rewriteValuegeneric(constValue) 356 if matched { 357 if isConst(constValue) { 358 return lattice{constant, constValue} 359 } 360 } 361 // Either we can not match generic rules for given value or it does not 362 // satisfy additional constraints(e.g. divide by zero), in these cases, clean 363 // up temporary value immediately in case they are not dominated by their args. 364 constValue.reset(OpInvalid) 365 return lattice{bottom, nil} 366} 367 368func (t *worklist) visitValue(val *Value) { 369 if !possibleConst(val) { 370 // fast fail for always worst Values, i.e. there is no lowering happen 371 // on them, their lattices must be initially worse Bottom. 372 return 373 } 374 375 oldLt := t.getLatticeCell(val) 376 defer func() { 377 // re-visit all uses of value if its lattice is changed 378 newLt := t.getLatticeCell(val) 379 if !equals(newLt, oldLt) { 380 if int8(oldLt.tag) > int8(newLt.tag) { 381 t.f.Fatalf("Must lower lattice\n") 382 } 383 t.addUses(val) 384 } 385 }() 386 387 switch val.Op { 388 // they are constant values, aren't they? 389 case OpConst64, OpConst32, OpConst16, OpConst8, 390 OpConstBool, OpConst32F, OpConst64F: //TODO: support ConstNil ConstString etc 391 t.latticeCells[val] = lattice{constant, val} 392 // lattice value of copy(x) actually means lattice value of (x) 393 case OpCopy: 394 t.latticeCells[val] = t.getLatticeCell(val.Args[0]) 395 // phi should be processed specially 396 case OpPhi: 397 t.latticeCells[val] = t.meet(val) 398 // fold 1-input operations: 399 case 400 // negate 401 OpNeg8, OpNeg16, OpNeg32, OpNeg64, OpNeg32F, OpNeg64F, 402 OpCom8, OpCom16, OpCom32, OpCom64, 403 // math 404 OpFloor, OpCeil, OpTrunc, OpRoundToEven, OpSqrt, 405 // conversion 406 OpTrunc16to8, OpTrunc32to8, OpTrunc32to16, OpTrunc64to8, 407 OpTrunc64to16, OpTrunc64to32, OpCvt32to32F, OpCvt32to64F, 408 OpCvt64to32F, OpCvt64to64F, OpCvt32Fto32, OpCvt32Fto64, 409 OpCvt64Fto32, OpCvt64Fto64, OpCvt32Fto64F, OpCvt64Fto32F, 410 OpCvtBoolToUint8, 411 OpZeroExt8to16, OpZeroExt8to32, OpZeroExt8to64, OpZeroExt16to32, 412 OpZeroExt16to64, OpZeroExt32to64, OpSignExt8to16, OpSignExt8to32, 413 OpSignExt8to64, OpSignExt16to32, OpSignExt16to64, OpSignExt32to64, 414 // bit 415 OpCtz8, OpCtz16, OpCtz32, OpCtz64, 416 // mask 417 OpSlicemask, 418 // safety check 419 OpIsNonNil, 420 // not 421 OpNot: 422 lt1 := t.getLatticeCell(val.Args[0]) 423 424 if lt1.tag == constant { 425 // here we take a shortcut by reusing generic rules to fold constants 426 t.latticeCells[val] = computeLattice(t.f, val, lt1.val) 427 } else { 428 t.latticeCells[val] = lattice{lt1.tag, nil} 429 } 430 // fold 2-input operations 431 case 432 // add 433 OpAdd64, OpAdd32, OpAdd16, OpAdd8, 434 OpAdd32F, OpAdd64F, 435 // sub 436 OpSub64, OpSub32, OpSub16, OpSub8, 437 OpSub32F, OpSub64F, 438 // mul 439 OpMul64, OpMul32, OpMul16, OpMul8, 440 OpMul32F, OpMul64F, 441 // div 442 OpDiv32F, OpDiv64F, 443 OpDiv8, OpDiv16, OpDiv32, OpDiv64, 444 OpDiv8u, OpDiv16u, OpDiv32u, OpDiv64u, //TODO: support div128u 445 // mod 446 OpMod8, OpMod16, OpMod32, OpMod64, 447 OpMod8u, OpMod16u, OpMod32u, OpMod64u, 448 // compare 449 OpEq64, OpEq32, OpEq16, OpEq8, 450 OpEq32F, OpEq64F, 451 OpLess64, OpLess32, OpLess16, OpLess8, 452 OpLess64U, OpLess32U, OpLess16U, OpLess8U, 453 OpLess32F, OpLess64F, 454 OpLeq64, OpLeq32, OpLeq16, OpLeq8, 455 OpLeq64U, OpLeq32U, OpLeq16U, OpLeq8U, 456 OpLeq32F, OpLeq64F, 457 OpEqB, OpNeqB, 458 // shift 459 OpLsh64x64, OpRsh64x64, OpRsh64Ux64, OpLsh32x64, 460 OpRsh32x64, OpRsh32Ux64, OpLsh16x64, OpRsh16x64, 461 OpRsh16Ux64, OpLsh8x64, OpRsh8x64, OpRsh8Ux64, 462 // safety check 463 OpIsInBounds, OpIsSliceInBounds, 464 // bit 465 OpAnd8, OpAnd16, OpAnd32, OpAnd64, 466 OpOr8, OpOr16, OpOr32, OpOr64, 467 OpXor8, OpXor16, OpXor32, OpXor64: 468 lt1 := t.getLatticeCell(val.Args[0]) 469 lt2 := t.getLatticeCell(val.Args[1]) 470 471 if lt1.tag == constant && lt2.tag == constant { 472 // here we take a shortcut by reusing generic rules to fold constants 473 t.latticeCells[val] = computeLattice(t.f, val, lt1.val, lt2.val) 474 } else { 475 if lt1.tag == bottom || lt2.tag == bottom { 476 t.latticeCells[val] = lattice{bottom, nil} 477 } else { 478 t.latticeCells[val] = lattice{top, nil} 479 } 480 } 481 default: 482 // Any other type of value cannot be a constant, they are always worst(Bottom) 483 } 484} 485 486// propagate propagates constants facts through CFG. If the block has single successor, 487// add the successor anyway. If the block has multiple successors, only add the 488// branch destination corresponding to lattice value of condition value. 489func (t *worklist) propagate(block *Block) { 490 switch block.Kind { 491 case BlockExit, BlockRet, BlockRetJmp, BlockInvalid: 492 // control flow ends, do nothing then 493 break 494 case BlockDefer: 495 // we know nothing about control flow, add all branch destinations 496 t.edges = append(t.edges, block.Succs...) 497 case BlockFirst: 498 fallthrough // always takes the first branch 499 case BlockPlain: 500 t.edges = append(t.edges, block.Succs[0]) 501 case BlockIf, BlockJumpTable: 502 cond := block.ControlValues()[0] 503 condLattice := t.getLatticeCell(cond) 504 if condLattice.tag == bottom { 505 // we know nothing about control flow, add all branch destinations 506 t.edges = append(t.edges, block.Succs...) 507 } else if condLattice.tag == constant { 508 // add branchIdx destinations depends on its condition 509 var branchIdx int64 510 if block.Kind == BlockIf { 511 branchIdx = 1 - condLattice.val.AuxInt 512 } else { 513 branchIdx = condLattice.val.AuxInt 514 } 515 t.edges = append(t.edges, block.Succs[branchIdx]) 516 } else { 517 // condition value is not visited yet, don't propagate it now 518 } 519 default: 520 t.f.Fatalf("All kind of block should be processed above.") 521 } 522} 523 524// rewireSuccessor rewires corresponding successors according to constant value 525// discovered by previous analysis. As the result, some successors become unreachable 526// and thus can be removed in further deadcode phase 527func rewireSuccessor(block *Block, constVal *Value) bool { 528 switch block.Kind { 529 case BlockIf: 530 block.removeEdge(int(constVal.AuxInt)) 531 block.Kind = BlockPlain 532 block.Likely = BranchUnknown 533 block.ResetControls() 534 return true 535 case BlockJumpTable: 536 // Remove everything but the known taken branch. 537 idx := int(constVal.AuxInt) 538 if idx < 0 || idx >= len(block.Succs) { 539 // This can only happen in unreachable code, 540 // as an invariant of jump tables is that their 541 // input index is in range. 542 // See issue 64826. 543 return false 544 } 545 block.swapSuccessorsByIdx(0, idx) 546 for len(block.Succs) > 1 { 547 block.removeEdge(1) 548 } 549 block.Kind = BlockPlain 550 block.Likely = BranchUnknown 551 block.ResetControls() 552 return true 553 default: 554 return false 555 } 556} 557 558// replaceConst will replace non-constant values that have been proven by sccp 559// to be constants. 560func (t *worklist) replaceConst() (int, int) { 561 constCnt, rewireCnt := 0, 0 562 for val, lt := range t.latticeCells { 563 if lt.tag == constant { 564 if !isConst(val) { 565 if t.f.pass.debug > 0 { 566 fmt.Printf("Replace %v with %v\n", val.LongString(), lt.val.LongString()) 567 } 568 val.reset(lt.val.Op) 569 val.AuxInt = lt.val.AuxInt 570 constCnt++ 571 } 572 // If const value controls this block, rewires successors according to its value 573 ctrlBlock := t.defBlock[val] 574 for _, block := range ctrlBlock { 575 if rewireSuccessor(block, lt.val) { 576 rewireCnt++ 577 if t.f.pass.debug > 0 { 578 fmt.Printf("Rewire %v %v successors\n", block.Kind, block) 579 } 580 } 581 } 582 } 583 } 584 return constCnt, rewireCnt 585} 586