Skip to content

Commit

Permalink
imdsclient: return options for all IMDS fetch functions
Browse files Browse the repository at this point in the history
This commit changes the return type for the IMDS fetch functions to
options returning None when a 404 is encountered. This makes it easier
for the caller to decide what to do when a 404 occurs and is closer to
the behavior before imdsclient was written.

`fetch_identity_document` and its associated `IdentityDocument` struct
were removed in favor of separate `fetch_region` and
`fetch_instance_type` functions. Instance-type is now fetched via IMDS
meta-data as opposed to the identity document. We are returning to the
original behavior as it is possible that the identity document can be
absent in certain situations.

Changes were also made to shibaken, pluto, and early-boot-config to
accommodate the options detailed above.
  • Loading branch information
jpculp committed May 26, 2021
1 parent 0f9993e commit fde8f46
Show file tree
Hide file tree
Showing 6 changed files with 99 additions and 74 deletions.
1 change: 0 additions & 1 deletion sources/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 9 additions & 4 deletions sources/api/early-boot-config/src/provider/aws.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@ impl AwsDataProvider {
/// Fetches user data, which is expected to be in TOML form and contain a `[settings]` section,
/// returning a SettingsJson representing the inside of that section.
async fn user_data(client: &mut ImdsClient) -> Result<Option<SettingsJson>> {
let user_data_raw = client.fetch_userdata().await.context(error::ImdsRequest)?;
let user_data_raw = match client.fetch_userdata().await.context(error::ImdsRequest)? {
Some(user_data_raw) => user_data_raw,
None => return Ok(None),
};
let user_data_str = expand_slice_maybe(&user_data_raw)
.context(error::Decompression { what: "user data" })?;
trace!("Received user data: {}", user_data_str);
Expand Down Expand Up @@ -55,11 +58,10 @@ impl AwsDataProvider {
.to_owned()
} else {
client
.fetch_identity_document()
.fetch_region()
.await
.context(error::ImdsRequest)?
.region()
.to_owned()
.context(error::ImdsMissingRegion)?
};
trace!(
"Retrieved region from instance identity document: {}",
Expand Down Expand Up @@ -133,6 +135,9 @@ mod error {
#[snafu(display("Unable to read input file '{}': {}", path.display(), source))]
InputFileRead { path: PathBuf, source: io::Error },

#[snafu(display("IMDS request failed: missing region"))]
ImdsMissingRegion {},

#[snafu(display("IMDS request failed: {}", source))]
ImdsRequest { source: imdsclient::Error },

Expand Down
16 changes: 12 additions & 4 deletions sources/api/pluto/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,11 +117,12 @@ type Result<T> = std::result::Result<T, PlutoError>;

async fn get_max_pods(client: &mut ImdsClient) -> Result<String> {
let instance_type = client
.fetch_identity_document()
.fetch_instance_type()
.await
.context(error::ImdsRequest)?
.instance_type()
.to_string();
.context(error::ImdsNone {
what: "instance_type",
})?;

// Find the corresponding maximum number of pods supported by this instance type
let file = BufReader::new(
Expand Down Expand Up @@ -208,6 +209,9 @@ async fn get_cluster_dns_from_imds_mac(client: &mut ImdsClient) -> Result<String
.fetch_mac_addresses()
.await
.context(error::ImdsRequest)?
.context(error::ImdsNone {
what: "mac addresses",
})?
.first()
.context(error::ImdsNone {
what: "mac addresses",
Expand All @@ -219,6 +223,9 @@ async fn get_cluster_dns_from_imds_mac(client: &mut ImdsClient) -> Result<String
.fetch_cidr_blocks_for_mac(&mac)
.await
.context(error::ImdsRequest)?
.context(error::ImdsNone {
what: "CIDR blocks",
})?
.first()
.context(error::ImdsNone {
what: "CIDR blocks",
Expand All @@ -239,7 +246,8 @@ async fn get_node_ip(client: &mut ImdsClient) -> Result<String> {
client
.fetch_local_ipv4_address()
.await
.context(error::ImdsRequest)
.context(error::ImdsRequest)?
.context(error::ImdsNone { what: "node ip" })
}

/// Print usage message.
Expand Down
6 changes: 4 additions & 2 deletions sources/api/shibaken/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,12 @@ impl UserData {
async fn fetch_public_keys_from_imds() -> Result<Vec<String>> {
info!("Connecting to IMDS");
let mut client = ImdsClient::new().await.context(error::ImdsClient)?;
client
let public_keys = client
.fetch_public_ssh_keys()
.await
.context(error::ImdsClient)
.context(error::ImdsClient)?
.unwrap_or_else(Vec::new);
Ok(public_keys)
}

/// Store the args we receive on the command line.
Expand Down
1 change: 0 additions & 1 deletion sources/imdsclient/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ exclude = ["README.md"]
http = "0.2"
log = "0.4"
reqwest = { version = "0.11.1", default-features = false }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1"
simplelog = "0.10"
snafu = "0.6"
Expand Down
136 changes: 74 additions & 62 deletions sources/imdsclient/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@ This method is useful for specifying things like a pinned date for the IMDS sche
use http::StatusCode;
use log::{debug, info, trace, warn};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use snafu::{ensure, ResultExt};
use serde_json::Value;
use snafu::{ensure, OptionExt, ResultExt};
use std::time::Duration;
use tokio::time;

Expand All @@ -30,25 +30,6 @@ pub struct ImdsClient {
session_token: String,
}

/// This is the return type when querying for the IMDS identity document, which contains information
/// such as region and instance_type. We only include the fields that we are using in Bottlerocket.
#[derive(Clone, Eq, PartialEq, Ord, PartialOrd, Serialize, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct IdentityDocument {
region: String,
instance_type: String,
}

impl IdentityDocument {
pub fn region(&self) -> &str {
self.region.as_str()
}

pub fn instance_type(&self) -> &str {
self.instance_type.as_str()
}
}

impl ImdsClient {
pub async fn new() -> Result<Self> {
Self::new_impl(BASE_URI.to_string()).await
Expand All @@ -65,57 +46,74 @@ impl ImdsClient {
}

/// Gets `user-data` from IMDS. The user-data may be either a UTF-8 string or compressed bytes.
pub async fn fetch_userdata(&mut self) -> Result<Vec<u8>> {
pub async fn fetch_userdata(&mut self) -> Result<Option<Vec<u8>>> {
self.fetch_imds(PINNED_SCHEMA, "user-data").await
}

/// Returns the 'identity document' with fields like region and instance_type.
pub async fn fetch_identity_document(&mut self) -> Result<IdentityDocument> {
/// Returns the region described in the identity document.
pub async fn fetch_region(&mut self) -> Result<Option<String>> {
let target = "dynamic/instance-identity/document";
let response = self.fetch_bytes(target).await?;
let identity_document: IdentityDocument =
serde_json::from_slice(&response).context(error::Serde)?;
Ok(identity_document)
let response = match self.fetch_bytes(target).await? {
Some(response) => response,
None => return Ok(None),
};
let identity_document: Value = serde_json::from_slice(&response).context(error::Serde)?;
let region = identity_document
.get("region")
.and_then(|value| value.as_str())
.map(|region| region.to_string());
Ok(region)
}

/// Returns the list of network interface mac addresses.
pub async fn fetch_mac_addresses(&mut self) -> Result<Vec<String>> {
pub async fn fetch_mac_addresses(&mut self) -> Result<Option<Vec<String>>> {
let macs_target = "meta-data/network/interfaces/macs";
let macs = self.fetch_string(&macs_target).await?;
Ok(macs.split('\n').map(|s| s.to_string()).collect())
let macs = self
.fetch_string(&macs_target)
.await?
.map(|macs| macs.lines().map(|s| s.to_string()).collect());
Ok(macs)
}

/// Gets the list of CIDR blocks for a given network interface `mac` address.
pub async fn fetch_cidr_blocks_for_mac(&mut self, mac: &str) -> Result<Vec<String>> {
pub async fn fetch_cidr_blocks_for_mac(&mut self, mac: &str) -> Result<Option<Vec<String>>> {
// Infer the cluster DNS based on our CIDR blocks.
let mac_cidr_blocks_target = format!(
"meta-data/network/interfaces/macs/{}/vpc-ipv4-cidr-blocks",
mac
);
let cidr_blocks = self.fetch_string(&mac_cidr_blocks_target).await?;
Ok(cidr_blocks.split('\n').map(|s| s.to_string()).collect())
let cidr_blocks = self
.fetch_string(&mac_cidr_blocks_target)
.await?
.map(|cidr_blocks| cidr_blocks.lines().map(|s| s.to_string()).collect());
Ok(cidr_blocks)
}

/// Gets the local IPV4 address from instance metadata.
pub async fn fetch_local_ipv4_address(&mut self) -> Result<String> {
pub async fn fetch_local_ipv4_address(&mut self) -> Result<Option<String>> {
let node_ip_target = "meta-data/local-ipv4";
self.fetch_string(&node_ip_target).await
}

/// Gets the instance-type from instance metadata.
pub async fn fetch_instance_type(&mut self) -> Result<Option<String>> {
let instance_type_target = "meta-data/instance-type";
self.fetch_string(&instance_type_target).await
}

/// Returns a list of public ssh keys skipping any keys that do not start with 'ssh'.
pub async fn fetch_public_ssh_keys(&mut self) -> Result<Vec<String>> {
pub async fn fetch_public_ssh_keys(&mut self) -> Result<Option<Vec<String>>> {
info!("Fetching list of available public keys from IMDS");
// Returns a list of available public keys as '0=my-public-key'
let public_key_list = match self.fetch_string("meta-data/public-keys").await {
Err(error::Error::NotFound { uri: _ }) => {
// this is OK, it just means there are no keys
debug!("no available public keys");
return Ok(Vec::new());
let public_key_list = match self.fetch_string("meta-data/public-keys").await? {
Some(public_key_list) => {
debug!("available public keys '{}'", &public_key_list);
public_key_list
}
Err(e) => {
return Err(e);
None => {
debug!("no available public keys");
return Ok(None);
}
Ok(value) => value,
};

debug!("available public keys '{}'", &public_key_list);
Expand All @@ -132,7 +130,10 @@ impl ImdsClient {
&public_key_targets.len()
);

let public_key_text = self.fetch_string(&target).await?;
let public_key_text = self
.fetch_string(&target)
.await?
.context(error::KeyNotFound { target })?;
let public_key = public_key_text.trim_end();
// Simple check to see if the text is probably an ssh key.
if public_key.starts_with("ssh") {
Expand All @@ -149,28 +150,36 @@ impl ImdsClient {
if public_keys.is_empty() {
warn!("No valid keys found");
}
Ok(public_keys)
Ok(Some(public_keys))
}

/// Helper to fetch bytes from IMDS using the pinned schema version.
async fn fetch_bytes<S>(&mut self, end_target: S) -> Result<Vec<u8>>
async fn fetch_bytes<S>(&mut self, end_target: S) -> Result<Option<Vec<u8>>>
where
S: AsRef<str>,
{
self.fetch_imds(PINNED_SCHEMA, end_target.as_ref()).await
}

/// Helper to fetch a string from IMDS using the pinned schema version.
async fn fetch_string<S>(&mut self, end_target: S) -> Result<String>
async fn fetch_string<S>(&mut self, end_target: S) -> Result<Option<String>>
where
S: AsRef<str>,
{
let response_body = self.fetch_imds(PINNED_SCHEMA, end_target).await?;
Ok(String::from_utf8(response_body).context(error::NonUtf8Response)?)
match self.fetch_imds(PINNED_SCHEMA, end_target).await? {
Some(response_body) => Ok(Some(
String::from_utf8(response_body).context(error::NonUtf8Response)?,
)),
None => Ok(None),
}
}

/// Fetch data from IMDS.
async fn fetch_imds<S1, S2>(&mut self, schema_version: S1, target: S2) -> Result<Vec<u8>>
async fn fetch_imds<S1, S2>(
&mut self,
schema_version: S1,
target: S2,
) -> Result<Option<Vec<u8>>>
where
S1: AsRef<str>,
S2: AsRef<str>,
Expand Down Expand Up @@ -218,11 +227,11 @@ impl ImdsClient {
let response_str = printable_string(&response_body);
trace!("Response: {:?}", response_str);

return Ok(response_body);
return Ok(Some(response_body));
}

// IMDS returns 404 if no user data is given, or if IMDS is disabled
StatusCode::NOT_FOUND => return Err(error::Error::NotFound { uri }),
StatusCode::NOT_FOUND => return Ok(None),

// IMDS returns 401 if the session token is expired or invalid
StatusCode::UNAUTHORIZED => {
Expand Down Expand Up @@ -358,12 +367,12 @@ mod error {
#[snafu(display("IMDS session failed: {}", source))]
FailedSession { source: reqwest::Error },

#[snafu(display("Error retrieving key from {}", target))]
KeyNotFound { target: String },

#[snafu(display("Response was not UTF-8: {}", source))]
NonUtf8Response { source: std::string::FromUtf8Error },

#[snafu(display("404 file not found fetching '{}'", uri))]
NotFound { uri: String },

#[snafu(display("Error {}ing '{}': {}", method, uri, source))]
Request {
method: String,
Expand Down Expand Up @@ -461,7 +470,7 @@ mod test {
.fetch_imds(schema_version, target)
.await
.unwrap();
assert_eq!(imds_data, response_body.as_bytes().to_vec());
assert_eq!(imds_data, Some(response_body.as_bytes().to_vec()));
}

#[tokio::test]
Expand Down Expand Up @@ -493,8 +502,11 @@ mod test {
),
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let result = imds_client.fetch_imds(schema_version, target).await;
assert!(matches!(result, Err(error::Error::NotFound { .. })));
let imds_data = imds_client
.fetch_imds(schema_version, target)
.await
.unwrap();
assert_eq!(imds_data, None);
}

#[tokio::test]
Expand Down Expand Up @@ -599,7 +611,7 @@ mod test {
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let imds_data = imds_client.fetch_string(end_target).await.unwrap();
assert_eq!(imds_data, response_body.to_string());
assert_eq!(imds_data, Some(response_body.to_string()));
}

#[tokio::test]
Expand Down Expand Up @@ -634,7 +646,7 @@ mod test {
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let imds_data = imds_client.fetch_bytes(end_target).await.unwrap();
assert_eq!(imds_data, response_body.as_bytes().to_vec());
assert_eq!(imds_data, Some(response_body.as_bytes().to_vec()));
}

#[tokio::test]
Expand Down Expand Up @@ -668,7 +680,7 @@ mod test {
);
let mut imds_client = ImdsClient::new_impl(base_uri).await.unwrap();
let imds_data = imds_client.fetch_userdata().await.unwrap();
assert_eq!(imds_data, response_body.as_bytes().to_vec());
assert_eq!(imds_data, Some(response_body.as_bytes().to_vec()));
}

#[test]
Expand Down

0 comments on commit fde8f46

Please sign in to comment.