import hashlib
import os
import re
import shutil
import tarfile
import urllib
import warnings
import zipfile
from urllib.request import urlretrieve

from keras.src.api_export import keras_export
from keras.src.backend import config
from keras.src.utils import io_utils
from keras.src.utils.module_utils import gfile
from keras.src.utils.progbar import Progbar


def path_to_string(path):
    """Convert `PathLike` objects to their string representation.

    If given a non-string typed path object, converts it to its string
    representation.

    If the object passed to `path` is not among the above, then it is
    returned unchanged. This allows e.g. passthrough of file objects
    through this function.

    Args:
        path: `PathLike` object that represents a path

    Returns:
        A string representation of the path argument, if Python support exists.
    """
    if isinstance(path, os.PathLike):
        return os.fspath(path)
    return path


def resolve_path(path):
    return os.path.realpath(os.path.abspath(path))


def is_path_in_dir(path, base_dir):
    return resolve_path(os.path.join(base_dir, path)).startswith(base_dir)


def is_link_in_dir(info, base):
    tip = resolve_path(os.path.join(base, os.path.dirname(info.name)))
    return is_path_in_dir(info.linkname, base_dir=tip)


def filter_safe_paths(members):
    base_dir = resolve_path(".")
    for finfo in members:
        valid_path = False
        if is_path_in_dir(finfo.name, base_dir):
            valid_path = True
            yield finfo
        elif finfo.issym() or finfo.islnk():
            if is_link_in_dir(finfo, base_dir):
                valid_path = True
                yield finfo
        if not valid_path:
            warnings.warn(
                "Skipping invalid path during archive extraction: "
                f"'{finfo.name}'.",
                stacklevel=2,
            )


def extract_archive(file_path, path=".", archive_format="auto"):
    """Extracts an archive if it matches a support format.

    Supports `.tar`, `.tar.gz`, `.tar.bz`, and `.zip` formats.

    Args:
        file_path: Path to the archive file.
        path: Where to extract the archive file.
        archive_format: Archive format to try for extracting the file.
            Options are `"auto"`, `"tar"`, `"zip"`, and `None`.
            `"tar"` includes `.tar`, `.tar.gz`, and `.tar.bz` files.
            The default `"auto"` uses `["tar", "zip"]`.
            `None` or an empty list will return no matches found.

    Returns:
        `True` if a match was found and an archive extraction was completed,
        `False` otherwise.
    """
    if archive_format is None:
        return False
    if archive_format == "auto":
        archive_format = ["tar", "zip"]
    if isinstance(archive_format, str):
        archive_format = [archive_format]

    file_path = path_to_string(file_path)
    path = path_to_string(path)

    for archive_type in archive_format:
        if archive_type == "tar":
            open_fn = tarfile.open
            is_match_fn = tarfile.is_tarfile
        elif archive_type == "zip":
            open_fn = zipfile.ZipFile
            is_match_fn = zipfile.is_zipfile
        else:
            raise NotImplementedError(archive_type)

        if is_match_fn(file_path):
            with open_fn(file_path) as archive:
                try:
                    if zipfile.is_zipfile(file_path):
                        # Zip archive.
                        archive.extractall(path)
                    else:
                        # Tar archive, perhaps unsafe. Filter paths.
                        archive.extractall(
                            path, members=filter_safe_paths(archive)
                        )
                except (tarfile.TarError, RuntimeError, KeyboardInterrupt):
                    if os.path.exists(path):
                        if os.path.isfile(path):
                            os.remove(path)
                        else:
                            shutil.rmtree(path)
                    raise
            return True
    return False


@keras_export("keras.utils.get_file")
def get_file(
    fname=None,
    origin=None,
    untar=False,
    md5_hash=None,
    file_hash=None,
    cache_subdir="datasets",
    hash_algorithm="auto",
    extract=False,
    archive_format="auto",
    cache_dir=None,
    force_download=False,
):
    """Downloads a file from a URL if it not already in the cache.

    By default the file at the url `origin` is downloaded to the
    cache_dir `~/.keras`, placed in the cache_subdir `datasets`,
    and given the filename `fname`. The final location of a file
    `example.txt` would therefore be `~/.keras/datasets/example.txt`.
    Files in `.tar`, `.tar.gz`, `.tar.bz`, and `.zip` formats can
    also be extracted.

    Passing a hash will verify the file after download. The command line
    programs `shasum` and `sha256sum` can compute the hash.

    Example:

    ```python
    path_to_downloaded_file = get_file(
        origin="https://storage.googleapis.com/download.tensorflow.org/example_images/flower_photos.tgz",
        extract=True,
    )
    ```

    Args:
        fname: If the target is a single file, this is your desired
            local name for the file.
            If `None`, the name of the file at `origin` will be used.
            If downloading and extracting a directory archive,
            the provided `fname` will be used as extraction directory
            name (only if it doesn't have an extension).
        origin: Original URL of the file.
        untar: Deprecated in favor of `extract` argument.
            Boolean, whether the file is a tar archive that should
            be extracted.
        md5_hash: Deprecated in favor of `file_hash` argument.
            md5 hash of the file for file integrity verification.
        file_hash: The expected hash string of the file after download.
            The sha256 and md5 hash algorithms are both supported.
        cache_subdir: Subdirectory under the Keras cache dir where the file is
            saved. If an absolute path, e.g. `"/path/to/folder"` is
            specified, the file will be saved at that location.
        hash_algorithm: Select the hash algorithm to verify the file.
            options are `"md5'`, `"sha256'`, and `"auto'`.
            The default 'auto' detects the hash algorithm in use.
        extract: If `True`, extracts the archive. Only applicable to compressed
            archive files like tar or zip.
        archive_format: Archive format to try for extracting the file.
            Options are `"auto'`, `"tar'`, `"zip'`, and `None`.
            `"tar"` includes tar, tar.gz, and tar.bz files.
            The default `"auto"` corresponds to `["tar", "zip"]`.
            None or an empty list will return no matches found.
        cache_dir: Location to store cached files, when None it
            defaults ether `$KERAS_HOME` if the `KERAS_HOME` environment
            variable is set or `~/.keras/`.
        force_download: If `True`, the file will always be re-downloaded
            regardless of the cache state.

    Returns:
        Path to the downloaded file.

    **⚠️ Warning on malicious downloads ⚠️**

    Downloading something from the Internet carries a risk.
    NEVER download a file/archive if you do not trust the source.
    We recommend that you specify the `file_hash` argument
    (if the hash of the source file is known) to make sure that the file you
    are getting is the one you expect.
    """
    if origin is None:
        raise ValueError(
            'Please specify the "origin" argument (URL of the file '
            "to download)."
        )

    if cache_dir is None:
        cache_dir = config.keras_home()
    if md5_hash is not None and file_hash is None:
        file_hash = md5_hash
        hash_algorithm = "md5"
    datadir_base = os.path.expanduser(cache_dir)
    if not os.access(datadir_base, os.W_OK):
        datadir_base = os.path.join("/tmp", ".keras")
    datadir = os.path.join(datadir_base, cache_subdir)
    os.makedirs(datadir, exist_ok=True)

    provided_fname = fname
    fname = path_to_string(fname)

    if not fname:
        fname = os.path.basename(urllib.parse.urlsplit(origin).path)
        if not fname:
            raise ValueError(
                "Can't parse the file name from the origin provided: "
                f"'{origin}'."
                "Please specify the `fname` argument."
            )
    else:
        if os.sep in fname:
            raise ValueError(
                "Paths are no longer accepted as the `fname` argument. "
                "To specify the file's parent directory, use "
                f"the `cache_dir` argument. Received: fname={fname}"
            )

    if extract or untar:
        if provided_fname:
            if "." in fname:
                download_target = os.path.join(datadir, fname)
                fname = fname[: fname.find(".")]
                extraction_dir = os.path.join(datadir, fname + "_extracted")
            else:
                extraction_dir = os.path.join(datadir, fname)
                download_target = os.path.join(datadir, fname + "_archive")
        else:
            extraction_dir = os.path.join(datadir, fname)
            download_target = os.path.join(datadir, fname + "_archive")
    else:
        download_target = os.path.join(datadir, fname)

    if force_download:
        download = True
    elif os.path.exists(download_target):
        # File found in cache.
        download = False
        # Verify integrity if a hash was provided.
        if file_hash is not None:
            if not validate_file(
                download_target, file_hash, algorithm=hash_algorithm
            ):
                io_utils.print_msg(
                    "A local file was found, but it seems to be "
                    f"incomplete or outdated because the {hash_algorithm} "
                    "file hash does not match the original value of "
                    f"{file_hash} so we will re-download the data."
                )
                download = True
    else:
        download = True

    if download:
        io_utils.print_msg(f"Downloading data from {origin}")

        class DLProgbar:
            """Manage progress bar state for use in urlretrieve."""

            def __init__(self):
                self.progbar = None
                self.finished = False

            def __call__(self, block_num, block_size, total_size):
                if total_size == -1:
                    total_size = None
                if not self.progbar:
                    self.progbar = Progbar(total_size)
                current = block_num * block_size

                if total_size is None:
                    self.progbar.update(current)
                else:
                    if current < total_size:
                        self.progbar.update(current)
                    elif not self.finished:
                        self.progbar.update(self.progbar.target)
                        self.finished = True

        error_msg = "URL fetch failure on {}: {} -- {}"
        try:
            try:
                urlretrieve(origin, download_target, DLProgbar())
            except urllib.error.HTTPError as e:
                raise Exception(error_msg.format(origin, e.code, e.msg))
            except urllib.error.URLError as e:
                raise Exception(error_msg.format(origin, e.errno, e.reason))
        except (Exception, KeyboardInterrupt):
            if os.path.exists(download_target):
                os.remove(download_target)
            raise

        # Validate download if succeeded and user provided an expected hash
        # Security conscious users would get the hash of the file from a
        # separate channel and pass it to this API to prevent MITM / corruption:
        if os.path.exists(download_target) and file_hash is not None:
            if not validate_file(
                download_target, file_hash, algorithm=hash_algorithm
            ):
                raise ValueError(
                    "Incomplete or corrupted file detected. "
                    f"The {hash_algorithm} "
                    "file hash does not match the provided value "
                    f"of {file_hash}."
                )

    if extract or untar:
        if untar:
            archive_format = "tar"

        status = extract_archive(
            download_target, extraction_dir, archive_format
        )
        if not status:
            warnings.warn("Could not extract archive.", stacklevel=2)
        return extraction_dir

    return download_target


def resolve_hasher(algorithm, file_hash=None):
    """Returns hash algorithm as hashlib function."""
    if algorithm == "sha256":
        return hashlib.sha256()

    if algorithm == "auto" and file_hash is not None and len(file_hash) == 64:
        return hashlib.sha256()

    # This is used only for legacy purposes.
    return hashlib.md5()


def hash_file(fpath, algorithm="sha256", chunk_size=65535):
    """Calculates a file sha256 or md5 hash.

    Example:

    >>> hash_file('/path/to/file.zip')
    'e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855'

    Args:
        fpath: Path to the file being validated.
        algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`.
            The default `"auto"` detects the hash algorithm in use.
        chunk_size: Bytes to read at a time, important for large files.

    Returns:
        The file hash.
    """
    if isinstance(algorithm, str):
        hasher = resolve_hasher(algorithm)
    else:
        hasher = algorithm

    with open(fpath, "rb") as fpath_file:
        for chunk in iter(lambda: fpath_file.read(chunk_size), b""):
            hasher.update(chunk)

    return hasher.hexdigest()


def validate_file(fpath, file_hash, algorithm="auto", chunk_size=65535):
    """Validates a file against a sha256 or md5 hash.

    Args:
        fpath: path to the file being validated
        file_hash:  The expected hash string of the file.
            The sha256 and md5 hash algorithms are both supported.
        algorithm: Hash algorithm, one of `"auto"`, `"sha256"`, or `"md5"`.
            The default `"auto"` detects the hash algorithm in use.
        chunk_size: Bytes to read at a time, important for large files.

    Returns:
        Boolean, whether the file is valid.
    """
    hasher = resolve_hasher(algorithm, file_hash)

    if str(hash_file(fpath, hasher, chunk_size)) == str(file_hash):
        return True
    else:
        return False


def is_remote_path(filepath):
    """
    Determines if a given filepath indicates a remote location.

    This function checks if the filepath represents a known remote pattern
    such as GCS (`/gcs`), CNS (`/cns`), CFS (`/cfs`), HDFS (`/hdfs`), Placer
    (`/placer`), TFHub (`/tfhub`), or a URL (`.*://`).

    Args:
        filepath (str): The path to be checked.

    Returns:
        bool: True if the filepath is a recognized remote path, otherwise False
    """
    if re.match(
        r"^(/cns|/cfs|/gcs|/hdfs|/readahead|/placer|/tfhub|.*://).*$",
        str(filepath),
    ):
        return True
    return False


# Below are gfile-replacement utils.


def _raise_if_no_gfile(path):
    raise ValueError(
        "Handling remote paths requires installing TensorFlow "
        f"(in order to use gfile). Received path: {path}"
    )


def exists(path):
    if is_remote_path(path):
        if gfile.available:
            return gfile.exists(path)
        else:
            _raise_if_no_gfile(path)
    return os.path.exists(path)


def File(path, mode="r"):
    if is_remote_path(path):
        if gfile.available:
            return gfile.GFile(path, mode=mode)
        else:
            _raise_if_no_gfile(path)
    return open(path, mode=mode)


def join(path, *paths):
    if is_remote_path(path):
        if gfile.available:
            return gfile.join(path, *paths)
        else:
            _raise_if_no_gfile(path)
    return os.path.join(path, *paths)


def isdir(path):
    if is_remote_path(path):
        if gfile.available:
            return gfile.isdir(path)
        else:
            _raise_if_no_gfile(path)
    return os.path.isdir(path)


def remove(path):
    if is_remote_path(path):
        if gfile.available:
            return gfile.remove(path)
        else:
            _raise_if_no_gfile(path)
    return os.remove(path)


def rmtree(path):
    if is_remote_path(path):
        if gfile.available:
            return gfile.rmtree(path)
        else:
            _raise_if_no_gfile(path)
    return shutil.rmtree(path)


def listdir(path):
    if is_remote_path(path):
        if gfile.available:
            return gfile.listdir(path)
        else:
            _raise_if_no_gfile(path)
    return os.listdir(path)


def copy(src, dst):
    if is_remote_path(src) or is_remote_path(dst):
        if gfile.available:
            return gfile.copy(src, dst, overwrite=True)
        else:
            _raise_if_no_gfile(f"src={src} dst={dst}")
    return shutil.copy(src, dst)


def makedirs(path):
    if is_remote_path(path):
        if gfile.available:
            return gfile.makedirs(path)
        else:
            _raise_if_no_gfile(path)
    return os.makedirs(path)
