AstrAI/khaosz/trainer/__init__.py

33 lines
744 B
Python

from khaosz.trainer.trainer import Trainer
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler
from khaosz.trainer.train_callback import (
TrainCallback,
GradientClippingCallback,
SchedulerCallback,
CheckpointCallback,
ProgressBarCallback,
MetricLoggerCallback,
)
__all__ = [
# Main trainer
"Trainer",
# Strategy factory
"StrategyFactory",
"BaseStrategy",
# Scheduler factory
"SchedulerFactory",
"BaseScheduler",
# Callbacks
"TrainCallback",
"GradientClippingCallback",
"SchedulerCallback",
"CheckpointCallback",
"ProgressBarCallback",
"MetricLoggerCallback",
]