Fix 'from __future__ import foo' and source file coding directives. (#6)
diff --git a/compiler/python_archive.py b/compiler/python_archive.py
index db66a0b..1eb9f5a 100755
--- a/compiler/python_archive.py
+++ b/compiler/python_archive.py
@@ -32,6 +32,7 @@
import logging
import os
import pkgutil
+import re
import sys
import tempfile
import zipfile
@@ -41,7 +42,7 @@
from subpar.compiler import stored_resource
# Boilerplate code added to __main__.py
-_main_template = """\
+_boilerplate_template = """\
# Boilerplate added by subpar/compiler/python_archive.py
from %(runtime_package)s import support as _
_.setup(import_roots=%(import_roots)s)
@@ -127,25 +128,51 @@
output_dir = os.path.dirname(self.output_filename)
return tempfile.NamedTemporaryFile(dir=output_dir, delete=False)
- def generate_main(self):
+ def generate_boilerplate(self):
+ """Generate boilerplate to be insert into __main__.py
+
+ We don't know the encoding of the main source file, so
+ require that the template be pure ascii, which we can safely
+ insert.
+
+ Returns:
+ A string containing only ascii characters
+ """
+ boilerplate_contents = _boilerplate_template % {
+ 'runtime_package': _runtime_package,
+ 'import_roots': str(self.import_roots),
+ }
+ return boilerplate_contents.encode('ascii').decode('ascii')
+
+ def generate_main(self, main_filename, boilerplate_contents):
"""Generate the contents of the __main__.py file
We take the module that is specified as the main entry point,
- and prepend some boilerplate to invoke import helper code.
+ and insert some boilerplate to invoke import helper code.
Returns:
A StoredResource
"""
- template_contents = _main_template % {
- 'runtime_package': _runtime_package,
- 'import_roots': str(self.import_roots),
- }
- with open(self.main_filename, 'rb') as main_file:
- main_contents = main_file.read()
- # We don't know the encoding of the main source file, so
- # require that the template be pure ascii, which we can safely
- # prepend.
- contents = template_contents.encode('ascii') + main_contents
+ # Read main source file, in unknown encoding. We use latin-1
+ # here, but any single-byte encoding that doesn't raise errors
+ # would work.
+ output_lines = []
+ with io.open(main_filename, 'rt', encoding='latin-1') as main_file:
+ output_lines = list(main_file)
+
+ # Find a good place to insert the boilerplate, which is the
+ # first line that is not a comment, blank line, or future
+ # import.
+ skip_regex = re.compile('''(#.*)|(\\s+)|(from\\s+__future__\\s+import)''')
+ idx = 0
+ while idx < len(output_lines):
+ if not skip_regex.match(output_lines[idx]):
+ break
+ idx += 1
+
+ # Insert boilerplate (might be beginning, middle or end)
+ output_lines[idx:idx] = [boilerplate_contents]
+ contents = ''.join(output_lines).encode('latin-1')
return stored_resource.StoredContent('__main__.py', contents)
def scan_manifest(self, manifest):
@@ -177,7 +204,8 @@
('Configuration error for [%s]: Manifest file included a '
'file named __main__.py, which is not allowed') %
self.manifest_filename)
- stored_resources['__main__.py'] = self.generate_main()
+ stored_resources['__main__.py'] = self.generate_main(
+ self.main_filename, self.generate_boilerplate())
# Add an __init__.py for each parent package of the support files
for stored_filename in _runtime_init_files:
diff --git a/compiler/python_archive_test.py b/compiler/python_archive_test.py
index aa7922e..c1974d1 100644
--- a/compiler/python_archive_test.py
+++ b/compiler/python_archive_test.py
@@ -122,6 +122,36 @@
# t closed but not deleted
self.assertTrue(os.path.exists(t.name))
+ def test_generate_boilerplate(self):
+ par = self._construct()
+ boilerplate = par.generate_boilerplate()
+ self.assertIn('Boilerplate', boilerplate)
+
+ def test_generate_main(self):
+ par = self._construct()
+ boilerplate = 'BOILERPLATE\n'
+ cases = [
+ # Insert at beginning
+ (b'spam = eggs\n',
+ b'BOILERPLATE\nspam = eggs\n'),
+ # Insert in the middle
+ (b'# a comment\nspam = eggs\n',
+ b'# a comment\nBOILERPLATE\nspam = eggs\n'),
+ # Insert after the end
+ (b'# a comment\n',
+ b'# a comment\nBOILERPLATE\n'),
+ # Blank lines
+ (b'\n \t\n',
+ b'\n \t\nBOILERPLATE\n'),
+ # Future import
+ (b'from __future__ import print_function\n',
+ b'from __future__ import print_function\nBOILERPLATE\n'),
+ ]
+ for main_content, expected in cases:
+ with test_utils.temp_file(main_content) as main_file:
+ actual = par.generate_main(main_file.name, boilerplate)
+ self.assertEqual(expected, actual.content)
+
def test_scan_manifest(self):
par = self._construct()
manifest = {'foo.py': '/something/foo.py', 'bar.py': None,}
diff --git a/compiler/test_utils.py b/compiler/test_utils.py
index 62a16c4..d94e0f6 100644
--- a/compiler/test_utils.py
+++ b/compiler/test_utils.py
@@ -15,7 +15,6 @@
"""Common test utilities"""
import os
-import sys
import tempfile
diff --git a/tests/BUILD b/tests/BUILD
index 43d3c3e..c09a933 100644
--- a/tests/BUILD
+++ b/tests/BUILD
@@ -72,6 +72,17 @@
srcs_version = "PY2AND3",
)
+[par_binary(
+ name = "package_g/g_%s" % version,
+ srcs = ["package_g/g.py"],
+ default_python_version = version,
+ main = "package_g/g.py",
+ srcs_version = "PY2AND3",
+) for version in [
+ "PY2",
+ "PY3",
+]]
+
# Test targets
[sh_test(
name = "%s_%s" % (name, version),
@@ -92,6 +103,7 @@
("import_root_test", ":package_d/d", "/package_d/d"),
("external_workspace_test", ":package_e/e", "/package_e/e"),
("version_test", ":package_f/f", "/package_f/f"),
+ ("main_boilerplate_test", ":package_g/g", "/package_g/g"),
] for version in [
"PY2",
"PY3",
diff --git a/tests/package_g/g.py b/tests/package_g/g.py
new file mode 100755
index 0000000..2006319
--- /dev/null
+++ b/tests/package_g/g.py
@@ -0,0 +1,20 @@
+#!/usr/bin/env python3
+# -*- coding: latin-1
+
+# Test __future__ imports
+from __future__ import print_function
+
+"""Integration test program G for Subpar.
+
+Test bootstrap interaction with __future__ imports and source file encodings.
+"""
+
+# Test the source file encoding specification above. See PEP 263 for
+# details. In the line below, this source file contains a byte
+# sequence that is valid latin-1 but not valid utf-8. Specifically,
+# between the two single quotes is a single byte 0xE4 (latin-1
+# encoding of LATIN SMALL LETTER A WITH DIAERESIS), and _not_ the
+# two-byte UTF-8 sequence 0xC3 0xA4.
+latin_1_bytes = u'ä'
+assert len(latin_1_bytes) == 1
+assert ord(latin_1_bytes[0]) == 0xE4
diff --git a/tests/package_g/g_PY2_filelist.txt b/tests/package_g/g_PY2_filelist.txt
new file mode 100644
index 0000000..d495021
--- /dev/null
+++ b/tests/package_g/g_PY2_filelist.txt
@@ -0,0 +1,8 @@
+__main__.py
+subpar/__init__.py
+subpar/runtime/__init__.py
+subpar/runtime/support.py
+subpar/tests/__init__.py
+subpar/tests/package_g/__init__.py
+subpar/tests/package_g/g.py
+subpar/tests/package_g/g_PY2
diff --git a/tests/package_g/g_PY3_filelist.txt b/tests/package_g/g_PY3_filelist.txt
new file mode 100644
index 0000000..028bf66
--- /dev/null
+++ b/tests/package_g/g_PY3_filelist.txt
@@ -0,0 +1,8 @@
+__main__.py
+subpar/__init__.py
+subpar/runtime/__init__.py
+subpar/runtime/support.py
+subpar/tests/__init__.py
+subpar/tests/package_g/__init__.py
+subpar/tests/package_g/g.py
+subpar/tests/package_g/g_PY3