Skip to content

Commit fc794fb

Browse files
committed
Add optional boundaries to getregistrationreceipt
1 parent 886e0ff commit fc794fb

File tree

4 files changed

+207
-43
lines changed

4 files changed

+207
-43
lines changed

watchtower-plugin/src/convert.rs

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -258,6 +258,117 @@ impl TryFrom<serde_json::Value> for GetAppointmentParams {
258258
}
259259
}
260260

261+
// Errors related to `getregistrationreceipt` command
262+
#[derive(Debug)]
263+
pub enum GetRegistrationReceiptError {
264+
InvalidId(String),
265+
InvalidFormat(String),
266+
}
267+
268+
impl std::fmt::Display for GetRegistrationReceiptError {
269+
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
270+
match self {
271+
GetRegistrationReceiptError::InvalidId(x) => write!(f, "{x}"),
272+
GetRegistrationReceiptError::InvalidFormat(x) => write!(f, "{x}"),
273+
}
274+
}
275+
}
276+
277+
// Parameters related to the `getregistrationreceipt` command
278+
#[derive(Debug)]
279+
pub struct GetRegistrationReceiptParams {
280+
pub tower_id: TowerId,
281+
pub subscription_start: Option<u32>,
282+
pub subscription_expiry: Option<u32>,
283+
}
284+
285+
impl TryFrom<serde_json::Value> for GetRegistrationReceiptParams {
286+
type Error = GetRegistrationReceiptError;
287+
288+
fn try_from(value: serde_json::Value) -> Result<Self, Self::Error> {
289+
match value {
290+
serde_json::Value::Array(a) => {
291+
let param_count = a.len();
292+
if param_count == 2 {
293+
Err(GetRegistrationReceiptError::InvalidFormat((
294+
"Both ends of boundary (subscription_start and subscription_expiry) are required.").to_string()
295+
))
296+
} else if param_count != 1 && param_count != 3 {
297+
Err(GetRegistrationReceiptError::InvalidFormat(format!(
298+
"Unexpected request format. The request needs 1 or 3 parameter. Received: {param_count}"
299+
)))
300+
} else {
301+
let tower_id = if let Some(s) = a.get(0).unwrap().as_str() {
302+
TowerId::from_str(s).map_err(|_| {
303+
GetRegistrationReceiptError::InvalidId("Invalid tower id".to_owned())
304+
})
305+
} else {
306+
Err(GetRegistrationReceiptError::InvalidId(
307+
"tower_id must be a hex encoded string".to_owned(),
308+
))
309+
}?;
310+
311+
let (subscription_start, subscription_expiry) = if let (Some(start), Some(expire)) = (a.get(1), a.get(2)){
312+
let start = start.as_i64().ok_or_else(|| {
313+
GetRegistrationReceiptError::InvalidFormat(
314+
"Subscription_start must be a positive integer".to_owned(),
315+
)
316+
})?;
317+
318+
let expire = expire.as_i64().ok_or_else(|| {
319+
GetRegistrationReceiptError::InvalidFormat(
320+
"Subscription_expire must be a positive integer".to_owned(),
321+
)
322+
})?;
323+
324+
if start >= 0 && expire > start {
325+
(Some(start as u32), Some(expire as u32))
326+
} else {
327+
return Err(GetRegistrationReceiptError::InvalidFormat(
328+
"subscription_start must be a positive integer and subscription_expire must be a positive integer greater than subscription_start".to_owned(),
329+
));
330+
}
331+
} else {
332+
(None, None)
333+
};
334+
335+
Ok(
336+
Self {
337+
tower_id,
338+
subscription_start,
339+
subscription_expiry,
340+
}
341+
)
342+
}
343+
},
344+
serde_json::Value::Object(mut m) => {
345+
let allowed_keys = ["tower_id", "subscription_start", "subscription_expiry"];
346+
let param_count = m.len();
347+
348+
if m.is_empty() || param_count > allowed_keys.len() {
349+
Err(GetRegistrationReceiptError::InvalidFormat(format!("Unexpected request format. The request needs 1-3 parameters. Received: {param_count}")))
350+
} else if !m.contains_key(allowed_keys[0]){
351+
Err(GetRegistrationReceiptError::InvalidId(format!("{} is mandatory", allowed_keys[0])))
352+
} else if !m.iter().all(|(k, _)| allowed_keys.contains(&k.as_str())) {
353+
Err(GetRegistrationReceiptError::InvalidFormat("Invalid named parameter found in request".to_owned()))
354+
} else {
355+
let mut params = Vec::with_capacity(allowed_keys.len());
356+
for k in allowed_keys {
357+
if let Some(v) = m.remove(k) {
358+
params.push(v);
359+
}
360+
}
361+
362+
GetRegistrationReceiptParams::try_from(json!(params))
363+
}
364+
},
365+
_ => Err(GetRegistrationReceiptError::InvalidFormat(format!(
366+
"Unexpected request format. Expected: tower_id [subscription_start] [subscription_expire]. Received: '{value}'"
367+
))),
368+
}
369+
}
370+
}
371+
261372
/// Data associated with a commitment revocation. Represents the data sent by CoreLN through the `commitment_revocation` hook.
262373
#[derive(Debug, Serialize, Deserialize)]
263374
pub struct CommitmentRevocation {

watchtower-plugin/src/dbm.rs

Lines changed: 60 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ use std::iter::FromIterator;
33
use std::path::PathBuf;
44
use std::str::FromStr;
55

6-
use rusqlite::{params, Connection, Error as SqliteError};
6+
use rusqlite::{params, Connection, Error as SqliteError, ToSql};
77

88
use bitcoin::secp256k1::SecretKey;
99

@@ -209,36 +209,43 @@ impl DBM {
209209
Some(tower)
210210
}
211211

212-
/// Loads the latest registration receipt for a given tower.
213-
///
212+
/// Loads the registration receipt(s) for a given tower in the given subscription range.
213+
/// If no range is given, then loads the latest receipt
214214
/// Latests is determined by the one with the `subscription_expiry` further into the future.
215215
pub fn load_registration_receipt(
216216
&self,
217217
tower_id: TowerId,
218218
user_id: UserId,
219-
) -> Option<RegistrationReceipt> {
220-
let mut stmt = self
221-
.connection
222-
.prepare(
223-
"SELECT available_slots, subscription_start, subscription_expiry, signature
224-
FROM registration_receipts
225-
WHERE tower_id = ?1 AND subscription_expiry = (SELECT MAX(subscription_expiry)
226-
FROM registration_receipts
227-
WHERE tower_id = ?1)",
228-
)
229-
.unwrap();
219+
subscription_start: Option<u32>,
220+
subscription_expiry: Option<u32>,
221+
) -> Option<Vec<RegistrationReceipt>> {
222+
let mut query = "SELECT available_slots, subscription_start, subscription_expiry, signature FROM registration_receipts WHERE tower_id = ?1".to_string();
223+
224+
let tower_id_encoded = tower_id.to_vec();
225+
let mut params: Vec<&dyn ToSql> = vec![&tower_id_encoded];
226+
227+
if subscription_expiry.is_none() {
228+
query.push_str(" AND subscription_expiry = (SELECT MAX(subscription_expiry) FROM registration_receipts WHERE tower_id = ?1)")
229+
} else {
230+
query.push_str(" AND subscription_start>=?2 AND subscription_expiry <=?3");
231+
params.push(&subscription_start);
232+
params.push(&subscription_expiry)
233+
}
234+
let mut stmt = self.connection.prepare(&query).unwrap();
230235

231-
stmt.query_row([tower_id.to_vec()], |row| {
232-
let slots: u32 = row.get(0).unwrap();
233-
let start: u32 = row.get(1).unwrap();
234-
let expiry: u32 = row.get(2).unwrap();
235-
let signature: String = row.get(3).unwrap();
236+
stmt.query_map(params.as_slice(), |row| {
237+
let slots: u32 = row.get(0)?;
238+
let start: u32 = row.get(1)?;
239+
let expiry: u32 = row.get(2)?;
240+
let signature: String = row.get(3)?;
236241

237242
Ok(RegistrationReceipt::with_signature(
238243
user_id, slots, start, expiry, signature,
239244
))
240245
})
241-
.ok()
246+
.unwrap()
247+
.map(|r| r.ok())
248+
.collect()
242249
}
243250

244251
/// Removes a tower record from the database.
@@ -725,34 +732,49 @@ mod tests {
725732
let tower_id = get_random_user_id();
726733
let net_addr = "talaia.watch";
727734
let receipt = get_random_registration_receipt();
735+
let subscription_start = Some(receipt.subscription_start());
736+
let subscription_expiry = Some(receipt.subscription_expiry());
728737

729738
// Check the receipt was stored
730739
dbm.store_tower_record(tower_id, net_addr, &receipt)
731740
.unwrap();
732741
assert_eq!(
733-
dbm.load_registration_receipt(tower_id, receipt.user_id())
734-
.unwrap(),
742+
dbm.load_registration_receipt(
743+
tower_id,
744+
receipt.user_id(),
745+
subscription_start,
746+
subscription_expiry
747+
)
748+
.unwrap()[0],
735749
receipt
736750
);
737751

738-
// Add another receipt for the same tower with a higher expiry and check this last one is loaded
752+
// Add another receipt for the same tower with a higher expiry and check that output gives vector of both receipts
739753
let middle_receipt = get_registration_receipt_from_previous(&receipt);
740754
let latest_receipt = get_registration_receipt_from_previous(&middle_receipt);
741755

756+
let latest_subscription_expiry = Some(latest_receipt.subscription_expiry());
757+
742758
dbm.store_tower_record(tower_id, net_addr, &latest_receipt)
743759
.unwrap();
744760
assert_eq!(
745-
dbm.load_registration_receipt(tower_id, latest_receipt.user_id())
746-
.unwrap(),
747-
latest_receipt
761+
dbm.load_registration_receipt(
762+
tower_id,
763+
latest_receipt.user_id(),
764+
subscription_start,
765+
latest_subscription_expiry
766+
)
767+
.unwrap(),
768+
vec![receipt, latest_receipt.clone()]
748769
);
749770

750-
// Add a final one with a lower expiry and check the last is still loaded
771+
// Add a final one with a lower expiry and check if the lastest receipt is loaded when boundry
772+
// params are not passed
751773
dbm.store_tower_record(tower_id, net_addr, &middle_receipt)
752774
.unwrap();
753775
assert_eq!(
754-
dbm.load_registration_receipt(tower_id, latest_receipt.user_id())
755-
.unwrap(),
776+
dbm.load_registration_receipt(tower_id, latest_receipt.user_id(), None, None)
777+
.unwrap()[0],
756778
latest_receipt
757779
);
758780
}
@@ -765,13 +787,20 @@ mod tests {
765787
let tower_id = get_random_user_id();
766788
let net_addr = "talaia.watch";
767789
let receipt = get_random_registration_receipt();
790+
let subscription_start = Some(receipt.subscription_start());
791+
let subscription_expiry = Some(receipt.subscription_expiry());
768792

769793
// Store it once
770794
dbm.store_tower_record(tower_id, net_addr, &receipt)
771795
.unwrap();
772796
assert_eq!(
773-
dbm.load_registration_receipt(tower_id, receipt.user_id())
774-
.unwrap(),
797+
dbm.load_registration_receipt(
798+
tower_id,
799+
receipt.user_id(),
800+
subscription_start,
801+
subscription_expiry
802+
)
803+
.unwrap()[0],
775804
receipt
776805
);
777806

watchtower-plugin/src/main.rs

Lines changed: 22 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ use teos_common::protos as common_msgs;
1818
use teos_common::TowerId;
1919
use teos_common::{cryptography, errors};
2020

21-
use watchtower_plugin::convert::{CommitmentRevocation, GetAppointmentParams, RegisterParams};
21+
use watchtower_plugin::convert::{
22+
CommitmentRevocation, GetAppointmentParams, GetRegistrationReceiptParams, RegisterParams,
23+
};
2224
use watchtower_plugin::net::http::{
2325
self, post_request, process_post_response, AddAppointmentError, ApiResponse, RequestError,
2426
};
@@ -126,22 +128,33 @@ async fn register(
126128
Ok(json!(receipt))
127129
}
128130

129-
/// Gets the latest registration receipt from the client to a given tower (if it exists).
130-
///
131+
/// Gets the registration receipt(s) from the client to a given tower (if it exists) in the given
132+
/// range. If no range is given, then gets the latest registration receipt.
131133
/// This is pulled from the database
132134
async fn get_registration_receipt(
133135
plugin: Plugin<Arc<Mutex<WTClient>>>,
134136
v: serde_json::Value,
135137
) -> Result<serde_json::Value, Error> {
136-
let tower_id = TowerId::try_from(v).map_err(|x| anyhow!(x))?;
138+
let params = GetRegistrationReceiptParams::try_from(v).map_err(|x| anyhow!(x))?;
139+
let tower_id = params.tower_id;
140+
let subscription_start = params.subscription_start;
141+
let subscription_expiry = params.subscription_expiry;
137142
let state = plugin.state().lock().unwrap();
138143

139-
if let Some(response) = state.get_registration_receipt(tower_id) {
140-
Ok(json!(response))
144+
let response =
145+
state.get_registration_receipt(tower_id, subscription_start, subscription_expiry);
146+
if response.clone().unwrap().is_empty() {
147+
if state.towers.contains_key(&tower_id) {
148+
Err(anyhow!(
149+
"No registration receipt found for {tower_id} on the given range"
150+
))
151+
} else {
152+
Err(anyhow!(
153+
"Cannot find {tower_id} within the known towers. Have you registered?"
154+
))
155+
}
141156
} else {
142-
Err(anyhow!(
143-
"Cannot find {tower_id} within the known towers. Have you registered?"
144-
))
157+
Ok(json!(response))
145158
}
146159
}
147160

watchtower-plugin/src/wt_client.rs

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -179,9 +179,20 @@ impl WTClient {
179179
Ok(())
180180
}
181181

182-
/// Gets the latest registration receipt of a given tower.
183-
pub fn get_registration_receipt(&self, tower_id: TowerId) -> Option<RegistrationReceipt> {
184-
self.dbm.load_registration_receipt(tower_id, self.user_id)
182+
/// Gets the registration receipt(s) of a given tower in the given range.
183+
/// If no range is given then gets the latest registration receipt
184+
pub fn get_registration_receipt(
185+
&self,
186+
tower_id: TowerId,
187+
subscription_start: Option<u32>,
188+
subscription_expiry: Option<u32>,
189+
) -> Option<Vec<RegistrationReceipt>> {
190+
self.dbm.load_registration_receipt(
191+
tower_id,
192+
self.user_id,
193+
subscription_start,
194+
subscription_expiry,
195+
)
185196
}
186197

187198
/// Loads a tower record from the database.

0 commit comments

Comments
 (0)