You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							1511 lines
						
					
					
						
							52 KiB
						
					
					
				
			
		
		
	
	
							1511 lines
						
					
					
						
							52 KiB
						
					
					
				"""Fortran/C symbolic expressions
 | 
						|
 | 
						|
References:
 | 
						|
- J3/21-007: Draft Fortran 202x. https://j3-fortran.org/doc/year/21/21-007.pdf
 | 
						|
"""
 | 
						|
 | 
						|
# To analyze Fortran expressions to solve dimensions specifications,
 | 
						|
# for instances, we implement a minimal symbolic engine for parsing
 | 
						|
# expressions into a tree of expression instances. As a first
 | 
						|
# instance, we care only about arithmetic expressions involving
 | 
						|
# integers and operations like addition (+), subtraction (-),
 | 
						|
# multiplication (*), division (Fortran / is Python //, Fortran // is
 | 
						|
# concatenate), and exponentiation (**).  In addition, .pyf files may
 | 
						|
# contain C expressions that support here is implemented as well.
 | 
						|
#
 | 
						|
# TODO: support logical constants (Op.BOOLEAN)
 | 
						|
# TODO: support logical operators (.AND., ...)
 | 
						|
# TODO: support defined operators (.MYOP., ...)
 | 
						|
#
 | 
						|
__all__ = ['Expr']
 | 
						|
 | 
						|
 | 
						|
import re
 | 
						|
import warnings
 | 
						|
from enum import Enum
 | 
						|
from math import gcd
 | 
						|
 | 
						|
 | 
						|
class Language(Enum):
 | 
						|
    """
 | 
						|
    Used as Expr.tostring language argument.
 | 
						|
    """
 | 
						|
    Python = 0
 | 
						|
    Fortran = 1
 | 
						|
    C = 2
 | 
						|
 | 
						|
 | 
						|
class Op(Enum):
 | 
						|
    """
 | 
						|
    Used as Expr op attribute.
 | 
						|
    """
 | 
						|
    INTEGER = 10
 | 
						|
    REAL = 12
 | 
						|
    COMPLEX = 15
 | 
						|
    STRING = 20
 | 
						|
    ARRAY = 30
 | 
						|
    SYMBOL = 40
 | 
						|
    TERNARY = 100
 | 
						|
    APPLY = 200
 | 
						|
    INDEXING = 210
 | 
						|
    CONCAT = 220
 | 
						|
    RELATIONAL = 300
 | 
						|
    TERMS = 1000
 | 
						|
    FACTORS = 2000
 | 
						|
    REF = 3000
 | 
						|
    DEREF = 3001
 | 
						|
 | 
						|
 | 
						|
class RelOp(Enum):
 | 
						|
    """
 | 
						|
    Used in Op.RELATIONAL expression to specify the function part.
 | 
						|
    """
 | 
						|
    EQ = 1
 | 
						|
    NE = 2
 | 
						|
    LT = 3
 | 
						|
    LE = 4
 | 
						|
    GT = 5
 | 
						|
    GE = 6
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def fromstring(cls, s, language=Language.C):
 | 
						|
        if language is Language.Fortran:
 | 
						|
            return {'.eq.': RelOp.EQ, '.ne.': RelOp.NE,
 | 
						|
                    '.lt.': RelOp.LT, '.le.': RelOp.LE,
 | 
						|
                    '.gt.': RelOp.GT, '.ge.': RelOp.GE}[s.lower()]
 | 
						|
        return {'==': RelOp.EQ, '!=': RelOp.NE, '<': RelOp.LT,
 | 
						|
                '<=': RelOp.LE, '>': RelOp.GT, '>=': RelOp.GE}[s]
 | 
						|
 | 
						|
    def tostring(self, language=Language.C):
 | 
						|
        if language is Language.Fortran:
 | 
						|
            return {RelOp.EQ: '.eq.', RelOp.NE: '.ne.',
 | 
						|
                    RelOp.LT: '.lt.', RelOp.LE: '.le.',
 | 
						|
                    RelOp.GT: '.gt.', RelOp.GE: '.ge.'}[self]
 | 
						|
        return {RelOp.EQ: '==', RelOp.NE: '!=',
 | 
						|
                RelOp.LT: '<', RelOp.LE: '<=',
 | 
						|
                RelOp.GT: '>', RelOp.GE: '>='}[self]
 | 
						|
 | 
						|
 | 
						|
class ArithOp(Enum):
 | 
						|
    """
 | 
						|
    Used in Op.APPLY expression to specify the function part.
 | 
						|
    """
 | 
						|
    POS = 1
 | 
						|
    NEG = 2
 | 
						|
    ADD = 3
 | 
						|
    SUB = 4
 | 
						|
    MUL = 5
 | 
						|
    DIV = 6
 | 
						|
    POW = 7
 | 
						|
 | 
						|
 | 
						|
class OpError(Exception):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
class Precedence(Enum):
 | 
						|
    """
 | 
						|
    Used as Expr.tostring precedence argument.
 | 
						|
    """
 | 
						|
    ATOM = 0
 | 
						|
    POWER = 1
 | 
						|
    UNARY = 2
 | 
						|
    PRODUCT = 3
 | 
						|
    SUM = 4
 | 
						|
    LT = 6
 | 
						|
    EQ = 7
 | 
						|
    LAND = 11
 | 
						|
    LOR = 12
 | 
						|
    TERNARY = 13
 | 
						|
    ASSIGN = 14
 | 
						|
    TUPLE = 15
 | 
						|
    NONE = 100
 | 
						|
 | 
						|
 | 
						|
integer_types = (int,)
 | 
						|
number_types = (int, float)
 | 
						|
 | 
						|
 | 
						|
def _pairs_add(d, k, v):
 | 
						|
    # Internal utility method for updating terms and factors data.
 | 
						|
    c = d.get(k)
 | 
						|
    if c is None:
 | 
						|
        d[k] = v
 | 
						|
    else:
 | 
						|
        c = c + v
 | 
						|
        if c:
 | 
						|
            d[k] = c
 | 
						|
        else:
 | 
						|
            del d[k]
 | 
						|
 | 
						|
 | 
						|
class ExprWarning(UserWarning):
 | 
						|
    pass
 | 
						|
 | 
						|
 | 
						|
def ewarn(message):
 | 
						|
    warnings.warn(message, ExprWarning, stacklevel=2)
 | 
						|
 | 
						|
 | 
						|
class Expr:
 | 
						|
    """Represents a Fortran expression as a op-data pair.
 | 
						|
 | 
						|
    Expr instances are hashable and sortable.
 | 
						|
    """
 | 
						|
 | 
						|
    @staticmethod
 | 
						|
    def parse(s, language=Language.C):
 | 
						|
        """Parse a Fortran expression to a Expr.
 | 
						|
        """
 | 
						|
        return fromstring(s, language=language)
 | 
						|
 | 
						|
    def __init__(self, op, data):
 | 
						|
        assert isinstance(op, Op)
 | 
						|
 | 
						|
        # sanity checks
 | 
						|
        if op is Op.INTEGER:
 | 
						|
            # data is a 2-tuple of numeric object and a kind value
 | 
						|
            # (default is 4)
 | 
						|
            assert isinstance(data, tuple) and len(data) == 2
 | 
						|
            assert isinstance(data[0], int)
 | 
						|
            assert isinstance(data[1], (int, str)), data
 | 
						|
        elif op is Op.REAL:
 | 
						|
            # data is a 2-tuple of numeric object and a kind value
 | 
						|
            # (default is 4)
 | 
						|
            assert isinstance(data, tuple) and len(data) == 2
 | 
						|
            assert isinstance(data[0], float)
 | 
						|
            assert isinstance(data[1], (int, str)), data
 | 
						|
        elif op is Op.COMPLEX:
 | 
						|
            # data is a 2-tuple of constant expressions
 | 
						|
            assert isinstance(data, tuple) and len(data) == 2
 | 
						|
        elif op is Op.STRING:
 | 
						|
            # data is a 2-tuple of quoted string and a kind value
 | 
						|
            # (default is 1)
 | 
						|
            assert isinstance(data, tuple) and len(data) == 2
 | 
						|
            assert (isinstance(data[0], str)
 | 
						|
                    and data[0][::len(data[0])-1] in ('""', "''", '@@'))
 | 
						|
            assert isinstance(data[1], (int, str)), data
 | 
						|
        elif op is Op.SYMBOL:
 | 
						|
            # data is any hashable object
 | 
						|
            assert hash(data) is not None
 | 
						|
        elif op in (Op.ARRAY, Op.CONCAT):
 | 
						|
            # data is a tuple of expressions
 | 
						|
            assert isinstance(data, tuple)
 | 
						|
            assert all(isinstance(item, Expr) for item in data), data
 | 
						|
        elif op in (Op.TERMS, Op.FACTORS):
 | 
						|
            # data is {<term|base>:<coeff|exponent>} where dict values
 | 
						|
            # are nonzero Python integers
 | 
						|
            assert isinstance(data, dict)
 | 
						|
        elif op is Op.APPLY:
 | 
						|
            # data is (<function>, <operands>, <kwoperands>) where
 | 
						|
            # operands are Expr instances
 | 
						|
            assert isinstance(data, tuple) and len(data) == 3
 | 
						|
            # function is any hashable object
 | 
						|
            assert hash(data[0]) is not None
 | 
						|
            assert isinstance(data[1], tuple)
 | 
						|
            assert isinstance(data[2], dict)
 | 
						|
        elif op is Op.INDEXING:
 | 
						|
            # data is (<object>, <indices>)
 | 
						|
            assert isinstance(data, tuple) and len(data) == 2
 | 
						|
            # function is any hashable object
 | 
						|
            assert hash(data[0]) is not None
 | 
						|
        elif op is Op.TERNARY:
 | 
						|
            # data is (<cond>, <expr1>, <expr2>)
 | 
						|
            assert isinstance(data, tuple) and len(data) == 3
 | 
						|
        elif op in (Op.REF, Op.DEREF):
 | 
						|
            # data is Expr instance
 | 
						|
            assert isinstance(data, Expr)
 | 
						|
        elif op is Op.RELATIONAL:
 | 
						|
            # data is (<relop>, <left>, <right>)
 | 
						|
            assert isinstance(data, tuple) and len(data) == 3
 | 
						|
        else:
 | 
						|
            raise NotImplementedError(
 | 
						|
                f'unknown op or missing sanity check: {op}')
 | 
						|
 | 
						|
        self.op = op
 | 
						|
        self.data = data
 | 
						|
 | 
						|
    def __eq__(self, other):
 | 
						|
        return (isinstance(other, Expr)
 | 
						|
                and self.op is other.op
 | 
						|
                and self.data == other.data)
 | 
						|
 | 
						|
    def __hash__(self):
 | 
						|
        if self.op in (Op.TERMS, Op.FACTORS):
 | 
						|
            data = tuple(sorted(self.data.items()))
 | 
						|
        elif self.op is Op.APPLY:
 | 
						|
            data = self.data[:2] + tuple(sorted(self.data[2].items()))
 | 
						|
        else:
 | 
						|
            data = self.data
 | 
						|
        return hash((self.op, data))
 | 
						|
 | 
						|
    def __lt__(self, other):
 | 
						|
        if isinstance(other, Expr):
 | 
						|
            if self.op is not other.op:
 | 
						|
                return self.op.value < other.op.value
 | 
						|
            if self.op in (Op.TERMS, Op.FACTORS):
 | 
						|
                return (tuple(sorted(self.data.items()))
 | 
						|
                        < tuple(sorted(other.data.items())))
 | 
						|
            if self.op is Op.APPLY:
 | 
						|
                if self.data[:2] != other.data[:2]:
 | 
						|
                    return self.data[:2] < other.data[:2]
 | 
						|
                return tuple(sorted(self.data[2].items())) < tuple(
 | 
						|
                    sorted(other.data[2].items()))
 | 
						|
            return self.data < other.data
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __le__(self, other): return self == other or self < other
 | 
						|
 | 
						|
    def __gt__(self, other): return not (self <= other)
 | 
						|
 | 
						|
    def __ge__(self, other): return not (self < other)
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return f'{type(self).__name__}({self.op}, {self.data!r})'
 | 
						|
 | 
						|
    def __str__(self):
 | 
						|
        return self.tostring()
 | 
						|
 | 
						|
    def tostring(self, parent_precedence=Precedence.NONE,
 | 
						|
                 language=Language.Fortran):
 | 
						|
        """Return a string representation of Expr.
 | 
						|
        """
 | 
						|
        if self.op in (Op.INTEGER, Op.REAL):
 | 
						|
            precedence = (Precedence.SUM if self.data[0] < 0
 | 
						|
                          else Precedence.ATOM)
 | 
						|
            r = str(self.data[0]) + (f'_{self.data[1]}'
 | 
						|
                                     if self.data[1] != 4 else '')
 | 
						|
        elif self.op is Op.COMPLEX:
 | 
						|
            r = ', '.join(item.tostring(Precedence.TUPLE, language=language)
 | 
						|
                          for item in self.data)
 | 
						|
            r = '(' + r + ')'
 | 
						|
            precedence = Precedence.ATOM
 | 
						|
        elif self.op is Op.SYMBOL:
 | 
						|
            precedence = Precedence.ATOM
 | 
						|
            r = str(self.data)
 | 
						|
        elif self.op is Op.STRING:
 | 
						|
            r = self.data[0]
 | 
						|
            if self.data[1] != 1:
 | 
						|
                r = self.data[1] + '_' + r
 | 
						|
            precedence = Precedence.ATOM
 | 
						|
        elif self.op is Op.ARRAY:
 | 
						|
            r = ', '.join(item.tostring(Precedence.TUPLE, language=language)
 | 
						|
                          for item in self.data)
 | 
						|
            r = '[' + r + ']'
 | 
						|
            precedence = Precedence.ATOM
 | 
						|
        elif self.op is Op.TERMS:
 | 
						|
            terms = []
 | 
						|
            for term, coeff in sorted(self.data.items()):
 | 
						|
                if coeff < 0:
 | 
						|
                    op = ' - '
 | 
						|
                    coeff = -coeff
 | 
						|
                else:
 | 
						|
                    op = ' + '
 | 
						|
                if coeff == 1:
 | 
						|
                    term = term.tostring(Precedence.SUM, language=language)
 | 
						|
                else:
 | 
						|
                    if term == as_number(1):
 | 
						|
                        term = str(coeff)
 | 
						|
                    else:
 | 
						|
                        term = f'{coeff} * ' + term.tostring(
 | 
						|
                            Precedence.PRODUCT, language=language)
 | 
						|
                if terms:
 | 
						|
                    terms.append(op)
 | 
						|
                elif op == ' - ':
 | 
						|
                    terms.append('-')
 | 
						|
                terms.append(term)
 | 
						|
            r = ''.join(terms) or '0'
 | 
						|
            precedence = Precedence.SUM if terms else Precedence.ATOM
 | 
						|
        elif self.op is Op.FACTORS:
 | 
						|
            factors = []
 | 
						|
            tail = []
 | 
						|
            for base, exp in sorted(self.data.items()):
 | 
						|
                op = ' * '
 | 
						|
                if exp == 1:
 | 
						|
                    factor = base.tostring(Precedence.PRODUCT,
 | 
						|
                                           language=language)
 | 
						|
                elif language is Language.C:
 | 
						|
                    if exp in range(2, 10):
 | 
						|
                        factor = base.tostring(Precedence.PRODUCT,
 | 
						|
                                               language=language)
 | 
						|
                        factor = ' * '.join([factor] * exp)
 | 
						|
                    elif exp in range(-10, 0):
 | 
						|
                        factor = base.tostring(Precedence.PRODUCT,
 | 
						|
                                               language=language)
 | 
						|
                        tail += [factor] * -exp
 | 
						|
                        continue
 | 
						|
                    else:
 | 
						|
                        factor = base.tostring(Precedence.TUPLE,
 | 
						|
                                               language=language)
 | 
						|
                        factor = f'pow({factor}, {exp})'
 | 
						|
                else:
 | 
						|
                    factor = base.tostring(Precedence.POWER,
 | 
						|
                                           language=language) + f' ** {exp}'
 | 
						|
                if factors:
 | 
						|
                    factors.append(op)
 | 
						|
                factors.append(factor)
 | 
						|
            if tail:
 | 
						|
                if not factors:
 | 
						|
                    factors += ['1']
 | 
						|
                factors += ['/', '(', ' * '.join(tail), ')']
 | 
						|
            r = ''.join(factors) or '1'
 | 
						|
            precedence = Precedence.PRODUCT if factors else Precedence.ATOM
 | 
						|
        elif self.op is Op.APPLY:
 | 
						|
            name, args, kwargs = self.data
 | 
						|
            if name is ArithOp.DIV and language is Language.C:
 | 
						|
                numer, denom = [arg.tostring(Precedence.PRODUCT,
 | 
						|
                                             language=language)
 | 
						|
                                for arg in args]
 | 
						|
                r = f'{numer} / {denom}'
 | 
						|
                precedence = Precedence.PRODUCT
 | 
						|
            else:
 | 
						|
                args = [arg.tostring(Precedence.TUPLE, language=language)
 | 
						|
                        for arg in args]
 | 
						|
                args += [k + '=' + v.tostring(Precedence.NONE)
 | 
						|
                         for k, v in kwargs.items()]
 | 
						|
                r = f'{name}({", ".join(args)})'
 | 
						|
                precedence = Precedence.ATOM
 | 
						|
        elif self.op is Op.INDEXING:
 | 
						|
            name = self.data[0]
 | 
						|
            args = [arg.tostring(Precedence.TUPLE, language=language)
 | 
						|
                    for arg in self.data[1:]]
 | 
						|
            r = f'{name}[{", ".join(args)}]'
 | 
						|
            precedence = Precedence.ATOM
 | 
						|
        elif self.op is Op.CONCAT:
 | 
						|
            args = [arg.tostring(Precedence.PRODUCT, language=language)
 | 
						|
                    for arg in self.data]
 | 
						|
            r = " // ".join(args)
 | 
						|
            precedence = Precedence.PRODUCT
 | 
						|
        elif self.op is Op.TERNARY:
 | 
						|
            cond, expr1, expr2 = [a.tostring(Precedence.TUPLE,
 | 
						|
                                             language=language)
 | 
						|
                                  for a in self.data]
 | 
						|
            if language is Language.C:
 | 
						|
                r = f'({cond}?{expr1}:{expr2})'
 | 
						|
            elif language is Language.Python:
 | 
						|
                r = f'({expr1} if {cond} else {expr2})'
 | 
						|
            elif language is Language.Fortran:
 | 
						|
                r = f'merge({expr1}, {expr2}, {cond})'
 | 
						|
            else:
 | 
						|
                raise NotImplementedError(
 | 
						|
                    f'tostring for {self.op} and {language}')
 | 
						|
            precedence = Precedence.ATOM
 | 
						|
        elif self.op is Op.REF:
 | 
						|
            r = '&' + self.data.tostring(Precedence.UNARY, language=language)
 | 
						|
            precedence = Precedence.UNARY
 | 
						|
        elif self.op is Op.DEREF:
 | 
						|
            r = '*' + self.data.tostring(Precedence.UNARY, language=language)
 | 
						|
            precedence = Precedence.UNARY
 | 
						|
        elif self.op is Op.RELATIONAL:
 | 
						|
            rop, left, right = self.data
 | 
						|
            precedence = (Precedence.EQ if rop in (RelOp.EQ, RelOp.NE)
 | 
						|
                          else Precedence.LT)
 | 
						|
            left = left.tostring(precedence, language=language)
 | 
						|
            right = right.tostring(precedence, language=language)
 | 
						|
            rop = rop.tostring(language=language)
 | 
						|
            r = f'{left} {rop} {right}'
 | 
						|
        else:
 | 
						|
            raise NotImplementedError(f'tostring for op {self.op}')
 | 
						|
        if parent_precedence.value < precedence.value:
 | 
						|
            # If parent precedence is higher than operand precedence,
 | 
						|
            # operand will be enclosed in parenthesis.
 | 
						|
            return '(' + r + ')'
 | 
						|
        return r
 | 
						|
 | 
						|
    def __pos__(self):
 | 
						|
        return self
 | 
						|
 | 
						|
    def __neg__(self):
 | 
						|
        return self * -1
 | 
						|
 | 
						|
    def __add__(self, other):
 | 
						|
        other = as_expr(other)
 | 
						|
        if isinstance(other, Expr):
 | 
						|
            if self.op is other.op:
 | 
						|
                if self.op in (Op.INTEGER, Op.REAL):
 | 
						|
                    return as_number(
 | 
						|
                        self.data[0] + other.data[0],
 | 
						|
                        max(self.data[1], other.data[1]))
 | 
						|
                if self.op is Op.COMPLEX:
 | 
						|
                    r1, i1 = self.data
 | 
						|
                    r2, i2 = other.data
 | 
						|
                    return as_complex(r1 + r2, i1 + i2)
 | 
						|
                if self.op is Op.TERMS:
 | 
						|
                    r = Expr(self.op, dict(self.data))
 | 
						|
                    for k, v in other.data.items():
 | 
						|
                        _pairs_add(r.data, k, v)
 | 
						|
                    return normalize(r)
 | 
						|
            if self.op is Op.COMPLEX and other.op in (Op.INTEGER, Op.REAL):
 | 
						|
                return self + as_complex(other)
 | 
						|
            elif self.op in (Op.INTEGER, Op.REAL) and other.op is Op.COMPLEX:
 | 
						|
                return as_complex(self) + other
 | 
						|
            elif self.op is Op.REAL and other.op is Op.INTEGER:
 | 
						|
                return self + as_real(other, kind=self.data[1])
 | 
						|
            elif self.op is Op.INTEGER and other.op is Op.REAL:
 | 
						|
                return as_real(self, kind=other.data[1]) + other
 | 
						|
            return as_terms(self) + as_terms(other)
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __radd__(self, other):
 | 
						|
        if isinstance(other, number_types):
 | 
						|
            return as_number(other) + self
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __sub__(self, other):
 | 
						|
        return self + (-other)
 | 
						|
 | 
						|
    def __rsub__(self, other):
 | 
						|
        if isinstance(other, number_types):
 | 
						|
            return as_number(other) - self
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __mul__(self, other):
 | 
						|
        other = as_expr(other)
 | 
						|
        if isinstance(other, Expr):
 | 
						|
            if self.op is other.op:
 | 
						|
                if self.op in (Op.INTEGER, Op.REAL):
 | 
						|
                    return as_number(self.data[0] * other.data[0],
 | 
						|
                                     max(self.data[1], other.data[1]))
 | 
						|
                elif self.op is Op.COMPLEX:
 | 
						|
                    r1, i1 = self.data
 | 
						|
                    r2, i2 = other.data
 | 
						|
                    return as_complex(r1 * r2 - i1 * i2, r1 * i2 + r2 * i1)
 | 
						|
 | 
						|
                if self.op is Op.FACTORS:
 | 
						|
                    r = Expr(self.op, dict(self.data))
 | 
						|
                    for k, v in other.data.items():
 | 
						|
                        _pairs_add(r.data, k, v)
 | 
						|
                    return normalize(r)
 | 
						|
                elif self.op is Op.TERMS:
 | 
						|
                    r = Expr(self.op, {})
 | 
						|
                    for t1, c1 in self.data.items():
 | 
						|
                        for t2, c2 in other.data.items():
 | 
						|
                            _pairs_add(r.data, t1 * t2, c1 * c2)
 | 
						|
                    return normalize(r)
 | 
						|
 | 
						|
            if self.op is Op.COMPLEX and other.op in (Op.INTEGER, Op.REAL):
 | 
						|
                return self * as_complex(other)
 | 
						|
            elif other.op is Op.COMPLEX and self.op in (Op.INTEGER, Op.REAL):
 | 
						|
                return as_complex(self) * other
 | 
						|
            elif self.op is Op.REAL and other.op is Op.INTEGER:
 | 
						|
                return self * as_real(other, kind=self.data[1])
 | 
						|
            elif self.op is Op.INTEGER and other.op is Op.REAL:
 | 
						|
                return as_real(self, kind=other.data[1]) * other
 | 
						|
 | 
						|
            if self.op is Op.TERMS:
 | 
						|
                return self * as_terms(other)
 | 
						|
            elif other.op is Op.TERMS:
 | 
						|
                return as_terms(self) * other
 | 
						|
 | 
						|
            return as_factors(self) * as_factors(other)
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __rmul__(self, other):
 | 
						|
        if isinstance(other, number_types):
 | 
						|
            return as_number(other) * self
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __pow__(self, other):
 | 
						|
        other = as_expr(other)
 | 
						|
        if isinstance(other, Expr):
 | 
						|
            if other.op is Op.INTEGER:
 | 
						|
                exponent = other.data[0]
 | 
						|
                # TODO: other kind not used
 | 
						|
                if exponent == 0:
 | 
						|
                    return as_number(1)
 | 
						|
                if exponent == 1:
 | 
						|
                    return self
 | 
						|
                if exponent > 0:
 | 
						|
                    if self.op is Op.FACTORS:
 | 
						|
                        r = Expr(self.op, {})
 | 
						|
                        for k, v in self.data.items():
 | 
						|
                            r.data[k] = v * exponent
 | 
						|
                        return normalize(r)
 | 
						|
                    return self * (self ** (exponent - 1))
 | 
						|
                elif exponent != -1:
 | 
						|
                    return (self ** (-exponent)) ** -1
 | 
						|
                return Expr(Op.FACTORS, {self: exponent})
 | 
						|
            return as_apply(ArithOp.POW, self, other)
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __truediv__(self, other):
 | 
						|
        other = as_expr(other)
 | 
						|
        if isinstance(other, Expr):
 | 
						|
            # Fortran / is different from Python /:
 | 
						|
            # - `/` is a truncate operation for integer operands
 | 
						|
            return normalize(as_apply(ArithOp.DIV, self, other))
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __rtruediv__(self, other):
 | 
						|
        other = as_expr(other)
 | 
						|
        if isinstance(other, Expr):
 | 
						|
            return other / self
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __floordiv__(self, other):
 | 
						|
        other = as_expr(other)
 | 
						|
        if isinstance(other, Expr):
 | 
						|
            # Fortran // is different from Python //:
 | 
						|
            # - `//` is a concatenate operation for string operands
 | 
						|
            return normalize(Expr(Op.CONCAT, (self, other)))
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __rfloordiv__(self, other):
 | 
						|
        other = as_expr(other)
 | 
						|
        if isinstance(other, Expr):
 | 
						|
            return other // self
 | 
						|
        return NotImplemented
 | 
						|
 | 
						|
    def __call__(self, *args, **kwargs):
 | 
						|
        # In Fortran, parenthesis () are use for both function call as
 | 
						|
        # well as indexing operations.
 | 
						|
        #
 | 
						|
        # TODO: implement a method for deciding when __call__ should
 | 
						|
        # return an INDEXING expression.
 | 
						|
        return as_apply(self, *map(as_expr, args),
 | 
						|
                        **dict((k, as_expr(v)) for k, v in kwargs.items()))
 | 
						|
 | 
						|
    def __getitem__(self, index):
 | 
						|
        # Provided to support C indexing operations that .pyf files
 | 
						|
        # may contain.
 | 
						|
        index = as_expr(index)
 | 
						|
        if not isinstance(index, tuple):
 | 
						|
            index = index,
 | 
						|
        if len(index) > 1:
 | 
						|
            ewarn(f'C-index should be a single expression but got `{index}`')
 | 
						|
        return Expr(Op.INDEXING, (self,) + index)
 | 
						|
 | 
						|
    def substitute(self, symbols_map):
 | 
						|
        """Recursively substitute symbols with values in symbols map.
 | 
						|
 | 
						|
        Symbols map is a dictionary of symbol-expression pairs.
 | 
						|
        """
 | 
						|
        if self.op is Op.SYMBOL:
 | 
						|
            value = symbols_map.get(self)
 | 
						|
            if value is None:
 | 
						|
                return self
 | 
						|
            m = re.match(r'\A(@__f2py_PARENTHESIS_(\w+)_\d+@)\Z', self.data)
 | 
						|
            if m:
 | 
						|
                # complement to fromstring method
 | 
						|
                items, paren = m.groups()
 | 
						|
                if paren in ['ROUNDDIV', 'SQUARE']:
 | 
						|
                    return as_array(value)
 | 
						|
                assert paren == 'ROUND', (paren, value)
 | 
						|
            return value
 | 
						|
        if self.op in (Op.INTEGER, Op.REAL, Op.STRING):
 | 
						|
            return self
 | 
						|
        if self.op in (Op.ARRAY, Op.COMPLEX):
 | 
						|
            return Expr(self.op, tuple(item.substitute(symbols_map)
 | 
						|
                                       for item in self.data))
 | 
						|
        if self.op is Op.CONCAT:
 | 
						|
            return normalize(Expr(self.op, tuple(item.substitute(symbols_map)
 | 
						|
                                                 for item in self.data)))
 | 
						|
        if self.op is Op.TERMS:
 | 
						|
            r = None
 | 
						|
            for term, coeff in self.data.items():
 | 
						|
                if r is None:
 | 
						|
                    r = term.substitute(symbols_map) * coeff
 | 
						|
                else:
 | 
						|
                    r += term.substitute(symbols_map) * coeff
 | 
						|
            if r is None:
 | 
						|
                ewarn('substitute: empty TERMS expression interpreted as'
 | 
						|
                      ' int-literal 0')
 | 
						|
                return as_number(0)
 | 
						|
            return r
 | 
						|
        if self.op is Op.FACTORS:
 | 
						|
            r = None
 | 
						|
            for base, exponent in self.data.items():
 | 
						|
                if r is None:
 | 
						|
                    r = base.substitute(symbols_map) ** exponent
 | 
						|
                else:
 | 
						|
                    r *= base.substitute(symbols_map) ** exponent
 | 
						|
            if r is None:
 | 
						|
                ewarn('substitute: empty FACTORS expression interpreted'
 | 
						|
                      ' as int-literal 1')
 | 
						|
                return as_number(1)
 | 
						|
            return r
 | 
						|
        if self.op is Op.APPLY:
 | 
						|
            target, args, kwargs = self.data
 | 
						|
            if isinstance(target, Expr):
 | 
						|
                target = target.substitute(symbols_map)
 | 
						|
            args = tuple(a.substitute(symbols_map) for a in args)
 | 
						|
            kwargs = dict((k, v.substitute(symbols_map))
 | 
						|
                          for k, v in kwargs.items())
 | 
						|
            return normalize(Expr(self.op, (target, args, kwargs)))
 | 
						|
        if self.op is Op.INDEXING:
 | 
						|
            func = self.data[0]
 | 
						|
            if isinstance(func, Expr):
 | 
						|
                func = func.substitute(symbols_map)
 | 
						|
            args = tuple(a.substitute(symbols_map) for a in self.data[1:])
 | 
						|
            return normalize(Expr(self.op, (func,) + args))
 | 
						|
        if self.op is Op.TERNARY:
 | 
						|
            operands = tuple(a.substitute(symbols_map) for a in self.data)
 | 
						|
            return normalize(Expr(self.op, operands))
 | 
						|
        if self.op in (Op.REF, Op.DEREF):
 | 
						|
            return normalize(Expr(self.op, self.data.substitute(symbols_map)))
 | 
						|
        if self.op is Op.RELATIONAL:
 | 
						|
            rop, left, right = self.data
 | 
						|
            left = left.substitute(symbols_map)
 | 
						|
            right = right.substitute(symbols_map)
 | 
						|
            return normalize(Expr(self.op, (rop, left, right)))
 | 
						|
        raise NotImplementedError(f'substitute method for {self.op}: {self!r}')
 | 
						|
 | 
						|
    def traverse(self, visit, *args, **kwargs):
 | 
						|
        """Traverse expression tree with visit function.
 | 
						|
 | 
						|
        The visit function is applied to an expression with given args
 | 
						|
        and kwargs.
 | 
						|
 | 
						|
        Traverse call returns an expression returned by visit when not
 | 
						|
        None, otherwise return a new normalized expression with
 | 
						|
        traverse-visit sub-expressions.
 | 
						|
        """
 | 
						|
        result = visit(self, *args, **kwargs)
 | 
						|
        if result is not None:
 | 
						|
            return result
 | 
						|
 | 
						|
        if self.op in (Op.INTEGER, Op.REAL, Op.STRING, Op.SYMBOL):
 | 
						|
            return self
 | 
						|
        elif self.op in (Op.COMPLEX, Op.ARRAY, Op.CONCAT, Op.TERNARY):
 | 
						|
            return normalize(Expr(self.op, tuple(
 | 
						|
                item.traverse(visit, *args, **kwargs)
 | 
						|
                for item in self.data)))
 | 
						|
        elif self.op in (Op.TERMS, Op.FACTORS):
 | 
						|
            data = {}
 | 
						|
            for k, v in self.data.items():
 | 
						|
                k = k.traverse(visit, *args, **kwargs)
 | 
						|
                v = (v.traverse(visit, *args, **kwargs)
 | 
						|
                     if isinstance(v, Expr) else v)
 | 
						|
                if k in data:
 | 
						|
                    v = data[k] + v
 | 
						|
                data[k] = v
 | 
						|
            return normalize(Expr(self.op, data))
 | 
						|
        elif self.op is Op.APPLY:
 | 
						|
            obj = self.data[0]
 | 
						|
            func = (obj.traverse(visit, *args, **kwargs)
 | 
						|
                    if isinstance(obj, Expr) else obj)
 | 
						|
            operands = tuple(operand.traverse(visit, *args, **kwargs)
 | 
						|
                             for operand in self.data[1])
 | 
						|
            kwoperands = dict((k, v.traverse(visit, *args, **kwargs))
 | 
						|
                              for k, v in self.data[2].items())
 | 
						|
            return normalize(Expr(self.op, (func, operands, kwoperands)))
 | 
						|
        elif self.op is Op.INDEXING:
 | 
						|
            obj = self.data[0]
 | 
						|
            obj = (obj.traverse(visit, *args, **kwargs)
 | 
						|
                   if isinstance(obj, Expr) else obj)
 | 
						|
            indices = tuple(index.traverse(visit, *args, **kwargs)
 | 
						|
                            for index in self.data[1:])
 | 
						|
            return normalize(Expr(self.op, (obj,) + indices))
 | 
						|
        elif self.op in (Op.REF, Op.DEREF):
 | 
						|
            return normalize(Expr(self.op,
 | 
						|
                                  self.data.traverse(visit, *args, **kwargs)))
 | 
						|
        elif self.op is Op.RELATIONAL:
 | 
						|
            rop, left, right = self.data
 | 
						|
            left = left.traverse(visit, *args, **kwargs)
 | 
						|
            right = right.traverse(visit, *args, **kwargs)
 | 
						|
            return normalize(Expr(self.op, (rop, left, right)))
 | 
						|
        raise NotImplementedError(f'traverse method for {self.op}')
 | 
						|
 | 
						|
    def contains(self, other):
 | 
						|
        """Check if self contains other.
 | 
						|
        """
 | 
						|
        found = []
 | 
						|
 | 
						|
        def visit(expr, found=found):
 | 
						|
            if found:
 | 
						|
                return expr
 | 
						|
            elif expr == other:
 | 
						|
                found.append(1)
 | 
						|
                return expr
 | 
						|
 | 
						|
        self.traverse(visit)
 | 
						|
 | 
						|
        return len(found) != 0
 | 
						|
 | 
						|
    def symbols(self):
 | 
						|
        """Return a set of symbols contained in self.
 | 
						|
        """
 | 
						|
        found = set()
 | 
						|
 | 
						|
        def visit(expr, found=found):
 | 
						|
            if expr.op is Op.SYMBOL:
 | 
						|
                found.add(expr)
 | 
						|
 | 
						|
        self.traverse(visit)
 | 
						|
 | 
						|
        return found
 | 
						|
 | 
						|
    def polynomial_atoms(self):
 | 
						|
        """Return a set of expressions used as atoms in polynomial self.
 | 
						|
        """
 | 
						|
        found = set()
 | 
						|
 | 
						|
        def visit(expr, found=found):
 | 
						|
            if expr.op is Op.FACTORS:
 | 
						|
                for b in expr.data:
 | 
						|
                    b.traverse(visit)
 | 
						|
                return expr
 | 
						|
            if expr.op in (Op.TERMS, Op.COMPLEX):
 | 
						|
                return
 | 
						|
            if expr.op is Op.APPLY and isinstance(expr.data[0], ArithOp):
 | 
						|
                if expr.data[0] is ArithOp.POW:
 | 
						|
                    expr.data[1][0].traverse(visit)
 | 
						|
                    return expr
 | 
						|
                return
 | 
						|
            if expr.op in (Op.INTEGER, Op.REAL):
 | 
						|
                return expr
 | 
						|
 | 
						|
            found.add(expr)
 | 
						|
 | 
						|
            if expr.op in (Op.INDEXING, Op.APPLY):
 | 
						|
                return expr
 | 
						|
 | 
						|
        self.traverse(visit)
 | 
						|
 | 
						|
        return found
 | 
						|
 | 
						|
    def linear_solve(self, symbol):
 | 
						|
        """Return a, b such that a * symbol + b == self.
 | 
						|
 | 
						|
        If self is not linear with respect to symbol, raise RuntimeError.
 | 
						|
        """
 | 
						|
        b = self.substitute({symbol: as_number(0)})
 | 
						|
        ax = self - b
 | 
						|
        a = ax.substitute({symbol: as_number(1)})
 | 
						|
 | 
						|
        zero, _ = as_numer_denom(a * symbol - ax)
 | 
						|
 | 
						|
        if zero != as_number(0):
 | 
						|
            raise RuntimeError(f'not a {symbol}-linear equation:'
 | 
						|
                               f' {a} * {symbol} + {b} == {self}')
 | 
						|
        return a, b
 | 
						|
 | 
						|
 | 
						|
def normalize(obj):
 | 
						|
    """Normalize Expr and apply basic evaluation methods.
 | 
						|
    """
 | 
						|
    if not isinstance(obj, Expr):
 | 
						|
        return obj
 | 
						|
 | 
						|
    if obj.op is Op.TERMS:
 | 
						|
        d = {}
 | 
						|
        for t, c in obj.data.items():
 | 
						|
            if c == 0:
 | 
						|
                continue
 | 
						|
            if t.op is Op.COMPLEX and c != 1:
 | 
						|
                t = t * c
 | 
						|
                c = 1
 | 
						|
            if t.op is Op.TERMS:
 | 
						|
                for t1, c1 in t.data.items():
 | 
						|
                    _pairs_add(d, t1, c1 * c)
 | 
						|
            else:
 | 
						|
                _pairs_add(d, t, c)
 | 
						|
        if len(d) == 0:
 | 
						|
            # TODO: deterimine correct kind
 | 
						|
            return as_number(0)
 | 
						|
        elif len(d) == 1:
 | 
						|
            (t, c), = d.items()
 | 
						|
            if c == 1:
 | 
						|
                return t
 | 
						|
        return Expr(Op.TERMS, d)
 | 
						|
 | 
						|
    if obj.op is Op.FACTORS:
 | 
						|
        coeff = 1
 | 
						|
        d = {}
 | 
						|
        for b, e in obj.data.items():
 | 
						|
            if e == 0:
 | 
						|
                continue
 | 
						|
            if b.op is Op.TERMS and isinstance(e, integer_types) and e > 1:
 | 
						|
                # expand integer powers of sums
 | 
						|
                b = b * (b ** (e - 1))
 | 
						|
                e = 1
 | 
						|
 | 
						|
            if b.op in (Op.INTEGER, Op.REAL):
 | 
						|
                if e == 1:
 | 
						|
                    coeff *= b.data[0]
 | 
						|
                elif e > 0:
 | 
						|
                    coeff *= b.data[0] ** e
 | 
						|
                else:
 | 
						|
                    _pairs_add(d, b, e)
 | 
						|
            elif b.op is Op.FACTORS:
 | 
						|
                if e > 0 and isinstance(e, integer_types):
 | 
						|
                    for b1, e1 in b.data.items():
 | 
						|
                        _pairs_add(d, b1, e1 * e)
 | 
						|
                else:
 | 
						|
                    _pairs_add(d, b, e)
 | 
						|
            else:
 | 
						|
                _pairs_add(d, b, e)
 | 
						|
        if len(d) == 0 or coeff == 0:
 | 
						|
            # TODO: deterimine correct kind
 | 
						|
            assert isinstance(coeff, number_types)
 | 
						|
            return as_number(coeff)
 | 
						|
        elif len(d) == 1:
 | 
						|
            (b, e), = d.items()
 | 
						|
            if e == 1:
 | 
						|
                t = b
 | 
						|
            else:
 | 
						|
                t = Expr(Op.FACTORS, d)
 | 
						|
            if coeff == 1:
 | 
						|
                return t
 | 
						|
            return Expr(Op.TERMS, {t: coeff})
 | 
						|
        elif coeff == 1:
 | 
						|
            return Expr(Op.FACTORS, d)
 | 
						|
        else:
 | 
						|
            return Expr(Op.TERMS, {Expr(Op.FACTORS, d): coeff})
 | 
						|
 | 
						|
    if obj.op is Op.APPLY and obj.data[0] is ArithOp.DIV:
 | 
						|
        dividend, divisor = obj.data[1]
 | 
						|
        t1, c1 = as_term_coeff(dividend)
 | 
						|
        t2, c2 = as_term_coeff(divisor)
 | 
						|
        if isinstance(c1, integer_types) and isinstance(c2, integer_types):
 | 
						|
            g = gcd(c1, c2)
 | 
						|
            c1, c2 = c1//g, c2//g
 | 
						|
        else:
 | 
						|
            c1, c2 = c1/c2, 1
 | 
						|
 | 
						|
        if t1.op is Op.APPLY and t1.data[0] is ArithOp.DIV:
 | 
						|
            numer = t1.data[1][0] * c1
 | 
						|
            denom = t1.data[1][1] * t2 * c2
 | 
						|
            return as_apply(ArithOp.DIV, numer, denom)
 | 
						|
 | 
						|
        if t2.op is Op.APPLY and t2.data[0] is ArithOp.DIV:
 | 
						|
            numer = t2.data[1][1] * t1 * c1
 | 
						|
            denom = t2.data[1][0] * c2
 | 
						|
            return as_apply(ArithOp.DIV, numer, denom)
 | 
						|
 | 
						|
        d = dict(as_factors(t1).data)
 | 
						|
        for b, e in as_factors(t2).data.items():
 | 
						|
            _pairs_add(d, b, -e)
 | 
						|
        numer, denom = {}, {}
 | 
						|
        for b, e in d.items():
 | 
						|
            if e > 0:
 | 
						|
                numer[b] = e
 | 
						|
            else:
 | 
						|
                denom[b] = -e
 | 
						|
        numer = normalize(Expr(Op.FACTORS, numer)) * c1
 | 
						|
        denom = normalize(Expr(Op.FACTORS, denom)) * c2
 | 
						|
 | 
						|
        if denom.op in (Op.INTEGER, Op.REAL) and denom.data[0] == 1:
 | 
						|
            # TODO: denom kind not used
 | 
						|
            return numer
 | 
						|
        return as_apply(ArithOp.DIV, numer, denom)
 | 
						|
 | 
						|
    if obj.op is Op.CONCAT:
 | 
						|
        lst = [obj.data[0]]
 | 
						|
        for s in obj.data[1:]:
 | 
						|
            last = lst[-1]
 | 
						|
            if (
 | 
						|
                    last.op is Op.STRING
 | 
						|
                    and s.op is Op.STRING
 | 
						|
                    and last.data[0][0] in '"\''
 | 
						|
                    and s.data[0][0] == last.data[0][-1]
 | 
						|
            ):
 | 
						|
                new_last = as_string(last.data[0][:-1] + s.data[0][1:],
 | 
						|
                                     max(last.data[1], s.data[1]))
 | 
						|
                lst[-1] = new_last
 | 
						|
            else:
 | 
						|
                lst.append(s)
 | 
						|
        if len(lst) == 1:
 | 
						|
            return lst[0]
 | 
						|
        return Expr(Op.CONCAT, tuple(lst))
 | 
						|
 | 
						|
    if obj.op is Op.TERNARY:
 | 
						|
        cond, expr1, expr2 = map(normalize, obj.data)
 | 
						|
        if cond.op is Op.INTEGER:
 | 
						|
            return expr1 if cond.data[0] else expr2
 | 
						|
        return Expr(Op.TERNARY, (cond, expr1, expr2))
 | 
						|
 | 
						|
    return obj
 | 
						|
 | 
						|
 | 
						|
def as_expr(obj):
 | 
						|
    """Convert non-Expr objects to Expr objects.
 | 
						|
    """
 | 
						|
    if isinstance(obj, complex):
 | 
						|
        return as_complex(obj.real, obj.imag)
 | 
						|
    if isinstance(obj, number_types):
 | 
						|
        return as_number(obj)
 | 
						|
    if isinstance(obj, str):
 | 
						|
        # STRING expression holds string with boundary quotes, hence
 | 
						|
        # applying repr:
 | 
						|
        return as_string(repr(obj))
 | 
						|
    if isinstance(obj, tuple):
 | 
						|
        return tuple(map(as_expr, obj))
 | 
						|
    return obj
 | 
						|
 | 
						|
 | 
						|
def as_symbol(obj):
 | 
						|
    """Return object as SYMBOL expression (variable or unparsed expression).
 | 
						|
    """
 | 
						|
    return Expr(Op.SYMBOL, obj)
 | 
						|
 | 
						|
 | 
						|
def as_number(obj, kind=4):
 | 
						|
    """Return object as INTEGER or REAL constant.
 | 
						|
    """
 | 
						|
    if isinstance(obj, int):
 | 
						|
        return Expr(Op.INTEGER, (obj, kind))
 | 
						|
    if isinstance(obj, float):
 | 
						|
        return Expr(Op.REAL, (obj, kind))
 | 
						|
    if isinstance(obj, Expr):
 | 
						|
        if obj.op in (Op.INTEGER, Op.REAL):
 | 
						|
            return obj
 | 
						|
    raise OpError(f'cannot convert {obj} to INTEGER or REAL constant')
 | 
						|
 | 
						|
 | 
						|
def as_integer(obj, kind=4):
 | 
						|
    """Return object as INTEGER constant.
 | 
						|
    """
 | 
						|
    if isinstance(obj, int):
 | 
						|
        return Expr(Op.INTEGER, (obj, kind))
 | 
						|
    if isinstance(obj, Expr):
 | 
						|
        if obj.op is Op.INTEGER:
 | 
						|
            return obj
 | 
						|
    raise OpError(f'cannot convert {obj} to INTEGER constant')
 | 
						|
 | 
						|
 | 
						|
def as_real(obj, kind=4):
 | 
						|
    """Return object as REAL constant.
 | 
						|
    """
 | 
						|
    if isinstance(obj, int):
 | 
						|
        return Expr(Op.REAL, (float(obj), kind))
 | 
						|
    if isinstance(obj, float):
 | 
						|
        return Expr(Op.REAL, (obj, kind))
 | 
						|
    if isinstance(obj, Expr):
 | 
						|
        if obj.op is Op.REAL:
 | 
						|
            return obj
 | 
						|
        elif obj.op is Op.INTEGER:
 | 
						|
            return Expr(Op.REAL, (float(obj.data[0]), kind))
 | 
						|
    raise OpError(f'cannot convert {obj} to REAL constant')
 | 
						|
 | 
						|
 | 
						|
def as_string(obj, kind=1):
 | 
						|
    """Return object as STRING expression (string literal constant).
 | 
						|
    """
 | 
						|
    return Expr(Op.STRING, (obj, kind))
 | 
						|
 | 
						|
 | 
						|
def as_array(obj):
 | 
						|
    """Return object as ARRAY expression (array constant).
 | 
						|
    """
 | 
						|
    if isinstance(obj, Expr):
 | 
						|
        obj = obj,
 | 
						|
    return Expr(Op.ARRAY, obj)
 | 
						|
 | 
						|
 | 
						|
def as_complex(real, imag=0):
 | 
						|
    """Return object as COMPLEX expression (complex literal constant).
 | 
						|
    """
 | 
						|
    return Expr(Op.COMPLEX, (as_expr(real), as_expr(imag)))
 | 
						|
 | 
						|
 | 
						|
def as_apply(func, *args, **kwargs):
 | 
						|
    """Return object as APPLY expression (function call, constructor, etc.)
 | 
						|
    """
 | 
						|
    return Expr(Op.APPLY,
 | 
						|
                (func, tuple(map(as_expr, args)),
 | 
						|
                 dict((k, as_expr(v)) for k, v in kwargs.items())))
 | 
						|
 | 
						|
 | 
						|
def as_ternary(cond, expr1, expr2):
 | 
						|
    """Return object as TERNARY expression (cond?expr1:expr2).
 | 
						|
    """
 | 
						|
    return Expr(Op.TERNARY, (cond, expr1, expr2))
 | 
						|
 | 
						|
 | 
						|
def as_ref(expr):
 | 
						|
    """Return object as referencing expression.
 | 
						|
    """
 | 
						|
    return Expr(Op.REF, expr)
 | 
						|
 | 
						|
 | 
						|
def as_deref(expr):
 | 
						|
    """Return object as dereferencing expression.
 | 
						|
    """
 | 
						|
    return Expr(Op.DEREF, expr)
 | 
						|
 | 
						|
 | 
						|
def as_eq(left, right):
 | 
						|
    return Expr(Op.RELATIONAL, (RelOp.EQ, left, right))
 | 
						|
 | 
						|
 | 
						|
def as_ne(left, right):
 | 
						|
    return Expr(Op.RELATIONAL, (RelOp.NE, left, right))
 | 
						|
 | 
						|
 | 
						|
def as_lt(left, right):
 | 
						|
    return Expr(Op.RELATIONAL, (RelOp.LT, left, right))
 | 
						|
 | 
						|
 | 
						|
def as_le(left, right):
 | 
						|
    return Expr(Op.RELATIONAL, (RelOp.LE, left, right))
 | 
						|
 | 
						|
 | 
						|
def as_gt(left, right):
 | 
						|
    return Expr(Op.RELATIONAL, (RelOp.GT, left, right))
 | 
						|
 | 
						|
 | 
						|
def as_ge(left, right):
 | 
						|
    return Expr(Op.RELATIONAL, (RelOp.GE, left, right))
 | 
						|
 | 
						|
 | 
						|
def as_terms(obj):
 | 
						|
    """Return expression as TERMS expression.
 | 
						|
    """
 | 
						|
    if isinstance(obj, Expr):
 | 
						|
        obj = normalize(obj)
 | 
						|
        if obj.op is Op.TERMS:
 | 
						|
            return obj
 | 
						|
        if obj.op is Op.INTEGER:
 | 
						|
            return Expr(Op.TERMS, {as_integer(1, obj.data[1]): obj.data[0]})
 | 
						|
        if obj.op is Op.REAL:
 | 
						|
            return Expr(Op.TERMS, {as_real(1, obj.data[1]): obj.data[0]})
 | 
						|
        return Expr(Op.TERMS, {obj: 1})
 | 
						|
    raise OpError(f'cannot convert {type(obj)} to terms Expr')
 | 
						|
 | 
						|
 | 
						|
def as_factors(obj):
 | 
						|
    """Return expression as FACTORS expression.
 | 
						|
    """
 | 
						|
    if isinstance(obj, Expr):
 | 
						|
        obj = normalize(obj)
 | 
						|
        if obj.op is Op.FACTORS:
 | 
						|
            return obj
 | 
						|
        if obj.op is Op.TERMS:
 | 
						|
            if len(obj.data) == 1:
 | 
						|
                (term, coeff), = obj.data.items()
 | 
						|
                if coeff == 1:
 | 
						|
                    return Expr(Op.FACTORS, {term: 1})
 | 
						|
                return Expr(Op.FACTORS, {term: 1, Expr.number(coeff): 1})
 | 
						|
        if ((obj.op is Op.APPLY
 | 
						|
             and obj.data[0] is ArithOp.DIV
 | 
						|
             and not obj.data[2])):
 | 
						|
            return Expr(Op.FACTORS, {obj.data[1][0]: 1, obj.data[1][1]: -1})
 | 
						|
        return Expr(Op.FACTORS, {obj: 1})
 | 
						|
    raise OpError(f'cannot convert {type(obj)} to terms Expr')
 | 
						|
 | 
						|
 | 
						|
def as_term_coeff(obj):
 | 
						|
    """Return expression as term-coefficient pair.
 | 
						|
    """
 | 
						|
    if isinstance(obj, Expr):
 | 
						|
        obj = normalize(obj)
 | 
						|
        if obj.op is Op.INTEGER:
 | 
						|
            return as_integer(1, obj.data[1]), obj.data[0]
 | 
						|
        if obj.op is Op.REAL:
 | 
						|
            return as_real(1, obj.data[1]), obj.data[0]
 | 
						|
        if obj.op is Op.TERMS:
 | 
						|
            if len(obj.data) == 1:
 | 
						|
                (term, coeff), = obj.data.items()
 | 
						|
                return term, coeff
 | 
						|
            # TODO: find common divisor of coefficients
 | 
						|
        if obj.op is Op.APPLY and obj.data[0] is ArithOp.DIV:
 | 
						|
            t, c = as_term_coeff(obj.data[1][0])
 | 
						|
            return as_apply(ArithOp.DIV, t, obj.data[1][1]), c
 | 
						|
        return obj, 1
 | 
						|
    raise OpError(f'cannot convert {type(obj)} to term and coeff')
 | 
						|
 | 
						|
 | 
						|
def as_numer_denom(obj):
 | 
						|
    """Return expression as numer-denom pair.
 | 
						|
    """
 | 
						|
    if isinstance(obj, Expr):
 | 
						|
        obj = normalize(obj)
 | 
						|
        if obj.op in (Op.INTEGER, Op.REAL, Op.COMPLEX, Op.SYMBOL,
 | 
						|
                      Op.INDEXING, Op.TERNARY):
 | 
						|
            return obj, as_number(1)
 | 
						|
        elif obj.op is Op.APPLY:
 | 
						|
            if obj.data[0] is ArithOp.DIV and not obj.data[2]:
 | 
						|
                numers, denoms = map(as_numer_denom, obj.data[1])
 | 
						|
                return numers[0] * denoms[1], numers[1] * denoms[0]
 | 
						|
            return obj, as_number(1)
 | 
						|
        elif obj.op is Op.TERMS:
 | 
						|
            numers, denoms = [], []
 | 
						|
            for term, coeff in obj.data.items():
 | 
						|
                n, d = as_numer_denom(term)
 | 
						|
                n = n * coeff
 | 
						|
                numers.append(n)
 | 
						|
                denoms.append(d)
 | 
						|
            numer, denom = as_number(0), as_number(1)
 | 
						|
            for i in range(len(numers)):
 | 
						|
                n = numers[i]
 | 
						|
                for j in range(len(numers)):
 | 
						|
                    if i != j:
 | 
						|
                        n *= denoms[j]
 | 
						|
                numer += n
 | 
						|
                denom *= denoms[i]
 | 
						|
            if denom.op in (Op.INTEGER, Op.REAL) and denom.data[0] < 0:
 | 
						|
                numer, denom = -numer, -denom
 | 
						|
            return numer, denom
 | 
						|
        elif obj.op is Op.FACTORS:
 | 
						|
            numer, denom = as_number(1), as_number(1)
 | 
						|
            for b, e in obj.data.items():
 | 
						|
                bnumer, bdenom = as_numer_denom(b)
 | 
						|
                if e > 0:
 | 
						|
                    numer *= bnumer ** e
 | 
						|
                    denom *= bdenom ** e
 | 
						|
                elif e < 0:
 | 
						|
                    numer *= bdenom ** (-e)
 | 
						|
                    denom *= bnumer ** (-e)
 | 
						|
            return numer, denom
 | 
						|
    raise OpError(f'cannot convert {type(obj)} to numer and denom')
 | 
						|
 | 
						|
 | 
						|
def _counter():
 | 
						|
    # Used internally to generate unique dummy symbols
 | 
						|
    counter = 0
 | 
						|
    while True:
 | 
						|
        counter += 1
 | 
						|
        yield counter
 | 
						|
 | 
						|
 | 
						|
COUNTER = _counter()
 | 
						|
 | 
						|
 | 
						|
def eliminate_quotes(s):
 | 
						|
    """Replace quoted substrings of input string.
 | 
						|
 | 
						|
    Return a new string and a mapping of replacements.
 | 
						|
    """
 | 
						|
    d = {}
 | 
						|
 | 
						|
    def repl(m):
 | 
						|
        kind, value = m.groups()[:2]
 | 
						|
        if kind:
 | 
						|
            # remove trailing underscore
 | 
						|
            kind = kind[:-1]
 | 
						|
        p = {"'": "SINGLE", '"': "DOUBLE"}[value[0]]
 | 
						|
        k = f'{kind}@__f2py_QUOTES_{p}_{COUNTER.__next__()}@'
 | 
						|
        d[k] = value
 | 
						|
        return k
 | 
						|
 | 
						|
    new_s = re.sub(r'({kind}_|)({single_quoted}|{double_quoted})'.format(
 | 
						|
        kind=r'\w[\w\d_]*',
 | 
						|
        single_quoted=r"('([^'\\]|(\\.))*')",
 | 
						|
        double_quoted=r'("([^"\\]|(\\.))*")'),
 | 
						|
        repl, s)
 | 
						|
 | 
						|
    assert '"' not in new_s
 | 
						|
    assert "'" not in new_s
 | 
						|
 | 
						|
    return new_s, d
 | 
						|
 | 
						|
 | 
						|
def insert_quotes(s, d):
 | 
						|
    """Inverse of eliminate_quotes.
 | 
						|
    """
 | 
						|
    for k, v in d.items():
 | 
						|
        kind = k[:k.find('@')]
 | 
						|
        if kind:
 | 
						|
            kind += '_'
 | 
						|
        s = s.replace(k, kind + v)
 | 
						|
    return s
 | 
						|
 | 
						|
 | 
						|
def replace_parenthesis(s):
 | 
						|
    """Replace substrings of input that are enclosed in parenthesis.
 | 
						|
 | 
						|
    Return a new string and a mapping of replacements.
 | 
						|
    """
 | 
						|
    # Find a parenthesis pair that appears first.
 | 
						|
 | 
						|
    # Fortran deliminator are `(`, `)`, `[`, `]`, `(/', '/)`, `/`.
 | 
						|
    # We don't handle `/` deliminator because it is not a part of an
 | 
						|
    # expression.
 | 
						|
    left, right = None, None
 | 
						|
    mn_i = len(s)
 | 
						|
    for left_, right_ in (('(/', '/)'),
 | 
						|
                          '()',
 | 
						|
                          '{}',  # to support C literal structs
 | 
						|
                          '[]'):
 | 
						|
        i = s.find(left_)
 | 
						|
        if i == -1:
 | 
						|
            continue
 | 
						|
        if i < mn_i:
 | 
						|
            mn_i = i
 | 
						|
            left, right = left_, right_
 | 
						|
 | 
						|
    if left is None:
 | 
						|
        return s, {}
 | 
						|
 | 
						|
    i = mn_i
 | 
						|
    j = s.find(right, i)
 | 
						|
 | 
						|
    while s.count(left, i + 1, j) != s.count(right, i + 1, j):
 | 
						|
        j = s.find(right, j + 1)
 | 
						|
        if j == -1:
 | 
						|
            raise ValueError(f'Mismatch of {left+right} parenthesis in {s!r}')
 | 
						|
 | 
						|
    p = {'(': 'ROUND', '[': 'SQUARE', '{': 'CURLY', '(/': 'ROUNDDIV'}[left]
 | 
						|
 | 
						|
    k = f'@__f2py_PARENTHESIS_{p}_{COUNTER.__next__()}@'
 | 
						|
    v = s[i+len(left):j]
 | 
						|
    r, d = replace_parenthesis(s[j+len(right):])
 | 
						|
    d[k] = v
 | 
						|
    return s[:i] + k + r, d
 | 
						|
 | 
						|
 | 
						|
def _get_parenthesis_kind(s):
 | 
						|
    assert s.startswith('@__f2py_PARENTHESIS_'), s
 | 
						|
    return s.split('_')[4]
 | 
						|
 | 
						|
 | 
						|
def unreplace_parenthesis(s, d):
 | 
						|
    """Inverse of replace_parenthesis.
 | 
						|
    """
 | 
						|
    for k, v in d.items():
 | 
						|
        p = _get_parenthesis_kind(k)
 | 
						|
        left = dict(ROUND='(', SQUARE='[', CURLY='{', ROUNDDIV='(/')[p]
 | 
						|
        right = dict(ROUND=')', SQUARE=']', CURLY='}', ROUNDDIV='/)')[p]
 | 
						|
        s = s.replace(k, left + v + right)
 | 
						|
    return s
 | 
						|
 | 
						|
 | 
						|
def fromstring(s, language=Language.C):
 | 
						|
    """Create an expression from a string.
 | 
						|
 | 
						|
    This is a "lazy" parser, that is, only arithmetic operations are
 | 
						|
    resolved, non-arithmetic operations are treated as symbols.
 | 
						|
    """
 | 
						|
    r = _FromStringWorker(language=language).parse(s)
 | 
						|
    if isinstance(r, Expr):
 | 
						|
        return r
 | 
						|
    raise ValueError(f'failed to parse `{s}` to Expr instance: got `{r}`')
 | 
						|
 | 
						|
 | 
						|
class _Pair:
 | 
						|
    # Internal class to represent a pair of expressions
 | 
						|
 | 
						|
    def __init__(self, left, right):
 | 
						|
        self.left = left
 | 
						|
        self.right = right
 | 
						|
 | 
						|
    def substitute(self, symbols_map):
 | 
						|
        left, right = self.left, self.right
 | 
						|
        if isinstance(left, Expr):
 | 
						|
            left = left.substitute(symbols_map)
 | 
						|
        if isinstance(right, Expr):
 | 
						|
            right = right.substitute(symbols_map)
 | 
						|
        return _Pair(left, right)
 | 
						|
 | 
						|
    def __repr__(self):
 | 
						|
        return f'{type(self).__name__}({self.left}, {self.right})'
 | 
						|
 | 
						|
 | 
						|
class _FromStringWorker:
 | 
						|
 | 
						|
    def __init__(self, language=Language.C):
 | 
						|
        self.original = None
 | 
						|
        self.quotes_map = None
 | 
						|
        self.language = language
 | 
						|
 | 
						|
    def finalize_string(self, s):
 | 
						|
        return insert_quotes(s, self.quotes_map)
 | 
						|
 | 
						|
    def parse(self, inp):
 | 
						|
        self.original = inp
 | 
						|
        unquoted, self.quotes_map = eliminate_quotes(inp)
 | 
						|
        return self.process(unquoted)
 | 
						|
 | 
						|
    def process(self, s, context='expr'):
 | 
						|
        """Parse string within the given context.
 | 
						|
 | 
						|
        The context may define the result in case of ambiguous
 | 
						|
        expressions. For instance, consider expressions `f(x, y)` and
 | 
						|
        `(x, y) + (a, b)` where `f` is a function and pair `(x, y)`
 | 
						|
        denotes complex number. Specifying context as "args" or
 | 
						|
        "expr", the subexpression `(x, y)` will be parse to an
 | 
						|
        argument list or to a complex number, respectively.
 | 
						|
        """
 | 
						|
        if isinstance(s, (list, tuple)):
 | 
						|
            return type(s)(self.process(s_, context) for s_ in s)
 | 
						|
 | 
						|
        assert isinstance(s, str), (type(s), s)
 | 
						|
 | 
						|
        # replace subexpressions in parenthesis with f2py @-names
 | 
						|
        r, raw_symbols_map = replace_parenthesis(s)
 | 
						|
        r = r.strip()
 | 
						|
 | 
						|
        def restore(r):
 | 
						|
            # restores subexpressions marked with f2py @-names
 | 
						|
            if isinstance(r, (list, tuple)):
 | 
						|
                return type(r)(map(restore, r))
 | 
						|
            return unreplace_parenthesis(r, raw_symbols_map)
 | 
						|
 | 
						|
        # comma-separated tuple
 | 
						|
        if ',' in r:
 | 
						|
            operands = restore(r.split(','))
 | 
						|
            if context == 'args':
 | 
						|
                return tuple(self.process(operands))
 | 
						|
            if context == 'expr':
 | 
						|
                if len(operands) == 2:
 | 
						|
                    # complex number literal
 | 
						|
                    return as_complex(*self.process(operands))
 | 
						|
            raise NotImplementedError(
 | 
						|
                f'parsing comma-separated list (context={context}): {r}')
 | 
						|
 | 
						|
        # ternary operation
 | 
						|
        m = re.match(r'\A([^?]+)[?]([^:]+)[:](.+)\Z', r)
 | 
						|
        if m:
 | 
						|
            assert context == 'expr', context
 | 
						|
            oper, expr1, expr2 = restore(m.groups())
 | 
						|
            oper = self.process(oper)
 | 
						|
            expr1 = self.process(expr1)
 | 
						|
            expr2 = self.process(expr2)
 | 
						|
            return as_ternary(oper, expr1, expr2)
 | 
						|
 | 
						|
        # relational expression
 | 
						|
        if self.language is Language.Fortran:
 | 
						|
            m = re.match(
 | 
						|
                r'\A(.+)\s*[.](eq|ne|lt|le|gt|ge)[.]\s*(.+)\Z', r, re.I)
 | 
						|
        else:
 | 
						|
            m = re.match(
 | 
						|
                r'\A(.+)\s*([=][=]|[!][=]|[<][=]|[<]|[>][=]|[>])\s*(.+)\Z', r)
 | 
						|
        if m:
 | 
						|
            left, rop, right = m.groups()
 | 
						|
            if self.language is Language.Fortran:
 | 
						|
                rop = '.' + rop + '.'
 | 
						|
            left, right = self.process(restore((left, right)))
 | 
						|
            rop = RelOp.fromstring(rop, language=self.language)
 | 
						|
            return Expr(Op.RELATIONAL, (rop, left, right))
 | 
						|
 | 
						|
        # keyword argument
 | 
						|
        m = re.match(r'\A(\w[\w\d_]*)\s*[=](.*)\Z', r)
 | 
						|
        if m:
 | 
						|
            keyname, value = m.groups()
 | 
						|
            value = restore(value)
 | 
						|
            return _Pair(keyname, self.process(value))
 | 
						|
 | 
						|
        # addition/subtraction operations
 | 
						|
        operands = re.split(r'((?<!\d[edED])[+-])', r)
 | 
						|
        if len(operands) > 1:
 | 
						|
            result = self.process(restore(operands[0] or '0'))
 | 
						|
            for op, operand in zip(operands[1::2], operands[2::2]):
 | 
						|
                operand = self.process(restore(operand))
 | 
						|
                op = op.strip()
 | 
						|
                if op == '+':
 | 
						|
                    result += operand
 | 
						|
                else:
 | 
						|
                    assert op == '-'
 | 
						|
                    result -= operand
 | 
						|
            return result
 | 
						|
 | 
						|
        # string concatenate operation
 | 
						|
        if self.language is Language.Fortran and '//' in r:
 | 
						|
            operands = restore(r.split('//'))
 | 
						|
            return Expr(Op.CONCAT,
 | 
						|
                        tuple(self.process(operands)))
 | 
						|
 | 
						|
        # multiplication/division operations
 | 
						|
        operands = re.split(r'(?<=[@\w\d_])\s*([*]|/)',
 | 
						|
                            (r if self.language is Language.C
 | 
						|
                             else r.replace('**', '@__f2py_DOUBLE_STAR@')))
 | 
						|
        if len(operands) > 1:
 | 
						|
            operands = restore(operands)
 | 
						|
            if self.language is not Language.C:
 | 
						|
                operands = [operand.replace('@__f2py_DOUBLE_STAR@', '**')
 | 
						|
                            for operand in operands]
 | 
						|
            # Expression is an arithmetic product
 | 
						|
            result = self.process(operands[0])
 | 
						|
            for op, operand in zip(operands[1::2], operands[2::2]):
 | 
						|
                operand = self.process(operand)
 | 
						|
                op = op.strip()
 | 
						|
                if op == '*':
 | 
						|
                    result *= operand
 | 
						|
                else:
 | 
						|
                    assert op == '/'
 | 
						|
                    result /= operand
 | 
						|
            return result
 | 
						|
 | 
						|
        # referencing/dereferencing
 | 
						|
        if r.startswith('*') or r.startswith('&'):
 | 
						|
            op = {'*': Op.DEREF, '&': Op.REF}[r[0]]
 | 
						|
            operand = self.process(restore(r[1:]))
 | 
						|
            return Expr(op, operand)
 | 
						|
 | 
						|
        # exponentiation operations
 | 
						|
        if self.language is not Language.C and '**' in r:
 | 
						|
            operands = list(reversed(restore(r.split('**'))))
 | 
						|
            result = self.process(operands[0])
 | 
						|
            for operand in operands[1:]:
 | 
						|
                operand = self.process(operand)
 | 
						|
                result = operand ** result
 | 
						|
            return result
 | 
						|
 | 
						|
        # int-literal-constant
 | 
						|
        m = re.match(r'\A({digit_string})({kind}|)\Z'.format(
 | 
						|
            digit_string=r'\d+',
 | 
						|
            kind=r'_(\d+|\w[\w\d_]*)'), r)
 | 
						|
        if m:
 | 
						|
            value, _, kind = m.groups()
 | 
						|
            if kind and kind.isdigit():
 | 
						|
                kind = int(kind)
 | 
						|
            return as_integer(int(value), kind or 4)
 | 
						|
 | 
						|
        # real-literal-constant
 | 
						|
        m = re.match(r'\A({significant}({exponent}|)|\d+{exponent})({kind}|)\Z'
 | 
						|
                     .format(
 | 
						|
                         significant=r'[.]\d+|\d+[.]\d*',
 | 
						|
                         exponent=r'[edED][+-]?\d+',
 | 
						|
                         kind=r'_(\d+|\w[\w\d_]*)'), r)
 | 
						|
        if m:
 | 
						|
            value, _, _, kind = m.groups()
 | 
						|
            if kind and kind.isdigit():
 | 
						|
                kind = int(kind)
 | 
						|
            value = value.lower()
 | 
						|
            if 'd' in value:
 | 
						|
                return as_real(float(value.replace('d', 'e')), kind or 8)
 | 
						|
            return as_real(float(value), kind or 4)
 | 
						|
 | 
						|
        # string-literal-constant with kind parameter specification
 | 
						|
        if r in self.quotes_map:
 | 
						|
            kind = r[:r.find('@')]
 | 
						|
            return as_string(self.quotes_map[r], kind or 1)
 | 
						|
 | 
						|
        # array constructor or literal complex constant or
 | 
						|
        # parenthesized expression
 | 
						|
        if r in raw_symbols_map:
 | 
						|
            paren = _get_parenthesis_kind(r)
 | 
						|
            items = self.process(restore(raw_symbols_map[r]),
 | 
						|
                                 'expr' if paren == 'ROUND' else 'args')
 | 
						|
            if paren == 'ROUND':
 | 
						|
                if isinstance(items, Expr):
 | 
						|
                    return items
 | 
						|
            if paren in ['ROUNDDIV', 'SQUARE']:
 | 
						|
                # Expression is a array constructor
 | 
						|
                if isinstance(items, Expr):
 | 
						|
                    items = (items,)
 | 
						|
                return as_array(items)
 | 
						|
 | 
						|
        # function call/indexing
 | 
						|
        m = re.match(r'\A(.+)\s*(@__f2py_PARENTHESIS_(ROUND|SQUARE)_\d+@)\Z',
 | 
						|
                     r)
 | 
						|
        if m:
 | 
						|
            target, args, paren = m.groups()
 | 
						|
            target = self.process(restore(target))
 | 
						|
            args = self.process(restore(args)[1:-1], 'args')
 | 
						|
            if not isinstance(args, tuple):
 | 
						|
                args = args,
 | 
						|
            if paren == 'ROUND':
 | 
						|
                kwargs = dict((a.left, a.right) for a in args
 | 
						|
                              if isinstance(a, _Pair))
 | 
						|
                args = tuple(a for a in args if not isinstance(a, _Pair))
 | 
						|
                # Warning: this could also be Fortran indexing operation..
 | 
						|
                return as_apply(target, *args, **kwargs)
 | 
						|
            else:
 | 
						|
                # Expression is a C/Python indexing operation
 | 
						|
                # (e.g. used in .pyf files)
 | 
						|
                assert paren == 'SQUARE'
 | 
						|
                return target[args]
 | 
						|
 | 
						|
        # Fortran standard conforming identifier
 | 
						|
        m = re.match(r'\A\w[\w\d_]*\Z', r)
 | 
						|
        if m:
 | 
						|
            return as_symbol(r)
 | 
						|
 | 
						|
        # fall-back to symbol
 | 
						|
        r = self.finalize_string(restore(r))
 | 
						|
        ewarn(
 | 
						|
            f'fromstring: treating {r!r} as symbol (original={self.original})')
 | 
						|
        return as_symbol(r)
 |