diff --git a/.github/workflows/release-mcp.yml b/.github/workflows/release-mcp.yml deleted file mode 100644 index 4d63a525b..000000000 --- a/.github/workflows/release-mcp.yml +++ /dev/null @@ -1,157 +0,0 @@ -name: Release MCP Package - -permissions: - contents: write - id-token: write - -on: - workflow_dispatch: - inputs: - bump_type: - description: "Type of version bump to apply" - required: true - type: choice - options: - - patch - - minor - - major - -concurrency: - group: release-mcp - cancel-in-progress: false - -jobs: - release: - runs-on: ubuntu-latest - - steps: - - name: Generate GitHub App token - id: generate_token - uses: actions/create-github-app-token@v1 - with: - app-id: ${{ secrets.RELEASE_APP_ID }} - private-key: ${{ secrets.RELEASE_APP_PRIVATE_KEY }} - - - name: Checkout repository - uses: actions/checkout@v4 - with: - ref: main - fetch-depth: 0 - token: ${{ steps.generate_token.outputs.token }} - - - name: Setup Node.js - uses: actions/setup-node@v4 - with: - node-version: "24" - registry-url: "https://registry.npmjs.org" - - - name: Calculate new version - id: calculate_version - run: | - # Extract current version from package.json - CURRENT_VERSION=$(node -p "require('./packages/mcp/package.json').version") - - if [ -z "$CURRENT_VERSION" ]; then - echo "Error: Could not extract current version from package.json" - exit 1 - fi - - echo "Current version: $CURRENT_VERSION" - - # Parse version components - IFS='.' read -r MAJOR MINOR PATCH <<< "$CURRENT_VERSION" - - # Apply bump based on input - BUMP_TYPE="${{ inputs.bump_type }}" - case "$BUMP_TYPE" in - major) - MAJOR=$((MAJOR + 1)) - MINOR=0 - PATCH=0 - ;; - minor) - MINOR=$((MINOR + 1)) - PATCH=0 - ;; - patch) - PATCH=$((PATCH + 1)) - ;; - *) - echo "Error: Invalid bump type: $BUMP_TYPE" - exit 1 - ;; - esac - - NEW_VERSION="$MAJOR.$MINOR.$PATCH" - echo "New version: $NEW_VERSION" - - # Export to GITHUB_ENV for use in subsequent steps - echo "VERSION=$NEW_VERSION" >> $GITHUB_ENV - - # Export to GITHUB_OUTPUT for use in other jobs - echo "version=$NEW_VERSION" >> $GITHUB_OUTPUT - - - name: Check if version already exists - run: | - if grep -q "## \[$VERSION\]" packages/mcp/CHANGELOG.md; then - echo "Error: Version $VERSION already exists in CHANGELOG.md" - exit 1 - fi - if git tag | grep -q "^mcp-v$VERSION$"; then - echo "Error: Tag mcp-v$VERSION already exists" - exit 1 - fi - - - name: Update CHANGELOG.md - run: | - DATE=$(date +%Y-%m-%d) - - # Insert the new version header after the [Unreleased] line - sed -i "/## \[Unreleased\]/a\\ - \\ - ## [$VERSION] - $DATE" packages/mcp/CHANGELOG.md - - echo "Updated CHANGELOG.md with version $VERSION" - cat packages/mcp/CHANGELOG.md | head -n 20 - - - name: Update package.json version - run: | - node -e " - const fs = require('fs'); - const path = 'packages/mcp/package.json'; - const pkg = JSON.parse(fs.readFileSync(path, 'utf8')); - pkg.version = process.env.VERSION; - fs.writeFileSync(path, JSON.stringify(pkg, null, 4) + '\n'); - " - echo "Updated package.json to version $VERSION" - head -n 5 packages/mcp/package.json - - - name: Configure git - run: | - git config user.name "github-actions[bot]" - git config user.email "github-actions[bot]@users.noreply.github.com" - - - name: Commit changes - run: | - git add packages/mcp/CHANGELOG.md packages/mcp/package.json - git commit -m "Release @sourcebot/mcp v$VERSION" - - - name: Install dependencies - run: yarn install --frozen-lockfile - - - name: Build MCP package - run: yarn workspace @sourcebot/mcp build - - - name: Publish to npm - env: - NODE_AUTH_TOKEN: ${{ secrets.NPM_TOKEN }} - run: | - cd packages/mcp - npm publish --provenance --access public - - - name: Push main - env: - GH_TOKEN: ${{ steps.generate_token.outputs.token }} - run: | - git push origin main - echo "✓ Pushed release commit to main" diff --git a/AGENTS.md b/AGENTS.md index 200241b78..69e435ef4 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -35,10 +35,6 @@ Standard dev commands are documented in `CONTRIBUTING.md` and `package.json`. Ke - **Build deps only:** `yarn build:deps` (builds shared packages: schemas, db, shared, query-language) - **DB migrations:** `yarn dev:prisma:migrate:dev` -### Deprecated Packages - -- **`packages/mcp`** - This standalone MCP package is deprecated. Do NOT modify it. MCP functionality is now handled by the web package at `packages/web/src/features/mcp/`. - ### Non-obvious Caveats - **Docker must be running** before `yarn dev`. Start it with `docker compose -f docker-compose-dev.yml up -d`. The backend will fail to connect to Redis/PostgreSQL otherwise. diff --git a/CHANGELOG.md b/CHANGELOG.md index bd292ce1c..785f8f20c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1020,7 +1020,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Added audit logging. [#355](https://github.com/sourcebot-dev/sourcebot/pull/355) - ### Fixed - Delete account join request when redeeming an invite. [#352](https://github.com/sourcebot-dev/sourcebot/pull/352) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index febcbb812..44a5ce8d9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -131,7 +131,7 @@ Then restart the dev server. Components that re-render will be highlighted in th - Keep pull requests small and focused - Explain the issue and why your change fixes it - Before adding new functionality, ensure it doesn't already exist elsewhere in the codebase -- Update `CHANGELOG.md` with an entry under `[Unreleased]` linking to your PR. New entries should be placed at the bottom of their section. If your change touches `packages/mcp`, update `packages/mcp/CHANGELOG.md` instead. +- Update `CHANGELOG.md` with an entry under `[Unreleased]` linking to your PR. New entries should be placed at the bottom of their section. ### UI Changes diff --git a/Makefile b/Makefile index 8db819228..0c58c00ac 100644 --- a/Makefile +++ b/Makefile @@ -37,8 +37,6 @@ clean: packages/db/dist \ packages/schemas/node_modules \ packages/schemas/dist \ - packages/mcp/node_modules \ - packages/mcp/dist \ packages/shared/node_modules \ packages/shared/dist \ .sourcebot diff --git a/docs/docs/features/mcp-server.mdx b/docs/docs/features/mcp-server.mdx index 1cb24b0d9..7595a3c12 100644 --- a/docs/docs/features/mcp-server.mdx +++ b/docs/docs/features/mcp-server.mdx @@ -7,6 +7,8 @@ import LicenseKeyRequired from '/snippets/license-key-required.mdx' The Sourcebot MCP Server connects AI tools to your [Sourcebot deployment](/docs/deployment/docker-compose). This gives your agents the ability to search, read files, resolve references & definitions, and more across all of your code hosted on Sourcebot. + + ## Use cases - **Context for local agents:** Plug the MCP into coding agents like Cursor, Claude Code, or Copilot to give them context across your entire codebase, not just the open workspace. @@ -16,7 +18,7 @@ The Sourcebot MCP Server connects AI tools to your [Sourcebot deployment](/docs/ Sourcebot MCP uses a [Streamable HTTP](https://modelcontextprotocol.io/specification/2025-06-18/basic/transports#streamable-http) transport hosted at the `/api/mcp` route. Two authorization mechanisms are supported: -- **OAuth (preferred)**: MCP clients that support OAuth 2.0 will automatically handle the authorization flow and issue a short lived access token. No API key or manual token management required. Only available with an active [Enterprise license](/docs/activating-a-subscription). +- **OAuth (preferred)**: MCP clients that support OAuth 2.0 will automatically handle the authorization flow and issue a short lived access token. No API key or manual token management required. Only available with a paid [subscription](/docs/activating-a-subscription). - **API key**: Any MCP client can authorize using a Sourcebot API key passed as a `Authorization: Bearer ` header. Create one in **Settings → API Keys**. You can read more about the options in the [authorization](#authorization) section. diff --git a/packages/shared/src/entitlements.ts b/packages/shared/src/entitlements.ts index b6da38fc5..bcfdac6cd 100644 --- a/packages/shared/src/entitlements.ts +++ b/packages/shared/src/entitlements.ts @@ -37,10 +37,10 @@ const ALL_ENTITLEMENTS = [ "analytics", "permission-syncing", "github-app", - "chat-sharing", "org-management", "oauth", - "ask" + "ask", + "mcp" ] as const; export type Entitlement = (typeof ALL_ENTITLEMENTS)[number]; diff --git a/packages/web/next.config.mjs b/packages/web/next.config.mjs index 6211fcfe2..c34c126d0 100644 --- a/packages/web/next.config.mjs +++ b/packages/web/next.config.mjs @@ -48,6 +48,13 @@ const nextConfig = { { source: "/register", destination: "/api/ee/oauth/register", + }, + // The MCP server lives under /api/ee/mcp so it sits in the EE-licensed + // route tree, but is exposed at the stable, public /api/mcp path that + // existing MCP client configurations point at. + { + source: "/api/mcp", + destination: "/api/ee/mcp", } ]; }, diff --git a/packages/web/src/app/(app)/@sidebar/components/defaultSidebar/index.tsx b/packages/web/src/app/(app)/@sidebar/components/defaultSidebar/index.tsx index c943e0ea5..39edf49df 100644 --- a/packages/web/src/app/(app)/@sidebar/components/defaultSidebar/index.tsx +++ b/packages/web/src/app/(app)/@sidebar/components/defaultSidebar/index.tsx @@ -13,7 +13,7 @@ import { ChatHistory } from "./chatHistory"; import { RepoVisitHistory } from "./repoVisitHistory"; import { getAuthContext, withAuth } from "@/middleware/withAuth"; import { sew } from "@/middleware/sew"; -import { isValidLicenseActive } from "@/lib/entitlements"; +import { hasEntitlement, isValidLicenseActive } from "@/lib/entitlements"; const SIDEBAR_CHAT_LIMIT = 30; export const SIDEBAR_REPO_VISITS_LIMIT = 10; @@ -23,7 +23,10 @@ export async function DefaultSidebar() { const cookieStore = await cookies(); const homeView = (cookieStore.get(HOME_VIEW_COOKIE_NAME)?.value ?? "search") as HomeView; - const chatHistory = session ? await getUserChatHistory() : []; + // Chat history is part of the Ask experience; hide it when the deployment + // is not on a plan that includes Ask. + const hasAskEntitlement = await hasEntitlement('ask'); + const chatHistory = (session && hasAskEntitlement) ? await getUserChatHistory() : []; if (isServiceError(chatHistory)) { throw new ServiceErrorException(chatHistory); } @@ -63,10 +66,12 @@ export async function DefaultSidebar() { } > - SIDEBAR_CHAT_LIMIT} - /> + {hasAskEntitlement && ( + SIDEBAR_CHAT_LIMIT} + /> + )} ); } diff --git a/packages/web/src/app/(app)/@sidebar/components/defaultSidebar/nav.tsx b/packages/web/src/app/(app)/@sidebar/components/defaultSidebar/nav.tsx index d20fe1057..62d1bbf03 100644 --- a/packages/web/src/app/(app)/@sidebar/components/defaultSidebar/nav.tsx +++ b/packages/web/src/app/(app)/@sidebar/components/defaultSidebar/nav.tsx @@ -22,6 +22,9 @@ interface NavItem { key: string; requiresAuth?: boolean; requiredEntitlement?: Entitlement; + // When true, the item is hidden entirely if the required entitlement is + // missing, instead of being shown with an upgrade badge. + hideIfMissingEntitlement?: boolean; } interface NavProps { @@ -69,6 +72,8 @@ export function Nav({ icon: MessagesSquareIcon, key: "chats", requiresAuth: true, + requiredEntitlement: "ask", + hideIfMissingEntitlement: true, }, { title: "Repositories", @@ -108,7 +113,10 @@ export function Nav({ return ( - {baseItems.filter((item) => !item.requiresAuth || isSignedIn).map((item) => { + {baseItems + .filter((item) => !item.requiresAuth || isSignedIn) + .filter((item) => !item.hideIfMissingEntitlement || !item.requiredEntitlement || entitlements.includes(item.requiredEntitlement)) + .map((item) => { const showNotification = (item.key === "settings" && isSettingsNotificationVisible); diff --git a/packages/web/src/app/(app)/@sidebar/components/upgradeButton.tsx b/packages/web/src/app/(app)/@sidebar/components/upgradeButton.tsx index a506a2763..9bba91a2d 100644 --- a/packages/web/src/app/(app)/@sidebar/components/upgradeButton.tsx +++ b/packages/web/src/app/(app)/@sidebar/components/upgradeButton.tsx @@ -2,8 +2,8 @@ import { ArrowUpCircle } from "lucide-react"; import { useState } from "react"; -import { UpsellDialog } from "@/ee/features/lighthouse/upsellDialog"; -import { useOffers } from "@/ee/features/lighthouse/useOffers"; +import { UpsellDialog } from "@/features/billing/upsellDialog"; +import { useOffers } from "@/features/billing/useOffers"; import { Skeleton } from "@/components/ui/skeleton"; export const UpgradeButton = () => { diff --git a/packages/web/src/app/(app)/askgh/[owner]/[repo]/page.tsx b/packages/web/src/app/(app)/askgh/[owner]/[repo]/page.tsx index 389cbc42f..dd1dd95c3 100644 --- a/packages/web/src/app/(app)/askgh/[owner]/[repo]/page.tsx +++ b/packages/web/src/app/(app)/askgh/[owner]/[repo]/page.tsx @@ -8,6 +8,8 @@ import { RepoIndexedGuard } from "./components/repoIndexedGuard"; import { LandingPage } from "./components/landingPage"; import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; import { auth } from "@/auth"; +import { hasEntitlement } from "@/lib/entitlements"; +import { ChatEntitlementMessage } from "@/features/chat/components/chatEntitlementMessage"; interface PageProps { params: Promise<{ owner: string; repo: string }>; @@ -17,6 +19,12 @@ export default async function GitHubRepoPage(props: PageProps) { const params = await props.params; const { owner, repo } = params; const session = await auth(); + + // The askgh experiment env flag must never bypass licensing; enforce `ask` + // uniformly (the demo deployment carries a real license with `ask`). + if (!await hasEntitlement('ask')) { + return ; + } const repoId = await (async () => { // 1. Look up repo by owner/repo diff --git a/packages/web/src/app/(app)/chat/[id]/page.tsx b/packages/web/src/app/(app)/chat/[id]/page.tsx index e31e0acf4..45b67be8a 100644 --- a/packages/web/src/app/(app)/chat/[id]/page.tsx +++ b/packages/web/src/app/(app)/chat/[id]/page.tsx @@ -3,7 +3,7 @@ import { getChatInfo, getSharedWithUsersForChat } from '@/features/chat/actions' import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; import { ServiceErrorException } from '@/lib/serviceError'; import { isServiceError } from '@/lib/utils'; -import { ChatThreadPanel } from './components/chatThreadPanel'; +import { ChatThreadPanel } from '@/ee/features/chat/components/chatThreadPanel'; import { notFound } from 'next/navigation'; import { StatusCodes } from 'http-status-codes'; import { Separator } from '@/components/ui/separator'; @@ -16,6 +16,7 @@ import { Metadata } from 'next'; import { SBChatMessage } from '@/features/chat/types'; import { env } from '@sourcebot/shared'; import { hasEntitlement } from '@/lib/entitlements'; +import { ChatEntitlementMessage } from '@/features/chat/components/chatEntitlementMessage'; import { captureEvent } from '@/lib/posthog'; interface PageProps { @@ -81,6 +82,20 @@ export default async function Page(props: PageProps) { const params = await props.params; const session = await auth(); + // Gate the Ask experience behind the `ask` entitlement (deployment-level). + // Viewing a public/shared chat still works on a licensed deployment; a + // downgraded deployment shows the upsell while preserving the chat data. + if (!await hasEntitlement('ask')) { + return ( + + ); + } + const languageModels = await getConfiguredLanguageModelsInfo(); const repos = await getRepos(); const searchContexts = await getSearchContexts(); @@ -122,7 +137,9 @@ export default async function Page(props: PageProps) { const indexedRepos = repos.filter((repo) => repo.indexedAt !== undefined); - const hasChatSharingEntitlement = await hasEntitlement('chat-sharing'); + // Chat sharing is part of Ask (the standalone `chat-sharing` entitlement was + // folded into `ask`). By this point the page is already gated on `ask`. + const hasChatSharingEntitlement = await hasEntitlement('ask'); return (
diff --git a/packages/web/src/app/(app)/chats/chatsPage.tsx b/packages/web/src/app/(app)/chats/chatsPage.tsx index 9fc2f987f..02a9dad02 100644 --- a/packages/web/src/app/(app)/chats/chatsPage.tsx +++ b/packages/web/src/app/(app)/chats/chatsPage.tsx @@ -27,7 +27,7 @@ import { useCallback, useEffect, useMemo, useRef, useState } from "react"; import { useHotkeys } from "react-hotkeys-hook"; import { listChats } from "@/app/api/(client)/client"; -import type { ListChatsResponse } from "@/app/api/(server)/chats/types"; +import type { ListChatsResponse } from "@/app/api/(server)/ee/chats/types"; type Chat = ListChatsResponse["chats"][number]; type SortBy = "name" | "updatedAt"; diff --git a/packages/web/src/app/(app)/chats/page.tsx b/packages/web/src/app/(app)/chats/page.tsx index d665b98ab..7dc44b3a2 100644 --- a/packages/web/src/app/(app)/chats/page.tsx +++ b/packages/web/src/app/(app)/chats/page.tsx @@ -1,6 +1,19 @@ import { authenticatedPage } from "@/middleware/authenticatedPage"; import { ChatsPage } from "./chatsPage"; +import { hasEntitlement } from "@/lib/entitlements"; +import { ChatEntitlementMessage } from "@/features/chat/components/chatEntitlementMessage"; export default authenticatedPage(async () => { + if (!await hasEntitlement('ask')) { + return ( + + ); + } return ; }); +`` \ No newline at end of file diff --git a/packages/web/src/app/(app)/layout.tsx b/packages/web/src/app/(app)/layout.tsx index 9842d0333..c05cbd38c 100644 --- a/packages/web/src/app/(app)/layout.tsx +++ b/packages/web/src/app/(app)/layout.tsx @@ -31,9 +31,9 @@ import { OrgRole } from "@sourcebot/db"; import { ServiceErrorException } from "@/lib/serviceError"; import { ConnectAccountsCard } from "@/ee/features/sso/components/connectAccountsCard"; import { SidebarProvider } from "@/components/ui/sidebar"; -import { CheckoutReturnHandler } from "@/ee/features/lighthouse/checkoutReturnHandler"; +import { CheckoutReturnHandler } from "@/features/billing/checkoutReturnHandler"; import { RoleProvider } from "@/features/auth/roleProvider"; -import { HasLicenseProvider } from "@/ee/features/lighthouse/hasLicenseProvider"; +import { HasLicenseProvider } from "@/features/billing/hasLicenseProvider"; import { tryGetLatestSourcebotTag } from "./components/banners/actions"; interface LayoutProps { diff --git a/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentEntitlementMessage.tsx b/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentEntitlementMessage.tsx new file mode 100644 index 000000000..94f889ae4 --- /dev/null +++ b/packages/web/src/app/(app)/settings/accountAskAgent/accountAskAgentEntitlementMessage.tsx @@ -0,0 +1,23 @@ +"use client" + +import { UpsellPanel } from "@/features/billing/upsellDialog" + +/** + * Shown in place of the per-user Ask Agent connector UI when the deployment is + * not on a plan that includes Ask Sourcebot. FSL (not ee/) so it can render for + * free-plan users as the upsell surface, reusing the shared feature-breakdown + * panel (plan comparison + trial/upgrade) without mounting any ee/ connector code. + */ +export function AccountAskAgentEntitlementMessage() { + return ( +
+ +
+ ) +} diff --git a/packages/web/src/app/(app)/settings/accountAskAgent/page.tsx b/packages/web/src/app/(app)/settings/accountAskAgent/page.tsx index 078e67288..dba12d6b7 100644 --- a/packages/web/src/app/(app)/settings/accountAskAgent/page.tsx +++ b/packages/web/src/app/(app)/settings/accountAskAgent/page.tsx @@ -1,4 +1,5 @@ -import { AccountAskAgentPage } from "./accountAskAgentPage"; +import { AccountAskAgentPage } from "@/ee/features/chat/mcp/components/accountAskAgentPage"; +import { AccountAskAgentEntitlementMessage } from "./accountAskAgentEntitlementMessage"; import { hasEntitlement } from "@/lib/entitlements"; import { authenticatedPage } from "@/middleware/authenticatedPage"; import { OrgRole } from "@sourcebot/db"; @@ -12,8 +13,14 @@ interface PageProps extends Record { } export default authenticatedPage(async ({ role }, { searchParams }) => { + // Ask Agent connectors are part of Ask Sourcebot. Gate the EE connector UI + // behind the `ask` entitlement here so it never renders or executes on a + // non-entitled deployment; show the FSL upsell panel instead. + if (!(await hasEntitlement('ask'))) { + return ; + } + const { status, server, message } = await searchParams; - const isOAuthAvailable = await hasEntitlement('oauth'); return ( (async ({ role }, { searchParams }) = callbackServer={server} callbackMessage={message} canManageConnectors={role === OrgRole.OWNER} - isOAuthAvailable={isOAuthAvailable} /> ); }); diff --git a/packages/web/src/app/(app)/settings/analytics/analyticsEntitlementMessage.tsx b/packages/web/src/app/(app)/settings/analytics/analyticsEntitlementMessage.tsx new file mode 100644 index 000000000..fdcaa4f79 --- /dev/null +++ b/packages/web/src/app/(app)/settings/analytics/analyticsEntitlementMessage.tsx @@ -0,0 +1,23 @@ +"use client" + +import { UpsellPanel } from "@/features/billing/upsellDialog" + +/** + * Shown in place of the analytics dashboard when the deployment is not on a plan + * that includes analytics. FSL (not ee/) so it can render for free-plan users as + * the upsell surface, reusing the shared feature-breakdown panel (plan comparison + * + trial/upgrade) without mounting any ee/ analytics feature code. + */ +export function AnalyticsEntitlementMessage() { + return ( +
+ +
+ ) +} diff --git a/packages/web/src/app/(app)/settings/analytics/page.tsx b/packages/web/src/app/(app)/settings/analytics/page.tsx index eb0831447..c4c9c90ca 100644 --- a/packages/web/src/app/(app)/settings/analytics/page.tsx +++ b/packages/web/src/app/(app)/settings/analytics/page.tsx @@ -1,5 +1,5 @@ import { AnalyticsContent } from "@/ee/features/analytics/analyticsContent"; -import { AnalyticsEntitlementMessage } from "@/ee/features/analytics/analyticsEntitlementMessage"; +import { AnalyticsEntitlementMessage } from "./analyticsEntitlementMessage"; import { authenticatedPage } from "@/middleware/authenticatedPage"; import { OrgRole } from "@sourcebot/db"; import { hasEntitlement } from "@/lib/entitlements"; diff --git a/packages/web/src/app/(app)/settings/layout.tsx b/packages/web/src/app/(app)/settings/layout.tsx index 604601027..951a11c76 100644 --- a/packages/web/src/app/(app)/settings/layout.tsx +++ b/packages/web/src/app/(app)/settings/layout.tsx @@ -44,7 +44,7 @@ export default async function SettingsLayout( } export const getSidebarNavGroups = async () => - withAuth(async ({ org, role, prisma }) => { + withAuth(async ({ role }) => { let numJoinRequests: number | undefined; if (role === OrgRole.OWNER) { const requests = await getOrgAccountRequests(); @@ -58,12 +58,7 @@ export const getSidebarNavGroups = async () => if (isServiceError(connectionStats)) { throw new ServiceErrorException(connectionStats); } - const hasOAuthEntitlement = await hasEntitlement("oauth"); - const hasApprovedConnectors = role === OrgRole.OWNER && !hasOAuthEntitlement - ? await prisma.mcpServer.count({ - where: { orgId: org.id }, - }) > 0 - : false; + const hasAskEntitlement = await hasEntitlement("ask"); const groups: NavGroup[] = [ { @@ -92,8 +87,9 @@ export const getSidebarNavGroups = async () => title: "MCP Server", href: `/settings/mcp`, icon: 'mcp' as const, + requiredEntitlement: 'mcp', }, - ...(hasOAuthEntitlement ? [ + ...(hasAskEntitlement ? [ { title: "Ask Agent", href: `/settings/accountAskAgent`, @@ -132,13 +128,12 @@ export const getSidebarNavGroups = async () => icon: "chart-area" as const, requiredEntitlement: 'analytics' }, - ...(hasOAuthEntitlement || hasApprovedConnectors ? [ - { - title: "Ask Agent", - href: `/settings/workspaceAskAgent`, - icon: "bot" as const, - } - ] : []), + { + title: "Ask Agent", + href: `/settings/workspaceAskAgent`, + icon: "bot" as const, + requiredEntitlement: 'ask', + }, { title: "License", href: `/settings/license`, diff --git a/packages/web/src/app/(app)/settings/license/activationCodeCard.tsx b/packages/web/src/app/(app)/settings/license/activationCodeCard.tsx index d87e6205f..63873a546 100644 --- a/packages/web/src/app/(app)/settings/license/activationCodeCard.tsx +++ b/packages/web/src/app/(app)/settings/license/activationCodeCard.tsx @@ -4,7 +4,7 @@ import { useState, useCallback } from "react"; import { Input } from "@/components/ui/input"; import { LoadingButton } from "@/components/ui/loading-button"; import { SettingsCard } from "../components/settingsCard"; -import { activateLicense } from "@/ee/features/lighthouse/actions"; +import { activateLicense } from "@/features/billing/actions"; import { isServiceError } from "@/lib/utils"; import { useToast } from "@/components/hooks/use-toast"; import { Separator } from "@/components/ui/separator"; diff --git a/packages/web/src/app/(app)/settings/license/onlineLicenseCard/removeActivationCodeDialog.tsx b/packages/web/src/app/(app)/settings/license/onlineLicenseCard/removeActivationCodeDialog.tsx index d2071c82e..49593cbb9 100644 --- a/packages/web/src/app/(app)/settings/license/onlineLicenseCard/removeActivationCodeDialog.tsx +++ b/packages/web/src/app/(app)/settings/license/onlineLicenseCard/removeActivationCodeDialog.tsx @@ -13,7 +13,7 @@ import { AlertDialogTitle, } from "@/components/ui/alert-dialog"; import { useToast } from "@/components/hooks/use-toast"; -import { deactivateLicense } from "@/ee/features/lighthouse/actions"; +import { deactivateLicense } from "@/features/billing/actions"; import { isServiceError } from "@/lib/utils"; interface RemoveActivationCodeDialogProps { diff --git a/packages/web/src/app/(app)/settings/license/page.tsx b/packages/web/src/app/(app)/settings/license/page.tsx index 97b24f1ab..b84cb0df5 100644 --- a/packages/web/src/app/(app)/settings/license/page.tsx +++ b/packages/web/src/app/(app)/settings/license/page.tsx @@ -10,9 +10,9 @@ import { OfflineLicenseCard } from "./offlineLicenseCard"; import { RecentInvoicesCard } from "./recentInvoicesCard"; import { YearlyTermSeatsUsageCard } from "./yearlyTermSeatsUsageCard"; import { SettingsCard } from "../components/settingsCard"; -import { UpsellPanel } from "@/ee/features/lighthouse/upsellDialog"; +import { UpsellPanel } from "@/features/billing/upsellDialog"; import { getAllInvoices } from "@/ee/features/lighthouse/actions"; -import { syncWithLighthouse } from "@/ee/features/lighthouse/servicePing"; +import { syncWithLighthouse } from "@/features/billing/servicePing"; import { isServiceError } from "@/lib/utils"; import { getYearlyTermStatus } from "./types"; diff --git a/packages/web/src/app/(app)/settings/license/recentInvoicesCard.tsx b/packages/web/src/app/(app)/settings/license/recentInvoicesCard.tsx index 1ea6d5db0..f94645867 100644 --- a/packages/web/src/app/(app)/settings/license/recentInvoicesCard.tsx +++ b/packages/web/src/app/(app)/settings/license/recentInvoicesCard.tsx @@ -1,6 +1,6 @@ import Link from "next/link"; import { ExternalLink } from "lucide-react"; -import { Invoice } from "@/ee/features/lighthouse/types"; +import { Invoice } from "@/features/billing/types"; import { SettingsCard, SettingsCardGroup } from "../components/settingsCard"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; diff --git a/packages/web/src/app/(app)/settings/license/types.ts b/packages/web/src/app/(app)/settings/license/types.ts index f4fbbc9c3..d7660d40a 100644 --- a/packages/web/src/app/(app)/settings/license/types.ts +++ b/packages/web/src/app/(app)/settings/license/types.ts @@ -1,4 +1,4 @@ -import { YearlyTermStatus as RawYearlyTermStatus } from "@/ee/features/lighthouse/types"; +import { YearlyTermStatus as RawYearlyTermStatus } from "@/features/billing/types"; import { License } from "@sourcebot/db"; export type YearlyTermStatus = Omit & { diff --git a/packages/web/src/app/(app)/settings/mcp/clientCard.tsx b/packages/web/src/app/(app)/settings/mcp/clientCard.tsx index 72c03bfda..66fed9c6b 100644 --- a/packages/web/src/app/(app)/settings/mcp/clientCard.tsx +++ b/packages/web/src/app/(app)/settings/mcp/clientCard.tsx @@ -5,7 +5,7 @@ import { useToast } from "@/components/hooks/use-toast"; import { Check, Copy, ExternalLink } from "lucide-react"; import Image from "next/image"; import { useState } from "react"; -import { SettingsCard } from "../components/settingsCard"; +import { SettingsCard } from "@/app/(app)/settings/components/settingsCard"; import { type McpClient, buildClientAction } from "./clients"; interface ClientCardProps { diff --git a/packages/web/src/app/(app)/settings/mcp/mcpEntitlementMessage.tsx b/packages/web/src/app/(app)/settings/mcp/mcpEntitlementMessage.tsx new file mode 100644 index 000000000..eae7dd0fd --- /dev/null +++ b/packages/web/src/app/(app)/settings/mcp/mcpEntitlementMessage.tsx @@ -0,0 +1,23 @@ +"use client" + +import { UpsellPanel } from "@/features/billing/upsellDialog" + +/** + * Shown in place of the MCP server setup UI when the deployment is not on a plan + * that includes the MCP server. FSL (not ee/) so it can render for free-plan + * users as the upsell surface, reusing the shared feature-breakdown panel + * (plan comparison + trial/upgrade) without mounting any ee/ MCP feature code. + */ +export function McpEntitlementMessage() { + return ( +
+ +
+ ) +} diff --git a/packages/web/src/app/(app)/settings/mcp/mcpPage.tsx b/packages/web/src/app/(app)/settings/mcp/mcpPage.tsx index 38eac2872..d42a3b685 100644 --- a/packages/web/src/app/(app)/settings/mcp/mcpPage.tsx +++ b/packages/web/src/app/(app)/settings/mcp/mcpPage.tsx @@ -16,28 +16,33 @@ import { useToast } from "@/components/hooks/use-toast"; import { type ConnectedOauthClient, revokeMcpClient } from "@/ee/features/oauth/actions"; import { isServiceError } from "@/lib/utils"; import { formatDistanceToNow } from "date-fns"; -import { Boxes, Trash2 } from "lucide-react"; +import { AlertTriangle, Boxes, Trash2 } from "lucide-react"; import Image from "next/image"; import Link from "next/link"; import { useRouter } from "next/navigation"; -import { CopyIconButton } from "../../components/copyIconButton"; -import { SettingsCard, SettingsCardGroup } from "../components/settingsCard"; +import { useState } from "react"; +import { CopyIconButton } from "@/app/(app)/components/copyIconButton"; +import { SettingsCard, SettingsCardGroup } from "@/app/(app)/settings/components/settingsCard"; import { ClientCard } from "./clientCard"; import { MCP_CLIENTS, matchKnownClient } from "./clients"; +import { UpsellDialog } from "@/features/billing/upsellDialog"; const DOCS_URL = "https://docs.sourcebot.dev/docs/features/mcp-server"; interface McpPageProps { mcpServerUrl: string; connectedClients: ConnectedOauthClient[]; + isMcpEnabled: boolean; } export function McpPage({ mcpServerUrl, - connectedClients + connectedClients, + isMcpEnabled }: McpPageProps) { const { toast } = useToast(); const router = useRouter(); + const [isUpsellDialogOpen, setIsUpsellDialogOpen] = useState(false); const handleCopyServerUrl = () => { navigator.clipboard.writeText(mcpServerUrl) @@ -70,33 +75,64 @@ export function McpPage({

MCP Server

- Connect AI coding tools to search and read your code through Sourcebot's MCP server. Learn more + Connect your agents to Sourcebot to allow them to fetch code context, and more. Learn more

- -

Server URL

-
-
- {mcpServerUrl} + {!isMcpEnabled && ( + <> +
+ +
+

The MCP server is unavailable on your plan

+

+ You can disconnect existing clients below, but connecting new clients requires{" "} + . +

+
- -
- + + + )} -
-
-

Install in a client

-

- Set up Sourcebot in your editor or coding agent. -

-
-
- {MCP_CLIENTS.map((client) => ( - - ))} -
-
+ {isMcpEnabled && ( + <> + +

Server URL

+
+
+ {mcpServerUrl} +
+ +
+
+ +
+
+

Install in a client

+

+ Set up Sourcebot in your editor or coding agent. +

+
+
+ {MCP_CLIENTS.map((client) => ( + + ))} +
+
+ + )}
diff --git a/packages/web/src/app/(app)/settings/mcp/page.tsx b/packages/web/src/app/(app)/settings/mcp/page.tsx index a24e7ba37..58165bd24 100644 --- a/packages/web/src/app/(app)/settings/mcp/page.tsx +++ b/packages/web/src/app/(app)/settings/mcp/page.tsx @@ -2,11 +2,13 @@ import { authenticatedPage } from "@/middleware/authenticatedPage"; import { getConnectedOauthClients } from "@/ee/features/oauth/actions"; import { ServiceErrorException } from "@/lib/serviceError"; import { isServiceError } from "@/lib/utils"; +import { hasEntitlement } from "@/lib/entitlements"; import { env } from "@sourcebot/shared"; import { McpPage } from "./mcpPage"; +import { McpEntitlementMessage } from "./mcpEntitlementMessage"; export default authenticatedPage(async () => { - const mcpServerUrl = `${env.AUTH_URL.replace(/\/$/, '')}/api/mcp`; + const hasMcpEntitlement = await hasEntitlement('mcp'); /** * @note at the time of writing (May 26, 26'), the only type of @@ -19,10 +21,22 @@ export default authenticatedPage(async () => { throw new ServiceErrorException(connectedClients); } + // The MCP server is a paid feature, but a downgraded deployment must still + // be able to revoke previously-connected clients. So render the page when + // entitled, or when there are connected clients to disconnect; otherwise + // show the upgrade prompt. The page itself hides the setup surface (server + // URL + install instructions) when the entitlement is absent. + if (!hasMcpEntitlement && connectedClients.length === 0) { + return ; + } + + const mcpServerUrl = `${env.AUTH_URL.replace(/\/$/, '')}/api/mcp`; + return ( ) }); diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/page.test.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/page.test.tsx index 8d5be7a56..19e8ae7c5 100644 --- a/packages/web/src/app/(app)/settings/workspaceAskAgent/page.test.tsx +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/page.test.tsx @@ -24,6 +24,9 @@ vi.mock('@/middleware/authenticatedPage', () => ({ vi.mock('./workspaceAskAgentPage', () => ({ WorkspaceAskAgentPage: () =>
Workspace Ask Agent client
, })); +vi.mock('./workspaceAskAgentEntitlementMessage', () => ({ + WorkspaceAskAgentEntitlementMessage: () =>
Upgrade to configure Ask Agent connectors
, +})); const { default: Page } = await import('./page'); @@ -38,13 +41,13 @@ afterEach(() => { }); describe('Ask Agent settings page', () => { - test('renders the client configuration page when OAuth is available', async () => { + test('renders the connector configuration page when Ask Agent is available', async () => { render(await Page({ searchParams: Promise.resolve({}) })); expect(screen.getByText('Workspace Ask Agent client')).toBeTruthy(); }); - test('renders the client configuration page when OAuth is unavailable but servers exist for cleanup', async () => { + test('renders the connector page for teardown when Ask Agent is unavailable but connectors exist', async () => { mocks.hasEntitlement.mockResolvedValue(false); mocks.authContext.prisma.mcpServer.count.mockResolvedValue(1); @@ -56,12 +59,12 @@ describe('Ask Agent settings page', () => { }); }); - test('renders an unavailable message when OAuth is not available and no cleanup is needed', async () => { + test('renders the upsell message when Ask Agent is unavailable and no connectors exist', async () => { mocks.hasEntitlement.mockResolvedValue(false); render(await Page({ searchParams: Promise.resolve({}) })); - expect(screen.getByText('Ask Agent Connectors Are Unavailable')).toBeTruthy(); + expect(screen.getByText('Upgrade to configure Ask Agent connectors')).toBeTruthy(); expect(screen.queryByText('Workspace Ask Agent client')).toBeNull(); }); }); diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/page.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/page.tsx index 1b4eeef14..06e1bddb2 100644 --- a/packages/web/src/app/(app)/settings/workspaceAskAgent/page.tsx +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/page.tsx @@ -2,7 +2,7 @@ import { hasEntitlement } from "@/lib/entitlements"; import { authenticatedPage } from "@/middleware/authenticatedPage"; import { OrgRole } from "@sourcebot/db"; import { WorkspaceAskAgentPage } from "./workspaceAskAgentPage"; -import { WorkspaceAskAgentUnavailableMessage } from "./workspaceAskAgentUnavailableMessage"; +import { WorkspaceAskAgentEntitlementMessage } from "./workspaceAskAgentEntitlementMessage"; interface PageProps extends Record { searchParams: Promise<{ @@ -13,13 +13,19 @@ interface PageProps extends Record { } export default authenticatedPage(async ({ org, prisma }, { searchParams }) => { - if (!(await hasEntitlement("oauth"))) { + // Adding connectors requires the `ask` entitlement. But a downgraded + // workspace must still be able to view and remove previously-configured + // connectors, so this page lives in FSL: when connectors already exist we + // render it for teardown (the page itself disables "add" and only allows + // removal in that state). We only show the upsell when there is nothing to + // clean up. + if (!(await hasEntitlement("ask"))) { const serverCount = await prisma.mcpServer.count({ where: { orgId: org.id }, }); if (serverCount === 0) { - return ; + return ; } } diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentEntitlementMessage.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentEntitlementMessage.tsx new file mode 100644 index 000000000..0fd09ef1a --- /dev/null +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentEntitlementMessage.tsx @@ -0,0 +1,24 @@ +"use client" + +import { UpsellPanel } from "@/features/billing/upsellDialog" + +/** + * Shown in place of the workspace Ask Agent connector configuration UI when the + * deployment is not on a plan that includes Ask Sourcebot. FSL (not ee/) so it + * can render for free-plan users as the upsell surface, reusing the shared + * feature-breakdown panel (plan comparison + trial/upgrade) without mounting any + * ee/ connector code. + */ +export function WorkspaceAskAgentEntitlementMessage() { + return ( +
+ +
+ ) +} diff --git a/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentPage.tsx b/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentPage.tsx index 18fbc4411..8e8f08d10 100644 --- a/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentPage.tsx +++ b/packages/web/src/app/(app)/settings/workspaceAskAgent/workspaceAskAgentPage.tsx @@ -1,6 +1,7 @@ 'use client'; import { useEffect, useMemo, useRef, useState } from "react"; +import { useRouter } from "next/navigation"; import { getMcpConfiguration, getMcpServersWithStatus } from "@/app/api/(client)/client"; import { useToast } from "@/components/hooks/use-toast"; import { @@ -24,11 +25,11 @@ import { ConnectMcpButton } from "@/ee/features/chat/mcp/components/connectMcpBu import { ConnectorCard } from "@/ee/features/chat/mcp/components/connectorCard"; import { useMcpToolMetadata } from "@/ee/features/chat/mcp/hooks/useMcpToolMetadata"; import { invalidateMcpConfigurationQueries, mcpQueryKeys } from "@/ee/features/chat/mcp/queryKeys"; -import { pluralize } from "@/ee/features/chat/mcp/utils"; +import { pluralize } from "@/features/chat/mcp/utils"; import { cn, isServiceError } from "@/lib/utils"; import { useQuery, useQueryClient } from "@tanstack/react-query"; import { AlertTriangleIcon, CableIcon, CopyIcon, Loader2, MoreHorizontalIcon, PlusIcon, Trash2Icon } from "lucide-react"; -import { PrefabConnectorPopover } from "./prefabConnectorPopover"; +import { PrefabConnectorPopover } from "@/ee/features/chat/mcp/components/prefabConnectorPopover"; import type { PrefabMcpServer } from "@/ee/features/chat/mcp/prefabMcpServers"; import type { McpConfigurationServer, ServerToolsEntry } from "@/ee/features/chat/mcp/types"; @@ -54,7 +55,7 @@ type WorkspaceConnectorStatus = { interface WorkspaceConnectorCardProps { server: McpConfigurationServer; status?: WorkspaceConnectorStatus; - isOAuthAvailable: boolean; + isAskAgentAvailable: boolean; isStatusLoading: boolean; isStatusError: boolean; toolEntry?: ServerToolsEntry; @@ -68,7 +69,7 @@ interface WorkspaceConnectorCardProps { function WorkspaceConnectorCard({ server, status, - isOAuthAvailable, + isAskAgentAvailable, isStatusLoading, isStatusError, toolEntry, @@ -80,8 +81,8 @@ function WorkspaceConnectorCard({ }: WorkspaceConnectorCardProps) { const isConnected = status?.isConnected === true; const isAuthExpired = status?.isAuthExpired === true; - const isStatusUnavailable = isOAuthAvailable !== true || isStatusLoading || isStatusError || !status; - const showConnectButton = isOAuthAvailable && !isStatusLoading && !isStatusError && !!status && !isConnected; + const isStatusUnavailable = isAskAgentAvailable !== true || isStatusLoading || isStatusError || !status; + const showConnectButton = isAskAgentAvailable && !isStatusLoading && !isStatusError && !!status && !isConnected; const serverLabel = server.name || server.serverUrl; return ( @@ -91,7 +92,7 @@ function WorkspaceConnectorCard({ serverUrl={server.serverUrl} isConnected={isConnected} isAuthExpired={isAuthExpired} - isOAuthAvailable={isOAuthAvailable} + isAskAgentAvailable={isAskAgentAvailable} isStatusUnavailable={isStatusUnavailable} toolEntry={isConnected ? toolEntry : undefined} toolUsage={server.toolUsage} @@ -160,6 +161,7 @@ function WorkspaceConnectorCard({ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callbackMessage }: WorkspaceAskAgentPageProps) { const { toast } = useToast(); const queryClient = useQueryClient(); + const router = useRouter(); const didHandleCallbackRef = useRef(false); useEffect(() => { @@ -210,7 +212,7 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback } return result; }, - enabled: data?.isOAuthAvailable !== false, + enabled: data?.isAskAgentAvailable !== false, }); const myStatusByServerId = useMemo(() => { @@ -222,8 +224,8 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback }, [serversWithStatus]); const servers = data?.servers ?? []; - const canCreateConnectors = data?.isOAuthAvailable === true; - const isOAuthUnavailable = data?.isOAuthAvailable === false; + const canCreateConnectors = data?.isAskAgentAvailable === true; + const isAskAgentUnavailable = data?.isAskAgentAvailable === false; const connectedServerCount = useMemo( () => serversWithStatus?.filter((server) => server.isConnected).length ?? 0, [serversWithStatus], @@ -233,7 +235,7 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback isToolsError, refetchTools, toolsByServerId, - } = useMcpToolMetadata(data?.isOAuthAvailable === true, connectedServerCount); + } = useMcpToolMetadata(data?.isAskAgentAvailable === true, connectedServerCount); const handleCreateDialogOpenChange = (open: boolean) => { setIsCreateDialogOpen(open); @@ -367,6 +369,16 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback await invalidateMcpConfigurationQueries(queryClient); setServerToDelete(null); + + // When the last connector is removed, re-run the server-side gate in + // page.tsx, which swaps this page for the upsell if the deployment + // lacks the `ask` entitlement. Only refresh in that case; otherwise + // the query invalidation above already updates the list. (`servers` + // is the pre-deletion list captured in this closure.) + const isLastServer = servers.filter((s) => s.id !== serverId).length === 0; + if (isLastServer) { + router.refresh(); + } } catch (error) { toast({ title: "Error", description: `Failed to remove connector: ${error}`, variant: "destructive" }); } finally { @@ -402,12 +414,12 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback - {/* OAuth unavailable warning */} - {!isLoading && isOAuthUnavailable && ( + {/* Ask Agent unavailable warning */} + {!isLoading && isAskAgentUnavailable && (
-

Connector OAuth is unavailable

+

Ask Agent connectors are unavailable

You can remove existing approved connectors and stored credentials, but cannot add new connectors.

@@ -418,7 +430,7 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback {/* Connectors section */}
-

Connectors

+

Connectors

Connectors are MCP servers that let Ask Agent use approved external tools alongside your indexed code.

@@ -430,7 +442,7 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback

Allowed connectors

- {isOAuthUnavailable + {isAskAgentUnavailable ? "Remove existing connector approvals and their stored credentials." : "Approve connector URLs that workspace members can connect to."}

@@ -468,8 +480,8 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback

No connectors configured yet

- {isOAuthUnavailable - ? "Connector OAuth is unavailable on this Sourcebot instance." + {isAskAgentUnavailable + ? "Ask Agent connectors are unavailable on this Sourcebot instance." : "Add a workspace-approved connector so members can use it with Ask Agent."}

@@ -480,7 +492,7 @@ export function WorkspaceAskAgentPage({ callbackStatus, callbackServer, callback key={server.id} server={server} status={myStatusByServerId.get(server.id)} - isOAuthAvailable={data?.isOAuthAvailable === true} + isAskAgentAvailable={data?.isAskAgentAvailable === true} isStatusLoading={isServersWithStatusLoading} isStatusError={isServersWithStatusError} toolEntry={toolsByServerId.get(server.id)} diff --git a/packages/web/src/app/api/(client)/client.ts b/packages/web/src/app/api/(client)/client.ts index ecc95818a..6c119341d 100644 --- a/packages/web/src/app/api/(client)/client.ts +++ b/packages/web/src/app/api/(client)/client.ts @@ -2,7 +2,7 @@ import { ServiceError } from "@/lib/serviceError"; import { GetVersionResponse, ListReposQueryParams, ListReposResponse } from "@/lib/types"; -import type { ListChatsQueryParams, ListChatsResponse } from "../(server)/chats/types"; +import type { ListChatsQueryParams, ListChatsResponse } from "../(server)/ee/chats/types"; import { isServiceError } from "@/lib/utils"; import { SearchRequest, @@ -30,7 +30,7 @@ import type { SearchChatShareableMembersQueryParams, SearchChatShareableMembersResponse, } from "../(server)/ee/chat/[chatId]/searchMembers/route"; -import { OffersResponse } from "@/ee/features/lighthouse/types"; +import { OffersResponse } from "@/features/billing/types"; import { ConnectMcpResponse } from "../(server)/ee/askmcp/connect/types"; import type { GetMcpServersResponse } from "../(server)/ee/askmcp/servers/route"; import type { GetMcpConfigurationResponse, GetMcpToolsResponse } from "@/ee/features/chat/mcp/types"; @@ -215,7 +215,7 @@ export const searchChatShareableMembers = async ( } export const listChats = async (queryParams: ListChatsQueryParams): Promise => { - const url = new URL("/api/chats", window.location.origin); + const url = new URL("/api/ee/chats", window.location.origin); for (const [key, value] of Object.entries(queryParams)) { if (value !== undefined) { url.searchParams.set(key, String(value)); diff --git a/packages/web/src/app/api/(server)/chat/blocking/route.ts b/packages/web/src/app/api/(server)/chat/blocking/route.ts index 56165f2c6..0098d97a5 100644 --- a/packages/web/src/app/api/(server)/chat/blocking/route.ts +++ b/packages/web/src/app/api/(server)/chat/blocking/route.ts @@ -1,4 +1,4 @@ -import { askCodebase } from "@/features/mcp/askCodebase"; +import { askCodebase } from "@/ee/features/mcp/askCodebase"; import { languageModelInfoSchema } from "@/features/chat/types"; import { apiHandler } from "@/lib/apiHandler"; import { requestBodySchemaValidationError, serviceErrorResponse } from "@/lib/serviceError"; diff --git a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts index 430cef51d..5b669b509 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.test.ts @@ -54,7 +54,7 @@ vi.mock('@ai-sdk/mcp', () => ({ })); const { GET } = await import('./route'); -const { createMcpOAuthState } = await import('@/features/mcp/mcpOAuthReturnTo'); +const { createMcpOAuthState } = await import('@/ee/features/chat/mcp/mcpOAuthReturnTo'); function createRequest(state = 'state-1') { return new NextRequest(`https://sourcebot.example.com/api/ee/askmcp/callback?code=code-1&state=${encodeURIComponent(state)}`, { diff --git a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts index c2d15704e..0b9d5c55a 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/callback/route.ts @@ -3,7 +3,7 @@ import { apiHandler } from '@/lib/apiHandler'; import { env, createLogger } from '@sourcebot/shared'; import { hasEntitlement } from '@/lib/entitlements'; import { OAUTH_NOT_SUPPORTED_ERROR_MESSAGE } from '@/ee/features/oauth/constants'; -import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { PrismaOAuthClientProvider } from '@/ee/features/chat/mcp/prismaOAuthClientProvider'; // Note: We use the raw (unscoped) prisma client here because this route handles OAuth // redirect callbacks from external providers, so it can't go through withAuth. Session // identity is verified via NextAuth's auth() instead, and all queries filter by userId. @@ -11,7 +11,7 @@ import { __unsafePrisma as prisma } from '@/prisma'; import { auth } from '@/auth'; import { NextRequest, NextResponse } from 'next/server'; import { getExternalMcpErrorLogFields } from '@/ee/features/chat/mcp/externalMcpError'; -import { getMcpOAuthReturnToFromState } from '@/features/mcp/mcpOAuthReturnTo'; +import { getMcpOAuthReturnToFromState } from '@/ee/features/chat/mcp/mcpOAuthReturnTo'; import { captureEvent } from '@/lib/posthog'; import { getMcpAuthMode, getMcpConnectorEntryPoint, getMcpConnectorFailureReason } from '@/ee/features/chat/mcp/analytics'; @@ -41,7 +41,7 @@ function redirectToCallbackError(message: string, returnTo?: string) { // eslint-disable-next-line authz/require-auth-wrapper -- OAuth redirect callback validates the active session with auth() and filters all queries by userId. export const GET = apiHandler(async (request: NextRequest) => { - if (!(await hasEntitlement('oauth'))) { + if (!(await hasEntitlement('ask'))) { return Response.json( { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, { status: 403 } diff --git a/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts index 1c868baae..fa2968d87 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.test.ts @@ -140,7 +140,7 @@ describe('GET /api/ee/askmcp/configuration', () => { }); expect(body).toMatchObject({ allowedMode: 'approved_only', - isOAuthAvailable: true, + isAskAgentAvailable: true, servers: [ { id: 'server-1', @@ -204,7 +204,7 @@ describe('GET /api/ee/askmcp/configuration', () => { expect(mocks.unsafePrisma.mcpServerToolCallCount.findMany).not.toHaveBeenCalled(); }); - test('rejects unauthenticated callers before checking OAuth entitlement', async () => { + test('rejects unauthenticated callers before checking the ask entitlement', async () => { mocks.withAuth.mockResolvedValue({ statusCode: 401, errorCode: ErrorCode.NOT_AUTHENTICATED, @@ -223,7 +223,7 @@ describe('GET /api/ee/askmcp/configuration', () => { expect(mocks.unsafePrisma.mcpServerToolCallCount.findMany).not.toHaveBeenCalled(); }); - test('allows entitled owners to list cleanup data when OAuth is unsupported', async () => { + test('allows entitled owners to list cleanup data when Ask Agent is unavailable', async () => { const prisma = createPrismaMock(); mocks.authContext = { org: { id: 1 }, @@ -237,7 +237,7 @@ describe('GET /api/ee/askmcp/configuration', () => { expect(response.status).toBe(200); expect(body).toMatchObject({ - isOAuthAvailable: false, + isAskAgentAvailable: false, servers: [ { id: 'server-1', @@ -272,7 +272,7 @@ describe('GET /api/ee/askmcp/configuration', () => { expect(body).toEqual({ servers: [], allowedMode: 'approved_only', - isOAuthAvailable: true, + isAskAgentAvailable: true, }); }); }); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts index b4d3949a6..55ad15afd 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/configuration/route.ts @@ -5,7 +5,7 @@ import { hasEntitlement } from '@/lib/entitlements'; import { withAuth } from '@/middleware/withAuth'; import { withMinimumOrgRole } from '@/middleware/withMinimumOrgRole'; import { __unsafePrisma } from '@/prisma'; -import { getMcpFaviconUrl } from '@/ee/features/chat/mcp/utils'; +import { getMcpFaviconUrl } from '@/features/chat/mcp/utils'; import type { GetMcpConfigurationResponse, McpServerToolUsageSummary } from '@/ee/features/chat/mcp/types'; import { OrgRole } from '@sourcebot/db'; import type { NextRequest } from 'next/server'; @@ -13,7 +13,7 @@ import type { NextRequest } from 'next/server'; export const GET = apiHandler(async (_request: NextRequest) => { const result = await withAuth(async ({ org, role, prisma }) => withMinimumOrgRole(role, OrgRole.OWNER, async (): Promise => { - const isOAuthAvailable = await hasEntitlement('oauth'); + const isAskAgentAvailable = await hasEntitlement('ask'); const orgServers = await prisma.mcpServer.findMany({ where: { orgId: org.id }, @@ -111,7 +111,7 @@ export const GET = apiHandler(async (_request: NextRequest) => { return { servers, allowedMode: 'approved_only', - isOAuthAvailable, + isAskAgentAvailable, }; })); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts index 6b9561ac6..2e09e53f5 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.test.ts @@ -46,7 +46,7 @@ vi.mock('@ai-sdk/mcp', () => ({ })); const { POST } = await import('./route'); -const { getMcpOAuthReturnToFromState } = await import('@/features/mcp/mcpOAuthReturnTo'); +const { getMcpOAuthReturnToFromState } = await import('@/ee/features/chat/mcp/mcpOAuthReturnTo'); function createRequest(body: { serverId: string; returnTo?: string } = { serverId: 'server-1' }) { return new NextRequest('https://sourcebot.example.com/api/ee/askmcp/connect', { diff --git a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts index a6b3aff02..bcdb92708 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/connect/route.ts @@ -4,7 +4,7 @@ import { withAuth } from '@/middleware/withAuth'; import { sew } from '@/middleware/sew'; import { isServiceError } from '@/lib/utils'; import { serviceErrorResponse, notFound, requestBodySchemaValidationError, ServiceErrorException } from '@/lib/serviceError'; -import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { PrismaOAuthClientProvider } from '@/ee/features/chat/mcp/prismaOAuthClientProvider'; import { NextRequest } from 'next/server'; import { z } from 'zod'; import { hasEntitlement } from '@/lib/entitlements'; @@ -15,7 +15,7 @@ import { __unsafePrisma } from '@/prisma'; import { getExternalMcpErrorLogFields } from '@/ee/features/chat/mcp/externalMcpError'; import { ErrorCode } from '@/lib/errorCodes'; import { StatusCodes } from 'http-status-codes'; -import { normalizeMcpOAuthReturnTo } from '@/features/mcp/mcpOAuthReturnTo'; +import { normalizeMcpOAuthReturnTo } from '@/ee/features/chat/mcp/mcpOAuthReturnTo'; import { captureEvent } from '@/lib/posthog'; import { getMcpAuthMode, getMcpConnectorEntryPoint, getMcpConnectorFailureReason } from '@/ee/features/chat/mcp/analytics'; @@ -43,7 +43,7 @@ function createTimeoutFetch(timeoutMs: number): typeof fetch { } export const POST = apiHandler(async (request: NextRequest) => { - if (!(await hasEntitlement('oauth'))) { + if (!(await hasEntitlement('ask'))) { return Response.json( { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, { status: 403 } diff --git a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts index 42417d501..2020bef11 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.test.ts @@ -72,7 +72,7 @@ beforeEach(() => { }); describe('GET /api/ee/askmcp/servers', () => { - test('returns an empty array when the oauth entitlement is not granted', async () => { + test('returns an empty array when the ask entitlement is not granted', async () => { mocks.hasEntitlement.mockResolvedValue(false); const prisma = createPrismaMock(); mocks.authContext = { diff --git a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts index 8ccb3527d..d81802a36 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/servers/route.ts @@ -3,7 +3,7 @@ import { serviceErrorResponse } from '@/lib/serviceError'; import { isServiceError } from '@/lib/utils'; import { withAuth } from '@/middleware/withAuth'; import { hasEntitlement } from '@/lib/entitlements'; -import { getMcpFaviconUrl } from '@/ee/features/chat/mcp/utils'; +import { getMcpFaviconUrl } from '@/features/chat/mcp/utils'; import { getStoredMcpConnectionStatus } from '@/ee/features/chat/mcp/connectionStatus'; import type { NextRequest } from 'next/server'; @@ -20,7 +20,7 @@ export interface McpServerWithStatus { export type GetMcpServersResponse = McpServerWithStatus[]; export const GET = apiHandler(async (_request: NextRequest) => { - if (!(await hasEntitlement('oauth'))) { + if (!(await hasEntitlement('ask'))) { return Response.json([] satisfies GetMcpServersResponse); } diff --git a/packages/web/src/app/api/(server)/ee/askmcp/tools/route.test.ts b/packages/web/src/app/api/(server)/ee/askmcp/tools/route.test.ts index 79cf8164d..a409b837f 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/tools/route.test.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/tools/route.test.ts @@ -62,7 +62,7 @@ describe('GET /api/ee/askmcp/tools', () => { ]); }); - test('returns access_denied when OAuth is unavailable', async () => { + test('returns access_denied when Ask Agent is unavailable', async () => { mocks.hasEntitlement.mockResolvedValue(false); const response = await GET(createRequest()); diff --git a/packages/web/src/app/api/(server)/ee/askmcp/tools/route.ts b/packages/web/src/app/api/(server)/ee/askmcp/tools/route.ts index aea01a7e7..d792c5b97 100644 --- a/packages/web/src/app/api/(server)/ee/askmcp/tools/route.ts +++ b/packages/web/src/app/api/(server)/ee/askmcp/tools/route.ts @@ -8,7 +8,7 @@ import { getMcpToolMetadata } from '@/ee/features/chat/mcp/mcpToolMetadata'; import type { NextRequest } from 'next/server'; export const GET = apiHandler(async (_request: NextRequest) => { - if (!(await hasEntitlement('oauth'))) { + if (!(await hasEntitlement('ask'))) { return Response.json( { error: 'access_denied', error_description: OAUTH_NOT_SUPPORTED_ERROR_MESSAGE }, { status: 403 }, diff --git a/packages/web/src/app/api/(server)/ee/chat/[chatId]/searchMembers/route.ts b/packages/web/src/app/api/(server)/ee/chat/[chatId]/searchMembers/route.ts index 22847e816..871db72d0 100644 --- a/packages/web/src/app/api/(server)/ee/chat/[chatId]/searchMembers/route.ts +++ b/packages/web/src/app/api/(server)/ee/chat/[chatId]/searchMembers/route.ts @@ -41,11 +41,11 @@ export const GET = apiHandler(async ( }) } - if (!await hasEntitlement('chat-sharing')) { + if (!await hasEntitlement('ask')) { return serviceErrorResponse({ statusCode: StatusCodes.FORBIDDEN, errorCode: ErrorCode.UNEXPECTED_ERROR, - message: "Chat sharing is not enabled for your license", + message: "Chat sharing requires a paid plan", }) } diff --git a/packages/web/src/app/api/(server)/chat/route.ts b/packages/web/src/app/api/(server)/ee/chat/route.ts similarity index 91% rename from packages/web/src/app/api/(server)/chat/route.ts rename to packages/web/src/app/api/(server)/ee/chat/route.ts index 77379457d..f0c8eafb8 100644 --- a/packages/web/src/app/api/(server)/chat/route.ts +++ b/packages/web/src/app/api/(server)/ee/chat/route.ts @@ -1,9 +1,10 @@ import { sew } from "@/middleware/sew"; -import { getAskMcpAvailabilityAnalytics, getAskMcpTurnCompletedAnalytics } from "@/features/chat/askMcpAnalytics.server"; -import { createMessageStream } from "@/features/chat/agent"; +import { getAskMcpAvailabilityAnalytics, getAskMcpTurnCompletedAnalytics } from "@/ee/features/chat/askMcpAnalytics.server"; +import { createMessageStream } from "@/ee/features/chat/agent"; import { additionalChatRequestParamsSchema } from "@/features/chat/types"; import { getLanguageModelKey } from "@/features/chat/utils"; -import { getAISDKLanguageModelAndOptions, getConfiguredLanguageModels, isOwnerOfChat, updateChatMessages } from "@/features/chat/utils.server"; +import { checkAskEntitlement, getConfiguredLanguageModels, isOwnerOfChat, updateChatMessages } from "@/features/chat/utils.server"; +import { getAISDKLanguageModelAndOptions } from "@/features/chat/llm.server"; import { apiHandler } from "@/lib/apiHandler"; import { ErrorCode } from "@/lib/errorCodes"; import { captureEvent } from "@/lib/posthog"; @@ -42,6 +43,13 @@ export const POST = apiHandler(async (req: NextRequest) => { const response = await sew(() => withOptionalAuth(async ({ org, user, prisma }) => { + // Gate the generative path behind the `ask` entitlement. The client + // also gates this, but server-side enforcement can't be bypassed. + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + // Validate that the chat exists. const chat = await prisma.chat.findUnique({ where: { diff --git a/packages/web/src/app/api/(server)/chats/route.ts b/packages/web/src/app/api/(server)/ee/chats/route.ts similarity index 86% rename from packages/web/src/app/api/(server)/chats/route.ts rename to packages/web/src/app/api/(server)/ee/chats/route.ts index 16db1352c..26178b415 100644 --- a/packages/web/src/app/api/(server)/chats/route.ts +++ b/packages/web/src/app/api/(server)/ee/chats/route.ts @@ -1,8 +1,9 @@ import { apiHandler } from "@/lib/apiHandler"; -import { serviceErrorResponse, queryParamsSchemaValidationError } from "@/lib/serviceError"; +import { serviceErrorResponse, queryParamsSchemaValidationError, ServiceError } from "@/lib/serviceError"; import { listChatsQueryParamsSchema, ListChatsResponse } from "./types"; import { isServiceError } from "@/lib/utils"; import { withAuth } from "@/middleware/withAuth"; +import { checkAskEntitlement } from "@/features/chat/utils.server"; import { NextRequest } from "next/server"; export const GET = apiHandler(async (request: NextRequest) => { @@ -22,7 +23,12 @@ export const GET = apiHandler(async (request: NextRequest) => { const { cursor, limit, query, sortBy, sortOrder } = parsed.data; - const result = await withAuth(async ({ org, user, prisma }): Promise => { + const result = await withAuth(async ({ org, user, prisma }): Promise => { + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + const chats = await prisma.chat.findMany({ where: { orgId: org.id, diff --git a/packages/web/src/app/api/(server)/chats/types.ts b/packages/web/src/app/api/(server)/ee/chats/types.ts similarity index 100% rename from packages/web/src/app/api/(server)/chats/types.ts rename to packages/web/src/app/api/(server)/ee/chats/types.ts diff --git a/packages/web/src/app/api/(server)/mcp/route.ts b/packages/web/src/app/api/(server)/ee/mcp/route.ts similarity index 88% rename from packages/web/src/app/api/(server)/mcp/route.ts rename to packages/web/src/app/api/(server)/ee/mcp/route.ts index 9218ccd0a..15c7f032a 100644 --- a/packages/web/src/app/api/(server)/mcp/route.ts +++ b/packages/web/src/app/api/(server)/ee/mcp/route.ts @@ -1,6 +1,7 @@ import { WebStandardStreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/webStandardStreamableHttp.js'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; -import { createMcpServer } from '@/features/mcp/server'; +import { createMcpServer } from '@/ee/features/mcp/server'; +import { MCP_PAID_PLAN_REQUIRED_MESSAGE } from '@/ee/features/mcp/constants'; import { withOptionalAuth } from '@/middleware/withAuth'; import { isServiceError } from '@/lib/utils'; import { notAuthenticated, serviceErrorResponse, ServiceError } from '@/lib/serviceError'; @@ -43,6 +44,16 @@ const MCP_SESSION_ID_HEADER = 'MCP-Session-Id'; const sessions = new Map(); export const POST = apiHandler(async (request: NextRequest) => { + // The MCP server is a paid feature. Gate every request before touching + // sessions or auth so free-plan instances get a clear upgrade signal. + if (!await hasEntitlement('mcp')) { + return await mcpErrorResponse({ + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: MCP_PAID_PLAN_REQUIRED_MESSAGE, + }); + } + const response = await sew(() => withOptionalAuth(async ({ user }) => { if (env.EXPERIMENT_ASK_GH_ENABLED === 'true' && !user) { @@ -95,6 +106,14 @@ export const POST = apiHandler(async (request: NextRequest) => { }); export const DELETE = apiHandler(async (request: NextRequest) => { + if (!await hasEntitlement('mcp')) { + return await mcpErrorResponse({ + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: MCP_PAID_PLAN_REQUIRED_MESSAGE, + }); + } + const result = await sew(() => withOptionalAuth(async ({ user }) => { if (env.EXPERIMENT_ASK_GH_ENABLED === 'true' && !user) { diff --git a/packages/web/src/app/api/(server)/offers/route.ts b/packages/web/src/app/api/(server)/offers/route.ts index b54dde331..aff748d5d 100644 --- a/packages/web/src/app/api/(server)/offers/route.ts +++ b/packages/web/src/app/api/(server)/offers/route.ts @@ -1,4 +1,4 @@ -import { client as lighthouseClient } from "@/ee/features/lighthouse/client"; +import { client as lighthouseClient } from "@/features/billing/client"; import { apiHandler } from "@/lib/apiHandler"; import { env } from "@sourcebot/shared"; diff --git a/packages/web/src/app/onboard/components/trialStep.tsx b/packages/web/src/app/onboard/components/trialStep.tsx index 636c7f40a..960ee8791 100644 --- a/packages/web/src/app/onboard/components/trialStep.tsx +++ b/packages/web/src/app/onboard/components/trialStep.tsx @@ -6,10 +6,10 @@ import { useSession } from "next-auth/react"; import { LoadingButton } from "@/components/ui/loading-button"; import { Skeleton } from "@/components/ui/skeleton"; import { completeOnboarding } from "@/actions"; -import { createCheckoutSession } from "@/ee/features/lighthouse/actions"; -import { useOffers } from "@/ee/features/lighthouse/useOffers"; -import { BillingInterval, PlanComparisonTable } from "@/ee/features/lighthouse/planComparisonTable"; -import { CheckoutDisclosures } from "@/ee/features/lighthouse/checkoutDisclosures"; +import { createCheckoutSession } from "@/features/billing/actions"; +import { useOffers } from "@/features/billing/useOffers"; +import { BillingInterval, PlanComparisonTable } from "@/features/billing/planComparisonTable"; +import { CheckoutDisclosures } from "@/features/billing/checkoutDisclosures"; import { useToast } from "@/components/hooks/use-toast"; import { isServiceError } from "@/lib/utils"; import useCaptureEvent from "@/hooks/useCaptureEvent"; diff --git a/packages/web/src/ee/features/analytics/analyticsEntitlementMessage.tsx b/packages/web/src/ee/features/analytics/analyticsEntitlementMessage.tsx deleted file mode 100644 index a7ac09b4f..000000000 --- a/packages/web/src/ee/features/analytics/analyticsEntitlementMessage.tsx +++ /dev/null @@ -1,56 +0,0 @@ -"use client" - -import { useState } from "react" -import { Card, CardContent, CardDescription, CardHeader, CardTitle } from "@/components/ui/card" -import { Button } from "@/components/ui/button" -import { Skeleton } from "@/components/ui/skeleton" -import { ArrowUpCircle, BarChart3 } from "lucide-react" -import { UpsellDialog } from "@/ee/features/lighthouse/upsellDialog" -import { useOffers } from "@/ee/features/lighthouse/useOffers" - -export function AnalyticsEntitlementMessage() { - const [isUpsellDialogOpen, setIsUpsellDialogOpen] = useState(false); - const { data: offers, isPending } = useOffers(); - - const buttonLabel = offers?.trial.eligible - ? `Start ${offers.trial.durationDays} day trial` - : "Upgrade to Pro"; - - return ( -
- - -
-
- -
-
- - Analytics is a Pro Feature - - - Get insights into your organization's usage patterns and activity. Learn more - -
- -
- {isPending ? ( - - ) : ( - - )} -
-
-
- -
- ) -} \ No newline at end of file diff --git a/packages/web/src/ee/features/chat/actions.ts b/packages/web/src/ee/features/chat/actions.ts new file mode 100644 index 000000000..fc2b7243f --- /dev/null +++ b/packages/web/src/ee/features/chat/actions.ts @@ -0,0 +1,62 @@ +'use server'; + +import { sew } from "@/middleware/sew"; +import { ErrorCode } from "@/lib/errorCodes"; +import { notFound, ServiceError } from "@/lib/serviceError"; +import { withOptionalAuth } from "@/middleware/withAuth"; +import { StatusCodes } from "http-status-codes"; +import { checkAskEntitlement, getConfiguredLanguageModels, isOwnerOfChat } from "@/features/chat/utils.server"; +import { generateChatNameFromMessage } from "./llm.server"; + +export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }) => sew(() => + withOptionalAuth(async ({ prisma, user, org }) => { + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + + const chat = await prisma.chat.findUnique({ + where: { + id: chatId, + orgId: org.id, + }, + }); + + if (!chat) { + return notFound(); + } + + const isOwner = await isOwnerOfChat(chat, user); + if (!isOwner) { + return notFound(); + } + + const languageModelConfig = + (await getConfiguredLanguageModels()) + .find((model) => model.model === languageModelId); + + if (!languageModelConfig) { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.INVALID_REQUEST_BODY, + message: `Language model ${languageModelId} is not configured.`, + } satisfies ServiceError; + } + + const name = await generateChatNameFromMessage({ message, languageModelConfig }); + + await prisma.chat.update({ + where: { + id: chatId, + orgId: org.id, + }, + data: { + name: name, + }, + }) + + return { + success: true, + } + }) +) diff --git a/packages/web/src/features/chat/agent.test.ts b/packages/web/src/ee/features/chat/agent.test.ts similarity index 98% rename from packages/web/src/features/chat/agent.test.ts rename to packages/web/src/ee/features/chat/agent.test.ts index e7984e655..a47e10434 100644 --- a/packages/web/src/features/chat/agent.test.ts +++ b/packages/web/src/ee/features/chat/agent.test.ts @@ -1,6 +1,6 @@ import { beforeEach, describe, expect, test, vi } from 'vitest'; import type { ModelMessage } from 'ai'; -import type { SBChatMessage, SBChatMessagePart } from './types'; +import type { SBChatMessage, SBChatMessagePart } from '@/features/chat/types'; const mockLogger = vi.hoisted(() => ({ debug: vi.fn(), @@ -76,7 +76,7 @@ vi.mock('@/features/tools', () => { }); vi.mock('@/lib/entitlements', () => ({ - hasEntitlement: vi.fn(() => false), + hasEntitlement: vi.fn(() => true), })); vi.mock('@/lib/posthog', () => ({ diff --git a/packages/web/src/features/chat/agent.ts b/packages/web/src/ee/features/chat/agent.ts similarity index 96% rename from packages/web/src/features/chat/agent.ts rename to packages/web/src/ee/features/chat/agent.ts index f4bd96854..b39a930e4 100644 --- a/packages/web/src/features/chat/agent.ts +++ b/packages/web/src/ee/features/chat/agent.ts @@ -18,9 +18,9 @@ import { import { z } from "zod"; import { randomUUID } from "crypto"; import _dedent from "dedent"; -import { ANSWER_TAG, FILE_REFERENCE_PREFIX } from "./constants"; -import { Source } from "./types"; -import { addLineNumbers, fileReferenceToString, getAnswerPartFromAssistantMessage, getTurnProgressState } from "./utils"; +import { ANSWER_TAG, FILE_REFERENCE_PREFIX } from "@/features/chat/constants"; +import { Source } from "@/features/chat/types"; +import { addLineNumbers, fileReferenceToString, getAnswerPartFromAssistantMessage, getTurnProgressState } from "@/features/chat/utils"; import { createTools } from "./tools"; import { getConnectedMcpClients } from "@/ee/features/chat/mcp/mcpClientFactory"; import { getMcpTools, McpToolsResult } from "@/ee/features/chat/mcp/mcpToolSets"; @@ -76,6 +76,14 @@ export const createMessageStream = async ({ userId, orgId, }: CreateMessageStreamResponseProps) => { + // Defense-in-depth: Ask Sourcebot is a paid feature. Every caller is + // expected to gate on the `ask` entitlement before reaching here (see + // checkAskEntitlement); this assertion backstops that contract so a future + // ungated caller cannot execute the agent on a non-entitled deployment. + if (!(await hasEntitlement('ask'))) { + throw new Error('Ask Sourcebot is not available in the current plan.'); + } + const latestMessage = messages[messages.length - 1]; const sources = latestMessage.parts .filter((part) => part.type === 'data-source') @@ -269,7 +277,7 @@ const createAgentStream = async ({ ).filter((source) => source !== undefined); let mcpToolSetsObj: McpToolsResult = { tools: {}, failedServers: [], serverFaviconUrls: {}, cleanup: async () => {} }; - if (userId && orgId && await hasEntitlement('oauth') && disabledMcpServerIds !== undefined) { + if (userId && orgId && await hasEntitlement('ask') && disabledMcpServerIds !== undefined) { try { const allMcpClients = await getConnectedMcpClients(prisma, userId, orgId); const mcpClients = allMcpClients.filter((c) => !disabledMcpServerIds.includes(c.serverId)); diff --git a/packages/web/src/ee/features/chat/askMcpAnalytics.server.ts b/packages/web/src/ee/features/chat/askMcpAnalytics.server.ts new file mode 100644 index 000000000..2278a0583 --- /dev/null +++ b/packages/web/src/ee/features/chat/askMcpAnalytics.server.ts @@ -0,0 +1,136 @@ +import { getStoredMcpConnectionStatus } from "@/ee/features/chat/mcp/connectionStatus"; +import { hasEntitlement } from "@/lib/entitlements"; +import type { PrismaClient } from "@sourcebot/db"; +import type { DynamicToolUIPart } from "ai"; +import type { SBChatMessage, SBChatMessagePart } from "@/features/chat/types"; +import { getTurnProgressState } from "@/features/chat/utils"; + +export type AskMcpAvailabilityAnalytics = { + hasAskMcpServersAvailable: boolean; + askMcpConnectedServerCount: number; + askMcpEnabledServerCount: number; + askMcpDisabledServerCount: number; +}; + +export type AskMcpTurnCompletedAnalytics = { + traceId?: string; + askMcpUsed: boolean; + askMcpToolCallCount: number; + askMcpToolSuccessCount: number; + askMcpToolFailureCount: number; + askMcpApprovalRequestedCount: number; + askMcpApprovalDeniedCount: number; + askMcpFailedServerCount: number; + durationMs: number; +}; + +const emptyAskMcpAvailability: AskMcpAvailabilityAnalytics = { + hasAskMcpServersAvailable: false, + askMcpConnectedServerCount: 0, + askMcpEnabledServerCount: 0, + askMcpDisabledServerCount: 0, +}; + +type AskMcpAvailabilityPrismaClient = Pick; + +export async function getAskMcpAvailabilityAnalytics({ + prisma, + userId, + orgId, + disabledMcpServerIds, +}: { + prisma: AskMcpAvailabilityPrismaClient; + userId: string | undefined; + orgId: number; + disabledMcpServerIds: string[]; +}): Promise { + if (!userId || !(await hasEntitlement("ask"))) { + return emptyAskMcpAvailability; + } + + const userServers = await prisma.userMcpServer.findMany({ + where: { + userId, + tokens: { not: null }, + server: { + orgId, + clientInfo: { not: null }, + }, + }, + select: { + serverId: true, + tokens: true, + tokensExpiresAt: true, + }, + }); + + const connectedServerIds = userServers + .filter((userServer) => + getStoredMcpConnectionStatus(userServer.tokens, userServer.tokensExpiresAt).state === "connected" + ) + .map((userServer) => userServer.serverId); + const disabledServerIds = new Set(disabledMcpServerIds); + const askMcpDisabledServerCount = connectedServerIds.filter((serverId) => disabledServerIds.has(serverId)).length; + const askMcpEnabledServerCount = connectedServerIds.length - askMcpDisabledServerCount; + + return { + hasAskMcpServersAvailable: askMcpEnabledServerCount > 0, + askMcpConnectedServerCount: connectedServerIds.length, + askMcpEnabledServerCount, + askMcpDisabledServerCount, + }; +} + +function isExternalMcpToolPart(part: SBChatMessagePart): part is SBChatMessagePart & DynamicToolUIPart { + return part.type === "dynamic-tool" && part.toolName.startsWith("mcp_"); +} + +function hasApproval(part: DynamicToolUIPart) { + return part.approval !== undefined; +} + +export function getAskMcpTurnCompletedAnalytics({ + messages, + availability, +}: { + messages: SBChatMessage[]; + availability: AskMcpAvailabilityAnalytics; +}): AskMcpTurnCompletedAnalytics | undefined { + const latestMessage = messages.at(-1); + const latestAssistantMessage = latestMessage?.role === "assistant" ? latestMessage : undefined; + if (!latestAssistantMessage) { + return undefined; + } + + const progressState = getTurnProgressState({ messages, status: "ready" }); + if (progressState.isTurnInProgress) { + return undefined; + } + + const externalMcpToolParts = latestAssistantMessage.parts.filter(isExternalMcpToolPart); + const askMcpToolSuccessCount = externalMcpToolParts.filter((part) => part.state === "output-available").length; + const askMcpToolFailureCount = externalMcpToolParts.filter((part) => part.state === "output-error").length; + const askMcpToolCallCount = askMcpToolSuccessCount + askMcpToolFailureCount; + const askMcpApprovalRequestedCount = externalMcpToolParts.filter(hasApproval).length; + const askMcpApprovalDeniedCount = externalMcpToolParts.filter((part) => part.state === "output-denied").length; + const askMcpFailedServerCount = latestAssistantMessage.parts.filter((part) => + part.type === "data-mcp-failed-server" + ).length; + + const hasMcpTurnActivity = externalMcpToolParts.length > 0 || askMcpFailedServerCount > 0; + if (!availability.hasAskMcpServersAvailable && !hasMcpTurnActivity) { + return undefined; + } + + return { + traceId: latestAssistantMessage.metadata?.traceId, + askMcpUsed: askMcpToolCallCount > 0, + askMcpToolCallCount, + askMcpToolSuccessCount, + askMcpToolFailureCount, + askMcpApprovalRequestedCount, + askMcpApprovalDeniedCount, + askMcpFailedServerCount, + durationMs: latestAssistantMessage.metadata?.totalResponseTimeMs ?? 0, + }; +} diff --git a/packages/web/src/features/chat/components/chatThread/answerCard.tsx b/packages/web/src/ee/features/chat/components/chatThread/answerCard.tsx similarity index 97% rename from packages/web/src/features/chat/components/chatThread/answerCard.tsx rename to packages/web/src/ee/features/chat/components/chatThread/answerCard.tsx index d0922053d..2aeb2ac95 100644 --- a/packages/web/src/features/chat/components/chatThread/answerCard.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/answerCard.tsx @@ -11,14 +11,14 @@ import { Toggle } from "@/components/ui/toggle"; import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; import { CopyIconButton } from "@/app/(app)/components/copyIconButton"; import { useToast } from "@/components/hooks/use-toast"; -import { convertLLMOutputToPortableMarkdown } from "../../utils"; -import { submitFeedback } from "../../actions"; +import { convertLLMOutputToPortableMarkdown } from "@/features/chat/utils"; +import { submitFeedback } from "@/features/chat/actions"; import { isServiceError } from "@/lib/utils"; import useCaptureEvent from "@/hooks/useCaptureEvent"; import { LangfuseWeb } from "langfuse"; import { env } from "@sourcebot/shared/client"; import isEqual from "fast-deep-equal/react"; -import { FileSource } from "../../types"; +import { FileSource } from "@/features/chat/types"; interface AnswerCardProps { answerText: string; diff --git a/packages/web/src/features/chat/components/chatThread/chatThread.tsx b/packages/web/src/ee/features/chat/components/chatThread/chatThread.tsx similarity index 97% rename from packages/web/src/features/chat/components/chatThread/chatThread.tsx rename to packages/web/src/ee/features/chat/components/chatThread/chatThread.tsx index f1fbb26da..0bcddfd1d 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThread.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/chatThread.tsx @@ -14,18 +14,19 @@ import { Fragment, useCallback, useEffect, useMemo, useRef, useState } from 'rea import { useStickToBottom } from 'use-stick-to-bottom'; import { Descendant } from 'slate'; import { useMessagePairs } from '../../useMessagePairs'; -import { useSelectedLanguageModel } from '../../useSelectedLanguageModel'; -import { ChatBox } from '../chatBox'; -import { ChatBoxToolbar } from '../chatBox/chatBoxToolbar'; +import { useSelectedLanguageModel } from '@/features/chat/useSelectedLanguageModel'; +import { ChatBox } from '@/features/chat/components/chatBox'; +import { ChatBoxToolbar } from '@/features/chat/components/chatBox/chatBoxToolbar'; import { ChatThreadListItem } from './chatThreadListItem'; import { ErrorBanner } from './errorBanner'; import { McpFailedServersBanner } from './mcpFailedServersBanner'; import { useRouter } from 'next/navigation'; import { usePrevious } from '@uidotdev/usehooks'; import { RepositoryQuery, SearchContextQuery } from '@/lib/types'; -import { duplicateChat, generateAndUpdateChatNameFromMessage } from '../../actions'; +import { duplicateChat } from '@/features/chat/actions'; +import { generateAndUpdateChatNameFromMessage } from '@/ee/features/chat/actions'; import { isServiceError } from '@/lib/utils'; -import { NotConfiguredErrorBanner } from '../notConfiguredErrorBanner'; +import { NotConfiguredErrorBanner } from '@/features/chat/components/notConfiguredErrorBanner'; import { McpServerIconContext, McpServerIconMap } from '../../mcpServerIconContext'; import { ToolApprovalProvider } from '../../toolApprovalContext'; import useCaptureEvent from '@/hooks/useCaptureEvent'; @@ -139,7 +140,7 @@ export const ChatThread = ({ // triggered by sendAutomaticallyWhen after tool approval. // eslint-disable-next-line react-hooks/refs -- DefaultChatTransport stores the body callback and invokes it during requests, not during render. const transport = useMemo(() => new DefaultChatTransport({ - api: '/api/chat', + api: '/api/ee/chat', headers: { 'X-Sourcebot-Client-Source': 'sourcebot-web-client', }, diff --git a/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx b/packages/web/src/ee/features/chat/components/chatThread/chatThreadListItem.tsx similarity index 99% rename from packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx rename to packages/web/src/ee/features/chat/components/chatThread/chatThreadListItem.tsx index d05508081..d9ccb3d7c 100644 --- a/packages/web/src/features/chat/components/chatThread/chatThreadListItem.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/chatThreadListItem.tsx @@ -6,16 +6,16 @@ import { Skeleton } from '@/components/ui/skeleton'; import { CheckCircle, Loader2 } from 'lucide-react'; import { CSSProperties, forwardRef, memo, useCallback, useEffect, useMemo, useRef, useState } from 'react'; import scrollIntoView from 'scroll-into-view-if-needed'; -import { Reference, referenceSchema, SBChatMessage, Source } from "../../types"; +import { Reference, referenceSchema, SBChatMessage, Source } from "@/features/chat/types"; import { useExtractReferences } from '../../useExtractReferences'; -import { getAnswerPartFromAssistantMessage, getLastStepParts, groupMessageIntoSteps, isSBChatToolPart, repairReferences, tryResolveFileReference } from '../../utils'; +import { getAnswerPartFromAssistantMessage, getLastStepParts, groupMessageIntoSteps, isSBChatToolPart, repairReferences, tryResolveFileReference } from '@/features/chat/utils'; import { AnswerCard } from './answerCard'; import { DetailsCard } from './detailsCard'; import { ApprovalRequestedToolPart, ToolApprovalBanner } from './toolApprovalBanner'; import { MarkdownRenderer, REFERENCE_PAYLOAD_ATTRIBUTE } from './markdownRenderer'; import { ReferencedSourcesListView } from './referencedSourcesListView'; import isEqual from "fast-deep-equal/react"; -import { ANSWER_TAG } from '../../constants'; +import { ANSWER_TAG } from '@/features/chat/constants'; interface ChatThreadListItemProps { userMessage: SBChatMessage; @@ -426,7 +426,7 @@ const ChatThreadListItemComponent = forwardRef - ) : (isTurnInProgress) ? ( + ) : isNetworkActive ? (
{Array.from({ length: 3 }).map((_, index) => ( diff --git a/packages/web/src/features/chat/components/chatThread/codeBlock.tsx b/packages/web/src/ee/features/chat/components/chatThread/codeBlock.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/codeBlock.tsx rename to packages/web/src/ee/features/chat/components/chatThread/codeBlock.tsx diff --git a/packages/web/src/features/chat/components/chatThread/codeFoldingExpandButton.tsx b/packages/web/src/ee/features/chat/components/chatThread/codeFoldingExpandButton.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/codeFoldingExpandButton.tsx rename to packages/web/src/ee/features/chat/components/chatThread/codeFoldingExpandButton.tsx diff --git a/packages/web/src/features/chat/components/chatThread/codeFoldingExtension.test.ts b/packages/web/src/ee/features/chat/components/chatThread/codeFoldingExtension.test.ts similarity index 99% rename from packages/web/src/features/chat/components/chatThread/codeFoldingExtension.test.ts rename to packages/web/src/ee/features/chat/components/chatThread/codeFoldingExtension.test.ts index 50b73512c..65c8203af 100644 --- a/packages/web/src/features/chat/components/chatThread/codeFoldingExtension.test.ts +++ b/packages/web/src/ee/features/chat/components/chatThread/codeFoldingExtension.test.ts @@ -10,7 +10,7 @@ import { expandRegion, FoldingState, } from './codeFoldingExtension' -import { FileReference } from '../../types' +import { FileReference } from '@/features/chat/types' import { EditorState, StateField } from '@codemirror/state' describe('calculateVisibleRanges', () => { diff --git a/packages/web/src/features/chat/components/chatThread/codeFoldingExtension.ts b/packages/web/src/ee/features/chat/components/chatThread/codeFoldingExtension.ts similarity index 99% rename from packages/web/src/features/chat/components/chatThread/codeFoldingExtension.ts rename to packages/web/src/ee/features/chat/components/chatThread/codeFoldingExtension.ts index 0229f4eae..9594375dc 100644 --- a/packages/web/src/features/chat/components/chatThread/codeFoldingExtension.ts +++ b/packages/web/src/ee/features/chat/components/chatThread/codeFoldingExtension.ts @@ -5,7 +5,7 @@ import { EditorView, WidgetType } from "@codemirror/view"; -import { FileReference } from "../../types"; +import { FileReference } from "@/features/chat/types"; import React, { CSSProperties } from "react"; import { createRoot } from "react-dom/client"; import { CodeFoldingExpandButton } from "./codeFoldingExpandButton"; diff --git a/packages/web/src/features/chat/components/chatThread/detailsCard.test.tsx b/packages/web/src/ee/features/chat/components/chatThread/detailsCard.test.tsx similarity index 98% rename from packages/web/src/features/chat/components/chatThread/detailsCard.test.tsx rename to packages/web/src/ee/features/chat/components/chatThread/detailsCard.test.tsx index 6f9c924cc..01ebbaf13 100644 --- a/packages/web/src/features/chat/components/chatThread/detailsCard.test.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/detailsCard.test.tsx @@ -2,7 +2,7 @@ import { cleanup, render, screen } from '@testing-library/react'; import { afterEach, describe, expect, test, vi } from 'vitest'; import { TooltipProvider } from '@/components/ui/tooltip'; import { DetailsCard } from './detailsCard'; -import type { SBChatMessagePart } from '../../types'; +import type { SBChatMessagePart } from '@/features/chat/types'; vi.mock('@/hooks/useCaptureEvent', () => ({ default: () => vi.fn(), diff --git a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx b/packages/web/src/ee/features/chat/components/chatThread/detailsCard.tsx similarity index 99% rename from packages/web/src/features/chat/components/chatThread/detailsCard.tsx rename to packages/web/src/ee/features/chat/components/chatThread/detailsCard.tsx index cd6d8228d..cb0df256b 100644 --- a/packages/web/src/features/chat/components/chatThread/detailsCard.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/detailsCard.tsx @@ -12,8 +12,8 @@ import { useStickToBottom } from 'use-stick-to-bottom'; import { Brain, ChevronDown, ChevronRight, Clock, InfoIcon, Loader2, ScanSearchIcon, ShieldQuestion, Wrench, Zap } from 'lucide-react'; import { memo, useCallback, useEffect, useMemo, useState } from 'react'; import { usePrevious } from '@uidotdev/usehooks'; -import { SBChatMessageMetadata, SBChatMessagePart } from '../../types'; -import { SearchScopeIcon } from '../searchScopeIcon'; +import { SBChatMessageMetadata, SBChatMessagePart } from '@/features/chat/types'; +import { SearchScopeIcon } from '@/features/chat/components/searchScopeIcon'; import { MarkdownRenderer } from './markdownRenderer'; import { FindSymbolDefinitionsToolComponent } from './tools/findSymbolDefinitionsToolComponent'; import { FindSymbolReferencesToolComponent } from './tools/findSymbolReferencesToolComponent'; diff --git a/packages/web/src/features/chat/components/chatThread/errorBanner.tsx b/packages/web/src/ee/features/chat/components/chatThread/errorBanner.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/errorBanner.tsx rename to packages/web/src/ee/features/chat/components/chatThread/errorBanner.tsx diff --git a/packages/web/src/features/chat/components/chatThread/index.ts b/packages/web/src/ee/features/chat/components/chatThread/index.ts similarity index 100% rename from packages/web/src/features/chat/components/chatThread/index.ts rename to packages/web/src/ee/features/chat/components/chatThread/index.ts diff --git a/packages/web/src/features/chat/components/chatThread/linearIssueCard.tsx b/packages/web/src/ee/features/chat/components/chatThread/linearIssueCard.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/linearIssueCard.tsx rename to packages/web/src/ee/features/chat/components/chatThread/linearIssueCard.tsx diff --git a/packages/web/src/features/chat/components/chatThread/markdownRenderer.tsx b/packages/web/src/ee/features/chat/components/chatThread/markdownRenderer.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/markdownRenderer.tsx rename to packages/web/src/ee/features/chat/components/chatThread/markdownRenderer.tsx diff --git a/packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx b/packages/web/src/ee/features/chat/components/chatThread/mcpFailedServersBanner.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/mcpFailedServersBanner.tsx rename to packages/web/src/ee/features/chat/components/chatThread/mcpFailedServersBanner.tsx diff --git a/packages/web/src/features/chat/components/chatThread/referencedFileSourceListItem.tsx b/packages/web/src/ee/features/chat/components/chatThread/referencedFileSourceListItem.tsx similarity index 99% rename from packages/web/src/features/chat/components/chatThread/referencedFileSourceListItem.tsx rename to packages/web/src/ee/features/chat/components/chatThread/referencedFileSourceListItem.tsx index ab8957868..e2e4f68fc 100644 --- a/packages/web/src/features/chat/components/chatThread/referencedFileSourceListItem.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/referencedFileSourceListItem.tsx @@ -15,7 +15,7 @@ import CodeMirror, { ReactCodeMirrorRef } from '@uiw/react-codemirror'; import isEqual from "fast-deep-equal/react"; import { ChevronDown, ChevronRight } from "lucide-react"; import { forwardRef, memo, Ref, useEffect, useImperativeHandle, useMemo, useState } from "react"; -import { FileReference } from "../../types"; +import { FileReference } from "@/features/chat/types"; import { createCodeFoldingExtension } from "./codeFoldingExtension"; import { createReferencesHighlightExtension, setHoveredIdEffect, setSelectedIdEffect } from "./referencesHighlightExtension"; diff --git a/packages/web/src/features/chat/components/chatThread/referencedFileSourceListItemContainer.tsx b/packages/web/src/ee/features/chat/components/chatThread/referencedFileSourceListItemContainer.tsx similarity index 98% rename from packages/web/src/features/chat/components/chatThread/referencedFileSourceListItemContainer.tsx rename to packages/web/src/ee/features/chat/components/chatThread/referencedFileSourceListItemContainer.tsx index 80ba9e063..61221bb25 100644 --- a/packages/web/src/features/chat/components/chatThread/referencedFileSourceListItemContainer.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/referencedFileSourceListItemContainer.tsx @@ -7,7 +7,7 @@ import { isServiceError, unwrapServiceError } from "@/lib/utils"; import { useQuery } from "@tanstack/react-query"; import { ReactCodeMirrorRef } from '@uiw/react-codemirror'; import { memo, useCallback } from "react"; -import { FileReference, FileSource, Reference } from "../../types"; +import { FileReference, FileSource, Reference } from "@/features/chat/types"; import { ReferencedFileSourceListItem } from "./referencedFileSourceListItem"; import isEqual from 'fast-deep-equal/react'; diff --git a/packages/web/src/features/chat/components/chatThread/referencedSourcesListView.tsx b/packages/web/src/ee/features/chat/components/chatThread/referencedSourcesListView.tsx similarity index 98% rename from packages/web/src/features/chat/components/chatThread/referencedSourcesListView.tsx rename to packages/web/src/ee/features/chat/components/chatThread/referencedSourcesListView.tsx index 3197338ba..9e7438faf 100644 --- a/packages/web/src/features/chat/components/chatThread/referencedSourcesListView.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/referencedSourcesListView.tsx @@ -4,8 +4,8 @@ import { ScrollArea } from "@/components/ui/scroll-area"; import { ReactCodeMirrorRef } from "@uiw/react-codemirror"; import { memo, useCallback, useEffect, useMemo, useRef, useState } from "react"; import scrollIntoView from 'scroll-into-view-if-needed'; -import { FileReference, FileSource, Reference } from "../../types"; -import { tryResolveFileReference } from '../../utils'; +import { FileReference, FileSource, Reference } from "@/features/chat/types"; +import { tryResolveFileReference } from '@/features/chat/utils'; import { ReferencedFileSourceListItemContainer } from "./referencedFileSourceListItemContainer"; import isEqual from 'fast-deep-equal/react'; diff --git a/packages/web/src/features/chat/components/chatThread/referencesHighlightExtension.ts b/packages/web/src/ee/features/chat/components/chatThread/referencesHighlightExtension.ts similarity index 99% rename from packages/web/src/features/chat/components/chatThread/referencesHighlightExtension.ts rename to packages/web/src/ee/features/chat/components/chatThread/referencesHighlightExtension.ts index c2467bbe4..30800ed2f 100644 --- a/packages/web/src/features/chat/components/chatThread/referencesHighlightExtension.ts +++ b/packages/web/src/ee/features/chat/components/chatThread/referencesHighlightExtension.ts @@ -1,6 +1,6 @@ import { EditorState, Range, StateEffect, StateField } from "@codemirror/state"; import { Decoration, DecorationSet, EditorView } from "@codemirror/view"; -import { FileReference } from "../../types"; +import { FileReference } from "@/features/chat/types"; const lineDecoration = Decoration.line({ attributes: { class: "cm-range-border-radius chat-lineHighlight" }, diff --git a/packages/web/src/features/chat/components/chatThread/signInPromptBanner.tsx b/packages/web/src/ee/features/chat/components/chatThread/signInPromptBanner.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/signInPromptBanner.tsx rename to packages/web/src/ee/features/chat/components/chatThread/signInPromptBanner.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tableOfContents.tsx b/packages/web/src/ee/features/chat/components/chatThread/tableOfContents.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tableOfContents.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tableOfContents.tsx diff --git a/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx b/packages/web/src/ee/features/chat/components/chatThread/toolApprovalBanner.tsx similarity index 96% rename from packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx rename to packages/web/src/ee/features/chat/components/chatThread/toolApprovalBanner.tsx index 636c951f9..ed0ccdecc 100644 --- a/packages/web/src/features/chat/components/chatThread/toolApprovalBanner.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/toolApprovalBanner.tsx @@ -2,8 +2,8 @@ import { Button } from "@/components/ui/button"; import { McpFavicon } from "@/ee/features/chat/mcp/components/mcpFavicon"; -import { useMcpServerIconMap } from "@/features/chat/mcpServerIconContext"; -import { useToolApproval } from "@/features/chat/toolApprovalContext"; +import { useMcpServerIconMap } from "@/ee/features/chat/mcpServerIconContext"; +import { useToolApproval } from "@/ee/features/chat/toolApprovalContext"; import { SBChatToolPart } from "@/features/chat/utils"; import { cn } from "@/lib/utils"; import { getToolName } from "ai"; diff --git a/packages/web/src/features/chat/components/chatThread/tools/fileRow.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/fileRow.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/fileRow.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/fileRow.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/findSymbolDefinitionsToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/findSymbolDefinitionsToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/findSymbolDefinitionsToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/findSymbolDefinitionsToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/findSymbolReferencesToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/findSymbolReferencesToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/findSymbolReferencesToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/findSymbolReferencesToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/getDiffToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/getDiffToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/getDiffToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/getDiffToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/globToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/globToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/globToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/globToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/grepToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/grepToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/grepToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/grepToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/jsonHighlighter.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/jsonHighlighter.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/jsonHighlighter.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/listCommitsToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/listCommitsToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/listCommitsToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/listCommitsToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/listReposToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/listReposToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/listReposToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/listReposToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/listTreeToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/listTreeToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/listTreeToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/listTreeToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/mcpToolComponent.tsx similarity index 98% rename from packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/mcpToolComponent.tsx index aeca09156..c162d2841 100644 --- a/packages/web/src/features/chat/components/chatThread/tools/mcpToolComponent.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/tools/mcpToolComponent.tsx @@ -2,7 +2,7 @@ import { CopyIconButton } from "@/app/(app)/components/copyIconButton"; import { McpFavicon } from "@/ee/features/chat/mcp/components/mcpFavicon"; -import { useMcpServerIconMap } from "@/features/chat/mcpServerIconContext"; +import { useMcpServerIconMap } from "@/ee/features/chat/mcpServerIconContext"; import { cn } from "@/lib/utils"; import { DynamicToolUIPart } from "ai"; import { CheckCircle, ChevronDown, XCircle } from "lucide-react"; diff --git a/packages/web/src/features/chat/components/chatThread/tools/readFileToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/readFileToolComponent.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/readFileToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/readFileToolComponent.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/repoBadge.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/repoBadge.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/repoBadge.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/repoBadge.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/repoHeader.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/repoHeader.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/repoHeader.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/repoHeader.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/toolOutputGuard.tsx similarity index 100% rename from packages/web/src/features/chat/components/chatThread/tools/toolOutputGuard.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/toolOutputGuard.tsx diff --git a/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx b/packages/web/src/ee/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx similarity index 96% rename from packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx rename to packages/web/src/ee/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx index 545ed9b7f..58bcf4e90 100644 --- a/packages/web/src/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx +++ b/packages/web/src/ee/features/chat/components/chatThread/tools/toolSearchToolComponent.tsx @@ -27,6 +27,7 @@ export const ToolSearchToolComponent = ({ query, results }: ToolSearchToolCompon Searched connector tools: {query} {results.length} result{results.length === 1 ? '' : 's'} +
diff --git a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.test.tsx b/packages/web/src/ee/features/chat/components/chatThreadPanel.test.tsx similarity index 97% rename from packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.test.tsx rename to packages/web/src/ee/features/chat/components/chatThreadPanel.test.tsx index cc3391dc0..46b8d4d91 100644 --- a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.test.tsx +++ b/packages/web/src/ee/features/chat/components/chatThreadPanel.test.tsx @@ -11,7 +11,7 @@ vi.mock('next/navigation', () => ({ useParams: () => ({ id: 'chat-1' }), })); -vi.mock('@/features/chat/components/chatThread', () => ({ +vi.mock('@/ee/features/chat/components/chatThread', () => ({ ChatThread: (props: { disabledMcpServerIds?: unknown }) => { chatThreadProps.push(props); return
; diff --git a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx b/packages/web/src/ee/features/chat/components/chatThreadPanel.tsx similarity index 85% rename from packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx rename to packages/web/src/ee/features/chat/components/chatThreadPanel.tsx index 33808b486..0323c714f 100644 --- a/packages/web/src/app/(app)/chat/[id]/components/chatThreadPanel.tsx +++ b/packages/web/src/ee/features/chat/components/chatThreadPanel.tsx @@ -1,12 +1,12 @@ 'use client'; -import { ChatThread } from '@/features/chat/components/chatThread'; +import { ChatThread } from '@/ee/features/chat/components/chatThread'; import { LanguageModelInfo, SBChatMessage, SearchScope, SetChatStatePayload } from '@/features/chat/types'; import { SELECTED_SEARCH_SCOPES_LOCAL_STORAGE_KEY, SET_CHAT_STATE_SESSION_STORAGE_KEY } from '@/features/chat/constants'; import { RepositoryQuery, SearchContextQuery } from '@/lib/types'; import { CreateUIMessage } from 'ai'; import { useEffect, useState } from 'react'; -import { useChatId } from '../../useChatId'; +import { useChatId } from '@/app/(app)/chat/useChatId'; import { useSessionStorage } from 'usehooks-ts'; interface ChatThreadPanelProps { @@ -20,6 +20,14 @@ interface ChatThreadPanelProps { chatName?: string; } +const normalizeDisabledMcpServerIds = (value: unknown): string[] => { + if (!Array.isArray(value)) { + return []; + } + + return value.filter((id): id is string => typeof id === 'string'); +} + export const ChatThreadPanel = ({ languageModels, repos, @@ -45,7 +53,7 @@ export const ChatThreadPanel = ({ // Use the last user message to determine what repos, contexts, and MCP state we should select by default. const lastUserMessage = messages.findLast((message) => message.role === "user"); const defaultSelectedSearchScopes = lastUserMessage?.metadata?.selectedSearchScopes ?? []; - const defaultDisabledMcpServerIds = lastUserMessage?.metadata?.disabledMcpServerIds ?? []; + const defaultDisabledMcpServerIds = normalizeDisabledMcpServerIds(lastUserMessage?.metadata?.disabledMcpServerIds); const [selectedSearchScopes, setSelectedSearchScopes] = useState(defaultSelectedSearchScopes); const [disabledMcpServerIds, setDisabledMcpServerIds] = useState(defaultDisabledMcpServerIds); @@ -57,7 +65,7 @@ export const ChatThreadPanel = ({ try { setInputMessage(chatState.inputMessage); setSelectedSearchScopes(chatState.selectedSearchScopes); - setDisabledMcpServerIds(chatState.disabledMcpServerIds); + setDisabledMcpServerIds(normalizeDisabledMcpServerIds(chatState.disabledMcpServerIds)); } catch { console.error('Invalid chat state in session storage'); } finally { diff --git a/packages/web/src/ee/features/chat/llm.server.ts b/packages/web/src/ee/features/chat/llm.server.ts new file mode 100644 index 000000000..2880fded9 --- /dev/null +++ b/packages/web/src/ee/features/chat/llm.server.ts @@ -0,0 +1,32 @@ +import 'server-only'; + +import { LanguageModel } from '@sourcebot/schemas/v3/languageModel.type'; +import { generateText } from "ai"; +import { getAISDKLanguageModelAndOptions } from "@/features/chat/llm.server"; + +export const generateChatNameFromMessage = async ({ message, languageModelConfig }: { message: string, languageModelConfig: LanguageModel }) => { + const { model } = await getAISDKLanguageModelAndOptions(languageModelConfig); + + const prompt = `Convert this question into a short topic title (max 50 characters). + +Rules: +- Do NOT include question words (what, where, how, why, when, which) +- Do NOT end with a question mark +- Capitalize the first letter of the title +- Focus on the subject/topic being discussed +- Make it sound like a file name or category + +Examples: +"Where is the authentication code?" → "Authentication Code" +"How to setup the database?" → "Database Setup" +"What are the API endpoints?" → "API Endpoints" + +User question: ${message}`; + + const result = await generateText({ + model, + prompt, + }); + + return result.text; +} diff --git a/packages/web/src/ee/features/chat/mcp/actions.test.ts b/packages/web/src/ee/features/chat/mcp/actions.test.ts index 1009c3a4f..5558f598f 100644 --- a/packages/web/src/ee/features/chat/mcp/actions.test.ts +++ b/packages/web/src/ee/features/chat/mcp/actions.test.ts @@ -140,7 +140,7 @@ describe('createMcpServer', () => { expect(prisma.mcpServer.create).not.toHaveBeenCalled(); }); - test('owners cannot add org MCP servers when OAuth is unsupported', async () => { + test('owners cannot add org MCP servers when Ask Agent is unavailable', async () => { const prisma = setAuthContext(OrgRole.OWNER); mocks.hasEntitlement.mockResolvedValue(false); @@ -363,7 +363,7 @@ describe('deleteMcpServer', () => { expect(mocks.unsafePrisma.mcpServer.deleteMany).not.toHaveBeenCalled(); }); - test('owners can delete org MCP servers when OAuth is unsupported', async () => { + test('owners can delete org MCP servers when Ask Agent is unavailable', async () => { setAuthContext(OrgRole.OWNER); mocks.hasEntitlement.mockResolvedValue(false); mocks.unsafePrisma.mcpServer.deleteMany.mockResolvedValue({ count: 1 }); diff --git a/packages/web/src/ee/features/chat/mcp/actions.ts b/packages/web/src/ee/features/chat/mcp/actions.ts index 7ce05ede4..d206fead6 100644 --- a/packages/web/src/ee/features/chat/mcp/actions.ts +++ b/packages/web/src/ee/features/chat/mcp/actions.ts @@ -10,7 +10,7 @@ import { isServiceError } from '@/lib/utils'; import { McpServerClientInfoSource, OrgRole, type PrismaClient } from '@sourcebot/db'; import { StatusCodes } from 'http-status-codes'; import { z } from 'zod'; -import { sanitizeMcpServerName } from './utils'; +import { sanitizeMcpServerName } from '@/features/chat/mcp/utils'; import { hasEntitlement } from '@/lib/entitlements'; import { oauthNotSupported } from './errors'; import { checkMcpServerDcrSupport } from './dcrDiscovery'; @@ -145,7 +145,7 @@ async function prepareMcpServerCreate({ export const checkMcpServerDynamicClientRegistration = async (serverUrl: string) => sew(() => withAuth(async ({ role }) => withMinimumOrgRole(role, OrgRole.OWNER, async () => { - if (!(await hasEntitlement('oauth'))) { + if (!(await hasEntitlement('ask'))) { return oauthNotSupported(); } @@ -185,7 +185,7 @@ export const createStaticOAuthMcpServer = async ( return sew(() => withAuth(async ({ org, role, prisma }) => withMinimumOrgRole(role, OrgRole.OWNER, async (): Promise => { - if (!(await hasEntitlement('oauth'))) { + if (!(await hasEntitlement('ask'))) { return oauthNotSupported(); } @@ -249,7 +249,7 @@ export const createStaticOAuthMcpServer = async ( export const createMcpServer = async (name: string, serverUrl: string) => sew(() => withAuth(async ({ org, role, prisma }) => withMinimumOrgRole(role, OrgRole.OWNER, async () => { - if (!(await hasEntitlement('oauth'))) { + if (!(await hasEntitlement('ask'))) { return oauthNotSupported(); } diff --git a/packages/web/src/ee/features/chat/mcp/components/accountAskAgentPage.test.tsx b/packages/web/src/ee/features/chat/mcp/components/accountAskAgentPage.test.tsx new file mode 100644 index 000000000..bf583a64c --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/accountAskAgentPage.test.tsx @@ -0,0 +1,34 @@ +import { afterEach, describe, expect, test, vi } from 'vitest'; +import { cleanup, render, screen } from '@testing-library/react'; + +vi.mock('@/app/api/(client)/client', () => ({ + getMcpServersWithStatus: vi.fn(), + getMcpServerTools: vi.fn(), +})); +vi.mock('@/ee/features/chat/mcp/actions', () => ({ + disconnectMcpServer: vi.fn(), +})); + +const { AccountAskAgentEmptyState } = await import('./accountAskAgentPage'); + +afterEach(() => { + cleanup(); +}); + +describe('AccountAskAgentEmptyState', () => { + test('points owners to workspace Ask Agent settings', () => { + render(); + + expect(screen.getByText('No connectors configured yet')).toBeTruthy(); + expect(screen.getByText('Open Workspace Ask Agent to approve connectors for your workspace.')).toBeTruthy(); + expect(screen.getByRole('link', { name: /Open Workspace Ask Agent/ }).getAttribute('href')).toBe('/settings/workspaceAskAgent'); + }); + + test('tells members to contact an admin', () => { + render(); + + expect(screen.getByText('No connectors available')).toBeTruthy(); + expect(screen.getByText(/Contact your workspace admin/)).toBeTruthy(); + expect(screen.queryByRole('link', { name: /Open Workspace Ask Agent/ })).toBeNull(); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/components/accountAskAgentPage.tsx b/packages/web/src/ee/features/chat/mcp/components/accountAskAgentPage.tsx new file mode 100644 index 000000000..40805a588 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/accountAskAgentPage.tsx @@ -0,0 +1,497 @@ +'use client'; + +import { useEffect, useMemo, useRef, useState } from "react"; +import Link from "next/link"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { CableIcon, ExternalLink, MoreHorizontal, SearchIcon, Settings2Icon, Unplug } from "lucide-react"; +import { getMcpServersWithStatus } from "@/app/api/(client)/client"; +import { useToast } from "@/components/hooks/use-toast"; +import { + AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, + AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogTitle, +} from "@/components/ui/alert-dialog"; +import { Button } from "@/components/ui/button"; +import { Card, CardContent } from "@/components/ui/card"; +import { + DropdownMenu, DropdownMenuContent, DropdownMenuItem, DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Input } from "@/components/ui/input"; +import { Separator } from "@/components/ui/separator"; +import { Skeleton } from "@/components/ui/skeleton"; +import { ConnectMcpButton } from "@/ee/features/chat/mcp/components/connectMcpButton"; +import { ConnectorCard } from "@/ee/features/chat/mcp/components/connectorCard"; +import { ConnectorRowInfo } from "@/ee/features/chat/mcp/components/connectorRowInfo"; +import { ConnectorToolTrigger } from "@/ee/features/chat/mcp/components/connectorToolDisclosure"; +import { useConnectMcp } from "@/ee/features/chat/mcp/hooks/useConnectMcp"; +import { useMcpToolMetadata } from "@/ee/features/chat/mcp/hooks/useMcpToolMetadata"; +import { disconnectMcpServer } from "@/ee/features/chat/mcp/actions"; +import { invalidateMcpConfigurationQueries, mcpQueryKeys } from "@/ee/features/chat/mcp/queryKeys"; +import { pluralize } from "@/features/chat/mcp/utils"; +import { cn, isServiceError } from "@/lib/utils"; +import type { McpServerWithStatus } from "@/app/api/(server)/ee/askmcp/servers/route"; +import type { ServerToolsEntry } from "@/ee/features/chat/mcp/types"; + +type FilterTab = "all" | "connected"; + +function clearCallbackParams() { + const url = new URL(window.location.href); + url.searchParams.delete('status'); + url.searchParams.delete('server'); + url.searchParams.delete('message'); + window.history.replaceState({}, '', url.toString()); +} + +interface AccountAskAgentPageProps { + callbackStatus?: string; + callbackServer?: string; + callbackMessage?: string; + canManageConnectors: boolean; +} + +export function AccountAskAgentEmptyState({ canManageConnectors }: { canManageConnectors: boolean }) { + return ( + + +
+ +
+

+ {canManageConnectors ? "No connectors configured yet" : "No connectors available"} +

+

+ {canManageConnectors + ? "Open Workspace Ask Agent to approve connectors for your workspace." + : "No connectors have been approved for this workspace yet. Contact your workspace admin."} +

+ {canManageConnectors && ( + + )} +
+
+ ); +} + +interface AccountConnectedConnectorCardProps { + server: McpServerWithStatus; + toolEntry?: ServerToolsEntry; + isToolsLoading: boolean; + isToolsError: boolean; + onRetryTools: () => void; + onReconnect: (serverId: string) => void; + onDisconnect: (server: McpServerWithStatus) => void; + disconnectingServerId: string | null; +} + +function AccountConnectedConnectorCard({ + server, + toolEntry, + isToolsLoading, + isToolsError, + onRetryTools, + onReconnect, + onDisconnect, + disconnectingServerId, +}: AccountConnectedConnectorCardProps) { + return ( + + {server.isConnected && ( + + + Connected + + )} + {server.isAuthExpired && ( + + + Authorization expired + + )} + + } + actionButtons={ + + + + + + onReconnect(server.id)}> + + Reconnect + + onDisconnect(server)} + > + + {disconnectingServerId === server.id ? "Disconnecting..." : "Disconnect"} + + + + } + /> + ); +} + +function AccountSuggestedConnectorCard({ server }: { server: McpServerWithStatus }) { + return ( + + + +
+ +
+
+ +
+
+ ); +} + +export function AccountAskAgentPage({ + callbackStatus, + callbackServer, + callbackMessage, + canManageConnectors, +}: AccountAskAgentPageProps) { + const { toast } = useToast(); + const queryClient = useQueryClient(); + const didHandleCallbackRef = useRef(false); + const [searchQuery, setSearchQuery] = useState(""); + const [activeTab, setActiveTab] = useState("all"); + const [disconnectingServerId, setDisconnectingServerId] = useState(null); + const [confirmDisconnectServer, setConfirmDisconnectServer] = useState<{ id: string; name: string } | null>(null); + const { connect: reconnectMcp } = useConnectMcp(); + + useEffect(() => { + if (didHandleCallbackRef.current) { + return; + } + if (callbackStatus === 'connected') { + didHandleCallbackRef.current = true; + toast({ description: `Successfully connected${callbackServer ? ` to ${callbackServer}` : ''}.` }); + clearCallbackParams(); + } else if (callbackStatus === 'error') { + didHandleCallbackRef.current = true; + toast({ title: "Connection failed", description: callbackMessage ?? 'Failed to connect connector.', variant: "destructive" }); + clearCallbackParams(); + } + }, [callbackStatus, callbackServer, callbackMessage, toast]); + + const { data: servers = [], isLoading, isError } = useQuery({ + queryKey: mcpQueryKeys.serversWithStatus, + queryFn: async () => { + const result = await getMcpServersWithStatus(); + if (isServiceError(result)) { + throw new Error("Failed to load connectors"); + } + return result; + }, + }); + + const connectedServers = useMemo( + () => servers.filter((s) => s.isConnected || s.isAuthExpired), + [servers], + ); + + const suggestedServers = useMemo( + () => servers.filter((s) => !s.isConnected && !s.isAuthExpired), + [servers], + ); + + const filteredConnected = useMemo(() => { + const list = connectedServers; + if (!searchQuery.trim()) { + return list; + } + const q = searchQuery.toLowerCase(); + return list.filter( + (s) => (s.name?.toLowerCase().includes(q)) || s.serverUrl.toLowerCase().includes(q), + ); + }, [connectedServers, searchQuery]); + + const filteredSuggested = useMemo(() => { + const list = suggestedServers; + if (!searchQuery.trim()) { + return list; + } + const q = searchQuery.toLowerCase(); + return list.filter( + (s) => (s.name?.toLowerCase().includes(q)) || s.serverUrl.toLowerCase().includes(q), + ); + }, [suggestedServers, searchQuery]); + + const visibleConnected = filteredConnected; + const visibleSuggested = activeTab === "all" ? filteredSuggested : []; + const activeConnectedServerCount = useMemo( + () => servers.filter((s) => s.isConnected).length, + [servers], + ); + const { + isToolsLoading, + isToolsError, + refetchTools, + toolsByServerId, + } = useMcpToolMetadata(true, activeConnectedServerCount); + + const handleDisconnect = async (serverId: string) => { + setDisconnectingServerId(serverId); + setConfirmDisconnectServer(null); + try { + const result = await disconnectMcpServer(serverId, 'account_settings'); + if (isServiceError(result)) { + toast({ title: "Error", description: `Failed to disconnect: ${result.message}`, variant: "destructive" }); + return; + } + toast({ description: "Connector disconnected." }); + await invalidateMcpConfigurationQueries(queryClient); + } catch { + toast({ title: "Error", description: "Failed to disconnect connector.", variant: "destructive" }); + } finally { + setDisconnectingServerId(null); + } + }; + + if (isError) { + return
Error loading connectors
; + } + + if (!isLoading && servers.length === 0) { + return ( +
+
+

Ask Agent

+

+ Manage your personal Ask Agent setup. +

+
+ +
+
+

Connectors

+

+ Manage workspace-approved connectors for use with Ask Agent. +

+
+ +
+
+ ); + } + + return ( +
+
+

Ask Agent

+

+ Manage your personal Ask Agent setup. +

+
+ + + +
+
+

Connectors

+

+ Manage workspace-approved connectors for use with Ask Agent. +

+
+ + {/* Search + filter bar */} +
+
+ + setSearchQuery(e.target.value)} + className="pl-9" + /> +
+
+ + +
+
+
+ + {isLoading ? ( +
+ {Array.from({ length: 3 }).map((_, index) => ( + + + +
+ + +
+ +
+
+ ))} +
+ ) : ( + <> + {/* Connected section */} +
+
+

+ Connected +

+

+ {connectedServers.length} {pluralize(connectedServers.length, "connector")} +

+
+ + {visibleConnected.length === 0 ? ( + + +

+ {searchQuery.trim() + ? "No connected connectors match your search." + : "No connectors connected yet."} +

+
+
+ ) : ( + visibleConnected.map((server) => ( + { void refetchTools(); }} + onReconnect={reconnectMcp} + onDisconnect={(serverToDisconnect) => setConfirmDisconnectServer({ + id: serverToDisconnect.id, + name: serverToDisconnect.name || serverToDisconnect.serverUrl, + })} + disconnectingServerId={disconnectingServerId} + /> + )) + )} +
+ + {/* Suggested section */} + {activeTab === "all" && ( +
+
+

+ Suggested +

+

+ workspace-approved +

+
+ + {visibleSuggested.length === 0 ? ( + + +

+ {searchQuery.trim() + ? "No suggested connectors match your search." + : "All connectors are connected."} +

+
+
+ ) : ( + visibleSuggested.map((server) => ( + + )) + )} +
+ )} + + )} + + {/* Disconnect confirmation dialog */} + { + if (!open) { + setConfirmDisconnectServer(null); + } + }} + > + + + Disconnect Connector + + Are you sure you want to disconnect from {confirmDisconnectServer?.name}? Your stored credentials for this connector will be removed. + + + + Cancel + { + if (confirmDisconnectServer) { + handleDisconnect(confirmDisconnectServer.id); + } + }} + className="bg-destructive text-destructive-foreground hover:bg-destructive/90" + > + Disconnect + + + + +
+ ); +} diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorCard.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorCard.tsx index 3fc7feaf7..eb2d3292d 100644 --- a/packages/web/src/ee/features/chat/mcp/components/connectorCard.tsx +++ b/packages/web/src/ee/features/chat/mcp/components/connectorCard.tsx @@ -14,7 +14,7 @@ interface ConnectorCardProps { isConnected: boolean; isAuthExpired?: boolean; - isOAuthAvailable?: boolean; + isAskAgentAvailable?: boolean; isStatusUnavailable?: boolean; toolEntry?: ServerToolsEntry; toolUsage?: McpServerToolUsageSummary; @@ -32,7 +32,7 @@ export function ConnectorCard({ serverUrl, isConnected, isAuthExpired, - isOAuthAvailable, + isAskAgentAvailable, isStatusUnavailable, toolEntry, toolUsage, @@ -68,7 +68,7 @@ export function ConnectorCard({ { test('renders unavailable state before connection-specific states', () => { renderToolTrigger({ isConnected: false, - isOAuthAvailable: false, + isAskAgentAvailable: false, }); expect(screen.getByText('Tools unavailable')).toBeTruthy(); diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.tsx index 659fc9702..f951b2978 100644 --- a/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.tsx +++ b/packages/web/src/ee/features/chat/mcp/components/connectorToolDisclosure.tsx @@ -3,7 +3,7 @@ import { useEffect, useState } from 'react'; import { Badge } from '@/components/ui/badge'; import { cn } from '@/lib/utils'; -import { pluralize } from '@/ee/features/chat/mcp/utils'; +import { pluralize } from '@/features/chat/mcp/utils'; import type { ServerToolsEntry, ToolMetadataErrorReason, ToolSummary } from '@/ee/features/chat/mcp/types'; import { ChevronDownIcon, RefreshCwIcon, WrenchIcon } from 'lucide-react'; @@ -30,7 +30,7 @@ function getToolCountLabel(entry: Extract { test('keeps connected and expired servers separate from connectable approved servers', () => { diff --git a/packages/web/src/ee/features/chat/mcp/components/connectorsMenu.tsx b/packages/web/src/ee/features/chat/mcp/components/connectorsMenu.tsx new file mode 100644 index 000000000..345758469 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/connectorsMenu.tsx @@ -0,0 +1,312 @@ +'use client'; + +import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuItem, + DropdownMenuSeparator, + DropdownMenuSub, + DropdownMenuSubContent, + DropdownMenuSubTrigger, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Switch } from "@/components/ui/switch"; +import { connectMcpToAsk, getMcpServersWithStatus } from "@/app/api/(client)/client"; +import { useToast } from "@/components/hooks/use-toast"; +import { McpFavicon } from "@/ee/features/chat/mcp/components/mcpFavicon"; +import { mcpQueryKeys } from "@/ee/features/chat/mcp/queryKeys"; +import { isServiceError } from "@/lib/utils"; +import { useQuery, useQueryClient } from "@tanstack/react-query"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { AlertTriangleIcon, CableIcon, Loader2Icon, PlusCircleIcon, PlusIcon, RefreshCwIcon, SettingsIcon } from "lucide-react"; +import { PlusButtonInfoCard } from "@/features/chat/components/chatBox/plusButtonInfoCard"; +import { useRouter } from "next/navigation"; +import { useEffect, useRef, useState } from "react"; +import { useSlate } from "slate-react"; +import { Editor } from "slate"; +import type { CustomEditor, SearchScope } from "@/features/chat/types"; +import { + clearMcpOAuthDraft, + consumeMcpOAuthDraftForPath, + createMcpOAuthDraftPath, + saveMcpOAuthDraft, +} from "@/features/chat/mcpOAuthDraft"; +import { clearEditorHistory, resetEditor } from "@/features/chat/utils"; + +interface ConnectorsMenuProps { + selectedSearchScopes: SearchScope[]; + onSelectedSearchScopesChange: (items: SearchScope[]) => void; + disabledMcpServerIds: string[]; + onDisabledMcpServerIdsChange: (ids: string[]) => void; +} + +interface ChatMenuMcpServer { + isConnected: boolean; + isAuthExpired: boolean; +} + +export function splitMcpServersForChatMenu(servers: T[]) { + return { + connectedServers: servers.filter((server) => server.isConnected || server.isAuthExpired), + connectableServers: servers.filter((server) => !server.isConnected && !server.isAuthExpired), + }; +} + +function restoreEditorChildren(editor: CustomEditor, children: CustomEditor['children']) { + editor.children = children; + editor.selection = { + anchor: Editor.end(editor, []), + focus: Editor.end(editor, []), + }; + clearEditorHistory(editor); + editor.onChange(); +} + +export const ConnectorsMenu = ({ + selectedSearchScopes, + onSelectedSearchScopesChange, + disabledMcpServerIds, + onDisabledMcpServerIdsChange, +}: ConnectorsMenuProps) => { + const [connectingServerId, setConnectingServerId] = useState(null); + const editor = useSlate(); + const hasRestoredMcpOAuthDraft = useRef(false); + const isMountedRef = useRef(false); + const queryClient = useQueryClient(); + const router = useRouter(); + const { toast } = useToast(); + + const { data: servers = [], isError, isLoading, refetch } = useQuery({ + queryKey: mcpQueryKeys.serversWithStatus, + queryFn: async () => { + const result = await getMcpServersWithStatus(); + if (isServiceError(result)) { + throw new Error("Failed to load connectors"); + } + return result; + }, + }); + + useEffect(() => { + isMountedRef.current = true; + + return () => { + isMountedRef.current = false; + }; + }, []); + + useEffect(() => { + if (hasRestoredMcpOAuthDraft.current) { + return; + } + + const currentPath = createMcpOAuthDraftPath(window.location.pathname, window.location.search); + if (!currentPath) { + return; + } + + const draft = consumeMcpOAuthDraftForPath(currentPath); + if (!draft) { + return; + } + + hasRestoredMcpOAuthDraft.current = true; + + try { + restoreEditorChildren(editor, draft.children); + onSelectedSearchScopesChange(draft.selectedSearchScopes); + onDisabledMcpServerIdsChange(draft.disabledMcpServerIds); + } catch (error) { + resetEditor(editor); + editor.onChange(); + console.error('Failed to restore MCP OAuth draft:', error); + } + }, [editor, onDisabledMcpServerIdsChange, onSelectedSearchScopesChange]); + + const onToggle = (serverId: string, checked: boolean) => { + if (checked) { + onDisabledMcpServerIdsChange(disabledMcpServerIds.filter((id) => id !== serverId)); + } else { + onDisabledMcpServerIdsChange([...disabledMcpServerIds, serverId]); + } + }; + + const handleConnect = async (serverId: string) => { + setConnectingServerId(serverId); + const returnTo = createMcpOAuthDraftPath(window.location.pathname, window.location.search) ?? '/chat'; + + saveMcpOAuthDraft({ + returnTo, + children: editor.children, + selectedSearchScopes, + disabledMcpServerIds, + }); + + try { + const result = await connectMcpToAsk({ + serverId, + returnTo, + }); + + if (!isMountedRef.current) { + return; + } + + if (isServiceError(result)) { + clearMcpOAuthDraft(); + toast({ + description: `Failed to connect connector. ${result.message}`, + variant: "destructive", + }); + setConnectingServerId(null); + return; + } + + if (result.authorizationUrl) { + window.location.href = result.authorizationUrl; + return; + } + + clearMcpOAuthDraft(); + toast({ description: 'Connector is already connected.' }); + await queryClient.invalidateQueries({ queryKey: mcpQueryKeys.serversWithStatus }); + if (!isMountedRef.current) { + return; + } + setConnectingServerId(null); + } catch { + if (!isMountedRef.current) { + return; + } + + clearMcpOAuthDraft(); + toast({ + description: "Failed to connect connector.", + variant: "destructive", + }); + setConnectingServerId(null); + return; + } + }; + + const { connectedServers, connectableServers } = splitMcpServersForChatMenu(servers); + const hasServers = connectedServers.length > 0 || connectableServers.length > 0; + + return ( + + + + + + + + + + + + e.preventDefault()}> + + + + Connectors + + + {isError && !hasServers ? ( + { + e.preventDefault(); + refetch(); + }} + className="gap-2 text-destructive" + > + + Failed to load. Retry? + + ) : isLoading ? ( + + Loading connectors... + + ) : !hasServers ? ( + + No connectors available + + ) : ( + <> + {connectedServers.map((server) => { + const isEnabled = !server.isAuthExpired && !disabledMcpServerIds.includes(server.id); + return ( + e.preventDefault()} + disabled={server.isAuthExpired} + className="flex items-center justify-between gap-2" + > +
+ {server.isAuthExpired ? ( + + ) : ( + + )} + {server.name} +
+ onToggle(server.id, checked)} + disabled={server.isAuthExpired} + className="scale-75" + /> +
+ ); + })} + {connectedServers.length > 0 && connectableServers.length > 0 && } + {connectableServers.map((server) => ( + { + e.preventDefault(); + void handleConnect(server.id); + }} + disabled={connectingServerId !== null} + className="group flex cursor-pointer items-center justify-between gap-2" + > +
+ + {server.name} +
+ {connectingServerId === server.id ? ( + + ) : ( + + )} +
+ ))} + + )} + + router.push(`/settings/accountAskAgent`)} + > + + My connectors + + router.push(`/settings/workspaceAskAgent`)} + > + + Workspace connectors + +
+
+
+
+ ); +}; diff --git a/packages/web/src/ee/features/chat/mcp/components/prefabConnectorPopover.tsx b/packages/web/src/ee/features/chat/mcp/components/prefabConnectorPopover.tsx new file mode 100644 index 000000000..23241443a --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/components/prefabConnectorPopover.tsx @@ -0,0 +1,129 @@ +'use client'; + +import { useMemo, useState } from "react"; +import { + Command, + CommandGroup, + CommandInput, + CommandItem, + CommandList, + CommandSeparator, +} from "@/components/ui/command"; +import { Button } from "@/components/ui/button"; +import { Popover, PopoverContent, PopoverTrigger } from "@/components/ui/popover"; +import { getDisplayServerUrl } from "@/ee/features/chat/mcp/components/connectorRowInfo"; +import { McpFavicon } from "@/ee/features/chat/mcp/components/mcpFavicon"; +import { + getAvailablePrefabMcpServers, + type PrefabMcpServer, +} from "@/ee/features/chat/mcp/prefabMcpServers"; +import { getMcpFaviconUrl } from "@/features/chat/mcp/utils"; +import { PlusIcon } from "lucide-react"; + +interface PrefabConnectorPopoverProps { + configuredServerUrls: string[]; + disabled?: boolean; + onSelectCustomUrl: () => void; + onSelectPrefabServer: (server: PrefabMcpServer) => void; + children?: React.ReactNode; +} + +export function PrefabConnectorPopover({ + configuredServerUrls, + disabled, + onSelectCustomUrl, + onSelectPrefabServer, + children, +}: PrefabConnectorPopoverProps) { + const [isOpen, setIsOpen] = useState(false); + const [search, setSearch] = useState(""); + + const availablePrefabServers = useMemo(() => ( + getAvailablePrefabMcpServers(configuredServerUrls) + ), [configuredServerUrls]); + + const filteredPrefabServers = useMemo(() => { + const normalizedSearch = search.trim().toLowerCase(); + + if (!normalizedSearch) { + return availablePrefabServers; + } + + return availablePrefabServers.filter((server) => server.name.toLowerCase().includes(normalizedSearch)); + }, [availablePrefabServers, search]); + + const handleOpenChange = (open: boolean) => { + setIsOpen(open); + + if (!open) { + setSearch(""); + } + }; + + const handleSelectPrefabServer = (server: PrefabMcpServer) => { + handleOpenChange(false); + onSelectPrefabServer(server); + }; + + const handleSelectCustomUrl = () => { + handleOpenChange(false); + onSelectCustomUrl(); + }; + + return ( + + + {children ?? ( + + )} + + + + + + + {filteredPrefabServers.map((server) => ( + handleSelectPrefabServer(server)} + className="cursor-pointer" + > +
+ +
+
+

{server.name}

+

{getDisplayServerUrl(server.serverUrl)}

+
+
+ ))} + {search.trim() && filteredPrefabServers.length === 0 && ( +
+ No connectors found. +
+ )} +
+ + + + + Custom URL... + + +
+
+
+
+ ); +} diff --git a/packages/web/src/ee/features/chat/mcp/hooks/useMcpToolMetadata.ts b/packages/web/src/ee/features/chat/mcp/hooks/useMcpToolMetadata.ts index b86eb66b5..d0e70063f 100644 --- a/packages/web/src/ee/features/chat/mcp/hooks/useMcpToolMetadata.ts +++ b/packages/web/src/ee/features/chat/mcp/hooks/useMcpToolMetadata.ts @@ -9,7 +9,7 @@ import type { ServerToolsEntry } from '@/ee/features/chat/mcp/types'; const EMPTY_TOOL_ENTRIES: ServerToolsEntry[] = []; -export function useMcpToolMetadata(isOAuthAvailable: boolean, connectedServerCount: number) { +export function useMcpToolMetadata(isAskAgentAvailable: boolean, connectedServerCount: number) { const queryClient = useQueryClient(); const lastAuthFailureInvalidatedAtRef = useRef(0); const { @@ -30,7 +30,7 @@ export function useMcpToolMetadata(isOAuthAvailable: boolean, connectedServerCou } return result; }, - enabled: isOAuthAvailable && connectedServerCount > 0, + enabled: isAskAgentAvailable && connectedServerCount > 0, staleTime: 5 * 60 * 1000, gcTime: 30 * 60 * 1000, refetchOnWindowFocus: false, diff --git a/packages/web/src/ee/features/chat/mcp/mcpClientFactory.test.ts b/packages/web/src/ee/features/chat/mcp/mcpClientFactory.test.ts index 4333449bc..af1318dc9 100644 --- a/packages/web/src/ee/features/chat/mcp/mcpClientFactory.test.ts +++ b/packages/web/src/ee/features/chat/mcp/mcpClientFactory.test.ts @@ -17,7 +17,7 @@ vi.mock('@sourcebot/shared', () => ({ vi.mock('server-only', () => ({ default: vi.fn() })); -vi.mock('@/features/mcp/prismaOAuthClientProvider', () => ({ +vi.mock('@/ee/features/chat/mcp/prismaOAuthClientProvider', () => ({ PrismaOAuthClientProvider: vi.fn(), })); diff --git a/packages/web/src/ee/features/chat/mcp/mcpClientFactory.ts b/packages/web/src/ee/features/chat/mcp/mcpClientFactory.ts index b74d710ec..e0a7b5e55 100644 --- a/packages/web/src/ee/features/chat/mcp/mcpClientFactory.ts +++ b/packages/web/src/ee/features/chat/mcp/mcpClientFactory.ts @@ -1,5 +1,5 @@ import { createLogger, env } from '@sourcebot/shared'; -import { PrismaOAuthClientProvider } from '@/features/mcp/prismaOAuthClientProvider'; +import { PrismaOAuthClientProvider } from '@/ee/features/chat/mcp/prismaOAuthClientProvider'; import { StreamableHTTPClientTransport } from '@modelcontextprotocol/sdk/client/streamableHttp.js'; import type { PrismaClient } from '@sourcebot/db'; import { getExternalMcpErrorLogFields } from './externalMcpError'; diff --git a/packages/web/src/ee/features/chat/mcp/mcpOAuthReturnTo.test.ts b/packages/web/src/ee/features/chat/mcp/mcpOAuthReturnTo.test.ts new file mode 100644 index 000000000..8b3b8a0fe --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpOAuthReturnTo.test.ts @@ -0,0 +1,36 @@ +import { describe, expect, test } from 'vitest'; +import { + createMcpOAuthState, + getMcpOAuthReturnToFromState, + normalizeMcpOAuthReturnTo, +} from './mcpOAuthReturnTo'; + +describe('MCP OAuth return paths', () => { + test('allows chat return paths', () => { + expect(normalizeMcpOAuthReturnTo('/chat')).toBe('/chat'); + expect(normalizeMcpOAuthReturnTo('/chat/thread-1?foo=bar')).toBe('/chat/thread-1?foo=bar'); + }); + + test('allows connector settings return paths', () => { + expect(normalizeMcpOAuthReturnTo('/settings/accountAskAgent?status=connected')).toBe('/settings/accountAskAgent?status=connected'); + }); + + test('rejects external and unrelated return paths', () => { + expect(normalizeMcpOAuthReturnTo('https://evil.example.com/chat')).toBeUndefined(); + expect(normalizeMcpOAuthReturnTo('//evil.example.com/chat')).toBeUndefined(); + expect(normalizeMcpOAuthReturnTo('/settings')).toBeUndefined(); + }); + + test('encodes and decodes return paths inside OAuth state', () => { + const state = createMcpOAuthState('nonce-1', '/chat'); + + expect(state).not.toBe('nonce-1'); + expect(getMcpOAuthReturnToFromState(state)).toBe('/chat'); + }); + + test('leaves state unchanged when no valid return path exists', () => { + expect(createMcpOAuthState('nonce-1')).toBe('nonce-1'); + expect(createMcpOAuthState('nonce-1', '/settings')).toBe('nonce-1'); + expect(getMcpOAuthReturnToFromState('nonce-1')).toBeUndefined(); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/mcpOAuthReturnTo.ts b/packages/web/src/ee/features/chat/mcp/mcpOAuthReturnTo.ts new file mode 100644 index 000000000..8127abdbc --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/mcpOAuthReturnTo.ts @@ -0,0 +1,63 @@ +const MCP_OAUTH_STATE_PREFIX = 'sourcebot_mcp.'; +const MCP_OAUTH_STATE_BASE_URL = 'https://sourcebot.invalid'; + +function isAllowedMcpOAuthReturnPath(pathname: string): boolean { + return pathname === '/chat' || pathname.startsWith('/chat/') || pathname === '/settings/accountAskAgent'; +} + +export function normalizeMcpOAuthReturnTo(returnTo: unknown): string | undefined { + if (typeof returnTo !== 'string') { + return undefined; + } + + const trimmedReturnTo = returnTo.trim(); + if (!trimmedReturnTo || !trimmedReturnTo.startsWith('/') || trimmedReturnTo.startsWith('//') || trimmedReturnTo.includes('\\')) { + return undefined; + } + + try { + const url = new URL(trimmedReturnTo, MCP_OAUTH_STATE_BASE_URL); + if (url.origin !== MCP_OAUTH_STATE_BASE_URL || !isAllowedMcpOAuthReturnPath(url.pathname)) { + return undefined; + } + + return `${url.pathname}${url.search}`; + } catch { + return undefined; + } +} + +export function createMcpOAuthState(nonce: string, returnTo?: string): string { + const normalizedReturnTo = normalizeMcpOAuthReturnTo(returnTo); + if (!normalizedReturnTo) { + return nonce; + } + + const encoded = Buffer.from(JSON.stringify({ + nonce, + returnTo: normalizedReturnTo, + })).toString('base64url'); + return `${MCP_OAUTH_STATE_PREFIX}${encoded}`; +} + +export function getMcpOAuthReturnToFromState(state: string | null | undefined): string | undefined { + if (!state?.startsWith(MCP_OAUTH_STATE_PREFIX)) { + return undefined; + } + + try { + const encoded = state.slice(MCP_OAUTH_STATE_PREFIX.length); + const payload = JSON.parse(Buffer.from(encoded, 'base64url').toString('utf8')) as unknown; + if ( + typeof payload === 'object' && + payload !== null && + 'returnTo' in payload + ) { + return normalizeMcpOAuthReturnTo(payload.returnTo); + } + } catch { + return undefined; + } + + return undefined; +} diff --git a/packages/web/src/ee/features/chat/mcp/mcpToolSets.ts b/packages/web/src/ee/features/chat/mcp/mcpToolSets.ts index 4e249247b..ccc5c6f26 100644 --- a/packages/web/src/ee/features/chat/mcp/mcpToolSets.ts +++ b/packages/web/src/ee/features/chat/mcp/mcpToolSets.ts @@ -5,7 +5,7 @@ import Ajv from 'ajv'; import { jsonSchema, ToolExecutionOptions } from 'ai'; import type { JSONSchema7, JSONSchema7Definition } from 'json-schema'; import { getExternalMcpErrorLogFields } from './externalMcpError'; -import { getMcpFaviconUrl } from './utils'; +import { getMcpFaviconUrl } from '@/features/chat/mcp/utils'; import { __unsafePrisma } from '@/prisma'; import { Prisma } from '@sourcebot/db'; import { captureEvent } from '@/lib/posthog'; diff --git a/packages/web/src/ee/features/chat/mcp/prismaOAuthClientProvider.test.ts b/packages/web/src/ee/features/chat/mcp/prismaOAuthClientProvider.test.ts new file mode 100644 index 000000000..cf40926f4 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/prismaOAuthClientProvider.test.ts @@ -0,0 +1,268 @@ +import { describe, expect, test, vi, beforeEach } from 'vitest'; +import { McpServerClientInfoSource } from '@sourcebot/db'; + +const mocks = vi.hoisted(() => ({ + logger: { + info: vi.fn(), + warn: vi.fn(), + error: vi.fn(), + debug: vi.fn(), + }, +})); + +vi.mock('server-only', () => ({})); +vi.mock('@/prisma', () => ({ + __unsafePrisma: { + mcpServer: {}, + userMcpServer: {}, + }, +})); +vi.mock('@sourcebot/shared', () => ({ + encryptOAuthToken: vi.fn((text: string | null | undefined) => text ? `encrypted:${text}` : undefined), + decryptOAuthToken: vi.fn((text: string | null | undefined) => text?.startsWith('encrypted:') ? text.slice('encrypted:'.length) : text), + createLogger: () => mocks.logger, +})); + +const { + PrismaOAuthClientProvider, + clearMcpServerClientCredentialsForObservedClient, +} = await import('./prismaOAuthClientProvider'); + +function createPrismaMock() { + return { + mcpServer: { + findFirst: vi.fn(), + updateMany: vi.fn(), + }, + userMcpServer: { + findUnique: vi.fn(), + update: vi.fn(), + updateMany: vi.fn(), + }, + }; +} + +function createProvider(prisma = createPrismaMock(), allowClientRegistration = false) { + return new PrismaOAuthClientProvider({ + prisma: prisma as never, + clientInvalidationPrisma: prisma as never, + serverId: 'server-1', + orgId: 1, + userId: 'user-1', + callbackUrl: 'https://sourcebot.example.com/api/ee/askmcp/callback', + allowClientRegistration, + }); +} + +beforeEach(() => { + vi.clearAllMocks(); +}); + +describe('PrismaOAuthClientProvider modes', () => { + test('connect-mode provider exposes saveClientInformation', () => { + const provider = createProvider(createPrismaMock(), true); + + expect('saveClientInformation' in provider).toBe(true); + expect(provider.saveClientInformation).toEqual(expect.any(Function)); + }); + + test('runtime and callback providers omit saveClientInformation', () => { + const provider = createProvider(); + + expect('saveClientInformation' in provider).toBe(false); + expect(provider.saveClientInformation).toBeUndefined(); + }); +}); + +describe('clearMcpServerClientCredentialsForObservedClient', () => { + test('matching observed clientInfo clears org clientInfo and all server tokens', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 1 }); + prisma.userMcpServer.updateMany.mockResolvedValue({ count: 2 }); + + const didClear = await clearMcpServerClientCredentialsForObservedClient({ + prisma: prisma as never, + serverId: 'server-1', + orgId: 1, + observedClientInfo: 'encrypted-client-info', + }); + + expect(didClear).toBe(true); + expect(prisma.mcpServer.updateMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + clientInfo: 'encrypted-client-info', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + expect(prisma.userMcpServer.updateMany).toHaveBeenCalledWith({ + where: { + serverId: 'server-1', + server: { orgId: 1 }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + }); + + test('stale observed clientInfo clears neither org clientInfo nor tokens', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 0 }); + + const didClear = await clearMcpServerClientCredentialsForObservedClient({ + prisma: prisma as never, + serverId: 'server-1', + orgId: 1, + observedClientInfo: 'stale-client-info', + }); + + expect(didClear).toBe(false); + expect(prisma.mcpServer.updateMany).toHaveBeenCalledOnce(); + expect(prisma.userMcpServer.updateMany).not.toHaveBeenCalled(); + }); +}); + +describe('PrismaOAuthClientProvider PKCE verifier storage', () => { + test('saveCodeVerifier encrypts the verifier before persisting it', async () => { + const prisma = createPrismaMock(); + prisma.userMcpServer.update.mockResolvedValue({ + userId: 'user-1', + serverId: 'server-1', + }); + const provider = createProvider(prisma); + + await provider.saveCodeVerifier('verifier-secret'); + + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + codeVerifier: 'encrypted:verifier-secret', + }, + }); + }); + + test('codeVerifier decrypts the stored verifier', async () => { + const prisma = createPrismaMock(); + prisma.userMcpServer.findUnique.mockResolvedValue({ + codeVerifier: 'encrypted:verifier-secret', + tokens: null, + state: null, + }); + const provider = createProvider(prisma); + + await expect(provider.codeVerifier()).resolves.toBe('verifier-secret'); + expect(mocks.logger.warn).not.toHaveBeenCalled(); + }); + + test('codeVerifier still accepts plaintext verifier values during migration and logs the fallback', async () => { + const prisma = createPrismaMock(); + prisma.userMcpServer.findUnique.mockResolvedValue({ + codeVerifier: 'plaintext-verifier', + tokens: null, + state: null, + }); + const provider = createProvider(prisma); + + await expect(provider.codeVerifier()).resolves.toBe('plaintext-verifier'); + expect(mocks.logger.warn).toHaveBeenCalledWith( + 'MCP OAuth code verifier was read without decryption; it may be plaintext from an older version.', + { + serverId: 'server-1', + orgId: 1, + userId: 'user-1', + }, + ); + }); +}); + +describe('PrismaOAuthClientProvider authorization redirect', () => { + test('overwrites existing prompt values with consent', async () => { + const prisma = createPrismaMock(); + prisma.userMcpServer.update.mockResolvedValue({ + userId: 'user-1', + serverId: 'server-1', + }); + const provider = createProvider(prisma); + + await provider.redirectToAuthorization(new URL('https://oauth.example.com/authorize?prompt=none&client_id=client-1')); + + expect(provider.authorizationUrl).toBe('https://oauth.example.com/authorize?prompt=consent&client_id=client-1'); + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + }); +}); + +describe('PrismaOAuthClientProvider static client information', () => { + test('clientInformation returns static OAuth client credentials', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.findFirst.mockResolvedValue({ + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.STATIC, + }); + const provider = createProvider(prisma); + + await expect(provider.clientInformation()).resolves.toEqual({ + client_id: 'client-id', + client_secret: 'client-secret', + }); + }); + + test('invalidate all preserves static client information and clears only the current user tokens and verifier', async () => { + const prisma = createPrismaMock(); + prisma.mcpServer.findFirst.mockResolvedValue({ + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.STATIC, + }); + prisma.mcpServer.updateMany.mockResolvedValue({ count: 0 }); + prisma.userMcpServer.update.mockResolvedValue({ + userId: 'user-1', + serverId: 'server-1', + }); + const provider = createProvider(prisma); + + await provider.clientInformation(); + await provider.invalidateCredentials('all'); + + expect(prisma.mcpServer.updateMany).toHaveBeenCalledWith({ + where: { + id: 'server-1', + orgId: 1, + clientInfo: 'encrypted:{"client_id":"client-id","client_secret":"client-secret"}', + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + expect(prisma.userMcpServer.updateMany).not.toHaveBeenCalled(); + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + expect(prisma.userMcpServer.update).toHaveBeenCalledWith({ + where: { + userId_serverId: { userId: 'user-1', serverId: 'server-1' }, + }, + data: { + codeVerifier: null, + state: null, + }, + }); + }); +}); diff --git a/packages/web/src/ee/features/chat/mcp/prismaOAuthClientProvider.ts b/packages/web/src/ee/features/chat/mcp/prismaOAuthClientProvider.ts new file mode 100644 index 000000000..ca1b46508 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcp/prismaOAuthClientProvider.ts @@ -0,0 +1,309 @@ +import 'server-only'; +import type { + OAuthClientProvider, + OAuthClientInformation, + OAuthClientMetadata, + OAuthTokens, +} from '@ai-sdk/mcp'; +import { McpServerClientInfoSource, type PrismaClient } from '@sourcebot/db'; +import { encryptOAuthToken, decryptOAuthToken, createLogger } from '@sourcebot/shared'; +import { __unsafePrisma } from '@/prisma'; +import { createMcpOAuthState } from './mcpOAuthReturnTo'; + +type McpOAuthPrismaClient = Pick; +const logger = createLogger('mcp-oauth-client-provider'); + +interface PrismaOAuthClientProviderOptions { + prisma: McpOAuthPrismaClient; + serverId: string; + orgId: number; + userId: string; + callbackUrl: string; + callbackReturnTo?: string; + allowClientRegistration?: boolean; + clientInvalidationPrisma?: McpOAuthPrismaClient; +} + +export interface ClearMcpServerClientCredentialsOptions { + prisma?: McpOAuthPrismaClient; + serverId: string; + orgId: number; + observedClientInfo: string | undefined; +} + +export async function clearMcpServerClientCredentialsForObservedClient({ + prisma = __unsafePrisma, + serverId, + orgId, + observedClientInfo, +}: ClearMcpServerClientCredentialsOptions): Promise { + if (!observedClientInfo) { + return false; + } + + const result = await prisma.mcpServer.updateMany({ + where: { + id: serverId, + orgId, + clientInfo: observedClientInfo, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + data: { clientInfo: null }, + }); + + if (result.count === 0) { + return false; + } + + await prisma.userMcpServer.updateMany({ + where: { + serverId, + server: { orgId }, + }, + data: { + tokens: null, + tokensExpiresAt: null, + }, + }); + + return true; +} + +/** + * Prisma-backed OAuthClientProvider for connecting to external MCP servers. + * + * Stores dynamic client registration on McpServer (per-org), and per-user + * tokens + ephemeral PKCE state on UserMcpServer. + */ +export class PrismaOAuthClientProvider implements OAuthClientProvider { + private readonly prisma: McpOAuthPrismaClient; + private readonly clientInvalidationPrisma: McpOAuthPrismaClient; + private readonly serverId: string; + private readonly orgId: number; + private readonly userId: string; + private readonly callbackUrl: string; + private readonly callbackReturnTo: string | undefined; + private observedClientInfo: string | undefined; + private observedClientInfoSource: McpServerClientInfoSource | undefined; + + /** Populated by redirectToAuthorization — read after auth() returns 'REDIRECT'. */ + public authorizationUrl: string | undefined; + + /** Only present in connect mode. If absent, the SDK cannot perform DCR. */ + declare saveClientInformation?: (info: OAuthClientInformation) => Promise; + + constructor({ + prisma, + serverId, + orgId, + userId, + callbackUrl, + callbackReturnTo, + allowClientRegistration = false, + clientInvalidationPrisma = __unsafePrisma, + }: PrismaOAuthClientProviderOptions) { + this.prisma = prisma; + this.clientInvalidationPrisma = clientInvalidationPrisma; + this.serverId = serverId; + this.orgId = orgId; + this.userId = userId; + this.callbackUrl = callbackUrl; + this.callbackReturnTo = callbackReturnTo; + + if (allowClientRegistration) { + this.saveClientInformation = async (info: OAuthClientInformation) => { + const encrypted = encryptOAuthToken(JSON.stringify(info)); + if (!encrypted) { + throw new Error('Failed to encrypt OAuth client information'); + } + + const result = await this.prisma.mcpServer.updateMany({ + where: { id: this.serverId, orgId: this.orgId }, + data: { + clientInfo: encrypted, + clientInfoSource: McpServerClientInfoSource.DYNAMIC, + }, + }); + if (result.count === 0) { + throw new Error('MCP server not found'); + } + + this.observedClientInfo = encrypted; + this.observedClientInfoSource = McpServerClientInfoSource.DYNAMIC; + }; + } + } + + get redirectUrl(): string | URL { + return this.callbackUrl; + } + + get clientMetadata(): OAuthClientMetadata { + return { + redirect_uris: [this.callbackUrl], + client_name: 'Sourcebot', + grant_types: ['authorization_code', 'refresh_token'], + response_types: ['code'], + token_endpoint_auth_method: 'none', + }; + } + + async clientInformation(): Promise { + const server = await this.prisma.mcpServer.findFirst({ + where: { id: this.serverId, orgId: this.orgId }, + select: { + clientInfo: true, + clientInfoSource: true, + }, + }); + if (!server?.clientInfo) { + this.observedClientInfo = undefined; + this.observedClientInfoSource = undefined; + return undefined; + } + + this.observedClientInfo = server.clientInfo; + this.observedClientInfoSource = server.clientInfoSource; + const decrypted = decryptOAuthToken(server.clientInfo); + return decrypted ? JSON.parse(decrypted) : undefined; + } + + async tokens(): Promise { + const userServer = await this.getUserServer(); + if (!userServer?.tokens) { + return undefined; + } + + const decrypted = decryptOAuthToken(userServer.tokens); + return decrypted ? JSON.parse(decrypted) : undefined; + } + + async saveTokens(tokens: OAuthTokens): Promise { + const encrypted = encryptOAuthToken(JSON.stringify(tokens)); + if (!encrypted) { + throw new Error('Failed to encrypt OAuth tokens'); + } + + const tokensExpiresAt = tokens.expires_in + ? new Date(Date.now() + tokens.expires_in * 1000) + : null; + await this.updateUserServer({ tokens: encrypted, tokensExpiresAt }); + } + + async codeVerifier(): Promise { + const userServer = await this.getUserServer(); + if (!userServer?.codeVerifier) { + throw new Error('No code verifier found'); + } + + const decrypted = decryptOAuthToken(userServer.codeVerifier); + if (!decrypted) { + throw new Error('Failed to decrypt OAuth code verifier'); + } + + if (decrypted === userServer.codeVerifier) { + logger.warn('MCP OAuth code verifier was read without decryption; it may be plaintext from an older version.', { + serverId: this.serverId, + orgId: this.orgId, + userId: this.userId, + }); + } + + return decrypted; + } + + async saveCodeVerifier(codeVerifier: string): Promise { + const encrypted = encryptOAuthToken(codeVerifier); + if (!encrypted) { + throw new Error('Failed to encrypt OAuth code verifier'); + } + + await this.updateUserServer({ codeVerifier: encrypted }); + } + + async state(): Promise { + return createMcpOAuthState(crypto.randomUUID(), this.callbackReturnTo); + } + + async saveState(state: string): Promise { + await this.updateUserServer({ state }); + } + + async storedState(): Promise { + const userServer = await this.getUserServer(); + return userServer?.state ?? undefined; + } + + async redirectToAuthorization(url: URL): Promise { + // Force the OAuth provider to show a consent/login screen on every authorization. + // This prevents a stolen-session attack where an attacker signs into Sourcebot on + // a victim's machine and silently obtains the victim's provider tokens via an + // existing browser session. + url.searchParams.set('prompt', 'consent'); + + // Clear stale tokens before starting a new authorization flow so the UI reflects + // that the user needs to complete OAuth again. + await this.invalidateCredentials('tokens'); + + this.authorizationUrl = url.toString(); + } + + async invalidateCredentials( + scope: 'all' | 'client' | 'tokens' | 'verifier' | 'discovery', + ): Promise { + if (scope === 'discovery') { + return; + } + + if (scope === 'all' || scope === 'client') { + const didClearDynamicClient = await clearMcpServerClientCredentialsForObservedClient({ + prisma: this.clientInvalidationPrisma, + serverId: this.serverId, + orgId: this.orgId, + observedClientInfo: this.observedClientInfo, + }); + if ( + scope === 'all' && + !didClearDynamicClient && + this.observedClientInfoSource === McpServerClientInfoSource.STATIC + ) { + await this.updateUserServer({ tokens: null, tokensExpiresAt: null }); + } + } + + if (scope === 'tokens') { + await this.updateUserServer({ tokens: null, tokensExpiresAt: null }); + } + + if (scope === 'all' || scope === 'verifier') { + await this.updateUserServer({ codeVerifier: null, state: null }); + } + } + + private async getUserServer() { + return this.prisma.userMcpServer.findUnique({ + where: { + userId_serverId: { userId: this.userId, serverId: this.serverId }, + }, + select: { + tokens: true, + codeVerifier: true, + state: true, + }, + }); + } + + private async updateUserServer(data: { + tokens?: string | null; + tokensExpiresAt?: Date | null; + codeVerifier?: string | null; + state?: string | null; + }) { + await this.prisma.userMcpServer.update({ + where: { + userId_serverId: { userId: this.userId, serverId: this.serverId }, + }, + data, + }); + } +} diff --git a/packages/web/src/ee/features/chat/mcp/types.ts b/packages/web/src/ee/features/chat/mcp/types.ts index 0d1c099ae..1698fd471 100644 --- a/packages/web/src/ee/features/chat/mcp/types.ts +++ b/packages/web/src/ee/features/chat/mcp/types.ts @@ -25,7 +25,7 @@ export interface McpServerToolUsageSummary { export interface GetMcpConfigurationResponse { servers: McpConfigurationServer[]; allowedMode: McpConfigurationAllowedMode; - isOAuthAvailable: boolean; + isAskAgentAvailable: boolean; } export interface ToolSummary { diff --git a/packages/web/src/ee/features/chat/mcpServerIconContext.tsx b/packages/web/src/ee/features/chat/mcpServerIconContext.tsx new file mode 100644 index 000000000..94628f4a5 --- /dev/null +++ b/packages/web/src/ee/features/chat/mcpServerIconContext.tsx @@ -0,0 +1,10 @@ +'use client'; + +import { createContext, useContext } from 'react'; + +// Maps sanitized server name (e.g. "linear") to a favicon URL. +export type McpServerIconMap = Record; + +export const McpServerIconContext = createContext({}); + +export const useMcpServerIconMap = () => useContext(McpServerIconContext); diff --git a/packages/web/src/ee/features/chat/toolApprovalContext.tsx b/packages/web/src/ee/features/chat/toolApprovalContext.tsx new file mode 100644 index 000000000..d4379c394 --- /dev/null +++ b/packages/web/src/ee/features/chat/toolApprovalContext.tsx @@ -0,0 +1,9 @@ +'use client'; + +import { createContext, useContext } from 'react'; +import type { ChatAddToolApproveResponseFunction } from 'ai'; + +const ToolApprovalContext = createContext(null); + +export const ToolApprovalProvider = ToolApprovalContext.Provider; +export const useToolApproval = () => useContext(ToolApprovalContext); \ No newline at end of file diff --git a/packages/web/src/features/chat/tools.ts b/packages/web/src/ee/features/chat/tools.ts similarity index 97% rename from packages/web/src/features/chat/tools.ts rename to packages/web/src/ee/features/chat/tools.ts index 149748bff..aee1a5897 100644 --- a/packages/web/src/features/chat/tools.ts +++ b/packages/web/src/ee/features/chat/tools.ts @@ -12,7 +12,7 @@ import { } from "@/features/tools"; import { ToolContext } from "@/features/tools/types"; import { ToolUIPart } from "ai"; -import { SBChatMessageToolTypes } from "./types"; +import { SBChatMessageToolTypes } from "@/features/chat/types"; export const createTools = (context: ToolContext) => ({ [readFileDefinition.name]: toVercelAITool(readFileDefinition, context), diff --git a/packages/web/src/features/chat/useExtractReferences.test.ts b/packages/web/src/ee/features/chat/useExtractReferences.test.ts similarity index 95% rename from packages/web/src/features/chat/useExtractReferences.test.ts rename to packages/web/src/ee/features/chat/useExtractReferences.test.ts index 7208aba46..f02bec857 100644 --- a/packages/web/src/features/chat/useExtractReferences.test.ts +++ b/packages/web/src/ee/features/chat/useExtractReferences.test.ts @@ -1,7 +1,7 @@ import { expect, test } from 'vitest' import { renderHook } from '@testing-library/react'; import { useExtractReferences } from './useExtractReferences'; -import { getFileReferenceId } from './utils'; +import { getFileReferenceId } from '@/features/chat/utils'; import { TextUIPart } from 'ai'; test('useExtractReferences extracts file references from text content', () => { diff --git a/packages/web/src/features/chat/useExtractReferences.ts b/packages/web/src/ee/features/chat/useExtractReferences.ts similarity index 81% rename from packages/web/src/features/chat/useExtractReferences.ts rename to packages/web/src/ee/features/chat/useExtractReferences.ts index 45ef173d8..7ab3eebe0 100644 --- a/packages/web/src/features/chat/useExtractReferences.ts +++ b/packages/web/src/ee/features/chat/useExtractReferences.ts @@ -1,9 +1,9 @@ 'use client'; import { useMemo } from "react"; -import { FileReference } from "./types"; -import { FILE_REFERENCE_REGEX } from "./constants"; -import { createFileReference } from "./utils"; +import { FileReference } from "@/features/chat/types"; +import { FILE_REFERENCE_REGEX } from "@/features/chat/constants"; +import { createFileReference } from "@/features/chat/utils"; import { TextUIPart } from "ai"; export const useExtractReferences = (part?: TextUIPart) => { diff --git a/packages/web/src/features/chat/useMessagePairs.test.ts b/packages/web/src/ee/features/chat/useMessagePairs.test.ts similarity index 98% rename from packages/web/src/features/chat/useMessagePairs.test.ts rename to packages/web/src/ee/features/chat/useMessagePairs.test.ts index a44179916..77179d06a 100644 --- a/packages/web/src/features/chat/useMessagePairs.test.ts +++ b/packages/web/src/ee/features/chat/useMessagePairs.test.ts @@ -1,5 +1,5 @@ import { expect, test } from 'vitest' -import { SBChatMessage } from './types'; +import { SBChatMessage } from '@/features/chat/types'; import { useMessagePairs } from './useMessagePairs'; import { renderHook } from '@testing-library/react'; diff --git a/packages/web/src/features/chat/useMessagePairs.ts b/packages/web/src/ee/features/chat/useMessagePairs.ts similarity index 96% rename from packages/web/src/features/chat/useMessagePairs.ts rename to packages/web/src/ee/features/chat/useMessagePairs.ts index 36ce7097d..71fa6036a 100644 --- a/packages/web/src/features/chat/useMessagePairs.ts +++ b/packages/web/src/ee/features/chat/useMessagePairs.ts @@ -1,7 +1,7 @@ 'use client'; import { useMemo } from "react"; -import { SBChatMessage } from "./types"; +import { SBChatMessage } from "@/features/chat/types"; // Pairs user messages with the assistant's response. export const useMessagePairs = (messages: SBChatMessage[]): [SBChatMessage, SBChatMessage | undefined][] => { diff --git a/packages/web/src/features/chat/useTOCItems.ts b/packages/web/src/ee/features/chat/useTOCItems.ts similarity index 100% rename from packages/web/src/features/chat/useTOCItems.ts rename to packages/web/src/ee/features/chat/useTOCItems.ts diff --git a/packages/web/src/ee/features/lighthouse/actions.ts b/packages/web/src/ee/features/lighthouse/actions.ts index 4336baac4..df1b98c90 100644 --- a/packages/web/src/ee/features/lighthouse/actions.ts +++ b/packages/web/src/ee/features/lighthouse/actions.ts @@ -4,91 +4,14 @@ import { sew } from "@/middleware/sew"; import { withAuth } from "@/middleware/withAuth"; import { withMinimumOrgRole } from "@/middleware/withMinimumOrgRole"; import { OrgRole } from "@sourcebot/db"; -import { ServiceError, ServiceErrorException } from "@/lib/serviceError"; +import { ServiceError } from "@/lib/serviceError"; import { StatusCodes } from "http-status-codes"; import { ErrorCode } from "@/lib/errorCodes"; -import { encryptActivationCode, decryptActivationCode, env } from "@sourcebot/shared"; -import { syncWithLighthouse } from "@/ee/features/lighthouse/servicePing"; +import { decryptActivationCode, env } from "@sourcebot/shared"; +import { syncWithLighthouse } from "@/features/billing/servicePing"; import { isServiceError } from "@/lib/utils"; -import { revalidatePath } from "next/cache"; -import { captureEvent } from "@/lib/posthog"; -import { UpsellSource } from "@/lib/posthogEvents"; -import { client } from "./client"; -import { Invoice } from "./types"; -import { z } from "zod"; - -export const activateLicense = async (activationCode: string): Promise<{ success: boolean } | ServiceError> => sew(() => - withAuth(async ({ org, role, prisma }) => - withMinimumOrgRole(role, OrgRole.OWNER, async () => { - // Check if a license already exists - const existing = await prisma.license.findUnique({ - where: { orgId: org.id }, - }); - - if (existing) { - return { - statusCode: StatusCodes.CONFLICT, - errorCode: ErrorCode.UNEXPECTED_ERROR, - message: "A license already exists for this organization.", - } satisfies ServiceError; - } - - await prisma.license.create({ - data: { - orgId: org.id, - activationCode: encryptActivationCode(activationCode), - }, - }); - - try { - // Bind the activation code to this install. This is the only - // call that mutates the binding on the Lighthouse side; the - // subsequent ping is pure read. - const activateResult = await client.activate({ - activationCode, - installId: env.SOURCEBOT_INSTALL_ID, - }); - - if (isServiceError(activateResult)) { - throw new ServiceErrorException(activateResult); - } - - // Immediately sync license data from Lighthouse. - await syncWithLighthouse(org.id); - } catch (e) { - // If activation or initial sync fails, remove the license record - await prisma.license.delete({ - where: { orgId: org.id }, - }); - - throw e; - } - - // Invalidate the (app) layout so BannerSlot re-resolves with the - // new license. - revalidatePath('/settings/license', 'layout'); - - return { success: true }; - }) - ) -); - -export const claimActivationCode = async (sessionId: string): Promise<{ activationCode: string } | ServiceError> => sew(() => - withAuth(async ({ role }) => - withMinimumOrgRole(role, OrgRole.OWNER, async () => { - const result = await client.claimActivationCode({ - sessionId, - installId: env.SOURCEBOT_INSTALL_ID, - }); - - if (isServiceError(result)) { - return result; - } - - return { activationCode: result.activationCode }; - }) - ) -); +import { client } from "@/features/billing/client"; +import { Invoice } from "@/features/billing/types"; export const refreshLicense = async (): Promise<{ success: boolean } | ServiceError> => sew(() => withAuth(async ({ org, role, prisma }) => @@ -112,147 +35,6 @@ export const refreshLicense = async (): Promise<{ success: boolean } | ServiceEr ) ); -export const createCheckoutSession = async ({ - source, - requestTrial = false, - interval = 'year', - returnPath: _returnPath = '/settings/license', - overrideEmail, -}: { - source: UpsellSource; - requestTrial?: boolean; - interval?: 'month' | 'year'; - returnPath?: string; - overrideEmail?: string; -}): Promise<{ url: string } | ServiceError> => sew(() => - withAuth(async ({ user, org, role, prisma }) => - withMinimumOrgRole(role, OrgRole.OWNER, async () => { - if (!user.email) { - return { - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.UNEXPECTED_ERROR, - message: "User does not have an email address.", - } satisfies ServiceError; - } - - // Validate the override on the server — never trust a client-supplied - // email. Fall back to the authenticated user's email when omitted. - let checkoutEmail = user.email; - if (overrideEmail !== undefined) { - const parsed = z.string().email().safeParse(overrideEmail); - if (!parsed.success) { - return { - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.UNEXPECTED_ERROR, - message: "Invalid overrideEmail.", - } satisfies ServiceError; - } - checkoutEmail = parsed.data; - } - - const memberCount = await prisma.userToOrg.count({ - where: { - orgId: org.id, - }, - }); - const quantity = Math.max(memberCount, 1); - - const existingLicense = await prisma.license.findUnique({ - where: { orgId: org.id }, - }); - const existingActivationCode = existingLicense - ? decryptActivationCode(existingLicense.activationCode) - : undefined; - - // Resolve the candidate against AUTH_URL so absolute paths, protocol- - // relative paths (`//evil.com`), and bare relative paths all get - // normalized through the URL parser. Reject anything that lands off- - // origin, carries a fragment, or already uses the reserved query keys - // we append below. - let returnPathname: string; - let returnSearch: string; - try { - const candidate = new URL(_returnPath, env.AUTH_URL); - const authOrigin = new URL(env.AUTH_URL).origin; - if (candidate.origin !== authOrigin) { - throw new Error('returnPath escapes AUTH_URL origin'); - } - if (candidate.hash) { - throw new Error('returnPath must not include a fragment'); - } - for (const reservedKey of ['checkout', 'session_id']) { - if (candidate.searchParams.has(reservedKey)) { - throw new Error(`returnPath must not include reserved query parameter: ${reservedKey}`); - } - } - returnPathname = candidate.pathname; - returnSearch = candidate.search; - } catch { - return { - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.UNEXPECTED_ERROR, - message: "Invalid returnPath.", - } satisfies ServiceError; - } - - await captureEvent('wa_upsell_checkout_started', { - source, - requestTrial, - interval, - returnPath: `${returnPathname}${returnSearch}`, - quantity, - }); - - // Build success/cancel URLs as raw strings so Stripe's literal - // `{CHECKOUT_SESSION_ID}` placeholder isn't URL-encoded by URL/ - // URLSearchParams (Stripe substitutes the raw token, not %7B...%7D). - const stripeSuccessQuery = 'checkout=success&session_id={CHECKOUT_SESSION_ID}'; - const successQuerySeparator = returnSearch ? '&' : '?'; - - const result = await client.checkout({ - email: checkoutEmail, - installId: env.SOURCEBOT_INSTALL_ID, - quantity, - requestTrial, - interval, - successUrl: `${env.AUTH_URL}${returnPathname}${returnSearch}${successQuerySeparator}${stripeSuccessQuery}`, - cancelUrl: `${env.AUTH_URL}${returnPathname}${returnSearch}`, - existingActivationCode, - }); - - if (isServiceError(result)) { - return result; - } - - return { url: result.url }; - }) - ) -); - -export const deactivateLicense = async (): Promise<{ success: boolean } | ServiceError> => sew(() => - withAuth(async ({ org, role, prisma }) => - withMinimumOrgRole(role, OrgRole.OWNER, async () => { - const existing = await prisma.license.findUnique({ - where: { orgId: org.id }, - }); - - if (!existing) { - return { - statusCode: StatusCodes.NOT_FOUND, - errorCode: ErrorCode.NOT_FOUND, - message: "No license found.", - } satisfies ServiceError; - } - - await prisma.license.delete({ - where: { orgId: org.id }, - }); - - return { success: true }; - }) - ) -); - export const createPortalSession = async (): Promise<{ url: string } | ServiceError> => sew(() => withAuth(async ({ org, role, prisma }) => withMinimumOrgRole(role, OrgRole.OWNER, async () => { diff --git a/packages/web/src/features/mcp/askCodebase.ts b/packages/web/src/ee/features/mcp/askCodebase.ts similarity index 91% rename from packages/web/src/features/mcp/askCodebase.ts rename to packages/web/src/ee/features/mcp/askCodebase.ts index 94bf4a3f1..520035435 100644 --- a/packages/web/src/features/mcp/askCodebase.ts +++ b/packages/web/src/ee/features/mcp/askCodebase.ts @@ -1,5 +1,7 @@ import { sew } from "@/middleware/sew"; -import { getConfiguredLanguageModels, getAISDKLanguageModelAndOptions, generateChatNameFromMessage, updateChatMessages } from "@/features/chat/utils.server"; +import { getConfiguredLanguageModels, updateChatMessages, checkAskEntitlement } from "@/features/chat/utils.server"; +import { generateChatNameFromMessage } from "@/ee/features/chat/llm.server"; +import { getAISDKLanguageModelAndOptions } from "@/features/chat/llm.server"; import { LanguageModelInfo, SBChatMessage, SearchScope } from "@/features/chat/types"; import { convertLLMOutputToPortableMarkdown, getAnswerPartFromAssistantMessage, getLanguageModelKey } from "@/features/chat/utils"; import { ErrorCode } from "@/lib/errorCodes"; @@ -12,7 +14,7 @@ import { StatusCodes } from "http-status-codes"; import { InferUIMessageChunk, UIDataTypes, UIMessage, UITools } from "ai"; import { captureEvent } from "@/lib/posthog"; import { createAudit } from "@/ee/features/audit/audit"; -import { createMessageStream } from "../chat/agent"; +import { createMessageStream } from "@/ee/features/chat/agent"; const logger = createLogger('ask-codebase-api'); @@ -44,6 +46,15 @@ const blockStreamUntilFinish = async => sew(() => withOptionalAuth(async ({ org, user, prisma }) => { + // Ask Sourcebot is a paid feature. askCodebase() is the single choke point + // for the programmatic ask path (the MCP `ask_codebase` tool and the + // /api/chat/blocking route both wrap it), so gating here covers both without + // double-gating at the tool-registration layer. + const askEntitlementError = await checkAskEntitlement(); + if (askEntitlementError) { + return askEntitlementError; + } + const { query, repos = [], languageModel: requestedLanguageModel, visibility: requestedVisibility, source } = params; const configuredModels = await getConfiguredLanguageModels(); diff --git a/packages/web/src/ee/features/mcp/constants.ts b/packages/web/src/ee/features/mcp/constants.ts new file mode 100644 index 000000000..95f7aa2aa --- /dev/null +++ b/packages/web/src/ee/features/mcp/constants.ts @@ -0,0 +1,8 @@ +export const MCP_DOCS_URL = "https://docs.sourcebot.dev/docs/features/mcp-server"; +export const PRICING_URL = "https://www.sourcebot.dev/pricing"; + +// Surfaced to MCP clients (and the programmatic blocking endpoint) when the +// instance is on the free plan. MCP clients render the agent-facing error +// text, so keep this human-readable and point at the upgrade path. +export const MCP_PAID_PLAN_REQUIRED_MESSAGE = + `The Sourcebot MCP server requires a paid subscription. Upgrade your plan at ${PRICING_URL} to enable it.`; diff --git a/packages/web/src/features/mcp/server.ts b/packages/web/src/ee/features/mcp/server.ts similarity index 90% rename from packages/web/src/features/mcp/server.ts rename to packages/web/src/ee/features/mcp/server.ts index 953953f14..1d3262a44 100644 --- a/packages/web/src/features/mcp/server.ts +++ b/packages/web/src/ee/features/mcp/server.ts @@ -1,15 +1,16 @@ import { languageModelInfoSchema, } from '@/features/chat/types'; -import { askCodebase } from '@/features/mcp/askCodebase'; +import { askCodebase } from '@/ee/features/mcp/askCodebase'; import { captureEvent } from '@/lib/posthog'; +import { hasEntitlement } from '@/lib/entitlements'; import { isServiceError } from '@/lib/utils'; import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js'; import { ChatVisibility } from '@sourcebot/db'; import { SOURCEBOT_VERSION } from '@sourcebot/shared'; import _dedent from 'dedent'; import { z } from 'zod'; -import { getConfiguredLanguageModelsInfo } from "../chat/utils.server"; +import { getConfiguredLanguageModelsInfo } from "@/features/chat/utils.server"; import { findSymbolDefinitionsDefinition, findSymbolReferencesDefinition, @@ -22,11 +23,19 @@ import { grepDefinition, ToolContext, globDefinition, -} from '../tools'; +} from '@/features/tools'; const dedent = _dedent.withOptions({ alignValues: true }); export async function createMcpServer(): Promise { + // Defense-in-depth: the MCP server is a paid feature. The /api/ee/mcp route + // gates on the `mcp` entitlement before calling this; this assertion + // backstops that contract so the server can't be constructed on a + // non-entitled deployment. + if (!(await hasEntitlement('mcp'))) { + throw new Error('The MCP server is not available in the current plan.'); + } + const server = new McpServer({ name: 'sourcebot-mcp-server', version: SOURCEBOT_VERSION, diff --git a/packages/web/src/features/mcp/types.ts b/packages/web/src/ee/features/mcp/types.ts similarity index 100% rename from packages/web/src/features/mcp/types.ts rename to packages/web/src/ee/features/mcp/types.ts diff --git a/packages/web/src/features/agents/review-agent/nodes/invokeDiffReviewLlm.ts b/packages/web/src/features/agents/review-agent/nodes/invokeDiffReviewLlm.ts index e43f137fe..915db6cec 100644 --- a/packages/web/src/features/agents/review-agent/nodes/invokeDiffReviewLlm.ts +++ b/packages/web/src/features/agents/review-agent/nodes/invokeDiffReviewLlm.ts @@ -1,5 +1,6 @@ import { sourcebot_file_diff_review, sourcebot_file_diff_review_schema } from "@/features/agents/review-agent/types"; -import { getAISDKLanguageModelAndOptions, getConfiguredLanguageModels } from "@/features/chat/utils.server"; +import { getConfiguredLanguageModels } from "@/features/chat/utils.server"; +import { getAISDKLanguageModelAndOptions } from "@/features/chat/llm.server"; import { env } from "@sourcebot/shared"; import { generateText } from "ai"; import fs from "fs"; diff --git a/packages/web/src/ee/features/lighthouse/CLAUDE.md b/packages/web/src/features/billing/CLAUDE.md similarity index 100% rename from packages/web/src/ee/features/lighthouse/CLAUDE.md rename to packages/web/src/features/billing/CLAUDE.md diff --git a/packages/web/src/features/billing/actions.ts b/packages/web/src/features/billing/actions.ts new file mode 100644 index 000000000..c9070b219 --- /dev/null +++ b/packages/web/src/features/billing/actions.ts @@ -0,0 +1,231 @@ +'use server'; + +import { sew } from "@/middleware/sew"; +import { withAuth } from "@/middleware/withAuth"; +import { withMinimumOrgRole } from "@/middleware/withMinimumOrgRole"; +import { OrgRole } from "@sourcebot/db"; +import { ServiceError, ServiceErrorException } from "@/lib/serviceError"; +import { StatusCodes } from "http-status-codes"; +import { ErrorCode } from "@/lib/errorCodes"; +import { encryptActivationCode, decryptActivationCode, env } from "@sourcebot/shared"; +import { syncWithLighthouse } from "./servicePing"; +import { isServiceError } from "@/lib/utils"; +import { revalidatePath } from "next/cache"; +import { captureEvent } from "@/lib/posthog"; +import { UpsellSource } from "@/lib/posthogEvents"; +import { client } from "./client"; +import { z } from "zod"; + +export const activateLicense = async (activationCode: string): Promise<{ success: boolean } | ServiceError> => sew(() => + withAuth(async ({ org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + // Check if a license already exists + const existing = await prisma.license.findUnique({ + where: { orgId: org.id }, + }); + + if (existing) { + return { + statusCode: StatusCodes.CONFLICT, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: "A license already exists for this organization.", + } satisfies ServiceError; + } + + await prisma.license.create({ + data: { + orgId: org.id, + activationCode: encryptActivationCode(activationCode), + }, + }); + + try { + // Bind the activation code to this install. This is the only + // call that mutates the binding on the Lighthouse side; the + // subsequent ping is pure read. + const activateResult = await client.activate({ + activationCode, + installId: env.SOURCEBOT_INSTALL_ID, + }); + + if (isServiceError(activateResult)) { + throw new ServiceErrorException(activateResult); + } + + // Immediately sync license data from Lighthouse. + await syncWithLighthouse(org.id); + } catch (e) { + // If activation or initial sync fails, remove the license record + await prisma.license.delete({ + where: { orgId: org.id }, + }); + + throw e; + } + + // Invalidate the (app) layout so BannerSlot re-resolves with the + // new license. + revalidatePath('/settings/license', 'layout'); + + return { success: true }; + }) + ) +); + +export const claimActivationCode = async (sessionId: string): Promise<{ activationCode: string } | ServiceError> => sew(() => + withAuth(async ({ role }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + const result = await client.claimActivationCode({ + sessionId, + installId: env.SOURCEBOT_INSTALL_ID, + }); + + if (isServiceError(result)) { + return result; + } + + return { activationCode: result.activationCode }; + }) + ) +); + +export const createCheckoutSession = async ({ + source, + requestTrial = false, + interval = 'year', + returnPath: _returnPath = '/settings/license', + overrideEmail, +}: { + source: UpsellSource; + requestTrial?: boolean; + interval?: 'month' | 'year'; + returnPath?: string; + overrideEmail?: string; +}): Promise<{ url: string } | ServiceError> => sew(() => + withAuth(async ({ user, org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + if (!user.email) { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: "User does not have an email address.", + } satisfies ServiceError; + } + + // Validate the override on the server — never trust a client-supplied + // email. Fall back to the authenticated user's email when omitted. + let checkoutEmail = user.email; + if (overrideEmail !== undefined) { + const parsed = z.string().email().safeParse(overrideEmail); + if (!parsed.success) { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: "Invalid overrideEmail.", + } satisfies ServiceError; + } + checkoutEmail = parsed.data; + } + + const memberCount = await prisma.userToOrg.count({ + where: { + orgId: org.id, + }, + }); + const quantity = Math.max(memberCount, 1); + + const existingLicense = await prisma.license.findUnique({ + where: { orgId: org.id }, + }); + const existingActivationCode = existingLicense + ? decryptActivationCode(existingLicense.activationCode) + : undefined; + + // Resolve the candidate against AUTH_URL so absolute paths, protocol- + // relative paths (`//evil.com`), and bare relative paths all get + // normalized through the URL parser. Reject anything that lands off- + // origin, carries a fragment, or already uses the reserved query keys + // we append below. + let returnPathname: string; + let returnSearch: string; + try { + const candidate = new URL(_returnPath, env.AUTH_URL); + const authOrigin = new URL(env.AUTH_URL).origin; + if (candidate.origin !== authOrigin) { + throw new Error('returnPath escapes AUTH_URL origin'); + } + if (candidate.hash) { + throw new Error('returnPath must not include a fragment'); + } + for (const reservedKey of ['checkout', 'session_id']) { + if (candidate.searchParams.has(reservedKey)) { + throw new Error(`returnPath must not include reserved query parameter: ${reservedKey}`); + } + } + returnPathname = candidate.pathname; + returnSearch = candidate.search; + } catch { + return { + statusCode: StatusCodes.BAD_REQUEST, + errorCode: ErrorCode.UNEXPECTED_ERROR, + message: "Invalid returnPath.", + } satisfies ServiceError; + } + + await captureEvent('wa_upsell_checkout_started', { + source, + requestTrial, + interval, + returnPath: `${returnPathname}${returnSearch}`, + quantity, + }); + + // Build success/cancel URLs as raw strings so Stripe's literal + // `{CHECKOUT_SESSION_ID}` placeholder isn't URL-encoded by URL/ + // URLSearchParams (Stripe substitutes the raw token, not %7B...%7D). + const stripeSuccessQuery = 'checkout=success&session_id={CHECKOUT_SESSION_ID}'; + const successQuerySeparator = returnSearch ? '&' : '?'; + + const result = await client.checkout({ + email: checkoutEmail, + installId: env.SOURCEBOT_INSTALL_ID, + quantity, + requestTrial, + interval, + successUrl: `${env.AUTH_URL}${returnPathname}${returnSearch}${successQuerySeparator}${stripeSuccessQuery}`, + cancelUrl: `${env.AUTH_URL}${returnPathname}${returnSearch}`, + existingActivationCode, + }); + + if (isServiceError(result)) { + return result; + } + + return { url: result.url }; + }) + ) +); + +export const deactivateLicense = async (): Promise<{ success: boolean } | ServiceError> => sew(() => + withAuth(async ({ org, role, prisma }) => + withMinimumOrgRole(role, OrgRole.OWNER, async () => { + const existing = await prisma.license.findUnique({ + where: { orgId: org.id }, + }); + + if (!existing) { + return { + statusCode: StatusCodes.NOT_FOUND, + errorCode: ErrorCode.NOT_FOUND, + message: "No license found.", + } satisfies ServiceError; + } + + await prisma.license.delete({ + where: { orgId: org.id }, + }); + + return { success: true }; + }) + ) +); diff --git a/packages/web/src/ee/features/lighthouse/checkoutDisclosures.tsx b/packages/web/src/features/billing/checkoutDisclosures.tsx similarity index 100% rename from packages/web/src/ee/features/lighthouse/checkoutDisclosures.tsx rename to packages/web/src/features/billing/checkoutDisclosures.tsx diff --git a/packages/web/src/ee/features/lighthouse/checkoutReturnHandler.tsx b/packages/web/src/features/billing/checkoutReturnHandler.tsx similarity index 100% rename from packages/web/src/ee/features/lighthouse/checkoutReturnHandler.tsx rename to packages/web/src/features/billing/checkoutReturnHandler.tsx diff --git a/packages/web/src/ee/features/lighthouse/client.ts b/packages/web/src/features/billing/client.ts similarity index 100% rename from packages/web/src/ee/features/lighthouse/client.ts rename to packages/web/src/features/billing/client.ts diff --git a/packages/web/src/ee/features/lighthouse/hasLicenseProvider.tsx b/packages/web/src/features/billing/hasLicenseProvider.tsx similarity index 100% rename from packages/web/src/ee/features/lighthouse/hasLicenseProvider.tsx rename to packages/web/src/features/billing/hasLicenseProvider.tsx diff --git a/packages/web/src/ee/features/lighthouse/licenseActivactionDialog.tsx b/packages/web/src/features/billing/licenseActivactionDialog.tsx similarity index 97% rename from packages/web/src/ee/features/lighthouse/licenseActivactionDialog.tsx rename to packages/web/src/features/billing/licenseActivactionDialog.tsx index a2f2debb2..257513c44 100644 --- a/packages/web/src/ee/features/lighthouse/licenseActivactionDialog.tsx +++ b/packages/web/src/features/billing/licenseActivactionDialog.tsx @@ -15,8 +15,8 @@ import { Input } from "@/components/ui/input"; import { Label } from "@/components/ui/label"; import { LoadingButton } from "@/components/ui/loading-button"; import { useToast } from "@/components/hooks/use-toast"; -import { activateLicense, deactivateLicense } from "@/ee/features/lighthouse/actions"; -import { useClaimActivationCode } from "@/ee/features/lighthouse/useClaimActivationCode"; +import { activateLicense, deactivateLicense } from "@/features/billing/actions"; +import { useClaimActivationCode } from "@/features/billing/useClaimActivationCode"; import { isServiceError } from "@/lib/utils"; import { useHasLicense } from "./hasLicenseProvider"; diff --git a/packages/web/src/ee/features/lighthouse/planComparisonTable.tsx b/packages/web/src/features/billing/planComparisonTable.tsx similarity index 94% rename from packages/web/src/ee/features/lighthouse/planComparisonTable.tsx rename to packages/web/src/features/billing/planComparisonTable.tsx index 7f5e6634c..d26bb1758 100644 --- a/packages/web/src/ee/features/lighthouse/planComparisonTable.tsx +++ b/packages/web/src/features/billing/planComparisonTable.tsx @@ -11,7 +11,7 @@ import { TableRow, } from "@/components/ui/table"; import { cn, formatCurrency } from "@/lib/utils"; -import { OffersResponse } from "@/ee/features/lighthouse/types"; +import { OffersResponse } from "@/features/billing/types"; interface FeatureLinkProps { text: string; @@ -123,6 +123,11 @@ export function PlanComparisonTable({ + + + + + diff --git a/packages/web/src/ee/features/lighthouse/servicePing.ts b/packages/web/src/features/billing/servicePing.ts similarity index 100% rename from packages/web/src/ee/features/lighthouse/servicePing.ts rename to packages/web/src/features/billing/servicePing.ts diff --git a/packages/web/src/ee/features/lighthouse/types.ts b/packages/web/src/features/billing/types.ts similarity index 100% rename from packages/web/src/ee/features/lighthouse/types.ts rename to packages/web/src/features/billing/types.ts diff --git a/packages/web/src/ee/features/lighthouse/upsellDialog.tsx b/packages/web/src/features/billing/upsellDialog.tsx similarity index 89% rename from packages/web/src/ee/features/lighthouse/upsellDialog.tsx rename to packages/web/src/features/billing/upsellDialog.tsx index 5984d64f9..0685d0e85 100644 --- a/packages/web/src/ee/features/lighthouse/upsellDialog.tsx +++ b/packages/web/src/features/billing/upsellDialog.tsx @@ -11,11 +11,11 @@ import { } from "@/components/ui/dialog"; import { LoadingButton } from "@/components/ui/loading-button"; import { Skeleton } from "@/components/ui/skeleton"; -import { createCheckoutSession } from "@/ee/features/lighthouse/actions"; -import { useHasLicense } from "@/ee/features/lighthouse/hasLicenseProvider"; -import { BillingInterval, PlanComparisonTable } from "@/ee/features/lighthouse/planComparisonTable"; -import { OffersResponse } from "@/ee/features/lighthouse/types"; -import { useOffers } from "@/ee/features/lighthouse/useOffers"; +import { createCheckoutSession } from "@/features/billing/actions"; +import { useHasLicense } from "@/features/billing/hasLicenseProvider"; +import { BillingInterval, PlanComparisonTable } from "@/features/billing/planComparisonTable"; +import { OffersResponse } from "@/features/billing/types"; +import { useOffers } from "@/features/billing/useOffers"; import { useRole } from "@/features/auth/useRole"; import useCaptureEvent from "@/hooks/useCaptureEvent"; import { UpsellSource } from "@/lib/posthogEvents"; @@ -23,7 +23,7 @@ import { cn, isServiceError } from "@/lib/utils"; import { OrgRole } from "@sourcebot/db"; import { ArrowUpCircle, ExternalLink, Loader2 } from "lucide-react"; import { useSession } from "next-auth/react"; -import { useCallback, useEffect, useMemo, useState } from "react"; +import { ReactNode, useCallback, useEffect, useMemo, useState } from "react"; import { CheckoutDisclosures } from "./checkoutDisclosures"; interface UpsellDialogProps { @@ -88,9 +88,13 @@ interface UpsellPanelProps { returnPath?: string; className?: string; licenseState?: UpsellLicenseState; + // Optional context-specific heading + subheading (e.g. "Upgrade to view Ask + // Sourcebot history"). Fall back to the billing-state-derived copy when omitted. + title?: string; + description?: ReactNode; } -export function UpsellPanel({ source, returnPath, className, licenseState = 'free' }: UpsellPanelProps) { +export function UpsellPanel({ source, returnPath, className, licenseState = 'free', title, description }: UpsellPanelProps) { const { data: offers, isPending, isError } = useOffers(); if (isError) { @@ -116,7 +120,7 @@ export function UpsellPanel({ source, returnPath, className, licenseState = 'fre return (
- +
); } @@ -127,9 +131,11 @@ interface UpsellPanelContentProps { returnPath?: string; variant: "dialog" | "inline"; licenseState: UpsellLicenseState; + titleOverride?: string; + descriptionOverride?: ReactNode; } -function UpsellPanelContent({ offers, source, returnPath, variant, licenseState }: UpsellPanelContentProps) { +function UpsellPanelContent({ offers, source, returnPath, variant, licenseState, titleOverride, descriptionOverride }: UpsellPanelContentProps) { const [billingInterval, setBillingInterval] = useState("year"); const [isCheckoutSessionCreating, setIsCheckoutSessionCreating] = useState(false); const { data: session } = useSession(); @@ -193,7 +199,7 @@ function UpsellPanelContent({ offers, source, returnPath, variant, licenseState }) }, [billingInterval, offers.trial.eligible, returnPath, toast, source, overrideEmail]); - const { title, description, buttonText } = useMemo(() => { + const { title: computedTitle, description: computedDescription, buttonText } = useMemo(() => { // Members can't upgrade the workspace themselves — show them the feature // table so they understand what's gated, but route them to the owner. if (!isOwner) { @@ -236,6 +242,11 @@ function UpsellPanelContent({ offers, source, returnPath, variant, licenseState } }, [isOwner, offers.trial.creditCardRequired, offers.trial.durationDays, offers.trial.eligible, licenseState]); + // Allow callers to override the heading/subheading with feature-specific + // context while keeping the billing-state-derived button text. + const title = titleOverride ?? computedTitle; + const description = descriptionOverride ?? computedDescription; + return ( <>
diff --git a/packages/web/src/ee/features/lighthouse/useClaimActivationCode.ts b/packages/web/src/features/billing/useClaimActivationCode.ts similarity index 100% rename from packages/web/src/ee/features/lighthouse/useClaimActivationCode.ts rename to packages/web/src/features/billing/useClaimActivationCode.ts diff --git a/packages/web/src/ee/features/lighthouse/useOffers.ts b/packages/web/src/features/billing/useOffers.ts similarity index 100% rename from packages/web/src/ee/features/lighthouse/useOffers.ts rename to packages/web/src/features/billing/useOffers.ts diff --git a/packages/web/src/features/chat/actions.ts b/packages/web/src/features/chat/actions.ts index c23f2404a..fccc6df8a 100644 --- a/packages/web/src/features/chat/actions.ts +++ b/packages/web/src/features/chat/actions.ts @@ -3,17 +3,20 @@ import { sew } from "@/middleware/sew"; import { createAudit } from "@/ee/features/audit/audit"; import { getAnonymousId, getOrCreateAnonymousId } from "@/lib/anonymousId"; -import { ErrorCode } from "@/lib/errorCodes"; import { captureEvent } from "@/lib/posthog"; -import { notFound, ServiceError } from "@/lib/serviceError"; +import { notFound } from "@/lib/serviceError"; import { withAuth, withOptionalAuth } from "@/middleware/withAuth"; import { ChatVisibility, Prisma } from "@sourcebot/db"; -import { StatusCodes } from "http-status-codes"; import { SBChatMessage } from "./types"; -import { generateChatNameFromMessage, getConfiguredLanguageModels, isChatSharedWithUser, isOwnerOfChat } from "./utils.server"; +import { checkAskEntitlement, isChatSharedWithUser, isOwnerOfChat } from "./utils.server"; export const createChat = async ({ source }: { source?: string } = {}) => sew(() => withOptionalAuth(async ({ org, user, prisma }) => { + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + const isGuestUser = user === undefined; // For anonymous users, get or create an anonymous ID to track ownership @@ -128,6 +131,11 @@ export const updateChatName = async ({ chatId, name }: { chatId: string, name: s export const updateChatVisibility = async ({ chatId, visibility }: { chatId: string, visibility: ChatVisibility }) => sew(() => withAuth(async ({ org, user, prisma }) => { + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + const chat = await prisma.chat.findUnique({ where: { id: chatId, @@ -168,54 +176,6 @@ export const updateChatVisibility = async ({ chatId, visibility }: { chatId: str }) ); -export const generateAndUpdateChatNameFromMessage = async ({ chatId, languageModelId, message }: { chatId: string, languageModelId: string, message: string }) => sew(() => - withOptionalAuth(async ({ prisma, user, org }) => { - const chat = await prisma.chat.findUnique({ - where: { - id: chatId, - orgId: org.id, - }, - }); - - if (!chat) { - return notFound(); - } - - const isOwner = await isOwnerOfChat(chat, user); - if (!isOwner) { - return notFound(); - } - - const languageModelConfig = - (await getConfiguredLanguageModels()) - .find((model) => model.model === languageModelId); - - if (!languageModelConfig) { - return { - statusCode: StatusCodes.BAD_REQUEST, - errorCode: ErrorCode.INVALID_REQUEST_BODY, - message: `Language model ${languageModelId} is not configured.`, - } satisfies ServiceError; - } - - const name = await generateChatNameFromMessage({ message, languageModelConfig }); - - await prisma.chat.update({ - where: { - id: chatId, - orgId: org.id, - }, - data: { - name: name, - }, - }) - - return { - success: true, - } - }) -) - export const deleteChat = async ({ chatId }: { chatId: string }) => sew(() => withAuth(async ({ org, user, prisma }) => { const chat = await prisma.chat.findUnique({ @@ -295,6 +255,11 @@ export const claimAnonymousChats = async () => sew(() => */ export const duplicateChat = async ({ chatId, newName }: { chatId: string, newName: string }) => sew(() => withOptionalAuth(async ({ org, user, prisma }) => { + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + const originalChat = await prisma.chat.findUnique({ where: { id: chatId, @@ -338,6 +303,11 @@ export const duplicateChat = async ({ chatId, newName }: { chatId: string, newNa */ export const getSharedWithUsersForChat = async ({ chatId }: { chatId: string }) => sew(() => withAuth(async ({ org, user, prisma }) => { + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + const chat = await prisma.chat.findUnique({ where: { id: chatId, @@ -377,6 +347,11 @@ export const getSharedWithUsersForChat = async ({ chatId }: { chatId: string }) */ export const shareChatWithUsers = async ({ chatId, userIds }: { chatId: string, userIds: string[] }) => sew(() => withAuth(async ({ org, user, prisma }) => { + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + const chat = await prisma.chat.findUnique({ where: { id: chatId, @@ -432,6 +407,11 @@ export const shareChatWithUsers = async ({ chatId, userIds }: { chatId: string, */ export const unshareChatWithUser = async ({ chatId, userId }: { chatId: string, userId: string }) => sew(() => withAuth(async ({ org, user, prisma }) => { + const askError = await checkAskEntitlement(); + if (askError) { + return askError; + } + const chat = await prisma.chat.findUnique({ where: { id: chatId, diff --git a/packages/web/src/features/chat/components/chatBox/chatBox.tsx b/packages/web/src/features/chat/components/chatBox/chatBox.tsx index 51d9a3f45..ed9f46153 100644 --- a/packages/web/src/features/chat/components/chatBox/chatBox.tsx +++ b/packages/web/src/features/chat/components/chatBox/chatBox.tsx @@ -26,7 +26,7 @@ import { usePathname } from "next/navigation"; import { PENDING_CHAT_SUBMISSION_SESSION_STORAGE_KEY } from "@/features/chat/constants"; import useCaptureEvent from "@/hooks/useCaptureEvent"; import { useHasEntitlement } from "@/features/entitlements/useHasEntitlement"; -import { UpsellDialog } from "@/ee/features/lighthouse/upsellDialog"; +import { UpsellDialog } from "@/features/billing/upsellDialog"; interface ChatBoxProps { onSubmit: (children: Descendant[], editor: CustomEditor) => void; diff --git a/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx b/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx index 1f09388e9..aea9e849b 100644 --- a/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx +++ b/packages/web/src/features/chat/components/chatBox/chatBoxPlusButton.tsx @@ -1,38 +1,9 @@ 'use client'; -import { Button } from "@/components/ui/button"; -import { - DropdownMenu, - DropdownMenuContent, - DropdownMenuItem, - DropdownMenuSeparator, - DropdownMenuSub, - DropdownMenuSubContent, - DropdownMenuSubTrigger, - DropdownMenuTrigger, -} from "@/components/ui/dropdown-menu"; -import { Switch } from "@/components/ui/switch"; -import { connectMcpToAsk, getMcpServersWithStatus } from "@/app/api/(client)/client"; -import { useToast } from "@/components/hooks/use-toast"; -import { McpFavicon } from "@/ee/features/chat/mcp/components/mcpFavicon"; -import { mcpQueryKeys } from "@/ee/features/chat/mcp/queryKeys"; -import { isServiceError } from "@/lib/utils"; -import { useQuery, useQueryClient } from "@tanstack/react-query"; -import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; -import { AlertTriangleIcon, CableIcon, Loader2Icon, PlusCircleIcon, PlusIcon, RefreshCwIcon, SettingsIcon } from "lucide-react"; -import { PlusButtonInfoCard } from "./plusButtonInfoCard"; -import { useRouter } from "next/navigation"; -import { useEffect, useRef, useState } from "react"; -import { useSlate } from "slate-react"; -import { Editor } from "slate"; -import type { CustomEditor, SearchScope } from "@/features/chat/types"; -import { - clearMcpOAuthDraft, - consumeMcpOAuthDraftForPath, - createMcpOAuthDraftPath, - saveMcpOAuthDraft, -} from "@/features/chat/mcpOAuthDraft"; -import { clearEditorHistory, resetEditor } from "@/features/chat/utils"; +import type { SearchScope } from "@/features/chat/types"; +import { useHasEntitlement } from "@/features/entitlements/useHasEntitlement"; +import { ConnectorsMenu } from "@/ee/features/chat/mcp/components/connectorsMenu"; +import { ConnectorsExplainerMenu } from "./connectorsExplainerMenu"; interface ChatBoxPlusButtonProps { selectedSearchScopes: SearchScope[]; @@ -41,272 +12,18 @@ interface ChatBoxPlusButtonProps { onDisabledMcpServerIdsChange: (ids: string[]) => void; } -interface ChatMenuMcpServer { - isConnected: boolean; - isAuthExpired: boolean; -} - -export function splitMcpServersForChatMenu(servers: T[]) { - return { - connectedServers: servers.filter((server) => server.isConnected || server.isAuthExpired), - connectableServers: servers.filter((server) => !server.isConnected && !server.isAuthExpired), - }; -} - -function restoreEditorChildren(editor: CustomEditor, children: CustomEditor['children']) { - editor.children = children; - editor.selection = { - anchor: Editor.end(editor, []), - focus: Editor.end(editor, []), - }; - clearEditorHistory(editor); - editor.onChange(); -} - -export const ChatBoxPlusButton = ({ - selectedSearchScopes, - onSelectedSearchScopesChange, - disabledMcpServerIds, - onDisabledMcpServerIdsChange, -}: ChatBoxPlusButtonProps) => { - const [connectingServerId, setConnectingServerId] = useState(null); - const editor = useSlate(); - const hasRestoredMcpOAuthDraft = useRef(false); - const isMountedRef = useRef(false); - const queryClient = useQueryClient(); - const router = useRouter(); - const { toast } = useToast(); - - const { data: servers = [], isError, isLoading, refetch } = useQuery({ - queryKey: mcpQueryKeys.serversWithStatus, - queryFn: async () => { - const result = await getMcpServersWithStatus(); - if (isServiceError(result)) { - throw new Error("Failed to load connectors"); - } - return result; - }, - }); - - useEffect(() => { - isMountedRef.current = true; - - return () => { - isMountedRef.current = false; - }; - }, []); - - useEffect(() => { - if (hasRestoredMcpOAuthDraft.current) { - return; - } - - const currentPath = createMcpOAuthDraftPath(window.location.pathname, window.location.search); - if (!currentPath) { - return; - } - - const draft = consumeMcpOAuthDraftForPath(currentPath); - if (!draft) { - return; - } - - hasRestoredMcpOAuthDraft.current = true; - - try { - restoreEditorChildren(editor, draft.children); - onSelectedSearchScopesChange(draft.selectedSearchScopes); - onDisabledMcpServerIdsChange(draft.disabledMcpServerIds); - } catch (error) { - resetEditor(editor); - editor.onChange(); - console.error('Failed to restore MCP OAuth draft:', error); - } - }, [editor, onDisabledMcpServerIdsChange, onSelectedSearchScopesChange]); - - const onToggle = (serverId: string, checked: boolean) => { - if (checked) { - onDisabledMcpServerIdsChange(disabledMcpServerIds.filter((id) => id !== serverId)); - } else { - onDisabledMcpServerIdsChange([...disabledMcpServerIds, serverId]); - } - }; - - const handleConnect = async (serverId: string) => { - setConnectingServerId(serverId); - const returnTo = createMcpOAuthDraftPath(window.location.pathname, window.location.search) ?? '/chat'; - - saveMcpOAuthDraft({ - returnTo, - children: editor.children, - selectedSearchScopes, - disabledMcpServerIds, - }); - - try { - const result = await connectMcpToAsk({ - serverId, - returnTo, - }); - - if (!isMountedRef.current) { - return; - } - - if (isServiceError(result)) { - clearMcpOAuthDraft(); - toast({ - description: `Failed to connect connector. ${result.message}`, - variant: "destructive", - }); - setConnectingServerId(null); - return; - } - - if (result.authorizationUrl) { - window.location.href = result.authorizationUrl; - return; - } - - clearMcpOAuthDraft(); - toast({ description: 'Connector is already connected.' }); - await queryClient.invalidateQueries({ queryKey: mcpQueryKeys.serversWithStatus }); - if (!isMountedRef.current) { - return; - } - setConnectingServerId(null); - } catch { - if (!isMountedRef.current) { - return; - } - - clearMcpOAuthDraft(); - toast({ - description: "Failed to connect connector.", - variant: "destructive", - }); - setConnectingServerId(null); - return; - } - }; +/** + * Entitlement-aware "+" button for the chat box. The connector machinery lives + * in ee/ and only ever renders/runs when the `ask` entitlement is present; + * free-plan users render the FSL explainer instead. Static-importing the ee/ + * component is fine — it is only invoked behind the entitlement check below. + */ +export const ChatBoxPlusButton = (props: ChatBoxPlusButtonProps) => { + const hasAskEntitlement = useHasEntitlement('ask'); - const { connectedServers, connectableServers } = splitMcpServersForChatMenu(servers); - const hasServers = connectedServers.length > 0 || connectableServers.length > 0; + if (hasAskEntitlement) { + return ; + } - return ( - - - - - - - - - - - - e.preventDefault()}> - - - - Connectors - - - {isError && !hasServers ? ( - { - e.preventDefault(); - refetch(); - }} - className="gap-2 text-destructive" - > - - Failed to load. Retry? - - ) : isLoading ? ( - - Loading connectors... - - ) : !hasServers ? ( - - No connectors available - - ) : ( - <> - {connectedServers.map((server) => { - const isEnabled = !server.isAuthExpired && !disabledMcpServerIds.includes(server.id); - return ( - e.preventDefault()} - disabled={server.isAuthExpired} - className="flex items-center justify-between gap-2" - > -
- {server.isAuthExpired ? ( - - ) : ( - - )} - {server.name} -
- onToggle(server.id, checked)} - disabled={server.isAuthExpired} - className="scale-75" - /> -
- ); - })} - {connectedServers.length > 0 && connectableServers.length > 0 && } - {connectableServers.map((server) => ( - { - e.preventDefault(); - void handleConnect(server.id); - }} - disabled={connectingServerId !== null} - className="group flex cursor-pointer items-center justify-between gap-2" - > -
- - {server.name} -
- {connectingServerId === server.id ? ( - - ) : ( - - )} -
- ))} - - )} - - router.push(`/settings/accountAskAgent`)} - > - - My connectors - - router.push(`/settings/workspaceAskAgent`)} - > - - Workspace connectors - -
-
-
-
- ); + return ; }; diff --git a/packages/web/src/features/chat/components/chatBox/connectorsExplainerMenu.tsx b/packages/web/src/features/chat/components/chatBox/connectorsExplainerMenu.tsx new file mode 100644 index 000000000..2a6f566ec --- /dev/null +++ b/packages/web/src/features/chat/components/chatBox/connectorsExplainerMenu.tsx @@ -0,0 +1,77 @@ +'use client'; + +import { useState } from "react"; +import { Button } from "@/components/ui/button"; +import { + DropdownMenu, + DropdownMenuContent, + DropdownMenuLabel, + DropdownMenuTrigger, +} from "@/components/ui/dropdown-menu"; +import { Tooltip, TooltipContent, TooltipTrigger } from "@/components/ui/tooltip"; +import { CableIcon, PlusIcon } from "lucide-react"; +import { PlusButtonInfoCard } from "./plusButtonInfoCard"; +import { UpsellDialog } from "@/features/billing/upsellDialog"; + +// TODO(ask): finalize the connectors docs URL once the page exists. +const CONNECTORS_DOCS_URL = "https://docs.sourcebot.dev/docs/features/ask/connectors"; + +/** + * Free-plan stand-in for the connectors menu. This is intentionally NOT in ee/: + * unlicensed users only ever render this explainer, never the real connector + * machinery (which lives in ee/ and runs solely behind the `mcp` entitlement). + * The "+" button stays visible so the feature is still discoverable, and the + * "paid plan" link opens the shared trial/upgrade dialog. + */ +export const ConnectorsExplainerMenu = () => { + const [isMenuOpen, setIsMenuOpen] = useState(false); + const [isUpsellOpen, setIsUpsellOpen] = useState(false); + + const openUpsell = () => { + // Close the dropdown first, then open the dialog on the next frame so the + // menu's overlay/pointer-events cleanup finishes before the dialog's focus + // trap mounts (avoids a Radix stacked-overlay race). + setIsMenuOpen(false); + requestAnimationFrame(() => setIsUpsellOpen(true)); + }; + + return ( + <> + + + + + + + + + + + + e.preventDefault()}> + + + Connectors + +

+ Connect external tools like Linear or Jira so the agent can pull in context beyond your code. Connectors are available on a{" "} + . Learn more +

+
+
+ + + ); +}; diff --git a/packages/web/src/features/chat/components/chatEntitlementMessage.tsx b/packages/web/src/features/chat/components/chatEntitlementMessage.tsx new file mode 100644 index 000000000..e59c91027 --- /dev/null +++ b/packages/web/src/features/chat/components/chatEntitlementMessage.tsx @@ -0,0 +1,39 @@ +"use client" + +import { ReactNode } from "react" +import { UpsellPanel } from "@/features/billing/upsellDialog" +import { UpsellSource } from "@/lib/posthogEvents" + +interface ChatEntitlementMessageProps { + source?: UpsellSource; + /** Context-specific heading (e.g. "Upgrade to view Ask Sourcebot history"). */ + title?: string; + /** Context-specific subheading describing the value (avoid repeating "Upgrade"). */ + description?: ReactNode; + returnPath?: string; +} + +/** + * Shown in place of the Ask experience when the deployment is not on a plan that + * includes Ask Sourcebot. This is FSL (not ee/) so it can render for free-plan + * users as the upsell surface, and it renders the shared feature-breakdown panel + * (plan comparison + trial/upgrade) without mounting any ee/ feature code. + */ +export function ChatEntitlementMessage({ + source = "chat", + title = "Upgrade to use Ask Sourcebot", + description = "Ask questions about your codebase and get answers with cited sources.", + returnPath = "/chat", +}: ChatEntitlementMessageProps) { + return ( +
+ +
+ ) +} diff --git a/packages/web/src/features/chat/llm.server.ts b/packages/web/src/features/chat/llm.server.ts new file mode 100644 index 000000000..38069dabb --- /dev/null +++ b/packages/web/src/features/chat/llm.server.ts @@ -0,0 +1,345 @@ +import 'server-only'; + +import { createPostHogClient, tryGetPostHogDistinctId } from "@/lib/posthog"; +import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; +import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; +import { createAzure } from '@ai-sdk/azure'; +import { createDeepSeek } from '@ai-sdk/deepseek'; +import { createGoogleGenerativeAI, GoogleLanguageModelOptions } from '@ai-sdk/google'; +import { createVertex } from '@ai-sdk/google-vertex'; +import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic'; +import { createMistral } from '@ai-sdk/mistral'; +import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai"; +import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; +import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; +import { createXai } from '@ai-sdk/xai'; +import { fromNodeProviderChain } from '@aws-sdk/credential-providers'; +import { createOpenRouter } from '@openrouter/ai-sdk-provider'; +import { withTracing } from "@posthog/ai"; +import { LanguageModel } from '@sourcebot/schemas/v3/languageModel.type'; +import { Token } from "@sourcebot/schemas/v3/shared.type"; +import { env, getTokenFromConfig } from '@sourcebot/shared'; +import { extractReasoningMiddleware, JSONValue, wrapLanguageModel } from "ai"; + +// @note: This module resolves a configured language model into an AI SDK +// provider object. It is intentionally FSL (open source) provider plumbing — +// it contains no Ask-specific logic and is shared by multiple features (the +// Ask chat agent, the MCP `ask_codebase` tool, AI search-assist, and the +// review agent). The re-licensed Ask logic (prompts, tools, threads, chat +// name generation) lives in `@/ee/features/chat`. + +export const getAISDKLanguageModelAndOptions = async (config: LanguageModel): Promise<{ + model: AISDKLanguageModelV3, + providerOptions?: Record>, + temperature?: number, +}> => { + const { provider, model: modelId } = config; + + const { model: _model, providerOptions } = await (async (): Promise<{ + model: AISDKLanguageModelV3, + providerOptions?: Record>, + }> => { + switch (provider) { + case 'amazon-bedrock': { + const aws = createAmazonBedrock({ + baseURL: config.baseUrl, + region: config.region ?? env.AWS_REGION, + accessKeyId: config.accessKeyId + ? await getTokenFromConfig(config.accessKeyId) + : env.AWS_ACCESS_KEY_ID, + secretAccessKey: config.accessKeySecret + ? await getTokenFromConfig(config.accessKeySecret) + : env.AWS_SECRET_ACCESS_KEY, + sessionToken: config.sessionToken + ? await getTokenFromConfig(config.sessionToken) + : env.AWS_SESSION_TOKEN, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + // Fallback to the default Node.js credential provider chain if no credentials are provided. + // See: https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/Package/-aws-sdk-credential-providers/#fromnodeproviderchain + credentialProvider: !config.accessKeyId && !config.accessKeySecret && !config.sessionToken + ? fromNodeProviderChain() + : undefined, + }); + + return { + model: aws(modelId), + }; + } + case 'anthropic': { + const anthropic = createAnthropic({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.ANTHROPIC_API_KEY, + authToken: config.authToken + ? await getTokenFromConfig(config.authToken) + : env.ANTHROPIC_AUTH_TOKEN, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + const isAdaptiveThinkingSupported = + modelId.startsWith('claude-opus-4-7'); + + return { + model: anthropic(modelId), + providerOptions: { + anthropic: { + thinking: isAdaptiveThinkingSupported ? { + type: "adaptive", + display: "summarized" + } : { + type: "enabled", + budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS, + } + } satisfies AnthropicProviderOptions, + }, + }; + } + case 'azure': { + const azure = createAzure({ + baseURL: config.baseUrl, + apiKey: config.token ? (await getTokenFromConfig(config.token)) : env.AZURE_API_KEY, + apiVersion: config.apiVersion, + resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + const reasoningSummary = config.reasoningSummary ?? 'auto'; + return { + model: azure(modelId), + providerOptions: { + openai: { + reasoningEffort: config.reasoningEffort ?? 'medium', + ...(reasoningSummary !== 'none' && { reasoningSummary }), + } satisfies OpenAIResponsesProviderOptions, + } + }; + } + case 'deepseek': { + const deepseek = createDeepSeek({ + baseURL: config.baseUrl, + apiKey: config.token ? (await getTokenFromConfig(config.token)) : env.DEEPSEEK_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: deepseek(modelId), + }; + } + case 'google-generative-ai': { + const google = createGoogleGenerativeAI({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.GOOGLE_GENERATIVE_AI_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: google(modelId), + providerOptions: { + google: { + thinkingConfig: { + includeThoughts: true, + thinkingBudget: config.thinkingBudget, + thinkingLevel: config.thinkingLevel + } + } satisfies GoogleLanguageModelOptions + } + }; + } + case 'google-vertex': { + const vertex = createVertex({ + project: config.project ?? env.GOOGLE_VERTEX_PROJECT, + location: config.region ?? env.GOOGLE_VERTEX_REGION, + ...(config.credentials ? { + googleAuthOptions: { + keyFilename: await getTokenFromConfig(config.credentials), + } + } : {}), + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: vertex(modelId), + providerOptions: { + vertex: { + thinkingConfig: { + includeThoughts: true, + thinkingBudget: + config.thinkingBudget ?? + env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS, + thinkingLevel: config.thinkingLevel, + } + } satisfies GoogleLanguageModelOptions + }, + }; + } + case 'google-vertex-anthropic': { + const vertexAnthropic = createVertexAnthropic({ + project: config.project ?? env.GOOGLE_VERTEX_PROJECT, + location: config.region ?? env.GOOGLE_VERTEX_REGION, + ...(config.credentials ? { + googleAuthOptions: { + keyFilename: await getTokenFromConfig(config.credentials), + } + } : {}), + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: vertexAnthropic(modelId), + }; + } + case 'mistral': { + const mistral = createMistral({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.MISTRAL_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: mistral(modelId), + }; + } + case 'openai': { + const openai = createOpenAI({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.OPENAI_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + const reasoningSummary = config.reasoningSummary ?? 'auto'; + return { + model: openai(modelId), + providerOptions: { + openai: { + reasoningEffort: config.reasoningEffort ?? 'medium', + ...(reasoningSummary !== 'none' && { reasoningSummary }), + } satisfies OpenAIResponsesProviderOptions, + }, + }; + } + case 'openai-compatible': { + const openai = createOpenAICompatible({ + baseURL: config.baseUrl, + name: config.displayName ?? modelId, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : undefined, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + queryParams: config.queryParams + ? await extractLanguageModelKeyValuePairs(config.queryParams) + : undefined, + }); + + const model = wrapLanguageModel({ + model: openai.chatModel(modelId), + middleware: [ + extractReasoningMiddleware({ + tagName: config.reasoningTag ?? 'think', + }), + ] + }); + + return { + model, + } + } + case 'openrouter': { + const openrouter = createOpenRouter({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.OPENROUTER_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: openrouter(modelId), + }; + } + case 'xai': { + const xai = createXai({ + baseURL: config.baseUrl, + apiKey: config.token + ? await getTokenFromConfig(config.token) + : env.XAI_API_KEY, + headers: config.headers + ? await extractLanguageModelKeyValuePairs(config.headers) + : undefined, + }); + + return { + model: xai(modelId), + }; + } + } + })(); + + const posthog = await createPostHogClient(); + const distinctId = await tryGetPostHogDistinctId(); + + // Only enable posthog LLM analytics for the ask GH experiment. + const model = env.EXPERIMENT_ASK_GH_ENABLED === 'true' ? + withTracing(_model, posthog, { + posthogDistinctId: distinctId, + }) : + _model; + + return { + model, + providerOptions, + temperature: config.temperature, + }; +} + +const extractLanguageModelKeyValuePairs = async ( + pairs: { + [k: string]: string | Token; + } +): Promise> => { + const resolvedPairs: Record = {}; + + if (!pairs) { + return resolvedPairs; + } + + for (const [key, val] of Object.entries(pairs)) { + if (typeof val === "string") { + resolvedPairs[key] = val; + continue; + } + + const value = await getTokenFromConfig(val); + resolvedPairs[key] = value; + } + + return resolvedPairs; +}; diff --git a/packages/web/src/features/chat/mcp/utils.test.ts b/packages/web/src/features/chat/mcp/utils.test.ts new file mode 100644 index 000000000..d3c887fc7 --- /dev/null +++ b/packages/web/src/features/chat/mcp/utils.test.ts @@ -0,0 +1,50 @@ +import { expect, test, describe } from 'vitest'; +import { getMcpFaviconUrl, sanitizeMcpServerName } from './utils'; + +describe('sanitizeMcpServerName', () => { + test('lowercases ASCII letters', () => { + expect(sanitizeMcpServerName('MyServer')).toBe('myserver'); + }); + + test('replaces special characters with underscores', () => { + expect(sanitizeMcpServerName('My Server!')).toBe('my_server_'); + }); + + test('preserves digits', () => { + expect(sanitizeMcpServerName('server123')).toBe('server123'); + }); + + test('replaces spaces and hyphens', () => { + expect(sanitizeMcpServerName('my-cool server')).toBe('my_cool_server'); + }); + + test('handles empty string', () => { + expect(sanitizeMcpServerName('')).toBe(''); + }); + + test('replaces unicode characters with underscores', () => { + expect(sanitizeMcpServerName('Ñoño')).toBe('_o_o'); + }); + + test('replaces all special characters', () => { + expect(sanitizeMcpServerName('@#$%')).toBe('____'); + }); + + test('returns already sanitized name unchanged', () => { + expect(sanitizeMcpServerName('linear')).toBe('linear'); + }); +}); + +describe('getMcpFaviconUrl', () => { + test('returns a Google favicon URL for a valid server URL', () => { + expect(getMcpFaviconUrl('https://mcp.linear.app/mcp')).toBe('https://www.google.com/s2/favicons?domain=https://mcp.linear.app&sz=32'); + }); + + test('returns a local Atlassian icon for the Atlassian prefab server', () => { + expect(getMcpFaviconUrl('https://mcp.atlassian.com/v1/mcp/authv2', 'Atlassian')).toMatch(/^data:image\/svg\+xml,/); + }); + + test('returns undefined for a malformed server URL', () => { + expect(getMcpFaviconUrl('not a url')).toBeUndefined(); + }); +}); diff --git a/packages/web/src/features/chat/mcp/utils.ts b/packages/web/src/features/chat/mcp/utils.ts new file mode 100644 index 000000000..5d1453cbb --- /dev/null +++ b/packages/web/src/features/chat/mcp/utils.ts @@ -0,0 +1,81 @@ +/** + * Sanitizes an MCP server name into a lowercase alphanumeric string suitable + * for use as a tool-name prefix (e.g. "My Server!" → "my_server_"). + * + * This is used to namespace MCP tools (mcp_{sanitizedName}__{toolName}) and + * to key favicon maps. Must be kept consistent everywhere — collisions on + * this value are prevented at server-creation time. + */ +export function sanitizeMcpServerName(name: string): string { + return name.toLowerCase().replace(/[^a-z0-9]/g, '_'); +} + +export function pluralize(count: number, singular: string, plural = `${singular}s`) { + return count === 1 ? singular : plural; +} + +const standardNumberFormatter = new Intl.NumberFormat(); +const compactNumberFormatter = new Intl.NumberFormat(undefined, { + notation: "compact", + maximumFractionDigits: 1, +}); + +export function formatCount(count: number) { + if (count >= 10_000) { + return compactNumberFormatter.format(count); + } + return standardNumberFormatter.format(count); +} + +export function formatUsageSharePercent(percent: number) { + if (percent <= 0) { + return "0%"; + } + if (percent < 1) { + return "<1%"; + } + if (percent < 10) { + return `${percent.toFixed(1).replace(/\.0$/, "")}%`; + } + return `${Math.round(percent)}%`; +} + +function createMcpIconDataUri(svg: string): string { + return `data:image/svg+xml,${encodeURIComponent(svg)}`; +} + +const atlassianIconSvg = ` + + + + + + + + + + + + + +`; + +const knownMcpFaviconUrlsBySanitizedName: Record = { + atlassian: createMcpIconDataUri(atlassianIconSvg), +}; + +export function getMcpFaviconUrl(serverUrl: string, serverName?: string): string | undefined { + if (serverName) { + const knownFaviconUrl = knownMcpFaviconUrlsBySanitizedName[sanitizeMcpServerName(serverName)]; + if (knownFaviconUrl) { + return knownFaviconUrl; + } + } + + try { + const origin = new URL(serverUrl).origin; + return `https://www.google.com/s2/favicons?domain=${origin}&sz=32`; + } catch { + return undefined; + } +} diff --git a/packages/web/src/features/chat/mcpOAuthDraft.test.ts b/packages/web/src/features/chat/mcpOAuthDraft.test.ts index 6f81f644e..030376136 100644 --- a/packages/web/src/features/chat/mcpOAuthDraft.test.ts +++ b/packages/web/src/features/chat/mcpOAuthDraft.test.ts @@ -1,5 +1,5 @@ import { beforeEach, describe, expect, test } from 'vitest'; -import { MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY } from './constants'; +import { MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY } from '@/features/chat/constants'; import { consumeMcpOAuthDraftForPath, normalizeMcpOAuthDraftPath, @@ -7,7 +7,7 @@ import { saveMcpOAuthDraft, } from './mcpOAuthDraft'; import type { Descendant } from 'slate'; -import type { SearchScope } from './types'; +import type { SearchScope } from '@/features/chat/types'; const children = [{ type: 'paragraph', diff --git a/packages/web/src/features/chat/mcpOAuthDraft.ts b/packages/web/src/features/chat/mcpOAuthDraft.ts index bbbf2a146..0c4fd655c 100644 --- a/packages/web/src/features/chat/mcpOAuthDraft.ts +++ b/packages/web/src/features/chat/mcpOAuthDraft.ts @@ -1,6 +1,6 @@ import type { Descendant } from "slate"; -import { MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY } from "./constants"; -import type { CustomText, MentionElement, ParagraphElement, SearchScope } from "./types"; +import { MCP_OAUTH_DRAFT_SESSION_STORAGE_KEY } from "@/features/chat/constants"; +import type { CustomText, MentionElement, ParagraphElement, SearchScope } from "@/features/chat/types"; const MCP_OAUTH_DRAFT_BASE_URL = 'https://sourcebot.invalid'; const MCP_OAUTH_DRAFT_MAX_AGE_MS = 30 * 60 * 1000; diff --git a/packages/web/src/features/chat/types.ts b/packages/web/src/features/chat/types.ts index 3c2619f14..9a6e970a8 100644 --- a/packages/web/src/features/chat/types.ts +++ b/packages/web/src/features/chat/types.ts @@ -4,7 +4,9 @@ import { HistoryEditor } from "slate-history"; import { ReactEditor, RenderElementProps } from "slate-react"; import { z } from "zod"; import { LanguageModel } from "@sourcebot/schemas/v3/index.type"; -import { createTools } from "./tools"; +// Type-only import: the chat message tool types are derived from the shape of the +// EE agent's tools, but no runtime dependency on ee/ is introduced (erased at build). +import type { createTools } from "@/ee/features/chat/tools"; export { sourceSchema } from "@/features/tools/types"; export type { FileSource, Source } from "@/features/tools/types"; import type { Source } from "@/features/tools/types"; diff --git a/packages/web/src/features/chat/utils.server.ts b/packages/web/src/features/chat/utils.server.ts index 9d00459f4..ffc3483a4 100644 --- a/packages/web/src/features/chat/utils.server.ts +++ b/packages/web/src/features/chat/utils.server.ts @@ -1,30 +1,33 @@ import 'server-only'; import { getAnonymousId } from '@/lib/anonymousId'; -import { createPostHogClient, tryGetPostHogDistinctId } from "@/lib/posthog"; -import { createAmazonBedrock } from '@ai-sdk/amazon-bedrock'; -import { AnthropicProviderOptions, createAnthropic } from '@ai-sdk/anthropic'; -import { createAzure } from '@ai-sdk/azure'; -import { createDeepSeek } from '@ai-sdk/deepseek'; -import { createGoogleGenerativeAI, GoogleLanguageModelOptions } from '@ai-sdk/google'; -import { createVertex } from '@ai-sdk/google-vertex'; -import { createVertexAnthropic } from '@ai-sdk/google-vertex/anthropic'; -import { createMistral } from '@ai-sdk/mistral'; -import { createOpenAI, OpenAIResponsesProviderOptions } from "@ai-sdk/openai"; -import { createOpenAICompatible } from "@ai-sdk/openai-compatible"; -import { LanguageModelV3 as AISDKLanguageModelV3 } from "@ai-sdk/provider"; -import { createXai } from '@ai-sdk/xai'; -import { fromNodeProviderChain } from '@aws-sdk/credential-providers'; -import { createOpenRouter } from '@openrouter/ai-sdk-provider'; -import { withTracing } from "@posthog/ai"; import { Chat, Prisma, PrismaClient, User } from '@sourcebot/db'; import { LanguageModel } from '@sourcebot/schemas/v3/languageModel.type'; -import { Token } from "@sourcebot/schemas/v3/shared.type"; -import { env, getTokenFromConfig, loadConfig } from '@sourcebot/shared'; -import { extractReasoningMiddleware, generateText, JSONValue, wrapLanguageModel } from "ai"; +import { env, loadConfig } from '@sourcebot/shared'; import fs from 'fs'; import path from 'path'; import { LanguageModelInfo, SBChatMessage } from './types'; +import { hasEntitlement } from '@/lib/entitlements'; +import { ServiceError } from '@/lib/serviceError'; +import { ErrorCode } from '@/lib/errorCodes'; +import { StatusCodes } from 'http-status-codes'; + +/** + * Returns a FORBIDDEN ServiceError when the deployment lacks the `ask` + * entitlement, or null when Ask is available. Gates the generative chat + * surfaces (message streaming, chat creation, sharing) server-side so the + * client gate can't be bypassed. + */ +export const checkAskEntitlement = async (): Promise => { + if (await hasEntitlement('ask')) { + return null; + } + return { + statusCode: StatusCodes.FORBIDDEN, + errorCode: ErrorCode.INSUFFICIENT_PERMISSIONS, + message: "Ask Sourcebot is not available in your current plan", + } satisfies ServiceError; +}; /** * Checks if the current user (authenticated or anonymous) is the owner of a chat. @@ -100,6 +103,7 @@ export const updateChatMessages = async ({ fs.writeFileSync(chatFile, JSON.stringify(messages, null, 2)); } }; + /** * Returns the full configuration of the language models. * @@ -129,346 +133,3 @@ export const getConfiguredLanguageModelsInfo = async () => { displayName: model.displayName, })); }; - -export const generateChatNameFromMessage = async ({ message, languageModelConfig }: { message: string, languageModelConfig: LanguageModel }) => { - const { model } = await getAISDKLanguageModelAndOptions(languageModelConfig); - - const prompt = `Convert this question into a short topic title (max 50 characters). - -Rules: -- Do NOT include question words (what, where, how, why, when, which) -- Do NOT end with a question mark -- Capitalize the first letter of the title -- Focus on the subject/topic being discussed -- Make it sound like a file name or category - -Examples: -"Where is the authentication code?" → "Authentication Code" -"How to setup the database?" → "Database Setup" -"What are the API endpoints?" → "API Endpoints" - -User question: ${message}`; - - const result = await generateText({ - model, - prompt, - }); - - return result.text; -} - -export const getAISDKLanguageModelAndOptions = async (config: LanguageModel): Promise<{ - model: AISDKLanguageModelV3, - providerOptions?: Record>, - temperature?: number, -}> => { - const { provider, model: modelId } = config; - - const { model: _model, providerOptions } = await (async (): Promise<{ - model: AISDKLanguageModelV3, - providerOptions?: Record>, - }> => { - switch (provider) { - case 'amazon-bedrock': { - const aws = createAmazonBedrock({ - baseURL: config.baseUrl, - region: config.region ?? env.AWS_REGION, - accessKeyId: config.accessKeyId - ? await getTokenFromConfig(config.accessKeyId) - : env.AWS_ACCESS_KEY_ID, - secretAccessKey: config.accessKeySecret - ? await getTokenFromConfig(config.accessKeySecret) - : env.AWS_SECRET_ACCESS_KEY, - sessionToken: config.sessionToken - ? await getTokenFromConfig(config.sessionToken) - : env.AWS_SESSION_TOKEN, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - // Fallback to the default Node.js credential provider chain if no credentials are provided. - // See: https://docs.aws.amazon.com/AWSJavaScriptSDK/v3/latest/Package/-aws-sdk-credential-providers/#fromnodeproviderchain - credentialProvider: !config.accessKeyId && !config.accessKeySecret && !config.sessionToken - ? fromNodeProviderChain() - : undefined, - }); - - return { - model: aws(modelId), - }; - } - case 'anthropic': { - const anthropic = createAnthropic({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.ANTHROPIC_API_KEY, - authToken: config.authToken - ? await getTokenFromConfig(config.authToken) - : env.ANTHROPIC_AUTH_TOKEN, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - const isAdaptiveThinkingSupported = - modelId.startsWith('claude-opus-4-7'); - - return { - model: anthropic(modelId), - providerOptions: { - anthropic: { - thinking: isAdaptiveThinkingSupported ? { - type: "adaptive", - display: "summarized" - } : { - type: "enabled", - budgetTokens: env.ANTHROPIC_THINKING_BUDGET_TOKENS, - } - } satisfies AnthropicProviderOptions, - }, - }; - } - case 'azure': { - const azure = createAzure({ - baseURL: config.baseUrl, - apiKey: config.token ? (await getTokenFromConfig(config.token)) : env.AZURE_API_KEY, - apiVersion: config.apiVersion, - resourceName: config.resourceName ?? env.AZURE_RESOURCE_NAME, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - const reasoningSummary = config.reasoningSummary ?? 'auto'; - return { - model: azure(modelId), - providerOptions: { - openai: { - reasoningEffort: config.reasoningEffort ?? 'medium', - ...(reasoningSummary !== 'none' && { reasoningSummary }), - } satisfies OpenAIResponsesProviderOptions, - } - }; - } - case 'deepseek': { - const deepseek = createDeepSeek({ - baseURL: config.baseUrl, - apiKey: config.token ? (await getTokenFromConfig(config.token)) : env.DEEPSEEK_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: deepseek(modelId), - }; - } - case 'google-generative-ai': { - const google = createGoogleGenerativeAI({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.GOOGLE_GENERATIVE_AI_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: google(modelId), - providerOptions: { - google: { - thinkingConfig: { - includeThoughts: true, - thinkingBudget: config.thinkingBudget, - thinkingLevel: config.thinkingLevel - } - } satisfies GoogleLanguageModelOptions - } - }; - } - case 'google-vertex': { - const vertex = createVertex({ - project: config.project ?? env.GOOGLE_VERTEX_PROJECT, - location: config.region ?? env.GOOGLE_VERTEX_REGION, - ...(config.credentials ? { - googleAuthOptions: { - keyFilename: await getTokenFromConfig(config.credentials), - } - } : {}), - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: vertex(modelId), - providerOptions: { - vertex: { - thinkingConfig: { - includeThoughts: true, - thinkingBudget: - config.thinkingBudget ?? - env.GOOGLE_VERTEX_THINKING_BUDGET_TOKENS, - thinkingLevel: config.thinkingLevel, - } - } satisfies GoogleLanguageModelOptions - }, - }; - } - case 'google-vertex-anthropic': { - const vertexAnthropic = createVertexAnthropic({ - project: config.project ?? env.GOOGLE_VERTEX_PROJECT, - location: config.region ?? env.GOOGLE_VERTEX_REGION, - ...(config.credentials ? { - googleAuthOptions: { - keyFilename: await getTokenFromConfig(config.credentials), - } - } : {}), - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: vertexAnthropic(modelId), - }; - } - case 'mistral': { - const mistral = createMistral({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.MISTRAL_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: mistral(modelId), - }; - } - case 'openai': { - const openai = createOpenAI({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.OPENAI_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - const reasoningSummary = config.reasoningSummary ?? 'auto'; - return { - model: openai(modelId), - providerOptions: { - openai: { - reasoningEffort: config.reasoningEffort ?? 'medium', - ...(reasoningSummary !== 'none' && { reasoningSummary }), - } satisfies OpenAIResponsesProviderOptions, - }, - }; - } - case 'openai-compatible': { - const openai = createOpenAICompatible({ - baseURL: config.baseUrl, - name: config.displayName ?? modelId, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : undefined, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - queryParams: config.queryParams - ? await extractLanguageModelKeyValuePairs(config.queryParams) - : undefined, - }); - - const model = wrapLanguageModel({ - model: openai.chatModel(modelId), - middleware: [ - extractReasoningMiddleware({ - tagName: config.reasoningTag ?? 'think', - }), - ] - }); - - return { - model, - } - } - case 'openrouter': { - const openrouter = createOpenRouter({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.OPENROUTER_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: openrouter(modelId), - }; - } - case 'xai': { - const xai = createXai({ - baseURL: config.baseUrl, - apiKey: config.token - ? await getTokenFromConfig(config.token) - : env.XAI_API_KEY, - headers: config.headers - ? await extractLanguageModelKeyValuePairs(config.headers) - : undefined, - }); - - return { - model: xai(modelId), - }; - } - } - })(); - - const posthog = await createPostHogClient(); - const distinctId = await tryGetPostHogDistinctId(); - - // Only enable posthog LLM analytics for the ask GH experiment. - const model = env.EXPERIMENT_ASK_GH_ENABLED === 'true' ? - withTracing(_model, posthog, { - posthogDistinctId: distinctId, - }) : - _model; - - return { - model, - providerOptions, - temperature: config.temperature, - }; -} - -const extractLanguageModelKeyValuePairs = async ( - pairs: { - [k: string]: string | Token; - } -): Promise> => { - const resolvedPairs: Record = {}; - - if (!pairs) { - return resolvedPairs; - } - - for (const [key, val] of Object.entries(pairs)) { - if (typeof val === "string") { - resolvedPairs[key] = val; - continue; - } - - const value = await getTokenFromConfig(val); - resolvedPairs[key] = value; - } - - return resolvedPairs; -}; \ No newline at end of file diff --git a/packages/web/src/features/searchAssist/actions.ts b/packages/web/src/features/searchAssist/actions.ts index b9ebb8266..cf02e6087 100644 --- a/packages/web/src/features/searchAssist/actions.ts +++ b/packages/web/src/features/searchAssist/actions.ts @@ -1,7 +1,8 @@ 'use server'; import { sew } from "@/middleware/sew"; -import { getConfiguredLanguageModels, getAISDKLanguageModelAndOptions } from "../chat/utils.server"; +import { getConfiguredLanguageModels } from "../chat/utils.server"; +import { getAISDKLanguageModelAndOptions } from "@/features/chat/llm.server"; import { ErrorCode } from "@/lib/errorCodes"; import { ServiceError } from "@/lib/serviceError"; import { withOptionalAuth } from "@/middleware/withAuth"; diff --git a/packages/web/src/features/userManagement/actions.ts b/packages/web/src/features/userManagement/actions.ts index 3531d9e84..859617f91 100644 --- a/packages/web/src/features/userManagement/actions.ts +++ b/packages/web/src/features/userManagement/actions.ts @@ -1,7 +1,7 @@ 'use server'; import { createAudit } from "@/ee/features/audit/audit"; -import { syncWithLighthouse } from "@/ee/features/lighthouse/servicePing"; +import { syncWithLighthouse } from "@/features/billing/servicePing"; import InviteUserEmail from "@/emails/inviteUserEmail"; import JoinRequestApprovedEmail from "@/emails/joinRequestApprovedEmail"; import { addUserToOrganization, orgHasAvailability } from "@/lib/authUtils"; diff --git a/packages/web/src/initialize.ts b/packages/web/src/initialize.ts index c866aeaff..33cf1491a 100644 --- a/packages/web/src/initialize.ts +++ b/packages/web/src/initialize.ts @@ -1,5 +1,5 @@ import { __unsafePrisma } from "@/prisma"; -import { startServicePingCronJob } from '@/ee/features/lighthouse/servicePing'; +import { startServicePingCronJob } from '@/features/billing/servicePing'; import { startChangelogPollingJob } from '@/features/changelog/pollChangelog'; import { createLogger, env } from "@sourcebot/shared"; import { hasEntitlement } from '@/lib/entitlements'; diff --git a/packages/web/src/lib/authUtils.ts b/packages/web/src/lib/authUtils.ts index 1f943075e..9215cabb7 100644 --- a/packages/web/src/lib/authUtils.ts +++ b/packages/web/src/lib/authUtils.ts @@ -7,7 +7,7 @@ import { createLogger, getSeatCap } from "@sourcebot/shared"; import { createAudit } from "@/ee/features/audit/audit"; import { StatusCodes } from "http-status-codes"; import { ErrorCode } from "./errorCodes"; -import { syncWithLighthouse } from "@/ee/features/lighthouse/servicePing"; +import { syncWithLighthouse } from "@/features/billing/servicePing"; import { hasEntitlement } from "./entitlements"; const logger = createLogger('web-auth-utils'); diff --git a/packages/web/src/lib/posthogEvents.ts b/packages/web/src/lib/posthogEvents.ts index 4326036d0..6a4001f01 100644 --- a/packages/web/src/lib/posthogEvents.ts +++ b/packages/web/src/lib/posthogEvents.ts @@ -4,8 +4,12 @@ export type UpsellSource = 'sidebar' | 'analytics_settings' | 'chat_box' | + 'chat' | + 'chats' | 'onboard' | - 'license_settings'; + 'license_settings' | + 'mcp_settings' | + 'chat_connectors'; export type SourcebotWebClientSource = 'sourcebot-web-client'; export type AskMcpAnalyticsSource = SourcebotWebClientSource | 'sourcebot-ask-agent';