feat: 优化工厂模式的实现
This commit is contained in:
parent
aa5e03d7f6
commit
3346c75584
|
|
@ -5,7 +5,7 @@ from astrai.config import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from astrai.core.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.data import BpeTokenizer, DatasetFactory
|
from astrai.data import BpeTokenizer, DatasetFactory
|
||||||
from astrai.inference.generator import (
|
from astrai.inference.generator import (
|
||||||
BatchGenerator,
|
BatchGenerator,
|
||||||
|
|
|
||||||
|
|
@ -1,105 +0,0 @@
|
||||||
"""Base factory class for extensible component registration."""
|
|
||||||
|
|
||||||
from abc import ABC
|
|
||||||
from typing import Callable, Dict, Generic, Type, TypeVar
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class BaseFactory(ABC, Generic[T]):
|
|
||||||
"""Generic factory class for component registration and creation.
|
|
||||||
|
|
||||||
This base class provides a decorator-based registration pattern
|
|
||||||
for creating extensible component factories.
|
|
||||||
|
|
||||||
Example usage:
|
|
||||||
class MyFactory(BaseFactory[MyBaseClass]):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@MyFactory.register("custom")
|
|
||||||
class CustomComponent(MyBaseClass):
|
|
||||||
...
|
|
||||||
|
|
||||||
component = MyFactory.create("custom", *args, **kwargs)
|
|
||||||
"""
|
|
||||||
|
|
||||||
_registry: Dict[str, Type[T]] = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(cls, name: str) -> Callable[[Type[T]], Type[T]]:
|
|
||||||
"""Decorator to register a component class.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Registration name for the component
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Decorator function that registers the component class
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the decorated class doesn't inherit from the base type
|
|
||||||
"""
|
|
||||||
|
|
||||||
def decorator(component_cls: Type[T]) -> Type[T]:
|
|
||||||
cls._validate_component(component_cls)
|
|
||||||
cls._registry[name] = component_cls
|
|
||||||
return component_cls
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def create(cls, name: str, *args, **kwargs) -> T:
|
|
||||||
"""Create a component instance by name.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Registered name of the component
|
|
||||||
*args: Positional arguments passed to component constructor
|
|
||||||
**kwargs: Keyword arguments passed to component constructor
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Component instance
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: If the component name is not registered
|
|
||||||
"""
|
|
||||||
if name not in cls._registry:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown component: '{name}'. "
|
|
||||||
f"Supported types: {sorted(cls._registry.keys())}"
|
|
||||||
)
|
|
||||||
component_cls = cls._registry[name]
|
|
||||||
return component_cls(*args, **kwargs)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _validate_component(cls, component_cls: Type[T]) -> None:
|
|
||||||
"""Validate that the component class is valid for this factory.
|
|
||||||
|
|
||||||
Override this method in subclasses to add custom validation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
component_cls: Component class to validate
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
TypeError: If the component class is invalid
|
|
||||||
"""
|
|
||||||
pass
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def list_registered(cls) -> list:
|
|
||||||
"""List all registered component names.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of registered component names
|
|
||||||
"""
|
|
||||||
return sorted(cls._registry.keys())
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def is_registered(cls, name: str) -> bool:
|
|
||||||
"""Check if a component name is registered.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
name: Component name to check
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
True if registered, False otherwise
|
|
||||||
"""
|
|
||||||
return name in cls._registry
|
|
||||||
|
|
@ -1,7 +1,6 @@
|
||||||
from astrai.data.dataset import (
|
from astrai.data.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
DatasetFactory,
|
DatasetFactory,
|
||||||
DatasetFactory,
|
|
||||||
DPODataset,
|
DPODataset,
|
||||||
GRPODataset,
|
GRPODataset,
|
||||||
MultiSegmentFetcher,
|
MultiSegmentFetcher,
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
|
|
||||||
from astrai.core.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.data.serialization import load_h5
|
from astrai.data.serialization import load_h5
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -181,8 +181,6 @@ class DatasetFactory(BaseFactory["BaseDataset"]):
|
||||||
dataset = DatasetFactory.create("custom", window_size, stride)
|
dataset = DatasetFactory.create("custom", window_size, stride)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_registry: Dict[str, type] = {}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, dataset_cls: type) -> None:
|
def _validate_component(cls, dataset_cls: type) -> None:
|
||||||
"""Validate that the dataset class inherits from BaseDataset."""
|
"""Validate that the dataset class inherits from BaseDataset."""
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,187 @@
|
||||||
|
"""Base factory class for extensible component registration."""
|
||||||
|
|
||||||
|
from abc import ABC
|
||||||
|
from typing import Callable, Dict, Generic, Type, TypeVar, Optional, List, Tuple
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class Registry:
|
||||||
|
"""Flexible registry for component classes with category and priority support.
|
||||||
|
|
||||||
|
This registry stores component classes with optional metadata (category, priority).
|
||||||
|
It provides methods for registration, retrieval, and listing with filtering.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._entries = {} # name -> (component_cls, category, priority)
|
||||||
|
|
||||||
|
def register(
|
||||||
|
self,
|
||||||
|
name: str,
|
||||||
|
component_cls: Type,
|
||||||
|
category: Optional[str] = None,
|
||||||
|
priority: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Register a component class with optional category and priority."""
|
||||||
|
if name in self._entries:
|
||||||
|
raise ValueError(f"Component '{name}' is already registered")
|
||||||
|
self._entries[name] = (component_cls, category, priority)
|
||||||
|
|
||||||
|
def get(self, name: str) -> Type:
|
||||||
|
"""Get component class by name."""
|
||||||
|
if name not in self._entries:
|
||||||
|
raise KeyError(f"Component '{name}' not found in registry")
|
||||||
|
return self._entries[name][0]
|
||||||
|
|
||||||
|
def get_with_metadata(self, name: str) -> Tuple[Type, Optional[str], int]:
|
||||||
|
"""Get component class with its metadata."""
|
||||||
|
entry = self._entries.get(name)
|
||||||
|
if entry is None:
|
||||||
|
raise KeyError(f"Component '{name}' not found in registry")
|
||||||
|
return entry
|
||||||
|
|
||||||
|
def contains(self, name: str) -> bool:
|
||||||
|
"""Check if a name is registered."""
|
||||||
|
return name in self._entries
|
||||||
|
|
||||||
|
def list_names(self) -> List[str]:
|
||||||
|
"""Return list of registered component names."""
|
||||||
|
return sorted(self._entries.keys())
|
||||||
|
|
||||||
|
def list_by_category(self, category: str) -> List[str]:
|
||||||
|
"""Return names of components belonging to a specific category."""
|
||||||
|
return sorted(
|
||||||
|
name for name, (_, cat, _) in self._entries.items() if cat == category
|
||||||
|
)
|
||||||
|
|
||||||
|
def list_by_priority(self, reverse: bool = False) -> List[str]:
|
||||||
|
"""Return names sorted by priority (default ascending)."""
|
||||||
|
return sorted(
|
||||||
|
self._entries.keys(),
|
||||||
|
key=lambda name: self._entries[name][2],
|
||||||
|
reverse=reverse,
|
||||||
|
)
|
||||||
|
|
||||||
|
def entries(self) -> Dict[str, Tuple[Type, Optional[str], int]]:
|
||||||
|
"""Return raw entries dictionary."""
|
||||||
|
return self._entries.copy()
|
||||||
|
|
||||||
|
|
||||||
|
class BaseFactory(ABC, Generic[T]):
|
||||||
|
"""Generic factory class for component registration and creation.
|
||||||
|
|
||||||
|
This base class provides a decorator-based registration pattern
|
||||||
|
for creating extensible component factories.
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
class MyFactory(BaseFactory[MyBaseClass]):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@MyFactory.register("custom")
|
||||||
|
class CustomComponent(MyBaseClass):
|
||||||
|
...
|
||||||
|
|
||||||
|
component = MyFactory.create("custom", *args, **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
_registry: Registry
|
||||||
|
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
super().__init_subclass__(**kwargs)
|
||||||
|
cls._registry = Registry()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(
|
||||||
|
cls, name: str, category: Optional[str] = None, priority: int = 0
|
||||||
|
) -> Callable[[Type[T]], Type[T]]:
|
||||||
|
"""Decorator to register a component class with optional category and priority.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registration name for the component
|
||||||
|
category: Optional category for grouping components
|
||||||
|
priority: Priority for ordering (default 0)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Decorator function that registers the component class
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the decorated class doesn't inherit from the base type
|
||||||
|
"""
|
||||||
|
|
||||||
|
def decorator(component_cls: Type[T]) -> Type[T]:
|
||||||
|
cls._validate_component(component_cls)
|
||||||
|
cls._registry.register(
|
||||||
|
name, component_cls, category=category, priority=priority
|
||||||
|
)
|
||||||
|
return component_cls
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def create(cls, name: str, *args, **kwargs) -> T:
|
||||||
|
"""Create a component instance by name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Registered name of the component
|
||||||
|
*args: Positional arguments passed to component constructor
|
||||||
|
**kwargs: Keyword arguments passed to component constructor
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Component instance
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the component name is not registered
|
||||||
|
"""
|
||||||
|
if not cls._registry.contains(name):
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown component: '{name}'. "
|
||||||
|
f"Supported types: {sorted(cls._registry.list_names())}"
|
||||||
|
)
|
||||||
|
component_cls = cls._registry.get(name)
|
||||||
|
return component_cls(*args, **kwargs)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, component_cls: Type[T]) -> None:
|
||||||
|
"""Validate that the component class is valid for this factory.
|
||||||
|
|
||||||
|
Override this method in subclasses to add custom validation.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
component_cls: Component class to validate
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
TypeError: If the component class is invalid
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_registered(cls) -> list:
|
||||||
|
"""List all registered component names.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of registered component names
|
||||||
|
"""
|
||||||
|
return cls._registry.list_names()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_registered(cls, name: str) -> bool:
|
||||||
|
"""Check if a component name is registered.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
name: Component name to check
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if registered, False otherwise
|
||||||
|
"""
|
||||||
|
return cls._registry.contains(name)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_by_category(cls, category: str) -> List[str]:
|
||||||
|
"""List registered component names in a category."""
|
||||||
|
return cls._registry.list_by_category(category)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def list_by_priority(cls, reverse: bool = False) -> List[str]:
|
||||||
|
"""List registered component names sorted by priority."""
|
||||||
|
return cls._registry.list_by_priority(reverse)
|
||||||
|
|
@ -6,7 +6,7 @@ from jinja2 import Template
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from astrai.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from astrai.core.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
from astrai.inference.core import EmbeddingEncoderCore, GeneratorCore, KVCacheManager
|
||||||
|
|
||||||
HistoryType = List[Tuple[str, str]]
|
HistoryType = List[Tuple[str, str]]
|
||||||
|
|
@ -299,8 +299,6 @@ class GeneratorFactory(BaseFactory[GeneratorCore]):
|
||||||
result = generator.generate(request)
|
result = generator.generate(request)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_registry: Dict[str, type] = {}
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
def create(parameter: ModelParameter, request: GenerationRequest) -> GeneratorCore:
|
||||||
"""Create a generator based on request characteristics.
|
"""Create a generator based on request characteristics.
|
||||||
|
|
|
||||||
|
|
@ -1,12 +1,8 @@
|
||||||
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
from astrai.trainer.schedule import BaseScheduler, SchedulerFactory
|
||||||
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
from astrai.trainer.strategy import BaseStrategy, StrategyFactory
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
CheckpointCallback,
|
|
||||||
GradientClippingCallback,
|
|
||||||
MetricLoggerCallback,
|
|
||||||
ProgressBarCallback,
|
|
||||||
SchedulerCallback,
|
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
|
CallbackFactory,
|
||||||
)
|
)
|
||||||
from astrai.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
|
|
||||||
|
|
@ -19,11 +15,7 @@ __all__ = [
|
||||||
# Scheduler factory
|
# Scheduler factory
|
||||||
"SchedulerFactory",
|
"SchedulerFactory",
|
||||||
"BaseScheduler",
|
"BaseScheduler",
|
||||||
# Callbacks
|
# Callback factory
|
||||||
"TrainCallback",
|
"TrainCallback",
|
||||||
"GradientClippingCallback",
|
"CallbackFactory",
|
||||||
"SchedulerCallback",
|
|
||||||
"CheckpointCallback",
|
|
||||||
"ProgressBarCallback",
|
|
||||||
"MetricLoggerCallback",
|
|
||||||
]
|
]
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,7 @@ from typing import Any, Dict, List, Type
|
||||||
|
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
|
|
||||||
from astrai.core.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
class BaseScheduler(LRScheduler, ABC):
|
class BaseScheduler(LRScheduler, ABC):
|
||||||
|
|
@ -41,8 +41,6 @@ class SchedulerFactory(BaseFactory["BaseScheduler"]):
|
||||||
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
scheduler = SchedulerFactory.create("custom", optimizer, **kwargs)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_registry: Dict[str, Type[BaseScheduler]] = {}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
def _validate_component(cls, scheduler_cls: Type[BaseScheduler]) -> None:
|
||||||
"""Validate that the scheduler class inherits from BaseScheduler."""
|
"""Validate that the scheduler class inherits from BaseScheduler."""
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ import torch.nn.functional as F
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from astrai.core.factory import BaseFactory
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
def unwrap_model(model: nn.Module) -> nn.Module:
|
def unwrap_model(model: nn.Module) -> nn.Module:
|
||||||
|
|
@ -122,8 +122,6 @@ class StrategyFactory(BaseFactory["BaseStrategy"]):
|
||||||
strategy = StrategyFactory.create("custom", model, device)
|
strategy = StrategyFactory.create("custom", model, device)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_registry: Dict[str, type] = {}
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def _validate_component(cls, strategy_cls: type) -> None:
|
def _validate_component(cls, strategy_cls: type) -> None:
|
||||||
"""Validate that the strategy class inherits from BaseStrategy."""
|
"""Validate that the strategy class inherits from BaseStrategy."""
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, Optional, Protocol
|
from typing import Callable, List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
|
|
@ -21,8 +21,10 @@ from astrai.trainer.metric_util import (
|
||||||
ctx_get_lr,
|
ctx_get_lr,
|
||||||
)
|
)
|
||||||
from astrai.trainer.train_context import TrainContext
|
from astrai.trainer.train_context import TrainContext
|
||||||
|
from astrai.factory import BaseFactory
|
||||||
|
|
||||||
|
|
||||||
|
@runtime_checkable
|
||||||
class TrainCallback(Protocol):
|
class TrainCallback(Protocol):
|
||||||
"""
|
"""
|
||||||
Callback interface for trainer.
|
Callback interface for trainer.
|
||||||
|
|
@ -56,6 +58,25 @@ class TrainCallback(Protocol):
|
||||||
"""Called when an error occurs during training."""
|
"""Called when an error occurs during training."""
|
||||||
|
|
||||||
|
|
||||||
|
class CallbackFactory(BaseFactory[TrainCallback]):
|
||||||
|
"""Factory for registering and creating training callbacks.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
@CallbackFactory.register("my_callback")
|
||||||
|
class MyCallback(TrainCallback):
|
||||||
|
...
|
||||||
|
|
||||||
|
callback = CallbackFactory.create("my_callback", **kwargs)
|
||||||
|
"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _validate_component(cls, callback_cls: type) -> None:
|
||||||
|
"""Validate that the callback class inherits from TrainCallback."""
|
||||||
|
if not issubclass(callback_cls, TrainCallback):
|
||||||
|
raise TypeError(f"{callback_cls.__name__} must inherit from TrainCallback")
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("gradient_clipping")
|
||||||
class GradientClippingCallback(TrainCallback):
|
class GradientClippingCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Gradient clipping callback for trainer.
|
Gradient clipping callback for trainer.
|
||||||
|
|
@ -69,6 +90,7 @@ class GradientClippingCallback(TrainCallback):
|
||||||
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
clip_grad_norm_(context.model.parameters(), self.max_grad_norm)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("scheduler")
|
||||||
class SchedulerCallback(TrainCallback):
|
class SchedulerCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Scheduler callback for trainer.
|
Scheduler callback for trainer.
|
||||||
|
|
@ -87,6 +109,7 @@ class SchedulerCallback(TrainCallback):
|
||||||
context.scheduler.step()
|
context.scheduler.step()
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("checkpoint")
|
||||||
class CheckpointCallback(TrainCallback):
|
class CheckpointCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Checkpoint callback for trainer.
|
Checkpoint callback for trainer.
|
||||||
|
|
@ -135,6 +158,7 @@ class CheckpointCallback(TrainCallback):
|
||||||
self._save_checkpoint(context)
|
self._save_checkpoint(context)
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("progress_bar")
|
||||||
class ProgressBarCallback(TrainCallback):
|
class ProgressBarCallback(TrainCallback):
|
||||||
"""
|
"""
|
||||||
Progress bar callback for trainer.
|
Progress bar callback for trainer.
|
||||||
|
|
@ -169,6 +193,7 @@ class ProgressBarCallback(TrainCallback):
|
||||||
self.progress_bar.close()
|
self.progress_bar.close()
|
||||||
|
|
||||||
|
|
||||||
|
@CallbackFactory.register("metric_logger")
|
||||||
class MetricLoggerCallback(TrainCallback):
|
class MetricLoggerCallback(TrainCallback):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -5,12 +5,8 @@ from astrai.config import TrainConfig
|
||||||
from astrai.data.serialization import Checkpoint
|
from astrai.data.serialization import Checkpoint
|
||||||
from astrai.parallel.setup import spawn_parallel_fn
|
from astrai.parallel.setup import spawn_parallel_fn
|
||||||
from astrai.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
CheckpointCallback,
|
|
||||||
GradientClippingCallback,
|
|
||||||
MetricLoggerCallback,
|
|
||||||
ProgressBarCallback,
|
|
||||||
SchedulerCallback,
|
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
|
CallbackFactory,
|
||||||
)
|
)
|
||||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
|
|
||||||
|
|
@ -28,13 +24,13 @@ class Trainer:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_default_callbacks(self) -> List[TrainCallback]:
|
def _get_default_callbacks(self) -> List[TrainCallback]:
|
||||||
train_config = self.train_config
|
cfg = self.train_config
|
||||||
return [
|
return [
|
||||||
ProgressBarCallback(train_config.n_epoch),
|
CallbackFactory.create("progress_bar", cfg.n_epoch),
|
||||||
CheckpointCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
CallbackFactory.create("checkpoint", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
MetricLoggerCallback(train_config.ckpt_dir, train_config.ckpt_interval),
|
CallbackFactory.create("metric_logger", cfg.ckpt_dir, cfg.ckpt_interval),
|
||||||
GradientClippingCallback(train_config.max_grad_norm),
|
CallbackFactory.create("gradient_clipping", cfg.max_grad_norm),
|
||||||
SchedulerCallback(),
|
CallbackFactory.create("scheduler"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
def _build_context(self, checkpoint: Optional[Checkpoint]) -> TrainContext:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue