"""add frame stats

Revision ID: 8350e5a3e24c
Revises: afaee09a3222
Create Date: 2025-11-26 12:24:15.234425

"""
from alembic import op
import sqlalchemy as sa
import geoalchemy2


# revision identifiers, used by Alembic.
revision = '8350e5a3e24c'
down_revision = 'afaee09a3222'
branch_labels = None
depends_on = None


def upgrade():
    op.add_column('frames', sa.Column('points_count', sa.INTEGER(), nullable=True))
    op.add_column('frames', sa.Column('geometry_complex', geoalchemy2.types.Geometry(geometry_type='POLYGON', srid=4326, spatial_index=False, from_text='ST_GeomFromEWKT', name='geometry'), nullable=True)),
    op.add_column('frames', sa.Column('geometry_bounds', geoalchemy2.types.Geometry(geometry_type='POLYGON', srid=4326, spatial_index=False, from_text='ST_GeomFromEWKT', name='geometry'), nullable=True)),

    op.execute(sa.text("""
        CREATE OR REPLACE FUNCTION frame_geometry(partition_name VARCHAR, frame_id INT) RETURNS void LANGUAGE plpgsql AS $func$
            DECLARE
                batch_count INT := 0;
                points_per_batch INT := 1000000;
                i INT := 0;
                sum INT := 0;
                geom VARCHAR := '';
                temp_table_name VARCHAR := (partition_name || '_geometry');
            BEGIN
                EXECUTE FORMAT('DROP TABLE if exists %s', temp_table_name);
                EXECUTE FORMAT('CREATE TEMPORARY TABLE %s(total int, geometry geometry)', temp_table_name);

                LOOP
                    EXECUTE FORMAT('
                        WITH geometry_cte AS (
                            SELECT geometry FROM %s ORDER BY id LIMIT ' || points_per_batch || ' OFFSET (' || i * points_per_batch || ')
                        )
                        INSERT INTO %s(total, geometry)
                            SELECT
                                count(*),
                                ST_ConvexHull(ST_Collect(array_agg(geometry))) 
                            FROM geometry_cte;', partition_name, temp_table_name);

                    EXECUTE FORMAT('SELECT min(total) FROM %s', temp_table_name) INTO batch_count;
                    EXIT WHEN batch_count < points_per_batch;
                    i := i + 1;
                END LOOP;

                EXECUTE FORMAT('SELECT sum(total) FROM %s', temp_table_name) INTO sum;

                EXECUTE FORMAT('UPDATE FRAMES SET 
                    points_count = ' || sum || ', 
                    geometry_bounds = (SELECT ST_SetSRID(ST_EXTENT(geometry), 4326) FROM %s), 
                    geometry_complex = (SELECT ST_Union(geometry) FROM %s), 
                    geometry = (SELECT ST_OrientedEnvelope(ST_Union(geometry)) FROM %s) 
                    WHERE id = '|| frame_id || ';', temp_table_name, temp_table_name, temp_table_name);
                EXECUTE FORMAT('DROP TABLE %s', temp_table_name);
            END;
        $func$;
    """))


def downgrade():
    op.drop_column('frames', 'points_count')
    op.drop_column('frames', 'geometry_complex')
    op.drop_column('frames', 'geometry_bounds')

    op.execute('DROP function frame_geometry(partition_name VARCHAR, frame_id INT);')
