"""This module is responsible for parsing a stub AST into a dictionary of names."""

import logging
import ast
import sys
from typing import (
    Any,
    Callable,
    Dict,
    Iterable,
    List,
    NamedTuple,
    NoReturn,
    Optional,
    Tuple,
    Type,
    Union,
)

from . import finder
from .finder import get_search_context, SearchContext, ModulePath, parse_stub_file

log = logging.getLogger(__name__)


class InvalidStub(Exception):
    pass


class ImportedName(NamedTuple):
    module_name: ModulePath
    name: Optional[str] = None


class OverloadedName(NamedTuple):
    definitions: List[ast.AST]


class NameInfo(NamedTuple):
    name: str
    is_exported: bool
    ast: Union[ast.AST, ImportedName, OverloadedName]
    # should be Optional[NameDict] but that needs a recursive type
    child_nodes: Optional[Dict[str, Any]] = None


NameDict = Dict[str, NameInfo]


def get_stub_names(
    module_name: str, *, search_context: Optional[SearchContext] = None
) -> Optional[NameDict]:
    """Given a module name, return a dictionary of names defined in that module."""
    if search_context is None:
        search_context = get_search_context()
    path = finder.get_stub_file(module_name, search_context=search_context)
    if path is None:
        return None
    is_init = path.name == "__init__.pyi"
    ast = parse_stub_file(path)
    return parse_ast(
        ast, search_context, ModulePath(tuple(module_name.split("."))), is_init=is_init
    )


def parse_ast(
    ast: ast.AST,
    search_context: SearchContext,
    module_name: ModulePath,
    *,
    is_init: bool = False,
) -> NameDict:
    visitor = _NameExtractor(search_context, module_name, is_init=is_init)
    name_dict: NameDict = {}
    try:
        names = visitor.visit(ast)
    except _AssertFailed:
        return name_dict
    for info in names:
        if info.name in name_dict:
            if info.child_nodes:
                log.warning(
                    "Name is already present in %s: %s", ".".join(module_name), info
                )
                continue
            existing = name_dict[info.name]

            # This is common and harmless, likely from an "import *"
            if isinstance(existing.ast, ImportedName) and isinstance(
                info.ast, ImportedName
            ):
                continue

            if isinstance(existing.ast, ImportedName):
                log.warning(
                    "Name is already imported in %s: %s",
                    ".".join(module_name),
                    existing,
                )
            elif existing.child_nodes:
                log.warning(
                    "Name is already present in %s: %s", ".".join(module_name), existing
                )
            elif isinstance(existing.ast, OverloadedName):
                existing.ast.definitions.append(info.ast)
            else:
                new_info = NameInfo(
                    existing.name,
                    existing.is_exported,
                    OverloadedName([existing.ast, info.ast]),
                )
                name_dict[info.name] = new_info
        else:
            name_dict[info.name] = info
    return name_dict


_CMP_OP_TO_FUNCTION: Dict[Type[ast.AST], Callable[[Any, Any], bool]] = {
    ast.Eq: lambda x, y: x == y,
    ast.NotEq: lambda x, y: x != y,
    ast.Lt: lambda x, y: x < y,
    ast.LtE: lambda x, y: x <= y,
    ast.Gt: lambda x, y: x > y,
    ast.GtE: lambda x, y: x >= y,
    ast.Is: lambda x, y: x is y,
    ast.IsNot: lambda x, y: x is not y,
    ast.In: lambda x, y: x in y,
    ast.NotIn: lambda x, y: x not in y,
}


def _name_is_exported(name: str) -> bool:
    if not name.startswith("_"):
        return True
    if name.startswith("__") and name.endswith("__"):
        return True
    return False


class _NameExtractor(ast.NodeVisitor):
    """Extract names from a stub module."""

    def __init__(
        self, ctx: SearchContext, module_name: ModulePath, *, is_init: bool = False
    ) -> None:
        self.ctx = ctx
        self.module_name = module_name
        self.is_init = is_init

    def visit_Module(self, node: ast.Module) -> List[NameInfo]:
        return [info for child in node.body for info in self.visit(child)]

    def visit_FunctionDef(self, node: ast.FunctionDef) -> Iterable[NameInfo]:
        yield NameInfo(node.name, _name_is_exported(node.name), node)

    def visit_AsyncFunctionDef(self, node: ast.AsyncFunctionDef) -> Iterable[NameInfo]:
        yield NameInfo(node.name, _name_is_exported(node.name), node)

    def visit_ClassDef(self, node: ast.ClassDef) -> Iterable[NameInfo]:
        children = [info for child in node.body for info in self.visit(child)]
        child_dict: NameDict = {}
        for info in children:
            if info.name in child_dict:
                existing = child_dict[info.name]
                if isinstance(existing.ast, OverloadedName):
                    existing.ast.definitions.append(info.ast)
                elif isinstance(existing.ast, ImportedName):
                    raise RuntimeError(
                        f"Unexpected import name in class: {existing.ast}"
                    )
                else:
                    new_info = NameInfo(
                        existing.name,
                        existing.is_exported,
                        OverloadedName([existing.ast, info.ast]),
                    )
                    child_dict[info.name] = new_info
            else:
                child_dict[info.name] = info
        yield NameInfo(node.name, _name_is_exported(node.name), node, child_dict)

    def visit_Assign(self, node: ast.Assign) -> Iterable[NameInfo]:
        for target in node.targets:
            if not isinstance(target, ast.Name):
                raise InvalidStub(
                    f"Assignment should only be to a simple name: {ast.dump(node)}"
                )
            yield NameInfo(target.id, _name_is_exported(target.id), node)

    def visit_AugAssign(self, node: ast.AugAssign) -> Iterable[NameInfo]:
        if not isinstance(node.op, ast.Add):
            raise InvalidStub(f"Only += is allowed in stubs: {ast.dump(node)}")
        if not isinstance(node.target, ast.Name) or node.target.id != "__all__":
            raise InvalidStub(f"+= is allowed for __all__: {ast.dump(node)}")
        yield NameInfo("__all__", True, node)

    def visit_AnnAssign(self, node: ast.AnnAssign) -> Iterable[NameInfo]:
        target = node.target
        if not isinstance(target, ast.Name):
            raise InvalidStub(
                f"Assignment should only be to a simple name: {ast.dump(node)}"
            )
        yield NameInfo(target.id, _name_is_exported(target.id), node)

    def visit_If(self, node: ast.If) -> Iterable[NameInfo]:
        visitor = _LiteralEvalVisitor(self.ctx)
        value = visitor.visit(node.test)
        if value:
            for stmt in node.body:
                yield from self.visit(stmt)
        else:
            for stmt in node.orelse:
                yield from self.visit(stmt)

    def visit_Assert(self, node: ast.Assert) -> Iterable[NameInfo]:
        visitor = _LiteralEvalVisitor(self.ctx)
        value = visitor.visit(node.test)
        if value:
            return []
        else:
            raise _AssertFailed

    def visit_Import(self, node: ast.Import) -> Iterable[NameInfo]:
        for alias in node.names:
            if alias.asname is not None:
                yield NameInfo(
                    alias.asname,
                    True,
                    ImportedName(ModulePath(tuple(alias.name.split(".")))),
                )
            else:
                # "import a.b" just binds the name "a"
                name = alias.name.split(".", 1)[0]
                yield NameInfo(name, False, ImportedName(ModulePath((name,))))

    def visit_ImportFrom(self, node: ast.ImportFrom) -> Iterable[NameInfo]:
        module: Tuple[str, ...]
        if node.module is None:
            module = ()
        else:
            module = tuple(node.module.split("."))
        if node.level == 0:
            source_module = ModulePath(module)
        elif node.level == 1:
            if self.is_init:
                source_module = ModulePath(self.module_name + module)
            else:
                source_module = ModulePath(self.module_name[:-1] + module)
        else:
            if self.is_init:
                source_module = ModulePath(self.module_name[: 1 - node.level] + module)
            else:
                source_module = ModulePath(self.module_name[: -node.level] + module)
        for alias in node.names:
            if alias.asname is not None:
                is_exported = not alias.asname.startswith("_")
                yield NameInfo(
                    alias.asname, is_exported, ImportedName(source_module, alias.name)
                )
            elif alias.name == "*":
                name_dict = get_stub_names(
                    ".".join(source_module), search_context=self.ctx
                )
                if name_dict is None:
                    log.warning(
                        f"could not import {source_module} in {self.module_name} with "
                        f"{self.ctx}"
                    )
                    continue
                for name, info in name_dict.items():
                    if info.is_exported:
                        yield NameInfo(name, True, ImportedName(source_module, name))
            else:
                yield NameInfo(
                    alias.name, False, ImportedName(source_module, alias.name)
                )

    def visit_Expr(self, node: ast.Expr) -> Iterable[NameInfo]:
        if not isinstance(node.value, (ast.Ellipsis, ast.Str)):
            raise InvalidStub(f"Cannot handle node {ast.dump(node)}")
        return []

    def visit_Pass(self, node: ast.Pass) -> Iterable[NameInfo]:
        return []

    def generic_visit(self, node: ast.AST) -> NoReturn:
        raise InvalidStub(f"Cannot handle node {ast.dump(node)}")


class _LiteralEvalVisitor(ast.NodeVisitor):
    def __init__(self, ctx: SearchContext) -> None:
        self.ctx = ctx

    # from version 3.8 on all constants are represented as ast.Constant
    if sys.version_info < (3, 8):

        def visit_Num(self, node: ast.Num) -> Union[int, float, complex]:
            return node.n

        def visit_Str(self, node: ast.Str) -> str:
            return node.s

    else:

        def visit_Constant(self, node: ast.Constant) -> Any:
            return node.value

    # from version 3.9 on an index is represented as the value directly
    if sys.version_info < (3, 9):

        def visit_Index(self, node: ast.Index) -> int:
            return self.visit(node.value)

    def visit_Tuple(self, node: ast.Tuple) -> Tuple[Any, ...]:
        return tuple(self.visit(elt) for elt in node.elts)

    def visit_Subscript(self, node: ast.Subscript) -> Any:
        value = self.visit(node.value)
        slc = self.visit(node.slice)
        return value[slc]

    def visit_Compare(self, node: ast.Compare) -> bool:
        if len(node.ops) != 1:
            raise InvalidStub(f"Cannot evaluate chained comparison {ast.dump(node)}")
        fn = _CMP_OP_TO_FUNCTION[type(node.ops[0])]
        return fn(self.visit(node.left), self.visit(node.comparators[0]))

    def visit_BoolOp(self, node: ast.BoolOp) -> bool:
        for val_node in node.values:
            val = self.visit(val_node)
            if (isinstance(node.op, ast.Or) and val) or (
                isinstance(node.op, ast.And) and not val
            ):
                return val
        return val

    def visit_Slice(self, node: ast.Slice) -> slice:
        lower = self.visit(node.lower) if node.lower is not None else None
        upper = self.visit(node.upper) if node.upper is not None else None
        step = self.visit(node.step) if node.step is not None else None
        return slice(lower, upper, step)

    def visit_Attribute(self, node: ast.Attribute) -> Any:
        val = node.value
        if not isinstance(val, ast.Name):
            raise InvalidStub(f"Invalid code in stub: {ast.dump(node)}")
        if val.id != "sys":
            raise InvalidStub(
                f"Attribute access must be on the sys module: {ast.dump(node)}"
            )
        if node.attr == "platform":
            return self.ctx.platform
        elif node.attr == "version_info":
            return self.ctx.version
        else:
            raise InvalidStub(f"Invalid attribute on {ast.dump(node)}")

    def generic_visit(self, node: ast.AST) -> NoReturn:
        raise InvalidStub(f"Cannot evaluate node {ast.dump(node)}")


class _AssertFailed(Exception):
    """Raised when a top-level assert in a stub fails."""
