diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c537495 --- /dev/null +++ b/.gitignore @@ -0,0 +1,191 @@ +# Created by https://www.gitignore.io/api/python,pycharm,jupyternotebooks,visualstudiocode +# Edit at https://www.gitignore.io/?templates=python,pycharm,jupyternotebooks,visualstudiocode + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### PyCharm ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +.idea/ + +# CMake +cmake-build-*/ + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ +# +### Python ### +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# celery beat schedule file +celerybeat-schedule + +# SageMath parsed files +*.sage.py + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# Mr Developer +.mr.developer.cfg +.project +.pydevproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +### VisualStudioCode ### +.vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json + +### VisualStudioCode Patch ### +# Ignore all local history of files +.history + +# End of https://www.gitignore.io/api/python,pycharm,jupyternotebooks,visualstudiocode + +/data/ +lightning_logs/ +logs/ +outputs/ +*.ckpt +*.pt[h] +*.html +*.pkl +*.pkl.bz2 +*.pkl.gz +.flake8 + + +# Created by https://www.toptal.com/developers/gitignore/api/vim +# Edit at https://www.toptal.com/developers/gitignore?templates=vim + +### Vim ### +# Swap +[._]*.s[a-v][a-z] +!*.svg # comment out if you don't need vector files +[._]*.sw[a-p] +[._]s[a-rt-v][a-z] +[._]ss[a-gi-z] +[._]sw[a-p] + +# Session +Session.vim +Sessionx.vim + +# Temporary +.netrwhist +*~ +# Auto-generated tag files +tags +# Persistent undo +[._]*.un~ + +# End of https://www.toptal.com/developers/gitignore/api/vim diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..f288702 --- /dev/null +++ b/LICENSE @@ -0,0 +1,674 @@ + GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. diff --git a/MANIFEST.in b/MANIFEST.in new file mode 100644 index 0000000..b192afd --- /dev/null +++ b/MANIFEST.in @@ -0,0 +1,9 @@ +prune outputs +prune lightning_logs +prune logs + +global-exclude .ipynb_checkpoints +global-exclude .git* +global-exclude *.ckpt +global-exclude *.pt[h] +global-exclude *.py[oc] diff --git a/README.md b/README.md new file mode 100644 index 0000000..f8283b7 --- /dev/null +++ b/README.md @@ -0,0 +1,116 @@ +# Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease + +[![Preprint](https://img.shields.io/badge/arXiv-2303.07125-b31b1b)](https://arxiv.org/abs/2303.07125) +[![License](https://img.shields.io/badge/license-GPLv3-blue.svg)](LICENSE) + +This repository contains the code to the paper "Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease" + +``` +@misc{https://doi.org/10.48550/arxiv.2303.07125, + doi = {10.48550/ARXIV.2303.07125}, + url = {https://arxiv.org/abs/2303.07125}, + author = {Wolf, Tom Nuno and Pölsterl, Sebastian and Wachinger, Christian}, + keywords = {Machine Learning (cs.LG), Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences}, + title = {Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease}, + publisher = {arXiv}, + year = {2023}, + copyright = {arXiv.org perpetual, non-exclusive license} +} +``` + +If you are using this code, please cite the paper above. + + +## Installation + +Use [conda](https://conda.io/miniconda.html) to create an environment called `panic` with all dependencies: + +```bash +conda env create -n panic --file requirements.yaml +``` + +Additionally, install the package torchpanic from this repository with +```bash +pip install --no-deps -e . +``` + +## Data + +We used data from the [Alzheimer's Disease Neuroimaging Initiative (ADNI)](http://adni.loni.usc.edu/). +Since we are not allowed to share our data, you would need to process the data yourself. +Data for training, validation, and testing should be stored in separate +[HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files, +using the following hierarchical format: + +1. First level: A unique identifier. +2. The second level always has the following entries: + 1. A group named `PET` with the subgroup `FDG`, which itself has the + [dataset](https://docs.h5py.org/en/stable/high/dataset.html) named `data` as child: + The graymatter density map of size (113,117,113). Additionally, the subgroup `FDG` has an attribute `imageuid` with is the unique image identifier. + 2. A group named `tabular`, which has two datasets called `data` and `missing`, each of size 41: + `data` contains the tabular data values, while `missing` is a missing value indicator if a tabular feature was not acquired at this visit. + 3. A scalar [attribute](https://docs.h5py.org/en/stable/high/attr.html) `RID` with the *patient* ID. + 4. A string attribute `VISCODE` with ADNI's visit code. + 5. A string attribute `DX` containing the diagnosis (`CN`, `MCI` or `Dementia`). + +One entry in the resulting HDF5 file should have the following structure: +``` +/1010012 Group + Attribute: RID scalar + Type: native long + Data: 1234 + Attribute: VISCODE scalar + Type: variable-length null-terminated UTF-8 string + Data: "bl" + Attribute: DX scalar + Type: variable-length null-terminated UTF-8 string + Data: "CN" +/1010012/PET Group +/1010012/PET/FDG Group + Attribute imageuid scalar + Type: variable-length null-terminated UTF-8 string + Data: "12345" +/1010012/PET/FDG/data Dataset {113, 137, 133} +/1010012/tabular Group +/1010012/tabular/data Dataset {41} +/1010012/tabular/missing Dataset {41} +``` + +Finally, the HDF5 file should also contain the following meta-information +in a separate group named `stats`: + +``` +/stats/tabular Group +/stats/tabular/columns Dataset {41} +/stats/tabular/mean Dataset {41} +/stats/tabular/stddev Dataset {41} +``` + +They are the names of the features in the tabular data, +their mean, and standard deviation. + +## Usage + +PANIC processes tabular data depending on its data type. +Therefore, it is necessary to tell PANIC how to process each tabular feature: +The following indices must be given to the model in the configs file `configs/model/panic.yaml`: + +`idx_real_features`: indices of real-valued features within `tabular` data. +`idx_cat_features`: indices of categorical features within `tabular` data. +`idx_real_has_missing`: indices of real-valued features which should be considered from `missing`. +`idx_cat_has_missing`: indices of categorical features which should be considered from `missing`. + +Similarly, missing tabular inputs to DAFT (`configs/model/daft.yaml`) need to be specified with `idx_tabular_has_missing`. + +## Training + +To train PANIC, or any of the baseline models, adapt the config files (mainly `train.yaml`) and execute the `train.py` script to begin training. + +Model checkpoints will be written to the `outputs` folder by default. + + +## Interpretation of results + +We provide some useful utility function to create plots and visualization required to interpret the model. +You can find them under `torchpanic/viz`. + diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml new file mode 100644 index 0000000..13f5e67 --- /dev/null +++ b/configs/callbacks/default.yaml @@ -0,0 +1,12 @@ +model_checkpoint: + _target_: pytorch_lightning.callbacks.ModelCheckpoint + monitor: ${model.validation_metric} # name of the logged metric which determines when model is improving + mode: max # "max" means higher metric value is better, can be also "min" + save_top_k: 1 # save k best models (determined by above metric) + dirpath: checkpoints/ + filename: "{epoch}-bacc" + +learning_rate_monitor: + _target_: pytorch_lightning.callbacks.LearningRateMonitor + logging_interval: epoch + log_momentum: False diff --git a/configs/datamodule/adni.yaml b/configs/datamodule/adni.yaml new file mode 100644 index 0000000..ac4316a --- /dev/null +++ b/configs/datamodule/adni.yaml @@ -0,0 +1,16 @@ +_target_: torchpanic.datamodule.adni.AdniDataModule + +train_data: ${data_dir}/${fold}-train.h5 # data_dir is specified in config.yaml +valid_data: ${data_dir}/${fold}-valid.h5 +test_data: ${data_dir}/${fold}-test.h5 +modalities: ["PET", "TABULAR"] # 2 = ModalityType.PET; 6 = ModalityType.TABULAR|PET; 7 = ModalityType.TABULAR|PET|MRI +batch_size: 32 +num_workers: 10 +metadata: + num_channels: 1 + num_classes: 3 +augmentation: + rotate: 30 + translate: 0 + scale: 0.2 + p: 0.5 diff --git a/configs/experiment/daft_tuned.yaml b/configs/experiment/daft_tuned.yaml new file mode 100644 index 0000000..0a62a6a --- /dev/null +++ b/configs/experiment/daft_tuned.yaml @@ -0,0 +1,8 @@ +# @package _global_ + +defaults: + - override /datamodule: adni + - override /model: daft + +name: adni-daft + diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml new file mode 100644 index 0000000..e69de29 diff --git a/configs/hydra/panic.yaml b/configs/hydra/panic.yaml new file mode 100644 index 0000000..e26a287 --- /dev/null +++ b/configs/hydra/panic.yaml @@ -0,0 +1,7 @@ +run: + dir: outputs/${name}/seed-${seed}/fold-${fold} +job: + chdir: True +sweep: + dir: outputs + subdir: ${name}/seed-${seed}/fold-${fold} diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml new file mode 100644 index 0000000..607142d --- /dev/null +++ b/configs/logger/tensorboard.yaml @@ -0,0 +1,8 @@ +tensorboard: + _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger + save_dir: logs/ + name: null + version: ${name} + log_graph: False + default_hp_metric: True + prefix: "" diff --git a/configs/model/daft.yaml b/configs/model/daft.yaml new file mode 100644 index 0000000..0b07f54 --- /dev/null +++ b/configs/model/daft.yaml @@ -0,0 +1,53 @@ +_target_: torchpanic.modules.standard.StandardModule +lr: 0.003607382027438678 +weight_decay: 1.1061692016959738e-05 +output_penalty_weight: 0.0 +num_classes: ${datamodule.metadata.num_classes} +validation_metric: 'val/bacc_best' + +net: + _target_: torchpanic.models.daft.DAFT + in_channels: ${datamodule.metadata.num_channels} + in_tabular: 41 + n_outputs: ${datamodule.metadata.num_classes} + n_basefilters: 32 + filmblock_args: + location: 3 + scale: True + shift: True + bottleneck_dim: 12 + idx_tabular_has_missing: + - 3 # abeta + - 4 # tau + - 5 # ptau + - 10 # all categorical, except gender + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + - 20 + - 21 + - 22 + - 23 + - 24 + - 25 + - 26 + - 27 + - 28 + - 29 + - 30 + - 31 + - 32 + - 33 + - 34 + - 35 + - 36 + - 37 + - 38 + - 39 + - 40 diff --git a/configs/model/panic.yaml b/configs/model/panic.yaml new file mode 100644 index 0000000..475fc20 --- /dev/null +++ b/configs/model/panic.yaml @@ -0,0 +1,103 @@ +_target_: torchpanic.modules.panic.PANIC + +lr: 0.0067 +weight_decay: 0.0001 +weight_decay_nam: 0.0001 +l_clst: 0.5 # lambda to multipli loss of intra_clst +l_sep: ${model.l_clst} # lambda to multipli loss of inter_clst +l_occ: 0.5 # lambda of occurrence loss +l_affine: 0.5 +l_nam: 0.0001 # l2 regularization of coefficients of NAM +epochs_all: 20 # push prototypes every x epochs +epochs_nam: 10 +epochs_warmup: 10 +enable_checkpointing: ${trainer.enable_checkpointing} +monitor_prototypes: False +enable_save_embeddings: False +enable_log_prototypes: False +validation_metric: 'val/bacc_save' + +net: + _target_: torchpanic.models.panic.PANIC + protonet: + backbone: "3dresnet" + in_channels: ${datamodule.metadata.num_channels} + out_features: ${datamodule.metadata.num_classes} + n_prototypes_per_class: 2 # just used for init + n_chans_protos: 64 + optim_features: True + normed_prototypes: True + n_blocks: 3 + n_basefilters: 32 + pretrained_model: ${hydra:runtime.cwd}/outputs/pretrained_encoders/seed-666/fold-${fold}/checkpoints/best.ckpt + nam: + out_features: 3 + hidden_units: [32, 32] + dropout_rate: 0.5 + feature_dropout_rate: 0.1 + idx_real_features: [1, 2, 3, 4, 5, 6, 7, 8, 9] # age, edu, abeta, tau, ptau, L-Hipp, R-Hipp, L-Ento, R-Ento + idx_cat_features: + - 0 + - 10 + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + - 20 + - 21 + - 22 + - 23 + - 24 + - 25 + - 26 + - 27 + - 28 + - 29 + - 30 + - 31 + - 32 + - 33 + - 34 + - 35 + - 36 + - 37 + - 38 + - 39 + - 40 + idx_real_has_missing: [2, 3, 4] + idx_cat_has_missing: + - 1 + - 2 + - 3 + - 4 + - 5 + - 6 + - 7 + - 8 + - 9 + - 10 + - 11 + - 12 + - 13 + - 14 + - 15 + - 16 + - 17 + - 18 + - 19 + - 20 + - 21 + - 22 + - 23 + - 24 + - 25 + - 26 + - 27 + - 28 + - 29 + - 30 diff --git a/configs/model/pretrain_encoder.yaml b/configs/model/pretrain_encoder.yaml new file mode 100644 index 0000000..07b09b6 --- /dev/null +++ b/configs/model/pretrain_encoder.yaml @@ -0,0 +1,18 @@ +_target_: torchpanic.modules.standard.StandardModule + +lr: 0.003 +weight_decay: 0.001 +validation_metric: 'val/bacc_best' + +net: + _target_: torchpanic.models.pretrain_encoder.Encoder + protonet: + backbone: "3dresnet" + in_channels: ${datamodule.metadata.num_channels} + out_features: ${datamodule.metadata.num_classes} + n_prototypes_per_class: 3 # just used for init + n_chans_protos: 64 + optim_features: True + n_blocks: 4 + n_basefilters: 32 + normed_prototypes: True diff --git a/configs/test.yaml b/configs/test.yaml new file mode 100644 index 0000000..a8b8e5e --- /dev/null +++ b/configs/test.yaml @@ -0,0 +1,21 @@ +# @package _global_ + +# specify here default evaluation configuration +defaults: + - _self_ + - datamodule: adni.yaml + - model: panic.yaml + - logger: tensorboard.yaml + - trainer: default.yaml + +# path to folder with data +data_dir: ??? + +# seed for random number generators in pytorch, numpy and python.random +seed: 666 + +fold: 0 + +name: "protopnet_test" + +ckpt_path: ??? diff --git a/configs/train.yaml b/configs/train.yaml new file mode 100644 index 0000000..4c0f407 --- /dev/null +++ b/configs/train.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +# specify here default training configuration +defaults: + - _self_ + - hydra: panic.yaml + - callbacks: default.yaml + - datamodule: adni.yaml + - model: panic.yaml + - logger: tensorboard.yaml + - trainer: default.yaml + +# path to folder with data +data_dir: ??? + +# fold to run this experiment for +fold: 0 + +# seed for random number generators in pytorch, numpy and python.random +seed: 666 + +# default name for the experiment, determines logging folder path +name: "train_panic" diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml new file mode 100644 index 0000000..8824d53 --- /dev/null +++ b/configs/trainer/default.yaml @@ -0,0 +1,20 @@ +_target_: pytorch_lightning.Trainer + +accelerator: gpu +devices: 1 +min_epochs: 1 +max_epochs: 100 + +enable_progress_bar: True + +detect_anomaly: True +log_every_n_steps: 10 +track_grad_norm: -1 # -1: disabled; 2: track 2-norm + +# number of validation steps to execute at the beginning of the training +# num_sanity_val_steps: 0 + +# ckpt path +resume_from_checkpoint: null + +enable_checkpointing: True diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..80e7c8c --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,5 @@ +[build-system] +requires = [ + "setuptools>=45", +] +build-backend = "setuptools.build_meta" diff --git a/requirements.yaml b/requirements.yaml new file mode 100644 index 0000000..98c6ed8 --- /dev/null +++ b/requirements.yaml @@ -0,0 +1,92 @@ +name: panic +channels: + - defaults +dependencies: + - _libgcc_mutex=0.1=main + - _openmp_mutex=5.1=1_gnu + - ca-certificates=2022.10.11=h06a4308_0 + - certifi=2022.9.24=py39h06a4308_0 + - ld_impl_linux-64=2.38=h1181459_1 + - libffi=3.3=he6710b0_2 + - libgcc-ng=11.2.0=h1234567_1 + - libgomp=11.2.0=h1234567_1 + - libstdcxx-ng=11.2.0=h1234567_1 + - ncurses=6.3=h5eee18b_3 + - openssl=1.1.1s=h7f8727e_0 + - pip=22.3.1=py39h06a4308_0 + - python=3.9.12=h12debd9_1 + - readline=8.2=h5eee18b_0 + - setuptools=65.5.0=py39h06a4308_0 + - sqlite=3.40.0=h5082296_0 + - tk=8.6.12=h1ccaba5_0 + - tzdata=2022g=h04d1e81_0 + - wheel=0.37.1=pyhd3eb1b0_0 + - xz=5.2.8=h5eee18b_0 + - zlib=1.2.13=h5eee18b_0 + - pip: + - --extra-index-url https://download.pytorch.org/whl/cu113 + - aiohttp==3.8.3 + - aiosignal==1.3.1 + - antlr4-python3-runtime==4.9.3 + - asttokens==2.2.1 + - async-timeout==4.0.2 + - attrs==22.1.0 + - backcall==0.2.0 + - charset-normalizer==2.1.1 + - click==8.1.3 + - colorama==0.4.6 + - commonmark==0.9.1 + - decorator==5.1.1 + - deprecated==1.2.13 + - executing==1.2.0 + - frozenlist==1.3.3 + - fsspec==2022.11.0 + - h5py==3.7.0 + - humanize==4.4.0 + - hydra-core==1.3.0 + - idna==3.4 + - ipdb==0.13.11 + - ipython==8.7.0 + - jedi==0.18.2 + - lightning-utilities==0.4.2 + - matplotlib-inline==0.1.6 + - multidict==6.0.3 + - nibabel==4.0.2 + - numpy==1.23.5 + - omegaconf==2.3.0 + - packaging==22.0 + - pandas==1.5.2 + - parso==0.8.3 + - pexpect==4.8.0 + - pickleshare==0.7.5 + - pillow==9.3.0 + - prompt-toolkit==3.0.36 + - protobuf==3.20.1 + - ptyprocess==0.7.0 + - pure-eval==0.2.2 + - pygments==2.13.0 + - python-dateutil==2.8.2 + - pytorch-lightning==1.8.4.post0 + - pytz==2022.6 + - pyyaml==6.0 + - requests==2.28.1 + - rich==12.6.0 + - scipy==1.9.3 + - shellingham==1.5.0 + - simpleitk==2.2.1 + - six==1.16.0 + - stack-data==0.6.2 + - tensorboardx==2.5.1 + - tomli==2.0.1 + - torch==1.12.1+cu113 + - torchio==0.18.86 + - torchmetrics==0.11.0 + - torchvision==0.13.1+cu113 + - tqdm==4.64.1 + - traitlets==5.7.1 + - typer==0.7.0 + - typing-extensions==4.4.0 + - urllib3==1.26.13 + - wcwidth==0.2.5 + - wrapt==1.14.1 + - yarl==1.8.2 diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..c98acd9 --- /dev/null +++ b/setup.cfg @@ -0,0 +1,22 @@ +[metadata] +name = torchpanic +version = 0.0.1 +description = PANIC for AD diagnosis with XAI +classifiers = + Programming Language :: Python :: 3 + +[options] +python_requires = >=3.9 +zip_safe = False +packages=find: +install_requires = + h5py + hydra + numpy + omegaconf + pandas + pytorch_lightning + tensorboard + torch + torchio + torchmetrics diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..6068493 --- /dev/null +++ b/setup.py @@ -0,0 +1,3 @@ +from setuptools import setup + +setup() diff --git a/test.py b/test.py new file mode 100644 index 0000000..400403e --- /dev/null +++ b/test.py @@ -0,0 +1,30 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import logging + +import hydra +from omegaconf import DictConfig + +from torchpanic.testing import test + + +@hydra.main(config_path="configs/", config_name="test.yaml") +def main(config: DictConfig): + return test(config) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main() diff --git a/torchpanic/__init__.py b/torchpanic/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchpanic/datamodule/__init__.py b/torchpanic/datamodule/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchpanic/datamodule/adni.py b/torchpanic/datamodule/adni.py new file mode 100644 index 0000000..1d599ad --- /dev/null +++ b/torchpanic/datamodule/adni.py @@ -0,0 +1,297 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import collections.abc +import copy +import logging +from operator import itemgetter +from typing import Any, Dict, Optional, Sequence, Tuple, Union + +import h5py +import numpy as np +import pandas as pd +import pytorch_lightning as pl +from torch.utils.data import DataLoader, Dataset +from torch.utils.data.dataloader import default_collate +import torchio as tio + +from .modalities import AugmentationType, DataPointType, ModalityType + +LOG = logging.getLogger(__name__) + +DIAGNOSIS_MAP = {"CN": 0, "MCI": 1, "Dementia": 2} +DIAGNOSIS_MAP_BINARY = {"CN": 0, "Dementia": 1} + + +def Identity(x): + return x + + +def get_image_transform(p=0.0, rotate=0, translate=0, scale=0): + # Image sizes PET & MRI Dataset {113, 137, 113} + img_transforms = [] + + randomAffineWithRot = tio.RandomAffine( + scales=scale, + degrees=rotate, + translation=translate, + image_interpolation="linear", + default_pad_value="otsu", + p=p, # no transform if validation + ) + img_transforms.append(randomAffineWithRot) + + img_transform = tio.Compose(img_transforms) + return img_transform + + +class AdniDataset(Dataset): + def __init__( + self, + path: str, + modalities: Union[ModalityType, int, Sequence[str]], + augmentation: Dict[str, Any], + ) -> None: + self.path = path + if isinstance(modalities, collections.abc.Sequence): + mod_type = ModalityType(0) + for mod in modalities: + mod_type |= getattr(ModalityType, mod) + self.modalities = mod_type + else: + self.modalities = ModalityType(modalities) + self.augmentation = augmentation + + self._set_transforms(augmentations=self.augmentation) + self._load() + + def _set_transforms(self, augmentations: Dict) -> None: + self.transforms = {} + if ModalityType.MRI in self.modalities: + self.transforms[ModalityType.MRI] = get_image_transform(**augmentations) + if ModalityType.PET in self.modalities: + self.transforms[ModalityType.PET] = get_image_transform(**augmentations) + if ModalityType.TABULAR in self.modalities: + self.transforms[ModalityType.TABULAR] = Identity + + def _load(self) -> None: + data_points = { + flag: [] for flag in ModalityType.__members__.values() if flag in self.modalities + } + load_mri = ModalityType.MRI in self.modalities + load_pet = ModalityType.PET in self.modalities + load_tab = ModalityType.TABULAR in self.modalities + + LOG.info("Loading %s from %s", self.modalities, self.path) + + diagnosis = [] + rid = [] + column_names = None + with h5py.File(self.path, mode='r') as file: + if load_tab: + tab_stats = file['stats/tabular'] + tab_mean = tab_stats['mean'][:] + tab_std = tab_stats['stddev'][:] + assert np.all(tab_std > 0), "stddev is not positive" + column_names = tab_stats['columns'][:] + self._tab_mean = tab_mean + self._tab_std = tab_std + + for name, group in file.items(): + if name == "stats": + continue + + data_point = [] + if load_mri: + mri_data = group['MRI/T1/data'][:] + data_point.append( + tio.Subject( + image=tio.ScalarImage(tensor=mri_data[np.newaxis]) + ) + ) + + if load_pet: + pet_data = group['PET/FDG/data'][:] + # pet_data = np.nan_to_num(pet_data, copy=False) + data_point.append( + tio.Subject( + image=tio.ScalarImage(tensor=pet_data[np.newaxis]) + ) + ) + + if load_tab: + tab_values = group['tabular/data'][:] + tab_missing = group['tabular/missing'][:] + # XXX: always assumes that mean and std are from the training data + tab_data = np.stack(( + (tab_values - tab_mean) / tab_std, + tab_missing, + )) + data_point.append(tab_data) + + assert len(data_points) == len(data_point) + for samples, data in zip(data_points.values(), data_point): + samples.append(data) + + diagnosis.append(group.attrs['DX']) + rid.append(group.attrs['RID']) + + LOG.info("Loaded %d samples", len(rid)) + + dmap = DIAGNOSIS_MAP + labels, counts = np.unique(diagnosis, return_counts=True) + assert len(labels) == len(dmap), f"expected {len(dmap)} labels, but got {labels}" + LOG.info("Classes: %s", pd.Series(counts, index=labels)) + + self._column_names = column_names + self._data_points = data_points + self._diagnosis = [dmap[d] for d in diagnosis] + self._rid = rid + + @property + def rid(self): + return self._rid + + @property + def column_names(self): + return self._column_names + + @property + def tabular_mean(self): + return self._tab_mean + + @property + def tabular_stddev(self): + return self._tab_std + + def tabular_inverse_transform(self, values): + values_arr = np.ma.atleast_2d(values) + if len(self._tab_mean) != values_arr.shape[1]: + raise ValueError(f"expected {len(self._tab_mean)} features, but got {values_arr.shape[1]}") + + vals_t = values_arr * self._tab_std[np.newaxis] + self._tab_mean[np.newaxis] + return vals_t.reshape(values.shape) + + def get_tabular(self, index: int, inverse_transform: bool = False) -> np.ma.array: + tab_vals, tab_miss = self[index][0][ModalityType.TABULAR] + tab_vals = np.ma.array(tab_vals, mask=tab_miss) + if inverse_transform: + return self.tabular_inverse_transform(tab_vals) + return tab_vals + + def __len__(self) -> int: + return len(self._rid) + + def _as_tensor(self, x): + if isinstance(x, tio.Subject): + return x.image.data + return x + + def __getitem__(self, index: int) -> Tuple[DataPointType, DataPointType, AugmentationType, int]: + label = self._diagnosis[index] + sample = {} + sample_raw = {} + augmentations = {} + for modality_id, samples in self._data_points.items(): + data_raw = samples[index] + data_transformed = self.transforms[modality_id](data_raw) + if isinstance(data_transformed, tio.Subject): + augmentations[modality_id] = data_transformed.get_composed_history() + + sample_raw[modality_id] = self._as_tensor(data_raw) + sample[modality_id] = self._as_tensor(data_transformed) + + return sample, sample_raw, augmentations, label + + +def collate_adni(batch): + get = itemgetter(0, 1, 3) # 2nd position is a tio.Transform instance + batch_wo_aug = [get(elem) for elem in batch] + + keys = batch[0][2].keys() + augmentations = {k: [elem[2][k] for elem in batch] for k in keys} + + batch_stacked = default_collate(batch_wo_aug) + return batch_stacked[:-1] + [augmentations] + batch_stacked[-1:] + + +class AdniDataModule(pl.LightningDataModule): + def __init__( + self, + modalities: Union[ModalityType, int], + train_data: str, + valid_data: str, + test_data: Optional[str] = None, + batch_size: int = 32, + num_workers: int = 4, + augmentation: Dict[str, Any] = {"p": 0.0}, + metadata: Optional[Dict[str, Any]] = None, + ): + super().__init__() + self.modalities = modalities + self.train_data = train_data + self.valid_data = valid_data + self.test_data = test_data + self.batch_size = batch_size + self.num_workers = num_workers + self.metadata = metadata + self.augmentation = augmentation + + def setup(self, stage: Optional[str] = None): + if stage == 'fit' or stage is None: + self.train_dataset = AdniDataset(self.train_data, modalities=self.modalities, augmentation=self.augmentation) + self.push_dataset = copy.deepcopy(self.train_dataset) + self.push_dataset._set_transforms(augmentations={"p": 0}) + self.eval_dataset = AdniDataset(self.valid_data, modalities=self.modalities, augmentation={"p": 0}) + elif stage == 'test' and self.test_data is not None: + self.test_dataset = AdniDataset(self.test_data, modalities=self.modalities, augmentation={"p": 0}) + self.eval_dataset = AdniDataset(self.valid_data, modalities=self.modalities, augmentation={"p": 0}) + + def train_dataloader(self): + return DataLoader( + self.train_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + drop_last=True, + pin_memory=True, + shuffle=True, + collate_fn=collate_adni, + ) + + def val_dataloader(self): + return DataLoader( + self.eval_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers // 2, + pin_memory=True, + collate_fn=collate_adni, + ) + + def test_dataloader(self): + return DataLoader( + self.test_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + collate_fn=collate_adni, + ) + + def push_dataloader(self): + return DataLoader( + self.push_dataset, + batch_size=self.batch_size, + num_workers=self.num_workers, + pin_memory=True, + collate_fn=collate_adni, + ) diff --git a/torchpanic/datamodule/modalities.py b/torchpanic/datamodule/modalities.py new file mode 100644 index 0000000..fbe5faf --- /dev/null +++ b/torchpanic/datamodule/modalities.py @@ -0,0 +1,30 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import enum +from typing import Dict, Tuple + +from torch import Tensor +import torchio as tio + + +class ModalityType(enum.IntFlag): + MRI = 1 + PET = 2 + TABULAR = 4 + + +DataPointType = Dict[ModalityType, Tensor] +AugmentationType = Dict[ModalityType, tio.Compose] +BatchWithLabelType = Tuple[DataPointType, DataPointType, AugmentationType, Tensor] diff --git a/torchpanic/models/__init__.py b/torchpanic/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchpanic/models/backbones.py b/torchpanic/models/backbones.py new file mode 100644 index 0000000..90667b4 --- /dev/null +++ b/torchpanic/models/backbones.py @@ -0,0 +1,142 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +from collections import OrderedDict +from torch import nn + + +def conv3d( + in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, +) -> nn.Module: + if kernel_size != 1: + padding = 1 + else: + padding = 0 + return nn.Conv3d( + in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False, + ) + + +class ConvBnReLU(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + bn_momentum: float = 0.05, + kernel_size: int = 3, + stride: int = 1, + padding: int = 1, + ): + super().__init__() + self.conv = nn.Conv3d( + in_channels, out_channels, + kernel_size, + stride=stride, + padding=padding, + bias=False, + ) + self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.bn(out) + out = self.relu(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int, bn_momentum: float = 0.05, stride: int = 1): + super().__init__() + self.conv1 = conv3d(in_channels, out_channels, stride=stride) + self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum) + self.dropout1 = nn.Dropout(p=0.2, inplace=True) + + self.conv2 = conv3d(out_channels, out_channels) + self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum) + self.relu = nn.ReLU(inplace=True) + + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + conv3d(in_channels, out_channels, kernel_size=1, stride=stride), + nn.BatchNorm3d(out_channels, momentum=bn_momentum) + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.dropout1(out) + + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class ThreeDResNet(nn.Module): + def __init__(self, in_channels: int, n_outputs: int, n_blocks: int = 4, bn_momentum: float = 0.05, n_basefilters: int = 32): + super().__init__() + self.conv1 = nn.Conv3d( + in_channels, + n_basefilters, + kernel_size=7, + stride=2, + padding=3, + bias=False,) + self.bn1 = nn.BatchNorm3d(n_basefilters, momentum=bn_momentum) + self.relu1 = nn.ReLU(inplace=True) + self.pool1 = nn.MaxPool3d(3, stride=2) + + if n_blocks < 2: + raise ValueError(f"n_blocks must be at least 2, but got {n_blocks}") + + blocks = [ + ("block1", ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)) + ] + n_filters = n_basefilters + for i in range(n_blocks - 1): + blocks.append( + (f"block{i+2}", ResBlock(n_filters, 2 * n_filters, bn_momentum=bn_momentum, stride=2)) + ) + n_filters *= 2 + + self.blocks = nn.Sequential(OrderedDict(blocks)) + self.gap = nn.AdaptiveAvgPool3d(1) + self.fc = nn.Linear(n_filters, n_outputs, bias=False) + +# def get_out_features(self): +# +# return self.out_features + + def forward(self, x): + out = self.conv1(x) + out = self.bn1(out) + out = self.relu1(out) + out = self.pool1(out) + out = self.blocks(out) + out = self.gap(out) + out = out.view(out.size(0), -1) + return self.fc(out) diff --git a/torchpanic/models/daft/__init__.py b/torchpanic/models/daft/__init__.py new file mode 100644 index 0000000..104830a --- /dev/null +++ b/torchpanic/models/daft/__init__.py @@ -0,0 +1 @@ +from .vol_networks import DAFT diff --git a/torchpanic/models/daft/vol_blocks.py b/torchpanic/models/daft/vol_blocks.py new file mode 100644 index 0000000..aca8c39 --- /dev/null +++ b/torchpanic/models/daft/vol_blocks.py @@ -0,0 +1,276 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +# +# +# +# +# This file is part of Dynamic Affine Feature Map Transform (DAFT). +# +# DAFT is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# DAFT is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with DAFT. If not, see . +from abc import ABCMeta, abstractmethod +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.nn.parallel +import torch.utils.data + + +def conv3d(in_channels, out_channels, kernel_size=3, stride=1): + if kernel_size != 1: + padding = 1 + else: + padding = 0 + return nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False) + + +class ConvBnReLU(nn.Module): + def __init__( + self, in_channels, out_channels, bn_momentum=0.05, kernel_size=7, stride=2, padding=3, + ): + super().__init__() + self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False) + self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum) + self.relu = nn.ReLU(inplace=True) + + def forward(self, x): + out = self.conv(x) + out = self.bn(out) + out = self.relu(out) + return out + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels, bn_momentum=0.05, stride=1): + super().__init__() + self.conv1 = conv3d(in_channels, out_channels, stride=stride) + self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum) + self.conv2 = conv3d(out_channels, out_channels) + self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum) + self.relu = nn.ReLU(inplace=True) + + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + conv3d(in_channels, out_channels, kernel_size=1, stride=stride), + nn.BatchNorm3d(out_channels, momentum=bn_momentum), + ) + else: + self.downsample = None + + def forward(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + + out += residual + out = self.relu(out) + + return out + + +class FilmBase(nn.Module, metaclass=ABCMeta): + """Absract base class for models that are related to FiLM of Perez et al""" + + def __init__( + self, + in_channels: int, + out_channels: int, + bn_momentum: float, + stride: int, + ndim_non_img: int, + location: int, + activation: str, + scale: bool, + shift: bool, + ) -> None: + + super().__init__() + + # sanity checks + if location not in set(range(5)): + raise ValueError(f"Invalid location specified: {location}") + if activation not in {"tanh", "sigmoid", "linear"}: + raise ValueError(f"Invalid location specified: {location}") + if (not isinstance(scale, bool) or not isinstance(shift, bool)) or (not scale and not shift): + raise ValueError( + f"scale and shift must be of type bool:\n -> scale value: {scale}, " + "scale type {type(scale)}\n -> shift value: {shift}, shift type: {type(shift)}" + ) + # ResBlock + self.conv1 = conv3d(in_channels, out_channels, stride=stride) + self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum, affine=(location != 3)) + self.conv2 = conv3d(out_channels, out_channels) + self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum) + self.relu = nn.ReLU(inplace=True) + self.global_pool = nn.AdaptiveAvgPool3d(1) + if stride != 1 or in_channels != out_channels: + self.downsample = nn.Sequential( + conv3d(in_channels, out_channels, kernel_size=1, stride=stride), + nn.BatchNorm3d(out_channels, momentum=bn_momentum), + ) + else: + self.downsample = None + # Film-specific variables + self.location = location + if self.location == 2 and self.downsample is None: + raise ValueError("This is equivalent to location=1 and no downsampling!") + # location decoding + self.film_dims = 0 + if location in {0, 1, 2}: + self.film_dims = in_channels + elif location in {3, 4}: + self.film_dims = out_channels + if activation == "sigmoid": + self.scale_activation = nn.Sigmoid() + elif activation == "tanh": + self.scale_activation = nn.Tanh() + elif activation == "linear": + self.scale_activation = None + + @abstractmethod + def rescale_features(self, feature_map, x_aux): + """method to recalibrate feature map x""" + + def forward(self, feature_map, x_aux): + + if self.location == 0: + feature_map = self.rescale_features(feature_map, x_aux) + residual = feature_map + + if self.location == 1: + residual = self.rescale_features(residual, x_aux) + + if self.location == 2: + feature_map = self.rescale_features(feature_map, x_aux) + out = self.conv1(feature_map) + out = self.bn1(out) + + if self.location == 3: + out = self.rescale_features(out, x_aux) + out = self.relu(out) + + if self.location == 4: + out = self.rescale_features(out, x_aux) + out = self.conv2(out) + out = self.bn2(out) + if self.downsample is not None: + residual = self.downsample(residual) + out += residual + out = self.relu(out) + + return out + + +class DAFTBlock(FilmBase): + # Block for ZeCatNet + def __init__( + self, + in_channels: int, + out_channels: int, + bn_momentum: float = 0.1, + stride: int = 2, + ndim_non_img: int = 15, + location: int = 0, + activation: str = "linear", + scale: bool = True, + shift: bool = True, + bottleneck_dim: int = 7, + ) -> None: + + super().__init__( + in_channels=in_channels, + out_channels=out_channels, + bn_momentum=bn_momentum, + stride=stride, + ndim_non_img=ndim_non_img, + location=location, + activation=activation, + scale=scale, + shift=shift, + ) + + self.bottleneck_dim = bottleneck_dim + aux_input_dims = self.film_dims + # shift and scale decoding + self.split_size = 0 + if scale and shift: + self.split_size = self.film_dims + self.scale = None + self.shift = None + self.film_dims = 2 * self.film_dims + elif not scale: + self.scale = 1 + self.shift = None + elif not shift: + self.shift = 0 + self.scale = None + + # create aux net + layers = [ + ("aux_base", nn.Linear(ndim_non_img + aux_input_dims, self.bottleneck_dim, bias=False)), + ("aux_relu", nn.ReLU()), + ("aux_out", nn.Linear(self.bottleneck_dim, self.film_dims, bias=False)), + ] + self.aux = nn.Sequential(OrderedDict(layers)) + + def rescale_features(self, feature_map, x_aux): + + squeeze = self.global_pool(feature_map) + squeeze = squeeze.view(squeeze.size(0), -1) + squeeze = torch.cat((squeeze, x_aux), dim=1) + + attention = self.aux(squeeze) + if self.scale == self.shift: + v_scale, v_shift = torch.split(attention, self.split_size, dim=1) + v_scale = v_scale.view(*v_scale.size(), 1, 1, 1).expand_as(feature_map) + v_shift = v_shift.view(*v_shift.size(), 1, 1, 1).expand_as(feature_map) + if self.scale_activation is not None: + v_scale = self.scale_activation(v_scale) + elif self.scale is None: + v_scale = attention + v_scale = v_scale.view(*v_scale.size(), 1, 1, 1).expand_as(feature_map) + v_shift = self.shift + if self.scale_activation is not None: + v_scale = self.scale_activation(v_scale) + elif self.shift is None: + v_scale = self.scale + v_shift = attention + v_shift = v_shift.view(*v_shift.size(), 1, 1, 1).expand_as(feature_map) + else: + raise AssertionError( + f"Sanity checking on scale and shift failed. Must be of type bool or None: {self.scale}, {self.shift}" + ) + + return (v_scale * feature_map) + v_shift diff --git a/torchpanic/models/daft/vol_networks.py b/torchpanic/models/daft/vol_networks.py new file mode 100644 index 0000000..d927e9e --- /dev/null +++ b/torchpanic/models/daft/vol_networks.py @@ -0,0 +1,165 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +# +# +# +# +# This file is part of Dynamic Affine Feature Map Transform (DAFT). +# +# DAFT is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# DAFT is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with DAFT. If not, see . +from typing import Any, Dict, Optional, Sequence + +import torch +import torch.nn as nn + +from ...datamodule.modalities import ModalityType +from .vol_blocks import ConvBnReLU, DAFTBlock, ResBlock + + +class HeterogeneousResNet(nn.Module): + def __init__( + self, + in_channels: int = 1, + n_outputs: int = 3, + bn_momentum: int = 0.1, + n_basefilters: int = 4, + ) -> None: + super().__init__() + + self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum) + self.pool1 = nn.MaxPool3d(2, stride=2) # 32 + self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum) + self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 16 + self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 8 + self.block4 = ResBlock(4 * n_basefilters, 8 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 4 + self.global_pool = nn.AdaptiveAvgPool3d(1) + self.fc = nn.Linear(8 * n_basefilters, n_outputs) + + def forward(self, batch): + image = batch[self.image_modality] + + out = self.conv1(image) + out = self.pool1(out) + out = self.block1(out) + out = self.block2(out) + out = self.block3(out) + out = self.block4(out) + out = self.global_pool(out) + out = out.view(out.size(0), -1) + out = self.fc(out) + + # cannot return None, because lightning complains + terms = torch.zeros((out.shape[0], 1, out.shape[1],), device=out.device) + return out, terms + + +class DAFT(nn.Module): + def __init__( + self, + in_channels: int, + in_tabular: int, + n_outputs: int, + idx_tabular_has_missing: Sequence[int], + bn_momentum: float = 0.05, + n_basefilters: int = 4, + filmblock_args: Optional[Dict[str, Any]] = None, + **kwargs, + ) -> None: + super().__init__() + + if filmblock_args is None: + filmblock_args = {} + + self.image_modality = ModalityType.PET + + if min(idx_tabular_has_missing) < 0: + raise ValueError("idx_tabular_has_missing contains negative values") + if max(idx_tabular_has_missing) >= in_tabular: + raise ValueError("index in idx_tabular_has_missing is out of range") + idx_missing = frozenset(idx_tabular_has_missing) + if len(idx_tabular_has_missing) != len(idx_missing): + raise ValueError("idx_tabular_has_missing contains duplicates") + + self.register_buffer( + 'idx_tabular_has_missing', + torch.tensor(idx_tabular_has_missing, dtype=torch.long), + ) + self.register_buffer( + 'idx_tabular_without_missing', + torch.tensor(list(set(range(in_tabular)).difference(idx_missing)), dtype=torch.long), + ) + + n_missing = len(idx_tabular_has_missing) + self.tab_missing_embeddings = nn.Parameter( + torch.empty((1, n_missing,), dtype=torch.float32), requires_grad=True) + nn.init.xavier_uniform_(self.tab_missing_embeddings) + + self.split_size = 4 * n_basefilters + self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum) + self.pool1 = nn.MaxPool3d(3, stride=2) # 32 + self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum) + self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 16 + self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 8 + self.blockX = DAFTBlock( + 4 * n_basefilters, + 8 * n_basefilters, + bn_momentum=bn_momentum, + ndim_non_img=in_tabular, + **filmblock_args, + ) # 4 + self.global_pool = nn.AdaptiveAvgPool3d(1) + self.fc = nn.Linear(8 * n_basefilters, n_outputs) + + def forward(self, batch): + image = batch[self.image_modality] + + tabular_data = batch[ModalityType.TABULAR] + values, is_missing = torch.unbind(tabular_data, axis=1) + + values_wo_missing = values[:, self.idx_tabular_without_missing] + values_w_missing = values[:, self.idx_tabular_has_missing] + missing_in_batch = is_missing[:, self.idx_tabular_has_missing] + tabular_masked = torch.where( + missing_in_batch == 1.0, self.tab_missing_embeddings, values_w_missing + ) + + features = torch.cat( + (values_wo_missing, tabular_masked), dim=1, + ) + + out = self.conv1(image) + out = self.pool1(out) + out = self.block1(out) + out = self.block2(out) + out = self.block3(out) + out = self.blockX(out, features) + out = self.global_pool(out) + out = out.view(out.size(0), -1) + out = self.fc(out) + + # cannot return None, because lightning complains + terms = torch.zeros((out.shape[0], 1, out.shape[1],), device=out.device) + return out, terms diff --git a/torchpanic/models/nam.py b/torchpanic/models/nam.py new file mode 100644 index 0000000..e9773bd --- /dev/null +++ b/torchpanic/models/nam.py @@ -0,0 +1,295 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +from typing import Dict, Sequence, Tuple +import torch +from torch import nn +from torch.nn import init + +from ..datamodule.adni import ModalityType +from torchpanic.modules.utils import init_vector_normal + + +class ExU(nn.Module): + """exp-centered unit""" + def __init__(self, out_features: int) -> None: + super().__init__() + self.weight = nn.Parameter(torch.empty((1, out_features))) + self.bias = nn.Parameter(torch.empty((1, out_features))) + + self.reset_parameters() + + def reset_parameters(self) -> None: + init.normal_(self.weight, mean=4.0, std=0.5) + init.zeros_(self.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + assert x.ndim == 2 and x.shape[1] == 1 + out = torch.exp(self.weight) * (x - self.bias) + return out + + +class ReLUN(nn.Module): + def __init__(self, n: float = 1.0) -> None: + super().__init__() + self.n = n + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.clamp(x, min=0.0, max=self.n) + + +class FeatureNet(nn.Module): + """A neural network for a single feature""" + def __init__( + self, + out_features: int, + hidden_units: Sequence[int], + dropout_rate: float = 0.5, + ) -> None: + super().__init__() + in_features = hidden_units[0] + layers = { + "in": nn.Sequential( + nn.utils.weight_norm(nn.Linear(1, in_features)), + nn.ReLU()), + } + for i, units in enumerate(hidden_units[1:]): + layers[f"dense_{i}"] = nn.Sequential( + nn.utils.weight_norm(nn.Linear(in_features, units)), + nn.Dropout(p=dropout_rate), + nn.ReLU(), + ) + in_features = units + layers["dense_out"] = nn.utils.weight_norm(nn.Linear(in_features, out_features, bias=False)) + self.hidden_layers = nn.ModuleDict(layers) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = x + for layer in self.hidden_layers.values(): + out = layer(out) + return out + + +class NAM(nn.Module): + """Neural Additive Model + + .. [1] Neural Additive Models: Interpretable Machine Learning with Neural Nets. NeurIPS 2021 + https://proceedings.neurips.cc/paper/2021/hash/251bd0442dfcc53b5a761e050f8022b8-Abstract.html + """ + def __init__( + self, + in_features: int, + out_features: int, + hidden_units: Sequence[int], + dropout_rate: float = 0.5, + feature_dropout_rate: float = 0.5, + ) -> None: + super().__init__() + + layers = {} + for i in range(in_features): + layers[f"fnet_{i}"] = nn.Sequential( + FeatureNet( + out_features=out_features, + hidden_units=hidden_units, + dropout_rate=dropout_rate, + ), + ) + self.feature_nns = nn.ModuleDict(layers) + self.feature_dropout = nn.Dropout1d(p=feature_dropout_rate) + + self.bias = nn.Parameter(torch.empty((1, out_features))) + + self.reset_parameters() + + def reset_parameters(self) -> None: + init.zeros_(self.bias) + + def base_forward(self, tabular: torch.Tensor) -> torch.Tensor: + values, is_missing = torch.unbind(tabular, axis=1) + + # FIXME: Treat missing value in a better way + x = values * (1.0 - is_missing) + + features = torch.split(x, 1, dim=-1) + outputs = [] + for x_i, layer in zip(features, self.feature_nns.values()): + outputs.append(layer(x_i)) + outputs = torch.stack(outputs, dim=1) + logits = self.feature_dropout(outputs) + return logits, outputs + + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + tabular = batch[ModalityType.TABULAR] + logits, outputs = self.base_forward(tabular) + logits = torch.sum(logits, dim=1) + self.bias + return logits, outputs + + +def is_unique(x): + return len(x) == len(set(x)) + + +class SemiParametricNAM(NAM): + def __init__( + self, + idx_real_features: Sequence[int], + idx_cat_features: Sequence[int], + idx_real_has_missing: Sequence[int], + idx_cat_has_missing: Sequence[int], + out_features: int, + hidden_units: Sequence[int], + dropout_rate: float = 0.5, + feature_dropout_rate: float = 0.5, + ) -> None: + super().__init__( + in_features=len(idx_real_features), + out_features=out_features, + hidden_units=hidden_units, + dropout_rate=dropout_rate, + feature_dropout_rate=feature_dropout_rate, + ) + + assert is_unique(idx_real_features) + assert is_unique(idx_cat_features) + assert is_unique(idx_real_has_missing) + assert is_unique(idx_cat_has_missing) + + self.cat_linear = nn.Linear(len(idx_cat_features), out_features, bias=False) + n_missing = len(idx_real_has_missing) + len(idx_cat_has_missing) + self.miss_linear = nn.Linear(n_missing, out_features, bias=False) + + self._idx_real_features = idx_real_features + self._idx_cat_features = idx_cat_features + self._idx_real_has_missing = idx_real_has_missing + self._idx_cat_has_missing = idx_cat_has_missing + + def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: + tabular = batch[ModalityType.TABULAR] + values, is_missing = torch.unbind(tabular, axis=1) + + val_categ = values[:, self._idx_cat_features] + miss_categ = is_missing[:, self._idx_cat_features] + has_miss_categ = miss_categ[:, self._idx_cat_has_missing] + + val_real = values[:, self._idx_real_features] + miss_real = is_missing[:, self._idx_real_features] + has_miss_real = miss_real[:, self._idx_real_has_missing] + + out_real_full = self.forward_real(val_real) * torch.unsqueeze(1.0 - miss_real, dim=-1) + out_real = torch.sum(out_real_full, dim=1) + + out_categ = self.cat_linear(val_categ * (1.0 - miss_categ)) + + out_miss = self.miss_linear(torch.cat((has_miss_categ, has_miss_real), dim=1)) + + return sum((out_real, out_categ, out_miss, self.bias,)), out_real_full + + def forward_real(self, x): + features = torch.split(x, 1, dim=-1) + outputs = [] + for x_i, layer in zip(features, self.feature_nns.values()): + outputs.append(layer(x_i)) + + outputs = torch.stack(outputs, dim=1) + outputs = self.feature_dropout(outputs) + return outputs + + +class BaseNAM(NAM): + def __init__( + self, + idx_real_features: Sequence[int], + idx_cat_features: Sequence[int], + idx_real_has_missing: Sequence[int], + idx_cat_has_missing: Sequence[int], + out_features: int, + hidden_units: Sequence[int], + dropout_rate: float = 0.5, + feature_dropout_rate: float = 0.5, + **kwargs + ) -> None: + super().__init__( + in_features=len(idx_real_features), + out_features=out_features, + hidden_units=hidden_units, + dropout_rate=dropout_rate, + feature_dropout_rate=feature_dropout_rate, + ) + + assert is_unique(idx_real_features) + assert is_unique(idx_cat_features) + assert is_unique(idx_real_has_missing) + assert is_unique(idx_cat_has_missing) + + n_missing = len(idx_real_has_missing) + len(idx_cat_has_missing) + self.tab_missing_embeddings = nn.Parameter( + torch.empty((n_missing, out_features), dtype=torch.float32), requires_grad=True) + nn.init.xavier_uniform_(self.tab_missing_embeddings) + + self.cat_linear = nn.Parameter( + torch.empty((len(idx_cat_features), out_features), dtype=torch.float32, requires_grad=True)) + nn.init.xavier_uniform_(self.cat_linear) + + self._idx_real_features = idx_real_features + self._idx_cat_features = idx_cat_features + self._idx_real_has_missing = idx_real_has_missing + self._idx_cat_has_missing = idx_cat_has_missing + + def forward_real(self, x): + features = torch.split(x, 1, dim=-1) + outputs = [] + for x_i, layer in zip(features, self.feature_nns.values()): + outputs.append(layer(x_i)) + + outputs = torch.stack(outputs, dim=1) + return outputs + + def base_forward(self, tabular: torch.Tensor) -> torch.Tensor: + values, is_missing = torch.unbind(tabular, dim=1) + + val_real = values[:, self._idx_real_features] + miss_real = is_missing[:, self._idx_real_features] + has_miss_real = miss_real[:, self._idx_real_has_missing] + # TODO for all miss_real==1, check that its index is in self._idx_real_has_missing + + val_categ = values[:, self._idx_cat_features] + miss_categ = is_missing[:, self._idx_cat_features] + has_miss_categ = miss_categ[:, self._idx_cat_has_missing] + # TODO for all miss_categ==1, check that its index is in self._idx_categ_has_missing + + features_real = self.forward_real(val_real) + features_categ = self.cat_linear.unsqueeze(0) * val_categ.unsqueeze(-1) + + features_real = features_real * (1.0 - miss_real.unsqueeze(-1)) # set features to zero where they are mising + features_categ = features_categ * (1.0 - miss_categ.unsqueeze(-1)) + + filler_real = torch.zeros_like(features_real) + filler_categ = torch.zeros_like(features_categ) + + filler_real[:, self._idx_real_has_missing, :] = \ + self.tab_missing_embeddings[len(self._idx_cat_has_missing):].unsqueeze(0) * has_miss_real.unsqueeze(-1) + filler_categ[:, self._idx_cat_has_missing, :] = \ + self.tab_missing_embeddings[:len(self._idx_cat_has_missing)].unsqueeze(0) * has_miss_categ.unsqueeze(-1) + + features_real = features_real + filler_real # filler only has values where real is 0 + features_categ = features_categ + filler_categ + + return torch.cat((features_real, features_categ), dim=1) + + def forward(self, tabular: torch.Tensor) -> torch.Tensor: + if isinstance(tabular, dict): + tabular = tabular[ModalityType.TABULAR] + features = self.base_forward(tabular) + return torch.sum(self.feature_dropout(features), dim=1) + self.bias, features diff --git a/torchpanic/models/panic.py b/torchpanic/models/panic.py new file mode 100644 index 0000000..f984a59 --- /dev/null +++ b/torchpanic/models/panic.py @@ -0,0 +1,94 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +from typing import Any, Dict +import torch +from torchvision import models + +from .protowrapper import ProtoWrapper +from .nam import BaseNAM +from ..models.backbones import ThreeDResNet + +BACKBONES = { + 'resnet18': (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1), + 'resnet34': (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1), + 'resnet50': (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V1), + 'resnet101': (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V1), + 'resnet152': (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V1), + '3dresnet': (ThreeDResNet, None), +} + + +def flatten_module(module): + children = list(module.children()) + flat_children = [] + if children == []: + return module + else: + for child in children: + try: + flat_children.extend(flatten_module(child)) + except TypeError: + flat_children.append(flatten_module(child)) + return flat_children + + +def deactivate_features(features): + for param in features.parameters(): + param.requires_grad = False + + +class PANIC(ProtoWrapper): + def __init__( + self, + protonet: Dict[Any, Any], + nam: Dict[Any, Any], + ) -> None: + super().__init__( + **protonet + ) + # must tidy up prototype vector dimensions! + self.classification = None + self.nam = BaseNAM( + **nam + ) + + def forward_image(self, image): + + return self.base_forward(image) + + def forward(self, image, tabular): + + feature_vectors, similarities, occurrences = self.forward_image(image) + # feature_vectors shape is (bs, n_protos, n_chans_per_prot) + # similarities is of shape (bs, n_protos) + similarities_reshaped = similarities.view( + similarities.size(0), self.num_classes, self.n_prototypes_per_class) + similarities_reshaped = similarities_reshaped.permute(0, 2, 1) + # this maps similarities such that we have the similarities w.r.t. each class: + # new shape is (bs, n_protos_per_class, n_classes) + + nam_features = self.nam.base_forward(tabular) + # nam_features shape is (bs, n_features, n_classes) + + features = torch.cat((similarities_reshaped, nam_features), dim=1) + + logits = self.nam.feature_dropout(features) + logits = torch.sum(logits, dim=1) + self.nam.bias + return logits, similarities, occurrences, features + + @torch.no_grad() + def push_forward(self, image, tabular): + + return self.forward_image(image) diff --git a/torchpanic/models/ppnet.py b/torchpanic/models/ppnet.py new file mode 100644 index 0000000..0192dee --- /dev/null +++ b/torchpanic/models/ppnet.py @@ -0,0 +1,196 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +# +# Original License: +# MIT License +# +# Copyright (c) 2019 Chaofan Chen (cfchen-duke), Oscar Li (OscarcarLi), +# Chaofan Tao, Alina Jade Barnett, Cynthia Rudin +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import torch +import torch.nn as nn +from torchvision import models + +from .backbones import ThreeDResNet + +BACKBONES = { + 'resnet18': (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1), + 'resnet34': (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1), + 'resnet50': (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V1), + 'resnet101': (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V1), + 'resnet152': (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V1), + '3dresnet': (ThreeDResNet, None), +} + + +def flatten_module(module): + children = list(module.children()) + flat_children = [] + if children == []: + return module + else: + for child in children: + try: + flat_children.extend(flatten_module(child)) + except TypeError: + flat_children.append(flatten_module(child)) + return flat_children + + +class PPNet(nn.Module): + def __init__( + self, + backbone: str, + in_channels: int, + out_features: int, + n_prototypes_per_class: int, + n_chans_protos: int, + optim_features: bool, + normed_prototypes: bool, + **kwargs, + ) -> None: + super().__init__() + assert backbone in BACKBONES.keys(), f"cannot find backbone {backbone} in valid BACKBONES {BACKBONES.keys()}!" + + self.normed_prototypes = normed_prototypes + self.n_prototypes_per_class = n_prototypes_per_class + + features, weights = BACKBONES[backbone] + if backbone.startswith("resnet"): + weights = weights.DEFAULT + if weights is None: + if 'pretrained_model' in kwargs: + pretrained_model = kwargs.pop('pretrained_model') + else: + pretrained_model = None + weights = kwargs + weights = kwargs + weights["in_channels"] = in_channels + weights["n_outputs"] = out_features + else: + weights = {"weights": weights} + features = features(**weights) + features = list(features.children()) + fc_layer = features[-1] + features = features[:-2] # remove GAP and last fc layer! + in_layer = features[0] + assert isinstance(in_layer, nn.Conv2d) or isinstance(in_layer, nn.Conv3d) + self.two_d = isinstance(in_layer, nn.Conv2d) + self.conv_func = nn.Conv2d if self.two_d else nn.Conv3d + if in_layer.in_channels != in_channels: + assert optim_features, "Different num input channels -> must optim!" + in_layer = self.conv_func( + in_channels=in_channels, + out_channels=in_layer.out_channels, + kernel_size=in_layer.kernel_size, + stride=in_layer.stride, + padding=in_layer.padding, + ) + in_layer.requires_grad = False + features[0] = in_layer + self.num_classes = out_features + self.n_features_encoder = fc_layer.in_features + self.prototype_shape = (n_prototypes_per_class * self.num_classes, n_chans_protos) + self.num_prototypes = self.prototype_shape[0] + + # prototype_activation_function could be 'log', 'linear', + # or a generic function that converts distance to similarity score + + ''' + Here we are initializing the class identities of the prototypes + Without domain specific knowledge we allocate the same number of + prototypes for each class + ''' + assert (self.num_prototypes % self.num_classes == 0), \ + f"{self.num_prototypes} vs {self.num_classes}" # not needed as we initialize differently + # a onehot indication matrix for each prototype's class identity + self.prototype_class_identity = nn.Parameter( + torch.zeros(self.num_prototypes, self.num_classes), + requires_grad=False) + + num_prototypes_per_class = self.num_prototypes // self.num_classes + for j in range(self.num_prototypes): + self.prototype_class_identity[j, j // num_prototypes_per_class] = 1 + + self.features = torch.nn.Sequential(*features) + + self.add_on_layers = nn.Sequential( + self.conv_func(in_channels=self.n_features_encoder, out_channels=self.prototype_shape[1], kernel_size=1), + nn.ReLU(), + self.conv_func(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1), + nn.Softplus() + ) + self.occurrence_module = nn.Sequential( + self.conv_func(in_channels=self.n_features_encoder, out_channels=self.n_features_encoder // 8, kernel_size=1), + nn.ReLU(), + self.conv_func(in_channels=self.n_features_encoder // 8, out_channels=self.prototype_shape[0], kernel_size=1), + nn.Sigmoid() + ) + self.gap = (-2, -1) if self.two_d else (-3, -2, -1) # if 3D network, pool h,w and d + self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape), requires_grad=True) + if self.normed_prototypes: + self.prototype_vectors.data = (self.prototype_vectors.data / torch.linalg.vector_norm( + self.prototype_vectors.data, ord=2, dim=1, keepdim=True, dtype=torch.double)).to(torch.float32) + # nn.init.xavier_uniform_(self.prototype_vectors) + self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1) + + self.ones = nn.Parameter(torch.ones(self.prototype_shape), + requires_grad=False) + + if pretrained_model is not None: + state_dict = torch.load(pretrained_model, map_location=self.prototype_vectors.device)['state_dict'] + state_dict = {x.replace('net.features.', ''): y for x, y in state_dict.items()} + state_dict = {x: y for x, y in state_dict.items() if x in self.features.state_dict().keys()} + self.features.load_state_dict(state_dict) + if not optim_features: + for x in self.features.parameters(): + x.requires_grad = False + + self.classification = nn.Linear(self.num_prototypes, self.num_classes, + bias=False) # do not use bias + + self.set_last_layer_incorrect_connection(incorrect_strength=-0.5) + + def set_last_layer_incorrect_connection(self, incorrect_strength): + ''' + the incorrect strength will be actual strength if -0.5 then input -0.5 + ''' + positive_one_weights_locations = torch.t(self.prototype_class_identity) + negative_one_weights_locations = 1 - positive_one_weights_locations + + correct_class_connection = 1 + incorrect_class_connection = incorrect_strength + self.classification.weight.data.copy_( + correct_class_connection * positive_one_weights_locations + + incorrect_class_connection * negative_one_weights_locations) diff --git a/torchpanic/models/pretrain_encoder.py b/torchpanic/models/pretrain_encoder.py new file mode 100644 index 0000000..4625723 --- /dev/null +++ b/torchpanic/models/pretrain_encoder.py @@ -0,0 +1,67 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +from typing import Any, Dict +import torch + +from .protowrapper import ProtoWrapper +from torchpanic.datamodule.adni import ModalityType + + +def flatten_module(module): + children = list(module.children()) + flat_children = [] + if children == []: + return module + else: + for child in children: + try: + flat_children.extend(flatten_module(child)) + except TypeError: + flat_children.append(flatten_module(child)) + return flat_children + + +def deactivate_features(features): + for param in features.parameters(): + param.requires_grad = False + + +class Encoder(torch.nn.Module): + def __init__( + self, + protonet: Dict[Any, Any], + ) -> None: + super().__init__() + wrapper = ProtoWrapper( + **protonet + ) + self.features = wrapper.features + self.gap = torch.nn.AdaptiveAvgPool3d(1) + out_chans_encoder = self.features[-1][-1].conv2.out_channels + self.classification = torch.nn.Linear(out_chans_encoder, protonet['out_features']) + self.nam_term_faker = None + + def forward(self, x): + + x = x[ModalityType.PET] + out = self.features(x) + out = self.gap(out) + out = out.view(out.size(0), -1) + out = self.classification(out) + if self.nam_term_faker is None: + self.nam_term_faker = torch.zeros_like(out).unsqueeze(-1) + return out, self.nam_term_faker + + diff --git a/torchpanic/models/protowrapper.py b/torchpanic/models/protowrapper.py new file mode 100644 index 0000000..f69e787 --- /dev/null +++ b/torchpanic/models/protowrapper.py @@ -0,0 +1,93 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import torch + +from .ppnet import PPNet + +def flatten_module(module): + children = list(module.children()) + flat_children = [] + if children == []: + return module + else: + for child in children: + try: + flat_children.extend(flatten_module(child)) + except TypeError: + flat_children.append(flatten_module(child)) + return flat_children + + +class ProtoWrapper(PPNet): + def __init__( + self, + backbone: str, + in_channels: int, + out_features: int, + n_prototypes_per_class: int, + n_chans_protos: int, + optim_features: bool, + normed_prototypes: bool, + **kwargs, + ) -> None: + super().__init__( + backbone, + in_channels, + out_features, + n_prototypes_per_class, + n_chans_protos, + optim_features, + normed_prototypes, + **kwargs) + + def base_forward(self, x): + + x = self.features(x) + occurrences = self.occurrence_module(x) # (bs, n_prototypes) + feature_map = self.add_on_layers(x) # (bs, n_chans_proto, h, w[, d]) + + # the following should work for 3D and 2D data! + broadcasting_shape = (occurrences.size(0), occurrences.size(1), feature_map.size(1), *feature_map.size()[2:]) + # of shape (bs, n_protos, n_chans_per_prot, h, w[, d]) + + # expand the two such that broadcasting is possible, i.e. vectorization of prototype feature calculation + occurrences_reshaped = occurrences.unsqueeze(2).expand(broadcasting_shape) + feature_map_reshaped = feature_map.unsqueeze(1).expand(broadcasting_shape) + + # element-wise multiplication of each occurence map with the featuremap + feature_vectors = occurrences_reshaped * feature_map_reshaped + feature_vectors = feature_vectors.mean(dim=self.gap) # essentially GAP over the spatial resolutions + # feature_vectors size is now (bs, n_protos, n_chans_per_prot) + # prototype_vectors size is (n_protos, n_chans_per_prot) + if self.normed_prototypes: + feature_vectors = (feature_vectors / torch.linalg.vector_norm( + feature_vectors, ord=2, dim=2, keepdim=True, dtype=torch.double)).to(torch.float32) + + # make prototypes broadcastable to featuer vectors + similarities = self.cosine_similarity(feature_vectors, self.prototype_vectors.unsqueeze(0)) + return feature_vectors, similarities, occurrences + + def forward(self, x): + + _, similarities, occurrences = self.base_forward(x) + + logits = self.classification(similarities) + + return logits, similarities, occurrences + + @torch.no_grad() + def push_forward(self, x): + + return self.base_forward(x) diff --git a/torchpanic/modules/__init__.py b/torchpanic/modules/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchpanic/modules/base.py b/torchpanic/modules/base.py new file mode 100644 index 0000000..b0c721d --- /dev/null +++ b/torchpanic/modules/base.py @@ -0,0 +1,131 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import logging +from typing import Any, List + +import pytorch_lightning as pl +from pytorch_lightning.loggers import TensorBoardLogger +import torch +from torchmetrics import Accuracy, ConfusionMatrix, MaxMetric + +from .utils import get_git_hash + +from ..datamodule.modalities import DataPointType + + +LOG = logging.getLogger(__name__) + + +class BaseModule(pl.LightningModule): + def __init__( + self, + net: torch.nn.Module, + num_classes: int, + ) -> None: + super().__init__() + + self.net = net + task = "binary" if num_classes <= 2 else "multiclass" + # use separate metric instance for train, val and test step + # to ensure a proper reduction over the epoch + self.train_acc = Accuracy(task=task, num_classes=num_classes) + self.val_acc = Accuracy(task=task, num_classes=num_classes) + self.test_acc = Accuracy(task=task, num_classes=num_classes) + self.val_cmat = ConfusionMatrix(task=task, num_classes=num_classes) + self.test_cmat = ConfusionMatrix(task=task, num_classes=num_classes) + + # for logging best so far validation accuracy + self.val_acc_best = MaxMetric() + self.val_bacc_best = MaxMetric() + + def _get_balanced_accuracy_from_confusion_matrix(self, confusion_matrix: ConfusionMatrix): + # Confusion matrix whose i-th row and j-th column entry indicates + # the number of samples with true label being i-th class and + # predicted label being j-th class. + cmat = confusion_matrix.compute() + per_class = cmat.diag() / cmat.sum(dim=1) + per_class = per_class[~torch.isnan(per_class)] # remove classes that are not present in this dataset + LOG.debug("Confusion matrix:\n%s", cmat) + + return per_class.mean() + + def _log_train_metrics( + self, loss: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor, + ) -> None: + self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False) + acc = self.train_acc(preds, targets) + self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True) + + def _update_validation_metrics( + self, loss: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor, + ) -> None: + self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False) + self.val_acc.update(preds, targets) + self.val_cmat.update(preds, targets) + + def _log_validation_metrics(self) -> None: + acc = self.val_acc.compute() # get val accuracy from current epoch + self.log("val/acc", acc, on_epoch=True) + + self.val_acc_best.update(acc) + self.log("val/acc_best", self.val_acc_best.compute(), on_epoch=True) + + # compute balanced accuracy + bacc = self._get_balanced_accuracy_from_confusion_matrix(self.val_cmat) + self.val_bacc_best.update(bacc) + self.log("val/bacc", bacc, on_epoch=True, prog_bar=True) + self.log("val/bacc_best", self.val_bacc_best.compute(), on_epoch=True) + + # reset metrics at the end of every epoch + self.val_acc.reset() + self.val_cmat.reset() + + def _update_test_metrics( + self, loss: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor, + ) -> None: + self.log("test/loss", loss, on_step=False, on_epoch=True) + self.test_acc.update(preds, targets) + self.test_cmat.update(preds, targets) + + def _log_test_metrics(self) -> None: + acc = self.test_acc.compute() + self.log("test/acc", acc) + + # compute balanced accuracy + bacc = self._get_balanced_accuracy_from_confusion_matrix(self.test_cmat) + self.log("test/bacc", bacc) + + # reset metrics at the end of every epoch + self.test_acc.reset() + self.test_cmat.reset() + + def on_train_start(self): + # by default lightning executes validation step sanity checks before training starts, + # so we need to make sure val_acc_best doesn't store accuracy from these checks + self.val_acc_best.reset() + self.val_bacc_best.reset() + + if isinstance(self.logger, TensorBoardLogger): + tb_logger = self.logger.experiment + # tb_logger.add_hparams({"git-commit": get_git_hash()}, {"hp_metric": -1}) + tb_logger.flush() + + def training_epoch_end(self, outputs: List[Any]): + # `outputs` is a list of dicts returned from `training_step()` + # reset metrics at the end of every epoch + self.train_acc.reset() + + def forward(self, x: DataPointType): + return self.net(x) diff --git a/torchpanic/modules/panic.py b/torchpanic/modules/panic.py new file mode 100644 index 0000000..5d56407 --- /dev/null +++ b/torchpanic/modules/panic.py @@ -0,0 +1,527 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import logging +import math +from pathlib import Path +from typing import Any, List +import warnings + +import numpy as np +# from skimage.transform import resize + +from pytorch_lightning.loggers import TensorBoardLogger +import torch +from torch.nn.functional import l1_loss +from torchmetrics import MaxMetric + +from ..datamodule.modalities import BatchWithLabelType, ModalityType +from .base import BaseModule + +LOG = logging.getLogger(__name__) + +STAGE2FLOAT = {"warmup": 0.0, "warmup_protonet": 1.0, "all": 2.0, "nam_only": 3.0} + + +# @torch.no_grad() +# def init_weights(m: torch.Tensor): +# if isinstance(m, nn.Conv3d): +# # every init technique has an underscore _ in the name +# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') +# if getattr(m, "bias", None) is not None: +# nn.init.zeros_(m.bias) +# elif isinstance(m, nn.BatchNorm3d): +# nn.init.constant_(m.weight, 1) +# nn.init.constant_(m.bias, 0) + + +class PANIC(BaseModule): + def __init__( + self, + net: torch.nn.Module, + weight_decay_nam: float = 0.0, + lr: float = 0.001, + weight_decay: float = 0.0005, + l_clst: float = 0.8, + l_sep: float = 0.08, + l_occ: float = 1e-4, + l_affine: float = 1e-4, + l_nam: float = 1e-4, + epochs_all: int = 3, + epochs_nam: int = 4, + epochs_warmup: int = 10, + enable_checkpointing: bool = True, + monitor_prototypes: bool = False, # wether to save prototypes of all push epochs or just the best one + enable_save_embeddings: bool = False, + enable_log_prototypes: bool = False, + **kwargs, + ): + super().__init__( + net=net, + num_classes=net.num_classes, + ) + self.automatic_optimization = False + + # this line allows to access init params with 'self.hparams' attribute + # it also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False, ignore=["net"]) + self._current_stage = None + self._current_optimizer = None + + # FIXME only picks PET + self.image_modality = ModalityType.PET + + self.net = net + if self.net.two_d: + self.reduce_dims = (2, 3,) # inputs are 4d tensors + else: + self.reduce_dims = (2, 3, 4,) # inputs are 5d tensors + + # loss function + self.criterion = torch.nn.CrossEntropyLoss() + + self.val_bacc_save = MaxMetric() + + def _cluster_losses(self, similarities, y): + prototypes_of_correct_class = torch.t(self.net.prototype_class_identity[:, y]).bool() + # select mask for each sample in batch. Shape is (bs, n_prototypes) + other_tensor = torch.tensor(0, device=similarities.device) # minimum value of the cosine similarity + similarities_correct_class = similarities.where(prototypes_of_correct_class, other=other_tensor) + # min value of cos similarity doesnt effect result + similarities_incorrect_class = similarities.where(~prototypes_of_correct_class, other=other_tensor) + # same here: if distance to other protos of other class is the minimum value + # of the cos similarity, they are distant to eachother! + clst = 1 - torch.max(similarities_correct_class, dim=1).values.mean() + sep = torch.max(similarities_incorrect_class, dim=1).values.mean() + + return clst, sep + + def _occurence_map_losses(self, occurrences, x_raw, aug): + with torch.no_grad(): + _, _, occurrences_raw = self.net.forward_image(x_raw) + + if self.net.two_d: + occurrences_raw = occurrences_raw.unsqueeze(-1) + occurrences_raw = occurrences_raw.cpu() + + for i, aug_i in enumerate(aug): + occurrences_raw[i] = aug_i(occurrences_raw[i]) + + if self.net.two_d: + occurrences_raw = occurrences_raw.squeeze(-1) + occurrences_raw = occurrences_raw.to(occurrences.device) + + affine = l1_loss(occurrences_raw, occurrences) + + # l1 penalty on occurence maps + l1 = torch.linalg.vector_norm(occurrences, ord=1, dim=self.reduce_dims).mean() + l1_norm = math.prod(occurrences.size()[1:]) # omit batch dimension + l1 = (l1 / l1_norm).mean() + return affine, l1 + + def _classification_loss(self, logits, targets): + preds = torch.argmax(logits, dim=1) + + xentropy = self.criterion(logits, targets) + return xentropy, preds + + def forward(self, batch: BatchWithLabelType): + x = batch[0] + image = x[self.image_modality] + tabular = x[ModalityType.TABULAR] + + return self.net(image, tabular) + + def on_train_epoch_start(self) -> None: + cur_stage = self._get_current_stage() + LOG.info("Epoch %d, optimizing %s", self.trainer.current_epoch, cur_stage) + + self.log("train/stage", STAGE2FLOAT[cur_stage]) + + optim_warmup, optim_warmup_protonet, optim_all, optim_nam = self.optimizers() + scheduler_warmup, scheduler_all, scheduler_nam = self.lr_schedulers() + if cur_stage == "warmup": + opt = optim_warmup + sched = scheduler_warmup + elif cur_stage == "warmup_protonet": + opt = optim_warmup_protonet + sched = scheduler_warmup + elif cur_stage == "all": + opt = optim_all + sched = scheduler_all + elif cur_stage == "nam_only": + opt = optim_nam + sched = scheduler_nam + self.push_prototypes() + else: + raise AssertionError() + + self._current_stage = cur_stage + self._current_optimizer = opt + self._current_scheduler = sched + + def training_step(self, batch: BatchWithLabelType, batch_idx: int): + cur_stage = self._current_stage + + x_raw, aug, y = batch[1:] + x_raw = x_raw[self.image_modality] + aug = aug[self.image_modality] + + logits, similarities, occurrences, nam_terms = self.forward(batch) + xentropy, preds = self._classification_loss(logits, y) + self.log("train/xentropy", xentropy) + losses = [xentropy] + + if cur_stage != "nam_only": + # cluster and seperation cost + clst, sep = self._cluster_losses(similarities, y) + losses.append(self.hparams.l_clst * clst) + losses.append(self.hparams.l_sep * sep) + + self.log("train/clst", clst) + self.log("train/sep", sep) + + # regularization of occurrence map + affine, l1 = self._occurence_map_losses(occurrences, x_raw, aug) + losses.append(self.hparams.l_affine * affine) + losses.append(self.hparams.l_occ * l1) + + self.log("train/affine", affine) + self.log("train/l1", l1) + + if cur_stage != "occ_and_feats": + # l2 penalty on terms of nam + start_index = self.net.n_prototypes_per_class + nam_penalty = nam_terms[:, start_index:, :].square().sum(dim=1).mean() + losses.append(self.hparams.l_nam * nam_penalty) + + self.log("train/nam_l2", nam_penalty) + + loss = sum(losses) + + opt = self._current_optimizer + opt.zero_grad() + self.manual_backward(loss) + opt.step() + + self._current_scheduler.step() + + self._log_train_metrics(loss, preds, y) + + return {"loss": loss, "preds": preds, "targets": y} + + def _log_train_metrics( + self, loss: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor, + ) -> None: + self.log(f"train/loss/{self._current_stage}", loss, on_step=False, on_epoch=True, prog_bar=False) + acc = self.train_acc(preds, targets) + self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True) + + def _get_current_stage(self, epoch=None): + total = self.hparams.epochs_all + self.hparams.epochs_nam + + stage = "nam_only" + if self.current_epoch < 0.5 * self.hparams.epochs_warmup: + stage = "warmup" + elif self.current_epoch < self.hparams.epochs_warmup: + stage = "warmup_protonet" + elif (self.current_epoch - self.hparams.epochs_warmup) % total < self.hparams.epochs_all: + stage = "all" + + return stage + + def training_epoch_end(self, outputs: List[Any]): + if self.hparams.enable_log_prototypes and isinstance(self.logger, TensorBoardLogger): + tb_logger = self.logger.experiment + + tb_logger.add_histogram( + "train/prototypes", self.net.prototype_vectors, global_step=self.trainer.global_step, + ) + + return super().training_epoch_end(outputs) + + def validation_step(self, batch: BatchWithLabelType, batch_idx: int): + logits, similarities, occurrences, nam_terms = self.forward(batch) + + targets = batch[-1] + loss, preds = self._classification_loss(logits, targets) + + self._update_validation_metrics(loss, preds, targets) + + return {"loss": loss, "preds": preds, "targets": targets} + + def validation_epoch_end(self, outputs: List[Any]): + # compute balanced accuracy + bacc = self._get_balanced_accuracy_from_confusion_matrix(self.val_cmat) + + cur_stage = self._current_stage + # every 10th epoch is a last layer optim epoch + if cur_stage == "nam_only": + self.val_bacc_save.update(bacc) + saver = bacc + else: + self.val_bacc_save.update(torch.tensor(0., dtype=torch.float32)) + saver = torch.tensor(-float('inf'), dtype=torch.float32, device=bacc.device) + self.log("val/bacc_save_monitor", self.val_bacc_save.compute(), on_epoch=True) + self.log("val/bacc_save", saver) + + self._log_validation_metrics() + + def test_step(self, batch: BatchWithLabelType, batch_idx: int): + logits, similarities, occurrences, nam_terms = self.forward(batch) + + targets = batch[-1] + loss, preds = self._classification_loss(logits, targets) + + self._update_test_metrics(loss, preds, targets) + + return {"loss": loss, "preds": preds, "targets": targets} + + def test_epoch_end(self, outputs: List[Any]): + self._log_test_metrics() + + def configure_optimizers(self): + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + See examples here: + https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers + """ + + prototypes = { + 'params': [self.net.prototype_vectors], + 'lr': self.hparams.lr, + 'weight_decay': 0.0, + } + nam = {name: p for name, p in self.net.nam.named_parameters() if p.requires_grad} + embeddings = [nam.pop('tab_missing_embeddings')] + embeddings = { + 'params': embeddings, + 'lr': self.hparams.lr, + 'weight_decay': 0.0, + } + nam = { + 'params': list(nam.values()), + 'lr': self.hparams.lr, + 'weight_decay': self.hparams.weight_decay_nam, + } + encoder = { + 'params': [p for p in self.net.features.parameters() if p.requires_grad], + 'lr': self.hparams.lr, + 'weight_decay': self.hparams.weight_decay, + } + occurrence_and_features = { + 'params': [p for p in self.net.add_on_layers.parameters() if p.requires_grad] + + [p for p in self.net.occurrence_module.parameters() if p.requires_grad], + 'lr': self.hparams.lr, + 'weight_decay': self.hparams.weight_decay, + } + if len(encoder['params']) == 0: + warnings.warn("Encoder seems to be frozen! No parameters require grad.") + assert len(nam['params']) > 0 + assert len(prototypes['params']) > 0 + assert len(embeddings['params']) > 0 + assert len(occurrence_and_features['params']) > 0 + optim_all = torch.optim.AdamW([ + encoder, occurrence_and_features, prototypes, nam, embeddings + ]) + optim_nam = torch.optim.AdamW([ + nam, embeddings + ]) + optim_warmup = torch.optim.AdamW([ + occurrence_and_features, prototypes + ]) + optim_warmup_protonet = torch.optim.AdamW([ + encoder, occurrence_and_features, prototypes + ]) + + training_iterations = len(self.trainer.datamodule.train_dataloader()) + LOG.info("Number of iterations for one epoch: %d", training_iterations) + + # stepping through warmup_protonet is sufficient, as all parameters groups are also in warmup + scheduler_kwargs = {'max_lr': self.hparams.lr, 'cycle_momentum': False} + scheduler_warmup = torch.optim.lr_scheduler.CyclicLR( + optim_warmup_protonet, + base_lr=self.hparams.lr / 20, + max_lr=self.hparams.lr / 10, + step_size_up=self.hparams.epochs_warmup * training_iterations, + cycle_momentum=False + ) + scheduler_all = torch.optim.lr_scheduler.CyclicLR( + optim_all, + base_lr=self.hparams.lr / 10, + step_size_up=(self.hparams.epochs_all * training_iterations) / 2, + step_size_down=(self.hparams.epochs_all * training_iterations) / 2, + **scheduler_kwargs + ) + scheduler_nam = torch.optim.lr_scheduler.CyclicLR( + optim_nam, + base_lr=self.hparams.lr / 10, + step_size_up=(self.hparams.epochs_nam * training_iterations) / 2, + step_size_down=(self.hparams.epochs_nam * training_iterations) / 2, + **scheduler_kwargs + ) + + return ([optim_warmup, optim_warmup_protonet, optim_all, optim_nam], + [scheduler_warmup, scheduler_all, scheduler_nam]) + + def push_prototypes(self): + LOG.info("Pushing protoypes. epoch=%d, step=%d", self.current_epoch, self.trainer.global_step) + + self.net.eval() + + prototype_shape = self.net.prototype_shape + n_prototypes = self.net.num_prototypes + + global_max_proto_dist = np.full(n_prototypes, np.NINF) + # global_max_proto_dist = np.ones(n_prototypes) * -1 + global_max_fmap = np.zeros(prototype_shape) + global_img_indices = np.zeros(n_prototypes, dtype=np.int) + global_img_classes = - np.ones(n_prototypes, dtype=np.int) + + if self.hparams.monitor_prototypes: + proto_epoch_dir = Path(self.trainer.log_dir) / f"prototypes_epoch_{self.current_epoch}" + else: + proto_epoch_dir = Path(self.trainer.log_dir) / "prototypes_best" + if self.hparams.enable_checkpointing: + proto_epoch_dir.mkdir(exist_ok=True) + + push_dataloader = self.trainer.datamodule.push_dataloader() + search_batch_size = push_dataloader.batch_size + + num_classes = self.net.num_classes + + save_embedding = self.hparams.enable_save_embeddings and isinstance(self.logger, TensorBoardLogger) + + # indicates which class a prototype belongs to + proto_class = torch.argmax(self.net.prototype_class_identity, dim=1).detach().cpu().numpy() + + embedding_data = [] + embedding_labels = [] + for push_iter, (search_batch_input, _, _, search_y) in enumerate(push_dataloader): + + start_index_of_search_batch = push_iter * search_batch_size + + feature_vectors = self.update_prototypes_on_batch( + search_batch_input, + start_index_of_search_batch, + global_max_proto_dist, + global_max_fmap, + global_img_indices, + global_img_classes, + num_classes, + search_y, + proto_epoch_dir, + ) + + if save_embedding: + # for each batch, split into one feature vector for each prototype + embedding_data.extend(feature_vectors[:, j] for j in range(n_prototypes)) + embedding_labels.append(np.repeat(proto_class, feature_vectors.shape[0])) + + prototype_update = np.reshape(global_max_fmap, tuple(prototype_shape)) + + if self.hparams.enable_checkpointing: + np.save(proto_epoch_dir / f"p_similarities_{self.current_epoch}.npy", global_max_proto_dist) + np.save(proto_epoch_dir / f"p_feature_maps_{self.current_epoch}.npy", global_max_fmap) + np.save(proto_epoch_dir / f"p_inp_indices_{self.current_epoch}.npy", global_img_indices) + np.save(proto_epoch_dir / f"p_inp_img_labels_{self.current_epoch}.npy", global_img_classes) + + if save_embedding: + tb_logger = self.logger.experiment + + embedding_data.append(self.net.prototype_vectors.detach().cpu().numpy()) + embedding_data.append(prototype_update) + embedding_labels = np.concatenate(embedding_labels) + metadata = [f"FV Class {i}" for i in embedding_labels] + + metadata.extend(f"Old PV Class {i}" for i in proto_class) + metadata.extend(f"New PV Class {i}" for i in proto_class) + + tb_logger.add_embedding( + mat=np.concatenate(embedding_data, axis=0), + metadata=metadata, + global_step=self.trainer.global_step, + tag="push_prototypes", + ) + + self.net.prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32, device=self.device)) + + self.net.train() + + def update_prototypes_on_batch( + self, + search_batch, + start_index_of_search_batch, + global_max_proto_dist, + global_max_fmap, + global_img_indices, + global_img_classes, + num_classes, + search_y, + proto_epoch_dir, + ): + self.net.eval() + + feats, sims, occ = self.net.push_forward( + search_batch[self.image_modality].to(self.device), + search_batch[ModalityType.TABULAR].to(self.device)) + + feature_vectors = np.copy(feats.detach().cpu().numpy()) + similarities = np.copy(sims.detach().cpu().numpy()) + occurrences = np.copy(occ.detach().cpu().numpy()) + + del feats, sims, occ + + class_to_img_index = {key: [] for key in range(num_classes)} + for img_index, img_y in enumerate(search_y): + img_label = img_y.item() + class_to_img_index[img_label].append(img_index) + + prototype_shape = self.net.prototype_shape + n_prototypes = prototype_shape[0] + + for j in range(n_prototypes): + target_class = torch.argmax(self.net.prototype_class_identity[j]).item() + if len(class_to_img_index[target_class]) == 0: # none of the images belongs to the class of this prototype + continue + proto_dist_j = similarities[class_to_img_index[target_class], j] + # distnces of all latents to the j-th prototype of this class within the batch + + batch_max_proto_dist_j = np.amax(proto_dist_j) # minimum distance of latents of this batch to prototype j + + if batch_max_proto_dist_j > global_max_proto_dist[j]: # save if a new min distance is present in this batch + + img_index_in_class = np.argmax(proto_dist_j) + img_index_in_batch = class_to_img_index[target_class][img_index_in_class] + + batch_max_fmap_j = feature_vectors[img_index_in_batch, j] + + # latent vector of minimum distance + global_max_proto_dist[j] = batch_max_proto_dist_j + global_max_fmap[j] = batch_max_fmap_j + global_img_indices[j] = img_index_in_batch + start_index_of_search_batch + global_img_classes[j] = search_y[img_index_in_batch].item() + + if self.hparams.enable_checkpointing: + + # original image + original_img_j = search_batch[self.image_modality][img_index_in_batch].detach().cpu().numpy() + + # find highly activated region of the original image + proto_occ_j = occurrences[img_index_in_batch, j] + + np.save(proto_epoch_dir / f"original_{j}_epoch_{self.current_epoch}.npy", original_img_j) + np.save(proto_epoch_dir / f"occurrence_{j}_epoch_{self.current_epoch}.npy", proto_occ_j) + + return feature_vectors diff --git a/torchpanic/modules/standard.py b/torchpanic/modules/standard.py new file mode 100644 index 0000000..7061053 --- /dev/null +++ b/torchpanic/modules/standard.py @@ -0,0 +1,115 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +from functools import partial +from typing import Any, List + +import torch +from torch import nn + +from .base import BaseModule + + +class StandardModule(BaseModule): + def __init__( + self, + net: nn.Module, + lr: float = 0.001, + num_classes: int = 3, + weight_decay: float = 0.0005, + output_penalty_weight: float = 0.001, + **kwargs, + ): + super().__init__( + net=net, + num_classes=num_classes, + ) + + # this line allows to access init params with 'self.hparams' attribute + # it also ensures init params will be stored in ckpt + self.save_hyperparameters(logger=False, ignore=["net"]) + + # loss function + if num_classes > 2: + self.criterion = torch.nn.CrossEntropyLoss() + else: + self.criterion = torch.nn.BCEWithLogitsLoss() + + def step(self, batch: Any): + x, _, _, y = batch + logits, terms = self.forward(x) + if self.hparams.num_classes < 3: + logits = logits.squeeze() + probs = torch.sigmoid(logits) + preds = torch.argmax(torch.stack((1.0 - probs, probs), dim=-1), dim=-1) + else: + preds = torch.argmax(logits, dim=1) + + loss = self.criterion(logits, y) + if self.hparams.output_penalty_weight > 0: + output_norm = torch.linalg.norm(terms, dim=[1, 2]) + loss = loss + self.hparams.output_penalty_weight * output_norm.mean() + return loss, preds, y.long() + + def training_step(self, batch: Any, batch_idx: int): + loss, preds, targets = self.step(batch) + + # log train metrics + self._log_train_metrics(loss, preds, targets) + + # we can return here dict with any tensors + # and then read it in some callback or in `training_epoch_end()` below + # remember to always return loss from `training_step()` or else backpropagation will fail! + return {"loss": loss, "preds": preds, "targets": targets} + + def validation_step(self, batch: Any, batch_idx: int): + loss, preds, targets = self.step(batch) + + self._update_validation_metrics(loss, preds, targets) + + return {"loss": loss, "preds": preds, "targets": targets} + + def validation_epoch_end(self, outputs: List[Any]): + self._log_validation_metrics() + + def test_step(self, batch: Any, batch_idx: int): + loss, preds, targets = self.step(batch) + + self._update_test_metrics(loss, preds, targets) + + return {"loss": loss, "preds": preds, "targets": targets} + + def test_epoch_end(self, outputs: List[Any]): + self._log_test_metrics() + + def configure_optimizers(self): + """Choose what optimizers and learning-rate schedulers to use in your optimization. + Normally you'd need one. But in the case of GANs or similar you might have multiple. + See examples here: + https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers + """ + optimizer = torch.optim.AdamW( + params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay + ) + n_iters = len(self.trainer.datamodule.train_dataloader()) + scheduler = torch.optim.lr_scheduler.CyclicLR( + optimizer, + base_lr=self.hparams.lr / 10, + max_lr=self.hparams.lr, + step_size_up=5 * n_iters, + step_size_down=5 * n_iters, + cycle_momentum=False, + ) + + return [optimizer], [{"scheduler": scheduler, "interval": "step"}] diff --git a/torchpanic/modules/utils.py b/torchpanic/modules/utils.py new file mode 100644 index 0000000..186087a --- /dev/null +++ b/torchpanic/modules/utils.py @@ -0,0 +1,105 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import math +from pathlib import Path +from typing import Union + +from hydra.utils import instantiate as hydra_init +import numpy as np +from omegaconf import DictConfig, OmegaConf +import torch +from yaml import safe_load as yaml_load + +import subprocess + + +def format_float_to_str(x: float) -> str: + return "{:2.1f}".format(x * 100) + + +def init_vector_normal(vector: torch.Tensor): + stdv = 1. / math.sqrt(vector.size(1)) + vector.data.uniform_(-stdv, stdv) + + +def get_current_stage(epoch, epochs_all, epochs_nam, warmup=10): + total = epochs_all + epochs_nam + stage = "nam_only" + if epoch < 0.5 * warmup: + stage = "warmup" + elif epoch < warmup: + stage = "warmup_protonet" + elif (epoch - warmup) % total < epochs_all: + stage = "all" + return stage + + +def get_last_valid_checkpoint(path: Path): + epoch = int(path.stem.split("=")[1].split("-")[0]) + epoch_old = int(path.stem.split("=")[1].split("-")[0]) + config = load_config(path) + e_all = config.model.epochs_all + e_nam = config.model.epochs_nam + warmup = config.model.epochs_warmup + + stage = get_current_stage(epoch, e_all, e_nam, warmup) + while stage == "nam_only": + epoch -= 1 + stage = get_current_stage(epoch, e_all, e_nam, warmup) + if epoch != epoch_old: + epoch += 1 + ckpt_path = str(path) + print(f"Previous epoch {epoch_old} was invalid. Valid checkpoint is of epoch {epoch}") + return ckpt_path.replace(f"epoch={epoch_old}", f"epoch={epoch}") + + +def init_vectors_orthogonally(vector: torch.Tensor, n_protos_per_class: int): + # vector has shape (n_protos, n_chans) + assert vector.size(0) % n_protos_per_class == 0 + torch.nn.init.xavier_uniform_(vector) + + for j in range(vector.size(0)): + vector.data[j, j // n_protos_per_class] += 1. + + +def load_config(ckpt_path: Union[str, Path]) -> DictConfig: + config_path = str(Path(ckpt_path).parent.parent / '.hydra' / 'config.yaml') + with open(config_path) as f: + y = yaml_load(f) + workdir = Path().absolute() + idx = workdir.parts.index('panic') + workdir = Path(*workdir.parts[:idx+1]) + if 'protonet' in y['model']['net']: + y['model']['net']['protonet']['pretrained_model'] = \ + y['model']['net']['protonet']['pretrained_model'].replace('${hydra:runtime.cwd}', str(workdir)) + config = OmegaConf.create(y) + return config + + +def load_model_and_data(ckpt_path: str, device=torch.device("cuda")): + ''' loads model and data with respect to a checkpoint path + must call data.setup(stage) to setup data + pytorch model can be retrieved with model.net ''' + config = load_config(Path(ckpt_path)) + data = hydra_init(config.datamodule) + model = hydra_init(config.model) + model.load_state_dict(torch.load(ckpt_path, map_location=device)["state_dict"]) + return model, data, config + + +def get_git_hash(): + return subprocess.check_output([ + "git", "rev-parse", "HEAD" + ], encoding="utf8").strip() diff --git a/torchpanic/testing.py b/torchpanic/testing.py new file mode 100644 index 0000000..c726eb0 --- /dev/null +++ b/torchpanic/testing.py @@ -0,0 +1,83 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import logging +from typing import List + +import hydra +from omegaconf import DictConfig +import pytorch_lightning as pl + +LOG = logging.getLogger(__name__) + + +def test(config: DictConfig): + # Set seed for random number generators in pytorch, numpy and python.random + if config.get("seed"): + pl.seed_everything(config.seed, workers=True) + + LOG.info("Instantiating datamodule <%s>", config.datamodule._target_) + data: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule) + + LOG.info("Instantiating model <%s>", config.model._target_) + model: pl.LightningModule = hydra.utils.instantiate(config.model) + + # Init lightning loggers + logger: List[pl.LightningLoggerBase] = [] + if "logger" in config: + for lg_conf in config.logger.values(): + if "_target_" in lg_conf: + LOG.info("Instantiating logger <%s>", lg_conf._target_) + logger.append(hydra.utils.instantiate(lg_conf)) + + LOG.info("Instantiating trainer <%s>", config.trainer._target_) + trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, logger=logger) + + # Log hyperparameters + if trainer.logger: + trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path}) + + LOG.info("Starting testing!") + return trainer.test(model, data, ckpt_path=config.ckpt_path) + + +def validate(config: DictConfig): + # Set seed for random number generators in pytorch, numpy and python.random + if config.get("seed"): + pl.seed_everything(config.seed, workers=True) + + LOG.info("Instantiating datamodule <%s>", config.datamodule._target_) + data: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule) + data.setup("test") + + LOG.info("Instantiating model <%s>", config.model._target_) + model: pl.LightningModule = hydra.utils.instantiate(config.model) + + # Init lightning loggers + logger: List[pl.LightningLoggerBase] = [] + if "logger" in config: + for lg_conf in config.logger.values(): + if "_target_" in lg_conf: + LOG.info("Instantiating logger <%s>", lg_conf._target_) + logger.append(hydra.utils.instantiate(lg_conf)) + + LOG.info("Instantiating trainer <%s>", config.trainer._target_) + trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, logger=logger) + + # Log hyperparameters + if trainer.logger: + trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path}) + + LOG.info("Starting validating!") + return trainer.validate(model, data, ckpt_path=config.ckpt_path) diff --git a/torchpanic/training.py b/torchpanic/training.py new file mode 100644 index 0000000..eb55a39 --- /dev/null +++ b/torchpanic/training.py @@ -0,0 +1,60 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import logging +from typing import List + +import hydra +from omegaconf import DictConfig +import pytorch_lightning as pl +import torch + +LOG = logging.getLogger(__name__) + + +def train(config: DictConfig): + # Set seed for random number generators in pytorch, numpy and python.random + if config.get("seed"): + pl.seed_everything(config.seed, workers=True) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + torch.use_deterministic_algorithms(True) + + LOG.info("Instantiating datamodule <%s>", config.datamodule._target_) + data: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule) + + LOG.info("Instantiating model <%s>", config.model._target_) + model: pl.LightningModule = hydra.utils.instantiate(config.model) + + # Init lightning callbacks + callbacks: List[pl.Callback] = [] + if "callbacks" in config: + for cb_conf in config.callbacks.values(): + if "_target_" in cb_conf: + LOG.info("Instantiating callback <%s>", cb_conf._target_) + callbacks.append(hydra.utils.instantiate(cb_conf)) + + # Init lightning loggers + logger: List[pl.LightningLoggerBase] = [] + if "logger" in config: + for lg_conf in config.logger.values(): + if "_target_" in lg_conf: + LOG.info("Instantiating logger <%s>", lg_conf._target_) + logger.append(hydra.utils.instantiate(lg_conf)) + + LOG.info("Instantiating trainer <%s>", config.trainer._target_) + trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, callbacks=callbacks, logger=logger) + + LOG.info("Starting training!") + trainer.fit(model, data) diff --git a/torchpanic/viz/__init__.py b/torchpanic/viz/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torchpanic/viz/nam_functions.py b/torchpanic/viz/nam_functions.py new file mode 100644 index 0000000..b47f506 --- /dev/null +++ b/torchpanic/viz/nam_functions.py @@ -0,0 +1,52 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt + +from .tabular import NamInspector, FunctionPlotter + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument("checkpoint", type=Path) + parser.add_argument("-o", "--output", type=Path) + + args = parser.parse_args() + + inspector = NamInspector(args.checkpoint) + inspector.load() + + data, samples, outputs = inspector.get_outputs() + + weights_cat_linear = inspector.get_linear_weights(revert_standardization=True) + + plotter = FunctionPlotter( + class_names=["CN", "MCI", "AD"], + log_scale=["Tau", "p-Tau"], + ) + + out_file = args.output + if out_file is None: + out_file = args.checkpoint.with_name('nam_functions.pdf') + + fig = plotter.plot(data, samples, outputs, weights_cat_linear.drop("bias")) + fig.savefig(out_file, bbox_inches="tight", transparent=True) + plt.close() + + +if __name__ == '__main__': + main() diff --git a/torchpanic/viz/summary.py b/torchpanic/viz/summary.py new file mode 100644 index 0000000..66fab51 --- /dev/null +++ b/torchpanic/viz/summary.py @@ -0,0 +1,162 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import matplotlib as mpl +import matplotlib.pyplot as plt +import numpy as np + +import seaborn as sns + +def summary_plot( + predictions, + feature_names=None, + max_display=20, + title="", + class_names=None, + class_inds=None, + colormap="Set2", + axis_color="black", + sort=True, + plot_size="auto", +): + """Create a summary plot of shapley values for 1-dimensional features, colored by feature values when they are provided. + This method is based on Slundberg - SHAP (https://github.com/slundberg/shap/blob/master/shap/) + Args: + predictions : list of numpy.array + For each class, a list of predictions for each function with shape = (n_samples, n_features). + feature_names : list + Names of the features (length # features) + max_display : int + How many top features to include in the plot (default is 20) + plot_size : "auto" (default), float, (float, float), or None + What size to make the plot. By default the size is auto-scaled based on the number of + features that are being displayed. Passing a single float will cause each row to be that + many inches high. Passing a pair of floats will scale the plot by that + number of inches. If None is passed then the size of the current figure will be left + unchanged. + """ + + multi_class = True + + # default color: + if multi_class: + #cm = mpl.colormaps[colormap] + cm = sns.color_palette(colormap, n_colors=3) + #color = lambda i: cm(i) + color = lambda i: cm[(i + 2) % 3] + + num_features = predictions[0].shape[1] if multi_class else predictions.shape[1] + + shape_msg = ( + "The shape of the shap_values matrix does not match the shape of the " + "provided data matrix." + ) + if num_features - 1 == len(feature_names): + raise ValueError( + shape_msg + + " Perhaps the extra column in the shap_values matrix is the " + "constant offset? If so just pass shap_values[:,:-1]." + ) + elif num_features != len(feature_names): + raise ValueError(shape_msg) + + if sort: + # order features by the sum of their effect magnitudes + if multi_class: + feature_order = np.argsort( + np.sum(np.mean(np.abs(predictions), axis=0), axis=0) + ) + else: + feature_order = np.argsort(np.sum(np.abs(predictions), axis=0)) + feature_order = feature_order[-min(max_display, len(feature_order)) :] + else: + feature_order = np.flip(np.arange(min(max_display, num_features)), 0) + + row_height = 0.4 + if plot_size == "auto": + figsize = (8, len(feature_order) * row_height + 1.5) + elif isinstance(plot_size, (list, tuple)): + figsize = (plot_size[0], plot_size[1]) + elif plot_size is not None: + figsize = (8, len(feature_order) * plot_size + 1.5) + + fig, ax = plt.subplots(figsize=figsize) + ax.axvline(x=0, color="#999999", zorder=-1) + + legend_handles = [] + legend_text = [] + + if class_names is None: + class_names = [f"Class {i}" for i in range(len(predictions))] + feature_inds = feature_order[:max_display] + y_pos = np.arange(len(feature_inds)) + left_pos = np.zeros(len(feature_inds)) + + if class_inds is None: + class_inds = np.argsort([-np.abs(p).mean() for p in predictions]) + elif class_inds == "original": + class_inds = range(len(predictions)) + for i, ind in enumerate(class_inds): + global_importance = np.abs(predictions[ind]).mean(0) + ax.barh( + y_pos, global_importance[feature_inds], 0.7, left=left_pos, align='center', + color=color(i), label=class_names[ind] + ) + left_pos += global_importance[feature_inds] + f_names_relevant = [feature_names[i] for i in feature_inds] + ax.set_yticklabels(f_names_relevant) + + ax.xaxis.set_ticks_position("top") + ax.xaxis.set_label_position("top") + ax.yaxis.set_ticks_position("none") + ax.spines["right"].set_visible(False) + ax.spines["bottom"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.tick_params(color=axis_color, labelcolor=axis_color) + + plt.yticks( + range(len(feature_order)), + [feature_names[i] for i in feature_order], + fontsize='large' + ) + plt.ylim(-1, len(feature_order)) + + ax.set_title(title) + ax.tick_params(color=axis_color, labelcolor=axis_color) + ax.tick_params("y", length=20, width=0.5, which="major") + ax.tick_params("x") # , labelsize=11) + ax.xaxis.grid(True) + ax.set_xlabel("Mean Importance", fontsize='large') + + # legend_handles.append(missing_handle) + # legend_text.append("Missing") + if len(legend_handles) > 0: + plt.legend( + legend_handles[::-1], + legend_text[::-1], + loc="center right", + bbox_to_anchor=(1.30, 0.55), + frameon=False, + ) + else: + plt.legend( + loc="best", + frameon=True, + fancybox=True, + facecolor='white', + framealpha=1.0, + fontsize=12, + ) + + return fig, f_names_relevant diff --git a/torchpanic/viz/tabular.py b/torchpanic/viz/tabular.py new file mode 100644 index 0000000..9eb2d51 --- /dev/null +++ b/torchpanic/viz/tabular.py @@ -0,0 +1,429 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +from typing import Sequence, Union + +import matplotlib as mpl +from matplotlib import gridspec +from matplotlib.patches import Rectangle +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +import torch +from torch.utils.data import DataLoader, Dataset + +from ..datamodule.adni import AdniDataset +from ..datamodule.credit_card_fraud import PandasDataSet +from ..models.panic import PANIC +from ..models.nam import BaseNAM, NAM +from ..modules.utils import load_model_and_data +from ..datamodule.modalities import ModalityType +from .utils import map_tabular_names + + +def collect_tabular_data(dataset: Dataset) -> pd.DataFrame: + data_values = [] + data_miss = [] + for i in range(len(dataset)): + x = dataset[i][0] + x_tab, x_miss = x[ModalityType.TABULAR] + + if isinstance(x_tab, torch.Tensor): + x_tab = x_tab.detach().cpu().numpy() # shape = (bs, num_features) + if isinstance(x_miss, torch.Tensor): + x_miss = x_miss.detach().cpu().numpy() # shape = (bs, num_features) + + data_values.append(x_tab) + data_miss.append(x_miss) + + if isinstance(dataset, AdniDataset): + index = dataset.rid + column_names = map_tabular_names(dataset.column_names) + elif isinstance(dataset, PandasDataSet): + index, column_names = None, None + else: + raise ValueError(f"Dataset of type {type(dataset)} not implemented") + + data = np.ma.array( + np.stack(data_values, axis=0), + mask=np.stack(data_miss, axis=0), + copy=False, + ) + data = pd.DataFrame(data, index=index, columns=column_names) + + return data + + +def create_sample_data(tabular_data: pd.DataFrame, n_samples: int = 500) -> torch.Tensor: + samples = torch.empty((n_samples, tabular_data.shape[1])) + for j, (_, series) in enumerate(tabular_data.iteritems()): + if hasattr(series, "cat"): + uniq_values = torch.from_numpy(series.cat.categories.to_numpy()) + values = torch.cat(( + uniq_values, + # fill by repeating the last category + torch.full((n_samples - len(uniq_values),), uniq_values[-1]) + )) + else: + q = series.quantile([0.01, 0.99]) + values = torch.linspace(q.iloc[0], q.iloc[1], n_samples) + samples[:, j] = values + return samples + + +def get_modality_shape(dataloader: DataLoader, modality: ModalityType): + dataset: Union[AdniDataset, PandasDataSet] = dataloader.dataset + return dataset[0][0][modality].shape + + +def iter_datasets(datamodule): + for path in ( + datamodule.train_data, datamodule.valid_data, datamodule.test_data, + ): + yield AdniDataset(path, is_training=False, modalities=datamodule.modalities) + + +class NamInspector: + def __init__( + self, checkpoint_path: str, dataset_name: str = "test", device=torch.device("cuda"), + ) -> None: + self.checkpoint_path = checkpoint_path + self.dataset_name = dataset_name + self.device = device + + def load(self): + plmodule, datamod, config = load_model_and_data(self.checkpoint_path) + datamod.setup("fit" if self.dataset_name != "test" else self.dataset_name) + + data = pd.concat([collect_tabular_data(ds) for ds in iter_datasets(datamod)]) + idx_cat = config.model.net.nam.idx_cat_features + for j in idx_cat: + data.iloc[:, j] = data.iloc[:, j].astype("category") + + self._data = data + self._model = plmodule.net.eval().to(self.device) + self._config = config + self._datamod = datamod + + def _get_dataloader(self): + dsnames = { + "fit": self._datamod.push_dataloader, + "val": self._datamod.val_dataloader, + "test": self._datamod.test_dataloader, + } + return dsnames[self.dataset_name]() + + @torch.no_grad() + def _collect_outputs( + self, samples_loader: DataLoader, include_missing: bool, + ) -> Sequence[np.ndarray]: + dataloader = self._get_dataloader() + + non_missing = torch.zeros((1, 1, self._data.shape[1]), device=self.device) + img_data = torch.zeros( + get_modality_shape(dataloader, ModalityType.PET), device=self.device + ).unsqueeze(0) + + n_prototypes = self._config.model.net.protonet.n_prototypes_per_class + + model: PANIC = self._model + outputs = [] + for x in samples_loader: + x = x.to(self.device).unsqueeze(1) + tab_in = torch.cat((x, non_missing.expand_as(x)), axis=1) + img_in = img_data.expand(x.shape[0], -1, -1, -1, -1) + + logits, similarities, occurrences, nam_features_without_dropout = model(img_in, tab_in) + # only collect outputs referring to tabular features + feature_outputs = nam_features_without_dropout[:, n_prototypes:] + + outputs.append(feature_outputs.detach().cpu().numpy()) + + if include_missing: + tab_in = torch.ones_like(tab_in)[:1] + img_in = img_in[:1] + logits, similarities, occurrences, nam_features_without_dropout = model( + img_in, tab_in + ) + feature_outputs = nam_features_without_dropout[:, n_prototypes:] + outputs.append(feature_outputs.detach().cpu().numpy()) + + return outputs + + def _apply_inverse_transform(self, data): + dataloader = self._get_dataloader() + + # invert standardize transform to restore original feature distributions + dataset: AdniDataset = dataloader.dataset + if isinstance(data, pd.DataFrame): + data_original = pd.DataFrame( + dataset.tabular_inverse_transform(data.values), + index=data.index, columns=data.columns, + ) + for col in self._data.select_dtypes(include=["category"]).columns: + vals = data_original.loc[:, col].apply(np.rint) # round to nearest integer + data_original.loc[:, col] = vals.astype("category") + else: + data_original = dataset.tabular_inverse_transform(data) + return data_original + + def get_outputs(self, plt_embeddings=False): + samples = create_sample_data(self._data) + samples_loader = DataLoader(samples, batch_size=self._config.datamodule.batch_size) + + outputs = self._collect_outputs(samples_loader, include_missing=plt_embeddings) + + samples = np.asfortranarray(samples) + data_original = self._apply_inverse_transform(self._data) + samples_original = self._apply_inverse_transform(samples) + + if plt_embeddings: + samples_original = np.row_stack( + (samples_original, 99 * np.ones(samples_original.shape[1])) + ) + data_original = data_original.append(pd.Series( + 99, index=data_original.columns, name="MISSING" + )) + + outputs = np.concatenate(outputs, axis=0) + outputs = np.asfortranarray(outputs) + + # reorder columns so they are in the same order as the output of the model + nam_config = self._config.model.net.nam + idx = nam_config["idx_real_features"] + nam_config["idx_cat_features"] + data_original = data_original.iloc[:, idx] + samples_original = samples_original[:, idx] + + return data_original, samples_original, outputs + + def get_linear_weights(self, revert_standardization: bool) -> pd.DataFrame: + nam_model: BaseNAM = self._model.nam + bias = nam_model.bias.detach().cpu().numpy() + weights = nam_model.cat_linear.detach().cpu().numpy() + + dataloader = self._get_dataloader() + dataset: AdniDataset = dataloader.dataset + # reorder columns so they are in the same order as the output of the model + nam_config = self._config.model.net.nam + cat_idx = nam_config["idx_cat_features"] + columns = ["bias"] + dataset.column_names[cat_idx].tolist() + + if revert_standardization: + mean = dataset.tabular_mean[cat_idx] + std = dataset.tabular_stddev[cat_idx] + + weights_new = np.empty_like(weights) + bias_new = np.empty_like(bias) + for k in range(weights.shape[1]): + weights_new[:, k] = weights[:, k] / std + bias_new[:, k] = bias[:, k] - np.dot(weights_new[:, k], mean) + else: + bias_new = bias + weights_new = weights + + coef = pd.DataFrame( + np.row_stack((bias_new, weights_new)), index=map_tabular_names(columns), + ) + coef.index.name = "feature" + coef.columns.name = "target" + return coef + + +def get_fraudnet_outputs_from_checkpoint(checkpoint_path, device=torch.device("cuda")): + plmodule, data, config = load_model_and_data(checkpoint_path) + data.setup("fit") + + data = collect_tabular_data(data.train_dataloader()) + samples = create_sample_data(data) + samples_loader = DataLoader(samples, batch_size=config.datamodule.batch_size) + non_missing = torch.zeros((1, 1, data.shape[1]), device=device) + + model: NAM = plmodule.net.eval().to(device) + outputs = [] + + with torch.no_grad(): + for x in samples_loader: + x = x.to(device).unsqueeze(1) + x = torch.cat((x, non_missing.expand_as(x)), axis=1) + logits, nam_features = model.base_forward(x) + outputs.append(nam_features.detach().cpu().numpy()) + outputs = np.concatenate(outputs) + + samples = np.asfortranarray(samples) + outputs = np.asfortranarray(outputs) + + return data, samples, outputs + + +class FunctionPlotter: + def __init__( + self, + class_names: Sequence[str], + log_scale: Sequence[str], + n_cols: int = 6, + size: float = 2.5, + ) -> None: + self.class_names = class_names + self.log_scale = frozenset(log_scale) + self.n_cols = n_cols + self.size = size + + def plot( + self, + data: pd.DataFrame, + samples: np.ndarray, + outputs: np.ndarray, + categorial_coefficients: pd.Series, + ): + assert data.shape[1] == samples.shape[1] + assert data.shape[1] == outputs.shape[1] + assert samples.shape[0] == outputs.shape[0] + assert samples.ndim == 2 + assert outputs.ndim == 3 + assert len(self.class_names) == outputs.shape[2] + + assert len(categorial_coefficients.index.difference(data.columns)) == 0 + + rc_params = { + "axes.titlesize": "small", + "xtick.labelsize": "x-small", + "ytick.labelsize": "x-small", + "lines.linewidth": 2, + } + + with mpl.rc_context(rc_params): + return self._plot_functions(data, samples, outputs, categorial_coefficients) + + def _plot_functions( + self, + data: pd.DataFrame, + samples: np.ndarray, + outputs: np.ndarray, + categorial_coefficients: pd.Series, + ): + categorical_columns = frozenset(categorial_coefficients.index) + + n_cols = self.n_cols + n_features = data.shape[1] + n_rows = int(np.ceil(n_features / n_cols)) + + fig = plt.figure( + figsize=(n_cols * self.size, n_rows * self.size) + ) + gs_outer = gridspec.GridSpec( + n_rows, n_cols, figure=fig, hspace=0.35, wspace=0.3, + ) + + n_odds_ratios = outputs.shape[2] - 1 + palette = sns.color_palette("Set2", n_colors=n_odds_ratios) + + ref_class_name = self.class_names[0] + legend_data = { + f"{name} vs {ref_class_name}": c + for name, c in zip(self.class_names[1:], palette) + } + + ax_legend = 0 + h_ratios = [5, 1] + for idx, (name, values) in enumerate(data.iteritems()): + i = idx // n_cols + j = idx % n_cols + gs = gs_outer[i, j].subgridspec(2, 1, height_ratios=h_ratios, hspace=0.1) + ax_top = fig.add_subplot(gs[0, 0]) + ax_bot = fig.add_subplot(gs[1, 0], sharex=ax_top) + if idx == data.shape[1] - 1: + ax_legend = ax_top + + values = values.dropna() + for cls_idx in range(1, n_odds_ratios + 1): + if name in categorical_columns: + plot_fn = self._plot_categorical + coef_ref = categorial_coefficients.loc[name, 0] + coef_new = categorial_coefficients.loc[name, cls_idx] + x = data.loc[:, name].cat.categories.to_numpy() + y = (coef_new - coef_ref) * (x - x[0]) # log-odds ratio + else: + plot_fn = self._plot_continuous + x = samples[:, idx] + mid = np.searchsorted(x, data.loc[:, name].mean()) + log_odds_mid = outputs[mid, idx, cls_idx] - outputs[mid, idx, 0] + log_odds = outputs[:, idx, cls_idx] - outputs[:, idx, 0] + y = log_odds - log_odds_mid # log-odds ratio + + plot_fn( + x=x, + y=y, + values=values if cls_idx == 1 else None, + ax_top=ax_top, + ax_bot=ax_bot, + color=palette[cls_idx - 1], + label=f"Class {cls_idx}", + ) + + ax_top.axhline(0.0, color="gray") + ax_top.tick_params(axis="x", labelbottom=False) + ax_top.set_title(name) + if name in self.log_scale: + ax_top.set_xscale("log") + + if j == 0: + ax_top.set_ylabel("log odds ratio") + + legend_kwargs = {"loc": "center left", "bbox_to_anchor": (1.05, 0.5)} + self._add_legend(ax_legend, legend_data, legend_kwargs) + + return fig + + def _add_legend(self, ax, palette, legend_kwargs=None): + handles = [] + labels = [] + for name, color in palette.items(): + p = Rectangle((0, 0), 1, 1) + p.set_facecolor(color) + handles.append(p) + labels.append(name) + + if legend_kwargs is None: + legend_kwargs = {"loc": "best"} + + ax.legend(handles, labels, **legend_kwargs) + + def _plot_continuous(self, x, y, values, ax_top, ax_bot, color, label): + ax_top.plot(x, y, marker="none", color=color, label=label, zorder=2.5) + ax_top.grid(True) + + if values is not None: + ax_bot.boxplot( + values.dropna(), + vert=False, + widths=0.75, + showmeans=True, + medianprops={"color": "#ff7f00"}, + meanprops={ + "marker": "d", "markeredgecolor": "#a65628", "markerfacecolor": "none", + }, + flierprops={"marker": "."}, + showfliers=False, + ) + ax_bot.yaxis.set_visible(False) + + def _plot_categorical(self, x, y, values, ax_top, ax_bot, color, label): + ax_top.step(x, y, where="mid", color=color, label=label, zorder=2.5) + ax_top.grid(True) + + if values is not None: + _, counts = np.unique(values, return_counts=True) + assert len(counts) == len(x) + ax_bot.bar(x, height=counts / counts.sum() * 100., width=0.6, color="dimgray") diff --git a/torchpanic/viz/utils.py b/torchpanic/viz/utils.py new file mode 100644 index 0000000..52e74ed --- /dev/null +++ b/torchpanic/viz/utils.py @@ -0,0 +1,106 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +from typing import Any, Dict, Sequence + +import numpy as np +import pandas as pd +import torch +from torch.utils.data import DataLoader + +from ..datamodule.modalities import ModalityType +from ..models.panic import PANIC + +_NAME_MAPPER = None + + +@torch.no_grad() +def get_tabnet_predictions(model: PANIC, data_loader: DataLoader, device = torch.device("cuda")): + model = model.eval().to(device) + bias = model.nam.bias.detach().cpu().numpy()[np.newaxis] + + all_logits = [] + predictions = [] + for x, x_raw, aug, y in data_loader: + logits, similarities, occurrences, nam_features_without_dropout = model( + x[ModalityType.PET].to(device), + x[ModalityType.TABULAR].to(device), + ) + nam_features_without_dropout = nam_features_without_dropout.detach().cpu().numpy() + outputs = np.concatenate(( + np.tile(bias, [logits.shape[0], 1, 1]), + nam_features_without_dropout, + ), axis=1) + predictions.append(outputs) + all_logits.append(logits.detach().cpu().numpy()) + + all_logits = np.concatenate(all_logits) + y_pred = np.argmax(all_logits, axis=1) + return np.concatenate(predictions), all_logits, y_pred + + +def _set_name_mapper(metadata: str): + global _NAME_MAPPER + + name_mapper = { + "real_age": "Age", + "PTEDUCAT": "Education", + "PTGENDER": "Male", + "ABETA": "A$\\beta$", + "TAU": "Tau", + "PTAU": "p-Tau", + "Left-Hippocampus": "Left Hippocampus", + "Right-Hippocampus": "Right Hippocampus", + "lh_entorhinal_thickness": "Left Entorhinal Cortex", + "rh_entorhinal_thickness": "Right Entorhinal Cortex", + } + snp_metadata = pd.read_csv( + metadata, index_col=0, + ) + snp_key = snp_metadata.agg( + lambda x: ":".join(map(str, x[ + ["Chromosome", "Position", "Allele2_reference", "Allele1_alternative"] + ])) + "_" + x["Allele1_alternative"], + axis=1 + ) + name_mapper.update((dict(zip(snp_key, snp_metadata.loc[:, "rsid"])))) + _NAME_MAPPER = name_mapper + + +def _get_name_mapper(metadata: str): + global _NAME_MAPPER + + if _NAME_MAPPER is None: + _set_name_mapper(metadata) + return _NAME_MAPPER + + +def map_tabular_names(names: Sequence[str], metadata_file: str) -> Sequence[str]: + name_mapper = _get_name_mapper(metadata_file) + return [name_mapper.get(name, name) for name in names] + + +def get_tabnet_output_names( + tabular_names: Sequence[str], n_prototypes: int, config: Dict[str, Any], metadata_file: str +) -> Sequence[str]: + name_mapper = _get_name_mapper(metadata_file) + + # reorder columns so they are in the same order as the output of the model + nam_config = config["model"]["net"]["nam"] + order = nam_config["idx_real_features"] + nam_config["idx_cat_features"] + + features_names = [name_mapper.get(tabular_names[i], tabular_names[i]) for i in order] + + prototype_names = [f"FDG-PET Proto {i}" for i in range(n_prototypes)] + return ["bias"] + prototype_names + features_names diff --git a/torchpanic/viz/waterfall.py b/torchpanic/viz/waterfall.py new file mode 100644 index 0000000..6d35af3 --- /dev/null +++ b/torchpanic/viz/waterfall.py @@ -0,0 +1,342 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import re + +import matplotlib.pyplot as plt +from matplotlib import colors +from matplotlib.transforms import ScaledTranslation +import numpy as np + + +def format_value(s, format_str): + """ Strips trailing zeros and uses a unicode minus sign. + """ + if np.ma.isMaskedArray(s) and np.ma.is_masked(s): + return "missing" + + if not isinstance(s, str): + s = format_str % s + s = re.sub(r'\.?0+$', '', s) + if s[0] == "-": + s = u"\u2212" + s[1:] + return s + + + +class WaterfallPlotter: + def __init__(self, n_prototypes: int) -> None: + self.n_prototypes = n_prototypes + + def plot( + self, + expected_value, + shap_values, + features=None, + feature_names=None, + actual_class_name=None, + predicted_class_name=None, + max_display=10, + ): + """ Plots an explantion of a single prediction as a waterfall plot. + The SHAP value of a feature represents the impact of the evidence provided by that feature on the model's + output. The waterfall plot is designed to visually display how the SHAP values (evidence) of each feature + move the model output from our prior expectation under the background data distribution, to the final model + prediction given the evidence of all the features. Features are sorted by the magnitude of their SHAP values + with the smallest magnitude features grouped together at the bottom of the plot when the number of features + in the models exceeds the max_display parameter. + Parameters + ---------- + expected_value : float + This is the reference value that the feature contributions start from. For SHAP values it should + be the value of explainer.expected_value. + shap_values : numpy.array + One dimensional array of SHAP values. + features : numpy.array + One dimensional array of feature values. This provides the values of all the + features, and should be the same shape as the shap_values argument. + feature_names : list + List of feature names (# features). + actual_class_name : str + Name of actual class. + predicted_class_name : str + Name of predicted class. + max_display : str + The maximum number of features to plot. + """ + + # init variables we use for tracking the plot locations + num_features = min(max_display, len(shap_values)) + row_height = 0.2 + rng = range(num_features - 1, -1, -1) + order = np.argsort(-np.abs(shap_values)) + pos_lefts = [] + pos_inds = [] + pos_widths = [] + neg_lefts = [] + neg_inds = [] + neg_widths = [] + loc = expected_value + shap_values.sum() + yticklabels = ["" for i in range(num_features + 1)] + + # size the plot based on how many features we are plotting + fig = plt.figure(figsize=(8, num_features * row_height + 1.5)) + + # see how many individual (vs. grouped at the end) features we are plotting + if num_features == len(shap_values): + num_individual = num_features + else: + num_individual = num_features - 1 + + # compute the locations of the individual features and plot the dashed connecting lines + for i in range(num_individual): + sval = shap_values[order[i]] + loc -= sval + if sval >= 0: + pos_inds.append(rng[i]) + pos_widths.append(sval) + pos_lefts.append(loc) + else: + neg_inds.append(rng[i]) + neg_widths.append(sval) + neg_lefts.append(loc) + if num_individual != num_features or i + 4 < num_individual: + plt.plot( + [loc, loc], + [rng[i] - 1 - 0.4, rng[i] + 0.4], + color="#bbbbbb", + linestyle="--", + linewidth=0.5, + zorder=-1, + ) + if features is None or order[i] < self.n_prototypes: + yticklabels[rng[i]] = feature_names[order[i]] + else: + yticklabels[rng[i]] = ( + feature_names[order[i]] + + " = " + + format_value(features[order[i]], "%0.03f") + ) + + # add a last grouped feature to represent the impact of all the features we didn't show + if num_features < len(shap_values): + yticklabels[0] = "%d other features" % (len(shap_values) - num_features + 1) + remaining_impact = expected_value - loc + if remaining_impact < 0: + pos_inds.append(0) + pos_widths.append(-remaining_impact) + pos_lefts.append(loc + remaining_impact) + else: + neg_inds.append(0) + neg_widths.append(-remaining_impact) + neg_lefts.append(loc + remaining_impact) + + points = ( + pos_lefts + + list(np.array(pos_lefts) + np.array(pos_widths)) + + neg_lefts + + list(np.array(neg_lefts) + np.array(neg_widths)) + ) + dataw = np.max(points) - np.min(points) + + # draw invisible bars just for sizing the axes + label_padding = np.array([0.1 * dataw if w < 1 else 0 for w in pos_widths]) + plt.barh( + pos_inds, + np.array(pos_widths) + label_padding + 0.02 * dataw, + left=np.array(pos_lefts) - 0.01 * dataw, + color=colors.to_rgba_array("r"), + alpha=0, + ) + label_padding = np.array([-0.1 * dataw if -w < 1 else 0 for w in neg_widths]) + plt.barh( + neg_inds, + np.array(neg_widths) + label_padding - 0.02 * dataw, + left=np.array(neg_lefts) + 0.01 * dataw, + color=colors.to_rgba_array("b"), + alpha=0, + ) + + # define variable we need for plotting the arrows + head_length = 0.08 + bar_width = 0.8 + xlen = plt.xlim()[1] - plt.xlim()[0] + fig = plt.gcf() + ax = plt.gca() + xticks = ax.get_xticks() + bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted()) + width, height = bbox.width, bbox.height + bbox_to_xscale = xlen / width + hl_scaled = bbox_to_xscale * head_length + renderer = fig.canvas.get_renderer() + + # draw the positive arrows + for i in range(len(pos_inds)): + dist = pos_widths[i] + arrow_obj = plt.arrow( + pos_lefts[i], + pos_inds[i], + max(dist - hl_scaled, 0.000001), + 0, + head_length=min(dist, hl_scaled), + color="tab:red", + width=bar_width, + head_width=bar_width, + ) + + txt_obj = plt.text( + pos_lefts[i] + 0.5 * dist, + pos_inds[i], + format_value(pos_widths[i], "%+0.02f"), + horizontalalignment="center", + verticalalignment="center", + color="white", + fontsize=12, + ) + text_bbox = txt_obj.get_window_extent(renderer=renderer) + arrow_bbox = arrow_obj.get_window_extent(renderer=renderer) + + # if the text overflows the arrow then draw it after the arrow + if text_bbox.width > arrow_bbox.width: + txt_obj.remove() + + # draw the negative arrows + for i in range(len(neg_inds)): + dist = neg_widths[i] + + arrow_obj = plt.arrow( + neg_lefts[i], + neg_inds[i], + -max(-dist - hl_scaled, 0.000001), + 0, + head_length=min(-dist, hl_scaled), + color="tab:blue", + width=bar_width, + head_width=bar_width, + ) + + txt_obj = plt.text( + neg_lefts[i] + 0.5 * dist, + neg_inds[i], + format_value(neg_widths[i], "%+0.02f"), + horizontalalignment="center", + verticalalignment="center", + color="white", + fontsize=12, + ) + text_bbox = txt_obj.get_window_extent(renderer=renderer) + arrow_bbox = arrow_obj.get_window_extent(renderer=renderer) + + # if the text overflows the arrow then draw it after the arrow + if text_bbox.width > arrow_bbox.width: + txt_obj.remove() + + # draw the y-ticks twice, once in gray and then again with just the feature names in black + # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks + ytick_pos = np.arange(num_features) + plt.yticks(ytick_pos, yticklabels[:-1], fontsize=13) + + # put horizontal lines for each feature row + for i in range(num_features): + plt.axhline(i, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1) + + # mark the prior expected value and the model prediction + plt.axvline( + expected_value, + color="#bbbbbb", + linestyle="--", + linewidth=0.85, + zorder=-1, + ) + fx = expected_value + shap_values.sum() + plt.axvline(fx, color="#bbbbbb", linestyle="--", linewidth=0.85, zorder=-1) + + # clean up the main axis + plt.gca().xaxis.set_ticks_position("bottom") + plt.gca().yaxis.set_ticks_position("none") + plt.gca().spines["right"].set_visible(False) + plt.gca().spines["top"].set_visible(False) + plt.gca().spines["left"].set_visible(False) + ax.tick_params(labelsize=13) + # plt.xlabel("Model output", fontsize=12) + + # draw the E[f(X)] tick mark + xmin, xmax = ax.get_xlim() + xticks = ax.get_xticks() + xticks = list(xticks) + min_ind = 0 + min_diff = 1e10 + for i in range(len(xticks)): + v = abs(xticks[i] - expected_value) + if v < min_diff: + min_diff = v + min_ind = i + xticks.pop(min_ind) + ax.set_xticks(xticks) + ax.tick_params(labelsize=13) + ax.set_xlim(xmin, xmax) + + ax2 = ax.twiny() + ax2.set_xlim(xmin, xmax) + ax2.set_xticks([expected_value, expected_value + 1e-8]) + ax2.set_xticklabels(["\nbias", "\n$ = " + format_value(expected_value, "%0.03f") + "$"], fontsize=12, ha="left") + ax2.spines["right"].set_visible(False) + ax2.spines["top"].set_visible(False) + ax2.spines["left"].set_visible(False) + + # draw the f(x) tick mark + ax3 = ax2.twiny() + ax3.set_xlim(xmin, xmax) + ax3.set_xticks([fx, fx + 1e-8]) + the_class = "^{\\mathrm{%s}}" % predicted_class_name if predicted_class_name is not None else "" + ax3.set_xticklabels( + [f"$\\mu{the_class}$", "$ = " + format_value(fx, "%0.03f") + "$"], fontsize=12, ha="left" + ) + tick_labels = ax3.xaxis.get_majorticklabels() + tick_labels[0].set_transform( + tick_labels[0].get_transform() + + ScaledTranslation(-10 / 72.0, 0, fig.dpi_scale_trans) + ) + tick_labels[1].set_transform( + tick_labels[1].get_transform() + + ScaledTranslation(12 / 72.0, 0, fig.dpi_scale_trans) + ) + # tick_labels[1].set_color("#999999") + ax3.spines["right"].set_visible(False) + ax3.spines["top"].set_visible(False) + ax3.spines["left"].set_visible(False) + + # adjust the position of the E[f(X)] = x.xx label + tick_labels = ax2.xaxis.get_majorticklabels() + tick_labels[0].set_transform( + tick_labels[0].get_transform() + + ScaledTranslation(-14 / 72.0, 0, fig.dpi_scale_trans) + ) + tick_labels[1].set_transform( + tick_labels[1].get_transform() + + ScaledTranslation( + 11 / 72.0, -1 / 72.0, fig.dpi_scale_trans + ) + ) + # tick_labels[1].set_color("#999999") + + the_title = [] + if predicted_class_name is not None: + the_title.append(f"Predicted: {predicted_class_name}") + if actual_class_name is not None: + the_title.append(f"Actual: {actual_class_name}") + if len(the_title) > 0: + plt.title(", ".join(the_title)) + + return fig diff --git a/train.py b/train.py new file mode 100644 index 0000000..3837227 --- /dev/null +++ b/train.py @@ -0,0 +1,30 @@ +# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC). +# +# PANIC is free software: you can redistribute it and/or modify +# it under the terms of the GNU General Public License as published by +# the Free Software Foundation, either version 3 of the License, or +# (at your option) any later version. +# +# PANIC is distributed in the hope that it will be useful, +# but WITHOUT ANY WARRANTY; without even the implied warranty of +# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +# GNU General Public License for more details. +# +# You should have received a copy of the GNU General Public License +# along with PANIC. If not, see . +import logging + +import hydra +from omegaconf import DictConfig + +from torchpanic.training import train + + +@hydra.main(config_path="configs/", config_name="train.yaml", version_base="1.2.0") +def main(config: DictConfig): + return train(config) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.INFO) + main()