From 8f1732b420ff78725832cdad54ded57abc33e936 Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 5 Apr 2017 03:30:18 +0300 Subject: [PATCH 01/18] add specialtype module; pass typecheck --- mypy/semanal.py | 955 ++------------------------------------------ mypy/specialtype.py | 951 +++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 985 insertions(+), 921 deletions(-) create mode 100644 mypy/specialtype.py diff --git a/mypy/semanal.py b/mypy/semanal.py index 8f285c135242c..c9770ade12e0a 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -43,7 +43,6 @@ traverse the entire AST. """ -from collections import OrderedDict from typing import ( List, Dict, Set, Tuple, cast, TypeVar, Union, Optional, Callable ) @@ -54,18 +53,17 @@ ImportFrom, ImportAll, Block, LDEF, NameExpr, MemberExpr, IndexExpr, TupleExpr, ListExpr, ExpressionStmt, ReturnStmt, RaiseStmt, AssertStmt, OperatorAssignmentStmt, WhileStmt, - ForStmt, BreakStmt, ContinueStmt, IfStmt, TryStmt, WithStmt, DelStmt, PassStmt, + ForStmt, BreakStmt, ContinueStmt, IfStmt, TryStmt, WithStmt, DelStmt, GlobalDecl, SuperExpr, DictExpr, CallExpr, RefExpr, OpExpr, UnaryExpr, SliceExpr, CastExpr, RevealTypeExpr, TypeApplication, Context, SymbolTable, SymbolTableNode, BOUND_TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, - LambdaExpr, MDEF, FuncBase, Decorator, SetExpr, TypeVarExpr, NewTypeExpr, + LambdaExpr, MDEF, Decorator, SetExpr, TypeVarExpr, StrExpr, BytesExpr, PrintStmt, ConditionalExpr, PromoteExpr, - ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, ARG_NAMED_OPT, MroError, type_aliases, - YieldFromExpr, NamedTupleExpr, TypedDictExpr, NonlocalDecl, SymbolNode, + ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, + YieldFromExpr, NonlocalDecl, SymbolNode, SetComprehension, DictionaryComprehension, TYPE_ALIAS, TypeAliasExpr, - YieldExpr, ExecStmt, Argument, BackquoteExpr, ImportBase, AwaitExpr, - IntExpr, FloatExpr, UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, - COVARIANT, CONTRAVARIANT, INVARIANT, UNBOUND_IMPORTED, LITERAL_YES, ARG_OPT, nongen_builtins, + YieldExpr, ExecStmt, BackquoteExpr, ImportBase, AwaitExpr, + IntExpr, FloatExpr, UnicodeExpr, UNBOUND_IMPORTED, LITERAL_YES, nongen_builtins, collections_type_aliases, get_member_expr_fullname, ) from mypy.typevars import has_no_typevars, fill_typevars @@ -75,8 +73,8 @@ from mypy.messages import CANNOT_ASSIGN_TO_TYPE from mypy.types import ( NoneTyp, CallableType, Overloaded, Instance, Type, TypeVarType, AnyType, - FunctionLike, UnboundType, TypeList, TypeVarDef, TypeType, - TupleType, UnionType, StarType, EllipsisType, function_type, TypedDictType, + FunctionLike, UnboundType, TypeList, TypeVarDef, + TupleType, UnionType, StarType, EllipsisType, function_type, ) from mypy.nodes import implicit_module_attrs from mypy.typeanal import ( @@ -85,7 +83,7 @@ from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.sametypes import is_same_type from mypy.options import Options -from mypy import join +from mypy.specialtype import Special T = TypeVar('T') @@ -206,6 +204,8 @@ class SemanticAnalyzer(NodeVisitor): imports = None # type: Set[str] # Imported modules (during phase 2 analysis) errors = None # type: Errors # Keeps track of generated errors + specialtype = None # type: Special + def __init__(self, modules: Dict[str, MypyFile], missing_modules: Set[str], @@ -232,6 +232,7 @@ def __init__(self, self.postpone_nested_functions_stack = [FUNCTION_BOTH_PHASES] self.postponed_functions_stack = [] self.all_exports = set() # type: Set[str] + self.specialtype = Special(self) def visit_file(self, file_node: MypyFile, fnam: str, options: Options) -> None: self.options = options @@ -331,8 +332,8 @@ def visit_func_def(self, defn: FuncDef) -> None: # A coroutine defined as `async def foo(...) -> T: ...` # has external return type `Awaitable[T]`. defn.type = defn.type.copy_modified( - ret_type = self.named_type_or_none('typing.Awaitable', - [defn.type.ret_type])) + ret_type=self.named_type_or_none('typing.Awaitable', + [defn.type.ret_type])) self.errors.pop_function() def prepare_method_signature(self, func: FuncDef) -> None: @@ -351,7 +352,8 @@ def prepare_method_signature(self, func: FuncDef) -> None: leading_type = fill_typevars(self.type) func.type = replace_implicit_first_type(functype, leading_type) - def set_original_def(self, previous: Node, new: FuncDef) -> bool: + @staticmethod + def set_original_def(previous: Node, new: FuncDef) -> bool: """If 'new' conditionally redefine 'previous', set 'previous' as original We reject straight redefinitions of functions, as they are usually @@ -653,9 +655,9 @@ def check_function_signature(self, fdef: FuncItem) -> None: def visit_class_def(self, defn: ClassDef) -> None: self.clean_up_bases_and_infer_type_variables(defn) - if self.analyze_typeddict_classdef(defn): + if self.specialtype.analyze_typeddict_classdef(defn): return - if self.analyze_namedtuple_classdef(defn): + if self.specialtype.analyze_namedtuple_classdef(defn): # just analyze the class body so we catch type errors in default values self.enter_class(defn) defn.defs.accept(self) @@ -719,7 +721,8 @@ def unbind_class_type_vars(self) -> None: def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None: decorator.accept(self) - def calculate_abstract_status(self, typ: TypeInfo) -> None: + @staticmethod + def calculate_abstract_status(typ: TypeInfo) -> None: """Calculate abstract status of a class. Set is_abstract of the type to True if the type has an unimplemented @@ -786,7 +789,7 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: except TypeTranslationError: # This error will be caught later. continue - tvars = self.analyze_typevar_declaration(base) + tvars = self.specialtype.analyze_typevar_declaration(base) if tvars is not None: if declared_tvars: self.fail('Duplicate Generic in bases', defn) @@ -814,25 +817,6 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: for i in reversed(removed): del defn.base_type_exprs[i] - def analyze_typevar_declaration(self, t: Type) -> Optional[List[Tuple[str, TypeVarExpr]]]: - if not isinstance(t, UnboundType): - return None - unbound = t - sym = self.lookup_qualified(unbound.name, unbound) - if sym is None or sym.node is None: - return None - if sym.node.fullname() == 'typing.Generic': - tvars = [] # type: List[Tuple[str, TypeVarExpr]] - for arg in unbound.args: - tvar = self.analyze_unbound_tvar(arg) - if tvar: - tvars.append(tvar) - else: - self.fail('Free type variable expected in %s[...]' % - sym.node.name(), t) - return tvars - return None - def analyze_unbound_tvar(self, t: Type) -> Tuple[str, TypeVarExpr]: if not isinstance(t, UnboundType): return None @@ -872,7 +856,8 @@ def get_tvars(self, tp: Type) -> List[Tuple[str, TypeVarExpr]]: tvars.extend(self.get_tvars(arg)) return self.remove_dups(tvars) - def remove_dups(self, tvars: List[T]) -> List[T]: + @staticmethod + def remove_dups(tvars: List[T]) -> List[T]: # Get unique elements in order of appearance all_tvars = set(tvars) new_tvars = [] # type: List[T] @@ -882,63 +867,6 @@ def remove_dups(self, tvars: List[T]) -> List[T]: all_tvars.remove(t) return new_tvars - def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: - # special case for NamedTuple - for base_expr in defn.base_type_exprs: - if isinstance(base_expr, RefExpr): - base_expr.accept(self) - if base_expr.fullname == 'typing.NamedTuple': - node = self.lookup(defn.name, defn) - if node is not None: - node.kind = GDEF # TODO in process_namedtuple_definition also applies here - items, types, default_items = self.check_namedtuple_classdef(defn) - node.node = self.build_namedtuple_typeinfo( - defn.name, items, types, default_items) - return True - return False - - def check_namedtuple_classdef( - self, defn: ClassDef) -> Tuple[List[str], List[Type], Dict[str, Expression]]: - NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' - 'expected "field_name: field_type"') - if self.options.python_version < (3, 6): - self.fail('NamedTuple class syntax is only supported in Python 3.6', defn) - return [], [], {} - if len(defn.base_type_exprs) > 1: - self.fail('NamedTuple should be a single base', defn) - items = [] # type: List[str] - types = [] # type: List[Type] - default_items = {} # type: Dict[str, Expression] - for stmt in defn.defs.body: - if not isinstance(stmt, AssignmentStmt): - # Still allow pass or ... (for empty namedtuples). - if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): - self.fail(NAMEDTUP_CLASS_ERROR, stmt) - elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): - # An assignment, but an invalid one. - self.fail(NAMEDTUP_CLASS_ERROR, stmt) - else: - # Append name and type in this case... - name = stmt.lvalues[0].name - items.append(name) - types.append(AnyType() if stmt.type is None else self.anal_type(stmt.type)) - # ...despite possible minor failures that allow further analyzis. - if name.startswith('_'): - self.fail('NamedTuple field name cannot start with an underscore: {}' - .format(name), stmt) - if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: - self.fail(NAMEDTUP_CLASS_ERROR, stmt) - elif isinstance(stmt.rvalue, TempNode): - # x: int assigns rvalue to TempNode(AnyType()) - if default_items: - self.fail('Non-default NamedTuple fields cannot follow default fields', - stmt) - else: - default_items[name] = stmt.rvalue - return items, types, default_items - def setup_class_def_analysis(self, defn: ClassDef) -> None: """Prepare for the analysis of a class definition.""" if not defn.info: @@ -999,7 +927,7 @@ def analyze_base_classes(self, defn: ClassDef) -> None: info.fallback_to_any = True # Add 'object' as implicit base if there is no other base class. - if (not base_types and defn.fullname != 'builtins.object'): + if not base_types and defn.fullname != 'builtins.object': base_types.append(self.object_type()) info.bases = base_types @@ -1020,13 +948,10 @@ def analyze_base_classes(self, defn: ClassDef) -> None: def expr_to_analyzed_type(self, expr: Expression) -> Type: if isinstance(expr, CallExpr): expr.accept(self) - info = self.check_namedtuple(expr) - if info is None: - # Some form of namedtuple is the only valid type that looks like a call - # expression. This isn't a valid type. + typ = self.specialtype.analyze_callexpr_as_type(expr) + if typ is None: raise TypeTranslationError() - fallback = Instance(info, []) - return TupleType(info.tuple_type.items, fallback=fallback) + return typ typ = expr_to_unanalyzed_type(expr) return self.anal_type(typ) @@ -1048,7 +973,8 @@ def verify_base_classes(self, defn: ClassDef) -> bool: return False return True - def is_base_class(self, t: TypeInfo, s: TypeInfo) -> bool: + @staticmethod + def is_base_class(t: TypeInfo, s: TypeInfo) -> bool: """Determine if t is a base class of s (but do not use mro).""" # Search the base class graph for t, starting from s. worklist = [s] @@ -1156,98 +1082,6 @@ def bind_class_type_variables_in_symbol_table( nodes.append(node) return nodes - def is_typeddict(self, expr: Expression) -> bool: - return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and - expr.node.typeddict_type is not None) - - def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: - # special case for TypedDict - possible = False - for base_expr in defn.base_type_exprs: - if isinstance(base_expr, RefExpr): - base_expr.accept(self) - if (base_expr.fullname == 'mypy_extensions.TypedDict' or - self.is_typeddict(base_expr)): - possible = True - if possible: - node = self.lookup(defn.name, defn) - if node is not None: - node.kind = GDEF # TODO in process_namedtuple_definition also applies here - if (len(defn.base_type_exprs) == 1 and - isinstance(defn.base_type_exprs[0], RefExpr) and - defn.base_type_exprs[0].fullname == 'mypy_extensions.TypedDict'): - # Building a new TypedDict - fields, types = self.check_typeddict_classdef(defn) - node.node = self.build_typeddict_typeinfo(defn.name, fields, types) - return True - # Extending/merging existing TypedDicts - if any(not isinstance(expr, RefExpr) or - expr.fullname != 'mypy_extensions.TypedDict' and - not self.is_typeddict(expr) for expr in defn.base_type_exprs): - self.fail("All bases of a new TypedDict must be TypedDict types", defn) - typeddict_bases = list(filter(self.is_typeddict, defn.base_type_exprs)) - newfields = [] # type: List[str] - newtypes = [] # type: List[Type] - tpdict = None # type: OrderedDict[str, Type] - for base in typeddict_bases: - assert isinstance(base, RefExpr) - assert isinstance(base.node, TypeInfo) - assert isinstance(base.node.typeddict_type, TypedDictType) - tpdict = base.node.typeddict_type.items - newdict = tpdict.copy() - for key in tpdict: - if key in newfields: - self.fail('Cannot overwrite TypedDict field "{}" while merging' - .format(key), defn) - newdict.pop(key) - newfields.extend(newdict.keys()) - newtypes.extend(newdict.values()) - fields, types = self.check_typeddict_classdef(defn, newfields) - newfields.extend(fields) - newtypes.extend(types) - node.node = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) - return True - return False - - def check_typeddict_classdef(self, defn: ClassDef, - oldfields: List[str] = None) -> Tuple[List[str], List[Type]]: - TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' - 'expected "field_name: field_type"') - if self.options.python_version < (3, 6): - self.fail('TypedDict class syntax is only supported in Python 3.6', defn) - return [], [] - fields = [] # type: List[str] - types = [] # type: List[Type] - for stmt in defn.defs.body: - if not isinstance(stmt, AssignmentStmt): - # Still allow pass or ... (for empty TypedDict's). - if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): - self.fail(TPDICT_CLASS_ERROR, stmt) - elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): - # An assignment, but an invalid one. - self.fail(TPDICT_CLASS_ERROR, stmt) - else: - name = stmt.lvalues[0].name - if name in (oldfields or []): - self.fail('Cannot overwrite TypedDict field "{}" while extending' - .format(name), stmt) - continue - if name in fields: - self.fail('Duplicate TypedDict field "{}"'.format(name), stmt) - continue - # Append name and type in this case... - fields.append(name) - types.append(AnyType() if stmt.type is None else self.anal_type(stmt.type)) - # ...despite possible minor failures that allow further analyzis. - if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: - self.fail(TPDICT_CLASS_ERROR, stmt) - elif not isinstance(stmt.rvalue, TempNode): - # x: int assigns rvalue to TempNode(AnyType()) - self.fail('Right hand side values are not supported in TypedDict', stmt) - return fields, types - def visit_import(self, i: Import) -> None: for id, as_id in i.ids: if as_id is not None: @@ -1344,8 +1178,8 @@ def visit_import_from(self, imp: ImportFrom) -> None: # Missing module. self.add_unknown_symbol(as_id or id, imp, is_import=True) - def process_import_over_existing_name(self, - imported_id: str, existing_symbol: SymbolTableNode, + @staticmethod + def process_import_over_existing_name(imported_id: str, existing_symbol: SymbolTableNode, module_symbol: SymbolTableNode, import_node: ImportBase) -> bool: if (existing_symbol.kind in (LDEF, GDEF, MDEF) and @@ -1494,11 +1328,7 @@ def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for lvalue in s.lvalues: self.store_declared_types(lvalue, s.type) self.check_and_set_up_type_alias(s) - self.process_newtype_declaration(s) - self.process_typevar_declaration(s) - self.process_namedtuple_definition(s) - self.process_typeddict_definition(s) - self.process_enum_call(s) + self.specialtype.process_declaration(s) if (len(s.lvalues) == 1 and isinstance(s.lvalues[0], NameExpr) and s.lvalues[0].name == '__all__' and s.lvalues[0].kind == GDEF and @@ -1721,591 +1551,6 @@ def store_declared_types(self, lvalue: Lvalue, typ: Type) -> None: # This has been flagged elsewhere as an error, so just ignore here. pass - def process_newtype_declaration(self, s: AssignmentStmt) -> None: - """Check if s declares a NewType; if yes, store it in symbol table.""" - # Extract and check all information from newtype declaration - name, call = self.analyze_newtype_declaration(s) - if name is None or call is None: - return - - old_type = self.check_newtype_args(name, call, s) - call.analyzed = NewTypeExpr(name, old_type, line=call.line) - if old_type is None: - return - - # Create the corresponding class definition if the aliased type is subtypeable - if isinstance(old_type, TupleType): - newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type.fallback) - newtype_class_info.tuple_type = old_type - elif isinstance(old_type, Instance): - newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type) - else: - message = "Argument 2 to NewType(...) must be subclassable (got {})" - self.fail(message.format(old_type), s) - return - - # If so, add it to the symbol table. - node = self.lookup(name, s) - if node is None: - self.fail("Could not find {} in current namespace".format(name), s) - return - # TODO: why does NewType work in local scopes despite always being of kind GDEF? - node.kind = GDEF - call.analyzed.info = node.node = newtype_class_info - - def analyze_newtype_declaration(self, - s: AssignmentStmt) -> Tuple[Optional[str], Optional[CallExpr]]: - """Return the NewType call expression if `s` is a newtype declaration or None otherwise.""" - name, call = None, None - if (len(s.lvalues) == 1 - and isinstance(s.lvalues[0], NameExpr) - and isinstance(s.rvalue, CallExpr) - and isinstance(s.rvalue.callee, RefExpr) - and s.rvalue.callee.fullname == 'typing.NewType'): - lvalue = s.lvalues[0] - name = s.lvalues[0].name - if not lvalue.is_def: - if s.type: - self.fail("Cannot declare the type of a NewType declaration", s) - else: - self.fail("Cannot redefine '%s' as a NewType" % name, s) - - # This dummy NewTypeExpr marks the call as sufficiently analyzed; it will be - # overwritten later with a fully complete NewTypeExpr if there are no other - # errors with the NewType() call. - call = s.rvalue - - return name, call - - def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Optional[Type]: - has_failed = False - args, arg_kinds = call.args, call.arg_kinds - if len(args) != 2 or arg_kinds[0] != ARG_POS or arg_kinds[1] != ARG_POS: - self.fail("NewType(...) expects exactly two positional arguments", context) - return None - - # Check first argument - if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): - self.fail("Argument 1 to NewType(...) must be a string literal", context) - has_failed = True - elif args[0].value != name: - msg = "String argument 1 '{}' to NewType(...) does not match variable name '{}'" - self.fail(msg.format(args[0].value, name), context) - has_failed = True - - # Check second argument - try: - unanalyzed_type = expr_to_unanalyzed_type(args[1]) - except TypeTranslationError: - self.fail("Argument 2 to NewType(...) must be a valid type", context) - return None - old_type = self.anal_type(unanalyzed_type) - - if isinstance(old_type, Instance) and old_type.type.is_newtype: - self.fail("Argument 2 to NewType(...) cannot be another NewType", context) - has_failed = True - - return None if has_failed else old_type - - def build_newtype_typeinfo(self, name: str, old_type: Type, base_type: Instance) -> TypeInfo: - info = self.basic_new_typeinfo(name, base_type) - info.is_newtype = True - - # Add __init__ method - args = [Argument(Var('cls'), NoneTyp(), None, ARG_POS), - self.make_argument('item', old_type)] - signature = CallableType( - arg_types=[cast(Type, None), old_type], - arg_kinds=[arg.kind for arg in args], - arg_names=['self', 'item'], - ret_type=old_type, - fallback=self.named_type('__builtins__.function'), - name=name) - init_func = FuncDef('__init__', args, Block([]), typ=signature) - init_func.info = info - info.names['__init__'] = SymbolTableNode(MDEF, init_func) - - return info - - def process_typevar_declaration(self, s: AssignmentStmt) -> None: - """Check if s declares a TypeVar; it yes, store it in symbol table.""" - call = self.get_typevar_declaration(s) - if not call: - return - - lvalue = s.lvalues[0] - assert isinstance(lvalue, NameExpr) - name = lvalue.name - if not lvalue.is_def: - if s.type: - self.fail("Cannot declare the type of a type variable", s) - else: - self.fail("Cannot redefine '%s' as a type variable" % name, s) - return - - if not self.check_typevar_name(call, name, s): - return - - # Constraining types - n_values = call.arg_kinds[1:].count(ARG_POS) - values = self.analyze_types(call.args[1:1 + n_values]) - - res = self.process_typevar_parameters(call.args[1 + n_values:], - call.arg_names[1 + n_values:], - call.arg_kinds[1 + n_values:], - n_values, - s) - if res is None: - return - variance, upper_bound = res - - # Yes, it's a valid type variable definition! Add it to the symbol table. - node = self.lookup(name, s) - node.kind = UNBOUND_TVAR - TypeVar = TypeVarExpr(name, node.fullname, values, upper_bound, variance) - TypeVar.line = call.line - call.analyzed = TypeVar - node.node = TypeVar - - def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> bool: - if len(call.args) < 1: - self.fail("Too few arguments for TypeVar()", context) - return False - if (not isinstance(call.args[0], (StrExpr, BytesExpr, UnicodeExpr)) - or not call.arg_kinds[0] == ARG_POS): - self.fail("TypeVar() expects a string literal as first argument", context) - return False - elif call.args[0].value != name: - msg = "String argument 1 '{}' to TypeVar(...) does not match variable name '{}'" - self.fail(msg.format(call.args[0].value, name), context) - return False - return True - - def get_typevar_declaration(self, s: AssignmentStmt) -> Optional[CallExpr]: - """Returns the TypeVar() call expression if `s` is a type var declaration - or None otherwise. - """ - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return None - if not isinstance(s.rvalue, CallExpr): - return None - call = s.rvalue - callee = call.callee - if not isinstance(callee, RefExpr): - return None - if callee.fullname != 'typing.TypeVar': - return None - return call - - def process_typevar_parameters(self, args: List[Expression], - names: List[Optional[str]], - kinds: List[int], - num_values: int, - context: Context) -> Optional[Tuple[int, Type]]: - has_values = (num_values > 0) - covariant = False - contravariant = False - upper_bound = self.object_type() # type: Type - for param_value, param_name, param_kind in zip(args, names, kinds): - if not param_kind == ARG_NAMED: - self.fail("Unexpected argument to TypeVar()", context) - return None - if param_name == 'covariant': - if isinstance(param_value, NameExpr): - if param_value.name == 'True': - covariant = True - else: - self.fail("TypeVar 'covariant' may only be 'True'", context) - return None - else: - self.fail("TypeVar 'covariant' may only be 'True'", context) - return None - elif param_name == 'contravariant': - if isinstance(param_value, NameExpr): - if param_value.name == 'True': - contravariant = True - else: - self.fail("TypeVar 'contravariant' may only be 'True'", context) - return None - else: - self.fail("TypeVar 'contravariant' may only be 'True'", context) - return None - elif param_name == 'bound': - if has_values: - self.fail("TypeVar cannot have both values and an upper bound", context) - return None - try: - upper_bound = self.expr_to_analyzed_type(param_value) - except TypeTranslationError: - self.fail("TypeVar 'bound' must be a type", param_value) - return None - elif param_name == 'values': - # Probably using obsolete syntax with values=(...). Explain the current syntax. - self.fail("TypeVar 'values' argument not supported", context) - self.fail("Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))", - context) - return None - else: - self.fail("Unexpected argument to TypeVar(): {}".format(param_name), context) - return None - - if covariant and contravariant: - self.fail("TypeVar cannot be both covariant and contravariant", context) - return None - elif num_values == 1: - self.fail("TypeVar cannot have only a single constraint", context) - return None - elif covariant: - variance = COVARIANT - elif contravariant: - variance = CONTRAVARIANT - else: - variance = INVARIANT - return (variance, upper_bound) - - def process_namedtuple_definition(self, s: AssignmentStmt) -> None: - """Check if s defines a namedtuple; if yes, store the definition in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - lvalue = s.lvalues[0] - name = lvalue.name - named_tuple = self.check_namedtuple(s.rvalue, name) - if named_tuple is None: - return - # Yes, it's a valid namedtuple definition. Add it to the symbol table. - node = self.lookup(name, s) - node.kind = GDEF # TODO locally defined namedtuple - node.node = named_tuple - - def check_namedtuple(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: - """Check if a call defines a namedtuple. - - The optional var_name argument is the name of the variable to - which this is assigned, if any. - - If it does, return the corresponding TypeInfo. Return None otherwise. - - If the definition is invalid but looks like a namedtuple, - report errors but return (some) TypeInfo. - """ - if not isinstance(node, CallExpr): - return None - call = node - callee = call.callee - if not isinstance(callee, RefExpr): - return None - fullname = callee.fullname - if fullname not in ('collections.namedtuple', 'typing.NamedTuple'): - return None - items, types, ok = self.parse_namedtuple_args(call, fullname) - if not ok: - # Error. Construct dummy return value. - return self.build_namedtuple_typeinfo('namedtuple', [], [], {}) - name = cast(StrExpr, call.args[0]).value - if name != var_name or self.is_func_scope(): - # Give it a unique name derived from the line number. - name += '@' + str(call.line) - info = self.build_namedtuple_typeinfo(name, items, types, {}) - # Store it as a global just in case it would remain anonymous. - # (Or in the nearest class if there is one.) - stnode = SymbolTableNode(GDEF, info, self.cur_mod_id) - if self.type: - self.type.names[name] = stnode - else: - self.globals[name] = stnode - call.analyzed = NamedTupleExpr(info) - call.analyzed.set_line(call.line, call.column) - return info - - def parse_namedtuple_args(self, call: CallExpr, - fullname: str) -> Tuple[List[str], List[Type], bool]: - # TODO: Share code with check_argument_count in checkexpr.py? - args = call.args - if len(args) < 2: - return self.fail_namedtuple_arg("Too few arguments for namedtuple()", call) - if len(args) > 2: - # FIX incorrect. There are two additional parameters - return self.fail_namedtuple_arg("Too many arguments for namedtuple()", call) - if call.arg_kinds != [ARG_POS, ARG_POS]: - return self.fail_namedtuple_arg("Unexpected arguments to namedtuple()", call) - if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): - return self.fail_namedtuple_arg( - "namedtuple() expects a string literal as the first argument", call) - types = [] # type: List[Type] - ok = True - if not isinstance(args[1], (ListExpr, TupleExpr)): - if (fullname == 'collections.namedtuple' - and isinstance(args[1], (StrExpr, BytesExpr, UnicodeExpr))): - str_expr = cast(StrExpr, args[1]) - items = str_expr.value.replace(',', ' ').split() - else: - return self.fail_namedtuple_arg( - "List or tuple literal expected as the second argument to namedtuple()", call) - else: - listexpr = args[1] - if fullname == 'collections.namedtuple': - # The fields argument contains just names, with implicit Any types. - if any(not isinstance(item, (StrExpr, BytesExpr, UnicodeExpr)) - for item in listexpr.items): - return self.fail_namedtuple_arg("String literal expected as namedtuple() item", - call) - items = [cast(StrExpr, item).value for item in listexpr.items] - else: - # The fields argument contains (name, type) tuples. - items, types, ok = self.parse_namedtuple_fields_with_types(listexpr.items, call) - if not types: - types = [AnyType() for _ in items] - underscore = [item for item in items if item.startswith('_')] - if underscore: - self.fail("namedtuple() field names cannot start with an underscore: " - + ', '.join(underscore), call) - return items, types, ok - - def parse_namedtuple_fields_with_types(self, nodes: List[Expression], - context: Context) -> Tuple[List[str], List[Type], bool]: - items = [] # type: List[str] - types = [] # type: List[Type] - for item in nodes: - if isinstance(item, TupleExpr): - if len(item.items) != 2: - return self.fail_namedtuple_arg("Invalid NamedTuple field definition", - item) - name, type_node = item.items - if isinstance(name, (StrExpr, BytesExpr, UnicodeExpr)): - items.append(name.value) - else: - return self.fail_namedtuple_arg("Invalid NamedTuple() field name", item) - try: - type = expr_to_unanalyzed_type(type_node) - except TypeTranslationError: - return self.fail_namedtuple_arg('Invalid field type', type_node) - types.append(self.anal_type(type)) - else: - return self.fail_namedtuple_arg("Tuple expected as NamedTuple() field", item) - return items, types, True - - def fail_namedtuple_arg(self, message: str, - context: Context) -> Tuple[List[str], List[Type], bool]: - self.fail(message, context) - return [], [], False - - def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance) -> TypeInfo: - class_def = ClassDef(name, Block([])) - class_def.fullname = self.qualified_name(name) - - info = TypeInfo(SymbolTable(), class_def, self.cur_mod_id) - info.mro = [info] + basetype_or_fallback.type.mro - info.bases = [basetype_or_fallback] - return info - - def build_namedtuple_typeinfo(self, name: str, items: List[str], types: List[Type], - default_items: Dict[str, Expression]) -> TypeInfo: - strtype = self.str_type() - basetuple_type = self.named_type('__builtins__.tuple', [AnyType()]) - dictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) - or self.object_type()) - # Actual signature should return OrderedDict[str, Union[types]] - ordereddictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) - or self.object_type()) - fallback = self.named_type('__builtins__.tuple', types) - # Note: actual signature should accept an invariant version of Iterable[UnionType[types]]. - # but it can't be expressed. 'new' and 'len' should be callable types. - iterable_type = self.named_type_or_none('typing.Iterable', [AnyType()]) - function_type = self.named_type('__builtins__.function') - - info = self.basic_new_typeinfo(name, fallback) - info.is_named_tuple = True - info.tuple_type = TupleType(types, fallback) - - def add_field(var: Var, is_initialized_in_class: bool = False, - is_property: bool = False) -> None: - var.info = info - var.is_initialized_in_class = is_initialized_in_class - var.is_property = is_property - info.names[var.name()] = SymbolTableNode(MDEF, var) - - vars = [Var(item, typ) for item, typ in zip(items, types)] - for var in vars: - add_field(var, is_property=True) - - tuple_of_strings = TupleType([strtype for _ in items], basetuple_type) - add_field(Var('_fields', tuple_of_strings), is_initialized_in_class=True) - add_field(Var('_field_types', dictype), is_initialized_in_class=True) - add_field(Var('_field_defaults', dictype), is_initialized_in_class=True) - add_field(Var('_source', strtype), is_initialized_in_class=True) - - tvd = TypeVarDef('NT', 1, [], info.tuple_type) - selftype = TypeVarType(tvd) - - def add_method(funcname: str, - ret: Type, - args: List[Argument], - name: str = None, - is_classmethod: bool = False, - ) -> None: - if is_classmethod: - first = [Argument(Var('cls'), TypeType(selftype), None, ARG_POS)] - else: - first = [Argument(Var('self'), selftype, None, ARG_POS)] - args = first + args - - types = [arg.type_annotation for arg in args] - items = [arg.variable.name() for arg in args] - arg_kinds = [arg.kind for arg in args] - signature = CallableType(types, arg_kinds, items, ret, function_type, - name=name or info.name() + '.' + funcname) - signature.variables = [tvd] - func = FuncDef(funcname, args, Block([]), typ=signature) - func.info = info - func.is_class = is_classmethod - if is_classmethod: - v = Var(funcname, signature) - v.is_classmethod = True - v.info = info - dec = Decorator(func, [NameExpr('classmethod')], v) - info.names[funcname] = SymbolTableNode(MDEF, dec) - else: - info.names[funcname] = SymbolTableNode(MDEF, func) - - add_method('_replace', ret=selftype, - args=[Argument(var, var.type, EllipsisExpr(), ARG_NAMED_OPT) for var in vars]) - - def make_init_arg(var: Var) -> Argument: - default = default_items.get(var.name(), None) - kind = ARG_POS if default is None else ARG_OPT - return Argument(var, var.type, default, kind) - - add_method('__init__', ret=NoneTyp(), name=info.name(), - args=[make_init_arg(var) for var in vars]) - add_method('_asdict', args=[], ret=ordereddictype) - add_method('_make', ret=selftype, is_classmethod=True, - args=[Argument(Var('iterable', iterable_type), iterable_type, None, ARG_POS), - Argument(Var('new'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT), - Argument(Var('len'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT)]) - return info - - def make_argument(self, name: str, type: Type) -> Argument: - return Argument(Var(name), type, None, ARG_POS) - - def analyze_types(self, items: List[Expression]) -> List[Type]: - result = [] # type: List[Type] - for node in items: - try: - result.append(self.anal_type(expr_to_unanalyzed_type(node))) - except TypeTranslationError: - self.fail('Type expected', node) - result.append(AnyType()) - return result - - def process_typeddict_definition(self, s: AssignmentStmt) -> None: - """Check if s defines a TypedDict; if yes, store the definition in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - lvalue = s.lvalues[0] - name = lvalue.name - typed_dict = self.check_typeddict(s.rvalue, name) - if typed_dict is None: - return - # Yes, it's a valid TypedDict definition. Add it to the symbol table. - node = self.lookup(name, s) - if node: - node.kind = GDEF # TODO locally defined TypedDict - node.node = typed_dict - - def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: - """Check if a call defines a TypedDict. - - The optional var_name argument is the name of the variable to - which this is assigned, if any. - - If it does, return the corresponding TypeInfo. Return None otherwise. - - If the definition is invalid but looks like a TypedDict, - report errors but return (some) TypeInfo. - """ - if not isinstance(node, CallExpr): - return None - call = node - callee = call.callee - if not isinstance(callee, RefExpr): - return None - fullname = callee.fullname - if fullname != 'mypy_extensions.TypedDict': - return None - items, types, ok = self.parse_typeddict_args(call, fullname) - if not ok: - # Error. Construct dummy return value. - return self.build_typeddict_typeinfo('TypedDict', [], []) - name = cast(StrExpr, call.args[0]).value - if name != var_name or self.is_func_scope(): - # Give it a unique name derived from the line number. - name += '@' + str(call.line) - info = self.build_typeddict_typeinfo(name, items, types) - # Store it as a global just in case it would remain anonymous. - # (Or in the nearest class if there is one.) - stnode = SymbolTableNode(GDEF, info, self.cur_mod_id) - if self.type: - self.type.names[name] = stnode - else: - self.globals[name] = stnode - call.analyzed = TypedDictExpr(info) - call.analyzed.set_line(call.line, call.column) - return info - - def parse_typeddict_args(self, call: CallExpr, - fullname: str) -> Tuple[List[str], List[Type], bool]: - # TODO: Share code with check_argument_count in checkexpr.py? - args = call.args - if len(args) < 2: - return self.fail_typeddict_arg("Too few arguments for TypedDict()", call) - if len(args) > 2: - return self.fail_typeddict_arg("Too many arguments for TypedDict()", call) - # TODO: Support keyword arguments - if call.arg_kinds != [ARG_POS, ARG_POS]: - return self.fail_typeddict_arg("Unexpected arguments to TypedDict()", call) - if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): - return self.fail_typeddict_arg( - "TypedDict() expects a string literal as the first argument", call) - if not isinstance(args[1], DictExpr): - return self.fail_typeddict_arg( - "TypedDict() expects a dictionary literal as the second argument", call) - dictexpr = args[1] - items, types, ok = self.parse_typeddict_fields_with_types(dictexpr.items, call) - return items, types, ok - - def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, Expression]], - context: Context) -> Tuple[List[str], List[Type], bool]: - items = [] # type: List[str] - types = [] # type: List[Type] - for (field_name_expr, field_type_expr) in dict_items: - if isinstance(field_name_expr, (StrExpr, BytesExpr, UnicodeExpr)): - items.append(field_name_expr.value) - else: - return self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr) - try: - type = expr_to_unanalyzed_type(field_type_expr) - except TypeTranslationError: - return self.fail_typeddict_arg('Invalid field type', field_type_expr) - types.append(self.anal_type(type)) - return items, types, True - - def fail_typeddict_arg(self, message: str, - context: Context) -> Tuple[List[str], List[Type], bool]: - self.fail(message, context) - return [], [], False - - def build_typeddict_typeinfo(self, name: str, items: List[str], - types: List[Type]) -> TypeInfo: - mapping_value_type = join.join_type_list(types) - fallback = (self.named_type_or_none('typing.Mapping', - [self.str_type(), mapping_value_type]) - or self.object_type()) - - info = self.basic_new_typeinfo(name, fallback) - info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), fallback) - - return info - def check_classvar(self, s: AssignmentStmt) -> None: lvalue = s.lvalues[0] if len(s.lvalues) != 1 or not isinstance(lvalue, RefExpr): @@ -2332,139 +1577,6 @@ def is_classvar(self, typ: Type) -> bool: def fail_invalid_classvar(self, context: Context) -> None: self.fail('ClassVar can only be used for assignments in class body', context) - def process_enum_call(self, s: AssignmentStmt) -> None: - """Check if s defines an Enum; if yes, store the definition in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - lvalue = s.lvalues[0] - name = lvalue.name - enum_call = self.check_enum_call(s.rvalue, name) - if enum_call is None: - return - # Yes, it's a valid Enum definition. Add it to the symbol table. - node = self.lookup(name, s) - if node: - node.kind = GDEF # TODO locally defined Enum - node.node = enum_call - - def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: - """Check if a call defines an Enum. - - Example: - - A = enum.Enum('A', 'foo bar') - - is equivalent to: - - class A(enum.Enum): - foo = 1 - bar = 2 - """ - if not isinstance(node, CallExpr): - return None - call = node - callee = call.callee - if not isinstance(callee, RefExpr): - return None - fullname = callee.fullname - if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): - return None - items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1]) - if not ok: - # Error. Construct dummy return value. - return self.build_enum_call_typeinfo('Enum', [], fullname) - name = cast(StrExpr, call.args[0]).value - if name != var_name or self.is_func_scope(): - # Give it a unique name derived from the line number. - name += '@' + str(call.line) - info = self.build_enum_call_typeinfo(name, items, fullname) - # Store it as a global just in case it would remain anonymous. - # (Or in the nearest class if there is one.) - stnode = SymbolTableNode(GDEF, info, self.cur_mod_id) - if self.type: - self.type.names[name] = stnode - else: - self.globals[name] = stnode - call.analyzed = EnumCallExpr(info, items, values) - call.analyzed.set_line(call.line, call.column) - return info - - def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo: - base = self.named_type_or_none(fullname) - assert base is not None - info = self.basic_new_typeinfo(name, base) - info.is_enum = True - for item in items: - var = Var(item) - var.info = info - var.is_property = True - info.names[item] = SymbolTableNode(MDEF, var) - return info - - def parse_enum_call_args(self, call: CallExpr, - class_name: str) -> Tuple[List[str], - List[Optional[Expression]], bool]: - args = call.args - if len(args) < 2: - return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call) - if len(args) > 2: - return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call) - if call.arg_kinds != [ARG_POS, ARG_POS]: - return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call) - if not isinstance(args[0], (StrExpr, UnicodeExpr)): - return self.fail_enum_call_arg( - "%s() expects a string literal as the first argument" % class_name, call) - items = [] - values = [] # type: List[Optional[Expression]] - if isinstance(args[1], (StrExpr, UnicodeExpr)): - fields = args[1].value - for field in fields.replace(',', ' ').split(): - items.append(field) - elif isinstance(args[1], (TupleExpr, ListExpr)): - seq_items = args[1].items - if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): - items = [cast(StrExpr, seq_item).value for seq_item in seq_items] - elif all(isinstance(seq_item, (TupleExpr, ListExpr)) - and len(seq_item.items) == 2 - and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) - for seq_item in seq_items): - for seq_item in seq_items: - assert isinstance(seq_item, (TupleExpr, ListExpr)) - name, value = seq_item.items - assert isinstance(name, (StrExpr, UnicodeExpr)) - items.append(name.value) - values.append(value) - else: - return self.fail_enum_call_arg( - "%s() with tuple or list expects strings or (name, value) pairs" % - class_name, - call) - elif isinstance(args[1], DictExpr): - for key, value in args[1].items: - if not isinstance(key, (StrExpr, UnicodeExpr)): - return self.fail_enum_call_arg( - "%s() with dict literal requires string literals" % class_name, call) - items.append(key.value) - values.append(value) - else: - # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}? - return self.fail_enum_call_arg( - "%s() expects a string, tuple, list or dict literal as the second argument" % - class_name, - call) - if len(items) == 0: - return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call) - if not values: - values = [None] * len(items) - assert len(items) == len(values) - return items, values, True - - def fail_enum_call_arg(self, message: str, - context: Context) -> Tuple[List[str], - List[Optional[Expression]], bool]: - self.fail(message, context) - return [], [], False - def visit_decorator(self, dec: Decorator) -> None: for d in dec.decorators: d.accept(self) @@ -3317,7 +2429,8 @@ def note(self, msg: str, ctx: Context) -> None: return self.errors.report(ctx.get_line(), ctx.get_column(), msg, severity='note') - def undefined_name_extra_info(self, fullname: str) -> Optional[str]: + @staticmethod + def undefined_name_extra_info(fullname: str) -> Optional[str]: if fullname in obsolete_name_mapping: return "(it's now called '{}')".format(obsolete_name_mapping[fullname]) else: diff --git a/mypy/specialtype.py b/mypy/specialtype.py new file mode 100644 index 0000000000000..3978e24714f27 --- /dev/null +++ b/mypy/specialtype.py @@ -0,0 +1,951 @@ +from collections import OrderedDict + +from typing import List, Dict, Tuple, cast, Optional + +from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError +from mypy.nodes import ( + TypeInfo, AssignmentStmt, FuncDef, ClassDef, Var, GDEF, Expression, + Block, NameExpr, TupleExpr, ListExpr, ExpressionStmt, PassStmt, + DictExpr, CallExpr, RefExpr, Context, SymbolTable, UNBOUND_TVAR, + MDEF, Decorator, TypeVarExpr, NewTypeExpr, StrExpr, BytesExpr, + ARG_POS, ARG_NAMED, ARG_NAMED_OPT, NamedTupleExpr, TypedDictExpr, Argument, + UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, + COVARIANT, CONTRAVARIANT, INVARIANT, ARG_OPT, SymbolTableNode +) +from mypy.types import ( + NoneTyp, CallableType, Instance, Type, TypeVarType, AnyType, + TypeVarDef, TupleType, UnboundType, TypedDictType, TypeType, +) +from mypy.semanal import SemanticAnalyzer +from mypy import join + + +class Special: + """ + Groups special-cased types: + * NamedTuple + * TypedDict + * NewType + Also handles analysis of special constructs: + * Enum (functional style) + * TypeVar + + The interface consists of + * process_declarations() + * analyze_* + """ + semanalyzer = None # type: SemanticAnalyzer + + def __init__(self, semanalyzer: SemanticAnalyzer) -> None: + self.semanalyzer = semanalyzer + # Delegations: + self.fail = semanalyzer.fail + self.lookup = semanalyzer.lookup + self.lookup_qualified = semanalyzer.lookup_qualified + self.named_type = semanalyzer.named_type + self.named_type_or_none = semanalyzer.named_type_or_none + self.object_type = semanalyzer.object_type + self.str_type = semanalyzer.str_type + + def process_declaration(self, s: AssignmentStmt) -> None: + self.process_newtype_declaration(s) + self.process_typevar_declaration(s) + self.process_namedtuple_definition(s) + self.process_typeddict_definition(s) + self.process_enum_call(s) + + def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: + for base_expr in defn.base_type_exprs: + if isinstance(base_expr, RefExpr): + base_expr.accept(self.semanalyzer) + if base_expr.fullname == 'typing.NamedTuple': + node = self.lookup(defn.name, defn) + if node is not None: + node.kind = GDEF # TODO in process_namedtuple_definition also applies here + items, types, default_items = self.check_namedtuple_classdef(defn) + node.node = self.build_namedtuple_typeinfo( + defn.name, items, types, default_items) + return True + return False + + def check_namedtuple_classdef( + self, defn: ClassDef) -> Tuple[List[str], List[Type], Dict[str, Expression]]: + NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' + 'expected "field_name: field_type"') + if self.semanalyzer.options.python_version < (3, 6): + self.fail('NamedTuple class syntax is only supported in Python 3.6', defn) + return [], [], {} + if len(defn.base_type_exprs) > 1: + self.fail('NamedTuple should be a single base', defn) + items = [] # type: List[str] + types = [] # type: List[Type] + default_items = {} # type: Dict[str, Expression] + for stmt in defn.defs.body: + if not isinstance(stmt, AssignmentStmt): + # Still allow pass or ... (for empty namedtuples). + if (not isinstance(stmt, PassStmt) and + not (isinstance(stmt, ExpressionStmt) and + isinstance(stmt.expr, EllipsisExpr))): + self.fail(NAMEDTUP_CLASS_ERROR, stmt) + elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): + # An assignment, but an invalid one. + self.fail(NAMEDTUP_CLASS_ERROR, stmt) + else: + # Append name and type in this case... + name = stmt.lvalues[0].name + items.append(name) + types.append(AnyType() if stmt.type is None else self.semanalyzer.anal_type(stmt.type)) + # ...despite possible minor failures that allow further analyzis. + if name.startswith('_'): + self.fail('NamedTuple field name cannot start with an underscore: {}' + .format(name), stmt) + if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: + self.fail(NAMEDTUP_CLASS_ERROR, stmt) + elif isinstance(stmt.rvalue, TempNode): + # x: int assigns rvalue to TempNode(AnyType()) + if default_items: + self.fail('Non-default NamedTuple fields cannot follow default fields', + stmt) + else: + default_items[name] = stmt.rvalue + return items, types, default_items + + @staticmethod + def is_typeddict(expr: Expression) -> bool: + return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and + expr.node.typeddict_type is not None) + + def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: + # special case for TypedDict + possible = False + for base_expr in defn.base_type_exprs: + if isinstance(base_expr, RefExpr): + base_expr.accept(self.semanalyzer) + if (base_expr.fullname == 'mypy_extensions.TypedDict' or + self.is_typeddict(base_expr)): + possible = True + if possible: + node = self.lookup(defn.name, defn) + if node is not None: + node.kind = GDEF # TODO in process_namedtuple_definition also applies here + if (len(defn.base_type_exprs) == 1 and + isinstance(defn.base_type_exprs[0], RefExpr) and + defn.base_type_exprs[0].fullname == 'mypy_extensions.TypedDict'): + # Building a new TypedDict + fields, types = self.check_typeddict_classdef(defn) + node.node = self.build_typeddict_typeinfo(defn.name, fields, types) + return True + # Extending/merging existing TypedDicts + if any(not isinstance(expr, RefExpr) or + expr.fullname != 'mypy_extensions.TypedDict' and + not self.is_typeddict(expr) for expr in defn.base_type_exprs): + self.fail("All bases of a new TypedDict must be TypedDict types", defn) + typeddict_bases = list(filter(self.is_typeddict, defn.base_type_exprs)) + newfields = [] # type: List[str] + newtypes = [] # type: List[Type] + tpdict = None # type: OrderedDict[str, Type] + for base in typeddict_bases: + assert isinstance(base, RefExpr) + assert isinstance(base.node, TypeInfo) + assert isinstance(base.node.typeddict_type, TypedDictType) + tpdict = base.node.typeddict_type.items + newdict = tpdict.copy() + for key in tpdict: + if key in newfields: + self.fail('Cannot overwrite TypedDict field "{}" while merging' + .format(key), defn) + newdict.pop(key) + newfields.extend(newdict.keys()) + newtypes.extend(newdict.values()) + fields, types = self.check_typeddict_classdef(defn, newfields) + newfields.extend(fields) + newtypes.extend(types) + node.node = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) + return True + return False + + def check_typeddict_classdef(self, defn: ClassDef, + oldfields: List[str] = None) -> Tuple[List[str], List[Type]]: + TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' + 'expected "field_name: field_type"') + if self.semanalyzer.options.python_version < (3, 6): + self.fail('TypedDict class syntax is only supported in Python 3.6', defn) + return [], [] + fields = [] # type: List[str] + types = [] # type: List[Type] + for stmt in defn.defs.body: + if not isinstance(stmt, AssignmentStmt): + # Still allow pass or ... (for empty TypedDict's). + if (not isinstance(stmt, PassStmt) and + not (isinstance(stmt, ExpressionStmt) and + isinstance(stmt.expr, EllipsisExpr))): + self.fail(TPDICT_CLASS_ERROR, stmt) + elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): + # An assignment, but an invalid one. + self.fail(TPDICT_CLASS_ERROR, stmt) + else: + name = stmt.lvalues[0].name + if name in (oldfields or []): + self.fail('Cannot overwrite TypedDict field "{}" while extending' + .format(name), stmt) + continue + if name in fields: + self.fail('Duplicate TypedDict field "{}"'.format(name), stmt) + continue + # Append name and type in this case... + fields.append(name) + types.append(AnyType() if stmt.type is None else self.semanalyzer.anal_type(stmt.type)) + # ...despite possible minor failures that allow further analyzis. + if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: + self.fail(TPDICT_CLASS_ERROR, stmt) + elif not isinstance(stmt.rvalue, TempNode): + # x: int assigns rvalue to TempNode(AnyType()) + self.fail('Right hand side values are not supported in TypedDict', stmt) + return fields, types + + def process_newtype_declaration(self, s: AssignmentStmt) -> None: + """Check if s declares a NewType; if yes, store it in symbol table.""" + # Extract and check all information from newtype declaration + name, call = self.analyze_newtype_declaration(s) + if name is None or call is None: + return + + old_type = self.check_newtype_args(name, call, s) + call.analyzed = NewTypeExpr(name, old_type, line=call.line) + if old_type is None: + return + + # Create the corresponding class definition if the aliased type is subtypeable + if isinstance(old_type, TupleType): + newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type.fallback) + newtype_class_info.tuple_type = old_type + elif isinstance(old_type, Instance): + newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type) + else: + message = "Argument 2 to NewType(...) must be subclassable (got {})" + self.fail(message.format(old_type), s) + return + + # If so, add it to the symbol table. + node = self.lookup(name, s) + if node is None: + self.fail("Could not find {} in current namespace".format(name), s) + return + # TODO: why does NewType work in local scopes despite always being of kind GDEF? + node.kind = GDEF + call.analyzed.info = node.node = newtype_class_info + + def analyze_newtype_declaration(self, s: AssignmentStmt + ) -> Tuple[Optional[str], Optional[CallExpr]]: + """Return the NewType call expression if `s` is a newtype declaration or None otherwise.""" + name, call = None, None + if (len(s.lvalues) == 1 + and isinstance(s.lvalues[0], NameExpr) + and isinstance(s.rvalue, CallExpr) + and isinstance(s.rvalue.callee, RefExpr) + and s.rvalue.callee.fullname == 'typing.NewType'): + lvalue = s.lvalues[0] + name = s.lvalues[0].name + if not lvalue.is_def: + if s.type: + self.fail("Cannot declare the type of a NewType declaration", s) + else: + self.fail("Cannot redefine '%s' as a NewType" % name, s) + + # This dummy NewTypeExpr marks the call as sufficiently analyzed; it will be + # overwritten later with a fully complete NewTypeExpr if there are no other + # errors with the NewType() call. + call = s.rvalue + + return name, call + + def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Optional[Type]: + has_failed = False + args, arg_kinds = call.args, call.arg_kinds + if len(args) != 2 or arg_kinds[0] != ARG_POS or arg_kinds[1] != ARG_POS: + self.fail("NewType(...) expects exactly two positional arguments", context) + return None + + # Check first argument + if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): + self.fail("Argument 1 to NewType(...) must be a string literal", context) + has_failed = True + elif args[0].value != name: + msg = "String argument 1 '{}' to NewType(...) does not match variable name '{}'" + self.fail(msg.format(args[0].value, name), context) + has_failed = True + + # Check second argument + try: + unanalyzed_type = expr_to_unanalyzed_type(args[1]) + except TypeTranslationError: + self.fail("Argument 2 to NewType(...) must be a valid type", context) + return None + old_type = self.semanalyzer.anal_type(unanalyzed_type) + + if isinstance(old_type, Instance) and old_type.type.is_newtype: + self.fail("Argument 2 to NewType(...) cannot be another NewType", context) + has_failed = True + + return None if has_failed else old_type + + def build_newtype_typeinfo(self, name: str, old_type: Type, base_type: Instance) -> TypeInfo: + info = self.basic_new_typeinfo(name, base_type) + info.is_newtype = True + + # Add __init__ method + args = [Argument(Var('cls'), NoneTyp(), None, ARG_POS), + self.make_argument('item', old_type)] + signature = CallableType( + arg_types=[cast(Type, None), old_type], + arg_kinds=[arg.kind for arg in args], + arg_names=['self', 'item'], + ret_type=old_type, + fallback=self.named_type('__builtins__.function'), + name=name) + init_func = FuncDef('__init__', args, Block([]), typ=signature) + init_func.info = info + info.names['__init__'] = SymbolTableNode(MDEF, init_func) + + return info + + def analyze_typevar_declaration(self, t: Type) -> Optional[List[Tuple[str, TypeVarExpr]]]: + if not isinstance(t, UnboundType): + return None + unbound = t + sym = self.lookup_qualified(unbound.name, unbound) + if sym is None or sym.node is None: + return None + if sym.node.fullname() == 'typing.Generic': + tvars = [] # type: List[Tuple[str, TypeVarExpr]] + for arg in unbound.args: + tvar = self.semanalyzer.analyze_unbound_tvar(arg) + if tvar: + tvars.append(tvar) + else: + self.fail('Free type variable expected in %s[...]' % + sym.node.name(), t) + return tvars + return None + + def analyze_types(self, items: List[Expression]) -> List[Type]: + result = [] # type: List[Type] + for node in items: + try: + result.append(self.semanalyzer.anal_type(expr_to_unanalyzed_type(node))) + except TypeTranslationError: + self.fail('Type expected', node) + result.append(AnyType()) + return result + + def process_typevar_declaration(self, s: AssignmentStmt) -> None: + """Check if s declares a TypeVar; it yes, store it in symbol table.""" + call = self.get_typevar_declaration(s) + if not call: + return + + lvalue = s.lvalues[0] + assert isinstance(lvalue, NameExpr) + name = lvalue.name + if not lvalue.is_def: + if s.type: + self.fail("Cannot declare the type of a type variable", s) + else: + self.fail("Cannot redefine '%s' as a type variable" % name, s) + return + + if not self.check_typevar_name(call, name, s): + return + + # Constraining types + n_values = call.arg_kinds[1:].count(ARG_POS) + values = self.analyze_types(call.args[1:1 + n_values]) + + res = self.process_typevar_parameters(call.args[1 + n_values:], + call.arg_names[1 + n_values:], + call.arg_kinds[1 + n_values:], + n_values, + s) + if res is None: + return + variance, upper_bound = res + + # Yes, it's a valid type variable definition! Add it to the symbol table. + node = self.lookup(name, s) + node.kind = UNBOUND_TVAR + TypeVar = TypeVarExpr(name, node.fullname, values, upper_bound, variance) + TypeVar.line = call.line + call.analyzed = TypeVar + node.node = TypeVar + + def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> bool: + if len(call.args) < 1: + self.fail("Too few arguments for TypeVar()", context) + return False + if (not isinstance(call.args[0], (StrExpr, BytesExpr, UnicodeExpr)) + or not call.arg_kinds[0] == ARG_POS): + self.fail("TypeVar() expects a string literal as first argument", context) + return False + elif call.args[0].value != name: + msg = "String argument 1 '{}' to TypeVar(...) does not match variable name '{}'" + self.fail(msg.format(call.args[0].value, name), context) + return False + return True + + @staticmethod + def get_typevar_declaration(s: AssignmentStmt) -> Optional[CallExpr]: + """Returns the TypeVar() call expression if `s` is a type var declaration + or None otherwise. + """ + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return None + if not isinstance(s.rvalue, CallExpr): + return None + call = s.rvalue + callee = call.callee + if not isinstance(callee, RefExpr): + return None + if callee.fullname != 'typing.TypeVar': + return None + return call + + def process_typevar_parameters(self, + args: List[Expression], + names: List[Optional[str]], + kinds: List[int], + num_values: int, + context: Context) -> Optional[Tuple[int, Type]]: + has_values = (num_values > 0) + covariant = False + contravariant = False + upper_bound = self.object_type() # type: Type + for param_value, param_name, param_kind in zip(args, names, kinds): + if not param_kind == ARG_NAMED: + self.fail("Unexpected argument to TypeVar()", context) + return None + if param_name == 'covariant': + if isinstance(param_value, NameExpr): + if param_value.name == 'True': + covariant = True + else: + self.fail("TypeVar 'covariant' may only be 'True'", context) + return None + else: + self.fail("TypeVar 'covariant' may only be 'True'", context) + return None + elif param_name == 'contravariant': + if isinstance(param_value, NameExpr): + if param_value.name == 'True': + contravariant = True + else: + self.fail("TypeVar 'contravariant' may only be 'True'", context) + return None + else: + self.fail("TypeVar 'contravariant' may only be 'True'", context) + return None + elif param_name == 'bound': + if has_values: + self.fail("TypeVar cannot have both values and an upper bound", context) + return None + try: + upper_bound = self.semanalyzer.expr_to_analyzed_type(param_value) + except TypeTranslationError: + self.fail("TypeVar 'bound' must be a type", param_value) + return None + elif param_name == 'values': + # Probably using obsolete syntax with values=(...). Explain the current syntax. + self.fail("TypeVar 'values' argument not supported", context) + self.fail("Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))", + context) + return None + else: + self.fail("Unexpected argument to TypeVar(): {}".format(param_name), context) + return None + + if covariant and contravariant: + self.fail("TypeVar cannot be both covariant and contravariant", context) + return None + elif num_values == 1: + self.fail("TypeVar cannot have only a single constraint", context) + return None + elif covariant: + variance = COVARIANT + elif contravariant: + variance = CONTRAVARIANT + else: + variance = INVARIANT + return (variance, upper_bound) + + def process_namedtuple_definition(self, s: AssignmentStmt) -> None: + """Check if s defines a namedtuple; if yes, store the definition in symbol table.""" + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + lvalue = s.lvalues[0] + name = lvalue.name + named_tuple = self.check_namedtuple(s.rvalue, name) + if named_tuple is None: + return + # Yes, it's a valid namedtuple definition. Add it to the symbol table. + node = self.lookup(name, s) + node.kind = GDEF # TODO locally defined namedtuple + node.node = named_tuple + + def check_namedtuple(self, expr: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines a namedtuple. + + The optional var_name argument is the name of the variable to + which this is assigned, if any. + + If it does, return the corresponding TypeInfo. Return None otherwise. + + If the definition is invalid but looks like a namedtuple, + report errors but return (some) TypeInfo. + """ + if not isinstance(expr, CallExpr): + return None + call = expr + callee = call.callee + if not isinstance(callee, RefExpr): + return None + fullname = callee.fullname + if fullname not in ('collections.namedtuple', 'typing.NamedTuple'): + return None + items, types, ok = self.parse_namedtuple_args(call, fullname) + if not ok: + # Error. Construct dummy return value. + return self.build_namedtuple_typeinfo('namedtuple', [], [], {}) + name = cast(StrExpr, call.args[0]).value + if name != var_name or self.semanalyzer.is_func_scope(): + # Give it a unique name derived from the line number. + name += '@' + str(call.line) + info = self.build_namedtuple_typeinfo(name, items, types, {}) + # Store it as a global just in case it would remain anonymous. + # (Or in the nearest class if there is one.) + stnode = SymbolTableNode(GDEF, info, self.semanalyzer.cur_mod_id) + if self.semanalyzer.type: + self.semanalyzer.type.names[name] = stnode + else: + self.semanalyzer.globals[name] = stnode + call.analyzed = NamedTupleExpr(info) + call.analyzed.set_line(call.line, call.column) + return info + + def parse_namedtuple_args(self, call: CallExpr, + fullname: str) -> Tuple[List[str], List[Type], bool]: + # TODO: Share code with check_argument_count in checkexpr.py? + args = call.args + if len(args) < 2: + return self.fail_namedtuple_arg("Too few arguments for namedtuple()", call) + if len(args) > 2: + # FIX incorrect. There are two additional parameters + return self.fail_namedtuple_arg("Too many arguments for namedtuple()", call) + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_namedtuple_arg("Unexpected arguments to namedtuple()", call) + if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): + return self.fail_namedtuple_arg( + "namedtuple() expects a string literal as the first argument", call) + types = [] # type: List[Type] + ok = True + if not isinstance(args[1], (ListExpr, TupleExpr)): + if (fullname == 'collections.namedtuple' + and isinstance(args[1], (StrExpr, BytesExpr, UnicodeExpr))): + str_expr = cast(StrExpr, args[1]) + items = str_expr.value.replace(',', ' ').split() + else: + return self.fail_namedtuple_arg( + "List or tuple literal expected as the second argument to namedtuple()", call) + else: + listexpr = args[1] + if fullname == 'collections.namedtuple': + # The fields argument contains just names, with implicit Any types. + if any(not isinstance(item, (StrExpr, BytesExpr, UnicodeExpr)) + for item in listexpr.items): + return self.fail_namedtuple_arg("String literal expected as namedtuple() item", + call) + items = [cast(StrExpr, item).value for item in listexpr.items] + else: + # The fields argument contains (name, type) tuples. + items, types, ok = self.parse_namedtuple_fields_with_types(listexpr.items) + if not types: + types = [AnyType() for _ in items] + underscore = [item for item in items if item.startswith('_')] + if underscore: + self.fail("namedtuple() field names cannot start with an underscore: " + + ', '.join(underscore), call) + return items, types, ok + + def parse_namedtuple_fields_with_types(self, nodes: List[Expression]) -> Tuple[List[str], List[Type], bool]: + items = [] # type: List[str] + types = [] # type: List[Type] + for item in nodes: + if isinstance(item, TupleExpr): + if len(item.items) != 2: + return self.fail_namedtuple_arg("Invalid NamedTuple field definition", + item) + name, type_node = item.items + if isinstance(name, (StrExpr, BytesExpr, UnicodeExpr)): + items.append(name.value) + else: + return self.fail_namedtuple_arg("Invalid NamedTuple() field name", item) + try: + type = expr_to_unanalyzed_type(type_node) + except TypeTranslationError: + return self.fail_namedtuple_arg('Invalid field type', type_node) + types.append(self.semanalyzer.anal_type(type)) + else: + return self.fail_namedtuple_arg("Tuple expected as NamedTuple() field", item) + return items, types, True + + def fail_namedtuple_arg(self, message: str, + context: Context) -> Tuple[List[str], List[Type], bool]: + self.fail(message, context) + return [], [], False + + def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance) -> TypeInfo: + class_def = ClassDef(name, Block([])) + class_def.fullname = self.semanalyzer.qualified_name(name) + + info = TypeInfo(SymbolTable(), class_def, self.semanalyzer.cur_mod_id) + info.mro = [info] + basetype_or_fallback.type.mro + info.bases = [basetype_or_fallback] + return info + + def analyze_callexpr_as_type(self, call: CallExpr) -> Optional[Type]: + info = self.check_namedtuple(call) + if info is None: + # Some form of namedtuple is the only valid type that looks like a call + # expression. This isn't a valid type. + return None + fallback = Instance(info, []) + return TupleType(info.tuple_type.items, fallback=fallback) + + def build_namedtuple_typeinfo(self, name: str, items: List[str], types: List[Type], + default_items: Dict[str, Expression]) -> TypeInfo: + strtype = self.str_type() + object_type = self.object_type() + basetuple_type = self.named_type('__builtins__.tuple', [AnyType()]) + dictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) + or object_type) + # Actual signature should return OrderedDict[str, Union[types]] + ordereddictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) + or object_type) + fallback = self.named_type('__builtins__.tuple', types) + # Note: actual signature should accept an invariant version of Iterable[UnionType[types]]. + # but it can't be expressed. 'new' and 'len' should be callable types. + iterable_type = self.named_type_or_none('typing.Iterable', [AnyType()]) + function_type = self.named_type('__builtins__.function') + + info = self.basic_new_typeinfo(name, fallback) + info.is_named_tuple = True + info.tuple_type = TupleType(types, fallback) + + def add_field(var: Var, is_initialized_in_class: bool = False, + is_property: bool = False) -> None: + var.info = info + var.is_initialized_in_class = is_initialized_in_class + var.is_property = is_property + info.names[var.name()] = SymbolTableNode(MDEF, var) + + vars = [Var(item, typ) for item, typ in zip(items, types)] + for var in vars: + add_field(var, is_property=True) + + tuple_of_strings = TupleType([strtype for _ in items], basetuple_type) + add_field(Var('_fields', tuple_of_strings), is_initialized_in_class=True) + add_field(Var('_field_types', dictype), is_initialized_in_class=True) + add_field(Var('_field_defaults', dictype), is_initialized_in_class=True) + add_field(Var('_source', strtype), is_initialized_in_class=True) + + tvd = TypeVarDef('NT', 1, [], info.tuple_type) + selftype = TypeVarType(tvd) + + def add_method(funcname: str, + ret: Type, + args: List[Argument], + name: str = None, + is_classmethod: bool = False, + ) -> None: + if is_classmethod: + first = [Argument(Var('cls'), TypeType(selftype), None, ARG_POS)] + else: + first = [Argument(Var('self'), selftype, None, ARG_POS)] + args = first + args + + types = [arg.type_annotation for arg in args] + items = [arg.variable.name() for arg in args] + arg_kinds = [arg.kind for arg in args] + signature = CallableType(types, arg_kinds, items, ret, function_type, + name=name or info.name() + '.' + funcname) + signature.variables = [tvd] + func = FuncDef(funcname, args, Block([]), typ=signature) + func.info = info + func.is_class = is_classmethod + if is_classmethod: + v = Var(funcname, signature) + v.is_classmethod = True + v.info = info + dec = Decorator(func, [NameExpr('classmethod')], v) + info.names[funcname] = SymbolTableNode(MDEF, dec) + else: + info.names[funcname] = SymbolTableNode(MDEF, func) + + add_method('_replace', ret=selftype, + args=[Argument(var, var.type, EllipsisExpr(), ARG_NAMED_OPT) for var in vars]) + + def make_init_arg(var: Var) -> Argument: + default = default_items.get(var.name(), None) + kind = ARG_POS if default is None else ARG_OPT + return Argument(var, var.type, default, kind) + + add_method('__init__', ret=NoneTyp(), name=info.name(), + args=[make_init_arg(var) for var in vars]) + add_method('_asdict', args=[], ret=ordereddictype) + add_method('_make', ret=selftype, is_classmethod=True, + args=[Argument(Var('iterable', iterable_type), iterable_type, None, ARG_POS), + Argument(Var('new'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT), + Argument(Var('len'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT)]) + return info + + @staticmethod + def make_argument(name: str, type: Type) -> Argument: + return Argument(Var(name), type, None, ARG_POS) + + def process_typeddict_definition(self, s: AssignmentStmt) -> None: + """Check if s defines a TypedDict; if yes, store the definition in symbol table.""" + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + lvalue = s.lvalues[0] + name = lvalue.name + typed_dict = self.check_typeddict(s.rvalue, name) + if typed_dict is None: + return + # Yes, it's a valid TypedDict definition. Add it to the symbol table. + node = self.lookup(name, s) + if node: + node.kind = GDEF # TODO locally defined TypedDict + node.node = typed_dict + + def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines a TypedDict. + + The optional var_name argument is the name of the variable to + which this is assigned, if any. + + If it does, return the corresponding TypeInfo. Return None otherwise. + + If the definition is invalid but looks like a TypedDict, + report errors but return (some) TypeInfo. + """ + if not isinstance(node, CallExpr): + return None + call = node + callee = call.callee + if not isinstance(callee, RefExpr): + return None + if callee.fullname != 'mypy_extensions.TypedDict': + return None + items, types, ok = self.parse_typeddict_args(call) + if not ok: + # Error. Construct dummy return value. + return self.build_typeddict_typeinfo('TypedDict', [], []) + name = cast(StrExpr, call.args[0]).value + if name != var_name or self.semanalyzer.is_func_scope(): + # Give it a unique name derived from the line number. + name += '@' + str(call.line) + info = self.build_typeddict_typeinfo(name, items, types) + # Store it as a global just in case it would remain anonymous. + # (Or in the nearest class if there is one.) + stnode = SymbolTableNode(GDEF, info, self.semanalyzer.cur_mod_id) + if self.semanalyzer.type: + self.semanalyzer.type.names[name] = stnode + else: + self.semanalyzer.globals[name] = stnode + call.analyzed = TypedDictExpr(info) + call.analyzed.set_line(call.line, call.column) + return info + + def parse_typeddict_args(self, call: CallExpr) -> Tuple[List[str], List[Type], bool]: + # TODO: Share code with check_argument_count in checkexpr.py? + args = call.args + if len(args) < 2: + return self.fail_typeddict_arg("Too few arguments for TypedDict()", call) + if len(args) > 2: + return self.fail_typeddict_arg("Too many arguments for TypedDict()", call) + # TODO: Support keyword arguments + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_typeddict_arg("Unexpected arguments to TypedDict()", call) + if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): + return self.fail_typeddict_arg( + "TypedDict() expects a string literal as the first argument", call) + if not isinstance(args[1], DictExpr): + return self.fail_typeddict_arg( + "TypedDict() expects a dictionary literal as the second argument", call) + dictexpr = args[1] + items, types, ok = self.parse_typeddict_fields_with_types(dictexpr.items) + return items, types, ok + + def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, Expression]], + ) -> Tuple[List[str], List[Type], bool]: + items = [] # type: List[str] + types = [] # type: List[Type] + for (field_name_expr, field_type_expr) in dict_items: + if isinstance(field_name_expr, (StrExpr, BytesExpr, UnicodeExpr)): + items.append(field_name_expr.value) + else: + return self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr) + try: + type = expr_to_unanalyzed_type(field_type_expr) + except TypeTranslationError: + return self.fail_typeddict_arg('Invalid field type', field_type_expr) + types.append(self.semanalyzer.anal_type(type)) + return items, types, True + + def fail_typeddict_arg(self, message: str, + context: Context) -> Tuple[List[str], List[Type], bool]: + self.fail(message, context) + return [], [], False + + def build_typeddict_typeinfo(self, name: str, items: List[str], + types: List[Type]) -> TypeInfo: + mapping_value_type = join.join_type_list(types) + fallback = (self.named_type_or_none('typing.Mapping', + [self.str_type(), mapping_value_type]) + or self.object_type()) + + info = self.basic_new_typeinfo(name, fallback) + info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), fallback) + + return info + + def process_enum_call(self, s: AssignmentStmt) -> None: + """Check if s defines an Enum; if yes, store the definition in symbol table.""" + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + lvalue = s.lvalues[0] + name = lvalue.name + enum_call = self.check_enum_call(s.rvalue, name) + if enum_call is None: + return + # Yes, it's a valid Enum definition. Add it to the symbol table. + node = self.lookup(name, s) + if node: + node.kind = GDEF # TODO locally defined Enum + node.node = enum_call + + def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines an Enum. + + Example: + + A = enum.Enum('A', 'foo bar') + + is equivalent to: + + class A(enum.Enum): + foo = 1 + bar = 2 + """ + if not isinstance(node, CallExpr): + return None + call = node + callee = call.callee + if not isinstance(callee, RefExpr): + return None + fullname = callee.fullname + if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): + return None + items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1]) + if not ok: + # Error. Construct dummy return value. + return self.build_enum_call_typeinfo('Enum', [], fullname) + name = cast(StrExpr, call.args[0]).value + if name != var_name or self.semanalyzer.is_func_scope(): + # Give it a unique name derived from the line number. + name += '@' + str(call.line) + info = self.build_enum_call_typeinfo(name, items, fullname) + # Store it as a global just in case it would remain anonymous. + # (Or in the nearest class if there is one.) + stnode = SymbolTableNode(GDEF, info, self.semanalyzer.cur_mod_id) + if self.semanalyzer.type: + self.semanalyzer.type.names[name] = stnode + else: + self.semanalyzer.globals[name] = stnode + call.analyzed = EnumCallExpr(info, items, values) + call.analyzed.set_line(call.line, call.column) + return info + + def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo: + base = self.named_type_or_none(fullname) + assert base is not None + info = self.basic_new_typeinfo(name, base) + info.is_enum = True + for item in items: + var = Var(item) + var.info = info + var.is_property = True + info.names[item] = SymbolTableNode(MDEF, var) + return info + + def parse_enum_call_args(self, call: CallExpr, + class_name: str) -> Tuple[List[str], + List[Optional[Expression]], bool]: + args = call.args + if len(args) < 2: + return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call) + if len(args) > 2: + return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call) + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call) + if not isinstance(args[0], (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() expects a string literal as the first argument" % class_name, call) + items = [] + values = [] # type: List[Optional[Expression]] + if isinstance(args[1], (StrExpr, UnicodeExpr)): + fields = args[1].value + for field in fields.replace(',', ' ').split(): + items.append(field) + elif isinstance(args[1], (TupleExpr, ListExpr)): + seq_items = args[1].items + if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): + items = [cast(StrExpr, seq_item).value for seq_item in seq_items] + elif all(isinstance(seq_item, (TupleExpr, ListExpr)) + and len(seq_item.items) == 2 + and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) + for seq_item in seq_items): + for seq_item in seq_items: + assert isinstance(seq_item, (TupleExpr, ListExpr)) + name, value = seq_item.items + assert isinstance(name, (StrExpr, UnicodeExpr)) + items.append(name.value) + values.append(value) + else: + return self.fail_enum_call_arg( + "%s() with tuple or list expects strings or (name, value) pairs" % + class_name, + call) + elif isinstance(args[1], DictExpr): + for key, value in args[1].items: + if not isinstance(key, (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() with dict literal requires string literals" % class_name, call) + items.append(key.value) + values.append(value) + else: + # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}? + return self.fail_enum_call_arg( + "%s() expects a string, tuple, list or dict literal as the second argument" % + class_name, + call) + if len(items) == 0: + return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call) + if not values: + values = [None] * len(items) + assert len(items) == len(values) + return items, values, True + + def fail_enum_call_arg(self, message: str, + context: Context) -> Tuple[List[str], + List[Optional[Expression]], bool]: + self.fail(message, context) + return [], [], False From 59b906a58b6db4fea30df48bdbf328c9266fd6af Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 5 Apr 2017 04:19:06 +0300 Subject: [PATCH 02/18] slight refactoring --- mypy/semanal.py | 13 +++++++++- mypy/specialtype.py | 63 ++++++++++++--------------------------------- 2 files changed, 29 insertions(+), 47 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index c9770ade12e0a..a69fe1a415d64 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -1508,7 +1508,8 @@ def analyze_member_lvalue(self, lval: MemberExpr) -> None: self.type.names[lval.name] = SymbolTableNode(MDEF, v) self.check_lvalue_validity(lval.node, lval) - def is_self_member_ref(self, memberexpr: MemberExpr) -> bool: + @staticmethod + def is_self_member_ref(memberexpr: MemberExpr) -> bool: """Does memberexpr to refer to an attribute of self?""" if not isinstance(memberexpr.expr, NameExpr): return False @@ -1521,6 +1522,16 @@ def check_lvalue_validity(self, node: Union[Expression, SymbolNode], ctx: Contex elif isinstance(node, TypeInfo): self.fail(CANNOT_ASSIGN_TO_TYPE, ctx) + def store_info(self, info: TypeInfo, name: str) -> None: + # Store it as a global just in case it would remain anonymous. + # (Or in the nearest class if there is one.) + # called from specialtype + stnode = SymbolTableNode(GDEF, info, self.cur_mod_id) + if self.type: + self.type.names[name] = stnode + else: + self.globals[name] = stnode + def store_declared_types(self, lvalue: Lvalue, typ: Type) -> None: if isinstance(typ, StarType) and not isinstance(lvalue, StarExpr): self.fail('Star type only allowed for starred expressions', lvalue) diff --git a/mypy/specialtype.py b/mypy/specialtype.py index 3978e24714f27..c9023f7e6fcc9 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -1,6 +1,6 @@ from collections import OrderedDict -from typing import List, Dict, Tuple, cast, Optional +from typing import List, Dict, Tuple, cast, Optional, Callable from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.nodes import ( @@ -16,7 +16,6 @@ NoneTyp, CallableType, Instance, Type, TypeVarType, AnyType, TypeVarDef, TupleType, UnboundType, TypedDictType, TypeType, ) -from mypy.semanal import SemanticAnalyzer from mypy import join @@ -34,9 +33,8 @@ class Special: * process_declarations() * analyze_* """ - semanalyzer = None # type: SemanticAnalyzer - def __init__(self, semanalyzer: SemanticAnalyzer) -> None: + def __init__(self, semanalyzer: 'mypy.semanal.SemanticAnalyzer') -> None: self.semanalyzer = semanalyzer # Delegations: self.fail = semanalyzer.fail @@ -476,19 +474,25 @@ def process_typevar_parameters(self, variance = INVARIANT return (variance, upper_bound) - def process_namedtuple_definition(self, s: AssignmentStmt) -> None: - """Check if s defines a namedtuple; if yes, store the definition in symbol table.""" + def process_call(self, s: AssignmentStmt, check: Callable[[Expression, str], TypeInfo]): + """Check if s defines a legal node; if yes, store the definition in symbol table.""" if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): return lvalue = s.lvalues[0] name = lvalue.name - named_tuple = self.check_namedtuple(s.rvalue, name) - if named_tuple is None: + info = check(s.rvalue, name) + if info is None: return - # Yes, it's a valid namedtuple definition. Add it to the symbol table. + # Yes, it's a valid definition. Add it to the symbol table. node = self.lookup(name, s) node.kind = GDEF # TODO locally defined namedtuple - node.node = named_tuple + node.node = info + + def process_namedtuple_definition(self, s: AssignmentStmt) -> None: + self.process_call(s, self.check_namedtuple) + + def process_enum_call(self, s: AssignmentStmt) -> None: + self.process_call(s, self.check_enum_call) def check_namedtuple(self, expr: Expression, var_name: str = None) -> Optional[TypeInfo]: """Check if a call defines a namedtuple. @@ -519,13 +523,7 @@ def check_namedtuple(self, expr: Expression, var_name: str = None) -> Optional[T # Give it a unique name derived from the line number. name += '@' + str(call.line) info = self.build_namedtuple_typeinfo(name, items, types, {}) - # Store it as a global just in case it would remain anonymous. - # (Or in the nearest class if there is one.) - stnode = SymbolTableNode(GDEF, info, self.semanalyzer.cur_mod_id) - if self.semanalyzer.type: - self.semanalyzer.type.names[name] = stnode - else: - self.semanalyzer.globals[name] = stnode + self.semanalyzer.store_info(info, name) call.analyzed = NamedTupleExpr(info) call.analyzed.set_line(call.line, call.column) return info @@ -753,13 +751,7 @@ def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[Ty # Give it a unique name derived from the line number. name += '@' + str(call.line) info = self.build_typeddict_typeinfo(name, items, types) - # Store it as a global just in case it would remain anonymous. - # (Or in the nearest class if there is one.) - stnode = SymbolTableNode(GDEF, info, self.semanalyzer.cur_mod_id) - if self.semanalyzer.type: - self.semanalyzer.type.names[name] = stnode - else: - self.semanalyzer.globals[name] = stnode + self.semanalyzer.store_info(info, name) call.analyzed = TypedDictExpr(info) call.analyzed.set_line(call.line, call.column) return info @@ -817,21 +809,6 @@ def build_typeddict_typeinfo(self, name: str, items: List[str], return info - def process_enum_call(self, s: AssignmentStmt) -> None: - """Check if s defines an Enum; if yes, store the definition in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - lvalue = s.lvalues[0] - name = lvalue.name - enum_call = self.check_enum_call(s.rvalue, name) - if enum_call is None: - return - # Yes, it's a valid Enum definition. Add it to the symbol table. - node = self.lookup(name, s) - if node: - node.kind = GDEF # TODO locally defined Enum - node.node = enum_call - def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: """Check if a call defines an Enum. @@ -863,13 +840,7 @@ class A(enum.Enum): # Give it a unique name derived from the line number. name += '@' + str(call.line) info = self.build_enum_call_typeinfo(name, items, fullname) - # Store it as a global just in case it would remain anonymous. - # (Or in the nearest class if there is one.) - stnode = SymbolTableNode(GDEF, info, self.semanalyzer.cur_mod_id) - if self.semanalyzer.type: - self.semanalyzer.type.names[name] = stnode - else: - self.semanalyzer.globals[name] = stnode + self.semanalyzer.store_info(info, name) call.analyzed = EnumCallExpr(info, items, values) call.analyzed.set_line(call.line, call.column) return info From b58cae79743e477ca7a7e30636fd9806928a0260 Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 5 Apr 2017 04:47:05 +0300 Subject: [PATCH 03/18] lint --- mypy/specialtype.py | 37 ++++++++++++++++++++----------------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/mypy/specialtype.py b/mypy/specialtype.py index c9023f7e6fcc9..ecea4d3211215 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -28,8 +28,8 @@ class Special: Also handles analysis of special constructs: * Enum (functional style) * TypeVar - - The interface consists of + + The interface consists of * process_declarations() * analyze_* """ @@ -82,8 +82,8 @@ def check_namedtuple_classdef( if not isinstance(stmt, AssignmentStmt): # Still allow pass or ... (for empty namedtuples). if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): + not (isinstance(stmt, ExpressionStmt) and + isinstance(stmt.expr, EllipsisExpr))): self.fail(NAMEDTUP_CLASS_ERROR, stmt) elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): # An assignment, but an invalid one. @@ -92,7 +92,8 @@ def check_namedtuple_classdef( # Append name and type in this case... name = stmt.lvalues[0].name items.append(name) - types.append(AnyType() if stmt.type is None else self.semanalyzer.anal_type(stmt.type)) + types.append(AnyType() if stmt.type is None + else self.semanalyzer.anal_type(stmt.type)) # ...despite possible minor failures that allow further analyzis. if name.startswith('_'): self.fail('NamedTuple field name cannot start with an underscore: {}' @@ -192,7 +193,8 @@ def check_typeddict_classdef(self, defn: ClassDef, continue # Append name and type in this case... fields.append(name) - types.append(AnyType() if stmt.type is None else self.semanalyzer.anal_type(stmt.type)) + types.append(AnyType() if stmt.type is None + else self.semanalyzer.anal_type(stmt.type)) # ...despite possible minor failures that allow further analyzis. if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: self.fail(TPDICT_CLASS_ERROR, stmt) @@ -496,12 +498,12 @@ def process_enum_call(self, s: AssignmentStmt) -> None: def check_namedtuple(self, expr: Expression, var_name: str = None) -> Optional[TypeInfo]: """Check if a call defines a namedtuple. - + The optional var_name argument is the name of the variable to which this is assigned, if any. - + If it does, return the corresponding TypeInfo. Return None otherwise. - + If the definition is invalid but looks like a namedtuple, report errors but return (some) TypeInfo. """ @@ -572,7 +574,8 @@ def parse_namedtuple_args(self, call: CallExpr, + ', '.join(underscore), call) return items, types, ok - def parse_namedtuple_fields_with_types(self, nodes: List[Expression]) -> Tuple[List[str], List[Type], bool]: + def parse_namedtuple_fields_with_types(self, nodes: List[Expression] + ) -> Tuple[List[str], List[Type], bool]: items = [] # type: List[str] types = [] # type: List[Type] for item in nodes: @@ -725,12 +728,12 @@ def process_typeddict_definition(self, s: AssignmentStmt) -> None: def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: """Check if a call defines a TypedDict. - + The optional var_name argument is the name of the variable to which this is assigned, if any. - + If it does, return the corresponding TypeInfo. Return None otherwise. - + If the definition is invalid but looks like a TypedDict, report errors but return (some) TypeInfo. """ @@ -811,13 +814,13 @@ def build_typeddict_typeinfo(self, name: str, items: List[str], def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: """Check if a call defines an Enum. - + Example: - + A = enum.Enum('A', 'foo bar') - + is equivalent to: - + class A(enum.Enum): foo = 1 bar = 2 From 2e2f2961e00c33591ee946683e3e99a6333152be Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 5 Apr 2017 08:48:16 +0300 Subject: [PATCH 04/18] more refactoring --- mypy/specialtype.py | 1837 +++++++++++++++++++++---------------------- 1 file changed, 912 insertions(+), 925 deletions(-) diff --git a/mypy/specialtype.py b/mypy/specialtype.py index ecea4d3211215..f021f164bcb36 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -1,925 +1,912 @@ -from collections import OrderedDict - -from typing import List, Dict, Tuple, cast, Optional, Callable - -from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError -from mypy.nodes import ( - TypeInfo, AssignmentStmt, FuncDef, ClassDef, Var, GDEF, Expression, - Block, NameExpr, TupleExpr, ListExpr, ExpressionStmt, PassStmt, - DictExpr, CallExpr, RefExpr, Context, SymbolTable, UNBOUND_TVAR, - MDEF, Decorator, TypeVarExpr, NewTypeExpr, StrExpr, BytesExpr, - ARG_POS, ARG_NAMED, ARG_NAMED_OPT, NamedTupleExpr, TypedDictExpr, Argument, - UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, - COVARIANT, CONTRAVARIANT, INVARIANT, ARG_OPT, SymbolTableNode -) -from mypy.types import ( - NoneTyp, CallableType, Instance, Type, TypeVarType, AnyType, - TypeVarDef, TupleType, UnboundType, TypedDictType, TypeType, -) -from mypy import join - - -class Special: - """ - Groups special-cased types: - * NamedTuple - * TypedDict - * NewType - Also handles analysis of special constructs: - * Enum (functional style) - * TypeVar - - The interface consists of - * process_declarations() - * analyze_* - """ - - def __init__(self, semanalyzer: 'mypy.semanal.SemanticAnalyzer') -> None: - self.semanalyzer = semanalyzer - # Delegations: - self.fail = semanalyzer.fail - self.lookup = semanalyzer.lookup - self.lookup_qualified = semanalyzer.lookup_qualified - self.named_type = semanalyzer.named_type - self.named_type_or_none = semanalyzer.named_type_or_none - self.object_type = semanalyzer.object_type - self.str_type = semanalyzer.str_type - - def process_declaration(self, s: AssignmentStmt) -> None: - self.process_newtype_declaration(s) - self.process_typevar_declaration(s) - self.process_namedtuple_definition(s) - self.process_typeddict_definition(s) - self.process_enum_call(s) - - def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: - for base_expr in defn.base_type_exprs: - if isinstance(base_expr, RefExpr): - base_expr.accept(self.semanalyzer) - if base_expr.fullname == 'typing.NamedTuple': - node = self.lookup(defn.name, defn) - if node is not None: - node.kind = GDEF # TODO in process_namedtuple_definition also applies here - items, types, default_items = self.check_namedtuple_classdef(defn) - node.node = self.build_namedtuple_typeinfo( - defn.name, items, types, default_items) - return True - return False - - def check_namedtuple_classdef( - self, defn: ClassDef) -> Tuple[List[str], List[Type], Dict[str, Expression]]: - NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' - 'expected "field_name: field_type"') - if self.semanalyzer.options.python_version < (3, 6): - self.fail('NamedTuple class syntax is only supported in Python 3.6', defn) - return [], [], {} - if len(defn.base_type_exprs) > 1: - self.fail('NamedTuple should be a single base', defn) - items = [] # type: List[str] - types = [] # type: List[Type] - default_items = {} # type: Dict[str, Expression] - for stmt in defn.defs.body: - if not isinstance(stmt, AssignmentStmt): - # Still allow pass or ... (for empty namedtuples). - if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): - self.fail(NAMEDTUP_CLASS_ERROR, stmt) - elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): - # An assignment, but an invalid one. - self.fail(NAMEDTUP_CLASS_ERROR, stmt) - else: - # Append name and type in this case... - name = stmt.lvalues[0].name - items.append(name) - types.append(AnyType() if stmt.type is None - else self.semanalyzer.anal_type(stmt.type)) - # ...despite possible minor failures that allow further analyzis. - if name.startswith('_'): - self.fail('NamedTuple field name cannot start with an underscore: {}' - .format(name), stmt) - if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: - self.fail(NAMEDTUP_CLASS_ERROR, stmt) - elif isinstance(stmt.rvalue, TempNode): - # x: int assigns rvalue to TempNode(AnyType()) - if default_items: - self.fail('Non-default NamedTuple fields cannot follow default fields', - stmt) - else: - default_items[name] = stmt.rvalue - return items, types, default_items - - @staticmethod - def is_typeddict(expr: Expression) -> bool: - return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and - expr.node.typeddict_type is not None) - - def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: - # special case for TypedDict - possible = False - for base_expr in defn.base_type_exprs: - if isinstance(base_expr, RefExpr): - base_expr.accept(self.semanalyzer) - if (base_expr.fullname == 'mypy_extensions.TypedDict' or - self.is_typeddict(base_expr)): - possible = True - if possible: - node = self.lookup(defn.name, defn) - if node is not None: - node.kind = GDEF # TODO in process_namedtuple_definition also applies here - if (len(defn.base_type_exprs) == 1 and - isinstance(defn.base_type_exprs[0], RefExpr) and - defn.base_type_exprs[0].fullname == 'mypy_extensions.TypedDict'): - # Building a new TypedDict - fields, types = self.check_typeddict_classdef(defn) - node.node = self.build_typeddict_typeinfo(defn.name, fields, types) - return True - # Extending/merging existing TypedDicts - if any(not isinstance(expr, RefExpr) or - expr.fullname != 'mypy_extensions.TypedDict' and - not self.is_typeddict(expr) for expr in defn.base_type_exprs): - self.fail("All bases of a new TypedDict must be TypedDict types", defn) - typeddict_bases = list(filter(self.is_typeddict, defn.base_type_exprs)) - newfields = [] # type: List[str] - newtypes = [] # type: List[Type] - tpdict = None # type: OrderedDict[str, Type] - for base in typeddict_bases: - assert isinstance(base, RefExpr) - assert isinstance(base.node, TypeInfo) - assert isinstance(base.node.typeddict_type, TypedDictType) - tpdict = base.node.typeddict_type.items - newdict = tpdict.copy() - for key in tpdict: - if key in newfields: - self.fail('Cannot overwrite TypedDict field "{}" while merging' - .format(key), defn) - newdict.pop(key) - newfields.extend(newdict.keys()) - newtypes.extend(newdict.values()) - fields, types = self.check_typeddict_classdef(defn, newfields) - newfields.extend(fields) - newtypes.extend(types) - node.node = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) - return True - return False - - def check_typeddict_classdef(self, defn: ClassDef, - oldfields: List[str] = None) -> Tuple[List[str], List[Type]]: - TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' - 'expected "field_name: field_type"') - if self.semanalyzer.options.python_version < (3, 6): - self.fail('TypedDict class syntax is only supported in Python 3.6', defn) - return [], [] - fields = [] # type: List[str] - types = [] # type: List[Type] - for stmt in defn.defs.body: - if not isinstance(stmt, AssignmentStmt): - # Still allow pass or ... (for empty TypedDict's). - if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): - self.fail(TPDICT_CLASS_ERROR, stmt) - elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): - # An assignment, but an invalid one. - self.fail(TPDICT_CLASS_ERROR, stmt) - else: - name = stmt.lvalues[0].name - if name in (oldfields or []): - self.fail('Cannot overwrite TypedDict field "{}" while extending' - .format(name), stmt) - continue - if name in fields: - self.fail('Duplicate TypedDict field "{}"'.format(name), stmt) - continue - # Append name and type in this case... - fields.append(name) - types.append(AnyType() if stmt.type is None - else self.semanalyzer.anal_type(stmt.type)) - # ...despite possible minor failures that allow further analyzis. - if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: - self.fail(TPDICT_CLASS_ERROR, stmt) - elif not isinstance(stmt.rvalue, TempNode): - # x: int assigns rvalue to TempNode(AnyType()) - self.fail('Right hand side values are not supported in TypedDict', stmt) - return fields, types - - def process_newtype_declaration(self, s: AssignmentStmt) -> None: - """Check if s declares a NewType; if yes, store it in symbol table.""" - # Extract and check all information from newtype declaration - name, call = self.analyze_newtype_declaration(s) - if name is None or call is None: - return - - old_type = self.check_newtype_args(name, call, s) - call.analyzed = NewTypeExpr(name, old_type, line=call.line) - if old_type is None: - return - - # Create the corresponding class definition if the aliased type is subtypeable - if isinstance(old_type, TupleType): - newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type.fallback) - newtype_class_info.tuple_type = old_type - elif isinstance(old_type, Instance): - newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type) - else: - message = "Argument 2 to NewType(...) must be subclassable (got {})" - self.fail(message.format(old_type), s) - return - - # If so, add it to the symbol table. - node = self.lookup(name, s) - if node is None: - self.fail("Could not find {} in current namespace".format(name), s) - return - # TODO: why does NewType work in local scopes despite always being of kind GDEF? - node.kind = GDEF - call.analyzed.info = node.node = newtype_class_info - - def analyze_newtype_declaration(self, s: AssignmentStmt - ) -> Tuple[Optional[str], Optional[CallExpr]]: - """Return the NewType call expression if `s` is a newtype declaration or None otherwise.""" - name, call = None, None - if (len(s.lvalues) == 1 - and isinstance(s.lvalues[0], NameExpr) - and isinstance(s.rvalue, CallExpr) - and isinstance(s.rvalue.callee, RefExpr) - and s.rvalue.callee.fullname == 'typing.NewType'): - lvalue = s.lvalues[0] - name = s.lvalues[0].name - if not lvalue.is_def: - if s.type: - self.fail("Cannot declare the type of a NewType declaration", s) - else: - self.fail("Cannot redefine '%s' as a NewType" % name, s) - - # This dummy NewTypeExpr marks the call as sufficiently analyzed; it will be - # overwritten later with a fully complete NewTypeExpr if there are no other - # errors with the NewType() call. - call = s.rvalue - - return name, call - - def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Optional[Type]: - has_failed = False - args, arg_kinds = call.args, call.arg_kinds - if len(args) != 2 or arg_kinds[0] != ARG_POS or arg_kinds[1] != ARG_POS: - self.fail("NewType(...) expects exactly two positional arguments", context) - return None - - # Check first argument - if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): - self.fail("Argument 1 to NewType(...) must be a string literal", context) - has_failed = True - elif args[0].value != name: - msg = "String argument 1 '{}' to NewType(...) does not match variable name '{}'" - self.fail(msg.format(args[0].value, name), context) - has_failed = True - - # Check second argument - try: - unanalyzed_type = expr_to_unanalyzed_type(args[1]) - except TypeTranslationError: - self.fail("Argument 2 to NewType(...) must be a valid type", context) - return None - old_type = self.semanalyzer.anal_type(unanalyzed_type) - - if isinstance(old_type, Instance) and old_type.type.is_newtype: - self.fail("Argument 2 to NewType(...) cannot be another NewType", context) - has_failed = True - - return None if has_failed else old_type - - def build_newtype_typeinfo(self, name: str, old_type: Type, base_type: Instance) -> TypeInfo: - info = self.basic_new_typeinfo(name, base_type) - info.is_newtype = True - - # Add __init__ method - args = [Argument(Var('cls'), NoneTyp(), None, ARG_POS), - self.make_argument('item', old_type)] - signature = CallableType( - arg_types=[cast(Type, None), old_type], - arg_kinds=[arg.kind for arg in args], - arg_names=['self', 'item'], - ret_type=old_type, - fallback=self.named_type('__builtins__.function'), - name=name) - init_func = FuncDef('__init__', args, Block([]), typ=signature) - init_func.info = info - info.names['__init__'] = SymbolTableNode(MDEF, init_func) - - return info - - def analyze_typevar_declaration(self, t: Type) -> Optional[List[Tuple[str, TypeVarExpr]]]: - if not isinstance(t, UnboundType): - return None - unbound = t - sym = self.lookup_qualified(unbound.name, unbound) - if sym is None or sym.node is None: - return None - if sym.node.fullname() == 'typing.Generic': - tvars = [] # type: List[Tuple[str, TypeVarExpr]] - for arg in unbound.args: - tvar = self.semanalyzer.analyze_unbound_tvar(arg) - if tvar: - tvars.append(tvar) - else: - self.fail('Free type variable expected in %s[...]' % - sym.node.name(), t) - return tvars - return None - - def analyze_types(self, items: List[Expression]) -> List[Type]: - result = [] # type: List[Type] - for node in items: - try: - result.append(self.semanalyzer.anal_type(expr_to_unanalyzed_type(node))) - except TypeTranslationError: - self.fail('Type expected', node) - result.append(AnyType()) - return result - - def process_typevar_declaration(self, s: AssignmentStmt) -> None: - """Check if s declares a TypeVar; it yes, store it in symbol table.""" - call = self.get_typevar_declaration(s) - if not call: - return - - lvalue = s.lvalues[0] - assert isinstance(lvalue, NameExpr) - name = lvalue.name - if not lvalue.is_def: - if s.type: - self.fail("Cannot declare the type of a type variable", s) - else: - self.fail("Cannot redefine '%s' as a type variable" % name, s) - return - - if not self.check_typevar_name(call, name, s): - return - - # Constraining types - n_values = call.arg_kinds[1:].count(ARG_POS) - values = self.analyze_types(call.args[1:1 + n_values]) - - res = self.process_typevar_parameters(call.args[1 + n_values:], - call.arg_names[1 + n_values:], - call.arg_kinds[1 + n_values:], - n_values, - s) - if res is None: - return - variance, upper_bound = res - - # Yes, it's a valid type variable definition! Add it to the symbol table. - node = self.lookup(name, s) - node.kind = UNBOUND_TVAR - TypeVar = TypeVarExpr(name, node.fullname, values, upper_bound, variance) - TypeVar.line = call.line - call.analyzed = TypeVar - node.node = TypeVar - - def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> bool: - if len(call.args) < 1: - self.fail("Too few arguments for TypeVar()", context) - return False - if (not isinstance(call.args[0], (StrExpr, BytesExpr, UnicodeExpr)) - or not call.arg_kinds[0] == ARG_POS): - self.fail("TypeVar() expects a string literal as first argument", context) - return False - elif call.args[0].value != name: - msg = "String argument 1 '{}' to TypeVar(...) does not match variable name '{}'" - self.fail(msg.format(call.args[0].value, name), context) - return False - return True - - @staticmethod - def get_typevar_declaration(s: AssignmentStmt) -> Optional[CallExpr]: - """Returns the TypeVar() call expression if `s` is a type var declaration - or None otherwise. - """ - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return None - if not isinstance(s.rvalue, CallExpr): - return None - call = s.rvalue - callee = call.callee - if not isinstance(callee, RefExpr): - return None - if callee.fullname != 'typing.TypeVar': - return None - return call - - def process_typevar_parameters(self, - args: List[Expression], - names: List[Optional[str]], - kinds: List[int], - num_values: int, - context: Context) -> Optional[Tuple[int, Type]]: - has_values = (num_values > 0) - covariant = False - contravariant = False - upper_bound = self.object_type() # type: Type - for param_value, param_name, param_kind in zip(args, names, kinds): - if not param_kind == ARG_NAMED: - self.fail("Unexpected argument to TypeVar()", context) - return None - if param_name == 'covariant': - if isinstance(param_value, NameExpr): - if param_value.name == 'True': - covariant = True - else: - self.fail("TypeVar 'covariant' may only be 'True'", context) - return None - else: - self.fail("TypeVar 'covariant' may only be 'True'", context) - return None - elif param_name == 'contravariant': - if isinstance(param_value, NameExpr): - if param_value.name == 'True': - contravariant = True - else: - self.fail("TypeVar 'contravariant' may only be 'True'", context) - return None - else: - self.fail("TypeVar 'contravariant' may only be 'True'", context) - return None - elif param_name == 'bound': - if has_values: - self.fail("TypeVar cannot have both values and an upper bound", context) - return None - try: - upper_bound = self.semanalyzer.expr_to_analyzed_type(param_value) - except TypeTranslationError: - self.fail("TypeVar 'bound' must be a type", param_value) - return None - elif param_name == 'values': - # Probably using obsolete syntax with values=(...). Explain the current syntax. - self.fail("TypeVar 'values' argument not supported", context) - self.fail("Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))", - context) - return None - else: - self.fail("Unexpected argument to TypeVar(): {}".format(param_name), context) - return None - - if covariant and contravariant: - self.fail("TypeVar cannot be both covariant and contravariant", context) - return None - elif num_values == 1: - self.fail("TypeVar cannot have only a single constraint", context) - return None - elif covariant: - variance = COVARIANT - elif contravariant: - variance = CONTRAVARIANT - else: - variance = INVARIANT - return (variance, upper_bound) - - def process_call(self, s: AssignmentStmt, check: Callable[[Expression, str], TypeInfo]): - """Check if s defines a legal node; if yes, store the definition in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - lvalue = s.lvalues[0] - name = lvalue.name - info = check(s.rvalue, name) - if info is None: - return - # Yes, it's a valid definition. Add it to the symbol table. - node = self.lookup(name, s) - node.kind = GDEF # TODO locally defined namedtuple - node.node = info - - def process_namedtuple_definition(self, s: AssignmentStmt) -> None: - self.process_call(s, self.check_namedtuple) - - def process_enum_call(self, s: AssignmentStmt) -> None: - self.process_call(s, self.check_enum_call) - - def check_namedtuple(self, expr: Expression, var_name: str = None) -> Optional[TypeInfo]: - """Check if a call defines a namedtuple. - - The optional var_name argument is the name of the variable to - which this is assigned, if any. - - If it does, return the corresponding TypeInfo. Return None otherwise. - - If the definition is invalid but looks like a namedtuple, - report errors but return (some) TypeInfo. - """ - if not isinstance(expr, CallExpr): - return None - call = expr - callee = call.callee - if not isinstance(callee, RefExpr): - return None - fullname = callee.fullname - if fullname not in ('collections.namedtuple', 'typing.NamedTuple'): - return None - items, types, ok = self.parse_namedtuple_args(call, fullname) - if not ok: - # Error. Construct dummy return value. - return self.build_namedtuple_typeinfo('namedtuple', [], [], {}) - name = cast(StrExpr, call.args[0]).value - if name != var_name or self.semanalyzer.is_func_scope(): - # Give it a unique name derived from the line number. - name += '@' + str(call.line) - info = self.build_namedtuple_typeinfo(name, items, types, {}) - self.semanalyzer.store_info(info, name) - call.analyzed = NamedTupleExpr(info) - call.analyzed.set_line(call.line, call.column) - return info - - def parse_namedtuple_args(self, call: CallExpr, - fullname: str) -> Tuple[List[str], List[Type], bool]: - # TODO: Share code with check_argument_count in checkexpr.py? - args = call.args - if len(args) < 2: - return self.fail_namedtuple_arg("Too few arguments for namedtuple()", call) - if len(args) > 2: - # FIX incorrect. There are two additional parameters - return self.fail_namedtuple_arg("Too many arguments for namedtuple()", call) - if call.arg_kinds != [ARG_POS, ARG_POS]: - return self.fail_namedtuple_arg("Unexpected arguments to namedtuple()", call) - if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): - return self.fail_namedtuple_arg( - "namedtuple() expects a string literal as the first argument", call) - types = [] # type: List[Type] - ok = True - if not isinstance(args[1], (ListExpr, TupleExpr)): - if (fullname == 'collections.namedtuple' - and isinstance(args[1], (StrExpr, BytesExpr, UnicodeExpr))): - str_expr = cast(StrExpr, args[1]) - items = str_expr.value.replace(',', ' ').split() - else: - return self.fail_namedtuple_arg( - "List or tuple literal expected as the second argument to namedtuple()", call) - else: - listexpr = args[1] - if fullname == 'collections.namedtuple': - # The fields argument contains just names, with implicit Any types. - if any(not isinstance(item, (StrExpr, BytesExpr, UnicodeExpr)) - for item in listexpr.items): - return self.fail_namedtuple_arg("String literal expected as namedtuple() item", - call) - items = [cast(StrExpr, item).value for item in listexpr.items] - else: - # The fields argument contains (name, type) tuples. - items, types, ok = self.parse_namedtuple_fields_with_types(listexpr.items) - if not types: - types = [AnyType() for _ in items] - underscore = [item for item in items if item.startswith('_')] - if underscore: - self.fail("namedtuple() field names cannot start with an underscore: " - + ', '.join(underscore), call) - return items, types, ok - - def parse_namedtuple_fields_with_types(self, nodes: List[Expression] - ) -> Tuple[List[str], List[Type], bool]: - items = [] # type: List[str] - types = [] # type: List[Type] - for item in nodes: - if isinstance(item, TupleExpr): - if len(item.items) != 2: - return self.fail_namedtuple_arg("Invalid NamedTuple field definition", - item) - name, type_node = item.items - if isinstance(name, (StrExpr, BytesExpr, UnicodeExpr)): - items.append(name.value) - else: - return self.fail_namedtuple_arg("Invalid NamedTuple() field name", item) - try: - type = expr_to_unanalyzed_type(type_node) - except TypeTranslationError: - return self.fail_namedtuple_arg('Invalid field type', type_node) - types.append(self.semanalyzer.anal_type(type)) - else: - return self.fail_namedtuple_arg("Tuple expected as NamedTuple() field", item) - return items, types, True - - def fail_namedtuple_arg(self, message: str, - context: Context) -> Tuple[List[str], List[Type], bool]: - self.fail(message, context) - return [], [], False - - def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance) -> TypeInfo: - class_def = ClassDef(name, Block([])) - class_def.fullname = self.semanalyzer.qualified_name(name) - - info = TypeInfo(SymbolTable(), class_def, self.semanalyzer.cur_mod_id) - info.mro = [info] + basetype_or_fallback.type.mro - info.bases = [basetype_or_fallback] - return info - - def analyze_callexpr_as_type(self, call: CallExpr) -> Optional[Type]: - info = self.check_namedtuple(call) - if info is None: - # Some form of namedtuple is the only valid type that looks like a call - # expression. This isn't a valid type. - return None - fallback = Instance(info, []) - return TupleType(info.tuple_type.items, fallback=fallback) - - def build_namedtuple_typeinfo(self, name: str, items: List[str], types: List[Type], - default_items: Dict[str, Expression]) -> TypeInfo: - strtype = self.str_type() - object_type = self.object_type() - basetuple_type = self.named_type('__builtins__.tuple', [AnyType()]) - dictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) - or object_type) - # Actual signature should return OrderedDict[str, Union[types]] - ordereddictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) - or object_type) - fallback = self.named_type('__builtins__.tuple', types) - # Note: actual signature should accept an invariant version of Iterable[UnionType[types]]. - # but it can't be expressed. 'new' and 'len' should be callable types. - iterable_type = self.named_type_or_none('typing.Iterable', [AnyType()]) - function_type = self.named_type('__builtins__.function') - - info = self.basic_new_typeinfo(name, fallback) - info.is_named_tuple = True - info.tuple_type = TupleType(types, fallback) - - def add_field(var: Var, is_initialized_in_class: bool = False, - is_property: bool = False) -> None: - var.info = info - var.is_initialized_in_class = is_initialized_in_class - var.is_property = is_property - info.names[var.name()] = SymbolTableNode(MDEF, var) - - vars = [Var(item, typ) for item, typ in zip(items, types)] - for var in vars: - add_field(var, is_property=True) - - tuple_of_strings = TupleType([strtype for _ in items], basetuple_type) - add_field(Var('_fields', tuple_of_strings), is_initialized_in_class=True) - add_field(Var('_field_types', dictype), is_initialized_in_class=True) - add_field(Var('_field_defaults', dictype), is_initialized_in_class=True) - add_field(Var('_source', strtype), is_initialized_in_class=True) - - tvd = TypeVarDef('NT', 1, [], info.tuple_type) - selftype = TypeVarType(tvd) - - def add_method(funcname: str, - ret: Type, - args: List[Argument], - name: str = None, - is_classmethod: bool = False, - ) -> None: - if is_classmethod: - first = [Argument(Var('cls'), TypeType(selftype), None, ARG_POS)] - else: - first = [Argument(Var('self'), selftype, None, ARG_POS)] - args = first + args - - types = [arg.type_annotation for arg in args] - items = [arg.variable.name() for arg in args] - arg_kinds = [arg.kind for arg in args] - signature = CallableType(types, arg_kinds, items, ret, function_type, - name=name or info.name() + '.' + funcname) - signature.variables = [tvd] - func = FuncDef(funcname, args, Block([]), typ=signature) - func.info = info - func.is_class = is_classmethod - if is_classmethod: - v = Var(funcname, signature) - v.is_classmethod = True - v.info = info - dec = Decorator(func, [NameExpr('classmethod')], v) - info.names[funcname] = SymbolTableNode(MDEF, dec) - else: - info.names[funcname] = SymbolTableNode(MDEF, func) - - add_method('_replace', ret=selftype, - args=[Argument(var, var.type, EllipsisExpr(), ARG_NAMED_OPT) for var in vars]) - - def make_init_arg(var: Var) -> Argument: - default = default_items.get(var.name(), None) - kind = ARG_POS if default is None else ARG_OPT - return Argument(var, var.type, default, kind) - - add_method('__init__', ret=NoneTyp(), name=info.name(), - args=[make_init_arg(var) for var in vars]) - add_method('_asdict', args=[], ret=ordereddictype) - add_method('_make', ret=selftype, is_classmethod=True, - args=[Argument(Var('iterable', iterable_type), iterable_type, None, ARG_POS), - Argument(Var('new'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT), - Argument(Var('len'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT)]) - return info - - @staticmethod - def make_argument(name: str, type: Type) -> Argument: - return Argument(Var(name), type, None, ARG_POS) - - def process_typeddict_definition(self, s: AssignmentStmt) -> None: - """Check if s defines a TypedDict; if yes, store the definition in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - lvalue = s.lvalues[0] - name = lvalue.name - typed_dict = self.check_typeddict(s.rvalue, name) - if typed_dict is None: - return - # Yes, it's a valid TypedDict definition. Add it to the symbol table. - node = self.lookup(name, s) - if node: - node.kind = GDEF # TODO locally defined TypedDict - node.node = typed_dict - - def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: - """Check if a call defines a TypedDict. - - The optional var_name argument is the name of the variable to - which this is assigned, if any. - - If it does, return the corresponding TypeInfo. Return None otherwise. - - If the definition is invalid but looks like a TypedDict, - report errors but return (some) TypeInfo. - """ - if not isinstance(node, CallExpr): - return None - call = node - callee = call.callee - if not isinstance(callee, RefExpr): - return None - if callee.fullname != 'mypy_extensions.TypedDict': - return None - items, types, ok = self.parse_typeddict_args(call) - if not ok: - # Error. Construct dummy return value. - return self.build_typeddict_typeinfo('TypedDict', [], []) - name = cast(StrExpr, call.args[0]).value - if name != var_name or self.semanalyzer.is_func_scope(): - # Give it a unique name derived from the line number. - name += '@' + str(call.line) - info = self.build_typeddict_typeinfo(name, items, types) - self.semanalyzer.store_info(info, name) - call.analyzed = TypedDictExpr(info) - call.analyzed.set_line(call.line, call.column) - return info - - def parse_typeddict_args(self, call: CallExpr) -> Tuple[List[str], List[Type], bool]: - # TODO: Share code with check_argument_count in checkexpr.py? - args = call.args - if len(args) < 2: - return self.fail_typeddict_arg("Too few arguments for TypedDict()", call) - if len(args) > 2: - return self.fail_typeddict_arg("Too many arguments for TypedDict()", call) - # TODO: Support keyword arguments - if call.arg_kinds != [ARG_POS, ARG_POS]: - return self.fail_typeddict_arg("Unexpected arguments to TypedDict()", call) - if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): - return self.fail_typeddict_arg( - "TypedDict() expects a string literal as the first argument", call) - if not isinstance(args[1], DictExpr): - return self.fail_typeddict_arg( - "TypedDict() expects a dictionary literal as the second argument", call) - dictexpr = args[1] - items, types, ok = self.parse_typeddict_fields_with_types(dictexpr.items) - return items, types, ok - - def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, Expression]], - ) -> Tuple[List[str], List[Type], bool]: - items = [] # type: List[str] - types = [] # type: List[Type] - for (field_name_expr, field_type_expr) in dict_items: - if isinstance(field_name_expr, (StrExpr, BytesExpr, UnicodeExpr)): - items.append(field_name_expr.value) - else: - return self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr) - try: - type = expr_to_unanalyzed_type(field_type_expr) - except TypeTranslationError: - return self.fail_typeddict_arg('Invalid field type', field_type_expr) - types.append(self.semanalyzer.anal_type(type)) - return items, types, True - - def fail_typeddict_arg(self, message: str, - context: Context) -> Tuple[List[str], List[Type], bool]: - self.fail(message, context) - return [], [], False - - def build_typeddict_typeinfo(self, name: str, items: List[str], - types: List[Type]) -> TypeInfo: - mapping_value_type = join.join_type_list(types) - fallback = (self.named_type_or_none('typing.Mapping', - [self.str_type(), mapping_value_type]) - or self.object_type()) - - info = self.basic_new_typeinfo(name, fallback) - info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), fallback) - - return info - - def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: - """Check if a call defines an Enum. - - Example: - - A = enum.Enum('A', 'foo bar') - - is equivalent to: - - class A(enum.Enum): - foo = 1 - bar = 2 - """ - if not isinstance(node, CallExpr): - return None - call = node - callee = call.callee - if not isinstance(callee, RefExpr): - return None - fullname = callee.fullname - if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): - return None - items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1]) - if not ok: - # Error. Construct dummy return value. - return self.build_enum_call_typeinfo('Enum', [], fullname) - name = cast(StrExpr, call.args[0]).value - if name != var_name or self.semanalyzer.is_func_scope(): - # Give it a unique name derived from the line number. - name += '@' + str(call.line) - info = self.build_enum_call_typeinfo(name, items, fullname) - self.semanalyzer.store_info(info, name) - call.analyzed = EnumCallExpr(info, items, values) - call.analyzed.set_line(call.line, call.column) - return info - - def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo: - base = self.named_type_or_none(fullname) - assert base is not None - info = self.basic_new_typeinfo(name, base) - info.is_enum = True - for item in items: - var = Var(item) - var.info = info - var.is_property = True - info.names[item] = SymbolTableNode(MDEF, var) - return info - - def parse_enum_call_args(self, call: CallExpr, - class_name: str) -> Tuple[List[str], - List[Optional[Expression]], bool]: - args = call.args - if len(args) < 2: - return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call) - if len(args) > 2: - return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call) - if call.arg_kinds != [ARG_POS, ARG_POS]: - return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call) - if not isinstance(args[0], (StrExpr, UnicodeExpr)): - return self.fail_enum_call_arg( - "%s() expects a string literal as the first argument" % class_name, call) - items = [] - values = [] # type: List[Optional[Expression]] - if isinstance(args[1], (StrExpr, UnicodeExpr)): - fields = args[1].value - for field in fields.replace(',', ' ').split(): - items.append(field) - elif isinstance(args[1], (TupleExpr, ListExpr)): - seq_items = args[1].items - if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): - items = [cast(StrExpr, seq_item).value for seq_item in seq_items] - elif all(isinstance(seq_item, (TupleExpr, ListExpr)) - and len(seq_item.items) == 2 - and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) - for seq_item in seq_items): - for seq_item in seq_items: - assert isinstance(seq_item, (TupleExpr, ListExpr)) - name, value = seq_item.items - assert isinstance(name, (StrExpr, UnicodeExpr)) - items.append(name.value) - values.append(value) - else: - return self.fail_enum_call_arg( - "%s() with tuple or list expects strings or (name, value) pairs" % - class_name, - call) - elif isinstance(args[1], DictExpr): - for key, value in args[1].items: - if not isinstance(key, (StrExpr, UnicodeExpr)): - return self.fail_enum_call_arg( - "%s() with dict literal requires string literals" % class_name, call) - items.append(key.value) - values.append(value) - else: - # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}? - return self.fail_enum_call_arg( - "%s() expects a string, tuple, list or dict literal as the second argument" % - class_name, - call) - if len(items) == 0: - return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call) - if not values: - values = [None] * len(items) - assert len(items) == len(values) - return items, values, True - - def fail_enum_call_arg(self, message: str, - context: Context) -> Tuple[List[str], - List[Optional[Expression]], bool]: - self.fail(message, context) - return [], [], False +from collections import OrderedDict + +from typing import List, Dict, Tuple, cast, Optional, Callable, TYPE_CHECKING + +from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError +from mypy.nodes import ( + TypeInfo, AssignmentStmt, FuncDef, ClassDef, Var, GDEF, Expression, + Block, NameExpr, TupleExpr, ListExpr, ExpressionStmt, PassStmt, + DictExpr, CallExpr, RefExpr, Context, SymbolTable, UNBOUND_TVAR, + MDEF, Decorator, TypeVarExpr, NewTypeExpr, StrExpr, BytesExpr, + ARG_POS, ARG_NAMED, ARG_NAMED_OPT, NamedTupleExpr, TypedDictExpr, Argument, + UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, + COVARIANT, CONTRAVARIANT, INVARIANT, ARG_OPT, SymbolTableNode +) +from mypy.types import ( + NoneTyp, CallableType, Instance, Type, TypeVarType, AnyType, + TypeVarDef, TupleType, UnboundType, TypedDictType, TypeType, +) +from mypy import join +if TYPE_CHECKING: + import mypy.semanal + + +class Special: + """ + Groups special-cased types: + * NamedTuple + * TypedDict + * NewType + Also handles analysis of special constructs: + * Enum (functional style) + * TypeVar + + The interface consists of + * process_declarations() + * analyze_* + """ + + def __init__(self, semanalyzer: 'mypy.semanal.SemanticAnalyzer') -> None: + self.semanalyzer = semanalyzer + # Delegations: + self.fail = semanalyzer.fail + self.lookup = semanalyzer.lookup + self.lookup_qualified = semanalyzer.lookup_qualified + self.named_type = semanalyzer.named_type + self.named_type_or_none = semanalyzer.named_type_or_none + self.object_type = semanalyzer.object_type + self.str_type = semanalyzer.str_type + + def process_declaration(self, s: AssignmentStmt) -> None: + self.process_newtype_declaration(s) + self.process_typevar_declaration(s) + self.process_namedtuple_definition(s) + self.process_typeddict_definition(s) + self.process_enum_call(s) + + def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: + for base_expr in defn.base_type_exprs: + if isinstance(base_expr, RefExpr): + base_expr.accept(self.semanalyzer) + if base_expr.fullname == 'typing.NamedTuple': + node = self.lookup(defn.name, defn) + if node is not None: + node.kind = GDEF # TODO in process_namedtuple_definition also applies here + items, types, default_items = self.check_namedtuple_classdef(defn) + node.node = self.build_namedtuple_typeinfo( + defn.name, items, types, default_items) + return True + return False + + def check_namedtuple_classdef( + self, defn: ClassDef) -> Tuple[List[str], List[Type], Dict[str, Expression]]: + NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' + 'expected "field_name: field_type"') + if self.semanalyzer.options.python_version < (3, 6): + self.fail('NamedTuple class syntax is only supported in Python 3.6', defn) + return [], [], {} + if len(defn.base_type_exprs) > 1: + self.fail('NamedTuple should be a single base', defn) + items = [] # type: List[str] + types = [] # type: List[Type] + default_items = {} # type: Dict[str, Expression] + for stmt in defn.defs.body: + if not isinstance(stmt, AssignmentStmt): + # Still allow pass or ... (for empty namedtuples). + if (not isinstance(stmt, PassStmt) and + not (isinstance(stmt, ExpressionStmt) and + isinstance(stmt.expr, EllipsisExpr))): + self.fail(NAMEDTUP_CLASS_ERROR, stmt) + elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): + # An assignment, but an invalid one. + self.fail(NAMEDTUP_CLASS_ERROR, stmt) + else: + # Append name and type in this case... + name = stmt.lvalues[0].name + items.append(name) + types.append(AnyType() if stmt.type is None + else self.semanalyzer.anal_type(stmt.type)) + # ...despite possible minor failures that allow further analyzis. + if name.startswith('_'): + self.fail('NamedTuple field name cannot start with an underscore: {}' + .format(name), stmt) + if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: + self.fail(NAMEDTUP_CLASS_ERROR, stmt) + elif isinstance(stmt.rvalue, TempNode): + # x: int assigns rvalue to TempNode(AnyType()) + if default_items: + self.fail('Non-default NamedTuple fields cannot follow default fields', + stmt) + else: + default_items[name] = stmt.rvalue + return items, types, default_items + + @staticmethod + def is_typeddict(expr: Expression) -> bool: + return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and + expr.node.typeddict_type is not None) + + def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: + # special case for TypedDict + possible = False + for base_expr in defn.base_type_exprs: + if isinstance(base_expr, RefExpr): + base_expr.accept(self.semanalyzer) + if (base_expr.fullname == 'mypy_extensions.TypedDict' or + self.is_typeddict(base_expr)): + possible = True + if possible: + node = self.lookup(defn.name, defn) + if node is not None: + node.kind = GDEF # TODO in process_namedtuple_definition also applies here + if (len(defn.base_type_exprs) == 1 and + isinstance(defn.base_type_exprs[0], RefExpr) and + defn.base_type_exprs[0].fullname == 'mypy_extensions.TypedDict'): + # Building a new TypedDict + fields, types = self.check_typeddict_classdef(defn) + node.node = self.build_typeddict_typeinfo(defn.name, fields, types) + return True + # Extending/merging existing TypedDicts + if any(not isinstance(expr, RefExpr) or + expr.fullname != 'mypy_extensions.TypedDict' and + not self.is_typeddict(expr) for expr in defn.base_type_exprs): + self.fail("All bases of a new TypedDict must be TypedDict types", defn) + typeddict_bases = list(filter(self.is_typeddict, defn.base_type_exprs)) + newfields = [] # type: List[str] + newtypes = [] # type: List[Type] + tpdict = None # type: OrderedDict[str, Type] + for base in typeddict_bases: + assert isinstance(base, RefExpr) + assert isinstance(base.node, TypeInfo) + assert isinstance(base.node.typeddict_type, TypedDictType) + tpdict = base.node.typeddict_type.items + newdict = tpdict.copy() + for key in tpdict: + if key in newfields: + self.fail('Cannot overwrite TypedDict field "{}" while merging' + .format(key), defn) + newdict.pop(key) + newfields.extend(newdict.keys()) + newtypes.extend(newdict.values()) + fields, types = self.check_typeddict_classdef(defn, newfields) + newfields.extend(fields) + newtypes.extend(types) + node.node = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) + return True + return False + + def check_typeddict_classdef(self, defn: ClassDef, + oldfields: List[str] = None) -> Tuple[List[str], List[Type]]: + TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' + 'expected "field_name: field_type"') + if self.semanalyzer.options.python_version < (3, 6): + self.fail('TypedDict class syntax is only supported in Python 3.6', defn) + return [], [] + fields = [] # type: List[str] + types = [] # type: List[Type] + for stmt in defn.defs.body: + if not isinstance(stmt, AssignmentStmt): + # Still allow pass or ... (for empty TypedDict's). + if (not isinstance(stmt, PassStmt) and + not (isinstance(stmt, ExpressionStmt) and + isinstance(stmt.expr, EllipsisExpr))): + self.fail(TPDICT_CLASS_ERROR, stmt) + elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): + # An assignment, but an invalid one. + self.fail(TPDICT_CLASS_ERROR, stmt) + else: + name = stmt.lvalues[0].name + if name in (oldfields or []): + self.fail('Cannot overwrite TypedDict field "{}" while extending' + .format(name), stmt) + continue + if name in fields: + self.fail('Duplicate TypedDict field "{}"'.format(name), stmt) + continue + # Append name and type in this case... + fields.append(name) + types.append(AnyType() if stmt.type is None + else self.semanalyzer.anal_type(stmt.type)) + # ...despite possible minor failures that allow further analyzis. + if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: + self.fail(TPDICT_CLASS_ERROR, stmt) + elif not isinstance(stmt.rvalue, TempNode): + # x: int assigns rvalue to TempNode(AnyType()) + self.fail('Right hand side values are not supported in TypedDict', stmt) + return fields, types + + def process_newtype_declaration(self, s: AssignmentStmt) -> None: + """Check if s declares a NewType; if yes, store it in symbol table.""" + # Extract and check all information from newtype declaration + name, call = self.analyze_newtype_declaration(s) + if name is None or call is None: + return + + old_type = self.check_newtype_args(name, call, s) + call.analyzed = NewTypeExpr(name, old_type, line=call.line) + if old_type is None: + return + + # Create the corresponding class definition if the aliased type is subtypeable + if isinstance(old_type, TupleType): + newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type.fallback) + newtype_class_info.tuple_type = old_type + elif isinstance(old_type, Instance): + newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type) + else: + message = "Argument 2 to NewType(...) must be subclassable (got {})" + self.fail(message.format(old_type), s) + return + + # If so, add it to the symbol table. + node = self.lookup(name, s) + if node is None: + self.fail("Could not find {} in current namespace".format(name), s) + return + # TODO: why does NewType work in local scopes despite always being of kind GDEF? + node.kind = GDEF + call.analyzed.info = node.node = newtype_class_info + + def analyze_newtype_declaration(self, s: AssignmentStmt + ) -> Tuple[Optional[str], Optional[CallExpr]]: + """Return the NewType call expression if `s` is a newtype declaration or None otherwise.""" + name, call = None, None + if (len(s.lvalues) == 1 + and isinstance(s.lvalues[0], NameExpr) + and isinstance(s.rvalue, CallExpr) + and isinstance(s.rvalue.callee, RefExpr) + and s.rvalue.callee.fullname == 'typing.NewType'): + lvalue = s.lvalues[0] + name = s.lvalues[0].name + if not lvalue.is_def: + if s.type: + self.fail("Cannot declare the type of a NewType declaration", s) + else: + self.fail("Cannot redefine '%s' as a NewType" % name, s) + + # This dummy NewTypeExpr marks the call as sufficiently analyzed; it will be + # overwritten later with a fully complete NewTypeExpr if there are no other + # errors with the NewType() call. + call = s.rvalue + + return name, call + + def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Optional[Type]: + has_failed = False + args, arg_kinds = call.args, call.arg_kinds + if len(args) != 2 or arg_kinds[0] != ARG_POS or arg_kinds[1] != ARG_POS: + self.fail("NewType(...) expects exactly two positional arguments", context) + return None + + # Check first argument + if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): + self.fail("Argument 1 to NewType(...) must be a string literal", context) + has_failed = True + elif args[0].value != name: + msg = "String argument 1 '{}' to NewType(...) does not match variable name '{}'" + self.fail(msg.format(args[0].value, name), context) + has_failed = True + + # Check second argument + try: + unanalyzed_type = expr_to_unanalyzed_type(args[1]) + except TypeTranslationError: + self.fail("Argument 2 to NewType(...) must be a valid type", context) + return None + old_type = self.semanalyzer.anal_type(unanalyzed_type) + + if isinstance(old_type, Instance) and old_type.type.is_newtype: + self.fail("Argument 2 to NewType(...) cannot be another NewType", context) + has_failed = True + + return None if has_failed else old_type + + def build_newtype_typeinfo(self, name: str, old_type: Type, base_type: Instance) -> TypeInfo: + info = self.basic_new_typeinfo(name, base_type) + info.is_newtype = True + + # Add __init__ method + args = [Argument(Var('cls'), NoneTyp(), None, ARG_POS), + self.make_argument('item', old_type)] + signature = CallableType( + arg_types=[cast(Type, None), old_type], + arg_kinds=[arg.kind for arg in args], + arg_names=['self', 'item'], + ret_type=old_type, + fallback=self.named_type('__builtins__.function'), + name=name) + init_func = FuncDef('__init__', args, Block([]), typ=signature) + init_func.info = info + info.names['__init__'] = SymbolTableNode(MDEF, init_func) + + return info + + def analyze_typevar_declaration(self, t: Type) -> Optional[List[Tuple[str, TypeVarExpr]]]: + if not isinstance(t, UnboundType): + return None + unbound = t + sym = self.lookup_qualified(unbound.name, unbound) + if sym is None or sym.node is None: + return None + if sym.node.fullname() == 'typing.Generic': + tvars = [] # type: List[Tuple[str, TypeVarExpr]] + for arg in unbound.args: + tvar = self.semanalyzer.analyze_unbound_tvar(arg) + if tvar: + tvars.append(tvar) + else: + self.fail('Free type variable expected in %s[...]' % + sym.node.name(), t) + return tvars + return None + + def analyze_types(self, items: List[Expression]) -> List[Type]: + result = [] # type: List[Type] + for node in items: + try: + result.append(self.semanalyzer.anal_type(expr_to_unanalyzed_type(node))) + except TypeTranslationError: + self.fail('Type expected', node) + result.append(AnyType()) + return result + + def process_typevar_declaration(self, s: AssignmentStmt) -> None: + """Check if s declares a TypeVar; it yes, store it in symbol table.""" + call = self.get_typevar_declaration(s) + if not call: + return + + lvalue = s.lvalues[0] + assert isinstance(lvalue, NameExpr) + name = lvalue.name + if not lvalue.is_def: + if s.type: + self.fail("Cannot declare the type of a type variable", s) + else: + self.fail("Cannot redefine '%s' as a type variable" % name, s) + return + + if not self.check_typevar_name(call, name, s): + return + + # Constraining types + n_values = call.arg_kinds[1:].count(ARG_POS) + values = self.analyze_types(call.args[1:1 + n_values]) + + res = self.process_typevar_parameters(call.args[1 + n_values:], + call.arg_names[1 + n_values:], + call.arg_kinds[1 + n_values:], + n_values, + s) + if res is None: + return + variance, upper_bound = res + + # Yes, it's a valid type variable definition! Add it to the symbol table. + node = self.lookup(name, s) + node.kind = UNBOUND_TVAR + TypeVar = TypeVarExpr(name, node.fullname, values, upper_bound, variance) + TypeVar.line = call.line + call.analyzed = TypeVar + node.node = TypeVar + + def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> bool: + if len(call.args) < 1: + self.fail("Too few arguments for TypeVar()", context) + return False + if (not isinstance(call.args[0], (StrExpr, BytesExpr, UnicodeExpr)) + or not call.arg_kinds[0] == ARG_POS): + self.fail("TypeVar() expects a string literal as first argument", context) + return False + elif call.args[0].value != name: + msg = "String argument 1 '{}' to TypeVar(...) does not match variable name '{}'" + self.fail(msg.format(call.args[0].value, name), context) + return False + return True + + @staticmethod + def get_typevar_declaration(s: AssignmentStmt) -> Optional[CallExpr]: + """Returns the TypeVar() call expression if `s` is a type var declaration + or None otherwise. + """ + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return None + if not isinstance(s.rvalue, CallExpr): + return None + call = s.rvalue + callee = call.callee + if not isinstance(callee, RefExpr): + return None + if callee.fullname != 'typing.TypeVar': + return None + return call + + def process_typevar_parameters(self, + args: List[Expression], + names: List[Optional[str]], + kinds: List[int], + num_values: int, + context: Context) -> Optional[Tuple[int, Type]]: + has_values = (num_values > 0) + covariant = False + contravariant = False + upper_bound = self.object_type() # type: Type + for param_value, param_name, param_kind in zip(args, names, kinds): + if not param_kind == ARG_NAMED: + self.fail("Unexpected argument to TypeVar()", context) + return None + if param_name == 'covariant': + if isinstance(param_value, NameExpr): + if param_value.name == 'True': + covariant = True + else: + self.fail("TypeVar 'covariant' may only be 'True'", context) + return None + else: + self.fail("TypeVar 'covariant' may only be 'True'", context) + return None + elif param_name == 'contravariant': + if isinstance(param_value, NameExpr): + if param_value.name == 'True': + contravariant = True + else: + self.fail("TypeVar 'contravariant' may only be 'True'", context) + return None + else: + self.fail("TypeVar 'contravariant' may only be 'True'", context) + return None + elif param_name == 'bound': + if has_values: + self.fail("TypeVar cannot have both values and an upper bound", context) + return None + try: + upper_bound = self.semanalyzer.expr_to_analyzed_type(param_value) + except TypeTranslationError: + self.fail("TypeVar 'bound' must be a type", param_value) + return None + elif param_name == 'values': + # Probably using obsolete syntax with values=(...). Explain the current syntax. + self.fail("TypeVar 'values' argument not supported", context) + self.fail("Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))", + context) + return None + else: + self.fail("Unexpected argument to TypeVar(): {}".format(param_name), context) + return None + + if covariant and contravariant: + self.fail("TypeVar cannot be both covariant and contravariant", context) + return None + elif num_values == 1: + self.fail("TypeVar cannot have only a single constraint", context) + return None + elif covariant: + variance = COVARIANT + elif contravariant: + variance = CONTRAVARIANT + else: + variance = INVARIANT + return (variance, upper_bound) + + def process_call(self, s: AssignmentStmt, + check: Callable[[Expression, str], TypeInfo]) -> None: + """Check if s defines a legal node; if yes, store the definition in symbol table.""" + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + lvalue = s.lvalues[0] + name = lvalue.name + info = check(s.rvalue, name) + if info is None: + return + # Yes, it's a valid definition. Add it to the symbol table. + node = self.lookup(name, s) + node.kind = GDEF # TODO locally defined type + node.node = info + + def process_namedtuple_definition(self, s: AssignmentStmt) -> None: + self.process_call(s, self.check_namedtuple) + + def process_enum_call(self, s: AssignmentStmt) -> None: + self.process_call(s, self.check_enum_call) + + def check_namedtuple(self, expr: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines a namedtuple. + + The optional var_name argument is the name of the variable to + which this is assigned, if any. + + If it does, return the corresponding TypeInfo. Return None otherwise. + + If the definition is invalid but looks like a namedtuple, + report errors but return (some) TypeInfo. + """ + + call, calleename, name = self.get_call(expr, var_name) + if calleename not in ('collections.namedtuple', 'typing.NamedTuple'): + return None + items, types, ok = self.parse_namedtuple_args(call, calleename) + info = self.build_namedtuple_typeinfo(name, items, types) + if ok: + self.semanalyzer.store_info(info, name) + call.analyzed = NamedTupleExpr(info) + call.analyzed.set_line(call.line, call.column) + return info + + def parse_namedtuple_args(self, call: CallExpr, + fullname: str) -> Tuple[List[str], List[Type], bool]: + # TODO: Share code with check_argument_count in checkexpr.py? + args = call.args + if len(args) < 2: + return self.fail_namedtuple_arg("Too few arguments for namedtuple()", call) + if len(args) > 2: + # FIX incorrect. There are two additional parameters + return self.fail_namedtuple_arg("Too many arguments for namedtuple()", call) + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_namedtuple_arg("Unexpected arguments to namedtuple()", call) + if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): + return self.fail_namedtuple_arg( + "namedtuple() expects a string literal as the first argument", call) + types = [] # type: List[Type] + ok = True + if not isinstance(args[1], (ListExpr, TupleExpr)): + if (fullname == 'collections.namedtuple' + and isinstance(args[1], (StrExpr, BytesExpr, UnicodeExpr))): + str_expr = cast(StrExpr, args[1]) + items = str_expr.value.replace(',', ' ').split() + else: + return self.fail_namedtuple_arg( + "List or tuple literal expected as the second argument to namedtuple()", call) + else: + listexpr = args[1] + if fullname == 'collections.namedtuple': + # The fields argument contains just names, with implicit Any types. + if any(not isinstance(item, (StrExpr, BytesExpr, UnicodeExpr)) + for item in listexpr.items): + return self.fail_namedtuple_arg("String literal expected as namedtuple() item", + call) + items = [cast(StrExpr, item).value for item in listexpr.items] + else: + # The fields argument contains (name, type) tuples. + items, types, ok = self.parse_namedtuple_fields_with_types(listexpr.items) + if not types: + types = [AnyType() for _ in items] + underscore = [item for item in items if item.startswith('_')] + if underscore: + self.fail("namedtuple() field names cannot start with an underscore: " + + ', '.join(underscore), call) + return items, types, ok + + def parse_namedtuple_fields_with_types(self, nodes: List[Expression] + ) -> Tuple[List[str], List[Type], bool]: + items = [] # type: List[str] + types = [] # type: List[Type] + for item in nodes: + if isinstance(item, TupleExpr): + if len(item.items) != 2: + return self.fail_namedtuple_arg("Invalid NamedTuple field definition", + item) + name, type_node = item.items + if isinstance(name, (StrExpr, BytesExpr, UnicodeExpr)): + items.append(name.value) + else: + return self.fail_namedtuple_arg("Invalid NamedTuple() field name", item) + try: + type = expr_to_unanalyzed_type(type_node) + except TypeTranslationError: + return self.fail_namedtuple_arg('Invalid field type', type_node) + types.append(self.semanalyzer.anal_type(type)) + else: + return self.fail_namedtuple_arg("Tuple expected as NamedTuple() field", item) + return items, types, True + + def fail_namedtuple_arg(self, message: str, + context: Context) -> Tuple[List[str], List[Type], bool]: + self.fail(message, context) + return [], [], False + + def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance) -> TypeInfo: + class_def = ClassDef(name, Block([])) + class_def.fullname = self.semanalyzer.qualified_name(name) + + info = TypeInfo(SymbolTable(), class_def, self.semanalyzer.cur_mod_id) + info.mro = [info] + basetype_or_fallback.type.mro + info.bases = [basetype_or_fallback] + return info + + def analyze_callexpr_as_type(self, call: CallExpr) -> Optional[Type]: + info = self.check_namedtuple(call) + if info is None: + # Some form of namedtuple is the only valid type that looks like a call + # expression. This isn't a valid type. + return None + fallback = Instance(info, []) + return TupleType(info.tuple_type.items, fallback=fallback) + + def build_namedtuple_typeinfo(self, name: str, items: List[str], types: List[Type], + default_items: Dict[str, Expression] = None) -> TypeInfo: + default_items = default_items or {} + strtype = self.str_type() + object_type = self.object_type() + basetuple_type = self.named_type('__builtins__.tuple', [AnyType()]) + dictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) + or object_type) + # Actual signature should return OrderedDict[str, Union[types]] + ordereddictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) + or object_type) + fallback = self.named_type('__builtins__.tuple', types) + # Note: actual signature should accept an invariant version of Iterable[UnionType[types]]. + # but it can't be expressed. 'new' and 'len' should be callable types. + iterable_type = self.named_type_or_none('typing.Iterable', [AnyType()]) + function_type = self.named_type('__builtins__.function') + + info = self.basic_new_typeinfo(name, fallback) + info.is_named_tuple = True + info.tuple_type = TupleType(types, fallback) + + def add_field(var: Var, is_initialized_in_class: bool = False, + is_property: bool = False) -> None: + var.info = info + var.is_initialized_in_class = is_initialized_in_class + var.is_property = is_property + info.names[var.name()] = SymbolTableNode(MDEF, var) + + vars = [Var(item, typ) for item, typ in zip(items, types)] + for var in vars: + add_field(var, is_property=True) + + tuple_of_strings = TupleType([strtype for _ in items], basetuple_type) + add_field(Var('_fields', tuple_of_strings), is_initialized_in_class=True) + add_field(Var('_field_types', dictype), is_initialized_in_class=True) + add_field(Var('_field_defaults', dictype), is_initialized_in_class=True) + add_field(Var('_source', strtype), is_initialized_in_class=True) + + tvd = TypeVarDef('NT', 1, [], info.tuple_type) + selftype = TypeVarType(tvd) + + def add_method(funcname: str, + ret: Type, + args: List[Argument], + name: str = None, + is_classmethod: bool = False, + ) -> None: + if is_classmethod: + first = [Argument(Var('cls'), TypeType(selftype), None, ARG_POS)] + else: + first = [Argument(Var('self'), selftype, None, ARG_POS)] + args = first + args + + types = [arg.type_annotation for arg in args] + items = [arg.variable.name() for arg in args] + arg_kinds = [arg.kind for arg in args] + signature = CallableType(types, arg_kinds, items, ret, function_type, + name=name or info.name() + '.' + funcname) + signature.variables = [tvd] + func = FuncDef(funcname, args, Block([]), typ=signature) + func.info = info + func.is_class = is_classmethod + if is_classmethod: + v = Var(funcname, signature) + v.is_classmethod = True + v.info = info + dec = Decorator(func, [NameExpr('classmethod')], v) + info.names[funcname] = SymbolTableNode(MDEF, dec) + else: + info.names[funcname] = SymbolTableNode(MDEF, func) + + add_method('_replace', ret=selftype, + args=[Argument(var, var.type, EllipsisExpr(), ARG_NAMED_OPT) for var in vars]) + + def make_init_arg(var: Var) -> Argument: + default = default_items.get(var.name(), None) + kind = ARG_POS if default is None else ARG_OPT + return Argument(var, var.type, default, kind) + + add_method('__init__', ret=NoneTyp(), name=info.name(), + args=[make_init_arg(var) for var in vars]) + add_method('_asdict', args=[], ret=ordereddictype) + add_method('_make', ret=selftype, is_classmethod=True, + args=[Argument(Var('iterable', iterable_type), iterable_type, None, ARG_POS), + Argument(Var('new'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT), + Argument(Var('len'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT)]) + return info + + @staticmethod + def make_argument(name: str, type: Type) -> Argument: + return Argument(Var(name), type, None, ARG_POS) + + def process_typeddict_definition(self, s: AssignmentStmt) -> None: + """Check if s defines a TypedDict; if yes, store the definition in symbol table.""" + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + lvalue = s.lvalues[0] + name = lvalue.name + typed_dict = self.check_typeddict(s.rvalue, name) + if typed_dict is None: + return + # Yes, it's a valid TypedDict definition. Add it to the symbol table. + node = self.lookup(name, s) + if node: + node.kind = GDEF # TODO locally defined TypedDict + node.node = typed_dict + + def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines a TypedDict. + + The optional var_name argument is the name of the variable to + which this is assigned, if any. + + If it does, return the corresponding TypeInfo. Return None otherwise. + + If the definition is invalid but looks like a TypedDict, + report errors but return (some) TypeInfo. + """ + call, calleename, name = self.get_call(node, var_name) + if calleename != 'mypy_extensions.TypedDict': + return None + items, types, ok = self.parse_typeddict_args(call) + info = self.build_typeddict_typeinfo(name, items, types) + if ok: + self.semanalyzer.store_info(info, name) + call.analyzed = TypedDictExpr(info) + call.analyzed.set_line(call.line, call.column) + return info + + def get_call(self, expr: Expression, var_name: str) -> Tuple[CallExpr, str, str]: + (call, calleename, name) = None, '', '' + if isinstance(expr, CallExpr): + call = expr + callee = call.callee + if isinstance(callee, RefExpr): + calleename = callee.fullname + if len(call.args) > 0: + name = getattr(call.args[0], 'value', var_name) + if isinstance(name, str): + if name != var_name or self.semanalyzer.is_func_scope(): + # Give it a unique name derived from the line number. + name += '@' + str(call.line) + else: + name = var_name + return (call, calleename, name) + + def parse_typeddict_args(self, call: CallExpr) -> Tuple[List[str], List[Type], bool]: + # TODO: Share code with check_argument_count in checkexpr.py? + args = call.args + if len(args) < 2: + return self.fail_typeddict_arg("Too few arguments for TypedDict()", call) + if len(args) > 2: + return self.fail_typeddict_arg("Too many arguments for TypedDict()", call) + # TODO: Support keyword arguments + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_typeddict_arg("Unexpected arguments to TypedDict()", call) + if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): + return self.fail_typeddict_arg( + "TypedDict() expects a string literal as the first argument", call) + if not isinstance(args[1], DictExpr): + return self.fail_typeddict_arg( + "TypedDict() expects a dictionary literal as the second argument", call) + dictexpr = args[1] + items, types, ok = self.parse_typeddict_fields_with_types(dictexpr.items) + return items, types, ok + + def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, Expression]], + ) -> Tuple[List[str], List[Type], bool]: + items = [] # type: List[str] + types = [] # type: List[Type] + for (field_name_expr, field_type_expr) in dict_items: + if isinstance(field_name_expr, (StrExpr, BytesExpr, UnicodeExpr)): + items.append(field_name_expr.value) + else: + return self.fail_typeddict_arg("Invalid TypedDict() field name", field_name_expr) + try: + type = expr_to_unanalyzed_type(field_type_expr) + except TypeTranslationError: + return self.fail_typeddict_arg('Invalid field type', field_type_expr) + types.append(self.semanalyzer.anal_type(type)) + return items, types, True + + def fail_typeddict_arg(self, message: str, + context: Context) -> Tuple[List[str], List[Type], bool]: + self.fail(message, context) + return [], [], False + + def build_typeddict_typeinfo(self, name: str, items: List[str], + types: List[Type]) -> TypeInfo: + mapping_value_type = join.join_type_list(types) + fallback = (self.named_type_or_none('typing.Mapping', + [self.str_type(), mapping_value_type]) + or self.object_type()) + + info = self.basic_new_typeinfo(name, fallback) + info.typeddict_type = TypedDictType(OrderedDict(zip(items, types)), fallback) + + return info + + def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: + """Check if a call defines an Enum. + + Example: + + A = enum.Enum('A', 'foo bar') + + is equivalent to: + + class A(enum.Enum): + foo = 1 + bar = 2 + """ + call, calleename, name = self.get_call(node, var_name) + if calleename not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): + return None + items, values, ok = self.parse_enum_call_args(call, calleename.split('.')[-1]) + info = self.build_enum_call_typeinfo(name, items, calleename) + if ok: + self.semanalyzer.store_info(info, name) + call.analyzed = EnumCallExpr(info, items, values) + call.analyzed.set_line(call.line, call.column) + return info + + def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo: + base = self.named_type_or_none(fullname) + assert base is not None + info = self.basic_new_typeinfo(name, base) + info.is_enum = True + for item in items: + var = Var(item) + var.info = info + var.is_property = True + info.names[item] = SymbolTableNode(MDEF, var) + return info + + def parse_enum_call_args(self, call: CallExpr, + class_name: str) -> Tuple[List[str], + List[Optional[Expression]], bool]: + args = call.args + if len(args) < 2: + return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call) + if len(args) > 2: + return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call) + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call) + if not isinstance(args[0], (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() expects a string literal as the first argument" % class_name, call) + items = [] + values = [] # type: List[Optional[Expression]] + if isinstance(args[1], (StrExpr, UnicodeExpr)): + fields = args[1].value + for field in fields.replace(',', ' ').split(): + items.append(field) + elif isinstance(args[1], (TupleExpr, ListExpr)): + seq_items = args[1].items + if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): + items = [cast(StrExpr, seq_item).value for seq_item in seq_items] + elif all(isinstance(seq_item, (TupleExpr, ListExpr)) + and len(seq_item.items) == 2 + and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) + for seq_item in seq_items): + for seq_item in seq_items: + assert isinstance(seq_item, (TupleExpr, ListExpr)) + name, value = seq_item.items + assert isinstance(name, (StrExpr, UnicodeExpr)) + items.append(name.value) + values.append(value) + else: + return self.fail_enum_call_arg( + "%s() with tuple or list expects strings or (name, value) pairs" % + class_name, + call) + elif isinstance(args[1], DictExpr): + for key, value in args[1].items: + if not isinstance(key, (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() with dict literal requires string literals" % class_name, call) + items.append(key.value) + values.append(value) + else: + # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}? + return self.fail_enum_call_arg( + "%s() expects a string, tuple, list or dict literal as the second argument" % + class_name, + call) + if len(items) == 0: + return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call) + if not values: + values = [None] * len(items) + assert len(items) == len(values) + return items, values, True + + def fail_enum_call_arg(self, message: str, + context: Context) -> Tuple[List[str], + List[Optional[Expression]], bool]: + self.fail(message, context) + return [], [], False From 47dafec0f1b2bdeab36c61164c825d094d264663 Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 5 Apr 2017 17:54:40 +0300 Subject: [PATCH 05/18] remove staticmethod and process_typevar_declaration --- mypy/nodes.py | 57 +++++++++++++ mypy/semanal.py | 190 +++++++++++++++----------------------------- mypy/specialtype.py | 60 +++++--------- 3 files changed, 142 insertions(+), 165 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index 4584245b9904d..758bb163d263e 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -561,6 +561,21 @@ def name(self) -> str: def accept(self, visitor: StatementVisitor[T]) -> T: return visitor.visit_func_def(self) + def set_original_def(self, previous: Node) -> bool: + """If 'self' conditionally redefine 'previous', set 'previous' as original + + We reject straight redefinitions of functions, as they are usually + a programming error. For example: + + . def f(): ... + . def f(): ... # Error: 'f' redefined + """ + if isinstance(previous, (FuncDef, Var)) and self.is_conditional: + self.original_def = previous + return True + else: + return False + def serialize(self) -> JsonDict: # We're deliberating omitting arguments and storing only arg_names and # arg_kinds for space-saving reasons (arguments is not used in later @@ -2056,6 +2071,33 @@ def get_method(self, name: str) -> FuncBase: return None return None + def calculate_abstract_status(self) -> None: + """Calculate abstract status of a class. + + Set is_abstract of the type to True if the type has an unimplemented + abstract attribute. Also compute a list of abstract attributes. + """ + concrete = set() # type: Set[str] + abstract = [] # type: List[str] + for base in self.mro: + for name, symnode in base.names.items(): + node = symnode.node + if isinstance(node, OverloadedFuncDef): + # Unwrap an overloaded function definition. We can just + # check arbitrarily the first overload item. If the + # different items have a different abstract status, there + # should be an error reported elsewhere. + func = node.items[0] # type: Node + else: + func = node + if isinstance(func, Decorator): + fdef = func.func + if fdef.is_abstract and name not in concrete: + self.is_abstract = True + abstract.append(name) + concrete.add(name) + self.abstract_attributes = sorted(abstract) + def calculate_mro(self) -> None: """Calculate and set mro (method resolution order). @@ -2116,6 +2158,21 @@ def direct_base_classes(self) -> 'List[TypeInfo]': """ return [base.type for base in self.bases] + def is_base_class(self, s: 'TypeInfo') -> bool: + """Determine if self is a base class of s (but do not use mro).""" + # Search the base class graph for t, starting from s. + worklist = [s] + visited = {s} + while worklist: + nxt = worklist.pop() + if nxt == self: + return True + for base in nxt.bases: + if base.type not in visited: + worklist.append(base.type) + visited.add(base.type) + return False + def __str__(self) -> str: """Return a string representation of the type. diff --git a/mypy/semanal.py b/mypy/semanal.py index a69fe1a415d64..ce85f72e4b333 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -293,7 +293,7 @@ def visit_func_def(self, defn: FuncDef) -> None: if defn.name() in self.type.names: # Redefinition. Conditional redefinition is okay. n = self.type.names[defn.name()].node - if not self.set_original_def(n, defn): + if not defn.set_original_def(n): self.name_already_defined(defn.name(), defn) self.type.names[defn.name()] = SymbolTableNode(MDEF, defn) self.prepare_method_signature(defn) @@ -303,7 +303,7 @@ def visit_func_def(self, defn: FuncDef) -> None: if defn.name() in self.locals[-1]: # Redefinition. Conditional redefinition is okay. n = self.locals[-1][defn.name()].node - if not self.set_original_def(n, defn): + if not defn.set_original_def(n): self.name_already_defined(defn.name(), defn) else: self.add_local(defn, defn) @@ -313,7 +313,7 @@ def visit_func_def(self, defn: FuncDef) -> None: symbol = self.globals.get(defn.name()) if isinstance(symbol.node, FuncDef) and symbol.node != defn: # This is redefinition. Conditional redefinition is okay. - if not self.set_original_def(symbol.node, defn): + if not defn.set_original_def(symbol.node): # Report error. self.check_no_global(defn.name(), defn, True) if phase_info == FUNCTION_FIRST_PHASE_POSTPONE_SECOND: @@ -352,22 +352,6 @@ def prepare_method_signature(self, func: FuncDef) -> None: leading_type = fill_typevars(self.type) func.type = replace_implicit_first_type(functype, leading_type) - @staticmethod - def set_original_def(previous: Node, new: FuncDef) -> bool: - """If 'new' conditionally redefine 'previous', set 'previous' as original - - We reject straight redefinitions of functions, as they are usually - a programming error. For example: - - . def f(): ... - . def f(): ... # Error: 'f' redefined - """ - if isinstance(previous, (FuncDef, Var)) and new.is_conditional: - new.original_def = previous - return True - else: - return False - def update_function_type_variables(self, defn: FuncDef) -> None: """Make any type variables in the signature of defn explicit. @@ -678,7 +662,7 @@ def visit_class_def(self, defn: ClassDef) -> None: # Analyze class body. defn.defs.accept(self) - self.calculate_abstract_status(defn.info) + defn.info.calculate_abstract_status() self.setup_type_promotion(defn) self.leave_class() @@ -721,34 +705,6 @@ def unbind_class_type_vars(self) -> None: def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None: decorator.accept(self) - @staticmethod - def calculate_abstract_status(typ: TypeInfo) -> None: - """Calculate abstract status of a class. - - Set is_abstract of the type to True if the type has an unimplemented - abstract attribute. Also compute a list of abstract attributes. - """ - concrete = set() # type: Set[str] - abstract = [] # type: List[str] - for base in typ.mro: - for name, symnode in base.names.items(): - node = symnode.node - if isinstance(node, OverloadedFuncDef): - # Unwrap an overloaded function definition. We can just - # check arbitrarily the first overload item. If the - # different items have a different abstract status, there - # should be an error reported elsewhere. - func = node.items[0] # type: Node - else: - func = node - if isinstance(func, Decorator): - fdef = func.func - if fdef.is_abstract and name not in concrete: - typ.is_abstract = True - abstract.append(name) - concrete.add(name) - typ.abstract_attributes = sorted(abstract) - def setup_type_promotion(self, defn: ClassDef) -> None: """Setup extra, ad-hoc subtyping relationships between classes (promotion). @@ -798,13 +754,13 @@ def clean_up_bases_and_infer_type_variables(self, defn: ClassDef) -> None: all_tvars = self.get_all_bases_tvars(defn, removed) if declared_tvars: - if len(self.remove_dups(declared_tvars)) < len(declared_tvars): + if len(remove_dups(declared_tvars)) < len(declared_tvars): self.fail("Duplicate type variables in Generic[...]", defn) - declared_tvars = self.remove_dups(declared_tvars) + declared_tvars = remove_dups(declared_tvars) if not set(all_tvars).issubset(set(declared_tvars)): self.fail("If Generic[...] is present it should list all type variables", defn) # In case of error, Generic tvars will go first - declared_tvars = self.remove_dups(declared_tvars + all_tvars) + declared_tvars = remove_dups(declared_tvars + all_tvars) else: declared_tvars = all_tvars for j, (name, tvar_expr) in enumerate(declared_tvars): @@ -838,7 +794,7 @@ def get_all_bases_tvars(self, defn: ClassDef, # This error will be caught later. continue tvars.extend(self.get_tvars(base)) - return self.remove_dups(tvars) + return remove_dups(tvars) def get_tvars(self, tp: Type) -> List[Tuple[str, TypeVarExpr]]: tvars = [] # type: List[Tuple[str, TypeVarExpr]] @@ -854,18 +810,7 @@ def get_tvars(self, tp: Type) -> List[Tuple[str, TypeVarExpr]]: tvars.append(tvar) else: tvars.extend(self.get_tvars(arg)) - return self.remove_dups(tvars) - - @staticmethod - def remove_dups(tvars: List[T]) -> List[T]: - # Get unique elements in order of appearance - all_tvars = set(tvars) - new_tvars = [] # type: List[T] - for t in tvars: - if t in all_tvars: - new_tvars.append(t) - all_tvars.remove(t) - return new_tvars + return remove_dups(tvars) def setup_class_def_analysis(self, defn: ClassDef) -> None: """Prepare for the analysis of a class definition.""" @@ -959,7 +904,7 @@ def verify_base_classes(self, defn: ClassDef) -> bool: info = defn.info for base in info.bases: baseinfo = base.type - if self.is_base_class(info, baseinfo): + if info.is_base_class(baseinfo): self.fail('Cycle in inheritance hierarchy', defn, blocker=True) # Clear bases to forcefully get rid of the cycle. info.bases = [] @@ -973,22 +918,6 @@ def verify_base_classes(self, defn: ClassDef) -> bool: return False return True - @staticmethod - def is_base_class(t: TypeInfo, s: TypeInfo) -> bool: - """Determine if t is a base class of s (but do not use mro).""" - # Search the base class graph for t, starting from s. - worklist = [s] - visited = {s} - while worklist: - nxt = worklist.pop() - if nxt == t: - return True - for base in nxt.bases: - if base.type not in visited: - worklist.append(base.type) - visited.add(base.type) - return False - def analyze_metaclass(self, defn: ClassDef) -> None: error_context = defn # type: Context if defn.metaclass is None and self.options.python_version[0] == 2: @@ -1156,7 +1085,7 @@ def visit_import_from(self, imp: ImportFrom) -> None: existing_symbol = self.globals.get(imported_id) if existing_symbol: # Import can redefine a variable. They get special treatment. - if self.process_import_over_existing_name( + if process_import_over_existing_name( imported_id, existing_symbol, node, imp): continue # 'from m import x as x' exports x in a stub file. @@ -1170,7 +1099,7 @@ def visit_import_from(self, imp: ImportFrom) -> None: elif module and not missing: # Missing attribute. message = "Module '{}' has no attribute '{}'".format(import_id, id) - extra = self.undefined_name_extra_info('{}.{}'.format(import_id, id)) + extra = undefined_name_extra_info('{}.{}'.format(import_id, id)) if extra: message += " {}".format(extra) self.fail(message, imp) @@ -1178,27 +1107,6 @@ def visit_import_from(self, imp: ImportFrom) -> None: # Missing module. self.add_unknown_symbol(as_id or id, imp, is_import=True) - @staticmethod - def process_import_over_existing_name(imported_id: str, existing_symbol: SymbolTableNode, - module_symbol: SymbolTableNode, - import_node: ImportBase) -> bool: - if (existing_symbol.kind in (LDEF, GDEF, MDEF) and - isinstance(existing_symbol.node, (Var, FuncDef, TypeInfo))): - # This is a valid import over an existing definition in the file. Construct a dummy - # assignment that we'll use to type check the import. - lvalue = NameExpr(imported_id) - lvalue.kind = existing_symbol.kind - lvalue.node = existing_symbol.node - rvalue = NameExpr(imported_id) - rvalue.kind = module_symbol.kind - rvalue.node = module_symbol.node - assignment = AssignmentStmt([lvalue], rvalue) - for node in assignment, lvalue, rvalue: - node.set_line(import_node) - import_node.assignments.append(assignment) - return True - return False - def normalize_type_alias(self, node: SymbolTableNode, ctx: Context) -> SymbolTableNode: normalized = False @@ -1244,7 +1152,7 @@ def visit_import_all(self, i: ImportAll) -> None: existing_symbol = self.globals.get(name) if existing_symbol: # Import can redefine a variable. They get special treatment. - if self.process_import_over_existing_name( + if process_import_over_existing_name( name, existing_symbol, node, i): continue self.add_symbol(name, SymbolTableNode(node.kind, node.node, @@ -1455,7 +1363,7 @@ def analyze_lvalue(self, lval: Lvalue, nested: bool = False, elif isinstance(lval, MemberExpr): if not add_global: self.analyze_member_lvalue(lval) - if explicit_type and not self.is_self_member_ref(lval): + if explicit_type and not is_self_member_ref(lval): self.fail('Type cannot be declared in assignment to non-self ' 'attribute', lval) elif isinstance(lval, IndexExpr): @@ -1495,7 +1403,7 @@ def analyze_tuple_or_list_lvalue(self, lval: Union[ListExpr, TupleExpr], def analyze_member_lvalue(self, lval: MemberExpr) -> None: lval.accept(self) - if (self.is_self_member_ref(lval) and + if (is_self_member_ref(lval) and self.type.get(lval.name) is None): # Implicit attribute definition in __init__. lval.is_def = True @@ -1508,14 +1416,6 @@ def analyze_member_lvalue(self, lval: MemberExpr) -> None: self.type.names[lval.name] = SymbolTableNode(MDEF, v) self.check_lvalue_validity(lval.node, lval) - @staticmethod - def is_self_member_ref(memberexpr: MemberExpr) -> bool: - """Does memberexpr to refer to an attribute of self?""" - if not isinstance(memberexpr.expr, NameExpr): - return False - node = memberexpr.expr.node - return isinstance(node, Var) and node.is_self - def check_lvalue_validity(self, node: Union[Expression, SymbolNode], ctx: Context) -> None: if isinstance(node, TypeVarExpr): self.fail('Invalid assignment target', ctx) @@ -1572,7 +1472,7 @@ def check_classvar(self, s: AssignmentStmt) -> None: node = lvalue.node if isinstance(node, Var): node.is_classvar = True - elif not isinstance(lvalue, MemberExpr) or self.is_self_member_ref(lvalue): + elif not isinstance(lvalue, MemberExpr) or is_self_member_ref(lvalue): # In case of member access, report error only when assigning to self # Other kinds of member assignments should be already reported self.fail_invalid_classvar(lvalue) @@ -2411,7 +2311,7 @@ def check_no_global(self, n: str, ctx: Context, def name_not_defined(self, name: str, ctx: Context) -> None: message = "Name '{}' is not defined".format(name) - extra = self.undefined_name_extra_info(name) + extra = undefined_name_extra_info(name) if extra: message += ' {}'.format(extra) self.fail(message, ctx) @@ -2440,13 +2340,6 @@ def note(self, msg: str, ctx: Context) -> None: return self.errors.report(ctx.get_line(), ctx.get_column(), msg, severity='note') - @staticmethod - def undefined_name_extra_info(fullname: str) -> Optional[str]: - if fullname in obsolete_name_mapping: - return "(it's now called '{}')".format(obsolete_name_mapping[fullname]) - else: - return None - def accept(self, node: Node) -> None: try: node.accept(self) @@ -2563,7 +2456,7 @@ def visit_func_def(self, func: FuncDef) -> None: # Ah this is an imported name. We can't resolve them now, so we'll postpone # this until the main phase of semantic analysis. return - if not sem.set_original_def(original_sym.node, func): + if not func.set_original_def(original_sym.node): # Report error. sem.check_no_global(func.name(), func) else: @@ -2848,6 +2741,34 @@ def builtin_type(self, name: str, args: List[Type] = None) -> Instance: return Instance(sym.node, args or []) +def process_import_over_existing_name(imported_id: str, existing_symbol: SymbolTableNode, + module_symbol: SymbolTableNode, + import_node: ImportBase) -> bool: + if (existing_symbol.kind in (LDEF, GDEF, MDEF) and + isinstance(existing_symbol.node, (Var, FuncDef, TypeInfo))): + # This is a valid import over an existing definition in the file. Construct a dummy + # assignment that we'll use to type check the import. + lvalue = NameExpr(imported_id) + lvalue.kind = existing_symbol.kind + lvalue.node = existing_symbol.node + rvalue = NameExpr(imported_id) + rvalue.kind = module_symbol.kind + rvalue.node = module_symbol.node + assignment = AssignmentStmt([lvalue], rvalue) + for node in assignment, lvalue, rvalue: + node.set_line(import_node) + import_node.assignments.append(assignment) + return True + return False + + +def undefined_name_extra_info(fullname: str) -> Optional[str]: + if fullname in obsolete_name_mapping: + return "(it's now called '{}')".format(obsolete_name_mapping[fullname]) + else: + return None + + def replace_implicit_first_type(sig: FunctionLike, new: Type) -> FunctionLike: if isinstance(sig, CallableType): return sig.copy_modified(arg_types=[new] + sig.arg_types[1:]) @@ -2929,6 +2850,14 @@ def remove_imported_names_from_symtable(names: SymbolTable, del names[name] +def is_self_member_ref(memberexpr: MemberExpr) -> bool: + """Does memberexpr to refer to an attribute of self?""" + if not isinstance(memberexpr.expr, NameExpr): + return False + node = memberexpr.expr.node + return isinstance(node, Var) and node.is_self + + def infer_reachability_of_if_statement(s: IfStmt, pyversion: Tuple[int, int], platform: str) -> None: @@ -3219,3 +3148,14 @@ def find_fixed_callable_return(expr: Expression) -> Optional[CallableType]: if isinstance(t.ret_type, CallableType): return t.ret_type return None + + +def remove_dups(tvars: List[T]) -> List[T]: + # Get unique elements in order of appearance + all_tvars = set(tvars) + new_tvars = [] # type: List[T] + for t in tvars: + if t in all_tvars: + new_tvars.append(t) + all_tvars.remove(t) + return new_tvars diff --git a/mypy/specialtype.py b/mypy/specialtype.py index f021f164bcb36..3e8cff2c1dd39 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -22,8 +22,9 @@ class Special: - """ - Groups special-cased types: + """Handling of special-cased types. + + Special-cased types include: * NamedTuple * TypedDict * NewType @@ -111,11 +112,6 @@ def check_namedtuple_classdef( default_items[name] = stmt.rvalue return items, types, default_items - @staticmethod - def is_typeddict(expr: Expression) -> bool: - return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and - expr.node.typeddict_type is not None) - def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: # special case for TypedDict possible = False @@ -123,7 +119,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: if isinstance(base_expr, RefExpr): base_expr.accept(self.semanalyzer) if (base_expr.fullname == 'mypy_extensions.TypedDict' or - self.is_typeddict(base_expr)): + is_typeddict(base_expr)): possible = True if possible: node = self.lookup(defn.name, defn) @@ -139,9 +135,9 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: # Extending/merging existing TypedDicts if any(not isinstance(expr, RefExpr) or expr.fullname != 'mypy_extensions.TypedDict' and - not self.is_typeddict(expr) for expr in defn.base_type_exprs): + not is_typeddict(expr) for expr in defn.base_type_exprs): self.fail("All bases of a new TypedDict must be TypedDict types", defn) - typeddict_bases = list(filter(self.is_typeddict, defn.base_type_exprs)) + typeddict_bases = list(filter(is_typeddict, defn.base_type_exprs)) newfields = [] # type: List[str] newtypes = [] # type: List[Type] tpdict = None # type: OrderedDict[str, Type] @@ -297,7 +293,7 @@ def build_newtype_typeinfo(self, name: str, old_type: Type, base_type: Instance) # Add __init__ method args = [Argument(Var('cls'), NoneTyp(), None, ARG_POS), - self.make_argument('item', old_type)] + Argument(Var('item'), old_type, None, ARG_POS)] signature = CallableType( arg_types=[cast(Type, None), old_type], arg_kinds=[arg.kind for arg in args], @@ -342,13 +338,12 @@ def analyze_types(self, items: List[Expression]) -> List[Type]: def process_typevar_declaration(self, s: AssignmentStmt) -> None: """Check if s declares a TypeVar; it yes, store it in symbol table.""" - call = self.get_typevar_declaration(s) - if not call: - return - + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return None lvalue = s.lvalues[0] - assert isinstance(lvalue, NameExpr) - name = lvalue.name + call, calleename, name = self.get_call(s.rvalue, lvalue.name, fresh=False) + if calleename != 'typing.TypeVar': + return None if not lvalue.is_def: if s.type: self.fail("Cannot declare the type of a type variable", s) @@ -394,23 +389,6 @@ def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> boo return False return True - @staticmethod - def get_typevar_declaration(s: AssignmentStmt) -> Optional[CallExpr]: - """Returns the TypeVar() call expression if `s` is a type var declaration - or None otherwise. - """ - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return None - if not isinstance(s.rvalue, CallExpr): - return None - call = s.rvalue - callee = call.callee - if not isinstance(callee, RefExpr): - return None - if callee.fullname != 'typing.TypeVar': - return None - return call - def process_typevar_parameters(self, args: List[Expression], names: List[Optional[str]], @@ -700,10 +678,6 @@ def make_init_arg(var: Var) -> Argument: Argument(Var('len'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT)]) return info - @staticmethod - def make_argument(name: str, type: Type) -> Argument: - return Argument(Var(name), type, None, ARG_POS) - def process_typeddict_definition(self, s: AssignmentStmt) -> None: """Check if s defines a TypedDict; if yes, store the definition in symbol table.""" if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): @@ -741,7 +715,8 @@ def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[Ty call.analyzed.set_line(call.line, call.column) return info - def get_call(self, expr: Expression, var_name: str) -> Tuple[CallExpr, str, str]: + def get_call(self, expr: Expression, var_name: str, *, + fresh: bool = True) -> Tuple[CallExpr, str, str]: (call, calleename, name) = None, '', '' if isinstance(expr, CallExpr): call = expr @@ -750,7 +725,7 @@ def get_call(self, expr: Expression, var_name: str) -> Tuple[CallExpr, str, str] calleename = callee.fullname if len(call.args) > 0: name = getattr(call.args[0], 'value', var_name) - if isinstance(name, str): + if isinstance(name, str) and fresh: if name != var_name or self.semanalyzer.is_func_scope(): # Give it a unique name derived from the line number. name += '@' + str(call.line) @@ -910,3 +885,8 @@ def fail_enum_call_arg(self, message: str, List[Optional[Expression]], bool]: self.fail(message, context) return [], [], False + + +def is_typeddict(expr: Expression) -> bool: + return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and + expr.node.typeddict_type is not None) From 0604ee97fb01c5cc35cdadac6d18af7d2646cc4a Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 5 Apr 2017 18:25:18 +0300 Subject: [PATCH 06/18] inherit AbstractNodeVisitor --- mypy/nodes.py | 6 ++--- mypy/semanal.py | 61 +++++++++++++++++++++++++++++++++++++++------ mypy/specialtype.py | 5 ++++ mypy/visitor.py | 13 +++++----- 4 files changed, 69 insertions(+), 16 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index 758bb163d263e..1c467a7a0a0b3 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -8,7 +8,7 @@ ) import mypy.strconv -from mypy.visitor import NodeVisitor, StatementVisitor, ExpressionVisitor +from mypy.visitor import AbstractNodeVisitor, StatementVisitor, ExpressionVisitor from mypy.util import dump_tagged, short_type @@ -144,7 +144,7 @@ def get_column(self) -> int: # TODO this should be just 'column' return self.column - def accept(self, visitor: NodeVisitor[T]) -> T: + def accept(self, visitor: AbstractNodeVisitor[T]) -> T: raise RuntimeError('Not implemented') @@ -271,7 +271,7 @@ def name(self) -> str: def fullname(self) -> str: return self._fullname - def accept(self, visitor: NodeVisitor[T]) -> T: + def accept(self, visitor: AbstractNodeVisitor[T]) -> T: return visitor.visit_mypy_file(self) def is_package_init_file(self) -> bool: diff --git a/mypy/semanal.py b/mypy/semanal.py index ce85f72e4b333..0efd3a8d619fb 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -66,8 +66,9 @@ IntExpr, FloatExpr, UnicodeExpr, UNBOUND_IMPORTED, LITERAL_YES, nongen_builtins, collections_type_aliases, get_member_expr_fullname, ) +import mypy.nodes from mypy.typevars import has_no_typevars, fill_typevars -from mypy.visitor import NodeVisitor +from mypy.visitor import AbstractNodeVisitor, NodeVisitor from mypy.traverser import TraverserVisitor from mypy.errors import Errors, report_internal_error from mypy.messages import CANNOT_ASSIGN_TO_TYPE @@ -150,7 +151,7 @@ FUNCTION_SECOND_PHASE = 2 # Only analyze body -class SemanticAnalyzer(NodeVisitor): +class SemanticAnalyzer(AbstractNodeVisitor[None]): """Semantically analyze parsed mypy files. The analyzer binds names and does various consistency checks for a @@ -655,7 +656,7 @@ def visit_class_def(self, defn: ClassDef) -> None: self.analyze_metaclass(defn) for decorator in defn.decorators: - self.analyze_class_decorator(defn, decorator) + decorator.accept(self) self.enter_class(defn) @@ -702,9 +703,6 @@ def unbind_class_type_vars(self) -> None: if self.bound_tvars: enable_typevars(self.bound_tvars) - def analyze_class_decorator(self, defn: ClassDef, decorator: Expression) -> None: - decorator.accept(self) - def setup_type_promotion(self, defn: ClassDef) -> None: """Setup extra, ad-hoc subtyping relationships between classes (promotion). @@ -1622,7 +1620,7 @@ def visit_if_stmt(self, s: IfStmt) -> None: def visit_try_stmt(self, s: TryStmt) -> None: self.analyze_try_stmt(s, self) - def analyze_try_stmt(self, s: TryStmt, visitor: NodeVisitor, + def analyze_try_stmt(self, s: TryStmt, visitor: AbstractNodeVisitor, add_global: bool = False) -> None: s.body.accept(visitor) for type, var, handler in zip(s.types, s.vars, s.handlers): @@ -1734,6 +1732,9 @@ def visit_exec_stmt(self, s: ExecStmt) -> None: if s.variables2: s.variables2.accept(self) + def visit_pass_stmt(self, o: 'mypy.nodes.PassStmt') -> T: + pass + # # Expressions # @@ -2107,6 +2108,52 @@ def visit_await_expr(self, expr: AwaitExpr) -> None: self.fail("'await' outside coroutine ('async def')", expr) expr.expr.accept(self) + def visit_type_alias_expr(self, o: 'mypy.nodes.TypeAliasExpr') -> None: + pass + + def visit_typeddict_expr(self, o: 'mypy.nodes.TypedDictExpr') -> None: + pass + + def visit_int_expr(self, o: 'mypy.nodes.IntExpr') -> None: + pass + + def visit_complex_expr(self, o: 'mypy.nodes.ComplexExpr') -> None: + pass + + def visit_float_expr(self, o: 'mypy.nodes.FloatExpr') -> None: + pass + + def visit_bytes_expr(self, o: 'mypy.nodes.BytesExpr') -> None: + pass + + def visit_str_expr(self, o: 'mypy.nodes.StrExpr') -> None: + pass + + def visit_unicode_expr(self, o: 'mypy.nodes.UnicodeExpr') -> None: + pass + + def visit_enum_call_expr(self, o: 'mypy.nodes.EnumCallExpr') -> None: + pass + + def visit_namedtuple_expr(self, o: 'mypy.nodes.NamedTupleExpr') -> None: + pass + + def visit_newtype_expr(self, o: 'mypy.nodes.NewTypeExpr') -> None: + pass + + def visit_type_var_expr(self, o: 'mypy.nodes.TypeVarExpr') -> None: + pass + + def visit_ellipsis(self, o: 'mypy.nodes.EllipsisExpr') -> None: + pass + + def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> None: + # this method is actually visited + pass + + def visit_var(self, o: 'mypy.nodes.Var') -> None: + assert False + # # Helpers # diff --git a/mypy/specialtype.py b/mypy/specialtype.py index 3e8cff2c1dd39..7db3b3bdf7094 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -1,3 +1,8 @@ +"""Special case semantic analysis for type-expressions, such as namedtuple. + +This module is used only by the SemanticAnalyzer, and is tightly coupled with it. +""" + from collections import OrderedDict from typing import List, Dict, Tuple, cast, Optional, Callable, TYPE_CHECKING diff --git a/mypy/visitor.py b/mypy/visitor.py index 6bd7520f4fb65..4574694399631 100644 --- a/mypy/visitor.py +++ b/mypy/visitor.py @@ -301,7 +301,13 @@ def visit_exec_stmt(self, o: 'mypy.nodes.ExecStmt') -> T: pass -class NodeVisitor(Generic[T], ExpressionVisitor[T], StatementVisitor[T]): +class AbstractNodeVisitor(Generic[T], ExpressionVisitor[T], StatementVisitor[T]): + # Not in superclasses: + def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T: + pass + + +class NodeVisitor(Generic[T], AbstractNodeVisitor[T]): """Empty base class for parse tree node visitors. The T type argument specifies the return type of the visit @@ -311,11 +317,6 @@ class NodeVisitor(Generic[T], ExpressionVisitor[T], StatementVisitor[T]): TODO make the default return value explicit """ - # Not in superclasses: - - def visit_mypy_file(self, o: 'mypy.nodes.MypyFile') -> T: - pass - # Module structure def visit_import(self, o: 'mypy.nodes.Import') -> T: From b9eec15f3dd0da451b0c5e9c174036db4420c3c6 Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 5 Apr 2017 18:49:33 +0300 Subject: [PATCH 07/18] fix typo --- mypy/semanal.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index 557a1b9016bfa..e19197a91a910 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -700,7 +700,7 @@ def visit_class_def(self, defn: ClassDef) -> None: def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: self.clean_up_bases_and_infer_type_variables(defn) if self.specialtype.analyze_typeddict_classdef(defn): - yield Fals + yield False return if self.specialtype.analyze_namedtuple_classdef(defn): # just analyze the class body so we catch type errors in default values From eb38a3ab99d0c21f84ca5734ac7ba24e68df9950 Mon Sep 17 00:00:00 2001 From: elazar Date: Thu, 6 Apr 2017 00:36:16 +0300 Subject: [PATCH 08/18] refactoring++ --- mypy/specialtype.py | 590 ++++++++++++++---------------- test-data/unit/check-newtype.test | 2 +- 2 files changed, 274 insertions(+), 318 deletions(-) diff --git a/mypy/specialtype.py b/mypy/specialtype.py index 7db3b3bdf7094..0df955a4045b7 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -5,7 +5,7 @@ from collections import OrderedDict -from typing import List, Dict, Tuple, cast, Optional, Callable, TYPE_CHECKING +from typing import List, Dict, Tuple, cast, Optional, TYPE_CHECKING from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.nodes import ( @@ -26,6 +26,15 @@ import mypy.semanal +class DeclInfo: + var_name = None # type: str + name = None # type: str + fullname = None # type: str + calleename = None # type: str + call = None # type: CallExpr + is_def = None # type: bool + + class Special: """Handling of special-cased types. @@ -54,11 +63,98 @@ def __init__(self, semanalyzer: 'mypy.semanal.SemanticAnalyzer') -> None: self.str_type = semanalyzer.str_type def process_declaration(self, s: AssignmentStmt) -> None: - self.process_newtype_declaration(s) - self.process_typevar_declaration(s) - self.process_namedtuple_definition(s) - self.process_typeddict_definition(s) - self.process_enum_call(s) + if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): + return + var_name = s.lvalues[0].name + is_def = s.lvalues[0].is_def + call, calleename, name = self.get_call(s.rvalue, var_name) + node = self.lookup(var_name, s) + if node is None: + return + fullname = node.fullname + info, tvar = self.dispatch_call(call, calleename, name, + var_name, fullname) + if tvar is not None: + node = self.lookup(name, s) + node.kind = UNBOUND_TVAR + node.node = tvar + tvar.line = call.line + call.analyzed = tvar + if (info or tvar) and not is_def: + tname = calleename.split('.')[-1] + if tname == 'TypeVar': + tname = 'type variable' + if s.type: + self.fail("Cannot declare the type of a %s" % tname, s) + else: + self.fail("Cannot redefine '%s' as a %s" % (var_name, tname), s) + if info is None: + return + # Yes, it's a valid definition. Add it to the symbol table. + node.kind = GDEF # TODO locally defined type + node.node = info + + def dispatch_call(self, call: CallExpr, calleename: str, + name: str, var_name: str, fullname: str) -> Tuple[TypeInfo, TypeVarExpr]: + tvar = None # type: TypeVarExpr + if calleename == 'typing.NewType': + info = self.check_newtype(call, var_name) + elif calleename == 'typing.TypeVar': + tvar = self.check_typevar(call, name, fullname) + info = None + elif calleename in ('collections.namedtuple', 'typing.NamedTuple'): + info = self.check_namedtuple(call, calleename, name) + elif calleename in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): + info = self.check_enum_call(call, calleename, name) + elif calleename == 'mypy_extensions.TypedDict': + info = self.check_typeddict(call, name) + else: + info = None + return info, tvar + + def check_typevar(self, call: CallExpr, name: str, fullname: str) -> Optional[TypeVarExpr]: + """Check if s declares a TypeVar; it yes, store it in symbol table.""" + if not self.check_typevar_name(call, name, context=call): + return None + + # Constraining types + n_values = call.arg_kinds[1:].count(ARG_POS) + values = self.analyze_types(call.args[1:1 + n_values]) + + res = self.process_typevar_arguments(call.args[1 + n_values:], + call.arg_names[1 + n_values:], + call.arg_kinds[1 + n_values:], + n_values, + context=call) + if res is None: + return None + variance, upper_bound = res + return TypeVarExpr(name, fullname, values, upper_bound, variance) + + def check_newtype(self, call: CallExpr, var_name: str = None) -> Optional[TypeInfo]: + """Check if s declares a NewType; if yes, store it in symbol table.""" + # Extract and check all information from newtype declaration + + # This dummy NewTypeExpr marks the call as sufficiently analyzed; it will be + # overwritten later with a fully complete NewTypeExpr if there are no other + # errors with the NewType() call. + + old_type = self.check_newtype_args(var_name, call, call) + call.analyzed = NewTypeExpr(var_name, old_type, line=call.line) + if old_type is None: + return None + + # Create the corresponding class definition if the aliased type is subtypeable + if isinstance(old_type, TupleType): + newtype_class_info = self.build_newtype_typeinfo(var_name, old_type, old_type.fallback) + newtype_class_info.tuple_type = old_type + elif isinstance(old_type, Instance): + newtype_class_info = self.build_newtype_typeinfo(var_name, old_type, old_type) + else: + message = "Argument 2 to NewType(...) must be subclassable (got {})" + self.fail(message.format(old_type), call) + return None + return newtype_class_info def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: for base_expr in defn.base_type_exprs: @@ -206,62 +302,6 @@ def check_typeddict_classdef(self, defn: ClassDef, self.fail('Right hand side values are not supported in TypedDict', stmt) return fields, types - def process_newtype_declaration(self, s: AssignmentStmt) -> None: - """Check if s declares a NewType; if yes, store it in symbol table.""" - # Extract and check all information from newtype declaration - name, call = self.analyze_newtype_declaration(s) - if name is None or call is None: - return - - old_type = self.check_newtype_args(name, call, s) - call.analyzed = NewTypeExpr(name, old_type, line=call.line) - if old_type is None: - return - - # Create the corresponding class definition if the aliased type is subtypeable - if isinstance(old_type, TupleType): - newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type.fallback) - newtype_class_info.tuple_type = old_type - elif isinstance(old_type, Instance): - newtype_class_info = self.build_newtype_typeinfo(name, old_type, old_type) - else: - message = "Argument 2 to NewType(...) must be subclassable (got {})" - self.fail(message.format(old_type), s) - return - - # If so, add it to the symbol table. - node = self.lookup(name, s) - if node is None: - self.fail("Could not find {} in current namespace".format(name), s) - return - # TODO: why does NewType work in local scopes despite always being of kind GDEF? - node.kind = GDEF - call.analyzed.info = node.node = newtype_class_info - - def analyze_newtype_declaration(self, s: AssignmentStmt - ) -> Tuple[Optional[str], Optional[CallExpr]]: - """Return the NewType call expression if `s` is a newtype declaration or None otherwise.""" - name, call = None, None - if (len(s.lvalues) == 1 - and isinstance(s.lvalues[0], NameExpr) - and isinstance(s.rvalue, CallExpr) - and isinstance(s.rvalue.callee, RefExpr) - and s.rvalue.callee.fullname == 'typing.NewType'): - lvalue = s.lvalues[0] - name = s.lvalues[0].name - if not lvalue.is_def: - if s.type: - self.fail("Cannot declare the type of a NewType declaration", s) - else: - self.fail("Cannot redefine '%s' as a NewType" % name, s) - - # This dummy NewTypeExpr marks the call as sufficiently analyzed; it will be - # overwritten later with a fully complete NewTypeExpr if there are no other - # errors with the NewType() call. - call = s.rvalue - - return name, call - def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Optional[Type]: has_failed = False args, arg_kinds = call.args, call.arg_kinds @@ -341,44 +381,24 @@ def analyze_types(self, items: List[Expression]) -> List[Type]: result.append(AnyType()) return result - def process_typevar_declaration(self, s: AssignmentStmt) -> None: - """Check if s declares a TypeVar; it yes, store it in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return None - lvalue = s.lvalues[0] - call, calleename, name = self.get_call(s.rvalue, lvalue.name, fresh=False) - if calleename != 'typing.TypeVar': - return None - if not lvalue.is_def: - if s.type: - self.fail("Cannot declare the type of a type variable", s) - else: - self.fail("Cannot redefine '%s' as a type variable" % name, s) - return - - if not self.check_typevar_name(call, name, s): - return + def check_namedtuple(self, call: CallExpr, calleename: str, name: str) -> Optional[TypeInfo]: + """Check if a call defines a namedtuple. - # Constraining types - n_values = call.arg_kinds[1:].count(ARG_POS) - values = self.analyze_types(call.args[1:1 + n_values]) + The optional var_name argument is the name of the variable to + which this is assigned, if any. - res = self.process_typevar_parameters(call.args[1 + n_values:], - call.arg_names[1 + n_values:], - call.arg_kinds[1 + n_values:], - n_values, - s) - if res is None: - return - variance, upper_bound = res + If it does, return the corresponding TypeInfo. Return None otherwise. - # Yes, it's a valid type variable definition! Add it to the symbol table. - node = self.lookup(name, s) - node.kind = UNBOUND_TVAR - TypeVar = TypeVarExpr(name, node.fullname, values, upper_bound, variance) - TypeVar.line = call.line - call.analyzed = TypeVar - node.node = TypeVar + If the definition is invalid but looks like a namedtuple, + report errors but return (some) TypeInfo. + """ + items, types, ok = self.parse_namedtuple_args(call, calleename) + info = self.build_namedtuple_typeinfo(name, items, types) + if ok: + self.semanalyzer.store_info(info, name) + call.analyzed = NamedTupleExpr(info) + call.analyzed.set_line(call.line, call.column) + return info def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> bool: if len(call.args) < 1: @@ -394,23 +414,23 @@ def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> boo return False return True - def process_typevar_parameters(self, - args: List[Expression], - names: List[Optional[str]], - kinds: List[int], - num_values: int, - context: Context) -> Optional[Tuple[int, Type]]: + def process_typevar_arguments(self, + args: List[Expression], + names: List[Optional[str]], + kinds: List[int], + num_values: int, + context: Context) -> Optional[Tuple[int, Type]]: has_values = (num_values > 0) covariant = False contravariant = False upper_bound = self.object_type() # type: Type - for param_value, param_name, param_kind in zip(args, names, kinds): - if not param_kind == ARG_NAMED: + for arg_value, arg_name, arg_kind in zip(args, names, kinds): + if not arg_kind == ARG_NAMED: self.fail("Unexpected argument to TypeVar()", context) return None - if param_name == 'covariant': - if isinstance(param_value, NameExpr): - if param_value.name == 'True': + if arg_name == 'covariant': + if isinstance(arg_value, NameExpr): + if arg_value.name == 'True': covariant = True else: self.fail("TypeVar 'covariant' may only be 'True'", context) @@ -418,9 +438,9 @@ def process_typevar_parameters(self, else: self.fail("TypeVar 'covariant' may only be 'True'", context) return None - elif param_name == 'contravariant': - if isinstance(param_value, NameExpr): - if param_value.name == 'True': + elif arg_name == 'contravariant': + if isinstance(arg_value, NameExpr): + if arg_value.name == 'True': contravariant = True else: self.fail("TypeVar 'contravariant' may only be 'True'", context) @@ -428,23 +448,23 @@ def process_typevar_parameters(self, else: self.fail("TypeVar 'contravariant' may only be 'True'", context) return None - elif param_name == 'bound': + elif arg_name == 'bound': if has_values: self.fail("TypeVar cannot have both values and an upper bound", context) return None try: - upper_bound = self.semanalyzer.expr_to_analyzed_type(param_value) + upper_bound = self.semanalyzer.expr_to_analyzed_type(arg_value) except TypeTranslationError: - self.fail("TypeVar 'bound' must be a type", param_value) + self.fail("TypeVar 'bound' must be a type", arg_value) return None - elif param_name == 'values': + elif arg_name == 'values': # Probably using obsolete syntax with values=(...). Explain the current syntax. self.fail("TypeVar 'values' argument not supported", context) self.fail("Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))", context) return None else: - self.fail("Unexpected argument to TypeVar(): {}".format(param_name), context) + self.fail("Unexpected argument to TypeVar(): {}".format(arg_name), context) return None if covariant and contravariant: @@ -461,117 +481,25 @@ def process_typevar_parameters(self, variance = INVARIANT return (variance, upper_bound) - def process_call(self, s: AssignmentStmt, - check: Callable[[Expression, str], TypeInfo]) -> None: - """Check if s defines a legal node; if yes, store the definition in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - lvalue = s.lvalues[0] - name = lvalue.name - info = check(s.rvalue, name) - if info is None: - return - # Yes, it's a valid definition. Add it to the symbol table. - node = self.lookup(name, s) - node.kind = GDEF # TODO locally defined type - node.node = info - - def process_namedtuple_definition(self, s: AssignmentStmt) -> None: - self.process_call(s, self.check_namedtuple) - - def process_enum_call(self, s: AssignmentStmt) -> None: - self.process_call(s, self.check_enum_call) - - def check_namedtuple(self, expr: Expression, var_name: str = None) -> Optional[TypeInfo]: - """Check if a call defines a namedtuple. + def check_typeddict(self, call: CallExpr, name: str) -> Optional[TypeInfo]: + """Check if a call defines a TypedDict. The optional var_name argument is the name of the variable to which this is assigned, if any. If it does, return the corresponding TypeInfo. Return None otherwise. - If the definition is invalid but looks like a namedtuple, + If the definition is invalid but looks like a TypedDict, report errors but return (some) TypeInfo. """ - - call, calleename, name = self.get_call(expr, var_name) - if calleename not in ('collections.namedtuple', 'typing.NamedTuple'): - return None - items, types, ok = self.parse_namedtuple_args(call, calleename) - info = self.build_namedtuple_typeinfo(name, items, types) + items, types, ok = self.parse_typeddict_args(call) + info = self.build_typeddict_typeinfo(name, items, types) if ok: self.semanalyzer.store_info(info, name) - call.analyzed = NamedTupleExpr(info) + call.analyzed = TypedDictExpr(info) call.analyzed.set_line(call.line, call.column) return info - def parse_namedtuple_args(self, call: CallExpr, - fullname: str) -> Tuple[List[str], List[Type], bool]: - # TODO: Share code with check_argument_count in checkexpr.py? - args = call.args - if len(args) < 2: - return self.fail_namedtuple_arg("Too few arguments for namedtuple()", call) - if len(args) > 2: - # FIX incorrect. There are two additional parameters - return self.fail_namedtuple_arg("Too many arguments for namedtuple()", call) - if call.arg_kinds != [ARG_POS, ARG_POS]: - return self.fail_namedtuple_arg("Unexpected arguments to namedtuple()", call) - if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): - return self.fail_namedtuple_arg( - "namedtuple() expects a string literal as the first argument", call) - types = [] # type: List[Type] - ok = True - if not isinstance(args[1], (ListExpr, TupleExpr)): - if (fullname == 'collections.namedtuple' - and isinstance(args[1], (StrExpr, BytesExpr, UnicodeExpr))): - str_expr = cast(StrExpr, args[1]) - items = str_expr.value.replace(',', ' ').split() - else: - return self.fail_namedtuple_arg( - "List or tuple literal expected as the second argument to namedtuple()", call) - else: - listexpr = args[1] - if fullname == 'collections.namedtuple': - # The fields argument contains just names, with implicit Any types. - if any(not isinstance(item, (StrExpr, BytesExpr, UnicodeExpr)) - for item in listexpr.items): - return self.fail_namedtuple_arg("String literal expected as namedtuple() item", - call) - items = [cast(StrExpr, item).value for item in listexpr.items] - else: - # The fields argument contains (name, type) tuples. - items, types, ok = self.parse_namedtuple_fields_with_types(listexpr.items) - if not types: - types = [AnyType() for _ in items] - underscore = [item for item in items if item.startswith('_')] - if underscore: - self.fail("namedtuple() field names cannot start with an underscore: " - + ', '.join(underscore), call) - return items, types, ok - - def parse_namedtuple_fields_with_types(self, nodes: List[Expression] - ) -> Tuple[List[str], List[Type], bool]: - items = [] # type: List[str] - types = [] # type: List[Type] - for item in nodes: - if isinstance(item, TupleExpr): - if len(item.items) != 2: - return self.fail_namedtuple_arg("Invalid NamedTuple field definition", - item) - name, type_node = item.items - if isinstance(name, (StrExpr, BytesExpr, UnicodeExpr)): - items.append(name.value) - else: - return self.fail_namedtuple_arg("Invalid NamedTuple() field name", item) - try: - type = expr_to_unanalyzed_type(type_node) - except TypeTranslationError: - return self.fail_namedtuple_arg('Invalid field type', type_node) - types.append(self.semanalyzer.anal_type(type)) - else: - return self.fail_namedtuple_arg("Tuple expected as NamedTuple() field", item) - return items, types, True - def fail_namedtuple_arg(self, message: str, context: Context) -> Tuple[List[str], List[Type], bool]: self.fail(message, context) @@ -587,8 +515,9 @@ def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance) -> TypeI return info def analyze_callexpr_as_type(self, call: CallExpr) -> Optional[Type]: - info = self.check_namedtuple(call) - if info is None: + call, calleename, name = self.get_call(call, '') + info, tvar = self.dispatch_call(call, calleename, name, '', '') + if info is None or info.tuple_type is None: # Some form of namedtuple is the only valid type that looks like a call # expression. This isn't a valid type. return None @@ -683,45 +612,7 @@ def make_init_arg(var: Var) -> Argument: Argument(Var('len'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT)]) return info - def process_typeddict_definition(self, s: AssignmentStmt) -> None: - """Check if s defines a TypedDict; if yes, store the definition in symbol table.""" - if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr): - return - lvalue = s.lvalues[0] - name = lvalue.name - typed_dict = self.check_typeddict(s.rvalue, name) - if typed_dict is None: - return - # Yes, it's a valid TypedDict definition. Add it to the symbol table. - node = self.lookup(name, s) - if node: - node.kind = GDEF # TODO locally defined TypedDict - node.node = typed_dict - - def check_typeddict(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: - """Check if a call defines a TypedDict. - - The optional var_name argument is the name of the variable to - which this is assigned, if any. - - If it does, return the corresponding TypeInfo. Return None otherwise. - - If the definition is invalid but looks like a TypedDict, - report errors but return (some) TypeInfo. - """ - call, calleename, name = self.get_call(node, var_name) - if calleename != 'mypy_extensions.TypedDict': - return None - items, types, ok = self.parse_typeddict_args(call) - info = self.build_typeddict_typeinfo(name, items, types) - if ok: - self.semanalyzer.store_info(info, name) - call.analyzed = TypedDictExpr(info) - call.analyzed.set_line(call.line, call.column) - return info - - def get_call(self, expr: Expression, var_name: str, *, - fresh: bool = True) -> Tuple[CallExpr, str, str]: + def get_call(self, expr: Expression, var_name: str) -> Tuple[CallExpr, str, str]: (call, calleename, name) = None, '', '' if isinstance(expr, CallExpr): call = expr @@ -730,6 +621,7 @@ def get_call(self, expr: Expression, var_name: str, *, calleename = callee.fullname if len(call.args) > 0: name = getattr(call.args[0], 'value', var_name) + fresh = (calleename is None or not calleename.endswith("TypeVar")) if isinstance(name, str) and fresh: if name != var_name or self.semanalyzer.is_func_scope(): # Give it a unique name derived from the line number. @@ -738,6 +630,50 @@ def get_call(self, expr: Expression, var_name: str, *, name = var_name return (call, calleename, name) + def parse_namedtuple_args(self, call: CallExpr, + fullname: str) -> Tuple[List[str], List[Type], bool]: + # TODO: Share code with check_argument_count in checkexpr.py? + args = call.args + if len(args) < 2: + return self.fail_namedtuple_arg("Too few arguments for namedtuple()", call) + if len(args) > 2: + # FIX incorrect. There are two additional parameters + return self.fail_namedtuple_arg("Too many arguments for namedtuple()", call) + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_namedtuple_arg("Unexpected arguments to namedtuple()", call) + if not isinstance(args[0], (StrExpr, BytesExpr, UnicodeExpr)): + return self.fail_namedtuple_arg( + "namedtuple() expects a string literal as the first argument", call) + types = [] # type: List[Type] + ok = True + if not isinstance(args[1], (ListExpr, TupleExpr)): + if (fullname == 'collections.namedtuple' + and isinstance(args[1], (StrExpr, BytesExpr, UnicodeExpr))): + str_expr = cast(StrExpr, args[1]) + items = str_expr.value.replace(',', ' ').split() + else: + return self.fail_namedtuple_arg( + "List or tuple literal expected as the second argument to namedtuple()", call) + else: + listexpr = args[1] + if fullname == 'collections.namedtuple': + # The fields argument contains just names, with implicit Any types. + if any(not isinstance(item, (StrExpr, BytesExpr, UnicodeExpr)) + for item in listexpr.items): + return self.fail_namedtuple_arg("String literal expected as namedtuple() item", + call) + items = [cast(StrExpr, item).value for item in listexpr.items] + else: + # The fields argument contains (name, type) tuples. + items, types, ok = self.parse_namedtuple_fields_with_types(listexpr.items) + if not types: + types = [AnyType() for _ in items] + underscore = [item for item in items if item.startswith('_')] + if underscore: + self.fail("namedtuple() field names cannot start with an underscore: " + + ', '.join(underscore), call) + return items, types, ok + def parse_typeddict_args(self, call: CallExpr) -> Tuple[List[str], List[Type], bool]: # TODO: Share code with check_argument_count in checkexpr.py? args = call.args @@ -758,6 +694,87 @@ def parse_typeddict_args(self, call: CallExpr) -> Tuple[List[str], List[Type], b items, types, ok = self.parse_typeddict_fields_with_types(dictexpr.items) return items, types, ok + def parse_enum_call_args(self, call: CallExpr, + class_name: str) -> Tuple[List[str], + List[Optional[Expression]], bool]: + args = call.args + if len(args) < 2: + return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call) + if len(args) > 2: + return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call) + if call.arg_kinds != [ARG_POS, ARG_POS]: + return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call) + if not isinstance(args[0], (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() expects a string literal as the first argument" % class_name, call) + items = [] + values = [] # type: List[Optional[Expression]] + if isinstance(args[1], (StrExpr, UnicodeExpr)): + fields = args[1].value + for field in fields.replace(',', ' ').split(): + items.append(field) + elif isinstance(args[1], (TupleExpr, ListExpr)): + seq_items = args[1].items + if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): + items = [cast(StrExpr, seq_item).value for seq_item in seq_items] + elif all(isinstance(seq_item, (TupleExpr, ListExpr)) + and len(seq_item.items) == 2 + and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) + for seq_item in seq_items): + for seq_item in seq_items: + assert isinstance(seq_item, (TupleExpr, ListExpr)) + name, value = seq_item.items + assert isinstance(name, (StrExpr, UnicodeExpr)) + items.append(name.value) + values.append(value) + else: + return self.fail_enum_call_arg( + "%s() with tuple or list expects strings or (name, value) pairs" % + class_name, + call) + elif isinstance(args[1], DictExpr): + for key, value in args[1].items: + if not isinstance(key, (StrExpr, UnicodeExpr)): + return self.fail_enum_call_arg( + "%s() with dict literal requires string literals" % class_name, call) + items.append(key.value) + values.append(value) + else: + # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}? + return self.fail_enum_call_arg( + "%s() expects a string, tuple, list or dict literal as the second argument" % + class_name, + call) + if len(items) == 0: + return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call) + if not values: + values = [None] * len(items) + assert len(items) == len(values) + return items, values, True + + def parse_namedtuple_fields_with_types(self, nodes: List[Expression] + ) -> Tuple[List[str], List[Type], bool]: + items = [] # type: List[str] + types = [] # type: List[Type] + for item in nodes: + if isinstance(item, TupleExpr): + if len(item.items) != 2: + return self.fail_namedtuple_arg("Invalid NamedTuple field definition", + item) + name, type_node = item.items + if isinstance(name, (StrExpr, BytesExpr, UnicodeExpr)): + items.append(name.value) + else: + return self.fail_namedtuple_arg("Invalid NamedTuple() field name", item) + try: + type = expr_to_unanalyzed_type(type_node) + except TypeTranslationError: + return self.fail_namedtuple_arg('Invalid field type', type_node) + types.append(self.semanalyzer.anal_type(type)) + else: + return self.fail_namedtuple_arg("Tuple expected as NamedTuple() field", item) + return items, types, True + def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, Expression]], ) -> Tuple[List[str], List[Type], bool]: items = [] # type: List[str] @@ -791,7 +808,7 @@ def build_typeddict_typeinfo(self, name: str, items: List[str], return info - def check_enum_call(self, node: Expression, var_name: str = None) -> Optional[TypeInfo]: + def check_enum_call(self, call: CallExpr, calleename: str, name: str) -> Optional[TypeInfo]: """Check if a call defines an Enum. Example: @@ -804,9 +821,6 @@ class A(enum.Enum): foo = 1 bar = 2 """ - call, calleename, name = self.get_call(node, var_name) - if calleename not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): - return None items, values, ok = self.parse_enum_call_args(call, calleename.split('.')[-1]) info = self.build_enum_call_typeinfo(name, items, calleename) if ok: @@ -827,64 +841,6 @@ def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) - info.names[item] = SymbolTableNode(MDEF, var) return info - def parse_enum_call_args(self, call: CallExpr, - class_name: str) -> Tuple[List[str], - List[Optional[Expression]], bool]: - args = call.args - if len(args) < 2: - return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call) - if len(args) > 2: - return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call) - if call.arg_kinds != [ARG_POS, ARG_POS]: - return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call) - if not isinstance(args[0], (StrExpr, UnicodeExpr)): - return self.fail_enum_call_arg( - "%s() expects a string literal as the first argument" % class_name, call) - items = [] - values = [] # type: List[Optional[Expression]] - if isinstance(args[1], (StrExpr, UnicodeExpr)): - fields = args[1].value - for field in fields.replace(',', ' ').split(): - items.append(field) - elif isinstance(args[1], (TupleExpr, ListExpr)): - seq_items = args[1].items - if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items): - items = [cast(StrExpr, seq_item).value for seq_item in seq_items] - elif all(isinstance(seq_item, (TupleExpr, ListExpr)) - and len(seq_item.items) == 2 - and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr)) - for seq_item in seq_items): - for seq_item in seq_items: - assert isinstance(seq_item, (TupleExpr, ListExpr)) - name, value = seq_item.items - assert isinstance(name, (StrExpr, UnicodeExpr)) - items.append(name.value) - values.append(value) - else: - return self.fail_enum_call_arg( - "%s() with tuple or list expects strings or (name, value) pairs" % - class_name, - call) - elif isinstance(args[1], DictExpr): - for key, value in args[1].items: - if not isinstance(key, (StrExpr, UnicodeExpr)): - return self.fail_enum_call_arg( - "%s() with dict literal requires string literals" % class_name, call) - items.append(key.value) - values.append(value) - else: - # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}? - return self.fail_enum_call_arg( - "%s() expects a string, tuple, list or dict literal as the second argument" % - class_name, - call) - if len(items) == 0: - return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call) - if not values: - values = [None] * len(items) - assert len(items) == len(values) - return items, values, True - def fail_enum_call_arg(self, message: str, context: Context) -> Tuple[List[str], List[Optional[Expression]], bool]: diff --git a/test-data/unit/check-newtype.test b/test-data/unit/check-newtype.test index 25adf9885d0db..ba4e67832fef0 100644 --- a/test-data/unit/check-newtype.test +++ b/test-data/unit/check-newtype.test @@ -302,7 +302,7 @@ c = NewType('c', str) # type: str main:4: error: Cannot redefine 'a' as a NewType main:7: error: Cannot assign to a type main:7: error: Cannot redefine 'b' as a NewType -main:9: error: Cannot declare the type of a NewType declaration +main:9: error: Cannot declare the type of a NewType [case testNewTypeAddingExplicitTypesFails] from typing import NewType From df66304845e2367b72d4dd2340fb06698e0ca400 Mon Sep 17 00:00:00 2001 From: elazar Date: Thu, 6 Apr 2017 04:20:02 +0300 Subject: [PATCH 09/18] refactor classdef processing --- mypy/specialtype.py | 266 ++++++++++++++++++-------------------------- 1 file changed, 111 insertions(+), 155 deletions(-) diff --git a/mypy/specialtype.py b/mypy/specialtype.py index 0df955a4045b7..027340d37a868 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -5,7 +5,7 @@ from collections import OrderedDict -from typing import List, Dict, Tuple, cast, Optional, TYPE_CHECKING +from typing import List, Dict, Tuple, cast, Optional, Union, Callable, TYPE_CHECKING from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.nodes import ( @@ -121,12 +121,13 @@ def check_typevar(self, call: CallExpr, name: str, fullname: str) -> Optional[Ty n_values = call.arg_kinds[1:].count(ARG_POS) values = self.analyze_types(call.args[1:1 + n_values]) - res = self.process_typevar_arguments(call.args[1 + n_values:], - call.arg_names[1 + n_values:], - call.arg_kinds[1 + n_values:], - n_values, - context=call) - if res is None: + res = self.parse_typevar_args(call.args[1 + n_values:], + call.arg_names[1 + n_values:], + call.arg_kinds[1 + n_values:], + n_values) + if isinstance(res, str): + for msg in res.split('\n'): + self.fail(msg, call) return None variance, upper_bound = res return TypeVarExpr(name, fullname, values, upper_bound, variance) @@ -139,7 +140,7 @@ def check_newtype(self, call: CallExpr, var_name: str = None) -> Optional[TypeIn # overwritten later with a fully complete NewTypeExpr if there are no other # errors with the NewType() call. - old_type = self.check_newtype_args(var_name, call, call) + old_type = self.parse_newtype_args(var_name, call, call) call.analyzed = NewTypeExpr(var_name, old_type, line=call.line) if old_type is None: return None @@ -156,27 +157,24 @@ def check_newtype(self, call: CallExpr, var_name: str = None) -> Optional[TypeIn return None return newtype_class_info - def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: + def lookup_base(self, defn: ClassDef, p: Callable[[RefExpr], bool] = lambda _: False) -> Optional[SymbolTableNode]: + res = None for base_expr in defn.base_type_exprs: if isinstance(base_expr, RefExpr): base_expr.accept(self.semanalyzer) - if base_expr.fullname == 'typing.NamedTuple': - node = self.lookup(defn.name, defn) - if node is not None: - node.kind = GDEF # TODO in process_namedtuple_definition also applies here - items, types, default_items = self.check_namedtuple_classdef(defn) - node.node = self.build_namedtuple_typeinfo( - defn.name, items, types, default_items) - return True - return False + if p(base_expr): + res = self.lookup(defn.name, defn) + return res - def check_namedtuple_classdef( - self, defn: ClassDef) -> Tuple[List[str], List[Type], Dict[str, Expression]]: - NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' - 'expected "field_name: field_type"') + def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: + node = self.lookup_base(defn, lambda x: x.fullname == 'typing.NamedTuple') + if node is None: + return False if self.semanalyzer.options.python_version < (3, 6): self.fail('NamedTuple class syntax is only supported in Python 3.6', defn) - return [], [], {} + node.kind = GDEF # TODO in process_namedtuple_definition also applies here + NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' + 'expected "field_name: field_type"') if len(defn.base_type_exprs) > 1: self.fail('NamedTuple should be a single base', defn) items = [] # type: List[str] @@ -186,8 +184,8 @@ def check_namedtuple_classdef( if not isinstance(stmt, AssignmentStmt): # Still allow pass or ... (for empty namedtuples). if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): + not (isinstance(stmt, ExpressionStmt) and + isinstance(stmt.expr, EllipsisExpr))): self.fail(NAMEDTUP_CLASS_ERROR, stmt) elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): # An assignment, but an invalid one. @@ -211,64 +209,47 @@ def check_namedtuple_classdef( stmt) else: default_items[name] = stmt.rvalue - return items, types, default_items + node.node = self.build_namedtuple_typeinfo(defn.name, items, types, default_items) + return True def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: # special case for TypedDict - possible = False - for base_expr in defn.base_type_exprs: - if isinstance(base_expr, RefExpr): - base_expr.accept(self.semanalyzer) - if (base_expr.fullname == 'mypy_extensions.TypedDict' or - is_typeddict(base_expr)): - possible = True - if possible: - node = self.lookup(defn.name, defn) - if node is not None: - node.kind = GDEF # TODO in process_namedtuple_definition also applies here - if (len(defn.base_type_exprs) == 1 and - isinstance(defn.base_type_exprs[0], RefExpr) and - defn.base_type_exprs[0].fullname == 'mypy_extensions.TypedDict'): - # Building a new TypedDict - fields, types = self.check_typeddict_classdef(defn) - node.node = self.build_typeddict_typeinfo(defn.name, fields, types) - return True - # Extending/merging existing TypedDicts - if any(not isinstance(expr, RefExpr) or - expr.fullname != 'mypy_extensions.TypedDict' and - not is_typeddict(expr) for expr in defn.base_type_exprs): - self.fail("All bases of a new TypedDict must be TypedDict types", defn) - typeddict_bases = list(filter(is_typeddict, defn.base_type_exprs)) - newfields = [] # type: List[str] - newtypes = [] # type: List[Type] - tpdict = None # type: OrderedDict[str, Type] - for base in typeddict_bases: - assert isinstance(base, RefExpr) - assert isinstance(base.node, TypeInfo) - assert isinstance(base.node.typeddict_type, TypedDictType) - tpdict = base.node.typeddict_type.items - newdict = tpdict.copy() - for key in tpdict: - if key in newfields: - self.fail('Cannot overwrite TypedDict field "{}" while merging' - .format(key), defn) - newdict.pop(key) - newfields.extend(newdict.keys()) - newtypes.extend(newdict.values()) - fields, types = self.check_typeddict_classdef(defn, newfields) - newfields.extend(fields) - newtypes.extend(types) - node.node = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) - return True - return False + node = self.lookup_base(defn, is_typeddict) + print(node) + if node is None: + return False + if self.semanalyzer.options.python_version < (3, 6): + self.fail('TypedDict class syntax is only supported in Python 3.6', defn) + typeddict_bases = [cast(RefExpr, expr) for expr in defn.base_type_exprs if is_typeddict(expr)] + if typeddict_bases != defn.base_type_exprs: + self.fail("All bases of a new TypedDict must be TypedDict types", defn) + typeddict_bases = [expr for expr in typeddict_bases if expr.fullname != 'mypy_extensions.TypedDict'] + newfields = [] # type: List[str] + newtypes = [] # type: List[Type] + for base in typeddict_bases: + assert isinstance(base, RefExpr) + assert isinstance(base.node, TypeInfo) + assert isinstance(base.node.typeddict_type, TypedDictType) + tpdict = base.node.typeddict_type.items + newdict = tpdict.copy() + for key in tpdict: + if key in newfields: + self.fail('Cannot overwrite TypedDict field "{}" while merging' + .format(key), defn) + newdict.pop(key) + newfields.extend(newdict.keys()) + newtypes.extend(newdict.values()) + fields, types = self.check_typeddict_classdef(defn, newfields) + newfields.extend(fields) + newtypes.extend(types) + node.node = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) + node.kind = GDEF + return True def check_typeddict_classdef(self, defn: ClassDef, oldfields: List[str] = None) -> Tuple[List[str], List[Type]]: TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' 'expected "field_name: field_type"') - if self.semanalyzer.options.python_version < (3, 6): - self.fail('TypedDict class syntax is only supported in Python 3.6', defn) - return [], [] fields = [] # type: List[str] types = [] # type: List[Type] for stmt in defn.defs.body: @@ -302,7 +283,7 @@ def check_typeddict_classdef(self, defn: ClassDef, self.fail('Right hand side values are not supported in TypedDict', stmt) return fields, types - def check_newtype_args(self, name: str, call: CallExpr, context: Context) -> Optional[Type]: + def parse_newtype_args(self, name: str, call: CallExpr, context: Context) -> Optional[Type]: has_failed = False args, arg_kinds = call.args, call.arg_kinds if len(args) != 2 or arg_kinds[0] != ARG_POS or arg_kinds[1] != ARG_POS: @@ -414,73 +395,6 @@ def check_typevar_name(self, call: CallExpr, name: str, context: Context) -> boo return False return True - def process_typevar_arguments(self, - args: List[Expression], - names: List[Optional[str]], - kinds: List[int], - num_values: int, - context: Context) -> Optional[Tuple[int, Type]]: - has_values = (num_values > 0) - covariant = False - contravariant = False - upper_bound = self.object_type() # type: Type - for arg_value, arg_name, arg_kind in zip(args, names, kinds): - if not arg_kind == ARG_NAMED: - self.fail("Unexpected argument to TypeVar()", context) - return None - if arg_name == 'covariant': - if isinstance(arg_value, NameExpr): - if arg_value.name == 'True': - covariant = True - else: - self.fail("TypeVar 'covariant' may only be 'True'", context) - return None - else: - self.fail("TypeVar 'covariant' may only be 'True'", context) - return None - elif arg_name == 'contravariant': - if isinstance(arg_value, NameExpr): - if arg_value.name == 'True': - contravariant = True - else: - self.fail("TypeVar 'contravariant' may only be 'True'", context) - return None - else: - self.fail("TypeVar 'contravariant' may only be 'True'", context) - return None - elif arg_name == 'bound': - if has_values: - self.fail("TypeVar cannot have both values and an upper bound", context) - return None - try: - upper_bound = self.semanalyzer.expr_to_analyzed_type(arg_value) - except TypeTranslationError: - self.fail("TypeVar 'bound' must be a type", arg_value) - return None - elif arg_name == 'values': - # Probably using obsolete syntax with values=(...). Explain the current syntax. - self.fail("TypeVar 'values' argument not supported", context) - self.fail("Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))", - context) - return None - else: - self.fail("Unexpected argument to TypeVar(): {}".format(arg_name), context) - return None - - if covariant and contravariant: - self.fail("TypeVar cannot be both covariant and contravariant", context) - return None - elif num_values == 1: - self.fail("TypeVar cannot have only a single constraint", context) - return None - elif covariant: - variance = COVARIANT - elif contravariant: - variance = CONTRAVARIANT - else: - variance = INVARIANT - return (variance, upper_bound) - def check_typeddict(self, call: CallExpr, name: str) -> Optional[TypeInfo]: """Check if a call defines a TypedDict. @@ -500,11 +414,6 @@ def check_typeddict(self, call: CallExpr, name: str) -> Optional[TypeInfo]: call.analyzed.set_line(call.line, call.column) return info - def fail_namedtuple_arg(self, message: str, - context: Context) -> Tuple[List[str], List[Type], bool]: - self.fail(message, context) - return [], [], False - def basic_new_typeinfo(self, name: str, basetype_or_fallback: Instance) -> TypeInfo: class_def = ClassDef(name, Block([])) class_def.fullname = self.semanalyzer.qualified_name(name) @@ -752,6 +661,45 @@ def parse_enum_call_args(self, call: CallExpr, assert len(items) == len(values) return items, values, True + def parse_typevar_args(self, + args: List[Expression], + names: List[Optional[str]], + kinds: List[int], + num_values: int) -> Union[str, Tuple[int, Type]]: + has_values = (num_values > 0) + upper_bound = self.object_type() # type: Type + variance = INVARIANT + for arg_value, arg_name, arg_kind in zip(args, names, kinds): + if arg_name in ('contravariant', 'covariant'): + if variance != INVARIANT: + return "TypeVar cannot be both covariant and contravariant" + if isinstance(arg_value, NameExpr) and arg_value.name == 'True': + if arg_name == 'contravariant': + variance = CONTRAVARIANT + else: + variance = COVARIANT + else: + return "TypeVar '{}' may only be 'True'".format(arg_name) + elif arg_name == 'bound': + if has_values: + return "TypeVar cannot have both values and an upper bound" + try: + upper_bound = self.semanalyzer.expr_to_analyzed_type(arg_value) + except TypeTranslationError: + return "TypeVar 'bound' must be a type" + elif arg_name == 'values': + # Probably using obsolete syntax with values=(...). Explain the current syntax. + return ("TypeVar 'values' argument not supported\n" + "Use TypeVar('T', t, ...) instead of TypeVar('T', values=(t, ...))") + else: + res = "Unexpected argument to TypeVar()" + if arg_name: + res += ": " + arg_name + return res + if num_values == 1: + return "TypeVar cannot have only a single constraint" + return (variance, upper_bound) + def parse_namedtuple_fields_with_types(self, nodes: List[Expression] ) -> Tuple[List[str], List[Type], bool]: items = [] # type: List[str] @@ -791,11 +739,6 @@ def parse_typeddict_fields_with_types(self, dict_items: List[Tuple[Expression, E types.append(self.semanalyzer.anal_type(type)) return items, types, True - def fail_typeddict_arg(self, message: str, - context: Context) -> Tuple[List[str], List[Type], bool]: - self.fail(message, context) - return [], [], False - def build_typeddict_typeinfo(self, name: str, items: List[str], types: List[Type]) -> TypeInfo: mapping_value_type = join.join_type_list(types) @@ -841,6 +784,16 @@ def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) - info.names[item] = SymbolTableNode(MDEF, var) return info + def fail_typeddict_arg(self, message: str, + context: Context) -> Tuple[List[str], List[Type], bool]: + self.fail(message, context) + return [], [], False + + def fail_namedtuple_arg(self, message: str, + context: Context) -> Tuple[List[str], List[Type], bool]: + self.fail(message, context) + return [], [], False + def fail_enum_call_arg(self, message: str, context: Context) -> Tuple[List[str], List[Optional[Expression]], bool]: @@ -849,5 +802,8 @@ def fail_enum_call_arg(self, message: str, def is_typeddict(expr: Expression) -> bool: - return (isinstance(expr, RefExpr) and isinstance(expr.node, TypeInfo) and - expr.node.typeddict_type is not None) + if not isinstance(expr, RefExpr): + return False + if expr.fullname == 'mypy_extensions.TypedDict': + return True + return isinstance(expr.node, TypeInfo) and expr.node.typeddict_type is not None From a4b78d0960b1431ed858a0ff8b8e1a1015cd7b3f Mon Sep 17 00:00:00 2001 From: elazar Date: Thu, 6 Apr 2017 06:01:48 +0300 Subject: [PATCH 10/18] refactor classdef contd. --- mypy/semanal.py | 6 +- mypy/specialtype.py | 134 ++++++++++---------- test-data/unit/check-typeddict.test | 4 +- test-data/unit/lib-stub/mypy_extensions.pyi | 2 +- test-data/unit/lib-stub/typing.pyi | 2 +- 5 files changed, 71 insertions(+), 77 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index e19197a91a910..b5be9d439d7be 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -702,11 +702,6 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: if self.specialtype.analyze_typeddict_classdef(defn): yield False return - if self.specialtype.analyze_namedtuple_classdef(defn): - # just analyze the class body so we catch type errors in default values - self.enter_class(defn) - yield True - self.leave_class() else: self.setup_class_def_analysis(defn) @@ -718,6 +713,7 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: for decorator in defn.decorators: decorator.accept(self) + self.specialtype.analyze_namedtuple_classdef(defn) self.enter_class(defn) yield True diff --git a/mypy/specialtype.py b/mypy/specialtype.py index 027340d37a868..f6ea312e98f26 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -14,7 +14,7 @@ DictExpr, CallExpr, RefExpr, Context, SymbolTable, UNBOUND_TVAR, MDEF, Decorator, TypeVarExpr, NewTypeExpr, StrExpr, BytesExpr, ARG_POS, ARG_NAMED, ARG_NAMED_OPT, NamedTupleExpr, TypedDictExpr, Argument, - UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, + UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, Statement, COVARIANT, CONTRAVARIANT, INVARIANT, ARG_OPT, SymbolTableNode ) from mypy.types import ( @@ -35,6 +35,12 @@ class DeclInfo: is_def = None # type: bool +NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' + 'expected "field_name: field_type"') +TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' + 'expected "field_name: field_type"') + + class Special: """Handling of special-cased types. @@ -166,37 +172,67 @@ def lookup_base(self, defn: ClassDef, p: Callable[[RefExpr], bool] = lambda _: F res = self.lookup(defn.name, defn) return res + def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: + node = self.lookup_base(defn, is_typeddict) + if node is None: + return False + if self.semanalyzer.options.python_version < (3, 6): + self.fail('TypedDict class syntax is only supported in Python 3.6', defn) + fields, types = self.analyze_typeddict_bases(defn) + newfields = [] # type: List[str] + newtypes = [] # type: List[Type] + for stmt in defn.defs.body: + if not isinstance(stmt, AssignmentStmt): + # Still allow pass or ... (for empty TypedDict's). + if not isinstance(stmt, (PassStmt, ExpressionStmt, EllipsisExpr)): + self.fail(TPDICT_CLASS_ERROR, stmt) + elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): + # An assignment, but an invalid one. + self.fail(TPDICT_CLASS_ERROR, stmt) + else: + name = stmt.lvalues[0].name + if name in fields: + self.fail('Cannot overwrite TypedDict field "{}" while extending' + .format(name), stmt) + if name in newfields: + self.fail('Duplicate TypedDict field "{}"'.format(name), stmt) + if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: + self.fail(TPDICT_CLASS_ERROR, stmt) + elif not isinstance(stmt.rvalue, TempNode): + # x: int assigns rvalue to TempNode(AnyType()) + self.fail('Right hand side values are not supported in TypedDict', stmt) + type = AnyType() if stmt.type is None else self.semanalyzer.anal_type(stmt.type) + newfields.append(name) + newtypes.append(type) + + fields.extend(newfields) + types.extend(newtypes) + node.node = self.build_typeddict_typeinfo(defn.name, fields, types) + node.kind = GDEF + return True + def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: node = self.lookup_base(defn, lambda x: x.fullname == 'typing.NamedTuple') if node is None: return False if self.semanalyzer.options.python_version < (3, 6): self.fail('NamedTuple class syntax is only supported in Python 3.6', defn) - node.kind = GDEF # TODO in process_namedtuple_definition also applies here - NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' - 'expected "field_name: field_type"') - if len(defn.base_type_exprs) > 1: - self.fail('NamedTuple should be a single base', defn) - items = [] # type: List[str] - types = [] # type: List[Type] + fields, types = self.analyze_namedtuple_bases(defn) + newfields = [] # type: List[str] + newtypes = [] # type: List[Type] default_items = {} # type: Dict[str, Expression] for stmt in defn.defs.body: if not isinstance(stmt, AssignmentStmt): # Still allow pass or ... (for empty namedtuples). - if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): + if not isinstance(stmt, (PassStmt, ExpressionStmt, EllipsisExpr)): self.fail(NAMEDTUP_CLASS_ERROR, stmt) elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): # An assignment, but an invalid one. self.fail(NAMEDTUP_CLASS_ERROR, stmt) else: - # Append name and type in this case... name = stmt.lvalues[0].name - items.append(name) - types.append(AnyType() if stmt.type is None - else self.semanalyzer.anal_type(stmt.type)) - # ...despite possible minor failures that allow further analyzis. + if name in newfields: + self.fail('Duplicate NamedTuple field "{}"'.format(name), stmt) if name.startswith('_'): self.fail('NamedTuple field name cannot start with an underscore: {}' .format(name), stmt) @@ -209,17 +245,21 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: stmt) else: default_items[name] = stmt.rvalue - node.node = self.build_namedtuple_typeinfo(defn.name, items, types, default_items) + type = AnyType() if stmt.type is None else self.semanalyzer.anal_type(stmt.type) + newfields.append(name) + newtypes.append(type) + fields.extend(newfields) + types.extend(newtypes) + node.node = self.build_namedtuple_typeinfo(defn.name, fields, types, default_items) + node.kind = GDEF return True - def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: - # special case for TypedDict - node = self.lookup_base(defn, is_typeddict) - print(node) - if node is None: - return False - if self.semanalyzer.options.python_version < (3, 6): - self.fail('TypedDict class syntax is only supported in Python 3.6', defn) + def analyze_namedtuple_bases(self, defn: ClassDef) -> Tuple[List[str], List[Type]]: + if len(defn.base_type_exprs) > 1: + self.fail('NamedTuple should be a single base', defn) + return ([], []) + + def analyze_typeddict_bases(self, defn: ClassDef) -> Tuple[List[str], List[Type]]: typeddict_bases = [cast(RefExpr, expr) for expr in defn.base_type_exprs if is_typeddict(expr)] if typeddict_bases != defn.base_type_exprs: self.fail("All bases of a new TypedDict must be TypedDict types", defn) @@ -239,49 +279,7 @@ def analyze_typeddict_classdef(self, defn: ClassDef) -> bool: newdict.pop(key) newfields.extend(newdict.keys()) newtypes.extend(newdict.values()) - fields, types = self.check_typeddict_classdef(defn, newfields) - newfields.extend(fields) - newtypes.extend(types) - node.node = self.build_typeddict_typeinfo(defn.name, newfields, newtypes) - node.kind = GDEF - return True - - def check_typeddict_classdef(self, defn: ClassDef, - oldfields: List[str] = None) -> Tuple[List[str], List[Type]]: - TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' - 'expected "field_name: field_type"') - fields = [] # type: List[str] - types = [] # type: List[Type] - for stmt in defn.defs.body: - if not isinstance(stmt, AssignmentStmt): - # Still allow pass or ... (for empty TypedDict's). - if (not isinstance(stmt, PassStmt) and - not (isinstance(stmt, ExpressionStmt) and - isinstance(stmt.expr, EllipsisExpr))): - self.fail(TPDICT_CLASS_ERROR, stmt) - elif len(stmt.lvalues) > 1 or not isinstance(stmt.lvalues[0], NameExpr): - # An assignment, but an invalid one. - self.fail(TPDICT_CLASS_ERROR, stmt) - else: - name = stmt.lvalues[0].name - if name in (oldfields or []): - self.fail('Cannot overwrite TypedDict field "{}" while extending' - .format(name), stmt) - continue - if name in fields: - self.fail('Duplicate TypedDict field "{}"'.format(name), stmt) - continue - # Append name and type in this case... - fields.append(name) - types.append(AnyType() if stmt.type is None - else self.semanalyzer.anal_type(stmt.type)) - # ...despite possible minor failures that allow further analyzis. - if stmt.type is None or hasattr(stmt, 'new_syntax') and not stmt.new_syntax: - self.fail(TPDICT_CLASS_ERROR, stmt) - elif not isinstance(stmt.rvalue, TempNode): - # x: int assigns rvalue to TempNode(AnyType()) - self.fail('Right hand side values are not supported in TypedDict', stmt) - return fields, types + return newfields, newtypes def parse_newtype_args(self, name: str, call: CallExpr, context: Context) -> Optional[Type]: has_failed = False diff --git a/test-data/unit/check-typeddict.test b/test-data/unit/check-typeddict.test index f8e4f721ad7e3..813f614c447c7 100644 --- a/test-data/unit/check-typeddict.test +++ b/test-data/unit/check-typeddict.test @@ -181,7 +181,7 @@ class Bad(TypedDict): x: str # E: Duplicate TypedDict field "x" b: Bad -reveal_type(b) # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=__main__.Bad)' +reveal_type(b) # E: Revealed type is 'TypedDict(x=builtins.str, _fallback=__main__.Bad)' [builtins fixtures/dict.pyi] [case testCannotCreateTypedDictWithClassOverwriting2] @@ -209,7 +209,7 @@ class Point2(Point1): x: float # E: Cannot overwrite TypedDict field "x" while extending p2: Point2 -reveal_type(p2) # E: Revealed type is 'TypedDict(x=builtins.int, _fallback=__main__.Point2)' +reveal_type(p2) # E: Revealed type is 'TypedDict(x=builtins.float, _fallback=__main__.Point2)' [builtins fixtures/dict.pyi] diff --git a/test-data/unit/lib-stub/mypy_extensions.pyi b/test-data/unit/lib-stub/mypy_extensions.pyi index 6e1e3b0ed2853..2ab1dbee8bc83 100644 --- a/test-data/unit/lib-stub/mypy_extensions.pyi +++ b/test-data/unit/lib-stub/mypy_extensions.pyi @@ -3,6 +3,6 @@ from typing import Dict, Type, TypeVar T = TypeVar('T') -def TypedDict(typename: str, fields: Dict[str, Type[T]]) -> Type[dict]: pass +class TypedDict: pass class NoReturn: pass diff --git a/test-data/unit/lib-stub/typing.pyi b/test-data/unit/lib-stub/typing.pyi index 01ac7b14f7b9e..316944b15203d 100644 --- a/test-data/unit/lib-stub/typing.pyi +++ b/test-data/unit/lib-stub/typing.pyi @@ -14,7 +14,7 @@ Tuple = 0 Callable = 0 builtinclass = 0 _promote = 0 -NamedTuple = 0 +class NamedTuple: pass Type = 0 no_type_check = 0 ClassVar = 0 From d60803316cdfce8ba17b57af7403767028a14831 Mon Sep 17 00:00:00 2001 From: elazar Date: Fri, 7 Apr 2017 00:20:02 +0300 Subject: [PATCH 11/18] Lint --- mypy/specialtype.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mypy/specialtype.py b/mypy/specialtype.py index f6ea312e98f26..95a56473dda11 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -163,7 +163,8 @@ def check_newtype(self, call: CallExpr, var_name: str = None) -> Optional[TypeIn return None return newtype_class_info - def lookup_base(self, defn: ClassDef, p: Callable[[RefExpr], bool] = lambda _: False) -> Optional[SymbolTableNode]: + def lookup_base(self, defn: ClassDef, + p: Callable[[RefExpr], bool] = lambda _: False) -> Optional[SymbolTableNode]: res = None for base_expr in defn.base_type_exprs: if isinstance(base_expr, RefExpr): @@ -260,10 +261,12 @@ def analyze_namedtuple_bases(self, defn: ClassDef) -> Tuple[List[str], List[Type return ([], []) def analyze_typeddict_bases(self, defn: ClassDef) -> Tuple[List[str], List[Type]]: - typeddict_bases = [cast(RefExpr, expr) for expr in defn.base_type_exprs if is_typeddict(expr)] + typeddict_bases = [cast(RefExpr, expr) for expr in defn.base_type_exprs + if is_typeddict(expr)] if typeddict_bases != defn.base_type_exprs: self.fail("All bases of a new TypedDict must be TypedDict types", defn) - typeddict_bases = [expr for expr in typeddict_bases if expr.fullname != 'mypy_extensions.TypedDict'] + typeddict_bases = [expr for expr in typeddict_bases + if expr.fullname != 'mypy_extensions.TypedDict'] newfields = [] # type: List[str] newtypes = [] # type: List[Type] for base in typeddict_bases: From 7b01f1ce1d17fac8e261d53c018888ae0cff15aa Mon Sep 17 00:00:00 2001 From: elazar Date: Fri, 7 Apr 2017 00:33:51 +0300 Subject: [PATCH 12/18] remove unneeded class --- mypy/specialtype.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/mypy/specialtype.py b/mypy/specialtype.py index 95a56473dda11..a3b43af5dde00 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -9,13 +9,16 @@ from mypy.exprtotype import expr_to_unanalyzed_type, TypeTranslationError from mypy.nodes import ( - TypeInfo, AssignmentStmt, FuncDef, ClassDef, Var, GDEF, Expression, - Block, NameExpr, TupleExpr, ListExpr, ExpressionStmt, PassStmt, - DictExpr, CallExpr, RefExpr, Context, SymbolTable, UNBOUND_TVAR, - MDEF, Decorator, TypeVarExpr, NewTypeExpr, StrExpr, BytesExpr, - ARG_POS, ARG_NAMED, ARG_NAMED_OPT, NamedTupleExpr, TypedDictExpr, Argument, - UnicodeExpr, EllipsisExpr, TempNode, EnumCallExpr, Statement, - COVARIANT, CONTRAVARIANT, INVARIANT, ARG_OPT, SymbolTableNode + TypeVarExpr, NewTypeExpr, NamedTupleExpr, TypedDictExpr, EnumCallExpr, + TypeInfo, SymbolTableNode, SymbolTable, Context, TempNode, + Var, Argument, NameExpr, RefExpr, + AssignmentStmt, FuncDef, ClassDef, Block, + Expression, EllipsisExpr, ExpressionStmt, PassStmt, + TupleExpr, ListExpr, DictExpr, CallExpr, Decorator, + StrExpr, BytesExpr, UnicodeExpr, + COVARIANT, CONTRAVARIANT, INVARIANT, + ARG_OPT, ARG_POS, ARG_NAMED, ARG_NAMED_OPT, + GDEF, MDEF, UNBOUND_TVAR, ) from mypy.types import ( NoneTyp, CallableType, Instance, Type, TypeVarType, AnyType, @@ -26,15 +29,6 @@ import mypy.semanal -class DeclInfo: - var_name = None # type: str - name = None # type: str - fullname = None # type: str - calleename = None # type: str - call = None # type: CallExpr - is_def = None # type: bool - - NAMEDTUP_CLASS_ERROR = ('Invalid statement in NamedTuple definition; ' 'expected "field_name: field_type"') TPDICT_CLASS_ERROR = ('Invalid statement in TypedDict definition; ' From 7a40ac42e22c2569292e29920f9934397cfd2908 Mon Sep 17 00:00:00 2001 From: elazar Date: Fri, 7 Apr 2017 03:48:26 +0300 Subject: [PATCH 13/18] moving namedtuple to use semantic info. NewNamedTuple tests pass --- mypy/nodes.py | 7 +- mypy/semanal.py | 31 +++-- mypy/specialtype.py | 127 ++++++++++++++++++++- test-data/unit/check-class-namedtuple.test | 22 ++-- 4 files changed, 155 insertions(+), 32 deletions(-) diff --git a/mypy/nodes.py b/mypy/nodes.py index d9c6b1ca2eb9e..00c5ecba7dec6 100644 --- a/mypy/nodes.py +++ b/mypy/nodes.py @@ -2,9 +2,10 @@ import os from abc import abstractmethod +from collections import OrderedDict from typing import ( - Any, TypeVar, List, Tuple, cast, Set, Dict, Union, Optional + Any, TypeVar, List, Tuple, cast, Set, Dict, Union, Optional, MutableMapping ) import mypy.strconv @@ -675,9 +676,9 @@ class Var(SymbolNode): type = None # type: mypy.types.Type # Declared or inferred type, or None # Is this the first argument to an ordinary method (usually "self")? is_self = False + is_inferred = False is_ready = False # If inferred, is the inferred type available? # Is this initialized explicitly to a non-None value in class body? - is_inferred = False is_initialized_in_class = False is_staticmethod = False is_classmethod = False @@ -2400,7 +2401,7 @@ def deserialize(cls, data: JsonDict) -> 'SymbolTableNode': return stnode -class SymbolTable(Dict[str, SymbolTableNode]): +class SymbolTable(OrderedDict, MutableMapping[str, SymbolTableNode]): def __str__(self) -> str: a = [] # type: List[str] for key, value in self.items(): diff --git a/mypy/semanal.py b/mypy/semanal.py index b5be9d439d7be..1554772296cf1 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -58,7 +58,7 @@ GlobalDecl, SuperExpr, DictExpr, CallExpr, RefExpr, OpExpr, UnaryExpr, SliceExpr, CastExpr, RevealTypeExpr, TypeApplication, Context, SymbolTable, SymbolTableNode, BOUND_TVAR, UNBOUND_TVAR, ListComprehension, GeneratorExpr, - LambdaExpr, MDEF, Decorator, SetExpr, TypeVarExpr, + LambdaExpr, MDEF, Decorator, SetExpr, TypeVarExpr, TempNode, StrExpr, BytesExpr, PrintStmt, ConditionalExpr, PromoteExpr, ComparisonExpr, StarExpr, ARG_POS, ARG_NAMED, MroError, type_aliases, YieldFromExpr, NonlocalDecl, SymbolNode, @@ -713,7 +713,6 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: for decorator in defn.decorators: decorator.accept(self) - self.specialtype.analyze_namedtuple_classdef(defn) self.enter_class(defn) yield True @@ -725,6 +724,8 @@ def analyze_class_body(self, defn: ClassDef) -> Iterator[bool]: self.unbind_class_type_vars() + self.specialtype.dispatch_classdef(defn) + def enter_class(self, defn: ClassDef) -> None: # Remember previous active class self.type_stack.append(self.type) @@ -1217,13 +1218,12 @@ def visit_import_all(self, i: ImportAll) -> None: pass def add_unknown_symbol(self, name: str, context: Context, is_import: bool = False) -> None: - var = Var(name) + var = Var(name, AnyType()) if self.type: var._fullname = self.type.fullname() + "." + name else: var._fullname = self.qualified_name(name) var.is_ready = True - var.type = AnyType() var.is_suppressed_import = is_import self.add_symbol(name, SymbolTableNode(GDEF, var, self.cur_mod_id), context) @@ -1258,7 +1258,8 @@ def anal_type(self, t: Type, allow_tuple_literal: bool = False, def visit_assignment_stmt(self, s: AssignmentStmt) -> None: for lval in s.lvalues: - self.analyze_lvalue(lval, explicit_type=s.type is not None) + self.analyze_lvalue(lval, explicit_type=s.type is not None, + is_initialized=not isinstance(s.rvalue, TempNode)) self.check_classvar(s) s.rvalue.accept(self) if s.type: @@ -1354,7 +1355,7 @@ def check_and_set_up_type_alias(self, s: AssignmentStmt) -> None: def analyze_lvalue(self, lval: Lvalue, nested: bool = False, add_global: bool = False, - explicit_type: bool = False) -> None: + explicit_type: bool = False, is_initialized: bool = True) -> None: """Analyze an lvalue or assignment target. Only if add_global is True, add name to globals table. If nested @@ -1372,6 +1373,7 @@ def analyze_lvalue(self, lval: Lvalue, nested: bool = False, v = Var(lval.name) v.set_line(lval) v._fullname = self.qualified_name(lval.name) + v.is_inferred = not explicit_type v.is_ready = False # Type not inferred yet lval.node = v lval.is_def = True @@ -1389,6 +1391,7 @@ def analyze_lvalue(self, lval: Lvalue, nested: bool = False, # Define new local name. v = Var(lval.name) v.set_line(lval) + v.is_inferred = not explicit_type lval.node = v lval.is_def = True lval.kind = LDEF @@ -1399,8 +1402,9 @@ def analyze_lvalue(self, lval: Lvalue, nested: bool = False, # Define a new attribute within class body. v = Var(lval.name) v.info = self.type - v.is_initialized_in_class = True + v.is_initialized_in_class = is_initialized v.set_line(lval) + v.is_inferred = not explicit_type lval.node = v lval.is_def = True lval.kind = MDEF @@ -1432,7 +1436,7 @@ def analyze_lvalue(self, lval: Lvalue, nested: bool = False, self.analyze_tuple_or_list_lvalue(lval, add_global, explicit_type) elif isinstance(lval, StarExpr): if nested: - self.analyze_lvalue(lval.expr, nested, add_global, explicit_type) + self.analyze_lvalue(lval.expr, nested, add_global, explicit_type, is_initialized) else: self.fail('Starred assignment target must be in a list or tuple', lval) else: @@ -1452,7 +1456,7 @@ def analyze_tuple_or_list_lvalue(self, lval: Union[ListExpr, TupleExpr], star_exprs[0].valid = True for i in items: self.analyze_lvalue(i, nested=True, add_global=add_global, - explicit_type = explicit_type) + explicit_type=explicit_type) def analyze_member_lvalue(self, lval: MemberExpr) -> None: lval.accept(self) @@ -2544,7 +2548,8 @@ def visit_block(self, b: Block) -> None: def visit_assignment_stmt(self, s: AssignmentStmt) -> None: if self.sem.is_module_scope(): for lval in s.lvalues: - self.analyze_lvalue(lval, explicit_type=s.type is not None) + self.analyze_lvalue(lval, explicit_type=s.type is not None, + is_initialized=s.rvalue is not None) def visit_func_def(self, func: FuncDef) -> None: sem = self.sem @@ -2696,9 +2701,11 @@ def visit_try_stmt(self, s: TryStmt) -> None: if self.sem.is_module_scope(): self.sem.analyze_try_stmt(s, self, add_global=self.sem.is_module_scope()) - def analyze_lvalue(self, lvalue: Lvalue, explicit_type: bool = False) -> None: + def analyze_lvalue(self, lvalue: Lvalue, explicit_type: bool = False, + is_initialized: bool = True) -> None: self.sem.analyze_lvalue(lvalue, add_global=self.sem.is_module_scope(), - explicit_type=explicit_type) + explicit_type=explicit_type, + is_initialized=is_initialized) def kind_by_scope(self) -> int: if self.sem.is_module_scope(): diff --git a/mypy/specialtype.py b/mypy/specialtype.py index a3b43af5dde00..1748c217ecd05 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -212,7 +212,9 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: return False if self.semanalyzer.options.python_version < (3, 6): self.fail('NamedTuple class syntax is only supported in Python 3.6', defn) - fields, types = self.analyze_namedtuple_bases(defn) + if len(defn.base_type_exprs) > 1: + self.fail('NamedTuple should be a single base', defn) + fields, types = [], [] newfields = [] # type: List[str] newtypes = [] # type: List[Type] default_items = {} # type: Dict[str, Expression] @@ -249,10 +251,125 @@ def analyze_namedtuple_classdef(self, defn: ClassDef) -> bool: node.kind = GDEF return True - def analyze_namedtuple_bases(self, defn: ClassDef) -> Tuple[List[str], List[Type]]: - if len(defn.base_type_exprs) > 1: - self.fail('NamedTuple should be a single base', defn) - return ([], []) + def dispatch_classdef(self, defn: ClassDef) -> bool: + node = self.lookup_base(defn, lambda x: x.fullname == 'typing.NamedTuple') + if node is None: + return False + self.analyze_namedtuple_classdef_1(defn.info) + return True + + def analyze_namedtuple_classdef_1(self, info: TypeInfo) -> None: + if self.semanalyzer.options.python_version < (3, 6): + self.fail('NamedTuple class syntax is only supported in Python 3.6', info) + if len(info.direct_base_classes()) > 1: + self.fail('NamedTuple should be a single base', info) + fields = [] # type: List[str] + types = [] # type: List[Type] + default_items = {} # type: Dict[str, Expression] + for name, sym in info.names.items(): + node = sym.node + if name.startswith('_'): + self.fail('NamedTuple field name cannot start with an underscore: {}' + .format(name), node) + if isinstance(node, Var): + if node.type and not node.is_inferred: + fields.append(name) + types.append(node.type) + if node.is_initialized_in_class: + default_items[name] = EllipsisExpr() + elif default_items: + self.fail('Non-default NamedTuple fields cannot follow default fields', + node) + else: + self.fail(NAMEDTUP_CLASS_ERROR, node) + self.update_namedtuple_typeinfo(info, fields, types, default_items) + + def update_namedtuple_typeinfo(self, info: TypeInfo, items: List[str], types: List[Type], + default_items: Dict[str, Expression] = None) -> None: + default_items = default_items or {} + strtype = self.str_type() + object_type = self.object_type() + basetuple_type = self.named_type('__builtins__.tuple', [AnyType()]) + dictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) + or object_type) + # Actual signature should return OrderedDict[str, Union[types]] + ordereddictype = (self.named_type_or_none('builtins.dict', [strtype, AnyType()]) + or object_type) + fallback = self.named_type('__builtins__.tuple') + # Note: actual signature should accept an invariant version of Iterable[UnionType[types]]. + # but it can't be expressed. 'new' and 'len' should be callable types. + iterable_type = self.named_type_or_none('typing.Iterable', [AnyType()]) + function_type = self.named_type('__builtins__.function') + + info.bases += [fallback] + info.is_named_tuple = True + info.tuple_type = TupleType(types, fallback) + + def add_field(var: Var, is_initialized_in_class: bool = False, + is_property: bool = False) -> None: + var.info = info + var.is_initialized_in_class = is_initialized_in_class + var.is_property = is_property + info.names[var.name()] = SymbolTableNode(MDEF, var) + + vars = [Var(item, typ) for item, typ in zip(items, types)] + for var in vars: + add_field(var, is_property=True) + + tuple_of_strings = TupleType([strtype for _ in items], basetuple_type) + add_field(Var('_fields', tuple_of_strings), is_initialized_in_class=True) + add_field(Var('_field_types', dictype), is_initialized_in_class=True) + add_field(Var('_field_defaults', dictype), is_initialized_in_class=True) + add_field(Var('_source', strtype), is_initialized_in_class=True) + + tvd = TypeVarDef('NT', 1, [], info.tuple_type) + selftype = TypeVarType(tvd) + + def add_method(funcname: str, + ret: Type, + args: List[Argument], + name: str = None, + is_classmethod: bool = False, + ) -> None: + if is_classmethod: + first = [Argument(Var('cls'), TypeType(selftype), None, ARG_POS)] + else: + first = [Argument(Var('self'), selftype, None, ARG_POS)] + args = first + args + + types = [arg.type_annotation for arg in args] + items = [arg.variable.name() for arg in args] + arg_kinds = [arg.kind for arg in args] + signature = CallableType(types, arg_kinds, items, ret, function_type, + name=name or info.name() + '.' + funcname) + signature.variables = [tvd] + func = FuncDef(funcname, args, Block([]), typ=signature) + func.info = info + func.is_class = is_classmethod + if is_classmethod: + v = Var(funcname, signature) + v.is_classmethod = True + v.info = info + dec = Decorator(func, [NameExpr('classmethod')], v) + info.names[funcname] = SymbolTableNode(MDEF, dec) + else: + info.names[funcname] = SymbolTableNode(MDEF, func) + + add_method('_replace', ret=selftype, + args=[Argument(var, var.type, EllipsisExpr(), ARG_NAMED_OPT) for var in vars]) + + def make_init_arg(var: Var) -> Argument: + default = default_items.get(var.name(), None) + kind = ARG_POS if default is None else ARG_OPT + return Argument(var, var.type, default, kind) + + add_method('__init__', ret=NoneTyp(), name=info.name(), + args=[make_init_arg(var) for var in vars]) + add_method('_asdict', args=[], ret=ordereddictype) + add_method('_make', ret=selftype, is_classmethod=True, + args=[Argument(Var('iterable', iterable_type), iterable_type, None, ARG_POS), + Argument(Var('new'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT), + Argument(Var('len'), AnyType(), EllipsisExpr(), ARG_NAMED_OPT)]) def analyze_typeddict_bases(self, defn: ClassDef) -> Tuple[List[str], List[Type]]: typeddict_bases = [cast(RefExpr, expr) for expr in defn.base_type_exprs diff --git a/test-data/unit/check-class-namedtuple.test b/test-data/unit/check-class-namedtuple.test index 46cb87d9a018b..2efa464c2839d 100644 --- a/test-data/unit/check-class-namedtuple.test +++ b/test-data/unit/check-class-namedtuple.test @@ -321,7 +321,7 @@ class Y(NamedTuple): x: int y: str -reveal_type([X(3, 'b'), Y(1, 'a')]) # E: Revealed type is 'builtins.list[Tuple[builtins.int, builtins.str]]' +reveal_type([X(3, 'b'), Y(1, 'a')]) # E: Revealed type is 'builtins.list[Tuple[builtins.int, builtins.str, fallback=typing.NamedTuple]]' [builtins fixtures/list.pyi] @@ -345,7 +345,7 @@ from typing import NamedTuple class X(NamedTuple): x: int y = z = 2 # E: Invalid statement in NamedTuple definition; expected "field_name: field_type" - def f(self): pass # E: Invalid statement in NamedTuple definition; expected "field_name: field_type" + def f(self): pass [case testNewNamedTupleWithInvalidItems2] # flags: --python-version 3.6 @@ -354,16 +354,14 @@ import typing class X(typing.NamedTuple): x: int y = 1 - x.x: int + +class Y(typing.NamedTuple): z: str = 'z' aa: int [out] main:6: error: Invalid statement in NamedTuple definition; expected "field_name: field_type" -main:7: error: Invalid statement in NamedTuple definition; expected "field_name: field_type" -main:7: error: Type cannot be declared in assignment to non-self attribute -main:7: error: "int" has no attribute "x" -main:9: error: Non-default NamedTuple fields cannot follow default fields +main:10: error: Non-default NamedTuple fields cannot follow default fields [builtins fixtures/list.pyi] @@ -390,7 +388,7 @@ def f(a: Type[N]): main:8: error: Unsupported type Type["N"] [case testNewNamedTupleWithDefaults] -# flags: --fast-parser --python-version 3.6 +# flags: --python-version 3.6 from typing import List, NamedTuple, Optional class X(NamedTuple): @@ -430,7 +428,7 @@ UserDefined(1) # E: Argument 1 to "UserDefined" has incompatible type "int"; ex [builtins fixtures/list.pyi] [case testNewNamedTupleWithDefaultsStrictOptional] -# flags: --fast-parser --strict-optional --python-version 3.6 +# flags: --strict-optional --python-version 3.6 from typing import List, NamedTuple, Optional class HasNone(NamedTuple): @@ -449,7 +447,7 @@ class CannotBeNone(NamedTuple): [builtins fixtures/list.pyi] [case testNewNamedTupleWrongType] -# flags: --fast-parser --python-version 3.6 +# flags: --python-version 3.6 from typing import NamedTuple class X(NamedTuple): @@ -457,14 +455,14 @@ class X(NamedTuple): y: int = 'not an int' # E: Incompatible types in assignment (expression has type "str", variable has type "int") [case testNewNamedTupleErrorInDefault] -# flags: --fast-parser --python-version 3.6 +# flags: --python-version 3.6 from typing import NamedTuple class X(NamedTuple): x: int = 1 + '1' # E: Unsupported operand types for + ("int" and "str") [case testNewNamedTupleInheritance] -# flags: --fast-parser --python-version 3.6 +# flags: --python-version 3.6 from typing import NamedTuple class X(NamedTuple): From 191a5d790e5789f1d6080048a8dadf622a1d213f Mon Sep 17 00:00:00 2001 From: elazar Date: Fri, 7 Apr 2017 04:06:55 +0300 Subject: [PATCH 14/18] some support for generic namedtuple --- mypy/typeanal.py | 3 --- test-data/unit/check-class-namedtuple.test | 13 +++++++++++++ 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 45772f4728884..47faff259a7fa 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -226,9 +226,6 @@ def visit_unbound_type(self, t: UnboundType) -> Type: if tup is not None: # The class has a Tuple[...] base class so it will be # represented as a tuple type. - if t.args: - self.fail('Generic tuple types not supported', t) - return AnyType() return tup.copy_modified(items=self.anal_array(tup.items), fallback=instance) td = info.typeddict_type diff --git a/test-data/unit/check-class-namedtuple.test b/test-data/unit/check-class-namedtuple.test index 2efa464c2839d..d3a0061755513 100644 --- a/test-data/unit/check-class-namedtuple.test +++ b/test-data/unit/check-class-namedtuple.test @@ -480,3 +480,16 @@ Y(y=1, x='1').method() class CallsBaseInit(X): def __init__(self, x: str) -> None: super().__init__(x) + +[case testNewNamedTupleGeneric] +# flags: --python-version 3.6 +from typing import NamedTuple, Generic, TypeVar +T = TypeVar('T') +class A(NamedTuple, Generic[T]): + x: T + y: str + +a : A[int] +reveal_type(a.x) # E: Revealed type is 'builtins.int*' +b : A[str] +reveal_type(b.x) # E: Revealed type is 'builtins.str*' From b09df603c6d882b881cf83df601042e212897b0b Mon Sep 17 00:00:00 2001 From: elazar Date: Fri, 7 Apr 2017 04:32:29 +0300 Subject: [PATCH 15/18] fix non-named generic tuple --- mypy/typeanal.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/mypy/typeanal.py b/mypy/typeanal.py index 47faff259a7fa..f23dcb4ed3a18 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -226,6 +226,9 @@ def visit_unbound_type(self, t: UnboundType) -> Type: if tup is not None: # The class has a Tuple[...] base class so it will be # represented as a tuple type. + if t.args and not info.is_named_tuple: + self.fail('Generic tuple types not supported', t) + return AnyType() return tup.copy_modified(items=self.anal_array(tup.items), fallback=instance) td = info.typeddict_type From 6bdbaa7a764eca9af50d16775875bd470f24cd41 Mon Sep 17 00:00:00 2001 From: elazar Date: Fri, 7 Apr 2017 14:13:21 +0300 Subject: [PATCH 16/18] don't test for generics; define macro expansion --- mypy/semanal.py | 6 +- mypy/specialtype.py | 67 +++++++++++++++++++++- mypy/typeanal.py | 2 +- test-data/unit/check-class-namedtuple.test | 13 ----- test-data/unit/check-tuples.test | 4 +- 5 files changed, 72 insertions(+), 20 deletions(-) diff --git a/mypy/semanal.py b/mypy/semanal.py index 1554772296cf1..4f987aea70511 100644 --- a/mypy/semanal.py +++ b/mypy/semanal.py @@ -86,6 +86,7 @@ from mypy.sametypes import is_same_type from mypy.options import Options from mypy.specialtype import Special +import mypy.specialtype T = TypeVar('T') @@ -2548,8 +2549,7 @@ def visit_block(self, b: Block) -> None: def visit_assignment_stmt(self, s: AssignmentStmt) -> None: if self.sem.is_module_scope(): for lval in s.lvalues: - self.analyze_lvalue(lval, explicit_type=s.type is not None, - is_initialized=s.rvalue is not None) + self.analyze_lvalue(lval, explicit_type=s.type is not None) def visit_func_def(self, func: FuncDef) -> None: sem = self.sem @@ -2607,6 +2607,8 @@ def visit_overloaded_func_def(self, func: OverloadedFuncDef) -> None: sem.function_stack.pop() def visit_class_def(self, cdef: ClassDef) -> None: + for expr in cdef.base_type_exprs: + expr.accept(self) kind = self.kind_by_scope() if kind == LDEF: return diff --git a/mypy/specialtype.py b/mypy/specialtype.py index 1748c217ecd05..8689368ceb2cd 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -35,6 +35,61 @@ 'expected "field_name: field_type"') +# TODO: this is problematic due to the requirement in ClassDef to do type analysis before constrution + +def build_namedtuple_classdef_from_call(call: CallExpr, fullname: str + ) -> Union[str, ClassDef]: + # TODO: Share code with check_argument_count in checkexpr.py? + args = call.args + if len(args) < 2: + return "Too few arguments for namedtuple()" + if len(args) > 2: + # FIX incorrect. There are two additional parameters + return "Too many arguments for namedtuple()" + if call.arg_kinds != [ARG_POS, ARG_POS]: + return "Unexpected arguments to namedtuple()" + String = (StrExpr, BytesExpr, UnicodeExpr) + if not isinstance(args[0], String): + return "namedtuple() expects a string literal as the first argument" + typename = args[0].value + typedecls = [] # type: List[Type] + if not isinstance(args[1], (ListExpr, TupleExpr)): + if (fullname == 'collections.namedtuple' and isinstance(args[1], String)): + str_expr = cast(StrExpr, args[1]) + names = str_expr.value.replace(',', ' ').split() + else: + return "List or tuple literal expected as the second argument to namedtuple()" + else: + listexpr = args[1] + if fullname == 'collections.namedtuple': + # The fields argument contains just names, with implicit Any types. + if any(not isinstance(item, String) for item in listexpr.items): + return "String literal expected as namedtuple() item" + names = [cast(StrExpr, item).value for item in listexpr.items] + else: + # The fields argument contains (name, type) tuples. + names = [] + for item in listexpr.items: + if isinstance(item, TupleExpr): + if len(item.items) != 2: + return "Invalid NamedTuple field definition" + name, type_node = item.items + if not isinstance(type_node, RefExpr): + return "TEMP: cannot parse complex annotations for namedtuple" + if isinstance(name, String): + names.append(name.value) + typedecls.append(UnboundType(type_node.fullname)) + else: + return "Tuple expected as NamedTuple() field" + if not typedecls: + typedecls = [AnyType() for _ in names] + return ClassDef(typename, + defs=Block([AssignmentStmt([NameExpr(name)], NameExpr('None'), + decl, new_syntax=True) + for name, decl in zip(names, typedecls)]), + base_type_exprs=[NameExpr('typing.NamedTuple')]) + + class Special: """Handling of special-cased types. @@ -103,7 +158,17 @@ def dispatch_call(self, call: CallExpr, calleename: str, tvar = self.check_typevar(call, name, fullname) info = None elif calleename in ('collections.namedtuple', 'typing.NamedTuple'): - info = self.check_namedtuple(call, calleename, name) + if call.analyzed: + raise Exception(str(call.analyzed)) + defn = call.analyzed.defn + if isinstance(defn, str): + self.fail(defn) + else: + self.semanalyzer.analyze_class_body(defn) + self.analyze_namedtuple_classdef_1(defn.info) + info = defn.info + else: + info = self.check_namedtuple(call, calleename, name) elif calleename in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'): info = self.check_enum_call(call, calleename, name) elif calleename == 'mypy_extensions.TypedDict': diff --git a/mypy/typeanal.py b/mypy/typeanal.py index f23dcb4ed3a18..45772f4728884 100644 --- a/mypy/typeanal.py +++ b/mypy/typeanal.py @@ -226,7 +226,7 @@ def visit_unbound_type(self, t: UnboundType) -> Type: if tup is not None: # The class has a Tuple[...] base class so it will be # represented as a tuple type. - if t.args and not info.is_named_tuple: + if t.args: self.fail('Generic tuple types not supported', t) return AnyType() return tup.copy_modified(items=self.anal_array(tup.items), diff --git a/test-data/unit/check-class-namedtuple.test b/test-data/unit/check-class-namedtuple.test index d3a0061755513..2efa464c2839d 100644 --- a/test-data/unit/check-class-namedtuple.test +++ b/test-data/unit/check-class-namedtuple.test @@ -480,16 +480,3 @@ Y(y=1, x='1').method() class CallsBaseInit(X): def __init__(self, x: str) -> None: super().__init__(x) - -[case testNewNamedTupleGeneric] -# flags: --python-version 3.6 -from typing import NamedTuple, Generic, TypeVar -T = TypeVar('T') -class A(NamedTuple, Generic[T]): - x: T - y: str - -a : A[int] -reveal_type(a.x) # E: Revealed type is 'builtins.int*' -b : A[str] -reveal_type(b.x) # E: Revealed type is 'builtins.str*' diff --git a/test-data/unit/check-tuples.test b/test-data/unit/check-tuples.test index 08345706f0d94..e65e1500f6984 100644 --- a/test-data/unit/check-tuples.test +++ b/test-data/unit/check-tuples.test @@ -721,10 +721,8 @@ y() # Expected: "str" not callable from typing import TypeVar, Generic, Tuple T = TypeVar('T') class Test(Generic[T], Tuple[T]): pass -x = Test() # type: Test[int] +x = Test() # type: Test[int] # E: Generic tuple types not supported [builtins fixtures/tuple.pyi] -[out] -main:4: error: Generic tuple types not supported -- Variable-length tuples (Tuple[t, ...] with literal '...') From d43a606a8250eb42d47a129c91cbbe98658beaea Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 12 Apr 2017 17:44:44 +0300 Subject: [PATCH 17/18] remove comment --- mypy/specialtype.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/mypy/specialtype.py b/mypy/specialtype.py index 8689368ceb2cd..88c40063e417e 100644 --- a/mypy/specialtype.py +++ b/mypy/specialtype.py @@ -35,8 +35,6 @@ 'expected "field_name: field_type"') -# TODO: this is problematic due to the requirement in ClassDef to do type analysis before constrution - def build_namedtuple_classdef_from_call(call: CallExpr, fullname: str ) -> Union[str, ClassDef]: # TODO: Share code with check_argument_count in checkexpr.py? From c6a2039a2057dfa3284637f47f81285e56d53cf2 Mon Sep 17 00:00:00 2001 From: elazar Date: Wed, 12 Apr 2017 17:48:49 +0300 Subject: [PATCH 18/18] make treetransform implement abstract visitor --- mypy/strconv.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/mypy/strconv.py b/mypy/strconv.py index 169d44bdf9aa4..97b0449bb48e5 100644 --- a/mypy/strconv.py +++ b/mypy/strconv.py @@ -7,10 +7,10 @@ from mypy.util import short_type, IdMapper import mypy.nodes -from mypy.visitor import NodeVisitor +from mypy.visitor import AbstractNodeVisitor -class StrConv(NodeVisitor[str]): +class StrConv(AbstractNodeVisitor[str]): """Visitor for converting a node to a human-readable string. For example, an MypyFile node from program '1' is converted into @@ -503,6 +503,9 @@ def visit_slice_expr(self, o: 'mypy.nodes.SliceExpr') -> str: def visit_backquote_expr(self, o: 'mypy.nodes.BackquoteExpr') -> str: return self.dump([o.expr], o) + def visit_temp_node(self, o: 'mypy.nodes.TempNode') -> str: + return self.dump([o.type], o) + def dump_tagged(nodes: Sequence[object], tag: str, str_conv: 'StrConv') -> str: """Convert an array into a pretty-printed multiline string representation.