Browse Source

UI: Add Auth and OAuth classes

Allows the ability to authenticate to a specific service.  Typically via
OAuth.
jp9000 6 years ago
parent
commit
08fb29a035

+ 11 - 0
UI/CMakeLists.txt

@@ -105,6 +105,15 @@ elseif(UNIX)
                 Qt5::X11Extras)
 endif()
 
+if(BROWSER_AVAILABLE_INTERNAL)
+	list(APPEND obs_PLATFORM_SOURCES
+		auth-oauth.cpp
+		)
+	list(APPEND obs_PLATFORM_HEADERS
+		auth-oauth.hpp
+		)
+endif()
+
 set(obs_libffutil_SOURCES
 	../deps/libff/libff/ff-util.c
 	)
@@ -150,6 +159,7 @@ set(obs_SOURCES
 	window-log-reply.cpp
 	window-projector.cpp
 	window-remux.cpp
+	auth-base.cpp
 	source-tree.cpp
 	properties-view.cpp
 	focus-list.cpp
@@ -198,6 +208,7 @@ set(obs_HEADERS
 	window-log-reply.hpp
 	window-projector.hpp
 	window-remux.hpp
+	auth-base.hpp
 	source-tree.hpp
 	properties-view.hpp
 	properties-view.moc.hpp

+ 71 - 0
UI/auth-base.cpp

@@ -0,0 +1,71 @@
+#include "auth-base.hpp"
+#include "window-basic-main.hpp"
+
+#include <vector>
+#include <map>
+
+struct AuthInfo {
+	Auth::Def def;
+	Auth::create_cb create;
+};
+
+static std::vector<AuthInfo> authDefs;
+
+void Auth::RegisterAuth(const Def &d, create_cb create)
+{
+	AuthInfo info = {d, create};
+	authDefs.push_back(info);
+}
+
+std::shared_ptr<Auth> Auth::Create(const std::string &service)
+{
+	for (auto &a : authDefs) {
+		if (service.find(a.def.service) != std::string::npos) {
+			return a.create();
+		}
+	}
+
+	return nullptr;
+}
+
+Auth::Type Auth::AuthType(const std::string &service)
+{
+	for (auto &a : authDefs) {
+		if (service.find(a.def.service) != std::string::npos) {
+			return a.def.type;
+		}
+	}
+
+	return Type::None;
+}
+
+void Auth::Load()
+{
+	OBSBasic *main = OBSBasic::Get();
+	const char *typeStr = config_get_string(main->Config(), "Auth", "Type");
+	if (!typeStr) typeStr = "";
+
+	main->auth = Create(typeStr);
+	if (main->auth) {
+		if (main->auth->LoadInternal()) {
+			main->auth->LoadUI();
+		}
+	}
+}
+
+void Auth::Save()
+{
+	OBSBasic *main = OBSBasic::Get();
+	Auth *auth = main->auth.get();
+	if (!auth) {
+		if (config_has_user_value(main->Config(), "Auth", "Type")) {
+			config_remove_value(main->Config(), "Auth", "Type");
+			config_save_safe(main->Config(), "tmp", nullptr);
+		}
+		return;
+	}
+
+	config_set_string(main->Config(), "Auth", "Type", auth->service());
+	auth->SaveInternal();
+	config_save_safe(main->Config(), "tmp", nullptr);
+}

+ 58 - 0
UI/auth-base.hpp

@@ -0,0 +1,58 @@
+#pragma once
+
+#include <QObject>
+#include <functional>
+#include <memory>
+
+class Auth : public QObject {
+	Q_OBJECT
+
+protected:
+	virtual void SaveInternal()=0;
+	virtual bool LoadInternal()=0;
+
+	bool firstLoad = true;
+
+	struct ErrorInfo {
+		std::string message;
+		std::string error;
+
+		ErrorInfo(std::string message_, std::string error_)
+			: message(message_), error(error_)
+		{}
+	};
+
+public:
+	enum class Type {
+		None,
+		OAuth_StreamKey
+	};
+
+	struct Def {
+		std::string service;
+		Type type;
+	};
+
+	typedef std::function<std::shared_ptr<Auth> ()> create_cb;
+
+	inline Auth(const Def &d) : def(d) {}
+	virtual ~Auth() {}
+
+	inline Type type() const {return def.type;}
+	inline const char *service() const {return def.service.c_str();}
+
+	virtual void LoadUI() {}
+
+	virtual void OnStreamConfig() {}
+
+	static std::shared_ptr<Auth> Create(const std::string &service);
+	static Type AuthType(const std::string &service);
+	static void Load();
+	static void Save();
+
+protected:
+	static void RegisterAuth(const Def &d, create_cb create);
+
+private:
+	Def def;
+};

+ 283 - 0
UI/auth-oauth.cpp

@@ -0,0 +1,283 @@
+#include "auth-oauth.hpp"
+
+#include <QPushButton>
+#include <QHBoxLayout>
+#include <QVBoxLayout>
+
+#include <qt-wrappers.hpp>
+#include <obs-app.hpp>
+
+#include "window-basic-main.hpp"
+#include "remote-text.hpp"
+
+#include <unordered_map>
+
+#include <json11.hpp>
+
+using namespace json11;
+
+#include <browser-panel.hpp>
+extern QCef *cef;
+extern QCefCookieManager *panel_cookies;
+
+/* ------------------------------------------------------------------------- */
+
+OAuthLogin::OAuthLogin(QWidget *parent, const std::string &url, bool token)
+	: QDialog   (parent),
+	  get_token (token)
+{
+	setWindowTitle("Auth");
+	resize(700, 700);
+
+	OBSBasic::InitBrowserPanelSafeBlock(true);
+
+	cefWidget = cef->create_widget(nullptr, url, panel_cookies);
+	if (!cefWidget) {
+		fail = true;
+		return;
+	}
+
+	connect(cefWidget, SIGNAL(titleChanged(const QString &)),
+			this, SLOT(setWindowTitle(const QString &)));
+	connect(cefWidget, SIGNAL(urlChanged(const QString &)),
+			this, SLOT(urlChanged(const QString &)));
+
+	QPushButton *close = new QPushButton(QTStr("Cancel"));
+	connect(close, &QAbstractButton::clicked,
+			this, &QDialog::reject);
+
+	QHBoxLayout *bottomLayout = new QHBoxLayout();
+	bottomLayout->addStretch();
+	bottomLayout->addWidget(close);
+	bottomLayout->addStretch();
+
+	QVBoxLayout *topLayout = new QVBoxLayout(this);
+	topLayout->addWidget(cefWidget);
+	topLayout->addLayout(bottomLayout);
+}
+
+OAuthLogin::~OAuthLogin()
+{
+	delete cefWidget;
+}
+
+void OAuthLogin::urlChanged(const QString &url)
+{
+	std::string uri = get_token ? "access_token=" : "code=";
+	int code_idx = url.indexOf(uri.c_str());
+	if (code_idx == -1)
+		return;
+
+	if (url.left(22) != "https://obsproject.com")
+		return;
+
+	code_idx += (int)uri.size();
+
+	int next_idx = url.indexOf("&", code_idx);
+	if (next_idx != -1)
+		code = url.mid(code_idx, next_idx - code_idx);
+	else
+		code = url.right(url.size() - code_idx);
+
+	accept();
+}
+
+/* ------------------------------------------------------------------------- */
+
+struct OAuthInfo {
+	Auth::Def def;
+	OAuth::login_cb login;
+	OAuth::delete_cookies_cb delete_cookies;
+};
+
+static std::vector<OAuthInfo> loginCBs;
+
+void OAuth::RegisterOAuth(const Def &d, create_cb create, login_cb login,
+		delete_cookies_cb delete_cookies)
+{
+	OAuthInfo info = {d, login, delete_cookies};
+	loginCBs.push_back(info);
+	RegisterAuth(d, create);
+}
+
+std::shared_ptr<Auth> OAuth::Login(QWidget *parent, const std::string &service)
+{
+	for (auto &a : loginCBs) {
+		if (service.find(a.def.service) != std::string::npos) {
+			return a.login(parent);
+		}
+	}
+
+	return nullptr;
+}
+
+void OAuth::DeleteCookies(const std::string &service)
+{
+	for (auto &a : loginCBs) {
+		if (service.find(a.def.service) != std::string::npos) {
+			a.delete_cookies();
+		}
+	}
+}
+
+void OAuth::SaveInternal()
+{
+	OBSBasic *main = OBSBasic::Get();
+	config_set_string(main->Config(), service(), "RefreshToken",
+			refresh_token.c_str());
+	config_set_string(main->Config(), service(), "Token", token.c_str());
+	config_set_uint(main->Config(), service(), "ExpireTime", expire_time);
+	config_set_int(main->Config(), service(), "ScopeVer", currentScopeVer);
+}
+
+static inline std::string get_config_str(
+		OBSBasic *main,
+		const char *section,
+		const char *name)
+{
+	const char *val = config_get_string(main->Config(), section, name);
+	return val ? val : "";
+}
+
+bool OAuth::LoadInternal()
+{
+	OBSBasic *main = OBSBasic::Get();
+	refresh_token = get_config_str(main, service(), "RefreshToken");
+	token = get_config_str(main, service(), "Token");
+	expire_time = config_get_uint(main->Config(), service(), "ExpireTime");
+	currentScopeVer = (int)config_get_int(main->Config(), service(),
+			"ScopeVer");
+	return implicit
+		? !token.empty()
+		: !refresh_token.empty();
+}
+
+bool OAuth::TokenExpired()
+{
+	if (token.empty())
+		return true;
+	if ((uint64_t)time(nullptr) > expire_time - 5)
+		return true;
+	return false;
+}
+
+bool OAuth::GetToken(const char *url, const std::string &client_id,
+		int scope_ver, const std::string &auth_code, bool retry)
+try {
+	std::string output;
+	std::string error;
+	std::string desc;
+
+	if (currentScopeVer > 0 && currentScopeVer < scope_ver) {
+		if (RetryLogin()) {
+			return true;
+		} else {
+			QString title = QTStr("Auth.InvalidScope.Title");
+			QString text = QTStr("Auth.InvalidScope.Text")
+				.arg(service());
+
+			QMessageBox::warning(OBSBasic::Get(), title, text);
+		}
+	}
+
+	if (auth_code.empty() && !TokenExpired()) {
+		return true;
+	}
+
+	std::string post_data;
+	post_data += "action=redirect&client_id=";
+	post_data += client_id;
+
+	if (!auth_code.empty()) {
+		post_data += "&grant_type=authorization_code&code=";
+		post_data += auth_code;
+	} else {
+		post_data += "&grant_type=refresh_token&refresh_token=";
+		post_data += refresh_token;
+	}
+
+	bool success = false;
+
+	auto func = [&] () {
+		success = GetRemoteFile(
+				url,
+				output,
+				error,
+				nullptr,
+				"application/x-www-form-urlencoded",
+				post_data.c_str(),
+				std::vector<std::string>(),
+				nullptr,
+				5);
+	};
+
+	ExecuteFuncSafeBlockMsgBox(
+			func,
+			QTStr("Auth.Authing.Title"),
+			QTStr("Auth.Authing.Text").arg(service()));
+	if (!success || output.empty())
+		throw ErrorInfo("Failed to get token from remote", error);
+
+	Json json = Json::parse(output, error);
+	if (!error.empty())
+		throw ErrorInfo("Failed to parse json", error);
+
+	/* -------------------------- */
+	/* error handling             */
+
+	error = json["error"].string_value();
+	if (!retry && error == "invalid_grant") {
+		if (RetryLogin()) {
+			return true;
+		}
+	}
+	if (!error.empty())
+		throw ErrorInfo(error, json["error_description"].string_value());
+
+	/* -------------------------- */
+	/* success!                   */
+
+	expire_time = (uint64_t)time(nullptr) + json["expires_in"].int_value();
+	token       = json["access_token"].string_value();
+	if (token.empty())
+		throw ErrorInfo("Failed to get token from remote", error);
+
+	if (!auth_code.empty()) {
+		refresh_token = json["refresh_token"].string_value();
+		if (refresh_token.empty())
+			throw ErrorInfo("Failed to get refresh token from "
+					"remote", error);
+
+		currentScopeVer = scope_ver;
+	}
+
+	return true;
+
+} catch (ErrorInfo info) {
+	if (!retry) {
+		QString title = QTStr("Auth.AuthFailure.Title");
+		QString text = QTStr("Auth.AuthFailure.Text")
+			.arg(service(), info.message.c_str(), info.error.c_str());
+
+		QMessageBox::warning(OBSBasic::Get(), title, text);
+	}
+
+	blog(LOG_WARNING, "%s: %s: %s",
+			__FUNCTION__,
+			info.message.c_str(),
+			info.error.c_str());
+	return false;
+}
+
+void OAuthStreamKey::OnStreamConfig()
+{
+	OBSBasic *main = OBSBasic::Get();
+	obs_service_t *service = main->GetService();
+
+	obs_data_t *settings = obs_service_get_settings(service);
+
+	obs_data_set_string(settings, "key", key_.c_str());
+	obs_service_update(service, settings);
+
+	obs_data_release(settings);
+}

+ 76 - 0
UI/auth-oauth.hpp

@@ -0,0 +1,76 @@
+#pragma once
+
+#include <QDialog>
+#include <string>
+#include <memory>
+
+#include "auth-base.hpp"
+
+class QCefWidget;
+
+class OAuthLogin : public QDialog {
+	Q_OBJECT
+
+	QCefWidget *cefWidget = nullptr;
+	QString code;
+	bool get_token = false;
+	bool fail = false;
+
+public:
+	OAuthLogin(QWidget *parent, const std::string &url, bool token);
+	~OAuthLogin();
+
+	inline QString GetCode() const {return code;}
+	inline bool LoadFail() const {return fail;}
+
+public slots:
+	void urlChanged(const QString &url);
+};
+
+class OAuth : public Auth {
+	Q_OBJECT
+
+public:
+	inline OAuth(const Def &d) : Auth(d) {}
+
+	typedef std::function<std::shared_ptr<Auth> (QWidget *)> login_cb;
+	typedef std::function<void()> delete_cookies_cb;
+
+	static std::shared_ptr<Auth> Login(QWidget *parent,
+			const std::string &service);
+	static void DeleteCookies(const std::string &service);
+
+	static void RegisterOAuth(const Def &d, create_cb create,
+			login_cb login, delete_cookies_cb delete_cookies);
+
+protected:
+	std::string refresh_token;
+	std::string token;
+	bool implicit = false;
+	uint64_t expire_time = 0;
+	int currentScopeVer = 0;
+
+	virtual void SaveInternal() override;
+	virtual bool LoadInternal() override;
+
+	virtual bool RetryLogin()=0;
+	bool TokenExpired();
+	bool GetToken(const char *url, const std::string &client_id,
+			int scope_ver,
+			const std::string &auth_code = std::string(),
+			bool retry = false);
+};
+
+class OAuthStreamKey : public OAuth {
+	Q_OBJECT
+
+protected:
+	std::string key_;
+
+public:
+	inline OAuthStreamKey(const Def &d) : OAuth(d) {}
+
+	inline const std::string &key() const {return key_;}
+
+	virtual void OnStreamConfig() override;
+};

+ 12 - 0
UI/data/locale/en-US.ini

@@ -91,6 +91,18 @@ AlreadyRunning.Title="OBS is already running"
 AlreadyRunning.Text="OBS is already running!  Unless you meant to do this, please shut down any existing instances of OBS before trying to run a new instance.  If you have OBS set to minimize to the system tray, please check to see if it's still running there."
 AlreadyRunning.LaunchAnyway="Launch Anyway"
 
+# Auth
+Auth.Authing.Title="Authenticating.."
+Auth.Authing.Text="Authenticating with %1, please wait.."
+Auth.AuthFailure.Title="Authentication Failure"
+Auth.AuthFailure.Text="Failed to authenticate with %1:\n\n%2: %3"
+Auth.InvalidScope.Title="Authentication Required"
+Auth.InvalidScope.Text="The authentication requirements for %1 have changed.  Some features may not be available."
+Auth.LoadingChannel.Title="Loading channel information.."
+Auth.LoadingChannel.Text="Loading channel information for %1, please wait.."
+Auth.ChannelFailure.Title="Failed to load channel"
+Auth.ChannelFailure.Text="Failed to load channel information for %1\n\n%2: %3"
+
 # copy filters
 Copy.Filters="Copy Filters"
 Paste.Filters="Paste Filters"

+ 8 - 0
UI/window-basic-main-outputs.cpp

@@ -651,6 +651,10 @@ bool SimpleOutput::StartStreaming(obs_service_t *service)
 	if (!Active())
 		SetupOutputs();
 
+	Auth *auth = main->GetAuth();
+	if (auth)
+		auth->OnStreamConfig();
+
 	/* --------------------- */
 
 	const char *type = obs_service_get_output_type(service);
@@ -1426,6 +1430,10 @@ bool AdvancedOutput::StartStreaming(obs_service_t *service)
 	if (!Active())
 		SetupOutputs();
 
+	Auth *auth = main->GetAuth();
+	if (auth)
+		auth->OnStreamConfig();
+
 	/* --------------------- */
 
 	int trackIndex = config_get_int(main->Config(), "AdvOut",

+ 10 - 0
UI/window-basic-main-profiles.cpp

@@ -232,7 +232,9 @@ bool OBSBasic::AddProfile(bool create_new, const char *title, const char *text,
 	config_set_string(App()->GlobalConfig(), "Basic", "ProfileDir",
 			newDir.c_str());
 
+	Auth::Save();
 	if (create_new) {
+		auth.reset();
 		DestroyPanelCookieManager();
 	} else if (!rename) {
 		DuplicateCurrentCookieProfile(config);
@@ -456,6 +458,8 @@ void OBSBasic::on_actionRemoveProfile_triggered()
 	config_set_string(App()->GlobalConfig(), "Basic", "ProfileDir",
 			newDir);
 
+	Auth::Save();
+	auth.reset();
 	DestroyPanelCookieManager();
 
 	config.Swap(basicConfig);
@@ -471,6 +475,8 @@ void OBSBasic::on_actionRemoveProfile_triggered()
 
 	UpdateTitleBar();
 
+	Auth::Load();
+
 	if (api) {
 		api->on_event(OBS_FRONTEND_EVENT_PROFILE_LIST_CHANGED);
 		api->on_event(OBS_FRONTEND_EVENT_PROFILE_CHANGED);
@@ -615,6 +621,8 @@ void OBSBasic::ChangeProfile()
 	config_set_string(App()->GlobalConfig(), "Basic", "ProfileDir",
 			newDir);
 
+	Auth::Save();
+	auth.reset();
 	DestroyPanelCookieManager();
 
 	config.Swap(basicConfig);
@@ -624,6 +632,8 @@ void OBSBasic::ChangeProfile()
 	config_save_safe(App()->GlobalConfig(), "tmp", nullptr);
 	UpdateTitleBar();
 
+	Auth::Load();
+
 	CheckForSimpleModeX264Fallback();
 
 	blog(LOG_INFO, "Switched to profile '%s' (%s)",

+ 8 - 4
UI/window-basic-main.cpp

@@ -1786,6 +1786,8 @@ void OBSBasic::OnFirstLoad()
 		}
 	}
 #endif
+
+	Auth::Load();
 }
 
 void OBSBasic::DeferredLoad(const QString &file, int requeueCount)
@@ -3656,10 +3658,6 @@ void OBSBasic::closeEvent(QCloseEvent *event)
 				"BasicWindow", "geometry",
 				saveGeometry().toBase64().constData());
 
-	config_set_string(App()->GlobalConfig(),
-			"BasicWindow", "DockState",
-			saveState().toBase64().constData());
-
 	if (outputHandler && outputHandler->Active()) {
 		SetShowing(true);
 
@@ -3688,7 +3686,13 @@ void OBSBasic::closeEvent(QCloseEvent *event)
 
 	signalHandlers.clear();
 
+	Auth::Save();
 	SaveProjectNow();
+	auth.reset();
+
+	config_set_string(App()->GlobalConfig(),
+			"BasicWindow", "DockState",
+			saveState().toBase64().constData());
 
 	if (api)
 		api->on_event(OBS_FRONTEND_EVENT_EXIT);

+ 6 - 0
UI/window-basic-main.hpp

@@ -32,6 +32,7 @@
 #include "window-basic-filters.hpp"
 #include "window-projector.hpp"
 #include "window-basic-about.hpp"
+#include "auth-base.hpp"
 
 #include <obs-frontend-internal.hpp>
 
@@ -116,6 +117,7 @@ class OBSBasic : public OBSMainWindow {
 	friend class OBSBasicStatusBar;
 	friend class OBSBasicSourceSelect;
 	friend class OBSBasicSettings;
+	friend class Auth;
 	friend struct OBSStudioAPI;
 
 	enum class MoveDir {
@@ -136,6 +138,8 @@ class OBSBasic : public OBSMainWindow {
 private:
 	obs_frontend_callbacks *api = nullptr;
 
+	std::shared_ptr<Auth> auth;
+
 	std::vector<VolControl*> volumes;
 
 	std::vector<OBSSignal> signalHandlers;
@@ -591,6 +595,8 @@ public:
 	void SaveService();
 	bool LoadService();
 
+	inline Auth *GetAuth() {return auth.get();}
+
 	inline void EnableOutputs(bool enable)
 	{
 		if (enable) {