chore: 将data 模块命名为dataset
This commit is contained in:
parent
bd9741dc5f
commit
9c31d78a22
|
|
@ -6,7 +6,7 @@ from astrai.config import (
|
|||
TrainConfig,
|
||||
)
|
||||
from astrai.factory import BaseFactory
|
||||
from astrai.data import DatasetFactory
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.tokenizer import BpeTokenizer
|
||||
from astrai.inference.generator import (
|
||||
BatchGenerator,
|
||||
|
|
|
|||
|
|
@ -1,10 +1,10 @@
|
|||
from astrai.data.dataset import (
|
||||
from astrai.dataset.dataset import (
|
||||
BaseDataset,
|
||||
DatasetFactory,
|
||||
BaseSegmentFetcher,
|
||||
MultiSegmentFetcher,
|
||||
)
|
||||
from astrai.data.sampler import ResumableDistributedSampler
|
||||
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
|
|
@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler
|
|||
from torch.utils.data import DataLoader
|
||||
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.data import ResumableDistributedSampler
|
||||
from astrai.dataset import ResumableDistributedSampler
|
||||
from astrai.serialization import Checkpoint
|
||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ import torch.optim as optim
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from astrai.config import ModelParameter, TrainConfig
|
||||
from astrai.data import DatasetFactory
|
||||
from astrai.dataset import DatasetFactory
|
||||
from astrai.parallel import get_rank
|
||||
from astrai.trainer import SchedulerFactory, Trainer
|
||||
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from astrai.data.dataset import DatasetFactory
|
||||
from astrai.dataset.dataset import DatasetFactory
|
||||
from astrai.serialization import save_h5
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from astrai.data import ResumableDistributedSampler
|
||||
from astrai.dataset import ResumableDistributedSampler
|
||||
|
||||
|
||||
def test_random_sampler_consistency(random_dataset):
|
||||
|
|
|
|||
Loading…
Reference in New Issue