"""average displacements

Revision ID: 0f8096f0867a
Revises: 50e8b4193a62
Create Date: 2025-09-24 10:26:38.942001

"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql


# revision identifiers, used by Alembic.
revision = '0f8096f0867a'
down_revision = '50e8b4193a62'
branch_labels = None
depends_on = None


def upgrade():
    op.add_column('frames', sa.Column('start_bin', sa.Integer(), nullable=True))
    op.add_column('frames', sa.Column('end_bin', sa.Integer(), nullable=True))

    for table in ['points_asc', 'points_desc']:
        op.add_column(table, sa.Column('average_disps', postgresql.ARRAY(sa.REAL), nullable=True))

    # add function for unnesting an N-dimensional array by 1 dimension
    # e.g. converts a 2D array into a set of its 1D array elements
    # https://stackoverflow.com/questions/8137112/unnest-array-by-one-level/8142998#8142998
    connection = op.get_bind()
    connection.execute(sa.text(
        """
            CREATE OR REPLACE FUNCTION unnest_by_one_dimension(a_nd ANYARRAY, OUT a_1d ANYARRAY)
                RETURNS SETOF ANYARRAY
                LANGUAGE plpgsql IMMUTABLE PARALLEL SAFE STRICT AS
            $func$
            BEGIN
                FOREACH a_1d SLICE 1 IN ARRAY a_nd LOOP
                    RETURN NEXT;
                END LOOP;
            END
            $func$;
        """
    ))


def downgrade():
    op.drop_column('frames', 'start_bin')
    op.drop_column('frames', 'end_bin')

    for table in ['points_asc', 'points_desc']:
        op.drop_column(table, 'average_disps')

    connection = op.get_bind()
    connection.execute(sa.text("DROP FUNCTION IF EXISTS unnest_by_one_dimension;"))
