chore: 将data 模块命名为dataset

This commit is contained in:
ViperEkura 2026-04-04 16:16:27 +08:00
parent bd9741dc5f
commit 9c31d78a22
8 changed files with 7 additions and 7 deletions

View File

@ -6,7 +6,7 @@ from astrai.config import (
TrainConfig, TrainConfig,
) )
from astrai.factory import BaseFactory from astrai.factory import BaseFactory
from astrai.data import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.tokenizer import BpeTokenizer from astrai.tokenizer import BpeTokenizer
from astrai.inference.generator import ( from astrai.inference.generator import (
BatchGenerator, BatchGenerator,

View File

@ -1,10 +1,10 @@
from astrai.data.dataset import ( from astrai.dataset.dataset import (
BaseDataset, BaseDataset,
DatasetFactory, DatasetFactory,
BaseSegmentFetcher, BaseSegmentFetcher,
MultiSegmentFetcher, MultiSegmentFetcher,
) )
from astrai.data.sampler import ResumableDistributedSampler from astrai.dataset.sampler import ResumableDistributedSampler
__all__ = [ __all__ = [
# Base classes # Base classes

View File

@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from astrai.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from astrai.data import ResumableDistributedSampler from astrai.dataset import ResumableDistributedSampler
from astrai.serialization import Checkpoint from astrai.serialization import Checkpoint
from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.parallel.setup import get_current_device, get_rank, get_world_size
from astrai.trainer.strategy import BaseStrategy, StrategyFactory from astrai.trainer.strategy import BaseStrategy, StrategyFactory

View File

@ -8,7 +8,7 @@ import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from astrai.config import ModelParameter, TrainConfig from astrai.config import ModelParameter, TrainConfig
from astrai.data import DatasetFactory from astrai.dataset import DatasetFactory
from astrai.parallel import get_rank from astrai.parallel import get_rank
from astrai.trainer import SchedulerFactory, Trainer from astrai.trainer import SchedulerFactory, Trainer

View File

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import torch import torch
from astrai.data.dataset import DatasetFactory from astrai.dataset.dataset import DatasetFactory
from astrai.serialization import save_h5 from astrai.serialization import save_h5

View File

@ -1,4 +1,4 @@
from astrai.data import ResumableDistributedSampler from astrai.dataset import ResumableDistributedSampler
def test_random_sampler_consistency(random_dataset): def test_random_sampler_consistency(random_dataset):