first commit
This commit is contained in:
commit
6749213f62
|
|
@ -0,0 +1,13 @@
|
||||||
|
# 忽略所有内容(从根目录开始)
|
||||||
|
*
|
||||||
|
|
||||||
|
# 允许扫描目录
|
||||||
|
!*/
|
||||||
|
|
||||||
|
!config.yaml
|
||||||
|
!pyproject.toml
|
||||||
|
!README.md
|
||||||
|
!.gitignore
|
||||||
|
|
||||||
|
!alcor/**/*.py
|
||||||
|
!docs/**/*.md
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
"""FastAPI应用工厂"""
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from alcor.config import config
|
||||||
|
from alcor.database import init_db
|
||||||
|
from alcor.routes import api_router
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
"""应用生命周期管理"""
|
||||||
|
# 启动时
|
||||||
|
print("🚀 Starting up ChatBackend API...")
|
||||||
|
|
||||||
|
# 初始化数据库
|
||||||
|
init_db()
|
||||||
|
print("✅ Database initialized")
|
||||||
|
|
||||||
|
# 加载内置工具
|
||||||
|
from alcor.tools.builtin import crawler, code, data
|
||||||
|
print(f"✅ Loaded {len(api_router.routes)} API routes")
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# 关闭时
|
||||||
|
print("👋 Shutting down ChatBackend API...")
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
"""创建FastAPI应用"""
|
||||||
|
app = FastAPI(
|
||||||
|
title="ChatBackend API",
|
||||||
|
description="智能聊天后端API,支持多模型、流式响应、工具调用",
|
||||||
|
version="1.0.0",
|
||||||
|
lifespan=lifespan
|
||||||
|
)
|
||||||
|
|
||||||
|
# 配置CORS
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"], # 生产环境应限制
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
# 注册路由
|
||||||
|
app.include_router(api_router, prefix="/api")
|
||||||
|
|
||||||
|
# 健康检查
|
||||||
|
@app.get("/health")
|
||||||
|
async def health_check():
|
||||||
|
return {"status": "healthy", "service": "chat-backend"}
|
||||||
|
|
||||||
|
@app.get("/")
|
||||||
|
async def root():
|
||||||
|
return {
|
||||||
|
"service": "ChatBackend API",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"docs": "/docs"
|
||||||
|
}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# 创建应用实例
|
||||||
|
app = create_app()
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
"""配置管理模块"""
|
||||||
|
import os
|
||||||
|
import yaml
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""配置类(单例模式)"""
|
||||||
|
|
||||||
|
_instance: Optional["Config"] = None
|
||||||
|
_config: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._load_config()
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def _load_config(self) -> None:
|
||||||
|
"""加载配置文件"""
|
||||||
|
yaml_paths = [
|
||||||
|
Path("config.yaml"),
|
||||||
|
Path(__file__).parent.parent / "config.yaml",
|
||||||
|
Path.cwd() / "config.yaml",
|
||||||
|
]
|
||||||
|
|
||||||
|
for path in yaml_paths:
|
||||||
|
if path.exists():
|
||||||
|
with open(path, "r", encoding="utf-8") as f:
|
||||||
|
self._config = yaml.safe_load(f) or {}
|
||||||
|
self._resolve_env_vars()
|
||||||
|
return
|
||||||
|
|
||||||
|
self._config = {}
|
||||||
|
|
||||||
|
def _resolve_env_vars(self) -> None:
|
||||||
|
"""解析环境变量引用"""
|
||||||
|
def resolve(value: Any) -> Any:
|
||||||
|
if isinstance(value, str) and value.startswith("${") and value.endswith("}"):
|
||||||
|
return os.environ.get(value[2:-1], "")
|
||||||
|
elif isinstance(value, dict):
|
||||||
|
return {k: resolve(v) for k, v in value.items()}
|
||||||
|
elif isinstance(value, list):
|
||||||
|
return [resolve(item) for item in value]
|
||||||
|
return value
|
||||||
|
|
||||||
|
self._config = resolve(self._config)
|
||||||
|
|
||||||
|
def get(self, key: str, default: Any = None) -> Any:
|
||||||
|
"""获取配置值,支持点号分隔的键"""
|
||||||
|
keys = key.split(".")
|
||||||
|
value = self._config
|
||||||
|
for k in keys:
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = value.get(k)
|
||||||
|
else:
|
||||||
|
return default
|
||||||
|
if value is None:
|
||||||
|
return default
|
||||||
|
return value
|
||||||
|
|
||||||
|
# App配置
|
||||||
|
@property
|
||||||
|
def secret_key(self) -> str:
|
||||||
|
return self.get("app.secret_key", "change-me-in-production")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def debug(self) -> bool:
|
||||||
|
return self.get("app.debug", True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app_host(self) -> str:
|
||||||
|
return self.get("app.host", "0.0.0.0")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def app_port(self) -> int:
|
||||||
|
return self.get("app.port", 8000)
|
||||||
|
|
||||||
|
# 数据库配置
|
||||||
|
@property
|
||||||
|
def database_url(self) -> str:
|
||||||
|
return self.get("database.url", "sqlite:///./chat.db")
|
||||||
|
|
||||||
|
# LLM配置
|
||||||
|
@property
|
||||||
|
def llm_api_key(self) -> str:
|
||||||
|
return self.get("llm.api_key", "") or os.environ.get("DEEPSEEK_API_KEY", "")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_api_url(self) -> str:
|
||||||
|
return self.get("llm.api_url", "https://api.deepseek.com/v1")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def llm_provider(self) -> str:
|
||||||
|
return self.get("llm.provider", "deepseek")
|
||||||
|
|
||||||
|
# 工具配置
|
||||||
|
@property
|
||||||
|
def tools_enable_cache(self) -> bool:
|
||||||
|
return self.get("tools.enable_cache", True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tools_cache_ttl(self) -> int:
|
||||||
|
return self.get("tools.cache_ttl", 300)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tools_max_workers(self) -> int:
|
||||||
|
return self.get("tools.max_workers", 4)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tools_max_iterations(self) -> int:
|
||||||
|
return self.get("tools.max_iterations", 10)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局配置实例
|
||||||
|
config = Config()
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
"""数据库连接模块"""
|
||||||
|
from sqlalchemy import create_engine
|
||||||
|
from sqlalchemy.ext.declarative import declarative_base
|
||||||
|
from sqlalchemy.orm import sessionmaker, Session
|
||||||
|
from typing import Generator
|
||||||
|
|
||||||
|
from alcor.config import config
|
||||||
|
|
||||||
|
|
||||||
|
# 创建数据库引擎
|
||||||
|
engine = create_engine(
|
||||||
|
config.database_url,
|
||||||
|
connect_args={"check_same_thread": False} if "sqlite" in config.database_url else {},
|
||||||
|
pool_pre_ping=True,
|
||||||
|
echo=config.debug
|
||||||
|
)
|
||||||
|
|
||||||
|
# 创建会话工厂
|
||||||
|
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||||
|
|
||||||
|
# 创建基类
|
||||||
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
||||||
|
def get_db() -> Generator[Session, None, None]:
|
||||||
|
"""获取数据库会话的依赖项"""
|
||||||
|
db = SessionLocal()
|
||||||
|
try:
|
||||||
|
yield db
|
||||||
|
finally:
|
||||||
|
db.close()
|
||||||
|
|
||||||
|
|
||||||
|
def init_db() -> None:
|
||||||
|
"""初始化数据库,创建所有表"""
|
||||||
|
Base.metadata.create_all(bind=engine)
|
||||||
|
|
@ -0,0 +1,140 @@
|
||||||
|
"""ORM模型定义"""
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional, List
|
||||||
|
from sqlalchemy import String, Text, Integer, Float, Boolean, DateTime, ForeignKey
|
||||||
|
from sqlalchemy.orm import relationship, Mapped, mapped_column
|
||||||
|
|
||||||
|
from alcor.database import Base
|
||||||
|
|
||||||
|
|
||||||
|
class Project(Base):
|
||||||
|
"""项目模型"""
|
||||||
|
__tablename__ = "projects"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), default="")
|
||||||
|
description: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
user: Mapped["User"] = relationship("User", backref="projects")
|
||||||
|
conversations: Mapped[List["Conversation"]] = relationship(
|
||||||
|
"Conversation",
|
||||||
|
back_populates="project",
|
||||||
|
lazy="dynamic"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
"""用户模型"""
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||||
|
username: Mapped[str] = mapped_column(String(50), unique=True, nullable=False)
|
||||||
|
email: Mapped[Optional[str]] = mapped_column(String(120), unique=True, nullable=True)
|
||||||
|
password_hash: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||||
|
role: Mapped[str] = mapped_column(String(20), default="user")
|
||||||
|
is_active: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
conversations: Mapped[List["Conversation"]] = relationship(
|
||||||
|
"Conversation",
|
||||||
|
back_populates="user",
|
||||||
|
lazy="dynamic"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"username": self.username,
|
||||||
|
"email": self.email,
|
||||||
|
"role": self.role,
|
||||||
|
"is_active": self.is_active,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Conversation(Base):
|
||||||
|
"""会话模型"""
|
||||||
|
__tablename__ = "conversations"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
user_id: Mapped[int] = mapped_column(Integer, ForeignKey("users.id"), nullable=False, index=True)
|
||||||
|
project_id: Mapped[Optional[str]] = mapped_column(String(64), ForeignKey("projects.id"), nullable=True)
|
||||||
|
title: Mapped[str] = mapped_column(String(255), default="")
|
||||||
|
model: Mapped[str] = mapped_column(String(64), default="glm-5")
|
||||||
|
system_prompt: Mapped[str] = mapped_column(Text, default="")
|
||||||
|
temperature: Mapped[float] = mapped_column(Float, default=1.0)
|
||||||
|
max_tokens: Mapped[int] = mapped_column(Integer, default=65536)
|
||||||
|
thinking_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, onupdate=datetime.utcnow)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
user: Mapped["User"] = relationship("User", back_populates="conversations")
|
||||||
|
project: Mapped[Optional["Project"]] = relationship("Project", back_populates="conversations")
|
||||||
|
messages: Mapped[List["Message"]] = relationship(
|
||||||
|
"Message",
|
||||||
|
back_populates="conversation",
|
||||||
|
lazy="dynamic",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
order_by="Message.created_at.asc()"
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"user_id": self.user_id,
|
||||||
|
"project_id": self.project_id,
|
||||||
|
"title": self.title,
|
||||||
|
"model": self.model,
|
||||||
|
"system_prompt": self.system_prompt,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"max_tokens": self.max_tokens,
|
||||||
|
"thinking_enabled": self.thinking_enabled,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None,
|
||||||
|
"updated_at": self.updated_at.isoformat() if self.updated_at else None
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class Message(Base):
|
||||||
|
"""消息模型"""
|
||||||
|
__tablename__ = "messages"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||||
|
conversation_id: Mapped[str] = mapped_column(
|
||||||
|
String(64),
|
||||||
|
ForeignKey("conversations.id"),
|
||||||
|
nullable=False,
|
||||||
|
index=True
|
||||||
|
)
|
||||||
|
role: Mapped[str] = mapped_column(String(16), nullable=False) # user, assistant, system, tool
|
||||||
|
content: Mapped[str] = mapped_column(Text, default="") # JSON: {text, steps, tool_calls}
|
||||||
|
token_count: Mapped[int] = mapped_column(Integer, default=0)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(DateTime, default=datetime.utcnow, index=True)
|
||||||
|
|
||||||
|
# 关系
|
||||||
|
conversation: Mapped["Conversation"] = relationship("Conversation", back_populates="messages")
|
||||||
|
|
||||||
|
def to_dict(self) -> dict:
|
||||||
|
return {
|
||||||
|
"id": self.id,
|
||||||
|
"conversation_id": self.conversation_id,
|
||||||
|
"role": self.role,
|
||||||
|
"content": self.content,
|
||||||
|
"token_count": self.token_count,
|
||||||
|
"created_at": self.created_at.isoformat() if self.created_at else None
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,14 @@
|
||||||
|
"""API路由模块"""
|
||||||
|
from fastapi import APIRouter
|
||||||
|
|
||||||
|
from alcor.routes import auth, conversations, messages, tools
|
||||||
|
|
||||||
|
api_router = APIRouter()
|
||||||
|
|
||||||
|
# 注册子路由
|
||||||
|
api_router.include_router(auth.router)
|
||||||
|
api_router.include_router(conversations.router)
|
||||||
|
api_router.include_router(messages.router)
|
||||||
|
api_router.include_router(tools.router)
|
||||||
|
|
||||||
|
__all__ = ["api_router"]
|
||||||
|
|
@ -0,0 +1,154 @@
|
||||||
|
"""认证路由"""
|
||||||
|
from datetime import timedelta
|
||||||
|
from typing import Optional
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer, OAuth2PasswordRequestForm
|
||||||
|
from pydantic import BaseModel, EmailStr
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from alcor.database import get_db
|
||||||
|
from alcor.models import User
|
||||||
|
from alcor.utils.helpers import (
|
||||||
|
hash_password,
|
||||||
|
verify_password,
|
||||||
|
create_access_token,
|
||||||
|
success_response,
|
||||||
|
error_response
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["认证"])
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
class UserRegister(BaseModel):
|
||||||
|
"""用户注册模型"""
|
||||||
|
username: str
|
||||||
|
email: Optional[EmailStr] = None
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserLogin(BaseModel):
|
||||||
|
"""用户登录模型"""
|
||||||
|
username: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class UserResponse(BaseModel):
|
||||||
|
"""用户响应模型"""
|
||||||
|
id: int
|
||||||
|
username: str
|
||||||
|
email: Optional[str] = None
|
||||||
|
role: str
|
||||||
|
is_active: bool
|
||||||
|
|
||||||
|
|
||||||
|
class TokenResponse(BaseModel):
|
||||||
|
"""令牌响应模型"""
|
||||||
|
access_token: str
|
||||||
|
token_type: str = "bearer"
|
||||||
|
|
||||||
|
|
||||||
|
def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
) -> User:
|
||||||
|
"""获取当前用户"""
|
||||||
|
from alcor.utils.helpers import decode_access_token
|
||||||
|
|
||||||
|
payload = decode_access_token(token)
|
||||||
|
if payload is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="无效的认证凭证",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
|
||||||
|
user_id = payload.get("sub")
|
||||||
|
if user_id is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="无效的认证凭证"
|
||||||
|
)
|
||||||
|
|
||||||
|
user = db.query(User).filter(User.id == user_id).first()
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="用户不存在"
|
||||||
|
)
|
||||||
|
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=dict)
|
||||||
|
def register(user_data: UserRegister, db: Session = Depends(get_db)):
|
||||||
|
"""用户注册"""
|
||||||
|
# 检查用户名是否存在
|
||||||
|
existing_user = db.query(User).filter(User.username == user_data.username).first()
|
||||||
|
if existing_user:
|
||||||
|
return error_response("用户名已存在", 400)
|
||||||
|
|
||||||
|
# 检查邮箱是否存在
|
||||||
|
if user_data.email:
|
||||||
|
existing_email = db.query(User).filter(User.email == user_data.email).first()
|
||||||
|
if existing_email:
|
||||||
|
return error_response("邮箱已被注册", 400)
|
||||||
|
|
||||||
|
# 创建用户
|
||||||
|
password_hash = hash_password(user_data.password)
|
||||||
|
user = User(
|
||||||
|
username=user_data.username,
|
||||||
|
email=user_data.email,
|
||||||
|
password_hash=password_hash,
|
||||||
|
role="user"
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(user)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user)
|
||||||
|
|
||||||
|
return success_response(
|
||||||
|
data={"id": user.id, "username": user.username},
|
||||||
|
message="注册成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=dict)
|
||||||
|
def login(user_data: UserLogin, db: Session = Depends(get_db)):
|
||||||
|
"""用户登录"""
|
||||||
|
user = db.query(User).filter(User.username == user_data.username).first()
|
||||||
|
|
||||||
|
if not user or not verify_password(user_data.password, user.password_hash or ""):
|
||||||
|
return error_response("用户名或密码错误", 401)
|
||||||
|
|
||||||
|
if not user.is_active:
|
||||||
|
return error_response("用户已被禁用", 403)
|
||||||
|
|
||||||
|
# 创建访问令牌
|
||||||
|
access_token = create_access_token(
|
||||||
|
data={"sub": user.id, "username": user.username},
|
||||||
|
expires_delta=timedelta(days=7)
|
||||||
|
)
|
||||||
|
|
||||||
|
return success_response(
|
||||||
|
data={
|
||||||
|
"access_token": access_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"user": user.to_dict()
|
||||||
|
},
|
||||||
|
message="登录成功"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/logout")
|
||||||
|
def logout(current_user: User = Depends(get_current_user)):
|
||||||
|
"""用户登出(前端清除令牌即可)"""
|
||||||
|
return success_response(message="登出成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=dict)
|
||||||
|
def get_me(current_user: User = Depends(get_current_user)):
|
||||||
|
"""获取当前用户信息"""
|
||||||
|
return success_response(data=current_user.to_dict())
|
||||||
|
|
@ -0,0 +1,153 @@
|
||||||
|
"""会话路由"""
|
||||||
|
from typing import Optional, List
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from alcor.database import get_db
|
||||||
|
from alcor.models import Conversation, User
|
||||||
|
from alcor.routes.auth import get_current_user
|
||||||
|
from alcor.utils.helpers import generate_id, success_response, error_response, paginate
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/conversations", tags=["会话"])
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationCreate(BaseModel):
|
||||||
|
"""创建会话模型"""
|
||||||
|
project_id: Optional[str] = None
|
||||||
|
title: str = ""
|
||||||
|
model: str = "glm-5"
|
||||||
|
system_prompt: str = ""
|
||||||
|
temperature: float = 1.0
|
||||||
|
max_tokens: int = 65536
|
||||||
|
thinking_enabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
class ConversationUpdate(BaseModel):
|
||||||
|
"""更新会话模型"""
|
||||||
|
title: Optional[str] = None
|
||||||
|
model: Optional[str] = None
|
||||||
|
system_prompt: Optional[str] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
thinking_enabled: Optional[bool] = None
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=dict)
|
||||||
|
def list_conversations(
|
||||||
|
project_id: Optional[str] = None,
|
||||||
|
page: int = 1,
|
||||||
|
page_size: int = 20,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取会话列表"""
|
||||||
|
query = db.query(Conversation).filter(Conversation.user_id == current_user.id)
|
||||||
|
|
||||||
|
if project_id:
|
||||||
|
query = query.filter(Conversation.project_id == project_id)
|
||||||
|
|
||||||
|
query = query.order_by(Conversation.updated_at.desc())
|
||||||
|
|
||||||
|
result = paginate(query, page, page_size)
|
||||||
|
items = [conv.to_dict() for conv in result["items"]]
|
||||||
|
|
||||||
|
return success_response(data={
|
||||||
|
"items": items,
|
||||||
|
"total": result["total"],
|
||||||
|
"page": result["page"],
|
||||||
|
"page_size": result["page_size"]
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=dict)
|
||||||
|
def create_conversation(
|
||||||
|
data: ConversationCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""创建会话"""
|
||||||
|
conversation = Conversation(
|
||||||
|
id=generate_id("conv"),
|
||||||
|
user_id=current_user.id,
|
||||||
|
project_id=data.project_id,
|
||||||
|
title=data.title or "新会话",
|
||||||
|
model=data.model,
|
||||||
|
system_prompt=data.system_prompt,
|
||||||
|
temperature=data.temperature,
|
||||||
|
max_tokens=data.max_tokens,
|
||||||
|
thinking_enabled=data.thinking_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
db.add(conversation)
|
||||||
|
db.commit()
|
||||||
|
db.refresh(conversation)
|
||||||
|
|
||||||
|
return success_response(data=conversation.to_dict(), message="会话创建成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{conversation_id}", response_model=dict)
|
||||||
|
def get_conversation(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取会话详情"""
|
||||||
|
conversation = db.query(Conversation).filter(
|
||||||
|
Conversation.id == conversation_id,
|
||||||
|
Conversation.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
return error_response("会话不存在", 404)
|
||||||
|
|
||||||
|
return success_response(data=conversation.to_dict())
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/{conversation_id}", response_model=dict)
|
||||||
|
def update_conversation(
|
||||||
|
conversation_id: str,
|
||||||
|
data: ConversationUpdate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""更新会话"""
|
||||||
|
conversation = db.query(Conversation).filter(
|
||||||
|
Conversation.id == conversation_id,
|
||||||
|
Conversation.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
return error_response("会话不存在", 404)
|
||||||
|
|
||||||
|
# 更新字段
|
||||||
|
update_data = data.dict(exclude_unset=True)
|
||||||
|
for key, value in update_data.items():
|
||||||
|
setattr(conversation, key, value)
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(conversation)
|
||||||
|
|
||||||
|
return success_response(data=conversation.to_dict(), message="会话更新成功")
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{conversation_id}", response_model=dict)
|
||||||
|
def delete_conversation(
|
||||||
|
conversation_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""删除会话"""
|
||||||
|
conversation = db.query(Conversation).filter(
|
||||||
|
Conversation.id == conversation_id,
|
||||||
|
Conversation.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
return error_response("会话不存在", 404)
|
||||||
|
|
||||||
|
db.delete(conversation)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return success_response(message="会话删除成功")
|
||||||
|
|
@ -0,0 +1,238 @@
|
||||||
|
"""消息路由"""
|
||||||
|
import json
|
||||||
|
from typing import Optional, List
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from alcor.database import get_db
|
||||||
|
from alcor.models import Conversation, Message, User
|
||||||
|
from alcor.routes.auth import get_current_user
|
||||||
|
from alcor.services.chat import chat_service
|
||||||
|
from alcor.utils.helpers import generate_id, success_response, error_response
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/messages", tags=["消息"])
|
||||||
|
|
||||||
|
|
||||||
|
class MessageCreate(BaseModel):
|
||||||
|
"""创建消息模型"""
|
||||||
|
conversation_id: str
|
||||||
|
content: str
|
||||||
|
tools_enabled: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
class MessageResponse(BaseModel):
|
||||||
|
"""消息响应模型"""
|
||||||
|
id: str
|
||||||
|
conversation_id: str
|
||||||
|
role: str
|
||||||
|
content: str
|
||||||
|
token_count: int
|
||||||
|
created_at: str
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{conversation_id}", response_model=dict)
|
||||||
|
def list_messages(
|
||||||
|
conversation_id: str,
|
||||||
|
limit: int = 100,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""获取消息列表"""
|
||||||
|
# 验证会话归属
|
||||||
|
conversation = db.query(Conversation).filter(
|
||||||
|
Conversation.id == conversation_id,
|
||||||
|
Conversation.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
return error_response("会话不存在", 404)
|
||||||
|
|
||||||
|
messages = db.query(Message).filter(
|
||||||
|
Message.conversation_id == conversation_id
|
||||||
|
).order_by(Message.created_at.desc()).limit(limit).all()
|
||||||
|
|
||||||
|
items = [msg.to_dict() for msg in reversed(messages)]
|
||||||
|
|
||||||
|
return success_response(data={
|
||||||
|
"items": items,
|
||||||
|
"total": len(items)
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", response_model=dict)
|
||||||
|
async def create_message(
|
||||||
|
data: MessageCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""发送消息(非流式)"""
|
||||||
|
# 验证会话
|
||||||
|
conversation = db.query(Conversation).filter(
|
||||||
|
Conversation.id == data.conversation_id,
|
||||||
|
Conversation.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
return error_response("会话不存在", 404)
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
user_message = Message(
|
||||||
|
id=generate_id("msg"),
|
||||||
|
conversation_id=data.conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=json.dumps({"text": data.content})
|
||||||
|
)
|
||||||
|
db.add(user_message)
|
||||||
|
|
||||||
|
# 更新会话时间
|
||||||
|
from datetime import datetime
|
||||||
|
conversation.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user_message)
|
||||||
|
|
||||||
|
# 获取AI响应(非流式)
|
||||||
|
response = chat_service.non_stream_response(
|
||||||
|
conversation=conversation,
|
||||||
|
user_message=data.content,
|
||||||
|
tools_enabled=data.tools_enabled
|
||||||
|
)
|
||||||
|
|
||||||
|
if not response.get("success"):
|
||||||
|
return error_response(response.get("error", "生成响应失败"), 500)
|
||||||
|
|
||||||
|
# 保存AI响应
|
||||||
|
ai_content = response.get("content", "")
|
||||||
|
ai_message = Message(
|
||||||
|
id=generate_id("msg"),
|
||||||
|
conversation_id=data.conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=json.dumps({
|
||||||
|
"text": ai_content,
|
||||||
|
"tool_calls": response.get("tool_calls")
|
||||||
|
}),
|
||||||
|
token_count=len(ai_content) // 4 # 粗略估算
|
||||||
|
)
|
||||||
|
db.add(ai_message)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return success_response(data={
|
||||||
|
"user_message": user_message.to_dict(),
|
||||||
|
"assistant_message": ai_message.to_dict()
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/stream")
|
||||||
|
async def stream_message(
|
||||||
|
data: MessageCreate,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""发送消息(流式响应 - SSE)"""
|
||||||
|
# 验证会话
|
||||||
|
conversation = db.query(Conversation).filter(
|
||||||
|
Conversation.id == data.conversation_id,
|
||||||
|
Conversation.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not conversation:
|
||||||
|
return error_response("会话不存在", 404)
|
||||||
|
|
||||||
|
# 保存用户消息
|
||||||
|
user_message = Message(
|
||||||
|
id=generate_id("msg"),
|
||||||
|
conversation_id=data.conversation_id,
|
||||||
|
role="user",
|
||||||
|
content=json.dumps({"text": data.content})
|
||||||
|
)
|
||||||
|
db.add(user_message)
|
||||||
|
|
||||||
|
# 更新会话时间
|
||||||
|
from datetime import datetime
|
||||||
|
conversation.updated_at = datetime.utcnow()
|
||||||
|
|
||||||
|
db.commit()
|
||||||
|
db.refresh(user_message)
|
||||||
|
|
||||||
|
async def event_generator():
|
||||||
|
"""SSE事件生成器"""
|
||||||
|
full_response = ""
|
||||||
|
message_id = generate_id("msg")
|
||||||
|
|
||||||
|
async for event in chat_service.stream_response(
|
||||||
|
conversation=conversation,
|
||||||
|
user_message=data.content,
|
||||||
|
tools_enabled=data.tools_enabled
|
||||||
|
):
|
||||||
|
event_type = event.get("type")
|
||||||
|
|
||||||
|
if event_type == "process_step":
|
||||||
|
step_type = event.get("step_type")
|
||||||
|
|
||||||
|
if step_type == "text":
|
||||||
|
content = event.get("content", "")
|
||||||
|
full_response += content
|
||||||
|
yield f"data: {json.dumps({'type': 'text', 'content': content})}\n\n"
|
||||||
|
|
||||||
|
elif step_type == "tool_call":
|
||||||
|
yield f"data: {json.dumps({'type': 'tool_call', 'tool_calls': event.get('tool_calls')})}\n\n"
|
||||||
|
|
||||||
|
elif step_type == "tool_result":
|
||||||
|
yield f"data: {json.dumps({'type': 'tool_result', 'result': event.get('result')})}\n\n"
|
||||||
|
|
||||||
|
elif event_type == "done":
|
||||||
|
# 保存AI消息
|
||||||
|
try:
|
||||||
|
ai_message = Message(
|
||||||
|
id=message_id,
|
||||||
|
conversation_id=data.conversation_id,
|
||||||
|
role="assistant",
|
||||||
|
content=json.dumps({"text": full_response}),
|
||||||
|
token_count=len(full_response) // 4
|
||||||
|
)
|
||||||
|
db.add(ai_message)
|
||||||
|
db.commit()
|
||||||
|
except Exception as e:
|
||||||
|
db.rollback()
|
||||||
|
|
||||||
|
yield f"data: {json.dumps({'type': 'done', 'message_id': message_id})}\n\n"
|
||||||
|
|
||||||
|
elif event_type == "error":
|
||||||
|
yield f"data: {json.dumps({'type': 'error', 'error': event.get('error')})}\n\n"
|
||||||
|
|
||||||
|
yield "data: [DONE]\n\n"
|
||||||
|
|
||||||
|
return StreamingResponse(
|
||||||
|
event_generator(),
|
||||||
|
media_type="text/event-stream",
|
||||||
|
headers={
|
||||||
|
"Cache-Control": "no-cache",
|
||||||
|
"Connection": "keep-alive",
|
||||||
|
"X-Accel-Buffering": "no"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{message_id}", response_model=dict)
|
||||||
|
def delete_message(
|
||||||
|
message_id: str,
|
||||||
|
current_user: User = Depends(get_current_user),
|
||||||
|
db: Session = Depends(get_db)
|
||||||
|
):
|
||||||
|
"""删除消息"""
|
||||||
|
# 获取消息及其会话
|
||||||
|
message = db.query(Message).join(Conversation).filter(
|
||||||
|
Message.id == message_id,
|
||||||
|
Conversation.user_id == current_user.id
|
||||||
|
).first()
|
||||||
|
|
||||||
|
if not message:
|
||||||
|
return error_response("消息不存在", 404)
|
||||||
|
|
||||||
|
db.delete(message)
|
||||||
|
db.commit()
|
||||||
|
|
||||||
|
return success_response(message="消息删除成功")
|
||||||
|
|
@ -0,0 +1,73 @@
|
||||||
|
"""工具路由"""
|
||||||
|
from typing import Optional, List, Dict, Any
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from alcor.database import get_db
|
||||||
|
from alcor.models import User
|
||||||
|
from alcor.routes.auth import get_current_user
|
||||||
|
from alcor.tools.core import registry
|
||||||
|
from alcor.utils.helpers import success_response
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/tools", tags=["工具"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", response_model=dict)
|
||||||
|
def list_tools(
|
||||||
|
category: Optional[str] = None,
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取可用工具列表"""
|
||||||
|
if category:
|
||||||
|
tools = registry.list_by_category(category)
|
||||||
|
else:
|
||||||
|
tools = registry.list_all()
|
||||||
|
|
||||||
|
# 按分类分组
|
||||||
|
categorized = {}
|
||||||
|
for tool in tools:
|
||||||
|
cat = tool.get("function", {}).get("category", "general")
|
||||||
|
if cat not in categorized:
|
||||||
|
categorized[cat] = []
|
||||||
|
categorized[cat].append(tool)
|
||||||
|
|
||||||
|
return success_response(data={
|
||||||
|
"tools": tools,
|
||||||
|
"categorized": categorized,
|
||||||
|
"total": registry.tool_count
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{name}", response_model=dict)
|
||||||
|
def get_tool(
|
||||||
|
name: str,
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""获取工具详情"""
|
||||||
|
tool = registry.get(name)
|
||||||
|
|
||||||
|
if not tool:
|
||||||
|
return {"success": False, "message": "工具不存在", "code": 404}
|
||||||
|
|
||||||
|
return success_response(data={
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters,
|
||||||
|
"category": tool.category
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{name}/execute", response_model=dict)
|
||||||
|
def execute_tool(
|
||||||
|
name: str,
|
||||||
|
arguments: Dict[str, Any],
|
||||||
|
current_user: User = Depends(get_current_user)
|
||||||
|
):
|
||||||
|
"""手动执行工具"""
|
||||||
|
result = registry.execute(name, arguments)
|
||||||
|
|
||||||
|
if not result.get("success"):
|
||||||
|
return {"success": False, "message": result.get("error"), "code": 400}
|
||||||
|
|
||||||
|
return success_response(data=result)
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""应用入口"""
|
||||||
|
import uvicorn
|
||||||
|
from alcor.config import config
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
"""启动应用"""
|
||||||
|
uvicorn.run(
|
||||||
|
"alcor:app",
|
||||||
|
host=config.app_host,
|
||||||
|
port=config.app_port,
|
||||||
|
reload=config.debug,
|
||||||
|
log_level="debug" if config.debug else "info"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -0,0 +1,11 @@
|
||||||
|
"""服务层模块"""
|
||||||
|
from alcor.services.llm_client import LLMClient, llm_client, LLMResponse
|
||||||
|
from alcor.services.chat import ChatService, chat_service
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"LLMClient",
|
||||||
|
"llm_client",
|
||||||
|
"LLMResponse",
|
||||||
|
"ChatService",
|
||||||
|
"chat_service"
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,262 @@
|
||||||
|
"""聊天服务模块"""
|
||||||
|
import json
|
||||||
|
import re
|
||||||
|
from typing import Dict, List, Optional, Any, Generator
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from alcor.models import Conversation, Message
|
||||||
|
from alcor.tools.executor import ToolExecutor
|
||||||
|
from alcor.tools.core import registry
|
||||||
|
from alcor.services.llm_client import llm_client, LLMClient
|
||||||
|
|
||||||
|
|
||||||
|
# 最大迭代次数,防止无限循环
|
||||||
|
MAX_ITERATIONS = 10
|
||||||
|
|
||||||
|
|
||||||
|
class ChatService:
|
||||||
|
"""聊天服务"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
llm_client: Optional[LLMClient] = None,
|
||||||
|
max_iterations: int = MAX_ITERATIONS
|
||||||
|
):
|
||||||
|
self.llm_client = llm_client or llm_client
|
||||||
|
self.tool_executor = ToolExecutor(enable_cache=True, cache_ttl=300)
|
||||||
|
self.max_iterations = max_iterations
|
||||||
|
|
||||||
|
def build_messages(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
include_system: bool = True
|
||||||
|
) -> List[Dict[str, str]]:
|
||||||
|
"""构建消息列表"""
|
||||||
|
messages = []
|
||||||
|
|
||||||
|
# 添加系统提示
|
||||||
|
if include_system and conversation.system_prompt:
|
||||||
|
messages.append({
|
||||||
|
"role": "system",
|
||||||
|
"content": conversation.system_prompt
|
||||||
|
})
|
||||||
|
|
||||||
|
# 添加历史消息
|
||||||
|
for msg in conversation.messages.order_by(Message.created_at).all():
|
||||||
|
try:
|
||||||
|
content_data = json.loads(msg.content) if msg.content else {}
|
||||||
|
if isinstance(content_data, dict):
|
||||||
|
text = content_data.get("text", "")
|
||||||
|
else:
|
||||||
|
text = str(msg.content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
text = msg.content
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": msg.role,
|
||||||
|
"content": text
|
||||||
|
})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
def stream_response(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
user_message: str,
|
||||||
|
tools_enabled: bool = True,
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Generator[Dict[str, Any], None, None]:
|
||||||
|
"""
|
||||||
|
流式响应生成器
|
||||||
|
|
||||||
|
生成事件类型:
|
||||||
|
- process_step: thinking/text/tool_call/tool_result 步骤
|
||||||
|
- done: 最终响应完成
|
||||||
|
- error: 出错时
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# 构建消息列表
|
||||||
|
messages = self.build_messages(conversation)
|
||||||
|
|
||||||
|
# 添加用户消息
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_message
|
||||||
|
})
|
||||||
|
|
||||||
|
# 获取工具列表
|
||||||
|
tools = registry.list_all() if tools_enabled else None
|
||||||
|
|
||||||
|
# 迭代处理
|
||||||
|
iteration = 0
|
||||||
|
full_response = ""
|
||||||
|
tool_calls_buffer: List[Dict] = []
|
||||||
|
|
||||||
|
while iteration < self.max_iterations:
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
# 调用LLM
|
||||||
|
tool_calls_this_round = None
|
||||||
|
|
||||||
|
for event in self.llm_client.stream(
|
||||||
|
model=conversation.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
temperature=conversation.temperature,
|
||||||
|
max_tokens=conversation.max_tokens,
|
||||||
|
thinking_enabled=conversation.thinking_enabled
|
||||||
|
):
|
||||||
|
event_type = event.get("type")
|
||||||
|
|
||||||
|
if event_type == "content_delta":
|
||||||
|
# 内容增量
|
||||||
|
content = event.get("content", "")
|
||||||
|
if content:
|
||||||
|
full_response += content
|
||||||
|
yield {
|
||||||
|
"type": "process_step",
|
||||||
|
"step_type": "text",
|
||||||
|
"content": content
|
||||||
|
}
|
||||||
|
|
||||||
|
elif event_type == "done":
|
||||||
|
# 完成
|
||||||
|
tool_calls_this_round = event.get("tool_calls")
|
||||||
|
|
||||||
|
# 处理工具调用
|
||||||
|
if tool_calls_this_round and tools_enabled:
|
||||||
|
yield {
|
||||||
|
"type": "process_step",
|
||||||
|
"step_type": "tool_call",
|
||||||
|
"tool_calls": tool_calls_this_round
|
||||||
|
}
|
||||||
|
|
||||||
|
# 执行工具
|
||||||
|
tool_results = self.tool_executor.process_tool_calls_parallel(
|
||||||
|
tool_calls_this_round
|
||||||
|
)
|
||||||
|
|
||||||
|
for result in tool_results:
|
||||||
|
yield {
|
||||||
|
"type": "process_step",
|
||||||
|
"step_type": "tool_result",
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加到消息历史
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_response,
|
||||||
|
"tool_calls": tool_calls_this_round
|
||||||
|
})
|
||||||
|
|
||||||
|
# 添加工具结果
|
||||||
|
for tr in tool_results:
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tr.get("tool_call_id"),
|
||||||
|
"content": tr.get("content", ""),
|
||||||
|
"name": tr.get("name")
|
||||||
|
})
|
||||||
|
|
||||||
|
tool_calls_buffer.extend(tool_calls_this_round)
|
||||||
|
else:
|
||||||
|
# 没有工具调用,退出循环
|
||||||
|
break
|
||||||
|
|
||||||
|
# 如果没有更多工具调用,结束
|
||||||
|
if not tool_calls_this_round or not tools_enabled:
|
||||||
|
break
|
||||||
|
|
||||||
|
# 最终完成
|
||||||
|
yield {
|
||||||
|
"type": "done",
|
||||||
|
"content": full_response,
|
||||||
|
"tool_calls": tool_calls_buffer if tool_calls_buffer else None,
|
||||||
|
"iterations": iteration
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
def non_stream_response(
|
||||||
|
self,
|
||||||
|
conversation: Conversation,
|
||||||
|
user_message: str,
|
||||||
|
tools_enabled: bool = True,
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""非流式响应"""
|
||||||
|
try:
|
||||||
|
messages = self.build_messages(conversation)
|
||||||
|
messages.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": user_message
|
||||||
|
})
|
||||||
|
|
||||||
|
tools = registry.list_all() if tools_enabled else None
|
||||||
|
|
||||||
|
# 迭代处理
|
||||||
|
iteration = 0
|
||||||
|
full_response = ""
|
||||||
|
all_tool_calls = []
|
||||||
|
|
||||||
|
while iteration < self.max_iterations:
|
||||||
|
iteration += 1
|
||||||
|
|
||||||
|
response = self.llm_client.call(
|
||||||
|
model=conversation.model,
|
||||||
|
messages=messages,
|
||||||
|
tools=tools,
|
||||||
|
stream=False,
|
||||||
|
temperature=conversation.temperature,
|
||||||
|
max_tokens=conversation.max_tokens
|
||||||
|
)
|
||||||
|
|
||||||
|
full_response = response.content
|
||||||
|
tool_calls = response.tool_calls
|
||||||
|
|
||||||
|
if tool_calls and tools_enabled:
|
||||||
|
# 执行工具
|
||||||
|
tool_results = self.tool_executor.process_tool_calls_parallel(tool_calls)
|
||||||
|
all_tool_calls.extend(tool_calls)
|
||||||
|
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_response,
|
||||||
|
"tool_calls": tool_calls
|
||||||
|
})
|
||||||
|
|
||||||
|
for tr in tool_results:
|
||||||
|
messages.append({
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": tr.get("tool_call_id"),
|
||||||
|
"content": tr.get("content", ""),
|
||||||
|
"name": tr.get("name")
|
||||||
|
})
|
||||||
|
else:
|
||||||
|
messages.append({
|
||||||
|
"role": "assistant",
|
||||||
|
"content": full_response
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"content": full_response,
|
||||||
|
"tool_calls": all_tool_calls,
|
||||||
|
"iterations": iteration
|
||||||
|
}
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局聊天服务
|
||||||
|
chat_service = ChatService()
|
||||||
|
|
@ -0,0 +1,256 @@
|
||||||
|
"""LLM API客户端"""
|
||||||
|
import json
|
||||||
|
from typing import Dict, List, Optional, Generator, Any, Callable, AsyncGenerator
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from alcor.config import config
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LLMResponse:
|
||||||
|
"""LLM响应"""
|
||||||
|
content: str
|
||||||
|
tool_calls: Optional[List[Dict[str, Any]]] = None
|
||||||
|
usage: Optional[Dict[str, int]] = None
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
raw: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
|
class LLMClient:
|
||||||
|
"""LLM API客户端,支持多种提供商"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
api_url: Optional[str] = None,
|
||||||
|
provider: Optional[str] = None
|
||||||
|
):
|
||||||
|
self.api_key = api_key or config.llm_api_key
|
||||||
|
self.api_url = api_url or config.llm_api_url
|
||||||
|
self.provider = provider or config.llm_provider or self._detect_provider()
|
||||||
|
self._client: Optional[httpx.AsyncClient] = None
|
||||||
|
|
||||||
|
def _detect_provider(self) -> str:
|
||||||
|
"""检测提供商"""
|
||||||
|
url = self.api_url.lower()
|
||||||
|
if "deepseek" in url:
|
||||||
|
return "deepseek"
|
||||||
|
elif "bigmodel" in url or "glm" in url:
|
||||||
|
return "glm"
|
||||||
|
elif "zhipu" in url:
|
||||||
|
return "glm"
|
||||||
|
elif "qwen" in url or "dashscope" in url:
|
||||||
|
return "qwen"
|
||||||
|
elif "moonshot" in url or "moonshot" in url:
|
||||||
|
return "moonshot"
|
||||||
|
return "openai"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def client(self) -> httpx.AsyncClient:
|
||||||
|
"""获取HTTP客户端"""
|
||||||
|
if self._client is None:
|
||||||
|
self._client = httpx.AsyncClient(
|
||||||
|
timeout=httpx.Timeout(120.0, connect=30.0),
|
||||||
|
headers={
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return self._client
|
||||||
|
|
||||||
|
async def close(self):
|
||||||
|
"""关闭客户端"""
|
||||||
|
if self._client:
|
||||||
|
await self._client.aclose()
|
||||||
|
self._client = None
|
||||||
|
|
||||||
|
def _build_headers(self) -> Dict[str, str]:
|
||||||
|
"""构建请求头"""
|
||||||
|
return {
|
||||||
|
"Authorization": f"Bearer {self.api_key}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
|
||||||
|
def _build_body(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
stream: bool = True,
|
||||||
|
**kwargs
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""构建请求体"""
|
||||||
|
body = {
|
||||||
|
"model": model,
|
||||||
|
"messages": messages,
|
||||||
|
"stream": stream
|
||||||
|
}
|
||||||
|
|
||||||
|
# 添加可选参数
|
||||||
|
if "temperature" in kwargs:
|
||||||
|
body["temperature"] = kwargs["temperature"]
|
||||||
|
if "max_tokens" in kwargs:
|
||||||
|
body["max_tokens"] = kwargs["max_tokens"]
|
||||||
|
if "top_p" in kwargs:
|
||||||
|
body["top_p"] = kwargs["top_p"]
|
||||||
|
if "thinking_enabled" in kwargs:
|
||||||
|
body["thinking_enabled"] = kwargs["thinking_enabled"]
|
||||||
|
|
||||||
|
# 添加工具
|
||||||
|
if tools:
|
||||||
|
body["tools"] = tools
|
||||||
|
|
||||||
|
return body
|
||||||
|
|
||||||
|
def _parse_response(self, data: Dict) -> LLMResponse:
|
||||||
|
"""解析响应"""
|
||||||
|
# 通用字段
|
||||||
|
content = ""
|
||||||
|
tool_calls = None
|
||||||
|
usage = None
|
||||||
|
finish_reason = None
|
||||||
|
|
||||||
|
# OpenAI格式
|
||||||
|
if "choices" in data:
|
||||||
|
choice = data["choices"][0]
|
||||||
|
message = choice.get("message", {})
|
||||||
|
content = message.get("content", "")
|
||||||
|
tool_calls = message.get("tool_calls")
|
||||||
|
finish_reason = choice.get("finish_reason")
|
||||||
|
|
||||||
|
# 使用量统计
|
||||||
|
if "usage" in data:
|
||||||
|
usage = {
|
||||||
|
"prompt_tokens": data["usage"].get("prompt_tokens", 0),
|
||||||
|
"completion_tokens": data["usage"].get("completion_tokens", 0),
|
||||||
|
"total_tokens": data["usage"].get("total_tokens", 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
return LLMResponse(
|
||||||
|
content=content,
|
||||||
|
tool_calls=tool_calls,
|
||||||
|
usage=usage,
|
||||||
|
finish_reason=finish_reason,
|
||||||
|
raw=data
|
||||||
|
)
|
||||||
|
|
||||||
|
async def call(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> LLMResponse:
|
||||||
|
"""调用LLM API(非流式)"""
|
||||||
|
body = self._build_body(model, messages, tools, stream=False, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await self.client.post(
|
||||||
|
self.api_url,
|
||||||
|
json=body,
|
||||||
|
headers=self._build_headers()
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
return self._parse_response(data)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise Exception(f"HTTP error: {e.response.status_code} - {e.response.text}")
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(f"LLM API error: {str(e)}")
|
||||||
|
|
||||||
|
async def stream(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
tools: Optional[List[Dict]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> AsyncGenerator[Dict[str, Any], None]:
|
||||||
|
"""流式调用LLM API"""
|
||||||
|
body = self._build_body(model, messages, tools, stream=True, **kwargs)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with self.client.stream(
|
||||||
|
"POST",
|
||||||
|
self.api_url,
|
||||||
|
json=body,
|
||||||
|
headers=self._build_headers()
|
||||||
|
) as response:
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
accumulated_content = ""
|
||||||
|
accumulated_tool_calls: Dict[int, Dict] = {}
|
||||||
|
|
||||||
|
async for line in response.aiter_lines():
|
||||||
|
if not line.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 跳过SSE前缀
|
||||||
|
if line.startswith("data: "):
|
||||||
|
line = line[6:]
|
||||||
|
|
||||||
|
if line == "[DONE]":
|
||||||
|
break
|
||||||
|
|
||||||
|
try:
|
||||||
|
chunk = json.loads(line)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 解析SSE数据
|
||||||
|
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||||
|
|
||||||
|
# 内容增量
|
||||||
|
content_delta = delta.get("content", "")
|
||||||
|
if content_delta:
|
||||||
|
accumulated_content += content_delta
|
||||||
|
yield {
|
||||||
|
"type": "content_delta",
|
||||||
|
"content": content_delta,
|
||||||
|
"full_content": accumulated_content
|
||||||
|
}
|
||||||
|
|
||||||
|
# 工具调用增量
|
||||||
|
tool_calls = delta.get("tool_calls", [])
|
||||||
|
for tc in tool_calls:
|
||||||
|
index = tc.get("index", 0)
|
||||||
|
if index not in accumulated_tool_calls:
|
||||||
|
accumulated_tool_calls[index] = {
|
||||||
|
"id": "",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "", "arguments": ""}
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.get("id"):
|
||||||
|
accumulated_tool_calls[index]["id"] = tc["id"]
|
||||||
|
if tc.get("function", {}).get("name"):
|
||||||
|
accumulated_tool_calls[index]["function"]["name"] = tc["function"]["name"]
|
||||||
|
if tc.get("function", {}).get("arguments"):
|
||||||
|
accumulated_tool_calls[index]["function"]["arguments"] += tc["function"]["arguments"]
|
||||||
|
|
||||||
|
# 完成信号
|
||||||
|
finish_reason = chunk.get("choices", [{}])[0].get("finish_reason")
|
||||||
|
if finish_reason:
|
||||||
|
yield {
|
||||||
|
"type": "done",
|
||||||
|
"finish_reason": finish_reason,
|
||||||
|
"content": accumulated_content,
|
||||||
|
"tool_calls": list(accumulated_tool_calls.values()) if accumulated_tool_calls else None,
|
||||||
|
"usage": chunk.get("usage")
|
||||||
|
}
|
||||||
|
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
"error": f"HTTP error: {e.response.status_code}"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
yield {
|
||||||
|
"type": "error",
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局LLM客户端
|
||||||
|
llm_client = LLMClient()
|
||||||
|
|
@ -0,0 +1,19 @@
|
||||||
|
"""工具系统模块"""
|
||||||
|
from alcor.tools.core import (
|
||||||
|
ToolDefinition,
|
||||||
|
ToolResult,
|
||||||
|
ToolRegistry,
|
||||||
|
registry
|
||||||
|
)
|
||||||
|
from alcor.tools.factory import tool, tool_function
|
||||||
|
from alcor.tools.executor import ToolExecutor
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"ToolDefinition",
|
||||||
|
"ToolResult",
|
||||||
|
"ToolRegistry",
|
||||||
|
"registry",
|
||||||
|
"tool",
|
||||||
|
"tool_function",
|
||||||
|
"ToolExecutor"
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
"""内置工具模块"""
|
||||||
|
# 导入所有内置工具以注册它们
|
||||||
|
from alcor.tools.builtin import crawler
|
||||||
|
from alcor.tools.builtin import code
|
||||||
|
from alcor.tools.builtin import data
|
||||||
|
|
||||||
|
__all__ = ["crawler", "code", "data"]
|
||||||
|
|
@ -0,0 +1,122 @@
|
||||||
|
"""代码执行工具"""
|
||||||
|
import json
|
||||||
|
import traceback
|
||||||
|
import ast
|
||||||
|
from typing import Dict, Any
|
||||||
|
|
||||||
|
from alcor.tools.factory import tool
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="python_execute",
|
||||||
|
description="Execute Python code and return the result",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"code": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Python code to execute"
|
||||||
|
},
|
||||||
|
"timeout": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Execution timeout in seconds",
|
||||||
|
"default": 30
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["code"]
|
||||||
|
},
|
||||||
|
category="code"
|
||||||
|
)
|
||||||
|
def python_execute(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行Python代码
|
||||||
|
|
||||||
|
注意:这是一个简化的执行器,生产环境应使用更安全的隔离环境
|
||||||
|
如:Docker容器、Pyodide等
|
||||||
|
"""
|
||||||
|
code = arguments.get("code", "")
|
||||||
|
timeout = arguments.get("timeout", 30)
|
||||||
|
|
||||||
|
if not code:
|
||||||
|
return {"success": False, "error": "Code is required"}
|
||||||
|
|
||||||
|
# 创建执行环境(允许大多数操作)
|
||||||
|
namespace = {
|
||||||
|
"__builtins__": __builtins__
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 编译并执行代码
|
||||||
|
compiled = compile(code, "<string>", "exec")
|
||||||
|
|
||||||
|
# 捕获输出
|
||||||
|
import io
|
||||||
|
from contextlib import redirect_stdout
|
||||||
|
|
||||||
|
output = io.StringIO()
|
||||||
|
|
||||||
|
with redirect_stdout(output):
|
||||||
|
exec(compiled, namespace)
|
||||||
|
|
||||||
|
result = output.getvalue()
|
||||||
|
|
||||||
|
# 尝试提取变量
|
||||||
|
result_vars = {k: v for k, v in namespace.items()
|
||||||
|
if not k.startswith("_") and k != "__builtins__"}
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"output": result,
|
||||||
|
"variables": {k: repr(v) for k, v in result_vars.items()},
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except SyntaxError as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Syntax error: {e}"
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Runtime error: {type(e).__name__}: {str(e)}"
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="python_eval",
|
||||||
|
description="Evaluate a Python expression and return the result",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Python expression to evaluate"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
},
|
||||||
|
category="code"
|
||||||
|
)
|
||||||
|
def python_eval(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""评估Python表达式"""
|
||||||
|
expression = arguments.get("expression", "")
|
||||||
|
|
||||||
|
if not expression:
|
||||||
|
return {"success": False, "error": "Expression is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = eval(expression)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"result": repr(result),
|
||||||
|
"type": type(result).__name__
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Evaluation error: {str(e)}"
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,189 @@
|
||||||
|
"""网页爬虫工具"""
|
||||||
|
import requests
|
||||||
|
from typing import Dict, Any, List, Optional
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
from alcor.tools.factory import tool
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the internet for information using web search",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Search keywords"
|
||||||
|
},
|
||||||
|
"max_results": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of results to return",
|
||||||
|
"default": 5
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def web_search(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
执行网络搜索
|
||||||
|
|
||||||
|
注意:这是一个占位实现,实际使用时需要接入真实的搜索API
|
||||||
|
如:Google Custom Search, DuckDuckGo, SerpAPI等
|
||||||
|
"""
|
||||||
|
query = arguments.get("query", "")
|
||||||
|
max_results = arguments.get("max_results", 5)
|
||||||
|
|
||||||
|
if not query:
|
||||||
|
return {"success": False, "error": "Query is required"}
|
||||||
|
|
||||||
|
# 模拟搜索结果
|
||||||
|
# 实际实现应接入真实搜索API
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"query": query,
|
||||||
|
"results": [
|
||||||
|
{
|
||||||
|
"title": f"Result for '{query}' - Example {i+1}",
|
||||||
|
"url": f"https://example.com/result_{i+1}",
|
||||||
|
"snippet": f"This is a sample search result for the query '{query}'. " * 3
|
||||||
|
}
|
||||||
|
for i in range(min(max_results, 5))
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="web_fetch",
|
||||||
|
description="Fetch and parse content from a web page",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "URL of the web page to fetch"
|
||||||
|
},
|
||||||
|
"extract_text": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Whether to extract text content only",
|
||||||
|
"default": True
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def web_fetch(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""获取并解析网页内容"""
|
||||||
|
url = arguments.get("url", "")
|
||||||
|
extract_text = arguments.get("extract_text", True)
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
return {"success": False, "error": "URL is required"}
|
||||||
|
|
||||||
|
# 简单的URL验证
|
||||||
|
if not url.startswith(("http://", "https://")):
|
||||||
|
url = "https://" + url
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||||
|
}
|
||||||
|
response = requests.get(url, headers=headers, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
if extract_text:
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
# 移除script和style标签
|
||||||
|
for tag in soup(["script", "style"]):
|
||||||
|
tag.decompose()
|
||||||
|
text = soup.get_text(separator="\n", strip=True)
|
||||||
|
# 清理多余空行
|
||||||
|
lines = [line.strip() for line in text.split("\n") if line.strip()]
|
||||||
|
text = "\n".join(lines)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"url": url,
|
||||||
|
"title": soup.title.string if soup.title else "",
|
||||||
|
"content": text[:10000] # 限制内容长度
|
||||||
|
}
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"url": url,
|
||||||
|
"html": response.text[:50000] # 限制HTML长度
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except requests.RequestException as e:
|
||||||
|
return {"success": False, "error": f"Failed to fetch URL: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="extract_links",
|
||||||
|
description="Extract all links from a web page",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"url": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "URL of the web page"
|
||||||
|
},
|
||||||
|
"max_links": {
|
||||||
|
"type": "integer",
|
||||||
|
"description": "Maximum number of links to extract",
|
||||||
|
"default": 20
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["url"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def extract_links(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""提取网页中的所有链接"""
|
||||||
|
url = arguments.get("url", "")
|
||||||
|
max_links = arguments.get("max_links", 20)
|
||||||
|
|
||||||
|
if not url:
|
||||||
|
return {"success": False, "error": "URL is required"}
|
||||||
|
|
||||||
|
if not url.startswith(("http://", "https://")):
|
||||||
|
url = "https://" + url
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {
|
||||||
|
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36"
|
||||||
|
}
|
||||||
|
response = requests.get(url, headers=headers, timeout=10)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
soup = BeautifulSoup(response.text, "html.parser")
|
||||||
|
links = []
|
||||||
|
|
||||||
|
for a_tag in soup.find_all("a", href=True)[:max_links]:
|
||||||
|
href = a_tag["href"]
|
||||||
|
# 处理相对URL
|
||||||
|
if href.startswith("/"):
|
||||||
|
from urllib.parse import urljoin
|
||||||
|
href = urljoin(url, href)
|
||||||
|
links.append({
|
||||||
|
"text": a_tag.get_text(strip=True) or href,
|
||||||
|
"url": href
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"url": url,
|
||||||
|
"links": links
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except requests.RequestException as e:
|
||||||
|
return {"success": False, "error": f"Failed to fetch URL: {str(e)}"}
|
||||||
|
|
@ -0,0 +1,270 @@
|
||||||
|
"""数据处理工具"""
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import hashlib
|
||||||
|
import base64
|
||||||
|
import urllib.parse
|
||||||
|
from typing import Dict, Any, List
|
||||||
|
|
||||||
|
from alcor.tools.factory import tool
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="calculate",
|
||||||
|
description="Perform mathematical calculations",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"expression": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Mathematical expression to evaluate (e.g., '2 + 2', 'sqrt(16)', 'sin(pi/2)')"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["expression"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def calculate(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""执行数学计算"""
|
||||||
|
expression = arguments.get("expression", "")
|
||||||
|
|
||||||
|
if not expression:
|
||||||
|
return {"success": False, "error": "Expression is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 安全替换数学函数
|
||||||
|
safe_dict = {
|
||||||
|
"abs": abs,
|
||||||
|
"round": round,
|
||||||
|
"min": min,
|
||||||
|
"max": max,
|
||||||
|
"sum": sum,
|
||||||
|
"pow": pow,
|
||||||
|
"sqrt": lambda x: x ** 0.5,
|
||||||
|
"sin": lambda x: __import__("math").sin(x),
|
||||||
|
"cos": lambda x: __import__("math").cos(x),
|
||||||
|
"tan": lambda x: __import__("math").tan(x),
|
||||||
|
"log": lambda x: __import__("math").log(x),
|
||||||
|
"pi": __import__("math").pi,
|
||||||
|
"e": __import__("math").e,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 移除危险字符,只保留数字和运算符
|
||||||
|
safe_expr = re.sub(r"[^0-9+\-*/().%sqrtinsclogmaxminpowabsroundte, ]", "", expression)
|
||||||
|
result = eval(safe_expr, {"__builtins__": {}, **safe_dict})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"expression": expression,
|
||||||
|
"result": float(result) if isinstance(result, (int, float)) else result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except ZeroDivisionError:
|
||||||
|
return {"success": False, "error": "Division by zero"}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": f"Calculation error: {str(e)}"}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="text_process",
|
||||||
|
description="Process and transform text",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Input text"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation to perform: upper, lower, title, reverse, word_count, char_count, reverse_words",
|
||||||
|
"enum": ["upper", "lower", "title", "reverse", "word_count", "char_count", "reverse_words"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def text_process(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""文本处理"""
|
||||||
|
text = arguments.get("text", "")
|
||||||
|
operation = arguments.get("operation", "")
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return {"success": False, "error": "Text is required"}
|
||||||
|
|
||||||
|
operations = {
|
||||||
|
"upper": lambda t: t.upper(),
|
||||||
|
"lower": lambda t: t.lower(),
|
||||||
|
"title": lambda t: t.title(),
|
||||||
|
"reverse": lambda t: t[::-1],
|
||||||
|
"word_count": lambda t: len(t.split()),
|
||||||
|
"char_count": lambda t: len(t),
|
||||||
|
"reverse_words": lambda t: " ".join(t.split()[::-1])
|
||||||
|
}
|
||||||
|
|
||||||
|
if operation not in operations:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Unknown operation: {operation}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = operations[operation](text)
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"operation": operation,
|
||||||
|
"input": text,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="hash_text",
|
||||||
|
description="Generate hash of text using various algorithms",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to hash"
|
||||||
|
},
|
||||||
|
"algorithm": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Hash algorithm: md5, sha1, sha256, sha512",
|
||||||
|
"default": "sha256"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def hash_text(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""生成文本哈希"""
|
||||||
|
text = arguments.get("text", "")
|
||||||
|
algorithm = arguments.get("algorithm", "sha256")
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return {"success": False, "error": "Text is required"}
|
||||||
|
|
||||||
|
hash_funcs = {
|
||||||
|
"md5": hashlib.md5,
|
||||||
|
"sha1": hashlib.sha1,
|
||||||
|
"sha256": hashlib.sha256,
|
||||||
|
"sha512": hashlib.sha512
|
||||||
|
}
|
||||||
|
|
||||||
|
if algorithm not in hash_funcs:
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"error": f"Unsupported algorithm: {algorithm}"
|
||||||
|
}
|
||||||
|
|
||||||
|
try:
|
||||||
|
hash_obj = hash_funcs[algorithm](text.encode("utf-8"))
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"algorithm": algorithm,
|
||||||
|
"hash": hash_obj.hexdigest()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="url_encode_decode",
|
||||||
|
description="URL encode or decode text",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to encode or decode"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation: encode or decode",
|
||||||
|
"enum": ["encode", "decode"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def url_encode_decode(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""URL编码/解码"""
|
||||||
|
text = arguments.get("text", "")
|
||||||
|
operation = arguments.get("operation", "encode")
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return {"success": False, "error": "Text is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if operation == "encode":
|
||||||
|
result = urllib.parse.quote(text)
|
||||||
|
else:
|
||||||
|
result = urllib.parse.unquote(text)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"operation": operation,
|
||||||
|
"input": text,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
||||||
|
|
||||||
|
@tool(
|
||||||
|
name="base64_encode_decode",
|
||||||
|
description="Base64 encode or decode text",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"text": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Text to encode or decode"
|
||||||
|
},
|
||||||
|
"operation": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Operation: encode or decode",
|
||||||
|
"enum": ["encode", "decode"]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["text", "operation"]
|
||||||
|
},
|
||||||
|
category="data"
|
||||||
|
)
|
||||||
|
def base64_encode_decode(arguments: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Base64编码/解码"""
|
||||||
|
text = arguments.get("text", "")
|
||||||
|
operation = arguments.get("operation", "encode")
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return {"success": False, "error": "Text is required"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
if operation == "encode":
|
||||||
|
result = base64.b64encode(text.encode("utf-8")).decode("utf-8")
|
||||||
|
else:
|
||||||
|
result = base64.b64decode(text.encode("utf-8")).decode("utf-8")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"data": {
|
||||||
|
"operation": operation,
|
||||||
|
"input": text,
|
||||||
|
"result": result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
return {"success": False, "error": str(e)}
|
||||||
|
|
@ -0,0 +1,111 @@
|
||||||
|
"""工具系统核心模块"""
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Dict, Any, Callable, List, Optional, TypeVar, Generic
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolDefinition:
|
||||||
|
"""工具定义"""
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
parameters: Dict[str, Any] # JSON Schema
|
||||||
|
handler: Callable
|
||||||
|
category: str = "general"
|
||||||
|
|
||||||
|
def to_openai_format(self) -> Dict[str, Any]:
|
||||||
|
"""转换为OpenAI格式"""
|
||||||
|
return {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": self.name,
|
||||||
|
"description": self.description,
|
||||||
|
"parameters": self.parameters
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ToolResult:
|
||||||
|
"""工具执行结果"""
|
||||||
|
success: bool
|
||||||
|
data: Any = None
|
||||||
|
error: Optional[str] = None
|
||||||
|
|
||||||
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
|
"""转换为字典"""
|
||||||
|
return {"success": self.success, "data": self.data, "error": self.error}
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def ok(cls, data: Any) -> "ToolResult":
|
||||||
|
"""创建成功结果"""
|
||||||
|
return cls(success=True, data=data)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def fail(cls, error: str) -> "ToolResult":
|
||||||
|
"""创建失败结果"""
|
||||||
|
return cls(success=False, error=error)
|
||||||
|
|
||||||
|
|
||||||
|
class ToolRegistry:
|
||||||
|
"""工具注册表(单例模式)"""
|
||||||
|
_instance: Optional["ToolRegistry"] = None
|
||||||
|
_tools: Dict[str, ToolDefinition] = {}
|
||||||
|
|
||||||
|
def __new__(cls):
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
def register(self, tool: ToolDefinition) -> None:
|
||||||
|
"""注册工具"""
|
||||||
|
self._tools[tool.name] = tool
|
||||||
|
|
||||||
|
def get(self, name: str) -> Optional[ToolDefinition]:
|
||||||
|
"""获取工具定义"""
|
||||||
|
return self._tools.get(name)
|
||||||
|
|
||||||
|
def list_all(self) -> List[Dict[str, Any]]:
|
||||||
|
"""列出所有工具"""
|
||||||
|
return [t.to_openai_format() for t in self._tools.values()]
|
||||||
|
|
||||||
|
def list_by_category(self, category: str) -> List[Dict[str, Any]]:
|
||||||
|
"""按分类列出工具"""
|
||||||
|
return [
|
||||||
|
t.to_openai_format()
|
||||||
|
for t in self._tools.values()
|
||||||
|
if t.category == category
|
||||||
|
]
|
||||||
|
|
||||||
|
def execute(self, name: str, arguments: dict) -> Dict[str, Any]:
|
||||||
|
"""执行工具"""
|
||||||
|
tool = self.get(name)
|
||||||
|
if not tool:
|
||||||
|
return ToolResult.fail(f"Tool not found: {name}").to_dict()
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = tool.handler(arguments)
|
||||||
|
if isinstance(result, ToolResult):
|
||||||
|
return result.to_dict()
|
||||||
|
return ToolResult.ok(result).to_dict()
|
||||||
|
except Exception as e:
|
||||||
|
return ToolResult.fail(str(e)).to_dict()
|
||||||
|
|
||||||
|
def clear(self) -> None:
|
||||||
|
"""清空所有工具"""
|
||||||
|
self._tools.clear()
|
||||||
|
|
||||||
|
def remove(self, name: str) -> bool:
|
||||||
|
"""移除工具"""
|
||||||
|
if name in self._tools:
|
||||||
|
del self._tools[name]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@property
|
||||||
|
def tool_count(self) -> int:
|
||||||
|
"""工具数量"""
|
||||||
|
return len(self._tools)
|
||||||
|
|
||||||
|
|
||||||
|
# 全局注册表实例
|
||||||
|
registry = ToolRegistry()
|
||||||
|
|
@ -0,0 +1,186 @@
|
||||||
|
"""工具执行器"""
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
import hashlib
|
||||||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||||
|
from typing import List, Dict, Optional, Any
|
||||||
|
|
||||||
|
from alcor.tools.core import registry, ToolResult
|
||||||
|
|
||||||
|
|
||||||
|
class ToolExecutor:
|
||||||
|
"""工具执行器,支持缓存、并行执行"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
enable_cache: bool = True,
|
||||||
|
cache_ttl: int = 300, # 5分钟
|
||||||
|
max_workers: int = 4
|
||||||
|
):
|
||||||
|
self.enable_cache = enable_cache
|
||||||
|
self.cache_ttl = cache_ttl
|
||||||
|
self.max_workers = max_workers
|
||||||
|
self._cache: Dict[str, tuple] = {} # (result, timestamp)
|
||||||
|
self._call_history: List[Dict[str, Any]] = []
|
||||||
|
|
||||||
|
def _make_cache_key(self, name: str, args: dict) -> str:
|
||||||
|
"""生成缓存键"""
|
||||||
|
args_str = json.dumps(args, sort_keys=True, ensure_ascii=False)
|
||||||
|
return hashlib.md5(f"{name}:{args_str}".encode()).hexdigest()
|
||||||
|
|
||||||
|
def _is_cache_valid(self, cache_key: str) -> bool:
|
||||||
|
"""检查缓存是否有效"""
|
||||||
|
if cache_key not in self._cache:
|
||||||
|
return False
|
||||||
|
_, timestamp = self._cache[cache_key]
|
||||||
|
return (time.time() - timestamp) < self.cache_ttl
|
||||||
|
|
||||||
|
def _get_cached(self, cache_key: str) -> Optional[Dict]:
|
||||||
|
"""获取缓存结果"""
|
||||||
|
if self.enable_cache and self._is_cache_valid(cache_key):
|
||||||
|
return self._cache[cache_key][0]
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _set_cached(self, cache_key: str, result: Dict) -> None:
|
||||||
|
"""设置缓存"""
|
||||||
|
if self.enable_cache:
|
||||||
|
self._cache[cache_key] = (result, time.time())
|
||||||
|
|
||||||
|
def _record_call(self, name: str, args: dict, result: Dict) -> None:
|
||||||
|
"""记录调用历史"""
|
||||||
|
self._call_history.append({
|
||||||
|
"name": name,
|
||||||
|
"args": args,
|
||||||
|
"result": result,
|
||||||
|
"timestamp": time.time()
|
||||||
|
})
|
||||||
|
# 限制历史记录数量
|
||||||
|
if len(self._call_history) > 1000:
|
||||||
|
self._call_history = self._call_history[-500:]
|
||||||
|
|
||||||
|
def process_tool_calls(
|
||||||
|
self,
|
||||||
|
tool_calls: List[Dict[str, Any]],
|
||||||
|
context: Optional[Dict[str, Any]] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""顺序处理工具调用"""
|
||||||
|
results = []
|
||||||
|
|
||||||
|
for call in tool_calls:
|
||||||
|
name = call.get("function", {}).get("name", "")
|
||||||
|
args_str = call.get("function", {}).get("arguments", "{}")
|
||||||
|
call_id = call.get("id", "")
|
||||||
|
|
||||||
|
# 解析JSON参数
|
||||||
|
try:
|
||||||
|
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
results.append(self._create_error_result(call_id, name, "Invalid JSON arguments"))
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 检查缓存
|
||||||
|
cache_key = self._make_cache_key(name, args)
|
||||||
|
cached_result = self._get_cached(cache_key)
|
||||||
|
|
||||||
|
if cached_result is not None:
|
||||||
|
result = cached_result
|
||||||
|
else:
|
||||||
|
# 执行工具
|
||||||
|
result = registry.execute(name, args)
|
||||||
|
self._set_cached(cache_key, result)
|
||||||
|
|
||||||
|
# 记录调用
|
||||||
|
self._record_call(name, args, result)
|
||||||
|
|
||||||
|
# 创建结果消息
|
||||||
|
results.append(self._create_tool_result(call_id, name, result))
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def process_tool_calls_parallel(
|
||||||
|
self,
|
||||||
|
tool_calls: List[Dict[str, Any]],
|
||||||
|
context: Optional[Dict[str, Any]] = None,
|
||||||
|
max_workers: Optional[int] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
|
"""并行处理工具调用"""
|
||||||
|
if len(tool_calls) <= 1:
|
||||||
|
return self.process_tool_calls(tool_calls, context)
|
||||||
|
|
||||||
|
workers = max_workers or self.max_workers
|
||||||
|
results = [None] * len(tool_calls)
|
||||||
|
exec_tasks = {}
|
||||||
|
|
||||||
|
# 解析所有参数
|
||||||
|
for i, call in enumerate(tool_calls):
|
||||||
|
try:
|
||||||
|
name = call.get("function", {}).get("name", "")
|
||||||
|
args_str = call.get("function", {}).get("arguments", "{}")
|
||||||
|
call_id = call.get("id", "")
|
||||||
|
args = json.loads(args_str) if isinstance(args_str, str) else args_str
|
||||||
|
exec_tasks[i] = (call_id, name, args)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
results[i] = self._create_error_result(
|
||||||
|
call.get("id", ""),
|
||||||
|
call.get("function", {}).get("name", ""),
|
||||||
|
"Invalid JSON"
|
||||||
|
)
|
||||||
|
|
||||||
|
# 并行执行
|
||||||
|
def run(call_id: str, name: str, args: dict) -> Dict[str, Any]:
|
||||||
|
# 检查缓存
|
||||||
|
cache_key = self._make_cache_key(name, args)
|
||||||
|
cached_result = self._get_cached(cache_key)
|
||||||
|
|
||||||
|
if cached_result is not None:
|
||||||
|
result = cached_result
|
||||||
|
else:
|
||||||
|
result = registry.execute(name, args)
|
||||||
|
self._set_cached(cache_key, result)
|
||||||
|
|
||||||
|
self._record_call(name, args, result)
|
||||||
|
return self._create_tool_result(call_id, name, result)
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=workers) as pool:
|
||||||
|
futures = {
|
||||||
|
pool.submit(run, cid, n, a): i
|
||||||
|
for i, (cid, n, a) in exec_tasks.items()
|
||||||
|
}
|
||||||
|
for future in as_completed(futures):
|
||||||
|
idx = futures[future]
|
||||||
|
try:
|
||||||
|
results[idx] = future.result()
|
||||||
|
except Exception as e:
|
||||||
|
results[idx] = self._create_error_result(
|
||||||
|
exec_tasks[idx][0] if idx in exec_tasks else "",
|
||||||
|
exec_tasks[idx][1] if idx in exec_tasks else "",
|
||||||
|
str(e)
|
||||||
|
)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
def _create_tool_result(self, call_id: str, name: str, result: Dict) -> Dict[str, Any]:
|
||||||
|
"""创建工具结果消息"""
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"name": name,
|
||||||
|
"content": json.dumps(result, ensure_ascii=False, default=str)
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_error_result(self, call_id: str, name: str, error: str) -> Dict[str, Any]:
|
||||||
|
"""创建错误结果消息"""
|
||||||
|
return {
|
||||||
|
"role": "tool",
|
||||||
|
"tool_call_id": call_id,
|
||||||
|
"name": name,
|
||||||
|
"content": json.dumps({"success": False, "error": error})
|
||||||
|
}
|
||||||
|
|
||||||
|
def clear_cache(self) -> None:
|
||||||
|
"""清空缓存"""
|
||||||
|
self._cache.clear()
|
||||||
|
|
||||||
|
def get_history(self, limit: int = 100) -> List[Dict[str, Any]]:
|
||||||
|
"""获取调用历史"""
|
||||||
|
return self._call_history[-limit:]
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
"""工具装饰器工厂"""
|
||||||
|
from typing import Callable, Any, Dict
|
||||||
|
from alcor.tools.core import ToolDefinition, registry
|
||||||
|
|
||||||
|
|
||||||
|
def tool(
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: Dict[str, Any],
|
||||||
|
category: str = "general"
|
||||||
|
) -> Callable:
|
||||||
|
"""
|
||||||
|
工具注册装饰器
|
||||||
|
|
||||||
|
用法示例:
|
||||||
|
```python
|
||||||
|
@tool(
|
||||||
|
name="web_search",
|
||||||
|
description="Search the internet for information",
|
||||||
|
parameters={
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"query": {"type": "string", "description": "Search keywords"},
|
||||||
|
"max_results": {"type": "integer", "description": "Max results", "default": 5}
|
||||||
|
},
|
||||||
|
"required": ["query"]
|
||||||
|
},
|
||||||
|
category="crawler"
|
||||||
|
)
|
||||||
|
def web_search(arguments: dict) -> dict:
|
||||||
|
# 实现...
|
||||||
|
return {"results": []}
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
def decorator(func: Callable) -> Callable:
|
||||||
|
tool_def = ToolDefinition(
|
||||||
|
name=name,
|
||||||
|
description=description,
|
||||||
|
parameters=parameters,
|
||||||
|
handler=func,
|
||||||
|
category=category
|
||||||
|
)
|
||||||
|
registry.register(tool_def)
|
||||||
|
return func
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
|
def tool_function(
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
parameters: Dict[str, Any],
|
||||||
|
category: str = "general"
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
工具装饰器的别名,提供更语义化的命名
|
||||||
|
"""
|
||||||
|
return tool(name=name, description=description, parameters=parameters, category=category)
|
||||||
|
|
@ -0,0 +1,22 @@
|
||||||
|
"""工具函数模块"""
|
||||||
|
from alcor.utils.helpers import (
|
||||||
|
generate_id,
|
||||||
|
hash_password,
|
||||||
|
verify_password,
|
||||||
|
create_access_token,
|
||||||
|
decode_access_token,
|
||||||
|
success_response,
|
||||||
|
error_response,
|
||||||
|
paginate
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"generate_id",
|
||||||
|
"hash_password",
|
||||||
|
"verify_password",
|
||||||
|
"create_access_token",
|
||||||
|
"decode_access_token",
|
||||||
|
"success_response",
|
||||||
|
"error_response",
|
||||||
|
"paginate"
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
"""辅助工具模块"""
|
||||||
|
import shortuuid
|
||||||
|
import jwt
|
||||||
|
from datetime import datetime, timedelta
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
|
||||||
|
from alcor.config import config
|
||||||
|
|
||||||
|
|
||||||
|
def generate_id(prefix: str = "") -> str:
|
||||||
|
"""生成唯一ID"""
|
||||||
|
unique_id = shortuuid.uuid()
|
||||||
|
if prefix:
|
||||||
|
return f"{prefix}_{unique_id}"
|
||||||
|
return unique_id
|
||||||
|
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
"""密码哈希"""
|
||||||
|
import bcrypt
|
||||||
|
salt = bcrypt.gensalt()
|
||||||
|
return bcrypt.hashpw(password.encode(), salt).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def verify_password(password: str, hashed: str) -> bool:
|
||||||
|
"""验证密码"""
|
||||||
|
import bcrypt
|
||||||
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def create_access_token(data: Dict[str, Any], expires_delta: Optional[timedelta] = None) -> str:
|
||||||
|
"""创建JWT访问令牌"""
|
||||||
|
to_encode = data.copy()
|
||||||
|
|
||||||
|
if expires_delta:
|
||||||
|
expire = datetime.utcnow() + expires_delta
|
||||||
|
else:
|
||||||
|
expire = datetime.utcnow() + timedelta(hours=24)
|
||||||
|
|
||||||
|
to_encode.update({"exp": expire, "iat": datetime.utcnow()})
|
||||||
|
|
||||||
|
encoded_jwt = jwt.encode(
|
||||||
|
to_encode,
|
||||||
|
config.secret_key,
|
||||||
|
algorithm="HS256"
|
||||||
|
)
|
||||||
|
return encoded_jwt
|
||||||
|
|
||||||
|
|
||||||
|
def decode_access_token(token: str) -> Optional[Dict[str, Any]]:
|
||||||
|
"""解码JWT令牌"""
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token,
|
||||||
|
config.secret_key,
|
||||||
|
algorithms=["HS256"]
|
||||||
|
)
|
||||||
|
return payload
|
||||||
|
except jwt.ExpiredSignatureError:
|
||||||
|
return None
|
||||||
|
except jwt.InvalidTokenError:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def success_response(data: Any = None, message: str = "Success") -> Dict[str, Any]:
|
||||||
|
"""成功响应封装"""
|
||||||
|
return {
|
||||||
|
"success": True,
|
||||||
|
"message": message,
|
||||||
|
"data": data
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def error_response(message: str, code: int = 400, errors: Any = None) -> Dict[str, Any]:
|
||||||
|
"""错误响应封装"""
|
||||||
|
response = {
|
||||||
|
"success": False,
|
||||||
|
"message": message,
|
||||||
|
"code": code
|
||||||
|
}
|
||||||
|
if errors:
|
||||||
|
response["errors"] = errors
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
def paginate(query, page: int = 1, page_size: int = 20):
|
||||||
|
"""分页辅助"""
|
||||||
|
total = query.count()
|
||||||
|
items = query.offset((page - 1) * page_size).limit(page_size).all()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"items": items,
|
||||||
|
"total": total,
|
||||||
|
"page": page,
|
||||||
|
"page_size": page_size,
|
||||||
|
"total_pages": (total + page_size - 1) // page_size
|
||||||
|
}
|
||||||
|
|
@ -0,0 +1,21 @@
|
||||||
|
# 配置文件
|
||||||
|
app:
|
||||||
|
secret_key: ${APP_SECRET_KEY}
|
||||||
|
debug: true
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8000
|
||||||
|
|
||||||
|
database:
|
||||||
|
type: sqlite
|
||||||
|
url: sqlite:///./chat.db
|
||||||
|
|
||||||
|
llm:
|
||||||
|
provider: deepseek
|
||||||
|
api_key: ${DEEPSEEK_API_KEY}
|
||||||
|
api_url: https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
tools:
|
||||||
|
enable_cache: true
|
||||||
|
cache_ttl: 300
|
||||||
|
max_workers: 4
|
||||||
|
max_iterations: 10
|
||||||
|
|
@ -0,0 +1,259 @@
|
||||||
|
# API 接口文档
|
||||||
|
|
||||||
|
## 认证 `/api/auth`
|
||||||
|
|
||||||
|
### POST /api/auth/register
|
||||||
|
用户注册
|
||||||
|
|
||||||
|
**请求体:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"username": "string",
|
||||||
|
"email": "user@example.com",
|
||||||
|
"password": "string"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"message": "注册成功",
|
||||||
|
"data": {
|
||||||
|
"id": 1,
|
||||||
|
"username": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### POST /api/auth/login
|
||||||
|
用户登录
|
||||||
|
|
||||||
|
**请求体:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"username": "string",
|
||||||
|
"password": "string"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"message": "登录成功",
|
||||||
|
"data": {
|
||||||
|
"access_token": "eyJ...",
|
||||||
|
"token_type": "bearer",
|
||||||
|
"user": {
|
||||||
|
"id": 1,
|
||||||
|
"username": "string",
|
||||||
|
"role": "user"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### POST /api/auth/logout
|
||||||
|
用户登出
|
||||||
|
|
||||||
|
**请求头:** `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"message": "登出成功"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### GET /api/auth/me
|
||||||
|
获取当前用户信息
|
||||||
|
|
||||||
|
**请求头:** `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"id": 1,
|
||||||
|
"username": "string",
|
||||||
|
"email": "user@example.com",
|
||||||
|
"role": "user",
|
||||||
|
"is_active": true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 会话 `/api/conversations`
|
||||||
|
|
||||||
|
### GET /api/conversations/
|
||||||
|
获取会话列表
|
||||||
|
|
||||||
|
**查询参数:**
|
||||||
|
- `project_id` (可选): 项目ID
|
||||||
|
- `page` (可选): 页码,默认1
|
||||||
|
- `page_size` (可选): 每页数量,默认20
|
||||||
|
|
||||||
|
**请求头:** `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"items": [...],
|
||||||
|
"total": 100,
|
||||||
|
"page": 1,
|
||||||
|
"page_size": 20
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### POST /api/conversations/
|
||||||
|
创建会话
|
||||||
|
|
||||||
|
**请求头:** `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
**请求体:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"project_id": "string (可选)",
|
||||||
|
"title": "新会话",
|
||||||
|
"model": "glm-5",
|
||||||
|
"system_prompt": "string (可选)",
|
||||||
|
"temperature": 1.0,
|
||||||
|
"max_tokens": 65536,
|
||||||
|
"thinking_enabled": false
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"message": "会话创建成功",
|
||||||
|
"data": {
|
||||||
|
"id": "conv_xxx",
|
||||||
|
"user_id": 1,
|
||||||
|
"title": "新会话",
|
||||||
|
"model": "glm-5",
|
||||||
|
...
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### GET /api/conversations/{id}
|
||||||
|
获取会话详情
|
||||||
|
|
||||||
|
**路径参数:**
|
||||||
|
- `id`: 会话ID
|
||||||
|
|
||||||
|
**请求头:** `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
### PUT /api/conversations/{id}
|
||||||
|
更新会话
|
||||||
|
|
||||||
|
### DELETE /api/conversations/{id}
|
||||||
|
删除会话
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 消息 `/api/messages`
|
||||||
|
|
||||||
|
### GET /api/messages/{conversation_id}
|
||||||
|
获取消息列表
|
||||||
|
|
||||||
|
**路径参数:**
|
||||||
|
- `conversation_id`: 会话ID
|
||||||
|
|
||||||
|
**查询参数:**
|
||||||
|
- `limit` (可选): 返回数量,默认100
|
||||||
|
|
||||||
|
**请求头:** `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
### POST /api/messages/
|
||||||
|
发送消息(非流式)
|
||||||
|
|
||||||
|
**请求头:** `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
**请求体:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"conversation_id": "conv_xxx",
|
||||||
|
"content": "用户消息",
|
||||||
|
"tools_enabled": true
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"user_message": {...},
|
||||||
|
"assistant_message": {...}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### POST /api/messages/stream
|
||||||
|
发送消息(流式响应)
|
||||||
|
|
||||||
|
使用 Server-Sent Events (SSE) 返回流式响应。
|
||||||
|
|
||||||
|
**事件类型:**
|
||||||
|
- `text`: 文本增量
|
||||||
|
- `tool_call`: 工具调用
|
||||||
|
- `tool_result`: 工具结果
|
||||||
|
- `done`: 完成
|
||||||
|
- `error`: 错误
|
||||||
|
|
||||||
|
### DELETE /api/messages/{id}
|
||||||
|
删除消息
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 工具 `/api/tools`
|
||||||
|
|
||||||
|
### GET /api/tools/
|
||||||
|
获取可用工具列表
|
||||||
|
|
||||||
|
**查询参数:**
|
||||||
|
- `category` (可选): 工具分类
|
||||||
|
|
||||||
|
**请求头:** `Authorization: Bearer <token>`
|
||||||
|
|
||||||
|
**响应:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"success": true,
|
||||||
|
"data": {
|
||||||
|
"tools": [...],
|
||||||
|
"categorized": {
|
||||||
|
"crawler": [...],
|
||||||
|
"code": [...],
|
||||||
|
"data": [...],
|
||||||
|
"weather": [...]
|
||||||
|
},
|
||||||
|
"total": 11
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### GET /api/tools/{name}
|
||||||
|
获取工具详情
|
||||||
|
|
||||||
|
### POST /api/tools/{name}/execute
|
||||||
|
手动执行工具
|
||||||
|
|
||||||
|
**请求体:**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"arg1": "value1",
|
||||||
|
"arg2": "value2"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
@ -0,0 +1,269 @@
|
||||||
|
# 项目架构
|
||||||
|
|
||||||
|
## 技术栈
|
||||||
|
|
||||||
|
- **框架**: FastAPI 0.109+
|
||||||
|
- **数据库**: SQLAlchemy 2.0+
|
||||||
|
- **认证**: JWT (PyJWT)
|
||||||
|
- **HTTP客户端**: httpx
|
||||||
|
- **配置**: YAML (PyYAML)
|
||||||
|
- **代码执行**: Python 原生执行
|
||||||
|
|
||||||
|
## 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
alcor/
|
||||||
|
├── __init__.py # FastAPI 应用工厂
|
||||||
|
├── run.py # 入口文件
|
||||||
|
├── config.py # 配置管理(YAML)
|
||||||
|
├── database.py # 数据库连接
|
||||||
|
├── models.py # ORM 模型
|
||||||
|
├── routes/ # API 路由层
|
||||||
|
│ ├── auth.py # 认证
|
||||||
|
│ ├── conversations.py # 会话管理
|
||||||
|
│ ├── messages.py # 消息处理
|
||||||
|
│ └── tools.py # 工具管理
|
||||||
|
├── services/ # 服务层
|
||||||
|
│ ├── chat.py # 聊天服务
|
||||||
|
│ └── llm_client.py # LLM 客户端
|
||||||
|
├── tools/ # 工具系统
|
||||||
|
│ ├── core.py # 核心类
|
||||||
|
│ ├── factory.py # 装饰器
|
||||||
|
│ ├── executor.py # 执行器
|
||||||
|
│ └── builtin/ # 内置工具
|
||||||
|
│ ├── code.py # 代码执行
|
||||||
|
│ ├── crawler.py # 网页爬虫
|
||||||
|
│ ├── data.py # 数据处理
|
||||||
|
│ └── weather.py # 天气查询
|
||||||
|
└── utils/ # 工具函数
|
||||||
|
└── helpers.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 核心组件
|
||||||
|
|
||||||
|
### 1. 应用工厂 (`__init__.py`)
|
||||||
|
FastAPI 应用入口,使用 lifespan 管理生命周期:
|
||||||
|
- 启动:初始化数据库、加载工具
|
||||||
|
- 关闭:清理资源
|
||||||
|
|
||||||
|
### 2. 配置管理 (`config.py`)
|
||||||
|
使用 YAML 文件管理配置:
|
||||||
|
- 配置文件:`config.yaml`
|
||||||
|
- 环境变量替换:`${VAR_NAME}`
|
||||||
|
- 单例模式全局访问
|
||||||
|
- 默认值支持
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# config.yaml 示例
|
||||||
|
app:
|
||||||
|
secret_key: ${APP_SECRET_KEY}
|
||||||
|
debug: true
|
||||||
|
|
||||||
|
llm:
|
||||||
|
provider: deepseek
|
||||||
|
api_key: ${DEEPSEEK_API_KEY}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 数据库 (`database.py`)
|
||||||
|
- SQLAlchemy 异步支持
|
||||||
|
- SQLite 默认数据库
|
||||||
|
- 依赖注入获取会话
|
||||||
|
|
||||||
|
### 4. ORM 模型 (`models.py`)
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
erDiagram
|
||||||
|
USER {
|
||||||
|
int id PK
|
||||||
|
string username UK
|
||||||
|
string email UK
|
||||||
|
string password_hash
|
||||||
|
string role
|
||||||
|
boolean is_active
|
||||||
|
datetime created_at
|
||||||
|
}
|
||||||
|
|
||||||
|
CONVERSATION {
|
||||||
|
string id PK
|
||||||
|
int user_id FK
|
||||||
|
string project_id FK
|
||||||
|
string title
|
||||||
|
string model
|
||||||
|
text system_prompt
|
||||||
|
float temperature
|
||||||
|
int max_tokens
|
||||||
|
boolean thinking_enabled
|
||||||
|
datetime created_at
|
||||||
|
datetime updated_at
|
||||||
|
}
|
||||||
|
|
||||||
|
MESSAGE {
|
||||||
|
string id PK
|
||||||
|
string conversation_id FK
|
||||||
|
string role
|
||||||
|
longtext content
|
||||||
|
int token_count
|
||||||
|
datetime created_at
|
||||||
|
}
|
||||||
|
|
||||||
|
USER ||--o{ CONVERSATION : "has"
|
||||||
|
CONVERSATION ||--o{ MESSAGE : "has"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 工具系统
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
classDiagram
|
||||||
|
class ToolDefinition {
|
||||||
|
+str name
|
||||||
|
+str description
|
||||||
|
+dict parameters
|
||||||
|
+Callable handler
|
||||||
|
+str category
|
||||||
|
+to_openai_format() dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolResult {
|
||||||
|
+bool success
|
||||||
|
+Any data
|
||||||
|
+str error
|
||||||
|
+to_dict() dict
|
||||||
|
+ok(data) ToolResult$
|
||||||
|
+fail(error) ToolResult$
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolRegistry {
|
||||||
|
+_tools: Dict
|
||||||
|
+register(tool) void
|
||||||
|
+get(name) ToolDefinition?
|
||||||
|
+list_all() List~dict~
|
||||||
|
+execute(name, arguments) dict
|
||||||
|
}
|
||||||
|
|
||||||
|
class ToolExecutor {
|
||||||
|
+registry: ToolRegistry
|
||||||
|
+enable_cache: bool
|
||||||
|
+cache_ttl: int
|
||||||
|
+_cache: Dict
|
||||||
|
+process_tool_calls(tool_calls, context) list
|
||||||
|
+process_tool_calls_parallel(tool_calls, context, max_workers) list
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 内置工具
|
||||||
|
|
||||||
|
| 工具 | 功能 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `python_execute` | 执行 Python 代码 | 支持 print 输出、变量访问 |
|
||||||
|
| `python_eval` | 计算表达式 | 快速求值 |
|
||||||
|
| `web_crawl` | 网页抓取 | BeautifulSoup + httpx |
|
||||||
|
| `get_weather` | 天气查询 | 支持城市名查询 |
|
||||||
|
| `process_data` | 数据处理 | JSON 转换、格式化等 |
|
||||||
|
|
||||||
|
### 6. 服务层
|
||||||
|
|
||||||
|
#### ChatService (`services/chat.py`)
|
||||||
|
核心聊天服务:
|
||||||
|
- Agentic Loop 迭代执行
|
||||||
|
- 流式 SSE 响应
|
||||||
|
- 工具调用编排
|
||||||
|
- 消息历史管理
|
||||||
|
- 自动重试机制
|
||||||
|
|
||||||
|
#### LLMClient (`services/llm_client.py`)
|
||||||
|
LLM API 客户端:
|
||||||
|
- 多提供商:DeepSeek、GLM、OpenAI
|
||||||
|
- 流式/同步调用
|
||||||
|
- 错误处理和重试
|
||||||
|
- Token 计数
|
||||||
|
|
||||||
|
### 7. 认证系统 (`routes/auth.py`)
|
||||||
|
- JWT Bearer Token
|
||||||
|
- Bcrypt 密码哈希
|
||||||
|
- 用户注册/登录
|
||||||
|
|
||||||
|
### 8. API 路由
|
||||||
|
|
||||||
|
| 路由 | 方法 | 说明 |
|
||||||
|
|------|------|------|
|
||||||
|
| `/auth/register` | POST | 用户注册 |
|
||||||
|
| `/auth/login` | POST | 用户登录 |
|
||||||
|
| `/conversations` | GET/POST | 会话列表/创建 |
|
||||||
|
| `/conversations/{id}` | GET/DELETE | 会话详情/删除 |
|
||||||
|
| `/messages/stream` | POST | 流式消息发送 |
|
||||||
|
| `/tools` | GET | 可用工具列表 |
|
||||||
|
|
||||||
|
## 数据流
|
||||||
|
|
||||||
|
### 消息处理流程
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
sequenceDiagram
|
||||||
|
participant Client
|
||||||
|
participant API as POST /messages/stream
|
||||||
|
participant CS as ChatService
|
||||||
|
participant LLM as LLM API
|
||||||
|
participant TE as ToolExecutor
|
||||||
|
|
||||||
|
Client->>API: POST {content, tools}
|
||||||
|
API->>CS: stream_response()
|
||||||
|
|
||||||
|
loop MAX_ITERATIONS
|
||||||
|
CS->>LLM: call(messages, tools)
|
||||||
|
LLM-->>CS: SSE Stream
|
||||||
|
|
||||||
|
alt tool_calls
|
||||||
|
CS->>TE: process_tool_calls_parallel()
|
||||||
|
TE-->>CS: tool_results
|
||||||
|
CS->>CS: 追加到 messages
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
CS->>API: SSE Stream
|
||||||
|
API-->>Client: 流式响应
|
||||||
|
```
|
||||||
|
|
||||||
|
## SSE 事件
|
||||||
|
|
||||||
|
| 事件 | 说明 |
|
||||||
|
|------|------|
|
||||||
|
| `text` | 文本内容增量 |
|
||||||
|
| `tool_call` | 工具调用请求 |
|
||||||
|
| `tool_result` | 工具执行结果 |
|
||||||
|
| `done` | 响应完成 |
|
||||||
|
| `error` | 错误信息 |
|
||||||
|
|
||||||
|
## 配置示例
|
||||||
|
|
||||||
|
### config.yaml
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
app:
|
||||||
|
secret_key: ${APP_SECRET_KEY}
|
||||||
|
debug: true
|
||||||
|
host: 0.0.0.0
|
||||||
|
port: 8000
|
||||||
|
|
||||||
|
database:
|
||||||
|
type: sqlite
|
||||||
|
url: sqlite:///./chat.db
|
||||||
|
|
||||||
|
llm:
|
||||||
|
provider: deepseek
|
||||||
|
api_key: ${DEEPSEEK_API_KEY}
|
||||||
|
api_url: https://api.deepseek.com/v1
|
||||||
|
|
||||||
|
tools:
|
||||||
|
enable_cache: true
|
||||||
|
cache_ttl: 300
|
||||||
|
max_workers: 4
|
||||||
|
max_iterations: 10
|
||||||
|
```
|
||||||
|
|
||||||
|
## 环境变量
|
||||||
|
|
||||||
|
| 变量 | 说明 | 示例 |
|
||||||
|
|------|------|------|
|
||||||
|
| `APP_SECRET_KEY` | 应用密钥 | `your-secret-key` |
|
||||||
|
| `DEEPSEEK_API_KEY` | DeepSeek API | `sk-xxxx` |
|
||||||
|
| `DATABASE_URL` | 数据库连接 | `sqlite:///./chat.db` |
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
[project]
|
||||||
|
name = "alcor"
|
||||||
|
version = "1.0.0"
|
||||||
|
description = "Alcor - FastAPI + SQLAlchemy"
|
||||||
|
readme = "docs/README.md"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
|
||||||
|
dependencies = [
|
||||||
|
"fastapi>=0.109.0",
|
||||||
|
"uvicorn[standard]>=0.27.0",
|
||||||
|
"python-multipart>=0.0.6",
|
||||||
|
"sse-starlette>=2.0.0",
|
||||||
|
"sqlalchemy>=2.0.25",
|
||||||
|
"aiosqlite>=0.19.0",
|
||||||
|
"pyjwt>=2.8.0",
|
||||||
|
"bcrypt>=4.1.2",
|
||||||
|
"python-jose[cryptography]>=3.3.0",
|
||||||
|
"httpx>=0.26.0",
|
||||||
|
"requests>=2.31.0",
|
||||||
|
"beautifulsoup4>=4.12.3",
|
||||||
|
"lxml>=5.1.0",
|
||||||
|
"pyyaml>=6.0.1",
|
||||||
|
"shortuuid>=1.0.11",
|
||||||
|
"pydantic>=2.5.0",
|
||||||
|
"pydantic-settings>=2.1.0",
|
||||||
|
"email-validator>=2.1.0",
|
||||||
|
]
|
||||||
|
|
||||||
|
[project.optional-dependencies]
|
||||||
|
dev = [
|
||||||
|
"pytest>=8.0.0",
|
||||||
|
"pytest-asyncio>=0.23.0",
|
||||||
|
"pytest-cov>=4.1.0",
|
||||||
|
"black>=24.0.0",
|
||||||
|
"ruff>=0.1.0",
|
||||||
|
]
|
||||||
Loading…
Reference in New Issue