diff --git a/astrai/__init__.py b/astrai/__init__.py index b98a372..453eaa0 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -5,14 +5,14 @@ from astrai.config import ( ModelConfig, TrainConfig, ) -from astrai.factory import BaseFactory from astrai.dataset import DatasetFactory -from astrai.tokenize import BpeTokenizer +from astrai.factory import BaseFactory from astrai.inference import ( GenerationRequest, InferenceEngine, ) from astrai.model import AutoModel, Transformer +from astrai.tokenize import BpeTokenizer from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer __all__ = [ diff --git a/astrai/dataset/__init__.py b/astrai/dataset/__init__.py index 56735e5..c42d532 100644 --- a/astrai/dataset/__init__.py +++ b/astrai/dataset/__init__.py @@ -1,7 +1,7 @@ from astrai.dataset.dataset import ( BaseDataset, - DatasetFactory, BaseSegmentFetcher, + DatasetFactory, MultiSegmentFetcher, ) from astrai.dataset.sampler import ResumableDistributedSampler diff --git a/astrai/factory.py b/astrai/factory.py index 2109113..2fd4819 100644 --- a/astrai/factory.py +++ b/astrai/factory.py @@ -1,7 +1,7 @@ """Base factory class for extensible component registration.""" from abc import ABC -from typing import Callable, Dict, Generic, Type, TypeVar, Optional, List, Tuple +from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar T = TypeVar("T") diff --git a/astrai/inference/engine.py b/astrai/inference/engine.py index 7b0aba3..8e0a2c1 100644 --- a/astrai/inference/engine.py +++ b/astrai/inference/engine.py @@ -8,8 +8,8 @@ from typing import Any, Dict, Generator, List, Optional, Union import torch import torch.nn as nn -from astrai.tokenize.tokenizer import TextTokenizer from astrai.inference.scheduler import InferenceScheduler +from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) @@ -109,7 +109,7 @@ class InferenceEngine: def __init__( self, model: nn.Module, - tokenizer: TextTokenizer, + tokenizer: AutoTokenizer, max_batch_size: int = 1, max_seq_len: Optional[int] = None, ): diff --git a/astrai/inference/scheduler.py b/astrai/inference/scheduler.py index d972b20..6d15334 100644 --- a/astrai/inference/scheduler.py +++ b/astrai/inference/scheduler.py @@ -9,7 +9,7 @@ import torch from torch import Tensor from astrai.model.automodel import AutoModel -from astrai.tokenize.tokenizer import TextTokenizer +from astrai.tokenize import AutoTokenizer class TaskStatus: @@ -101,7 +101,7 @@ class InferenceScheduler: def __init__( self, model: AutoModel, - tokenizer: TextTokenizer, + tokenizer: AutoTokenizer, max_batch_size: int = 16, max_seq_len: Optional[int] = None, device: str = "cuda", diff --git a/astrai/inference/server.py b/astrai/inference/server.py index d7cec2d..86b25c9 100644 --- a/astrai/inference/server.py +++ b/astrai/inference/server.py @@ -19,7 +19,7 @@ from pydantic import BaseModel, Field from astrai.inference.engine import InferenceEngine from astrai.model import AutoModel -from astrai.tokenize import TextTokenizer +from astrai.tokenize import AutoTokenizer logger = logging.getLogger(__name__) @@ -96,7 +96,7 @@ def load_model( raise FileNotFoundError(f"Parameter directory not found: {param_path}") # Load tokenizer separately - tokenizer = TextTokenizer.from_pretrained(param_path) + tokenizer = AutoTokenizer.from_pretrained(param_path) _model_param = AutoModel.from_pretrained(param_path) _model_param.to(device=device, dtype=dtype) logger.info(f"Model loaded on {device} with dtype {dtype}") diff --git a/astrai/model/__init__.py b/astrai/model/__init__.py index 35d74cc..252449f 100644 --- a/astrai/model/__init__.py +++ b/astrai/model/__init__.py @@ -1,3 +1,4 @@ +from astrai.model.automodel import AutoModel from astrai.model.module import ( GQA, MLP, @@ -6,8 +7,6 @@ from astrai.model.module import ( RMSNorm, ) from astrai.model.transformer import Transformer -from astrai.model.automodel import AutoModel - __all__ = [ # Modules diff --git a/astrai/model/automodel.py b/astrai/model/automodel.py index b2cbad1..8e4d5e9 100644 --- a/astrai/model/automodel.py +++ b/astrai/model/automodel.py @@ -2,12 +2,12 @@ AutoModel base class for model loading and saving. """ -import torch.nn as nn -import safetensors.torch as st - -from pathlib import Path from contextlib import contextmanager -from typing import Self, Union, Dict, Type +from pathlib import Path +from typing import Dict, Self, Type, Union + +import safetensors.torch as st +import torch.nn as nn from astrai.config import ModelConfig diff --git a/astrai/tokenize/__init__.py b/astrai/tokenize/__init__.py index e98db15..f0b10f7 100644 --- a/astrai/tokenize/__init__.py +++ b/astrai/tokenize/__init__.py @@ -1,23 +1,15 @@ +from astrai.tokenize.chat_template import ChatTemplate, MessageType from astrai.tokenize.tokenizer import ( - TextTokenizer, + AutoTokenizer, BpeTokenizer, ) from astrai.tokenize.trainer import BpeTrainer -from astrai.tokenize.chat_template import ( - ChatTemplate, - HistoryType, - MessageType, -) - -# Alias for compatibility -AutoTokenizer = TextTokenizer __all__ = [ - "TextTokenizer", "AutoTokenizer", "BpeTokenizer", "BpeTrainer", "ChatTemplate", - "HistoryType", "MessageType", + "HistoryType", ] diff --git a/astrai/tokenize/chat_template.py b/astrai/tokenize/chat_template.py index 6d4b2ee..6f81f7c 100644 --- a/astrai/tokenize/chat_template.py +++ b/astrai/tokenize/chat_template.py @@ -1,9 +1,10 @@ -from typing import Dict, List, Optional, Tuple, Any -from jinja2 import Template from dataclasses import dataclass +from typing import Any, Dict, List, Optional -HistoryType = List[Tuple[str, str]] -MessageType = Dict[str, str] +from jinja2 import Template + +# Message type for chat messages +type MessageType = Dict[str, Any] @dataclass diff --git a/astrai/tokenize/tokenizer.py b/astrai/tokenize/tokenizer.py index a41b847..f717025 100644 --- a/astrai/tokenize/tokenizer.py +++ b/astrai/tokenize/tokenizer.py @@ -8,10 +8,11 @@ from typing import Dict, List, Optional, Union from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors from tokenizers.models import BPE + from astrai.tokenize.chat_template import ChatTemplate -class TextTokenizer: +class AutoTokenizer: """Base tokenizer class with automatic loading support""" TOKENIZER_CLASSES = {} # Registry for auto-loading @@ -51,7 +52,7 @@ class TextTokenizer: self.set_chat_template(config["chat_template"]) @classmethod - def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer": + def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer": """Load tokenizer from pretrained directory.""" instance = cls(path) return instance @@ -235,7 +236,7 @@ class TextTokenizer: return rendered -class BpeTokenizer(TextTokenizer): +class BpeTokenizer(AutoTokenizer): """BPE tokenizer implementation.""" def __init__( diff --git a/astrai/trainer/__init__.py b/astrai/trainer/__init__.py index 0e4485f..f7c5d5b 100644 --- a/astrai/trainer/__init__.py +++ b/astrai/trainer/__init__.py @@ -1,8 +1,8 @@ from astrai.trainer.schedule import BaseScheduler, SchedulerFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.train_callback import ( - TrainCallback, CallbackFactory, + TrainCallback, ) from astrai.trainer.trainer import Trainer diff --git a/astrai/trainer/train_callback.py b/astrai/trainer/train_callback.py index c79508d..6381b31 100644 --- a/astrai/trainer/train_callback.py +++ b/astrai/trainer/train_callback.py @@ -8,8 +8,9 @@ import torch.nn as nn from torch.nn.utils import clip_grad_norm_ from tqdm import tqdm -from astrai.serialization import Checkpoint +from astrai.factory import BaseFactory from astrai.parallel import only_on_rank +from astrai.serialization import Checkpoint from astrai.trainer.metric_util import ( ctx_get_grad_max, ctx_get_grad_mean, @@ -21,7 +22,6 @@ from astrai.trainer.metric_util import ( ctx_get_lr, ) from astrai.trainer.train_context import TrainContext -from astrai.factory import BaseFactory @runtime_checkable diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index d689d4d..54d7319 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -8,8 +8,8 @@ from torch.utils.data import DataLoader from astrai.config.train_config import TrainConfig from astrai.dataset import ResumableDistributedSampler -from astrai.serialization import Checkpoint from astrai.parallel.setup import get_current_device, get_rank, get_world_size +from astrai.serialization import Checkpoint from astrai.trainer.strategy import BaseStrategy, StrategyFactory diff --git a/astrai/trainer/trainer.py b/astrai/trainer/trainer.py index 817f258..9831545 100644 --- a/astrai/trainer/trainer.py +++ b/astrai/trainer/trainer.py @@ -2,11 +2,11 @@ import logging from typing import List, Optional from astrai.config import TrainConfig -from astrai.serialization import Checkpoint from astrai.parallel.setup import spawn_parallel_fn +from astrai.serialization import Checkpoint from astrai.trainer.train_callback import ( - TrainCallback, CallbackFactory, + TrainCallback, ) from astrai.trainer.train_context import TrainContext, TrainContextBuilder diff --git a/scripts/demo/stream_chat.py b/scripts/demo/stream_chat.py index 1a5ce95..a87a685 100644 --- a/scripts/demo/stream_chat.py +++ b/scripts/demo/stream_chat.py @@ -1,11 +1,11 @@ from pathlib import Path import torch + from astrai.inference import InferenceEngine from astrai.model import AutoModel from astrai.tokenize import AutoTokenizer - PROJECT_ROOT = Path(__file__).resolve().parents[2] PARAMETER_ROOT = Path(PROJECT_ROOT, "params") diff --git a/scripts/tools/generate.py b/scripts/tools/generate.py index 3ab02c9..ab54528 100644 --- a/scripts/tools/generate.py +++ b/scripts/tools/generate.py @@ -3,9 +3,9 @@ import json import torch +from astrai.inference import InferenceEngine from astrai.model import AutoModel from astrai.tokenize import AutoTokenizer -from astrai.inference import InferenceEngine def processor( diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 13a2893..5b02b3f 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -2,10 +2,10 @@ import argparse import os from functools import partial +import safetensors.torch as st import torch import torch.nn as nn import torch.optim as optim -import safetensors.torch as st from torch.nn.parallel import DistributedDataParallel as DDP from astrai.config import ModelConfig, TrainConfig diff --git a/tests/conftest.py b/tests/conftest.py index 0b6073e..01b7f67 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -11,8 +11,8 @@ from tokenizers import pre_tokenizers from torch.utils.data import Dataset from astrai.config.model_config import ModelConfig -from astrai.tokenize import BpeTokenizer, BpeTrainer from astrai.model.transformer import Transformer +from astrai.tokenize import BpeTokenizer, BpeTrainer class RandomDataset(Dataset): diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index 5e4d6f9..ce68447 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -5,8 +5,8 @@ import torch.distributed as dist from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR -from astrai.serialization import Checkpoint from astrai.parallel.setup import get_rank, spawn_parallel_fn +from astrai.serialization import Checkpoint def test_single_process(): diff --git a/tests/trainer/test_train_strategy.py b/tests/trainer/test_train_strategy.py index 6e347b0..2926f56 100644 --- a/tests/trainer/test_train_strategy.py +++ b/tests/trainer/test_train_strategy.py @@ -1,7 +1,7 @@ import numpy as np import torch -from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler +from astrai.trainer.schedule import CosineScheduler, SchedulerFactory, SGDRScheduler def test_schedule_factory_random_configs():