diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml new file mode 100644 index 0000000..7d4ab9b --- /dev/null +++ b/.github/workflows/pr.yml @@ -0,0 +1,35 @@ +name: PR Checks + +on: + pull_request: + branches: + - main + +concurrency: + group: pr-${{ github.event.pull_request.number || github.ref }} + cancel-in-progress: true + +jobs: + checks: + runs-on: [self-hosted, linux, no-gpu] + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Install dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + python -m pip install -e . + python -m pip install pytest pre-commit + + - name: Lint and format + run: pre-commit run --all-files --show-diff-on-failure + + - name: Run unit tests + run: python -m pytest -q -m "not integration" diff --git a/.gitignore b/.gitignore index 8a933d1..eca98cc 100644 --- a/.gitignore +++ b/.gitignore @@ -36,6 +36,16 @@ Thumbs.db # Sensitive local data keys/ +VECTORSTORE.md + +# External symlinks (local workspace references) +es2-msa +es2-msa/ +es2-deploy +es2-deploy/ + +# Local helper scripts +run_unit_tests.py # Jupyter .ipynb_checkpoints/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..00caf48 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,10 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.7.1 + hooks: + - id: ruff + - repo: https://github.com/psf/black + rev: 24.10.0 + hooks: + - id: black + language_version: python3.11 diff --git a/README.md b/README.md index 518deb7..d6b522a 100644 --- a/README.md +++ b/README.md @@ -13,12 +13,12 @@ Encrypted vector search for LangChain using Envector (ES2), powered by homomorph - `python3.11 -m venv .venv && source .venv/bin/activate` - Install runtime dependencies: - `pip install -U pip setuptools wheel` - - `pip install es2==1.1.0 langchain sentence-transformers` + - `pip install es2 langchain sentence-transformers` ## Usage Overview 1. Configure Envector using `EnvectorConfig`, pointing to your ES2 endpoint and keys. 2. Initialize embeddings (or provide pre-computed vectors). -3. Instantiate `Envector(config=cfg, embeddings=emb)` and call `add_texts` or `as_retriever`. +3. Instantiate `Envector(config=cfg, embeddings=emb)` and call `add_texts`, `add_documents`, or use `as_retriever`. 4. Run `similarity_search` or plug the retriever into your LangChain pipeline. > See `notebooks/` for end-to-end walkthroughs and the `libs/envector` package for implementation details. @@ -27,7 +27,7 @@ Encrypted vector search for LangChain using Envector (ES2), powered by homomorph Key dataclasses live in `libs/envector/config.py`: - `ConnectionConfig`: address or host/port for ES2. - `KeyConfig`: key path, key ID, optional preset/eval mode. -- `IndexSettings`: index name, dimension (16–4096), query encryption mode, optional output fields and fetch parameters. +- `IndexSettings`: index name, dimension (32–4096), query encryption mode, optional output fields and fetch parameters. - `EnvectorConfig`: wraps the above and enables auto-creation via `create_if_missing`. ## Data Model @@ -41,10 +41,66 @@ Key dataclasses live in `libs/envector/config.py`: - Manual item IDs are not accepted; returned IDs from `add_texts` are ephemeral. - Filtering happens client-side; ensure metadata is JSON for structured filters. +## Examples +- Configuration + ```python + from langchain_envector.config import ConnectionConfig, EnvectorConfig, IndexSettings, KeyConfig + + cfg = EnvectorConfig( + connection=ConnectionConfig( + address=ES2_ADDRESS, + access_token=ES2_ACCESS_TOKEN + ), + key=KeyConfig( + key_path=ES2_KEY_PATH, + key_id=ES2_KEY_ID, + preset="ip", + eval_mode="rmp" + ), + index=IndexSettings( + index_name=INDEX_NAME, + dim=vector_dim, + query_encryption="cipher" + ), + create_if_missing=True, + ) + ``` + +- Add documents (from LangChain Documents): + + ```python + from langchain_core.documents import Document + from langchain_envector.vectorstore import Envector + + docs = [ + Document( + page_content="chunk-1", + metadata={"source": "paper.pdf", "page": 1, "chunk": 0} + ), + Document( + page_content="chunk-2", + metadata={"source": "paper.pdf", "page": 1, "chunk": 1} + ), + ] + + store = Envector(config=cfg, embeddings=emb) + store.add_documents(docs) + ``` + ## Troubleshooting - Connection issues: verify ES2 address and registered keys. - Embeddings mismatch: ensure embedding dimension equals `index.dim` when supplying vectors. - Unexpected raw strings: confirm inserts used the JSON envelope. +- Key Issues: check key's metadata to sync with the registered key if facing any key issue. + +## Testing Without ES2 +- Run unit tests offline (no ES2 or SDK required): + - `python -m pytest -q -m "not integration"` + - or `python scripts/run_unit_tests.py` +- Run integration tests (requires server and keys): + - Export `ES2_ADDRESS`, `ES2_KEY_PATH`, `ES2_KEY_ID` + - Optional: `ES2_USE_EMBEDDINGS=1`, `ES2_EMB_MODEL`, `ES2_USE_HF_DATASET=1` + - `python -m pytest -q -m integration -s` ## Contributing See [`CONTRIBUTE.md`](CONTRIBUTE.md) for development, testing, and PR guidelines. diff --git a/libs/envector/examples/basic_usage.py b/libs/envector/examples/basic_usage.py index c0824ea..d70b61b 100644 --- a/libs/envector/examples/basic_usage.py +++ b/libs/envector/examples/basic_usage.py @@ -8,7 +8,12 @@ from __future__ import annotations -from libs.envector.config import ConnectionConfig, EnvectorConfig, IndexSettings, KeyConfig +from libs.envector.config import ( + ConnectionConfig, + EnvectorConfig, + IndexSettings, + KeyConfig, +) from libs.envector.vectorstore import Envector @@ -16,7 +21,9 @@ def main(): # Replace with your actual settings cfg = EnvectorConfig( connection=ConnectionConfig(address="localhost:50050"), - key=KeyConfig(key_path="./keys", key_id="example_key", preset="ip", eval_mode="rmp"), + key=KeyConfig( + key_path="./keys", key_id="example_key", preset="ip", eval_mode="rmp" + ), index=IndexSettings(index_name="demo", dim=384, query_encryption="plain"), create_if_missing=True, ) @@ -43,4 +50,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/libs/envector/examples/cipher_query.py b/libs/envector/examples/cipher_query.py index d7defc3..64fe174 100644 --- a/libs/envector/examples/cipher_query.py +++ b/libs/envector/examples/cipher_query.py @@ -6,15 +6,24 @@ from __future__ import annotations -from libs.envector.config import ConnectionConfig, EnvectorConfig, IndexSettings, KeyConfig +from libs.envector.config import ( + ConnectionConfig, + EnvectorConfig, + IndexSettings, + KeyConfig, +) from libs.envector.vectorstore import Envector def main(): cfg = EnvectorConfig( connection=ConnectionConfig(address="localhost:50050"), - key=KeyConfig(key_path="./keys", key_id="example_key", preset="ip", eval_mode="rmp"), - index=IndexSettings(index_name="demo_cipher", dim=384, query_encryption="cipher"), + key=KeyConfig( + key_path="./keys", key_id="example_key", preset="ip", eval_mode="rmp" + ), + index=IndexSettings( + index_name="demo_cipher", dim=384, query_encryption="cipher" + ), create_if_missing=True, ) @@ -38,4 +47,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/libs/envector/examples/ingest_synthetic_1k.py b/libs/envector/examples/ingest_synthetic_1k.py index dbb0e31..8b055c8 100644 --- a/libs/envector/examples/ingest_synthetic_1k.py +++ b/libs/envector/examples/ingest_synthetic_1k.py @@ -19,7 +19,12 @@ from pathlib import Path from typing import List -from libs.envector.config import ConnectionConfig, EnvectorConfig, IndexSettings, KeyConfig +from libs.envector.config import ( + ConnectionConfig, + EnvectorConfig, + IndexSettings, + KeyConfig, +) from libs.envector.vectorstore import Envector @@ -34,7 +39,12 @@ def main(): ap.add_argument("--key-path", required=True) ap.add_argument("--key-id", required=True) ap.add_argument("--index-name", required=True) - ap.add_argument("--dim", type=int, required=False, help="If omitted and --use-embeddings, infer from model.") + ap.add_argument( + "--dim", + type=int, + required=False, + help="If omitted and --use-embeddings, infer from model.", + ) ap.add_argument("--dataset", default="data/synthetic_rag_1k.jsonl") ap.add_argument("--use-embeddings", action="store_true") ap.add_argument("--model", default="sentence-transformers/all-MiniLM-L6-v2") @@ -52,7 +62,9 @@ def main(): cfg = EnvectorConfig( connection=ConnectionConfig(address=args.address), - key=KeyConfig(key_path=args.key_path, key_id=args.key_id, preset="ip", eval_mode="rmp"), + key=KeyConfig( + key_path=args.key_path, key_id=args.key_id, preset="ip", eval_mode="rmp" + ), index=IndexSettings( index_name=args.index_name, dim=(args.dim if args.dim is not None else inferred_dim or 0), @@ -76,7 +88,9 @@ def main(): if embeddings is None: # Without embeddings, require manual vectors; here we simply skip. # Users should provide --use-embeddings or adapt to their vector source. - raise ValueError("--use-embeddings is required unless you provide vectors explicitly.") + raise ValueError( + "--use-embeddings is required unless you provide vectors explicitly." + ) store.add_texts(t_batch, metadatas=m_batch) print(f"Inserted {len(texts)} documents into index '{args.index_name}'") diff --git a/libs/envector/langchain_envector/__init__.py b/libs/envector/langchain_envector/__init__.py index a92cf2b..cf7a9b6 100644 --- a/libs/envector/langchain_envector/__init__.py +++ b/libs/envector/langchain_envector/__init__.py @@ -7,5 +7,10 @@ from .vectorstore import Envector from .config import ConnectionConfig, EnvectorConfig, IndexSettings, KeyConfig -__all__ = ["Envector", "ConnectionConfig", "EnvectorConfig", "IndexSettings", "KeyConfig"] - +__all__ = [ + "Envector", + "ConnectionConfig", + "EnvectorConfig", + "IndexSettings", + "KeyConfig", +] diff --git a/libs/envector/langchain_envector/client.py b/libs/envector/langchain_envector/client.py index c8b2952..c3ac2f9 100644 --- a/libs/envector/langchain_envector/client.py +++ b/libs/envector/langchain_envector/client.py @@ -1,7 +1,5 @@ from __future__ import annotations -from typing import Optional - from .config import EnvectorConfig @@ -34,7 +32,9 @@ def init(self): else: if not (c.host and c.port): raise ValueError("Either address or host+port must be provided.") - es2_client.init_connect(host=c.host, port=c.port, access_token=c.access_token) + es2_client.init_connect( + host=c.host, port=c.port, access_token=c.access_token + ) # Key path baseline for Index from es2.index import Index as _Index @@ -79,4 +79,3 @@ def es2(self): if self._es2 is None: raise RuntimeError("Client not initialized. Call init().") return self._es2 - diff --git a/libs/envector/langchain_envector/config.py b/libs/envector/langchain_envector/config.py index b6be7c0..62e5291 100644 --- a/libs/envector/langchain_envector/config.py +++ b/libs/envector/langchain_envector/config.py @@ -39,4 +39,3 @@ class EnvectorConfig: key: KeyConfig index: IndexSettings create_if_missing: bool = True - diff --git a/libs/envector/langchain_envector/retriever.py b/libs/envector/langchain_envector/retriever.py index ee578bc..255471c 100644 --- a/libs/envector/langchain_envector/retriever.py +++ b/libs/envector/langchain_envector/retriever.py @@ -12,7 +12,9 @@ class EnvectorRetriever: - def __init__(self, store: Envector, *, search_kwargs: Optional[Dict[str, Any]] = None) -> None: + def __init__( + self, store: Envector, *, search_kwargs: Optional[Dict[str, Any]] = None + ) -> None: self.store = store self.search_kwargs = search_kwargs or {} diff --git a/libs/envector/langchain_envector/types.py b/libs/envector/langchain_envector/types.py index e4a82ae..286a24c 100644 --- a/libs/envector/langchain_envector/types.py +++ b/libs/envector/langchain_envector/types.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union, overload +from typing import Any, Callable, Dict, List, Optional, Protocol class Embeddings(Protocol): @@ -10,10 +10,14 @@ class Embeddings(Protocol): LangChain-compatible embeddings typically implement these two methods. """ - def embed_documents(self, texts: List[str]) -> List[List[float]]: # pragma: no cover - interface only + def embed_documents( + self, texts: List[str] + ) -> List[List[float]]: # pragma: no cover - interface only ... - def embed_query(self, text: str) -> List[float]: # pragma: no cover - interface only + def embed_query( + self, text: str + ) -> List[float]: # pragma: no cover - interface only ... @@ -53,8 +57,6 @@ def unpack_metadata(raw: Any) -> Dict[str, Any]: if isinstance(raw, dict): return raw - print("slafjklshglkhslafhlksadjlghsal;hf") - # Some responses wrap the payload in a single-element list. if isinstance(raw, (list, tuple)): if len(raw) == 1: @@ -94,8 +96,13 @@ def unpack_metadata(raw: Any) -> Dict[str, Any]: # --- Embeddings adaptation helpers ----------------------------------------------------- + class _CallableEmbeddings: - def __init__(self, docs_fn: Callable[[List[str]], List[List[float]]], query_fn: Callable[[str], List[float]]): + def __init__( + self, + docs_fn: Callable[[List[str]], List[List[float]]], + query_fn: Callable[[str], List[float]], + ): self._docs_fn = docs_fn self._query_fn = query_fn @@ -132,7 +139,12 @@ def query_fn(text: str) -> List[float]: return _CallableEmbeddings(docs_fn, query_fn) # Case 3: Tuple of callables - if isinstance(emb, tuple) and len(emb) == 2 and callable(emb[0]) and callable(emb[1]): + if ( + isinstance(emb, tuple) + and len(emb) == 2 + and callable(emb[0]) + and callable(emb[1]) + ): docs_fn, query_fn = emb # type: ignore[assignment] return _CallableEmbeddings(docs_fn, query_fn) diff --git a/libs/envector/langchain_envector/vectorstore.py b/libs/envector/langchain_envector/vectorstore.py index b56d0a0..0350af9 100644 --- a/libs/envector/langchain_envector/vectorstore.py +++ b/libs/envector/langchain_envector/vectorstore.py @@ -1,9 +1,6 @@ from __future__ import annotations -import json -from typing import Any, Dict, Iterable, List, Optional, Sequence -from uuid import uuid4 - +from typing import Any, Dict, List, Optional from .config import EnvectorConfig from .client import EnvectorClient from .types import Embeddings, as_embeddings, pack_metadata, unpack_metadata @@ -21,12 +18,15 @@ def _try_import_langchain(): except Exception: # pragma: no cover - optional dependency # Minimal shim if LangChain is not installed class Document: # type: ignore - def __init__(self, page_content: str, metadata: Optional[Dict[str, Any]] = None): + def __init__( + self, page_content: str, metadata: Optional[Dict[str, Any]] = None + ): self.page_content = page_content self.metadata = metadata or {} try: from langchain_core.vectorstores import VectorStore as _VectorStore # type: ignore + VectorStoreBase = _VectorStore except Exception: # pragma: no cover - optional dependency pass @@ -119,9 +119,15 @@ def similarity_search( top_k = fetch_k or self.config.index.fetch_k or k - results = self.client.index.search(query=embedding, top_k=top_k, output_fields=self.config.index.output_fields) + results = self.client.index.search( + query=embedding, top_k=top_k, output_fields=self.config.index.output_fields + ) # ES2 Index.search returns a list for each query; we passed single query - result = results[0] if isinstance(results, list) and results and isinstance(results[0], list) else results + result = ( + results[0] + if isinstance(results, list) and results and isinstance(results[0], list) + else results + ) docs = [] # Iterate from top-1 to top-k @@ -129,7 +135,7 @@ def similarity_search( # item = {"id": ..., "score": float, "metadata": [str] or {...}} score = float(item.get("score", 0.0)) md_obj_raw = item.get("metadata") - + # Metadata encryption/decryption is handled by the SDK. # Envector currently supports a single associated data field (string). # Convention: if the string is JSON like {"text": str, "metadata": {...}}, @@ -148,7 +154,10 @@ def similarity_search( if score_threshold is not None and score < score_threshold: continue - doc = Document(page_content=text, metadata={**metadata, "_score": score, "_id": item.get("id")}) + doc = Document( + page_content=text, + metadata={**metadata, "_score": score, "_id": item.get("id")}, + ) docs.append(doc) # Trim to k after filtering @@ -178,6 +187,30 @@ def similarity_search_by_vector( # ------------------------------- # Class constructors (LangChain compatibility) # ------------------------------- + def add_documents( + self, + documents: List[Document], + ids: Optional[List[str]] = None, + *, + vectors: Optional[List[List[float]]] = None, + **kwargs: Any, + ) -> List[int]: + """Insert a list of Documents. + + Mirrors LangChain's VectorStore API. Delegates to `add_texts` by + extracting `page_content` and `metadata` from each Document. + + Notes: + - Manual `ids` are ignored (ES2 does not support user-provided IDs). + - When `embeddings` is not configured, you must supply `vectors`. + - Returns ephemeral IDs as produced by the client insert. + """ + texts = [getattr(d, "page_content", "") for d in documents] + metadatas = [getattr(d, "metadata", {}) for d in documents] + return self.add_texts( + texts=texts, metadatas=metadatas, ids=ids, vectors=vectors, **kwargs + ) + @classmethod def from_texts( cls, @@ -211,7 +244,9 @@ def from_documents( ) -> "Envector": # type: ignore[override] texts = [d.page_content for d in documents] metadatas = [getattr(d, "metadata", {}) for d in documents] - return cls.from_texts(texts=texts, metadatas=metadatas, embeddings=embeddings, **kwargs) + return cls.from_texts( + texts=texts, metadatas=metadatas, embeddings=embeddings, **kwargs + ) # Optional: if LangChain is installed, this will be used; otherwise, users may call similarity_search directly. def as_retriever(self, **kwargs: Any): # pragma: no cover - wrapper @@ -222,7 +257,9 @@ def as_retriever(self, **kwargs: Any): # pragma: no cover - wrapper except Exception: # Minimal shim if VectorStoreRetriever is unavailable class _Retriever: - def __init__(self, vs: Envector, search_kwargs: Optional[Dict[str, Any]] = None): + def __init__( + self, vs: Envector, search_kwargs: Optional[Dict[str, Any]] = None + ): self.vs = vs self.search_kwargs = search_kwargs or {} diff --git a/scripts/export_hf_dataset.py b/scripts/export_hf_dataset.py index d49f299..9dff61d 100644 --- a/scripts/export_hf_dataset.py +++ b/scripts/export_hf_dataset.py @@ -17,19 +17,26 @@ import argparse import json from pathlib import Path -from typing import List def main(): ap = argparse.ArgumentParser() ap.add_argument("--name", required=True, help="HF dataset name, e.g., ag_news") - ap.add_argument("--subset", default=None, help="Optional subset/config of the dataset") + ap.add_argument( + "--subset", default=None, help="Optional subset/config of the dataset" + ) ap.add_argument("--split", default="train") ap.add_argument("--text-column", required=True) - ap.add_argument("--meta-columns", nargs="*", default=[], help="Optional metadata columns to carry over") + ap.add_argument( + "--meta-columns", + nargs="*", + default=[], + help="Optional metadata columns to carry over", + ) ap.add_argument("--size", type=int, default=1000) ap.add_argument("--seed", type=int, default=42) ap.add_argument("--out", default="data/hf_export.jsonl") + ap.add_argument("--cache-dir", default=None, help="Optional HF datasets cache dir") args = ap.parse_args() try: @@ -37,9 +44,14 @@ def main(): except Exception as e: # pragma: no cover - env dependent raise SystemExit(f"Install 'datasets' package to use this script: {e}") - ds = load_dataset(args.name, args.subset, split=args.split) - if args.size and args.size < len(ds): - ds = ds.shuffle(seed=args.seed).select(range(args.size)) + ds = load_dataset( + args.name, + args.subset, + split=args.split, + cache_dir=args.cache_dir, + streaming=True, + ) + ds = ds.shuffle(seed=args.seed).take(args.size) out_path = Path(args.out) out_path.parent.mkdir(parents=True, exist_ok=True) @@ -47,7 +59,9 @@ def main(): with out_path.open("w", encoding="utf-8") as f: for row in ds: text = row[args.text_column] - meta = {k: row.get(k) for k in args.meta_columns} if args.meta_columns else {} + meta = ( + {k: row.get(k) for k in args.meta_columns} if args.meta_columns else {} + ) rec = {"text": text, "metadata": meta} f.write(json.dumps(rec, ensure_ascii=False) + "\n") @@ -56,4 +70,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/scripts/make_synthetic_rag_dataset.py b/scripts/make_synthetic_rag_dataset.py index 8badf1c..b9f4b61 100644 --- a/scripts/make_synthetic_rag_dataset.py +++ b/scripts/make_synthetic_rag_dataset.py @@ -11,7 +11,6 @@ import argparse import json -import os import random from pathlib import Path @@ -59,7 +58,7 @@ def make_sentence(topic: str) -> str: def make_paragraph(topic: str, min_sent: int = 3, max_sent: int = 7) -> str: n = random.randint(min_sent, max_sent) - return " " .join(make_sentence(topic) for _ in range(n)) + return " ".join(make_sentence(topic) for _ in range(n)) def main(): @@ -86,4 +85,3 @@ def main(): if __name__ == "__main__": main() - diff --git a/scripts/run_unit_tests.py b/scripts/run_unit_tests.py index ecd1cdc..489698b 100644 --- a/scripts/run_unit_tests.py +++ b/scripts/run_unit_tests.py @@ -4,6 +4,11 @@ import inspect import sys import traceback +from pathlib import Path + +REPO_ROOT = Path(__file__).resolve().parents[1] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) def run_module_tests(module_name: str) -> list[tuple[str, bool, str]]: @@ -43,4 +48,3 @@ def main() -> int: if __name__ == "__main__": raise SystemExit(main()) - diff --git a/tests/__init__.py b/tests/__init__.py index c10059a..070d470 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -13,4 +13,3 @@ pkg_path = str(_PKG_DIR) if pkg_path not in sys.path: sys.path.insert(0, pkg_path) - diff --git a/tests/conftest.py b/tests/conftest.py index 47172b6..fe7954e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -24,8 +24,7 @@ class FakeIndex: def insert(self, data: List[List[float]], metadata: List[str]): self.inserted.append({"data": data, "metadata": metadata}) - batch_idx = len(self.inserted) - 1 - return [len(self.inserted)+i+1 for i in range(len(metadata))] + return [len(self.inserted) + i + 1 for i in range(len(metadata))] def search(self, query: List[float], top_k: int, output_fields: List[str]): if self.search_payload is not None: diff --git a/tests/integration/test_es2_integration.py b/tests/integration/test_es2_integration.py index c43661c..03aeb81 100644 --- a/tests/integration/test_es2_integration.py +++ b/tests/integration/test_es2_integration.py @@ -5,7 +5,12 @@ import time import pytest -from langchain_envector.config import ConnectionConfig, EnvectorConfig, IndexSettings, KeyConfig +from langchain_envector.config import ( + ConnectionConfig, + EnvectorConfig, + IndexSettings, + KeyConfig, +) from langchain_envector.vectorstore import Envector @@ -33,13 +38,17 @@ def test_e2e_vectorstore_plain_and_cipher(): key_path = _require_env("ES2_KEY_PATH") key_id = _require_env("ES2_KEY_ID") use_emb = os.environ.get("ES2_USE_EMBEDDINGS") in {"1", "true", "TRUE", "yes"} - model_name = os.environ.get("ES2_EMB_MODEL", "sentence-transformers/all-MiniLM-L6-v2") + model_name = os.environ.get( + "ES2_EMB_MODEL", "sentence-transformers/all-MiniLM-L6-v2" + ) use_hf = os.environ.get("ES2_USE_HF_DATASET") in {"1", "true", "TRUE", "yes"} hf_name = os.environ.get("ES2_HF_NAME", "ag_news") hf_subset = os.environ.get("ES2_HF_SUBSET") hf_split = os.environ.get("ES2_HF_SPLIT", "train") hf_text_col = os.environ.get("ES2_HF_TEXT_COL", "text") - hf_meta_cols = [c for c in os.environ.get("ES2_HF_META_COLS", "label").split(",") if c] + hf_meta_cols = [ + c for c in os.environ.get("ES2_HF_META_COLS", "label").split(",") if c + ] hf_size = int(os.environ.get("ES2_HF_SIZE", "200")) hf_seed = int(os.environ.get("ES2_HF_SEED", "42")) @@ -64,14 +73,17 @@ def test_e2e_vectorstore_plain_and_cipher(): except Exception as e: pytest.skip(f"Embeddings requested but unavailable: {e}") else: - dim = int(dim_env or "16") + dim = int(dim_env or "32") - if dim < 16 or dim > 4096: - pytest.skip("Envector supports dimensions in [16, 4096]") + if dim < 32 or dim > 4096: + pytest.skip("Envector supports dimensions in [32, 4096]") - base_index_name = os.environ.get("ES2_INDEX_NAME", f"inttest_{secrets.token_hex(4)}") + base_index_name = os.environ.get( + "ES2_INDEX_NAME", f"inttest_{secrets.token_hex(4)}" + ) import es2 + es2.init_connect(address=address) es2.reset() @@ -79,7 +91,9 @@ def test_e2e_vectorstore_plain_and_cipher(): cfg_plain = EnvectorConfig( connection=ConnectionConfig(address=address), key=KeyConfig(key_path=key_path, key_id=key_id, preset="ip", eval_mode="rmp"), - index=IndexSettings(index_name=f"{base_index_name}_plain", dim=dim, query_encryption="plain"), + index=IndexSettings( + index_name=f"{base_index_name}_plain", dim=dim, query_encryption="plain" + ), create_if_missing=True, ) store_plain = Envector(config=cfg_plain, embeddings=(emb if use_emb else None)) @@ -93,14 +107,14 @@ def test_e2e_vectorstore_plain_and_cipher(): if hf_size and hf_size < len(ds): ds = ds.shuffle(seed=hf_seed).select(range(hf_size)) texts = [row[hf_text_col] for row in ds] - metas = [ - {k: row.get(k) for k in hf_meta_cols if k in row} - for row in ds - ] + metas = [{k: row.get(k) for k in hf_meta_cols if k in row} for row in ds] print(texts[0]) print(metas[0]) else: - texts = ["machine learning accelerates research", "cooking recipes are delicious"] + texts = [ + "machine learning accelerates research", + "cooking recipes are delicious", + ] metas = [{"label": "A"}, {"label": "B"}] if use_emb: @@ -120,21 +134,34 @@ def test_e2e_vectorstore_plain_and_cipher(): docs = store_plain.similarity_search(q1, k=3) print("[plain] top-3 results for:", q1) for d in docs: - print(" - score=", d.metadata.get("_score"), "text=", (d.page_content[:80] + ("..." if len(d.page_content) > 80 else ""))) + print( + " - score=", + d.metadata.get("_score"), + "text=", + (d.page_content[:80] + ("..." if len(d.page_content) > 80 else "")), + ) assert len(docs) >= 1 assert all("_id" in d.metadata for d in docs) # optional filter check if 'label' is part of meta if not use_hf: - docs_f = store_plain.similarity_search("cooking", k=2, filter={"label": "B"}) + docs_f = store_plain.similarity_search( + "cooking", k=2, filter={"label": "B"} + ) print("[plain] filtered results (label=B):", [d.metadata for d in docs_f]) - assert len(docs_f) >= 1 and all(d.metadata.get("label") == "B" for d in docs_f) + assert len(docs_f) >= 1 and all( + d.metadata.get("label") == "B" for d in docs_f + ) else: # Using explicit embeddings docs = store_plain.similarity_search("q", k=2, embedding=e1) - print("[plain] results (explicit embedding e1):", [d.page_content for d in docs]) + print( + "[plain] results (explicit embedding e1):", [d.page_content for d in docs] + ) assert any(d.page_content == texts[0] for d in docs) assert all("_id" in d.metadata for d in docs) - docs_f = store_plain.similarity_search("q", k=2, embedding=e2, filter={"label": "B"}) + docs_f = store_plain.similarity_search( + "q", k=2, embedding=e2, filter={"label": "B"} + ) print("[plain] filtered (e2, label=B):", [d.page_content for d in docs_f]) assert len(docs_f) >= 1 assert docs_f[0].page_content == texts[1] @@ -143,7 +170,9 @@ def test_e2e_vectorstore_plain_and_cipher(): cfg_cc = EnvectorConfig( connection=ConnectionConfig(address=address), key=KeyConfig(key_path=key_path, key_id=key_id, preset="ip", eval_mode="rmp"), - index=IndexSettings(index_name=f"{base_index_name}_cipher", dim=dim, query_encryption="cipher"), + index=IndexSettings( + index_name=f"{base_index_name}_cipher", dim=dim, query_encryption="cipher" + ), create_if_missing=True, ) store_cc = Envector(config=cfg_cc, embeddings=(emb if use_emb else None)) @@ -158,12 +187,20 @@ def test_e2e_vectorstore_plain_and_cipher(): docs_cc = store_cc.similarity_search(q2, k=3) print("[cipher] top-3 results for:", q2) for d in docs_cc: - print(" - score=", d.metadata.get("_score"), "text=", (d.page_content[:80] + ("..." if len(d.page_content) > 80 else ""))) + print( + " - score=", + d.metadata.get("_score"), + "text=", + (d.page_content[:80] + ("..." if len(d.page_content) > 80 else "")), + ) assert len(docs_cc) >= 1 assert all("_id" in d.metadata for d in docs_cc) else: docs_cc = store_cc.similarity_search("q", k=2, embedding=e2) - print("[cipher] results (explicit embedding e2):", [d.page_content for d in docs_cc]) + print( + "[cipher] results (explicit embedding e2):", + [d.page_content for d in docs_cc], + ) assert any(d.page_content == texts[1] for d in docs_cc) assert all("_id" in d.metadata for d in docs_cc) diff --git a/tests/test_types.py b/tests/test_types.py index c85c371..65cd546 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -1,7 +1,5 @@ from __future__ import annotations -import json - from langchain_envector.types import pack_metadata, unpack_metadata diff --git a/tests/test_vectorstore.py b/tests/test_vectorstore.py index a1bb8c7..b572ee5 100644 --- a/tests/test_vectorstore.py +++ b/tests/test_vectorstore.py @@ -1,9 +1,12 @@ from __future__ import annotations -import re - -from langchain_envector.config import ConnectionConfig, EnvectorConfig, IndexSettings, KeyConfig -from langchain_envector.vectorstore import Envector +from langchain_envector.config import ( + ConnectionConfig, + EnvectorConfig, + IndexSettings, + KeyConfig, +) +from langchain_envector.vectorstore import Envector, Document as LC_Document from .conftest import FakeClient, FakeEmbeddings, FakeIndex @@ -20,7 +23,9 @@ def test_add_texts_ignores_ids_and_returns_item_ids(): client = FakeClient() store = Envector(config=_cfg(), embeddings=FakeEmbeddings(dim=4), client=client) - ret_ids = store.add_texts(["t1", "t2"], metadatas=[{"m": 1}, {"m": 2}], ids=["a", "b"]) # ids ignored + ret_ids = store.add_texts( + ["t1", "t2"], metadatas=[{"m": 1}, {"m": 2}], ids=["a", "b"] + ) # ids ignored # Returned IDs assert len(ret_ids) == 2 @@ -30,20 +35,32 @@ def test_add_texts_ignores_ids_and_returns_item_ids(): assert len(client.index.inserted) == 1 packed = client.index.inserted[0]["metadata"] assert len(packed) == 2 - assert "\"id\"" not in packed[0] + assert '"id"' not in packed[0] def test_similarity_search_with_filter_and_threshold(): index = FakeIndex() # Two items, different scores and tags - index.search_payload = [[ - {"id": "pos-0", "score": 0.95, "metadata": "{\"text\": \"A\", \"metadata\": {\"tag\": \"keep\"}}"}, - {"id": "pos-1", "score": 0.40, "metadata": "{\"text\": \"B\", \"metadata\": {\"tag\": \"drop\"}}"}, - ]] + index.search_payload = [ + [ + { + "id": "pos-0", + "score": 0.95, + "metadata": '{"text": "A", "metadata": {"tag": "keep"}}', + }, + { + "id": "pos-1", + "score": 0.40, + "metadata": '{"text": "B", "metadata": {"tag": "drop"}}', + }, + ] + ] client = FakeClient(index) store = Envector(config=_cfg(), embeddings=FakeEmbeddings(dim=4), client=client) - docs = store.similarity_search("q", k=5, filter={"tag": "keep"}, score_threshold=0.5) + docs = store.similarity_search( + "q", k=5, filter={"tag": "keep"}, score_threshold=0.5 + ) assert len(docs) == 1 assert docs[0].page_content == "A" assert docs[0].metadata["_score"] >= 0.5 @@ -52,9 +69,15 @@ def test_similarity_search_with_filter_and_threshold(): def test_similarity_search_handles_string_metadata(): index = FakeIndex() # metadata returned as a single JSON string instead of list - index.search_payload = [[ - {"id": "pos-0", "score": 0.8, "metadata": "{\"text\": \"S\", \"metadata\": {\"t\": 1}}"}, - ]] + index.search_payload = [ + [ + { + "id": "pos-0", + "score": 0.8, + "metadata": '{"text": "S", "metadata": {"t": 1}}', + }, + ] + ] client = FakeClient(index) store = Envector(config=_cfg(), embeddings=FakeEmbeddings(dim=4), client=client) @@ -67,9 +90,15 @@ def test_similarity_search_handles_string_metadata(): def test_similarity_search_uses_raw_text_when_not_json(): index = FakeIndex() # metadata is a plain string (not JSON); should be treated as page_content - index.search_payload = [[ - {"id": "pos-raw", "score": 0.6, "metadata": "Plain text content without JSON"}, - ]] + index.search_payload = [ + [ + { + "id": "pos-raw", + "score": 0.6, + "metadata": "Plain text content without JSON", + }, + ] + ] client = FakeClient(index) store = Envector(config=_cfg(), embeddings=FakeEmbeddings(dim=4), client=client) @@ -77,15 +106,19 @@ def test_similarity_search_uses_raw_text_when_not_json(): assert len(docs) == 1 assert docs[0].page_content == "Plain text content without JSON" # user metadata should be empty dict when not provided - assert all(k in docs[0].metadata for k in ["_score", "_id"]) # only system fields present + assert all( + k in docs[0].metadata for k in ["_score", "_id"] + ) # only system fields present def test_similarity_search_handles_python_literal_metadata(): index = FakeIndex() literal = str({"text": "Literal", "metadata": {"tag": "py"}}) - index.search_payload = [[ - {"id": "pos-lit", "score": 0.7, "metadata": literal}, - ]] + index.search_payload = [ + [ + {"id": "pos-lit", "score": 0.7, "metadata": literal}, + ] + ] client = FakeClient(index) store = Envector(config=_cfg(), embeddings=FakeEmbeddings(dim=4), client=client) @@ -94,5 +127,111 @@ def test_similarity_search_handles_python_literal_metadata(): assert docs[0].page_content == "Literal" assert docs[0].metadata.get("tag") == "py" - # dict-type metadata is not supported currently; only text-based + + +def test_similarity_search_by_vector_with_filter_and_threshold(): + index = FakeIndex() + index.search_payload = [ + [ + { + "id": "v-0", + "score": 0.88, + "metadata": '{"text": "Keep", "metadata": {"k": 1}}', + }, + { + "id": "v-1", + "score": 0.30, + "metadata": '{"text": "Drop", "metadata": {"k": 2}}', + }, + ] + ] + client = FakeClient(index) + store = Envector(config=_cfg(), embeddings=FakeEmbeddings(dim=4), client=client) + + # Explicit vector search (bypasses embed_query), with filter/threshold + docs = store.similarity_search_by_vector( + [0.0, 0.0, 0.0, 0.0], k=5, filter={"k": 1}, score_threshold=0.5 + ) + assert len(docs) == 1 + assert docs[0].page_content == "Keep" + assert docs[0].metadata["_score"] >= 0.5 + + +def test_from_texts_inserts_using_embeddings(): + client = FakeClient() + store = Envector.from_texts( + ["A", "B"], + metadatas=[{"m": "a"}, {"m": "b"}], + embeddings=FakeEmbeddings(dim=4), + config=_cfg(), + client=client, + ) + assert isinstance(store, Envector) + # One batch inserted + assert len(client.index.inserted) == 1 + # Two items packed + assert len(client.index.inserted[0]["metadata"]) == 2 + + +def test_from_documents_paths_through_to_texts(): + client = FakeClient() + docs = [ + LC_Document(page_content="X", metadata={"a": 1}), + LC_Document(page_content="Y", metadata={"a": 2}), + ] + store = Envector.from_documents( + docs, embeddings=FakeEmbeddings(dim=4), config=_cfg(), client=client + ) + assert isinstance(store, Envector) + assert len(client.index.inserted) == 1 + packed = client.index.inserted[0]["metadata"] + # Texts preserved + assert any('"text": "X"' in m for m in packed) + assert any('"text": "Y"' in m for m in packed) + + +def test_add_documents_with_embeddings(): + client = FakeClient() + store = Envector(config=_cfg(), embeddings=FakeEmbeddings(dim=4), client=client) + + docs = [ + LC_Document(page_content="C1", metadata={"s": 1}), + LC_Document(page_content="C2", metadata={"s": 2}), + ] + ret = store.add_documents(docs) + assert len(ret) == 2 + assert len(client.index.inserted) == 1 + packed = client.index.inserted[0]["metadata"] + assert any('"text": "C1"' in m for m in packed) + assert any('"text": "C2"' in m for m in packed) + + +def test_add_documents_requires_vectors_when_no_embeddings(): + client = FakeClient() + store = Envector(config=_cfg(), embeddings=None, client=client) + docs = [LC_Document(page_content="C", metadata={})] + try: + store.add_documents(docs) + assert ( + False + ), "Expected ValueError when embeddings is None and no vectors provided" + except ValueError as e: + assert "embeddings is None and vectors not provided" in str(e) + + +def test_add_documents_with_explicit_vectors(): + client = FakeClient() + store = Envector(config=_cfg(), embeddings=None, client=client) + + docs = [ + LC_Document(page_content="V1", metadata={"k": "a"}), + LC_Document(page_content="V2", metadata={"k": "b"}), + ] + vecs = [ + [1.0, 0.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + ] + ret = store.add_documents(docs, vectors=vecs) + assert len(ret) == 2 + assert len(client.index.inserted) == 1