captions-mssapi.cpp 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171
  1. #include "captions-mssapi.hpp"
  2. #define do_log(type, format, ...) \
  3. blog(type, "[Captions] " format, ##__VA_ARGS__)
  4. #define error(format, ...) do_log(LOG_ERROR, format, ##__VA_ARGS__)
  5. #define debug(format, ...) do_log(LOG_DEBUG, format, ##__VA_ARGS__)
  6. mssapi_captions::mssapi_captions(captions_cb callback, const std::string &lang)
  7. try : captions_handler(callback, AUDIO_FORMAT_16BIT, 16000) {
  8. HRESULT hr;
  9. std::wstring wlang;
  10. wlang.resize(lang.size());
  11. for (size_t i = 0; i < lang.size(); i++)
  12. wlang[i] = (wchar_t)lang[i];
  13. LCID lang_id = LocaleNameToLCID(wlang.c_str(), 0);
  14. wchar_t lang_str[32];
  15. _snwprintf(lang_str, 31, L"language=%x", (int)lang_id);
  16. stop = CreateEvent(nullptr, false, false, nullptr);
  17. if (!stop.Valid())
  18. throw "Failed to create event";
  19. hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
  20. if (FAILED(hr))
  21. throw HRError("SpFindBestToken failed", hr);
  22. hr = CoCreateInstance(CLSID_SpInprocRecognizer, nullptr, CLSCTX_ALL,
  23. __uuidof(ISpRecognizer), (void **)&recognizer);
  24. if (FAILED(hr))
  25. throw HRError("CoCreateInstance for recognizer failed", hr);
  26. hr = recognizer->SetRecognizer(token);
  27. if (FAILED(hr))
  28. throw HRError("SetRecognizer failed", hr);
  29. hr = recognizer->SetRecoState(SPRST_INACTIVE);
  30. if (FAILED(hr))
  31. throw HRError("SetRecoState(SPRST_INACTIVE) failed", hr);
  32. hr = recognizer->CreateRecoContext(&context);
  33. if (FAILED(hr))
  34. throw HRError("CreateRecoContext failed", hr);
  35. ULONGLONG interest = SPFEI(SPEI_RECOGNITION) |
  36. SPFEI(SPEI_END_SR_STREAM);
  37. hr = context->SetInterest(interest, interest);
  38. if (FAILED(hr))
  39. throw HRError("SetInterest failed", hr);
  40. hr = context->SetNotifyWin32Event();
  41. if (FAILED(hr))
  42. throw HRError("SetNotifyWin32Event", hr);
  43. notify = context->GetNotifyEventHandle();
  44. if (notify == INVALID_HANDLE_VALUE)
  45. throw HRError("GetNotifyEventHandle failed", E_NOINTERFACE);
  46. size_t sample_rate = audio_output_get_sample_rate(obs_get_audio());
  47. audio = new CaptionStream((DWORD)sample_rate, this);
  48. audio->Release();
  49. hr = recognizer->SetInput(audio, false);
  50. if (FAILED(hr))
  51. throw HRError("SetInput failed", hr);
  52. hr = context->CreateGrammar(1, &grammar);
  53. if (FAILED(hr))
  54. throw HRError("CreateGrammar failed", hr);
  55. hr = grammar->LoadDictation(nullptr, SPLO_STATIC);
  56. if (FAILED(hr))
  57. throw HRError("LoadDictation failed", hr);
  58. try {
  59. t = std::thread([this]() { main_thread(); });
  60. } catch (...) {
  61. throw "Failed to create thread";
  62. }
  63. } catch (const char *err) {
  64. blog(LOG_WARNING, "%s: %s", __FUNCTION__, err);
  65. throw CAPTIONS_ERROR_GENERIC_FAIL;
  66. } catch (HRError err) {
  67. blog(LOG_WARNING, "%s: %s (%lX)", __FUNCTION__, err.str, err.hr);
  68. throw CAPTIONS_ERROR_GENERIC_FAIL;
  69. }
  70. mssapi_captions::~mssapi_captions()
  71. {
  72. if (t.joinable()) {
  73. SetEvent(stop);
  74. t.join();
  75. }
  76. }
  77. void mssapi_captions::main_thread()
  78. try {
  79. HRESULT hr;
  80. os_set_thread_name(__FUNCTION__);
  81. hr = grammar->SetDictationState(SPRS_ACTIVE);
  82. if (FAILED(hr))
  83. throw HRError("SetDictationState failed", hr);
  84. hr = recognizer->SetRecoState(SPRST_ACTIVE);
  85. if (FAILED(hr))
  86. throw HRError("SetRecoState(SPRST_ACTIVE) failed", hr);
  87. HANDLE events[] = {notify, stop};
  88. started = true;
  89. for (;;) {
  90. DWORD ret = WaitForMultipleObjects(2, events, false, INFINITE);
  91. if (ret != WAIT_OBJECT_0)
  92. break;
  93. CSpEvent event;
  94. bool exit = false;
  95. while (event.GetFrom(context) == S_OK) {
  96. if (event.eEventId == SPEI_RECOGNITION) {
  97. ISpRecoResult *result = event.RecoResult();
  98. CoTaskMemPtr<wchar_t> text;
  99. hr = result->GetText((ULONG)-1, (ULONG)-1, true,
  100. &text, nullptr);
  101. if (FAILED(hr))
  102. continue;
  103. char text_utf8[512];
  104. os_wcs_to_utf8(text, 0, text_utf8, 512);
  105. callback(text_utf8);
  106. blog(LOG_DEBUG, "\"%s\"", text_utf8);
  107. } else if (event.eEventId == SPEI_END_SR_STREAM) {
  108. exit = true;
  109. break;
  110. }
  111. }
  112. if (exit)
  113. break;
  114. }
  115. audio->Stop();
  116. } catch (HRError err) {
  117. blog(LOG_WARNING, "%s failed: %s (%lX)", __FUNCTION__, err.str, err.hr);
  118. }
  119. void mssapi_captions::pcm_data(const void *data, size_t frames)
  120. {
  121. if (started)
  122. audio->PushAudio(data, frames);
  123. }
  124. captions_handler_info mssapi_info = {
  125. []() -> std::string { return "Microsoft Speech-to-Text"; },
  126. [](captions_cb cb, const std::string &lang) -> captions_handler * {
  127. return new mssapi_captions(cb, lang);
  128. }};