Skip to content

Commit f6790e8

Browse files
committed
fix(client): Handle invalid reddit response of base URL location
1 parent ea87ec3 commit f6790e8

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

src/client.rs

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use cached::proc_macro::cached;
33
use futures_lite::future::block_on;
44
use futures_lite::{future::Boxed, FutureExt};
55
use hyper::client::HttpConnector;
6+
use hyper::header::HeaderValue;
67
use hyper::{body, body::Buf, client, header, Body, Client, Method, Request, Response, Uri};
78
use hyper_rustls::HttpsConnector;
89
use libflate::gzip;
@@ -21,6 +22,7 @@ use crate::server::RequestExt;
2122
use crate::utils::format_url;
2223

2324
const REDDIT_URL_BASE: &str = "https://oauth.reddit.com";
25+
const ALTERNATIVE_REDDIT_URL_BASE: &str = "https://www.reddit.com";
2426

2527
pub static CLIENT: Lazy<Client<HttpsConnector<HttpConnector>>> = Lazy::new(|| {
2628
let https = hyper_rustls::HttpsConnectorBuilder::new()
@@ -221,12 +223,13 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
221223
if !redirect {
222224
return Ok(response);
223225
};
224-
226+
let location_header = response.headers().get(header::LOCATION);
227+
if location_header == Some(&HeaderValue::from_static("https://www.reddit.com/")) {
228+
return Err("Reddit response was invalid".to_string());
229+
}
225230
return request(
226231
method,
227-
response
228-
.headers()
229-
.get(header::LOCATION)
232+
location_header
230233
.map(|val| {
231234
// We need to make adjustments to the URI
232235
// we get back from Reddit. Namely, we
@@ -239,7 +242,11 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
239242
// required.
240243
//
241244
// 2. Percent-encode the path.
242-
let new_path = percent_encode(val.as_bytes(), CONTROLS).to_string().trim_start_matches(REDDIT_URL_BASE).to_string();
245+
let new_path = percent_encode(val.as_bytes(), CONTROLS)
246+
.to_string()
247+
.trim_start_matches(REDDIT_URL_BASE)
248+
.trim_start_matches(ALTERNATIVE_REDDIT_URL_BASE)
249+
.to_string();
243250
format!("{new_path}{}raw_json=1", if new_path.contains('?') { "&" } else { "?" })
244251
})
245252
.unwrap_or_default()
@@ -298,7 +305,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
298305
}
299306
}
300307
Err(e) => {
301-
dbg_msg!("{} {}: {}", method, path, e);
308+
dbg_msg!("{method} {REDDIT_URL_BASE}{path}: {}", e);
302309

303310
Err(e.to_string())
304311
}
@@ -312,6 +319,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
312319
// Make a request to a Reddit API and parse the JSON response
313320
#[cached(size = 100, time = 30, result = true)]
314321
pub async fn json(path: String, quarantine: bool) -> Result<Value, String> {
322+
trace!("going to get {path}");
315323
// Closure to quickly build errors
316324
let err = |msg: &str, e: String, path: String| -> Result<Value, String> {
317325
// eprintln!("{} - {}: {}", url, msg, e);

0 commit comments

Comments
 (0)