Skip to content

Commit 9465ec8

Browse files
authored
feat(core): add validation rule for the condition syntax check of row-level access controls (#1176)
1 parent 4158912 commit 9465ec8

File tree

12 files changed

+306
-9
lines changed

12 files changed

+306
-9
lines changed

ibis-server/app/model/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ class GcsFileConnectionInfo(BaseConnectionInfo):
357357

358358
class ValidateDTO(BaseModel):
359359
manifest_str: str = manifest_str_field
360-
parameters: dict[str, str]
360+
parameters: dict
361361
connection_info: ConnectionInfo = connection_info_field
362362

363363

ibis-server/app/model/validator.py

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
11
from __future__ import annotations
22

3+
from wren_core import (
4+
RowLevelAccessControl,
5+
SessionProperty,
6+
to_manifest,
7+
validate_rlac_rule,
8+
)
9+
310
from app.mdl.rewriter import Rewriter
411
from app.model import NotFoundError, UnprocessableEntityError
512
from app.model.connector import Connector
613
from app.util import base64_to_dict
714

8-
rules = ["column_is_valid", "relationship_is_valid"]
15+
rules = ["column_is_valid", "relationship_is_valid", "rlac_condition_syntax_is_valid"]
916

1017

1118
class Validator:
1219
def __init__(self, connector: Connector, rewriter: Rewriter):
1320
self.connector = connector
1421
self.rewriter = rewriter
1522

16-
async def validate(self, rule: str, parameters: dict[str, str], manifest_str: str):
23+
async def validate(self, rule: str, parameters: dict, manifest_str: str):
1724
if rule not in rules:
1825
raise RuleNotFoundError(rule)
1926
try:
@@ -144,6 +151,45 @@ def format_result(result):
144151
except Exception as e:
145152
raise ValidationError(f"Exception: {type(e)}, message: {e!s}")
146153

154+
async def _validate_rlac_condition_syntax_is_valid(
155+
self, parameters: dict, manifest_str: str
156+
):
157+
if parameters.get("modelName") is None:
158+
raise MissingRequiredParameterError("modelName")
159+
if parameters.get("requiredProperties") is None:
160+
raise MissingRequiredParameterError("requiredProperties")
161+
if parameters.get("condition") is None:
162+
raise MissingRequiredParameterError("condition")
163+
164+
model_name = parameters.get("modelName")
165+
required_properties = parameters.get("requiredProperties")
166+
condition = parameters.get("condition")
167+
168+
required_properties = [
169+
SessionProperty(
170+
name=prop["name"],
171+
required=bool(prop["required"]),
172+
default_expr=prop.get("defaultExpr", None),
173+
)
174+
for prop in required_properties
175+
]
176+
177+
rlac = RowLevelAccessControl(
178+
name="rlac_validation",
179+
required_properties=required_properties,
180+
condition=condition,
181+
)
182+
183+
manifest = to_manifest(manifest_str)
184+
model = manifest.get_model(model_name)
185+
if model is None:
186+
raise ValueError(f"Model {model_name} not found in manifest")
187+
188+
try:
189+
validate_rlac_rule(rlac, model)
190+
except Exception as e:
191+
raise ValidationError(e)
192+
147193
def _get_model(self, manifest, model_name):
148194
models = list(filter(lambda m: m["name"] == model_name, manifest["models"]))
149195
if len(models) == 0:

ibis-server/tests/routers/v3/connector/postgres/test_validate.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,60 @@ async def test_validate_rule_column_is_valid_without_one_parameter(
119119
)
120120
assert response.status_code == 422
121121
assert response.text == "Missing required parameter: `modelName`"
122+
123+
124+
async def test_validate_rlac_condition_syntax_is_valid(
125+
client, manifest_str, connection_info
126+
):
127+
response = await client.post(
128+
url=f"{base_url}/validate/rlac_condition_syntax_is_valid",
129+
json={
130+
"connectionInfo": connection_info,
131+
"manifestStr": manifest_str,
132+
"parameters": {
133+
"modelName": "orders",
134+
"requiredProperties": [
135+
{"name": "session_order", "required": "false"},
136+
],
137+
"condition": "@session_order = o_orderkey",
138+
},
139+
},
140+
)
141+
assert response.status_code == 204
142+
143+
response = await client.post(
144+
url=f"{base_url}/validate/rlac_condition_syntax_is_valid",
145+
json={
146+
"connectionInfo": connection_info,
147+
"manifestStr": manifest_str,
148+
"parameters": {
149+
"modelName": "orders",
150+
"requiredProperties": [
151+
{"name": "session_order", "required": False},
152+
],
153+
"condition": "@session_order = o_orderkey",
154+
},
155+
},
156+
)
157+
assert response.status_code == 204
158+
159+
response = await client.post(
160+
url=f"{base_url}/validate/rlac_condition_syntax_is_valid",
161+
json={
162+
"connectionInfo": connection_info,
163+
"manifestStr": manifest_str,
164+
"parameters": {
165+
"modelName": "orders",
166+
"requiredProperties": [
167+
{"name": "session_order", "required": "false"},
168+
],
169+
"condition": "@session_not_found = o_orderkey",
170+
},
171+
},
172+
)
173+
174+
assert response.status_code == 422
175+
assert (
176+
response.text
177+
== "Error during planning: The session property @session_not_found is used, but not found in the session properties"
178+
)

wren-core-base/src/mdl/py_method.rs

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
#[cfg(feature = "python-binding")]
2121
mod manifest_python_impl {
22-
use crate::mdl::manifest::{Manifest, Model};
22+
use crate::mdl::manifest::{Manifest, Model, RowLevelAccessControl, SessionProperty};
2323
use crate::mdl::DataSource;
2424
use pyo3::{pymethods, PyResult};
2525
use std::sync::Arc;
@@ -49,6 +49,16 @@ mod manifest_python_impl {
4949
fn data_source(&self) -> PyResult<Option<DataSource>> {
5050
Ok(self.data_source)
5151
}
52+
53+
fn get_model(&self, name: &str) -> PyResult<Option<Model>> {
54+
let model = self
55+
.models
56+
.iter()
57+
.find(|m| m.name == name)
58+
.cloned()
59+
.map(Arc::unwrap_or_clone);
60+
Ok(model)
61+
}
5262
}
5363

5464
#[pymethods]
@@ -58,4 +68,30 @@ mod manifest_python_impl {
5868
Ok(self.name.clone())
5969
}
6070
}
71+
72+
#[pymethods]
73+
impl SessionProperty {
74+
#[new]
75+
#[pyo3(signature = (name, required = false, default_expr = None))]
76+
fn new(name: String, required: bool, default_expr: Option<String>) -> Self {
77+
Self {
78+
name,
79+
required,
80+
default_expr,
81+
}
82+
}
83+
}
84+
85+
#[pymethods]
86+
impl RowLevelAccessControl {
87+
#[new]
88+
#[pyo3(signature = (name, condition, required_properties = vec![]))]
89+
fn new(name: String, condition: String, required_properties: Vec<SessionProperty>) -> Self {
90+
Self {
91+
name,
92+
condition,
93+
required_properties,
94+
}
95+
}
96+
}
6197
}

wren-core-py/src/errors.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ impl From<serde_json::Error> for CoreError {
5151

5252
impl From<wren_core::DataFusionError> for CoreError {
5353
fn from(err: wren_core::DataFusionError) -> Self {
54-
CoreError::new(&format!("DataFusion error: {}", err))
54+
CoreError::new(err.to_string().as_str())
5555
}
5656
}
5757

wren-core-py/src/extractor.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ fn resolve_used_table_names(mdl: &WrenMDL, sql: &str) -> Result<Vec<String>, Cor
5656
tables
5757
.iter()
5858
.filter(|t| {
59-
t.catalog().map_or(true, |catalog| catalog == mdl.catalog())
60-
&& t.schema().map_or(true, |schema| schema == mdl.schema())
59+
t.catalog().is_none_or(|catalog| catalog == mdl.catalog())
60+
&& t.schema().is_none_or(|schema| schema == mdl.schema())
6161
})
6262
.map(|t| t.table().to_string())
6363
.collect()

wren-core-py/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ mod errors;
77
mod extractor;
88
mod manifest;
99
pub mod remote_functions;
10+
mod validation;
1011

1112
#[pymodule]
1213
#[pyo3(name = "wren_core")]
@@ -15,7 +16,12 @@ fn wren_core_wrapper(m: &Bound<'_, PyModule>) -> PyResult<()> {
1516
m.add_class::<context::PySessionContext>()?;
1617
m.add_class::<PyRemoteFunction>()?;
1718
m.add_class::<manifest::Manifest>()?;
19+
m.add_class::<manifest::Model>()?;
20+
m.add_class::<manifest::RowLevelAccessControl>()?;
21+
m.add_class::<manifest::SessionProperty>()?;
1822
m.add_class::<extractor::PyManifestExtractor>()?;
1923
m.add_function(wrap_pyfunction!(manifest::to_json_base64, m)?)?;
24+
m.add_function(wrap_pyfunction!(manifest::to_manifest, m)?)?;
25+
m.add_function(wrap_pyfunction!(validation::validate_rlac_rule, m)?)?;
2026
Ok(())
2127
}

wren-core-py/src/manifest.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ pub fn to_json_base64(mdl: Manifest) -> Result<String, CoreError> {
1313
Ok(mdl_base64)
1414
}
1515

16+
#[pyfunction]
1617
/// Convert a base64 encoded JSON string to a manifest object.
1718
pub fn to_manifest(mdl_base64: &str) -> Result<Manifest, CoreError> {
1819
let decoded_bytes = BASE64_STANDARD.decode(mdl_base64)?;

wren-core-py/src/validation.rs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
use pyo3::pyfunction;
2+
use wren_core_base::mdl::{Model, RowLevelAccessControl};
3+
4+
use crate::errors::CoreError;
5+
6+
#[pyfunction]
7+
pub fn validate_rlac_rule(
8+
rule: &RowLevelAccessControl,
9+
model: &Model,
10+
) -> Result<(), CoreError> {
11+
wren_core::logical_plan::analyze::access_control::validate_rlac_rule(rule, model)?;
12+
Ok(())
13+
}

wren-core-py/tests/test_modeling_core.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,12 @@
55
import pytest
66
from wren_core import (
77
ManifestExtractor,
8+
RowLevelAccessControl,
89
SessionContext,
10+
SessionProperty,
911
to_json_base64,
12+
to_manifest,
13+
validate_rlac_rule,
1014
)
1115

1216
manifest = {
@@ -298,3 +302,35 @@ def test_rlac():
298302
rewritten_sql
299303
== "SELECT customer.c_custkey, customer.c_name FROM (SELECT customer.c_custkey, customer.c_name FROM (SELECT __source.c_custkey AS c_custkey, __source.c_name AS c_name FROM main.customer AS __source) AS customer) AS customer WHERE customer.c_name = 'test_user'"
300304
)
305+
306+
307+
def test_validate_rlac_rule():
308+
manifest = to_manifest(manifest_str)
309+
model = manifest.get_model("customer")
310+
if model is None:
311+
raise ValueError("Model customer not found in manifest")
312+
rlac = RowLevelAccessControl(
313+
name="test",
314+
required_properties=[
315+
SessionProperty(
316+
name="session_user",
317+
required=False,
318+
)
319+
],
320+
condition="c_name = @session_user",
321+
)
322+
323+
validate_rlac_rule(rlac, model)
324+
325+
rlac = RowLevelAccessControl(
326+
name="test",
327+
required_properties=[],
328+
condition="c_name = @session_user",
329+
)
330+
331+
with pytest.raises(Exception) as e:
332+
validate_rlac_rule(rlac, model)
333+
assert (
334+
str(e.value)
335+
== "Exception: DataFusion error: Error during planning: The session property @session_user is used, but not found in the session properties"
336+
)

0 commit comments

Comments
 (0)