diff --git a/tools/train.py b/tools/train.py index c498d37..b87a9c5 100644 --- a/tools/train.py +++ b/tools/train.py @@ -80,8 +80,8 @@ def train( ) param_groups = [ - {"params": [p for n, p in model.named_parameters() if "embedding" in n], "lr": max_lr * embdeding_lr_rate}, - {"params": [p for n, p in model.named_parameters() if "embedding" not in n], "lr": max_lr} + {"params": [p for n, p in model.named_parameters() if "embed" in n], "lr": max_lr * embdeding_lr_rate}, + {"params": [p for n, p in model.named_parameters() if "embed" not in n], "lr": max_lr} ] optim = AdamW(