diff --git a/khaosz/trainer/callback.py b/khaosz/trainer/callback.py index 6505833..ce3927a 100644 --- a/khaosz/trainer/callback.py +++ b/khaosz/trainer/callback.py @@ -154,6 +154,11 @@ class SchedulerCallback(TrainerCallback): def on_train_begin(self, trainer: 'Trainer', **kwargs): checkpoint = cast(Checkpoint, kwargs.get('checkpoint')) self.current_iter = len(checkpoint.loss_list) + + for group in trainer.train_config.optimizer.param_groups: + if "initial_lr" not in group: + group["initial_lr"] = group["lr"] + self.schedule_config.validate() lambda_scheduler_fn = SchedulerFactory.load_schedule_fn( self.schedule_config diff --git a/khaosz/trainer/trainer.py b/khaosz/trainer/trainer.py index 3d44968..f374a65 100644 --- a/khaosz/trainer/trainer.py +++ b/khaosz/trainer/trainer.py @@ -1,4 +1,5 @@ import torch +import itertools from typing import Optional, List from torch.utils.data import DataLoader, RandomSampler @@ -38,15 +39,20 @@ class Trainer: SchedulerCallback(self.schedule_config), ] - def _create_dataloader(self) -> DataLoader: + def _create_dataloader(self, start_index: int = 0) -> DataLoader: seed = self.train_config.random_seed generator = torch.Generator().manual_seed(seed) sampler = RandomSampler(self.train_config.dataset, generator=generator) - return DataLoader( + dataloader = DataLoader( self.train_config.dataset, batch_size=self.train_config.batch_size, sampler=sampler ) + + if start_index > 0: + dataloader = itertools.islice(dataloader, start_index, None) + + return dataloader def _call_callbacks(self, method_name: str, **kwargs): for callback in self.callbacks: @@ -58,34 +64,29 @@ class Trainer: self, train_checkpoint: Optional[Checkpoint] = None ) -> Checkpoint: - assert self.schedule_config.schedule_type in ["cosine", "sgdr"] - + if train_checkpoint: self.checkpoint = train_checkpoint self.train_config.optimizer.load_state_dict(train_checkpoint.optim_state) - - self.checkpoint.optim_state = self.train_config.optimizer.state_dict() + else: + self.checkpoint.optim_state = self.train_config.optimizer.state_dict() + current_iter = len(self.checkpoint.loss_list) - - for group in self.train_config.optimizer.param_groups: - if "initial_lr" not in group: - group["initial_lr"] = group["lr"] + total_steps_per_epoch = len(self.train_config.dataset) // self.train_config.batch_size + total_reamining_steps = total_steps_per_epoch * self.train_config.n_epoch - current_iter - reamining_steps = self.train_config.n_epoch - current_iter - total_steps = len(self.train_config.dataset) // self.train_config.batch_size - remaining_epochs = (reamining_steps + total_steps - 1) // total_steps + current_epochs = total_reamining_steps // total_steps_per_epoch + current_steps = total_reamining_steps % total_steps_per_epoch + # train self._call_callbacks('on_train_begin', checkpoint=self.checkpoint) + self.checkpoint.model.train() try: - for epoch in range(remaining_epochs): - self.checkpoint.model.train() - + for epoch in range(current_epochs): # epoch self._call_callbacks('on_epoch_begin', epoch=epoch) - - dataloader = self._create_dataloader() - + dataloader = self._create_dataloader(start_index=current_steps) for batch in dataloader: # batch self._call_callbacks('on_batch_begin', batch=batch) @@ -110,5 +111,4 @@ class Trainer: finally: self._call_callbacks('on_train_end', checkpoint=self.checkpoint) - - return self.checkpoint \ No newline at end of file + return self.checkpoint \ No newline at end of file diff --git a/tests/test_trainer.py b/tests/test_trainer.py index a94b1d2..971aaa5 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -204,9 +204,9 @@ def test_checkpoint_train(test_env): dataset=dataset, optimizer=optimizer, checkpoint_dir=test_env["test_dir"], - n_epoch=2, + n_epoch=1, batch_size=2, - checkpoint_interval=5, + checkpoint_interval=1, accumulation_steps=1, max_grad_norm=1.0, random_seed=42 @@ -218,13 +218,18 @@ def test_checkpoint_train(test_env): pad_token_id=test_env["tokenizer"].pad_id ) schedule_config = CosineScheduleConfig( - warmup_steps=100, - total_steps=1000 + warmup_steps=1, + total_steps=5 ) trainer = Trainer(param, train_config, schedule_config) + checkpoint = None + try: - trainer.train() + checkpoint = trainer.train() except Exception: - checkpoint = trainer.checkpoint - trainer.train(train_checkpoint=checkpoint) \ No newline at end of file + pass + + checkpoint = trainer.train(train_checkpoint=checkpoint) + assert len(checkpoint.loss_list) == 5 - 1 + \ No newline at end of file