refactor: 使用dataclass 设定config

This commit is contained in:
ViperEkura 2026-03-29 00:45:51 +08:00
parent ae73559fd2
commit 2a6c82b3ba
8 changed files with 128 additions and 85 deletions

View File

@ -1,49 +1,105 @@
"""Configuration management"""
"""Configuration management using dataclasses"""
import sys
from dataclasses import dataclass, field
from typing import List, Dict, Optional
from backend import load_config
_cfg = load_config()
# Model list (for /api/models endpoint)
MODELS = _cfg.get("models", [])
@dataclass
class ModelConfig:
"""Individual model configuration."""
id: str
name: str
api_url: str
api_key: str
# Validate each model has required fields at startup
_REQUIRED_MODEL_KEYS = {"id", "name", "api_url", "api_key"}
_model_ids_seen = set()
for _i, _m in enumerate(MODELS):
_missing = _REQUIRED_MODEL_KEYS - set(_m.keys())
if _missing:
print(f"[config] ERROR: models[{_i}] missing required fields: {_missing}", file=sys.stderr)
sys.exit(1)
if _m["id"] in _model_ids_seen:
print(f"[config] ERROR: duplicate model id '{_m['id']}'", file=sys.stderr)
sys.exit(1)
_model_ids_seen.add(_m["id"])
# Per-model config lookup: {model_id: {api_url, api_key}}
MODEL_CONFIG = {m["id"]: {"api_url": m["api_url"], "api_key": m["api_key"]} for m in MODELS}
@dataclass
class SubAgentConfig:
"""Sub-agent (multi_agent tool) settings."""
max_iterations: int = 3
max_concurrency: int = 3
timeout: int = 60
# default_model must exist in models
DEFAULT_MODEL = _cfg.get("default_model", "")
if DEFAULT_MODEL and DEFAULT_MODEL not in MODEL_CONFIG:
print(f"[config] ERROR: default_model '{DEFAULT_MODEL}' not found in models", file=sys.stderr)
sys.exit(1)
if MODELS and not DEFAULT_MODEL:
DEFAULT_MODEL = MODELS[0]["id"]
# Max agentic loop iterations (tool call rounds)
MAX_ITERATIONS = _cfg.get("max_iterations", 5)
@dataclass
class CodeExecutionConfig:
"""Code execution settings."""
default_strictness: str = "standard"
extra_allowed_modules: Dict = field(default_factory=dict)
# Max parallel workers for tool execution (ThreadPoolExecutor)
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
# Sub-agent settings (multi_agent tool)
_sa = _cfg.get("sub_agent", {})
SUB_AGENT_MAX_ITERATIONS = _sa.get("max_iterations", 3)
SUB_AGENT_MAX_CONCURRENCY = _sa.get("max_concurrency", 3)
SUB_AGENT_TIMEOUT = _sa.get("timeout", 60)
@dataclass
class AppConfig:
"""Main application configuration."""
models: List[ModelConfig] = field(default_factory=list)
default_model: str = ""
max_iterations: int = 5
tool_max_workers: int = 4
sub_agent: SubAgentConfig = field(default_factory=SubAgentConfig)
code_execution: CodeExecutionConfig = field(default_factory=CodeExecutionConfig)
# Code execution settings
_ce = _cfg.get("code_execution", {})
CODE_EXECUTION_DEFAULT_STRICTNESS = _ce.get("default_strictness", "standard")
CODE_EXECUTION_EXTRA_MODULES = _ce.get("extra_allowed_modules", {})
# Per-model config lookup: {model_id: ModelConfig}
_model_config_map: Dict[str, ModelConfig] = field(default_factory=dict, repr=False)
def __post_init__(self):
"""Build lookup map after initialization."""
self._model_config_map = {m.id: m for m in self.models}
def get_model_config(self, model_id: str) -> Optional[ModelConfig]:
"""Get model config by ID."""
return self._model_config_map.get(model_id)
def get_model_credentials(self, model_id: str) -> tuple:
"""Get (api_url, api_key) for a model."""
cfg = self.get_model_config(model_id)
if not cfg:
raise ValueError(f"Unknown model: '{model_id}', not found in config")
if not cfg.api_url:
raise ValueError(f"Model '{model_id}' has no api_url configured")
if not cfg.api_key:
raise ValueError(f"Model '{model_id}' has no api_key configured")
return cfg.api_url, cfg.api_key
def _parse_config(raw: dict) -> AppConfig:
"""Parse raw YAML config into AppConfig dataclass."""
# Parse models
models = []
for m in raw.get("models", []):
models.append(ModelConfig(
id=m["id"],
name=m["name"],
api_url=m["api_url"],
api_key=m["api_key"],
))
# Parse sub_agent
sa_raw = raw.get("sub_agent", {})
sub_agent = SubAgentConfig(
max_iterations=sa_raw.get("max_iterations", 3),
max_concurrency=sa_raw.get("max_concurrency", 3),
timeout=sa_raw.get("timeout", 60),
)
# Parse code_execution
ce_raw = raw.get("code_execution", {})
code_execution = CodeExecutionConfig(
default_strictness=ce_raw.get("default_strictness", "standard"),
extra_allowed_modules=ce_raw.get("extra_allowed_modules", {}),
)
return AppConfig(
models=models,
default_model=raw.get("default_model", ""),
max_iterations=raw.get("max_iterations", 5),
tool_max_workers=raw.get("tool_max_workers", 4),
sub_agent=sub_agent,
code_execution=code_execution,
)
# Load and validate configuration at startup
_raw_cfg = load_config()
config = _parse_config(_raw_cfg)

View File

@ -8,13 +8,13 @@ from backend.routes.stats import bp as stats_bp
from backend.routes.projects import bp as projects_bp
from backend.routes.auth import bp as auth_bp, init_auth
from backend.services.llm_client import LLMClient
from backend.config import MODEL_CONFIG
from backend.config import config
def register_routes(app: Flask):
"""Register all route blueprints"""
# Initialize LLM client with per-model config
client = LLMClient(MODEL_CONFIG)
# Initialize LLM client with config
client = LLMClient(config)
init_chat_service(client)
# Register LLM client in service locator so tools (e.g. multi_agent) can access it

View File

@ -5,7 +5,7 @@ from flask import Blueprint, request, g
from backend import db
from backend.models import Conversation, Project
from backend.utils.helpers import ok, err, to_dict
from backend.config import DEFAULT_MODEL
from backend.config import config
bp = Blueprint("conversations", __name__)
@ -40,7 +40,7 @@ def conversation_list():
user_id=user.id,
project_id=project_id or None,
title=d.get("title", ""),
model=d.get("model", DEFAULT_MODEL),
model=d.get("model", config.default_model),
system_prompt=d.get("system_prompt", ""),
temperature=d.get("temperature", 1.0),
max_tokens=d.get("max_tokens", 65536),
@ -105,4 +105,4 @@ def conversation_detail(conv_id):
conv.project_id = project_id or None
db.session.commit()
return ok(_conv_to_dict(conv))
return ok(_conv_to_dict(conv))

View File

@ -1,7 +1,7 @@
"""Model list API routes"""
from flask import Blueprint
from backend.utils.helpers import ok
from backend.config import MODELS
from backend.config import config
bp = Blueprint("models", __name__)
@ -13,7 +13,10 @@ _SENSITIVE_KEYS = {"api_key", "api_url"}
def list_models():
"""Get available model list (without sensitive fields like api_key)"""
safe_models = [
{k: v for k, v in m.items() if k not in _SENSITIVE_KEYS}
for m in MODELS
{
"id": m.id,
"name": m.name,
}
for m in config.models
]
return ok(safe_models)
return ok(safe_models)

View File

@ -14,7 +14,7 @@ from backend.utils.helpers import (
build_messages,
)
from backend.services.llm_client import LLMClient
from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS
from backend.config import config as _cfg
logger = logging.getLogger(__name__)
@ -89,7 +89,8 @@ class ChatService:
total_completion_tokens = 0
total_prompt_tokens = 0
for iteration in range(MAX_ITERATIONS):
for iteration in range(_cfg.max_iterations):
# Helper to parse stream_result event
def parse_stream_result(event_str):
"""Parse stream_result SSE event and extract data dict."""
@ -385,7 +386,7 @@ class ChatService:
if len(tool_calls_list) > 1:
with app.app_context():
return executor.process_tool_calls_parallel(
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
tool_calls_list, context, max_workers=_cfg.tool_max_workers
)
else:
with app.app_context():

View File

@ -35,28 +35,22 @@ def _detect_provider(api_url: str) -> str:
class LLMClient:
"""OpenAI-compatible LLM API client.
Each model must have its own api_url and api_key configured in MODEL_CONFIG.
Each model must have its own api_url and api_key configured in config.models.
"""
def __init__(self, model_config: dict):
"""Initialize with per-model config lookup.
def __init__(self, cfg):
"""Initialize with AppConfig.
Args:
model_config: {model_id: {"api_url": ..., "api_key": ...}}
cfg: AppConfig dataclass instance
"""
self.model_config = model_config
self.cfg = cfg
def _get_credentials(self, model: str):
"""Get api_url and api_key for a model, with env-var expansion."""
cfg = self.model_config.get(model)
if not cfg:
raise ValueError(f"Unknown model: '{model}', not found in config")
api_url = _resolve_env_vars(cfg.get("api_url", ""))
api_key = _resolve_env_vars(cfg.get("api_key", ""))
if not api_url:
raise ValueError(f"Model '{model}' has no api_url configured")
if not api_key:
raise ValueError(f"Model '{model}' has no api_key configured")
api_url, api_key = self.cfg.get_model_credentials(model)
api_url = _resolve_env_vars(api_url)
api_key = _resolve_env_vars(api_key)
return api_url, api_key
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,

View File

@ -1,22 +1,12 @@
"""Multi-agent tool for spawning concurrent sub-agents.
Provides:
- multi_agent: Spawn sub-agents with independent LLM conversation loops
"""
import json
import logging
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import List, Dict, Any, Optional
from typing import List, Any, Optional
from backend.tools import get_service
from backend.tools.factory import tool
from backend.tools.core import registry
from backend.tools.executor import ToolExecutor
from backend.config import (
DEFAULT_MODEL,
SUB_AGENT_MAX_ITERATIONS,
SUB_AGENT_MAX_CONCURRENCY,
SUB_AGENT_TIMEOUT,
)
from backend.config import config
logger = logging.getLogger(__name__)
@ -118,7 +108,7 @@ def _run_sub_agent(
stream=False,
max_tokens=max_tokens,
temperature=temperature,
timeout=SUB_AGENT_TIMEOUT,
timeout=config.sub_agent.timeout,
)
if resp.status_code != 200:
@ -253,13 +243,13 @@ def multi_agent(arguments: dict) -> dict:
app = current_app._get_current_object()
# Use injected model/project_id from executor context, fall back to defaults
model = arguments.get("_model") or DEFAULT_MODEL
model = arguments.get("_model") or config.default_model
project_id = arguments.get("_project_id")
max_tokens = arguments.get("_max_tokens", 65536)
temperature = arguments.get("_temperature", 0.7)
# Execute agents concurrently
concurrency = min(len(tasks), SUB_AGENT_MAX_CONCURRENCY)
concurrency = min(len(tasks), config.sub_agent.max_concurrency)
results = [None] * len(tasks)
with ThreadPoolExecutor(max_workers=concurrency) as pool:
@ -274,7 +264,7 @@ def multi_agent(arguments: dict) -> dict:
temperature,
project_id,
app,
SUB_AGENT_MAX_ITERATIONS,
config.sub_agent.max_iterations,
): i
for i, task in enumerate(tasks)
}

View File

@ -7,8 +7,7 @@ import textwrap
from typing import Dict, List, Set
from backend.tools.factory import tool
from backend.config import CODE_EXECUTION_DEFAULT_STRICTNESS as DEFAULT_STRICTNESS
from backend.config import CODE_EXECUTION_EXTRA_MODULES as _CFG_EXTRA_MODULES
from backend.config import config
# Strictness profiles configuration
@ -101,7 +100,7 @@ def register_extra_modules(strictness: str, modules: Set[str] | List[str]) -> No
# Apply extra modules from config.yml on module load
for _level, _mods in _CFG_EXTRA_MODULES.items():
for _level, _mods in config.code_execution.extra_allowed_modules.items():
if isinstance(_mods, list) and _mods:
register_extra_modules(_level, _mods)
@ -144,7 +143,7 @@ def execute_python(arguments: dict) -> dict:
5. Subprocess isolation
"""
code = arguments["code"]
strictness = arguments.get("strictness", DEFAULT_STRICTNESS)
strictness = arguments.get("strictness", config.code_execution.default_strictness)
# Validate strictness level
if strictness not in STRICTNESS_PROFILES: