|
|
@@ -73,8 +73,16 @@ export async function handler(
|
|
|
const stickyProvider = await stickyTracker?.get()
|
|
|
|
|
|
const retriableRequest = async (retry: RetryOptions = { excludeProviders: [], retryCount: 0 }) => {
|
|
|
- const providerInfo = selectProvider(zenData, modelInfo, sessionId, isTrial ?? false, retry, stickyProvider)
|
|
|
- const authInfo = await authenticate(modelInfo, providerInfo)
|
|
|
+ const authInfo = await authenticate(modelInfo)
|
|
|
+ const providerInfo = selectProvider(
|
|
|
+ zenData,
|
|
|
+ authInfo,
|
|
|
+ modelInfo,
|
|
|
+ sessionId,
|
|
|
+ isTrial ?? false,
|
|
|
+ retry,
|
|
|
+ stickyProvider,
|
|
|
+ )
|
|
|
validateBilling(authInfo, modelInfo)
|
|
|
validateModelSettings(authInfo)
|
|
|
updateProviderKey(authInfo, providerInfo)
|
|
|
@@ -291,6 +299,7 @@ export async function handler(
|
|
|
|
|
|
function selectProvider(
|
|
|
zenData: ZenData,
|
|
|
+ authInfo: AuthInfo,
|
|
|
modelInfo: ModelInfo,
|
|
|
sessionId: string,
|
|
|
isTrial: boolean,
|
|
|
@@ -298,6 +307,10 @@ export async function handler(
|
|
|
stickyProvider: string | undefined,
|
|
|
) {
|
|
|
const provider = (() => {
|
|
|
+ if (authInfo?.provider?.credentials) {
|
|
|
+ return modelInfo.providers.find((provider) => provider.id === modelInfo.byokProvider)
|
|
|
+ }
|
|
|
+
|
|
|
if (isTrial) {
|
|
|
return modelInfo.providers.find((provider) => provider.id === modelInfo.trial!.provider)
|
|
|
}
|
|
|
@@ -342,15 +355,15 @@ export async function handler(
|
|
|
}
|
|
|
}
|
|
|
|
|
|
- async function authenticate(modelInfo: ModelInfo, providerInfo: ProviderInfo) {
|
|
|
+ async function authenticate(modelInfo: ModelInfo) {
|
|
|
const apiKey = opts.parseApiKey(input.request.headers)
|
|
|
if (!apiKey || apiKey === "public") {
|
|
|
if (modelInfo.allowAnonymous) return
|
|
|
throw new AuthError("Missing API key.")
|
|
|
}
|
|
|
|
|
|
- const data = await Database.use((tx) =>
|
|
|
- tx
|
|
|
+ const data = await Database.use((tx) => {
|
|
|
+ const query = tx
|
|
|
.select({
|
|
|
apiKey: KeyTable.id,
|
|
|
workspaceID: KeyTable.workspaceID,
|
|
|
@@ -378,13 +391,15 @@ export async function handler(
|
|
|
.innerJoin(BillingTable, eq(BillingTable.workspaceID, KeyTable.workspaceID))
|
|
|
.innerJoin(UserTable, and(eq(UserTable.workspaceID, KeyTable.workspaceID), eq(UserTable.id, KeyTable.userID)))
|
|
|
.leftJoin(ModelTable, and(eq(ModelTable.workspaceID, KeyTable.workspaceID), eq(ModelTable.model, modelInfo.id)))
|
|
|
- .leftJoin(
|
|
|
+
|
|
|
+ if (modelInfo.byokProvider) {
|
|
|
+ query.leftJoin(
|
|
|
ProviderTable,
|
|
|
- and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, providerInfo.id)),
|
|
|
+ and(eq(ProviderTable.workspaceID, KeyTable.workspaceID), eq(ProviderTable.provider, modelInfo.byokProvider)),
|
|
|
)
|
|
|
- .where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted)))
|
|
|
- .then((rows) => rows[0]),
|
|
|
- )
|
|
|
+ }
|
|
|
+ return query.where(and(eq(KeyTable.key, apiKey), isNull(KeyTable.timeDeleted))).then((rows) => rows[0])
|
|
|
+ })
|
|
|
|
|
|
if (!data) throw new AuthError("Invalid API key.")
|
|
|
logger.metric({
|
|
|
@@ -457,8 +472,7 @@ export async function handler(
|
|
|
}
|
|
|
|
|
|
function updateProviderKey(authInfo: AuthInfo, providerInfo: ProviderInfo) {
|
|
|
- if (!authInfo) return
|
|
|
- if (!authInfo.provider?.credentials) return
|
|
|
+ if (!authInfo?.provider?.credentials) return
|
|
|
providerInfo.apiKey = authInfo.provider.credentials
|
|
|
}
|
|
|
|