Remove explicit check for correct number of rows

GitOrigin-RevId: 3c14a310b41ce98fe56d23085dffb47da5872f5e
Change-Id: If5fe62a00c785da6d36aeb47b116dbc0c624aea3
diff --git a/code/png.py b/code/png.py
index ad1690f..1103bc0 100755
--- a/code/png.py
+++ b/code/png.py
@@ -659,9 +659,10 @@
         Write a PNG image to the output file.
         `rows` should be an iterable that yields each row
         (each row is a sequence of values).
-        The rows should be the rows of the original image,
-        so there should be ``self.height`` rows of
-        ``self.width * self.planes`` values.
+        The first `self.height` rows will be used for the PNG file.
+        Extra rows are left unconsumed, but insufficient rows
+        will raise a `ProtocolError`.
+        Each row should have `self.width * self.planes` values.
         """
 
         # Values per row
@@ -688,12 +689,7 @@
                     )
                 yield row
 
-        nrows = self.write_passes(outfile, check_rows(rows))
-        if nrows != self.height:
-            raise ProtocolError(
-                "rows supplied (%d) does not match height (%d)" % (nrows, self.height)
-            )
-        return nrows
+        return self.write_passes(outfile, check_rows(rows))
 
     def write_passes(self, outfile, rows):
         """
@@ -757,9 +753,12 @@
         data = array("B")
 
         # raise i scope out of the for loop. set to -1, because the for loop
-        # sets i to 0 on the first pass
-        i = -1
-        for i, row in enumerate(rows):
+        irows = iter(rows)
+        for i in range(self.height):
+            try:
+                row = next(irows)
+            except StopIteration:
+                raise ProtocolError("Not enough rows: %d supplied; %d required" % (i, self.height))
             # Add "None" filter type.
             # Currently, it's essential that this filter type be used
             # for every scanline as
@@ -781,7 +780,6 @@
             write_chunk(outfile, b"IDAT", compressed + flushed)
         # https://www.w3.org/TR/PNG/#11IEND
         write_chunk(outfile, b"IEND")
-        return i + 1
 
     def write_preamble(self, outfile):
         # https://www.w3.org/TR/PNG/#5PNG-file-signature
diff --git a/code/test_png.py b/code/test_png.py
index 307dd4b..52e83db 100644
--- a/code/test_png.py
+++ b/code/test_png.py
@@ -363,36 +363,13 @@
             w.write(o, rows)
 
     def test_write_empty(self):
-        """Test writing an empty file expecting an error."""
+        """Test writing no rows raise expected error."""
         w = png.Writer(1, 1)
         o = BytesIO()
         empty = []
 
-        with self.assertRaises(png.ProtocolError) as cm:
-            try:
-                w.write(o, empty)
-            except UnboundLocalError as e:
-                """
-                Protect against:
-                File "test_png.py", line 399, in test_write_empty
-                w.write(o, empty)
-                UnboundLocalError: local variable 'i' referenced before
-                assignment
-                """
-                self.fail("UnexpectedLocalError exception: {}".format(e))
-
-        self.assertEqual(
-            str(cm.exception),
-            "ProtocolError: rows supplied (0) does not match height (1)",
-        )
-
-    def test_write_length(self):
-        """Test row length is returned from Writer.write()"""
-        w = png.Writer(1, 2)
-        o = BytesIO()
-        rows = [[1], [1]]
-        row_count = w.write(o, rows)
-        self.assertEqual(row_count, 2)
+        with self.assertRaises(png.ProtocolError):
+            w.write(o, empty)
 
     def test_write_background_colour(self):
         """Test that background keyword works."""
@@ -1094,14 +1071,6 @@
             png.ProtocolError, png.Writer, 1, 4, bitdepth=2, palette=[a, b, c]
         )
 
-    def test_wrong_rows(self):
-        """
-        Wrong number of rows.
-        """
-        rows = [[0, 0xAA], [0x55, 0xFF]]
-        w = png.Writer(size=(2, 1))
-        self.assertRaises(png.ProtocolError, w.write, BytesIO(), rows)
-
     def test_write_palette_bad_fraction(self):
         """Palette with fractions should raise error."""
         a = (255, 255, 255, 0.9)