chore: 修改错误拼写

This commit is contained in:
ViperEkura 2026-04-06 10:37:19 +08:00
parent ace8f6ee68
commit f2ffdf60d0
4 changed files with 14 additions and 13 deletions

View File

@ -100,8 +100,8 @@ classDiagram
} }
class MultiSegmentFetcher { class MultiSegmentFetcher {
+Dict muti_fetchers +Dict multi_fetchers
+List muti_keys +List multi_keys
+key_fetch(begin_idx, end_idx, keys) Dict +key_fetch(begin_idx, end_idx, keys) Dict
+fetch_data(begin_idx, end_idx) Dict +fetch_data(begin_idx, end_idx) Dict
} }
@ -148,7 +148,7 @@ classDiagram
class Transformer { class Transformer {
+ModelConfig config +ModelConfig config
+RotaryEmbedding rotary_embeding +RotaryEmbedding rotary_embedding
+Embedding embed_tokens +Embedding embed_tokens
+ModuleList layers +ModuleList layers
+RMSNorm norm +RMSNorm norm

View File

@ -72,15 +72,16 @@ class MultiSegmentFetcher:
Each key corresponds to a different type of data (e.g., "sequence", "mask"). Each key corresponds to a different type of data (e.g., "sequence", "mask").
""" """
def __init__(self, muti_segments: Dict): def __init__(self, multi_segments: Dict):
self.muti_keys = list(muti_segments.keys()) self.multi_keys = list(multi_segments.keys())
self.muti_fetchers = { self.multi_fetchers = {
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() key: BaseSegmentFetcher(segments)
for key, segments in multi_segments.items()
} }
def __len__(self) -> int: def __len__(self) -> int:
"""Returns the minimum length across all fetchers.""" """Returns the minimum length across all fetchers."""
len_list = [len(seg) for seg in self.muti_fetchers.values()] len_list = [len(seg) for seg in self.multi_fetchers.values()]
return min(len_list) return min(len_list)
def key_fetch( def key_fetch(
@ -100,7 +101,7 @@ class MultiSegmentFetcher:
keys = [keys] if isinstance(keys, str) else keys keys = [keys] if isinstance(keys, str) else keys
for key in keys: for key in keys:
fetcher = self.muti_fetchers[key] fetcher = self.multi_fetchers[key]
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
fetch_dict[key] = fetch_tensor fetch_dict[key] = fetch_tensor
@ -108,7 +109,7 @@ class MultiSegmentFetcher:
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
"""Fetch all keys.""" """Fetch all keys."""
return self.key_fetch(begin_idx, end_idx, self.muti_keys) return self.key_fetch(begin_idx, end_idx, self.multi_keys)
class BaseDataset(Dataset, ABC): class BaseDataset(Dataset, ABC):

View File

@ -257,7 +257,7 @@ class MLA(nn.Module):
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False) self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps) self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
# KV (k_nope, k_rope, v) # KV (k_nope, k_rope, v)
self.kv_b_proj = Linear( self.kv_b_proj = Linear(

View File

@ -234,13 +234,13 @@ def train(
}, },
) )
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs) total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
scheduler_fn = partial( scheduler_fn = partial(
create_scheduler, create_scheduler,
**{ **{
"schedule_type": "cosine", "schedule_type": "cosine",
"warmup_steps": warmup_steps, "warmup_steps": warmup_steps,
"lr_decay_steps": toltal_steps - warmup_steps, "lr_decay_steps": total_steps - warmup_steps,
}, },
) )