diff --git a/vicinity/datatypes.py b/vicinity/datatypes.py index 8871609..f09db5e 100644 --- a/vicinity/datatypes.py +++ b/vicinity/datatypes.py @@ -1,13 +1,16 @@ from collections.abc import Iterable from enum import Enum from pathlib import Path +from typing import TypeVar from numpy import typing as npt +T = TypeVar("T") + PathLike = str | Path Matrix = npt.NDArray | list[npt.NDArray] -SimilarityItem = list[tuple[str, float]] -SimilarityResult = list[SimilarityItem] +SimilarityItem = list[tuple[T, float]] +SimilarityResult = list[list[tuple[T, float]]] # Tuple of (indices, distances) SingleQueryResult = tuple[npt.NDArray, npt.NDArray] QueryResult = list[SingleQueryResult] diff --git a/vicinity/vicinity.py b/vicinity/vicinity.py index 6330ff1..b1334e0 100644 --- a/vicinity/vicinity.py +++ b/vicinity/vicinity.py @@ -6,7 +6,7 @@ from collections.abc import Iterable, Sequence from pathlib import Path from time import perf_counter -from typing import Any +from typing import Any, Generic import numpy as np import orjson @@ -15,12 +15,12 @@ from vicinity import Metric from vicinity.backends import AbstractBackend, BasicBackend, BasicVectorStore, get_backend_class -from vicinity.datatypes import Backend, PathLike, SimilarityResult +from vicinity.datatypes import Backend, PathLike, SimilarityResult, T logger = logging.getLogger(__name__) -class Vicinity: +class Vicinity(Generic[T]): """ Work with vector representations of items. @@ -30,7 +30,7 @@ class Vicinity: def __init__( self, - items: Sequence[Any], + items: Sequence[T], backend: AbstractBackend, metadata: dict[str, Any] | None = None, vector_store: BasicVectorStore | None = None, @@ -50,7 +50,7 @@ def __init__( raise ValueError( "Your vector space and list of items are not the same length: " f"{len(backend)} != {len(items)}" ) - self.items: list[Any] = list(items) + self.items: list[T] = list(items) self.backend: AbstractBackend = backend self.metadata = metadata or {} self.vector_store = vector_store @@ -73,13 +73,13 @@ def __len__(self) -> int: @classmethod def from_vectors_and_items( - cls: type[Vicinity], + cls: type[Vicinity[T]], vectors: npt.NDArray, - items: Sequence[Any], + items: Sequence[T], backend_type: Backend | str = Backend.BASIC, store_vectors: bool = False, **kwargs: Any, - ) -> Vicinity: + ) -> Vicinity[T]: """ Create a Vicinity instance from vectors and items. @@ -115,7 +115,7 @@ def query( self, vectors: npt.NDArray, k: int = 10, - ) -> SimilarityResult: + ) -> SimilarityResult[T]: """ Find the nearest neighbors to some arbitrary vector. @@ -142,7 +142,7 @@ def query_threshold( vectors: npt.NDArray, threshold: float = 0.5, max_k: int = 100, - ) -> SimilarityResult: + ) -> SimilarityResult[T]: """ Find the nearest neighbors to some arbitrary vector with some threshold. Note: the output is not sorted. @@ -178,7 +178,9 @@ def save( :param folder: The path to which to save the JSON file. The vectors are saved separately. The JSON contains a path to the numpy file. :param overwrite: Whether to overwrite the JSON and numpy files if they already exist. :raises ValueError: If the path is not a directory. - :raises JSONEncodeError: If the items are not serializable. + :raises JSONEncodeError: If the items are not JSON-serializable. ``save()`` and ``load()`` + only support item types that orjson can encode (e.g. strings, numbers, dicts). + Use ``Vicinity[str]`` or another serializable type if you need persistence. """ path = Path(folder) path.mkdir(parents=True, exist_ok=overwrite) @@ -200,7 +202,7 @@ def save( self.vector_store.save(store_path) @classmethod - def load(cls, filename: PathLike) -> Vicinity: + def load(cls, filename: PathLike) -> Vicinity[Any]: """ Load a Vicinity instance in fast format. @@ -231,7 +233,7 @@ def load(cls, filename: PathLike) -> Vicinity: return instance - def insert(self, tokens: Sequence[Any], vectors: npt.NDArray) -> None: + def insert(self, tokens: Sequence[T], vectors: npt.NDArray) -> None: """ Insert new items into the vector space. @@ -250,7 +252,7 @@ def insert(self, tokens: Sequence[Any], vectors: npt.NDArray) -> None: if self.vector_store is not None: self.vector_store.insert(vectors) - def delete(self, tokens: Sequence[Any]) -> None: + def delete(self, tokens: Sequence[T]) -> None: """ Delete tokens from the vector space. @@ -309,7 +311,7 @@ def push_to_hub( ) @classmethod - def load_from_hub(cls: type[Vicinity], repo_id: str, token: str | None = None, **kwargs: Any) -> Vicinity: + def load_from_hub(cls: type[Vicinity[Any]], repo_id: str, token: str | None = None, **kwargs: Any) -> Vicinity[Any]: """ Load a Vicinity instance from the Hugging Face Hub.