style: 修改为显式导入
This commit is contained in:
parent
3346c75584
commit
b531232a9b
|
|
@ -94,7 +94,7 @@ class Transformer(nn.Module):
|
||||||
self.norm = RMSNorm(config.dim, config.norm_eps)
|
self.norm = RMSNorm(config.dim, config.norm_eps)
|
||||||
self.lm_head = Linear(config.dim, config.vocab_size)
|
self.lm_head = Linear(config.dim, config.vocab_size)
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight:
|
||||||
self.lm_head.weight = self.embed_tokens.weight
|
self.lm_head.weight = self.embed_tokens.weight
|
||||||
|
|
||||||
self._init_parameters()
|
self._init_parameters()
|
||||||
|
|
@ -103,7 +103,7 @@ class Transformer(nn.Module):
|
||||||
lm_head_key = "lm_head.weight"
|
lm_head_key = "lm_head.weight"
|
||||||
embed_key = "embed_tokens.weight"
|
embed_key = "embed_tokens.weight"
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight:
|
||||||
# same tensor
|
# same tensor
|
||||||
state_dict[lm_head_key] = state_dict[embed_key]
|
state_dict[lm_head_key] = state_dict[embed_key]
|
||||||
else:
|
else:
|
||||||
|
|
@ -118,7 +118,7 @@ class Transformer(nn.Module):
|
||||||
destination=destination, prefix=prefix, keep_vars=keep_vars
|
destination=destination, prefix=prefix, keep_vars=keep_vars
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.config.tie_weight == True:
|
if self.config.tie_weight:
|
||||||
lm_head_key = prefix + "lm_head.weight"
|
lm_head_key = prefix + "lm_head.weight"
|
||||||
if lm_head_key in state_dict:
|
if lm_head_key in state_dict:
|
||||||
del state_dict[lm_head_key]
|
del state_dict[lm_head_key]
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,6 @@ def test_single_process():
|
||||||
for epoch in range(3):
|
for epoch in range(3):
|
||||||
for iteration in range(10):
|
for iteration in range(10):
|
||||||
x = torch.randn(32, 10)
|
x = torch.randn(32, 10)
|
||||||
y = torch.randn(32, 5)
|
|
||||||
loss = model(x).mean()
|
loss = model(x).mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
@ -44,7 +43,6 @@ def simple_training():
|
||||||
for epoch in range(2):
|
for epoch in range(2):
|
||||||
for iteration in range(5):
|
for iteration in range(5):
|
||||||
x = torch.randn(16, 10)
|
x = torch.randn(16, 10)
|
||||||
y = torch.randn(16, 5)
|
|
||||||
loss = model(x).mean()
|
loss = model(x).mean()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
optimizer.step()
|
optimizer.step()
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.data.dataset import *
|
from astrai.data.dataset import DatasetFactory
|
||||||
from astrai.data.serialization import save_h5
|
from astrai.data.serialization import save_h5
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,4 @@
|
||||||
from astrai.data import *
|
from astrai.data import ResumableDistributedSampler
|
||||||
from astrai.trainer import *
|
|
||||||
|
|
||||||
|
|
||||||
def test_random_sampler_consistency(random_dataset):
|
def test_random_sampler_consistency(random_dataset):
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ def test_health_no_model(client, monkeypatch):
|
||||||
assert response.status_code == 200
|
assert response.status_code == 200
|
||||||
data = response.json()
|
data = response.json()
|
||||||
assert data["status"] == "ok"
|
assert data["status"] == "ok"
|
||||||
assert data["model_loaded"] == False
|
assert not data["model_loaded"]
|
||||||
|
|
||||||
|
|
||||||
def test_health_with_model(client, loaded_model):
|
def test_health_with_model(client, loaded_model):
|
||||||
|
|
|
||||||
|
|
@ -2,11 +2,8 @@ import os
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config import *
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.data import *
|
|
||||||
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||||
from astrai.model import *
|
|
||||||
from astrai.trainer import *
|
|
||||||
|
|
||||||
|
|
||||||
def test_model_parameter(test_env):
|
def test_model_parameter(test_env):
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ def transformer_test_env():
|
||||||
for file in os.listdir(test_dir):
|
for file in os.listdir(test_dir):
|
||||||
os.remove(os.path.join(test_dir, file))
|
os.remove(os.path.join(test_dir, file))
|
||||||
os.rmdir(test_dir)
|
os.rmdir(test_dir)
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -58,8 +58,11 @@ def create_train_config(
|
||||||
TrainConfig instance configured for testing
|
TrainConfig instance configured for testing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001)
|
def optimizer_fn(m):
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
return torch.optim.AdamW(m.parameters(), lr=0.001)
|
||||||
|
|
||||||
|
def scheduler_fn(optim):
|
||||||
|
return SchedulerFactory.create(
|
||||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,19 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config import *
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.trainer import *
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
|
from astrai.trainer.train_callback import TrainCallback
|
||||||
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
"""Test that all callbacks are properly integrated"""
|
"""Test that all callbacks are properly integrated"""
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
def optimizer_fn(model):
|
||||||
|
return torch.optim.AdamW(model.parameters())
|
||||||
|
|
||||||
|
def scheduler_fn(optim):
|
||||||
|
return SchedulerFactory.create(
|
||||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -3,16 +3,20 @@ import os
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config import *
|
from astrai.config.train_config import TrainConfig
|
||||||
from astrai.data.serialization import Checkpoint
|
from astrai.data.serialization import Checkpoint
|
||||||
from astrai.trainer import *
|
from astrai.trainer.schedule import SchedulerFactory
|
||||||
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
||||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
"""Simulate early stopping behavior"""
|
"""Simulate early stopping behavior"""
|
||||||
|
|
||||||
optimizer_fn = lambda model: torch.optim.AdamW(model.parameters())
|
def optimizer_fn(model):
|
||||||
scheduler_fn = lambda optim: SchedulerFactory.create(
|
return torch.optim.AdamW(model.parameters())
|
||||||
|
|
||||||
|
def scheduler_fn(optim):
|
||||||
|
return SchedulerFactory.create(
|
||||||
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,7 @@
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from astrai.config import *
|
from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler
|
||||||
from astrai.data.dataset import *
|
|
||||||
from astrai.trainer.schedule import *
|
|
||||||
|
|
||||||
|
|
||||||
def test_schedule_factory_random_configs():
|
def test_schedule_factory_random_configs():
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from astrai.data.dataset import *
|
|
||||||
from astrai.trainer import Trainer
|
from astrai.trainer import Trainer
|
||||||
|
|
||||||
# train_config_factory is injected via fixture
|
# train_config_factory is injected via fixture
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue