chore: 修改类名,优化导入顺序
This commit is contained in:
parent
9b22b1651e
commit
39766aa1dc
|
|
@ -5,14 +5,14 @@ from astrai.config import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
from astrai.dataset import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.tokenize import BpeTokenizer
|
from astrai.factory import BaseFactory
|
||||||
from astrai.inference import (
|
from astrai.inference import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
InferenceEngine,
|
InferenceEngine,
|
||||||
)
|
)
|
||||||
from astrai.model import AutoModel, Transformer
|
from astrai.model import AutoModel, Transformer
|
||||||
|
from astrai.tokenize import BpeTokenizer
|
||||||
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
from astrai.trainer import SchedulerFactory, StrategyFactory, Trainer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from astrai.dataset.dataset import (
|
from astrai.dataset.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
DatasetFactory,
|
|
||||||
BaseSegmentFetcher,
|
BaseSegmentFetcher,
|
||||||
|
DatasetFactory,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
)
|
)
|
||||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
"""Base factory class for extensible component registration."""
|
"""Base factory class for extensible component registration."""
|
||||||
|
|
||||||
from abc import ABC
|
from abc import ABC
|
||||||
from typing import Callable, Dict, Generic, Type, TypeVar, Optional, List, Tuple
|
from typing import Callable, Dict, Generic, List, Optional, Tuple, Type, TypeVar
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ from typing import Any, Dict, Generator, List, Optional, Union
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.tokenize.tokenizer import TextTokenizer
|
|
||||||
from astrai.inference.scheduler import InferenceScheduler
|
from astrai.inference.scheduler import InferenceScheduler
|
||||||
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -109,7 +109,7 @@ class InferenceEngine:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: nn.Module,
|
model: nn.Module,
|
||||||
tokenizer: TextTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 1,
|
max_batch_size: int = 1,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -9,7 +9,7 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.model.automodel import AutoModel
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.tokenize.tokenizer import TextTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
class TaskStatus:
|
class TaskStatus:
|
||||||
|
|
@ -101,7 +101,7 @@ class InferenceScheduler:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: AutoModel,
|
model: AutoModel,
|
||||||
tokenizer: TextTokenizer,
|
tokenizer: AutoTokenizer,
|
||||||
max_batch_size: int = 16,
|
max_batch_size: int = 16,
|
||||||
max_seq_len: Optional[int] = None,
|
max_seq_len: Optional[int] = None,
|
||||||
device: str = "cuda",
|
device: str = "cuda",
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,7 @@ from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from astrai.inference.engine import InferenceEngine
|
from astrai.inference.engine import InferenceEngine
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import TextTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -96,7 +96,7 @@ def load_model(
|
||||||
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
raise FileNotFoundError(f"Parameter directory not found: {param_path}")
|
||||||
|
|
||||||
# Load tokenizer separately
|
# Load tokenizer separately
|
||||||
tokenizer = TextTokenizer.from_pretrained(param_path)
|
tokenizer = AutoTokenizer.from_pretrained(param_path)
|
||||||
_model_param = AutoModel.from_pretrained(param_path)
|
_model_param = AutoModel.from_pretrained(param_path)
|
||||||
_model_param.to(device=device, dtype=dtype)
|
_model_param.to(device=device, dtype=dtype)
|
||||||
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
logger.info(f"Model loaded on {device} with dtype {dtype}")
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
from astrai.model.automodel import AutoModel
|
||||||
from astrai.model.module import (
|
from astrai.model.module import (
|
||||||
GQA,
|
GQA,
|
||||||
MLP,
|
MLP,
|
||||||
|
|
@ -6,8 +7,6 @@ from astrai.model.module import (
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
)
|
)
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
from astrai.model.automodel import AutoModel
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Modules
|
# Modules
|
||||||
|
|
|
||||||
|
|
@ -2,12 +2,12 @@
|
||||||
AutoModel base class for model loading and saving.
|
AutoModel base class for model loading and saving.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import torch.nn as nn
|
|
||||||
import safetensors.torch as st
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Self, Union, Dict, Type
|
from pathlib import Path
|
||||||
|
from typing import Dict, Self, Type, Union
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
|
import torch.nn as nn
|
||||||
|
|
||||||
from astrai.config import ModelConfig
|
from astrai.config import ModelConfig
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,23 +1,15 @@
|
||||||
|
from astrai.tokenize.chat_template import ChatTemplate, MessageType
|
||||||
from astrai.tokenize.tokenizer import (
|
from astrai.tokenize.tokenizer import (
|
||||||
TextTokenizer,
|
AutoTokenizer,
|
||||||
BpeTokenizer,
|
BpeTokenizer,
|
||||||
)
|
)
|
||||||
from astrai.tokenize.trainer import BpeTrainer
|
from astrai.tokenize.trainer import BpeTrainer
|
||||||
from astrai.tokenize.chat_template import (
|
|
||||||
ChatTemplate,
|
|
||||||
HistoryType,
|
|
||||||
MessageType,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Alias for compatibility
|
|
||||||
AutoTokenizer = TextTokenizer
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"TextTokenizer",
|
|
||||||
"AutoTokenizer",
|
"AutoTokenizer",
|
||||||
"BpeTokenizer",
|
"BpeTokenizer",
|
||||||
"BpeTrainer",
|
"BpeTrainer",
|
||||||
"ChatTemplate",
|
"ChatTemplate",
|
||||||
"HistoryType",
|
|
||||||
"MessageType",
|
"MessageType",
|
||||||
|
"HistoryType",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from typing import Dict, List, Optional, Tuple, Any
|
|
||||||
from jinja2 import Template
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
HistoryType = List[Tuple[str, str]]
|
from jinja2 import Template
|
||||||
MessageType = Dict[str, str]
|
|
||||||
|
# Message type for chat messages
|
||||||
|
type MessageType = Dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -8,10 +8,11 @@ from typing import Dict, List, Optional, Union
|
||||||
|
|
||||||
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
|
||||||
from tokenizers.models import BPE
|
from tokenizers.models import BPE
|
||||||
|
|
||||||
from astrai.tokenize.chat_template import ChatTemplate
|
from astrai.tokenize.chat_template import ChatTemplate
|
||||||
|
|
||||||
|
|
||||||
class TextTokenizer:
|
class AutoTokenizer:
|
||||||
"""Base tokenizer class with automatic loading support"""
|
"""Base tokenizer class with automatic loading support"""
|
||||||
|
|
||||||
TOKENIZER_CLASSES = {} # Registry for auto-loading
|
TOKENIZER_CLASSES = {} # Registry for auto-loading
|
||||||
|
|
@ -51,7 +52,7 @@ class TextTokenizer:
|
||||||
self.set_chat_template(config["chat_template"])
|
self.set_chat_template(config["chat_template"])
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "TextTokenizer":
|
def from_pretrained(cls, path: Union[str, Path], **kwargs) -> "AutoTokenizer":
|
||||||
"""Load tokenizer from pretrained directory."""
|
"""Load tokenizer from pretrained directory."""
|
||||||
instance = cls(path)
|
instance = cls(path)
|
||||||
return instance
|
return instance
|
||||||
|
|
@ -235,7 +236,7 @@ class TextTokenizer:
|
||||||
return rendered
|
return rendered
|
||||||
|
|
||||||
|
|
||||||
class BpeTokenizer(TextTokenizer):
|
class BpeTokenizer(AutoTokenizer):
|
||||||
"""BPE tokenizer implementation."""
|
"""BPE tokenizer implementation."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
TrainCallback,
|
|
||||||
CallbackFactory,
|
CallbackFactory,
|
||||||
|
TrainCallback,
|
||||||
)
|
)
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,9 @@ import torch.nn as nn
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.factory import BaseFactory
|
||||||
from astrai.parallel import only_on_rank
|
from astrai.parallel import only_on_rank
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.metric_util import (
|
from astrai.trainer.metric_util import (
|
||||||
ctx_get_grad_max,
|
ctx_get_grad_max,
|
||||||
ctx_get_grad_mean,
|
ctx_get_grad_mean,
|
||||||
|
|
@ -21,7 +22,6 @@ from astrai.trainer.metric_util import (
|
||||||
ctx_get_lr,
|
ctx_get_lr,
|
||||||
)
|
)
|
||||||
from astrai.trainer.train_context import TrainContext
|
from astrai.trainer.train_context import TrainContext
|
||||||
from astrai.factory import BaseFactory
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
|
|
|
||||||
|
|
@ -8,8 +8,8 @@ from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.dataset import ResumableDistributedSampler
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,11 @@ import logging
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from astrai.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.parallel.setup import spawn_parallel_fn
|
from astrai.parallel.setup import spawn_parallel_fn
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
TrainCallback,
|
|
||||||
CallbackFactory,
|
CallbackFactory,
|
||||||
|
TrainCallback,
|
||||||
)
|
)
|
||||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.inference import InferenceEngine
|
from astrai.inference import InferenceEngine
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
|
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
PROJECT_ROOT = Path(__file__).resolve().parents[2]
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ import json
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from astrai.inference import InferenceEngine
|
||||||
from astrai.model import AutoModel
|
from astrai.model import AutoModel
|
||||||
from astrai.tokenize import AutoTokenizer
|
from astrai.tokenize import AutoTokenizer
|
||||||
from astrai.inference import InferenceEngine
|
|
||||||
|
|
||||||
|
|
||||||
def processor(
|
def processor(
|
||||||
|
|
|
||||||
|
|
@ -2,10 +2,10 @@ import argparse
|
||||||
import os
|
import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
|
import safetensors.torch as st
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
import safetensors.torch as st
|
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.config import ModelConfig, TrainConfig
|
from astrai.config import ModelConfig, TrainConfig
|
||||||
|
|
|
||||||
|
|
@ -11,8 +11,8 @@ from tokenizers import pre_tokenizers
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from astrai.tokenize import BpeTokenizer, BpeTrainer
|
|
||||||
from astrai.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
from astrai.tokenize import BpeTokenizer, BpeTrainer
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(Dataset):
|
class RandomDataset(Dataset):
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,8 @@ import torch.distributed as dist
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
|
|
||||||
from astrai.serialization import Checkpoint
|
|
||||||
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
||||||
|
from astrai.serialization import Checkpoint
|
||||||
|
|
||||||
|
|
||||||
def test_single_process():
|
def test_single_process():
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler
|
from astrai.trainer.schedule import CosineScheduler, SchedulerFactory, SGDRScheduler
|
||||||
|
|
||||||
|
|
||||||
def test_schedule_factory_random_configs():
|
def test_schedule_factory_random_configs():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue