chore: 更新项目名称
This commit is contained in:
parent
780b9e1855
commit
2e009cf59a
14
README.md
14
README.md
|
|
@ -1,7 +1,7 @@
|
|||
<div align="center">
|
||||
<img src="assets/images/project_logo.png" width="auto" alt="Logo">
|
||||
<!-- <img src="assets/images/project_logo.png" width="auto" alt="Logo"> -->
|
||||
|
||||
<h1>KHAOSZ</h1>
|
||||
<h1>AstrAI</h1>
|
||||
|
||||
<div>
|
||||
<a href="#english">English</a> •
|
||||
|
|
@ -48,8 +48,8 @@
|
|||
### Installation
|
||||
|
||||
```bash
|
||||
git clone https://github.com/username/khaosz.git
|
||||
cd khaosz
|
||||
git clone https://github.com/ViperEkura/AstrAI.git
|
||||
cd AstrAI
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
|
@ -95,8 +95,8 @@ python demo/generate_ar.py
|
|||
### 安装
|
||||
|
||||
```bash
|
||||
git clone https://github.com/username/khaosz.git
|
||||
cd khaosz
|
||||
git clone https://github.com/ViperEkura/AstrAI.git
|
||||
cd AstrAI
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
|
|
@ -143,7 +143,7 @@ python demo/generate_ar.py
|
|||
|
||||
### Download | 下载
|
||||
|
||||
- [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ)
|
||||
- [HuggingFace](https://huggingface.co/ViperEk/AstrAI)
|
||||
- `python demo/download.py`
|
||||
|
||||
### Lincence | 许可证
|
||||
|
|
|
|||
|
|
@ -1,16 +1,16 @@
|
|||
# KHAOSZ Data Flow Documentation
|
||||
# AstrAI Data Flow Documentation
|
||||
|
||||
This document describes the data flow of the KHAOSZ project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
|
||||
This document describes the data flow of the AstrAI project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
|
||||
|
||||
## Overview
|
||||
|
||||
KHAOSZ adopts a modular design with the following main components:
|
||||
- **Data Module** (`khaosz/data/`): Dataset, sampler, tokenizer, serialization tools
|
||||
- **Model Module** (`khaosz/model/`): Transformer model and its submodules
|
||||
- **Training Module** (`khaosz/trainer/`): Trainer, training context, strategies, schedulers
|
||||
- **Inference Module** (`khaosz/inference/`): Generation core, KV cache management, streaming generation
|
||||
- **Config Module** (`khaosz/config/`): Model, training, scheduler, and other configurations
|
||||
- **Parallel Module** (`khaosz/parallel/`): Distributed training support
|
||||
AstrAI adopts a modular design with the following main components:
|
||||
- **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools
|
||||
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
||||
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
||||
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
|
||||
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||
|
||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
||||
|
||||
|
|
@ -199,7 +199,7 @@ flowchart LR
|
|||
|
||||
## Summary
|
||||
|
||||
The data flow design of KHAOSZ reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||
|
||||
> Document Update Time: 2026-03-30
|
||||
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`
|
||||
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
|
||||
|
||||
Thus, the KHAOSZ project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source!
|
||||
Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source!
|
||||
|
||||
## 2. System Architecture
|
||||
|
||||
|
|
|
|||
|
|
@ -83,8 +83,8 @@
|
|||
### Usage Example
|
||||
|
||||
```python
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import StreamGenerator, GenerationRequest
|
||||
|
||||
# Load model
|
||||
param = ModelParameter.load("your_model_dir")
|
||||
|
|
|
|||
|
|
@ -1,13 +1,13 @@
|
|||
__version__ = "1.3.2"
|
||||
__author__ = "ViperEkura"
|
||||
|
||||
from khaosz.config import (
|
||||
from astrai.config import (
|
||||
ModelConfig,
|
||||
TrainConfig,
|
||||
)
|
||||
from khaosz.model.transformer import Transformer
|
||||
from khaosz.data import DatasetLoader, BpeTokenizer
|
||||
from khaosz.inference.generator import (
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.data import DatasetLoader, BpeTokenizer
|
||||
from astrai.inference.generator import (
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
|
|
@ -15,7 +15,7 @@ from khaosz.inference.generator import (
|
|||
EmbeddingEncoder,
|
||||
GeneratorFactory,
|
||||
)
|
||||
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory
|
||||
from astrai.trainer import Trainer, StrategyFactory, SchedulerFactory
|
||||
|
||||
__all__ = [
|
||||
"Transformer",
|
||||
|
|
@ -1,12 +1,12 @@
|
|||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
||||
from khaosz.config.schedule_config import (
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.config.param_config import BaseModelIO, ModelParameter
|
||||
from astrai.config.schedule_config import (
|
||||
ScheduleConfig,
|
||||
CosineScheduleConfig,
|
||||
SGDRScheduleConfig,
|
||||
ScheduleConfigFactory,
|
||||
)
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
from astrai.config.train_config import TrainConfig
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -5,9 +5,9 @@ from dataclasses import dataclass, field
|
|||
from typing import Optional, Self, Union
|
||||
from pathlib import Path
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.model.transformer import Transformer
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from khaosz.data.dataset import (
|
||||
from astrai.data.dataset import (
|
||||
BaseDataset,
|
||||
SEQDataset,
|
||||
DPODataset,
|
||||
|
|
@ -9,8 +9,8 @@ from khaosz.data.dataset import (
|
|||
DatasetFactory,
|
||||
)
|
||||
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.data.sampler import ResumableDistributedSampler
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
from astrai.data.sampler import ResumableDistributedSampler
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
|
|
@ -6,7 +6,7 @@ import bisect
|
|||
from abc import ABC, abstractmethod
|
||||
from torch import Tensor
|
||||
from torch.utils.data import Dataset
|
||||
from khaosz.data.serialization import load_h5
|
||||
from astrai.data.serialization import load_h5
|
||||
from typing import List, Dict, Optional, Union
|
||||
|
||||
|
||||
|
|
@ -8,7 +8,7 @@ import torch.distributed as dist
|
|||
from pathlib import Path
|
||||
from torch import Tensor
|
||||
from typing import Any, Dict, List
|
||||
from khaosz.parallel.setup import get_rank
|
||||
from astrai.parallel.setup import get_rank
|
||||
|
||||
|
||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||
|
|
@ -1,11 +1,11 @@
|
|||
from khaosz.inference.core import (
|
||||
from astrai.inference.core import (
|
||||
disable_random_init,
|
||||
GeneratorCore,
|
||||
EmbeddingEncoderCore,
|
||||
KVCacheManager,
|
||||
)
|
||||
|
||||
from khaosz.inference.generator import (
|
||||
from astrai.inference.generator import (
|
||||
GenerationRequest,
|
||||
LoopGenerator,
|
||||
StreamGenerator,
|
||||
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
|||
from torch import Tensor
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||
from khaosz.config import ModelParameter, ModelConfig
|
||||
from astrai.config import ModelParameter, ModelConfig
|
||||
|
||||
|
||||
def apply_sampling_strategies(
|
||||
|
|
@ -2,8 +2,8 @@ import torch
|
|||
from dataclasses import dataclass
|
||||
from torch import Tensor
|
||||
from typing import List, Tuple, Union, Optional, Generator
|
||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from astrai.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||
from astrai.config.param_config import ModelParameter
|
||||
|
||||
|
||||
HistoryType = List[Tuple[str, str]]
|
||||
|
|
@ -1,10 +1,10 @@
|
|||
from khaosz.model.module import (
|
||||
from astrai.model.module import (
|
||||
Linear,
|
||||
RMSNorm,
|
||||
MLP,
|
||||
GQA,
|
||||
DecoderBlock,
|
||||
)
|
||||
from khaosz.model.transformer import Transformer
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
|
||||
|
|
@ -3,8 +3,8 @@ import torch.nn as nn
|
|||
|
||||
from torch import Tensor
|
||||
from typing import Any, Mapping, Optional, Tuple
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.model.module import (
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.model.module import (
|
||||
Embedding,
|
||||
DecoderBlock,
|
||||
Linear,
|
||||
|
|
@ -1,4 +1,4 @@
|
|||
from khaosz.parallel.setup import (
|
||||
from astrai.parallel.setup import (
|
||||
get_world_size,
|
||||
get_rank,
|
||||
get_current_device,
|
||||
|
|
@ -7,7 +7,7 @@ from khaosz.parallel.setup import (
|
|||
spawn_parallel_fn,
|
||||
)
|
||||
|
||||
from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear
|
||||
from astrai.parallel.module import RowParallelLinear, ColumnParallelLinear
|
||||
|
||||
__all__ = [
|
||||
"get_world_size",
|
||||
|
|
@ -1,8 +1,8 @@
|
|||
from khaosz.trainer.trainer import Trainer
|
||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler
|
||||
from astrai.trainer.trainer import Trainer
|
||||
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from astrai.trainer.schedule import SchedulerFactory, BaseScheduler
|
||||
|
||||
from khaosz.trainer.train_callback import (
|
||||
from astrai.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
GradientClippingCallback,
|
||||
SchedulerCallback,
|
||||
|
|
@ -4,7 +4,7 @@ import math
|
|||
from abc import abstractmethod, ABC
|
||||
from typing import Any, Dict, List, Type
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from khaosz.config.schedule_config import ScheduleConfig
|
||||
from astrai.config.schedule_config import ScheduleConfig
|
||||
|
||||
|
||||
class BaseScheduler(LRScheduler, ABC):
|
||||
|
|
@ -8,8 +8,8 @@ from tqdm import tqdm
|
|||
from torch.nn.utils import clip_grad_norm_
|
||||
from typing import Callable, List, Optional, Protocol
|
||||
|
||||
from khaosz.parallel import only_on_rank
|
||||
from khaosz.trainer.metric_util import (
|
||||
from astrai.parallel import only_on_rank
|
||||
from astrai.trainer.metric_util import (
|
||||
ctx_get_loss,
|
||||
ctx_get_lr,
|
||||
ctx_get_grad_max,
|
||||
|
|
@ -19,8 +19,8 @@ from khaosz.trainer.metric_util import (
|
|||
ctx_get_grad_std,
|
||||
ctx_get_grad_nan_num,
|
||||
)
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
from khaosz.trainer.train_context import TrainContext
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.trainer.train_context import TrainContext
|
||||
|
||||
|
||||
class TrainCallback(Protocol):
|
||||
|
|
@ -3,11 +3,11 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from khaosz.data import ResumableDistributedSampler
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from khaosz.config.train_config import TrainConfig
|
||||
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
|
||||
from astrai.data import ResumableDistributedSampler
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
|
||||
from astrai.config.train_config import TrainConfig
|
||||
from astrai.parallel.setup import get_current_device, get_world_size, get_rank
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional, Self
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
import logging
|
||||
from typing import Optional, List
|
||||
from khaosz.config import TrainConfig
|
||||
from khaosz.trainer.train_callback import (
|
||||
from astrai.config import TrainConfig
|
||||
from astrai.trainer.train_callback import (
|
||||
TrainCallback,
|
||||
ProgressBarCallback,
|
||||
CheckpointCallback,
|
||||
|
|
@ -9,9 +9,9 @@ from khaosz.trainer.train_callback import (
|
|||
GradientClippingCallback,
|
||||
SchedulerCallback,
|
||||
)
|
||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
from khaosz.parallel.setup import spawn_parallel_fn
|
||||
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.parallel.setup import spawn_parallel_fn
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -6,7 +6,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
|||
|
||||
if __name__ == "__main__":
|
||||
snapshot_download(
|
||||
repo_id="ViperEk/KHAOSZ",
|
||||
repo_id="ViperEk/AstrAI",
|
||||
local_dir=PARAMETER_ROOT,
|
||||
force_download=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.core import disable_random_init
|
||||
from khaosz.inference.generator import GeneratorFactory, GenerationRequest
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.core import disable_random_init
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.core import disable_random_init
|
||||
from khaosz.inference.generator import GeneratorFactory, GenerationRequest
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.core import disable_random_init
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
from pathlib import Path
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.core import disable_random_init
|
||||
from khaosz.inference.generator import GeneratorFactory, GenerationRequest
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.core import disable_random_init
|
||||
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||
|
||||
PROJECT_ROOT = Path(__file__).parent.parent
|
||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|||
|
||||
[project]
|
||||
dynamic = ["version"]
|
||||
name = "khaosz"
|
||||
name = "astrai"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
|
|
@ -23,7 +23,7 @@ classifiers = [
|
|||
"License :: OSI Approved :: GPL-3.0",
|
||||
"Operating System :: OS Independent",
|
||||
]
|
||||
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
|
||||
urls = { Homepage = "https://github.com/ViperEkura/AstrAI" }
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = ["pytest==9.0.2", "ruff"]
|
||||
|
|
@ -35,7 +35,7 @@ where = ["."]
|
|||
extra-index-url = "https://download.pytorch.org/whl/cu126"
|
||||
|
||||
[tool.setuptools.dynamic]
|
||||
version = { attr = "khaosz.__version__" }
|
||||
version = { attr = "astrai.__version__" }
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
|
|
|
|||
|
|
@ -7,9 +7,9 @@ import torch
|
|||
import pytest
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from khaosz.data.tokenizer import BpeTokenizer
|
||||
from khaosz.model.transformer import Transformer
|
||||
from astrai.config.model_config import ModelConfig
|
||||
from astrai.data.tokenizer import BpeTokenizer
|
||||
from astrai.model.transformer import Transformer
|
||||
|
||||
|
||||
class RandomDataset(Dataset):
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import torch.distributed as dist
|
|||
|
||||
from torch.optim import AdamW
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
||||
from astrai.data.serialization import Checkpoint
|
||||
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
||||
|
||||
|
||||
def test_single_process():
|
||||
|
|
|
|||
|
|
@ -1,8 +1,8 @@
|
|||
import torch
|
||||
import numpy as np
|
||||
|
||||
from khaosz.data.serialization import save_h5
|
||||
from khaosz.data.dataset import *
|
||||
from astrai.data.serialization import save_h5
|
||||
from astrai.data.dataset import *
|
||||
|
||||
|
||||
def test_dataset_loader_random_paths(base_test_env):
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
from khaosz.trainer import *
|
||||
from khaosz.data import *
|
||||
from astrai.trainer import *
|
||||
from astrai.data import *
|
||||
|
||||
|
||||
def test_random_sampler_consistency(random_dataset):
|
||||
|
|
|
|||
|
|
@ -5,11 +5,11 @@ import shutil
|
|||
import pytest
|
||||
import tempfile
|
||||
import safetensors.torch as st
|
||||
from khaosz.trainer import *
|
||||
from khaosz.config import *
|
||||
from khaosz.model import *
|
||||
from khaosz.data import *
|
||||
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||
from astrai.trainer import *
|
||||
from astrai.config import *
|
||||
from astrai.model import *
|
||||
from astrai.data import *
|
||||
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||
from tokenizers import pre_tokenizers
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -4,8 +4,8 @@ import torch
|
|||
import pytest
|
||||
import tempfile
|
||||
import safetensors.torch as st
|
||||
from khaosz.model.transformer import Transformer
|
||||
from khaosz.config.model_config import ModelConfig
|
||||
from astrai.model.transformer import Transformer
|
||||
from astrai.config.model_config import ModelConfig
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn
|
||||
from astrai.parallel import get_rank, only_on_rank, spawn_parallel_fn
|
||||
|
||||
|
||||
@only_on_rank(0)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
from astrai.config import *
|
||||
from astrai.trainer import *
|
||||
|
||||
|
||||
def test_callback_integration(base_test_env, random_dataset):
|
||||
|
|
|
|||
|
|
@ -1,9 +1,9 @@
|
|||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.data.serialization import Checkpoint
|
||||
from astrai.config import *
|
||||
from astrai.trainer import *
|
||||
from astrai.data.serialization import Checkpoint
|
||||
|
||||
|
||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import torch
|
|||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer.schedule import *
|
||||
from khaosz.data.dataset import *
|
||||
from astrai.config import *
|
||||
from astrai.trainer.schedule import *
|
||||
from astrai.data.dataset import *
|
||||
|
||||
|
||||
def test_schedule_factory_random_configs():
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import torch
|
|||
import numpy as np
|
||||
|
||||
|
||||
from khaosz.config import *
|
||||
from khaosz.trainer import *
|
||||
from khaosz.data.dataset import *
|
||||
from astrai.config import *
|
||||
from astrai.trainer import *
|
||||
from astrai.data.dataset import *
|
||||
|
||||
|
||||
def test_different_batch_sizes(base_test_env, random_dataset):
|
||||
|
|
|
|||
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
from typing import Dict, Any
|
||||
from dataclasses import dataclass
|
||||
from khaosz.model.transformer import ModelConfig, Transformer
|
||||
from astrai.model.transformer import ModelConfig, Transformer
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ import torch
|
|||
import json
|
||||
import argparse
|
||||
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
||||
from khaosz.inference.core import disable_random_init
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.generator import BatchGenerator, GenerationRequest
|
||||
from astrai.inference.core import disable_random_init
|
||||
|
||||
|
||||
def processor(
|
||||
|
|
|
|||
|
|
@ -6,8 +6,8 @@ import argparse
|
|||
import tqdm
|
||||
|
||||
from torch import Tensor
|
||||
from khaosz.config.param_config import ModelParameter
|
||||
from khaosz.inference.core import disable_random_init
|
||||
from astrai.config.param_config import ModelParameter
|
||||
from astrai.inference.core import disable_random_init
|
||||
|
||||
|
||||
def compute_perplexity(
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ import torch.optim as optim
|
|||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
from functools import partial
|
||||
from khaosz.data import DatasetLoader
|
||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from khaosz.trainer import Trainer, SchedulerFactory
|
||||
from khaosz.parallel import get_rank
|
||||
from astrai.data import DatasetLoader
|
||||
from astrai.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||
from astrai.trainer import Trainer, SchedulerFactory
|
||||
from astrai.parallel import get_rank
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
|
|
|
|||
Loading…
Reference in New Issue