diff --git a/astrai/__init__.py b/astrai/__init__.py index 65ab6a6..664f129 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -6,7 +6,7 @@ from astrai.config import ( TrainConfig, ) from astrai.factory import BaseFactory -from astrai.data import DatasetFactory +from astrai.dataset import DatasetFactory from astrai.tokenizer import BpeTokenizer from astrai.inference.generator import ( BatchGenerator, diff --git a/astrai/data/__init__.py b/astrai/dataset/__init__.py similarity index 74% rename from astrai/data/__init__.py rename to astrai/dataset/__init__.py index d8df820..56735e5 100644 --- a/astrai/data/__init__.py +++ b/astrai/dataset/__init__.py @@ -1,10 +1,10 @@ -from astrai.data.dataset import ( +from astrai.dataset.dataset import ( BaseDataset, DatasetFactory, BaseSegmentFetcher, MultiSegmentFetcher, ) -from astrai.data.sampler import ResumableDistributedSampler +from astrai.dataset.sampler import ResumableDistributedSampler __all__ = [ # Base classes diff --git a/astrai/data/dataset.py b/astrai/dataset/dataset.py similarity index 100% rename from astrai/data/dataset.py rename to astrai/dataset/dataset.py diff --git a/astrai/data/sampler.py b/astrai/dataset/sampler.py similarity index 100% rename from astrai/data/sampler.py rename to astrai/dataset/sampler.py diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index abdb1fd..d689d4d 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader from astrai.config.train_config import TrainConfig -from astrai.data import ResumableDistributedSampler +from astrai.dataset import ResumableDistributedSampler from astrai.serialization import Checkpoint from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.trainer.strategy import BaseStrategy, StrategyFactory diff --git a/scripts/tools/train.py b/scripts/tools/train.py index ec14cff..1c8bd0f 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -8,7 +8,7 @@ import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP from astrai.config import ModelParameter, TrainConfig -from astrai.data import DatasetFactory +from astrai.dataset import DatasetFactory from astrai.parallel import get_rank from astrai.trainer import SchedulerFactory, Trainer diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 4b3f4e2..bb1d94c 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,7 +1,7 @@ import numpy as np import torch -from astrai.data.dataset import DatasetFactory +from astrai.dataset.dataset import DatasetFactory from astrai.serialization import save_h5 diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 045a567..1401ecc 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -1,4 +1,4 @@ -from astrai.data import ResumableDistributedSampler +from astrai.dataset import ResumableDistributedSampler def test_random_sampler_consistency(random_dataset):