diff --git a/data/bash.txt b/data/bash.txt new file mode 100644 index 0000000..4031688 --- /dev/null +++ b/data/bash.txt @@ -0,0 +1,14 @@ +@role: user +@content: + +local utils = require("module.utils") +local M = {} +utils.pa +return M + + +@role: assistant +@content: + +utils.parser + diff --git a/data/go.txt b/data/go.txt new file mode 100644 index 0000000..e69de29 diff --git a/data/java.txt b/data/java.txt new file mode 100644 index 0000000..e69de29 diff --git a/data/javascript.txt b/data/javascript.txt new file mode 100644 index 0000000..e69de29 diff --git a/data/lua.txt b/data/lua.txt new file mode 100644 index 0000000..e69de29 diff --git a/data/markdown.txt b/data/markdown.txt new file mode 100644 index 0000000..e69de29 diff --git a/data/miscellanious.txt b/data/miscellanious.txt new file mode 100644 index 0000000..e69de29 diff --git a/data/python.txt b/data/python.txt new file mode 100644 index 0000000..e69de29 diff --git a/data/rust.txt b/data/rust.txt new file mode 100644 index 0000000..e69de29 diff --git a/lua/model_cmp/commands.lua b/lua/model_cmp/commands.lua index 4f037ac..28742a8 100644 --- a/lua/model_cmp/commands.lua +++ b/lua/model_cmp/commands.lua @@ -89,6 +89,7 @@ local function modelcmp_stop() vim.api.nvim_create_user_command('ModelCmpStop', function() common.stop() logger.debugging("Stopped api") + api.stop() end, {}) end @@ -103,6 +104,7 @@ local function modelcmp_logs() vim.api.nvim_buf_set_option(newbuf, 'swapfile', false) -- No swap file vim.api.nvim_buf_set_lines(newbuf, 0, -1, false, logger.Logs) vim.api.nvim_buf_set_option(newbuf, 'modifiable', false) -- Make it read-only + -- Todo add logger end, {}) end diff --git a/lua/model_cmp/modelapi/common.lua b/lua/model_cmp/modelapi/common.lua index a8418b4..526a880 100644 --- a/lua/model_cmp/modelapi/common.lua +++ b/lua/model_cmp/modelapi/common.lua @@ -1,7 +1,7 @@ local apiconfig = require("model_cmp.modelapi.apiconfig") local context = require("model_cmp.context") local logger = require("model_cmp.logger") -local preprompt = require("model_cmp.modelapi.prompt") +local prompter = require("model_cmp.modelapi.prompt") local req = require("model_cmp.modelapi.request") local utils = require("model_cmp.utils") local virtualtext = require("model_cmp.virtualtext") @@ -11,7 +11,7 @@ local llama = require("model_cmp.modelapi.llama") local M = {} -vim.g.server = "local_llama" +vim.g.model_cmp_connection_server = nil local available_keys = { GEMINI_API_KEY = 0 @@ -35,25 +35,25 @@ local function check_available() end function M.send_request() - local bufnr = context.ContextEngine.bufnr + local currlang = context.ContextEngine.currlang local ctx = context.generate_context_text() - local few_shots = preprompt.complete_few_shots + + local prompt = prompter.generate_prompt("text", ctx) local request - local server = vim.g.server - if vim.g.server == "" or vim.g.server == nil then - logger.error("NO server setup") + local server = vim.g.model_cmp_connection_server + if server == nil then + logger.trace("NO server setup") return end - if server == "default" or server == "local_llama" then - request = llama.generate_request(few_shots, ctx) + if server == "local_llama" then + request = llama.generate_request(prompt) elseif server == "gemini" then if available_keys.GEMINI_API_KEY ~= 1 then logger.error("GEMINI_API_KEY is not set") return end - local systemprompt = preprompt.default.content - request = gemini.generate_request(few_shots, systemprompt, ctx) + request = gemini.generate_request(prompt) end if request == nil then @@ -63,8 +63,8 @@ function M.send_request() req.send(request, function(response) vim.schedule(function() - local text = "" - if vim.g.server == "gemini" then + local text = nil + if vim.g.model_cmp_connection_server == "gemini" then text = gemini.decode_response(response) else text = utils.decode_response(response) @@ -79,7 +79,7 @@ function M.send_request() end function M.stop() - vim.g.server = "" + vim.g.model_cmp_connection_server = nil end function M.setup(config) diff --git a/lua/model_cmp/modelapi/gemini.lua b/lua/model_cmp/modelapi/gemini.lua index 7f8b44b..4be9d90 100644 --- a/lua/model_cmp/modelapi/gemini.lua +++ b/lua/model_cmp/modelapi/gemini.lua @@ -22,14 +22,15 @@ local function generate_url(model_name) end function M.start(model_name) - vim.g.server = "gemini" + vim.g.model_cmp_connection_server = "gemini" end -local function transform_ctx_messages(ctx_messages) - -- Transforming few shot messages +---@param prompt Prompt +local function transform_fewshots(prompt) local new_chat = {} - for _, msg in ipairs(ctx_messages) do + for _, msg in ipairs(prompt.fewshots) do local gemini_message = {} + if msg.role == 'user' then gemini_message = { role = 'user', @@ -45,17 +46,19 @@ local function transform_ctx_messages(ctx_messages) }, } end + table.insert(new_chat, gemini_message) end return new_chat end -function M.generate_request(ctx_messages, content, mainctx) - local messages = transform_ctx_messages(ctx_messages) +---@param prompt Prompt +function M.generate_request(prompt) + local messages = transform_fewshots(prompt) local mainmsg = { role = 'user', parts = { - { text = mainctx }, + { text = prompt.context.content }, }, } @@ -71,7 +74,7 @@ function M.generate_request(ctx_messages, content, mainctx) vim.fn.json_encode({ system_instruction = { parts = { - text = content + text = prompt.systemrole.content } }, contents = messages, diff --git a/lua/model_cmp/modelapi/llama.lua b/lua/model_cmp/modelapi/llama.lua index 64e1856..8c9178f 100644 --- a/lua/model_cmp/modelapi/llama.lua +++ b/lua/model_cmp/modelapi/llama.lua @@ -8,11 +8,23 @@ local generate_url = function(custom_url) end function M.start(model_name) - vim.g.server = "" + vim.g.model_cmp_connection_server = "local_llama" end -function M.generate_request(ctx_messages, ctx) +---@param prompt Prompt +function M.generate_request(prompt) local custom = apiconfig.default() + + local messages = { prompt.systemrole } + for _, k in ipairs(prompt.fewshots) do + table.insert(messages, k) + end + local context = { + role = "user", + content = prompt.context.content, + } + table.insert(messages, context) + local request = { "-s", "-X", "POST", @@ -21,7 +33,7 @@ function M.generate_request(ctx_messages, ctx) "-d", vim.fn.json_encode({ model = "llama", - messages = ctx_messages, + messages = messages, n_predict = 128, temperature = 0.1, stop = { "" }, diff --git a/lua/model_cmp/modelapi/prompt.lua b/lua/model_cmp/modelapi/prompt.lua index 1210a23..4162df0 100644 --- a/lua/model_cmp/modelapi/prompt.lua +++ b/lua/model_cmp/modelapi/prompt.lua @@ -1,99 +1,77 @@ -local M = {} - -M.complete_few_shots = { - { - role = "user", - content = [[# language: lua - -local ghosttext = require("model_cmp.ghosttext") local utils = require("model_cmp.utils") local M = {} -vim.b.requ +M.default_systemrole = { + role = "system", + content = [[Act as GitHub Copilot. Complete the code where the token is. +Follow the instructions: +- Output only the current line after replacing the tag. +- No explanations, no comments, no full files generations allowed. +- Max code generation is 5 lines. +- Match the language and indentation. +]] +} -local get_server_url = function() -- get server url from the user or from the config - return "http://127.0.0.1:8080/v1/chat/completions" -end -]] - }, - { - role = "assistant", - content = [[vim.b.request_sent = false]] - }, - { - role = "user", - content = [[# language: lua - -function M.get_trigger_characters() - return { '@', '.', '(', '[', ':', ' ' } +---@param language string +local function fewshot_lang_parser(language) + -- need to add path for the language context file + if language == "text" then + return { + { + role = "user", + content = "nothing" + }, + { + role = "assistant", + content = "ok nothing" + } + } + end end -function M.get_capabilities() +---@param language string +---@param ctx table +---@return Singlefewshot +local function generate_context_shot(language, ctx) + local langprompt = "#language: " .. language return { - completion - triggerCharacters = M.get_trigger_characters(), - }, + role = "user", + content = langprompt .. "\n" .. ctx } end - -]] - }, - { - role = "assistant", - content = [[ completionProvider = {]] - }, - { - role = "user", - content = [[# language: Python - -from transformers.agents import ( - ReactCodeAgent, - ReactJsonAgent, - HfApiEngine, - ManagedAgent, -) -from transformers.agents.search import DuckDuckGoSearchTool - -llm_engine = - -web_agent = ReactJsonAgent( - tools=[DuckDuckGoSearchTool(), visit_webpage], - llm_engine=llm_engine, - max_iterations=10, -) - -]] - }, - { - role = "assistant", - content = [[llm_engine = HfApiEngine(model)]] - }, -} -M.default = { - role = "user", - content = [[Act as GitHub Copilot. Complete the code where the token is. +---@class Singlefewshot +---@field role string<"user" | "assistant" | "model" | "system"> +---@field content string -Instructions: -- Output only the code that replaces . -- No explanations, no comments, no full file. -- Limit to ≤ 2 lines of code. -- Match language and indentation. -]] -} +---@class Prompt +---@field systemrole Singlefewshot +---@field fewshots table +---@field language string +---@field context Singlefewshot --- After this we are going to collect and send new data -M.closecall_suggestions = { - role = "user", - content = [[You almost got the right answer, try again with a different but similar result, -]] -} +---@param ctx any +---@return Prompt +function M.default_prompt(ctx) + return { + systemrole = M.default_systemrole, + fewshots = fewshot_lang_parser("text"), + language = "text", + context = ctx + } +end -M.wrong_suggestion = { - role = "user", - content = [[You predicted wrong, try again with a completly new suggestion but with this code -]] -} +function M.generate_prompt(language, ctx) + local prompt = M.default_prompt(ctx) + prompt.context = generate_context_shot(language, ctx) + if language == "text" or language == "" then + return prompt + end + local fewshots = fewshot_lang_parser(language) + prompt.fewshots = fewshots + prompt.language = language + return prompt +end return M