You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					517 lines
				
				17 KiB
			
		
		
			
		
	
	
					517 lines
				
				17 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								# ext/mypy/decl_class.py
							 | 
						||
| 
								 | 
							
								# Copyright (C) 2021 the SQLAlchemy authors and contributors
							 | 
						||
| 
								 | 
							
								# <see AUTHORS file>
							 | 
						||
| 
								 | 
							
								#
							 | 
						||
| 
								 | 
							
								# This module is part of SQLAlchemy and is released under
							 | 
						||
| 
								 | 
							
								# the MIT License: https://www.opensource.org/licenses/mit-license.php
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from typing import List
							 | 
						||
| 
								 | 
							
								from typing import Optional
							 | 
						||
| 
								 | 
							
								from typing import Union
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from mypy.nodes import AssignmentStmt
							 | 
						||
| 
								 | 
							
								from mypy.nodes import CallExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import ClassDef
							 | 
						||
| 
								 | 
							
								from mypy.nodes import Decorator
							 | 
						||
| 
								 | 
							
								from mypy.nodes import LambdaExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import ListExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import MemberExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import NameExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import PlaceholderNode
							 | 
						||
| 
								 | 
							
								from mypy.nodes import RefExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import StrExpr
							 | 
						||
| 
								 | 
							
								from mypy.nodes import SymbolNode
							 | 
						||
| 
								 | 
							
								from mypy.nodes import SymbolTableNode
							 | 
						||
| 
								 | 
							
								from mypy.nodes import TempNode
							 | 
						||
| 
								 | 
							
								from mypy.nodes import TypeInfo
							 | 
						||
| 
								 | 
							
								from mypy.nodes import Var
							 | 
						||
| 
								 | 
							
								from mypy.plugin import SemanticAnalyzerPluginInterface
							 | 
						||
| 
								 | 
							
								from mypy.types import AnyType
							 | 
						||
| 
								 | 
							
								from mypy.types import CallableType
							 | 
						||
| 
								 | 
							
								from mypy.types import get_proper_type
							 | 
						||
| 
								 | 
							
								from mypy.types import Instance
							 | 
						||
| 
								 | 
							
								from mypy.types import NoneType
							 | 
						||
| 
								 | 
							
								from mypy.types import ProperType
							 | 
						||
| 
								 | 
							
								from mypy.types import Type
							 | 
						||
| 
								 | 
							
								from mypy.types import TypeOfAny
							 | 
						||
| 
								 | 
							
								from mypy.types import UnboundType
							 | 
						||
| 
								 | 
							
								from mypy.types import UnionType
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from . import apply
							 | 
						||
| 
								 | 
							
								from . import infer
							 | 
						||
| 
								 | 
							
								from . import names
							 | 
						||
| 
								 | 
							
								from . import util
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def scan_declarative_assignments_and_apply_types(
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    is_mixin_scan: bool = False,
							 | 
						||
| 
								 | 
							
								) -> Optional[List[util.SQLAlchemyAttribute]]:
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    info = util.info_for_cls(cls, api)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if info is None:
							 | 
						||
| 
								 | 
							
								        # this can occur during cached passes
							 | 
						||
| 
								 | 
							
								        return None
							 | 
						||
| 
								 | 
							
								    elif cls.fullname.startswith("builtins"):
							 | 
						||
| 
								 | 
							
								        return None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    mapped_attributes: Optional[
							 | 
						||
| 
								 | 
							
								        List[util.SQLAlchemyAttribute]
							 | 
						||
| 
								 | 
							
								    ] = util.get_mapped_attributes(info, api)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # used by assign.add_additional_orm_attributes among others
							 | 
						||
| 
								 | 
							
								    util.establish_as_sqlalchemy(info)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if mapped_attributes is not None:
							 | 
						||
| 
								 | 
							
								        # ensure that a class that's mapped is always picked up by
							 | 
						||
| 
								 | 
							
								        # its mapped() decorator or declarative metaclass before
							 | 
						||
| 
								 | 
							
								        # it would be detected as an unmapped mixin class
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not is_mixin_scan:
							 | 
						||
| 
								 | 
							
								            # mypy can call us more than once.  it then *may* have reset the
							 | 
						||
| 
								 | 
							
								            # left hand side of everything, but not the right that we removed,
							 | 
						||
| 
								 | 
							
								            # removing our ability to re-scan.   but we have the types
							 | 
						||
| 
								 | 
							
								            # here, so lets re-apply them, or if we have an UnboundType,
							 | 
						||
| 
								 | 
							
								            # we can re-scan
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            apply.re_apply_declarative_assignments(cls, api, mapped_attributes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return mapped_attributes
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    mapped_attributes = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if not cls.defs.body:
							 | 
						||
| 
								 | 
							
								        # when we get a mixin class from another file, the body is
							 | 
						||
| 
								 | 
							
								        # empty (!) but the names are in the symbol table.  so use that.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        for sym_name, sym in info.names.items():
							 | 
						||
| 
								 | 
							
								            _scan_symbol_table_entry(
							 | 
						||
| 
								 | 
							
								                cls, api, sym_name, sym, mapped_attributes
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        for stmt in util.flatten_typechecking(cls.defs.body):
							 | 
						||
| 
								 | 
							
								            if isinstance(stmt, AssignmentStmt):
							 | 
						||
| 
								 | 
							
								                _scan_declarative_assignment_stmt(
							 | 
						||
| 
								 | 
							
								                    cls, api, stmt, mapped_attributes
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            elif isinstance(stmt, Decorator):
							 | 
						||
| 
								 | 
							
								                _scan_declarative_decorator_stmt(
							 | 
						||
| 
								 | 
							
								                    cls, api, stmt, mapped_attributes
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								    _scan_for_mapped_bases(cls, api)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if not is_mixin_scan:
							 | 
						||
| 
								 | 
							
								        apply.add_additional_orm_attributes(cls, api, mapped_attributes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    util.set_mapped_attributes(info, mapped_attributes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    return mapped_attributes
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _scan_symbol_table_entry(
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    name: str,
							 | 
						||
| 
								 | 
							
								    value: SymbolTableNode,
							 | 
						||
| 
								 | 
							
								    attributes: List[util.SQLAlchemyAttribute],
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    """Extract mapping information from a SymbolTableNode that's in the
							 | 
						||
| 
								 | 
							
								    type.names dictionary.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    value_type = get_proper_type(value.type)
							 | 
						||
| 
								 | 
							
								    if not isinstance(value_type, Instance):
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    left_hand_explicit_type = None
							 | 
						||
| 
								 | 
							
								    type_id = names.type_id_for_named_node(value_type.type)
							 | 
						||
| 
								 | 
							
								    # type_id = names._type_id_for_unbound_type(value.type.type, cls, api)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    err = False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # TODO: this is nearly the same logic as that of
							 | 
						||
| 
								 | 
							
								    # _scan_declarative_decorator_stmt, likely can be merged
							 | 
						||
| 
								 | 
							
								    if type_id in {
							 | 
						||
| 
								 | 
							
								        names.MAPPED,
							 | 
						||
| 
								 | 
							
								        names.RELATIONSHIP,
							 | 
						||
| 
								 | 
							
								        names.COMPOSITE_PROPERTY,
							 | 
						||
| 
								 | 
							
								        names.MAPPER_PROPERTY,
							 | 
						||
| 
								 | 
							
								        names.SYNONYM_PROPERTY,
							 | 
						||
| 
								 | 
							
								        names.COLUMN_PROPERTY,
							 | 
						||
| 
								 | 
							
								    }:
							 | 
						||
| 
								 | 
							
								        if value_type.args:
							 | 
						||
| 
								 | 
							
								            left_hand_explicit_type = get_proper_type(value_type.args[0])
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            err = True
							 | 
						||
| 
								 | 
							
								    elif type_id is names.COLUMN:
							 | 
						||
| 
								 | 
							
								        if not value_type.args:
							 | 
						||
| 
								 | 
							
								            err = True
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            typeengine_arg: Union[ProperType, TypeInfo] = get_proper_type(
							 | 
						||
| 
								 | 
							
								                value_type.args[0]
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								            if isinstance(typeengine_arg, Instance):
							 | 
						||
| 
								 | 
							
								                typeengine_arg = typeengine_arg.type
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if isinstance(typeengine_arg, (UnboundType, TypeInfo)):
							 | 
						||
| 
								 | 
							
								                sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
							 | 
						||
| 
								 | 
							
								                if sym is not None and isinstance(sym.node, TypeInfo):
							 | 
						||
| 
								 | 
							
								                    if names.has_base_type_id(sym.node, names.TYPEENGINE):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                        left_hand_explicit_type = UnionType(
							 | 
						||
| 
								 | 
							
								                            [
							 | 
						||
| 
								 | 
							
								                                infer.extract_python_type_from_typeengine(
							 | 
						||
| 
								 | 
							
								                                    api, sym.node, []
							 | 
						||
| 
								 | 
							
								                                ),
							 | 
						||
| 
								 | 
							
								                                NoneType(),
							 | 
						||
| 
								 | 
							
								                            ]
							 | 
						||
| 
								 | 
							
								                        )
							 | 
						||
| 
								 | 
							
								                    else:
							 | 
						||
| 
								 | 
							
								                        util.fail(
							 | 
						||
| 
								 | 
							
								                            api,
							 | 
						||
| 
								 | 
							
								                            "Column type should be a TypeEngine "
							 | 
						||
| 
								 | 
							
								                            "subclass not '{}'".format(sym.node.fullname),
							 | 
						||
| 
								 | 
							
								                            value_type,
							 | 
						||
| 
								 | 
							
								                        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if err:
							 | 
						||
| 
								 | 
							
								        msg = (
							 | 
						||
| 
								 | 
							
								            "Can't infer type from attribute {} on class {}. "
							 | 
						||
| 
								 | 
							
								            "please specify a return type from this function that is "
							 | 
						||
| 
								 | 
							
								            "one of: Mapped[<python type>], relationship[<target class>], "
							 | 
						||
| 
								 | 
							
								            "Column[<TypeEngine>], MapperProperty[<python type>]"
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        util.fail(api, msg.format(name, cls.name), cls)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        left_hand_explicit_type = AnyType(TypeOfAny.special_form)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if left_hand_explicit_type is not None:
							 | 
						||
| 
								 | 
							
								        assert value.node is not None
							 | 
						||
| 
								 | 
							
								        attributes.append(
							 | 
						||
| 
								 | 
							
								            util.SQLAlchemyAttribute(
							 | 
						||
| 
								 | 
							
								                name=name,
							 | 
						||
| 
								 | 
							
								                line=value.node.line,
							 | 
						||
| 
								 | 
							
								                column=value.node.column,
							 | 
						||
| 
								 | 
							
								                typ=left_hand_explicit_type,
							 | 
						||
| 
								 | 
							
								                info=cls.info,
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _scan_declarative_decorator_stmt(
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    stmt: Decorator,
							 | 
						||
| 
								 | 
							
								    attributes: List[util.SQLAlchemyAttribute],
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    """Extract mapping information from a @declared_attr in a declarative
							 | 
						||
| 
								 | 
							
								    class.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    E.g.::
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        @reg.mapped
							 | 
						||
| 
								 | 
							
								        class MyClass:
							 | 
						||
| 
								 | 
							
								            # ...
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            @declared_attr
							 | 
						||
| 
								 | 
							
								            def updated_at(cls) -> Column[DateTime]:
							 | 
						||
| 
								 | 
							
								                return Column(DateTime)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Will resolve in mypy as::
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        @reg.mapped
							 | 
						||
| 
								 | 
							
								        class MyClass:
							 | 
						||
| 
								 | 
							
								            # ...
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            updated_at: Mapped[Optional[datetime.datetime]]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    for dec in stmt.decorators:
							 | 
						||
| 
								 | 
							
								        if (
							 | 
						||
| 
								 | 
							
								            isinstance(dec, (NameExpr, MemberExpr, SymbolNode))
							 | 
						||
| 
								 | 
							
								            and names.type_id_for_named_node(dec) is names.DECLARED_ATTR
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								            break
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    dec_index = cls.defs.body.index(stmt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    left_hand_explicit_type: Optional[ProperType] = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if util.name_is_dunder(stmt.name):
							 | 
						||
| 
								 | 
							
								        # for dunder names like __table_args__, __tablename__,
							 | 
						||
| 
								 | 
							
								        # __mapper_args__ etc., rewrite these as simple assignment
							 | 
						||
| 
								 | 
							
								        # statements; otherwise mypy doesn't like if the decorated
							 | 
						||
| 
								 | 
							
								        # function has an annotation like ``cls: Type[Foo]`` because
							 | 
						||
| 
								 | 
							
								        # it isn't @classmethod
							 | 
						||
| 
								 | 
							
								        any_ = AnyType(TypeOfAny.special_form)
							 | 
						||
| 
								 | 
							
								        left_node = NameExpr(stmt.var.name)
							 | 
						||
| 
								 | 
							
								        left_node.node = stmt.var
							 | 
						||
| 
								 | 
							
								        new_stmt = AssignmentStmt([left_node], TempNode(any_))
							 | 
						||
| 
								 | 
							
								        new_stmt.type = left_node.node.type
							 | 
						||
| 
								 | 
							
								        cls.defs.body[dec_index] = new_stmt
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								    elif isinstance(stmt.func.type, CallableType):
							 | 
						||
| 
								 | 
							
								        func_type = stmt.func.type.ret_type
							 | 
						||
| 
								 | 
							
								        if isinstance(func_type, UnboundType):
							 | 
						||
| 
								 | 
							
								            type_id = names.type_id_for_unbound_type(func_type, cls, api)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            # this does not seem to occur unless the type argument is
							 | 
						||
| 
								 | 
							
								            # incorrect
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if (
							 | 
						||
| 
								 | 
							
								            type_id
							 | 
						||
| 
								 | 
							
								            in {
							 | 
						||
| 
								 | 
							
								                names.MAPPED,
							 | 
						||
| 
								 | 
							
								                names.RELATIONSHIP,
							 | 
						||
| 
								 | 
							
								                names.COMPOSITE_PROPERTY,
							 | 
						||
| 
								 | 
							
								                names.MAPPER_PROPERTY,
							 | 
						||
| 
								 | 
							
								                names.SYNONYM_PROPERTY,
							 | 
						||
| 
								 | 
							
								                names.COLUMN_PROPERTY,
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								            and func_type.args
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								            left_hand_explicit_type = get_proper_type(func_type.args[0])
							 | 
						||
| 
								 | 
							
								        elif type_id is names.COLUMN and func_type.args:
							 | 
						||
| 
								 | 
							
								            typeengine_arg = func_type.args[0]
							 | 
						||
| 
								 | 
							
								            if isinstance(typeengine_arg, UnboundType):
							 | 
						||
| 
								 | 
							
								                sym = api.lookup_qualified(typeengine_arg.name, typeengine_arg)
							 | 
						||
| 
								 | 
							
								                if sym is not None and isinstance(sym.node, TypeInfo):
							 | 
						||
| 
								 | 
							
								                    if names.has_base_type_id(sym.node, names.TYPEENGINE):
							 | 
						||
| 
								 | 
							
								                        left_hand_explicit_type = UnionType(
							 | 
						||
| 
								 | 
							
								                            [
							 | 
						||
| 
								 | 
							
								                                infer.extract_python_type_from_typeengine(
							 | 
						||
| 
								 | 
							
								                                    api, sym.node, []
							 | 
						||
| 
								 | 
							
								                                ),
							 | 
						||
| 
								 | 
							
								                                NoneType(),
							 | 
						||
| 
								 | 
							
								                            ]
							 | 
						||
| 
								 | 
							
								                        )
							 | 
						||
| 
								 | 
							
								                    else:
							 | 
						||
| 
								 | 
							
								                        util.fail(
							 | 
						||
| 
								 | 
							
								                            api,
							 | 
						||
| 
								 | 
							
								                            "Column type should be a TypeEngine "
							 | 
						||
| 
								 | 
							
								                            "subclass not '{}'".format(sym.node.fullname),
							 | 
						||
| 
								 | 
							
								                            func_type,
							 | 
						||
| 
								 | 
							
								                        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if left_hand_explicit_type is None:
							 | 
						||
| 
								 | 
							
								        # no type on the decorated function.  our option here is to
							 | 
						||
| 
								 | 
							
								        # dig into the function body and get the return type, but they
							 | 
						||
| 
								 | 
							
								        # should just have an annotation.
							 | 
						||
| 
								 | 
							
								        msg = (
							 | 
						||
| 
								 | 
							
								            "Can't infer type from @declared_attr on function '{}';  "
							 | 
						||
| 
								 | 
							
								            "please specify a return type from this function that is "
							 | 
						||
| 
								 | 
							
								            "one of: Mapped[<python type>], relationship[<target class>], "
							 | 
						||
| 
								 | 
							
								            "Column[<TypeEngine>], MapperProperty[<python type>]"
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        util.fail(api, msg.format(stmt.var.name), stmt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        left_hand_explicit_type = AnyType(TypeOfAny.special_form)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    left_node = NameExpr(stmt.var.name)
							 | 
						||
| 
								 | 
							
								    left_node.node = stmt.var
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # totally feeling around in the dark here as I don't totally understand
							 | 
						||
| 
								 | 
							
								    # the significance of UnboundType.  It seems to be something that is
							 | 
						||
| 
								 | 
							
								    # not going to do what's expected when it is applied as the type of
							 | 
						||
| 
								 | 
							
								    # an AssignmentStatement.  So do a feeling-around-in-the-dark version
							 | 
						||
| 
								 | 
							
								    # of converting it to the regular Instance/TypeInfo/UnionType structures
							 | 
						||
| 
								 | 
							
								    # we see everywhere else.
							 | 
						||
| 
								 | 
							
								    if isinstance(left_hand_explicit_type, UnboundType):
							 | 
						||
| 
								 | 
							
								        left_hand_explicit_type = get_proper_type(
							 | 
						||
| 
								 | 
							
								            util.unbound_to_instance(api, left_hand_explicit_type)
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    left_node.node.type = api.named_type(
							 | 
						||
| 
								 | 
							
								        names.NAMED_TYPE_SQLA_MAPPED, [left_hand_explicit_type]
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # this will ignore the rvalue entirely
							 | 
						||
| 
								 | 
							
								    # rvalue = TempNode(AnyType(TypeOfAny.special_form))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # rewrite the node as:
							 | 
						||
| 
								 | 
							
								    # <attr> : Mapped[<typ>] =
							 | 
						||
| 
								 | 
							
								    # _sa_Mapped._empty_constructor(lambda: <function body>)
							 | 
						||
| 
								 | 
							
								    # the function body is maintained so it gets type checked internally
							 | 
						||
| 
								 | 
							
								    rvalue = util.expr_to_mapped_constructor(
							 | 
						||
| 
								 | 
							
								        LambdaExpr(stmt.func.arguments, stmt.func.body)
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    new_stmt = AssignmentStmt([left_node], rvalue)
							 | 
						||
| 
								 | 
							
								    new_stmt.type = left_node.node.type
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    attributes.append(
							 | 
						||
| 
								 | 
							
								        util.SQLAlchemyAttribute(
							 | 
						||
| 
								 | 
							
								            name=left_node.name,
							 | 
						||
| 
								 | 
							
								            line=stmt.line,
							 | 
						||
| 
								 | 
							
								            column=stmt.column,
							 | 
						||
| 
								 | 
							
								            typ=left_hand_explicit_type,
							 | 
						||
| 
								 | 
							
								            info=cls.info,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								    cls.defs.body[dec_index] = new_stmt
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _scan_declarative_assignment_stmt(
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								    stmt: AssignmentStmt,
							 | 
						||
| 
								 | 
							
								    attributes: List[util.SQLAlchemyAttribute],
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    """Extract mapping information from an assignment statement in a
							 | 
						||
| 
								 | 
							
								    declarative class.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								    lvalue = stmt.lvalues[0]
							 | 
						||
| 
								 | 
							
								    if not isinstance(lvalue, NameExpr):
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    sym = cls.info.names.get(lvalue.name)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    # this establishes that semantic analysis has taken place, which
							 | 
						||
| 
								 | 
							
								    # means the nodes are populated and we are called from an appropriate
							 | 
						||
| 
								 | 
							
								    # hook.
							 | 
						||
| 
								 | 
							
								    assert sym is not None
							 | 
						||
| 
								 | 
							
								    node = sym.node
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if isinstance(node, PlaceholderNode):
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    assert node is lvalue.node
							 | 
						||
| 
								 | 
							
								    assert isinstance(node, Var)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if node.name == "__abstract__":
							 | 
						||
| 
								 | 
							
								        if api.parse_bool(stmt.rvalue) is True:
							 | 
						||
| 
								 | 
							
								            util.set_is_base(cls.info)
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								    elif node.name == "__tablename__":
							 | 
						||
| 
								 | 
							
								        util.set_has_table(cls.info)
							 | 
						||
| 
								 | 
							
								    elif node.name.startswith("__"):
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								    elif node.name == "_mypy_mapped_attrs":
							 | 
						||
| 
								 | 
							
								        if not isinstance(stmt.rvalue, ListExpr):
							 | 
						||
| 
								 | 
							
								            util.fail(api, "_mypy_mapped_attrs is expected to be a list", stmt)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            for item in stmt.rvalue.items:
							 | 
						||
| 
								 | 
							
								                if isinstance(item, (NameExpr, StrExpr)):
							 | 
						||
| 
								 | 
							
								                    apply.apply_mypy_mapped_attr(cls, api, item, attributes)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    left_hand_mapped_type: Optional[Type] = None
							 | 
						||
| 
								 | 
							
								    left_hand_explicit_type: Optional[ProperType] = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if node.is_inferred or node.type is None:
							 | 
						||
| 
								 | 
							
								        if isinstance(stmt.type, UnboundType):
							 | 
						||
| 
								 | 
							
								            # look for an explicit Mapped[] type annotation on the left
							 | 
						||
| 
								 | 
							
								            # side with nothing on the right
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # print(stmt.type)
							 | 
						||
| 
								 | 
							
								            # Mapped?[Optional?[A?]]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            left_hand_explicit_type = stmt.type
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            if stmt.type.name == "Mapped":
							 | 
						||
| 
								 | 
							
								                mapped_sym = api.lookup_qualified("Mapped", cls)
							 | 
						||
| 
								 | 
							
								                if (
							 | 
						||
| 
								 | 
							
								                    mapped_sym is not None
							 | 
						||
| 
								 | 
							
								                    and mapped_sym.node is not None
							 | 
						||
| 
								 | 
							
								                    and names.type_id_for_named_node(mapped_sym.node)
							 | 
						||
| 
								 | 
							
								                    is names.MAPPED
							 | 
						||
| 
								 | 
							
								                ):
							 | 
						||
| 
								 | 
							
								                    left_hand_explicit_type = get_proper_type(
							 | 
						||
| 
								 | 
							
								                        stmt.type.args[0]
							 | 
						||
| 
								 | 
							
								                    )
							 | 
						||
| 
								 | 
							
								                    left_hand_mapped_type = stmt.type
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            # TODO: do we need to convert from unbound for this case?
							 | 
						||
| 
								 | 
							
								            # left_hand_explicit_type = util._unbound_to_instance(
							 | 
						||
| 
								 | 
							
								            #     api, left_hand_explicit_type
							 | 
						||
| 
								 | 
							
								            # )
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        node_type = get_proper_type(node.type)
							 | 
						||
| 
								 | 
							
								        if (
							 | 
						||
| 
								 | 
							
								            isinstance(node_type, Instance)
							 | 
						||
| 
								 | 
							
								            and names.type_id_for_named_node(node_type.type) is names.MAPPED
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								            # print(node.type)
							 | 
						||
| 
								 | 
							
								            # sqlalchemy.orm.attributes.Mapped[<python type>]
							 | 
						||
| 
								 | 
							
								            left_hand_explicit_type = get_proper_type(node_type.args[0])
							 | 
						||
| 
								 | 
							
								            left_hand_mapped_type = node_type
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            # print(node.type)
							 | 
						||
| 
								 | 
							
								            # <python type>
							 | 
						||
| 
								 | 
							
								            left_hand_explicit_type = node_type
							 | 
						||
| 
								 | 
							
								            left_hand_mapped_type = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if isinstance(stmt.rvalue, TempNode) and left_hand_mapped_type is not None:
							 | 
						||
| 
								 | 
							
								        # annotation without assignment and Mapped is present
							 | 
						||
| 
								 | 
							
								        # as type annotation
							 | 
						||
| 
								 | 
							
								        # equivalent to using _infer_type_from_left_hand_type_only.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        python_type_for_type = left_hand_explicit_type
							 | 
						||
| 
								 | 
							
								    elif isinstance(stmt.rvalue, CallExpr) and isinstance(
							 | 
						||
| 
								 | 
							
								        stmt.rvalue.callee, RefExpr
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        python_type_for_type = infer.infer_type_from_right_hand_nameexpr(
							 | 
						||
| 
								 | 
							
								            api, stmt, node, left_hand_explicit_type, stmt.rvalue.callee
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if python_type_for_type is None:
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    else:
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    assert python_type_for_type is not None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    attributes.append(
							 | 
						||
| 
								 | 
							
								        util.SQLAlchemyAttribute(
							 | 
						||
| 
								 | 
							
								            name=node.name,
							 | 
						||
| 
								 | 
							
								            line=stmt.line,
							 | 
						||
| 
								 | 
							
								            column=stmt.column,
							 | 
						||
| 
								 | 
							
								            typ=python_type_for_type,
							 | 
						||
| 
								 | 
							
								            info=cls.info,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    apply.apply_type_to_mapped_statement(
							 | 
						||
| 
								 | 
							
								        api,
							 | 
						||
| 
								 | 
							
								        stmt,
							 | 
						||
| 
								 | 
							
								        lvalue,
							 | 
						||
| 
								 | 
							
								        left_hand_explicit_type,
							 | 
						||
| 
								 | 
							
								        python_type_for_type,
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								def _scan_for_mapped_bases(
							 | 
						||
| 
								 | 
							
								    cls: ClassDef,
							 | 
						||
| 
								 | 
							
								    api: SemanticAnalyzerPluginInterface,
							 | 
						||
| 
								 | 
							
								) -> None:
							 | 
						||
| 
								 | 
							
								    """Given a class, iterate through its superclass hierarchy to find
							 | 
						||
| 
								 | 
							
								    all other classes that are considered as ORM-significant.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    Locates non-mapped mixins and scans them for mapped attributes to be
							 | 
						||
| 
								 | 
							
								    applied to subclasses.
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    """
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    info = util.info_for_cls(cls, api)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    if info is None:
							 | 
						||
| 
								 | 
							
								        return
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    for base_info in info.mro[1:-1]:
							 | 
						||
| 
								 | 
							
								        if base_info.fullname.startswith("builtins"):
							 | 
						||
| 
								 | 
							
								            continue
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # scan each base for mapped attributes.  if they are not already
							 | 
						||
| 
								 | 
							
								        # scanned (but have all their type info), that means they are unmapped
							 | 
						||
| 
								 | 
							
								        # mixins
							 | 
						||
| 
								 | 
							
								        scan_declarative_assignments_and_apply_types(
							 | 
						||
| 
								 | 
							
								            base_info.defn, api, is_mixin_scan=True
							 | 
						||
| 
								 | 
							
								        )
							 |