Browse Source

frontend-plugins: Abstract captions

Currently the captioning code is a bit intertwined with the UI, and the
captioning is hard-coded towards microsoft speech API.

This patch abstracts captions to allow other APIs to be implemented
later.
jp9000 8 years ago
parent
commit
aa70704700

+ 4 - 0
UI/frontend-plugins/frontend-tools/CMakeLists.txt

@@ -45,9 +45,13 @@ if(WIN32)
 		set(frontend-tools_PLATFORM_SOURCES
 			${frontend-tools_PLATFORM_SOURCES}
 			captions.cpp
+			captions-handler.cpp
+			captions-mssapi.cpp
 			captions-mssapi-stream.cpp)
 		set(frontend-tools_PLATFORM_HEADERS
 			captions.hpp
+			captions-handler.hpp
+			captions-mssapi.hpp
 			captions-mssapi-stream.hpp)
 		set(frontend-tools_PLATFORM_UI
 			forms/captions.ui)

+ 54 - 0
UI/frontend-plugins/frontend-tools/captions-handler.cpp

@@ -0,0 +1,54 @@
+#include "captions-handler.hpp"
+
+captions_handler::captions_handler(
+		captions_cb       callback,
+		enum audio_format format,
+		uint32_t          sample_rate)
+	: cb(callback)
+{
+	if (!reset_resampler(format, sample_rate))
+		throw CAPTIONS_ERROR_GENERIC_FAIL;
+}
+
+bool captions_handler::reset_resampler(
+		enum audio_format format,
+		uint32_t sample_rate)
+try {
+	obs_audio_info ai;
+	if (!obs_get_audio_info(&ai))
+		throw std::string("Failed to get OBS audio info");
+
+	resample_info src = {
+		ai.samples_per_sec,
+		AUDIO_FORMAT_FLOAT_PLANAR,
+		ai.speakers
+	};
+	resample_info dst = {
+		sample_rate,
+		format,
+		SPEAKERS_MONO
+	};
+
+	if (!resampler.reset(dst, src))
+		throw std::string("Failed to create audio resampler");
+
+	return true;
+
+} catch (std::string text) {
+	blog(LOG_WARNING, "%s: %s", __FUNCTION__, text.c_str());
+	return false;
+}
+
+void captions_handler::push_audio(const audio_data *audio)
+{
+	uint8_t *out[MAX_AV_PLANES];
+	uint32_t frames;
+	uint64_t ts_offset;
+	bool success;
+
+	success = audio_resampler_resample(resampler,
+			out, &frames, &ts_offset,
+			(const uint8_t *const *)audio->data, audio->frames);
+	if (success)
+		pcm_data(out[0], frames);
+}

+ 67 - 0
UI/frontend-plugins/frontend-tools/captions-handler.hpp

@@ -0,0 +1,67 @@
+#pragma once
+
+#include <media-io/audio-resampler.h>
+#include <obs-module.h>
+#include <functional>
+#include <string>
+
+class resampler_obj {
+	audio_resampler_t *resampler = nullptr;
+
+public:
+	inline ~resampler_obj()
+	{
+		audio_resampler_destroy(resampler);
+	}
+
+	inline bool reset(const resample_info &dst, const resample_info &src)
+	{
+		audio_resampler_destroy(resampler);
+		resampler = audio_resampler_create(&dst, &src);
+		return !!resampler;
+	}
+
+	inline operator audio_resampler_t*() {return resampler;}
+};
+
+/* ------------------------------------------------------------------------- */
+
+typedef std::function<void (const std::string &)> captions_cb;
+
+#define captions_error(s) std::string(obs_module_text("Captions.Error." ## s))
+#define CAPTIONS_ERROR_GENERIC_FAIL     captions_error("GenericFail")
+
+/* ------------------------------------------------------------------------- */
+
+class captions_handler {
+	captions_cb cb;
+	resampler_obj resampler;
+
+protected:
+	inline void callback(const std::string &text)
+	{
+		cb(text);
+	}
+
+	virtual void pcm_data(const void *data, size_t frames)=0;
+
+	/* always resamples to 1 channel */
+	bool reset_resampler(enum audio_format format, uint32_t sample_rate);
+
+public:
+	/* throw std::string for errors shown to users */
+	captions_handler(
+			captions_cb       callback,
+			enum audio_format format,
+			uint32_t          sample_rate);
+	virtual ~captions_handler() {}
+
+	void push_audio(const audio_data *audio);
+};
+
+/* ------------------------------------------------------------------------- */
+
+struct captions_handler_info {
+	std::string       (*name)(void);
+	captions_handler *(*create)(captions_cb cb, const std::string &lang);
+};

+ 12 - 22
UI/frontend-plugins/frontend-tools/captions-mssapi-stream.cpp

@@ -1,4 +1,5 @@
 #include "captions-mssapi-stream.hpp"
+#include "captions-mssapi.hpp"
 #include <mmreg.h>
 #include <util/windows/CoTaskMemPtr.hpp>
 #include <util/threading.h>
@@ -13,7 +14,8 @@ using namespace std;
 #define debugfunc(format, ...)
 #endif
 
-CaptionStream::CaptionStream(DWORD samplerate_) :
+CaptionStream::CaptionStream(DWORD samplerate_, mssapi_captions *handler_) :
+	handler(handler_),
 	samplerate(samplerate_),
 	event(CreateEvent(nullptr, false, false, nullptr))
 {
@@ -28,8 +30,6 @@ CaptionStream::CaptionStream(DWORD samplerate_) :
 	format.nBlockAlign = 2;
 	format.wBitsPerSample = 16;
 	format.cbSize = sizeof(format);
-
-	resampler.Reset(&format);
 }
 
 void CaptionStream::Stop()
@@ -42,28 +42,16 @@ void CaptionStream::Stop()
 	cv.notify_one();
 }
 
-void CaptionStream::PushAudio(const struct audio_data *data, bool muted)
+void CaptionStream::PushAudio(const void *data, size_t frames)
 {
-	uint8_t *output[MAX_AV_PLANES] = {};
-	uint32_t frames = data->frames;
-	uint64_t ts_offset;
 	bool ready = false;
 
-	audio_resampler_resample(resampler, output, &frames, &ts_offset,
-			data->data, data->frames);
-
-	if (output[0]) {
-		if (muted)
-			memset(output[0], 0, frames * sizeof(int16_t));
-
-		lock_guard<mutex> lock(m);
-		circlebuf_push_back(buf, output[0], frames * sizeof(int16_t));
-		write_pos += frames * sizeof(int16_t);
-
-		if (wait_size && buf->size >= wait_size)
-			ready = true;
-	}
+	lock_guard<mutex> lock(m);
+	circlebuf_push_back(buf, data, frames * sizeof(int16_t));
+	write_pos += frames * sizeof(int16_t);
 
+	if (wait_size && buf->size >= wait_size)
+		ready = true;
 	if (ready)
 		cv.notify_one();
 }
@@ -316,7 +304,9 @@ STDMETHODIMP CaptionStream::SetFormat(REFGUID guid_ref,
 	if (guid_ref == SPDFID_WaveFormatEx) {
 		lock_guard<mutex> lock(m);
 		memcpy(&format, wfex, sizeof(format));
-		resampler.Reset(wfex);
+		if (!handler->reset_resampler(AUDIO_FORMAT_16BIT,
+				wfex->nSamplesPerSec))
+			return E_FAIL;
 
 		/* 50 msec */
 		DWORD size = format.nSamplesPerSec / 20;

+ 6 - 31
UI/frontend-plugins/frontend-tools/captions-mssapi-stream.hpp

@@ -1,10 +1,11 @@
+#pragma once
+
 #include <windows.h>
 #include <sapi.h>
 #include <condition_variable>
 #include <mutex>
 #include <vector>
 #include <obs.h>
-#include <media-io/audio-resampler.h>
 #include <util/circlebuf.h>
 #include <util/windows/WinHandle.hpp>
 
@@ -18,37 +19,12 @@ public:
 	inline circlebuf *operator->() {return &buf;}
 };
 
-class Resampler {
-	audio_resampler_t *resampler = nullptr;
-
-public:
-	inline void Reset(const WAVEFORMATEX *wfex)
-	{
-		const struct audio_output_info *aoi =
-			audio_output_get_info(obs_get_audio());
-
-		struct resample_info src;
-		src.samples_per_sec = aoi->samples_per_sec;
-		src.format = aoi->format;
-		src.speakers = aoi->speakers;
-
-		struct resample_info dst;
-		dst.samples_per_sec = uint32_t(wfex->nSamplesPerSec);
-		dst.format = AUDIO_FORMAT_16BIT;
-		dst.speakers = (enum speaker_layout)wfex->nChannels;
-
-		if (resampler)
-			audio_resampler_destroy(resampler);
-		resampler = audio_resampler_create(&dst, &src);
-	}
-
-	inline ~Resampler() {audio_resampler_destroy(resampler);}
-	inline operator audio_resampler_t*() {return resampler;}
-};
+class mssapi_captions;
 
 class CaptionStream : public ISpAudio {
 	volatile long refs = 1;
 	SPAUDIOBUFFERINFO buf_info = {};
+	mssapi_captions *handler;
 	ULONG notify_size = 0;
 	SPAUDIOSTATE state;
 	WinHandle event;
@@ -58,7 +34,6 @@ class CaptionStream : public ISpAudio {
 	std::mutex m;
 	std::vector<int16_t> temp_buf;
 	WAVEFORMATEX format = {};
-	Resampler resampler;
 
 	CircleBuf buf;
 	ULONG wait_size = 0;
@@ -67,10 +42,10 @@ class CaptionStream : public ISpAudio {
 	ULONGLONG write_pos = 0;
 
 public:
-	CaptionStream(DWORD samplerate);
+	CaptionStream(DWORD samplerate, mssapi_captions *handler_);
 
 	void Stop();
-	void PushAudio(const struct audio_data *audio_data, bool muted);
+	void PushAudio(const void *data, size_t frames);
 
 	// IUnknown methods
 	STDMETHODIMP QueryInterface(REFIID riid, void **ppv) override;

+ 179 - 0
UI/frontend-plugins/frontend-tools/captions-mssapi.cpp

@@ -0,0 +1,179 @@
+#include "captions-mssapi.hpp"
+
+#define do_log(type, format, ...) blog(type, "[Captions] " format, \
+		##__VA_ARGS__)
+
+#define error(format, ...) do_log(LOG_ERROR, format, ##__VA_ARGS__)
+#define debug(format, ...) do_log(LOG_DEBUG, format, ##__VA_ARGS__)
+
+mssapi_captions::mssapi_captions(
+		captions_cb callback,
+		const std::string &lang) try
+	: captions_handler(callback, AUDIO_FORMAT_16BIT, 16000)
+{
+	HRESULT hr;
+
+	std::wstring wlang;
+	wlang.resize(lang.size());
+
+	for (size_t i = 0; i < lang.size(); i++)
+		wlang[i] = (wchar_t)lang[i];
+
+	LCID lang_id = LocaleNameToLCID(wlang.c_str(), 0);
+
+	wchar_t lang_str[32];
+	_snwprintf(lang_str, 31, L"language=%x", (int)lang_id);
+
+	stop = CreateEvent(nullptr, false, false, nullptr);
+	if (!stop.Valid())
+		throw "Failed to create event";
+
+	hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
+	if (FAILED(hr))
+		throw HRError("SpFindBestToken failed", hr);
+
+	hr = CoCreateInstance(CLSID_SpInprocRecognizer, nullptr, CLSCTX_ALL,
+			__uuidof(ISpRecognizer), (void**)&recognizer);
+	if (FAILED(hr))
+		throw HRError("CoCreateInstance for recognizer failed", hr);
+
+	hr = recognizer->SetRecognizer(token);
+	if (FAILED(hr))
+		throw HRError("SetRecognizer failed", hr);
+
+	hr = recognizer->SetRecoState(SPRST_INACTIVE);
+	if (FAILED(hr))
+		throw HRError("SetRecoState(SPRST_INACTIVE) failed", hr);
+
+	hr = recognizer->CreateRecoContext(&context);
+	if (FAILED(hr))
+		throw HRError("CreateRecoContext failed", hr);
+
+	ULONGLONG interest = SPFEI(SPEI_RECOGNITION) |
+	                     SPFEI(SPEI_END_SR_STREAM);
+	hr = context->SetInterest(interest, interest);
+	if (FAILED(hr))
+		throw HRError("SetInterest failed", hr);
+
+	hr = context->SetNotifyWin32Event();
+	if (FAILED(hr))
+		throw HRError("SetNotifyWin32Event", hr);
+
+	notify = context->GetNotifyEventHandle();
+	if (notify == INVALID_HANDLE_VALUE)
+		throw HRError("GetNotifyEventHandle failed", E_NOINTERFACE);
+
+	size_t sample_rate = audio_output_get_sample_rate(obs_get_audio());
+	audio = new CaptionStream((DWORD)sample_rate, this);
+	audio->Release();
+
+	hr = recognizer->SetInput(audio, false);
+	if (FAILED(hr))
+		throw HRError("SetInput failed", hr);
+
+	hr = context->CreateGrammar(1, &grammar);
+	if (FAILED(hr))
+		throw HRError("CreateGrammar failed", hr);
+
+	hr = grammar->LoadDictation(nullptr, SPLO_STATIC);
+	if (FAILED(hr))
+		throw HRError("LoadDictation failed", hr);
+
+	try {
+		t = std::thread([this] () {main_thread();});
+	} catch (...) {
+		throw "Failed to create thread";
+	}
+
+} catch (const char *err) {
+	blog(LOG_WARNING, "%s: %s", __FUNCTION__, err);
+	throw CAPTIONS_ERROR_GENERIC_FAIL;
+
+} catch (HRError err) {
+	blog(LOG_WARNING, "%s: %s (%lX)", __FUNCTION__, err.str, err.hr);
+	throw CAPTIONS_ERROR_GENERIC_FAIL;
+}
+
+mssapi_captions::~mssapi_captions()
+{
+	if (t.joinable()) {
+		SetEvent(stop);
+		t.join();
+	}
+}
+
+void mssapi_captions::main_thread()
+try {
+	HRESULT hr;
+
+	os_set_thread_name(__FUNCTION__);
+
+	hr = grammar->SetDictationState(SPRS_ACTIVE);
+	if (FAILED(hr))
+		throw HRError("SetDictationState failed", hr);
+
+	hr = recognizer->SetRecoState(SPRST_ACTIVE);
+	if (FAILED(hr))
+		throw HRError("SetRecoState(SPRST_ACTIVE) failed", hr);
+
+	HANDLE events[] = {notify, stop};
+
+	started = true;
+
+	for (;;) {
+		DWORD ret = WaitForMultipleObjects(2, events, false, INFINITE);
+		if (ret != WAIT_OBJECT_0)
+			break;
+
+		CSpEvent event;
+		bool exit = false;
+
+		while (event.GetFrom(context) == S_OK) {
+			if (event.eEventId == SPEI_RECOGNITION) {
+				ISpRecoResult *result = event.RecoResult();
+
+				CoTaskMemPtr<wchar_t> text;
+				hr = result->GetText((ULONG)-1, (ULONG)-1,
+						true, &text, nullptr);
+				if (FAILED(hr))
+					continue;
+
+				char text_utf8[512];
+				os_wcs_to_utf8(text, 0, text_utf8, 512);
+
+				callback(text_utf8);
+
+				blog(LOG_DEBUG, "\"%s\"", text_utf8);
+
+			} else if (event.eEventId == SPEI_END_SR_STREAM) {
+				exit = true;
+				break;
+			}
+		}
+
+		if (exit)
+			break;
+	}
+
+	audio->Stop();
+
+} catch (HRError err) {
+	blog(LOG_WARNING, "%s failed: %s (%lX)", __FUNCTION__, err.str, err.hr);
+}
+
+void mssapi_captions::pcm_data(const void *data, size_t frames)
+{
+	if (started)
+		audio->PushAudio(data, frames);
+}
+
+captions_handler_info mssapi_info = {
+	[] () -> std::string
+	{
+		return "Microsoft Speech-to-Text";
+	},
+	[] (captions_cb cb, const std::string &lang) -> captions_handler *
+	{
+		return new mssapi_captions(cb, lang);
+	}
+};

+ 37 - 0
UI/frontend-plugins/frontend-tools/captions-mssapi.hpp

@@ -0,0 +1,37 @@
+#pragma once
+
+#include "captions-handler.hpp"
+#include "captions-mssapi-stream.hpp"
+#include <util/windows/HRError.hpp>
+#include <util/windows/ComPtr.hpp>
+#include <util/windows/WinHandle.hpp>
+#include <util/windows/CoTaskMemPtr.hpp>
+#include <util/threading.h>
+#include <util/platform.h>
+#include <sphelper.h>
+
+#include <obs.hpp>
+
+#include <thread>
+
+class mssapi_captions : public captions_handler {
+	friend class CaptionStream;
+
+	ComPtr<CaptionStream>  audio;
+	ComPtr<ISpObjectToken> token;
+	ComPtr<ISpRecoGrammar> grammar;
+	ComPtr<ISpRecognizer>  recognizer;
+	ComPtr<ISpRecoContext> context;
+
+	HANDLE                 notify;
+	WinHandle              stop;
+	std::thread            t;
+	bool                   started = false;
+
+	void main_thread();
+
+public:
+	mssapi_captions(captions_cb callback, const std::string &lang);
+	virtual ~mssapi_captions();
+	virtual void pcm_data(const void *data, size_t frames) override;
+};

+ 129 - 192
UI/frontend-plugins/frontend-tools/captions.cpp

@@ -1,47 +1,54 @@
+#include <QMessageBox>
+
+#include <windows.h>
 #include <obs-frontend-api.h>
-#include "captions-mssapi-stream.hpp"
 #include "captions.hpp"
+#include "captions-handler.hpp"
 #include "tool-helpers.hpp"
-#include <sphelper.h>
 #include <util/dstr.hpp>
 #include <util/platform.h>
-#include <util/windows/HRError.hpp>
+#include <util/windows/WinHandle.hpp>
 #include <util/windows/ComPtr.hpp>
-#include <util/windows/CoTaskMemPtr.hpp>
-#include <util/threading.h>
 #include <obs-module.h>
+#include <sphelper.h>
 
+#include <unordered_map>
+#include <vector>
 #include <string>
 #include <thread>
 #include <mutex>
 
+#include "captions-mssapi.hpp"
+
 #define do_log(type, format, ...) blog(type, "[Captions] " format, \
 		##__VA_ARGS__)
 
-#define error(format, ...) do_log(LOG_ERROR, format, ##__VA_ARGS__)
+#define warn(format, ...) do_log(LOG_WARNING, format, ##__VA_ARGS__)
 #define debug(format, ...) do_log(LOG_DEBUG, format, ##__VA_ARGS__)
 
 using namespace std;
 
-struct obs_captions {
-	thread th;
-	recursive_mutex m;
-	WinHandle stop_event;
+#define DEFAULT_HANDLER "mssapi"
 
+struct obs_captions {
+	string handler_id = DEFAULT_HANDLER;
 	string source_name;
 	OBSWeakSource source;
-	LANGID lang_id;
+	unique_ptr<captions_handler> handler;
+	LANGID lang_id = GetUserDefaultUILanguage();
 
-	void main_thread();
-	void start();
-	void stop();
+	std::unordered_map<std::string, captions_handler_info&> handler_types;
 
-	inline obs_captions() :
-		stop_event(CreateEvent(nullptr, false, false, nullptr)),
-		lang_id(GetUserDefaultUILanguage())
+	inline void register_handler(const char *id,
+			captions_handler_info &info)
 	{
+		handler_types.emplace(id, info);
 	}
 
+	void start();
+	void stop();
+
+	obs_captions();
 	inline ~obs_captions() {stop();}
 };
 
@@ -72,8 +79,6 @@ CaptionsDialog::CaptionsDialog(QWidget *parent) :
 {
 	ui->setupUi(this);
 
-	lock_guard<recursive_mutex> lock(captions->m);
-
 	auto cb = [this] (obs_source_t *source)
 	{
 		uint32_t caps = obs_source_get_output_flags(source);
@@ -97,8 +102,19 @@ CaptionsDialog::CaptionsDialog(QWidget *parent) :
 			return (*static_cast<cb_t*>(data))(source);}, &cb);
 	ui->source->blockSignals(false);
 
+	for (auto &ht : captions->handler_types) {
+		QString name = ht.second.name().c_str();
+		QString id = ht.first.c_str();
+		ui->provider->addItem(name, id);
+	}
+
+	QString qhandler_id = captions->handler_id.c_str();
+	int idx = ui->provider->findData(qhandler_id);
+	if (idx != -1)
+		ui->provider->setCurrentIndex(idx);
+
 	ui->enable->blockSignals(true);
-	ui->enable->setChecked(captions->th.joinable());
+	ui->enable->setChecked(!!captions->handler);
 	ui->enable->blockSignals(false);
 
 	vector<locale_info> locales;
@@ -129,13 +145,11 @@ CaptionsDialog::CaptionsDialog(QWidget *parent) :
 		ui->language->setEnabled(false);
 
 	} else if (!set_language) {
-		bool started = captions->th.joinable();
+		bool started = !!captions->handler;
 		if (started)
 			captions->stop();
 
-		captions->m.lock();
 		captions->lang_id = locales[0].id;
-		captions->m.unlock();
 
 		if (started)
 			captions->start();
@@ -144,14 +158,12 @@ CaptionsDialog::CaptionsDialog(QWidget *parent) :
 
 void CaptionsDialog::on_source_currentIndexChanged(int)
 {
-	bool started = captions->th.joinable();
+	bool started = !!captions->handler;
 	if (started)
 		captions->stop();
 
-	captions->m.lock();
 	captions->source_name = ui->source->currentText().toUtf8().constData();
 	captions->source = GetWeakSourceByName(captions->source_name.c_str());
-	captions->m.unlock();
 
 	if (started)
 		captions->start();
@@ -159,205 +171,122 @@ void CaptionsDialog::on_source_currentIndexChanged(int)
 
 void CaptionsDialog::on_enable_clicked(bool checked)
 {
-	if (checked)
+	if (checked) {
 		captions->start();
-	else
+		if (!captions->handler) {
+			ui->enable->blockSignals(true);
+			ui->enable->setChecked(false);
+			ui->enable->blockSignals(false);
+		}
+	} else {
 		captions->stop();
+	}
 }
 
 void CaptionsDialog::on_language_currentIndexChanged(int)
 {
-	bool started = captions->th.joinable();
+	bool started = !!captions->handler;
 	if (started)
 		captions->stop();
 
-	captions->m.lock();
 	captions->lang_id = (LANGID)ui->language->currentData().toInt();
-	captions->m.unlock();
 
 	if (started)
 		captions->start();
 }
 
-/* ------------------------------------------------------------------------- */
-
-void obs_captions::main_thread()
-try {
-	ComPtr<CaptionStream>  audio;
-	ComPtr<ISpObjectToken> token;
-	ComPtr<ISpRecoGrammar> grammar;
-	ComPtr<ISpRecognizer>  recognizer;
-	ComPtr<ISpRecoContext> context;
-	HRESULT hr;
-
-	auto cb = [&] (const struct audio_data *audio_data,
-			bool muted)
-	{
-		audio->PushAudio(audio_data, muted);
-	};
-
-	using cb_t = decltype(cb);
-
-	auto pre_cb = [] (void *param, obs_source_t*,
-		const struct audio_data *audio_data, bool muted)
-	{
-		return (*static_cast<cb_t*>(param))(audio_data, muted);
-	};
-
-	os_set_thread_name(__FUNCTION__);
-
-	CoInitialize(nullptr);
-
-	wchar_t lang_str[32];
-	_snwprintf(lang_str, 31, L"language=%x", (int)captions->lang_id);
-
-	hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
-	if (FAILED(hr))
-		throw HRError("SpFindBestToken failed", hr);
-
-	hr = CoCreateInstance(CLSID_SpInprocRecognizer, nullptr, CLSCTX_ALL,
-			__uuidof(ISpRecognizer), (void**)&recognizer);
-	if (FAILED(hr))
-		throw HRError("CoCreateInstance for recognizer failed", hr);
-
-	hr = recognizer->SetRecognizer(token);
-	if (FAILED(hr))
-		throw HRError("SetRecognizer failed", hr);
-
-	hr = recognizer->SetRecoState(SPRST_INACTIVE);
-	if (FAILED(hr))
-		throw HRError("SetRecoState(SPRST_INACTIVE) failed", hr);
-
-	hr = recognizer->CreateRecoContext(&context);
-	if (FAILED(hr))
-		throw HRError("CreateRecoContext failed", hr);
-
-	ULONGLONG interest = SPFEI(SPEI_RECOGNITION) |
-		SPFEI(SPEI_END_SR_STREAM);
-	hr = context->SetInterest(interest, interest);
-	if (FAILED(hr))
-		throw HRError("SetInterest failed", hr);
-
-	HANDLE notify;
-
-	hr = context->SetNotifyWin32Event();
-	if (FAILED(hr))
-		throw HRError("SetNotifyWin32Event", hr);
-
-	notify = context->GetNotifyEventHandle();
-	if (notify == INVALID_HANDLE_VALUE)
-		throw HRError("GetNotifyEventHandle failed", E_NOINTERFACE);
-
-	size_t sample_rate = audio_output_get_sample_rate(obs_get_audio());
-	audio = new CaptionStream((DWORD)sample_rate);
-	audio->Release();
-
-	hr = recognizer->SetInput(audio, false);
-	if (FAILED(hr))
-		throw HRError("SetInput failed", hr);
-
-	hr = context->CreateGrammar(1, &grammar);
-	if (FAILED(hr))
-		throw HRError("CreateGrammar failed", hr);
-
-	hr = grammar->LoadDictation(nullptr, SPLO_STATIC);
-	if (FAILED(hr))
-		throw HRError("LoadDictation failed", hr);
+void CaptionsDialog::on_provider_currentIndexChanged(int idx)
+{
+	bool started = !!captions->handler;
+	if (started)
+		captions->stop();
 
-	hr = grammar->SetDictationState(SPRS_ACTIVE);
-	if (FAILED(hr))
-		throw HRError("SetDictationState failed", hr);
+	captions->handler_id =
+		ui->provider->itemData(idx).toString().toUtf8().constData();
 
-	hr = recognizer->SetRecoState(SPRST_ACTIVE);
-	if (FAILED(hr))
-		throw HRError("SetRecoState(SPRST_ACTIVE) failed", hr);
+	if (started)
+		captions->start();
+}
 
-	HANDLE events[] = {notify, stop_event};
+/* ------------------------------------------------------------------------- */
 
-	{
-		captions->source = GetWeakSourceByName(
-				captions->source_name.c_str());
-		OBSSource strong = OBSGetStrongRef(source);
-		if (strong)
-			obs_source_add_audio_capture_callback(strong,
-					pre_cb, &cb);
+static void caption_text(const std::string &text)
+{
+	obs_output *output = obs_frontend_get_streaming_output();
+	if (output) {
+		obs_output_output_caption_text1(output, text.c_str());
+		obs_output_release(output);
 	}
+}
 
-	for (;;) {
-		DWORD ret = WaitForMultipleObjects(2, events, false, INFINITE);
-		if (ret != WAIT_OBJECT_0)
-			break;
-
-		CSpEvent event;
-		bool exit = false;
-
-		while (event.GetFrom(context) == S_OK) {
-			if (event.eEventId == SPEI_RECOGNITION) {
-				ISpRecoResult *result = event.RecoResult();
+static void audio_capture(void*, obs_source_t*,
+		const struct audio_data *audio, bool)
+{
+	captions->handler->push_audio(audio);
+}
 
-				CoTaskMemPtr<wchar_t> text;
-				hr = result->GetText((ULONG)-1, (ULONG)-1,
-						true, &text, nullptr);
-				if (FAILED(hr))
-					continue;
+void obs_captions::start()
+{
+	if (!captions->handler && valid_lang(lang_id)) {
+		wchar_t wname[256];
+
+		auto pair = handler_types.find(handler_id);
+		if (pair == handler_types.end()) {
+			warn("Failed to find handler '%s'",
+					handler_id.c_str());
+			return;
+		}
 
-				char text_utf8[512];
-				os_wcs_to_utf8(text, 0, text_utf8, 512);
+		if (!LCIDToLocaleName(lang_id, wname, 256, 0)) {
+			warn("Failed to get locale name: %d",
+					(int)GetLastError());
+			return;
+		}
 
-				obs_output_t *output =
-					obs_frontend_get_streaming_output();
-				if (output)
-					obs_output_output_caption_text1(output,
-							text_utf8);
+		size_t len = (size_t)wcslen(wname);
 
-				debug("\"%s\"", text_utf8);
+		string lang_name;
+		lang_name.resize(len);
 
-				obs_output_release(output);
+		for (size_t i = 0; i < len; i++)
+			lang_name[i] = (char)wname[i];
 
-			} else if (event.eEventId == SPEI_END_SR_STREAM) {
-				exit = true;
-				break;
-			}
+		OBSSource s = OBSGetStrongRef(source);
+		if (!s) {
+			warn("Source invalid");
+			return;
 		}
 
-		if (exit)
-			break;
-	}
+		try {
+			captions_handler *h = pair->second.create(caption_text,
+					lang_name);
+			handler.reset(h);
 
-	{
-		OBSSource strong = OBSGetStrongRef(source);
-		if (strong)
-			obs_source_remove_audio_capture_callback(strong,
-					pre_cb, &cb);
-	}
-
-	audio->Stop();
+			OBSSource s = OBSGetStrongRef(source);
+			obs_source_add_audio_capture_callback(s,
+					audio_capture, nullptr);
 
-	CoUninitialize();
+		} catch (std::string text) {
+			QWidget *window =
+				(QWidget*)obs_frontend_get_main_window();
 
-} catch (HRError err) {
-	error("%s failed: %s (%lX)", __FUNCTION__, err.str, err.hr);
-	CoUninitialize();
-	captions->th.detach();
-}
+			warn("Failed to create handler: %s", text.c_str());
 
-void obs_captions::start()
-{
-	if (!captions->th.joinable()) {
-		ResetEvent(captions->stop_event);
+			QMessageBox::warning(window,
+				obs_module_text("Captions.Error.GenericFail"),
+				text.c_str());
 
-		if (valid_lang(captions->lang_id))
-			captions->th = thread([] () {captions->main_thread();});
+		}
 	}
 }
 
 void obs_captions::stop()
 {
-	if (!captions->th.joinable())
-		return;
-
-	SetEvent(captions->stop_event);
-	captions->th.join();
+	OBSSource s = OBSGetStrongRef(source);
+	if (s)
+		obs_source_remove_audio_capture_callback(s,
+				audio_capture, nullptr);
+	handler.reset();
 }
 
 static bool get_locale_name(LANGID id, char *out)
@@ -455,6 +384,15 @@ static void get_valid_locale_names(vector<locale_info> &locales)
 
 /* ------------------------------------------------------------------------- */
 
+extern captions_handler_info mssapi_info;
+
+obs_captions::obs_captions()
+{
+	register_handler("mssapi", mssapi_info);
+}
+
+/* ------------------------------------------------------------------------- */
+
 extern "C" void FreeCaptions()
 {
 	delete captions;
@@ -470,37 +408,36 @@ static void obs_event(enum obs_frontend_event event, void *)
 static void save_caption_data(obs_data_t *save_data, bool saving, void*)
 {
 	if (saving) {
-		lock_guard<recursive_mutex> lock(captions->m);
 		obs_data_t *obj = obs_data_create();
 
 		obs_data_set_string(obj, "source",
 				captions->source_name.c_str());
-		obs_data_set_bool(obj, "enabled", captions->th.joinable());
+		obs_data_set_bool(obj, "enabled", !!captions->handler);
 		obs_data_set_int(obj, "lang_id", captions->lang_id);
+		obs_data_set_string(obj, "provider",
+				captions->handler_id.c_str());
 
 		obs_data_set_obj(save_data, "captions", obj);
 		obs_data_release(obj);
 	} else {
 		captions->stop();
 
-		captions->m.lock();
-
 		obs_data_t *obj = obs_data_get_obj(save_data, "captions");
 		if (!obj)
 			obj = obs_data_create();
 
 		obs_data_set_default_int(obj, "lang_id",
 				GetUserDefaultUILanguage());
+		obs_data_set_default_string(obj, "provider", DEFAULT_HANDLER);
 
 		bool enabled = obs_data_get_bool(obj, "enabled");
 		captions->source_name = obs_data_get_string(obj, "source");
 		captions->lang_id = (int)obs_data_get_int(obj, "lang_id");
+		captions->handler_id = obs_data_get_string(obj, "provider");
 		captions->source = GetWeakSourceByName(
 				captions->source_name.c_str());
 		obs_data_release(obj);
 
-		captions->m.unlock();
-
 		if (enabled)
 			captions->start();
 	}

+ 1 - 0
UI/frontend-plugins/frontend-tools/captions.hpp

@@ -17,4 +17,5 @@ public slots:
 	void on_source_currentIndexChanged(int idx);
 	void on_enable_clicked(bool checked);
 	void on_language_currentIndexChanged(int idx);
+	void on_provider_currentIndexChanged(int idx);
 };

+ 2 - 0
UI/frontend-plugins/frontend-tools/data/locale/en-US.ini

@@ -14,6 +14,8 @@ Stop="Stop"
 Captions="Captions (Experimental)"
 Captions.AudioSource="Audio source"
 Captions.CurrentSystemLanguage="Current System Language (%1)"
+Captions.Provider="Provider"
+Captions.Error.GenericFail="Failed to start captions"
 
 OutputTimer="Output Timer"
 OutputTimer.Stream="Stop streaming after:"

+ 15 - 1
UI/frontend-plugins/frontend-tools/forms/captions.ui

@@ -7,7 +7,7 @@
     <x>0</x>
     <y>0</y>
     <width>519</width>
-    <height>140</height>
+    <height>152</height>
    </rect>
   </property>
   <property name="windowTitle">
@@ -56,6 +56,20 @@
      <item row="2" column="1">
       <widget class="QComboBox" name="language"/>
      </item>
+     <item row="3" column="1">
+      <widget class="QComboBox" name="provider">
+       <property name="insertPolicy">
+        <enum>QComboBox::InsertAlphabetically</enum>
+       </property>
+      </widget>
+     </item>
+     <item row="3" column="0">
+      <widget class="QLabel" name="label_3">
+       <property name="text">
+        <string>Captions.Provider</string>
+       </property>
+      </widget>
+     </item>
     </layout>
    </item>
    <item>