captions.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533
  1. #include <obs-frontend-api.h>
  2. #include "captions-stream.hpp"
  3. #include "captions.hpp"
  4. #include "tool-helpers.hpp"
  5. #include <sphelper.h>
  6. #include <util/dstr.hpp>
  7. #include <util/platform.h>
  8. #include <util/windows/HRError.hpp>
  9. #include <util/windows/ComPtr.hpp>
  10. #include <util/windows/CoTaskMemPtr.hpp>
  11. #include <util/threading.h>
  12. #include <obs-module.h>
  13. #include <string>
  14. #include <thread>
  15. #include <mutex>
  16. #define do_log(type, format, ...) blog(type, "[Captions] " format, \
  17. ##__VA_ARGS__)
  18. #define error(format, ...) do_log(LOG_ERROR, format, ##__VA_ARGS__)
  19. #define debug(format, ...) do_log(LOG_DEBUG, format, ##__VA_ARGS__)
  20. using namespace std;
  21. struct obs_captions {
  22. thread th;
  23. recursive_mutex m;
  24. WinHandle stop_event;
  25. string source_name;
  26. OBSWeakSource source;
  27. LANGID lang_id;
  28. void main_thread();
  29. void start();
  30. void stop();
  31. inline obs_captions() :
  32. stop_event(CreateEvent(nullptr, false, false, nullptr)),
  33. lang_id(GetUserDefaultUILanguage())
  34. {
  35. }
  36. inline ~obs_captions() {stop();}
  37. };
  38. static obs_captions *captions = nullptr;
  39. /* ------------------------------------------------------------------------- */
  40. struct locale_info {
  41. DStr name;
  42. LANGID id;
  43. inline locale_info() {}
  44. inline locale_info(const locale_info &) = delete;
  45. inline locale_info(locale_info &&li)
  46. : name(std::move(li.name)),
  47. id(li.id)
  48. {}
  49. };
  50. static void get_valid_locale_names(vector<locale_info> &names);
  51. static bool valid_lang(LANGID id);
  52. /* ------------------------------------------------------------------------- */
  53. CaptionsDialog::CaptionsDialog(QWidget *parent) :
  54. QDialog(parent),
  55. ui(new Ui_CaptionsDialog)
  56. {
  57. ui->setupUi(this);
  58. lock_guard<recursive_mutex> lock(captions->m);
  59. auto cb = [this] (obs_source_t *source)
  60. {
  61. uint32_t caps = obs_source_get_output_flags(source);
  62. QString name = obs_source_get_name(source);
  63. if (caps & OBS_SOURCE_AUDIO)
  64. ui->source->addItem(name);
  65. OBSWeakSource weak = OBSGetWeakRef(source);
  66. if (weak == captions->source)
  67. ui->source->setCurrentText(name);
  68. return true;
  69. };
  70. using cb_t = decltype(cb);
  71. ui->source->blockSignals(true);
  72. ui->source->addItem(QStringLiteral(""));
  73. ui->source->setCurrentIndex(0);
  74. obs_enum_sources([] (void *data, obs_source_t *source) {
  75. return (*static_cast<cb_t*>(data))(source);}, &cb);
  76. ui->source->blockSignals(false);
  77. ui->enable->blockSignals(true);
  78. ui->enable->setChecked(captions->th.joinable());
  79. ui->enable->blockSignals(false);
  80. vector<locale_info> locales;
  81. get_valid_locale_names(locales);
  82. bool set_language = false;
  83. ui->language->blockSignals(true);
  84. for (int idx = 0; idx < (int)locales.size(); idx++) {
  85. locale_info &locale = locales[idx];
  86. ui->language->addItem(locale.name->array, (int)locale.id);
  87. if (locale.id == captions->lang_id) {
  88. ui->language->setCurrentIndex(idx);
  89. set_language = true;
  90. }
  91. }
  92. if (!set_language && locales.size())
  93. ui->language->setCurrentIndex(0);
  94. ui->language->blockSignals(false);
  95. if (!locales.size()) {
  96. ui->source->setEnabled(false);
  97. ui->enable->setEnabled(false);
  98. ui->language->setEnabled(false);
  99. } else if (!set_language) {
  100. bool started = captions->th.joinable();
  101. if (started)
  102. captions->stop();
  103. captions->m.lock();
  104. captions->lang_id = locales[0].id;
  105. captions->m.unlock();
  106. if (started)
  107. captions->start();
  108. }
  109. }
  110. void CaptionsDialog::on_source_currentIndexChanged(int)
  111. {
  112. bool started = captions->th.joinable();
  113. if (started)
  114. captions->stop();
  115. captions->m.lock();
  116. captions->source_name = ui->source->currentText().toUtf8().constData();
  117. captions->source = GetWeakSourceByName(captions->source_name.c_str());
  118. captions->m.unlock();
  119. if (started)
  120. captions->start();
  121. }
  122. void CaptionsDialog::on_enable_clicked(bool checked)
  123. {
  124. if (checked)
  125. captions->start();
  126. else
  127. captions->stop();
  128. }
  129. void CaptionsDialog::on_language_currentIndexChanged(int)
  130. {
  131. bool started = captions->th.joinable();
  132. if (started)
  133. captions->stop();
  134. captions->m.lock();
  135. captions->lang_id = (LANGID)ui->language->currentData().toInt();
  136. captions->m.unlock();
  137. if (started)
  138. captions->start();
  139. }
  140. /* ------------------------------------------------------------------------- */
  141. void obs_captions::main_thread()
  142. try {
  143. ComPtr<CaptionStream> audio;
  144. ComPtr<ISpObjectToken> token;
  145. ComPtr<ISpRecoGrammar> grammar;
  146. ComPtr<ISpRecognizer> recognizer;
  147. ComPtr<ISpRecoContext> context;
  148. HRESULT hr;
  149. auto cb = [&] (const struct audio_data *audio_data,
  150. bool muted)
  151. {
  152. audio->PushAudio(audio_data, muted);
  153. };
  154. using cb_t = decltype(cb);
  155. auto pre_cb = [] (void *param, obs_source_t*,
  156. const struct audio_data *audio_data, bool muted)
  157. {
  158. return (*static_cast<cb_t*>(param))(audio_data, muted);
  159. };
  160. os_set_thread_name(__FUNCTION__);
  161. CoInitialize(nullptr);
  162. wchar_t lang_str[32];
  163. _snwprintf(lang_str, 31, L"language=%x", (int)captions->lang_id);
  164. hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
  165. if (FAILED(hr))
  166. throw HRError("SpFindBestToken failed", hr);
  167. hr = CoCreateInstance(CLSID_SpInprocRecognizer, nullptr, CLSCTX_ALL,
  168. __uuidof(ISpRecognizer), (void**)&recognizer);
  169. if (FAILED(hr))
  170. throw HRError("CoCreateInstance for recognizer failed", hr);
  171. hr = recognizer->SetRecognizer(token);
  172. if (FAILED(hr))
  173. throw HRError("SetRecognizer failed", hr);
  174. hr = recognizer->SetRecoState(SPRST_INACTIVE);
  175. if (FAILED(hr))
  176. throw HRError("SetRecoState(SPRST_INACTIVE) failed", hr);
  177. hr = recognizer->CreateRecoContext(&context);
  178. if (FAILED(hr))
  179. throw HRError("CreateRecoContext failed", hr);
  180. ULONGLONG interest = SPFEI(SPEI_RECOGNITION) |
  181. SPFEI(SPEI_END_SR_STREAM);
  182. hr = context->SetInterest(interest, interest);
  183. if (FAILED(hr))
  184. throw HRError("SetInterest failed", hr);
  185. HANDLE notify;
  186. hr = context->SetNotifyWin32Event();
  187. if (FAILED(hr))
  188. throw HRError("SetNotifyWin32Event", hr);
  189. notify = context->GetNotifyEventHandle();
  190. if (notify == INVALID_HANDLE_VALUE)
  191. throw HRError("GetNotifyEventHandle failed", E_NOINTERFACE);
  192. size_t sample_rate = audio_output_get_sample_rate(obs_get_audio());
  193. audio = new CaptionStream((DWORD)sample_rate);
  194. audio->Release();
  195. hr = recognizer->SetInput(audio, false);
  196. if (FAILED(hr))
  197. throw HRError("SetInput failed", hr);
  198. hr = context->CreateGrammar(1, &grammar);
  199. if (FAILED(hr))
  200. throw HRError("CreateGrammar failed", hr);
  201. hr = grammar->LoadDictation(nullptr, SPLO_STATIC);
  202. if (FAILED(hr))
  203. throw HRError("LoadDictation failed", hr);
  204. hr = grammar->SetDictationState(SPRS_ACTIVE);
  205. if (FAILED(hr))
  206. throw HRError("SetDictationState failed", hr);
  207. hr = recognizer->SetRecoState(SPRST_ACTIVE);
  208. if (FAILED(hr))
  209. throw HRError("SetRecoState(SPRST_ACTIVE) failed", hr);
  210. HANDLE events[] = {notify, stop_event};
  211. {
  212. captions->source = GetWeakSourceByName(
  213. captions->source_name.c_str());
  214. OBSSource strong = OBSGetStrongRef(source);
  215. if (strong)
  216. obs_source_add_audio_capture_callback(strong,
  217. pre_cb, &cb);
  218. }
  219. for (;;) {
  220. DWORD ret = WaitForMultipleObjects(2, events, false, INFINITE);
  221. if (ret != WAIT_OBJECT_0)
  222. break;
  223. CSpEvent event;
  224. bool exit = false;
  225. while (event.GetFrom(context) == S_OK) {
  226. if (event.eEventId == SPEI_RECOGNITION) {
  227. ISpRecoResult *result = event.RecoResult();
  228. CoTaskMemPtr<wchar_t> text;
  229. hr = result->GetText((ULONG)-1, (ULONG)-1,
  230. true, &text, nullptr);
  231. if (FAILED(hr))
  232. continue;
  233. char text_utf8[512];
  234. os_wcs_to_utf8(text, 0, text_utf8, 512);
  235. obs_output_t *output =
  236. obs_frontend_get_streaming_output();
  237. if (output)
  238. obs_output_output_caption_text1(output,
  239. text_utf8);
  240. debug("\"%s\"", text_utf8);
  241. obs_output_release(output);
  242. } else if (event.eEventId == SPEI_END_SR_STREAM) {
  243. exit = true;
  244. break;
  245. }
  246. }
  247. if (exit)
  248. break;
  249. }
  250. {
  251. OBSSource strong = OBSGetStrongRef(source);
  252. if (strong)
  253. obs_source_remove_audio_capture_callback(strong,
  254. pre_cb, &cb);
  255. }
  256. audio->Stop();
  257. CoUninitialize();
  258. } catch (HRError err) {
  259. error("%s failed: %s (%lX)", __FUNCTION__, err.str, err.hr);
  260. CoUninitialize();
  261. captions->th.detach();
  262. }
  263. void obs_captions::start()
  264. {
  265. if (!captions->th.joinable()) {
  266. ResetEvent(captions->stop_event);
  267. if (valid_lang(captions->lang_id))
  268. captions->th = thread([] () {captions->main_thread();});
  269. }
  270. }
  271. void obs_captions::stop()
  272. {
  273. if (!captions->th.joinable())
  274. return;
  275. SetEvent(captions->stop_event);
  276. captions->th.join();
  277. }
  278. static bool get_locale_name(LANGID id, char *out)
  279. {
  280. wchar_t name[256];
  281. int size = GetLocaleInfoW(id, LOCALE_SENGLISHLANGUAGENAME, name, 256);
  282. if (size <= 0)
  283. return false;
  284. os_wcs_to_utf8(name, 0, out, 256);
  285. return true;
  286. }
  287. static bool valid_lang(LANGID id)
  288. {
  289. ComPtr<ISpObjectToken> token;
  290. wchar_t lang_str[32];
  291. HRESULT hr;
  292. _snwprintf(lang_str, 31, L"language=%x", (int)id);
  293. hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
  294. return SUCCEEDED(hr);
  295. }
  296. static void get_valid_locale_names(vector<locale_info> &locales)
  297. {
  298. locale_info cur;
  299. char locale_name[256];
  300. static const LANGID default_locales[] = {
  301. 0x0409,
  302. 0x0401,
  303. 0x0402,
  304. 0x0403,
  305. 0x0404,
  306. 0x0405,
  307. 0x0406,
  308. 0x0407,
  309. 0x0408,
  310. 0x040a,
  311. 0x040b,
  312. 0x040c,
  313. 0x040d,
  314. 0x040e,
  315. 0x040f,
  316. 0x0410,
  317. 0x0411,
  318. 0x0412,
  319. 0x0413,
  320. 0x0414,
  321. 0x0415,
  322. 0x0416,
  323. 0x0417,
  324. 0x0418,
  325. 0x0419,
  326. 0x041a,
  327. 0
  328. };
  329. /* ---------------------------------- */
  330. LANGID def_id = GetUserDefaultUILanguage();
  331. LANGID id = def_id;
  332. if (valid_lang(id) && get_locale_name(id, locale_name)) {
  333. dstr_copy(cur.name, obs_module_text(
  334. "Captions.CurrentSystemLanguage"));
  335. dstr_replace(cur.name, "%1", locale_name);
  336. cur.id = id;
  337. locales.push_back(std::move(cur));
  338. }
  339. /* ---------------------------------- */
  340. const LANGID *locale = default_locales;
  341. while (*locale) {
  342. id = *locale;
  343. if (id != def_id &&
  344. valid_lang(id) &&
  345. get_locale_name(id, locale_name)) {
  346. dstr_copy(cur.name, locale_name);
  347. cur.id = id;
  348. locales.push_back(std::move(cur));
  349. }
  350. locale++;
  351. }
  352. }
  353. /* ------------------------------------------------------------------------- */
  354. extern "C" void FreeCaptions()
  355. {
  356. delete captions;
  357. captions = nullptr;
  358. }
  359. static void obs_event(enum obs_frontend_event event, void *)
  360. {
  361. if (event == OBS_FRONTEND_EVENT_EXIT)
  362. FreeCaptions();
  363. }
  364. static void save_caption_data(obs_data_t *save_data, bool saving, void*)
  365. {
  366. if (saving) {
  367. lock_guard<recursive_mutex> lock(captions->m);
  368. obs_data_t *obj = obs_data_create();
  369. obs_data_set_string(obj, "source",
  370. captions->source_name.c_str());
  371. obs_data_set_bool(obj, "enabled", captions->th.joinable());
  372. obs_data_set_int(obj, "lang_id", captions->lang_id);
  373. obs_data_set_obj(save_data, "captions", obj);
  374. obs_data_release(obj);
  375. } else {
  376. captions->stop();
  377. captions->m.lock();
  378. obs_data_t *obj = obs_data_get_obj(save_data, "captions");
  379. if (!obj)
  380. obj = obs_data_create();
  381. obs_data_set_default_int(obj, "lang_id",
  382. GetUserDefaultUILanguage());
  383. bool enabled = obs_data_get_bool(obj, "enabled");
  384. captions->source_name = obs_data_get_string(obj, "source");
  385. captions->lang_id = (int)obs_data_get_int(obj, "lang_id");
  386. captions->source = GetWeakSourceByName(
  387. captions->source_name.c_str());
  388. obs_data_release(obj);
  389. captions->m.unlock();
  390. if (enabled)
  391. captions->start();
  392. }
  393. }
  394. extern "C" void InitCaptions()
  395. {
  396. QAction *action = (QAction*)obs_frontend_add_tools_menu_qaction(
  397. obs_module_text("Captions"));
  398. captions = new obs_captions;
  399. auto cb = [] ()
  400. {
  401. obs_frontend_push_ui_translation(obs_module_get_string);
  402. QWidget *window =
  403. (QWidget*)obs_frontend_get_main_window();
  404. CaptionsDialog dialog(window);
  405. dialog.exec();
  406. obs_frontend_pop_ui_translation();
  407. };
  408. obs_frontend_add_save_callback(save_caption_data, nullptr);
  409. obs_frontend_add_event_callback(obs_event, nullptr);
  410. action->connect(action, &QAction::triggered, cb);
  411. }