style: 修改为显式导入

This commit is contained in:
ViperEkura 2026-04-04 16:02:49 +08:00
parent 3346c75584
commit b531232a9b
12 changed files with 38 additions and 34 deletions

View File

@ -94,7 +94,7 @@ class Transformer(nn.Module):
self.norm = RMSNorm(config.dim, config.norm_eps) self.norm = RMSNorm(config.dim, config.norm_eps)
self.lm_head = Linear(config.dim, config.vocab_size) self.lm_head = Linear(config.dim, config.vocab_size)
if self.config.tie_weight == True: if self.config.tie_weight:
self.lm_head.weight = self.embed_tokens.weight self.lm_head.weight = self.embed_tokens.weight
self._init_parameters() self._init_parameters()
@ -103,7 +103,7 @@ class Transformer(nn.Module):
lm_head_key = "lm_head.weight" lm_head_key = "lm_head.weight"
embed_key = "embed_tokens.weight" embed_key = "embed_tokens.weight"
if self.config.tie_weight == True: if self.config.tie_weight:
# same tensor # same tensor
state_dict[lm_head_key] = state_dict[embed_key] state_dict[lm_head_key] = state_dict[embed_key]
else: else:
@ -118,7 +118,7 @@ class Transformer(nn.Module):
destination=destination, prefix=prefix, keep_vars=keep_vars destination=destination, prefix=prefix, keep_vars=keep_vars
) )
if self.config.tie_weight == True: if self.config.tie_weight:
lm_head_key = prefix + "lm_head.weight" lm_head_key = prefix + "lm_head.weight"
if lm_head_key in state_dict: if lm_head_key in state_dict:
del state_dict[lm_head_key] del state_dict[lm_head_key]

View File

@ -17,7 +17,6 @@ def test_single_process():
for epoch in range(3): for epoch in range(3):
for iteration in range(10): for iteration in range(10):
x = torch.randn(32, 10) x = torch.randn(32, 10)
y = torch.randn(32, 5)
loss = model(x).mean() loss = model(x).mean()
loss.backward() loss.backward()
optimizer.step() optimizer.step()
@ -44,7 +43,6 @@ def simple_training():
for epoch in range(2): for epoch in range(2):
for iteration in range(5): for iteration in range(5):
x = torch.randn(16, 10) x = torch.randn(16, 10)
y = torch.randn(16, 5)
loss = model(x).mean() loss = model(x).mean()
loss.backward() loss.backward()
optimizer.step() optimizer.step()

View File

@ -1,7 +1,7 @@
import numpy as np import numpy as np
import torch import torch
from astrai.data.dataset import * from astrai.data.dataset import DatasetFactory
from astrai.data.serialization import save_h5 from astrai.data.serialization import save_h5

View File

@ -1,5 +1,4 @@
from astrai.data import * from astrai.data import ResumableDistributedSampler
from astrai.trainer import *
def test_random_sampler_consistency(random_dataset): def test_random_sampler_consistency(random_dataset):

View File

@ -10,7 +10,7 @@ def test_health_no_model(client, monkeypatch):
assert response.status_code == 200 assert response.status_code == 200
data = response.json() data = response.json()
assert data["status"] == "ok" assert data["status"] == "ok"
assert data["model_loaded"] == False assert not data["model_loaded"]
def test_health_with_model(client, loaded_model): def test_health_with_model(client, loaded_model):

View File

@ -2,11 +2,8 @@ import os
import torch import torch
from astrai.config import * from astrai.config.param_config import ModelParameter
from astrai.data import *
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
from astrai.model import *
from astrai.trainer import *
def test_model_parameter(test_env): def test_model_parameter(test_env):

View File

@ -36,7 +36,7 @@ def transformer_test_env():
for file in os.listdir(test_dir): for file in os.listdir(test_dir):
os.remove(os.path.join(test_dir, file)) os.remove(os.path.join(test_dir, file))
os.rmdir(test_dir) os.rmdir(test_dir)
except: except Exception:
pass pass

View File

@ -58,10 +58,13 @@ def create_train_config(
TrainConfig instance configured for testing TrainConfig instance configured for testing
""" """
optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001) def optimizer_fn(m):
scheduler_fn = lambda optim: SchedulerFactory.create( return torch.optim.AdamW(m.parameters(), lr=0.001)
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
) def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
return TrainConfig( return TrainConfig(
strategy=strategy, strategy=strategy,

View File

@ -1,15 +1,21 @@
import torch import torch
from astrai.config import * from astrai.config.train_config import TrainConfig
from astrai.trainer import * from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.train_callback import TrainCallback
from astrai.trainer.trainer import Trainer
def test_callback_integration(base_test_env, random_dataset): def test_callback_integration(base_test_env, random_dataset):
"""Test that all callbacks are properly integrated""" """Test that all callbacks are properly integrated"""
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
scheduler_fn = lambda optim: SchedulerFactory.create( def optimizer_fn(model):
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 return torch.optim.AdamW(model.parameters())
)
def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig( train_config = TrainConfig(
model=base_test_env["model"], model=base_test_env["model"],

View File

@ -3,18 +3,22 @@ import os
import numpy as np import numpy as np
import torch import torch
from astrai.config import * from astrai.config.train_config import TrainConfig
from astrai.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from astrai.trainer import * from astrai.trainer.schedule import SchedulerFactory
from astrai.trainer.trainer import Trainer
def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
"""Simulate early stopping behavior""" """Simulate early stopping behavior"""
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) def optimizer_fn(model):
scheduler_fn = lambda optim: SchedulerFactory.create( return torch.optim.AdamW(model.parameters())
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
) def scheduler_fn(optim):
return SchedulerFactory.create(
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
)
train_config = TrainConfig( train_config = TrainConfig(
strategy="seq", strategy="seq",

View File

@ -1,9 +1,7 @@
import numpy as np import numpy as np
import torch import torch
from astrai.config import * from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler
from astrai.data.dataset import *
from astrai.trainer.schedule import *
def test_schedule_factory_random_configs(): def test_schedule_factory_random_configs():

View File

@ -1,4 +1,3 @@
from astrai.data.dataset import *
from astrai.trainer import Trainer from astrai.trainer import Trainer
# train_config_factory is injected via fixture # train_config_factory is injected via fixture