Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for visualizing self-attention heatmaps + sequence classifier outputs w/ attentions #1117

Merged
merged 14 commits into from
Dec 28, 2024

Conversation

xenova
Copy link
Collaborator

@xenova xenova commented Dec 26, 2024

This PR allows users to visualize per-layer per-head attentions. Based on this notebook by @NielsRogge.

Example usage:

import { AutoProcessor, AutoModelForImageClassification, interpolate_4d, RawImage } from "@huggingface/transformers";

// Load model and processor
const model_id = "onnx-community/dinov2-with-registers-small-with-attentions";
const model = await AutoModelForImageClassification.from_pretrained(model_id);
const processor = await AutoProcessor.from_pretrained(model_id);

// Load image from URL
const image = await RawImage.read("https://huggingface.co/datasets/Xenova/transformers.js-docs/resolve/main/cats.jpg");

// Pre-process image
const inputs = await processor(image);

// Perform inference
const { logits, attentions } = await model(inputs);

// Get the predicted class
const cls = logits[0].argmax().item();
const label = model.config.id2label[cls];
console.log(`Predicted class: ${label}`);

// Set config values
const patch_size = model.config.patch_size;
const [width, height] = inputs.pixel_values.dims.slice(-2);
const w_featmap = Math.floor(width / patch_size);
const h_featmap = Math.floor(height / patch_size);
const num_heads = model.config.num_attention_heads;
const num_cls_tokens = 1;
const num_register_tokens = model.config.num_register_tokens ?? 0;

// Visualize attention maps
const selected_attentions = attentions
    .at(-1) // we are only interested in the attention maps of the last layer
    .slice(0, null, 0, [num_cls_tokens + num_register_tokens, null])
    .view(num_heads, 1, w_featmap, h_featmap);

const upscaled = await interpolate_4d(selected_attentions, {
    size: [width, height],
    mode: "nearest",
});

for (let i = 0; i < num_heads; ++i) {
    const head_attentions = upscaled[i];
    const minval = head_attentions.min().item();
    const maxval = head_attentions.max().item();
    const image = RawImage.fromTensor(
        head_attentions
            .sub_(minval)
            .div_(maxval - minval)
            .mul_(255)
            .to("uint8"),
    );
    await image.save(`attn-head-${i}.png`);
}

Attention head heatmaps for last layer:

attn-head-0 attn-head-1 attn-head-2
attn-head-3 attn-head-4 attn-head-5

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

@xenova xenova merged commit 41a6139 into main Dec 28, 2024
4 checks passed
@xenova xenova deleted the support-classifier-attentions branch December 28, 2024 11:21
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants