import copy
import io
import json
import math
import os
import zipfile
from pathlib import Path
from time import sleep
from typing import List

import requests
from bridgestyle.mapboxgl.fromgeostyler import convertGroup
from bridgestyle.qgis import togeostyler
from qgis.core import Qgis, QgsMessageLog, QgsProject
from qgis.PyQt.QtCore import QSettings, QThread, pyqtSignal, pyqtSlot
from threedi_mi_utils import bypass_max_path_limit

from rana_qgis_plugin.utils_lizard import import_from_geostyler

from .utils import (
    build_vrt,
    get_local_file_path,
    image_to_bytes,
    split_scenario_extent,
)
from .utils_api import (
    finish_file_upload,
    get_project_jobs,
    get_raster_file_link,
    get_raster_style_file,
    get_raster_style_upload_urls,
    get_tenant_file_descriptor,
    get_tenant_file_descriptor_view,
    get_tenant_file_url,
    get_tenant_project_file,
    get_vector_style_file,
    get_vector_style_upload_urls,
    map_result_to_file_name,
    request_raster_generate,
    start_file_upload,
)

CHUNK_SIZE = 1024 * 1024  # 1 MB


class FileDownloadWorker(QThread):
    """Worker thread for downloading files."""

    progress = pyqtSignal(int, str)
    finished = pyqtSignal(dict, dict, str)
    failed = pyqtSignal(str)

    def __init__(
        self,
        project: dict,
        file: dict,
    ):
        super().__init__()
        self.project = project
        self.file = file

    @pyqtSlot()
    def run(self):
        project_slug = self.project["slug"]
        path = self.file["id"]
        descriptor_id = self.file["descriptor_id"]
        url = get_tenant_file_url(self.project["id"], {"path": path})
        local_dir_structure, local_file_path = get_local_file_path(project_slug, path)
        os.makedirs(local_dir_structure, exist_ok=True)
        try:
            with requests.get(url, stream=True) as response:
                response.raise_for_status()
                total_size = int(response.headers.get("content-length", 0))
                downloaded_size = 0
                previous_progress = -1
                with open(local_file_path, "wb") as file:
                    for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
                        file.write(chunk)
                        downloaded_size += len(chunk)
                        progress = int((downloaded_size / total_size) * 100)
                        if progress > previous_progress:
                            self.progress.emit(progress, "")
                            previous_progress = progress
            # Fetch and extract the QML zip for vector files
            if self.file["data_type"] in ["vector", "raster"]:
                if self.file["data_type"] == "raster":
                    qml_zip_content = get_raster_style_file(descriptor_id, "qml.zip")
                else:
                    qml_zip_content = get_vector_style_file(descriptor_id, "qml.zip")
                if qml_zip_content:
                    stream = io.BytesIO(qml_zip_content)
                    if zipfile.is_zipfile(stream):
                        with zipfile.ZipFile(stream, "r") as zip_file:
                            zip_file.extractall(local_dir_structure)
            self.finished.emit(self.project, self.file, local_file_path)
        except requests.exceptions.RequestException as e:
            self.failed.emit(f"Failed to download file: {str(e)}")
        except Exception as e:
            self.failed.emit(f"An error occurred: {str(e)}")


class FileUploadWorker(QThread):
    """Worker thread for uploading new (non-rana) files."""

    progress = pyqtSignal(int, str)
    finished = pyqtSignal(dict)
    conflict = pyqtSignal()
    failed = pyqtSignal(str)
    warning = pyqtSignal(str)

    def __init__(self, project: dict, local_paths: list[Path], online_dir: str):
        super().__init__()
        self.project = project
        self.local_paths = local_paths
        self.online_dir = online_dir

    def handle_file_conflict(self, online_path):
        server_file = get_tenant_project_file(self.project["id"], {"path": online_path})
        if server_file:
            self.failed.emit("File already exist on server.")
            return False
        return True  # Continue to upload

    @pyqtSlot()
    def run(self):
        # For a single file finished is only emitted if upload was successfull
        if len(self.local_paths) == 1:
            success = self.upload_single_file(self.local_paths[0], 0, 100)
            if success:
                self.finished.emit(self.project)
        # For a multi upload we always emit finish
        else:
            progress_per_file = 100 // len(self.local_paths)
            for i, local_path in enumerate(self.local_paths):
                self.upload_single_file(
                    local_path, i * progress_per_file, progress_per_file
                )
            self.finished.emit(self.project)

    def upload_single_file(
        self, local_path: Path, progress_start, progress_step
    ) -> bool:
        online_path = f"{self.online_dir}{local_path.name}"
        # Check if file exists locally before uploading
        if not local_path.exists():
            self.failed.emit(f"File not found: {local_path}")
            return False

        # Handle file conflict
        continue_upload = self.handle_file_conflict(online_path)
        if not continue_upload:
            return False

        # Save file to Rana
        try:
            self.progress.emit(progress_start, "")
            # Step 1: POST request to initiate the upload
            upload_response = start_file_upload(
                self.project["id"], {"path": online_path}
            )
            if not upload_response:
                self.failed.emit("Failed to initiate file upload.")
                return False
            upload_url = upload_response["urls"][0]
            # Step 2: Upload the file to the upload_url
            self.progress.emit(int(0.2 * progress_step + progress_start), "")
            with open(local_path, "rb") as file:
                response = requests.put(upload_url, data=file)
                response.raise_for_status()
            # Step 3: Complete the upload
            self.progress.emit(int(0.8 * progress_step + progress_start), "")
            response = finish_file_upload(
                self.project["id"],
                upload_response,
            )
            if not response:
                self.failed.emit("Failed to complete file upload.")
                return False
            self.progress.emit(progress_start + progress_step, "")
        except Exception as e:
            self.failed.emit(f"Failed to upload file to Rana: {str(e)}")
            return False
        return True


class ExistingFileUploadWorker(FileUploadWorker):
    """Worker thread for uploading files."""

    def __init__(self, project: dict, file: dict):
        local_file = Path(get_local_file_path(project["slug"], file["id"])[1])
        if "/" not in file["id"]:
            online_dir = ""
        else:
            online_dir = file["id"][: file["id"].rindex("/") + 1]
        super().__init__(project, [local_file], online_dir)

        self.file_overwrite = False
        self.last_modified = None
        self.last_modified_key = f"{project['name']}/{file['id']}/last_modified"
        self.finished.connect(self._finish)

    def handle_file_conflict(self, online_path):
        local_last_modified = QSettings().value(self.last_modified_key)
        server_file = get_tenant_project_file(self.project["id"], {"path": online_path})
        if not server_file:
            self.failed.emit(
                "Failed to get file from server. Check if file has been moved or deleted."
            )
            return False
        self.last_modified = server_file["last_modified"]
        if self.last_modified != local_last_modified:
            self.conflict.emit()
            while self.file_overwrite is None:
                self.msleep(100)
            if self.file_overwrite is False:
                self.failed.emit("File upload aborted.")
                return False
        return True  # Continue to upload

    def _finish(self):
        QSettings().setValue(self.last_modified_key, self.last_modified)


class VectorStyleWorker(QThread):
    """Worker thread for generating vector styling files"""

    finished = pyqtSignal(str)
    failed = pyqtSignal(str)
    warning = pyqtSignal(str)

    def __init__(
        self,
        project: dict,
        file: dict,
    ):
        super().__init__()
        self.project = project
        self.file = file

    def upload_to_s3(self, url: str, data: dict, content_type: str):
        """Method to upload to S3"""
        try:
            headers = {"Content-Type": content_type}
            response = requests.put(url, data=data, headers=headers)
            response.raise_for_status()
        except requests.exceptions.RequestException as e:
            self.failed.emit(f"Failed to upload file to S3: {str(e)}")

    def create_qml_zip(self, local_dir: str, zip_path: str):
        """Craete a QML zip file for all the qml files in the local directory"""
        try:
            with zipfile.ZipFile(zip_path, "w") as zip_file:
                for root, _, files in os.walk(local_dir):
                    for file in files:
                        if file.endswith(".qml"):
                            file_path = os.path.join(root, file)
                            zip_file.write(
                                file_path, os.path.relpath(file_path, local_dir)
                            )
        except Exception as e:
            self.failed.emit(f"Failed to create QML zip: {str(e)}")

    @pyqtSlot()
    def run(self):
        if not self.file:
            self.failed.emit("File not found.")
            return
        path = self.file["id"]
        file_name = os.path.basename(path.rstrip("/"))
        descriptor_id = self.file["descriptor_id"]
        all_layers = QgsProject.instance().mapLayers().values()
        layers = [layer for layer in all_layers if file_name in layer.source()]

        if not layers:
            self.failed.emit(
                f"No layers found for {file_name}. Open the file in QGIS and try again."
            )
            return

        qgis_layers = {layer.name(): layer for layer in layers}
        group = {"layers": list(qgis_layers.keys())}
        base_url = "http://baseUrl"

        # Save QML style files for each layer to local directory
        local_dir, _ = get_local_file_path(self.project["slug"], path)
        os.makedirs(local_dir, exist_ok=True)
        for layer in layers:
            qml_path = os.path.join(local_dir, f"{layer.name()}.qml")
            layer.saveNamedStyle(str(qml_path))

        # Convert QGIS layers to styling files for the Rana Web Client
        try:
            _, warnings, mb_style, sprite_sheet = convertGroup(
                group, qgis_layers, base_url, workspace="workspace", name="default"
            )
            if warnings:
                self.warning.emit(", ".join(set(warnings)))

            # Get upload URLs to S3
            upload_urls = get_vector_style_upload_urls(descriptor_id)

            if not upload_urls:
                self.failed.emit("Failed to get vector style upload URLs from the API.")
                return

            # Upload style.json
            self.upload_to_s3(
                upload_urls["style.json"],
                json.dumps(mb_style).replace(r"\\n", r"\n").replace(r"\\t", r"\t"),
                "application/json",
            )

            # Upload sprite images if available
            if sprite_sheet and sprite_sheet.get("img") and sprite_sheet.get("img2x"):
                self.upload_to_s3(
                    upload_urls["sprite.png"],
                    image_to_bytes(sprite_sheet["img"]),
                    "image/png",
                )
                self.upload_to_s3(
                    upload_urls["sprite@2x.png"],
                    image_to_bytes(sprite_sheet["img2x"]),
                    "image/png",
                )
                self.upload_to_s3(
                    upload_urls["sprite.json"], sprite_sheet["json"], "application/json"
                )
                self.upload_to_s3(
                    upload_urls["sprite@2x.json"],
                    sprite_sheet["json2x"],
                    "application/json",
                )

            # Zip and upload QML zip
            zip_path = os.path.join(local_dir, "qml.zip")
            self.create_qml_zip(local_dir, zip_path)
            with open(zip_path, "rb") as file:
                self.upload_to_s3(upload_urls["qml.zip"], file, "application/zip")
            os.remove(zip_path)

            # Finish
            self.finished.emit(f"Styling files uploaded successfully for {file_name}.")
        except Exception as e:
            self.failed.emit(f"Failed to generate and upload styling files: {str(e)}")


class RasterStyleWorker(QThread):
    """Worker thread for generating and uploading raster styling files"""

    finished = pyqtSignal(str)
    failed = pyqtSignal(str)
    warning = pyqtSignal(str)

    def __init__(
        self,
        project: dict,
        file: dict,
    ):
        super().__init__()
        self.project = project
        self.file = file

    @pyqtSlot()
    def run(self):
        if not self.file:
            self.failed.emit("File not found.")
            return
        path = self.file["id"]
        file_name = os.path.basename(path.rstrip("/"))
        descriptor_id = self.file["descriptor_id"]
        all_layers = QgsProject.instance().mapLayers().values()
        layers = [layer for layer in all_layers if file_name in layer.source()]

        if not layers:
            self.failed.emit(
                f"No layers found for {file_name}. Open the file in QGIS and try again."
            )
            return

        if not len(layers) == 1:
            self.failed.emit(
                f"Multiple layers found for {file_name}. Open the file in QGIS and try again."
            )
            return

        layer = layers[0]

        # Save QML style files for each layer to local directory
        local_dir, _ = get_local_file_path(self.project["slug"], path)
        os.makedirs(local_dir, exist_ok=True)

        qml_file_name = os.path.splitext(layer.name())[0] + ".qml"
        qml_path = os.path.join(local_dir, qml_file_name)
        layer.saveNamedStyle(str(qml_path))
        zip_path = os.path.join(local_dir, "qml.zip")
        with zipfile.ZipFile(zip_path, "w") as zipf:
            zipf.write(qml_path, qml_file_name)

        # Raster styling to geostyler, and then to lizard styling
        try:
            geostyler, _, _, warnings = togeostyler.convert(layer)
            if len(geostyler["rules"]) != 1:
                self.failed.emit(f"Multiple rules found for {file_name}.")
                return
            if len(geostyler["rules"][0]["symbolizers"]) != 1:
                self.failed.emit(f"Multiple symbolizers found for {file_name}.")
                return

            lizard_styling = import_from_geostyler(
                geostyler["rules"][0]["symbolizers"][0]
            )

            # Do some corrections and checks
            labels = copy.deepcopy(lizard_styling.get("labels", {}))
            for language, ranges in labels.items():
                new_labels = []
                for quantity, label in ranges:
                    if math.isinf(quantity):
                        warnings.append(
                            f"Label '{label}' with infinite quantity cannot be used and will be ignored."
                        )
                    else:
                        new_labels.append([quantity, label])

                lizard_styling["labels"][language] = new_labels

            if lizard_styling["type"] == "DiscreteColormap":
                for entry, _ in lizard_styling["data"]:
                    if isinstance(entry, float):
                        self.failed.emit(
                            f"Failed to generate and upload styling files: DiscreteColormap cannot contain float quantities."
                        )
                        return

            if warnings:
                self.warning.emit(", ".join(set(warnings)))

            lizard_styling_path = os.path.join(local_dir, "colormap.json")
            with open(lizard_styling_path, "w") as f:
                json.dump(lizard_styling, f)

            files = [
                ("files", "colormap.json", lizard_styling_path, "application/json"),
                ("files", "qml.zip", zip_path, "application/zip"),
            ]
            get_raster_style_upload_urls(descriptor_id, files)

            os.remove(zip_path)
            os.remove(lizard_styling_path)

            self.finished.emit(f"Styling files uploaded successfully for {file_name}.")
        except Exception as e:
            self.failed.emit(f"Failed to generate and upload styling files: {str(e)}")


class LizardResultDownloadWorker(QThread):
    """Worker thread for downloading files from ."""

    progress = pyqtSignal(int, str)
    finished = pyqtSignal(dict, dict, str)
    failed = pyqtSignal(str)

    def __init__(
        self,
        project: dict,
        file: dict,
        result_ids: List[int],
        target_folder: str,
        grid: dict,
        nodata: int,
        pixelsize: float,
        crs: str,
        download_raw: bool,
    ):
        super().__init__()
        self.project = project
        self.file = file
        self.result_ids = result_ids
        self.target_folder = target_folder
        self.grid = grid
        self.nodata = nodata
        self.pixelsize = pixelsize
        self.crs = crs
        self.download_raw = download_raw

    @pyqtSlot()
    def run(self):
        if self.download_raw:
            project_slug = self.project["slug"]
            path = self.file["id"]
            descriptor_id = self.file["descriptor_id"]
            url = get_tenant_file_url(self.project["id"], {"path": path})
            local_dir_structure, local_file_path = get_local_file_path(
                project_slug, path
            )
            os.makedirs(local_dir_structure, exist_ok=True)
            try:
                with requests.get(url, stream=True) as response:
                    response.raise_for_status()
                    total_size = int(response.headers.get("content-length", 0))
                    downloaded_size = 0
                    previous_progress = -1
                    with open(local_file_path, "wb") as file:
                        for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
                            file.write(chunk)
                            downloaded_size += len(chunk)
                            progress = int((downloaded_size / total_size) * 100)
                            if progress > previous_progress:
                                self.progress.emit(progress, "Downloading raw data")
                                previous_progress = progress
            except requests.exceptions.RequestException as e:
                self.failed.emit(f"Failed to download file: {str(e)}")
                return
            except Exception as e:
                self.failed.emit(f"An error occurred: {str(e)}")
                return

            # unzip the raw results in the working directory
            with zipfile.ZipFile(local_file_path, "r") as zip_ref:
                zip_ref.extractall(self.target_folder)
                # check if there is a zip containing log files, these need to be extracted as well
                descriptor = get_tenant_file_descriptor(self.file["descriptor_id"])
                try:
                    sim_id = descriptor["meta"]["simulation"]["id"]
                    log_zip_path = os.path.join(
                        self.target_folder, f"log_files_sim_{sim_id}.zip"
                    )
                    if os.path.isfile(log_zip_path):
                        with zipfile.ZipFile(log_zip_path, "r") as log_zip_ref:
                            log_zip_ref.extractall(self.target_folder)
                    else:
                        QgsMessageLog.logMessage(
                            "Subarchive containing log files not present, ignoring.",
                            level=Qgis.MessageLevel.Warning,
                        )
                except KeyError:
                    QgsMessageLog.logMessage(
                        "Subarchive info missing, ignoring.",
                        level=Qgis.MessageLevel.Critical,
                    )

        descriptor_id = self.file["descriptor_id"]
        task_failed = False
        for result_id in self.result_ids:
            # Retrieve URLS from file descriptors (again), presigned url might be expired
            results = get_tenant_file_descriptor_view(
                descriptor_id, "lizard-scenario-results"
            )
            result = [r for r in results if r["id"] == result_id][0]
            file_name = map_result_to_file_name(result)
            # if raster can be downloaded directly from rana
            if result["attachment_url"]:
                target_file = bypass_max_path_limit(
                    os.path.join(self.target_folder, file_name)
                )
                try:
                    with requests.get(
                        result["attachment_url"], stream=True
                    ) as response:
                        response.raise_for_status()
                        total_size = int(response.headers.get("content-length", 0))
                        downloaded_size = 0
                        previous_progress = -1
                        with open(target_file, "wb") as file:
                            for chunk in response.iter_content(chunk_size=CHUNK_SIZE):
                                file.write(chunk)
                                downloaded_size += len(chunk)
                                progress = int((downloaded_size / total_size) * 100)
                                if progress > previous_progress:
                                    self.progress.emit(progress, file_name)
                                    previous_progress = progress

                except requests.exceptions.RequestException as e:
                    self.failed.emit(f"Failed to download file: {str(e)}")
                    task_failed = True
                    break
                except Exception as e:
                    self.failed.emit(f"An error occurred: {str(e)}")
                    task_failed = True
                    break
            # if raster first needs to be generated
            else:
                previous_progress = -1
                spatial_bounds = split_scenario_extent(
                    grid=self.grid, resolution=self.pixelsize, max_pixel_count=1 * 10**8
                )
                # start generate task for each tile of the raster to be downloaded
                bboxes, width, height = spatial_bounds
                raster_tasks = []
                counter = 0
                for x1, y1, x2, y2 in bboxes:
                    bbox = f"{x1},{y1},{x2},{y2}"
                    payload = {
                        "width": width,
                        "height": height,
                        "bbox": bbox,
                        "projection": self.crs,
                        "format": "geotiff",
                        "async": "true",
                    }
                    if self.nodata is not None:
                        payload["nodata"] = self.nodata
                    r = request_raster_generate(
                        descriptor_id=descriptor_id,
                        raster_id=result["raster_id"],
                        payload=payload,
                    )
                    raster_tasks.append(r)
                    counter += 1
                    progress = int((counter / len(bboxes)) * 10)
                    if progress > previous_progress:
                        self.progress.emit(progress, file_name)
                        previous_progress = progress

                # multi-tile raster download
                if len(raster_tasks) > 1:

                    def download_tile(file_link, target_file):
                        with requests.get(file_link, stream=True) as response:
                            response.raise_for_status()
                            total_size = int(response.headers.get("content-length", 0))
                            with open(target_file, "wb") as file:
                                for chunk in response.iter_content(
                                    chunk_size=CHUNK_SIZE
                                ):
                                    file.write(chunk)

                    rasters = {
                        raster_task_id: {
                            "downloaded": False,
                            "filepath": bypass_max_path_limit(
                                os.path.join(
                                    self.target_folder,
                                    f"{file_name}{task_number:02d}.tif",
                                )
                            ),
                        }
                        for task_number, raster_task_id in enumerate(raster_tasks)
                    }
                    task_counter = 0

                    while (
                        False in [task["downloaded"] for task in rasters.values()]
                        and not task_failed
                    ):
                        # wait between each repoll of task statuses
                        sleep(5)
                        for raster_task_id in rasters.keys():
                            # poll all raster generate tasks to check if any is ready to download
                            if not rasters[raster_task_id]["downloaded"]:
                                try:
                                    file_link = get_raster_file_link(
                                        descriptor_id=descriptor_id,
                                        task_id=raster_task_id,
                                    )
                                    if file_link:
                                        download_tile(
                                            file_link,
                                            rasters[raster_task_id]["filepath"],
                                        )
                                        rasters[raster_task_id]["downloaded"] = True

                                        task_counter += 1
                                        # reserve last 10% of progress for raster merging
                                        progress = int(
                                            10 + (task_counter / len(raster_tasks)) * 80
                                        )
                                        if progress > previous_progress:
                                            self.progress.emit(progress, file_name)
                                        previous_progress = progress
                                except requests.exceptions.RequestException as e:
                                    self.failed.emit(
                                        f"Failed to download file: {str(e)}"
                                    )
                                    task_failed = True
                                    break
                                except Exception as e:
                                    self.failed.emit(f"An error occurred: {str(e)}")
                                    task_failed = True
                                    break
                    if not task_failed:
                        raster_filepaths = [
                            item["filepath"] for item in rasters.values()
                        ]
                        raster_filepaths.sort()
                        first_raster_filepath = raster_filepaths[0]
                        vrt_filepath = first_raster_filepath.replace("_01", "").replace(
                            ".tif", ".vrt"
                        )

                        vrt_options = {
                            "resolution": "average",
                            "resampleAlg": "nearest",
                            "srcNodata": self.nodata,
                        }
                        build_vrt(vrt_filepath, raster_filepaths, **vrt_options)
                        self.progress.emit(100, file_name)
                # single-tile raster download
                else:
                    target_file = bypass_max_path_limit(
                        os.path.join(self.target_folder, (file_name + ".tif"))
                    )
                    file_link = False
                    while not (file_link or task_failed):
                        sleep(5)
                        try:
                            file_link = get_raster_file_link(
                                descriptor_id=descriptor_id, task_id=raster_tasks[0]
                            )
                            if not file_link:
                                continue

                            with requests.get(file_link, stream=True) as response:
                                response.raise_for_status()
                                total_size = int(
                                    response.headers.get("content-length", 0)
                                )
                                downloaded_size = 0
                                with open(target_file, "wb") as file:
                                    for chunk in response.iter_content(
                                        chunk_size=CHUNK_SIZE
                                    ):
                                        file.write(chunk)
                                        downloaded_size += len(chunk)
                                        progress = int(
                                            10 + (downloaded_size / total_size) * 90
                                        )
                                        if progress > previous_progress:
                                            self.progress.emit(progress, file_name)
                                            previous_progress = progress

                        except requests.exceptions.RequestException as e:
                            self.failed.emit(f"Failed to download file: {str(e)}")
                            task_failed = True
                        except Exception as e:
                            self.failed.emit(f"An error occurred: {str(e)}")
                            task_failed = True

        if not task_failed:
            self.finished.emit(self.project, self.file, self.target_folder)


class ProjectJobMonitorWorker(QThread):
    failed = pyqtSignal(str)
    jobs_added = pyqtSignal(list)
    job_updated = pyqtSignal(dict)

    def __init__(self, project_id, parent=None):
        super().__init__(parent)
        self.active_jobs = {}
        self.project_id = project_id
        self._stop_flag = False

    def run(self):
        # initialize active jobs
        self.update_jobs()
        while not self._stop_flag:
            self.update_jobs()
            # Process contains a single api call, so every second should be fine
            QThread.sleep(1)

    def stop(self):
        """Gracefully stop the worker"""
        self._stop_flag = True
        self.wait()

    def update_jobs(self):
        response = get_project_jobs(self.project_id)
        if not response:
            return
        current_jobs = response["items"]
        new_jobs = {
            job["id"]: job for job in current_jobs if job["id"] not in self.active_jobs
        }
        self.jobs_added.emit(list(new_jobs.values()))
        self.active_jobs.update(new_jobs)
        for job in current_jobs:
            if job["id"] in new_jobs:
                # new job cannot be updated
                continue
            if (
                job["state"] != self.active_jobs[job["id"]]["state"]
                or job["process"] != self.active_jobs[job["id"]]["process"]
            ):
                self.job_updated.emit(job)
                self.active_jobs[job["id"]] = job
