Skip to content

Commit bab3747

Browse files
committed
support intermediate variables
1 parent 3fdeb35 commit bab3747

File tree

1 file changed

+28
-3
lines changed

1 file changed

+28
-3
lines changed

src/sas/sascalc/shape2sas/PluginGenerator.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,30 @@ def generate_plugin(
2626
return model_str, full_path
2727

2828

29+
def get_shape_symbols(symbols: tuple[set[str], set[str]], modelPars: list[list[str], list[str | float]]) -> tuple[set[str], set[str]]:
30+
"""
31+
Get the symbols used in the model, discarding user-defined variables
32+
"""
33+
shape_symbols = set()
34+
for shape in modelPars[0]: # iterate over shape names
35+
for symbol in shape[1:]: # skip shape name
36+
shape_symbols.add(symbol)
37+
38+
# filter out user-defined symbols
39+
lhs_symbols, rhs_symbols = set(), set()
40+
for symbol in symbols[0]:
41+
if symbol in shape_symbols or symbol[1:] in shape_symbols:
42+
lhs_symbols.add(symbol)
43+
44+
for symbol in symbols[1]:
45+
if symbol in shape_symbols or symbol[1:] in shape_symbols:
46+
rhs_symbols.add(symbol)
47+
48+
print(f"LHS: {lhs_symbols}")
49+
print(f"RHS: {rhs_symbols}")
50+
51+
return lhs_symbols, rhs_symbols
52+
2953
def format_parameter_list(par: list[list[str | float]]) -> str:
3054
"""
3155
Format a list of parameters to the model string. In this case the list
@@ -174,9 +198,10 @@ def generate_model(
174198
) -> str:
175199
"""Generates a theoretical model"""
176200
importStatement, parameters, translation = usertext.imports, usertext.params, usertext.constraints
177-
insert_delta, delta_parameters_def, delta_parameters_update = script_insert_delta_parameters(modelPars, fitPar, usertext.symbols)
178-
insert_constraint_update, constraint_update = script_insert_apply_constraints(usertext.symbols[0])
179-
insert_constrained_defs, constrained_parameters = script_insert_constrained_parameters(usertext.symbols, modelPars)
201+
symbols = get_shape_symbols(usertext.symbols, modelPars)
202+
insert_delta, delta_parameters_def, delta_parameters_update = script_insert_delta_parameters(modelPars, fitPar, symbols)
203+
insert_constraint_update, constraint_update = script_insert_apply_constraints(symbols[0])
204+
insert_constrained_defs, constrained_parameters = script_insert_constrained_parameters(symbols, modelPars)
180205
nl = '\n'
181206
fitPar.insert(0, "q")
182207
model_str = (

0 commit comments

Comments
 (0)