diff --git a/lua/model_cmp/config.lua b/lua/model_cmp/config.lua index 376a781..f8fa8c9 100644 --- a/lua/model_cmp/config.lua +++ b/lua/model_cmp/config.lua @@ -1,25 +1,16 @@ +local apiconfig = require("model_cmp.modelapi.apiconfig") + local M = {} ---@class ModelCmp.Config ----@field delay integer delay between each request in ms ----@field api table> API request or server request config +---@field api ModelCmp.Modelapi.Config ---@field virtualtext table virtual text config ----@field prompt Prompt prompts for api - ----@class Prompt ----@field basic_template string ----@field rules string ----@field language string ----@field precontext string ---@return ModelCmp.Config function M.default() return { - api = { - url = "", -- url to the server, defaults are already set, you just need to setup this up only if the url is different for your server - key = "", -- None if using local - type = "" -- EX: OPENAI, Claude, Gemini, llama.cpp(local or none are also valid) - }, + api = apiconfig.default(), + virtualtext = { enable = false, type = "inline", diff --git a/lua/model_cmp/modelapi/apiconfig.lua b/lua/model_cmp/modelapi/apiconfig.lua new file mode 100644 index 0000000..bb34a18 --- /dev/null +++ b/lua/model_cmp/modelapi/apiconfig.lua @@ -0,0 +1,30 @@ +local M = {} + +---@class ModelCmp.Modelapi.Config +---@field apikeys APIKeyHolder +---@field custom_url table + +---@class APIKeyHolder +---@field OPENAI_API_KEY string +---@field CLAUDE_API_KEY string + +---@return APIKeyHolder +local function get_apikeys() + return { + OPENAI_API_KEY = "", + CLAUDE_API_KEY = "" + } +end + +---@return ModelCmp.Modelapi.Config +function M.default() + return { + apikeys = get_apikeys(), + custom_url = { + url = "http://127.0.0.1", + port = "8080" + } + } +end + +return M diff --git a/lua/model_cmp/modelapi/common.lua b/lua/model_cmp/modelapi/common.lua new file mode 100644 index 0000000..4bb08f5 --- /dev/null +++ b/lua/model_cmp/modelapi/common.lua @@ -0,0 +1,111 @@ +local req = require("model_cmp.modelapi.request") +local virtualtext = require("model_cmp.virtualtext") +local utils = require("model_cmp.utils") +local apiconfig = require("model_cmp.modelapi.apiconfig") + +-- server channels +local llama = require("model_cmp.modelapi.llama") + +local M = {} + +vim.g.server = "url" -- llama is default available options are openai, claude + +-- 0 means not avaiable and 1 means avaiable +local available_keys = { + OPENAI_API_KEY = 0, + CLAUDE_API_KEY = 0, +} + +local function servername_to_key() + local server = vim.g.server + if server == "openai" then + return "OPENAI_API_KEY" + elseif server == "claude" then + return "CLAUDE_API_KEY" + else + return + end +end + + +--Check for availability for both apikeys and server urls +local function check_available() + for keyname, key in ipairs(M.apikeys) do + if key ~= "" then + available_keys[keyname] = 1 + end + end + if M.custom_url ~= nil then + if M.custom_url.url == "" or M.custom_url.port == "" then + M.custom_url = { url = "http://127.0.0.1", port = "8080" } + end + else + M.custom_url = apiconfig.default().custom_url + end +end + +M.requests = {} -- only store buffer id + +local function add_request(bufid) + local index = #M.requests + 1 + table.insert(M.requests, bufid) + return index +end + +local function remove_request(index) + table.remove(M.requests, index) +end + +-- we will check if there is a request already made for the given buffer +local function check_already_requested(bufnr) + for buffer in pairs(M.requests) do + if bufnr == buffer then + return true + end + end + return false +end + +function M.send_request() + local request, bufnr + + local server = vim.g.server + if server == "url" then + bufnr, request = llama.generate_request() + elseif server == "openai" then + if available_keys[servername_to_key()] then + -- Working on openai services + end + elseif server == "claude" then + if available_keys[servername_to_key()] then + -- Working on claude services + end + end + + if request ~= nil then + return + end + if not check_already_requested(bufnr) then + return + end + + add_request(bufnr) + req.send(request, + function(response) + vim.schedule(function() + local text = utils.decode_response(response) + virtualtext.VirtualText:update_preview(text) + remove_request(bufnr) + end) + end + ) +end + +function M.setup(config) + local api = config.api + M.apikeys = api.apikeys + M.custom_url = api.custom_url + check_available() +end + +return M diff --git a/lua/model_cmp/modelapi/llama.lua b/lua/model_cmp/modelapi/llama.lua index 810211e..7ca6234 100644 --- a/lua/model_cmp/modelapi/llama.lua +++ b/lua/model_cmp/modelapi/llama.lua @@ -1,27 +1,25 @@ -local curl = require("model_cmp.modelapi.curl") local context = require("model_cmp.context") local systemprompt = require("model_cmp.modelapi.prompt") -local virtualtext = require("model_cmp.virtualtext") -local utils = require("model_cmp.utils") +local apiconfig = require("model_cmp.modelapi.apiconfig") local M = {} vim.b.request_sent = false -local get_server_url = function() -- get server url from the user or from the config - return "http://127.0.0.1:8080/v1/chat/completions" +local generate_url = function(custom_url) + local url = custom_url.url .. ":" .. custom_url.port .. "/v1/chat/completions" + return url end -function M.check_server_running() -end - -function M.send_request() +function M.generate_request() + local bufnr = context.ContextEngine.bufnr local prompt = context.generate_context_text() local lang = context.ContextEngine:get_currlang() local complete_prompt = "# language: " .. lang .. prompt local few_shots = systemprompt.complete_few_shots + local custom = apiconfig.default() local messages = {} table.insert(messages, systemprompt.default) for _, msg in ipairs(few_shots) do @@ -32,7 +30,7 @@ function M.send_request() local request = { "-s", "-X", "POST", - get_server_url(), + generate_url(), "-H", "Content-Type: application/json", "-d", vim.fn.json_encode({ @@ -44,27 +42,7 @@ function M.send_request() max_token = 50 }), } - vim.b.request_sent = true - curl.send(request, - function(response) - vim.schedule(function() - local text = utils.decode_response(response) - virtualtext.VirtualText:update_preview(text) - vim.b.request_sent = false - end) - end - ) -end - ---- TEMPORARY ACTIONS - - -function M.text_changed() - if vim.b.request_sent then - return - end - virtualtext.action.clear_preview() - M.send_request() + return bufnr, request end return M diff --git a/lua/model_cmp/modelapi/managekey.lua b/lua/model_cmp/modelapi/managekey.lua deleted file mode 100644 index 186a9a7..0000000 --- a/lua/model_cmp/modelapi/managekey.lua +++ /dev/null @@ -1,49 +0,0 @@ -local os = os -local M = {} - -M.API_KEYS = { - OPENAI = "", -} - -M.API_KEY_TYPE = { - "OPENAI_API_KEY", -} - -function M.api_is_set() - if vim.g.API_KEYS == nil then - return false - end - return true -end - -function M.is_available(config) - -- check if the api key is there in config - local available = false - local next = next - local api_keys = config.api_keys - if api_keys ~= nil and next(api_keys) ~= nil then - vim.g.OPENAI_API_KEY = api_keys.OPENAI_API_KEY - available = true - end - - -- check if there is any environment variable for API KEY - for idx, key in pairs(M.API_KEY_TYPE) do - local currkey = os.getenv(key) - if currkey ~= nil then - if key == "OPENAI_API_KEY" then - M.API_KEYS.OPENAI = currkey - available = true - end - end - end - - vim.g.API_KEYS = M.API_KEYS - - return available -end - -function M.get_api_key(keytype) - return vim.g.API_KEYS[keytype] -end - -return M diff --git a/lua/model_cmp/modelapi/curl.lua b/lua/model_cmp/modelapi/request.lua similarity index 73% rename from lua/model_cmp/modelapi/curl.lua rename to lua/model_cmp/modelapi/request.lua index b98261b..f0197ba 100644 --- a/lua/model_cmp/modelapi/curl.lua +++ b/lua/model_cmp/modelapi/request.lua @@ -2,9 +2,8 @@ local Job = require("plenary.job") local M = {} -M.requests = {} -- This can store pending jobs or request data -function M.send(request_args, callback) +function M.send(bufnr, request_args, callback) local result = {} local job = Job:new({ command = "curl", @@ -20,7 +19,6 @@ function M.send(request_args, callback) end, }) job:start() - table.insert(M.requests, job) -- Optional: keep track of active jobs end return M