fix(tools): 修正训练脚本中的嵌入层参数分组判断条件
This commit is contained in:
parent
3c7ed84516
commit
3bf2468905
|
|
@ -80,8 +80,8 @@ def train(
|
||||||
)
|
)
|
||||||
|
|
||||||
param_groups = [
|
param_groups = [
|
||||||
{"params": [p for n, p in model.named_parameters() if "embedding" in n], "lr": max_lr * embdeding_lr_rate},
|
{"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate},
|
||||||
{"params": [p for n, p in model.named_parameters() if "embedding" not in n], "lr": max_lr}
|
{"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr}
|
||||||
]
|
]
|
||||||
|
|
||||||
optim = AdamW(
|
optim = AdamW(
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue