|
@@ -18,10 +18,11 @@ import (
|
|
|
"github.com/gofrs/uuid"
|
|
|
)
|
|
|
|
|
|
-type Service[T any] struct {
|
|
|
- userMap map[[16]byte]T
|
|
|
- logger logger.Logger
|
|
|
- handler Handler
|
|
|
+type Service[T comparable] struct {
|
|
|
+ userMap map[[16]byte]T
|
|
|
+ userFlow map[T]string
|
|
|
+ logger logger.Logger
|
|
|
+ handler Handler
|
|
|
}
|
|
|
|
|
|
type Handler interface {
|
|
@@ -30,23 +31,26 @@ type Handler interface {
|
|
|
E.Handler
|
|
|
}
|
|
|
|
|
|
-func NewService[T any](logger logger.Logger, handler Handler) *Service[T] {
|
|
|
+func NewService[T comparable](logger logger.Logger, handler Handler) *Service[T] {
|
|
|
return &Service[T]{
|
|
|
logger: logger,
|
|
|
handler: handler,
|
|
|
}
|
|
|
}
|
|
|
|
|
|
-func (s *Service[T]) UpdateUsers(userList []T, userUUIDList []string) {
|
|
|
+func (s *Service[T]) UpdateUsers(userList []T, userUUIDList []string, userFlowList []string) {
|
|
|
userMap := make(map[[16]byte]T)
|
|
|
+ userFlowMap := make(map[T]string)
|
|
|
for i, userName := range userList {
|
|
|
userID := uuid.FromStringOrNil(userUUIDList[i])
|
|
|
if userID == uuid.Nil {
|
|
|
userID = uuid.NewV5(uuid.Nil, userUUIDList[i])
|
|
|
}
|
|
|
userMap[userID] = userName
|
|
|
+ userFlowMap[userName] = userFlowList[i]
|
|
|
}
|
|
|
s.userMap = userMap
|
|
|
+ s.userFlow = userFlowMap
|
|
|
}
|
|
|
|
|
|
var _ N.TCPConnectionHandler = (*Service[int])(nil)
|
|
@@ -63,8 +67,13 @@ func (s *Service[T]) NewConnection(ctx context.Context, conn net.Conn, metadata
|
|
|
ctx = auth.ContextWithUser(ctx, user)
|
|
|
metadata.Destination = request.Destination
|
|
|
|
|
|
+ userFlow := s.userFlow[user]
|
|
|
+ if request.Flow != userFlow {
|
|
|
+ return E.New("flow mismatch: expected ", userFlow, ", but got ", request.Flow)
|
|
|
+ }
|
|
|
+
|
|
|
protocolConn := conn
|
|
|
- switch request.Flow {
|
|
|
+ switch userFlow {
|
|
|
case "":
|
|
|
case FlowVision:
|
|
|
protocolConn, err = NewVisionConn(conn, request.UUID, s.logger)
|