chore: 修改类名,优化导入顺序
This commit is contained in:
parent
9b22b1651e
commit
39766aa1dc
|
|
@ -5,14 +5,14 @@ from astrai.config import (
|
|||
ModelConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.tokenize import BpeTokenizer
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.inference import (
|
||||
GenerationRequest,
|
||||
InferenceEngine,
|
||||
)
|
||||
from astrai.model import AutoModel, Transformer
|
||||
from astrai.tokenize import BpeTokenizer
|
||||
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
||||
|
||||
__all__ = [
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
from astrai.dataset.dataset import (
|
||||
BaseDataset,
|
||||
DatasetFactory,
|
||||
BaseSegmentFetcher,
|
||||
DatasetFactory,
|
||||
MultiSegmentFetcher,
|
||||
)
|
||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
"""Base factory class for extensible component registration."""
|
||||
|
||||
from abc import ABC
|
||||
from typing import Callable, Dict, Generic, Type, TypeVar, Optional, List, Tuple
|
||||
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.tokenize.tokenizer import TextTokenizer
|
||||
from astrai.inference.scheduler import InferenceScheduler
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -109,7 +109,7 @@ class InferenceEngine:
|
|||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
tokenizer: TextTokenizer,
|
||||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 1,
|
||||
max_seq_len: Optional[int] = None,
|
||||
):
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import torch
|
|||
from torch import Tensor
|
||||
|
||||
from astrai.model.automodel import AutoModel
|
||||
from astrai.tokenize.tokenizer import TextTokenizer
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
class TaskStatus:
|
||||
|
|
@ -101,7 +101,7 @@ class InferenceScheduler:
|
|||
def __init__(
|
||||
self,
|
||||
model: AutoModel,
|
||||
tokenizer: TextTokenizer,
|
||||
tokenizer: AutoTokenizer,
|
||||
max_batch_size: int = 16,
|
||||
max_seq_len: Optional[int] = None,
|
||||
device: str = "cuda",
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ from pydantic import BaseModel, Field
|
|||
|
||||
from astrai.inference.engine import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import TextTokenizer
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -96,7 +96,7 @@ def load_model(
|
|||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||
|
||||
# Load tokenizer separately
|
||||
tokenizer = TextTokenizer.from_pretrained(param_path)
|
||||
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||
_model_param = AutoModel.from_pretrained(param_path)
|
||||
_model_param.to(device=device, dtype=dtype)
|
||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from astrai.model.automodel import AutoModel
|
||||
from astrai.model.module import (
|
||||
GQA,
|
||||
MLP,
|
||||
|
|
@ -6,8 +7,6 @@ from astrai.model.module import (
|
|||
RMSNorm,
|
||||
)
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.model.automodel import AutoModel
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Modules
|
||||
|
|
|
|||
|
|
@ -2,12 +2,12 @@
|
|||
AutoModel base class for model loading and saving.
|
||||
"""
|
||||
|
||||
import torch.nn as nn
|
||||
import safetensors.torch as st
|
||||
|
||||
from pathlib import Path
|
||||
from contextlib import contextmanager
|
||||
from typing import Self, Union, Dict, Type
|
||||
from pathlib import Path
|
||||
from typing import Dict, Self, Type, Union
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch.nn as nn
|
||||
|
||||
from astrai.config import ModelConfig
|
||||
|
||||
|
|
|
|||
|
|
@ -1,23 +1,15 @@
|
|||
from astrai.tokenize.chat_template import ChatTemplate, MessageType
|
||||
from astrai.tokenize.tokenizer import (
|
||||
TextTokenizer,
|
||||
AutoTokenizer,
|
||||
BpeTokenizer,
|
||||
)
|
||||
from astrai.tokenize.trainer import BpeTrainer
|
||||
from astrai.tokenize.chat_template import (
|
||||
ChatTemplate,
|
||||
HistoryType,
|
||||
MessageType,
|
||||
)
|
||||
|
||||
# Alias for compatibility
|
||||
AutoTokenizer = TextTokenizer
|
||||
|
||||
__all__ = [
|
||||
"TextTokenizer",
|
||||
"AutoTokenizer",
|
||||
"BpeTokenizer",
|
||||
"BpeTrainer",
|
||||
"ChatTemplate",
|
||||
"HistoryType",
|
||||
"MessageType",
|
||||
"HistoryType",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from typing import Dict, List, Optional, Tuple, Any
|
||||
from jinja2 import Template
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
MessageType = Dict[str, str]
|
||||
from jinja2 import Template
|
||||
|
||||
# Message type for chat messages
|
||||
type MessageType = Dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -8,10 +8,11 @@ from typing import Dict, List, Optional, Union
|
|||
|
||||
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||||
from tokenizers.models import BPE
|
||||
|
||||
from astrai.tokenize.chat_template import ChatTemplate
|
||||
|
||||
|
||||
class TextTokenizer:
|
||||
class AutoTokenizer:
|
||||
"""Base tokenizer class with automatic loading support"""
|
||||
|
||||
TOKENIZER_CLASSES = {} # Registry for auto-loading
|
||||
|
|
@ -51,7 +52,7 @@ class TextTokenizer:
|
|||
self.set_chat_template(config["chat_template"])
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer":
|
||||
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
|
||||
"""Load tokenizer from pretrained directory."""
|
||||
instance = cls(path)
|
||||
return instance
|
||||
|
|
@ -235,7 +236,7 @@ class TextTokenizer:
|
|||
return rendered
|
||||
|
||||
|
||||
class BpeTokenizer(TextTokenizer):
|
||||
class BpeTokenizer(AutoTokenizer):
|
||||
"""BPE tokenizer implementation."""
|
||||
|
||||
def __init__(
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
from astrai.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
CallbackFactory,
|
||||
TrainCallback,
|
||||
)
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
|
|
|||
|
|
@ -8,8 +8,9 @@ import torch.nn as nn
|
|||
from torch.nn.utils import clip_grad_norm_
|
||||
from tqdm import tqdm
|
||||
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_grad_max,
|
||||
ctx_get_grad_mean,
|
||||
|
|
@ -21,7 +22,6 @@ from astrai.trainer.metric_util import (
|
|||
ctx_get_lr,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext
|
||||
from astrai.factory import BaseFactory
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
|
|
|
|||
|
|
@ -8,8 +8,8 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.dataset import ResumableDistributedSampler
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -2,11 +2,11 @@ import logging
|
|||
from typing import List, Optional
|
||||
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
CallbackFactory,
|
||||
TrainCallback,
|
||||
)
|
||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
|
||||
|
|
|
|||
|
|
@ -1,11 +1,11 @@
|
|||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from astrai.inference import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
|
||||
|
||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import json
|
|||
|
||||
import torch
|
||||
|
||||
from astrai.inference import InferenceEngine
|
||||
from astrai.model import AutoModel
|
||||
from astrai.tokenize import AutoTokenizer
|
||||
from astrai.inference import InferenceEngine
|
||||
|
||||
|
||||
def processor(
|
||||
|
|
|
|||
|
|
@ -2,10 +2,10 @@ import argparse
|
|||
import os
|
||||
from functools import partial
|
||||
|
||||
import safetensors.torch as st
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
import safetensors.torch as st
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.config import ModelConfig, TrainConfig
|
||||
|
|
|
|||
|
|
@ -11,8 +11,8 @@ from tokenizers import pre_tokenizers
|
|||
from torch.utils.data import Dataset
|
||||
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.tokenize import BpeTokenizer, BpeTrainer
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.tokenize import BpeTokenizer, BpeTrainer
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
|
|
|
|||
|
|
@ -5,8 +5,8 @@ import torch.distributed as dist
|
|||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
||||
from astrai.serialization import Checkpoint
|
||||
|
||||
|
||||
def test_single_process():
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler
|
||||
from astrai.trainer.schedule import CosineScheduler, SchedulerFactory, SGDRScheduler
|
||||
|
||||
|
||||
def test_schedule_factory_random_configs():
|
||||
|
|
|
|||
Loading…
Reference in New Issue