Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ability to set request specific redirect policy #1204

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 12 additions & 5 deletions src/async_impl/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1088,7 +1088,7 @@ impl Client {
}

pub(super) fn execute_request(&self, req: Request) -> Pending {
let (method, url, mut headers, body, timeout) = req.pieces();
let (method, url, mut headers, body, timeout, redirect_policy) = req.pieces();
if url.scheme() != "http" && url.scheme() != "https" {
return Pending::new_err(error::url_bad_scheme(url));
}
Expand Down Expand Up @@ -1165,6 +1165,7 @@ impl Client {

in_flight,
timeout,
redirect_policy,
}),
}
}
Expand Down Expand Up @@ -1361,6 +1362,7 @@ pin_project! {
in_flight: ResponseFuture,
#[pin]
timeout: Option<Pin<Box<Sleep>>>,
redirect_policy: Option<redirect::Policy>,
}
}

Expand Down Expand Up @@ -1503,10 +1505,15 @@ impl Future for PendingRequest {
}
let url = self.url.clone();
self.as_mut().urls().push(url);
let action = self
.client
.redirect_policy
.check(res.status(), &loc, &self.urls);

// Request specific redirect policy takes precedence
// over client redirect policy
let policy = match &self.redirect_policy {
Some(p) => p,
None => &self.client.redirect_policy,
};

let action = policy.check(res.status(), &loc, &self.urls);

match action {
redirect::ActionKind::Follow => {
Expand Down
49 changes: 46 additions & 3 deletions src/async_impl/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use super::response::Response;
#[cfg(feature = "multipart")]
use crate::header::CONTENT_LENGTH;
use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE};
use crate::{Method, Url};
use crate::{redirect, Method, Url};
use http::{request::Parts, Request as HttpRequest};

/// A request which can be executed with `Client::execute()`.
Expand All @@ -27,6 +27,7 @@ pub struct Request {
headers: HeaderMap,
body: Option<Body>,
timeout: Option<Duration>,
redirect_policy: Option<redirect::Policy>,
}

/// A builder to construct the properties of a `Request`.
Expand All @@ -48,6 +49,7 @@ impl Request {
headers: HeaderMap::new(),
body: None,
timeout: None,
redirect_policy: None,
}
}

Expand Down Expand Up @@ -111,6 +113,18 @@ impl Request {
&mut self.timeout
}

/// Get the redirect policy.
#[inline]
pub fn redirect_policy(&self) -> Option<&redirect::Policy> {
self.redirect_policy.as_ref()
}

/// Get a mutable reference to the redirect policy.
#[inline]
pub fn redirect_policy_mut(&mut self) -> &mut Option<redirect::Policy> {
&mut self.redirect_policy
}

/// Attempt to clone the request.
///
/// `None` is returned if the request can not be cloned, i.e. if the body is a stream.
Expand All @@ -122,12 +136,29 @@ impl Request {
let mut req = Request::new(self.method().clone(), self.url().clone());
*req.timeout_mut() = self.timeout().cloned();
*req.headers_mut() = self.headers().clone();
*req.redirect_policy_mut() = self.redirect_policy().cloned();
req.body = body;
Some(req)
}

pub(super) fn pieces(self) -> (Method, Url, HeaderMap, Option<Body>, Option<Duration>) {
(self.method, self.url, self.headers, self.body, self.timeout)
pub(super) fn pieces(
self,
) -> (
Method,
Url,
HeaderMap,
Option<Body>,
Option<Duration>,
Option<redirect::Policy>,
) {
(
self.method,
self.url,
self.headers,
self.body,
self.timeout,
self.redirect_policy,
)
}
}

Expand Down Expand Up @@ -244,6 +275,17 @@ impl RequestBuilder {
self
}

/// Enables a request specific redirect policy.
///
/// It affects only this request and overrides
/// the request policy configured using `ClientBuilder::request()`.
pub fn redirect(mut self, policy: redirect::Policy) -> RequestBuilder {
if let Ok(ref mut req) = self.request {
*req.redirect_policy_mut() = Some(policy);
}
self
}

/// Sends a multipart/form-data body.
///
/// ```
Expand Down Expand Up @@ -525,6 +567,7 @@ where
headers,
body: Some(body.into()),
timeout: None,
redirect_policy: None,
})
}
}
Expand Down
25 changes: 24 additions & 1 deletion src/blocking/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use super::body::{self, Body};
use super::multipart;
use super::Client;
use crate::header::{HeaderMap, HeaderName, HeaderValue, CONTENT_TYPE};
use crate::{async_impl, Method, Url};
use crate::{async_impl, redirect, Method, Url};

/// A request which can be executed with `Client::execute()`.
pub struct Request {
Expand Down Expand Up @@ -102,6 +102,18 @@ impl Request {
self.inner.timeout_mut()
}

/// Get the redirect policy.
#[inline]
pub fn redirect_policy(&self) -> Option<&redirect::Policy> {
self.inner.redirect_policy()
}

/// Get a mutable reference to the redirect policy.
#[inline]
pub fn redirect_policy_mut(&mut self) -> &mut Option<redirect::Policy> {
self.inner.redirect_policy_mut()
}

/// Attempts to clone the `Request`.
///
/// None is returned if a body is which can not be cloned. This can be because the body is a
Expand Down Expand Up @@ -342,6 +354,17 @@ impl RequestBuilder {
self
}

/// Enables a request specific redirect policy.
///
/// It affects only this request and overrides
/// the request policy configured using `ClientBuilder::request()`.
pub fn redirect(mut self, policy: redirect::Policy) -> RequestBuilder {
if let Ok(ref mut req) = self.request {
*req.redirect_policy_mut() = Some(policy);
}
self
}

/// Modify the query string of the URL.
///
/// Modifies the URL of this request, adding the parameters provided.
Expand Down
11 changes: 11 additions & 0 deletions src/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ use crate::Url;
/// the allowed maximum redirect hops in a chain.
/// - `none` can be used to disable all redirect behavior.
/// - `custom` can be used to create a customized policy.
#[derive(Clone)]
pub struct Policy {
inner: PolicyKind,
}
Expand Down Expand Up @@ -209,6 +210,16 @@ enum PolicyKind {
None,
}

impl Clone for PolicyKind {
fn clone(&self) -> Self {
match self {
c @ PolicyKind::Custom(_) => c.clone(),
l @ PolicyKind::Limit(_) => l.clone(),
PolicyKind::None => PolicyKind::None,
}
}
}

impl fmt::Debug for Policy {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_tuple("Policy").field(&self.inner).finish()
Expand Down
61 changes: 61 additions & 0 deletions tests/redirect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,3 +316,64 @@ async fn test_redirect_302_with_set_cookies() {
assert_eq!(res.url().as_str(), dst);
assert_eq!(res.status(), reqwest::StatusCode::OK);
}

#[tokio::test]
async fn test_request_redirect() {
let code = 301u16;

let redirect = server::http(move |req| async move {
if req.method() == "POST" {
assert_eq!(req.uri(), &*format!("/{}", code));
http::Response::builder()
.status(code)
.header("location", "/dst")
.header("server", "test-redirect")
.body(Default::default())
.unwrap()
} else {
assert_eq!(req.method(), "GET");

http::Response::builder()
.header("server", "test-dst")
.body(Default::default())
.unwrap()
}
});

let url = format!("http://{}/{}", redirect.addr(), code);
let dst = format!("http://{}/{}", redirect.addr(), "dst");

let default_redirect_client = reqwest::Client::new();
let res = default_redirect_client
.request(reqwest::Method::POST, &url)
.redirect(reqwest::redirect::Policy::none())
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), url);
assert_eq!(res.status(), reqwest::StatusCode::MOVED_PERMANENTLY);
assert_eq!(
res.headers().get(reqwest::header::SERVER).unwrap(),
&"test-redirect"
);

let no_redirect_client = reqwest::Client::builder()
.redirect(reqwest::redirect::Policy::none())
.build()
.unwrap();

let res = no_redirect_client
.request(reqwest::Method::POST, &url)
.redirect(reqwest::redirect::Policy::limited(2))
.send()
.await
.unwrap();

assert_eq!(res.url().as_str(), dst);
assert_eq!(res.status(), reqwest::StatusCode::OK);
assert_eq!(
res.headers().get(reqwest::header::SERVER).unwrap(),
&"test-dst"
);
}