| """Tools for runtime type inference""" |
| |
| import inspect |
| from inspect3 import getfullargspec, getcallargs |
| import types |
| import codecs |
| import os |
| import tokenize |
| try: |
| from StringIO import StringIO |
| from unparse import Unparser |
| except: |
| from io import StringIO |
| from unparse3 import Unparser |
| import ast |
| |
| |
| MAX_INFERRED_TUPLE_LENGTH = 10 |
| PREFERRED_LINE_LENGTH = 79 |
| |
| |
| var_db = {} # (location, variable) -> type |
| func_argid_db = {} # funcid -> argspec |
| func_arg_db = {} # (funcid, name) -> type |
| func_return_db = {} # funcname -> type |
| func_source_db = {} # funcid -> source string |
| #func_info_db = {} # funcid -> (class, name, argspec, file, line, source) |
| ignore_files = set() |
| |
| # The type inferencing wrapper should not be reentrant. It's not, in theory, calling |
| # out to any external code which we would want to infer the types of. However, |
| # sometimes we do something like infer_type(arg.keys()) or infer_type(arg.values()) if |
| # the arg is a collection, and we want to know about the types of its elements. .keys(), |
| # .values(), etc. can be overloaded, possibly to a method we've wrapped. This can become |
| # infinitely recursive, particularly because on something like arg.keys(), keys() gets passed |
| # arg as the first parameter, so if we've wrapped keys() we'll try to infer_type(arg), |
| # which will detect it's a dictionary, call infer_type(arg.keys()), recurse and so on. |
| # We ran in to this problem with collections.OrderedDict. |
| # To prevent reentrancy, we set is_performing_inference = True iff we're in the middle of |
| # inferring the types of a function. If we try to run another function we've wrapped, |
| # we skip type inferencing so we can't accidentally infinitely recurse. |
| is_performing_inference = False |
| |
| |
| def reset(): |
| global var_db, func_argid_db, func_arg_db, func_return_db, func_source_db |
| global ignore_files, is_performing_inference |
| var_db = {} |
| func_arg_db = {} |
| func_return_db = {} |
| # we don't actually want to clear these on reset(), or we'll |
| # lose the functions we've already wrapped forever. |
| #func_source_db = {} |
| #func_argid_db = {} |
| is_performing_inference = False |
| ignore_files = set() |
| |
| |
| def format_state(pretty=False): |
| lines = [] |
| for loc, var in sorted(var_db.keys()): |
| lines.append('%s: %s' % (var, var_db[(loc, var)])) |
| funcnames = sorted(set(func_return_db.keys())) |
| prevclass = '' |
| indent = '' |
| for funcid in funcnames: |
| curclass, name, sourcefile, sourceline = funcid |
| if curclass != prevclass: |
| if curclass: |
| lines.append('class %s(...):' % curclass) |
| indent = ' ' * 4 |
| else: |
| indent = '' |
| prevclass = curclass |
| |
| lines.append(format_sig(funcid, name, indent, pretty)) |
| return '\n'.join(lines) |
| |
| |
| def unparse_ast(node): |
| buf = StringIO() |
| Unparser(node, buf) |
| return buf.getvalue().strip() |
| |
| |
| def format_sig(funcid, fname, indent, pretty, defaults=[]): |
| (argnames, varargs, varkw, _, kwonlyargs, _, _) = func_argid_db[funcid] |
| |
| # to get defaults, parse the function, get the nodes for the |
| # defaults, then unparse them |
| try: |
| fn_ast = ast.parse(func_source_db[funcid].strip()).body[0] |
| |
| # override fname if we parsed a different one |
| fname = fn_ast.name |
| |
| defaults = [unparse_ast(dn) for dn in fn_ast.args.defaults] |
| |
| if hasattr(fn_ast.args, 'kw_defaults'): |
| kwonly_defaults = [unparse_ast(dn) for dn in fn_ast.args.kw_defaults] |
| else: |
| kwonly_defaults = [] |
| except: |
| defaults, kwonly_defaults = [], [] |
| finally: |
| # pad defaults to match the length of args |
| defaults = ([None] * (len(argnames) - len(defaults))) + defaults |
| kwonly_defaults = ([None] * (len(kwonlyargs) - len(kwonly_defaults))) + kwonly_defaults |
| |
| args = [('', arg, default) for (arg, default) in zip(argnames, defaults)] |
| |
| if varargs: |
| args += [('*', varargs, None)] |
| elif len(kwonlyargs) > 0: |
| args += [('*', '', None)] |
| if len(kwonlyargs) > 0: |
| args += [('', arg, default) for (arg, default) in zip(kwonlyargs, kwonly_defaults)] |
| if varkw: |
| args += [('**', varkw, None)] |
| |
| argstrs = [] |
| for i, (prefix, arg, default) in enumerate(args): |
| argstr = prefix + arg |
| |
| # Omit type of self argument. |
| if (funcid, arg) in func_arg_db and not (i == 0 and arg == 'self'): |
| argstr += ': %s' % func_arg_db[(funcid, arg)] |
| |
| if default: |
| argstr += ' = %s' % default |
| |
| argstrs.append(argstr) |
| |
| ret = str(func_return_db.get(funcid, Unknown())) |
| |
| sig = 'def %s(%s) -> %s' % (fname, ', '.join(argstrs), ret) |
| if not pretty or len(sig) <= PREFERRED_LINE_LENGTH or not args: |
| return indent + sig |
| |
| else: |
| # Format into multiple lines to conserve horizontal space. |
| first = indent + 'def %s(' % fname |
| extra_indent = first.index('(') + 1 |
| |
| decl = indent + first |
| decl += (',\n' + indent + ' ' * extra_indent).join(argstrs) |
| decl += ')\n%s -> %s' % (indent + ' ' * (extra_indent - 4), ret) |
| return decl |
| |
| |
| def annotate_file(path): |
| # this should be documented somewhere... |
| INDENT_TOKEN = 5 |
| |
| with open(path, 'r') as targetfile: |
| source = targetfile.read() |
| |
| line_offsets = [] |
| source_length = 0 |
| for line in source.split('\n'): |
| line_offsets.append(source_length) |
| source_length = source_length + len(line) + 1 |
| |
| funcids = set(funcid for funcid, arg in func_arg_db) |
| |
| # list of (oldstart, oldend, replacement) |
| replacements = [] # type: List[Tuple[Int, Int, String]] |
| |
| for funcid in funcids: |
| class_name, name, sourcefile, def_start_line = funcid |
| if sourcefile != path: |
| continue |
| |
| func_source = func_source_db[funcid] |
| tokens = list(tokenize.generate_tokens(StringIO(func_source).readline)) |
| assert len(tokens) > 0 |
| |
| # we're making the assumption that the def at least gets to start on |
| # it's own line, which is fine for non-lambdas |
| |
| if tokens[0][0] == INDENT_TOKEN: |
| indent = tokens[0][1] |
| del tokens[0] |
| else: |
| indent = '' |
| |
| # Find the first indent, which should be between the end of the def |
| # and before the start of the body. Then find the preceding colon, |
| # which should be at the end of the def. |
| |
| for indent_loc in range(len(tokens)): |
| if tokens[indent_loc][0] == INDENT_TOKEN: |
| function_is_one_line = False |
| break |
| else: |
| function_is_one_line = True |
| |
| if function_is_one_line: |
| # we're also making the assumption that the def has an indent on the |
| # line following the signature, which is true almost all of the time. |
| # If this is not the case, we should just leave a comment above the |
| # function, although I might not have time to do that now. |
| continue |
| |
| for def_end_loc in range(indent_loc, -1, -1): |
| if tokens[def_end_loc][1] == ':': |
| break |
| |
| assert def_end_loc > 0 |
| |
| def_end_line, def_end_col = tokens[def_end_loc][2] |
| def_end_line -= 1 # the tokenizer apparently 1-indexes lines |
| def_end_line += def_start_line |
| |
| def_start_offset = line_offsets[def_start_line] |
| def_end_offset = line_offsets[def_end_line] + def_end_col |
| |
| annotated_def = format_sig(funcid, name, indent, True) |
| |
| replacements.append((def_start_offset, def_end_offset, annotated_def)) |
| |
| # ideally, we'd put this after the docstring |
| replacements.append((0, 0, "from typing import List, Dict, Set, Tuple, Callable, Pattern, Match, Union, Optional\n")) |
| |
| # absurdly inefficient algorithm: replace with O(n) writer |
| |
| for (start, end, replacement) in sorted(replacements, key=lambda r: r[0], reverse=True): |
| source = source[0:start] + replacement + source[end:] |
| |
| return source |
| |
| |
| def dump(): |
| s = format_state(pretty=True) |
| if s: |
| print() |
| print('INFERRED TYPES:') |
| print(s) |
| reset() |
| |
| |
| def dump_at_exit(): |
| import atexit |
| atexit.register(dump) |
| |
| |
| def get_defining_file(obj): |
| try: |
| path = os.path.abspath(inspect.getfile(obj)) |
| if path.endswith('.pyc'): |
| path = path[0:-1] |
| return path |
| except: |
| return None |
| |
| |
| def infer_var(name, value): |
| key = (None, name) |
| update_var_db(key, value) |
| |
| |
| def infer_attrs(x): |
| if hasattr(x, '__class__'): |
| t = x.__class__ |
| else: |
| t = type(x) |
| cls = t.__name__ |
| typedict = t.__dict__ |
| for dict in x.__dict__, typedict: |
| for attr, value in dict.items(): |
| if attr in ('__dict__', '__doc__', '__module__', '__weakref__'): |
| continue |
| if type(value) is type(infer_attrs) and dict is typedict: |
| # Skip methods. |
| continue |
| key = (None, '%s.%s' % (cls, attr)) |
| update_var_db(key, value) |
| |
| |
| def infer_method_signature(class_name): |
| def decorator(func): |
| return infer_signature(func, class_name) |
| return decorator |
| |
| |
| def infer_signature(func, class_name=''): |
| """Decorator that infers the signature of a function.""" |
| |
| # infer_method_signature should be idempotent |
| if hasattr(func, '__is_inferring_sig'): |
| return func |
| |
| assert func.__module__ != infer_method_signature.__module__ |
| |
| try: |
| funcfile = get_defining_file(func) |
| funcsource, sourceline = inspect.getsourcelines(func) |
| sourceline -= 1 # getsourcelines is apparently 1-indexed |
| except: |
| return func |
| |
| funcid = (class_name, func.__name__, funcfile, sourceline) |
| func_source_db[funcid] = ''.join(funcsource) |
| |
| try: |
| func_argid_db[funcid] = getfullargspec(func) |
| vargs_name, kwargs_name = func_argid_db[funcid][1], func_argid_db[funcid][2] |
| except TypeError: |
| # Not supported. |
| return func |
| |
| def wrapper(*args, **kwargs): |
| global is_performing_inference |
| # If we're already doing inference, we should be in our own code, not code we're checking. |
| # Not doing this check sometimes results in infinite recursion. |
| |
| if is_performing_inference: |
| return func(*args, **kwargs) |
| |
| expecting_type_error, got_type_error, got_exception = False, False, False |
| |
| is_performing_inference = True |
| try: |
| callargs = getcallargs(func, *args, **kwargs) |
| |
| # we have to handle *args and **kwargs separately |
| if vargs_name: |
| va = callargs.pop(vargs_name) |
| if kwargs_name: |
| kw = callargs.pop(kwargs_name) |
| |
| arg_db = {arg: infer_value_type(value) for arg, value in callargs.items()} |
| |
| # *args and **kwargs need to merge the types of all their values |
| if vargs_name: |
| arg_db[vargs_name] = union_many_types(*[infer_value_type(v) for v in va]) |
| if kwargs_name: |
| arg_db[kwargs_name] = union_many_types(*[infer_value_type(v) for v in kw.values()]) |
| |
| except TypeError: |
| got_exception = expecting_type_error = True |
| except: |
| got_exception = True |
| finally: |
| is_performing_inference = False |
| |
| try: |
| ret = func(*args, **kwargs) |
| except TypeError: |
| got_type_error = got_exception = True |
| raise |
| except: |
| got_exception = True |
| raise |
| finally: |
| if not got_exception: |
| assert not expecting_type_error |
| |
| # if we didn't get a TypeError, update the actual database |
| for arg, t in arg_db.items(): |
| update_db(func_arg_db, (funcid, arg), t) |
| |
| # if we got an exception, we don't have a ret |
| if not got_exception: |
| is_performing_inference = True |
| try: |
| type = infer_value_type(ret) |
| update_db(func_return_db, funcid, type) |
| except: |
| pass |
| finally: |
| is_performing_inference = False |
| |
| return ret |
| |
| if hasattr(func, '__name__'): |
| wrapper.__name__ = func.__name__ |
| wrapper.__is_inferring_sig = True |
| return wrapper |
| |
| |
| def infer_class(cls): |
| """Class decorator for inferring signatures of all methods of the class.""" |
| for attr, value in cls.__dict__.items(): |
| if type(value) is type(infer_class): |
| setattr(cls, attr, infer_method_signature(cls.__name__)(value)) |
| return cls |
| |
| |
| def infer_module(namespace): |
| if hasattr(namespace, '__dict__'): |
| namespace = namespace.__dict__ |
| for name, value in list(namespace.items()): |
| if inspect.isfunction(value): |
| namespace[name] = infer_signature(value) |
| elif inspect.isclass(value): |
| namespace[name] = infer_class(value) |
| |
| |
| def update_var_db(key, value): |
| type = infer_value_type(value) |
| update_db(var_db, key, type) |
| |
| |
| def update_db(db, key, type): |
| if key not in db: |
| db[key] = type |
| else: |
| db[key] = combine_types(db[key], type) |
| |
| |
| def merge_db(db, other): |
| assert id(db) != id(other) |
| for key in other.keys(): |
| if key not in db: |
| db[key] = other[key] |
| else: |
| db[key] = combine_types(db[key], other[key]) |
| |
| |
| def infer_value_type(value, depth=0): |
| # Prevent infinite recursion |
| if depth > 5: |
| return Unknown() |
| depth += 1 |
| |
| if value is None: |
| return None |
| elif isinstance(value, list): |
| return Generic('List', [infer_value_types(value, depth)]) |
| elif isinstance(value, dict): |
| keytype = infer_value_types(value.keys(), depth) |
| valuetype = infer_value_types(value.values(), depth) |
| return Generic('Dict', (keytype, valuetype)) |
| elif isinstance(value, tuple): |
| if len(value) <= MAX_INFERRED_TUPLE_LENGTH: |
| return Tuple(infer_value_type(item, depth) |
| for item in value) |
| else: |
| return Generic('TupleSequence', [infer_value_types(value, depth)]) |
| elif isinstance(value, set): |
| return Generic('Set', [infer_value_types(value, depth)]) |
| elif isinstance(value, types.MethodType) or isinstance(value, types.FunctionType): |
| return Instance(Callable) |
| else: |
| for t in type(value).mro(): |
| if get_defining_file(t) in ignore_files: |
| continue |
| elif t is object: |
| return Any() |
| elif hasattr(types, 'InstanceType') and t is types.InstanceType: |
| return Any() |
| else: |
| return Instance(t) |
| else: |
| return Any() |
| |
| |
| def infer_value_types(values, depth=0): |
| """Infer a single type for an iterable of values. |
| |
| >>> infer_value_types((1, 'x')) |
| Union(int, str) |
| >>> infer_value_types([]) |
| Unknown |
| """ |
| inferred = Unknown() |
| for value in sample(values): |
| type = infer_value_type(value, depth) |
| inferred = combine_types(inferred, type) |
| return inferred |
| |
| |
| def sample(values): |
| # TODO only return a sample of values |
| return list(values) |
| |
| |
| def union_many_types(*types): |
| union = Unknown() |
| for t in types: |
| union = combine_types(union, t) |
| return union |
| |
| |
| def combine_types(x, y): |
| """Perform a union of two types. |
| |
| >>> combine_types(Instance(int), None) |
| Optional[int] |
| """ |
| if isinstance(x, Unknown): |
| return y |
| if isinstance(y, Unknown): |
| return x |
| if isinstance(x, Any): |
| return x |
| if isinstance(y, Any): |
| return y |
| if isinstance(x, Union): |
| return combine_either(x, y) |
| if isinstance(y, Union): |
| return combine_either(y, x) |
| if x == y: |
| return x |
| return simplify_either([x], [y]) |
| |
| |
| def combine_either(either, x): |
| if isinstance(x, Union): |
| xtypes = x.types |
| else: |
| xtypes = [x] |
| return simplify_either(either.types, xtypes) |
| |
| |
| def simplify_either(x, y): |
| numerics = [Instance(int), Instance(float), Instance(complex)] |
| |
| # TODO this is O(n**2); use an O(n) algorithm instead |
| result = list(x) |
| for type in y: |
| if isinstance(type, Generic): |
| for i, rt in enumerate(result): |
| if isinstance(rt, Generic) and type.typename == rt.typename: |
| result[i] = Generic(rt.typename, |
| (combine_types(t, s) |
| for t, s in zip(type.args, rt.args))) |
| break |
| else: |
| result.append(type) |
| elif isinstance(type, Tuple): |
| for i, rt in enumerate(result): |
| if isinstance(rt, Tuple) and len(type) == len(rt): |
| result[i] = Tuple(combine_types(t, s) |
| for t, s in zip(type.itemtypes, |
| rt.itemtypes)) |
| break |
| else: |
| result.append(type) |
| elif type in numerics: |
| for i, rt in enumerate(result): |
| if rt in numerics: |
| result[i] = numerics[max(numerics.index(rt), numerics.index(type))] |
| break |
| else: |
| result.append(type) |
| elif isinstance(type, Instance): |
| for i, rt in enumerate(result): |
| if isinstance(rt, Instance): |
| # Union[A, SubclassOfA] -> A |
| # Union[A, A] -> A, because issubclass(A, A) == True, |
| if issubclass(type.typeobj, rt.typeobj): |
| break |
| elif issubclass(rt.typeobj, type.typeobj): |
| result[i] = type |
| break |
| else: |
| result.append(type) |
| elif type not in result: |
| result.append(type) |
| |
| if len(result) > 1: |
| return Union(result) |
| else: |
| return result[0] |
| |
| |
| class TypeBase(object): |
| """Abstract base class of all type objects. |
| |
| Type objects use isinstance tests librarally -- they don't support duck |
| typing well. |
| """ |
| |
| def __eq__(self, other): |
| if type(other) is not type(self): |
| return False |
| for attr in self.__dict__: |
| if getattr(other, attr) != getattr(self, attr): |
| return False |
| return True |
| |
| def __ne__(self, other): |
| return not self == other |
| |
| def __repr__(self): |
| return str(self) |
| |
| |
| class Instance(TypeBase): |
| def __init__(self, typeobj): |
| assert not inspect.isclass(typeobj) or not issubclass(typeobj, TypeBase) |
| self.typeobj = typeobj |
| |
| def __str__(self): |
| # cheat on regular expression objects which have weird class names |
| # to be consistent with typing.py |
| if self.typeobj == Pattern: |
| return "Pattern" |
| elif self.typeobj == Match: |
| return "Match" |
| else: |
| return self.typeobj.__name__ |
| |
| def __repr__(self): |
| return 'Instance(%s)' % self |
| |
| |
| class Generic(TypeBase): |
| def __init__(self, typename, args): |
| self.typename = typename |
| self.args = tuple(args) |
| |
| def __str__(self): |
| return '%s[%s]' % (self.typename, ', '.join(str(t) |
| for t in self.args)) |
| |
| |
| class Tuple(TypeBase): |
| def __init__(self, itemtypes): |
| self.itemtypes = tuple(itemtypes) |
| |
| def __len__(self): |
| return len(self.itemtypes) |
| |
| def __str__(self): |
| return 'Tuple[%s]' % (', '.join(str(t) for t in self.itemtypes)) |
| |
| |
| class Union(TypeBase): |
| def __init__(self, types): |
| assert len(types) > 1 |
| self.types = tuple(types) |
| |
| def __eq__(self, other): |
| if type(other) is not Union: |
| return False |
| # TODO this is O(n**2); use an O(n) algorithm instead |
| for t in self.types: |
| if t not in other.types: |
| return False |
| for t in other.types: |
| if t not in self.types: |
| return False |
| return True |
| |
| def __str__(self): |
| types = list(self.types) |
| if str != bytes: # on Python 2 str == bytes |
| if Instance(bytes) in types and Instance(str) in types: |
| # we Union[bytes, str] -> AnyStr as late as possible so we avoid |
| # corner cases like subclasses of bytes or str |
| types.remove(Instance(bytes)) |
| types.remove(Instance(str)) |
| types.append(Instance(AnyStr)) |
| if len(types) == 1: |
| return str(types[0]) |
| elif len(types) == 2 and None in types: |
| type = [t for t in types if t is not None][0] |
| return 'Optional[%s]' % type |
| else: |
| return 'Union[%s]' % (', '.join(sorted(str(t) for t in types))) |
| |
| |
| class Unknown(TypeBase): |
| def __str__(self): |
| return 'Unknown' |
| |
| def __repr__(self): |
| return 'Unknown()' |
| |
| |
| class Any(TypeBase): |
| def __str__(self): |
| return 'Any' |
| |
| def __repr__(self): |
| return 'Any()' |
| |
| |
| class AnyStr(object): pass |
| class Callable(object): pass |
| import re |
| Pattern = type(re.compile(u'')) |
| Match = type(re.match(u'', u'')) |