Skip to content

Commit

Permalink
gpu + cpu
Browse files Browse the repository at this point in the history
  • Loading branch information
gaetanbrison committed Jan 6, 2025
1 parent f7fccf9 commit adfef58
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
13 changes: 10 additions & 3 deletions carte_ai/src/carte_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,24 @@
def _carte_calculate_attention(
edge_index: Tensor, query: Tensor, key: Tensor, value: Tensor
):
## Fix to work on cpu and gpu provided by Ayoub Kachkach
# Calculate the scaled-dot product attention
attention = torch.sum(torch.mul(query[edge_index[0], :], key), dim=1)
attention = attention / math.sqrt(query.size(1))
attention = softmax(attention, edge_index[0])


# Ensure `attention` and `value` have the same dtype
attention = attention.to(value.dtype)

# Generate the output
src = torch.mul(attention, value.t()).t()


# Ensure `src` and `query` have the same dtype
src = src.to(query.dtype)

# Use torch.index_add_ to replace scatter function
output = torch.zeros_like(query).index_add_(0, edge_index[0], src)

return output, attention


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "carte-ai"
version = "0.0.23"
version = "0.0.25"
description = "CARTE-AI: Context Aware Representation of Table Entries for AI"
readme = { file = "README.md", content-type = "text/markdown" }
requires-python = ">=3.10.12"
Expand Down

0 comments on commit adfef58

Please sign in to comment.