Skip to content

Commit 8bf3194

Browse files
committed
solution update of talaia-labs#199
1 parent 47da3c8 commit 8bf3194

File tree

4 files changed

+70
-70
lines changed

4 files changed

+70
-70
lines changed

watchtower-plugin/src/convert.rs

Lines changed: 18 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -200,7 +200,7 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
200200
if param_count != 2 {
201201
Err(GetAppointmentError::InvalidFormat(format!(
202202
"Unexpected request format. The request needs 2 parameter. Received: {param_count}"
203-
)))
203+
)))
204204
} else {
205205
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
206206
TowerId::from_str(s).map_err(|_| {
@@ -289,7 +289,9 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
289289
match value {
290290
serde_json::Value::Array(a) => {
291291
let param_count = a.len();
292-
if param_count != 1 && param_count != 3 {
292+
if param_count == 2{
293+
Err(GetRegistrationReceiptError::InvalidFormat(("Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string()))
294+
} else if param_count != 1 && param_count != 3 {
293295
Err(GetRegistrationReceiptError::InvalidFormat(format!(
294296
"Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}"
295297
)))
@@ -303,33 +305,23 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
303305
"tower_id must be a hex encoded string".to_owned(),
304306
))
305307
}?;
306-
let subscription_start = if let Some(start) = a.get(1).and_then(|v| v.as_i64()) {
307-
if start >= 0 {
308-
Some(start as u32)
309-
} else {
310-
return Err(GetRegistrationReceiptError::InvalidFormat(
311-
"Subscription-start must be a positive integer".to_owned(),
312-
));
313-
}
314-
} else {
315-
None
316-
};
317-
let subscription_expiry = if let Some(expire) = a.get(2).and_then(|v| v.as_i64()) {
318-
if expire > subscription_start.unwrap() as i64 {
319-
Some(expire as u32)
308+
309+
let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1).and_then(|v| v.as_i64()), a.get(2).and_then(|v| v.as_i64())) {
310+
if start >= 0 && expire > start {
311+
(Some(start as u32), Some(expire as u32))
320312
} else {
321313
return Err(GetRegistrationReceiptError::InvalidFormat(
322-
"Subscription-expire must be a positive integer and greater than subscription_start".to_owned(),
323-
));
314+
"Subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
315+
));
324316
}
317+
} else if a.get(1).is_some() || a.get(2).is_some() {
318+
return Err(GetRegistrationReceiptError::InvalidFormat(
319+
"Subscription_start and subscription_expiry must be provided together as positive integers".to_owned(),
320+
));
325321
} else {
326-
None
322+
(None, None)
327323
};
328-
if subscription_start.is_some() != subscription_expiry.is_some() {
329-
return Err(GetRegistrationReceiptError::InvalidFormat(
330-
"Subscription-start and subscription-expiry must be provided together".to_owned(),
331-
));
332-
}
324+
333325
Ok(Self {
334326
tower_id,
335327
subscription_start,
@@ -354,11 +346,12 @@ impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
354346
params.push(v);
355347
}
356348
}
349+
357350
GetRegistrationReceiptParams::try_from(json!(params))
358351
}
359352
},
360353
_ => Err(GetRegistrationReceiptError::InvalidFormat(format!(
361-
"Unexpected request format. Expected: tower_id and optional arguments subscription_start & subscription_expire. Received: '{value}'"
354+
"Unexpected request format. Expected: tower_id [subscription_start] [subscription_expire]. Received: '{value}'"
362355
))),
363356
}
364357
}

watchtower-plugin/src/dbm.rs

Lines changed: 40 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::iter::FromIterator;
33
use std::path::PathBuf;
44
use std::str::FromStr;
55

6-
use rusqlite::{params, Connection, Error as SqliteError};
6+
use rusqlite::{params, Connection, Error as SqliteError, ToSql};
77

88
use bitcoin::secp256k1::SecretKey;
99

@@ -218,28 +218,37 @@ impl DBM {
218218
user_id: UserId,
219219
subscription_start: Option<u32>,
220220
subscription_expiry: Option<u32>,
221-
) -> Option<RegistrationReceipt> {
222-
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1 AND (subscription_start >=?2 OR ?2 is NULL) AND (subscription_expiry <=?3 OR ?3 is NULL)".to_string();
221+
) -> Vec<RegistrationReceipt> {
222+
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string();
223223

224-
if subscription_expiry == None {
225-
query.push_str(" OR subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)")
226-
};
224+
let tower_id_encoded = tower_id.to_vec();
225+
let mut params: Vec<&dyn ToSql> = vec![&tower_id_encoded];
226+
227+
if subscription_expiry.is_none() {
228+
query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)")
229+
} else {
230+
query.push_str(" AND subscription_start>=?2 AND subscription_expiry <=?3");
231+
params.push(&subscription_start);
232+
params.push(&subscription_expiry)
233+
}
227234
let mut stmt = self.connection.prepare(&query).unwrap();
228235

229-
stmt.query_row(
230-
params![tower_id.to_vec(), subscription_start, subscription_expiry],
231-
|row| {
232-
let slots: u32 = row.get(0).unwrap();
233-
let start: u32 = row.get(1).unwrap();
234-
let expiry: u32 = row.get(2).unwrap();
235-
let signature: String = row.get(3).unwrap();
236+
let receipts = stmt
237+
.query_map(params.as_slice(), |row| {
238+
let slots: u32 = row.get(0)?;
239+
let start: u32 = row.get(1)?;
240+
let expiry: u32 = row.get(2)?;
241+
let signature: String = row.get(3)?;
236242

237243
Ok(RegistrationReceipt::with_signature(
238244
user_id, slots, start, expiry, signature,
239245
))
240-
},
241-
)
242-
.ok()
246+
})
247+
.unwrap()
248+
.collect::<Result<Vec<_>, _>>()
249+
.unwrap_or_default();
250+
251+
receipts
243252
}
244253

245254
/// Removes a tower record from the database.
@@ -650,8 +659,8 @@ mod tests {
650659

651660
use teos_common::cryptography::get_random_keypair;
652661
use teos_common::test_utils::{
653-
generate_random_appointment, get_random_int, get_random_registration_receipt,
654-
get_random_user_id, get_registration_receipt_from_previous,
662+
generate_random_appointment, get_random_registration_receipt, get_random_user_id,
663+
get_registration_receipt_from_previous,
655664
};
656665

657666
impl DBM {
@@ -738,39 +747,34 @@ mod tests {
738747
receipt.user_id(),
739748
subscription_start,
740749
subscription_expiry
741-
)
742-
.unwrap(),
750+
)[0],
743751
receipt
744752
);
745753

746-
// Add another receipt for the same tower with a higher expiry and check this last one is loaded
754+
// Add another receipt for the same tower with a higher expiry and check that output gives vector of both receipts
747755
let middle_receipt = get_registration_receipt_from_previous(&receipt);
748756
let latest_receipt = get_registration_receipt_from_previous(&middle_receipt);
749757

758+
let latest_subscription_expiry = Some(latest_receipt.subscription_expiry());
759+
750760
dbm.store_tower_record(tower_id, net_addr, &latest_receipt)
751761
.unwrap();
752762
assert_eq!(
753763
dbm.load_registration_receipt(
754764
tower_id,
755765
latest_receipt.user_id(),
756766
subscription_start,
757-
subscription_expiry
758-
)
759-
.unwrap(),
760-
latest_receipt
767+
latest_subscription_expiry
768+
),
769+
vec![receipt, latest_receipt.clone()]
761770
);
762771

763-
// Add a final one with a lower expiry and check the last is still loaded
772+
// Add a final one with a lower expiry and check if the lastest receipt is loaded when boundry
773+
// params are not passed
764774
dbm.store_tower_record(tower_id, net_addr, &middle_receipt)
765775
.unwrap();
766776
assert_eq!(
767-
dbm.load_registration_receipt(
768-
tower_id,
769-
latest_receipt.user_id(),
770-
subscription_start,
771-
subscription_expiry
772-
)
773-
.unwrap(),
777+
dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)[0],
774778
latest_receipt
775779
);
776780
}
@@ -783,8 +787,8 @@ mod tests {
783787
let tower_id = get_random_user_id();
784788
let net_addr = "talaia.watch";
785789
let receipt = get_random_registration_receipt();
786-
let subscription_start = get_random_int();
787-
let subscription_expiry = get_random_int();
790+
let subscription_start = Some(receipt.subscription_start());
791+
let subscription_expiry = Some(receipt.subscription_expiry());
788792

789793
// Store it once
790794
dbm.store_tower_record(tower_id, net_addr, &receipt)
@@ -795,8 +799,7 @@ mod tests {
795799
receipt.user_id(),
796800
subscription_start,
797801
subscription_expiry
798-
)
799-
.unwrap(),
802+
)[0],
800803
receipt
801804
);
802805

watchtower-plugin/src/main.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -141,14 +141,18 @@ async fn get_registration_receipt(
141141
let subscription_expiry = params.subscription_expiry;
142142
let state = plugin.state().lock().unwrap();
143143

144-
if let Some(response) =
145-
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry)
146-
{
147-
Ok(json!(response))
144+
let response =
145+
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry);
146+
if response.is_empty() {
147+
if state.towers.contains_key(&tower_id) {
148+
Err(anyhow!("No registration receipt found for {tower_id}"))
149+
} else {
150+
Err(anyhow!(
151+
"Cannot find {tower_id} within the known towers. Have you registered ?"
152+
))
153+
}
148154
} else {
149-
Err(anyhow!(
150-
"Cannot find {tower_id} within the known towers. Have you registered?"
151-
))
155+
Ok(json!(response))
152156
}
153157
}
154158

watchtower-plugin/src/wt_client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -185,7 +185,7 @@ impl WTClient {
185185
tower_id: TowerId,
186186
subscription_start: Option<u32>,
187187
subscription_expiry: Option<u32>,
188-
) -> Option<RegistrationReceipt> {
188+
) -> Vec<RegistrationReceipt> {
189189
self.dbm.load_registration_receipt(
190190
tower_id,
191191
self.user_id,

0 commit comments

Comments
 (0)