@@ -88,7 +88,7 @@ def merge_text(current_text: str, parameter_text: str):
8888 for start , line in enumerate (current_text_lines ):
8989 if line .startswith ("parameters =" ):
9090 break
91-
91+
9292 # find closing bracket of the parameters list
9393 bracket_count = 0
9494 for end , line in enumerate (current_text_lines [start :]):
@@ -152,7 +152,57 @@ def as_ast(text: str):
152152 traceback_to_show = '\n ' .join (last_lines )
153153 logger .error (traceback_to_show )
154154 return None
155-
155+
156+ def expand_center_of_mass_pars (constraint : ast .Assign ) -> list [ast .Assign ]:
157+ """Expand center of mass parameters to include all components."""
158+
159+ # check if this is a COM assignment we need to expand
160+ if (len (constraint .targets ) != 1 or
161+ not isinstance (constraint .targets [0 ], ast .Name ) or
162+ not isinstance (constraint .value , ast .Name )):
163+ return constraint
164+
165+ lhs = constraint .targets [0 ].id
166+ rhs = constraint .value .id
167+
168+ # check if lhs is a COM parameter (with or without 'd' prefix)
169+ lhs_is_delta = lhs .startswith ('d' )
170+ lhs_base = lhs [1 :] if lhs_is_delta else lhs
171+
172+ if not (lhs_base .startswith ("COM" ) and lhs_base [3 :].isdigit ()):
173+ return constraint
174+
175+ # check rhs
176+ rhs_is_delta = rhs .startswith ('d' )
177+ rhs_base = rhs [1 :] if rhs_is_delta else rhs
178+
179+ new_targets , new_values = [], []
180+ if rhs_base .startswith ("COM" ) and rhs_base [3 :].isdigit ():
181+ print ("CASE DOUBLE COM" )
182+ # rhs is also a COM parameter: COM2 = COM1 -> COMX2, COMY2, COMZ2 =COMX1, COMY1, COMZ1
183+ lhs_shape_num = lhs_base [3 :]
184+ rhs_shape_num = rhs_base [3 :]
185+
186+ for axis in ['X' , 'Y' , 'Z' ]:
187+ lhs_full = f"{ 'd' if lhs_is_delta else '' } COM{ axis } { lhs_shape_num } "
188+ rhs_full = f"{ 'd' if rhs_is_delta else '' } COM{ axis } { rhs_shape_num } "
189+ new_targets .append (ast .Name (id = lhs_full , ctx = ast .Store ()))
190+ new_values .append (ast .Name (id = rhs_full , ctx = ast .Load ()))
191+
192+ else :
193+ print ("CASE SINGLE COM" )
194+ # rhs is a regular parameter: COM2 = X -> COMX2, COMY2, COMZ2 = X, X, X
195+ lhs_shape_num = lhs_base [3 :]
196+ rhs_full = f"{ 'd' if rhs_is_delta else '' } { rhs_base } "
197+ for axis in ['X' , 'Y' , 'Z' ]:
198+ lhs_full = f"{ 'd' if lhs_is_delta else '' } COM{ axis } { lhs_shape_num } "
199+ new_targets .append (ast .Name (id = lhs_full , ctx = ast .Store ()))
200+ new_values .append (ast .Name (id = rhs_full , ctx = ast .Load ()))
201+
202+ constraint .targets = [ast .Tuple (elts = new_targets , ctx = ast .Store ())]
203+ constraint .value = ast .Tuple (elts = new_values , ctx = ast .Load ())
204+ return constraint
205+
156206 def parse_ast (tree : ast .AST ):
157207 params = None
158208 imports = []
@@ -166,24 +216,37 @@ def parse_ast(tree: ast.AST):
166216 case ast .Assign ():
167217 if node .targets [0 ].id == 'parameters' :
168218 params = node
219+ elif node .targets [0 ].id .startswith ('dCOM' ) or node .targets [0 ].id .startswith ('COM' ):
220+ constraints .append (expand_center_of_mass_pars (node ))
169221 else :
170222 constraints .append (node )
171223
224+ print (f"Parsed constraints: { constraints } " )
172225 return params , imports , constraints
173-
226+
174227 def extract_symbols (constraints : list [ast .AST ]) -> tuple [list [str ], list [str ]]:
175228 """Extract all symbols used in the constraints."""
176229 lhs , rhs = set (), set ()
177230 for node in constraints :
178231 # left-hand side of assignment
179232 for target in node .targets :
180- if isinstance (target , ast .Name ):
181- lhs .add (target .id )
233+ match target :
234+ case ast .Name ():
235+ lhs .add (target .id )
236+ case ast .Tuple ():
237+ for elt in target .elts :
238+ if isinstance (elt , ast .Name ):
239+ lhs .add (elt .id )
182240
183241 # right-hand side of assignment
184242 for value in ast .walk (node .value ):
185- if isinstance (value , ast .Name ):
186- rhs .add (value .id )
243+ match value :
244+ case ast .Name ():
245+ rhs .add (value .id )
246+ case ast .Tuple :
247+ for elt in value .elts :
248+ if isinstance (elt , ast .Name ):
249+ rhs .add (elt .id )
187250
188251 return lhs , rhs
189252
@@ -215,13 +278,16 @@ def validate_imports(imports: list[ast.ImportFrom | ast.Import]):
215278
216279 def mark_named_parameters (checkedPars : list [list [bool ]], modelPars : list [str ], symbols : set [str ]):
217280 """Mark parameters in the modelPars as checked if they are in symbols_lhs."""
281+ def in_symbols (par : str ):
282+ if par in symbols : return True
283+ if 'd' + par in symbols : return True
284+ return False
285+
218286 for i , shape in enumerate (modelPars ):
219287 for j , par in enumerate (shape ):
220288 if par is None :
221289 continue
222- in_symbols = par in symbols
223- d_in_symbols = "d" + par in symbols
224- checkedPars [i ][j ] = checkedPars [i ][j ] or in_symbols or d_in_symbols
290+ checkedPars [i ][j ] = checkedPars [i ][j ] or in_symbols (par )
225291 return checkedPars
226292
227293 tree = as_ast (text )
0 commit comments