Skip to content

Commit

Permalink
feat: treesitter-based completions (#1412)
Browse files Browse the repository at this point in the history
  • Loading branch information
benlubas authored May 23, 2024
1 parent 2e4e7ec commit 79f6a49
Show file tree
Hide file tree
Showing 3 changed files with 119 additions and 107 deletions.
216 changes: 110 additions & 106 deletions lua/neorg/modules/core/completion/module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ link completions are smart about closing `:` and `}`.
--]]

local neorg = require("neorg.core")
local Path = require("pathlib")
local log, modules, utils = neorg.log, neorg.modules, neorg.utils
local dirutils, dirman, link_utils, treesitter

local module = modules.create("core.completion")

Expand All @@ -43,7 +45,10 @@ module.config.public = {
}

module.setup = function()
return { success = true, requires = { "core.dirman", "core.integrations.treesitter" } }
return {
success = true,
requires = { "core.dirman", "core.dirman.utils", "core.integrations.treesitter", "core.links" },
}
end

module.private = {
Expand All @@ -52,12 +57,6 @@ module.private = {
--- Get a list of all norg files in current workspace. Returns { workspace_path, norg_files }
--- @return { [1]: PathlibPath, [2]: PathlibPath[]|nil }|nil
get_norg_files = function()
---@type core.dirman
local dirman = neorg.modules.get_module("core.dirman")
if not dirman then
return nil
end

local current_workspace = dirman.get_current_workspace()
local norg_files = dirman.get_norg_files(current_workspace[1])
return { current_workspace[2], norg_files }
Expand Down Expand Up @@ -86,108 +85,68 @@ module.private = {
return closing_colon .. closing_brace
end,

--- Get the lines in a given norg file path.
--- @param file string file path, norg syntax accepted
--- @return table<string>
get_lines = function(file)
---@type core.dirman.utils
local dirutils = neorg.modules.get_module("core.dirman.utils")
if not dirutils then
return {}
--- query all the linkable items in a given buffer/file for a given link type
---@param source number | string | PathlibPath bufnr or file path
---@param link_type "generic" | "definition" | "footnote" | string
get_linkables = function(source, link_type)
local query_str = link_utils.get_link_target_query_string(link_type)
local norg_parser
local iter_src
if type(source) ~= "string" and type(source) ~= "number" then
source = tostring(source)
end
local expanded = dirutils.expand_path(file, true)

local lines
if expanded then
if not string.match(expanded, "%.norg$") then
expanded = expanded .. ".norg"
end
local ok
ok, lines = pcall(vim.fn.readfile, expanded)
if not ok then
lines = {}
if type(source) == "string" then
-- check if the file is open; use the buffer contents if it is
if vim.fn.bufnr(source) ~= -1 then
source = vim.uri_to_bufnr(vim.uri_from_fname(source))
else
iter_src = io.open(source, "r"):read("*a")
norg_parser = vim.treesitter.get_string_parser(iter_src, "norg")
end
end
return lines
end,

--- Find linkable headers in the given file
--- @param file string file path, norg syntax is accepted
--- @param context table
--- @param heading_level number?
--- @return table<string>
find_headers = function(file, context, heading_level)
local leading_whitespace = " "
if context.before_char == " " then
leading_whitespace = ""
end

local closing_chars = module.private.get_closing_chars(context, false)
leading_whitespace = leading_whitespace or ""
local ret = {}

local lines = module.private.get_lines(file)
for _, line in ipairs(lines) do
local heading = { line:match("^%s*(%*+)%s+(.+)$") }
if not vim.tbl_isempty(heading) and (not heading_level or #heading[1] == heading_level) then
-- remove potential GTD status from link
local stripped_heading = string.gsub(heading[2], "^%(.%)%s?", "")
table.insert(ret, leading_whitespace .. stripped_heading .. closing_chars)
if type(source) == "number" then
if source == 0 then
source = vim.api.nvim_get_current_buf()
end
-- local marker_or_drawer = { line:match("^%s*(%|%|?%s+(.+))$") }
-- if not vim.tbl_isempty(marker_or_drawer) then
-- -- TODO: how do you link to these things
-- -- what even are they?
-- table.insert(ret, marker_or_drawer[2])
-- end
norg_parser = vim.treesitter.get_parser(source, "norg")
iter_src = source
end

return ret
end,

--- Find footers in the given file
--- @param file string file path, norg syntax is accepted
--- @return table<string>
find_footnotes = function(file, context)
local ret = {}
local leading_whitespace = " "
if context.before_char == " " then
leading_whitespace = ""
if not norg_parser then
return {}
end

local closing_chars = module.private.get_closing_chars(context, false)
leading_whitespace = leading_whitespace or ""
local lines = module.private.get_lines(file)
for _, line in ipairs(lines) do
local footnote = { line:match("^%s*%^%^? (.+)$") }
if not vim.tbl_isempty(footnote) then
table.insert(ret, leading_whitespace .. footnote[1] .. closing_chars)
local norg_tree = norg_parser:parse()[1]
local query = vim.treesitter.query.parse("norg", query_str)
local links = {}
for id, node in query:iter_captures(norg_tree:root(), iter_src) do
local capture = query.captures[id]
if capture == "title" then
local original_title = treesitter.get_node_text(node, iter_src)
if original_title then
local title = original_title:gsub("\\", "")
title = title:gsub("%s+", " ")
title = title:gsub("^%s+", "")
table.insert(links, {
original_title = original_title,
title = title,
node = node,
})
end
end
end

return ret
return links
end,

generate_file_links = function(context, _prev, _saved, _match)
local res = {}
---@type core.dirman
local dirman = neorg.modules.get_module("core.dirman")
if not dirman then
return {}
end

local files = module.private.get_norg_files()
if not files or not files[2] then
return {}
end

local closing_chars = module.private.get_closing_chars(context, true)
for _, filepath in pairs(files[2]) do
local file = tostring(filepath)
local bufnr = dirman.get_file_bufnr(file)

if vim.api.nvim_get_current_buf() ~= bufnr then
local rel = filepath:relative_to(files[1], false)
for _, file in pairs(files[2]) do
if not file:samefile(Path.new(vim.api.nvim_buf_get_name(0))) then
local rel = file:relative_to(files[1], false)
if rel and rel:len() > 0 then
local link = "{:$/" .. rel:with_suffix(""):tostring() .. closing_chars
table.insert(res, link)
Expand All @@ -198,27 +157,66 @@ module.private = {
return res
end,

generate_local_heading_links = function(context, _prev, _saved, match)
--- Generate list of autocompletion suggestions for links
--- @param context table
--- @param source number | string | PathlibPath
--- @param node_type string
--- @return string[]
suggestions = function(context, source, node_type)
local leading_whitespace = " "
if context.before_char == " " then
leading_whitespace = ""
end
local links = module.private.get_linkables(source, node_type)
local closing_chars = module.private.get_closing_chars(context, false)
return vim.tbl_map(function(x)
return leading_whitespace .. x.title .. closing_chars
end, links)
-- return vim.iter(links)
-- :map(function(x)
-- return leading_whitespace .. x.title .. closing_chars
-- end)
-- :totable()
end,

--- All the things that you can link to (`{#|}` completions)
local_link_targets = function(context, _prev, _saved, _match)
return module.private.suggestions(context, 0, "generic")
end,

local_heading_links = function(context, _prev, _saved, match)
local heading_level = match[2] and #match[2]
return module.private.find_headers(vim.api.nvim_buf_get_name(0), context, heading_level)
return module.private.suggestions(context, 0, ("heading%d"):format(heading_level))
end,

generate_foreign_heading_links = function(context, _prev, _saved, match)
foreign_heading_links = function(context, _prev, _saved, match)
local file = match[1]
local heading_level = match[2] and #match[2]
if file then
return module.private.find_headers(file, context, heading_level)
file = dirutils.expand_pathlib(file)
return module.private.suggestions(context, file, ("heading%d"):format(heading_level))
end
return {}
end,

generate_local_footnote_links = function(context, _prev, _saved, _match)
return module.private.find_footnotes(vim.api.nvim_buf_get_name(0), context)
foreign_generic_links = function(context, _prev, _saved, match)
local file = match[1]
if file then
file = dirutils.expand_pathlib(file)
return module.private.suggestions(context, file, "generic")
end
return {}
end,

generate_foreign_footnote_links = function(context, _prev, _saved, match)
local_footnote_links = function(context, _prev, _saved, _match)
return module.private.suggestions(context, 0, "footnote")
end,

foreign_footnote_links = function(context, _prev, _saved, match)
local file = match[2]
if match[2] then
return module.private.find_footnotes(match[2], context)
file = dirutils.expand_pathlib(file)
return module.private.suggestions(context, file, "footnote")
end
return {}
end,
Expand Down Expand Up @@ -255,6 +253,11 @@ module.load = function()
return
end

dirutils = module.required["core.dirman.utils"]
dirman = module.required["core.dirman"]
link_utils = module.required["core.links"]
treesitter = module.required["core.integrations.treesitter"]

-- Set a special function in the integration module to allow it to communicate with us
module.private.engine.invoke_completion_engine = function(context) ---@diagnostic disable-line
return module.public.complete(context) ---@diagnostic disable-line -- TODO: type error workaround <pysan3>
Expand Down Expand Up @@ -456,7 +459,7 @@ module.public = {
{ -- links that have a file path, suggest any heading from the file `{:...:#|}`
regex = "^.*{:(.*):#[^}]*",

complete = module.private.generate_foreign_heading_links,
complete = module.private.foreign_generic_links,

node = module.private.normal_norg,

Expand All @@ -468,7 +471,7 @@ module.public = {
{ -- links that have a file path, suggest direct headings from the file `{:...:*|}`
regex = "^.*{:(.*):(%*+)[^}]*",

complete = module.private.generate_foreign_heading_links,
complete = module.private.foreign_heading_links,

node = module.private.normal_norg,

Expand All @@ -480,7 +483,8 @@ module.public = {
{ -- # links to headings in the current file `{#|}`
regex = "^.*{#[^}]*",

complete = module.private.generate_local_heading_links,
-- complete = module.private.generate_local_heading_links,
complete = module.private.local_link_targets,

node = module.private.normal_norg,

Expand All @@ -494,7 +498,7 @@ module.public = {
-- the first capture group is a nothing group so that match[2] is reliably the heading
-- level or nil if there's no heading level.

complete = module.private.generate_local_heading_links,
complete = module.private.local_heading_links,

node = module.private.normal_norg,

Expand All @@ -506,7 +510,7 @@ module.public = {
{ -- ^ footnote links in the current file `{^|}`
regex = "^(.*){%^[^}]*",

complete = module.private.generate_local_footnote_links,
complete = module.private.local_footnote_links,

node = module.private.normal_norg,

Expand All @@ -518,7 +522,7 @@ module.public = {
{ -- ^ footnote links in another file `{:path:^|}`
regex = "^(.*){:(.*):%^[^}]*",

complete = module.private.generate_foreign_footnote_links,
complete = module.private.foreign_footnote_links,

node = module.private.normal_norg,

Expand Down Expand Up @@ -560,7 +564,7 @@ module.public = {
-- If the completion data has a node variable then attempt to match the current node too!
if completion_data.node then
-- Grab the treesitter utilities
local ts = module.required["core.integrations.treesitter"].get_ts_utils()
local ts = treesitter.get_ts_utils()

-- If the type of completion data we're dealing with is a string then attempt to parse it
if type(completion_data.node) == "string" then
Expand Down
1 change: 0 additions & 1 deletion lua/neorg/modules/core/esupports/hop/module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -583,7 +583,6 @@ module.public = {

_ = function()
local query_str = links.get_link_target_query_string(parsed_link_information.link_type)

local document_root = module.required["core.integrations.treesitter"].get_document_root(buf_pointer)

if not document_root then
Expand Down
9 changes: 9 additions & 0 deletions lua/neorg/modules/core/integrations/treesitter/module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,15 @@ module.public = {
descend(root)
end,
get_node_text = function(node, source)
if not node then return "" end

-- when source is the string contents of the file
if type(source) == "string" then
local _, _, start_bytes = node:start()
local _, _, end_bytes = node:end_()
return string.sub(source, start_bytes, end_bytes)
end

source = source or 0

local start_row, start_col = node:start()
Expand Down

0 comments on commit 79f6a49

Please sign in to comment.