AstrAI/khaosz/utils/retriever.py

88 lines
3.0 KiB
Python

import torch
import sqlite3
import numpy as np
from torch import Tensor
from typing import Dict, List, Tuple
class Retriever:
def __init__(self, db_path=None):
self.data: Dict[str, Tensor] = {}
self.embedding_cache: Tensor = None
self.is_caculated: bool = False
if db_path is not None:
self.load(db_path)
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
if not self.data:
return []
query = query.flatten().unsqueeze(1) # [dim, 1]
norm_embeddings = self._embeddings.to(
device=query.device,
dtype=query.dtype
) # [n_vectors, dim]
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
top_k = min(top_k, len(self.data))
indices = sim_scores.topk(top_k).indices
keys = list(self.data.keys())
return [(keys[i], sim_scores[i].item()) for i in indices]
def add_vector(self, key: str, vector_data: Tensor):
self.is_caculated = False
self.data[key] = vector_data.flatten().float().cpu()
def delete_vector(self, key: str):
self.is_caculated = False
self.data.pop(key, None)
def save(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('DELETE FROM vectors')
for item, vec in self.data.items():
vec_bytes = vec.numpy().tobytes()
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
(item, vec_bytes))
conn.commit()
conn.close()
def load(self, db_path):
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
self._init_db(cursor)
cursor.execute('SELECT key, vector FROM vectors')
rows = cursor.fetchall()
self.data = {}
for row in rows:
key, vec_bytes = row
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
vec = torch.from_numpy(vec_numpy)
self.data[key] = vec
conn.close()
def _init_db(self,cursor: sqlite3.Cursor):
# Create table if not exists (in case loading from a new database)
cursor.execute('''
CREATE TABLE IF NOT EXISTS vectors (
id INTEGER PRIMARY KEY AUTOINCREMENT,
key TEXT UNIQUE NOT NULL,
vector BLOB NOT NULL
)''')
@property
def _embeddings(self) -> Tensor:
if not self.is_caculated:
embeddings = torch.stack(list(self.data.values()))
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
self.embedding_cache = norm_embeddings
return self.embedding_cache