refactor(khaosz/trainer): 重构训练器模块的导入路径和文件命名
This commit is contained in:
parent
e467420475
commit
198c1ac55c
|
|
@ -1,4 +1,4 @@
|
|||
from khaosz.trainer.dataset import DatasetLoader
|
||||
from khaosz.trainer.data_util import DatasetLoader
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.strategy import (
|
||||
TrainConfig,
|
||||
|
|
@ -7,7 +7,7 @@ from khaosz.trainer.strategy import (
|
|||
StrategyFactory,
|
||||
SchedulerFactory
|
||||
)
|
||||
from khaosz.trainer.callback import (
|
||||
from khaosz.trainer.trainer_callback import (
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
TrainerCallback,
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import bisect
|
|||
import pickle as pkl
|
||||
from abc import ABC, abstractmethod
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
from typing import Callable, List, Dict, Literal, Union
|
||||
|
||||
MutiSeg = Dict[str, List[Tensor]]
|
||||
|
|
@ -265,3 +265,57 @@ class DatasetLoader:
|
|||
dataset.load(load_path)
|
||||
|
||||
return dataset
|
||||
|
||||
|
||||
class RandomSampler(Sampler[int]):
|
||||
def __init__(self, data_source, generator=None, seed=42):
|
||||
self.data_source = data_source
|
||||
self.seed = seed
|
||||
self.epoch = 0
|
||||
self.current_index = 0
|
||||
self._indices = None
|
||||
|
||||
if generator is None:
|
||||
self.generator = torch.Generator()
|
||||
self.generator.manual_seed(seed)
|
||||
else:
|
||||
self.generator = generator
|
||||
|
||||
def _generate_indices(self):
|
||||
n = len(self.data_source)
|
||||
self._indices = torch.randperm(n, generator=self.generator).tolist()
|
||||
|
||||
def __iter__(self):
|
||||
n = len(self.data_source)
|
||||
|
||||
if self._indices is None:
|
||||
self._generate_indices()
|
||||
|
||||
for i in range(self.current_index, n):
|
||||
yield self._indices[i]
|
||||
|
||||
self.epoch += 1
|
||||
self.current_index = 0
|
||||
self._indices = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source) - self.current_index
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'epoch': self.epoch,
|
||||
'current_index': self.current_index,
|
||||
'seed': self.seed,
|
||||
'generator_state': self.generator.get_state() if self.generator else None,
|
||||
'indices': self._indices
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.epoch = state_dict['epoch']
|
||||
self.current_index = state_dict['current_index']
|
||||
self.seed = state_dict['seed']
|
||||
|
||||
if self.generator and state_dict['generator_state'] is not None:
|
||||
self.generator.set_state(state_dict['generator_state'])
|
||||
|
||||
self._indices = state_dict['indices']
|
||||
|
|
@ -1,11 +1,12 @@
|
|||
import torch
|
||||
import itertools
|
||||
from typing import Optional, List
|
||||
from torch.utils.data import DataLoader, RandomSampler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from khaosz.core import ModelParameter, Checkpoint
|
||||
from khaosz.trainer.data_util import RandomSampler
|
||||
from khaosz.trainer.strategy import TrainConfig, ScheduleConfig
|
||||
from khaosz.trainer.callback import (
|
||||
from khaosz.trainer.trainer_callback import (
|
||||
TrainerCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
|
|
@ -42,7 +43,11 @@ class Trainer:
|
|||
def _create_dataloader(self, start_index: int = 0) -> DataLoader:
|
||||
seed = self.train_config.random_seed
|
||||
generator = torch.Generator().manual_seed(seed)
|
||||
sampler = RandomSampler(self.train_config.dataset, generator=generator)
|
||||
sampler = RandomSampler(
|
||||
self.train_config.dataset,
|
||||
generator=generator,
|
||||
seed=seed
|
||||
)
|
||||
dataloader = DataLoader(
|
||||
self.train_config.dataset,
|
||||
batch_size=self.train_config.batch_size,
|
||||
|
|
|
|||
Loading…
Reference in New Issue