refactor(khaosz/trainer): 重构训练器模块的导入路径和文件命名

This commit is contained in:
ViperEkura 2025-09-29 19:35:17 +08:00
parent e467420475
commit 198c1ac55c
4 changed files with 66 additions and 7 deletions

View File

@ -1,4 +1,4 @@
from khaosz.trainer.dataset import DatasetLoader from khaosz.trainer.data_util import DatasetLoader
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
from khaosz.trainer.strategy import ( from khaosz.trainer.strategy import (
TrainConfig, TrainConfig,
@ -7,7 +7,7 @@ from khaosz.trainer.strategy import (
StrategyFactory, StrategyFactory,
SchedulerFactory SchedulerFactory
) )
from khaosz.trainer.callback import ( from khaosz.trainer.trainer_callback import (
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
TrainerCallback, TrainerCallback,

View File

@ -3,7 +3,7 @@ import bisect
import pickle as pkl import pickle as pkl
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset, Sampler
from typing import Callable, List, Dict, Literal, Union from typing import Callable, List, Dict, Literal, Union
MutiSeg = Dict[str, List[Tensor]] MutiSeg = Dict[str, List[Tensor]]
@ -264,4 +264,58 @@ class DatasetLoader:
dataset = dataset_router[train_type](max_len, device) dataset = dataset_router[train_type](max_len, device)
dataset.load(load_path) dataset.load(load_path)
return dataset return dataset
class RandomSampler(Sampler[int]):
def __init__(self, data_source, generator=None, seed=42):
self.data_source = data_source
self.seed = seed
self.epoch = 0
self.current_index = 0
self._indices = None
if generator is None:
self.generator = torch.Generator()
self.generator.manual_seed(seed)
else:
self.generator = generator
def _generate_indices(self):
n = len(self.data_source)
self._indices = torch.randperm(n, generator=self.generator).tolist()
def __iter__(self):
n = len(self.data_source)
if self._indices is None:
self._generate_indices()
for i in range(self.current_index, n):
yield self._indices[i]
self.epoch += 1
self.current_index = 0
self._indices = None
def __len__(self):
return len(self.data_source) - self.current_index
def state_dict(self):
return {
'epoch': self.epoch,
'current_index': self.current_index,
'seed': self.seed,
'generator_state': self.generator.get_state() if self.generator else None,
'indices': self._indices
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.current_index = state_dict['current_index']
self.seed = state_dict['seed']
if self.generator and state_dict['generator_state'] is not None:
self.generator.set_state(state_dict['generator_state'])
self._indices = state_dict['indices']

View File

@ -1,11 +1,12 @@
import torch import torch
import itertools import itertools
from typing import Optional, List from typing import Optional, List
from torch.utils.data import DataLoader, RandomSampler from torch.utils.data import DataLoader
from khaosz.core import ModelParameter, Checkpoint from khaosz.core import ModelParameter, Checkpoint
from khaosz.trainer.data_util import RandomSampler
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
from khaosz.trainer.callback import ( from khaosz.trainer.trainer_callback import (
TrainerCallback, TrainerCallback,
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
@ -42,7 +43,11 @@ class Trainer:
def _create_dataloader(self, start_index: int = 0) -> DataLoader: def _create_dataloader(self, start_index: int = 0) -> DataLoader:
seed = self.train_config.random_seed seed = self.train_config.random_seed
generator = torch.Generator().manual_seed(seed) generator = torch.Generator().manual_seed(seed)
sampler = RandomSampler(self.train_config.dataset, generator=generator) sampler = RandomSampler(
self.train_config.dataset,
generator=generator,
seed=seed
)
dataloader = DataLoader( dataloader = DataLoader(
self.train_config.dataset, self.train_config.dataset,
batch_size=self.train_config.batch_size, batch_size=self.train_config.batch_size,