| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196 |
- package aws
- import (
- "encoding/json"
- "fmt"
- "net/http"
- "one-api/common"
- "one-api/dto"
- "one-api/relay/channel/claude"
- relaycommon "one-api/relay/common"
- "one-api/relay/helper"
- "one-api/types"
- "strings"
- "github.com/gin-gonic/gin"
- "github.com/pkg/errors"
- "github.com/aws/aws-sdk-go-v2/aws"
- "github.com/aws/aws-sdk-go-v2/credentials"
- "github.com/aws/aws-sdk-go-v2/service/bedrockruntime"
- bedrockruntimeTypes "github.com/aws/aws-sdk-go-v2/service/bedrockruntime/types"
- )
- func newAwsClient(c *gin.Context, info *relaycommon.RelayInfo) (*bedrockruntime.Client, error) {
- awsSecret := strings.Split(info.ApiKey, "|")
- if len(awsSecret) != 3 {
- return nil, errors.New("invalid aws secret key")
- }
- ak := awsSecret[0]
- sk := awsSecret[1]
- region := awsSecret[2]
- client := bedrockruntime.New(bedrockruntime.Options{
- Region: region,
- Credentials: aws.NewCredentialsCache(credentials.NewStaticCredentialsProvider(ak, sk, "")),
- })
- return client, nil
- }
- func wrapErr(err error) *dto.OpenAIErrorWithStatusCode {
- return &dto.OpenAIErrorWithStatusCode{
- StatusCode: http.StatusInternalServerError,
- Error: dto.OpenAIError{
- Message: fmt.Sprintf("%s", err.Error()),
- },
- }
- }
- func awsRegionPrefix(awsRegionId string) string {
- parts := strings.Split(awsRegionId, "-")
- regionPrefix := ""
- if len(parts) > 0 {
- regionPrefix = parts[0]
- }
- return regionPrefix
- }
- func awsModelCanCrossRegion(awsModelId, awsRegionPrefix string) bool {
- regionSet, exists := awsModelCanCrossRegionMap[awsModelId]
- return exists && regionSet[awsRegionPrefix]
- }
- func awsModelCrossRegion(awsModelId, awsRegionPrefix string) string {
- modelPrefix, find := awsRegionCrossModelPrefixMap[awsRegionPrefix]
- if !find {
- return awsModelId
- }
- return modelPrefix + "." + awsModelId
- }
- func awsModelID(requestModel string) string {
- if awsModelID, ok := awsModelIDMap[requestModel]; ok {
- return awsModelID
- }
- return requestModel
- }
- func awsHandler(c *gin.Context, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
- awsCli, err := newAwsClient(c, info)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
- }
- awsModelId := awsModelID(c.GetString("request_model"))
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
- if canCrossRegion {
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
- }
- awsReq := &bedrockruntime.InvokeModelInput{
- ModelId: aws.String(awsModelId),
- Accept: aws.String("application/json"),
- ContentType: aws.String("application/json"),
- }
- claudeReq_, ok := c.Get("converted_request")
- if !ok {
- return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
- }
- claudeReq := claudeReq_.(*dto.ClaudeRequest)
- awsClaudeReq := copyRequest(claudeReq)
- awsReq.Body, err = json.Marshal(awsClaudeReq)
- if err != nil {
- return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
- }
- awsResp, err := awsCli.InvokeModel(c.Request.Context(), awsReq)
- if err != nil {
- return types.NewError(errors.Wrap(err, "InvokeModel"), types.ErrorCodeChannelAwsClientError), nil
- }
- claudeInfo := &claude.ClaudeResponseInfo{
- ResponseId: helper.GetResponseID(c),
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- ResponseText: strings.Builder{},
- Usage: &dto.Usage{},
- }
- handlerErr := claude.HandleClaudeResponseData(c, info, claudeInfo, awsResp.Body, RequestModeMessage)
- if handlerErr != nil {
- return handlerErr, nil
- }
- return nil, claudeInfo.Usage
- }
- func awsStreamHandler(c *gin.Context, resp *http.Response, info *relaycommon.RelayInfo, requestMode int) (*types.NewAPIError, *dto.Usage) {
- awsCli, err := newAwsClient(c, info)
- if err != nil {
- return types.NewError(err, types.ErrorCodeChannelAwsClientError), nil
- }
- awsModelId := awsModelID(c.GetString("request_model"))
- awsRegionPrefix := awsRegionPrefix(awsCli.Options().Region)
- canCrossRegion := awsModelCanCrossRegion(awsModelId, awsRegionPrefix)
- if canCrossRegion {
- awsModelId = awsModelCrossRegion(awsModelId, awsRegionPrefix)
- }
- awsReq := &bedrockruntime.InvokeModelWithResponseStreamInput{
- ModelId: aws.String(awsModelId),
- Accept: aws.String("application/json"),
- ContentType: aws.String("application/json"),
- }
- claudeReq_, ok := c.Get("converted_request")
- if !ok {
- return types.NewError(errors.New("aws claude request not found"), types.ErrorCodeInvalidRequest), nil
- }
- claudeReq := claudeReq_.(*dto.ClaudeRequest)
- awsClaudeReq := copyRequest(claudeReq)
- awsReq.Body, err = json.Marshal(awsClaudeReq)
- if err != nil {
- return types.NewError(errors.Wrap(err, "marshal request"), types.ErrorCodeBadResponseBody), nil
- }
- awsResp, err := awsCli.InvokeModelWithResponseStream(c.Request.Context(), awsReq)
- if err != nil {
- return types.NewError(errors.Wrap(err, "InvokeModelWithResponseStream"), types.ErrorCodeChannelAwsClientError), nil
- }
- stream := awsResp.GetStream()
- defer stream.Close()
- claudeInfo := &claude.ClaudeResponseInfo{
- ResponseId: helper.GetResponseID(c),
- Created: common.GetTimestamp(),
- Model: info.UpstreamModelName,
- ResponseText: strings.Builder{},
- Usage: &dto.Usage{},
- }
- for event := range stream.Events() {
- switch v := event.(type) {
- case *bedrockruntimeTypes.ResponseStreamMemberChunk:
- info.SetFirstResponseTime()
- respErr := claude.HandleStreamResponseData(c, info, claudeInfo, string(v.Value.Bytes), RequestModeMessage)
- if respErr != nil {
- return respErr, nil
- }
- case *bedrockruntimeTypes.UnknownUnionMember:
- fmt.Println("unknown tag:", v.Tag)
- return types.NewError(errors.New("unknown response type"), types.ErrorCodeInvalidRequest), nil
- default:
- fmt.Println("union is nil or unknown type")
- return types.NewError(errors.New("nil or unknown response type"), types.ErrorCodeInvalidRequest), nil
- }
- }
- claude.HandleStreamFinalResponse(c, info, claudeInfo, RequestModeMessage)
- return nil, claudeInfo.Usage
- }
|