You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					458 lines
				
				15 KiB
			
		
		
			
		
	
	
					458 lines
				
				15 KiB
			| 
								 
											3 years ago
										 
									 | 
							
								# testing/assertsql.py
							 | 
						||
| 
								 | 
							
								# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
							 | 
						||
| 
								 | 
							
								# <see AUTHORS file>
							 | 
						||
| 
								 | 
							
								#
							 | 
						||
| 
								 | 
							
								# This module is part of SQLAlchemy and is released under
							 | 
						||
| 
								 | 
							
								# the MIT License: https://www.opensource.org/licenses/mit-license.php
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								import collections
							 | 
						||
| 
								 | 
							
								import contextlib
							 | 
						||
| 
								 | 
							
								import re
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								from .. import event
							 | 
						||
| 
								 | 
							
								from .. import util
							 | 
						||
| 
								 | 
							
								from ..engine import url
							 | 
						||
| 
								 | 
							
								from ..engine.default import DefaultDialect
							 | 
						||
| 
								 | 
							
								from ..engine.util import _distill_cursor_params
							 | 
						||
| 
								 | 
							
								from ..schema import _DDLCompiles
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class AssertRule(object):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    is_consumed = False
							 | 
						||
| 
								 | 
							
								    errormessage = None
							 | 
						||
| 
								 | 
							
								    consume_statement = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def process_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def no_more_statements(self):
							 | 
						||
| 
								 | 
							
								        assert False, (
							 | 
						||
| 
								 | 
							
								            "All statements are complete, but pending "
							 | 
						||
| 
								 | 
							
								            "assertion rules remain"
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SQLMatchRule(AssertRule):
							 | 
						||
| 
								 | 
							
								    pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class CursorSQL(SQLMatchRule):
							 | 
						||
| 
								 | 
							
								    def __init__(self, statement, params=None, consume_statement=True):
							 | 
						||
| 
								 | 
							
								        self.statement = statement
							 | 
						||
| 
								 | 
							
								        self.params = params
							 | 
						||
| 
								 | 
							
								        self.consume_statement = consume_statement
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def process_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        stmt = execute_observed.statements[0]
							 | 
						||
| 
								 | 
							
								        if self.statement != stmt.statement or (
							 | 
						||
| 
								 | 
							
								            self.params is not None and self.params != stmt.parameters
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								            self.errormessage = (
							 | 
						||
| 
								 | 
							
								                "Testing for exact SQL %s parameters %s received %s %s"
							 | 
						||
| 
								 | 
							
								                % (
							 | 
						||
| 
								 | 
							
								                    self.statement,
							 | 
						||
| 
								 | 
							
								                    self.params,
							 | 
						||
| 
								 | 
							
								                    stmt.statement,
							 | 
						||
| 
								 | 
							
								                    stmt.parameters,
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            execute_observed.statements.pop(0)
							 | 
						||
| 
								 | 
							
								            self.is_consumed = True
							 | 
						||
| 
								 | 
							
								            if not execute_observed.statements:
							 | 
						||
| 
								 | 
							
								                self.consume_statement = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class CompiledSQL(SQLMatchRule):
							 | 
						||
| 
								 | 
							
								    def __init__(self, statement, params=None, dialect="default"):
							 | 
						||
| 
								 | 
							
								        self.statement = statement
							 | 
						||
| 
								 | 
							
								        self.params = params
							 | 
						||
| 
								 | 
							
								        self.dialect = dialect
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _compare_sql(self, execute_observed, received_statement):
							 | 
						||
| 
								 | 
							
								        stmt = re.sub(r"[\n\t]", "", self.statement)
							 | 
						||
| 
								 | 
							
								        return received_statement == stmt
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _compile_dialect(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        if self.dialect == "default":
							 | 
						||
| 
								 | 
							
								            dialect = DefaultDialect()
							 | 
						||
| 
								 | 
							
								            # this is currently what tests are expecting
							 | 
						||
| 
								 | 
							
								            # dialect.supports_default_values = True
							 | 
						||
| 
								 | 
							
								            dialect.supports_default_metavalue = True
							 | 
						||
| 
								 | 
							
								            return dialect
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            # ugh
							 | 
						||
| 
								 | 
							
								            if self.dialect == "postgresql":
							 | 
						||
| 
								 | 
							
								                params = {"implicit_returning": True}
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                params = {}
							 | 
						||
| 
								 | 
							
								            return url.URL.create(self.dialect).get_dialect()(**params)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _received_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        """reconstruct the statement and params in terms
							 | 
						||
| 
								 | 
							
								        of a target dialect, which for CompiledSQL is just DefaultDialect."""
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        context = execute_observed.context
							 | 
						||
| 
								 | 
							
								        compare_dialect = self._compile_dialect(execute_observed)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # received_statement runs a full compile().  we should not need to
							 | 
						||
| 
								 | 
							
								        # consider extracted_parameters; if we do this indicates some state
							 | 
						||
| 
								 | 
							
								        # is being sent from a previous cached query, which some misbehaviors
							 | 
						||
| 
								 | 
							
								        # in the ORM can cause, see #6881
							 | 
						||
| 
								 | 
							
								        cache_key = None  # execute_observed.context.compiled.cache_key
							 | 
						||
| 
								 | 
							
								        extracted_parameters = (
							 | 
						||
| 
								 | 
							
								            None  # execute_observed.context.extracted_parameters
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if "schema_translate_map" in context.execution_options:
							 | 
						||
| 
								 | 
							
								            map_ = context.execution_options["schema_translate_map"]
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            map_ = None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if isinstance(execute_observed.clauseelement, _DDLCompiles):
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								            compiled = execute_observed.clauseelement.compile(
							 | 
						||
| 
								 | 
							
								                dialect=compare_dialect,
							 | 
						||
| 
								 | 
							
								                schema_translate_map=map_,
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            compiled = execute_observed.clauseelement.compile(
							 | 
						||
| 
								 | 
							
								                cache_key=cache_key,
							 | 
						||
| 
								 | 
							
								                dialect=compare_dialect,
							 | 
						||
| 
								 | 
							
								                column_keys=context.compiled.column_keys,
							 | 
						||
| 
								 | 
							
								                for_executemany=context.compiled.for_executemany,
							 | 
						||
| 
								 | 
							
								                schema_translate_map=map_,
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        _received_statement = re.sub(r"[\n\t]", "", util.text_type(compiled))
							 | 
						||
| 
								 | 
							
								        parameters = execute_observed.parameters
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not parameters:
							 | 
						||
| 
								 | 
							
								            _received_parameters = [
							 | 
						||
| 
								 | 
							
								                compiled.construct_params(
							 | 
						||
| 
								 | 
							
								                    extracted_parameters=extracted_parameters
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								            ]
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            _received_parameters = [
							 | 
						||
| 
								 | 
							
								                compiled.construct_params(
							 | 
						||
| 
								 | 
							
								                    m, extracted_parameters=extracted_parameters
							 | 
						||
| 
								 | 
							
								                )
							 | 
						||
| 
								 | 
							
								                for m in parameters
							 | 
						||
| 
								 | 
							
								            ]
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return _received_statement, _received_parameters
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def process_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        context = execute_observed.context
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        _received_statement, _received_parameters = self._received_statement(
							 | 
						||
| 
								 | 
							
								            execute_observed
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        params = self._all_params(context)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        equivalent = self._compare_sql(execute_observed, _received_statement)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if equivalent:
							 | 
						||
| 
								 | 
							
								            if params is not None:
							 | 
						||
| 
								 | 
							
								                all_params = list(params)
							 | 
						||
| 
								 | 
							
								                all_received = list(_received_parameters)
							 | 
						||
| 
								 | 
							
								                while all_params and all_received:
							 | 
						||
| 
								 | 
							
								                    param = dict(all_params.pop(0))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								                    for idx, received in enumerate(list(all_received)):
							 | 
						||
| 
								 | 
							
								                        # do a positive compare only
							 | 
						||
| 
								 | 
							
								                        for param_key in param:
							 | 
						||
| 
								 | 
							
								                            # a key in param did not match current
							 | 
						||
| 
								 | 
							
								                            # 'received'
							 | 
						||
| 
								 | 
							
								                            if (
							 | 
						||
| 
								 | 
							
								                                param_key not in received
							 | 
						||
| 
								 | 
							
								                                or received[param_key] != param[param_key]
							 | 
						||
| 
								 | 
							
								                            ):
							 | 
						||
| 
								 | 
							
								                                break
							 | 
						||
| 
								 | 
							
								                        else:
							 | 
						||
| 
								 | 
							
								                            # all keys in param matched 'received';
							 | 
						||
| 
								 | 
							
								                            # onto next param
							 | 
						||
| 
								 | 
							
								                            del all_received[idx]
							 | 
						||
| 
								 | 
							
								                            break
							 | 
						||
| 
								 | 
							
								                    else:
							 | 
						||
| 
								 | 
							
								                        # param did not match any entry
							 | 
						||
| 
								 | 
							
								                        # in all_received
							 | 
						||
| 
								 | 
							
								                        equivalent = False
							 | 
						||
| 
								 | 
							
								                        break
							 | 
						||
| 
								 | 
							
								                if all_params or all_received:
							 | 
						||
| 
								 | 
							
								                    equivalent = False
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if equivalent:
							 | 
						||
| 
								 | 
							
								            self.is_consumed = True
							 | 
						||
| 
								 | 
							
								            self.errormessage = None
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            self.errormessage = self._failure_message(params) % {
							 | 
						||
| 
								 | 
							
								                "received_statement": _received_statement,
							 | 
						||
| 
								 | 
							
								                "received_parameters": _received_parameters,
							 | 
						||
| 
								 | 
							
								            }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _all_params(self, context):
							 | 
						||
| 
								 | 
							
								        if self.params:
							 | 
						||
| 
								 | 
							
								            if callable(self.params):
							 | 
						||
| 
								 | 
							
								                params = self.params(context)
							 | 
						||
| 
								 | 
							
								            else:
							 | 
						||
| 
								 | 
							
								                params = self.params
							 | 
						||
| 
								 | 
							
								            if not isinstance(params, list):
							 | 
						||
| 
								 | 
							
								                params = [params]
							 | 
						||
| 
								 | 
							
								            return params
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            return None
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _failure_message(self, expected_params):
							 | 
						||
| 
								 | 
							
								        return (
							 | 
						||
| 
								 | 
							
								            "Testing for compiled statement\n%r partial params %s, "
							 | 
						||
| 
								 | 
							
								            "received\n%%(received_statement)r with params "
							 | 
						||
| 
								 | 
							
								            "%%(received_parameters)r"
							 | 
						||
| 
								 | 
							
								            % (
							 | 
						||
| 
								 | 
							
								                self.statement.replace("%", "%%"),
							 | 
						||
| 
								 | 
							
								                repr(expected_params).replace("%", "%%"),
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class RegexSQL(CompiledSQL):
							 | 
						||
| 
								 | 
							
								    def __init__(self, regex, params=None, dialect="default"):
							 | 
						||
| 
								 | 
							
								        SQLMatchRule.__init__(self)
							 | 
						||
| 
								 | 
							
								        self.regex = re.compile(regex)
							 | 
						||
| 
								 | 
							
								        self.orig_regex = regex
							 | 
						||
| 
								 | 
							
								        self.params = params
							 | 
						||
| 
								 | 
							
								        self.dialect = dialect
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _failure_message(self, expected_params):
							 | 
						||
| 
								 | 
							
								        return (
							 | 
						||
| 
								 | 
							
								            "Testing for compiled statement ~%r partial params %s, "
							 | 
						||
| 
								 | 
							
								            "received %%(received_statement)r with params "
							 | 
						||
| 
								 | 
							
								            "%%(received_parameters)r"
							 | 
						||
| 
								 | 
							
								            % (
							 | 
						||
| 
								 | 
							
								                self.orig_regex.replace("%", "%%"),
							 | 
						||
| 
								 | 
							
								                repr(expected_params).replace("%", "%%"),
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _compare_sql(self, execute_observed, received_statement):
							 | 
						||
| 
								 | 
							
								        return bool(self.regex.match(received_statement))
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class DialectSQL(CompiledSQL):
							 | 
						||
| 
								 | 
							
								    def _compile_dialect(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        return execute_observed.context.dialect
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _compare_no_space(self, real_stmt, received_stmt):
							 | 
						||
| 
								 | 
							
								        stmt = re.sub(r"[\n\t]", "", real_stmt)
							 | 
						||
| 
								 | 
							
								        return received_stmt == stmt
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _received_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        received_stmt, received_params = super(
							 | 
						||
| 
								 | 
							
								            DialectSQL, self
							 | 
						||
| 
								 | 
							
								        )._received_statement(execute_observed)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        # TODO: why do we need this part?
							 | 
						||
| 
								 | 
							
								        for real_stmt in execute_observed.statements:
							 | 
						||
| 
								 | 
							
								            if self._compare_no_space(real_stmt.statement, received_stmt):
							 | 
						||
| 
								 | 
							
								                break
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            raise AssertionError(
							 | 
						||
| 
								 | 
							
								                "Can't locate compiled statement %r in list of "
							 | 
						||
| 
								 | 
							
								                "statements actually invoked" % received_stmt
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return received_stmt, execute_observed.context.compiled_parameters
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _compare_sql(self, execute_observed, received_statement):
							 | 
						||
| 
								 | 
							
								        stmt = re.sub(r"[\n\t]", "", self.statement)
							 | 
						||
| 
								 | 
							
								        # convert our comparison statement to have the
							 | 
						||
| 
								 | 
							
								        # paramstyle of the received
							 | 
						||
| 
								 | 
							
								        paramstyle = execute_observed.context.dialect.paramstyle
							 | 
						||
| 
								 | 
							
								        if paramstyle == "pyformat":
							 | 
						||
| 
								 | 
							
								            stmt = re.sub(r":([\w_]+)", r"%(\1)s", stmt)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            # positional params
							 | 
						||
| 
								 | 
							
								            repl = None
							 | 
						||
| 
								 | 
							
								            if paramstyle == "qmark":
							 | 
						||
| 
								 | 
							
								                repl = "?"
							 | 
						||
| 
								 | 
							
								            elif paramstyle == "format":
							 | 
						||
| 
								 | 
							
								                repl = r"%s"
							 | 
						||
| 
								 | 
							
								            elif paramstyle == "numeric":
							 | 
						||
| 
								 | 
							
								                repl = None
							 | 
						||
| 
								 | 
							
								            stmt = re.sub(r":([\w_]+)", repl, stmt)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        return received_statement == stmt
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class CountStatements(AssertRule):
							 | 
						||
| 
								 | 
							
								    def __init__(self, count):
							 | 
						||
| 
								 | 
							
								        self.count = count
							 | 
						||
| 
								 | 
							
								        self._statement_count = 0
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def process_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        self._statement_count += 1
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def no_more_statements(self):
							 | 
						||
| 
								 | 
							
								        if self.count != self._statement_count:
							 | 
						||
| 
								 | 
							
								            assert False, "desired statement count %d does not match %d" % (
							 | 
						||
| 
								 | 
							
								                self.count,
							 | 
						||
| 
								 | 
							
								                self._statement_count,
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class AllOf(AssertRule):
							 | 
						||
| 
								 | 
							
								    def __init__(self, *rules):
							 | 
						||
| 
								 | 
							
								        self.rules = set(rules)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def process_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        for rule in list(self.rules):
							 | 
						||
| 
								 | 
							
								            rule.errormessage = None
							 | 
						||
| 
								 | 
							
								            rule.process_statement(execute_observed)
							 | 
						||
| 
								 | 
							
								            if rule.is_consumed:
							 | 
						||
| 
								 | 
							
								                self.rules.discard(rule)
							 | 
						||
| 
								 | 
							
								                if not self.rules:
							 | 
						||
| 
								 | 
							
								                    self.is_consumed = True
							 | 
						||
| 
								 | 
							
								                break
							 | 
						||
| 
								 | 
							
								            elif not rule.errormessage:
							 | 
						||
| 
								 | 
							
								                # rule is not done yet
							 | 
						||
| 
								 | 
							
								                self.errormessage = None
							 | 
						||
| 
								 | 
							
								                break
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            self.errormessage = list(self.rules)[0].errormessage
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class EachOf(AssertRule):
							 | 
						||
| 
								 | 
							
								    def __init__(self, *rules):
							 | 
						||
| 
								 | 
							
								        self.rules = list(rules)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def process_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        while self.rules:
							 | 
						||
| 
								 | 
							
								            rule = self.rules[0]
							 | 
						||
| 
								 | 
							
								            rule.process_statement(execute_observed)
							 | 
						||
| 
								 | 
							
								            if rule.is_consumed:
							 | 
						||
| 
								 | 
							
								                self.rules.pop(0)
							 | 
						||
| 
								 | 
							
								            elif rule.errormessage:
							 | 
						||
| 
								 | 
							
								                self.errormessage = rule.errormessage
							 | 
						||
| 
								 | 
							
								            if rule.consume_statement:
							 | 
						||
| 
								 | 
							
								                break
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        if not self.rules:
							 | 
						||
| 
								 | 
							
								            self.is_consumed = True
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def no_more_statements(self):
							 | 
						||
| 
								 | 
							
								        if self.rules and not self.rules[0].is_consumed:
							 | 
						||
| 
								 | 
							
								            self.rules[0].no_more_statements()
							 | 
						||
| 
								 | 
							
								        elif self.rules:
							 | 
						||
| 
								 | 
							
								            super(EachOf, self).no_more_statements()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Conditional(EachOf):
							 | 
						||
| 
								 | 
							
								    def __init__(self, condition, rules, else_rules):
							 | 
						||
| 
								 | 
							
								        if condition:
							 | 
						||
| 
								 | 
							
								            super(Conditional, self).__init__(*rules)
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            super(Conditional, self).__init__(*else_rules)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class Or(AllOf):
							 | 
						||
| 
								 | 
							
								    def process_statement(self, execute_observed):
							 | 
						||
| 
								 | 
							
								        for rule in self.rules:
							 | 
						||
| 
								 | 
							
								            rule.process_statement(execute_observed)
							 | 
						||
| 
								 | 
							
								            if rule.is_consumed:
							 | 
						||
| 
								 | 
							
								                self.is_consumed = True
							 | 
						||
| 
								 | 
							
								                break
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            self.errormessage = list(self.rules)[0].errormessage
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SQLExecuteObserved(object):
							 | 
						||
| 
								 | 
							
								    def __init__(self, context, clauseelement, multiparams, params):
							 | 
						||
| 
								 | 
							
								        self.context = context
							 | 
						||
| 
								 | 
							
								        self.clauseelement = clauseelement
							 | 
						||
| 
								 | 
							
								        self.parameters = _distill_cursor_params(
							 | 
						||
| 
								 | 
							
								            context.connection, tuple(multiparams), params
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								        self.statements = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def __repr__(self):
							 | 
						||
| 
								 | 
							
								        return str(self.statements)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SQLCursorExecuteObserved(
							 | 
						||
| 
								 | 
							
								    collections.namedtuple(
							 | 
						||
| 
								 | 
							
								        "SQLCursorExecuteObserved",
							 | 
						||
| 
								 | 
							
								        ["statement", "parameters", "context", "executemany"],
							 | 
						||
| 
								 | 
							
								    )
							 | 
						||
| 
								 | 
							
								):
							 | 
						||
| 
								 | 
							
								    pass
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								class SQLAsserter(object):
							 | 
						||
| 
								 | 
							
								    def __init__(self):
							 | 
						||
| 
								 | 
							
								        self.accumulated = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def _close(self):
							 | 
						||
| 
								 | 
							
								        self._final = self.accumulated
							 | 
						||
| 
								 | 
							
								        del self.accumulated
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    def assert_(self, *rules):
							 | 
						||
| 
								 | 
							
								        rule = EachOf(*rules)
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        observed = list(self._final)
							 | 
						||
| 
								 | 
							
								        while observed:
							 | 
						||
| 
								 | 
							
								            statement = observed.pop(0)
							 | 
						||
| 
								 | 
							
								            rule.process_statement(statement)
							 | 
						||
| 
								 | 
							
								            if rule.is_consumed:
							 | 
						||
| 
								 | 
							
								                break
							 | 
						||
| 
								 | 
							
								            elif rule.errormessage:
							 | 
						||
| 
								 | 
							
								                assert False, rule.errormessage
							 | 
						||
| 
								 | 
							
								        if observed:
							 | 
						||
| 
								 | 
							
								            assert False, "Additional SQL statements remain:\n%s" % observed
							 | 
						||
| 
								 | 
							
								        elif not rule.is_consumed:
							 | 
						||
| 
								 | 
							
								            rule.no_more_statements()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								@contextlib.contextmanager
							 | 
						||
| 
								 | 
							
								def assert_engine(engine):
							 | 
						||
| 
								 | 
							
								    asserter = SQLAsserter()
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    orig = []
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @event.listens_for(engine, "before_execute")
							 | 
						||
| 
								 | 
							
								    def connection_execute(
							 | 
						||
| 
								 | 
							
								        conn, clauseelement, multiparams, params, execution_options
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        # grab the original statement + params before any cursor
							 | 
						||
| 
								 | 
							
								        # execution
							 | 
						||
| 
								 | 
							
								        orig[:] = clauseelement, multiparams, params
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    @event.listens_for(engine, "after_cursor_execute")
							 | 
						||
| 
								 | 
							
								    def cursor_execute(
							 | 
						||
| 
								 | 
							
								        conn, cursor, statement, parameters, context, executemany
							 | 
						||
| 
								 | 
							
								    ):
							 | 
						||
| 
								 | 
							
								        if not context:
							 | 
						||
| 
								 | 
							
								            return
							 | 
						||
| 
								 | 
							
								        # then grab real cursor statements and associate them all
							 | 
						||
| 
								 | 
							
								        # around a single context
							 | 
						||
| 
								 | 
							
								        if (
							 | 
						||
| 
								 | 
							
								            asserter.accumulated
							 | 
						||
| 
								 | 
							
								            and asserter.accumulated[-1].context is context
							 | 
						||
| 
								 | 
							
								        ):
							 | 
						||
| 
								 | 
							
								            obs = asserter.accumulated[-1]
							 | 
						||
| 
								 | 
							
								        else:
							 | 
						||
| 
								 | 
							
								            obs = SQLExecuteObserved(context, orig[0], orig[1], orig[2])
							 | 
						||
| 
								 | 
							
								            asserter.accumulated.append(obs)
							 | 
						||
| 
								 | 
							
								        obs.statements.append(
							 | 
						||
| 
								 | 
							
								            SQLCursorExecuteObserved(
							 | 
						||
| 
								 | 
							
								                statement, parameters, context, executemany
							 | 
						||
| 
								 | 
							
								            )
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								    try:
							 | 
						||
| 
								 | 
							
								        yield asserter
							 | 
						||
| 
								 | 
							
								    finally:
							 | 
						||
| 
								 | 
							
								        event.remove(engine, "after_cursor_execute", cursor_execute)
							 | 
						||
| 
								 | 
							
								        event.remove(engine, "before_execute", connection_execute)
							 | 
						||
| 
								 | 
							
								        asserter._close()
							 |