@@ -405,6 +405,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
405405 }
406406
407407 var one = ctx . AddInitializer ( 1.0f , "one" ) ;
408+ var oneInt = ctx . AddInitializer ( 1 , typeof ( int ) , "oneInt" ) ;
408409 var zero = ctx . AddInitializer ( 0.0f , "zero" ) ;
409410 var labelCount = ctx . AddInitializer ( ( float ) _labelCount , "labelCount" ) ;
410411 var trainingCount = ctx . AddInitializer ( ( float ) _totalTrainingCount , "totalTrainingCount" ) ;
@@ -414,108 +415,119 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
414415 var labelHistogramName = ctx . AddInitializer ( labelHistogramExpanded , new long [ ] { _featureHistogram [ 0 ] . Length , _labelHistogram . Length } , "labelHistogramExpanded" ) ;
415416 var learnedAbsentFeatureLogProb = ctx . AddInitializer ( _absentFeaturesLogProb , new long [ ] { _absentFeaturesLogProb . Length , 1 } , "absentFeaturesLogProb" ) ;
416417
417- var greaterOutput = ctx . AddIntermediateVariable ( null , "greaterOutput" , true ) ;
418+ var typeOne = new VectorDataViewType ( NumberDataViewType . Single , 1 ) ;
419+ var typeFea = new VectorDataViewType ( NumberDataViewType . Single , _featureHistogram [ 0 ] . Length ) ;
420+ var typeLabelByFea = new VectorDataViewType ( NumberDataViewType . Single , _labelHistogram . Length , _featureHistogram [ 0 ] . Length ) ;
421+ var typeLabelByOne = new VectorDataViewType ( NumberDataViewType . Single , _labelHistogram . Length , 1 ) ;
422+
423+ var greaterOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( BooleanDataViewType . Instance , _featureHistogram [ 0 ] . Length ) , "greaterOutput" ) ;
418424 var opType = "Greater" ;
419425 ctx . CreateNode ( opType , new [ ] { featureColumn , zero } , new [ ] { greaterOutput } , ctx . GetNodeName ( opType ) , "" ) ;
420426
421427 opType = "Cast" ;
422- var isFeaturePresent = ctx . AddIntermediateVariable ( null , "isFeaturePresent" , true ) ;
423- var node = ctx . CreateNode ( opType , greaterOutput , isFeaturePresent , ctx . GetNodeName ( opType ) , "" ) ;
428+ var castOutput = ctx . AddIntermediateVariable ( typeFea , "CastOutput" ) ;
429+ var node = ctx . CreateNode ( opType , greaterOutput , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
424430 var t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
425431 node . AddAttribute ( "to" , t ) ;
426432
433+ opType = "ExpandDims" ;
434+ var isFeaturePresent = ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Single , 1 , _featureHistogram [ 0 ] . Length ) , "isFeaturePresent" ) ;
435+ ctx . CreateNode ( opType , new [ ] { castOutput , oneInt } , new [ ] { isFeaturePresent } , ctx . GetNodeName ( opType ) , "com.microsoft" ) ;
436+
427437 //initialize logProb
428438 opType = "Div" ;
429- var divOutput = ctx . AddIntermediateVariable ( null , "DivOutput" , true ) ;
439+ var divOutput = ctx . AddIntermediateVariable ( typeOne , "DivOutput" ) ;
430440 ctx . CreateNode ( opType , new [ ] { labelHistogram , trainingCount } , new [ ] { divOutput } , ctx . GetNodeName ( opType ) , "" ) ;
431441
432442 opType = "Log" ;
433- var logOutput = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
443+ var logOutput = ctx . AddIntermediateVariable ( typeOne , "LogOutput" ) ;
434444 ctx . CreateNode ( opType , divOutput , logOutput , ctx . GetNodeName ( opType ) , "" ) ;
435445
436446 //log1
437447 opType = "Sum" ;
438- var sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
448+ var sumOutput = ctx . AddIntermediateVariable ( _inputType , "SumOutput" ) ;
439449 ctx . CreateNode ( opType , new [ ] { featureHistogramName , one } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
440450
441- var logOutput1 = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
451+ var logOutput1 = ctx . AddIntermediateVariable ( typeLabelByFea , "LogOutput" ) ;
442452 LogMul ( ctx , sumOutput , isFeaturePresent , logOutput1 ) ;
443453
444454 //log2
445455 opType = "Transpose" ;
446- var labelHistogramTrans = ctx . AddIntermediateVariable ( null , "transpose" , true ) ;
456+ var labelHistogramTrans = ctx . AddIntermediateVariable ( typeFea , "Transpose" ) ;
447457 ctx . CreateNode ( opType , labelHistogramName , labelHistogramTrans , ctx . GetNodeName ( opType ) , "" ) ;
448458
449459 opType = "Sub" ;
450- var absentFeatureCount = ctx . AddIntermediateVariable ( null , "AbsentFeatureCounts" , true ) ;
460+ var absentFeatureCount = ctx . AddIntermediateVariable ( typeFea , "AbsentFeatureCounts" ) ;
451461 ctx . CreateNode ( opType , new [ ] { labelHistogramTrans , featureHistogramName } , new [ ] { absentFeatureCount } , ctx . GetNodeName ( opType ) , "" ) ;
452462
453463 opType = "Sum" ;
454- sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
464+ sumOutput = ctx . AddIntermediateVariable ( typeFea , "SumOutput" ) ;
455465 ctx . CreateNode ( opType , new [ ] { labelHistogramTrans , labelCount } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
456466
457- var logOutput2 = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
467+ var logOutput2 = ctx . AddIntermediateVariable ( typeLabelByFea , "LogOutput" ) ;
458468 LogMul ( ctx , sumOutput , isFeaturePresent , logOutput2 ) ;
459469
460470 //log3
461471 opType = "Sum" ;
462- sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
472+ sumOutput = ctx . AddIntermediateVariable ( typeFea , "SumOutput" ) ;
463473 ctx . CreateNode ( opType , new [ ] { absentFeatureCount , one } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
464474
465- var logOutput3 = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
475+ var logOutput3 = ctx . AddIntermediateVariable ( typeLabelByFea , "LogOutput" ) ;
466476 LogMul ( ctx , sumOutput , isFeaturePresent , logOutput3 ) ;
467477
468478 //result
469479 opType = "Sub" ;
470- var logProb = ctx . AddIntermediateVariable ( null , "LogProb" , true ) ;
480+ var logProb = ctx . AddIntermediateVariable ( typeLabelByFea , "LogProb" ) ;
471481 ctx . CreateNode ( opType , new [ ] { logOutput1 , logOutput2 } , new [ ] { logProb } , ctx . GetNodeName ( opType ) , "" ) ;
472482
473483 opType = "Sub" ;
474- var absentFeatureLogProb = ctx . AddIntermediateVariable ( null , "AbsentFeatureLogProb" , true ) ;
484+ var absentFeatureLogProb = ctx . AddIntermediateVariable ( typeLabelByFea , "AbsentFeatureLogProb" ) ;
475485 ctx . CreateNode ( opType , new [ ] { logOutput3 , logOutput2 } , new [ ] { absentFeatureLogProb } , ctx . GetNodeName ( opType ) , "" ) ;
476486
477487 opType = "ReduceSum" ;
478- var logProbReduceSum = ctx . AddIntermediateVariable ( null , "ReduceSum" , true ) ;
488+ var logProbReduceSum = ctx . AddIntermediateVariable ( typeLabelByOne , "ReduceSum" ) ;
479489 node = ctx . CreateNode ( opType , new [ ] { logProb } , new [ ] { logProbReduceSum } , ctx . GetNodeName ( opType ) , "" ) ;
480- long [ ] list = { 1 } ;
490+ long [ ] list = { 2 } ;
481491 node . AddAttribute ( "axes" , list ) ;
482492
483493 opType = "ReduceSum" ;
484- var absentFeatureLogProbReduceSum = ctx . AddIntermediateVariable ( null , "ReduceSum" , true ) ;
494+ var absentFeatureLogProbReduceSum = ctx . AddIntermediateVariable ( typeLabelByOne , "ReduceSum" ) ;
485495 node = ctx . CreateNode ( opType , new [ ] { absentFeatureLogProb } , new [ ] { absentFeatureLogProbReduceSum } , ctx . GetNodeName ( opType ) , "" ) ;
486496 node . AddAttribute ( "axes" , list ) ;
487497
488498 opType = "Cast" ;
489- var castOutput = ctx . AddIntermediateVariable ( null , "CastOutput2" , true ) ;
499+ castOutput = ctx . AddIntermediateVariable ( NumberDataViewType . Single , "CastOutput" ) ;
490500 node = ctx . CreateNode ( opType , learnedAbsentFeatureLogProb , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
491501 t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
492502 node . AddAttribute ( "to" , t ) ;
493503
494504 opType = "Sub" ;
495- var subOutput = ctx . AddIntermediateVariable ( null , "SubOutput" , true ) ;
505+ var subOutput = ctx . AddIntermediateVariable ( typeLabelByOne , "SubOutput" ) ;
496506 ctx . CreateNode ( opType , new [ ] { castOutput , absentFeatureLogProbReduceSum } , new [ ] { subOutput } , ctx . GetNodeName ( opType ) , "" ) ;
497507
498508 opType = "Sum" ;
499- sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
509+ sumOutput = ctx . AddIntermediateVariable ( typeLabelByOne , "SumOutput" ) ;
500510 ctx . CreateNode ( opType , new [ ] { subOutput , logProbReduceSum , logOutput } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
501511
502- opType = "Transpose " ;
503- var transposeOutput = ctx . AddIntermediateVariable ( null , "TransposeOutput" , true ) ;
504- ctx . CreateNode ( opType , new [ ] { sumOutput } , new [ ] { outputNames [ 1 ] } , ctx . GetNodeName ( opType ) , "" ) ;
512+ opType = "Squeeze " ;
513+ var squeezeNode = ctx . CreateNode ( opType , sumOutput , outputNames [ 1 ] , ctx . GetNodeName ( opType ) , "" ) ;
514+ squeezeNode . AddAttribute ( "axes" , new long [ ] { 2 } ) ;
505515
506516 opType = "ArgMax" ;
507- var scoreIndex = ctx . AddIntermediateVariable ( null , "ScoreIndex" , true ) ;
508- ctx . CreateNode ( opType , new [ ] { sumOutput } , new [ ] { scoreIndex } , ctx . GetNodeName ( opType ) , "" ) ;
517+ var scoreIndex = ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Int64 , 1 ) , "ScoreIndex" ) ;
518+ node = ctx . CreateNode ( opType , new [ ] { sumOutput } , new [ ] { scoreIndex } , ctx . GetNodeName ( opType ) , "" ) ;
519+ node . AddAttribute ( "axis" , 1 ) ;
520+ node . AddAttribute ( "keepdims" , 0 ) ;
509521
510522 opType = "Cast" ;
511- castOutput = ctx . AddIntermediateVariable ( null , "CastOutput3" , true ) ;
523+ castOutput = ctx . AddIntermediateVariable ( typeOne , "CastOutput" ) ;
512524 node = ctx . CreateNode ( opType , scoreIndex , castOutput , ctx . GetNodeName ( opType ) , "" ) ;
513525 t = InternalDataKindExtensions . ToInternalDataKind ( DataKind . Single ) . ToType ( ) ;
514526 node . AddAttribute ( "to" , t ) ;
515527
516528 //log3
517529 opType = "Sum" ;
518- sumOutput = ctx . AddIntermediateVariable ( null , "SumOutput" , true ) ;
530+ sumOutput = ctx . AddIntermediateVariable ( typeOne , "SumOutput" ) ;
519531 ctx . CreateNode ( opType , new [ ] { castOutput , one } , new [ ] { sumOutput } , ctx . GetNodeName ( opType ) , "" ) ;
520532
521533 opType = "Cast" ;
@@ -529,7 +541,7 @@ bool ISingleCanSaveOnnx.SaveAsOnnx(OnnxContext ctx, string[] outputNames, string
529541 private void LogMul ( OnnxContext ctx , string input , string isFeaturePresent , string output )
530542 {
531543 var opType = "Log" ;
532- var logOutput = ctx . AddIntermediateVariable ( null , "LogOutput" , true ) ;
544+ var logOutput = ctx . AddIntermediateVariable ( new VectorDataViewType ( NumberDataViewType . Single , _featureHistogram [ 0 ] . Length ) , "LogOutput" ) ;
533545 ctx . CreateNode ( opType , input , logOutput , ctx . GetNodeName ( opType ) , "" ) ;
534546
535547 opType = "Mul" ;
0 commit comments