From 41a6139ee6551aba68075dac292dc842e118f76f Mon Sep 17 00:00:00 2001 From: Joshua Lochner Date: Sat, 28 Dec 2024 13:21:06 +0200 Subject: [PATCH] Add support for visualizing self-attention heatmaps + sequence classifier outputs w/ attentions (#1117) * Add support for nearest neighbour interpolation * Support bigint inputs for `min` and `max` helper functions * Add min/max/argmin/argmax tensor ops * Allow sequence classifier outputs with attentions * Align depth estimation pipeline w/ python version * Update `min` and `max` types * Fix qwen2vl processor * Update depth estimation pipeline unit tests * Remove old comments * Update types * Fix type issues --- src/models.js | 46 ++++++++++------ .../idefics3/image_processing_idefics3.js | 2 + .../pyannote/feature_extraction_pyannote.js | 1 + .../feature_extraction_seamless_m4t.js | 4 +- .../whisper/feature_extraction_whisper.js | 2 +- src/ops/registry.js | 10 ++++ src/pipelines.js | 22 ++++++-- src/tokenizers.js | 7 ++- src/utils/maths.js | 14 ++--- src/utils/tensor.js | 52 +++++++++++++++---- .../test_pipelines_depth_estimation.js | 8 +-- 11 files changed, 121 insertions(+), 47 deletions(-) diff --git a/src/models.js b/src/models.js index 6037bfbf5..d0bc2311b 100644 --- a/src/models.js +++ b/src/models.js @@ -4463,6 +4463,7 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel { const image_nums = vision_tokens.filter(x => x == image_token_id).length; const video_nums = vision_tokens.filter(x => x == video_token_id).length; + /** @type {number[][]} */ let llm_pos_ids_list = []; let st = 0; let remain_images = image_nums; @@ -4532,6 +4533,7 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel { // NOTE: Each item in llm_pos_ids_list is an array of shape (3, text_len), // meaning to perform concatenation along dim=1, we can do the following: const num_items = llm_pos_ids_list.reduce((acc, x) => acc + x.length, 0); + /** @type {number[]} */ const llm_positions = new Array(num_items); let index = 0; for (let x = 0; x < 3; ++x) { @@ -4572,9 +4574,10 @@ export class Qwen2VLForConditionalGeneration extends Qwen2VLPreTrainedModel { { length: 3 * data.length }, (_, i) => data[i % data.length] ); + /** @type {bigint[]} */ const mrope_position_deltas = Array.from( { length: dims[0] }, - (_, i) => max(data.subarray(dims[1] * i, dims[1] * (i + 1)))[0] + 1 + dims[1] + (_, i) => max(data.subarray(dims[1] * i, dims[1] * (i + 1)))[0] + 1n + BigInt(dims[1]) ); return [ @@ -5145,7 +5148,7 @@ export class DPTModel extends DPTPreTrainedModel { } * * **Example:** Depth estimation w/ `Xenova/dpt-hybrid-midas`. * ```javascript - * import { DPTForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@huggingface/transformers'; + * import { DPTForDepthEstimation, AutoProcessor, RawImage, interpolate_4d } from '@huggingface/transformers'; * * // Load model and processor * const model_id = 'Xenova/dpt-hybrid-midas'; @@ -5154,7 +5157,7 @@ export class DPTModel extends DPTPreTrainedModel { } * * // Load image from URL * const url = 'http://images.cocodataset.org/val2017/000000039769.jpg'; - * const image = await RawImage.fromURL(url); + * const image = await RawImage.read(url); * * // Prepare image for the model * const inputs = await processor(image); @@ -5163,10 +5166,15 @@ export class DPTModel extends DPTPreTrainedModel { } * const { predicted_depth } = await model(inputs); * * // Interpolate to original size - * const prediction = interpolate(predicted_depth, image.size.reverse(), 'bilinear', false); + * const prediction = (await interpolate_4d(predicted_depth.unsqueeze(1), { + * size: image.size.reverse(), + * mode: 'bilinear', + * })).squeeze(1); * * // Visualize the prediction - * const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8'); + * const min = prediction.min().item(); + * const max = prediction.max().item(); + * const formatted = prediction.sub_(min).div_(max - min).mul_(255).to('uint8'); * const depth = RawImage.fromTensor(formatted); * // RawImage { * // data: Uint8Array(307200) [ 85, 85, 84, ... ], @@ -5216,11 +5224,7 @@ export class GLPNPreTrainedModel extends PreTrainedModel { } export class GLPNModel extends GLPNPreTrainedModel { } /** - * GLPN Model transformer with a lightweight depth estimation head on top e.g. for KITTI, NYUv2. - * - * **Example:** Depth estimation w/ `Xenova/glpn-kitti`. - * ```javascript - * import { GLPNForDepthEstimation, AutoProcessor, RawImage, interpolate, max } from '@huggingface/transformers'; + * import { GLPNForDepthEstimation, AutoProcessor, RawImage, interpolate_4d } from '@huggingface/transformers'; * * // Load model and processor * const model_id = 'Xenova/glpn-kitti'; @@ -5229,7 +5233,7 @@ export class GLPNModel extends GLPNPreTrainedModel { } * * // Load image from URL * const url = 'http://images.cocodataset.org/val2017/000000039769.jpg'; - * const image = await RawImage.fromURL(url); + * const image = await RawImage.read(url); * * // Prepare image for the model * const inputs = await processor(image); @@ -5238,13 +5242,18 @@ export class GLPNModel extends GLPNPreTrainedModel { } * const { predicted_depth } = await model(inputs); * * // Interpolate to original size - * const prediction = interpolate(predicted_depth, image.size.reverse(), 'bilinear', false); + * const prediction = (await interpolate_4d(predicted_depth.unsqueeze(1), { + * size: image.size.reverse(), + * mode: 'bilinear', + * })).squeeze(1); * * // Visualize the prediction - * const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8'); + * const min = prediction.min().item(); + * const max = prediction.max().item(); + * const formatted = prediction.sub_(min).div_(max - min).mul_(255).to('uint8'); * const depth = RawImage.fromTensor(formatted); * // RawImage { - * // data: Uint8Array(307200) [ 207, 169, 154, ... ], + * // data: Uint8Array(307200) [ 85, 85, 84, ... ], * // width: 640, * // height: 480, * // channels: 1 @@ -7747,10 +7756,17 @@ export class SequenceClassifierOutput extends ModelOutput { /** * @param {Object} output The output of the model. * @param {Tensor} output.logits classification (or regression if config.num_labels==1) scores (before SoftMax). + * @param {Record} [output.attentions] Object of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length, sequence_length)`. + * Attentions weights after the attention softmax, used to compute the weighted average in the self-attention heads. */ - constructor({ logits }) { + constructor({ logits, ...attentions }) { super(); this.logits = logits; + const attentions_list = Object.values(attentions); + if (attentions_list.length > 0) { + // Only set attentions if they are not empty + this.attentions = attentions_list; + } } } diff --git a/src/models/idefics3/image_processing_idefics3.js b/src/models/idefics3/image_processing_idefics3.js index 8864661c9..b01d62594 100644 --- a/src/models/idefics3/image_processing_idefics3.js +++ b/src/models/idefics3/image_processing_idefics3.js @@ -146,6 +146,8 @@ export class Idefics3ImageProcessor extends ImageProcessor { const start_offset = i * pixel_attention_mask_stride + num_patches * h * w; const end_offset = (i + 1) * pixel_attention_mask_stride; + + // @ts-expect-error pixel_attention_mask_data.fill(false, start_offset, end_offset); } } diff --git a/src/models/pyannote/feature_extraction_pyannote.js b/src/models/pyannote/feature_extraction_pyannote.js index a0044e159..e563bb9c8 100644 --- a/src/models/pyannote/feature_extraction_pyannote.js +++ b/src/models/pyannote/feature_extraction_pyannote.js @@ -52,6 +52,7 @@ export class PyAnnoteFeatureExtractor extends FeatureExtractor { let current_speaker = -1; for (let i = 0; i < scores.length; ++i) { + /** @type {number[]} */ const probabilities = softmax(scores[i]); const [score, id] = max(probabilities); const [start, end] = [i, i + 1]; diff --git a/src/models/seamless_m4t/feature_extraction_seamless_m4t.js b/src/models/seamless_m4t/feature_extraction_seamless_m4t.js index 8f02de062..12e8cfa45 100644 --- a/src/models/seamless_m4t/feature_extraction_seamless_m4t.js +++ b/src/models/seamless_m4t/feature_extraction_seamless_m4t.js @@ -133,8 +133,8 @@ export class SeamlessM4TFeatureExtractor extends FeatureExtractor { 'int64', new BigInt64Array(numPaddedFrames), [1, numPaddedFrames], - ) - padded_attention_mask.data.fill(1n, 0, num_frames); + ); + /** @type {BigInt64Array} */ (padded_attention_mask.data).fill(1n, 0, num_frames); } } } diff --git a/src/models/whisper/feature_extraction_whisper.js b/src/models/whisper/feature_extraction_whisper.js index f4d351f88..6f7bcdd56 100644 --- a/src/models/whisper/feature_extraction_whisper.js +++ b/src/models/whisper/feature_extraction_whisper.js @@ -44,7 +44,7 @@ export class WhisperFeatureExtractor extends FeatureExtractor { ) const data = features.data; - const maxValue = max(data)[0]; + const maxValue = max(/** @type {Float32Array} */(data))[0]; for (let i = 0; i < data.length; ++i) { data[i] = (Math.max(data[i], maxValue - 8.0) + 4.0) / 4.0; diff --git a/src/ops/registry.js b/src/ops/registry.js index ba6dd7ecf..247a0517f 100644 --- a/src/ops/registry.js +++ b/src/ops/registry.js @@ -36,6 +36,16 @@ export class TensorOpRegistry { // executionProviders: ['webgpu'], }; + static get nearest_interpolate_4d() { + if (!this._nearest_interpolate_4d) { + this._nearest_interpolate_4d = wrap( + [8, 10, 18, 0, 58, 129, 1, 10, 41, 10, 1, 120, 10, 0, 10, 0, 10, 1, 115, 18, 1, 121, 34, 6, 82, 101, 115, 105, 122, 101, 42, 18, 10, 4, 109, 111, 100, 101, 34, 7, 110, 101, 97, 114, 101, 115, 116, 160, 1, 3, 18, 1, 114, 90, 31, 10, 1, 120, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 90, 15, 10, 1, 115, 18, 10, 10, 8, 8, 7, 18, 4, 10, 2, 8, 4, 98, 31, 10, 1, 121, 18, 26, 10, 24, 8, 1, 18, 20, 10, 3, 18, 1, 98, 10, 3, 18, 1, 99, 10, 3, 18, 1, 104, 10, 3, 18, 1, 119, 66, 2, 16, 21], + this.session_options, + 'y', + ); + } + return this._nearest_interpolate_4d; + } static get bilinear_interpolate_4d() { if (!this._bilinear_interpolate_4d) { this._bilinear_interpolate_4d = wrap( diff --git a/src/pipelines.js b/src/pipelines.js index fcf29c8d5..87c489c38 100644 --- a/src/pipelines.js +++ b/src/pipelines.js @@ -69,7 +69,7 @@ import { import { Tensor, mean_pooling, - interpolate, + interpolate_4d, quantize_embeddings, topk, } from './utils/tensor.js'; @@ -2901,11 +2901,23 @@ export class DepthEstimationPipeline extends (/** @type {new (options: ImagePipe const toReturn = []; for (let i = 0; i < preparedImages.length; ++i) { - const prediction = interpolate(predicted_depth[i], preparedImages[i].size.reverse(), 'bilinear', false); - const formatted = prediction.mul_(255 / max(prediction.data)[0]).to('uint8'); + const batch = predicted_depth[i]; + const [height, width] = batch.dims.slice(-2); + const [new_width, new_height] = preparedImages[i].size; + + // Interpolate to original size + const prediction = (await interpolate_4d(batch.view(1, 1, height, width), { + size: [new_height, new_width], + mode: 'bilinear', + })).view(new_height, new_width); + + const minval = /** @type {number} */(prediction.min().item()); + const maxval = /** @type {number} */(prediction.max().item()); + const formatted = prediction.sub(minval).div_(maxval - minval).mul_(255).to('uint8').unsqueeze(0); + const depth = RawImage.fromTensor(formatted); toReturn.push({ - predicted_depth: predicted_depth[i], - depth: RawImage.fromTensor(formatted), + predicted_depth: prediction, + depth, }); } diff --git a/src/tokenizers.js b/src/tokenizers.js index b45de6e24..0b9295d8f 100644 --- a/src/tokenizers.js +++ b/src/tokenizers.js @@ -533,7 +533,7 @@ class Unigram extends TokenizerModel { * Create a new Unigram tokenizer model. * @param {Object} config The configuration object for the Unigram model. * @param {number} config.unk_id The ID of the unknown token - * @param {any[][]} config.vocab A 2D array representing a mapping of tokens to scores. + * @param {[string, number][]} config.vocab A 2D array representing a mapping of tokens to scores. * @param {Object} moreConfig Additional configuration object for the Unigram model. */ constructor(config, moreConfig) { @@ -541,11 +541,10 @@ class Unigram extends TokenizerModel { const vocabSize = config.vocab.length; this.vocab = new Array(vocabSize); + /** @type {number[]} */ this.scores = new Array(vocabSize); for (let i = 0; i < vocabSize; ++i) { - const piece = config.vocab[i]; - this.vocab[i] = piece[0]; - this.scores[i] = piece[1]; + [this.vocab[i], this.scores[i]] = config.vocab[i]; } this.unk_token_id = config.unk_id; diff --git a/src/utils/maths.js b/src/utils/maths.js index e6cb2d6ca..107068dca 100644 --- a/src/utils/maths.js +++ b/src/utils/maths.js @@ -225,8 +225,9 @@ export function magnitude(arr) { /** * Returns the value and index of the minimum element in an array. - * @param {number[]|TypedArray} arr array of numbers. - * @returns {[number, number]} the value and index of the minimum element, of the form: [valueOfMin, indexOfMin] + * @template {number[]|bigint[]|AnyTypedArray} T + * @param {T} arr array of numbers. + * @returns {T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number]} the value and index of the minimum element, of the form: [valueOfMin, indexOfMin] * @throws {Error} If array is empty. */ export function min(arr) { @@ -239,14 +240,15 @@ export function min(arr) { indexOfMin = i; } } - return [min, indexOfMin]; + return /** @type {T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number]} */([min, indexOfMin]); } /** * Returns the value and index of the maximum element in an array. - * @param {number[]|AnyTypedArray} arr array of numbers. - * @returns {[number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax] + * @template {number[]|bigint[]|AnyTypedArray} T + * @param {T} arr array of numbers. + * @returns {T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number]} the value and index of the maximum element, of the form: [valueOfMax, indexOfMax] * @throws {Error} If array is empty. */ export function max(arr) { @@ -259,7 +261,7 @@ export function max(arr) { indexOfMax = i; } } - return [Number(max), indexOfMax]; + return /** @type {T extends bigint[]|BigTypedArray ? [bigint, number] : [number, number]} */([max, indexOfMax]); } function isPowerOfTwo(number) { diff --git a/src/utils/tensor.js b/src/utils/tensor.js index 93a3e108e..fda0aa04c 100644 --- a/src/utils/tensor.js +++ b/src/utils/tensor.js @@ -9,6 +9,8 @@ import { interpolate_data, + max, + min, permute_data } from './maths.js'; @@ -464,8 +466,6 @@ export class Tensor { return this.permute(...dims); } - // TODO add .max() and .min() methods - /** * Returns the sum of each row of the input tensor in the given dimension dim. * @@ -759,6 +759,36 @@ export class Tensor { return mean(this, dim, keepdim); } + min(dim = null, keepdim = false) { + if (dim !== null) { + throw new Error("`dim !== null` not yet implemented."); + } + const value = min(this.data)[0]; + return new Tensor(this.type, [value], []); + } + max(dim = null, keepdim = false) { + if (dim !== null) { + throw new Error("`dim !== null` not yet implemented."); + } + const value = max(this.data)[0]; + return new Tensor(this.type, [value], []); + } + + argmin(dim = null, keepdim = false) { + if (dim !== null) { + throw new Error("`dim !== null` not yet implemented."); + } + const index = min(this.data)[1]; + return new Tensor('int64', [BigInt(index)], []); + } + argmax(dim = null, keepdim = false) { + if (dim !== null) { + throw new Error("`dim !== null` not yet implemented."); + } + const index = max(this.data)[1]; + return new Tensor('int64', [BigInt(index)], []); + } + /** * Performs Tensor dtype conversion. * @param {DataType} type The desired data type. @@ -892,7 +922,7 @@ export function interpolate(input, [out_height, out_width], mode = 'bilinear', a * @param {Tensor} input the input tensor * @param {Object} options the options for the interpolation * @param {[number, number]|[number, number, number]|[number, number, number, number]} [options.size=null] output spatial size. - * @param {"bilinear"|"bicubic"} [options.mode='bilinear'] algorithm used for upsampling + * @param {"nearest"|"bilinear"|"bicubic"} [options.mode='bilinear'] algorithm used for upsampling * @returns {Promise} The interpolated tensor. */ export async function interpolate_4d(input, { @@ -922,7 +952,9 @@ export async function interpolate_4d(input, { } let op; - if (mode === 'bilinear') { + if (mode === 'nearest') { + op = await TensorOpRegistry.nearest_interpolate_4d; + } else if (mode === 'bilinear') { op = await TensorOpRegistry.bilinear_interpolate_4d; } else if (mode === 'bicubic') { op = await TensorOpRegistry.bicubic_interpolate_4d; @@ -963,13 +995,13 @@ export async function rfft(x, a) { * Returns the k largest elements of the given input tensor. * Inspired by https://pytorch.org/docs/stable/generated/torch.topk.html * @param {Tensor} x the input tensor - * @param {number} k the k in "top-k" + * @param {number} [k] the k in "top-k" * @returns {Promise<[Tensor, Tensor]>} the output tuple of (Tensor, LongTensor) of top-k elements and their indices. */ export async function topk(x, k) { const op = await TensorOpRegistry.top_k; - if (k === null) { + if (k == null) { k = x.dims.at(-1); } else { k = Math.min(k, x.dims.at(-1)); @@ -998,10 +1030,10 @@ const arrayToIndexTensor = (array) => new Tensor('int64', array, [array.length]) export async function slice(data, starts, ends, axes, steps) { const op = await TensorOpRegistry.slice; return await op({ - x: data, - s: arrayToIndexTensor(starts), - e: arrayToIndexTensor(ends), - a: arrayToIndexTensor(axes), + x: data, + s: arrayToIndexTensor(starts), + e: arrayToIndexTensor(ends), + a: arrayToIndexTensor(axes), t: arrayToIndexTensor(steps ?? new Array(axes.length).fill(1)), }); } diff --git a/tests/pipelines/test_pipelines_depth_estimation.js b/tests/pipelines/test_pipelines_depth_estimation.js index f0d5fe887..534e91030 100644 --- a/tests/pipelines/test_pipelines_depth_estimation.js +++ b/tests/pipelines/test_pipelines_depth_estimation.js @@ -1,4 +1,4 @@ -import { pipeline, DepthEstimationPipeline, RawImage } from "../../src/transformers.js"; +import { pipeline, DepthEstimationPipeline } from "../../src/transformers.js"; import { MAX_MODEL_LOAD_TIME, MAX_TEST_EXECUTION_TIME, MAX_MODEL_DISPOSE_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js"; import { load_cached_image } from "../asset_cache.js"; @@ -25,7 +25,7 @@ export default () => { "default", async () => { const output = await pipe(images[0]); - expect(output.predicted_depth.dims).toEqual([32, 32]); + expect(output.predicted_depth.dims).toEqual([224, 224]); expect(output.predicted_depth.mean().item()).toBeCloseTo(0.000006106501587055391, 6); expect(output.depth.size).toEqual(images[0].size); }, @@ -39,10 +39,10 @@ export default () => { async () => { const output = await pipe(images); expect(output).toHaveLength(images.length); - expect(output[0].predicted_depth.dims).toEqual([32, 32]); + expect(output[0].predicted_depth.dims).toEqual([224, 224]); expect(output[0].predicted_depth.mean().item()).toBeCloseTo(0.000006106501587055391, 6); expect(output[0].depth.size).toEqual(images[0].size); - expect(output[1].predicted_depth.dims).toEqual([32, 32]); + expect(output[1].predicted_depth.dims).toEqual([224, 224]); expect(output[1].predicted_depth.mean().item()).toBeCloseTo(0.0000014548650142387487, 6); expect(output[1].depth.size).toEqual(images[1].size); },