chore: 修改错误拼写
This commit is contained in:
parent
ace8f6ee68
commit
f2ffdf60d0
|
|
@ -100,8 +100,8 @@ classDiagram
|
||||||
}
|
}
|
||||||
|
|
||||||
class MultiSegmentFetcher {
|
class MultiSegmentFetcher {
|
||||||
+Dict muti_fetchers
|
+Dict multi_fetchers
|
||||||
+List muti_keys
|
+List multi_keys
|
||||||
+key_fetch(begin_idx, end_idx, keys) Dict
|
+key_fetch(begin_idx, end_idx, keys) Dict
|
||||||
+fetch_data(begin_idx, end_idx) Dict
|
+fetch_data(begin_idx, end_idx) Dict
|
||||||
}
|
}
|
||||||
|
|
@ -148,7 +148,7 @@ classDiagram
|
||||||
|
|
||||||
class Transformer {
|
class Transformer {
|
||||||
+ModelConfig config
|
+ModelConfig config
|
||||||
+RotaryEmbedding rotary_embeding
|
+RotaryEmbedding rotary_embedding
|
||||||
+Embedding embed_tokens
|
+Embedding embed_tokens
|
||||||
+ModuleList layers
|
+ModuleList layers
|
||||||
+RMSNorm norm
|
+RMSNorm norm
|
||||||
|
|
|
||||||
|
|
@ -72,15 +72,16 @@ class MultiSegmentFetcher:
|
||||||
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
Each key corresponds to a different type of data (e.g., "sequence", "mask").
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, muti_segments: Dict):
|
def __init__(self, multi_segments: Dict):
|
||||||
self.muti_keys = list(muti_segments.keys())
|
self.multi_keys = list(multi_segments.keys())
|
||||||
self.muti_fetchers = {
|
self.multi_fetchers = {
|
||||||
key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items()
|
key: BaseSegmentFetcher(segments)
|
||||||
|
for key, segments in multi_segments.items()
|
||||||
}
|
}
|
||||||
|
|
||||||
def __len__(self) -> int:
|
def __len__(self) -> int:
|
||||||
"""Returns the minimum length across all fetchers."""
|
"""Returns the minimum length across all fetchers."""
|
||||||
len_list = [len(seg) for seg in self.muti_fetchers.values()]
|
len_list = [len(seg) for seg in self.multi_fetchers.values()]
|
||||||
return min(len_list)
|
return min(len_list)
|
||||||
|
|
||||||
def key_fetch(
|
def key_fetch(
|
||||||
|
|
@ -100,7 +101,7 @@ class MultiSegmentFetcher:
|
||||||
keys = [keys] if isinstance(keys, str) else keys
|
keys = [keys] if isinstance(keys, str) else keys
|
||||||
|
|
||||||
for key in keys:
|
for key in keys:
|
||||||
fetcher = self.muti_fetchers[key]
|
fetcher = self.multi_fetchers[key]
|
||||||
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
fetch_tensor = fetcher.fetch_data(begin_idx, end_idx)
|
||||||
fetch_dict[key] = fetch_tensor
|
fetch_dict[key] = fetch_tensor
|
||||||
|
|
||||||
|
|
@ -108,7 +109,7 @@ class MultiSegmentFetcher:
|
||||||
|
|
||||||
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
def fetch_data(self, begin_idx: int, end_idx: int) -> Dict:
|
||||||
"""Fetch all keys."""
|
"""Fetch all keys."""
|
||||||
return self.key_fetch(begin_idx, end_idx, self.muti_keys)
|
return self.key_fetch(begin_idx, end_idx, self.multi_keys)
|
||||||
|
|
||||||
|
|
||||||
class BaseDataset(Dataset, ABC):
|
class BaseDataset(Dataset, ABC):
|
||||||
|
|
|
||||||
|
|
@ -257,7 +257,7 @@ class MLA(nn.Module):
|
||||||
|
|
||||||
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False)
|
||||||
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False)
|
||||||
self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps)
|
self.kv_norm = RMSNorm(kv_lora_rank, norm_eps)
|
||||||
|
|
||||||
# KV (k_nope, k_rope, v)
|
# KV (k_nope, k_rope, v)
|
||||||
self.kv_b_proj = Linear(
|
self.kv_b_proj = Linear(
|
||||||
|
|
|
||||||
|
|
@ -234,13 +234,13 @@ def train(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
total_steps = len(dataset) * n_epoch // (batch_size * nprocs)
|
||||||
scheduler_fn = partial(
|
scheduler_fn = partial(
|
||||||
create_scheduler,
|
create_scheduler,
|
||||||
**{
|
**{
|
||||||
"schedule_type": "cosine",
|
"schedule_type": "cosine",
|
||||||
"warmup_steps": warmup_steps,
|
"warmup_steps": warmup_steps,
|
||||||
"lr_decay_steps": toltal_steps - warmup_steps,
|
"lr_decay_steps": total_steps - warmup_steps,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue