#  Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
#  See https://llvm.org/LICENSE.txt for license information.
#  SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from ._func_ops_gen import *
from ._func_ops_gen import _Dialect

try:
    from ..ir import *
    from ._ods_common import (
        get_default_loc_context as _get_default_loc_context,
        _cext as _ods_cext,
    )

    import inspect

    from typing import Any, List, Optional, Sequence, Union
except ImportError as e:
    raise RuntimeError("Error loading imports from extension module") from e

ARGUMENT_ATTRIBUTE_NAME = "arg_attrs"
RESULT_ATTRIBUTE_NAME = "res_attrs"


@_ods_cext.register_operation(_Dialect, replace=True)
class ConstantOp(ConstantOp):
    """Specialization for the constant op class."""

    @property
    def type(self):
        return self.results[0].type


@_ods_cext.register_operation(_Dialect, replace=True)
class FuncOp(FuncOp):
    """Specialization for the func op class."""

    def __init__(
        self, name, type, *, visibility=None, body_builder=None, loc=None, ip=None
    ):
        """
        Create a FuncOp with the provided `name`, `type`, and `visibility`.
        - `name` is a string representing the function name.
        - `type` is either a FunctionType or a pair of list describing inputs and
          results.
        - `visibility` is a string matching `public`, `private`, or `nested`. None
          implies private visibility.
        - `body_builder` is an optional callback, when provided a new entry block
          is created and the callback is invoked with the new op as argument within
          an InsertionPoint context already set for the block. The callback is
          expected to insert a terminator in the block.
        """
        sym_name = StringAttr.get(str(name))

        # If the type is passed as a tuple, build a FunctionType on the fly.
        if isinstance(type, tuple):
            type = FunctionType.get(inputs=type[0], results=type[1])

        type = TypeAttr.get(type)
        sym_visibility = (
            StringAttr.get(str(visibility)) if visibility is not None else None
        )
        super().__init__(sym_name, type, sym_visibility=sym_visibility, loc=loc, ip=ip)
        if body_builder:
            entry_block = self.add_entry_block()
            with InsertionPoint(entry_block):
                body_builder(self)

    @property
    def is_external(self):
        return len(self.regions[0].blocks) == 0

    @property
    def body(self):
        return self.regions[0]

    @property
    def type(self):
        return FunctionType(TypeAttr(self.attributes["function_type"]).value)

    @property
    def visibility(self):
        return self.attributes["sym_visibility"]

    @property
    def name(self) -> StringAttr:
        return StringAttr(self.attributes["sym_name"])

    @property
    def entry_block(self):
        if self.is_external:
            raise IndexError("External function does not have a body")
        return self.regions[0].blocks[0]

    def add_entry_block(self, arg_locs: Optional[Sequence[Location]] = None):
        """
        Add an entry block to the function body using the function signature to
        infer block arguments.
        Returns the newly created block
        """
        if not self.is_external:
            raise IndexError("The function already has an entry block!")
        self.body.blocks.append(*self.type.inputs, arg_locs=arg_locs)
        return self.body.blocks[0]

    @property
    def arg_attrs(self):
        return ArrayAttr(self.attributes[ARGUMENT_ATTRIBUTE_NAME])

    @arg_attrs.setter
    def arg_attrs(self, attribute: Union[ArrayAttr, list]):
        if isinstance(attribute, ArrayAttr):
            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = attribute
        else:
            self.attributes[ARGUMENT_ATTRIBUTE_NAME] = ArrayAttr.get(
                attribute, context=self.context
            )

    @property
    def arguments(self):
        return self.entry_block.arguments

    @property
    def result_attrs(self):
        return self.attributes[RESULT_ATTRIBUTE_NAME]

    @result_attrs.setter
    def result_attrs(self, attribute: ArrayAttr):
        self.attributes[RESULT_ATTRIBUTE_NAME] = attribute

    @classmethod
    def from_py_func(
        FuncOp,
        *inputs: Type,
        results: Optional[Sequence[Type]] = None,
        name: Optional[str] = None,
    ):
        """Decorator to define an MLIR FuncOp specified as a python function.

        Requires that an `mlir.ir.InsertionPoint` and `mlir.ir.Location` are
        active for the current thread (i.e. established in a `with` block).

        When applied as a decorator to a Python function, an entry block will
        be constructed for the FuncOp with types as specified in `*inputs`. The
        block arguments will be passed positionally to the Python function. In
        addition, if the Python function accepts keyword arguments generally or
        has a corresponding keyword argument, the following will be passed:
          * `func_op`: The `func` op being defined.

        By default, the function name will be the Python function `__name__`. This
        can be overriden by passing the `name` argument to the decorator.

        If `results` is not specified, then the decorator will implicitly
        insert a `ReturnOp` with the `Value`'s returned from the decorated
        function. It will also set the `FuncOp` type with the actual return
        value types. If `results` is specified, then the decorated function
        must return `None` and no implicit `ReturnOp` is added (nor are the result
        types updated). The implicit behavior is intended for simple, single-block
        cases, and users should specify result types explicitly for any complicated
        cases.

        The decorated function can further be called from Python and will insert
        a `CallOp` at the then-current insertion point, returning either None (
        if no return values), a unary Value (for one result), or a list of Values).
        This mechanism cannot be used to emit recursive calls (by construction).
        """

        def decorator(f):
            from . import func

            # Introspect the callable for optional features.
            sig = inspect.signature(f)
            has_arg_func_op = False
            for param in sig.parameters.values():
                if param.kind == param.VAR_KEYWORD:
                    has_arg_func_op = True
                if param.name == "func_op" and (
                    param.kind == param.POSITIONAL_OR_KEYWORD
                    or param.kind == param.KEYWORD_ONLY
                ):
                    has_arg_func_op = True

            # Emit the FuncOp.
            implicit_return = results is None
            symbol_name = name or f.__name__
            function_type = FunctionType.get(
                inputs=inputs, results=[] if implicit_return else results
            )
            func_op = FuncOp(name=symbol_name, type=function_type)
            with InsertionPoint(func_op.add_entry_block()):
                func_args = func_op.entry_block.arguments
                func_kwargs = {}
                if has_arg_func_op:
                    func_kwargs["func_op"] = func_op
                return_values = f(*func_args, **func_kwargs)
                if not implicit_return:
                    return_types = list(results)
                    assert return_values is None, (
                        "Capturing a python function with explicit `results=` "
                        "requires that the wrapped function returns None."
                    )
                else:
                    # Coerce return values, add ReturnOp and rewrite func type.
                    if return_values is None:
                        return_values = []
                    elif isinstance(return_values, tuple):
                        return_values = list(return_values)
                    elif isinstance(return_values, Value):
                        # Returning a single value is fine, coerce it into a list.
                        return_values = [return_values]
                    elif isinstance(return_values, OpView):
                        # Returning a single operation is fine, coerce its results a list.
                        return_values = return_values.operation.results
                    elif isinstance(return_values, Operation):
                        # Returning a single operation is fine, coerce its results a list.
                        return_values = return_values.results
                    else:
                        return_values = list(return_values)
                    func.ReturnOp(return_values)
                    # Recompute the function type.
                    return_types = [v.type for v in return_values]
                    function_type = FunctionType.get(
                        inputs=inputs, results=return_types
                    )
                    func_op.attributes["function_type"] = TypeAttr.get(function_type)

            def emit_call_op(*call_args):
                call_op = func.CallOp(
                    return_types, FlatSymbolRefAttr.get(symbol_name), call_args
                )
                if return_types is None:
                    return None
                elif len(return_types) == 1:
                    return call_op.result
                else:
                    return call_op.results

            wrapped = emit_call_op
            wrapped.__name__ = f.__name__
            wrapped.func_op = func_op
            return wrapped

        return decorator


func = FuncOp.from_py_func


@_ods_cext.register_operation(_Dialect, replace=True)
class CallOp(CallOp):
    """Specialization for the call op class."""

    def __init__(
        self,
        calleeOrResults: Union[FuncOp, List[Type]],
        argumentsOrCallee: Union[List, FlatSymbolRefAttr, str],
        arguments: Optional[List] = None,
        *,
        loc=None,
        ip=None,
    ):
        """Creates an call operation.

        The constructor accepts three different forms:

          1. A function op to be called followed by a list of arguments.
          2. A list of result types, followed by the name of the function to be
             called as string, following by a list of arguments.
          3. A list of result types, followed by the name of the function to be
             called as symbol reference attribute, followed by a list of arguments.

        For example

            f = func.FuncOp("foo", ...)
            func.CallOp(f, [args])
            func.CallOp([result_types], "foo", [args])

        In all cases, the location and insertion point may be specified as keyword
        arguments if not provided by the surrounding context managers.
        """

        # TODO: consider supporting constructor "overloads", e.g., through a custom
        # or pybind-provided metaclass.
        if isinstance(calleeOrResults, FuncOp):
            if not isinstance(argumentsOrCallee, list):
                raise ValueError(
                    "when constructing a call to a function, expected "
                    + "the second argument to be a list of call arguments, "
                    + f"got {type(argumentsOrCallee)}"
                )
            if arguments is not None:
                raise ValueError(
                    "unexpected third argument when constructing a call"
                    + "to a function"
                )

            super().__init__(
                calleeOrResults.type.results,
                FlatSymbolRefAttr.get(
                    calleeOrResults.name.value, context=_get_default_loc_context(loc)
                ),
                argumentsOrCallee,
                loc=loc,
                ip=ip,
            )
            return

        if isinstance(argumentsOrCallee, list):
            raise ValueError(
                "when constructing a call to a function by name, "
                + "expected the second argument to be a string or a "
                + f"FlatSymbolRefAttr, got {type(argumentsOrCallee)}"
            )

        if isinstance(argumentsOrCallee, FlatSymbolRefAttr):
            super().__init__(
                calleeOrResults, argumentsOrCallee, arguments, loc=loc, ip=ip
            )
        elif isinstance(argumentsOrCallee, str):
            super().__init__(
                calleeOrResults,
                FlatSymbolRefAttr.get(
                    argumentsOrCallee, context=_get_default_loc_context(loc)
                ),
                arguments,
                loc=loc,
                ip=ip,
            )
