| """Utilities for mypy.stubgen, mypy.stubgenc, and mypy.stubdoc modules.""" |
| |
| from __future__ import annotations |
| |
| import os.path |
| import re |
| import sys |
| import traceback |
| from abc import abstractmethod |
| from collections import defaultdict |
| from collections.abc import Iterable, Iterator, Mapping |
| from contextlib import contextmanager |
| from typing import Final, overload |
| |
| from mypy_extensions import mypyc_attr |
| |
| import mypy.options |
| from mypy.modulefinder import ModuleNotFoundReason |
| from mypy.moduleinspect import InspectError, ModuleInspect |
| from mypy.nodes import PARAM_SPEC_KIND, TYPE_VAR_TUPLE_KIND, ClassDef, FuncDef, TypeAliasStmt |
| from mypy.stubdoc import ArgSig, FunctionSig |
| from mypy.types import ( |
| AnyType, |
| NoneType, |
| Type, |
| TypeList, |
| TypeStrVisitor, |
| UnboundType, |
| UnionType, |
| UnpackType, |
| ) |
| |
| # Modules that may fail when imported, or that may have side effects (fully qualified). |
| NOT_IMPORTABLE_MODULES = () |
| |
| # Typing constructs to be replaced by their builtin equivalents. |
| TYPING_BUILTIN_REPLACEMENTS: Final = { |
| # From typing |
| "typing.Text": "builtins.str", |
| "typing.Tuple": "builtins.tuple", |
| "typing.List": "builtins.list", |
| "typing.Dict": "builtins.dict", |
| "typing.Set": "builtins.set", |
| "typing.FrozenSet": "builtins.frozenset", |
| "typing.Type": "builtins.type", |
| # From typing_extensions |
| "typing_extensions.Text": "builtins.str", |
| "typing_extensions.Tuple": "builtins.tuple", |
| "typing_extensions.List": "builtins.list", |
| "typing_extensions.Dict": "builtins.dict", |
| "typing_extensions.Set": "builtins.set", |
| "typing_extensions.FrozenSet": "builtins.frozenset", |
| "typing_extensions.Type": "builtins.type", |
| } |
| |
| |
| class CantImport(Exception): |
| def __init__(self, module: str, message: str) -> None: |
| self.module = module |
| self.message = message |
| |
| |
| def walk_packages( |
| inspect: ModuleInspect, packages: list[str], verbose: bool = False |
| ) -> Iterator[str]: |
| """Iterates through all packages and sub-packages in the given list. |
| |
| This uses runtime imports (in another process) to find both Python and C modules. |
| For Python packages we simply pass the __path__ attribute to pkgutil.walk_packages() to |
| get the content of the package (all subpackages and modules). However, packages in C |
| extensions do not have this attribute, so we have to roll out our own logic: recursively |
| find all modules imported in the package that have matching names. |
| """ |
| for package_name in packages: |
| if package_name in NOT_IMPORTABLE_MODULES: |
| print(f"{package_name}: Skipped (blacklisted)") |
| continue |
| if verbose: |
| print(f"Trying to import {package_name!r} for runtime introspection") |
| try: |
| prop = inspect.get_package_properties(package_name) |
| except InspectError: |
| if verbose: |
| tb = traceback.format_exc() |
| sys.stderr.write(tb) |
| report_missing(package_name) |
| continue |
| yield prop.name |
| if prop.is_c_module: |
| # Recursively iterate through the subpackages |
| yield from walk_packages(inspect, prop.subpackages, verbose) |
| else: |
| yield from prop.subpackages |
| |
| |
| def find_module_path_using_sys_path(module: str, sys_path: list[str]) -> str | None: |
| relative_candidates = ( |
| module.replace(".", "/") + ".py", |
| os.path.join(module.replace(".", "/"), "__init__.py"), |
| ) |
| for base in sys_path: |
| for relative_path in relative_candidates: |
| path = os.path.join(base, relative_path) |
| if os.path.isfile(path): |
| return path |
| return None |
| |
| |
| def find_module_path_and_all_py3( |
| inspect: ModuleInspect, module: str, verbose: bool |
| ) -> tuple[str | None, list[str] | None] | None: |
| """Find module and determine __all__ for a Python 3 module. |
| |
| Return None if the module is a C or pyc-only module. |
| Return (module_path, __all__) if it is a Python module. |
| Raise CantImport if import failed. |
| """ |
| if module in NOT_IMPORTABLE_MODULES: |
| raise CantImport(module, "") |
| |
| # TODO: Support custom interpreters. |
| if verbose: |
| print(f"Trying to import {module!r} for runtime introspection") |
| try: |
| mod = inspect.get_package_properties(module) |
| except InspectError as e: |
| # Fall back to finding the module using sys.path. |
| path = find_module_path_using_sys_path(module, sys.path) |
| if path is None: |
| raise CantImport(module, str(e)) from e |
| return path, None |
| if mod.is_c_module: |
| return None |
| return mod.file, mod.all |
| |
| |
| @contextmanager |
| def generate_guarded( |
| mod: str, target: str, ignore_errors: bool = True, verbose: bool = False |
| ) -> Iterator[None]: |
| """Ignore or report errors during stub generation. |
| |
| Optionally report success. |
| """ |
| if verbose: |
| print(f"Processing {mod}") |
| try: |
| yield |
| except Exception as e: |
| if not ignore_errors: |
| raise e |
| else: |
| # --ignore-errors was passed |
| print("Stub generation failed for", mod, file=sys.stderr) |
| else: |
| if verbose: |
| print(f"Created {target}") |
| |
| |
| def report_missing(mod: str, message: str | None = "", traceback: str = "") -> None: |
| if message: |
| message = " with error: " + message |
| print(f"{mod}: Failed to import, skipping{message}") |
| |
| |
| def fail_missing(mod: str, reason: ModuleNotFoundReason) -> None: |
| if reason is ModuleNotFoundReason.NOT_FOUND: |
| clarification = "(consider using --search-path)" |
| elif reason is ModuleNotFoundReason.FOUND_WITHOUT_TYPE_HINTS: |
| clarification = "(module likely exists, but is not PEP 561 compatible)" |
| else: |
| clarification = f"(unknown reason '{reason}')" |
| raise SystemExit(f"Can't find module '{mod}' {clarification}") |
| |
| |
| @overload |
| def remove_misplaced_type_comments(source: bytes) -> bytes: ... |
| |
| |
| @overload |
| def remove_misplaced_type_comments(source: str) -> str: ... |
| |
| |
| def remove_misplaced_type_comments(source: str | bytes) -> str | bytes: |
| """Remove comments from source that could be understood as misplaced type comments. |
| |
| Normal comments may look like misplaced type comments, and since they cause blocking |
| parse errors, we want to avoid them. |
| """ |
| if isinstance(source, bytes): |
| # This gives us a 1-1 character code mapping, so it's roundtrippable. |
| text = source.decode("latin1") |
| else: |
| text = source |
| |
| # Remove something that looks like a variable type comment but that's by itself |
| # on a line, as it will often generate a parse error (unless it's # type: ignore). |
| text = re.sub(r'^[ \t]*# +type: +["\'a-zA-Z_].*$', "", text, flags=re.MULTILINE) |
| |
| # Remove something that looks like a function type comment after docstring, |
| # which will result in a parse error. |
| text = re.sub(r'""" *\n[ \t\n]*# +type: +\(.*$', '"""\n', text, flags=re.MULTILINE) |
| text = re.sub(r"''' *\n[ \t\n]*# +type: +\(.*$", "'''\n", text, flags=re.MULTILINE) |
| |
| # Remove something that looks like a badly formed function type comment. |
| text = re.sub(r"^[ \t]*# +type: +\([^()]+(\)[ \t]*)?$", "", text, flags=re.MULTILINE) |
| |
| if isinstance(source, bytes): |
| return text.encode("latin1") |
| else: |
| return text |
| |
| |
| def common_dir_prefix(paths: list[str]) -> str: |
| if not paths: |
| return "." |
| cur = os.path.dirname(os.path.normpath(paths[0])) |
| for path in paths[1:]: |
| while True: |
| path = os.path.dirname(os.path.normpath(path)) |
| if (cur + os.sep).startswith(path + os.sep): |
| cur = path |
| break |
| return cur or "." |
| |
| |
| class AnnotationPrinter(TypeStrVisitor): |
| """Visitor used to print existing annotations in a file. |
| |
| The main difference from TypeStrVisitor is a better treatment of |
| unbound types. |
| |
| Notes: |
| * This visitor doesn't add imports necessary for annotations, this is done separately |
| by ImportTracker. |
| * It can print all kinds of types, but the generated strings may not be valid (notably |
| callable types) since it prints the same string that reveal_type() does. |
| * For Instance types it prints the fully qualified names. |
| """ |
| |
| # TODO: Generate valid string representation for callable types. |
| # TODO: Use short names for Instances. |
| def __init__( |
| self, |
| stubgen: BaseStubGenerator, |
| known_modules: list[str] | None = None, |
| local_modules: list[str] | None = None, |
| ) -> None: |
| super().__init__(options=mypy.options.Options()) |
| self.stubgen = stubgen |
| self.known_modules = known_modules |
| self.local_modules = local_modules or ["builtins"] |
| |
| def visit_any(self, t: AnyType) -> str: |
| s = super().visit_any(t) |
| self.stubgen.import_tracker.require_name(s) |
| return s |
| |
| def visit_unbound_type(self, t: UnboundType) -> str: |
| s = t.name |
| fullname = self.stubgen.resolve_name(s) |
| if fullname == "typing.Union": |
| return " | ".join([item.accept(self) for item in t.args]) |
| if fullname == "typing.Optional": |
| if len(t.args) == 1: |
| return f"{t.args[0].accept(self)} | None" |
| return self.stubgen.add_name("_typeshed.Incomplete") |
| if fullname in TYPING_BUILTIN_REPLACEMENTS: |
| s = self.stubgen.add_name(TYPING_BUILTIN_REPLACEMENTS[fullname], require=True) |
| if self.known_modules is not None and "." in s: |
| # see if this object is from any of the modules that we're currently processing. |
| # reverse sort so that subpackages come before parents: e.g. "foo.bar" before "foo". |
| for module_name in self.local_modules + sorted(self.known_modules, reverse=True): |
| if s.startswith(module_name + "."): |
| if module_name in self.local_modules: |
| s = s[len(module_name) + 1 :] |
| arg_module = module_name |
| break |
| else: |
| arg_module = s[: s.rindex(".")] |
| if arg_module not in self.local_modules: |
| self.stubgen.import_tracker.add_import(arg_module, require=True) |
| elif s == "NoneType": |
| # when called without analysis all types are unbound, so this won't hit |
| # visit_none_type(). |
| s = "None" |
| else: |
| self.stubgen.import_tracker.require_name(s) |
| if t.args: |
| s += f"[{self.args_str(t.args)}]" |
| elif t.empty_tuple_index: |
| s += "[()]" |
| return s |
| |
| def visit_none_type(self, t: NoneType) -> str: |
| return "None" |
| |
| def visit_type_list(self, t: TypeList) -> str: |
| return f"[{self.list_str(t.items)}]" |
| |
| def visit_union_type(self, t: UnionType) -> str: |
| return " | ".join([item.accept(self) for item in t.items]) |
| |
| def visit_unpack_type(self, t: UnpackType) -> str: |
| if self.options.python_version >= (3, 11): |
| return f"*{t.type.accept(self)}" |
| return super().visit_unpack_type(t) |
| |
| def args_str(self, args: Iterable[Type]) -> str: |
| """Convert an array of arguments to strings and join the results with commas. |
| |
| The main difference from list_str is the preservation of quotes for string |
| arguments |
| """ |
| types = ["builtins.bytes", "builtins.str"] |
| res = [] |
| for arg in args: |
| arg_str = arg.accept(self) |
| if isinstance(arg, UnboundType) and arg.original_str_fallback in types: |
| res.append(f"'{arg_str}'") |
| else: |
| res.append(arg_str) |
| return ", ".join(res) |
| |
| |
| class ClassInfo: |
| def __init__( |
| self, |
| name: str, |
| self_var: str, |
| docstring: str | None = None, |
| cls: type | None = None, |
| parent: ClassInfo | None = None, |
| ) -> None: |
| self.name = name |
| self.self_var = self_var |
| self.docstring = docstring |
| self.cls = cls |
| self.parent = parent |
| |
| |
| class FunctionContext: |
| def __init__( |
| self, |
| module_name: str, |
| name: str, |
| docstring: str | None = None, |
| is_abstract: bool = False, |
| class_info: ClassInfo | None = None, |
| ) -> None: |
| self.module_name = module_name |
| self.name = name |
| self.docstring = docstring |
| self.is_abstract = is_abstract |
| self.class_info = class_info |
| self._fullname: str | None = None |
| |
| @property |
| def fullname(self) -> str: |
| if self._fullname is None: |
| if self.class_info: |
| parents = [] |
| class_info: ClassInfo | None = self.class_info |
| while class_info is not None: |
| parents.append(class_info.name) |
| class_info = class_info.parent |
| namespace = ".".join(reversed(parents)) |
| self._fullname = f"{self.module_name}.{namespace}.{self.name}" |
| else: |
| self._fullname = f"{self.module_name}.{self.name}" |
| return self._fullname |
| |
| |
| def infer_method_ret_type(name: str) -> str | None: |
| """Infer return types for known special methods""" |
| if name.startswith("__") and name.endswith("__"): |
| name = name[2:-2] |
| if name in ("float", "bool", "bytes", "int", "complex", "str"): |
| return name |
| # Note: __eq__ and co may return arbitrary types, but bool is good enough for stubgen. |
| elif name in ("eq", "ne", "lt", "le", "gt", "ge", "contains"): |
| return "bool" |
| elif name in ("len", "length_hint", "index", "hash", "sizeof", "trunc", "floor", "ceil"): |
| return "int" |
| elif name in ("format", "repr"): |
| return "str" |
| elif name in ("init", "setitem", "del", "delitem"): |
| return "None" |
| return None |
| |
| |
| def infer_method_arg_types( |
| name: str, self_var: str = "self", arg_names: list[str] | None = None |
| ) -> list[ArgSig] | None: |
| """Infer argument types for known special methods""" |
| args: list[ArgSig] | None = None |
| if name.startswith("__") and name.endswith("__"): |
| if arg_names and len(arg_names) >= 1 and arg_names[0] == "self": |
| arg_names = arg_names[1:] |
| |
| name = name[2:-2] |
| if name == "exit": |
| if arg_names is None: |
| arg_names = ["type", "value", "traceback"] |
| if len(arg_names) == 3: |
| arg_types = [ |
| "type[BaseException] | None", |
| "BaseException | None", |
| "types.TracebackType | None", |
| ] |
| args = [ |
| ArgSig(name=arg_name, type=arg_type) |
| for arg_name, arg_type in zip(arg_names, arg_types) |
| ] |
| if args is not None: |
| return [ArgSig(name=self_var)] + args |
| return None |
| |
| |
| @mypyc_attr(allow_interpreted_subclasses=True) |
| class SignatureGenerator: |
| """Abstract base class for extracting a list of FunctionSigs for each function.""" |
| |
| def remove_self_type( |
| self, inferred: list[FunctionSig] | None, self_var: str |
| ) -> list[FunctionSig] | None: |
| """Remove type annotation from self/cls argument""" |
| if inferred: |
| for signature in inferred: |
| if signature.args: |
| if signature.args[0].name == self_var: |
| signature.args[0].type = None |
| return inferred |
| |
| @abstractmethod |
| def get_function_sig( |
| self, default_sig: FunctionSig, ctx: FunctionContext |
| ) -> list[FunctionSig] | None: |
| """Return a list of signatures for the given function. |
| |
| If no signature can be found, return None. If all of the registered SignatureGenerators |
| for the stub generator return None, then the default_sig will be used. |
| """ |
| pass |
| |
| @abstractmethod |
| def get_property_type(self, default_type: str | None, ctx: FunctionContext) -> str | None: |
| """Return the type of the given property""" |
| pass |
| |
| |
| class ImportTracker: |
| """Record necessary imports during stub generation.""" |
| |
| def __init__(self) -> None: |
| # module_for['foo'] has the module name where 'foo' was imported from, or None if |
| # 'foo' is a module imported directly; |
| # direct_imports['foo'] is the module path used when the name 'foo' was added to the |
| # namespace. |
| # reverse_alias['foo'] is the name that 'foo' had originally when imported with an |
| # alias; examples |
| # 'from pkg import mod' ==> module_for['mod'] == 'pkg' |
| # 'from pkg import mod as m' ==> module_for['m'] == 'pkg' |
| # ==> reverse_alias['m'] == 'mod' |
| # 'import pkg.mod as m' ==> module_for['m'] == None |
| # ==> reverse_alias['m'] == 'pkg.mod' |
| # 'import pkg.mod' ==> module_for['pkg'] == None |
| # ==> module_for['pkg.mod'] == None |
| # ==> direct_imports['pkg'] == 'pkg.mod' |
| # ==> direct_imports['pkg.mod'] == 'pkg.mod' |
| self.module_for: dict[str, str | None] = {} |
| self.direct_imports: dict[str, str] = {} |
| self.reverse_alias: dict[str, str] = {} |
| |
| # required_names is the set of names that are actually used in a type annotation |
| self.required_names: set[str] = set() |
| |
| # Names that should be reexported if they come from another module |
| self.reexports: set[str] = set() |
| |
| def add_import_from( |
| self, module: str, names: list[tuple[str, str | None]], require: bool = False |
| ) -> None: |
| for name, alias in names: |
| if alias: |
| # 'from {module} import {name} as {alias}' |
| self.module_for[alias] = module |
| self.reverse_alias[alias] = name |
| else: |
| # 'from {module} import {name}' |
| self.module_for[name] = module |
| self.reverse_alias.pop(name, None) |
| if require: |
| self.require_name(alias or name) |
| self.direct_imports.pop(alias or name, None) |
| |
| def add_import(self, module: str, alias: str | None = None, require: bool = False) -> None: |
| if alias: |
| # 'import {module} as {alias}' |
| assert "." not in alias # invalid syntax |
| self.module_for[alias] = None |
| self.reverse_alias[alias] = module |
| if require: |
| self.required_names.add(alias) |
| else: |
| # 'import {module}' |
| name = module |
| if require: |
| self.required_names.add(name) |
| # add module and its parent packages |
| while name: |
| self.module_for[name] = None |
| self.direct_imports[name] = module |
| self.reverse_alias.pop(name, None) |
| name = name.rpartition(".")[0] |
| |
| def require_name(self, name: str) -> None: |
| while name not in self.direct_imports and "." in name: |
| name = name.rsplit(".", 1)[0] |
| self.required_names.add(name) |
| |
| def reexport(self, name: str) -> None: |
| """Mark a given non qualified name as needed in __all__. |
| |
| This means that in case it comes from a module, it should be |
| imported with an alias even if the alias is the same as the name. |
| """ |
| self.require_name(name) |
| self.reexports.add(name) |
| |
| def import_lines(self) -> list[str]: |
| """The list of required import lines (as strings with python code). |
| |
| In order for a module be included in this output, an identifier must be both |
| 'required' via require_name() and 'imported' via add_import_from() |
| or add_import() |
| """ |
| result = [] |
| |
| # To summarize multiple names imported from a same module, we collect those |
| # in the `module_map` dictionary, mapping a module path to the list of names that should |
| # be imported from it. the names can also be alias in the form 'original as alias' |
| module_map: Mapping[str, list[str]] = defaultdict(list) |
| |
| for name in sorted( |
| self.required_names, |
| key=lambda n: (self.reverse_alias[n], n) if n in self.reverse_alias else (n, ""), |
| ): |
| # If we haven't seen this name in an import statement, ignore it |
| if name not in self.module_for: |
| continue |
| |
| m = self.module_for[name] |
| if m is not None: |
| # This name was found in a from ... import ... |
| # Collect the name in the module_map |
| if name in self.reverse_alias: |
| name = f"{self.reverse_alias[name]} as {name}" |
| elif name in self.reexports: |
| name = f"{name} as {name}" |
| module_map[m].append(name) |
| else: |
| # This name was found in an import ... |
| # We can already generate the import line |
| if name in self.reverse_alias: |
| source = self.reverse_alias[name] |
| result.append(f"import {source} as {name}\n") |
| elif name in self.reexports: |
| assert "." not in name # Because reexports only has nonqualified names |
| result.append(f"import {name} as {name}\n") |
| else: |
| result.append(f"import {name}\n") |
| |
| # Now generate all the from ... import ... lines collected in module_map |
| for module, names in sorted(module_map.items()): |
| result.append(f"from {module} import {', '.join(sorted(names))}\n") |
| return result |
| |
| |
| @mypyc_attr(allow_interpreted_subclasses=True) |
| class BaseStubGenerator: |
| # These names should be omitted from generated stubs. |
| IGNORED_DUNDERS: Final = { |
| "__all__", |
| "__author__", |
| "__about__", |
| "__copyright__", |
| "__email__", |
| "__license__", |
| "__summary__", |
| "__title__", |
| "__uri__", |
| "__str__", |
| "__repr__", |
| "__getstate__", |
| "__setstate__", |
| "__slots__", |
| "__builtins__", |
| "__cached__", |
| "__file__", |
| "__name__", |
| "__package__", |
| "__path__", |
| "__spec__", |
| "__loader__", |
| } |
| TYPING_MODULE_NAMES: Final = ("typing", "typing_extensions") |
| # Special-cased names that are implicitly exported from the stub (from m import y as y). |
| EXTRA_EXPORTED: Final = { |
| "pyasn1_modules.rfc2437.univ", |
| "pyasn1_modules.rfc2459.char", |
| "pyasn1_modules.rfc2459.univ", |
| } |
| |
| def __init__( |
| self, |
| _all_: list[str] | None = None, |
| include_private: bool = False, |
| export_less: bool = False, |
| include_docstrings: bool = False, |
| ) -> None: |
| # Best known value of __all__. |
| self._all_ = _all_ |
| self._include_private = include_private |
| self._include_docstrings = include_docstrings |
| # Disable implicit exports of package-internal imports? |
| self.export_less = export_less |
| self._import_lines: list[str] = [] |
| self._output: list[str] = [] |
| # Current indent level (indent is hardcoded to 4 spaces). |
| self._indent = "" |
| self._toplevel_names: list[str] = [] |
| self.import_tracker = ImportTracker() |
| # Top-level members |
| self.defined_names: set[str] = set() |
| self.sig_generators = self.get_sig_generators() |
| # populated by visit_mypy_file |
| self.module_name: str = "" |
| # These are "soft" imports for objects which might appear in annotations but not have |
| # a corresponding import statement. |
| self.known_imports = { |
| "_typeshed": ["Incomplete"], |
| "typing": ["Any", "TypeVar", "NamedTuple", "TypedDict"], |
| "collections.abc": ["Generator"], |
| "typing_extensions": ["ParamSpec", "TypeVarTuple"], |
| } |
| |
| def get_sig_generators(self) -> list[SignatureGenerator]: |
| return [] |
| |
| def resolve_name(self, name: str) -> str: |
| """Return the full name resolving imports and import aliases.""" |
| if "." not in name: |
| real_module = self.import_tracker.module_for.get(name) |
| real_short = self.import_tracker.reverse_alias.get(name, name) |
| if real_module is None and real_short not in self.defined_names: |
| real_module = "builtins" # not imported and not defined, must be a builtin |
| else: |
| name_module, real_short = name.split(".", 1) |
| real_module = self.import_tracker.reverse_alias.get(name_module, name_module) |
| resolved_name = real_short if real_module is None else f"{real_module}.{real_short}" |
| return resolved_name |
| |
| def add_name(self, fullname: str, require: bool = True) -> str: |
| """Add a name to be imported and return the name reference. |
| |
| The import will be internal to the stub (i.e don't reexport). |
| """ |
| module, name = fullname.rsplit(".", 1) |
| alias = "_" + name if name in self.defined_names else None |
| while alias in self.defined_names: |
| alias = "_" + alias |
| if module != "builtins" or alias: # don't import from builtins unless needed |
| self.import_tracker.add_import_from(module, [(name, alias)], require=require) |
| return alias or name |
| |
| def add_import_line(self, line: str) -> None: |
| """Add a line of text to the import section, unless it's already there.""" |
| if line not in self._import_lines: |
| self._import_lines.append(line) |
| |
| def get_imports(self) -> str: |
| """Return the import statements for the stub.""" |
| imports = "" |
| if self._import_lines: |
| imports += "".join(self._import_lines) |
| imports += "".join(self.import_tracker.import_lines()) |
| return imports |
| |
| def output(self) -> str: |
| """Return the text for the stub.""" |
| pieces: list[str] = [] |
| if imports := self.get_imports(): |
| pieces.append(imports) |
| if dunder_all := self.get_dunder_all(): |
| pieces.append(dunder_all) |
| if self._output: |
| pieces.append("".join(self._output)) |
| return "\n".join(pieces) |
| |
| def get_dunder_all(self) -> str: |
| """Return the __all__ list for the stub.""" |
| if self._all_: |
| # Note we emit all names in the runtime __all__ here, even if they |
| # don't actually exist. If that happens, the runtime has a bug, and |
| # it's not obvious what the correct behavior should be. We choose |
| # to reflect the runtime __all__ as closely as possible. |
| return f"__all__ = {self._all_!r}\n" |
| return "" |
| |
| def add(self, string: str) -> None: |
| """Add text to generated stub.""" |
| self._output.append(string) |
| |
| def is_top_level(self) -> bool: |
| """Are we processing the top level of a file?""" |
| return self._indent == "" |
| |
| def indent(self) -> None: |
| """Add one level of indentation.""" |
| self._indent += " " |
| |
| def dedent(self) -> None: |
| """Remove one level of indentation.""" |
| self._indent = self._indent[:-4] |
| |
| def record_name(self, name: str) -> None: |
| """Mark a name as defined. |
| |
| This only does anything if at the top level of a module. |
| """ |
| if self.is_top_level(): |
| self._toplevel_names.append(name) |
| |
| def is_recorded_name(self, name: str) -> bool: |
| """Has this name been recorded previously?""" |
| return self.is_top_level() and name in self._toplevel_names |
| |
| def set_defined_names(self, defined_names: set[str]) -> None: |
| self.defined_names = defined_names |
| # Names in __all__ are required |
| for name in self._all_ or (): |
| self.import_tracker.reexport(name) |
| |
| for pkg, imports in self.known_imports.items(): |
| for t in imports: |
| # require=False means that the import won't be added unless require_name() is called |
| # for the object during generation. |
| self.add_name(f"{pkg}.{t}", require=False) |
| |
| def check_undefined_names(self) -> None: |
| undefined_names = [name for name in self._all_ or [] if name not in self._toplevel_names] |
| if undefined_names: |
| if self._output: |
| self.add("\n") |
| self.add("# Names in __all__ with no definition:\n") |
| for name in sorted(undefined_names): |
| self.add(f"# {name}\n") |
| |
| def get_signatures( |
| self, |
| default_signature: FunctionSig, |
| sig_generators: list[SignatureGenerator], |
| func_ctx: FunctionContext, |
| ) -> list[FunctionSig]: |
| for sig_gen in sig_generators: |
| inferred = sig_gen.get_function_sig(default_signature, func_ctx) |
| if inferred: |
| return inferred |
| |
| return [default_signature] |
| |
| def get_property_type( |
| self, |
| default_type: str | None, |
| sig_generators: list[SignatureGenerator], |
| func_ctx: FunctionContext, |
| ) -> str | None: |
| for sig_gen in sig_generators: |
| inferred = sig_gen.get_property_type(default_type, func_ctx) |
| if inferred: |
| return inferred |
| |
| return default_type |
| |
| def format_func_def( |
| self, |
| sigs: list[FunctionSig], |
| is_coroutine: bool = False, |
| decorators: list[str] | None = None, |
| docstring: str | None = None, |
| ) -> list[str]: |
| lines: list[str] = [] |
| if decorators is None: |
| decorators = [] |
| |
| for signature in sigs: |
| # dump decorators, just before "def ..." |
| for deco in decorators: |
| lines.append(f"{self._indent}{deco}") |
| |
| lines.append( |
| signature.format_sig( |
| indent=self._indent, |
| is_async=is_coroutine, |
| docstring=docstring if self._include_docstrings else None, |
| ) |
| ) |
| return lines |
| |
| def format_type_args(self, o: TypeAliasStmt | FuncDef | ClassDef) -> str: |
| if not o.type_args: |
| return "" |
| p = AnnotationPrinter(self) |
| type_args_list: list[str] = [] |
| for type_arg in o.type_args: |
| if type_arg.kind == PARAM_SPEC_KIND: |
| prefix = "**" |
| elif type_arg.kind == TYPE_VAR_TUPLE_KIND: |
| prefix = "*" |
| else: |
| prefix = "" |
| if type_arg.upper_bound: |
| bound_or_values = f": {type_arg.upper_bound.accept(p)}" |
| elif type_arg.values: |
| bound_or_values = f": ({', '.join(v.accept(p) for v in type_arg.values)})" |
| else: |
| bound_or_values = "" |
| if type_arg.default: |
| default = f" = {type_arg.default.accept(p)}" |
| else: |
| default = "" |
| type_args_list.append(f"{prefix}{type_arg.name}{bound_or_values}{default}") |
| return "[" + ", ".join(type_args_list) + "]" |
| |
| def print_annotation( |
| self, |
| t: Type, |
| known_modules: list[str] | None = None, |
| local_modules: list[str] | None = None, |
| ) -> str: |
| printer = AnnotationPrinter(self, known_modules, local_modules) |
| return t.accept(printer) |
| |
| def is_not_in_all(self, name: str) -> bool: |
| if self.is_private_name(name): |
| return False |
| if self._all_: |
| return self.is_top_level() and name not in self._all_ |
| return False |
| |
| def is_private_name(self, name: str, fullname: str | None = None) -> bool: |
| if "__mypy-" in name: |
| return True # Never include mypy generated symbols |
| if self._include_private: |
| return False |
| if fullname in self.EXTRA_EXPORTED: |
| return False |
| if name == "_": |
| return False |
| if not name.startswith("_"): |
| return False |
| if self._all_ and name in self._all_: |
| return False |
| if name.startswith("__") and name.endswith("__"): |
| return name in self.IGNORED_DUNDERS |
| return True |
| |
| def should_reexport(self, name: str, full_module: str, name_is_alias: bool) -> bool: |
| if ( |
| not name_is_alias |
| and self.module_name |
| and (self.module_name + "." + name) in self.EXTRA_EXPORTED |
| ): |
| # Special case certain names that should be exported, against our general rules. |
| return True |
| if name_is_alias: |
| return False |
| if self.export_less: |
| return False |
| if not self.module_name: |
| return False |
| is_private = self.is_private_name(name, full_module + "." + name) |
| if is_private: |
| return False |
| top_level = full_module.split(".")[0] |
| self_top_level = self.module_name.split(".", 1)[0] |
| if top_level not in (self_top_level, "_" + self_top_level): |
| # Export imports from the same package, since we can't reliably tell whether they |
| # are part of the public API. |
| return False |
| if self._all_: |
| return name in self._all_ |
| return True |