From f2ffdf60d0e6b291b8aa83be53f2e6ccb553cf43 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Mon, 6 Apr 2026 10:37:19 +0800 Subject: [PATCH] =?UTF-8?q?chore:=20=E4=BF=AE=E6=94=B9=E9=94=99=E8=AF=AF?= =?UTF-8?q?=E6=8B=BC=E5=86=99?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- assets/docs/design.md | 6 +++--- astrai/dataset/dataset.py | 15 ++++++++------- astrai/model/module.py | 2 +- scripts/tools/train.py | 4 ++-- 4 files changed, 14 insertions(+), 13 deletions(-) diff --git a/assets/docs/design.md b/assets/docs/design.md index c7d929a..5cae4e6 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -100,8 +100,8 @@ classDiagram } class MultiSegmentFetcher { - +Dict muti_fetchers - +List muti_keys + +Dict multi_fetchers + +List multi_keys +key_fetch(begin_idx, end_idx, keys) Dict +fetch_data(begin_idx, end_idx) Dict } @@ -148,7 +148,7 @@ classDiagram class Transformer { +ModelConfig config - +RotaryEmbedding rotary_embeding + +RotaryEmbedding rotary_embedding +Embedding embed_tokens +ModuleList layers +RMSNorm norm diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 66a4b65..1a49b62 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -72,15 +72,16 @@ class MultiSegmentFetcher: Each key corresponds to a different type of data (e.g., "sequence", "mask"). """ - def __init__(self, muti_segments: Dict): - self.muti_keys = list(muti_segments.keys()) - self.muti_fetchers = { - key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() + def __init__(self, multi_segments: Dict): + self.multi_keys = list(multi_segments.keys()) + self.multi_fetchers = { + key: BaseSegmentFetcher(segments) + for key, segments in multi_segments.items() } def __len__(self) -> int: """Returns the minimum length across all fetchers.""" - len_list = [len(seg) for seg in self.muti_fetchers.values()] + len_list = [len(seg) for seg in self.multi_fetchers.values()] return min(len_list) def key_fetch( @@ -100,7 +101,7 @@ class MultiSegmentFetcher: keys = [keys] if isinstance(keys, str) else keys for key in keys: - fetcher = self.muti_fetchers[key] + fetcher = self.multi_fetchers[key] fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_dict[key] = fetch_tensor @@ -108,7 +109,7 @@ class MultiSegmentFetcher: def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: """Fetch all keys.""" - return self.key_fetch(begin_idx, end_idx, self.muti_keys) + return self.key_fetch(begin_idx, end_idx, self.multi_keys) class BaseDataset(Dataset, ABC): diff --git a/astrai/model/module.py b/astrai/model/module.py index 0d22166..28c64f6 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -257,7 +257,7 @@ class MLA(nn.Module): self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) - self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps) + self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) # KV (k_nope, k_rope, v) self.kv_b_proj = Linear( diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 5b02b3f..eb43382 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -234,13 +234,13 @@ def train( }, ) - toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs) + total_steps = len(dataset) * n_epoch // (batch_size * nprocs) scheduler_fn = partial( create_scheduler, **{ "schedule_type": "cosine", "warmup_steps": warmup_steps, - "lr_decay_steps": toltal_steps - warmup_steps, + "lr_decay_steps": total_steps - warmup_steps, }, )