Skip to content

Commit

Permalink
Improved vpn connect/disconnect logic
Browse files Browse the repository at this point in the history
Whenever connect is asked, vpn service tries to create os vpn entry
with latest hostname and user credentials from guardian service.

fix brave/brave-browser#18648
fix brave/brave-browser#18422
  • Loading branch information
simonhong committed Oct 26, 2021
1 parent 15d9d3b commit ef04ff3
Show file tree
Hide file tree
Showing 15 changed files with 756 additions and 177 deletions.
3 changes: 3 additions & 0 deletions components/brave_vpn/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ static_library("brave_vpn") {
":brave_vpn_internal",
":mojom",
"//brave/components/resources:strings",
"//brave/components/skus/browser",
"//components/prefs",
"//third_party/icu",
"//ui/base",
Expand Down Expand Up @@ -112,7 +113,9 @@ source_set("unit_tests") {
deps = [
":brave_vpn",
"//base",
"//brave/components/skus/browser",
"//components/prefs:test_support",
"//components/sync_preferences:test_support",
"//content/test:test_support",
"//services/network:test_support",
"//testing/gtest",
Expand Down
7 changes: 7 additions & 0 deletions components/brave_vpn/brave_vpn_connection_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,13 @@ BraveVPNConnectionInfo::BraveVPNConnectionInfo(
BraveVPNConnectionInfo& BraveVPNConnectionInfo::operator=(
const BraveVPNConnectionInfo& info) = default;

void BraveVPNConnectionInfo::Reset() {
connection_name_.clear();
hostname_.clear();
username_.clear();
password_.clear();
}

bool BraveVPNConnectionInfo::IsValid() const {
// TODO(simonhong): Improve credentials validation.
return !hostname_.empty() && !username_.empty() && !password_.empty();
Expand Down
1 change: 1 addition & 0 deletions components/brave_vpn/brave_vpn_connection_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class BraveVPNConnectionInfo {
BraveVPNConnectionInfo(const BraveVPNConnectionInfo& info);
BraveVPNConnectionInfo& operator=(const BraveVPNConnectionInfo& info);

void Reset();
bool IsValid() const;
void SetConnectionInfo(const std::string& connection_name,
const std::string& hostname,
Expand Down
17 changes: 8 additions & 9 deletions components/brave_vpn/brave_vpn_os_connection_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,14 @@ class BraveVPNOSConnectionAPI {
public:
class Observer : public base::CheckedObserver {
public:
// TODO(simonhong): Don't need |name| parameter because only one vpn
// connection is managed.
virtual void OnCreated(const std::string& name) = 0;
virtual void OnRemoved(const std::string& name) = 0;
virtual void OnConnected(const std::string& name) = 0;
virtual void OnIsConnecting(const std::string& name) = 0;
virtual void OnConnectFailed(const std::string& name) = 0;
virtual void OnDisconnected(const std::string& name) = 0;
virtual void OnIsDisconnecting(const std::string& name) = 0;
virtual void OnCreated() = 0;
virtual void OnCreateFailed() = 0;
virtual void OnRemoved() = 0;
virtual void OnConnected() = 0;
virtual void OnIsConnecting() = 0;
virtual void OnConnectFailed() = 0;
virtual void OnDisconnected() = 0;
virtual void OnIsDisconnecting() = 0;

protected:
~Observer() override = default;
Expand Down
37 changes: 30 additions & 7 deletions components/brave_vpn/brave_vpn_os_connection_api_mac.mm
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,26 @@

const NSString* kBraveVPNKey = @"BraveVPNKey";

std::string NEVPNStatusToString(NEVPNStatus status) {
switch (status) {
case NEVPNStatusInvalid:
return "NEVPNStatusInvalid";
case NEVPNStatusDisconnected:
return "NEVPNStatusDisconnected";
case NEVPNStatusConnecting:
return "NEVPNStatusConnecting";
case NEVPNStatusConnected:
return "NEVPNStatusConnected";
case NEVPNStatusReasserting:
return "NEVPNStatusReasserting";
case NEVPNStatusDisconnecting:
return "NEVPNStatusDisconnecting";
default:
NOTREACHED();
}
return std::string();
}

NSData* GetPasswordRefForAccount() {
NSString* bundle_id = [[NSBundle mainBundle] bundleIdentifier];
CFTypeRef copy_result = NULL;
Expand Down Expand Up @@ -172,11 +192,14 @@ OSStatus StorePassword(const NSString* password) {
if (error) {
LOG(ERROR) << "Create - saveToPrefs error: "
<< base::SysNSStringToUTF8([error localizedDescription]);
for (Observer& obs : observers_)
obs.OnCreateFailed();

return;
}
VLOG(2) << "Create - saveToPrefs success";
for (Observer& obs : observers_)
obs.OnCreated(std::string());
obs.OnCreated();
}];
}];
}
Expand All @@ -201,7 +224,7 @@ OSStatus StorePassword(const NSString* password) {
}
VLOG(2) << "RemoveVPNConnection - successfully removed";
for (Observer& obs : observers_)
obs.OnRemoved(std::string());
obs.OnRemoved();
}];
}
RemoveKeychainItemForAccount();
Expand Down Expand Up @@ -264,25 +287,25 @@ OSStatus StorePassword(const NSString* password) {
}

NEVPNStatus current_status = [[vpn_manager connection] status];
VLOG(2) << "CheckConnection: " << current_status;
VLOG(2) << "CheckConnection: " << NEVPNStatusToString(current_status);
switch (current_status) {
case NEVPNStatusConnected:
for (Observer& obs : observers_)
obs.OnConnected(name);
obs.OnConnected();
break;
case NEVPNStatusConnecting:
case NEVPNStatusReasserting:
for (Observer& obs : observers_)
obs.OnIsConnecting(name);
obs.OnIsConnecting();
break;
case NEVPNStatusDisconnected:
case NEVPNStatusInvalid:
for (Observer& obs : observers_)
obs.OnDisconnected(name);
obs.OnDisconnected();
break;
case NEVPNStatusDisconnecting:
for (Observer& obs : observers_)
obs.OnIsDisconnecting(name);
obs.OnIsDisconnecting();
break;
default:
break;
Expand Down
12 changes: 6 additions & 6 deletions components/brave_vpn/brave_vpn_os_connection_api_sim.cc
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ void BraveVPNOSConnectionAPISim::OnCreated(const std::string& name,
return;

for (Observer& obs : observers_)
obs.OnCreated(name);
obs.OnCreated();
}

void BraveVPNOSConnectionAPISim::OnConnected(const std::string& name,
Expand All @@ -104,12 +104,12 @@ void BraveVPNOSConnectionAPISim::OnConnected(const std::string& name,
}

for (Observer& obs : observers_)
success ? obs.OnConnected(name) : obs.OnConnectFailed(name);
success ? obs.OnConnected() : obs.OnConnectFailed();
}

void BraveVPNOSConnectionAPISim::OnIsConnecting(const std::string& name) {
for (Observer& obs : observers_)
obs.OnIsConnecting(name);
obs.OnIsConnecting();
}

void BraveVPNOSConnectionAPISim::OnDisconnected(const std::string& name,
Expand All @@ -118,12 +118,12 @@ void BraveVPNOSConnectionAPISim::OnDisconnected(const std::string& name,
return;

for (Observer& obs : observers_)
obs.OnDisconnected(name);
obs.OnDisconnected();
}

void BraveVPNOSConnectionAPISim::OnIsDisconnecting(const std::string& name) {
for (Observer& obs : observers_)
obs.OnIsDisconnecting(name);
obs.OnIsDisconnecting();
}

void BraveVPNOSConnectionAPISim::OnRemoved(const std::string& name,
Expand All @@ -132,7 +132,7 @@ void BraveVPNOSConnectionAPISim::OnRemoved(const std::string& name,
return;

for (Observer& obs : observers_)
obs.OnRemoved(name);
obs.OnRemoved();
}

} // namespace brave_vpn
19 changes: 11 additions & 8 deletions components/brave_vpn/brave_vpn_os_connection_api_win.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,19 +136,19 @@ void BraveVPNOSConnectionAPIWin::OnCheckConnection(
for (Observer& obs : observers_) {
switch (result) {
case CheckConnectionResult::CONNECTED:
obs.OnConnected(name);
obs.OnConnected();
break;
case CheckConnectionResult::CONNECTING:
obs.OnIsConnecting(name);
obs.OnIsConnecting();
break;
case CheckConnectionResult::CONNECT_FAILED:
obs.OnConnectFailed(name);
obs.OnConnectFailed();
break;
case CheckConnectionResult::DISCONNECTED:
obs.OnDisconnected(name);
obs.OnDisconnected();
break;
case CheckConnectionResult::DISCONNECTING:
obs.OnIsDisconnecting(name);
obs.OnIsDisconnecting();
break;
default:
break;
Expand All @@ -158,11 +158,14 @@ void BraveVPNOSConnectionAPIWin::OnCheckConnection(

void BraveVPNOSConnectionAPIWin::OnCreated(const std::string& name,
bool success) {
if (!success)
if (!success) {
for (Observer& obs : observers_)
obs.OnCreateFailed();
return;
}

for (Observer& obs : observers_)
obs.OnCreated(name);
obs.OnCreated();
}

void BraveVPNOSConnectionAPIWin::OnRemoved(const std::string& name,
Expand All @@ -171,7 +174,7 @@ void BraveVPNOSConnectionAPIWin::OnRemoved(const std::string& name,
return;

for (Observer& obs : observers_)
obs.OnRemoved(name);
obs.OnRemoved();
}

void BraveVPNOSConnectionAPIWin::StartVPNConnectionChangeMonitoring() {
Expand Down
32 changes: 27 additions & 5 deletions components/brave_vpn/brave_vpn_service.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ constexpr char kCreateSubscriberCredential[] =
"api/v1/subscriber-credential/create";
constexpr char kProfileCredential[] = "api/v1.1/register-and-create";
constexpr char kVerifyPurchaseToken[] = "api/v1.1/verify-purchase-token";
constexpr char kCreateSubscriberCredentialV12[] =
"api/v1.2/subscriber-credential/create";

net::NetworkTrafficAnnotationTag GetNetworkTrafficAnnotationTag() {
return net::DefineNetworkTrafficAnnotation("brave_vpn_service", R"(
Expand Down Expand Up @@ -81,12 +83,14 @@ BraveVpnService::~BraveVpnService() = default;

void BraveVpnService::Shutdown() {}

void BraveVpnService::OAuthRequest(const GURL& url,
const std::string& method,
const std::string& post_data,
URLRequestCallback callback) {
void BraveVpnService::OAuthRequest(
const GURL& url,
const std::string& method,
const std::string& post_data,
URLRequestCallback callback,
const base::flat_map<std::string, std::string>& headers) {
api_request_helper_.Request(method, url, post_data, "application/json", false,
std::move(callback));
std::move(callback), headers);
}

void BraveVpnService::GetAllServerRegions(ResponseCallback callback) {
Expand Down Expand Up @@ -196,3 +200,21 @@ void BraveVpnService::OnGetSubscriberCredential(
}
std::move(callback).Run(subscriber_credential, success);
}

void BraveVpnService::GetSubscriberCredentialV12(
ResponseCallback callback,
const std::string& payments_environment,
const std::string& monthly_pass) {
auto internal_callback =
base::BindOnce(&BraveVpnService::OnGetSubscriberCredential,
weak_ptr_factory_.GetWeakPtr(), std::move(callback));

const GURL base_url =
GetURLWithPath(kVpnHost, kCreateSubscriberCredentialV12);
base::Value dict(base::Value::Type::DICTIONARY);
dict.SetStringKey("validation-method", "brave-premium");
dict.SetStringKey("brave-vpn-premium-monthly-pass", monthly_pass);
std::string request_body = CreateJSONRequestBody(dict);
OAuthRequest(base_url, "POST", request_body, std::move(internal_callback),
{{"Brave-Payments-Environment", payments_environment}});
}
13 changes: 9 additions & 4 deletions components/brave_vpn/brave_vpn_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,22 @@ class BraveVpnService : public KeyedService {
const std::string& product_id,
const std::string& product_type,
const std::string& bundle_id);
void GetSubscriberCredentialV12(ResponseCallback callback,
const std::string& payments_environment,
const std::string& monthly_pass);

private:
using URLRequestCallback =
base::OnceCallback<void(int,
const std::string&,
const base::flat_map<std::string, std::string>&)>;

void OAuthRequest(const GURL& url,
const std::string& method,
const std::string& post_data,
URLRequestCallback callback);
void OAuthRequest(
const GURL& url,
const std::string& method,
const std::string& post_data,
URLRequestCallback callback,
const base::flat_map<std::string, std::string>& headers = {});

void OnGetResponse(ResponseCallback callback,
int status,
Expand Down
Loading

0 comments on commit ef04ff3

Please sign in to comment.