diff --git a/apps/sim/app/api/speech/token/route.ts b/apps/sim/app/api/speech/token/route.ts index 9e55c50084c..b4a5835b9eb 100644 --- a/apps/sim/app/api/speech/token/route.ts +++ b/apps/sim/app/api/speech/token/route.ts @@ -4,7 +4,7 @@ import { createLogger } from '@sim/logger' import { eq } from 'drizzle-orm' import { type NextRequest, NextResponse } from 'next/server' import { getSession } from '@/lib/auth' -import { hasExceededCostLimit } from '@/lib/billing/core/subscription' +import { checkServerSideUsageLimits } from '@/lib/billing/calculations/usage-monitor' import { recordUsage } from '@/lib/billing/core/usage-log' import { env } from '@/lib/core/config/env' import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags' @@ -110,11 +110,14 @@ export async function POST(request: NextRequest) { } } - if (billingUserId && isBillingEnabled) { - const exceeded = await hasExceededCostLimit(billingUserId) - if (exceeded) { + if (billingUserId) { + const usageCheck = await checkServerSideUsageLimits(billingUserId) + if (usageCheck.isExceeded) { return NextResponse.json( - { error: 'Usage limit exceeded. Please upgrade your plan to continue.' }, + { + error: + usageCheck.message || 'Usage limit exceeded. Please upgrade your plan to continue.', + }, { status: 402 } ) } diff --git a/apps/sim/app/workspace/[workspaceId]/home/components/user-input/user-input.tsx b/apps/sim/app/workspace/[workspaceId]/home/components/user-input/user-input.tsx index e02c73bbb9f..981dc9d0e58 100644 --- a/apps/sim/app/workspace/[workspaceId]/home/components/user-input/user-input.tsx +++ b/apps/sim/app/workspace/[workspaceId]/home/components/user-input/user-input.tsx @@ -39,6 +39,7 @@ import { extractContextTokens, } from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/copilot/components/user-input/utils' import { useWorkflowMap } from '@/hooks/queries/workflows' +import { useSettingsNavigation } from '@/hooks/use-settings-navigation' import { useSpeechToText } from '@/hooks/use-speech-to-text' import type { ChatContext } from '@/stores/panel' @@ -120,6 +121,7 @@ export function UserInput({ onEnterWhileEmpty, }: UserInputProps) { const { workspaceId } = useParams<{ workspaceId: string }>() + const { navigateToSettings } = useSettingsNavigation() const { data: workflowsById = {} } = useWorkflowMap(workspaceId) const { data: session } = useSession() const [value, setValue] = useState(defaultValue) @@ -239,12 +241,19 @@ export function UserInput({ valueRef.current = newVal }, []) + const handleUsageLimitExceeded = useCallback(() => { + navigateToSettings({ section: 'subscription' }) + }, [navigateToSettings]) + const { isListening, isSupported: isSttSupported, toggleListening: rawToggle, resetTranscript, - } = useSpeechToText({ onTranscript: handleTranscript }) + } = useSpeechToText({ + onTranscript: handleTranscript, + onUsageLimitExceeded: handleUsageLimitExceeded, + }) const toggleListening = useCallback(() => { if (!isListening) { diff --git a/apps/sim/hooks/use-speech-to-text.ts b/apps/sim/hooks/use-speech-to-text.ts index 76346f5faf3..d426ac42759 100644 --- a/apps/sim/hooks/use-speech-to-text.ts +++ b/apps/sim/hooks/use-speech-to-text.ts @@ -18,6 +18,7 @@ export type PermissionState = 'prompt' | 'granted' | 'denied' interface UseSpeechToTextProps { onTranscript: (text: string) => void + onUsageLimitExceeded?: () => void language?: string } @@ -31,6 +32,7 @@ interface UseSpeechToTextReturn { export function useSpeechToText({ onTranscript, + onUsageLimitExceeded, language, }: UseSpeechToTextProps): UseSpeechToTextReturn { const [isListening, setIsListening] = useState(false) @@ -38,6 +40,7 @@ export function useSpeechToText({ const [permissionState, setPermissionState] = useState('prompt') const onTranscriptRef = useRef(onTranscript) + const onUsageLimitExceededRef = useRef(onUsageLimitExceeded) const languageRef = useRef(language) const mountedRef = useRef(true) const startingRef = useRef(false) @@ -55,6 +58,7 @@ export function useSpeechToText({ const committedTextRef = useRef('') onTranscriptRef.current = onTranscript + onUsageLimitExceededRef.current = onUsageLimitExceeded languageRef.current = language useEffect(() => { @@ -165,6 +169,10 @@ export function useSpeechToText({ }) if (!tokenResponse.ok) { + if (tokenResponse.status === 402) { + onUsageLimitExceededRef.current?.() + return false + } const body = await tokenResponse.json().catch(() => ({})) throw new Error(body.error || 'Failed to get speech token') } diff --git a/apps/sim/lib/billing/core/subscription.ts b/apps/sim/lib/billing/core/subscription.ts index 916cb84f008..2e62e61206f 100644 --- a/apps/sim/lib/billing/core/subscription.ts +++ b/apps/sim/lib/billing/core/subscription.ts @@ -1,13 +1,11 @@ import { db } from '@sim/db' -import { member, subscription, user, userStats } from '@sim/db/schema' +import { member, subscription, user } from '@sim/db/schema' import { createLogger } from '@sim/logger' import { and, eq, inArray, sql } from 'drizzle-orm' import { getEffectiveBillingStatus, isOrganizationBillingBlocked } from '@/lib/billing/core/access' import { getHighestPrioritySubscription } from '@/lib/billing/core/plan' -import { getUserUsageLimit } from '@/lib/billing/core/usage' import { getPlanTierCredits, - isOrgPlan, isPro as isPlanPro, isTeam as isPlanTeam, } from '@/lib/billing/plan-helpers' @@ -16,12 +14,9 @@ import { checkProPlan, checkTeamPlan, ENTITLED_SUBSCRIPTION_STATUSES, - getFreeTierLimit, - getPerUserMinimumLimit, hasUsableSubscriptionAccess, USABLE_SUBSCRIPTION_STATUSES, } from '@/lib/billing/subscriptions/utils' -import type { UserSubscriptionState } from '@/lib/billing/types' import { isAccessControlEnabled, isBillingEnabled, @@ -485,145 +480,6 @@ export async function hasLiveSyncAccess(userId: string): Promise { } } -/** - * Check if user has exceeded their cost limit based on current period usage - */ -export async function hasExceededCostLimit(userId: string): Promise { - try { - if (!isBillingEnabled) { - return false - } - - const subscription = await getHighestPrioritySubscription(userId) - - let limit = getFreeTierLimit() // Default free tier limit - - if (subscription) { - // Team/Enterprise: Use organization limit - if (isOrgPlan(subscription.plan)) { - limit = await getUserUsageLimit(userId) - logger.info('Using organization limit', { - userId, - plan: subscription.plan, - limit, - }) - } else { - // Pro/Free: Use individual limit - limit = getPerUserMinimumLimit(subscription) - logger.info('Using subscription-based limit', { - userId, - plan: subscription.plan, - limit, - }) - } - } else { - logger.info('Using free tier limit', { userId, limit }) - } - - // Get user stats to check current period usage - const statsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId)) - - if (statsRecords.length === 0) { - return false - } - - // Use current period cost instead of total cost for accurate billing period tracking - const currentCost = Number.parseFloat( - statsRecords[0].currentPeriodCost?.toString() || statsRecords[0].totalCost.toString() - ) - - logger.info('Checking cost limit', { userId, currentCost, limit }) - - return currentCost >= limit - } catch (error) { - logger.error('Error checking cost limit', { error, userId }) - return false // Be conservative in case of error - } -} - -/** - * Check if sharing features are enabled for user - */ -// Removed unused feature flag helpers: isSharingEnabled, isMultiplayerEnabled, isWorkspaceCollaborationEnabled - -/** - * Get comprehensive subscription state for a user - * Single function to get all subscription information - */ -export async function getUserSubscriptionState(userId: string): Promise { - try { - // Get subscription and user stats in parallel to minimize DB calls - const [subscription, statsRecords] = await Promise.all([ - getHighestPrioritySubscription(userId), - db.select().from(userStats).where(eq(userStats.userId, userId)).limit(1), - ]) - - // Determine plan types based on subscription (avoid redundant DB calls) - const isPro = - !isBillingEnabled || - !!( - subscription && - (checkProPlan(subscription) || - checkTeamPlan(subscription) || - checkEnterprisePlan(subscription)) - ) - const isTeam = - !isBillingEnabled || - !!(subscription && (checkTeamPlan(subscription) || checkEnterprisePlan(subscription))) - const isEnterprise = !isBillingEnabled || !!(subscription && checkEnterprisePlan(subscription)) - const isFree = !isPro && !isTeam && !isEnterprise - - // Determine plan name - let planName = 'free' - if (isEnterprise) planName = 'enterprise' - else if (isTeam) planName = 'team' - else if (isPro) planName = 'pro' - - // Check cost limit using already-fetched user stats - let hasExceededLimit = false - if (isBillingEnabled && statsRecords.length > 0) { - let limit = getFreeTierLimit() // Default free tier limit - if (subscription) { - // Team/Enterprise: Use organization limit - if (isOrgPlan(subscription.plan)) { - limit = await getUserUsageLimit(userId) - } else { - // Pro/Free: Use individual limit - limit = getPerUserMinimumLimit(subscription) - } - } - - const currentCost = Number.parseFloat( - statsRecords[0].currentPeriodCost?.toString() || statsRecords[0].totalCost.toString() - ) - hasExceededLimit = currentCost >= limit - } - - return { - isPro, - isTeam, - isEnterprise, - isFree, - highestPrioritySubscription: subscription, - hasExceededLimit, - planName, - } - } catch (error) { - logger.error('Error getting user subscription state', { error, userId }) - - // Return safe defaults in case of error - return { - isPro: false, - isTeam: false, - isEnterprise: false, - isFree: true, - highestPrioritySubscription: null, - hasExceededLimit: false, - planName: 'free', - } - } -} - /** * Send welcome email for Pro and Team plan subscriptions */ diff --git a/apps/sim/lib/billing/index.ts b/apps/sim/lib/billing/index.ts index c8dd1d6c4ff..b4260c82952 100644 --- a/apps/sim/lib/billing/index.ts +++ b/apps/sim/lib/billing/index.ts @@ -9,7 +9,6 @@ export * from '@/lib/billing/core/organization' export * from '@/lib/billing/core/subscription' export { getHighestPrioritySubscription as getActiveSubscription, - getUserSubscriptionState as getSubscriptionState, hasAccessControlAccess, hasCredentialSetsAccess, hasPaidSubscription, diff --git a/apps/sim/lib/billing/types/index.ts b/apps/sim/lib/billing/types/index.ts index 3c3f846fc8f..ac8c9736e1f 100644 --- a/apps/sim/lib/billing/types/index.ts +++ b/apps/sim/lib/billing/types/index.ts @@ -73,16 +73,6 @@ export interface BillingData { daysRemaining: number } -export interface UserSubscriptionState { - isPro: boolean - isTeam: boolean - isEnterprise: boolean - isFree: boolean - highestPrioritySubscription: any | null - hasExceededLimit: boolean - planName: string -} - export interface SubscriptionPlan { name: string priceId: string diff --git a/apps/sim/lib/copilot/chat-payload.test.ts b/apps/sim/lib/copilot/chat-payload.test.ts index 0c7b187e7fd..817ac013cbc 100644 --- a/apps/sim/lib/copilot/chat-payload.test.ts +++ b/apps/sim/lib/copilot/chat-payload.test.ts @@ -3,18 +3,18 @@ */ import { beforeEach, describe, expect, it, vi } from 'vitest' -vi.mock('@sim/logger', () => { - const createMockLogger = (): Record => ({ - info: vi.fn(), - warn: vi.fn(), - error: vi.fn(), - withMetadata: vi.fn(() => createMockLogger()), - }) - return { createLogger: vi.fn(() => createMockLogger()) } -}) +const { mockGetHighestPrioritySubscription } = vi.hoisted(() => ({ + mockGetHighestPrioritySubscription: vi.fn(), +})) vi.mock('@/lib/billing/core/subscription', () => ({ - getUserSubscriptionState: vi.fn(), + getHighestPrioritySubscription: mockGetHighestPrioritySubscription, +})) + +vi.mock('@/lib/billing/plan-helpers', () => ({ + isPaid: vi.fn( + (plan: string | null) => plan === 'pro' || plan === 'team' || plan === 'enterprise' + ), })) vi.mock('@/lib/copilot/chat-context', () => ({ @@ -57,48 +57,41 @@ vi.mock('@/tools/params', () => ({ createUserToolSchema: vi.fn(() => ({ type: 'object', properties: {} })), })) -import { getUserSubscriptionState } from '@/lib/billing/core/subscription' import { buildIntegrationToolSchemas } from '@/lib/copilot/chat-payload' -const mockedGetUserSubscriptionState = getUserSubscriptionState as unknown as { - mockResolvedValue: (value: unknown) => void - mockRejectedValue: (value: unknown) => void - mockClear: () => void -} - describe('buildIntegrationToolSchemas', () => { beforeEach(() => { vi.clearAllMocks() }) it('appends the email footer prompt for free users', async () => { - mockedGetUserSubscriptionState.mockResolvedValue({ isFree: true }) + mockGetHighestPrioritySubscription.mockResolvedValue(null) const toolSchemas = await buildIntegrationToolSchemas('user-free') const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send') - expect(getUserSubscriptionState).toHaveBeenCalledWith('user-free') + expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-free') expect(gmailTool?.description).toContain('sent with sim ai') }) it('does not append the email footer prompt for paid users', async () => { - mockedGetUserSubscriptionState.mockResolvedValue({ isFree: false }) + mockGetHighestPrioritySubscription.mockResolvedValue({ plan: 'pro', status: 'active' }) const toolSchemas = await buildIntegrationToolSchemas('user-paid') const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send') - expect(getUserSubscriptionState).toHaveBeenCalledWith('user-paid') + expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-paid') expect(gmailTool?.description).toBe('Send emails using Gmail') }) it('still builds integration tools when subscription lookup fails', async () => { - mockedGetUserSubscriptionState.mockRejectedValue(new Error('db unavailable')) + mockGetHighestPrioritySubscription.mockRejectedValue(new Error('db unavailable')) const toolSchemas = await buildIntegrationToolSchemas('user-error') const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send') const brandfetchTool = toolSchemas.find((tool) => tool.name === 'brandfetch_search') - expect(getUserSubscriptionState).toHaveBeenCalledWith('user-error') + expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-error') expect(gmailTool?.description).toBe('Send emails using Gmail') expect(brandfetchTool?.description).toBe('Search for brands by company name') }) diff --git a/apps/sim/lib/copilot/chat-payload.ts b/apps/sim/lib/copilot/chat-payload.ts index 69b1d342f17..dc82325f73f 100644 --- a/apps/sim/lib/copilot/chat-payload.ts +++ b/apps/sim/lib/copilot/chat-payload.ts @@ -1,5 +1,6 @@ import { createLogger } from '@sim/logger' -import { getUserSubscriptionState } from '@/lib/billing/core/subscription' +import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription' +import { isPaid } from '@/lib/billing/plan-helpers' import { getCopilotToolDescription } from '@/lib/copilot/tool-descriptions' import { isHosted } from '@/lib/core/config/feature-flags' import { createMcpToolId } from '@/lib/mcp/utils' @@ -57,10 +58,10 @@ export async function buildIntegrationToolSchemas( let shouldAppendEmailTagline = false try { - const subscriptionState = await getUserSubscriptionState(userId) - shouldAppendEmailTagline = subscriptionState.isFree + const subscription = await getHighestPrioritySubscription(userId) + shouldAppendEmailTagline = !subscription || !isPaid(subscription.plan) } catch (error) { - reqLogger.warn('Failed to load subscription state for copilot tool descriptions', { + reqLogger.warn('Failed to load subscription for copilot tool descriptions', { userId, error: error instanceof Error ? error.message : String(error), }) diff --git a/apps/sim/lib/table/billing.ts b/apps/sim/lib/table/billing.ts index 7183ad9b55c..a517d094c58 100644 --- a/apps/sim/lib/table/billing.ts +++ b/apps/sim/lib/table/billing.ts @@ -5,7 +5,8 @@ */ import { createLogger } from '@sim/logger' -import { getUserSubscriptionState } from '@/lib/billing/core/subscription' +import { getHighestPrioritySubscription } from '@/lib/billing/core/subscription' +import { getPlanTypeForLimits } from '@/lib/billing/plan-helpers' import { getWorkspaceBilledAccountUserId } from '@/lib/workspaces/utils' import { type PlanName, TABLE_PLAN_LIMITS, type TablePlanLimits } from './constants' @@ -29,8 +30,8 @@ export async function getWorkspaceTableLimits(workspaceId: string): Promise