You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							466 lines
						
					
					
						
							13 KiB
						
					
					
				
			
		
		
	
	
							466 lines
						
					
					
						
							13 KiB
						
					
					
				# testing/engines.py
 | 
						|
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
 | 
						|
# <see AUTHORS file>
 | 
						|
#
 | 
						|
# This module is part of SQLAlchemy and is released under
 | 
						|
# the MIT License: https://www.opensource.org/licenses/mit-license.php
 | 
						|
 | 
						|
from __future__ import absolute_import
 | 
						|
 | 
						|
import collections
 | 
						|
import re
 | 
						|
import warnings
 | 
						|
import weakref
 | 
						|
 | 
						|
from . import config
 | 
						|
from .util import decorator
 | 
						|
from .util import gc_collect
 | 
						|
from .. import event
 | 
						|
from .. import pool
 | 
						|
from ..util import await_only
 | 
						|
 | 
						|
 | 
						|
class ConnectionKiller(object):
 | 
						|
    def __init__(self):
 | 
						|
        self.proxy_refs = weakref.WeakKeyDictionary()
 | 
						|
        self.testing_engines = collections.defaultdict(set)
 | 
						|
        self.dbapi_connections = set()
 | 
						|
 | 
						|
    def add_pool(self, pool):
 | 
						|
        event.listen(pool, "checkout", self._add_conn)
 | 
						|
        event.listen(pool, "checkin", self._remove_conn)
 | 
						|
        event.listen(pool, "close", self._remove_conn)
 | 
						|
        event.listen(pool, "close_detached", self._remove_conn)
 | 
						|
        # note we are keeping "invalidated" here, as those are still
 | 
						|
        # opened connections we would like to roll back
 | 
						|
 | 
						|
    def _add_conn(self, dbapi_con, con_record, con_proxy):
 | 
						|
        self.dbapi_connections.add(dbapi_con)
 | 
						|
        self.proxy_refs[con_proxy] = True
 | 
						|
 | 
						|
    def _remove_conn(self, dbapi_conn, *arg):
 | 
						|
        self.dbapi_connections.discard(dbapi_conn)
 | 
						|
 | 
						|
    def add_engine(self, engine, scope):
 | 
						|
        self.add_pool(engine.pool)
 | 
						|
 | 
						|
        assert scope in ("class", "global", "function", "fixture")
 | 
						|
        self.testing_engines[scope].add(engine)
 | 
						|
 | 
						|
    def _safe(self, fn):
 | 
						|
        try:
 | 
						|
            fn()
 | 
						|
        except Exception as e:
 | 
						|
            warnings.warn(
 | 
						|
                "testing_reaper couldn't rollback/close connection: %s" % e
 | 
						|
            )
 | 
						|
 | 
						|
    def rollback_all(self):
 | 
						|
        for rec in list(self.proxy_refs):
 | 
						|
            if rec is not None and rec.is_valid:
 | 
						|
                self._safe(rec.rollback)
 | 
						|
 | 
						|
    def checkin_all(self):
 | 
						|
        # run pool.checkin() for all ConnectionFairy instances we have
 | 
						|
        # tracked.
 | 
						|
 | 
						|
        for rec in list(self.proxy_refs):
 | 
						|
            if rec is not None and rec.is_valid:
 | 
						|
                self.dbapi_connections.discard(rec.dbapi_connection)
 | 
						|
                self._safe(rec._checkin)
 | 
						|
 | 
						|
        # for fairy refs that were GCed and could not close the connection,
 | 
						|
        # such as asyncio, roll back those remaining connections
 | 
						|
        for con in self.dbapi_connections:
 | 
						|
            self._safe(con.rollback)
 | 
						|
        self.dbapi_connections.clear()
 | 
						|
 | 
						|
    def close_all(self):
 | 
						|
        self.checkin_all()
 | 
						|
 | 
						|
    def prepare_for_drop_tables(self, connection):
 | 
						|
        # don't do aggressive checks for third party test suites
 | 
						|
        if not config.bootstrapped_as_sqlalchemy:
 | 
						|
            return
 | 
						|
 | 
						|
        from . import provision
 | 
						|
 | 
						|
        provision.prepare_for_drop_tables(connection.engine.url, connection)
 | 
						|
 | 
						|
    def _drop_testing_engines(self, scope):
 | 
						|
        eng = self.testing_engines[scope]
 | 
						|
        for rec in list(eng):
 | 
						|
            for proxy_ref in list(self.proxy_refs):
 | 
						|
                if proxy_ref is not None and proxy_ref.is_valid:
 | 
						|
                    if (
 | 
						|
                        proxy_ref._pool is not None
 | 
						|
                        and proxy_ref._pool is rec.pool
 | 
						|
                    ):
 | 
						|
                        self._safe(proxy_ref._checkin)
 | 
						|
            if hasattr(rec, "sync_engine"):
 | 
						|
                await_only(rec.dispose())
 | 
						|
            else:
 | 
						|
                rec.dispose()
 | 
						|
        eng.clear()
 | 
						|
 | 
						|
    def after_test(self):
 | 
						|
        self._drop_testing_engines("function")
 | 
						|
 | 
						|
    def after_test_outside_fixtures(self, test):
 | 
						|
        # don't do aggressive checks for third party test suites
 | 
						|
        if not config.bootstrapped_as_sqlalchemy:
 | 
						|
            return
 | 
						|
 | 
						|
        if test.__class__.__leave_connections_for_teardown__:
 | 
						|
            return
 | 
						|
 | 
						|
        self.checkin_all()
 | 
						|
 | 
						|
        # on PostgreSQL, this will test for any "idle in transaction"
 | 
						|
        # connections.   useful to identify tests with unusual patterns
 | 
						|
        # that can't be cleaned up correctly.
 | 
						|
        from . import provision
 | 
						|
 | 
						|
        with config.db.connect() as conn:
 | 
						|
            provision.prepare_for_drop_tables(conn.engine.url, conn)
 | 
						|
 | 
						|
    def stop_test_class_inside_fixtures(self):
 | 
						|
        self.checkin_all()
 | 
						|
        self._drop_testing_engines("function")
 | 
						|
        self._drop_testing_engines("class")
 | 
						|
 | 
						|
    def stop_test_class_outside_fixtures(self):
 | 
						|
        # ensure no refs to checked out connections at all.
 | 
						|
 | 
						|
        if pool.base._strong_ref_connection_records:
 | 
						|
            gc_collect()
 | 
						|
 | 
						|
            if pool.base._strong_ref_connection_records:
 | 
						|
                ln = len(pool.base._strong_ref_connection_records)
 | 
						|
                pool.base._strong_ref_connection_records.clear()
 | 
						|
                assert (
 | 
						|
                    False
 | 
						|
                ), "%d connection recs not cleared after test suite" % (ln)
 | 
						|
 | 
						|
    def final_cleanup(self):
 | 
						|
        self.checkin_all()
 | 
						|
        for scope in self.testing_engines:
 | 
						|
            self._drop_testing_engines(scope)
 | 
						|
 | 
						|
    def assert_all_closed(self):
 | 
						|
        for rec in self.proxy_refs:
 | 
						|
            if rec.is_valid:
 | 
						|
                assert False
 | 
						|
 | 
						|
 | 
						|
testing_reaper = ConnectionKiller()
 | 
						|
 | 
						|
 | 
						|
@decorator
 | 
						|
def assert_conns_closed(fn, *args, **kw):
 | 
						|
    try:
 | 
						|
        fn(*args, **kw)
 | 
						|
    finally:
 | 
						|
        testing_reaper.assert_all_closed()
 | 
						|
 | 
						|
 | 
						|
@decorator
 | 
						|
def rollback_open_connections(fn, *args, **kw):
 | 
						|
    """Decorator that rolls back all open connections after fn execution."""
 | 
						|
 | 
						|
    try:
 | 
						|
        fn(*args, **kw)
 | 
						|
    finally:
 | 
						|
        testing_reaper.rollback_all()
 | 
						|
 | 
						|
 | 
						|
@decorator
 | 
						|
def close_first(fn, *args, **kw):
 | 
						|
    """Decorator that closes all connections before fn execution."""
 | 
						|
 | 
						|
    testing_reaper.checkin_all()
 | 
						|
    fn(*args, **kw)
 | 
						|
 | 
						|
 | 
						|
@decorator
 | 
						|
def close_open_connections(fn, *args, **kw):
 | 
						|
    """Decorator that closes all connections after fn execution."""
 | 
						|
    try:
 | 
						|
        fn(*args, **kw)
 | 
						|
    finally:
 | 
						|
        testing_reaper.checkin_all()
 | 
						|
 | 
						|
 | 
						|
def all_dialects(exclude=None):
 | 
						|
    import sqlalchemy.dialects as d
 | 
						|
 | 
						|
    for name in d.__all__:
 | 
						|
        # TEMPORARY
 | 
						|
        if exclude and name in exclude:
 | 
						|
            continue
 | 
						|
        mod = getattr(d, name, None)
 | 
						|
        if not mod:
 | 
						|
            mod = getattr(
 | 
						|
                __import__("sqlalchemy.dialects.%s" % name).dialects, name
 | 
						|
            )
 | 
						|
        yield mod.dialect()
 | 
						|
 | 
						|
 | 
						|
class ReconnectFixture(object):
 | 
						|
    def __init__(self, dbapi):
 | 
						|
        self.dbapi = dbapi
 | 
						|
        self.connections = []
 | 
						|
        self.is_stopped = False
 | 
						|
 | 
						|
    def __getattr__(self, key):
 | 
						|
        return getattr(self.dbapi, key)
 | 
						|
 | 
						|
    def connect(self, *args, **kwargs):
 | 
						|
 | 
						|
        conn = self.dbapi.connect(*args, **kwargs)
 | 
						|
        if self.is_stopped:
 | 
						|
            self._safe(conn.close)
 | 
						|
            curs = conn.cursor()  # should fail on Oracle etc.
 | 
						|
            # should fail for everything that didn't fail
 | 
						|
            # above, connection is closed
 | 
						|
            curs.execute("select 1")
 | 
						|
            assert False, "simulated connect failure didn't work"
 | 
						|
        else:
 | 
						|
            self.connections.append(conn)
 | 
						|
            return conn
 | 
						|
 | 
						|
    def _safe(self, fn):
 | 
						|
        try:
 | 
						|
            fn()
 | 
						|
        except Exception as e:
 | 
						|
            warnings.warn("ReconnectFixture couldn't close connection: %s" % e)
 | 
						|
 | 
						|
    def shutdown(self, stop=False):
 | 
						|
        # TODO: this doesn't cover all cases
 | 
						|
        # as nicely as we'd like, namely MySQLdb.
 | 
						|
        # would need to implement R. Brewer's
 | 
						|
        # proxy server idea to get better
 | 
						|
        # coverage.
 | 
						|
        self.is_stopped = stop
 | 
						|
        for c in list(self.connections):
 | 
						|
            self._safe(c.close)
 | 
						|
        self.connections = []
 | 
						|
 | 
						|
    def restart(self):
 | 
						|
        self.is_stopped = False
 | 
						|
 | 
						|
 | 
						|
def reconnecting_engine(url=None, options=None):
 | 
						|
    url = url or config.db.url
 | 
						|
    dbapi = config.db.dialect.dbapi
 | 
						|
    if not options:
 | 
						|
        options = {}
 | 
						|
    options["module"] = ReconnectFixture(dbapi)
 | 
						|
    engine = testing_engine(url, options)
 | 
						|
    _dispose = engine.dispose
 | 
						|
 | 
						|
    def dispose():
 | 
						|
        engine.dialect.dbapi.shutdown()
 | 
						|
        engine.dialect.dbapi.is_stopped = False
 | 
						|
        _dispose()
 | 
						|
 | 
						|
    engine.test_shutdown = engine.dialect.dbapi.shutdown
 | 
						|
    engine.test_restart = engine.dialect.dbapi.restart
 | 
						|
    engine.dispose = dispose
 | 
						|
    return engine
 | 
						|
 | 
						|
 | 
						|
def testing_engine(
 | 
						|
    url=None,
 | 
						|
    options=None,
 | 
						|
    future=None,
 | 
						|
    asyncio=False,
 | 
						|
    transfer_staticpool=False,
 | 
						|
    _sqlite_savepoint=False,
 | 
						|
):
 | 
						|
    """Produce an engine configured by --options with optional overrides."""
 | 
						|
 | 
						|
    if asyncio:
 | 
						|
        assert not _sqlite_savepoint
 | 
						|
        from sqlalchemy.ext.asyncio import (
 | 
						|
            create_async_engine as create_engine,
 | 
						|
        )
 | 
						|
    elif future or (
 | 
						|
        config.db and config.db._is_future and future is not False
 | 
						|
    ):
 | 
						|
        from sqlalchemy.future import create_engine
 | 
						|
    else:
 | 
						|
        from sqlalchemy import create_engine
 | 
						|
    from sqlalchemy.engine.url import make_url
 | 
						|
 | 
						|
    if not options:
 | 
						|
        use_reaper = True
 | 
						|
        scope = "function"
 | 
						|
        sqlite_savepoint = False
 | 
						|
    else:
 | 
						|
        use_reaper = options.pop("use_reaper", True)
 | 
						|
        scope = options.pop("scope", "function")
 | 
						|
        sqlite_savepoint = options.pop("sqlite_savepoint", False)
 | 
						|
 | 
						|
    url = url or config.db.url
 | 
						|
 | 
						|
    url = make_url(url)
 | 
						|
    if options is None:
 | 
						|
        if config.db is None or url.drivername == config.db.url.drivername:
 | 
						|
            options = config.db_opts
 | 
						|
        else:
 | 
						|
            options = {}
 | 
						|
    elif config.db is not None and url.drivername == config.db.url.drivername:
 | 
						|
        default_opt = config.db_opts.copy()
 | 
						|
        default_opt.update(options)
 | 
						|
 | 
						|
    engine = create_engine(url, **options)
 | 
						|
 | 
						|
    if sqlite_savepoint and engine.name == "sqlite":
 | 
						|
        # apply SQLite savepoint workaround
 | 
						|
        @event.listens_for(engine, "connect")
 | 
						|
        def do_connect(dbapi_connection, connection_record):
 | 
						|
            dbapi_connection.isolation_level = None
 | 
						|
 | 
						|
        @event.listens_for(engine, "begin")
 | 
						|
        def do_begin(conn):
 | 
						|
            conn.exec_driver_sql("BEGIN")
 | 
						|
 | 
						|
    if transfer_staticpool:
 | 
						|
        from sqlalchemy.pool import StaticPool
 | 
						|
 | 
						|
        if config.db is not None and isinstance(config.db.pool, StaticPool):
 | 
						|
            use_reaper = False
 | 
						|
            engine.pool._transfer_from(config.db.pool)
 | 
						|
 | 
						|
    if scope == "global":
 | 
						|
        if asyncio:
 | 
						|
            engine.sync_engine._has_events = True
 | 
						|
        else:
 | 
						|
            engine._has_events = (
 | 
						|
                True  # enable event blocks, helps with profiling
 | 
						|
            )
 | 
						|
 | 
						|
    if isinstance(engine.pool, pool.QueuePool):
 | 
						|
        engine.pool._timeout = 0
 | 
						|
        engine.pool._max_overflow = 0
 | 
						|
    if use_reaper:
 | 
						|
        testing_reaper.add_engine(engine, scope)
 | 
						|
 | 
						|
    return engine
 | 
						|
 | 
						|
 | 
						|
def mock_engine(dialect_name=None):
 | 
						|
    """Provides a mocking engine based on the current testing.db.
 | 
						|
 | 
						|
    This is normally used to test DDL generation flow as emitted
 | 
						|
    by an Engine.
 | 
						|
 | 
						|
    It should not be used in other cases, as assert_compile() and
 | 
						|
    assert_sql_execution() are much better choices with fewer
 | 
						|
    moving parts.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    from sqlalchemy import create_mock_engine
 | 
						|
 | 
						|
    if not dialect_name:
 | 
						|
        dialect_name = config.db.name
 | 
						|
 | 
						|
    buffer = []
 | 
						|
 | 
						|
    def executor(sql, *a, **kw):
 | 
						|
        buffer.append(sql)
 | 
						|
 | 
						|
    def assert_sql(stmts):
 | 
						|
        recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer]
 | 
						|
        assert recv == stmts, recv
 | 
						|
 | 
						|
    def print_sql():
 | 
						|
        d = engine.dialect
 | 
						|
        return "\n".join(str(s.compile(dialect=d)) for s in engine.mock)
 | 
						|
 | 
						|
    engine = create_mock_engine(dialect_name + "://", executor)
 | 
						|
    assert not hasattr(engine, "mock")
 | 
						|
    engine.mock = buffer
 | 
						|
    engine.assert_sql = assert_sql
 | 
						|
    engine.print_sql = print_sql
 | 
						|
    return engine
 | 
						|
 | 
						|
 | 
						|
class DBAPIProxyCursor(object):
 | 
						|
    """Proxy a DBAPI cursor.
 | 
						|
 | 
						|
    Tests can provide subclasses of this to intercept
 | 
						|
    DBAPI-level cursor operations.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, engine, conn, *args, **kwargs):
 | 
						|
        self.engine = engine
 | 
						|
        self.connection = conn
 | 
						|
        self.cursor = conn.cursor(*args, **kwargs)
 | 
						|
 | 
						|
    def execute(self, stmt, parameters=None, **kw):
 | 
						|
        if parameters:
 | 
						|
            return self.cursor.execute(stmt, parameters, **kw)
 | 
						|
        else:
 | 
						|
            return self.cursor.execute(stmt, **kw)
 | 
						|
 | 
						|
    def executemany(self, stmt, params, **kw):
 | 
						|
        return self.cursor.executemany(stmt, params, **kw)
 | 
						|
 | 
						|
    def __iter__(self):
 | 
						|
        return iter(self.cursor)
 | 
						|
 | 
						|
    def __getattr__(self, key):
 | 
						|
        return getattr(self.cursor, key)
 | 
						|
 | 
						|
 | 
						|
class DBAPIProxyConnection(object):
 | 
						|
    """Proxy a DBAPI connection.
 | 
						|
 | 
						|
    Tests can provide subclasses of this to intercept
 | 
						|
    DBAPI-level connection operations.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    def __init__(self, engine, cursor_cls):
 | 
						|
        self.conn = engine.pool._creator()
 | 
						|
        self.engine = engine
 | 
						|
        self.cursor_cls = cursor_cls
 | 
						|
 | 
						|
    def cursor(self, *args, **kwargs):
 | 
						|
        return self.cursor_cls(self.engine, self.conn, *args, **kwargs)
 | 
						|
 | 
						|
    def close(self):
 | 
						|
        self.conn.close()
 | 
						|
 | 
						|
    def __getattr__(self, key):
 | 
						|
        return getattr(self.conn, key)
 | 
						|
 | 
						|
 | 
						|
def proxying_engine(
 | 
						|
    conn_cls=DBAPIProxyConnection, cursor_cls=DBAPIProxyCursor
 | 
						|
):
 | 
						|
    """Produce an engine that provides proxy hooks for
 | 
						|
    common methods.
 | 
						|
 | 
						|
    """
 | 
						|
 | 
						|
    def mock_conn():
 | 
						|
        return conn_cls(config.db, cursor_cls)
 | 
						|
 | 
						|
    def _wrap_do_on_connect(do_on_connect):
 | 
						|
        def go(dbapi_conn):
 | 
						|
            return do_on_connect(dbapi_conn.conn)
 | 
						|
 | 
						|
        return go
 | 
						|
 | 
						|
    return testing_engine(
 | 
						|
        options={
 | 
						|
            "creator": mock_conn,
 | 
						|
            "_wrap_do_on_connect": _wrap_do_on_connect,
 | 
						|
        }
 | 
						|
    )
 |