1
0

captions-mssapi.cpp 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. #include "captions-mssapi.hpp"
  2. #define do_log(type, format, ...) blog(type, "[Captions] " format, \
  3. ##__VA_ARGS__)
  4. #define error(format, ...) do_log(LOG_ERROR, format, ##__VA_ARGS__)
  5. #define debug(format, ...) do_log(LOG_DEBUG, format, ##__VA_ARGS__)
  6. mssapi_captions::mssapi_captions(
  7. captions_cb callback,
  8. const std::string &lang) try
  9. : captions_handler(callback, AUDIO_FORMAT_16BIT, 16000)
  10. {
  11. HRESULT hr;
  12. std::wstring wlang;
  13. wlang.resize(lang.size());
  14. for (size_t i = 0; i < lang.size(); i++)
  15. wlang[i] = (wchar_t)lang[i];
  16. LCID lang_id = LocaleNameToLCID(wlang.c_str(), 0);
  17. wchar_t lang_str[32];
  18. _snwprintf(lang_str, 31, L"language=%x", (int)lang_id);
  19. stop = CreateEvent(nullptr, false, false, nullptr);
  20. if (!stop.Valid())
  21. throw "Failed to create event";
  22. hr = SpFindBestToken(SPCAT_RECOGNIZERS, lang_str, nullptr, &token);
  23. if (FAILED(hr))
  24. throw HRError("SpFindBestToken failed", hr);
  25. hr = CoCreateInstance(CLSID_SpInprocRecognizer, nullptr, CLSCTX_ALL,
  26. __uuidof(ISpRecognizer), (void**)&recognizer);
  27. if (FAILED(hr))
  28. throw HRError("CoCreateInstance for recognizer failed", hr);
  29. hr = recognizer->SetRecognizer(token);
  30. if (FAILED(hr))
  31. throw HRError("SetRecognizer failed", hr);
  32. hr = recognizer->SetRecoState(SPRST_INACTIVE);
  33. if (FAILED(hr))
  34. throw HRError("SetRecoState(SPRST_INACTIVE) failed", hr);
  35. hr = recognizer->CreateRecoContext(&context);
  36. if (FAILED(hr))
  37. throw HRError("CreateRecoContext failed", hr);
  38. ULONGLONG interest = SPFEI(SPEI_RECOGNITION) |
  39. SPFEI(SPEI_END_SR_STREAM);
  40. hr = context->SetInterest(interest, interest);
  41. if (FAILED(hr))
  42. throw HRError("SetInterest failed", hr);
  43. hr = context->SetNotifyWin32Event();
  44. if (FAILED(hr))
  45. throw HRError("SetNotifyWin32Event", hr);
  46. notify = context->GetNotifyEventHandle();
  47. if (notify == INVALID_HANDLE_VALUE)
  48. throw HRError("GetNotifyEventHandle failed", E_NOINTERFACE);
  49. size_t sample_rate = audio_output_get_sample_rate(obs_get_audio());
  50. audio = new CaptionStream((DWORD)sample_rate, this);
  51. audio->Release();
  52. hr = recognizer->SetInput(audio, false);
  53. if (FAILED(hr))
  54. throw HRError("SetInput failed", hr);
  55. hr = context->CreateGrammar(1, &grammar);
  56. if (FAILED(hr))
  57. throw HRError("CreateGrammar failed", hr);
  58. hr = grammar->LoadDictation(nullptr, SPLO_STATIC);
  59. if (FAILED(hr))
  60. throw HRError("LoadDictation failed", hr);
  61. try {
  62. t = std::thread([this] () {main_thread();});
  63. } catch (...) {
  64. throw "Failed to create thread";
  65. }
  66. } catch (const char *err) {
  67. blog(LOG_WARNING, "%s: %s", __FUNCTION__, err);
  68. throw CAPTIONS_ERROR_GENERIC_FAIL;
  69. } catch (HRError err) {
  70. blog(LOG_WARNING, "%s: %s (%lX)", __FUNCTION__, err.str, err.hr);
  71. throw CAPTIONS_ERROR_GENERIC_FAIL;
  72. }
  73. mssapi_captions::~mssapi_captions()
  74. {
  75. if (t.joinable()) {
  76. SetEvent(stop);
  77. t.join();
  78. }
  79. }
  80. void mssapi_captions::main_thread()
  81. try {
  82. HRESULT hr;
  83. os_set_thread_name(__FUNCTION__);
  84. hr = grammar->SetDictationState(SPRS_ACTIVE);
  85. if (FAILED(hr))
  86. throw HRError("SetDictationState failed", hr);
  87. hr = recognizer->SetRecoState(SPRST_ACTIVE);
  88. if (FAILED(hr))
  89. throw HRError("SetRecoState(SPRST_ACTIVE) failed", hr);
  90. HANDLE events[] = {notify, stop};
  91. started = true;
  92. for (;;) {
  93. DWORD ret = WaitForMultipleObjects(2, events, false, INFINITE);
  94. if (ret != WAIT_OBJECT_0)
  95. break;
  96. CSpEvent event;
  97. bool exit = false;
  98. while (event.GetFrom(context) == S_OK) {
  99. if (event.eEventId == SPEI_RECOGNITION) {
  100. ISpRecoResult *result = event.RecoResult();
  101. CoTaskMemPtr<wchar_t> text;
  102. hr = result->GetText((ULONG)-1, (ULONG)-1,
  103. true, &text, nullptr);
  104. if (FAILED(hr))
  105. continue;
  106. char text_utf8[512];
  107. os_wcs_to_utf8(text, 0, text_utf8, 512);
  108. callback(text_utf8);
  109. blog(LOG_DEBUG, "\"%s\"", text_utf8);
  110. } else if (event.eEventId == SPEI_END_SR_STREAM) {
  111. exit = true;
  112. break;
  113. }
  114. }
  115. if (exit)
  116. break;
  117. }
  118. audio->Stop();
  119. } catch (HRError err) {
  120. blog(LOG_WARNING, "%s failed: %s (%lX)", __FUNCTION__, err.str, err.hr);
  121. }
  122. void mssapi_captions::pcm_data(const void *data, size_t frames)
  123. {
  124. if (started)
  125. audio->PushAudio(data, frames);
  126. }
  127. captions_handler_info mssapi_info = {
  128. [] () -> std::string
  129. {
  130. return "Microsoft Speech-to-Text";
  131. },
  132. [] (captions_cb cb, const std::string &lang) -> captions_handler *
  133. {
  134. return new mssapi_captions(cb, lang);
  135. }
  136. };