From a850a2df071c55d7641caf45282d37c2045c63bf Mon Sep 17 00:00:00 2001 From: RoomWithOutRoof Date: Wed, 15 Apr 2026 20:22:00 +0800 Subject: [PATCH 1/2] Add unit tests for src modules - test_msign.py: Add unit tests for msign matrix sign function - test_manifold_muon.py: Add unit tests for manifold muon optimizer - test_hyperspherical_descent.py: Add unit tests for hyperspherical descent --- pytest.ini | 5 + ...rical_descent.cpython-312-pytest-7.4.4.pyc | Bin 0 -> 22173 bytes ...manifold_muon.cpython-312-pytest-7.4.4.pyc | Bin 0 -> 18350 bytes .../test_msign.cpython-312-pytest-7.4.4.pyc | Bin 0 -> 16498 bytes tests/test_hyperspherical_descent.py | 126 ++++++++++++++++++ tests/test_manifold_muon.py | 110 +++++++++++++++ tests/test_msign.py | 88 ++++++++++++ 7 files changed, 329 insertions(+) create mode 100644 pytest.ini create mode 100644 tests/__pycache__/test_hyperspherical_descent.cpython-312-pytest-7.4.4.pyc create mode 100644 tests/__pycache__/test_manifold_muon.cpython-312-pytest-7.4.4.pyc create mode 100644 tests/__pycache__/test_msign.cpython-312-pytest-7.4.4.pyc create mode 100644 tests/test_hyperspherical_descent.py create mode 100644 tests/test_manifold_muon.py create mode 100644 tests/test_msign.py diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..38be796 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +testpaths = tests +python_files = test_*.py +python_classes = Test* +python_functions = test_* \ No newline at end of file diff --git a/tests/__pycache__/test_hyperspherical_descent.cpython-312-pytest-7.4.4.pyc b/tests/__pycache__/test_hyperspherical_descent.cpython-312-pytest-7.4.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c3f3d9f5a43edc1ef8b3e7e849d62aa8905cfc10 GIT binary patch literal 22173 zcmeHPTW}lKdEN!Ez~Vwuq(sL`VhNTdTd*aH04a%bB+Ig;IJFykoSQT1BnXUMil9M& zJ_|}FyoeoFZRlyHmBurXGUK*+X-6768Naj-ed$Zz@?uA->4DXe+jz#+nV=ISb?oO$Wvs;`tnZC8dC+ zzwSN9uKlroPS@8^VoCd1^pu{9OkcYn>dZyjMx##3v+eYJskR9H>e#jg-5qo6>WqW! zWIz7qp4s}my=U>6#ad%=&scrIu9d5$Q+rO;$+ONvsci3Vlq${29;l%*Q?FXio(8-* zqs8FQHny;YP1sg*%c{LJhR^ZCIXi&-r}mqdv=7XI6?603=H~O;E}M6~XAWGEH~!VU z>wg>^lP`}K^s?wdxPf(GFW>w&^d<~c{it?E|IT(zYw5jv+SUJ3Yeo9V6Kxn!KTU4L zjq)Mfin+15IDaCKEyFcj{n?$sLEMcy?M9M6o1@c>U$LVVapN<(6~!_$k=MOxXT=%= zJm1_PKa)p3v`=czs8at1nS-CXF+QG%nsbCdmq(eirbGLZ>!7>_uhlYc8K+|8hjJ=% z0;h%*e=XXC@6bwE3Cna76`fw0jZyw?%Dddb@5!TU@^|6Kmvt-oyxuZe@iY35?u%(l z$y};A5j@tj&TM_LYGoIQ?bzg`ovqXs78`K4v|x|r;uj;=@RO1uAh`%VT*J`{=lHR_ z``2)QT|+d0``F!N7|2Ea8V;C?`p1z&hjMyz_l}(lONE@1dcix%9?Ig`c+MH4?WJ1T zL+sXM&S}Siqs3~YZO~`AXgdm7kra?jHU~=%aNM9(?1TJpbKmsE48aN2Hb|*v)d;dm z?!+6ZhZmM`;ulM#JzNBgt`>>CK+5&`g%Yvb8F*ZLX0cLjRBF(4wGKI)_MJu6J2Y$Y zM7915U{I|zZ?}_0r!-?@4eh}rl#fcic8riZX>Tmf)X99QQM8|5sFrG_20Y0n33dfz zMSr@jKSn->T&1yD_8xilVuD?;iuS4(`yCwES)NE5t9gLFS?~ug@-OSa%fLr%9$OvL z)=&*jYUjs5UeewlJ$h4%C3DNiu51`Rmp+?5|IN2H+_QY7<7)JQ3r8+IcWE@g{N<~e zoh!eczh7o{UQYrqH-R(FhqDQ|L2(wO`5L5H7z#^y2Usj|R(B%;XQOU35NBgn1h6*d z8Uky7rp)vptc_Bv1>ANG1#52|f<-xkgd&*8i3p}}9fY+4!9(&%O#W>lVhm6*-ZCg2 z=7XmT$3s*fvZpGI*=+tO!^APN35E~>MN=aP%@@~%XnS)Go$?o>WHZ+Ac_0W`dp}lj zY(K-X{nO+wY{5f~`>+7WV~y58a#&PBj!AOAxWf=*pNbe$L5Q)v*Hq72MJffF53LU> zczG07&PgDk8)b(7>WNpLI3M|oCztgrzKh}eF6bBXmxgyQ$FG{(S8mbYTjq8$22BTJ zK&&qefVqt_4N^D?D0x4_0d4O8Ar2G-4xlnbPykXKAmez3;+*7>Ody#A!X`1<3#V(I zf`v~w1$lXd5yu50r=Wg5e<^d9oQ2}rHxYNNn78KdEptcrL_{5i3gZtxm5JDZ=2;3B z%n=PltPpW%5Cz%jtvKY1b5o|3a1$A=WxA&GzgEPGHc&I8X)eRhs8VMD=duSuosjPs)Vk*Lh>beKFfG5PED_*VtPqlf=sU@FUyZ$dc{G>Kt;nqWdr^vTINzbmuT+ot#H)p z4KTv$)?(StPM4;$Cri~u+Zm%4R`UhV*1=3@G9BPn6r6Zg6sQi&u!0ls7U&%87Mysu zRx!USPOVf+wRU3N6FDcBYK@g@OK=0us6oYam$+YdnIHuiwiN47l{qk-W7Zehg12#g zw+&izI|{XtyYcQeBwLZ(gJe4r^q`P?fm}>6GbupfA^YIDs`Yp}4(pgHX25xt)Hu6h zx==pvpeoG`rfim(Oxt=XK(#tpd9~)j^{GFsG``u*zX!5bOp)`ZPm-EvP&@{{&=h8shsK zH{aAE$pg!ezPs^`b5m!hUVU)+=+(^B%5D66D>HT703NzA8khYjM*j!U{uhf|x(om2 zgaaLF@e;?>CVaGTHyocs;EWge3^gmxc!_L*$P9M>T#UCZk1Yd`+Tgt4!f23H-%3EP zgvezA$|nKkQ*bxtnr_lfDJY+I(`R&Nn&P~)G6U>rULW-pp!=Doxd055@5cK9a9)-O zuwc7OP@L~sM)4IS6x+Ht=c5WA0h~+i5^%nE>#DCH5t3KPgphnfa)soRa-%noQf~Cl zC*(#*yFx~Uwe1BapHA9kh~{v*L|$NS*}J{DmT+^1`)&XJL*ERYv9W+yjDVUdeU) zPQWoThSH3=MaC)(+dE})Fb~iDf-n!F5gjNf;iv;8my!+?@(|t?4is`2F59t%rF@>! z2}<`;x_>%1sEAG~H4$T3#b2idk(+*m=RlC~q)SR2B9k1$da2y>+$sg%)uiugpHIS0 zIVc7c3Cq1M_w|yoRKT)9(mb+00jpGT5IAt&0V1U6&7uYwfxNY1%ClGlSYs*B z-%e$#8wI%nb{!^^ui#4uk2*-m&ru19D+V+LNo}8m>=}3@IpHV=5>l23kYU{=sD$iU z1~dg}3H8hQsKN&>E_EcIgzVkAPeNLzvg{rZRG{=-7zbh@A7pfI4n$7qD>EdokiAy& zHNy*i5>itsi(;Z91%^JQKsHg|dn>V&%B7o=z2y@`87+Wd%G!-m7D9m_xTj$gw$^{* zg-t(y?BAQG4}=E;_!%1)91VKltbfWSC4JL?o;(m6t(J<;M z&rb>$jZT^I1?UCxMIh5)3gN6*{Ou-hqaRt&6<%OtWiv?7hEl!lUJMd@_5EZD9tW*y z-rEb`%ZpBPw($-^TBSxfDlUs z!{4k)F!=Uc1h}%S2v<&h!lQT2)!k{HqN{0PyFnB)qjy{CzK7m9SK}&FP;nup8XVNr zh^`_0MHM#zFyAe1)+L}~Az&&^4da%H2Vp>jiU*Ms!~-r^rD4p{6yzTDnHm!21Nu(7 z7(+Qw1ynqwuY`cI5}v7%=sVO$2rWK)z&(bpFi?4x%v`d$XBDC2?XP&o&RS?$%vWj^ zr%@@75!AM5NOAM)iX6G6)&!9XJ<2g58BpbvMg!?rGMJsdBU!za47;Fi2ce6*e^|Ut zu9s!t>zU@G>oX2~l{}4HTm=GBLGHqWZCQQ;!54_k?1M`qW6O#4l=;xyvvS-2x}^pV zD=Ibq8<2S|3or^kdu|x0Nu++=K%n~TO>i@`VoH6myQwpQfgl%U3@ZjMAH$753oWDD z$N5Qbr-cPe=CJFGmQi*>tYz65=KxJXL6=hP2mFya1B?op6A&;Fl2_u7`C5dzI#8{5 zO>fVa8h9vM4N`-}!hwQDX+Wa{piw%PY3}te6VKP_&aSLgA)rIThNNtLCR?*lz~&@) zaDtSq3T#*Pb~t<$x`k=GI1Y09$NE9oy)}%U3fQWZJxEu5!Fkb6m+J%;e%TJE2y;`Q zSKGHW3w+?z%Pqi1EdREnR=vSu0khbo)J^Zf%k4;ZAi>Z;aGj&`iul#icFd~G%#f4t zMDh^)l#jsLtj80^X``3gXy?vyrw>ZE{ z+YnI4CuXRBxvK;K396y}a-{lV12}TUW>bK&-d#_pH!IRo)VaZHGko87ENoe09hS;T4!`J0{*LL-{0kSjg(vG-xr2q=tvPc z;k&?Z5!|vq5uMM!d+6-xDiJC~hohy?y{x0E`=p!hgPQ5Bm=q=pmxX7Yj2Gf}Mj3 zb`H3MoNtVN1Y>!tHTXL?8;1HGiz*rh;pTuCgKEwoV^!pYu{ZSlHXGo(dM+GgG3II5+= z&R4#hg@tDX7}l`vHueo#TdIW-i>ZnOyZ;-dTG?iLCaOU>gYMPs@U)UTJZe6Ljps$p zc;@+j^Pnly_A&6piw}Y3*g|0*&UikEwN!ba9@m^f!#>8zx1kYwTA*~2(tQMVnC6h; zG{{3t3w?HAy7`oHUvMTF6MS*D1Xq%&PM**K4Bdc%MCf78vknnnXEivb9cLX%T98@M z=>0o>+VRo?{7#G29^}83@Jzj4tuIJ*pY9&YZRphbb69Hy$s97HY=ZXo{IHkLPxyjU z_L2q6t;T@vC*SM1RTyWE?)%|Bp-v2(1l)cJANJoR^YJT65YN@y?vP`%}zIKGUx+Vncz@~*(%1x4SMX3ld z8(L}by{DBnGHxcYjRD9vz&*yTL3a?^7;=Z8jT9Jd+o*ve6{G+G&dbuUSzOAaX)eHr zBRS?h32LV-5unC(m!N8A&oa_ew1$lx>ZXSP&SgJ>%HF$m)l=X1Mvjm=r6S$CZK+5n zyanG4uz?1}N=I@x@Gk%+SCU+(#>MZ|SAJ23;P$_LNBwrU?JK`142a%e`9-1LhU629 zB$!KVBlYkqUwK%5Wu?v_BPZ5nE{#xgX~Z*^HcHb*sR@IlI+DAQx5D-^gM6s)It%sSS$op79YjSqezY+`3{mFB6$Uf zPZ9NPQ!n^b)`Bm(y5L1udpnjqb(UQf_Ntm}!54X5m~v2BacxySGR<{IFU^$>S{k2s zcztTCs9*`0aB@DKb=6<~d2{>a;r;mAKhrCJsDE#p(^OrY&TVX)MfwYK#bP^EEY8F4 zK~!xx&lHQ#!4F{KZ_mM1oc)$vJ26vv9s-zbg1<0Ab|X1}y10(#`sMr_oPl= zp0<|Dv`zjqd~}w!Hy7EFUtJoWVI`E#Bc~49Kn|A2e Q@l(3~q<%xA2U^Mh0gG9rm;e9( literal 0 HcmV?d00001 diff --git a/tests/__pycache__/test_manifold_muon.cpython-312-pytest-7.4.4.pyc b/tests/__pycache__/test_manifold_muon.cpython-312-pytest-7.4.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e46e48ddd5276bb8bca5d3cd2baa94e672b782a5 GIT binary patch literal 18350 zcmeHPU2GiH6`tAInVtQy6FV4U+Q1SBWJzLs?fg=bK+BJS;3q33UMU&5>6uEoz{)V{Rge~s!u1Ksq)d!OOl3+Ji?OsJACt~_*<-2^ zIVJ&P&@r$T1*#h=(72%iO&BqtNka#kGU7nHj0DhbBMEe!kpjBj=mOeP>&~om*FR}m z_TwI#k0Y0|(7yud+u?t463Anc44m5j{VDzoLIp!UFGn(U>^z{r>YuK`B)ccl>$pB{Kd$G0v`KR>n zA4vbuyXN)&vwHu@&FA$^Z|hwbf{eHHP5-g*OM7u>SZ*3mYy-!1@fjF9`26Gtq~r2; z?~tU1+7a~atI$3@oQR|F53><&Zk`Ix{L6h_vigMf*9Rqd%8Bx?hlnM6`TgJ= zlq}28f53HES(UeHs8@_r|2uEXsb3N}RSoSmr4H|*5i??j?!<~Rz0&PI{%pdt+{p8S zv!nB8;pZ1+BYsqFs15D7{Ik2FQZ1fI)bDQkveVXNWu|1LtHiWS^1PWYma8*1)v#gyN%hiCDnUd|Q^j=1B6)1{i7|G6hcisY4ZQ8_cC=Z#t zN9R%mCy=X>eAy@yhyYixGCg8rmgZ?6L4ePV<8#v6;>E$(mkh zj=xC@o&jEBePkJ8wXcLT)p`ZU52Sbd_Ag4&cxLv&3+ww%BwtRR{MH-mZ=bz?$yMLx zQ`uAbbA4N9AG(;@I)8Qk-b`)%Fb=#d0%uYac*Y=uf@hS1XKOGtWFz8(UmI|3DjN|1 zRCavR{*+6z=x8VK2kKyG)z`fESS~+12Y6I4s4hv z8-UEM3cq$`EIPsG6mk=`un`CX)UL}SYVBdDwP%!k0ekQ;0`&)SClZ7Rau*O08g>$p z#9j`=hutDR>~}AcWgkDP1yuC?vop~r zuR`<%WtWIPBl5a(27860q1Y?|G+YcuBoKoB$A@-Bm?hiLpq$1(=SIwlrKE=L=+^&? zC}`)X@X%u_#Z$qV*8)4D-$rGX9#bK9M8K;r3dI^B;%#{i)YgSwLh1xI2nf$kIqwtN z4B*5=w8MX_q1W_GyuQ6j?d5h%xnhGYQ!!==W_mO~ntncCnn3}N{Vt%!H7A2%VVhN} z{(@&mU==h+_HhdZ5}t(uf|?mEknplV!m}($cv&l$HbtwPFDDvXvF(wJl}R)P^5q(2 zK#CeGOjFRBHH8r}1YA;#LB(O;Xoguf>f7!m;GWNIK2Tk^T^Mqg(;Qsa(Ii z6HSAaMWz=|1L;`FI6Zv6X8_fV?eo|C?+tzX8YW9?n7oF`9bxh#pD)LT_rl^yiaH|gmcH&00F~VavYQdP7_A*igDoUmTfukO9BVU zHr)FD@HXXqt>ZSyl|tX-(3!oS{#*Sa#@^_`)bYN1ko z-Xs%d+ODMYY0IYm;ykIPtCeEe_B2)lx|;a;zy5IQqrd$A09dC7;M!H}N@)R?xZeNo z+=p*Ie)#JLL8*=AA-A6QbXwFo_i{}Y60J2=FiZ=jie(1ixSeXm;Ems&#&-8(d*Ey5 znyaf#o9H@ob#Nw??V?@F2x=_34|pRuf&_;Q)!4nGnQo@@25PMQq^p=UlljJV&O-;_ znn{Lfw66X%xa)D{fP4$PehSG_Cz|I{XQQ-7^q=js&`b+7T2G}7^PurJS1y>HdC-Ik z=sU390n)j0JM#Mn&iCw_)!yCEKdb+(r~kywFW-D}@@&tYv)Y9X8&B+edEd!resk#6 zL#OXQyJ_IuhQV3=qJGc(r}uYOzh}8rLW~D_BrT?{sZ?t6QUt4?;Ji1Y!o6sK#i1Sy z6{LvjC}7odSqUqm;GF;;1;~+`aP3ZTeTPV{COOh_MKs`qBMQiophk#{T3&-%{VnSR zNN&VB=DbhfBV;YN$v@G$chNrqYn@k&1Aid`dk1x=t9M|Q278@Ry7ND+O+d&SgM4@N82P!EKIy`!|xvcd;Qo! z2zXH7{sc@cI6gRd`DqXTxlaQyn|UiB@v&uqkNPo&uUXei49}^7^I{3=kIp5R06rN5 z9y@^d!uJ#L}A9VE7Lg`uKGYH&WNBqS|`>Yfpmb` zy|_gy8V3bz-;V&ePsTc|DD%I*1#BEop`K)u{at5Yy%Xof$@%( zV4N>iCzru7c>=~0pjf4MCw}!R9)CJW4nsD>zyQEZ$Qvc&cn48C47_e#8N6r*^`TY3 z7p?IS_OvbkZ~$OmIo~_924Kx=UTfMUd{yjN(wcnPMMD;mNfZdq{KvSxsWIcOW25oyG589#7THXsDO5vK-p)X|;z>zeO| zO*jcKpDR@B2+Xw*)3Qlp!w3y}Ooc3c2mJ9V3FdQ9Bg8~oUV~2JmURTciwg!|YtH)w zK0-JTLI%KiZQZ+QyvCc%*VY_}BN4*E@{zbIW9OUGZJyG1b+;+!o7ip2`EK4e<$RZ} zk$5dm&Hbc}!F==(>SAo)lEX>zua2L=+)r|IpSy+O9B(9Q$xN!g`w8l? zla+}|IbTX|EmTS~(`73Sfyykui~+`_4D$l8pv7rmfv%rXEtGr{F6+M*yuev-;xZQi zraqnzAPbS&UK^0|+6ZwiKnDZh(e>QO;C2Sjf#5*SLZAfqB~TwNL`VBkl1tnNNG}wao}GNLnHPzMMGqp zu11<-BX(Jc@`5MXb0yyuqSrHrO0H<-g7;qE+lhW3RuhN9dJjkkjtUl@=|xVw+jG;2 zjW2IJ`Ow*(%@{{OvXB>r-9+c8;I12_^l6@|F&1DP*q z5n<*0C0f`d#9YvE(Bph3wBNPl_)<=WC5z)S=>_fT`ukF@oXbDrf3xB@6*h)`N!0H? zYhTLu18peKltZ5K`@S~hFIFU0_)>g52ex4bZH8#bj$$M@1{Ecy5gAD%*eaI=e#Yg^p(OTuhe_gR|_@Ls*G#KM;wG8+PXHP(B$xvt}Hbsg`6 zpS^FjyAg%3j-D0v7v&2DvuZ<72W(DG+h*CS5HI|={^-h489Yds#a*nRT0co!OEf@- zb`wPlHZ0nCh(TbAEL!Usm2Nm9_dz;{zo4pg{TtYM)^8b8o+6LhJ?%f6#dr@|F3e~^ z^&M8u@QEl7bgvS0S%u!HcAh2AVaJ4D?-z;vcAvnIDXtx~5N!Xy@wD6i^UGe>L*-l&{O;w7>)LtYQAw8Pu zb@d$G!k^2ziCk_Pwg{C>NT+f+c&HL??uRRl{aAn-8!sNUXNXDAl|!~6*@NUBB#$6@ z97ztzB$8*5l#v`k@?#{kK;~Hg6paZk8B&6C&w3KbG3jq=@+XIWc<953qTaHoNZJh_ zVCqsrQ!|&=tLmMXx;?T3Q$XVCmoD8HQ+HnK70JO%y_!0-h&AW~({3^;SLeTNKt-I} zpJNX=z$6d+LcstrkC37bzii-VVR#SCV|$8zhFqh1;YQ)bPr5x8DGW5dEcSB?^yyy0 z!Z_mei7)H-vVlH8ehGP~vst$R`9P6n`6534kGkHLQg2Jiccg6}X%n)1=Oqa#ZQ}nu C4_^lW literal 0 HcmV?d00001 diff --git a/tests/__pycache__/test_msign.cpython-312-pytest-7.4.4.pyc b/tests/__pycache__/test_msign.cpython-312-pytest-7.4.4.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0516eb6ec937b1760ece6952d8c482a47504744f GIT binary patch literal 16498 zcmeHOUu+yl8Q;CTz1#a^r%vN0=^xl8P07;uoO2u}L3P^nFG*V?niN_?Ih}9U$vNjc zcXr+6u6L+{RJat;pa{jFfck1uUQv067kFvffTbl;K!75@Ac!R@zOOGK!S>#Ft``2Dl%&7n0pB8SV)bi~I4>EJAs3}d zneIcg;e1GzW+Rh|%+5q7qx4K{G6wrlQJsxX#%B|giP_|2QkFu}(~=QBBN>s4A=d6> zsv_&k3LZ@8VMm#@3dc*1dg!sov(Fwrdd%6tQ`w`(4n6kllX+H+N!dMEJqC&Ml1ckD z88YNE(qz~OfsEkQCKV$BGHNIwV@4FDYQ#Xs4Haaff$XpJ>HSXp zm}%M1VXyN1a(Go0{$cp9j=+t;u*gqHXXTd%B&jajUbz|fCb0tFYVr@IddOD1-)V9DjVUg8p^u$ROW4V?WtE1 zda4-F3*jm}hkDeA8EP$Bkm*s(9_05X+{;}&&)?T#{JMPq9odMVmg|waa#sGqrifID z>xt?%`l@TziSm5W(B_C~ndDVdE0pHuZOF~dnQ2|oLo4_M2@akfqPrD*?I4ZM-pyaZ zw`m1mHpow5KFnfiNxl&pAJ^sT(2kvRl@Z-ayyhmfaSc1MY~Xfix{Aji%I+g09_(QUteL&y!9EM8yyk8J6EU{hG%o7!gRT4Si?93OlIkmPtR4HyXm;l!z_OTEVy%c_|*qI;7PYH6@xxm zZ-97R`t#uUsuYgv3r}9TWAJ?PN6Ez(-oIn}!l5gJcQ0nny}0ny)zr@APv`eRYUjUV zaOqX(h|~a+*gp!Be*};_FLj1SLpDNOF93}C5a~_NdZ-o(gvhWF2ABlU^I-Cq!XyoX z$q0o>1?vbf8C^H<@nF3x_;@7&e4;|HgJ9AF-!A2RIE{H=rW)~e0<#2wSs4CGJyJ>N z$?EO~C_4f3zi4Ww3Wlk{WF&>WX=yw2<)Wn>)JXZ1rKcHqAxLo`!!$t^wX`O@+N)c1 zLNA30=w6B7m?0qTy&lVD33C+hJWKEVbX5zS#X<>G#;XMouS-7_ zRZU0&rQy9e$j)bw8Z>Qn1qnUOfec6yyicZaqm3GO57~vM5Cc#FADGZntng;c#tG^u zv}Gp@t&zJi_Z}2mP~3~+J`|rr@p%wS@gOx#cEc%ASG$^-jmdsyO}NB$$pIMWVEBXTWE{@+0Dt5ll~5TMEhdX*d<5EP+JN7PW<0t8=vEWV_--t19j_Jkksy$#MS!P8Y9Zh_ zuD2GwtY>9wvk_hc+Ywn;0Ne5V!0+s1p;4sE-1XQGcu_xFo(5jjb?brW^ytbIy^_?G z;2I+#&e_=PILYkmyX=z?8}t%a-Y5(=JL4!JW%Vn-gPh z-3(8R=!(za$d};aWw;c_&_(6$H+tzi_a+X9JNLm70$|yyWXN4tWQD1Yzwx5h4U_I^76Q^ zM+~rTD6V?53|3Q=dW_l5wP-E&ZuBDd980S;74WuF7DW4za;b&mNI+pfjANYJ-A6@=jke8JOZ8z z_)y&S(JtjYOzu+7Tb=7t&a>W)#DzH34{;j-l5RrXO+3%v*W$d-d=L5|*=+g4ajdRZ zl6tE8$Wvva<+Qo7Rj>=N2xyqopet9^qrD2!GQPpL@7r`|ZbuZ}gw#o_J3;Uf2RN77D2$iD~OG)Z*aKdmMw*2@U zj72XRDj{enBn?zKkqJnNo`!wm!GQ)(L#E+qPlMlKuZ2kF`5j!xgQ(UpxB!;_U=5Mj zl{RXvq09(0$kiGsvcqp4KAU^%*~QGI)V=Pi!8@VL1KSs(*VOUlTl@D1b-aUG`^}ov zS`s2oJ`2TGizdOf3+hdE)myBl3hFIhi{Dmn@W#^3S6V>kZ>u+4B3NJb=BceNqwA?C zPrW62R&NR3XTEQz-ay5v^*B{=k24ilGLJ)8Mo%ivE6R19pytR|;cX(%qd11*1rXJr zi@MG8cW1gfyty$WT2C})X1$af^P>B{^A5_5d>wD+MG#c0?V)n-1o;N!=*hNvioA@~ zzJ(%-b!kwI>jocbI2vtGPGc6j34G<`YNvYzql-U=EGH$^2LlbSmU8F@vwjBxjIX|p zZ@={BOJ`qs>zfPDEgrh$t!{gI>fo|^tN%Vw54Nf*FTmx@I;}po3vlsSd9W9&A^7Fr zoqaD}j9_E%i|AbnVB&)HH7K+n;UBP1*6+c5L5O_(zU(LWj*f>c(;7tPUoI( z6|Nu9Q-4%Bl!2Ytu@|j1py#5uE$A*iTL)qbBh1&-6kCYC6}kym$Ca_Bt!-=}568jC z!m$SL)4XMBQ`5zA&d!WMbRqbmN&t8jZ3@B+iD}Q1l19(E=!!ZSVg`>OD{r$vXxNoG zVDda>unpU%jqVwNHA70Nap7<&B9MBtrKq#1Ayfw}0}6o?dSA=(A-%2>a~oz0hg-1~ ze=#v&9Z?Sq_o3{bhWSvxlNjD}z(U73$3*6Z5&@A9W|j|T8$fI+_b4=LeGkMfvU=D7 z&A$JM(98(E>wDA1SZ9a>7S2EvMWA;H0{H^-74Xw5G{6s{DXf2V_IDUcEedhaQLG~b z_{9WT6yW=4z&Lcs(k&?RqJhZzB?Bnn>ft7!ap#_TC84LH zKPjlE9&t6|E~dqdgpsVp+<0j*PQt%xuW=G7^rpi<3Lf>8-dEk#fN9_knKSdYmY=um z@~lRE9U4RvLwqpc8$+@cmAm|QWtu4x_M_qHMtg| zAOW9KfZcu8_7)(6h?61L!T=MrmH-oL3DEM}516K{bO8v4Un79LwAO&hxVEK8zx|*d zwTx)j9<|hf9z{Ceg}>OlreL?uJ?9Tw;!nG4&s%EHu@vcAxA42pJ!*OXy7s802J}ej zGL{}m2KvZ?&?1yRMqH>>1k2s;Zx9jnK%0*sSnl{~iVXj7Jw!iDd8p}=EADkZkpr_m zUnT^Rw`i7*+b7aA7!>H@1fnXKxy%&|3D*-0K4BV1mZ9!l&KQB@v4W6r3xbhO3j*eX zAJ_>)`vgA#AUIw+i#0>XZ-%W@(pWRRL5jH54Wi`_*Q5sY1lY^-Sj20mXUZ6**e&{d zN0^1Z8{0#?CXa!z;42}QC^DLOzkt}eseHBw>%D;4xv8R=Wg}L7C=k1Y3x5T#@e2^I zOV`xVF?ML8Ic8@YK+QjEJa>_Q$5nytSpMgNE2W5$kmhi345uIt|&zXJg6zRq=owc%U+ZSvldE_dX+mOGA?zxrwPyXq<_ zMArCqgKU8$A7Q?4axG(&O}w4}m-2?{_6C~(zr16SHCCZ?96mrouIzqkMBqiErU3uk zD1jMh5p?*FVLzUx@y_VU)cc$K5%hz?UNs7N7#qXTKM0G~Fiq_O>KWOFhg7j_$5aNz z2#QBgJc{BF3aW9Cu5i}`?wQ^ZzdmAjPc4NuUmpo2jzV+RUqA%d3HbHVJM!g$ZR`30 z($$+D8x7IeJUn_r-{7cO`Wm;(-S%|R*-JEq&+3YJ2=!9RGX0zlAINp0ujG=fO z#nUKIBakSJIf(RoI?9#RzjdVfTD@OsJ>Lzn`Q zP rows).""" + G = torch.randn(3, 5) + result = msign(G) + assert result.shape == (3, 5) + + def test_msign_tall_matrix(self): + """msign should handle tall matrices (rows > cols).""" + G = torch.randn(5, 3) + result = msign(G) + assert result.shape == (5, 3) + + def test_msign_no_nan(self): + """msign should not produce NaN values.""" + G = torch.randn(4, 4) + result = msign(G) + assert not torch.isnan(result).any() + + def test_msign_no_inf(self): + """msign should not produce Inf values.""" + G = torch.randn(4, 4) + result = msign(G) + assert not torch.isinf(result).any() + + def test_msign_deterministic(self): + """msign should be deterministic with fixed seed.""" + torch.manual_seed(42) + G = torch.randn(4, 4) + result1 = msign(G) + + torch.manual_seed(42) + G = torch.randn(4, 4) + result2 = msign(G) + + assert torch.allclose(result1, result2) + + def test_msign_sign_property(self): + """For a positive definite matrix, result should be identity-like.""" + G = torch.eye(4) * 2 # positive definite + result = msign(G, steps=20) + # For PD matrix, sign should be identity + assert torch.allclose(result, torch.eye(4), atol=0.1) + + def test_msign_negative_definite(self): + """For a negative definite matrix, result should be negative identity.""" + G = -torch.eye(4) * 2 # negative definite + result = msign(G, steps=20) + # For ND matrix, sign should be -identity + assert torch.allclose(result, -torch.eye(4), atol=0.1) + + def test_msign_bfloat16_internal(self): + """msign should use bfloat16 internally but return float.""" + G = torch.randn(4, 4) + result = msign(G) + assert result.dtype == torch.float32 + + def test_msign_custom_steps(self): + """msign should respect custom steps parameter.""" + G = torch.randn(4, 4) + result1 = msign(G, steps=1) + result2 = msign(G, steps=20) + # More steps should give different results + # (not guaranteed, but likely) + assert result1.shape == result2.shape + + def test_msign_abc_list_stable_length(self): + """ABC_LIST_STABLE should have correct length.""" + assert len(ABC_LIST_STABLE) == len(ABC_LIST) + + def test_msign_single_step(self): + """msign should work with single step.""" + G = torch.randn(3, 3) + result = msign(G, steps=1) + assert result.shape == (3, 3) + assert not torch.isnan(result).any() \ No newline at end of file From 5d616326f4eb93972303dd2562670da156b6d05b Mon Sep 17 00:00:00 2001 From: RoomWithOutRoof Date: Wed, 15 Apr 2026 22:48:04 +0800 Subject: [PATCH 2/2] Add unit tests for src modules - Add test_main.py for hyperspherical_descent - Add test_manifold_muon_extra.py for additional manifold_muon tests Good day, Thank you for your work on this excellent library! Warmly, RoomWithOutRoof --- tests/test_main.py | 42 ++++++++++++++++++++++++++++++ tests/test_manifold_muon_extra.py | 43 +++++++++++++++++++++++++++++++ 2 files changed, 85 insertions(+) create mode 100644 tests/test_main.py create mode 100644 tests/test_manifold_muon_extra.py diff --git a/tests/test_main.py b/tests/test_main.py new file mode 100644 index 0000000..0bc7162 --- /dev/null +++ b/tests/test_main.py @@ -0,0 +1,42 @@ +"""Tests for main.py entry points.""" + +import pytest +import torch + + +class TestHypersphericalDescent: + """Test hyperspherical descent optimizer.""" + + def test_hyperspherical_descent_preserves_shape(self): + """Should preserve input shape.""" + from src.hyperspherical_descent import hyperspherical_descent + for shape in [(4,), (8,), (16,)]: + W = torch.randn(*shape) + G = torch.randn(*shape) + result = hyperspherical_descent(W, G) + assert result.shape == shape + + def test_hyperspherical_descent_unit_norm(self): + """Result should have unit norm.""" + from src.hyperspherical_descent import hyperspherical_descent + W = torch.randn(8) + G = torch.randn(8) + result = hyperspherical_descent(W, G) + assert torch.allclose(result.norm(), torch.tensor(1.0), atol=1e-5) + + def test_hyperspherical_descent_no_nan(self): + """Should not produce NaN.""" + from src.hyperspherical_descent import hyperspherical_descent + W = torch.randn(8) + G = torch.randn(8) + result = hyperspherical_descent(W, G) + assert not torch.isnan(result).any() + + def test_hyperspherical_descent_2d(self): + """Should work with 2D tensors.""" + from src.hyperspherical_descent import hyperspherical_descent + W = torch.randn(4, 4) + G = torch.randn(4, 4) + result = hyperspherical_descent(W, G) + assert result.shape == (4, 4) + assert not torch.isnan(result).any() \ No newline at end of file diff --git a/tests/test_manifold_muon_extra.py b/tests/test_manifold_muon_extra.py new file mode 100644 index 0000000..f421a61 --- /dev/null +++ b/tests/test_manifold_muon_extra.py @@ -0,0 +1,43 @@ +"""Additional tests for manifold_muon.""" + +import pytest +import torch + + +class TestManifoldMuon: + """Test manifold muon optimizer.""" + + def test_manifold_muon_preserves_shape(self): + """Should preserve input shape.""" + from src.manifold_muon import manifold_muon + # Use tall matrices (rows > cols) + W = torch.randn(8, 4) + G = torch.randn(8, 4) + result = manifold_muon(W, G) + assert result.shape == (8, 4) + + def test_manifold_muon_wide_matrix(self): + """Should handle wide matrices (cols > rows).""" + from src.manifold_muon import manifold_muon + W = torch.randn(4, 8) + G = torch.randn(4, 8) + result = manifold_muon(W, G) + assert result.shape == (4, 8) + + def test_manifold_muon_no_nan(self): + """Should not produce NaN.""" + from src.manifold_muon import manifold_muon + W = torch.randn(8, 4) + G = torch.randn(8, 4) + result = manifold_muon(W, G) + assert not torch.isnan(result).any() + + def test_manifold_muon_orthogonality(self): + """Result columns should be approximately orthonormal.""" + from src.manifold_muon import manifold_muon + W = torch.randn(8, 4) + G = torch.randn(8, 4) + result = manifold_muon(W, G, steps=50) + # Check orthonormal: Q^T @ Q should be close to identity + QtQ = result.T @ result + assert torch.allclose(QtQ, torch.eye(4), atol=0.1) \ No newline at end of file