137 lines
5.3 KiB
Python
137 lines
5.3 KiB
Python
# api/v1/chat.py
|
|
import httpx
|
|
import json
|
|
import asyncio
|
|
from fastapi import APIRouter, HTTPException, Query
|
|
from fastapi.responses import StreamingResponse
|
|
from typing import List, Dict, Any
|
|
from models.chat import ChatRequest, ChatResponse
|
|
from services import redis_service
|
|
from services.history_manager import prepare_history
|
|
from utils.logging import logger
|
|
from config import settings
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
@router.post("/chat", response_model=ChatResponse)
|
|
async def chat_endpoint(payload: ChatRequest):
|
|
try:
|
|
session_id = payload.session_id
|
|
if not session_id:
|
|
meta = redis_service.create_session(payload.user_id, payload.message)
|
|
session_id = meta["session_id"]
|
|
|
|
redis_service.save_chat(payload.user_id, session_id, {"role": "user", "content": payload.message})
|
|
history_to_send = await prepare_history(payload.user_id, session_id)
|
|
|
|
#model_to_use = payload.model_name or settings.MODEL_NAME
|
|
# Recupera modello dalla sessione
|
|
session_meta = redis_service.get_session_meta(payload.user_id, session_id)
|
|
model_to_use = payload.model_name or session_meta.get("model_name") or settings.MODEL_NAME
|
|
|
|
async with httpx.AsyncClient(timeout=settings.REQUEST_TIMEOUT) as client:
|
|
resp = await client.post(
|
|
settings.LM_STUDIO_URL,
|
|
json={"model": model_to_use, "messages": history_to_send},
|
|
)
|
|
resp.raise_for_status()
|
|
data = resp.json()
|
|
|
|
reply = data["choices"][0]["message"]["content"]
|
|
redis_service.save_chat(payload.user_id, session_id, {"role": "assistant", "content": reply})
|
|
|
|
return ChatResponse(response=reply, session_id=session_id)
|
|
|
|
except Exception:
|
|
logger.exception("Error in /chat endpoint")
|
|
raise HTTPException(status_code=500, detail="Internal server error")
|
|
|
|
|
|
@router.post("/chat-stream")
|
|
async def chat_stream_endpoint(payload: ChatRequest):
|
|
session_id = payload.session_id
|
|
if not session_id:
|
|
meta = redis_service.create_session(payload.user_id, payload.message)
|
|
session_id = meta["session_id"]
|
|
|
|
redis_service.save_chat(payload.user_id, session_id, {"role": "user", "content": payload.message})
|
|
history_to_send = await prepare_history(payload.user_id, session_id)
|
|
#model_to_use = payload.model_name or settings.MODEL_NAME
|
|
# Recupera modello dalla sessione
|
|
session_meta = redis_service.get_session_meta(payload.user_id, session_id)
|
|
model_to_use = payload.model_name or session_meta.get("model_name") or settings.MODEL_NAME
|
|
|
|
async def event_generator():
|
|
assistant_text = ""
|
|
try:
|
|
async with httpx.AsyncClient(timeout=None) as client:
|
|
async with client.stream(
|
|
"POST",
|
|
settings.LM_STUDIO_URL,
|
|
json={
|
|
"model": model_to_use,
|
|
"messages": history_to_send,
|
|
"stream": True
|
|
}
|
|
) as r:
|
|
async for raw_line in r.aiter_lines():
|
|
if not raw_line:
|
|
continue
|
|
|
|
line = raw_line if raw_line.startswith("data:") else f"data: {raw_line}"
|
|
payload_str = line[len("data: "):].strip()
|
|
|
|
if payload_str == "[DONE]":
|
|
yield "data: [DONE]\n\n"
|
|
break
|
|
|
|
yield f"data: {payload_str}\n\n"
|
|
|
|
try:
|
|
obj = json.loads(payload_str)
|
|
choice = obj.get("choices", [{}])[0]
|
|
delta = choice.get("delta", {})
|
|
piece = delta.get("content") or choice.get("text")
|
|
if piece:
|
|
assistant_text += piece
|
|
except json.JSONDecodeError:
|
|
pass
|
|
|
|
await asyncio.sleep(0)
|
|
except Exception as e:
|
|
logger.exception("Streaming error in /chat-stream")
|
|
yield f"event: error\ndata: {str(e)}\n\n"
|
|
finally:
|
|
if assistant_text:
|
|
redis_service.save_chat(payload.user_id, session_id, {"role": "assistant", "content": assistant_text})
|
|
|
|
headers = {
|
|
"Cache-Control": "no-cache",
|
|
"Connection": "keep-alive",
|
|
"X-Accel-Buffering": "no"
|
|
}
|
|
return StreamingResponse(event_generator(), media_type="text/event-stream", headers=headers)
|
|
|
|
|
|
@router.get("/history")
|
|
async def get_history(
|
|
user_id: str = Query(..., description="User ID"),
|
|
session_id: str = Query(..., description="Session ID"),
|
|
limit: int = Query(50, description="Max number of messages to return")
|
|
) -> List[Dict[str, Any]]:
|
|
logger.info(f"[GET /history] user_id={user_id}, session_id={session_id}, limit={limit}")
|
|
history = redis_service.get_chat(user_id, session_id, limit=limit)
|
|
return history or []
|
|
|
|
|
|
@router.delete("/history")
|
|
async def delete_history(
|
|
user_id: str = Query(..., description="User ID"),
|
|
session_id: str = Query(..., description="Session ID")
|
|
):
|
|
logger.info(f"[DELETE /history] user_id={user_id}, session_id={session_id}")
|
|
redis_service.clear_chat(user_id, session_id)
|
|
return {"status": "cleared"}
|
|
|