From ef04ff30b0ce42f1148d35d5d465f45f6faced00 Mon Sep 17 00:00:00 2001 From: Simon Hong Date: Tue, 12 Oct 2021 12:04:26 +0900 Subject: [PATCH] Improved vpn connect/disconnect logic Whenever connect is asked, vpn service tries to create os vpn entry with latest hostname and user credentials from guardian service. fix https://github.com/brave/brave-browser/issues/18648 fix https://github.com/brave/brave-browser/issues/18422 --- components/brave_vpn/BUILD.gn | 3 + .../brave_vpn/brave_vpn_connection_info.cc | 7 + .../brave_vpn/brave_vpn_connection_info.h | 1 + .../brave_vpn/brave_vpn_os_connection_api.h | 17 +- .../brave_vpn_os_connection_api_mac.mm | 37 +- .../brave_vpn_os_connection_api_sim.cc | 12 +- .../brave_vpn_os_connection_api_win.cc | 19 +- components/brave_vpn/brave_vpn_service.cc | 32 +- components/brave_vpn/brave_vpn_service.h | 13 +- .../brave_vpn/brave_vpn_service_desktop.cc | 577 ++++++++++++++---- .../brave_vpn/brave_vpn_service_desktop.h | 65 +- components/brave_vpn/brave_vpn_unittest.cc | 131 +++- .../brave_vpn/resources/panel/vpn_panel.tsx | 1 - components/brave_vpn/switches.h | 4 +- components/brave_vpn/utils_win.cc | 14 +- 15 files changed, 756 insertions(+), 177 deletions(-) diff --git a/components/brave_vpn/BUILD.gn b/components/brave_vpn/BUILD.gn index 85a6fe0d454c..09f75530a14e 100644 --- a/components/brave_vpn/BUILD.gn +++ b/components/brave_vpn/BUILD.gn @@ -74,6 +74,7 @@ static_library("brave_vpn") { ":brave_vpn_internal", ":mojom", "//brave/components/resources:strings", + "//brave/components/skus/browser", "//components/prefs", "//third_party/icu", "//ui/base", @@ -112,7 +113,9 @@ source_set("unit_tests") { deps = [ ":brave_vpn", "//base", + "//brave/components/skus/browser", "//components/prefs:test_support", + "//components/sync_preferences:test_support", "//content/test:test_support", "//services/network:test_support", "//testing/gtest", diff --git a/components/brave_vpn/brave_vpn_connection_info.cc b/components/brave_vpn/brave_vpn_connection_info.cc index 9b81157778d0..6c2c1a798a49 100644 --- a/components/brave_vpn/brave_vpn_connection_info.cc +++ b/components/brave_vpn/brave_vpn_connection_info.cc @@ -15,6 +15,13 @@ BraveVPNConnectionInfo::BraveVPNConnectionInfo( BraveVPNConnectionInfo& BraveVPNConnectionInfo::operator=( const BraveVPNConnectionInfo& info) = default; +void BraveVPNConnectionInfo::Reset() { + connection_name_.clear(); + hostname_.clear(); + username_.clear(); + password_.clear(); +} + bool BraveVPNConnectionInfo::IsValid() const { // TODO(simonhong): Improve credentials validation. return !hostname_.empty() && !username_.empty() && !password_.empty(); diff --git a/components/brave_vpn/brave_vpn_connection_info.h b/components/brave_vpn/brave_vpn_connection_info.h index 25ecd9b796c2..863c95662125 100644 --- a/components/brave_vpn/brave_vpn_connection_info.h +++ b/components/brave_vpn/brave_vpn_connection_info.h @@ -17,6 +17,7 @@ class BraveVPNConnectionInfo { BraveVPNConnectionInfo(const BraveVPNConnectionInfo& info); BraveVPNConnectionInfo& operator=(const BraveVPNConnectionInfo& info); + void Reset(); bool IsValid() const; void SetConnectionInfo(const std::string& connection_name, const std::string& hostname, diff --git a/components/brave_vpn/brave_vpn_os_connection_api.h b/components/brave_vpn/brave_vpn_os_connection_api.h index 6f503b7c05d2..fa05b819c981 100644 --- a/components/brave_vpn/brave_vpn_os_connection_api.h +++ b/components/brave_vpn/brave_vpn_os_connection_api.h @@ -18,15 +18,14 @@ class BraveVPNOSConnectionAPI { public: class Observer : public base::CheckedObserver { public: - // TODO(simonhong): Don't need |name| parameter because only one vpn - // connection is managed. - virtual void OnCreated(const std::string& name) = 0; - virtual void OnRemoved(const std::string& name) = 0; - virtual void OnConnected(const std::string& name) = 0; - virtual void OnIsConnecting(const std::string& name) = 0; - virtual void OnConnectFailed(const std::string& name) = 0; - virtual void OnDisconnected(const std::string& name) = 0; - virtual void OnIsDisconnecting(const std::string& name) = 0; + virtual void OnCreated() = 0; + virtual void OnCreateFailed() = 0; + virtual void OnRemoved() = 0; + virtual void OnConnected() = 0; + virtual void OnIsConnecting() = 0; + virtual void OnConnectFailed() = 0; + virtual void OnDisconnected() = 0; + virtual void OnIsDisconnecting() = 0; protected: ~Observer() override = default; diff --git a/components/brave_vpn/brave_vpn_os_connection_api_mac.mm b/components/brave_vpn/brave_vpn_os_connection_api_mac.mm index 1d00adb99f37..b103fe74c7ad 100644 --- a/components/brave_vpn/brave_vpn_os_connection_api_mac.mm +++ b/components/brave_vpn/brave_vpn_os_connection_api_mac.mm @@ -22,6 +22,26 @@ const NSString* kBraveVPNKey = @"BraveVPNKey"; +std::string NEVPNStatusToString(NEVPNStatus status) { + switch (status) { + case NEVPNStatusInvalid: + return "NEVPNStatusInvalid"; + case NEVPNStatusDisconnected: + return "NEVPNStatusDisconnected"; + case NEVPNStatusConnecting: + return "NEVPNStatusConnecting"; + case NEVPNStatusConnected: + return "NEVPNStatusConnected"; + case NEVPNStatusReasserting: + return "NEVPNStatusReasserting"; + case NEVPNStatusDisconnecting: + return "NEVPNStatusDisconnecting"; + default: + NOTREACHED(); + } + return std::string(); +} + NSData* GetPasswordRefForAccount() { NSString* bundle_id = [[NSBundle mainBundle] bundleIdentifier]; CFTypeRef copy_result = NULL; @@ -172,11 +192,14 @@ OSStatus StorePassword(const NSString* password) { if (error) { LOG(ERROR) << "Create - saveToPrefs error: " << base::SysNSStringToUTF8([error localizedDescription]); + for (Observer& obs : observers_) + obs.OnCreateFailed(); + return; } VLOG(2) << "Create - saveToPrefs success"; for (Observer& obs : observers_) - obs.OnCreated(std::string()); + obs.OnCreated(); }]; }]; } @@ -201,7 +224,7 @@ OSStatus StorePassword(const NSString* password) { } VLOG(2) << "RemoveVPNConnection - successfully removed"; for (Observer& obs : observers_) - obs.OnRemoved(std::string()); + obs.OnRemoved(); }]; } RemoveKeychainItemForAccount(); @@ -264,25 +287,25 @@ OSStatus StorePassword(const NSString* password) { } NEVPNStatus current_status = [[vpn_manager connection] status]; - VLOG(2) << "CheckConnection: " << current_status; + VLOG(2) << "CheckConnection: " << NEVPNStatusToString(current_status); switch (current_status) { case NEVPNStatusConnected: for (Observer& obs : observers_) - obs.OnConnected(name); + obs.OnConnected(); break; case NEVPNStatusConnecting: case NEVPNStatusReasserting: for (Observer& obs : observers_) - obs.OnIsConnecting(name); + obs.OnIsConnecting(); break; case NEVPNStatusDisconnected: case NEVPNStatusInvalid: for (Observer& obs : observers_) - obs.OnDisconnected(name); + obs.OnDisconnected(); break; case NEVPNStatusDisconnecting: for (Observer& obs : observers_) - obs.OnIsDisconnecting(name); + obs.OnIsDisconnecting(); break; default: break; diff --git a/components/brave_vpn/brave_vpn_os_connection_api_sim.cc b/components/brave_vpn/brave_vpn_os_connection_api_sim.cc index d6c077f6c5a2..9859cd3dc3c3 100644 --- a/components/brave_vpn/brave_vpn_os_connection_api_sim.cc +++ b/components/brave_vpn/brave_vpn_os_connection_api_sim.cc @@ -92,7 +92,7 @@ void BraveVPNOSConnectionAPISim::OnCreated(const std::string& name, return; for (Observer& obs : observers_) - obs.OnCreated(name); + obs.OnCreated(); } void BraveVPNOSConnectionAPISim::OnConnected(const std::string& name, @@ -104,12 +104,12 @@ void BraveVPNOSConnectionAPISim::OnConnected(const std::string& name, } for (Observer& obs : observers_) - success ? obs.OnConnected(name) : obs.OnConnectFailed(name); + success ? obs.OnConnected() : obs.OnConnectFailed(); } void BraveVPNOSConnectionAPISim::OnIsConnecting(const std::string& name) { for (Observer& obs : observers_) - obs.OnIsConnecting(name); + obs.OnIsConnecting(); } void BraveVPNOSConnectionAPISim::OnDisconnected(const std::string& name, @@ -118,12 +118,12 @@ void BraveVPNOSConnectionAPISim::OnDisconnected(const std::string& name, return; for (Observer& obs : observers_) - obs.OnDisconnected(name); + obs.OnDisconnected(); } void BraveVPNOSConnectionAPISim::OnIsDisconnecting(const std::string& name) { for (Observer& obs : observers_) - obs.OnIsDisconnecting(name); + obs.OnIsDisconnecting(); } void BraveVPNOSConnectionAPISim::OnRemoved(const std::string& name, @@ -132,7 +132,7 @@ void BraveVPNOSConnectionAPISim::OnRemoved(const std::string& name, return; for (Observer& obs : observers_) - obs.OnRemoved(name); + obs.OnRemoved(); } } // namespace brave_vpn diff --git a/components/brave_vpn/brave_vpn_os_connection_api_win.cc b/components/brave_vpn/brave_vpn_os_connection_api_win.cc index b867ac2dec81..52731640c3cf 100644 --- a/components/brave_vpn/brave_vpn_os_connection_api_win.cc +++ b/components/brave_vpn/brave_vpn_os_connection_api_win.cc @@ -136,19 +136,19 @@ void BraveVPNOSConnectionAPIWin::OnCheckConnection( for (Observer& obs : observers_) { switch (result) { case CheckConnectionResult::CONNECTED: - obs.OnConnected(name); + obs.OnConnected(); break; case CheckConnectionResult::CONNECTING: - obs.OnIsConnecting(name); + obs.OnIsConnecting(); break; case CheckConnectionResult::CONNECT_FAILED: - obs.OnConnectFailed(name); + obs.OnConnectFailed(); break; case CheckConnectionResult::DISCONNECTED: - obs.OnDisconnected(name); + obs.OnDisconnected(); break; case CheckConnectionResult::DISCONNECTING: - obs.OnIsDisconnecting(name); + obs.OnIsDisconnecting(); break; default: break; @@ -158,11 +158,14 @@ void BraveVPNOSConnectionAPIWin::OnCheckConnection( void BraveVPNOSConnectionAPIWin::OnCreated(const std::string& name, bool success) { - if (!success) + if (!success) { + for (Observer& obs : observers_) + obs.OnCreateFailed(); return; + } for (Observer& obs : observers_) - obs.OnCreated(name); + obs.OnCreated(); } void BraveVPNOSConnectionAPIWin::OnRemoved(const std::string& name, @@ -171,7 +174,7 @@ void BraveVPNOSConnectionAPIWin::OnRemoved(const std::string& name, return; for (Observer& obs : observers_) - obs.OnRemoved(name); + obs.OnRemoved(); } void BraveVPNOSConnectionAPIWin::StartVPNConnectionChangeMonitoring() { diff --git a/components/brave_vpn/brave_vpn_service.cc b/components/brave_vpn/brave_vpn_service.cc index cb7086f372e1..0507d1ffd9d3 100644 --- a/components/brave_vpn/brave_vpn_service.cc +++ b/components/brave_vpn/brave_vpn_service.cc @@ -22,6 +22,8 @@ constexpr char kCreateSubscriberCredential[] = "api/v1/subscriber-credential/create"; constexpr char kProfileCredential[] = "api/v1.1/register-and-create"; constexpr char kVerifyPurchaseToken[] = "api/v1.1/verify-purchase-token"; +constexpr char kCreateSubscriberCredentialV12[] = + "api/v1.2/subscriber-credential/create"; net::NetworkTrafficAnnotationTag GetNetworkTrafficAnnotationTag() { return net::DefineNetworkTrafficAnnotation("brave_vpn_service", R"( @@ -81,12 +83,14 @@ BraveVpnService::~BraveVpnService() = default; void BraveVpnService::Shutdown() {} -void BraveVpnService::OAuthRequest(const GURL& url, - const std::string& method, - const std::string& post_data, - URLRequestCallback callback) { +void BraveVpnService::OAuthRequest( + const GURL& url, + const std::string& method, + const std::string& post_data, + URLRequestCallback callback, + const base::flat_map& headers) { api_request_helper_.Request(method, url, post_data, "application/json", false, - std::move(callback)); + std::move(callback), headers); } void BraveVpnService::GetAllServerRegions(ResponseCallback callback) { @@ -196,3 +200,21 @@ void BraveVpnService::OnGetSubscriberCredential( } std::move(callback).Run(subscriber_credential, success); } + +void BraveVpnService::GetSubscriberCredentialV12( + ResponseCallback callback, + const std::string& payments_environment, + const std::string& monthly_pass) { + auto internal_callback = + base::BindOnce(&BraveVpnService::OnGetSubscriberCredential, + weak_ptr_factory_.GetWeakPtr(), std::move(callback)); + + const GURL base_url = + GetURLWithPath(kVpnHost, kCreateSubscriberCredentialV12); + base::Value dict(base::Value::Type::DICTIONARY); + dict.SetStringKey("validation-method", "brave-premium"); + dict.SetStringKey("brave-vpn-premium-monthly-pass", monthly_pass); + std::string request_body = CreateJSONRequestBody(dict); + OAuthRequest(base_url, "POST", request_body, std::move(internal_callback), + {{"Brave-Payments-Environment", payments_environment}}); +} diff --git a/components/brave_vpn/brave_vpn_service.h b/components/brave_vpn/brave_vpn_service.h index a5e70e2921d6..05c6fbce4203 100644 --- a/components/brave_vpn/brave_vpn_service.h +++ b/components/brave_vpn/brave_vpn_service.h @@ -53,6 +53,9 @@ class BraveVpnService : public KeyedService { const std::string& product_id, const std::string& product_type, const std::string& bundle_id); + void GetSubscriberCredentialV12(ResponseCallback callback, + const std::string& payments_environment, + const std::string& monthly_pass); private: using URLRequestCallback = @@ -60,10 +63,12 @@ class BraveVpnService : public KeyedService { const std::string&, const base::flat_map&)>; - void OAuthRequest(const GURL& url, - const std::string& method, - const std::string& post_data, - URLRequestCallback callback); + void OAuthRequest( + const GURL& url, + const std::string& method, + const std::string& post_data, + URLRequestCallback callback, + const base::flat_map& headers = {}); void OnGetResponse(ResponseCallback callback, int status, diff --git a/components/brave_vpn/brave_vpn_service_desktop.cc b/components/brave_vpn/brave_vpn_service_desktop.cc index 5ddfe9e6c79f..16911d1246f2 100644 --- a/components/brave_vpn/brave_vpn_service_desktop.cc +++ b/components/brave_vpn/brave_vpn_service_desktop.cc @@ -6,7 +6,6 @@ #include "brave/components/brave_vpn/brave_vpn_service_desktop.h" #include -#include #include #include "base/bind.h" @@ -20,6 +19,7 @@ #include "brave/components/brave_vpn/pref_names.h" #include "brave/components/brave_vpn/switches.h" #include "brave/components/brave_vpn/url_constants.h" +#include "brave/components/skus/browser/pref_names.h" #include "components/prefs/pref_service.h" #include "components/prefs/scoped_user_pref_update.h" #include "third_party/icu/source/i18n/unicode/timezone.h" @@ -32,30 +32,44 @@ constexpr char kRegionContinentKey[] = "continent"; constexpr char kRegionNameKey[] = "name"; constexpr char kRegionNamePrettyKey[] = "name-pretty"; -bool GetVPNCredentialsFromSwitch(brave_vpn::BraveVPNConnectionInfo* info) { - DCHECK(info); - auto* cmd = base::CommandLine::ForCurrentProcess(); - if (!cmd->HasSwitch(brave_vpn::switches::kBraveVPNTestCredentials)) - return false; - - std::string value = - cmd->GetSwitchValueASCII(brave_vpn::switches::kBraveVPNTestCredentials); - std::vector tokens = base::SplitString( - value, ":", base::KEEP_WHITESPACE, base::SPLIT_WANT_ALL); - if (tokens.size() == 4) { - info->SetConnectionInfo(tokens[0], tokens[1], tokens[2], tokens[3]); - return true; +std::string GetStringFor(ConnectionState state) { + switch (state) { + case ConnectionState::CONNECTED: + return "Connected"; + case ConnectionState::CONNECTING: + return "Connecting"; + case ConnectionState::DISCONNECTED: + return "Disconnected"; + case ConnectionState::DISCONNECTING: + return "Disconnecting"; + case ConnectionState::CONNECT_FAILED: + return "Connect failed"; + default: + NOTREACHED(); } - LOG(ERROR) << __func__ << ": Invalid credentials"; - return false; + return std::string(); } -brave_vpn::BraveVPNOSConnectionAPI* GetBraveVPNConnectionAPI() { +bool IsValidRegion(const brave_vpn::mojom::Region& region) { + if (region.continent.empty() || region.name.empty() || + region.name_pretty.empty()) + return false; + + return true; +} + +std::string GetBraveVPNPaymentsEnv() { auto* cmd = base::CommandLine::ForCurrentProcess(); - if (cmd->HasSwitch(brave_vpn::switches::kBraveVPNSimulation)) - return brave_vpn::BraveVPNOSConnectionAPI::GetInstanceForTest(); - return brave_vpn::BraveVPNOSConnectionAPI::GetInstance(); + if (!cmd->HasSwitch(brave_vpn::switches::kBraveVPNPaymentsEnv)) { +#if defined(OFFICIAL_BUILD) + return ""; +#else + return "development"; +#endif + } + + return cmd->GetSwitchValueASCII(brave_vpn::switches::kBraveVPNPaymentsEnv); } } // namespace @@ -65,127 +79,262 @@ BraveVpnServiceDesktop::BraveVpnServiceDesktop( PrefService* prefs) : BraveVpnService(url_loader_factory), prefs_(prefs) { DCHECK(brave_vpn::IsBraveVPNEnabled()); + DETACH_FROM_SEQUENCE(sequence_checker_); + auto* cmd = base::CommandLine::ForCurrentProcess(); + is_simulation_ = cmd->HasSwitch(brave_vpn::switches::kBraveVPNSimulation); observed_.Observe(GetBraveVPNConnectionAPI()); GetBraveVPNConnectionAPI()->set_target_vpn_entry_name(kBraveVPNEntryName); - GetBraveVPNConnectionAPI()->CheckConnection(kBraveVPNEntryName); + // Load cached data. LoadCachedRegionData(); - FetchRegionData(); - CheckPurchasedStatus(); + LoadPurchasedState(); + LoadSelectedRegion(); + + // If already in purchased state, there is a high probability that + // the user has previously connected. So, try connection checking. + if (is_purchased_user()) + GetBraveVPNConnectionAPI()->CheckConnection(kBraveVPNEntryName); + + pref_change_registrar_.Init(prefs_); + pref_change_registrar_.Add( + brave_rewards::prefs::kSkusVPNCredential, + base::BindRepeating(&BraveVpnServiceDesktop::OnSkusVPNCredentialUpdated, + base::Unretained(this))); +} + +BraveVpnServiceDesktop::~BraveVpnServiceDesktop() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); +} + +void BraveVpnServiceDesktop::ScheduleFetchRegionDataIfNeeded() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + + if (!is_purchased_user()) + return; + if (region_data_update_timer_.IsRunning()) + return; + + // Try to update region list every 5h. + FetchRegionData(); constexpr int kRegionDataUpdateIntervalInHours = 5; region_data_update_timer_.Start( FROM_HERE, base::TimeDelta::FromHours(kRegionDataUpdateIntervalInHours), this, &BraveVpnServiceDesktop::FetchRegionData); } -BraveVpnServiceDesktop::~BraveVpnServiceDesktop() = default; - void BraveVpnServiceDesktop::Shutdown() { BraveVpnService::Shutdown(); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + observed_.Reset(); receivers_.Clear(); observers_.Clear(); + pref_change_registrar_.RemoveAll(); } -void BraveVpnServiceDesktop::OnCreated(const std::string& name) { +void BraveVpnServiceDesktop::OnCreated() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; + if (cancel_connecting_) { + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTED); + cancel_connecting_ = false; + return; + } + for (const auto& obs : observers_) obs->OnConnectionCreated(); + + // It's time to ask connecting to os after vpn entry is created. + GetBraveVPNConnectionAPI()->Connect(GetConnectionInfo().connection_name()); +} + +void BraveVpnServiceDesktop::OnCreateFailed() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); } -void BraveVpnServiceDesktop::OnRemoved(const std::string& name) { +void BraveVpnServiceDesktop::OnRemoved() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; for (const auto& obs : observers_) obs->OnConnectionRemoved(); } -void BraveVpnServiceDesktop::OnConnected(const std::string& name) { - if (connection_state_ == ConnectionState::CONNECTED) +void BraveVpnServiceDesktop::UpdateAndNotifyConnectionStateChange( + ConnectionState state) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (connection_state_ == state) return; - connection_state_ = ConnectionState::CONNECTED; - - for (const auto& obs : observers_) - obs->OnConnectionStateChanged(ConnectionState::CONNECTED); -} + // On Windows, we get disconnected status update twice. + // When user connects to different region while connected, + // we disconnect current connection and connect to newly selected + // region. To do that we monitor |DISCONNECTED| state and start + // connect when we get that state. But, Windows sends disconnected state + // noti again. So, ignore second one. + // On exception - we allow from connecting to disconnected in canceling + // scenario. + if (connection_state_ == ConnectionState::CONNECTING && + state == ConnectionState::DISCONNECTED && !cancel_connecting_) { + VLOG(2) << __func__ << ": Ignore disconnected state while connecting"; + return; + } -void BraveVpnServiceDesktop::OnIsConnecting(const std::string& name) { - if (connection_state_ == ConnectionState::CONNECTING) + // On Windows, we could get disconnected state after connect failed. + // To make connect failed state as a last state, ignore disconnected state. + if (connection_state_ == ConnectionState::CONNECT_FAILED && + state == ConnectionState::DISCONNECTED) { + VLOG(2) << __func__ << ": Ignore disconnected state after connect failed"; return; + } - connection_state_ = ConnectionState::CONNECTING; + VLOG(2) << __func__ << " : changing from " << GetStringFor(connection_state_) + << " to " << GetStringFor(state); + connection_state_ = state; for (const auto& obs : observers_) - obs->OnConnectionStateChanged(ConnectionState::CONNECTING); + obs->OnConnectionStateChanged(connection_state_); } -void BraveVpnServiceDesktop::OnConnectFailed(const std::string& name) { - if (connection_state_ == ConnectionState::CONNECT_FAILED) - return; +void BraveVpnServiceDesktop::OnConnected() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; - connection_state_ = ConnectionState::CONNECT_FAILED; + if (cancel_connecting_) { + // As connect is done, we don't need more for cancelling. + // Just start normal Disconenct() process. + cancel_connecting_ = false; + GetBraveVPNConnectionAPI()->Disconnect(kBraveVPNEntryName); + return; + } - for (const auto& obs : observers_) - obs->OnConnectionStateChanged(ConnectionState::CONNECT_FAILED); + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECTED); } -void BraveVpnServiceDesktop::OnDisconnected(const std::string& name) { - if (connection_state_ == ConnectionState::DISCONNECTED) - return; +void BraveVpnServiceDesktop::OnIsConnecting() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; - connection_state_ = ConnectionState::DISCONNECTED; + if (!cancel_connecting_) + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECTING); +} - for (const auto& obs : observers_) - obs->OnConnectionStateChanged(ConnectionState::DISCONNECTED); +void BraveVpnServiceDesktop::OnConnectFailed() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; + + cancel_connecting_ = false; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); } -void BraveVpnServiceDesktop::OnIsDisconnecting(const std::string& name) { - if (connection_state_ == ConnectionState::DISCONNECTING) - return; +void BraveVpnServiceDesktop::OnDisconnected() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; - connection_state_ = ConnectionState::DISCONNECTING; + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTED); - for (const auto& obs : observers_) - obs->OnConnectionStateChanged(ConnectionState::DISCONNECTING); + if (needs_connect_) { + needs_connect_ = false; + Connect(); + } +} + +void BraveVpnServiceDesktop::OnIsDisconnecting() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTING); } void BraveVpnServiceDesktop::CreateVPNConnection() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (cancel_connecting_) { + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTED); + cancel_connecting_ = false; + return; + } + + VLOG(2) << __func__; GetBraveVPNConnectionAPI()->CreateVPNConnection(GetConnectionInfo()); } void BraveVpnServiceDesktop::RemoveVPNConnnection() { - GetBraveVPNConnectionAPI()->RemoveVPNConnection( - GetConnectionInfo().connection_name()); + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; + GetBraveVPNConnectionAPI()->RemoveVPNConnection(kBraveVPNEntryName); } void BraveVpnServiceDesktop::Connect() { - if (connection_state_ == ConnectionState::CONNECTING) + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (connection_state_ == ConnectionState::DISCONNECTING || + connection_state_ == ConnectionState::CONNECTING) { + VLOG(2) << __func__ + << ": Current state: " << GetStringFor(connection_state_) + << " : prevent connecting while previous operation is in-progress"; return; + } - GetBraveVPNConnectionAPI()->Connect(GetConnectionInfo().connection_name()); + DCHECK(!cancel_connecting_); + + // User can ask connect again when user want to change region. + if (connection_state_ == ConnectionState::CONNECTED) { + // Disconnect first and then create again to setup for new region. + needs_connect_ = true; + Disconnect(); + return; + } + + VLOG(2) << __func__ << " : start connecting!"; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECTING); + + if (is_simulation_ || connection_info_.IsValid()) { + VLOG(2) << __func__ + << " : direct connect as we already have valid connection info."; + GetBraveVPNConnectionAPI()->Connect(GetConnectionInfo().connection_name()); + return; + } + + // If user doesn't select region explicitely, use default device region. + std::string target_region_name = device_region_.name; + if (IsValidRegion(selected_region_)) { + target_region_name = selected_region_.name; + VLOG(2) << __func__ << " : start connecting with valid selected_region: " + << target_region_name; + } + + FetchHostnamesForRegion(target_region_name); } void BraveVpnServiceDesktop::Disconnect() { - if (connection_state_ == ConnectionState::DISCONNECTING) + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (connection_state_ == ConnectionState::DISCONNECTED) { + VLOG(2) << __func__ << " : already disconnected"; return; + } - GetBraveVPNConnectionAPI()->Disconnect(GetConnectionInfo().connection_name()); -} + if (connection_state_ == ConnectionState::DISCONNECTING) { + VLOG(2) << __func__ << " : disconnecting in progress"; + return; + } -void BraveVpnServiceDesktop::CheckPurchasedStatus() { - brave_vpn::BraveVPNConnectionInfo info; - if (GetVPNCredentialsFromSwitch(&info)) { - SetPurchasedState(PurchasedState::PURCHASED); - CreateVPNConnection(); + if (is_simulation_ || connection_state_ != ConnectionState::CONNECTING) { + VLOG(2) << __func__ << " : start disconnecting!"; + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTING); + GetBraveVPNConnectionAPI()->Disconnect(kBraveVPNEntryName); return; } - NOTIMPLEMENTED(); + cancel_connecting_ = true; + VLOG(2) << __func__ << " : Start cancelling connect request"; + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTING); } void BraveVpnServiceDesktop::ToggleConnection() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); const bool can_disconnect = (connection_state_ == ConnectionState::CONNECTED || connection_state_ == ConnectionState::CONNECTING); @@ -194,45 +343,52 @@ void BraveVpnServiceDesktop::ToggleConnection() { void BraveVpnServiceDesktop::AddObserver( mojo::PendingRemote observer) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); observers_.Add(std::move(observer)); } brave_vpn::BraveVPNConnectionInfo BraveVpnServiceDesktop::GetConnectionInfo() { - brave_vpn::BraveVPNConnectionInfo info; - if (GetVPNCredentialsFromSwitch(&info)) - return info; - - // TODO(simonhong): Get real credentials from payment service. - NOTIMPLEMENTED(); - return info; + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + return connection_info_; } void BraveVpnServiceDesktop::BindInterface( mojo::PendingReceiver receiver) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); receivers_.Add(this, std::move(receiver)); } void BraveVpnServiceDesktop::GetConnectionState( GetConnectionStateCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__ << " : " << static_cast(connection_state_); std::move(callback).Run(connection_state_); } void BraveVpnServiceDesktop::GetPurchasedState( GetPurchasedStateCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__ << " : " << static_cast(purchased_state_); std::move(callback).Run(purchased_state_); } void BraveVpnServiceDesktop::FetchRegionData() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__ << " : Start fetching region data"; // Unretained is safe here becasue this class owns request helper. GetAllServerRegions(base::BindOnce(&BraveVpnServiceDesktop::OnFetchRegionList, base::Unretained(this))); } void BraveVpnServiceDesktop::LoadCachedRegionData() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); auto* preference = prefs_->FindPreference(brave_vpn::prefs::kBraveVPNRegionList); - if (preference && !preference->IsDefaultValue()) + if (preference && !preference->IsDefaultValue()) { ParseAndCacheRegionList(preference->GetValue()->Clone()); + VLOG(2) << __func__ << " : " + << "Loaded cached region list"; + } preference = prefs_->FindPreference(brave_vpn::prefs::kBraveVPNDeviceRegion); if (preference && !preference->IsDefaultValue()) { @@ -246,12 +402,60 @@ void BraveVpnServiceDesktop::LoadCachedRegionData() { device_region_.continent = *continent; device_region_.name = *name; device_region_.name_pretty = *name_pretty; + VLOG(2) << __func__ << " : " + << "Loaded cached device region"; + } + } +} + +void BraveVpnServiceDesktop::LoadPurchasedState() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto* cmd = base::CommandLine::ForCurrentProcess(); + if (cmd->HasSwitch(brave_vpn::switches::kBraveVPNTestMonthlyPass)) { + skus_credential_ = + cmd->GetSwitchValueASCII(brave_vpn::switches::kBraveVPNTestMonthlyPass); + SetPurchasedState(PurchasedState::PURCHASED); + return; + } + + const std::string credential = + prefs_->GetString(brave_rewards::prefs::kSkusVPNCredential); + if (skus_credential_ == credential) + return; + + skus_credential_ = credential; + + if (!skus_credential_.empty()) { + VLOG(2) << __func__ << " : " + << "Loaded cached skus credentials"; + SetPurchasedState(PurchasedState::PURCHASED); + } +} + +void BraveVpnServiceDesktop::LoadSelectedRegion() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + auto* preference = + prefs_->FindPreference(brave_vpn::prefs::kBraveVPNSelectedRegion); + if (preference && !preference->IsDefaultValue()) { + auto* region_value = preference->GetValue(); + const std::string* continent = + region_value->FindStringKey(kRegionContinentKey); + const std::string* name = region_value->FindStringKey(kRegionNameKey); + const std::string* name_pretty = + region_value->FindStringKey(kRegionNamePrettyKey); + if (continent && name && name_pretty) { + selected_region_.continent = *continent; + selected_region_.name = *name; + selected_region_.name_pretty = *name_pretty; + VLOG(2) << __func__ << " : " + << "Loaded selected region"; } } } void BraveVpnServiceDesktop::OnFetchRegionList(const std::string& region_list, bool success) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!success) { // TODO(simonhong): Re-try? VLOG(2) << "Failed to get region list"; @@ -274,6 +478,7 @@ void BraveVpnServiceDesktop::OnFetchRegionList(const std::string& region_list, } bool BraveVpnServiceDesktop::ParseAndCacheRegionList(base::Value region_value) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(region_value.is_list()); if (!region_value.is_list()) return false; @@ -301,17 +506,18 @@ bool BraveVpnServiceDesktop::ParseAndCacheRegionList(base::Value region_value) { return (a.name_pretty < b.name_pretty); }); + VLOG(2) << __func__ << " : has regionlist: " << !regions_.empty(); + // If we can't get region list, we can't determine device region. if (regions_.empty()) return false; - return true; } void BraveVpnServiceDesktop::OnFetchTimezones(const std::string& timezones_list, bool success) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!success) { - // TODO(simonhong): Re-try? VLOG(2) << "Failed to get timezones list"; SetFallbackDeviceRegion(); return; @@ -323,12 +529,12 @@ void BraveVpnServiceDesktop::OnFetchTimezones(const std::string& timezones_list, return; } - // TODO(simonhong): Re-try? SetFallbackDeviceRegion(); } void BraveVpnServiceDesktop::ParseAndCacheDeviceRegionName( base::Value timezones_value) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(timezones_value.is_list()); if (!timezones_value.is_list()) { @@ -371,6 +577,7 @@ void BraveVpnServiceDesktop::ParseAndCacheDeviceRegionName( } void BraveVpnServiceDesktop::SetDeviceRegion(const std::string& name) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); auto it = std::find_if(regions_.begin(), regions_.end(), [&name](const auto& region) { return region.name == name; }); @@ -380,6 +587,7 @@ void BraveVpnServiceDesktop::SetDeviceRegion(const std::string& name) { } void BraveVpnServiceDesktop::SetFallbackDeviceRegion() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); // Set first item in the region list as a |device_region_| as a fallback. DCHECK(!regions_.empty()); if (regions_.empty()) @@ -390,6 +598,7 @@ void BraveVpnServiceDesktop::SetFallbackDeviceRegion() { void BraveVpnServiceDesktop::SetDeviceRegion( const brave_vpn::mojom::Region& region) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); device_region_ = region; DictionaryPrefUpdate update(prefs_, brave_vpn::prefs::kBraveVPNDeviceRegion); @@ -400,6 +609,7 @@ void BraveVpnServiceDesktop::SetDeviceRegion( } std::string BraveVpnServiceDesktop::GetCurrentTimeZone() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (!test_timezone_.empty()) return test_timezone_; @@ -412,6 +622,7 @@ std::string BraveVpnServiceDesktop::GetCurrentTimeZone() { } void BraveVpnServiceDesktop::GetAllRegions(GetAllRegionsCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); std::vector regions; for (const auto& region : regions_) { regions.push_back(region.Clone()); @@ -420,37 +631,40 @@ void BraveVpnServiceDesktop::GetAllRegions(GetAllRegionsCallback callback) { } void BraveVpnServiceDesktop::GetDeviceRegion(GetDeviceRegionCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; std::move(callback).Run(device_region_.Clone()); } void BraveVpnServiceDesktop::GetSelectedRegion( GetSelectedRegionCallback callback) { - auto* preference = - prefs_->FindPreference(brave_vpn::prefs::kBraveVPNSelectedRegion); - if (preference->IsDefaultValue()) { - // Gives device region if there is no cached selected region. - std::move(callback).Run(device_region_.Clone()); - return; - } + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; - auto* region_value = preference->GetValue(); - const std::string* continent = - region_value->FindStringKey(kRegionContinentKey); - const std::string* name = region_value->FindStringKey(kRegionNameKey); - const std::string* name_pretty = - region_value->FindStringKey(kRegionNamePrettyKey); - if (!continent || !name || !name_pretty) { - // Gives device region if invalid data is cached. + if (!IsValidRegion(selected_region_)) { + // Gives device region if there is no cached selected region. + VLOG(2) << __func__ << " : give device region instead."; std::move(callback).Run(device_region_.Clone()); return; } - brave_vpn::mojom::Region region(*continent, *name, *name_pretty); - std::move(callback).Run(region.Clone()); + VLOG(2) << __func__ << " : Give " << selected_region_.name_pretty; + std::move(callback).Run(selected_region_.Clone()); } void BraveVpnServiceDesktop::SetSelectedRegion( brave_vpn::mojom::RegionPtr region_ptr) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (connection_state_ == ConnectionState::DISCONNECTING || + connection_state_ == ConnectionState::CONNECTING) { + VLOG(2) << __func__ + << ": Current state: " << GetStringFor(connection_state_) + << " : prevent changing selected region while previous operation " + "is in-progress"; + return; + } + + VLOG(2) << __func__ << " : " << region_ptr->name_pretty; DictionaryPrefUpdate update(prefs_, brave_vpn::prefs::kBraveVPNSelectedRegion); base::Value* dict = update.Get(); @@ -458,11 +672,15 @@ void BraveVpnServiceDesktop::SetSelectedRegion( dict->SetStringKey(kRegionNameKey, region_ptr->name); dict->SetStringKey(kRegionNamePrettyKey, region_ptr->name_pretty); - // Start hostname fetching for selected region. - FetchHostnamesForRegion(region_ptr->name); + selected_region_.continent = region_ptr->continent; + selected_region_.name = region_ptr->name; + selected_region_.name_pretty = region_ptr->name_pretty; + + connection_info_.Reset(); } void BraveVpnServiceDesktop::GetProductUrls(GetProductUrlsCallback callback) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); brave_vpn::mojom::ProductUrls urls; urls.feedback = brave_vpn::kFeedbackUrl; urls.about = brave_vpn::kAboutUrl; @@ -471,6 +689,11 @@ void BraveVpnServiceDesktop::GetProductUrls(GetProductUrlsCallback callback) { } void BraveVpnServiceDesktop::FetchHostnamesForRegion(const std::string& name) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; + // Hostname will be replaced with latest one. + hostname_.reset(); + // Unretained is safe here becasue this class owns request helper. GetHostnamesForRegion( base::BindOnce(&BraveVpnServiceDesktop::OnFetchHostnames, @@ -481,8 +704,17 @@ void BraveVpnServiceDesktop::FetchHostnamesForRegion(const std::string& name) { void BraveVpnServiceDesktop::OnFetchHostnames(const std::string& region, const std::string& hostnames, bool success) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + VLOG(2) << __func__; + if (cancel_connecting_) { + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTED); + cancel_connecting_ = false; + return; + } + if (!success) { - // TODO(simonhong): Retry? + VLOG(2) << __func__ << " : failed to fetch hostnames for " << region; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); return; } @@ -492,20 +724,26 @@ void BraveVpnServiceDesktop::OnFetchHostnames(const std::string& region, return; } - // TODO(simonhong): Retry? + VLOG(2) << __func__ << " : failed to fetch hostnames for " << region; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); } void BraveVpnServiceDesktop::ParseAndCacheHostnames( const std::string& region, base::Value hostnames_value) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); DCHECK(hostnames_value.is_list()); - if (!hostnames_value.is_list()) + if (!hostnames_value.is_list()) { + VLOG(2) << __func__ << " : failed to parse hostnames for " << region; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); return; + } constexpr char kHostnameKey[] = "hostname"; constexpr char kDisplayNameKey[] = "display-name"; constexpr char kOfflineKey[] = "offline"; constexpr char kCapacityScoreKey[] = "capacity-score"; + std::vector hostnames; for (const auto& value : hostnames_value.GetList()) { DCHECK(value.is_dict()); @@ -524,10 +762,42 @@ void BraveVpnServiceDesktop::ParseAndCacheHostnames( *offline, *capacity_score}); } - hostnames_[region] = std::move(hostnames); + VLOG(2) << __func__ << " : has hostname: " << !hostnames.empty(); + + if (hostnames.empty()) { + VLOG(2) << __func__ << " : got empty hostnames list for " << region; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); + return; + } + + hostname_ = PickBestHostname(hostnames); + if (hostname_->hostname.empty()) { + VLOG(2) << __func__ << " : got empty hostnames list for " << region; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); + return; + } + + VLOG(2) << __func__ << " : Picked " << hostname_->hostname << ", " + << hostname_->display_name << ", " << hostname_->is_offline << ", " + << hostname_->capacity_score; + + if (skus_credential_.empty()) { + VLOG(2) << __func__ << " : skus_credential is empty"; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); + return; + } + + // Get subscriber credentials and then get EAP credentials with it to create + // OS VPN entry. + VLOG(2) << __func__ << " : request subsriber credential"; + GetSubscriberCredentialV12( + base::BindOnce(&BraveVpnServiceDesktop::OnGetSubscriberCredential, + base::Unretained(this)), + GetBraveVPNPaymentsEnv(), skus_credential_); } void BraveVpnServiceDesktop::SetPurchasedState(PurchasedState state) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); if (purchased_state_ == state) return; @@ -535,4 +805,105 @@ void BraveVpnServiceDesktop::SetPurchasedState(PurchasedState state) { for (const auto& obs : observers_) obs->OnPurchasedStateChanged(purchased_state_); + + ScheduleFetchRegionDataIfNeeded(); +} + +void BraveVpnServiceDesktop::OnSkusVPNCredentialUpdated() { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + LoadPurchasedState(); +} + +void BraveVpnServiceDesktop::OnGetSubscriberCredential( + const std::string& subscriber_credential, + bool success) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (cancel_connecting_) { + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTED); + cancel_connecting_ = false; + return; + } + + if (!success) { + VLOG(2) << __func__ << " : failed to get subscriber credential"; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); + return; + } + + VLOG(2) << __func__ << " : received subscriber credential"; + + GetProfileCredentials( + base::BindOnce(&BraveVpnServiceDesktop::OnGetProfileCredentials, + base::Unretained(this)), + subscriber_credential, hostname_->hostname); +} + +void BraveVpnServiceDesktop::OnGetProfileCredentials( + const std::string& profile_credential, + bool success) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + if (cancel_connecting_) { + UpdateAndNotifyConnectionStateChange(ConnectionState::DISCONNECTED); + cancel_connecting_ = false; + return; + } + + if (!success) { + VLOG(2) << __func__ << " : failed to get profile credential"; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); + return; + } + + VLOG(2) << __func__ << " : received profile credential"; + + absl::optional value = + base::JSONReader::Read(profile_credential); + if (value && value->is_dict()) { + constexpr char kUsernameKey[] = "eap-username"; + constexpr char kPasswordKey[] = "eap-password"; + const std::string* username = value->FindStringKey(kUsernameKey); + const std::string* password = value->FindStringKey(kPasswordKey); + if (!username || !password) { + VLOG(2) << __func__ << " : it's invalid profile credential"; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); + return; + } + + connection_info_.SetConnectionInfo(kBraveVPNEntryName, hostname_->hostname, + *username, *password); + // Let's create os vpn entry with |connection_info_|. + CreateVPNConnection(); + return; + } + + VLOG(2) << __func__ << " : it's invalid profile credential"; + UpdateAndNotifyConnectionStateChange(ConnectionState::CONNECT_FAILED); +} + +std::unique_ptr BraveVpnServiceDesktop::PickBestHostname( + const std::vector& hostnames) { + DCHECK_CALLED_ON_VALID_SEQUENCE(sequence_checker_); + std::vector filtered_hostnames; + std::copy_if( + hostnames.begin(), hostnames.end(), + std::back_inserter(filtered_hostnames), + [](const brave_vpn::Hostname& hostname) { return !hostname.is_offline; }); + + std::sort(filtered_hostnames.begin(), filtered_hostnames.end(), + [](const brave_vpn::Hostname& a, const brave_vpn::Hostname& b) { + return a.capacity_score > b.capacity_score; + }); + + if (filtered_hostnames.empty()) + return std::make_unique(); + + // Pick highest capacity score. + return std::make_unique(filtered_hostnames[0]); +} + +brave_vpn::BraveVPNOSConnectionAPI* +BraveVpnServiceDesktop::GetBraveVPNConnectionAPI() { + if (is_simulation_) + return brave_vpn::BraveVPNOSConnectionAPI::GetInstanceForTest(); + return brave_vpn::BraveVPNOSConnectionAPI::GetInstance(); } diff --git a/components/brave_vpn/brave_vpn_service_desktop.h b/components/brave_vpn/brave_vpn_service_desktop.h index bdf9682cb2aa..a5df2359d8ae 100644 --- a/components/brave_vpn/brave_vpn_service_desktop.h +++ b/components/brave_vpn/brave_vpn_service_desktop.h @@ -6,17 +6,19 @@ #ifndef BRAVE_COMPONENTS_BRAVE_VPN_BRAVE_VPN_SERVICE_DESKTOP_H_ #define BRAVE_COMPONENTS_BRAVE_VPN_BRAVE_VPN_SERVICE_DESKTOP_H_ +#include #include #include -#include "base/containers/flat_map.h" #include "base/scoped_observation.h" +#include "base/sequence_checker.h" #include "base/timer/timer.h" #include "brave/components/brave_vpn/brave_vpn.mojom.h" #include "brave/components/brave_vpn/brave_vpn_connection_info.h" #include "brave/components/brave_vpn/brave_vpn_data_types.h" #include "brave/components/brave_vpn/brave_vpn_os_connection_api.h" #include "brave/components/brave_vpn/brave_vpn_service.h" +#include "components/prefs/pref_change_registrar.h" #include "mojo/public/cpp/bindings/pending_remote.h" #include "mojo/public/cpp/bindings/receiver_set.h" #include "mojo/public/cpp/bindings/remote_set.h" @@ -43,21 +45,17 @@ class BraveVpnServiceDesktop BraveVpnServiceDesktop(const BraveVpnServiceDesktop&) = delete; BraveVpnServiceDesktop& operator=(const BraveVpnServiceDesktop&) = delete; + void ToggleConnection(); void RemoveVPNConnnection(); bool is_connected() const { return connection_state_ == ConnectionState::CONNECTED; } - bool is_purchased_user() const { return purchased_state_ == PurchasedState::PURCHASED; } - ConnectionState connection_state() const { return connection_state_; } - void CheckPurchasedStatus(); - void ToggleConnection(); - void BindInterface( mojo::PendingReceiver receiver); @@ -78,24 +76,33 @@ class BraveVpnServiceDesktop private: friend class BraveAppMenuBrowserTest; friend class BraveBrowserCommandControllerTest; - FRIEND_TEST_ALL_PREFIXES(BraveVPNTest, RegionDataTest); - FRIEND_TEST_ALL_PREFIXES(BraveVPNTest, HostnamesTest); - FRIEND_TEST_ALL_PREFIXES(BraveVPNTest, LoadRegionDataFromPrefsTest); + FRIEND_TEST_ALL_PREFIXES(BraveVPNServiceTest, RegionDataTest); + FRIEND_TEST_ALL_PREFIXES(BraveVPNServiceTest, HostnamesTest); + FRIEND_TEST_ALL_PREFIXES(BraveVPNServiceTest, CancelConnectingTest); + FRIEND_TEST_ALL_PREFIXES(BraveVPNServiceTest, ConnectionInfoTest); + FRIEND_TEST_ALL_PREFIXES(BraveVPNServiceTest, LoadPurchasedStateTest); + FRIEND_TEST_ALL_PREFIXES(BraveVPNServiceTest, LoadRegionDataFromPrefsTest); + FRIEND_TEST_ALL_PREFIXES(BraveVPNServiceTest, NeedsConnectTest); // BraveVpnService overrides: void Shutdown() override; // brave_vpn::BraveVPNOSConnectionAPI::Observer overrides: - void OnCreated(const std::string& name) override; - void OnRemoved(const std::string& name) override; - void OnConnected(const std::string& name) override; - void OnIsConnecting(const std::string& name) override; - void OnConnectFailed(const std::string& name) override; - void OnDisconnected(const std::string& name) override; - void OnIsDisconnecting(const std::string& name) override; + void OnCreated() override; + void OnCreateFailed() override; + void OnRemoved() override; + void OnConnected() override; + void OnIsConnecting() override; + void OnConnectFailed() override; + void OnDisconnected() override; + void OnIsDisconnecting() override; brave_vpn::BraveVPNConnectionInfo GetConnectionInfo(); void LoadCachedRegionData(); + void LoadPurchasedState(); + void LoadSelectedRegion(); + void UpdateAndNotifyConnectionStateChange(ConnectionState state); + void FetchRegionData(); void OnFetchRegionList(const std::string& region_list, bool success); bool ParseAndCacheRegionList(base::Value region_value); @@ -113,24 +120,46 @@ class BraveVpnServiceDesktop std::string GetCurrentTimeZone(); void SetPurchasedState(PurchasedState state); + void ScheduleFetchRegionDataIfNeeded(); + std::unique_ptr PickBestHostname( + const std::vector& hostnames); + + void OnSkusVPNCredentialUpdated(); + void OnGetSubscriberCredential(const std::string& subscriber_credential, + bool success); + void OnGetProfileCredentials(const std::string& profile_credential, + bool success); + + brave_vpn::BraveVPNOSConnectionAPI* GetBraveVPNConnectionAPI(); void set_test_timezone(const std::string& timezone) { test_timezone_ = timezone; } PrefService* prefs_ = nullptr; - base::flat_map> hostnames_; + PrefChangeRegistrar pref_change_registrar_; + std::string skus_credential_; std::vector regions_; brave_vpn::mojom::Region device_region_; - ConnectionState connection_state_ = ConnectionState::DISCONNECTED; + brave_vpn::mojom::Region selected_region_; + std::unique_ptr hostname_; + brave_vpn::BraveVPNConnectionInfo connection_info_; + bool cancel_connecting_ = false; PurchasedState purchased_state_ = PurchasedState::NOT_PURCHASED; + ConnectionState connection_state_ = ConnectionState::DISCONNECTED; + bool needs_connect_ = false; base::ScopedObservation observed_{this}; mojo::ReceiverSet receivers_; mojo::RemoteSet observers_; base::RepeatingTimer region_data_update_timer_; + + // Only for testing. std::string test_timezone_; + bool is_simulation_ = false; + + SEQUENCE_CHECKER(sequence_checker_); }; #endif // BRAVE_COMPONENTS_BRAVE_VPN_BRAVE_VPN_SERVICE_DESKTOP_H_ diff --git a/components/brave_vpn/brave_vpn_unittest.cc b/components/brave_vpn/brave_vpn_unittest.cc index 83b8f337c601..57ba76a8f068 100644 --- a/components/brave_vpn/brave_vpn_unittest.cc +++ b/components/brave_vpn/brave_vpn_unittest.cc @@ -12,18 +12,23 @@ #include "brave/components/brave_vpn/brave_vpn_utils.h" #include "brave/components/brave_vpn/features.h" #include "brave/components/brave_vpn/pref_names.h" +#include "brave/components/skus/browser/pref_names.h" +#include "brave/components/skus/browser/skus_sdk_impl.h" +#include "components/prefs/pref_registry_simple.h" #include "components/prefs/testing_pref_service.h" +#include "components/sync_preferences/testing_pref_service_syncable.h" #include "content/public/test/browser_task_environment.h" #include "services/network/test/test_shared_url_loader_factory.h" #include "testing/gtest/include/gtest/gtest.h" -class BraveVPNTest : public testing::Test { +class BraveVPNServiceTest : public testing::Test { public: - BraveVPNTest() { + BraveVPNServiceTest() { scoped_feature_list_.InitAndEnableFeature(brave_vpn::features::kBraveVPN); } void SetUp() override { + brave_rewards::SkusSdkImpl::RegisterProfilePrefs(pref_service_.registry()); brave_vpn::prefs::RegisterProfilePrefs(pref_service_.registry()); service_ = std::make_unique( base::MakeRefCounted(), @@ -178,9 +183,18 @@ class BraveVPNTest : public testing::Test { ])"; } + std::string GetProfileCredentialData() { + return R"( + { + "eap-username": "brave-user", + "eap-password": "brave-pwd" + } + )"; + } + base::test::ScopedFeatureList scoped_feature_list_; content::BrowserTaskEnvironment task_environment_; - TestingPrefServiceSimple pref_service_; + sync_preferences::TestingPrefServiceSyncable pref_service_; std::unique_ptr service_; }; @@ -188,7 +202,7 @@ TEST(BraveVPNFeatureTest, FeatureTest) { EXPECT_FALSE(brave_vpn::IsBraveVPNEnabled()); } -TEST_F(BraveVPNTest, RegionDataTest) { +TEST_F(BraveVPNServiceTest, RegionDataTest) { // Test invalid region data. service_->OnFetchRegionList(std::string(), true); EXPECT_TRUE(service_->regions_.empty()); @@ -216,17 +230,116 @@ TEST_F(BraveVPNTest, RegionDataTest) { EXPECT_EQ(service_->regions_[0], service_->device_region_); } -TEST_F(BraveVPNTest, HostnamesTest) { +TEST_F(BraveVPNServiceTest, HostnamesTest) { // Set valid hostnames list + service_->hostname_.reset(); service_->OnFetchHostnames("region-a", GetHostnamesData(), true); - EXPECT_EQ(5UL, service_->hostnames_["region-a"].size()); + // Check best one is picked from fetched hostname list. + EXPECT_EQ("host-2.brave.com", service_->hostname_->hostname); - // Set invalid hostnames list + // Can't get hostname from invalid hostnames list + service_->hostname_.reset(); service_->OnFetchHostnames("invalid-region-b", "", false); - EXPECT_EQ(0UL, service_->hostnames_["invalid-region-b"].size()); + EXPECT_FALSE(service_->hostname_); +} + +TEST_F(BraveVPNServiceTest, LoadPurchasedStateTest) { + EXPECT_EQ(PurchasedState::NOT_PURCHASED, service_->purchased_state_); + pref_service_.SetString(brave_rewards::prefs::kSkusVPNCredential, "abcdefg"); + EXPECT_EQ(PurchasedState::PURCHASED, service_->purchased_state_); +} + +TEST_F(BraveVPNServiceTest, CancelConnectingTest) { + service_->connection_state_ = ConnectionState::CONNECTING; + service_->cancel_connecting_ = true; + service_->connection_state_ = ConnectionState::CONNECTING; + service_->OnCreated(); + EXPECT_FALSE(service_->cancel_connecting_); + EXPECT_EQ(ConnectionState::DISCONNECTED, service_->connection_state_); + + // Start disconnect() when connect is done for cancelling. + service_->cancel_connecting_ = false; + service_->connection_state_ = ConnectionState::CONNECTING; + service_->Disconnect(); + EXPECT_TRUE(service_->cancel_connecting_); + EXPECT_EQ(ConnectionState::DISCONNECTING, service_->connection_state_); + service_->OnConnected(); + EXPECT_FALSE(service_->cancel_connecting_); + EXPECT_EQ(ConnectionState::DISCONNECTING, service_->connection_state_); + + service_->cancel_connecting_ = false; + service_->connection_state_ = ConnectionState::CONNECTING; + service_->Disconnect(); + EXPECT_TRUE(service_->cancel_connecting_); + EXPECT_EQ(ConnectionState::DISCONNECTING, service_->connection_state_); + + service_->cancel_connecting_ = true; + service_->CreateVPNConnection(); + EXPECT_FALSE(service_->cancel_connecting_); + EXPECT_EQ(ConnectionState::DISCONNECTED, service_->connection_state_); + + service_->cancel_connecting_ = true; + service_->connection_state_ = ConnectionState::CONNECTING; + service_->OnFetchHostnames("", "", true); + EXPECT_FALSE(service_->cancel_connecting_); + EXPECT_EQ(ConnectionState::DISCONNECTED, service_->connection_state_); + + service_->cancel_connecting_ = true; + service_->connection_state_ = ConnectionState::CONNECTING; + service_->OnGetSubscriberCredential("", true); + EXPECT_FALSE(service_->cancel_connecting_); + EXPECT_EQ(ConnectionState::DISCONNECTED, service_->connection_state_); + + service_->cancel_connecting_ = true; + service_->connection_state_ = ConnectionState::CONNECTING; + service_->OnGetProfileCredentials("", true); + EXPECT_FALSE(service_->cancel_connecting_); + EXPECT_EQ(ConnectionState::DISCONNECTED, service_->connection_state_); +} + +TEST_F(BraveVPNServiceTest, ConnectionInfoTest) { + // Check valid connection info is set when valid hostname and profile + // credential are fetched. + service_->connection_state_ = ConnectionState::CONNECTING; + pref_service_.SetString(brave_rewards::prefs::kSkusVPNCredential, "abcdefg"); + service_->OnFetchHostnames("region-a", GetHostnamesData(), true); + EXPECT_EQ(ConnectionState::CONNECTING, service_->connection_state_); + + // To prevent real os vpn entry creation. + service_->is_simulation_ = true; + service_->OnGetProfileCredentials(GetProfileCredentialData(), true); + EXPECT_EQ(ConnectionState::CONNECTING, service_->connection_state_); + EXPECT_TRUE(service_->connection_info_.IsValid()); + + // Check cached connection info is cleared when user set new selected region. + service_->connection_state_ = ConnectionState::DISCONNECTED; + brave_vpn::mojom::Region region; + service_->SetSelectedRegion(region.Clone()); + EXPECT_FALSE(service_->connection_info_.IsValid()); +} + +TEST_F(BraveVPNServiceTest, NeedsConnectTest) { + // Check ignore Connect() request while connecting or disconnecting is + // in-progress. + service_->connection_state_ = ConnectionState::CONNECTING; + service_->Connect(); + EXPECT_EQ(ConnectionState::CONNECTING, service_->connection_state_); + + service_->connection_state_ = ConnectionState::DISCONNECTING; + service_->Connect(); + EXPECT_EQ(ConnectionState::DISCONNECTING, service_->connection_state_); + + // Handle connect after disconnect current connection. + service_->connection_state_ = ConnectionState::CONNECTED; + service_->Connect(); + EXPECT_TRUE(service_->needs_connect_); + EXPECT_EQ(ConnectionState::DISCONNECTING, service_->connection_state_); + service_->OnDisconnected(); + EXPECT_FALSE(service_->needs_connect_); + EXPECT_EQ(ConnectionState::CONNECTING, service_->connection_state_); } -TEST_F(BraveVPNTest, LoadRegionDataFromPrefsTest) { +TEST_F(BraveVPNServiceTest, LoadRegionDataFromPrefsTest) { // Initially, prefs doesn't have region data. EXPECT_EQ(brave_vpn::mojom::Region(), service_->device_region_); EXPECT_TRUE(service_->regions_.empty()); diff --git a/components/brave_vpn/resources/panel/vpn_panel.tsx b/components/brave_vpn/resources/panel/vpn_panel.tsx index 37f198ccf966..92e306244147 100644 --- a/components/brave_vpn/resources/panel/vpn_panel.tsx +++ b/components/brave_vpn/resources/panel/vpn_panel.tsx @@ -58,7 +58,6 @@ function initialize () { render(, document.getElementById('mountPoint'), () => { getPanelBrowserAPI().panelHandler.showUI() - getPanelBrowserAPI().serviceHandler.createVPNConnection() }) } diff --git a/components/brave_vpn/switches.h b/components/brave_vpn/switches.h index 0ea39fff6a1d..3d10d7110b56 100644 --- a/components/brave_vpn/switches.h +++ b/components/brave_vpn/switches.h @@ -10,12 +10,12 @@ namespace brave_vpn { namespace switches { -// Value should be "connection-name:host-name:user-name:password". -constexpr char kBraveVPNTestCredentials[] = "brave-vpn-test-credentials"; // Use for simulation instead of calling os platform apis. constexpr char kBraveVPNSimulation[] = "brave-vpn-simulate"; // Use "prod", "staging" or "dev" constexpr char kBraveVPNAccountHost[] = "brave-vpn-account-host"; +constexpr char kBraveVPNPaymentsEnv[] = "brave-vpn-payments-env"; +constexpr char kBraveVPNTestMonthlyPass[] = "brave-vpn-test-monthly-pass"; } // namespace switches diff --git a/components/brave_vpn/utils_win.cc b/components/brave_vpn/utils_win.cc index ea8d477d11ae..6560d6b00845 100644 --- a/components/brave_vpn/utils_win.cc +++ b/components/brave_vpn/utils_win.cc @@ -26,6 +26,7 @@ HANDLE g_disconnecting_event_handle = NULL; void WINAPI RasDialFunc(UINT, RASCONNSTATE rasconnstate, DWORD error) { if (error) { + SetEvent(g_connect_failed_event_handle); internal::PrintRasError(error); return; } @@ -198,12 +199,15 @@ bool DisconnectEntry(const std::wstring& entry_name) { // If successful, print the names of the active connections. if (ERROR_SUCCESS == dw_ret) { - DVLOG(2) << "The following RAS connections are currently active:"; + VLOG(2) << __func__ + << " : The following RAS connections are currently active:" + << dw_connections; for (DWORD i = 0; i < dw_connections; i++) { std::wstring name(lp_ras_conn[i].szEntryName); std::wstring type(lp_ras_conn[i].szDeviceType); + VLOG(2) << __func__ << " : " << name << ", " << type; if (name.compare(entry_name) == 0 && type.compare(L"VPN") == 0) { - DVLOG(2) << "Disconnect... " << entry_name; + VLOG(2) << __func__ << " : Disconnect... " << entry_name; SetEvent(g_disconnecting_event_handle); dw_ret = RasHangUpA(lp_ras_conn[i].hrasconn); break; @@ -223,7 +227,7 @@ bool DisconnectEntry(const std::wstring& entry_name) { return false; } - DVLOG(2) << "There are no active RAS connections."; + VLOG(2) << "There are no active RAS connections."; return true; } @@ -260,7 +264,7 @@ bool ConnectEntry(const std::wstring& entry_name) { wcscpy_s(lp_ras_dial_params->szUserName, UNLEN + 1, credentials.szUserName); wcscpy_s(lp_ras_dial_params->szPassword, PWLEN + 1, credentials.szPassword); - DVLOG(2) << "Connecting to " << entry_name; + VLOG(2) << __func__ << " : Connecting to " << entry_name; HRASCONN h_ras_conn = NULL; dw_ret = RasDial(NULL, DEFAULT_PHONE_BOOK, lp_ras_dial_params, 0, (LPVOID)(&RasDialFunc), &h_ras_conn); @@ -270,7 +274,7 @@ bool ConnectEntry(const std::wstring& entry_name) { SetEvent(g_connect_failed_event_handle); return false; } - DVLOG(2) << "SUCCESS!"; + VLOG(2) << "SUCCESS!"; HeapFree(GetProcessHeap(), 0, (LPVOID)lp_ras_dial_params);