"""L2 descriptor normaliser aligning cosine similarity to FAISS inner-product (AZ-283). Public surface frozen by ``_docs/02_document/contracts/shared_helpers/descriptor_normaliser.md`` v1.0.0. Used on both the corpus side (C10 index build) and the query side (C2 runtime lookup). The two sides MUST go through the same helper so the FAISS HNSW search returns useful neighbours. """ from __future__ import annotations from typing import Final import numpy as np __all__ = [ "ALLOWED_DTYPES", "DescriptorNormaliser", "DescriptorNormaliserError", ] # Allowed input dtypes; anything else is rejected to keep the FAISS index and # query path on the same precision. ALLOWED_DTYPES: Final[tuple[np.dtype, ...]] = ( np.dtype(np.float16), np.dtype(np.float32), ) _METRIC_VALUE: Final[str] = "inner_product" class DescriptorNormaliserError(ValueError): """Raised on shape / dtype violations (AZ-283).""" def _validate_dtype(arr: np.ndarray, label: str) -> None: if arr.dtype not in ALLOWED_DTYPES: raise DescriptorNormaliserError( f"{label}: dtype {arr.dtype} not in allowed set (float16, float32)" ) class DescriptorNormaliser: """Stateless L2-normalisation helper; dtype-preserving; zero-norm safe.""" @staticmethod def l2_normalise(descriptor: np.ndarray) -> np.ndarray: if not isinstance(descriptor, np.ndarray): raise DescriptorNormaliserError( f"l2_normalise: expected np.ndarray; got {type(descriptor).__name__}" ) if descriptor.ndim != 1: raise DescriptorNormaliserError( f"l2_normalise: expected 1-D shape (D,); got shape {descriptor.shape}" ) if descriptor.shape[0] < 1: raise DescriptorNormaliserError( f"l2_normalise: dimension must be >= 1; got shape {descriptor.shape}" ) _validate_dtype(descriptor, "l2_normalise") in_dtype = descriptor.dtype # Compute norm in float32 to stabilise float16 inputs against overflow / # underflow; cast back to the caller dtype so we never silently up-cast. as_f32 = descriptor.astype(np.float32, copy=False) norm = float(np.linalg.norm(as_f32)) if norm == 0.0: return np.zeros_like(descriptor) normalised_f32 = as_f32 / norm return normalised_f32.astype(in_dtype, copy=False) @staticmethod def l2_normalise_batch(descriptors: np.ndarray) -> np.ndarray: if not isinstance(descriptors, np.ndarray): raise DescriptorNormaliserError( f"l2_normalise_batch: expected np.ndarray; got {type(descriptors).__name__}" ) if descriptors.ndim != 2: raise DescriptorNormaliserError( f"l2_normalise_batch: expected 2-D shape (N, D); got shape {descriptors.shape}" ) if descriptors.shape[0] < 1 or descriptors.shape[1] < 1: raise DescriptorNormaliserError( f"l2_normalise_batch: N and D must be >= 1; got shape {descriptors.shape}" ) _validate_dtype(descriptors, "l2_normalise_batch") in_dtype = descriptors.dtype as_f32 = descriptors.astype(np.float32, copy=False) norms = np.linalg.norm(as_f32, axis=1, keepdims=True) # Avoid division-by-zero: leave zero rows as zero. safe = np.where(norms == 0.0, 1.0, norms) normalised_f32 = np.where(norms == 0.0, 0.0, as_f32 / safe) return normalised_f32.astype(in_dtype, copy=False) @staticmethod def intra_cluster_normalise( descriptor: np.ndarray, num_clusters: int ) -> np.ndarray: """Per-cluster L2 normalisation for VLAD-aggregated descriptors (AZ-338). NetVLAD's published preprocessing chain L2-normalises each per-cluster sub-vector BEFORE the global L2 step. The input is a flat 1-D VLAD descriptor of shape ``(num_clusters * cluster_dim,)`` which is reshaped to ``(num_clusters, cluster_dim)``, normalised row-wise, then flattened back. ``num_clusters`` must divide ``descriptor.shape[0]``. Zero-norm sub-vectors are returned as zero (consistent with :meth:`l2_normalise`). """ if not isinstance(descriptor, np.ndarray): raise DescriptorNormaliserError( f"intra_cluster_normalise: expected np.ndarray; " f"got {type(descriptor).__name__}" ) if descriptor.ndim != 1: raise DescriptorNormaliserError( f"intra_cluster_normalise: expected 1-D shape (K*D,); " f"got shape {descriptor.shape}" ) if not isinstance(num_clusters, int) or isinstance(num_clusters, bool): raise DescriptorNormaliserError( f"intra_cluster_normalise: num_clusters must be a non-bool " f"int; got {num_clusters!r}" ) if num_clusters < 1: raise DescriptorNormaliserError( f"intra_cluster_normalise: num_clusters must be >= 1; " f"got {num_clusters}" ) total_dim = descriptor.shape[0] if total_dim % num_clusters != 0: raise DescriptorNormaliserError( f"intra_cluster_normalise: descriptor length {total_dim} " f"not divisible by num_clusters={num_clusters}" ) _validate_dtype(descriptor, "intra_cluster_normalise") in_dtype = descriptor.dtype cluster_dim = total_dim // num_clusters reshaped = descriptor.reshape(num_clusters, cluster_dim).astype( np.float32, copy=False ) norms = np.linalg.norm(reshaped, axis=1, keepdims=True) safe = np.where(norms == 0.0, 1.0, norms) normalised = np.where(norms == 0.0, 0.0, reshaped / safe) return normalised.reshape(total_dim).astype(in_dtype, copy=False) @staticmethod def descriptor_metric() -> str: return _METRIC_VALUE