#!/usr/bin/env python3
# -*- coding: utf-8 -*-

from time import sleep
import shlex
from getpass import getpass
import logging
import logging.handlers
import os
import sys
import json
import urllib
from urllib import request
import ssl
import certifi
import urllib.parse
import urllib.error
import pysftp
import argparse
from typing import List, Callable, Optional
from .patchdiff import Database, DiffList, FileInfo
from . import helpers, aria, version_check
from .TransferProgress import PatcherProgress
from distutils.version import LooseVersion
from io import StringIO

BUILD_DIR = "build"
URL = ""
FTP_DIR = ""
FTP_USER = ""
FTP_PORT = 22
WAIT_FOR_PROCESS = ""

# Those must be relative paths
ARIA_DIR = "aria2"
CONFIG_FILE = "config.json"
LOCK_FILE = "upload.lock"
PATCHER_LOG_FILE = "patcher_log.txt"
PATCHER_LOG_FILE_OLD = "patcher_log.old.txt"
INDEX_FILE = "index.txt"
PATCHER_VERSION_FILE = "patcher_version.json"
REMOTE_PERMISSIONS = 755

ARIA_CMD = [
    "--no-conf=true",
    "--allow-overwrite=true",
    "--auto-file-renaming=false",
    "--optimize-concurrent-downloads=true",
    "--remote-time=true",
    "--console-log-level=warn",
    "--download-result=hide",
    "--summary-interval=0",

    # Always start from scratch, to prevent files being continued with a newer
    # upstream version.
    "--continue=false",
    "--remove-control-file=true",
    "--always-resume=false",
    "--auto-save-interval=0",
]


def load_config() -> bool:
    global BUILD_DIR, FTP_DIR, FTP_USER, FTP_PORT, URL, ARIA_DIR, WAIT_FOR_PROCESS

    if os.path.isfile(CONFIG_FILE):
        with open(CONFIG_FILE, "r") as f:
            config = json.load(f)

        BUILD_DIR        = config.get("build", BUILD_DIR)
        FTP_DIR          = config.get("ftpdir", FTP_DIR)
        FTP_USER         = config.get("ftpuser", FTP_USER)
        FTP_PORT         = config.get("ftpport", FTP_PORT)
        URL              = config.get("url", URL)
        ARIA_DIR         = config.get("aria", ARIA_DIR)
        WAIT_FOR_PROCESS = config.get("wait_for_process", WAIT_FOR_PROCESS)

        BUILD_DIR = helpers.host_path(BUILD_DIR)
        ARIA_DIR = helpers.host_path(ARIA_DIR)

        limit = config.get("limit", 0)
        if limit != 0:
            ARIA_CMD.append("--max-overall-download-limit=" + str(limit))

    FTP_PORT = int(FTP_PORT)

    if os.path.isabs(BUILD_DIR):
        logging.error("Build directory path must be relative")
        return False

    if not URL:
        logging.error("No URL specified")
        return False

    if ARIA_DIR:
        os.environ["PATH"] += os.pathsep + os.path.abspath(ARIA_DIR)

    return True


def get_from_url(url: str, silent=False) -> Optional[str]:
    logging.debug("Retrieving %s", url)
    try:
        ctx = ssl.create_default_context(cafile=certifi.where())
        with request.urlopen(url, context=ctx) as data:
            return data.read().decode(data.headers.get_content_charset() or "utf-8")
    except urllib.error.HTTPError as e:
        if not silent:
            logging.error(str(e))
        return None


def get_server_url(relurl: str) -> str:
    return helpers.urljoin(URL, urllib.parse.quote(relurl))


def get_from_server(relurl: str, silent=False) -> Optional[str]:
    return get_from_url(get_server_url(relurl), silent)


def get_from_server_str(relurl: str, silent=False) -> Optional[str]:
    return get_from_server(relurl, silent)


def get_server_index() -> Database:
    data = get_from_server_str(helpers.net_path(INDEX_FILE))
    if data is None:
        logging.error("Failed to get index file from server.")
        return None
    remote = Database("")
    remote.load_from_string(data)
    return remote


def is_server_updating(silent) -> bool:
    if get_from_server(helpers.net_path(LOCK_FILE), True) is not None:
        if not silent:
            logging.error("New patch is currently being uploaded.")
        return True
    return False


def wait_for_process():
    if WAIT_FOR_PROCESS:
        exepath = os.path.realpath(os.path.join(BUILD_DIR, helpers.host_path(WAIT_FOR_PROCESS)))

        while helpers.find_process(exepath):
            logging.info("%s is running. Please close it to continue.", WAIT_FOR_PROCESS)
            input("Press Enter to continue...")


def wait_for_server_updating(upload_mode: bool):
    # Print error message at least once
    if not is_server_updating(silent=False):
        return

    while upload_mode:
        ans = input("Ignore and upload anyway? (YES/NO): ")
        if ans == "NO":
            break
        if ans == "YES":
            logging.warning("Ignoring lock file and uploading anyway")
            return

    logging.info("Waiting for update to complete...")
    while is_server_updating(silent=True):
        sleep(5)


def needs_update(diff: DiffList) -> bool:
    return len(diff.modified) + len(diff.removed) != 0


def get_full_file_list(infolist: List[FileInfo], sort: bool, reverse: bool = False) -> List[FileInfo]:
    """Returns list of files relative to BUILD_DIR with host path separator."""
    files = [ FileInfo(os.path.join(BUILD_DIR, i.filename), i.checksum, i.size) for i in infolist ]
    if sort:
        files.sort(key=lambda x: x.filename, reverse=reverse)
    return files


def delete_old_files(diff: DiffList, remove_func: Callable, print_finished_files: bool = False):
    """Call a function for each removed file"""
    if len(diff.removed) == 0:
        return

    # Sort file list so that files and directories get processed in
    # order of their hierachy from bottom to top.
    # Otherwise rmdir() would fail due to folders still having children.
    logging.info("Deleting old files...")
    with PatcherProgress(diff.removed, BUILD_DIR) as progress:
        progress.no_speed_info = True
        progress.print_finished_files = print_finished_files
        for i in get_full_file_list(diff.removed, sort=True, reverse=True):
            remove_func(i)
            progress.mark_finished(i.filename, False)
            progress.print_progress_bar()


def download(diff: DiffList, remote: Database, print_files=False) -> bool:
    ariadl = aria.AriaDownloader()

    if len(diff.modified) > 0:
        # Shuffle filelist so a lot of small files get mixed in between larger
        # ones and don't starve download speed.
        # NOTE: tested it, does not work because there are so many small files
        # that will starve the throughput permantently to 1-2 MiB/s.
        # filelist = list(diff.modified)
        # random.shuffle(filelist)

        # aria creates missing directories automatically
        for i in get_full_file_list(diff.modified, sort=False):
            if i.is_directory():
                os.makedirs(i.filename, exist_ok=True)
            else:
                ariadl.add_url(get_server_url(i.net_filename),
                               os.path.dirname(i.filename))

        with PatcherProgress(diff.modified, BUILD_DIR) as progress:
            progress.print_finished_files = print_files
            ariadl.start(ARIA_CMD, progress)

    def remove_func(f: FileInfo):
        fname = f.filename
        if os.path.isdir(fname):
            os.removedirs(fname)
        elif os.path.isfile(fname):
            os.remove(fname)

    delete_old_files(diff, remove_func, print_files)

    # At this point, all files were successfully downloaded, otherwise the
    # return value would be False. That means the local files are at least as
    # new as the target version. If the server was updated during download or
    # is currently being updated, some files might be from the new version.
    # So it should be safe to replace the local index file with the remote
    # index fetched when download started.
    logging.debug("Updating index file")
    remote.save(INDEX_FILE)

    # Server update started during download but is not yet finished
    if is_server_updating(silent=False):
        logging.warning("The upload of a new patch started during the download process.\nYour game files may be incomplete.")
        return False

    return True


def upload(diff: DiffList, print_files=True, compress=False) -> bool:  # pylint: disable=too-many-statements
    if not FTP_DIR:
        logging.error("Empty FTP directory path")
        return False

    if FTP_USER:
        logging.info("Username: %s", FTP_USER)
        user = FTP_USER
    else:
        user = input("Username: ")

    host = urllib.parse.urlsplit(URL)[1]
    logging.info("Connecting to '%s' as user '%s' at port %i...", host, user, FTP_PORT)
    lockfile = helpers.net_path(LOCK_FILE)

    knownhosts = ""
    if os.path.isfile("known_hosts"):
        knownhosts = os.path.join(os.getcwd(), "known_hosts")
    cnopts = pysftp.CnOpts(knownhosts)
    cnopts.compression = compress

    with pysftp.Connection(host, user, password=getpass("Password: "),
                           port=FTP_PORT, cnopts=cnopts) as srv:
        logging.info("Changing to '%s'", FTP_DIR)
        srv.chdir(FTP_DIR)

        try:
            logging.debug("Creating lock file")
            srv.putfo(StringIO(""), lockfile)
            srv.makedirs(BUILD_DIR, mode=REMOTE_PERMISSIONS)
            logging.debug("ls: %s", ", ".join(srv.listdir()))

            if len(diff.modified):
                # Sort directories in front, so that dir structure can be
                # created respectively
                files = get_full_file_list(diff.modified, sort=True, reverse=False)

                logging.info("Uploading files...")
                with PatcherProgress(diff.modified, BUILD_DIR) as progress:
                    progress.print_finished_files = print_files
                    progress.print_progress_bar()

                    def put_cb(transferred: int, _total: int):
                        progress.mark_in_progress([ transferred ])
                        progress.update_speed()
                        progress.print_progress_bar()

                    for i in files:
                        if i.is_directory():
                            srv.makedirs(i.net_filename, REMOTE_PERMISSIONS)
                        else:
                            # Preserve original timestamps. Useful  e.g. when changing servers and reuploading the project.
                            # All timestamps would be set to the new upload date. Keeping original timestamps is desirable especially for patchlogs.
                            srv.put(i.filename, i.net_filename, callback=put_cb, preserve_mtime=True)
                            progress.mark_finished(i.filename, True)
                            progress.print_progress_bar()

            def remove_func(f: FileInfo):
                fname = f.net_filename
                if srv.isdir(fname):
                    srv.rmdir(fname)
                elif srv.isfile(fname):
                    srv.remove(fname)

            delete_old_files(diff, remove_func, print_files)

            logging.debug("Updating index file")
            srv.put(helpers.net_path(INDEX_FILE))

            if srv.isfile(lockfile):
                logging.debug("Removing lock file")
                srv.remove(lockfile)
            else:
                logging.error("There should be a lock file but is not!")

        except:  # noqa
            logging.warning("Upload aborted.\nLock file was not removed!\nClients cannot update until successful upload!")
            raise

    logging.info("Upload finished!")
    return True


def iter_files_progress(files: List[FileInfo]):
    num = len(files)
    for i, f in enumerate(files, 1):
        logging.info("[%i/%i] %s", i, num, f.filename)
        yield f


def validate_files() -> Database:
    """Revalidates files, saves new index, and returns Database object."""
    local = Database(BUILD_DIR)
    local.scan_dir(print_progress=True)
    local.save(INDEX_FILE)
    return local


def setup_logger(level):
    class MyFormatter(logging.Formatter):
        def format(self, record):
            if record.levelno == logging.INFO:
                self._style._fmt = "%(message)s"  # pylint: disable=protected-access
            else:
                self._style._fmt = "%(levelname)s: %(message)s"  # pylint: disable=protected-access
            return logging.Formatter.format(self, record)

    if os.path.isfile(PATCHER_LOG_FILE):
        if os.path.isfile(PATCHER_LOG_FILE_OLD):
            os.remove(PATCHER_LOG_FILE_OLD)  # windows can't rename if file already exists
        os.rename(PATCHER_LOG_FILE, PATCHER_LOG_FILE_OLD)

    logging.basicConfig(level=logging.DEBUG,
                        handlers=[logging.FileHandler(PATCHER_LOG_FILE, 'w', 'utf-8')],
                        format="%(levelname)s: %(message)s")

    # Disable pysftp debug logging
    logging.getLogger("paramiko").setLevel(logging.INFO)

    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(level)
    handler.setFormatter(MyFormatter())
    logging.root.addHandler(handler)


class Runner:
    def __init__(self, program_name: str, version: str):
        self._program_name = program_name
        self._version = LooseVersion(version)
        self._args: argparse.Namespace = None
        self._diff: DiffList = None
        self._local: Database = None
        self._remote: Database = None

    def _on_before_upload(self):
        self.print_summary()

    def _on_before_download(self):
        self.print_summary()

    def print_summary(self):
        self._diff.print_summary(list_files=self._args.verbose)

    def _main(self) -> int:
        logging.info("%s - v%s", self._program_name, self._version)
        print()

        if not load_config():
            return 1

        if not version_check.check_version(PATCHER_VERSION_FILE, self._version, BUILD_DIR):
            return 0
        print()

        success = False
        while not success:
            args = self._args

            wait_for_server_updating(args.upload)
            wait_for_process()

            if args.verify or not os.path.isfile(INDEX_FILE) \
                    or (args.upload and not args.no_verify):
                self._local = validate_files()
            else:
                self._local = Database.load(INDEX_FILE)

            args.verify = False  # Enable only in first iteration

            self._remote = get_server_index()
            if not self._remote:
                if not args.upload:
                    return 1
                self._remote = Database("")

            diff = self._remote.diff(self._local) if args.upload else self._local.diff(self._remote)
            self._diff = diff

            if not needs_update(diff):
                logging.info("Files are up-to-date. Exiting.")
                return 0

            if args.upload:
                self._on_before_upload()
                success = upload(diff, args.print_files, args.compress)
            else:
                self._on_before_download()
                try:
                    success = download(diff, self._remote, args.print_files)
                except Exception:
                    logging.warning("Errors occurred during download -> Validating game files")
                    validate_files()
                    raise
                print("")

                # Trigger another loop to ensure there was not another update.
                # Will exit if files are up-to-date.
                # If an update started and finished during download, it can be
                # handled correctly this way.
                success = False

        return 0

    def run(self) -> int:
        parser = argparse.ArgumentParser()
        parser.add_argument("--verify", action="store_true", help="Force regenerating checksums.")
        parser.add_argument("--no-verify", action="store_true", help="Force not regenerating checksums (useful for uploads if nothing changed).")
        parser.add_argument("-d", "--download", action="store_true", help="Download from server (default).")
        parser.add_argument("-u", "--upload", action="store_true", help="Upload to server.")
        parser.add_argument("-v", "--verbose", action="store_true", help="Verbose output.")
        parser.add_argument("-p", "--print-files", action="store_true", help="Print completed files during upload/download.")
        parser.add_argument("-c", "--compress", action="store_true", help="Compress files during upload.")
        parser.add_argument("--pause", default=True, action="store_true", help="Wait for user to press enter before exiting.")
        parser.add_argument("--nopause", action="store_true", help="Negate --pause, has precedence over --pause.")
        parser.add_argument("--config", help="Use the given config file instead of the default 'config.json'.")
        args = parser.parse_args()
        self._args = args

        setup_logger(logging.DEBUG if args.verbose else logging.INFO)

        if args.config:
            global CONFIG_FILE
            CONFIG_FILE = args.config

        try:
            errcode = self._main()
        except (Exception, KeyboardInterrupt, SystemExit) as e:
            # KeyboardInterrupt and SystemExit are special and not regular exceptions.
            print()
            logging.debug(e, exc_info=e)
            logging.error(e)
            errcode = 1

        # Don't put this in a finally block, because it prevents the log
        # from being flushed until Enter is pressed.
        if args.pause and not args.nopause:
            print()
            input("Press Enter or close this window to continue...")

        return errcode
