57 lines
2.3 KiB
Python
57 lines
2.3 KiB
Python
from qdrant_client import QdrantClient
|
|
from qdrant_client.models import PointStruct, VectorParams, Distance
|
|
|
|
from .chunk import Chunk
|
|
|
|
class Qdrant:
|
|
def __init__(self, host: str, port: int, collection_name: str ) -> None:
|
|
self.client = QdrantClient(host=host, port=port)
|
|
self.collection_name = collection_name
|
|
|
|
def create_collection(self) -> None:
|
|
self.client.create_collection(
|
|
collection_name=self.collection_name,
|
|
vectors_config=VectorParams(
|
|
size=4096,
|
|
distance=Distance.COSINE,
|
|
),
|
|
)
|
|
|
|
def create_if_not_exists_collection(self) -> None:
|
|
if not self.client.collection_exists(collection_name=self.collection_name):
|
|
self.create_collection()
|
|
|
|
def delete_collection(self) -> None:
|
|
self.client.delete_collection(collection_name=self.collection_name)
|
|
|
|
def create_points(self, chunks: list[Chunk], bucket: str, object_name: str) -> list[PointStruct]:
|
|
points = []
|
|
for chunk in chunks:
|
|
points.append(self.create_point(chunk, bucket, object_name))
|
|
return points
|
|
|
|
def create_point(self, chunk: Chunk, bucket: str, object_name: str) -> PointStruct:
|
|
if not chunk.has_embedding or chunk.embedding is None:
|
|
raise ValueError("Chunk has no embedding")
|
|
embedding: list[float] = chunk.embedding
|
|
point = PointStruct(
|
|
id=chunk.id,
|
|
vector=embedding,
|
|
payload={
|
|
"text": chunk.text,
|
|
"bucket": bucket,
|
|
"object": object_name,
|
|
"id": chunk.id,
|
|
"chunk_size": chunk.size
|
|
}
|
|
)
|
|
return point
|
|
|
|
def upsert_points(self, points: list[PointStruct], batch_size: int = 50) -> None:
|
|
num_batches = (len(points) + batch_size - 1) // batch_size
|
|
print(f"Upserting {len(points)} points in {num_batches} batches...")
|
|
for batch_start in range(0, len(points), batch_size):
|
|
batch = points[batch_start:batch_start + batch_size]
|
|
self.client.upsert(collection_name=self.collection_name, points=batch)
|
|
print(f"Upserted {len(batch)} points...")
|
|
print(f"Upserted {len(points)} points!") |