You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							2518 lines
						
					
					
						
							82 KiB
						
					
					
				
			
		
		
	
	
							2518 lines
						
					
					
						
							82 KiB
						
					
					
				# orm/persistence.py
 | 
						|
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
 | 
						|
# <see AUTHORS file>
 | 
						|
#
 | 
						|
# This module is part of SQLAlchemy and is released under
 | 
						|
# the MIT License: https://www.opensource.org/licenses/mit-license.php
 | 
						|
 | 
						|
"""private module containing functions used to emit INSERT, UPDATE
 | 
						|
and DELETE statements on behalf of a :class:`_orm.Mapper` and its descending
 | 
						|
mappers.
 | 
						|
 | 
						|
The functions here are called only by the unit of work functions
 | 
						|
in unitofwork.py.
 | 
						|
 | 
						|
"""
 | 
						|
 | 
						|
from itertools import chain
 | 
						|
from itertools import groupby
 | 
						|
import operator
 | 
						|
 | 
						|
from . import attributes
 | 
						|
from . import evaluator
 | 
						|
from . import exc as orm_exc
 | 
						|
from . import loading
 | 
						|
from . import sync
 | 
						|
from .base import NO_VALUE
 | 
						|
from .base import state_str
 | 
						|
from .. import exc as sa_exc
 | 
						|
from .. import future
 | 
						|
from .. import sql
 | 
						|
from .. import util
 | 
						|
from ..engine import result as _result
 | 
						|
from ..sql import coercions
 | 
						|
from ..sql import expression
 | 
						|
from ..sql import operators
 | 
						|
from ..sql import roles
 | 
						|
from ..sql import select
 | 
						|
from ..sql import sqltypes
 | 
						|
from ..sql.base import _entity_namespace_key
 | 
						|
from ..sql.base import CompileState
 | 
						|
from ..sql.base import Options
 | 
						|
from ..sql.dml import DeleteDMLState
 | 
						|
from ..sql.dml import InsertDMLState
 | 
						|
from ..sql.dml import UpdateDMLState
 | 
						|
from ..sql.elements import BooleanClauseList
 | 
						|
from ..sql.selectable import LABEL_STYLE_TABLENAME_PLUS_COL
 | 
						|
 | 
						|
 | 
						|
def _bulk_insert(
 | 
						|
    mapper,
 | 
						|
    mappings,
 | 
						|
    session_transaction,
 | 
						|
    isstates,
 | 
						|
    return_defaults,
 | 
						|
    render_nulls,
 | 
						|
):
 | 
						|
    base_mapper = mapper.base_mapper
 | 
						|
 | 
						|
    if session_transaction.session.connection_callable:
 | 
						|
        raise NotImplementedError(
 | 
						|
            "connection_callable / per-instance sharding "
 | 
						|
            "not supported in bulk_insert()"
 | 
						|
        )
 | 
						|
 | 
						|
    if isstates:
 | 
						|
        if return_defaults:
 | 
						|
            states = [(state, state.dict) for state in mappings]
 | 
						|
            mappings = [dict_ for (state, dict_) in states]
 | 
						|
        else:
 | 
						|
            mappings = [state.dict for state in mappings]
 | 
						|
    else:
 | 
						|
        mappings = list(mappings)
 | 
						|
 | 
						|
    connection = session_transaction.connection(base_mapper)
 | 
						|
    for table, super_mapper in base_mapper._sorted_tables.items():
 | 
						|
        if not mapper.isa(super_mapper):
 | 
						|
            continue
 | 
						|
 | 
						|
        records = (
 | 
						|
            (
 | 
						|
                None,
 | 
						|
                state_dict,
 | 
						|
                params,
 | 
						|
                mapper,
 | 
						|
                connection,
 | 
						|
                value_params,
 | 
						|
                has_all_pks,
 | 
						|
                has_all_defaults,
 | 
						|
            )
 | 
						|
            for (
 | 
						|
                state,
 | 
						|
                state_dict,
 | 
						|
                params,
 | 
						|
                mp,
 | 
						|
                conn,
 | 
						|
                value_params,
 | 
						|
                has_all_pks,
 | 
						|
                has_all_defaults,
 | 
						|
            ) in _collect_insert_commands(
 | 
						|
                table,
 | 
						|
                ((None, mapping, mapper, connection) for mapping in mappings),
 | 
						|
                bulk=True,
 | 
						|
                return_defaults=return_defaults,
 | 
						|
                render_nulls=render_nulls,
 | 
						|
            )
 | 
						|
        )
 | 
						|
        _emit_insert_statements(
 | 
						|
            base_mapper,
 | 
						|
            None,
 | 
						|
            super_mapper,
 | 
						|
            table,
 | 
						|
            records,
 | 
						|
            bookkeeping=return_defaults,
 | 
						|
        )
 | 
						|
 | 
						|
    if return_defaults and isstates:
 | 
						|
        identity_cls = mapper._identity_class
 | 
						|
        identity_props = [p.key for p in mapper._identity_key_props]
 | 
						|
        for state, dict_ in states:
 | 
						|
            state.key = (
 | 
						|
                identity_cls,
 | 
						|
                tuple([dict_[key] for key in identity_props]),
 | 
						|
            )
 | 
						|
 | 
						|
 | 
						|
def _bulk_update(
 | 
						|
    mapper, mappings, session_transaction, isstates, update_changed_only
 | 
						|
):
 | 
						|
    base_mapper = mapper.base_mapper
 | 
						|
 | 
						|
    search_keys = mapper._primary_key_propkeys
 | 
						|
    if mapper._version_id_prop:
 | 
						|
        search_keys = {mapper._version_id_prop.key}.union(search_keys)
 | 
						|
 | 
						|
    def _changed_dict(mapper, state):
 | 
						|
        return dict(
 | 
						|
            (k, v)
 | 
						|
            for k, v in state.dict.items()
 | 
						|
            if k in state.committed_state or k in search_keys
 | 
						|
        )
 | 
						|
 | 
						|
    if isstates:
 | 
						|
        if update_changed_only:
 | 
						|
            mappings = [_changed_dict(mapper, state) for state in mappings]
 | 
						|
        else:
 | 
						|
            mappings = [state.dict for state in mappings]
 | 
						|
    else:
 | 
						|
        mappings = list(mappings)
 | 
						|
 | 
						|
    if session_transaction.session.connection_callable:
 | 
						|
        raise NotImplementedError(
 | 
						|
            "connection_callable / per-instance sharding "
 | 
						|
            "not supported in bulk_update()"
 | 
						|
        )
 | 
						|
 | 
						|
    connection = session_transaction.connection(base_mapper)
 | 
						|
 | 
						|
    for table, super_mapper in base_mapper._sorted_tables.items():
 | 
						|
        if not mapper.isa(super_mapper):
 | 
						|
            continue
 | 
						|
 | 
						|
        records = _collect_update_commands(
 | 
						|
            None,
 | 
						|
            table,
 | 
						|
            (
 | 
						|
                (
 | 
						|
                    None,
 | 
						|
                    mapping,
 | 
						|
                    mapper,
 | 
						|
                    connection,
 | 
						|
                    (
 | 
						|
                        mapping[mapper._version_id_prop.key]
 | 
						|
                        if mapper._version_id_prop
 | 
						|
                        else None
 | 
						|
                    ),
 | 
						|
                )
 | 
						|
                for mapping in mappings
 | 
						|
            ),
 | 
						|
            bulk=True,
 | 
						|
        )
 | 
						|
 | 
						|
        _emit_update_statements(
 | 
						|
            base_mapper,
 | 
						|
            None,
 | 
						|
            super_mapper,
 | 
						|
            table,
 | 
						|
            records,
 | 
						|
            bookkeeping=False,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def save_obj(base_mapper, states, uowtransaction, single=False):
 | 
						|
    """Issue ``INSERT`` and/or ``UPDATE`` statements for a list
 | 
						|
    of objects.
 | 
						|
 | 
						|
    This is called within the context of a UOWTransaction during a
 | 
						|
    flush operation, given a list of states to be flushed.  The
 | 
						|
    base mapper in an inheritance hierarchy handles the inserts/
 | 
						|
    updates for all descendant mappers.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    # if batch=false, call _save_obj separately for each object
 | 
						|
    if not single and not base_mapper.batch:
 | 
						|
        for state in _sort_states(base_mapper, states):
 | 
						|
            save_obj(base_mapper, [state], uowtransaction, single=True)
 | 
						|
        return
 | 
						|
 | 
						|
    states_to_update = []
 | 
						|
    states_to_insert = []
 | 
						|
 | 
						|
    for (
 | 
						|
        state,
 | 
						|
        dict_,
 | 
						|
        mapper,
 | 
						|
        connection,
 | 
						|
        has_identity,
 | 
						|
        row_switch,
 | 
						|
        update_version_id,
 | 
						|
    ) in _organize_states_for_save(base_mapper, states, uowtransaction):
 | 
						|
        if has_identity or row_switch:
 | 
						|
            states_to_update.append(
 | 
						|
                (state, dict_, mapper, connection, update_version_id)
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            states_to_insert.append((state, dict_, mapper, connection))
 | 
						|
 | 
						|
    for table, mapper in base_mapper._sorted_tables.items():
 | 
						|
        if table not in mapper._pks_by_table:
 | 
						|
            continue
 | 
						|
        insert = _collect_insert_commands(table, states_to_insert)
 | 
						|
 | 
						|
        update = _collect_update_commands(
 | 
						|
            uowtransaction, table, states_to_update
 | 
						|
        )
 | 
						|
 | 
						|
        _emit_update_statements(
 | 
						|
            base_mapper,
 | 
						|
            uowtransaction,
 | 
						|
            mapper,
 | 
						|
            table,
 | 
						|
            update,
 | 
						|
        )
 | 
						|
 | 
						|
        _emit_insert_statements(
 | 
						|
            base_mapper,
 | 
						|
            uowtransaction,
 | 
						|
            mapper,
 | 
						|
            table,
 | 
						|
            insert,
 | 
						|
        )
 | 
						|
 | 
						|
    _finalize_insert_update_commands(
 | 
						|
        base_mapper,
 | 
						|
        uowtransaction,
 | 
						|
        chain(
 | 
						|
            (
 | 
						|
                (state, state_dict, mapper, connection, False)
 | 
						|
                for (state, state_dict, mapper, connection) in states_to_insert
 | 
						|
            ),
 | 
						|
            (
 | 
						|
                (state, state_dict, mapper, connection, True)
 | 
						|
                for (
 | 
						|
                    state,
 | 
						|
                    state_dict,
 | 
						|
                    mapper,
 | 
						|
                    connection,
 | 
						|
                    update_version_id,
 | 
						|
                ) in states_to_update
 | 
						|
            ),
 | 
						|
        ),
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def post_update(base_mapper, states, uowtransaction, post_update_cols):
 | 
						|
    """Issue UPDATE statements on behalf of a relationship() which
 | 
						|
    specifies post_update.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    states_to_update = list(
 | 
						|
        _organize_states_for_post_update(base_mapper, states, uowtransaction)
 | 
						|
    )
 | 
						|
 | 
						|
    for table, mapper in base_mapper._sorted_tables.items():
 | 
						|
        if table not in mapper._pks_by_table:
 | 
						|
            continue
 | 
						|
 | 
						|
        update = (
 | 
						|
            (
 | 
						|
                state,
 | 
						|
                state_dict,
 | 
						|
                sub_mapper,
 | 
						|
                connection,
 | 
						|
                mapper._get_committed_state_attr_by_column(
 | 
						|
                    state, state_dict, mapper.version_id_col
 | 
						|
                )
 | 
						|
                if mapper.version_id_col is not None
 | 
						|
                else None,
 | 
						|
            )
 | 
						|
            for state, state_dict, sub_mapper, connection in states_to_update
 | 
						|
            if table in sub_mapper._pks_by_table
 | 
						|
        )
 | 
						|
 | 
						|
        update = _collect_post_update_commands(
 | 
						|
            base_mapper, uowtransaction, table, update, post_update_cols
 | 
						|
        )
 | 
						|
 | 
						|
        _emit_post_update_statements(
 | 
						|
            base_mapper,
 | 
						|
            uowtransaction,
 | 
						|
            mapper,
 | 
						|
            table,
 | 
						|
            update,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def delete_obj(base_mapper, states, uowtransaction):
 | 
						|
    """Issue ``DELETE`` statements for a list of objects.
 | 
						|
 | 
						|
    This is called within the context of a UOWTransaction during a
 | 
						|
    flush operation.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    states_to_delete = list(
 | 
						|
        _organize_states_for_delete(base_mapper, states, uowtransaction)
 | 
						|
    )
 | 
						|
 | 
						|
    table_to_mapper = base_mapper._sorted_tables
 | 
						|
 | 
						|
    for table in reversed(list(table_to_mapper.keys())):
 | 
						|
        mapper = table_to_mapper[table]
 | 
						|
        if table not in mapper._pks_by_table:
 | 
						|
            continue
 | 
						|
        elif mapper.inherits and mapper.passive_deletes:
 | 
						|
            continue
 | 
						|
 | 
						|
        delete = _collect_delete_commands(
 | 
						|
            base_mapper, uowtransaction, table, states_to_delete
 | 
						|
        )
 | 
						|
 | 
						|
        _emit_delete_statements(
 | 
						|
            base_mapper,
 | 
						|
            uowtransaction,
 | 
						|
            mapper,
 | 
						|
            table,
 | 
						|
            delete,
 | 
						|
        )
 | 
						|
 | 
						|
    for (
 | 
						|
        state,
 | 
						|
        state_dict,
 | 
						|
        mapper,
 | 
						|
        connection,
 | 
						|
        update_version_id,
 | 
						|
    ) in states_to_delete:
 | 
						|
        mapper.dispatch.after_delete(mapper, connection, state)
 | 
						|
 | 
						|
 | 
						|
def _organize_states_for_save(base_mapper, states, uowtransaction):
 | 
						|
    """Make an initial pass across a set of states for INSERT or
 | 
						|
    UPDATE.
 | 
						|
 | 
						|
    This includes splitting out into distinct lists for
 | 
						|
    each, calling before_insert/before_update, obtaining
 | 
						|
    key information for each state including its dictionary,
 | 
						|
    mapper, the connection to use for the execution per state,
 | 
						|
    and the identity flag.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    for state, dict_, mapper, connection in _connections_for_states(
 | 
						|
        base_mapper, uowtransaction, states
 | 
						|
    ):
 | 
						|
 | 
						|
        has_identity = bool(state.key)
 | 
						|
 | 
						|
        instance_key = state.key or mapper._identity_key_from_state(state)
 | 
						|
 | 
						|
        row_switch = update_version_id = None
 | 
						|
 | 
						|
        # call before_XXX extensions
 | 
						|
        if not has_identity:
 | 
						|
            mapper.dispatch.before_insert(mapper, connection, state)
 | 
						|
        else:
 | 
						|
            mapper.dispatch.before_update(mapper, connection, state)
 | 
						|
 | 
						|
        if mapper._validate_polymorphic_identity:
 | 
						|
            mapper._validate_polymorphic_identity(mapper, state, dict_)
 | 
						|
 | 
						|
        # detect if we have a "pending" instance (i.e. has
 | 
						|
        # no instance_key attached to it), and another instance
 | 
						|
        # with the same identity key already exists as persistent.
 | 
						|
        # convert to an UPDATE if so.
 | 
						|
        if (
 | 
						|
            not has_identity
 | 
						|
            and instance_key in uowtransaction.session.identity_map
 | 
						|
        ):
 | 
						|
            instance = uowtransaction.session.identity_map[instance_key]
 | 
						|
            existing = attributes.instance_state(instance)
 | 
						|
 | 
						|
            if not uowtransaction.was_already_deleted(existing):
 | 
						|
                if not uowtransaction.is_deleted(existing):
 | 
						|
                    util.warn(
 | 
						|
                        "New instance %s with identity key %s conflicts "
 | 
						|
                        "with persistent instance %s"
 | 
						|
                        % (state_str(state), instance_key, state_str(existing))
 | 
						|
                    )
 | 
						|
                else:
 | 
						|
                    base_mapper._log_debug(
 | 
						|
                        "detected row switch for identity %s.  "
 | 
						|
                        "will update %s, remove %s from "
 | 
						|
                        "transaction",
 | 
						|
                        instance_key,
 | 
						|
                        state_str(state),
 | 
						|
                        state_str(existing),
 | 
						|
                    )
 | 
						|
 | 
						|
                    # remove the "delete" flag from the existing element
 | 
						|
                    uowtransaction.remove_state_actions(existing)
 | 
						|
                    row_switch = existing
 | 
						|
 | 
						|
        if (has_identity or row_switch) and mapper.version_id_col is not None:
 | 
						|
            update_version_id = mapper._get_committed_state_attr_by_column(
 | 
						|
                row_switch if row_switch else state,
 | 
						|
                row_switch.dict if row_switch else dict_,
 | 
						|
                mapper.version_id_col,
 | 
						|
            )
 | 
						|
 | 
						|
        yield (
 | 
						|
            state,
 | 
						|
            dict_,
 | 
						|
            mapper,
 | 
						|
            connection,
 | 
						|
            has_identity,
 | 
						|
            row_switch,
 | 
						|
            update_version_id,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def _organize_states_for_post_update(base_mapper, states, uowtransaction):
 | 
						|
    """Make an initial pass across a set of states for UPDATE
 | 
						|
    corresponding to post_update.
 | 
						|
 | 
						|
    This includes obtaining key information for each state
 | 
						|
    including its dictionary, mapper, the connection to use for
 | 
						|
    the execution per state.
 | 
						|
 | 
						|
    """
 | 
						|
    return _connections_for_states(base_mapper, uowtransaction, states)
 | 
						|
 | 
						|
 | 
						|
def _organize_states_for_delete(base_mapper, states, uowtransaction):
 | 
						|
    """Make an initial pass across a set of states for DELETE.
 | 
						|
 | 
						|
    This includes calling out before_delete and obtaining
 | 
						|
    key information for each state including its dictionary,
 | 
						|
    mapper, the connection to use for the execution per state.
 | 
						|
 | 
						|
    """
 | 
						|
    for state, dict_, mapper, connection in _connections_for_states(
 | 
						|
        base_mapper, uowtransaction, states
 | 
						|
    ):
 | 
						|
 | 
						|
        mapper.dispatch.before_delete(mapper, connection, state)
 | 
						|
 | 
						|
        if mapper.version_id_col is not None:
 | 
						|
            update_version_id = mapper._get_committed_state_attr_by_column(
 | 
						|
                state, dict_, mapper.version_id_col
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            update_version_id = None
 | 
						|
 | 
						|
        yield (state, dict_, mapper, connection, update_version_id)
 | 
						|
 | 
						|
 | 
						|
def _collect_insert_commands(
 | 
						|
    table,
 | 
						|
    states_to_insert,
 | 
						|
    bulk=False,
 | 
						|
    return_defaults=False,
 | 
						|
    render_nulls=False,
 | 
						|
):
 | 
						|
    """Identify sets of values to use in INSERT statements for a
 | 
						|
    list of states.
 | 
						|
 | 
						|
    """
 | 
						|
    for state, state_dict, mapper, connection in states_to_insert:
 | 
						|
        if table not in mapper._pks_by_table:
 | 
						|
            continue
 | 
						|
 | 
						|
        params = {}
 | 
						|
        value_params = {}
 | 
						|
 | 
						|
        propkey_to_col = mapper._propkey_to_col[table]
 | 
						|
 | 
						|
        eval_none = mapper._insert_cols_evaluating_none[table]
 | 
						|
 | 
						|
        for propkey in set(propkey_to_col).intersection(state_dict):
 | 
						|
            value = state_dict[propkey]
 | 
						|
            col = propkey_to_col[propkey]
 | 
						|
            if value is None and col not in eval_none and not render_nulls:
 | 
						|
                continue
 | 
						|
            elif not bulk and (
 | 
						|
                hasattr(value, "__clause_element__")
 | 
						|
                or isinstance(value, sql.ClauseElement)
 | 
						|
            ):
 | 
						|
                value_params[col] = (
 | 
						|
                    value.__clause_element__()
 | 
						|
                    if hasattr(value, "__clause_element__")
 | 
						|
                    else value
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                params[col.key] = value
 | 
						|
 | 
						|
        if not bulk:
 | 
						|
            # for all the columns that have no default and we don't have
 | 
						|
            # a value and where "None" is not a special value, add
 | 
						|
            # explicit None to the INSERT.   This is a legacy behavior
 | 
						|
            # which might be worth removing, as it should not be necessary
 | 
						|
            # and also produces confusion, given that "missing" and None
 | 
						|
            # now have distinct meanings
 | 
						|
            for colkey in (
 | 
						|
                mapper._insert_cols_as_none[table]
 | 
						|
                .difference(params)
 | 
						|
                .difference([c.key for c in value_params])
 | 
						|
            ):
 | 
						|
                params[colkey] = None
 | 
						|
 | 
						|
        if not bulk or return_defaults:
 | 
						|
            # params are in terms of Column key objects, so
 | 
						|
            # compare to pk_keys_by_table
 | 
						|
            has_all_pks = mapper._pk_keys_by_table[table].issubset(params)
 | 
						|
 | 
						|
            if mapper.base_mapper.eager_defaults:
 | 
						|
                has_all_defaults = mapper._server_default_cols[table].issubset(
 | 
						|
                    params
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                has_all_defaults = True
 | 
						|
        else:
 | 
						|
            has_all_defaults = has_all_pks = True
 | 
						|
 | 
						|
        if (
 | 
						|
            mapper.version_id_generator is not False
 | 
						|
            and mapper.version_id_col is not None
 | 
						|
            and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
        ):
 | 
						|
            params[mapper.version_id_col.key] = mapper.version_id_generator(
 | 
						|
                None
 | 
						|
            )
 | 
						|
 | 
						|
        yield (
 | 
						|
            state,
 | 
						|
            state_dict,
 | 
						|
            params,
 | 
						|
            mapper,
 | 
						|
            connection,
 | 
						|
            value_params,
 | 
						|
            has_all_pks,
 | 
						|
            has_all_defaults,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def _collect_update_commands(
 | 
						|
    uowtransaction, table, states_to_update, bulk=False
 | 
						|
):
 | 
						|
    """Identify sets of values to use in UPDATE statements for a
 | 
						|
    list of states.
 | 
						|
 | 
						|
    This function works intricately with the history system
 | 
						|
    to determine exactly what values should be updated
 | 
						|
    as well as how the row should be matched within an UPDATE
 | 
						|
    statement.  Includes some tricky scenarios where the primary
 | 
						|
    key of an object might have been changed.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    for (
 | 
						|
        state,
 | 
						|
        state_dict,
 | 
						|
        mapper,
 | 
						|
        connection,
 | 
						|
        update_version_id,
 | 
						|
    ) in states_to_update:
 | 
						|
 | 
						|
        if table not in mapper._pks_by_table:
 | 
						|
            continue
 | 
						|
 | 
						|
        pks = mapper._pks_by_table[table]
 | 
						|
 | 
						|
        value_params = {}
 | 
						|
 | 
						|
        propkey_to_col = mapper._propkey_to_col[table]
 | 
						|
 | 
						|
        if bulk:
 | 
						|
            # keys here are mapped attribute keys, so
 | 
						|
            # look at mapper attribute keys for pk
 | 
						|
            params = dict(
 | 
						|
                (propkey_to_col[propkey].key, state_dict[propkey])
 | 
						|
                for propkey in set(propkey_to_col)
 | 
						|
                .intersection(state_dict)
 | 
						|
                .difference(mapper._pk_attr_keys_by_table[table])
 | 
						|
            )
 | 
						|
            has_all_defaults = True
 | 
						|
        else:
 | 
						|
            params = {}
 | 
						|
            for propkey in set(propkey_to_col).intersection(
 | 
						|
                state.committed_state
 | 
						|
            ):
 | 
						|
                value = state_dict[propkey]
 | 
						|
                col = propkey_to_col[propkey]
 | 
						|
 | 
						|
                if hasattr(value, "__clause_element__") or isinstance(
 | 
						|
                    value, sql.ClauseElement
 | 
						|
                ):
 | 
						|
                    value_params[col] = (
 | 
						|
                        value.__clause_element__()
 | 
						|
                        if hasattr(value, "__clause_element__")
 | 
						|
                        else value
 | 
						|
                    )
 | 
						|
                # guard against values that generate non-__nonzero__
 | 
						|
                # objects for __eq__()
 | 
						|
                elif (
 | 
						|
                    state.manager[propkey].impl.is_equal(
 | 
						|
                        value, state.committed_state[propkey]
 | 
						|
                    )
 | 
						|
                    is not True
 | 
						|
                ):
 | 
						|
                    params[col.key] = value
 | 
						|
 | 
						|
            if mapper.base_mapper.eager_defaults:
 | 
						|
                has_all_defaults = (
 | 
						|
                    mapper._server_onupdate_default_cols[table]
 | 
						|
                ).issubset(params)
 | 
						|
            else:
 | 
						|
                has_all_defaults = True
 | 
						|
 | 
						|
        if (
 | 
						|
            update_version_id is not None
 | 
						|
            and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
        ):
 | 
						|
 | 
						|
            if not bulk and not (params or value_params):
 | 
						|
                # HACK: check for history in other tables, in case the
 | 
						|
                # history is only in a different table than the one
 | 
						|
                # where the version_id_col is.  This logic was lost
 | 
						|
                # from 0.9 -> 1.0.0 and restored in 1.0.6.
 | 
						|
                for prop in mapper._columntoproperty.values():
 | 
						|
                    history = state.manager[prop.key].impl.get_history(
 | 
						|
                        state, state_dict, attributes.PASSIVE_NO_INITIALIZE
 | 
						|
                    )
 | 
						|
                    if history.added:
 | 
						|
                        break
 | 
						|
                else:
 | 
						|
                    # no net change, break
 | 
						|
                    continue
 | 
						|
 | 
						|
            col = mapper.version_id_col
 | 
						|
            no_params = not params and not value_params
 | 
						|
            params[col._label] = update_version_id
 | 
						|
 | 
						|
            if (
 | 
						|
                bulk or col.key not in params
 | 
						|
            ) and mapper.version_id_generator is not False:
 | 
						|
                val = mapper.version_id_generator(update_version_id)
 | 
						|
                params[col.key] = val
 | 
						|
            elif mapper.version_id_generator is False and no_params:
 | 
						|
                # no version id generator, no values set on the table,
 | 
						|
                # and version id wasn't manually incremented.
 | 
						|
                # set version id to itself so we get an UPDATE
 | 
						|
                # statement
 | 
						|
                params[col.key] = update_version_id
 | 
						|
 | 
						|
        elif not (params or value_params):
 | 
						|
            continue
 | 
						|
 | 
						|
        has_all_pks = True
 | 
						|
        expect_pk_cascaded = False
 | 
						|
        if bulk:
 | 
						|
            # keys here are mapped attribute keys, so
 | 
						|
            # look at mapper attribute keys for pk
 | 
						|
            pk_params = dict(
 | 
						|
                (propkey_to_col[propkey]._label, state_dict.get(propkey))
 | 
						|
                for propkey in set(propkey_to_col).intersection(
 | 
						|
                    mapper._pk_attr_keys_by_table[table]
 | 
						|
                )
 | 
						|
            )
 | 
						|
        else:
 | 
						|
            pk_params = {}
 | 
						|
            for col in pks:
 | 
						|
                propkey = mapper._columntoproperty[col].key
 | 
						|
 | 
						|
                history = state.manager[propkey].impl.get_history(
 | 
						|
                    state, state_dict, attributes.PASSIVE_OFF
 | 
						|
                )
 | 
						|
 | 
						|
                if history.added:
 | 
						|
                    if (
 | 
						|
                        not history.deleted
 | 
						|
                        or ("pk_cascaded", state, col)
 | 
						|
                        in uowtransaction.attributes
 | 
						|
                    ):
 | 
						|
                        expect_pk_cascaded = True
 | 
						|
                        pk_params[col._label] = history.added[0]
 | 
						|
                        params.pop(col.key, None)
 | 
						|
                    else:
 | 
						|
                        # else, use the old value to locate the row
 | 
						|
                        pk_params[col._label] = history.deleted[0]
 | 
						|
                        if col in value_params:
 | 
						|
                            has_all_pks = False
 | 
						|
                else:
 | 
						|
                    pk_params[col._label] = history.unchanged[0]
 | 
						|
                if pk_params[col._label] is None:
 | 
						|
                    raise orm_exc.FlushError(
 | 
						|
                        "Can't update table %s using NULL for primary "
 | 
						|
                        "key value on column %s" % (table, col)
 | 
						|
                    )
 | 
						|
 | 
						|
        if params or value_params:
 | 
						|
            params.update(pk_params)
 | 
						|
            yield (
 | 
						|
                state,
 | 
						|
                state_dict,
 | 
						|
                params,
 | 
						|
                mapper,
 | 
						|
                connection,
 | 
						|
                value_params,
 | 
						|
                has_all_defaults,
 | 
						|
                has_all_pks,
 | 
						|
            )
 | 
						|
        elif expect_pk_cascaded:
 | 
						|
            # no UPDATE occurs on this table, but we expect that CASCADE rules
 | 
						|
            # have changed the primary key of the row; propagate this event to
 | 
						|
            # other columns that expect to have been modified. this normally
 | 
						|
            # occurs after the UPDATE is emitted however we invoke it here
 | 
						|
            # explicitly in the absence of our invoking an UPDATE
 | 
						|
            for m, equated_pairs in mapper._table_to_equated[table]:
 | 
						|
                sync.populate(
 | 
						|
                    state,
 | 
						|
                    m,
 | 
						|
                    state,
 | 
						|
                    m,
 | 
						|
                    equated_pairs,
 | 
						|
                    uowtransaction,
 | 
						|
                    mapper.passive_updates,
 | 
						|
                )
 | 
						|
 | 
						|
 | 
						|
def _collect_post_update_commands(
 | 
						|
    base_mapper, uowtransaction, table, states_to_update, post_update_cols
 | 
						|
):
 | 
						|
    """Identify sets of values to use in UPDATE statements for a
 | 
						|
    list of states within a post_update operation.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    for (
 | 
						|
        state,
 | 
						|
        state_dict,
 | 
						|
        mapper,
 | 
						|
        connection,
 | 
						|
        update_version_id,
 | 
						|
    ) in states_to_update:
 | 
						|
 | 
						|
        # assert table in mapper._pks_by_table
 | 
						|
 | 
						|
        pks = mapper._pks_by_table[table]
 | 
						|
        params = {}
 | 
						|
        hasdata = False
 | 
						|
 | 
						|
        for col in mapper._cols_by_table[table]:
 | 
						|
            if col in pks:
 | 
						|
                params[col._label] = mapper._get_state_attr_by_column(
 | 
						|
                    state, state_dict, col, passive=attributes.PASSIVE_OFF
 | 
						|
                )
 | 
						|
 | 
						|
            elif col in post_update_cols or col.onupdate is not None:
 | 
						|
                prop = mapper._columntoproperty[col]
 | 
						|
                history = state.manager[prop.key].impl.get_history(
 | 
						|
                    state, state_dict, attributes.PASSIVE_NO_INITIALIZE
 | 
						|
                )
 | 
						|
                if history.added:
 | 
						|
                    value = history.added[0]
 | 
						|
                    params[col.key] = value
 | 
						|
                    hasdata = True
 | 
						|
        if hasdata:
 | 
						|
            if (
 | 
						|
                update_version_id is not None
 | 
						|
                and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
            ):
 | 
						|
 | 
						|
                col = mapper.version_id_col
 | 
						|
                params[col._label] = update_version_id
 | 
						|
 | 
						|
                if (
 | 
						|
                    bool(state.key)
 | 
						|
                    and col.key not in params
 | 
						|
                    and mapper.version_id_generator is not False
 | 
						|
                ):
 | 
						|
                    val = mapper.version_id_generator(update_version_id)
 | 
						|
                    params[col.key] = val
 | 
						|
            yield state, state_dict, mapper, connection, params
 | 
						|
 | 
						|
 | 
						|
def _collect_delete_commands(
 | 
						|
    base_mapper, uowtransaction, table, states_to_delete
 | 
						|
):
 | 
						|
    """Identify values to use in DELETE statements for a list of
 | 
						|
    states to be deleted."""
 | 
						|
 | 
						|
    for (
 | 
						|
        state,
 | 
						|
        state_dict,
 | 
						|
        mapper,
 | 
						|
        connection,
 | 
						|
        update_version_id,
 | 
						|
    ) in states_to_delete:
 | 
						|
 | 
						|
        if table not in mapper._pks_by_table:
 | 
						|
            continue
 | 
						|
 | 
						|
        params = {}
 | 
						|
        for col in mapper._pks_by_table[table]:
 | 
						|
            params[
 | 
						|
                col.key
 | 
						|
            ] = value = mapper._get_committed_state_attr_by_column(
 | 
						|
                state, state_dict, col
 | 
						|
            )
 | 
						|
            if value is None:
 | 
						|
                raise orm_exc.FlushError(
 | 
						|
                    "Can't delete from table %s "
 | 
						|
                    "using NULL for primary "
 | 
						|
                    "key value on column %s" % (table, col)
 | 
						|
                )
 | 
						|
 | 
						|
        if (
 | 
						|
            update_version_id is not None
 | 
						|
            and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
        ):
 | 
						|
            params[mapper.version_id_col.key] = update_version_id
 | 
						|
        yield params, connection
 | 
						|
 | 
						|
 | 
						|
def _emit_update_statements(
 | 
						|
    base_mapper,
 | 
						|
    uowtransaction,
 | 
						|
    mapper,
 | 
						|
    table,
 | 
						|
    update,
 | 
						|
    bookkeeping=True,
 | 
						|
):
 | 
						|
    """Emit UPDATE statements corresponding to value lists collected
 | 
						|
    by _collect_update_commands()."""
 | 
						|
 | 
						|
    needs_version_id = (
 | 
						|
        mapper.version_id_col is not None
 | 
						|
        and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
    )
 | 
						|
 | 
						|
    execution_options = {"compiled_cache": base_mapper._compiled_cache}
 | 
						|
 | 
						|
    def update_stmt():
 | 
						|
        clauses = BooleanClauseList._construct_raw(operators.and_)
 | 
						|
 | 
						|
        for col in mapper._pks_by_table[table]:
 | 
						|
            clauses.clauses.append(
 | 
						|
                col == sql.bindparam(col._label, type_=col.type)
 | 
						|
            )
 | 
						|
 | 
						|
        if needs_version_id:
 | 
						|
            clauses.clauses.append(
 | 
						|
                mapper.version_id_col
 | 
						|
                == sql.bindparam(
 | 
						|
                    mapper.version_id_col._label,
 | 
						|
                    type_=mapper.version_id_col.type,
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        stmt = table.update().where(clauses)
 | 
						|
        return stmt
 | 
						|
 | 
						|
    cached_stmt = base_mapper._memo(("update", table), update_stmt)
 | 
						|
 | 
						|
    for (
 | 
						|
        (connection, paramkeys, hasvalue, has_all_defaults, has_all_pks),
 | 
						|
        records,
 | 
						|
    ) in groupby(
 | 
						|
        update,
 | 
						|
        lambda rec: (
 | 
						|
            rec[4],  # connection
 | 
						|
            set(rec[2]),  # set of parameter keys
 | 
						|
            bool(rec[5]),  # whether or not we have "value" parameters
 | 
						|
            rec[6],  # has_all_defaults
 | 
						|
            rec[7],  # has all pks
 | 
						|
        ),
 | 
						|
    ):
 | 
						|
        rows = 0
 | 
						|
        records = list(records)
 | 
						|
 | 
						|
        statement = cached_stmt
 | 
						|
        return_defaults = False
 | 
						|
 | 
						|
        if not has_all_pks:
 | 
						|
            statement = statement.return_defaults()
 | 
						|
            return_defaults = True
 | 
						|
        elif (
 | 
						|
            bookkeeping
 | 
						|
            and not has_all_defaults
 | 
						|
            and mapper.base_mapper.eager_defaults
 | 
						|
        ):
 | 
						|
            statement = statement.return_defaults()
 | 
						|
            return_defaults = True
 | 
						|
        elif mapper.version_id_col is not None:
 | 
						|
            statement = statement.return_defaults(mapper.version_id_col)
 | 
						|
            return_defaults = True
 | 
						|
 | 
						|
        assert_singlerow = (
 | 
						|
            connection.dialect.supports_sane_rowcount
 | 
						|
            if not return_defaults
 | 
						|
            else connection.dialect.supports_sane_rowcount_returning
 | 
						|
        )
 | 
						|
 | 
						|
        assert_multirow = (
 | 
						|
            assert_singlerow
 | 
						|
            and connection.dialect.supports_sane_multi_rowcount
 | 
						|
        )
 | 
						|
        allow_multirow = has_all_defaults and not needs_version_id
 | 
						|
 | 
						|
        if hasvalue:
 | 
						|
            for (
 | 
						|
                state,
 | 
						|
                state_dict,
 | 
						|
                params,
 | 
						|
                mapper,
 | 
						|
                connection,
 | 
						|
                value_params,
 | 
						|
                has_all_defaults,
 | 
						|
                has_all_pks,
 | 
						|
            ) in records:
 | 
						|
                c = connection._execute_20(
 | 
						|
                    statement.values(value_params),
 | 
						|
                    params,
 | 
						|
                    execution_options=execution_options,
 | 
						|
                )
 | 
						|
                if bookkeeping:
 | 
						|
                    _postfetch(
 | 
						|
                        mapper,
 | 
						|
                        uowtransaction,
 | 
						|
                        table,
 | 
						|
                        state,
 | 
						|
                        state_dict,
 | 
						|
                        c,
 | 
						|
                        c.context.compiled_parameters[0],
 | 
						|
                        value_params,
 | 
						|
                        True,
 | 
						|
                        c.returned_defaults,
 | 
						|
                    )
 | 
						|
                rows += c.rowcount
 | 
						|
                check_rowcount = assert_singlerow
 | 
						|
        else:
 | 
						|
            if not allow_multirow:
 | 
						|
                check_rowcount = assert_singlerow
 | 
						|
                for (
 | 
						|
                    state,
 | 
						|
                    state_dict,
 | 
						|
                    params,
 | 
						|
                    mapper,
 | 
						|
                    connection,
 | 
						|
                    value_params,
 | 
						|
                    has_all_defaults,
 | 
						|
                    has_all_pks,
 | 
						|
                ) in records:
 | 
						|
                    c = connection._execute_20(
 | 
						|
                        statement, params, execution_options=execution_options
 | 
						|
                    )
 | 
						|
 | 
						|
                    # TODO: why with bookkeeping=False?
 | 
						|
                    if bookkeeping:
 | 
						|
                        _postfetch(
 | 
						|
                            mapper,
 | 
						|
                            uowtransaction,
 | 
						|
                            table,
 | 
						|
                            state,
 | 
						|
                            state_dict,
 | 
						|
                            c,
 | 
						|
                            c.context.compiled_parameters[0],
 | 
						|
                            value_params,
 | 
						|
                            True,
 | 
						|
                            c.returned_defaults,
 | 
						|
                        )
 | 
						|
                    rows += c.rowcount
 | 
						|
            else:
 | 
						|
                multiparams = [rec[2] for rec in records]
 | 
						|
 | 
						|
                check_rowcount = assert_multirow or (
 | 
						|
                    assert_singlerow and len(multiparams) == 1
 | 
						|
                )
 | 
						|
 | 
						|
                c = connection._execute_20(
 | 
						|
                    statement, multiparams, execution_options=execution_options
 | 
						|
                )
 | 
						|
 | 
						|
                rows += c.rowcount
 | 
						|
 | 
						|
                for (
 | 
						|
                    state,
 | 
						|
                    state_dict,
 | 
						|
                    params,
 | 
						|
                    mapper,
 | 
						|
                    connection,
 | 
						|
                    value_params,
 | 
						|
                    has_all_defaults,
 | 
						|
                    has_all_pks,
 | 
						|
                ) in records:
 | 
						|
                    if bookkeeping:
 | 
						|
                        _postfetch(
 | 
						|
                            mapper,
 | 
						|
                            uowtransaction,
 | 
						|
                            table,
 | 
						|
                            state,
 | 
						|
                            state_dict,
 | 
						|
                            c,
 | 
						|
                            c.context.compiled_parameters[0],
 | 
						|
                            value_params,
 | 
						|
                            True,
 | 
						|
                            c.returned_defaults
 | 
						|
                            if not c.context.executemany
 | 
						|
                            else None,
 | 
						|
                        )
 | 
						|
 | 
						|
        if check_rowcount:
 | 
						|
            if rows != len(records):
 | 
						|
                raise orm_exc.StaleDataError(
 | 
						|
                    "UPDATE statement on table '%s' expected to "
 | 
						|
                    "update %d row(s); %d were matched."
 | 
						|
                    % (table.description, len(records), rows)
 | 
						|
                )
 | 
						|
 | 
						|
        elif needs_version_id:
 | 
						|
            util.warn(
 | 
						|
                "Dialect %s does not support updated rowcount "
 | 
						|
                "- versioning cannot be verified."
 | 
						|
                % c.dialect.dialect_description
 | 
						|
            )
 | 
						|
 | 
						|
 | 
						|
def _emit_insert_statements(
 | 
						|
    base_mapper,
 | 
						|
    uowtransaction,
 | 
						|
    mapper,
 | 
						|
    table,
 | 
						|
    insert,
 | 
						|
    bookkeeping=True,
 | 
						|
):
 | 
						|
    """Emit INSERT statements corresponding to value lists collected
 | 
						|
    by _collect_insert_commands()."""
 | 
						|
 | 
						|
    cached_stmt = base_mapper._memo(("insert", table), table.insert)
 | 
						|
 | 
						|
    execution_options = {"compiled_cache": base_mapper._compiled_cache}
 | 
						|
 | 
						|
    for (
 | 
						|
        (connection, pkeys, hasvalue, has_all_pks, has_all_defaults),
 | 
						|
        records,
 | 
						|
    ) in groupby(
 | 
						|
        insert,
 | 
						|
        lambda rec: (
 | 
						|
            rec[4],  # connection
 | 
						|
            set(rec[2]),  # parameter keys
 | 
						|
            bool(rec[5]),  # whether we have "value" parameters
 | 
						|
            rec[6],
 | 
						|
            rec[7],
 | 
						|
        ),
 | 
						|
    ):
 | 
						|
 | 
						|
        statement = cached_stmt
 | 
						|
 | 
						|
        if (
 | 
						|
            not bookkeeping
 | 
						|
            or (
 | 
						|
                has_all_defaults
 | 
						|
                or not base_mapper.eager_defaults
 | 
						|
                or not connection.dialect.implicit_returning
 | 
						|
            )
 | 
						|
            and has_all_pks
 | 
						|
            and not hasvalue
 | 
						|
        ):
 | 
						|
            # the "we don't need newly generated values back" section.
 | 
						|
            # here we have all the PKs, all the defaults or we don't want
 | 
						|
            # to fetch them, or the dialect doesn't support RETURNING at all
 | 
						|
            # so we have to post-fetch / use lastrowid anyway.
 | 
						|
            records = list(records)
 | 
						|
            multiparams = [rec[2] for rec in records]
 | 
						|
 | 
						|
            c = connection._execute_20(
 | 
						|
                statement, multiparams, execution_options=execution_options
 | 
						|
            )
 | 
						|
 | 
						|
            if bookkeeping:
 | 
						|
                for (
 | 
						|
                    (
 | 
						|
                        state,
 | 
						|
                        state_dict,
 | 
						|
                        params,
 | 
						|
                        mapper_rec,
 | 
						|
                        conn,
 | 
						|
                        value_params,
 | 
						|
                        has_all_pks,
 | 
						|
                        has_all_defaults,
 | 
						|
                    ),
 | 
						|
                    last_inserted_params,
 | 
						|
                ) in zip(records, c.context.compiled_parameters):
 | 
						|
                    if state:
 | 
						|
                        _postfetch(
 | 
						|
                            mapper_rec,
 | 
						|
                            uowtransaction,
 | 
						|
                            table,
 | 
						|
                            state,
 | 
						|
                            state_dict,
 | 
						|
                            c,
 | 
						|
                            last_inserted_params,
 | 
						|
                            value_params,
 | 
						|
                            False,
 | 
						|
                            c.returned_defaults
 | 
						|
                            if not c.context.executemany
 | 
						|
                            else None,
 | 
						|
                        )
 | 
						|
                    else:
 | 
						|
                        _postfetch_bulk_save(mapper_rec, state_dict, table)
 | 
						|
 | 
						|
        else:
 | 
						|
            # here, we need defaults and/or pk values back.
 | 
						|
 | 
						|
            records = list(records)
 | 
						|
            if (
 | 
						|
                not hasvalue
 | 
						|
                and connection.dialect.insert_executemany_returning
 | 
						|
                and len(records) > 1
 | 
						|
            ):
 | 
						|
                do_executemany = True
 | 
						|
            else:
 | 
						|
                do_executemany = False
 | 
						|
 | 
						|
            if not has_all_defaults and base_mapper.eager_defaults:
 | 
						|
                statement = statement.return_defaults()
 | 
						|
            elif mapper.version_id_col is not None:
 | 
						|
                statement = statement.return_defaults(mapper.version_id_col)
 | 
						|
            elif do_executemany:
 | 
						|
                statement = statement.return_defaults(*table.primary_key)
 | 
						|
 | 
						|
            if do_executemany:
 | 
						|
                multiparams = [rec[2] for rec in records]
 | 
						|
 | 
						|
                c = connection._execute_20(
 | 
						|
                    statement, multiparams, execution_options=execution_options
 | 
						|
                )
 | 
						|
 | 
						|
                if bookkeeping:
 | 
						|
                    for (
 | 
						|
                        (
 | 
						|
                            state,
 | 
						|
                            state_dict,
 | 
						|
                            params,
 | 
						|
                            mapper_rec,
 | 
						|
                            conn,
 | 
						|
                            value_params,
 | 
						|
                            has_all_pks,
 | 
						|
                            has_all_defaults,
 | 
						|
                        ),
 | 
						|
                        last_inserted_params,
 | 
						|
                        inserted_primary_key,
 | 
						|
                        returned_defaults,
 | 
						|
                    ) in util.zip_longest(
 | 
						|
                        records,
 | 
						|
                        c.context.compiled_parameters,
 | 
						|
                        c.inserted_primary_key_rows,
 | 
						|
                        c.returned_defaults_rows or (),
 | 
						|
                    ):
 | 
						|
                        if inserted_primary_key is None:
 | 
						|
                            # this is a real problem and means that we didn't
 | 
						|
                            # get back as many PK rows.  we can't continue
 | 
						|
                            # since this indicates PK rows were missing, which
 | 
						|
                            # means we likely mis-populated records starting
 | 
						|
                            # at that point with incorrectly matched PK
 | 
						|
                            # values.
 | 
						|
                            raise orm_exc.FlushError(
 | 
						|
                                "Multi-row INSERT statement for %s did not "
 | 
						|
                                "produce "
 | 
						|
                                "the correct number of INSERTed rows for "
 | 
						|
                                "RETURNING.  Ensure there are no triggers or "
 | 
						|
                                "special driver issues preventing INSERT from "
 | 
						|
                                "functioning properly." % mapper_rec
 | 
						|
                            )
 | 
						|
 | 
						|
                        for pk, col in zip(
 | 
						|
                            inserted_primary_key,
 | 
						|
                            mapper._pks_by_table[table],
 | 
						|
                        ):
 | 
						|
                            prop = mapper_rec._columntoproperty[col]
 | 
						|
                            if state_dict.get(prop.key) is None:
 | 
						|
                                state_dict[prop.key] = pk
 | 
						|
 | 
						|
                        if state:
 | 
						|
                            _postfetch(
 | 
						|
                                mapper_rec,
 | 
						|
                                uowtransaction,
 | 
						|
                                table,
 | 
						|
                                state,
 | 
						|
                                state_dict,
 | 
						|
                                c,
 | 
						|
                                last_inserted_params,
 | 
						|
                                value_params,
 | 
						|
                                False,
 | 
						|
                                returned_defaults,
 | 
						|
                            )
 | 
						|
                        else:
 | 
						|
                            _postfetch_bulk_save(mapper_rec, state_dict, table)
 | 
						|
            else:
 | 
						|
                for (
 | 
						|
                    state,
 | 
						|
                    state_dict,
 | 
						|
                    params,
 | 
						|
                    mapper_rec,
 | 
						|
                    connection,
 | 
						|
                    value_params,
 | 
						|
                    has_all_pks,
 | 
						|
                    has_all_defaults,
 | 
						|
                ) in records:
 | 
						|
                    if value_params:
 | 
						|
                        result = connection._execute_20(
 | 
						|
                            statement.values(value_params),
 | 
						|
                            params,
 | 
						|
                            execution_options=execution_options,
 | 
						|
                        )
 | 
						|
                    else:
 | 
						|
                        result = connection._execute_20(
 | 
						|
                            statement,
 | 
						|
                            params,
 | 
						|
                            execution_options=execution_options,
 | 
						|
                        )
 | 
						|
 | 
						|
                    primary_key = result.inserted_primary_key
 | 
						|
                    if primary_key is None:
 | 
						|
                        raise orm_exc.FlushError(
 | 
						|
                            "Single-row INSERT statement for %s "
 | 
						|
                            "did not produce a "
 | 
						|
                            "new primary key result "
 | 
						|
                            "being invoked.  Ensure there are no triggers or "
 | 
						|
                            "special driver issues preventing INSERT from "
 | 
						|
                            "functioning properly." % (mapper_rec,)
 | 
						|
                        )
 | 
						|
                    for pk, col in zip(
 | 
						|
                        primary_key, mapper._pks_by_table[table]
 | 
						|
                    ):
 | 
						|
                        prop = mapper_rec._columntoproperty[col]
 | 
						|
                        if (
 | 
						|
                            col in value_params
 | 
						|
                            or state_dict.get(prop.key) is None
 | 
						|
                        ):
 | 
						|
                            state_dict[prop.key] = pk
 | 
						|
                    if bookkeeping:
 | 
						|
                        if state:
 | 
						|
                            _postfetch(
 | 
						|
                                mapper_rec,
 | 
						|
                                uowtransaction,
 | 
						|
                                table,
 | 
						|
                                state,
 | 
						|
                                state_dict,
 | 
						|
                                result,
 | 
						|
                                result.context.compiled_parameters[0],
 | 
						|
                                value_params,
 | 
						|
                                False,
 | 
						|
                                result.returned_defaults
 | 
						|
                                if not result.context.executemany
 | 
						|
                                else None,
 | 
						|
                            )
 | 
						|
                        else:
 | 
						|
                            _postfetch_bulk_save(mapper_rec, state_dict, table)
 | 
						|
 | 
						|
 | 
						|
def _emit_post_update_statements(
 | 
						|
    base_mapper, uowtransaction, mapper, table, update
 | 
						|
):
 | 
						|
    """Emit UPDATE statements corresponding to value lists collected
 | 
						|
    by _collect_post_update_commands()."""
 | 
						|
 | 
						|
    execution_options = {"compiled_cache": base_mapper._compiled_cache}
 | 
						|
 | 
						|
    needs_version_id = (
 | 
						|
        mapper.version_id_col is not None
 | 
						|
        and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
    )
 | 
						|
 | 
						|
    def update_stmt():
 | 
						|
        clauses = BooleanClauseList._construct_raw(operators.and_)
 | 
						|
 | 
						|
        for col in mapper._pks_by_table[table]:
 | 
						|
            clauses.clauses.append(
 | 
						|
                col == sql.bindparam(col._label, type_=col.type)
 | 
						|
            )
 | 
						|
 | 
						|
        if needs_version_id:
 | 
						|
            clauses.clauses.append(
 | 
						|
                mapper.version_id_col
 | 
						|
                == sql.bindparam(
 | 
						|
                    mapper.version_id_col._label,
 | 
						|
                    type_=mapper.version_id_col.type,
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        stmt = table.update().where(clauses)
 | 
						|
 | 
						|
        if mapper.version_id_col is not None:
 | 
						|
            stmt = stmt.return_defaults(mapper.version_id_col)
 | 
						|
 | 
						|
        return stmt
 | 
						|
 | 
						|
    statement = base_mapper._memo(("post_update", table), update_stmt)
 | 
						|
 | 
						|
    # execute each UPDATE in the order according to the original
 | 
						|
    # list of states to guarantee row access order, but
 | 
						|
    # also group them into common (connection, cols) sets
 | 
						|
    # to support executemany().
 | 
						|
    for key, records in groupby(
 | 
						|
        update,
 | 
						|
        lambda rec: (rec[3], set(rec[4])),  # connection  # parameter keys
 | 
						|
    ):
 | 
						|
        rows = 0
 | 
						|
 | 
						|
        records = list(records)
 | 
						|
        connection = key[0]
 | 
						|
 | 
						|
        assert_singlerow = (
 | 
						|
            connection.dialect.supports_sane_rowcount
 | 
						|
            if mapper.version_id_col is None
 | 
						|
            else connection.dialect.supports_sane_rowcount_returning
 | 
						|
        )
 | 
						|
        assert_multirow = (
 | 
						|
            assert_singlerow
 | 
						|
            and connection.dialect.supports_sane_multi_rowcount
 | 
						|
        )
 | 
						|
        allow_multirow = not needs_version_id or assert_multirow
 | 
						|
 | 
						|
        if not allow_multirow:
 | 
						|
            check_rowcount = assert_singlerow
 | 
						|
            for state, state_dict, mapper_rec, connection, params in records:
 | 
						|
 | 
						|
                c = connection._execute_20(
 | 
						|
                    statement, params, execution_options=execution_options
 | 
						|
                )
 | 
						|
 | 
						|
                _postfetch_post_update(
 | 
						|
                    mapper_rec,
 | 
						|
                    uowtransaction,
 | 
						|
                    table,
 | 
						|
                    state,
 | 
						|
                    state_dict,
 | 
						|
                    c,
 | 
						|
                    c.context.compiled_parameters[0],
 | 
						|
                )
 | 
						|
                rows += c.rowcount
 | 
						|
        else:
 | 
						|
            multiparams = [
 | 
						|
                params
 | 
						|
                for state, state_dict, mapper_rec, conn, params in records
 | 
						|
            ]
 | 
						|
 | 
						|
            check_rowcount = assert_multirow or (
 | 
						|
                assert_singlerow and len(multiparams) == 1
 | 
						|
            )
 | 
						|
 | 
						|
            c = connection._execute_20(
 | 
						|
                statement, multiparams, execution_options=execution_options
 | 
						|
            )
 | 
						|
 | 
						|
            rows += c.rowcount
 | 
						|
            for state, state_dict, mapper_rec, connection, params in records:
 | 
						|
                _postfetch_post_update(
 | 
						|
                    mapper_rec,
 | 
						|
                    uowtransaction,
 | 
						|
                    table,
 | 
						|
                    state,
 | 
						|
                    state_dict,
 | 
						|
                    c,
 | 
						|
                    c.context.compiled_parameters[0],
 | 
						|
                )
 | 
						|
 | 
						|
        if check_rowcount:
 | 
						|
            if rows != len(records):
 | 
						|
                raise orm_exc.StaleDataError(
 | 
						|
                    "UPDATE statement on table '%s' expected to "
 | 
						|
                    "update %d row(s); %d were matched."
 | 
						|
                    % (table.description, len(records), rows)
 | 
						|
                )
 | 
						|
 | 
						|
        elif needs_version_id:
 | 
						|
            util.warn(
 | 
						|
                "Dialect %s does not support updated rowcount "
 | 
						|
                "- versioning cannot be verified."
 | 
						|
                % c.dialect.dialect_description
 | 
						|
            )
 | 
						|
 | 
						|
 | 
						|
def _emit_delete_statements(
 | 
						|
    base_mapper, uowtransaction, mapper, table, delete
 | 
						|
):
 | 
						|
    """Emit DELETE statements corresponding to value lists collected
 | 
						|
    by _collect_delete_commands()."""
 | 
						|
 | 
						|
    need_version_id = (
 | 
						|
        mapper.version_id_col is not None
 | 
						|
        and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
    )
 | 
						|
 | 
						|
    def delete_stmt():
 | 
						|
        clauses = BooleanClauseList._construct_raw(operators.and_)
 | 
						|
 | 
						|
        for col in mapper._pks_by_table[table]:
 | 
						|
            clauses.clauses.append(
 | 
						|
                col == sql.bindparam(col.key, type_=col.type)
 | 
						|
            )
 | 
						|
 | 
						|
        if need_version_id:
 | 
						|
            clauses.clauses.append(
 | 
						|
                mapper.version_id_col
 | 
						|
                == sql.bindparam(
 | 
						|
                    mapper.version_id_col.key, type_=mapper.version_id_col.type
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        return table.delete().where(clauses)
 | 
						|
 | 
						|
    statement = base_mapper._memo(("delete", table), delete_stmt)
 | 
						|
    for connection, recs in groupby(delete, lambda rec: rec[1]):  # connection
 | 
						|
        del_objects = [params for params, connection in recs]
 | 
						|
 | 
						|
        execution_options = {"compiled_cache": base_mapper._compiled_cache}
 | 
						|
        expected = len(del_objects)
 | 
						|
        rows_matched = -1
 | 
						|
        only_warn = False
 | 
						|
 | 
						|
        if (
 | 
						|
            need_version_id
 | 
						|
            and not connection.dialect.supports_sane_multi_rowcount
 | 
						|
        ):
 | 
						|
            if connection.dialect.supports_sane_rowcount:
 | 
						|
                rows_matched = 0
 | 
						|
                # execute deletes individually so that versioned
 | 
						|
                # rows can be verified
 | 
						|
                for params in del_objects:
 | 
						|
 | 
						|
                    c = connection._execute_20(
 | 
						|
                        statement, params, execution_options=execution_options
 | 
						|
                    )
 | 
						|
                    rows_matched += c.rowcount
 | 
						|
            else:
 | 
						|
                util.warn(
 | 
						|
                    "Dialect %s does not support deleted rowcount "
 | 
						|
                    "- versioning cannot be verified."
 | 
						|
                    % connection.dialect.dialect_description
 | 
						|
                )
 | 
						|
                connection._execute_20(
 | 
						|
                    statement, del_objects, execution_options=execution_options
 | 
						|
                )
 | 
						|
        else:
 | 
						|
            c = connection._execute_20(
 | 
						|
                statement, del_objects, execution_options=execution_options
 | 
						|
            )
 | 
						|
 | 
						|
            if not need_version_id:
 | 
						|
                only_warn = True
 | 
						|
 | 
						|
            rows_matched = c.rowcount
 | 
						|
 | 
						|
        if (
 | 
						|
            base_mapper.confirm_deleted_rows
 | 
						|
            and rows_matched > -1
 | 
						|
            and expected != rows_matched
 | 
						|
            and (
 | 
						|
                connection.dialect.supports_sane_multi_rowcount
 | 
						|
                or len(del_objects) == 1
 | 
						|
            )
 | 
						|
        ):
 | 
						|
            # TODO: why does this "only warn" if versioning is turned off,
 | 
						|
            # whereas the UPDATE raises?
 | 
						|
            if only_warn:
 | 
						|
                util.warn(
 | 
						|
                    "DELETE statement on table '%s' expected to "
 | 
						|
                    "delete %d row(s); %d were matched.  Please set "
 | 
						|
                    "confirm_deleted_rows=False within the mapper "
 | 
						|
                    "configuration to prevent this warning."
 | 
						|
                    % (table.description, expected, rows_matched)
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                raise orm_exc.StaleDataError(
 | 
						|
                    "DELETE statement on table '%s' expected to "
 | 
						|
                    "delete %d row(s); %d were matched.  Please set "
 | 
						|
                    "confirm_deleted_rows=False within the mapper "
 | 
						|
                    "configuration to prevent this warning."
 | 
						|
                    % (table.description, expected, rows_matched)
 | 
						|
                )
 | 
						|
 | 
						|
 | 
						|
def _finalize_insert_update_commands(base_mapper, uowtransaction, states):
 | 
						|
    """finalize state on states that have been inserted or updated,
 | 
						|
    including calling after_insert/after_update events.
 | 
						|
 | 
						|
    """
 | 
						|
    for state, state_dict, mapper, connection, has_identity in states:
 | 
						|
 | 
						|
        if mapper._readonly_props:
 | 
						|
            readonly = state.unmodified_intersection(
 | 
						|
                [
 | 
						|
                    p.key
 | 
						|
                    for p in mapper._readonly_props
 | 
						|
                    if (
 | 
						|
                        p.expire_on_flush
 | 
						|
                        and (not p.deferred or p.key in state.dict)
 | 
						|
                    )
 | 
						|
                    or (
 | 
						|
                        not p.expire_on_flush
 | 
						|
                        and not p.deferred
 | 
						|
                        and p.key not in state.dict
 | 
						|
                    )
 | 
						|
                ]
 | 
						|
            )
 | 
						|
            if readonly:
 | 
						|
                state._expire_attributes(state.dict, readonly)
 | 
						|
 | 
						|
        # if eager_defaults option is enabled, load
 | 
						|
        # all expired cols.  Else if we have a version_id_col, make sure
 | 
						|
        # it isn't expired.
 | 
						|
        toload_now = []
 | 
						|
 | 
						|
        if base_mapper.eager_defaults:
 | 
						|
            toload_now.extend(
 | 
						|
                state._unloaded_non_object.intersection(
 | 
						|
                    mapper._server_default_plus_onupdate_propkeys
 | 
						|
                )
 | 
						|
            )
 | 
						|
 | 
						|
        if (
 | 
						|
            mapper.version_id_col is not None
 | 
						|
            and mapper.version_id_generator is False
 | 
						|
        ):
 | 
						|
            if mapper._version_id_prop.key in state.unloaded:
 | 
						|
                toload_now.extend([mapper._version_id_prop.key])
 | 
						|
 | 
						|
        if toload_now:
 | 
						|
            state.key = base_mapper._identity_key_from_state(state)
 | 
						|
            stmt = future.select(mapper).set_label_style(
 | 
						|
                LABEL_STYLE_TABLENAME_PLUS_COL
 | 
						|
            )
 | 
						|
            loading.load_on_ident(
 | 
						|
                uowtransaction.session,
 | 
						|
                stmt,
 | 
						|
                state.key,
 | 
						|
                refresh_state=state,
 | 
						|
                only_load_props=toload_now,
 | 
						|
            )
 | 
						|
 | 
						|
        # call after_XXX extensions
 | 
						|
        if not has_identity:
 | 
						|
            mapper.dispatch.after_insert(mapper, connection, state)
 | 
						|
        else:
 | 
						|
            mapper.dispatch.after_update(mapper, connection, state)
 | 
						|
 | 
						|
        if (
 | 
						|
            mapper.version_id_generator is False
 | 
						|
            and mapper.version_id_col is not None
 | 
						|
        ):
 | 
						|
            if state_dict[mapper._version_id_prop.key] is None:
 | 
						|
                raise orm_exc.FlushError(
 | 
						|
                    "Instance does not contain a non-NULL version value"
 | 
						|
                )
 | 
						|
 | 
						|
 | 
						|
def _postfetch_post_update(
 | 
						|
    mapper, uowtransaction, table, state, dict_, result, params
 | 
						|
):
 | 
						|
    if uowtransaction.is_deleted(state):
 | 
						|
        return
 | 
						|
 | 
						|
    prefetch_cols = result.context.compiled.prefetch
 | 
						|
    postfetch_cols = result.context.compiled.postfetch
 | 
						|
 | 
						|
    if (
 | 
						|
        mapper.version_id_col is not None
 | 
						|
        and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
    ):
 | 
						|
        prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
 | 
						|
 | 
						|
    refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
 | 
						|
    if refresh_flush:
 | 
						|
        load_evt_attrs = []
 | 
						|
 | 
						|
    for c in prefetch_cols:
 | 
						|
        if c.key in params and c in mapper._columntoproperty:
 | 
						|
            dict_[mapper._columntoproperty[c].key] = params[c.key]
 | 
						|
            if refresh_flush:
 | 
						|
                load_evt_attrs.append(mapper._columntoproperty[c].key)
 | 
						|
 | 
						|
    if refresh_flush and load_evt_attrs:
 | 
						|
        mapper.class_manager.dispatch.refresh_flush(
 | 
						|
            state, uowtransaction, load_evt_attrs
 | 
						|
        )
 | 
						|
 | 
						|
    if postfetch_cols:
 | 
						|
        state._expire_attributes(
 | 
						|
            state.dict,
 | 
						|
            [
 | 
						|
                mapper._columntoproperty[c].key
 | 
						|
                for c in postfetch_cols
 | 
						|
                if c in mapper._columntoproperty
 | 
						|
            ],
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def _postfetch(
 | 
						|
    mapper,
 | 
						|
    uowtransaction,
 | 
						|
    table,
 | 
						|
    state,
 | 
						|
    dict_,
 | 
						|
    result,
 | 
						|
    params,
 | 
						|
    value_params,
 | 
						|
    isupdate,
 | 
						|
    returned_defaults,
 | 
						|
):
 | 
						|
    """Expire attributes in need of newly persisted database state,
 | 
						|
    after an INSERT or UPDATE statement has proceeded for that
 | 
						|
    state."""
 | 
						|
 | 
						|
    prefetch_cols = result.context.compiled.prefetch
 | 
						|
    postfetch_cols = result.context.compiled.postfetch
 | 
						|
    returning_cols = result.context.compiled.returning
 | 
						|
 | 
						|
    if (
 | 
						|
        mapper.version_id_col is not None
 | 
						|
        and mapper.version_id_col in mapper._cols_by_table[table]
 | 
						|
    ):
 | 
						|
        prefetch_cols = list(prefetch_cols) + [mapper.version_id_col]
 | 
						|
 | 
						|
    refresh_flush = bool(mapper.class_manager.dispatch.refresh_flush)
 | 
						|
    if refresh_flush:
 | 
						|
        load_evt_attrs = []
 | 
						|
 | 
						|
    if returning_cols:
 | 
						|
        row = returned_defaults
 | 
						|
        if row is not None:
 | 
						|
            for row_value, col in zip(row, returning_cols):
 | 
						|
                # pk cols returned from insert are handled
 | 
						|
                # distinctly, don't step on the values here
 | 
						|
                if col.primary_key and result.context.isinsert:
 | 
						|
                    continue
 | 
						|
 | 
						|
                # note that columns can be in the "return defaults" that are
 | 
						|
                # not mapped to this mapper, typically because they are
 | 
						|
                # "excluded", which can be specified directly or also occurs
 | 
						|
                # when using declarative w/ single table inheritance
 | 
						|
                prop = mapper._columntoproperty.get(col)
 | 
						|
                if prop:
 | 
						|
                    dict_[prop.key] = row_value
 | 
						|
                    if refresh_flush:
 | 
						|
                        load_evt_attrs.append(prop.key)
 | 
						|
 | 
						|
    for c in prefetch_cols:
 | 
						|
        if c.key in params and c in mapper._columntoproperty:
 | 
						|
            dict_[mapper._columntoproperty[c].key] = params[c.key]
 | 
						|
            if refresh_flush:
 | 
						|
                load_evt_attrs.append(mapper._columntoproperty[c].key)
 | 
						|
 | 
						|
    if refresh_flush and load_evt_attrs:
 | 
						|
        mapper.class_manager.dispatch.refresh_flush(
 | 
						|
            state, uowtransaction, load_evt_attrs
 | 
						|
        )
 | 
						|
 | 
						|
    if isupdate and value_params:
 | 
						|
        # explicitly suit the use case specified by
 | 
						|
        # [ticket:3801], PK SQL expressions for UPDATE on non-RETURNING
 | 
						|
        # database which are set to themselves in order to do a version bump.
 | 
						|
        postfetch_cols.extend(
 | 
						|
            [
 | 
						|
                col
 | 
						|
                for col in value_params
 | 
						|
                if col.primary_key and col not in returning_cols
 | 
						|
            ]
 | 
						|
        )
 | 
						|
 | 
						|
    if postfetch_cols:
 | 
						|
        state._expire_attributes(
 | 
						|
            state.dict,
 | 
						|
            [
 | 
						|
                mapper._columntoproperty[c].key
 | 
						|
                for c in postfetch_cols
 | 
						|
                if c in mapper._columntoproperty
 | 
						|
            ],
 | 
						|
        )
 | 
						|
 | 
						|
    # synchronize newly inserted ids from one table to the next
 | 
						|
    # TODO: this still goes a little too often.  would be nice to
 | 
						|
    # have definitive list of "columns that changed" here
 | 
						|
    for m, equated_pairs in mapper._table_to_equated[table]:
 | 
						|
        sync.populate(
 | 
						|
            state,
 | 
						|
            m,
 | 
						|
            state,
 | 
						|
            m,
 | 
						|
            equated_pairs,
 | 
						|
            uowtransaction,
 | 
						|
            mapper.passive_updates,
 | 
						|
        )
 | 
						|
 | 
						|
 | 
						|
def _postfetch_bulk_save(mapper, dict_, table):
 | 
						|
    for m, equated_pairs in mapper._table_to_equated[table]:
 | 
						|
        sync.bulk_populate_inherit_keys(dict_, m, equated_pairs)
 | 
						|
 | 
						|
 | 
						|
def _connections_for_states(base_mapper, uowtransaction, states):
 | 
						|
    """Return an iterator of (state, state.dict, mapper, connection).
 | 
						|
 | 
						|
    The states are sorted according to _sort_states, then paired
 | 
						|
    with the connection they should be using for the given
 | 
						|
    unit of work transaction.
 | 
						|
 | 
						|
    """
 | 
						|
    # if session has a connection callable,
 | 
						|
    # organize individual states with the connection
 | 
						|
    # to use for update
 | 
						|
    if uowtransaction.session.connection_callable:
 | 
						|
        connection_callable = uowtransaction.session.connection_callable
 | 
						|
    else:
 | 
						|
        connection = uowtransaction.transaction.connection(base_mapper)
 | 
						|
        connection_callable = None
 | 
						|
 | 
						|
    for state in _sort_states(base_mapper, states):
 | 
						|
        if connection_callable:
 | 
						|
            connection = connection_callable(base_mapper, state.obj())
 | 
						|
 | 
						|
        mapper = state.manager.mapper
 | 
						|
 | 
						|
        yield state, state.dict, mapper, connection
 | 
						|
 | 
						|
 | 
						|
def _sort_states(mapper, states):
 | 
						|
    pending = set(states)
 | 
						|
    persistent = set(s for s in pending if s.key is not None)
 | 
						|
    pending.difference_update(persistent)
 | 
						|
 | 
						|
    try:
 | 
						|
        persistent_sorted = sorted(
 | 
						|
            persistent, key=mapper._persistent_sortkey_fn
 | 
						|
        )
 | 
						|
    except TypeError as err:
 | 
						|
        util.raise_(
 | 
						|
            sa_exc.InvalidRequestError(
 | 
						|
                "Could not sort objects by primary key; primary key "
 | 
						|
                "values must be sortable in Python (was: %s)" % err
 | 
						|
            ),
 | 
						|
            replace_context=err,
 | 
						|
        )
 | 
						|
    return (
 | 
						|
        sorted(pending, key=operator.attrgetter("insert_order"))
 | 
						|
        + persistent_sorted
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
_EMPTY_DICT = util.immutabledict()
 | 
						|
 | 
						|
 | 
						|
class BulkUDCompileState(CompileState):
 | 
						|
    class default_update_options(Options):
 | 
						|
        _synchronize_session = "evaluate"
 | 
						|
        _autoflush = True
 | 
						|
        _subject_mapper = None
 | 
						|
        _resolved_values = _EMPTY_DICT
 | 
						|
        _resolved_keys_as_propnames = _EMPTY_DICT
 | 
						|
        _value_evaluators = _EMPTY_DICT
 | 
						|
        _matched_objects = None
 | 
						|
        _matched_rows = None
 | 
						|
        _refresh_identity_token = None
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def orm_pre_session_exec(
 | 
						|
        cls,
 | 
						|
        session,
 | 
						|
        statement,
 | 
						|
        params,
 | 
						|
        execution_options,
 | 
						|
        bind_arguments,
 | 
						|
        is_reentrant_invoke,
 | 
						|
    ):
 | 
						|
        if is_reentrant_invoke:
 | 
						|
            return statement, execution_options
 | 
						|
 | 
						|
        (
 | 
						|
            update_options,
 | 
						|
            execution_options,
 | 
						|
        ) = BulkUDCompileState.default_update_options.from_execution_options(
 | 
						|
            "_sa_orm_update_options",
 | 
						|
            {"synchronize_session"},
 | 
						|
            execution_options,
 | 
						|
            statement._execution_options,
 | 
						|
        )
 | 
						|
 | 
						|
        sync = update_options._synchronize_session
 | 
						|
        if sync is not None:
 | 
						|
            if sync not in ("evaluate", "fetch", False):
 | 
						|
                raise sa_exc.ArgumentError(
 | 
						|
                    "Valid strategies for session synchronization "
 | 
						|
                    "are 'evaluate', 'fetch', False"
 | 
						|
                )
 | 
						|
 | 
						|
        bind_arguments["clause"] = statement
 | 
						|
        try:
 | 
						|
            plugin_subject = statement._propagate_attrs["plugin_subject"]
 | 
						|
        except KeyError:
 | 
						|
            assert False, "statement had 'orm' plugin but no plugin_subject"
 | 
						|
        else:
 | 
						|
            bind_arguments["mapper"] = plugin_subject.mapper
 | 
						|
 | 
						|
        update_options += {"_subject_mapper": plugin_subject.mapper}
 | 
						|
 | 
						|
        if update_options._autoflush:
 | 
						|
            session._autoflush()
 | 
						|
 | 
						|
        statement = statement._annotate(
 | 
						|
            {"synchronize_session": update_options._synchronize_session}
 | 
						|
        )
 | 
						|
 | 
						|
        # this stage of the execution is called before the do_orm_execute event
 | 
						|
        # hook.  meaning for an extension like horizontal sharding, this step
 | 
						|
        # happens before the extension splits out into multiple backends and
 | 
						|
        # runs only once.  if we do pre_sync_fetch, we execute a SELECT
 | 
						|
        # statement, which the horizontal sharding extension splits amongst the
 | 
						|
        # shards and combines the results together.
 | 
						|
 | 
						|
        if update_options._synchronize_session == "evaluate":
 | 
						|
            update_options = cls._do_pre_synchronize_evaluate(
 | 
						|
                session,
 | 
						|
                statement,
 | 
						|
                params,
 | 
						|
                execution_options,
 | 
						|
                bind_arguments,
 | 
						|
                update_options,
 | 
						|
            )
 | 
						|
        elif update_options._synchronize_session == "fetch":
 | 
						|
            update_options = cls._do_pre_synchronize_fetch(
 | 
						|
                session,
 | 
						|
                statement,
 | 
						|
                params,
 | 
						|
                execution_options,
 | 
						|
                bind_arguments,
 | 
						|
                update_options,
 | 
						|
            )
 | 
						|
 | 
						|
        return (
 | 
						|
            statement,
 | 
						|
            util.immutabledict(execution_options).union(
 | 
						|
                {"_sa_orm_update_options": update_options}
 | 
						|
            ),
 | 
						|
        )
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def orm_setup_cursor_result(
 | 
						|
        cls,
 | 
						|
        session,
 | 
						|
        statement,
 | 
						|
        params,
 | 
						|
        execution_options,
 | 
						|
        bind_arguments,
 | 
						|
        result,
 | 
						|
    ):
 | 
						|
 | 
						|
        # this stage of the execution is called after the
 | 
						|
        # do_orm_execute event hook.  meaning for an extension like
 | 
						|
        # horizontal sharding, this step happens *within* the horizontal
 | 
						|
        # sharding event handler which calls session.execute() re-entrantly
 | 
						|
        # and will occur for each backend individually.
 | 
						|
        # the sharding extension then returns its own merged result from the
 | 
						|
        # individual ones we return here.
 | 
						|
 | 
						|
        update_options = execution_options["_sa_orm_update_options"]
 | 
						|
        if update_options._synchronize_session == "evaluate":
 | 
						|
            cls._do_post_synchronize_evaluate(session, result, update_options)
 | 
						|
        elif update_options._synchronize_session == "fetch":
 | 
						|
            cls._do_post_synchronize_fetch(session, result, update_options)
 | 
						|
 | 
						|
        return result
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _adjust_for_extra_criteria(cls, global_attributes, ext_info):
 | 
						|
        """Apply extra criteria filtering.
 | 
						|
 | 
						|
        For all distinct single-table-inheritance mappers represented in the
 | 
						|
        table being updated or deleted, produce additional WHERE criteria such
 | 
						|
        that only the appropriate subtypes are selected from the total results.
 | 
						|
 | 
						|
        Additionally, add WHERE criteria originating from LoaderCriteriaOptions
 | 
						|
        collected from the statement.
 | 
						|
 | 
						|
        """
 | 
						|
 | 
						|
        return_crit = ()
 | 
						|
 | 
						|
        adapter = ext_info._adapter if ext_info.is_aliased_class else None
 | 
						|
 | 
						|
        if (
 | 
						|
            "additional_entity_criteria",
 | 
						|
            ext_info.mapper,
 | 
						|
        ) in global_attributes:
 | 
						|
            return_crit += tuple(
 | 
						|
                ae._resolve_where_criteria(ext_info)
 | 
						|
                for ae in global_attributes[
 | 
						|
                    ("additional_entity_criteria", ext_info.mapper)
 | 
						|
                ]
 | 
						|
                if ae.include_aliases or ae.entity is ext_info
 | 
						|
            )
 | 
						|
 | 
						|
        if ext_info.mapper._single_table_criterion is not None:
 | 
						|
            return_crit += (ext_info.mapper._single_table_criterion,)
 | 
						|
 | 
						|
        if adapter:
 | 
						|
            return_crit = tuple(adapter.traverse(crit) for crit in return_crit)
 | 
						|
 | 
						|
        return return_crit
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _do_pre_synchronize_evaluate(
 | 
						|
        cls,
 | 
						|
        session,
 | 
						|
        statement,
 | 
						|
        params,
 | 
						|
        execution_options,
 | 
						|
        bind_arguments,
 | 
						|
        update_options,
 | 
						|
    ):
 | 
						|
        mapper = update_options._subject_mapper
 | 
						|
        target_cls = mapper.class_
 | 
						|
 | 
						|
        value_evaluators = resolved_keys_as_propnames = _EMPTY_DICT
 | 
						|
 | 
						|
        try:
 | 
						|
            evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
 | 
						|
            crit = ()
 | 
						|
            if statement._where_criteria:
 | 
						|
                crit += statement._where_criteria
 | 
						|
 | 
						|
            global_attributes = {}
 | 
						|
            for opt in statement._with_options:
 | 
						|
                if opt._is_criteria_option:
 | 
						|
                    opt.get_global_criteria(global_attributes)
 | 
						|
 | 
						|
            if global_attributes:
 | 
						|
                crit += cls._adjust_for_extra_criteria(
 | 
						|
                    global_attributes, mapper
 | 
						|
                )
 | 
						|
 | 
						|
            if crit:
 | 
						|
                eval_condition = evaluator_compiler.process(*crit)
 | 
						|
            else:
 | 
						|
 | 
						|
                def eval_condition(obj):
 | 
						|
                    return True
 | 
						|
 | 
						|
        except evaluator.UnevaluatableError as err:
 | 
						|
            util.raise_(
 | 
						|
                sa_exc.InvalidRequestError(
 | 
						|
                    'Could not evaluate current criteria in Python: "%s". '
 | 
						|
                    "Specify 'fetch' or False for the "
 | 
						|
                    "synchronize_session execution option." % err
 | 
						|
                ),
 | 
						|
                from_=err,
 | 
						|
            )
 | 
						|
 | 
						|
        if statement.__visit_name__ == "lambda_element":
 | 
						|
            # ._resolved is called on every LambdaElement in order to
 | 
						|
            # generate the cache key, so this access does not add
 | 
						|
            # additional expense
 | 
						|
            effective_statement = statement._resolved
 | 
						|
        else:
 | 
						|
            effective_statement = statement
 | 
						|
 | 
						|
        if effective_statement.__visit_name__ == "update":
 | 
						|
            resolved_values = cls._get_resolved_values(
 | 
						|
                mapper, effective_statement
 | 
						|
            )
 | 
						|
            value_evaluators = {}
 | 
						|
            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
 | 
						|
                mapper, resolved_values
 | 
						|
            )
 | 
						|
            for key, value in resolved_keys_as_propnames:
 | 
						|
                try:
 | 
						|
                    _evaluator = evaluator_compiler.process(
 | 
						|
                        coercions.expect(roles.ExpressionElementRole, value)
 | 
						|
                    )
 | 
						|
                except evaluator.UnevaluatableError:
 | 
						|
                    pass
 | 
						|
                else:
 | 
						|
                    value_evaluators[key] = _evaluator
 | 
						|
 | 
						|
        # TODO: detect when the where clause is a trivial primary key match.
 | 
						|
        matched_objects = [
 | 
						|
            state.obj()
 | 
						|
            for state in session.identity_map.all_states()
 | 
						|
            if state.mapper.isa(mapper)
 | 
						|
            and not state.expired
 | 
						|
            and eval_condition(state.obj())
 | 
						|
            and (
 | 
						|
                update_options._refresh_identity_token is None
 | 
						|
                # TODO: coverage for the case where horizontal sharding
 | 
						|
                # invokes an update() or delete() given an explicit identity
 | 
						|
                # token up front
 | 
						|
                or state.identity_token
 | 
						|
                == update_options._refresh_identity_token
 | 
						|
            )
 | 
						|
        ]
 | 
						|
        return update_options + {
 | 
						|
            "_matched_objects": matched_objects,
 | 
						|
            "_value_evaluators": value_evaluators,
 | 
						|
            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
 | 
						|
        }
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _get_resolved_values(cls, mapper, statement):
 | 
						|
        if statement._multi_values:
 | 
						|
            return []
 | 
						|
        elif statement._ordered_values:
 | 
						|
            return list(statement._ordered_values)
 | 
						|
        elif statement._values:
 | 
						|
            return list(statement._values.items())
 | 
						|
        else:
 | 
						|
            return []
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _resolved_keys_as_propnames(cls, mapper, resolved_values):
 | 
						|
        values = []
 | 
						|
        for k, v in resolved_values:
 | 
						|
            if isinstance(k, attributes.QueryableAttribute):
 | 
						|
                values.append((k.key, v))
 | 
						|
                continue
 | 
						|
            elif hasattr(k, "__clause_element__"):
 | 
						|
                k = k.__clause_element__()
 | 
						|
 | 
						|
            if mapper and isinstance(k, expression.ColumnElement):
 | 
						|
                try:
 | 
						|
                    attr = mapper._columntoproperty[k]
 | 
						|
                except orm_exc.UnmappedColumnError:
 | 
						|
                    pass
 | 
						|
                else:
 | 
						|
                    values.append((attr.key, v))
 | 
						|
            else:
 | 
						|
                raise sa_exc.InvalidRequestError(
 | 
						|
                    "Invalid expression type: %r" % k
 | 
						|
                )
 | 
						|
        return values
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _do_pre_synchronize_fetch(
 | 
						|
        cls,
 | 
						|
        session,
 | 
						|
        statement,
 | 
						|
        params,
 | 
						|
        execution_options,
 | 
						|
        bind_arguments,
 | 
						|
        update_options,
 | 
						|
    ):
 | 
						|
        mapper = update_options._subject_mapper
 | 
						|
 | 
						|
        select_stmt = (
 | 
						|
            select(*(mapper.primary_key + (mapper.select_identity_token,)))
 | 
						|
            .select_from(mapper)
 | 
						|
            .options(*statement._with_options)
 | 
						|
        )
 | 
						|
        select_stmt._where_criteria = statement._where_criteria
 | 
						|
 | 
						|
        def skip_for_full_returning(orm_context):
 | 
						|
            bind = orm_context.session.get_bind(**orm_context.bind_arguments)
 | 
						|
            if bind.dialect.full_returning:
 | 
						|
                return _result.null_result()
 | 
						|
            else:
 | 
						|
                return None
 | 
						|
 | 
						|
        result = session.execute(
 | 
						|
            select_stmt,
 | 
						|
            params,
 | 
						|
            execution_options,
 | 
						|
            bind_arguments,
 | 
						|
            _add_event=skip_for_full_returning,
 | 
						|
        )
 | 
						|
        matched_rows = result.fetchall()
 | 
						|
 | 
						|
        value_evaluators = _EMPTY_DICT
 | 
						|
 | 
						|
        if statement.__visit_name__ == "lambda_element":
 | 
						|
            # ._resolved is called on every LambdaElement in order to
 | 
						|
            # generate the cache key, so this access does not add
 | 
						|
            # additional expense
 | 
						|
            effective_statement = statement._resolved
 | 
						|
        else:
 | 
						|
            effective_statement = statement
 | 
						|
 | 
						|
        if effective_statement.__visit_name__ == "update":
 | 
						|
            target_cls = mapper.class_
 | 
						|
            evaluator_compiler = evaluator.EvaluatorCompiler(target_cls)
 | 
						|
            resolved_values = cls._get_resolved_values(
 | 
						|
                mapper, effective_statement
 | 
						|
            )
 | 
						|
            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
 | 
						|
                mapper, resolved_values
 | 
						|
            )
 | 
						|
 | 
						|
            resolved_keys_as_propnames = cls._resolved_keys_as_propnames(
 | 
						|
                mapper, resolved_values
 | 
						|
            )
 | 
						|
            value_evaluators = {}
 | 
						|
            for key, value in resolved_keys_as_propnames:
 | 
						|
                try:
 | 
						|
                    _evaluator = evaluator_compiler.process(
 | 
						|
                        coercions.expect(roles.ExpressionElementRole, value)
 | 
						|
                    )
 | 
						|
                except evaluator.UnevaluatableError:
 | 
						|
                    pass
 | 
						|
                else:
 | 
						|
                    value_evaluators[key] = _evaluator
 | 
						|
 | 
						|
        else:
 | 
						|
            resolved_keys_as_propnames = _EMPTY_DICT
 | 
						|
 | 
						|
        return update_options + {
 | 
						|
            "_value_evaluators": value_evaluators,
 | 
						|
            "_matched_rows": matched_rows,
 | 
						|
            "_resolved_keys_as_propnames": resolved_keys_as_propnames,
 | 
						|
        }
 | 
						|
 | 
						|
 | 
						|
class ORMDMLState:
 | 
						|
    @classmethod
 | 
						|
    def get_entity_description(cls, statement):
 | 
						|
        ext_info = statement.table._annotations["parententity"]
 | 
						|
        mapper = ext_info.mapper
 | 
						|
        if ext_info.is_aliased_class:
 | 
						|
            _label_name = ext_info.name
 | 
						|
        else:
 | 
						|
            _label_name = mapper.class_.__name__
 | 
						|
 | 
						|
        return {
 | 
						|
            "name": _label_name,
 | 
						|
            "type": mapper.class_,
 | 
						|
            "expr": ext_info.entity,
 | 
						|
            "entity": ext_info.entity,
 | 
						|
            "table": mapper.local_table,
 | 
						|
        }
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def get_returning_column_descriptions(cls, statement):
 | 
						|
        def _ent_for_col(c):
 | 
						|
            return c._annotations.get("parententity", None)
 | 
						|
 | 
						|
        def _attr_for_col(c, ent):
 | 
						|
            if ent is None:
 | 
						|
                return c
 | 
						|
            proxy_key = c._annotations.get("proxy_key", None)
 | 
						|
            if not proxy_key:
 | 
						|
                return c
 | 
						|
            else:
 | 
						|
                return getattr(ent.entity, proxy_key, c)
 | 
						|
 | 
						|
        return [
 | 
						|
            {
 | 
						|
                "name": c.key,
 | 
						|
                "type": c.type,
 | 
						|
                "expr": _attr_for_col(c, ent),
 | 
						|
                "aliased": ent.is_aliased_class,
 | 
						|
                "entity": ent.entity,
 | 
						|
            }
 | 
						|
            for c, ent in [
 | 
						|
                (c, _ent_for_col(c)) for c in statement._all_selected_columns
 | 
						|
            ]
 | 
						|
        ]
 | 
						|
 | 
						|
 | 
						|
@CompileState.plugin_for("orm", "insert")
 | 
						|
class ORMInsert(ORMDMLState, InsertDMLState):
 | 
						|
    @classmethod
 | 
						|
    def orm_pre_session_exec(
 | 
						|
        cls,
 | 
						|
        session,
 | 
						|
        statement,
 | 
						|
        params,
 | 
						|
        execution_options,
 | 
						|
        bind_arguments,
 | 
						|
        is_reentrant_invoke,
 | 
						|
    ):
 | 
						|
        bind_arguments["clause"] = statement
 | 
						|
        try:
 | 
						|
            plugin_subject = statement._propagate_attrs["plugin_subject"]
 | 
						|
        except KeyError:
 | 
						|
            assert False, "statement had 'orm' plugin but no plugin_subject"
 | 
						|
        else:
 | 
						|
            bind_arguments["mapper"] = plugin_subject.mapper
 | 
						|
 | 
						|
        return (
 | 
						|
            statement,
 | 
						|
            util.immutabledict(execution_options),
 | 
						|
        )
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def orm_setup_cursor_result(
 | 
						|
        cls,
 | 
						|
        session,
 | 
						|
        statement,
 | 
						|
        params,
 | 
						|
        execution_options,
 | 
						|
        bind_arguments,
 | 
						|
        result,
 | 
						|
    ):
 | 
						|
        return result
 | 
						|
 | 
						|
 | 
						|
@CompileState.plugin_for("orm", "update")
 | 
						|
class BulkORMUpdate(ORMDMLState, UpdateDMLState, BulkUDCompileState):
 | 
						|
    @classmethod
 | 
						|
    def create_for_statement(cls, statement, compiler, **kw):
 | 
						|
 | 
						|
        self = cls.__new__(cls)
 | 
						|
 | 
						|
        ext_info = statement.table._annotations["parententity"]
 | 
						|
 | 
						|
        self.mapper = mapper = ext_info.mapper
 | 
						|
 | 
						|
        self.extra_criteria_entities = {}
 | 
						|
 | 
						|
        self._resolved_values = cls._get_resolved_values(mapper, statement)
 | 
						|
 | 
						|
        extra_criteria_attributes = {}
 | 
						|
 | 
						|
        for opt in statement._with_options:
 | 
						|
            if opt._is_criteria_option:
 | 
						|
                opt.get_global_criteria(extra_criteria_attributes)
 | 
						|
 | 
						|
        if not statement._preserve_parameter_order and statement._values:
 | 
						|
            self._resolved_values = dict(self._resolved_values)
 | 
						|
 | 
						|
        new_stmt = sql.Update.__new__(sql.Update)
 | 
						|
        new_stmt.__dict__.update(statement.__dict__)
 | 
						|
        new_stmt.table = mapper.local_table
 | 
						|
 | 
						|
        # note if the statement has _multi_values, these
 | 
						|
        # are passed through to the new statement, which will then raise
 | 
						|
        # InvalidRequestError because UPDATE doesn't support multi_values
 | 
						|
        # right now.
 | 
						|
        if statement._ordered_values:
 | 
						|
            new_stmt._ordered_values = self._resolved_values
 | 
						|
        elif statement._values:
 | 
						|
            new_stmt._values = self._resolved_values
 | 
						|
 | 
						|
        new_crit = cls._adjust_for_extra_criteria(
 | 
						|
            extra_criteria_attributes, mapper
 | 
						|
        )
 | 
						|
        if new_crit:
 | 
						|
            new_stmt = new_stmt.where(*new_crit)
 | 
						|
 | 
						|
        # if we are against a lambda statement we might not be the
 | 
						|
        # topmost object that received per-execute annotations
 | 
						|
 | 
						|
        if (
 | 
						|
            compiler._annotations.get("synchronize_session", None) == "fetch"
 | 
						|
            and compiler.dialect.full_returning
 | 
						|
        ):
 | 
						|
            if new_stmt._returning:
 | 
						|
                raise sa_exc.InvalidRequestError(
 | 
						|
                    "Can't use synchronize_session='fetch' "
 | 
						|
                    "with explicit returning()"
 | 
						|
                )
 | 
						|
            new_stmt = new_stmt.returning(*mapper.primary_key)
 | 
						|
 | 
						|
        UpdateDMLState.__init__(self, new_stmt, compiler, **kw)
 | 
						|
 | 
						|
        return self
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _get_crud_kv_pairs(cls, statement, kv_iterator):
 | 
						|
        plugin_subject = statement._propagate_attrs["plugin_subject"]
 | 
						|
 | 
						|
        core_get_crud_kv_pairs = UpdateDMLState._get_crud_kv_pairs
 | 
						|
 | 
						|
        if not plugin_subject or not plugin_subject.mapper:
 | 
						|
            return core_get_crud_kv_pairs(statement, kv_iterator)
 | 
						|
 | 
						|
        mapper = plugin_subject.mapper
 | 
						|
 | 
						|
        values = []
 | 
						|
 | 
						|
        for k, v in kv_iterator:
 | 
						|
            k = coercions.expect(roles.DMLColumnRole, k)
 | 
						|
 | 
						|
            if isinstance(k, util.string_types):
 | 
						|
                desc = _entity_namespace_key(mapper, k, default=NO_VALUE)
 | 
						|
                if desc is NO_VALUE:
 | 
						|
                    values.append(
 | 
						|
                        (
 | 
						|
                            k,
 | 
						|
                            coercions.expect(
 | 
						|
                                roles.ExpressionElementRole,
 | 
						|
                                v,
 | 
						|
                                type_=sqltypes.NullType(),
 | 
						|
                                is_crud=True,
 | 
						|
                            ),
 | 
						|
                        )
 | 
						|
                    )
 | 
						|
                else:
 | 
						|
                    values.extend(
 | 
						|
                        core_get_crud_kv_pairs(
 | 
						|
                            statement, desc._bulk_update_tuples(v)
 | 
						|
                        )
 | 
						|
                    )
 | 
						|
            elif "entity_namespace" in k._annotations:
 | 
						|
                k_anno = k._annotations
 | 
						|
                attr = _entity_namespace_key(
 | 
						|
                    k_anno["entity_namespace"], k_anno["proxy_key"]
 | 
						|
                )
 | 
						|
                values.extend(
 | 
						|
                    core_get_crud_kv_pairs(
 | 
						|
                        statement, attr._bulk_update_tuples(v)
 | 
						|
                    )
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                values.append(
 | 
						|
                    (
 | 
						|
                        k,
 | 
						|
                        coercions.expect(
 | 
						|
                            roles.ExpressionElementRole,
 | 
						|
                            v,
 | 
						|
                            type_=sqltypes.NullType(),
 | 
						|
                            is_crud=True,
 | 
						|
                        ),
 | 
						|
                    )
 | 
						|
                )
 | 
						|
        return values
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _do_post_synchronize_evaluate(cls, session, result, update_options):
 | 
						|
 | 
						|
        states = set()
 | 
						|
        evaluated_keys = list(update_options._value_evaluators.keys())
 | 
						|
        values = update_options._resolved_keys_as_propnames
 | 
						|
        attrib = set(k for k, v in values)
 | 
						|
        for obj in update_options._matched_objects:
 | 
						|
 | 
						|
            state, dict_ = (
 | 
						|
                attributes.instance_state(obj),
 | 
						|
                attributes.instance_dict(obj),
 | 
						|
            )
 | 
						|
 | 
						|
            # the evaluated states were gathered across all identity tokens.
 | 
						|
            # however the post_sync events are called per identity token,
 | 
						|
            # so filter.
 | 
						|
            if (
 | 
						|
                update_options._refresh_identity_token is not None
 | 
						|
                and state.identity_token
 | 
						|
                != update_options._refresh_identity_token
 | 
						|
            ):
 | 
						|
                continue
 | 
						|
 | 
						|
            # only evaluate unmodified attributes
 | 
						|
            to_evaluate = state.unmodified.intersection(evaluated_keys)
 | 
						|
            for key in to_evaluate:
 | 
						|
                if key in dict_:
 | 
						|
                    dict_[key] = update_options._value_evaluators[key](obj)
 | 
						|
 | 
						|
            state.manager.dispatch.refresh(state, None, to_evaluate)
 | 
						|
 | 
						|
            state._commit(dict_, list(to_evaluate))
 | 
						|
 | 
						|
            to_expire = attrib.intersection(dict_).difference(to_evaluate)
 | 
						|
            if to_expire:
 | 
						|
                state._expire_attributes(dict_, to_expire)
 | 
						|
 | 
						|
            states.add(state)
 | 
						|
        session._register_altered(states)
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _do_post_synchronize_fetch(cls, session, result, update_options):
 | 
						|
        target_mapper = update_options._subject_mapper
 | 
						|
 | 
						|
        states = set()
 | 
						|
        evaluated_keys = list(update_options._value_evaluators.keys())
 | 
						|
 | 
						|
        if result.returns_rows:
 | 
						|
            matched_rows = [
 | 
						|
                tuple(row) + (update_options._refresh_identity_token,)
 | 
						|
                for row in result.all()
 | 
						|
            ]
 | 
						|
        else:
 | 
						|
            matched_rows = update_options._matched_rows
 | 
						|
 | 
						|
        objs = [
 | 
						|
            session.identity_map[identity_key]
 | 
						|
            for identity_key in [
 | 
						|
                target_mapper.identity_key_from_primary_key(
 | 
						|
                    list(primary_key),
 | 
						|
                    identity_token=identity_token,
 | 
						|
                )
 | 
						|
                for primary_key, identity_token in [
 | 
						|
                    (row[0:-1], row[-1]) for row in matched_rows
 | 
						|
                ]
 | 
						|
                if update_options._refresh_identity_token is None
 | 
						|
                or identity_token == update_options._refresh_identity_token
 | 
						|
            ]
 | 
						|
            if identity_key in session.identity_map
 | 
						|
        ]
 | 
						|
 | 
						|
        values = update_options._resolved_keys_as_propnames
 | 
						|
        attrib = set(k for k, v in values)
 | 
						|
 | 
						|
        for obj in objs:
 | 
						|
            state, dict_ = (
 | 
						|
                attributes.instance_state(obj),
 | 
						|
                attributes.instance_dict(obj),
 | 
						|
            )
 | 
						|
 | 
						|
            to_evaluate = state.unmodified.intersection(evaluated_keys)
 | 
						|
            for key in to_evaluate:
 | 
						|
                if key in dict_:
 | 
						|
                    dict_[key] = update_options._value_evaluators[key](obj)
 | 
						|
            state.manager.dispatch.refresh(state, None, to_evaluate)
 | 
						|
 | 
						|
            state._commit(dict_, list(to_evaluate))
 | 
						|
 | 
						|
            to_expire = attrib.intersection(dict_).difference(to_evaluate)
 | 
						|
            if to_expire:
 | 
						|
                state._expire_attributes(dict_, to_expire)
 | 
						|
 | 
						|
            states.add(state)
 | 
						|
        session._register_altered(states)
 | 
						|
 | 
						|
 | 
						|
@CompileState.plugin_for("orm", "delete")
 | 
						|
class BulkORMDelete(ORMDMLState, DeleteDMLState, BulkUDCompileState):
 | 
						|
    @classmethod
 | 
						|
    def create_for_statement(cls, statement, compiler, **kw):
 | 
						|
        self = cls.__new__(cls)
 | 
						|
 | 
						|
        ext_info = statement.table._annotations["parententity"]
 | 
						|
        self.mapper = mapper = ext_info.mapper
 | 
						|
 | 
						|
        self.extra_criteria_entities = {}
 | 
						|
 | 
						|
        extra_criteria_attributes = {}
 | 
						|
 | 
						|
        for opt in statement._with_options:
 | 
						|
            if opt._is_criteria_option:
 | 
						|
                opt.get_global_criteria(extra_criteria_attributes)
 | 
						|
 | 
						|
        new_crit = cls._adjust_for_extra_criteria(
 | 
						|
            extra_criteria_attributes, mapper
 | 
						|
        )
 | 
						|
        if new_crit:
 | 
						|
            statement = statement.where(*new_crit)
 | 
						|
 | 
						|
        if (
 | 
						|
            mapper
 | 
						|
            and compiler._annotations.get("synchronize_session", None)
 | 
						|
            == "fetch"
 | 
						|
            and compiler.dialect.full_returning
 | 
						|
        ):
 | 
						|
            statement = statement.returning(*mapper.primary_key)
 | 
						|
 | 
						|
        DeleteDMLState.__init__(self, statement, compiler, **kw)
 | 
						|
 | 
						|
        return self
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _do_post_synchronize_evaluate(cls, session, result, update_options):
 | 
						|
 | 
						|
        session._remove_newly_deleted(
 | 
						|
            [
 | 
						|
                attributes.instance_state(obj)
 | 
						|
                for obj in update_options._matched_objects
 | 
						|
            ]
 | 
						|
        )
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def _do_post_synchronize_fetch(cls, session, result, update_options):
 | 
						|
        target_mapper = update_options._subject_mapper
 | 
						|
 | 
						|
        if result.returns_rows:
 | 
						|
            matched_rows = [
 | 
						|
                tuple(row) + (update_options._refresh_identity_token,)
 | 
						|
                for row in result.all()
 | 
						|
            ]
 | 
						|
        else:
 | 
						|
            matched_rows = update_options._matched_rows
 | 
						|
 | 
						|
        for row in matched_rows:
 | 
						|
            primary_key = row[0:-1]
 | 
						|
            identity_token = row[-1]
 | 
						|
 | 
						|
            # TODO: inline this and call remove_newly_deleted
 | 
						|
            # once
 | 
						|
            identity_key = target_mapper.identity_key_from_primary_key(
 | 
						|
                list(primary_key),
 | 
						|
                identity_token=identity_token,
 | 
						|
            )
 | 
						|
            if identity_key in session.identity_map:
 | 
						|
                session._remove_newly_deleted(
 | 
						|
                    [
 | 
						|
                        attributes.instance_state(
 | 
						|
                            session.identity_map[identity_key]
 | 
						|
                        )
 | 
						|
                    ]
 | 
						|
                )
 |