style: 修改为显式导入

This commit is contained in:
ViperEkura 2026-04-04 16:02:49 +08:00
parent 3346c75584
commit b531232a9b
12 changed files with 38 additions and 34 deletions

View File

@ -94,7 +94,7 @@ class Transformer(nn.Module):
self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight == True:
if self.config.tie_weight:
self.lm_head.weight = self.embed_tokens.weight
self._init_parameters()
@ -103,7 +103,7 @@ class Transformer(nn.Module):
lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight"
if self.config.tie_weight == True:
if self.config.tie_weight:
# same tensor
state_dict[lm_head_key] = state_dict[embed_key]
else:
@ -118,7 +118,7 @@ class Transformer(nn.Module):
destination=destination, prefix=prefix, keep_vars=keep_vars
)
if self.config.tie_weight == True:
if self.config.tie_weight:
lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict:
del state_dict[lm_head_key]

View File

@ -17,7 +17,6 @@ def test_single_process():
for epoch in range(3):
for iteration in range(10):
x = torch.randn(32, 10)
y = torch.randn(32, 5)
loss = model(x).mean()
loss.backward()
optimizer.step()
@ -44,7 +43,6 @@ def simple_training():
for epoch in range(2):
for iteration in range(5):
x = torch.randn(16, 10)
y = torch.randn(16, 5)
loss = model(x).mean()
loss.backward()
optimizer.step()

View File

@ -1,7 +1,7 @@
import numpy as np
import torch
from astrai.data.dataset import *
from astrai.data.dataset import DatasetFactory
from astrai.data.serialization import save_h5

View File

@ -1,5 +1,4 @@
from astrai.data import *
from astrai.trainer import *
from astrai.data import ResumableDistributedSampler
def test_random_sampler_consistency(random_dataset):

View File

@ -10,7 +10,7 @@ def test_health_no_model(client, monkeypatch):
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
assert data["model_loaded"] == False
assert not data["model_loaded"]
def test_health_with_model(client, loaded_model):

View File

@ -2,11 +2,8 @@ import os
import torch
from astrai.config import *
from astrai.data import *
from astrai.config.param_config import ModelParameter
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
from astrai.model import *
from astrai.trainer import *
def test_model_parameter(test_env):

View File

@ -36,7 +36,7 @@ def transformer_test_env():
for file in os.listdir(test_dir):
os.remove(os.path.join(test_dir, file))
os.rmdir(test_dir)
except:
except Exception:
pass

View File

@ -58,10 +58,13 @@ def create_train_config(
TrainConfig instance configured for testing
"""
optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001)
scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
def optimizer_fn(m):
return torch.optim.AdamW(m.parameters(), lr=0.001)
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
return TrainConfig(
strategy=strategy,

View File

@ -1,15 +1,21 @@
import torch
from astrai.config import *
from astrai.trainer import *
from astrai.config.train_config import TrainConfig
from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.train_callback import TrainCallback
from astrai.trainer.trainer import Trainer
def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated"""
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
def optimizer_fn(model):
return torch.optim.AdamW(model.parameters())
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(
model=base_test_env["model"],

View File

@ -3,18 +3,22 @@ import os
import numpy as np
import torch
from astrai.config import *
from astrai.config.train_config import TrainConfig
from astrai.data.serialization import Checkpoint
from astrai.trainer import *
from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior"""
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
def optimizer_fn(model):
return torch.optim.AdamW(model.parameters())
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig(
strategy="seq",

View File

@ -1,9 +1,7 @@
import numpy as np
import torch
from astrai.config import *
from astrai.data.dataset import *
from astrai.trainer.schedule import *
from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler
def test_schedule_factory_random_configs():

View File

@ -1,4 +1,3 @@
from astrai.data.dataset import *
from astrai.trainer import Trainer
# train_config_factory is injected via fixture