You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					300 lines
				
				9.4 KiB
			
		
		
			
		
	
	
					300 lines
				
				9.4 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								# ext/mypy/apply.py
							 | 
						||
| 
								 | 
							
								# Copyright (C) 2021 the SQLAlchemy authors and contributors
							 | 
						||
| 
								 | 
							
								# <see AUTHORS file>
							 | 
						||
| 
								 | 
							
								#
							 | 
						||
| 
								 | 
							
								# This module is part of SQLAlchemy and is released under
							 | 
						||
| 
								 | 
							
								# the MIT License: https://www.opensource.org/licenses/mit-license.php
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from typing import List
							 | 
						||
| 
								 | 
							
								from typing import Optional
							 | 
						||
| 
								 | 
							
								from typing import Union
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from mypy.nodes import ARG_NAMED_OPT
							 | 
						||
| 
								 | 
							
								from mypy.nodes import Argument
							 | 
						||
| 
								 | 
							
								from mypy.nodes import AssignmentStmt
							 | 
						||
| 
								 | 
							
								from mypy.nodes import CallExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import ClassDef
							 | 
						||
| 
								 | 
							
								from mypy.nodes import MDEF
							 | 
						||
| 
								 | 
							
								from mypy.nodes import MemberExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import NameExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import RefExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import StrExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import SymbolTableNode
							 | 
						||
| 
								 | 
							
								from mypy.nodes import TempNode
							 | 
						||
| 
								 | 
							
								from mypy.nodes import TypeInfo
							 | 
						||
| 
								 | 
							
								from mypy.nodes import Var
							 | 
						||
| 
								 | 
							
								from mypy.plugin import SemanticAnalyzerPluginInterface
							 | 
						||
| 
								 | 
							
								from mypy.plugins.common import add_method_to_class
							 | 
						||
| 
								 | 
							
								from mypy.types import AnyType
							 | 
						||
| 
								 | 
							
								from mypy.types import get_proper_type
							 | 
						||
| 
								 | 
							
								from mypy.types import Instance
							 | 
						||
| 
								 | 
							
								from mypy.types import NoneTyp
							 | 
						||
| 
								 | 
							
								from mypy.types import ProperType
							 | 
						||
| 
								 | 
							
								from mypy.types import TypeOfAny
							 | 
						||
| 
								 | 
							
								from mypy.types import UnboundType
							 | 
						||
| 
								 | 
							
								from mypy.types import UnionType
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from . import infer
							 | 
						||
| 
								 | 
							
								from . import util
							 | 
						||
| 
								 | 
							
								from .names import NAMED_TYPE_SQLA_MAPPED
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def apply_mypy_mapped_attr(
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    item: Union[NameExpr, StrExpr],
							 | 
						||
| 
								 | 
							
								    attributes: List[util.SQLAlchemyAttribute],
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    if isinstance(item, NameExpr):
							 | 
						||
| 
								 | 
							
								        name = item.name
							 | 
						||
| 
								 | 
							
								    elif isinstance(item, StrExpr):
							 | 
						||
| 
								 | 
							
								        name = item.value
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        return None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for stmt in cls.defs.body:
							 | 
						||
| 
								 | 
							
								        if (
							 | 
						||
| 
								 | 
							
								            isinstance(stmt, AssignmentStmt)
							 | 
						||
| 
								 | 
							
								            and isinstance(stmt.lvalues[0], NameExpr)
							 | 
						||
| 
								 | 
							
								            and stmt.lvalues[0].name == name
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								            break
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        util.fail(api, "Can't find mapped attribute {}".format(name), cls)
							 | 
						||
| 
								 | 
							
								        return None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if stmt.type is None:
							 | 
						||
| 
								 | 
							
								        util.fail(
							 | 
						||
| 
								 | 
							
								            api,
							 | 
						||
| 
								 | 
							
								            "Statement linked from _mypy_mapped_attrs has no "
							 | 
						||
| 
								 | 
							
								            "typing information",
							 | 
						||
| 
								 | 
							
								            stmt,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        return None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    left_hand_explicit_type = get_proper_type(stmt.type)
							 | 
						||
| 
								 | 
							
								    assert isinstance(
							 | 
						||
| 
								 | 
							
								        left_hand_explicit_type, (Instance, UnionType, UnboundType)
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    attributes.append(
							 | 
						||
| 
								 | 
							
								        util.SQLAlchemyAttribute(
							 | 
						||
| 
								 | 
							
								            name=name,
							 | 
						||
| 
								 | 
							
								            line=item.line,
							 | 
						||
| 
								 | 
							
								            column=item.column,
							 | 
						||
| 
								 | 
							
								            typ=left_hand_explicit_type,
							 | 
						||
| 
								 | 
							
								            info=cls.info,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    apply_type_to_mapped_statement(
							 | 
						||
| 
								 | 
							
								        api, stmt, stmt.lvalues[0], left_hand_explicit_type, None
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def re_apply_declarative_assignments(
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    attributes: List[util.SQLAlchemyAttribute],
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    """For multiple class passes, re-apply our left-hand side types as mypy
							 | 
						||
| 
								 | 
							
								    seems to reset them in place.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    mapped_attr_lookup = {attr.name: attr for attr in attributes}
							 | 
						||
| 
								 | 
							
								    update_cls_metadata = False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for stmt in cls.defs.body:
							 | 
						||
| 
								 | 
							
								        # for a re-apply, all of our statements are AssignmentStmt;
							 | 
						||
| 
								 | 
							
								        # @declared_attr calls will have been converted and this
							 | 
						||
| 
								 | 
							
								        # currently seems to be preserved by mypy (but who knows if this
							 | 
						||
| 
								 | 
							
								        # will change).
							 | 
						||
| 
								 | 
							
								        if (
							 | 
						||
| 
								 | 
							
								            isinstance(stmt, AssignmentStmt)
							 | 
						||
| 
								 | 
							
								            and isinstance(stmt.lvalues[0], NameExpr)
							 | 
						||
| 
								 | 
							
								            and stmt.lvalues[0].name in mapped_attr_lookup
							 | 
						||
| 
								 | 
							
								            and isinstance(stmt.lvalues[0].node, Var)
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            left_node = stmt.lvalues[0].node
							 | 
						||
| 
								 | 
							
								            python_type_for_type = mapped_attr_lookup[
							 | 
						||
| 
								 | 
							
								                stmt.lvalues[0].name
							 | 
						||
| 
								 | 
							
								            ].type
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            left_node_proper_type = get_proper_type(left_node.type)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # if we have scanned an UnboundType and now there's a more
							 | 
						||
| 
								 | 
							
								            # specific type than UnboundType, call the re-scan so we
							 | 
						||
| 
								 | 
							
								            # can get that set up correctly
							 | 
						||
| 
								 | 
							
								            if (
							 | 
						||
| 
								 | 
							
								                isinstance(python_type_for_type, UnboundType)
							 | 
						||
| 
								 | 
							
								                and not isinstance(left_node_proper_type, UnboundType)
							 | 
						||
| 
								 | 
							
								                and (
							 | 
						||
| 
								 | 
							
								                    isinstance(stmt.rvalue, CallExpr)
							 | 
						||
| 
								 | 
							
								                    and isinstance(stmt.rvalue.callee, MemberExpr)
							 | 
						||
| 
								 | 
							
								                    and isinstance(stmt.rvalue.callee.expr, NameExpr)
							 | 
						||
| 
								 | 
							
								                    and stmt.rvalue.callee.expr.node is not None
							 | 
						||
| 
								 | 
							
								                    and stmt.rvalue.callee.expr.node.fullname
							 | 
						||
| 
								 | 
							
								                    == NAMED_TYPE_SQLA_MAPPED
							 | 
						||
| 
								 | 
							
								                    and stmt.rvalue.callee.name == "_empty_constructor"
							 | 
						||
| 
								 | 
							
								                    and isinstance(stmt.rvalue.args[0], CallExpr)
							 | 
						||
| 
								 | 
							
								                    and isinstance(stmt.rvalue.args[0].callee, RefExpr)
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            ):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                python_type_for_type = (
							 | 
						||
| 
								 | 
							
								                    infer.infer_type_from_right_hand_nameexpr(
							 | 
						||
| 
								 | 
							
								                        api,
							 | 
						||
| 
								 | 
							
								                        stmt,
							 | 
						||
| 
								 | 
							
								                        left_node,
							 | 
						||
| 
								 | 
							
								                        left_node_proper_type,
							 | 
						||
| 
								 | 
							
								                        stmt.rvalue.args[0].callee,
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                if python_type_for_type is None or isinstance(
							 | 
						||
| 
								 | 
							
								                    python_type_for_type, UnboundType
							 | 
						||
| 
								 | 
							
								                ):
							 | 
						||
| 
								 | 
							
								                    continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                # update the SQLAlchemyAttribute with the better information
							 | 
						||
| 
								 | 
							
								                mapped_attr_lookup[
							 | 
						||
| 
								 | 
							
								                    stmt.lvalues[0].name
							 | 
						||
| 
								 | 
							
								                ].type = python_type_for_type
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                update_cls_metadata = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if python_type_for_type is not None:
							 | 
						||
| 
								 | 
							
								                left_node.type = api.named_type(
							 | 
						||
| 
								 | 
							
								                    NAMED_TYPE_SQLA_MAPPED, [python_type_for_type]
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if update_cls_metadata:
							 | 
						||
| 
								 | 
							
								        util.set_mapped_attributes(cls.info, attributes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def apply_type_to_mapped_statement(
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    stmt: AssignmentStmt,
							 | 
						||
| 
								 | 
							
								    lvalue: NameExpr,
							 | 
						||
| 
								 | 
							
								    left_hand_explicit_type: Optional[ProperType],
							 | 
						||
| 
								 | 
							
								    python_type_for_type: Optional[ProperType],
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    """Apply the Mapped[<type>] annotation and right hand object to a
							 | 
						||
| 
								 | 
							
								    declarative assignment statement.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    This converts a Python declarative class statement such as::
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        class User(Base):
							 | 
						||
| 
								 | 
							
								            # ...
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            attrname = Column(Integer)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    To one that describes the final Python behavior to Mypy::
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        class User(Base):
							 | 
						||
| 
								 | 
							
								            # ...
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            attrname : Mapped[Optional[int]] = <meaningless temp node>
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    left_node = lvalue.node
							 | 
						||
| 
								 | 
							
								    assert isinstance(left_node, Var)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if left_hand_explicit_type is not None:
							 | 
						||
| 
								 | 
							
								        left_node.type = api.named_type(
							 | 
						||
| 
								 | 
							
								            NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        lvalue.is_inferred_def = False
							 | 
						||
| 
								 | 
							
								        left_node.type = api.named_type(
							 | 
						||
| 
								 | 
							
								            NAMED_TYPE_SQLA_MAPPED,
							 | 
						||
| 
								 | 
							
								            [] if python_type_for_type is None else [python_type_for_type],
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # so to have it skip the right side totally, we can do this:
							 | 
						||
| 
								 | 
							
								    # stmt.rvalue = TempNode(AnyType(TypeOfAny.special_form))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # however, if we instead manufacture a new node that uses the old
							 | 
						||
| 
								 | 
							
								    # one, then we can still get type checking for the call itself,
							 | 
						||
| 
								 | 
							
								    # e.g. the Column, relationship() call, etc.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # rewrite the node as:
							 | 
						||
| 
								 | 
							
								    # <attr> : Mapped[<typ>] =
							 | 
						||
| 
								 | 
							
								    # _sa_Mapped._empty_constructor(<original CallExpr from rvalue>)
							 | 
						||
| 
								 | 
							
								    # the original right-hand side is maintained so it gets type checked
							 | 
						||
| 
								 | 
							
								    # internally
							 | 
						||
| 
								 | 
							
								    stmt.rvalue = util.expr_to_mapped_constructor(stmt.rvalue)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def add_additional_orm_attributes(
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    attributes: List[util.SQLAlchemyAttribute],
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    """Apply __init__, __table__ and other attributes to the mapped class."""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    info = util.info_for_cls(cls, api)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if info is None:
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    is_base = util.get_is_base(info)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if "__init__" not in info.names and not is_base:
							 | 
						||
| 
								 | 
							
								        mapped_attr_names = {attr.name: attr.type for attr in attributes}
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for base in info.mro[1:-1]:
							 | 
						||
| 
								 | 
							
								            if "sqlalchemy" not in info.metadata:
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            base_cls_attributes = util.get_mapped_attributes(base, api)
							 | 
						||
| 
								 | 
							
								            if base_cls_attributes is None:
							 | 
						||
| 
								 | 
							
								                continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            for attr in base_cls_attributes:
							 | 
						||
| 
								 | 
							
								                mapped_attr_names.setdefault(attr.name, attr.type)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        arguments = []
							 | 
						||
| 
								 | 
							
								        for name, typ in mapped_attr_names.items():
							 | 
						||
| 
								 | 
							
								            if typ is None:
							 | 
						||
| 
								 | 
							
								                typ = AnyType(TypeOfAny.special_form)
							 | 
						||
| 
								 | 
							
								            arguments.append(
							 | 
						||
| 
								 | 
							
								                Argument(
							 | 
						||
| 
								 | 
							
								                    variable=Var(name, typ),
							 | 
						||
| 
								 | 
							
								                    type_annotation=typ,
							 | 
						||
| 
								 | 
							
								                    initializer=TempNode(typ),
							 | 
						||
| 
								 | 
							
								                    kind=ARG_NAMED_OPT,
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        add_method_to_class(api, cls, "__init__", arguments, NoneTyp())
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if "__table__" not in info.names and util.get_has_table(info):
							 | 
						||
| 
								 | 
							
								        _apply_placeholder_attr_to_class(
							 | 
						||
| 
								 | 
							
								            api, cls, "sqlalchemy.sql.schema.Table", "__table__"
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								    if not is_base:
							 | 
						||
| 
								 | 
							
								        _apply_placeholder_attr_to_class(
							 | 
						||
| 
								 | 
							
								            api, cls, "sqlalchemy.orm.mapper.Mapper", "__mapper__"
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _apply_placeholder_attr_to_class(
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    qualified_name: str,
							 | 
						||
| 
								 | 
							
								    attrname: str,
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    sym = api.lookup_fully_qualified_or_none(qualified_name)
							 | 
						||
| 
								 | 
							
								    if sym:
							 | 
						||
| 
								 | 
							
								        assert isinstance(sym.node, TypeInfo)
							 | 
						||
| 
								 | 
							
								        type_: ProperType = Instance(sym.node, [])
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        type_ = AnyType(TypeOfAny.special_form)
							 | 
						||
| 
								 | 
							
								    var = Var(attrname)
							 | 
						||
| 
								 | 
							
								    var._fullname = cls.fullname + "." + attrname
							 | 
						||
| 
								 | 
							
								    var.info = cls.info
							 | 
						||
| 
								 | 
							
								    var.type = type_
							 | 
						||
| 
								 | 
							
								    cls.info.names[attrname] = SymbolTableNode(MDEF, var)
							 |