@@ -256,11 +256,22 @@ class GeluApprox(str, Enum):
256256 none = "none"
257257
258258
259- geluDesc = OperatorDescriptor (inputDescriptor = IoDesc ("data_in" ),
260- outputDescriptor = IoDesc ("data_out" ),
261- attrDescriptors = [
262- AttrDesc ("approximate" , GeluApprox , default = GeluApprox .none ),
263- ])
259+ geluDesc = OperatorDescriptor (
260+ inputDescriptor = IoDesc ("data_in" ),
261+ outputDescriptor = IoDesc ("data_out" ),
262+ attrDescriptors = [
263+ AttrDesc ("approximate" , GeluApprox , default = GeluApprox .none ),
264+ ],
265+ )
266+
267+ iGeluDesc = OperatorDescriptor (
268+ inputDescriptor = IoDesc ("data_in" ),
269+ outputDescriptor = IoDesc ("data_out" ),
270+ attrDescriptors = [
271+ AttrDesc ("b" , IntUnpack ),
272+ AttrDesc ("one" , IntUnpack ),
273+ ],
274+ )
264275
265276requantizedIGeluDesc = OperatorDescriptor (inputDescriptor = IoDesc (["data_in" , "mul" , "add" , "shift" ]),
266277 outputDescriptor = IoDesc ("data_out" ),
@@ -691,6 +702,12 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
691702 ],
692703)
693704
705+ sgdDesc = OperatorDescriptor (
706+ inputDescriptor = IoDesc (["weight" , "grad" ]),
707+ outputDescriptor = IoDesc ("weight_updated" ),
708+ attrDescriptors = [AttrDesc ("lr" , FloatUnpack )],
709+ )
710+
694711defaultOperatorDescriptors : Dict [str , OperatorDescriptor ] = {
695712 "Add" : addDesc ,
696713 "CLCA" : clcaDesc ,
@@ -729,12 +746,14 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
729746 "RequantizediHardswish" : requantizedIHardswishDesc ,
730747 "RequantShift" : requantShiftDesc ,
731748 "Reshape" : reshapeDesc ,
749+ "SGD" : sgdDesc ,
732750 "Slice" : sliceDesc ,
733751 "Softmax" : softmaxDesc ,
734752 "SoftmaxGrad" : softmaxGradDesc ,
735753 "Squeeze" : squeezeDesc ,
736754 "Transpose" : transposeDesc ,
737755 "Unsqueeze" : unsqueezeDesc ,
756+ "iGELU" : iGeluDesc ,
738757 "iHardswish" : iHardswishDesc ,
739758 "iLayerNorm" : iLayerNormDesc ,
740759 "iNoNorm" : iNoNormDesc ,
0 commit comments