Remove explicit check for correct number of rows
GitOrigin-RevId: 3c14a310b41ce98fe56d23085dffb47da5872f5e
Change-Id: If5fe62a00c785da6d36aeb47b116dbc0c624aea3
diff --git a/code/png.py b/code/png.py
index ad1690f..1103bc0 100755
--- a/code/png.py
+++ b/code/png.py
@@ -659,9 +659,10 @@
Write a PNG image to the output file.
`rows` should be an iterable that yields each row
(each row is a sequence of values).
- The rows should be the rows of the original image,
- so there should be ``self.height`` rows of
- ``self.width * self.planes`` values.
+ The first `self.height` rows will be used for the PNG file.
+ Extra rows are left unconsumed, but insufficient rows
+ will raise a `ProtocolError`.
+ Each row should have `self.width * self.planes` values.
"""
# Values per row
@@ -688,12 +689,7 @@
)
yield row
- nrows = self.write_passes(outfile, check_rows(rows))
- if nrows != self.height:
- raise ProtocolError(
- "rows supplied (%d) does not match height (%d)" % (nrows, self.height)
- )
- return nrows
+ return self.write_passes(outfile, check_rows(rows))
def write_passes(self, outfile, rows):
"""
@@ -757,9 +753,12 @@
data = array("B")
# raise i scope out of the for loop. set to -1, because the for loop
- # sets i to 0 on the first pass
- i = -1
- for i, row in enumerate(rows):
+ irows = iter(rows)
+ for i in range(self.height):
+ try:
+ row = next(irows)
+ except StopIteration:
+ raise ProtocolError("Not enough rows: %d supplied; %d required" % (i, self.height))
# Add "None" filter type.
# Currently, it's essential that this filter type be used
# for every scanline as
@@ -781,7 +780,6 @@
write_chunk(outfile, b"IDAT", compressed + flushed)
# https://www.w3.org/TR/PNG/#11IEND
write_chunk(outfile, b"IEND")
- return i + 1
def write_preamble(self, outfile):
# https://www.w3.org/TR/PNG/#5PNG-file-signature
diff --git a/code/test_png.py b/code/test_png.py
index 307dd4b..52e83db 100644
--- a/code/test_png.py
+++ b/code/test_png.py
@@ -363,36 +363,13 @@
w.write(o, rows)
def test_write_empty(self):
- """Test writing an empty file expecting an error."""
+ """Test writing no rows raise expected error."""
w = png.Writer(1, 1)
o = BytesIO()
empty = []
- with self.assertRaises(png.ProtocolError) as cm:
- try:
- w.write(o, empty)
- except UnboundLocalError as e:
- """
- Protect against:
- File "test_png.py", line 399, in test_write_empty
- w.write(o, empty)
- UnboundLocalError: local variable 'i' referenced before
- assignment
- """
- self.fail("UnexpectedLocalError exception: {}".format(e))
-
- self.assertEqual(
- str(cm.exception),
- "ProtocolError: rows supplied (0) does not match height (1)",
- )
-
- def test_write_length(self):
- """Test row length is returned from Writer.write()"""
- w = png.Writer(1, 2)
- o = BytesIO()
- rows = [[1], [1]]
- row_count = w.write(o, rows)
- self.assertEqual(row_count, 2)
+ with self.assertRaises(png.ProtocolError):
+ w.write(o, empty)
def test_write_background_colour(self):
"""Test that background keyword works."""
@@ -1094,14 +1071,6 @@
png.ProtocolError, png.Writer, 1, 4, bitdepth=2, palette=[a, b, c]
)
- def test_wrong_rows(self):
- """
- Wrong number of rows.
- """
- rows = [[0, 0xAA], [0x55, 0xFF]]
- w = png.Writer(size=(2, 1))
- self.assertRaises(png.ProtocolError, w.write, BytesIO(), rows)
-
def test_write_palette_bad_fraction(self):
"""Palette with fractions should raise error."""
a = (255, 255, 255, 0.9)