Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions data/bash.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
@role: user
@content:
<code>
local utils = require("module.utils")
local M = {}
utils.pa<missing>
return M
</code>

@role: assistant
@content:
<code>
utils.parser
</code>
Empty file added data/go.txt
Empty file.
Empty file added data/java.txt
Empty file.
Empty file added data/javascript.txt
Empty file.
Empty file added data/lua.txt
Empty file.
Empty file added data/markdown.txt
Empty file.
Empty file added data/miscellanious.txt
Empty file.
Empty file added data/python.txt
Empty file.
Empty file added data/rust.txt
Empty file.
2 changes: 2 additions & 0 deletions lua/model_cmp/commands.lua
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ local function modelcmp_stop()
vim.api.nvim_create_user_command('ModelCmpStop', function()
common.stop()
logger.debugging("Stopped api")
api.stop()
end, {})
end

Expand All @@ -103,6 +104,7 @@ local function modelcmp_logs()
vim.api.nvim_buf_set_option(newbuf, 'swapfile', false) -- No swap file
vim.api.nvim_buf_set_lines(newbuf, 0, -1, false, logger.Logs)
vim.api.nvim_buf_set_option(newbuf, 'modifiable', false) -- Make it read-only
-- Todo add logger
end, {})
end

Expand Down
28 changes: 14 additions & 14 deletions lua/model_cmp/modelapi/common.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local apiconfig = require("model_cmp.modelapi.apiconfig")
local context = require("model_cmp.context")
local logger = require("model_cmp.logger")
local preprompt = require("model_cmp.modelapi.prompt")
local prompter = require("model_cmp.modelapi.prompt")
local req = require("model_cmp.modelapi.request")
local utils = require("model_cmp.utils")
local virtualtext = require("model_cmp.virtualtext")
Expand All @@ -11,7 +11,7 @@ local llama = require("model_cmp.modelapi.llama")

local M = {}

vim.g.server = "local_llama"
vim.g.model_cmp_connection_server = nil

local available_keys = {
GEMINI_API_KEY = 0
Expand All @@ -35,25 +35,25 @@ local function check_available()
end

function M.send_request()
local bufnr = context.ContextEngine.bufnr
local currlang = context.ContextEngine.currlang
local ctx = context.generate_context_text()
local few_shots = preprompt.complete_few_shots

local prompt = prompter.generate_prompt("text", ctx)
local request

local server = vim.g.server
if vim.g.server == "" or vim.g.server == nil then
logger.error("NO server setup")
local server = vim.g.model_cmp_connection_server
if server == nil then
logger.trace("NO server setup")
return
end
if server == "default" or server == "local_llama" then
request = llama.generate_request(few_shots, ctx)
if server == "local_llama" then
request = llama.generate_request(prompt)
elseif server == "gemini" then
if available_keys.GEMINI_API_KEY ~= 1 then
logger.error("GEMINI_API_KEY is not set")
return
end
local systemprompt = preprompt.default.content
request = gemini.generate_request(few_shots, systemprompt, ctx)
request = gemini.generate_request(prompt)
end

if request == nil then
Expand All @@ -63,8 +63,8 @@ function M.send_request()
req.send(request,
function(response)
vim.schedule(function()
local text = ""
if vim.g.server == "gemini" then
local text = nil
if vim.g.model_cmp_connection_server == "gemini" then
text = gemini.decode_response(response)
else
text = utils.decode_response(response)
Expand All @@ -79,7 +79,7 @@ function M.send_request()
end

function M.stop()
vim.g.server = ""
vim.g.model_cmp_connection_server = nil
end

function M.setup(config)
Expand Down
19 changes: 11 additions & 8 deletions lua/model_cmp/modelapi/gemini.lua
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@ local function generate_url(model_name)
end

function M.start(model_name)
vim.g.server = "gemini"
vim.g.model_cmp_connection_server = "gemini"
end

local function transform_ctx_messages(ctx_messages)
-- Transforming few shot messages
---@param prompt Prompt
local function transform_fewshots(prompt)
local new_chat = {}
for _, msg in ipairs(ctx_messages) do
for _, msg in ipairs(prompt.fewshots) do
local gemini_message = {}

if msg.role == 'user' then
gemini_message = {
role = 'user',
Expand All @@ -45,17 +46,19 @@ local function transform_ctx_messages(ctx_messages)
},
}
end

table.insert(new_chat, gemini_message)
end
return new_chat
end

function M.generate_request(ctx_messages, content, mainctx)
local messages = transform_ctx_messages(ctx_messages)
---@param prompt Prompt
function M.generate_request(prompt)
local messages = transform_fewshots(prompt)
local mainmsg = {
role = 'user',
parts = {
{ text = mainctx },
{ text = prompt.context.content },
},
}

Expand All @@ -71,7 +74,7 @@ function M.generate_request(ctx_messages, content, mainctx)
vim.fn.json_encode({
system_instruction = {
parts = {
text = content
text = prompt.systemrole.content
}
},
contents = messages,
Expand Down
18 changes: 15 additions & 3 deletions lua/model_cmp/modelapi/llama.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,23 @@ local generate_url = function(custom_url)
end

function M.start(model_name)
vim.g.server = ""
vim.g.model_cmp_connection_server = "local_llama"
end

function M.generate_request(ctx_messages, ctx)
---@param prompt Prompt
function M.generate_request(prompt)
local custom = apiconfig.default()

local messages = { prompt.systemrole }
for _, k in ipairs(prompt.fewshots) do
table.insert(messages, k)
end
local context = {
role = "user",
content = prompt.context.content,
}
table.insert(messages, context)

local request = {
"-s",
"-X", "POST",
Expand All @@ -21,7 +33,7 @@ function M.generate_request(ctx_messages, ctx)
"-d",
vim.fn.json_encode({
model = "llama",
messages = ctx_messages,
messages = messages,
n_predict = 128,
temperature = 0.1,
stop = { "</s>" },
Expand Down
144 changes: 61 additions & 83 deletions lua/model_cmp/modelapi/prompt.lua
Original file line number Diff line number Diff line change
@@ -1,99 +1,77 @@
local M = {}

M.complete_few_shots = {
{
role = "user",
content = [[# language: lua
<code>
local ghosttext = require("model_cmp.ghosttext")
local utils = require("model_cmp.utils")

local M = {}

vim.b.requ<missing>
M.default_systemrole = {
role = "system",
content = [[Act as GitHub Copilot. Complete the code where the <missing> token is.
Follow the instructions:
- Output only the current line after replacing the <missing> tag.
- No explanations, no comments, no full files generations allowed.
- Max code generation is 5 lines.
- Match the language and indentation.
]]
}

local get_server_url = function() -- get server url from the user or from the config
return "http://127.0.0.1:8080/v1/chat/completions"
end
</code>]]
},
{
role = "assistant",
content = [[vim.b.request_sent = false]]
},
{
role = "user",
content = [[# language: lua
<code>
function M.get_trigger_characters()
return { '@', '.', '(', '[', ':', ' ' }
---@param language string
local function fewshot_lang_parser(language)
-- need to add path for the language context file
if language == "text" then
return {
{
role = "user",
content = "nothing"
},
{
role = "assistant",
content = "ok nothing"
}
}
end
end

function M.get_capabilities()
---@param language string
---@param ctx table<string>
---@return Singlefewshot
local function generate_context_shot(language, ctx)
local langprompt = "#language: " .. language
return {
completion<missing>
triggerCharacters = M.get_trigger_characters(),
},
role = "user",
content = langprompt .. "\n" .. ctx
}
end
</code>
]]
},
{
role = "assistant",
content = [[ completionProvider = {]]
},
{
role = "user",
content = [[# language: Python
<code>
from transformers.agents import (
ReactCodeAgent,
ReactJsonAgent,
HfApiEngine,
ManagedAgent,
)
from transformers.agents.search import DuckDuckGoSearchTool

llm_engine = <missing>

web_agent = ReactJsonAgent(
tools=[DuckDuckGoSearchTool(), visit_webpage],
llm_engine=llm_engine,
max_iterations=10,
)
</code>
]]
},
{
role = "assistant",
content = [[llm_engine = HfApiEngine(model)]]
},
}

M.default = {
role = "user",
content = [[Act as GitHub Copilot. Complete the code where the <missing> token is.
---@class Singlefewshot
---@field role string<"user" | "assistant" | "model" | "system">
---@field content string

Instructions:
- Output only the code that replaces <missing>.
- No explanations, no comments, no full file.
- Limit to ≤ 2 lines of code.
- Match language and indentation.
]]
}
---@class Prompt
---@field systemrole Singlefewshot
---@field fewshots table<Singlefewshot>
---@field language string
---@field context Singlefewshot

-- After this we are going to collect and send new data
M.closecall_suggestions = {
role = "user",
content = [[You almost got the right answer, try again with a different but similar result,
]]
}
---@param ctx any
---@return Prompt
function M.default_prompt(ctx)
return {
systemrole = M.default_systemrole,
fewshots = fewshot_lang_parser("text"),
language = "text",
context = ctx
}
end

M.wrong_suggestion = {
role = "user",
content = [[You predicted wrong, try again with a completly new suggestion but with this code
]]
}
function M.generate_prompt(language, ctx)
local prompt = M.default_prompt(ctx)
prompt.context = generate_context_shot(language, ctx)
if language == "text" or language == "" then
return prompt
end
local fewshots = fewshot_lang_parser(language)
prompt.fewshots = fewshots
prompt.language = language
return prompt
end

return M