updater.cpp 48 KB


  1. /*
  2. * Copyright (c) 2023 Lain Bailey <[email protected]>
  3. *
  4. * Permission to use, copy, modify, and distribute this software for any
  5. * purpose with or without fee is hereby granted, provided that the above
  6. * copyright notice and this permission notice appear in all copies.
  7. *
  8. * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
  9. * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
  10. * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
  11. * ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
  12. * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
  13. * ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
  14. * OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
  15. */
  16. #include "updater.hpp"
  17. #include "manifest.hpp"
  18. #include <psapi.h>
  19. #include <WinTrust.h>
  20. #include <SoftPub.h>
  21. #include <util/windows/CoTaskMemPtr.hpp>
  22. #include <future>
  23. #include <string>
  24. #include <string_view>
  25. #include <mutex>
  26. #include <unordered_set>
  27. #include <queue>
  28. using namespace std;
  29. using namespace updater;
  30. /* ----------------------------------------------------------------------- */
  31. constexpr const string_view kCDNUrl = "https://cdn-fastly.obsproject.com/";
  32. constexpr const wchar_t *kCDNHostname = L"cdn-fastly.obsproject.com";
  33. constexpr const wchar_t *kCDNUpdateBaseUrl = L"https://cdn-fastly.obsproject.com/update_studio";
  34. constexpr const wchar_t *kPatchManifestURL = L"https://obsproject.com/update_studio/getpatchmanifest";
  35. constexpr const wchar_t *kVSRedistURL = L"https://aka.ms/vs/17/release/vc_redist.x64.exe";
  36. constexpr const wchar_t *kMSHostname = L"aka.ms";
  37. /* ----------------------------------------------------------------------- */
  38. HANDLE cancelRequested = nullptr;
  39. HANDLE updateThread = nullptr;
  40. HINSTANCE hinstMain = nullptr;
  41. HWND hwndMain = nullptr;
  42. HCRYPTPROV hProvider = 0;
  43. static bool bExiting = false;
  44. static bool updateFailed = false;
  45. static bool downloadThreadFailure = false;
  46. size_t totalFileSize = 0;
  47. size_t completedFileSize = 0;
  48. static int completedUpdates = 0;
  49. static wchar_t tempPath[MAX_PATH];
  50. static wchar_t obs_base_directory[MAX_PATH];
  51. struct LastError {
  52. DWORD code;
  53. inline LastError() { code = GetLastError(); }
  54. };
  55. void FreeWinHttpHandle(HINTERNET handle)
  56. {
  57. WinHttpCloseHandle(handle);
  58. }
  59. /* ----------------------------------------------------------------------- */
  60. static bool IsVSRedistOutdated()
  61. {
  62. VS_FIXEDFILEINFO *info = nullptr;
  63. UINT len = 0;
  64. vector<std::byte> buf;
  65. const wchar_t vc_dll[] = L"msvcp140";
  66. auto size = GetFileVersionInfoSize(vc_dll, nullptr);
  67. if (!size)
  68. return true;
  69. buf.resize(size);
  70. if (!GetFileVersionInfo(vc_dll, 0, size, buf.data()))
  71. return true;
  72. bool success = VerQueryValue(buf.data(), L"\\", reinterpret_cast<LPVOID *>(&info), &len);
  73. if (!success || !info || !len)
  74. return true;
  75. return LOWORD(info->dwFileVersionMS) < 40;
  76. }
  77. static void Status(const wchar_t *fmt, ...)
  78. {
  79. wchar_t str[512];
  80. va_list argptr;
  81. va_start(argptr, fmt);
  82. StringCbVPrintf(str, sizeof(str), fmt, argptr);
  83. SetDlgItemText(hwndMain, IDC_STATUS, str);
  84. va_end(argptr);
  85. }
  86. static bool MyCopyFile(const wchar_t *src, const wchar_t *dest)
  87. try {
  88. WinHandle hSrc;
  89. WinHandle hDest;
  90. hSrc = CreateFile(src, GENERIC_READ, FILE_SHARE_READ, nullptr, OPEN_EXISTING, FILE_FLAG_SEQUENTIAL_SCAN,
  91. nullptr);
  92. if (!hSrc.Valid())
  93. throw LastError();
  94. hDest = CreateFile(dest, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, 0, nullptr);
  95. if (!hDest.Valid())
  96. throw LastError();
  97. BYTE buf[65536];
  98. DWORD read, wrote;
  99. for (;;) {
  100. if (!ReadFile(hSrc, buf, sizeof(buf), &read, nullptr))
  101. throw LastError();
  102. if (read == 0)
  103. break;
  104. if (!WriteFile(hDest, buf, read, &wrote, nullptr))
  105. throw LastError();
  106. if (wrote != read)
  107. return false;
  108. }
  109. return true;
  110. } catch (LastError &error) {
  111. SetLastError(error.code);
  112. return false;
  113. }
  114. static void MyDeleteFile(const wstring &filename)
  115. {
  116. /* Try straightforward delete first */
  117. if (DeleteFile(filename.c_str()))
  118. return;
  119. DWORD err = GetLastError();
  120. if (err == ERROR_FILE_NOT_FOUND)
  121. return;
  122. /* If all else fails, schedule the file to be deleted on reboot */
  123. MoveFileEx(filename.c_str(), nullptr, MOVEFILE_DELAY_UNTIL_REBOOT);
  124. }
  125. static bool IsSafeFilename(const wchar_t *path)
  126. {
  127. const wchar_t *p = path;
  128. if (!*p)
  129. return false;
  130. if (wcsstr(path, L".."))
  131. return false;
  132. if (*p == '/')
  133. return false;
  134. while (*p) {
  135. if (!isalnum(*p) && *p != '.' && *p != '/' && *p != '_' && *p != '-')
  136. return false;
  137. p++;
  138. }
  139. return true;
  140. }
  141. static string QuickReadFile(const wchar_t *path)
  142. {
  143. string data;
  144. WinHandle handle = CreateFileW(path, GENERIC_READ, 0, nullptr, OPEN_EXISTING, 0, nullptr);
  145. if (!handle.Valid()) {
  146. return {};
  147. }
  148. LARGE_INTEGER size;
  149. if (!GetFileSizeEx(handle, &size)) {
  150. return {};
  151. }
  152. data.resize((size_t)size.QuadPart);
  153. DWORD read;
  154. if (!ReadFile(handle, data.data(), (DWORD)data.size(), &read, nullptr)) {
  155. return {};
  156. }
  157. if (read != size.QuadPart) {
  158. return {};
  159. }
  160. return data;
  161. }
  162. static bool QuickWriteFile(const wchar_t *file, const void *data, size_t size)
  163. try {
  164. WinHandle handle = CreateFile(file, GENERIC_WRITE, 0, nullptr, CREATE_ALWAYS, FILE_FLAG_WRITE_THROUGH, nullptr);
  165. if (handle == INVALID_HANDLE_VALUE)
  166. throw GetLastError();
  167. DWORD written;
  168. if (!WriteFile(handle, data, (DWORD)size, &written, nullptr))
  169. throw GetLastError();
  170. return true;
  171. } catch (LastError &error) {
  172. SetLastError(error.code);
  173. return false;
  174. }
  175. /* ----------------------------------------------------------------------- */
  176. /* Extend std::hash for B2Hash */
  177. template<> struct std::hash<B2Hash> {
  178. size_t operator()(const B2Hash &value) const noexcept
  179. {
  180. return hash<string_view>{}(string_view(reinterpret_cast<const char *>(value.data()), value.size()));
  181. }
  182. };
  183. enum state_t {
  184. STATE_INVALID,
  185. STATE_PENDING_DOWNLOAD,
  186. STATE_DOWNLOADING,
  187. STATE_DOWNLOADED,
  188. STATE_ALREADY_DOWNLOADED,
  189. STATE_INSTALL_FAILED,
  190. STATE_INSTALLED,
  191. };
  192. struct update_t {
  193. wstring sourceURL;
  194. wstring outputPath;
  195. wstring previousFile;
  196. string packageName;
  197. B2Hash hash;
  198. B2Hash my_hash;
  199. B2Hash downloadHash;
  200. size_t fileSize = 0;
  201. state_t state = STATE_INVALID;
  202. bool has_hash = false;
  203. bool patchable = false;
  204. bool compressed = false;
  205. update_t() = default;
  206. update_t(const update_t &from) = default;
  207. update_t(update_t &&from) noexcept
  208. : sourceURL(std::move(from.sourceURL)),
  209. outputPath(std::move(from.outputPath)),
  210. previousFile(std::move(from.previousFile)),
  211. packageName(std::move(from.packageName)),
  212. hash(from.hash),
  213. my_hash(from.my_hash),
  214. downloadHash(from.downloadHash),
  215. fileSize(from.fileSize),
  216. state(from.state),
  217. has_hash(from.has_hash),
  218. patchable(from.patchable),
  219. compressed(from.compressed)
  220. {
  221. from.state = STATE_INVALID;
  222. }
  223. void CleanPartialUpdate() const
  224. {
  225. if (state == STATE_INSTALL_FAILED || state == STATE_INSTALLED) {
  226. if (!previousFile.empty()) {
  227. DeleteFile(outputPath.c_str());
  228. MyCopyFile(previousFile.c_str(), outputPath.c_str());
  229. DeleteFile(previousFile.c_str());
  230. } else {
  231. DeleteFile(outputPath.c_str());
  232. }
  233. }
  234. }
  235. update_t &operator=(const update_t &from) = default;
  236. };
  237. struct deletion_t {
  238. wstring originalFilename;
  239. wstring deleteMeFilename;
  240. void UndoRename() const
  241. {
  242. if (!deleteMeFilename.empty())
  243. MoveFile(deleteMeFilename.c_str(), originalFilename.c_str());
  244. }
  245. };
  246. static unordered_map<B2Hash, vector<std::byte>> download_data;
  247. static unordered_map<string, B2Hash> hashes;
  248. static vector<update_t> updates;
  249. static vector<deletion_t> deletions;
  250. static mutex updateMutex;
  251. static inline void CleanupPartialUpdates()
  252. {
  253. for (update_t &update : updates)
  254. update.CleanPartialUpdate();
  255. for (deletion_t &deletion : deletions)
  256. deletion.UndoRename();
  257. }
  258. /* ----------------------------------------------------------------------- */
  259. static int Decompress(ZSTD_DCtx *ctx, std::vector<std::byte> &buf, size_t size)
  260. {
  261. // Copy compressed data
  262. vector<std::byte> comp(buf.begin(), buf.end());
  263. try {
  264. buf.resize(size);
  265. } catch (...) {
  266. return -1;
  267. }
  268. // Overwrite buffer with decompressed data
  269. size_t result = ZSTD_decompressDCtx(ctx, buf.data(), buf.size(), comp.data(), comp.size());
  270. if (result != size)
  271. return -9;
  272. if (ZSTD_isError(result))
  273. return -10;
  274. return 0;
  275. }
  276. bool DownloadWorkerThread()
  277. {
  278. const DWORD tlsProtocols = WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2 | WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_3;
  279. const DWORD enableHTTP2Flag = WINHTTP_PROTOCOL_FLAG_HTTP2;
  280. const DWORD compressionFlags = WINHTTP_DECOMPRESSION_FLAG_ALL;
  281. HttpHandle hSession = WinHttpOpen(L"OBS Studio Updater/3.0", WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY,
  282. WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0);
  283. if (!hSession) {
  284. downloadThreadFailure = true;
  285. Status(L"Update failed: Couldn't open obsproject.com");
  286. return false;
  287. }
  288. WinHttpSetOption(hSession, WINHTTP_OPTION_SECURE_PROTOCOLS, (LPVOID)&tlsProtocols, sizeof(tlsProtocols));
  289. WinHttpSetOption(hSession, WINHTTP_OPTION_ENABLE_HTTP_PROTOCOL, (LPVOID)&enableHTTP2Flag,
  290. sizeof(enableHTTP2Flag));
  291. WinHttpSetOption(hSession, WINHTTP_OPTION_DECOMPRESSION, (LPVOID)&compressionFlags, sizeof(compressionFlags));
  292. HttpHandle hConnect = WinHttpConnect(hSession, kCDNHostname, INTERNET_DEFAULT_HTTPS_PORT, 0);
  293. if (!hConnect) {
  294. downloadThreadFailure = true;
  295. Status(L"Update failed: Couldn't connect to %S", kCDNHostname);
  296. return false;
  297. }
  298. ZSTDDCtx zCtx;
  299. for (;;) {
  300. bool foundWork = false;
  301. unique_lock<mutex> ulock(updateMutex);
  302. for (update_t &update : updates) {
  303. int responseCode;
  304. DWORD waitResult = WaitForSingleObject(cancelRequested, 0);
  305. if (waitResult == WAIT_OBJECT_0) {
  306. return false;
  307. }
  308. if (update.state != STATE_PENDING_DOWNLOAD)
  309. continue;
  310. update.state = STATE_DOWNLOADING;
  311. ulock.unlock();
  312. foundWork = true;
  313. if (downloadThreadFailure) {
  314. return false;
  315. }
  316. auto &buf = download_data[update.downloadHash];
  317. /* Reserve required memory */
  318. buf.reserve(update.fileSize);
  319. if (!HTTPGetBuffer(hConnect, update.sourceURL.c_str(), L"Accept-Encoding: gzip", buf,
  320. &responseCode)) {
  321. downloadThreadFailure = true;
  322. Status(L"Update failed: Could not download "
  323. L"%s (error code %d)",
  324. update.outputPath.c_str(), responseCode);
  325. return true;
  326. }
  327. if (responseCode != 200) {
  328. downloadThreadFailure = true;
  329. Status(L"Update failed: Could not download "
  330. L"%s (error code %d)",
  331. update.outputPath.c_str(), responseCode);
  332. return true;
  333. }
  334. /* Validate hash of downloaded data. */
  335. B2Hash dataHash;
  336. blake2b(dataHash.data(), dataHash.size(), buf.data(), buf.size(), nullptr, 0);
  337. if (dataHash != update.downloadHash) {
  338. downloadThreadFailure = true;
  339. Status(L"Update failed: Integrity check "
  340. L"failed on %s",
  341. update.outputPath.c_str());
  342. return true;
  343. }
  344. /* Decompress data in compressed buffer. */
  345. if (update.compressed && !update.patchable) {
  346. int res = Decompress(zCtx, buf, update.fileSize);
  347. if (res) {
  348. downloadThreadFailure = true;
  349. Status(L"Update failed: Decompression "
  350. L"failed on %s (error code %d)",
  351. update.outputPath.c_str(), res);
  352. return true;
  353. }
  354. }
  355. ulock.lock();
  356. update.state = STATE_DOWNLOADED;
  357. completedUpdates++;
  358. }
  359. if (!foundWork) {
  360. break;
  361. }
  362. if (downloadThreadFailure) {
  363. return false;
  364. }
  365. }
  366. return true;
  367. }
  368. static bool RunDownloadWorkers(int num)
  369. try {
  370. vector<future<bool>> thread_success_results;
  371. thread_success_results.resize(num);
  372. for (future<bool> &result : thread_success_results) {
  373. result = async(DownloadWorkerThread);
  374. }
  375. for (future<bool> &result : thread_success_results) {
  376. if (!result.get()) {
  377. return false;
  378. }
  379. }
  380. return true;
  381. } catch (...) {
  382. return false;
  383. }
  384. /* ----------------------------------------------------------------------- */
  385. enum { WAITIFOBS_SUCCESS, WAITIFOBS_WRONG_PROCESS, WAITIFOBS_CANCELLED };
  386. static inline DWORD WaitIfOBS(DWORD id, const wchar_t *expected)
  387. {
  388. wchar_t path[MAX_PATH];
  389. wchar_t *name;
  390. DWORD path_len = _countof(path);
  391. *path = 0;
  392. WinHandle proc = OpenProcess(PROCESS_QUERY_INFORMATION | PROCESS_VM_READ | SYNCHRONIZE, false, id);
  393. if (!proc.Valid())
  394. return WAITIFOBS_WRONG_PROCESS;
  395. if (!QueryFullProcessImageNameW(proc, 0, path, &path_len))
  396. return WAITIFOBS_WRONG_PROCESS;
  397. // check it's actually our exe that's running
  398. size_t len = wcslen(obs_base_directory);
  399. if (wcsncmp(path, obs_base_directory, len) != 0)
  400. return WAITIFOBS_WRONG_PROCESS;
  401. name = wcsrchr(path, L'\\');
  402. if (name)
  403. name += 1;
  404. else
  405. name = path;
  406. if (_wcsnicmp(name, expected, 5) == 0) {
  407. HANDLE hWait[2];
  408. hWait[0] = proc;
  409. hWait[1] = cancelRequested;
  410. int i = WaitForMultipleObjects(2, hWait, false, INFINITE);
  411. if (i == WAIT_OBJECT_0 + 1)
  412. return WAITIFOBS_CANCELLED;
  413. return WAITIFOBS_SUCCESS;
  414. }
  415. return WAITIFOBS_WRONG_PROCESS;
  416. }
  417. static bool WaitForOBS()
  418. {
  419. DWORD proc_ids[1024], needed, count;
  420. if (!EnumProcesses(proc_ids, sizeof(proc_ids), &needed)) {
  421. return true;
  422. }
  423. count = needed / sizeof(DWORD);
  424. for (DWORD i = 0; i < count; i++) {
  425. DWORD id = proc_ids[i];
  426. if (id != 0) {
  427. switch (WaitIfOBS(id, L"obs64")) {
  428. case WAITIFOBS_SUCCESS:
  429. return true;
  430. case WAITIFOBS_WRONG_PROCESS:
  431. break;
  432. case WAITIFOBS_CANCELLED:
  433. return false;
  434. }
  435. }
  436. }
  437. return true;
  438. }
  439. /* ----------------------------------------------------------------------- */
  440. static inline bool UTF8ToWide(wchar_t *wide, int wideSize, const char *utf8)
  441. {
  442. return !!MultiByteToWideChar(CP_UTF8, 0, utf8, -1, wide, wideSize);
  443. }
  444. static inline bool WideToUTF8(char *utf8, int utf8Size, const wchar_t *wide)
  445. {
  446. return !!WideCharToMultiByte(CP_UTF8, 0, wide, -1, utf8, utf8Size, nullptr, nullptr);
  447. }
  448. #define UTF8ToWideBuf(wide, utf8) UTF8ToWide(wide, _countof(wide), utf8)
  449. #define WideToUTF8Buf(utf8, wide) WideToUTF8(utf8, _countof(utf8), wide)
  450. /* ----------------------------------------------------------------------- */
  451. queue<string> hashQueue;
  452. void HasherThread()
  453. {
  454. unique_lock ulock(updateMutex, defer_lock);
  455. while (true) {
  456. ulock.lock();
  457. if (hashQueue.empty())
  458. return;
  459. auto fileName = hashQueue.front();
  460. hashQueue.pop();
  461. ulock.unlock();
  462. wchar_t updateFileName[MAX_PATH];
  463. if (!UTF8ToWideBuf(updateFileName, fileName.c_str()))
  464. continue;
  465. if (!IsSafeFilename(updateFileName))
  466. continue;
  467. B2Hash existingHash;
  468. if (CalculateFileHash(updateFileName, existingHash)) {
  469. ulock.lock();
  470. hashes.emplace(fileName, existingHash);
  471. ulock.unlock();
  472. }
  473. }
  474. }
  475. static void RunHasherWorkers(int num, const vector<Package> &packages)
  476. try {
  477. for (const Package &package : packages) {
  478. for (const File &file : package.files) {
  479. hashQueue.push(file.name);
  480. }
  481. }
  482. vector<future<void>> futures;
  483. futures.resize(num);
  484. for (auto &result : futures) {
  485. result = async(launch::async, HasherThread);
  486. }
  487. for (auto &result : futures) {
  488. result.wait();
  489. }
  490. } catch (...) {
  491. }
  492. /* ----------------------------------------------------------------------- */
  493. static inline bool FileExists(const wchar_t *path)
  494. {
  495. WIN32_FIND_DATAW wfd;
  496. HANDLE hFind;
  497. hFind = FindFirstFileW(path, &wfd);
  498. if (hFind != INVALID_HANDLE_VALUE)
  499. FindClose(hFind);
  500. return hFind != INVALID_HANDLE_VALUE;
  501. }
  502. static bool NonCorePackageInstalled(const char *name)
  503. {
  504. if (strcmp(name, "obs-browser") == 0)
  505. return FileExists(L"obs-plugins\\64bit\\obs-browser.dll");
  506. return false;
  507. }
  508. static bool AddPackageUpdateFiles(const Package &package, const wchar_t *branch)
  509. {
  510. wchar_t wPackageName[512];
  511. if (!UTF8ToWideBuf(wPackageName, package.name.c_str()))
  512. return false;
  513. if (package.name != "core" && !NonCorePackageInstalled(package.name.c_str()))
  514. return true;
  515. for (const File &file : package.files) {
  516. if (file.hash.size() != kBlake2StrLength)
  517. continue;
  518. /* The download hash may not exist if a file is uncompressed */
  519. bool compressed = false;
  520. if (file.compressed_hash.size() == kBlake2StrLength)
  521. compressed = true;
  522. /* convert strings to wide */
  523. wchar_t sourceURL[1024];
  524. wchar_t updateFileName[MAX_PATH];
  525. if (!UTF8ToWideBuf(updateFileName, file.name.c_str()))
  526. continue;
  527. /* make sure paths are safe */
  528. if (!IsSafeFilename(updateFileName)) {
  529. Status(L"Update failed: Unsafe path '%s' found in "
  530. L"manifest",
  531. updateFileName);
  532. return false;
  533. }
  534. StringCbPrintf(sourceURL, sizeof(sourceURL), L"%s/%s/%s/%s", kCDNUpdateBaseUrl, branch, wPackageName,
  535. updateFileName);
  536. /* Convert hashes */
  537. B2Hash updateHash;
  538. StringToHash(file.hash, updateHash);
  539. /* We don't really care if this fails, it's just to avoid
  540. * wasting bandwidth by downloading unmodified files */
  541. B2Hash localFileHash;
  542. bool has_hash = false;
  543. if (hashes.count(file.name)) {
  544. localFileHash = hashes[file.name];
  545. if (localFileHash == updateHash)
  546. continue;
  547. has_hash = true;
  548. }
  549. /* Add update file */
  550. update_t update;
  551. update.fileSize = file.size;
  552. update.outputPath = updateFileName;
  553. update.sourceURL = sourceURL;
  554. update.packageName = package.name;
  555. update.state = STATE_PENDING_DOWNLOAD;
  556. update.patchable = false;
  557. update.compressed = compressed;
  558. update.hash = updateHash;
  559. if (compressed) {
  560. update.sourceURL += L".zst";
  561. StringToHash(file.compressed_hash, update.downloadHash);
  562. } else {
  563. update.downloadHash = updateHash;
  564. }
  565. update.has_hash = has_hash;
  566. if (has_hash)
  567. update.my_hash = localFileHash;
  568. updates.push_back(std::move(update));
  569. totalFileSize += file.size;
  570. }
  571. return true;
  572. }
  573. static void AddPackageRemovedFiles(const Package &package)
  574. {
  575. for (const string &filename : package.removed_files) {
  576. wchar_t removedFileName[MAX_PATH];
  577. if (!UTF8ToWideBuf(removedFileName, filename.c_str()))
  578. continue;
  579. /* Ensure paths are safe, also check if file exists */
  580. if (!IsSafeFilename(removedFileName))
  581. continue;
  582. /* Technically GetFileAttributes can fail for other reasons,
  583. * so double-check by also checking the last error */
  584. if (GetFileAttributesW(removedFileName) == INVALID_FILE_ATTRIBUTES) {
  585. int err = GetLastError();
  586. if (err == ERROR_FILE_NOT_FOUND || err == ERROR_PATH_NOT_FOUND)
  587. continue;
  588. }
  589. deletion_t deletion;
  590. deletion.originalFilename = removedFileName;
  591. deletions.push_back(deletion);
  592. }
  593. }
  594. static bool RenameRemovedFile(deletion_t &deletion)
  595. {
  596. _TCHAR deleteMeName[MAX_PATH];
  597. _TCHAR randomStr[MAX_PATH];
  598. BYTE junk[40];
  599. B2Hash hash;
  600. string temp;
  601. CryptGenRandom(hProvider, sizeof(junk), junk);
  602. blake2b(hash.data(), hash.size(), junk, sizeof(junk), nullptr, 0);
  603. HashToString(hash, temp);
  604. if (!UTF8ToWideBuf(randomStr, temp.c_str()))
  605. return false;
  606. randomStr[8] = 0;
  607. StringCbCopy(deleteMeName, sizeof(deleteMeName), deletion.originalFilename.c_str());
  608. StringCbCat(deleteMeName, sizeof(deleteMeName), L".");
  609. StringCbCat(deleteMeName, sizeof(deleteMeName), randomStr);
  610. StringCbCat(deleteMeName, sizeof(deleteMeName), L".deleteme");
  611. if (MoveFile(deletion.originalFilename.c_str(), deleteMeName)) {
  612. /* Only set this if the file was successfully renamed */
  613. deletion.deleteMeFilename = deleteMeName;
  614. return true;
  615. }
  616. return false;
  617. }
  618. static void UpdateWithPatchIfAvailable(const PatchResponse &patch)
  619. {
  620. wchar_t widePatchableFilename[MAX_PATH];
  621. wchar_t sourceURL[1024];
  622. if (patch.source.compare(0, kCDNUrl.size(), kCDNUrl) != 0)
  623. return;
  624. if (patch.name.find('/') == string::npos)
  625. return;
  626. string patchPackageName(patch.name, 0, patch.name.find('/'));
  627. string fileName(patch.name, patch.name.find('/') + 1);
  628. if (!UTF8ToWideBuf(widePatchableFilename, fileName.c_str()))
  629. return;
  630. if (!UTF8ToWideBuf(sourceURL, patch.source.c_str()))
  631. return;
  632. for (update_t &update : updates) {
  633. if (update.packageName != patchPackageName)
  634. continue;
  635. if (update.outputPath != widePatchableFilename)
  636. continue;
  637. update.patchable = true;
  638. /* Replace the source URL with the patch file, update
  639. * the download hash, and re-calculate download size */
  640. StringToHash(patch.hash, update.downloadHash);
  641. update.sourceURL = sourceURL;
  642. totalFileSize -= (update.fileSize - patch.size);
  643. update.fileSize = patch.size;
  644. break;
  645. }
  646. }
  647. static bool MoveInUseFileAway(const update_t &file)
  648. {
  649. _TCHAR deleteMeName[MAX_PATH];
  650. _TCHAR randomStr[MAX_PATH];
  651. BYTE junk[40];
  652. B2Hash hash;
  653. string temp;
  654. CryptGenRandom(hProvider, sizeof(junk), junk);
  655. blake2b(hash.data(), hash.size(), junk, sizeof(junk), nullptr, 0);
  656. HashToString(hash, temp);
  657. if (!UTF8ToWideBuf(randomStr, temp.c_str()))
  658. return false;
  659. randomStr[8] = 0;
  660. StringCbCopy(deleteMeName, sizeof(deleteMeName), file.outputPath.c_str());
  661. StringCbCat(deleteMeName, sizeof(deleteMeName), L".");
  662. StringCbCat(deleteMeName, sizeof(deleteMeName), randomStr);
  663. StringCbCat(deleteMeName, sizeof(deleteMeName), L".deleteme");
  664. if (MoveFile(file.outputPath.c_str(), deleteMeName)) {
  665. if (MyCopyFile(deleteMeName, file.outputPath.c_str())) {
  666. MoveFileEx(deleteMeName, NULL, MOVEFILE_DELAY_UNTIL_REBOOT);
  667. return true;
  668. } else {
  669. MoveFile(deleteMeName, file.outputPath.c_str());
  670. }
  671. }
  672. return false;
  673. }
  674. static bool UpdateFile(ZSTD_DCtx *ctx, update_t &file)
  675. {
  676. wchar_t oldFileRenamedPath[MAX_PATH];
  677. /* Grab the patch/file data from the global cache. */
  678. vector<std::byte> &patch_data = download_data[file.downloadHash];
  679. /* Check if we're replacing an existing file or just installing a new
  680. * one */
  681. DWORD attribs = GetFileAttributes(file.outputPath.c_str());
  682. if (attribs != INVALID_FILE_ATTRIBUTES) {
  683. wchar_t baseName[MAX_PATH];
  684. StringCbCopy(baseName, sizeof(baseName), file.outputPath.c_str());
  685. wchar_t *curFileName = wcsrchr(baseName, '/');
  686. if (curFileName) {
  687. curFileName[0] = '\0';
  688. curFileName++;
  689. } else
  690. curFileName = baseName;
  691. /* Backup the existing file in case a rollback is needed */
  692. StringCbCopy(oldFileRenamedPath, sizeof(oldFileRenamedPath), file.outputPath.c_str());
  693. StringCbCat(oldFileRenamedPath, sizeof(oldFileRenamedPath), L".old");
  694. if (!MyCopyFile(file.outputPath.c_str(), oldFileRenamedPath)) {
  695. DWORD err = GetLastError();
  696. int is_sharing_violation = (err == ERROR_SHARING_VIOLATION || err == ERROR_USER_MAPPED_FILE);
  697. if (is_sharing_violation)
  698. Status(L"Update failed: %s is still in use. "
  699. L"Close all programs and try again.",
  700. curFileName);
  701. else
  702. Status(L"Update failed: Couldn't backup %s "
  703. L"(error %d)",
  704. curFileName, GetLastError());
  705. return false;
  706. }
  707. file.previousFile = oldFileRenamedPath;
  708. int error_code;
  709. bool installed_ok;
  710. bool already_tried_to_move = false;
  711. retryAfterMovingFile:
  712. if (file.patchable) {
  713. error_code = ApplyPatch(ctx, patch_data.data(), file.fileSize, file.outputPath.c_str());
  714. installed_ok = (error_code == 0);
  715. if (installed_ok) {
  716. B2Hash patchedFileHash;
  717. if (!CalculateFileHash(file.outputPath.c_str(), patchedFileHash)) {
  718. Status(L"Update failed: Couldn't "
  719. L"verify integrity of patched %s",
  720. curFileName);
  721. file.state = STATE_INSTALL_FAILED;
  722. return false;
  723. }
  724. if (file.hash != patchedFileHash) {
  725. Status(L"Update failed: Integrity "
  726. L"check of patched "
  727. L"%s failed",
  728. curFileName);
  729. file.state = STATE_INSTALL_FAILED;
  730. return false;
  731. }
  732. }
  733. } else {
  734. installed_ok = QuickWriteFile(file.outputPath.c_str(), patch_data.data(), patch_data.size());
  735. error_code = GetLastError();
  736. }
  737. if (!installed_ok) {
  738. int is_sharing_violation =
  739. (error_code == ERROR_SHARING_VIOLATION || error_code == ERROR_USER_MAPPED_FILE);
  740. if (is_sharing_violation) {
  741. if (!already_tried_to_move) {
  742. already_tried_to_move = true;
  743. if (MoveInUseFileAway(file))
  744. goto retryAfterMovingFile;
  745. }
  746. Status(L"Update failed: %s is still in use. "
  747. L"Close all "
  748. L"programs and try again.",
  749. curFileName);
  750. } else {
  751. DWORD err = GetLastError();
  752. Status(L"Update failed: Couldn't update %s "
  753. L"(error %d)",
  754. curFileName, err ? err : error_code);
  755. }
  756. file.state = STATE_INSTALL_FAILED;
  757. return false;
  758. }
  759. file.state = STATE_INSTALLED;
  760. } else {
  761. if (file.patchable) {
  762. /* Uh oh, we thought we could patch something but it's
  763. * no longer there! */
  764. Status(L"Update failed: Source file %s not found", file.outputPath.c_str());
  765. return false;
  766. }
  767. /* We may be installing into new folders,
  768. * make sure they exist */
  769. filesystem::path filePath(file.outputPath.c_str());
  770. create_directories(filePath.parent_path());
  771. file.previousFile = L"";
  772. bool success = !!QuickWriteFile(file.outputPath.c_str(), patch_data.data(), patch_data.size());
  773. if (!success) {
  774. Status(L"Update failed: Couldn't install %s (error %d)", file.outputPath.c_str(),
  775. GetLastError());
  776. file.state = STATE_INSTALL_FAILED;
  777. return false;
  778. }
  779. file.state = STATE_INSTALLED;
  780. }
  781. return true;
  782. }
  783. queue<reference_wrapper<update_t>> updateQueue;
  784. static int lastPosition = 0;
  785. static int installed = 0;
  786. static bool updateThreadFailed = false;
  787. static bool UpdateWorker()
  788. {
  789. unique_lock<mutex> ulock(updateMutex, defer_lock);
  790. ZSTDDCtx zCtx;
  791. while (true) {
  792. ulock.lock();
  793. if (updateThreadFailed)
  794. return false;
  795. if (updateQueue.empty())
  796. break;
  797. auto update = updateQueue.front();
  798. updateQueue.pop();
  799. ulock.unlock();
  800. if (!UpdateFile(zCtx, update)) {
  801. updateThreadFailed = true;
  802. return false;
  803. } else {
  804. int position = (int)(((float)++installed / (float)completedUpdates) * 100.0f);
  805. if (position > lastPosition) {
  806. lastPosition = position;
  807. SendDlgItemMessage(hwndMain, IDC_PROGRESS, PBM_SETPOS, position, 0);
  808. }
  809. }
  810. }
  811. return true;
  812. }
  813. static bool RunUpdateWorkers(int num)
  814. try {
  815. for (update_t &update : updates)
  816. updateQueue.emplace(update);
  817. vector<future<bool>> thread_success_results;
  818. thread_success_results.resize(num);
  819. for (future<bool> &result : thread_success_results) {
  820. result = async(launch::async, UpdateWorker);
  821. }
  822. for (future<bool> &result : thread_success_results) {
  823. if (!result.get()) {
  824. return false;
  825. }
  826. }
  827. return true;
  828. } catch (...) {
  829. return false;
  830. }
  831. static bool UpdateVSRedists()
  832. {
  833. /* ------------------------------------------ *
  834. * Initialize session */
  835. const DWORD tlsProtocols = WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2;
  836. const DWORD compressionFlags = WINHTTP_DECOMPRESSION_FLAG_ALL;
  837. HttpHandle hSession = WinHttpOpen(L"OBS Studio Updater/3.0", WINHTTP_ACCESS_TYPE_AUTOMATIC_PROXY,
  838. WINHTTP_NO_PROXY_NAME, WINHTTP_NO_PROXY_BYPASS, 0);
  839. if (!hSession) {
  840. Status(L"VC Redist Update failed: Couldn't create session");
  841. return false;
  842. }
  843. WinHttpSetOption(hSession, WINHTTP_OPTION_SECURE_PROTOCOLS, (LPVOID)&tlsProtocols, sizeof(tlsProtocols));
  844. WinHttpSetOption(hSession, WINHTTP_OPTION_DECOMPRESSION, (LPVOID)&compressionFlags, sizeof(compressionFlags));
  845. HttpHandle hConnect = WinHttpConnect(hSession, kMSHostname, INTERNET_DEFAULT_HTTPS_PORT, 0);
  846. if (!hConnect) {
  847. Status(L"Update failed: Couldn't connect to %S", kMSHostname);
  848. return false;
  849. }
  850. int responseCode;
  851. DWORD waitResult = WaitForSingleObject(cancelRequested, 0);
  852. if (waitResult == WAIT_OBJECT_0) {
  853. return false;
  854. }
  855. /* ------------------------------------------ *
  856. * Download redist */
  857. Status(L"Downloading Visual C++ Redistributable");
  858. wstring destPath;
  859. destPath += tempPath;
  860. destPath += L"\\VC_redist.x64.exe";
  861. if (!HTTPGetFile(hConnect, kVSRedistURL, destPath.c_str(), L"Accept-Encoding: gzip", &responseCode)) {
  862. DeleteFile(destPath.c_str());
  863. Status(L"Update failed: Could not download "
  864. L"%s (error code %d)",
  865. L"Visual C++ Redistributable", responseCode);
  866. return false;
  867. }
  868. /* ------------------------------------------ *
  869. * Verify file signature */
  870. GUID action = WINTRUST_ACTION_GENERIC_VERIFY_V2;
  871. WINTRUST_FILE_INFO fileInfo = {};
  872. fileInfo.cbStruct = sizeof(fileInfo);
  873. fileInfo.pcwszFilePath = destPath.c_str();
  874. WINTRUST_DATA data = {};
  875. data.cbStruct = sizeof(data);
  876. data.dwUIChoice = WTD_UI_NONE;
  877. data.dwUnionChoice = WTD_CHOICE_FILE;
  878. data.dwStateAction = WTD_STATEACTION_VERIFY;
  879. data.pFile = &fileInfo;
  880. LONG result = WinVerifyTrust(nullptr, &action, &data);
  881. if (result != ERROR_SUCCESS) {
  882. Status(L"Update failed: Signature verification failed for "
  883. L"%s (error code %d / %d)",
  884. L"Visual C++ Redistributable", result, GetLastError());
  885. DeleteFile(destPath.c_str());
  886. return false;
  887. }
  888. /* ------------------------------------------ *
  889. * If verification succeeded, install redist */
  890. wchar_t commandline[MAX_PATH + MAX_PATH];
  891. StringCbPrintf(commandline, sizeof(commandline), L"%s /install /quiet /norestart", destPath.c_str());
  892. PROCESS_INFORMATION pi = {};
  893. STARTUPINFO si = {};
  894. si.cb = sizeof(si);
  895. bool success = !!CreateProcessW(destPath.c_str(), commandline, nullptr, nullptr, false, CREATE_NO_WINDOW,
  896. nullptr, nullptr, &si, &pi);
  897. if (success) {
  898. Status(L"Installing %s...", L"Visual C++ Redistributable");
  899. CloseHandle(pi.hThread);
  900. WaitForSingleObject(pi.hProcess, INFINITE);
  901. CloseHandle(pi.hProcess);
  902. } else {
  903. Status(L"Update failed: Could not execute "
  904. L"%s (error code %d)",
  905. L"Visual C++ Redistributable", (int)GetLastError());
  906. }
  907. DeleteFile(destPath.c_str());
  908. waitResult = WaitForSingleObject(cancelRequested, 0);
  909. if (waitResult == WAIT_OBJECT_0) {
  910. return false;
  911. }
  912. return success;
  913. }
  914. static void UpdateRegistryVersion(const Manifest &manifest)
  915. {
  916. const char *regKey = R"(Software\Microsoft\Windows\CurrentVersion\Uninstall\OBS Studio)";
  917. LSTATUS res;
  918. HKEY key;
  919. char version[32];
  920. int formattedLen;
  921. /* The manifest does not store a version string, so we gotta make one ourselves. */
  922. if (manifest.beta || manifest.rc) {
  923. formattedLen = sprintf_s(version, sizeof(version), "%d.%d.%d-%s%d", manifest.version_major,
  924. manifest.version_minor, manifest.version_patch, manifest.beta ? "beta" : "rc",
  925. manifest.beta ? manifest.beta : manifest.rc);
  926. } else {
  927. formattedLen = sprintf_s(version, sizeof(version), "%d.%d.%d", manifest.version_major,
  928. manifest.version_minor, manifest.version_patch);
  929. }
  930. if (formattedLen <= 0)
  931. return;
  932. res = RegOpenKeyExA(HKEY_LOCAL_MACHINE, regKey, 0, KEY_WRITE | KEY_WOW64_32KEY, &key);
  933. if (res != ERROR_SUCCESS)
  934. return;
  935. RegSetValueExA(key, "DisplayVersion", 0, REG_SZ, (const BYTE *)version, formattedLen + 1);
  936. RegCloseKey(key);
  937. }
  938. static void ClearShaderCache()
  939. {
  940. wchar_t shader_path[MAX_PATH];
  941. SHGetFolderPathW(NULL, CSIDL_COMMON_APPDATA, NULL, SHGFP_TYPE_CURRENT, shader_path);
  942. StringCbCatW(shader_path, sizeof(shader_path), L"\\obs-studio\\shader-cache");
  943. filesystem::remove_all(shader_path);
  944. }
  945. extern "C" void UpdateHookFiles(void);
  946. static bool Update(wchar_t *cmdLine)
  947. {
  948. /* ------------------------------------- *
  949. * Check to make sure OBS isn't running */
  950. HANDLE hObsUpdateMutex = OpenMutexW(SYNCHRONIZE, false, L"OBSStudioUpdateMutex");
  951. if (hObsUpdateMutex) {
  952. HANDLE hWait[2];
  953. hWait[0] = hObsUpdateMutex;
  954. hWait[1] = cancelRequested;
  955. int i = WaitForMultipleObjects(2, hWait, false, INFINITE);
  956. if (i == WAIT_OBJECT_0)
  957. ReleaseMutex(hObsUpdateMutex);
  958. CloseHandle(hObsUpdateMutex);
  959. if (i == WAIT_OBJECT_0 + 1)
  960. return false;
  961. }
  962. if (!WaitForOBS())
  963. return false;
  964. /* ------------------------------------- *
  965. * Init crypt stuff */
  966. CryptProvider hProvider;
  967. if (!CryptAcquireContext(&hProvider, nullptr, MS_ENH_RSA_AES_PROV, PROV_RSA_AES, CRYPT_VERIFYCONTEXT)) {
  968. SetDlgItemTextW(hwndMain, IDC_STATUS, L"Update failed: CryptAcquireContext failure");
  969. return false;
  970. }
  971. ::hProvider = hProvider;
  972. /* ------------------------------------- */
  973. SetDlgItemTextW(hwndMain, IDC_STATUS, L"Searching for available updates...");
  974. HWND hProgress = GetDlgItem(hwndMain, IDC_PROGRESS);
  975. LONG_PTR style = GetWindowLongPtr(hProgress, GWL_STYLE);
  976. SetWindowLongPtr(hProgress, GWL_STYLE, style | PBS_MARQUEE);
  977. SendDlgItemMessage(hwndMain, IDC_PROGRESS, PBM_SETMARQUEE, 1, 0);
  978. /* ------------------------------------- *
  979. * Check if updating portable build */
  980. bool bIsPortable = false;
  981. wstring branch = L"stable";
  982. wstring appdata;
  983. if (cmdLine[0]) {
  984. int argc;
  985. LPWSTR *argv = CommandLineToArgvW(cmdLine, &argc);
  986. if (argv) {
  987. for (int i = 0; i < argc; i++) {
  988. if (wcscmp(argv[i], L"Portable") == 0) {
  989. // Legacy OBS
  990. bIsPortable = true;
  991. break;
  992. } else if (wcsncmp(argv[i], L"--branch=", 9) == 0) {
  993. branch = argv[i] + 9;
  994. } else if (wcsncmp(argv[i], L"--appdata=", 10) == 0) {
  995. appdata = argv[i] + 10;
  996. } else if (wcscmp(argv[i], L"--portable") == 0) {
  997. bIsPortable = true;
  998. } else if (wcsncmp(argv[i], L"--portable--branch=", 19) == 0) {
  999. /* Versions pre-29.1 beta 2 produce broken parameters :( */
  1000. bIsPortable = true;
  1001. branch = argv[i] + 19;
  1002. }
  1003. }
  1004. LocalFree((HLOCAL)argv);
  1005. }
  1006. }
  1007. /* ------------------------------------- *
  1008. * Get config path */
  1009. wchar_t lpAppDataPath[MAX_PATH];
  1010. lpAppDataPath[0] = 0;
  1011. if (bIsPortable) {
  1012. StringCbCopy(lpAppDataPath, sizeof(lpAppDataPath), obs_base_directory);
  1013. StringCbCat(lpAppDataPath, sizeof(lpAppDataPath), L"\\config");
  1014. } else {
  1015. if (!appdata.empty()) {
  1016. HRESULT hr = StringCbCopy(lpAppDataPath, sizeof(lpAppDataPath), appdata.c_str());
  1017. if (hr != S_OK) {
  1018. Status(L"Update failed: Could not determine AppData "
  1019. L"location");
  1020. return false;
  1021. }
  1022. } else {
  1023. CoTaskMemPtr<wchar_t> pOut;
  1024. HRESULT hr = SHGetKnownFolderPath(FOLDERID_RoamingAppData, KF_FLAG_DEFAULT, nullptr, &pOut);
  1025. if (hr != S_OK) {
  1026. Status(L"Update failed: Could not determine AppData "
  1027. L"location");
  1028. return false;
  1029. }
  1030. StringCbCopy(lpAppDataPath, sizeof(lpAppDataPath), pOut);
  1031. }
  1032. }
  1033. StringCbCat(lpAppDataPath, sizeof(lpAppDataPath), L"\\obs-studio");
  1034. /* ------------------------------------- *
  1035. * Get download path */
  1036. wchar_t manifestPath[MAX_PATH];
  1037. wchar_t tempDirName[MAX_PATH];
  1038. manifestPath[0] = 0;
  1039. tempDirName[0] = 0;
  1040. StringCbPrintf(manifestPath, sizeof(manifestPath), L"%s\\updates\\manifest.json", lpAppDataPath);
  1041. if (!GetTempPathW(_countof(tempDirName), tempDirName)) {
  1042. Status(L"Update failed: Failed to get temp path: %ld", GetLastError());
  1043. return false;
  1044. }
  1045. if (!GetTempFileNameW(tempDirName, L"obs-studio", 0, tempPath)) {
  1046. Status(L"Update failed: Failed to create temp dir name: %ld", GetLastError());
  1047. return false;
  1048. }
  1049. DeleteFile(tempPath);
  1050. CreateDirectory(tempPath, nullptr);
  1051. /* ------------------------------------- *
  1052. * Load manifest file */
  1053. Manifest manifest;
  1054. {
  1055. string manifestFile = QuickReadFile(manifestPath);
  1056. if (manifestFile.empty()) {
  1057. Status(L"Update failed: Couldn't load manifest file");
  1058. return false;
  1059. }
  1060. try {
  1061. json manifestContents = json::parse(manifestFile);
  1062. manifest = manifestContents.get<Manifest>();
  1063. } catch (json::exception &e) {
  1064. Status(L"Update failed: Couldn't parse update "
  1065. L"manifest: %S",
  1066. e.what());
  1067. return false;
  1068. }
  1069. }
  1070. /* ------------------------------------- *
  1071. * Hash local files listed in manifest */
  1072. RunHasherWorkers(4, manifest.packages);
  1073. /* ------------------------------------- *
  1074. * Parse current manifest update files */
  1075. for (const Package &package : manifest.packages) {
  1076. if (!AddPackageUpdateFiles(package, branch.c_str())) {
  1077. Status(L"Update failed: Failed to process update packages");
  1078. return false;
  1079. }
  1080. /* Add removed files to deletion queue (if any) */
  1081. AddPackageRemovedFiles(package);
  1082. }
  1083. SendDlgItemMessage(hwndMain, IDC_PROGRESS, PBM_SETMARQUEE, 0, 0);
  1084. SetWindowLongPtr(hProgress, GWL_STYLE, style);
  1085. /* ------------------------------------- *
  1086. * Exit if updates already installed */
  1087. if (updates.empty()) {
  1088. Status(L"All available updates are already installed.");
  1089. SetDlgItemText(hwndMain, IDC_BUTTON, L"Launch OBS");
  1090. return true;
  1091. }
  1092. /* ------------------------------------- *
  1093. * Check VS redistributables version */
  1094. if (IsVSRedistOutdated()) {
  1095. if (!UpdateVSRedists()) {
  1096. return false;
  1097. }
  1098. }
  1099. /* ------------------------------------- *
  1100. * Generate file hash json */
  1101. PatchesRequest files;
  1102. for (update_t &update : updates) {
  1103. if (!update.has_hash)
  1104. continue;
  1105. char outputPath[MAX_PATH];
  1106. if (!WideToUTF8Buf(outputPath, update.outputPath.c_str()))
  1107. continue;
  1108. string hash_string;
  1109. HashToString(update.my_hash, hash_string);
  1110. string package_path;
  1111. package_path = update.packageName;
  1112. package_path += "/";
  1113. package_path += outputPath;
  1114. files.push_back({package_path, hash_string});
  1115. }
  1116. /* ------------------------------------- *
  1117. * Send file hashes */
  1118. string newManifest;
  1119. if (!files.empty()) {
  1120. json request = files;
  1121. string post_body = request.dump();
  1122. int len = (int)post_body.size();
  1123. size_t compressSize = ZSTD_compressBound(len);
  1124. string compressedJson;
  1125. compressedJson.resize(compressSize);
  1126. size_t result = ZSTD_compress(compressedJson.data(), compressedJson.size(), post_body.data(),
  1127. post_body.size(), ZSTD_CLEVEL_DEFAULT);
  1128. if (ZSTD_isError(result))
  1129. return false;
  1130. compressedJson.resize(result);
  1131. wstring manifestUrl(kPatchManifestURL);
  1132. if (branch != L"stable")
  1133. manifestUrl += L"?branch=" + branch;
  1134. int responseCode;
  1135. bool success = !!HTTPPostData(manifestUrl.c_str(), (BYTE *)compressedJson.data(),
  1136. (int)compressedJson.size(), L"Accept-Encoding: gzip", &responseCode,
  1137. newManifest);
  1138. if (!success)
  1139. return false;
  1140. if (responseCode != 200) {
  1141. Status(L"Update failed: HTTP/%d while trying to "
  1142. L"download patch manifest",
  1143. responseCode);
  1144. return false;
  1145. }
  1146. } else {
  1147. newManifest = "[]";
  1148. }
  1149. /* ------------------------------------- *
  1150. * Parse new manifest */
  1151. PatchesResponse patches;
  1152. try {
  1153. json patchManifest = json::parse(newManifest);
  1154. patches = patchManifest.get<PatchesResponse>();
  1155. } catch (json::exception &e) {
  1156. Status(L"Update failed: Couldn't parse patch manifest: %S", e.what());
  1157. return false;
  1158. }
  1159. /* Update updates with patch information. */
  1160. for (const PatchResponse &patch : patches) {
  1161. UpdateWithPatchIfAvailable(patch);
  1162. }
  1163. /* ------------------------------------- *
  1164. * Deduplicate Downloads */
  1165. unordered_set<B2Hash> downloadHashes;
  1166. for (update_t &update : updates) {
  1167. if (downloadHashes.count(update.downloadHash)) {
  1168. update.state = STATE_ALREADY_DOWNLOADED;
  1169. totalFileSize -= update.fileSize;
  1170. completedUpdates++;
  1171. } else {
  1172. downloadHashes.insert(update.downloadHash);
  1173. }
  1174. }
  1175. /* ------------------------------------- *
  1176. * Download Updates */
  1177. Status(L"Downloading updates...");
  1178. if (!RunDownloadWorkers(4))
  1179. return false;
  1180. if ((size_t)completedUpdates != updates.size()) {
  1181. Status(L"Update failed to download all files.");
  1182. return false;
  1183. }
  1184. /* ------------------------------------- *
  1185. * Install updates */
  1186. SendDlgItemMessage(hwndMain, IDC_PROGRESS, PBM_SETPOS, 0, 0);
  1187. Status(L"Installing updates...");
  1188. if (!RunUpdateWorkers(4))
  1189. return false;
  1190. for (deletion_t &deletion : deletions) {
  1191. if (!RenameRemovedFile(deletion)) {
  1192. Status(L"Update failed: Couldn't remove "
  1193. L"obsolete files");
  1194. return false;
  1195. }
  1196. }
  1197. /* ------------------------------------- *
  1198. * Install virtual camera */
  1199. auto runcommand = [](wchar_t *cmd) {
  1200. STARTUPINFO si = {};
  1201. si.cb = sizeof(si);
  1202. si.dwFlags = STARTF_USESHOWWINDOW;
  1203. si.wShowWindow = SW_HIDE;
  1204. PROCESS_INFORMATION pi;
  1205. bool success = !!CreateProcessW(nullptr, cmd, nullptr, nullptr, false, CREATE_NEW_CONSOLE, nullptr,
  1206. nullptr, &si, &pi);
  1207. if (success) {
  1208. WaitForSingleObject(pi.hProcess, INFINITE);
  1209. CloseHandle(pi.hThread);
  1210. CloseHandle(pi.hProcess);
  1211. }
  1212. };
  1213. if (!bIsPortable) {
  1214. Status(L"Installing Virtual Camera...");
  1215. wchar_t regsvr[MAX_PATH];
  1216. wchar_t src[MAX_PATH];
  1217. wchar_t tmp[MAX_PATH];
  1218. wchar_t tmp2[MAX_PATH];
  1219. SHGetFolderPathW(nullptr, CSIDL_SYSTEM, nullptr, SHGFP_TYPE_CURRENT, regsvr);
  1220. StringCbCat(regsvr, sizeof(regsvr), L"\\regsvr32.exe");
  1221. StringCbCopy(src, sizeof(src), obs_base_directory);
  1222. StringCbCat(src, sizeof(src), L"\\data\\obs-plugins\\win-dshow\\");
  1223. StringCbCopy(tmp, sizeof(tmp), L"\"");
  1224. StringCbCat(tmp, sizeof(tmp), regsvr);
  1225. StringCbCat(tmp, sizeof(tmp), L"\" /s \"");
  1226. StringCbCat(tmp, sizeof(tmp), src);
  1227. StringCbCat(tmp, sizeof(tmp), L"obs-virtualcam-module");
  1228. StringCbCopy(tmp2, sizeof(tmp2), tmp);
  1229. StringCbCat(tmp2, sizeof(tmp2), L"32.dll\"");
  1230. runcommand(tmp2);
  1231. StringCbCopy(tmp2, sizeof(tmp2), tmp);
  1232. StringCbCat(tmp2, sizeof(tmp2), L"64.dll\"");
  1233. runcommand(tmp2);
  1234. }
  1235. /* ------------------------------------- *
  1236. * Update hook files and vulkan registry */
  1237. Status(L"Updating Game Capture hooks...");
  1238. UpdateHookFiles();
  1239. /* ------------------------------------- *
  1240. * Clear shader cache */
  1241. Status(L"Clearing shader cache...");
  1242. ClearShaderCache();
  1243. /* ------------------------------------- *
  1244. * Update installed version in registry */
  1245. if (!bIsPortable) {
  1246. Status(L"Updating version information...");
  1247. UpdateRegistryVersion(manifest);
  1248. }
  1249. /* ------------------------------------- *
  1250. * Finish */
  1251. Status(L"Cleaning up...");
  1252. /* If we get here, all updates installed successfully so we can purge
  1253. * the old versions */
  1254. for (update_t &update : updates) {
  1255. if (!update.previousFile.empty())
  1256. DeleteFile(update.previousFile.c_str());
  1257. }
  1258. /* Delete all removed files mentioned in the manifest */
  1259. for (deletion_t &deletion : deletions)
  1260. MyDeleteFile(deletion.deleteMeFilename);
  1261. SendDlgItemMessage(hwndMain, IDC_PROGRESS, PBM_SETPOS, 100, 0);
  1262. Status(L"Update complete.");
  1263. SetDlgItemText(hwndMain, IDC_BUTTON, L"Launch OBS");
  1264. return true;
  1265. }
  1266. static DWORD WINAPI UpdateThread(void *arg)
  1267. {
  1268. wchar_t *cmdLine = (wchar_t *)arg;
  1269. bool success = Update(cmdLine);
  1270. if (!success) {
  1271. /* This handles deleting temp files and rolling back and
  1272. * partially installed updates */
  1273. CleanupPartialUpdates();
  1274. if (tempPath[0])
  1275. RemoveDirectory(tempPath);
  1276. if (WaitForSingleObject(cancelRequested, 0) == WAIT_OBJECT_0)
  1277. Status(L"Update aborted.");
  1278. HWND hProgress = GetDlgItem(hwndMain, IDC_PROGRESS);
  1279. LONG_PTR style = GetWindowLongPtr(hProgress, GWL_STYLE);
  1280. SetWindowLongPtr(hProgress, GWL_STYLE, style & ~PBS_MARQUEE);
  1281. SendMessage(hProgress, PBM_SETSTATE, PBST_ERROR, 0);
  1282. SetDlgItemText(hwndMain, IDC_BUTTON, L"Exit");
  1283. EnableWindow(GetDlgItem(hwndMain, IDC_BUTTON), true);
  1284. updateFailed = true;
  1285. } else {
  1286. if (tempPath[0])
  1287. RemoveDirectory(tempPath);
  1288. }
  1289. if (bExiting)
  1290. ExitProcess(success);
  1291. return 0;
  1292. }
  1293. static void CancelUpdate(bool quit)
  1294. {
  1295. if (WaitForSingleObject(updateThread, 0) != WAIT_OBJECT_0) {
  1296. bExiting = quit;
  1297. SetEvent(cancelRequested);
  1298. } else {
  1299. PostQuitMessage(0);
  1300. }
  1301. }
  1302. static void LaunchOBS(LPWSTR lpCmdLine)
  1303. {
  1304. wchar_t newCwd[MAX_PATH];
  1305. wchar_t obsPath[MAX_PATH];
  1306. StringCbCopy(obsPath, sizeof(obsPath), obs_base_directory);
  1307. StringCbCat(obsPath, sizeof(obsPath), L"\\bin\\64bit");
  1308. SetCurrentDirectory(obsPath);
  1309. StringCbCopy(newCwd, sizeof(newCwd), obsPath);
  1310. StringCbCat(obsPath, sizeof(obsPath), L"\\obs64.exe");
  1311. if (!FileExists(obsPath)) {
  1312. /* TODO: give user a message maybe? */
  1313. return;
  1314. }
  1315. SHELLEXECUTEINFO execInfo;
  1316. ZeroMemory(&execInfo, sizeof(execInfo));
  1317. execInfo.cbSize = sizeof(execInfo);
  1318. execInfo.lpFile = obsPath;
  1319. execInfo.lpDirectory = newCwd;
  1320. execInfo.nShow = SW_SHOWNORMAL;
  1321. if (lpCmdLine[0])
  1322. execInfo.lpParameters = lpCmdLine;
  1323. ShellExecuteEx(&execInfo);
  1324. }
  1325. static INT_PTR CALLBACK UpdateDialogProc(HWND hwnd, UINT message, WPARAM wParam, LPARAM lParam)
  1326. {
  1327. switch (message) {
  1328. case WM_INITDIALOG: {
  1329. static HICON hMainIcon = LoadIcon(hinstMain, MAKEINTRESOURCE(IDI_ICON1));
  1330. SendMessage(hwnd, WM_SETICON, ICON_BIG, (LPARAM)hMainIcon);
  1331. SendMessage(hwnd, WM_SETICON, ICON_SMALL, (LPARAM)hMainIcon);
  1332. return true;
  1333. }
  1334. case WM_COMMAND:
  1335. if (LOWORD(wParam) == IDC_BUTTON) {
  1336. if (HIWORD(wParam) == BN_CLICKED) {
  1337. DWORD result = WaitForSingleObject(updateThread, 0);
  1338. if (result == WAIT_OBJECT_0) {
  1339. if (updateFailed)
  1340. PostQuitMessage(0);
  1341. else
  1342. PostQuitMessage(1);
  1343. } else {
  1344. EnableWindow((HWND)lParam, false);
  1345. CancelUpdate(false);
  1346. }
  1347. }
  1348. }
  1349. return true;
  1350. case WM_CLOSE:
  1351. CancelUpdate(true);
  1352. return true;
  1353. }
  1354. return false;
  1355. }
  1356. static int RestartAsAdmin(LPCWSTR lpCmdLine, LPCWSTR cwd)
  1357. {
  1358. wchar_t myPath[MAX_PATH];
  1359. if (!GetModuleFileNameW(nullptr, myPath, _countof(myPath) - 1)) {
  1360. return 0;
  1361. }
  1362. /* If the admin is a different user, add the path to the user's
  1363. * AppData to the command line so we can load the correct manifest. */
  1364. wstring elevatedCmdLine(lpCmdLine);
  1365. CoTaskMemPtr<wchar_t> pOut;
  1366. HRESULT hr = SHGetKnownFolderPath(FOLDERID_RoamingAppData, KF_FLAG_DEFAULT, nullptr, &pOut);
  1367. if (hr == S_OK) {
  1368. elevatedCmdLine += L" \"--appdata=";
  1369. elevatedCmdLine += pOut;
  1370. elevatedCmdLine += L"\"";
  1371. }
  1372. SHELLEXECUTEINFO shExInfo = {0};
  1373. shExInfo.cbSize = sizeof(shExInfo);
  1374. shExInfo.fMask = SEE_MASK_NOCLOSEPROCESS;
  1375. shExInfo.hwnd = nullptr;
  1376. shExInfo.lpVerb = L"runas"; /* Operation to perform */
  1377. shExInfo.lpFile = myPath; /* Application to start */
  1378. shExInfo.lpParameters = elevatedCmdLine.c_str(); /* Additional parameters */
  1379. shExInfo.lpDirectory = cwd;
  1380. shExInfo.nShow = SW_NORMAL;
  1381. shExInfo.hInstApp = nullptr;
  1382. /* annoyingly the actual elevated updater will disappear behind other
  1383. * windows :( */
  1384. AllowSetForegroundWindow(ASFW_ANY);
  1385. if (ShellExecuteEx(&shExInfo)) {
  1386. DWORD exitCode;
  1387. WaitForSingleObject(shExInfo.hProcess, INFINITE);
  1388. if (GetExitCodeProcess(shExInfo.hProcess, &exitCode)) {
  1389. if (exitCode == 1) {
  1390. return exitCode;
  1391. }
  1392. }
  1393. CloseHandle(shExInfo.hProcess);
  1394. }
  1395. return 0;
  1396. }
  1397. static bool HasElevation()
  1398. {
  1399. SID_IDENTIFIER_AUTHORITY sia = SECURITY_NT_AUTHORITY;
  1400. PSID sid = nullptr;
  1401. BOOL elevated = false;
  1402. BOOL success;
  1403. success = AllocateAndInitializeSid(&sia, 2, SECURITY_BUILTIN_DOMAIN_RID, DOMAIN_ALIAS_RID_ADMINS, 0, 0, 0, 0, 0,
  1404. 0, &sid);
  1405. if (success && sid) {
  1406. CheckTokenMembership(nullptr, sid, &elevated);
  1407. FreeSid(sid);
  1408. }
  1409. return elevated;
  1410. }
  1411. int WINAPI wWinMain(HINSTANCE hInstance, HINSTANCE, LPWSTR lpCmdLine, int)
  1412. {
  1413. INITCOMMONCONTROLSEX icce;
  1414. wchar_t cwd[MAX_PATH];
  1415. GetCurrentDirectoryW(_countof(cwd) - 1, cwd);
  1416. if (!IsWindows10OrGreater()) {
  1417. MessageBox(nullptr,
  1418. L"OBS Studio 28 and newer no longer support Windows 7,"
  1419. L" Windows 8, or Windows 8.1. You can disable the"
  1420. L" following setting to opt out of future updates:"
  1421. L" Settings → General → General → Automatically check"
  1422. L" for updates on startup",
  1423. L"Unsupported Operating System", MB_ICONWARNING);
  1424. return 0;
  1425. }
  1426. if (!HasElevation()) {
  1427. WinHandle hMutex = OpenMutex(SYNCHRONIZE, false, L"OBSUpdaterRunningAsNonAdminUser");
  1428. if (hMutex) {
  1429. MessageBox(nullptr, L"OBS Studio Updater must be run as an administrator.", L"Updater Error",
  1430. MB_ICONWARNING);
  1431. return 2;
  1432. }
  1433. HANDLE hLowMutex = CreateMutexW(nullptr, true, L"OBSUpdaterRunningAsNonAdminUser");
  1434. /* return code 1 = user wanted to launch OBS */
  1435. if (RestartAsAdmin(lpCmdLine, cwd) == 1) {
  1436. StringCbCat(cwd, sizeof(cwd), L"\\..\\..");
  1437. GetFullPathName(cwd, _countof(obs_base_directory), obs_base_directory, nullptr);
  1438. SetCurrentDirectory(obs_base_directory);
  1439. LaunchOBS(lpCmdLine);
  1440. }
  1441. if (hLowMutex) {
  1442. ReleaseMutex(hLowMutex);
  1443. CloseHandle(hLowMutex);
  1444. }
  1445. return 0;
  1446. } else {
  1447. StringCbCat(cwd, sizeof(cwd), L"\\..\\..");
  1448. GetFullPathName(cwd, _countof(obs_base_directory), obs_base_directory, nullptr);
  1449. SetCurrentDirectory(obs_base_directory);
  1450. hinstMain = hInstance;
  1451. icce.dwSize = sizeof(icce);
  1452. icce.dwICC = ICC_PROGRESS_CLASS;
  1453. InitCommonControlsEx(&icce);
  1454. hwndMain = CreateDialog(hInstance, MAKEINTRESOURCE(IDD_UPDATEDIALOG), nullptr, UpdateDialogProc);
  1455. if (!hwndMain) {
  1456. return -1;
  1457. }
  1458. ShowWindow(hwndMain, SW_SHOWNORMAL);
  1459. SetForegroundWindow(hwndMain);
  1460. cancelRequested = CreateEvent(nullptr, true, false, nullptr);
  1461. updateThread = CreateThread(nullptr, 0, UpdateThread, lpCmdLine, 0, nullptr);
  1462. MSG msg;
  1463. while (GetMessage(&msg, nullptr, 0, 0)) {
  1464. if (!IsDialogMessage(hwndMain, &msg)) {
  1465. TranslateMessage(&msg);
  1466. DispatchMessage(&msg);
  1467. }
  1468. }
  1469. /* there is no non-elevated process waiting for us if UAC is
  1470. * disabled */
  1471. WinHandle hMutex = OpenMutex(SYNCHRONIZE, false, L"OBSUpdaterRunningAsNonAdminUser");
  1472. if (msg.wParam == 1 && !hMutex) {
  1473. LaunchOBS(lpCmdLine);
  1474. }
  1475. return (int)msg.wParam;
  1476. }
  1477. }