fix(snip): markup/math detection edge cases

This commit is contained in:
arne314
2025-06-14 14:41:48 +02:00
parent 8da248151f
commit 7cd4751162
2 changed files with 18 additions and 8 deletions

View File

@@ -20,10 +20,10 @@ vim.api.nvim_create_autocmd('TextChangedI', {
M.in_math = function()
local cursor = utils.get_cursor_pos()
return utils.cursor_within_treesitter_query(ts_math_query, 0, cursor)
and not utils.cursor_within_treesitter_query(ts_string_query, 0, cursor)
return utils.cursor_within_treesitter_query(ts_math_query, 0, 0, cursor)
and not utils.cursor_within_treesitter_query(ts_string_query, 0, 0, cursor)
end
M.in_markup = function() return utils.cursor_within_treesitter_query(ts_markup_query, 1) end
M.in_markup = function() return utils.cursor_within_treesitter_query(ts_markup_query, 1, 2) end
M.not_in_math = function() return not M.in_math() end
M.not_in_markup = function() return not M.in_markup() end
M.snippets_toggle = true

View File

@@ -92,25 +92,35 @@ function M.treesitter_match_start_end(match)
return start_row, start_col, end_row, end_col
end
function M.cursor_within_treesitter_query(query, match_tolerance, cursor)
function M.cursor_within_treesitter_query(query, match_tolerance_l, match_tolerance_r, cursor)
cursor = cursor or M.get_cursor_pos()
match_tolerance_l = match_tolerance_l or 0
match_tolerance_r = match_tolerance_r or 0
local bufnr = vim.api.nvim_get_current_buf()
local root = M.get_treesitter_root(bufnr)
for _, match in ipairs(M.treesitter_iter_matches(root, query, bufnr, cursor[1], cursor[1] + 1)) do
for _, nodes in pairs(match) do
local start_row, start_col, end_row, end_col = M.treesitter_match_start_end(nodes)
local matched = M.cursor_within_coords(cursor, start_row, end_row, start_col, end_col, match_tolerance)
local matched = M.cursor_within_coords(
cursor,
start_row,
end_row,
start_col,
end_col,
match_tolerance_l,
match_tolerance_r
)
if matched then return true end
end
end
return false
end
function M.cursor_within_coords(cursor, start_row, end_row, start_col, end_col, match_tolerance)
function M.cursor_within_coords(cursor, start_row, end_row, start_col, end_col, match_tolerance_l, match_tolerance_r)
if start_row <= cursor[1] and end_row >= cursor[1] then
if start_row == cursor[1] and start_col - match_tolerance >= cursor[2] then
if start_row == cursor[1] and start_col - match_tolerance_l >= cursor[2] then
return false
elseif end_row == cursor[1] and end_col + match_tolerance <= cursor[2] then
elseif end_row == cursor[1] and end_col + match_tolerance_r <= cursor[2] then
return false
end
return true