chore: 更新项目名称
This commit is contained in:
parent
780b9e1855
commit
2e009cf59a
14
README.md
14
README.md
|
|
@ -1,7 +1,7 @@
|
||||||
<div align="center">
|
<div align="center">
|
||||||
<img src="assets/images/project_logo.png" width="auto" alt="Logo">
|
<!-- <img src="assets/images/project_logo.png" width="auto" alt="Logo"> -->
|
||||||
|
|
||||||
<h1>KHAOSZ</h1>
|
<h1>AstrAI</h1>
|
||||||
|
|
||||||
<div>
|
<div>
|
||||||
<a href="#english">English</a> •
|
<a href="#english">English</a> •
|
||||||
|
|
@ -48,8 +48,8 @@
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/username/khaosz.git
|
git clone https://github.com/ViperEkura/AstrAI.git
|
||||||
cd khaosz
|
cd AstrAI
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -95,8 +95,8 @@ python demo/generate_ar.py
|
||||||
### 安装
|
### 安装
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
git clone https://github.com/username/khaosz.git
|
git clone https://github.com/ViperEkura/AstrAI.git
|
||||||
cd khaosz
|
cd AstrAI
|
||||||
pip install -e .
|
pip install -e .
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
@ -143,7 +143,7 @@ python demo/generate_ar.py
|
||||||
|
|
||||||
### Download | 下载
|
### Download | 下载
|
||||||
|
|
||||||
- [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ)
|
- [HuggingFace](https://huggingface.co/ViperEk/AstrAI)
|
||||||
- `python demo/download.py`
|
- `python demo/download.py`
|
||||||
|
|
||||||
### Lincence | 许可证
|
### Lincence | 许可证
|
||||||
|
|
|
||||||
|
|
@ -1,16 +1,16 @@
|
||||||
# KHAOSZ Data Flow Documentation
|
# AstrAI Data Flow Documentation
|
||||||
|
|
||||||
This document describes the data flow of the KHAOSZ project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
|
This document describes the data flow of the AstrAI project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
|
||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
KHAOSZ adopts a modular design with the following main components:
|
AstrAI adopts a modular design with the following main components:
|
||||||
- **Data Module** (`khaosz/data/`): Dataset, sampler, tokenizer, serialization tools
|
- **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools
|
||||||
- **Model Module** (`khaosz/model/`): Transformer model and its submodules
|
- **Model Module** (`astrai/model/`): Transformer model and its submodules
|
||||||
- **Training Module** (`khaosz/trainer/`): Trainer, training context, strategies, schedulers
|
- **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
|
||||||
- **Inference Module** (`khaosz/inference/`): Generation core, KV cache management, streaming generation
|
- **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
|
||||||
- **Config Module** (`khaosz/config/`): Model, training, scheduler, and other configurations
|
- **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
|
||||||
- **Parallel Module** (`khaosz/parallel/`): Distributed training support
|
- **Parallel Module** (`astrai/parallel/`): Distributed training support
|
||||||
|
|
||||||
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
|
||||||
|
|
||||||
|
|
@ -199,7 +199,7 @@ flowchart LR
|
||||||
|
|
||||||
## Summary
|
## Summary
|
||||||
|
|
||||||
The data flow design of KHAOSZ reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
|
||||||
|
|
||||||
> Document Update Time: 2026-03-30
|
> Document Update Time: 2026-03-30
|
||||||
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`
|
> Corresponding Code Version: Refer to version number defined in `pyproject.toml`
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
|
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
|
||||||
|
|
||||||
Thus, the KHAOSZ project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source!
|
Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source!
|
||||||
|
|
||||||
## 2. System Architecture
|
## 2. System Architecture
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -83,8 +83,8 @@
|
||||||
### Usage Example
|
### Usage Example
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from khaosz.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from khaosz.inference.generator import StreamGenerator, GenerationRequest
|
from astrai.inference.generator import StreamGenerator, GenerationRequest
|
||||||
|
|
||||||
# Load model
|
# Load model
|
||||||
param = ModelParameter.load("your_model_dir")
|
param = ModelParameter.load("your_model_dir")
|
||||||
|
|
|
||||||
|
|
@ -1,13 +1,13 @@
|
||||||
__version__ = "1.3.2"
|
__version__ = "1.3.2"
|
||||||
__author__ = "ViperEkura"
|
__author__ = "ViperEkura"
|
||||||
|
|
||||||
from khaosz.config import (
|
from astrai.config import (
|
||||||
ModelConfig,
|
ModelConfig,
|
||||||
TrainConfig,
|
TrainConfig,
|
||||||
)
|
)
|
||||||
from khaosz.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
from khaosz.data import DatasetLoader, BpeTokenizer
|
from astrai.data import DatasetLoader, BpeTokenizer
|
||||||
from khaosz.inference.generator import (
|
from astrai.inference.generator import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
LoopGenerator,
|
LoopGenerator,
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
|
|
@ -15,7 +15,7 @@ from khaosz.inference.generator import (
|
||||||
EmbeddingEncoder,
|
EmbeddingEncoder,
|
||||||
GeneratorFactory,
|
GeneratorFactory,
|
||||||
)
|
)
|
||||||
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory
|
from astrai.trainer import Trainer, StrategyFactory, SchedulerFactory
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"Transformer",
|
"Transformer",
|
||||||
|
|
@ -1,12 +1,12 @@
|
||||||
from khaosz.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from khaosz.config.param_config import BaseModelIO, ModelParameter
|
from astrai.config.param_config import BaseModelIO, ModelParameter
|
||||||
from khaosz.config.schedule_config import (
|
from astrai.config.schedule_config import (
|
||||||
ScheduleConfig,
|
ScheduleConfig,
|
||||||
CosineScheduleConfig,
|
CosineScheduleConfig,
|
||||||
SGDRScheduleConfig,
|
SGDRScheduleConfig,
|
||||||
ScheduleConfigFactory,
|
ScheduleConfigFactory,
|
||||||
)
|
)
|
||||||
from khaosz.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
|
@ -5,9 +5,9 @@ from dataclasses import dataclass, field
|
||||||
from typing import Optional, Self, Union
|
from typing import Optional, Self, Union
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from astrai.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from khaosz.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from khaosz.data.dataset import (
|
from astrai.data.dataset import (
|
||||||
BaseDataset,
|
BaseDataset,
|
||||||
SEQDataset,
|
SEQDataset,
|
||||||
DPODataset,
|
DPODataset,
|
||||||
|
|
@ -9,8 +9,8 @@ from khaosz.data.dataset import (
|
||||||
DatasetFactory,
|
DatasetFactory,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from astrai.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.data.sampler import ResumableDistributedSampler
|
from astrai.data.sampler import ResumableDistributedSampler
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Base classes
|
# Base classes
|
||||||
|
|
@ -6,7 +6,7 @@ import bisect
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from khaosz.data.serialization import load_h5
|
from astrai.data.serialization import load_h5
|
||||||
from typing import List, Dict, Optional, Union
|
from typing import List, Dict, Optional, Union
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -8,7 +8,7 @@ import torch.distributed as dist
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Dict, List
|
from typing import Any, Dict, List
|
||||||
from khaosz.parallel.setup import get_rank
|
from astrai.parallel.setup import get_rank
|
||||||
|
|
||||||
|
|
||||||
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):
|
||||||
|
|
@ -1,11 +1,11 @@
|
||||||
from khaosz.inference.core import (
|
from astrai.inference.core import (
|
||||||
disable_random_init,
|
disable_random_init,
|
||||||
GeneratorCore,
|
GeneratorCore,
|
||||||
EmbeddingEncoderCore,
|
EmbeddingEncoderCore,
|
||||||
KVCacheManager,
|
KVCacheManager,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.inference.generator import (
|
from astrai.inference.generator import (
|
||||||
GenerationRequest,
|
GenerationRequest,
|
||||||
LoopGenerator,
|
LoopGenerator,
|
||||||
StreamGenerator,
|
StreamGenerator,
|
||||||
|
|
@ -4,7 +4,7 @@ import torch.nn as nn
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
from typing import Any, Callable, List, Tuple, Union, Optional, Self
|
||||||
from khaosz.config import ModelParameter, ModelConfig
|
from astrai.config import ModelParameter, ModelConfig
|
||||||
|
|
||||||
|
|
||||||
def apply_sampling_strategies(
|
def apply_sampling_strategies(
|
||||||
|
|
@ -2,8 +2,8 @@ import torch
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import List, Tuple, Union, Optional, Generator
|
from typing import List, Tuple, Union, Optional, Generator
|
||||||
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
from astrai.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
|
||||||
from khaosz.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
|
|
||||||
|
|
||||||
HistoryType = List[Tuple[str, str]]
|
HistoryType = List[Tuple[str, str]]
|
||||||
|
|
@ -1,10 +1,10 @@
|
||||||
from khaosz.model.module import (
|
from astrai.model.module import (
|
||||||
Linear,
|
Linear,
|
||||||
RMSNorm,
|
RMSNorm,
|
||||||
MLP,
|
MLP,
|
||||||
GQA,
|
GQA,
|
||||||
DecoderBlock,
|
DecoderBlock,
|
||||||
)
|
)
|
||||||
from khaosz.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
|
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]
|
||||||
|
|
@ -3,8 +3,8 @@ import torch.nn as nn
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from typing import Any, Mapping, Optional, Tuple
|
from typing import Any, Mapping, Optional, Tuple
|
||||||
from khaosz.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from khaosz.model.module import (
|
from astrai.model.module import (
|
||||||
Embedding,
|
Embedding,
|
||||||
DecoderBlock,
|
DecoderBlock,
|
||||||
Linear,
|
Linear,
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
from khaosz.parallel.setup import (
|
from astrai.parallel.setup import (
|
||||||
get_world_size,
|
get_world_size,
|
||||||
get_rank,
|
get_rank,
|
||||||
get_current_device,
|
get_current_device,
|
||||||
|
|
@ -7,7 +7,7 @@ from khaosz.parallel.setup import (
|
||||||
spawn_parallel_fn,
|
spawn_parallel_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear
|
from astrai.parallel.module import RowParallelLinear, ColumnParallelLinear
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"get_world_size",
|
"get_world_size",
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
from khaosz.trainer.trainer import Trainer
|
from astrai.trainer.trainer import Trainer
|
||||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
|
||||||
from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler
|
from astrai.trainer.schedule import SchedulerFactory, BaseScheduler
|
||||||
|
|
||||||
from khaosz.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
SchedulerCallback,
|
SchedulerCallback,
|
||||||
|
|
@ -4,7 +4,7 @@ import math
|
||||||
from abc import abstractmethod, ABC
|
from abc import abstractmethod, ABC
|
||||||
from typing import Any, Dict, List, Type
|
from typing import Any, Dict, List, Type
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from khaosz.config.schedule_config import ScheduleConfig
|
from astrai.config.schedule_config import ScheduleConfig
|
||||||
|
|
||||||
|
|
||||||
class BaseScheduler(LRScheduler, ABC):
|
class BaseScheduler(LRScheduler, ABC):
|
||||||
|
|
@ -8,8 +8,8 @@ from tqdm import tqdm
|
||||||
from torch.nn.utils import clip_grad_norm_
|
from torch.nn.utils import clip_grad_norm_
|
||||||
from typing import Callable, List, Optional, Protocol
|
from typing import Callable, List, Optional, Protocol
|
||||||
|
|
||||||
from khaosz.parallel import only_on_rank
|
from astrai.parallel import only_on_rank
|
||||||
from khaosz.trainer.metric_util import (
|
from astrai.trainer.metric_util import (
|
||||||
ctx_get_loss,
|
ctx_get_loss,
|
||||||
ctx_get_lr,
|
ctx_get_lr,
|
||||||
ctx_get_grad_max,
|
ctx_get_grad_max,
|
||||||
|
|
@ -19,8 +19,8 @@ from khaosz.trainer.metric_util import (
|
||||||
ctx_get_grad_std,
|
ctx_get_grad_std,
|
||||||
ctx_get_grad_nan_num,
|
ctx_get_grad_nan_num,
|
||||||
)
|
)
|
||||||
from khaosz.data.serialization import Checkpoint
|
from astrai.data.serialization import Checkpoint
|
||||||
from khaosz.trainer.train_context import TrainContext
|
from astrai.trainer.train_context import TrainContext
|
||||||
|
|
||||||
|
|
||||||
class TrainCallback(Protocol):
|
class TrainCallback(Protocol):
|
||||||
|
|
@ -3,11 +3,11 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import LRScheduler
|
from torch.optim.lr_scheduler import LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
from khaosz.data import ResumableDistributedSampler
|
from astrai.data import ResumableDistributedSampler
|
||||||
from khaosz.data.serialization import Checkpoint
|
from astrai.data.serialization import Checkpoint
|
||||||
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy
|
from astrai.trainer.strategy import StrategyFactory, BaseStrategy
|
||||||
from khaosz.config.train_config import TrainConfig
|
from astrai.config.train_config import TrainConfig
|
||||||
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank
|
from astrai.parallel.setup import get_current_device, get_world_size, get_rank
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Self
|
from typing import Optional, Self
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
from khaosz.config import TrainConfig
|
from astrai.config import TrainConfig
|
||||||
from khaosz.trainer.train_callback import (
|
from astrai.trainer.train_callback import (
|
||||||
TrainCallback,
|
TrainCallback,
|
||||||
ProgressBarCallback,
|
ProgressBarCallback,
|
||||||
CheckpointCallback,
|
CheckpointCallback,
|
||||||
|
|
@ -9,9 +9,9 @@ from khaosz.trainer.train_callback import (
|
||||||
GradientClippingCallback,
|
GradientClippingCallback,
|
||||||
SchedulerCallback,
|
SchedulerCallback,
|
||||||
)
|
)
|
||||||
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder
|
from astrai.trainer.train_context import TrainContext, TrainContextBuilder
|
||||||
from khaosz.data.serialization import Checkpoint
|
from astrai.data.serialization import Checkpoint
|
||||||
from khaosz.parallel.setup import spawn_parallel_fn
|
from astrai.parallel.setup import spawn_parallel_fn
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -6,7 +6,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id="ViperEk/KHAOSZ",
|
repo_id="ViperEk/AstrAI",
|
||||||
local_dir=PARAMETER_ROOT,
|
local_dir=PARAMETER_ROOT,
|
||||||
force_download=True,
|
force_download=True,
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from khaosz.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from khaosz.inference.core import disable_random_init
|
from astrai.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import GeneratorFactory, GenerationRequest
|
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from khaosz.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from khaosz.inference.core import disable_random_init
|
from astrai.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import GeneratorFactory, GenerationRequest
|
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from khaosz.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from khaosz.inference.core import disable_random_init
|
from astrai.inference.core import disable_random_init
|
||||||
from khaosz.inference.generator import GeneratorFactory, GenerationRequest
|
from astrai.inference.generator import GeneratorFactory, GenerationRequest
|
||||||
|
|
||||||
PROJECT_ROOT = Path(__file__).parent.parent
|
PROJECT_ROOT = Path(__file__).parent.parent
|
||||||
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
||||||
|
|
||||||
[project]
|
[project]
|
||||||
dynamic = ["version"]
|
dynamic = ["version"]
|
||||||
name = "khaosz"
|
name = "astrai"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.12"
|
requires-python = ">=3.12"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
|
@ -23,7 +23,7 @@ classifiers = [
|
||||||
"License :: OSI Approved :: GPL-3.0",
|
"License :: OSI Approved :: GPL-3.0",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
]
|
]
|
||||||
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" }
|
urls = { Homepage = "https://github.com/ViperEkura/AstrAI" }
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
dev = ["pytest==9.0.2", "ruff"]
|
dev = ["pytest==9.0.2", "ruff"]
|
||||||
|
|
@ -35,7 +35,7 @@ where = ["."]
|
||||||
extra-index-url = "https://download.pytorch.org/whl/cu126"
|
extra-index-url = "https://download.pytorch.org/whl/cu126"
|
||||||
|
|
||||||
[tool.setuptools.dynamic]
|
[tool.setuptools.dynamic]
|
||||||
version = { attr = "khaosz.__version__" }
|
version = { attr = "astrai.__version__" }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py312"
|
target-version = "py312"
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,9 @@ import torch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from torch.utils.data import Dataset
|
from torch.utils.data import Dataset
|
||||||
from khaosz.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
from khaosz.data.tokenizer import BpeTokenizer
|
from astrai.data.tokenizer import BpeTokenizer
|
||||||
from khaosz.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
|
|
||||||
|
|
||||||
class RandomDataset(Dataset):
|
class RandomDataset(Dataset):
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ import torch.distributed as dist
|
||||||
|
|
||||||
from torch.optim import AdamW
|
from torch.optim import AdamW
|
||||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||||
from khaosz.data.serialization import Checkpoint
|
from astrai.data.serialization import Checkpoint
|
||||||
from khaosz.parallel.setup import get_rank, spawn_parallel_fn
|
from astrai.parallel.setup import get_rank, spawn_parallel_fn
|
||||||
|
|
||||||
|
|
||||||
def test_single_process():
|
def test_single_process():
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from khaosz.data.serialization import save_h5
|
from astrai.data.serialization import save_h5
|
||||||
from khaosz.data.dataset import *
|
from astrai.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
def test_dataset_loader_random_paths(base_test_env):
|
def test_dataset_loader_random_paths(base_test_env):
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
from khaosz.trainer import *
|
from astrai.trainer import *
|
||||||
from khaosz.data import *
|
from astrai.data import *
|
||||||
|
|
||||||
|
|
||||||
def test_random_sampler_consistency(random_dataset):
|
def test_random_sampler_consistency(random_dataset):
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,11 @@ import shutil
|
||||||
import pytest
|
import pytest
|
||||||
import tempfile
|
import tempfile
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
from khaosz.trainer import *
|
from astrai.trainer import *
|
||||||
from khaosz.config import *
|
from astrai.config import *
|
||||||
from khaosz.model import *
|
from astrai.model import *
|
||||||
from khaosz.data import *
|
from astrai.data import *
|
||||||
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
|
||||||
from tokenizers import pre_tokenizers
|
from tokenizers import pre_tokenizers
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -4,8 +4,8 @@ import torch
|
||||||
import pytest
|
import pytest
|
||||||
import tempfile
|
import tempfile
|
||||||
import safetensors.torch as st
|
import safetensors.torch as st
|
||||||
from khaosz.model.transformer import Transformer
|
from astrai.model.transformer import Transformer
|
||||||
from khaosz.config.model_config import ModelConfig
|
from astrai.config.model_config import ModelConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn
|
from astrai.parallel import get_rank, only_on_rank, spawn_parallel_fn
|
||||||
|
|
||||||
|
|
||||||
@only_on_rank(0)
|
@only_on_rank(0)
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from khaosz.config import *
|
from astrai.config import *
|
||||||
from khaosz.trainer import *
|
from astrai.trainer import *
|
||||||
|
|
||||||
|
|
||||||
def test_callback_integration(base_test_env, random_dataset):
|
def test_callback_integration(base_test_env, random_dataset):
|
||||||
|
|
|
||||||
|
|
@ -1,9 +1,9 @@
|
||||||
import os
|
import os
|
||||||
import torch
|
import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from khaosz.config import *
|
from astrai.config import *
|
||||||
from khaosz.trainer import *
|
from astrai.trainer import *
|
||||||
from khaosz.data.serialization import Checkpoint
|
from astrai.data.serialization import Checkpoint
|
||||||
|
|
||||||
|
|
||||||
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
def test_early_stopping_simulation(base_test_env, early_stopping_dataset):
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from khaosz.config import *
|
from astrai.config import *
|
||||||
from khaosz.trainer.schedule import *
|
from astrai.trainer.schedule import *
|
||||||
from khaosz.data.dataset import *
|
from astrai.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
def test_schedule_factory_random_configs():
|
def test_schedule_factory_random_configs():
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ import torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
|
||||||
from khaosz.config import *
|
from astrai.config import *
|
||||||
from khaosz.trainer import *
|
from astrai.trainer import *
|
||||||
from khaosz.data.dataset import *
|
from astrai.data.dataset import *
|
||||||
|
|
||||||
|
|
||||||
def test_different_batch_sizes(base_test_env, random_dataset):
|
def test_different_batch_sizes(base_test_env, random_dataset):
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
import torch
|
import torch
|
||||||
from typing import Dict, Any
|
from typing import Dict, Any
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from khaosz.model.transformer import ModelConfig, Transformer
|
from astrai.model.transformer import ModelConfig, Transformer
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
|
|
||||||
|
|
@ -2,9 +2,9 @@ import torch
|
||||||
import json
|
import json
|
||||||
import argparse
|
import argparse
|
||||||
|
|
||||||
from khaosz.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from khaosz.inference.generator import BatchGenerator, GenerationRequest
|
from astrai.inference.generator import BatchGenerator, GenerationRequest
|
||||||
from khaosz.inference.core import disable_random_init
|
from astrai.inference.core import disable_random_init
|
||||||
|
|
||||||
|
|
||||||
def processor(
|
def processor(
|
||||||
|
|
|
||||||
|
|
@ -6,8 +6,8 @@ import argparse
|
||||||
import tqdm
|
import tqdm
|
||||||
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from khaosz.config.param_config import ModelParameter
|
from astrai.config.param_config import ModelParameter
|
||||||
from khaosz.inference.core import disable_random_init
|
from astrai.inference.core import disable_random_init
|
||||||
|
|
||||||
|
|
||||||
def compute_perplexity(
|
def compute_perplexity(
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,10 @@ import torch.optim as optim
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from khaosz.data import DatasetLoader
|
from astrai.data import DatasetLoader
|
||||||
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
from astrai.config import ModelParameter, TrainConfig, CosineScheduleConfig
|
||||||
from khaosz.trainer import Trainer, SchedulerFactory
|
from astrai.trainer import Trainer, SchedulerFactory
|
||||||
from khaosz.parallel import get_rank
|
from astrai.parallel import get_rank
|
||||||
|
|
||||||
|
|
||||||
def parse_args() -> argparse.Namespace:
|
def parse_args() -> argparse.Namespace:
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue