You can not select more than 25 topics
			Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
		
		
		
		
		
			
		
			
				
					
					
						
							205 lines
						
					
					
						
							6.0 KiB
						
					
					
				
			
		
		
	
	
							205 lines
						
					
					
						
							6.0 KiB
						
					
					
				from .. import fixtures
 | 
						|
from ..assertions import eq_
 | 
						|
from ..schema import Column
 | 
						|
from ..schema import Table
 | 
						|
from ... import ForeignKey
 | 
						|
from ... import Integer
 | 
						|
from ... import select
 | 
						|
from ... import String
 | 
						|
from ... import testing
 | 
						|
 | 
						|
 | 
						|
class CTETest(fixtures.TablesTest):
 | 
						|
    __backend__ = True
 | 
						|
    __requires__ = ("ctes",)
 | 
						|
 | 
						|
    run_inserts = "each"
 | 
						|
    run_deletes = "each"
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def define_tables(cls, metadata):
 | 
						|
        Table(
 | 
						|
            "some_table",
 | 
						|
            metadata,
 | 
						|
            Column("id", Integer, primary_key=True),
 | 
						|
            Column("data", String(50)),
 | 
						|
            Column("parent_id", ForeignKey("some_table.id")),
 | 
						|
        )
 | 
						|
 | 
						|
        Table(
 | 
						|
            "some_other_table",
 | 
						|
            metadata,
 | 
						|
            Column("id", Integer, primary_key=True),
 | 
						|
            Column("data", String(50)),
 | 
						|
            Column("parent_id", Integer),
 | 
						|
        )
 | 
						|
 | 
						|
    @classmethod
 | 
						|
    def insert_data(cls, connection):
 | 
						|
        connection.execute(
 | 
						|
            cls.tables.some_table.insert(),
 | 
						|
            [
 | 
						|
                {"id": 1, "data": "d1", "parent_id": None},
 | 
						|
                {"id": 2, "data": "d2", "parent_id": 1},
 | 
						|
                {"id": 3, "data": "d3", "parent_id": 1},
 | 
						|
                {"id": 4, "data": "d4", "parent_id": 3},
 | 
						|
                {"id": 5, "data": "d5", "parent_id": 3},
 | 
						|
            ],
 | 
						|
        )
 | 
						|
 | 
						|
    def test_select_nonrecursive_round_trip(self, connection):
 | 
						|
        some_table = self.tables.some_table
 | 
						|
 | 
						|
        cte = (
 | 
						|
            select(some_table)
 | 
						|
            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
 | 
						|
            .cte("some_cte")
 | 
						|
        )
 | 
						|
        result = connection.execute(
 | 
						|
            select(cte.c.data).where(cte.c.data.in_(["d4", "d5"]))
 | 
						|
        )
 | 
						|
        eq_(result.fetchall(), [("d4",)])
 | 
						|
 | 
						|
    def test_select_recursive_round_trip(self, connection):
 | 
						|
        some_table = self.tables.some_table
 | 
						|
 | 
						|
        cte = (
 | 
						|
            select(some_table)
 | 
						|
            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
 | 
						|
            .cte("some_cte", recursive=True)
 | 
						|
        )
 | 
						|
 | 
						|
        cte_alias = cte.alias("c1")
 | 
						|
        st1 = some_table.alias()
 | 
						|
        # note that SQL Server requires this to be UNION ALL,
 | 
						|
        # can't be UNION
 | 
						|
        cte = cte.union_all(
 | 
						|
            select(st1).where(st1.c.id == cte_alias.c.parent_id)
 | 
						|
        )
 | 
						|
        result = connection.execute(
 | 
						|
            select(cte.c.data)
 | 
						|
            .where(cte.c.data != "d2")
 | 
						|
            .order_by(cte.c.data.desc())
 | 
						|
        )
 | 
						|
        eq_(
 | 
						|
            result.fetchall(),
 | 
						|
            [("d4",), ("d3",), ("d3",), ("d1",), ("d1",), ("d1",)],
 | 
						|
        )
 | 
						|
 | 
						|
    def test_insert_from_select_round_trip(self, connection):
 | 
						|
        some_table = self.tables.some_table
 | 
						|
        some_other_table = self.tables.some_other_table
 | 
						|
 | 
						|
        cte = (
 | 
						|
            select(some_table)
 | 
						|
            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
 | 
						|
            .cte("some_cte")
 | 
						|
        )
 | 
						|
        connection.execute(
 | 
						|
            some_other_table.insert().from_select(
 | 
						|
                ["id", "data", "parent_id"], select(cte)
 | 
						|
            )
 | 
						|
        )
 | 
						|
        eq_(
 | 
						|
            connection.execute(
 | 
						|
                select(some_other_table).order_by(some_other_table.c.id)
 | 
						|
            ).fetchall(),
 | 
						|
            [(2, "d2", 1), (3, "d3", 1), (4, "d4", 3)],
 | 
						|
        )
 | 
						|
 | 
						|
    @testing.requires.ctes_with_update_delete
 | 
						|
    @testing.requires.update_from
 | 
						|
    def test_update_from_round_trip(self, connection):
 | 
						|
        some_table = self.tables.some_table
 | 
						|
        some_other_table = self.tables.some_other_table
 | 
						|
 | 
						|
        connection.execute(
 | 
						|
            some_other_table.insert().from_select(
 | 
						|
                ["id", "data", "parent_id"], select(some_table)
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        cte = (
 | 
						|
            select(some_table)
 | 
						|
            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
 | 
						|
            .cte("some_cte")
 | 
						|
        )
 | 
						|
        connection.execute(
 | 
						|
            some_other_table.update()
 | 
						|
            .values(parent_id=5)
 | 
						|
            .where(some_other_table.c.data == cte.c.data)
 | 
						|
        )
 | 
						|
        eq_(
 | 
						|
            connection.execute(
 | 
						|
                select(some_other_table).order_by(some_other_table.c.id)
 | 
						|
            ).fetchall(),
 | 
						|
            [
 | 
						|
                (1, "d1", None),
 | 
						|
                (2, "d2", 5),
 | 
						|
                (3, "d3", 5),
 | 
						|
                (4, "d4", 5),
 | 
						|
                (5, "d5", 3),
 | 
						|
            ],
 | 
						|
        )
 | 
						|
 | 
						|
    @testing.requires.ctes_with_update_delete
 | 
						|
    @testing.requires.delete_from
 | 
						|
    def test_delete_from_round_trip(self, connection):
 | 
						|
        some_table = self.tables.some_table
 | 
						|
        some_other_table = self.tables.some_other_table
 | 
						|
 | 
						|
        connection.execute(
 | 
						|
            some_other_table.insert().from_select(
 | 
						|
                ["id", "data", "parent_id"], select(some_table)
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        cte = (
 | 
						|
            select(some_table)
 | 
						|
            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
 | 
						|
            .cte("some_cte")
 | 
						|
        )
 | 
						|
        connection.execute(
 | 
						|
            some_other_table.delete().where(
 | 
						|
                some_other_table.c.data == cte.c.data
 | 
						|
            )
 | 
						|
        )
 | 
						|
        eq_(
 | 
						|
            connection.execute(
 | 
						|
                select(some_other_table).order_by(some_other_table.c.id)
 | 
						|
            ).fetchall(),
 | 
						|
            [(1, "d1", None), (5, "d5", 3)],
 | 
						|
        )
 | 
						|
 | 
						|
    @testing.requires.ctes_with_update_delete
 | 
						|
    def test_delete_scalar_subq_round_trip(self, connection):
 | 
						|
 | 
						|
        some_table = self.tables.some_table
 | 
						|
        some_other_table = self.tables.some_other_table
 | 
						|
 | 
						|
        connection.execute(
 | 
						|
            some_other_table.insert().from_select(
 | 
						|
                ["id", "data", "parent_id"], select(some_table)
 | 
						|
            )
 | 
						|
        )
 | 
						|
 | 
						|
        cte = (
 | 
						|
            select(some_table)
 | 
						|
            .where(some_table.c.data.in_(["d2", "d3", "d4"]))
 | 
						|
            .cte("some_cte")
 | 
						|
        )
 | 
						|
        connection.execute(
 | 
						|
            some_other_table.delete().where(
 | 
						|
                some_other_table.c.data
 | 
						|
                == select(cte.c.data)
 | 
						|
                .where(cte.c.id == some_other_table.c.id)
 | 
						|
                .scalar_subquery()
 | 
						|
            )
 | 
						|
        )
 | 
						|
        eq_(
 | 
						|
            connection.execute(
 | 
						|
                select(some_other_table).order_by(some_other_table.c.id)
 | 
						|
            ).fetchall(),
 | 
						|
            [(1, "d1", None), (5, "d5", 3)],
 | 
						|
        )
 |