|
| 1 | +import { openai } from '@ai-sdk/openai' |
| 2 | +import { convertToModelMessages, stepCountIs, streamText, tool, type UIMessage } from 'ai' |
| 3 | +import { sql } from 'drizzle-orm' |
| 4 | +import { z } from 'zod' |
| 5 | +import { db, docsEmbeddings } from '@/lib/db' |
| 6 | +import { generateSearchEmbedding } from '@/lib/embeddings' |
| 7 | + |
| 8 | +export const runtime = 'nodejs' |
| 9 | +export const maxDuration = 30 |
| 10 | + |
| 11 | +/** Model used for the Ask AI chat. Override with OPENAI_CHAT_MODEL in the environment. */ |
| 12 | +const CHAT_MODEL = process.env.OPENAI_CHAT_MODEL || 'gpt-5.4-mini' |
| 13 | + |
| 14 | +/** Max documentation chunks returned per search to ground an answer. */ |
| 15 | +const SEARCH_LIMIT = 6 |
| 16 | + |
| 17 | +/** Candidates pulled before locale filtering, so a locale still yields SEARCH_LIMIT results. */ |
| 18 | +const SEARCH_CANDIDATES = SEARCH_LIMIT * 4 |
| 19 | + |
| 20 | +/** Minimum cosine similarity for an English vector match (mirrors the site search route). */ |
| 21 | +const SIMILARITY_THRESHOLD = 0.6 |
| 22 | + |
| 23 | +/** Locales the docs are published in (mirrors the site search route). */ |
| 24 | +const KNOWN_LOCALES = ['en', 'es', 'fr', 'de', 'ja', 'zh'] |
| 25 | +const DEFAULT_LOCALE = 'en' |
| 26 | + |
| 27 | +/** Postgres full-text config per locale (mirrors the site search route). */ |
| 28 | +const TS_CONFIG: Record<string, string> = { |
| 29 | + en: 'english', |
| 30 | + es: 'spanish', |
| 31 | + fr: 'french', |
| 32 | + de: 'german', |
| 33 | + ja: 'simple', |
| 34 | + zh: 'simple', |
| 35 | +} |
| 36 | + |
| 37 | +/** |
| 38 | + * Abuse guards. This endpoint proxies a paid LLM, so an unauthenticated public |
| 39 | + * route is a target for scripted "free inference". These bounds cap the cost of |
| 40 | + * any single request; an in-memory per-IP rate limit (below) caps volume on the |
| 41 | + * hot path. A shared-store rate limit, a provider spend cap, and edge bot |
| 42 | + * protection remain the durable controls (see the PR checklist). |
| 43 | + * |
| 44 | + * The size cap counts only user-authored text — NOT the conversation history, |
| 45 | + * assistant turns, or retrieved doc chunks we add via the searchDocs tool, which |
| 46 | + * legitimately grow large over a multi-turn chat. |
| 47 | + */ |
| 48 | +const MAX_MESSAGES = 200 |
| 49 | +const MAX_USER_INPUT_CHARS = 400_000 |
| 50 | +const MAX_OUTPUT_TOKENS = 4000 |
| 51 | +const MAX_STEPS = 6 |
| 52 | +/** Backstop on the whole serialized payload — blocks stuffing assistant/tool parts past the user-text cap. */ |
| 53 | +const MAX_TOTAL_CHARS = 1_000_000 |
| 54 | + |
| 55 | +/** |
| 56 | + * Per-IP rate limit. Fixed window, in-memory: this bounds volume from a single |
| 57 | + * source on a warm instance without external infra. It is best-effort on |
| 58 | + * serverless (state is per-instance, not shared across regions/cold starts); |
| 59 | + * a shared store (e.g. Vercel KV) and an edge WAF remain the durable controls, |
| 60 | + * but this closes the "no volume limit at all" gap on the hot path. |
| 61 | + */ |
| 62 | +const RATE_LIMIT_MAX = 20 |
| 63 | +const RATE_LIMIT_WINDOW_MS = 60_000 |
| 64 | +const rateLimitHits = new Map<string, { count: number; resetAt: number }>() |
| 65 | + |
| 66 | +/** Resolve the client IP from forwarding headers, falling back to a shared bucket. */ |
| 67 | +function getClientIp(req: Request): string { |
| 68 | + const forwarded = req.headers.get('x-forwarded-for') |
| 69 | + if (forwarded) return forwarded.split(',')[0].trim() |
| 70 | + return req.headers.get('x-real-ip') ?? 'unknown' |
| 71 | +} |
| 72 | + |
| 73 | +/** Fixed-window check. Returns retry-after seconds when the caller is over the limit, else null. */ |
| 74 | +function rateLimit(ip: string, now: number): number | null { |
| 75 | + const entry = rateLimitHits.get(ip) |
| 76 | + if (!entry || now >= entry.resetAt) { |
| 77 | + rateLimitHits.set(ip, { count: 1, resetAt: now + RATE_LIMIT_WINDOW_MS }) |
| 78 | + return null |
| 79 | + } |
| 80 | + if (entry.count >= RATE_LIMIT_MAX) { |
| 81 | + return Math.ceil((entry.resetAt - now) / 1000) |
| 82 | + } |
| 83 | + entry.count += 1 |
| 84 | + return null |
| 85 | +} |
| 86 | + |
| 87 | +/** Drop expired buckets so the Map doesn't grow unbounded on a long-lived instance. */ |
| 88 | +function sweepRateLimit(now: number): void { |
| 89 | + if (rateLimitHits.size < 10_000) return |
| 90 | + for (const [ip, entry] of rateLimitHits) { |
| 91 | + if (now >= entry.resetAt) rateLimitHits.delete(ip) |
| 92 | + } |
| 93 | +} |
| 94 | + |
| 95 | +/** A structurally valid UI message: has a role and a parts array. */ |
| 96 | +function isValidMessage(message: unknown): message is UIMessage { |
| 97 | + return ( |
| 98 | + typeof message === 'object' && |
| 99 | + message !== null && |
| 100 | + typeof (message as { role?: unknown }).role === 'string' && |
| 101 | + Array.isArray((message as { parts?: unknown }).parts) |
| 102 | + ) |
| 103 | +} |
| 104 | + |
| 105 | +/** Total length of user-authored text across the conversation. */ |
| 106 | +function userInputChars(messages: UIMessage[]): number { |
| 107 | + let total = 0 |
| 108 | + for (const message of messages) { |
| 109 | + if (message.role !== 'user') continue |
| 110 | + for (const part of message.parts) { |
| 111 | + if (part.type === 'text' && typeof part.text === 'string') total += part.text.length |
| 112 | + } |
| 113 | + } |
| 114 | + return total |
| 115 | +} |
| 116 | + |
| 117 | +/** |
| 118 | + * Strip everything the model shouldn't trust from client-supplied history: |
| 119 | + * drop `system` messages (client-injected instructions) and every non-text part |
| 120 | + * (e.g. crafted tool results faking searchDocs output). Only user/assistant text |
| 121 | + * survives, so grounding comes from the server-run searchDocs tool — not the |
| 122 | + * client's payload. |
| 123 | + */ |
| 124 | +function sanitizeMessages(messages: UIMessage[]): UIMessage[] { |
| 125 | + return messages |
| 126 | + .filter((message) => message.role === 'user' || message.role === 'assistant') |
| 127 | + .map((message) => ({ |
| 128 | + ...message, |
| 129 | + parts: message.parts.filter((part) => part.type === 'text' && typeof part.text === 'string'), |
| 130 | + })) |
| 131 | + .filter((message) => message.parts.length > 0) |
| 132 | +} |
| 133 | + |
| 134 | +/** |
| 135 | + * Reject obvious cross-origin calls. Same-origin browser requests send an |
| 136 | + * `Origin` header matching the host; we allow those, plus any host in |
| 137 | + * DOCS_ALLOWED_ORIGINS (comma-separated). Requests with no Origin (e.g. curl) |
| 138 | + * are allowed through to the cost caps rather than blocked, since Origin is |
| 139 | + * trivially spoofable and is a filter, not a security boundary. |
| 140 | + */ |
| 141 | +function isAllowedOrigin(req: Request): boolean { |
| 142 | + const origin = req.headers.get('origin') |
| 143 | + if (!origin) return true |
| 144 | + |
| 145 | + let originHost: string |
| 146 | + try { |
| 147 | + originHost = new URL(origin).host.toLowerCase() |
| 148 | + } catch { |
| 149 | + return false |
| 150 | + } |
| 151 | + |
| 152 | + const forwardedHost = req.headers.get('x-forwarded-host') ?? req.headers.get('host') |
| 153 | + const requestHost = forwardedHost?.split(',')[0].trim().toLowerCase() |
| 154 | + if (requestHost && originHost === requestHost) return true |
| 155 | + |
| 156 | + const allowlist = (process.env.DOCS_ALLOWED_ORIGINS ?? '') |
| 157 | + .split(',') |
| 158 | + .map((value) => value.trim().toLowerCase()) |
| 159 | + .filter(Boolean) |
| 160 | + return allowlist.includes(originHost) |
| 161 | +} |
| 162 | + |
| 163 | +const SYSTEM_PROMPT = `You are the documentation assistant for Sim — the open-source AI workspace where teams build, deploy, and manage AI agents. |
| 164 | +
|
| 165 | +Answer questions about Sim using the documentation. Always call the searchDocs tool before answering anything specific about Sim's features, configuration, or usage — do not answer from memory. Base your answer only on the returned documentation; if the docs do not cover the question, say so plainly rather than guessing. |
| 166 | +
|
| 167 | +Guidelines: |
| 168 | +- Be direct and concrete. Lead with the answer, then the detail. |
| 169 | +- Reference the relevant pages by their titles so the user knows where to read more. |
| 170 | +- When you show configuration or code, keep it minimal and correct. |
| 171 | +- The agent is called "Sim" and the chat surface is "Chat" — never say "Mothership" or "copilot". |
| 172 | +- If a question is unrelated to Sim, briefly say it's outside the docs' scope.` |
| 173 | + |
| 174 | +const SEARCH_COLUMNS = { |
| 175 | + title: docsEmbeddings.headerText, |
| 176 | + url: docsEmbeddings.sourceLink, |
| 177 | + content: docsEmbeddings.chunkText, |
| 178 | + sourceDocument: docsEmbeddings.sourceDocument, |
| 179 | +} |
| 180 | + |
| 181 | +/** |
| 182 | + * Retrieve candidate chunks for grounding. English docs are embedded, so they |
| 183 | + * use vector similarity; other locales rely on Postgres full-text keyword search |
| 184 | + * (vector search over English-trained embeddings does not serve them) — the same |
| 185 | + * split the site search route makes. |
| 186 | + */ |
| 187 | +/** |
| 188 | + * SQL predicate selecting only the locale's documents, so the row limit applies |
| 189 | + * to matching rows (mirrors `matchesLocale`): non-English docs are prefixed with |
| 190 | + * their locale segment; English is everything not prefixed with another locale. |
| 191 | + */ |
| 192 | +function localeFilter(locale: string) { |
| 193 | + const firstSegment = sql`split_part(${docsEmbeddings.sourceDocument}, '/', 1)` |
| 194 | + if (locale === DEFAULT_LOCALE) { |
| 195 | + const others = KNOWN_LOCALES.filter((l) => l !== DEFAULT_LOCALE) |
| 196 | + return sql`${firstSegment} not in (${sql.join( |
| 197 | + others.map((l) => sql`${l}`), |
| 198 | + sql`, ` |
| 199 | + )})` |
| 200 | + } |
| 201 | + return sql`${firstSegment} = ${locale}` |
| 202 | +} |
| 203 | + |
| 204 | +async function searchDocs(query: string, locale: string) { |
| 205 | + let rows: Array<{ title: string; url: string; content: string; sourceDocument: string }> |
| 206 | + |
| 207 | + if (locale === DEFAULT_LOCALE) { |
| 208 | + const embedding = await generateSearchEmbedding(query) |
| 209 | + const vectorLiteral = JSON.stringify(embedding) |
| 210 | + rows = await db |
| 211 | + .select(SEARCH_COLUMNS) |
| 212 | + .from(docsEmbeddings) |
| 213 | + .where( |
| 214 | + sql`1 - (${docsEmbeddings.embedding} <=> ${vectorLiteral}::vector) >= ${SIMILARITY_THRESHOLD} and ${localeFilter(locale)}` |
| 215 | + ) |
| 216 | + .orderBy(sql`${docsEmbeddings.embedding} <=> ${vectorLiteral}::vector`) |
| 217 | + .limit(SEARCH_CANDIDATES) |
| 218 | + } else { |
| 219 | + const tsConfig = TS_CONFIG[locale] ?? 'simple' |
| 220 | + rows = await db |
| 221 | + .select(SEARCH_COLUMNS) |
| 222 | + .from(docsEmbeddings) |
| 223 | + .where( |
| 224 | + sql`${docsEmbeddings.chunkTextTsv} @@ plainto_tsquery(${tsConfig}, ${query}) and ${localeFilter(locale)}` |
| 225 | + ) |
| 226 | + .orderBy( |
| 227 | + sql`ts_rank(${docsEmbeddings.chunkTextTsv}, plainto_tsquery(${tsConfig}, ${query})) DESC` |
| 228 | + ) |
| 229 | + .limit(SEARCH_CANDIDATES) |
| 230 | + } |
| 231 | + |
| 232 | + return rows.slice(0, SEARCH_LIMIT).map((row) => ({ |
| 233 | + title: row.title, |
| 234 | + url: row.url, |
| 235 | + content: row.content, |
| 236 | + })) |
| 237 | +} |
| 238 | + |
| 239 | +export async function POST(req: Request) { |
| 240 | + if (!isAllowedOrigin(req)) { |
| 241 | + return new Response('Forbidden', { status: 403 }) |
| 242 | + } |
| 243 | + |
| 244 | + const now = Date.now() |
| 245 | + sweepRateLimit(now) |
| 246 | + const retryAfter = rateLimit(getClientIp(req), now) |
| 247 | + if (retryAfter !== null) { |
| 248 | + return new Response('Too many requests', { |
| 249 | + status: 429, |
| 250 | + headers: { 'Retry-After': String(retryAfter) }, |
| 251 | + }) |
| 252 | + } |
| 253 | + |
| 254 | + let body: { messages: UIMessage[]; locale?: string } |
| 255 | + try { |
| 256 | + body = await req.json() |
| 257 | + } catch { |
| 258 | + return new Response('Invalid JSON', { status: 400 }) |
| 259 | + } |
| 260 | + const { messages } = body |
| 261 | + const locale = KNOWN_LOCALES.includes(body.locale ?? '') |
| 262 | + ? (body.locale as string) |
| 263 | + : DEFAULT_LOCALE |
| 264 | + |
| 265 | + if (!Array.isArray(messages) || messages.length === 0 || messages.length > MAX_MESSAGES) { |
| 266 | + return new Response('Invalid request', { status: 400 }) |
| 267 | + } |
| 268 | + if (!messages.every(isValidMessage)) { |
| 269 | + return new Response('Invalid request', { status: 400 }) |
| 270 | + } |
| 271 | + if (userInputChars(messages) > MAX_USER_INPUT_CHARS) { |
| 272 | + return new Response('Request too large', { status: 413 }) |
| 273 | + } |
| 274 | + if (JSON.stringify(messages).length > MAX_TOTAL_CHARS) { |
| 275 | + return new Response('Request too large', { status: 413 }) |
| 276 | + } |
| 277 | + |
| 278 | + const modelMessages = sanitizeMessages(messages) |
| 279 | + if (modelMessages.length === 0) { |
| 280 | + return new Response('Invalid request', { status: 400 }) |
| 281 | + } |
| 282 | + |
| 283 | + const result = streamText({ |
| 284 | + model: openai(CHAT_MODEL), |
| 285 | + system: SYSTEM_PROMPT, |
| 286 | + messages: convertToModelMessages(modelMessages), |
| 287 | + stopWhen: stepCountIs(MAX_STEPS), |
| 288 | + maxOutputTokens: MAX_OUTPUT_TOKENS, |
| 289 | + tools: { |
| 290 | + searchDocs: tool({ |
| 291 | + description: |
| 292 | + 'Search the Sim documentation for relevant content. Use this before answering any question about Sim.', |
| 293 | + inputSchema: z.object({ |
| 294 | + query: z.string().describe('A focused natural-language search query.'), |
| 295 | + }), |
| 296 | + execute: async ({ query }) => searchDocs(query, locale), |
| 297 | + }), |
| 298 | + }, |
| 299 | + }) |
| 300 | + |
| 301 | + return result.toUIMessageStreamResponse() |
| 302 | +} |
0 commit comments