Browse Source

win-wasapi: Schedule work on real-time work queue

MS claims it can schedule audio better if we use their API.
jpark37 4 years ago
parent
commit
24d82062ae
1 changed files with 302 additions and 8 deletions
  1. 302 8
      plugins/win-wasapi/win-wasapi.cpp

+ 302 - 8
plugins/win-wasapi/win-wasapi.cpp

@@ -14,6 +14,7 @@
 #include <cinttypes>
 
 #include <avrt.h>
+#include <RTWorkQ.h>
 
 using namespace std;
 
@@ -25,6 +26,66 @@ static void GetWASAPIDefaults(obs_data_t *settings);
 #define OBS_KSAUDIO_SPEAKER_4POINT1 \
 	(KSAUDIO_SPEAKER_SURROUND | SPEAKER_LOW_FREQUENCY)
 
+typedef HRESULT(STDAPICALLTYPE *PFN_RtwqUnlockWorkQueue)(DWORD);
+typedef HRESULT(STDAPICALLTYPE *PFN_RtwqLockSharedWorkQueue)(PCWSTR usageClass,
+							     LONG basePriority,
+							     DWORD *taskId,
+							     DWORD *id);
+typedef HRESULT(STDAPICALLTYPE *PFN_RtwqCreateAsyncResult)(IUnknown *,
+							   IRtwqAsyncCallback *,
+							   IUnknown *,
+							   IRtwqAsyncResult **);
+typedef HRESULT(STDAPICALLTYPE *PFN_RtwqPutWorkItem)(DWORD, LONG,
+						     IRtwqAsyncResult *);
+typedef HRESULT(STDAPICALLTYPE *PFN_RtwqPutWaitingWorkItem)(HANDLE, LONG,
+							    IRtwqAsyncResult *,
+							    RTWQWORKITEM_KEY *);
+
+class ARtwqAsyncCallback : public IRtwqAsyncCallback {
+protected:
+	ARtwqAsyncCallback(void *source) : source(source) {}
+
+public:
+	STDMETHOD_(ULONG, AddRef)() { return ++refCount; }
+
+	STDMETHOD_(ULONG, Release)() { return --refCount; }
+
+	STDMETHOD(QueryInterface)(REFIID riid, void **ppvObject)
+	{
+		HRESULT hr = E_NOINTERFACE;
+
+		if (riid == __uuidof(IRtwqAsyncCallback) ||
+		    riid == __uuidof(IUnknown)) {
+			*ppvObject = this;
+			AddRef();
+			hr = S_OK;
+		} else {
+			*ppvObject = NULL;
+		}
+
+		return hr;
+	}
+
+	STDMETHOD(GetParameters)
+	(DWORD *pdwFlags, DWORD *pdwQueue)
+	{
+		*pdwFlags = 0;
+		*pdwQueue = queue_id;
+		return S_OK;
+	}
+
+	STDMETHOD(Invoke)
+	(IRtwqAsyncResult *) override = 0;
+
+	DWORD GetQueueId() const { return queue_id; }
+	void SetQueueId(DWORD id) { queue_id = id; }
+
+protected:
+	std::atomic<ULONG> refCount = 1;
+	void *source;
+	DWORD queue_id = 0;
+};
+
 class WASAPISource {
 	ComPtr<IMMNotificationClient> notify;
 	ComPtr<IMMDeviceEnumerator> enumerator;
@@ -35,6 +96,12 @@ class WASAPISource {
 	wstring default_id;
 	string device_id;
 	string device_name;
+	PFN_RtwqUnlockWorkQueue rtwq_unlock_work_queue = NULL;
+	PFN_RtwqLockSharedWorkQueue rtwq_lock_shared_work_queue = NULL;
+	PFN_RtwqCreateAsyncResult rtwq_create_async_result = NULL;
+	PFN_RtwqPutWorkItem rtwq_put_work_item = NULL;
+	PFN_RtwqPutWaitingWorkItem rtwq_put_waiting_work_item = NULL;
+	bool rtwq_supported = false;
 	uint64_t lastNotifyTime = 0;
 	bool isInputDevice;
 	std::atomic<bool> useDeviceTiming = false;
@@ -43,6 +110,55 @@ class WASAPISource {
 	bool previouslyFailed = false;
 	WinHandle reconnectThread;
 
+	class CallbackStartCapture : public ARtwqAsyncCallback {
+	public:
+		CallbackStartCapture(WASAPISource *source)
+			: ARtwqAsyncCallback(source)
+		{
+		}
+
+		STDMETHOD(Invoke)
+		(IRtwqAsyncResult *) override
+		{
+			((WASAPISource *)source)->OnStartCapture();
+			return S_OK;
+		}
+
+	} startCapture;
+	ComPtr<IRtwqAsyncResult> startCaptureAsyncResult;
+
+	class CallbackSampleReady : public ARtwqAsyncCallback {
+	public:
+		CallbackSampleReady(WASAPISource *source)
+			: ARtwqAsyncCallback(source)
+		{
+		}
+
+		STDMETHOD(Invoke)
+		(IRtwqAsyncResult *) override
+		{
+			((WASAPISource *)source)->OnSampleReady();
+			return S_OK;
+		}
+	} sampleReady;
+	ComPtr<IRtwqAsyncResult> sampleReadyAsyncResult;
+
+	class CallbackRestart : public ARtwqAsyncCallback {
+	public:
+		CallbackRestart(WASAPISource *source)
+			: ARtwqAsyncCallback(source)
+		{
+		}
+
+		STDMETHOD(Invoke)
+		(IRtwqAsyncResult *) override
+		{
+			((WASAPISource *)source)->OnRestart();
+			return S_OK;
+		}
+	} restart;
+	ComPtr<IRtwqAsyncResult> restartAsyncResult;
+
 	WinHandle captureThread;
 	WinHandle idleSignal;
 	WinHandle stopSignal;
@@ -94,6 +210,10 @@ public:
 	void Update(obs_data_t *settings);
 
 	void SetDefaultDevice(EDataFlow flow, ERole role, LPCWSTR id);
+
+	void OnStartCapture();
+	void OnSampleReady();
+	void OnRestart();
 };
 
 class WASAPINotify : public IMMNotificationClient {
@@ -149,7 +269,11 @@ public:
 
 WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_,
 			   bool input)
-	: source(source_), isInputDevice(input)
+	: source(source_),
+	  isInputDevice(input),
+	  startCapture(this),
+	  sampleReady(this),
+	  restart(this)
 {
 	UpdateSettings(settings);
 
@@ -200,11 +324,73 @@ WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_,
 	if (FAILED(hr))
 		throw HRError("Failed to register endpoint callback", hr);
 
-	captureThread = CreateThread(nullptr, 0, WASAPISource::CaptureThread,
-				     this, 0, nullptr);
-	if (!captureThread.Valid()) {
-		enumerator->UnregisterEndpointNotificationCallback(notify);
-		throw "Failed to create capture thread";
+	/* OBS will already load DLL on startup if it exists */
+	const HMODULE rtwq_module = GetModuleHandle(L"RTWorkQ.dll");
+	rtwq_supported = rtwq_module != NULL;
+	if (rtwq_supported) {
+		rtwq_unlock_work_queue =
+			(PFN_RtwqUnlockWorkQueue)GetProcAddress(
+				rtwq_module, "RtwqUnlockWorkQueue");
+		rtwq_lock_shared_work_queue =
+			(PFN_RtwqLockSharedWorkQueue)GetProcAddress(
+				rtwq_module, "RtwqLockSharedWorkQueue");
+		rtwq_create_async_result =
+			(PFN_RtwqCreateAsyncResult)GetProcAddress(
+				rtwq_module, "RtwqCreateAsyncResult");
+		rtwq_put_work_item = (PFN_RtwqPutWorkItem)GetProcAddress(
+			rtwq_module, "RtwqPutWorkItem");
+		rtwq_put_waiting_work_item =
+			(PFN_RtwqPutWaitingWorkItem)GetProcAddress(
+				rtwq_module, "RtwqPutWaitingWorkItem");
+
+		hr = rtwq_create_async_result(nullptr, &startCapture, nullptr,
+					      &startCaptureAsyncResult);
+		if (FAILED(hr)) {
+			enumerator->UnregisterEndpointNotificationCallback(
+				notify);
+			throw HRError(
+				"Could not create startCaptureAsyncResult", hr);
+		}
+
+		hr = rtwq_create_async_result(nullptr, &sampleReady, nullptr,
+					      &sampleReadyAsyncResult);
+		if (FAILED(hr)) {
+			enumerator->UnregisterEndpointNotificationCallback(
+				notify);
+			throw HRError("Could not create sampleReadyAsyncResult",
+				      hr);
+		}
+
+		hr = rtwq_create_async_result(nullptr, &restart, nullptr,
+					      &restartAsyncResult);
+		if (FAILED(hr)) {
+			enumerator->UnregisterEndpointNotificationCallback(
+				notify);
+			throw HRError("Could not create restartAsyncResult",
+				      hr);
+		}
+
+		DWORD taskId = 0;
+		DWORD id = 0;
+		hr = rtwq_lock_shared_work_queue(L"Capture", 0, &taskId, &id);
+		if (FAILED(hr)) {
+			enumerator->UnregisterEndpointNotificationCallback(
+				notify);
+			throw HRError("RtwqLockSharedWorkQueue failed", hr);
+		}
+
+		startCapture.SetQueueId(id);
+		sampleReady.SetQueueId(id);
+		restart.SetQueueId(id);
+	} else {
+		captureThread = CreateThread(nullptr, 0,
+					     WASAPISource::CaptureThread, this,
+					     0, nullptr);
+		if (!captureThread.Valid()) {
+			enumerator->UnregisterEndpointNotificationCallback(
+				notify);
+			throw "Failed to create capture thread";
+		}
 	}
 
 	Start();
@@ -212,7 +398,12 @@ WASAPISource::WASAPISource(obs_data_t *settings, obs_source_t *source_,
 
 void WASAPISource::Start()
 {
-	SetEvent(initSignal);
+	if (rtwq_supported) {
+		rtwq_put_work_item(startCapture.GetQueueId(), 0,
+				   startCaptureAsyncResult);
+	} else {
+		SetEvent(initSignal);
+	}
 }
 
 void WASAPISource::Stop()
@@ -221,13 +412,19 @@ void WASAPISource::Stop()
 
 	blog(LOG_INFO, "WASAPI: Device '%s' Terminated", device_name.c_str());
 
+	if (rtwq_supported)
+		SetEvent(receiveSignal);
+
 	WaitForSingleObject(idleSignal, INFINITE);
 
 	SetEvent(exitSignal);
 
 	WaitForSingleObject(reconnectThread, INFINITE);
 
-	WaitForSingleObject(captureThread, INFINITE);
+	if (rtwq_supported)
+		rtwq_unlock_work_queue(sampleReady.GetQueueId());
+	else
+		WaitForSingleObject(captureThread, INFINITE);
 }
 
 WASAPISource::~WASAPISource()
@@ -444,6 +641,24 @@ void WASAPISource::Initialize()
 	client = std::move(temp_client);
 	capture = std::move(temp_capture);
 
+	if (rtwq_supported) {
+		HRESULT hr = rtwq_put_waiting_work_item(
+			receiveSignal, 0, sampleReadyAsyncResult, nullptr);
+		if (FAILED(hr)) {
+			capture.Clear();
+			client.Clear();
+			throw HRError("RtwqPutWaitingWorkItem failed", hr);
+		}
+
+		hr = rtwq_put_waiting_work_item(restartSignal, 0,
+						restartAsyncResult, nullptr);
+		if (FAILED(hr)) {
+			capture.Clear();
+			client.Clear();
+			throw HRError("RtwqPutWaitingWorkItem failed", hr);
+		}
+	}
+
 	blog(LOG_INFO, "WASAPI: Device '%s' [%" PRIu32 " Hz] initialized",
 	     device_name.c_str(), sampleRate);
 }
@@ -724,6 +939,85 @@ void WASAPISource::SetDefaultDevice(EDataFlow flow, ERole role, LPCWSTR id)
 	SetEvent(restartSignal);
 }
 
+void WASAPISource::OnStartCapture()
+{
+	const DWORD ret = WaitForSingleObject(stopSignal, 0);
+	switch (ret) {
+	case WAIT_OBJECT_0:
+		SetEvent(idleSignal);
+		break;
+
+	default:
+		assert(ret == WAIT_TIMEOUT);
+
+		if (!TryInitialize()) {
+			blog(LOG_INFO, "WASAPI: Device '%s' failed to start",
+			     device_id.c_str());
+			reconnectDuration = RECONNECT_INTERVAL;
+			SetEvent(reconnectSignal);
+		}
+	}
+}
+
+void WASAPISource::OnSampleReady()
+{
+	bool stop = false;
+	bool reconnect = false;
+
+	if (!ProcessCaptureData()) {
+		stop = true;
+		reconnect = true;
+		reconnectDuration = RECONNECT_INTERVAL;
+	}
+
+	if (WaitForSingleObject(restartSignal, 0) == WAIT_OBJECT_0) {
+		stop = true;
+		reconnect = true;
+		reconnectDuration = 0;
+
+		ResetEvent(restartSignal);
+		rtwq_put_waiting_work_item(restartSignal, 0, restartAsyncResult,
+					   nullptr);
+	}
+
+	if (WaitForSingleObject(stopSignal, 0) == WAIT_OBJECT_0) {
+		stop = true;
+		reconnect = false;
+	}
+
+	if (!stop) {
+		if (FAILED(rtwq_put_waiting_work_item(receiveSignal, 0,
+						      sampleReadyAsyncResult,
+						      nullptr))) {
+			blog(LOG_ERROR,
+			     "Could not requeue sample receive work");
+			stop = true;
+			reconnect = true;
+			reconnectDuration = RECONNECT_INTERVAL;
+		}
+	}
+
+	if (stop) {
+		client->Stop();
+
+		capture.Clear();
+		client.Clear();
+
+		if (reconnect) {
+			blog(LOG_INFO, "Device '%s' invalidated.  Retrying",
+			     device_name.c_str());
+			SetEvent(reconnectSignal);
+		} else {
+			SetEvent(idleSignal);
+		}
+	}
+}
+
+void WASAPISource::OnRestart()
+{
+	SetEvent(receiveSignal);
+}
+
 /* ------------------------------------------------------------------------- */
 
 static const char *GetWASAPIInputName(void *)