For "x shift y", check that y < (x's bit-size)
diff --git a/lang/check/bounds.go b/lang/check/bounds.go
index 175673a..615dfc7 100644
--- a/lang/check/bounds.go
+++ b/lang/check/bounds.go
@@ -27,6 +27,13 @@
type bounds = interval.IntRange
+var numShiftBounds = [...]bounds{
+ t.IDU8: {zero, big.NewInt(7)},
+ t.IDU16: {zero, big.NewInt(15)},
+ t.IDU32: {zero, big.NewInt(31)},
+ t.IDU64: {zero, big.NewInt(63)},
+}
+
var numTypeBounds = [...]bounds{
t.IDI8: {big.NewInt(-1 << 7), big.NewInt(1<<7 - 1)},
t.IDI16: {big.NewInt(-1 << 15), big.NewInt(1<<15 - 1)},
@@ -1391,34 +1398,37 @@
big.NewInt(0).Sub(rb[1], one),
}, nil
- case t.IDXBinaryShiftL, t.IDXBinaryTildeModShiftL:
- if lb[0].Sign() < 0 {
- return bounds{}, fmt.Errorf("check: shift op argument %q is possibly negative", lhs.Str(q.tm))
- }
- if rb[0].Sign() < 0 {
- return bounds{}, fmt.Errorf("check: shift op argument %q is possibly negative", rhs.Str(q.tm))
- }
- if rb[1].Cmp(ffff) > 0 {
- return bounds{}, fmt.Errorf("check: shift %q out of range", rhs.Str(q.tm))
- }
- nb, _ := lb.Lsh(rb)
- if op == t.IDXBinaryTildeModShiftL {
- if qid := lhs.MType().QID(); qid[0] == t.IDBase {
- b := numTypeBounds[qid[1]]
- nb[1] = min(nb[1], b[1])
+ case t.IDXBinaryShiftL, t.IDXBinaryTildeModShiftL, t.IDXBinaryShiftR:
+ shiftBounds := bounds{}
+ typeBounds := bounds{}
+ if lTyp := lhs.MType(); lTyp.IsNumType() {
+ id := int(lTyp.QID()[1])
+ if id < len(numShiftBounds) {
+ shiftBounds = numShiftBounds[id]
+ }
+ if id < len(numTypeBounds) {
+ typeBounds = numTypeBounds[id]
}
}
- return nb, nil
+ if shiftBounds[0] == nil {
+ return bounds{}, fmt.Errorf("check: shift op argument %q of type %q does not have unsigned integer type",
+ lhs.Str(q.tm), lhs.MType().Str(q.tm))
+ } else if !shiftBounds.ContainsIntRange(rb) {
+ return bounds{}, fmt.Errorf("check: shift op argument %q is outside the range %s", rhs.Str(q.tm), shiftBounds)
+ }
- case t.IDXBinaryShiftR:
- if lb[0].Sign() < 0 {
- return bounds{}, fmt.Errorf("check: shift op argument %q is possibly negative", lhs.Str(q.tm))
+ switch op {
+ case t.IDXBinaryShiftL:
+ nb, _ := lb.Lsh(rb)
+ return nb, nil
+ case t.IDXBinaryTildeModShiftL:
+ nb, _ := lb.Lsh(rb)
+ nb[1] = min(nb[1], typeBounds[1])
+ return nb, nil
+ case t.IDXBinaryShiftR:
+ nb, _ := lb.Rsh(rb)
+ return nb, nil
}
- if rb[0].Sign() < 0 {
- return bounds{}, fmt.Errorf("check: shift op argument %q is possibly negative", rhs.Str(q.tm))
- }
- nb, _ := lb.Rsh(rb)
- return nb, nil
case t.IDXBinaryAmp, t.IDXBinaryPipe, t.IDXBinaryHat:
// TODO: should type-checking ensure that bitwise ops only apply to
diff --git a/lib/interval/interval.go b/lib/interval/interval.go
index 0668ff8..a896add 100644
--- a/lib/interval/interval.go
+++ b/lib/interval/interval.go
@@ -196,12 +196,28 @@
(x[1] == nil || x[1].Sign() >= 0)
}
-// Contains returns whether x contains i.
-func (x IntRange) Contains(i *big.Int) bool {
+// ContainsInt returns whether x contains i.
+func (x IntRange) ContainsInt(i *big.Int) bool {
return (x[0] == nil || x[0].Cmp(i) <= 0) &&
(x[1] == nil || x[1].Cmp(i) >= 0)
}
+// ContainsIntRange returns whether x contains every element of y.
+//
+// It returns true if y is empty.
+func (x IntRange) ContainsIntRange(y IntRange) bool {
+ if y.Empty() {
+ return true
+ }
+ if (x[0] != nil) && (y[0] == nil || x[0].Cmp(y[0]) > 0) {
+ return false
+ }
+ if (x[1] != nil) && (y[1] == nil || x[1].Cmp(y[1]) < 0) {
+ return false
+ }
+ return true
+}
+
// Eq returns whether x equals y.
func (x IntRange) Eq(y IntRange) bool {
if xe, ye := x.Empty(), y.Empty(); xe || ye {
@@ -602,10 +618,10 @@
// The smaller of those is also a sufficient bound if that smaller
// value is contained in the other interval. For example, if both xx
// and yy can be x[0], then (x[0] | x[0]) is simply x[0].
- if x.Contains(y[0]) {
+ if x.ContainsInt(y[0]) {
return IntRange{y[0], nil}, true
}
- if y.Contains(x[0]) {
+ if y.ContainsInt(x[0]) {
return IntRange{x[0], nil}, true
}
if x[1] == nil && y[1] == nil {
diff --git a/release/c/wuffs-unsupported-snapshot.c b/release/c/wuffs-unsupported-snapshot.c
index b79925d..f52ed58 100644
--- a/release/c/wuffs-unsupported-snapshot.c
+++ b/release/c/wuffs-unsupported-snapshot.c
@@ -5872,7 +5872,8 @@
WUFFS_BASE__COROUTINE_SUSPENSION_POINT_0;
if ((self->private_impl.f_n_bits >= 8) ||
- ((self->private_impl.f_bits >> self->private_impl.f_n_bits) != 0)) {
+ ((self->private_impl.f_bits >> (self->private_impl.f_n_bits & 7)) !=
+ 0)) {
status = wuffs_deflate__error__internal_error_inconsistent_n_bits;
goto exit;
}
@@ -6545,7 +6546,7 @@
}
if ((self->private_impl.f_n_bits >= 8) ||
- ((self->private_impl.f_bits >> self->private_impl.f_n_bits) != 0)) {
+ ((self->private_impl.f_bits >> (self->private_impl.f_n_bits & 7)) != 0)) {
status = wuffs_deflate__error__internal_error_inconsistent_n_bits;
goto exit;
}
@@ -6886,7 +6887,8 @@
WUFFS_BASE__COROUTINE_SUSPENSION_POINT_0;
if ((self->private_impl.f_n_bits >= 8) ||
- ((self->private_impl.f_bits >> self->private_impl.f_n_bits) != 0)) {
+ ((self->private_impl.f_bits >> (self->private_impl.f_n_bits & 7)) !=
+ 0)) {
status = wuffs_deflate__error__internal_error_inconsistent_n_bits;
goto exit;
}
@@ -7186,7 +7188,8 @@
self->private_impl.f_bits = v_bits;
self->private_impl.f_n_bits = v_n_bits;
if ((self->private_impl.f_n_bits >= 8) ||
- ((self->private_impl.f_bits >> self->private_impl.f_n_bits) != 0)) {
+ ((self->private_impl.f_bits >> (self->private_impl.f_n_bits & 7)) !=
+ 0)) {
status = wuffs_deflate__error__internal_error_inconsistent_n_bits;
goto exit;
}
diff --git a/std/deflate/decode_deflate.wuffs b/std/deflate/decode_deflate.wuffs
index 33a549e..4244b84 100644
--- a/std/deflate/decode_deflate.wuffs
+++ b/std/deflate/decode_deflate.wuffs
@@ -242,7 +242,7 @@
// TODO: make this "if" into a function invariant?
//
// Ditto for decode_huffman_slow and decode_huffman_fast.
- if (this.n_bits >= 8) or ((this.bits >> this.n_bits) != 0) {
+ if (this.n_bits >= 8) or ((this.bits >> (this.n_bits & 7)) != 0) {
return "?internal error: inconsistent n_bits"
}
this.n_bits = 0
diff --git a/std/deflate/decode_huffman_fast.wuffs b/std/deflate/decode_huffman_fast.wuffs
index 1a5dc91..e9ee4d0 100644
--- a/std/deflate/decode_huffman_fast.wuffs
+++ b/std/deflate/decode_huffman_fast.wuffs
@@ -34,7 +34,7 @@
var hlen base.u32[..0x7FFF]
var hdist base.u32
- if (this.n_bits >= 8) or ((this.bits >> this.n_bits) != 0) {
+ if (this.n_bits >= 8) or ((this.bits >> (this.n_bits & 7)) != 0) {
return "?internal error: inconsistent n_bits"
}
diff --git a/std/deflate/decode_huffman_slow.wuffs b/std/deflate/decode_huffman_slow.wuffs
index dd59f7a..fd3fcc5 100644
--- a/std/deflate/decode_huffman_slow.wuffs
+++ b/std/deflate/decode_huffman_slow.wuffs
@@ -38,7 +38,7 @@
// decode_huffman_*.wuffs files as small as possible, while retaining both
// correctness and performance.
- if (this.n_bits >= 8) or ((this.bits >> this.n_bits) != 0) {
+ if (this.n_bits >= 8) or ((this.bits >> (this.n_bits & 7)) != 0) {
return "?internal error: inconsistent n_bits"
}
@@ -285,7 +285,7 @@
this.bits = bits
this.n_bits = n_bits
- if (this.n_bits >= 8) or ((this.bits >> this.n_bits) != 0) {
+ if (this.n_bits >= 8) or ((this.bits >> (this.n_bits & 7)) != 0) {
return "?internal error: inconsistent n_bits"
}
}