refactor(khaosz): 重构项目结构

This commit is contained in:
ViperEkura 2025-10-18 13:56:59 +08:00
parent 8434c19923
commit c51b203fde
28 changed files with 423 additions and 423 deletions

View File

@ -1,7 +1,7 @@
import torch import torch
from typing import Dict, Any from typing import Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
from khaosz.core.transformer import TransformerConfig, Transformer from khaosz.model.transformer import TransformerConfig, Transformer
@dataclass @dataclass

View File

@ -1,16 +1,23 @@
__version__ = "1.3.0" __version__ = "1.3.0"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from khaosz.model import Khaosz from khaosz.khaosz import Khaosz
from khaosz.core.transformer import Transformer, TransformerConfig from khaosz.config import (
TransformerConfig,
ParameterLoader,
TrainConfig,
)
from khaosz.model.transformer import Transformer
from khaosz.utils.retriever import Retriever from khaosz.utils.retriever import Retriever
from khaosz.utils.splitter import ( from khaosz.utils.splitter import (
SemanticTextSplitter, SemanticTextSplitter,
PriorityTextSplitter PriorityTextSplitter
) )
from khaosz.core.tokenizer import BpeTokenizer from khaosz.data import (
from khaosz.core.parameter import ParameterLoader DatasetLoader,
from khaosz.core.generator import ( BpeTokenizer
)
from khaosz.inference.generator import (
TextGenerator, TextGenerator,
ChatGenerator, ChatGenerator,
StreamGenerator, StreamGenerator,
@ -18,10 +25,9 @@ from khaosz.core.generator import (
RetrievalGenerator, RetrievalGenerator,
EmbeddingEncoder EmbeddingEncoder
) )
from khaosz.trainer import ( from khaosz.trainer import (
Trainer, Trainer,
DatasetLoader,
TrainConfig,
StrategyFactory, StrategyFactory,
SchedulerFactory SchedulerFactory
) )
@ -44,7 +50,7 @@ __all__ = [
# trainer # trainer
"Trainer", "Trainer",
"DatasetLoader", "DatasetLoader", # 保持在 __all__ 中,但来源是 khaosz.data
"TrainConfig", "TrainConfig",
"StrategyFactory", "StrategyFactory",
"SchedulerFactory", "SchedulerFactory",
@ -53,4 +59,4 @@ __all__ = [
"Retriever", "Retriever",
"SemanticTextSplitter", "SemanticTextSplitter",
"PriorityTextSplitter", "PriorityTextSplitter",
] ]

12
khaosz/config/__init__.py Normal file
View File

@ -0,0 +1,12 @@
from khaosz.config.model_config import TransformerConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
from khaosz.config.train_config import TrainConfig
__all__ = [
"BaseModelIO",
"ModelParameter",
"Checkpoint",
"ParameterLoader",
"TransformerConfig",
"TrainConfig"
]

View File

@ -0,0 +1,36 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional, Self
@dataclass
class TransformerConfig:
# basic config
vocab_size: Optional[int] = None
n_dim: Optional[int] = None
n_head: Optional[int] = None
n_layer: Optional[int] = None
m_len: Optional[int] = None
norm_eps: Optional[float] = None
d_ffn: Optional[int] = None
# GQA
n_kvhead: Optional[int] = None
def load(self, config_path: str) -> Self:
with open(config_path, 'r') as f:
config: dict = json.load(f)
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str) -> None:
config_dict = asdict(self)
config_dict = {k: v for k, v in config_dict.items() if v is not None}
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)

View File

@ -8,8 +8,9 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Self, Union from typing import Any, Dict, List, Optional, Self, Union
from pathlib import Path from pathlib import Path
from khaosz.core.tokenizer import BpeTokenizer from khaosz.data.tokenizer import BpeTokenizer
from khaosz.core.transformer import TransformerConfig, Transformer from khaosz.config.model_config import TransformerConfig
from khaosz.model.transformer import Transformer
class BaseModelIO: class BaseModelIO:
@ -99,18 +100,18 @@ class Checkpoint(BaseModelIO):
metadata={"help": "Transformer model."} metadata={"help": "Transformer model."}
) )
tokenizer: BpeTokenizer = field( tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer, default=None,
metadata={"help": "Tokenizer for the model."} metadata={"help": "Tokenizer for the model."}
) )
config: TransformerConfig = field( config: TransformerConfig = field(
default_factory=TransformerConfig, default=None,
metadata={"help": "Transformer model configuration."} metadata={"help": "Transformer model configuration."}
) )
optimizer_state: Dict[str, Any] = field( optimizer_state: Dict[str, Any] = field(
default=None, default=None,
metadata={"help": "Optimizer state."} metadata={"help": "Optimizer state."}
) )
sampler_state: Dict[str, Any] = field( scheduler_state: Dict[str, Any] = field(
default=None, default=None,
metadata={"help": "Sampler state."} metadata={"help": "Sampler state."}
) )
@ -142,10 +143,10 @@ class Checkpoint(BaseModelIO):
# Save optimizer state # Save optimizer state
with open(str(paths["optimizer_state"]), "wb") as f: with open(str(paths["optimizer_state"]), "wb") as f:
pkl.dump(self.optimizer_state, f) pkl.dump(self.optimizer_state, f)
# Save sampler state # Save sampler state
with open(str(paths["sampler_state"]), "wb") as f: with open(str(paths["sampler_state"]), "wb") as f:
pkl.dump(self.sampler_state, f) pkl.dump(self.scheduler_state, f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self: def load_training_state(self, load_dir: Union[str, Path]) -> Self:
paths = self._get_training_paths(load_dir) paths = self._get_training_paths(load_dir)
@ -163,7 +164,7 @@ class Checkpoint(BaseModelIO):
# Load sampler state # Load sampler state
if paths["sampler_state"].exists(): if paths["sampler_state"].exists():
with open(str(paths["sampler_state"]), "rb") as f: with open(str(paths["sampler_state"]), "rb") as f:
self.sampler_state = pkl.load(f) self.scheduler_state = pkl.load(f)
return self return self
@ -173,7 +174,7 @@ class Checkpoint(BaseModelIO):
return return
current_iter = len(self.loss_list) current_iter = len(self.loss_list)
plt.figure(figsize=(10, 6)) plt.figure(figsize=(10, 6))
plt.plot(self.loss_list) plt.plot(self.loss_list)
plt.title(f"Training Loss - Iteration {current_iter}") plt.title(f"Training Loss - Iteration {current_iter}")

View File

@ -1,14 +1,16 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional from typing import Optional, TYPE_CHECKING
from torch.utils.data import Dataset from torch.utils.data import Dataset
from torch.optim import Optimizer from torch.optim import Optimizer
from khaosz.trainer.strategy import BaseStrategy
if TYPE_CHECKING:
from khaosz.trainer.strategy import BaseStrategy
@dataclass @dataclass
class TrainConfig: class TrainConfig:
strategy: BaseStrategy = field( strategy: "BaseStrategy" = field(
default=None, default=None,
metadata={"help": "Training strategy."} metadata={"help": "Training strategy."}
) )

View File

@ -1,27 +0,0 @@
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.transformer import Transformer, TransformerConfig
from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
from khaosz.core.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
__all__ = [
"Transformer",
"TransformerConfig",
"BpeTokenizer",
"ParameterLoader",
"ModelParameter",
"Checkpoint",
"TextGenerator",
"ChatGenerator",
"StreamGenerator",
"BatchGenerator",
"RetrievalGenerator",
"EmbeddingEncoder"
]

30
khaosz/data/__init__.py Normal file
View File

@ -0,0 +1,30 @@
from khaosz.data.data_util import (
BaseDataset,
SeqDataset,
DpoDataset,
SftDataset,
PpoDataset,
MutiSegmentFetcher,
ResumeableRandomSampler,
DatasetLoader,
load_pkl_files,
build_attention_mask,
build_loss_mask
)
from khaosz.data.tokenizer import BpeTokenizer
__all__ = [
"BaseDataset",
"SeqDataset",
"DpoDataset",
"SftDataset",
"PpoDataset",
"MutiSegmentFetcher",
"ResumeableRandomSampler",
"DatasetLoader",
"load_pkl_files",
"build_attention_mask",
"build_loss_mask",
"BpeTokenizer"
]

View File

@ -263,58 +263,39 @@ class DatasetLoader:
dataset.load(load_path) dataset.load(load_path)
return dataset return dataset
class RandomSampler(Sampler[int]):
def __init__(self, data_source, generator=None, seed=42): class ResumeableRandomSampler(Sampler[int]):
self.data_source = data_source def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42):
self.seed = seed self.num_samples = len(data_source)
self.epoch = 0 self.epoch = start_epoch
self.current_iter = 0 self.iter = start_iter
self._indices = None
if generator is None: generator = torch.Generator()
self.generator = torch.Generator() generator.manual_seed(seed)
self.generator.manual_seed(seed)
else: self.generator = generator
self.generator = generator self._indices = None
def _generate_indices(self): def _get_indices(self):
n = len(self.data_source) for _ in range(self.epoch):
self._indices = torch.randperm(n, generator=self.generator).tolist() _ = torch.randperm(self.num_samples, generator=self.generator)
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
self._indices = current_epoch_indices[self.iter % self.num_samples:]
def __iter__(self): def __iter__(self):
n = len(self.data_source)
if self._indices is None: if self._indices is None:
self._generate_indices() self._get_indices()
start = self.current_iter % n for i in self._indices:
for i in range(start, n): self.iter += 1
self.current_iter += 1 yield i
yield self._indices[i]
self.epoch += 1 self.epoch += 1
self._indices = None self._indices = None
def __len__(self): def __len__(self):
return len(self.data_source) if self._indices is None:
self._get_indices()
def state_dict(self): return len(self._indices)
return {
'epoch': self.epoch,
'current_iter': self.current_iter,
'seed': self.seed,
'generator_state': self.generator.get_state() if self.generator else None,
'indices': self._indices
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.current_iter = state_dict['current_iter']
self.seed = state_dict['seed']
if self.generator and state_dict['generator_state'] is not None:
self.generator.set_state(state_dict['generator_state'])
self._indices = state_dict['indices']

97
khaosz/inference/core.py Normal file
View File

@ -0,0 +1,97 @@
import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self
from khaosz.config.param_config import ModelParameter
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def compute_logits(
self,
input_ids: Tensor,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
return logits, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
#if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
if indices is not None:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self

View File

@ -1,7 +1,8 @@
import torch import torch
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self from typing import List, Tuple, Union, Optional, Generator
from khaosz.core.parameter import ModelParameter from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore
from khaosz.config.param_config import ModelParameter
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str: def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
@ -168,96 +169,6 @@ class KVCacheManager:
return self._seq_mask return self._seq_mask
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def compute_logits(
self,
input_ids: Tensor,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
return logits, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
#if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
if indices is not None:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class TextGenerator(GeneratorCore): class TextGenerator(GeneratorCore):

View File

@ -1,7 +1,7 @@
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Generator, Union from typing import List, Tuple, Generator, Union
from khaosz.core.generator import ( from khaosz.inference.generator import (
TextGenerator, TextGenerator,
ChatGenerator, ChatGenerator,
StreamGenerator, StreamGenerator,
@ -9,7 +9,7 @@ from khaosz.core.generator import (
RetrievalGenerator, RetrievalGenerator,
EmbeddingEncoder EmbeddingEncoder
) )
from khaosz.core.parameter import ParameterLoader from khaosz.config.param_config import ParameterLoader
class Khaosz: class Khaosz:

17
khaosz/model/__init__.py Normal file
View File

@ -0,0 +1,17 @@
from khaosz.model.module import (
Linear,
RMSNorm,
MLP,
GQA,
DecoderBlock,
)
from khaosz.model.transformer import Transformer
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
"Transformer"
]

View File

@ -1,12 +1,11 @@
import json
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from torch import Tensor from torch import Tensor
from torch.nn import init from torch.nn import init
from dataclasses import asdict, dataclass from typing import Optional, Tuple
from typing import List, Optional, Self, Tuple
def repeat_kv(x: Tensor, n_rep: int) -> Tensor: def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
@ -71,89 +70,6 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
return x_out.to(dtype) return x_out.to(dtype)
def process_attention_mask(
seq_mask: Tensor,
start_pos: int = 0,
seq_len: int = 0,
is_causal: bool = False,
device: torch.device = "cuda",
dtype: torch.dtype = torch.float32
) -> Tensor:
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
start_pos (int): The starting position of the sequence.
seq_len (int): The length of the sequence.
is_causal (bool): Whether the attention is causal or not.
device (torch.device): The device to use.
Returns:
Tensor: The attention mask tensor.
"""
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
causal_mask = torch.tril(
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
diagonal=start_pos
)
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
expanded_mask = expanded_mask & causal_mask
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@dataclass
class TransformerConfig:
# basic config
vocab_size: Optional[int] = None
n_dim: Optional[int] = None
n_head: Optional[int] = None
n_layer: Optional[int] = None
m_len: Optional[int] = None
norm_eps: Optional[float] = None
d_ffn: Optional[int] = None
# GQA
n_kvhead: Optional[int] = None
def load(self, config_path: str) -> Self:
with open(config_path, 'r') as f:
config: dict = json.load(f)
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str) -> None:
config_dict = asdict(self)
config_dict = {k: v for k, v in config_dict.items() if v is not None}
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)
class Linear(nn.Module): class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool=False): def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
@ -287,60 +203,4 @@ class DecoderBlock(nn.Module):
# feed forward # feed forward
x = self.ffn(self.norm_ffn(x)) + x x = self.ffn(self.norm_ffn(x)) + x
return x return x
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
self.layers = nn.ModuleList([
DecoderBlock(
config.n_dim,
config.n_head,
config.d_ffn,
config.n_kvhead,
config.norm_eps
)
for _ in range(config.n_layer)
])
self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
init.normal_(self.embedding, mean=0, std=0.02)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
start_pos: int = 0
) -> Tensor:
assert input_ids.ndim == 2
seq_len = input_ids.size(-1)
x = F.embedding(input_ids, self.embedding)
self.freq_cis = self.freq_cis.to(x.device)
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
has_kvcache = persistent_key_values is not None
attn_mask = process_attention_mask(
input_mask,
start_pos=start_pos,
seq_len=seq_len,
is_causal=has_kvcache,
device=x.device,
dtype=x.dtype
)
for i, layer in enumerate(self.layers):
kv_cache = persistent_key_values[i] if persistent_key_values else None
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
hidden_states = self.norm(x)
logits = F.linear(hidden_states, self.embedding)
return {
"logits": logits,
"hidden_states": hidden_states
}

119
khaosz/model/transformer.py Normal file
View File

@ -0,0 +1,119 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import init
from typing import List, Optional, Tuple
from khaosz.config.model_config import TransformerConfig
from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb
def process_attention_mask(
seq_mask: Tensor,
start_pos: int = 0,
seq_len: int = 0,
is_causal: bool = False,
device: torch.device = "cuda",
dtype: torch.dtype = torch.float32
) -> Tensor:
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
start_pos (int): The starting position of the sequence.
seq_len (int): The length of the sequence.
is_causal (bool): Whether the attention is causal or not.
device (torch.device): The device to use.
Returns:
Tensor: The attention mask tensor.
"""
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
causal_mask = torch.tril(
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
diagonal=start_pos
)
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
expanded_mask = expanded_mask & causal_mask
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
self.layers = nn.ModuleList([
DecoderBlock(
config.n_dim,
config.n_head,
config.d_ffn,
config.n_kvhead,
config.norm_eps
)
for _ in range(config.n_layer)
])
self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
init.normal_(self.embedding, mean=0, std=0.02)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
start_pos: int = 0
) -> Tensor:
assert input_ids.ndim == 2
seq_len = input_ids.size(-1)
x = F.embedding(input_ids, self.embedding)
self.freq_cis = self.freq_cis.to(x.device)
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
has_kvcache = persistent_key_values is not None
attn_mask = process_attention_mask(
input_mask,
start_pos=start_pos,
seq_len=seq_len,
is_causal=has_kvcache,
device=x.device,
dtype=x.dtype
)
for i, layer in enumerate(self.layers):
kv_cache = persistent_key_values[i] if persistent_key_values else None
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
hidden_states = self.norm(x)
logits = F.linear(hidden_states, self.embedding)
return {
"logits": logits,
"hidden_states": hidden_states
}

View File

@ -1,6 +1,4 @@
from khaosz.trainer.data_util import DatasetLoader
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
from khaosz.trainer.train_config import TrainConfig
from khaosz.trainer.strategy import ( from khaosz.trainer.strategy import (
CosineScheduleConfig, CosineScheduleConfig,
SgdrScheduleConfig, SgdrScheduleConfig,
@ -17,19 +15,16 @@ from khaosz.trainer.train_callback import (
) )
__all__ = [ __all__ = [
"DatasetLoader",
"Trainer", "Trainer",
"TrainConfig", "StrategyFactory",
"CosineScheduleConfig", "CosineScheduleConfig",
"SgdrScheduleConfig", "SgdrScheduleConfig",
"StrategyFactory",
"SchedulerFactory", "SchedulerFactory",
# callback # callback
"TrainCallback", "TrainCallback",
"ProgressBarCallback", "ProgressBarCallback",
"CheckpointCallback", "CheckpointCallback",
"TrainCallback",
"SchedulerCallback", "SchedulerCallback",
"StepMonitorCallback" "StepMonitorCallback"
] ]

View File

@ -106,7 +106,7 @@ class CheckpointCallback(TrainCallback):
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'): def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}") save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
context.checkpoint.sampler_state = context.sampler.state_dict() # context.checkpoint.scheduler_state = context.sampler.state_dict()
context.checkpoint.optimizer_state = context.optimizer.state_dict() context.checkpoint.optimizer_state = context.optimizer.state_dict()
context.checkpoint.save(save_path) context.checkpoint.save(save_path)
self.last_ckpt_iter = context.current_iter self.last_ckpt_iter = context.current_iter

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field, fields from dataclasses import dataclass, field, fields
from typing import Optional, Self, TYPE_CHECKING from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from khaosz.core.parameter import Checkpoint from khaosz.config.param_config import Checkpoint
from khaosz.trainer.data_util import RandomSampler from khaosz.data.data_util import ResumeableRandomSampler
if TYPE_CHECKING: if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer from khaosz.trainer.trainer import Trainer
@ -13,11 +14,11 @@ if TYPE_CHECKING:
class TrainContext: class TrainContext:
dataloader: DataLoader = field(default=None) dataloader: DataLoader = field(default=None)
optimizer: Optimizer = field(default=None) optimizer: Optimizer = field(default=None)
sampler: RandomSampler = field(default=None) scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None)
epoch: int = field(default=0) epoch: int = field(default=0)
current_iter: int = field(default=0) current_iter: int = field(default=0)
loss: float = field(default=0.0) loss: float = field(default=0.0)
checkpoint: Checkpoint = field(default=None)
def asdict(self) -> dict: def asdict(self) -> dict:
return {field.name: getattr(self, field.name) return {field.name: getattr(self, field.name)
@ -27,15 +28,7 @@ class TrainContext:
class TrainContextBuilder: class TrainContextBuilder:
def __init__(self, trainer: 'Trainer'): def __init__(self, trainer: 'Trainer'):
self.trainer = trainer self.trainer = trainer
self._context = TrainContext( self._context = TrainContext()
dataloader=None,
optimizer=None,
sampler=None,
epoch=0,
current_iter=0,
loss=0.0,
checkpoint=None
)
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self: def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None: if checkpoint is None:
@ -43,32 +36,10 @@ class TrainContextBuilder:
model=self.trainer.parameter.model, model=self.trainer.parameter.model,
tokenizer=self.trainer.parameter.tokenizer, tokenizer=self.trainer.parameter.tokenizer,
config=self.trainer.parameter.config, config=self.trainer.parameter.config,
sampler_state=None,
optimizer_state=None,
loss_list=[]
) )
self._context.checkpoint = checkpoint self._context.checkpoint = checkpoint
return self return self
def with_sampler(self) -> Self:
seed = self.trainer.train_config.random_seed
sampler = RandomSampler(
data_source=self.trainer.train_config.dataset,
seed=seed
)
if self._context.checkpoint and self._context.checkpoint.sampler_state:
sampler.load_state_dict(self._context.checkpoint.sampler_state)
self._context.sampler = sampler
self._context.epoch = sampler.epoch
self._context.current_iter = sampler.current_iter
if self._context.checkpoint:
self._context.checkpoint.sampler_state = sampler.state_dict()
return self
def with_optimizer(self) -> Self: def with_optimizer(self) -> Self:
optimizer = self.trainer.train_config.optimizer optimizer = self.trainer.train_config.optimizer
@ -82,11 +53,22 @@ class TrainContextBuilder:
return self return self
def with_scheduler(self) -> Self:
return self
def with_dataloader(self) -> Self: def with_dataloader(self) -> Self:
resumeable_sampler = ResumeableRandomSampler(
data_source=self.trainer.train_config.dataset,
start_epoch=self._context.epoch,
start_iter=self._context.current_iter,
seed=self.trainer.train_config.random_seed
)
dataloader = DataLoader( dataloader = DataLoader(
self.trainer.train_config.dataset, self.trainer.train_config.dataset,
batch_size=self.trainer.train_config.batch_size, batch_size=self.trainer.train_config.batch_size,
sampler=self._context.sampler, sampler=resumeable_sampler,
num_workers=self.trainer.train_config.num_workers, num_workers=self.trainer.train_config.num_workers,
pin_memory=self.trainer.train_config.pin_memory, pin_memory=self.trainer.train_config.pin_memory,
prefetch_factor=self.trainer.train_config.prefetch_factor prefetch_factor=self.trainer.train_config.prefetch_factor

View File

@ -1,9 +1,9 @@
import logging import logging
from typing import Optional, List from typing import Optional, List
from khaosz.core import ModelParameter, Checkpoint from khaosz.config import ModelParameter, Checkpoint
from khaosz.trainer.strategy import ScheduleConfig from khaosz.trainer.strategy import ScheduleConfig
from khaosz.trainer.train_config import TrainConfig from khaosz.config.train_config import TrainConfig
from khaosz.trainer.train_callback import ( from khaosz.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
@ -39,8 +39,8 @@ class Trainer:
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext: def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self) return (TrainContextBuilder(self)
.with_checkpoint(checkpoint) .with_checkpoint(checkpoint)
.with_sampler()
.with_optimizer() .with_optimizer()
.with_scheduler()
.with_dataloader() .with_dataloader()
.build()) .build())

View File

@ -9,9 +9,11 @@ import pytest
import matplotlib import matplotlib
from torch.utils.data import Dataset from torch.utils.data import Dataset
from khaosz.core import * from khaosz.config.model_config import TransformerConfig
from khaosz.trainer import * from khaosz.data.data_util import build_attention_mask, build_loss_mask
from khaosz.trainer.data_util import * from khaosz.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer
matplotlib.use("Agg") matplotlib.use("Agg")

View File

@ -1,8 +1,7 @@
import torch import torch
from khaosz.core import * from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.trainer.data_util import *
def test_callback_integration(base_test_env, random_dataset): def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated""" """Test that all callbacks are properly integrated"""

View File

@ -3,9 +3,9 @@ import torch
import pickle import pickle
import numpy as np import numpy as np
from khaosz.core import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.trainer.data_util import * from khaosz.data.data_util import *
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):
"""Test dataset loader with multiple random paths""" """Test dataset loader with multiple random paths"""

View File

@ -1,8 +1,7 @@
import torch import torch
from khaosz.config import *
from khaosz.core import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.trainer.data_util import *
def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""

View File

@ -5,8 +5,11 @@ import shutil
import pytest import pytest
import tempfile import tempfile
import safetensors.torch as st import safetensors.torch as st
from khaosz.core import * from khaosz.trainer import *
from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore from khaosz.config import *
from khaosz.model import *
from khaosz.data import *
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers from tokenizers import pre_tokenizers
@pytest.fixture @pytest.fixture

View File

@ -1,14 +1,13 @@
from khaosz.core import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.trainer.data_util import * from khaosz.data.data_util import *
def test_random_sampler_consistency(random_dataset): def test_random_sampler_consistency(random_dataset):
"""Test RandomSampler produces consistent results with same seed""" """Test RandomSampler produces consistent results with same seed"""
dataset = random_dataset dataset = random_dataset
# Create two samplers with same seed # Create two samplers with same seed
sampler1 = RandomSampler(dataset, seed=42) sampler1 = ResumeableRandomSampler(dataset, seed=42)
sampler2 = RandomSampler(dataset, seed=42) sampler2 = ResumeableRandomSampler(dataset, seed=42)
indices1 = list(iter(sampler1)) indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2)) indices2 = list(iter(sampler2))
@ -20,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset):
dataset = random_dataset dataset = random_dataset
# Create two samplers with different seeds # Create two samplers with different seeds
sampler1 = RandomSampler(dataset, seed=42) sampler1 = ResumeableRandomSampler(dataset, seed=42)
sampler2 = RandomSampler(dataset, seed=123) sampler2 = ResumeableRandomSampler(dataset, seed=123)
indices1 = list(iter(sampler1)) indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2)) indices2 = list(iter(sampler2))
@ -29,38 +28,13 @@ def test_random_sampler_different_seeds(random_dataset):
# Very high probability they should be different # Very high probability they should be different
assert indices1 != indices2 assert indices1 != indices2
def test_sampler_state_persistence(random_dataset):
"""Test that sampler state is correctly saved and loaded"""
dataset = random_dataset
n = len(dataset)
# Create sampler and get some indices
sampler = RandomSampler(dataset, seed=42)
iter1 = iter(sampler)
indices1 = [next(iter1) for _ in range(min(10, n))]
# Save state
state_dict = sampler.state_dict()
# Get more indices
indices2 = [next(iter1) for _ in range(min(10, n - len(indices1)))]
# Create new sampler and load state
sampler2 = RandomSampler(dataset, seed=42)
sampler2.load_state_dict(state_dict)
# Check that new sampler produces same sequence from saved point
iter2 = iter(sampler2)
indices3 = [next(iter2) for _ in range(min(10, n - len(indices1)))]
assert indices2 == indices3
def test_sampler_across_epochs(random_dataset): def test_sampler_across_epochs(random_dataset):
"""Test sampler behavior across multiple epochs""" """Test sampler behavior across multiple epochs"""
dataset = random_dataset dataset = random_dataset
n = len(dataset) n = len(dataset)
sampler = RandomSampler(dataset, seed=42) sampler = ResumeableRandomSampler(dataset, seed=42)
# Get indices for first epoch # Get indices for first epoch
epoch1_indices = list(iter(sampler)) epoch1_indices = list(iter(sampler))

View File

@ -1,9 +1,10 @@
import torch import torch
import numpy as np import numpy as np
from khaosz.core import *
from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.trainer.data_util import * from khaosz.data.data_util import *
def test_different_batch_sizes(base_test_env, random_dataset): def test_different_batch_sizes(base_test_env, random_dataset):
"""Test training with different batch sizes""" """Test training with different batch sizes"""

View File

@ -1,9 +1,9 @@
import torch import torch
import numpy as np import numpy as np
from khaosz.core import * from khaosz.config import *
from khaosz.trainer import * from khaosz.trainer import *
from khaosz.trainer.data_util import * from khaosz.data.data_util import *
def test_multi_turn_training(base_test_env, multi_turn_dataset): def test_multi_turn_training(base_test_env, multi_turn_dataset):
"""Test training with multi-turn conversation data""" """Test training with multi-turn conversation data"""