| // Copyright 2023 syzkaller project authors. All rights reserved. |
| // Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file. |
| |
| package prog |
| |
| import ( |
| "errors" |
| "fmt" |
| ) |
| |
| func (bo *BinaryExpression) Evaluate(finder ArgFinder) (uint64, bool) { |
| left, ok := bo.Left.Evaluate(finder) |
| if !ok { |
| return 0, false |
| } |
| right, ok := bo.Right.Evaluate(finder) |
| if !ok { |
| return 0, false |
| } |
| switch bo.Operator { |
| case OperatorCompareEq: |
| if left == right { |
| return 1, true |
| } |
| return 0, true |
| case OperatorCompareNeq: |
| if left != right { |
| return 1, true |
| } |
| return 0, true |
| case OperatorBinaryAnd: |
| return left & right, true |
| } |
| panic(fmt.Sprintf("unknown operator %q", bo.Operator)) |
| } |
| |
| func (v *Value) Evaluate(finder ArgFinder) (uint64, bool) { |
| if len(v.Path) == 0 { |
| return v.Value, true |
| } |
| found := finder(v.Path) |
| if found == SquashedArgFound { |
| // This is expectable. |
| return 0, false |
| } |
| if found == nil { |
| panic(fmt.Sprintf("no argument was found by %v", v.Path)) |
| } |
| constArg, ok := found.(*ConstArg) |
| if !ok { |
| panic("value expressions must only rely on int fields") |
| } |
| return constArg.Val, true |
| } |
| |
| func makeArgFinder(t *Target, c *Call, unionArg *UnionArg, parents parentStack) ArgFinder { |
| return func(path []string) Arg { |
| f := t.findArg(unionArg.Option, path, nil, nil, parents, 0) |
| if f == nil { |
| return nil |
| } |
| if f.isAnyPtr { |
| return SquashedArgFound |
| } |
| return f.arg |
| } |
| } |
| |
| func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, changed bool) { |
| if r.inPatchConditional { |
| return nil, false |
| } |
| r.inPatchConditional = true |
| defer func() { r.inPatchConditional = false }() |
| |
| var extraCalls []*Call |
| var anyPatched bool |
| for { |
| replace := map[Arg]Arg{} |
| forEachStaleUnion(r.target, c, |
| func(unionArg *UnionArg, unionType *UnionType, okIndices []int) { |
| idx := okIndices[r.Intn(len(okIndices))] |
| newType, newDir := unionType.Fields[idx].Type, |
| unionType.Fields[idx].Dir(unionArg.Dir()) |
| newTypeArg, newCalls := r.generateArg(s, newType, newDir) |
| replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx) |
| extraCalls = append(extraCalls, newCalls...) |
| anyPatched = true |
| }) |
| for old, new := range replace { |
| replaceArg(old, new) |
| } |
| // The newly inserted argument might contain more arguments we need |
| // to patch. |
| // Repeat until we have to change nothing. |
| if len(replace) == 0 { |
| break |
| } |
| } |
| return extraCalls, anyPatched |
| } |
| |
| func forEachStaleUnion(target *Target, c *Call, cb func(*UnionArg, *UnionType, []int)) { |
| for _, callArg := range c.Args { |
| foreachSubArgWithStack(callArg, func(arg Arg, argCtx *ArgCtx) { |
| if target.isAnyPtr(arg.Type()) { |
| argCtx.Stop = true |
| return |
| } |
| unionArg, ok := arg.(*UnionArg) |
| if !ok { |
| return |
| } |
| unionType, ok := arg.Type().(*UnionType) |
| if !ok || !unionType.isConditional() { |
| return |
| } |
| argFinder := makeArgFinder(target, c, unionArg, argCtx.parentStack) |
| ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder) |
| if !calculated { |
| // Let it stay as is. |
| return |
| } |
| if !unionArg.transient && ok { |
| return |
| } |
| matchingIndices := matchingUnionArgs(unionType, argFinder) |
| if len(matchingIndices) == 0 { |
| // Conditional fields are transformed in such a way |
| // that one field always matches. |
| // For unions we demand that there's a field w/o conditions. |
| panic(fmt.Sprintf("no matching union fields: %#v", unionType)) |
| } |
| cb(unionArg, unionType, matchingIndices) |
| }) |
| } |
| } |
| |
| func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) { |
| field := typ.Fields[idx] |
| if field.Condition == nil { |
| return true, true |
| } |
| val, ok := field.Condition.Evaluate(finder) |
| if !ok { |
| // We could not calculate the expression. |
| // Let the union stay as it was. |
| return true, false |
| } |
| return val != 0, true |
| } |
| |
| func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int { |
| var ret []int |
| for i := range typ.Fields { |
| ok, _ := checkUnionArg(i, typ, finder) |
| if ok { |
| ret = append(ret, i) |
| } |
| } |
| return ret |
| } |
| |
| func (p *Prog) checkConditions() error { |
| for _, c := range p.Calls { |
| err := c.checkConditions(p.Target, false) |
| if err != nil { |
| return err |
| } |
| } |
| return nil |
| } |
| |
| var ErrViolatedConditions = errors.New("conditional fields rules violation") |
| |
| func (c *Call) checkConditions(target *Target, ignoreTransient bool) error { |
| var ret error |
| forEachStaleUnion(target, c, |
| func(a *UnionArg, t *UnionType, okIndices []int) { |
| if ignoreTransient && a.transient { |
| return |
| } |
| ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions", |
| ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name, |
| okIndices) |
| }) |
| return ret |
| } |
| |
| func (c *Call) setDefaultConditions(target *Target, transientOnly bool) bool { |
| var anyReplaced bool |
| // Replace stale conditions with the default values of their correct types. |
| for { |
| replace := map[Arg]Arg{} |
| forEachStaleUnion(target, c, |
| func(unionArg *UnionArg, unionType *UnionType, okIndices []int) { |
| if transientOnly && !unionArg.transient { |
| return |
| } |
| // If several union options match, take the first one. |
| idx := okIndices[0] |
| field := unionType.Fields[idx] |
| replace[unionArg] = MakeUnionArg(unionType, |
| unionArg.Dir(), |
| field.DefaultArg(field.Dir(unionArg.Dir())), |
| idx) |
| }) |
| for old, new := range replace { |
| anyReplaced = true |
| replaceArg(old, new) |
| } |
| if len(replace) == 0 { |
| break |
| } |
| } |
| return anyReplaced |
| } |