diff --git a/script/core/definition.lua b/script/core/definition.lua index 3619916e6..47af10fae 100644 --- a/script/core/definition.lua +++ b/script/core/definition.lua @@ -7,18 +7,70 @@ local rpath = require 'workspace.require-path' local jumpSource = require 'core.jump-source' local wssymbol = require 'core.workspace-symbol' -local function sortResults(results) +--- @param s string +--- @return string[] +local function split(s) + local r = {} + s:gsub('[^/]+', function (w) + r[#r+1] = w:gsub("~1", "/"):gsub("~0", "~") + end) + return r +end + +--- Returns the Levenshtein distance between the two given string arrays +--- @param a string[] +--- @param b string[] +--- @return number +local function levenshtein_distance(a, b) + local a_len, b_len = #a, #b + local matrix = {} --- @type integer[][] + + -- Initialize the matrix + for i = 1, a_len + 1 do + matrix[i] = { [1] = i } + end + + for j = 1, b_len + 1 do + matrix[1][j] = j + end + + -- Compute the Levenshtein distance + for i = 1, a_len do + for j = 1, b_len do + local cost = (a[i] == b[j]) and 0 or 1 + matrix[i + 1][j + 1] = + math.min(matrix[i][j + 1] + 1, matrix[i + 1][j] + 1, matrix[i][j] + cost) + end + end + + -- Return the Levenshtein distance + return matrix[a_len + 1][b_len + 1] +end + +--- @param path1 string +--- @param path2 string +--- @return number +local function path_similarity_ratio(path1, path2) + local parts1 = split(path1) + local parts2 = split(path2) + local distance = levenshtein_distance(parts1, parts2) + return distance * 2 / (#parts1 + #parts2) +end + +local function sortResults(results, uri) -- 先按照顺序排序 + -- Sort in order first table.sort(results, function (a, b) local u1 = guide.getUri(a.target) local u2 = guide.getUri(b.target) if u1 == u2 then return a.target.start < b.target.start else - return u1 < u2 + return path_similarity_ratio(uri, u1) < path_similarity_ratio(uri, u2) end end) -- 如果2个结果处于嵌套状态,则取范围小的那个 + -- If two results are nested, take the one with the smaller range local lf, lu for i = #results, 1, -1 do local res = results[i].target @@ -141,7 +193,7 @@ return function (uri, offset) local results = {} local uris = checkRequire(source) if uris then - for i, uri in ipairs(uris) do + for _, uri in ipairs(uris) do results[#results+1] = { uri = uri, source = source, @@ -230,7 +282,7 @@ return function (uri, offset) return nil end - sortResults(results) + sortResults(results, uri) jumpSource(results) return results diff --git a/script/core/find-source.lua b/script/core/find-source.lua index c5d52f3e4..265fa2ca7 100644 --- a/script/core/find-source.lua +++ b/script/core/find-source.lua @@ -13,7 +13,7 @@ end return function (state, position, accept) local len = math.huge - local result + local result --- @type parser.object guide.eachSourceContain(state.ast, position, function (source) if source.type == 'function' then if not isValidFunctionPos(source, position) then