import json import os import shutil import tempfile import numpy as np import pytest import safetensors.torch as st import torch from tokenizers import pre_tokenizers from torch.utils.data import Dataset from astrai.config.model_config import ModelConfig from astrai.model.transformer import Transformer from astrai.tokenize import BpeTokenizer, BpeTrainer class RandomDataset(Dataset): """Random dataset for testing purposes.""" def __init__(self, length=None, max_length=64, vocab_size=1000): self.length = length or int(np.random.randint(100, 200)) self.max_length = max_length self.vocab_size = vocab_size def __len__(self): return self.length def __getitem__(self, idx): return { "input_ids": torch.randint(0, self.vocab_size, (self.max_length,)), "target_ids": torch.randint(0, self.vocab_size, (self.max_length,)), } class MultiTurnDataset(Dataset): """Multi-turn dataset with loss mask for SFT training tests.""" def __init__(self, length=None, max_length=64, vocab_size=1000): self.length = length or int(np.random.randint(100, 200)) self.max_length = max_length self.vocab_size = vocab_size def __len__(self): return self.length def __getitem__(self, idx): input_ids = torch.randint(0, self.vocab_size, (self.max_length,)) target_ids = torch.randint(0, self.vocab_size, (self.max_length,)) loss_mask = torch.randint(0, 1, (self.max_length,)) return { "input_ids": input_ids, "target_ids": target_ids, "loss_mask": loss_mask, } class EarlyStoppingDataset(Dataset): """Dataset that triggers early stopping after a specified number of iterations.""" def __init__(self, length=10, stop_after=5): self.length = length self.stop_after = stop_after self.count = 0 def __len__(self): return self.length def __getitem__(self, idx): self.count += 1 if self.count == self.stop_after: raise RuntimeError("Simulated early stopping") return { "input_ids": torch.randint(0, 1000, (64,)), "target_ids": torch.randint(0, 1000, (64,)), } @pytest.fixture def base_test_env(request: pytest.FixtureRequest): """Create base test environment with randomly configured model and tokenizer""" func_name = request.function.__name__ test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") n_dim_choices = [8, 16, 32] n_head_choices = [2, 4] dim = int(np.random.choice(n_dim_choices)) n_heads = int(np.random.choice(n_head_choices)) n_kv_heads = n_heads // 2 dim_ffn = dim * 2 config = { "vocab_size": 1000, "dim": dim, "n_heads": n_heads, "n_kv_heads": n_kv_heads, "dim_ffn": dim_ffn, "max_len": 1024, "n_layers": 4, "norm_eps": 1e-5, } with open(config_path, "w") as f: json.dump(config, f) device = "cuda" if torch.cuda.is_available() else "cpu" transformer_config = ModelConfig().load(config_path) model = Transformer(transformer_config).to(device=device) tokenizer = BpeTokenizer() yield { "device": device, "test_dir": str(test_dir), "config_path": config_path, "transformer_config": transformer_config, "model": model, "tokenizer": tokenizer, } shutil.rmtree(test_dir) @pytest.fixture def random_dataset(): dataset = RandomDataset() yield dataset @pytest.fixture def multi_turn_dataset(): dataset = MultiTurnDataset() yield dataset @pytest.fixture def early_stopping_dataset(): dataset = EarlyStoppingDataset() yield dataset @pytest.fixture def test_env(request: pytest.FixtureRequest): """Create a test environment with saved model and tokenizer files.""" func_name = request.function.__name__ test_dir = tempfile.mkdtemp(prefix=f"{func_name}_") config_path = os.path.join(test_dir, "config.json") tokenizer_path = os.path.join(test_dir, "tokenizer.json") model_path = os.path.join(test_dir, "model.safetensors") config = { "vocab_size": 1000, "dim": 128, "n_heads": 4, "n_kv_heads": 2, "dim_ffn": 256, "max_len": 64, "n_layers": 2, "norm_eps": 1e-5, } with open(config_path, "w") as f: json.dump(config, f) tokenizer = BpeTokenizer() trainer = BpeTrainer(tokenizer) sp_token_iter = iter(pre_tokenizers.ByteLevel.alphabet()) trainer.train_from_iterator(sp_token_iter, config["vocab_size"], 1) tokenizer.save(tokenizer_path) transformer_config = ModelConfig().load(config_path) model = Transformer(transformer_config) st.save_file(model.state_dict(), model_path) yield { "test_dir": test_dir, "model": model, "tokenizer": tokenizer, "transformer_config": transformer_config, } shutil.rmtree(test_dir)