Sfoglia il codice sorgente

cmd/stdiscosrv: Streamline context handling

Jakob Borg 2 anni fa
parent
commit
a80e6be353
1 ha cambiato i file con 10 aggiunte e 12 eliminazioni
  1. 10 12
      cmd/stdiscosrv/apisrv.go

+ 10 - 12
cmd/stdiscosrv/apisrv.go

@@ -111,8 +111,6 @@ func (s *apiSrv) Serve(_ context.Context) error {
 	return err
 }
 
-var topCtx = context.Background()
-
 func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
 	t0 := time.Now()
 
@@ -125,10 +123,10 @@ func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
 	}()
 
 	reqID := requestID(rand.Int63())
-	ctx := context.WithValue(topCtx, idKey, reqID)
+	req = req.WithContext(context.WithValue(req.Context(), idKey, reqID))
 
 	if debug {
-		log.Println(reqID, req.Method, req.URL)
+		log.Println(reqID, req.Method, req.URL, req.Proto)
 	}
 
 	remoteAddr := &net.TCPAddr{
@@ -154,17 +152,17 @@ func (s *apiSrv) handler(w http.ResponseWriter, req *http.Request) {
 	}
 
 	switch req.Method {
-	case "GET":
-		s.handleGET(ctx, lw, req)
-	case "POST":
-		s.handlePOST(ctx, remoteAddr, lw, req)
+	case http.MethodGet:
+		s.handleGET(lw, req)
+	case http.MethodPost:
+		s.handlePOST(remoteAddr, lw, req)
 	default:
 		http.Error(lw, "Method Not Allowed", http.StatusMethodNotAllowed)
 	}
 }
 
-func (s *apiSrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http.Request) {
-	reqID := ctx.Value(idKey).(requestID)
+func (s *apiSrv) handleGET(w http.ResponseWriter, req *http.Request) {
+	reqID := req.Context().Value(idKey).(requestID)
 
 	deviceID, err := protocol.DeviceIDFromString(req.URL.Query().Get("device"))
 	if err != nil {
@@ -232,8 +230,8 @@ func (s *apiSrv) handleGET(ctx context.Context, w http.ResponseWriter, req *http
 	})
 }
 
-func (s *apiSrv) handlePOST(ctx context.Context, remoteAddr *net.TCPAddr, w http.ResponseWriter, req *http.Request) {
-	reqID := ctx.Value(idKey).(requestID)
+func (s *apiSrv) handlePOST(remoteAddr *net.TCPAddr, w http.ResponseWriter, req *http.Request) {
+	reqID := req.Context().Value(idKey).(requestID)
 
 	rawCert, err := certificateBytes(req)
 	if err != nil {