| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212 |
- package hooks
- import (
- "errors"
- "testing"
- "github.com/QuantumNous/new-api/core/interfaces"
- "github.com/QuantumNous/new-api/core/registry"
- )
- // Mock Hook实现
- type testHook struct {
- name string
- priority int
- enabled bool
- beforeCalled bool
- afterCalled bool
- errorCalled bool
- shouldReturnError bool
- }
- func (h *testHook) Name() string { return h.name }
- func (h *testHook) Priority() int { return h.priority }
- func (h *testHook) Enabled() bool { return h.enabled }
- func (h *testHook) OnBeforeRequest(ctx *interfaces.HookContext) error {
- h.beforeCalled = true
- if h.shouldReturnError {
- return errors.New("test error")
- }
- return nil
- }
- func (h *testHook) OnAfterResponse(ctx *interfaces.HookContext) error {
- h.afterCalled = true
- if h.shouldReturnError {
- return errors.New("test error")
- }
- return nil
- }
- func (h *testHook) OnError(ctx *interfaces.HookContext, err error) error {
- h.errorCalled = true
- return nil
- }
- func TestHookChainExecution(t *testing.T) {
- // 创建测试hooks
- hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
- hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
- hook3 := &testHook{name: "hook3", priority: 75, enabled: false} // disabled
-
- // 创建Hook链
- chain := &HookChain{
- hooks: []interfaces.RelayHook{hook1, hook2, hook3},
- }
-
- // 创建测试上下文
- ctx := &interfaces.HookContext{
- Data: make(map[string]interface{}),
- }
-
- // 测试ExecuteBeforeRequest
- if err := chain.ExecuteBeforeRequest(ctx); err != nil {
- t.Errorf("ExecuteBeforeRequest failed: %v", err)
- }
-
- // 检查enabled的hooks是否被调用
- if !hook1.beforeCalled {
- t.Error("hook1 OnBeforeRequest should be called")
- }
-
- if !hook2.beforeCalled {
- t.Error("hook2 OnBeforeRequest should be called")
- }
-
- // disabled的hook不应该被调用
- if hook3.beforeCalled {
- t.Error("hook3 OnBeforeRequest should not be called (disabled)")
- }
-
- // 测试ExecuteAfterResponse
- if err := chain.ExecuteAfterResponse(ctx); err != nil {
- t.Errorf("ExecuteAfterResponse failed: %v", err)
- }
-
- if !hook1.afterCalled {
- t.Error("hook1 OnAfterResponse should be called")
- }
-
- if !hook2.afterCalled {
- t.Error("hook2 OnAfterResponse should be called")
- }
-
- // 测试ExecuteOnError
- testErr := errors.New("test error")
- if err := chain.ExecuteOnError(ctx, testErr); err != testErr {
- t.Error("ExecuteOnError should return original error")
- }
-
- if !hook1.errorCalled {
- t.Error("hook1 OnError should be called")
- }
- }
- func TestHookChainErrorHandling(t *testing.T) {
- // 创建会返回错误的hook
- errorHook := &testHook{
- name: "error_hook",
- priority: 100,
- enabled: true,
- shouldReturnError: true,
- }
-
- chain := &HookChain{
- hooks: []interfaces.RelayHook{errorHook},
- }
-
- ctx := &interfaces.HookContext{
- Data: make(map[string]interface{}),
- }
-
- // 测试错误处理
- if err := chain.ExecuteBeforeRequest(ctx); err == nil {
- t.Error("Expected error from ExecuteBeforeRequest")
- }
- }
- func TestHookChainShouldSkip(t *testing.T) {
- hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
- hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
-
- chain := &HookChain{
- hooks: []interfaces.RelayHook{hook1, hook2},
- }
-
- ctx := &interfaces.HookContext{
- Data: make(map[string]interface{}),
- ShouldSkip: true, // 设置跳过标记
- }
-
- // 执行
- if err := chain.ExecuteBeforeRequest(ctx); err != nil {
- t.Errorf("ExecuteBeforeRequest failed: %v", err)
- }
-
- // 由于ShouldSkip为true,hooks不应该被调用
- // 注意:当前实现在第一个hook执行后才会检查ShouldSkip
- // 所以hook1会被调用,但hook2不会
- }
- func TestHookChainCount(t *testing.T) {
- hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
- hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
-
- chain := &HookChain{
- hooks: []interfaces.RelayHook{hook1, hook2},
- }
-
- if count := chain.Count(); count != 2 {
- t.Errorf("Expected count 2, got %d", count)
- }
- }
- func TestHookChainGetHooks(t *testing.T) {
- hook1 := &testHook{name: "hook1", priority: 100, enabled: true}
- hook2 := &testHook{name: "hook2", priority: 50, enabled: true}
-
- chain := &HookChain{
- hooks: []interfaces.RelayHook{hook1, hook2},
- }
-
- hooks := chain.GetHooks()
- if len(hooks) != 2 {
- t.Errorf("Expected 2 hooks, got %d", len(hooks))
- }
- }
- func TestGlobalChain(t *testing.T) {
- // 测试全局链的单例模式
- chain1 := GetGlobalChain()
- chain2 := GetGlobalChain()
-
- if chain1 != chain2 {
- t.Error("GetGlobalChain should return the same instance")
- }
- }
- // 集成测试:测试完整的注册和执行流程
- func TestIntegration(t *testing.T) {
- // 注册测试hook
- testHook := &testHook{
- name: "integration_test_hook",
- priority: 100,
- enabled: true,
- }
-
- if err := registry.RegisterHook(testHook); err != nil {
- // 如果已注册,跳过错误
- t.Logf("Hook already registered (expected in some cases): %v", err)
- }
-
- // 创建新的hook链并加载
- chain := &HookChain{hooks: make([]interfaces.RelayHook, 0)}
- chain.LoadHooks()
-
- // 检查是否加载了hooks
- if chain.Count() == 0 {
- t.Log("No hooks loaded (expected if registry is clean)")
- }
- }
|