diff --git a/astrai/model/transformer.py b/astrai/model/transformer.py index 79746bb..281cd4a 100644 --- a/astrai/model/transformer.py +++ b/astrai/model/transformer.py @@ -94,7 +94,7 @@ class Transformer(nn.Module): self.norm = RMSNorm(config.dim, config.norm_eps) self.lm_head = Linear(config.dim, config.vocab_size) - if self.config.tie_weight == True: + if self.config.tie_weight: self.lm_head.weight = self.embed_tokens.weight self._init_parameters() @@ -103,7 +103,7 @@ class Transformer(nn.Module): lm_head_key = "lm_head.weight" embed_key = "embed_tokens.weight" - if self.config.tie_weight == True: + if self.config.tie_weight: # same tensor state_dict[lm_head_key] = state_dict[embed_key] else: @@ -118,7 +118,7 @@ class Transformer(nn.Module): destination=destination, prefix=prefix, keep_vars=keep_vars ) - if self.config.tie_weight == True: + if self.config.tie_weight: lm_head_key = prefix + "lm_head.weight" if lm_head_key in state_dict: del state_dict[lm_head_key] diff --git a/tests/data/test_checkpoint.py b/tests/data/test_checkpoint.py index f261a13..1285a1e 100644 --- a/tests/data/test_checkpoint.py +++ b/tests/data/test_checkpoint.py @@ -17,7 +17,6 @@ def test_single_process(): for epoch in range(3): for iteration in range(10): x = torch.randn(32, 10) - y = torch.randn(32, 5) loss = model(x).mean() loss.backward() optimizer.step() @@ -44,7 +43,6 @@ def simple_training(): for epoch in range(2): for iteration in range(5): x = torch.randn(16, 10) - y = torch.randn(16, 5) loss = model(x).mean() loss.backward() optimizer.step() diff --git a/tests/data/test_dataset.py b/tests/data/test_dataset.py index 6de18af..580b3ea 100644 --- a/tests/data/test_dataset.py +++ b/tests/data/test_dataset.py @@ -1,7 +1,7 @@ import numpy as np import torch -from astrai.data.dataset import * +from astrai.data.dataset import DatasetFactory from astrai.data.serialization import save_h5 diff --git a/tests/data/test_sampler.py b/tests/data/test_sampler.py index 8ac821c..045a567 100644 --- a/tests/data/test_sampler.py +++ b/tests/data/test_sampler.py @@ -1,5 +1,4 @@ -from astrai.data import * -from astrai.trainer import * +from astrai.data import ResumableDistributedSampler def test_random_sampler_consistency(random_dataset): diff --git a/tests/inference/test_server.py b/tests/inference/test_server.py index 1e1bb92..00bc61f 100644 --- a/tests/inference/test_server.py +++ b/tests/inference/test_server.py @@ -10,7 +10,7 @@ def test_health_no_model(client, monkeypatch): assert response.status_code == 200 data = response.json() assert data["status"] == "ok" - assert data["model_loaded"] == False + assert not data["model_loaded"] def test_health_with_model(client, loaded_model): diff --git a/tests/module/test_module.py b/tests/module/test_module.py index 19e63ca..06bf989 100644 --- a/tests/module/test_module.py +++ b/tests/module/test_module.py @@ -2,11 +2,8 @@ import os import torch -from astrai.config import * -from astrai.data import * +from astrai.config.param_config import ModelParameter from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore -from astrai.model import * -from astrai.trainer import * def test_model_parameter(test_env): diff --git a/tests/module/test_tie_weight.py b/tests/module/test_tie_weight.py index 143fd68..e15b4ac 100644 --- a/tests/module/test_tie_weight.py +++ b/tests/module/test_tie_weight.py @@ -36,7 +36,7 @@ def transformer_test_env(): for file in os.listdir(test_dir): os.remove(os.path.join(test_dir, file)) os.rmdir(test_dir) - except: + except Exception: pass diff --git a/tests/trainer/conftest.py b/tests/trainer/conftest.py index 2a086a0..4efa745 100644 --- a/tests/trainer/conftest.py +++ b/tests/trainer/conftest.py @@ -58,10 +58,13 @@ def create_train_config( TrainConfig instance configured for testing """ - optimizer_fn = lambda m: torch.optim.AdamW(m.parameters(), lr=0.001) - scheduler_fn = lambda optim: SchedulerFactory.create( - optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 - ) + def optimizer_fn(m): + return torch.optim.AdamW(m.parameters(), lr=0.001) + + def scheduler_fn(optim): + return SchedulerFactory.create( + optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + ) return TrainConfig( strategy=strategy, diff --git a/tests/trainer/test_callbacks.py b/tests/trainer/test_callbacks.py index afd1c57..0238eab 100644 --- a/tests/trainer/test_callbacks.py +++ b/tests/trainer/test_callbacks.py @@ -1,15 +1,21 @@ import torch -from astrai.config import * -from astrai.trainer import * +from astrai.config.train_config import TrainConfig +from astrai.trainer.schedule import SchedulerFactory +from astrai.trainer.train_callback import TrainCallback +from astrai.trainer.trainer import Trainer def test_callback_integration(base_test_env, random_dataset): """Test that all callbacks are properly integrated""" - optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) - scheduler_fn = lambda optim: SchedulerFactory.create( - optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 - ) + + def optimizer_fn(model): + return torch.optim.AdamW(model.parameters()) + + def scheduler_fn(optim): + return SchedulerFactory.create( + optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + ) train_config = TrainConfig( model=base_test_env["model"], diff --git a/tests/trainer/test_early_stopping.py b/tests/trainer/test_early_stopping.py index 58d2bb5..052fd2a 100644 --- a/tests/trainer/test_early_stopping.py +++ b/tests/trainer/test_early_stopping.py @@ -3,18 +3,22 @@ import os import numpy as np import torch -from astrai.config import * +from astrai.config.train_config import TrainConfig from astrai.data.serialization import Checkpoint -from astrai.trainer import * +from astrai.trainer.schedule import SchedulerFactory +from astrai.trainer.trainer import Trainer def test_early_stopping_simulation(base_test_env, early_stopping_dataset): """Simulate early stopping behavior""" - optimizer_fn = lambda model: torch.optim.AdamW(model.parameters()) - scheduler_fn = lambda optim: SchedulerFactory.create( - optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 - ) + def optimizer_fn(model): + return torch.optim.AdamW(model.parameters()) + + def scheduler_fn(optim): + return SchedulerFactory.create( + optim, "cosine", warmup_steps=10, lr_decay_steps=10, min_rate=0.05 + ) train_config = TrainConfig( strategy="seq", diff --git a/tests/trainer/test_train_strategy.py b/tests/trainer/test_train_strategy.py index 7138312..6e347b0 100644 --- a/tests/trainer/test_train_strategy.py +++ b/tests/trainer/test_train_strategy.py @@ -1,9 +1,7 @@ import numpy as np import torch -from astrai.config import * -from astrai.data.dataset import * -from astrai.trainer.schedule import * +from astrai.trainer.schedule import SchedulerFactory, CosineScheduler, SGDRScheduler def test_schedule_factory_random_configs(): diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c17b5c7..bb6520d 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1,4 +1,3 @@ -from astrai.data.dataset import * from astrai.trainer import Trainer # train_config_factory is injected via fixture