From 9c31d78a2238ec542716df95f82046d96edb9e93 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Sat, 4 Apr 2026 16:16:27 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E5=B0=86data=20=E6=A8=A1=E5=9D=97?= =?UTF-8?q?=E5=91=BD=E5=90=8D=E4=B8=BAdataset?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrai/__init__.py | 2 +- astrai/{data => dataset}/__init__.py | 4 ++-- astrai/{data => dataset}/dataset.py | 0 astrai/{data => dataset}/sampler.py | 0 astrai/trainer/train_context.py | 2 +- scripts/tools/train.py | 2 +- tests/data/test_dataset.py | 2 +- tests/data/test_sampler.py | 2 +- 8 files changed, 7 insertions(+), 7 deletions(-) rename astrai/{data => dataset}/__init__.py (74%) rename astrai/{data => dataset}/dataset.py (100%) rename astrai/{data => dataset}/sampler.py (100%) diff --git a/astrai/__init__.py b/astrai/__init__.py index 65ab6a6..664f129 100644 --- a/astrai/__init__.py +++ b/astrai/__init__.py @@ -6,7 +6,7 @@ from astrai.config import ( TrainConfig, ) from astrai.factory import BaseFactory -from astrai.data import DatasetFactory +from astrai.dataset import DatasetFactory from astrai.tokenizer import BpeTokenizer from astrai.inference.generator import ( BatchGenerator, diff --git a/astrai/data/__init__.py b/astrai/dataset/__init__.py similarity index 74% rename from astrai/data/__init__.py rename to astrai/dataset/__init__.py index d8df820..56735e5 100644 --- a/astrai/data/__init__.py +++ b/astrai/dataset/__init__.py @@ -1,10 +1,10 @@ -from astrai.data.dataset import ( +from astrai.dataset.dataset import ( BaseDataset, DatasetFactory, BaseSegmentFetcher, MultiSegmentFetcher, ) -from astrai.data.sampler import ResumableDistributedSampler +from astrai.dataset.sampler import ResumableDistributedSampler __all__ = [ # Base classes diff --git a/astrai/data/dataset.py b/astrai/dataset/dataset.py similarity index 100% rename from astrai/data/dataset.py rename to astrai/dataset/dataset.py diff --git a/astrai/data/sampler.py b/astrai/dataset/sampler.py similarity index 100% rename from astrai/data/sampler.py rename to astrai/dataset/sampler.py diff --git a/astrai/trainer/train_context.py b/astrai/trainer/train_context.py index abdb1fd..d689d4d 100644 --- a/astrai/trainer/train_context.py +++ b/astrai/trainer/train_context.py @@ -7,7 +7,7 @@ from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader from astrai.config.train_config import TrainConfig -from astrai.data import ResumableDistributedSampler +from astrai.dataset import ResumableDistributedSampler from astrai.serialization import Checkpoint from astrai.parallel.setup import get_current_device, get_rank, get_world_size from astrai.trainer.strategy import BaseStrategy, StrategyFactory diff --git a/scripts/tools/train.py b/scripts/tools/train.py index ec14cff..1c8bd0f 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -8,7 +8,7 @@ import torch.optim as optim from torch.nn.parallel import DistributedDataParallel as DDP from astrai.config import ModelParameter, TrainConfig -from astrai.data import DatasetFactory +from astrai.dataset import DatasetFactory from astrai.parallel import get_rank from astrai.trainer import SchedulerFactory, Trainer diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 4b3f4e2..bb1d94c 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,7 +1,7 @@ import numpy as np import torch -from astrai.data.dataset import DatasetFactory +from astrai.dataset.dataset import DatasetFactory from astrai.serialization import save_h5 diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 045a567..1401ecc 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -1,4 +1,4 @@ -from astrai.data import ResumableDistributedSampler +from astrai.dataset import ResumableDistributedSampler def test_random_sampler_consistency(random_dataset):