From 5529a29327b20fd2f19e0391e15a05be600a1e4b Mon Sep 17 00:00:00 2001 From: arne314 <73391160+arne314@users.noreply.github.com> Date: Sun, 8 Dec 2024 18:39:58 +0100 Subject: [PATCH] perf: use treesitter to detect math/markup --- lua/typstar/autosnippets.lua | 20 +++++++++++------ lua/typstar/snippets/math.lua | 1 + lua/typstar/snippets/visual.lua | 2 +- lua/typstar/utils.lua | 38 ++++++++++++++++++++++++++++++++- 4 files changed, 52 insertions(+), 9 deletions(-) diff --git a/lua/typstar/autosnippets.lua b/lua/typstar/autosnippets.lua index d8bf6a4..c2b7d2e 100644 --- a/lua/typstar/autosnippets.lua +++ b/lua/typstar/autosnippets.lua @@ -1,24 +1,30 @@ local M = {} local cfg = require('typstar.config').config.snippets +local utils = require('typstar.utils') local luasnip = require('luasnip') local fmta = require('luasnip.extras.fmt').fmta local lsengines = require('luasnip.nodes.util.trig_engines') +local ts = vim.treesitter local last_keystroke_time = nil -vim.api.nvim_create_autocmd("TextChangedI", { +vim.api.nvim_create_autocmd('TextChangedI', { callback = function() last_keystroke_time = vim.loop.now() end, }) local lexical_result_cache = {} -M.in_math = function() return vim.api.nvim_eval('typst#in_math()') == 1 end -M.in_markup = function() return vim.api.nvim_eval('typst#in_markup()') == 1 end -M.in_code = function() return vim.api.nvim_eval('typst#in_code()') == 1 end -M.in_comment = function() return vim.api.nvim_eval('typst#in_comment()') == 1 end +local ts_markup_query = ts.query.parse('typst', '(text) @markup') +local ts_math_query = ts.query.parse('typst', '(math) @math') +local ts_string_query = ts.query.parse('typst', '(string) @string') + +M.in_math = function() + local cursor = utils.get_cursor_pos() + return utils.cursor_inside_treesitter_query(ts_math_query, cursor) + and not utils.cursor_inside_treesitter_query(ts_string_query, cursor) +end +M.in_markup = function() return utils.cursor_inside_treesitter_query(ts_markup_query) end M.not_in_math = function() return not M.in_math() end M.not_in_markup = function() return not M.in_markup() end -M.not_in_code = function() return not M.in_code() end -M.not_in_comment = function() return not M.in_comment() end M.snippets_toggle = true function M.cap(i) diff --git a/lua/typstar/snippets/math.lua b/lua/typstar/snippets/math.lua index 7a0ee5c..b9d5aa8 100644 --- a/lua/typstar/snippets/math.lua +++ b/lua/typstar/snippets/math.lua @@ -63,6 +63,7 @@ return { snip('rrn', 'RR^n ', {}, math), snip('cc', 'cases(\n\t<>\n)\\', { i(1, '1') }, math), snip('pi', 'pi ', {}, math), + snip('in', 'in ', {}, math), snip('(.*)iv', '<>^(-1)', { cap(1) }, math), snip('(.*)sr', '<>^(2)', { cap(1) }, math), snip('(.*)rd', '<>^(<>)', { cap(1), i(1, 'n') }, math), diff --git a/lua/typstar/snippets/visual.lua b/lua/typstar/snippets/visual.lua index 693bf8c..fef8eab 100644 --- a/lua/typstar/snippets/visual.lua +++ b/lua/typstar/snippets/visual.lua @@ -14,7 +14,7 @@ local operations = { -- boolean denotes whether an additional layer of () bracke { 'vi', '1/(', ')', true }, { 'bb', '(', ')', false }, { 'sq', '[', ']', true }, - { 'abs', '|', '|', false }, + { 'abs', 'abs(', ')', false }, { 'ul', 'underline(', ')', false }, { 'ol', 'overline(', ')', false }, { 'ub', 'underbrace(', ')', false }, diff --git a/lua/typstar/utils.lua b/lua/typstar/utils.lua index 1e378bf..4bc849b 100644 --- a/lua/typstar/utils.lua +++ b/lua/typstar/utils.lua @@ -1,7 +1,14 @@ local M = {} +local ts = vim.treesitter + +function M.get_cursor_pos() + local cursor_row, cursor_col = unpack(vim.api.nvim_win_get_cursor(0)) + cursor_row = cursor_row - 1 + return { cursor_row, cursor_col } +end function M.insert_snippet(snip) - local line_num = vim.fn.getcurpos()[2] + local line_num = M.get_cursor_pos()[1] + 1 local lines = {} for line in snip:gmatch '[^\r\n]+' do table.insert(lines, line) @@ -13,4 +20,33 @@ function M.run_shell_command(cmd) vim.fn.jobstart(cmd) end +function M.cursor_inside_treesitter_query(query, cursor) + cursor = cursor or M.get_cursor_pos() + local bufnr = vim.api.nvim_get_current_buf() + local root = ts.get_parser(bufnr):parse()[1]:root() + for _, match, _ in query:iter_matches(root, bufnr, cursor[1], cursor[1] + 1) do + if match then + local start_row, start_col, _, _ = match[1]:range() + local _, _, end_row, end_col = match[#match]:range() + local matched = M.cursor_inside_coords(cursor, start_row, end_row, start_col, end_col) + if matched then + return true + end + end + end + return false +end + +function M.cursor_inside_coords(cursor, start_row, end_row, start_col, end_col) + if start_row <= cursor[1] and end_row >= cursor[1] then + if start_row == cursor[1] and start_col > cursor[2] then + return false + elseif end_row == cursor[1] and end_col < cursor[2] then + return false + end + return true + end + return false +end + return M