Skip to content

Commit 3581344

Browse files
committed
added COM support
1 parent 9808b68 commit 3581344

File tree

1 file changed

+76
-10
lines changed

1 file changed

+76
-10
lines changed

src/sas/qtgui/Calculators/Shape2SAS/Constraints.py

Lines changed: 76 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def merge_text(current_text: str, parameter_text: str):
8888
for start, line in enumerate(current_text_lines):
8989
if line.startswith("parameters ="):
9090
break
91-
91+
9292
# find closing bracket of the parameters list
9393
bracket_count = 0
9494
for end, line in enumerate(current_text_lines[start:]):
@@ -152,7 +152,57 @@ def as_ast(text: str):
152152
traceback_to_show = '\n'.join(last_lines)
153153
logger.error(traceback_to_show)
154154
return None
155-
155+
156+
def expand_center_of_mass_pars(constraint: ast.Assign) -> list[ast.Assign]:
157+
"""Expand center of mass parameters to include all components."""
158+
159+
# check if this is a COM assignment we need to expand
160+
if (len(constraint.targets) != 1 or
161+
not isinstance(constraint.targets[0], ast.Name) or
162+
not isinstance(constraint.value, ast.Name)):
163+
return constraint
164+
165+
lhs = constraint.targets[0].id
166+
rhs = constraint.value.id
167+
168+
# check if lhs is a COM parameter (with or without 'd' prefix)
169+
lhs_is_delta = lhs.startswith('d')
170+
lhs_base = lhs[1:] if lhs_is_delta else lhs
171+
172+
if not (lhs_base.startswith("COM") and lhs_base[3:].isdigit()):
173+
return constraint
174+
175+
# check rhs
176+
rhs_is_delta = rhs.startswith('d')
177+
rhs_base = rhs[1:] if rhs_is_delta else rhs
178+
179+
new_targets, new_values = [], []
180+
if rhs_base.startswith("COM") and rhs_base[3:].isdigit():
181+
print("CASE DOUBLE COM")
182+
# rhs is also a COM parameter: COM2 = COM1 -> COMX2, COMY2, COMZ2 =COMX1, COMY1, COMZ1
183+
lhs_shape_num = lhs_base[3:]
184+
rhs_shape_num = rhs_base[3:]
185+
186+
for axis in ['X', 'Y', 'Z']:
187+
lhs_full = f"{'d' if lhs_is_delta else ''}COM{axis}{lhs_shape_num}"
188+
rhs_full = f"{'d' if rhs_is_delta else ''}COM{axis}{rhs_shape_num}"
189+
new_targets.append(ast.Name(id=lhs_full, ctx=ast.Store()))
190+
new_values.append(ast.Name(id=rhs_full, ctx=ast.Load()))
191+
192+
else:
193+
print("CASE SINGLE COM")
194+
# rhs is a regular parameter: COM2 = X -> COMX2, COMY2, COMZ2 = X, X, X
195+
lhs_shape_num = lhs_base[3:]
196+
rhs_full = f"{'d' if rhs_is_delta else ''}{rhs_base}"
197+
for axis in ['X', 'Y', 'Z']:
198+
lhs_full = f"{'d' if lhs_is_delta else ''}COM{axis}{lhs_shape_num}"
199+
new_targets.append(ast.Name(id=lhs_full, ctx=ast.Store()))
200+
new_values.append(ast.Name(id=rhs_full, ctx=ast.Load()))
201+
202+
constraint.targets = [ast.Tuple(elts=new_targets, ctx=ast.Store())]
203+
constraint.value = ast.Tuple(elts=new_values, ctx=ast.Load())
204+
return constraint
205+
156206
def parse_ast(tree: ast.AST):
157207
params = None
158208
imports = []
@@ -166,24 +216,37 @@ def parse_ast(tree: ast.AST):
166216
case ast.Assign():
167217
if node.targets[0].id == 'parameters':
168218
params = node
219+
elif node.targets[0].id.startswith('dCOM') or node.targets[0].id.startswith('COM'):
220+
constraints.append(expand_center_of_mass_pars(node))
169221
else:
170222
constraints.append(node)
171223

224+
print(f"Parsed constraints: {constraints}")
172225
return params, imports, constraints
173-
226+
174227
def extract_symbols(constraints: list[ast.AST]) -> tuple[list[str], list[str]]:
175228
"""Extract all symbols used in the constraints."""
176229
lhs, rhs = set(), set()
177230
for node in constraints:
178231
# left-hand side of assignment
179232
for target in node.targets:
180-
if isinstance(target, ast.Name):
181-
lhs.add(target.id)
233+
match target:
234+
case ast.Name():
235+
lhs.add(target.id)
236+
case ast.Tuple():
237+
for elt in target.elts:
238+
if isinstance(elt, ast.Name):
239+
lhs.add(elt.id)
182240

183241
# right-hand side of assignment
184242
for value in ast.walk(node.value):
185-
if isinstance(value, ast.Name):
186-
rhs.add(value.id)
243+
match value:
244+
case ast.Name():
245+
rhs.add(value.id)
246+
case ast.Tuple:
247+
for elt in value.elts:
248+
if isinstance(elt, ast.Name):
249+
rhs.add(elt.id)
187250

188251
return lhs, rhs
189252

@@ -215,13 +278,16 @@ def validate_imports(imports: list[ast.ImportFrom | ast.Import]):
215278

216279
def mark_named_parameters(checkedPars: list[list[bool]], modelPars: list[str], symbols: set[str]):
217280
"""Mark parameters in the modelPars as checked if they are in symbols_lhs."""
281+
def in_symbols(par: str):
282+
if par in symbols: return True
283+
if 'd' + par in symbols: return True
284+
return False
285+
218286
for i, shape in enumerate(modelPars):
219287
for j, par in enumerate(shape):
220288
if par is None:
221289
continue
222-
in_symbols = par in symbols
223-
d_in_symbols = "d" + par in symbols
224-
checkedPars[i][j] = checkedPars[i][j] or in_symbols or d_in_symbols
290+
checkedPars[i][j] = checkedPars[i][j] or in_symbols(par)
225291
return checkedPars
226292

227293
tree = as_ast(text)

0 commit comments

Comments
 (0)