feat(vector): Automatic indexing of documents in s3 storage
This commit is contained in:
57
vector/qdrant.py
Normal file
57
vector/qdrant.py
Normal file
@@ -0,0 +1,57 @@
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import PointStruct, VectorParams, Distance
|
||||
|
||||
from .chunk import Chunk
|
||||
|
||||
class Qdrant:
|
||||
def __init__(self, host: str, port: int, collection_name: str ) -> None:
|
||||
self.client = QdrantClient(host=host, port=port)
|
||||
self.collection_name = collection_name
|
||||
|
||||
def create_collection(self) -> None:
|
||||
self.client.create_collection(
|
||||
collection_name=self.collection_name,
|
||||
vectors_config=VectorParams(
|
||||
size=4096,
|
||||
distance=Distance.COSINE,
|
||||
),
|
||||
)
|
||||
|
||||
def create_if_not_exists_collection(self) -> None:
|
||||
if not self.client.collection_exists(collection_name=self.collection_name):
|
||||
self.create_collection()
|
||||
|
||||
def delete_collection(self) -> None:
|
||||
self.client.delete_collection(collection_name=self.collection_name)
|
||||
|
||||
def create_points(self, chunks: list[Chunk], bucket: str, object_name: str) -> list[PointStruct]:
|
||||
points = []
|
||||
for chunk in chunks:
|
||||
points.append(self.create_point(chunk, bucket, object_name))
|
||||
return points
|
||||
|
||||
def create_point(self, chunk: Chunk, bucket: str, object_name: str) -> PointStruct:
|
||||
if not chunk.has_embedding or chunk.embedding is None:
|
||||
raise ValueError("Chunk has no embedding")
|
||||
embedding: list[float] = chunk.embedding
|
||||
point = PointStruct(
|
||||
id=chunk.id,
|
||||
vector=embedding,
|
||||
payload={
|
||||
"text": chunk.text,
|
||||
"bucket": bucket,
|
||||
"object": object_name,
|
||||
"id": chunk.id,
|
||||
"chunk_size": chunk.size
|
||||
}
|
||||
)
|
||||
return point
|
||||
|
||||
def upsert_points(self, points: list[PointStruct], batch_size: int = 50) -> None:
|
||||
num_batches = (len(points) + batch_size - 1) // batch_size
|
||||
print(f"Upserting {len(points)} points in {num_batches} batches...")
|
||||
for batch_start in range(0, len(points), batch_size):
|
||||
batch = points[batch_start:batch_start + batch_size]
|
||||
self.client.upsert(collection_name=self.collection_name, points=batch)
|
||||
print(f"Upserted {len(batch)} points...")
|
||||
print(f"Upserted {len(points)} points!")
|
||||
Reference in New Issue
Block a user