Semantic IDs with FAISS¶
This tutorial explains how to generate semantic IDs using FAISS (Facebook AI Similarity Search) for efficient nearest neighbor clustering. This approach uses hierarchical k-means clustering to create discrete item representations.
Overview¶
FAISS-based semantic ID generation works by:
Clustering item embeddings using hierarchical k-means
Assigning each item a path through the cluster hierarchy
Using the cluster path as a semantic ID
Prerequisites¶
Install FAISS:
uv pip install faiss-cpu # or faiss-gpu for GPU support
Generating Semantic IDs¶
import faiss
import numpy as np
def generate_faiss_semantic_ids(
embeddings: np.ndarray,
n_clusters_per_level: int = 256,
n_levels: int = 3,
):
"""
Generate hierarchical semantic IDs using FAISS k-means.
Args:
embeddings: Item embeddings of shape [n_items, embedding_dim]
n_clusters_per_level: Number of clusters at each level
n_levels: Depth of the hierarchy
Returns:
semantic_ids: Array of shape [n_items, n_levels]
"""
n_items, dim = embeddings.shape
semantic_ids = np.zeros((n_items, n_levels), dtype=np.int32)
# Normalize embeddings for cosine similarity
faiss.normalize_L2(embeddings)
current_embeddings = embeddings.copy()
for level in range(n_levels):
# Perform k-means clustering
kmeans = faiss.Kmeans(
d=dim,
k=n_clusters_per_level,
niter=20,
verbose=True,
gpu=False, # Set to True for GPU
)
kmeans.train(current_embeddings)
# Assign items to clusters
_, assignments = kmeans.index.search(current_embeddings, 1)
semantic_ids[:, level] = assignments.flatten()
# Compute residuals for next level
centroids = kmeans.centroids[assignments.flatten()]
current_embeddings = current_embeddings - centroids
faiss.normalize_L2(current_embeddings)
return semantic_ids
Example Usage¶
from sentence_transformers import SentenceTransformer
# Load item descriptions and generate embeddings
model = SentenceTransformer('all-MiniLM-L6-v2')
item_descriptions = ["Action movie", "Comedy film", "Drama series", ...]
embeddings = model.encode(item_descriptions)
# Generate semantic IDs
semantic_ids = generate_faiss_semantic_ids(
embeddings,
n_clusters_per_level=64,
n_levels=4,
)
# semantic_ids[i] gives the hierarchical cluster path for item i
print(f"Item 0 semantic ID: {semantic_ids[0]}")
# Output: Item 0 semantic ID: [23, 45, 12, 7]
Comparison with RQ-VAE¶
Aspect |
FAISS |
RQ-VAE |
|---|---|---|
Training |
No training required |
Requires training |
Flexibility |
Fixed clustering |
Learnable representations |
Speed |
Very fast |
Moderate |
Quality |
Good for similar items |
Better semantic capture |
Use Cases¶
Fast prototyping: When you need semantic IDs quickly without training
Large-scale datasets: FAISS scales well to millions of items
Baseline comparison: Compare against learned methods like RQ-VAE
See Also¶
Semantic IDs with RQ-VAE: Learned semantic IDs with RQ-VAE