Skip to content

Commit 9f6cec4

Browse files
committed
feat: add input token and quota calculation for API request pre-checks
1 parent 4d841c4 commit 9f6cec4

File tree

6 files changed

+24
-16
lines changed

6 files changed

+24
-16
lines changed

addition/generation/api.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,9 @@ import (
55
"chat/globals"
66
"chat/utils"
77
"fmt"
8-
"github.com/gin-gonic/gin"
98
"strings"
9+
10+
"github.com/gin-gonic/gin"
1011
)
1112

1213
type WebsocketGenerationForm struct {
@@ -53,7 +54,7 @@ func GenerateAPI(c *gin.Context) {
5354
return
5455
}
5556

56-
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model)
57+
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, form.Model, []globals.Message{})
5758
if check != nil {
5859
conn.Send(globals.GenerationSegmentResponse{
5960
Message: check.Error(),

auth/rule.go

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,22 @@ import (
44
"chat/channel"
55
"database/sql"
66
"fmt"
7+
8+
"chat/globals"
9+
"chat/utils"
10+
711
"github.com/go-redis/redis/v8"
812
)
913

1014
const (
1115
ErrNotAuthenticated = "not authenticated error (model: %s)"
1216
ErrNotSetPrice = "the price of the model is not set (model: %s)"
1317
ErrNotEnoughQuota = "user quota is not enough error (model: %s, minimum quota: %0.2f, your quota: %0.2f)"
18+
ErrEstimatedCost = "estimated cost exceeds user quota (model: %s, estimated cost: %0.2f, your quota: %0.2f)"
1419
)
1520

1621
// CanEnableModel returns whether the model can be enabled (without subscription)
17-
func CanEnableModel(db *sql.DB, user *User, model string) error {
22+
func CanEnableModel(db *sql.DB, user *User, model string, messages []globals.Message) error {
1823
isAuth := user != nil
1924
isAdmin := isAuth && user.IsAdmin(db)
2025

@@ -37,21 +42,23 @@ func CanEnableModel(db *sql.DB, user *User, model string) error {
3742
return fmt.Errorf(ErrNotAuthenticated, model)
3843
}
3944

40-
// return if the user is authenticated and has enough quota
41-
limit := charge.GetLimit()
45+
// Calculate estimated input cost
46+
inputTokens := utils.NumTokensFromMessages(messages, model, false)
47+
estimatedInputCost := float32(inputTokens) / 1000 * charge.GetInput()
4248

49+
// Get user's current quota
4350
quota := user.GetQuota(db)
44-
if quota < limit {
45-
return fmt.Errorf(ErrNotEnoughQuota, model, limit, quota)
51+
if quota < estimatedInputCost {
52+
return fmt.Errorf(ErrEstimatedCost, model, estimatedInputCost, quota)
4653
}
4754

4855
return nil
4956
}
5057

51-
func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string) (canEnable error, usePlan bool) {
58+
func CanEnableModelWithSubscription(db *sql.DB, cache *redis.Client, user *User, model string, messages []globals.Message) (canEnable error, usePlan bool) {
5259
// use subscription quota first
5360
if user != nil && HandleSubscriptionUsage(db, cache, user, model) {
5461
return nil, true
5562
}
56-
return CanEnableModel(db, user, model), false
63+
return CanEnableModel(db, user, model, messages), false
5764
}

manager/chat.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ func ChatHandler(conn *Connection, user *auth.User, instance *conversation.Conve
202202
model := instance.GetModel()
203203
segment := adapter.ClearMessages(model, web.ToChatSearched(instance, restart))
204204

205-
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model)
205+
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model, segment)
206206
conn.Send(globals.ChatSegmentResponse{
207207
Conversation: instance.GetId(),
208208
})

manager/chat_completions.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,12 @@ func supportRelayPlan() bool {
2828
return channel.SystemInstance.SupportRelayPlan()
2929
}
3030

31-
func checkEnableState(db *sql.DB, cache *redis.Client, user *auth.User, model string) (state error, plan bool) {
31+
func checkEnableState(db *sql.DB, cache *redis.Client, user *auth.User, model string, messages []globals.Message) (state error, plan bool) {
3232
if supportRelayPlan() {
33-
return auth.CanEnableModelWithSubscription(db, cache, user, model)
33+
return auth.CanEnableModelWithSubscription(db, cache, user, model, messages)
3434
}
3535

36-
return auth.CanEnableModel(db, user, model), false
36+
return auth.CanEnableModel(db, user, model, messages), false
3737
}
3838

3939
func ChatRelayAPI(c *gin.Context) {
@@ -80,7 +80,7 @@ func ChatRelayAPI(c *gin.Context) {
8080
form.Official = true
8181
}
8282

83-
check, plan := checkEnableState(db, cache, user, form.Model)
83+
check, plan := checkEnableState(db, cache, user, form.Model, messages)
8484
if check != nil {
8585
sendErrorResponse(c, check, "quota_exceeded_error")
8686
return

manager/completions.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ func NativeChatHandler(c *gin.Context, user *auth.User, model string, message []
2828

2929
db := utils.GetDBFromContext(c)
3030
cache := utils.GetCacheFromContext(c)
31-
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model)
31+
check, plan := auth.CanEnableModelWithSubscription(db, cache, user, model, segment)
3232

3333
if check != nil {
3434
return check.Error(), 0

manager/images.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ func ImagesRelayAPI(c *gin.Context) {
5454
form.Model = strings.TrimSuffix(form.Model, "-official")
5555
}
5656

57-
check := auth.CanEnableModel(db, user, form.Model)
57+
check := auth.CanEnableModel(db, user, form.Model, []globals.Message{})
5858
if check != nil {
5959
sendErrorResponse(c, check, "quota_exceeded_error")
6060
return

0 commit comments

Comments
 (0)