chore: 修改类名,优化导入顺序

This commit is contained in:
ViperEkura 2026-04-05 22:27:57 +08:00
parent 9b22b1651e
commit 39766aa1dc
21 changed files with 40 additions and 47 deletions

View File

@ -5,14 +5,14 @@ from astrai.config import (
ModelConfig,
TrainConfig,
)
from astrai.factory import BaseFactory
from astrai.dataset import DatasetFactory
from astrai.tokenize import BpeTokenizer
from astrai.factory import BaseFactory
from astrai.inference import (
GenerationRequest,
InferenceEngine,
)
from astrai.model import AutoModel, Transformer
from astrai.tokenize import BpeTokenizer
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
__all__ = [

View File

@ -1,7 +1,7 @@
from astrai.dataset.dataset import (
BaseDataset,
DatasetFactory,
BaseSegmentFetcher,
DatasetFactory,
MultiSegmentFetcher,
)
from astrai.dataset.sampler import ResumableDistributedSampler

View File

@ -1,7 +1,7 @@
"""Base factory class for extensible component registration."""
from abc import ABC
from typing import Callable, Dict, Generic, Type, TypeVar, Optional, List, Tuple
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
T = TypeVar("T")

View File

@ -8,8 +8,8 @@ from typing import Any, Dict, Generator, List, Optional, Union
import torch
import torch.nn as nn
from astrai.tokenize.tokenizer import TextTokenizer
from astrai.inference.scheduler import InferenceScheduler
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
@ -109,7 +109,7 @@ class InferenceEngine:
def __init__(
self,
model: nn.Module,
tokenizer: TextTokenizer,
tokenizer: AutoTokenizer,
max_batch_size: int = 1,
max_seq_len: Optional[int] = None,
):

View File

@ -9,7 +9,7 @@ import torch
from torch import Tensor
from astrai.model.automodel import AutoModel
from astrai.tokenize.tokenizer import TextTokenizer
from astrai.tokenize import AutoTokenizer
class TaskStatus:
@ -101,7 +101,7 @@ class InferenceScheduler:
def __init__(
self,
model: AutoModel,
tokenizer: TextTokenizer,
tokenizer: AutoTokenizer,
max_batch_size: int = 16,
max_seq_len: Optional[int] = None,
device: str = "cuda",

View File

@ -19,7 +19,7 @@ from pydantic import BaseModel, Field
from astrai.inference.engine import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import TextTokenizer
from astrai.tokenize import AutoTokenizer
logger = logging.getLogger(__name__)
@ -96,7 +96,7 @@ def load_model(
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
# Load tokenizer separately
tokenizer = TextTokenizer.from_pretrained(param_path)
tokenizer = AutoTokenizer.from_pretrained(param_path)
_model_param = AutoModel.from_pretrained(param_path)
_model_param.to(device=device, dtype=dtype)
logger.info(f"Model loaded on {device} with dtype {dtype}")

View File

@ -1,3 +1,4 @@
from astrai.model.automodel import AutoModel
from astrai.model.module import (
GQA,
MLP,
@ -6,8 +7,6 @@ from astrai.model.module import (
RMSNorm,
)
from astrai.model.transformer import Transformer
from astrai.model.automodel import AutoModel
__all__ = [
# Modules

View File

@ -2,12 +2,12 @@
AutoModel base class for model loading and saving.
"""
import torch.nn as nn
import safetensors.torch as st
from pathlib import Path
from contextlib import contextmanager
from typing import Self, Union, Dict, Type
from pathlib import Path
from typing import Dict, Self, Type, Union
import safetensors.torch as st
import torch.nn as nn
from astrai.config import ModelConfig

View File

@ -1,23 +1,15 @@
from astrai.tokenize.chat_template import ChatTemplate, MessageType
from astrai.tokenize.tokenizer import (
TextTokenizer,
AutoTokenizer,
BpeTokenizer,
)
from astrai.tokenize.trainer import BpeTrainer
from astrai.tokenize.chat_template import (
ChatTemplate,
HistoryType,
MessageType,
)
# Alias for compatibility
AutoTokenizer = TextTokenizer
__all__ = [
"TextTokenizer",
"AutoTokenizer",
"BpeTokenizer",
"BpeTrainer",
"ChatTemplate",
"HistoryType",
"MessageType",
"HistoryType",
]

View File

@ -1,9 +1,10 @@
from typing import Dict, List, Optional, Tuple, Any
from jinja2 import Template
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
HistoryType = List[Tuple[str, str]]
MessageType = Dict[str, str]
from jinja2 import Template
# Message type for chat messages
type MessageType = Dict[str, Any]
@dataclass

View File

@ -8,10 +8,11 @@ from typing import Dict, List, Optional, Union
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
from tokenizers.models import BPE
from astrai.tokenize.chat_template import ChatTemplate
class TextTokenizer:
class AutoTokenizer:
"""Base tokenizer class with automatic loading support"""
TOKENIZER_CLASSES = {} # Registry for auto-loading
@ -51,7 +52,7 @@ class TextTokenizer:
self.set_chat_template(config["chat_template"])
@classmethod
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer":
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
"""Load tokenizer from pretrained directory."""
instance = cls(path)
return instance
@ -235,7 +236,7 @@ class TextTokenizer:
return rendered
class BpeTokenizer(TextTokenizer):
class BpeTokenizer(AutoTokenizer):
"""BPE tokenizer implementation."""
def __init__(

View File

@ -1,8 +1,8 @@
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import (
TrainCallback,
CallbackFactory,
TrainCallback,
)
from astrai.trainer.trainer import Trainer

View File

@ -8,8 +8,9 @@ import torch.nn as nn
from torch.nn.utils import clip_grad_norm_
from tqdm import tqdm
from astrai.serialization import Checkpoint
from astrai.factory import BaseFactory
from astrai.parallel import only_on_rank
from astrai.serialization import Checkpoint
from astrai.trainer.metric_util import (
ctx_get_grad_max,
ctx_get_grad_mean,
@ -21,7 +22,6 @@ from astrai.trainer.metric_util import (
ctx_get_lr,
)
from astrai.trainer.train_context import TrainContext
from astrai.factory import BaseFactory
@runtime_checkable

View File

@ -8,8 +8,8 @@ from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig
from astrai.dataset import ResumableDistributedSampler
from astrai.serialization import Checkpoint
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.serialization import Checkpoint
from astrai.trainer.strategy import BaseStrategy, StrategyFactory

View File

@ -2,11 +2,11 @@ import logging
from typing import List, Optional
from astrai.config import TrainConfig
from astrai.serialization import Checkpoint
from astrai.parallel.setup import spawn_parallel_fn
from astrai.serialization import Checkpoint
from astrai.trainer.train_callback import (
TrainCallback,
CallbackFactory,
TrainCallback,
)
from astrai.trainer.train_context import TrainContext, TrainContextBuilder

View File

@ -1,11 +1,11 @@
from pathlib import Path
import torch
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
PROJECT_ROOT = Path(__file__).resolve().parents[2]
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -3,9 +3,9 @@ import json
import torch
from astrai.inference import InferenceEngine
from astrai.model import AutoModel
from astrai.tokenize import AutoTokenizer
from astrai.inference import InferenceEngine
def processor(

View File

@ -2,10 +2,10 @@ import argparse
import os
from functools import partial
import safetensors.torch as st
import torch
import torch.nn as nn
import torch.optim as optim
import safetensors.torch as st
from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelConfig, TrainConfig

View File

@ -11,8 +11,8 @@ from tokenizers import pre_tokenizers
from torch.utils.data import Dataset
from astrai.config.model_config import ModelConfig
from astrai.tokenize import BpeTokenizer, BpeTrainer
from astrai.model.transformer import Transformer
from astrai.tokenize import BpeTokenizer, BpeTrainer
class RandomDataset(Dataset):

View File

@ -5,8 +5,8 @@ import torch.distributed as dist
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from astrai.serialization import Checkpoint
from astrai.parallel.setup import get_rank, spawn_parallel_fn
from astrai.serialization import Checkpoint
def test_single_process():

View File

@ -1,7 +1,7 @@
import numpy as np
import torch
from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler
from astrai.trainer.schedule import CosineScheduler, SchedulerFactory, SGDRScheduler
def test_schedule_factory_random_configs():