AstrAI/astrai/config/train_config.py

102 lines
3.6 KiB
Python

import torch.nn as nn
from torch.utils.data import Dataset
from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from dataclasses import dataclass, field
from typing import Callable, List, Optional
@dataclass
class TrainConfig:
# basic setting
model: nn.Module = field(default=None, metadata={"help": "Model for training."})
strategy: str = field(default=None, metadata={"help": "Training strategy."})
dataset: Dataset = field(default=None, metadata={"help": "Dataset for training."})
optimizer_fn: Callable[[nn.Module], Optimizer] = field(
default=None, metadata={"help": "Optimizer factory for training."}
)
scheduler_fn: Callable[[Optimizer], LRScheduler] = field(
default=None, metadata={"help": "Scheduler factory for training."}
)
n_epoch: int = field(default=1, metadata={"help": "Number of epochs for training."})
batch_size: int = field(default=4, metadata={"help": "Batch size for training."})
accumulation_steps: int = field(
default=1, metadata={"help": "Number of iterations between steps."}
)
max_grad_norm: float = field(
default=1.0, metadata={"help": "Maximum gradient norm."}
)
# checkpoint setting
start_epoch: int = field(default=0, metadata={"help": "Start epoch for training."})
start_batch: int = field(
default=0, metadata={"help": "Start batch iteration for training."}
)
ckpt_dir: str = field(
default="./checkpoint", metadata={"help": "Checkpoint directory."}
)
ckpt_interval: int = field(
default=5000, metadata={"help": "Number of iterations between checkpoints."}
)
# dataloader setting
random_seed: int = field(default=3407, metadata={"help": "Random seed."})
num_workers: int = field(
default=0, metadata={"help": "Number of workers for dataloader."}
)
prefetch_factor: Optional[int] = field(
default=None, metadata={"help": "Prefetch factor for dataloader."}
)
pin_memory: bool = field(
default=False, metadata={"help": "Pin memory for dataloader."}
)
# distributed training
nprocs: int = field(
default=1, metadata={"help": "Number of processes for distributed training."}
)
backend: str = field(
default="nccl", metadata={"help": "Distributed training backend."}
)
master_addr: str = field(
default="localhost",
metadata={"help": "Master address for distributed training."},
)
master_port: str = field(
default="29500", metadata={"help": "Master port for distributed training."}
)
parallel_wrapper: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for training."}
)
state_dict_fn: Optional[Callable] = field(
default=None, metadata={"help": "Parallel function for state dict saving."}
)
# others
device_ids: Optional[List[int]] = field(
default=None, metadata={"help": "Device ids for distributed training."}
)
device_type: str = field(
default="cuda", metadata={"help": "Device type for distributed training."}
)
extra_kwargs: dict = field(
default_factory=dict, metadata={"help": "Other arguments."}
)
def __post_init__(self):
self.validate()
def validate(self):
required_fields = [
"model",
"strategy",
"dataset",
"optimizer_fn",
"scheduler_fn",
]
for field_name in required_fields:
if getattr(self, field_name) is None:
raise ValueError(f"{field_name} is required.")