chore: 修改错误拼写

This commit is contained in:
ViperEkura 2026-04-06 10:37:19 +08:00
parent ace8f6ee68
commit f2ffdf60d0
4 changed files with 14 additions and 13 deletions

View File

@ -100,8 +100,8 @@ classDiagram
}
class MultiSegmentFetcher {
+Dict muti_fetchers
+List muti_keys
+Dict multi_fetchers
+List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict
}
@ -148,7 +148,7 @@ classDiagram
class Transformer {
+ModelConfig config
+RotaryEmbedding rotary_embeding
+RotaryEmbedding rotary_embedding
+Embedding embed_tokens
+ModuleList layers
+RMSNorm norm

View File

@ -72,15 +72,16 @@ class MultiSegmentFetcher:
Each key corresponds to a different type of data (e.g., "sequence", "mask").
"""
def __init__(self, muti_segments: Dict):
self.muti_keys = list(muti_segments.keys())
self.muti_fetchers = {
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
def __init__(self, multi_segments: Dict):
self.multi_keys = list(multi_segments.keys())
self.multi_fetchers = {
key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
}
def __len__(self) -> int:
"""Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.muti_fetchers.values()]
len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list)
def key_fetch(
@ -100,7 +101,7 @@ class MultiSegmentFetcher:
keys = [keys] if isinstance(keys, str) else keys
for key in keys:
fetcher = self.muti_fetchers[key]
fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor
@ -108,7 +109,7 @@ class MultiSegmentFetcher:
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseDataset(Dataset, ABC):

View File

@ -257,7 +257,7 @@ class MLA(nn.Module):
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# KV (k_nope, k_rope, v)
self.kv_b_proj = Linear(

View File

@ -234,13 +234,13 @@ def train(
},
)
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
scheduler_fn = partial(
create_scheduler,
**{
"schedule_type": "cosine",
"warmup_steps": warmup_steps,
"lr_decay_steps": toltal_steps - warmup_steps,
"lr_decay_steps": total_steps - warmup_steps,
},
)