from sqlalchemy import text
from main.models._mapped_tables import Frame
from main.models.users.region import Region
from main.services.satbots.scripts.shared import output
from setup import DB_SESSION, CONFIG, app
import csv
from pathlib import Path

BATCH_COUNT = 100000

def retrieve_batch(db_session, layer, layer_config, region, frames, satbot, writer, highest_id=0, count=0):
    points = db_session.execute(text(f"""
    SELECT points.id, '{layer}', frames.name, points.hdf5_x, points.hdf5_y
    FROM points_{layer_config['orbit']} points
    JOIN frames ON (frames.id = frame_id)
    WHERE ST_WITHIN(points.geometry, '{region.region_polygon}')
        AND frame_id IN ({str([frame.id for frame in frames]).replace('[', '').replace(']', '')})
        AND points.id > {highest_id}
    ORDER BY points.id
    LIMIT {BATCH_COUNT}
    """)).fetchall()

    writer.writerows(points)

    count += len(points)

    if len(points) == BATCH_COUNT:
        highest_id = points[-1].id
        retrieve_batch(db_session, layer, layer_config, region, frames, satbot, writer, highest_id, count)
    else:
        output(satbot, f"   Points: {count}")

def main(satbot=None):
    file_path = f"{satbot.working_directory}/{satbot.file_name}"
    Path(file_path).unlink(missing_ok=True)

    try:
        start_date = satbot.decoded_parameters.get('start_date')
        end_date = satbot.decoded_parameters.get('end_date')
        layers = satbot.decoded_parameters.get('layers')
        region_id = satbot.decoded_parameters.get('region')
    except(KeyError, AttributeError):
            output(satbot, 'Invalid data provided.')
            return

    with open(file_path, 'a', encoding="utf-8") as csv_file:
        writer = csv.writer(csv_file)
        writer.writerow(['id', 'layer', 'frame', 'row', 'col'])

        with DB_SESSION['users'].begin() as db_session:
            region = db_session.query(Region).filter(Region.id == region_id).one()

        output(satbot, f"Filtering by Region {region.id}: \"{region.name}\"\n{region.polygon_geojson}")

        for i, layer in enumerate(layers):
            layer_config = CONFIG['tiles']['layers'][layer]
            output(satbot, f"---------------------------------------\n{layer_config['group']} {layer_config['name']}", progress=i/len(layers)*100)

            with DB_SESSION[layer_config['database']].begin() as db_session:
                frames = db_session.execute(text(f"SELECT id FROM frames WHERE layer = '{layer}' AND ST_INTERSECTS(geometry, '{region.region_polygon}')")).fetchall()
                output(satbot, f"   Frames: {len(frames)}")

                if len(frames) > 0:
                    retrieve_batch(db_session, layer, layer_config, region, frames, satbot, writer)
                else:
                    output(satbot, '   Points: 0')
