AstrAI/khaosz/utils/splitter.py

127 lines
4.0 KiB
Python

import re
import torch
import torch.nn.functional as F
from abc import ABC, abstractmethod
from torch import Tensor
from typing import List, Callable, Optional
class BaseTextSplitter(ABC):
def __init__(
self,
max_len: int = 512,
chunk_overlap: int = 0,
):
if max_len <= 0:
raise ValueError("max_len must be > 0")
if chunk_overlap < 0:
raise ValueError("chunk_overlap must be >= 0")
self.max_len = max_len
self.chunk_overlap = chunk_overlap
@abstractmethod
def split(self, text: str, **kwargs) -> List[str]:
raise NotImplementedError
def preprocess(self, text: str) -> str:
return text.strip()
def postprocess(self, chunks: List[str]) -> List[str]:
return [chunk.strip() for chunk in chunks if chunk.strip()]
class PriorityTextSplitter(BaseTextSplitter):
def __init__(
self,
separators: List[str],
max_len: int = 512,
chunk_overlap: int = 0,
):
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
if not separators:
raise ValueError("separators must be a non-empty list")
self.separators = separators
def split(self, text: str) -> List[str]:
text = self.preprocess(text)
for sep in self.separators:
parts = text.split(sep)
valid_parts = [p.strip() for p in parts if p.strip()]
if len(valid_parts) > 1:
return self.postprocess(valid_parts)
return [text]
class SemanticTextSplitter(BaseTextSplitter):
DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
def __init__(
self,
embedding_func: Callable[[List[str]], List[Tensor]],
pattern: Optional[str] = None,
max_len: int = 512,
chunk_overlap: int = 0,
):
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
if not callable(embedding_func):
raise TypeError("embedding_func must be callable")
self.embedding_func = embedding_func
self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
def split(
self,
text: str,
threshold: float = 0.5,
window_size: int = 1,
) -> List[str]:
text = self.preprocess(text)
sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
if len(sentences) <= 1:
return self.postprocess(sentences)
try:
sentence_embs = self.embedding_func(sentences)
except Exception as e:
raise RuntimeError(f"Embedding generation failed: {e}")
if len(sentence_embs) != len(sentences):
raise ValueError("Embedding function must return one vector per sentence")
chunks = []
emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
current_chunk: List[str] = [sentences[0]]
for i in range(1, len(sentences)):
start_prev = max(0, i - window_size)
end_prev = i
start_next = i
end_next = min(len(sentences), i + window_size)
prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
similarity = F.cosine_similarity(
prev_window_emb.unsqueeze(0),
next_window_emb.unsqueeze(0),
dim=1
).item()
dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
if similarity < dynamic_threshold:
chunks.append(" ".join(current_chunk))
overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
current_chunk = current_chunk[overlap_start:]
current_chunk.append(sentences[i])
else:
current_chunk.append(sentences[i])
if current_chunk:
chunks.append(" ".join(current_chunk))
return self.postprocess(chunks)