refactor: 使用dataclass 设定config
This commit is contained in:
parent
ae73559fd2
commit
2a6c82b3ba
|
|
@ -1,49 +1,105 @@
|
||||||
"""Configuration management"""
|
"""Configuration management using dataclasses"""
|
||||||
import sys
|
import sys
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import List, Dict, Optional
|
||||||
|
|
||||||
from backend import load_config
|
from backend import load_config
|
||||||
|
|
||||||
_cfg = load_config()
|
|
||||||
|
|
||||||
# Model list (for /api/models endpoint)
|
@dataclass
|
||||||
MODELS = _cfg.get("models", [])
|
class ModelConfig:
|
||||||
|
"""Individual model configuration."""
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
api_url: str
|
||||||
|
api_key: str
|
||||||
|
|
||||||
# Validate each model has required fields at startup
|
|
||||||
_REQUIRED_MODEL_KEYS = {"id", "name", "api_url", "api_key"}
|
|
||||||
_model_ids_seen = set()
|
|
||||||
for _i, _m in enumerate(MODELS):
|
|
||||||
_missing = _REQUIRED_MODEL_KEYS - set(_m.keys())
|
|
||||||
if _missing:
|
|
||||||
print(f"[config] ERROR: models[{_i}] missing required fields: {_missing}", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
if _m["id"] in _model_ids_seen:
|
|
||||||
print(f"[config] ERROR: duplicate model id '{_m['id']}'", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
_model_ids_seen.add(_m["id"])
|
|
||||||
|
|
||||||
# Per-model config lookup: {model_id: {api_url, api_key}}
|
@dataclass
|
||||||
MODEL_CONFIG = {m["id"]: {"api_url": m["api_url"], "api_key": m["api_key"]} for m in MODELS}
|
class SubAgentConfig:
|
||||||
|
"""Sub-agent (multi_agent tool) settings."""
|
||||||
|
max_iterations: int = 3
|
||||||
|
max_concurrency: int = 3
|
||||||
|
timeout: int = 60
|
||||||
|
|
||||||
# default_model must exist in models
|
|
||||||
DEFAULT_MODEL = _cfg.get("default_model", "")
|
|
||||||
if DEFAULT_MODEL and DEFAULT_MODEL not in MODEL_CONFIG:
|
|
||||||
print(f"[config] ERROR: default_model '{DEFAULT_MODEL}' not found in models", file=sys.stderr)
|
|
||||||
sys.exit(1)
|
|
||||||
if MODELS and not DEFAULT_MODEL:
|
|
||||||
DEFAULT_MODEL = MODELS[0]["id"]
|
|
||||||
|
|
||||||
# Max agentic loop iterations (tool call rounds)
|
@dataclass
|
||||||
MAX_ITERATIONS = _cfg.get("max_iterations", 5)
|
class CodeExecutionConfig:
|
||||||
|
"""Code execution settings."""
|
||||||
|
default_strictness: str = "standard"
|
||||||
|
extra_allowed_modules: Dict = field(default_factory=dict)
|
||||||
|
|
||||||
# Max parallel workers for tool execution (ThreadPoolExecutor)
|
|
||||||
TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4)
|
|
||||||
|
|
||||||
# Sub-agent settings (multi_agent tool)
|
@dataclass
|
||||||
_sa = _cfg.get("sub_agent", {})
|
class AppConfig:
|
||||||
SUB_AGENT_MAX_ITERATIONS = _sa.get("max_iterations", 3)
|
"""Main application configuration."""
|
||||||
SUB_AGENT_MAX_CONCURRENCY = _sa.get("max_concurrency", 3)
|
models: List[ModelConfig] = field(default_factory=list)
|
||||||
SUB_AGENT_TIMEOUT = _sa.get("timeout", 60)
|
default_model: str = ""
|
||||||
|
max_iterations: int = 5
|
||||||
|
tool_max_workers: int = 4
|
||||||
|
sub_agent: SubAgentConfig = field(default_factory=SubAgentConfig)
|
||||||
|
code_execution: CodeExecutionConfig = field(default_factory=CodeExecutionConfig)
|
||||||
|
|
||||||
# Code execution settings
|
# Per-model config lookup: {model_id: ModelConfig}
|
||||||
_ce = _cfg.get("code_execution", {})
|
_model_config_map: Dict[str, ModelConfig] = field(default_factory=dict, repr=False)
|
||||||
CODE_EXECUTION_DEFAULT_STRICTNESS = _ce.get("default_strictness", "standard")
|
|
||||||
CODE_EXECUTION_EXTRA_MODULES = _ce.get("extra_allowed_modules", {})
|
def __post_init__(self):
|
||||||
|
"""Build lookup map after initialization."""
|
||||||
|
self._model_config_map = {m.id: m for m in self.models}
|
||||||
|
|
||||||
|
def get_model_config(self, model_id: str) -> Optional[ModelConfig]:
|
||||||
|
"""Get model config by ID."""
|
||||||
|
return self._model_config_map.get(model_id)
|
||||||
|
|
||||||
|
def get_model_credentials(self, model_id: str) -> tuple:
|
||||||
|
"""Get (api_url, api_key) for a model."""
|
||||||
|
cfg = self.get_model_config(model_id)
|
||||||
|
if not cfg:
|
||||||
|
raise ValueError(f"Unknown model: '{model_id}', not found in config")
|
||||||
|
if not cfg.api_url:
|
||||||
|
raise ValueError(f"Model '{model_id}' has no api_url configured")
|
||||||
|
if not cfg.api_key:
|
||||||
|
raise ValueError(f"Model '{model_id}' has no api_key configured")
|
||||||
|
return cfg.api_url, cfg.api_key
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_config(raw: dict) -> AppConfig:
|
||||||
|
"""Parse raw YAML config into AppConfig dataclass."""
|
||||||
|
# Parse models
|
||||||
|
models = []
|
||||||
|
for m in raw.get("models", []):
|
||||||
|
models.append(ModelConfig(
|
||||||
|
id=m["id"],
|
||||||
|
name=m["name"],
|
||||||
|
api_url=m["api_url"],
|
||||||
|
api_key=m["api_key"],
|
||||||
|
))
|
||||||
|
|
||||||
|
# Parse sub_agent
|
||||||
|
sa_raw = raw.get("sub_agent", {})
|
||||||
|
sub_agent = SubAgentConfig(
|
||||||
|
max_iterations=sa_raw.get("max_iterations", 3),
|
||||||
|
max_concurrency=sa_raw.get("max_concurrency", 3),
|
||||||
|
timeout=sa_raw.get("timeout", 60),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Parse code_execution
|
||||||
|
ce_raw = raw.get("code_execution", {})
|
||||||
|
code_execution = CodeExecutionConfig(
|
||||||
|
default_strictness=ce_raw.get("default_strictness", "standard"),
|
||||||
|
extra_allowed_modules=ce_raw.get("extra_allowed_modules", {}),
|
||||||
|
)
|
||||||
|
|
||||||
|
return AppConfig(
|
||||||
|
models=models,
|
||||||
|
default_model=raw.get("default_model", ""),
|
||||||
|
max_iterations=raw.get("max_iterations", 5),
|
||||||
|
tool_max_workers=raw.get("tool_max_workers", 4),
|
||||||
|
sub_agent=sub_agent,
|
||||||
|
code_execution=code_execution,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Load and validate configuration at startup
|
||||||
|
_raw_cfg = load_config()
|
||||||
|
config = _parse_config(_raw_cfg)
|
||||||
|
|
|
||||||
|
|
@ -8,13 +8,13 @@ from backend.routes.stats import bp as stats_bp
|
||||||
from backend.routes.projects import bp as projects_bp
|
from backend.routes.projects import bp as projects_bp
|
||||||
from backend.routes.auth import bp as auth_bp, init_auth
|
from backend.routes.auth import bp as auth_bp, init_auth
|
||||||
from backend.services.llm_client import LLMClient
|
from backend.services.llm_client import LLMClient
|
||||||
from backend.config import MODEL_CONFIG
|
from backend.config import config
|
||||||
|
|
||||||
|
|
||||||
def register_routes(app: Flask):
|
def register_routes(app: Flask):
|
||||||
"""Register all route blueprints"""
|
"""Register all route blueprints"""
|
||||||
# Initialize LLM client with per-model config
|
# Initialize LLM client with config
|
||||||
client = LLMClient(MODEL_CONFIG)
|
client = LLMClient(config)
|
||||||
init_chat_service(client)
|
init_chat_service(client)
|
||||||
|
|
||||||
# Register LLM client in service locator so tools (e.g. multi_agent) can access it
|
# Register LLM client in service locator so tools (e.g. multi_agent) can access it
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ from flask import Blueprint, request, g
|
||||||
from backend import db
|
from backend import db
|
||||||
from backend.models import Conversation, Project
|
from backend.models import Conversation, Project
|
||||||
from backend.utils.helpers import ok, err, to_dict
|
from backend.utils.helpers import ok, err, to_dict
|
||||||
from backend.config import DEFAULT_MODEL
|
from backend.config import config
|
||||||
|
|
||||||
bp = Blueprint("conversations", __name__)
|
bp = Blueprint("conversations", __name__)
|
||||||
|
|
||||||
|
|
@ -40,7 +40,7 @@ def conversation_list():
|
||||||
user_id=user.id,
|
user_id=user.id,
|
||||||
project_id=project_id or None,
|
project_id=project_id or None,
|
||||||
title=d.get("title", ""),
|
title=d.get("title", ""),
|
||||||
model=d.get("model", DEFAULT_MODEL),
|
model=d.get("model", config.default_model),
|
||||||
system_prompt=d.get("system_prompt", ""),
|
system_prompt=d.get("system_prompt", ""),
|
||||||
temperature=d.get("temperature", 1.0),
|
temperature=d.get("temperature", 1.0),
|
||||||
max_tokens=d.get("max_tokens", 65536),
|
max_tokens=d.get("max_tokens", 65536),
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Model list API routes"""
|
"""Model list API routes"""
|
||||||
from flask import Blueprint
|
from flask import Blueprint
|
||||||
from backend.utils.helpers import ok
|
from backend.utils.helpers import ok
|
||||||
from backend.config import MODELS
|
from backend.config import config
|
||||||
|
|
||||||
bp = Blueprint("models", __name__)
|
bp = Blueprint("models", __name__)
|
||||||
|
|
||||||
|
|
@ -13,7 +13,10 @@ _SENSITIVE_KEYS = {"api_key", "api_url"}
|
||||||
def list_models():
|
def list_models():
|
||||||
"""Get available model list (without sensitive fields like api_key)"""
|
"""Get available model list (without sensitive fields like api_key)"""
|
||||||
safe_models = [
|
safe_models = [
|
||||||
{k: v for k, v in m.items() if k not in _SENSITIVE_KEYS}
|
{
|
||||||
for m in MODELS
|
"id": m.id,
|
||||||
|
"name": m.name,
|
||||||
|
}
|
||||||
|
for m in config.models
|
||||||
]
|
]
|
||||||
return ok(safe_models)
|
return ok(safe_models)
|
||||||
|
|
@ -14,7 +14,7 @@ from backend.utils.helpers import (
|
||||||
build_messages,
|
build_messages,
|
||||||
)
|
)
|
||||||
from backend.services.llm_client import LLMClient
|
from backend.services.llm_client import LLMClient
|
||||||
from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS
|
from backend.config import config as _cfg
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -89,7 +89,8 @@ class ChatService:
|
||||||
total_completion_tokens = 0
|
total_completion_tokens = 0
|
||||||
total_prompt_tokens = 0
|
total_prompt_tokens = 0
|
||||||
|
|
||||||
for iteration in range(MAX_ITERATIONS):
|
|
||||||
|
for iteration in range(_cfg.max_iterations):
|
||||||
# Helper to parse stream_result event
|
# Helper to parse stream_result event
|
||||||
def parse_stream_result(event_str):
|
def parse_stream_result(event_str):
|
||||||
"""Parse stream_result SSE event and extract data dict."""
|
"""Parse stream_result SSE event and extract data dict."""
|
||||||
|
|
@ -385,7 +386,7 @@ class ChatService:
|
||||||
if len(tool_calls_list) > 1:
|
if len(tool_calls_list) > 1:
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
return executor.process_tool_calls_parallel(
|
return executor.process_tool_calls_parallel(
|
||||||
tool_calls_list, context, max_workers=TOOL_MAX_WORKERS
|
tool_calls_list, context, max_workers=_cfg.tool_max_workers
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
with app.app_context():
|
with app.app_context():
|
||||||
|
|
|
||||||
|
|
@ -35,28 +35,22 @@ def _detect_provider(api_url: str) -> str:
|
||||||
class LLMClient:
|
class LLMClient:
|
||||||
"""OpenAI-compatible LLM API client.
|
"""OpenAI-compatible LLM API client.
|
||||||
|
|
||||||
Each model must have its own api_url and api_key configured in MODEL_CONFIG.
|
Each model must have its own api_url and api_key configured in config.models.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, model_config: dict):
|
def __init__(self, cfg):
|
||||||
"""Initialize with per-model config lookup.
|
"""Initialize with AppConfig.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_config: {model_id: {"api_url": ..., "api_key": ...}}
|
cfg: AppConfig dataclass instance
|
||||||
"""
|
"""
|
||||||
self.model_config = model_config
|
self.cfg = cfg
|
||||||
|
|
||||||
def _get_credentials(self, model: str):
|
def _get_credentials(self, model: str):
|
||||||
"""Get api_url and api_key for a model, with env-var expansion."""
|
"""Get api_url and api_key for a model, with env-var expansion."""
|
||||||
cfg = self.model_config.get(model)
|
api_url, api_key = self.cfg.get_model_credentials(model)
|
||||||
if not cfg:
|
api_url = _resolve_env_vars(api_url)
|
||||||
raise ValueError(f"Unknown model: '{model}', not found in config")
|
api_key = _resolve_env_vars(api_key)
|
||||||
api_url = _resolve_env_vars(cfg.get("api_url", ""))
|
|
||||||
api_key = _resolve_env_vars(cfg.get("api_key", ""))
|
|
||||||
if not api_url:
|
|
||||||
raise ValueError(f"Model '{model}' has no api_url configured")
|
|
||||||
if not api_key:
|
|
||||||
raise ValueError(f"Model '{model}' has no api_key configured")
|
|
||||||
return api_url, api_key
|
return api_url, api_key
|
||||||
|
|
||||||
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,
|
def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled,
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,12 @@
|
||||||
"""Multi-agent tool for spawning concurrent sub-agents.
|
|
||||||
|
|
||||||
Provides:
|
|
||||||
- multi_agent: Spawn sub-agents with independent LLM conversation loops
|
|
||||||
"""
|
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
from typing import List, Dict, Any, Optional
|
from typing import List, Any, Optional
|
||||||
from backend.tools import get_service
|
from backend.tools import get_service
|
||||||
from backend.tools.factory import tool
|
from backend.tools.factory import tool
|
||||||
from backend.tools.core import registry
|
from backend.tools.core import registry
|
||||||
from backend.tools.executor import ToolExecutor
|
from backend.tools.executor import ToolExecutor
|
||||||
from backend.config import (
|
from backend.config import config
|
||||||
DEFAULT_MODEL,
|
|
||||||
SUB_AGENT_MAX_ITERATIONS,
|
|
||||||
SUB_AGENT_MAX_CONCURRENCY,
|
|
||||||
SUB_AGENT_TIMEOUT,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -118,7 +108,7 @@ def _run_sub_agent(
|
||||||
stream=False,
|
stream=False,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
timeout=SUB_AGENT_TIMEOUT,
|
timeout=config.sub_agent.timeout,
|
||||||
)
|
)
|
||||||
|
|
||||||
if resp.status_code != 200:
|
if resp.status_code != 200:
|
||||||
|
|
@ -253,13 +243,13 @@ def multi_agent(arguments: dict) -> dict:
|
||||||
app = current_app._get_current_object()
|
app = current_app._get_current_object()
|
||||||
|
|
||||||
# Use injected model/project_id from executor context, fall back to defaults
|
# Use injected model/project_id from executor context, fall back to defaults
|
||||||
model = arguments.get("_model") or DEFAULT_MODEL
|
model = arguments.get("_model") or config.default_model
|
||||||
project_id = arguments.get("_project_id")
|
project_id = arguments.get("_project_id")
|
||||||
max_tokens = arguments.get("_max_tokens", 65536)
|
max_tokens = arguments.get("_max_tokens", 65536)
|
||||||
temperature = arguments.get("_temperature", 0.7)
|
temperature = arguments.get("_temperature", 0.7)
|
||||||
|
|
||||||
# Execute agents concurrently
|
# Execute agents concurrently
|
||||||
concurrency = min(len(tasks), SUB_AGENT_MAX_CONCURRENCY)
|
concurrency = min(len(tasks), config.sub_agent.max_concurrency)
|
||||||
results = [None] * len(tasks)
|
results = [None] * len(tasks)
|
||||||
|
|
||||||
with ThreadPoolExecutor(max_workers=concurrency) as pool:
|
with ThreadPoolExecutor(max_workers=concurrency) as pool:
|
||||||
|
|
@ -274,7 +264,7 @@ def multi_agent(arguments: dict) -> dict:
|
||||||
temperature,
|
temperature,
|
||||||
project_id,
|
project_id,
|
||||||
app,
|
app,
|
||||||
SUB_AGENT_MAX_ITERATIONS,
|
config.sub_agent.max_iterations,
|
||||||
): i
|
): i
|
||||||
for i, task in enumerate(tasks)
|
for i, task in enumerate(tasks)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,8 +7,7 @@ import textwrap
|
||||||
from typing import Dict, List, Set
|
from typing import Dict, List, Set
|
||||||
|
|
||||||
from backend.tools.factory import tool
|
from backend.tools.factory import tool
|
||||||
from backend.config import CODE_EXECUTION_DEFAULT_STRICTNESS as DEFAULT_STRICTNESS
|
from backend.config import config
|
||||||
from backend.config import CODE_EXECUTION_EXTRA_MODULES as _CFG_EXTRA_MODULES
|
|
||||||
|
|
||||||
|
|
||||||
# Strictness profiles configuration
|
# Strictness profiles configuration
|
||||||
|
|
@ -101,7 +100,7 @@ def register_extra_modules(strictness: str, modules: Set[str] | List[str]) -> No
|
||||||
|
|
||||||
|
|
||||||
# Apply extra modules from config.yml on module load
|
# Apply extra modules from config.yml on module load
|
||||||
for _level, _mods in _CFG_EXTRA_MODULES.items():
|
for _level, _mods in config.code_execution.extra_allowed_modules.items():
|
||||||
if isinstance(_mods, list) and _mods:
|
if isinstance(_mods, list) and _mods:
|
||||||
register_extra_modules(_level, _mods)
|
register_extra_modules(_level, _mods)
|
||||||
|
|
||||||
|
|
@ -144,7 +143,7 @@ def execute_python(arguments: dict) -> dict:
|
||||||
5. Subprocess isolation
|
5. Subprocess isolation
|
||||||
"""
|
"""
|
||||||
code = arguments["code"]
|
code = arguments["code"]
|
||||||
strictness = arguments.get("strictness", DEFAULT_STRICTNESS)
|
strictness = arguments.get("strictness", config.code_execution.default_strictness)
|
||||||
|
|
||||||
# Validate strictness level
|
# Validate strictness level
|
||||||
if strictness not in STRICTNESS_PROFILES:
|
if strictness not in STRICTNESS_PROFILES:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue