| //===- switches.go - misc utils -------------------------------------------===// |
| // |
| // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| // See https://llvm.org/LICENSE.txt for license information. |
| // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| // |
| //===----------------------------------------------------------------------===// |
| // |
| // This file implements transformations and IR generation for switches. |
| // |
| //===----------------------------------------------------------------------===// |
| |
| package irgen |
| |
| import ( |
| "go/token" |
| |
| "llvm.org/llgo/third_party/gotools/go/exact" |
| "llvm.org/llgo/third_party/gotools/go/ssa" |
| "llvm.org/llgo/third_party/gotools/go/ssa/ssautil" |
| "llvm.org/llvm/bindings/go/llvm" |
| ) |
| |
| // switchInstr is an instruction representing a switch on constant |
| // integer values. |
| type switchInstr struct { |
| ssa.Instruction |
| ssautil.Switch |
| } |
| |
| func (sw *switchInstr) String() string { |
| return sw.Switch.String() |
| } |
| |
| func (sw *switchInstr) Parent() *ssa.Function { |
| return sw.Default.Instrs[0].Parent() |
| } |
| |
| func (sw *switchInstr) Block() *ssa.BasicBlock { |
| return sw.Start |
| } |
| |
| func (sw *switchInstr) Operands(rands []*ssa.Value) []*ssa.Value { |
| return nil |
| } |
| |
| func (sw *switchInstr) Pos() token.Pos { |
| return token.NoPos |
| } |
| |
| // emitSwitch emits an LLVM switch instruction. |
| func (fr *frame) emitSwitch(instr *switchInstr) { |
| cases, _ := dedupConstCases(fr, instr.ConstCases) |
| ncases := len(cases) |
| elseblock := fr.block(instr.Default) |
| llswitch := fr.builder.CreateSwitch(fr.llvmvalue(instr.X), elseblock, ncases) |
| for _, c := range cases { |
| llswitch.AddCase(fr.llvmvalue(c.Value), fr.block(c.Body)) |
| } |
| } |
| |
| // transformSwitches replaces the final If statement in start blocks |
| // with a high-level switch instruction, and erases chained condition |
| // blocks. |
| func (fr *frame) transformSwitches(f *ssa.Function) { |
| for _, sw := range ssautil.Switches(f) { |
| if sw.ConstCases == nil { |
| // TODO(axw) investigate switch |
| // on hashes in type switches. |
| continue |
| } |
| if !isInteger(sw.X.Type()) && !isBoolean(sw.X.Type()) { |
| // LLVM switches can only operate on integers. |
| continue |
| } |
| instr := &switchInstr{Switch: sw} |
| sw.Start.Instrs[len(sw.Start.Instrs)-1] = instr |
| for _, c := range sw.ConstCases[1:] { |
| fr.blocks[c.Block.Index].EraseFromParent() |
| fr.blocks[c.Block.Index] = llvm.BasicBlock{} |
| } |
| |
| // Fix predecessors in successor blocks for fixupPhis. |
| cases, duplicates := dedupConstCases(fr, instr.ConstCases) |
| for _, c := range cases { |
| for _, succ := range c.Block.Succs { |
| for i, pred := range succ.Preds { |
| if pred == c.Block { |
| succ.Preds[i] = sw.Start |
| break |
| } |
| } |
| } |
| } |
| |
| // Remove redundant edges corresponding to duplicate cases |
| // that will not feature in the LLVM switch instruction. |
| for _, c := range duplicates { |
| for _, succ := range c.Block.Succs { |
| for i, pred := range succ.Preds { |
| if pred == c.Block { |
| head := succ.Preds[:i] |
| tail := succ.Preds[i+1:] |
| succ.Preds = append(head, tail...) |
| removePhiEdge(succ, i) |
| break |
| } |
| } |
| } |
| } |
| } |
| } |
| |
| // dedupConstCases separates duplicate const cases. |
| // |
| // TODO(axw) fix this in go/ssa/ssautil. |
| func dedupConstCases(fr *frame, in []ssautil.ConstCase) (unique, duplicates []ssautil.ConstCase) { |
| unique = make([]ssautil.ConstCase, 0, len(in)) |
| dedup: |
| for i, c1 := range in { |
| for _, c2 := range in[i+1:] { |
| if exact.Compare(c1.Value.Value, token.EQL, c2.Value.Value) { |
| duplicates = append(duplicates, c1) |
| continue dedup |
| } |
| } |
| unique = append(unique, c1) |
| } |
| return unique, duplicates |
| } |
| |
| // removePhiEdge removes the i'th edge from each PHI |
| // instruction in the specified basic block. |
| func removePhiEdge(bb *ssa.BasicBlock, i int) { |
| for _, instr := range bb.Instrs { |
| instr, ok := instr.(*ssa.Phi) |
| if !ok { |
| return |
| } |
| head := instr.Edges[:i] |
| tail := instr.Edges[i+1:] |
| instr.Edges = append(head, tail...) |
| } |
| } |