diff --git a/khaosz/trainer/train_context.py b/khaosz/trainer/train_context.py index 90af6c7..ad4bdf4 100644 --- a/khaosz/trainer/train_context.py +++ b/khaosz/trainer/train_context.py @@ -4,7 +4,7 @@ from torch.optim.lr_scheduler import LRScheduler from torch.utils.data import DataLoader from khaosz.data import ResumableDistributedSampler -from khaosz.data.checkpoint import Checkpoint +from khaosz.data.serialization import Checkpoint from khaosz.trainer.strategy import StrategyFactory, BaseStrategy from khaosz.config.train_config import TrainConfig from khaosz.parallel.setup import get_current_device, get_world_size, get_rank diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 89faf26..e1155d9 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -10,7 +10,7 @@ from khaosz.trainer.train_callback import ( SchedulerCallback ) from khaosz.trainer.train_context import TrainContext, TrainContextBuilder -from khaosz.data.checkpoint import Checkpoint +from khaosz.data.serialization import Checkpoint from khaosz.parallel.setup import spawn_parallel_fn logger = logging.getLogger(__name__) diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index 2d264c5..57071e6 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -4,7 +4,7 @@ import torch.distributed as dist from torch.optim import AdamW from torch.optim.lr_scheduler import CosineAnnealingLR -from khaosz.data.checkpoint import Checkpoint +from khaosz.data.serialization import Checkpoint from khaosz.parallel.setup import get_rank, spawn_parallel_fn def test_single_process(): diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index 7070bc5..b8893f9 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -3,7 +3,7 @@ import torch import numpy as np from khaosz.config import * from khaosz.trainer import * -from khaosz.data.checkpoint import Checkpoint +from khaosz.data.serialization import Checkpoint def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior"""