| #!/usr/bin/env python3 |
| """Generator of dynamically typed draft stubs for arbitrary modules. |
| |
| The logic of this script can be split in three steps: |
| * parsing options and finding sources: |
| - use runtime imports be default (to find also C modules) |
| - or use mypy's mechanisms, if importing is prohibited |
| * (optionally) semantically analysing the sources using mypy (as a single set) |
| * emitting the stubs text: |
| - for Python modules: from ASTs using StubGenerator |
| - for C modules using runtime introspection and (optionally) Sphinx docs |
| |
| During first and third steps some problematic files can be skipped, but any |
| blocking error during second step will cause the whole program to stop. |
| |
| Basic usage: |
| |
| $ stubgen foo.py bar.py some_directory |
| => Generate out/foo.pyi, out/bar.pyi, and stubs for some_directory (recursively). |
| |
| $ stubgen -m urllib.parse |
| => Generate out/urllib/parse.pyi. |
| |
| $ stubgen -p urllib |
| => Generate stubs for whole urlib package (recursively). |
| |
| For C modules, you can get more precise function signatures by parsing .rst (Sphinx) |
| documentation for extra information. For this, use the --doc-dir option: |
| |
| $ stubgen --doc-dir <DIR>/Python-3.4.2/Doc/library -m curses |
| |
| Note: The generated stubs should be verified manually. |
| |
| TODO: |
| - maybe use .rst docs also for Python modules |
| - maybe export more imported names if there is no __all__ (this affects ssl.SSLError, for example) |
| - a quick and dirty heuristic would be to turn this on if a module has something like |
| 'from x import y as _y' |
| - we don't seem to always detect properties ('closed' in 'io', for example) |
| """ |
| |
| from __future__ import annotations |
| |
| import argparse |
| import glob |
| import keyword |
| import os |
| import os.path |
| import sys |
| import traceback |
| from collections import defaultdict |
| from typing import Final, Iterable, Mapping |
| |
| import mypy.build |
| import mypy.mixedtraverser |
| import mypy.parse |
| import mypy.traverser |
| import mypy.util |
| from mypy.build import build |
| from mypy.errors import CompileError, Errors |
| from mypy.find_sources import InvalidSourceList, create_source_list |
| from mypy.modulefinder import ( |
| BuildSource, |
| FindModuleCache, |
| ModuleNotFoundReason, |
| SearchPaths, |
| default_lib_path, |
| ) |
| from mypy.moduleinspect import ModuleInspect |
| from mypy.nodes import ( |
| ARG_NAMED, |
| ARG_POS, |
| ARG_STAR, |
| ARG_STAR2, |
| IS_ABSTRACT, |
| NOT_ABSTRACT, |
| AssignmentStmt, |
| Block, |
| BytesExpr, |
| CallExpr, |
| ClassDef, |
| ComparisonExpr, |
| ComplexExpr, |
| Decorator, |
| DictExpr, |
| EllipsisExpr, |
| Expression, |
| FloatExpr, |
| FuncBase, |
| FuncDef, |
| IfStmt, |
| Import, |
| ImportAll, |
| ImportFrom, |
| IndexExpr, |
| IntExpr, |
| ListExpr, |
| MemberExpr, |
| MypyFile, |
| NameExpr, |
| OpExpr, |
| OverloadedFuncDef, |
| Statement, |
| StrExpr, |
| TempNode, |
| TupleExpr, |
| TypeInfo, |
| UnaryExpr, |
| Var, |
| ) |
| from mypy.options import Options as MypyOptions |
| from mypy.stubdoc import Sig, find_unique_signatures, parse_all_signatures |
| from mypy.stubgenc import ( |
| DocstringSignatureGenerator, |
| ExternalSignatureGenerator, |
| FallbackSignatureGenerator, |
| SignatureGenerator, |
| generate_stub_for_c_module, |
| ) |
| from mypy.stubutil import ( |
| CantImport, |
| common_dir_prefix, |
| fail_missing, |
| find_module_path_and_all_py3, |
| generate_guarded, |
| remove_misplaced_type_comments, |
| report_missing, |
| walk_packages, |
| ) |
| from mypy.traverser import ( |
| all_yield_expressions, |
| has_return_statement, |
| has_yield_expression, |
| has_yield_from_expression, |
| ) |
| from mypy.types import ( |
| OVERLOAD_NAMES, |
| TPDICT_NAMES, |
| TYPED_NAMEDTUPLE_NAMES, |
| AnyType, |
| CallableType, |
| Instance, |
| NoneType, |
| TupleType, |
| Type, |
| TypeList, |
| TypeStrVisitor, |
| UnboundType, |
| UnionType, |
| get_proper_type, |
| ) |
| from mypy.visitor import NodeVisitor |
| |
| TYPING_MODULE_NAMES: Final = ("typing", "typing_extensions") |
| |
| # Common ways of naming package containing vendored modules. |
| VENDOR_PACKAGES: Final = ["packages", "vendor", "vendored", "_vendor", "_vendored_packages"] |
| |
| # Avoid some file names that are unnecessary or likely to cause trouble (\n for end of path). |
| BLACKLIST: Final = [ |
| "/six.py\n", # Likely vendored six; too dynamic for us to handle |
| "/vendored/", # Vendored packages |
| "/vendor/", # Vendored packages |
| "/_vendor/", |
| "/_vendored_packages/", |
| ] |
| |
| # Special-cased names that are implicitly exported from the stub (from m import y as y). |
| EXTRA_EXPORTED: Final = { |
| "pyasn1_modules.rfc2437.univ", |
| "pyasn1_modules.rfc2459.char", |
| "pyasn1_modules.rfc2459.univ", |
| } |
| |
| # These names should be omitted from generated stubs. |
| IGNORED_DUNDERS: Final = { |
| "__all__", |
| "__author__", |
| "__version__", |
| "__about__", |
| "__copyright__", |
| "__email__", |
| "__license__", |
| "__summary__", |
| "__title__", |
| "__uri__", |
| "__str__", |
| "__repr__", |
| "__getstate__", |
| "__setstate__", |
| "__slots__", |
| } |
| |
| # These methods are expected to always return a non-trivial value. |
| METHODS_WITH_RETURN_VALUE: Final = { |
| "__ne__", |
| "__eq__", |
| "__lt__", |
| "__le__", |
| "__gt__", |
| "__ge__", |
| "__hash__", |
| "__iter__", |
| } |
| |
| # These magic methods always return the same type. |
| KNOWN_MAGIC_METHODS_RETURN_TYPES: Final = { |
| "__len__": "int", |
| "__length_hint__": "int", |
| "__init__": "None", |
| "__del__": "None", |
| "__bool__": "bool", |
| "__bytes__": "bytes", |
| "__format__": "str", |
| "__contains__": "bool", |
| "__complex__": "complex", |
| "__int__": "int", |
| "__float__": "float", |
| "__index__": "int", |
| } |
| |
| |
| class Options: |
| """Represents stubgen options. |
| |
| This class is mutable to simplify testing. |
| """ |
| |
| def __init__( |
| self, |
| pyversion: tuple[int, int], |
| no_import: bool, |
| doc_dir: str, |
| search_path: list[str], |
| interpreter: str, |
| parse_only: bool, |
| ignore_errors: bool, |
| include_private: bool, |
| output_dir: str, |
| modules: list[str], |
| packages: list[str], |
| files: list[str], |
| verbose: bool, |
| quiet: bool, |
| export_less: bool, |
| include_docstrings: bool, |
| ) -> None: |
| # See parse_options for descriptions of the flags. |
| self.pyversion = pyversion |
| self.no_import = no_import |
| self.doc_dir = doc_dir |
| self.search_path = search_path |
| self.interpreter = interpreter |
| self.decointerpreter = interpreter |
| self.parse_only = parse_only |
| self.ignore_errors = ignore_errors |
| self.include_private = include_private |
| self.output_dir = output_dir |
| self.modules = modules |
| self.packages = packages |
| self.files = files |
| self.verbose = verbose |
| self.quiet = quiet |
| self.export_less = export_less |
| self.include_docstrings = include_docstrings |
| |
| |
| class StubSource: |
| """A single source for stub: can be a Python or C module. |
| |
| A simple extension of BuildSource that also carries the AST and |
| the value of __all__ detected at runtime. |
| """ |
| |
| def __init__( |
| self, module: str, path: str | None = None, runtime_all: list[str] | None = None |
| ) -> None: |
| self.source = BuildSource(path, module, None) |
| self.runtime_all = runtime_all |
| self.ast: MypyFile | None = None |
| |
| @property |
| def module(self) -> str: |
| return self.source.module |
| |
| @property |
| def path(self) -> str | None: |
| return self.source.path |
| |
| |
| # What was generated previously in the stub file. We keep track of these to generate |
| # nicely formatted output (add empty line between non-empty classes, for example). |
| EMPTY: Final = "EMPTY" |
| FUNC: Final = "FUNC" |
| CLASS: Final = "CLASS" |
| EMPTY_CLASS: Final = "EMPTY_CLASS" |
| VAR: Final = "VAR" |
| NOT_IN_ALL: Final = "NOT_IN_ALL" |
| |
| # Indicates that we failed to generate a reasonable output |
| # for a given node. These should be manually replaced by a user. |
| |
| ERROR_MARKER: Final = "<ERROR>" |
| |
| |
| class AnnotationPrinter(TypeStrVisitor): |
| """Visitor used to print existing annotations in a file. |
| |
| The main difference from TypeStrVisitor is a better treatment of |
| unbound types. |
| |
| Notes: |
| * This visitor doesn't add imports necessary for annotations, this is done separately |
| by ImportTracker. |
| * It can print all kinds of types, but the generated strings may not be valid (notably |
| callable types) since it prints the same string that reveal_type() does. |
| * For Instance types it prints the fully qualified names. |
| """ |
| |
| # TODO: Generate valid string representation for callable types. |
| # TODO: Use short names for Instances. |
| def __init__(self, stubgen: StubGenerator) -> None: |
| super().__init__(options=mypy.options.Options()) |
| self.stubgen = stubgen |
| |
| def visit_any(self, t: AnyType) -> str: |
| s = super().visit_any(t) |
| self.stubgen.import_tracker.require_name(s) |
| return s |
| |
| def visit_unbound_type(self, t: UnboundType) -> str: |
| s = t.name |
| self.stubgen.import_tracker.require_name(s) |
| if t.args: |
| s += f"[{self.args_str(t.args)}]" |
| return s |
| |
| def visit_none_type(self, t: NoneType) -> str: |
| return "None" |
| |
| def visit_type_list(self, t: TypeList) -> str: |
| return f"[{self.list_str(t.items)}]" |
| |
| def visit_union_type(self, t: UnionType) -> str: |
| return " | ".join([item.accept(self) for item in t.items]) |
| |
| def args_str(self, args: Iterable[Type]) -> str: |
| """Convert an array of arguments to strings and join the results with commas. |
| |
| The main difference from list_str is the preservation of quotes for string |
| arguments |
| """ |
| types = ["builtins.bytes", "builtins.str"] |
| res = [] |
| for arg in args: |
| arg_str = arg.accept(self) |
| if isinstance(arg, UnboundType) and arg.original_str_fallback in types: |
| res.append(f"'{arg_str}'") |
| else: |
| res.append(arg_str) |
| return ", ".join(res) |
| |
| |
| class AliasPrinter(NodeVisitor[str]): |
| """Visitor used to collect type aliases _and_ type variable definitions. |
| |
| Visit r.h.s of the definition to get the string representation of type alias. |
| """ |
| |
| def __init__(self, stubgen: StubGenerator) -> None: |
| self.stubgen = stubgen |
| super().__init__() |
| |
| def visit_call_expr(self, node: CallExpr) -> str: |
| # Call expressions are not usually types, but we also treat `X = TypeVar(...)` as a |
| # type alias that has to be preserved (even if TypeVar is not the same as an alias) |
| callee = node.callee.accept(self) |
| args = [] |
| for name, arg, kind in zip(node.arg_names, node.args, node.arg_kinds): |
| if kind == ARG_POS: |
| args.append(arg.accept(self)) |
| elif kind == ARG_STAR: |
| args.append("*" + arg.accept(self)) |
| elif kind == ARG_STAR2: |
| args.append("**" + arg.accept(self)) |
| elif kind == ARG_NAMED: |
| args.append(f"{name}={arg.accept(self)}") |
| else: |
| raise ValueError(f"Unknown argument kind {kind} in call") |
| return f"{callee}({', '.join(args)})" |
| |
| def visit_name_expr(self, node: NameExpr) -> str: |
| self.stubgen.import_tracker.require_name(node.name) |
| return node.name |
| |
| def visit_member_expr(self, o: MemberExpr) -> str: |
| node: Expression = o |
| trailer = "" |
| while isinstance(node, MemberExpr): |
| trailer = "." + node.name + trailer |
| node = node.expr |
| if not isinstance(node, NameExpr): |
| return ERROR_MARKER |
| self.stubgen.import_tracker.require_name(node.name) |
| return node.name + trailer |
| |
| def visit_str_expr(self, node: StrExpr) -> str: |
| return repr(node.value) |
| |
| def visit_index_expr(self, node: IndexExpr) -> str: |
| base = node.base.accept(self) |
| index = node.index.accept(self) |
| if len(index) > 2 and index.startswith("(") and index.endswith(")"): |
| index = index[1:-1] |
| return f"{base}[{index}]" |
| |
| def visit_tuple_expr(self, node: TupleExpr) -> str: |
| return f"({', '.join(n.accept(self) for n in node.items)})" |
| |
| def visit_list_expr(self, node: ListExpr) -> str: |
| return f"[{', '.join(n.accept(self) for n in node.items)}]" |
| |
| def visit_dict_expr(self, o: DictExpr) -> str: |
| dict_items = [] |
| for key, value in o.items: |
| # This is currently only used for TypedDict where all keys are strings. |
| assert isinstance(key, StrExpr) |
| dict_items.append(f"{key.accept(self)}: {value.accept(self)}") |
| return f"{{{', '.join(dict_items)}}}" |
| |
| def visit_ellipsis(self, node: EllipsisExpr) -> str: |
| return "..." |
| |
| def visit_op_expr(self, o: OpExpr) -> str: |
| return f"{o.left.accept(self)} {o.op} {o.right.accept(self)}" |
| |
| |
| class ImportTracker: |
| """Record necessary imports during stub generation.""" |
| |
| def __init__(self) -> None: |
| # module_for['foo'] has the module name where 'foo' was imported from, or None if |
| # 'foo' is a module imported directly; examples |
| # 'from pkg.m import f as foo' ==> module_for['foo'] == 'pkg.m' |
| # 'from m import f' ==> module_for['f'] == 'm' |
| # 'import m' ==> module_for['m'] == None |
| # 'import pkg.m' ==> module_for['pkg.m'] == None |
| # ==> module_for['pkg'] == None |
| self.module_for: dict[str, str | None] = {} |
| |
| # direct_imports['foo'] is the module path used when the name 'foo' was added to the |
| # namespace. |
| # import foo.bar.baz ==> direct_imports['foo'] == 'foo.bar.baz' |
| # ==> direct_imports['foo.bar'] == 'foo.bar.baz' |
| # ==> direct_imports['foo.bar.baz'] == 'foo.bar.baz' |
| self.direct_imports: dict[str, str] = {} |
| |
| # reverse_alias['foo'] is the name that 'foo' had originally when imported with an |
| # alias; examples |
| # 'import numpy as np' ==> reverse_alias['np'] == 'numpy' |
| # 'import foo.bar as bar' ==> reverse_alias['bar'] == 'foo.bar' |
| # 'from decimal import Decimal as D' ==> reverse_alias['D'] == 'Decimal' |
| self.reverse_alias: dict[str, str] = {} |
| |
| # required_names is the set of names that are actually used in a type annotation |
| self.required_names: set[str] = set() |
| |
| # Names that should be reexported if they come from another module |
| self.reexports: set[str] = set() |
| |
| def add_import_from(self, module: str, names: list[tuple[str, str | None]]) -> None: |
| for name, alias in names: |
| if alias: |
| # 'from {module} import {name} as {alias}' |
| self.module_for[alias] = module |
| self.reverse_alias[alias] = name |
| else: |
| # 'from {module} import {name}' |
| self.module_for[name] = module |
| self.reverse_alias.pop(name, None) |
| self.direct_imports.pop(alias or name, None) |
| |
| def add_import(self, module: str, alias: str | None = None) -> None: |
| if alias: |
| # 'import {module} as {alias}' |
| self.module_for[alias] = None |
| self.reverse_alias[alias] = module |
| else: |
| # 'import {module}' |
| name = module |
| # add module and its parent packages |
| while name: |
| self.module_for[name] = None |
| self.direct_imports[name] = module |
| self.reverse_alias.pop(name, None) |
| name = name.rpartition(".")[0] |
| |
| def require_name(self, name: str) -> None: |
| self.required_names.add(name.split(".")[0]) |
| |
| def reexport(self, name: str) -> None: |
| """Mark a given non qualified name as needed in __all__. |
| |
| This means that in case it comes from a module, it should be |
| imported with an alias even is the alias is the same as the name. |
| """ |
| self.require_name(name) |
| self.reexports.add(name) |
| |
| def import_lines(self) -> list[str]: |
| """The list of required import lines (as strings with python code).""" |
| result = [] |
| |
| # To summarize multiple names imported from a same module, we collect those |
| # in the `module_map` dictionary, mapping a module path to the list of names that should |
| # be imported from it. the names can also be alias in the form 'original as alias' |
| module_map: Mapping[str, list[str]] = defaultdict(list) |
| |
| for name in sorted(self.required_names): |
| # If we haven't seen this name in an import statement, ignore it |
| if name not in self.module_for: |
| continue |
| |
| m = self.module_for[name] |
| if m is not None: |
| # This name was found in a from ... import ... |
| # Collect the name in the module_map |
| if name in self.reverse_alias: |
| name = f"{self.reverse_alias[name]} as {name}" |
| elif name in self.reexports: |
| name = f"{name} as {name}" |
| module_map[m].append(name) |
| else: |
| # This name was found in an import ... |
| # We can already generate the import line |
| if name in self.reverse_alias: |
| source = self.reverse_alias[name] |
| result.append(f"import {source} as {name}\n") |
| elif name in self.reexports: |
| assert "." not in name # Because reexports only has nonqualified names |
| result.append(f"import {name} as {name}\n") |
| else: |
| result.append(f"import {self.direct_imports[name]}\n") |
| |
| # Now generate all the from ... import ... lines collected in module_map |
| for module, names in sorted(module_map.items()): |
| result.append(f"from {module} import {', '.join(sorted(names))}\n") |
| return result |
| |
| |
| def find_defined_names(file: MypyFile) -> set[str]: |
| finder = DefinitionFinder() |
| file.accept(finder) |
| return finder.names |
| |
| |
| class DefinitionFinder(mypy.traverser.TraverserVisitor): |
| """Find names of things defined at the top level of a module.""" |
| |
| # TODO: Assignment statements etc. |
| |
| def __init__(self) -> None: |
| # Short names of things defined at the top level. |
| self.names: set[str] = set() |
| |
| def visit_class_def(self, o: ClassDef) -> None: |
| # Don't recurse into classes, as we only keep track of top-level definitions. |
| self.names.add(o.name) |
| |
| def visit_func_def(self, o: FuncDef) -> None: |
| # Don't recurse, as we only keep track of top-level definitions. |
| self.names.add(o.name) |
| |
| |
| def find_referenced_names(file: MypyFile) -> set[str]: |
| finder = ReferenceFinder() |
| file.accept(finder) |
| return finder.refs |
| |
| |
| class ReferenceFinder(mypy.mixedtraverser.MixedTraverserVisitor): |
| """Find all name references (both local and global).""" |
| |
| # TODO: Filter out local variable and class attribute references |
| |
| def __init__(self) -> None: |
| # Short names of things defined at the top level. |
| self.refs: set[str] = set() |
| |
| def visit_block(self, block: Block) -> None: |
| if not block.is_unreachable: |
| super().visit_block(block) |
| |
| def visit_name_expr(self, e: NameExpr) -> None: |
| self.refs.add(e.name) |
| |
| def visit_instance(self, t: Instance) -> None: |
| self.add_ref(t.type.fullname) |
| super().visit_instance(t) |
| |
| def visit_unbound_type(self, t: UnboundType) -> None: |
| if t.name: |
| self.add_ref(t.name) |
| |
| def visit_tuple_type(self, t: TupleType) -> None: |
| # Ignore fallback |
| for item in t.items: |
| item.accept(self) |
| |
| def visit_callable_type(self, t: CallableType) -> None: |
| # Ignore fallback |
| for arg in t.arg_types: |
| arg.accept(self) |
| t.ret_type.accept(self) |
| |
| def add_ref(self, fullname: str) -> None: |
| self.refs.add(fullname.split(".")[-1]) |
| |
| |
| class StubGenerator(mypy.traverser.TraverserVisitor): |
| """Generate stub text from a mypy AST.""" |
| |
| def __init__( |
| self, |
| _all_: list[str] | None, |
| include_private: bool = False, |
| analyzed: bool = False, |
| export_less: bool = False, |
| include_docstrings: bool = False, |
| ) -> None: |
| # Best known value of __all__. |
| self._all_ = _all_ |
| self._output: list[str] = [] |
| self._decorators: list[str] = [] |
| self._import_lines: list[str] = [] |
| # Current indent level (indent is hardcoded to 4 spaces). |
| self._indent = "" |
| # Stack of defined variables (per scope). |
| self._vars: list[list[str]] = [[]] |
| # What was generated previously in the stub file. |
| self._state = EMPTY |
| self._toplevel_names: list[str] = [] |
| self._include_private = include_private |
| self._include_docstrings = include_docstrings |
| self._current_class: ClassDef | None = None |
| self.import_tracker = ImportTracker() |
| # Was the tree semantically analysed before? |
| self.analyzed = analyzed |
| # Disable implicit exports of package-internal imports? |
| self.export_less = export_less |
| # Add imports that could be implicitly generated |
| self.import_tracker.add_import_from("typing", [("NamedTuple", None)]) |
| # Names in __all__ are required |
| for name in _all_ or (): |
| if name not in IGNORED_DUNDERS: |
| self.import_tracker.reexport(name) |
| self.defined_names: set[str] = set() |
| # Short names of methods defined in the body of the current class |
| self.method_names: set[str] = set() |
| |
| def visit_mypy_file(self, o: MypyFile) -> None: |
| self.module = o.fullname # Current module being processed |
| self.path = o.path |
| self.defined_names = find_defined_names(o) |
| self.referenced_names = find_referenced_names(o) |
| known_imports = { |
| "_typeshed": ["Incomplete"], |
| "typing": ["Any", "TypeVar", "NamedTuple"], |
| "collections.abc": ["Generator"], |
| "typing_extensions": ["TypedDict", "ParamSpec", "TypeVarTuple"], |
| } |
| for pkg, imports in known_imports.items(): |
| for t in imports: |
| if t not in self.defined_names: |
| alias = None |
| else: |
| alias = "_" + t |
| self.import_tracker.add_import_from(pkg, [(t, alias)]) |
| super().visit_mypy_file(o) |
| undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names] |
| if undefined_names: |
| if self._state != EMPTY: |
| self.add("\n") |
| self.add("# Names in __all__ with no definition:\n") |
| for name in sorted(undefined_names): |
| self.add(f"# {name}\n") |
| |
| def visit_overloaded_func_def(self, o: OverloadedFuncDef) -> None: |
| """@property with setters and getters, @overload chain and some others.""" |
| overload_chain = False |
| for item in o.items: |
| if not isinstance(item, Decorator): |
| continue |
| if self.is_private_name(item.func.name, item.func.fullname): |
| continue |
| |
| self.process_decorator(item) |
| if not overload_chain: |
| self.visit_func_def(item.func) |
| if item.func.is_overload: |
| overload_chain = True |
| elif item.func.is_overload: |
| self.visit_func_def(item.func) |
| else: |
| # skip the overload implementation and clear the decorator we just processed |
| self.clear_decorators() |
| |
| def visit_func_def(self, o: FuncDef) -> None: |
| if ( |
| self.is_private_name(o.name, o.fullname) |
| or self.is_not_in_all(o.name) |
| or (self.is_recorded_name(o.name) and not o.is_overload) |
| ): |
| self.clear_decorators() |
| return |
| if not self._indent and self._state not in (EMPTY, FUNC) and not o.is_awaitable_coroutine: |
| self.add("\n") |
| if not self.is_top_level(): |
| self_inits = find_self_initializers(o) |
| for init, value in self_inits: |
| if init in self.method_names: |
| # Can't have both an attribute and a method/property with the same name. |
| continue |
| init_code = self.get_init(init, value) |
| if init_code: |
| self.add(init_code) |
| # dump decorators, just before "def ..." |
| for s in self._decorators: |
| self.add(s) |
| self.clear_decorators() |
| self.add(f"{self._indent}{'async ' if o.is_coroutine else ''}def {o.name}(") |
| self.record_name(o.name) |
| args: list[str] = [] |
| for i, arg_ in enumerate(o.arguments): |
| var = arg_.variable |
| kind = arg_.kind |
| name = var.name |
| annotated_type = ( |
| o.unanalyzed_type.arg_types[i] |
| if isinstance(o.unanalyzed_type, CallableType) |
| else None |
| ) |
| # I think the name check is incorrect: there are libraries which |
| # name their 0th argument other than self/cls |
| is_self_arg = i == 0 and name == "self" |
| is_cls_arg = i == 0 and name == "cls" |
| annotation = "" |
| if annotated_type and not is_self_arg and not is_cls_arg: |
| # Luckily, an argument explicitly annotated with "Any" has |
| # type "UnboundType" and will not match. |
| if not isinstance(get_proper_type(annotated_type), AnyType): |
| annotation = f": {self.print_annotation(annotated_type)}" |
| |
| if kind.is_named() and not any(arg.startswith("*") for arg in args): |
| args.append("*") |
| |
| if arg_.initializer: |
| if not annotation: |
| typename = self.get_str_type_of_node(arg_.initializer, True, False) |
| if typename == "": |
| annotation = "=..." |
| else: |
| annotation = f": {typename} = ..." |
| else: |
| annotation += " = ..." |
| arg = name + annotation |
| elif kind == ARG_STAR: |
| arg = f"*{name}{annotation}" |
| elif kind == ARG_STAR2: |
| arg = f"**{name}{annotation}" |
| else: |
| arg = name + annotation |
| args.append(arg) |
| retname = None |
| if o.name != "__init__" and isinstance(o.unanalyzed_type, CallableType): |
| if isinstance(get_proper_type(o.unanalyzed_type.ret_type), AnyType): |
| # Luckily, a return type explicitly annotated with "Any" has |
| # type "UnboundType" and will enter the else branch. |
| retname = None # implicit Any |
| else: |
| retname = self.print_annotation(o.unanalyzed_type.ret_type) |
| elif o.abstract_status == IS_ABSTRACT or o.name in METHODS_WITH_RETURN_VALUE: |
| # Always assume abstract methods return Any unless explicitly annotated. Also |
| # some dunder methods should not have a None return type. |
| retname = None # implicit Any |
| elif o.name in KNOWN_MAGIC_METHODS_RETURN_TYPES: |
| retname = KNOWN_MAGIC_METHODS_RETURN_TYPES[o.name] |
| elif has_yield_expression(o) or has_yield_from_expression(o): |
| generator_name = self.add_typing_import("Generator") |
| yield_name = "None" |
| send_name = "None" |
| return_name = "None" |
| if has_yield_from_expression(o): |
| yield_name = send_name = self.add_typing_import("Incomplete") |
| else: |
| for expr, in_assignment in all_yield_expressions(o): |
| if expr.expr is not None and not self.is_none_expr(expr.expr): |
| yield_name = self.add_typing_import("Incomplete") |
| if in_assignment: |
| send_name = self.add_typing_import("Incomplete") |
| if has_return_statement(o): |
| return_name = self.add_typing_import("Incomplete") |
| retname = f"{generator_name}[{yield_name}, {send_name}, {return_name}]" |
| elif not has_return_statement(o) and o.abstract_status == NOT_ABSTRACT: |
| retname = "None" |
| retfield = "" |
| if retname is not None: |
| retfield = " -> " + retname |
| |
| self.add(", ".join(args)) |
| self.add(f"){retfield}:") |
| if self._include_docstrings and o.docstring: |
| docstring = mypy.util.quote_docstring(o.docstring) |
| self.add(f"\n{self._indent} {docstring}\n") |
| else: |
| self.add(" ...\n") |
| |
| self._state = FUNC |
| |
| def is_none_expr(self, expr: Expression) -> bool: |
| return isinstance(expr, NameExpr) and expr.name == "None" |
| |
| def visit_decorator(self, o: Decorator) -> None: |
| if self.is_private_name(o.func.name, o.func.fullname): |
| return |
| self.process_decorator(o) |
| self.visit_func_def(o.func) |
| |
| def process_decorator(self, o: Decorator) -> None: |
| """Process a series of decorators. |
| |
| Only preserve certain special decorators such as @abstractmethod. |
| """ |
| for decorator in o.original_decorators: |
| if not isinstance(decorator, (NameExpr, MemberExpr)): |
| continue |
| qualname = get_qualified_name(decorator) |
| fullname = self.get_fullname(decorator) |
| if fullname in ( |
| "builtins.property", |
| "builtins.staticmethod", |
| "builtins.classmethod", |
| "functools.cached_property", |
| ): |
| self.add_decorator(qualname, require_name=True) |
| elif fullname in ( |
| "asyncio.coroutine", |
| "asyncio.coroutines.coroutine", |
| "types.coroutine", |
| ): |
| o.func.is_awaitable_coroutine = True |
| self.add_decorator(qualname, require_name=True) |
| elif fullname == "abc.abstractmethod": |
| self.add_decorator(qualname, require_name=True) |
| o.func.abstract_status = IS_ABSTRACT |
| elif fullname in ( |
| "abc.abstractproperty", |
| "abc.abstractstaticmethod", |
| "abc.abstractclassmethod", |
| ): |
| abc_module = qualname.rpartition(".")[0] |
| if not abc_module: |
| self.import_tracker.add_import("abc") |
| builtin_decorator_replacement = fullname[len("abc.abstract") :] |
| self.add_decorator(builtin_decorator_replacement, require_name=False) |
| self.add_decorator(f"{abc_module or 'abc'}.abstractmethod", require_name=True) |
| o.func.abstract_status = IS_ABSTRACT |
| elif fullname in OVERLOAD_NAMES: |
| self.add_decorator(qualname, require_name=True) |
| o.func.is_overload = True |
| elif qualname.endswith(".setter"): |
| self.add_decorator(qualname, require_name=False) |
| |
| def get_fullname(self, expr: Expression) -> str: |
| """Return the full name resolving imports and import aliases.""" |
| if ( |
| self.analyzed |
| and isinstance(expr, (NameExpr, MemberExpr)) |
| and expr.fullname |
| and not (isinstance(expr.node, Var) and expr.node.is_suppressed_import) |
| ): |
| return expr.fullname |
| name = get_qualified_name(expr) |
| if "." not in name: |
| real_module = self.import_tracker.module_for.get(name) |
| real_short = self.import_tracker.reverse_alias.get(name, name) |
| if real_module is None and real_short not in self.defined_names: |
| real_module = "builtins" # not imported and not defined, must be a builtin |
| else: |
| name_module, real_short = name.split(".", 1) |
| real_module = self.import_tracker.reverse_alias.get(name_module, name_module) |
| resolved_name = real_short if real_module is None else f"{real_module}.{real_short}" |
| return resolved_name |
| |
| def visit_class_def(self, o: ClassDef) -> None: |
| self._current_class = o |
| self.method_names = find_method_names(o.defs.body) |
| sep: int | None = None |
| if not self._indent and self._state != EMPTY: |
| sep = len(self._output) |
| self.add("\n") |
| self.add(f"{self._indent}class {o.name}") |
| self.record_name(o.name) |
| base_types = self.get_base_types(o) |
| if base_types: |
| for base in base_types: |
| self.import_tracker.require_name(base) |
| if isinstance(o.metaclass, (NameExpr, MemberExpr)): |
| meta = o.metaclass.accept(AliasPrinter(self)) |
| base_types.append("metaclass=" + meta) |
| elif self.analyzed and o.info.is_abstract and not o.info.is_protocol: |
| base_types.append("metaclass=abc.ABCMeta") |
| self.import_tracker.add_import("abc") |
| self.import_tracker.require_name("abc") |
| if base_types: |
| self.add(f"({', '.join(base_types)})") |
| self.add(":\n") |
| self._indent += " " |
| if self._include_docstrings and o.docstring: |
| docstring = mypy.util.quote_docstring(o.docstring) |
| self.add(f"{self._indent}{docstring}\n") |
| n = len(self._output) |
| self._vars.append([]) |
| super().visit_class_def(o) |
| self._indent = self._indent[:-4] |
| self._vars.pop() |
| self._vars[-1].append(o.name) |
| if len(self._output) == n: |
| if self._state == EMPTY_CLASS and sep is not None: |
| self._output[sep] = "" |
| if not (self._include_docstrings and o.docstring): |
| self._output[-1] = self._output[-1][:-1] + " ...\n" |
| self._state = EMPTY_CLASS |
| else: |
| self._state = CLASS |
| self.method_names = set() |
| self._current_class = None |
| |
| def get_base_types(self, cdef: ClassDef) -> list[str]: |
| """Get list of base classes for a class.""" |
| base_types: list[str] = [] |
| p = AliasPrinter(self) |
| for base in cdef.base_type_exprs + cdef.removed_base_type_exprs: |
| if isinstance(base, (NameExpr, MemberExpr)): |
| if self.get_fullname(base) != "builtins.object": |
| base_types.append(get_qualified_name(base)) |
| elif isinstance(base, IndexExpr): |
| base_types.append(base.accept(p)) |
| elif isinstance(base, CallExpr): |
| # namedtuple(typename, fields), NamedTuple(typename, fields) calls can |
| # be used as a base class. The first argument is a string literal that |
| # is usually the same as the class name. |
| # |
| # Note: |
| # A call-based named tuple as a base class cannot be safely converted to |
| # a class-based NamedTuple definition because class attributes defined |
| # in the body of the class inheriting from the named tuple call are not |
| # namedtuple fields at runtime. |
| if self.is_namedtuple(base): |
| nt_fields = self._get_namedtuple_fields(base) |
| assert isinstance(base.args[0], StrExpr) |
| typename = base.args[0].value |
| if nt_fields is None: |
| # Invalid namedtuple() call, cannot determine fields |
| base_types.append(self.add_typing_import("Incomplete")) |
| continue |
| fields_str = ", ".join(f"({f!r}, {t})" for f, t in nt_fields) |
| namedtuple_name = self.add_typing_import("NamedTuple") |
| base_types.append(f"{namedtuple_name}({typename!r}, [{fields_str}])") |
| elif self.is_typed_namedtuple(base): |
| base_types.append(base.accept(p)) |
| else: |
| # At this point, we don't know what the base class is, so we |
| # just use Incomplete as the base class. |
| base_types.append(self.add_typing_import("Incomplete")) |
| for name, value in cdef.keywords.items(): |
| if name == "metaclass": |
| continue # handled separately |
| base_types.append(f"{name}={value.accept(p)}") |
| return base_types |
| |
| def visit_block(self, o: Block) -> None: |
| # Unreachable statements may be partially uninitialized and that may |
| # cause trouble. |
| if not o.is_unreachable: |
| super().visit_block(o) |
| |
| def visit_assignment_stmt(self, o: AssignmentStmt) -> None: |
| foundl = [] |
| |
| for lvalue in o.lvalues: |
| if isinstance(lvalue, NameExpr) and isinstance(o.rvalue, CallExpr): |
| if self.is_namedtuple(o.rvalue) or self.is_typed_namedtuple(o.rvalue): |
| self.process_namedtuple(lvalue, o.rvalue) |
| foundl.append(False) # state is updated in process_namedtuple |
| continue |
| if self.is_typeddict(o.rvalue): |
| self.process_typeddict(lvalue, o.rvalue) |
| foundl.append(False) # state is updated in process_typeddict |
| continue |
| if ( |
| isinstance(lvalue, NameExpr) |
| and not self.is_private_name(lvalue.name) |
| # it is never an alias with explicit annotation |
| and not o.unanalyzed_type |
| and self.is_alias_expression(o.rvalue) |
| ): |
| self.process_typealias(lvalue, o.rvalue) |
| continue |
| if isinstance(lvalue, (TupleExpr, ListExpr)): |
| items = lvalue.items |
| if isinstance(o.unanalyzed_type, TupleType): # type: ignore[misc] |
| annotations: Iterable[Type | None] = o.unanalyzed_type.items |
| else: |
| annotations = [None] * len(items) |
| else: |
| items = [lvalue] |
| annotations = [o.unanalyzed_type] |
| sep = False |
| found = False |
| for item, annotation in zip(items, annotations): |
| if isinstance(item, NameExpr): |
| init = self.get_init(item.name, o.rvalue, annotation) |
| if init: |
| found = True |
| if not sep and not self._indent and self._state not in (EMPTY, VAR): |
| init = "\n" + init |
| sep = True |
| self.add(init) |
| self.record_name(item.name) |
| foundl.append(found) |
| |
| if all(foundl): |
| self._state = VAR |
| |
| def is_namedtuple(self, expr: CallExpr) -> bool: |
| return self.get_fullname(expr.callee) == "collections.namedtuple" |
| |
| def is_typed_namedtuple(self, expr: CallExpr) -> bool: |
| return self.get_fullname(expr.callee) in TYPED_NAMEDTUPLE_NAMES |
| |
| def _get_namedtuple_fields(self, call: CallExpr) -> list[tuple[str, str]] | None: |
| if self.is_namedtuple(call): |
| fields_arg = call.args[1] |
| if isinstance(fields_arg, StrExpr): |
| field_names = fields_arg.value.replace(",", " ").split() |
| elif isinstance(fields_arg, (ListExpr, TupleExpr)): |
| field_names = [] |
| for field in fields_arg.items: |
| if not isinstance(field, StrExpr): |
| return None |
| field_names.append(field.value) |
| else: |
| return None # Invalid namedtuple fields type |
| if not field_names: |
| return [] |
| incomplete = self.add_typing_import("Incomplete") |
| return [(field_name, incomplete) for field_name in field_names] |
| elif self.is_typed_namedtuple(call): |
| fields_arg = call.args[1] |
| if not isinstance(fields_arg, (ListExpr, TupleExpr)): |
| return None |
| fields: list[tuple[str, str]] = [] |
| p = AliasPrinter(self) |
| for field in fields_arg.items: |
| if not (isinstance(field, TupleExpr) and len(field.items) == 2): |
| return None |
| field_name, field_type = field.items |
| if not isinstance(field_name, StrExpr): |
| return None |
| fields.append((field_name.value, field_type.accept(p))) |
| return fields |
| else: |
| return None # Not a named tuple call |
| |
| def process_namedtuple(self, lvalue: NameExpr, rvalue: CallExpr) -> None: |
| if self._state == CLASS: |
| self.add("\n") |
| |
| if not isinstance(rvalue.args[0], StrExpr): |
| self.annotate_as_incomplete(lvalue) |
| return |
| |
| fields = self._get_namedtuple_fields(rvalue) |
| if fields is None: |
| self.annotate_as_incomplete(lvalue) |
| return |
| bases = self.add_typing_import("NamedTuple") |
| # TODO: Add support for generic NamedTuples. Requires `Generic` as base class. |
| class_def = f"{self._indent}class {lvalue.name}({bases}):" |
| if len(fields) == 0: |
| self.add(f"{class_def} ...\n") |
| self._state = EMPTY_CLASS |
| else: |
| if self._state not in (EMPTY, CLASS): |
| self.add("\n") |
| self.add(f"{class_def}\n") |
| for f_name, f_type in fields: |
| self.add(f"{self._indent} {f_name}: {f_type}\n") |
| self._state = CLASS |
| |
| def is_typeddict(self, expr: CallExpr) -> bool: |
| return self.get_fullname(expr.callee) in TPDICT_NAMES |
| |
| def process_typeddict(self, lvalue: NameExpr, rvalue: CallExpr) -> None: |
| if self._state == CLASS: |
| self.add("\n") |
| |
| if not isinstance(rvalue.args[0], StrExpr): |
| self.annotate_as_incomplete(lvalue) |
| return |
| |
| items: list[tuple[str, Expression]] = [] |
| total: Expression | None = None |
| if len(rvalue.args) > 1 and rvalue.arg_kinds[1] == ARG_POS: |
| if not isinstance(rvalue.args[1], DictExpr): |
| self.annotate_as_incomplete(lvalue) |
| return |
| for attr_name, attr_type in rvalue.args[1].items: |
| if not isinstance(attr_name, StrExpr): |
| self.annotate_as_incomplete(lvalue) |
| return |
| items.append((attr_name.value, attr_type)) |
| if len(rvalue.args) > 2: |
| if rvalue.arg_kinds[2] != ARG_NAMED or rvalue.arg_names[2] != "total": |
| self.annotate_as_incomplete(lvalue) |
| return |
| total = rvalue.args[2] |
| else: |
| for arg_name, arg in zip(rvalue.arg_names[1:], rvalue.args[1:]): |
| if not isinstance(arg_name, str): |
| self.annotate_as_incomplete(lvalue) |
| return |
| if arg_name == "total": |
| total = arg |
| else: |
| items.append((arg_name, arg)) |
| bases = self.add_typing_import("TypedDict") |
| p = AliasPrinter(self) |
| if any(not key.isidentifier() or keyword.iskeyword(key) for key, _ in items): |
| # Keep the call syntax if there are non-identifier or reserved keyword keys. |
| self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n") |
| self._state = VAR |
| else: |
| # TODO: Add support for generic TypedDicts. Requires `Generic` as base class. |
| if total is not None: |
| bases += f", total={total.accept(p)}" |
| class_def = f"{self._indent}class {lvalue.name}({bases}):" |
| if len(items) == 0: |
| self.add(f"{class_def} ...\n") |
| self._state = EMPTY_CLASS |
| else: |
| if self._state not in (EMPTY, CLASS): |
| self.add("\n") |
| self.add(f"{class_def}\n") |
| for key, key_type in items: |
| self.add(f"{self._indent} {key}: {key_type.accept(p)}\n") |
| self._state = CLASS |
| |
| def annotate_as_incomplete(self, lvalue: NameExpr) -> None: |
| self.add(f"{self._indent}{lvalue.name}: {self.add_typing_import('Incomplete')}\n") |
| self._state = VAR |
| |
| def is_alias_expression(self, expr: Expression, top_level: bool = True) -> bool: |
| """Return True for things that look like target for an alias. |
| |
| Used to know if assignments look like type aliases, function alias, |
| or module alias. |
| """ |
| # Assignment of TypeVar(...) and other typevar-likes are passed through |
| if isinstance(expr, CallExpr) and self.get_fullname(expr.callee) in ( |
| "typing.TypeVar", |
| "typing_extensions.TypeVar", |
| "typing.ParamSpec", |
| "typing_extensions.ParamSpec", |
| "typing.TypeVarTuple", |
| "typing_extensions.TypeVarTuple", |
| ): |
| return True |
| elif isinstance(expr, EllipsisExpr): |
| return not top_level |
| elif isinstance(expr, NameExpr): |
| if expr.name in ("True", "False"): |
| return False |
| elif expr.name == "None": |
| return not top_level |
| else: |
| return not self.is_private_name(expr.name) |
| elif isinstance(expr, MemberExpr) and self.analyzed: |
| # Also add function and module aliases. |
| return ( |
| top_level |
| and isinstance(expr.node, (FuncDef, Decorator, MypyFile)) |
| or isinstance(expr.node, TypeInfo) |
| ) and not self.is_private_member(expr.node.fullname) |
| elif ( |
| isinstance(expr, IndexExpr) |
| and isinstance(expr.base, NameExpr) |
| and not self.is_private_name(expr.base.name) |
| ): |
| if isinstance(expr.index, TupleExpr): |
| indices = expr.index.items |
| else: |
| indices = [expr.index] |
| if expr.base.name == "Callable" and len(indices) == 2: |
| args, ret = indices |
| if isinstance(args, EllipsisExpr): |
| indices = [ret] |
| elif isinstance(args, ListExpr): |
| indices = args.items + [ret] |
| else: |
| return False |
| return all(self.is_alias_expression(i, top_level=False) for i in indices) |
| else: |
| return False |
| |
| def process_typealias(self, lvalue: NameExpr, rvalue: Expression) -> None: |
| p = AliasPrinter(self) |
| self.add(f"{self._indent}{lvalue.name} = {rvalue.accept(p)}\n") |
| self.record_name(lvalue.name) |
| self._vars[-1].append(lvalue.name) |
| |
| def visit_if_stmt(self, o: IfStmt) -> None: |
| # Ignore if __name__ == '__main__'. |
| expr = o.expr[0] |
| if ( |
| isinstance(expr, ComparisonExpr) |
| and isinstance(expr.operands[0], NameExpr) |
| and isinstance(expr.operands[1], StrExpr) |
| and expr.operands[0].name == "__name__" |
| and "__main__" in expr.operands[1].value |
| ): |
| return |
| super().visit_if_stmt(o) |
| |
| def visit_import_all(self, o: ImportAll) -> None: |
| self.add_import_line(f"from {'.' * o.relative}{o.id} import *\n") |
| |
| def visit_import_from(self, o: ImportFrom) -> None: |
| exported_names: set[str] = set() |
| import_names = [] |
| module, relative = translate_module_name(o.id, o.relative) |
| if self.module: |
| full_module, ok = mypy.util.correct_relative_import( |
| self.module, relative, module, self.path.endswith(".__init__.py") |
| ) |
| if not ok: |
| full_module = module |
| else: |
| full_module = module |
| if module == "__future__": |
| return # Not preserved |
| for name, as_name in o.names: |
| if name == "six": |
| # Vendored six -- translate into plain 'import six'. |
| self.visit_import(Import([("six", None)])) |
| continue |
| exported = False |
| if as_name is None and self.module and (self.module + "." + name) in EXTRA_EXPORTED: |
| # Special case certain names that should be exported, against our general rules. |
| exported = True |
| is_private = self.is_private_name(name, full_module + "." + name) |
| if ( |
| as_name is None |
| and name not in self.referenced_names |
| and (not self._all_ or name in IGNORED_DUNDERS) |
| and not is_private |
| and module not in ("abc", "asyncio") + TYPING_MODULE_NAMES |
| ): |
| # An imported name that is never referenced in the module is assumed to be |
| # exported, unless there is an explicit __all__. Note that we need to special |
| # case 'abc' since some references are deleted during semantic analysis. |
| exported = True |
| top_level = full_module.split(".")[0] |
| if ( |
| as_name is None |
| and not self.export_less |
| and (not self._all_ or name in IGNORED_DUNDERS) |
| and self.module |
| and not is_private |
| and top_level in (self.module.split(".")[0], "_" + self.module.split(".")[0]) |
| ): |
| # Export imports from the same package, since we can't reliably tell whether they |
| # are part of the public API. |
| exported = True |
| if exported: |
| self.import_tracker.reexport(name) |
| as_name = name |
| import_names.append((name, as_name)) |
| self.import_tracker.add_import_from("." * relative + module, import_names) |
| self._vars[-1].extend(alias or name for name, alias in import_names) |
| for name, alias in import_names: |
| self.record_name(alias or name) |
| |
| if self._all_: |
| # Include "import from"s that import names defined in __all__. |
| names = [ |
| name |
| for name, alias in o.names |
| if name in self._all_ and alias is None and name not in IGNORED_DUNDERS |
| ] |
| exported_names.update(names) |
| |
| def visit_import(self, o: Import) -> None: |
| for id, as_id in o.ids: |
| self.import_tracker.add_import(id, as_id) |
| if as_id is None: |
| target_name = id.split(".")[0] |
| else: |
| target_name = as_id |
| self._vars[-1].append(target_name) |
| self.record_name(target_name) |
| |
| def get_init( |
| self, lvalue: str, rvalue: Expression, annotation: Type | None = None |
| ) -> str | None: |
| """Return initializer for a variable. |
| |
| Return None if we've generated one already or if the variable is internal. |
| """ |
| if lvalue in self._vars[-1]: |
| # We've generated an initializer already for this variable. |
| return None |
| # TODO: Only do this at module top level. |
| if self.is_private_name(lvalue) or self.is_not_in_all(lvalue): |
| return None |
| self._vars[-1].append(lvalue) |
| if annotation is not None: |
| typename = self.print_annotation(annotation) |
| if ( |
| isinstance(annotation, UnboundType) |
| and not annotation.args |
| and annotation.name == "Final" |
| and self.import_tracker.module_for.get("Final") in TYPING_MODULE_NAMES |
| ): |
| # Final without type argument is invalid in stubs. |
| final_arg = self.get_str_type_of_node(rvalue) |
| typename += f"[{final_arg}]" |
| else: |
| typename = self.get_str_type_of_node(rvalue) |
| initializer = self.get_assign_initializer(rvalue) |
| return f"{self._indent}{lvalue}: {typename}{initializer}\n" |
| |
| def get_assign_initializer(self, rvalue: Expression) -> str: |
| """Does this rvalue need some special initializer value?""" |
| if self._current_class and self._current_class.info: |
| # Current rules |
| # 1. Return `...` if we are dealing with `NamedTuple` and it has an existing default value |
| if self._current_class.info.is_named_tuple and not isinstance(rvalue, TempNode): |
| return " = ..." |
| # TODO: support other possible cases, where initializer is important |
| |
| # By default, no initializer is required: |
| return "" |
| |
| def add(self, string: str) -> None: |
| """Add text to generated stub.""" |
| self._output.append(string) |
| |
| def add_decorator(self, name: str, require_name: bool = False) -> None: |
| if require_name: |
| self.import_tracker.require_name(name) |
| if not self._indent and self._state not in (EMPTY, FUNC): |
| self._decorators.append("\n") |
| self._decorators.append(f"{self._indent}@{name}\n") |
| |
| def clear_decorators(self) -> None: |
| self._decorators.clear() |
| |
| def typing_name(self, name: str) -> str: |
| if name in self.defined_names: |
| # Avoid name clash between name from typing and a name defined in stub. |
| return "_" + name |
| else: |
| return name |
| |
| def add_typing_import(self, name: str) -> str: |
| """Add a name to be imported for typing, unless it's imported already. |
| |
| The import will be internal to the stub. |
| """ |
| name = self.typing_name(name) |
| self.import_tracker.require_name(name) |
| return name |
| |
| def add_import_line(self, line: str) -> None: |
| """Add a line of text to the import section, unless it's already there.""" |
| if line not in self._import_lines: |
| self._import_lines.append(line) |
| |
| def output(self) -> str: |
| """Return the text for the stub.""" |
| imports = "" |
| if self._import_lines: |
| imports += "".join(self._import_lines) |
| imports += "".join(self.import_tracker.import_lines()) |
| if imports and self._output: |
| imports += "\n" |
| return imports + "".join(self._output) |
| |
| def is_not_in_all(self, name: str) -> bool: |
| if self.is_private_name(name): |
| return False |
| if self._all_: |
| return self.is_top_level() and name not in self._all_ |
| return False |
| |
| def is_private_name(self, name: str, fullname: str | None = None) -> bool: |
| if self._include_private: |
| return False |
| if fullname in EXTRA_EXPORTED: |
| return False |
| return name.startswith("_") and (not name.endswith("__") or name in IGNORED_DUNDERS) |
| |
| def is_private_member(self, fullname: str) -> bool: |
| parts = fullname.split(".") |
| return any(self.is_private_name(part) for part in parts) |
| |
| def get_str_type_of_node( |
| self, rvalue: Expression, can_infer_optional: bool = False, can_be_any: bool = True |
| ) -> str: |
| rvalue = self.maybe_unwrap_unary_expr(rvalue) |
| |
| if isinstance(rvalue, IntExpr): |
| return "int" |
| if isinstance(rvalue, StrExpr): |
| return "str" |
| if isinstance(rvalue, BytesExpr): |
| return "bytes" |
| if isinstance(rvalue, FloatExpr): |
| return "float" |
| if isinstance(rvalue, ComplexExpr): # 1j |
| return "complex" |
| if isinstance(rvalue, OpExpr) and rvalue.op in ("-", "+"): # -1j + 1 |
| if isinstance(self.maybe_unwrap_unary_expr(rvalue.left), ComplexExpr) or isinstance( |
| self.maybe_unwrap_unary_expr(rvalue.right), ComplexExpr |
| ): |
| return "complex" |
| if isinstance(rvalue, NameExpr) and rvalue.name in ("True", "False"): |
| return "bool" |
| if can_infer_optional and isinstance(rvalue, NameExpr) and rvalue.name == "None": |
| return f"{self.add_typing_import('Incomplete')} | None" |
| if can_be_any: |
| return self.add_typing_import("Incomplete") |
| else: |
| return "" |
| |
| def maybe_unwrap_unary_expr(self, expr: Expression) -> Expression: |
| """Unwrap (possibly nested) unary expressions. |
| |
| But, some unary expressions can change the type of expression. |
| While we want to preserve it. For example, `~True` is `int`. |
| So, we only allow a subset of unary expressions to be unwrapped. |
| """ |
| if not isinstance(expr, UnaryExpr): |
| return expr |
| |
| # First, try to unwrap `[+-]+ (int|float|complex)` expr: |
| math_ops = ("+", "-") |
| if expr.op in math_ops: |
| while isinstance(expr, UnaryExpr): |
| if expr.op not in math_ops or not isinstance( |
| expr.expr, (IntExpr, FloatExpr, ComplexExpr, UnaryExpr) |
| ): |
| break |
| expr = expr.expr |
| return expr |
| |
| # Next, try `not bool` expr: |
| if expr.op == "not": |
| while isinstance(expr, UnaryExpr): |
| if expr.op != "not" or not isinstance(expr.expr, (NameExpr, UnaryExpr)): |
| break |
| if isinstance(expr.expr, NameExpr) and expr.expr.name not in ("True", "False"): |
| break |
| expr = expr.expr |
| return expr |
| |
| # This is some other unary expr, we cannot do anything with it (yet?). |
| return expr |
| |
| def print_annotation(self, t: Type) -> str: |
| printer = AnnotationPrinter(self) |
| return t.accept(printer) |
| |
| def is_top_level(self) -> bool: |
| """Are we processing the top level of a file?""" |
| return self._indent == "" |
| |
| def record_name(self, name: str) -> None: |
| """Mark a name as defined. |
| |
| This only does anything if at the top level of a module. |
| """ |
| if self.is_top_level(): |
| self._toplevel_names.append(name) |
| |
| def is_recorded_name(self, name: str) -> bool: |
| """Has this name been recorded previously?""" |
| return self.is_top_level() and name in self._toplevel_names |
| |
| |
| def find_method_names(defs: list[Statement]) -> set[str]: |
| # TODO: Traverse into nested definitions |
| result = set() |
| for defn in defs: |
| if isinstance(defn, FuncDef): |
| result.add(defn.name) |
| elif isinstance(defn, Decorator): |
| result.add(defn.func.name) |
| elif isinstance(defn, OverloadedFuncDef): |
| for item in defn.items: |
| result.update(find_method_names([item])) |
| return result |
| |
| |
| class SelfTraverser(mypy.traverser.TraverserVisitor): |
| def __init__(self) -> None: |
| self.results: list[tuple[str, Expression]] = [] |
| |
| def visit_assignment_stmt(self, o: AssignmentStmt) -> None: |
| lvalue = o.lvalues[0] |
| if ( |
| isinstance(lvalue, MemberExpr) |
| and isinstance(lvalue.expr, NameExpr) |
| and lvalue.expr.name == "self" |
| ): |
| self.results.append((lvalue.name, o.rvalue)) |
| |
| |
| def find_self_initializers(fdef: FuncBase) -> list[tuple[str, Expression]]: |
| """Find attribute initializers in a method. |
| |
| Return a list of pairs (attribute name, r.h.s. expression). |
| """ |
| traverser = SelfTraverser() |
| fdef.accept(traverser) |
| return traverser.results |
| |
| |
| def get_qualified_name(o: Expression) -> str: |
| if isinstance(o, NameExpr): |
| return o.name |
| elif isinstance(o, MemberExpr): |
| return f"{get_qualified_name(o.expr)}.{o.name}" |
| else: |
| return ERROR_MARKER |
| |
| |
| def remove_blacklisted_modules(modules: list[StubSource]) -> list[StubSource]: |
| return [ |
| module for module in modules if module.path is None or not is_blacklisted_path(module.path) |
| ] |
| |
| |
| def is_blacklisted_path(path: str) -> bool: |
| return any(substr in (normalize_path_separators(path) + "\n") for substr in BLACKLIST) |
| |
| |
| def normalize_path_separators(path: str) -> str: |
| if sys.platform == "win32": |
| return path.replace("\\", "/") |
| return path |
| |
| |
| def collect_build_targets( |
| options: Options, mypy_opts: MypyOptions |
| ) -> tuple[list[StubSource], list[StubSource]]: |
| """Collect files for which we need to generate stubs. |
| |
| Return list of Python modules and C modules. |
| """ |
| if options.packages or options.modules: |
| if options.no_import: |
| py_modules = find_module_paths_using_search( |
| options.modules, options.packages, options.search_path, options.pyversion |
| ) |
| c_modules: list[StubSource] = [] |
| else: |
| # Using imports is the default, since we can also find C modules. |
| py_modules, c_modules = find_module_paths_using_imports( |
| options.modules, options.packages, options.verbose, options.quiet |
| ) |
| else: |
| # Use mypy native source collection for files and directories. |
| try: |
| source_list = create_source_list(options.files, mypy_opts) |
| except InvalidSourceList as e: |
| raise SystemExit(str(e)) from e |
| py_modules = [StubSource(m.module, m.path) for m in source_list] |
| c_modules = [] |
| |
| py_modules = remove_blacklisted_modules(py_modules) |
| |
| return py_modules, c_modules |
| |
| |
| def find_module_paths_using_imports( |
| modules: list[str], packages: list[str], verbose: bool, quiet: bool |
| ) -> tuple[list[StubSource], list[StubSource]]: |
| """Find path and runtime value of __all__ (if possible) for modules and packages. |
| |
| This function uses runtime Python imports to get the information. |
| """ |
| with ModuleInspect() as inspect: |
| py_modules: list[StubSource] = [] |
| c_modules: list[StubSource] = [] |
| found = list(walk_packages(inspect, packages, verbose)) |
| modules = modules + found |
| modules = [ |
| mod for mod in modules if not is_non_library_module(mod) |
| ] # We don't want to run any tests or scripts |
| for mod in modules: |
| try: |
| result = find_module_path_and_all_py3(inspect, mod, verbose) |
| except CantImport as e: |
| tb = traceback.format_exc() |
| if verbose: |
| sys.stdout.write(tb) |
| if not quiet: |
| report_missing(mod, e.message, tb) |
| continue |
| if not result: |
| c_modules.append(StubSource(mod)) |
| else: |
| path, runtime_all = result |
| py_modules.append(StubSource(mod, path, runtime_all)) |
| return py_modules, c_modules |
| |
| |
| def is_non_library_module(module: str) -> bool: |
| """Does module look like a test module or a script?""" |
| if module.endswith( |
| ( |
| ".tests", |
| ".test", |
| ".testing", |
| "_tests", |
| "_test_suite", |
| "test_util", |
| "test_utils", |
| "test_base", |
| ".__main__", |
| ".conftest", # Used by pytest |
| ".setup", # Typically an install script |
| ) |
| ): |
| return True |
| if module.split(".")[-1].startswith("test_"): |
| return True |
| if ( |
| ".tests." in module |
| or ".test." in module |
| or ".testing." in module |
| or ".SelfTest." in module |
| ): |
| return True |
| return False |
| |
| |
| def translate_module_name(module: str, relative: int) -> tuple[str, int]: |
| for pkg in VENDOR_PACKAGES: |
| for alt in "six.moves", "six": |
| substr = f"{pkg}.{alt}" |
| if module.endswith("." + substr) or (module == substr and relative): |
| return alt, 0 |
| if "." + substr + "." in module: |
| return alt + "." + module.partition("." + substr + ".")[2], 0 |
| return module, relative |
| |
| |
| def find_module_paths_using_search( |
| modules: list[str], packages: list[str], search_path: list[str], pyversion: tuple[int, int] |
| ) -> list[StubSource]: |
| """Find sources for modules and packages requested. |
| |
| This function just looks for source files at the file system level. |
| This is used if user passes --no-import, and will not find C modules. |
| Exit if some of the modules or packages can't be found. |
| """ |
| result: list[StubSource] = [] |
| typeshed_path = default_lib_path(mypy.build.default_data_dir(), pyversion, None) |
| search_paths = SearchPaths((".",) + tuple(search_path), (), (), tuple(typeshed_path)) |
| cache = FindModuleCache(search_paths, fscache=None, options=None) |
| for module in modules: |
| m_result = cache.find_module(module) |
| if isinstance(m_result, ModuleNotFoundReason): |
| fail_missing(module, m_result) |
| module_path = None |
| else: |
| module_path = m_result |
| result.append(StubSource(module, module_path)) |
| for package in packages: |
| p_result = cache.find_modules_recursive(package) |
| if p_result: |
| fail_missing(package, ModuleNotFoundReason.NOT_FOUND) |
| sources = [StubSource(m.module, m.path) for m in p_result] |
| result.extend(sources) |
| |
| result = [m for m in result if not is_non_library_module(m.module)] |
| |
| return result |
| |
| |
| def mypy_options(stubgen_options: Options) -> MypyOptions: |
| """Generate mypy options using the flag passed by user.""" |
| options = MypyOptions() |
| options.follow_imports = "skip" |
| options.incremental = False |
| options.ignore_errors = True |
| options.semantic_analysis_only = True |
| options.python_version = stubgen_options.pyversion |
| options.show_traceback = True |
| options.transform_source = remove_misplaced_type_comments |
| options.preserve_asts = True |
| options.include_docstrings = stubgen_options.include_docstrings |
| |
| # Override cache_dir if provided in the environment |
| environ_cache_dir = os.getenv("MYPY_CACHE_DIR", "") |
| if environ_cache_dir.strip(): |
| options.cache_dir = environ_cache_dir |
| options.cache_dir = os.path.expanduser(options.cache_dir) |
| |
| return options |
| |
| |
| def parse_source_file(mod: StubSource, mypy_options: MypyOptions) -> None: |
| """Parse a source file. |
| |
| On success, store AST in the corresponding attribute of the stub source. |
| If there are syntax errors, print them and exit. |
| """ |
| assert mod.path is not None, "Not found module was not skipped" |
| with open(mod.path, "rb") as f: |
| data = f.read() |
| source = mypy.util.decode_python_encoding(data) |
| errors = Errors(mypy_options) |
| mod.ast = mypy.parse.parse( |
| source, fnam=mod.path, module=mod.module, errors=errors, options=mypy_options |
| ) |
| mod.ast._fullname = mod.module |
| if errors.is_blockers(): |
| # Syntax error! |
| for m in errors.new_messages(): |
| sys.stderr.write(f"{m}\n") |
| sys.exit(1) |
| |
| |
| def generate_asts_for_modules( |
| py_modules: list[StubSource], parse_only: bool, mypy_options: MypyOptions, verbose: bool |
| ) -> None: |
| """Use mypy to parse (and optionally analyze) source files.""" |
| if not py_modules: |
| return # Nothing to do here, but there may be C modules |
| if verbose: |
| print(f"Processing {len(py_modules)} files...") |
| if parse_only: |
| for mod in py_modules: |
| parse_source_file(mod, mypy_options) |
| return |
| # Perform full semantic analysis of the source set. |
| try: |
| res = build([module.source for module in py_modules], mypy_options) |
| except CompileError as e: |
| raise SystemExit(f"Critical error during semantic analysis: {e}") from e |
| |
| for mod in py_modules: |
| mod.ast = res.graph[mod.module].tree |
| # Use statically inferred __all__ if there is no runtime one. |
| if mod.runtime_all is None: |
| mod.runtime_all = res.manager.semantic_analyzer.export_map[mod.module] |
| |
| |
| def generate_stub_from_ast( |
| mod: StubSource, |
| target: str, |
| parse_only: bool = False, |
| include_private: bool = False, |
| export_less: bool = False, |
| include_docstrings: bool = False, |
| ) -> None: |
| """Use analysed (or just parsed) AST to generate type stub for single file. |
| |
| If directory for target doesn't exist it will created. Existing stub |
| will be overwritten. |
| """ |
| gen = StubGenerator( |
| mod.runtime_all, |
| include_private=include_private, |
| analyzed=not parse_only, |
| export_less=export_less, |
| include_docstrings=include_docstrings, |
| ) |
| assert mod.ast is not None, "This function must be used only with analyzed modules" |
| mod.ast.accept(gen) |
| |
| # Write output to file. |
| subdir = os.path.dirname(target) |
| if subdir and not os.path.isdir(subdir): |
| os.makedirs(subdir) |
| with open(target, "w") as file: |
| file.write("".join(gen.output())) |
| |
| |
| def get_sig_generators(options: Options) -> list[SignatureGenerator]: |
| sig_generators: list[SignatureGenerator] = [ |
| DocstringSignatureGenerator(), |
| FallbackSignatureGenerator(), |
| ] |
| if options.doc_dir: |
| # Collect info from docs (if given). Always check these first. |
| sigs, class_sigs = collect_docs_signatures(options.doc_dir) |
| sig_generators.insert(0, ExternalSignatureGenerator(sigs, class_sigs)) |
| return sig_generators |
| |
| |
| def collect_docs_signatures(doc_dir: str) -> tuple[dict[str, str], dict[str, str]]: |
| """Gather all function and class signatures in the docs. |
| |
| Return a tuple (function signatures, class signatures). |
| Currently only used for C modules. |
| """ |
| all_sigs: list[Sig] = [] |
| all_class_sigs: list[Sig] = [] |
| for path in glob.glob(f"{doc_dir}/*.rst"): |
| with open(path) as f: |
| loc_sigs, loc_class_sigs = parse_all_signatures(f.readlines()) |
| all_sigs += loc_sigs |
| all_class_sigs += loc_class_sigs |
| sigs = dict(find_unique_signatures(all_sigs)) |
| class_sigs = dict(find_unique_signatures(all_class_sigs)) |
| return sigs, class_sigs |
| |
| |
| def generate_stubs(options: Options) -> None: |
| """Main entry point for the program.""" |
| mypy_opts = mypy_options(options) |
| py_modules, c_modules = collect_build_targets(options, mypy_opts) |
| sig_generators = get_sig_generators(options) |
| # Use parsed sources to generate stubs for Python modules. |
| generate_asts_for_modules(py_modules, options.parse_only, mypy_opts, options.verbose) |
| files = [] |
| for mod in py_modules: |
| assert mod.path is not None, "Not found module was not skipped" |
| target = mod.module.replace(".", "/") |
| if os.path.basename(mod.path) == "__init__.py": |
| target += "/__init__.pyi" |
| else: |
| target += ".pyi" |
| target = os.path.join(options.output_dir, target) |
| files.append(target) |
| with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): |
| generate_stub_from_ast( |
| mod, |
| target, |
| options.parse_only, |
| options.include_private, |
| options.export_less, |
| include_docstrings=options.include_docstrings, |
| ) |
| |
| # Separately analyse C modules using different logic. |
| all_modules = sorted(m.module for m in (py_modules + c_modules)) |
| for mod in c_modules: |
| if any(py_mod.module.startswith(mod.module + ".") for py_mod in py_modules + c_modules): |
| target = mod.module.replace(".", "/") + "/__init__.pyi" |
| else: |
| target = mod.module.replace(".", "/") + ".pyi" |
| target = os.path.join(options.output_dir, target) |
| files.append(target) |
| with generate_guarded(mod.module, target, options.ignore_errors, options.verbose): |
| generate_stub_for_c_module( |
| mod.module, |
| target, |
| known_modules=all_modules, |
| sig_generators=sig_generators, |
| include_docstrings=options.include_docstrings, |
| ) |
| num_modules = len(py_modules) + len(c_modules) |
| if not options.quiet and num_modules > 0: |
| print("Processed %d modules" % num_modules) |
| if len(files) == 1: |
| print(f"Generated {files[0]}") |
| else: |
| print(f"Generated files under {common_dir_prefix(files)}" + os.sep) |
| |
| |
| HEADER = """%(prog)s [-h] [more options, see -h] |
| [-m MODULE] [-p PACKAGE] [files ...]""" |
| |
| DESCRIPTION = """ |
| Generate draft stubs for modules. |
| |
| Stubs are generated in directory ./out, to avoid overriding files with |
| manual changes. This directory is assumed to exist. |
| """ |
| |
| |
| def parse_options(args: list[str]) -> Options: |
| parser = argparse.ArgumentParser(prog="stubgen", usage=HEADER, description=DESCRIPTION) |
| |
| parser.add_argument( |
| "--ignore-errors", |
| action="store_true", |
| help="ignore errors when trying to generate stubs for modules", |
| ) |
| parser.add_argument( |
| "--no-import", |
| action="store_true", |
| help="don't import the modules, just parse and analyze them " |
| "(doesn't work with C extension modules and might not " |
| "respect __all__)", |
| ) |
| parser.add_argument( |
| "--parse-only", |
| action="store_true", |
| help="don't perform semantic analysis of sources, just parse them " |
| "(only applies to Python modules, might affect quality of stubs)", |
| ) |
| parser.add_argument( |
| "--include-private", |
| action="store_true", |
| help="generate stubs for objects and members considered private " |
| "(single leading underscore and no trailing underscores)", |
| ) |
| parser.add_argument( |
| "--export-less", |
| action="store_true", |
| help="don't implicitly export all names imported from other modules in the same package", |
| ) |
| parser.add_argument( |
| "--include-docstrings", |
| action="store_true", |
| help="include existing docstrings with the stubs", |
| ) |
| parser.add_argument("-v", "--verbose", action="store_true", help="show more verbose messages") |
| parser.add_argument("-q", "--quiet", action="store_true", help="show fewer messages") |
| parser.add_argument( |
| "--doc-dir", |
| metavar="PATH", |
| default="", |
| help="use .rst documentation in PATH (this may result in " |
| "better stubs in some cases; consider setting this to " |
| "DIR/Python-X.Y.Z/Doc/library)", |
| ) |
| parser.add_argument( |
| "--search-path", |
| metavar="PATH", |
| default="", |
| help="specify module search directories, separated by ':' " |
| "(currently only used if --no-import is given)", |
| ) |
| parser.add_argument( |
| "-o", |
| "--output", |
| metavar="PATH", |
| dest="output_dir", |
| default="out", |
| help="change the output directory [default: %(default)s]", |
| ) |
| parser.add_argument( |
| "-m", |
| "--module", |
| action="append", |
| metavar="MODULE", |
| dest="modules", |
| default=[], |
| help="generate stub for module; can repeat for more modules", |
| ) |
| parser.add_argument( |
| "-p", |
| "--package", |
| action="append", |
| metavar="PACKAGE", |
| dest="packages", |
| default=[], |
| help="generate stubs for package recursively; can be repeated", |
| ) |
| parser.add_argument( |
| metavar="files", |
| nargs="*", |
| dest="files", |
| help="generate stubs for given files or directories", |
| ) |
| |
| ns = parser.parse_args(args) |
| |
| pyversion = sys.version_info[:2] |
| ns.interpreter = sys.executable |
| |
| if ns.modules + ns.packages and ns.files: |
| parser.error("May only specify one of: modules/packages or files.") |
| if ns.quiet and ns.verbose: |
| parser.error("Cannot specify both quiet and verbose messages") |
| |
| # Create the output folder if it doesn't already exist. |
| if not os.path.exists(ns.output_dir): |
| os.makedirs(ns.output_dir) |
| |
| return Options( |
| pyversion=pyversion, |
| no_import=ns.no_import, |
| doc_dir=ns.doc_dir, |
| search_path=ns.search_path.split(":"), |
| interpreter=ns.interpreter, |
| ignore_errors=ns.ignore_errors, |
| parse_only=ns.parse_only, |
| include_private=ns.include_private, |
| output_dir=ns.output_dir, |
| modules=ns.modules, |
| packages=ns.packages, |
| files=ns.files, |
| verbose=ns.verbose, |
| quiet=ns.quiet, |
| export_less=ns.export_less, |
| include_docstrings=ns.include_docstrings, |
| ) |
| |
| |
| def main(args: list[str] | None = None) -> None: |
| mypy.util.check_python_version("stubgen") |
| # Make sure that the current directory is in sys.path so that |
| # stubgen can be run on packages in the current directory. |
| if not ("" in sys.path or "." in sys.path): |
| sys.path.insert(0, "") |
| |
| options = parse_options(sys.argv[1:] if args is None else args) |
| generate_stubs(options) |
| |
| |
| if __name__ == "__main__": |
| main() |