blob: cfe6eb70bd108971d9ebefb75957ec798596f538 [file] [log] [blame] [edit]
// Copyright 2023 syzkaller project authors. All rights reserved.
// Use of this source code is governed by Apache 2 LICENSE that can be found in the LICENSE file.
package prog
import (
"errors"
"fmt"
)
func (bo *BinaryExpression) Evaluate(finder ArgFinder) (uint64, bool) {
left, ok := bo.Left.Evaluate(finder)
if !ok {
return 0, false
}
right, ok := bo.Right.Evaluate(finder)
if !ok {
return 0, false
}
switch bo.Operator {
case OperatorCompareEq:
if left == right {
return 1, true
}
return 0, true
case OperatorCompareNeq:
if left != right {
return 1, true
}
return 0, true
case OperatorBinaryAnd:
return left & right, true
}
panic(fmt.Sprintf("unknown operator %q", bo.Operator))
}
func (v *Value) Evaluate(finder ArgFinder) (uint64, bool) {
if len(v.Path) == 0 {
return v.Value, true
}
found := finder(v.Path)
if found == SquashedArgFound {
// This is expectable.
return 0, false
}
if found == nil {
panic(fmt.Sprintf("no argument was found by %v", v.Path))
}
constArg, ok := found.(*ConstArg)
if !ok {
panic("value expressions must only rely on int fields")
}
return constArg.Val, true
}
func makeArgFinder(t *Target, c *Call, unionArg *UnionArg, parents parentStack) ArgFinder {
return func(path []string) Arg {
f := t.findArg(unionArg.Option, path, nil, nil, parents, 0)
if f == nil {
return nil
}
if f.isAnyPtr {
return SquashedArgFound
}
return f.arg
}
}
func (r *randGen) patchConditionalFields(c *Call, s *state) (extra []*Call, changed bool) {
if r.inPatchConditional {
return nil, false
}
r.inPatchConditional = true
defer func() { r.inPatchConditional = false }()
var extraCalls []*Call
var anyPatched bool
for {
replace := map[Arg]Arg{}
forEachStaleUnion(r.target, c,
func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
idx := okIndices[r.Intn(len(okIndices))]
newType, newDir := unionType.Fields[idx].Type,
unionType.Fields[idx].Dir(unionArg.Dir())
newTypeArg, newCalls := r.generateArg(s, newType, newDir)
replace[unionArg] = MakeUnionArg(unionType, newDir, newTypeArg, idx)
extraCalls = append(extraCalls, newCalls...)
anyPatched = true
})
for old, new := range replace {
replaceArg(old, new)
}
// The newly inserted argument might contain more arguments we need
// to patch.
// Repeat until we have to change nothing.
if len(replace) == 0 {
break
}
}
return extraCalls, anyPatched
}
func forEachStaleUnion(target *Target, c *Call, cb func(*UnionArg, *UnionType, []int)) {
for _, callArg := range c.Args {
foreachSubArgWithStack(callArg, func(arg Arg, argCtx *ArgCtx) {
if target.isAnyPtr(arg.Type()) {
argCtx.Stop = true
return
}
unionArg, ok := arg.(*UnionArg)
if !ok {
return
}
unionType, ok := arg.Type().(*UnionType)
if !ok || !unionType.isConditional() {
return
}
argFinder := makeArgFinder(target, c, unionArg, argCtx.parentStack)
ok, calculated := checkUnionArg(unionArg.Index, unionType, argFinder)
if !calculated {
// Let it stay as is.
return
}
if !unionArg.transient && ok {
return
}
matchingIndices := matchingUnionArgs(unionType, argFinder)
if len(matchingIndices) == 0 {
// Conditional fields are transformed in such a way
// that one field always matches.
// For unions we demand that there's a field w/o conditions.
panic(fmt.Sprintf("no matching union fields: %#v", unionType))
}
cb(unionArg, unionType, matchingIndices)
})
}
}
func checkUnionArg(idx int, typ *UnionType, finder ArgFinder) (ok, calculated bool) {
field := typ.Fields[idx]
if field.Condition == nil {
return true, true
}
val, ok := field.Condition.Evaluate(finder)
if !ok {
// We could not calculate the expression.
// Let the union stay as it was.
return true, false
}
return val != 0, true
}
func matchingUnionArgs(typ *UnionType, finder ArgFinder) []int {
var ret []int
for i := range typ.Fields {
ok, _ := checkUnionArg(i, typ, finder)
if ok {
ret = append(ret, i)
}
}
return ret
}
func (p *Prog) checkConditions() error {
for _, c := range p.Calls {
err := c.checkConditions(p.Target, false)
if err != nil {
return err
}
}
return nil
}
var ErrViolatedConditions = errors.New("conditional fields rules violation")
func (c *Call) checkConditions(target *Target, ignoreTransient bool) error {
var ret error
forEachStaleUnion(target, c,
func(a *UnionArg, t *UnionType, okIndices []int) {
if ignoreTransient && a.transient {
return
}
ret = fmt.Errorf("%w union %s field is #%d(%s), but %v satisfy conditions",
ErrViolatedConditions, t.Name(), a.Index, t.Fields[a.Index].Name,
okIndices)
})
return ret
}
func (c *Call) setDefaultConditions(target *Target, transientOnly bool) bool {
var anyReplaced bool
// Replace stale conditions with the default values of their correct types.
for {
replace := map[Arg]Arg{}
forEachStaleUnion(target, c,
func(unionArg *UnionArg, unionType *UnionType, okIndices []int) {
if transientOnly && !unionArg.transient {
return
}
// If several union options match, take the first one.
idx := okIndices[0]
field := unionType.Fields[idx]
replace[unionArg] = MakeUnionArg(unionType,
unionArg.Dir(),
field.DefaultArg(field.Dir(unionArg.Dir())),
idx)
})
for old, new := range replace {
anyReplaced = true
replaceArg(old, new)
}
if len(replace) == 0 {
break
}
}
return anyReplaced
}