chore: 更新项目名称

This commit is contained in:
ViperEkura 2026-03-31 09:34:11 +08:00
parent 780b9e1855
commit 2e009cf59a
51 changed files with 118 additions and 118 deletions

View File

@ -1,7 +1,7 @@
<div align="center"> <div align="center">
<img src="assets/images/project_logo.png" width="auto" alt="Logo"> <!-- <img src="assets/images/project_logo.png" width="auto" alt="Logo"> -->
<h1>KHAOSZ</h1> <h1>AstrAI</h1>
<div> <div>
<a href="#english">English</a> <a href="#english">English</a>
@ -48,8 +48,8 @@
### Installation ### Installation
```bash ```bash
git clone https://github.com/username/khaosz.git git clone https://github.com/ViperEkura/AstrAI.git
cd khaosz cd AstrAI
pip install -e . pip install -e .
``` ```
@ -95,8 +95,8 @@ python demo/generate_ar.py
### 安装 ### 安装
```bash ```bash
git clone https://github.com/username/khaosz.git git clone https://github.com/ViperEkura/AstrAI.git
cd khaosz cd AstrAI
pip install -e . pip install -e .
``` ```
@ -143,7 +143,7 @@ python demo/generate_ar.py
### Download | 下载 ### Download | 下载
- [HuggingFace](https://huggingface.co/ViperEk/KHAOSZ) - [HuggingFace](https://huggingface.co/ViperEk/AstrAI)
- `python demo/download.py` - `python demo/download.py`
### Lincence | 许可证 ### Lincence | 许可证

View File

@ -1,16 +1,16 @@
# KHAOSZ Data Flow Documentation # AstrAI Data Flow Documentation
This document describes the data flow of the KHAOSZ project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference. This document describes the data flow of the AstrAI project (a training and inference framework for autoregressive Transformer language models). It covers the complete flow from raw data to model training and inference.
## Overview ## Overview
KHAOSZ adopts a modular design with the following main components: AstrAI adopts a modular design with the following main components:
- **Data Module** (`khaosz/data/`): Dataset, sampler, tokenizer, serialization tools - **Data Module** (`astrai/data/`): Dataset, sampler, tokenizer, serialization tools
- **Model Module** (`khaosz/model/`): Transformer model and its submodules - **Model Module** (`astrai/model/`): Transformer model and its submodules
- **Training Module** (`khaosz/trainer/`): Trainer, training context, strategies, schedulers - **Training Module** (`astrai/trainer/`): Trainer, training context, strategies, schedulers
- **Inference Module** (`khaosz/inference/`): Generation core, KV cache management, streaming generation - **Inference Module** (`astrai/inference/`): Generation core, KV cache management, streaming generation
- **Config Module** (`khaosz/config/`): Model, training, scheduler, and other configurations - **Config Module** (`astrai/config/`): Model, training, scheduler, and other configurations
- **Parallel Module** (`khaosz/parallel/`): Distributed training support - **Parallel Module** (`astrai/parallel/`): Distributed training support
The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**. The data flow can generally be divided into two main lines: **Training Data Flow** and **Inference Data Flow**.
@ -199,7 +199,7 @@ flowchart LR
## Summary ## Summary
The data flow design of KHAOSZ reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension. The data flow design of AstrAI reflects the characteristics of modularity, extensibility, and resumability. The training data flow supports large-scale distributed training through chunk loading, resumable sampling, gradient accumulation, and other mechanisms; the inference data flow achieves efficient text generation using KV cache and sampling strategies. Clear interfaces between modules facilitate customization and extension.
> Document Update Time: 2026-03-30 > Document Update Time: 2026-03-30
> Corresponding Code Version: Refer to version number defined in `pyproject.toml` > Corresponding Code Version: Refer to version number defined in `pyproject.toml`

View File

@ -2,7 +2,7 @@
There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence. There are many large language models on the market today, such as GPT, LLaMA, and others, with tens of billions or even hundreds of billions of parameters. But honestly, these models have extremely high hardware requirements, making them inaccessible for ordinary developers. I thought: **Can we create a model that is both useful and can run on ordinary computers?** This is also what most people currently hope for - a locally deployable AI project that achieves complete privatization while maintaining some level of intelligence.
Thus, the KHAOSZ project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source! Thus, the AstrAI project was born - 1B parameters, Chinese-English bilingual, supporting dialogue, text generation, RAG retrieval, and the training code is open source!
## 2. System Architecture ## 2. System Architecture

View File

@ -83,8 +83,8 @@
### Usage Example ### Usage Example
```python ```python
from khaosz.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from khaosz.inference.generator import StreamGenerator, GenerationRequest from astrai.inference.generator import StreamGenerator, GenerationRequest
# Load model # Load model
param = ModelParameter.load("your_model_dir") param = ModelParameter.load("your_model_dir")

View File

@ -1,13 +1,13 @@
__version__ = "1.3.2" __version__ = "1.3.2"
__author__ = "ViperEkura" __author__ = "ViperEkura"
from khaosz.config import ( from astrai.config import (
ModelConfig, ModelConfig,
TrainConfig, TrainConfig,
) )
from khaosz.model.transformer import Transformer from astrai.model.transformer import Transformer
from khaosz.data import DatasetLoader, BpeTokenizer from astrai.data import DatasetLoader, BpeTokenizer
from khaosz.inference.generator import ( from astrai.inference.generator import (
GenerationRequest, GenerationRequest,
LoopGenerator, LoopGenerator,
StreamGenerator, StreamGenerator,
@ -15,7 +15,7 @@ from khaosz.inference.generator import (
EmbeddingEncoder, EmbeddingEncoder,
GeneratorFactory, GeneratorFactory,
) )
from khaosz.trainer import Trainer, StrategyFactory, SchedulerFactory from astrai.trainer import Trainer, StrategyFactory, SchedulerFactory
__all__ = [ __all__ = [
"Transformer", "Transformer",

View File

@ -1,12 +1,12 @@
from khaosz.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from khaosz.config.param_config import BaseModelIO, ModelParameter from astrai.config.param_config import BaseModelIO, ModelParameter
from khaosz.config.schedule_config import ( from astrai.config.schedule_config import (
ScheduleConfig, ScheduleConfig,
CosineScheduleConfig, CosineScheduleConfig,
SGDRScheduleConfig, SGDRScheduleConfig,
ScheduleConfigFactory, ScheduleConfigFactory,
) )
from khaosz.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
__all__ = [ __all__ = [

View File

@ -5,9 +5,9 @@ from dataclasses import dataclass, field
from typing import Optional, Self, Union from typing import Optional, Self, Union
from pathlib import Path from pathlib import Path
from khaosz.data.tokenizer import BpeTokenizer from astrai.data.tokenizer import BpeTokenizer
from khaosz.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from khaosz.model.transformer import Transformer from astrai.model.transformer import Transformer
@dataclass @dataclass

View File

@ -1,4 +1,4 @@
from khaosz.data.dataset import ( from astrai.data.dataset import (
BaseDataset, BaseDataset,
SEQDataset, SEQDataset,
DPODataset, DPODataset,
@ -9,8 +9,8 @@ from khaosz.data.dataset import (
DatasetFactory, DatasetFactory,
) )
from khaosz.data.tokenizer import BpeTokenizer from astrai.data.tokenizer import BpeTokenizer
from khaosz.data.sampler import ResumableDistributedSampler from astrai.data.sampler import ResumableDistributedSampler
__all__ = [ __all__ = [
# Base classes # Base classes

View File

@ -6,7 +6,7 @@ import bisect
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from torch import Tensor from torch import Tensor
from torch.utils.data import Dataset from torch.utils.data import Dataset
from khaosz.data.serialization import load_h5 from astrai.data.serialization import load_h5
from typing import List, Dict, Optional, Union from typing import List, Dict, Optional, Union

View File

@ -8,7 +8,7 @@ import torch.distributed as dist
from pathlib import Path from pathlib import Path
from torch import Tensor from torch import Tensor
from typing import Any, Dict, List from typing import Any, Dict, List
from khaosz.parallel.setup import get_rank from astrai.parallel.setup import get_rank
def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]): def save_h5(file_path: str, file_name: str, tensor_group: Dict[str, List[Tensor]]):

View File

@ -1,11 +1,11 @@
from khaosz.inference.core import ( from astrai.inference.core import (
disable_random_init, disable_random_init,
GeneratorCore, GeneratorCore,
EmbeddingEncoderCore, EmbeddingEncoderCore,
KVCacheManager, KVCacheManager,
) )
from khaosz.inference.generator import ( from astrai.inference.generator import (
GenerationRequest, GenerationRequest,
LoopGenerator, LoopGenerator,
StreamGenerator, StreamGenerator,

View File

@ -4,7 +4,7 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Callable, List, Tuple, Union, Optional, Self from typing import Any, Callable, List, Tuple, Union, Optional, Self
from khaosz.config import ModelParameter, ModelConfig from astrai.config import ModelParameter, ModelConfig
def apply_sampling_strategies( def apply_sampling_strategies(

View File

@ -2,8 +2,8 @@ import torch
from dataclasses import dataclass from dataclasses import dataclass
from torch import Tensor from torch import Tensor
from typing import List, Tuple, Union, Optional, Generator from typing import List, Tuple, Union, Optional, Generator
from khaosz.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager from astrai.inference.core import GeneratorCore, EmbeddingEncoderCore, KVCacheManager
from khaosz.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
HistoryType = List[Tuple[str, str]] HistoryType = List[Tuple[str, str]]

View File

@ -1,10 +1,10 @@
from khaosz.model.module import ( from astrai.model.module import (
Linear, Linear,
RMSNorm, RMSNorm,
MLP, MLP,
GQA, GQA,
DecoderBlock, DecoderBlock,
) )
from khaosz.model.transformer import Transformer from astrai.model.transformer import Transformer
__all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"] __all__ = ["Linear", "RMSNorm", "MLP", "GQA", "DecoderBlock", "Transformer"]

View File

@ -3,8 +3,8 @@ import torch.nn as nn
from torch import Tensor from torch import Tensor
from typing import Any, Mapping, Optional, Tuple from typing import Any, Mapping, Optional, Tuple
from khaosz.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from khaosz.model.module import ( from astrai.model.module import (
Embedding, Embedding,
DecoderBlock, DecoderBlock,
Linear, Linear,

View File

@ -1,4 +1,4 @@
from khaosz.parallel.setup import ( from astrai.parallel.setup import (
get_world_size, get_world_size,
get_rank, get_rank,
get_current_device, get_current_device,
@ -7,7 +7,7 @@ from khaosz.parallel.setup import (
spawn_parallel_fn, spawn_parallel_fn,
) )
from khaosz.parallel.module import RowParallelLinear, ColumnParallelLinear from astrai.parallel.module import RowParallelLinear, ColumnParallelLinear
__all__ = [ __all__ = [
"get_world_size", "get_world_size",

View File

@ -1,8 +1,8 @@
from khaosz.trainer.trainer import Trainer from astrai.trainer.trainer import Trainer
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy from astrai.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.trainer.schedule import SchedulerFactory, BaseScheduler from astrai.trainer.schedule import SchedulerFactory, BaseScheduler
from khaosz.trainer.train_callback import ( from astrai.trainer.train_callback import (
TrainCallback, TrainCallback,
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback, SchedulerCallback,

View File

@ -4,7 +4,7 @@ import math
from abc import abstractmethod, ABC from abc import abstractmethod, ABC
from typing import Any, Dict, List, Type from typing import Any, Dict, List, Type
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from khaosz.config.schedule_config import ScheduleConfig from astrai.config.schedule_config import ScheduleConfig
class BaseScheduler(LRScheduler, ABC): class BaseScheduler(LRScheduler, ABC):

View File

@ -8,8 +8,8 @@ from tqdm import tqdm
from torch.nn.utils import clip_grad_norm_ from torch.nn.utils import clip_grad_norm_
from typing import Callable, List, Optional, Protocol from typing import Callable, List, Optional, Protocol
from khaosz.parallel import only_on_rank from astrai.parallel import only_on_rank
from khaosz.trainer.metric_util import ( from astrai.trainer.metric_util import (
ctx_get_loss, ctx_get_loss,
ctx_get_lr, ctx_get_lr,
ctx_get_grad_max, ctx_get_grad_max,
@ -19,8 +19,8 @@ from khaosz.trainer.metric_util import (
ctx_get_grad_std, ctx_get_grad_std,
ctx_get_grad_nan_num, ctx_get_grad_nan_num,
) )
from khaosz.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from khaosz.trainer.train_context import TrainContext from astrai.trainer.train_context import TrainContext
class TrainCallback(Protocol): class TrainCallback(Protocol):

View File

@ -3,11 +3,11 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler from torch.optim.lr_scheduler import LRScheduler
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from khaosz.data import ResumableDistributedSampler from astrai.data import ResumableDistributedSampler
from khaosz.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from khaosz.trainer.strategy import StrategyFactory, BaseStrategy from astrai.trainer.strategy import StrategyFactory, BaseStrategy
from khaosz.config.train_config import TrainConfig from astrai.config.train_config import TrainConfig
from khaosz.parallel.setup import get_current_device, get_world_size, get_rank from astrai.parallel.setup import get_current_device, get_world_size, get_rank
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Optional, Self from typing import Optional, Self

View File

@ -1,7 +1,7 @@
import logging import logging
from typing import Optional, List from typing import Optional, List
from khaosz.config import TrainConfig from astrai.config import TrainConfig
from khaosz.trainer.train_callback import ( from astrai.trainer.train_callback import (
TrainCallback, TrainCallback,
ProgressBarCallback, ProgressBarCallback,
CheckpointCallback, CheckpointCallback,
@ -9,9 +9,9 @@ from khaosz.trainer.train_callback import (
GradientClippingCallback, GradientClippingCallback,
SchedulerCallback, SchedulerCallback,
) )
from khaosz.trainer.train_context import TrainContext, TrainContextBuilder from astrai.trainer.train_context import TrainContext, TrainContextBuilder
from khaosz.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from khaosz.parallel.setup import spawn_parallel_fn from astrai.parallel.setup import spawn_parallel_fn
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -6,7 +6,7 @@ PARAMETER_ROOT = Path(PROJECT_ROOT, "params")
if __name__ == "__main__": if __name__ == "__main__":
snapshot_download( snapshot_download(
repo_id="ViperEk/KHAOSZ", repo_id="ViperEk/AstrAI",
local_dir=PARAMETER_ROOT, local_dir=PARAMETER_ROOT,
force_download=True, force_download=True,
) )

View File

@ -1,8 +1,8 @@
import torch import torch
from pathlib import Path from pathlib import Path
from khaosz.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init from astrai.inference.core import disable_random_init
from khaosz.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -1,8 +1,8 @@
import torch import torch
from pathlib import Path from pathlib import Path
from khaosz.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init from astrai.inference.core import disable_random_init
from khaosz.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -1,8 +1,8 @@
import torch import torch
from pathlib import Path from pathlib import Path
from khaosz.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init from astrai.inference.core import disable_random_init
from khaosz.inference.generator import GeneratorFactory, GenerationRequest from astrai.inference.generator import GeneratorFactory, GenerationRequest
PROJECT_ROOT = Path(__file__).parent.parent PROJECT_ROOT = Path(__file__).parent.parent
PARAMETER_ROOT = Path(PROJECT_ROOT, "params") PARAMETER_ROOT = Path(PROJECT_ROOT, "params")

View File

@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
[project] [project]
dynamic = ["version"] dynamic = ["version"]
name = "khaosz" name = "astrai"
readme = "README.md" readme = "README.md"
requires-python = ">=3.12" requires-python = ">=3.12"
dependencies = [ dependencies = [
@ -23,7 +23,7 @@ classifiers = [
"License :: OSI Approved :: GPL-3.0", "License :: OSI Approved :: GPL-3.0",
"Operating System :: OS Independent", "Operating System :: OS Independent",
] ]
urls = { Homepage = "https://github.com/ViperEkura/KHAOSZ" } urls = { Homepage = "https://github.com/ViperEkura/AstrAI" }
[project.optional-dependencies] [project.optional-dependencies]
dev = ["pytest==9.0.2", "ruff"] dev = ["pytest==9.0.2", "ruff"]
@ -35,7 +35,7 @@ where = ["."]
extra-index-url = "https://download.pytorch.org/whl/cu126" extra-index-url = "https://download.pytorch.org/whl/cu126"
[tool.setuptools.dynamic] [tool.setuptools.dynamic]
version = { attr = "khaosz.__version__" } version = { attr = "astrai.__version__" }
[tool.ruff] [tool.ruff]
target-version = "py312" target-version = "py312"

View File

@ -7,9 +7,9 @@ import torch
import pytest import pytest
from torch.utils.data import Dataset from torch.utils.data import Dataset
from khaosz.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
from khaosz.data.tokenizer import BpeTokenizer from astrai.data.tokenizer import BpeTokenizer
from khaosz.model.transformer import Transformer from astrai.model.transformer import Transformer
class RandomDataset(Dataset): class RandomDataset(Dataset):

View File

@ -4,8 +4,8 @@ import torch.distributed as dist
from torch.optim import AdamW from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR from torch.optim.lr_scheduler import CosineAnnealingLR
from khaosz.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
from khaosz.parallel.setup import get_rank, spawn_parallel_fn from astrai.parallel.setup import get_rank, spawn_parallel_fn
def test_single_process(): def test_single_process():

View File

@ -1,8 +1,8 @@
import torch import torch
import numpy as np import numpy as np
from khaosz.data.serialization import save_h5 from astrai.data.serialization import save_h5
from khaosz.data.dataset import * from astrai.data.dataset import *
def test_dataset_loader_random_paths(base_test_env): def test_dataset_loader_random_paths(base_test_env):

View File

@ -1,5 +1,5 @@
from khaosz.trainer import * from astrai.trainer import *
from khaosz.data import * from astrai.data import *
def test_random_sampler_consistency(random_dataset): def test_random_sampler_consistency(random_dataset):

View File

@ -5,11 +5,11 @@ import shutil
import pytest import pytest
import tempfile import tempfile
import safetensors.torch as st import safetensors.torch as st
from khaosz.trainer import * from astrai.trainer import *
from khaosz.config import * from astrai.config import *
from khaosz.model import * from astrai.model import *
from khaosz.data import * from astrai.data import *
from khaosz.inference.generator import EmbeddingEncoderCore, GeneratorCore from astrai.inference.generator import EmbeddingEncoderCore, GeneratorCore
from tokenizers import pre_tokenizers from tokenizers import pre_tokenizers

View File

@ -4,8 +4,8 @@ import torch
import pytest import pytest
import tempfile import tempfile
import safetensors.torch as st import safetensors.torch as st
from khaosz.model.transformer import Transformer from astrai.model.transformer import Transformer
from khaosz.config.model_config import ModelConfig from astrai.config.model_config import ModelConfig
@pytest.fixture @pytest.fixture

View File

@ -1,7 +1,7 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from khaosz.parallel import get_rank, only_on_rank, spawn_parallel_fn from astrai.parallel import get_rank, only_on_rank, spawn_parallel_fn
@only_on_rank(0) @only_on_rank(0)

View File

@ -1,7 +1,7 @@
import torch import torch
from khaosz.config import * from astrai.config import *
from khaosz.trainer import * from astrai.trainer import *
def test_callback_integration(base_test_env, random_dataset): def test_callback_integration(base_test_env, random_dataset):

View File

@ -1,9 +1,9 @@
import os import os
import torch import torch
import numpy as np import numpy as np
from khaosz.config import * from astrai.config import *
from khaosz.trainer import * from astrai.trainer import *
from khaosz.data.serialization import Checkpoint from astrai.data.serialization import Checkpoint
def test_early_stopping_simulation(base_test_env, early_stopping_dataset): def test_early_stopping_simulation(base_test_env, early_stopping_dataset):

View File

@ -2,9 +2,9 @@ import torch
import numpy as np import numpy as np
import pytest import pytest
from khaosz.config import * from astrai.config import *
from khaosz.trainer.schedule import * from astrai.trainer.schedule import *
from khaosz.data.dataset import * from astrai.data.dataset import *
def test_schedule_factory_random_configs(): def test_schedule_factory_random_configs():

View File

@ -2,9 +2,9 @@ import torch
import numpy as np import numpy as np
from khaosz.config import * from astrai.config import *
from khaosz.trainer import * from astrai.trainer import *
from khaosz.data.dataset import * from astrai.data.dataset import *
def test_different_batch_sizes(base_test_env, random_dataset): def test_different_batch_sizes(base_test_env, random_dataset):

View File

@ -1,7 +1,7 @@
import torch import torch
from typing import Dict, Any from typing import Dict, Any
from dataclasses import dataclass from dataclasses import dataclass
from khaosz.model.transformer import ModelConfig, Transformer from astrai.model.transformer import ModelConfig, Transformer
@dataclass @dataclass

View File

@ -2,9 +2,9 @@ import torch
import json import json
import argparse import argparse
from khaosz.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from khaosz.inference.generator import BatchGenerator, GenerationRequest from astrai.inference.generator import BatchGenerator, GenerationRequest
from khaosz.inference.core import disable_random_init from astrai.inference.core import disable_random_init
def processor( def processor(

View File

@ -6,8 +6,8 @@ import argparse
import tqdm import tqdm
from torch import Tensor from torch import Tensor
from khaosz.config.param_config import ModelParameter from astrai.config.param_config import ModelParameter
from khaosz.inference.core import disable_random_init from astrai.inference.core import disable_random_init
def compute_perplexity( def compute_perplexity(

View File

@ -6,10 +6,10 @@ import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP from torch.nn.parallel import DistributedDataParallel as DDP
from functools import partial from functools import partial
from khaosz.data import DatasetLoader from astrai.data import DatasetLoader
from khaosz.config import ModelParameter, TrainConfig, CosineScheduleConfig from astrai.config import ModelParameter, TrainConfig, CosineScheduleConfig
from khaosz.trainer import Trainer, SchedulerFactory from astrai.trainer import Trainer, SchedulerFactory
from khaosz.parallel import get_rank from astrai.parallel import get_rank
def parse_args() -> argparse.Namespace: def parse_args() -> argparse.Namespace: