refactor(khaosz): 重构项目结构
This commit is contained in:
parent
8434c19923
commit
c51b203fde
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from khaosz.core.transformer import TransformerConfig, Transformer
|
||||
from khaosz.model.transformer import TransformerConfig, Transformer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -1,16 +1,23 @@
|
|||
__version__ = "1.3.0"
|
||||
__author__ = "ViperEkura"
|
||||
|
||||
from khaosz.model import Khaosz
|
||||
from khaosz.core.transformer import Transformer, TransformerConfig
|
||||
from khaosz.khaosz import Khaosz
|
||||
from khaosz.config import (
|
||||
TransformerConfig,
|
||||
ParameterLoader,
|
||||
TrainConfig,
|
||||
)
|
||||
from khaosz.model.transformer import Transformer
|
||||
from khaosz.utils.retriever import Retriever
|
||||
from khaosz.utils.splitter import (
|
||||
SemanticTextSplitter,
|
||||
PriorityTextSplitter
|
||||
)
|
||||
from khaosz.core.tokenizer import BpeTokenizer
|
||||
from khaosz.core.parameter import ParameterLoader
|
||||
from khaosz.core.generator import (
|
||||
from khaosz.data import (
|
||||
DatasetLoader,
|
||||
BpeTokenizer
|
||||
)
|
||||
from khaosz.inference.generator import (
|
||||
TextGenerator,
|
||||
ChatGenerator,
|
||||
StreamGenerator,
|
||||
|
|
@ -18,10 +25,9 @@ from khaosz.core.generator import (
|
|||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
|
||||
from khaosz.trainer import (
|
||||
Trainer,
|
||||
DatasetLoader,
|
||||
TrainConfig,
|
||||
StrategyFactory,
|
||||
SchedulerFactory
|
||||
)
|
||||
|
|
@ -44,7 +50,7 @@ __all__ = [
|
|||
|
||||
# trainer
|
||||
"Trainer",
|
||||
"DatasetLoader",
|
||||
"DatasetLoader", # 保持在 __all__ 中,但来源是 khaosz.data
|
||||
"TrainConfig",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
|
|
|
|||
|
|
@ -0,0 +1,12 @@
|
|||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.config.param_config import BaseModelIO, ModelParameter, Checkpoint, ParameterLoader
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
|
||||
__all__ = [
|
||||
"BaseModelIO",
|
||||
"ModelParameter",
|
||||
"Checkpoint",
|
||||
"ParameterLoader",
|
||||
"TransformerConfig",
|
||||
"TrainConfig"
|
||||
]
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
import json
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import Optional, Self
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
# basic config
|
||||
vocab_size: Optional[int] = None
|
||||
n_dim: Optional[int] = None
|
||||
n_head: Optional[int] = None
|
||||
n_layer: Optional[int] = None
|
||||
m_len: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
d_ffn: Optional[int] = None
|
||||
|
||||
# GQA
|
||||
n_kvhead: Optional[int] = None
|
||||
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
with open(config_path, 'r') as f:
|
||||
config: dict = json.load(f)
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str) -> None:
|
||||
config_dict = asdict(self)
|
||||
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
|
||||
|
|
@ -8,8 +8,9 @@ from dataclasses import dataclass, field
|
|||
from typing import Any, Dict, List, Optional, Self, Union
|
||||
from pathlib import Path
|
||||
|
||||
from khaosz.core.tokenizer import BpeTokenizer
|
||||
from khaosz.core.transformer import TransformerConfig, Transformer
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
||||
|
||||
class BaseModelIO:
|
||||
|
|
@ -99,18 +100,18 @@ class Checkpoint(BaseModelIO):
|
|||
metadata={"help": "Transformer model."}
|
||||
)
|
||||
tokenizer: BpeTokenizer = field(
|
||||
default_factory=BpeTokenizer,
|
||||
default=None,
|
||||
metadata={"help": "Tokenizer for the model."}
|
||||
)
|
||||
config: TransformerConfig = field(
|
||||
default_factory=TransformerConfig,
|
||||
default=None,
|
||||
metadata={"help": "Transformer model configuration."}
|
||||
)
|
||||
optimizer_state: Dict[str, Any] = field(
|
||||
default=None,
|
||||
metadata={"help": "Optimizer state."}
|
||||
)
|
||||
sampler_state: Dict[str, Any] = field(
|
||||
scheduler_state: Dict[str, Any] = field(
|
||||
default=None,
|
||||
metadata={"help": "Sampler state."}
|
||||
)
|
||||
|
|
@ -145,7 +146,7 @@ class Checkpoint(BaseModelIO):
|
|||
|
||||
# Save sampler state
|
||||
with open(str(paths["sampler_state"]), "wb") as f:
|
||||
pkl.dump(self.sampler_state, f)
|
||||
pkl.dump(self.scheduler_state, f)
|
||||
|
||||
def load_training_state(self, load_dir: Union[str, Path]) -> Self:
|
||||
paths = self._get_training_paths(load_dir)
|
||||
|
|
@ -163,7 +164,7 @@ class Checkpoint(BaseModelIO):
|
|||
# Load sampler state
|
||||
if paths["sampler_state"].exists():
|
||||
with open(str(paths["sampler_state"]), "rb") as f:
|
||||
self.sampler_state = pkl.load(f)
|
||||
self.scheduler_state = pkl.load(f)
|
||||
|
||||
return self
|
||||
|
||||
|
|
@ -1,14 +1,16 @@
|
|||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
from typing import Optional, TYPE_CHECKING
|
||||
from torch.utils.data import Dataset
|
||||
from torch.optim import Optimizer
|
||||
from khaosz.trainer.strategy import BaseStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.strategy import BaseStrategy
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainConfig:
|
||||
|
||||
strategy: BaseStrategy = field(
|
||||
strategy: "BaseStrategy" = field(
|
||||
default=None,
|
||||
metadata={"help": "Training strategy."}
|
||||
)
|
||||
|
|
@ -1,27 +0,0 @@
|
|||
from khaosz.core.tokenizer import BpeTokenizer
|
||||
from khaosz.core.transformer import Transformer, TransformerConfig
|
||||
from khaosz.core.parameter import ParameterLoader, ModelParameter, Checkpoint
|
||||
from khaosz.core.generator import (
|
||||
TextGenerator,
|
||||
ChatGenerator,
|
||||
StreamGenerator,
|
||||
BatchGenerator,
|
||||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Transformer",
|
||||
"TransformerConfig",
|
||||
"BpeTokenizer",
|
||||
"ParameterLoader",
|
||||
"ModelParameter",
|
||||
"Checkpoint",
|
||||
"TextGenerator",
|
||||
"ChatGenerator",
|
||||
"StreamGenerator",
|
||||
"BatchGenerator",
|
||||
"RetrievalGenerator",
|
||||
"EmbeddingEncoder"
|
||||
]
|
||||
|
|
@ -0,0 +1,30 @@
|
|||
from khaosz.data.data_util import (
|
||||
BaseDataset,
|
||||
SeqDataset,
|
||||
DpoDataset,
|
||||
SftDataset,
|
||||
PpoDataset,
|
||||
MutiSegmentFetcher,
|
||||
ResumeableRandomSampler,
|
||||
DatasetLoader,
|
||||
load_pkl_files,
|
||||
build_attention_mask,
|
||||
build_loss_mask
|
||||
)
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
|
||||
__all__ = [
|
||||
"BaseDataset",
|
||||
"SeqDataset",
|
||||
"DpoDataset",
|
||||
"SftDataset",
|
||||
"PpoDataset",
|
||||
"MutiSegmentFetcher",
|
||||
"ResumeableRandomSampler",
|
||||
"DatasetLoader",
|
||||
"load_pkl_files",
|
||||
"build_attention_mask",
|
||||
"build_loss_mask",
|
||||
"BpeTokenizer"
|
||||
]
|
||||
|
|
@ -265,56 +265,37 @@ class DatasetLoader:
|
|||
return dataset
|
||||
|
||||
|
||||
class RandomSampler(Sampler[int]):
|
||||
def __init__(self, data_source, generator=None, seed=42):
|
||||
self.data_source = data_source
|
||||
self.seed = seed
|
||||
self.epoch = 0
|
||||
self.current_iter = 0
|
||||
class ResumeableRandomSampler(Sampler[int]):
|
||||
def __init__(self, data_source, start_epoch=0, start_iter=0, seed=42):
|
||||
self.num_samples = len(data_source)
|
||||
self.epoch = start_epoch
|
||||
self.iter = start_iter
|
||||
|
||||
generator = torch.Generator()
|
||||
generator.manual_seed(seed)
|
||||
|
||||
self.generator = generator
|
||||
self._indices = None
|
||||
|
||||
if generator is None:
|
||||
self.generator = torch.Generator()
|
||||
self.generator.manual_seed(seed)
|
||||
else:
|
||||
self.generator = generator
|
||||
def _get_indices(self):
|
||||
for _ in range(self.epoch):
|
||||
_ = torch.randperm(self.num_samples, generator=self.generator)
|
||||
|
||||
def _generate_indices(self):
|
||||
n = len(self.data_source)
|
||||
self._indices = torch.randperm(n, generator=self.generator).tolist()
|
||||
current_epoch_indices = torch.randperm(self.num_samples, generator=self.generator).tolist()
|
||||
self._indices = current_epoch_indices[self.iter % self.num_samples:]
|
||||
|
||||
def __iter__(self):
|
||||
n = len(self.data_source)
|
||||
|
||||
if self._indices is None:
|
||||
self._generate_indices()
|
||||
self._get_indices()
|
||||
|
||||
start = self.current_iter % n
|
||||
for i in range(start, n):
|
||||
self.current_iter += 1
|
||||
yield self._indices[i]
|
||||
for i in self._indices:
|
||||
self.iter += 1
|
||||
yield i
|
||||
|
||||
self.epoch += 1
|
||||
self._indices = None
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_source)
|
||||
|
||||
def state_dict(self):
|
||||
return {
|
||||
'epoch': self.epoch,
|
||||
'current_iter': self.current_iter,
|
||||
'seed': self.seed,
|
||||
'generator_state': self.generator.get_state() if self.generator else None,
|
||||
'indices': self._indices
|
||||
}
|
||||
|
||||
def load_state_dict(self, state_dict):
|
||||
self.epoch = state_dict['epoch']
|
||||
self.current_iter = state_dict['current_iter']
|
||||
self.seed = state_dict['seed']
|
||||
|
||||
if self.generator and state_dict['generator_state'] is not None:
|
||||
self.generator.set_state(state_dict['generator_state'])
|
||||
|
||||
self._indices = state_dict['indices']
|
||||
if self._indices is None:
|
||||
self._get_indices()
|
||||
return len(self._indices)
|
||||
|
|
@ -0,0 +1,97 @@
|
|||
import torch
|
||||
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator, Self
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
|
||||
class GeneratorCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||
start_pos: int = 0
|
||||
) -> Tuple[Tensor, int]:
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
cache_increase = input_ids.size(-1)
|
||||
|
||||
return logits, cache_increase
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
|
||||
class EmbeddingEncoderCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
with_batch = isinstance(sentence, list)
|
||||
ids = self.tokenizer.encode(sentence)
|
||||
batch_ids = ids if with_batch else [ids]
|
||||
max_model_len = self.config.m_len
|
||||
|
||||
all_fragments = []
|
||||
fragment_origin_idx = []
|
||||
|
||||
for i, seq in enumerate(batch_ids):
|
||||
if len(seq) > max_model_len:
|
||||
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
||||
all_fragments.extend(fragments)
|
||||
fragment_origin_idx.extend([i] * len(fragments))
|
||||
else:
|
||||
all_fragments.append(seq)
|
||||
fragment_origin_idx.append(i)
|
||||
|
||||
#if empty fragments
|
||||
if not all_fragments or not ids:
|
||||
return [] if with_batch else torch.tensor([])
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
||||
|
||||
padded_ids = []
|
||||
masks = []
|
||||
for seq in all_fragments:
|
||||
pad_len = max_len - len(seq)
|
||||
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
|
||||
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
||||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
||||
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
||||
# [num_fragments, seq_len, hidden_size]
|
||||
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
||||
|
||||
sentence_embs: List[Tensor] = []
|
||||
for i in range(len(batch_ids)):
|
||||
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
||||
if indices is not None:
|
||||
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||
sentence_embs.append(emb.flatten())
|
||||
|
||||
if with_batch:
|
||||
return [emb.flatten() for emb in sentence_embs]
|
||||
else:
|
||||
return sentence_embs[0].flatten()
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
|
@ -1,7 +1,8 @@
|
|||
import torch
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator, Self
|
||||
from khaosz.core.parameter import ModelParameter
|
||||
from typing import List, Tuple, Union, Optional, Generator
|
||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
|
||||
|
||||
def build_prompt(query: str, history: Optional[List[Tuple[str, str]]] = None) -> str:
|
||||
|
|
@ -168,96 +169,6 @@ class KVCacheManager:
|
|||
return self._seq_mask
|
||||
|
||||
|
||||
class GeneratorCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
kv_caches: Optional[List[Tuple[Tensor, Tensor]]] = None,
|
||||
start_pos: int = 0
|
||||
) -> Tuple[Tensor, int]:
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_ids, attn_mask, kv_caches, start_pos)
|
||||
logits = outputs["logits"][:, -1, :]
|
||||
cache_increase = input_ids.size(-1)
|
||||
|
||||
return logits, cache_increase
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
|
||||
class EmbeddingEncoderCore:
|
||||
def __init__(self, parameter: ModelParameter):
|
||||
self.model = parameter.model
|
||||
self.tokenizer = parameter.tokenizer
|
||||
self.config = parameter.config
|
||||
|
||||
def encode(self, sentence: Union[str, List[str]]) -> Union[Tensor, List[Tensor]]:
|
||||
with_batch = isinstance(sentence, list)
|
||||
ids = self.tokenizer.encode(sentence)
|
||||
batch_ids = ids if with_batch else [ids]
|
||||
max_model_len = self.config.m_len
|
||||
|
||||
all_fragments = []
|
||||
fragment_origin_idx = []
|
||||
|
||||
for i, seq in enumerate(batch_ids):
|
||||
if len(seq) > max_model_len:
|
||||
fragments = [seq[j:j+max_model_len] for j in range(0, len(seq), max_model_len)]
|
||||
all_fragments.extend(fragments)
|
||||
fragment_origin_idx.extend([i] * len(fragments))
|
||||
else:
|
||||
all_fragments.append(seq)
|
||||
fragment_origin_idx.append(i)
|
||||
|
||||
#if empty fragments
|
||||
if not all_fragments or not ids:
|
||||
return [] if with_batch else torch.tensor([])
|
||||
|
||||
device = next(self.model.parameters()).device
|
||||
max_len = min(max(len(seq) for seq in all_fragments), max_model_len)
|
||||
|
||||
padded_ids = []
|
||||
masks = []
|
||||
for seq in all_fragments:
|
||||
pad_len = max_len - len(seq)
|
||||
padded_seq = seq + [self.tokenizer.pad_id] * pad_len
|
||||
mask = [token_id != self.tokenizer.pad_id for token_id in padded_seq]
|
||||
padded_ids.append(padded_seq)
|
||||
masks.append(mask)
|
||||
|
||||
input_tensor = torch.tensor(padded_ids, device=device, dtype=torch.long)
|
||||
seq_mask = torch.tensor(masks, device=device, dtype=torch.bool)
|
||||
|
||||
with torch.inference_mode():
|
||||
outputs = self.model(input_tensor, seq_mask)["hidden_states"]
|
||||
# [num_fragments, seq_len, hidden_size]
|
||||
fragment_embs = torch.mul(outputs, seq_mask.unsqueeze(-1))
|
||||
|
||||
sentence_embs: List[Tensor] = []
|
||||
for i in range(len(batch_ids)):
|
||||
indices = [idx for idx, orig_idx in enumerate(fragment_origin_idx) if orig_idx == i]
|
||||
if indices is not None:
|
||||
sum_frags = torch.sum(fragment_embs[indices, :, :], dim=1) # [frags, hidden_size]
|
||||
length = torch.sum(seq_mask[indices, :], dim=1).unsqueeze(1) # [frags, 1]
|
||||
emb = torch.sum(sum_frags / length, dim=0) # [frags, hidden_size]
|
||||
sentence_embs.append(emb.flatten())
|
||||
|
||||
if with_batch:
|
||||
return [emb.flatten() for emb in sentence_embs]
|
||||
else:
|
||||
return sentence_embs[0].flatten()
|
||||
|
||||
def to(self, *args, **kargs) -> Self:
|
||||
self.model.to(*args, **kargs)
|
||||
return self
|
||||
|
||||
|
||||
class TextGenerator(GeneratorCore):
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
from torch import Tensor
|
||||
from typing import List, Tuple, Generator, Union
|
||||
|
||||
from khaosz.core.generator import (
|
||||
from khaosz.inference.generator import (
|
||||
TextGenerator,
|
||||
ChatGenerator,
|
||||
StreamGenerator,
|
||||
|
|
@ -9,7 +9,7 @@ from khaosz.core.generator import (
|
|||
RetrievalGenerator,
|
||||
EmbeddingEncoder
|
||||
)
|
||||
from khaosz.core.parameter import ParameterLoader
|
||||
from khaosz.config.param_config import ParameterLoader
|
||||
|
||||
|
||||
class Khaosz:
|
||||
|
|
@ -0,0 +1,17 @@
|
|||
from khaosz.model.module import (
|
||||
Linear,
|
||||
RMSNorm,
|
||||
MLP,
|
||||
GQA,
|
||||
DecoderBlock,
|
||||
)
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
||||
__all__ = [
|
||||
"Linear",
|
||||
"RMSNorm",
|
||||
"MLP",
|
||||
"GQA",
|
||||
"DecoderBlock",
|
||||
"Transformer"
|
||||
]
|
||||
|
|
@ -1,12 +1,11 @@
|
|||
import json
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import init
|
||||
from dataclasses import asdict, dataclass
|
||||
from typing import List, Optional, Self, Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
||||
def repeat_kv(x: Tensor, n_rep: int) -> Tensor:
|
||||
|
|
@ -71,89 +70,6 @@ def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
|
|||
|
||||
return x_out.to(dtype)
|
||||
|
||||
def process_attention_mask(
|
||||
seq_mask: Tensor,
|
||||
start_pos: int = 0,
|
||||
seq_len: int = 0,
|
||||
is_causal: bool = False,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.float32
|
||||
) -> Tensor:
|
||||
"""
|
||||
Create attention mask for GQA
|
||||
Args:
|
||||
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
||||
start_pos (int): The starting position of the sequence.
|
||||
seq_len (int): The length of the sequence.
|
||||
is_causal (bool): Whether the attention is causal or not.
|
||||
device (torch.device): The device to use.
|
||||
Returns:
|
||||
Tensor: The attention mask tensor.
|
||||
"""
|
||||
|
||||
if seq_mask is None:
|
||||
if start_pos != 0:
|
||||
# for single prompt chat
|
||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||
else:
|
||||
return None
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||
# if ndim > 2, it's 4D tensor
|
||||
return seq_mask
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
# (bsz, start_pos + seq_len)
|
||||
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
||||
# (bsz, seq_len, start_pos + seq_len)
|
||||
|
||||
if is_causal:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
|
||||
diagonal=start_pos
|
||||
)
|
||||
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
|
||||
expanded_mask = expanded_mask & causal_mask
|
||||
|
||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
@dataclass
|
||||
class TransformerConfig:
|
||||
# basic config
|
||||
vocab_size: Optional[int] = None
|
||||
n_dim: Optional[int] = None
|
||||
n_head: Optional[int] = None
|
||||
n_layer: Optional[int] = None
|
||||
m_len: Optional[int] = None
|
||||
norm_eps: Optional[float] = None
|
||||
d_ffn: Optional[int] = None
|
||||
|
||||
# GQA
|
||||
n_kvhead: Optional[int] = None
|
||||
|
||||
|
||||
def load(self, config_path: str) -> Self:
|
||||
with open(config_path, 'r') as f:
|
||||
config: dict = json.load(f)
|
||||
for key, value in config.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
|
||||
return self
|
||||
|
||||
def save(self, config_path: str) -> None:
|
||||
config_dict = asdict(self)
|
||||
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||
with open(config_path, 'w') as f:
|
||||
json.dump(config_dict, f, indent=4)
|
||||
|
||||
|
||||
class Linear(nn.Module):
|
||||
def __init__(self, in_dim: int, out_dim: int, bias: bool=False):
|
||||
|
|
@ -288,59 +204,3 @@ class DecoderBlock(nn.Module):
|
|||
x = self.ffn(self.norm_ffn(x)) + x
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderBlock(
|
||||
config.n_dim,
|
||||
config.n_head,
|
||||
config.d_ffn,
|
||||
config.n_kvhead,
|
||||
config.norm_eps
|
||||
)
|
||||
for _ in range(config.n_layer)
|
||||
])
|
||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
||||
init.normal_(self.embedding, mean=0, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor]=None,
|
||||
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
||||
start_pos: int = 0
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
seq_len = input_ids.size(-1)
|
||||
x = F.embedding(input_ids, self.embedding)
|
||||
|
||||
self.freq_cis = self.freq_cis.to(x.device)
|
||||
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
||||
has_kvcache = persistent_key_values is not None
|
||||
|
||||
attn_mask = process_attention_mask(
|
||||
input_mask,
|
||||
start_pos=start_pos,
|
||||
seq_len=seq_len,
|
||||
is_causal=has_kvcache,
|
||||
device=x.device,
|
||||
dtype=x.dtype
|
||||
)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
kv_cache = persistent_key_values[i] if persistent_key_values else None
|
||||
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = F.linear(hidden_states, self.embedding)
|
||||
|
||||
return {
|
||||
"logits": logits,
|
||||
"hidden_states": hidden_states
|
||||
}
|
||||
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from torch import Tensor
|
||||
from torch.nn import init
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.model.module import DecoderBlock, RMSNorm, get_rotary_emb
|
||||
|
||||
|
||||
def process_attention_mask(
|
||||
seq_mask: Tensor,
|
||||
start_pos: int = 0,
|
||||
seq_len: int = 0,
|
||||
is_causal: bool = False,
|
||||
device: torch.device = "cuda",
|
||||
dtype: torch.dtype = torch.float32
|
||||
) -> Tensor:
|
||||
"""
|
||||
Create attention mask for GQA
|
||||
Args:
|
||||
seq_mask (Tensor): A tensor indicating whether each position is valid or not.
|
||||
start_pos (int): The starting position of the sequence.
|
||||
seq_len (int): The length of the sequence.
|
||||
is_causal (bool): Whether the attention is causal or not.
|
||||
device (torch.device): The device to use.
|
||||
Returns:
|
||||
Tensor: The attention mask tensor.
|
||||
"""
|
||||
|
||||
if seq_mask is None:
|
||||
if start_pos != 0:
|
||||
# for single prompt chat
|
||||
seq_mask = torch.ones((1, seq_len), dtype=torch.bool, device=device)
|
||||
else:
|
||||
return None
|
||||
|
||||
if seq_mask.dim() > 2:
|
||||
# shape (bsz, seq_len) or (bsz,n_heads, seq_len, seq_len + start_pos)
|
||||
# if ndim > 2, it's 4D tensor
|
||||
return seq_mask
|
||||
|
||||
batch_size = seq_mask.size(0)
|
||||
seq_mask = seq_mask[:, :start_pos + seq_len].to(device=device, dtype=torch.bool)
|
||||
# (bsz, start_pos + seq_len)
|
||||
expanded_mask = seq_mask.unsqueeze(1).expand(batch_size, seq_len, start_pos + seq_len)
|
||||
# (bsz, seq_len, start_pos + seq_len)
|
||||
|
||||
if is_causal:
|
||||
causal_mask = torch.tril(
|
||||
torch.ones((seq_len, start_pos + seq_len), dtype=torch.bool, device=device),
|
||||
diagonal=start_pos
|
||||
)
|
||||
causal_mask = causal_mask.unsqueeze(0).expand(batch_size, seq_len, start_pos + seq_len)
|
||||
expanded_mask = expanded_mask & causal_mask
|
||||
|
||||
attention_mask = torch.zeros_like(expanded_mask, dtype=dtype, device=device)
|
||||
attention_mask = attention_mask.masked_fill_(~expanded_mask, -torch.finfo(dtype).max / 2).unsqueeze(1)
|
||||
# (bsz, 1, seq_len, seq_len + start_pos)
|
||||
|
||||
return attention_mask
|
||||
|
||||
|
||||
class Transformer(nn.Module):
|
||||
def __init__(self, config: TransformerConfig):
|
||||
super().__init__()
|
||||
self.embedding = nn.Parameter(torch.empty(config.vocab_size, config.n_dim))
|
||||
self.layers = nn.ModuleList([
|
||||
DecoderBlock(
|
||||
config.n_dim,
|
||||
config.n_head,
|
||||
config.d_ffn,
|
||||
config.n_kvhead,
|
||||
config.norm_eps
|
||||
)
|
||||
for _ in range(config.n_layer)
|
||||
])
|
||||
self.norm = RMSNorm(config.n_dim, config.norm_eps)
|
||||
self.freq_cis = get_rotary_emb(config.n_dim // config.n_head, config.m_len)
|
||||
init.normal_(self.embedding, mean=0, std=0.02)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: Tensor,
|
||||
input_mask: Optional[Tensor]=None,
|
||||
persistent_key_values: Optional[List[Tuple[Tensor, Tensor]]]=None,
|
||||
start_pos: int = 0
|
||||
) -> Tensor:
|
||||
assert input_ids.ndim == 2
|
||||
seq_len = input_ids.size(-1)
|
||||
x = F.embedding(input_ids, self.embedding)
|
||||
|
||||
self.freq_cis = self.freq_cis.to(x.device)
|
||||
freqs_cis = self.freq_cis[start_pos:start_pos+seq_len]
|
||||
has_kvcache = persistent_key_values is not None
|
||||
|
||||
attn_mask = process_attention_mask(
|
||||
input_mask,
|
||||
start_pos=start_pos,
|
||||
seq_len=seq_len,
|
||||
is_causal=has_kvcache,
|
||||
device=x.device,
|
||||
dtype=x.dtype
|
||||
)
|
||||
|
||||
for i, layer in enumerate(self.layers):
|
||||
kv_cache = persistent_key_values[i] if persistent_key_values else None
|
||||
x = layer(x, freqs_cis, attn_mask, kv_cache, start_pos)
|
||||
|
||||
hidden_states = self.norm(x)
|
||||
logits = F.linear(hidden_states, self.embedding)
|
||||
|
||||
return {
|
||||
"logits": logits,
|
||||
"hidden_states": hidden_states
|
||||
}
|
||||
|
||||
|
|
@ -1,6 +1,4 @@
|
|||
from khaosz.trainer.data_util import DatasetLoader
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.train_config import TrainConfig
|
||||
from khaosz.trainer.strategy import (
|
||||
CosineScheduleConfig,
|
||||
SgdrScheduleConfig,
|
||||
|
|
@ -17,19 +15,16 @@ from khaosz.trainer.train_callback import (
|
|||
)
|
||||
|
||||
__all__ = [
|
||||
"DatasetLoader",
|
||||
"Trainer",
|
||||
"TrainConfig",
|
||||
"StrategyFactory",
|
||||
"CosineScheduleConfig",
|
||||
"SgdrScheduleConfig",
|
||||
"StrategyFactory",
|
||||
"SchedulerFactory",
|
||||
|
||||
# callback
|
||||
"TrainCallback",
|
||||
"ProgressBarCallback",
|
||||
"CheckpointCallback",
|
||||
"TrainCallback",
|
||||
"SchedulerCallback",
|
||||
"StepMonitorCallback"
|
||||
]
|
||||
|
|
@ -106,7 +106,7 @@ class CheckpointCallback(TrainCallback):
|
|||
|
||||
def _save_checkpoint(self, trainer: 'Trainer', context: 'TrainContext'):
|
||||
save_path = os.path.join(trainer.train_config.checkpoint_dir, f"iter_{context.current_iter}")
|
||||
context.checkpoint.sampler_state = context.sampler.state_dict()
|
||||
# context.checkpoint.scheduler_state = context.sampler.state_dict()
|
||||
context.checkpoint.optimizer_state = context.optimizer.state_dict()
|
||||
context.checkpoint.save(save_path)
|
||||
self.last_ckpt_iter = context.current_iter
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
from dataclasses import dataclass, field, fields
|
||||
from typing import Optional, Self, TYPE_CHECKING
|
||||
from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
from khaosz.core.parameter import Checkpoint
|
||||
from khaosz.trainer.data_util import RandomSampler
|
||||
from khaosz.config.param_config import Checkpoint
|
||||
from khaosz.data.data_util import ResumeableRandomSampler
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from khaosz.trainer.trainer import Trainer
|
||||
|
|
@ -13,11 +14,11 @@ if TYPE_CHECKING:
|
|||
class TrainContext:
|
||||
dataloader: DataLoader = field(default=None)
|
||||
optimizer: Optimizer = field(default=None)
|
||||
sampler: RandomSampler = field(default=None)
|
||||
scheduler: LRScheduler = field(default=None)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
epoch: int = field(default=0)
|
||||
current_iter: int = field(default=0)
|
||||
loss: float = field(default=0.0)
|
||||
checkpoint: Checkpoint = field(default=None)
|
||||
|
||||
def asdict(self) -> dict:
|
||||
return {field.name: getattr(self, field.name)
|
||||
|
|
@ -27,15 +28,7 @@ class TrainContext:
|
|||
class TrainContextBuilder:
|
||||
def __init__(self, trainer: 'Trainer'):
|
||||
self.trainer = trainer
|
||||
self._context = TrainContext(
|
||||
dataloader=None,
|
||||
optimizer=None,
|
||||
sampler=None,
|
||||
epoch=0,
|
||||
current_iter=0,
|
||||
loss=0.0,
|
||||
checkpoint=None
|
||||
)
|
||||
self._context = TrainContext()
|
||||
|
||||
def with_checkpoint(self, checkpoint: Optional[Checkpoint]) -> Self:
|
||||
if checkpoint is None:
|
||||
|
|
@ -43,32 +36,10 @@ class TrainContextBuilder:
|
|||
model=self.trainer.parameter.model,
|
||||
tokenizer=self.trainer.parameter.tokenizer,
|
||||
config=self.trainer.parameter.config,
|
||||
sampler_state=None,
|
||||
optimizer_state=None,
|
||||
loss_list=[]
|
||||
)
|
||||
self._context.checkpoint = checkpoint
|
||||
return self
|
||||
|
||||
def with_sampler(self) -> Self:
|
||||
seed = self.trainer.train_config.random_seed
|
||||
sampler = RandomSampler(
|
||||
data_source=self.trainer.train_config.dataset,
|
||||
seed=seed
|
||||
)
|
||||
|
||||
if self._context.checkpoint and self._context.checkpoint.sampler_state:
|
||||
sampler.load_state_dict(self._context.checkpoint.sampler_state)
|
||||
|
||||
self._context.sampler = sampler
|
||||
self._context.epoch = sampler.epoch
|
||||
self._context.current_iter = sampler.current_iter
|
||||
|
||||
if self._context.checkpoint:
|
||||
self._context.checkpoint.sampler_state = sampler.state_dict()
|
||||
|
||||
return self
|
||||
|
||||
def with_optimizer(self) -> Self:
|
||||
optimizer = self.trainer.train_config.optimizer
|
||||
|
||||
|
|
@ -82,11 +53,22 @@ class TrainContextBuilder:
|
|||
|
||||
return self
|
||||
|
||||
def with_scheduler(self) -> Self:
|
||||
return self
|
||||
|
||||
|
||||
def with_dataloader(self) -> Self:
|
||||
resumeable_sampler = ResumeableRandomSampler(
|
||||
data_source=self.trainer.train_config.dataset,
|
||||
start_epoch=self._context.epoch,
|
||||
start_iter=self._context.current_iter,
|
||||
seed=self.trainer.train_config.random_seed
|
||||
)
|
||||
|
||||
dataloader = DataLoader(
|
||||
self.trainer.train_config.dataset,
|
||||
batch_size=self.trainer.train_config.batch_size,
|
||||
sampler=self._context.sampler,
|
||||
sampler=resumeable_sampler,
|
||||
num_workers=self.trainer.train_config.num_workers,
|
||||
pin_memory=self.trainer.train_config.pin_memory,
|
||||
prefetch_factor=self.trainer.train_config.prefetch_factor
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
|
||||
from khaosz.core import ModelParameter, Checkpoint
|
||||
from khaosz.config import ModelParameter, Checkpoint
|
||||
from khaosz.trainer.strategy import ScheduleConfig
|
||||
from khaosz.trainer.train_config import TrainConfig
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
from khaosz.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
|
|
@ -39,8 +39,8 @@ class Trainer:
|
|||
def _build_train_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||
return (TrainContextBuilder(self)
|
||||
.with_checkpoint(checkpoint)
|
||||
.with_sampler()
|
||||
.with_optimizer()
|
||||
.with_scheduler()
|
||||
.with_dataloader()
|
||||
.build())
|
||||
|
||||
|
|
|
|||
|
|
@ -9,9 +9,11 @@ import pytest
|
|||
import matplotlib
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from khaosz.core import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.trainer.data_util import *
|
||||
from khaosz.config.model_config import TransformerConfig
|
||||
from khaosz.data.data_util import build_attention_mask, build_loss_mask
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.model.transformer import Transformer
|
||||
|
||||
|
||||
matplotlib.use("Agg")
|
||||
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import torch
|
||||
|
||||
from khaosz.core import *
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.trainer.data_util import *
|
||||
|
||||
def test_callback_integration(base_test_env, random_dataset):
|
||||
"""Test that all callbacks are properly integrated"""
|
||||
|
|
|
|||
|
|
@ -3,9 +3,9 @@ import torch
|
|||
import pickle
|
||||
import numpy as np
|
||||
|
||||
from khaosz.core import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.trainer.data_util import *
|
||||
from khaosz.data.data_util import *
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
"""Test dataset loader with multiple random paths"""
|
||||
|
|
|
|||
|
|
@ -1,8 +1,7 @@
|
|||
import torch
|
||||
|
||||
from khaosz.core import *
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.trainer.data_util import *
|
||||
|
||||
|
||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||
"""Simulate early stopping behavior"""
|
||||
|
|
|
|||
|
|
@ -5,8 +5,11 @@ import shutil
|
|||
import pytest
|
||||
import tempfile
|
||||
import safetensors.torch as st
|
||||
from khaosz.core import *
|
||||
from khaosz.core.generator import EmbeddingEncoderCore, GeneratorCore
|
||||
from khaosz.trainer import *
|
||||
from khaosz.config import *
|
||||
from khaosz.model import *
|
||||
from khaosz.data import *
|
||||
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||
from tokenizers import pre_tokenizers
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -1,14 +1,13 @@
|
|||
from khaosz.core import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.trainer.data_util import *
|
||||
from khaosz.data.data_util import *
|
||||
|
||||
def test_random_sampler_consistency(random_dataset):
|
||||
"""Test RandomSampler produces consistent results with same seed"""
|
||||
dataset = random_dataset
|
||||
|
||||
# Create two samplers with same seed
|
||||
sampler1 = RandomSampler(dataset, seed=42)
|
||||
sampler2 = RandomSampler(dataset, seed=42)
|
||||
sampler1 = ResumeableRandomSampler(dataset, seed=42)
|
||||
sampler2 = ResumeableRandomSampler(dataset, seed=42)
|
||||
|
||||
indices1 = list(iter(sampler1))
|
||||
indices2 = list(iter(sampler2))
|
||||
|
|
@ -20,8 +19,8 @@ def test_random_sampler_different_seeds(random_dataset):
|
|||
dataset = random_dataset
|
||||
|
||||
# Create two samplers with different seeds
|
||||
sampler1 = RandomSampler(dataset, seed=42)
|
||||
sampler2 = RandomSampler(dataset, seed=123)
|
||||
sampler1 = ResumeableRandomSampler(dataset, seed=42)
|
||||
sampler2 = ResumeableRandomSampler(dataset, seed=123)
|
||||
|
||||
indices1 = list(iter(sampler1))
|
||||
indices2 = list(iter(sampler2))
|
||||
|
|
@ -29,38 +28,13 @@ def test_random_sampler_different_seeds(random_dataset):
|
|||
# Very high probability they should be different
|
||||
assert indices1 != indices2
|
||||
|
||||
def test_sampler_state_persistence(random_dataset):
|
||||
"""Test that sampler state is correctly saved and loaded"""
|
||||
dataset = random_dataset
|
||||
n = len(dataset)
|
||||
|
||||
# Create sampler and get some indices
|
||||
sampler = RandomSampler(dataset, seed=42)
|
||||
iter1 = iter(sampler)
|
||||
indices1 = [next(iter1) for _ in range(min(10, n))]
|
||||
|
||||
# Save state
|
||||
state_dict = sampler.state_dict()
|
||||
|
||||
# Get more indices
|
||||
indices2 = [next(iter1) for _ in range(min(10, n - len(indices1)))]
|
||||
|
||||
# Create new sampler and load state
|
||||
sampler2 = RandomSampler(dataset, seed=42)
|
||||
sampler2.load_state_dict(state_dict)
|
||||
|
||||
# Check that new sampler produces same sequence from saved point
|
||||
iter2 = iter(sampler2)
|
||||
indices3 = [next(iter2) for _ in range(min(10, n - len(indices1)))]
|
||||
|
||||
assert indices2 == indices3
|
||||
|
||||
def test_sampler_across_epochs(random_dataset):
|
||||
"""Test sampler behavior across multiple epochs"""
|
||||
dataset = random_dataset
|
||||
n = len(dataset)
|
||||
|
||||
sampler = RandomSampler(dataset, seed=42)
|
||||
sampler = ResumeableRandomSampler(dataset, seed=42)
|
||||
|
||||
# Get indices for first epoch
|
||||
epoch1_indices = list(iter(sampler))
|
||||
|
|
|
|||
|
|
@ -1,9 +1,10 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from khaosz.core import *
|
||||
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.trainer.data_util import *
|
||||
from khaosz.data.data_util import *
|
||||
|
||||
def test_different_batch_sizes(base_test_env, random_dataset):
|
||||
"""Test training with different batch sizes"""
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from khaosz.core import *
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.trainer.data_util import *
|
||||
from khaosz.data.data_util import *
|
||||
|
||||
def test_multi_turn_training(base_test_env, multi_turn_dataset):
|
||||
"""Test training with multi-turn conversation data"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue