AstrAI/astrai/trainer/__init__.py

31 lines
730 B
Python

from astrai.trainer.trainer import Trainer
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
from astrai.trainer.schedule import SchedulerFactory, BaseScheduler
from astrai.trainer.train_callback import (
TrainCallback,
GradientClippingCallback,
SchedulerCallback,
CheckpointCallback,
ProgressBarCallback,
MetricLoggerCallback,
)
__all__ = [
# Main trainer
"Trainer",
# Strategy factory
"StrategyFactory",
"BaseStrategy",
# Scheduler factory
"SchedulerFactory",
"BaseScheduler",
# Callbacks
"TrainCallback",
"GradientClippingCallback",
"SchedulerCallback",
"CheckpointCallback",
"ProgressBarCallback",
"MetricLoggerCallback",
]