from datetime import datetime
import json
import numpy
from sqlalchemy import Column, Integer, ARRAY, Table, String, ForeignKey, TIMESTAMP, Float, func, JSON, select
from sqlalchemy.orm import relationship
from geoalchemy2 import Geometry
from geoalchemy2.comparator import Comparator
from geoalchemy2.shape import from_shape, to_shape
from shapely import from_wkb, to_geojson, wkb
from geopandas import GeoDataFrame

from ...services.custom_sqlalchemy_types import UtcNow
from .._mapper import schema
from .custom_group import CustomGroup
from setup import CONFIG, DB_SESSION

from main.services.displacements_helper import DisplacementsHelper
from main.services.geometry_helper import GeometryHelper

class CustomPolygon:
    def __init__(self, custom_group_id, reference, geometry):
        self.custom_group_id = custom_group_id
        self.reference = reference
        self._geometry_wkb = from_shape(geometry, srid=4326) if geometry else None
        self.created_at = datetime.now()
        self.updated_at = datetime.now()

    displacements_helper = DisplacementsHelper()
    geometry_helper = GeometryHelper()

    @property
    def to_dict_minimal(self):
        content = {
            "id": self.id,
            "custom_group": self.custom_group.base_output,
            "reference": self.reference,
            "geometry": json.loads(to_geojson(self.geometry)),
            "area": self.area
        }
        return content

    @property
    def to_dict(self):
        return self.to_dict_minimal | {
            "metrics": self.metrics
        }

    @property
    def geometry(self):
        return from_wkb(self._geometry_wkb.desc)

    @property
    def area(self):
        return GeometryHelper().calculated_area(self.geometry)

    @property
    def validates(self):
        for required_field in ['custom_group_id', 'reference', 'geometry']:
            if not getattr(self, required_field):
                return f"Incomplete data - missing {required_field}"

        return True

    # Returns a list of Frame IDs with geometry covering a provided Point or Polygon
    @classmethod
    def within(self, geometry, models, layers):
        with DB_SESSION[next(iter(models))].begin() as db_session:
            polygons = select(CustomPolygon).where(
                    func.ST_Within(
                        self.__table__.c._geometry_wkb,
                        wkb.dumps(geometry, srid=4326)
                    )
                ).limit(CONFIG['polygons']['map_limit'])
            return [polygon[0].to_dict(models, layers) for polygon in db_session.execute(polygons).all()]

    def ungrouped_models(self, models):
        layers = []
        for database, models in models.items():
            for model in models:
                if {database:model} not in layers:
                    layers.append({database: model})
        return layers

    def points(self, models, layers):
        return self.geometry_helper.point_data_within(
            self.geometry, tables=self.ungrouped_models(models),
            data={"layers": layers, "rmse_range": [0,6] }, limit=CONFIG['points']['default_limit']
        )

    def points_separated_by_layer(self, points):
        separated = {}
        for point in points:
            if not separated.get(point.frame.layer):
                separated[point.frame.layer] = []
            separated[point.frame.layer].append(point)
        return separated

    def diff_calculations(self, diffs):
        output = {}
        if(len(diffs) > 0):
            masked_diffs = numpy.ma.masked_array(diffs, numpy.isinf(diffs))

            output = {
                "avg": round(numpy.nanmean(masked_diffs), 2),
                "max": round(numpy.nanmax(masked_diffs), 2),
                "min": round(numpy.nanmin(masked_diffs), 2)
            }

            for percentile in [5, 10, 30, 50, 70, 90, 95]:
                output[f"p{percentile:02d}"] = round(numpy.nanpercentile(diffs, percentile), 2)

        return output

    def distances(self, points):
        frame = GeoDataFrame({ 'id': [p.id for p in points], 'geometry': [p.geometry for p in points] }, crs="EPSG:4326")
        frame_3857 = frame.to_crs('EPSG:3857')

        matrix = frame_3857.geometry.apply(lambda geometry: frame_3857.geometry.distance(geometry))
        matrix.index = frame_3857['id']
        matrix.columns = frame_3857['id']

        return matrix

    def time_period(self, period, layer_points):
        output = {}

        disp_diffs = {}
        for point in layer_points:
            disp_diffs[point.id] = point.disp_diff(period)

        distances = self.distances(layer_points)

        data = {"t_vels": [], "t_dispdiffs": [], "t_diffs_mv": [], "t_diffs_mv_distances": [], "dates": numpy.array([], dtype=int)}           
        for point in layer_points:
            disp_diff, t_vel, t_dates = disp_diffs[point.id]
            if t_dates.any():
                data["dates"] = numpy.append(data["dates"], t_dates)
                data["t_dispdiffs"].append(disp_diff)
                data["t_vels"].append(t_vel)

                for paired_point in layer_points:
                    if paired_point != point:
                        paired_disp_diff, _, _ = disp_diffs[paired_point.id]
                        diff_mv = numpy.abs(paired_disp_diff - disp_diff)
                        data['t_diffs_mv'].append(diff_mv)
                        data['t_diffs_mv_distances'].append(diff_mv / distances[point.id][paired_point.id])

        if(len(data['t_dispdiffs']) > 0):
            data['dates'] = set(data['dates'])
            output['avg_velocity'] = round(numpy.nanmean(data['t_vels']), 2)

            output['diffs_disps'] = self.diff_calculations(data['t_dispdiffs'])
            output['diffs_abs'] = self.diff_calculations(numpy.abs(data['t_dispdiffs']))
            output['diffs_mv'] = self.diff_calculations(data['t_diffs_mv'])
            output['diffs_mv_len'] = self.diff_calculations(data['t_diffs_mv_distances'])
            output['points'] = len(layer_points)
            output['dates'] = {"count": len(data['dates']), "begin": int(min(data['dates'])), "end": int(max(data['dates']))}

        return output

    def metrics_layers(self, points_separated):
        output = {}

        for layer_key in points_separated:
            output[layer_key] = {"avg_velocity": round(numpy.nanmean([x.velocity for x in points_separated[layer_key]]), 2)}

            for months in [6, 12, 18, 24]:
                output[layer_key][f"{months:02}_months"] = self.time_period(months, points_separated[layer_key])

        return output

    def calculate_metrics(self, access_helper, database):
        layers = CONFIG['tiles']['layers']
        database_layers = {}
        for layer, layer_data in layers.items():
            if layer_data['database'] == database:
                database_layers[layer] = layer_data

        filters = {"layers": database_layers }
        models = access_helper.tables(filters, grouped=True, check_authorised=False)

        points = self.points(models, layers)
        points_separated = self.points_separated_by_layer(points)

        return self.metrics_layers(points_separated)

    def closest(self, location):
        location_as_sql_func = func.ST_GeomFromEWKT("SRID=4326;" + location.wkt)
        query = func.ST_Contains(self.__table__.c._geometry_wkb, location_as_sql_func)

        return query, None

custom_polygons_table = Table(
    "custom_polygons",
    schema.metadata,
    Column("id", Integer, primary_key=True, autoincrement=True),
    Column("custom_group_id", Integer, ForeignKey("custom_groups.id"), nullable=False, index=True),
    Column("reference", String, nullable=True),
    Column("metrics", JSON, nullable=True),
    Column("geometry", Geometry(geometry_type="POLYGON", srid=4326), key="_geometry_wkb", nullable=False),
    Column("created_at", TIMESTAMP(timezone=False), server_default=UtcNow(), nullable=False),
    Column("updated_at", TIMESTAMP(timezone=False), server_default=UtcNow(), nullable=False)
)

schema.map_imperatively(CustomPolygon, custom_polygons_table,
    properties={
        "custom_group": relationship(
            CustomGroup,
            backref="custom_polygons",
            lazy='joined'
        )
    }
)
