captions.cpp 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354
  1. #include <obs-frontend-api.h>
  2. #include "captions-stream.hpp"
  3. #include "captions.hpp"
  4. #include "tool-helpers.hpp"
  5. #include <sphelper.h>
  6. #include <util/platform.h>
  7. #include <util/windows/HRError.hpp>
  8. #include <util/windows/ComPtr.hpp>
  9. #include <util/windows/CoTaskMemPtr.hpp>
  10. #include <util/threading.h>
  11. #include <obs-module.h>
  12. #include <string>
  13. #include <thread>
  14. #include <mutex>
  15. #define do_log(type, format, ...) blog(type, "[Captions] " format, \
  16. ##__VA_ARGS__)
  17. #define error(format, ...) do_log(LOG_ERROR, format, ##__VA_ARGS__)
  18. #define debug(format, ...) do_log(LOG_DEBUG, format, ##__VA_ARGS__)
  19. using namespace std;
  20. struct obs_captions {
  21. thread th;
  22. recursive_mutex m;
  23. WinHandle stop_event;
  24. string source_name;
  25. OBSWeakSource source;
  26. void main_thread();
  27. void start();
  28. void stop();
  29. inline obs_captions() :
  30. stop_event(CreateEvent(nullptr, false, false, nullptr))
  31. {
  32. }
  33. inline ~obs_captions() {stop();}
  34. };
  35. static obs_captions *captions = nullptr;
  36. /* ------------------------------------------------------------------------- */
  37. CaptionsDialog::CaptionsDialog(QWidget *parent) :
  38. QDialog(parent),
  39. ui(new Ui_CaptionsDialog)
  40. {
  41. ui->setupUi(this);
  42. lock_guard<recursive_mutex> lock(captions->m);
  43. auto cb = [this] (obs_source_t *source)
  44. {
  45. uint32_t caps = obs_source_get_output_flags(source);
  46. QString name = obs_source_get_name(source);
  47. if (caps & OBS_SOURCE_AUDIO)
  48. ui->source->addItem(name);
  49. OBSWeakSource weak = OBSGetWeakRef(source);
  50. if (weak == captions->source)
  51. ui->source->setCurrentText(name);
  52. return true;
  53. };
  54. using cb_t = decltype(cb);
  55. ui->source->blockSignals(true);
  56. ui->source->addItem(QStringLiteral(""));
  57. ui->source->setCurrentIndex(0);
  58. obs_enum_sources([] (void *data, obs_source_t *source) {
  59. return (*static_cast<cb_t*>(data))(source);}, &cb);
  60. ui->source->blockSignals(false);
  61. ui->enable->blockSignals(true);
  62. ui->enable->setChecked(captions->th.joinable());
  63. ui->enable->blockSignals(false);
  64. }
  65. void CaptionsDialog::on_source_currentIndexChanged(int)
  66. {
  67. bool started = captions->th.joinable();
  68. if (started)
  69. captions->stop();
  70. captions->m.lock();
  71. captions->source_name = ui->source->currentText().toUtf8().constData();
  72. captions->source = GetWeakSourceByName(captions->source_name.c_str());
  73. captions->m.unlock();
  74. if (started)
  75. captions->start();
  76. }
  77. void CaptionsDialog::on_enable_clicked(bool checked)
  78. {
  79. if (checked)
  80. captions->start();
  81. else
  82. captions->stop();
  83. }
  84. /* ------------------------------------------------------------------------- */
  85. void obs_captions::main_thread()
  86. try {
  87. ComPtr<CaptionStream> audio;
  88. ComPtr<ISpObjectToken> token;
  89. ComPtr<ISpRecoGrammar> grammar;
  90. ComPtr<ISpRecognizer> recognizer;
  91. ComPtr<ISpRecoContext> context;
  92. HRESULT hr;
  93. auto cb = [&] (const struct audio_data *audio_data,
  94. bool muted)
  95. {
  96. audio->PushAudio(audio_data, muted);
  97. };
  98. using cb_t = decltype(cb);
  99. auto pre_cb = [] (void *param, obs_source_t*,
  100. const struct audio_data *audio_data, bool muted)
  101. {
  102. return (*static_cast<cb_t*>(param))(audio_data, muted);
  103. };
  104. os_set_thread_name(__FUNCTION__);
  105. CoInitialize(nullptr);
  106. hr = SpFindBestToken(SPCAT_RECOGNIZERS, L"language=409", nullptr,
  107. &token);
  108. if (FAILED(hr))
  109. throw HRError("SpFindBestToken failed", hr);
  110. hr = CoCreateInstance(CLSID_SpInprocRecognizer, nullptr, CLSCTX_ALL,
  111. __uuidof(ISpRecognizer), (void**)&recognizer);
  112. if (FAILED(hr))
  113. throw HRError("CoCreateInstance for recognizer failed", hr);
  114. hr = recognizer->SetRecognizer(token);
  115. if (FAILED(hr))
  116. throw HRError("SetRecognizer failed", hr);
  117. hr = recognizer->SetRecoState(SPRST_INACTIVE);
  118. if (FAILED(hr))
  119. throw HRError("SetRecoState(SPRST_INACTIVE) failed", hr);
  120. hr = recognizer->CreateRecoContext(&context);
  121. if (FAILED(hr))
  122. throw HRError("CreateRecoContext failed", hr);
  123. ULONGLONG interest = SPFEI(SPEI_RECOGNITION) |
  124. SPFEI(SPEI_END_SR_STREAM);
  125. hr = context->SetInterest(interest, interest);
  126. if (FAILED(hr))
  127. throw HRError("SetInterest failed", hr);
  128. HANDLE notify;
  129. hr = context->SetNotifyWin32Event();
  130. if (FAILED(hr))
  131. throw HRError("SetNotifyWin32Event", hr);
  132. notify = context->GetNotifyEventHandle();
  133. if (notify == INVALID_HANDLE_VALUE)
  134. throw HRError("GetNotifyEventHandle failed", E_NOINTERFACE);
  135. size_t sample_rate = audio_output_get_sample_rate(obs_get_audio());
  136. audio = new CaptionStream((DWORD)sample_rate);
  137. audio->Release();
  138. hr = recognizer->SetInput(audio, false);
  139. if (FAILED(hr))
  140. throw HRError("SetInput failed", hr);
  141. hr = context->CreateGrammar(1, &grammar);
  142. if (FAILED(hr))
  143. throw HRError("CreateGrammar failed", hr);
  144. hr = grammar->LoadDictation(nullptr, SPLO_STATIC);
  145. if (FAILED(hr))
  146. throw HRError("LoadDictation failed", hr);
  147. hr = grammar->SetDictationState(SPRS_ACTIVE);
  148. if (FAILED(hr))
  149. throw HRError("SetDictationState failed", hr);
  150. hr = recognizer->SetRecoState(SPRST_ACTIVE);
  151. if (FAILED(hr))
  152. throw HRError("SetRecoState(SPRST_ACTIVE) failed", hr);
  153. HANDLE events[] = {notify, stop_event};
  154. {
  155. captions->source = GetWeakSourceByName(
  156. captions->source_name.c_str());
  157. OBSSource strong = OBSGetStrongRef(source);
  158. if (strong)
  159. obs_source_add_audio_capture_callback(strong,
  160. pre_cb, &cb);
  161. }
  162. for (;;) {
  163. DWORD ret = WaitForMultipleObjects(2, events, false, INFINITE);
  164. if (ret != WAIT_OBJECT_0)
  165. break;
  166. CSpEvent event;
  167. bool exit = false;
  168. while (event.GetFrom(context) == S_OK) {
  169. if (event.eEventId == SPEI_RECOGNITION) {
  170. ISpRecoResult *result = event.RecoResult();
  171. CoTaskMemPtr<wchar_t> text;
  172. hr = result->GetText((ULONG)-1, (ULONG)-1,
  173. true, &text, nullptr);
  174. if (FAILED(hr))
  175. continue;
  176. char text_utf8[512];
  177. os_wcs_to_utf8(text, 0, text_utf8, 512);
  178. obs_output_t *output =
  179. obs_frontend_get_streaming_output();
  180. if (output)
  181. obs_output_output_caption_text1(output,
  182. text_utf8);
  183. debug("\"%s\"", text_utf8);
  184. obs_output_release(output);
  185. } else if (event.eEventId == SPEI_END_SR_STREAM) {
  186. exit = true;
  187. break;
  188. }
  189. }
  190. if (exit)
  191. break;
  192. }
  193. {
  194. OBSSource strong = OBSGetStrongRef(source);
  195. if (strong)
  196. obs_source_remove_audio_capture_callback(strong,
  197. pre_cb, &cb);
  198. }
  199. audio->Stop();
  200. CoUninitialize();
  201. } catch (HRError err) {
  202. error("%s failed: %s (%lX)", __FUNCTION__, err.str, err.hr);
  203. CoUninitialize();
  204. }
  205. void obs_captions::start()
  206. {
  207. if (!captions->th.joinable())
  208. captions->th = thread([] () {captions->main_thread();});
  209. }
  210. void obs_captions::stop()
  211. {
  212. if (!captions->th.joinable())
  213. return;
  214. SetEvent(captions->stop_event);
  215. captions->th.join();
  216. }
  217. /* ------------------------------------------------------------------------- */
  218. extern "C" void FreeCaptions()
  219. {
  220. delete captions;
  221. captions = nullptr;
  222. }
  223. static void obs_event(enum obs_frontend_event event, void *)
  224. {
  225. if (event == OBS_FRONTEND_EVENT_EXIT)
  226. FreeCaptions();
  227. }
  228. static void save_caption_data(obs_data_t *save_data, bool saving, void*)
  229. {
  230. if (saving) {
  231. lock_guard<recursive_mutex> lock(captions->m);
  232. obs_data_t *obj = obs_data_create();
  233. obs_data_set_string(obj, "source",
  234. captions->source_name.c_str());
  235. obs_data_set_bool(obj, "enabled", captions->th.joinable());
  236. obs_data_set_obj(save_data, "captions", obj);
  237. obs_data_release(obj);
  238. } else {
  239. captions->stop();
  240. captions->m.lock();
  241. obs_data_t *obj = obs_data_get_obj(save_data, "captions");
  242. if (!obj)
  243. obj = obs_data_create();
  244. bool enabled = obs_data_get_bool(obj, "enabled");
  245. captions->source_name = obs_data_get_string(obj, "source");
  246. captions->source = GetWeakSourceByName(
  247. captions->source_name.c_str());
  248. obs_data_release(obj);
  249. captions->m.unlock();
  250. if (enabled)
  251. captions->start();
  252. }
  253. }
  254. extern "C" void InitCaptions()
  255. {
  256. QAction *action = (QAction*)obs_frontend_add_tools_menu_qaction(
  257. obs_module_text("Captions"));
  258. captions = new obs_captions;
  259. auto cb = [] ()
  260. {
  261. obs_frontend_push_ui_translation(obs_module_get_string);
  262. QWidget *window =
  263. (QWidget*)obs_frontend_get_main_window();
  264. CaptionsDialog dialog(window);
  265. dialog.exec();
  266. obs_frontend_pop_ui_translation();
  267. };
  268. obs_frontend_add_save_callback(save_caption_data, nullptr);
  269. obs_frontend_add_event_callback(obs_event, nullptr);
  270. action->connect(action, &QAction::triggered, cb);
  271. }