chore: 将data 模块命名为dataset
This commit is contained in:
parent
bd9741dc5f
commit
9c31d78a22
|
|
@ -6,7 +6,7 @@ from astrai.config import (
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from astrai.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.data import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.tokenizer import BpeTokenizer
|
from astrai.tokenizer import BpeTokenizer
|
||||||
from astrai.inference.generator import (
|
from astrai.inference.generator import (
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from astrai.data.dataset import (
|
from astrai.dataset.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
DatasetFactory,
|
DatasetFactory,
|
||||||
BaseSegmentFetcher,
|
BaseSegmentFetcher,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
)
|
)
|
||||||
from astrai.data.sampler import ResumableDistributedSampler
|
from astrai.dataset.sampler import ResumableDistributedSampler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base classes
|
# Base classes
|
||||||
|
|
@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from astrai.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.data import ResumableDistributedSampler
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
from astrai.serialization import Checkpoint
|
from astrai.serialization import Checkpoint
|
||||||
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
from astrai.parallel.setup import get_current_device, get_rank, get_world_size
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import torch.optim as optim
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.config import ModelParameter, TrainConfig
|
from astrai.config import ModelParameter, TrainConfig
|
||||||
from astrai.data import DatasetFactory
|
from astrai.dataset import DatasetFactory
|
||||||
from astrai.parallel import get_rank
|
from astrai.parallel import get_rank
|
||||||
from astrai.trainer import SchedulerFactory, Trainer
|
from astrai.trainer import SchedulerFactory, Trainer
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.data.dataset import DatasetFactory
|
from astrai.dataset.dataset import DatasetFactory
|
||||||
from astrai.serialization import save_h5
|
from astrai.serialization import save_h5
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from astrai.data import ResumableDistributedSampler
|
from astrai.dataset import ResumableDistributedSampler
|
||||||
|
|
||||||
|
|
||||||
def test_random_sampler_consistency(random_dataset):
|
def test_random_sampler_consistency(random_dataset):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue