Add include_error_handler parameter.

This works like error_handler but is specific to a template
when included in another using the include tag.

Change-Id: Ie5506a8cba42c71519c703eacc82050902b9ceba
Pull-request: https://bitbucket.org/zzzeek/mako/pull-requests/22
diff --git a/doc/build/changelog.rst b/doc/build/changelog.rst
index d01e91d..a1f743d 100644
--- a/doc/build/changelog.rst
+++ b/doc/build/changelog.rst
@@ -8,6 +8,15 @@
 .. changelog::
     :version: 1.0.6
 
+    .. change::
+        :tags: feature
+
+      Added new parameter :paramref:`.Template.include_error_handler` .
+      This works like :paramref:`.Template.error_handler` but indicates the
+      handler should take place when this template is included within another
+      template via the ``<%include>`` tag.  Pull request courtesy
+      Huayi Zhang.
+
 .. changelog::
     :version: 1.0.5
     :released: Wed Nov 2 2016
diff --git a/mako/lookup.py b/mako/lookup.py
index a9c5bb2..0d3f304 100644
--- a/mako/lookup.py
+++ b/mako/lookup.py
@@ -180,7 +180,8 @@
                  enable_loop=True,
                  input_encoding=None,
                  preprocessor=None,
-                 lexer_cls=None):
+                 lexer_cls=None,
+                 include_error_handler=None):
 
         self.directories = [posixpath.normpath(d) for d in
                             util.to_list(directories, ())
@@ -203,6 +204,7 @@
         self.template_args = {
             'format_exceptions': format_exceptions,
             'error_handler': error_handler,
+            'include_error_handler': include_error_handler,
             'disable_unicode': disable_unicode,
             'bytestring_passthrough': bytestring_passthrough,
             'output_encoding': output_encoding,
diff --git a/mako/runtime.py b/mako/runtime.py
index 5c40381..769541c 100644
--- a/mako/runtime.py
+++ b/mako/runtime.py
@@ -749,7 +749,16 @@
     (callable_, ctx) = _populate_self_namespace(
         context._clean_inheritance_tokens(),
         template)
-    callable_(ctx, **_kwargs_for_include(callable_, context._data, **kwargs))
+    kwargs = _kwargs_for_include(callable_, context._data, **kwargs)
+    if template.include_error_handler:
+        try:
+            callable_(ctx, **kwargs)
+        except Exception:
+            result = template.include_error_handler(ctx, compat.exception_as())
+            if not result:
+                compat.reraise(*sys.exc_info())
+    else:
+        callable_(ctx, **kwargs)
 
 
 def _inherit_from(context, uri, calling_uri):
diff --git a/mako/template.py b/mako/template.py
index bacbc13..c3e0c25 100644
--- a/mako/template.py
+++ b/mako/template.py
@@ -109,6 +109,11 @@
      completes. Is used to provide custom error-rendering
      functions.
 
+     .. seealso::
+
+        :paramref:`.Template.include_error_handler` - include-specific
+        error handler function
+
     :param format_exceptions: if ``True``, exceptions which occur during
      the render phase of this template will be caught and
      formatted into an HTML error page, which then becomes the
@@ -129,6 +134,16 @@
      import will not appear as the first executed statement in the generated
      code and will therefore not have the desired effect.
 
+    :param include_error_handler: An error handler that runs when this template
+     is included within another one via the ``<%include>`` tag, and raises an
+     error.  Compare to the :paramref:`.Template.error_handler` option.
+
+     .. versionadded:: 1.0.6
+
+     .. seealso::
+
+        :paramref:`.Template.error_handler` - top-level error handler function
+
     :param input_encoding: Encoding of the template's source code.  Can
      be used in lieu of the coding comment. See
      :ref:`usage_unicode` as well as :ref:`unicode_toplevel` for
@@ -243,7 +258,8 @@
                  future_imports=None,
                  enable_loop=True,
                  preprocessor=None,
-                 lexer_cls=None):
+                 lexer_cls=None,
+                 include_error_handler=None):
         if uri:
             self.module_id = re.sub(r'\W', "_", uri)
             self.uri = uri
@@ -329,6 +345,7 @@
         self.callable_ = self.module.render_body
         self.format_exceptions = format_exceptions
         self.error_handler = error_handler
+        self.include_error_handler = include_error_handler
         self.lookup = lookup
 
         self.module_directory = module_directory
@@ -528,6 +545,7 @@
                  cache_type=None,
                  cache_dir=None,
                  cache_url=None,
+                 include_error_handler=None,
                  ):
         self.module_id = re.sub(r'\W', "_", module._template_uri)
         self.uri = module._template_uri
@@ -559,6 +577,7 @@
         self.callable_ = self.module.render_body
         self.format_exceptions = format_exceptions
         self.error_handler = error_handler
+        self.include_error_handler = include_error_handler
         self.lookup = lookup
         self._setup_cache_args(
             cache_impl, cache_enabled, cache_args,
@@ -579,6 +598,7 @@
         self.encoding_errors = parent.encoding_errors
         self.format_exceptions = parent.format_exceptions
         self.error_handler = parent.error_handler
+        self.include_error_handler = parent.include_error_handler
         self.enable_loop = parent.enable_loop
         self.lookup = parent.lookup
         self.bytestring_passthrough = parent.bytestring_passthrough
diff --git a/test/test_template.py b/test/test_template.py
index f551230..7f21978 100644
--- a/test/test_template.py
+++ b/test/test_template.py
@@ -613,6 +613,21 @@
         """)
         assert flatten_result(lookup.get_template("c").render()) == "bar: calling bar this is a"
 
+    def test_include_error_handler(self):
+        def handle(context, error):
+            context.write('include error')
+            return True
+
+        lookup = TemplateLookup(include_error_handler=handle)
+        lookup.put_string("a", """
+            this is a.
+            <%include file="b"/>
+        """)
+        lookup.put_string("b", """
+            this is b ${1/0} end.
+        """)
+        assert flatten_result(lookup.get_template("a").render()) == "this is a. this is b include error"
+
 class UndefinedVarsTest(TemplateTest):
     def test_undefined(self):
         t = Template("""