Add a threshold program
GitOrigin-RevId: 472df66d6a0ee76530c6698fa75ec28e2dd28566
Change-Id: I075a4d25cff1804a75561635560a018c37553df1
diff --git a/code/prithreshpng b/code/prithreshpng
new file mode 100755
index 0000000..9803d53
--- /dev/null
+++ b/code/prithreshpng
@@ -0,0 +1,105 @@
+#!/usr/bin/env python
+
+# prithreshpng
+# Threshold one image against another.
+# Output a 1-bit per channel PNG, where l <= r
+
+import collections
+
+from array import array
+
+import png
+
+"""
+prithreshpng file1.png file2.png
+
+The `prithreshpng` tool compares channels from the input images.
+"""
+
+Image = collections.namedtuple("Image", "rows info")
+
+
+class Error(Exception):
+ pass
+
+class ImageError(Error):
+ pass
+
+
+def thresh(out, args):
+ """Compare input PNG files and threshold the right image
+ against the left;
+ the output image is 1 when l <= r, that is when the right
+ image is at least as bright as the left.
+ """
+
+ paths = args.input
+
+ if len(paths) != 2:
+ raise Error("Required input is missing.")
+
+ images = []
+
+ for image_index, path in enumerate(paths):
+ inp = png.cli_open(path)
+ rows, info = png.Reader(file=inp).asDirect()[2:]
+ rows = list(rows)
+ image = Image(rows, info)
+ images.append(image)
+
+ planes = images[0].info["planes"]
+ size = images[0].info["size"]
+ for image in images:
+ if image.info["planes"] != planes:
+ raise ImageError("All images should have same number of channels")
+ if image.info["size"] != size:
+ raise ImageError("All images should have same size")
+
+ size = images[0].info["size"]
+ out_channels = planes
+
+ # Values per row, of output image
+ vpr = out_channels * size[0]
+
+ def thresh_row_iter():
+ """
+ Yield each woven row in turn.
+ """
+ # The zip call creates an iterator that yields
+ # a tuple with each element containing the next row
+ # for each of the input images.
+ for row_tuple in zip(*(image.rows for image in images)):
+ # Compare values pairwise
+ vs = zip(*row_tuple)
+ # output row
+ row = array("B", [v[0] <= v[1] for v in vs])
+ yield row
+
+ w = png.Writer(
+ size[0],
+ size[1],
+ greyscale=True,
+ alpha=False,
+ bitdepth=1,
+ )
+ w.write(out, thresh_row_iter())
+
+
+def main(argv=None):
+ import argparse
+ import itertools
+ import sys
+
+ if argv is None:
+ argv = sys.argv
+ argv = argv[1:]
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("input", nargs=2)
+ args = parser.parse_args(argv)
+
+ return thresh(png.binary_stdout(), args)
+
+
+if __name__ == "__main__":
+ main()