diff --git a/backend/config.py b/backend/config.py index fce9968..92ab78f 100644 --- a/backend/config.py +++ b/backend/config.py @@ -1,49 +1,105 @@ -"""Configuration management""" +"""Configuration management using dataclasses""" import sys +from dataclasses import dataclass, field +from typing import List, Dict, Optional + from backend import load_config -_cfg = load_config() -# Model list (for /api/models endpoint) -MODELS = _cfg.get("models", []) +@dataclass +class ModelConfig: + """Individual model configuration.""" + id: str + name: str + api_url: str + api_key: str -# Validate each model has required fields at startup -_REQUIRED_MODEL_KEYS = {"id", "name", "api_url", "api_key"} -_model_ids_seen = set() -for _i, _m in enumerate(MODELS): - _missing = _REQUIRED_MODEL_KEYS - set(_m.keys()) - if _missing: - print(f"[config] ERROR: models[{_i}] missing required fields: {_missing}", file=sys.stderr) - sys.exit(1) - if _m["id"] in _model_ids_seen: - print(f"[config] ERROR: duplicate model id '{_m['id']}'", file=sys.stderr) - sys.exit(1) - _model_ids_seen.add(_m["id"]) -# Per-model config lookup: {model_id: {api_url, api_key}} -MODEL_CONFIG = {m["id"]: {"api_url": m["api_url"], "api_key": m["api_key"]} for m in MODELS} +@dataclass +class SubAgentConfig: + """Sub-agent (multi_agent tool) settings.""" + max_iterations: int = 3 + max_concurrency: int = 3 + timeout: int = 60 -# default_model must exist in models -DEFAULT_MODEL = _cfg.get("default_model", "") -if DEFAULT_MODEL and DEFAULT_MODEL not in MODEL_CONFIG: - print(f"[config] ERROR: default_model '{DEFAULT_MODEL}' not found in models", file=sys.stderr) - sys.exit(1) -if MODELS and not DEFAULT_MODEL: - DEFAULT_MODEL = MODELS[0]["id"] -# Max agentic loop iterations (tool call rounds) -MAX_ITERATIONS = _cfg.get("max_iterations", 5) +@dataclass +class CodeExecutionConfig: + """Code execution settings.""" + default_strictness: str = "standard" + extra_allowed_modules: Dict = field(default_factory=dict) -# Max parallel workers for tool execution (ThreadPoolExecutor) -TOOL_MAX_WORKERS = _cfg.get("tool_max_workers", 4) -# Sub-agent settings (multi_agent tool) -_sa = _cfg.get("sub_agent", {}) -SUB_AGENT_MAX_ITERATIONS = _sa.get("max_iterations", 3) -SUB_AGENT_MAX_CONCURRENCY = _sa.get("max_concurrency", 3) -SUB_AGENT_TIMEOUT = _sa.get("timeout", 60) +@dataclass +class AppConfig: + """Main application configuration.""" + models: List[ModelConfig] = field(default_factory=list) + default_model: str = "" + max_iterations: int = 5 + tool_max_workers: int = 4 + sub_agent: SubAgentConfig = field(default_factory=SubAgentConfig) + code_execution: CodeExecutionConfig = field(default_factory=CodeExecutionConfig) -# Code execution settings -_ce = _cfg.get("code_execution", {}) -CODE_EXECUTION_DEFAULT_STRICTNESS = _ce.get("default_strictness", "standard") -CODE_EXECUTION_EXTRA_MODULES = _ce.get("extra_allowed_modules", {}) + # Per-model config lookup: {model_id: ModelConfig} + _model_config_map: Dict[str, ModelConfig] = field(default_factory=dict, repr=False) + + def __post_init__(self): + """Build lookup map after initialization.""" + self._model_config_map = {m.id: m for m in self.models} + + def get_model_config(self, model_id: str) -> Optional[ModelConfig]: + """Get model config by ID.""" + return self._model_config_map.get(model_id) + + def get_model_credentials(self, model_id: str) -> tuple: + """Get (api_url, api_key) for a model.""" + cfg = self.get_model_config(model_id) + if not cfg: + raise ValueError(f"Unknown model: '{model_id}', not found in config") + if not cfg.api_url: + raise ValueError(f"Model '{model_id}' has no api_url configured") + if not cfg.api_key: + raise ValueError(f"Model '{model_id}' has no api_key configured") + return cfg.api_url, cfg.api_key + + +def _parse_config(raw: dict) -> AppConfig: + """Parse raw YAML config into AppConfig dataclass.""" + # Parse models + models = [] + for m in raw.get("models", []): + models.append(ModelConfig( + id=m["id"], + name=m["name"], + api_url=m["api_url"], + api_key=m["api_key"], + )) + + # Parse sub_agent + sa_raw = raw.get("sub_agent", {}) + sub_agent = SubAgentConfig( + max_iterations=sa_raw.get("max_iterations", 3), + max_concurrency=sa_raw.get("max_concurrency", 3), + timeout=sa_raw.get("timeout", 60), + ) + + # Parse code_execution + ce_raw = raw.get("code_execution", {}) + code_execution = CodeExecutionConfig( + default_strictness=ce_raw.get("default_strictness", "standard"), + extra_allowed_modules=ce_raw.get("extra_allowed_modules", {}), + ) + + return AppConfig( + models=models, + default_model=raw.get("default_model", ""), + max_iterations=raw.get("max_iterations", 5), + tool_max_workers=raw.get("tool_max_workers", 4), + sub_agent=sub_agent, + code_execution=code_execution, + ) + + +# Load and validate configuration at startup +_raw_cfg = load_config() +config = _parse_config(_raw_cfg) diff --git a/backend/routes/__init__.py b/backend/routes/__init__.py index 4919ba9..40cb670 100644 --- a/backend/routes/__init__.py +++ b/backend/routes/__init__.py @@ -8,13 +8,13 @@ from backend.routes.stats import bp as stats_bp from backend.routes.projects import bp as projects_bp from backend.routes.auth import bp as auth_bp, init_auth from backend.services.llm_client import LLMClient -from backend.config import MODEL_CONFIG +from backend.config import config def register_routes(app: Flask): """Register all route blueprints""" - # Initialize LLM client with per-model config - client = LLMClient(MODEL_CONFIG) + # Initialize LLM client with config + client = LLMClient(config) init_chat_service(client) # Register LLM client in service locator so tools (e.g. multi_agent) can access it diff --git a/backend/routes/conversations.py b/backend/routes/conversations.py index d84c30f..bd442a1 100644 --- a/backend/routes/conversations.py +++ b/backend/routes/conversations.py @@ -5,7 +5,7 @@ from flask import Blueprint, request, g from backend import db from backend.models import Conversation, Project from backend.utils.helpers import ok, err, to_dict -from backend.config import DEFAULT_MODEL +from backend.config import config bp = Blueprint("conversations", __name__) @@ -40,7 +40,7 @@ def conversation_list(): user_id=user.id, project_id=project_id or None, title=d.get("title", ""), - model=d.get("model", DEFAULT_MODEL), + model=d.get("model", config.default_model), system_prompt=d.get("system_prompt", ""), temperature=d.get("temperature", 1.0), max_tokens=d.get("max_tokens", 65536), @@ -105,4 +105,4 @@ def conversation_detail(conv_id): conv.project_id = project_id or None db.session.commit() - return ok(_conv_to_dict(conv)) + return ok(_conv_to_dict(conv)) \ No newline at end of file diff --git a/backend/routes/models.py b/backend/routes/models.py index 8fc4368..43a5f2a 100644 --- a/backend/routes/models.py +++ b/backend/routes/models.py @@ -1,7 +1,7 @@ """Model list API routes""" from flask import Blueprint from backend.utils.helpers import ok -from backend.config import MODELS +from backend.config import config bp = Blueprint("models", __name__) @@ -13,7 +13,10 @@ _SENSITIVE_KEYS = {"api_key", "api_url"} def list_models(): """Get available model list (without sensitive fields like api_key)""" safe_models = [ - {k: v for k, v in m.items() if k not in _SENSITIVE_KEYS} - for m in MODELS + { + "id": m.id, + "name": m.name, + } + for m in config.models ] - return ok(safe_models) + return ok(safe_models) \ No newline at end of file diff --git a/backend/services/chat.py b/backend/services/chat.py index ac1602f..b5a4c0e 100644 --- a/backend/services/chat.py +++ b/backend/services/chat.py @@ -14,7 +14,7 @@ from backend.utils.helpers import ( build_messages, ) from backend.services.llm_client import LLMClient -from backend.config import MAX_ITERATIONS, TOOL_MAX_WORKERS +from backend.config import config as _cfg logger = logging.getLogger(__name__) @@ -89,7 +89,8 @@ class ChatService: total_completion_tokens = 0 total_prompt_tokens = 0 - for iteration in range(MAX_ITERATIONS): + + for iteration in range(_cfg.max_iterations): # Helper to parse stream_result event def parse_stream_result(event_str): """Parse stream_result SSE event and extract data dict.""" @@ -385,7 +386,7 @@ class ChatService: if len(tool_calls_list) > 1: with app.app_context(): return executor.process_tool_calls_parallel( - tool_calls_list, context, max_workers=TOOL_MAX_WORKERS + tool_calls_list, context, max_workers=_cfg.tool_max_workers ) else: with app.app_context(): diff --git a/backend/services/llm_client.py b/backend/services/llm_client.py index 267a488..73e47fb 100644 --- a/backend/services/llm_client.py +++ b/backend/services/llm_client.py @@ -35,28 +35,22 @@ def _detect_provider(api_url: str) -> str: class LLMClient: """OpenAI-compatible LLM API client. - Each model must have its own api_url and api_key configured in MODEL_CONFIG. + Each model must have its own api_url and api_key configured in config.models. """ - def __init__(self, model_config: dict): - """Initialize with per-model config lookup. + def __init__(self, cfg): + """Initialize with AppConfig. Args: - model_config: {model_id: {"api_url": ..., "api_key": ...}} + cfg: AppConfig dataclass instance """ - self.model_config = model_config + self.cfg = cfg def _get_credentials(self, model: str): """Get api_url and api_key for a model, with env-var expansion.""" - cfg = self.model_config.get(model) - if not cfg: - raise ValueError(f"Unknown model: '{model}', not found in config") - api_url = _resolve_env_vars(cfg.get("api_url", "")) - api_key = _resolve_env_vars(cfg.get("api_key", "")) - if not api_url: - raise ValueError(f"Model '{model}' has no api_url configured") - if not api_key: - raise ValueError(f"Model '{model}' has no api_key configured") + api_url, api_key = self.cfg.get_model_credentials(model) + api_url = _resolve_env_vars(api_url) + api_key = _resolve_env_vars(api_key) return api_url, api_key def _build_body(self, model, messages, max_tokens, temperature, thinking_enabled, diff --git a/backend/tools/builtin/agent.py b/backend/tools/builtin/agent.py index ad2d59c..60933d6 100644 --- a/backend/tools/builtin/agent.py +++ b/backend/tools/builtin/agent.py @@ -1,22 +1,12 @@ -"""Multi-agent tool for spawning concurrent sub-agents. - -Provides: -- multi_agent: Spawn sub-agents with independent LLM conversation loops -""" import json import logging from concurrent.futures import ThreadPoolExecutor, as_completed -from typing import List, Dict, Any, Optional +from typing import List, Any, Optional from backend.tools import get_service from backend.tools.factory import tool from backend.tools.core import registry from backend.tools.executor import ToolExecutor -from backend.config import ( - DEFAULT_MODEL, - SUB_AGENT_MAX_ITERATIONS, - SUB_AGENT_MAX_CONCURRENCY, - SUB_AGENT_TIMEOUT, -) +from backend.config import config logger = logging.getLogger(__name__) @@ -118,7 +108,7 @@ def _run_sub_agent( stream=False, max_tokens=max_tokens, temperature=temperature, - timeout=SUB_AGENT_TIMEOUT, + timeout=config.sub_agent.timeout, ) if resp.status_code != 200: @@ -253,13 +243,13 @@ def multi_agent(arguments: dict) -> dict: app = current_app._get_current_object() # Use injected model/project_id from executor context, fall back to defaults - model = arguments.get("_model") or DEFAULT_MODEL + model = arguments.get("_model") or config.default_model project_id = arguments.get("_project_id") max_tokens = arguments.get("_max_tokens", 65536) temperature = arguments.get("_temperature", 0.7) # Execute agents concurrently - concurrency = min(len(tasks), SUB_AGENT_MAX_CONCURRENCY) + concurrency = min(len(tasks), config.sub_agent.max_concurrency) results = [None] * len(tasks) with ThreadPoolExecutor(max_workers=concurrency) as pool: @@ -274,7 +264,7 @@ def multi_agent(arguments: dict) -> dict: temperature, project_id, app, - SUB_AGENT_MAX_ITERATIONS, + config.sub_agent.max_iterations, ): i for i, task in enumerate(tasks) } diff --git a/backend/tools/builtin/code.py b/backend/tools/builtin/code.py index b190957..fcf52d3 100644 --- a/backend/tools/builtin/code.py +++ b/backend/tools/builtin/code.py @@ -7,8 +7,7 @@ import textwrap from typing import Dict, List, Set from backend.tools.factory import tool -from backend.config import CODE_EXECUTION_DEFAULT_STRICTNESS as DEFAULT_STRICTNESS -from backend.config import CODE_EXECUTION_EXTRA_MODULES as _CFG_EXTRA_MODULES +from backend.config import config # Strictness profiles configuration @@ -101,7 +100,7 @@ def register_extra_modules(strictness: str, modules: Set[str] | List[str]) -> No # Apply extra modules from config.yml on module load -for _level, _mods in _CFG_EXTRA_MODULES.items(): +for _level, _mods in config.code_execution.extra_allowed_modules.items(): if isinstance(_mods, list) and _mods: register_extra_modules(_level, _mods) @@ -144,7 +143,7 @@ def execute_python(arguments: dict) -> dict: 5. Subprocess isolation """ code = arguments["code"] - strictness = arguments.get("strictness", DEFAULT_STRICTNESS) + strictness = arguments.get("strictness", config.code_execution.default_strictness) # Validate strictness level if strictness not in STRICTNESS_PROFILES: