Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions apps/sim/app/api/speech/token/route.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import { createLogger } from '@sim/logger'
import { eq } from 'drizzle-orm'
import { type NextRequest, NextResponse } from 'next/server'
import { getSession } from '@/lib/auth'
import { hasExceededCostLimit } from '@/lib/billing/core/subscription'
import { checkServerSideUsageLimits } from '@/lib/billing/calculations/usage-monitor'
import { recordUsage } from '@/lib/billing/core/usage-log'
import { env } from '@/lib/core/config/env'
import { getCostMultiplier, isBillingEnabled } from '@/lib/core/config/feature-flags'
Expand Down Expand Up @@ -110,11 +110,14 @@ export async function POST(request: NextRequest) {
}
}

if (billingUserId && isBillingEnabled) {
const exceeded = await hasExceededCostLimit(billingUserId)
if (exceeded) {
if (billingUserId) {
const usageCheck = await checkServerSideUsageLimits(billingUserId)
if (usageCheck.isExceeded) {
return NextResponse.json(
{ error: 'Usage limit exceeded. Please upgrade your plan to continue.' },
{
error:
usageCheck.message || 'Usage limit exceeded. Please upgrade your plan to continue.',
},
{ status: 402 }
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import {
extractContextTokens,
} from '@/app/workspace/[workspaceId]/w/[workflowId]/components/panel/components/copilot/components/user-input/utils'
import { useWorkflowMap } from '@/hooks/queries/workflows'
import { useSettingsNavigation } from '@/hooks/use-settings-navigation'
import { useSpeechToText } from '@/hooks/use-speech-to-text'
import type { ChatContext } from '@/stores/panel'

Expand Down Expand Up @@ -120,6 +121,7 @@ export function UserInput({
onEnterWhileEmpty,
}: UserInputProps) {
const { workspaceId } = useParams<{ workspaceId: string }>()
const { navigateToSettings } = useSettingsNavigation()
const { data: workflowsById = {} } = useWorkflowMap(workspaceId)
const { data: session } = useSession()
const [value, setValue] = useState(defaultValue)
Expand Down Expand Up @@ -239,12 +241,19 @@ export function UserInput({
valueRef.current = newVal
}, [])

const handleUsageLimitExceeded = useCallback(() => {
navigateToSettings({ section: 'subscription' })
}, [navigateToSettings])

const {
isListening,
isSupported: isSttSupported,
toggleListening: rawToggle,
resetTranscript,
} = useSpeechToText({ onTranscript: handleTranscript })
} = useSpeechToText({
onTranscript: handleTranscript,
onUsageLimitExceeded: handleUsageLimitExceeded,
})

const toggleListening = useCallback(() => {
if (!isListening) {
Expand Down
8 changes: 8 additions & 0 deletions apps/sim/hooks/use-speech-to-text.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ export type PermissionState = 'prompt' | 'granted' | 'denied'

interface UseSpeechToTextProps {
onTranscript: (text: string) => void
onUsageLimitExceeded?: () => void
language?: string
}

Expand All @@ -31,13 +32,15 @@ interface UseSpeechToTextReturn {

export function useSpeechToText({
onTranscript,
onUsageLimitExceeded,
language,
}: UseSpeechToTextProps): UseSpeechToTextReturn {
const [isListening, setIsListening] = useState(false)
const [isSupported, setIsSupported] = useState(false)
const [permissionState, setPermissionState] = useState<PermissionState>('prompt')

const onTranscriptRef = useRef(onTranscript)
const onUsageLimitExceededRef = useRef(onUsageLimitExceeded)
const languageRef = useRef(language)
const mountedRef = useRef(true)
const startingRef = useRef(false)
Expand All @@ -55,6 +58,7 @@ export function useSpeechToText({
const committedTextRef = useRef('')

onTranscriptRef.current = onTranscript
onUsageLimitExceededRef.current = onUsageLimitExceeded
languageRef.current = language

useEffect(() => {
Expand Down Expand Up @@ -165,6 +169,10 @@ export function useSpeechToText({
})

if (!tokenResponse.ok) {
if (tokenResponse.status === 402) {
onUsageLimitExceededRef.current?.()
return false
}
const body = await tokenResponse.json().catch(() => ({}))
throw new Error(body.error || 'Failed to get speech token')
}
Expand Down
146 changes: 1 addition & 145 deletions apps/sim/lib/billing/core/subscription.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import { db } from '@sim/db'
import { member, subscription, user, userStats } from '@sim/db/schema'
import { member, subscription, user } from '@sim/db/schema'
import { createLogger } from '@sim/logger'
import { and, eq, inArray, sql } from 'drizzle-orm'
import { getEffectiveBillingStatus, isOrganizationBillingBlocked } from '@/lib/billing/core/access'
import { getHighestPrioritySubscription } from '@/lib/billing/core/plan'
import { getUserUsageLimit } from '@/lib/billing/core/usage'
import {
getPlanTierCredits,
isOrgPlan,
isPro as isPlanPro,
isTeam as isPlanTeam,
} from '@/lib/billing/plan-helpers'
Expand All @@ -16,12 +14,9 @@ import {
checkProPlan,
checkTeamPlan,
ENTITLED_SUBSCRIPTION_STATUSES,
getFreeTierLimit,
getPerUserMinimumLimit,
hasUsableSubscriptionAccess,
USABLE_SUBSCRIPTION_STATUSES,
} from '@/lib/billing/subscriptions/utils'
import type { UserSubscriptionState } from '@/lib/billing/types'
import {
isAccessControlEnabled,
isBillingEnabled,
Expand Down Expand Up @@ -485,145 +480,6 @@ export async function hasLiveSyncAccess(userId: string): Promise<boolean> {
}
}

/**
* Check if user has exceeded their cost limit based on current period usage
*/
export async function hasExceededCostLimit(userId: string): Promise<boolean> {
try {
if (!isBillingEnabled) {
return false
}

const subscription = await getHighestPrioritySubscription(userId)

let limit = getFreeTierLimit() // Default free tier limit

if (subscription) {
// Team/Enterprise: Use organization limit
if (isOrgPlan(subscription.plan)) {
limit = await getUserUsageLimit(userId)
logger.info('Using organization limit', {
userId,
plan: subscription.plan,
limit,
})
} else {
// Pro/Free: Use individual limit
limit = getPerUserMinimumLimit(subscription)
logger.info('Using subscription-based limit', {
userId,
plan: subscription.plan,
limit,
})
}
} else {
logger.info('Using free tier limit', { userId, limit })
}

// Get user stats to check current period usage
const statsRecords = await db.select().from(userStats).where(eq(userStats.userId, userId))

if (statsRecords.length === 0) {
return false
}

// Use current period cost instead of total cost for accurate billing period tracking
const currentCost = Number.parseFloat(
statsRecords[0].currentPeriodCost?.toString() || statsRecords[0].totalCost.toString()
)

logger.info('Checking cost limit', { userId, currentCost, limit })

return currentCost >= limit
} catch (error) {
logger.error('Error checking cost limit', { error, userId })
return false // Be conservative in case of error
}
}

/**
* Check if sharing features are enabled for user
*/
// Removed unused feature flag helpers: isSharingEnabled, isMultiplayerEnabled, isWorkspaceCollaborationEnabled

/**
* Get comprehensive subscription state for a user
* Single function to get all subscription information
*/
export async function getUserSubscriptionState(userId: string): Promise<UserSubscriptionState> {
try {
// Get subscription and user stats in parallel to minimize DB calls
const [subscription, statsRecords] = await Promise.all([
getHighestPrioritySubscription(userId),
db.select().from(userStats).where(eq(userStats.userId, userId)).limit(1),
])

// Determine plan types based on subscription (avoid redundant DB calls)
const isPro =
!isBillingEnabled ||
!!(
subscription &&
(checkProPlan(subscription) ||
checkTeamPlan(subscription) ||
checkEnterprisePlan(subscription))
)
const isTeam =
!isBillingEnabled ||
!!(subscription && (checkTeamPlan(subscription) || checkEnterprisePlan(subscription)))
const isEnterprise = !isBillingEnabled || !!(subscription && checkEnterprisePlan(subscription))
const isFree = !isPro && !isTeam && !isEnterprise

// Determine plan name
let planName = 'free'
if (isEnterprise) planName = 'enterprise'
else if (isTeam) planName = 'team'
else if (isPro) planName = 'pro'

// Check cost limit using already-fetched user stats
let hasExceededLimit = false
if (isBillingEnabled && statsRecords.length > 0) {
let limit = getFreeTierLimit() // Default free tier limit
if (subscription) {
// Team/Enterprise: Use organization limit
if (isOrgPlan(subscription.plan)) {
limit = await getUserUsageLimit(userId)
} else {
// Pro/Free: Use individual limit
limit = getPerUserMinimumLimit(subscription)
}
}

const currentCost = Number.parseFloat(
statsRecords[0].currentPeriodCost?.toString() || statsRecords[0].totalCost.toString()
)
hasExceededLimit = currentCost >= limit
}

return {
isPro,
isTeam,
isEnterprise,
isFree,
highestPrioritySubscription: subscription,
hasExceededLimit,
planName,
}
} catch (error) {
logger.error('Error getting user subscription state', { error, userId })

// Return safe defaults in case of error
return {
isPro: false,
isTeam: false,
isEnterprise: false,
isFree: true,
highestPrioritySubscription: null,
hasExceededLimit: false,
planName: 'free',
}
}
}

/**
* Send welcome email for Pro and Team plan subscriptions
*/
Expand Down
1 change: 0 additions & 1 deletion apps/sim/lib/billing/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ export * from '@/lib/billing/core/organization'
export * from '@/lib/billing/core/subscription'
export {
getHighestPrioritySubscription as getActiveSubscription,
getUserSubscriptionState as getSubscriptionState,
hasAccessControlAccess,
hasCredentialSetsAccess,
hasPaidSubscription,
Expand Down
10 changes: 0 additions & 10 deletions apps/sim/lib/billing/types/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -73,16 +73,6 @@ export interface BillingData {
daysRemaining: number
}

export interface UserSubscriptionState {
isPro: boolean
isTeam: boolean
isEnterprise: boolean
isFree: boolean
highestPrioritySubscription: any | null
hasExceededLimit: boolean
planName: string
}

export interface SubscriptionPlan {
name: string
priceId: string
Expand Down
39 changes: 16 additions & 23 deletions apps/sim/lib/copilot/chat-payload.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,18 @@
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'

vi.mock('@sim/logger', () => {
const createMockLogger = (): Record<string, any> => ({
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
withMetadata: vi.fn(() => createMockLogger()),
})
return { createLogger: vi.fn(() => createMockLogger()) }
})
const { mockGetHighestPrioritySubscription } = vi.hoisted(() => ({
mockGetHighestPrioritySubscription: vi.fn(),
}))

vi.mock('@/lib/billing/core/subscription', () => ({
getUserSubscriptionState: vi.fn(),
getHighestPrioritySubscription: mockGetHighestPrioritySubscription,
}))

vi.mock('@/lib/billing/plan-helpers', () => ({
isPaid: vi.fn(
(plan: string | null) => plan === 'pro' || plan === 'team' || plan === 'enterprise'
),
}))

vi.mock('@/lib/copilot/chat-context', () => ({
Expand Down Expand Up @@ -57,48 +57,41 @@ vi.mock('@/tools/params', () => ({
createUserToolSchema: vi.fn(() => ({ type: 'object', properties: {} })),
}))

import { getUserSubscriptionState } from '@/lib/billing/core/subscription'
import { buildIntegrationToolSchemas } from '@/lib/copilot/chat-payload'

const mockedGetUserSubscriptionState = getUserSubscriptionState as unknown as {
mockResolvedValue: (value: unknown) => void
mockRejectedValue: (value: unknown) => void
mockClear: () => void
}

describe('buildIntegrationToolSchemas', () => {
beforeEach(() => {
vi.clearAllMocks()
})

it('appends the email footer prompt for free users', async () => {
mockedGetUserSubscriptionState.mockResolvedValue({ isFree: true })
mockGetHighestPrioritySubscription.mockResolvedValue(null)

const toolSchemas = await buildIntegrationToolSchemas('user-free')
const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send')

expect(getUserSubscriptionState).toHaveBeenCalledWith('user-free')
expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-free')
expect(gmailTool?.description).toContain('sent with sim ai')
})

it('does not append the email footer prompt for paid users', async () => {
mockedGetUserSubscriptionState.mockResolvedValue({ isFree: false })
mockGetHighestPrioritySubscription.mockResolvedValue({ plan: 'pro', status: 'active' })

const toolSchemas = await buildIntegrationToolSchemas('user-paid')
const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send')

expect(getUserSubscriptionState).toHaveBeenCalledWith('user-paid')
expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-paid')
expect(gmailTool?.description).toBe('Send emails using Gmail')
})

it('still builds integration tools when subscription lookup fails', async () => {
mockedGetUserSubscriptionState.mockRejectedValue(new Error('db unavailable'))
mockGetHighestPrioritySubscription.mockRejectedValue(new Error('db unavailable'))

const toolSchemas = await buildIntegrationToolSchemas('user-error')
const gmailTool = toolSchemas.find((tool) => tool.name === 'gmail_send')
const brandfetchTool = toolSchemas.find((tool) => tool.name === 'brandfetch_search')

expect(getUserSubscriptionState).toHaveBeenCalledWith('user-error')
expect(mockGetHighestPrioritySubscription).toHaveBeenCalledWith('user-error')
expect(gmailTool?.description).toBe('Send emails using Gmail')
expect(brandfetchTool?.description).toBe('Search for brands by company name')
})
Expand Down
Loading
Loading