import numpy as np
from sqlalchemy import text
from config import CONFIG
from main.models._mapped_tables import db_selector
from main.models.insar.frame import Frame, Status
from main.services.satbots.scripts.shared import output
from setup import DB_SESSION, date_bin
from app import app

def process_batch(database, frame, batch_size, last_id, num_bins, bin_indices):
    with DB_SESSION[database].begin() as session:
        sql = f"""
            WITH 
            points AS (
                SELECT id AS point_id
                FROM {frame.partitions[1]}
                WHERE id > {last_id}
                ORDER BY point_id
                LIMIT {batch_size}
            ),
            disps AS (
                SELECT 
                    p.point_id,
                    unnest(values) AS disp,
                    unnest(dates) AS date
                FROM points p
                JOIN {frame.partitions[0]} ON p.point_id = {frame.partitions[0]}.point_id
                JOIN date_arrays ON {frame.partitions[0]}.date_arrays_id = date_arrays.id
            )
            SELECT
                point_id,
                array_agg(disp ORDER BY date) AS disps
            FROM disps
            GROUP BY point_id
        """

        results = session.execute(text(sql)).fetchall()
        num_points = len(results)
        last_id = 0

        average_disp_arrays = []

        for point_id, disps in results:
            if point_id > last_id:
                last_id = point_id

            values = np.array(disps)
            keep = ~np.isnan(values)
            average_disps = \
                np.bincount(bin_indices[keep], weights=values[keep], minlength=num_bins) \
                / np.bincount(bin_indices[keep], minlength=num_bins)
            disps_str = '{' + ','.join(str(x) for x in average_disps) + '}'
            average_disp_arrays.append(f"({point_id}, '{disps_str}'::real[])")

        sql = f"""
            UPDATE {frame.partitions[1]} AS p
            SET average_disps = temp.average_disps
            FROM (VALUES 
                {','.join(average_disp_arrays)}
            ) AS temp(id, average_disps)
            WHERE p.id = temp.id
        """
        session.execute(text(sql))

        session.commit()

    return last_id, num_points

def generate_average_disps(satbot, database, frame_id):
    with DB_SESSION[database].begin() as session:
        if frame_id.isdigit():
            frame = session.query(Frame).filter(Frame.id == frame_id).one()
        else:
            frame = session.query(Frame).filter(Frame.name == frame_id).one()
        
        if frame.status == Status.UPDATING:
            raise ValueError(f"Frame {database}-{frame.id} is already being updated, or something has gone wrong in a previous update to leave it in an incorrect state.")

        original_status = frame.status
        frame.status = Status.UPDATING
        
        # set date bins based on current date arrays
        sql = f"""
            WITH points_sample AS (
                SELECT id AS point_id
                FROM {frame.partitions[1]}
                ORDER BY random()
                LIMIT 100
            ),
            dates AS (
                SELECT DISTINCT dates 
                FROM points_sample 
                JOIN {frame.partitions[0]} ON {frame.partitions[0]}.point_id = points_sample.point_id
                JOIN date_arrays ON date_arrays.id = {frame.partitions[0]}.date_arrays_id
            ) 
            SELECT date
            FROM dates, unnest(dates) AS date
            ORDER BY date
        """

        dates = np.array(session.execute(text(sql)).fetchall()).flatten()
        date_bins = date_bin(dates)
        frame.start_bin = int(date_bins[0])
        frame.end_bin = int(date_bins[-1])
        bins = np.arange(frame.start_bin, frame.end_bin + 1)
        num_bins = bins.shape[0]
        bin_indices = np.digitize(date_bins, bins) - 1

        output(satbot, f"Found {len(dates)} dates")

        sql = f"SELECT reltuples::bigint AS estimate FROM pg_class where relname = '{frame.partitions[1]}';"
        num_points = max(session.execute(text(sql)).fetchone()[0], 0)
        output(satbot, f"Found approximately {num_points} points")

        session.add(frame)
        session.commit()

    batch_size = CONFIG["satbots"]["generate_average_disps"]["batch_size"]
    num_batches = num_points // batch_size + 1
    points_processed = 0
    last_id = 0
    new_points = batch_size
    batch_num = 0
    while new_points == batch_size:
        batch_num += 1
        last_id, new_points = process_batch(database, frame, batch_size, last_id, num_bins, bin_indices)
        points_processed += new_points
        progress = (float(batch_num) / num_batches) * 100
        output(satbot, f"Batch {batch_num}: {points_processed}", progress=min(round(progress), 100))

    output(satbot, f"Processed {points_processed} points", progress=100)

    with DB_SESSION[database].begin() as session:
        frame.status = original_status if original_status != Status.UPDATING else Status.AVAILABLE

        session.add(frame)
        session.commit()

def main(satbot):
    try:
        frame = satbot.decoded_parameters['frame-id'].split('-')
        database = db_selector(frame[0])
        frame_id = frame[-1]
    except (KeyError, AttributeError):
        with app.app_context():
            satbot.update(log=['Invalid Frame ID provided.'])
            return
    
    generate_average_disps(satbot, database, frame_id)
