from __future__ import annotations

"""Misere Hex in pygame.

Board encoding:
    0 = empty
    1 = red
    2 = blue

Coordinate convention:
    Strategies receive the board as a NumPy array indexed as board[y, x]
    and must return a move as the tuple (x, y).

Default strategies:
    - Red (player 1): random legal move.
    - Blue (player 2): for each legal Blue move, evaluate how close Blue
      would be to a normal top-to-bottom connection after making that move,
      then choose the move that is *worst* for Blue's own connection. In
      misere Hex, that means Blue tries to avoid helping Blue connect.

Misere rule:
    If Red connects Red's sides, Blue wins.
    If Blue connects Blue's sides, Red wins.
"""

import math
import random
from collections import deque
from dataclasses import dataclass
from typing import Callable

import numpy as np

try:
    import pygame
except ImportError:  # Allows import and non-GUI testing without pygame installed.
    pygame = None


EMPTY = 0
RED = 1
BLUE = 2

PLAYER_NAMES = {
    RED: "Red",
    BLUE: "Blue",
}

STONE_COLORS = {
    EMPTY: (232, 232, 232),
    RED: (212, 69, 69),
    BLUE: (63, 110, 214),
}

BORDER_COLORS = {
    RED: (180, 32, 32),
    BLUE: (35, 81, 190),
}

BACKGROUND_COLOR = (247, 244, 236)
OUTLINE_COLOR = (72, 72, 72)
TEXT_COLOR = (22, 22, 22)
INFO_BG = (240, 236, 227)
LAST_MOVE_MARKER = (18, 18, 18)
BUTTON_FILL = (226, 226, 226)
BUTTON_HOVER_FILL = (210, 219, 232)
BUTTON_TEXT_COLOR = TEXT_COLOR

BOARD_SIZE = 7
HEX_RADIUS = 32
FPS = 60
START_DELAY_MS = 700
MOVE_DELAY_MS = 550
MARGIN_X = 80
MARGIN_Y = 80
INFO_HEIGHT = 160
BUTTON_WIDTH = 150
BUTTON_HEIGHT = 42
BUTTON_MARGIN = 12

Strategy = Callable[[np.ndarray], tuple[int, int]]

# Neighbor offsets for axial-like coordinates laid out as a rhombus.
NEIGHBOR_DELTAS = (
    (1, 0),
    (-1, 0),
    (0, 1),
    (0, -1),
    (1, -1),
    (-1, 1),
)


def other_player(player: int) -> int:
    return BLUE if player == RED else RED


def in_bounds(size: int, x: int, y: int) -> bool:
    return 0 <= x < size and 0 <= y < size


def neighbors(size: int, x: int, y: int):
    for dx, dy in NEIGHBOR_DELTAS:
        nx, ny = x + dx, y + dy
        if in_bounds(size, nx, ny):
            yield nx, ny


def legal_moves(board: np.ndarray) -> list[tuple[int, int]]:
    size = board.shape[0]
    return [(x, y) for y in range(size) for x in range(size) if board[y, x] == EMPTY]


def has_connection(board: np.ndarray, player: int) -> bool:
    """Return True if `player` has connected their own normal Hex sides."""
    size = board.shape[0]
    frontier = deque()
    seen: set[tuple[int, int]] = set()

    if player == RED:
        for y in range(size):
            if board[y, 0] == RED:
                frontier.append((0, y))
                seen.add((0, y))

        def reached_goal(x: int, y: int) -> bool:
            return x == size - 1
    else:
        for x in range(size):
            if board[0, x] == BLUE:
                frontier.append((x, 0))
                seen.add((x, 0))

        def reached_goal(x: int, y: int) -> bool:
            return y == size - 1

    while frontier:
        x, y = frontier.popleft()
        if reached_goal(x, y):
            return True
        for nx, ny in neighbors(size, x, y):
            if (nx, ny) not in seen and board[ny, nx] == player:
                seen.add((nx, ny))
                frontier.append((nx, ny))
    return False


def shortest_connection_cost(board: np.ndarray, player: int) -> float:
    """Return the minimum number of empty cells needed for `player` to connect.

    Own stones cost 0 to traverse, empty cells cost 1, and opponent stones are
    blocked. A lower score means the player is closer to their normal Hex
    connection.
    """
    size = board.shape[0]
    opponent = other_player(player)
    inf = 10**9
    dist = np.full((size, size), inf, dtype=np.int32)
    frontier: deque[tuple[int, int]] = deque()

    if player == RED:
        starts = [(0, y) for y in range(size)]
        goals = [(size - 1, y) for y in range(size)]
    else:
        starts = [(x, 0) for x in range(size)]
        goals = [(x, size - 1) for x in range(size)]

    for x, y in starts:
        cell = int(board[y, x])
        if cell == opponent:
            continue
        initial_cost = 0 if cell == player else 1
        if initial_cost < dist[y, x]:
            dist[y, x] = initial_cost
            if initial_cost == 0:
                frontier.appendleft((x, y))
            else:
                frontier.append((x, y))

    while frontier:
        x, y = frontier.popleft()
        current = int(dist[y, x])
        for nx, ny in neighbors(size, x, y):
            cell = int(board[ny, nx])
            if cell == opponent:
                continue
            step_cost = 0 if cell == player else 1
            new_cost = current + step_cost
            if new_cost < dist[ny, nx]:
                dist[ny, nx] = new_cost
                if step_cost == 0:
                    frontier.appendleft((nx, ny))
                else:
                    frontier.append((nx, ny))

    best = min(int(dist[y, x]) for x, y in goals)
    return float("inf") if best >= inf else float(best)


def random_red_strategy(board: np.ndarray) -> tuple[int, int]:
    """Player 1 default: choose a random legal move."""
    moves = legal_moves(board)
    if not moves:
        raise ValueError("No legal moves are available.")
    return random.choice(moves)



def blue_anti_connection_strategy(board: np.ndarray) -> tuple[int, int]:
    """Player 2 default: choose the move worst for Blue's own connection.

    For each legal move, temporarily place a Blue stone, measure Blue's own
    shortest remaining top-to-bottom connection cost, and choose the move with
    the *largest* resulting cost. Immediate self-connections are ranked last
    because they lose instantly under the misere rule.
    """
    moves = legal_moves(board)
    if not moves:
        raise ValueError("No legal moves are available.")

    size = board.shape[0]
    center = (size - 1) / 2.0

    best_score: tuple[float, int, int, float] | None = None
    best_moves: list[tuple[int, int]] = []

    for x, y in moves:
        board[y, x] = BLUE
        loses_immediately = has_connection(board, BLUE)
        blue_cost = shortest_connection_cost(board, BLUE)
        board[y, x] = EMPTY

        blue_neighbors = sum(1 for nx, ny in neighbors(size, x, y) if board[ny, nx] == BLUE)
        on_blue_goal_edge = 1 if y == 0 or y == size - 1 else 0
        center_distance = abs(x - center) + abs(y - center)

        if loses_immediately:
            primary_score = -1e9
        elif math.isinf(blue_cost):
            primary_score = 1e9
        else:
            primary_score = blue_cost

        score = (primary_score, -blue_neighbors, -on_blue_goal_edge, center_distance)
        if best_score is None or score > best_score:
            best_score = score
            best_moves = [(x, y)]
        elif score == best_score:
            best_moves.append((x, y))

    return random.choice(best_moves)


# Backward-compatible alias for earlier code that referenced the old name.
blue_blocking_strategy = blue_anti_connection_strategy


@dataclass(frozen=True)
class BoardGeometry:
    size: int
    radius: float = HEX_RADIUS
    margin_x: int = MARGIN_X
    margin_y: int = MARGIN_Y
    info_height: int = INFO_HEIGHT

    @property
    def x_step(self) -> float:
        return math.sqrt(3.0) * self.radius

    @property
    def y_step(self) -> float:
        return 1.5 * self.radius

    def center(self, x: int, y: int) -> tuple[float, float]:
        cx = self.margin_x + (x + 0.5 * y) * self.x_step
        cy = self.margin_y + y * self.y_step
        return cx, cy

    def polygon(self, x: int, y: int) -> list[tuple[int, int]]:
        cx, cy = self.center(x, y)
        r = self.radius
        dx = math.sqrt(3.0) * 0.5 * r
        points = [
            (cx, cy - r),
            (cx + dx, cy - 0.5 * r),
            (cx + dx, cy + 0.5 * r),
            (cx, cy + r),
            (cx - dx, cy + 0.5 * r),
            (cx - dx, cy - 0.5 * r),
        ]
        return [(int(round(px)), int(round(py))) for px, py in points]

    def window_size(self) -> tuple[int, int]:
        max_x = 0
        max_y = 0
        for y in range(self.size):
            for x in range(self.size):
                for px, py in self.polygon(x, y):
                    max_x = max(max_x, px)
                    max_y = max(max_y, py)
        width = max_x + self.margin_x
        height = max_y + self.margin_y + self.info_height
        return width, height


class MisereHexGame:
    def __init__(
        self,
        size: int,
        red_strategy: Strategy,
        blue_strategy: Strategy,
    ) -> None:
        self.size = size
        self.board = np.zeros((size, size), dtype=np.int8)
        self.strategies = {
            RED: red_strategy,
            BLUE: blue_strategy,
        }
        self.reset()

    def reset(self) -> None:
        self.board.fill(EMPTY)
        self.current_player = RED
        self.last_move: tuple[int, int] | None = None
        self.move_count = 0
        self.game_over = False
        self.winner: int | None = None
        self.losing_player: int | None = None
        self.status_text = "Red to move."

    def _validate_move(self, move: tuple[int, int] | list[int] | np.ndarray) -> tuple[int, int]:
        if not isinstance(move, (tuple, list, np.ndarray)) or len(move) != 2:
            raise ValueError(f"Strategy must return (x, y); got {move!r}")
        x = int(move[0])
        y = int(move[1])
        if not in_bounds(self.size, x, y):
            raise ValueError(f"Move {(x, y)} is out of bounds.")
        if self.board[y, x] != EMPTY:
            raise ValueError(f"Move {(x, y)} is not on an empty cell.")
        return x, y

    def _forfeit(self, offender: int, reason: str) -> None:
        self.game_over = True
        self.losing_player = offender
        self.winner = other_player(offender)
        self.status_text = f"{PLAYER_NAMES[self.winner]} wins by forfeit: {reason}"

    def make_next_move(self) -> None:
        if self.game_over:
            return

        strategy = self.strategies[self.current_player]
        try:
            proposed_move = strategy(self.board.copy())
            x, y = self._validate_move(proposed_move)
        except Exception as exc:  # noqa: BLE001 - deliberate strategy sandboxing.
            self._forfeit(self.current_player, str(exc))
            return

        self.board[y, x] = self.current_player
        self.last_move = (x, y)
        self.move_count += 1

        if has_connection(self.board, self.current_player):
            self.losing_player = self.current_player
            self.winner = other_player(self.current_player)
            self.game_over = True
            if self.current_player == RED:
                self.status_text = "Blue wins! Red connected left-to-right."
            else:
                self.status_text = "Red wins! Blue connected top-to-bottom."
            return

        if not legal_moves(self.board):
            self.game_over = True
            self.winner = None
            self.status_text = "Board full. No winner detected (unexpected in Hex)."
            return

        self.current_player = other_player(self.current_player)
        self.status_text = f"{PLAYER_NAMES[self.current_player]} to move."



def get_restart_button_rect(window_width: int, window_height: int):
    if pygame is None:
        return None
    x = (window_width - BUTTON_WIDTH) // 2
    y = window_height - BUTTON_HEIGHT - BUTTON_MARGIN
    return pygame.Rect(x, y, BUTTON_WIDTH, BUTTON_HEIGHT)



def draw_restart_button(screen, rect, font) -> None:
    if pygame is None:
        return

    hovered = rect.collidepoint(pygame.mouse.get_pos())
    fill = BUTTON_HOVER_FILL if hovered else BUTTON_FILL
    pygame.draw.rect(screen, fill, rect, border_radius=10)
    pygame.draw.rect(screen, OUTLINE_COLOR, rect, 2, border_radius=10)

    label = font.render("Restart", True, BUTTON_TEXT_COLOR)
    label_rect = label.get_rect(center=rect.center)
    screen.blit(label, label_rect)



def draw_borders(screen, geometry: BoardGeometry) -> None:
    if pygame is None:
        return

    red_width = 7
    blue_width = 7
    size = geometry.size

    for y in range(size):
        left_poly = geometry.polygon(0, y)
        right_poly = geometry.polygon(size - 1, y)
        pygame.draw.line(screen, BORDER_COLORS[RED], left_poly[5], left_poly[4], red_width)
        pygame.draw.line(screen, BORDER_COLORS[RED], right_poly[1], right_poly[2], red_width)

    for x in range(size):
        top_poly = geometry.polygon(x, 0)
        bottom_poly = geometry.polygon(x, size - 1)
        pygame.draw.line(screen, BORDER_COLORS[BLUE], top_poly[5], top_poly[0], blue_width)
        pygame.draw.line(screen, BORDER_COLORS[BLUE], top_poly[0], top_poly[1], blue_width)
        pygame.draw.line(screen, BORDER_COLORS[BLUE], bottom_poly[4], bottom_poly[3], blue_width)
        pygame.draw.line(screen, BORDER_COLORS[BLUE], bottom_poly[3], bottom_poly[2], blue_width)



def draw_info_panel(screen, game: MisereHexGame, window_width: int, window_height: int, fonts, restart_rect) -> None:
    if pygame is None:
        return

    title_font, body_font = fonts
    panel_rect = pygame.Rect(0, window_height - INFO_HEIGHT, window_width, INFO_HEIGHT)
    pygame.draw.rect(screen, INFO_BG, panel_rect)
    pygame.draw.line(screen, OUTLINE_COLOR, (0, panel_rect.top), (window_width, panel_rect.top), 2)

    lines = [
        game.status_text,
        "Red: random. Blue: avoid helping Blue connect.",
        "Misere Hex: connecting your own sides loses.",
    ]

    y = panel_rect.top + 12
    for index, line in enumerate(lines):
        font = title_font if index == 0 else body_font
        surface = font.render(line, True, TEXT_COLOR)
        screen.blit(surface, (20, y))
        y += surface.get_height() + (6 if index == 0 else 4)

    draw_restart_button(screen, restart_rect, body_font)



def draw_game(screen, game: MisereHexGame, geometry: BoardGeometry, fonts, restart_rect) -> None:
    if pygame is None:
        return

    window_width, window_height = geometry.window_size()
    screen.fill(BACKGROUND_COLOR)

    for y in range(game.size):
        for x in range(game.size):
            poly = geometry.polygon(x, y)
            pygame.draw.polygon(screen, STONE_COLORS[int(game.board[y, x])], poly)
            pygame.draw.polygon(screen, OUTLINE_COLOR, poly, 2)

    draw_borders(screen, geometry)

    if game.last_move is not None:
        cx, cy = geometry.center(*game.last_move)
        pygame.draw.circle(
            screen,
            LAST_MOVE_MARKER,
            (int(round(cx)), int(round(cy))),
            max(4, int(geometry.radius * 0.16)),
        )

    draw_info_panel(screen, game, window_width, window_height, fonts, restart_rect)
    pygame.display.flip()



def run_game(
    board_size: int = BOARD_SIZE,
    red_strategy: Strategy = random_red_strategy,
    blue_strategy: Strategy = blue_anti_connection_strategy,
) -> None:
    if pygame is None:
        raise SystemExit("pygame is not installed. Please install it with: pip install pygame numpy")

    pygame.init()
    pygame.display.set_caption("Misere Hex")

    geometry = BoardGeometry(board_size)
    window_size = geometry.window_size()
    screen = pygame.display.set_mode(window_size)
    restart_rect = get_restart_button_rect(*window_size)

    title_font = pygame.font.SysFont(None, 30)
    body_font = pygame.font.SysFont(None, 24)
    fonts = (title_font, body_font)

    game = MisereHexGame(board_size, red_strategy, blue_strategy)
    clock = pygame.time.Clock()
    next_move_at = pygame.time.get_ticks() + START_DELAY_MS

    running = True
    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.MOUSEBUTTONDOWN and event.button == 1:
                if restart_rect.collidepoint(event.pos):
                    game.reset()
                    next_move_at = pygame.time.get_ticks() + START_DELAY_MS

        now = pygame.time.get_ticks()
        if not game.game_over and now >= next_move_at:
            game.make_next_move()
            next_move_at = now + MOVE_DELAY_MS

        draw_game(screen, game, geometry, fonts, restart_rect)
        clock.tick(FPS)

    pygame.quit()


if __name__ == "__main__":
    run_game()
