From c259266922e133a4b488e4b60f63adfd71b02a3d Mon Sep 17 00:00:00 2001 From: "Samuele E. Locatelli" Date: Fri, 5 Sep 2025 09:11:28 +0000 Subject: [PATCH] Fix streamlit model, fix model selection display --- backend/api/v1/chat.py | 34 +-- backend/api/v1/sessions.py | 21 +- backend/config.py | 7 +- backend/models/chat.py | 1 + backend/models/session.py | 1 + backend/services/redis_service.py | 3 +- frontend/src/App.jsx | 5 +- frontend/src/AssistantMessage.jsx.orig | 193 +++++++++++++++++ frontend/src/ChatInput.jsx | 112 ++++++---- frontend/src/ChatLayout.jsx | 4 +- package-lock.json | 280 +++++++++++++++++++++++++ package.json | 1 + streamlit/app.py | 37 ++-- streamlit/app.py.new | 79 +++++++ 14 files changed, 679 insertions(+), 99 deletions(-) create mode 100644 frontend/src/AssistantMessage.jsx.orig create mode 100644 streamlit/app.py.new diff --git a/backend/api/v1/chat.py b/backend/api/v1/chat.py index 36cadd7..cbcae48 100644 --- a/backend/api/v1/chat.py +++ b/backend/api/v1/chat.py @@ -13,34 +13,41 @@ from config import settings router = APIRouter() +@router.get("/models", response_model=List[str]) +async def list_models(): + try: + async with httpx.AsyncClient(timeout=settings.REQUEST_TIMEOUT) as client: + resp = await client.get(settings.LM_STUDIO_MODELS) + resp.raise_for_status() + data = resp.json() + return [m["id"] for m in data.get("data", [])] + except Exception: + logger.exception("Error fetching models") + raise HTTPException(status_code=500, detail="Failed to fetch models") + @router.post("/chat", response_model=ChatResponse) async def chat_endpoint(payload: ChatRequest): try: - # Creazione sessione se non esiste session_id = payload.session_id if not session_id: meta = redis_service.create_session(payload.user_id, payload.message) session_id = meta["session_id"] - # Salva messaggio utente redis_service.save_chat(payload.user_id, session_id, {"role": "user", "content": payload.message}) - - # Prepara la history ottimizzata history_to_send = await prepare_history(payload.user_id, session_id) - # Chiamata a LM Studio + model_to_use = payload.model_name or settings.MODEL_NAME + async with httpx.AsyncClient(timeout=settings.REQUEST_TIMEOUT) as client: resp = await client.post( settings.LM_STUDIO_URL, - json={"model": settings.MODEL_NAME, "messages": history_to_send}, + json={"model": model_to_use, "messages": history_to_send}, ) resp.raise_for_status() data = resp.json() reply = data["choices"][0]["message"]["content"] - - # Salva risposta assistant redis_service.save_chat(payload.user_id, session_id, {"role": "assistant", "content": reply}) return ChatResponse(response=reply, session_id=session_id) @@ -52,20 +59,14 @@ async def chat_endpoint(payload: ChatRequest): @router.post("/chat-stream") async def chat_stream_endpoint(payload: ChatRequest): - """ - Streams model output token-by-token usando SSE, - con windowing + summarization + condensed history. - """ session_id = payload.session_id if not session_id: meta = redis_service.create_session(payload.user_id, payload.message) session_id = meta["session_id"] - # Salva messaggio utente redis_service.save_chat(payload.user_id, session_id, {"role": "user", "content": payload.message}) - - # Prepara la history ottimizzata history_to_send = await prepare_history(payload.user_id, session_id) + model_to_use = payload.model_name or settings.MODEL_NAME async def event_generator(): assistant_text = "" @@ -75,7 +76,7 @@ async def chat_stream_endpoint(payload: ChatRequest): "POST", settings.LM_STUDIO_URL, json={ - "model": settings.MODEL_NAME, + "model": model_to_use, "messages": history_to_send, "stream": True } @@ -139,4 +140,3 @@ async def delete_history( redis_service.clear_chat(user_id, session_id) return {"status": "cleared"} - diff --git a/backend/api/v1/sessions.py b/backend/api/v1/sessions.py index c9a1820..e6d02e8 100644 --- a/backend/api/v1/sessions.py +++ b/backend/api/v1/sessions.py @@ -1,6 +1,6 @@ # api/v1/sessions.py -from typing import List +from typing import List, Optional from fastapi import Body, Query, Path, WebSocket, WebSocketDisconnect from services import redis_service from fastapi import APIRouter @@ -17,11 +17,9 @@ router = APIRouter() async def sessions_ws(websocket: WebSocket, user_id: str = Query(...)): await websocket.accept() try: - # Invia subito la lista completa sessions = redis_service.get_sessions(user_id) await websocket.send_json({"type": "full_list", "sessions": sessions}) - # Sottoscrizione al canale Redis pubsub = redis_service.r.pubsub() channel = f"sessions:{user_id}" await pubsub.subscribe(channel) @@ -59,10 +57,10 @@ async def get_session_meta_endpoint( @router.post("/sessions", response_model=dict) async def create_session_endpoint( user_id: str = Query(..., description="User ID"), - first_message: str = Body("", embed=True) + first_message: str = Body("", embed=True), + model_name: Optional[str] = Body(None, embed=True) # <-- Accept model_name ): - meta = redis_service.create_session(user_id, first_message) - # Notifica WS + meta = redis_service.create_session(user_id, first_message, model_name=model_name) redis_service.r.publish( f"sessions:{user_id}", json.dumps({"type": "created", "session": meta}) @@ -73,9 +71,15 @@ async def create_session_endpoint( async def update_session_endpoint( user_id: str = Query(..., description="User ID"), session_id: str = Path(..., description="Session ID"), - session_name: str = Body(..., embed=True) + session_name: Optional[str] = Body(None, embed=True), + model_name: Optional[str] = Body(None, embed=True) # <-- Allow model update ): - updated = redis_service.update_session_meta(user_id, session_id, session_name=session_name) or {} + updated = redis_service.update_session_meta( + user_id, session_id, + session_name=session_name, + model_name=model_name + ) or {} + if updated: redis_service.r.publish( f"sessions:{user_id}", @@ -95,4 +99,3 @@ async def delete_session_endpoint( ) return {"status": "deleted"} - diff --git a/backend/config.py b/backend/config.py index af7d3f5..25f4403 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,3 +1,4 @@ +# config.py import os from pydantic_settings import BaseSettings @@ -6,9 +7,11 @@ class Settings(BaseSettings): REDIS_PORT: int = int(os.getenv("REDIS_PORT", 6379)) REDIS_DB: int = int(os.getenv("REDIS_DB", 0)) LM_STUDIO_URL: str = os.getenv("LM_STUDIO_URL", "http://10.74.83.100:1234/v1/chat/completions") - #MODEL_NAME: str = os.getenv("MODEL_NAME", "qwen/qwen3-4b-thinking-2507") + LM_STUDIO_MODELS: str = os.getenv("LM_STUDIO_URL", "http://10.74.83.100:1234/v1/models") MODEL_NAME: str = os.getenv("MODEL_NAME", "qwen/qwen3-4b-2507") - REQUEST_TIMEOUT: float = float(os.getenv("REQUEST_TIMEOUT", 30.0)) + #MODEL_NAME: str = os.getenv("MODEL_NAME", "qwen/qwen3-4b-thinking-2507") + #MODEL_NAME: str = os.getenv("MODEL_NAME", "openai/gpt-oss-20b") + REQUEST_TIMEOUT: float = float(os.getenv("REQUEST_TIMEOUT", 60.0)) settings = Settings() diff --git a/backend/models/chat.py b/backend/models/chat.py index 4125add..bd3b16f 100644 --- a/backend/models/chat.py +++ b/backend/models/chat.py @@ -7,6 +7,7 @@ class ChatRequest(BaseModel): user_id: str # identifier for the user (can be same as session if desired) session_id: Optional[str] = None # new: multi-session handling message: str # user input text + model_name: Optional[str] = None # <-- Add this class ChatResponse(BaseModel): response: str # assistant's reply diff --git a/backend/models/session.py b/backend/models/session.py index b30fbcb..6924011 100644 --- a/backend/models/session.py +++ b/backend/models/session.py @@ -9,4 +9,5 @@ class SessionMeta(BaseModel): session_name: str message_count: int = 0 history_size_bytes: int = 0 + model_name: Optional[str] = None # <-- Added field diff --git a/backend/services/redis_service.py b/backend/services/redis_service.py index 973ccb5..b4aa51a 100644 --- a/backend/services/redis_service.py +++ b/backend/services/redis_service.py @@ -76,7 +76,8 @@ def create_session(user_id: str, first_message: str) -> dict: "created_at": created_at, "session_name": session_name, "message_count": 0, - "history_size_bytes": 0 + "history_size_bytes": 0, + "model_name": model_name # <-- Add this line } meta_key = f"chatSession:{user_id}:{session_id}" index_key = f"chatSessionsIndex:{user_id}" diff --git a/frontend/src/App.jsx b/frontend/src/App.jsx index 6f13d8b..7b14961 100644 --- a/frontend/src/App.jsx +++ b/frontend/src/App.jsx @@ -10,7 +10,7 @@ import "katex/dist/katex.min.css"; // <-- IMPORTANTE export default function App() { const { messages, loading, sendMessage, stopGenerating, setMessages } = useChatStream(); const [themeName, setThemeName] = useState("light"); - const [sessionName, setSessionName] = useState(""); + const [sessionName, setSessionName, sessionModelName ] = useState(""); const theme = themes[themeName]; const userId = getUserId(); const sessionId = getSessionId(); @@ -103,7 +103,8 @@ export default function App() { onSelectSession={handleSelectSession} userId={userId} sessionId={sessionId} - sessionName={sessionName} // <-- aggiunto + sessionModelName={sessionModelName} + sessionName={sessionName} /> ); } diff --git a/frontend/src/AssistantMessage.jsx.orig b/frontend/src/AssistantMessage.jsx.orig new file mode 100644 index 0000000..4ea4529 --- /dev/null +++ b/frontend/src/AssistantMessage.jsx.orig @@ -0,0 +1,193 @@ +// src/AssistantMessage.jsx +import React, { useState, useEffect } from "react"; +import ReactMarkdown from "react-markdown"; +import remarkGfm from "remark-gfm"; +import remarkMath from "remark-math"; +import rehypeKatex from "rehype-katex"; + +// Prism.js per syntax highlight +import Prism from "prismjs"; +import "prismjs/themes/prism.css"; + +// Linguaggi base +import "prismjs/components/prism-sql"; +import "prismjs/components/prism-javascript"; +import "prismjs/components/prism-css"; +import "prismjs/components/prism-json"; +import "prismjs/components/prism-markdown"; +import "prismjs/components/prism-csharp"; +import "prismjs/components/prism-lua"; +import "prismjs/components/prism-c"; +import "prismjs/components/prism-cpp"; +import "prismjs/components/prism-python"; +import "prismjs/components/prism-basic"; +import "prismjs/components/prism-javascript"; + +// Aggiungo alias "vb" che punta a "vbnet" +Prism.languages.vb = Prism.languages.vbnet; + +export default function AssistantMessage({ + content, + theme, + timestamp, + startedAt, + endedAt, + isFinal +}) { + const [showThink, setShowThink] = useState(false); + const [fadeOut, setFadeOut] = useState(false); + + const ts = timestamp ?? endedAt ?? startedAt; + + const thinkMatch = content?.match(/([\s\S]*?)<\/think>/i); + const thinkContent = thinkMatch ? thinkMatch[1].trim() : null; + + const isComplete = isFinal || Boolean(timestamp || endedAt); + + const visibleContent = isComplete + ? content?.replace(/[\s\S]*?<\/think>/i, "").trim() + : content; + + useEffect(() => { + if (thinkContent && !isComplete) { + setShowThink(true); + setFadeOut(false); + } + if (thinkContent && isComplete) { + setFadeOut(true); + const timer = setTimeout(() => setShowThink(false), 600); + return () => clearTimeout(timer); + } + }, [thinkContent, isComplete]); + + return ( +
+
+ {showThink && ( +
+ 🤔 {thinkContent} +
+ )} + + ( + + ), + th: (props) =>
, + code: CodeWithCopy + }} + > + {visibleContent} + + + {ts != null && ( +
+ {formatDateTime(ts)} +
+ )} + + ); +} + +function CodeWithCopy({ inline, className = "", children, ...props }) { + const [copied, setCopied] = useState(false); + const codeText = String(children).replace(/\n$/, ""); + const isFencedBlock = !inline && /^language-/.test(className); + + // Evidenziazione con Prism + useEffect(() => { + if (isFencedBlock) { + Prism.highlightAll(); + } + }, [codeText, isFencedBlock]); + + if (!isFencedBlock) { + return ( + + {children} + + ); + } + + const handleCopy = async () => { + try { + await navigator.clipboard.writeText(codeText); + setCopied(true); + setTimeout(() => setCopied(false), 1500); + } catch (err) { + console.error("Copy failed", err); + } + }; + + return ( +
+
+        {codeText}
+      
+ + {copied && ( + + Copied! + + )} +
+ ); +} + +function formatDateTime(dateTime) { + const date = dateTime instanceof Date ? dateTime : new Date(dateTime); + if (Number.isNaN(date.getTime())) return String(dateTime); + return date.toLocaleString(undefined, { + year: "numeric", + month: "short", + day: "numeric", + hour: "2-digit", + minute: "2-digit", + second: "2-digit" + }); +} + + diff --git a/frontend/src/ChatInput.jsx b/frontend/src/ChatInput.jsx index e461357..cc5b19d 100644 --- a/frontend/src/ChatInput.jsx +++ b/frontend/src/ChatInput.jsx @@ -1,13 +1,16 @@ // src/ChatInput.jsx import React, { useState, useEffect, useRef } from "react"; +import axios from "axios"; -const MAX_EXECUTION_TIME_MS_DEFAULT = 2 * 60 * 1000; // 2 minuti +const MAX_EXECUTION_TIME_MS_DEFAULT = 2 * 60 * 1000; const EXTRA_TIME_MS = 60 * 1000; const CHAR_LIMIT_FOR_TEXTAREA = 80; const LINE_LIMIT_FOR_TEXTAREA = 1; -export default function ChatInput({ onSend, onStop, loading }) { +export default function ChatInput({ onSend, onStop, loading, sessionModelName = "" }) { const [inputValue, setInputValue] = useState(""); + const [modelList, setModelList] = useState([]); + const [selectedModel, setSelectedModel] = useState(sessionModelName || ""); const [timeLeft, setTimeLeft] = useState(null); const [elapsedTime, setElapsedTime] = useState(0); const [maxExecutionTime, setMaxExecutionTime] = useState(MAX_EXECUTION_TIME_MS_DEFAULT); @@ -21,7 +24,21 @@ export default function ChatInput({ onSend, onStop, loading }) { inputValue.length > CHAR_LIMIT_FOR_TEXTAREA || inputValue.split("\n").length > LINE_LIMIT_FOR_TEXTAREA; - // Focus automatico + // Fetch models and set default from session + useEffect(() => { + axios.get("/v1/models") + .then(res => { + const sortedModels = res.data.sort((a, b) => a.localeCompare(b)); + setModelList(sortedModels); + if (sessionModelName && sortedModels.includes(sessionModelName)) { + setSelectedModel(sessionModelName); + } + }) + .catch(err => { + console.error("❌ Failed to fetch models:", err.message); + }); + }, [sessionModelName]); + useEffect(() => { if (!loading && inputRef.current) { inputRef.current.focus(); @@ -30,7 +47,6 @@ export default function ChatInput({ onSend, onStop, loading }) { } }, [isTextarea, loading]); - // Auto‑resize useEffect(() => { if (isTextarea && inputRef.current) { inputRef.current.style.height = "auto"; @@ -38,7 +54,6 @@ export default function ChatInput({ onSend, onStop, loading }) { } }, [inputValue, isTextarea]); - // Gestione timer e countdown useEffect(() => { if (loading) { setTimeLeft(Math.floor(maxExecutionTime / 1000)); @@ -77,7 +92,7 @@ export default function ChatInput({ onSend, onStop, loading }) { const handleSend = () => { if (inputValue.trim()) { - onSend(inputValue); + onSend(inputValue, selectedModel); setInputValue(""); } }; @@ -113,36 +128,55 @@ export default function ChatInput({ onSend, onStop, loading }) { }; return ( -
- {loading ? ( -
- 💬 Il modello sta processando… (tempo trascorso: {elapsedTime}s) -
- ) : isTextarea ? ( -