refactor(khaosz): 重构项目结构
This commit is contained in:
parent
8434c19923
commit
c51b203fde
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from khaosz.core.transformer import TransformerConfig, Transformer
|
from khaosz.model.transformer import TransformerConfig, Transformer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,23 @@
|
||||||
__version__ = "1.3.0"
|
__version__ = "1.3.0"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from khaosz.model import Khaosz
|
from khaosz.khaosz import Khaosz
|
||||||
from khaosz.core.transformer import Transformer, TransformerConfig
|
from khaosz.config import (
|
||||||
|
TransformerConfig,
|
||||||
|
ParameterLoader,
|
||||||
|
TrainConfig,
|
||||||
|
)
|
||||||
|
from khaosz.model.transformer import Transformer
|
||||||
from khaosz.utils.retriever import Retriever
|
from khaosz.utils.retriever import Retriever
|
||||||
from khaosz.utils.splitter import (
|
from khaosz.utils.splitter import (
|
||||||
SemanticTextSplitter,
|
SemanticTextSplitter,
|
||||||
PriorityTextSplitter
|
PriorityTextSplitter
|
||||||
)
|
)
|
||||||
from khaosz.core.tokenizer import BpeTokenizer
|
from khaosz.data import (
|
||||||
from khaosz.core.parameter import ParameterLoader
|
DatasetLoader,
|
||||||
from khaosz.core.generator import (
|
BpeTokenizer
|
||||||
|
)
|
||||||
|
from khaosz.inference.generator import (
|
||||||
TextGenerator,
|
TextGenerator,
|
||||||
ChatGenerator,
|
ChatGenerator,
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
|
|
@ -18,10 +25,9 @@ from khaosz.core.generator import (
|
||||||
RetrievalGenerator,
|
RetrievalGenerator,
|
||||||
EmbeddingEncoder
|
EmbeddingEncoder
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.trainer import (
|
from khaosz.trainer import (
|
||||||
Trainer,
|
Trainer,
|
||||||
DatasetLoader,
|
|
||||||
TrainConfig,
|
|
||||||
StrategyFactory,
|
StrategyFactory,
|
||||||
SchedulerFactory
|
SchedulerFactory
|
||||||
)
|
)
|
||||||
|
|
@ -44,7 +50,7 @@ __all__ = [
|
||||||
|
|
||||||
# trainer
|
# trainer
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"DatasetLoader",
|
"DatasetLoader", # 保持在 __all__ 中,但来源是 khaosz.data
|
||||||
"TrainConfig",
|
"TrainConfig",
|
||||||
"StrategyFactory",
|
"StrategyFactory",
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
|
@ -53,4 +59,4 @@ __all__ = [
|
||||||
"Retriever",
|
"Retriever",
|
||||||
"SemanticTextSplitter",
|
"SemanticTextSplitter",
|
||||||
"PriorityTextSplitter",
|
"PriorityTextSplitter",
|
||||||
]
|
]
|
||||||
|
|
@ -0,0 +1,12 @@
|
||||||
|
from khaosz.config.model_config import TransformerConfig
|
||||||
|
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
||||||
|
from khaosz.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseModelIO",
|
||||||
|
"ModelParameter",
|
||||||
|
"Checkpoint",
|
||||||
|
"ParameterLoader",
|
||||||
|
"TransformerConfig",
|
||||||
|
"TrainConfig"
|
||||||
|
]
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
import json
|
||||||
|
|
||||||
|
from dataclasses import asdict, dataclass
|
||||||
|
from typing import Optional, Self
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class TransformerConfig:
|
||||||
|
# basic config
|
||||||
|
vocab_size: Optional[int] = None
|
||||||
|
n_dim: Optional[int] = None
|
||||||
|
n_head: Optional[int] = None
|
||||||
|
n_layer: Optional[int] = None
|
||||||
|
m_len: Optional[int] = None
|
||||||
|
norm_eps: Optional[float] = None
|
||||||
|
d_ffn: Optional[int] = None
|
||||||
|
|
||||||
|
# GQA
|
||||||
|
n_kvhead: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
def load(self, config_path: str) -> Self:
|
||||||
|
with open(config_path, 'r') as f:
|
||||||
|
config: dict = json.load(f)
|
||||||
|
for key, value in config.items():
|
||||||
|
if hasattr(self, key):
|
||||||
|
setattr(self, key, value)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def save(self, config_path: str) -> None:
|
||||||
|
config_dict = asdict(self)
|
||||||
|
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||||
|
with open(config_path, 'w') as f:
|
||||||
|
json.dump(config_dict, f, indent=4)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -8,8 +8,9 @@ from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional, Self, Union
|
from typing import Any, Dict, List, Optional, Self, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from khaosz.core.tokenizer import BpeTokenizer
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.core.transformer import TransformerConfig, Transformer
|
from khaosz.config.model_config import TransformerConfig
|
||||||
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
class BaseModelIO:
|
class BaseModelIO:
|
||||||
|
|
@ -99,18 +100,18 @@ class Checkpoint(BaseModelIO):
|
||||||
metadata={"help": "Transformer model."}
|
metadata={"help": "Transformer model."}
|
||||||
)
|
)
|
||||||
tokenizer: BpeTokenizer = field(
|
tokenizer: BpeTokenizer = field(
|
||||||
default_factory=BpeTokenizer,
|
default=None,
|
||||||
metadata={"help": "Tokenizer for the model."}
|
metadata={"help": "Tokenizer for the model."}
|
||||||
)
|
)
|
||||||
config: TransformerConfig = field(
|
config: TransformerConfig = field(
|
||||||
default_factory=TransformerConfig,
|
default=None,
|
||||||
metadata={"help": "Transformer model configuration."}
|
metadata={"help": "Transformer model configuration."}
|
||||||
)
|
)
|
||||||
optimizer_state: Dict[str, Any] = field(
|
optimizer_state: Dict[str, Any] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Optimizer state."}
|
metadata={"help": "Optimizer state."}
|
||||||
)
|
)
|
||||||
sampler_state: Dict[str, Any] = field(
|
scheduler_state: Dict[str, Any] = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Sampler state."}
|
metadata={"help": "Sampler state."}
|
||||||
)
|
)
|
||||||
|
|
@ -142,10 +143,10 @@ class Checkpoint(BaseModelIO):
|
||||||
# Save optimizer state
|
# Save optimizer state
|
||||||
with open(str(paths["optimizer_state"]), "wb") as f:
|
with open(str(paths["optimizer_state"]), "wb") as f:
|
||||||
pkl.dump(self.optimizer_state, f)
|
pkl.dump(self.optimizer_state, f)
|
||||||
|
|
||||||
# Save sampler state
|
# Save sampler state
|
||||||
with open(str(paths["sampler_state"]), "wb") as f:
|
with open(str(paths["sampler_state"]), "wb") as f:
|
||||||
pkl.dump(self.sampler_state, f)
|
pkl.dump(self.scheduler_state, f)
|
||||||
|
|
||||||
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
||||||
paths = self._get_training_paths(load_dir)
|
paths = self._get_training_paths(load_dir)
|
||||||
|
|
@ -163,7 +164,7 @@ class Checkpoint(BaseModelIO):
|
||||||
# Load sampler state
|
# Load sampler state
|
||||||
if paths["sampler_state"].exists():
|
if paths["sampler_state"].exists():
|
||||||
with open(str(paths["sampler_state"]), "rb") as f:
|
with open(str(paths["sampler_state"]), "rb") as f:
|
||||||
self.sampler_state = pkl.load(f)
|
self.scheduler_state = pkl.load(f)
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -173,7 +174,7 @@ class Checkpoint(BaseModelIO):
|
||||||
return
|
return
|
||||||
|
|
||||||
current_iter = len(self.loss_list)
|
current_iter = len(self.loss_list)
|
||||||
|
|
||||||
plt.figure(figsize=(10, 6))
|
plt.figure(figsize=(10, 6))
|
||||||
plt.plot(self.loss_list)
|
plt.plot(self.loss_list)
|
||||||
plt.title(f"Training Loss - Iteration {current_iter}")
|
plt.title(f"Training Loss - Iteration {current_iter}")
|
||||||
|
|
@ -1,14 +1,16 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
from typing import Optional, TYPE_CHECKING
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
from khaosz.trainer.strategy import BaseStrategy
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from khaosz.trainer.strategy import BaseStrategy
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class TrainConfig:
|
class TrainConfig:
|
||||||
|
|
||||||
strategy: BaseStrategy = field(
|
strategy: "BaseStrategy" = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Training strategy."}
|
metadata={"help": "Training strategy."}
|
||||||
)
|
)
|
||||||
|
|
@ -1,27 +0,0 @@
|
||||||
from khaosz.core.tokenizer import BpeTokenizer
|
|
||||||
from khaosz.core.transformer import Transformer, TransformerConfig
|
|
||||||
from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
|
|
||||||
from khaosz.core.generator import (
|
|
||||||
TextGenerator,
|
|
||||||
ChatGenerator,
|
|
||||||
StreamGenerator,
|
|
||||||
BatchGenerator,
|
|
||||||
RetrievalGenerator,
|
|
||||||
EmbeddingEncoder
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"Transformer",
|
|
||||||
"TransformerConfig",
|
|
||||||
"BpeTokenizer",
|
|
||||||
"ParameterLoader",
|
|
||||||
"ModelParameter",
|
|
||||||
"Checkpoint",
|
|
||||||
"TextGenerator",
|
|
||||||
"ChatGenerator",
|
|
||||||
"StreamGenerator",
|
|
||||||
"BatchGenerator",
|
|
||||||
"RetrievalGenerator",
|
|
||||||
"EmbeddingEncoder"
|
|
||||||
]
|
|
||||||
|
|
@ -0,0 +1,30 @@
|
||||||
|
from khaosz.data.data_util import (
|
||||||
|
BaseDataset,
|
||||||
|
SeqDataset,
|
||||||
|
DpoDataset,
|
||||||
|
SftDataset,
|
||||||
|
PpoDataset,
|
||||||
|
MutiSegmentFetcher,
|
||||||
|
ResumeableRandomSampler,
|
||||||
|
DatasetLoader,
|
||||||
|
load_pkl_files,
|
||||||
|
build_attention_mask,
|
||||||
|
build_loss_mask
|
||||||
|
)
|
||||||
|
|
||||||
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"BaseDataset",
|
||||||
|
"SeqDataset",
|
||||||
|
"DpoDataset",
|
||||||
|
"SftDataset",
|
||||||
|
"PpoDataset",
|
||||||
|
"MutiSegmentFetcher",
|
||||||
|
"ResumeableRandomSampler",
|
||||||
|
"DatasetLoader",
|
||||||
|
"load_pkl_files",
|
||||||
|
"build_attention_mask",
|
||||||
|
"build_loss_mask",
|
||||||
|
"BpeTokenizer"
|
||||||
|
]
|
||||||
|
|
@ -263,58 +263,39 @@ class DatasetLoader:
|
||||||
dataset.load(load_path)
|
dataset.load(load_path)
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
class RandomSampler(Sampler[int]):
|
|
||||||
def __init__(self, data_source, generator=None, seed=42):
|
class ResumeableRandomSampler(Sampler[int]):
|
||||||
self.data_source = data_source
|
def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42):
|
||||||
self.seed = seed
|
self.num_samples = len(data_source)
|
||||||
self.epoch = 0
|
self.epoch = start_epoch
|
||||||
self.current_iter = 0
|
self.iter = start_iter
|
||||||
self._indices = None
|
|
||||||
|
|
||||||
if generator is None:
|
generator = torch.Generator()
|
||||||
self.generator = torch.Generator()
|
generator.manual_seed(seed)
|
||||||
self.generator.manual_seed(seed)
|
|
||||||
else:
|
self.generator = generator
|
||||||
self.generator = generator
|
self._indices = None
|
||||||
|
|
||||||
def _generate_indices(self):
|
def _get_indices(self):
|
||||||
n = len(self.data_source)
|
for _ in range(self.epoch):
|
||||||
self._indices = torch.randperm(n, generator=self.generator).tolist()
|
_ = torch.randperm(self.num_samples, generator=self.generator)
|
||||||
|
|
||||||
|
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||||
|
self._indices = current_epoch_indices[self.iter % self.num_samples:]
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
n = len(self.data_source)
|
|
||||||
|
|
||||||
if self._indices is None:
|
if self._indices is None:
|
||||||
self._generate_indices()
|
self._get_indices()
|
||||||
|
|
||||||
start = self.current_iter % n
|
for i in self._indices:
|
||||||
for i in range(start, n):
|
self.iter += 1
|
||||||
self.current_iter += 1
|
yield i
|
||||||
yield self._indices[i]
|
|
||||||
|
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
self._indices = None
|
self._indices = None
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.data_source)
|
if self._indices is None:
|
||||||
|
self._get_indices()
|
||||||
def state_dict(self):
|
return len(self._indices)
|
||||||
return {
|
|
||||||
'epoch': self.epoch,
|
|
||||||
'current_iter': self.current_iter,
|
|
||||||
'seed': self.seed,
|
|
||||||
'generator_state': self.generator.get_state() if self.generator else None,
|
|
||||||
'indices': self._indices
|
|
||||||
}
|
|
||||||
|
|
||||||
def load_state_dict(self, state_dict):
|
|
||||||
self.epoch = state_dict['epoch']
|
|
||||||
self.current_iter = state_dict['current_iter']
|
|
||||||
self.seed = state_dict['seed']
|
|
||||||
|
|
||||||
if self.generator and state_dict['generator_state'] is not None:
|
|
||||||
self.generator.set_state(state_dict['generator_state'])
|
|
||||||
|
|
||||||
self._indices = state_dict['indices']
|
|
||||||
|
|
@ -0,0 +1,97 @@
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from typing import List, Tuple, Union, Optional, Generator, Self
|
||||||
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
|
class GeneratorCore:
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
self.model = parameter.model
|
||||||
|
self.tokenizer = parameter.tokenizer
|
||||||
|
self.config = parameter.config
|
||||||
|
|
||||||
|
def compute_logits(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
attn_mask: Optional[Tensor] = None,
|
||||||
|
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||||
|
start_pos: int = 0
|
||||||
|
) -> Tuple[Tensor, int]:
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||||
|
logits = outputs["logits"][:, -1, :]
|
||||||
|
cache_increase = input_ids.size(-1)
|
||||||
|
|
||||||
|
return logits, cache_increase
|
||||||
|
|
||||||
|
def to(self, *args, **kargs) -> Self:
|
||||||
|
self.model.to(*args, **kargs)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class EmbeddingEncoderCore:
|
||||||
|
def __init__(self, parameter: ModelParameter):
|
||||||
|
self.model = parameter.model
|
||||||
|
self.tokenizer = parameter.tokenizer
|
||||||
|
self.config = parameter.config
|
||||||
|
|
||||||
|
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||||
|
with_batch = isinstance(sentence, list)
|
||||||
|
ids = self.tokenizer.encode(sentence)
|
||||||
|
batch_ids = ids if with_batch else [ids]
|
||||||
|
max_model_len = self.config.m_len
|
||||||
|
|
||||||
|
all_fragments = []
|
||||||
|
fragment_origin_idx = []
|
||||||
|
|
||||||
|
for i, seq in enumerate(batch_ids):
|
||||||
|
if len(seq) > max_model_len:
|
||||||
|
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
||||||
|
all_fragments.extend(fragments)
|
||||||
|
fragment_origin_idx.extend([i] * len(fragments))
|
||||||
|
else:
|
||||||
|
all_fragments.append(seq)
|
||||||
|
fragment_origin_idx.append(i)
|
||||||
|
|
||||||
|
#if empty fragments
|
||||||
|
if not all_fragments or not ids:
|
||||||
|
return [] if with_batch else torch.tensor([])
|
||||||
|
|
||||||
|
device = next(self.model.parameters()).device
|
||||||
|
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
||||||
|
|
||||||
|
padded_ids = []
|
||||||
|
masks = []
|
||||||
|
for seq in all_fragments:
|
||||||
|
pad_len = max_len - len(seq)
|
||||||
|
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
|
||||||
|
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
||||||
|
padded_ids.append(padded_seq)
|
||||||
|
masks.append(mask)
|
||||||
|
|
||||||
|
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
||||||
|
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
||||||
|
|
||||||
|
with torch.inference_mode():
|
||||||
|
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
||||||
|
# [num_fragments, seq_len, hidden_size]
|
||||||
|
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
||||||
|
|
||||||
|
sentence_embs: List[Tensor] = []
|
||||||
|
for i in range(len(batch_ids)):
|
||||||
|
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
||||||
|
if indices is not None:
|
||||||
|
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
||||||
|
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
||||||
|
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||||
|
sentence_embs.append(emb.flatten())
|
||||||
|
|
||||||
|
if with_batch:
|
||||||
|
return [emb.flatten() for emb in sentence_embs]
|
||||||
|
else:
|
||||||
|
return sentence_embs[0].flatten()
|
||||||
|
|
||||||
|
def to(self, *args, **kargs) -> Self:
|
||||||
|
self.model.to(*args, **kargs)
|
||||||
|
return self
|
||||||
|
|
@ -1,7 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List, Tuple, Union, Optional, Generator, Self
|
from typing import List, Tuple, Union, Optional, Generator
|
||||||
from khaosz.core.parameter import ModelParameter
|
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore
|
||||||
|
from khaosz.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
|
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
|
||||||
|
|
@ -168,96 +169,6 @@ class KVCacheManager:
|
||||||
return self._seq_mask
|
return self._seq_mask
|
||||||
|
|
||||||
|
|
||||||
class GeneratorCore:
|
|
||||||
def __init__(self, parameter: ModelParameter):
|
|
||||||
self.model = parameter.model
|
|
||||||
self.tokenizer = parameter.tokenizer
|
|
||||||
self.config = parameter.config
|
|
||||||
|
|
||||||
def compute_logits(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
attn_mask: Optional[Tensor] = None,
|
|
||||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
|
||||||
start_pos: int = 0
|
|
||||||
) -> Tuple[Tensor, int]:
|
|
||||||
with torch.inference_mode():
|
|
||||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
|
||||||
logits = outputs["logits"][:, -1, :]
|
|
||||||
cache_increase = input_ids.size(-1)
|
|
||||||
|
|
||||||
return logits, cache_increase
|
|
||||||
|
|
||||||
def to(self, *args, **kargs) -> Self:
|
|
||||||
self.model.to(*args, **kargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class EmbeddingEncoderCore:
|
|
||||||
def __init__(self, parameter: ModelParameter):
|
|
||||||
self.model = parameter.model
|
|
||||||
self.tokenizer = parameter.tokenizer
|
|
||||||
self.config = parameter.config
|
|
||||||
|
|
||||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
|
||||||
with_batch = isinstance(sentence, list)
|
|
||||||
ids = self.tokenizer.encode(sentence)
|
|
||||||
batch_ids = ids if with_batch else [ids]
|
|
||||||
max_model_len = self.config.m_len
|
|
||||||
|
|
||||||
all_fragments = []
|
|
||||||
fragment_origin_idx = []
|
|
||||||
|
|
||||||
for i, seq in enumerate(batch_ids):
|
|
||||||
if len(seq) > max_model_len:
|
|
||||||
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
|
||||||
all_fragments.extend(fragments)
|
|
||||||
fragment_origin_idx.extend([i] * len(fragments))
|
|
||||||
else:
|
|
||||||
all_fragments.append(seq)
|
|
||||||
fragment_origin_idx.append(i)
|
|
||||||
|
|
||||||
#if empty fragments
|
|
||||||
if not all_fragments or not ids:
|
|
||||||
return [] if with_batch else torch.tensor([])
|
|
||||||
|
|
||||||
device = next(self.model.parameters()).device
|
|
||||||
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
|
||||||
|
|
||||||
padded_ids = []
|
|
||||||
masks = []
|
|
||||||
for seq in all_fragments:
|
|
||||||
pad_len = max_len - len(seq)
|
|
||||||
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
|
|
||||||
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
|
||||||
padded_ids.append(padded_seq)
|
|
||||||
masks.append(mask)
|
|
||||||
|
|
||||||
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
|
||||||
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
|
||||||
# [num_fragments, seq_len, hidden_size]
|
|
||||||
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
|
||||||
|
|
||||||
sentence_embs: List[Tensor] = []
|
|
||||||
for i in range(len(batch_ids)):
|
|
||||||
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
|
||||||
if indices is not None:
|
|
||||||
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
|
||||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
|
||||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
|
||||||
sentence_embs.append(emb.flatten())
|
|
||||||
|
|
||||||
if with_batch:
|
|
||||||
return [emb.flatten() for emb in sentence_embs]
|
|
||||||
else:
|
|
||||||
return sentence_embs[0].flatten()
|
|
||||||
|
|
||||||
def to(self, *args, **kargs) -> Self:
|
|
||||||
self.model.to(*args, **kargs)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
class TextGenerator(GeneratorCore):
|
class TextGenerator(GeneratorCore):
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List, Tuple, Generator, Union
|
from typing import List, Tuple, Generator, Union
|
||||||
|
|
||||||
from khaosz.core.generator import (
|
from khaosz.inference.generator import (
|
||||||
TextGenerator,
|
TextGenerator,
|
||||||
ChatGenerator,
|
ChatGenerator,
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
|
|
@ -9,7 +9,7 @@ from khaosz.core.generator import (
|
||||||
RetrievalGenerator,
|
RetrievalGenerator,
|
||||||
EmbeddingEncoder
|
EmbeddingEncoder
|
||||||
)
|
)
|
||||||
from khaosz.core.parameter import ParameterLoader
|
from khaosz.config.param_config import ParameterLoader
|
||||||
|
|
||||||
|
|
||||||
class Khaosz:
|
class Khaosz:
|
||||||
|
|
@ -0,0 +1,17 @@
|
||||||
|
from khaosz.model.module import (
|
||||||
|
Linear,
|
||||||
|
RMSNorm,
|
||||||
|
MLP,
|
||||||
|
GQA,
|
||||||
|
DecoderBlock,
|
||||||
|
)
|
||||||
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"Linear",
|
||||||
|
"RMSNorm",
|
||||||
|
"MLP",
|
||||||
|
"GQA",
|
||||||
|
"DecoderBlock",
|
||||||
|
"Transformer"
|
||||||
|
]
|
||||||
|
|
@ -1,12 +1,11 @@
|
||||||
import json
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn import init
|
from torch.nn import init
|
||||||
from dataclasses import asdict, dataclass
|
from typing import Optional, Tuple
|
||||||
from typing import List, Optional, Self, Tuple
|
|
||||||
|
|
||||||
|
|
||||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||||
|
|
@ -71,89 +70,6 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
||||||
|
|
||||||
return x_out.to(dtype)
|
return x_out.to(dtype)
|
||||||
|
|
||||||
def process_attention_mask(
|
|
||||||
seq_mask: Tensor,
|
|
||||||
start_pos: int = 0,
|
|
||||||
seq_len: int = 0,
|
|
||||||
is_causal: bool = False,
|
|
||||||
device: torch.device = "cuda",
|
|
||||||
dtype: torch.dtype = torch.float32
|
|
||||||
) -> Tensor:
|
|
||||||
"""
|
|
||||||
Create attention mask for GQA
|
|
||||||
Args:
|
|
||||||
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
|
||||||
start_pos (int): The starting position of the sequence.
|
|
||||||
seq_len (int): The length of the sequence.
|
|
||||||
is_causal (bool): Whether the attention is causal or not.
|
|
||||||
device (torch.device): The device to use.
|
|
||||||
Returns:
|
|
||||||
Tensor: The attention mask tensor.
|
|
||||||
"""
|
|
||||||
|
|
||||||
if seq_mask is None:
|
|
||||||
if start_pos != 0:
|
|
||||||
# for single prompt chat
|
|
||||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if seq_mask.dim() > 2:
|
|
||||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
|
||||||
# if ndim > 2, it's 4D tensor
|
|
||||||
return seq_mask
|
|
||||||
|
|
||||||
batch_size = seq_mask.size(0)
|
|
||||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
|
||||||
# (bsz, start_pos + seq_len)
|
|
||||||
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
|
||||||
# (bsz, seq_len, start_pos + seq_len)
|
|
||||||
|
|
||||||
if is_causal:
|
|
||||||
causal_mask = torch.tril(
|
|
||||||
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
|
|
||||||
diagonal=start_pos
|
|
||||||
)
|
|
||||||
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
|
|
||||||
expanded_mask = expanded_mask & causal_mask
|
|
||||||
|
|
||||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
|
||||||
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
|
||||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
|
||||||
|
|
||||||
return attention_mask
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TransformerConfig:
|
|
||||||
# basic config
|
|
||||||
vocab_size: Optional[int] = None
|
|
||||||
n_dim: Optional[int] = None
|
|
||||||
n_head: Optional[int] = None
|
|
||||||
n_layer: Optional[int] = None
|
|
||||||
m_len: Optional[int] = None
|
|
||||||
norm_eps: Optional[float] = None
|
|
||||||
d_ffn: Optional[int] = None
|
|
||||||
|
|
||||||
# GQA
|
|
||||||
n_kvhead: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
def load(self, config_path: str) -> Self:
|
|
||||||
with open(config_path, 'r') as f:
|
|
||||||
config: dict = json.load(f)
|
|
||||||
for key, value in config.items():
|
|
||||||
if hasattr(self, key):
|
|
||||||
setattr(self, key, value)
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def save(self, config_path: str) -> None:
|
|
||||||
config_dict = asdict(self)
|
|
||||||
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
|
||||||
with open(config_path, 'w') as f:
|
|
||||||
json.dump(config_dict, f, indent=4)
|
|
||||||
|
|
||||||
|
|
||||||
class Linear(nn.Module):
|
class Linear(nn.Module):
|
||||||
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
|
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
|
||||||
|
|
@ -287,60 +203,4 @@ class DecoderBlock(nn.Module):
|
||||||
# feed forward
|
# feed forward
|
||||||
x = self.ffn(self.norm_ffn(x)) + x
|
x = self.ffn(self.norm_ffn(x)) + x
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
class Transformer(nn.Module):
|
|
||||||
def __init__(self, config: TransformerConfig):
|
|
||||||
super().__init__()
|
|
||||||
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
|
|
||||||
self.layers = nn.ModuleList([
|
|
||||||
DecoderBlock(
|
|
||||||
config.n_dim,
|
|
||||||
config.n_head,
|
|
||||||
config.d_ffn,
|
|
||||||
config.n_kvhead,
|
|
||||||
config.norm_eps
|
|
||||||
)
|
|
||||||
for _ in range(config.n_layer)
|
|
||||||
])
|
|
||||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
|
||||||
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
|
||||||
init.normal_(self.embedding, mean=0, std=0.02)
|
|
||||||
|
|
||||||
def forward(
|
|
||||||
self,
|
|
||||||
input_ids: Tensor,
|
|
||||||
input_mask: Optional[Tensor]=None,
|
|
||||||
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
|
||||||
start_pos: int = 0
|
|
||||||
) -> Tensor:
|
|
||||||
assert input_ids.ndim == 2
|
|
||||||
seq_len = input_ids.size(-1)
|
|
||||||
x = F.embedding(input_ids, self.embedding)
|
|
||||||
|
|
||||||
self.freq_cis = self.freq_cis.to(x.device)
|
|
||||||
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
|
||||||
has_kvcache = persistent_key_values is not None
|
|
||||||
|
|
||||||
attn_mask = process_attention_mask(
|
|
||||||
input_mask,
|
|
||||||
start_pos=start_pos,
|
|
||||||
seq_len=seq_len,
|
|
||||||
is_causal=has_kvcache,
|
|
||||||
device=x.device,
|
|
||||||
dtype=x.dtype
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, layer in enumerate(self.layers):
|
|
||||||
kv_cache = persistent_key_values[i] if persistent_key_values else None
|
|
||||||
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
|
|
||||||
|
|
||||||
hidden_states = self.norm(x)
|
|
||||||
logits = F.linear(hidden_states, self.embedding)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"logits": logits,
|
|
||||||
"hidden_states": hidden_states
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,119 @@
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import init
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from khaosz.config.model_config import TransformerConfig
|
||||||
|
from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb
|
||||||
|
|
||||||
|
|
||||||
|
def process_attention_mask(
|
||||||
|
seq_mask: Tensor,
|
||||||
|
start_pos: int = 0,
|
||||||
|
seq_len: int = 0,
|
||||||
|
is_causal: bool = False,
|
||||||
|
device: torch.device = "cuda",
|
||||||
|
dtype: torch.dtype = torch.float32
|
||||||
|
) -> Tensor:
|
||||||
|
"""
|
||||||
|
Create attention mask for GQA
|
||||||
|
Args:
|
||||||
|
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
||||||
|
start_pos (int): The starting position of the sequence.
|
||||||
|
seq_len (int): The length of the sequence.
|
||||||
|
is_causal (bool): Whether the attention is causal or not.
|
||||||
|
device (torch.device): The device to use.
|
||||||
|
Returns:
|
||||||
|
Tensor: The attention mask tensor.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if seq_mask is None:
|
||||||
|
if start_pos != 0:
|
||||||
|
# for single prompt chat
|
||||||
|
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if seq_mask.dim() > 2:
|
||||||
|
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||||
|
# if ndim > 2, it's 4D tensor
|
||||||
|
return seq_mask
|
||||||
|
|
||||||
|
batch_size = seq_mask.size(0)
|
||||||
|
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||||
|
# (bsz, start_pos + seq_len)
|
||||||
|
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
||||||
|
# (bsz, seq_len, start_pos + seq_len)
|
||||||
|
|
||||||
|
if is_causal:
|
||||||
|
causal_mask = torch.tril(
|
||||||
|
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
|
||||||
|
diagonal=start_pos
|
||||||
|
)
|
||||||
|
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
|
||||||
|
expanded_mask = expanded_mask & causal_mask
|
||||||
|
|
||||||
|
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||||
|
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
||||||
|
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||||
|
|
||||||
|
return attention_mask
|
||||||
|
|
||||||
|
|
||||||
|
class Transformer(nn.Module):
|
||||||
|
def __init__(self, config: TransformerConfig):
|
||||||
|
super().__init__()
|
||||||
|
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
|
||||||
|
self.layers = nn.ModuleList([
|
||||||
|
DecoderBlock(
|
||||||
|
config.n_dim,
|
||||||
|
config.n_head,
|
||||||
|
config.d_ffn,
|
||||||
|
config.n_kvhead,
|
||||||
|
config.norm_eps
|
||||||
|
)
|
||||||
|
for _ in range(config.n_layer)
|
||||||
|
])
|
||||||
|
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||||
|
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
||||||
|
init.normal_(self.embedding, mean=0, std=0.02)
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
input_ids: Tensor,
|
||||||
|
input_mask: Optional[Tensor]=None,
|
||||||
|
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
||||||
|
start_pos: int = 0
|
||||||
|
) -> Tensor:
|
||||||
|
assert input_ids.ndim == 2
|
||||||
|
seq_len = input_ids.size(-1)
|
||||||
|
x = F.embedding(input_ids, self.embedding)
|
||||||
|
|
||||||
|
self.freq_cis = self.freq_cis.to(x.device)
|
||||||
|
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
||||||
|
has_kvcache = persistent_key_values is not None
|
||||||
|
|
||||||
|
attn_mask = process_attention_mask(
|
||||||
|
input_mask,
|
||||||
|
start_pos=start_pos,
|
||||||
|
seq_len=seq_len,
|
||||||
|
is_causal=has_kvcache,
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, layer in enumerate(self.layers):
|
||||||
|
kv_cache = persistent_key_values[i] if persistent_key_values else None
|
||||||
|
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
|
||||||
|
|
||||||
|
hidden_states = self.norm(x)
|
||||||
|
logits = F.linear(hidden_states, self.embedding)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"logits": logits,
|
||||||
|
"hidden_states": hidden_states
|
||||||
|
}
|
||||||
|
|
||||||
|
|
@ -1,6 +1,4 @@
|
||||||
from khaosz.trainer.data_util import DatasetLoader
|
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
from khaosz.trainer.train_config import TrainConfig
|
|
||||||
from khaosz.trainer.strategy import (
|
from khaosz.trainer.strategy import (
|
||||||
CosineScheduleConfig,
|
CosineScheduleConfig,
|
||||||
SgdrScheduleConfig,
|
SgdrScheduleConfig,
|
||||||
|
|
@ -17,19 +15,16 @@ from khaosz.trainer.train_callback import (
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"DatasetLoader",
|
|
||||||
"Trainer",
|
"Trainer",
|
||||||
"TrainConfig",
|
"StrategyFactory",
|
||||||
"CosineScheduleConfig",
|
"CosineScheduleConfig",
|
||||||
"SgdrScheduleConfig",
|
"SgdrScheduleConfig",
|
||||||
"StrategyFactory",
|
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
|
|
||||||
# callback
|
# callback
|
||||||
"TrainCallback",
|
"TrainCallback",
|
||||||
"ProgressBarCallback",
|
"ProgressBarCallback",
|
||||||
"CheckpointCallback",
|
"CheckpointCallback",
|
||||||
"TrainCallback",
|
|
||||||
"SchedulerCallback",
|
"SchedulerCallback",
|
||||||
"StepMonitorCallback"
|
"StepMonitorCallback"
|
||||||
]
|
]
|
||||||
|
|
@ -106,7 +106,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
|
|
||||||
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
|
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
|
||||||
context.checkpoint.sampler_state = context.sampler.state_dict()
|
# context.checkpoint.scheduler_state = context.sampler.state_dict()
|
||||||
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
||||||
context.checkpoint.save(save_path)
|
context.checkpoint.save(save_path)
|
||||||
self.last_ckpt_iter = context.current_iter
|
self.last_ckpt_iter = context.current_iter
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
from dataclasses import dataclass, field, fields
|
from dataclasses import dataclass, field, fields
|
||||||
from typing import Optional, Self, TYPE_CHECKING
|
from typing import Optional, Self, TYPE_CHECKING
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
from khaosz.core.parameter import Checkpoint
|
from khaosz.config.param_config import Checkpoint
|
||||||
from khaosz.trainer.data_util import RandomSampler
|
from khaosz.data.data_util import ResumeableRandomSampler
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from khaosz.trainer.trainer import Trainer
|
from khaosz.trainer.trainer import Trainer
|
||||||
|
|
@ -13,11 +14,11 @@ if TYPE_CHECKING:
|
||||||
class TrainContext:
|
class TrainContext:
|
||||||
dataloader: DataLoader = field(default=None)
|
dataloader: DataLoader = field(default=None)
|
||||||
optimizer: Optimizer = field(default=None)
|
optimizer: Optimizer = field(default=None)
|
||||||
sampler: RandomSampler = field(default=None)
|
scheduler: LRScheduler = field(default=None)
|
||||||
|
checkpoint: Checkpoint = field(default=None)
|
||||||
epoch: int = field(default=0)
|
epoch: int = field(default=0)
|
||||||
current_iter: int = field(default=0)
|
current_iter: int = field(default=0)
|
||||||
loss: float = field(default=0.0)
|
loss: float = field(default=0.0)
|
||||||
checkpoint: Checkpoint = field(default=None)
|
|
||||||
|
|
||||||
def asdict(self) -> dict:
|
def asdict(self) -> dict:
|
||||||
return {field.name: getattr(self, field.name)
|
return {field.name: getattr(self, field.name)
|
||||||
|
|
@ -27,15 +28,7 @@ class TrainContext:
|
||||||
class TrainContextBuilder:
|
class TrainContextBuilder:
|
||||||
def __init__(self, trainer: 'Trainer'):
|
def __init__(self, trainer: 'Trainer'):
|
||||||
self.trainer = trainer
|
self.trainer = trainer
|
||||||
self._context = TrainContext(
|
self._context = TrainContext()
|
||||||
dataloader=None,
|
|
||||||
optimizer=None,
|
|
||||||
sampler=None,
|
|
||||||
epoch=0,
|
|
||||||
current_iter=0,
|
|
||||||
loss=0.0,
|
|
||||||
checkpoint=None
|
|
||||||
)
|
|
||||||
|
|
||||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||||
if checkpoint is None:
|
if checkpoint is None:
|
||||||
|
|
@ -43,32 +36,10 @@ class TrainContextBuilder:
|
||||||
model=self.trainer.parameter.model,
|
model=self.trainer.parameter.model,
|
||||||
tokenizer=self.trainer.parameter.tokenizer,
|
tokenizer=self.trainer.parameter.tokenizer,
|
||||||
config=self.trainer.parameter.config,
|
config=self.trainer.parameter.config,
|
||||||
sampler_state=None,
|
|
||||||
optimizer_state=None,
|
|
||||||
loss_list=[]
|
|
||||||
)
|
)
|
||||||
self._context.checkpoint = checkpoint
|
self._context.checkpoint = checkpoint
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def with_sampler(self) -> Self:
|
|
||||||
seed = self.trainer.train_config.random_seed
|
|
||||||
sampler = RandomSampler(
|
|
||||||
data_source=self.trainer.train_config.dataset,
|
|
||||||
seed=seed
|
|
||||||
)
|
|
||||||
|
|
||||||
if self._context.checkpoint and self._context.checkpoint.sampler_state:
|
|
||||||
sampler.load_state_dict(self._context.checkpoint.sampler_state)
|
|
||||||
|
|
||||||
self._context.sampler = sampler
|
|
||||||
self._context.epoch = sampler.epoch
|
|
||||||
self._context.current_iter = sampler.current_iter
|
|
||||||
|
|
||||||
if self._context.checkpoint:
|
|
||||||
self._context.checkpoint.sampler_state = sampler.state_dict()
|
|
||||||
|
|
||||||
return self
|
|
||||||
|
|
||||||
def with_optimizer(self) -> Self:
|
def with_optimizer(self) -> Self:
|
||||||
optimizer = self.trainer.train_config.optimizer
|
optimizer = self.trainer.train_config.optimizer
|
||||||
|
|
||||||
|
|
@ -82,11 +53,22 @@ class TrainContextBuilder:
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def with_scheduler(self) -> Self:
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
def with_dataloader(self) -> Self:
|
def with_dataloader(self) -> Self:
|
||||||
|
resumeable_sampler = ResumeableRandomSampler(
|
||||||
|
data_source=self.trainer.train_config.dataset,
|
||||||
|
start_epoch=self._context.epoch,
|
||||||
|
start_iter=self._context.current_iter,
|
||||||
|
seed=self.trainer.train_config.random_seed
|
||||||
|
)
|
||||||
|
|
||||||
dataloader = DataLoader(
|
dataloader = DataLoader(
|
||||||
self.trainer.train_config.dataset,
|
self.trainer.train_config.dataset,
|
||||||
batch_size=self.trainer.train_config.batch_size,
|
batch_size=self.trainer.train_config.batch_size,
|
||||||
sampler=self._context.sampler,
|
sampler=resumeable_sampler,
|
||||||
num_workers=self.trainer.train_config.num_workers,
|
num_workers=self.trainer.train_config.num_workers,
|
||||||
pin_memory=self.trainer.train_config.pin_memory,
|
pin_memory=self.trainer.train_config.pin_memory,
|
||||||
prefetch_factor=self.trainer.train_config.prefetch_factor
|
prefetch_factor=self.trainer.train_config.prefetch_factor
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from khaosz.core import ModelParameter, Checkpoint
|
from khaosz.config import ModelParameter, Checkpoint
|
||||||
from khaosz.trainer.strategy import ScheduleConfig
|
from khaosz.trainer.strategy import ScheduleConfig
|
||||||
from khaosz.trainer.train_config import TrainConfig
|
from khaosz.config.train_config import TrainConfig
|
||||||
from khaosz.trainer.train_callback import (
|
from khaosz.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
|
|
@ -39,8 +39,8 @@ class Trainer:
|
||||||
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
return (TrainContextBuilder(self)
|
return (TrainContextBuilder(self)
|
||||||
.with_checkpoint(checkpoint)
|
.with_checkpoint(checkpoint)
|
||||||
.with_sampler()
|
|
||||||
.with_optimizer()
|
.with_optimizer()
|
||||||
|
.with_scheduler()
|
||||||
.with_dataloader()
|
.with_dataloader()
|
||||||
.build())
|
.build())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -9,9 +9,11 @@ import pytest
|
||||||
import matplotlib
|
import matplotlib
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from khaosz.core import *
|
from khaosz.config.model_config import TransformerConfig
|
||||||
from khaosz.trainer import *
|
from khaosz.data.data_util import build_attention_mask, build_loss_mask
|
||||||
from khaosz.trainer.data_util import *
|
from khaosz.data.tokenizer import BpeTokenizer
|
||||||
|
from khaosz.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
matplotlib.use("Agg")
|
matplotlib.use("Agg")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from khaosz.core import *
|
from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.trainer.data_util import *
|
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
"""Test that all callbacks are properly integrated"""
|
"""Test that all callbacks are properly integrated"""
|
||||||
|
|
|
||||||
|
|
@ -3,9 +3,9 @@ import torch
|
||||||
import pickle
|
import pickle
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from khaosz.core import *
|
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.trainer.data_util import *
|
from khaosz.data.data_util import *
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
"""Test dataset loader with multiple random paths"""
|
"""Test dataset loader with multiple random paths"""
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
from khaosz.config import *
|
||||||
from khaosz.core import *
|
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.trainer.data_util import *
|
|
||||||
|
|
||||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
"""Simulate early stopping behavior"""
|
"""Simulate early stopping behavior"""
|
||||||
|
|
|
||||||
|
|
@ -5,8 +5,11 @@ import shutil
|
||||||
import pytest
|
import pytest
|
||||||
import tempfile
|
import tempfile
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
from khaosz.core import *
|
from khaosz.trainer import *
|
||||||
from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
|
from khaosz.config import *
|
||||||
|
from khaosz.model import *
|
||||||
|
from khaosz.data import *
|
||||||
|
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||||
from tokenizers import pre_tokenizers
|
from tokenizers import pre_tokenizers
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -1,14 +1,13 @@
|
||||||
from khaosz.core import *
|
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.trainer.data_util import *
|
from khaosz.data.data_util import *
|
||||||
|
|
||||||
def test_random_sampler_consistency(random_dataset):
|
def test_random_sampler_consistency(random_dataset):
|
||||||
"""Test RandomSampler produces consistent results with same seed"""
|
"""Test RandomSampler produces consistent results with same seed"""
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
|
|
||||||
# Create two samplers with same seed
|
# Create two samplers with same seed
|
||||||
sampler1 = RandomSampler(dataset, seed=42)
|
sampler1 = ResumeableRandomSampler(dataset, seed=42)
|
||||||
sampler2 = RandomSampler(dataset, seed=42)
|
sampler2 = ResumeableRandomSampler(dataset, seed=42)
|
||||||
|
|
||||||
indices1 = list(iter(sampler1))
|
indices1 = list(iter(sampler1))
|
||||||
indices2 = list(iter(sampler2))
|
indices2 = list(iter(sampler2))
|
||||||
|
|
@ -20,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset):
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
|
|
||||||
# Create two samplers with different seeds
|
# Create two samplers with different seeds
|
||||||
sampler1 = RandomSampler(dataset, seed=42)
|
sampler1 = ResumeableRandomSampler(dataset, seed=42)
|
||||||
sampler2 = RandomSampler(dataset, seed=123)
|
sampler2 = ResumeableRandomSampler(dataset, seed=123)
|
||||||
|
|
||||||
indices1 = list(iter(sampler1))
|
indices1 = list(iter(sampler1))
|
||||||
indices2 = list(iter(sampler2))
|
indices2 = list(iter(sampler2))
|
||||||
|
|
@ -29,38 +28,13 @@ def test_random_sampler_different_seeds(random_dataset):
|
||||||
# Very high probability they should be different
|
# Very high probability they should be different
|
||||||
assert indices1 != indices2
|
assert indices1 != indices2
|
||||||
|
|
||||||
def test_sampler_state_persistence(random_dataset):
|
|
||||||
"""Test that sampler state is correctly saved and loaded"""
|
|
||||||
dataset = random_dataset
|
|
||||||
n = len(dataset)
|
|
||||||
|
|
||||||
# Create sampler and get some indices
|
|
||||||
sampler = RandomSampler(dataset, seed=42)
|
|
||||||
iter1 = iter(sampler)
|
|
||||||
indices1 = [next(iter1) for _ in range(min(10, n))]
|
|
||||||
|
|
||||||
# Save state
|
|
||||||
state_dict = sampler.state_dict()
|
|
||||||
|
|
||||||
# Get more indices
|
|
||||||
indices2 = [next(iter1) for _ in range(min(10, n - len(indices1)))]
|
|
||||||
|
|
||||||
# Create new sampler and load state
|
|
||||||
sampler2 = RandomSampler(dataset, seed=42)
|
|
||||||
sampler2.load_state_dict(state_dict)
|
|
||||||
|
|
||||||
# Check that new sampler produces same sequence from saved point
|
|
||||||
iter2 = iter(sampler2)
|
|
||||||
indices3 = [next(iter2) for _ in range(min(10, n - len(indices1)))]
|
|
||||||
|
|
||||||
assert indices2 == indices3
|
|
||||||
|
|
||||||
def test_sampler_across_epochs(random_dataset):
|
def test_sampler_across_epochs(random_dataset):
|
||||||
"""Test sampler behavior across multiple epochs"""
|
"""Test sampler behavior across multiple epochs"""
|
||||||
dataset = random_dataset
|
dataset = random_dataset
|
||||||
n = len(dataset)
|
n = len(dataset)
|
||||||
|
|
||||||
sampler = RandomSampler(dataset, seed=42)
|
sampler = ResumeableRandomSampler(dataset, seed=42)
|
||||||
|
|
||||||
# Get indices for first epoch
|
# Get indices for first epoch
|
||||||
epoch1_indices = list(iter(sampler))
|
epoch1_indices = list(iter(sampler))
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,10 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from khaosz.core import *
|
|
||||||
|
from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.trainer.data_util import *
|
from khaosz.data.data_util import *
|
||||||
|
|
||||||
def test_different_batch_sizes(base_test_env, random_dataset):
|
def test_different_batch_sizes(base_test_env, random_dataset):
|
||||||
"""Test training with different batch sizes"""
|
"""Test training with different batch sizes"""
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from khaosz.core import *
|
from khaosz.config import *
|
||||||
from khaosz.trainer import *
|
from khaosz.trainer import *
|
||||||
from khaosz.trainer.data_util import *
|
from khaosz.data.data_util import *
|
||||||
|
|
||||||
def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
||||||
"""Test training with multi-turn conversation data"""
|
"""Test training with multi-turn conversation data"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue