refactor(khaosz): 重构项目结构

This commit is contained in:
ViperEkura 2025-10-18 13:56:59 +08:00
parent 8434c19923
commit c51b203fde
28 changed files with 423 additions and 423 deletions

View File

@ -1,7 +1,7 @@
import torch
from typing import Dict, Any
from dataclasses import dataclass
from khaosz.core.transformer import TransformerConfig, Transformer
from khaosz.model.transformer import TransformerConfig, Transformer
@dataclass

View File

@ -1,16 +1,23 @@
__version__ = "1.3.0"
__author__ = "ViperEkura"
from khaosz.model import Khaosz
from khaosz.core.transformer import Transformer, TransformerConfig
from khaosz.khaosz import Khaosz
from khaosz.config import (
TransformerConfig,
ParameterLoader,
TrainConfig,
)
from khaosz.model.transformer import Transformer
from khaosz.utils.retriever import Retriever
from khaosz.utils.splitter import (
SemanticTextSplitter,
PriorityTextSplitter
)
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.parameter import ParameterLoader
from khaosz.core.generator import (
from khaosz.data import (
DatasetLoader,
BpeTokenizer
)
from khaosz.inference.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
@ -18,10 +25,9 @@ from khaosz.core.generator import (
RetrievalGenerator,
EmbeddingEncoder
)
from khaosz.trainer import (
Trainer,
DatasetLoader,
TrainConfig,
StrategyFactory,
SchedulerFactory
)
@ -44,7 +50,7 @@ __all__ = [
# trainer
"Trainer",
"DatasetLoader",
"DatasetLoader", # 保持在 __all__ 中,但来源是 khaosz.data
"TrainConfig",
"StrategyFactory",
"SchedulerFactory",

12
khaosz/config/__init__.py Normal file
View File

@ -0,0 +1,12 @@
from khaosz.config.model_config import TransformerConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
from khaosz.config.train_config import TrainConfig
__all__ = [
"BaseModelIO",
"ModelParameter",
"Checkpoint",
"ParameterLoader",
"TransformerConfig",
"TrainConfig"
]

View File

@ -0,0 +1,36 @@
import json
from dataclasses import asdict, dataclass
from typing import Optional, Self
@dataclass
class TransformerConfig:
# basic config
vocab_size: Optional[int] = None
n_dim: Optional[int] = None
n_head: Optional[int] = None
n_layer: Optional[int] = None
m_len: Optional[int] = None
norm_eps: Optional[float] = None
d_ffn: Optional[int] = None
# GQA
n_kvhead: Optional[int] = None
def load(self, config_path: str) -> Self:
with open(config_path, 'r') as f:
config: dict = json.load(f)
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str) -> None:
config_dict = asdict(self)
config_dict = {k: v for k, v in config_dict.items() if v is not None}
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)

View File

@ -8,8 +8,9 @@ from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional, Self, Union
from pathlib import Path
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.transformer import TransformerConfig, Transformer
from khaosz.data.tokenizer import BpeTokenizer
from khaosz.config.model_config import TransformerConfig
from khaosz.model.transformer import Transformer
class BaseModelIO:
@ -99,18 +100,18 @@ class Checkpoint(BaseModelIO):
metadata={"help": "Transformer model."}
)
tokenizer: BpeTokenizer = field(
default_factory=BpeTokenizer,
default=None,
metadata={"help": "Tokenizer for the model."}
)
config: TransformerConfig = field(
default_factory=TransformerConfig,
default=None,
metadata={"help": "Transformer model configuration."}
)
optimizer_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Optimizer state."}
)
sampler_state: Dict[str, Any] = field(
scheduler_state: Dict[str, Any] = field(
default=None,
metadata={"help": "Sampler state."}
)
@ -145,7 +146,7 @@ class Checkpoint(BaseModelIO):
# Save sampler state
with open(str(paths["sampler_state"]), "wb") as f:
pkl.dump(self.sampler_state, f)
pkl.dump(self.scheduler_state, f)
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
paths = self._get_training_paths(load_dir)
@ -163,7 +164,7 @@ class Checkpoint(BaseModelIO):
# Load sampler state
if paths["sampler_state"].exists():
with open(str(paths["sampler_state"]), "rb") as f:
self.sampler_state = pkl.load(f)
self.scheduler_state = pkl.load(f)
return self

View File

@ -1,14 +1,16 @@
from dataclasses import dataclass, field
from typing import Optional
from typing import Optional, TYPE_CHECKING
from torch.utils.data import Dataset
from torch.optim import Optimizer
if TYPE_CHECKING:
from khaosz.trainer.strategy import BaseStrategy
@dataclass
class TrainConfig:
strategy: BaseStrategy = field(
strategy: "BaseStrategy" = field(
default=None,
metadata={"help": "Training strategy."}
)

View File

@ -1,27 +0,0 @@
from khaosz.core.tokenizer import BpeTokenizer
from khaosz.core.transformer import Transformer, TransformerConfig
from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
from khaosz.core.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
BatchGenerator,
RetrievalGenerator,
EmbeddingEncoder
)
__all__ = [
"Transformer",
"TransformerConfig",
"BpeTokenizer",
"ParameterLoader",
"ModelParameter",
"Checkpoint",
"TextGenerator",
"ChatGenerator",
"StreamGenerator",
"BatchGenerator",
"RetrievalGenerator",
"EmbeddingEncoder"
]

30
khaosz/data/__init__.py Normal file
View File

@ -0,0 +1,30 @@
from khaosz.data.data_util import (
BaseDataset,
SeqDataset,
DpoDataset,
SftDataset,
PpoDataset,
MutiSegmentFetcher,
ResumeableRandomSampler,
DatasetLoader,
load_pkl_files,
build_attention_mask,
build_loss_mask
)
from khaosz.data.tokenizer import BpeTokenizer
__all__ = [
"BaseDataset",
"SeqDataset",
"DpoDataset",
"SftDataset",
"PpoDataset",
"MutiSegmentFetcher",
"ResumeableRandomSampler",
"DatasetLoader",
"load_pkl_files",
"build_attention_mask",
"build_loss_mask",
"BpeTokenizer"
]

View File

@ -265,56 +265,37 @@ class DatasetLoader:
return dataset
class RandomSampler(Sampler[int]):
def __init__(self, data_source, generator=None, seed=42):
self.data_source = data_source
self.seed = seed
self.epoch = 0
self.current_iter = 0
class ResumeableRandomSampler(Sampler[int]):
def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42):
self.num_samples = len(data_source)
self.epoch = start_epoch
self.iter = start_iter
generator = torch.Generator()
generator.manual_seed(seed)
self.generator = generator
self._indices = None
if generator is None:
self.generator = torch.Generator()
self.generator.manual_seed(seed)
else:
self.generator = generator
def _get_indices(self):
for _ in range(self.epoch):
_ = torch.randperm(self.num_samples, generator=self.generator)
def _generate_indices(self):
n = len(self.data_source)
self._indices = torch.randperm(n, generator=self.generator).tolist()
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
self._indices = current_epoch_indices[self.iter % self.num_samples:]
def __iter__(self):
n = len(self.data_source)
if self._indices is None:
self._generate_indices()
self._get_indices()
start = self.current_iter % n
for i in range(start, n):
self.current_iter += 1
yield self._indices[i]
for i in self._indices:
self.iter += 1
yield i
self.epoch += 1
self._indices = None
def __len__(self):
return len(self.data_source)
def state_dict(self):
return {
'epoch': self.epoch,
'current_iter': self.current_iter,
'seed': self.seed,
'generator_state': self.generator.get_state() if self.generator else None,
'indices': self._indices
}
def load_state_dict(self, state_dict):
self.epoch = state_dict['epoch']
self.current_iter = state_dict['current_iter']
self.seed = state_dict['seed']
if self.generator and state_dict['generator_state'] is not None:
self.generator.set_state(state_dict['generator_state'])
self._indices = state_dict['indices']
if self._indices is None:
self._get_indices()
return len(self._indices)

97
khaosz/inference/core.py Normal file
View File

@ -0,0 +1,97 @@
import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self
from khaosz.config.param_config import ModelParameter
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def compute_logits(
self,
input_ids: Tensor,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
return logits, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
#if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
if indices is not None:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self

View File

@ -1,7 +1,8 @@
import torch
from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator, Self
from khaosz.core.parameter import ModelParameter
from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore
from khaosz.config.param_config import ModelParameter
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
@ -168,96 +169,6 @@ class KVCacheManager:
return self._seq_mask
class GeneratorCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def compute_logits(
self,
input_ids: Tensor,
attn_mask: Optional[Tensor] = None,
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
start_pos: int = 0
) -> Tuple[Tensor, int]:
with torch.inference_mode():
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
logits = outputs["logits"][:, -1, :]
cache_increase = input_ids.size(-1)
return logits, cache_increase
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class EmbeddingEncoderCore:
def __init__(self, parameter: ModelParameter):
self.model = parameter.model
self.tokenizer = parameter.tokenizer
self.config = parameter.config
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
with_batch = isinstance(sentence, list)
ids = self.tokenizer.encode(sentence)
batch_ids = ids if with_batch else [ids]
max_model_len = self.config.m_len
all_fragments = []
fragment_origin_idx = []
for i, seq in enumerate(batch_ids):
if len(seq) > max_model_len:
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
all_fragments.extend(fragments)
fragment_origin_idx.extend([i] * len(fragments))
else:
all_fragments.append(seq)
fragment_origin_idx.append(i)
#if empty fragments
if not all_fragments or not ids:
return [] if with_batch else torch.tensor([])
device = next(self.model.parameters()).device
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
padded_ids = []
masks = []
for seq in all_fragments:
pad_len = max_len - len(seq)
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
padded_ids.append(padded_seq)
masks.append(mask)
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
with torch.inference_mode():
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
# [num_fragments, seq_len, hidden_size]
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
sentence_embs: List[Tensor] = []
for i in range(len(batch_ids)):
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
if indices is not None:
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
sentence_embs.append(emb.flatten())
if with_batch:
return [emb.flatten() for emb in sentence_embs]
else:
return sentence_embs[0].flatten()
def to(self, *args, **kargs) -> Self:
self.model.to(*args, **kargs)
return self
class TextGenerator(GeneratorCore):

View File

@ -1,7 +1,7 @@
from torch import Tensor
from typing import List, Tuple, Generator, Union
from khaosz.core.generator import (
from khaosz.inference.generator import (
TextGenerator,
ChatGenerator,
StreamGenerator,
@ -9,7 +9,7 @@ from khaosz.core.generator import (
RetrievalGenerator,
EmbeddingEncoder
)
from khaosz.core.parameter import ParameterLoader
from khaosz.config.param_config import ParameterLoader
class Khaosz:

17
khaosz/model/__init__.py Normal file
View File

@ -0,0 +1,17 @@
from khaosz.model.module import (
Linear,
RMSNorm,
MLP,
GQA,
DecoderBlock,
)
from khaosz.model.transformer import Transformer
__all__ = [
"Linear",
"RMSNorm",
"MLP",
"GQA",
"DecoderBlock",
"Transformer"
]

View File

@ -1,12 +1,11 @@
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import init
from dataclasses import asdict, dataclass
from typing import List, Optional, Self, Tuple
from typing import Optional, Tuple
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
@ -71,89 +70,6 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
return x_out.to(dtype)
def process_attention_mask(
seq_mask: Tensor,
start_pos: int = 0,
seq_len: int = 0,
is_causal: bool = False,
device: torch.device = "cuda",
dtype: torch.dtype = torch.float32
) -> Tensor:
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
start_pos (int): The starting position of the sequence.
seq_len (int): The length of the sequence.
is_causal (bool): Whether the attention is causal or not.
device (torch.device): The device to use.
Returns:
Tensor: The attention mask tensor.
"""
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
causal_mask = torch.tril(
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
diagonal=start_pos
)
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
expanded_mask = expanded_mask & causal_mask
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
@dataclass
class TransformerConfig:
# basic config
vocab_size: Optional[int] = None
n_dim: Optional[int] = None
n_head: Optional[int] = None
n_layer: Optional[int] = None
m_len: Optional[int] = None
norm_eps: Optional[float] = None
d_ffn: Optional[int] = None
# GQA
n_kvhead: Optional[int] = None
def load(self, config_path: str) -> Self:
with open(config_path, 'r') as f:
config: dict = json.load(f)
for key, value in config.items():
if hasattr(self, key):
setattr(self, key, value)
return self
def save(self, config_path: str) -> None:
config_dict = asdict(self)
config_dict = {k: v for k, v in config_dict.items() if v is not None}
with open(config_path, 'w') as f:
json.dump(config_dict, f, indent=4)
class Linear(nn.Module):
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
@ -288,59 +204,3 @@ class DecoderBlock(nn.Module):
x = self.ffn(self.norm_ffn(x)) + x
return x
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
self.layers = nn.ModuleList([
DecoderBlock(
config.n_dim,
config.n_head,
config.d_ffn,
config.n_kvhead,
config.norm_eps
)
for _ in range(config.n_layer)
])
self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
init.normal_(self.embedding, mean=0, std=0.02)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
start_pos: int = 0
) -> Tensor:
assert input_ids.ndim == 2
seq_len = input_ids.size(-1)
x = F.embedding(input_ids, self.embedding)
self.freq_cis = self.freq_cis.to(x.device)
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
has_kvcache = persistent_key_values is not None
attn_mask = process_attention_mask(
input_mask,
start_pos=start_pos,
seq_len=seq_len,
is_causal=has_kvcache,
device=x.device,
dtype=x.dtype
)
for i, layer in enumerate(self.layers):
kv_cache = persistent_key_values[i] if persistent_key_values else None
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
hidden_states = self.norm(x)
logits = F.linear(hidden_states, self.embedding)
return {
"logits": logits,
"hidden_states": hidden_states
}

119
khaosz/model/transformer.py Normal file
View File

@ -0,0 +1,119 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from torch.nn import init
from typing import List, Optional, Tuple
from khaosz.config.model_config import TransformerConfig
from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb
def process_attention_mask(
seq_mask: Tensor,
start_pos: int = 0,
seq_len: int = 0,
is_causal: bool = False,
device: torch.device = "cuda",
dtype: torch.dtype = torch.float32
) -> Tensor:
"""
Create attention mask for GQA
Args:
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
start_pos (int): The starting position of the sequence.
seq_len (int): The length of the sequence.
is_causal (bool): Whether the attention is causal or not.
device (torch.device): The device to use.
Returns:
Tensor: The attention mask tensor.
"""
if seq_mask is None:
if start_pos != 0:
# for single prompt chat
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
else:
return None
if seq_mask.dim() > 2:
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
# if ndim > 2, it's 4D tensor
return seq_mask
batch_size = seq_mask.size(0)
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
# (bsz, start_pos + seq_len)
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
# (bsz, seq_len, start_pos + seq_len)
if is_causal:
causal_mask = torch.tril(
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
diagonal=start_pos
)
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
expanded_mask = expanded_mask & causal_mask
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
# (bsz, 1, seq_len, seq_len + start_pos)
return attention_mask
class Transformer(nn.Module):
def __init__(self, config: TransformerConfig):
super().__init__()
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
self.layers = nn.ModuleList([
DecoderBlock(
config.n_dim,
config.n_head,
config.d_ffn,
config.n_kvhead,
config.norm_eps
)
for _ in range(config.n_layer)
])
self.norm = RMSNorm(config.n_dim, config.norm_eps)
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
init.normal_(self.embedding, mean=0, std=0.02)
def forward(
self,
input_ids: Tensor,
input_mask: Optional[Tensor]=None,
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
start_pos: int = 0
) -> Tensor:
assert input_ids.ndim == 2
seq_len = input_ids.size(-1)
x = F.embedding(input_ids, self.embedding)
self.freq_cis = self.freq_cis.to(x.device)
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
has_kvcache = persistent_key_values is not None
attn_mask = process_attention_mask(
input_mask,
start_pos=start_pos,
seq_len=seq_len,
is_causal=has_kvcache,
device=x.device,
dtype=x.dtype
)
for i, layer in enumerate(self.layers):
kv_cache = persistent_key_values[i] if persistent_key_values else None
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
hidden_states = self.norm(x)
logits = F.linear(hidden_states, self.embedding)
return {
"logits": logits,
"hidden_states": hidden_states
}

View File

@ -1,6 +1,4 @@
from khaosz.trainer.data_util import DatasetLoader
from khaosz.trainer.trainer import Trainer
from khaosz.trainer.train_config import TrainConfig
from khaosz.trainer.strategy import (
CosineScheduleConfig,
SgdrScheduleConfig,
@ -17,19 +15,16 @@ from khaosz.trainer.train_callback import (
)
__all__ = [
"DatasetLoader",
"Trainer",
"TrainConfig",
"StrategyFactory",
"CosineScheduleConfig",
"SgdrScheduleConfig",
"StrategyFactory",
"SchedulerFactory",
# callback
"TrainCallback",
"ProgressBarCallback",
"CheckpointCallback",
"TrainCallback",
"SchedulerCallback",
"StepMonitorCallback"
]

View File

@ -106,7 +106,7 @@ class CheckpointCallback(TrainCallback):
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
context.checkpoint.sampler_state = context.sampler.state_dict()
# context.checkpoint.scheduler_state = context.sampler.state_dict()
context.checkpoint.optimizer_state = context.optimizer.state_dict()
context.checkpoint.save(save_path)
self.last_ckpt_iter = context.current_iter

View File

@ -1,9 +1,10 @@
from dataclasses import dataclass, field, fields
from typing import Optional, Self, TYPE_CHECKING
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader
from khaosz.core.parameter import Checkpoint
from khaosz.trainer.data_util import RandomSampler
from khaosz.config.param_config import Checkpoint
from khaosz.data.data_util import ResumeableRandomSampler
if TYPE_CHECKING:
from khaosz.trainer.trainer import Trainer
@ -13,11 +14,11 @@ if TYPE_CHECKING:
class TrainContext:
dataloader: DataLoader = field(default=None)
optimizer: Optimizer = field(default=None)
sampler: RandomSampler = field(default=None)
scheduler: LRScheduler = field(default=None)
checkpoint: Checkpoint = field(default=None)
epoch: int = field(default=0)
current_iter: int = field(default=0)
loss: float = field(default=0.0)
checkpoint: Checkpoint = field(default=None)
def asdict(self) -> dict:
return {field.name: getattr(self, field.name)
@ -27,15 +28,7 @@ class TrainContext:
class TrainContextBuilder:
def __init__(self, trainer: 'Trainer'):
self.trainer = trainer
self._context = TrainContext(
dataloader=None,
optimizer=None,
sampler=None,
epoch=0,
current_iter=0,
loss=0.0,
checkpoint=None
)
self._context = TrainContext()
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
if checkpoint is None:
@ -43,32 +36,10 @@ class TrainContextBuilder:
model=self.trainer.parameter.model,
tokenizer=self.trainer.parameter.tokenizer,
config=self.trainer.parameter.config,
sampler_state=None,
optimizer_state=None,
loss_list=[]
)
self._context.checkpoint = checkpoint
return self
def with_sampler(self) -> Self:
seed = self.trainer.train_config.random_seed
sampler = RandomSampler(
data_source=self.trainer.train_config.dataset,
seed=seed
)
if self._context.checkpoint and self._context.checkpoint.sampler_state:
sampler.load_state_dict(self._context.checkpoint.sampler_state)
self._context.sampler = sampler
self._context.epoch = sampler.epoch
self._context.current_iter = sampler.current_iter
if self._context.checkpoint:
self._context.checkpoint.sampler_state = sampler.state_dict()
return self
def with_optimizer(self) -> Self:
optimizer = self.trainer.train_config.optimizer
@ -82,11 +53,22 @@ class TrainContextBuilder:
return self
def with_scheduler(self) -> Self:
return self
def with_dataloader(self) -> Self:
resumeable_sampler = ResumeableRandomSampler(
data_source=self.trainer.train_config.dataset,
start_epoch=self._context.epoch,
start_iter=self._context.current_iter,
seed=self.trainer.train_config.random_seed
)
dataloader = DataLoader(
self.trainer.train_config.dataset,
batch_size=self.trainer.train_config.batch_size,
sampler=self._context.sampler,
sampler=resumeable_sampler,
num_workers=self.trainer.train_config.num_workers,
pin_memory=self.trainer.train_config.pin_memory,
prefetch_factor=self.trainer.train_config.prefetch_factor

View File

@ -1,9 +1,9 @@
import logging
from typing import Optional, List
from khaosz.core import ModelParameter, Checkpoint
from khaosz.config import ModelParameter, Checkpoint
from khaosz.trainer.strategy import ScheduleConfig
from khaosz.trainer.train_config import TrainConfig
from khaosz.config.train_config import TrainConfig
from khaosz.trainer.train_callback import (
TrainCallback,
ProgressBarCallback,
@ -39,8 +39,8 @@ class Trainer:
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
return (TrainContextBuilder(self)
.with_checkpoint(checkpoint)
.with_sampler()
.with_optimizer()
.with_scheduler()
.with_dataloader()
.build())

View File

@ -9,9 +9,11 @@ import pytest
import matplotlib
from torch.utils.data import Dataset
from khaosz.core import *
from khaosz.trainer import *
from khaosz.trainer.data_util import *
from khaosz.config.model_config import TransformerConfig
from khaosz.data.data_util import build_attention_mask, build_loss_mask
from khaosz.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer
matplotlib.use("Agg")

View File

@ -1,8 +1,7 @@
import torch
from khaosz.core import *
from khaosz.config import *
from khaosz.trainer import *
from khaosz.trainer.data_util import *
def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated"""

View File

@ -3,9 +3,9 @@ import torch
import pickle
import numpy as np
from khaosz.core import *
from khaosz.trainer import *
from khaosz.trainer.data_util import *
from khaosz.data.data_util import *
def test_dataset_loader_random_paths(base_test_env):
"""Test dataset loader with multiple random paths"""

View File

@ -1,8 +1,7 @@
import torch
from khaosz.core import *
from khaosz.config import *
from khaosz.trainer import *
from khaosz.trainer.data_util import *
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior"""

View File

@ -5,8 +5,11 @@ import shutil
import pytest
import tempfile
import safetensors.torch as st
from khaosz.core import *
from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
from khaosz.trainer import *
from khaosz.config import *
from khaosz.model import *
from khaosz.data import *
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers
@pytest.fixture

View File

@ -1,14 +1,13 @@
from khaosz.core import *
from khaosz.trainer import *
from khaosz.trainer.data_util import *
from khaosz.data.data_util import *
def test_random_sampler_consistency(random_dataset):
"""Test RandomSampler produces consistent results with same seed"""
dataset = random_dataset
# Create two samplers with same seed
sampler1 = RandomSampler(dataset, seed=42)
sampler2 = RandomSampler(dataset, seed=42)
sampler1 = ResumeableRandomSampler(dataset, seed=42)
sampler2 = ResumeableRandomSampler(dataset, seed=42)
indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2))
@ -20,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset):
dataset = random_dataset
# Create two samplers with different seeds
sampler1 = RandomSampler(dataset, seed=42)
sampler2 = RandomSampler(dataset, seed=123)
sampler1 = ResumeableRandomSampler(dataset, seed=42)
sampler2 = ResumeableRandomSampler(dataset, seed=123)
indices1 = list(iter(sampler1))
indices2 = list(iter(sampler2))
@ -29,38 +28,13 @@ def test_random_sampler_different_seeds(random_dataset):
# Very high probability they should be different
assert indices1 != indices2
def test_sampler_state_persistence(random_dataset):
"""Test that sampler state is correctly saved and loaded"""
dataset = random_dataset
n = len(dataset)
# Create sampler and get some indices
sampler = RandomSampler(dataset, seed=42)
iter1 = iter(sampler)
indices1 = [next(iter1) for _ in range(min(10, n))]
# Save state
state_dict = sampler.state_dict()
# Get more indices
indices2 = [next(iter1) for _ in range(min(10, n - len(indices1)))]
# Create new sampler and load state
sampler2 = RandomSampler(dataset, seed=42)
sampler2.load_state_dict(state_dict)
# Check that new sampler produces same sequence from saved point
iter2 = iter(sampler2)
indices3 = [next(iter2) for _ in range(min(10, n - len(indices1)))]
assert indices2 == indices3
def test_sampler_across_epochs(random_dataset):
"""Test sampler behavior across multiple epochs"""
dataset = random_dataset
n = len(dataset)
sampler = RandomSampler(dataset, seed=42)
sampler = ResumeableRandomSampler(dataset, seed=42)
# Get indices for first epoch
epoch1_indices = list(iter(sampler))

View File

@ -1,9 +1,10 @@
import torch
import numpy as np
from khaosz.core import *
from khaosz.config import *
from khaosz.trainer import *
from khaosz.trainer.data_util import *
from khaosz.data.data_util import *
def test_different_batch_sizes(base_test_env, random_dataset):
"""Test training with different batch sizes"""

View File

@ -1,9 +1,9 @@
import torch
import numpy as np
from khaosz.core import *
from khaosz.config import *
from khaosz.trainer import *
from khaosz.trainer.data_util import *
from khaosz.data.data_util import *
def test_multi_turn_training(base_test_env, multi_turn_dataset):
"""Test training with multi-turn conversation data"""