refactor: 使用dataclass 设定config

This commit is contained in:
ViperEkura 2026-03-29 00:45:51 +08:00
parent ae73559fd2
commit 2a6c82b3ba
8 changed files with 128 additions and 85 deletions

View File

@ -1,49 +1,105 @@
"""Configuration management""" """Configuration management using dataclasses"""
import sys import sys
from dataclasses import dataclass, field
from typing import List, Dict, Optional
from backend import load_config from backend import load_config
_cfg = load_config()
# Model list (for /api/models endpoint) @dataclass
MODELS = _cfg.get("models", []) class ModelConfig:
"""Individual model configuration."""
id: str
name: str
api_url: str
api_key: str
# Validate each model has required fields at startup
_REQUIRED_MODEL_KEYS = {"id", "name", "api_url", "api_key"}
_model_ids_seen = set()
for _i, _m in enumerate(MODELS):
_missing = _REQUIRED_MODEL_KEYS - set(_m.keys())
if _missing:
print(f"[config] ERROR: models[{_i}] missing required fields: {_missing}", file=sys.stderr)
sys.exit(1)
if _m["id"] in _model_ids_seen:
print(f"[config] ERROR: duplicate model id '{_m['id']}'", file=sys.stderr)
sys.exit(1)
_model_ids_seen.add(_m["id"])
# Per-model config lookup: {model_id: {api_url, api_key}} @dataclass
MODEL_CONFIG = {m["id"]: {"api_url": m["api_url"], "api_key": m["api_key"]} for m in MODELS} class SubAgentConfig:
"""Sub-agent (multi_agent tool) settings."""
max_iterations: int = 3
max_concurrency: int = 3
timeout: int = 60
# default_model must exist in models
DEFAULT_MODEL = _cfg.get("default_model", "")
if DEFAULT_MODEL and DEFAULT_MODEL not in MODEL_CONFIG:
print(f"[config] ERROR: default_model '{DEFAULT_MODEL}' not found in models", file=sys.stderr)
sys.exit(1)
if MODELS and not DEFAULT_MODEL:
DEFAULT_MODEL = MODELS[0]["id"]
# Max agentic loop iterations (tool call rounds) @dataclass
MAX_ITERATIONS = _cfg.get("max_iterations", 5) class CodeExecutionConfig:
"""Code execution settings."""
default_strictness: str = "standard"
extra_allowed_modules: Dict = field(default_factory=dict)
# Max parallel workers for tool execution (ThreadPoolExecutor)
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
# Sub-agent settings (multi_agent tool) @dataclass
_sa = _cfg.get("sub_agent", {}) class AppConfig:
SUB_AGENT_MAX_ITERATIONS = _sa.get("max_iterations", 3) """Main application configuration."""
SUB_AGENT_MAX_CONCURRENCY = _sa.get("max_concurrency", 3) models: List[ModelConfig] = field(default_factory=list)
SUB_AGENT_TIMEOUT = _sa.get("timeout", 60) default_model: str = ""
max_iterations: int = 5
tool_max_workers: int = 4
sub_agent: SubAgentConfig = field(default_factory=SubAgentConfig)
code_execution: CodeExecutionConfig = field(default_factory=CodeExecutionConfig)
# Code execution settings # Per-model config lookup: {model_id: ModelConfig}
_ce = _cfg.get("code_execution", {}) _model_config_map: Dict[str, ModelConfig] = field(default_factory=dict, repr=False)
CODE_EXECUTION_DEFAULT_STRICTNESS = _ce.get("default_strictness", "standard")
CODE_EXECUTION_EXTRA_MODULES = _ce.get("extra_allowed_modules", {}) def __post_init__(self):
"""Build lookup map after initialization."""
self._model_config_map = {m.id: m for m in self.models}
def get_model_config(self, model_id: str) -> Optional[ModelConfig]:
"""Get model config by ID."""
return self._model_config_map.get(model_id)
def get_model_credentials(self, model_id: str) -> tuple:
"""Get (api_url, api_key) for a model."""
cfg = self.get_model_config(model_id)
if not cfg:
raise ValueError(f"Unknown model: '{model_id}', not found in config")
if not cfg.api_url:
raise ValueError(f"Model '{model_id}' has no api_url configured")
if not cfg.api_key:
raise ValueError(f"Model '{model_id}' has no api_key configured")
return cfg.api_url, cfg.api_key
def _parse_config(raw: dict) -> AppConfig:
"""Parse raw YAML config into AppConfig dataclass."""
# Parse models
models = []
for m in raw.get("models", []):
models.append(ModelConfig(
id=m["id"],
name=m["name"],
api_url=m["api_url"],
api_key=m["api_key"],
))
# Parse sub_agent
sa_raw = raw.get("sub_agent", {})
sub_agent = SubAgentConfig(
max_iterations=sa_raw.get("max_iterations", 3),
max_concurrency=sa_raw.get("max_concurrency", 3),
timeout=sa_raw.get("timeout", 60),
)
# Parse code_execution
ce_raw = raw.get("code_execution", {})
code_execution = CodeExecutionConfig(
default_strictness=ce_raw.get("default_strictness", "standard"),
extra_allowed_modules=ce_raw.get("extra_allowed_modules", {}),
)
return AppConfig(
models=models,
default_model=raw.get("default_model", ""),
max_iterations=raw.get("max_iterations", 5),
tool_max_workers=raw.get("tool_max_workers", 4),
sub_agent=sub_agent,
code_execution=code_execution,
)
# Load and validate configuration at startup
_raw_cfg = load_config()
config = _parse_config(_raw_cfg)

View File

@ -8,13 +8,13 @@ from backend.routes.stats import bp as stats_bp
from backend.routes.projects import bp as projects_bp from backend.routes.projects import bp as projects_bp
from backend.routes.auth import bp as auth_bp, init_auth from backend.routes.auth import bp as auth_bp, init_auth
from backend.services.llm_client import LLMClient from backend.services.llm_client import LLMClient
from backend.config import MODEL_CONFIG from backend.config import config
def register_routes(app: Flask): def register_routes(app: Flask):
"""Register all route blueprints""" """Register all route blueprints"""
# Initialize LLM client with per-model config # Initialize LLM client with config
client = LLMClient(MODEL_CONFIG) client = LLMClient(config)
init_chat_service(client) init_chat_service(client)
# Register LLM client in service locator so tools (e.g. multi_agent) can access it # Register LLM client in service locator so tools (e.g. multi_agent) can access it

View File

@ -5,7 +5,7 @@ from flask import Blueprint, request, g
from backend import db from backend import db
from backend.models import Conversation, Project from backend.models import Conversation, Project
from backend.utils.helpers import ok, err, to_dict from backend.utils.helpers import ok, err, to_dict
from backend.config import DEFAULT_MODEL from backend.config import config
bp = Blueprint("conversations", __name__) bp = Blueprint("conversations", __name__)
@ -40,7 +40,7 @@ def conversation_list():
user_id=user.id, user_id=user.id,
project_id=project_id or None, project_id=project_id or None,
title=d.get("title", ""), title=d.get("title", ""),
model=d.get("model", DEFAULT_MODEL), model=d.get("model", config.default_model),
system_prompt=d.get("system_prompt", ""), system_prompt=d.get("system_prompt", ""),
temperature=d.get("temperature", 1.0), temperature=d.get("temperature", 1.0),
max_tokens=d.get("max_tokens", 65536), max_tokens=d.get("max_tokens", 65536),

View File

@ -1,7 +1,7 @@
"""Model list API routes""" """Model list API routes"""
from flask import Blueprint from flask import Blueprint
from backend.utils.helpers import ok from backend.utils.helpers import ok
from backend.config import MODELS from backend.config import config
bp = Blueprint("models", __name__) bp = Blueprint("models", __name__)
@ -13,7 +13,10 @@ _SENSITIVE_KEYS = {"api_key", "api_url"}
def list_models(): def list_models():
"""Get available model list (without sensitive fields like api_key)""" """Get available model list (without sensitive fields like api_key)"""
safe_models = [ safe_models = [
{k: v for k, v in m.items() if k not in _SENSITIVE_KEYS} {
for m in MODELS "id": m.id,
"name": m.name,
}
for m in config.models
] ]
return ok(safe_models) return ok(safe_models)

View File

@ -14,7 +14,7 @@ from backend.utils.helpers import (
build_messages, build_messages,
) )
from backend.services.llm_client import LLMClient from backend.services.llm_client import LLMClient
from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS from backend.config import config as _cfg
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -89,7 +89,8 @@ class ChatService:
total_completion_tokens = 0 total_completion_tokens = 0
total_prompt_tokens = 0 total_prompt_tokens = 0
for iteration in range(MAX_ITERATIONS):
for iteration in range(_cfg.max_iterations):
# Helper to parse stream_result event # Helper to parse stream_result event
def parse_stream_result(event_str): def parse_stream_result(event_str):
"""Parse stream_result SSE event and extract data dict.""" """Parse stream_result SSE event and extract data dict."""
@ -385,7 +386,7 @@ class ChatService:
if len(tool_calls_list) > 1: if len(tool_calls_list) > 1:
with app.app_context(): with app.app_context():
return executor.process_tool_calls_parallel( return executor.process_tool_calls_parallel(
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS tool_calls_list, context, max_workers=_cfg.tool_max_workers
) )
else: else:
with app.app_context(): with app.app_context():

View File

@ -35,28 +35,22 @@ def _detect_provider(api_url: str) -> str:
class LLMClient: class LLMClient:
"""OpenAI-compatible LLM API client. """OpenAI-compatible LLM API client.
Each model must have its own api_url and api_key configured in MODEL_CONFIG. Each model must have its own api_url and api_key configured in config.models.
""" """
def __init__(self, model_config: dict): def __init__(self, cfg):
"""Initialize with per-model config lookup. """Initialize with AppConfig.
Args: Args:
model_config: {model_id: {"api_url": ..., "api_key": ...}} cfg: AppConfig dataclass instance
""" """
self.model_config = model_config self.cfg = cfg
def _get_credentials(self, model: str): def _get_credentials(self, model: str):
"""Get api_url and api_key for a model, with env-var expansion.""" """Get api_url and api_key for a model, with env-var expansion."""
cfg = self.model_config.get(model) api_url, api_key = self.cfg.get_model_credentials(model)
if not cfg: api_url = _resolve_env_vars(api_url)
raise ValueError(f"Unknown model: '{model}', not found in config") api_key = _resolve_env_vars(api_key)
api_url = _resolve_env_vars(cfg.get("api_url", ""))
api_key = _resolve_env_vars(cfg.get("api_key", ""))
if not api_url:
raise ValueError(f"Model '{model}' has no api_url configured")
if not api_key:
raise ValueError(f"Model '{model}' has no api_key configured")
return api_url, api_key return api_url, api_key
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled, def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,

View File

@ -1,22 +1,12 @@
"""Multi-agent tool for spawning concurrent sub-agents.
Provides:
- multi_agent: Spawn sub-agents with independent LLM conversation loops
"""
import json import json
import logging import logging
from concurrent.futures import ThreadPoolExecutor, as_completed from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional from typing import List, Any, Optional
from backend.tools import get_service from backend.tools import get_service
from backend.tools.factory import tool from backend.tools.factory import tool
from backend.tools.core import registry from backend.tools.core import registry
from backend.tools.executor import ToolExecutor from backend.tools.executor import ToolExecutor
from backend.config import ( from backend.config import config
DEFAULT_MODEL,
SUB_AGENT_MAX_ITERATIONS,
SUB_AGENT_MAX_CONCURRENCY,
SUB_AGENT_TIMEOUT,
)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -118,7 +108,7 @@ def _run_sub_agent(
stream=False, stream=False,
max_tokens=max_tokens, max_tokens=max_tokens,
temperature=temperature, temperature=temperature,
timeout=SUB_AGENT_TIMEOUT, timeout=config.sub_agent.timeout,
) )
if resp.status_code != 200: if resp.status_code != 200:
@ -253,13 +243,13 @@ def multi_agent(arguments: dict) -> dict:
app = current_app._get_current_object() app = current_app._get_current_object()
# Use injected model/project_id from executor context, fall back to defaults # Use injected model/project_id from executor context, fall back to defaults
model = arguments.get("_model") or DEFAULT_MODEL model = arguments.get("_model") or config.default_model
project_id = arguments.get("_project_id") project_id = arguments.get("_project_id")
max_tokens = arguments.get("_max_tokens", 65536) max_tokens = arguments.get("_max_tokens", 65536)
temperature = arguments.get("_temperature", 0.7) temperature = arguments.get("_temperature", 0.7)
# Execute agents concurrently # Execute agents concurrently
concurrency = min(len(tasks), SUB_AGENT_MAX_CONCURRENCY) concurrency = min(len(tasks), config.sub_agent.max_concurrency)
results = [None] * len(tasks) results = [None] * len(tasks)
with ThreadPoolExecutor(max_workers=concurrency) as pool: with ThreadPoolExecutor(max_workers=concurrency) as pool:
@ -274,7 +264,7 @@ def multi_agent(arguments: dict) -> dict:
temperature, temperature,
project_id, project_id,
app, app,
SUB_AGENT_MAX_ITERATIONS, config.sub_agent.max_iterations,
): i ): i
for i, task in enumerate(tasks) for i, task in enumerate(tasks)
} }

View File

@ -7,8 +7,7 @@ import textwrap
from typing import Dict, List, Set from typing import Dict, List, Set
from backend.tools.factory import tool from backend.tools.factory import tool
from backend.config import CODE_EXECUTION_DEFAULT_STRICTNESS as DEFAULT_STRICTNESS from backend.config import config
from backend.config import CODE_EXECUTION_EXTRA_MODULES as _CFG_EXTRA_MODULES
# Strictness profiles configuration # Strictness profiles configuration
@ -101,7 +100,7 @@ def register_extra_modules(strictness: str, modules: Set[str] | List[str]) -> No
# Apply extra modules from config.yml on module load # Apply extra modules from config.yml on module load
for _level, _mods in _CFG_EXTRA_MODULES.items(): for _level, _mods in config.code_execution.extra_allowed_modules.items():
if isinstance(_mods, list) and _mods: if isinstance(_mods, list) and _mods:
register_extra_modules(_level, _mods) register_extra_modules(_level, _mods)
@ -144,7 +143,7 @@ def execute_python(arguments: dict) -> dict:
5. Subprocess isolation 5. Subprocess isolation
""" """
code = arguments["code"] code = arguments["code"]
strictness = arguments.get("strictness", DEFAULT_STRICTNESS) strictness = arguments.get("strictness", config.code_execution.default_strictness)
# Validate strictness level # Validate strictness level
if strictness not in STRICTNESS_PROFILES: if strictness not in STRICTNESS_PROFILES: