Skip to content

Commit 3cb7c3c

Browse files
committed
Add SoftmaxCrossEntropyLoss(Grad)
1 parent c5c0222 commit 3cb7c3c

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -708,6 +708,18 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
708708
attrDescriptors = [AttrDesc("lr", FloatUnpack)],
709709
)
710710

711+
softmaxCrossEntropyLossDesc = OperatorDescriptor(
712+
inputDescriptor = IoDesc(["logits", "labels"]),
713+
outputDescriptor = IoDesc("log_prob"),
714+
attrDescriptors = [],
715+
)
716+
717+
softmaxCrossEntropyLossGradDesc = OperatorDescriptor(
718+
inputDescriptor = IoDesc(["log_prob", "labels"]),
719+
outputDescriptor = IoDesc("grad"),
720+
attrDescriptors = [],
721+
)
722+
711723
defaultOperatorDescriptors: Dict[str, OperatorDescriptor] = {
712724
"Add": addDesc,
713725
"CLCA": clcaDesc,
@@ -749,6 +761,8 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
749761
"SGD": sgdDesc,
750762
"Slice": sliceDesc,
751763
"Softmax": softmaxDesc,
764+
"SoftmaxCrossEntropyLoss": softmaxCrossEntropyLossDesc,
765+
"SoftmaxCrossEntropyLossGrad": softmaxCrossEntropyLossGradDesc,
752766
"SoftmaxGrad": softmaxGradDesc,
753767
"Squeeze": squeezeDesc,
754768
"Transpose": transposeDesc,

0 commit comments

Comments
 (0)