Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions vicinity/datatypes.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from collections.abc import Iterable
from enum import Enum
from pathlib import Path
from typing import TypeVar

from numpy import typing as npt

T = TypeVar("T")

PathLike = str | Path
Matrix = npt.NDArray | list[npt.NDArray]
SimilarityItem = list[tuple[str, float]]
SimilarityResult = list[SimilarityItem]
SimilarityItem = list[tuple[T, float]]
SimilarityResult = list[list[tuple[T, float]]]
# Tuple of (indices, distances)
SingleQueryResult = tuple[npt.NDArray, npt.NDArray]
QueryResult = list[SingleQueryResult]
Expand Down
32 changes: 17 additions & 15 deletions vicinity/vicinity.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from collections.abc import Iterable, Sequence
from pathlib import Path
from time import perf_counter
from typing import Any
from typing import Any, Generic

import numpy as np
import orjson
Expand All @@ -15,12 +15,12 @@

from vicinity import Metric
from vicinity.backends import AbstractBackend, BasicBackend, BasicVectorStore, get_backend_class
from vicinity.datatypes import Backend, PathLike, SimilarityResult
from vicinity.datatypes import Backend, PathLike, SimilarityResult, T

logger = logging.getLogger(__name__)


class Vicinity:
class Vicinity(Generic[T]):
"""
Work with vector representations of items.

Expand All @@ -30,7 +30,7 @@ class Vicinity:

def __init__(
self,
items: Sequence[Any],
items: Sequence[T],
backend: AbstractBackend,
metadata: dict[str, Any] | None = None,
vector_store: BasicVectorStore | None = None,
Expand All @@ -50,7 +50,7 @@ def __init__(
raise ValueError(
"Your vector space and list of items are not the same length: " f"{len(backend)} != {len(items)}"
)
self.items: list[Any] = list(items)
self.items: list[T] = list(items)
self.backend: AbstractBackend = backend
self.metadata = metadata or {}
self.vector_store = vector_store
Expand All @@ -73,13 +73,13 @@ def __len__(self) -> int:

@classmethod
def from_vectors_and_items(
cls: type[Vicinity],
cls: type[Vicinity[T]],
vectors: npt.NDArray,
items: Sequence[Any],
items: Sequence[T],
backend_type: Backend | str = Backend.BASIC,
store_vectors: bool = False,
**kwargs: Any,
) -> Vicinity:
) -> Vicinity[T]:
"""
Create a Vicinity instance from vectors and items.

Expand Down Expand Up @@ -115,7 +115,7 @@ def query(
self,
vectors: npt.NDArray,
k: int = 10,
) -> SimilarityResult:
) -> SimilarityResult[T]:
"""
Find the nearest neighbors to some arbitrary vector.

Expand All @@ -142,7 +142,7 @@ def query_threshold(
vectors: npt.NDArray,
threshold: float = 0.5,
max_k: int = 100,
) -> SimilarityResult:
) -> SimilarityResult[T]:
"""
Find the nearest neighbors to some arbitrary vector with some threshold. Note: the output is not sorted.

Expand Down Expand Up @@ -178,7 +178,9 @@ def save(
:param folder: The path to which to save the JSON file. The vectors are saved separately. The JSON contains a path to the numpy file.
:param overwrite: Whether to overwrite the JSON and numpy files if they already exist.
:raises ValueError: If the path is not a directory.
:raises JSONEncodeError: If the items are not serializable.
:raises JSONEncodeError: If the items are not JSON-serializable. ``save()`` and ``load()``
only support item types that orjson can encode (e.g. strings, numbers, dicts).
Use ``Vicinity[str]`` or another serializable type if you need persistence.
"""
path = Path(folder)
path.mkdir(parents=True, exist_ok=overwrite)
Expand All @@ -200,7 +202,7 @@ def save(
self.vector_store.save(store_path)

@classmethod
def load(cls, filename: PathLike) -> Vicinity:
def load(cls, filename: PathLike) -> Vicinity[Any]:
"""
Load a Vicinity instance in fast format.

Expand Down Expand Up @@ -231,7 +233,7 @@ def load(cls, filename: PathLike) -> Vicinity:

return instance

def insert(self, tokens: Sequence[Any], vectors: npt.NDArray) -> None:
def insert(self, tokens: Sequence[T], vectors: npt.NDArray) -> None:
"""
Insert new items into the vector space.

Expand All @@ -250,7 +252,7 @@ def insert(self, tokens: Sequence[Any], vectors: npt.NDArray) -> None:
if self.vector_store is not None:
self.vector_store.insert(vectors)

def delete(self, tokens: Sequence[Any]) -> None:
def delete(self, tokens: Sequence[T]) -> None:
"""
Delete tokens from the vector space.

Expand Down Expand Up @@ -309,7 +311,7 @@ def push_to_hub(
)

@classmethod
def load_from_hub(cls: type[Vicinity], repo_id: str, token: str | None = None, **kwargs: Any) -> Vicinity:
def load_from_hub(cls: type[Vicinity[Any]], repo_id: str, token: str | None = None, **kwargs: Any) -> Vicinity[Any]:
"""
Load a Vicinity instance from the Hugging Face Hub.

Expand Down
Loading