You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							242 lines
						
					
					
						
							6.7 KiB
						
					
					
				
			
		
		
	
	
							242 lines
						
					
					
						
							6.7 KiB
						
					
					
				# orm/evaluator.py
 | 
						|
# Copyright (C) 2005-2022 the SQLAlchemy authors and contributors
 | 
						|
# <see AUTHORS file>
 | 
						|
#
 | 
						|
# This module is part of SQLAlchemy and is released under
 | 
						|
# the MIT License: https://www.opensource.org/licenses/mit-license.php
 | 
						|
 | 
						|
import operator
 | 
						|
 | 
						|
from .. import inspect
 | 
						|
from .. import util
 | 
						|
from ..sql import and_
 | 
						|
from ..sql import operators
 | 
						|
 | 
						|
 | 
						|
class UnevaluatableError(Exception):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class _NoObject(operators.ColumnOperators):
 | 
						|
    def operate(self, *arg, **kw):
 | 
						|
        return None
 | 
						|
 | 
						|
    def reverse_operate(self, *arg, **kw):
 | 
						|
        return None
 | 
						|
 | 
						|
 | 
						|
_NO_OBJECT = _NoObject()
 | 
						|
 | 
						|
_straight_ops = set(
 | 
						|
    getattr(operators, op)
 | 
						|
    for op in (
 | 
						|
        "add",
 | 
						|
        "mul",
 | 
						|
        "sub",
 | 
						|
        "div",
 | 
						|
        "mod",
 | 
						|
        "truediv",
 | 
						|
        "lt",
 | 
						|
        "le",
 | 
						|
        "ne",
 | 
						|
        "gt",
 | 
						|
        "ge",
 | 
						|
        "eq",
 | 
						|
    )
 | 
						|
)
 | 
						|
 | 
						|
_extended_ops = {
 | 
						|
    operators.in_op: (lambda a, b: a in b if a is not _NO_OBJECT else None),
 | 
						|
    operators.not_in_op: (
 | 
						|
        lambda a, b: a not in b if a is not _NO_OBJECT else None
 | 
						|
    ),
 | 
						|
}
 | 
						|
 | 
						|
_notimplemented_ops = set(
 | 
						|
    getattr(operators, op)
 | 
						|
    for op in (
 | 
						|
        "like_op",
 | 
						|
        "not_like_op",
 | 
						|
        "ilike_op",
 | 
						|
        "not_ilike_op",
 | 
						|
        "startswith_op",
 | 
						|
        "between_op",
 | 
						|
        "endswith_op",
 | 
						|
        "concat_op",
 | 
						|
    )
 | 
						|
)
 | 
						|
 | 
						|
 | 
						|
class EvaluatorCompiler(object):
 | 
						|
    def __init__(self, target_cls=None):
 | 
						|
        self.target_cls = target_cls
 | 
						|
 | 
						|
    def process(self, *clauses):
 | 
						|
        if len(clauses) > 1:
 | 
						|
            clause = and_(*clauses)
 | 
						|
        elif clauses:
 | 
						|
            clause = clauses[0]
 | 
						|
 | 
						|
        meth = getattr(self, "visit_%s" % clause.__visit_name__, None)
 | 
						|
        if not meth:
 | 
						|
            raise UnevaluatableError(
 | 
						|
                "Cannot evaluate %s" % type(clause).__name__
 | 
						|
            )
 | 
						|
        return meth(clause)
 | 
						|
 | 
						|
    def visit_grouping(self, clause):
 | 
						|
        return self.process(clause.element)
 | 
						|
 | 
						|
    def visit_null(self, clause):
 | 
						|
        return lambda obj: None
 | 
						|
 | 
						|
    def visit_false(self, clause):
 | 
						|
        return lambda obj: False
 | 
						|
 | 
						|
    def visit_true(self, clause):
 | 
						|
        return lambda obj: True
 | 
						|
 | 
						|
    def visit_column(self, clause):
 | 
						|
        if "parentmapper" in clause._annotations:
 | 
						|
            parentmapper = clause._annotations["parentmapper"]
 | 
						|
            if self.target_cls and not issubclass(
 | 
						|
                self.target_cls, parentmapper.class_
 | 
						|
            ):
 | 
						|
                raise UnevaluatableError(
 | 
						|
                    "Can't evaluate criteria against alternate class %s"
 | 
						|
                    % parentmapper.class_
 | 
						|
                )
 | 
						|
            key = parentmapper._columntoproperty[clause].key
 | 
						|
        else:
 | 
						|
            key = clause.key
 | 
						|
            if (
 | 
						|
                self.target_cls
 | 
						|
                and key in inspect(self.target_cls).column_attrs
 | 
						|
            ):
 | 
						|
                util.warn(
 | 
						|
                    "Evaluating non-mapped column expression '%s' onto "
 | 
						|
                    "ORM instances; this is a deprecated use case.  Please "
 | 
						|
                    "make use of the actual mapped columns in ORM-evaluated "
 | 
						|
                    "UPDATE / DELETE expressions." % clause
 | 
						|
                )
 | 
						|
            else:
 | 
						|
                raise UnevaluatableError("Cannot evaluate column: %s" % clause)
 | 
						|
 | 
						|
        get_corresponding_attr = operator.attrgetter(key)
 | 
						|
        return (
 | 
						|
            lambda obj: get_corresponding_attr(obj)
 | 
						|
            if obj is not None
 | 
						|
            else _NO_OBJECT
 | 
						|
        )
 | 
						|
 | 
						|
    def visit_tuple(self, clause):
 | 
						|
        return self.visit_clauselist(clause)
 | 
						|
 | 
						|
    def visit_clauselist(self, clause):
 | 
						|
        evaluators = list(map(self.process, clause.clauses))
 | 
						|
        if clause.operator is operators.or_:
 | 
						|
 | 
						|
            def evaluate(obj):
 | 
						|
                has_null = False
 | 
						|
                for sub_evaluate in evaluators:
 | 
						|
                    value = sub_evaluate(obj)
 | 
						|
                    if value:
 | 
						|
                        return True
 | 
						|
                    has_null = has_null or value is None
 | 
						|
                if has_null:
 | 
						|
                    return None
 | 
						|
                return False
 | 
						|
 | 
						|
        elif clause.operator is operators.and_:
 | 
						|
 | 
						|
            def evaluate(obj):
 | 
						|
                for sub_evaluate in evaluators:
 | 
						|
                    value = sub_evaluate(obj)
 | 
						|
                    if not value:
 | 
						|
                        if value is None or value is _NO_OBJECT:
 | 
						|
                            return None
 | 
						|
                        return False
 | 
						|
                return True
 | 
						|
 | 
						|
        elif clause.operator is operators.comma_op:
 | 
						|
 | 
						|
            def evaluate(obj):
 | 
						|
                values = []
 | 
						|
                for sub_evaluate in evaluators:
 | 
						|
                    value = sub_evaluate(obj)
 | 
						|
                    if value is None or value is _NO_OBJECT:
 | 
						|
                        return None
 | 
						|
                    values.append(value)
 | 
						|
                return tuple(values)
 | 
						|
 | 
						|
        else:
 | 
						|
            raise UnevaluatableError(
 | 
						|
                "Cannot evaluate clauselist with operator %s" % clause.operator
 | 
						|
            )
 | 
						|
 | 
						|
        return evaluate
 | 
						|
 | 
						|
    def visit_binary(self, clause):
 | 
						|
        eval_left, eval_right = list(
 | 
						|
            map(self.process, [clause.left, clause.right])
 | 
						|
        )
 | 
						|
        operator = clause.operator
 | 
						|
        if operator is operators.is_:
 | 
						|
 | 
						|
            def evaluate(obj):
 | 
						|
                return eval_left(obj) == eval_right(obj)
 | 
						|
 | 
						|
        elif operator is operators.is_not:
 | 
						|
 | 
						|
            def evaluate(obj):
 | 
						|
                return eval_left(obj) != eval_right(obj)
 | 
						|
 | 
						|
        elif operator in _extended_ops:
 | 
						|
 | 
						|
            def evaluate(obj):
 | 
						|
                left_val = eval_left(obj)
 | 
						|
                right_val = eval_right(obj)
 | 
						|
                if left_val is None or right_val is None:
 | 
						|
                    return None
 | 
						|
 | 
						|
                return _extended_ops[operator](left_val, right_val)
 | 
						|
 | 
						|
        elif operator in _straight_ops:
 | 
						|
 | 
						|
            def evaluate(obj):
 | 
						|
                left_val = eval_left(obj)
 | 
						|
                right_val = eval_right(obj)
 | 
						|
                if left_val is None or right_val is None:
 | 
						|
                    return None
 | 
						|
                return operator(eval_left(obj), eval_right(obj))
 | 
						|
 | 
						|
        else:
 | 
						|
            raise UnevaluatableError(
 | 
						|
                "Cannot evaluate %s with operator %s"
 | 
						|
                % (type(clause).__name__, clause.operator)
 | 
						|
            )
 | 
						|
        return evaluate
 | 
						|
 | 
						|
    def visit_unary(self, clause):
 | 
						|
        eval_inner = self.process(clause.element)
 | 
						|
        if clause.operator is operators.inv:
 | 
						|
 | 
						|
            def evaluate(obj):
 | 
						|
                value = eval_inner(obj)
 | 
						|
                if value is None:
 | 
						|
                    return None
 | 
						|
                return not value
 | 
						|
 | 
						|
            return evaluate
 | 
						|
        raise UnevaluatableError(
 | 
						|
            "Cannot evaluate %s with operator %s"
 | 
						|
            % (type(clause).__name__, clause.operator)
 | 
						|
        )
 | 
						|
 | 
						|
    def visit_bindparam(self, clause):
 | 
						|
        if clause.callable:
 | 
						|
            val = clause.callable()
 | 
						|
        else:
 | 
						|
            val = clause.value
 | 
						|
        return lambda obj: val
 |