refactor: 使用dataclass 设定config
This commit is contained in:
parent
ae73559fd2
commit
2a6c82b3ba
|
|
@ -1,49 +1,105 @@
|
|||
"""Configuration management"""
|
||||
"""Configuration management using dataclasses"""
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Dict, Optional
|
||||
|
||||
from backend import load_config
|
||||
|
||||
_cfg = load_config()
|
||||
|
||||
# Model list (for /api/models endpoint)
|
||||
MODELS = _cfg.get("models", [])
|
||||
@dataclass
|
||||
class ModelConfig:
|
||||
"""Individual model configuration."""
|
||||
id: str
|
||||
name: str
|
||||
api_url: str
|
||||
api_key: str
|
||||
|
||||
# Validate each model has required fields at startup
|
||||
_REQUIRED_MODEL_KEYS = {"id", "name", "api_url", "api_key"}
|
||||
_model_ids_seen = set()
|
||||
for _i, _m in enumerate(MODELS):
|
||||
_missing = _REQUIRED_MODEL_KEYS - set(_m.keys())
|
||||
if _missing:
|
||||
print(f"[config] ERROR: models[{_i}] missing required fields: {_missing}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if _m["id"] in _model_ids_seen:
|
||||
print(f"[config] ERROR: duplicate model id '{_m['id']}'", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
_model_ids_seen.add(_m["id"])
|
||||
|
||||
# Per-model config lookup: {model_id: {api_url, api_key}}
|
||||
MODEL_CONFIG = {m["id"]: {"api_url": m["api_url"], "api_key": m["api_key"]} for m in MODELS}
|
||||
@dataclass
|
||||
class SubAgentConfig:
|
||||
"""Sub-agent (multi_agent tool) settings."""
|
||||
max_iterations: int = 3
|
||||
max_concurrency: int = 3
|
||||
timeout: int = 60
|
||||
|
||||
# default_model must exist in models
|
||||
DEFAULT_MODEL = _cfg.get("default_model", "")
|
||||
if DEFAULT_MODEL and DEFAULT_MODEL not in MODEL_CONFIG:
|
||||
print(f"[config] ERROR: default_model '{DEFAULT_MODEL}' not found in models", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
if MODELS and not DEFAULT_MODEL:
|
||||
DEFAULT_MODEL = MODELS[0]["id"]
|
||||
|
||||
# Max agentic loop iterations (tool call rounds)
|
||||
MAX_ITERATIONS = _cfg.get("max_iterations", 5)
|
||||
@dataclass
|
||||
class CodeExecutionConfig:
|
||||
"""Code execution settings."""
|
||||
default_strictness: str = "standard"
|
||||
extra_allowed_modules: Dict = field(default_factory=dict)
|
||||
|
||||
# Max parallel workers for tool execution (ThreadPoolExecutor)
|
||||
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
|
||||
|
||||
# Sub-agent settings (multi_agent tool)
|
||||
_sa = _cfg.get("sub_agent", {})
|
||||
SUB_AGENT_MAX_ITERATIONS = _sa.get("max_iterations", 3)
|
||||
SUB_AGENT_MAX_CONCURRENCY = _sa.get("max_concurrency", 3)
|
||||
SUB_AGENT_TIMEOUT = _sa.get("timeout", 60)
|
||||
@dataclass
|
||||
class AppConfig:
|
||||
"""Main application configuration."""
|
||||
models: List[ModelConfig] = field(default_factory=list)
|
||||
default_model: str = ""
|
||||
max_iterations: int = 5
|
||||
tool_max_workers: int = 4
|
||||
sub_agent: SubAgentConfig = field(default_factory=SubAgentConfig)
|
||||
code_execution: CodeExecutionConfig = field(default_factory=CodeExecutionConfig)
|
||||
|
||||
# Code execution settings
|
||||
_ce = _cfg.get("code_execution", {})
|
||||
CODE_EXECUTION_DEFAULT_STRICTNESS = _ce.get("default_strictness", "standard")
|
||||
CODE_EXECUTION_EXTRA_MODULES = _ce.get("extra_allowed_modules", {})
|
||||
# Per-model config lookup: {model_id: ModelConfig}
|
||||
_model_config_map: Dict[str, ModelConfig] = field(default_factory=dict, repr=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Build lookup map after initialization."""
|
||||
self._model_config_map = {m.id: m for m in self.models}
|
||||
|
||||
def get_model_config(self, model_id: str) -> Optional[ModelConfig]:
|
||||
"""Get model config by ID."""
|
||||
return self._model_config_map.get(model_id)
|
||||
|
||||
def get_model_credentials(self, model_id: str) -> tuple:
|
||||
"""Get (api_url, api_key) for a model."""
|
||||
cfg = self.get_model_config(model_id)
|
||||
if not cfg:
|
||||
raise ValueError(f"Unknown model: '{model_id}', not found in config")
|
||||
if not cfg.api_url:
|
||||
raise ValueError(f"Model '{model_id}' has no api_url configured")
|
||||
if not cfg.api_key:
|
||||
raise ValueError(f"Model '{model_id}' has no api_key configured")
|
||||
return cfg.api_url, cfg.api_key
|
||||
|
||||
|
||||
def _parse_config(raw: dict) -> AppConfig:
|
||||
"""Parse raw YAML config into AppConfig dataclass."""
|
||||
# Parse models
|
||||
models = []
|
||||
for m in raw.get("models", []):
|
||||
models.append(ModelConfig(
|
||||
id=m["id"],
|
||||
name=m["name"],
|
||||
api_url=m["api_url"],
|
||||
api_key=m["api_key"],
|
||||
))
|
||||
|
||||
# Parse sub_agent
|
||||
sa_raw = raw.get("sub_agent", {})
|
||||
sub_agent = SubAgentConfig(
|
||||
max_iterations=sa_raw.get("max_iterations", 3),
|
||||
max_concurrency=sa_raw.get("max_concurrency", 3),
|
||||
timeout=sa_raw.get("timeout", 60),
|
||||
)
|
||||
|
||||
# Parse code_execution
|
||||
ce_raw = raw.get("code_execution", {})
|
||||
code_execution = CodeExecutionConfig(
|
||||
default_strictness=ce_raw.get("default_strictness", "standard"),
|
||||
extra_allowed_modules=ce_raw.get("extra_allowed_modules", {}),
|
||||
)
|
||||
|
||||
return AppConfig(
|
||||
models=models,
|
||||
default_model=raw.get("default_model", ""),
|
||||
max_iterations=raw.get("max_iterations", 5),
|
||||
tool_max_workers=raw.get("tool_max_workers", 4),
|
||||
sub_agent=sub_agent,
|
||||
code_execution=code_execution,
|
||||
)
|
||||
|
||||
|
||||
# Load and validate configuration at startup
|
||||
_raw_cfg = load_config()
|
||||
config = _parse_config(_raw_cfg)
|
||||
|
|
|
|||
|
|
@ -8,13 +8,13 @@ from backend.routes.stats import bp as stats_bp
|
|||
from backend.routes.projects import bp as projects_bp
|
||||
from backend.routes.auth import bp as auth_bp, init_auth
|
||||
from backend.services.llm_client import LLMClient
|
||||
from backend.config import MODEL_CONFIG
|
||||
from backend.config import config
|
||||
|
||||
|
||||
def register_routes(app: Flask):
|
||||
"""Register all route blueprints"""
|
||||
# Initialize LLM client with per-model config
|
||||
client = LLMClient(MODEL_CONFIG)
|
||||
# Initialize LLM client with config
|
||||
client = LLMClient(config)
|
||||
init_chat_service(client)
|
||||
|
||||
# Register LLM client in service locator so tools (e.g. multi_agent) can access it
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ from flask import Blueprint, request, g
|
|||
from backend import db
|
||||
from backend.models import Conversation, Project
|
||||
from backend.utils.helpers import ok, err, to_dict
|
||||
from backend.config import DEFAULT_MODEL
|
||||
from backend.config import config
|
||||
|
||||
bp = Blueprint("conversations", __name__)
|
||||
|
||||
|
|
@ -40,7 +40,7 @@ def conversation_list():
|
|||
user_id=user.id,
|
||||
project_id=project_id or None,
|
||||
title=d.get("title", ""),
|
||||
model=d.get("model", DEFAULT_MODEL),
|
||||
model=d.get("model", config.default_model),
|
||||
system_prompt=d.get("system_prompt", ""),
|
||||
temperature=d.get("temperature", 1.0),
|
||||
max_tokens=d.get("max_tokens", 65536),
|
||||
|
|
@ -105,4 +105,4 @@ def conversation_detail(conv_id):
|
|||
conv.project_id = project_id or None
|
||||
|
||||
db.session.commit()
|
||||
return ok(_conv_to_dict(conv))
|
||||
return ok(_conv_to_dict(conv))
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""Model list API routes"""
|
||||
from flask import Blueprint
|
||||
from backend.utils.helpers import ok
|
||||
from backend.config import MODELS
|
||||
from backend.config import config
|
||||
|
||||
bp = Blueprint("models", __name__)
|
||||
|
||||
|
|
@ -13,7 +13,10 @@ _SENSITIVE_KEYS = {"api_key", "api_url"}
|
|||
def list_models():
|
||||
"""Get available model list (without sensitive fields like api_key)"""
|
||||
safe_models = [
|
||||
{k: v for k, v in m.items() if k not in _SENSITIVE_KEYS}
|
||||
for m in MODELS
|
||||
{
|
||||
"id": m.id,
|
||||
"name": m.name,
|
||||
}
|
||||
for m in config.models
|
||||
]
|
||||
return ok(safe_models)
|
||||
return ok(safe_models)
|
||||
|
|
@ -14,7 +14,7 @@ from backend.utils.helpers import (
|
|||
build_messages,
|
||||
)
|
||||
from backend.services.llm_client import LLMClient
|
||||
from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS
|
||||
from backend.config import config as _cfg
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -89,7 +89,8 @@ class ChatService:
|
|||
total_completion_tokens = 0
|
||||
total_prompt_tokens = 0
|
||||
|
||||
for iteration in range(MAX_ITERATIONS):
|
||||
|
||||
for iteration in range(_cfg.max_iterations):
|
||||
# Helper to parse stream_result event
|
||||
def parse_stream_result(event_str):
|
||||
"""Parse stream_result SSE event and extract data dict."""
|
||||
|
|
@ -385,7 +386,7 @@ class ChatService:
|
|||
if len(tool_calls_list) > 1:
|
||||
with app.app_context():
|
||||
return executor.process_tool_calls_parallel(
|
||||
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
|
||||
tool_calls_list, context, max_workers=_cfg.tool_max_workers
|
||||
)
|
||||
else:
|
||||
with app.app_context():
|
||||
|
|
|
|||
|
|
@ -35,28 +35,22 @@ def _detect_provider(api_url: str) -> str:
|
|||
class LLMClient:
|
||||
"""OpenAI-compatible LLM API client.
|
||||
|
||||
Each model must have its own api_url and api_key configured in MODEL_CONFIG.
|
||||
Each model must have its own api_url and api_key configured in config.models.
|
||||
"""
|
||||
|
||||
def __init__(self, model_config: dict):
|
||||
"""Initialize with per-model config lookup.
|
||||
def __init__(self, cfg):
|
||||
"""Initialize with AppConfig.
|
||||
|
||||
Args:
|
||||
model_config: {model_id: {"api_url": ..., "api_key": ...}}
|
||||
cfg: AppConfig dataclass instance
|
||||
"""
|
||||
self.model_config = model_config
|
||||
self.cfg = cfg
|
||||
|
||||
def _get_credentials(self, model: str):
|
||||
"""Get api_url and api_key for a model, with env-var expansion."""
|
||||
cfg = self.model_config.get(model)
|
||||
if not cfg:
|
||||
raise ValueError(f"Unknown model: '{model}', not found in config")
|
||||
api_url = _resolve_env_vars(cfg.get("api_url", ""))
|
||||
api_key = _resolve_env_vars(cfg.get("api_key", ""))
|
||||
if not api_url:
|
||||
raise ValueError(f"Model '{model}' has no api_url configured")
|
||||
if not api_key:
|
||||
raise ValueError(f"Model '{model}' has no api_key configured")
|
||||
api_url, api_key = self.cfg.get_model_credentials(model)
|
||||
api_url = _resolve_env_vars(api_url)
|
||||
api_key = _resolve_env_vars(api_key)
|
||||
return api_url, api_key
|
||||
|
||||
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,
|
||||
|
|
|
|||
|
|
@ -1,22 +1,12 @@
|
|||
"""Multi-agent tool for spawning concurrent sub-agents.
|
||||
|
||||
Provides:
|
||||
- multi_agent: Spawn sub-agents with independent LLM conversation loops
|
||||
"""
|
||||
import json
|
||||
import logging
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, Dict, Any, Optional
|
||||
from typing import List, Any, Optional
|
||||
from backend.tools import get_service
|
||||
from backend.tools.factory import tool
|
||||
from backend.tools.core import registry
|
||||
from backend.tools.executor import ToolExecutor
|
||||
from backend.config import (
|
||||
DEFAULT_MODEL,
|
||||
SUB_AGENT_MAX_ITERATIONS,
|
||||
SUB_AGENT_MAX_CONCURRENCY,
|
||||
SUB_AGENT_TIMEOUT,
|
||||
)
|
||||
from backend.config import config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -118,7 +108,7 @@ def _run_sub_agent(
|
|||
stream=False,
|
||||
max_tokens=max_tokens,
|
||||
temperature=temperature,
|
||||
timeout=SUB_AGENT_TIMEOUT,
|
||||
timeout=config.sub_agent.timeout,
|
||||
)
|
||||
|
||||
if resp.status_code != 200:
|
||||
|
|
@ -253,13 +243,13 @@ def multi_agent(arguments: dict) -> dict:
|
|||
app = current_app._get_current_object()
|
||||
|
||||
# Use injected model/project_id from executor context, fall back to defaults
|
||||
model = arguments.get("_model") or DEFAULT_MODEL
|
||||
model = arguments.get("_model") or config.default_model
|
||||
project_id = arguments.get("_project_id")
|
||||
max_tokens = arguments.get("_max_tokens", 65536)
|
||||
temperature = arguments.get("_temperature", 0.7)
|
||||
|
||||
# Execute agents concurrently
|
||||
concurrency = min(len(tasks), SUB_AGENT_MAX_CONCURRENCY)
|
||||
concurrency = min(len(tasks), config.sub_agent.max_concurrency)
|
||||
results = [None] * len(tasks)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as pool:
|
||||
|
|
@ -274,7 +264,7 @@ def multi_agent(arguments: dict) -> dict:
|
|||
temperature,
|
||||
project_id,
|
||||
app,
|
||||
SUB_AGENT_MAX_ITERATIONS,
|
||||
config.sub_agent.max_iterations,
|
||||
): i
|
||||
for i, task in enumerate(tasks)
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,8 +7,7 @@ import textwrap
|
|||
from typing import Dict, List, Set
|
||||
|
||||
from backend.tools.factory import tool
|
||||
from backend.config import CODE_EXECUTION_DEFAULT_STRICTNESS as DEFAULT_STRICTNESS
|
||||
from backend.config import CODE_EXECUTION_EXTRA_MODULES as _CFG_EXTRA_MODULES
|
||||
from backend.config import config
|
||||
|
||||
|
||||
# Strictness profiles configuration
|
||||
|
|
@ -101,7 +100,7 @@ def register_extra_modules(strictness: str, modules: Set[str] | List[str]) -> No
|
|||
|
||||
|
||||
# Apply extra modules from config.yml on module load
|
||||
for _level, _mods in _CFG_EXTRA_MODULES.items():
|
||||
for _level, _mods in config.code_execution.extra_allowed_modules.items():
|
||||
if isinstance(_mods, list) and _mods:
|
||||
register_extra_modules(_level, _mods)
|
||||
|
||||
|
|
@ -144,7 +143,7 @@ def execute_python(arguments: dict) -> dict:
|
|||
5. Subprocess isolation
|
||||
"""
|
||||
code = arguments["code"]
|
||||
strictness = arguments.get("strictness", DEFAULT_STRICTNESS)
|
||||
strictness = arguments.get("strictness", config.code_execution.default_strictness)
|
||||
|
||||
# Validate strictness level
|
||||
if strictness not in STRICTNESS_PROFILES:
|
||||
|
|
|
|||
Loading…
Reference in New Issue