style: 修改为显式导入
This commit is contained in:
parent
3346c75584
commit
b531232a9b
|
|
@ -94,7 +94,7 @@ class Transformer(nn.Module):
|
|||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||
|
||||
if self.config.tie_weight == True:
|
||||
if self.config.tie_weight:
|
||||
self.lm_head.weight = self.embed_tokens.weight
|
||||
|
||||
self._init_parameters()
|
||||
|
|
@ -103,7 +103,7 @@ class Transformer(nn.Module):
|
|||
lm_head_key = "lm_head.weight"
|
||||
embed_key = "embed_tokens.weight"
|
||||
|
||||
if self.config.tie_weight == True:
|
||||
if self.config.tie_weight:
|
||||
# same tensor
|
||||
state_dict[lm_head_key] = state_dict[embed_key]
|
||||
else:
|
||||
|
|
@ -118,7 +118,7 @@ class Transformer(nn.Module):
|
|||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||
)
|
||||
|
||||
if self.config.tie_weight == True:
|
||||
if self.config.tie_weight:
|
||||
lm_head_key = prefix + "lm_head.weight"
|
||||
if lm_head_key in state_dict:
|
||||
del state_dict[lm_head_key]
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ def test_single_process():
|
|||
for epoch in range(3):
|
||||
for iteration in range(10):
|
||||
x = torch.randn(32, 10)
|
||||
y = torch.randn(32, 5)
|
||||
loss = model(x).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
@ -44,7 +43,6 @@ def simple_training():
|
|||
for epoch in range(2):
|
||||
for iteration in range(5):
|
||||
x = torch.randn(16, 10)
|
||||
y = torch.randn(16, 5)
|
||||
loss = model(x).mean()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from astrai.data.dataset import *
|
||||
from astrai.data.dataset import DatasetFactory
|
||||
from astrai.data.serialization import save_h5
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,4 @@
|
|||
from astrai.data import *
|
||||
from astrai.trainer import *
|
||||
from astrai.data import ResumableDistributedSampler
|
||||
|
||||
|
||||
def test_random_sampler_consistency(random_dataset):
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ def test_health_no_model(client, monkeypatch):
|
|||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["status"] == "ok"
|
||||
assert data["model_loaded"] == False
|
||||
assert not data["model_loaded"]
|
||||
|
||||
|
||||
def test_health_with_model(client, loaded_model):
|
||||
|
|
|
|||
|
|
@ -2,11 +2,8 @@ import os
|
|||
|
||||
import torch
|
||||
|
||||
from astrai.config import *
|
||||
from astrai.data import *
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||
from astrai.model import *
|
||||
from astrai.trainer import *
|
||||
|
||||
|
||||
def test_model_parameter(test_env):
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def transformer_test_env():
|
|||
for file in os.listdir(test_dir):
|
||||
os.remove(os.path.join(test_dir, file))
|
||||
os.rmdir(test_dir)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -58,8 +58,11 @@ def create_train_config(
|
|||
TrainConfig instance configured for testing
|
||||
"""
|
||||
|
||||
optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001)
|
||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
||||
def optimizer_fn(m):
|
||||
return torch.optim.AdamW(m.parameters(), lr=0.001)
|
||||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,13 +1,19 @@
|
|||
import torch
|
||||
|
||||
from astrai.config import *
|
||||
from astrai.trainer import *
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
from astrai.trainer.train_callback import TrainCallback
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
||||
def test_callback_integration(base_test_env, random_dataset):
|
||||
"""Test that all callbacks are properly integrated"""
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
||||
|
||||
def optimizer_fn(model):
|
||||
return torch.optim.AdamW(model.parameters())
|
||||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -3,16 +3,20 @@ import os
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from astrai.config import *
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.trainer import *
|
||||
from astrai.trainer.schedule import SchedulerFactory
|
||||
from astrai.trainer.trainer import Trainer
|
||||
|
||||
|
||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||
"""Simulate early stopping behavior"""
|
||||
|
||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
||||
def optimizer_fn(model):
|
||||
return torch.optim.AdamW(model.parameters())
|
||||
|
||||
def scheduler_fn(optim):
|
||||
return SchedulerFactory.create(
|
||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,9 +1,7 @@
|
|||
import numpy as np
|
||||
import torch
|
||||
|
||||
from astrai.config import *
|
||||
from astrai.data.dataset import *
|
||||
from astrai.trainer.schedule import *
|
||||
from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler
|
||||
|
||||
|
||||
def test_schedule_factory_random_configs():
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
from astrai.data.dataset import *
|
||||
from astrai.trainer import Trainer
|
||||
|
||||
# train_config_factory is injected via fixture
|
||||
|
|
|
|||
Loading…
Reference in New Issue