diff --git a/carte_ai/src/carte_model.py b/carte_ai/src/carte_model.py index ff2e2b0..ef4a604 100644 --- a/carte_ai/src/carte_model.py +++ b/carte_ai/src/carte_model.py @@ -9,17 +9,24 @@ def _carte_calculate_attention( edge_index: Tensor, query: Tensor, key: Tensor, value: Tensor ): + ## Fix to work on cpu and gpu provided by Ayoub Kachkach # Calculate the scaled-dot product attention attention = torch.sum(torch.mul(query[edge_index[0], :], key), dim=1) attention = attention / math.sqrt(query.size(1)) attention = softmax(attention, edge_index[0]) - + + # Ensure `attention` and `value` have the same dtype + attention = attention.to(value.dtype) + # Generate the output src = torch.mul(attention, value.t()).t() - + + # Ensure `src` and `query` have the same dtype + src = src.to(query.dtype) + # Use torch.index_add_ to replace scatter function output = torch.zeros_like(query).index_add_(0, edge_index[0], src) - + return output, attention diff --git a/pyproject.toml b/pyproject.toml index 36fd8a3..e958cd3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "carte-ai" -version = "0.0.23" +version = "0.0.25" description = "CARTE-AI: Context Aware Representation of Table Entries for AI" readme = { file = "README.md", content-type = "text/markdown" } requires-python = ">=3.10.12"