Add a threshold program

GitOrigin-RevId: 472df66d6a0ee76530c6698fa75ec28e2dd28566
Change-Id: I075a4d25cff1804a75561635560a018c37553df1
diff --git a/code/prithreshpng b/code/prithreshpng
new file mode 100755
index 0000000..9803d53
--- /dev/null
+++ b/code/prithreshpng
@@ -0,0 +1,105 @@
+#!/usr/bin/env python
+
+# prithreshpng
+# Threshold one image against another.
+# Output a 1-bit per channel PNG, where l <= r
+
+import collections
+
+from array import array
+
+import png
+
+"""
+prithreshpng file1.png file2.png
+
+The `prithreshpng` tool compares channels from the input images.
+"""
+
+Image = collections.namedtuple("Image", "rows info")
+
+
+class Error(Exception):
+    pass
+
+class ImageError(Error):
+    pass
+
+
+def thresh(out, args):
+    """Compare input PNG files and threshold the right image
+    against the left;
+    the output image is 1 when l <= r, that is when the right
+    image is at least as bright as the left.
+    """
+
+    paths = args.input
+
+    if len(paths) != 2:
+        raise Error("Required input is missing.")
+    
+    images = []
+
+    for image_index, path in enumerate(paths):
+        inp = png.cli_open(path)
+        rows, info = png.Reader(file=inp).asDirect()[2:]
+        rows = list(rows)
+        image = Image(rows, info)
+        images.append(image)
+
+    planes = images[0].info["planes"]
+    size = images[0].info["size"]
+    for image in images:
+        if image.info["planes"] != planes:
+            raise ImageError("All images should have same number of channels")
+        if image.info["size"] != size:
+            raise ImageError("All images should have same size")
+
+    size = images[0].info["size"]
+    out_channels = planes
+
+    # Values per row, of output image
+    vpr = out_channels * size[0]
+
+    def thresh_row_iter():
+        """
+        Yield each woven row in turn.
+        """
+        # The zip call creates an iterator that yields
+        # a tuple with each element containing the next row
+        # for each of the input images.
+        for row_tuple in zip(*(image.rows for image in images)):
+            # Compare values pairwise
+            vs = zip(*row_tuple)
+            # output row
+            row = array("B", [v[0] <= v[1] for v in vs])
+            yield row
+
+    w = png.Writer(
+        size[0],
+        size[1],
+        greyscale=True,
+        alpha=False,
+        bitdepth=1,
+    )
+    w.write(out, thresh_row_iter())
+
+
+def main(argv=None):
+    import argparse
+    import itertools
+    import sys
+
+    if argv is None:
+        argv = sys.argv
+    argv = argv[1:]
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("input", nargs=2)
+    args = parser.parse_args(argv)
+
+    return thresh(png.binary_stdout(), args)
+
+
+if __name__ == "__main__":
+    main()