Bläddra i källkod

CI: Use distributed network for service uptime check

derrod 1 år sedan
förälder
incheckning
8c3dbd9b51

+ 10 - 2
.github/actions/services-validator/action.yaml

@@ -4,6 +4,12 @@ inputs:
   repositorySecret:
     description: GitHub token for API access
     required: true
+  checkApiSecret:
+    description: Token for server check API
+    required: false
+  checkApiServers:
+    description: Servers for the check API
+    required: false
   runSchemaChecks:
     description: Enable schema checking
     required: false
@@ -46,8 +52,8 @@ runs:
           eval "$(/home/linuxbrew/.linuxbrew/bin/brew shellenv)"
           echo "/home/linuxbrew/.linuxbrew/bin:/home/linuxbrew/.linuxbrew/sbin" >> $GITHUB_PATH
         fi
-        brew install --overwrite --quiet python3
-        python3 -m pip install jsonschema json_source_map requests
+        brew install --overwrite --quiet python@3.11
+        python3 -m pip install jsonschema json_source_map requests aiohttp
         echo ::endgroup::
 
     - name: Validate Services File JSON Schema 🕵️
@@ -91,6 +97,8 @@ runs:
         GITHUB_TOKEN: ${{ inputs.repositorySecret }}
         WORKFLOW_RUN_ID: ${{ github.run_id }}
         REPOSITORY: ${{ github.repository }}
+        API_KEY: ${{ inputs.checkApiSecret }}
+        API_SERVERS: ${{ inputs.checkApiServers }}
       run: |
         : Check for defunct services 📉
         python3 -u .github/scripts/utils.py/check-services.py

+ 76 - 85
.github/scripts/utils.py/check-services.py

@@ -1,15 +1,15 @@
+import asyncio
 import json
-import socket
-import ssl
 import os
 import time
-import requests
 import sys
 import zipfile
 
+import aiohttp
+import requests
+
 from io import BytesIO
-from random import randbytes
-from urllib.parse import urlparse
+from typing import List, Dict
 from collections import defaultdict
 
 MINIMUM_PURGE_AGE = 9.75 * 24 * 60 * 60  # slightly less than 10 days
@@ -56,74 +56,6 @@ GQL_QUERY = """{
   }
 }"""
 
-context = ssl.create_default_context()
-
-
-def check_ftl_server(hostname) -> bool:
-    """Check if hostname resolves to a valid address - FTL handshake not implemented"""
-    try:
-        socket.getaddrinfo(hostname, 8084, proto=socket.IPPROTO_UDP)
-    except socket.gaierror as e:
-        print(f"⚠️ Could not resolve hostname for server: {hostname} (Exception: {e})")
-        return False
-    else:
-        return True
-
-
-def check_hls_server(uri) -> bool:
-    """Check if URL responds with status code < 500 and not 404, indicating that at least there's *something* there"""
-    try:
-        r = requests.post(uri, timeout=TIMEOUT)
-        if r.status_code >= 500 or r.status_code == 404:
-            raise Exception(f"Server responded with {r.status_code}")
-    except Exception as e:
-        print(f"⚠️ Could not connect to HLS server: {uri} (Exception: {e})")
-        return False
-    else:
-        return True
-
-
-def check_rtmp_server(uri) -> bool:
-    """Try connecting and sending a RTMP handshake (with SSL if necessary)"""
-    parsed = urlparse(uri)
-    hostname, port = parsed.netloc.partition(":")[::2]
-
-    if port:
-        port = int(port)
-    elif parsed.scheme == "rtmps":
-        port = 443
-    else:
-        port = 1935
-
-    try:
-        recv = b""
-        with socket.create_connection((hostname, port), timeout=TIMEOUT) as sock:
-            # RTMP handshake is \x03 + 4 bytes time (can be 0) + 4 zero bytes + 1528 bytes random
-            handshake = b"\x03\x00\x00\x00\x00\x00\x00\x00\x00" + randbytes(1528)
-            if parsed.scheme == "rtmps":
-                with context.wrap_socket(sock, server_hostname=hostname) as ssock:
-                    ssock.sendall(handshake)
-                    while True:
-                        _tmp = ssock.recv(4096)
-                        recv += _tmp
-                        if len(recv) >= 1536 or not _tmp:
-                            break
-            else:
-                sock.sendall(handshake)
-                while True:
-                    _tmp = sock.recv(4096)
-                    recv += _tmp
-                    if len(recv) >= 1536 or not _tmp:
-                        break
-
-        if len(recv) < 1536 or recv[0] != 3:
-            raise ValueError("Invalid RTMP handshake received from server")
-    except Exception as e:
-        print(f"⚠️ Connection to server failed: {uri} (Exception: {e})")
-        return False
-    else:
-        return True
-
 
 def get_last_artifact():
     s = requests.session()
@@ -219,7 +151,18 @@ def set_output(name, value):
         print(f"Writing to github output files failed: {e!r}")
 
 
-def main():
+async def check_servers_task(
+    session: aiohttp.ClientSession, host: str, protocol: str, servers: List[str]
+) -> List[Dict]:
+    query = [dict(url=h, protocol=protocol) for h in servers]
+
+    async with session.get(
+        f"http://{host}:8999/test_remote_servers", json=query
+    ) as resp:
+        return await resp.json()
+
+
+async def process_services(session: aiohttp.ClientSession, check_servers: List[str]):
     try:
         with open(SERVICES_FILE, encoding="utf-8") as services_file:
             raw_services = services_file.read()
@@ -265,25 +208,47 @@ def main():
             continue
 
         service_type = service.get("recommended", {}).get("output", "rtmp_output")
-        if service_type not in {"rtmp_output", "ffmpeg_hls_muxer", "ftl_output"}:
+        if service_type not in {"rtmp_output", "ffmpeg_hls_muxer"}:
             print("Unknown service type:", service_type)
             new_services["services"].append(service)
             continue
 
+        protocol = "rtmp" if service_type == "rtmp_output" else "hls"
+
         # create a copy to mess with
         new_service = service.copy()
         new_service["servers"] = []
 
         # run checks for all the servers, and store results in timestamp cache
-        for server in service["servers"]:
-            if service_type == "ftl_output":
-                is_ok = check_ftl_server(server["url"])
-            elif service_type == "ffmpeg_hls_muxer":
-                is_ok = check_hls_server(server["url"])
-            else:  # rtmp
-                is_ok = check_rtmp_server(server["url"])
+        try:
+            servers = [s["url"] for s in service["servers"]]
+            tasks = []
+            for host in check_servers:
+                tasks.append(
+                    asyncio.create_task(
+                        check_servers_task(session, host, protocol, servers)
+                    )
+                )
+            results = await asyncio.gather(*tasks)
+        except Exception as e:
+            print(
+                f"❌ Querying server status for \"{service['name']}\" failed with: {e}"
+            )
+            return 1
+
+        # go over results
+        for server, result in zip(service["servers"], zip(*results)):
+            failure_count = sum(not res["status"] for res in result)
+            probe_count = len(result)
+            # only treat server as failed if all check servers reported a failure
+            is_ok = failure_count < probe_count
 
             if not is_ok:
+                failures = {res["comment"] for res in result if not res["status"]}
+                print(
+                    f"⚠️ Connecting to server failed: {server['url']} (Reason(s): {failures})"
+                )
+
                 if ts := fail_timestamps.get(server["url"], None):
                     if (delta := start_time - ts) >= MINIMUM_PURGE_AGE:
                         print(
@@ -313,9 +278,11 @@ def main():
         if not new_service["servers"]:
             print(f'💀 Service "{service["name"]}" has no valid servers left, removing!')
             affected_services[service["name"]] = f"Service removed"
-            continue
+        else:
+            new_services["services"].append(new_service)
 
-        new_services["services"].append(new_service)
+        # wait a bit between services
+        await asyncio.sleep(2.0)
 
     # write cache file
     try:
@@ -377,5 +344,29 @@ def main():
         set_output("make_pr", "false")
 
 
+async def main():
+    # check for environment variables
+    try:
+        api_key = os.environ["API_KEY"]
+        servers = os.environ["API_SERVERS"].split(",")
+        if not servers:
+            raise ValueError("No checker servers!")
+        # Mask everything except the region code
+        for server in servers:
+            prefix = server[: server.index(".") + 1]
+            print(f"::add-mask::{prefix}")
+            suffix = server[server.index(".", len(prefix)) :]
+            print(f"::add-mask::{suffix}")
+    except Exception as e:
+        print(f"❌ Failed getting required environment variables: {e}")
+        return 1
+
+    # create aiohttp session
+    async with aiohttp.ClientSession() as session:
+        session.headers["Authorization"] = api_key
+        session.headers["User-Agent"] = "OBS Repo Service Checker/1.0"
+        return await process_services(session, servers)
+
+
 if __name__ == "__main__":
-    sys.exit(main())
+    sys.exit(asyncio.run(main()))

+ 2 - 0
.github/workflows/dispatch.yaml

@@ -45,6 +45,8 @@ jobs:
         uses: ./.github/actions/services-validator
         with:
           repositorySecret: ${{ secrets.GITHUB_TOKEN }}
+          checkApiSecret: ${{ secrets.CHECK_SERVERS_API_KEY }}
+          checkApiServers: ${{ secrets.CHECK_SERVERS_LIST }}
           runSchemaChecks: true
           runServiceChecks: true
           createPullRequest: true

+ 2 - 0
.github/workflows/scheduled.yaml

@@ -27,6 +27,8 @@ jobs:
         uses: ./.github/actions/services-validator
         with:
           repositorySecret: ${{ secrets.GITHUB_TOKEN }}
+          checkApiSecret: ${{ secrets.CHECK_SERVERS_API_KEY }}
+          checkApiServers: ${{ secrets.CHECK_SERVERS_LIST }}
           runSchemaChecks: false
           runServiceChecks: true
           createPullRequest: true