From e32609f9cd3a8cba3f6535a5185fa08f34820f92 Mon Sep 17 00:00:00 2001 From: arne314 <73391160+arne314@users.noreply.github.com> Date: Wed, 11 Dec 2024 22:01:42 +0100 Subject: [PATCH] feat: smart snippets using treesitter --- lua/typstar/autosnippets.lua | 3 +- lua/typstar/excalidraw.lua | 2 +- lua/typstar/snippets/letters.lua | 4 +- lua/typstar/snippets/math.lua | 5 +- lua/typstar/snippets/visual.lua | 124 ++++++++++++++++++------------- lua/typstar/utils.lua | 32 ++++++-- 6 files changed, 104 insertions(+), 66 deletions(-) diff --git a/lua/typstar/autosnippets.lua b/lua/typstar/autosnippets.lua index 1790302..ef9b780 100644 --- a/lua/typstar/autosnippets.lua +++ b/lua/typstar/autosnippets.lua @@ -43,13 +43,14 @@ function M.ri(insert_node_id) return luasnip.function_node(function(args) return args[1][1] end, insert_node_id) end -function M.snip(trigger, expand, insert, condition, priority) +function M.snip(trigger, expand, insert, condition, priority, wordTrig) priority = priority or 1000 return luasnip.snippet( { trig = trigger, trigEngine = M.engine, trigEngineOpts = { condition = condition }, + wordTrig = wordTrig, priority = priority, snippetType = 'autosnippet' }, diff --git a/lua/typstar/excalidraw.lua b/lua/typstar/excalidraw.lua index f2202da..f9d6503 100644 --- a/lua/typstar/excalidraw.lua +++ b/lua/typstar/excalidraw.lua @@ -26,7 +26,7 @@ function M.insert_drawing() local filename = os.date(cfg.filename) local path = assets_dir .. '/' .. filename .. '.excalidraw.md' local path_inserted = cfg.assetsDir .. '/' .. filename .. cfg.fileExtensionInserted - utils.insert_snippet(string.format(affix, path_inserted)) + utils.insert_text_block(string.format(affix, path_inserted)) launch_obsidian_open(path) end diff --git a/lua/typstar/snippets/letters.lua b/lua/typstar/snippets/letters.lua index 4230039..292ebc0 100644 --- a/lua/typstar/snippets/letters.lua +++ b/lua/typstar/snippets/letters.lua @@ -38,9 +38,9 @@ end local generate_index_snippets = function(letter) for _, index in pairs(common_indices) do table.insert(letter_snippets, - snip(letter .. '(' .. index .. ') ', letter .. '_(<>) ', { cap(1) }, math, 200)) + snip(letter .. '(' .. index .. ') ', letter .. '_<> ', { cap(1) }, math, 200)) table.insert(letter_snippets, - snip('\\$' .. letter .. '\\$(' .. index .. ') ', '$' .. letter .. '_(<>)$ ', { cap(1) }, markup, 200)) + snip('\\$' .. letter .. '\\$(' .. index .. ') ', '$' .. letter .. '_<>$ ', { cap(1) }, markup, 200)) end end diff --git a/lua/typstar/snippets/math.lua b/lua/typstar/snippets/math.lua index b9d5aa8..1af1fc8 100644 --- a/lua/typstar/snippets/math.lua +++ b/lua/typstar/snippets/math.lua @@ -65,8 +65,9 @@ return { snip('pi', 'pi ', {}, math), snip('in', 'in ', {}, math), snip('(.*)iv', '<>^(-1)', { cap(1) }, math), - snip('(.*)sr', '<>^(2)', { cap(1) }, math), - snip('(.*)rd', '<>^(<>)', { cap(1), i(1, 'n') }, math), + snip('(.*)sr', '<>^2', { cap(1) }, math), + snip('(.*)jj', '<>_(<>)', { cap(1), i(1, 'n') }, math), + snip('(.*)kk', '<>^(<>)', { cap(1), i(1, 'n') }, math), snip('ddx', '(d <>)(d <>)', { i(1, 'f'), i(2, 'x') }, math), snip('it', 'integral_(<>)^(<>)', { i(1, 'a'), i(2, 'b') }, math), diff --git a/lua/typstar/snippets/visual.lua b/lua/typstar/snippets/visual.lua index fef8eab..49c7676 100644 --- a/lua/typstar/snippets/visual.lua +++ b/lua/typstar/snippets/visual.lua @@ -1,72 +1,90 @@ +local ts = vim.treesitter local ls = require('luasnip') -local i = ls.insert_node local d = ls.dynamic_node -local f = ls.function_node +local i = ls.insert_node +local s = ls.snippet_node +local t = ls.text_node +local utils = require('typstar.utils') local helper = require('typstar.autosnippets') local math = helper.in_math local snip = helper.snip -local cap = helper.cap -local get_visual = helper.get_visual local snippets = {} -local operations = { -- boolean denotes whether an additional layer of () brackets should be removed - { 'vi', '1/(', ')', true }, - { 'bb', '(', ')', false }, - { 'sq', '[', ']', true }, - { 'abs', 'abs(', ')', false }, - { 'ul', 'underline(', ')', false }, - { 'ol', 'overline(', ')', false }, - { 'ub', 'underbrace(', ')', false }, - { 'ob', 'overbrace(', ')', false }, - { 'ht', 'hat(', ')', false }, - { 'br', 'macron(', ')', false }, - { 'dt', 'dot(', ')', false }, - { 'ci', 'circle(', ')', false }, - { 'td', 'tilde(', ')', false }, - { 'nr', 'norm(', ')', false }, - { 'vv', 'vec(', ')', false }, - { 'rt', 'sqrt(', ')', false }, +local operations = { -- first boolean: existing brackets should be kept; second boolean: brackets should be added + { 'vi', '1/', '', true, false }, + { 'bb', '(', ')', true, false }, -- add round brackets + { 'sq', '[', ']', true, false }, -- add square brackets + { 'bB', '(', ')', false, false }, -- replace with round brackets + { 'sQ', '[', ']', false, false }, -- replace with square brackets + { 'BB', '', '', false, false }, -- remove brackets + { 'ss', '"', '"', false, false }, + { 'abs', 'abs', '', true, true }, + { 'ul', 'underline', '', true, true }, + { 'ol', 'overline', '', true, true }, + { 'ub', 'underbrace', '', true, true }, + { 'ob', 'overbrace', '', true, true }, + { 'ht', 'hat', '', true, true }, + { 'br', 'macron', '', true, true }, + { 'dt', 'dot', '', true, true }, + { 'ci', 'circle', '', true, true }, + { 'td', 'tilde', '', true, true }, + { 'nr', 'norm', '', true, true }, + { 'vv', 'vec', '', true, true }, + { 'rt', 'sqrt', '', true, true }, } -local wrap_brackets = function(args, snippet, val) - local captured = snippet.captures[1] - local bracket_types = { [')'] = '(', [']'] = '[', ['}'] = '{' } - local closing_bracket = captured:sub(-1, -1) - local opening_bracket = bracket_types[closing_bracket] +local ts_wrap_query = ts.query.parse('typst', '[(call) (ident) (letter) (number)] @wrap') +local ts_wrapnobrackets_query = ts.query.parse('typst', '(group) @wrapnobrackets') - if opening_bracket == nil then - return captured - end - - local n_brackets = 0 - local char - - for i = #captured, 1, -1 do - char = captured:sub(i, i) - if char == closing_bracket then - n_brackets = n_brackets + 1 - elseif char == opening_bracket then - n_brackets = n_brackets - 1 - end - - if n_brackets == 0 then - local remove_additional = val[4] and opening_bracket == '(' - return captured:sub(1, i - 1) .. val[2] - .. captured:sub(i + (remove_additional and 1 or 0), -(remove_additional and 2 or 1)) .. val[3] +local process_ts_query = function(bufnr, cursor, query, root, insert1, insert2, cut_offset) + for _, match, _ in query:iter_matches(root, bufnr, cursor[1], cursor[1] + 1) do + if match then + local start_row, start_col, end_row, end_col = utils.treesitter_match_start_end(match) + if end_row == cursor[1] and end_col == cursor[2] then + vim.schedule(function() -- to not interfere with luasnip + local cursor_offset = 0 + local old_len1, new_len1 = utils.insert_text( + bufnr, start_row, start_col, insert1, 0, cut_offset) + if start_row == cursor[1] then + cursor_offset = cursor_offset + (new_len1 - old_len1) + end + local old_len2, new_len2 = utils.insert_text( + bufnr, end_row, cursor[2] + cursor_offset, insert2, cut_offset, 0) + if end_row == cursor[1] then + cursor_offset = cursor_offset + (new_len2 - old_len2) + end + vim.api.nvim_win_set_cursor(0, { cursor[1] + 1, cursor[2] + cursor_offset }) + end) + return true + end end end - return captured + return false +end + +local smart_wrap = function(args, parent, old_state, expand) + local bufnr = vim.api.nvim_get_current_buf() + local cursor = utils.get_cursor_pos() + local root = utils.get_treesitter_root(bufnr) + + if process_ts_query(bufnr, cursor, ts_wrapnobrackets_query, root, expand[2], expand[3], expand[4] and 0 or 1) then + return s(nil, t()) + end + + local expand1 = expand[5] and expand[2] .. '(' or expand[2] + local expand2 = expand[5] and expand[3] .. ')' or expand[3] + if process_ts_query(bufnr, cursor, ts_wrap_query, root, expand1, expand2) then + return s(nil, t()) + end + if #parent.env.LS_SELECT_RAW > 0 then + return s(nil, t(expand1 .. table.concat(parent.env.LS_SELECT_RAW) .. expand2)) + end + return s(nil, { t(expand1), i(1, '1+1'), t(expand2) }) end for _, val in pairs(operations) do - table.insert(snippets, snip(val[1], val[2] .. '<>' .. val[3], { d(1, get_visual) }, math)) - table.insert(snippets, snip('[\\s$]' .. val[1], val[2] .. '<>' .. val[3], { i(1, '1') }, math)) - table.insert(snippets, - snip('([\\w]+)' - .. val[1], val[2] .. '<>' .. val[3], { cap(1) }, math, 900)) - table.insert(snippets, - snip('(.*[\\)|\\]|\\}])' .. val[1], '<>', { f(wrap_brackets, {}, { user_args = { val } }), nil }, math, 1100)) + table.insert(snippets, snip(val[1], '<>', { d(1, smart_wrap, {}, { user_args = { val } }) }, math, 1000, false)) end return { diff --git a/lua/typstar/utils.lua b/lua/typstar/utils.lua index e1c8dbe..b6b5f6d 100644 --- a/lua/typstar/utils.lua +++ b/lua/typstar/utils.lua @@ -7,28 +7,46 @@ function M.get_cursor_pos() return { cursor_row, cursor_col } end -function M.insert_snippet(snip) +function M.insert_text(bufnr, row, col, snip, begin_offset, end_offset) + begin_offset = begin_offset or 0 + end_offset = end_offset or 0 + local line = vim.api.nvim_buf_get_lines(bufnr, row, row + 1, true)[1] + local old_len = #line + line = line:sub(1, col - begin_offset) .. snip .. line:sub(col + 1 + end_offset, #line) + vim.api.nvim_buf_set_lines(bufnr, row, row + 1, false, { line }) + return old_len, #line +end + +function M.insert_text_block(snip) local line_num = M.get_cursor_pos()[1] + 1 local lines = {} for line in snip:gmatch '[^\r\n]+' do table.insert(lines, line) end - vim.api.nvim_buf_set_lines(0, line_num, line_num, false, lines) + vim.api.nvim_buf_set_lines(vim.api.nvim_get_current_buf(), line_num, line_num, false, lines) end function M.run_shell_command(cmd) vim.fn.jobstart(cmd) end +function M.get_treesitter_root(bufnr) + return ts.get_parser(bufnr):parse()[1]:root() +end + +function M.treesitter_match_start_end(match) + local start_row, start_col, _, _ = match[1]:range() + local _, _, end_row, end_col = match[#match]:range() + return start_row, start_col, end_row, end_col +end + function M.cursor_within_treesitter_query(query, match_tolerance, cursor) cursor = cursor or M.get_cursor_pos() local bufnr = vim.api.nvim_get_current_buf() - local root = ts.get_parser(bufnr):parse()[1]:root() - for _, match, _ in query:iter_matches(root, bufnr, cursor[1], cursor[1] + 1) do + for _, match, _ in query:iter_matches(M.get_treesitter_root(bufnr), bufnr, cursor[1], cursor[1] + 1) do if match then - local start_row, start_col, _, _ = match[1]:range() - local _, _, end_row, end_col = match[#match]:range() - local matched = M.cursor_within_coords(cursor, start_row, end_row, start_col, end_col, + local start_row, start_col, end_row, end_col = M.treesitter_match_start_end(match) + local matched = M.cursor_within_coords(cursor, start_row, end_row, start_col, end_col, match_tolerance) if matched then return true