diff --git a/apps/sim/app/(landing)/models/utils.ts b/apps/sim/app/(landing)/models/utils.ts index fdc0dc548f0..8942dfa948d 100644 --- a/apps/sim/app/(landing)/models/utils.ts +++ b/apps/sim/app/(landing)/models/utils.ts @@ -8,6 +8,9 @@ const PROVIDER_PREFIXES: Record = { bedrock: ['bedrock/'], cerebras: ['cerebras/'], fireworks: ['fireworks/'], + together: ['together/'], + baseten: ['baseten/'], + 'ollama-cloud': ['ollama-cloud/'], groq: ['groq/'], openrouter: ['openrouter/'], vllm: ['vllm/'], diff --git a/apps/sim/app/api/providers/baseten/models/route.test.ts b/apps/sim/app/api/providers/baseten/models/route.test.ts new file mode 100644 index 00000000000..8ada3c427ea --- /dev/null +++ b/apps/sim/app/api/providers/baseten/models/route.test.ts @@ -0,0 +1,252 @@ +/** + * @vitest-environment node + */ +import { createMockRequest } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockFilterBlacklistedModels, + mockIsProviderBlacklisted, + mockGetBYOKKey, + mockGetSession, + mockGetUserEntityPermissions, + mutableEnv, +} = vi.hoisted(() => ({ + mockFilterBlacklistedModels: vi.fn(), + mockIsProviderBlacklisted: vi.fn(), + mockGetBYOKKey: vi.fn(), + mockGetSession: vi.fn(), + mockGetUserEntityPermissions: vi.fn(), + mutableEnv: { BASETEN_API_KEY: undefined as string | undefined }, +})) + +vi.mock('@/lib/core/config/env', () => ({ env: mutableEnv })) + +vi.mock('@/providers/utils', () => ({ + filterBlacklistedModels: mockFilterBlacklistedModels, + isProviderBlacklisted: mockIsProviderBlacklisted, +})) + +vi.mock('@/lib/api-key/byok', () => ({ + getBYOKKey: mockGetBYOKKey, +})) + +vi.mock('@/lib/auth', () => ({ + getSession: mockGetSession, +})) + +vi.mock('@/lib/workspaces/permissions/utils', () => ({ + getUserEntityPermissions: mockGetUserEntityPermissions, +})) + +import { GET } from '@/app/api/providers/baseten/models/route' + +const BASETEN_MODELS_URL = 'https://inference.baseten.co/v1/models' + +function jsonResponse(body: unknown, init: { ok?: boolean; status?: number } = {}): Response { + const status = init.status ?? 200 + const ok = init.ok ?? (status >= 200 && status < 300) + return { + ok, + status, + statusText: ok ? 'OK' : 'Error', + json: vi.fn(async () => body), + } as unknown as Response +} + +function setEnvKey(value: string | undefined): void { + mutableEnv.BASETEN_API_KEY = value +} + +function authHeaderFromLastFetch(mockFetch: ReturnType): unknown { + const init = mockFetch.mock.calls.at(-1)?.[1] as RequestInit | undefined + return (init?.headers as Record | undefined)?.Authorization +} + +describe('GET /api/providers/baseten/models', () => { + let mockFetch: ReturnType + + beforeEach(() => { + vi.clearAllMocks() + + mockFetch = vi.fn() + vi.stubGlobal('fetch', mockFetch) + + mockIsProviderBlacklisted.mockReturnValue(false) + mockFilterBlacklistedModels.mockImplementation((models: string[]) => models) + mockGetBYOKKey.mockResolvedValue(null) + mockGetSession.mockResolvedValue(null) + mockGetUserEntityPermissions.mockResolvedValue(null) + setEnvKey(undefined) + }) + + it('returns empty models without fetching when the provider is blacklisted', async () => { + mockIsProviderBlacklisted.mockReturnValue(true) + setEnvKey('env-key') + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('returns empty models when no workspaceId and no env key are available', async () => { + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('fetches models with the env key and prefixes each id with baseten/', async () => { + setEnvKey('env-key') + mockFetch.mockResolvedValueOnce( + jsonResponse({ + data: [{ id: 'openai/gpt-oss-120b' }, { id: 'deepseek-ai/DeepSeek-V3' }], + }) + ) + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + models: ['baseten/openai/gpt-oss-120b', 'baseten/deepseek-ai/DeepSeek-V3'], + }) + + expect(mockFetch).toHaveBeenCalledTimes(1) + const [url, init] = mockFetch.mock.calls[0] + expect(url).toBe(BASETEN_MODELS_URL) + expect((init.headers as Record).Authorization).toBe('Bearer env-key') + }) + + it('uses the BYOK key when workspaceId, session, and permission are present', async () => { + setEnvKey('env-key') + mockGetSession.mockResolvedValue({ user: { id: 'user-1' } }) + mockGetUserEntityPermissions.mockResolvedValue('admin') + mockGetBYOKKey.mockResolvedValue({ apiKey: 'byok-key', isBYOK: true }) + mockFetch.mockResolvedValueOnce(jsonResponse({ data: [{ id: 'model-a' }] })) + + const res = await GET( + createMockRequest('GET', undefined, {}, 'http://localhost:3000/api/test?workspaceId=ws-1') + ) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['baseten/model-a'] }) + + expect(mockGetBYOKKey).toHaveBeenCalledWith('ws-1', 'baseten') + expect(authHeaderFromLastFetch(mockFetch)).toBe('Bearer byok-key') + }) + + it('falls back to the env key when there is a workspaceId but no session', async () => { + setEnvKey('env-key') + mockGetSession.mockResolvedValue(null) + mockFetch.mockResolvedValueOnce(jsonResponse({ data: [{ id: 'model-a' }] })) + + const res = await GET( + createMockRequest('GET', undefined, {}, 'http://localhost:3000/api/test?workspaceId=ws-1') + ) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['baseten/model-a'] }) + expect(mockGetBYOKKey).not.toHaveBeenCalled() + expect(authHeaderFromLastFetch(mockFetch)).toBe('Bearer env-key') + }) + + it('falls back to the env key when the user lacks workspace permission', async () => { + setEnvKey('env-key') + mockGetSession.mockResolvedValue({ user: { id: 'user-1' } }) + mockGetUserEntityPermissions.mockResolvedValue(null) + mockFetch.mockResolvedValueOnce(jsonResponse({ data: [{ id: 'model-a' }] })) + + const res = await GET( + createMockRequest('GET', undefined, {}, 'http://localhost:3000/api/test?workspaceId=ws-1') + ) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['baseten/model-a'] }) + expect(mockGetBYOKKey).not.toHaveBeenCalled() + expect(authHeaderFromLastFetch(mockFetch)).toBe('Bearer env-key') + }) + + it('returns empty models when the upstream responds 401', async () => { + setEnvKey('env-key') + mockFetch.mockResolvedValueOnce(jsonResponse({}, { ok: false, status: 401 })) + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('returns empty models when the upstream responds 500', async () => { + setEnvKey('env-key') + mockFetch.mockResolvedValueOnce(jsonResponse({}, { ok: false, status: 500 })) + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('returns empty models when fetch throws', async () => { + setEnvKey('env-key') + mockFetch.mockRejectedValueOnce(new Error('network down')) + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('returns empty models when the upstream data array is empty', async () => { + setEnvKey('env-key') + mockFetch.mockResolvedValueOnce(jsonResponse({ data: [] })) + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('returns empty models when the upstream omits the data field', async () => { + setEnvKey('env-key') + mockFetch.mockResolvedValueOnce(jsonResponse({ object: 'list' })) + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('dedupes repeated model ids', async () => { + setEnvKey('env-key') + mockFetch.mockResolvedValueOnce( + jsonResponse({ + data: [{ id: 'model-a' }, { id: 'model-a' }, { id: 'model-b' }], + }) + ) + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['baseten/model-a', 'baseten/model-b'] }) + }) + + it('drops models removed by the blacklist filter', async () => { + setEnvKey('env-key') + mockFilterBlacklistedModels.mockImplementation((models: string[]) => + models.filter((m) => m !== 'baseten/blocked-model') + ) + mockFetch.mockResolvedValueOnce( + jsonResponse({ + data: [{ id: 'allowed-model' }, { id: 'blocked-model' }], + }) + ) + + const res = await GET(createMockRequest('GET')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['baseten/allowed-model'] }) + }) +}) diff --git a/apps/sim/app/api/providers/baseten/models/route.ts b/apps/sim/app/api/providers/baseten/models/route.ts new file mode 100644 index 00000000000..b73a6711421 --- /dev/null +++ b/apps/sim/app/api/providers/baseten/models/route.ts @@ -0,0 +1,93 @@ +import { createLogger } from '@sim/logger' +import { getErrorMessage } from '@sim/utils/errors' +import { type NextRequest, NextResponse } from 'next/server' +import { + basetenProviderModelsQuerySchema, + basetenUpstreamResponseSchema, + providerModelsResponseSchema, +} from '@/lib/api/contracts/providers' +import { validationErrorResponse } from '@/lib/api/server' +import { getBYOKKey } from '@/lib/api-key/byok' +import { getSession } from '@/lib/auth' +import { env } from '@/lib/core/config/env' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' +import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils' + +const logger = createLogger('BasetenModelsAPI') + +export const GET = withRouteHandler(async (request: NextRequest) => { + if (isProviderBlacklisted('baseten')) { + logger.info('Baseten provider is blacklisted, returning empty models') + return NextResponse.json({ models: [] }) + } + + let apiKey: string | undefined + + const queryValidation = basetenProviderModelsQuerySchema.safeParse({ + workspaceId: request.nextUrl.searchParams.get('workspaceId') ?? undefined, + }) + if (!queryValidation.success) return validationErrorResponse(queryValidation.error) + const { workspaceId } = queryValidation.data + if (workspaceId) { + const session = await getSession() + if (session?.user?.id) { + const permission = await getUserEntityPermissions(session.user.id, 'workspace', workspaceId) + if (permission) { + const byokResult = await getBYOKKey(workspaceId, 'baseten') + if (byokResult) { + apiKey = byokResult.apiKey + } + } + } + } + + if (!apiKey) { + apiKey = env.BASETEN_API_KEY + } + + if (!apiKey) { + logger.info('No Baseten API key available, returning empty models') + return NextResponse.json({ models: [] }) + } + + try { + const response = await fetch('https://inference.baseten.co/v1/models', { + headers: { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + }, + cache: 'no-store', + }) + + if (!response.ok) { + logger.warn('Failed to fetch Baseten models', { + status: response.status, + statusText: response.statusText, + }) + return NextResponse.json({ models: [] }) + } + + const data = basetenUpstreamResponseSchema.parse(await response.json()) + + const allModels: string[] = [] + for (const model of data.data ?? []) { + allModels.push(`baseten/${model.id}`) + } + + const uniqueModels = Array.from(new Set(allModels)) + const models = filterBlacklistedModels(uniqueModels) + + logger.info('Successfully fetched Baseten models', { + count: models.length, + filtered: uniqueModels.length - models.length, + }) + + return NextResponse.json(providerModelsResponseSchema.parse({ models })) + } catch (error) { + logger.error('Error fetching Baseten models', { + error: getErrorMessage(error, 'Unknown error'), + }) + return NextResponse.json({ models: [] }) + } +}) diff --git a/apps/sim/app/api/providers/ollama-cloud/models/route.test.ts b/apps/sim/app/api/providers/ollama-cloud/models/route.test.ts new file mode 100644 index 00000000000..fd70d1ae8b1 --- /dev/null +++ b/apps/sim/app/api/providers/ollama-cloud/models/route.test.ts @@ -0,0 +1,238 @@ +/** + * @vitest-environment node + */ +import { createMockRequest } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockFilterBlacklistedModels, + mockIsProviderBlacklisted, + mockGetBYOKKey, + mockGetSession, + mockGetUserEntityPermissions, + mockFetch, +} = vi.hoisted(() => ({ + mockFilterBlacklistedModels: vi.fn(), + mockIsProviderBlacklisted: vi.fn(), + mockGetBYOKKey: vi.fn(), + mockGetSession: vi.fn(), + mockGetUserEntityPermissions: vi.fn(), + mockFetch: vi.fn(), +})) + +vi.mock('@/providers/utils', () => ({ + filterBlacklistedModels: mockFilterBlacklistedModels, + isProviderBlacklisted: mockIsProviderBlacklisted, +})) + +vi.mock('@/lib/api-key/byok', () => ({ + getBYOKKey: mockGetBYOKKey, +})) + +vi.mock('@/lib/auth', () => ({ + getSession: mockGetSession, +})) + +vi.mock('@/lib/workspaces/permissions/utils', () => ({ + getUserEntityPermissions: mockGetUserEntityPermissions, +})) + +import { GET } from '@/app/api/providers/ollama-cloud/models/route' + +const OLLAMA_CLOUD_TAGS_URL = 'https://ollama.com/api/tags' + +const okResponse = (body: unknown) => ({ + ok: true, + status: 200, + statusText: 'OK', + json: vi.fn().mockResolvedValue(body), +}) + +const errorResponse = (status: number, statusText = 'Unauthorized') => ({ + ok: false, + status, + statusText, + json: vi.fn().mockResolvedValue({}), +}) + +/** + * Builds a request whose query string carries the given workspaceId. Passing + * `undefined` omits the param entirely; passing `''` produces `?workspaceId=`. + */ +const requestWithWorkspace = (workspaceId?: string) => { + const url = new URL('http://localhost:3000/api/providers/ollama-cloud/models') + if (workspaceId !== undefined) { + url.searchParams.set('workspaceId', workspaceId) + } + return createMockRequest('GET', undefined, {}, url.toString()) +} + +const fetchAuthHeader = () => { + const init = mockFetch.mock.calls[0]?.[1] as RequestInit | undefined + const headers = init?.headers as Record | undefined + return headers?.Authorization +} + +/** Grants a session + workspace permission so the BYOK lookup is reached. */ +const grantWorkspaceAccess = () => { + mockGetSession.mockResolvedValue({ user: { id: 'user-1' } }) + mockGetUserEntityPermissions.mockResolvedValue('admin') +} + +describe('GET /api/providers/ollama-cloud/models', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.stubGlobal('fetch', mockFetch) + + mockIsProviderBlacklisted.mockReturnValue(false) + mockFilterBlacklistedModels.mockImplementation((models: string[]) => models) + mockGetBYOKKey.mockResolvedValue(null) + mockGetSession.mockResolvedValue(null) + mockGetUserEntityPermissions.mockResolvedValue(null) + }) + + it('returns empty models without calling fetch when the provider is blacklisted', async () => { + mockIsProviderBlacklisted.mockReturnValue(true) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('returns empty models when there is no workspaceId (BYOK only, no env fallback)', async () => { + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockFetch).not.toHaveBeenCalled() + expect(mockGetBYOKKey).not.toHaveBeenCalled() + }) + + it('returns empty models when the workspace has no stored BYOK key (never falls back to a hosted key)', async () => { + grantWorkspaceAccess() + mockGetBYOKKey.mockResolvedValue(null) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockGetBYOKKey).toHaveBeenCalledWith('ws-1', 'ollama-cloud') + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('fetches /api/tags with the BYOK key and prefixes each model name with ollama-cloud/', async () => { + grantWorkspaceAccess() + mockGetBYOKKey.mockResolvedValue({ apiKey: 'byok-ollama-key' }) + mockFetch.mockResolvedValue( + okResponse({ + models: [{ name: 'gpt-oss:120b' }, { name: 'deepseek-v3.1:671b' }], + }) + ) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + models: ['ollama-cloud/gpt-oss:120b', 'ollama-cloud/deepseek-v3.1:671b'], + }) + + expect(mockFetch).toHaveBeenCalledTimes(1) + expect(mockFetch.mock.calls[0][0]).toBe(OLLAMA_CLOUD_TAGS_URL) + expect(fetchAuthHeader()).toBe('Bearer byok-ollama-key') + }) + + it('does not call getBYOKKey when there is a workspaceId but no session', async () => { + mockGetSession.mockResolvedValue(null) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockGetBYOKKey).not.toHaveBeenCalled() + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('does not call getBYOKKey when the session user lacks workspace permission', async () => { + mockGetSession.mockResolvedValue({ user: { id: 'user-1' } }) + mockGetUserEntityPermissions.mockResolvedValue(null) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockGetBYOKKey).not.toHaveBeenCalled() + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('returns empty models when the upstream fetch responds non-ok', async () => { + grantWorkspaceAccess() + mockGetBYOKKey.mockResolvedValue({ apiKey: 'byok-ollama-key' }) + mockFetch.mockResolvedValue(errorResponse(401, 'Unauthorized')) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('returns empty models when the upstream fetch throws', async () => { + grantWorkspaceAccess() + mockGetBYOKKey.mockResolvedValue({ apiKey: 'byok-ollama-key' }) + mockFetch.mockRejectedValue(new Error('network down')) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('returns a validation error for an empty workspaceId query param', async () => { + const res = await GET(requestWithWorkspace('')) + + expect(res.status).toBe(400) + const body = (await res.json()) as { error: string } + expect(body.error).toBe('Validation error') + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('dedupes duplicate model names from the upstream response', async () => { + grantWorkspaceAccess() + mockGetBYOKKey.mockResolvedValue({ apiKey: 'byok-ollama-key' }) + mockFetch.mockResolvedValue( + okResponse({ + models: [{ name: 'gpt-oss:120b' }, { name: 'gpt-oss:120b' }, { name: 'qwen3-coder:480b' }], + }) + ) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + models: ['ollama-cloud/gpt-oss:120b', 'ollama-cloud/qwen3-coder:480b'], + }) + }) + + it('applies the blacklist filter to the deduped model list', async () => { + grantWorkspaceAccess() + mockGetBYOKKey.mockResolvedValue({ apiKey: 'byok-ollama-key' }) + mockFilterBlacklistedModels.mockImplementation((models: string[]) => + models.filter((m) => !m.includes('qwen')) + ) + mockFetch.mockResolvedValue( + okResponse({ + models: [{ name: 'gpt-oss:120b' }, { name: 'qwen3-coder:480b' }], + }) + ) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['ollama-cloud/gpt-oss:120b'] }) + expect(mockFilterBlacklistedModels).toHaveBeenCalledWith([ + 'ollama-cloud/gpt-oss:120b', + 'ollama-cloud/qwen3-coder:480b', + ]) + }) +}) diff --git a/apps/sim/app/api/providers/ollama-cloud/models/route.ts b/apps/sim/app/api/providers/ollama-cloud/models/route.ts new file mode 100644 index 00000000000..bd5673e0842 --- /dev/null +++ b/apps/sim/app/api/providers/ollama-cloud/models/route.ts @@ -0,0 +1,91 @@ +import { createLogger } from '@sim/logger' +import { getErrorMessage } from '@sim/utils/errors' +import { type NextRequest, NextResponse } from 'next/server' +import { + ollamaCloudProviderModelsQuerySchema, + ollamaUpstreamResponseSchema, + providerModelsResponseSchema, +} from '@/lib/api/contracts/providers' +import { validationErrorResponse } from '@/lib/api/server' +import { getBYOKKey } from '@/lib/api-key/byok' +import { getSession } from '@/lib/auth' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' +import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils' + +const logger = createLogger('OllamaCloudModelsAPI') + +/** + * Get available Ollama Cloud models. + * + * Ollama Cloud is BYOK-only — Sim never supplies a hosted key and never bills + * usage. Models are listed only when the workspace has stored its own Ollama + * API key, which is used to authenticate against the cloud `/api/tags` endpoint. + */ +export const GET = withRouteHandler(async (request: NextRequest) => { + if (isProviderBlacklisted('ollama-cloud')) { + logger.info('Ollama Cloud provider is blacklisted, returning empty models') + return NextResponse.json({ models: [] }) + } + + const queryValidation = ollamaCloudProviderModelsQuerySchema.safeParse({ + workspaceId: request.nextUrl.searchParams.get('workspaceId') ?? undefined, + }) + if (!queryValidation.success) return validationErrorResponse(queryValidation.error) + const { workspaceId } = queryValidation.data + + let apiKey: string | undefined + if (workspaceId) { + const session = await getSession() + if (session?.user?.id) { + const permission = await getUserEntityPermissions(session.user.id, 'workspace', workspaceId) + if (permission) { + const byokResult = await getBYOKKey(workspaceId, 'ollama-cloud') + if (byokResult) { + apiKey = byokResult.apiKey + } + } + } + } + + if (!apiKey) { + logger.info('No Ollama Cloud API key available, returning empty models') + return NextResponse.json({ models: [] }) + } + + try { + const response = await fetch('https://ollama.com/api/tags', { + headers: { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + }, + cache: 'no-store', + }) + + if (!response.ok) { + logger.warn('Failed to fetch Ollama Cloud models', { + status: response.status, + statusText: response.statusText, + }) + return NextResponse.json({ models: [] }) + } + + const data = ollamaUpstreamResponseSchema.parse(await response.json()) + + const allModels = data.models.map((model) => `ollama-cloud/${model.name}`) + const uniqueModels = Array.from(new Set(allModels)) + const models = filterBlacklistedModels(uniqueModels) + + logger.info('Successfully fetched Ollama Cloud models', { + count: models.length, + filtered: uniqueModels.length - models.length, + }) + + return NextResponse.json(providerModelsResponseSchema.parse({ models })) + } catch (error) { + logger.error('Error fetching Ollama Cloud models', { + error: getErrorMessage(error, 'Unknown error'), + }) + return NextResponse.json({ models: [] }) + } +}) diff --git a/apps/sim/app/api/providers/together/models/route.test.ts b/apps/sim/app/api/providers/together/models/route.test.ts new file mode 100644 index 00000000000..ae801bb7c56 --- /dev/null +++ b/apps/sim/app/api/providers/together/models/route.test.ts @@ -0,0 +1,259 @@ +/** + * @vitest-environment node + */ +import { createMockRequest } from '@sim/testing' +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockFilterBlacklistedModels, + mockIsProviderBlacklisted, + mockGetBYOKKey, + mockGetSession, + mockGetUserEntityPermissions, + mockFetch, + mutableEnv, +} = vi.hoisted(() => ({ + mockFilterBlacklistedModels: vi.fn(), + mockIsProviderBlacklisted: vi.fn(), + mockGetBYOKKey: vi.fn(), + mockGetSession: vi.fn(), + mockGetUserEntityPermissions: vi.fn(), + mockFetch: vi.fn(), + mutableEnv: { TOGETHER_API_KEY: undefined as string | undefined }, +})) + +vi.mock('@/providers/utils', () => ({ + filterBlacklistedModels: mockFilterBlacklistedModels, + isProviderBlacklisted: mockIsProviderBlacklisted, +})) + +vi.mock('@/lib/api-key/byok', () => ({ + getBYOKKey: mockGetBYOKKey, +})) + +vi.mock('@/lib/auth', () => ({ + getSession: mockGetSession, +})) + +vi.mock('@/lib/workspaces/permissions/utils', () => ({ + getUserEntityPermissions: mockGetUserEntityPermissions, +})) + +vi.mock('@/lib/core/config/env', () => ({ + env: mutableEnv, +})) + +import { GET } from '@/app/api/providers/together/models/route' + +const TOGETHER_MODELS_URL = 'https://api.together.ai/v1/models' + +const okResponse = (body: unknown) => ({ + ok: true, + status: 200, + statusText: 'OK', + json: vi.fn().mockResolvedValue(body), +}) + +const errorResponse = (status: number, statusText = 'Unauthorized') => ({ + ok: false, + status, + statusText, + json: vi.fn().mockResolvedValue({}), +}) + +/** + * Builds a request whose query string carries the given workspaceId. Passing + * `undefined` omits the param entirely; passing `''` produces `?workspaceId=`. + */ +const requestWithWorkspace = (workspaceId?: string) => { + const url = new URL('http://localhost:3000/api/providers/together/models') + if (workspaceId !== undefined) { + url.searchParams.set('workspaceId', workspaceId) + } + return createMockRequest('GET', undefined, {}, url.toString()) +} + +const fetchAuthHeader = () => { + const init = mockFetch.mock.calls[0]?.[1] as RequestInit | undefined + const headers = init?.headers as Record | undefined + return headers?.Authorization +} + +describe('GET /api/providers/together/models', () => { + beforeEach(() => { + vi.clearAllMocks() + vi.stubGlobal('fetch', mockFetch) + + mutableEnv.TOGETHER_API_KEY = undefined + mockIsProviderBlacklisted.mockReturnValue(false) + mockFilterBlacklistedModels.mockImplementation((models: string[]) => models) + mockGetBYOKKey.mockResolvedValue(null) + mockGetSession.mockResolvedValue(null) + mockGetUserEntityPermissions.mockResolvedValue(null) + }) + + it('returns empty models without calling fetch when the provider is blacklisted', async () => { + mockIsProviderBlacklisted.mockReturnValue(true) + + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('returns empty models when there is no workspaceId and no env key', async () => { + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('fetches with the env key and prefixes each model id with together/', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockFetch.mockResolvedValue( + okResponse([{ id: 'moonshotai/Kimi-K2-Instruct' }, { id: 'Qwen/Qwen2.5-72B-Instruct-Turbo' }]) + ) + + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + models: ['together/moonshotai/Kimi-K2-Instruct', 'together/Qwen/Qwen2.5-72B-Instruct-Turbo'], + }) + + expect(mockFetch).toHaveBeenCalledTimes(1) + expect(mockFetch.mock.calls[0][0]).toBe(TOGETHER_MODELS_URL) + expect(fetchAuthHeader()).toBe('Bearer env-together-key') + }) + + it('uses the BYOK key when a workspace, session, and permission are present', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockGetSession.mockResolvedValue({ user: { id: 'user-1' } }) + mockGetUserEntityPermissions.mockResolvedValue('admin') + mockGetBYOKKey.mockResolvedValue({ apiKey: 'byok-together-key' }) + mockFetch.mockResolvedValue(okResponse([{ id: 'moonshotai/Kimi-K2-Instruct' }])) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['together/moonshotai/Kimi-K2-Instruct'] }) + + expect(mockGetBYOKKey).toHaveBeenCalledWith('ws-1', 'together') + expect(fetchAuthHeader()).toBe('Bearer byok-together-key') + }) + + it('falls back to the env key when a workspaceId is given but there is no session', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockGetSession.mockResolvedValue(null) + mockFetch.mockResolvedValue(okResponse([{ id: 'moonshotai/Kimi-K2-Instruct' }])) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['together/moonshotai/Kimi-K2-Instruct'] }) + expect(mockGetBYOKKey).not.toHaveBeenCalled() + expect(fetchAuthHeader()).toBe('Bearer env-together-key') + }) + + it('falls back to the env key when the session user lacks workspace permission', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockGetSession.mockResolvedValue({ user: { id: 'user-1' } }) + mockGetUserEntityPermissions.mockResolvedValue(null) + mockFetch.mockResolvedValue(okResponse([{ id: 'moonshotai/Kimi-K2-Instruct' }])) + + const res = await GET(requestWithWorkspace('ws-1')) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['together/moonshotai/Kimi-K2-Instruct'] }) + expect(mockGetBYOKKey).not.toHaveBeenCalled() + expect(fetchAuthHeader()).toBe('Bearer env-together-key') + }) + + it('returns empty models when the upstream fetch responds non-ok', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockFetch.mockResolvedValue(errorResponse(401, 'Unauthorized')) + + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('returns empty models when the upstream fetch throws', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockFetch.mockRejectedValue(new Error('network down')) + + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: [] }) + }) + + it('returns a validation error for an empty workspaceId query param', async () => { + const res = await GET(requestWithWorkspace('')) + + expect(res.status).toBe(400) + const body = (await res.json()) as { error: string } + expect(body.error).toBe('Validation error') + expect(mockFetch).not.toHaveBeenCalled() + }) + + it('dedupes duplicate model ids from the upstream array', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockFetch.mockResolvedValue( + okResponse([ + { id: 'moonshotai/Kimi-K2-Instruct' }, + { id: 'moonshotai/Kimi-K2-Instruct' }, + { id: 'Qwen/Qwen2.5-72B-Instruct-Turbo' }, + ]) + ) + + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + models: ['together/moonshotai/Kimi-K2-Instruct', 'together/Qwen/Qwen2.5-72B-Instruct-Turbo'], + }) + }) + + it('applies the blacklist filter to the deduped model list', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockFilterBlacklistedModels.mockImplementation((models: string[]) => + models.filter((m) => !m.includes('Qwen')) + ) + mockFetch.mockResolvedValue( + okResponse([{ id: 'moonshotai/Kimi-K2-Instruct' }, { id: 'Qwen/Qwen2.5-72B-Instruct-Turbo' }]) + ) + + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ models: ['together/moonshotai/Kimi-K2-Instruct'] }) + expect(mockFilterBlacklistedModels).toHaveBeenCalledWith([ + 'together/moonshotai/Kimi-K2-Instruct', + 'together/Qwen/Qwen2.5-72B-Instruct-Turbo', + ]) + }) + + it('filters out non-chat model types (image, embedding, rerank, etc.)', async () => { + mutableEnv.TOGETHER_API_KEY = 'env-together-key' + mockFetch.mockResolvedValue( + okResponse([ + { id: 'meta-llama/Llama-3.3-70B-Instruct-Turbo', type: 'chat' }, + { id: 'black-forest-labs/FLUX.1-schnell', type: 'image' }, + { id: 'BAAI/bge-large-en-v1.5', type: 'embedding' }, + { id: 'Salesforce/Llama-Rank-V1', type: 'rerank' }, + { id: 'openai/whisper-large-v3', type: 'transcribe' }, + ]) + ) + + const res = await GET(requestWithWorkspace()) + + expect(res.status).toBe(200) + expect(await res.json()).toEqual({ + models: ['together/meta-llama/Llama-3.3-70B-Instruct-Turbo'], + }) + }) +}) diff --git a/apps/sim/app/api/providers/together/models/route.ts b/apps/sim/app/api/providers/together/models/route.ts new file mode 100644 index 00000000000..dcaa0dc0c5a --- /dev/null +++ b/apps/sim/app/api/providers/together/models/route.ts @@ -0,0 +1,105 @@ +import { createLogger } from '@sim/logger' +import { getErrorMessage } from '@sim/utils/errors' +import { type NextRequest, NextResponse } from 'next/server' +import { + providerModelsResponseSchema, + togetherProviderModelsQuerySchema, + togetherUpstreamResponseSchema, +} from '@/lib/api/contracts/providers' +import { validationErrorResponse } from '@/lib/api/server' +import { getBYOKKey } from '@/lib/api-key/byok' +import { getSession } from '@/lib/auth' +import { env } from '@/lib/core/config/env' +import { withRouteHandler } from '@/lib/core/utils/with-route-handler' +import { getUserEntityPermissions } from '@/lib/workspaces/permissions/utils' +import { filterBlacklistedModels, isProviderBlacklisted } from '@/providers/utils' + +const logger = createLogger('TogetherModelsAPI') + +/** Together's catalog includes non-text models; only chat models work with chat completions. */ +const NON_CHAT_MODEL_TYPES = new Set([ + 'image', + 'video', + 'audio', + 'transcribe', + 'embedding', + 'moderation', + 'rerank', +]) + +export const GET = withRouteHandler(async (request: NextRequest) => { + if (isProviderBlacklisted('together')) { + logger.info('Together provider is blacklisted, returning empty models') + return NextResponse.json({ models: [] }) + } + + let apiKey: string | undefined + + const queryValidation = togetherProviderModelsQuerySchema.safeParse({ + workspaceId: request.nextUrl.searchParams.get('workspaceId') ?? undefined, + }) + if (!queryValidation.success) return validationErrorResponse(queryValidation.error) + const { workspaceId } = queryValidation.data + if (workspaceId) { + const session = await getSession() + if (session?.user?.id) { + const permission = await getUserEntityPermissions(session.user.id, 'workspace', workspaceId) + if (permission) { + const byokResult = await getBYOKKey(workspaceId, 'together') + if (byokResult) { + apiKey = byokResult.apiKey + } + } + } + } + + if (!apiKey) { + apiKey = env.TOGETHER_API_KEY + } + + if (!apiKey) { + logger.info('No Together API key available, returning empty models') + return NextResponse.json({ models: [] }) + } + + try { + const response = await fetch('https://api.together.ai/v1/models', { + headers: { + Authorization: `Bearer ${apiKey}`, + 'Content-Type': 'application/json', + }, + cache: 'no-store', + }) + + if (!response.ok) { + logger.warn('Failed to fetch Together models', { + status: response.status, + statusText: response.statusText, + }) + return NextResponse.json({ models: [] }) + } + + const data = togetherUpstreamResponseSchema.parse(await response.json()) + + const allModels: string[] = [] + for (const model of data) { + if (model.type && NON_CHAT_MODEL_TYPES.has(model.type)) continue + allModels.push(`together/${model.id}`) + } + + const uniqueModels = Array.from(new Set(allModels)) + const models = filterBlacklistedModels(uniqueModels) + + logger.info('Successfully fetched Together models', { + count: models.length, + filtered: uniqueModels.length - models.length, + }) + + return NextResponse.json(providerModelsResponseSchema.parse({ models })) + } catch (error) { + logger.error('Error fetching Together models', { + error: getErrorMessage(error, 'Unknown error'), + }) + return NextResponse.json({ models: [] }) + } +}) diff --git a/apps/sim/app/workspace/[workspaceId]/providers/provider-models-loader.tsx b/apps/sim/app/workspace/[workspaceId]/providers/provider-models-loader.tsx index f2563a2b37c..0622fafa858 100644 --- a/apps/sim/app/workspace/[workspaceId]/providers/provider-models-loader.tsx +++ b/apps/sim/app/workspace/[workspaceId]/providers/provider-models-loader.tsx @@ -5,10 +5,13 @@ import { createLogger } from '@sim/logger' import { useParams } from 'next/navigation' import { useProviderModels } from '@/hooks/queries/providers' import { + updateBasetenProviderModels, updateFireworksProviderModels, updateLiteLLMProviderModels, + updateOllamaCloudProviderModels, updateOllamaProviderModels, updateOpenRouterProviderModels, + updateTogetherProviderModels, updateVLLMProviderModels, } from '@/providers/utils' import { type ProviderName, useProvidersStore } from '@/stores/providers' @@ -31,6 +34,8 @@ function useSyncProvider(provider: ProviderName, workspaceId?: string) { try { if (provider === 'ollama') { updateOllamaProviderModels(data.models) + } else if (provider === 'ollama-cloud') { + void updateOllamaCloudProviderModels(data.models) } else if (provider === 'vllm') { updateVLLMProviderModels(data.models) } else if (provider === 'litellm') { @@ -42,6 +47,10 @@ function useSyncProvider(provider: ProviderName, workspaceId?: string) { } } else if (provider === 'fireworks') { void updateFireworksProviderModels(data.models) + } else if (provider === 'together') { + void updateTogetherProviderModels(data.models) + } else if (provider === 'baseten') { + void updateBasetenProviderModels(data.models) } } catch (syncError) { logger.warn(`Failed to sync provider definitions for ${provider}`, syncError as Error) @@ -63,9 +72,12 @@ export function ProviderModelsLoader() { useSyncProvider('base') useSyncProvider('ollama') + useSyncProvider('ollama-cloud', workspaceId) useSyncProvider('vllm') useSyncProvider('litellm') useSyncProvider('openrouter') useSyncProvider('fireworks', workspaceId) + useSyncProvider('together', workspaceId) + useSyncProvider('baseten', workspaceId) return null } diff --git a/apps/sim/app/workspace/[workspaceId]/settings/components/byok/byok.tsx b/apps/sim/app/workspace/[workspaceId]/settings/components/byok/byok.tsx index 998f1c5dcf9..d289eea58dd 100644 --- a/apps/sim/app/workspace/[workspaceId]/settings/components/byok/byok.tsx +++ b/apps/sim/app/workspace/[workspaceId]/settings/components/byok/byok.tsx @@ -17,6 +17,7 @@ import { } from '@/components/emcn' import { AnthropicIcon, + BasetenIcon, BrandfetchIcon, ExaAIIcon, FindymailIcon, @@ -29,12 +30,14 @@ import { JinaAIIcon, LinkupIcon, MistralIcon, + OllamaIcon, OpenAIIcon, ParallelIcon, PeopleDataLabsIcon, PerplexityIcon, ProspeoIcon, SerperIcon, + TogetherIcon, WizaIcon, } from '@/components/icons' import { Input } from '@/components/ui' @@ -91,6 +94,27 @@ const PROVIDERS: { description: 'LLM calls', placeholder: 'Enter your Fireworks API key', }, + { + id: 'together', + name: 'Together AI', + icon: TogetherIcon, + description: 'LLM calls', + placeholder: 'Enter your Together AI API key', + }, + { + id: 'baseten', + name: 'Baseten', + icon: BasetenIcon, + description: 'LLM calls', + placeholder: 'Enter your Baseten API key', + }, + { + id: 'ollama-cloud', + name: 'Ollama Cloud', + icon: OllamaIcon, + description: 'LLM calls', + placeholder: 'Enter your Ollama API key', + }, { id: 'falai', name: 'Fal.ai', diff --git a/apps/sim/blocks/utils.ts b/apps/sim/blocks/utils.ts index 4a17b845263..4b1d0b556ed 100644 --- a/apps/sim/blocks/utils.ts +++ b/apps/sim/blocks/utils.ts @@ -50,18 +50,24 @@ export function getModelOptions() { const providersState = useProvidersStore.getState() const baseModels = providersState.providers.base.models const ollamaModels = providersState.providers.ollama.models + const ollamaCloudModels = providersState.providers['ollama-cloud'].models const vllmModels = providersState.providers.vllm.models const litellmModels = providersState.providers.litellm.models const openrouterModels = providersState.providers.openrouter.models const fireworksModels = providersState.providers.fireworks.models + const togetherModels = providersState.providers.together.models + const basetenModels = providersState.providers.baseten.models const allModels = Array.from( new Set([ ...baseModels, ...ollamaModels, + ...ollamaCloudModels, ...vllmModels, ...litellmModels, ...openrouterModels, ...fireworksModels, + ...togetherModels, + ...basetenModels, ]) ) diff --git a/apps/sim/components/icons.tsx b/apps/sim/components/icons.tsx index 2417e6acb58..e2c0e8946a6 100644 --- a/apps/sim/components/icons.tsx +++ b/apps/sim/components/icons.tsx @@ -3998,6 +3998,41 @@ export function OpenRouterIcon(props: SVGProps) { ) } +export function TogetherIcon(props: SVGProps) { + return ( + + + + + + ) +} + +export function BasetenIcon(props: SVGProps) { + return ( + + + + ) +} + export function MondayIcon(props: SVGProps) { return ( m.id) }, ollama: { models: [] }, + 'ollama-cloud': { models: [] }, vllm: { models: [] }, litellm: { models: [] }, openrouter: { models: [] }, fireworks: { models: [] }, + together: { models: [] }, + baseten: { models: [] }, }, } diff --git a/apps/sim/lib/copilot/tools/server/workflow/edit-workflow/validation.ts b/apps/sim/lib/copilot/tools/server/workflow/edit-workflow/validation.ts index e98e39a3967..94010df4479 100644 --- a/apps/sim/lib/copilot/tools/server/workflow/edit-workflow/validation.ts +++ b/apps/sim/lib/copilot/tools/server/workflow/edit-workflow/validation.ts @@ -369,7 +369,7 @@ export function validateValueForSubBlockType( blockType, field: fieldName, value, - error: `Unknown model id "${trimmed}" for block "${blockType}". Read components/blocks/${blockType}.json (the model.options array) for valid ids; prefer entries with recommended: true and avoid deprecated: true. For user-configured models (Ollama, vLLM, LiteLLM, OpenRouter, Fireworks), prefix the id with the provider slash, e.g. "ollama/llama3.1:8b".${suggestionText}`, + error: `Unknown model id "${trimmed}" for block "${blockType}". Read components/blocks/${blockType}.json (the model.options array) for valid ids; prefer entries with recommended: true and avoid deprecated: true. For user-configured models (Ollama, Ollama Cloud, vLLM, LiteLLM, OpenRouter, Fireworks, Together AI, Baseten), prefix the id with the provider slash, e.g. "ollama/llama3.1:8b" or "ollama-cloud/gpt-oss:120b".${suggestionText}`, }, } } diff --git a/apps/sim/lib/core/config/env.ts b/apps/sim/lib/core/config/env.ts index 3b871859295..223eb519524 100644 --- a/apps/sim/lib/core/config/env.ts +++ b/apps/sim/lib/core/config/env.ts @@ -132,6 +132,8 @@ export const env = createEnv({ LITELLM_BASE_URL: z.string().url().optional(), // LiteLLM proxy base URL (OpenAI-compatible) LITELLM_API_KEY: z.string().optional(), // Optional bearer token for LiteLLM FIREWORKS_API_KEY: z.string().optional(), // Optional Fireworks AI API key for model listing + TOGETHER_API_KEY: z.string().optional(), // Optional Together AI API key for model listing and inference + BASETEN_API_KEY: z.string().optional(), // Optional Baseten API key for model listing and inference COHERE_API_KEY: z.string().min(1).optional(), // Cohere API key for reranker (rerank-v4.0-pro, rerank-v4.0-fast, rerank-v3.5) COHERE_API_KEY_1: z.string().min(1).optional(), // Primary Cohere API key for rotation COHERE_API_KEY_2: z.string().min(1).optional(), // Additional Cohere API key for load balancing diff --git a/apps/sim/providers/attachments.ts b/apps/sim/providers/attachments.ts index d1b5d48c828..d9edad96fd5 100644 --- a/apps/sim/providers/attachments.ts +++ b/apps/sim/providers/attachments.ts @@ -22,6 +22,8 @@ export type AttachmentProvider = | 'mistral' | 'groq' | 'fireworks' + | 'together' + | 'baseten' | 'ollama' | 'vllm' | 'litellm' @@ -92,6 +94,8 @@ const PROVIDER_SUPPORTED_LABELS: Record = { mistral: 'images through image_url message parts', groq: 'images through image_url message parts on multimodal models', fireworks: 'images through image_url message parts on vision models', + together: 'images through image_url message parts on vision models', + baseten: 'images through image_url message parts on vision models', ollama: 'images through image_url message parts on vision models', vllm: 'images through image_url message parts on multimodal models', litellm: 'images through image_url message parts on multimodal models', @@ -109,7 +113,9 @@ export function getAttachmentProvider(providerId: ProviderId | string): Attachme if (providerId === 'mistral') return 'mistral' if (providerId === 'groq') return 'groq' if (providerId === 'fireworks') return 'fireworks' - if (providerId === 'ollama') return 'ollama' + if (providerId === 'together') return 'together' + if (providerId === 'baseten') return 'baseten' + if (providerId === 'ollama' || providerId === 'ollama-cloud') return 'ollama' if (providerId === 'vllm') return 'vllm' if (providerId === 'litellm') return 'litellm' if (providerId === 'xai') return 'xai' @@ -248,6 +254,8 @@ function isMimeTypeSupportedByProvider( case 'mistral': case 'groq': case 'fireworks': + case 'together': + case 'baseten': case 'ollama': case 'vllm': case 'litellm': diff --git a/apps/sim/providers/baseten/index.test.ts b/apps/sim/providers/baseten/index.test.ts new file mode 100644 index 00000000000..af5fa39f61b --- /dev/null +++ b/apps/sim/providers/baseten/index.test.ts @@ -0,0 +1,232 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCreate, + mockSupportsNativeStructuredOutputs, + mockPrepareToolsWithUsageControl, + mockExecuteTool, +} = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockSupportsNativeStructuredOutputs: vi.fn(), + mockPrepareToolsWithUsageControl: vi.fn(), + mockExecuteTool: vi.fn(), +})) + +vi.mock('openai', () => ({ + default: vi.fn().mockImplementation(() => ({ + chat: { completions: { create: mockCreate } }, + })), +})) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 5 })) + +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn().mockReturnValue([]), + getProviderDefaultModel: vi.fn().mockReturnValue('openai/gpt-oss-120b'), +})) + +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: vi.fn((messages) => messages), +})) + +vi.mock('@/providers/baseten/utils', () => ({ + supportsNativeStructuredOutputs: mockSupportsNativeStructuredOutputs, + createReadableStreamFromOpenAIStream: vi.fn(() => ({}) as ReadableStream), + checkForForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), +})) + +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn().mockReturnValue({ input: 0, output: 0, total: 0 }), + generateSchemaInstructions: vi.fn(() => 'SCHEMA_INSTRUCTIONS'), + prepareToolExecution: vi.fn(() => ({ toolParams: { x: 1 }, executionParams: { x: 1 } })), + prepareToolsWithUsageControl: mockPrepareToolsWithUsageControl, + sumToolCosts: vi.fn().mockReturnValue(0), +})) + +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +import { basetenProvider } from '@/providers/baseten/index' +import { ProviderError } from '@/providers/types' + +const textResponse = (content: string) => ({ + choices: [{ message: { content, tool_calls: [] } }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, +}) + +const toolCallResponse = () => ({ + choices: [ + { + message: { + content: null, + tool_calls: [ + { id: 'call_1', type: 'function', function: { name: 'my_tool', arguments: '{"x":1}' } }, + ], + }, + }, + ], + usage: { prompt_tokens: 8, completion_tokens: 4, total_tokens: 12 }, +}) + +const toolDef = { + id: 'my_tool', + name: 'my_tool', + description: '', + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, +} + +const callBody = (index: number) => mockCreate.mock.calls[index][0] +const lastCallBody = () => mockCreate.mock.calls.at(-1)?.[0] + +describe('basetenProvider', () => { + beforeEach(() => { + vi.clearAllMocks() + mockSupportsNativeStructuredOutputs.mockResolvedValue(true) + mockPrepareToolsWithUsageControl.mockImplementation((tools) => ({ + tools, + toolChoice: 'auto', + forcedTools: [], + })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + const baseRequest = { + model: 'baseten/openai/gpt-oss-120b', + systemPrompt: 'You are helpful.', + messages: [{ role: 'user' as const, content: 'Hello' }], + apiKey: 'bt-test-key', + } + + it('throws when the API key is missing', async () => { + await expect( + basetenProvider.executeRequest({ ...baseRequest, apiKey: undefined }) + ).rejects.toThrow('API key is required for Baseten') + }) + + it('returns content and token usage for a simple request', async () => { + mockCreate.mockResolvedValueOnce(textResponse('hi there')) + + const result = await basetenProvider.executeRequest(baseRequest) + + expect(result).toMatchObject({ + content: 'hi there', + model: 'openai/gpt-oss-120b', + tokens: { input: 10, output: 5, total: 15 }, + }) + }) + + it('strips only the leading baseten/ prefix from the model id', async () => { + mockCreate.mockResolvedValueOnce(textResponse('ok')) + + await basetenProvider.executeRequest(baseRequest) + + expect(callBody(0).model).toBe('openai/gpt-oss-120b') + }) + + it('wraps API errors in a ProviderError', async () => { + mockCreate.mockRejectedValueOnce(new Error('boom')) + + await expect(basetenProvider.executeRequest(baseRequest)).rejects.toBeInstanceOf(ProviderError) + }) + + it('streams directly when there are no tools', async () => { + mockCreate.mockResolvedValueOnce({}) + + const result = await basetenProvider.executeRequest({ ...baseRequest, stream: true }) + + expect(lastCallBody()).toMatchObject({ stream: true, stream_options: { include_usage: true } }) + expect(result).toHaveProperty('stream') + expect(result).toHaveProperty('execution') + }) + + it('sends a json_schema response_format with no strict field', async () => { + mockCreate.mockResolvedValueOnce(textResponse('{}')) + + await basetenProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' }, strict: true }, + }) + + expect(lastCallBody().response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'my_schema', schema: { type: 'object' } }, + }) + expect(lastCallBody().response_format.json_schema).not.toHaveProperty('strict') + }) + + it('falls back to json_object with prompt instructions when native is unsupported', async () => { + mockSupportsNativeStructuredOutputs.mockResolvedValue(false) + mockCreate.mockResolvedValueOnce(textResponse('{}')) + + await basetenProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' } }, + }) + + expect(lastCallBody().response_format).toEqual({ type: 'json_object' }) + expect(lastCallBody().messages.at(-1)).toEqual({ + role: 'user', + content: 'SCHEMA_INSTRUCTIONS', + }) + }) + + it('defers response_format to a final call when tools are active', async () => { + mockCreate + .mockResolvedValueOnce(textResponse('intermediate')) + .mockResolvedValueOnce(textResponse('{"done":true}')) + + await basetenProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' } }, + tools: [toolDef], + }) + + expect(mockCreate).toHaveBeenCalledTimes(2) + expect(callBody(0).response_format).toBeUndefined() + expect(callBody(0).tools).toBeDefined() + expect(callBody(1).response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'my_schema', schema: { type: 'object' } }, + }) + expect(callBody(1).tools).toBeUndefined() + }) + + it('runs the tool loop and threads tool results back into the conversation', async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse()) + .mockResolvedValueOnce(textResponse('final answer')) + + const result = await basetenProvider.executeRequest({ ...baseRequest, tools: [toolDef] }) + + expect(mockExecuteTool).toHaveBeenCalledWith('my_tool', { x: 1 }, expect.anything()) + expect(result).toMatchObject({ content: 'final answer' }) + expect((result as { toolCalls?: unknown[] }).toolCalls).toHaveLength(1) + + const followUpMessages = callBody(1).messages + expect(followUpMessages).toContainEqual( + expect.objectContaining({ role: 'assistant', tool_calls: expect.any(Array) }) + ) + expect(followUpMessages).toContainEqual( + expect.objectContaining({ role: 'tool', tool_call_id: 'call_1' }) + ) + }) + + it("forces tool_choice 'none' on the final streaming call after tools run", async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse()) + .mockResolvedValueOnce(textResponse('done')) + .mockResolvedValueOnce({}) + + await basetenProvider.executeRequest({ ...baseRequest, stream: true, tools: [toolDef] }) + + expect(mockCreate).toHaveBeenCalledTimes(3) + expect(lastCallBody()).toMatchObject({ tool_choice: 'none', stream: true }) + }) +}) diff --git a/apps/sim/providers/baseten/index.ts b/apps/sim/providers/baseten/index.ts new file mode 100644 index 00000000000..a1dd2cfb7c2 --- /dev/null +++ b/apps/sim/providers/baseten/index.ts @@ -0,0 +1,653 @@ +import { createLogger } from '@sim/logger' +import { getErrorMessage, toError } from '@sim/utils/errors' +import OpenAI from 'openai' +import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions' +import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { formatMessagesForProvider } from '@/providers/attachments' +import { + checkForForcedToolUsage, + createReadableStreamFromOpenAIStream, + supportsNativeStructuredOutputs, +} from '@/providers/baseten/utils' +import { getProviderDefaultModel, getProviderModels } from '@/providers/models' +import { enrichLastModelSegmentFromChatCompletions } from '@/providers/trace-enrichment' +import type { + FunctionCallResponse, + Message, + ProviderConfig, + ProviderRequest, + ProviderResponse, + TimeSegment, +} from '@/providers/types' +import { ProviderError } from '@/providers/types' +import { + calculateCost, + generateSchemaInstructions, + prepareToolExecution, + prepareToolsWithUsageControl, + sumToolCosts, +} from '@/providers/utils' +import { executeTool } from '@/tools' + +const logger = createLogger('BasetenProvider') + +/** + * Applies structured output configuration to a payload based on model capabilities. + * Uses native json_schema for supported models, falls back to json_object with prompt instructions. + */ +async function applyResponseFormat( + targetPayload: any, + messages: any[], + responseFormat: any, + model: string +): Promise { + const useNative = await supportsNativeStructuredOutputs(model) + + if (useNative) { + logger.info('Using native structured outputs for Baseten model', { model }) + targetPayload.response_format = { + type: 'json_schema', + json_schema: { + name: responseFormat.name || 'response_schema', + schema: responseFormat.schema || responseFormat, + }, + } + return messages + } + + logger.info('Using json_object mode with prompt instructions for Baseten model', { model }) + const schema = responseFormat.schema || responseFormat + const schemaInstructions = generateSchemaInstructions(schema, responseFormat.name) + targetPayload.response_format = { type: 'json_object' } + return [...messages, { role: 'user', content: schemaInstructions }] +} + +export const basetenProvider: ProviderConfig = { + id: 'baseten', + name: 'Baseten', + description: 'Fast inference for open-source models via Baseten Model APIs', + version: '1.0.0', + models: getProviderModels('baseten'), + defaultModel: getProviderDefaultModel('baseten'), + + executeRequest: async ( + request: ProviderRequest + ): Promise => { + if (!request.apiKey) { + throw new Error('API key is required for Baseten') + } + + const client = new OpenAI({ + apiKey: request.apiKey, + baseURL: 'https://inference.baseten.co/v1', + }) + + const requestedModel = request.model.replace(/^baseten\//, '') + + logger.info('Preparing Baseten request', { + model: requestedModel, + hasSystemPrompt: !!request.systemPrompt, + hasMessages: !!request.messages?.length, + hasTools: !!request.tools?.length, + toolCount: request.tools?.length || 0, + hasResponseFormat: !!request.responseFormat, + stream: !!request.stream, + }) + + const allMessages: Message[] = [] + + if (request.systemPrompt) { + allMessages.push({ role: 'system', content: request.systemPrompt }) + } + + if (request.context) { + allMessages.push({ role: 'user', content: request.context }) + } + + if (request.messages) { + allMessages.push(...request.messages) + } + const formattedMessages = formatMessagesForProvider(allMessages, 'baseten') as Message[] + + const tools = request.tools?.length + ? request.tools.map((tool) => ({ + type: 'function', + function: { + name: tool.id, + description: tool.description, + parameters: tool.parameters, + }, + })) + : undefined + + const payload: any = { + model: requestedModel, + messages: formattedMessages, + } + + if (request.temperature !== undefined) payload.temperature = request.temperature + if (request.maxTokens != null) payload.max_tokens = request.maxTokens + + let preparedTools: ReturnType | null = null + let hasActiveTools = false + if (tools?.length) { + preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'baseten') + const { tools: filteredTools, toolChoice } = preparedTools + if (filteredTools?.length && toolChoice) { + payload.tools = filteredTools + payload.tool_choice = toolChoice + hasActiveTools = true + } + } + + const providerStartTime = Date.now() + const providerStartTimeISO = new Date(providerStartTime).toISOString() + + try { + if (request.responseFormat && !hasActiveTools) { + payload.messages = await applyResponseFormat( + payload, + payload.messages, + request.responseFormat, + requestedModel + ) + } + + if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) { + const streamingParams: ChatCompletionCreateParamsStreaming = { + ...payload, + stream: true, + stream_options: { include_usage: true }, + } + const streamResponse = await client.chat.completions.create( + streamingParams, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const streamingResult = { + stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => { + streamingResult.execution.output.content = content + streamingResult.execution.output.tokens = { + input: usage.prompt_tokens, + output: usage.completion_tokens, + total: usage.total_tokens, + } + + const costResult = calculateCost( + requestedModel, + usage.prompt_tokens, + usage.completion_tokens + ) + streamingResult.execution.output.cost = { + input: costResult.input, + output: costResult.output, + total: costResult.total, + } + + const end = Date.now() + const endISO = new Date(end).toISOString() + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = endISO + streamingResult.execution.output.providerTiming.duration = end - providerStartTime + if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { + streamingResult.execution.output.providerTiming.timeSegments[0].endTime = end + streamingResult.execution.output.providerTiming.timeSegments[0].duration = + end - providerStartTime + } + } + }), + execution: { + success: true, + output: { + content: '', + model: requestedModel, + tokens: { input: 0, output: 0, total: 0 }, + toolCalls: undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + timeSegments: [ + { + type: 'model', + name: request.model, + startTime: providerStartTime, + endTime: Date.now(), + duration: Date.now() - providerStartTime, + }, + ], + }, + cost: { input: 0, output: 0, total: 0 }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + }, + } as StreamingExecution + + return streamingResult as StreamingExecution + } + + const initialCallTime = Date.now() + const originalToolChoice = payload.tool_choice + const forcedTools = preparedTools?.forcedTools || [] + let usedForcedTools: string[] = [] + + let currentResponse = await client.chat.completions.create( + payload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + const firstResponseTime = Date.now() - initialCallTime + + let content = currentResponse.choices[0]?.message?.content || '' + const tokens = { + input: currentResponse.usage?.prompt_tokens || 0, + output: currentResponse.usage?.completion_tokens || 0, + total: currentResponse.usage?.total_tokens || 0, + } + const toolCalls: FunctionCallResponse[] = [] + const toolResults: Record[] = [] + const currentMessages = [...formattedMessages] + let iterationCount = 0 + let modelTime = firstResponseTime + let toolsTime = 0 + let hasUsedForcedTool = false + const timeSegments: TimeSegment[] = [ + { + type: 'model', + name: request.model, + startTime: initialCallTime, + endTime: initialCallTime + firstResponseTime, + duration: firstResponseTime, + }, + ] + + const forcedToolResult = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = forcedToolResult.hasUsedForcedTool + usedForcedTools = forcedToolResult.usedForcedTools + + while (iterationCount < MAX_TOOL_ITERATIONS) { + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + } + + const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls + + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + toolCallsInResponse, + { model: request.model, provider: 'baseten' } + ) + + if (!toolCallsInResponse || toolCallsInResponse.length === 0) { + break + } + + const toolsStartTime = Date.now() + + const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => { + const toolCallStartTime = Date.now() + const toolName = toolCall.function.name + + try { + const toolArgs = JSON.parse(toolCall.function.arguments) + const tool = request.tools?.find((t) => t.id === toolName) + + if (!tool) return null + + const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) + const toolCallEndTime = Date.now() + + return { + toolCall, + toolName, + toolParams, + result, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } + } catch (error) { + const toolCallEndTime = Date.now() + logger.error('Error processing tool call (Baseten):', { + error: toError(error).message, + toolName, + }) + + return { + toolCall, + toolName, + toolParams: {}, + result: { + success: false, + output: undefined, + error: getErrorMessage(error, 'Tool execution failed'), + }, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } + } + }) + + const executionResults = await Promise.allSettled(toolExecutionPromises) + + currentMessages.push({ + role: 'assistant', + content: null, + tool_calls: toolCallsInResponse.map((tc) => ({ + id: tc.id, + type: 'function', + function: { + name: tc.function.name, + arguments: tc.function.arguments, + }, + })), + }) + + for (const settledResult of executionResults) { + if (settledResult.status === 'rejected' || !settledResult.value) continue + + const { toolCall, toolName, toolParams, result, startTime, endTime, duration } = + settledResult.value + + timeSegments.push({ + type: 'tool', + name: toolName, + startTime: startTime, + endTime: endTime, + duration: duration, + toolCallId: toolCall.id, + }) + + let resultContent: any + if (result.success) { + toolResults.push(result.output!) + resultContent = result.output + } else { + resultContent = { + error: true, + message: result.error || 'Tool execution failed', + tool: toolName, + } + } + + toolCalls.push({ + name: toolName, + arguments: toolParams, + startTime: new Date(startTime).toISOString(), + endTime: new Date(endTime).toISOString(), + duration: duration, + result: resultContent, + success: result.success, + }) + + currentMessages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(resultContent), + }) + } + + const thisToolsTime = Date.now() - toolsStartTime + toolsTime += thisToolsTime + + const nextPayload = { + ...payload, + messages: currentMessages, + } + + if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) { + const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool)) + if (remainingTools.length > 0) { + nextPayload.tool_choice = { type: 'function', function: { name: remainingTools[0] } } + } else { + nextPayload.tool_choice = 'auto' + } + } + + const nextModelStartTime = Date.now() + currentResponse = await client.chat.completions.create( + nextPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + const nextForcedToolResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = nextForcedToolResult.hasUsedForcedTool + usedForcedTools = nextForcedToolResult.usedForcedTools + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime + timeSegments.push({ + type: 'model', + name: request.model, + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) + modelTime += thisModelTime + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + } + if (currentResponse.usage) { + tokens.input += currentResponse.usage.prompt_tokens || 0 + tokens.output += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } + iterationCount++ + } + + if (iterationCount === MAX_TOOL_ITERATIONS) { + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + currentResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: 'baseten' } + ) + } + + if (request.stream) { + const accumulatedCost = calculateCost(requestedModel, tokens.input, tokens.output) + + const streamingParams: ChatCompletionCreateParamsStreaming = { + ...payload, + messages: [...currentMessages], + tool_choice: 'none', + stream: true, + stream_options: { include_usage: true }, + } + + if (request.responseFormat) { + ;(streamingParams as any).messages = await applyResponseFormat( + streamingParams as any, + streamingParams.messages, + request.responseFormat, + requestedModel + ) + } + + const streamResponse = await client.chat.completions.create( + streamingParams, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const streamingResult = { + stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => { + streamingResult.execution.output.content = content + streamingResult.execution.output.tokens = { + input: tokens.input + usage.prompt_tokens, + output: tokens.output + usage.completion_tokens, + total: tokens.total + usage.total_tokens, + } + + const streamCost = calculateCost( + requestedModel, + usage.prompt_tokens, + usage.completion_tokens + ) + const tc = sumToolCosts(toolResults) + streamingResult.execution.output.cost = { + input: accumulatedCost.input + streamCost.input, + output: accumulatedCost.output + streamCost.output, + toolCost: tc || undefined, + total: accumulatedCost.total + streamCost.total + tc, + } + }), + execution: { + success: true, + output: { + content: '', + model: requestedModel, + tokens: { input: tokens.input, output: tokens.output, total: tokens.total }, + toolCalls: + toolCalls.length > 0 + ? { + list: toolCalls, + count: toolCalls.length, + } + : undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + cost: { + input: accumulatedCost.input, + output: accumulatedCost.output, + total: accumulatedCost.total, + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + }, + } as StreamingExecution + + return streamingResult as StreamingExecution + } + + if (request.responseFormat && hasActiveTools) { + const finalPayload: any = { + model: payload.model, + messages: [...currentMessages], + } + if (payload.temperature !== undefined) { + finalPayload.temperature = payload.temperature + } + if (payload.max_tokens !== undefined) { + finalPayload.max_tokens = payload.max_tokens + } + + finalPayload.messages = await applyResponseFormat( + finalPayload, + finalPayload.messages, + request.responseFormat, + requestedModel + ) + + const finalStartTime = Date.now() + const finalResponse = await client.chat.completions.create( + finalPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + const finalEndTime = Date.now() + const finalDuration = finalEndTime - finalStartTime + + timeSegments.push({ + type: 'model', + name: 'Final structured response', + startTime: finalStartTime, + endTime: finalEndTime, + duration: finalDuration, + }) + modelTime += finalDuration + + if (finalResponse.choices[0]?.message?.content) { + content = finalResponse.choices[0].message.content + } + if (finalResponse.usage) { + tokens.input += finalResponse.usage.prompt_tokens || 0 + tokens.output += finalResponse.usage.completion_tokens || 0 + tokens.total += finalResponse.usage.total_tokens || 0 + } + + enrichLastModelSegmentFromChatCompletions( + timeSegments, + finalResponse, + finalResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: 'baseten' } + ) + } + + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + return { + content, + model: requestedModel, + tokens, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + toolResults: toolResults.length > 0 ? toolResults : undefined, + timing: { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + } + } catch (error) { + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + const errorDetails: Record = { + error: toError(error).message, + duration: totalDuration, + } + if (error && typeof error === 'object') { + const err = error as any + if (err.status) errorDetails.status = err.status + if (err.code) errorDetails.code = err.code + if (err.type) errorDetails.type = err.type + if (err.error?.message) errorDetails.providerMessage = err.error.message + if (err.error?.metadata) errorDetails.metadata = err.error.metadata + } + + logger.error('Error in Baseten request:', errorDetails) + throw new ProviderError(toError(error).message, { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + }) + } + }, +} diff --git a/apps/sim/providers/baseten/utils.ts b/apps/sim/providers/baseten/utils.ts new file mode 100644 index 00000000000..ca5cf5dc5c0 --- /dev/null +++ b/apps/sim/providers/baseten/utils.ts @@ -0,0 +1,41 @@ +import type { ChatCompletionChunk } from 'openai/resources/chat/completions' +import type { CompletionUsage } from 'openai/resources/completions' +import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils' + +/** + * Checks if a model supports native structured outputs (json_schema). + * Baseten Model APIs support structured outputs across their OpenAI-compatible inference API. + */ +export async function supportsNativeStructuredOutputs(_modelId: string): Promise { + return true +} + +/** + * Creates a ReadableStream from a Baseten streaming response. + * Uses the shared OpenAI-compatible streaming utility. + */ +export function createReadableStreamFromOpenAIStream( + openaiStream: AsyncIterable, + onComplete?: (content: string, usage: CompletionUsage) => void +): ReadableStream { + return createOpenAICompatibleStream(openaiStream, 'Baseten', onComplete) +} + +/** + * Checks if a forced tool was used in a Baseten response. + * Uses the shared OpenAI-compatible forced tool usage helper. + */ +export function checkForForcedToolUsage( + response: any, + toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, + forcedTools: string[], + usedForcedTools: string[] +): { hasUsedForcedTool: boolean; usedForcedTools: string[] } { + return checkForForcedToolUsageOpenAI( + response, + toolChoice, + 'Baseten', + forcedTools, + usedForcedTools + ) +} diff --git a/apps/sim/providers/models.ts b/apps/sim/providers/models.ts index 6c98ea3623b..ca008901b84 100644 --- a/apps/sim/providers/models.ts +++ b/apps/sim/providers/models.ts @@ -11,6 +11,7 @@ import type React from 'react' import { AnthropicIcon, AzureIcon, + BasetenIcon, BedrockIcon, CerebrasIcon, DeepseekIcon, @@ -22,6 +23,7 @@ import { OllamaIcon, OpenAIIcon, OpenRouterIcon, + TogetherIcon, VertexIcon, VllmIcon, xAIIcon, @@ -98,6 +100,38 @@ export const PROVIDER_DEFINITIONS: Record = { contextInformationAvailable: false, models: [], }, + together: { + id: 'together', + name: 'Together AI', + description: 'Fast inference for open-source models via Together AI', + defaultModel: '', + modelPatterns: [/^together\//], + icon: TogetherIcon, + color: '#EF2CC1', + isReseller: true, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + contextInformationAvailable: false, + models: [], + }, + baseten: { + id: 'baseten', + name: 'Baseten', + description: 'Fast inference for open-source models via Baseten Model APIs', + defaultModel: '', + modelPatterns: [/^baseten\//], + icon: BasetenIcon, + color: '#1A1A2E', + isReseller: true, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + contextInformationAvailable: false, + models: [], + }, openrouter: { id: 'openrouter', name: 'OpenRouter', @@ -113,6 +147,21 @@ export const PROVIDER_DEFINITIONS: Record = { contextInformationAvailable: false, models: [], }, + 'ollama-cloud': { + id: 'ollama-cloud', + name: 'Ollama Cloud', + description: 'Hosted open-source models via Ollama Cloud (bring your own key)', + defaultModel: '', + modelPatterns: [/^ollama-cloud\//], + icon: OllamaIcon, + isReseller: true, + capabilities: { + temperature: { min: 0, max: 2 }, + toolUsageControl: true, + }, + contextInformationAvailable: false, + models: [], + }, vllm: { id: 'vllm', name: 'vLLM', @@ -2839,10 +2888,13 @@ export function getProviderModels(providerId: string): string[] { export const DYNAMIC_MODEL_PROVIDERS = [ 'ollama', + 'ollama-cloud', 'vllm', 'litellm', 'openrouter', 'fireworks', + 'together', + 'baseten', ] as const function getAllStaticModelIds(): string[] { @@ -2897,7 +2949,19 @@ export function suggestModelIdsForUnknownModel(_modelId: string, limit = 5): str export function getBaseModelProviders(): Record { return Object.entries(PROVIDER_DEFINITIONS) - .filter(([providerId]) => !['ollama', 'vllm', 'litellm', 'openrouter'].includes(providerId)) + .filter( + ([providerId]) => + ![ + 'ollama', + 'ollama-cloud', + 'vllm', + 'litellm', + 'openrouter', + 'fireworks', + 'together', + 'baseten', + ].includes(providerId) + ) .reduce( (map, [providerId, provider]) => { provider.models.forEach((model) => { @@ -3098,6 +3162,42 @@ export function updateFireworksModels(models: string[]): void { })) } +export function updateTogetherModels(models: string[]): void { + PROVIDER_DEFINITIONS.together.models = models.map((modelId) => ({ + id: modelId, + pricing: { + input: 0, + output: 0, + updatedAt: new Date().toISOString().split('T')[0], + }, + capabilities: {}, + })) +} + +export function updateBasetenModels(models: string[]): void { + PROVIDER_DEFINITIONS.baseten.models = models.map((modelId) => ({ + id: modelId, + pricing: { + input: 0, + output: 0, + updatedAt: new Date().toISOString().split('T')[0], + }, + capabilities: {}, + })) +} + +export function updateOllamaCloudModels(models: string[]): void { + PROVIDER_DEFINITIONS['ollama-cloud'].models = models.map((modelId) => ({ + id: modelId, + pricing: { + input: 0, + output: 0, + updatedAt: new Date().toISOString().split('T')[0], + }, + capabilities: {}, + })) +} + export function updateOpenRouterModels(models: string[]): void { PROVIDER_DEFINITIONS.openrouter.models = models.map((modelId) => ({ id: modelId, diff --git a/apps/sim/providers/ollama-cloud/index.test.ts b/apps/sim/providers/ollama-cloud/index.test.ts new file mode 100644 index 00000000000..c4a1f921f58 --- /dev/null +++ b/apps/sim/providers/ollama-cloud/index.test.ts @@ -0,0 +1,361 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +type StreamUsage = { prompt_tokens: number; completion_tokens: number; total_tokens: number } + +const { mockCreate, mockExecuteTool, streamOnComplete, MockAPIError } = vi.hoisted(() => { + class MockAPIError extends Error { + status?: number + code?: string | null + type?: string + constructor(message: string, opts: { status?: number; code?: string; type?: string } = {}) { + super(message) + this.name = 'APIError' + this.status = opts.status + this.code = opts.code + this.type = opts.type + } + } + return { + mockCreate: vi.fn(), + mockExecuteTool: vi.fn(), + streamOnComplete: { + current: undefined as undefined | ((content: string, usage: StreamUsage) => void), + }, + MockAPIError, + } +}) + +const mockOpenAIConstructor = vi.hoisted(() => vi.fn()) + +vi.mock('openai', () => { + const OpenAI = vi.fn((opts: unknown) => { + mockOpenAIConstructor(opts) + return { chat: { completions: { create: mockCreate } } } + }) + ;(OpenAI as unknown as { APIError: typeof MockAPIError }).APIError = MockAPIError + return { default: OpenAI } +}) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 20 })) +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn().mockReturnValue([]), + getProviderDefaultModel: vi.fn().mockReturnValue(''), +})) +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: (messages: unknown) => messages, +})) +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) +vi.mock('@/providers/ollama-cloud/utils', () => ({ + createReadableStreamFromOllamaCloudStream: ( + _stream: unknown, + onComplete: (content: string, usage: StreamUsage) => void + ) => { + streamOnComplete.current = onComplete + return 'OLLAMA_CLOUD_STREAM' + }, +})) +vi.mock('@/providers/utils', () => ({ + calculateCost: () => ({ input: 0, output: 0, total: 0, pricing: null }), + generateSchemaInstructions: () => 'SCHEMA_INSTRUCTIONS', + prepareToolExecution: (_tool: unknown, args: Record) => ({ + toolParams: args, + executionParams: args, + }), + sumToolCosts: () => 0, +})) +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +import { ollamaCloudProvider } from '@/providers/ollama-cloud' +import type { ProviderRequest, ProviderResponse, ProviderToolConfig } from '@/providers/types' + +interface StreamingResult { + stream: string + execution: { + output: { + content: string + model: string + tokens: { input: number; output: number; total: number } + toolCalls?: { list: unknown[]; count: number } + cost?: { input: number; output: number; total: number } + } + } +} + +type ToolCallChunk = { id: string; type: 'function'; function: { name: string; arguments: string } } + +function completion( + opts: { content?: string | null; toolCalls?: ToolCallChunk[]; usage?: StreamUsage } = {} +) { + return { + choices: [{ message: { content: opts.content ?? null, tool_calls: opts.toolCalls } }], + usage: opts.usage ?? { prompt_tokens: 5, completion_tokens: 3, total_tokens: 8 }, + } +} + +function makeTool(id: string, usageControl?: 'auto' | 'force' | 'none'): ProviderToolConfig { + return { + id, + name: id, + description: `${id} tool`, + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, + ...(usageControl ? { usageControl } : {}), + } +} + +const baseRequest: ProviderRequest = { + model: 'ollama-cloud/gpt-oss:120b', + messages: [{ role: 'user', content: 'hi' }], + apiKey: 'oc-test-key', +} + +describe('ollamaCloudProvider.executeRequest', () => { + beforeEach(() => { + vi.clearAllMocks() + streamOnComplete.current = undefined + mockCreate.mockResolvedValue(completion({ content: 'hello' })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + it('throws when the API key is missing (BYOK is required)', async () => { + await expect( + ollamaCloudProvider.executeRequest({ ...baseRequest, apiKey: undefined }) + ).rejects.toThrow('API key is required for Ollama Cloud') + }) + + it('builds the OpenAI client with the cloud base URL and the user key', async () => { + await ollamaCloudProvider.executeRequest(baseRequest) + expect(mockOpenAIConstructor).toHaveBeenCalledWith( + expect.objectContaining({ + apiKey: 'oc-test-key', + baseURL: 'https://ollama.com/v1', + }) + ) + }) + + it('strips the ollama-cloud/ prefix before calling the API and reports the stripped model id', async () => { + const result = (await ollamaCloudProvider.executeRequest(baseRequest)) as ProviderResponse + expect(mockCreate.mock.calls[0][0].model).toBe('gpt-oss:120b') + expect(result).toMatchObject({ content: 'hello', model: 'gpt-oss:120b' }) + }) + + it('assembles system, context, then history in order and forwards params', async () => { + await ollamaCloudProvider.executeRequest({ + ...baseRequest, + systemPrompt: 'be nice', + context: 'ctx', + temperature: 0.5, + maxTokens: 128, + }) + + const payload = mockCreate.mock.calls[0][0] + expect(payload.messages).toEqual([ + { role: 'system', content: 'be nice' }, + { role: 'user', content: 'ctx' }, + { role: 'user', content: 'hi' }, + ]) + expect(payload.temperature).toBe(0.5) + expect(payload.max_tokens).toBe(128) + }) + + it('returns content verbatim (keeps ```json fences) when no responseFormat', async () => { + const fenced = '```json\n{"a":1}\n```' + mockCreate.mockResolvedValue(completion({ content: fenced })) + const result = (await ollamaCloudProvider.executeRequest(baseRequest)) as ProviderResponse + expect(result.content).toBe(fenced) + }) + + it('strips ```json fences and requests JSON mode with schema instructions when responseFormat is set', async () => { + mockCreate.mockResolvedValue(completion({ content: '```json\n{"a":1}\n```' })) + const result = (await ollamaCloudProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'r', schema: { type: 'object' }, strict: true }, + })) as ProviderResponse + expect(result.content).toBe('{"a":1}') + const payload = mockCreate.mock.calls[0][0] + expect(payload.response_format).toEqual({ type: 'json_object' }) + expect(payload.messages.at(-1)).toEqual({ role: 'user', content: 'SCHEMA_INSTRUCTIONS' }) + }) + + it('runs the tool loop: parses string args, feeds results back, then terminates', async () => { + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{"x":1}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'done' })) + + const result = (await ollamaCloudProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('mytool')], + })) as ProviderResponse + + expect(mockExecuteTool).toHaveBeenCalledWith('mytool', { x: 1 }, expect.anything()) + expect(mockCreate).toHaveBeenCalledTimes(2) + expect(result.content).toBe('done') + expect(result.toolCalls).toEqual([ + expect.objectContaining({ name: 'mytool', success: true, arguments: { x: 1 } }), + ]) + expect(result.toolResults).toEqual([{ ok: true }]) + + const followUp = mockCreate.mock.calls[1][0].messages + expect(followUp).toContainEqual( + expect.objectContaining({ + role: 'assistant', + content: null, + tool_calls: [ + expect.objectContaining({ + id: 'call_1', + function: { name: 'mytool', arguments: '{"x":1}' }, + }), + ], + }) + ) + expect(followUp).toContainEqual({ + role: 'tool', + tool_call_id: 'call_1', + content: JSON.stringify({ ok: true }), + }) + }) + + it('records a failed tool result without aborting the loop', async () => { + mockExecuteTool.mockResolvedValue({ success: false, error: 'boom' }) + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'recovered' })) + + const result = (await ollamaCloudProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('mytool')], + })) as ProviderResponse + + expect(result.content).toBe('recovered') + expect(result.toolCalls?.[0]).toMatchObject({ name: 'mytool', success: false }) + const toolMsg = mockCreate.mock.calls[1][0].messages.find( + (m: { role: string }) => m.role === 'tool' + ) + expect(JSON.parse(toolMsg.content)).toMatchObject({ error: true, message: 'boom' }) + }) + + it('executes parallel tool calls from a single response', async () => { + mockExecuteTool + .mockResolvedValueOnce({ success: true, output: { from: 'a' } }) + .mockResolvedValueOnce({ success: true, output: { from: 'b' } }) + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_a', type: 'function', function: { name: 'a', arguments: '{}' } }, + { id: 'call_b', type: 'function', function: { name: 'b', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'summary' })) + + const result = (await ollamaCloudProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('a'), makeTool('b')], + })) as ProviderResponse + + expect(mockExecuteTool).toHaveBeenCalledTimes(2) + expect(result.toolCalls?.map((c) => c.name)).toEqual(['a', 'b']) + }) + + it('filters out tools with usageControl "none"', async () => { + await ollamaCloudProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('keep'), makeTool('drop', 'none')], + }) + const sent = mockCreate.mock.calls[0][0].tools + expect(sent.map((t: { function: { name: string } }) => t.function.name)).toEqual(['keep']) + }) + + it('never forces tools (Ollama Cloud ignores tool_choice) and keeps "auto"', async () => { + await ollamaCloudProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('forced', 'force')], + }) + const payload = mockCreate.mock.calls[0][0] + expect(payload.tool_choice).toBe('auto') + expect(payload.tools.map((t: { function: { name: string } }) => t.function.name)).toEqual([ + 'forced', + ]) + }) + + it('surfaces an OpenAI APIError message through ProviderError', async () => { + mockCreate.mockRejectedValue( + new MockAPIError('model not found', { + status: 404, + code: 'not_found', + type: 'invalid_request_error', + }) + ) + await expect(ollamaCloudProvider.executeRequest(baseRequest)).rejects.toThrow('model not found') + }) + + it('streams content and usage when no tools are used', async () => { + const result = (await ollamaCloudProvider.executeRequest({ + ...baseRequest, + stream: true, + })) as unknown as StreamingResult + + expect(result.stream).toBe('OLLAMA_CLOUD_STREAM') + expect(mockCreate.mock.calls[0][0].stream_options).toEqual({ include_usage: true }) + expect(result.execution.output.model).toBe('gpt-oss:120b') + + streamOnComplete.current?.('streamed text', { + prompt_tokens: 4, + completion_tokens: 6, + total_tokens: 10, + }) + expect(result.execution.output.content).toBe('streamed text') + expect(result.execution.output.tokens).toMatchObject({ input: 4, output: 6, total: 10 }) + }) + + it('streams the final response after a tool loop and removes tools/tool_choice', async () => { + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'intermediate' })) + + const result = (await ollamaCloudProvider.executeRequest({ + ...baseRequest, + stream: true, + tools: [makeTool('mytool')], + })) as unknown as StreamingResult + + expect(result.stream).toBe('OLLAMA_CLOUD_STREAM') + expect(mockExecuteTool).toHaveBeenCalledTimes(1) + + const finalCall = mockCreate.mock.calls[2][0] + expect(finalCall.tools).toBeUndefined() + expect(finalCall.tool_choice).toBeUndefined() + + streamOnComplete.current?.('final answer', { + prompt_tokens: 2, + completion_tokens: 4, + total_tokens: 6, + }) + expect(result.execution.output.content).toBe('final answer') + expect(result.execution.output.toolCalls).toMatchObject({ count: 1 }) + }) +}) diff --git a/apps/sim/providers/ollama-cloud/index.ts b/apps/sim/providers/ollama-cloud/index.ts new file mode 100644 index 00000000000..0f782f7e24d --- /dev/null +++ b/apps/sim/providers/ollama-cloud/index.ts @@ -0,0 +1,47 @@ +import { createLogger } from '@sim/logger' +import OpenAI from 'openai' +import type { StreamingExecution } from '@/executor/types' +import { getProviderDefaultModel, getProviderModels } from '@/providers/models' +import { executeOllamaProviderRequest } from '@/providers/ollama/core' +import { createReadableStreamFromOllamaCloudStream } from '@/providers/ollama-cloud/utils' +import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types' + +const logger = createLogger('OllamaCloudProvider') + +/** Ollama Cloud OpenAI-compatible endpoint. BYOK only — Sim never hosts a key or bills usage. */ +const OLLAMA_CLOUD_BASE_URL = 'https://ollama.com/v1' + +export const ollamaCloudProvider: ProviderConfig = { + id: 'ollama-cloud', + name: 'Ollama Cloud', + description: 'Hosted open-source models via Ollama Cloud (bring your own key)', + version: '1.0.0', + models: getProviderModels('ollama-cloud'), + defaultModel: getProviderDefaultModel('ollama-cloud'), + + executeRequest: async ( + request: ProviderRequest + ): Promise => { + const apiKey = request.apiKey + if (!apiKey) { + throw new Error('API key is required for Ollama Cloud') + } + + const requestedModel = request.model.replace(/^ollama-cloud\//, '') + + return executeOllamaProviderRequest( + { ...request, model: requestedModel }, + { + providerId: 'ollama-cloud', + providerLabel: 'Ollama Cloud', + createClient: () => + new OpenAI({ + apiKey, + baseURL: OLLAMA_CLOUD_BASE_URL, + }), + createStream: createReadableStreamFromOllamaCloudStream, + logger, + } + ) + }, +} diff --git a/apps/sim/providers/ollama-cloud/utils.ts b/apps/sim/providers/ollama-cloud/utils.ts new file mode 100644 index 00000000000..3ab364c66bb --- /dev/null +++ b/apps/sim/providers/ollama-cloud/utils.ts @@ -0,0 +1,14 @@ +import type { ChatCompletionChunk } from 'openai/resources/chat/completions' +import type { CompletionUsage } from 'openai/resources/completions' +import { createOpenAICompatibleStream } from '@/providers/utils' + +/** + * Creates a ReadableStream from an Ollama Cloud streaming response. + * Uses the shared OpenAI-compatible streaming utility. + */ +export function createReadableStreamFromOllamaCloudStream( + ollamaStream: AsyncIterable, + onComplete?: (content: string, usage: CompletionUsage) => void +): ReadableStream { + return createOpenAICompatibleStream(ollamaStream, 'Ollama Cloud', onComplete) +} diff --git a/apps/sim/providers/ollama/core.ts b/apps/sim/providers/ollama/core.ts new file mode 100644 index 00000000000..3021d410919 --- /dev/null +++ b/apps/sim/providers/ollama/core.ts @@ -0,0 +1,677 @@ +import type { Logger } from '@sim/logger' +import { getErrorMessage } from '@sim/utils/errors' +import OpenAI from 'openai' +import type { + ChatCompletionChunk, + ChatCompletionCreateParamsStreaming, +} from 'openai/resources/chat/completions' +import type { CompletionUsage } from 'openai/resources/completions' +import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { formatMessagesForProvider } from '@/providers/attachments' +import { enrichLastModelSegmentFromChatCompletions } from '@/providers/trace-enrichment' +import type { Message, ProviderRequest, ProviderResponse, TimeSegment } from '@/providers/types' +import { ProviderError } from '@/providers/types' +import { + calculateCost, + generateSchemaInstructions, + prepareToolExecution, + sumToolCosts, +} from '@/providers/utils' +import { executeTool } from '@/tools' + +/** + * Ollama enforces JSON mode (`json_object`) but ignores `json_schema`, so + * structured outputs use JSON mode with the schema described in-prompt. Mutates + * `payload.response_format` and returns the messages with instructions appended. + */ +function applyJsonResponseFormat( + payload: { response_format?: unknown }, + messages: Message[], + responseFormat: NonNullable +): Message[] { + payload.response_format = { type: 'json_object' } + const schema = responseFormat.schema || responseFormat + return [ + ...messages, + { role: 'user', content: generateSchemaInstructions(schema, responseFormat.name) }, + ] +} + +/** + * Per-provider hooks for the shared Ollama execution logic. The self-hosted + * `ollama` and hosted `ollama-cloud` providers differ only in client + * construction and labels; both pass those in here. + */ +export interface OllamaCoreConfig { + /** Provider id used for trace enrichment (`ollama`, `ollama-cloud`). */ + providerId: string + /** Human-readable label used in log messages. */ + providerLabel: string + /** Builds the OpenAI-compatible client (base URL + credentials per provider). */ + createClient: () => OpenAI + createStream: ( + stream: AsyncIterable, + onComplete?: (content: string, usage: CompletionUsage) => void + ) => ReadableStream + logger: Logger +} + +/** + * Shared execution logic for the Ollama-family providers, which speak the same + * OpenAI-compatible Ollama API. Ollama ignores `tool_choice`, so tools are sent + * as `tool_choice: 'auto'` (forced tools degrade to auto) and the final post-tool + * call drops tools entirely rather than relying on `tool_choice: 'none'`. + */ +export async function executeOllamaProviderRequest( + request: ProviderRequest, + config: OllamaCoreConfig +): Promise { + const { providerId, providerLabel, logger } = config + + logger.info(`Preparing ${providerLabel} request`, { + model: request.model, + hasSystemPrompt: !!request.systemPrompt, + hasMessages: !!request.messages?.length, + hasTools: !!request.tools?.length, + toolCount: request.tools?.length || 0, + hasResponseFormat: !!request.responseFormat, + stream: !!request.stream, + }) + + const ollama = config.createClient() + + const allMessages: Message[] = [] + + if (request.systemPrompt) { + allMessages.push({ + role: 'system', + content: request.systemPrompt, + }) + } + + if (request.context) { + allMessages.push({ + role: 'user', + content: request.context, + }) + } + + if (request.messages) { + allMessages.push(...request.messages) + } + const formattedMessages = formatMessagesForProvider(allMessages, providerId) as Message[] + + const tools = request.tools?.length + ? request.tools.map((tool) => ({ + type: 'function', + function: { + name: tool.id, + description: tool.description, + parameters: tool.parameters, + }, + })) + : undefined + + const payload: any = { + model: request.model, + messages: formattedMessages, + } + + if (request.temperature !== undefined) payload.temperature = request.temperature + if (request.maxTokens != null) payload.max_tokens = request.maxTokens + + let hasActiveTools = false + if (tools?.length) { + const filteredTools = tools.filter((tool) => { + const toolId = tool.function?.name + const toolConfig = request.tools?.find((t) => t.id === toolId) + return toolConfig?.usageControl !== 'none' + }) + + const hasForcedTools = tools.some((tool) => { + const toolId = tool.function?.name + const toolConfig = request.tools?.find((t) => t.id === toolId) + return toolConfig?.usageControl === 'force' + }) + + if (hasForcedTools) { + logger.warn( + `${providerLabel} does not support forced tool selection (tool_choice parameter is ignored). ` + + 'Tools marked with usageControl="force" will behave as "auto" instead.' + ) + } + + if (filteredTools?.length) { + payload.tools = filteredTools + payload.tool_choice = 'auto' + hasActiveTools = true + + logger.info(`${providerLabel} request configuration:`, { + toolCount: filteredTools.length, + toolChoice: 'auto', + forcedToolsIgnored: hasForcedTools, + model: request.model, + }) + } + } + + // With tools, defer structured output to the final call so JSON mode doesn't preempt tool use. + if (request.responseFormat && !hasActiveTools) { + payload.messages = applyJsonResponseFormat(payload, payload.messages, request.responseFormat) + logger.info(`Added JSON response format to ${providerLabel} request`) + } + + const providerStartTime = Date.now() + const providerStartTimeISO = new Date(providerStartTime).toISOString() + + try { + if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) { + logger.info(`Using streaming response for ${providerLabel} request`) + + const streamingParams: ChatCompletionCreateParamsStreaming = { + ...payload, + stream: true, + stream_options: { include_usage: true }, + } + const streamResponse = await ollama.chat.completions.create( + streamingParams, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const streamingResult = { + stream: config.createStream(streamResponse, (content, usage) => { + streamingResult.execution.output.content = content + + if (content && request.responseFormat) { + streamingResult.execution.output.content = content + .replace(/```json\n?|\n?```/g, '') + .trim() + } + + streamingResult.execution.output.tokens = { + input: usage.prompt_tokens, + output: usage.completion_tokens, + total: usage.total_tokens, + } + + const costResult = calculateCost( + request.model, + usage.prompt_tokens, + usage.completion_tokens + ) + streamingResult.execution.output.cost = { + input: costResult.input, + output: costResult.output, + total: costResult.total, + } + + const streamEndTime = Date.now() + const streamEndTimeISO = new Date(streamEndTime).toISOString() + + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO + streamingResult.execution.output.providerTiming.duration = + streamEndTime - providerStartTime + + if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { + streamingResult.execution.output.providerTiming.timeSegments[0].endTime = + streamEndTime + streamingResult.execution.output.providerTiming.timeSegments[0].duration = + streamEndTime - providerStartTime + } + } + }), + execution: { + success: true, + output: { + content: '', + model: request.model, + tokens: { input: 0, output: 0, total: 0 }, + toolCalls: undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + timeSegments: [ + { + type: 'model', + name: request.model, + startTime: providerStartTime, + endTime: Date.now(), + duration: Date.now() - providerStartTime, + }, + ], + }, + cost: { input: 0, output: 0, total: 0 }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + }, + } as StreamingExecution + + return streamingResult as StreamingExecution + } + + const initialCallTime = Date.now() + + let currentResponse = await ollama.chat.completions.create( + payload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + const firstResponseTime = Date.now() - initialCallTime + + let content = currentResponse.choices[0]?.message?.content || '' + + if (content && request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '') + content = content.trim() + } + + const tokens = { + input: currentResponse.usage?.prompt_tokens || 0, + output: currentResponse.usage?.completion_tokens || 0, + total: currentResponse.usage?.total_tokens || 0, + } + const toolCalls = [] + const toolResults: Record[] = [] + const currentMessages = [...formattedMessages] + let iterationCount = 0 + + let modelTime = firstResponseTime + let toolsTime = 0 + + const timeSegments: TimeSegment[] = [ + { + type: 'model', + name: request.model, + startTime: initialCallTime, + endTime: initialCallTime + firstResponseTime, + duration: firstResponseTime, + }, + ] + + while (iterationCount < MAX_TOOL_ITERATIONS) { + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } + } + + const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls + + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + toolCallsInResponse, + { + model: request.model, + provider: providerId, + } + ) + + if (!toolCallsInResponse || toolCallsInResponse.length === 0) { + break + } + + logger.info( + `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` + ) + + const toolsStartTime = Date.now() + + const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => { + const toolCallStartTime = Date.now() + const toolName = toolCall.function.name + + try { + const toolArgs = JSON.parse(toolCall.function.arguments) + const tool = request.tools?.find((t) => t.id === toolName) + + if (!tool) return null + + const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) + const toolCallEndTime = Date.now() + + return { + toolCall, + toolName, + toolParams, + result, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } + } catch (error) { + const toolCallEndTime = Date.now() + logger.error('Error processing tool call:', { error, toolName }) + + return { + toolCall, + toolName, + toolParams: {}, + result: { + success: false, + output: undefined, + error: getErrorMessage(error, 'Tool execution failed'), + }, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } + } + }) + + const executionResults = await Promise.allSettled(toolExecutionPromises) + + currentMessages.push({ + role: 'assistant', + content: null, + tool_calls: toolCallsInResponse.map((tc) => ({ + id: tc.id, + type: 'function', + function: { + name: tc.function.name, + arguments: tc.function.arguments, + }, + })), + }) + + for (const settledResult of executionResults) { + if (settledResult.status === 'rejected' || !settledResult.value) continue + + const { toolCall, toolName, toolParams, result, startTime, endTime, duration } = + settledResult.value + + timeSegments.push({ + type: 'tool', + name: toolName, + startTime: startTime, + endTime: endTime, + duration: duration, + toolCallId: toolCall.id, + }) + + let resultContent: any + if (result.success && result.output) { + toolResults.push(result.output) + resultContent = result.output + } else { + resultContent = { + error: true, + message: result.error || 'Tool execution failed', + tool: toolName, + } + } + + toolCalls.push({ + name: toolName, + arguments: toolParams, + startTime: new Date(startTime).toISOString(), + endTime: new Date(endTime).toISOString(), + duration: duration, + result: resultContent, + success: result.success, + }) + + currentMessages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(resultContent), + }) + } + + const thisToolsTime = Date.now() - toolsStartTime + toolsTime += thisToolsTime + + const nextPayload = { + ...payload, + messages: currentMessages, + } + + const nextModelStartTime = Date.now() + + currentResponse = await ollama.chat.completions.create( + nextPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime + + timeSegments.push({ + type: 'model', + name: request.model, + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) + + modelTime += thisModelTime + + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + if (request.responseFormat) { + content = content.replace(/```json\n?|\n?```/g, '').trim() + } + } + + if (currentResponse.usage) { + tokens.input += currentResponse.usage.prompt_tokens || 0 + tokens.output += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } + + iterationCount++ + } + + if (iterationCount === MAX_TOOL_ITERATIONS) { + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + currentResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: providerId } + ) + } + + if (request.stream) { + logger.info(`Using streaming for final ${providerLabel} response after tool processing`) + + const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output) + + const { tools: _tools, tool_choice: _toolChoice, ...streamPayload } = payload + + const finalMessages = request.responseFormat + ? applyJsonResponseFormat(streamPayload, currentMessages, request.responseFormat) + : currentMessages + + const streamingParams: ChatCompletionCreateParamsStreaming = { + ...streamPayload, + messages: finalMessages, + stream: true, + stream_options: { include_usage: true }, + } + const streamResponse = await ollama.chat.completions.create( + streamingParams, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const streamingResult = { + stream: config.createStream(streamResponse, (content, usage) => { + streamingResult.execution.output.content = content + + if (content && request.responseFormat) { + streamingResult.execution.output.content = content + .replace(/```json\n?|\n?```/g, '') + .trim() + } + + streamingResult.execution.output.tokens = { + input: tokens.input + usage.prompt_tokens, + output: tokens.output + usage.completion_tokens, + total: tokens.total + usage.total_tokens, + } + + const streamCost = calculateCost( + request.model, + usage.prompt_tokens, + usage.completion_tokens + ) + const tc = sumToolCosts(toolResults) + streamingResult.execution.output.cost = { + input: accumulatedCost.input + streamCost.input, + output: accumulatedCost.output + streamCost.output, + toolCost: tc || undefined, + total: accumulatedCost.total + streamCost.total + tc, + } + }), + execution: { + success: true, + output: { + content: '', + model: request.model, + tokens: { + input: tokens.input, + output: tokens.output, + total: tokens.total, + }, + toolCalls: + toolCalls.length > 0 + ? { + list: toolCalls, + count: toolCalls.length, + } + : undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + cost: { + input: accumulatedCost.input, + output: accumulatedCost.output, + total: accumulatedCost.total, + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + }, + } as StreamingExecution + + return streamingResult as StreamingExecution + } + + // Deferred structured output: one final JSON-mode call now that tools have run. + if (request.responseFormat && hasActiveTools) { + const finalPayload: any = { model: payload.model } + if (payload.temperature !== undefined) finalPayload.temperature = payload.temperature + if (payload.max_tokens !== undefined) finalPayload.max_tokens = payload.max_tokens + finalPayload.messages = applyJsonResponseFormat( + finalPayload, + currentMessages, + request.responseFormat + ) + + const finalStartTime = Date.now() + const finalResponse = await ollama.chat.completions.create( + finalPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + const finalEndTime = Date.now() + + timeSegments.push({ + type: 'model', + name: 'Final structured response', + startTime: finalStartTime, + endTime: finalEndTime, + duration: finalEndTime - finalStartTime, + }) + modelTime += finalEndTime - finalStartTime + + if (finalResponse.choices[0]?.message?.content) { + content = finalResponse.choices[0].message.content.replace(/```json\n?|\n?```/g, '').trim() + } + if (finalResponse.usage) { + tokens.input += finalResponse.usage.prompt_tokens || 0 + tokens.output += finalResponse.usage.completion_tokens || 0 + tokens.total += finalResponse.usage.total_tokens || 0 + } + + enrichLastModelSegmentFromChatCompletions( + timeSegments, + finalResponse, + finalResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: providerId } + ) + } + + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + return { + content, + model: request.model, + tokens, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + toolResults: toolResults.length > 0 ? toolResults : undefined, + timing: { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + } + } catch (error) { + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + let errorMessage = getErrorMessage(error, 'Unknown error') + let errorType: string | undefined + let errorCode: string | undefined + let status: number | undefined + + if (error instanceof OpenAI.APIError) { + errorMessage = error.message + errorType = error.type + errorCode = error.code ?? undefined + status = error.status + } + + logger.error(`Error in ${providerLabel} request:`, { + error: errorMessage, + errorType, + errorCode, + status, + duration: totalDuration, + }) + + throw new ProviderError(errorMessage, { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + }) + } +} diff --git a/apps/sim/providers/ollama/index.test.ts b/apps/sim/providers/ollama/index.test.ts index a6b91e9e8f5..c4d223ad13c 100644 --- a/apps/sim/providers/ollama/index.test.ts +++ b/apps/sim/providers/ollama/index.test.ts @@ -53,6 +53,7 @@ vi.mock('@/providers/ollama/utils', () => ({ })) vi.mock('@/providers/utils', () => ({ calculateCost: () => ({ input: 0, output: 0, total: 0, pricing: null }), + generateSchemaInstructions: () => 'SCHEMA_INSTRUCTIONS', prepareToolExecution: (_tool: unknown, args: Record) => ({ toolParams: args, executionParams: args, @@ -141,17 +142,45 @@ describe('ollamaProvider.executeRequest', () => { expect(result.content).toBe(fenced) }) - it('strips ```json fences and sends a json_schema response_format when requested', async () => { + it('strips ```json fences and requests JSON mode with schema instructions when responseFormat is set', async () => { mockCreate.mockResolvedValue(completion({ content: '```json\n{"a":1}\n```' })) const result = (await ollamaProvider.executeRequest({ ...baseRequest, responseFormat: { name: 'r', schema: { type: 'object' }, strict: true }, })) as ProviderResponse expect(result.content).toBe('{"a":1}') - expect(mockCreate.mock.calls[0][0].response_format).toMatchObject({ - type: 'json_schema', - json_schema: { name: 'r', schema: { type: 'object' }, strict: true }, - }) + const payload = mockCreate.mock.calls[0][0] + expect(payload.response_format).toEqual({ type: 'json_object' }) + expect(payload.messages.at(-1)).toEqual({ role: 'user', content: 'SCHEMA_INSTRUCTIONS' }) + }) + + it('defers structured output while tools run, then makes a final JSON-mode call', async () => { + mockCreate + .mockResolvedValueOnce( + completion({ + toolCalls: [ + { id: 'call_1', type: 'function', function: { name: 'mytool', arguments: '{}' } }, + ], + }) + ) + .mockResolvedValueOnce(completion({ content: 'intermediate' })) + .mockResolvedValueOnce(completion({ content: '{"a":1}' })) + + const result = (await ollamaProvider.executeRequest({ + ...baseRequest, + tools: [makeTool('mytool')], + responseFormat: { name: 'r', schema: { type: 'object' } }, + })) as ProviderResponse + + expect(mockCreate).toHaveBeenCalledTimes(3) + expect(mockCreate.mock.calls[0][0].response_format).toBeUndefined() + expect(mockCreate.mock.calls[0][0].tools).toBeDefined() + + const finalCall = mockCreate.mock.calls[2][0] + expect(finalCall.response_format).toEqual({ type: 'json_object' }) + expect(finalCall.tools).toBeUndefined() + expect(finalCall.messages.at(-1)).toEqual({ role: 'user', content: 'SCHEMA_INSTRUCTIONS' }) + expect(result.content).toBe('{"a":1}') }) it('runs the tool loop: parses string args, feeds results back, then terminates', async () => { diff --git a/apps/sim/providers/ollama/index.ts b/apps/sim/providers/ollama/index.ts index bfe7cff6134..755c11fdb36 100644 --- a/apps/sim/providers/ollama/index.ts +++ b/apps/sim/providers/ollama/index.ts @@ -1,25 +1,13 @@ import { createLogger } from '@sim/logger' import { getErrorMessage } from '@sim/utils/errors' import OpenAI from 'openai' -import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions' import { getOllamaUrl } from '@/lib/core/utils/urls' import type { StreamingExecution } from '@/executor/types' -import { MAX_TOOL_ITERATIONS } from '@/providers' -import { formatMessagesForProvider } from '@/providers/attachments' +import { executeOllamaProviderRequest } from '@/providers/ollama/core' import type { ModelsObject } from '@/providers/ollama/types' import { createReadableStreamFromOllamaStream } from '@/providers/ollama/utils' -import { enrichLastModelSegmentFromChatCompletions } from '@/providers/trace-enrichment' -import type { - Message, - ProviderConfig, - ProviderRequest, - ProviderResponse, - TimeSegment, -} from '@/providers/types' -import { ProviderError } from '@/providers/types' -import { calculateCost, prepareToolExecution, sumToolCosts } from '@/providers/utils' +import type { ProviderConfig, ProviderRequest, ProviderResponse } from '@/providers/types' import { useProvidersStore } from '@/stores/providers' -import { executeTool } from '@/tools' const logger = createLogger('OllamaProvider') const OLLAMA_HOST = getOllamaUrl() @@ -59,567 +47,16 @@ export const ollamaProvider: ProviderConfig = { executeRequest: async ( request: ProviderRequest ): Promise => { - logger.info('Preparing Ollama request', { - model: request.model, - hasSystemPrompt: !!request.systemPrompt, - hasMessages: !!request.messages?.length, - hasTools: !!request.tools?.length, - toolCount: request.tools?.length || 0, - hasResponseFormat: !!request.responseFormat, - stream: !!request.stream, + return executeOllamaProviderRequest(request, { + providerId: 'ollama', + providerLabel: 'Ollama', + createClient: () => + new OpenAI({ + apiKey: 'empty', + baseURL: `${OLLAMA_HOST}/v1`, + }), + createStream: createReadableStreamFromOllamaStream, + logger, }) - - const ollama = new OpenAI({ - apiKey: 'empty', - baseURL: `${OLLAMA_HOST}/v1`, - }) - - const allMessages: Message[] = [] - - if (request.systemPrompt) { - allMessages.push({ - role: 'system', - content: request.systemPrompt, - }) - } - - if (request.context) { - allMessages.push({ - role: 'user', - content: request.context, - }) - } - - if (request.messages) { - allMessages.push(...request.messages) - } - const formattedMessages = formatMessagesForProvider(allMessages, 'ollama') as Message[] - - const tools = request.tools?.length - ? request.tools.map((tool) => ({ - type: 'function', - function: { - name: tool.id, - description: tool.description, - parameters: tool.parameters, - }, - })) - : undefined - - const payload: any = { - model: request.model, - messages: formattedMessages, - } - - if (request.temperature !== undefined) payload.temperature = request.temperature - if (request.maxTokens != null) payload.max_tokens = request.maxTokens - - if (request.responseFormat) { - payload.response_format = { - type: 'json_schema', - json_schema: { - name: request.responseFormat.name || 'response_schema', - schema: request.responseFormat.schema || request.responseFormat, - strict: request.responseFormat.strict !== false, - }, - } - - logger.info('Added JSON schema response format to Ollama request') - } - - if (tools?.length) { - const filteredTools = tools.filter((tool) => { - const toolId = tool.function?.name - const toolConfig = request.tools?.find((t) => t.id === toolId) - return toolConfig?.usageControl !== 'none' - }) - - const hasForcedTools = tools.some((tool) => { - const toolId = tool.function?.name - const toolConfig = request.tools?.find((t) => t.id === toolId) - return toolConfig?.usageControl === 'force' - }) - - if (hasForcedTools) { - logger.warn( - 'Ollama does not support forced tool selection (tool_choice parameter is ignored). ' + - 'Tools marked with usageControl="force" will behave as "auto" instead.' - ) - } - - if (filteredTools?.length) { - payload.tools = filteredTools - payload.tool_choice = 'auto' - - logger.info('Ollama request configuration:', { - toolCount: filteredTools.length, - toolChoice: 'auto', - forcedToolsIgnored: hasForcedTools, - model: request.model, - }) - } - } - - const providerStartTime = Date.now() - const providerStartTimeISO = new Date(providerStartTime).toISOString() - - try { - if (request.stream && (!tools || tools.length === 0)) { - logger.info('Using streaming response for Ollama request') - - const streamingParams: ChatCompletionCreateParamsStreaming = { - ...payload, - stream: true, - stream_options: { include_usage: true }, - } - const streamResponse = await ollama.chat.completions.create( - streamingParams, - request.abortSignal ? { signal: request.abortSignal } : undefined - ) - - const streamingResult = { - stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => { - streamingResult.execution.output.content = content - - if (content && request.responseFormat) { - streamingResult.execution.output.content = content - .replace(/```json\n?|\n?```/g, '') - .trim() - } - - streamingResult.execution.output.tokens = { - input: usage.prompt_tokens, - output: usage.completion_tokens, - total: usage.total_tokens, - } - - const costResult = calculateCost( - request.model, - usage.prompt_tokens, - usage.completion_tokens - ) - streamingResult.execution.output.cost = { - input: costResult.input, - output: costResult.output, - total: costResult.total, - } - - const streamEndTime = Date.now() - const streamEndTimeISO = new Date(streamEndTime).toISOString() - - if (streamingResult.execution.output.providerTiming) { - streamingResult.execution.output.providerTiming.endTime = streamEndTimeISO - streamingResult.execution.output.providerTiming.duration = - streamEndTime - providerStartTime - - if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { - streamingResult.execution.output.providerTiming.timeSegments[0].endTime = - streamEndTime - streamingResult.execution.output.providerTiming.timeSegments[0].duration = - streamEndTime - providerStartTime - } - } - }), - execution: { - success: true, - output: { - content: '', - model: request.model, - tokens: { input: 0, output: 0, total: 0 }, - toolCalls: undefined, - providerTiming: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: Date.now() - providerStartTime, - timeSegments: [ - { - type: 'model', - name: request.model, - startTime: providerStartTime, - endTime: Date.now(), - duration: Date.now() - providerStartTime, - }, - ], - }, - cost: { input: 0, output: 0, total: 0 }, - }, - logs: [], - metadata: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: Date.now() - providerStartTime, - }, - }, - } as StreamingExecution - - return streamingResult as StreamingExecution - } - - const initialCallTime = Date.now() - - let currentResponse = await ollama.chat.completions.create( - payload, - request.abortSignal ? { signal: request.abortSignal } : undefined - ) - const firstResponseTime = Date.now() - initialCallTime - - let content = currentResponse.choices[0]?.message?.content || '' - - if (content && request.responseFormat) { - content = content.replace(/```json\n?|\n?```/g, '') - content = content.trim() - } - - const tokens = { - input: currentResponse.usage?.prompt_tokens || 0, - output: currentResponse.usage?.completion_tokens || 0, - total: currentResponse.usage?.total_tokens || 0, - } - const toolCalls = [] - const toolResults: Record[] = [] - const currentMessages = [...formattedMessages] - let iterationCount = 0 - - let modelTime = firstResponseTime - let toolsTime = 0 - - const timeSegments: TimeSegment[] = [ - { - type: 'model', - name: request.model, - startTime: initialCallTime, - endTime: initialCallTime + firstResponseTime, - duration: firstResponseTime, - }, - ] - - while (iterationCount < MAX_TOOL_ITERATIONS) { - if (currentResponse.choices[0]?.message?.content) { - content = currentResponse.choices[0].message.content - if (request.responseFormat) { - content = content.replace(/```json\n?|\n?```/g, '').trim() - } - } - - const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls - - enrichLastModelSegmentFromChatCompletions( - timeSegments, - currentResponse, - toolCallsInResponse, - { model: request.model, provider: 'ollama' } - ) - - if (!toolCallsInResponse || toolCallsInResponse.length === 0) { - break - } - - logger.info( - `Processing ${toolCallsInResponse.length} tool calls (iteration ${iterationCount + 1}/${MAX_TOOL_ITERATIONS})` - ) - - const toolsStartTime = Date.now() - - const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => { - const toolCallStartTime = Date.now() - const toolName = toolCall.function.name - - try { - const toolArgs = JSON.parse(toolCall.function.arguments) - const tool = request.tools?.find((t) => t.id === toolName) - - if (!tool) return null - - const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) - const result = await executeTool(toolName, executionParams, { - signal: request.abortSignal, - }) - const toolCallEndTime = Date.now() - - return { - toolCall, - toolName, - toolParams, - result, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration: toolCallEndTime - toolCallStartTime, - } - } catch (error) { - const toolCallEndTime = Date.now() - logger.error('Error processing tool call:', { error, toolName }) - - return { - toolCall, - toolName, - toolParams: {}, - result: { - success: false, - output: undefined, - error: getErrorMessage(error, 'Tool execution failed'), - }, - startTime: toolCallStartTime, - endTime: toolCallEndTime, - duration: toolCallEndTime - toolCallStartTime, - } - } - }) - - const executionResults = await Promise.allSettled(toolExecutionPromises) - - currentMessages.push({ - role: 'assistant', - content: null, - tool_calls: toolCallsInResponse.map((tc) => ({ - id: tc.id, - type: 'function', - function: { - name: tc.function.name, - arguments: tc.function.arguments, - }, - })), - }) - - for (const settledResult of executionResults) { - if (settledResult.status === 'rejected' || !settledResult.value) continue - - const { toolCall, toolName, toolParams, result, startTime, endTime, duration } = - settledResult.value - - timeSegments.push({ - type: 'tool', - name: toolName, - startTime: startTime, - endTime: endTime, - duration: duration, - toolCallId: toolCall.id, - }) - - let resultContent: any - if (result.success && result.output) { - toolResults.push(result.output) - resultContent = result.output - } else { - resultContent = { - error: true, - message: result.error || 'Tool execution failed', - tool: toolName, - } - } - - toolCalls.push({ - name: toolName, - arguments: toolParams, - startTime: new Date(startTime).toISOString(), - endTime: new Date(endTime).toISOString(), - duration: duration, - result: resultContent, - success: result.success, - }) - - currentMessages.push({ - role: 'tool', - tool_call_id: toolCall.id, - content: JSON.stringify(resultContent), - }) - } - - const thisToolsTime = Date.now() - toolsStartTime - toolsTime += thisToolsTime - - const nextPayload = { - ...payload, - messages: currentMessages, - } - - const nextModelStartTime = Date.now() - - currentResponse = await ollama.chat.completions.create( - nextPayload, - request.abortSignal ? { signal: request.abortSignal } : undefined - ) - - const nextModelEndTime = Date.now() - const thisModelTime = nextModelEndTime - nextModelStartTime - - timeSegments.push({ - type: 'model', - name: request.model, - startTime: nextModelStartTime, - endTime: nextModelEndTime, - duration: thisModelTime, - }) - - modelTime += thisModelTime - - if (currentResponse.choices[0]?.message?.content) { - content = currentResponse.choices[0].message.content - if (request.responseFormat) { - content = content.replace(/```json\n?|\n?```/g, '').trim() - } - } - - if (currentResponse.usage) { - tokens.input += currentResponse.usage.prompt_tokens || 0 - tokens.output += currentResponse.usage.completion_tokens || 0 - tokens.total += currentResponse.usage.total_tokens || 0 - } - - iterationCount++ - } - - if (iterationCount === MAX_TOOL_ITERATIONS) { - enrichLastModelSegmentFromChatCompletions( - timeSegments, - currentResponse, - currentResponse.choices[0]?.message?.tool_calls, - { model: request.model, provider: 'ollama' } - ) - } - - if (request.stream) { - logger.info('Using streaming for final response after tool processing') - - const accumulatedCost = calculateCost(request.model, tokens.input, tokens.output) - - const { tools: _tools, tool_choice: _toolChoice, ...streamPayload } = payload - - const streamingParams: ChatCompletionCreateParamsStreaming = { - ...streamPayload, - messages: currentMessages, - stream: true, - stream_options: { include_usage: true }, - } - const streamResponse = await ollama.chat.completions.create( - streamingParams, - request.abortSignal ? { signal: request.abortSignal } : undefined - ) - - const streamingResult = { - stream: createReadableStreamFromOllamaStream(streamResponse, (content, usage) => { - streamingResult.execution.output.content = content - - if (content && request.responseFormat) { - streamingResult.execution.output.content = content - .replace(/```json\n?|\n?```/g, '') - .trim() - } - - streamingResult.execution.output.tokens = { - input: tokens.input + usage.prompt_tokens, - output: tokens.output + usage.completion_tokens, - total: tokens.total + usage.total_tokens, - } - - const streamCost = calculateCost( - request.model, - usage.prompt_tokens, - usage.completion_tokens - ) - const tc = sumToolCosts(toolResults) - streamingResult.execution.output.cost = { - input: accumulatedCost.input + streamCost.input, - output: accumulatedCost.output + streamCost.output, - toolCost: tc || undefined, - total: accumulatedCost.total + streamCost.total + tc, - } - }), - execution: { - success: true, - output: { - content: '', - model: request.model, - tokens: { - input: tokens.input, - output: tokens.output, - total: tokens.total, - }, - toolCalls: - toolCalls.length > 0 - ? { - list: toolCalls, - count: toolCalls.length, - } - : undefined, - providerTiming: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: Date.now() - providerStartTime, - modelTime: modelTime, - toolsTime: toolsTime, - firstResponseTime: firstResponseTime, - iterations: iterationCount + 1, - timeSegments: timeSegments, - }, - cost: { - input: accumulatedCost.input, - output: accumulatedCost.output, - total: accumulatedCost.total, - }, - }, - logs: [], - metadata: { - startTime: providerStartTimeISO, - endTime: new Date().toISOString(), - duration: Date.now() - providerStartTime, - }, - }, - } as StreamingExecution - - return streamingResult as StreamingExecution - } - - const providerEndTime = Date.now() - const providerEndTimeISO = new Date(providerEndTime).toISOString() - const totalDuration = providerEndTime - providerStartTime - - return { - content, - model: request.model, - tokens, - toolCalls: toolCalls.length > 0 ? toolCalls : undefined, - toolResults: toolResults.length > 0 ? toolResults : undefined, - timing: { - startTime: providerStartTimeISO, - endTime: providerEndTimeISO, - duration: totalDuration, - modelTime: modelTime, - toolsTime: toolsTime, - firstResponseTime: firstResponseTime, - iterations: iterationCount + 1, - timeSegments: timeSegments, - }, - } - } catch (error) { - const providerEndTime = Date.now() - const providerEndTimeISO = new Date(providerEndTime).toISOString() - const totalDuration = providerEndTime - providerStartTime - - let errorMessage = getErrorMessage(error, 'Unknown error') - let errorType: string | undefined - let errorCode: string | undefined - let status: number | undefined - - if (error instanceof OpenAI.APIError) { - errorMessage = error.message - errorType = error.type - errorCode = error.code ?? undefined - status = error.status - } - - logger.error('Error in Ollama request:', { - error: errorMessage, - errorType, - errorCode, - status, - duration: totalDuration, - }) - - throw new ProviderError(errorMessage, { - startTime: providerStartTimeISO, - endTime: providerEndTimeISO, - duration: totalDuration, - }) - } }, } diff --git a/apps/sim/providers/registry.ts b/apps/sim/providers/registry.ts index 5aa48d3db3a..5e65e92796c 100644 --- a/apps/sim/providers/registry.ts +++ b/apps/sim/providers/registry.ts @@ -3,6 +3,7 @@ import { getErrorMessage } from '@sim/utils/errors' import { anthropicProvider } from '@/providers/anthropic' import { azureAnthropicProvider } from '@/providers/azure-anthropic' import { azureOpenAIProvider } from '@/providers/azure-openai' +import { basetenProvider } from '@/providers/baseten' import { bedrockProvider } from '@/providers/bedrock' import { cerebrasProvider } from '@/providers/cerebras' import { deepseekProvider } from '@/providers/deepseek' @@ -12,8 +13,10 @@ import { groqProvider } from '@/providers/groq' import { litellmProvider } from '@/providers/litellm' import { mistralProvider } from '@/providers/mistral' import { ollamaProvider } from '@/providers/ollama' +import { ollamaCloudProvider } from '@/providers/ollama-cloud' import { openaiProvider } from '@/providers/openai' import { openRouterProvider } from '@/providers/openrouter' +import { togetherProvider } from '@/providers/together' import type { ProviderConfig, ProviderId } from '@/providers/types' import { vertexProvider } from '@/providers/vertex' import { vllmProvider } from '@/providers/vllm' @@ -37,7 +40,10 @@ const providerRegistry: Record = { 'azure-openai': azureOpenAIProvider, openrouter: openRouterProvider, fireworks: fireworksProvider, + together: togetherProvider, + baseten: basetenProvider, ollama: ollamaProvider, + 'ollama-cloud': ollamaCloudProvider, bedrock: bedrockProvider, } diff --git a/apps/sim/providers/together/index.test.ts b/apps/sim/providers/together/index.test.ts new file mode 100644 index 00000000000..39e01ab7679 --- /dev/null +++ b/apps/sim/providers/together/index.test.ts @@ -0,0 +1,225 @@ +/** + * @vitest-environment node + */ +import { beforeEach, describe, expect, it, vi } from 'vitest' + +const { + mockCreate, + mockSupportsNativeStructuredOutputs, + mockPrepareToolsWithUsageControl, + mockExecuteTool, +} = vi.hoisted(() => ({ + mockCreate: vi.fn(), + mockSupportsNativeStructuredOutputs: vi.fn(), + mockPrepareToolsWithUsageControl: vi.fn(), + mockExecuteTool: vi.fn(), +})) + +vi.mock('openai', () => ({ + default: vi.fn().mockImplementation(() => ({ + chat: { completions: { create: mockCreate } }, + })), +})) + +vi.mock('@/providers', () => ({ MAX_TOOL_ITERATIONS: 5 })) + +vi.mock('@/providers/models', () => ({ + getProviderModels: vi.fn().mockReturnValue([]), + getProviderDefaultModel: vi.fn().mockReturnValue('moonshotai/Kimi-K2-Instruct'), +})) + +vi.mock('@/providers/attachments', () => ({ + formatMessagesForProvider: vi.fn((messages) => messages), +})) + +vi.mock('@/providers/together/utils', () => ({ + supportsNativeStructuredOutputs: mockSupportsNativeStructuredOutputs, + createReadableStreamFromOpenAIStream: vi.fn(() => ({}) as ReadableStream), + checkForForcedToolUsage: vi.fn(() => ({ hasUsedForcedTool: false, usedForcedTools: [] })), +})) + +vi.mock('@/providers/trace-enrichment', () => ({ + enrichLastModelSegmentFromChatCompletions: vi.fn(), +})) + +vi.mock('@/providers/utils', () => ({ + calculateCost: vi.fn().mockReturnValue({ input: 0, output: 0, total: 0 }), + generateSchemaInstructions: vi.fn(() => 'SCHEMA_INSTRUCTIONS'), + prepareToolExecution: vi.fn(() => ({ toolParams: { x: 1 }, executionParams: { x: 1 } })), + prepareToolsWithUsageControl: mockPrepareToolsWithUsageControl, + sumToolCosts: vi.fn().mockReturnValue(0), +})) + +vi.mock('@/tools', () => ({ executeTool: mockExecuteTool })) + +import { togetherProvider } from '@/providers/together/index' +import { ProviderError } from '@/providers/types' + +const textResponse = (content: string) => ({ + choices: [{ message: { content, tool_calls: [] } }], + usage: { prompt_tokens: 10, completion_tokens: 5, total_tokens: 15 }, +}) + +const toolCallResponse = () => ({ + choices: [ + { + message: { + content: null, + tool_calls: [ + { id: 'call_1', type: 'function', function: { name: 'my_tool', arguments: '{"x":1}' } }, + ], + }, + }, + ], + usage: { prompt_tokens: 8, completion_tokens: 4, total_tokens: 12 }, +}) + +const toolDef = { + id: 'my_tool', + name: 'my_tool', + description: '', + params: {}, + parameters: { type: 'object', properties: {}, required: [] }, +} + +const callBody = (index: number) => mockCreate.mock.calls[index][0] +const lastCallBody = () => mockCreate.mock.calls.at(-1)?.[0] + +describe('togetherProvider', () => { + beforeEach(() => { + vi.clearAllMocks() + mockSupportsNativeStructuredOutputs.mockResolvedValue(true) + mockPrepareToolsWithUsageControl.mockImplementation((tools) => ({ + tools, + toolChoice: 'auto', + forcedTools: [], + })) + mockExecuteTool.mockResolvedValue({ success: true, output: { ok: true } }) + }) + + const baseRequest = { + model: 'together/moonshotai/Kimi-K2-Instruct', + systemPrompt: 'You are helpful.', + messages: [{ role: 'user' as const, content: 'Hello' }], + apiKey: 'together-test-key', + } + + it('throws when the API key is missing', async () => { + await expect( + togetherProvider.executeRequest({ ...baseRequest, apiKey: undefined }) + ).rejects.toThrow('API key is required for Together AI') + }) + + it('strips only the leading together/ prefix from the model id', async () => { + mockCreate.mockResolvedValueOnce(textResponse('hi there')) + + const result = await togetherProvider.executeRequest(baseRequest) + + expect(lastCallBody().model).toBe('moonshotai/Kimi-K2-Instruct') + expect(result).toMatchObject({ + content: 'hi there', + model: 'moonshotai/Kimi-K2-Instruct', + tokens: { input: 10, output: 5, total: 15 }, + }) + }) + + it('wraps API errors in a ProviderError', async () => { + mockCreate.mockRejectedValueOnce(new Error('boom')) + + await expect(togetherProvider.executeRequest(baseRequest)).rejects.toBeInstanceOf(ProviderError) + }) + + it('streams directly when there are no tools', async () => { + mockCreate.mockResolvedValueOnce({}) + + const result = await togetherProvider.executeRequest({ ...baseRequest, stream: true }) + + expect(lastCallBody()).toMatchObject({ stream: true, stream_options: { include_usage: true } }) + expect(result).toHaveProperty('stream') + expect(result).toHaveProperty('execution') + }) + + it('sends a json_schema response_format with no strict field', async () => { + mockCreate.mockResolvedValueOnce(textResponse('{}')) + + await togetherProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' }, strict: true }, + }) + + expect(lastCallBody().response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'my_schema', schema: { type: 'object' } }, + }) + expect(lastCallBody().response_format.json_schema).not.toHaveProperty('strict') + }) + + it('falls back to json_object with prompt instructions when native is unsupported', async () => { + mockSupportsNativeStructuredOutputs.mockResolvedValue(false) + mockCreate.mockResolvedValueOnce(textResponse('{}')) + + await togetherProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' } }, + }) + + expect(lastCallBody().response_format).toEqual({ type: 'json_object' }) + expect(lastCallBody().messages.at(-1)).toEqual({ + role: 'user', + content: 'SCHEMA_INSTRUCTIONS', + }) + }) + + it('defers response_format to a final call when tools are active', async () => { + mockCreate + .mockResolvedValueOnce(textResponse('intermediate')) + .mockResolvedValueOnce(textResponse('{"done":true}')) + + await togetherProvider.executeRequest({ + ...baseRequest, + responseFormat: { name: 'my_schema', schema: { type: 'object' } }, + tools: [toolDef], + }) + + expect(mockCreate).toHaveBeenCalledTimes(2) + expect(callBody(0).response_format).toBeUndefined() + expect(callBody(0).tools).toBeDefined() + expect(callBody(1).response_format).toEqual({ + type: 'json_schema', + json_schema: { name: 'my_schema', schema: { type: 'object' } }, + }) + expect(callBody(1).tools).toBeUndefined() + }) + + it('runs the tool loop and threads tool results back into the conversation', async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse()) + .mockResolvedValueOnce(textResponse('final answer')) + + const result = await togetherProvider.executeRequest({ ...baseRequest, tools: [toolDef] }) + + expect(mockExecuteTool).toHaveBeenCalledWith('my_tool', { x: 1 }, expect.anything()) + expect(result).toMatchObject({ content: 'final answer' }) + expect((result as { toolCalls?: unknown[] }).toolCalls).toHaveLength(1) + + const followUpMessages = callBody(1).messages + expect(followUpMessages).toContainEqual( + expect.objectContaining({ role: 'assistant', tool_calls: expect.any(Array) }) + ) + expect(followUpMessages).toContainEqual( + expect.objectContaining({ role: 'tool', tool_call_id: 'call_1' }) + ) + }) + + it("forces tool_choice 'none' on the final streaming call after tools run", async () => { + mockCreate + .mockResolvedValueOnce(toolCallResponse()) + .mockResolvedValueOnce(textResponse('done')) + .mockResolvedValueOnce({}) + + await togetherProvider.executeRequest({ ...baseRequest, stream: true, tools: [toolDef] }) + + expect(mockCreate).toHaveBeenCalledTimes(3) + expect(lastCallBody()).toMatchObject({ tool_choice: 'none', stream: true }) + }) +}) diff --git a/apps/sim/providers/together/index.ts b/apps/sim/providers/together/index.ts new file mode 100644 index 00000000000..aff98633f51 --- /dev/null +++ b/apps/sim/providers/together/index.ts @@ -0,0 +1,653 @@ +import { createLogger } from '@sim/logger' +import { getErrorMessage, toError } from '@sim/utils/errors' +import OpenAI from 'openai' +import type { ChatCompletionCreateParamsStreaming } from 'openai/resources/chat/completions' +import type { StreamingExecution } from '@/executor/types' +import { MAX_TOOL_ITERATIONS } from '@/providers' +import { formatMessagesForProvider } from '@/providers/attachments' +import { getProviderDefaultModel, getProviderModels } from '@/providers/models' +import { + checkForForcedToolUsage, + createReadableStreamFromOpenAIStream, + supportsNativeStructuredOutputs, +} from '@/providers/together/utils' +import { enrichLastModelSegmentFromChatCompletions } from '@/providers/trace-enrichment' +import type { + FunctionCallResponse, + Message, + ProviderConfig, + ProviderRequest, + ProviderResponse, + TimeSegment, +} from '@/providers/types' +import { ProviderError } from '@/providers/types' +import { + calculateCost, + generateSchemaInstructions, + prepareToolExecution, + prepareToolsWithUsageControl, + sumToolCosts, +} from '@/providers/utils' +import { executeTool } from '@/tools' + +const logger = createLogger('TogetherProvider') + +/** + * Applies structured output configuration to a payload based on model capabilities. + * Uses native json_schema for supported models, falls back to json_object with prompt instructions. + */ +async function applyResponseFormat( + targetPayload: any, + messages: any[], + responseFormat: any, + model: string +): Promise { + const useNative = await supportsNativeStructuredOutputs(model) + + if (useNative) { + logger.info('Using native structured outputs for Together model', { model }) + targetPayload.response_format = { + type: 'json_schema', + json_schema: { + name: responseFormat.name || 'response_schema', + schema: responseFormat.schema || responseFormat, + }, + } + return messages + } + + logger.info('Using json_object mode with prompt instructions for Together model', { model }) + const schema = responseFormat.schema || responseFormat + const schemaInstructions = generateSchemaInstructions(schema, responseFormat.name) + targetPayload.response_format = { type: 'json_object' } + return [...messages, { role: 'user', content: schemaInstructions }] +} + +export const togetherProvider: ProviderConfig = { + id: 'together', + name: 'Together AI', + description: 'Fast inference for open-source models via Together AI', + version: '1.0.0', + models: getProviderModels('together'), + defaultModel: getProviderDefaultModel('together'), + + executeRequest: async ( + request: ProviderRequest + ): Promise => { + if (!request.apiKey) { + throw new Error('API key is required for Together AI') + } + + const client = new OpenAI({ + apiKey: request.apiKey, + baseURL: 'https://api.together.ai/v1', + }) + + const requestedModel = request.model.replace(/^together\//, '') + + logger.info('Preparing Together request', { + model: requestedModel, + hasSystemPrompt: !!request.systemPrompt, + hasMessages: !!request.messages?.length, + hasTools: !!request.tools?.length, + toolCount: request.tools?.length || 0, + hasResponseFormat: !!request.responseFormat, + stream: !!request.stream, + }) + + const allMessages: Message[] = [] + + if (request.systemPrompt) { + allMessages.push({ role: 'system', content: request.systemPrompt }) + } + + if (request.context) { + allMessages.push({ role: 'user', content: request.context }) + } + + if (request.messages) { + allMessages.push(...request.messages) + } + const formattedMessages = formatMessagesForProvider(allMessages, 'together') as Message[] + + const tools = request.tools?.length + ? request.tools.map((tool) => ({ + type: 'function', + function: { + name: tool.id, + description: tool.description, + parameters: tool.parameters, + }, + })) + : undefined + + const payload: any = { + model: requestedModel, + messages: formattedMessages, + } + + if (request.temperature !== undefined) payload.temperature = request.temperature + if (request.maxTokens != null) payload.max_tokens = request.maxTokens + + let preparedTools: ReturnType | null = null + let hasActiveTools = false + if (tools?.length) { + preparedTools = prepareToolsWithUsageControl(tools, request.tools, logger, 'together') + const { tools: filteredTools, toolChoice } = preparedTools + if (filteredTools?.length && toolChoice) { + payload.tools = filteredTools + payload.tool_choice = toolChoice + hasActiveTools = true + } + } + + const providerStartTime = Date.now() + const providerStartTimeISO = new Date(providerStartTime).toISOString() + + try { + if (request.responseFormat && !hasActiveTools) { + payload.messages = await applyResponseFormat( + payload, + payload.messages, + request.responseFormat, + requestedModel + ) + } + + if (request.stream && (!tools || tools.length === 0 || !hasActiveTools)) { + const streamingParams: ChatCompletionCreateParamsStreaming = { + ...payload, + stream: true, + stream_options: { include_usage: true }, + } + const streamResponse = await client.chat.completions.create( + streamingParams, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const streamingResult = { + stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => { + streamingResult.execution.output.content = content + streamingResult.execution.output.tokens = { + input: usage.prompt_tokens, + output: usage.completion_tokens, + total: usage.total_tokens, + } + + const costResult = calculateCost( + requestedModel, + usage.prompt_tokens, + usage.completion_tokens + ) + streamingResult.execution.output.cost = { + input: costResult.input, + output: costResult.output, + total: costResult.total, + } + + const end = Date.now() + const endISO = new Date(end).toISOString() + if (streamingResult.execution.output.providerTiming) { + streamingResult.execution.output.providerTiming.endTime = endISO + streamingResult.execution.output.providerTiming.duration = end - providerStartTime + if (streamingResult.execution.output.providerTiming.timeSegments?.[0]) { + streamingResult.execution.output.providerTiming.timeSegments[0].endTime = end + streamingResult.execution.output.providerTiming.timeSegments[0].duration = + end - providerStartTime + } + } + }), + execution: { + success: true, + output: { + content: '', + model: requestedModel, + tokens: { input: 0, output: 0, total: 0 }, + toolCalls: undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + timeSegments: [ + { + type: 'model', + name: request.model, + startTime: providerStartTime, + endTime: Date.now(), + duration: Date.now() - providerStartTime, + }, + ], + }, + cost: { input: 0, output: 0, total: 0 }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + }, + } as StreamingExecution + + return streamingResult as StreamingExecution + } + + const initialCallTime = Date.now() + const originalToolChoice = payload.tool_choice + const forcedTools = preparedTools?.forcedTools || [] + let usedForcedTools: string[] = [] + + let currentResponse = await client.chat.completions.create( + payload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + const firstResponseTime = Date.now() - initialCallTime + + let content = currentResponse.choices[0]?.message?.content || '' + const tokens = { + input: currentResponse.usage?.prompt_tokens || 0, + output: currentResponse.usage?.completion_tokens || 0, + total: currentResponse.usage?.total_tokens || 0, + } + const toolCalls: FunctionCallResponse[] = [] + const toolResults: Record[] = [] + const currentMessages = [...formattedMessages] + let iterationCount = 0 + let modelTime = firstResponseTime + let toolsTime = 0 + let hasUsedForcedTool = false + const timeSegments: TimeSegment[] = [ + { + type: 'model', + name: request.model, + startTime: initialCallTime, + endTime: initialCallTime + firstResponseTime, + duration: firstResponseTime, + }, + ] + + const forcedToolResult = checkForForcedToolUsage( + currentResponse, + originalToolChoice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = forcedToolResult.hasUsedForcedTool + usedForcedTools = forcedToolResult.usedForcedTools + + while (iterationCount < MAX_TOOL_ITERATIONS) { + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + } + + const toolCallsInResponse = currentResponse.choices[0]?.message?.tool_calls + + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + toolCallsInResponse, + { model: request.model, provider: 'together' } + ) + + if (!toolCallsInResponse || toolCallsInResponse.length === 0) { + break + } + + const toolsStartTime = Date.now() + + const toolExecutionPromises = toolCallsInResponse.map(async (toolCall) => { + const toolCallStartTime = Date.now() + const toolName = toolCall.function.name + + try { + const toolArgs = JSON.parse(toolCall.function.arguments) + const tool = request.tools?.find((t) => t.id === toolName) + + if (!tool) return null + + const { toolParams, executionParams } = prepareToolExecution(tool, toolArgs, request) + const result = await executeTool(toolName, executionParams, { + signal: request.abortSignal, + }) + const toolCallEndTime = Date.now() + + return { + toolCall, + toolName, + toolParams, + result, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } + } catch (error) { + const toolCallEndTime = Date.now() + logger.error('Error processing tool call (Together):', { + error: toError(error).message, + toolName, + }) + + return { + toolCall, + toolName, + toolParams: {}, + result: { + success: false, + output: undefined, + error: getErrorMessage(error, 'Tool execution failed'), + }, + startTime: toolCallStartTime, + endTime: toolCallEndTime, + duration: toolCallEndTime - toolCallStartTime, + } + } + }) + + const executionResults = await Promise.allSettled(toolExecutionPromises) + + currentMessages.push({ + role: 'assistant', + content: null, + tool_calls: toolCallsInResponse.map((tc) => ({ + id: tc.id, + type: 'function', + function: { + name: tc.function.name, + arguments: tc.function.arguments, + }, + })), + }) + + for (const settledResult of executionResults) { + if (settledResult.status === 'rejected' || !settledResult.value) continue + + const { toolCall, toolName, toolParams, result, startTime, endTime, duration } = + settledResult.value + + timeSegments.push({ + type: 'tool', + name: toolName, + startTime: startTime, + endTime: endTime, + duration: duration, + toolCallId: toolCall.id, + }) + + let resultContent: any + if (result.success) { + toolResults.push(result.output!) + resultContent = result.output + } else { + resultContent = { + error: true, + message: result.error || 'Tool execution failed', + tool: toolName, + } + } + + toolCalls.push({ + name: toolName, + arguments: toolParams, + startTime: new Date(startTime).toISOString(), + endTime: new Date(endTime).toISOString(), + duration: duration, + result: resultContent, + success: result.success, + }) + + currentMessages.push({ + role: 'tool', + tool_call_id: toolCall.id, + content: JSON.stringify(resultContent), + }) + } + + const thisToolsTime = Date.now() - toolsStartTime + toolsTime += thisToolsTime + + const nextPayload = { + ...payload, + messages: currentMessages, + } + + if (typeof originalToolChoice === 'object' && hasUsedForcedTool && forcedTools.length > 0) { + const remainingTools = forcedTools.filter((tool) => !usedForcedTools.includes(tool)) + if (remainingTools.length > 0) { + nextPayload.tool_choice = { type: 'function', function: { name: remainingTools[0] } } + } else { + nextPayload.tool_choice = 'auto' + } + } + + const nextModelStartTime = Date.now() + currentResponse = await client.chat.completions.create( + nextPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + const nextForcedToolResult = checkForForcedToolUsage( + currentResponse, + nextPayload.tool_choice, + forcedTools, + usedForcedTools + ) + hasUsedForcedTool = nextForcedToolResult.hasUsedForcedTool + usedForcedTools = nextForcedToolResult.usedForcedTools + const nextModelEndTime = Date.now() + const thisModelTime = nextModelEndTime - nextModelStartTime + timeSegments.push({ + type: 'model', + name: request.model, + startTime: nextModelStartTime, + endTime: nextModelEndTime, + duration: thisModelTime, + }) + modelTime += thisModelTime + if (currentResponse.choices[0]?.message?.content) { + content = currentResponse.choices[0].message.content + } + if (currentResponse.usage) { + tokens.input += currentResponse.usage.prompt_tokens || 0 + tokens.output += currentResponse.usage.completion_tokens || 0 + tokens.total += currentResponse.usage.total_tokens || 0 + } + iterationCount++ + } + + if (iterationCount === MAX_TOOL_ITERATIONS) { + enrichLastModelSegmentFromChatCompletions( + timeSegments, + currentResponse, + currentResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: 'together' } + ) + } + + if (request.stream) { + const accumulatedCost = calculateCost(requestedModel, tokens.input, tokens.output) + + const streamingParams: ChatCompletionCreateParamsStreaming = { + ...payload, + messages: [...currentMessages], + tool_choice: 'none', + stream: true, + stream_options: { include_usage: true }, + } + + if (request.responseFormat) { + ;(streamingParams as any).messages = await applyResponseFormat( + streamingParams as any, + streamingParams.messages, + request.responseFormat, + requestedModel + ) + } + + const streamResponse = await client.chat.completions.create( + streamingParams, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + + const streamingResult = { + stream: createReadableStreamFromOpenAIStream(streamResponse, (content, usage) => { + streamingResult.execution.output.content = content + streamingResult.execution.output.tokens = { + input: tokens.input + usage.prompt_tokens, + output: tokens.output + usage.completion_tokens, + total: tokens.total + usage.total_tokens, + } + + const streamCost = calculateCost( + requestedModel, + usage.prompt_tokens, + usage.completion_tokens + ) + const tc = sumToolCosts(toolResults) + streamingResult.execution.output.cost = { + input: accumulatedCost.input + streamCost.input, + output: accumulatedCost.output + streamCost.output, + toolCost: tc || undefined, + total: accumulatedCost.total + streamCost.total + tc, + } + }), + execution: { + success: true, + output: { + content: '', + model: requestedModel, + tokens: { input: tokens.input, output: tokens.output, total: tokens.total }, + toolCalls: + toolCalls.length > 0 + ? { + list: toolCalls, + count: toolCalls.length, + } + : undefined, + providerTiming: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + cost: { + input: accumulatedCost.input, + output: accumulatedCost.output, + total: accumulatedCost.total, + }, + }, + logs: [], + metadata: { + startTime: providerStartTimeISO, + endTime: new Date().toISOString(), + duration: Date.now() - providerStartTime, + }, + }, + } as StreamingExecution + + return streamingResult as StreamingExecution + } + + if (request.responseFormat && hasActiveTools) { + const finalPayload: any = { + model: payload.model, + messages: [...currentMessages], + } + if (payload.temperature !== undefined) { + finalPayload.temperature = payload.temperature + } + if (payload.max_tokens !== undefined) { + finalPayload.max_tokens = payload.max_tokens + } + + finalPayload.messages = await applyResponseFormat( + finalPayload, + finalPayload.messages, + request.responseFormat, + requestedModel + ) + + const finalStartTime = Date.now() + const finalResponse = await client.chat.completions.create( + finalPayload, + request.abortSignal ? { signal: request.abortSignal } : undefined + ) + const finalEndTime = Date.now() + const finalDuration = finalEndTime - finalStartTime + + timeSegments.push({ + type: 'model', + name: 'Final structured response', + startTime: finalStartTime, + endTime: finalEndTime, + duration: finalDuration, + }) + modelTime += finalDuration + + if (finalResponse.choices[0]?.message?.content) { + content = finalResponse.choices[0].message.content + } + if (finalResponse.usage) { + tokens.input += finalResponse.usage.prompt_tokens || 0 + tokens.output += finalResponse.usage.completion_tokens || 0 + tokens.total += finalResponse.usage.total_tokens || 0 + } + + enrichLastModelSegmentFromChatCompletions( + timeSegments, + finalResponse, + finalResponse.choices[0]?.message?.tool_calls, + { model: request.model, provider: 'together' } + ) + } + + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + return { + content, + model: requestedModel, + tokens, + toolCalls: toolCalls.length > 0 ? toolCalls : undefined, + toolResults: toolResults.length > 0 ? toolResults : undefined, + timing: { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + modelTime: modelTime, + toolsTime: toolsTime, + firstResponseTime: firstResponseTime, + iterations: iterationCount + 1, + timeSegments: timeSegments, + }, + } + } catch (error) { + const providerEndTime = Date.now() + const providerEndTimeISO = new Date(providerEndTime).toISOString() + const totalDuration = providerEndTime - providerStartTime + + const errorDetails: Record = { + error: toError(error).message, + duration: totalDuration, + } + if (error && typeof error === 'object') { + const err = error as any + if (err.status) errorDetails.status = err.status + if (err.code) errorDetails.code = err.code + if (err.type) errorDetails.type = err.type + if (err.error?.message) errorDetails.providerMessage = err.error.message + if (err.error?.metadata) errorDetails.metadata = err.error.metadata + } + + logger.error('Error in Together request:', errorDetails) + throw new ProviderError(toError(error).message, { + startTime: providerStartTimeISO, + endTime: providerEndTimeISO, + duration: totalDuration, + }) + } + }, +} diff --git a/apps/sim/providers/together/utils.ts b/apps/sim/providers/together/utils.ts new file mode 100644 index 00000000000..22b7823f342 --- /dev/null +++ b/apps/sim/providers/together/utils.ts @@ -0,0 +1,41 @@ +import type { ChatCompletionChunk } from 'openai/resources/chat/completions' +import type { CompletionUsage } from 'openai/resources/completions' +import { checkForForcedToolUsageOpenAI, createOpenAICompatibleStream } from '@/providers/utils' + +/** + * Together gates native `json_schema` per-model, so we use the broadly supported + * JSON-object mode for all models to avoid 400s. See https://docs.together.ai/docs/json-mode. + */ +export async function supportsNativeStructuredOutputs(_modelId: string): Promise { + return false +} + +/** + * Creates a ReadableStream from a Together AI streaming response. + * Uses the shared OpenAI-compatible streaming utility. + */ +export function createReadableStreamFromOpenAIStream( + openaiStream: AsyncIterable, + onComplete?: (content: string, usage: CompletionUsage) => void +): ReadableStream { + return createOpenAICompatibleStream(openaiStream, 'Together', onComplete) +} + +/** + * Checks if a forced tool was used in a Together AI response. + * Uses the shared OpenAI-compatible forced tool usage helper. + */ +export function checkForForcedToolUsage( + response: any, + toolChoice: string | { type: string; function?: { name: string }; name?: string; any?: any }, + forcedTools: string[], + usedForcedTools: string[] +): { hasUsedForcedTool: boolean; usedForcedTools: string[] } { + return checkForForcedToolUsageOpenAI( + response, + toolChoice, + 'Together', + forcedTools, + usedForcedTools + ) +} diff --git a/apps/sim/providers/types.ts b/apps/sim/providers/types.ts index dc2f25927d6..f5ab7a812a7 100644 --- a/apps/sim/providers/types.ts +++ b/apps/sim/providers/types.ts @@ -13,8 +13,11 @@ export type ProviderId = | 'groq' | 'mistral' | 'ollama' + | 'ollama-cloud' | 'openrouter' | 'fireworks' + | 'together' + | 'baseten' | 'vllm' | 'litellm' | 'bedrock' diff --git a/apps/sim/providers/utils.ts b/apps/sim/providers/utils.ts index 5cce6ce387e..d9fe8cbab3e 100644 --- a/apps/sim/providers/utils.ts +++ b/apps/sim/providers/utils.ts @@ -131,6 +131,7 @@ function buildProviderMetadata(providerId: ProviderId): ProviderMetadata { export const providers: Record = { ollama: buildProviderMetadata('ollama'), + 'ollama-cloud': buildProviderMetadata('ollama-cloud'), vllm: buildProviderMetadata('vllm'), litellm: buildProviderMetadata('litellm'), openai: { @@ -155,6 +156,8 @@ export const providers: Record = { bedrock: buildProviderMetadata('bedrock'), openrouter: buildProviderMetadata('openrouter'), fireworks: buildProviderMetadata('fireworks'), + together: buildProviderMetadata('together'), + baseten: buildProviderMetadata('baseten'), } export function updateOllamaProviderModels(models: string[]): void { @@ -186,15 +189,36 @@ export async function updateFireworksProviderModels(models: string[]): Promise { + const { updateOllamaCloudModels } = await import('@/providers/models') + updateOllamaCloudModels(models) + providers['ollama-cloud'].models = getProviderModelsFromDefinitions('ollama-cloud') +} + +export async function updateTogetherProviderModels(models: string[]): Promise { + const { updateTogetherModels } = await import('@/providers/models') + updateTogetherModels(models) + providers.together.models = getProviderModelsFromDefinitions('together') +} + +export async function updateBasetenProviderModels(models: string[]): Promise { + const { updateBasetenModels } = await import('@/providers/models') + updateBasetenModels(models) + providers.baseten.models = getProviderModelsFromDefinitions('baseten') +} + export function getBaseModelProviders(): Record { const allProviders = Object.entries(providers) .filter( ([providerId]) => providerId !== 'ollama' && + providerId !== 'ollama-cloud' && providerId !== 'vllm' && providerId !== 'litellm' && providerId !== 'openrouter' && - providerId !== 'fireworks' + providerId !== 'fireworks' && + providerId !== 'together' && + providerId !== 'baseten' ) .reduce( (map, [providerId, config]) => { diff --git a/apps/sim/stores/providers/store.ts b/apps/sim/stores/providers/store.ts index 00896c0ba7c..d2925f8c3f6 100644 --- a/apps/sim/stores/providers/store.ts +++ b/apps/sim/stores/providers/store.ts @@ -8,10 +8,13 @@ export const useProvidersStore = create((set, get) => ({ providers: { base: { models: [], isLoading: false }, ollama: { models: [], isLoading: false }, + 'ollama-cloud': { models: [], isLoading: false }, vllm: { models: [], isLoading: false }, litellm: { models: [], isLoading: false }, openrouter: { models: [], isLoading: false }, fireworks: { models: [], isLoading: false }, + together: { models: [], isLoading: false }, + baseten: { models: [], isLoading: false }, }, openRouterModelInfo: {}, diff --git a/apps/sim/stores/providers/types.ts b/apps/sim/stores/providers/types.ts index 7022529f202..7fe7f8cdeab 100644 --- a/apps/sim/stores/providers/types.ts +++ b/apps/sim/stores/providers/types.ts @@ -1,4 +1,13 @@ -export type ProviderName = 'ollama' | 'vllm' | 'litellm' | 'openrouter' | 'fireworks' | 'base' +export type ProviderName = + | 'ollama' + | 'ollama-cloud' + | 'vllm' + | 'litellm' + | 'openrouter' + | 'fireworks' + | 'together' + | 'baseten' + | 'base' export interface OpenRouterModelInfo { id: string diff --git a/apps/sim/tools/types.ts b/apps/sim/tools/types.ts index 24501cb6d96..fad0f65c80c 100644 --- a/apps/sim/tools/types.ts +++ b/apps/sim/tools/types.ts @@ -8,6 +8,9 @@ export type BYOKProviderId = | 'google' | 'mistral' | 'fireworks' + | 'together' + | 'baseten' + | 'ollama-cloud' | 'falai' | 'firecrawl' | 'exa' diff --git a/scripts/check-api-validation-contracts.ts b/scripts/check-api-validation-contracts.ts index 4fb99356eed..6f63a524433 100644 --- a/scripts/check-api-validation-contracts.ts +++ b/scripts/check-api-validation-contracts.ts @@ -9,8 +9,8 @@ const QUERY_HOOKS_DIR = path.join(ROOT, 'apps/sim/hooks/queries') const SELECTOR_HOOKS_DIR = path.join(ROOT, 'apps/sim/hooks/selectors') const BASELINE = { - totalRoutes: 758, - zodRoutes: 758, + totalRoutes: 761, + zodRoutes: 761, nonZodRoutes: 0, } as const