@@ -3,6 +3,7 @@ use cached::proc_macro::cached;
3
3
use futures_lite:: future:: block_on;
4
4
use futures_lite:: { future:: Boxed , FutureExt } ;
5
5
use hyper:: client:: HttpConnector ;
6
+ use hyper:: header:: HeaderValue ;
6
7
use hyper:: { body, body:: Buf , client, header, Body , Client , Method , Request , Response , Uri } ;
7
8
use hyper_rustls:: HttpsConnector ;
8
9
use libflate:: gzip;
@@ -21,6 +22,7 @@ use crate::server::RequestExt;
21
22
use crate :: utils:: format_url;
22
23
23
24
const REDDIT_URL_BASE : & str = "https://oauth.reddit.com" ;
25
+ const ALTERNATIVE_REDDIT_URL_BASE : & str = "https://www.reddit.com" ;
24
26
25
27
pub static CLIENT : Lazy < Client < HttpsConnector < HttpConnector > > > = Lazy :: new ( || {
26
28
let https = hyper_rustls:: HttpsConnectorBuilder :: new ( )
@@ -221,12 +223,13 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
221
223
if !redirect {
222
224
return Ok ( response) ;
223
225
} ;
224
-
226
+ let location_header = response. headers ( ) . get ( header:: LOCATION ) ;
227
+ if location_header == Some ( & HeaderValue :: from_static ( "https://www.reddit.com/" ) ) {
228
+ return Err ( "Reddit response was invalid" . to_string ( ) ) ;
229
+ }
225
230
return request (
226
231
method,
227
- response
228
- . headers ( )
229
- . get ( header:: LOCATION )
232
+ location_header
230
233
. map ( |val| {
231
234
// We need to make adjustments to the URI
232
235
// we get back from Reddit. Namely, we
@@ -239,7 +242,11 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
239
242
// required.
240
243
//
241
244
// 2. Percent-encode the path.
242
- let new_path = percent_encode ( val. as_bytes ( ) , CONTROLS ) . to_string ( ) . trim_start_matches ( REDDIT_URL_BASE ) . to_string ( ) ;
245
+ let new_path = percent_encode ( val. as_bytes ( ) , CONTROLS )
246
+ . to_string ( )
247
+ . trim_start_matches ( REDDIT_URL_BASE )
248
+ . trim_start_matches ( ALTERNATIVE_REDDIT_URL_BASE )
249
+ . to_string ( ) ;
243
250
format ! ( "{new_path}{}raw_json=1" , if new_path. contains( '?' ) { "&" } else { "?" } )
244
251
} )
245
252
. unwrap_or_default ( )
@@ -298,7 +305,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
298
305
}
299
306
}
300
307
Err ( e) => {
301
- dbg_msg ! ( "{} {} : {}" , method , path , e) ;
308
+ dbg_msg ! ( "{method } {REDDIT_URL_BASE}{path} : {}" , e) ;
302
309
303
310
Err ( e. to_string ( ) )
304
311
}
@@ -312,6 +319,7 @@ fn request(method: &'static Method, path: String, redirect: bool, quarantine: bo
312
319
// Make a request to a Reddit API and parse the JSON response
313
320
#[ cached( size = 100 , time = 30 , result = true ) ]
314
321
pub async fn json ( path : String , quarantine : bool ) -> Result < Value , String > {
322
+ trace ! ( "going to get {path}" ) ;
315
323
// Closure to quickly build errors
316
324
let err = |msg : & str , e : String , path : String | -> Result < Value , String > {
317
325
// eprintln!("{} - {}: {}", url, msg, e);
0 commit comments