diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..c537495
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,191 @@
+# Created by https://www.gitignore.io/api/python,pycharm,jupyternotebooks,visualstudiocode
+# Edit at https://www.gitignore.io/?templates=python,pycharm,jupyternotebooks,visualstudiocode
+
+### JupyterNotebooks ###
+# gitignore template for Jupyter Notebooks
+# website: http://jupyter.org/
+
+.ipynb_checkpoints
+*/.ipynb_checkpoints/*
+
+# IPython
+profile_default/
+ipython_config.py
+
+# Remove previous ipynb_checkpoints
+# git rm -r .ipynb_checkpoints/
+
+### PyCharm ###
+# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio and WebStorm
+# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
+
+.idea/
+
+# CMake
+cmake-build-*/
+
+# File-based project format
+*.iws
+
+# IntelliJ
+out/
+
+# mpeltonen/sbt-idea plugin
+.idea_modules/
+#
+### Python ###
+# Byte-compiled / optimized / DLL files
+__pycache__/
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+dist/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+pip-wheel-metadata/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+.hypothesis/
+.pytest_cache/
+
+# Translations
+*.mo
+*.pot
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+target/
+
+# pyenv
+.python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# celery beat schedule file
+celerybeat-schedule
+
+# SageMath parsed files
+*.sage.py
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# Mr Developer
+.mr.developer.cfg
+.project
+.pydevproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+### VisualStudioCode ###
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+
+### VisualStudioCode Patch ###
+# Ignore all local history of files
+.history
+
+# End of https://www.gitignore.io/api/python,pycharm,jupyternotebooks,visualstudiocode
+
+/data/
+lightning_logs/
+logs/
+outputs/
+*.ckpt
+*.pt[h]
+*.html
+*.pkl
+*.pkl.bz2
+*.pkl.gz
+.flake8
+
+
+# Created by https://www.toptal.com/developers/gitignore/api/vim
+# Edit at https://www.toptal.com/developers/gitignore?templates=vim
+
+### Vim ###
+# Swap
+[._]*.s[a-v][a-z]
+!*.svg # comment out if you don't need vector files
+[._]*.sw[a-p]
+[._]s[a-rt-v][a-z]
+[._]ss[a-gi-z]
+[._]sw[a-p]
+
+# Session
+Session.vim
+Sessionx.vim
+
+# Temporary
+.netrwhist
+*~
+# Auto-generated tag files
+tags
+# Persistent undo
+[._]*.un~
+
+# End of https://www.toptal.com/developers/gitignore/api/vim
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000..f288702
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,674 @@
+ GNU GENERAL PUBLIC LICENSE
+ Version 3, 29 June 2007
+
+ Copyright (C) 2007 Free Software Foundation, Inc.
+ Everyone is permitted to copy and distribute verbatim copies
+ of this license document, but changing it is not allowed.
+
+ Preamble
+
+ The GNU General Public License is a free, copyleft license for
+software and other kinds of works.
+
+ The licenses for most software and other practical works are designed
+to take away your freedom to share and change the works. By contrast,
+the GNU General Public License is intended to guarantee your freedom to
+share and change all versions of a program--to make sure it remains free
+software for all its users. We, the Free Software Foundation, use the
+GNU General Public License for most of our software; it applies also to
+any other work released this way by its authors. You can apply it to
+your programs, too.
+
+ When we speak of free software, we are referring to freedom, not
+price. Our General Public Licenses are designed to make sure that you
+have the freedom to distribute copies of free software (and charge for
+them if you wish), that you receive source code or can get it if you
+want it, that you can change the software or use pieces of it in new
+free programs, and that you know you can do these things.
+
+ To protect your rights, we need to prevent others from denying you
+these rights or asking you to surrender the rights. Therefore, you have
+certain responsibilities if you distribute copies of the software, or if
+you modify it: responsibilities to respect the freedom of others.
+
+ For example, if you distribute copies of such a program, whether
+gratis or for a fee, you must pass on to the recipients the same
+freedoms that you received. You must make sure that they, too, receive
+or can get the source code. And you must show them these terms so they
+know their rights.
+
+ Developers that use the GNU GPL protect your rights with two steps:
+(1) assert copyright on the software, and (2) offer you this License
+giving you legal permission to copy, distribute and/or modify it.
+
+ For the developers' and authors' protection, the GPL clearly explains
+that there is no warranty for this free software. For both users' and
+authors' sake, the GPL requires that modified versions be marked as
+changed, so that their problems will not be attributed erroneously to
+authors of previous versions.
+
+ Some devices are designed to deny users access to install or run
+modified versions of the software inside them, although the manufacturer
+can do so. This is fundamentally incompatible with the aim of
+protecting users' freedom to change the software. The systematic
+pattern of such abuse occurs in the area of products for individuals to
+use, which is precisely where it is most unacceptable. Therefore, we
+have designed this version of the GPL to prohibit the practice for those
+products. If such problems arise substantially in other domains, we
+stand ready to extend this provision to those domains in future versions
+of the GPL, as needed to protect the freedom of users.
+
+ Finally, every program is threatened constantly by software patents.
+States should not allow patents to restrict development and use of
+software on general-purpose computers, but in those that do, we wish to
+avoid the special danger that patents applied to a free program could
+make it effectively proprietary. To prevent this, the GPL assures that
+patents cannot be used to render the program non-free.
+
+ The precise terms and conditions for copying, distribution and
+modification follow.
+
+ TERMS AND CONDITIONS
+
+ 0. Definitions.
+
+ "This License" refers to version 3 of the GNU General Public License.
+
+ "Copyright" also means copyright-like laws that apply to other kinds of
+works, such as semiconductor masks.
+
+ "The Program" refers to any copyrightable work licensed under this
+License. Each licensee is addressed as "you". "Licensees" and
+"recipients" may be individuals or organizations.
+
+ To "modify" a work means to copy from or adapt all or part of the work
+in a fashion requiring copyright permission, other than the making of an
+exact copy. The resulting work is called a "modified version" of the
+earlier work or a work "based on" the earlier work.
+
+ A "covered work" means either the unmodified Program or a work based
+on the Program.
+
+ To "propagate" a work means to do anything with it that, without
+permission, would make you directly or secondarily liable for
+infringement under applicable copyright law, except executing it on a
+computer or modifying a private copy. Propagation includes copying,
+distribution (with or without modification), making available to the
+public, and in some countries other activities as well.
+
+ To "convey" a work means any kind of propagation that enables other
+parties to make or receive copies. Mere interaction with a user through
+a computer network, with no transfer of a copy, is not conveying.
+
+ An interactive user interface displays "Appropriate Legal Notices"
+to the extent that it includes a convenient and prominently visible
+feature that (1) displays an appropriate copyright notice, and (2)
+tells the user that there is no warranty for the work (except to the
+extent that warranties are provided), that licensees may convey the
+work under this License, and how to view a copy of this License. If
+the interface presents a list of user commands or options, such as a
+menu, a prominent item in the list meets this criterion.
+
+ 1. Source Code.
+
+ The "source code" for a work means the preferred form of the work
+for making modifications to it. "Object code" means any non-source
+form of a work.
+
+ A "Standard Interface" means an interface that either is an official
+standard defined by a recognized standards body, or, in the case of
+interfaces specified for a particular programming language, one that
+is widely used among developers working in that language.
+
+ The "System Libraries" of an executable work include anything, other
+than the work as a whole, that (a) is included in the normal form of
+packaging a Major Component, but which is not part of that Major
+Component, and (b) serves only to enable use of the work with that
+Major Component, or to implement a Standard Interface for which an
+implementation is available to the public in source code form. A
+"Major Component", in this context, means a major essential component
+(kernel, window system, and so on) of the specific operating system
+(if any) on which the executable work runs, or a compiler used to
+produce the work, or an object code interpreter used to run it.
+
+ The "Corresponding Source" for a work in object code form means all
+the source code needed to generate, install, and (for an executable
+work) run the object code and to modify the work, including scripts to
+control those activities. However, it does not include the work's
+System Libraries, or general-purpose tools or generally available free
+programs which are used unmodified in performing those activities but
+which are not part of the work. For example, Corresponding Source
+includes interface definition files associated with source files for
+the work, and the source code for shared libraries and dynamically
+linked subprograms that the work is specifically designed to require,
+such as by intimate data communication or control flow between those
+subprograms and other parts of the work.
+
+ The Corresponding Source need not include anything that users
+can regenerate automatically from other parts of the Corresponding
+Source.
+
+ The Corresponding Source for a work in source code form is that
+same work.
+
+ 2. Basic Permissions.
+
+ All rights granted under this License are granted for the term of
+copyright on the Program, and are irrevocable provided the stated
+conditions are met. This License explicitly affirms your unlimited
+permission to run the unmodified Program. The output from running a
+covered work is covered by this License only if the output, given its
+content, constitutes a covered work. This License acknowledges your
+rights of fair use or other equivalent, as provided by copyright law.
+
+ You may make, run and propagate covered works that you do not
+convey, without conditions so long as your license otherwise remains
+in force. You may convey covered works to others for the sole purpose
+of having them make modifications exclusively for you, or provide you
+with facilities for running those works, provided that you comply with
+the terms of this License in conveying all material for which you do
+not control copyright. Those thus making or running the covered works
+for you must do so exclusively on your behalf, under your direction
+and control, on terms that prohibit them from making any copies of
+your copyrighted material outside their relationship with you.
+
+ Conveying under any other circumstances is permitted solely under
+the conditions stated below. Sublicensing is not allowed; section 10
+makes it unnecessary.
+
+ 3. Protecting Users' Legal Rights From Anti-Circumvention Law.
+
+ No covered work shall be deemed part of an effective technological
+measure under any applicable law fulfilling obligations under article
+11 of the WIPO copyright treaty adopted on 20 December 1996, or
+similar laws prohibiting or restricting circumvention of such
+measures.
+
+ When you convey a covered work, you waive any legal power to forbid
+circumvention of technological measures to the extent such circumvention
+is effected by exercising rights under this License with respect to
+the covered work, and you disclaim any intention to limit operation or
+modification of the work as a means of enforcing, against the work's
+users, your or third parties' legal rights to forbid circumvention of
+technological measures.
+
+ 4. Conveying Verbatim Copies.
+
+ You may convey verbatim copies of the Program's source code as you
+receive it, in any medium, provided that you conspicuously and
+appropriately publish on each copy an appropriate copyright notice;
+keep intact all notices stating that this License and any
+non-permissive terms added in accord with section 7 apply to the code;
+keep intact all notices of the absence of any warranty; and give all
+recipients a copy of this License along with the Program.
+
+ You may charge any price or no price for each copy that you convey,
+and you may offer support or warranty protection for a fee.
+
+ 5. Conveying Modified Source Versions.
+
+ You may convey a work based on the Program, or the modifications to
+produce it from the Program, in the form of source code under the
+terms of section 4, provided that you also meet all of these conditions:
+
+ a) The work must carry prominent notices stating that you modified
+ it, and giving a relevant date.
+
+ b) The work must carry prominent notices stating that it is
+ released under this License and any conditions added under section
+ 7. This requirement modifies the requirement in section 4 to
+ "keep intact all notices".
+
+ c) You must license the entire work, as a whole, under this
+ License to anyone who comes into possession of a copy. This
+ License will therefore apply, along with any applicable section 7
+ additional terms, to the whole of the work, and all its parts,
+ regardless of how they are packaged. This License gives no
+ permission to license the work in any other way, but it does not
+ invalidate such permission if you have separately received it.
+
+ d) If the work has interactive user interfaces, each must display
+ Appropriate Legal Notices; however, if the Program has interactive
+ interfaces that do not display Appropriate Legal Notices, your
+ work need not make them do so.
+
+ A compilation of a covered work with other separate and independent
+works, which are not by their nature extensions of the covered work,
+and which are not combined with it such as to form a larger program,
+in or on a volume of a storage or distribution medium, is called an
+"aggregate" if the compilation and its resulting copyright are not
+used to limit the access or legal rights of the compilation's users
+beyond what the individual works permit. Inclusion of a covered work
+in an aggregate does not cause this License to apply to the other
+parts of the aggregate.
+
+ 6. Conveying Non-Source Forms.
+
+ You may convey a covered work in object code form under the terms
+of sections 4 and 5, provided that you also convey the
+machine-readable Corresponding Source under the terms of this License,
+in one of these ways:
+
+ a) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by the
+ Corresponding Source fixed on a durable physical medium
+ customarily used for software interchange.
+
+ b) Convey the object code in, or embodied in, a physical product
+ (including a physical distribution medium), accompanied by a
+ written offer, valid for at least three years and valid for as
+ long as you offer spare parts or customer support for that product
+ model, to give anyone who possesses the object code either (1) a
+ copy of the Corresponding Source for all the software in the
+ product that is covered by this License, on a durable physical
+ medium customarily used for software interchange, for a price no
+ more than your reasonable cost of physically performing this
+ conveying of source, or (2) access to copy the
+ Corresponding Source from a network server at no charge.
+
+ c) Convey individual copies of the object code with a copy of the
+ written offer to provide the Corresponding Source. This
+ alternative is allowed only occasionally and noncommercially, and
+ only if you received the object code with such an offer, in accord
+ with subsection 6b.
+
+ d) Convey the object code by offering access from a designated
+ place (gratis or for a charge), and offer equivalent access to the
+ Corresponding Source in the same way through the same place at no
+ further charge. You need not require recipients to copy the
+ Corresponding Source along with the object code. If the place to
+ copy the object code is a network server, the Corresponding Source
+ may be on a different server (operated by you or a third party)
+ that supports equivalent copying facilities, provided you maintain
+ clear directions next to the object code saying where to find the
+ Corresponding Source. Regardless of what server hosts the
+ Corresponding Source, you remain obligated to ensure that it is
+ available for as long as needed to satisfy these requirements.
+
+ e) Convey the object code using peer-to-peer transmission, provided
+ you inform other peers where the object code and Corresponding
+ Source of the work are being offered to the general public at no
+ charge under subsection 6d.
+
+ A separable portion of the object code, whose source code is excluded
+from the Corresponding Source as a System Library, need not be
+included in conveying the object code work.
+
+ A "User Product" is either (1) a "consumer product", which means any
+tangible personal property which is normally used for personal, family,
+or household purposes, or (2) anything designed or sold for incorporation
+into a dwelling. In determining whether a product is a consumer product,
+doubtful cases shall be resolved in favor of coverage. For a particular
+product received by a particular user, "normally used" refers to a
+typical or common use of that class of product, regardless of the status
+of the particular user or of the way in which the particular user
+actually uses, or expects or is expected to use, the product. A product
+is a consumer product regardless of whether the product has substantial
+commercial, industrial or non-consumer uses, unless such uses represent
+the only significant mode of use of the product.
+
+ "Installation Information" for a User Product means any methods,
+procedures, authorization keys, or other information required to install
+and execute modified versions of a covered work in that User Product from
+a modified version of its Corresponding Source. The information must
+suffice to ensure that the continued functioning of the modified object
+code is in no case prevented or interfered with solely because
+modification has been made.
+
+ If you convey an object code work under this section in, or with, or
+specifically for use in, a User Product, and the conveying occurs as
+part of a transaction in which the right of possession and use of the
+User Product is transferred to the recipient in perpetuity or for a
+fixed term (regardless of how the transaction is characterized), the
+Corresponding Source conveyed under this section must be accompanied
+by the Installation Information. But this requirement does not apply
+if neither you nor any third party retains the ability to install
+modified object code on the User Product (for example, the work has
+been installed in ROM).
+
+ The requirement to provide Installation Information does not include a
+requirement to continue to provide support service, warranty, or updates
+for a work that has been modified or installed by the recipient, or for
+the User Product in which it has been modified or installed. Access to a
+network may be denied when the modification itself materially and
+adversely affects the operation of the network or violates the rules and
+protocols for communication across the network.
+
+ Corresponding Source conveyed, and Installation Information provided,
+in accord with this section must be in a format that is publicly
+documented (and with an implementation available to the public in
+source code form), and must require no special password or key for
+unpacking, reading or copying.
+
+ 7. Additional Terms.
+
+ "Additional permissions" are terms that supplement the terms of this
+License by making exceptions from one or more of its conditions.
+Additional permissions that are applicable to the entire Program shall
+be treated as though they were included in this License, to the extent
+that they are valid under applicable law. If additional permissions
+apply only to part of the Program, that part may be used separately
+under those permissions, but the entire Program remains governed by
+this License without regard to the additional permissions.
+
+ When you convey a copy of a covered work, you may at your option
+remove any additional permissions from that copy, or from any part of
+it. (Additional permissions may be written to require their own
+removal in certain cases when you modify the work.) You may place
+additional permissions on material, added by you to a covered work,
+for which you have or can give appropriate copyright permission.
+
+ Notwithstanding any other provision of this License, for material you
+add to a covered work, you may (if authorized by the copyright holders of
+that material) supplement the terms of this License with terms:
+
+ a) Disclaiming warranty or limiting liability differently from the
+ terms of sections 15 and 16 of this License; or
+
+ b) Requiring preservation of specified reasonable legal notices or
+ author attributions in that material or in the Appropriate Legal
+ Notices displayed by works containing it; or
+
+ c) Prohibiting misrepresentation of the origin of that material, or
+ requiring that modified versions of such material be marked in
+ reasonable ways as different from the original version; or
+
+ d) Limiting the use for publicity purposes of names of licensors or
+ authors of the material; or
+
+ e) Declining to grant rights under trademark law for use of some
+ trade names, trademarks, or service marks; or
+
+ f) Requiring indemnification of licensors and authors of that
+ material by anyone who conveys the material (or modified versions of
+ it) with contractual assumptions of liability to the recipient, for
+ any liability that these contractual assumptions directly impose on
+ those licensors and authors.
+
+ All other non-permissive additional terms are considered "further
+restrictions" within the meaning of section 10. If the Program as you
+received it, or any part of it, contains a notice stating that it is
+governed by this License along with a term that is a further
+restriction, you may remove that term. If a license document contains
+a further restriction but permits relicensing or conveying under this
+License, you may add to a covered work material governed by the terms
+of that license document, provided that the further restriction does
+not survive such relicensing or conveying.
+
+ If you add terms to a covered work in accord with this section, you
+must place, in the relevant source files, a statement of the
+additional terms that apply to those files, or a notice indicating
+where to find the applicable terms.
+
+ Additional terms, permissive or non-permissive, may be stated in the
+form of a separately written license, or stated as exceptions;
+the above requirements apply either way.
+
+ 8. Termination.
+
+ You may not propagate or modify a covered work except as expressly
+provided under this License. Any attempt otherwise to propagate or
+modify it is void, and will automatically terminate your rights under
+this License (including any patent licenses granted under the third
+paragraph of section 11).
+
+ However, if you cease all violation of this License, then your
+license from a particular copyright holder is reinstated (a)
+provisionally, unless and until the copyright holder explicitly and
+finally terminates your license, and (b) permanently, if the copyright
+holder fails to notify you of the violation by some reasonable means
+prior to 60 days after the cessation.
+
+ Moreover, your license from a particular copyright holder is
+reinstated permanently if the copyright holder notifies you of the
+violation by some reasonable means, this is the first time you have
+received notice of violation of this License (for any work) from that
+copyright holder, and you cure the violation prior to 30 days after
+your receipt of the notice.
+
+ Termination of your rights under this section does not terminate the
+licenses of parties who have received copies or rights from you under
+this License. If your rights have been terminated and not permanently
+reinstated, you do not qualify to receive new licenses for the same
+material under section 10.
+
+ 9. Acceptance Not Required for Having Copies.
+
+ You are not required to accept this License in order to receive or
+run a copy of the Program. Ancillary propagation of a covered work
+occurring solely as a consequence of using peer-to-peer transmission
+to receive a copy likewise does not require acceptance. However,
+nothing other than this License grants you permission to propagate or
+modify any covered work. These actions infringe copyright if you do
+not accept this License. Therefore, by modifying or propagating a
+covered work, you indicate your acceptance of this License to do so.
+
+ 10. Automatic Licensing of Downstream Recipients.
+
+ Each time you convey a covered work, the recipient automatically
+receives a license from the original licensors, to run, modify and
+propagate that work, subject to this License. You are not responsible
+for enforcing compliance by third parties with this License.
+
+ An "entity transaction" is a transaction transferring control of an
+organization, or substantially all assets of one, or subdividing an
+organization, or merging organizations. If propagation of a covered
+work results from an entity transaction, each party to that
+transaction who receives a copy of the work also receives whatever
+licenses to the work the party's predecessor in interest had or could
+give under the previous paragraph, plus a right to possession of the
+Corresponding Source of the work from the predecessor in interest, if
+the predecessor has it or can get it with reasonable efforts.
+
+ You may not impose any further restrictions on the exercise of the
+rights granted or affirmed under this License. For example, you may
+not impose a license fee, royalty, or other charge for exercise of
+rights granted under this License, and you may not initiate litigation
+(including a cross-claim or counterclaim in a lawsuit) alleging that
+any patent claim is infringed by making, using, selling, offering for
+sale, or importing the Program or any portion of it.
+
+ 11. Patents.
+
+ A "contributor" is a copyright holder who authorizes use under this
+License of the Program or a work on which the Program is based. The
+work thus licensed is called the contributor's "contributor version".
+
+ A contributor's "essential patent claims" are all patent claims
+owned or controlled by the contributor, whether already acquired or
+hereafter acquired, that would be infringed by some manner, permitted
+by this License, of making, using, or selling its contributor version,
+but do not include claims that would be infringed only as a
+consequence of further modification of the contributor version. For
+purposes of this definition, "control" includes the right to grant
+patent sublicenses in a manner consistent with the requirements of
+this License.
+
+ Each contributor grants you a non-exclusive, worldwide, royalty-free
+patent license under the contributor's essential patent claims, to
+make, use, sell, offer for sale, import and otherwise run, modify and
+propagate the contents of its contributor version.
+
+ In the following three paragraphs, a "patent license" is any express
+agreement or commitment, however denominated, not to enforce a patent
+(such as an express permission to practice a patent or covenant not to
+sue for patent infringement). To "grant" such a patent license to a
+party means to make such an agreement or commitment not to enforce a
+patent against the party.
+
+ If you convey a covered work, knowingly relying on a patent license,
+and the Corresponding Source of the work is not available for anyone
+to copy, free of charge and under the terms of this License, through a
+publicly available network server or other readily accessible means,
+then you must either (1) cause the Corresponding Source to be so
+available, or (2) arrange to deprive yourself of the benefit of the
+patent license for this particular work, or (3) arrange, in a manner
+consistent with the requirements of this License, to extend the patent
+license to downstream recipients. "Knowingly relying" means you have
+actual knowledge that, but for the patent license, your conveying the
+covered work in a country, or your recipient's use of the covered work
+in a country, would infringe one or more identifiable patents in that
+country that you have reason to believe are valid.
+
+ If, pursuant to or in connection with a single transaction or
+arrangement, you convey, or propagate by procuring conveyance of, a
+covered work, and grant a patent license to some of the parties
+receiving the covered work authorizing them to use, propagate, modify
+or convey a specific copy of the covered work, then the patent license
+you grant is automatically extended to all recipients of the covered
+work and works based on it.
+
+ A patent license is "discriminatory" if it does not include within
+the scope of its coverage, prohibits the exercise of, or is
+conditioned on the non-exercise of one or more of the rights that are
+specifically granted under this License. You may not convey a covered
+work if you are a party to an arrangement with a third party that is
+in the business of distributing software, under which you make payment
+to the third party based on the extent of your activity of conveying
+the work, and under which the third party grants, to any of the
+parties who would receive the covered work from you, a discriminatory
+patent license (a) in connection with copies of the covered work
+conveyed by you (or copies made from those copies), or (b) primarily
+for and in connection with specific products or compilations that
+contain the covered work, unless you entered into that arrangement,
+or that patent license was granted, prior to 28 March 2007.
+
+ Nothing in this License shall be construed as excluding or limiting
+any implied license or other defenses to infringement that may
+otherwise be available to you under applicable patent law.
+
+ 12. No Surrender of Others' Freedom.
+
+ If conditions are imposed on you (whether by court order, agreement or
+otherwise) that contradict the conditions of this License, they do not
+excuse you from the conditions of this License. If you cannot convey a
+covered work so as to satisfy simultaneously your obligations under this
+License and any other pertinent obligations, then as a consequence you may
+not convey it at all. For example, if you agree to terms that obligate you
+to collect a royalty for further conveying from those to whom you convey
+the Program, the only way you could satisfy both those terms and this
+License would be to refrain entirely from conveying the Program.
+
+ 13. Use with the GNU Affero General Public License.
+
+ Notwithstanding any other provision of this License, you have
+permission to link or combine any covered work with a work licensed
+under version 3 of the GNU Affero General Public License into a single
+combined work, and to convey the resulting work. The terms of this
+License will continue to apply to the part which is the covered work,
+but the special requirements of the GNU Affero General Public License,
+section 13, concerning interaction through a network will apply to the
+combination as such.
+
+ 14. Revised Versions of this License.
+
+ The Free Software Foundation may publish revised and/or new versions of
+the GNU General Public License from time to time. Such new versions will
+be similar in spirit to the present version, but may differ in detail to
+address new problems or concerns.
+
+ Each version is given a distinguishing version number. If the
+Program specifies that a certain numbered version of the GNU General
+Public License "or any later version" applies to it, you have the
+option of following the terms and conditions either of that numbered
+version or of any later version published by the Free Software
+Foundation. If the Program does not specify a version number of the
+GNU General Public License, you may choose any version ever published
+by the Free Software Foundation.
+
+ If the Program specifies that a proxy can decide which future
+versions of the GNU General Public License can be used, that proxy's
+public statement of acceptance of a version permanently authorizes you
+to choose that version for the Program.
+
+ Later license versions may give you additional or different
+permissions. However, no additional obligations are imposed on any
+author or copyright holder as a result of your choosing to follow a
+later version.
+
+ 15. Disclaimer of Warranty.
+
+ THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
+APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
+HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
+OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
+THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
+IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
+ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
+
+ 16. Limitation of Liability.
+
+ IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
+WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
+THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
+GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
+USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
+DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
+PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
+EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
+SUCH DAMAGES.
+
+ 17. Interpretation of Sections 15 and 16.
+
+ If the disclaimer of warranty and limitation of liability provided
+above cannot be given local legal effect according to their terms,
+reviewing courts shall apply local law that most closely approximates
+an absolute waiver of all civil liability in connection with the
+Program, unless a warranty or assumption of liability accompanies a
+copy of the Program in return for a fee.
+
+ END OF TERMS AND CONDITIONS
+
+ How to Apply These Terms to Your New Programs
+
+ If you develop a new program, and you want it to be of the greatest
+possible use to the public, the best way to achieve this is to make it
+free software which everyone can redistribute and change under these terms.
+
+ To do so, attach the following notices to the program. It is safest
+to attach them to the start of each source file to most effectively
+state the exclusion of warranty; and each file should have at least
+the "copyright" line and a pointer to where the full notice is found.
+
+
+ Copyright (C)
+
+ This program is free software: you can redistribute it and/or modify
+ it under the terms of the GNU General Public License as published by
+ the Free Software Foundation, either version 3 of the License, or
+ (at your option) any later version.
+
+ This program is distributed in the hope that it will be useful,
+ but WITHOUT ANY WARRANTY; without even the implied warranty of
+ MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ GNU General Public License for more details.
+
+ You should have received a copy of the GNU General Public License
+ along with this program. If not, see .
+
+Also add information on how to contact you by electronic and paper mail.
+
+ If the program does terminal interaction, make it output a short
+notice like this when it starts in an interactive mode:
+
+ Copyright (C)
+ This program comes with ABSOLUTELY NO WARRANTY; for details type `show w'.
+ This is free software, and you are welcome to redistribute it
+ under certain conditions; type `show c' for details.
+
+The hypothetical commands `show w' and `show c' should show the appropriate
+parts of the General Public License. Of course, your program's commands
+might be different; for a GUI interface, you would use an "about box".
+
+ You should also get your employer (if you work as a programmer) or school,
+if any, to sign a "copyright disclaimer" for the program, if necessary.
+For more information on this, and how to apply and follow the GNU GPL, see
+.
+
+ The GNU General Public License does not permit incorporating your program
+into proprietary programs. If your program is a subroutine library, you
+may consider it more useful to permit linking proprietary applications with
+the library. If this is what you want to do, use the GNU Lesser General
+Public License instead of this License. But first, please read
+.
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 0000000..b192afd
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,9 @@
+prune outputs
+prune lightning_logs
+prune logs
+
+global-exclude .ipynb_checkpoints
+global-exclude .git*
+global-exclude *.ckpt
+global-exclude *.pt[h]
+global-exclude *.py[oc]
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..f8283b7
--- /dev/null
+++ b/README.md
@@ -0,0 +1,116 @@
+# Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease
+
+[](https://arxiv.org/abs/2303.07125)
+[](LICENSE)
+
+This repository contains the code to the paper "Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease"
+
+```
+@misc{https://doi.org/10.48550/arxiv.2303.07125,
+ doi = {10.48550/ARXIV.2303.07125},
+ url = {https://arxiv.org/abs/2303.07125},
+ author = {Wolf, Tom Nuno and Pölsterl, Sebastian and Wachinger, Christian},
+ keywords = {Machine Learning (cs.LG), Computer Vision and Pattern Recognition (cs.CV), FOS: Computer and information sciences, FOS: Computer and information sciences},
+ title = {Don't PANIC: Prototypical Additive Neural Network for Interpretable Classification of Alzheimer's Disease},
+ publisher = {arXiv},
+ year = {2023},
+ copyright = {arXiv.org perpetual, non-exclusive license}
+}
+```
+
+If you are using this code, please cite the paper above.
+
+
+## Installation
+
+Use [conda](https://conda.io/miniconda.html) to create an environment called `panic` with all dependencies:
+
+```bash
+conda env create -n panic --file requirements.yaml
+```
+
+Additionally, install the package torchpanic from this repository with
+```bash
+pip install --no-deps -e .
+```
+
+## Data
+
+We used data from the [Alzheimer's Disease Neuroimaging Initiative (ADNI)](http://adni.loni.usc.edu/).
+Since we are not allowed to share our data, you would need to process the data yourself.
+Data for training, validation, and testing should be stored in separate
+[HDF5](https://en.wikipedia.org/wiki/Hierarchical_Data_Format) files,
+using the following hierarchical format:
+
+1. First level: A unique identifier.
+2. The second level always has the following entries:
+ 1. A group named `PET` with the subgroup `FDG`, which itself has the
+ [dataset](https://docs.h5py.org/en/stable/high/dataset.html) named `data` as child:
+ The graymatter density map of size (113,117,113). Additionally, the subgroup `FDG` has an attribute `imageuid` with is the unique image identifier.
+ 2. A group named `tabular`, which has two datasets called `data` and `missing`, each of size 41:
+ `data` contains the tabular data values, while `missing` is a missing value indicator if a tabular feature was not acquired at this visit.
+ 3. A scalar [attribute](https://docs.h5py.org/en/stable/high/attr.html) `RID` with the *patient* ID.
+ 4. A string attribute `VISCODE` with ADNI's visit code.
+ 5. A string attribute `DX` containing the diagnosis (`CN`, `MCI` or `Dementia`).
+
+One entry in the resulting HDF5 file should have the following structure:
+```
+/1010012 Group
+ Attribute: RID scalar
+ Type: native long
+ Data: 1234
+ Attribute: VISCODE scalar
+ Type: variable-length null-terminated UTF-8 string
+ Data: "bl"
+ Attribute: DX scalar
+ Type: variable-length null-terminated UTF-8 string
+ Data: "CN"
+/1010012/PET Group
+/1010012/PET/FDG Group
+ Attribute imageuid scalar
+ Type: variable-length null-terminated UTF-8 string
+ Data: "12345"
+/1010012/PET/FDG/data Dataset {113, 137, 133}
+/1010012/tabular Group
+/1010012/tabular/data Dataset {41}
+/1010012/tabular/missing Dataset {41}
+```
+
+Finally, the HDF5 file should also contain the following meta-information
+in a separate group named `stats`:
+
+```
+/stats/tabular Group
+/stats/tabular/columns Dataset {41}
+/stats/tabular/mean Dataset {41}
+/stats/tabular/stddev Dataset {41}
+```
+
+They are the names of the features in the tabular data,
+their mean, and standard deviation.
+
+## Usage
+
+PANIC processes tabular data depending on its data type.
+Therefore, it is necessary to tell PANIC how to process each tabular feature:
+The following indices must be given to the model in the configs file `configs/model/panic.yaml`:
+
+`idx_real_features`: indices of real-valued features within `tabular` data.
+`idx_cat_features`: indices of categorical features within `tabular` data.
+`idx_real_has_missing`: indices of real-valued features which should be considered from `missing`.
+`idx_cat_has_missing`: indices of categorical features which should be considered from `missing`.
+
+Similarly, missing tabular inputs to DAFT (`configs/model/daft.yaml`) need to be specified with `idx_tabular_has_missing`.
+
+## Training
+
+To train PANIC, or any of the baseline models, adapt the config files (mainly `train.yaml`) and execute the `train.py` script to begin training.
+
+Model checkpoints will be written to the `outputs` folder by default.
+
+
+## Interpretation of results
+
+We provide some useful utility function to create plots and visualization required to interpret the model.
+You can find them under `torchpanic/viz`.
+
diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml
new file mode 100644
index 0000000..13f5e67
--- /dev/null
+++ b/configs/callbacks/default.yaml
@@ -0,0 +1,12 @@
+model_checkpoint:
+ _target_: pytorch_lightning.callbacks.ModelCheckpoint
+ monitor: ${model.validation_metric} # name of the logged metric which determines when model is improving
+ mode: max # "max" means higher metric value is better, can be also "min"
+ save_top_k: 1 # save k best models (determined by above metric)
+ dirpath: checkpoints/
+ filename: "{epoch}-bacc"
+
+learning_rate_monitor:
+ _target_: pytorch_lightning.callbacks.LearningRateMonitor
+ logging_interval: epoch
+ log_momentum: False
diff --git a/configs/datamodule/adni.yaml b/configs/datamodule/adni.yaml
new file mode 100644
index 0000000..ac4316a
--- /dev/null
+++ b/configs/datamodule/adni.yaml
@@ -0,0 +1,16 @@
+_target_: torchpanic.datamodule.adni.AdniDataModule
+
+train_data: ${data_dir}/${fold}-train.h5 # data_dir is specified in config.yaml
+valid_data: ${data_dir}/${fold}-valid.h5
+test_data: ${data_dir}/${fold}-test.h5
+modalities: ["PET", "TABULAR"] # 2 = ModalityType.PET; 6 = ModalityType.TABULAR|PET; 7 = ModalityType.TABULAR|PET|MRI
+batch_size: 32
+num_workers: 10
+metadata:
+ num_channels: 1
+ num_classes: 3
+augmentation:
+ rotate: 30
+ translate: 0
+ scale: 0.2
+ p: 0.5
diff --git a/configs/experiment/daft_tuned.yaml b/configs/experiment/daft_tuned.yaml
new file mode 100644
index 0000000..0a62a6a
--- /dev/null
+++ b/configs/experiment/daft_tuned.yaml
@@ -0,0 +1,8 @@
+# @package _global_
+
+defaults:
+ - override /datamodule: adni
+ - override /model: daft
+
+name: adni-daft
+
diff --git a/configs/hydra/default.yaml b/configs/hydra/default.yaml
new file mode 100644
index 0000000..e69de29
diff --git a/configs/hydra/panic.yaml b/configs/hydra/panic.yaml
new file mode 100644
index 0000000..e26a287
--- /dev/null
+++ b/configs/hydra/panic.yaml
@@ -0,0 +1,7 @@
+run:
+ dir: outputs/${name}/seed-${seed}/fold-${fold}
+job:
+ chdir: True
+sweep:
+ dir: outputs
+ subdir: ${name}/seed-${seed}/fold-${fold}
diff --git a/configs/logger/tensorboard.yaml b/configs/logger/tensorboard.yaml
new file mode 100644
index 0000000..607142d
--- /dev/null
+++ b/configs/logger/tensorboard.yaml
@@ -0,0 +1,8 @@
+tensorboard:
+ _target_: pytorch_lightning.loggers.tensorboard.TensorBoardLogger
+ save_dir: logs/
+ name: null
+ version: ${name}
+ log_graph: False
+ default_hp_metric: True
+ prefix: ""
diff --git a/configs/model/daft.yaml b/configs/model/daft.yaml
new file mode 100644
index 0000000..0b07f54
--- /dev/null
+++ b/configs/model/daft.yaml
@@ -0,0 +1,53 @@
+_target_: torchpanic.modules.standard.StandardModule
+lr: 0.003607382027438678
+weight_decay: 1.1061692016959738e-05
+output_penalty_weight: 0.0
+num_classes: ${datamodule.metadata.num_classes}
+validation_metric: 'val/bacc_best'
+
+net:
+ _target_: torchpanic.models.daft.DAFT
+ in_channels: ${datamodule.metadata.num_channels}
+ in_tabular: 41
+ n_outputs: ${datamodule.metadata.num_classes}
+ n_basefilters: 32
+ filmblock_args:
+ location: 3
+ scale: True
+ shift: True
+ bottleneck_dim: 12
+ idx_tabular_has_missing:
+ - 3 # abeta
+ - 4 # tau
+ - 5 # ptau
+ - 10 # all categorical, except gender
+ - 11
+ - 12
+ - 13
+ - 14
+ - 15
+ - 16
+ - 17
+ - 18
+ - 19
+ - 20
+ - 21
+ - 22
+ - 23
+ - 24
+ - 25
+ - 26
+ - 27
+ - 28
+ - 29
+ - 30
+ - 31
+ - 32
+ - 33
+ - 34
+ - 35
+ - 36
+ - 37
+ - 38
+ - 39
+ - 40
diff --git a/configs/model/panic.yaml b/configs/model/panic.yaml
new file mode 100644
index 0000000..475fc20
--- /dev/null
+++ b/configs/model/panic.yaml
@@ -0,0 +1,103 @@
+_target_: torchpanic.modules.panic.PANIC
+
+lr: 0.0067
+weight_decay: 0.0001
+weight_decay_nam: 0.0001
+l_clst: 0.5 # lambda to multipli loss of intra_clst
+l_sep: ${model.l_clst} # lambda to multipli loss of inter_clst
+l_occ: 0.5 # lambda of occurrence loss
+l_affine: 0.5
+l_nam: 0.0001 # l2 regularization of coefficients of NAM
+epochs_all: 20 # push prototypes every x epochs
+epochs_nam: 10
+epochs_warmup: 10
+enable_checkpointing: ${trainer.enable_checkpointing}
+monitor_prototypes: False
+enable_save_embeddings: False
+enable_log_prototypes: False
+validation_metric: 'val/bacc_save'
+
+net:
+ _target_: torchpanic.models.panic.PANIC
+ protonet:
+ backbone: "3dresnet"
+ in_channels: ${datamodule.metadata.num_channels}
+ out_features: ${datamodule.metadata.num_classes}
+ n_prototypes_per_class: 2 # just used for init
+ n_chans_protos: 64
+ optim_features: True
+ normed_prototypes: True
+ n_blocks: 3
+ n_basefilters: 32
+ pretrained_model: ${hydra:runtime.cwd}/outputs/pretrained_encoders/seed-666/fold-${fold}/checkpoints/best.ckpt
+ nam:
+ out_features: 3
+ hidden_units: [32, 32]
+ dropout_rate: 0.5
+ feature_dropout_rate: 0.1
+ idx_real_features: [1, 2, 3, 4, 5, 6, 7, 8, 9] # age, edu, abeta, tau, ptau, L-Hipp, R-Hipp, L-Ento, R-Ento
+ idx_cat_features:
+ - 0
+ - 10
+ - 11
+ - 12
+ - 13
+ - 14
+ - 15
+ - 16
+ - 17
+ - 18
+ - 19
+ - 20
+ - 21
+ - 22
+ - 23
+ - 24
+ - 25
+ - 26
+ - 27
+ - 28
+ - 29
+ - 30
+ - 31
+ - 32
+ - 33
+ - 34
+ - 35
+ - 36
+ - 37
+ - 38
+ - 39
+ - 40
+ idx_real_has_missing: [2, 3, 4]
+ idx_cat_has_missing:
+ - 1
+ - 2
+ - 3
+ - 4
+ - 5
+ - 6
+ - 7
+ - 8
+ - 9
+ - 10
+ - 11
+ - 12
+ - 13
+ - 14
+ - 15
+ - 16
+ - 17
+ - 18
+ - 19
+ - 20
+ - 21
+ - 22
+ - 23
+ - 24
+ - 25
+ - 26
+ - 27
+ - 28
+ - 29
+ - 30
diff --git a/configs/model/pretrain_encoder.yaml b/configs/model/pretrain_encoder.yaml
new file mode 100644
index 0000000..07b09b6
--- /dev/null
+++ b/configs/model/pretrain_encoder.yaml
@@ -0,0 +1,18 @@
+_target_: torchpanic.modules.standard.StandardModule
+
+lr: 0.003
+weight_decay: 0.001
+validation_metric: 'val/bacc_best'
+
+net:
+ _target_: torchpanic.models.pretrain_encoder.Encoder
+ protonet:
+ backbone: "3dresnet"
+ in_channels: ${datamodule.metadata.num_channels}
+ out_features: ${datamodule.metadata.num_classes}
+ n_prototypes_per_class: 3 # just used for init
+ n_chans_protos: 64
+ optim_features: True
+ n_blocks: 4
+ n_basefilters: 32
+ normed_prototypes: True
diff --git a/configs/test.yaml b/configs/test.yaml
new file mode 100644
index 0000000..a8b8e5e
--- /dev/null
+++ b/configs/test.yaml
@@ -0,0 +1,21 @@
+# @package _global_
+
+# specify here default evaluation configuration
+defaults:
+ - _self_
+ - datamodule: adni.yaml
+ - model: panic.yaml
+ - logger: tensorboard.yaml
+ - trainer: default.yaml
+
+# path to folder with data
+data_dir: ???
+
+# seed for random number generators in pytorch, numpy and python.random
+seed: 666
+
+fold: 0
+
+name: "protopnet_test"
+
+ckpt_path: ???
diff --git a/configs/train.yaml b/configs/train.yaml
new file mode 100644
index 0000000..4c0f407
--- /dev/null
+++ b/configs/train.yaml
@@ -0,0 +1,23 @@
+# @package _global_
+
+# specify here default training configuration
+defaults:
+ - _self_
+ - hydra: panic.yaml
+ - callbacks: default.yaml
+ - datamodule: adni.yaml
+ - model: panic.yaml
+ - logger: tensorboard.yaml
+ - trainer: default.yaml
+
+# path to folder with data
+data_dir: ???
+
+# fold to run this experiment for
+fold: 0
+
+# seed for random number generators in pytorch, numpy and python.random
+seed: 666
+
+# default name for the experiment, determines logging folder path
+name: "train_panic"
diff --git a/configs/trainer/default.yaml b/configs/trainer/default.yaml
new file mode 100644
index 0000000..8824d53
--- /dev/null
+++ b/configs/trainer/default.yaml
@@ -0,0 +1,20 @@
+_target_: pytorch_lightning.Trainer
+
+accelerator: gpu
+devices: 1
+min_epochs: 1
+max_epochs: 100
+
+enable_progress_bar: True
+
+detect_anomaly: True
+log_every_n_steps: 10
+track_grad_norm: -1 # -1: disabled; 2: track 2-norm
+
+# number of validation steps to execute at the beginning of the training
+# num_sanity_val_steps: 0
+
+# ckpt path
+resume_from_checkpoint: null
+
+enable_checkpointing: True
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..80e7c8c
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,5 @@
+[build-system]
+requires = [
+ "setuptools>=45",
+]
+build-backend = "setuptools.build_meta"
diff --git a/requirements.yaml b/requirements.yaml
new file mode 100644
index 0000000..98c6ed8
--- /dev/null
+++ b/requirements.yaml
@@ -0,0 +1,92 @@
+name: panic
+channels:
+ - defaults
+dependencies:
+ - _libgcc_mutex=0.1=main
+ - _openmp_mutex=5.1=1_gnu
+ - ca-certificates=2022.10.11=h06a4308_0
+ - certifi=2022.9.24=py39h06a4308_0
+ - ld_impl_linux-64=2.38=h1181459_1
+ - libffi=3.3=he6710b0_2
+ - libgcc-ng=11.2.0=h1234567_1
+ - libgomp=11.2.0=h1234567_1
+ - libstdcxx-ng=11.2.0=h1234567_1
+ - ncurses=6.3=h5eee18b_3
+ - openssl=1.1.1s=h7f8727e_0
+ - pip=22.3.1=py39h06a4308_0
+ - python=3.9.12=h12debd9_1
+ - readline=8.2=h5eee18b_0
+ - setuptools=65.5.0=py39h06a4308_0
+ - sqlite=3.40.0=h5082296_0
+ - tk=8.6.12=h1ccaba5_0
+ - tzdata=2022g=h04d1e81_0
+ - wheel=0.37.1=pyhd3eb1b0_0
+ - xz=5.2.8=h5eee18b_0
+ - zlib=1.2.13=h5eee18b_0
+ - pip:
+ - --extra-index-url https://download.pytorch.org/whl/cu113
+ - aiohttp==3.8.3
+ - aiosignal==1.3.1
+ - antlr4-python3-runtime==4.9.3
+ - asttokens==2.2.1
+ - async-timeout==4.0.2
+ - attrs==22.1.0
+ - backcall==0.2.0
+ - charset-normalizer==2.1.1
+ - click==8.1.3
+ - colorama==0.4.6
+ - commonmark==0.9.1
+ - decorator==5.1.1
+ - deprecated==1.2.13
+ - executing==1.2.0
+ - frozenlist==1.3.3
+ - fsspec==2022.11.0
+ - h5py==3.7.0
+ - humanize==4.4.0
+ - hydra-core==1.3.0
+ - idna==3.4
+ - ipdb==0.13.11
+ - ipython==8.7.0
+ - jedi==0.18.2
+ - lightning-utilities==0.4.2
+ - matplotlib-inline==0.1.6
+ - multidict==6.0.3
+ - nibabel==4.0.2
+ - numpy==1.23.5
+ - omegaconf==2.3.0
+ - packaging==22.0
+ - pandas==1.5.2
+ - parso==0.8.3
+ - pexpect==4.8.0
+ - pickleshare==0.7.5
+ - pillow==9.3.0
+ - prompt-toolkit==3.0.36
+ - protobuf==3.20.1
+ - ptyprocess==0.7.0
+ - pure-eval==0.2.2
+ - pygments==2.13.0
+ - python-dateutil==2.8.2
+ - pytorch-lightning==1.8.4.post0
+ - pytz==2022.6
+ - pyyaml==6.0
+ - requests==2.28.1
+ - rich==12.6.0
+ - scipy==1.9.3
+ - shellingham==1.5.0
+ - simpleitk==2.2.1
+ - six==1.16.0
+ - stack-data==0.6.2
+ - tensorboardx==2.5.1
+ - tomli==2.0.1
+ - torch==1.12.1+cu113
+ - torchio==0.18.86
+ - torchmetrics==0.11.0
+ - torchvision==0.13.1+cu113
+ - tqdm==4.64.1
+ - traitlets==5.7.1
+ - typer==0.7.0
+ - typing-extensions==4.4.0
+ - urllib3==1.26.13
+ - wcwidth==0.2.5
+ - wrapt==1.14.1
+ - yarl==1.8.2
diff --git a/setup.cfg b/setup.cfg
new file mode 100644
index 0000000..c98acd9
--- /dev/null
+++ b/setup.cfg
@@ -0,0 +1,22 @@
+[metadata]
+name = torchpanic
+version = 0.0.1
+description = PANIC for AD diagnosis with XAI
+classifiers =
+ Programming Language :: Python :: 3
+
+[options]
+python_requires = >=3.9
+zip_safe = False
+packages=find:
+install_requires =
+ h5py
+ hydra
+ numpy
+ omegaconf
+ pandas
+ pytorch_lightning
+ tensorboard
+ torch
+ torchio
+ torchmetrics
diff --git a/setup.py b/setup.py
new file mode 100644
index 0000000..6068493
--- /dev/null
+++ b/setup.py
@@ -0,0 +1,3 @@
+from setuptools import setup
+
+setup()
diff --git a/test.py b/test.py
new file mode 100644
index 0000000..400403e
--- /dev/null
+++ b/test.py
@@ -0,0 +1,30 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import logging
+
+import hydra
+from omegaconf import DictConfig
+
+from torchpanic.testing import test
+
+
+@hydra.main(config_path="configs/", config_name="test.yaml")
+def main(config: DictConfig):
+ return test(config)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ main()
diff --git a/torchpanic/__init__.py b/torchpanic/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/torchpanic/datamodule/__init__.py b/torchpanic/datamodule/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/torchpanic/datamodule/adni.py b/torchpanic/datamodule/adni.py
new file mode 100644
index 0000000..1d599ad
--- /dev/null
+++ b/torchpanic/datamodule/adni.py
@@ -0,0 +1,297 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import collections.abc
+import copy
+import logging
+from operator import itemgetter
+from typing import Any, Dict, Optional, Sequence, Tuple, Union
+
+import h5py
+import numpy as np
+import pandas as pd
+import pytorch_lightning as pl
+from torch.utils.data import DataLoader, Dataset
+from torch.utils.data.dataloader import default_collate
+import torchio as tio
+
+from .modalities import AugmentationType, DataPointType, ModalityType
+
+LOG = logging.getLogger(__name__)
+
+DIAGNOSIS_MAP = {"CN": 0, "MCI": 1, "Dementia": 2}
+DIAGNOSIS_MAP_BINARY = {"CN": 0, "Dementia": 1}
+
+
+def Identity(x):
+ return x
+
+
+def get_image_transform(p=0.0, rotate=0, translate=0, scale=0):
+ # Image sizes PET & MRI Dataset {113, 137, 113}
+ img_transforms = []
+
+ randomAffineWithRot = tio.RandomAffine(
+ scales=scale,
+ degrees=rotate,
+ translation=translate,
+ image_interpolation="linear",
+ default_pad_value="otsu",
+ p=p, # no transform if validation
+ )
+ img_transforms.append(randomAffineWithRot)
+
+ img_transform = tio.Compose(img_transforms)
+ return img_transform
+
+
+class AdniDataset(Dataset):
+ def __init__(
+ self,
+ path: str,
+ modalities: Union[ModalityType, int, Sequence[str]],
+ augmentation: Dict[str, Any],
+ ) -> None:
+ self.path = path
+ if isinstance(modalities, collections.abc.Sequence):
+ mod_type = ModalityType(0)
+ for mod in modalities:
+ mod_type |= getattr(ModalityType, mod)
+ self.modalities = mod_type
+ else:
+ self.modalities = ModalityType(modalities)
+ self.augmentation = augmentation
+
+ self._set_transforms(augmentations=self.augmentation)
+ self._load()
+
+ def _set_transforms(self, augmentations: Dict) -> None:
+ self.transforms = {}
+ if ModalityType.MRI in self.modalities:
+ self.transforms[ModalityType.MRI] = get_image_transform(**augmentations)
+ if ModalityType.PET in self.modalities:
+ self.transforms[ModalityType.PET] = get_image_transform(**augmentations)
+ if ModalityType.TABULAR in self.modalities:
+ self.transforms[ModalityType.TABULAR] = Identity
+
+ def _load(self) -> None:
+ data_points = {
+ flag: [] for flag in ModalityType.__members__.values() if flag in self.modalities
+ }
+ load_mri = ModalityType.MRI in self.modalities
+ load_pet = ModalityType.PET in self.modalities
+ load_tab = ModalityType.TABULAR in self.modalities
+
+ LOG.info("Loading %s from %s", self.modalities, self.path)
+
+ diagnosis = []
+ rid = []
+ column_names = None
+ with h5py.File(self.path, mode='r') as file:
+ if load_tab:
+ tab_stats = file['stats/tabular']
+ tab_mean = tab_stats['mean'][:]
+ tab_std = tab_stats['stddev'][:]
+ assert np.all(tab_std > 0), "stddev is not positive"
+ column_names = tab_stats['columns'][:]
+ self._tab_mean = tab_mean
+ self._tab_std = tab_std
+
+ for name, group in file.items():
+ if name == "stats":
+ continue
+
+ data_point = []
+ if load_mri:
+ mri_data = group['MRI/T1/data'][:]
+ data_point.append(
+ tio.Subject(
+ image=tio.ScalarImage(tensor=mri_data[np.newaxis])
+ )
+ )
+
+ if load_pet:
+ pet_data = group['PET/FDG/data'][:]
+ # pet_data = np.nan_to_num(pet_data, copy=False)
+ data_point.append(
+ tio.Subject(
+ image=tio.ScalarImage(tensor=pet_data[np.newaxis])
+ )
+ )
+
+ if load_tab:
+ tab_values = group['tabular/data'][:]
+ tab_missing = group['tabular/missing'][:]
+ # XXX: always assumes that mean and std are from the training data
+ tab_data = np.stack((
+ (tab_values - tab_mean) / tab_std,
+ tab_missing,
+ ))
+ data_point.append(tab_data)
+
+ assert len(data_points) == len(data_point)
+ for samples, data in zip(data_points.values(), data_point):
+ samples.append(data)
+
+ diagnosis.append(group.attrs['DX'])
+ rid.append(group.attrs['RID'])
+
+ LOG.info("Loaded %d samples", len(rid))
+
+ dmap = DIAGNOSIS_MAP
+ labels, counts = np.unique(diagnosis, return_counts=True)
+ assert len(labels) == len(dmap), f"expected {len(dmap)} labels, but got {labels}"
+ LOG.info("Classes: %s", pd.Series(counts, index=labels))
+
+ self._column_names = column_names
+ self._data_points = data_points
+ self._diagnosis = [dmap[d] for d in diagnosis]
+ self._rid = rid
+
+ @property
+ def rid(self):
+ return self._rid
+
+ @property
+ def column_names(self):
+ return self._column_names
+
+ @property
+ def tabular_mean(self):
+ return self._tab_mean
+
+ @property
+ def tabular_stddev(self):
+ return self._tab_std
+
+ def tabular_inverse_transform(self, values):
+ values_arr = np.ma.atleast_2d(values)
+ if len(self._tab_mean) != values_arr.shape[1]:
+ raise ValueError(f"expected {len(self._tab_mean)} features, but got {values_arr.shape[1]}")
+
+ vals_t = values_arr * self._tab_std[np.newaxis] + self._tab_mean[np.newaxis]
+ return vals_t.reshape(values.shape)
+
+ def get_tabular(self, index: int, inverse_transform: bool = False) -> np.ma.array:
+ tab_vals, tab_miss = self[index][0][ModalityType.TABULAR]
+ tab_vals = np.ma.array(tab_vals, mask=tab_miss)
+ if inverse_transform:
+ return self.tabular_inverse_transform(tab_vals)
+ return tab_vals
+
+ def __len__(self) -> int:
+ return len(self._rid)
+
+ def _as_tensor(self, x):
+ if isinstance(x, tio.Subject):
+ return x.image.data
+ return x
+
+ def __getitem__(self, index: int) -> Tuple[DataPointType, DataPointType, AugmentationType, int]:
+ label = self._diagnosis[index]
+ sample = {}
+ sample_raw = {}
+ augmentations = {}
+ for modality_id, samples in self._data_points.items():
+ data_raw = samples[index]
+ data_transformed = self.transforms[modality_id](data_raw)
+ if isinstance(data_transformed, tio.Subject):
+ augmentations[modality_id] = data_transformed.get_composed_history()
+
+ sample_raw[modality_id] = self._as_tensor(data_raw)
+ sample[modality_id] = self._as_tensor(data_transformed)
+
+ return sample, sample_raw, augmentations, label
+
+
+def collate_adni(batch):
+ get = itemgetter(0, 1, 3) # 2nd position is a tio.Transform instance
+ batch_wo_aug = [get(elem) for elem in batch]
+
+ keys = batch[0][2].keys()
+ augmentations = {k: [elem[2][k] for elem in batch] for k in keys}
+
+ batch_stacked = default_collate(batch_wo_aug)
+ return batch_stacked[:-1] + [augmentations] + batch_stacked[-1:]
+
+
+class AdniDataModule(pl.LightningDataModule):
+ def __init__(
+ self,
+ modalities: Union[ModalityType, int],
+ train_data: str,
+ valid_data: str,
+ test_data: Optional[str] = None,
+ batch_size: int = 32,
+ num_workers: int = 4,
+ augmentation: Dict[str, Any] = {"p": 0.0},
+ metadata: Optional[Dict[str, Any]] = None,
+ ):
+ super().__init__()
+ self.modalities = modalities
+ self.train_data = train_data
+ self.valid_data = valid_data
+ self.test_data = test_data
+ self.batch_size = batch_size
+ self.num_workers = num_workers
+ self.metadata = metadata
+ self.augmentation = augmentation
+
+ def setup(self, stage: Optional[str] = None):
+ if stage == 'fit' or stage is None:
+ self.train_dataset = AdniDataset(self.train_data, modalities=self.modalities, augmentation=self.augmentation)
+ self.push_dataset = copy.deepcopy(self.train_dataset)
+ self.push_dataset._set_transforms(augmentations={"p": 0})
+ self.eval_dataset = AdniDataset(self.valid_data, modalities=self.modalities, augmentation={"p": 0})
+ elif stage == 'test' and self.test_data is not None:
+ self.test_dataset = AdniDataset(self.test_data, modalities=self.modalities, augmentation={"p": 0})
+ self.eval_dataset = AdniDataset(self.valid_data, modalities=self.modalities, augmentation={"p": 0})
+
+ def train_dataloader(self):
+ return DataLoader(
+ self.train_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ drop_last=True,
+ pin_memory=True,
+ shuffle=True,
+ collate_fn=collate_adni,
+ )
+
+ def val_dataloader(self):
+ return DataLoader(
+ self.eval_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers // 2,
+ pin_memory=True,
+ collate_fn=collate_adni,
+ )
+
+ def test_dataloader(self):
+ return DataLoader(
+ self.test_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ collate_fn=collate_adni,
+ )
+
+ def push_dataloader(self):
+ return DataLoader(
+ self.push_dataset,
+ batch_size=self.batch_size,
+ num_workers=self.num_workers,
+ pin_memory=True,
+ collate_fn=collate_adni,
+ )
diff --git a/torchpanic/datamodule/modalities.py b/torchpanic/datamodule/modalities.py
new file mode 100644
index 0000000..fbe5faf
--- /dev/null
+++ b/torchpanic/datamodule/modalities.py
@@ -0,0 +1,30 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import enum
+from typing import Dict, Tuple
+
+from torch import Tensor
+import torchio as tio
+
+
+class ModalityType(enum.IntFlag):
+ MRI = 1
+ PET = 2
+ TABULAR = 4
+
+
+DataPointType = Dict[ModalityType, Tensor]
+AugmentationType = Dict[ModalityType, tio.Compose]
+BatchWithLabelType = Tuple[DataPointType, DataPointType, AugmentationType, Tensor]
diff --git a/torchpanic/models/__init__.py b/torchpanic/models/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/torchpanic/models/backbones.py b/torchpanic/models/backbones.py
new file mode 100644
index 0000000..90667b4
--- /dev/null
+++ b/torchpanic/models/backbones.py
@@ -0,0 +1,142 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+from collections import OrderedDict
+from torch import nn
+
+
+def conv3d(
+ in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1,
+) -> nn.Module:
+ if kernel_size != 1:
+ padding = 1
+ else:
+ padding = 0
+ return nn.Conv3d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False,
+ )
+
+
+class ConvBnReLU(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bn_momentum: float = 0.05,
+ kernel_size: int = 3,
+ stride: int = 1,
+ padding: int = 1,
+ ):
+ super().__init__()
+ self.conv = nn.Conv3d(
+ in_channels, out_channels,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ bias=False,
+ )
+ self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.bn(out)
+ out = self.relu(out)
+ return out
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels: int, out_channels: int, bn_momentum: float = 0.05, stride: int = 1):
+ super().__init__()
+ self.conv1 = conv3d(in_channels, out_channels, stride=stride)
+ self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
+ self.dropout1 = nn.Dropout(p=0.2, inplace=True)
+
+ self.conv2 = conv3d(out_channels, out_channels)
+ self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
+ self.relu = nn.ReLU(inplace=True)
+
+ if stride != 1 or in_channels != out_channels:
+ self.downsample = nn.Sequential(
+ conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
+ nn.BatchNorm3d(out_channels, momentum=bn_momentum)
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.dropout1(out)
+
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class ThreeDResNet(nn.Module):
+ def __init__(self, in_channels: int, n_outputs: int, n_blocks: int = 4, bn_momentum: float = 0.05, n_basefilters: int = 32):
+ super().__init__()
+ self.conv1 = nn.Conv3d(
+ in_channels,
+ n_basefilters,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False,)
+ self.bn1 = nn.BatchNorm3d(n_basefilters, momentum=bn_momentum)
+ self.relu1 = nn.ReLU(inplace=True)
+ self.pool1 = nn.MaxPool3d(3, stride=2)
+
+ if n_blocks < 2:
+ raise ValueError(f"n_blocks must be at least 2, but got {n_blocks}")
+
+ blocks = [
+ ("block1", ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum))
+ ]
+ n_filters = n_basefilters
+ for i in range(n_blocks - 1):
+ blocks.append(
+ (f"block{i+2}", ResBlock(n_filters, 2 * n_filters, bn_momentum=bn_momentum, stride=2))
+ )
+ n_filters *= 2
+
+ self.blocks = nn.Sequential(OrderedDict(blocks))
+ self.gap = nn.AdaptiveAvgPool3d(1)
+ self.fc = nn.Linear(n_filters, n_outputs, bias=False)
+
+# def get_out_features(self):
+#
+# return self.out_features
+
+ def forward(self, x):
+ out = self.conv1(x)
+ out = self.bn1(out)
+ out = self.relu1(out)
+ out = self.pool1(out)
+ out = self.blocks(out)
+ out = self.gap(out)
+ out = out.view(out.size(0), -1)
+ return self.fc(out)
diff --git a/torchpanic/models/daft/__init__.py b/torchpanic/models/daft/__init__.py
new file mode 100644
index 0000000..104830a
--- /dev/null
+++ b/torchpanic/models/daft/__init__.py
@@ -0,0 +1 @@
+from .vol_networks import DAFT
diff --git a/torchpanic/models/daft/vol_blocks.py b/torchpanic/models/daft/vol_blocks.py
new file mode 100644
index 0000000..aca8c39
--- /dev/null
+++ b/torchpanic/models/daft/vol_blocks.py
@@ -0,0 +1,276 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+#
+#
+#
+#
+# This file is part of Dynamic Affine Feature Map Transform (DAFT).
+#
+# DAFT is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# DAFT is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with DAFT. If not, see .
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import torch
+import torch.nn as nn
+import torch.nn.parallel
+import torch.utils.data
+
+
+def conv3d(in_channels, out_channels, kernel_size=3, stride=1):
+ if kernel_size != 1:
+ padding = 1
+ else:
+ padding = 0
+ return nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
+
+
+class ConvBnReLU(nn.Module):
+ def __init__(
+ self, in_channels, out_channels, bn_momentum=0.05, kernel_size=7, stride=2, padding=3,
+ ):
+ super().__init__()
+ self.conv = nn.Conv3d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, bias=False)
+ self.bn = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ out = self.conv(x)
+ out = self.bn(out)
+ out = self.relu(out)
+ return out
+
+
+class ResBlock(nn.Module):
+ def __init__(self, in_channels, out_channels, bn_momentum=0.05, stride=1):
+ super().__init__()
+ self.conv1 = conv3d(in_channels, out_channels, stride=stride)
+ self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
+ self.conv2 = conv3d(out_channels, out_channels)
+ self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
+ self.relu = nn.ReLU(inplace=True)
+
+ if stride != 1 or in_channels != out_channels:
+ self.downsample = nn.Sequential(
+ conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
+ nn.BatchNorm3d(out_channels, momentum=bn_momentum),
+ )
+ else:
+ self.downsample = None
+
+ def forward(self, x):
+ residual = x
+
+ out = self.conv1(x)
+ out = self.bn1(out)
+
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.bn2(out)
+
+ if self.downsample is not None:
+ residual = self.downsample(x)
+
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class FilmBase(nn.Module, metaclass=ABCMeta):
+ """Absract base class for models that are related to FiLM of Perez et al"""
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bn_momentum: float,
+ stride: int,
+ ndim_non_img: int,
+ location: int,
+ activation: str,
+ scale: bool,
+ shift: bool,
+ ) -> None:
+
+ super().__init__()
+
+ # sanity checks
+ if location not in set(range(5)):
+ raise ValueError(f"Invalid location specified: {location}")
+ if activation not in {"tanh", "sigmoid", "linear"}:
+ raise ValueError(f"Invalid location specified: {location}")
+ if (not isinstance(scale, bool) or not isinstance(shift, bool)) or (not scale and not shift):
+ raise ValueError(
+ f"scale and shift must be of type bool:\n -> scale value: {scale}, "
+ "scale type {type(scale)}\n -> shift value: {shift}, shift type: {type(shift)}"
+ )
+ # ResBlock
+ self.conv1 = conv3d(in_channels, out_channels, stride=stride)
+ self.bn1 = nn.BatchNorm3d(out_channels, momentum=bn_momentum, affine=(location != 3))
+ self.conv2 = conv3d(out_channels, out_channels)
+ self.bn2 = nn.BatchNorm3d(out_channels, momentum=bn_momentum)
+ self.relu = nn.ReLU(inplace=True)
+ self.global_pool = nn.AdaptiveAvgPool3d(1)
+ if stride != 1 or in_channels != out_channels:
+ self.downsample = nn.Sequential(
+ conv3d(in_channels, out_channels, kernel_size=1, stride=stride),
+ nn.BatchNorm3d(out_channels, momentum=bn_momentum),
+ )
+ else:
+ self.downsample = None
+ # Film-specific variables
+ self.location = location
+ if self.location == 2 and self.downsample is None:
+ raise ValueError("This is equivalent to location=1 and no downsampling!")
+ # location decoding
+ self.film_dims = 0
+ if location in {0, 1, 2}:
+ self.film_dims = in_channels
+ elif location in {3, 4}:
+ self.film_dims = out_channels
+ if activation == "sigmoid":
+ self.scale_activation = nn.Sigmoid()
+ elif activation == "tanh":
+ self.scale_activation = nn.Tanh()
+ elif activation == "linear":
+ self.scale_activation = None
+
+ @abstractmethod
+ def rescale_features(self, feature_map, x_aux):
+ """method to recalibrate feature map x"""
+
+ def forward(self, feature_map, x_aux):
+
+ if self.location == 0:
+ feature_map = self.rescale_features(feature_map, x_aux)
+ residual = feature_map
+
+ if self.location == 1:
+ residual = self.rescale_features(residual, x_aux)
+
+ if self.location == 2:
+ feature_map = self.rescale_features(feature_map, x_aux)
+ out = self.conv1(feature_map)
+ out = self.bn1(out)
+
+ if self.location == 3:
+ out = self.rescale_features(out, x_aux)
+ out = self.relu(out)
+
+ if self.location == 4:
+ out = self.rescale_features(out, x_aux)
+ out = self.conv2(out)
+ out = self.bn2(out)
+ if self.downsample is not None:
+ residual = self.downsample(residual)
+ out += residual
+ out = self.relu(out)
+
+ return out
+
+
+class DAFTBlock(FilmBase):
+ # Block for ZeCatNet
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ bn_momentum: float = 0.1,
+ stride: int = 2,
+ ndim_non_img: int = 15,
+ location: int = 0,
+ activation: str = "linear",
+ scale: bool = True,
+ shift: bool = True,
+ bottleneck_dim: int = 7,
+ ) -> None:
+
+ super().__init__(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ bn_momentum=bn_momentum,
+ stride=stride,
+ ndim_non_img=ndim_non_img,
+ location=location,
+ activation=activation,
+ scale=scale,
+ shift=shift,
+ )
+
+ self.bottleneck_dim = bottleneck_dim
+ aux_input_dims = self.film_dims
+ # shift and scale decoding
+ self.split_size = 0
+ if scale and shift:
+ self.split_size = self.film_dims
+ self.scale = None
+ self.shift = None
+ self.film_dims = 2 * self.film_dims
+ elif not scale:
+ self.scale = 1
+ self.shift = None
+ elif not shift:
+ self.shift = 0
+ self.scale = None
+
+ # create aux net
+ layers = [
+ ("aux_base", nn.Linear(ndim_non_img + aux_input_dims, self.bottleneck_dim, bias=False)),
+ ("aux_relu", nn.ReLU()),
+ ("aux_out", nn.Linear(self.bottleneck_dim, self.film_dims, bias=False)),
+ ]
+ self.aux = nn.Sequential(OrderedDict(layers))
+
+ def rescale_features(self, feature_map, x_aux):
+
+ squeeze = self.global_pool(feature_map)
+ squeeze = squeeze.view(squeeze.size(0), -1)
+ squeeze = torch.cat((squeeze, x_aux), dim=1)
+
+ attention = self.aux(squeeze)
+ if self.scale == self.shift:
+ v_scale, v_shift = torch.split(attention, self.split_size, dim=1)
+ v_scale = v_scale.view(*v_scale.size(), 1, 1, 1).expand_as(feature_map)
+ v_shift = v_shift.view(*v_shift.size(), 1, 1, 1).expand_as(feature_map)
+ if self.scale_activation is not None:
+ v_scale = self.scale_activation(v_scale)
+ elif self.scale is None:
+ v_scale = attention
+ v_scale = v_scale.view(*v_scale.size(), 1, 1, 1).expand_as(feature_map)
+ v_shift = self.shift
+ if self.scale_activation is not None:
+ v_scale = self.scale_activation(v_scale)
+ elif self.shift is None:
+ v_scale = self.scale
+ v_shift = attention
+ v_shift = v_shift.view(*v_shift.size(), 1, 1, 1).expand_as(feature_map)
+ else:
+ raise AssertionError(
+ f"Sanity checking on scale and shift failed. Must be of type bool or None: {self.scale}, {self.shift}"
+ )
+
+ return (v_scale * feature_map) + v_shift
diff --git a/torchpanic/models/daft/vol_networks.py b/torchpanic/models/daft/vol_networks.py
new file mode 100644
index 0000000..d927e9e
--- /dev/null
+++ b/torchpanic/models/daft/vol_networks.py
@@ -0,0 +1,165 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+#
+#
+#
+#
+# This file is part of Dynamic Affine Feature Map Transform (DAFT).
+#
+# DAFT is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# DAFT is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with DAFT. If not, see .
+from typing import Any, Dict, Optional, Sequence
+
+import torch
+import torch.nn as nn
+
+from ...datamodule.modalities import ModalityType
+from .vol_blocks import ConvBnReLU, DAFTBlock, ResBlock
+
+
+class HeterogeneousResNet(nn.Module):
+ def __init__(
+ self,
+ in_channels: int = 1,
+ n_outputs: int = 3,
+ bn_momentum: int = 0.1,
+ n_basefilters: int = 4,
+ ) -> None:
+ super().__init__()
+
+ self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum)
+ self.pool1 = nn.MaxPool3d(2, stride=2) # 32
+ self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)
+ self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 16
+ self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 8
+ self.block4 = ResBlock(4 * n_basefilters, 8 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 4
+ self.global_pool = nn.AdaptiveAvgPool3d(1)
+ self.fc = nn.Linear(8 * n_basefilters, n_outputs)
+
+ def forward(self, batch):
+ image = batch[self.image_modality]
+
+ out = self.conv1(image)
+ out = self.pool1(out)
+ out = self.block1(out)
+ out = self.block2(out)
+ out = self.block3(out)
+ out = self.block4(out)
+ out = self.global_pool(out)
+ out = out.view(out.size(0), -1)
+ out = self.fc(out)
+
+ # cannot return None, because lightning complains
+ terms = torch.zeros((out.shape[0], 1, out.shape[1],), device=out.device)
+ return out, terms
+
+
+class DAFT(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ in_tabular: int,
+ n_outputs: int,
+ idx_tabular_has_missing: Sequence[int],
+ bn_momentum: float = 0.05,
+ n_basefilters: int = 4,
+ filmblock_args: Optional[Dict[str, Any]] = None,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+
+ if filmblock_args is None:
+ filmblock_args = {}
+
+ self.image_modality = ModalityType.PET
+
+ if min(idx_tabular_has_missing) < 0:
+ raise ValueError("idx_tabular_has_missing contains negative values")
+ if max(idx_tabular_has_missing) >= in_tabular:
+ raise ValueError("index in idx_tabular_has_missing is out of range")
+ idx_missing = frozenset(idx_tabular_has_missing)
+ if len(idx_tabular_has_missing) != len(idx_missing):
+ raise ValueError("idx_tabular_has_missing contains duplicates")
+
+ self.register_buffer(
+ 'idx_tabular_has_missing',
+ torch.tensor(idx_tabular_has_missing, dtype=torch.long),
+ )
+ self.register_buffer(
+ 'idx_tabular_without_missing',
+ torch.tensor(list(set(range(in_tabular)).difference(idx_missing)), dtype=torch.long),
+ )
+
+ n_missing = len(idx_tabular_has_missing)
+ self.tab_missing_embeddings = nn.Parameter(
+ torch.empty((1, n_missing,), dtype=torch.float32), requires_grad=True)
+ nn.init.xavier_uniform_(self.tab_missing_embeddings)
+
+ self.split_size = 4 * n_basefilters
+ self.conv1 = ConvBnReLU(in_channels, n_basefilters, bn_momentum=bn_momentum)
+ self.pool1 = nn.MaxPool3d(3, stride=2) # 32
+ self.block1 = ResBlock(n_basefilters, n_basefilters, bn_momentum=bn_momentum)
+ self.block2 = ResBlock(n_basefilters, 2 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 16
+ self.block3 = ResBlock(2 * n_basefilters, 4 * n_basefilters, bn_momentum=bn_momentum, stride=2) # 8
+ self.blockX = DAFTBlock(
+ 4 * n_basefilters,
+ 8 * n_basefilters,
+ bn_momentum=bn_momentum,
+ ndim_non_img=in_tabular,
+ **filmblock_args,
+ ) # 4
+ self.global_pool = nn.AdaptiveAvgPool3d(1)
+ self.fc = nn.Linear(8 * n_basefilters, n_outputs)
+
+ def forward(self, batch):
+ image = batch[self.image_modality]
+
+ tabular_data = batch[ModalityType.TABULAR]
+ values, is_missing = torch.unbind(tabular_data, axis=1)
+
+ values_wo_missing = values[:, self.idx_tabular_without_missing]
+ values_w_missing = values[:, self.idx_tabular_has_missing]
+ missing_in_batch = is_missing[:, self.idx_tabular_has_missing]
+ tabular_masked = torch.where(
+ missing_in_batch == 1.0, self.tab_missing_embeddings, values_w_missing
+ )
+
+ features = torch.cat(
+ (values_wo_missing, tabular_masked), dim=1,
+ )
+
+ out = self.conv1(image)
+ out = self.pool1(out)
+ out = self.block1(out)
+ out = self.block2(out)
+ out = self.block3(out)
+ out = self.blockX(out, features)
+ out = self.global_pool(out)
+ out = out.view(out.size(0), -1)
+ out = self.fc(out)
+
+ # cannot return None, because lightning complains
+ terms = torch.zeros((out.shape[0], 1, out.shape[1],), device=out.device)
+ return out, terms
diff --git a/torchpanic/models/nam.py b/torchpanic/models/nam.py
new file mode 100644
index 0000000..e9773bd
--- /dev/null
+++ b/torchpanic/models/nam.py
@@ -0,0 +1,295 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+from typing import Dict, Sequence, Tuple
+import torch
+from torch import nn
+from torch.nn import init
+
+from ..datamodule.adni import ModalityType
+from torchpanic.modules.utils import init_vector_normal
+
+
+class ExU(nn.Module):
+ """exp-centered unit"""
+ def __init__(self, out_features: int) -> None:
+ super().__init__()
+ self.weight = nn.Parameter(torch.empty((1, out_features)))
+ self.bias = nn.Parameter(torch.empty((1, out_features)))
+
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ init.normal_(self.weight, mean=4.0, std=0.5)
+ init.zeros_(self.bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ assert x.ndim == 2 and x.shape[1] == 1
+ out = torch.exp(self.weight) * (x - self.bias)
+ return out
+
+
+class ReLUN(nn.Module):
+ def __init__(self, n: float = 1.0) -> None:
+ super().__init__()
+ self.n = n
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return torch.clamp(x, min=0.0, max=self.n)
+
+
+class FeatureNet(nn.Module):
+ """A neural network for a single feature"""
+ def __init__(
+ self,
+ out_features: int,
+ hidden_units: Sequence[int],
+ dropout_rate: float = 0.5,
+ ) -> None:
+ super().__init__()
+ in_features = hidden_units[0]
+ layers = {
+ "in": nn.Sequential(
+ nn.utils.weight_norm(nn.Linear(1, in_features)),
+ nn.ReLU()),
+ }
+ for i, units in enumerate(hidden_units[1:]):
+ layers[f"dense_{i}"] = nn.Sequential(
+ nn.utils.weight_norm(nn.Linear(in_features, units)),
+ nn.Dropout(p=dropout_rate),
+ nn.ReLU(),
+ )
+ in_features = units
+ layers["dense_out"] = nn.utils.weight_norm(nn.Linear(in_features, out_features, bias=False))
+ self.hidden_layers = nn.ModuleDict(layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = x
+ for layer in self.hidden_layers.values():
+ out = layer(out)
+ return out
+
+
+class NAM(nn.Module):
+ """Neural Additive Model
+
+ .. [1] Neural Additive Models: Interpretable Machine Learning with Neural Nets. NeurIPS 2021
+ https://proceedings.neurips.cc/paper/2021/hash/251bd0442dfcc53b5a761e050f8022b8-Abstract.html
+ """
+ def __init__(
+ self,
+ in_features: int,
+ out_features: int,
+ hidden_units: Sequence[int],
+ dropout_rate: float = 0.5,
+ feature_dropout_rate: float = 0.5,
+ ) -> None:
+ super().__init__()
+
+ layers = {}
+ for i in range(in_features):
+ layers[f"fnet_{i}"] = nn.Sequential(
+ FeatureNet(
+ out_features=out_features,
+ hidden_units=hidden_units,
+ dropout_rate=dropout_rate,
+ ),
+ )
+ self.feature_nns = nn.ModuleDict(layers)
+ self.feature_dropout = nn.Dropout1d(p=feature_dropout_rate)
+
+ self.bias = nn.Parameter(torch.empty((1, out_features)))
+
+ self.reset_parameters()
+
+ def reset_parameters(self) -> None:
+ init.zeros_(self.bias)
+
+ def base_forward(self, tabular: torch.Tensor) -> torch.Tensor:
+ values, is_missing = torch.unbind(tabular, axis=1)
+
+ # FIXME: Treat missing value in a better way
+ x = values * (1.0 - is_missing)
+
+ features = torch.split(x, 1, dim=-1)
+ outputs = []
+ for x_i, layer in zip(features, self.feature_nns.values()):
+ outputs.append(layer(x_i))
+ outputs = torch.stack(outputs, dim=1)
+ logits = self.feature_dropout(outputs)
+ return logits, outputs
+
+ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
+ tabular = batch[ModalityType.TABULAR]
+ logits, outputs = self.base_forward(tabular)
+ logits = torch.sum(logits, dim=1) + self.bias
+ return logits, outputs
+
+
+def is_unique(x):
+ return len(x) == len(set(x))
+
+
+class SemiParametricNAM(NAM):
+ def __init__(
+ self,
+ idx_real_features: Sequence[int],
+ idx_cat_features: Sequence[int],
+ idx_real_has_missing: Sequence[int],
+ idx_cat_has_missing: Sequence[int],
+ out_features: int,
+ hidden_units: Sequence[int],
+ dropout_rate: float = 0.5,
+ feature_dropout_rate: float = 0.5,
+ ) -> None:
+ super().__init__(
+ in_features=len(idx_real_features),
+ out_features=out_features,
+ hidden_units=hidden_units,
+ dropout_rate=dropout_rate,
+ feature_dropout_rate=feature_dropout_rate,
+ )
+
+ assert is_unique(idx_real_features)
+ assert is_unique(idx_cat_features)
+ assert is_unique(idx_real_has_missing)
+ assert is_unique(idx_cat_has_missing)
+
+ self.cat_linear = nn.Linear(len(idx_cat_features), out_features, bias=False)
+ n_missing = len(idx_real_has_missing) + len(idx_cat_has_missing)
+ self.miss_linear = nn.Linear(n_missing, out_features, bias=False)
+
+ self._idx_real_features = idx_real_features
+ self._idx_cat_features = idx_cat_features
+ self._idx_real_has_missing = idx_real_has_missing
+ self._idx_cat_has_missing = idx_cat_has_missing
+
+ def forward(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
+ tabular = batch[ModalityType.TABULAR]
+ values, is_missing = torch.unbind(tabular, axis=1)
+
+ val_categ = values[:, self._idx_cat_features]
+ miss_categ = is_missing[:, self._idx_cat_features]
+ has_miss_categ = miss_categ[:, self._idx_cat_has_missing]
+
+ val_real = values[:, self._idx_real_features]
+ miss_real = is_missing[:, self._idx_real_features]
+ has_miss_real = miss_real[:, self._idx_real_has_missing]
+
+ out_real_full = self.forward_real(val_real) * torch.unsqueeze(1.0 - miss_real, dim=-1)
+ out_real = torch.sum(out_real_full, dim=1)
+
+ out_categ = self.cat_linear(val_categ * (1.0 - miss_categ))
+
+ out_miss = self.miss_linear(torch.cat((has_miss_categ, has_miss_real), dim=1))
+
+ return sum((out_real, out_categ, out_miss, self.bias,)), out_real_full
+
+ def forward_real(self, x):
+ features = torch.split(x, 1, dim=-1)
+ outputs = []
+ for x_i, layer in zip(features, self.feature_nns.values()):
+ outputs.append(layer(x_i))
+
+ outputs = torch.stack(outputs, dim=1)
+ outputs = self.feature_dropout(outputs)
+ return outputs
+
+
+class BaseNAM(NAM):
+ def __init__(
+ self,
+ idx_real_features: Sequence[int],
+ idx_cat_features: Sequence[int],
+ idx_real_has_missing: Sequence[int],
+ idx_cat_has_missing: Sequence[int],
+ out_features: int,
+ hidden_units: Sequence[int],
+ dropout_rate: float = 0.5,
+ feature_dropout_rate: float = 0.5,
+ **kwargs
+ ) -> None:
+ super().__init__(
+ in_features=len(idx_real_features),
+ out_features=out_features,
+ hidden_units=hidden_units,
+ dropout_rate=dropout_rate,
+ feature_dropout_rate=feature_dropout_rate,
+ )
+
+ assert is_unique(idx_real_features)
+ assert is_unique(idx_cat_features)
+ assert is_unique(idx_real_has_missing)
+ assert is_unique(idx_cat_has_missing)
+
+ n_missing = len(idx_real_has_missing) + len(idx_cat_has_missing)
+ self.tab_missing_embeddings = nn.Parameter(
+ torch.empty((n_missing, out_features), dtype=torch.float32), requires_grad=True)
+ nn.init.xavier_uniform_(self.tab_missing_embeddings)
+
+ self.cat_linear = nn.Parameter(
+ torch.empty((len(idx_cat_features), out_features), dtype=torch.float32, requires_grad=True))
+ nn.init.xavier_uniform_(self.cat_linear)
+
+ self._idx_real_features = idx_real_features
+ self._idx_cat_features = idx_cat_features
+ self._idx_real_has_missing = idx_real_has_missing
+ self._idx_cat_has_missing = idx_cat_has_missing
+
+ def forward_real(self, x):
+ features = torch.split(x, 1, dim=-1)
+ outputs = []
+ for x_i, layer in zip(features, self.feature_nns.values()):
+ outputs.append(layer(x_i))
+
+ outputs = torch.stack(outputs, dim=1)
+ return outputs
+
+ def base_forward(self, tabular: torch.Tensor) -> torch.Tensor:
+ values, is_missing = torch.unbind(tabular, dim=1)
+
+ val_real = values[:, self._idx_real_features]
+ miss_real = is_missing[:, self._idx_real_features]
+ has_miss_real = miss_real[:, self._idx_real_has_missing]
+ # TODO for all miss_real==1, check that its index is in self._idx_real_has_missing
+
+ val_categ = values[:, self._idx_cat_features]
+ miss_categ = is_missing[:, self._idx_cat_features]
+ has_miss_categ = miss_categ[:, self._idx_cat_has_missing]
+ # TODO for all miss_categ==1, check that its index is in self._idx_categ_has_missing
+
+ features_real = self.forward_real(val_real)
+ features_categ = self.cat_linear.unsqueeze(0) * val_categ.unsqueeze(-1)
+
+ features_real = features_real * (1.0 - miss_real.unsqueeze(-1)) # set features to zero where they are mising
+ features_categ = features_categ * (1.0 - miss_categ.unsqueeze(-1))
+
+ filler_real = torch.zeros_like(features_real)
+ filler_categ = torch.zeros_like(features_categ)
+
+ filler_real[:, self._idx_real_has_missing, :] = \
+ self.tab_missing_embeddings[len(self._idx_cat_has_missing):].unsqueeze(0) * has_miss_real.unsqueeze(-1)
+ filler_categ[:, self._idx_cat_has_missing, :] = \
+ self.tab_missing_embeddings[:len(self._idx_cat_has_missing)].unsqueeze(0) * has_miss_categ.unsqueeze(-1)
+
+ features_real = features_real + filler_real # filler only has values where real is 0
+ features_categ = features_categ + filler_categ
+
+ return torch.cat((features_real, features_categ), dim=1)
+
+ def forward(self, tabular: torch.Tensor) -> torch.Tensor:
+ if isinstance(tabular, dict):
+ tabular = tabular[ModalityType.TABULAR]
+ features = self.base_forward(tabular)
+ return torch.sum(self.feature_dropout(features), dim=1) + self.bias, features
diff --git a/torchpanic/models/panic.py b/torchpanic/models/panic.py
new file mode 100644
index 0000000..f984a59
--- /dev/null
+++ b/torchpanic/models/panic.py
@@ -0,0 +1,94 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+from typing import Any, Dict
+import torch
+from torchvision import models
+
+from .protowrapper import ProtoWrapper
+from .nam import BaseNAM
+from ..models.backbones import ThreeDResNet
+
+BACKBONES = {
+ 'resnet18': (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1),
+ 'resnet34': (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1),
+ 'resnet50': (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V1),
+ 'resnet101': (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V1),
+ 'resnet152': (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V1),
+ '3dresnet': (ThreeDResNet, None),
+}
+
+
+def flatten_module(module):
+ children = list(module.children())
+ flat_children = []
+ if children == []:
+ return module
+ else:
+ for child in children:
+ try:
+ flat_children.extend(flatten_module(child))
+ except TypeError:
+ flat_children.append(flatten_module(child))
+ return flat_children
+
+
+def deactivate_features(features):
+ for param in features.parameters():
+ param.requires_grad = False
+
+
+class PANIC(ProtoWrapper):
+ def __init__(
+ self,
+ protonet: Dict[Any, Any],
+ nam: Dict[Any, Any],
+ ) -> None:
+ super().__init__(
+ **protonet
+ )
+ # must tidy up prototype vector dimensions!
+ self.classification = None
+ self.nam = BaseNAM(
+ **nam
+ )
+
+ def forward_image(self, image):
+
+ return self.base_forward(image)
+
+ def forward(self, image, tabular):
+
+ feature_vectors, similarities, occurrences = self.forward_image(image)
+ # feature_vectors shape is (bs, n_protos, n_chans_per_prot)
+ # similarities is of shape (bs, n_protos)
+ similarities_reshaped = similarities.view(
+ similarities.size(0), self.num_classes, self.n_prototypes_per_class)
+ similarities_reshaped = similarities_reshaped.permute(0, 2, 1)
+ # this maps similarities such that we have the similarities w.r.t. each class:
+ # new shape is (bs, n_protos_per_class, n_classes)
+
+ nam_features = self.nam.base_forward(tabular)
+ # nam_features shape is (bs, n_features, n_classes)
+
+ features = torch.cat((similarities_reshaped, nam_features), dim=1)
+
+ logits = self.nam.feature_dropout(features)
+ logits = torch.sum(logits, dim=1) + self.nam.bias
+ return logits, similarities, occurrences, features
+
+ @torch.no_grad()
+ def push_forward(self, image, tabular):
+
+ return self.forward_image(image)
diff --git a/torchpanic/models/ppnet.py b/torchpanic/models/ppnet.py
new file mode 100644
index 0000000..0192dee
--- /dev/null
+++ b/torchpanic/models/ppnet.py
@@ -0,0 +1,196 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+#
+# Original License:
+# MIT License
+#
+# Copyright (c) 2019 Chaofan Chen (cfchen-duke), Oscar Li (OscarcarLi),
+# Chaofan Tao, Alina Jade Barnett, Cynthia Rudin
+#
+# Permission is hereby granted, free of charge, to any person obtaining a copy
+# of this software and associated documentation files (the "Software"), to deal
+# in the Software without restriction, including without limitation the rights
+# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+# copies of the Software, and to permit persons to whom the Software is
+# furnished to do so, subject to the following conditions:
+#
+# The above copyright notice and this permission notice shall be included in all
+# copies or substantial portions of the Software.
+#
+# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+# SOFTWARE.
+
+
+import torch
+import torch.nn as nn
+from torchvision import models
+
+from .backbones import ThreeDResNet
+
+BACKBONES = {
+ 'resnet18': (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1),
+ 'resnet34': (models.resnet34, models.ResNet34_Weights.IMAGENET1K_V1),
+ 'resnet50': (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V1),
+ 'resnet101': (models.resnet101, models.ResNet101_Weights.IMAGENET1K_V1),
+ 'resnet152': (models.resnet152, models.ResNet152_Weights.IMAGENET1K_V1),
+ '3dresnet': (ThreeDResNet, None),
+}
+
+
+def flatten_module(module):
+ children = list(module.children())
+ flat_children = []
+ if children == []:
+ return module
+ else:
+ for child in children:
+ try:
+ flat_children.extend(flatten_module(child))
+ except TypeError:
+ flat_children.append(flatten_module(child))
+ return flat_children
+
+
+class PPNet(nn.Module):
+ def __init__(
+ self,
+ backbone: str,
+ in_channels: int,
+ out_features: int,
+ n_prototypes_per_class: int,
+ n_chans_protos: int,
+ optim_features: bool,
+ normed_prototypes: bool,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ assert backbone in BACKBONES.keys(), f"cannot find backbone {backbone} in valid BACKBONES {BACKBONES.keys()}!"
+
+ self.normed_prototypes = normed_prototypes
+ self.n_prototypes_per_class = n_prototypes_per_class
+
+ features, weights = BACKBONES[backbone]
+ if backbone.startswith("resnet"):
+ weights = weights.DEFAULT
+ if weights is None:
+ if 'pretrained_model' in kwargs:
+ pretrained_model = kwargs.pop('pretrained_model')
+ else:
+ pretrained_model = None
+ weights = kwargs
+ weights = kwargs
+ weights["in_channels"] = in_channels
+ weights["n_outputs"] = out_features
+ else:
+ weights = {"weights": weights}
+ features = features(**weights)
+ features = list(features.children())
+ fc_layer = features[-1]
+ features = features[:-2] # remove GAP and last fc layer!
+ in_layer = features[0]
+ assert isinstance(in_layer, nn.Conv2d) or isinstance(in_layer, nn.Conv3d)
+ self.two_d = isinstance(in_layer, nn.Conv2d)
+ self.conv_func = nn.Conv2d if self.two_d else nn.Conv3d
+ if in_layer.in_channels != in_channels:
+ assert optim_features, "Different num input channels -> must optim!"
+ in_layer = self.conv_func(
+ in_channels=in_channels,
+ out_channels=in_layer.out_channels,
+ kernel_size=in_layer.kernel_size,
+ stride=in_layer.stride,
+ padding=in_layer.padding,
+ )
+ in_layer.requires_grad = False
+ features[0] = in_layer
+ self.num_classes = out_features
+ self.n_features_encoder = fc_layer.in_features
+ self.prototype_shape = (n_prototypes_per_class * self.num_classes, n_chans_protos)
+ self.num_prototypes = self.prototype_shape[0]
+
+ # prototype_activation_function could be 'log', 'linear',
+ # or a generic function that converts distance to similarity score
+
+ '''
+ Here we are initializing the class identities of the prototypes
+ Without domain specific knowledge we allocate the same number of
+ prototypes for each class
+ '''
+ assert (self.num_prototypes % self.num_classes == 0), \
+ f"{self.num_prototypes} vs {self.num_classes}" # not needed as we initialize differently
+ # a onehot indication matrix for each prototype's class identity
+ self.prototype_class_identity = nn.Parameter(
+ torch.zeros(self.num_prototypes, self.num_classes),
+ requires_grad=False)
+
+ num_prototypes_per_class = self.num_prototypes // self.num_classes
+ for j in range(self.num_prototypes):
+ self.prototype_class_identity[j, j // num_prototypes_per_class] = 1
+
+ self.features = torch.nn.Sequential(*features)
+
+ self.add_on_layers = nn.Sequential(
+ self.conv_func(in_channels=self.n_features_encoder, out_channels=self.prototype_shape[1], kernel_size=1),
+ nn.ReLU(),
+ self.conv_func(in_channels=self.prototype_shape[1], out_channels=self.prototype_shape[1], kernel_size=1),
+ nn.Softplus()
+ )
+ self.occurrence_module = nn.Sequential(
+ self.conv_func(in_channels=self.n_features_encoder, out_channels=self.n_features_encoder // 8, kernel_size=1),
+ nn.ReLU(),
+ self.conv_func(in_channels=self.n_features_encoder // 8, out_channels=self.prototype_shape[0], kernel_size=1),
+ nn.Sigmoid()
+ )
+ self.gap = (-2, -1) if self.two_d else (-3, -2, -1) # if 3D network, pool h,w and d
+ self.prototype_vectors = nn.Parameter(torch.rand(self.prototype_shape), requires_grad=True)
+ if self.normed_prototypes:
+ self.prototype_vectors.data = (self.prototype_vectors.data / torch.linalg.vector_norm(
+ self.prototype_vectors.data, ord=2, dim=1, keepdim=True, dtype=torch.double)).to(torch.float32)
+ # nn.init.xavier_uniform_(self.prototype_vectors)
+ self.cosine_similarity = torch.nn.CosineSimilarity(dim=-1)
+
+ self.ones = nn.Parameter(torch.ones(self.prototype_shape),
+ requires_grad=False)
+
+ if pretrained_model is not None:
+ state_dict = torch.load(pretrained_model, map_location=self.prototype_vectors.device)['state_dict']
+ state_dict = {x.replace('net.features.', ''): y for x, y in state_dict.items()}
+ state_dict = {x: y for x, y in state_dict.items() if x in self.features.state_dict().keys()}
+ self.features.load_state_dict(state_dict)
+ if not optim_features:
+ for x in self.features.parameters():
+ x.requires_grad = False
+
+ self.classification = nn.Linear(self.num_prototypes, self.num_classes,
+ bias=False) # do not use bias
+
+ self.set_last_layer_incorrect_connection(incorrect_strength=-0.5)
+
+ def set_last_layer_incorrect_connection(self, incorrect_strength):
+ '''
+ the incorrect strength will be actual strength if -0.5 then input -0.5
+ '''
+ positive_one_weights_locations = torch.t(self.prototype_class_identity)
+ negative_one_weights_locations = 1 - positive_one_weights_locations
+
+ correct_class_connection = 1
+ incorrect_class_connection = incorrect_strength
+ self.classification.weight.data.copy_(
+ correct_class_connection * positive_one_weights_locations
+ + incorrect_class_connection * negative_one_weights_locations)
diff --git a/torchpanic/models/pretrain_encoder.py b/torchpanic/models/pretrain_encoder.py
new file mode 100644
index 0000000..4625723
--- /dev/null
+++ b/torchpanic/models/pretrain_encoder.py
@@ -0,0 +1,67 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+from typing import Any, Dict
+import torch
+
+from .protowrapper import ProtoWrapper
+from torchpanic.datamodule.adni import ModalityType
+
+
+def flatten_module(module):
+ children = list(module.children())
+ flat_children = []
+ if children == []:
+ return module
+ else:
+ for child in children:
+ try:
+ flat_children.extend(flatten_module(child))
+ except TypeError:
+ flat_children.append(flatten_module(child))
+ return flat_children
+
+
+def deactivate_features(features):
+ for param in features.parameters():
+ param.requires_grad = False
+
+
+class Encoder(torch.nn.Module):
+ def __init__(
+ self,
+ protonet: Dict[Any, Any],
+ ) -> None:
+ super().__init__()
+ wrapper = ProtoWrapper(
+ **protonet
+ )
+ self.features = wrapper.features
+ self.gap = torch.nn.AdaptiveAvgPool3d(1)
+ out_chans_encoder = self.features[-1][-1].conv2.out_channels
+ self.classification = torch.nn.Linear(out_chans_encoder, protonet['out_features'])
+ self.nam_term_faker = None
+
+ def forward(self, x):
+
+ x = x[ModalityType.PET]
+ out = self.features(x)
+ out = self.gap(out)
+ out = out.view(out.size(0), -1)
+ out = self.classification(out)
+ if self.nam_term_faker is None:
+ self.nam_term_faker = torch.zeros_like(out).unsqueeze(-1)
+ return out, self.nam_term_faker
+
+
diff --git a/torchpanic/models/protowrapper.py b/torchpanic/models/protowrapper.py
new file mode 100644
index 0000000..f69e787
--- /dev/null
+++ b/torchpanic/models/protowrapper.py
@@ -0,0 +1,93 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import torch
+
+from .ppnet import PPNet
+
+def flatten_module(module):
+ children = list(module.children())
+ flat_children = []
+ if children == []:
+ return module
+ else:
+ for child in children:
+ try:
+ flat_children.extend(flatten_module(child))
+ except TypeError:
+ flat_children.append(flatten_module(child))
+ return flat_children
+
+
+class ProtoWrapper(PPNet):
+ def __init__(
+ self,
+ backbone: str,
+ in_channels: int,
+ out_features: int,
+ n_prototypes_per_class: int,
+ n_chans_protos: int,
+ optim_features: bool,
+ normed_prototypes: bool,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ backbone,
+ in_channels,
+ out_features,
+ n_prototypes_per_class,
+ n_chans_protos,
+ optim_features,
+ normed_prototypes,
+ **kwargs)
+
+ def base_forward(self, x):
+
+ x = self.features(x)
+ occurrences = self.occurrence_module(x) # (bs, n_prototypes)
+ feature_map = self.add_on_layers(x) # (bs, n_chans_proto, h, w[, d])
+
+ # the following should work for 3D and 2D data!
+ broadcasting_shape = (occurrences.size(0), occurrences.size(1), feature_map.size(1), *feature_map.size()[2:])
+ # of shape (bs, n_protos, n_chans_per_prot, h, w[, d])
+
+ # expand the two such that broadcasting is possible, i.e. vectorization of prototype feature calculation
+ occurrences_reshaped = occurrences.unsqueeze(2).expand(broadcasting_shape)
+ feature_map_reshaped = feature_map.unsqueeze(1).expand(broadcasting_shape)
+
+ # element-wise multiplication of each occurence map with the featuremap
+ feature_vectors = occurrences_reshaped * feature_map_reshaped
+ feature_vectors = feature_vectors.mean(dim=self.gap) # essentially GAP over the spatial resolutions
+ # feature_vectors size is now (bs, n_protos, n_chans_per_prot)
+ # prototype_vectors size is (n_protos, n_chans_per_prot)
+ if self.normed_prototypes:
+ feature_vectors = (feature_vectors / torch.linalg.vector_norm(
+ feature_vectors, ord=2, dim=2, keepdim=True, dtype=torch.double)).to(torch.float32)
+
+ # make prototypes broadcastable to featuer vectors
+ similarities = self.cosine_similarity(feature_vectors, self.prototype_vectors.unsqueeze(0))
+ return feature_vectors, similarities, occurrences
+
+ def forward(self, x):
+
+ _, similarities, occurrences = self.base_forward(x)
+
+ logits = self.classification(similarities)
+
+ return logits, similarities, occurrences
+
+ @torch.no_grad()
+ def push_forward(self, x):
+
+ return self.base_forward(x)
diff --git a/torchpanic/modules/__init__.py b/torchpanic/modules/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/torchpanic/modules/base.py b/torchpanic/modules/base.py
new file mode 100644
index 0000000..b0c721d
--- /dev/null
+++ b/torchpanic/modules/base.py
@@ -0,0 +1,131 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import logging
+from typing import Any, List
+
+import pytorch_lightning as pl
+from pytorch_lightning.loggers import TensorBoardLogger
+import torch
+from torchmetrics import Accuracy, ConfusionMatrix, MaxMetric
+
+from .utils import get_git_hash
+
+from ..datamodule.modalities import DataPointType
+
+
+LOG = logging.getLogger(__name__)
+
+
+class BaseModule(pl.LightningModule):
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ num_classes: int,
+ ) -> None:
+ super().__init__()
+
+ self.net = net
+ task = "binary" if num_classes <= 2 else "multiclass"
+ # use separate metric instance for train, val and test step
+ # to ensure a proper reduction over the epoch
+ self.train_acc = Accuracy(task=task, num_classes=num_classes)
+ self.val_acc = Accuracy(task=task, num_classes=num_classes)
+ self.test_acc = Accuracy(task=task, num_classes=num_classes)
+ self.val_cmat = ConfusionMatrix(task=task, num_classes=num_classes)
+ self.test_cmat = ConfusionMatrix(task=task, num_classes=num_classes)
+
+ # for logging best so far validation accuracy
+ self.val_acc_best = MaxMetric()
+ self.val_bacc_best = MaxMetric()
+
+ def _get_balanced_accuracy_from_confusion_matrix(self, confusion_matrix: ConfusionMatrix):
+ # Confusion matrix whose i-th row and j-th column entry indicates
+ # the number of samples with true label being i-th class and
+ # predicted label being j-th class.
+ cmat = confusion_matrix.compute()
+ per_class = cmat.diag() / cmat.sum(dim=1)
+ per_class = per_class[~torch.isnan(per_class)] # remove classes that are not present in this dataset
+ LOG.debug("Confusion matrix:\n%s", cmat)
+
+ return per_class.mean()
+
+ def _log_train_metrics(
+ self, loss: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor,
+ ) -> None:
+ self.log("train/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
+ acc = self.train_acc(preds, targets)
+ self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
+
+ def _update_validation_metrics(
+ self, loss: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor,
+ ) -> None:
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False)
+ self.val_acc.update(preds, targets)
+ self.val_cmat.update(preds, targets)
+
+ def _log_validation_metrics(self) -> None:
+ acc = self.val_acc.compute() # get val accuracy from current epoch
+ self.log("val/acc", acc, on_epoch=True)
+
+ self.val_acc_best.update(acc)
+ self.log("val/acc_best", self.val_acc_best.compute(), on_epoch=True)
+
+ # compute balanced accuracy
+ bacc = self._get_balanced_accuracy_from_confusion_matrix(self.val_cmat)
+ self.val_bacc_best.update(bacc)
+ self.log("val/bacc", bacc, on_epoch=True, prog_bar=True)
+ self.log("val/bacc_best", self.val_bacc_best.compute(), on_epoch=True)
+
+ # reset metrics at the end of every epoch
+ self.val_acc.reset()
+ self.val_cmat.reset()
+
+ def _update_test_metrics(
+ self, loss: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor,
+ ) -> None:
+ self.log("test/loss", loss, on_step=False, on_epoch=True)
+ self.test_acc.update(preds, targets)
+ self.test_cmat.update(preds, targets)
+
+ def _log_test_metrics(self) -> None:
+ acc = self.test_acc.compute()
+ self.log("test/acc", acc)
+
+ # compute balanced accuracy
+ bacc = self._get_balanced_accuracy_from_confusion_matrix(self.test_cmat)
+ self.log("test/bacc", bacc)
+
+ # reset metrics at the end of every epoch
+ self.test_acc.reset()
+ self.test_cmat.reset()
+
+ def on_train_start(self):
+ # by default lightning executes validation step sanity checks before training starts,
+ # so we need to make sure val_acc_best doesn't store accuracy from these checks
+ self.val_acc_best.reset()
+ self.val_bacc_best.reset()
+
+ if isinstance(self.logger, TensorBoardLogger):
+ tb_logger = self.logger.experiment
+ # tb_logger.add_hparams({"git-commit": get_git_hash()}, {"hp_metric": -1})
+ tb_logger.flush()
+
+ def training_epoch_end(self, outputs: List[Any]):
+ # `outputs` is a list of dicts returned from `training_step()`
+ # reset metrics at the end of every epoch
+ self.train_acc.reset()
+
+ def forward(self, x: DataPointType):
+ return self.net(x)
diff --git a/torchpanic/modules/panic.py b/torchpanic/modules/panic.py
new file mode 100644
index 0000000..5d56407
--- /dev/null
+++ b/torchpanic/modules/panic.py
@@ -0,0 +1,527 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import logging
+import math
+from pathlib import Path
+from typing import Any, List
+import warnings
+
+import numpy as np
+# from skimage.transform import resize
+
+from pytorch_lightning.loggers import TensorBoardLogger
+import torch
+from torch.nn.functional import l1_loss
+from torchmetrics import MaxMetric
+
+from ..datamodule.modalities import BatchWithLabelType, ModalityType
+from .base import BaseModule
+
+LOG = logging.getLogger(__name__)
+
+STAGE2FLOAT = {"warmup": 0.0, "warmup_protonet": 1.0, "all": 2.0, "nam_only": 3.0}
+
+
+# @torch.no_grad()
+# def init_weights(m: torch.Tensor):
+# if isinstance(m, nn.Conv3d):
+# # every init technique has an underscore _ in the name
+# nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+# if getattr(m, "bias", None) is not None:
+# nn.init.zeros_(m.bias)
+# elif isinstance(m, nn.BatchNorm3d):
+# nn.init.constant_(m.weight, 1)
+# nn.init.constant_(m.bias, 0)
+
+
+class PANIC(BaseModule):
+ def __init__(
+ self,
+ net: torch.nn.Module,
+ weight_decay_nam: float = 0.0,
+ lr: float = 0.001,
+ weight_decay: float = 0.0005,
+ l_clst: float = 0.8,
+ l_sep: float = 0.08,
+ l_occ: float = 1e-4,
+ l_affine: float = 1e-4,
+ l_nam: float = 1e-4,
+ epochs_all: int = 3,
+ epochs_nam: int = 4,
+ epochs_warmup: int = 10,
+ enable_checkpointing: bool = True,
+ monitor_prototypes: bool = False, # wether to save prototypes of all push epochs or just the best one
+ enable_save_embeddings: bool = False,
+ enable_log_prototypes: bool = False,
+ **kwargs,
+ ):
+ super().__init__(
+ net=net,
+ num_classes=net.num_classes,
+ )
+ self.automatic_optimization = False
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # it also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False, ignore=["net"])
+ self._current_stage = None
+ self._current_optimizer = None
+
+ # FIXME only picks PET
+ self.image_modality = ModalityType.PET
+
+ self.net = net
+ if self.net.two_d:
+ self.reduce_dims = (2, 3,) # inputs are 4d tensors
+ else:
+ self.reduce_dims = (2, 3, 4,) # inputs are 5d tensors
+
+ # loss function
+ self.criterion = torch.nn.CrossEntropyLoss()
+
+ self.val_bacc_save = MaxMetric()
+
+ def _cluster_losses(self, similarities, y):
+ prototypes_of_correct_class = torch.t(self.net.prototype_class_identity[:, y]).bool()
+ # select mask for each sample in batch. Shape is (bs, n_prototypes)
+ other_tensor = torch.tensor(0, device=similarities.device) # minimum value of the cosine similarity
+ similarities_correct_class = similarities.where(prototypes_of_correct_class, other=other_tensor)
+ # min value of cos similarity doesnt effect result
+ similarities_incorrect_class = similarities.where(~prototypes_of_correct_class, other=other_tensor)
+ # same here: if distance to other protos of other class is the minimum value
+ # of the cos similarity, they are distant to eachother!
+ clst = 1 - torch.max(similarities_correct_class, dim=1).values.mean()
+ sep = torch.max(similarities_incorrect_class, dim=1).values.mean()
+
+ return clst, sep
+
+ def _occurence_map_losses(self, occurrences, x_raw, aug):
+ with torch.no_grad():
+ _, _, occurrences_raw = self.net.forward_image(x_raw)
+
+ if self.net.two_d:
+ occurrences_raw = occurrences_raw.unsqueeze(-1)
+ occurrences_raw = occurrences_raw.cpu()
+
+ for i, aug_i in enumerate(aug):
+ occurrences_raw[i] = aug_i(occurrences_raw[i])
+
+ if self.net.two_d:
+ occurrences_raw = occurrences_raw.squeeze(-1)
+ occurrences_raw = occurrences_raw.to(occurrences.device)
+
+ affine = l1_loss(occurrences_raw, occurrences)
+
+ # l1 penalty on occurence maps
+ l1 = torch.linalg.vector_norm(occurrences, ord=1, dim=self.reduce_dims).mean()
+ l1_norm = math.prod(occurrences.size()[1:]) # omit batch dimension
+ l1 = (l1 / l1_norm).mean()
+ return affine, l1
+
+ def _classification_loss(self, logits, targets):
+ preds = torch.argmax(logits, dim=1)
+
+ xentropy = self.criterion(logits, targets)
+ return xentropy, preds
+
+ def forward(self, batch: BatchWithLabelType):
+ x = batch[0]
+ image = x[self.image_modality]
+ tabular = x[ModalityType.TABULAR]
+
+ return self.net(image, tabular)
+
+ def on_train_epoch_start(self) -> None:
+ cur_stage = self._get_current_stage()
+ LOG.info("Epoch %d, optimizing %s", self.trainer.current_epoch, cur_stage)
+
+ self.log("train/stage", STAGE2FLOAT[cur_stage])
+
+ optim_warmup, optim_warmup_protonet, optim_all, optim_nam = self.optimizers()
+ scheduler_warmup, scheduler_all, scheduler_nam = self.lr_schedulers()
+ if cur_stage == "warmup":
+ opt = optim_warmup
+ sched = scheduler_warmup
+ elif cur_stage == "warmup_protonet":
+ opt = optim_warmup_protonet
+ sched = scheduler_warmup
+ elif cur_stage == "all":
+ opt = optim_all
+ sched = scheduler_all
+ elif cur_stage == "nam_only":
+ opt = optim_nam
+ sched = scheduler_nam
+ self.push_prototypes()
+ else:
+ raise AssertionError()
+
+ self._current_stage = cur_stage
+ self._current_optimizer = opt
+ self._current_scheduler = sched
+
+ def training_step(self, batch: BatchWithLabelType, batch_idx: int):
+ cur_stage = self._current_stage
+
+ x_raw, aug, y = batch[1:]
+ x_raw = x_raw[self.image_modality]
+ aug = aug[self.image_modality]
+
+ logits, similarities, occurrences, nam_terms = self.forward(batch)
+ xentropy, preds = self._classification_loss(logits, y)
+ self.log("train/xentropy", xentropy)
+ losses = [xentropy]
+
+ if cur_stage != "nam_only":
+ # cluster and seperation cost
+ clst, sep = self._cluster_losses(similarities, y)
+ losses.append(self.hparams.l_clst * clst)
+ losses.append(self.hparams.l_sep * sep)
+
+ self.log("train/clst", clst)
+ self.log("train/sep", sep)
+
+ # regularization of occurrence map
+ affine, l1 = self._occurence_map_losses(occurrences, x_raw, aug)
+ losses.append(self.hparams.l_affine * affine)
+ losses.append(self.hparams.l_occ * l1)
+
+ self.log("train/affine", affine)
+ self.log("train/l1", l1)
+
+ if cur_stage != "occ_and_feats":
+ # l2 penalty on terms of nam
+ start_index = self.net.n_prototypes_per_class
+ nam_penalty = nam_terms[:, start_index:, :].square().sum(dim=1).mean()
+ losses.append(self.hparams.l_nam * nam_penalty)
+
+ self.log("train/nam_l2", nam_penalty)
+
+ loss = sum(losses)
+
+ opt = self._current_optimizer
+ opt.zero_grad()
+ self.manual_backward(loss)
+ opt.step()
+
+ self._current_scheduler.step()
+
+ self._log_train_metrics(loss, preds, y)
+
+ return {"loss": loss, "preds": preds, "targets": y}
+
+ def _log_train_metrics(
+ self, loss: torch.Tensor, preds: torch.Tensor, targets: torch.Tensor,
+ ) -> None:
+ self.log(f"train/loss/{self._current_stage}", loss, on_step=False, on_epoch=True, prog_bar=False)
+ acc = self.train_acc(preds, targets)
+ self.log("train/acc", acc, on_step=False, on_epoch=True, prog_bar=True)
+
+ def _get_current_stage(self, epoch=None):
+ total = self.hparams.epochs_all + self.hparams.epochs_nam
+
+ stage = "nam_only"
+ if self.current_epoch < 0.5 * self.hparams.epochs_warmup:
+ stage = "warmup"
+ elif self.current_epoch < self.hparams.epochs_warmup:
+ stage = "warmup_protonet"
+ elif (self.current_epoch - self.hparams.epochs_warmup) % total < self.hparams.epochs_all:
+ stage = "all"
+
+ return stage
+
+ def training_epoch_end(self, outputs: List[Any]):
+ if self.hparams.enable_log_prototypes and isinstance(self.logger, TensorBoardLogger):
+ tb_logger = self.logger.experiment
+
+ tb_logger.add_histogram(
+ "train/prototypes", self.net.prototype_vectors, global_step=self.trainer.global_step,
+ )
+
+ return super().training_epoch_end(outputs)
+
+ def validation_step(self, batch: BatchWithLabelType, batch_idx: int):
+ logits, similarities, occurrences, nam_terms = self.forward(batch)
+
+ targets = batch[-1]
+ loss, preds = self._classification_loss(logits, targets)
+
+ self._update_validation_metrics(loss, preds, targets)
+
+ return {"loss": loss, "preds": preds, "targets": targets}
+
+ def validation_epoch_end(self, outputs: List[Any]):
+ # compute balanced accuracy
+ bacc = self._get_balanced_accuracy_from_confusion_matrix(self.val_cmat)
+
+ cur_stage = self._current_stage
+ # every 10th epoch is a last layer optim epoch
+ if cur_stage == "nam_only":
+ self.val_bacc_save.update(bacc)
+ saver = bacc
+ else:
+ self.val_bacc_save.update(torch.tensor(0., dtype=torch.float32))
+ saver = torch.tensor(-float('inf'), dtype=torch.float32, device=bacc.device)
+ self.log("val/bacc_save_monitor", self.val_bacc_save.compute(), on_epoch=True)
+ self.log("val/bacc_save", saver)
+
+ self._log_validation_metrics()
+
+ def test_step(self, batch: BatchWithLabelType, batch_idx: int):
+ logits, similarities, occurrences, nam_terms = self.forward(batch)
+
+ targets = batch[-1]
+ loss, preds = self._classification_loss(logits, targets)
+
+ self._update_test_metrics(loss, preds, targets)
+
+ return {"loss": loss, "preds": preds, "targets": targets}
+
+ def test_epoch_end(self, outputs: List[Any]):
+ self._log_test_metrics()
+
+ def configure_optimizers(self):
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
+ See examples here:
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
+ """
+
+ prototypes = {
+ 'params': [self.net.prototype_vectors],
+ 'lr': self.hparams.lr,
+ 'weight_decay': 0.0,
+ }
+ nam = {name: p for name, p in self.net.nam.named_parameters() if p.requires_grad}
+ embeddings = [nam.pop('tab_missing_embeddings')]
+ embeddings = {
+ 'params': embeddings,
+ 'lr': self.hparams.lr,
+ 'weight_decay': 0.0,
+ }
+ nam = {
+ 'params': list(nam.values()),
+ 'lr': self.hparams.lr,
+ 'weight_decay': self.hparams.weight_decay_nam,
+ }
+ encoder = {
+ 'params': [p for p in self.net.features.parameters() if p.requires_grad],
+ 'lr': self.hparams.lr,
+ 'weight_decay': self.hparams.weight_decay,
+ }
+ occurrence_and_features = {
+ 'params': [p for p in self.net.add_on_layers.parameters() if p.requires_grad] +
+ [p for p in self.net.occurrence_module.parameters() if p.requires_grad],
+ 'lr': self.hparams.lr,
+ 'weight_decay': self.hparams.weight_decay,
+ }
+ if len(encoder['params']) == 0:
+ warnings.warn("Encoder seems to be frozen! No parameters require grad.")
+ assert len(nam['params']) > 0
+ assert len(prototypes['params']) > 0
+ assert len(embeddings['params']) > 0
+ assert len(occurrence_and_features['params']) > 0
+ optim_all = torch.optim.AdamW([
+ encoder, occurrence_and_features, prototypes, nam, embeddings
+ ])
+ optim_nam = torch.optim.AdamW([
+ nam, embeddings
+ ])
+ optim_warmup = torch.optim.AdamW([
+ occurrence_and_features, prototypes
+ ])
+ optim_warmup_protonet = torch.optim.AdamW([
+ encoder, occurrence_and_features, prototypes
+ ])
+
+ training_iterations = len(self.trainer.datamodule.train_dataloader())
+ LOG.info("Number of iterations for one epoch: %d", training_iterations)
+
+ # stepping through warmup_protonet is sufficient, as all parameters groups are also in warmup
+ scheduler_kwargs = {'max_lr': self.hparams.lr, 'cycle_momentum': False}
+ scheduler_warmup = torch.optim.lr_scheduler.CyclicLR(
+ optim_warmup_protonet,
+ base_lr=self.hparams.lr / 20,
+ max_lr=self.hparams.lr / 10,
+ step_size_up=self.hparams.epochs_warmup * training_iterations,
+ cycle_momentum=False
+ )
+ scheduler_all = torch.optim.lr_scheduler.CyclicLR(
+ optim_all,
+ base_lr=self.hparams.lr / 10,
+ step_size_up=(self.hparams.epochs_all * training_iterations) / 2,
+ step_size_down=(self.hparams.epochs_all * training_iterations) / 2,
+ **scheduler_kwargs
+ )
+ scheduler_nam = torch.optim.lr_scheduler.CyclicLR(
+ optim_nam,
+ base_lr=self.hparams.lr / 10,
+ step_size_up=(self.hparams.epochs_nam * training_iterations) / 2,
+ step_size_down=(self.hparams.epochs_nam * training_iterations) / 2,
+ **scheduler_kwargs
+ )
+
+ return ([optim_warmup, optim_warmup_protonet, optim_all, optim_nam],
+ [scheduler_warmup, scheduler_all, scheduler_nam])
+
+ def push_prototypes(self):
+ LOG.info("Pushing protoypes. epoch=%d, step=%d", self.current_epoch, self.trainer.global_step)
+
+ self.net.eval()
+
+ prototype_shape = self.net.prototype_shape
+ n_prototypes = self.net.num_prototypes
+
+ global_max_proto_dist = np.full(n_prototypes, np.NINF)
+ # global_max_proto_dist = np.ones(n_prototypes) * -1
+ global_max_fmap = np.zeros(prototype_shape)
+ global_img_indices = np.zeros(n_prototypes, dtype=np.int)
+ global_img_classes = - np.ones(n_prototypes, dtype=np.int)
+
+ if self.hparams.monitor_prototypes:
+ proto_epoch_dir = Path(self.trainer.log_dir) / f"prototypes_epoch_{self.current_epoch}"
+ else:
+ proto_epoch_dir = Path(self.trainer.log_dir) / "prototypes_best"
+ if self.hparams.enable_checkpointing:
+ proto_epoch_dir.mkdir(exist_ok=True)
+
+ push_dataloader = self.trainer.datamodule.push_dataloader()
+ search_batch_size = push_dataloader.batch_size
+
+ num_classes = self.net.num_classes
+
+ save_embedding = self.hparams.enable_save_embeddings and isinstance(self.logger, TensorBoardLogger)
+
+ # indicates which class a prototype belongs to
+ proto_class = torch.argmax(self.net.prototype_class_identity, dim=1).detach().cpu().numpy()
+
+ embedding_data = []
+ embedding_labels = []
+ for push_iter, (search_batch_input, _, _, search_y) in enumerate(push_dataloader):
+
+ start_index_of_search_batch = push_iter * search_batch_size
+
+ feature_vectors = self.update_prototypes_on_batch(
+ search_batch_input,
+ start_index_of_search_batch,
+ global_max_proto_dist,
+ global_max_fmap,
+ global_img_indices,
+ global_img_classes,
+ num_classes,
+ search_y,
+ proto_epoch_dir,
+ )
+
+ if save_embedding:
+ # for each batch, split into one feature vector for each prototype
+ embedding_data.extend(feature_vectors[:, j] for j in range(n_prototypes))
+ embedding_labels.append(np.repeat(proto_class, feature_vectors.shape[0]))
+
+ prototype_update = np.reshape(global_max_fmap, tuple(prototype_shape))
+
+ if self.hparams.enable_checkpointing:
+ np.save(proto_epoch_dir / f"p_similarities_{self.current_epoch}.npy", global_max_proto_dist)
+ np.save(proto_epoch_dir / f"p_feature_maps_{self.current_epoch}.npy", global_max_fmap)
+ np.save(proto_epoch_dir / f"p_inp_indices_{self.current_epoch}.npy", global_img_indices)
+ np.save(proto_epoch_dir / f"p_inp_img_labels_{self.current_epoch}.npy", global_img_classes)
+
+ if save_embedding:
+ tb_logger = self.logger.experiment
+
+ embedding_data.append(self.net.prototype_vectors.detach().cpu().numpy())
+ embedding_data.append(prototype_update)
+ embedding_labels = np.concatenate(embedding_labels)
+ metadata = [f"FV Class {i}" for i in embedding_labels]
+
+ metadata.extend(f"Old PV Class {i}" for i in proto_class)
+ metadata.extend(f"New PV Class {i}" for i in proto_class)
+
+ tb_logger.add_embedding(
+ mat=np.concatenate(embedding_data, axis=0),
+ metadata=metadata,
+ global_step=self.trainer.global_step,
+ tag="push_prototypes",
+ )
+
+ self.net.prototype_vectors.data.copy_(torch.tensor(prototype_update, dtype=torch.float32, device=self.device))
+
+ self.net.train()
+
+ def update_prototypes_on_batch(
+ self,
+ search_batch,
+ start_index_of_search_batch,
+ global_max_proto_dist,
+ global_max_fmap,
+ global_img_indices,
+ global_img_classes,
+ num_classes,
+ search_y,
+ proto_epoch_dir,
+ ):
+ self.net.eval()
+
+ feats, sims, occ = self.net.push_forward(
+ search_batch[self.image_modality].to(self.device),
+ search_batch[ModalityType.TABULAR].to(self.device))
+
+ feature_vectors = np.copy(feats.detach().cpu().numpy())
+ similarities = np.copy(sims.detach().cpu().numpy())
+ occurrences = np.copy(occ.detach().cpu().numpy())
+
+ del feats, sims, occ
+
+ class_to_img_index = {key: [] for key in range(num_classes)}
+ for img_index, img_y in enumerate(search_y):
+ img_label = img_y.item()
+ class_to_img_index[img_label].append(img_index)
+
+ prototype_shape = self.net.prototype_shape
+ n_prototypes = prototype_shape[0]
+
+ for j in range(n_prototypes):
+ target_class = torch.argmax(self.net.prototype_class_identity[j]).item()
+ if len(class_to_img_index[target_class]) == 0: # none of the images belongs to the class of this prototype
+ continue
+ proto_dist_j = similarities[class_to_img_index[target_class], j]
+ # distnces of all latents to the j-th prototype of this class within the batch
+
+ batch_max_proto_dist_j = np.amax(proto_dist_j) # minimum distance of latents of this batch to prototype j
+
+ if batch_max_proto_dist_j > global_max_proto_dist[j]: # save if a new min distance is present in this batch
+
+ img_index_in_class = np.argmax(proto_dist_j)
+ img_index_in_batch = class_to_img_index[target_class][img_index_in_class]
+
+ batch_max_fmap_j = feature_vectors[img_index_in_batch, j]
+
+ # latent vector of minimum distance
+ global_max_proto_dist[j] = batch_max_proto_dist_j
+ global_max_fmap[j] = batch_max_fmap_j
+ global_img_indices[j] = img_index_in_batch + start_index_of_search_batch
+ global_img_classes[j] = search_y[img_index_in_batch].item()
+
+ if self.hparams.enable_checkpointing:
+
+ # original image
+ original_img_j = search_batch[self.image_modality][img_index_in_batch].detach().cpu().numpy()
+
+ # find highly activated region of the original image
+ proto_occ_j = occurrences[img_index_in_batch, j]
+
+ np.save(proto_epoch_dir / f"original_{j}_epoch_{self.current_epoch}.npy", original_img_j)
+ np.save(proto_epoch_dir / f"occurrence_{j}_epoch_{self.current_epoch}.npy", proto_occ_j)
+
+ return feature_vectors
diff --git a/torchpanic/modules/standard.py b/torchpanic/modules/standard.py
new file mode 100644
index 0000000..7061053
--- /dev/null
+++ b/torchpanic/modules/standard.py
@@ -0,0 +1,115 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+from functools import partial
+from typing import Any, List
+
+import torch
+from torch import nn
+
+from .base import BaseModule
+
+
+class StandardModule(BaseModule):
+ def __init__(
+ self,
+ net: nn.Module,
+ lr: float = 0.001,
+ num_classes: int = 3,
+ weight_decay: float = 0.0005,
+ output_penalty_weight: float = 0.001,
+ **kwargs,
+ ):
+ super().__init__(
+ net=net,
+ num_classes=num_classes,
+ )
+
+ # this line allows to access init params with 'self.hparams' attribute
+ # it also ensures init params will be stored in ckpt
+ self.save_hyperparameters(logger=False, ignore=["net"])
+
+ # loss function
+ if num_classes > 2:
+ self.criterion = torch.nn.CrossEntropyLoss()
+ else:
+ self.criterion = torch.nn.BCEWithLogitsLoss()
+
+ def step(self, batch: Any):
+ x, _, _, y = batch
+ logits, terms = self.forward(x)
+ if self.hparams.num_classes < 3:
+ logits = logits.squeeze()
+ probs = torch.sigmoid(logits)
+ preds = torch.argmax(torch.stack((1.0 - probs, probs), dim=-1), dim=-1)
+ else:
+ preds = torch.argmax(logits, dim=1)
+
+ loss = self.criterion(logits, y)
+ if self.hparams.output_penalty_weight > 0:
+ output_norm = torch.linalg.norm(terms, dim=[1, 2])
+ loss = loss + self.hparams.output_penalty_weight * output_norm.mean()
+ return loss, preds, y.long()
+
+ def training_step(self, batch: Any, batch_idx: int):
+ loss, preds, targets = self.step(batch)
+
+ # log train metrics
+ self._log_train_metrics(loss, preds, targets)
+
+ # we can return here dict with any tensors
+ # and then read it in some callback or in `training_epoch_end()` below
+ # remember to always return loss from `training_step()` or else backpropagation will fail!
+ return {"loss": loss, "preds": preds, "targets": targets}
+
+ def validation_step(self, batch: Any, batch_idx: int):
+ loss, preds, targets = self.step(batch)
+
+ self._update_validation_metrics(loss, preds, targets)
+
+ return {"loss": loss, "preds": preds, "targets": targets}
+
+ def validation_epoch_end(self, outputs: List[Any]):
+ self._log_validation_metrics()
+
+ def test_step(self, batch: Any, batch_idx: int):
+ loss, preds, targets = self.step(batch)
+
+ self._update_test_metrics(loss, preds, targets)
+
+ return {"loss": loss, "preds": preds, "targets": targets}
+
+ def test_epoch_end(self, outputs: List[Any]):
+ self._log_test_metrics()
+
+ def configure_optimizers(self):
+ """Choose what optimizers and learning-rate schedulers to use in your optimization.
+ Normally you'd need one. But in the case of GANs or similar you might have multiple.
+ See examples here:
+ https://pytorch-lightning.readthedocs.io/en/latest/common/lightning_module.html#configure-optimizers
+ """
+ optimizer = torch.optim.AdamW(
+ params=self.parameters(), lr=self.hparams.lr, weight_decay=self.hparams.weight_decay
+ )
+ n_iters = len(self.trainer.datamodule.train_dataloader())
+ scheduler = torch.optim.lr_scheduler.CyclicLR(
+ optimizer,
+ base_lr=self.hparams.lr / 10,
+ max_lr=self.hparams.lr,
+ step_size_up=5 * n_iters,
+ step_size_down=5 * n_iters,
+ cycle_momentum=False,
+ )
+
+ return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
diff --git a/torchpanic/modules/utils.py b/torchpanic/modules/utils.py
new file mode 100644
index 0000000..186087a
--- /dev/null
+++ b/torchpanic/modules/utils.py
@@ -0,0 +1,105 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import math
+from pathlib import Path
+from typing import Union
+
+from hydra.utils import instantiate as hydra_init
+import numpy as np
+from omegaconf import DictConfig, OmegaConf
+import torch
+from yaml import safe_load as yaml_load
+
+import subprocess
+
+
+def format_float_to_str(x: float) -> str:
+ return "{:2.1f}".format(x * 100)
+
+
+def init_vector_normal(vector: torch.Tensor):
+ stdv = 1. / math.sqrt(vector.size(1))
+ vector.data.uniform_(-stdv, stdv)
+
+
+def get_current_stage(epoch, epochs_all, epochs_nam, warmup=10):
+ total = epochs_all + epochs_nam
+ stage = "nam_only"
+ if epoch < 0.5 * warmup:
+ stage = "warmup"
+ elif epoch < warmup:
+ stage = "warmup_protonet"
+ elif (epoch - warmup) % total < epochs_all:
+ stage = "all"
+ return stage
+
+
+def get_last_valid_checkpoint(path: Path):
+ epoch = int(path.stem.split("=")[1].split("-")[0])
+ epoch_old = int(path.stem.split("=")[1].split("-")[0])
+ config = load_config(path)
+ e_all = config.model.epochs_all
+ e_nam = config.model.epochs_nam
+ warmup = config.model.epochs_warmup
+
+ stage = get_current_stage(epoch, e_all, e_nam, warmup)
+ while stage == "nam_only":
+ epoch -= 1
+ stage = get_current_stage(epoch, e_all, e_nam, warmup)
+ if epoch != epoch_old:
+ epoch += 1
+ ckpt_path = str(path)
+ print(f"Previous epoch {epoch_old} was invalid. Valid checkpoint is of epoch {epoch}")
+ return ckpt_path.replace(f"epoch={epoch_old}", f"epoch={epoch}")
+
+
+def init_vectors_orthogonally(vector: torch.Tensor, n_protos_per_class: int):
+ # vector has shape (n_protos, n_chans)
+ assert vector.size(0) % n_protos_per_class == 0
+ torch.nn.init.xavier_uniform_(vector)
+
+ for j in range(vector.size(0)):
+ vector.data[j, j // n_protos_per_class] += 1.
+
+
+def load_config(ckpt_path: Union[str, Path]) -> DictConfig:
+ config_path = str(Path(ckpt_path).parent.parent / '.hydra' / 'config.yaml')
+ with open(config_path) as f:
+ y = yaml_load(f)
+ workdir = Path().absolute()
+ idx = workdir.parts.index('panic')
+ workdir = Path(*workdir.parts[:idx+1])
+ if 'protonet' in y['model']['net']:
+ y['model']['net']['protonet']['pretrained_model'] = \
+ y['model']['net']['protonet']['pretrained_model'].replace('${hydra:runtime.cwd}', str(workdir))
+ config = OmegaConf.create(y)
+ return config
+
+
+def load_model_and_data(ckpt_path: str, device=torch.device("cuda")):
+ ''' loads model and data with respect to a checkpoint path
+ must call data.setup(stage) to setup data
+ pytorch model can be retrieved with model.net '''
+ config = load_config(Path(ckpt_path))
+ data = hydra_init(config.datamodule)
+ model = hydra_init(config.model)
+ model.load_state_dict(torch.load(ckpt_path, map_location=device)["state_dict"])
+ return model, data, config
+
+
+def get_git_hash():
+ return subprocess.check_output([
+ "git", "rev-parse", "HEAD"
+ ], encoding="utf8").strip()
diff --git a/torchpanic/testing.py b/torchpanic/testing.py
new file mode 100644
index 0000000..c726eb0
--- /dev/null
+++ b/torchpanic/testing.py
@@ -0,0 +1,83 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import logging
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+import pytorch_lightning as pl
+
+LOG = logging.getLogger(__name__)
+
+
+def test(config: DictConfig):
+ # Set seed for random number generators in pytorch, numpy and python.random
+ if config.get("seed"):
+ pl.seed_everything(config.seed, workers=True)
+
+ LOG.info("Instantiating datamodule <%s>", config.datamodule._target_)
+ data: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule)
+
+ LOG.info("Instantiating model <%s>", config.model._target_)
+ model: pl.LightningModule = hydra.utils.instantiate(config.model)
+
+ # Init lightning loggers
+ logger: List[pl.LightningLoggerBase] = []
+ if "logger" in config:
+ for lg_conf in config.logger.values():
+ if "_target_" in lg_conf:
+ LOG.info("Instantiating logger <%s>", lg_conf._target_)
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ LOG.info("Instantiating trainer <%s>", config.trainer._target_)
+ trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, logger=logger)
+
+ # Log hyperparameters
+ if trainer.logger:
+ trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path})
+
+ LOG.info("Starting testing!")
+ return trainer.test(model, data, ckpt_path=config.ckpt_path)
+
+
+def validate(config: DictConfig):
+ # Set seed for random number generators in pytorch, numpy and python.random
+ if config.get("seed"):
+ pl.seed_everything(config.seed, workers=True)
+
+ LOG.info("Instantiating datamodule <%s>", config.datamodule._target_)
+ data: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule)
+ data.setup("test")
+
+ LOG.info("Instantiating model <%s>", config.model._target_)
+ model: pl.LightningModule = hydra.utils.instantiate(config.model)
+
+ # Init lightning loggers
+ logger: List[pl.LightningLoggerBase] = []
+ if "logger" in config:
+ for lg_conf in config.logger.values():
+ if "_target_" in lg_conf:
+ LOG.info("Instantiating logger <%s>", lg_conf._target_)
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ LOG.info("Instantiating trainer <%s>", config.trainer._target_)
+ trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, logger=logger)
+
+ # Log hyperparameters
+ if trainer.logger:
+ trainer.logger.log_hyperparams({"ckpt_path": config.ckpt_path})
+
+ LOG.info("Starting validating!")
+ return trainer.validate(model, data, ckpt_path=config.ckpt_path)
diff --git a/torchpanic/training.py b/torchpanic/training.py
new file mode 100644
index 0000000..eb55a39
--- /dev/null
+++ b/torchpanic/training.py
@@ -0,0 +1,60 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import logging
+from typing import List
+
+import hydra
+from omegaconf import DictConfig
+import pytorch_lightning as pl
+import torch
+
+LOG = logging.getLogger(__name__)
+
+
+def train(config: DictConfig):
+ # Set seed for random number generators in pytorch, numpy and python.random
+ if config.get("seed"):
+ pl.seed_everything(config.seed, workers=True)
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+ torch.use_deterministic_algorithms(True)
+
+ LOG.info("Instantiating datamodule <%s>", config.datamodule._target_)
+ data: pl.LightningDataModule = hydra.utils.instantiate(config.datamodule)
+
+ LOG.info("Instantiating model <%s>", config.model._target_)
+ model: pl.LightningModule = hydra.utils.instantiate(config.model)
+
+ # Init lightning callbacks
+ callbacks: List[pl.Callback] = []
+ if "callbacks" in config:
+ for cb_conf in config.callbacks.values():
+ if "_target_" in cb_conf:
+ LOG.info("Instantiating callback <%s>", cb_conf._target_)
+ callbacks.append(hydra.utils.instantiate(cb_conf))
+
+ # Init lightning loggers
+ logger: List[pl.LightningLoggerBase] = []
+ if "logger" in config:
+ for lg_conf in config.logger.values():
+ if "_target_" in lg_conf:
+ LOG.info("Instantiating logger <%s>", lg_conf._target_)
+ logger.append(hydra.utils.instantiate(lg_conf))
+
+ LOG.info("Instantiating trainer <%s>", config.trainer._target_)
+ trainer: pl.Trainer = hydra.utils.instantiate(config.trainer, callbacks=callbacks, logger=logger)
+
+ LOG.info("Starting training!")
+ trainer.fit(model, data)
diff --git a/torchpanic/viz/__init__.py b/torchpanic/viz/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/torchpanic/viz/nam_functions.py b/torchpanic/viz/nam_functions.py
new file mode 100644
index 0000000..b47f506
--- /dev/null
+++ b/torchpanic/viz/nam_functions.py
@@ -0,0 +1,52 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import argparse
+from pathlib import Path
+
+import matplotlib.pyplot as plt
+
+from .tabular import NamInspector, FunctionPlotter
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("checkpoint", type=Path)
+ parser.add_argument("-o", "--output", type=Path)
+
+ args = parser.parse_args()
+
+ inspector = NamInspector(args.checkpoint)
+ inspector.load()
+
+ data, samples, outputs = inspector.get_outputs()
+
+ weights_cat_linear = inspector.get_linear_weights(revert_standardization=True)
+
+ plotter = FunctionPlotter(
+ class_names=["CN", "MCI", "AD"],
+ log_scale=["Tau", "p-Tau"],
+ )
+
+ out_file = args.output
+ if out_file is None:
+ out_file = args.checkpoint.with_name('nam_functions.pdf')
+
+ fig = plotter.plot(data, samples, outputs, weights_cat_linear.drop("bias"))
+ fig.savefig(out_file, bbox_inches="tight", transparent=True)
+ plt.close()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/torchpanic/viz/summary.py b/torchpanic/viz/summary.py
new file mode 100644
index 0000000..66fab51
--- /dev/null
+++ b/torchpanic/viz/summary.py
@@ -0,0 +1,162 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import matplotlib as mpl
+import matplotlib.pyplot as plt
+import numpy as np
+
+import seaborn as sns
+
+def summary_plot(
+ predictions,
+ feature_names=None,
+ max_display=20,
+ title="",
+ class_names=None,
+ class_inds=None,
+ colormap="Set2",
+ axis_color="black",
+ sort=True,
+ plot_size="auto",
+):
+ """Create a summary plot of shapley values for 1-dimensional features, colored by feature values when they are provided.
+ This method is based on Slundberg - SHAP (https://github.com/slundberg/shap/blob/master/shap/)
+ Args:
+ predictions : list of numpy.array
+ For each class, a list of predictions for each function with shape = (n_samples, n_features).
+ feature_names : list
+ Names of the features (length # features)
+ max_display : int
+ How many top features to include in the plot (default is 20)
+ plot_size : "auto" (default), float, (float, float), or None
+ What size to make the plot. By default the size is auto-scaled based on the number of
+ features that are being displayed. Passing a single float will cause each row to be that
+ many inches high. Passing a pair of floats will scale the plot by that
+ number of inches. If None is passed then the size of the current figure will be left
+ unchanged.
+ """
+
+ multi_class = True
+
+ # default color:
+ if multi_class:
+ #cm = mpl.colormaps[colormap]
+ cm = sns.color_palette(colormap, n_colors=3)
+ #color = lambda i: cm(i)
+ color = lambda i: cm[(i + 2) % 3]
+
+ num_features = predictions[0].shape[1] if multi_class else predictions.shape[1]
+
+ shape_msg = (
+ "The shape of the shap_values matrix does not match the shape of the "
+ "provided data matrix."
+ )
+ if num_features - 1 == len(feature_names):
+ raise ValueError(
+ shape_msg
+ + " Perhaps the extra column in the shap_values matrix is the "
+ "constant offset? If so just pass shap_values[:,:-1]."
+ )
+ elif num_features != len(feature_names):
+ raise ValueError(shape_msg)
+
+ if sort:
+ # order features by the sum of their effect magnitudes
+ if multi_class:
+ feature_order = np.argsort(
+ np.sum(np.mean(np.abs(predictions), axis=0), axis=0)
+ )
+ else:
+ feature_order = np.argsort(np.sum(np.abs(predictions), axis=0))
+ feature_order = feature_order[-min(max_display, len(feature_order)) :]
+ else:
+ feature_order = np.flip(np.arange(min(max_display, num_features)), 0)
+
+ row_height = 0.4
+ if plot_size == "auto":
+ figsize = (8, len(feature_order) * row_height + 1.5)
+ elif isinstance(plot_size, (list, tuple)):
+ figsize = (plot_size[0], plot_size[1])
+ elif plot_size is not None:
+ figsize = (8, len(feature_order) * plot_size + 1.5)
+
+ fig, ax = plt.subplots(figsize=figsize)
+ ax.axvline(x=0, color="#999999", zorder=-1)
+
+ legend_handles = []
+ legend_text = []
+
+ if class_names is None:
+ class_names = [f"Class {i}" for i in range(len(predictions))]
+ feature_inds = feature_order[:max_display]
+ y_pos = np.arange(len(feature_inds))
+ left_pos = np.zeros(len(feature_inds))
+
+ if class_inds is None:
+ class_inds = np.argsort([-np.abs(p).mean() for p in predictions])
+ elif class_inds == "original":
+ class_inds = range(len(predictions))
+ for i, ind in enumerate(class_inds):
+ global_importance = np.abs(predictions[ind]).mean(0)
+ ax.barh(
+ y_pos, global_importance[feature_inds], 0.7, left=left_pos, align='center',
+ color=color(i), label=class_names[ind]
+ )
+ left_pos += global_importance[feature_inds]
+ f_names_relevant = [feature_names[i] for i in feature_inds]
+ ax.set_yticklabels(f_names_relevant)
+
+ ax.xaxis.set_ticks_position("top")
+ ax.xaxis.set_label_position("top")
+ ax.yaxis.set_ticks_position("none")
+ ax.spines["right"].set_visible(False)
+ ax.spines["bottom"].set_visible(False)
+ ax.spines["left"].set_visible(False)
+ ax.tick_params(color=axis_color, labelcolor=axis_color)
+
+ plt.yticks(
+ range(len(feature_order)),
+ [feature_names[i] for i in feature_order],
+ fontsize='large'
+ )
+ plt.ylim(-1, len(feature_order))
+
+ ax.set_title(title)
+ ax.tick_params(color=axis_color, labelcolor=axis_color)
+ ax.tick_params("y", length=20, width=0.5, which="major")
+ ax.tick_params("x") # , labelsize=11)
+ ax.xaxis.grid(True)
+ ax.set_xlabel("Mean Importance", fontsize='large')
+
+ # legend_handles.append(missing_handle)
+ # legend_text.append("Missing")
+ if len(legend_handles) > 0:
+ plt.legend(
+ legend_handles[::-1],
+ legend_text[::-1],
+ loc="center right",
+ bbox_to_anchor=(1.30, 0.55),
+ frameon=False,
+ )
+ else:
+ plt.legend(
+ loc="best",
+ frameon=True,
+ fancybox=True,
+ facecolor='white',
+ framealpha=1.0,
+ fontsize=12,
+ )
+
+ return fig, f_names_relevant
diff --git a/torchpanic/viz/tabular.py b/torchpanic/viz/tabular.py
new file mode 100644
index 0000000..9eb2d51
--- /dev/null
+++ b/torchpanic/viz/tabular.py
@@ -0,0 +1,429 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+from typing import Sequence, Union
+
+import matplotlib as mpl
+from matplotlib import gridspec
+from matplotlib.patches import Rectangle
+import matplotlib.pyplot as plt
+import numpy as np
+import pandas as pd
+import seaborn as sns
+import torch
+from torch.utils.data import DataLoader, Dataset
+
+from ..datamodule.adni import AdniDataset
+from ..datamodule.credit_card_fraud import PandasDataSet
+from ..models.panic import PANIC
+from ..models.nam import BaseNAM, NAM
+from ..modules.utils import load_model_and_data
+from ..datamodule.modalities import ModalityType
+from .utils import map_tabular_names
+
+
+def collect_tabular_data(dataset: Dataset) -> pd.DataFrame:
+ data_values = []
+ data_miss = []
+ for i in range(len(dataset)):
+ x = dataset[i][0]
+ x_tab, x_miss = x[ModalityType.TABULAR]
+
+ if isinstance(x_tab, torch.Tensor):
+ x_tab = x_tab.detach().cpu().numpy() # shape = (bs, num_features)
+ if isinstance(x_miss, torch.Tensor):
+ x_miss = x_miss.detach().cpu().numpy() # shape = (bs, num_features)
+
+ data_values.append(x_tab)
+ data_miss.append(x_miss)
+
+ if isinstance(dataset, AdniDataset):
+ index = dataset.rid
+ column_names = map_tabular_names(dataset.column_names)
+ elif isinstance(dataset, PandasDataSet):
+ index, column_names = None, None
+ else:
+ raise ValueError(f"Dataset of type {type(dataset)} not implemented")
+
+ data = np.ma.array(
+ np.stack(data_values, axis=0),
+ mask=np.stack(data_miss, axis=0),
+ copy=False,
+ )
+ data = pd.DataFrame(data, index=index, columns=column_names)
+
+ return data
+
+
+def create_sample_data(tabular_data: pd.DataFrame, n_samples: int = 500) -> torch.Tensor:
+ samples = torch.empty((n_samples, tabular_data.shape[1]))
+ for j, (_, series) in enumerate(tabular_data.iteritems()):
+ if hasattr(series, "cat"):
+ uniq_values = torch.from_numpy(series.cat.categories.to_numpy())
+ values = torch.cat((
+ uniq_values,
+ # fill by repeating the last category
+ torch.full((n_samples - len(uniq_values),), uniq_values[-1])
+ ))
+ else:
+ q = series.quantile([0.01, 0.99])
+ values = torch.linspace(q.iloc[0], q.iloc[1], n_samples)
+ samples[:, j] = values
+ return samples
+
+
+def get_modality_shape(dataloader: DataLoader, modality: ModalityType):
+ dataset: Union[AdniDataset, PandasDataSet] = dataloader.dataset
+ return dataset[0][0][modality].shape
+
+
+def iter_datasets(datamodule):
+ for path in (
+ datamodule.train_data, datamodule.valid_data, datamodule.test_data,
+ ):
+ yield AdniDataset(path, is_training=False, modalities=datamodule.modalities)
+
+
+class NamInspector:
+ def __init__(
+ self, checkpoint_path: str, dataset_name: str = "test", device=torch.device("cuda"),
+ ) -> None:
+ self.checkpoint_path = checkpoint_path
+ self.dataset_name = dataset_name
+ self.device = device
+
+ def load(self):
+ plmodule, datamod, config = load_model_and_data(self.checkpoint_path)
+ datamod.setup("fit" if self.dataset_name != "test" else self.dataset_name)
+
+ data = pd.concat([collect_tabular_data(ds) for ds in iter_datasets(datamod)])
+ idx_cat = config.model.net.nam.idx_cat_features
+ for j in idx_cat:
+ data.iloc[:, j] = data.iloc[:, j].astype("category")
+
+ self._data = data
+ self._model = plmodule.net.eval().to(self.device)
+ self._config = config
+ self._datamod = datamod
+
+ def _get_dataloader(self):
+ dsnames = {
+ "fit": self._datamod.push_dataloader,
+ "val": self._datamod.val_dataloader,
+ "test": self._datamod.test_dataloader,
+ }
+ return dsnames[self.dataset_name]()
+
+ @torch.no_grad()
+ def _collect_outputs(
+ self, samples_loader: DataLoader, include_missing: bool,
+ ) -> Sequence[np.ndarray]:
+ dataloader = self._get_dataloader()
+
+ non_missing = torch.zeros((1, 1, self._data.shape[1]), device=self.device)
+ img_data = torch.zeros(
+ get_modality_shape(dataloader, ModalityType.PET), device=self.device
+ ).unsqueeze(0)
+
+ n_prototypes = self._config.model.net.protonet.n_prototypes_per_class
+
+ model: PANIC = self._model
+ outputs = []
+ for x in samples_loader:
+ x = x.to(self.device).unsqueeze(1)
+ tab_in = torch.cat((x, non_missing.expand_as(x)), axis=1)
+ img_in = img_data.expand(x.shape[0], -1, -1, -1, -1)
+
+ logits, similarities, occurrences, nam_features_without_dropout = model(img_in, tab_in)
+ # only collect outputs referring to tabular features
+ feature_outputs = nam_features_without_dropout[:, n_prototypes:]
+
+ outputs.append(feature_outputs.detach().cpu().numpy())
+
+ if include_missing:
+ tab_in = torch.ones_like(tab_in)[:1]
+ img_in = img_in[:1]
+ logits, similarities, occurrences, nam_features_without_dropout = model(
+ img_in, tab_in
+ )
+ feature_outputs = nam_features_without_dropout[:, n_prototypes:]
+ outputs.append(feature_outputs.detach().cpu().numpy())
+
+ return outputs
+
+ def _apply_inverse_transform(self, data):
+ dataloader = self._get_dataloader()
+
+ # invert standardize transform to restore original feature distributions
+ dataset: AdniDataset = dataloader.dataset
+ if isinstance(data, pd.DataFrame):
+ data_original = pd.DataFrame(
+ dataset.tabular_inverse_transform(data.values),
+ index=data.index, columns=data.columns,
+ )
+ for col in self._data.select_dtypes(include=["category"]).columns:
+ vals = data_original.loc[:, col].apply(np.rint) # round to nearest integer
+ data_original.loc[:, col] = vals.astype("category")
+ else:
+ data_original = dataset.tabular_inverse_transform(data)
+ return data_original
+
+ def get_outputs(self, plt_embeddings=False):
+ samples = create_sample_data(self._data)
+ samples_loader = DataLoader(samples, batch_size=self._config.datamodule.batch_size)
+
+ outputs = self._collect_outputs(samples_loader, include_missing=plt_embeddings)
+
+ samples = np.asfortranarray(samples)
+ data_original = self._apply_inverse_transform(self._data)
+ samples_original = self._apply_inverse_transform(samples)
+
+ if plt_embeddings:
+ samples_original = np.row_stack(
+ (samples_original, 99 * np.ones(samples_original.shape[1]))
+ )
+ data_original = data_original.append(pd.Series(
+ 99, index=data_original.columns, name="MISSING"
+ ))
+
+ outputs = np.concatenate(outputs, axis=0)
+ outputs = np.asfortranarray(outputs)
+
+ # reorder columns so they are in the same order as the output of the model
+ nam_config = self._config.model.net.nam
+ idx = nam_config["idx_real_features"] + nam_config["idx_cat_features"]
+ data_original = data_original.iloc[:, idx]
+ samples_original = samples_original[:, idx]
+
+ return data_original, samples_original, outputs
+
+ def get_linear_weights(self, revert_standardization: bool) -> pd.DataFrame:
+ nam_model: BaseNAM = self._model.nam
+ bias = nam_model.bias.detach().cpu().numpy()
+ weights = nam_model.cat_linear.detach().cpu().numpy()
+
+ dataloader = self._get_dataloader()
+ dataset: AdniDataset = dataloader.dataset
+ # reorder columns so they are in the same order as the output of the model
+ nam_config = self._config.model.net.nam
+ cat_idx = nam_config["idx_cat_features"]
+ columns = ["bias"] + dataset.column_names[cat_idx].tolist()
+
+ if revert_standardization:
+ mean = dataset.tabular_mean[cat_idx]
+ std = dataset.tabular_stddev[cat_idx]
+
+ weights_new = np.empty_like(weights)
+ bias_new = np.empty_like(bias)
+ for k in range(weights.shape[1]):
+ weights_new[:, k] = weights[:, k] / std
+ bias_new[:, k] = bias[:, k] - np.dot(weights_new[:, k], mean)
+ else:
+ bias_new = bias
+ weights_new = weights
+
+ coef = pd.DataFrame(
+ np.row_stack((bias_new, weights_new)), index=map_tabular_names(columns),
+ )
+ coef.index.name = "feature"
+ coef.columns.name = "target"
+ return coef
+
+
+def get_fraudnet_outputs_from_checkpoint(checkpoint_path, device=torch.device("cuda")):
+ plmodule, data, config = load_model_and_data(checkpoint_path)
+ data.setup("fit")
+
+ data = collect_tabular_data(data.train_dataloader())
+ samples = create_sample_data(data)
+ samples_loader = DataLoader(samples, batch_size=config.datamodule.batch_size)
+ non_missing = torch.zeros((1, 1, data.shape[1]), device=device)
+
+ model: NAM = plmodule.net.eval().to(device)
+ outputs = []
+
+ with torch.no_grad():
+ for x in samples_loader:
+ x = x.to(device).unsqueeze(1)
+ x = torch.cat((x, non_missing.expand_as(x)), axis=1)
+ logits, nam_features = model.base_forward(x)
+ outputs.append(nam_features.detach().cpu().numpy())
+ outputs = np.concatenate(outputs)
+
+ samples = np.asfortranarray(samples)
+ outputs = np.asfortranarray(outputs)
+
+ return data, samples, outputs
+
+
+class FunctionPlotter:
+ def __init__(
+ self,
+ class_names: Sequence[str],
+ log_scale: Sequence[str],
+ n_cols: int = 6,
+ size: float = 2.5,
+ ) -> None:
+ self.class_names = class_names
+ self.log_scale = frozenset(log_scale)
+ self.n_cols = n_cols
+ self.size = size
+
+ def plot(
+ self,
+ data: pd.DataFrame,
+ samples: np.ndarray,
+ outputs: np.ndarray,
+ categorial_coefficients: pd.Series,
+ ):
+ assert data.shape[1] == samples.shape[1]
+ assert data.shape[1] == outputs.shape[1]
+ assert samples.shape[0] == outputs.shape[0]
+ assert samples.ndim == 2
+ assert outputs.ndim == 3
+ assert len(self.class_names) == outputs.shape[2]
+
+ assert len(categorial_coefficients.index.difference(data.columns)) == 0
+
+ rc_params = {
+ "axes.titlesize": "small",
+ "xtick.labelsize": "x-small",
+ "ytick.labelsize": "x-small",
+ "lines.linewidth": 2,
+ }
+
+ with mpl.rc_context(rc_params):
+ return self._plot_functions(data, samples, outputs, categorial_coefficients)
+
+ def _plot_functions(
+ self,
+ data: pd.DataFrame,
+ samples: np.ndarray,
+ outputs: np.ndarray,
+ categorial_coefficients: pd.Series,
+ ):
+ categorical_columns = frozenset(categorial_coefficients.index)
+
+ n_cols = self.n_cols
+ n_features = data.shape[1]
+ n_rows = int(np.ceil(n_features / n_cols))
+
+ fig = plt.figure(
+ figsize=(n_cols * self.size, n_rows * self.size)
+ )
+ gs_outer = gridspec.GridSpec(
+ n_rows, n_cols, figure=fig, hspace=0.35, wspace=0.3,
+ )
+
+ n_odds_ratios = outputs.shape[2] - 1
+ palette = sns.color_palette("Set2", n_colors=n_odds_ratios)
+
+ ref_class_name = self.class_names[0]
+ legend_data = {
+ f"{name} vs {ref_class_name}": c
+ for name, c in zip(self.class_names[1:], palette)
+ }
+
+ ax_legend = 0
+ h_ratios = [5, 1]
+ for idx, (name, values) in enumerate(data.iteritems()):
+ i = idx // n_cols
+ j = idx % n_cols
+ gs = gs_outer[i, j].subgridspec(2, 1, height_ratios=h_ratios, hspace=0.1)
+ ax_top = fig.add_subplot(gs[0, 0])
+ ax_bot = fig.add_subplot(gs[1, 0], sharex=ax_top)
+ if idx == data.shape[1] - 1:
+ ax_legend = ax_top
+
+ values = values.dropna()
+ for cls_idx in range(1, n_odds_ratios + 1):
+ if name in categorical_columns:
+ plot_fn = self._plot_categorical
+ coef_ref = categorial_coefficients.loc[name, 0]
+ coef_new = categorial_coefficients.loc[name, cls_idx]
+ x = data.loc[:, name].cat.categories.to_numpy()
+ y = (coef_new - coef_ref) * (x - x[0]) # log-odds ratio
+ else:
+ plot_fn = self._plot_continuous
+ x = samples[:, idx]
+ mid = np.searchsorted(x, data.loc[:, name].mean())
+ log_odds_mid = outputs[mid, idx, cls_idx] - outputs[mid, idx, 0]
+ log_odds = outputs[:, idx, cls_idx] - outputs[:, idx, 0]
+ y = log_odds - log_odds_mid # log-odds ratio
+
+ plot_fn(
+ x=x,
+ y=y,
+ values=values if cls_idx == 1 else None,
+ ax_top=ax_top,
+ ax_bot=ax_bot,
+ color=palette[cls_idx - 1],
+ label=f"Class {cls_idx}",
+ )
+
+ ax_top.axhline(0.0, color="gray")
+ ax_top.tick_params(axis="x", labelbottom=False)
+ ax_top.set_title(name)
+ if name in self.log_scale:
+ ax_top.set_xscale("log")
+
+ if j == 0:
+ ax_top.set_ylabel("log odds ratio")
+
+ legend_kwargs = {"loc": "center left", "bbox_to_anchor": (1.05, 0.5)}
+ self._add_legend(ax_legend, legend_data, legend_kwargs)
+
+ return fig
+
+ def _add_legend(self, ax, palette, legend_kwargs=None):
+ handles = []
+ labels = []
+ for name, color in palette.items():
+ p = Rectangle((0, 0), 1, 1)
+ p.set_facecolor(color)
+ handles.append(p)
+ labels.append(name)
+
+ if legend_kwargs is None:
+ legend_kwargs = {"loc": "best"}
+
+ ax.legend(handles, labels, **legend_kwargs)
+
+ def _plot_continuous(self, x, y, values, ax_top, ax_bot, color, label):
+ ax_top.plot(x, y, marker="none", color=color, label=label, zorder=2.5)
+ ax_top.grid(True)
+
+ if values is not None:
+ ax_bot.boxplot(
+ values.dropna(),
+ vert=False,
+ widths=0.75,
+ showmeans=True,
+ medianprops={"color": "#ff7f00"},
+ meanprops={
+ "marker": "d", "markeredgecolor": "#a65628", "markerfacecolor": "none",
+ },
+ flierprops={"marker": "."},
+ showfliers=False,
+ )
+ ax_bot.yaxis.set_visible(False)
+
+ def _plot_categorical(self, x, y, values, ax_top, ax_bot, color, label):
+ ax_top.step(x, y, where="mid", color=color, label=label, zorder=2.5)
+ ax_top.grid(True)
+
+ if values is not None:
+ _, counts = np.unique(values, return_counts=True)
+ assert len(counts) == len(x)
+ ax_bot.bar(x, height=counts / counts.sum() * 100., width=0.6, color="dimgray")
diff --git a/torchpanic/viz/utils.py b/torchpanic/viz/utils.py
new file mode 100644
index 0000000..52e74ed
--- /dev/null
+++ b/torchpanic/viz/utils.py
@@ -0,0 +1,106 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+from typing import Any, Dict, Sequence
+
+import numpy as np
+import pandas as pd
+import torch
+from torch.utils.data import DataLoader
+
+from ..datamodule.modalities import ModalityType
+from ..models.panic import PANIC
+
+_NAME_MAPPER = None
+
+
+@torch.no_grad()
+def get_tabnet_predictions(model: PANIC, data_loader: DataLoader, device = torch.device("cuda")):
+ model = model.eval().to(device)
+ bias = model.nam.bias.detach().cpu().numpy()[np.newaxis]
+
+ all_logits = []
+ predictions = []
+ for x, x_raw, aug, y in data_loader:
+ logits, similarities, occurrences, nam_features_without_dropout = model(
+ x[ModalityType.PET].to(device),
+ x[ModalityType.TABULAR].to(device),
+ )
+ nam_features_without_dropout = nam_features_without_dropout.detach().cpu().numpy()
+ outputs = np.concatenate((
+ np.tile(bias, [logits.shape[0], 1, 1]),
+ nam_features_without_dropout,
+ ), axis=1)
+ predictions.append(outputs)
+ all_logits.append(logits.detach().cpu().numpy())
+
+ all_logits = np.concatenate(all_logits)
+ y_pred = np.argmax(all_logits, axis=1)
+ return np.concatenate(predictions), all_logits, y_pred
+
+
+def _set_name_mapper(metadata: str):
+ global _NAME_MAPPER
+
+ name_mapper = {
+ "real_age": "Age",
+ "PTEDUCAT": "Education",
+ "PTGENDER": "Male",
+ "ABETA": "A$\\beta$",
+ "TAU": "Tau",
+ "PTAU": "p-Tau",
+ "Left-Hippocampus": "Left Hippocampus",
+ "Right-Hippocampus": "Right Hippocampus",
+ "lh_entorhinal_thickness": "Left Entorhinal Cortex",
+ "rh_entorhinal_thickness": "Right Entorhinal Cortex",
+ }
+ snp_metadata = pd.read_csv(
+ metadata, index_col=0,
+ )
+ snp_key = snp_metadata.agg(
+ lambda x: ":".join(map(str, x[
+ ["Chromosome", "Position", "Allele2_reference", "Allele1_alternative"]
+ ])) + "_" + x["Allele1_alternative"],
+ axis=1
+ )
+ name_mapper.update((dict(zip(snp_key, snp_metadata.loc[:, "rsid"]))))
+ _NAME_MAPPER = name_mapper
+
+
+def _get_name_mapper(metadata: str):
+ global _NAME_MAPPER
+
+ if _NAME_MAPPER is None:
+ _set_name_mapper(metadata)
+ return _NAME_MAPPER
+
+
+def map_tabular_names(names: Sequence[str], metadata_file: str) -> Sequence[str]:
+ name_mapper = _get_name_mapper(metadata_file)
+ return [name_mapper.get(name, name) for name in names]
+
+
+def get_tabnet_output_names(
+ tabular_names: Sequence[str], n_prototypes: int, config: Dict[str, Any], metadata_file: str
+) -> Sequence[str]:
+ name_mapper = _get_name_mapper(metadata_file)
+
+ # reorder columns so they are in the same order as the output of the model
+ nam_config = config["model"]["net"]["nam"]
+ order = nam_config["idx_real_features"] + nam_config["idx_cat_features"]
+
+ features_names = [name_mapper.get(tabular_names[i], tabular_names[i]) for i in order]
+
+ prototype_names = [f"FDG-PET Proto {i}" for i in range(n_prototypes)]
+ return ["bias"] + prototype_names + features_names
diff --git a/torchpanic/viz/waterfall.py b/torchpanic/viz/waterfall.py
new file mode 100644
index 0000000..6d35af3
--- /dev/null
+++ b/torchpanic/viz/waterfall.py
@@ -0,0 +1,342 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import re
+
+import matplotlib.pyplot as plt
+from matplotlib import colors
+from matplotlib.transforms import ScaledTranslation
+import numpy as np
+
+
+def format_value(s, format_str):
+ """ Strips trailing zeros and uses a unicode minus sign.
+ """
+ if np.ma.isMaskedArray(s) and np.ma.is_masked(s):
+ return "missing"
+
+ if not isinstance(s, str):
+ s = format_str % s
+ s = re.sub(r'\.?0+$', '', s)
+ if s[0] == "-":
+ s = u"\u2212" + s[1:]
+ return s
+
+
+
+class WaterfallPlotter:
+ def __init__(self, n_prototypes: int) -> None:
+ self.n_prototypes = n_prototypes
+
+ def plot(
+ self,
+ expected_value,
+ shap_values,
+ features=None,
+ feature_names=None,
+ actual_class_name=None,
+ predicted_class_name=None,
+ max_display=10,
+ ):
+ """ Plots an explantion of a single prediction as a waterfall plot.
+ The SHAP value of a feature represents the impact of the evidence provided by that feature on the model's
+ output. The waterfall plot is designed to visually display how the SHAP values (evidence) of each feature
+ move the model output from our prior expectation under the background data distribution, to the final model
+ prediction given the evidence of all the features. Features are sorted by the magnitude of their SHAP values
+ with the smallest magnitude features grouped together at the bottom of the plot when the number of features
+ in the models exceeds the max_display parameter.
+ Parameters
+ ----------
+ expected_value : float
+ This is the reference value that the feature contributions start from. For SHAP values it should
+ be the value of explainer.expected_value.
+ shap_values : numpy.array
+ One dimensional array of SHAP values.
+ features : numpy.array
+ One dimensional array of feature values. This provides the values of all the
+ features, and should be the same shape as the shap_values argument.
+ feature_names : list
+ List of feature names (# features).
+ actual_class_name : str
+ Name of actual class.
+ predicted_class_name : str
+ Name of predicted class.
+ max_display : str
+ The maximum number of features to plot.
+ """
+
+ # init variables we use for tracking the plot locations
+ num_features = min(max_display, len(shap_values))
+ row_height = 0.2
+ rng = range(num_features - 1, -1, -1)
+ order = np.argsort(-np.abs(shap_values))
+ pos_lefts = []
+ pos_inds = []
+ pos_widths = []
+ neg_lefts = []
+ neg_inds = []
+ neg_widths = []
+ loc = expected_value + shap_values.sum()
+ yticklabels = ["" for i in range(num_features + 1)]
+
+ # size the plot based on how many features we are plotting
+ fig = plt.figure(figsize=(8, num_features * row_height + 1.5))
+
+ # see how many individual (vs. grouped at the end) features we are plotting
+ if num_features == len(shap_values):
+ num_individual = num_features
+ else:
+ num_individual = num_features - 1
+
+ # compute the locations of the individual features and plot the dashed connecting lines
+ for i in range(num_individual):
+ sval = shap_values[order[i]]
+ loc -= sval
+ if sval >= 0:
+ pos_inds.append(rng[i])
+ pos_widths.append(sval)
+ pos_lefts.append(loc)
+ else:
+ neg_inds.append(rng[i])
+ neg_widths.append(sval)
+ neg_lefts.append(loc)
+ if num_individual != num_features or i + 4 < num_individual:
+ plt.plot(
+ [loc, loc],
+ [rng[i] - 1 - 0.4, rng[i] + 0.4],
+ color="#bbbbbb",
+ linestyle="--",
+ linewidth=0.5,
+ zorder=-1,
+ )
+ if features is None or order[i] < self.n_prototypes:
+ yticklabels[rng[i]] = feature_names[order[i]]
+ else:
+ yticklabels[rng[i]] = (
+ feature_names[order[i]]
+ + " = "
+ + format_value(features[order[i]], "%0.03f")
+ )
+
+ # add a last grouped feature to represent the impact of all the features we didn't show
+ if num_features < len(shap_values):
+ yticklabels[0] = "%d other features" % (len(shap_values) - num_features + 1)
+ remaining_impact = expected_value - loc
+ if remaining_impact < 0:
+ pos_inds.append(0)
+ pos_widths.append(-remaining_impact)
+ pos_lefts.append(loc + remaining_impact)
+ else:
+ neg_inds.append(0)
+ neg_widths.append(-remaining_impact)
+ neg_lefts.append(loc + remaining_impact)
+
+ points = (
+ pos_lefts
+ + list(np.array(pos_lefts) + np.array(pos_widths))
+ + neg_lefts
+ + list(np.array(neg_lefts) + np.array(neg_widths))
+ )
+ dataw = np.max(points) - np.min(points)
+
+ # draw invisible bars just for sizing the axes
+ label_padding = np.array([0.1 * dataw if w < 1 else 0 for w in pos_widths])
+ plt.barh(
+ pos_inds,
+ np.array(pos_widths) + label_padding + 0.02 * dataw,
+ left=np.array(pos_lefts) - 0.01 * dataw,
+ color=colors.to_rgba_array("r"),
+ alpha=0,
+ )
+ label_padding = np.array([-0.1 * dataw if -w < 1 else 0 for w in neg_widths])
+ plt.barh(
+ neg_inds,
+ np.array(neg_widths) + label_padding - 0.02 * dataw,
+ left=np.array(neg_lefts) + 0.01 * dataw,
+ color=colors.to_rgba_array("b"),
+ alpha=0,
+ )
+
+ # define variable we need for plotting the arrows
+ head_length = 0.08
+ bar_width = 0.8
+ xlen = plt.xlim()[1] - plt.xlim()[0]
+ fig = plt.gcf()
+ ax = plt.gca()
+ xticks = ax.get_xticks()
+ bbox = ax.get_window_extent().transformed(fig.dpi_scale_trans.inverted())
+ width, height = bbox.width, bbox.height
+ bbox_to_xscale = xlen / width
+ hl_scaled = bbox_to_xscale * head_length
+ renderer = fig.canvas.get_renderer()
+
+ # draw the positive arrows
+ for i in range(len(pos_inds)):
+ dist = pos_widths[i]
+ arrow_obj = plt.arrow(
+ pos_lefts[i],
+ pos_inds[i],
+ max(dist - hl_scaled, 0.000001),
+ 0,
+ head_length=min(dist, hl_scaled),
+ color="tab:red",
+ width=bar_width,
+ head_width=bar_width,
+ )
+
+ txt_obj = plt.text(
+ pos_lefts[i] + 0.5 * dist,
+ pos_inds[i],
+ format_value(pos_widths[i], "%+0.02f"),
+ horizontalalignment="center",
+ verticalalignment="center",
+ color="white",
+ fontsize=12,
+ )
+ text_bbox = txt_obj.get_window_extent(renderer=renderer)
+ arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)
+
+ # if the text overflows the arrow then draw it after the arrow
+ if text_bbox.width > arrow_bbox.width:
+ txt_obj.remove()
+
+ # draw the negative arrows
+ for i in range(len(neg_inds)):
+ dist = neg_widths[i]
+
+ arrow_obj = plt.arrow(
+ neg_lefts[i],
+ neg_inds[i],
+ -max(-dist - hl_scaled, 0.000001),
+ 0,
+ head_length=min(-dist, hl_scaled),
+ color="tab:blue",
+ width=bar_width,
+ head_width=bar_width,
+ )
+
+ txt_obj = plt.text(
+ neg_lefts[i] + 0.5 * dist,
+ neg_inds[i],
+ format_value(neg_widths[i], "%+0.02f"),
+ horizontalalignment="center",
+ verticalalignment="center",
+ color="white",
+ fontsize=12,
+ )
+ text_bbox = txt_obj.get_window_extent(renderer=renderer)
+ arrow_bbox = arrow_obj.get_window_extent(renderer=renderer)
+
+ # if the text overflows the arrow then draw it after the arrow
+ if text_bbox.width > arrow_bbox.width:
+ txt_obj.remove()
+
+ # draw the y-ticks twice, once in gray and then again with just the feature names in black
+ # The 1e-8 is so matplotlib 3.3 doesn't try and collapse the ticks
+ ytick_pos = np.arange(num_features)
+ plt.yticks(ytick_pos, yticklabels[:-1], fontsize=13)
+
+ # put horizontal lines for each feature row
+ for i in range(num_features):
+ plt.axhline(i, color="#cccccc", lw=0.5, dashes=(1, 5), zorder=-1)
+
+ # mark the prior expected value and the model prediction
+ plt.axvline(
+ expected_value,
+ color="#bbbbbb",
+ linestyle="--",
+ linewidth=0.85,
+ zorder=-1,
+ )
+ fx = expected_value + shap_values.sum()
+ plt.axvline(fx, color="#bbbbbb", linestyle="--", linewidth=0.85, zorder=-1)
+
+ # clean up the main axis
+ plt.gca().xaxis.set_ticks_position("bottom")
+ plt.gca().yaxis.set_ticks_position("none")
+ plt.gca().spines["right"].set_visible(False)
+ plt.gca().spines["top"].set_visible(False)
+ plt.gca().spines["left"].set_visible(False)
+ ax.tick_params(labelsize=13)
+ # plt.xlabel("Model output", fontsize=12)
+
+ # draw the E[f(X)] tick mark
+ xmin, xmax = ax.get_xlim()
+ xticks = ax.get_xticks()
+ xticks = list(xticks)
+ min_ind = 0
+ min_diff = 1e10
+ for i in range(len(xticks)):
+ v = abs(xticks[i] - expected_value)
+ if v < min_diff:
+ min_diff = v
+ min_ind = i
+ xticks.pop(min_ind)
+ ax.set_xticks(xticks)
+ ax.tick_params(labelsize=13)
+ ax.set_xlim(xmin, xmax)
+
+ ax2 = ax.twiny()
+ ax2.set_xlim(xmin, xmax)
+ ax2.set_xticks([expected_value, expected_value + 1e-8])
+ ax2.set_xticklabels(["\nbias", "\n$ = " + format_value(expected_value, "%0.03f") + "$"], fontsize=12, ha="left")
+ ax2.spines["right"].set_visible(False)
+ ax2.spines["top"].set_visible(False)
+ ax2.spines["left"].set_visible(False)
+
+ # draw the f(x) tick mark
+ ax3 = ax2.twiny()
+ ax3.set_xlim(xmin, xmax)
+ ax3.set_xticks([fx, fx + 1e-8])
+ the_class = "^{\\mathrm{%s}}" % predicted_class_name if predicted_class_name is not None else ""
+ ax3.set_xticklabels(
+ [f"$\\mu{the_class}$", "$ = " + format_value(fx, "%0.03f") + "$"], fontsize=12, ha="left"
+ )
+ tick_labels = ax3.xaxis.get_majorticklabels()
+ tick_labels[0].set_transform(
+ tick_labels[0].get_transform()
+ + ScaledTranslation(-10 / 72.0, 0, fig.dpi_scale_trans)
+ )
+ tick_labels[1].set_transform(
+ tick_labels[1].get_transform()
+ + ScaledTranslation(12 / 72.0, 0, fig.dpi_scale_trans)
+ )
+ # tick_labels[1].set_color("#999999")
+ ax3.spines["right"].set_visible(False)
+ ax3.spines["top"].set_visible(False)
+ ax3.spines["left"].set_visible(False)
+
+ # adjust the position of the E[f(X)] = x.xx label
+ tick_labels = ax2.xaxis.get_majorticklabels()
+ tick_labels[0].set_transform(
+ tick_labels[0].get_transform()
+ + ScaledTranslation(-14 / 72.0, 0, fig.dpi_scale_trans)
+ )
+ tick_labels[1].set_transform(
+ tick_labels[1].get_transform()
+ + ScaledTranslation(
+ 11 / 72.0, -1 / 72.0, fig.dpi_scale_trans
+ )
+ )
+ # tick_labels[1].set_color("#999999")
+
+ the_title = []
+ if predicted_class_name is not None:
+ the_title.append(f"Predicted: {predicted_class_name}")
+ if actual_class_name is not None:
+ the_title.append(f"Actual: {actual_class_name}")
+ if len(the_title) > 0:
+ plt.title(", ".join(the_title))
+
+ return fig
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..3837227
--- /dev/null
+++ b/train.py
@@ -0,0 +1,30 @@
+# This file is part of Prototypical Additive Neural Network for Interpretable Classification (PANIC).
+#
+# PANIC is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# PANIC is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with PANIC. If not, see .
+import logging
+
+import hydra
+from omegaconf import DictConfig
+
+from torchpanic.training import train
+
+
+@hydra.main(config_path="configs/", config_name="train.yaml", version_base="1.2.0")
+def main(config: DictConfig):
+ return train(config)
+
+
+if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO)
+ main()