127 lines
4.0 KiB
Python
127 lines
4.0 KiB
Python
import re
|
|
import torch
|
|
import torch.nn.functional as F
|
|
|
|
from abc import ABC, abstractmethod
|
|
from torch import Tensor
|
|
from typing import List, Callable, Optional
|
|
|
|
|
|
class BaseTextSplitter(ABC):
|
|
def __init__(
|
|
self,
|
|
max_len: int = 512,
|
|
chunk_overlap: int = 0,
|
|
):
|
|
if max_len <= 0:
|
|
raise ValueError("max_len must be > 0")
|
|
if chunk_overlap < 0:
|
|
raise ValueError("chunk_overlap must be >= 0")
|
|
|
|
self.max_len = max_len
|
|
self.chunk_overlap = chunk_overlap
|
|
|
|
@abstractmethod
|
|
def split(self, text: str, **kwargs) -> List[str]:
|
|
raise NotImplementedError
|
|
|
|
def preprocess(self, text: str) -> str:
|
|
return text.strip()
|
|
|
|
def postprocess(self, chunks: List[str]) -> List[str]:
|
|
return [chunk.strip() for chunk in chunks if chunk.strip()]
|
|
|
|
|
|
class PriorityTextSplitter(BaseTextSplitter):
|
|
def __init__(
|
|
self,
|
|
separators: List[str],
|
|
max_len: int = 512,
|
|
chunk_overlap: int = 0,
|
|
):
|
|
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
|
|
if not separators:
|
|
raise ValueError("separators must be a non-empty list")
|
|
self.separators = separators
|
|
|
|
def split(self, text: str) -> List[str]:
|
|
text = self.preprocess(text)
|
|
for sep in self.separators:
|
|
parts = text.split(sep)
|
|
|
|
valid_parts = [p.strip() for p in parts if p.strip()]
|
|
if len(valid_parts) > 1:
|
|
return self.postprocess(valid_parts)
|
|
return [text]
|
|
|
|
|
|
class SemanticTextSplitter(BaseTextSplitter):
|
|
|
|
DEFAULT_PATTERN = r'(?<=[。!?!?])(?=(?:[^"\'‘’“”]*["\'‘’“”][^"\'‘’“”]*["\'‘’“”])*[^"\'‘’“”]*$)'
|
|
|
|
def __init__(
|
|
self,
|
|
embedding_func: Callable[[List[str]], List[Tensor]],
|
|
pattern: Optional[str] = None,
|
|
max_len: int = 512,
|
|
chunk_overlap: int = 0,
|
|
):
|
|
super().__init__(max_len=max_len, chunk_overlap=chunk_overlap)
|
|
if not callable(embedding_func):
|
|
raise TypeError("embedding_func must be callable")
|
|
self.embedding_func = embedding_func
|
|
self.pattern = pattern or SemanticTextSplitter.DEFAULT_PATTERN
|
|
|
|
def split(
|
|
self,
|
|
text: str,
|
|
threshold: float = 0.5,
|
|
window_size: int = 1,
|
|
) -> List[str]:
|
|
text = self.preprocess(text)
|
|
sentences = [s.strip() for s in re.split(self.pattern, text) if s.strip()]
|
|
|
|
if len(sentences) <= 1:
|
|
return self.postprocess(sentences)
|
|
|
|
try:
|
|
sentence_embs = self.embedding_func(sentences)
|
|
except Exception as e:
|
|
raise RuntimeError(f"Embedding generation failed: {e}")
|
|
|
|
if len(sentence_embs) != len(sentences):
|
|
raise ValueError("Embedding function must return one vector per sentence")
|
|
|
|
chunks = []
|
|
emb_tensor = torch.stack(sentence_embs) # shape: [N, D]
|
|
current_chunk: List[str] = [sentences[0]]
|
|
|
|
for i in range(1, len(sentences)):
|
|
start_prev = max(0, i - window_size)
|
|
end_prev = i
|
|
start_next = i
|
|
end_next = min(len(sentences), i + window_size)
|
|
|
|
prev_window_emb = emb_tensor[start_prev:end_prev].mean(dim=0)
|
|
next_window_emb = emb_tensor[start_next:end_next].mean(dim=0)
|
|
|
|
similarity = F.cosine_similarity(
|
|
prev_window_emb.unsqueeze(0),
|
|
next_window_emb.unsqueeze(0),
|
|
dim=1
|
|
).item()
|
|
|
|
dynamic_threshold = max(threshold * (1 - 0.03 * (end_next - start_prev)), 0.2)
|
|
|
|
if similarity < dynamic_threshold:
|
|
chunks.append(" ".join(current_chunk))
|
|
overlap_start = max(0, len(current_chunk) - self.chunk_overlap)
|
|
current_chunk = current_chunk[overlap_start:]
|
|
current_chunk.append(sentences[i])
|
|
else:
|
|
current_chunk.append(sentences[i])
|
|
|
|
if current_chunk:
|
|
chunks.append(" ".join(current_chunk))
|
|
|
|
return self.postprocess(chunks) |