From 3bf24689051f26e3117d1f3a2a1a988df5f8f031 Mon Sep 17 00:00:00 2001 From: ViperEkura <3081035982@qq.com> Date: Wed, 19 Nov 2025 17:47:33 +0800 Subject: [PATCH] =?UTF-8?q?fix(tools):=20=E4=BF=AE=E6=AD=A3=E8=AE=AD?= =?UTF-8?q?=E7=BB=83=E8=84=9A=E6=9C=AC=E4=B8=AD=E7=9A=84=E5=B5=8C=E5=85=A5?= =?UTF-8?q?=E5=B1=82=E5=8F=82=E6=95=B0=E5=88=86=E7=BB=84=E5=88=A4=E6=96=AD?= =?UTF-8?q?=E6=9D=A1=E4=BB=B6?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- tools/train.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/train.py b/tools/train.py index c498d37..b87a9c5 100644 --- a/tools/train.py +++ b/tools/train.py @@ -80,8 +80,8 @@ def train( ) param_groups = [ - {"params": [p for n, p in model.named_parameters() if "embedding" in n], "lr": max_lr * embdeding_lr_rate}, - {"params": [p for n, p in model.named_parameters() if "embedding" not in n], "lr": max_lr} + {"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate}, + {"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr} ] optim = AdamW(