captions.cpp 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. #include <obs-frontend-api.h>
  2. #include "captions-stream.hpp"
  3. #include "captions.hpp"
  4. #include "tool-helpers.hpp"
  5. #include <sphelper.h>
  6. #include <util/dstr.hpp>
  7. #include <util/platform.h>
  8. #include <util/windows/HRError.hpp>
  9. #include <util/windows/ComPtr.hpp>
  10. #include <util/windows/CoTaskMemPtr.hpp>
  11. #include <util/threading.h>
  12. #include <obs-module.h>
  13. #include <string>
  14. #include <thread>
  15. #include <mutex>
  16. #define do_log(type, format, ...) blog(type, "[Captions] " format, \
  17. ##__VA_ARGS__)
  18. #define error(format, ...) do_log(LOG_ERROR, format, ##__VA_ARGS__)
  19. #define debug(format, ...) do_log(LOG_DEBUG, format, ##__VA_ARGS__)
  20. using namespace std;
  21. struct obs_captions {
  22. thread th;
  23. recursive_mutex m;
  24. WinHandle stop_event;
  25. string source_name;
  26. OBSWeakSource source;
  27. LANGID lang_id;
  28. void main_thread();
  29. void start();
  30. void stop();
  31. inline obs_captions() :
  32. stop_event(CreateEvent(nullptr, false, false, nullptr)),
  33. lang_id(GetUserDefaultUILanguage())
  34. {
  35. }
  36. inline ~obs_captions() {stop();}
  37. };
  38. static obs_captions *captions = nullptr;
  39. /* ------------------------------------------------------------------------- */
  40. struct locale_info {
  41. DStr name;
  42. LANGID id;
  43. inline locale_info() {}
  44. inline locale_info(const locale_info &) = delete;
  45. inline locale_info(locale_info &&li)
  46. : name(std::move(li.name)),
  47. id(li.id)
  48. {}
  49. };
  50. static void get_valid_locale_names(vector<locale_info> &names);
  51. static bool valid_lang(LANGID id);
  52. /* ------------------------------------------------------------------------- */
  53. CaptionsDialog::CaptionsDialog(QWidget *parent) :
  54. QDialog(parent),
  55. ui(new Ui_CaptionsDialog)
  56. {
  57. ui->setupUi(this);
  58. lock_guard<recursive_mutex> lock(captions->m);
  59. auto cb = [this] (obs_source_t *source)
  60. {
  61. uint32_t caps = obs_source_get_output_flags(source);
  62. QString name = obs_source_get_name(source);
  63. if (caps & OBS_SOURCE_AUDIO)
  64. ui->source->addItem(name);
  65. OBSWeakSource weak = OBSGetWeakRef(source);
  66. if (weak == captions->source)
  67. ui->source->setCurrentText(name);
  68. return true;
  69. };
  70. using cb_t = decltype(cb);
  71. ui->source->blockSignals(true);
  72. ui->source->addItem(QStringLiteral(""));
  73. ui->source->setCurrentIndex(0);
  74. obs_enum_sources([] (void *data, obs_source_t *source) {
  75. return (*static_cast<cb_t*>(data))(source);}, &cb);
  76. ui->source->blockSignals(false);
  77. ui->enable->blockSignals(true);
  78. ui->enable->setChecked(captions->th.joinable());
  79. ui->enable->blockSignals(false);
  80. vector<locale_info> locales;
  81. get_valid_locale_names(locales);
  82. bool set_language = false;
  83. ui->language->blockSignals(true);
  84. for (int idx = 0; idx < (int)locales.size(); idx++) {
  85. locale_info &locale = locales[idx];
  86. ui->language->addItem(locale.name->array, (int)locale.id);
  87. if (locale.id == captions->lang_id) {
  88. ui->language->setCurrentIndex(idx);
  89. set_language = true;
  90. }
  91. }
  92. if (!set_language && locales.size())
  93. ui->language->setCurrentIndex(0);
  94. ui->language->blockSignals(false);
  95. if (!locales.size()) {
  96. ui->source->setEnabled(false);
  97. ui->enable->setEnabled(false);
  98. ui->language->setEnabled(false);
  99. } else if (!set_language) {
  100. bool started = captions->th.joinable();
  101. if (started)
  102. captions->stop();
  103. captions->m.lock();
  104. captions->lang_id = locales[0].id;
  105. captions->m.unlock();
  106. if (started)
  107. captions->start();
  108. }
  109. }
  110. void CaptionsDialog::on_source_currentIndexChanged(int)
  111. {
  112. bool started = captions->th.joinable();
  113. if (started)
  114. captions->stop();
  115. captions->m.lock();
  116. captions->source_name = ui->source->currentText().toUtf8().constData();
  117. captions->source = GetWeakSourceByName(captions->source_name.c_str());
  118. captions->m.unlock();
  119. if (started)
  120. captions->start();
  121. }
  122. void CaptionsDialog::on_enable_clicked(bool checked)
  123. {
  124. if (checked)
  125. captions->start();
  126. else
  127. captions->stop();
  128. }
  129. void CaptionsDialog::on_language_currentIndexChanged(int)
  130. {
  131. bool started = captions->th.joinable();
  132. if (started)
  133. captions->stop();
  134. captions->m.lock();
  135. captions->lang_id = (LANGID)ui->language->currentData().toInt();
  136. captions->m.unlock();
  137. if (started)
  138. captions->start();
  139. }
  140. /* ------------------------------------------------------------------------- */
  141. void obs_captions::main_thread()
  142. try {
  143. ComPtr<CaptionStream> audio;
  144. ComPtr<ISpObjectToken> token;
  145. ComPtr<ISpRecoGrammar> grammar;
  146. ComPtr<ISpRecognizer> recognizer;
  147. ComPtr<ISpRecoContext> context;
  148. HRESULT hr;
  149. auto cb = [&] (const struct audio_data *audio_data,
  150. bool muted)
  151. {
  152. audio->PushAudio(audio_data, muted);
  153. };
  154. using cb_t = decltype(cb);
  155. auto pre_cb = [] (void *param, obs_source_t*,
  156. const struct audio_data *audio_data, bool muted)
  157. {
  158. return (*static_cast<cb_t*>(param))(audio_data, muted);
  159. };
  160. os_set_thread_name(__FUNCTION__);
  161. CoInitialize(nullptr);
  162. wchar_t lang_str[32];
  163. _snwprintf(lang_str, 31, L"language=%x", (int)captions->lang_id);
  164. hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
  165. if (FAILED(hr))
  166. throw HRError("SpFindBestToken failed", hr);
  167. hr = CoCreateInstance(CLSID_SpInprocRecognizer, nullptr, CLSCTX_ALL,
  168. __uuidof(ISpRecognizer), (void**)&recognizer);
  169. if (FAILED(hr))
  170. throw HRError("CoCreateInstance for recognizer failed", hr);
  171. hr = recognizer->SetRecognizer(token);
  172. if (FAILED(hr))
  173. throw HRError("SetRecognizer failed", hr);
  174. hr = recognizer->SetRecoState(SPRST_INACTIVE);
  175. if (FAILED(hr))
  176. throw HRError("SetRecoState(SPRST_INACTIVE) failed", hr);
  177. hr = recognizer->CreateRecoContext(&context);
  178. if (FAILED(hr))
  179. throw HRError("CreateRecoContext failed", hr);
  180. ULONGLONG interest = SPFEI(SPEI_RECOGNITION) |
  181. SPFEI(SPEI_END_SR_STREAM);
  182. hr = context->SetInterest(interest, interest);
  183. if (FAILED(hr))
  184. throw HRError("SetInterest failed", hr);
  185. HANDLE notify;
  186. hr = context->SetNotifyWin32Event();
  187. if (FAILED(hr))
  188. throw HRError("SetNotifyWin32Event", hr);
  189. notify = context->GetNotifyEventHandle();
  190. if (notify == INVALID_HANDLE_VALUE)
  191. throw HRError("GetNotifyEventHandle failed", E_NOINTERFACE);
  192. size_t sample_rate = audio_output_get_sample_rate(obs_get_audio());
  193. audio = new CaptionStream((DWORD)sample_rate);
  194. audio->Release();
  195. hr = recognizer->SetInput(audio, false);
  196. if (FAILED(hr))
  197. throw HRError("SetInput failed", hr);
  198. hr = context->CreateGrammar(1, &grammar);
  199. if (FAILED(hr))
  200. throw HRError("CreateGrammar failed", hr);
  201. hr = grammar->LoadDictation(nullptr, SPLO_STATIC);
  202. if (FAILED(hr))
  203. throw HRError("LoadDictation failed", hr);
  204. hr = grammar->SetDictationState(SPRS_ACTIVE);
  205. if (FAILED(hr))
  206. throw HRError("SetDictationState failed", hr);
  207. hr = recognizer->SetRecoState(SPRST_ACTIVE);
  208. if (FAILED(hr))
  209. throw HRError("SetRecoState(SPRST_ACTIVE) failed", hr);
  210. HANDLE events[] = {notify, stop_event};
  211. {
  212. captions->source = GetWeakSourceByName(
  213. captions->source_name.c_str());
  214. OBSSource strong = OBSGetStrongRef(source);
  215. if (strong)
  216. obs_source_add_audio_capture_callback(strong,
  217. pre_cb, &cb);
  218. }
  219. for (;;) {
  220. DWORD ret = WaitForMultipleObjects(2, events, false, INFINITE);
  221. if (ret != WAIT_OBJECT_0)
  222. break;
  223. CSpEvent event;
  224. bool exit = false;
  225. while (event.GetFrom(context) == S_OK) {
  226. if (event.eEventId == SPEI_RECOGNITION) {
  227. ISpRecoResult *result = event.RecoResult();
  228. CoTaskMemPtr<wchar_t> text;
  229. hr = result->GetText((ULONG)-1, (ULONG)-1,
  230. true, &text, nullptr);
  231. if (FAILED(hr))
  232. continue;
  233. char text_utf8[512];
  234. os_wcs_to_utf8(text, 0, text_utf8, 512);
  235. obs_output_t *output =
  236. obs_frontend_get_streaming_output();
  237. if (output)
  238. obs_output_output_caption_text1(output,
  239. text_utf8);
  240. debug("\"%s\"", text_utf8);
  241. obs_output_release(output);
  242. } else if (event.eEventId == SPEI_END_SR_STREAM) {
  243. exit = true;
  244. break;
  245. }
  246. }
  247. if (exit)
  248. break;
  249. }
  250. {
  251. OBSSource strong = OBSGetStrongRef(source);
  252. if (strong)
  253. obs_source_remove_audio_capture_callback(strong,
  254. pre_cb, &cb);
  255. }
  256. audio->Stop();
  257. CoUninitialize();
  258. } catch (HRError err) {
  259. error("%s failed: %s (%lX)", __FUNCTION__, err.str, err.hr);
  260. CoUninitialize();
  261. captions->th.detach();
  262. }
  263. void obs_captions::start()
  264. {
  265. if (!captions->th.joinable()) {
  266. if (valid_lang(captions->lang_id))
  267. captions->th = thread([] () {captions->main_thread();});
  268. }
  269. }
  270. void obs_captions::stop()
  271. {
  272. if (!captions->th.joinable())
  273. return;
  274. SetEvent(captions->stop_event);
  275. captions->th.join();
  276. }
  277. static bool get_locale_name(LANGID id, char *out)
  278. {
  279. wchar_t name[256];
  280. int size = GetLocaleInfoW(id, LOCALE_SENGLISHLANGUAGENAME, name, 256);
  281. if (size <= 0)
  282. return false;
  283. os_wcs_to_utf8(name, 0, out, 256);
  284. return true;
  285. }
  286. static bool valid_lang(LANGID id)
  287. {
  288. ComPtr<ISpObjectToken> token;
  289. wchar_t lang_str[32];
  290. HRESULT hr;
  291. _snwprintf(lang_str, 31, L"language=%x", (int)id);
  292. hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
  293. return SUCCEEDED(hr);
  294. }
  295. static void get_valid_locale_names(vector<locale_info> &locales)
  296. {
  297. locale_info cur;
  298. char locale_name[256];
  299. static const LANGID default_locales[] = {
  300. 0x0409,
  301. 0x0401,
  302. 0x0402,
  303. 0x0403,
  304. 0x0404,
  305. 0x0405,
  306. 0x0406,
  307. 0x0407,
  308. 0x0408,
  309. 0x040a,
  310. 0x040b,
  311. 0x040c,
  312. 0x040d,
  313. 0x040e,
  314. 0x040f,
  315. 0x0410,
  316. 0x0411,
  317. 0x0412,
  318. 0x0413,
  319. 0x0414,
  320. 0x0415,
  321. 0x0416,
  322. 0x0417,
  323. 0x0418,
  324. 0x0419,
  325. 0x041a,
  326. 0
  327. };
  328. /* ---------------------------------- */
  329. LANGID def_id = GetUserDefaultUILanguage();
  330. LANGID id = def_id;
  331. if (valid_lang(id) && get_locale_name(id, locale_name)) {
  332. dstr_copy(cur.name, obs_module_text(
  333. "Captions.CurrentSystemLanguage"));
  334. dstr_replace(cur.name, "%1", locale_name);
  335. cur.id = id;
  336. locales.push_back(std::move(cur));
  337. }
  338. /* ---------------------------------- */
  339. const LANGID *locale = default_locales;
  340. while (*locale) {
  341. id = *locale;
  342. if (id != def_id &&
  343. valid_lang(id) &&
  344. get_locale_name(id, locale_name)) {
  345. dstr_copy(cur.name, locale_name);
  346. cur.id = id;
  347. locales.push_back(std::move(cur));
  348. }
  349. locale++;
  350. }
  351. }
  352. /* ------------------------------------------------------------------------- */
  353. extern "C" void FreeCaptions()
  354. {
  355. delete captions;
  356. captions = nullptr;
  357. }
  358. static void obs_event(enum obs_frontend_event event, void *)
  359. {
  360. if (event == OBS_FRONTEND_EVENT_EXIT)
  361. FreeCaptions();
  362. }
  363. static void save_caption_data(obs_data_t *save_data, bool saving, void*)
  364. {
  365. if (saving) {
  366. lock_guard<recursive_mutex> lock(captions->m);
  367. obs_data_t *obj = obs_data_create();
  368. obs_data_set_string(obj, "source",
  369. captions->source_name.c_str());
  370. obs_data_set_bool(obj, "enabled", captions->th.joinable());
  371. obs_data_set_int(obj, "lang_id", captions->lang_id);
  372. obs_data_set_obj(save_data, "captions", obj);
  373. obs_data_release(obj);
  374. } else {
  375. captions->stop();
  376. captions->m.lock();
  377. obs_data_t *obj = obs_data_get_obj(save_data, "captions");
  378. if (!obj)
  379. obj = obs_data_create();
  380. obs_data_set_default_int(obj, "lang_id",
  381. GetUserDefaultUILanguage());
  382. bool enabled = obs_data_get_bool(obj, "enabled");
  383. captions->source_name = obs_data_get_string(obj, "source");
  384. captions->lang_id = (int)obs_data_get_int(obj, "lang_id");
  385. captions->source = GetWeakSourceByName(
  386. captions->source_name.c_str());
  387. obs_data_release(obj);
  388. captions->m.unlock();
  389. if (enabled)
  390. captions->start();
  391. }
  392. }
  393. extern "C" void InitCaptions()
  394. {
  395. QAction *action = (QAction*)obs_frontend_add_tools_menu_qaction(
  396. obs_module_text("Captions"));
  397. captions = new obs_captions;
  398. auto cb = [] ()
  399. {
  400. obs_frontend_push_ui_translation(obs_module_get_string);
  401. QWidget *window =
  402. (QWidget*)obs_frontend_get_main_window();
  403. CaptionsDialog dialog(window);
  404. dialog.exec();
  405. obs_frontend_pop_ui_translation();
  406. };
  407. obs_frontend_add_save_callback(save_caption_data, nullptr);
  408. obs_frontend_add_event_callback(obs_event, nullptr);
  409. action->connect(action, &QAction::triggered, cb);
  410. }