chore: 修改错误拼写
This commit is contained in:
parent
ace8f6ee68
commit
f2ffdf60d0
|
|
@ -100,8 +100,8 @@ classDiagram
|
|||
}
|
||||
|
||||
class MultiSegmentFetcher {
|
||||
+Dict muti_fetchers
|
||||
+List muti_keys
|
||||
+Dict multi_fetchers
|
||||
+List multi_keys
|
||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||
+fetch_data(begin_idx, end_idx) Dict
|
||||
}
|
||||
|
|
@ -148,7 +148,7 @@ classDiagram
|
|||
|
||||
class Transformer {
|
||||
+ModelConfig config
|
||||
+RotaryEmbedding rotary_embeding
|
||||
+RotaryEmbedding rotary_embedding
|
||||
+Embedding embed_tokens
|
||||
+ModuleList layers
|
||||
+RMSNorm norm
|
||||
|
|
|
|||
|
|
@ -72,15 +72,16 @@ class MultiSegmentFetcher:
|
|||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||
"""
|
||||
|
||||
def __init__(self, muti_segments: Dict):
|
||||
self.muti_keys = list(muti_segments.keys())
|
||||
self.muti_fetchers = {
|
||||
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
|
||||
def __init__(self, multi_segments: Dict):
|
||||
self.multi_keys = list(multi_segments.keys())
|
||||
self.multi_fetchers = {
|
||||
key: BaseSegmentFetcher(segments)
|
||||
for key, segments in multi_segments.items()
|
||||
}
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Returns the minimum length across all fetchers."""
|
||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
||||
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||
return min(len_list)
|
||||
|
||||
def key_fetch(
|
||||
|
|
@ -100,7 +101,7 @@ class MultiSegmentFetcher:
|
|||
keys = [keys] if isinstance(keys, str) else keys
|
||||
|
||||
for key in keys:
|
||||
fetcher = self.muti_fetchers[key]
|
||||
fetcher = self.multi_fetchers[key]
|
||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||
fetch_dict[key] = fetch_tensor
|
||||
|
||||
|
|
@ -108,7 +109,7 @@ class MultiSegmentFetcher:
|
|||
|
||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||
"""Fetch all keys."""
|
||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
||||
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||
|
||||
|
||||
class BaseDataset(Dataset, ABC):
|
||||
|
|
|
|||
|
|
@ -257,7 +257,7 @@ class MLA(nn.Module):
|
|||
|
||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
|
||||
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||
|
||||
# KV (k_nope, k_rope, v)
|
||||
self.kv_b_proj = Linear(
|
||||
|
|
|
|||
|
|
@ -234,13 +234,13 @@ def train(
|
|||
},
|
||||
)
|
||||
|
||||
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||
scheduler_fn = partial(
|
||||
create_scheduler,
|
||||
**{
|
||||
"schedule_type": "cosine",
|
||||
"warmup_steps": warmup_steps,
|
||||
"lr_decay_steps": toltal_steps - warmup_steps,
|
||||
"lr_decay_steps": total_steps - warmup_steps,
|
||||
},
|
||||
)
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue