88 lines
3.0 KiB
Python
88 lines
3.0 KiB
Python
import torch
|
|
import sqlite3
|
|
import numpy as np
|
|
from torch import Tensor
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
class Retriever:
|
|
def __init__(self, db_path=None):
|
|
self.data: Dict[str, Tensor] = {}
|
|
self.embedding_cache: Tensor = None
|
|
self.is_caculated: bool = False
|
|
|
|
if db_path is not None:
|
|
self.load(db_path)
|
|
|
|
def retrieve(self, query: Tensor, top_k: int) -> List[Tuple[str, float]]:
|
|
if not self.data:
|
|
return []
|
|
|
|
query = query.flatten().unsqueeze(1) # [dim, 1]
|
|
norm_embeddings = self._embeddings.to(
|
|
device=query.device,
|
|
dtype=query.dtype
|
|
) # [n_vectors, dim]
|
|
sim_scores = torch.matmul(norm_embeddings, query).squeeze() # [n_vectors]
|
|
|
|
top_k = min(top_k, len(self.data))
|
|
indices = sim_scores.topk(top_k).indices
|
|
keys = list(self.data.keys())
|
|
|
|
return [(keys[i], sim_scores[i].item()) for i in indices]
|
|
|
|
def add_vector(self, key: str, vector_data: Tensor):
|
|
self.is_caculated = False
|
|
self.data[key] = vector_data.flatten().float().cpu()
|
|
|
|
def delete_vector(self, key: str):
|
|
self.is_caculated = False
|
|
self.data.pop(key, None)
|
|
|
|
def save(self, db_path):
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
self._init_db(cursor)
|
|
cursor.execute('DELETE FROM vectors')
|
|
|
|
for item, vec in self.data.items():
|
|
vec_bytes = vec.numpy().tobytes()
|
|
cursor.execute('INSERT OR REPLACE INTO vectors (key, vector) VALUES (?, ?)',
|
|
(item, vec_bytes))
|
|
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
def load(self, db_path):
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
self._init_db(cursor)
|
|
cursor.execute('SELECT key, vector FROM vectors')
|
|
rows = cursor.fetchall()
|
|
self.data = {}
|
|
|
|
for row in rows:
|
|
key, vec_bytes = row
|
|
vec_numpy = np.frombuffer(vec_bytes, dtype=np.float32).copy()
|
|
vec = torch.from_numpy(vec_numpy)
|
|
self.data[key] = vec
|
|
|
|
conn.close()
|
|
|
|
def _init_db(self,cursor: sqlite3.Cursor):
|
|
# Create table if not exists (in case loading from a new database)
|
|
cursor.execute('''
|
|
CREATE TABLE IF NOT EXISTS vectors (
|
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
|
key TEXT UNIQUE NOT NULL,
|
|
vector BLOB NOT NULL
|
|
)''')
|
|
|
|
@property
|
|
def _embeddings(self) -> Tensor:
|
|
if not self.is_caculated:
|
|
embeddings = torch.stack(list(self.data.values()))
|
|
norm_embeddings = embeddings / torch.norm(embeddings, dim=-1, keepdim=True)
|
|
self.embedding_cache = norm_embeddings
|
|
|
|
return self.embedding_cache |