#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import sys
import os
from hashlib import md5
from typing import List, Dict, Union, Optional, ValuesView, Iterable, Tuple
from dataclasses import dataclass
import argparse
import logging
from .TransferProgress import Progress
from .FileInfo import FileInfo

READ_BUFFER_SIZE = 2 * 4096


@dataclass
class DiffList:
    modified: List[FileInfo]
    removed: List[FileInfo]

    def print_summary(self, list_files: bool = False):
        if list_files:
            logging.info("Files to be removed:")
            for i in self.removed:
                logging.info("\t%s", i.filename)

            logging.info("Files changed:")
            for i in self.modified:
                logging.info("\t%s", i.filename)

        sizemb = sum([ i.size for i in self.modified ]) // 1024 ** 2
        freedmb = sum([ i.size for i in self.removed ]) // 1024 ** 2
        changed = len(self.modified)
        removed = len(self.removed)
        logging.info("Total: %s changed, %s removed, %s MiB changed, %s MiB freed",
                     changed, removed, sizemb, freedmb)


class Database:
    def __init__(self, path: str):
        self._path: str = path
        self._namedict: Dict[str, FileInfo] = {}

    def get_path(self) -> str:
        return self._path

    def get_files_only(self) -> List[FileInfo]:
        return [ i for i in self.get_files() if not i.is_directory() ]

    def get_directories_only(self) -> List[FileInfo]:
        return [ i for i in self.get_files() if i.is_directory() ]

    def get_files(self) -> ValuesView[FileInfo]:
        return self._namedict.values()

    def get_file(self, fname: str) -> Optional[FileInfo]:
        """Returns a FileInfo or None if file not found."""
        return self._namedict.get(fname, None)

    def get_checksum(self, filename) -> str:
        """Returns checksum or None if file not found."""
        f = self.get_file(filename)
        return f.checksum if f else None

    def remove_file(self, filename_or_fileinfo: Union[str, FileInfo]) -> Optional[FileInfo]:
        """Removes a file from the database and returns the corresponding FileInfo or None if it does not exist.

        Filenames should be relative to the database directory.
        """
        if isinstance(filename_or_fileinfo, str):
            fname = filename_or_fileinfo
        else:
            fname = filename_or_fileinfo.filename

        return self._namedict.pop(fname, None)

    def _check_is_child(self, path: str, is_relative: bool = True) -> Tuple[str, str]:
        """Checks if the given path is a descendent of the database directory.

        Returns a tuple (full_path, relative_path) containing the full path to
        the target and a path relative to the database directory.
        """
        full_name = os.path.join(self.get_path(), path) if is_relative else path
        relname = os.path.relpath(full_name, self.get_path())
        assert not relname.startswith(".."), "File is not a child of the database directory"
        return full_name, relname

    def add_file(self, filename: str, is_relative: bool = True) -> Optional[FileInfo]:
        """Add a file or directory to the database.

        If is_relative is true, filenames must be relative to the database directory.
        If the target is a directory, it will be added as such, without
        recursively adding its children.
        If the file already exists in the database, it will be updated.
        Returns the corresponding FileInfo or None if the local file does not exist.
        """
        full_name, relname = self._check_is_child(filename, is_relative)
        info = None

        if os.path.isdir(full_name):
            info = FileInfo(relname, "", 0)
        elif os.path.isfile(full_name):
            info = FileInfo(relname, md5sum(full_name), os.path.getsize(full_name))
        else:
            logging.error("No such file or directory: %s", full_name)

        if info:
            self._namedict[info.filename] = info
            logging.debug("Added file %s to database", info.filename)

        return info

    def scan_dir(self, print_progress: bool = False):
        """Shortcut for add_dir(".")"""
        self.add_dir(".", print_progress)

    def add_dir(self, path: str, print_progress: bool = False, is_relative: bool = True):  # pylint: disable=too-many-locals
        """Recursively add a directory to the database.

        Filenames must be relative to the database directory.
        If files already exist in the database, they will be updated.
        """
        # NOTE: Threading does not make this faster, because it is IO bound.
        # Tests resulted in the same computation times using single and multi
        # threaded (using multiprocessing of course).
        # TODO: Test if this conclusion is also true for SSDs.
        full_root, _ = self._check_is_child(path, is_relative)
        files: List[str] = []
        total_size: int = 0

        logging.info("Calculating checksums for %s...", full_root)
        logging.debug("Creating file list...")
        for dirpath, dirnames, filenames in os.walk(full_root):
            for i in filenames:
                fname = os.path.join(dirpath, i)
                logging.debug(fname)
                total_size += os.path.getsize(fname)
                files.append(fname)

            for i in dirnames:
                files.append(os.path.join(dirpath, i))

        if files:
            logging.debug("Calculating checksums...")
            progress = Progress(total_size) if print_progress else None

            if progress:
                progress.print_progress_bar()

            for file in files:
                info = self.add_file(file, is_relative=False)

                if progress:
                    progress.add_progress(info.size)
                    # progress.update_speed()
                    progress.print_progress_bar()

            if progress:
                progress.finish()

    def merge(self, other: "Database"):
        for i in other.get_files():
            self._namedict[i.filename] = i

    def diff(self, other: "Database") -> DiffList:
        remove = []
        modif = []

        # Find old, no longer existing files
        for i in self.get_files():
            if not other.get_file(i.filename ):
                remove.append(i)

        # Find new or modified files
        for i in other.get_files():
            if i.checksum != self.get_checksum(i.filename):
                modif.append(i)

        return DiffList(modif, remove)

    def clear(self):
        """Clear all file entries, but keep database path."""
        logging.debug("Clearing database")
        self._namedict.clear()

    def save(self, filename: str):
        with open(filename, "w") as f:
            f.write(self.save_string())

    def save_string(self) -> str:
        return "\n".join(i.serialize_string() for i in self.get_files())

    @staticmethod
    def load(path: str, print_progress=False) -> Optional["Database"]:
        """Create database from a directory or database file."""
        db: Database = None

        if os.path.isfile(path):
            db = Database("")
            db.load_from_file(path)
        elif os.path.isdir(path):
            db = Database(path)
            db.scan_dir(print_progress)
        else:
            logging.error("No such file or directory: %s", path)

        return db

    def load_from_file(self, filename: str):
        with open(filename, "r") as f:
            self.load_from_string(f.read())

    def load_from_string(self, s: str):
        lines = s.splitlines(keepends=False)
        infolist = []
        for fname, checksum, size in zip(*[ iter(lines) ] * 3):
            infolist.append(FileInfo(fname, checksum, int(size)))
        self._set_file_list(infolist)

    def _set_file_list(self, infolist: Iterable[FileInfo]):
        self.clear()
        for i in infolist:
            self._namedict[i.filename] = i
        logging.debug("Set new file list")


# https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file
def md5sum(fname):
    md5hash = md5()
    with open(fname, "rb") as f:
        for chunk in iter(lambda: f.read(READ_BUFFER_SIZE), b""):
            md5hash.update(chunk)
    return md5hash.hexdigest()


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("local", help="Either a directory or an index file")
    parser.add_argument("remote", default=None, nargs="?", help="Either a directory or an index file")
    parser.add_argument("-o", "--output", nargs=1, type=str, help="Save index file for local dir")
    parser.add_argument("-v", "--verbose", help="Show debug output", action="store_true")
    args = parser.parse_args()

    level = logging.DEBUG if args.verbose else logging.INFO
    # logging.basicConfig(level=level, format="%(levelname)s: %(message)s")
    logging.basicConfig(level=level, format="%(message)s")

    local = Database.load(args.local)

    if args.output:
        local.save(args.output[0])

    if args.remote:
        remote = Database.load(args.remote)
        diffs = local.diff(remote)
        diffs.print_summary(args.verbose)
    elif not args.output:
        logging.info(local.save_string())

    return 0


if __name__ == "__main__":
    sys.exit(main())
