from __future__ import annotations

import copy
import re
from dataclasses import dataclass
from enum import Enum, unique
from pathlib import Path

from cvise.passes.hint_based import HintBasedPass
from cvise.utils.fileutil import filter_files_by_patterns
from cvise.utils.hint import Hint, HintBundle, Patch


@unique
class _Vocab(Enum):
    # Items must be listed in the index order; indices must be contiguous and start from zero.
    FILEREF = (0, b'@fileref')
    MAKE_HEADER_NON_MODULAR = (1, b'make-header-non-modular')
    DELETE_USE_DECL = (2, b'delete-use-decl')
    DELETE_EMPTY_SUBMODULE = (3, b'delete-empty-submodule')
    INLINE_SUBMODULE_CONTENTS = (4, b'inline-submodule-contents')
    DELETE_LINE = (5, b'delete-line')


class ClangModuleMapPass(HintBasedPass):
    """A pass for removing items from C++ header module map files.

    See https://clang.llvm.org/docs/Modules.html#module-map-language for the specification.

    Note: C-Vise JSON config should specify "claim_files" for this pass, to prevent it from being attempted on unrelated
    files and to prevent the module map files from being corrupted by other passes.
    """

    def check_prerequisites(self):
        return True

    def supports_dir_test_cases(self):
        return True

    def output_hint_types(self) -> list[bytes]:
        return [v.value[1] for v in _Vocab]

    def generate_hints(self, test_case: Path, *args, **kwargs):
        paths = filter_files_by_patterns(test_case, self.claim_files, self.claimed_by_others_files)
        vocab: list[bytes] = [v.value[1] for v in _Vocab]  # initial set of strings used in hints
        path_to_vocab: dict[Path, int] = {}
        hints: list[Hint] = []
        for path in paths:
            file = _parse_file(path)

            rel_path = path.relative_to(test_case)
            path_id = _get_vocab_id(rel_path, vocab, path_to_vocab)

            for mod in file.modules:
                _create_hints_for_module(
                    mod, test_case, path_id, toplevel=True, hints=hints, vocab=vocab, path_to_vocab=path_to_vocab
                )
            _create_hints_for_unclassified_lines(file.unclassified_lines, path_id, hints)

        return HintBundle(hints=hints, vocabulary=vocab)


@dataclass
class _SourceLoc:
    begin: int
    end: int


@dataclass
class _HeaderDecl:
    loc: _SourceLoc
    file_path: str


@dataclass
class _UseDecl:
    loc: _SourceLoc
    id: str


@dataclass
class _ModuleDecl:
    loc: _SourceLoc
    title_loc: _SourceLoc  # location of " ... module ... {"
    close_brace_loc: _SourceLoc
    id: str
    headers: list[_HeaderDecl]
    uses: list[_UseDecl]
    submodules: list[_ModuleDecl]


@dataclass
class _ModuleMapFile:
    modules: list[_ModuleDecl]
    unclassified_lines: list[_SourceLoc]


def _get_vocab_id(path: Path, vocab: list[bytes], path_to_vocab: dict[Path, int]) -> int:
    if path in path_to_vocab:
        return path_to_vocab[path]
    vocab.append(str(path).encode())
    id = len(vocab) - 1
    path_to_vocab[path] = id
    return id


def _create_hints_for_module(
    mod: _ModuleDecl,
    test_case: Path,
    path_id: int,
    toplevel: bool,
    hints: list[Hint],
    vocab: list[bytes],
    path_to_vocab: dict[Path, int],
) -> None:
    empty = not mod.headers and not mod.uses and not mod.submodules
    if not toplevel and empty:
        hints.append(
            Hint(
                type=_Vocab.DELETE_EMPTY_SUBMODULE.value[0],
                patches=(
                    Patch(
                        path=path_id,
                        left=mod.loc.begin,
                        right=mod.loc.end,
                    ),
                ),
            )
        )

    if not toplevel and not empty:
        hints.append(
            Hint(
                type=_Vocab.INLINE_SUBMODULE_CONTENTS.value[0],
                patches=(
                    Patch(
                        path=path_id,
                        left=mod.title_loc.begin,
                        right=mod.title_loc.end,
                    ),
                    Patch(
                        path=path_id,
                        left=mod.close_brace_loc.begin,
                        right=mod.close_brace_loc.end,
                    ),
                ),
            )
        )

    for header in mod.headers:
        if (test_case / header.file_path).exists():
            header_path_id = _get_vocab_id(Path(header.file_path), vocab, path_to_vocab)
            hints.append(
                Hint(
                    type=_Vocab.FILEREF.value[0],
                    patches=(
                        Patch(
                            path=path_id,
                            left=header.loc.begin,
                            right=header.loc.end,
                        ),
                    ),
                    extra=header_path_id,
                )
            )
        hints.append(
            Hint(
                type=_Vocab.MAKE_HEADER_NON_MODULAR.value[0],
                patches=(
                    Patch(
                        path=path_id,
                        left=header.loc.begin,
                        right=header.loc.end,
                    ),
                ),
            )
        )
    for use in mod.uses:
        hints.append(
            Hint(
                type=_Vocab.DELETE_USE_DECL.value[0],
                patches=(
                    Patch(
                        path=path_id,
                        left=use.loc.begin,
                        right=use.loc.end,
                    ),
                ),
            )
        )
    for submod in mod.submodules:
        _create_hints_for_module(
            submod, test_case, path_id, toplevel=False, hints=hints, vocab=vocab, path_to_vocab=path_to_vocab
        )


def _create_hints_for_unclassified_lines(unclassified_lines: list[_SourceLoc], path_id: int, hints: list[Hint]) -> None:
    for loc in unclassified_lines:
        hints.append(
            Hint(
                type=_Vocab.DELETE_LINE.value[0],
                patches=(
                    Patch(
                        path=path_id,
                        left=loc.begin,
                        right=loc.end,
                    ),
                ),
            )
        )


def _parse_file(path: Path) -> _ModuleMapFile:
    file = _ModuleMapFile(modules=[], unclassified_lines=[])
    with open(path, 'rb') as f:
        stack: list[_ModuleDecl] = []
        file_pos = 0
        for line in f:
            loc = _SourceLoc(begin=file_pos, end=file_pos + len(line))
            file_pos = loc.end

            for ancestor in stack:
                ancestor.loc.end = loc.end  # expand each active module to cover the current line

            if module := _try_parse_module_decl(line, loc):
                parent_children = stack[-1].submodules if stack else file.modules
                parent_children.append(module)
                stack.append(module)
            elif stack and (header := _try_parse_header_decl(line, loc)):
                stack[-1].headers.append(header)
            elif stack and (use := _try_parse_use_decl(line, loc)):
                stack[-1].uses.append(use)
            elif stack and _is_close_brace(line):
                stack[-1].close_brace_loc = loc
                stack.pop()
            else:
                file.unclassified_lines.append(loc)

    return file


def _try_parse_module_decl(line: bytes, loc: _SourceLoc) -> _ModuleDecl | None:
    m = re.match(rb'.*\bmodule\s+(\S+).*{\s*', line)
    if not m:
        return None
    module_id = m.group(1).decode().strip('"')
    title_loc = copy.copy(loc)
    # close_brace_loc will be replaced with a real value once the closing brace line is parsed
    close_brace_loc = title_loc
    return _ModuleDecl(
        loc=loc,
        title_loc=title_loc,
        close_brace_loc=close_brace_loc,
        id=module_id,
        headers=[],
        uses=[],
        submodules=[],
    )


def _try_parse_header_decl(line: bytes, loc: _SourceLoc) -> _HeaderDecl | None:
    m = re.match(rb'.*\bheader\s+"([^\s"]+)".*', line)
    if not m:
        return None
    file_path = m.group(1).decode()
    return _HeaderDecl(loc=loc, file_path=file_path)


def _try_parse_use_decl(line: bytes, loc: _SourceLoc) -> _UseDecl | None:
    m = re.match(rb'.*\buse\s+(\S+).*', line)
    if not m:
        return None
    module_id = m.group(1).decode().strip('"')
    return _UseDecl(loc=loc, id=module_id)


def _is_close_brace(line: bytes) -> bool:
    return line.strip() == b'}'
