feat: apply snippets to math wrapped in brackets

This commit is contained in:
arne314
2024-11-22 23:29:27 +01:00
parent d2d91f7a01
commit b03004d44c

View File

@@ -1,6 +1,7 @@
local ls = require('luasnip') local ls = require('luasnip')
local i = ls.insert_node local i = ls.insert_node
local d = ls.dynamic_node local d = ls.dynamic_node
local f = ls.function_node
local helper = require('typstar.autosnippets') local helper = require('typstar.autosnippets')
local math = helper.in_math local math = helper.in_math
@@ -9,28 +10,60 @@ local cap = helper.cap
local get_visual = helper.get_visual local get_visual = helper.get_visual
local snippets = {} local snippets = {}
local operations = { local operations = { -- boolean denotes whether an additional layer of () brackets should be removed
{ 'vi', '1/(', ')' }, { 'vi', '1/(', ')', true },
{ 'rb', '(', ')' }, { 'bb', '(', ')', false },
{ 'sq', '[', ']' }, { 'sq', '[', ']', true },
{ 'abs', '|', '|' }, { 'abs', '|', '|', false },
{ 'ul', 'underline(', ')' }, { 'ul', 'underline(', ')', false },
{ 'ol', 'overline(', ')' }, { 'ol', 'overline(', ')', false },
{ 'ht', 'hat(', ')' }, { 'ub', 'underbrace(', ')', false },
{ 'br', 'macron(', ')' }, { 'ob', 'overbrace(', ')', false },
{ 'dt', 'dot(', ')' }, { 'ht', 'hat(', ')', false },
{ 'ci', 'circle(', ')' }, { 'br', 'macron(', ')', false },
{ 'td', 'tilde(', ')' }, { 'dt', 'dot(', ')', false },
{ 'nr', 'norm(', ')' }, { 'ci', 'circle(', ')', false },
{ 'vv', 'vec(', ')' }, { 'td', 'tilde(', ')', false },
{ 'rt', 'sqrt(', ')' }, { 'nr', 'norm(', ')', false },
{ 'vv', 'vec(', ')', false },
{ 'rt', 'sqrt(', ')', false },
} }
local wrap_brackets = function(args, snippet, val)
local captured = snippet.captures[1]
local bracket_types = { [')'] = '(', [']'] = '[', ['}'] = '{' }
local closing_bracket = captured:sub(-1, -1)
local opening_bracket = bracket_types[closing_bracket]
if opening_bracket == nil then
return captured
end
local n_brackets = 0
local char
for i = #captured, 1, -1 do
char = captured:sub(i, i)
if char == closing_bracket then
n_brackets = n_brackets + 1
elseif char == opening_bracket then
n_brackets = n_brackets - 1
end
if n_brackets == 0 then
local remove_additional = val[4] and opening_bracket == '('
return captured:sub(1, i - 1) .. val[2]
.. captured:sub(i + (remove_additional and 1 or 0), -(remove_additional and 2 or 1)) .. val[3]
end
end
return captured
end
for _, val in pairs(operations) do for _, val in pairs(operations) do
table.insert(snippets, snip(val[1], val[2] .. '<>' .. val[3], { d(1, get_visual) }, math, 1200)) table.insert(snippets, snip(val[1], val[2] .. '<>' .. val[3], { d(1, get_visual), extra_node }, math))
table.insert(snippets, snip('%s' .. val[1], val[2] .. '<>' .. val[3], { i(1, '1'), extra_node }, math))
table.insert(snippets, table.insert(snippets,
snip('(%s)([^%s]*)' .. val[1], '<>' .. val[2] .. '<>' .. val[3], { cap(1), cap(2) }, math, 1100)) snip('(.*[%)|%]|%}])' .. val[1], '<>', { f(wrap_brackets, {}, { user_args = { val } }), nil }, math, 1100))
table.insert(snippets, snip('%s' .. val[1], val[2] .. '<>' .. val[3], { i(1, '1') }, math))
end end
return { return {