AstrAI/astrai/trainer/__init__.py

30 lines
729 B
Python

from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
from astrai.trainer.train_callback import (
CheckpointCallback,
GradientClippingCallback,
MetricLoggerCallback,
ProgressBarCallback,
SchedulerCallback,
TrainCallback,
)
from astrai.trainer.trainer import Trainer
__all__ = [
# Main trainer
"Trainer",
# Strategy factory
"StrategyFactory",
"BaseStrategy",
# Scheduler factory
"SchedulerFactory",
"BaseScheduler",
# Callbacks
"TrainCallback",
"GradientClippingCallback",
"SchedulerCallback",
"CheckpointCallback",
"ProgressBarCallback",
"MetricLoggerCallback",
]