#! /usr/bin/env/python3
# -*- coding=utf-8 -*-
import os
import time
import logging
import uvicorn
import socket
import psutil
from contextlib import asynccontextmanager
from prometheus_client import generate_latest, Gauge, CONTENT_TYPE_LATEST
from fastapi import FastAPI, Request, Response
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.responses import PlainTextResponse

from shared.conf import MAX_CONCURRENCY, PORT, RPC_LOCATION, UVICORN_LOG_CONFIG, VERSION
from shared.logging_util import logger as fine_ai_logger
from fine_agent.app import router as fine_agent_router, startup as fine_agent_startup
from fine_agent.multilingual.app import router as i18n_fine_agent_router
from fine_agent.utils.logging_util import logger as fine_agent_logger
from ai_core_suite.app import router as ai_core_suite_router, startup as ai_core_suite_startup
from ai_core_suite.utils.logging_util import logger as ai_core_suite_logger
from chat2dashboard.app import router as chat2dashboard_router, startup as chat2dashboard_startup
from chat2dashboard.utils.logging_util import logger as chat2dashboard_logger
from chat2dashboard.multilingual.en.router import router as en_chat2dashboard_router
from shared.model_util import load_model
from shared.status_codes import *


io_read_gauge: Gauge
io_write_gauge: Gauge


async def startup():
    fine_ai_logger.info(f"------------ Version {VERSION} ------------")

    fine_ai_logger.info(f"开始加载模型")
    start_time = time.perf_counter()
    load_model()
    fine_ai_logger.info(f"模型已加载，耗时: {time.perf_counter() - start_time:.3f}s")

    global io_read_gauge, io_write_gauge

    # 定义 Prometheus 指标
    io_read_gauge = Gauge("process_io_read_bytes", "IO read in bytes")
    io_write_gauge = Gauge("process_io_write_bytes", "IO write in bytes")

    # await chat2search_startup()
    await ai_core_suite_startup()
    await fine_agent_startup()
    await chat2dashboard_startup()

    fine_ai_logger.info("fine_ai started successfully")
    fine_ai_logger.info(f"------------ Version {VERSION} ------------")


async def shutdown():
    fine_ai_logger.info(f"fine_ai 服务已关闭")


@asynccontextmanager
async def lifespan(app: FastAPI):
    await startup()
    yield
    await shutdown()


app = FastAPI(lifespan=lifespan, openapi_url=None, docs_url=None, redoc_url=None)
for router in (ai_core_suite_router, chat2dashboard_router, fine_agent_router, i18n_fine_agent_router, en_chat2dashboard_router):
    app.include_router(router)


def get_logger(rpc_payload):
    if rpc_payload.serviceName == "fine-agent":
        return fine_agent_logger
    elif rpc_payload.serviceName == "ai-core-suite":
        return ai_core_suite_logger
    else:
        return fine_ai_logger


@app.middleware("http")
async def handle_response(request: Request, call_next):
    try:
        start_time = time.perf_counter()
        process = psutil.Process(os.getpid())
        mem_before_import = process.memory_info().rss

        response = await call_next(request)
        response.headers["Content-Type"] = "application/json; charset=UTF-8"

        mem_after_import = process.memory_info().rss
        if request.url.path != "/prometheus/metrics":
            fine_ai_logger.info(
                f"code: {response.status_code} | "
                f"URL: {request.url.path} | "
                f"time: {time.perf_counter() - start_time:.3f}s | "
                f"memory: {mem_before_import}->{mem_after_import} bytes | "
                f"client: {request.client.host + ':' + str(request.client.port) if request.client else None} | "
                "\n"
            )
        return response

    except Exception as e:
        fine_ai_logger.exception(e)
        return JSONResponse(status_code=500, content={"detail": "Internal Server Error"})


@app.exception_handler(RequestValidationError)
def validation_exception_handler(request: Request, exc: RequestValidationError):
    fine_ai_logger.warning(f"RequestValidationError: {exc.errors()}")
    return JSONResponse(status_code=422, content={"detail": exc.errors()})


@app.get("/openapi.json")
@app.get("/docs")
@app.get("/redoc")
async def disable_doc_url():
    return JSONResponse(status_code=404, content={"detail": "Not Found"})


@app.get("/prometheus/metrics")
async def prometheus_metrics():
    """OPS 服务监控接口，返回 prometheus 格式数据"""
    try:
        # 默认带有 CPU 时间、内存信息
        # 额外添加 IO 信息
        io_counters = psutil.Process().io_counters()
        io_read_gauge.set(io_counters.read_bytes)
        io_write_gauge.set(io_counters.write_bytes)
        return Response(generate_latest(), media_type=CONTENT_TYPE_LATEST)
    except Exception as e:
        fine_ai_logger.error("获取 prometheus 监控信息失败")
        fine_ai_logger.exception(e)
        return JSONResponse(status_code=500, content={"detail": "Internal Server Error"})


@app.get(f"/webroot/{RPC_LOCATION}/alive")
def is_alive():
    return True


@app.get("/v1/{language}/version")
@app.get("/v1/version")
def get_version():
    return {"version": VERSION}


@app.get("/v1/{language}/alive")
@app.get("/v1/alive")
def alive():
    return PlainTextResponse("alive")


@app.put("/loglevel/{level}")
async def set_log_level(level: str):
    try:
        level = level.strip().lower()
        if level == "critical":
            log_level = logging.CRITICAL
        elif level == "fatal":
            log_level = logging.FATAL
        elif level == "error":
            log_level = logging.ERROR
        elif level.startswith("warn"):
            log_level = logging.WARNING
        elif level == "info":
            log_level = logging.INFO
        elif level == "debug":
            log_level = logging.DEBUG
        else:
            fine_ai_logger.warning(f"未知的日志级别 '{level}'")
            return False
        fine_ai_logger.setLevel(log_level)
        ai_core_suite_logger.setLevel(log_level)
        fine_agent_logger.setLevel(log_level)
        chat2dashboard_logger.setLevel(log_level)
    except Exception as e:
        fine_ai_logger.error("调整日志级别时出现未知错误")
        fine_ai_logger.exception(e)
        return False
    return True


def get_local_ip():
    skt = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
    skt.connect(("8.8.8.8", 80))
    return skt.getsockname()[0]


if __name__ == "__main__":
    uvicorn.run(app='app:app', host=get_local_ip(), port=PORT, reload=False, limit_concurrency=MAX_CONCURRENCY, log_config=UVICORN_LOG_CONFIG)
