diff --git a/assets/docs/design.md b/assets/docs/design.md index c7d929a..5cae4e6 100644 --- a/assets/docs/design.md +++ b/assets/docs/design.md @@ -100,8 +100,8 @@ classDiagram } class MultiSegmentFetcher { - +Dict muti_fetchers - +List muti_keys + +Dict multi_fetchers + +List multi_keys +key_fetch(begin_idx, end_idx, keys) Dict +fetch_data(begin_idx, end_idx) Dict } @@ -148,7 +148,7 @@ classDiagram class Transformer { +ModelConfig config - +RotaryEmbedding rotary_embeding + +RotaryEmbedding rotary_embedding +Embedding embed_tokens +ModuleList layers +RMSNorm norm diff --git a/astrai/dataset/dataset.py b/astrai/dataset/dataset.py index 66a4b65..1a49b62 100644 --- a/astrai/dataset/dataset.py +++ b/astrai/dataset/dataset.py @@ -72,15 +72,16 @@ class MultiSegmentFetcher: Each key corresponds to a different type of data (e.g., "sequence", "mask"). """ - def __init__(self, muti_segments: Dict): - self.muti_keys = list(muti_segments.keys()) - self.muti_fetchers = { - key: BaseSegmentFetcher(segments) for key, segments in muti_segments.items() + def __init__(self, multi_segments: Dict): + self.multi_keys = list(multi_segments.keys()) + self.multi_fetchers = { + key: BaseSegmentFetcher(segments) + for key, segments in multi_segments.items() } def __len__(self) -> int: """Returns the minimum length across all fetchers.""" - len_list = [len(seg) for seg in self.muti_fetchers.values()] + len_list = [len(seg) for seg in self.multi_fetchers.values()] return min(len_list) def key_fetch( @@ -100,7 +101,7 @@ class MultiSegmentFetcher: keys = [keys] if isinstance(keys, str) else keys for key in keys: - fetcher = self.muti_fetchers[key] + fetcher = self.multi_fetchers[key] fetch_tensor = fetcher.fetch_data(begin_idx, end_idx) fetch_dict[key] = fetch_tensor @@ -108,7 +109,7 @@ class MultiSegmentFetcher: def fetch_data(self, begin_idx: int, end_idx: int) -> Dict: """Fetch all keys.""" - return self.key_fetch(begin_idx, end_idx, self.muti_keys) + return self.key_fetch(begin_idx, end_idx, self.multi_keys) class BaseDataset(Dataset, ABC): diff --git a/astrai/model/module.py b/astrai/model/module.py index 0d22166..28c64f6 100644 --- a/astrai/model/module.py +++ b/astrai/model/module.py @@ -257,7 +257,7 @@ class MLA(nn.Module): self.q_proj = Linear(dim, n_heads * self.head_dim, bias=False) self.kv_a_proj = Linear(dim, kv_lora_rank, bias=False) - self.kv_norm = RMSNorm(kv_lora_rank, eps=norm_eps) + self.kv_norm = RMSNorm(kv_lora_rank, norm_eps) # KV (k_nope, k_rope, v) self.kv_b_proj = Linear( diff --git a/scripts/tools/train.py b/scripts/tools/train.py index 5b02b3f..eb43382 100644 --- a/scripts/tools/train.py +++ b/scripts/tools/train.py @@ -234,13 +234,13 @@ def train( }, ) - toltal_steps = len(dataset) * n_epoch // (batch_size * nprocs) + total_steps = len(dataset) * n_epoch // (batch_size * nprocs) scheduler_fn = partial( create_scheduler, **{ "schedule_type": "cosine", "warmup_steps": warmup_steps, - "lr_decay_steps": toltal_steps - warmup_steps, + "lr_decay_steps": total_steps - warmup_steps, }, )