| #!/usr/bin/env python3 |
| """Stub generator for C modules. |
| |
| The public interface is via the mypy.stubgen module. |
| """ |
| |
| from __future__ import annotations |
| |
| import importlib |
| import inspect |
| import os.path |
| import re |
| from abc import abstractmethod |
| from types import ModuleType |
| from typing import Any, Iterable, Mapping |
| from typing_extensions import Final |
| |
| from mypy.moduleinspect import is_c_module |
| from mypy.stubdoc import ( |
| ArgSig, |
| FunctionSig, |
| infer_arg_sig_from_anon_docstring, |
| infer_prop_type_from_docstring, |
| infer_ret_type_sig_from_anon_docstring, |
| infer_ret_type_sig_from_docstring, |
| infer_sig_from_docstring, |
| ) |
| |
| # Members of the typing module to consider for importing by default. |
| _DEFAULT_TYPING_IMPORTS: Final = ( |
| "Any", |
| "Callable", |
| "ClassVar", |
| "Dict", |
| "Iterable", |
| "Iterator", |
| "List", |
| "Optional", |
| "Tuple", |
| "Union", |
| ) |
| |
| |
| class SignatureGenerator: |
| """Abstract base class for extracting a list of FunctionSigs for each function.""" |
| |
| def remove_self_type( |
| self, inferred: list[FunctionSig] | None, self_var: str |
| ) -> list[FunctionSig] | None: |
| """Remove type annotation from self/cls argument""" |
| if inferred: |
| for signature in inferred: |
| if signature.args: |
| if signature.args[0].name == self_var: |
| signature.args[0].type = None |
| return inferred |
| |
| @abstractmethod |
| def get_function_sig( |
| self, func: object, module_name: str, name: str |
| ) -> list[FunctionSig] | None: |
| pass |
| |
| @abstractmethod |
| def get_method_sig( |
| self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str |
| ) -> list[FunctionSig] | None: |
| pass |
| |
| |
| class ExternalSignatureGenerator(SignatureGenerator): |
| def __init__( |
| self, func_sigs: dict[str, str] | None = None, class_sigs: dict[str, str] | None = None |
| ): |
| """ |
| Takes a mapping of function/method names to signatures and class name to |
| class signatures (usually corresponds to __init__). |
| """ |
| self.func_sigs = func_sigs or {} |
| self.class_sigs = class_sigs or {} |
| |
| def get_function_sig( |
| self, func: object, module_name: str, name: str |
| ) -> list[FunctionSig] | None: |
| if name in self.func_sigs: |
| return [ |
| FunctionSig( |
| name=name, |
| args=infer_arg_sig_from_anon_docstring(self.func_sigs[name]), |
| ret_type="Any", |
| ) |
| ] |
| else: |
| return None |
| |
| def get_method_sig( |
| self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str |
| ) -> list[FunctionSig] | None: |
| if ( |
| name in ("__new__", "__init__") |
| and name not in self.func_sigs |
| and class_name in self.class_sigs |
| ): |
| return [ |
| FunctionSig( |
| name=name, |
| args=infer_arg_sig_from_anon_docstring(self.class_sigs[class_name]), |
| ret_type=infer_method_ret_type(name), |
| ) |
| ] |
| inferred = self.get_function_sig(func, module_name, name) |
| return self.remove_self_type(inferred, self_var) |
| |
| |
| class DocstringSignatureGenerator(SignatureGenerator): |
| def get_function_sig( |
| self, func: object, module_name: str, name: str |
| ) -> list[FunctionSig] | None: |
| docstr = getattr(func, "__doc__", None) |
| inferred = infer_sig_from_docstring(docstr, name) |
| if inferred: |
| assert docstr is not None |
| if is_pybind11_overloaded_function_docstring(docstr, name): |
| # Remove pybind11 umbrella (*args, **kwargs) for overloaded functions |
| del inferred[-1] |
| return inferred |
| |
| def get_method_sig( |
| self, |
| cls: type, |
| func: object, |
| module_name: str, |
| class_name: str, |
| func_name: str, |
| self_var: str, |
| ) -> list[FunctionSig] | None: |
| inferred = self.get_function_sig(func, module_name, func_name) |
| if not inferred and func_name == "__init__": |
| # look for class-level constructor signatures of the form <class_name>(<signature>) |
| inferred = self.get_function_sig(cls, module_name, class_name) |
| return self.remove_self_type(inferred, self_var) |
| |
| |
| class FallbackSignatureGenerator(SignatureGenerator): |
| def get_function_sig( |
| self, func: object, module_name: str, name: str |
| ) -> list[FunctionSig] | None: |
| return [ |
| FunctionSig( |
| name=name, |
| args=infer_arg_sig_from_anon_docstring("(*args, **kwargs)"), |
| ret_type="Any", |
| ) |
| ] |
| |
| def get_method_sig( |
| self, cls: type, func: object, module_name: str, class_name: str, name: str, self_var: str |
| ) -> list[FunctionSig] | None: |
| return [ |
| FunctionSig( |
| name=name, |
| args=infer_method_args(name, self_var), |
| ret_type=infer_method_ret_type(name), |
| ) |
| ] |
| |
| |
| def generate_stub_for_c_module( |
| module_name: str, |
| target: str, |
| known_modules: list[str], |
| sig_generators: Iterable[SignatureGenerator], |
| ) -> None: |
| """Generate stub for C module. |
| |
| Signature generators are called in order until a list of signatures is returned. The order |
| is: |
| - signatures inferred from .rst documentation (if given) |
| - simple runtime introspection (looking for docstrings and attributes |
| with simple builtin types) |
| - fallback based special method names or "(*args, **kwargs)" |
| |
| If directory for target doesn't exist it will be created. Existing stub |
| will be overwritten. |
| """ |
| module = importlib.import_module(module_name) |
| assert is_c_module(module), f"{module_name} is not a C module" |
| subdir = os.path.dirname(target) |
| if subdir and not os.path.isdir(subdir): |
| os.makedirs(subdir) |
| imports: list[str] = [] |
| functions: list[str] = [] |
| done = set() |
| items = sorted(get_members(module), key=lambda x: x[0]) |
| for name, obj in items: |
| if is_c_function(obj): |
| generate_c_function_stub( |
| module, |
| name, |
| obj, |
| output=functions, |
| known_modules=known_modules, |
| imports=imports, |
| sig_generators=sig_generators, |
| ) |
| done.add(name) |
| types: list[str] = [] |
| for name, obj in items: |
| if name.startswith("__") and name.endswith("__"): |
| continue |
| if is_c_type(obj): |
| generate_c_type_stub( |
| module, |
| name, |
| obj, |
| output=types, |
| known_modules=known_modules, |
| imports=imports, |
| sig_generators=sig_generators, |
| ) |
| done.add(name) |
| variables = [] |
| for name, obj in items: |
| if name.startswith("__") and name.endswith("__"): |
| continue |
| if name not in done and not inspect.ismodule(obj): |
| type_str = strip_or_import( |
| get_type_fullname(type(obj)), module, known_modules, imports |
| ) |
| variables.append(f"{name}: {type_str}") |
| output = sorted(set(imports)) |
| for line in variables: |
| output.append(line) |
| for line in types: |
| if line.startswith("class") and output and output[-1]: |
| output.append("") |
| output.append(line) |
| if output and functions: |
| output.append("") |
| for line in functions: |
| output.append(line) |
| output = add_typing_import(output) |
| with open(target, "w") as file: |
| for line in output: |
| file.write(f"{line}\n") |
| |
| |
| def add_typing_import(output: list[str]) -> list[str]: |
| """Add typing imports for collections/types that occur in the generated stub.""" |
| names = [] |
| for name in _DEFAULT_TYPING_IMPORTS: |
| if any(re.search(r"\b%s\b" % name, line) for line in output): |
| names.append(name) |
| if names: |
| return [f"from typing import {', '.join(names)}", ""] + output |
| else: |
| return output.copy() |
| |
| |
| def get_members(obj: object) -> list[tuple[str, Any]]: |
| obj_dict: Mapping[str, Any] = getattr(obj, "__dict__") # noqa: B009 |
| results = [] |
| for name in obj_dict: |
| if is_skipped_attribute(name): |
| continue |
| # Try to get the value via getattr |
| try: |
| value = getattr(obj, name) |
| except AttributeError: |
| continue |
| else: |
| results.append((name, value)) |
| return results |
| |
| |
| def is_c_function(obj: object) -> bool: |
| return inspect.isbuiltin(obj) or type(obj) is type(ord) |
| |
| |
| def is_c_method(obj: object) -> bool: |
| return inspect.ismethoddescriptor(obj) or type(obj) in ( |
| type(str.index), |
| type(str.__add__), |
| type(str.__new__), |
| ) |
| |
| |
| def is_c_classmethod(obj: object) -> bool: |
| return inspect.isbuiltin(obj) or type(obj).__name__ in ( |
| "classmethod", |
| "classmethod_descriptor", |
| ) |
| |
| |
| def is_c_property(obj: object) -> bool: |
| return inspect.isdatadescriptor(obj) or hasattr(obj, "fget") |
| |
| |
| def is_c_property_readonly(prop: Any) -> bool: |
| return hasattr(prop, "fset") and prop.fset is None |
| |
| |
| def is_c_type(obj: object) -> bool: |
| return inspect.isclass(obj) or type(obj) is type(int) |
| |
| |
| def is_pybind11_overloaded_function_docstring(docstr: str, name: str) -> bool: |
| return docstr.startswith(f"{name}(*args, **kwargs)\n" + "Overloaded function.\n\n") |
| |
| |
| def generate_c_function_stub( |
| module: ModuleType, |
| name: str, |
| obj: object, |
| *, |
| known_modules: list[str], |
| sig_generators: Iterable[SignatureGenerator], |
| output: list[str], |
| imports: list[str], |
| self_var: str | None = None, |
| cls: type | None = None, |
| class_name: str | None = None, |
| ) -> None: |
| """Generate stub for a single function or method. |
| |
| The result (always a single line) will be appended to 'output'. |
| If necessary, any required names will be added to 'imports'. |
| The 'class_name' is used to find signature of __init__ or __new__ in |
| 'class_sigs'. |
| """ |
| inferred: list[FunctionSig] | None = None |
| if class_name: |
| # method: |
| assert cls is not None, "cls should be provided for methods" |
| assert self_var is not None, "self_var should be provided for methods" |
| for sig_gen in sig_generators: |
| inferred = sig_gen.get_method_sig( |
| cls, obj, module.__name__, class_name, name, self_var |
| ) |
| if inferred: |
| # add self/cls var, if not present |
| for sig in inferred: |
| if not sig.args or sig.args[0].name not in ("self", "cls"): |
| sig.args.insert(0, ArgSig(name=self_var)) |
| break |
| else: |
| # function: |
| for sig_gen in sig_generators: |
| inferred = sig_gen.get_function_sig(obj, module.__name__, name) |
| if inferred: |
| break |
| |
| if not inferred: |
| raise ValueError( |
| "No signature was found. This should never happen " |
| "if FallbackSignatureGenerator is provided" |
| ) |
| |
| is_overloaded = len(inferred) > 1 if inferred else False |
| if is_overloaded: |
| imports.append("from typing import overload") |
| if inferred: |
| for signature in inferred: |
| args: list[str] = [] |
| for arg in signature.args: |
| arg_def = arg.name |
| if arg_def == "None": |
| arg_def = "_none" # None is not a valid argument name |
| |
| if arg.type: |
| arg_def += ": " + strip_or_import(arg.type, module, known_modules, imports) |
| |
| if arg.default: |
| arg_def += " = ..." |
| |
| args.append(arg_def) |
| |
| if is_overloaded: |
| output.append("@overload") |
| # a sig generator indicates @classmethod by specifying the cls arg |
| if class_name and signature.args and signature.args[0].name == "cls": |
| output.append("@classmethod") |
| output.append( |
| "def {function}({args}) -> {ret}: ...".format( |
| function=name, |
| args=", ".join(args), |
| ret=strip_or_import(signature.ret_type, module, known_modules, imports), |
| ) |
| ) |
| |
| |
| def strip_or_import( |
| typ: str, module: ModuleType, known_modules: list[str], imports: list[str] |
| ) -> str: |
| """Strips unnecessary module names from typ. |
| |
| If typ represents a type that is inside module or is a type coming from builtins, remove |
| module declaration from it. Return stripped name of the type. |
| |
| Arguments: |
| typ: name of the type |
| module: in which this type is used |
| known_modules: other modules being processed |
| imports: list of import statements (may be modified during the call) |
| """ |
| local_modules = ["builtins"] |
| if module: |
| local_modules.append(module.__name__) |
| |
| stripped_type = typ |
| if any(c in typ for c in "[,"): |
| for subtyp in re.split(r"[\[,\]]", typ): |
| stripped_subtyp = strip_or_import(subtyp.strip(), module, known_modules, imports) |
| if stripped_subtyp != subtyp: |
| stripped_type = re.sub( |
| r"(^|[\[, ]+)" + re.escape(subtyp) + r"($|[\], ]+)", |
| r"\1" + stripped_subtyp + r"\2", |
| stripped_type, |
| ) |
| elif "." in typ: |
| for module_name in local_modules + list(reversed(known_modules)): |
| if typ.startswith(module_name + "."): |
| if module_name in local_modules: |
| stripped_type = typ[len(module_name) + 1 :] |
| arg_module = module_name |
| break |
| else: |
| arg_module = typ[: typ.rindex(".")] |
| if arg_module not in local_modules: |
| imports.append(f"import {arg_module}") |
| if stripped_type == "NoneType": |
| stripped_type = "None" |
| return stripped_type |
| |
| |
| def is_static_property(obj: object) -> bool: |
| return type(obj).__name__ == "pybind11_static_property" |
| |
| |
| def generate_c_property_stub( |
| name: str, |
| obj: object, |
| static_properties: list[str], |
| rw_properties: list[str], |
| ro_properties: list[str], |
| readonly: bool, |
| module: ModuleType | None = None, |
| known_modules: list[str] | None = None, |
| imports: list[str] | None = None, |
| ) -> None: |
| """Generate property stub using introspection of 'obj'. |
| |
| Try to infer type from docstring, append resulting lines to 'output'. |
| """ |
| |
| def infer_prop_type(docstr: str | None) -> str | None: |
| """Infer property type from docstring or docstring signature.""" |
| if docstr is not None: |
| inferred = infer_ret_type_sig_from_anon_docstring(docstr) |
| if not inferred: |
| inferred = infer_ret_type_sig_from_docstring(docstr, name) |
| if not inferred: |
| inferred = infer_prop_type_from_docstring(docstr) |
| return inferred |
| else: |
| return None |
| |
| inferred = infer_prop_type(getattr(obj, "__doc__", None)) |
| if not inferred: |
| fget = getattr(obj, "fget", None) |
| inferred = infer_prop_type(getattr(fget, "__doc__", None)) |
| if not inferred: |
| inferred = "Any" |
| |
| if module is not None and imports is not None and known_modules is not None: |
| inferred = strip_or_import(inferred, module, known_modules, imports) |
| |
| if is_static_property(obj): |
| trailing_comment = " # read-only" if readonly else "" |
| static_properties.append(f"{name}: ClassVar[{inferred}] = ...{trailing_comment}") |
| else: # regular property |
| if readonly: |
| ro_properties.append("@property") |
| ro_properties.append(f"def {name}(self) -> {inferred}: ...") |
| else: |
| rw_properties.append(f"{name}: {inferred}") |
| |
| |
| def generate_c_type_stub( |
| module: ModuleType, |
| class_name: str, |
| obj: type, |
| output: list[str], |
| known_modules: list[str], |
| imports: list[str], |
| sig_generators: Iterable[SignatureGenerator], |
| ) -> None: |
| """Generate stub for a single class using runtime introspection. |
| |
| The result lines will be appended to 'output'. If necessary, any |
| required names will be added to 'imports'. |
| """ |
| raw_lookup = getattr(obj, "__dict__") # noqa: B009 |
| items = sorted(get_members(obj), key=lambda x: method_name_sort_key(x[0])) |
| names = set(x[0] for x in items) |
| methods: list[str] = [] |
| types: list[str] = [] |
| static_properties: list[str] = [] |
| rw_properties: list[str] = [] |
| ro_properties: list[str] = [] |
| attrs: list[tuple[str, Any]] = [] |
| for attr, value in items: |
| # use unevaluated descriptors when dealing with property inspection |
| raw_value = raw_lookup.get(attr, value) |
| if is_c_method(value) or is_c_classmethod(value): |
| if attr == "__new__": |
| # TODO: We should support __new__. |
| if "__init__" in names: |
| # Avoid duplicate functions if both are present. |
| # But is there any case where .__new__() has a |
| # better signature than __init__() ? |
| continue |
| attr = "__init__" |
| if is_c_classmethod(value): |
| self_var = "cls" |
| else: |
| self_var = "self" |
| generate_c_function_stub( |
| module, |
| attr, |
| value, |
| output=methods, |
| known_modules=known_modules, |
| imports=imports, |
| self_var=self_var, |
| cls=obj, |
| class_name=class_name, |
| sig_generators=sig_generators, |
| ) |
| elif is_c_property(raw_value): |
| generate_c_property_stub( |
| attr, |
| raw_value, |
| static_properties, |
| rw_properties, |
| ro_properties, |
| is_c_property_readonly(raw_value), |
| module=module, |
| known_modules=known_modules, |
| imports=imports, |
| ) |
| elif is_c_type(value): |
| generate_c_type_stub( |
| module, |
| attr, |
| value, |
| types, |
| imports=imports, |
| known_modules=known_modules, |
| sig_generators=sig_generators, |
| ) |
| else: |
| attrs.append((attr, value)) |
| |
| for attr, value in attrs: |
| static_properties.append( |
| "{}: ClassVar[{}] = ...".format( |
| attr, |
| strip_or_import(get_type_fullname(type(value)), module, known_modules, imports), |
| ) |
| ) |
| all_bases = type.mro(obj) |
| if all_bases[-1] is object: |
| # TODO: Is this always object? |
| del all_bases[-1] |
| # remove pybind11_object. All classes generated by pybind11 have pybind11_object in their MRO, |
| # which only overrides a few functions in object type |
| if all_bases and all_bases[-1].__name__ == "pybind11_object": |
| del all_bases[-1] |
| # remove the class itself |
| all_bases = all_bases[1:] |
| # Remove base classes of other bases as redundant. |
| bases: list[type] = [] |
| for base in all_bases: |
| if not any(issubclass(b, base) for b in bases): |
| bases.append(base) |
| if bases: |
| bases_str = "(%s)" % ", ".join( |
| strip_or_import(get_type_fullname(base), module, known_modules, imports) |
| for base in bases |
| ) |
| else: |
| bases_str = "" |
| if types or static_properties or rw_properties or methods or ro_properties: |
| output.append(f"class {class_name}{bases_str}:") |
| for line in types: |
| if ( |
| output |
| and output[-1] |
| and not output[-1].startswith("class") |
| and line.startswith("class") |
| ): |
| output.append("") |
| output.append(" " + line) |
| for line in static_properties: |
| output.append(f" {line}") |
| for line in rw_properties: |
| output.append(f" {line}") |
| for line in methods: |
| output.append(f" {line}") |
| for line in ro_properties: |
| output.append(f" {line}") |
| else: |
| output.append(f"class {class_name}{bases_str}: ...") |
| |
| |
| def get_type_fullname(typ: type) -> str: |
| return f"{typ.__module__}.{getattr(typ, '__qualname__', typ.__name__)}" |
| |
| |
| def method_name_sort_key(name: str) -> tuple[int, str]: |
| """Sort methods in classes in a typical order. |
| |
| I.e.: constructor, normal methods, special methods. |
| """ |
| if name in ("__new__", "__init__"): |
| return 0, name |
| if name.startswith("__") and name.endswith("__"): |
| return 2, name |
| return 1, name |
| |
| |
| def is_pybind_skipped_attribute(attr: str) -> bool: |
| return attr.startswith("__pybind11_module_local_") |
| |
| |
| def is_skipped_attribute(attr: str) -> bool: |
| return attr in ( |
| "__class__", |
| "__getattribute__", |
| "__str__", |
| "__repr__", |
| "__doc__", |
| "__dict__", |
| "__module__", |
| "__weakref__", |
| ) or is_pybind_skipped_attribute( # For pickling |
| attr |
| ) |
| |
| |
| def infer_method_args(name: str, self_var: str | None = None) -> list[ArgSig]: |
| args: list[ArgSig] | None = None |
| if name.startswith("__") and name.endswith("__"): |
| name = name[2:-2] |
| if name in ( |
| "hash", |
| "iter", |
| "next", |
| "sizeof", |
| "copy", |
| "deepcopy", |
| "reduce", |
| "getinitargs", |
| "int", |
| "float", |
| "trunc", |
| "complex", |
| "bool", |
| "abs", |
| "bytes", |
| "dir", |
| "len", |
| "reversed", |
| "round", |
| "index", |
| "enter", |
| ): |
| args = [] |
| elif name == "getitem": |
| args = [ArgSig(name="index")] |
| elif name == "setitem": |
| args = [ArgSig(name="index"), ArgSig(name="object")] |
| elif name in ("delattr", "getattr"): |
| args = [ArgSig(name="name")] |
| elif name == "setattr": |
| args = [ArgSig(name="name"), ArgSig(name="value")] |
| elif name == "getstate": |
| args = [] |
| elif name == "setstate": |
| args = [ArgSig(name="state")] |
| elif name in ( |
| "eq", |
| "ne", |
| "lt", |
| "le", |
| "gt", |
| "ge", |
| "add", |
| "radd", |
| "sub", |
| "rsub", |
| "mul", |
| "rmul", |
| "mod", |
| "rmod", |
| "floordiv", |
| "rfloordiv", |
| "truediv", |
| "rtruediv", |
| "divmod", |
| "rdivmod", |
| "pow", |
| "rpow", |
| "xor", |
| "rxor", |
| "or", |
| "ror", |
| "and", |
| "rand", |
| "lshift", |
| "rlshift", |
| "rshift", |
| "rrshift", |
| "contains", |
| "delitem", |
| "iadd", |
| "iand", |
| "ifloordiv", |
| "ilshift", |
| "imod", |
| "imul", |
| "ior", |
| "ipow", |
| "irshift", |
| "isub", |
| "itruediv", |
| "ixor", |
| ): |
| args = [ArgSig(name="other")] |
| elif name in ("neg", "pos", "invert"): |
| args = [] |
| elif name == "get": |
| args = [ArgSig(name="instance"), ArgSig(name="owner")] |
| elif name == "set": |
| args = [ArgSig(name="instance"), ArgSig(name="value")] |
| elif name == "reduce_ex": |
| args = [ArgSig(name="protocol")] |
| elif name == "exit": |
| args = [ArgSig(name="type"), ArgSig(name="value"), ArgSig(name="traceback")] |
| if args is None: |
| args = [ArgSig(name="*args"), ArgSig(name="**kwargs")] |
| return [ArgSig(name=self_var or "self")] + args |
| |
| |
| def infer_method_ret_type(name: str) -> str: |
| if name.startswith("__") and name.endswith("__"): |
| name = name[2:-2] |
| if name in ("float", "bool", "bytes", "int"): |
| return name |
| # Note: __eq__ and co may return arbitrary types, but bool is good enough for stubgen. |
| elif name in ("eq", "ne", "lt", "le", "gt", "ge", "contains"): |
| return "bool" |
| elif name in ("len", "hash", "sizeof", "trunc", "floor", "ceil"): |
| return "int" |
| elif name in ("init", "setitem"): |
| return "None" |
| return "Any" |