Add type annotations
diff --git a/autoflake.py b/autoflake.py
index 7c2ca2c..2d1066a 100755
--- a/autoflake.py
+++ b/autoflake.py
@@ -20,6 +20,8 @@
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
"""Removes unused imports and unused variables as reported by pyflakes."""
+from __future__ import annotations
+
import ast
import collections
import difflib
@@ -34,6 +36,14 @@
import sys
import sysconfig
import tokenize
+from typing import Any
+from typing import Callable
+from typing import cast
+from typing import IO
+from typing import Iterable
+from typing import Mapping
+from typing import MutableMapping
+from typing import Sequence
import pyflakes.api
import pyflakes.messages
@@ -54,7 +64,7 @@
MAX_PYTHON_FILE_DETECTION_BYTES = 1024
-def standard_paths():
+def standard_paths() -> Iterable[str]:
"""Yield paths to standard modules."""
paths = sysconfig.get_paths()
path_names = ("stdlib", "platstdlib")
@@ -71,7 +81,7 @@
yield from os.listdir(dynload_path)
-def standard_package_names():
+def standard_package_names() -> Iterable[str]:
"""Yield standard module names."""
for name in standard_paths():
if name.startswith("_") or "-" in name:
@@ -107,32 +117,40 @@
)
-def unused_import_line_numbers(messages):
+def unused_import_line_numbers(
+ messages: Iterable[pyflakes.messages.Message],
+) -> Iterable[int]:
"""Yield line numbers of unused imports."""
for message in messages:
if isinstance(message, pyflakes.messages.UnusedImport):
yield message.lineno
-def unused_import_module_name(messages):
+def unused_import_module_name(
+ messages: Iterable[pyflakes.messages.Message],
+) -> Iterable[tuple[int, str]]:
"""Yield line number and module name of unused imports."""
- pattern = r"\'(.+?)\'"
+ pattern = re.compile(r"\'(.+?)\'")
for message in messages:
if isinstance(message, pyflakes.messages.UnusedImport):
- module_name = re.search(pattern, str(message))
+ module_name = pattern.search(str(message))
if module_name:
module_name = module_name.group()[1:-1]
yield (message.lineno, module_name)
-def star_import_used_line_numbers(messages):
+def star_import_used_line_numbers(
+ messages: Iterable[pyflakes.messages.Message],
+) -> Iterable[int]:
"""Yield line number of star import usage."""
for message in messages:
if isinstance(message, pyflakes.messages.ImportStarUsed):
yield message.lineno
-def star_import_usage_undefined_name(messages):
+def star_import_usage_undefined_name(
+ messages: Iterable[pyflakes.messages.Message],
+) -> Iterable[tuple[int, str, str]]:
"""Yield line number, undefined name, and its possible origin module."""
for message in messages:
if isinstance(message, pyflakes.messages.ImportStarUsage):
@@ -141,14 +159,19 @@
yield (message.lineno, undefined_name, module_name)
-def unused_variable_line_numbers(messages):
+def unused_variable_line_numbers(
+ messages: Iterable[pyflakes.messages.Message],
+) -> Iterable[int]:
"""Yield line numbers of unused variables."""
for message in messages:
if isinstance(message, pyflakes.messages.UnusedVariable):
yield message.lineno
-def duplicate_key_line_numbers(messages, source):
+def duplicate_key_line_numbers(
+ messages: Iterable[pyflakes.messages.Message],
+ source: str,
+) -> Iterable[int]:
"""Yield line numbers of duplicate keys."""
messages = [
message
@@ -178,15 +201,17 @@
yield message.lineno
-def create_key_to_messages_dict(messages):
+def create_key_to_messages_dict(
+ messages: Iterable[pyflakes.messages.MultiValueRepeatedKeyLiteral],
+) -> Mapping[Any, Iterable[pyflakes.messages.MultiValueRepeatedKeyLiteral]]:
"""Return dict mapping the key to list of messages."""
- dictionary = collections.defaultdict(lambda: [])
+ dictionary = collections.defaultdict(list)
for message in messages:
dictionary[message.message_args[0]].append(message)
return dictionary
-def check(source):
+def check(source: str) -> Iterable[pyflakes.messages.Message]:
"""Return messages from pyflakes."""
reporter = ListReporter()
try:
@@ -199,28 +224,28 @@
class StubFile:
"""Stub out file for pyflakes."""
- def write(self, *_):
+ def write(self, *_: Any) -> None:
"""Stub out."""
class ListReporter(pyflakes.reporter.Reporter):
"""Accumulate messages in messages list."""
- def __init__(self):
+ def __init__(self) -> None:
"""Initialize.
Ignore errors from Reporter.
"""
ignore = StubFile()
pyflakes.reporter.Reporter.__init__(self, ignore, ignore)
- self.messages = []
+ self.messages: list[pyflakes.messages.Message] = []
- def flake(self, message):
+ def flake(self, message: pyflakes.messages.Message) -> None:
"""Accumulate messages."""
self.messages.append(message)
-def extract_package_name(line):
+def extract_package_name(line: str) -> str | None:
"""Return package name in import statement."""
assert "\\" not in line
assert "(" not in line
@@ -239,7 +264,7 @@
return package
-def multiline_import(line, previous_line=""):
+def multiline_import(line: str, previous_line: str = "") -> bool:
"""Return True if import is spans multiples lines."""
for symbol in "()":
if symbol in line:
@@ -248,7 +273,7 @@
return multiline_statement(line, previous_line)
-def multiline_statement(line, previous_line=""):
+def multiline_statement(line: str, previous_line: str = "") -> bool:
"""Return True if this is part of a multiline statement."""
for symbol in "\\:;":
if symbol in line:
@@ -270,11 +295,11 @@
with the following line.
"""
- def __init__(self, line):
+ def __init__(self, line: str) -> None:
"""Analyse and store the first line."""
self.accumulator = collections.deque([line])
- def __call__(self, line):
+ def __call__(self, line: str) -> PendingFix | str:
"""Process line considering the accumulator.
Return self to keep processing the following lines or a string
@@ -283,7 +308,7 @@
raise NotImplementedError("Abstract method needs to be overwritten")
-def _valid_char_in_line(char, line):
+def _valid_char_in_line(char: str, line: str) -> bool:
"""Return True if a char appears in the line and is not commented."""
comment_index = line.find("#")
char_index = line.find(char)
@@ -293,19 +318,22 @@
return valid_char_in_line
-def _top_module(module_name):
+def _top_module(module_name: str) -> str:
"""Return the name of the top level module in the hierarchy."""
if module_name[0] == ".":
return "%LOCAL_MODULE%"
return module_name.split(".")[0]
-def _modules_to_remove(unused_modules, safe_to_remove=SAFE_IMPORTS):
+def _modules_to_remove(
+ unused_modules: Iterable[str],
+ safe_to_remove: Iterable[str] = SAFE_IMPORTS,
+) -> Iterable[str]:
"""Discard unused modules that are not safe to remove from the list."""
return [x for x in unused_modules if _top_module(x) in safe_to_remove]
-def _segment_module(segment):
+def _segment_module(segment: str) -> str:
"""Extract the module identifier inside the segment.
It might be the case the segment does not have a module (e.g. is composed
@@ -338,19 +366,19 @@
def __init__(
self,
- line,
- unused_module=(),
- remove_all_unused_imports=False,
- safe_to_remove=SAFE_IMPORTS,
- previous_line="",
+ line: str,
+ unused_module: Iterable[str] = (),
+ remove_all_unused_imports: bool = False,
+ safe_to_remove: Iterable[str] = SAFE_IMPORTS,
+ previous_line: str = "",
):
"""Receive the same parameters as ``filter_unused_import``."""
- self.remove = unused_module
- self.parenthesized = "(" in line
+ self.remove: Iterable[str] = unused_module
+ self.parenthesized: bool = "(" in line
self.from_, imports = self.IMPORT_RE.split(line, maxsplit=1)
match = self.BASE_RE.search(self.from_)
self.base = match.group(1) if match else None
- self.give_up = False
+ self.give_up: bool = False
if not remove_all_unused_imports:
if self.base and _top_module(self.base) not in safe_to_remove:
@@ -366,7 +394,7 @@
PendingFix.__init__(self, imports)
- def is_over(self, line=None):
+ def is_over(self, line: str | None = None) -> bool:
"""Return True if the multiline import statement is over."""
line = line or self.accumulator[-1]
@@ -375,12 +403,12 @@
return not _valid_char_in_line("\\", line)
- def analyze(self, line):
+ def analyze(self, line: str) -> None:
"""Decide if the statement will be fixed or left unchanged."""
if any(ch in line for ch in ";:#"):
self.give_up = True
- def fix(self, accumulated):
+ def fix(self, accumulated: Iterable[str]) -> str:
"""Given a collection of accumulated lines, fix the entire import."""
old_imports = "".join(accumulated)
ending = get_line_ending(old_imports)
@@ -416,12 +444,14 @@
# Replace empty imports with a "pass" statement
empty = len(fixed.strip(string.whitespace + "\\(),")) < 1
if empty:
- indentation = self.INDENTATION_RE.search(self.from_).group(0)
+ match = self.INDENTATION_RE.search(self.from_)
+ assert match is not None
+ indentation = match.group(0)
return indentation + "pass" + ending
return self.from_ + "import " + fixed
- def __call__(self, line=None):
+ def __call__(self, line: str | None = None) -> PendingFix | str:
"""Accumulate all the lines in the import and then trigger the fix."""
if line:
self.accumulator.append(line)
@@ -434,18 +464,22 @@
return self.fix(self.accumulator)
-def _filter_imports(imports, parent=None, unused_module=()):
+def _filter_imports(
+ imports: Iterable[str],
+ parent: str | None = None,
+ unused_module: Iterable[str] = (),
+) -> Sequence[str]:
# We compare full module name (``a.module`` not `module`) to
# guarantee the exact same module as detected from pyflakes.
sep = "" if parent and parent[-1] == "." else "."
- def full_name(name):
+ def full_name(name: str) -> str:
return name if parent is None else parent + sep + name
return [x for x in imports if full_name(x) not in unused_module]
-def filter_from_import(line, unused_module):
+def filter_from_import(line: str, unused_module: Iterable[str]) -> str:
"""Parse and filter ``from something import a, b, c``.
Return line without unused import modules, or `pass` if all of the
@@ -456,10 +490,12 @@
string=line,
maxsplit=1,
)
- base_module = re.search(
+ match = re.search(
pattern=r"\bfrom\s+([^ ]+)",
string=indentation,
- ).group(1)
+ )
+ assert match is not None
+ base_module = match.group(1)
imports = re.split(pattern=r"\s*,\s*", string=imports.strip())
filtered_imports = _filter_imports(imports, base_module, unused_module)
@@ -473,7 +509,7 @@
return indentation + ", ".join(filtered_imports) + get_line_ending(line)
-def break_up_import(line):
+def break_up_import(line: str) -> str:
"""Return line with imports on separate lines."""
assert "\\" not in line
assert "(" not in line
@@ -501,15 +537,15 @@
def filter_code(
- source,
- additional_imports=None,
- expand_star_imports=False,
- remove_all_unused_imports=False,
- remove_duplicate_keys=False,
- remove_unused_variables=False,
- remove_rhs_for_unused_variables=False,
- ignore_init_module_imports=False,
-):
+ source: str,
+ additional_imports: Iterable[str] | None = None,
+ expand_star_imports: bool = False,
+ remove_all_unused_imports: bool = False,
+ remove_duplicate_keys: bool = False,
+ remove_unused_variables: bool = False,
+ remove_rhs_for_unused_variables: bool = False,
+ ignore_init_module_imports: bool = False,
+) -> Iterable[str]:
"""Yield code with unused imports removed."""
imports = SAFE_IMPORTS
if additional_imports:
@@ -519,15 +555,16 @@
messages = check(source)
if ignore_init_module_imports:
- marked_import_line_numbers = frozenset()
+ marked_import_line_numbers: frozenset[int] = frozenset()
else:
marked_import_line_numbers = frozenset(
unused_import_line_numbers(messages),
)
- marked_unused_module = collections.defaultdict(lambda: [])
+ marked_unused_module: dict[int, list[str]] = collections.defaultdict(list)
for line_number, module_name in unused_import_module_name(messages):
marked_unused_module[line_number].append(module_name)
+ undefined_names: list[str] = []
if expand_star_imports and not (
# See explanations in #18.
re.search(r"\b__all__\b", source)
@@ -540,7 +577,6 @@
# Auto expanding only possible for single star import
marked_star_import_line_numbers = frozenset()
else:
- undefined_names = []
for line_number, undefined_name, _ in star_import_usage_undefined_name(
messages,
):
@@ -558,7 +594,7 @@
marked_variable_line_numbers = frozenset()
if remove_duplicate_keys:
- marked_key_line_numbers = frozenset(
+ marked_key_line_numbers: frozenset[int] = frozenset(
duplicate_key_line_numbers(messages, source),
)
else:
@@ -568,7 +604,7 @@
sio = io.StringIO(source)
previous_line = ""
- result = None
+ result: str | PendingFix = ""
for line_number, line in enumerate(sio.readlines(), start=1):
if isinstance(result, PendingFix):
result = result(line)
@@ -606,27 +642,32 @@
previous_line = line
-def get_messages_by_line(messages):
+def get_messages_by_line(
+ messages: Iterable[pyflakes.messages.Message],
+) -> Mapping[int, pyflakes.messages.Message]:
"""Return dictionary that maps line number to message."""
- line_messages = {}
+ line_messages: dict[int, pyflakes.messages.Message] = {}
for message in messages:
line_messages[message.lineno] = message
return line_messages
-def filter_star_import(line, marked_star_import_undefined_name):
+def filter_star_import(
+ line: str,
+ marked_star_import_undefined_name: Iterable[str],
+) -> str:
"""Return line with the star import expanded."""
undefined_name = sorted(set(marked_star_import_undefined_name))
return re.sub(r"\*", ", ".join(undefined_name), line)
def filter_unused_import(
- line,
- unused_module,
- remove_all_unused_imports,
- imports,
- previous_line="",
-):
+ line: str,
+ unused_module: Iterable[str],
+ remove_all_unused_imports: bool,
+ imports: Iterable[str],
+ previous_line: str = "",
+) -> PendingFix | str:
"""Return line if used, otherwise return None."""
# Ignore doctests.
if line.lstrip().startswith(">"):
@@ -648,7 +689,7 @@
return break_up_import(line)
package = extract_package_name(line)
- if not remove_all_unused_imports and package not in imports:
+ if not remove_all_unused_imports and package is not None and package not in imports:
return line
if "," in line:
@@ -662,7 +703,11 @@
return get_indentation(line) + "pass" + get_line_ending(line)
-def filter_unused_variable(line, previous_line="", drop_rhs=False):
+def filter_unused_variable(
+ line: str,
+ previous_line: str = "",
+ drop_rhs: bool = False,
+) -> str:
"""Return line if used, otherwise return None."""
if re.match(EXCEPT_REGEX, line):
return re.sub(r" as \w+:$", ":", line, count=1)
@@ -690,13 +735,13 @@
def filter_duplicate_key(
- line,
- message,
- line_number,
- marked_line_numbers,
- source,
- previous_line="",
-):
+ line: str,
+ message: pyflakes.messages.Message,
+ line_number: int,
+ marked_line_numbers: Iterable[int],
+ source: str,
+ previous_line: str = "",
+) -> str:
"""Return '' if first occurrence of the key otherwise return `line`."""
if marked_line_numbers and line_number == sorted(marked_line_numbers)[0]:
return ""
@@ -704,7 +749,7 @@
return line
-def dict_entry_has_key(line, key):
+def dict_entry_has_key(line: str, key: Any) -> bool:
"""Return True if `line` is a dict entry that uses `key`.
Return False for multiline cases where the line should not be removed by
@@ -726,10 +771,10 @@
if multiline_statement(result.group(2)):
return False
- return candidate_key == key
+ return cast(bool, candidate_key == key)
-def is_literal_or_name(value):
+def is_literal_or_name(value: str) -> bool:
"""Return True if value is a literal or a name."""
try:
ast.literal_eval(value)
@@ -742,13 +787,13 @@
# Support removal of variables on the right side. But make sure
# there are no dots, which could mean an access of a property.
- return re.match(r"^\w+\s*$", value)
+ return re.match(r"^\w+\s*$", value) is not None
def useless_pass_line_numbers(
- source,
- ignore_pass_after_docstring=False,
-):
+ source: str,
+ ignore_pass_after_docstring: bool = False,
+) -> Iterable[int]:
"""Yield line numbers of unneeded "pass" statements."""
sio = io.StringIO(source)
previous_token_type = None
@@ -799,13 +844,13 @@
def filter_useless_pass(
- source,
- ignore_pass_statements=False,
- ignore_pass_after_docstring=False,
-):
+ source: str,
+ ignore_pass_statements: bool = False,
+ ignore_pass_after_docstring: bool = False,
+) -> Iterable[str]:
"""Yield code with useless "pass" lines removed."""
if ignore_pass_statements:
- marked_lines = frozenset()
+ marked_lines: frozenset[int] = frozenset()
else:
try:
marked_lines = frozenset(
@@ -823,7 +868,7 @@
yield line
-def get_indentation(line):
+def get_indentation(line: str) -> str:
"""Return leading whitespace."""
if line.strip():
non_whitespace_index = len(line) - len(line.lstrip())
@@ -832,7 +877,7 @@
return ""
-def get_line_ending(line):
+def get_line_ending(line: str) -> str:
"""Return line ending."""
non_whitespace_index = len(line.rstrip()) - len(line)
if not non_whitespace_index:
@@ -842,17 +887,17 @@
def fix_code(
- source,
- additional_imports=None,
- expand_star_imports=False,
- remove_all_unused_imports=False,
- remove_duplicate_keys=False,
- remove_unused_variables=False,
- remove_rhs_for_unused_variables=False,
- ignore_init_module_imports=False,
- ignore_pass_statements=False,
- ignore_pass_after_docstring=False,
-):
+ source: str,
+ additional_imports: Iterable[str] | None = None,
+ expand_star_imports: bool = False,
+ remove_all_unused_imports: bool = False,
+ remove_duplicate_keys: bool = False,
+ remove_unused_variables: bool = False,
+ remove_rhs_for_unused_variables: bool = False,
+ ignore_init_module_imports: bool = False,
+ ignore_pass_statements: bool = False,
+ ignore_pass_after_docstring: bool = False,
+) -> str:
"""Return code with all filtering run on it."""
if not source:
return source
@@ -891,7 +936,11 @@
return filtered_source
-def fix_file(filename, args, standard_out=None) -> int:
+def fix_file(
+ filename: str,
+ args: Mapping[str, Any],
+ standard_out: IO[str] | None = None,
+) -> int:
"""Run fix_code() on a file."""
if standard_out is None:
standard_out = sys.stdout
@@ -908,12 +957,12 @@
def _fix_file(
- input_file,
- filename,
- args,
- write_to_stdout,
- standard_out,
- encoding=None,
+ input_file: IO[str],
+ filename: str,
+ args: Mapping[str, Any],
+ write_to_stdout: bool,
+ standard_out: IO[str],
+ encoding: str | None = None,
) -> int:
source = input_file.read()
original_source = source
@@ -981,11 +1030,11 @@
def open_with_encoding(
- filename,
- encoding,
- mode="r",
- limit_byte_check=-1,
-):
+ filename: str,
+ encoding: str | None,
+ mode: str = "r",
+ limit_byte_check: int = -1,
+) -> IO[str]:
"""Return opened file with a specific encoding."""
if not encoding:
encoding = detect_encoding(filename, limit_byte_check=limit_byte_check)
@@ -994,11 +1043,11 @@
filename,
mode=mode,
encoding=encoding,
- newline="",
- ) # Preserve line endings
+ newline="", # Preserve line endings
+ )
-def detect_encoding(filename, limit_byte_check=-1):
+def detect_encoding(filename: str, limit_byte_check: int = -1) -> str:
"""Return file encoding."""
try:
with open(filename, "rb") as input_file:
@@ -1013,7 +1062,7 @@
return "latin-1"
-def _detect_encoding(readline):
+def _detect_encoding(readline: Callable[[], bytes]) -> str:
"""Return file encoding."""
try:
encoding = tokenize.detect_encoding(readline)[0]
@@ -1022,7 +1071,7 @@
return "latin-1"
-def get_diff_text(old, new, filename):
+def get_diff_text(old: Sequence[str], new: Sequence[str], filename: str) -> str:
"""Return text of unified diff between old and new."""
newline = "\n"
diff = difflib.unified_diff(
@@ -1044,12 +1093,12 @@
return text
-def _split_comma_separated(string):
+def _split_comma_separated(string: str) -> set[str]:
"""Return a set of strings."""
return {text.strip() for text in string.split(",") if text.strip()}
-def is_python_file(filename):
+def is_python_file(filename: str) -> bool:
"""Return True if filename is Python file."""
if filename.endswith(".py"):
return True
@@ -1073,7 +1122,7 @@
return True
-def is_exclude_file(filename, exclude):
+def is_exclude_file(filename: str, exclude: Iterable[str]) -> bool:
"""Return True if file matches exclude pattern."""
base_name = os.path.basename(filename)
@@ -1088,7 +1137,7 @@
return False
-def match_file(filename, exclude):
+def match_file(filename: str, exclude: Iterable[str]) -> bool:
"""Return True if file is okay for modifying/recursing."""
if is_exclude_file(filename, exclude):
_LOGGER.debug("Skipped %s: matched to exclude pattern", filename)
@@ -1100,7 +1149,11 @@
return True
-def find_files(filenames, recursive, exclude):
+def find_files(
+ filenames: list[str],
+ recursive: bool,
+ exclude: Iterable[str],
+) -> Iterable[str]:
"""Yield filenames."""
while filenames:
name = filenames.pop(0)
@@ -1129,7 +1182,7 @@
_LOGGER.debug("Skipped %s: matched to exclude pattern", name)
-def process_pyproject_toml(toml_file_path):
+def process_pyproject_toml(toml_file_path: str) -> MutableMapping[str, Any] | None:
"""Extract config mapping from pyproject.toml file."""
try:
import tomllib
@@ -1140,7 +1193,7 @@
return tomllib.load(f).get("tool", {}).get("autoflake", None)
-def process_config_file(config_file_path):
+def process_config_file(config_file_path: str) -> MutableMapping[str, Any] | None:
"""Extract config mapping from config file."""
import configparser
@@ -1152,16 +1205,16 @@
return reader["autoflake"]
-def find_and_process_config(args):
+def find_and_process_config(args: Mapping[str, Any]) -> MutableMapping[str, Any] | None:
# Configuration file parsers {filename: parser function}.
- CONFIG_FILES = {
+ CONFIG_FILES: Mapping[str, Callable[[str], MutableMapping[str, Any] | None]] = {
"pyproject.toml": process_pyproject_toml,
"setup.cfg": process_config_file,
}
# Traverse the file tree common to all files given as argument looking for
# a configuration file
config_path = os.path.commonpath([os.path.abspath(file) for file in args["files"]])
- config = None
+ config: Mapping[str, Any] | None = None
while True:
for config_file, processor in CONFIG_FILES.items():
config_file_path = os.path.join(
@@ -1179,7 +1232,9 @@
return config
-def merge_configuration_file(flag_args):
+def merge_configuration_file(
+ flag_args: MutableMapping[str, Any],
+) -> tuple[MutableMapping[str, Any], bool]:
"""Merge configuration from a file into args."""
BOOL_TYPES = {
"1": True,
@@ -1194,7 +1249,7 @@
if "config_file" in flag_args:
config_file = pathlib.Path(flag_args["config_file"]).resolve()
- config = process_config_file(config_file)
+ config = process_config_file(str(config_file))
if not config:
_LOGGER.error(
@@ -1222,7 +1277,7 @@
"write_to_stdout",
}
- config_args = {}
+ config_args: dict[str, Any] = {}
if config is not None:
for name, value in config.items():
arg = name.replace("-", "_")
@@ -1272,7 +1327,12 @@
}, True
-def _main(argv, standard_out, standard_error, standard_input=None) -> int:
+def _main(
+ argv: Sequence[str],
+ standard_out: IO[str] | None,
+ standard_error: IO[str] | None,
+ standard_input: IO[str] | None = None,
+) -> int:
"""Return exit status.
0 means no error.
@@ -1423,8 +1483,7 @@
),
)
- args = parser.parse_args(argv[1:])
- args = vars(args)
+ args: MutableMapping[str, Any] = vars(parser.parse_args(argv[1:]))
if standard_error is None:
_LOGGER.addHandler(logging.NullHandler())
@@ -1476,7 +1535,7 @@
or standard_out is not None
):
for name in files:
- if name == "-":
+ if name == "-" and standard_input is not None:
exit_status |= _fix_file(
standard_input,
args["stdin_display_name"],
diff --git a/test_autoflake.py b/test_autoflake.py
index 4dd6271..c321dc9 100755
--- a/test_autoflake.py
+++ b/test_autoflake.py
@@ -10,6 +10,10 @@
import sys
import tempfile
import unittest
+from typing import Any
+from typing import Iterator
+from typing import Mapping
+from typing import Sequence
import autoflake
@@ -30,10 +34,10 @@
"""Unit tests."""
- def test_imports(self):
+ def test_imports(self) -> None:
self.assertGreater(len(autoflake.SAFE_IMPORTS), 0)
- def test_unused_import_line_numbers(self):
+ def test_unused_import_line_numbers(self) -> None:
self.assertEqual(
[1],
list(
@@ -43,7 +47,7 @@
),
)
- def test_unused_import_line_numbers_with_from(self):
+ def test_unused_import_line_numbers_with_from(self) -> None:
self.assertEqual(
[1],
list(
@@ -53,7 +57,7 @@
),
)
- def test_unused_import_line_numbers_with_dot(self):
+ def test_unused_import_line_numbers_with_dot(self) -> None:
self.assertEqual(
[1],
list(
@@ -63,7 +67,7 @@
),
)
- def test_extract_package_name(self):
+ def test_extract_package_name(self) -> None:
self.assertEqual("os", autoflake.extract_package_name("import os"))
self.assertEqual(
"os",
@@ -74,10 +78,10 @@
autoflake.extract_package_name("import os.path"),
)
- def test_extract_package_name_should_ignore_doctest_for_now(self):
+ def test_extract_package_name_should_ignore_doctest_for_now(self) -> None:
self.assertFalse(autoflake.extract_package_name(">>> import os"))
- def test_standard_package_names(self):
+ def test_standard_package_names(self) -> None:
self.assertIn("os", list(autoflake.standard_package_names()))
self.assertIn("subprocess", list(autoflake.standard_package_names()))
self.assertIn("urllib", list(autoflake.standard_package_names()))
@@ -85,7 +89,7 @@
self.assertNotIn("autoflake", list(autoflake.standard_package_names()))
self.assertNotIn("pep8", list(autoflake.standard_package_names()))
- def test_get_line_ending(self):
+ def test_get_line_ending(self) -> None:
self.assertEqual("\n", autoflake.get_line_ending("\n"))
self.assertEqual("\n", autoflake.get_line_ending("abc\n"))
self.assertEqual("\t \t\n", autoflake.get_line_ending("abc\t \t\n"))
@@ -93,7 +97,7 @@
self.assertEqual("", autoflake.get_line_ending("abc"))
self.assertEqual("", autoflake.get_line_ending(""))
- def test_get_indentation(self):
+ def test_get_indentation(self) -> None:
self.assertEqual("", autoflake.get_indentation(""))
self.assertEqual(" ", autoflake.get_indentation(" abc"))
self.assertEqual(" ", autoflake.get_indentation(" abc \n\t"))
@@ -101,7 +105,7 @@
self.assertEqual(" \t ", autoflake.get_indentation(" \t abc \n\t"))
self.assertEqual("", autoflake.get_indentation(" "))
- def test_filter_star_import(self):
+ def test_filter_star_import(self) -> None:
self.assertEqual(
"from math import cos",
autoflake.filter_star_import(
@@ -118,7 +122,7 @@
),
)
- def test_filter_unused_variable(self):
+ def test_filter_unused_variable(self) -> None:
self.assertEqual(
"foo()",
autoflake.filter_unused_variable("x = foo()"),
@@ -129,7 +133,7 @@
autoflake.filter_unused_variable(" x = foo()"),
)
- def test_filter_unused_variable_with_literal_or_name(self):
+ def test_filter_unused_variable_with_literal_or_name(self) -> None:
self.assertEqual(
"pass",
autoflake.filter_unused_variable("x = 1"),
@@ -145,7 +149,7 @@
autoflake.filter_unused_variable("x = {}"),
)
- def test_filter_unused_variable_with_basic_data_structures(self):
+ def test_filter_unused_variable_with_basic_data_structures(self) -> None:
self.assertEqual(
"pass",
autoflake.filter_unused_variable("x = dict()"),
@@ -161,19 +165,19 @@
autoflake.filter_unused_variable("x = set()"),
)
- def test_filter_unused_variable_should_ignore_multiline(self):
+ def test_filter_unused_variable_should_ignore_multiline(self) -> None:
self.assertEqual(
"x = foo()\\",
autoflake.filter_unused_variable("x = foo()\\"),
)
- def test_filter_unused_variable_should_multiple_assignments(self):
+ def test_filter_unused_variable_should_multiple_assignments(self) -> None:
self.assertEqual(
"x = y = foo()",
autoflake.filter_unused_variable("x = y = foo()"),
)
- def test_filter_unused_variable_with_exception(self):
+ def test_filter_unused_variable_with_exception(self) -> None:
self.assertEqual(
"except Exception:",
autoflake.filter_unused_variable("except Exception as exception:"),
@@ -186,7 +190,7 @@
),
)
- def test_filter_unused_variable_drop_rhs(self):
+ def test_filter_unused_variable_drop_rhs(self) -> None:
self.assertEqual(
"",
autoflake.filter_unused_variable(
@@ -203,7 +207,7 @@
),
)
- def test_filter_unused_variable_with_literal_or_name_drop_rhs(self):
+ def test_filter_unused_variable_with_literal_or_name_drop_rhs(self) -> None:
self.assertEqual(
"pass",
autoflake.filter_unused_variable("x = 1", drop_rhs=True),
@@ -219,7 +223,7 @@
autoflake.filter_unused_variable("x = {}", drop_rhs=True),
)
- def test_filter_unused_variable_with_basic_data_structures_drop_rhs(self):
+ def test_filter_unused_variable_with_basic_data_structures_drop_rhs(self) -> None:
self.assertEqual(
"pass",
autoflake.filter_unused_variable("x = dict()", drop_rhs=True),
@@ -235,19 +239,19 @@
autoflake.filter_unused_variable("x = set()", drop_rhs=True),
)
- def test_filter_unused_variable_should_ignore_multiline_drop_rhs(self):
+ def test_filter_unused_variable_should_ignore_multiline_drop_rhs(self) -> None:
self.assertEqual(
"x = foo()\\",
autoflake.filter_unused_variable("x = foo()\\", drop_rhs=True),
)
- def test_filter_unused_variable_should_multiple_assignments_drop_rhs(self):
+ def test_filter_unused_variable_should_multiple_assignments_drop_rhs(self) -> None:
self.assertEqual(
"x = y = foo()",
autoflake.filter_unused_variable("x = y = foo()", drop_rhs=True),
)
- def test_filter_unused_variable_with_exception_drop_rhs(self):
+ def test_filter_unused_variable_with_exception_drop_rhs(self) -> None:
self.assertEqual(
"except Exception:",
autoflake.filter_unused_variable(
@@ -264,7 +268,7 @@
),
)
- def test_filter_code(self):
+ def test_filter_code(self) -> None:
self.assertEqual(
"""\
import os
@@ -282,7 +286,7 @@
),
)
- def test_filter_code_with_indented_import(self):
+ def test_filter_code_with_indented_import(self) -> None:
self.assertEqual(
"""\
import os
@@ -302,7 +306,7 @@
),
)
- def test_filter_code_with_from(self):
+ def test_filter_code_with_from(self) -> None:
self.assertEqual(
"""\
pass
@@ -318,7 +322,7 @@
),
)
- def test_filter_code_with_not_from(self):
+ def test_filter_code_with_not_from(self) -> None:
self.assertEqual(
"""\
pass
@@ -335,7 +339,7 @@
),
)
- def test_filter_code_with_used_from(self):
+ def test_filter_code_with_used_from(self) -> None:
self.assertEqual(
"""\
import frommer
@@ -352,7 +356,7 @@
),
)
- def test_filter_code_with_ambiguous_from(self):
+ def test_filter_code_with_ambiguous_from(self) -> None:
self.assertEqual(
"""\
pass
@@ -367,7 +371,7 @@
),
)
- def test_filter_code_should_avoid_inline_except(self):
+ def test_filter_code_should_avoid_inline_except(self) -> None:
line = """\
try: from zap import foo
except: from zap import bar
@@ -382,7 +386,7 @@
),
)
- def test_filter_code_should_avoid_escaped_newlines(self):
+ def test_filter_code_should_avoid_escaped_newlines(self) -> None:
line = """\
try:\\
from zap import foo
@@ -399,7 +403,7 @@
),
)
- def test_filter_code_with_remove_all_unused_imports(self):
+ def test_filter_code_with_remove_all_unused_imports(self) -> None:
self.assertEqual(
"""\
pass
@@ -418,7 +422,7 @@
),
)
- def test_filter_code_with_additional_imports(self):
+ def test_filter_code_with_additional_imports(self) -> None:
self.assertEqual(
"""\
pass
@@ -437,7 +441,7 @@
),
)
- def test_filter_code_should_ignore_imports_with_inline_comment(self):
+ def test_filter_code_should_ignore_imports_with_inline_comment(self) -> None:
self.assertEqual(
"""\
from os import path # foo
@@ -457,7 +461,7 @@
),
)
- def test_filter_code_should_respect_noqa(self):
+ def test_filter_code_should_respect_noqa(self) -> None:
self.assertEqual(
"""\
pass
@@ -477,7 +481,7 @@
),
)
- def test_filter_code_expand_star_imports(self):
+ def test_filter_code_expand_star_imports(self) -> None:
self.assertEqual(
"""\
from math import sin
@@ -512,7 +516,7 @@
),
)
- def test_filter_code_ignore_multiple_star_import(self):
+ def test_filter_code_ignore_multiple_star_import(self) -> None:
self.assertEqual(
"""\
from math import *
@@ -533,7 +537,7 @@
),
)
- def test_filter_code_with_special_re_symbols_in_key(self):
+ def test_filter_code_with_special_re_symbols_in_key(self) -> None:
self.assertEqual(
"""\
a = {
@@ -555,7 +559,7 @@
),
)
- def test_multiline_import(self):
+ def test_multiline_import(self) -> None:
self.assertTrue(
autoflake.multiline_import(
r"""\
@@ -586,7 +590,7 @@
autoflake.multiline_import("from os import (path, sep)"),
)
- def test_multiline_statement(self):
+ def test_multiline_statement(self) -> None:
self.assertFalse(autoflake.multiline_statement("x = foo()"))
self.assertTrue(autoflake.multiline_statement("x = 1;"))
@@ -599,25 +603,25 @@
),
)
- def test_break_up_import(self):
+ def test_break_up_import(self) -> None:
self.assertEqual(
"import abc\nimport subprocess\nimport math\n",
autoflake.break_up_import("import abc, subprocess, math\n"),
)
- def test_break_up_import_with_indentation(self):
+ def test_break_up_import_with_indentation(self) -> None:
self.assertEqual(
" import abc\n import subprocess\n import math\n",
autoflake.break_up_import(" import abc, subprocess, math\n"),
)
- def test_break_up_import_should_do_nothing_on_no_line_ending(self):
+ def test_break_up_import_should_do_nothing_on_no_line_ending(self) -> None:
self.assertEqual(
"import abc, subprocess, math",
autoflake.break_up_import("import abc, subprocess, math"),
)
- def test_filter_from_import_no_remove(self):
+ def test_filter_from_import_no_remove(self) -> None:
self.assertEqual(
"""\
from foo import abc, subprocess, math\n""",
@@ -627,7 +631,7 @@
),
)
- def test_filter_from_import_remove_module(self):
+ def test_filter_from_import_remove_module(self) -> None:
self.assertEqual(
"""\
from foo import subprocess, math\n""",
@@ -637,7 +641,7 @@
),
)
- def test_filter_from_import_remove_all(self):
+ def test_filter_from_import_remove_all(self) -> None:
self.assertEqual(
" pass\n",
autoflake.filter_from_import(
@@ -650,7 +654,7 @@
),
)
- def test_filter_code_multiline_imports(self):
+ def test_filter_code_multiline_imports(self) -> None:
self.assertEqual(
r"""\
import os
@@ -671,7 +675,7 @@
),
)
- def test_filter_code_multiline_from_imports(self):
+ def test_filter_code_multiline_from_imports(self) -> None:
self.assertEqual(
r"""\
import os
@@ -709,7 +713,7 @@
),
)
- def test_filter_code_should_ignore_semicolons(self):
+ def test_filter_code_should_ignore_semicolons(self) -> None:
self.assertEqual(
r"""\
import os
@@ -729,7 +733,7 @@
),
)
- def test_filter_code_should_ignore_non_standard_library(self):
+ def test_filter_code_should_ignore_non_standard_library(self) -> None:
self.assertEqual(
"""\
import os
@@ -755,7 +759,7 @@
),
)
- def test_filter_code_should_ignore_unsafe_imports(self):
+ def test_filter_code_should_ignore_unsafe_imports(self) -> None:
self.assertEqual(
"""\
import rlcompleter
@@ -777,16 +781,16 @@
),
)
- def test_filter_code_should_ignore_docstring(self):
+ def test_filter_code_should_ignore_docstring(self) -> None:
line = """
-def foo():
+def foo() -> None:
'''
>>> import math
'''
"""
self.assertEqual(line, "".join(autoflake.filter_code(line)))
- def test_with_ignore_init_module_imports_flag(self):
+ def test_with_ignore_init_module_imports_flag(self) -> None:
# Need a temp directory in order to specify file name as __init__.py
temp_directory = tempfile.mkdtemp(dir=".")
temp_file = os.path.join(temp_directory, "__init__.py")
@@ -804,7 +808,7 @@
finally:
shutil.rmtree(temp_directory)
- def test_without_ignore_init_module_imports_flag(self):
+ def test_without_ignore_init_module_imports_flag(self) -> None:
# Need a temp directory in order to specify file name as __init__.py
temp_directory = tempfile.mkdtemp(dir=".")
temp_file = os.path.join(temp_directory, "__init__.py")
@@ -822,7 +826,7 @@
finally:
shutil.rmtree(temp_directory)
- def test_fix_code(self):
+ def test_fix_code(self) -> None:
self.assertEqual(
"""\
import os
@@ -845,7 +849,7 @@
),
)
- def test_fix_code_with_from_and_as(self):
+ def test_fix_code_with_from_and_as(self) -> None:
self.assertEqual(
"""\
from collections import namedtuple as xyz
@@ -895,7 +899,7 @@
),
)
- def test_fix_code_with_from_with_and_without_remove_all(self):
+ def test_fix_code_with_from_with_and_without_remove_all(self) -> None:
code = """\
from x import a as b, c as d
"""
@@ -911,7 +915,7 @@
autoflake.fix_code(code, remove_all_unused_imports=False),
)
- def test_fix_code_with_from_and_depth_module(self):
+ def test_fix_code_with_from_and_depth_module(self) -> None:
self.assertEqual(
"""\
from distutils.version import StrictVersion
@@ -938,16 +942,16 @@
),
)
- def test_fix_code_with_indented_from(self):
+ def test_fix_code_with_indented_from(self) -> None:
self.assertEqual(
"""\
-def z():
+def z() -> None:
from ctypes import POINTER, byref
POINTER, byref
""",
autoflake.fix_code(
"""\
-def z():
+def z() -> None:
from ctypes import c_short, c_uint, c_int, c_long, pointer, POINTER, byref
POINTER, byref
""",
@@ -956,24 +960,24 @@
self.assertEqual(
"""\
-def z():
+def z() -> None:
pass
""",
autoflake.fix_code(
"""\
-def z():
+def z() -> None:
from ctypes import c_short, c_uint, c_int, c_long, pointer, POINTER, byref
""",
),
)
- def test_fix_code_with_empty_string(self):
+ def test_fix_code_with_empty_string(self) -> None:
self.assertEqual(
"",
autoflake.fix_code(""),
)
- def test_fix_code_with_from_and_as_and_escaped_newline(self):
+ def test_fix_code_with_from_and_as_and_escaped_newline(self) -> None:
"""Make sure stuff after escaped newline is not lost."""
result = autoflake.fix_code(
"""\
@@ -995,16 +999,16 @@
autoflake.fix_code(result),
)
- def test_fix_code_with_unused_variables(self):
+ def test_fix_code_with_unused_variables(self) -> None:
self.assertEqual(
"""\
-def main():
+def main() -> None:
y = 11
print(y)
""",
autoflake.fix_code(
"""\
-def main():
+def main() -> None:
x = 10
y = 11
print(y)
@@ -1013,16 +1017,16 @@
),
)
- def test_fix_code_with_unused_variables_drop_rhs(self):
+ def test_fix_code_with_unused_variables_drop_rhs(self) -> None:
self.assertEqual(
"""\
-def main():
+def main() -> None:
y = 11
print(y)
""",
autoflake.fix_code(
"""\
-def main():
+def main() -> None:
x = 10
y = 11
print(y)
@@ -1032,13 +1036,13 @@
),
)
- def test_fix_code_with_unused_variables_should_skip_nonlocal(self):
+ def test_fix_code_with_unused_variables_should_skip_nonlocal(self) -> None:
"""pyflakes does not handle nonlocal correctly."""
code = """\
-def bar():
+def bar() -> None:
x = 1
- def foo():
+ def foo() -> None:
nonlocal x
x = 2
"""
@@ -1055,10 +1059,10 @@
):
"""pyflakes does not handle nonlocal correctly."""
code = """\
-def bar():
+def bar() -> None:
x = 1
- def foo():
+ def foo() -> None:
nonlocal x
x = 2
"""
@@ -1071,39 +1075,39 @@
),
)
- def test_detect_encoding_with_bad_encoding(self):
+ def test_detect_encoding_with_bad_encoding(self) -> None:
with temporary_file("# -*- coding: blah -*-\n") as filename:
self.assertEqual(
"latin-1",
autoflake.detect_encoding(filename),
)
- def test_fix_code_with_comma_on_right(self):
+ def test_fix_code_with_comma_on_right(self) -> None:
"""pyflakes does not handle nonlocal correctly."""
self.assertEqual(
"""\
-def main():
+def main() -> None:
pass
""",
autoflake.fix_code(
"""\
-def main():
+def main() -> None:
x = (1, 2, 3)
""",
remove_unused_variables=True,
),
)
- def test_fix_code_with_comma_on_right_drop_rhs(self):
+ def test_fix_code_with_comma_on_right_drop_rhs(self) -> None:
"""pyflakes does not handle nonlocal correctly."""
self.assertEqual(
"""\
-def main():
+def main() -> None:
pass
""",
autoflake.fix_code(
"""\
-def main():
+def main() -> None:
x = (1, 2, 3)
""",
remove_unused_variables=True,
@@ -1111,9 +1115,9 @@
),
)
- def test_fix_code_with_unused_variables_should_skip_multiple(self):
+ def test_fix_code_with_unused_variables_should_skip_multiple(self) -> None:
code = """\
-def main():
+def main() -> None:
(x, y, z) = (1, 2, 3)
print(z)
"""
@@ -1129,7 +1133,7 @@
self,
):
code = """\
-def main():
+def main() -> None:
(x, y, z) = (1, 2, 3)
print(z)
"""
@@ -1142,14 +1146,14 @@
),
)
- def test_fix_code_should_handle_pyflakes_recursion_error_gracefully(self):
+ def test_fix_code_should_handle_pyflakes_recursion_error_gracefully(self) -> None:
code = "x = [{}]".format("+".join(["abc" for _ in range(2000)]))
self.assertEqual(
code,
autoflake.fix_code(code),
)
- def test_fix_code_with_duplicate_key(self):
+ def test_fix_code_with_duplicate_key(self) -> None:
self.assertEqual(
"""\
a = {
@@ -1172,7 +1176,7 @@
),
)
- def test_fix_code_with_duplicate_key_longer(self):
+ def test_fix_code_with_duplicate_key_longer(self) -> None:
self.assertEqual(
"""\
{
@@ -1202,7 +1206,7 @@
),
)
- def test_fix_code_with_duplicate_key_with_many_braces(self):
+ def test_fix_code_with_duplicate_key_with_many_braces(self) -> None:
self.assertEqual(
"""\
a = None
@@ -1232,7 +1236,7 @@
),
)
- def test_fix_code_should_ignore_complex_case_of_duplicate_key(self):
+ def test_fix_code_should_ignore_complex_case_of_duplicate_key(self) -> None:
"""We only handle simple cases."""
code = """\
a = {(0,1): 1, (0, 1): 'two',
@@ -1251,7 +1255,7 @@
),
)
- def test_fix_code_should_ignore_complex_case_of_duplicate_key_comma(self):
+ def test_fix_code_should_ignore_complex_case_of_duplicate_key_comma(self) -> None:
"""We only handle simple cases."""
code = """\
{
@@ -1302,7 +1306,7 @@
),
)
- def test_fix_code_should_ignore_more_cases_of_duplicate_key(self):
+ def test_fix_code_should_ignore_more_cases_of_duplicate_key(self) -> None:
"""We only handle simple cases."""
code = """\
a = {
@@ -1324,7 +1328,7 @@
),
)
- def test_fix_code_should_ignore_duplicate_key_with_comments(self):
+ def test_fix_code_should_ignore_duplicate_key_with_comments(self) -> None:
"""We only handle simple cases."""
code = """\
a = {
@@ -1367,7 +1371,7 @@
),
)
- def test_fix_code_should_ignore_duplicate_key_with_multiline_key(self):
+ def test_fix_code_should_ignore_duplicate_key_with_multiline_key(self) -> None:
"""We only handle simple cases."""
code = """\
a = {
@@ -1389,7 +1393,7 @@
),
)
- def test_fix_code_should_ignore_duplicate_key_with_no_comma(self):
+ def test_fix_code_should_ignore_duplicate_key_with_no_comma(self) -> None:
"""We don't want to delete the line and leave a lone comma."""
code = """\
a = {
@@ -1411,27 +1415,27 @@
),
)
- def test_fix_code_keeps_pass_statements(self):
+ def test_fix_code_keeps_pass_statements(self) -> None:
code = """\
if True:
pass
else:
- def foo():
+ def foo() -> None:
\"\"\" A docstring. \"\"\"
pass
- def foo2():
+ def foo2() -> None:
\"\"\" A docstring. \"\"\"
pass
- def foo3():
+ def foo3() -> None:
\"\"\" A docstring. \"\"\"
pass
- def bar():
+ def bar() -> None:
# abc
pass
- def blah():
+ def blah() -> None:
123
pass
pass # Nope.
@@ -1448,28 +1452,28 @@
),
)
- def test_fix_code_keeps_passes_after_docstrings(self):
+ def test_fix_code_keeps_passes_after_docstrings(self) -> None:
actual = autoflake.fix_code(
"""\
if True:
pass
else:
- def foo():
+ def foo() -> None:
\"\"\" A docstring. \"\"\"
pass
- def foo2():
+ def foo2() -> None:
\"\"\" A docstring. \"\"\"
pass
- def foo3():
+ def foo3() -> None:
\"\"\" A docstring. \"\"\"
pass
- def bar():
+ def bar() -> None:
# abc
pass
- def blah():
+ def blah() -> None:
123
pass
pass # Nope.
@@ -1482,29 +1486,29 @@
if True:
pass
else:
- def foo():
+ def foo() -> None:
\"\"\" A docstring. \"\"\"
pass
- def foo2():
+ def foo2() -> None:
\"\"\" A docstring. \"\"\"
pass
- def foo3():
+ def foo3() -> None:
\"\"\" A docstring. \"\"\"
pass
- def bar():
+ def bar() -> None:
# abc
pass
- def blah():
+ def blah() -> None:
123
pass # Nope.
"""
self.assertEqual(actual, expected)
- def test_useless_pass_line_numbers(self):
+ def test_useless_pass_line_numbers(self) -> None:
self.assertEqual(
[1],
list(
@@ -1523,7 +1527,7 @@
),
)
- def test_useless_pass_line_numbers_with_escaped_newline(self):
+ def test_useless_pass_line_numbers_with_escaped_newline(self) -> None:
self.assertEqual(
[],
list(
@@ -1533,7 +1537,7 @@
),
)
- def test_useless_pass_line_numbers_with_more_complex(self):
+ def test_useless_pass_line_numbers_with_more_complex(self) -> None:
self.assertEqual(
[6],
list(
@@ -1550,12 +1554,12 @@
),
)
- def test_useless_pass_line_numbers_after_docstring(self):
+ def test_useless_pass_line_numbers_after_docstring(self) -> None:
actual_pass_line_numbers = list(
autoflake.useless_pass_line_numbers(
"""\
@abc.abstractmethod
- def some_abstract_method():
+ def some_abstract_method() -> None:
\"\"\"Some docstring.\"\"\"
pass
""",
@@ -1565,12 +1569,12 @@
expected_pass_line_numbers = [4]
self.assertEqual(expected_pass_line_numbers, actual_pass_line_numbers)
- def test_useless_pass_line_numbers_keep_pass_after_docstring(self):
+ def test_useless_pass_line_numbers_keep_pass_after_docstring(self) -> None:
actual_pass_line_numbers = list(
autoflake.useless_pass_line_numbers(
"""\
@abc.abstractmethod
- def some_abstract_method():
+ def some_abstract_method() -> None:
\"\"\"Some docstring.\"\"\"
pass
""",
@@ -1581,7 +1585,7 @@
expected_pass_line_numbers = []
self.assertEqual(expected_pass_line_numbers, actual_pass_line_numbers)
- def test_filter_useless_pass(self):
+ def test_filter_useless_pass(self) -> None:
self.assertEqual(
"""\
if True:
@@ -1604,7 +1608,7 @@
),
)
- def test_filter_useless_pass_with_syntax_error(self):
+ def test_filter_useless_pass_with_syntax_error(self) -> None:
source = """\
if True:
if True:
@@ -1625,19 +1629,19 @@
"".join(autoflake.filter_useless_pass(source)),
)
- def test_filter_useless_pass_more_complex(self):
+ def test_filter_useless_pass_more_complex(self) -> None:
self.assertEqual(
"""\
if True:
pass
else:
- def foo():
+ def foo() -> None:
pass
# abc
- def bar():
+ def bar() -> None:
# abc
pass
- def blah():
+ def blah() -> None:
123
pass # Nope.
True
@@ -1649,13 +1653,13 @@
if True:
pass
else:
- def foo():
+ def foo() -> None:
pass
# abc
- def bar():
+ def bar() -> None:
# abc
pass
- def blah():
+ def blah() -> None:
123
pass
pass # Nope.
@@ -1668,14 +1672,14 @@
),
)
- def test_filter_useless_pass_keep_pass_after_docstring(self):
+ def test_filter_useless_pass_keep_pass_after_docstring(self) -> None:
source = """\
- def foo():
+ def foo() -> None:
\"\"\" This is not a useless 'pass'. \"\"\"
pass
@abc.abstractmethod
- def bar():
+ def bar() -> None:
\"\"\"
Also this is not a useless 'pass'.
\"\"\"
@@ -1691,7 +1695,7 @@
),
)
- def test_filter_useless_pass_keeps_pass_statements(self):
+ def test_filter_useless_pass_keeps_pass_statements(self) -> None:
source = """\
if True:
pass
@@ -1715,7 +1719,7 @@
),
)
- def test_filter_useless_paspasss_with_try(self):
+ def test_filter_useless_paspasss_with_try(self) -> None:
self.assertEqual(
"""\
import os
@@ -1740,7 +1744,7 @@
),
)
- def test_filter_useless_pass_leading_pass(self):
+ def test_filter_useless_pass_leading_pass(self) -> None:
self.assertEqual(
"""\
if True:
@@ -1767,17 +1771,17 @@
),
)
- def test_filter_useless_pass_leading_pass_with_number(self):
+ def test_filter_useless_pass_leading_pass_with_number(self) -> None:
self.assertEqual(
"""\
-def func11():
+def func11() -> None:
0, 11 / 2
return 1
""",
"".join(
autoflake.filter_useless_pass(
"""\
-def func11():
+def func11() -> None:
pass
0, 11 / 2
return 1
@@ -1786,17 +1790,17 @@
),
)
- def test_filter_useless_pass_leading_pass_with_string(self):
+ def test_filter_useless_pass_leading_pass_with_string(self) -> None:
self.assertEqual(
"""\
-def func11():
+def func11() -> None:
'hello'
return 1
""",
"".join(
autoflake.filter_useless_pass(
"""\
-def func11():
+def func11() -> None:
pass
'hello'
return 1
@@ -1805,18 +1809,18 @@
),
)
- def test_check(self):
+ def test_check(self) -> None:
self.assertTrue(autoflake.check("import os"))
- def test_check_with_bad_syntax(self):
+ def test_check_with_bad_syntax(self) -> None:
self.assertFalse(autoflake.check("foo("))
- def test_check_with_unicode(self):
+ def test_check_with_unicode(self) -> None:
self.assertFalse(autoflake.check('print("∑")'))
self.assertTrue(autoflake.check("import os # ∑"))
- def test_get_diff_text(self):
+ def test_get_diff_text(self) -> None:
# We ignore the first two lines since it differs on Python 2.6.
self.assertEqual(
"""\
@@ -1830,7 +1834,7 @@
),
)
- def test_get_diff_text_without_newline(self):
+ def test_get_diff_text_without_newline(self) -> None:
# We ignore the first two lines since it differs on Python 2.6.
self.assertEqual(
"""\
@@ -1845,7 +1849,7 @@
),
)
- def test_is_literal_or_name(self):
+ def test_is_literal_or_name(self) -> None:
self.assertTrue(autoflake.is_literal_or_name("123"))
self.assertTrue(autoflake.is_literal_or_name("[1, 2, 3]"))
self.assertTrue(autoflake.is_literal_or_name("xyz"))
@@ -1853,7 +1857,7 @@
self.assertFalse(autoflake.is_literal_or_name("xyz.prop"))
self.assertFalse(autoflake.is_literal_or_name(" "))
- def test_is_python_file(self):
+ def test_is_python_file(self) -> None:
self.assertTrue(
autoflake.is_python_file(
os.path.join(ROOT_DIRECTORY, "autoflake.py"),
@@ -1878,7 +1882,7 @@
self.assertFalse(autoflake.is_python_file(os.devnull))
self.assertFalse(autoflake.is_python_file("/bin/bash"))
- def test_is_exclude_file(self):
+ def test_is_exclude_file(self) -> None:
self.assertTrue(
autoflake.is_exclude_file(
"1.py",
@@ -1915,7 +1919,7 @@
),
)
- def test_match_file(self):
+ def test_match_file(self) -> None:
with temporary_file("", suffix=".py", prefix=".") as filename:
self.assertFalse(
autoflake.match_file(filename, exclude=[]),
@@ -1930,7 +1934,7 @@
msg=filename,
)
- def test_find_files(self):
+ def test_find_files(self) -> None:
temp_directory = tempfile.mkdtemp()
try:
target = os.path.join(temp_directory, "dir")
@@ -1970,7 +1974,7 @@
finally:
shutil.rmtree(temp_directory)
- def test_exclude(self):
+ def test_exclude(self) -> None:
temp_directory = tempfile.mkdtemp(dir=".")
try:
with open(os.path.join(temp_directory, "a.py"), "w") as output:
@@ -2000,7 +2004,7 @@
"""System tests."""
- def test_diff(self):
+ def test_diff(self) -> None:
with temporary_file(
"""\
import re
@@ -2025,7 +2029,7 @@
"\n".join(output_file.getvalue().split("\n")[3:]),
)
- def test_diff_with_nonexistent_file(self):
+ def test_diff_with_nonexistent_file(self) -> None:
output_file = io.StringIO()
autoflake._main(
argv=["my_fake_program", "nonexistent_file"],
@@ -2034,7 +2038,7 @@
)
self.assertIn("no such file", output_file.getvalue().lower())
- def test_diff_with_encoding_declaration(self):
+ def test_diff_with_encoding_declaration(self) -> None:
with temporary_file(
"""\
# coding: iso-8859-1
@@ -2061,7 +2065,7 @@
"\n".join(output_file.getvalue().split("\n")[3:]),
)
- def test_in_place(self):
+ def test_in_place(self) -> None:
with temporary_file(
"""\
import foo
@@ -2096,7 +2100,7 @@
f.read(),
)
- def test_check_with_empty_file(self):
+ def test_check_with_empty_file(self) -> None:
line = ""
with temporary_file(line) as filename:
@@ -2111,7 +2115,7 @@
output_file.getvalue(),
)
- def test_check_correct_file(self):
+ def test_check_correct_file(self) -> None:
with temporary_file(
"""\
import foo
@@ -2130,7 +2134,7 @@
output_file.getvalue(),
)
- def test_check_correct_file_with_quiet(self):
+ def test_check_correct_file_with_quiet(self) -> None:
with temporary_file(
"""\
import foo
@@ -2151,7 +2155,7 @@
)
self.assertEqual("", output_file.getvalue())
- def test_check_useless_pass(self):
+ def test_check_useless_pass(self) -> None:
with temporary_file(
"""\
import foo
@@ -2180,7 +2184,7 @@
output_file.getvalue(),
)
- def test_check_with_multiple_files(self):
+ def test_check_with_multiple_files(self) -> None:
with temporary_file("import sys") as file1:
with temporary_file("import sys") as file2:
output_file = io.StringIO()
@@ -2198,7 +2202,7 @@
set(output_file.getvalue().strip().split(os.linesep)),
)
- def test_check_diff_with_empty_file(self):
+ def test_check_diff_with_empty_file(self) -> None:
line = ""
with temporary_file(line) as filename:
@@ -2213,7 +2217,7 @@
output_file.getvalue(),
)
- def test_check_diff_correct_file(self):
+ def test_check_diff_correct_file(self) -> None:
with temporary_file(
"""\
import foo
@@ -2232,7 +2236,7 @@
output_file.getvalue(),
)
- def test_check_diff_correct_file_with_quiet(self):
+ def test_check_diff_correct_file_with_quiet(self) -> None:
with temporary_file(
"""\
import foo
@@ -2253,7 +2257,7 @@
)
self.assertEqual("", output_file.getvalue())
- def test_check_diff_useless_pass(self):
+ def test_check_diff_useless_pass(self) -> None:
with temporary_file(
"""\
import foo
@@ -2293,7 +2297,7 @@
"\n".join(output_file.getvalue().split("\n")[3:]),
)
- def test_in_place_with_empty_file(self):
+ def test_in_place_with_empty_file(self) -> None:
line = ""
with temporary_file(line) as filename:
@@ -2306,7 +2310,7 @@
with open(filename) as f:
self.assertEqual(line, f.read())
- def test_in_place_with_with_useless_pass(self):
+ def test_in_place_with_with_useless_pass(self) -> None:
with temporary_file(
"""\
import foo
@@ -2344,17 +2348,17 @@
f.read(),
)
- def test_with_missing_file(self):
+ def test_with_missing_file(self) -> None:
output_file = io.StringIO()
ignore = StubFile()
autoflake._main(
argv=["my_fake_program", "--in-place", ".fake"],
standard_out=output_file,
- standard_error=ignore,
+ standard_error=ignore, # type: ignore
)
self.assertFalse(output_file.getvalue())
- def test_ignore_hidden_directories(self):
+ def test_ignore_hidden_directories(self) -> None:
with temporary_directory() as directory:
with temporary_directory(
prefix=".",
@@ -2382,7 +2386,7 @@
output_file.getvalue().strip(),
)
- def test_in_place_and_stdout(self):
+ def test_in_place_and_stdout(self) -> None:
output_file = io.StringIO()
self.assertRaises(
SystemExit,
@@ -2392,7 +2396,7 @@
standard_error=output_file,
)
- def test_end_to_end(self):
+ def test_end_to_end(self) -> None:
with temporary_file(
"""\
import fake_fake, fake_foo, fake_bar, fake_zoo
@@ -2422,7 +2426,7 @@
"\n".join(process.communicate()[0].decode().split(os.linesep)[3:]),
)
- def test_end_to_end_multiple_files(self):
+ def test_end_to_end_multiple_files(self) -> None:
with temporary_file(
"""\
import fake_fake, fake_foo, fake_bar, fake_zoo
@@ -2453,7 +2457,7 @@
status_code = process.wait()
self.assertEqual(1, status_code)
- def test_end_to_end_with_remove_all_unused_imports(self):
+ def test_end_to_end_with_remove_all_unused_imports(self) -> None:
with temporary_file(
"""\
import fake_fake, fake_foo, fake_bar, fake_zoo
@@ -2481,7 +2485,7 @@
"\n".join(process.communicate()[0].decode().split(os.linesep)[3:]),
)
- def test_end_to_end_with_remove_duplicate_keys_multiple_lines(self):
+ def test_end_to_end_with_remove_duplicate_keys_multiple_lines(self) -> None:
with temporary_file(
"""\
a = {
@@ -2521,7 +2525,7 @@
"\n".join(process.communicate()[0].decode().split(os.linesep)[3:]),
)
- def test_end_to_end_with_remove_duplicate_keys_and_other_errors(self):
+ def test_end_to_end_with_remove_duplicate_keys_and_other_errors(self) -> None:
with temporary_file(
"""\
from math import *
@@ -2565,7 +2569,7 @@
"\n".join(process.communicate()[0].decode().split(os.linesep)[3:]),
)
- def test_end_to_end_with_remove_duplicate_keys_tuple(self):
+ def test_end_to_end_with_remove_duplicate_keys_tuple(self) -> None:
with temporary_file(
"""\
a = {
@@ -2596,7 +2600,7 @@
"\n".join(process.communicate()[0].decode().split(os.linesep)[3:]),
)
- def test_end_to_end_with_error(self):
+ def test_end_to_end_with_error(self) -> None:
with temporary_file(
"""\
import fake_fake, fake_foo, fake_bar, fake_zoo
@@ -2619,7 +2623,7 @@
process.communicate()[1].decode(),
)
- def test_end_to_end_from_stdin(self):
+ def test_end_to_end_from_stdin(self) -> None:
stdin_data = b"""\
import fake_fake, fake_foo, fake_bar, fake_zoo
import re, os
@@ -2641,7 +2645,7 @@
"\n".join(stdout.decode().split(os.linesep)),
)
- def test_end_to_end_from_stdin_with_in_place(self):
+ def test_end_to_end_from_stdin_with_in_place(self) -> None:
stdin_data = b"""\
import fake_fake, fake_foo, fake_bar, fake_zoo
import re, os, sys
@@ -2663,7 +2667,7 @@
"\n".join(stdout.decode().split(os.linesep)),
)
- def test_end_to_end_dont_remove_unused_imports_when_not_using_flag(self):
+ def test_end_to_end_dont_remove_unused_imports_when_not_using_flag(self) -> None:
with temporary_file(
"""\
from . import fake_bar
@@ -2685,7 +2689,7 @@
class MultilineFromImportTests(unittest.TestCase):
- def test_is_over(self):
+ def test_is_over(self) -> None:
filt = autoflake.FilterMultilineImport("from . import (\n")
self.assertTrue(filt.is_over("module)\n"))
self.assertTrue(filt.is_over(" )\n"))
@@ -2713,16 +2717,25 @@
unused = ()
- def assert_fix(self, lines, result, remove_all=True):
+ def assert_fix(
+ self,
+ lines: Sequence[str],
+ result: str,
+ remove_all: bool = True,
+ ) -> None:
fixer = autoflake.FilterMultilineImport(
lines[0],
remove_all_unused_imports=remove_all,
unused_module=self.unused,
)
- fixed = functools.reduce(lambda acc, x: acc(x), lines[1:], fixer())
+ fixed = functools.reduce(
+ lambda acc, x: acc(x) if isinstance(acc, autoflake.PendingFix) else acc,
+ lines[1:],
+ fixer(),
+ )
self.assertEqual(fixed, result)
- def test_fix(self):
+ def test_fix(self) -> None:
self.unused = ["third_party.lib" + str(x) for x in (1, 3, 4)]
# Example m0 (isort)
@@ -2867,7 +2880,7 @@
")\n",
)
- def test_indentation(self):
+ def test_indentation(self) -> None:
# Some weird indentation examples
self.unused = ["third_party.lib" + str(x) for x in (1, 3, 4)]
self.assert_fix(
@@ -2888,7 +2901,7 @@
"\tfrom third_party import \\\n" "\t\tlib2, lib5, lib6\n",
)
- def test_fix_relative(self):
+ def test_fix_relative(self) -> None:
# Example m0 (isort)
self.unused = [".lib" + str(x) for x in (1, 3, 4)]
self.assert_fix(
@@ -2941,7 +2954,7 @@
"from .parent import (\n" " lib2,\n" " lib5\n" ")\n",
)
- def test_fix_without_from(self):
+ def test_fix_without_from(self) -> None:
self.unused = ["lib" + str(x) for x in (1, 3, 4)]
# Multiline but not "from"
@@ -2994,7 +3007,7 @@
"import \\\n" " lib2.x.y.z \\" " , \\\n" " lib5.x.y.z\n",
)
- def test_give_up(self):
+ def test_give_up(self) -> None:
# Semicolon
self.unused = ["lib" + str(x) for x in (1, 3, 4)]
self.assert_fix(
@@ -3030,7 +3043,7 @@
") ; import sys\n",
)
- def test_just_one_import_used(self):
+ def test_just_one_import_used(self) -> None:
self.unused = ["lib2"]
self.assert_fix(
[
@@ -3055,7 +3068,7 @@
"\tpass\n",
)
- def test_just_one_import_left(self):
+ def test_just_one_import_left(self) -> None:
# Examples from issue #8
self.unused = ["math.sqrt"]
self.assert_fix(
@@ -3089,7 +3102,7 @@
"from re import (subn)\n",
)
- def test_no_empty_imports(self):
+ def test_no_empty_imports(self) -> None:
self.unused = ["lib" + str(x) for x in (1, 3, 4)]
self.assert_fix(
[
@@ -3113,7 +3126,7 @@
"\t\tpass\n",
)
- def test_without_remove_all(self):
+ def test_without_remove_all(self) -> None:
self.unused = ["lib" + str(x) for x in (1, 3, 4)]
self.assert_fix(
[
@@ -3151,15 +3164,14 @@
class ConfigFileTest(unittest.TestCase):
- def setUp(self):
+ def setUp(self) -> None:
self.tmpdir = tempfile.mkdtemp(prefix="autoflake.")
- def tearDown(self):
- if self.tmpdir is not None:
+ def tearDown(self) -> None:
+ if self.tmpdir:
shutil.rmtree(self.tmpdir)
- self.tmpdir = None
- def effective_path(self, path, is_file=True):
+ def effective_path(self, path: str, is_file: bool = True) -> str:
path = os.path.normpath(path)
if os.path.isabs(path):
raise ValueError("Should not create an absolute test path")
@@ -3170,7 +3182,7 @@
raise ValueError("Should create a path within the tmp dir only")
return effective_path
- def create_dir(self, path):
+ def create_dir(self, path) -> None:
effective_path = self.effective_path(path, False)
if sys.version_info >= (3, 2, 0):
os.makedirs(effective_path, exist_ok=True)
@@ -3186,13 +3198,13 @@
self.create_dir(parent)
os.mkdir(effective_path)
- def create_file(self, path, contents=""):
+ def create_file(self, path, contents="") -> None:
effective_path = self.effective_path(path)
self.create_dir(os.path.split(path)[0])
with open(effective_path, "w") as f:
f.write(contents)
- def with_defaults(self, **kwargs):
+ def with_defaults(self, **kwargs: Any) -> Mapping[str, Any]:
return {
"check": False,
"check_diff": False,
@@ -3211,7 +3223,7 @@
**kwargs,
}
- def test_no_config_file(self):
+ def test_no_config_file(self) -> None:
self.create_file("test_me.py")
original_args = {
"files": [self.effective_path("test_me.py")],
@@ -3220,7 +3232,7 @@
assert success is True
assert args == self.with_defaults(**original_args)
- def test_non_nested_pyproject_toml_empty(self):
+ def test_non_nested_pyproject_toml_empty(self) -> None:
self.create_file("test_me.py")
self.create_file("pyproject.toml", '[tool.other]\nprop="value"\n')
files = [self.effective_path("test_me.py")]
@@ -3229,7 +3241,7 @@
assert success is True
assert args == self.with_defaults(**original_args)
- def test_non_nested_pyproject_toml_non_empty(self):
+ def test_non_nested_pyproject_toml_non_empty(self) -> None:
self.create_file("test_me.py")
self.create_file(
"pyproject.toml",
@@ -3243,7 +3255,7 @@
expand_star_imports=True,
)
- def test_non_nested_setup_cfg_non_empty(self):
+ def test_non_nested_setup_cfg_non_empty(self) -> None:
self.create_file("test_me.py")
self.create_file(
"setup.cfg",
@@ -3254,7 +3266,7 @@
assert success is True
assert args == self.with_defaults(files=files)
- def test_non_nested_setup_cfg_empty(self):
+ def test_non_nested_setup_cfg_empty(self) -> None:
self.create_file("test_me.py")
self.create_file(
"setup.cfg",
@@ -3268,7 +3280,7 @@
expand_star_imports=True,
)
- def test_nested_file(self):
+ def test_nested_file(self) -> None:
self.create_file("nested/file/test_me.py")
self.create_file(
"pyproject.toml",
@@ -3282,7 +3294,7 @@
expand_star_imports=True,
)
- def test_common_path_nested_file_do_not_load(self):
+ def test_common_path_nested_file_do_not_load(self) -> None:
self.create_file("nested/file/test_me.py")
self.create_file("nested/other/test_me.py")
self.create_file(
@@ -3297,7 +3309,7 @@
assert success is True
assert args == self.with_defaults(files=files)
- def test_common_path_nested_file_do_load(self):
+ def test_common_path_nested_file_do_load(self) -> None:
self.create_file("nested/file/test_me.py")
self.create_file("nested/other/test_me.py")
self.create_file(
@@ -3315,7 +3327,7 @@
expand_star_imports=True,
)
- def test_common_path_instead_of_common_prefix(self):
+ def test_common_path_instead_of_common_prefix(self) -> None:
"""Using common prefix would result in a failure."""
self.create_file("nested/file-foo/test_me.py")
self.create_file("nested/file-bar/test_me.py")
@@ -3331,7 +3343,7 @@
assert success is True
assert args == self.with_defaults(files=files)
- def test_continue_search_if_no_config_found(self):
+ def test_continue_search_if_no_config_found(self) -> None:
self.create_file("nested/test_me.py")
self.create_file(
"nested/pyproject.toml",
@@ -3349,7 +3361,7 @@
expand_star_imports=True,
)
- def test_stop_search_if_config_found(self):
+ def test_stop_search_if_config_found(self) -> None:
self.create_file("nested/test_me.py")
self.create_file(
"nested/pyproject.toml",
@@ -3364,7 +3376,7 @@
assert success is True
assert args == self.with_defaults(files=files)
- def test_config_option(self):
+ def test_config_option(self) -> None:
with temporary_file(
suffix=".ini",
contents=("[autoflake]\n" "check = True\n"),
@@ -3385,7 +3397,7 @@
check=True,
)
- def test_load_false(self):
+ def test_load_false(self) -> None:
self.create_file("test_me.py")
self.create_file(
"setup.cfg",
@@ -3400,7 +3412,7 @@
expand_star_imports=False,
)
- def test_list_value_pyproject_toml(self):
+ def test_list_value_pyproject_toml(self) -> None:
self.create_file("test_me.py")
self.create_file(
"pyproject.toml",
@@ -3414,7 +3426,7 @@
imports="my_lib,other_lib",
)
- def test_list_value_comma_sep_string_pyproject_toml(self):
+ def test_list_value_comma_sep_string_pyproject_toml(self) -> None:
self.create_file("test_me.py")
self.create_file(
"pyproject.toml",
@@ -3428,7 +3440,7 @@
imports="my_lib,other_lib",
)
- def test_list_value_setup_cfg(self):
+ def test_list_value_setup_cfg(self) -> None:
self.create_file("test_me.py")
self.create_file(
"setup.cfg",
@@ -3442,7 +3454,7 @@
imports="my_lib,other_lib",
)
- def test_non_bool_value_for_bool_property(self):
+ def test_non_bool_value_for_bool_property(self) -> None:
self.create_file("test_me.py")
self.create_file(
"pyproject.toml",
@@ -3452,7 +3464,7 @@
_, success = autoflake.merge_configuration_file({"files": files})
assert success is False
- def test_non_bool_value_for_bool_property_in_setup_cfg(self):
+ def test_non_bool_value_for_bool_property_in_setup_cfg(self) -> None:
self.create_file("test_me.py")
self.create_file(
"setup.cfg",
@@ -3462,7 +3474,7 @@
_, success = autoflake.merge_configuration_file({"files": files})
assert success is False
- def test_non_list_value_for_list_property(self):
+ def test_non_list_value_for_list_property(self) -> None:
self.create_file("test_me.py")
self.create_file(
"pyproject.toml",
@@ -3472,7 +3484,7 @@
_, success = autoflake.merge_configuration_file({"files": files})
assert success is False
- def test_merge_with_cli_set_list_property(self):
+ def test_merge_with_cli_set_list_property(self) -> None:
self.create_file("test_me.py")
self.create_file(
"pyproject.toml",
@@ -3488,7 +3500,7 @@
imports="my_lib,other_lib",
)
- def test_merge_prioritizes_flags(self):
+ def test_merge_prioritizes_flags(self) -> None:
self.create_file("test_me.py")
self.create_file(
"pyproject.toml",
@@ -3510,7 +3522,12 @@
@contextlib.contextmanager
-def temporary_file(contents, directory=".", suffix=".py", prefix=""):
+def temporary_file(
+ contents: str,
+ directory: str = ".",
+ suffix: str = ".py",
+ prefix: str = "",
+) -> Iterator[str]:
"""Write contents to temporary file and yield it."""
f = tempfile.NamedTemporaryFile(
suffix=suffix,
@@ -3527,7 +3544,7 @@
@contextlib.contextmanager
-def temporary_directory(directory=".", prefix="tmp."):
+def temporary_directory(directory: str = ".", prefix: str = "tmp.") -> Iterator[str]:
"""Create temporary directory and yield its path."""
temp_directory = tempfile.mkdtemp(prefix=prefix, dir=directory)
try:
@@ -3540,7 +3557,7 @@
"""Fake file that ignores everything."""
- def write(*_):
+ def write(*_: Any) -> None:
"""Ignore."""
diff --git a/test_fuzz.py b/test_fuzz.py
index e4af090..3d2dd93 100755
--- a/test_fuzz.py
+++ b/test_fuzz.py
@@ -5,10 +5,12 @@
done by doing a syntax check after the autoflake run. The number of
Pyflakes warnings is also confirmed to always improve.
"""
+import argparse
import os
import shlex
import subprocess
import sys
+from typing import Sequence
import autoflake
@@ -27,12 +29,12 @@
END = ""
-def colored(text, color):
+def colored(text: str, color: str) -> str:
"""Return color coded text."""
return color + text + END
-def pyflakes_count(filename):
+def pyflakes_count(filename: str) -> int:
"""Return pyflakes error count."""
with autoflake.open_with_encoding(
filename,
@@ -41,7 +43,7 @@
return len(list(autoflake.check(f.read())))
-def readlines(filename):
+def readlines(filename: str) -> Sequence[str]:
"""Return contents of file as a list of lines."""
with autoflake.open_with_encoding(
filename,
@@ -50,7 +52,7 @@
return f.readlines()
-def diff(before, after):
+def diff(before: str, after: str) -> str:
"""Return diff of two files."""
import difflib
@@ -64,7 +66,12 @@
)
-def run(filename, command, verbose=False, options=None):
+def run(
+ filename: str,
+ command: str,
+ verbose: bool = False,
+ options: list[str] | None = None,
+) -> bool:
"""Run autoflake on file at filename.
Return True on success.
@@ -123,7 +130,7 @@
return True
-def check_syntax(filename, raise_error=False):
+def check_syntax(filename: str, raise_error: bool = False) -> bool:
"""Return True if syntax is okay."""
with autoflake.open_with_encoding(
filename,
@@ -139,10 +146,8 @@
return False
-def process_args():
+def process_args() -> argparse.Namespace:
"""Return processed arguments (options and positional arguments)."""
- import argparse
-
parser = argparse.ArgumentParser()
parser.add_argument(
@@ -192,7 +197,7 @@
return parser.parse_args()
-def check(args):
+def check(args: argparse.Namespace) -> bool:
"""Run recursively run autoflake on directory of files.
Return False if the fix results in broken syntax.
@@ -267,7 +272,7 @@
return True
-def main():
+def main() -> int:
"""Run main."""
return 0 if check(process_args()) else 1