Skip to main content

Sparse Embeddings for Neural Search.

Project description

SparsEmbed - Splade

Neural search

This repository presents an unofficial replication of both models Splade and SparseEmbed with are state of the art models in information retrieval:

Note: This project is currently a work in progress. Splade Model is ready to use but I'm working on SparseEmbed. 🔨🧹

Installation

We can install sparsembed using:

pip install sparsembed

If we plan to evaluate our model while training install:

pip install "sparsembed[eval]"

Retriever

Splade

We can initialize a Splade Retriever directly from the splade_v2_max checkpoint available on HuggingFace. Retrievers are based on PyTorch sparse matrices, stored in memory and accelerated with GPU. We can reduce the number of activated tokens via the n_tokens parameter in order to reduce the memory usage of those sparse matrices.

from sparsembed import model, retrieve
from transformers import AutoModelForMaskedLM, AutoTokenizer

device = "cuda" # cpu

batch_size = 10

# List documents to index:
documents = [
 {'id': 0,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': 'Paris is the capital and most populous city of France.'},
 {'id': 1,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': "Since the 17th century, Paris has been one of Europe's major centres of science, and arts."},
 {'id': 2,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France.'
}]

model = model.Splade(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    device=device
)

retriever = retrieve.SpladeRetriever(
    key="id", # Key identifier of each document.
    on=["title", "text"], # Fields to search.
    model=model # Splade retriever.
)

retriever = retriever.add(
    documents=documents,
    batch_size=batch_size,
    k_tokens=256, # Number of activated tokens.
)

retriever(
    ["paris", "Toulouse"], # Queries 
    k_tokens=20, # Maximum number of activated tokens.
    k=100, # Number of documents to retrieve.
    batch_size=batch_size
)
[[{'id': 0, 'similarity': 11.481657981872559},
  {'id': 2, 'similarity': 11.294965744018555},
  {'id': 1, 'similarity': 10.059721946716309}],
 [{'id': 0, 'similarity': 0.7379149198532104},
  {'id': 2, 'similarity': 0.6973429918289185},
  {'id': 1, 'similarity': 0.5428210496902466}]]

SparsEmbed

We can also initialize a retriever dedicated to SparseEmbed model. The checkpoint naver/splade_v2_max is not a SparseEmbed trained model so we should train one before using it as a retriever.

from sparsembed import model, retrieve
from transformers import AutoModelForMaskedLM, AutoTokenizer

device = "cuda" # cpu

batch_size = 10

# List documents to index:
documents = [
 {'id': 0,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': 'Paris is the capital and most populous city of France.'},
 {'id': 1,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': "Since the 17th century, Paris has been one of Europe's major centres of science, and arts."},
 {'id': 2,
  'title': 'Paris',
  'url': 'https://en.wikipedia.org/wiki/Paris',
  'text': 'The City of Paris is the centre and seat of government of the region and province of Île-de-France.'
}]

model = model.SparsEmbed(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    device=device
)

retriever = retrieve.SparsEmbedRetriever(
    key="id", # Key identifier of each document.
    on=["title", "text"], # Fields to search.
    model=model # Splade retriever.
)

retriever = retriever.add(
    documents=documents,
    batch_size=batch_size,
    k_tokens=256, # Number of activated tokens.
)

retriever(
    ["paris", "Toulouse"], # Queries 
    k_tokens=20, # Maximum number of activated tokens.
    k=100, # Number of documents to retrieve.
    batch_size=batch_size
)

Training

Let's fine-tune Splade and SparsEmbed.

Dataset

Your training dataset must be made out of triples (anchor, positive, negative) where anchor is a query, positive is a document that is directly linked to the anchor and negative is a document that is not relevant for the anchor.

X = [
    ("anchor 1", "positive 1", "negative 1"),
    ("anchor 2", "positive 2", "negative 2"),
    ("anchor 3", "positive 3", "negative 3"),
]

Models

Both Splade and SparseEmbed models can be initialized from the AutoModelForMaskedLM pretrained models.

from transformers import AutoModelForMaskedLM, AutoTokenizer

model = model.Splade(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    device=device,
)
from transformers import AutoModelForMaskedLM, AutoTokenizer

model = model.SparsEmbed(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    embedding_size=64,
    k_tokens=96,
    device=device,
)

Splade

The following PyTorch code snippet illustrates the training loop to fine-tune Splade:

from transformers import AutoModelForMaskedLM, AutoTokenizer, optimization
from sparsembed import model, utils, train, retrieve, losses
import torch

device = "cuda" # cpu or cuda
batch_size = 8
epochs = 1 # Number of times the model will train over the whole dataset.

model = model.Splade(
    model=AutoModelForMaskedLM.from_pretrained("naver/splade_v2_max").to(device),
    tokenizer=AutoTokenizer.from_pretrained("naver/splade_v2_max"),
    device=device
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

scheduler = optimization.get_linear_schedule_with_warmup(
    optimizer=optimizer, 
    num_warmup_steps=6000,
    num_training_steps=4_000_000,
)

flops_scheduler = losses.FlopsScheduler(weight=1e-4, steps=50_000) 

X = [
    ("anchor 1", "positive 1", "negative 1"),
    ("anchor 2", "positive 2", "negative 2"),
    ("anchor 3", "positive 3", "negative 3"),
]

for anchor, positive, negative in utils.iter(
        X,
        epochs=epochs,
        batch_size=batch_size,
        shuffle=True
    ):
        loss = train.train_splade(
            model=model,
            optimizer=optimizer,
            anchor=anchor,
            positive=positive,
            negative=negative,
            flops_loss_weight=flops_scheduler.get(),
        )

        scheduler.step()
        flops_scheduler.step()

# Save the model.
model.save_pretrained("checkpoint")

# Beir benchmark for evaluation.
documents, queries, qrels = utils.load_beir("scifact", split="test")

retriever = retrieve.SpladeRetriever(
    key="id",
    on=["title", "text"],
    model=model
)

retriever = retriever.add(
    documents=documents,
    batch_size=batch_size,
    k_tokens=96,
)

utils.evaluate(
    retriever=retriever,
    batch_size=batch_size,
    qrels=qrels,
    queries=queries,
    k=100,
    k_tokens=96,
    metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)

After having saved the model with save_pretrained, we can load the checkpoint using:

from sparsembed import model

device = "cuda"

model = model.Splade(
    model_name_or_path="checkpoint",
    device=device,
)

SparsEmbed

The following PyTorch code snippet illustrates the training loop to fine-tune SparseEmbed:

from transformers import AutoModelForMaskedLM, AutoTokenizer, optimization
from sparsembed import model, utils, train, retrieve, losses
import torch

device = "cuda" # cpu or cuda
batch_size = 8
epochs = 1 # Number of times the model will train over the whole dataset.

model = model.SparsEmbed(
    model=AutoModelForMaskedLM.from_pretrained("distilbert-base-uncased").to(device),
    tokenizer=AutoTokenizer.from_pretrained("distilbert-base-uncased"),
    device=device
)

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

scheduler = optimization.get_linear_schedule_with_warmup(
    optimizer=optimizer, 
    num_warmup_steps=6000, # Number of warmup steps.
    num_training_steps=4_000_000 # Length training set.
)

flops_scheduler = losses.FlopsScheduler(weight=1e-4, steps=50_000)

X = [
    ("anchor 1", "positive 1", "negative 1"),
    ("anchor 2", "positive 2", "negative 2"),
    ("anchor 3", "positive 3", "negative 3"),
]

for anchor, positive, negative in utils.iter(
        X,
        epochs=epochs,
        batch_size=batch_size,
        shuffle=True
    ):
        loss = train.train_sparsembed(
            model=model,
            optimizer=optimizer,
            k_tokens=96,
            anchor=anchor,
            positive=positive,
            negative=negative,
            flops_loss_weight=flops_scheduler.get(),
            sparse_loss_weight=0.1,
        )

        scheduler.step()
        flops_scheduler.step()

# Save the model.
model.save_pretrained("checkpoint")

# Beir benchmark for evaluation.
documents, queries, qrels = utils.load_beir("scifact", split="test")

retriever = retrieve.SparsEmbedRetriever(
    key="id",
    on=["title", "text"],
    model=model
)

retriever = retriever.add(
    documents=documents,
    k_tokens=96,
    batch_size=batch_size
)

utils.evaluate(
    retriever=retriever,
    batch_size=batch_size,
    qrels=qrels,
    queries=queries,
    k=100,
    k_tokens=96,
    metrics=["map", "ndcg@10", "ndcg@10", "recall@10", "hits@10"]
)

After having saved the model with save_pretrained, we can load the checkpoint using:

from sparsembed import model

device = "cuda"

model = model.SparsEmbed(
    model_name_or_path="checkpoint",
    device=device,
)

Utils

We can get the activated tokens / embeddings of a sentence with:

model.encode(["deep learning, information retrieval, sparse models"])

We can evaluate similarities between pairs of queries and documents without the use of a retriever:

model.scores(
    queries=["Query A", "Query B"], 
    documents=["Document A", "Document B"],
    batch_size=32,
)
tensor([5.1449, 9.1194])

Wen can visualize activated tokens:

model.decode(**model.encode(["deep learning, information retrieval, sparse models"]))
['deep sparse model retrieval information models depth fuzzy learning dense poor memory recall processing reading lacy include remember knowledge training heavy retrieve guide vague type small learn data']

Benchmarks

Work in progress.

Project details


Download files

Download the file for your platform. If you're not sure which to choose, learn more about installing packages.

Source Distribution

sparsembed-0.1.1.tar.gz (19.0 kB view hashes)

Uploaded Source

Supported by

AWS AWS Cloud computing and Security Sponsor Datadog Datadog Monitoring Fastly Fastly CDN Google Google Download Analytics Microsoft Microsoft PSF Sponsor Pingdom Pingdom Monitoring Sentry Sentry Error logging StatusPage StatusPage Status page