from contextlib import contextmanager
from decimal import Decimal
import json
from typing import Any, Iterable

import mysql.connector
from mysql.connector import Error

from .config import config
from .logger import get_logger


log = get_logger(__name__)


def _json_default(value: Any) -> str:
    if isinstance(value, Decimal):
        return str(value)
    return str(value)


class Database:
    def __init__(self) -> None:
        self.options = {
            "host": config.db_host,
            "port": config.db_port,
            "database": config.db_name,
            "user": config.db_user,
            "password": config.db_password,
            "autocommit": False,
        }

    @contextmanager
    def connection(self):
        conn = mysql.connector.connect(**self.options)
        try:
            yield conn
            conn.commit()
        except Exception:
            conn.rollback()
            raise
        finally:
            conn.close()

    def fetch_all(self, sql: str, params: Iterable[Any] | None = None) -> list[dict]:
        with self.connection() as conn:
            cursor = conn.cursor(dictionary=True)
            cursor.execute(sql, tuple(params or ()))
            rows = cursor.fetchall()
            cursor.close()
            return rows

    def fetch_one(self, sql: str, params: Iterable[Any] | None = None) -> dict | None:
        rows = self.fetch_all(sql, params)
        return rows[0] if rows else None

    def execute(self, sql: str, params: Iterable[Any] | None = None) -> int:
        with self.connection() as conn:
            cursor = conn.cursor()
            cursor.execute(sql, tuple(params or ()))
            last_id = cursor.lastrowid
            cursor.close()
            return last_id

    def execute_many(self, sql: str, rows: list[Iterable[Any]]) -> None:
        if not rows:
            return
        with self.connection() as conn:
            cursor = conn.cursor()
            cursor.executemany(sql, rows)
            cursor.close()

    def setting_map(self) -> dict[str, str]:
        rows = self.fetch_all("SELECT setting_key, setting_value FROM settings")
        return {row["setting_key"]: row["setting_value"] for row in rows}

    def set_setting(self, key: str, value: str) -> None:
        self.execute(
            """
            INSERT INTO settings (setting_key, setting_value)
            VALUES (%s, %s)
            ON DUPLICATE KEY UPDATE setting_value = VALUES(setting_value)
            """,
            (key, value),
        )

    def log_trade_event(
        self,
        message: str,
        level: str = "INFO",
        trade_id: int | None = None,
        pump_event_id: int | None = None,
        context: dict | None = None,
    ) -> None:
        self.execute(
            """
            INSERT INTO trade_logs (trade_id, pump_event_id, level, message, context)
            VALUES (%s, %s, %s, %s, %s)
            """,
            (
                trade_id,
                pump_event_id,
                level,
                message,
                json.dumps(context or {}, default=_json_default),
            ),
        )

    def log_api(self, service: str, endpoint: str, status_code: int | None, success: bool, message: str) -> None:
        try:
            self.execute(
                """
                INSERT INTO api_logs (service, endpoint, status_code, success, message)
                VALUES (%s, %s, %s, %s, %s)
                """,
                (service, endpoint, status_code, 1 if success else 0, message[:2000]),
            )
        except Error:
            log.exception("Failed writing API log")


db = Database()

