|
|
@@ -21,6 +21,7 @@ import { oaCompatHelper } from "./provider/openai-compatible"
|
|
|
import { createRateLimiter } from "./rateLimiter"
|
|
|
import { createDataDumper } from "./dataDumper"
|
|
|
import { createTrialLimiter } from "./trialLimiter"
|
|
|
+import { createStickyTracker } from "./stickyProviderTracker"
|
|
|
|
|
|
type ZenData = Awaited<ReturnType<typeof ZenData.list>>
|
|
|
type RetryOptions = {
|
|
|
@@ -68,9 +69,11 @@ export async function handler(
|
|
|
const isTrial = await trialLimiter?.isTrial()
|
|
|
const rateLimiter = createRateLimiter(modelInfo.id, modelInfo.rateLimit, ip)
|
|
|
await rateLimiter?.check()
|
|
|
+ const stickyTracker = createStickyTracker(modelInfo.stickyProvider ?? false, sessionId)
|
|
|
+ const stickyProvider = await stickyTracker?.get()
|
|
|
|
|
|
const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => {
|
|
|
- const providerInfo = selectProvider(zenData, modelInfo, sessionId, isTrial ?? false, retry)
|
|
|
+ const providerInfo = selectProvider(zenData, modelInfo, sessionId, isTrial ?? false, retry, stickyProvider)
|
|
|
const authInfo = await authenticate(modelInfo, providerInfo)
|
|
|
validateBilling(authInfo, modelInfo)
|
|
|
validateModelSettings(authInfo)
|
|
|
@@ -121,6 +124,9 @@ export async function handler(
|
|
|
dataDumper?.provideModel(providerInfo.storeModel)
|
|
|
dataDumper?.provideRequest(reqBody)
|
|
|
|
|
|
+ // Store sticky provider
|
|
|
+ await stickyTracker?.set(providerInfo.id)
|
|
|
+
|
|
|
// Scrub response headers
|
|
|
const resHeaders = new Headers()
|
|
|
const keepHeaders = ["content-type", "cache-control"]
|
|
|
@@ -289,12 +295,18 @@ export async function handler(
|
|
|
sessionId: string,
|
|
|
isTrial: boolean,
|
|
|
retry: RetryOptions,
|
|
|
+ stickyProvider: string | undefined,
|
|
|
) {
|
|
|
const provider = (() => {
|
|
|
if (isTrial) {
|
|
|
return modelInfo.providers.find((provider) => provider.id === modelInfo.trial!.provider)
|
|
|
}
|
|
|
|
|
|
+ if (stickyProvider) {
|
|
|
+ const provider = modelInfo.providers.find((provider) => provider.id === stickyProvider)
|
|
|
+ if (provider) return provider
|
|
|
+ }
|
|
|
+
|
|
|
if (retry.retryCount === MAX_RETRIES) {
|
|
|
return modelInfo.providers.find((provider) => provider.id === modelInfo.fallbackProvider)
|
|
|
}
|