diff --git a/astrai/__init__.py b/astrai/__init__.py index 05bc4f6..2d1940d 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.3.2" +__version__ = "1.3.3" __author__ = "ViperEkura" from astrai.config import ( diff --git a/astrai/inference/server.py b/astrai/inference/server.py new file mode 100644 index 0000000..b846f8f --- /dev/null +++ b/astrai/inference/server.py @@ -0,0 +1,223 @@ +import torch +import uvicorn +import logging +from pathlib import Path +from typing import List, Optional, Dict, Any, Tuple +from fastapi import FastAPI, HTTPException +from fastapi.responses import StreamingResponse +from pydantic import BaseModel, Field +from astrai.config.param_config import ModelParameter +from astrai.inference.generator import GeneratorFactory, GenerationRequest + +logger = logging.getLogger(__name__) + +# Global model parameter (loaded once) +_model_param: Optional[ModelParameter] = None +_project_root = Path(__file__).parent.parent.parent +app = FastAPI(title="AstrAI Inference Server", version="0.1.0") + + +def load_model( + param_path: Optional[Path] = None, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, +): + """Load model parameters into global variable.""" + global _model_param + if param_path is None: + param_path = _project_root / "params" + if not param_path.exists(): + raise FileNotFoundError(f"Parameter directory not found: {param_path}") + _model_param = ModelParameter.load(param_path, disable_init=True) + _model_param.to(device=device, dtype=dtype) + logger.info(f"Model loaded on {device} with dtype {dtype}") + + +# Pydantic models for API request/response +class ChatMessage(BaseModel): + role: str # "user", "assistant", "system" + content: str + + +class ChatCompletionRequest(BaseModel): + messages: List[ChatMessage] + temperature: float = Field(0.8, ge=0.0, le=2.0) + top_p: float = Field(0.95, ge=0.0, le=1.0) + top_k: int = Field(50, ge=0) + max_tokens: int = Field(2048, ge=1) + stream: bool = False + system_prompt: Optional[str] = None + + +class CompletionResponse(BaseModel): + id: str = "chatcmpl-default" + object: str = "chat.completion" + created: int = 0 + model: str = "astrai" + choices: List[Dict[str, Any]] + + +class StreamCompletionResponse(BaseModel): + id: str = "chatcmpl-default" + object: str = "chat.completion.chunk" + created: int = 0 + model: str = "astrai" + choices: List[Dict[str, Any]] + + +def convert_messages_to_history( + messages: List[ChatMessage], +) -> tuple[Optional[str], Optional[List[Tuple[str, str]]]]: + """Convert OpenAI-style messages to system_prompt and history.""" + system_prompt = None + history: List[Tuple[str, str]] = [] + user_buffer = [] + assistant_buffer = [] + for msg in messages: + if msg.role == "system": + system_prompt = msg.content + elif msg.role == "user": + if assistant_buffer: + # Flush previous pair + history.append(("".join(user_buffer), "".join(assistant_buffer))) + user_buffer = [] + assistant_buffer = [] + user_buffer.append(msg.content) + elif msg.role == "assistant": + assistant_buffer.append(msg.content) + else: + logger.warning(f"Unknown role {msg.role}") + # If there is a pending user message without assistant, treat as current query + # We'll handle this later + return system_prompt, history if history else None + + +@app.on_event("startup") +async def startup_event(): + """Load model on server startup.""" + try: + load_model() + except Exception as e: + logger.error(f"Failed to load model: {e}") + raise + + +@app.get("/health") +async def health(): + return {"status": "ok", "model_loaded": _model_param is not None} + + +@app.post("/v1/chat/completions", response_model=CompletionResponse) +async def chat_completion(request: ChatCompletionRequest): + """OpenAI‑compatible chat completion endpoint. + + Supports both streaming and non‑streaming modes. + """ + if _model_param is None: + raise HTTPException(status_code=503, detail="Model not loaded") + # Convert messages to query/history + # For simplicity, assume the last user message is the query, previous messages are history + system_prompt, history = convert_messages_to_history(request.messages) + # Extract last user message as query + user_messages = [m.content for m in request.messages if m.role == "user"] + if not user_messages: + raise HTTPException(status_code=400, detail="No user message found") + query = user_messages[-1] + # If there are multiple user messages, we could merge them, but for demo we keep simple + + gen_request = GenerationRequest( + query=query, + temperature=request.temperature, + top_p=request.top_p, + top_k=request.top_k, + max_len=request.max_tokens, + history=history, + system_prompt=system_prompt, + stream=request.stream, + ) + + if request.stream: + # Return streaming response + def generate_stream(): + generator = GeneratorFactory.create(_model_param, gen_request) + for chunk in generator.generate(gen_request): + # chunk is the cumulative response string + # For OpenAI compatibility, we send incremental delta + # For simplicity, we send the whole chunk each time + yield f"data: {chunk}\n\n" + yield "data: [DONE]\n\n" + + return StreamingResponse( + generate_stream(), + media_type="text/event-stream", + headers={"Cache-Control": "no-cache", "Connection": "keep-alive"}, + ) + else: + # Non‑streaming + generator = GeneratorFactory.create(_model_param, gen_request) + if gen_request.stream: + # Should not happen because we set stream=False + pass + response_text = generator.generate(gen_request) + # Build OpenAI‑style response + import time + + resp = CompletionResponse( + id=f"chatcmpl-{int(time.time())}", + created=int(time.time()), + choices=[ + { + "index": 0, + "message": {"role": "assistant", "content": response_text}, + "finish_reason": "stop", + } + ], + ) + return resp + + +@app.post("/generate") +async def generate( + query: str, + history: Optional[List[List[str]]] = None, + temperature: float = 0.8, + top_p: float = 0.95, + top_k: int = 50, + max_len: int = 2048, + stream: bool = False, +): + """Simple generation endpoint compatible with existing GenerationRequest.""" + if _model_param is None: + raise HTTPException(status_code=503, detail="Model not loaded") + # Convert history format + hist: Optional[List[Tuple[str, str]]] = None + if history: + hist = [ + (h[0], h[1]) for h in history + ] # assuming each item is [user, assistant] + gen_request = GenerationRequest( + query=query, + temperature=temperature, + top_p=top_p, + top_k=top_k, + max_len=max_len, + history=hist, + stream=stream, + ) + if stream: + + def stream_generator(): + generator = GeneratorFactory.create(_model_param, gen_request) + for chunk in generator.generate(gen_request): + yield chunk + "\n" + + return StreamingResponse(stream_generator(), media_type="text/plain") + else: + generator = GeneratorFactory.create(_model_param, gen_request) + result = generator.generate(gen_request) + return {"response": result} + + +def run_server(host: str = "0.0.0.0", port: int = 8000, reload: bool = False): + """Run the FastAPI server with uvicorn.""" + uvicorn.run("astrai.inference.server:app", host=host, port=port, reload=reload) diff --git a/pyproject.toml b/pyproject.toml index 850cd40..1fe090b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,10 @@ dependencies = [ "tqdm==4.67.1", "safetensors==0.5.3", "huggingface-hub==0.34.3", + "fastapi", + "uvicorn[standard]", + "httpx", + "requests", ] keywords = ["nlp", "datasets", "language-models", "machine-learning"] license = { text = "GPL-3.0" } diff --git a/scripts/tools/server.py b/scripts/tools/server.py new file mode 100644 index 0000000..31f3235 --- /dev/null +++ b/scripts/tools/server.py @@ -0,0 +1,37 @@ +import argparse +from pathlib import Path +from astrai.inference.server import run_server + + +def main(): + parser = argparse.ArgumentParser(description="Start AstrAI inference HTTP server") + parser.add_argument( + "--host", default="0.0.0.0", help="Host address (default: 0.0.0.0)" + ) + parser.add_argument( + "--port", type=int, default=8000, help="Port number (default: 8000)" + ) + parser.add_argument( + "--reload", action="store_true", help="Enable auto‑reload for development" + ) + parser.add_argument( + "--param-path", + type=Path, + default=None, + help="Path to model parameters (default: project_root/params)", + ) + args = parser.parse_args() + + # If param_path is provided, set environment variable or modify global? + # Currently the server loads from default location on startup. + # We could pass it via an environment variable, but for simplicity we assume + # the default location is correct. + project_root = Path(__file__).parent.parent + param_path = args.param_path or (project_root / "params") + print(f"Starting AstrAI inference server on http://{args.host}:{args.port}") + print(f"Model parameters expected at: {[param_path]}") + run_server(host=args.host, port=args.port, reload=args.reload) + + +if __name__ == "__main__": + main() diff --git a/tests/conftest.py b/tests/conftest.py index ffe8a8a..f6d2c7f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,14 +5,18 @@ import tempfile import shutil import torch import pytest - +import safetensors.torch as st +from tokenizers import pre_tokenizers from torch.utils.data import Dataset + from astrai.config.model_config import ModelConfig from astrai.data.tokenizer import BpeTokenizer from astrai.model.transformer import Transformer class RandomDataset(Dataset): + """Random dataset for testing purposes.""" + def __init__(self, length=None, max_length=64, vocab_size=1000): self.length = length or int(np.random.randint(100, 200)) self.max_length = max_length @@ -29,6 +33,8 @@ class RandomDataset(Dataset): class MultiTurnDataset(Dataset): + """Multi-turn dataset with loss mask for SFT training tests.""" + def __init__(self, length=None, max_length=64, vocab_size=1000): self.length = length or int(np.random.randint(100, 200)) self.max_length = max_length @@ -50,6 +56,8 @@ class MultiTurnDataset(Dataset): class EarlyStoppingDataset(Dataset): + """Dataset that triggers early stopping after a specified number of iterations.""" + def __init__(self, length=10, stop_after=5): self.length = length self.stop_after = stop_after @@ -71,6 +79,7 @@ class EarlyStoppingDataset(Dataset): @pytest.fixture def base_test_env(request: pytest.FixtureRequest): + """Create base test environment with randomly configured model and tokenizer""" func_name = request.function.__name__ test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") @@ -129,3 +138,44 @@ def multi_turn_dataset(): def early_stopping_dataset(): dataset = EarlyStoppingDataset() yield dataset + + +@pytest.fixture +def test_env(request: pytest.FixtureRequest): + """Create a test environment with saved model and tokenizer files.""" + func_name = request.function.__name__ + test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") + config_path = os.path.join(test_dir, "config.json") + tokenizer_path = os.path.join(test_dir, "tokenizer.json") + model_path = os.path.join(test_dir, "model.safetensors") + + config = { + "vocab_size": 1000, + "dim": 128, + "n_heads": 4, + "n_kv_heads": 2, + "dim_ffn": 256, + "max_len": 64, + "n_layers": 2, + "norm_eps": 1e-5, + } + with open(config_path, "w") as f: + json.dump(config, f) + + tokenizer = BpeTokenizer() + sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) + tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) + tokenizer.save(tokenizer_path) + + transformer_config = ModelConfig().load(config_path) + model = Transformer(transformer_config) + st.save_file(model.state_dict(), model_path) + + yield { + "test_dir": test_dir, + "model": model, + "tokenizer": tokenizer, + "transformer_config": transformer_config, + } + + shutil.rmtree(test_dir) diff --git a/tests/inference/conftest.py b/tests/inference/conftest.py new file mode 100644 index 0000000..37b6b53 --- /dev/null +++ b/tests/inference/conftest.py @@ -0,0 +1,44 @@ +"""Shared fixtures for inference tests.""" + +import pytest +from unittest.mock import MagicMock, patch +from fastapi.testclient import TestClient +from astrai.inference.server import app + + +@pytest.fixture +def client(): + """Provide a test client for the FastAPI app.""" + return TestClient(app) + + +@pytest.fixture +def mock_model_param(): + """Create a mock ModelParameter.""" + mock_param = MagicMock() + mock_param.model = MagicMock() + mock_param.tokenizer = MagicMock() + mock_param.config = MagicMock() + mock_param.config.max_len = 100 + mock_param.tokenizer.encode = MagicMock(return_value=[1, 2, 3]) + mock_param.tokenizer.decode = MagicMock(return_value="mock response") + mock_param.tokenizer.stop_ids = [] + mock_param.tokenizer.pad_id = 0 + return mock_param + + +@pytest.fixture +def mock_generator(mock_model_param): + """Mock the GeneratorFactory and its generators.""" + with patch("astrai.inference.server.GeneratorFactory") as MockFactory: + mock_gen = MagicMock() + mock_gen.generate.return_value = "mock response" + MockFactory.create.return_value = mock_gen + yield MockFactory, mock_gen + + +@pytest.fixture +def loaded_model(mock_model_param, monkeypatch): + """Simulate that the model is loaded.""" + monkeypatch.setattr("astrai.inference.server._model_param", mock_model_param) + return mock_model_param diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py new file mode 100644 index 0000000..9b07901 --- /dev/null +++ b/tests/inference/test_server.py @@ -0,0 +1,144 @@ +"""Unit tests for the inference HTTP server.""" + +import pytest +from unittest.mock import MagicMock, patch +from fastapi.testclient import TestClient +from astrai.inference.server import app + + +def test_health_no_model(client, monkeypatch): + """GET /health should return 200 even when model not loaded.""" + monkeypatch.setattr("astrai.inference.server._model_param", None) + response = client.get("/health") + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + assert data["model_loaded"] == False + + +def test_health_with_model(client, loaded_model): + """GET /health should return 200 when model is loaded.""" + response = client.get("/health") + assert response.status_code == 200 + assert response.json() == {"status": "ok", "model_loaded": True} + + +def test_generate_non_stream(client, loaded_model, mock_generator): + """POST /generate with stream=false should return JSON response.""" + MockFactory, mock_gen = mock_generator + mock_gen.generate.return_value = "Test response" + response = client.post( + "/generate", + params={ + "query": "Hello", + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "max_len": 100, + "stream": False, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["response"] == "Test response" + MockFactory.create.assert_called_once() + + +def test_generate_stream(client, loaded_model, mock_generator): + """POST /generate with stream=true should return plain text stream.""" + MockFactory, mock_gen = mock_generator + # Simulate a streaming generator that yields two chunks + mock_gen.generate.return_value = ["chunk1", "chunk2"] + response = client.post( + "/generate", + params={ + "query": "Hello", + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "max_len": 100, + "stream": True, + }, + headers={"Accept": "text/plain"}, + ) + assert response.status_code == 200 + assert response.headers["content-type"] == "text/plain; charset=utf-8" + # The stream yields lines ending with newline + content = response.content.decode("utf-8") + assert "chunk1" in content + assert "chunk2" in content + + +def test_chat_completions_non_stream(client, loaded_model, mock_generator): + """POST /v1/chat/completions with stream=false returns OpenAI‑style JSON.""" + MockFactory, mock_gen = mock_generator + mock_gen.generate.return_value = "Assistant reply" + response = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "max_tokens": 100, + "stream": False, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data["object"] == "chat.completion" + assert len(data["choices"]) == 1 + assert data["choices"][0]["message"]["content"] == "Assistant reply" + + +def test_chat_completions_stream(client, loaded_model, mock_generator): + """POST /v1/chat/completions with stream=true returns SSE stream.""" + MockFactory, mock_gen = mock_generator + # Simulate a streaming generator that yields cumulative responses + mock_gen.generate.return_value = ["cumulative1", "cumulative2"] + response = client.post( + "/v1/chat/completions", + json={ + "messages": [{"role": "user", "content": "Hello"}], + "temperature": 0.8, + "top_p": 0.95, + "top_k": 50, + "max_tokens": 100, + "stream": True, + }, + headers={"Accept": "text/event-stream"}, + ) + assert response.status_code == 200 + assert response.headers["content-type"] == "text/event-stream; charset=utf-8" + # Parse SSE lines + lines = [ + line.strip() for line in response.content.decode("utf-8").split("\n") if line + ] + # Should contain data lines and a final [DONE] + assert any("cumulative1" in line for line in lines) + assert any("cumulative2" in line for line in lines) + + +def test_generate_with_history(client, loaded_model, mock_generator): + """POST /generate with history parameter.""" + MockFactory, mock_gen = mock_generator + mock_gen.generate.return_value = "Response with history" + response = client.post( + "/generate", + params={ + "query": "Hi", + "history": [["user1", "assistant1"], ["user2", "assistant2"]], + "stream": False, + }, + ) + assert response.status_code == 200 + MockFactory.create.assert_called_once() + # Check that history was passed correctly (currently history is not parsed due to FastAPI limitation) + call_args = MockFactory.create.call_args + req = call_args[0][1] # second argument is GenerationRequest + # Because history cannot be passed via query params, it will be None + assert req.history is None + + +if __name__ == "__main__": + pytest.main([__file__, "-v"]) diff --git a/tests/module/test_module.py b/tests/module/test_module.py index a37bbc1..6304989 100644 --- a/tests/module/test_module.py +++ b/tests/module/test_module.py @@ -1,56 +1,10 @@ import os -import json import torch -import shutil -import pytest -import tempfile -import safetensors.torch as st from astrai.trainer import * from astrai.config import * from astrai.model import * from astrai.data import * from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore -from tokenizers import pre_tokenizers - - -@pytest.fixture -def test_env(request: pytest.FixtureRequest): - func_name = request.function.__name__ - test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") - config_path = os.path.join(test_dir, "config.json") - tokenizer_path = os.path.join(test_dir, "tokenizer.json") - model_path = os.path.join(test_dir, "model.safetensors") - - config = { - "vocab_size": 1000, - "dim": 128, - "n_heads": 4, - "n_kv_heads": 2, - "dim_ffn": 256, - "max_len": 64, - "n_layers": 2, - "norm_eps": 1e-5, - } - with open(config_path, "w") as f: - json.dump(config, f) - - tokenizer = BpeTokenizer() - sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) - tokenizer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) - tokenizer.save(tokenizer_path) - - transformer_config = ModelConfig().load(config_path) - model = Transformer(transformer_config) - st.save_file(model.state_dict(), model_path) - - yield { - "test_dir": test_dir, - "model": model, - "tokenizer": tokenizer, - "transformer_config": transformer_config, - } - - shutil.rmtree(test_dir) def test_model_parameter(test_env): diff --git a/tests/module/test_tie_weight.py b/tests/module/test_tie_weight.py index b628aca..63f71f0 100644 --- a/tests/module/test_tie_weight.py +++ b/tests/module/test_tie_weight.py @@ -10,7 +10,6 @@ from astrai.config.model_config import ModelConfig @pytest.fixture def transformer_test_env(): - """创建Transformer测试专用环境""" test_dir = tempfile.mkdtemp(prefix="transformer_test_") config_path = os.path.join(test_dir, "config.json") diff --git a/tests/test_parallel.py b/tests/parallel/test_parallel.py similarity index 100% rename from tests/test_parallel.py rename to tests/parallel/test_parallel.py diff --git a/tests/trainer/conftest.py b/tests/trainer/conftest.py new file mode 100644 index 0000000..e4bd006 --- /dev/null +++ b/tests/trainer/conftest.py @@ -0,0 +1,97 @@ +import torch +from torch.utils.data import Dataset +import pytest + + +class TrainerDataset(Dataset): + """Base dataset for trainer tests with consistent interface.""" + + def __init__(self, length=100, max_length=64, vocab_size=1000): + self.length = length + self.max_length = max_length + self.vocab_size = vocab_size + + def __len__(self): + return self.length + + def __getitem__(self, idx): + return { + "input_ids": torch.randint(0, self.vocab_size, (self.max_length,)), + "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)), + } + + +def create_train_config( + model: torch.nn.Module, + dataset: Dataset, + test_dir: str, + device: str, + strategy: str = "seq", + n_epoch: int = 1, + batch_size: int = 2, + accumulation_steps: int = 1, + max_grad_norm: float = 1.0, + ckpt_interval: int = 5, + random_seed: int = 42, + **kwargs, +): + """Factory function to create common TrainConfig for tests. + + Args: + model: The model to train + dataset: Training dataset + test_dir: Checkpoint directory + device: Device type ("cuda" or "cpu") + strategy: Training strategy type (default: "seq") + n_epoch: Number of epochs (default: 1) + batch_size: Batch size (default: 2) + accumulation_steps: Gradient accumulation steps (default: 1) + max_grad_norm: Maximum gradient norm for clipping (default: 1.0) + ckpt_interval: Checkpoint save interval in iterations (default: 5) + random_seed: Random seed for reproducibility (default: 42) + **kwargs: Additional arguments passed to TrainConfig + + Returns: + TrainConfig instance configured for testing + """ + from astrai.config import TrainConfig + from astrai.config.schedule_config import CosineScheduleConfig + from astrai.trainer.schedule import SchedulerFactory + + schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) + optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001) + scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) + + return TrainConfig( + strategy=strategy, + model=model, + dataset=dataset, + optimizer_fn=optimizer_fn, + scheduler_fn=scheduler_fn, + ckpt_dir=test_dir, + n_epoch=n_epoch, + batch_size=batch_size, + ckpt_interval=ckpt_interval, + accumulation_steps=accumulation_steps, + max_grad_norm=max_grad_norm, + random_seed=random_seed, + device_type=device, + **kwargs, + ) + + +@pytest.fixture +def train_config_factory(): + """Fixture that provides the create_train_config factory function. + + This fixture can be used by tests to create consistent TrainConfig + instances with sensible defaults for testing. + """ + return create_train_config + + +@pytest.fixture +def trainer_dataset(): + """Fixture providing a dataset for trainer tests.""" + dataset = TrainerDataset() + yield dataset diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 3f69398..b0fad1f 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,63 +1,39 @@ import torch -import numpy as np - -from astrai.config import * -from astrai.trainer import * from astrai.data.dataset import * +from astrai.trainer import Trainer + +# train_config_factory is injected via fixture -def test_different_batch_sizes(base_test_env, random_dataset): +def test_different_batch_sizes(base_test_env, random_dataset, train_config_factory): """Test training with different batch sizes""" batch_sizes = [1, 2, 4, 8] for batch_size in batch_sizes: - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) - scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) - - train_config = TrainConfig( - strategy="seq", + train_config = train_config_factory( model=base_test_env["model"], dataset=random_dataset, - optimizer_fn=optimizer_fn, - scheduler_fn=scheduler_fn, - ckpt_dir=base_test_env["test_dir"], - n_epoch=1, + test_dir=base_test_env["test_dir"], + device=base_test_env["device"], batch_size=batch_size, - ckpt_interval=5, - accumulation_steps=1, - max_grad_norm=1.0, - random_seed=np.random.randint(1000), - device_type=base_test_env["device"], ) assert train_config.batch_size == batch_size -def test_gradient_accumulation(base_test_env, random_dataset): +def test_gradient_accumulation(base_test_env, random_dataset, train_config_factory): """Test training with different gradient accumulation steps""" accumulation_steps_list = [1, 2, 4] for accumulation_steps in accumulation_steps_list: - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) - scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) - - train_config = TrainConfig( - strategy="seq", + train_config = train_config_factory( model=base_test_env["model"], - optimizer_fn=optimizer_fn, - scheduler_fn=scheduler_fn, dataset=random_dataset, - ckpt_dir=base_test_env["test_dir"], - n_epoch=1, + test_dir=base_test_env["test_dir"], + device=base_test_env["device"], batch_size=2, - ckpt_interval=10, accumulation_steps=accumulation_steps, - max_grad_norm=1.0, - random_seed=42, - device_type=base_test_env["device"], ) trainer = Trainer(train_config) @@ -66,7 +42,7 @@ def test_gradient_accumulation(base_test_env, random_dataset): assert train_config.accumulation_steps == accumulation_steps -def test_memory_efficient_training(base_test_env, random_dataset): +def test_memory_efficient_training(base_test_env, random_dataset, train_config_factory): """Test training with memory-efficient configurations""" # Test with smaller batch sizes and gradient checkpointing small_batch_configs = [ @@ -76,24 +52,13 @@ def test_memory_efficient_training(base_test_env, random_dataset): ] for config in small_batch_configs: - schedule_config = CosineScheduleConfig(warmup_steps=10, total_steps=20) - optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) - scheduler_fn = lambda optim: SchedulerFactory.load(optim, schedule_config) - - train_config = TrainConfig( - strategy="seq", + train_config = train_config_factory( model=base_test_env["model"], dataset=random_dataset, - optimizer_fn=optimizer_fn, - scheduler_fn=scheduler_fn, - ckpt_dir=base_test_env["test_dir"], - n_epoch=1, + test_dir=base_test_env["test_dir"], + device=base_test_env["device"], batch_size=config["batch_size"], - ckpt_interval=5, accumulation_steps=config["accumulation_steps"], - max_grad_norm=1.0, - random_seed=42, - device_type=base_test_env["device"], ) assert train_config.accumulation_steps == config["accumulation_steps"]