diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000..083ec3c --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1 @@ +* @Simpag @nicola-bastianello diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..db3da45 --- /dev/null +++ b/.gitignore @@ -0,0 +1,11 @@ +**/__pycache__ +**/build +*.egg-info +*.so +.DS_Store +.mypy_cache +.tox +.vscode +dist +pyrightconfig.json +.claude \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/benchmarks/bench_array.py b/benchmarks/bench_array.py new file mode 100644 index 0000000..4ae8d8f --- /dev/null +++ b/benchmarks/bench_array.py @@ -0,0 +1,71 @@ +""" +Microbenchmark: ``decent_array.Array`` operator overhead vs native frameworks. + +Measures the wrapper cost added by routing operators through ``Array.__add__``, +``Array.__neg__`` etc. against calling the framework's native operators +directly. Iterates over every framework whose package is importable; missing +optional dependencies are skipped silently. + +The overhead column is ``wrapped / native`` runtime — values close to 1.0x mean +the wrapper is essentially free. Large values at small sizes are expected +(operator dispatch dominates) and should converge toward 1.0x as elementwise +work grows. + +Run with:: + + python benchmarks/bench_array.py +""" + +from __future__ import annotations + +from bench_common import ( + SIZES, + BackendCase, + activate_backend, + discover_backends, + fmt_row, + parse_backends_arg, + print_preamble, + print_size_header, + time_us_safe, +) + +from decent_array import Array + + +def _bench_case(case: BackendCase) -> None: + activate_backend(case.name) + print(f"## {case.name}\n") + for n in SIZES: + a = case.make(n) + b = case.make(n) + d_a, d_b = Array(a), Array(b) + + print_size_header(n) + rows = ( + ("add", lambda a=a, b=b: a + b, lambda d_a=d_a, d_b=d_b: d_a + d_b), + ("sub", lambda a=a, b=b: a - b, lambda d_a=d_a, d_b=d_b: d_a - d_b), + ("mul", lambda a=a, b=b: a * b, lambda d_a=d_a, d_b=d_b: d_a * d_b), + ("div", lambda a=a, b=b: a / b, lambda d_a=d_a, d_b=d_b: d_a / d_b), + ("neg", lambda a=a: -a, lambda d_a=d_a: -d_a), + ("abs", lambda a=a: abs(a), lambda d_a=d_a: abs(d_a)), + ("pow", lambda a=a: a ** 2.0, lambda d_a=d_a: d_a ** 2.0), + ) + for op, native_fn, wrapped_fn in rows: + n_us = time_us_safe(case, native_fn) + w_us = time_us_safe(case, wrapped_fn) + print(fmt_row(op, n_us, w_us)) + print() + print() + + +def main() -> None: + print_preamble("Array operator overhead vs native frameworks") + cases = discover_backends(only=parse_backends_arg()) + print(f"available backends: {', '.join(c.name for c in cases)}\n") + for case in cases: + _bench_case(case) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/bench_common.py b/benchmarks/bench_common.py new file mode 100644 index 0000000..d67726e --- /dev/null +++ b/benchmarks/bench_common.py @@ -0,0 +1,179 @@ +""" +Shared helpers for ``bench_array.py`` and ``bench_iop.py``. + +Three concerns live here so the benchmarks stay focused on the comparison logic: + +* :func:`discover_backends` returns the subset of frameworks whose package is + importable; backends with a missing optional dependency are skipped silently. +* :func:`is_compiled` / :func:`print_preamble` report whether the user is + running against a mypyc-compiled build of ``decent_array`` or the pure-Python + source — this materially changes overhead numbers, so the result is printed + at the top of every run. +* :func:`time_us` / :func:`time_us_safe` wrap :mod:`timeit` to take the + ``min`` of several auto-ranged repeats. ``min`` is the canonical choice: it + reports the lower bound of the machine's per-call cost and is the metric + least sensitive to background activity. A warmup call precedes timing so + JIT-style backends (JAX) don't skew the first iteration. +""" + +from __future__ import annotations + +import importlib +import timeit +from collections.abc import Callable +from dataclasses import dataclass +from typing import Any + +SIZES: tuple[int, ...] = (10, 100, 1_000, 10_000) +REPEATS: int = 7 + + +def _no_sync(_value: Any) -> None: # noqa: ANN401 + """No-op sync used for synchronous backends (numpy, torch CPU, tf eager CPU).""" + + +def _sync_jax(value: Any) -> None: # noqa: ANN401 + """Block until a JAX DeviceArray is materialized, unwrapping ``Array`` if needed.""" + # Imported lazily so the module can load even when decent_array isn't yet importable. + from decent_array import Array # noqa: PLC0415 + + raw = value.value if isinstance(value, Array) else value + raw.block_until_ready() + + +@dataclass(slots=True) +class BackendCase: + """A discovered backend plus the helpers needed to drive it in a benchmark.""" + + name: str + make: Callable[[int], Any] + sync: Callable[[Any], None] + + +def activate_backend(name: str) -> None: + """Activate ``name`` as the live backend, resetting any previously active one. + + ``decent_array`` enforces a single-active-backend invariant per execution context; + swapping between frameworks within one process requires resetting first. + """ + from decent_array.interoperability._backend_manager import reset_backends, set_backend # noqa: PLC0415 + + reset_backends() + set_backend(name) + + +def discover_backends(only: list[str] | None = None) -> list[BackendCase]: + """Return one :class:`BackendCase` per importable framework, in a stable order. + + Args: + only: Optional allowlist of backend names. When provided, frameworks not in the + list are skipped entirely (their packages aren't even imported), and any + requested name that isn't a known backend raises :class:`ValueError`. + + """ + import numpy as np # always available — hard dependency # noqa: PLC0415 + + known = {"numpy", "pytorch", "jax", "tensorflow"} + if only is not None: + unknown = set(only) - known + if unknown: + raise ValueError(f"unknown backend(s): {sorted(unknown)}; known: {sorted(known)}") + wanted = set(only) + else: + wanted = known + + cases: list[BackendCase] = [] + + if "numpy" in wanted: + cases.append(BackendCase("numpy", lambda n: np.random.rand(n), _no_sync)) + + if "pytorch" in wanted: + try: + import torch # noqa: PLC0415 + except ImportError: + pass + else: + cases.append(BackendCase("pytorch", lambda n: torch.from_numpy(np.random.rand(n)), _no_sync)) + + if "jax" in wanted: + try: + import jax.numpy as jnp # noqa: PLC0415 + except ImportError: + pass + else: + cases.append(BackendCase("jax", lambda n: jnp.asarray(np.random.rand(n)), _sync_jax)) + + if "tensorflow" in wanted: + try: + import tensorflow as tf # noqa: PLC0415 + except ImportError: + pass + else: + cases.append(BackendCase("tensorflow", lambda n: tf.constant(np.random.rand(n)), _no_sync)) + + return cases + + +def parse_backends_arg() -> list[str] | None: + """Parse the shared ``--backends`` CLI flag; returns ``None`` if not given.""" + import argparse # noqa: PLC0415 + + parser = argparse.ArgumentParser(add_help=True) + parser.add_argument( + "--backends", + type=str, + default=None, + help="comma-separated allowlist of backends (numpy,pytorch,jax,tensorflow); default = all available", + ) + args = parser.parse_args() + if args.backends is None: + return None + return [b.strip() for b in args.backends.split(",") if b.strip()] + + +def is_compiled() -> tuple[bool, str]: + """Return ``(True, path)`` if the Array module loaded from a ``.so``/``.pyd``, else ``(False, .py path)``.""" + module = importlib.import_module("decent_array._array") + path = module.__file__ or "" + return path.endswith((".so", ".pyd")), path + + +def print_preamble(title: str) -> None: + compiled, path = is_compiled() + print(f"# {title}\n") + print(f"decent_array compiled: {'yes' if compiled else 'no'}") + print(f" Array loaded from: {path}") + print(f" timing: min over {REPEATS} repeats, iterations per repeat auto-tuned to ~0.2s\n") + + +def time_us(case: BackendCase, fn: Callable[[], Any]) -> float: + """Per-call runtime in µs; min over :data:`REPEATS` measurements with autoranged N.""" + + def runner() -> None: + case.sync(fn()) + + runner() # warmup — material for JAX's first-call compilation + timer = timeit.Timer(runner) + n, _ = timer.autorange() + times = timer.repeat(repeat=REPEATS, number=n) + return (min(times) / n) * 1e6 + + +def time_us_safe(case: BackendCase, fn: Callable[[], Any]) -> float | None: + """Like :func:`time_us` but returns ``None`` if ``fn`` raises (e.g. TF 1D matmul).""" + try: + return time_us(case, fn) + except Exception: # noqa: BLE001 + return None + + +def fmt_row(op: str, native_us: float | None, wrapped_us: float | None) -> str: + if native_us is None or wrapped_us is None: + return f" {op:<8} {'n/a':>13} {'n/a':>13} {'n/a':>8}" + ratio = wrapped_us / native_us if native_us > 0 else float("inf") + return f" {op:<8} {native_us:>10.3f} µs {wrapped_us:>10.3f} µs {ratio:>6.2f}x" + + +def print_size_header(n: int) -> None: + print(f"size = {n:_}") + print(f" {'op':<8} {'native':>13} {'wrapped':>13} {'overhead':>8}") diff --git a/benchmarks/bench_iop.py b/benchmarks/bench_iop.py new file mode 100644 index 0000000..5127b0a --- /dev/null +++ b/benchmarks/bench_iop.py @@ -0,0 +1,136 @@ +""" +Microbenchmark: ``iop.`` dispatch overhead vs native frameworks. + +Measures the cost of calling top-level interoperability functions (which look +up the active backend on each call and dispatch to it) against calling each +framework's native equivalents directly. Same shape and intent as +``bench_array.py`` but for the function-style API surface rather than the +operator-style ``Array`` API. Iterates over every framework whose package is +importable; missing optional dependencies are skipped silently. + +Where the two benchmarks share an op (``add``, ``mul``), differences in the +overhead column reflect dunder-method dispatch vs. module-level function +dispatch; the remaining ops (``sum``, ``dot``, ``norm``, ``mean``, ``sqrt``, +``sign``) only exist on this surface. + +Run with:: + + python benchmarks/bench_iop.py +""" + +from __future__ import annotations + +from collections.abc import Callable +from typing import Any + +from bench_common import ( + SIZES, + BackendCase, + activate_backend, + discover_backends, + fmt_row, + parse_backends_arg, + print_preamble, + print_size_header, + time_us_safe, +) + +import decent_array.interoperability as iop +from decent_array import Array + + +def _native_ops(backend: str) -> dict[str, Callable[..., Any]]: + """Return the native-framework equivalents of each ``iop.`` for ``backend``.""" + if backend == "numpy": + import numpy as np # noqa: PLC0415 + + return { + "add": np.add, + "mul": np.multiply, + "dot": np.dot, + "sum": np.sum, + "mean": np.mean, + "norm": np.linalg.norm, + "sqrt": np.sqrt, + "sign": np.sign, + } + if backend == "pytorch": + import torch # noqa: PLC0415 + + return { + "add": torch.add, + "mul": torch.mul, + "dot": torch.dot, + "sum": torch.sum, + "mean": torch.mean, + "norm": torch.linalg.norm, + "sqrt": torch.sqrt, + "sign": torch.sign, + } + if backend == "jax": + import jax.numpy as jnp # noqa: PLC0415 + + return { + "add": jnp.add, + "mul": jnp.multiply, + "dot": jnp.dot, + "sum": jnp.sum, + "mean": jnp.mean, + "norm": jnp.linalg.norm, + "sqrt": jnp.sqrt, + "sign": jnp.sign, + } + if backend == "tensorflow": + import tensorflow as tf # noqa: PLC0415 + + return { + "add": tf.add, + "mul": tf.multiply, + "dot": lambda a, b: tf.tensordot(a, b, axes=1), + "sum": tf.reduce_sum, + "mean": tf.reduce_mean, + "norm": tf.norm, + "sqrt": tf.sqrt, + "sign": tf.sign, + } + raise ValueError(f"unknown backend: {backend}") + + +def _bench_case(case: BackendCase) -> None: + activate_backend(case.name) + native = _native_ops(case.name) + print(f"## {case.name}\n") + for n in SIZES: + a = case.make(n) + b = case.make(n) + d_a, d_b = Array(a), Array(b) + + print_size_header(n) + rows = ( + ("add", lambda a=a, b=b: native["add"](a, b), lambda d_a=d_a, d_b=d_b: iop.add(d_a, d_b)), + ("mul", lambda a=a, b=b: native["mul"](a, b), lambda d_a=d_a, d_b=d_b: iop.mul(d_a, d_b)), + ("dot", lambda a=a, b=b: native["dot"](a, b), lambda d_a=d_a, d_b=d_b: iop.dot(d_a, d_b)), + ("sum", lambda a=a: native["sum"](a), lambda d_a=d_a: iop.sum(d_a)), + ("mean", lambda a=a: native["mean"](a), lambda d_a=d_a: iop.mean(d_a)), + ("norm", lambda a=a: native["norm"](a), lambda d_a=d_a: iop.norm(d_a)), + ("sqrt", lambda a=a: native["sqrt"](a), lambda d_a=d_a: iop.sqrt(d_a)), + ("sign", lambda a=a: native["sign"](a), lambda d_a=d_a: iop.sign(d_a)), + ) + for op, native_fn, wrapped_fn in rows: + n_us = time_us_safe(case, native_fn) + w_us = time_us_safe(case, wrapped_fn) + print(fmt_row(op, n_us, w_us)) + print() + print() + + +def main() -> None: + print_preamble("iop function-call overhead vs native frameworks") + cases = discover_backends(only=parse_backends_arg()) + print(f"available backends: {', '.join(c.name for c in cases)}\n") + for case in cases: + _bench_case(case) + + +if __name__ == "__main__": + main() diff --git a/benchmarks/profile_hotpath.py b/benchmarks/profile_hotpath.py new file mode 100644 index 0000000..588e494 --- /dev/null +++ b/benchmarks/profile_hotpath.py @@ -0,0 +1,111 @@ +""" +Profile ``decent_array`` hot paths to find where wrapper overhead actually lives. + +Runs a tight loop of representative operations (mix of ``Array`` operators and +``iop`` function calls) under :mod:`cProfile`, then prints the highest-impact +callees by cumulative time. The output answers "of the wrapper overhead we see +in ``bench_array.py`` / ``bench_iop.py``, where is the time actually spent?" — +turning a vague ratio number into specific functions to optimize. + +Notes: +* cProfile adds per-call overhead of ~1 µs, which dominates anything sub-µs. The + *relative* shape of the profile is still informative (which functions are + called most often, and which take the largest share of cumulative time); + don't read the absolute µs values. +* Repeats are deliberately on the small side so that pure-Python and mypyc- + compiled runs both finish quickly. For mypyc-compiled modules cProfile only + records entry/exit (compiled internals are opaque), so the profile is most + informative against the pure-Python source. + +Run with:: + + python benchmarks/profile_hotpath.py + python benchmarks/profile_hotpath.py --backend pytorch + python benchmarks/profile_hotpath.py --topn 50 +""" + +from __future__ import annotations + +import argparse +import cProfile +import importlib +import pstats +from io import StringIO + +import numpy as np + +import decent_array.interoperability as iop +from decent_array import Array + + +def _is_compiled() -> tuple[bool, str]: + module = importlib.import_module("decent_array._array") + path = module.__file__ or "" + return path.endswith((".so", ".pyd")), path + + +def hot_loop(a: Array, b: Array, iterations: int) -> None: + """A representative mix of operator-style and function-style calls.""" + for _ in range(iterations): + # Operator surface (Array dunders) + _ = a + b + _ = a - b + _ = a * b + _ = a / b + _ = a + 2.0 + _ = -a + _ = abs(a) + _ = a**2.0 + # Function surface (iop dispatch) + _ = iop.add(a, b) + _ = iop.mul(a, b) + _ = iop.sum(a) + _ = iop.norm(a) + _ = iop.dot(a, b) + _ = iop.sqrt(a) + + +def main() -> None: + parser = argparse.ArgumentParser() + parser.add_argument("--backend", default="numpy", choices=["numpy", "pytorch", "jax", "tensorflow"]) + parser.add_argument("--iterations", type=int, default=20_000) + parser.add_argument("--topn", type=int, default=30) + parser.add_argument("--size", type=int, default=100, help="size of the arrays to operate on") + args = parser.parse_args() + + iop.set_backend(args.backend) + + # Modest fixed-size array — small enough that wrapper overhead is the + # dominant cost, large enough that the underlying ops aren't degenerate. + a = iop.uniform(0.0, 1.0, shape=(args.size,)) + b = iop.uniform(0.0, 1.0, shape=(args.size,)) + + # Warmup so first-call jit / cache effects don't pollute the profile. + hot_loop(a, b, iterations=10) + + profiler = cProfile.Profile() + profiler.enable() + hot_loop(a, b, iterations=args.iterations) + profiler.disable() + + compiled, path = _is_compiled() + print(f"backend: {args.backend} compiled: {'yes' if compiled else 'no'}") + print(f"Array module: {path}") + print(f"iterations: {args.iterations:_}\n") + + buf = StringIO() + stats = pstats.Stats(profiler, stream=buf) + stats.sort_stats("cumulative") + stats.print_stats(args.topn) + print(buf.getvalue()) + + print("\n--- top callees by total (self) time ---\n") + buf2 = StringIO() + stats2 = pstats.Stats(profiler, stream=buf2) + stats2.sort_stats("tottime") + stats2.print_stats(args.topn) + print(buf2.getvalue()) + + +if __name__ == "__main__": + main() diff --git a/decent_array/__init__.py b/decent_array/__init__.py new file mode 100644 index 0000000..9893247 --- /dev/null +++ b/decent_array/__init__.py @@ -0,0 +1,8 @@ +from decent_array import interoperability, types +from decent_array._array import Array + +__all__ = [ + "Array", + "interoperability", + "types", +] diff --git a/decent_array/_array.py b/decent_array/_array.py new file mode 100644 index 0000000..9c554f5 --- /dev/null +++ b/decent_array/_array.py @@ -0,0 +1,301 @@ +""" +Lightweight wrapper around backend-native arrays. + +The :class:`Array` class wraps a single value of the active backend's framework type. +Under the single-active-backend invariant maintained by the backend manager, every +:class:`Array` at runtime holds a value from the same framework, so operators dispatch +directly to the active backend without per-call isinstance dispatch. + +Operator contract is *strict*: binary arithmetic and indexing accept either another +:class:`Array` or a Python scalar (``int``/``float``). Pass other framework-native +arrays through :func:`decent_array.interoperability.get_item` and friends, not through the +operator path. + +Hot-path notes: + +* ``__add__``/``__sub__``/``__mul__``/``__truediv__``/``__matmul__``, the unary + ``__neg__``/``__abs__``/``__pow__``, the comparisons ``__eq__``/``__ne__``/``__lt__``/ + ``__le__``/``__gt__``/``__ge__`` and the bitwise ``__and__``/``__rand__`` are + inlined: every supported framework's tensor implements the equivalent operator + natively with numpy-equivalent semantics, so routing through the backend saves + nothing. +* Operators that *do* go through the backend (in-place math, indexing, properties + like ``shape``/``transpose``) read the cached ``_backend`` slot. +* Overriding ``__eq__`` makes :class:`Array` unhashable (Python clears ``__hash__`` + automatically). This matches numpy/torch/jax/tf, where element-wise equality is + more useful than identity-based hashing. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, Self + +from decent_array.interoperability._backend_manager import register_backend_listener + +if TYPE_CHECKING: + from decent_array.interoperability._abstracts import Backend + from decent_array.types import ArrayKey, SupportedArrayTypes + + +_BACKEND_INSTANCE: Backend | None = None + + +def _update_backend(backend: Backend | None) -> None: + global _BACKEND_INSTANCE # noqa: PLW0603 + _BACKEND_INSTANCE = backend + + +register_backend_listener(_update_backend) + + +class Array: # noqa: PLR0904 + """ + Wrapper around a single backend-native array. + + Storage is two slots (``value``, ``_backend``) declared via ``__slots__``; + instances have no ``__dict__``. Every operator that delegates to the backend + reads the cached slot, so dispatch is one slot load plus the backend method call. + """ + + __slots__ = ("_backend", "value") + + def __init__(self, value: SupportedArrayTypes) -> None: + """ + Wrap ``value`` in an :class:`Array`. + + Args: + value: A backend-native array (or scalar) to wrap. The attribute is typed + as :class:`typing.Any` because the wrapper is intentionally type-erased + — backend code accesses :attr:`value` knowing the framework type, and + typing it more strictly forces a ``cast`` at every call site without + runtime benefit. + + Raises: + RuntimeError: If no backend is registered yet. An :class:`Array` cannot be + constructed until a backend is set; call :func:`set_backend` to initialize + the interoperability layer. + + """ + if _BACKEND_INSTANCE is None: + raise RuntimeError( + "No backend registered yet. An Array cannot be constructed until a backend is set. " + "Call set_backend() to initialize the interoperability layer." + ) + + self.value: Any = value + self._backend: Backend = _BACKEND_INSTANCE + + # Binary arithmetic ---------------------------------------------------- + + def __add__(self, other: Array | float) -> Array: + """Return the sum of the array and another array or a scalar.""" + return Array(self.value + (other.value if type(other) is Array else other)) + + def __radd__(self, other: float) -> Array: + """Return the sum of the array and a scalar.""" + return Array(other + self.value) + + def __sub__(self, other: Array | float) -> Array: + """Return the subtraction of another array or a scalar from the array.""" + return Array(self.value - (other.value if type(other) is Array else other)) + + def __rsub__(self, other: float) -> Array: + """Return the subtraction of the array from a scalar.""" + return Array(other - self.value) + + def __mul__(self, other: Array | float) -> Array: + """Return the product of the array and another array or a scalar.""" + return Array(self.value * (other.value if type(other) is Array else other)) + + def __rmul__(self, other: float) -> Array: + """Return the product of the array and a scalar.""" + return Array(other * self.value) + + def __truediv__(self, other: Array | float) -> Array: + """Return the true division of the array by ``other``.""" + return Array(self.value / (other.value if type(other) is Array else other)) + + def __rtruediv__(self, other: float) -> Array: + """Return the true division of ``other`` by the array.""" + return Array(other / self.value) + + def __matmul__(self, other: Array) -> Array: + """Return the matrix multiplication of the array with ``other``.""" + return Array(self.value @ other.value) + + def __rmatmul__(self, other: Array) -> Array: + """Return the matrix multiplication of ``other`` with the array.""" + return Array(other.value @ self.value) + + def __pow__(self, other: float) -> Array: + """Exponentiate the array by a scalar power.""" + # numpy/torch/jax/tf all implement ``tensor ** p`` with semantics matching the + # backend's ``pow``; routing through the backend would cost an extra method + # call for no behavioral difference. + return Array(self.value**other) + + # Comparisons ---------------------------------------------------------- + # + # Element-wise comparisons return an :class:`Array` of bools. The ``__eq__`` and + # ``__ne__`` parameters are typed ``object`` to match the LSP signature inherited + # from :class:`object`; the body still enforces the strict ``Array | scalar`` + # contract via the underlying framework's operator (incompatible operands raise + # from the backend's native comparison, matching ``__add__``). + # + # Overriding ``__eq__`` makes instances unhashable; ``__hash__ = None`` makes that + # explicit (and silences the lint that flags the dropped ``__hash__``). + + __hash__ = None # type: ignore[assignment] + + def __eq__(self, other: object) -> Array: # type: ignore[override] + """Element-wise equality.""" + return Array(self.value == (other.value if type(other) is Array else other)) + + def __ne__(self, other: object) -> Array: # type: ignore[override] + """Element-wise inequality.""" + return Array(self.value != (other.value if type(other) is Array else other)) + + def __lt__(self, other: Array | float) -> Array: + """Element-wise less-than.""" + return Array(self.value < (other.value if type(other) is Array else other)) + + def __le__(self, other: Array | float) -> Array: + """Element-wise less-than-or-equal.""" + return Array(self.value <= (other.value if type(other) is Array else other)) + + def __gt__(self, other: Array | float) -> Array: + """Element-wise greater-than.""" + return Array(self.value > (other.value if type(other) is Array else other)) + + def __ge__(self, other: Array | float) -> Array: + """Element-wise greater-than-or-equal.""" + return Array(self.value >= (other.value if type(other) is Array else other)) + + # Bitwise -------------------------------------------------------------- + # + # Bitwise AND is only defined for integer/boolean dtypes. ``__and__``'s ``Array + # | int`` is a Union (mypyc keeps Union operands boxed, so a ``bool`` operand + # stays a ``bool`` and TF's strict dtype check accepts it). ``__rand__``'s + # operand is typed ``Any`` for the same reason: a single-primitive annotation + # like ``int`` causes mypyc to unbox a ``True`` to ``1`` before the call body, + # which fails TF's ``1 & bool_tensor`` rejection. Native operator semantics on + # the wrapped tensor enforce the actual dtype contract. + + def __and__(self, other: Array | int) -> Array: + """Element-wise bitwise/logical AND.""" + return Array(self.value & (other.value if type(other) is Array else other)) + + def __rand__(self, other: Any) -> Array: # noqa: ANN401 + """Element-wise bitwise/logical AND with the array on the right.""" + return Array(other & self.value) + + # In-place arithmetic -------------------------------------------------- + # + # The backend handles the framework's mutability semantics: numpy/pytorch mutate + # `value` in place, jax/tensorflow rebind it. In every case the returned object is + # the same wrapper instance, so we just discard the return and yield ``self``. + + def __iadd__(self, other: Array | float) -> Self: + """In-place addition.""" + self._backend.iadd(self, other) + return self + + def __isub__(self, other: Array | float) -> Self: + """In-place subtraction.""" + self._backend.isub(self, other) + return self + + def __imul__(self, other: Array | float) -> Self: + """In-place multiplication.""" + self._backend.imul(self, other) + return self + + def __itruediv__(self, other: Array | float) -> Self: + """In-place true division.""" + self._backend.idiv(self, other) + return self + + # Unary ---------------------------------------------------------------- + + def __neg__(self) -> Array: + """Return the negation of the array.""" + # Native ``-tensor`` matches the backend's ``negative`` wrapper across all + # supported frameworks, so the indirection is not needed. + return Array(-self.value) + + def __abs__(self) -> Array: + """Return the absolute value of the array.""" + # Same rationale as ``__neg__`` — native ``abs(tensor)`` matches each + # backend's ``absolute`` implementation. + return Array(abs(self.value)) + + # Indexing ------------------------------------------------------------- + + def __getitem__(self, key: ArrayKey) -> Array: + """Return the item at ``key``.""" + return self._backend.get_item(self, key) + + def __setitem__(self, key: ArrayKey, value: Array | float) -> None: + """Set the item at ``key`` to ``value``.""" + if not isinstance(value, Array): + value = Array(value) + self._backend.set_item(self, key, value) + + # Containers / iteration ---------------------------------------------- + + def __len__(self) -> int: + """Return the size of the first dimension of the array.""" + return len(self.value) + + # Coercion ------------------------------------------------------------- + + def __float__(self) -> float: + """Coerce a scalar array to a Python float.""" + return float(self._backend.astype(self, float)) + + # Repr ----------------------------------------------------------------- + + def __repr__(self) -> str: + """Show the wrapper and the wrapped value.""" + return f"Array({self.value!r})" + + def __str__(self) -> str: + """Stringify the wrapped value, not the wrapper.""" + return str(self.value) + + # Properties ----------------------------------------------------------- + + @property + def shape(self) -> tuple[int, ...]: + """Return the shape of the array.""" + return self._backend.shape(self) + + @property + def size(self) -> int: + """Return the total number of elements in the array.""" + return self._backend.size(self) + + @property + def ndim(self) -> int: + """Return the number of dimensions of the array.""" + return self._backend.ndim(self) + + @property + def transpose(self) -> Array: + """Return a transposed view of the array.""" + return self._backend.transpose(self) + + @property + def T(self) -> Array: # noqa: N802 + """Return a transposed view of the array.""" + return self.transpose + + @property + def any(self) -> bool: + """Return True if any element of the array is truthy.""" + return self._backend.any(self) + + @property + def all(self) -> bool: + """Return True if all elements of the array are truthy.""" + return self._backend.all(self) diff --git a/decent_array/interoperability/__init__.py b/decent_array/interoperability/__init__.py new file mode 100644 index 0000000..a1b96ed --- /dev/null +++ b/decent_array/interoperability/__init__.py @@ -0,0 +1,159 @@ +""" +Interoperability layer. + +Typical usage:: + + import decent_array.interoperability as iop + + iop.set_backend("numpy") + a = iop.zeros((3, 3)) + iop.set_seed(42) # Optional: set RNG seed for reproducibility + s = iop.normal(shape=(2,)) + +""" + +from ._backend_manager import set_backend +from ._decorators import autodecorate_cost_method +from ._iop.functions import ( + absolute, + add, + all, # noqa: A004 + any, # noqa: A004 + argmax, + argmin, + astype, + bitwise_and, + copy, + device_of, + device_to_native, + diag, + div, + dot, + eq, + eye, + eye_like, + from_numpy, + from_numpy_like, + ge, + get_item, + gt, + iadd, + idiv, + imul, + isub, + le, + lt, + matmul, + max, # noqa: A004 + maximum, + mean, + min, # noqa: A004 + mul, + ndim, + ne, + negative, + norm, + ones, + ones_like, + pow, # noqa: A004 + reshape, + set_item, + shape, + sign, + size, + sqrt, + squeeze, + stack, + sub, + sum, # noqa: A004 + to_array, + to_numpy, + transpose, + unsqueeze, + zeros, + zeros_like, +) +from ._iop.rng import ( + choice, + derive_seed, + get_rng_state, + get_seed, + normal, + normal_like, + set_rng_state, + set_seed, + uniform, + uniform_like, +) + +__all__ = [ + "absolute", + "add", + "all", + "any", + "argmax", + "argmin", + "astype", + "autodecorate_cost_method", + "bitwise_and", + "choice", + "copy", + "derive_seed", + "device_of", + "device_to_native", + "diag", + "div", + "dot", + "eq", + "eye", + "eye_like", + "from_numpy", + "from_numpy_like", + "ge", + "get_item", + "get_rng_state", + "get_seed", + "gt", + "iadd", + "idiv", + "imul", + "isub", + "le", + "lt", + "matmul", + "max", + "maximum", + "mean", + "min", + "mul", + "ndim", + "ne", + "negative", + "norm", + "normal", + "normal_like", + "ones", + "ones_like", + "pow", + "reshape", + "set_backend", + "set_item", + "set_rng_state", + "set_seed", + "shape", + "sign", + "size", + "sqrt", + "squeeze", + "stack", + "sub", + "sum", + "to_array", + "to_numpy", + "transpose", + "uniform", + "uniform_like", + "unsqueeze", + "zeros", + "zeros_like", +] diff --git a/decent_array/interoperability/_abstracts/__init__.py b/decent_array/interoperability/_abstracts/__init__.py new file mode 100644 index 0000000..b49360e --- /dev/null +++ b/decent_array/interoperability/_abstracts/__init__.py @@ -0,0 +1,3 @@ +from .backend import Backend + +__all__ = ["Backend"] diff --git a/decent_array/interoperability/_abstracts/backend.py b/decent_array/interoperability/_abstracts/backend.py new file mode 100644 index 0000000..72b0b5e --- /dev/null +++ b/decent_array/interoperability/_abstracts/backend.py @@ -0,0 +1,327 @@ +""" +Abstract :class:`Backend` contract. + +All abstract methods live in this single class rather than across six mixin ABCs. The +flat layout is mypyc-compatible: when this module is included in the same compilation +group as the concrete backends, ``_BACKEND.add(self, other)`` becomes a native +compiled-to-compiled call (no Python attribute lookup, no bound-method allocation), +which removes the need for a ``raw_ops`` escape hatch on the hot path. + +Section dividers in this file group related operations the way the legacy split files +did (creation, manipulation, linalg, math, operators, RNG); the only thing that +changed is that they all live on one class. +""" + +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from decent_array.types import SupportedDevices + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from decent_array import Array + from decent_array.types import ArrayKey + + +class Backend(ABC): # noqa: PLR0904 + """ + Abstract base class for a backend. + + Concrete backends are bound to a single :class:`SupportedDevices` at construction + time; that device is the default for all new arrays produced by this backend. + """ + + def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: + self.device: SupportedDevices = device + + # Array creation ------------------------------------------------------ + + @abstractmethod + def zeros(self, shape: tuple[int, ...]) -> Array: + """Create an array of zeros with the given shape.""" + + @abstractmethod + def zeros_like(self, array: Array) -> Array: + """Create an array of zeros matching the shape and type of ``array``.""" + + @abstractmethod + def ones(self, shape: tuple[int, ...]) -> Array: + """Create an array of ones with the given shape.""" + + @abstractmethod + def ones_like(self, array: Array) -> Array: + """Create an array of ones matching the shape and type of ``array``.""" + + @abstractmethod + def eye(self, n: int) -> Array: + """Create an ``n x n`` identity matrix.""" + + @abstractmethod + def eye_like(self, array: Array) -> Array: + """Create an identity matrix matching the trailing two dims of ``array``.""" + + @abstractmethod + def device_to_native(self, device: SupportedDevices) -> Any: # noqa: ANN401 + """Convert :class:`SupportedDevices` to the backend's native device representation.""" + + @abstractmethod + def device_of(self, array: Array) -> SupportedDevices: + """Return the :class:`SupportedDevices` of the given array.""" + + # Array manipulation -------------------------------------------------- + + @abstractmethod + def copy(self, array: Array) -> Array: + """Return a copy of ``array``.""" + + @abstractmethod + def to_numpy(self, array: Array) -> NDArray[Any]: + """Convert ``array`` to a NumPy array on the CPU.""" + + @abstractmethod + def from_numpy(self, array: NDArray[Any]) -> Array: + """Convert a NumPy array on the CPU to an :class:`Array` on this backend.""" + + @abstractmethod + def from_numpy_like(self, array: NDArray[Any], like: Array) -> Array: + """Convert a Numpy array to an :class:`Array` on this backend, matching shape and type of ``like``.""" + + @abstractmethod + def to_array(self, array: float | bool) -> Array: + """Convert a Python scalar to an :class:`Array` on this backend.""" + + @abstractmethod + def stack(self, arrays: Sequence[Array], axis: int = 0) -> Array: + """Stack a sequence of arrays along a new dimension.""" + + @abstractmethod + def reshape(self, array: Array, shape: tuple[int, ...]) -> Array: + """Reshape ``array`` to ``shape``.""" + + @abstractmethod + def transpose(self, array: Array, axis: tuple[int, ...] | None = None) -> Array: + """Transpose ``array``; ``None`` reverses the dimensions.""" + + @abstractmethod + def shape(self, array: Array) -> tuple[int, ...]: + """Return the shape of ``array``.""" + + @abstractmethod + def size(self, array: Array) -> int: + """Return the total number of elements in ``array``.""" + + @abstractmethod + def ndim(self, array: Array) -> int: + """Return the number of dimensions of ``array``.""" + + @abstractmethod + def squeeze(self, array: Array, axis: int | tuple[int, ...] | None = None) -> Array: + """Remove single-dimensional entries from ``array``.""" + + @abstractmethod + def unsqueeze(self, array: Array, axis: int) -> Array: + """Insert a singleton dimension at ``axis``.""" + + @abstractmethod + def diag(self, array: Array) -> Array: + """Diagonal: build from a vector or extract from a matrix.""" + + @abstractmethod + def astype(self, array: Array, dtype: type[float | int | bool]) -> float | int | bool: + """Cast a single-element ``array`` to a Python scalar of ``dtype``.""" + + # Linalg -------------------------------------------------------------- + + @abstractmethod + def dot(self, array1: Array, array2: Array) -> Array: + """Dot product of two arrays.""" + + @abstractmethod + def matmul(self, array1: Array, array2: Array) -> Array: + """Matrix multiplication of two arrays.""" + + @abstractmethod + def norm( + self, + array: Array, + p: float = 2, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ) -> Array: + """Compute the norm of ``array``.""" + + # Math reductions ----------------------------------------------------- + + @abstractmethod + def sum(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + """Sum elements of ``array`` along ``axis``.""" + + @abstractmethod + def mean(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + """Mean of ``array`` along ``axis``.""" + + @abstractmethod + def min(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + """Minimum of ``array`` along ``axis``.""" + + @abstractmethod + def max(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + """Maximum of ``array`` along ``axis``.""" + + @abstractmethod + def any(self, array: Array) -> bool: + """Return True if any element of ``array`` is truthy.""" + + @abstractmethod + def all(self, array: Array) -> bool: + """Return True if all elements of ``array`` are truthy.""" + + # Math elementwise — both operands may be Array or scalar (operator dunders pass + # either). ``Array | float`` covers both because PEP 484's numeric tower implicitly + # admits ``int``. + + @abstractmethod + def add(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise addition.""" + + @abstractmethod + def iadd[T: Array](self, array1: T, array2: Array | float) -> T: + """In-place element-wise addition.""" + + @abstractmethod + def sub(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise subtraction.""" + + @abstractmethod + def isub[T: Array](self, array1: T, array2: Array | float) -> T: + """In-place element-wise subtraction.""" + + @abstractmethod + def mul(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise multiplication.""" + + @abstractmethod + def imul[T: Array](self, array1: T, array2: Array | float) -> T: + """In-place element-wise multiplication.""" + + @abstractmethod + def div(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise division.""" + + @abstractmethod + def idiv[T: Array](self, array1: T, array2: Array | float) -> T: + """In-place element-wise division.""" + + @abstractmethod + def pow(self, array: Array, p: float) -> Array: + """Raise ``array`` to power ``p``.""" + + @abstractmethod + def negative(self, array: Array) -> Array: + """Element-wise negation.""" + + @abstractmethod + def absolute(self, array: Array) -> Array: + """Element-wise absolute value.""" + + @abstractmethod + def sqrt(self, array: Array) -> Array: + """Element-wise square root.""" + + # Comparisons — both operands may be Array or scalar. + + @abstractmethod + def eq(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise equality. Returns an :class:`Array` of bools.""" + + @abstractmethod + def ne(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise inequality. Returns an :class:`Array` of bools.""" + + @abstractmethod + def lt(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise less-than. Returns an :class:`Array` of bools.""" + + @abstractmethod + def le(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise less-than-or-equal. Returns an :class:`Array` of bools.""" + + @abstractmethod + def gt(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise greater-than. Returns an :class:`Array` of bools.""" + + @abstractmethod + def ge(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise greater-than-or-equal. Returns an :class:`Array` of bools.""" + + # Bitwise — operands may be int/bool arrays or scalars. Mirrors Python's ``&``, + # which dispatches to ``logical_and`` on bool tensors and ``bitwise_and`` on int + # tensors in every supported framework. + + @abstractmethod + def bitwise_and(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise bitwise/logical AND.""" + + # Operators ----------------------------------------------------------- + + @abstractmethod + def sign(self, array: Array) -> Array: + """Element-wise sign.""" + + @abstractmethod + def maximum(self, array1: Array | float, array2: Array | float) -> Array: + """Element-wise maximum.""" + + @abstractmethod + def argmax(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + """Index of maximum value along ``axis``.""" + + @abstractmethod + def argmin(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + """Index of minimum value along ``axis``.""" + + @abstractmethod + def set_item(self, array: Array, key: ArrayKey, value: Array) -> None: + """Set ``array[key] = value``.""" + + @abstractmethod + def get_item(self, array: Array, key: ArrayKey) -> Array: + """Return ``array[key]``.""" + + # RNG ----------------------------------------------------------------- + + @abstractmethod + def set_seed(self, seed: int) -> None: + """Seed the backend's RNG with ``seed``.""" + + @abstractmethod + def get_rng_state(self) -> dict[str, Any]: + """Return a snapshot of the backend's RNG state.""" + + @abstractmethod + def set_rng_state(self, state: dict[str, Any]) -> None: + """Restore an RNG snapshot produced by :meth:`get_rng_state`.""" + + @abstractmethod + def normal(self, mean: float = 0.0, std: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + """Draw normally distributed samples.""" + + @abstractmethod + def uniform(self, low: float = 0.0, high: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + """Draw uniformly distributed samples from ``[low, high)``.""" + + @abstractmethod + def normal_like(self, array: Array, mean: float = 0.0, std: float = 1.0) -> Array: + """Draw normally distributed samples shaped like ``array``.""" + + @abstractmethod + def uniform_like(self, array: Array, low: float = 0.0, high: float = 1.0) -> Array: + """Draw uniformly distributed samples shaped like ``array``.""" + + @abstractmethod + def choice(self, array: Array, size: int, replace: bool = True) -> Array: + """Sample ``size`` elements from ``array``.""" diff --git a/decent_array/interoperability/_backend_manager.py b/decent_array/interoperability/_backend_manager.py new file mode 100644 index 0000000..3092958 --- /dev/null +++ b/decent_array/interoperability/_backend_manager.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +import importlib +from collections.abc import Callable +from contextvars import ContextVar + +from decent_array.types import SupportedDevices, SupportedFrameworks + +from ._abstracts import Backend + +_BACKEND_REGISTRY: dict[SupportedFrameworks, type[Backend]] = {} +_BACKEND_INSTANCES: dict[SupportedFrameworks, Backend] = {} +_ACTIVE_BACKEND: ContextVar[SupportedFrameworks | None] = ContextVar( + "decent_array.interoperability.active_backend", default=None +) +_BACKEND_LISTENERS: list[Callable[[Backend | None], None]] = [] +_BACKEND_INSTANCE: Backend | None = None + + +def set_backend( + backend: SupportedFrameworks | str, + device: SupportedDevices | str = SupportedDevices.CPU, +) -> None: + """ + Set the active backend (and target device) for the current execution context. + + The first call binds both the backend and the device; subsequent calls must use the + same backend *and* the same device or a :class:`RuntimeError` is raised. This + single-backend, single-device invariant lets the rest of the interoperability layer + skip framework dispatch and isinstance checks, and lets backends construct array + creation routines bound to a specific accelerator. + + Backend modules are auto-imported on demand. + + Args: + backend: A :class:`~decent_array.types.SupportedFrameworks` value, its canonical string (e.g. + ``"numpy"``, ``"pytorch"``), or any alias declared by the backend at + registration time. Aliases are only resolvable after the backend module has + been loaded; pass the canonical name on the first call to trigger autoload. + device: Target accelerator. Accepts a :class:`~decent_array.types.SupportedDevices` value or its + string equivalent (``"cpu"``, ``"gpu"``, ``"mps"``). Defaults to CPU. The + backend's array-creation methods produce arrays on this device by default. + + Note: + Raises :class:`ImportError` if the backend module cannot be imported (e.g. due to + a missing optional dependency). + + Raises: + RuntimeError: If a different backend (or the same backend with a different device) + is already active in this context. + + """ + requested = _normalize(backend) + requested_device = device if isinstance(device, SupportedDevices) else SupportedDevices(device) + + current = _ACTIVE_BACKEND.get() + if current is not None and current != requested: + raise RuntimeError( + f"Backend already set to '{current.value}', cannot set to '{requested.value}'. " + "A single execution context may only use one backend." + ) + + cached = _instantiate(requested, requested_device) + if cached.device != requested_device: + raise RuntimeError( + f"Backend '{requested.value}' already configured with device " + f"'{cached.device.value}', cannot reconfigure to '{requested_device.value}'." + ) + + if current is None: + _ACTIVE_BACKEND.set(requested) + global _BACKEND_INSTANCE # noqa: PLW0603 + _BACKEND_INSTANCE = cached + for listener in _BACKEND_LISTENERS: + listener(_BACKEND_INSTANCE) + + +def register_backend_listener(listener: Callable[[Backend | None], None]) -> None: + """ + Register a callback to be invoked on backend activation. + + The callback receives the active backend instance as its only argument. If a backend + is already active, the callback is invoked immediately with the current backend. + + Args: + listener: A callable that accepts a single :class:`Backend` instance argument. + + """ + _BACKEND_LISTENERS.append(listener) + if _BACKEND_INSTANCE is not None: + listener(_BACKEND_INSTANCE) + + +def register_backend( + backend: SupportedFrameworks, + cls: type[Backend], +) -> None: + """ + Register a backend class under a :class:`SupportedFrameworks` value. + + Called once per backend module *after* the class definition (rather than as a + class decorator). Decorator-based registration would mark the decorated class as + non-extension under mypyc, blocking native compiled-to-compiled dispatch on + ``_BACKEND.add(...)`` and friends — the call-form keeps concrete backends as + extension classes. + + Backends are instantiated lazily on first use. Re-registering replaces the + previous class and discards any cached instance, but keeps existing aliases + (which still point to the same canonical name). + + Args: + backend: Canonical backend identifier. + cls: A concrete subclass of :class:`Backend`. + + Raises: + TypeError: If ``cls`` is not a subclass of :class:`Backend`. + + """ + if not issubclass(cls, Backend): + raise TypeError(f"Registered backend must be a subclass of Backend, got {cls}") + _BACKEND_REGISTRY[backend] = cls + _BACKEND_INSTANCES.pop(backend, None) + + +def reset_backends() -> None: + """ + Clear the active backend and all cached instances for the current context. + + Intended for tests or tightly scoped execution; not part of normal use. Registry + entries (classes and aliases) are preserved. + """ + global _BACKEND_INSTANCE # noqa: PLW0603 + _ACTIVE_BACKEND.set(None) + _BACKEND_INSTANCES.clear() + _BACKEND_INSTANCE = None + for listener in _BACKEND_LISTENERS: + listener(None) + + +def _normalize(backend: SupportedFrameworks | str) -> SupportedFrameworks: + if isinstance(backend, SupportedFrameworks): + return backend + try: + return SupportedFrameworks(backend) + except ValueError as exc: + valid = ", ".join(f.value for f in SupportedFrameworks) + raise KeyError(f"Unknown backend '{backend}'. Valid backends: {valid}.") from exc + + +def _instantiate(backend: SupportedFrameworks, device: SupportedDevices) -> Backend: + if backend in _BACKEND_INSTANCES: + return _BACKEND_INSTANCES[backend] + + if backend not in _BACKEND_REGISTRY: + _auto_import(backend) + + cls = _BACKEND_REGISTRY.get(backend) + if cls is None: + raise KeyError( + f"Backend '{backend.value}' is not registered. Ensure the corresponding backend module is importable." + ) + + instance = cls(device=device) + _BACKEND_INSTANCES[backend] = instance + return instance + + +def _auto_import(backend: SupportedFrameworks) -> None: + """ + Import the backend's package so its registration side-effect runs. + + Raises: + ImportError: If the backend module cannot be imported. + + """ + current_module = __name__.rsplit(".", 1)[0] + module_name = current_module + f"._{backend.value}" + try: + importlib.import_module(module_name) + except ImportError as exc: + raise ImportError( + f"Failed to import the backend module for '{backend.value}'. Ensure the " + "corresponding backend package is installed and importable." + ) from exc diff --git a/decent_array/interoperability/_decorators.py b/decent_array/interoperability/_decorators.py new file mode 100644 index 0000000..bd0df3f --- /dev/null +++ b/decent_array/interoperability/_decorators.py @@ -0,0 +1,58 @@ +""" +Decorator that bridges :class:`Cost` superclass signatures with framework-native subclass implementations. + +Single-backend semantics make this decorator dramatically simpler than the v1 version: no +framework dispatch, no cross-framework conversion, no ``to_array_like`` magic — just +unwrap input :class:`~decent_array.Array` values to their native form, call the subclass method, and +re-wrap the return if the superclass declared ``-> Array``. +""" + +from __future__ import annotations + +from collections.abc import Callable +from functools import wraps +from typing import Any, cast + + +def autodecorate_cost_method[T: Callable[..., Any]](superclass_method: T) -> Callable[[Callable[..., Any]], T]: + """ + Decorate a Cost method override so its body can use raw framework arrays. + + Each :class:`~decent_array.Array` argument is unwrapped to its underlying value before the call. + If the *superclass* method's return annotation is :class:`~decent_array.Array`, the return value + is re-wrapped in :class:`~decent_array.Array` (unless already wrapped). All other arguments and + return values pass through unchanged. + + Args: + superclass_method: The base-class method being overridden (e.g. ``Cost.function``). + Used solely to look up the declared return type at decoration time. + + Example: + .. code-block:: python + + class LinearRegressionCost(EmpiricalRiskCost): + @autodecorate_cost_method(EmpiricalRiskCost.gradient) + def gradient(self, x: NDArray[float64], indices: ...) -> NDArray[float64]: + # ``x`` arrives as a numpy ndarray; the wrapper unwraps the caller's Array. + return self.A.T @ (self.A @ x - self.b) / self.n_samples + # Return value is wrapped back into Array because EmpiricalRiskCost.gradient + # is annotated ``-> Array``. + + """ + from decent_array import Array # noqa: PLC0415 + + return_is_array = superclass_method.__annotations__.get("return") is Array + + def decorator(func: Callable[..., Any]) -> T: + @wraps(func) + def wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: # noqa: ANN401 + new_args = [a.value if isinstance(a, Array) else a for a in args] + new_kwargs = {k: (v.value if isinstance(v, Array) else v) for k, v in kwargs.items()} + result = func(self, *new_args, **new_kwargs) + if return_is_array and not isinstance(result, Array): + return Array(result) + return result + + return cast("T", wrapper) + + return decorator diff --git a/decent_array/interoperability/_iop/__init__.py b/decent_array/interoperability/_iop/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/decent_array/interoperability/_iop/functions.py b/decent_array/interoperability/_iop/functions.py new file mode 100644 index 0000000..9239a3f --- /dev/null +++ b/decent_array/interoperability/_iop/functions.py @@ -0,0 +1,480 @@ +""" +Module-level interoperability functions. + +Each function delegates to the active backend cached in this module's ``_BACKEND_INSTANCE`` +slot. The slot is rebound by :func:`decent_array.interoperability.set_backend`. +Calling any of these before ``set_backend`` raises +:class:`RuntimeError`. + +When this module and ``Backend`` are mypyc-compiled in the same group, +``_BACKEND_INSTANCE.add(...)`` dispatches as a native compiled-to-compiled call — no Python +attribute lookup, no bound-method allocation per call. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING, Any + +from decent_array.interoperability._backend_manager import register_backend_listener + +if TYPE_CHECKING: + from numpy.typing import NDArray + + from decent_array import Array + from decent_array.interoperability._abstracts import Backend + from decent_array.types import ArrayKey, SupportedDevices + +_BACKEND_INSTANCE: Backend | None = None +_error = RuntimeError("No backend active: call 'set_backend' with a supported framework to activate one.") + + +def _update_backend(backend: Backend | None) -> None: + global _BACKEND_INSTANCE # noqa: PLW0603 + _BACKEND_INSTANCE = backend + + +register_backend_listener(_update_backend) + +# Array creation + + +def zeros(shape: tuple[int, ...]) -> Array: + """Create an array of zeros with the given shape.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.zeros(shape) + + +def zeros_like(array: Array) -> Array: + """Create an array of zeros matching the shape and type of ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.zeros_like(array) + + +def ones(shape: tuple[int, ...]) -> Array: + """Create an array of ones with the given shape.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.ones(shape) + + +def ones_like(array: Array) -> Array: + """Create an array of ones matching the shape and type of ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.ones_like(array) + + +def eye(n: int) -> Array: + """Create an ``n x n`` identity matrix.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.eye(n) + + +def eye_like(array: Array) -> Array: + """Create an identity matrix matching the trailing 2 dims of ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.eye_like(array) + + +def device_to_native(device: SupportedDevices) -> Any: # noqa: ANN401 + """Convert :class:`~decent_array.types.SupportedDevices` to the active backend's native device.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.device_to_native(device) + + +def device_of(array: Array) -> SupportedDevices: + """Return the :class:`~decent_array.types.SupportedDevices` of ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.device_of(array) + + +# Array manipulation + + +def copy(array: Array) -> Array: + """Return a copy of ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.copy(array) + + +def to_numpy(array: Array) -> NDArray[Any]: + """Convert ``array`` to a NumPy array on CPU.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.to_numpy(array) + + +def from_numpy(array: NDArray[Any]) -> Array: + """Convert a NumPy array on CPU to an :class:`~decent_array.Array` on the active backend.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.from_numpy(array) + + +def from_numpy_like(array: NDArray[Any], like: Array) -> Array: + """Convert a NumPy array to an :class:`~decent_array.Array` matching ``like``'s dtype and device.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.from_numpy_like(array, like) + + +def to_array(array: float | bool) -> Array: + """Convert a Python scalar to an :class:`~decent_array.Array` on the active backend.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.to_array(array) + + +def stack(arrays: Sequence[Array], axis: int = 0) -> Array: + """Stack a sequence of arrays along a new dimension.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.stack(arrays, axis) + + +def reshape(array: Array, shape: tuple[int, ...]) -> Array: + """Reshape ``array`` to ``shape``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.reshape(array, shape) + + +def transpose(array: Array, axis: tuple[int, ...] | None = None) -> Array: + """Transpose ``array``; ``None`` reverses dimensions.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.transpose(array, axis) + + +def shape(array: Array) -> tuple[int, ...]: + """Return the shape of ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.shape(array) + + +def size(array: Array) -> int: + """Return the total number of elements in ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.size(array) + + +def ndim(array: Array) -> int: + """Return the number of dimensions of ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.ndim(array) + + +def squeeze(array: Array, axis: int | tuple[int, ...] | None = None) -> Array: + """Remove single-dimensional entries from ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.squeeze(array, axis) + + +def unsqueeze(array: Array, axis: int) -> Array: + """Insert a singleton dimension at ``axis``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.unsqueeze(array, axis) + + +def diag(array: Array) -> Array: + """Diagonal: build from a vector or extract from a matrix.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.diag(array) + + +def astype(array: Array, dtype: type[float | int | bool]) -> float | int | bool: + """Cast a single-element ``array`` to a Python scalar of ``dtype``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.astype(array, dtype) + + +# Linalg + + +def dot(array1: Array, array2: Array) -> Array: + """Dot product of two arrays.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.dot(array1, array2) + + +def matmul(array1: Array, array2: Array) -> Array: + """Matrix multiplication of two arrays.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.matmul(array1, array2) + + +def norm( + array: Array, + p: float = 2, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """Norm of ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.norm(array, p, axis, keepdims) + + +# Math reductions + + +def sum( # noqa: A001 + array: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """Sum elements of ``array`` along ``axis``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.sum(array, axis, keepdims) + + +def mean( + array: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """Mean of ``array`` along ``axis``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.mean(array, axis, keepdims) + + +def min( # noqa: A001 + array: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """Minimum of ``array`` along ``axis``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.min(array, axis, keepdims) + + +def max( # noqa: A001 + array: Array, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, +) -> Array: + """Maximum of ``array`` along ``axis``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.max(array, axis, keepdims) + + +def any(array: Array) -> bool: # noqa: A001 + """Return True if any element of ``array`` is truthy.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.any(array) + + +def all(array: Array) -> bool: # noqa: A001 + """Return True if all elements of ``array`` are truthy.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.all(array) + + +# Math elementwise + + +def add(array1: Array | float, array2: Array | float) -> Array: + """Element-wise addition.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.add(array1, array2) + + +def iadd[T: Array](array1: T, array2: Array | float) -> T: + """In-place element-wise addition.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.iadd(array1, array2) + + +def sub(array1: Array | float, array2: Array | float) -> Array: + """Element-wise subtraction.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.sub(array1, array2) + + +def isub[T: Array](array1: T, array2: Array | float) -> T: + """In-place element-wise subtraction.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.isub(array1, array2) + + +def mul(array1: Array | float, array2: Array | float) -> Array: + """Element-wise multiplication.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.mul(array1, array2) + + +def imul[T: Array](array1: T, array2: Array | float) -> T: + """In-place element-wise multiplication.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.imul(array1, array2) + + +def div(array1: Array | float, array2: Array | float) -> Array: + """Element-wise division.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.div(array1, array2) + + +def idiv[T: Array](array1: T, array2: Array | float) -> T: + """In-place element-wise division.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.idiv(array1, array2) + + +def pow(array: Array, p: float) -> Array: # noqa: A001 + """Raise ``array`` to power ``p``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.pow(array, p) + + +def negative(array: Array) -> Array: + """Element-wise negation.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.negative(array) + + +def absolute(array: Array) -> Array: + """Element-wise absolute value.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.absolute(array) + + +def sqrt(array: Array) -> Array: + """Element-wise square root.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.sqrt(array) + + +# Comparisons + + +def eq(array1: Array | float, array2: Array | float) -> Array: + """Element-wise equality. Returns an :class:`~decent_array.Array` of bools.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.eq(array1, array2) + + +def ne(array1: Array | float, array2: Array | float) -> Array: + """Element-wise inequality. Returns an :class:`~decent_array.Array` of bools.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.ne(array1, array2) + + +def lt(array1: Array | float, array2: Array | float) -> Array: + """Element-wise less-than. Returns an :class:`~decent_array.Array` of bools.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.lt(array1, array2) + + +def le(array1: Array | float, array2: Array | float) -> Array: + """Element-wise less-than-or-equal. Returns an :class:`~decent_array.Array` of bools.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.le(array1, array2) + + +def gt(array1: Array | float, array2: Array | float) -> Array: + """Element-wise greater-than. Returns an :class:`~decent_array.Array` of bools.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.gt(array1, array2) + + +def ge(array1: Array | float, array2: Array | float) -> Array: + """Element-wise greater-than-or-equal. Returns an :class:`~decent_array.Array` of bools.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.ge(array1, array2) + + +# Bitwise + + +def bitwise_and(array1: Array | float, array2: Array | float) -> Array: + """Element-wise bitwise/logical AND.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.bitwise_and(array1, array2) + + +# Operators + + +def sign(array: Array) -> Array: + """Element-wise sign.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.sign(array) + + +def maximum(array1: Array | float, array2: Array | float) -> Array: + """Element-wise maximum.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.maximum(array1, array2) + + +def argmax(array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + """Index of maximum value along ``axis``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.argmax(array, axis, keepdims) + + +def argmin(array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + """Index of minimum value along ``axis``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.argmin(array, axis, keepdims) + + +def set_item(array: Array, key: ArrayKey, value: Array) -> None: + """Set ``array[key] = value`` in place.""" + if _BACKEND_INSTANCE is None: + raise _error + _BACKEND_INSTANCE.set_item(array, key, value) + + +def get_item(array: Array, key: ArrayKey) -> Array: + """Return ``array[key]``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.get_item(array, key) diff --git a/decent_array/interoperability/_iop/rng.py b/decent_array/interoperability/_iop/rng.py new file mode 100644 index 0000000..706bf1e --- /dev/null +++ b/decent_array/interoperability/_iop/rng.py @@ -0,0 +1,211 @@ +""" +Random-number coordination across backends. + +The active backend handles its own RNG, but two extra concerns sit above it: + +1. Python's :mod:`random` is often used incidentally and must also be seeded. +2. NumPy's RNG is frequently consulted by other frameworks (e.g. dataset shuffling + helpers, scikit-learn pre-processing) regardless of the active backend, so its state + must be tracked and restored alongside the active backend's state. + +:class:`_RngCoordinator` owns both concerns. RNG functions exposed by ``_iop`` route +through a process-singleton coordinator. + +When the active backend *is* numpy, the coordinator avoids double-seeding to keep the +RNG-state snapshot self-consistent. +""" + +from __future__ import annotations + +import random +from typing import TYPE_CHECKING, Any + +from decent_array.interoperability._backend_manager import _instantiate, register_backend_listener +from decent_array.types import SupportedDevices, SupportedFrameworks + +if TYPE_CHECKING: + from decent_array import Array + from decent_array.interoperability._abstracts import Backend + + +_NUMPY_STATE_KEY = "__numpy_rng_state__" +_PYTHON_RANDOM_KEY = "__python_random_state__" +_BACKEND_INSTANCE: Backend | None = None +_error = RuntimeError("No backend active: call 'set_backend' with a supported framework to activate one.") + + +def _update_backend(backend: Backend | None) -> None: + global _BACKEND_INSTANCE # noqa: PLW0603 + _BACKEND_INSTANCE = backend + + +register_backend_listener(_update_backend) + + +class _RngCoordinator: + """Coordinate RNG seeding/state across the active backend, NumPy, and Python's random.""" + + def __init__(self) -> None: + self._global_seed: int | None = None + + def set_seed(self, seed: int, *, set_global_seed: bool = True) -> None: + """ + Seed Python's ``random``, NumPy's RNG, and the active backend's RNG. + + Args: + seed: Base seed. + set_global_seed: If False, leaves :func:`get_seed` untouched. Use this for + trial-local reseeding where the externally observable base seed must be + preserved. + + """ + if _BACKEND_INSTANCE is None: + raise _error + + random.seed(seed) + active = _BACKEND_INSTANCE + active.set_seed(seed) + numpy_backend = self._numpy_backend() + if numpy_backend is not active: + numpy_backend.set_seed(seed) + if set_global_seed: + self._global_seed = seed + + def get_seed(self) -> int | None: + """Return the seed last passed to :meth:`set_seed` (with ``set_global_seed=True``).""" + return self._global_seed + + def get_rng_state(self) -> dict[str, Any]: + """ + Snapshot the RNG state of the active backend, NumPy (if auxiliary), and Python's random. + + The active backend's state is returned as-is. If the active backend is not NumPy, + NumPy's state is embedded under the reserved key ``"__numpy_rng_state__"``. The + Python ``random`` state is always embedded under ``"__python_random_state__"`` so + that incidental ``random.random()`` calls survive a snapshot/restore round-trip. + + """ + if _BACKEND_INSTANCE is None: + raise _error + + active = _BACKEND_INSTANCE + state = active.get_rng_state() + state[_PYTHON_RANDOM_KEY] = random.getstate() + numpy_backend = self._numpy_backend() + if numpy_backend is not active: + state[_NUMPY_STATE_KEY] = numpy_backend.get_rng_state() + return state + + def set_rng_state(self, state: dict[str, Any]) -> None: + """Restore a snapshot produced by :meth:`get_rng_state`.""" + if _BACKEND_INSTANCE is None: + raise _error + + # Copy so we can mutate without surprising the caller. + state = dict(state) + python_state = state.pop(_PYTHON_RANDOM_KEY, None) + if python_state is not None: + random.setstate(python_state) + active = _BACKEND_INSTANCE + numpy_backend = self._numpy_backend() + if numpy_backend is not active: + numpy_state = state.pop(_NUMPY_STATE_KEY, None) + if numpy_state is not None: + numpy_backend.set_rng_state(numpy_state) + active.set_rng_state(state) + + def _numpy_backend(self) -> Backend: + return _instantiate(SupportedFrameworks.NUMPY, SupportedDevices.CPU) + + +_COORDINATOR = _RngCoordinator() + + +def set_seed(seed: int) -> None: + """Seed Python ``random``, NumPy, and the active backend's RNG with ``seed``.""" + _COORDINATOR.set_seed(seed) + + +def _set_seed_without_global(seed: int) -> None: + """ + Seed without changing the value returned by :func:`get_seed`. + + Used for trial-local reseeding where the externally observable base seed must be preserved. + """ + _COORDINATOR.set_seed(seed, set_global_seed=False) + + +def _reset_rng() -> None: + """Reset RNG state to a fresh state.""" + global _COORDINATOR # noqa: PLW0603 + _COORDINATOR = _RngCoordinator() + + +def get_seed() -> int | None: + """Return the most recently set global seed, or ``None`` if unset.""" + return _COORDINATOR.get_seed() + + +def get_rng_state() -> dict[str, Any]: + """Return a snapshot of the active backend's RNG state.""" + return _COORDINATOR.get_rng_state() + + +def set_rng_state(state: dict[str, Any]) -> None: + """Restore an RNG snapshot produced by :func:`get_rng_state`.""" + _COORDINATOR.set_rng_state(state) + + +def derive_seed() -> int: + """ + Derive a new seed from the current state. + + This is useful when you want to create a new generator that is independent but reproducible from the current one. + For example, you might use this to seed a data loader's RNG based on the main RNG to ensure that data shuffling is + reproducible across runs, but different from the main RNG used for model initialization. + + Returns: + An integer seed derived from the current RNG state. + + """ + current_seed = get_seed() + if current_seed is None: + return random.randint(0, 2**32 - 1) + # Derive a new seed by hashing the current seed with some random data. + random_data = random.getrandbits(256) + return (current_seed + random_data) % (2**32) + + +def normal(mean: float = 0.0, std: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + """Draw normally distributed samples on the active backend.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.normal(mean, std, shape) + + +def uniform(low: float = 0.0, high: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + """Draw uniformly distributed samples on the active backend.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.uniform(low, high, shape) + + +def normal_like(array: Array, mean: float = 0.0, std: float = 1.0) -> Array: + """Draw normally distributed samples shaped like ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.normal_like(array, mean, std) + + +def uniform_like(array: Array, low: float = 0.0, high: float = 1.0) -> Array: + """Draw uniformly distributed samples shaped like ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.uniform_like(array, low, high) + + +def choice(array: Array, size: int, replace: bool = True) -> Array: + """Sample ``size`` elements from ``array``.""" + if _BACKEND_INSTANCE is None: + raise _error + return _BACKEND_INSTANCE.choice(array, size, replace) diff --git a/decent_array/interoperability/_jax/__init__.py b/decent_array/interoperability/_jax/__init__.py new file mode 100644 index 0000000..23baa9c --- /dev/null +++ b/decent_array/interoperability/_jax/__init__.py @@ -0,0 +1,5 @@ +"""JAX backend package; importing it triggers backend registration.""" + +from .jax_backend import JaxBackend + +__all__ = ["JaxBackend"] diff --git a/decent_array/interoperability/_jax/jax_backend.py b/decent_array/interoperability/_jax/jax_backend.py new file mode 100644 index 0000000..4dbd187 --- /dev/null +++ b/decent_array/interoperability/_jax/jax_backend.py @@ -0,0 +1,301 @@ +""" +JAX backend for interoperability_2. + +Importing this module registers the backend via :func:`register_backend`, so the +package can be auto-loaded on the first ``set_backend("jax")`` call. + +JAX arrays are immutable, so :meth:`set_item` rebinds the wrapper's underlying value +rather than mutating it. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from time import time_ns +from typing import Any, cast + +import jax +import jax.numpy as jnp +import numpy as np +from numpy.typing import NDArray + +from decent_array import Array +from decent_array.interoperability._abstracts import Backend +from decent_array.interoperability._backend_manager import register_backend +from decent_array.types import ArrayKey, SupportedDevices, SupportedFrameworks + + +def _unwrap(array: Any) -> Any: # noqa: ANN401 + """Return the underlying value of an :class:`Array`, or pass ``array`` through.""" + return array.value if type(array) is Array else array + + +class JaxBackend(Backend): # noqa: PLR0904 + """JAX implementation of :class:`Backend`.""" + + def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: + super().__init__(device) + self._native_device: jax.Device = self.device_to_native(device) + self._key: jax.Array = jax.random.key(time_ns()) + + # Array creation + + def zeros(self, shape: tuple[int, ...]) -> Array: + return Array(jnp.zeros(shape, device=self._native_device)) + + def zeros_like(self, array: Array) -> Array: + return Array(jnp.zeros_like(array.value)) + + def ones(self, shape: tuple[int, ...]) -> Array: + return Array(jnp.ones(shape, device=self._native_device)) + + def ones_like(self, array: Array) -> Array: + return Array(jnp.ones_like(array.value)) + + def eye(self, n: int) -> Array: + return Array(jnp.eye(n, device=self._native_device)) + + def eye_like(self, array: Array) -> Array: + v = array.value + rows, cols = v.shape[-2:] + return Array(jnp.eye(rows, cols, dtype=v.dtype, device=v.device)) + + def device_to_native(self, device: SupportedDevices) -> jax.Device: + if device == SupportedDevices.CPU: + return jax.devices("cpu")[0] + if device == SupportedDevices.GPU: + return jax.devices("gpu")[0] + raise ValueError(f"Unsupported device for JAX: {device}") + + def device_of(self, array: Array) -> SupportedDevices: + platform = array.value.device.platform + if platform == "gpu": + return SupportedDevices.GPU + if platform == "cpu": + return SupportedDevices.CPU + raise TypeError(f"Unsupported JAX platform: {platform}") + + # Array manipulation + + def copy(self, array: Array) -> Array: + return Array(jnp.array(array.value, copy=True)) + + def to_numpy(self, array: Array) -> NDArray[Any]: + return np.array(array.value) + + def from_numpy(self, array: NDArray[Any]) -> Array: + return Array(jnp.array(array, device=self._native_device)) + + def from_numpy_like(self, array: NDArray[Any], like: Array) -> Array: + v = like.value + return Array(jnp.asarray(array, dtype=v.dtype, device=v.device)) + + def to_array(self, array: float | bool) -> Array: + return Array(jnp.array(array, device=self._native_device)) + + def stack(self, arrays: Sequence[Array], axis: int = 0) -> Array: + if len(arrays) == 0: + raise ValueError("Cannot stack an empty sequence of arrays.") + return Array(jnp.stack([a.value for a in arrays], axis=axis)) + + def reshape(self, array: Array, shape: tuple[int, ...]) -> Array: + return Array(jnp.reshape(array.value, shape)) + + def transpose(self, array: Array, axis: tuple[int, ...] | None = None) -> Array: + return Array(jnp.transpose(array.value, axes=axis)) + + def shape(self, array: Array) -> tuple[int, ...]: + return tuple(array.value.shape) + + def size(self, array: Array) -> int: + return int(array.value.size) + + def ndim(self, array: Array) -> int: + return int(array.value.ndim) + + def squeeze(self, array: Array, axis: int | tuple[int, ...] | None = None) -> Array: + return Array(jnp.squeeze(array.value, axis=axis)) + + def unsqueeze(self, array: Array, axis: int) -> Array: + return Array(jnp.expand_dims(array.value, axis=axis)) + + def diag(self, array: Array) -> Array: + return Array(jnp.diag(array.value)) + + def astype(self, array: Array, dtype: type[float | int | bool]) -> float | int | bool: + return dtype(array.value.item()) + + # Linalg + + def dot(self, array1: Array, array2: Array) -> Array: + return Array(jnp.dot(array1.value, array2.value)) + + def matmul(self, array1: Array, array2: Array) -> Array: + return Array(array1.value @ array2.value) + + def norm( + self, + array: Array, + p: float = 2, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ) -> Array: + return Array(jnp.linalg.norm(array.value, ord=p, axis=axis, keepdims=keepdims)) + + # Math reductions + + def sum(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(jnp.sum(array.value, axis=axis, keepdims=keepdims)) + + def mean(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(jnp.mean(array.value, axis=axis, keepdims=keepdims)) + + def min(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(jnp.min(array.value, axis=axis, keepdims=keepdims)) + + def max(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(jnp.max(array.value, axis=axis, keepdims=keepdims)) + + def any(self, array: Array) -> bool: + return bool(jnp.any(array.value)) + + def all(self, array: Array) -> bool: + return bool(jnp.all(array.value)) + + # Math elementwise — JAX arrays are immutable; "in-place" ops rebind the wrapper. + # Operands may be Array or scalar (operator dunders pass either); ``Array | float`` + # covers both because PEP 484's numeric tower implicitly admits ``int``. + + def add(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.add(_unwrap(array1), _unwrap(array2))) + + def iadd[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value = jnp.add(array1.value, _unwrap(array2)) + return array1 + + def sub(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.subtract(_unwrap(array1), _unwrap(array2))) + + def isub[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value = jnp.subtract(array1.value, _unwrap(array2)) + return array1 + + def mul(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.multiply(_unwrap(array1), _unwrap(array2))) + + def imul[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value = jnp.multiply(array1.value, _unwrap(array2)) + return array1 + + def div(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.divide(_unwrap(array1), _unwrap(array2))) + + def idiv[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value = jnp.divide(array1.value, _unwrap(array2)) + return array1 + + def pow(self, array: Array, p: float) -> Array: + return Array(jnp.power(array.value, p)) + + def negative(self, array: Array) -> Array: + return Array(jnp.negative(array.value)) + + def absolute(self, array: Array) -> Array: + return Array(jnp.abs(array.value)) + + def sqrt(self, array: Array) -> Array: + return Array(jnp.sqrt(array.value)) + + # Comparisons + + def eq(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.equal(_unwrap(array1), _unwrap(array2))) + + def ne(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.not_equal(_unwrap(array1), _unwrap(array2))) + + def lt(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.less(_unwrap(array1), _unwrap(array2))) + + def le(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.less_equal(_unwrap(array1), _unwrap(array2))) + + def gt(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.greater(_unwrap(array1), _unwrap(array2))) + + def ge(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.greater_equal(_unwrap(array1), _unwrap(array2))) + + # Bitwise + + def bitwise_and(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.bitwise_and(_unwrap(array1), _unwrap(array2))) + + # Operators + + def sign(self, array: Array) -> Array: + return Array(jnp.sign(array.value)) + + def maximum(self, array1: Array | float, array2: Array | float) -> Array: + return Array(jnp.maximum(_unwrap(array1), _unwrap(array2))) + + def argmax(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + return Array(jnp.argmax(array.value, axis=axis, keepdims=keepdims)) + + def argmin(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + return Array(jnp.argmin(array.value, axis=axis, keepdims=keepdims)) + + def set_item(self, array: Array, key: ArrayKey, value: Array) -> None: + # JAX arrays are immutable; rebind the wrapper to a new array with `key` updated. + array.value = array.value.at[key].set(value.value) + + def get_item(self, array: Array, key: ArrayKey) -> Array: + return Array(array.value[key]) + + # RNG + + def set_seed(self, seed: int) -> None: + self._key = jax.random.key(seed) + + def get_rng_state(self) -> dict[str, Any]: + return {"jax_key": jax.random.key_data(self._key)} + + def set_rng_state(self, state: dict[str, Any]) -> None: + if "jax_key" in state: + self._key = jax.random.wrap_key_data(state["jax_key"]) + + def normal(self, mean: float = 0.0, std: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + sub = self._next_key() + sample = jax.random.normal(sub, shape=shape).to_device(self._native_device) + return Array(mean + std * sample) + + def uniform(self, low: float = 0.0, high: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + sub = self._next_key() + return Array(jax.random.uniform(sub, shape=shape, minval=low, maxval=high).to_device(self._native_device)) + + def normal_like(self, array: Array, mean: float = 0.0, std: float = 1.0) -> Array: + v = array.value + sub = self._next_key() + sample = jax.random.normal(sub, shape=v.shape, dtype=v.dtype) + return Array(mean + std * sample) + + def uniform_like(self, array: Array, low: float = 0.0, high: float = 1.0) -> Array: + v = array.value + sub = self._next_key() + return Array(jax.random.uniform(sub, shape=v.shape, dtype=v.dtype, minval=low, maxval=high)) + + def choice(self, array: Array, size: int, replace: bool = True) -> Array: + v = array.value + sub = self._next_key() + indices = jax.random.choice(sub, a=v.shape[0], shape=(size,), replace=replace) + return Array(v[indices]) + + # internals + + def _next_key(self) -> jax.Array: + """Split the stored key, advance state, return a sub-key for one draw.""" + self._key, sub = jax.random.split(self._key) + return cast("jax.Array", sub) + + +register_backend(SupportedFrameworks.JAX, JaxBackend) diff --git a/decent_array/interoperability/_numpy/__init__.py b/decent_array/interoperability/_numpy/__init__.py new file mode 100644 index 0000000..7d2136e --- /dev/null +++ b/decent_array/interoperability/_numpy/__init__.py @@ -0,0 +1,5 @@ +"""NumPy backend package; importing it triggers backend registration.""" + +from .numpy_backend import NumpyBackend + +__all__ = ["NumpyBackend"] diff --git a/decent_array/interoperability/_numpy/numpy_backend.py b/decent_array/interoperability/_numpy/numpy_backend.py new file mode 100644 index 0000000..f3e0dfb --- /dev/null +++ b/decent_array/interoperability/_numpy/numpy_backend.py @@ -0,0 +1,296 @@ +""" +NumPy backend for interoperability_2. + +Importing this module registers the backend via :func:`register_backend`, so the +package can be auto-loaded on the first ``set_backend("numpy")`` call. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from copy import deepcopy +from typing import Any + +import numpy as np +from numpy.typing import NDArray + +from decent_array import Array +from decent_array.interoperability._abstracts import Backend +from decent_array.interoperability._backend_manager import register_backend +from decent_array.types import ArrayKey, SupportedDevices, SupportedFrameworks + + +def _unwrap(array: Any) -> Any: # noqa: ANN401 + """ + Return the underlying value of an :class:`Array`, or pass ``array`` through. + + Typed as ``Any`` because operator dunders may pass either an :class:`Array` or a + Python scalar; the strict abstract signature would force a ``cast`` at every call + site without runtime benefit. + """ + return array.value if type(array) is Array else array + + +class NumpyBackend(Backend): # noqa: PLR0904 + """NumPy implementation of :class:`Backend`.""" + + def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: + if device != SupportedDevices.CPU: + raise ValueError(f"NumPy backend only supports CPU, got '{device.value}'.") + super().__init__(device) + self._rng: np.random.Generator = np.random.default_rng() + + # Array creation + + def zeros(self, shape: tuple[int, ...]) -> Array: + return Array(np.zeros(shape)) + + def zeros_like(self, array: Array) -> Array: + return Array(np.zeros_like(array.value)) + + def ones(self, shape: tuple[int, ...]) -> Array: + return Array(np.ones(shape)) + + def ones_like(self, array: Array) -> Array: + return Array(np.ones_like(array.value)) + + def eye(self, n: int) -> Array: + return Array(np.eye(n)) + + def eye_like(self, array: Array) -> Array: + v = array.value + return Array(np.eye(*v.shape[-2:], dtype=v.dtype)) + + def device_to_native(self, device: SupportedDevices) -> Any: # noqa: ANN401 + # NumPy has no explicit device management; surface the request unchanged. + return device + + def device_of(self, array: Array) -> SupportedDevices: # noqa: ARG002 + return SupportedDevices.CPU + + # Array manipulation + + def copy(self, array: Array) -> Array: + v = array.value + if isinstance(v, np.ndarray | np.generic): + return Array(np.copy(v)) + return Array(deepcopy(v)) + + def to_numpy(self, array: Array) -> NDArray[Any]: + """Return the value of an :class:`Array` as a NumPy array.""" + v = array.value + if isinstance(v, np.ndarray): + return v + return np.asarray(v) + + def from_numpy(self, array: NDArray[Any]) -> Array: + return Array(array) + + def from_numpy_like(self, array: NDArray[Any], like: Array) -> Array: + # NumPy has no device dimension, so only the dtype of ``like`` matters. + return Array(np.asarray(array, dtype=like.value.dtype)) + + def to_array(self, array: float | bool) -> Array: + return Array(np.array(array)) + + def stack(self, arrays: Sequence[Array], axis: int = 0) -> Array: + if len(arrays) == 0: + raise ValueError("Cannot stack an empty sequence of arrays.") + return Array(np.stack([a.value for a in arrays], axis=axis)) + + def reshape(self, array: Array, shape: tuple[int, ...]) -> Array: + return Array(np.reshape(array.value, shape)) + + def transpose(self, array: Array, axis: tuple[int, ...] | None = None) -> Array: + return Array(np.transpose(array.value, axes=axis)) + + def shape(self, array: Array) -> tuple[int, ...]: + return tuple(array.value.shape) + + def size(self, array: Array) -> int: + return int(array.value.size) + + def ndim(self, array: Array) -> int: + return int(array.value.ndim) + + def squeeze(self, array: Array, axis: int | tuple[int, ...] | None = None) -> Array: + return Array(np.squeeze(array.value, axis=axis)) + + def unsqueeze(self, array: Array, axis: int) -> Array: + return Array(np.expand_dims(array.value, axis=axis)) + + def diag(self, array: Array) -> Array: + return Array(np.diag(array.value)) + + def astype(self, array: Array, dtype: type[float | int | bool]) -> float | int | bool: + v = array.value + scalar = v.item() if hasattr(v, "item") else v + return dtype(scalar) + + # Linalg + + def dot(self, array1: Array, array2: Array) -> Array: + return Array(np.dot(array1.value, array2.value)) + + def matmul(self, array1: Array, array2: Array) -> Array: + return Array(array1.value @ array2.value) + + def norm( + self, + array: Array, + p: float = 2, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ) -> Array: + return Array(np.linalg.norm(array.value, ord=p, axis=axis, keepdims=keepdims)) + + # Math reductions + + def sum(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(np.sum(array.value, axis=axis, keepdims=keepdims)) + + def mean(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(np.mean(array.value, axis=axis, keepdims=keepdims)) + + def min(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(np.min(array.value, axis=axis, keepdims=keepdims)) + + def max(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(np.max(array.value, axis=axis, keepdims=keepdims)) + + def any(self, array: Array) -> bool: + return bool(np.any(array.value)) + + def all(self, array: Array) -> bool: + return bool(np.all(array.value)) + + # Math elementwise — operands may be Array or scalar (operator dunders pass either). + # ``Array | float`` covers both: PEP 484's numeric tower implicitly admits ``int``. + + def add(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.add(_unwrap(array1), _unwrap(array2))) + + def iadd[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value += _unwrap(array2) + return array1 + + def sub(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.subtract(_unwrap(array1), _unwrap(array2))) + + def isub[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value -= _unwrap(array2) + return array1 + + def mul(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.multiply(_unwrap(array1), _unwrap(array2))) + + def imul[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value *= _unwrap(array2) + return array1 + + def div(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.divide(_unwrap(array1), _unwrap(array2))) + + def idiv[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value /= _unwrap(array2) + return array1 + + def pow(self, array: Array, p: float) -> Array: + return Array(np.power(array.value, p)) + + def negative(self, array: Array) -> Array: + return Array(np.negative(array.value)) + + def absolute(self, array: Array) -> Array: + return Array(np.abs(array.value)) + + def sqrt(self, array: Array) -> Array: + return Array(np.sqrt(array.value)) + + # Comparisons + + def eq(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.equal(_unwrap(array1), _unwrap(array2))) + + def ne(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.not_equal(_unwrap(array1), _unwrap(array2))) + + def lt(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.less(_unwrap(array1), _unwrap(array2))) + + def le(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.less_equal(_unwrap(array1), _unwrap(array2))) + + def gt(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.greater(_unwrap(array1), _unwrap(array2))) + + def ge(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.greater_equal(_unwrap(array1), _unwrap(array2))) + + # Bitwise + + def bitwise_and(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.bitwise_and(_unwrap(array1), _unwrap(array2))) + + # Operators + + def sign(self, array: Array) -> Array: + return Array(np.sign(array.value)) + + def maximum(self, array1: Array | float, array2: Array | float) -> Array: + return Array(np.maximum(_unwrap(array1), _unwrap(array2))) + + def argmax(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + return Array(np.argmax(array.value, axis=axis, keepdims=keepdims)) + + def argmin(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + return Array(np.argmin(array.value, axis=axis, keepdims=keepdims)) + + def set_item(self, array: Array, key: ArrayKey, value: Array) -> None: + array.value[key] = value.value + + def get_item(self, array: Array, key: ArrayKey) -> Array: + return Array(array.value[key]) + + # RNG + + def set_seed(self, seed: int) -> None: + # Seed both the legacy global state and our owned Generator. The legacy state is + # important because some downstream libraries (sklearn, pandas) consult it. + np.random.seed(seed) # noqa: NPY002 + self._rng = np.random.default_rng(seed) + + def get_rng_state(self) -> dict[str, Any]: + # ``np.random.get_state()`` returns a tuple by default; ``legacy=False`` returns + # the equivalent dict form, which both matches the surrounding ``dict[str, Any]`` + # value type (so mypyc's strict union narrowing is satisfied) and round-trips + # cleanly through ``np.random.set_state``. + return { + "numpy_bit_generator_state": deepcopy(self._rng.bit_generator.state), + "numpy_legacy_state": np.random.get_state(legacy=False), # noqa: NPY002 + } + + def set_rng_state(self, state: dict[str, Any]) -> None: + if "numpy_bit_generator_state" in state: + self._rng = np.random.default_rng() + self._rng.bit_generator.state = state["numpy_bit_generator_state"] + if "numpy_legacy_state" in state: + np.random.set_state(state["numpy_legacy_state"]) # noqa: NPY002 + + def normal(self, mean: float = 0.0, std: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + return Array(self._rng.normal(loc=mean, scale=std, size=shape)) + + def uniform(self, low: float = 0.0, high: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + return Array(self._rng.uniform(low=low, high=high, size=shape)) + + def normal_like(self, array: Array, mean: float = 0.0, std: float = 1.0) -> Array: + return Array(self._rng.normal(loc=mean, scale=std, size=array.value.shape)) + + def uniform_like(self, array: Array, low: float = 0.0, high: float = 1.0) -> Array: + return Array(self._rng.uniform(low=low, high=high, size=array.value.shape)) + + def choice(self, array: Array, size: int, replace: bool = True) -> Array: + return Array(self._rng.choice(array.value, size=size, replace=replace)) + + +register_backend(SupportedFrameworks.NUMPY, NumpyBackend) diff --git a/decent_array/interoperability/_pytorch/__init__.py b/decent_array/interoperability/_pytorch/__init__.py new file mode 100644 index 0000000..7e28c44 --- /dev/null +++ b/decent_array/interoperability/_pytorch/__init__.py @@ -0,0 +1,5 @@ +"""PyTorch backend package; importing it triggers backend registration.""" + +from .pytorch_backend import PyTorchBackend + +__all__ = ["PyTorchBackend"] diff --git a/decent_array/interoperability/_pytorch/pytorch_backend.py b/decent_array/interoperability/_pytorch/pytorch_backend.py new file mode 100644 index 0000000..39dfc31 --- /dev/null +++ b/decent_array/interoperability/_pytorch/pytorch_backend.py @@ -0,0 +1,341 @@ +""" +PyTorch backend for interoperability_2. + +Importing this module registers the backend via :func:`register_backend`, so the +package can be auto-loaded on the first ``set_backend("pytorch")`` call. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any + +import numpy as np +import torch +from numpy.typing import NDArray + +from decent_array import Array +from decent_array.interoperability._abstracts import Backend +from decent_array.interoperability._backend_manager import register_backend +from decent_array.types import ArrayKey, SupportedDevices, SupportedFrameworks + + +def _unwrap(array: Any) -> Any: # noqa: ANN401 + """Return the underlying value of an :class:`Array`, or pass ``array`` through.""" + return array.value if type(array) is Array else array + + +class PyTorchBackend(Backend): # noqa: PLR0904 + """PyTorch implementation of :class:`Backend`.""" + + def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: + super().__init__(device) + self._native_device: str = self.device_to_native(device) + self._generator: torch.Generator = torch.Generator(device=self._native_device) + + # Array creation + + def zeros(self, shape: tuple[int, ...]) -> Array: + return Array(torch.zeros(shape, device=self._native_device)) + + def zeros_like(self, array: Array) -> Array: + return Array(torch.zeros_like(array.value)) + + def ones(self, shape: tuple[int, ...]) -> Array: + return Array(torch.ones(shape, device=self._native_device)) + + def ones_like(self, array: Array) -> Array: + return Array(torch.ones_like(array.value)) + + def eye(self, n: int) -> Array: + return Array(torch.eye(n, device=self._native_device)) + + def eye_like(self, array: Array) -> Array: + v = array.value + return Array(torch.eye(*v.shape[-2:], dtype=v.dtype, device=v.device)) + + def device_to_native(self, device: SupportedDevices) -> str: + if device == SupportedDevices.CPU: + return "cpu" + if device == SupportedDevices.GPU: + return "cuda" + if device == SupportedDevices.MPS: + return "mps" + raise ValueError(f"Unsupported device: {device}") + + def device_of(self, array: Array) -> SupportedDevices: + kind = array.value.device.type + if kind == "cpu": + return SupportedDevices.CPU + if kind == "cuda": + return SupportedDevices.GPU + if kind == "mps": + return SupportedDevices.MPS + raise TypeError(f"Unsupported PyTorch device type: {kind}") + + # Array manipulation + + def copy(self, array: Array) -> Array: + return Array(array.value.detach().clone()) + + def to_numpy(self, array: Array) -> NDArray[Any]: + """Return the value of an :class:`Array` as a NumPy array.""" + v = array.value + if isinstance(v, torch.Tensor): + ret: NDArray[Any] = v.cpu().numpy() + else: + ret = np.asarray(v) + return ret + + def from_numpy(self, array: NDArray[Any]) -> Array: + return Array(torch.from_numpy(array).to(device=self._native_device)) + + def from_numpy_like(self, array: NDArray[Any], like: Array) -> Array: + v = like.value + return Array(torch.from_numpy(array).to(dtype=v.dtype, device=v.device)) + + def to_array(self, array: float | bool) -> Array: + return Array(torch.tensor(array, device=self._native_device)) + + def stack(self, arrays: Sequence[Array], axis: int = 0) -> Array: + if len(arrays) == 0: + raise ValueError("Cannot stack an empty sequence of arrays.") + return Array(torch.stack([a.value for a in arrays], dim=axis)) + + def reshape(self, array: Array, shape: tuple[int, ...]) -> Array: + return Array(torch.reshape(array.value, shape)) + + def transpose(self, array: Array, axis: tuple[int, ...] | None = None) -> Array: + v = array.value + dims = axis if axis is not None else tuple(reversed(range(v.ndim))) + return Array(torch.permute(v, dims=dims)) + + def shape(self, array: Array) -> tuple[int, ...]: + return tuple(array.value.shape) + + def size(self, array: Array) -> int: + return int(array.value.numel()) + + def ndim(self, array: Array) -> int: + return int(array.value.ndim) + + def squeeze(self, array: Array, axis: int | tuple[int, ...] | None = None) -> Array: + v = array.value + if axis is None: + return Array(torch.squeeze(v)) + return Array(torch.squeeze(v, dim=axis)) + + def unsqueeze(self, array: Array, axis: int) -> Array: + return Array(torch.unsqueeze(array.value, dim=axis)) + + def diag(self, array: Array) -> Array: + return Array(torch.diag(array.value)) + + def astype(self, array: Array, dtype: type[float | int | bool]) -> float | int | bool: + return dtype(array.value.item()) + + # Linalg + + def dot(self, array1: Array, array2: Array) -> Array: + return Array(torch.dot(array1.value, array2.value)) + + def matmul(self, array1: Array, array2: Array) -> Array: + return Array(array1.value @ array2.value) + + def norm( + self, + array: Array, + p: float = 2, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ) -> Array: + return Array(torch.linalg.norm(array.value, ord=p, axis=axis, keepdim=keepdims)) + + # Math reductions + + def sum(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + v = array.value + if axis is None: + return Array(torch.sum(v)) + return Array(torch.sum(v, dim=axis, keepdim=keepdims)) + + def mean(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + v = array.value + if axis is None: + return Array(torch.mean(v)) + return Array(torch.mean(v, dim=axis, keepdim=keepdims)) + + def min(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + v = array.value + if axis is None: + return Array(torch.min(v)) + return Array(torch.amin(v, dim=axis, keepdim=keepdims)) + + def max(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + v = array.value + if axis is None: + return Array(torch.max(v)) + return Array(torch.amax(v, dim=axis, keepdim=keepdims)) + + def any(self, array: Array) -> bool: + return bool(torch.any(array.value).item()) + + def all(self, array: Array) -> bool: + return bool(torch.all(array.value).item()) + + # Math elementwise — operands may be Array or scalar (operator dunders pass either). + # ``Array | float`` covers both: PEP 484's numeric tower implicitly admits ``int``. + + def add(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.add(_unwrap(array1), _unwrap(array2))) + + def iadd[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value.add_(_unwrap(array2)) + return array1 + + def sub(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.sub(_unwrap(array1), _unwrap(array2))) + + def isub[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value.sub_(_unwrap(array2)) + return array1 + + def mul(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.mul(_unwrap(array1), _unwrap(array2))) + + def imul[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value.mul_(_unwrap(array2)) + return array1 + + def div(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.div(_unwrap(array1), _unwrap(array2))) + + def idiv[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value.div_(_unwrap(array2)) + return array1 + + def pow(self, array: Array, p: float) -> Array: + return Array(torch.pow(array.value, p)) + + def negative(self, array: Array) -> Array: + return Array(torch.neg(array.value)) + + def absolute(self, array: Array) -> Array: + return Array(torch.abs(array.value)) + + def sqrt(self, array: Array) -> Array: + return Array(torch.sqrt(array.value)) + + # Comparisons + + def eq(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.eq(_unwrap(array1), _unwrap(array2))) + + def ne(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.ne(_unwrap(array1), _unwrap(array2))) + + def lt(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.lt(_unwrap(array1), _unwrap(array2))) + + def le(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.le(_unwrap(array1), _unwrap(array2))) + + def gt(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.gt(_unwrap(array1), _unwrap(array2))) + + def ge(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.ge(_unwrap(array1), _unwrap(array2))) + + # Bitwise + + def bitwise_and(self, array1: Array | float, array2: Array | float) -> Array: + return Array(torch.bitwise_and(_unwrap(array1), _unwrap(array2))) + + # Operators + + def sign(self, array: Array) -> Array: + return Array(torch.sign(array.value)) + + def maximum(self, array1: Array | float, array2: Array | float) -> Array: + a, b = _unwrap(array1), _unwrap(array2) + # torch.maximum requires both operands to be Tensors; lift Python scalars to + # match the dtype/device of the tensor operand so the contract matches numpy. + if not isinstance(a, torch.Tensor): + ref = b if isinstance(b, torch.Tensor) else None + a = torch.tensor(a, dtype=ref.dtype if ref is not None else None, device=self._native_device) + if not isinstance(b, torch.Tensor): + b = torch.tensor(b, dtype=a.dtype, device=a.device) + return Array(torch.maximum(a, b)) + + def argmax(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + return Array(torch.argmax(array.value, dim=axis, keepdim=keepdims)) + + def argmin(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + return Array(torch.argmin(array.value, dim=axis, keepdim=keepdims)) + + def set_item(self, array: Array, key: ArrayKey, value: Array) -> None: + array.value[key] = value.value + + def get_item(self, array: Array, key: ArrayKey) -> Array: + return Array(array.value[key]) + + # RNG + + def set_seed(self, seed: int) -> None: + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + self._generator.manual_seed(seed) + + def get_rng_state(self) -> dict[str, Any]: + state: dict[str, Any] = { + "torch_cpu_state": torch.random.get_rng_state(), + "torch_generator_state": self._generator.get_state(), + } + if torch.cuda.is_available(): + state["torch_cuda_states"] = torch.cuda.get_rng_state_all() + return state + + def set_rng_state(self, state: dict[str, Any]) -> None: + if "torch_cpu_state" in state: + torch.random.set_rng_state(state["torch_cpu_state"]) + if "torch_cuda_states" in state and torch.cuda.is_available(): + torch.cuda.set_rng_state_all(state["torch_cuda_states"]) + if "torch_generator_state" in state: + self._generator.set_state(state["torch_generator_state"]) + + def normal(self, mean: float = 0.0, std: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + return Array( + torch.normal(mean=mean, std=std, size=shape, device=self._native_device, generator=self._generator) + ) + + def uniform(self, low: float = 0.0, high: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + rand = torch.rand(size=shape, device=self._native_device, generator=self._generator) + return Array((high - low) * rand + low) + + def normal_like(self, array: Array, mean: float = 0.0, std: float = 1.0) -> Array: + v = array.value + return Array( + torch.normal( + mean=mean, + std=std, + size=tuple(v.shape), + dtype=v.dtype, + device=v.device, + generator=self._generator, + ) + ) + + def uniform_like(self, array: Array, low: float = 0.0, high: float = 1.0) -> Array: + v = array.value + rand = torch.rand(size=tuple(v.shape), dtype=v.dtype, device=v.device, generator=self._generator) + return Array((high - low) * rand + low) + + def choice(self, array: Array, size: int, replace: bool = True) -> Array: + v = array.value + weights = torch.ones(v.shape[0], device=v.device) + indices = weights.multinomial(num_samples=size, replacement=replace, generator=self._generator) + return Array(v[indices]) + + +register_backend(SupportedFrameworks.PYTORCH, PyTorchBackend) diff --git a/decent_array/interoperability/_tensorflow/__init__.py b/decent_array/interoperability/_tensorflow/__init__.py new file mode 100644 index 0000000..dae3831 --- /dev/null +++ b/decent_array/interoperability/_tensorflow/__init__.py @@ -0,0 +1,5 @@ +"""TensorFlow backend package; importing it triggers backend registration.""" + +from .tensorflow_backend import TensorflowBackend + +__all__ = ["TensorflowBackend"] diff --git a/decent_array/interoperability/_tensorflow/tensorflow_backend.py b/decent_array/interoperability/_tensorflow/tensorflow_backend.py new file mode 100644 index 0000000..1026701 --- /dev/null +++ b/decent_array/interoperability/_tensorflow/tensorflow_backend.py @@ -0,0 +1,349 @@ +""" +TensorFlow backend for interoperability_2. + +Importing this module registers the backend via :func:`register_backend`, so the +package can be auto-loaded on the first ``set_backend("tensorflow")`` call. + +TF eager Tensors are immutable, so :meth:`set_item` round-trips through numpy and the +in-place math operations rebind the wrapper's underlying value. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import Any, cast + +import numpy as np +import tensorflow as tf +from numpy.typing import NDArray + +from decent_array import Array +from decent_array.interoperability._abstracts import Backend +from decent_array.interoperability._backend_manager import register_backend +from decent_array.types import ArrayKey, SupportedDevices, SupportedFrameworks + + +def _unwrap(array: Any) -> Any: # noqa: ANN401 + """Return the underlying value of an :class:`Array`, or pass ``array`` through.""" + return array.value if type(array) is Array else array + + +class TensorflowBackend(Backend): # noqa: PLR0904 + """TensorFlow implementation of :class:`Backend`.""" + + def __init__(self, device: SupportedDevices = SupportedDevices.CPU) -> None: + super().__init__(device) + self._native_device: str = self.device_to_native(device) + self._generator: tf.random.Generator = tf.random.Generator.from_non_deterministic_state(alg="philox") + + # Array creation + + def zeros(self, shape: tuple[int, ...]) -> Array: + with tf.device(self._native_device): + return Array(tf.zeros(shape)) + + def zeros_like(self, array: Array) -> Array: + return Array(tf.zeros_like(array.value)) + + def ones(self, shape: tuple[int, ...]) -> Array: + with tf.device(self._native_device): + return Array(tf.ones(shape)) + + def ones_like(self, array: Array) -> Array: + return Array(tf.ones_like(array.value)) + + def eye(self, n: int) -> Array: + with tf.device(self._native_device): + return Array(tf.eye(n)) + + def eye_like(self, array: Array) -> Array: + v = array.value + rows, cols = v.shape[-2:] + return Array(tf.eye(rows, cols, dtype=v.dtype)) + + def device_to_native(self, device: SupportedDevices) -> str: + if device in {SupportedDevices.CPU, SupportedDevices.GPU}: + return f"/{device.value}:0" + raise ValueError(f"Unsupported device for TensorFlow: {device}") + + def device_of(self, array: Array) -> SupportedDevices: + device_str = array.value.device.lower() + if "gpu" in device_str or "cuda" in device_str: + return SupportedDevices.GPU + return SupportedDevices.CPU + + # Array manipulation + + def copy(self, array: Array) -> Array: + return Array(tf.identity(array.value)) + + def to_numpy(self, array: Array) -> NDArray[Any]: + """Return the value of an :class:`Array` as a NumPy array.""" + v = array.value + if isinstance(v, tf.Tensor): + ret: NDArray[Any] = v.numpy() + else: + ret = np.asarray(v) + return ret + + def from_numpy(self, array: NDArray[Any]) -> Array: + """Create an :class:`Array` from a NumPy array.""" + with tf.device(self._native_device): + return Array(tf.convert_to_tensor(array)) + + def from_numpy_like(self, array: NDArray[Any], like: Array) -> Array: + """Create an :class:`Array` from a NumPy array, on ``like``'s device with ``like``'s dtype.""" + v = like.value + with tf.device(v.device): + return Array(tf.convert_to_tensor(array, dtype=v.dtype)) + + def to_array(self, array: float | bool) -> Array: + """Convert a Python scalar to an :class:`Array` on this backend.""" + with tf.device(self._native_device): + return Array(tf.convert_to_tensor(array)) + + def stack(self, arrays: Sequence[Array], axis: int = 0) -> Array: + if len(arrays) == 0: + raise ValueError("Cannot stack an empty sequence of arrays.") + return Array(tf.stack([a.value for a in arrays], axis=axis)) + + def reshape(self, array: Array, shape: tuple[int, ...]) -> Array: + return Array(tf.reshape(array.value, shape)) + + def transpose(self, array: Array, axis: tuple[int, ...] | None = None) -> Array: + return Array(tf.transpose(array.value, perm=axis)) + + def shape(self, array: Array) -> tuple[int, ...]: + return cast("tuple[int, ...]", tuple(array.value.shape)) + + def size(self, array: Array) -> int: + return int(tf.size(array.value).numpy()) + + def ndim(self, array: Array) -> int: + return len(array.value.shape) + + def squeeze(self, array: Array, axis: int | tuple[int, ...] | None = None) -> Array: + return Array(tf.squeeze(array.value, axis=axis)) + + def unsqueeze(self, array: Array, axis: int) -> Array: + return Array(tf.expand_dims(array.value, axis=axis)) + + def diag(self, array: Array) -> Array: + v = array.value + rank = v.shape.ndims + if rank == 1: + return Array(tf.linalg.diag(v)) + if rank == 2: + return Array(tf.linalg.diag_part(v)) + raise ValueError(f"diag requires a 1- or 2-D tensor, got rank {rank}") + + def astype(self, array: Array, dtype: type[float | int | bool]) -> float | int | bool: + return dtype(array.value.numpy().item()) + + # Linalg + + def dot(self, array1: Array, array2: Array) -> Array: + return Array(tf.tensordot(array1.value, array2.value, axes=1)) + + def matmul(self, array1: Array, array2: Array) -> Array: + # tf.matmul requires both operands to have ndim >= 2; fall back to tensordot + # for the vector cases so semantics match numpy / torch / jax matmul. + a, b = array1.value, array2.value + if a.shape.ndims is None or b.shape.ndims is None or a.shape.ndims < 2 or b.shape.ndims < 2: + return Array(tf.tensordot(a, b, axes=1)) + return Array(a @ b) + + def norm( + self, + array: Array, + p: float = 2, + axis: int | tuple[int, ...] | None = None, + keepdims: bool = False, + ) -> Array: + v = array.value + # tf.norm defaults differ from np.linalg.norm on 2-D inputs (operator vs. + # Frobenius); match numpy's flat default by reducing over both trailing axes. + axis = axis if axis is not None else (-2, -1) if v.ndim == 2 else None + return Array(tf.norm(v, ord=p, axis=axis, keepdims=keepdims)) + + # Math reductions + + def sum(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(tf.reduce_sum(array.value, axis=axis, keepdims=keepdims)) + + def mean(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(tf.reduce_mean(array.value, axis=axis, keepdims=keepdims)) + + def min(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(tf.reduce_min(array.value, axis=axis, keepdims=keepdims)) + + def max(self, array: Array, axis: int | tuple[int, ...] | None = None, keepdims: bool = False) -> Array: + return Array(tf.reduce_max(array.value, axis=axis, keepdims=keepdims)) + + def any(self, array: Array) -> bool: + return bool(tf.reduce_any(tf.cast(array.value, tf.bool)).numpy()) + + def all(self, array: Array) -> bool: + return bool(tf.reduce_all(tf.cast(array.value, tf.bool)).numpy()) + + # Math elementwise — TF Tensors are immutable; "in-place" ops rebind the wrapper. + # Operands may be Array or scalar (operator dunders pass either); ``Array | float`` + # covers both because PEP 484's numeric tower implicitly admits ``int``. + + def add(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.add(_unwrap(array1), _unwrap(array2))) + + def iadd[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value = tf.add(array1.value, _unwrap(array2)) + return array1 + + def sub(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.subtract(_unwrap(array1), _unwrap(array2))) + + def isub[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value = tf.subtract(array1.value, _unwrap(array2)) + return array1 + + def mul(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.multiply(_unwrap(array1), _unwrap(array2))) + + def imul[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value = tf.multiply(array1.value, _unwrap(array2)) + return array1 + + def div(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.divide(_unwrap(array1), _unwrap(array2))) + + def idiv[T: Array](self, array1: T, array2: Array | float) -> T: + array1.value = tf.divide(array1.value, _unwrap(array2)) + return array1 + + def pow(self, array: Array, p: float) -> Array: + return Array(tf.pow(array.value, p)) + + def negative(self, array: Array) -> Array: + return Array(tf.negative(array.value)) + + def absolute(self, array: Array) -> Array: + return Array(tf.abs(array.value)) + + def sqrt(self, array: Array) -> Array: + return Array(tf.sqrt(array.value)) + + # Comparisons + + def eq(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.equal(_unwrap(array1), _unwrap(array2))) + + def ne(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.not_equal(_unwrap(array1), _unwrap(array2))) + + def lt(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.less(_unwrap(array1), _unwrap(array2))) + + def le(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.less_equal(_unwrap(array1), _unwrap(array2))) + + def gt(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.greater(_unwrap(array1), _unwrap(array2))) + + def ge(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.greater_equal(_unwrap(array1), _unwrap(array2))) + + # Bitwise — TF's native ``&`` dispatches to ``tf.math.logical_and`` for bool + # tensors and ``tf.bitwise.bitwise_and`` for int tensors, matching numpy/torch/jax + # operator semantics. Calling either named function directly here would constrain + # us to one dtype family. + + def bitwise_and(self, array1: Array | float, array2: Array | float) -> Array: + return Array(_unwrap(array1) & _unwrap(array2)) + + # Operators + + def sign(self, array: Array) -> Array: + return Array(tf.sign(array.value)) + + def maximum(self, array1: Array | float, array2: Array | float) -> Array: + return Array(tf.maximum(_unwrap(array1), _unwrap(array2))) + + def argmax(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + v = array.value + if axis is None: + flat = tf.argmax(tf.reshape(v, [-1]), axis=0) + if keepdims: + ndim = v.shape.ndims or 0 + return Array(tf.reshape(flat, [1] * ndim)) + return Array(flat) + out = tf.argmax(v, axis=axis) + if keepdims: + out = tf.expand_dims(out, axis=axis) + return Array(out) + + def argmin(self, array: Array, axis: int | None = None, keepdims: bool = False) -> Array: + v = array.value + if axis is None: + flat = tf.argmin(tf.reshape(v, [-1]), axis=0) + if keepdims: + ndim = v.shape.ndims or 0 + return Array(tf.reshape(flat, [1] * ndim)) + return Array(flat) + out = tf.argmin(v, axis=axis) + if keepdims: + out = tf.expand_dims(out, axis=axis) + return Array(out) + + def set_item(self, array: Array, key: ArrayKey, value: Array) -> None: + # TF eager tensors are immutable; round-trip through numpy so arbitrary indexing + # patterns (slices, fancy indexing) Just Work. The wrapper is rebound to a fresh + # tensor on the configured device. This is correct but allocates — algorithms + # that hammer set_item in tight loops should consider numpy or pytorch. + original = array.value + np_array = original.numpy().copy() + np_array[key] = np.asarray(value.value) + with tf.device(self._native_device): + array.value = tf.convert_to_tensor(np_array, dtype=original.dtype) + + def get_item(self, array: Array, key: ArrayKey) -> Array: + return Array(array.value[key]) + + # RNG + + def set_seed(self, seed: int) -> None: + tf.random.set_seed(seed) + self._generator = tf.random.Generator.from_seed(seed, alg="philox") + + def get_rng_state(self) -> dict[str, Any]: + return {"tf_generator_state": self._generator.state.numpy()} + + def set_rng_state(self, state: dict[str, Any]) -> None: + if "tf_generator_state" in state: + self._generator = tf.random.Generator.from_state(state["tf_generator_state"], alg="philox") + + def normal(self, mean: float = 0.0, std: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + with tf.device(self._native_device): + return Array(self._generator.normal(shape=shape, mean=mean, stddev=std)) + + def uniform(self, low: float = 0.0, high: float = 1.0, shape: tuple[int, ...] = ()) -> Array: + with tf.device(self._native_device): + return Array(self._generator.uniform(shape=shape, minval=low, maxval=high)) + + def normal_like(self, array: Array, mean: float = 0.0, std: float = 1.0) -> Array: + v = array.value + return Array(self._generator.normal(shape=tf.shape(v), mean=mean, stddev=std, dtype=v.dtype)) + + def uniform_like(self, array: Array, low: float = 0.0, high: float = 1.0) -> Array: + v = array.value + return Array(self._generator.uniform(shape=tf.shape(v), minval=low, maxval=high, dtype=v.dtype)) + + def choice(self, array: Array, size: int, replace: bool = True) -> Array: + v = array.value + n = v.shape[0] + if replace: + indices = self._generator.uniform(shape=(size,), minval=0, maxval=n, dtype=tf.int32) + else: + scores = self._generator.uniform(shape=(n,), dtype=tf.float32) + indices = tf.cast(tf.math.top_k(scores, k=size).indices, tf.int32) + return Array(tf.gather(v, indices)) + + +register_backend(SupportedFrameworks.TENSORFLOW, TensorflowBackend) diff --git a/decent_array/types.py b/decent_array/types.py new file mode 100644 index 0000000..6d62fcb --- /dev/null +++ b/decent_array/types.py @@ -0,0 +1,49 @@ +"""Type definitions for optimization variables.""" + +from __future__ import annotations + +from enum import Enum +from typing import TYPE_CHECKING, SupportsIndex, TypeAlias, Union + +if TYPE_CHECKING: + import jax + import numpy + import tensorflow as tf + import torch + +ArrayLike: TypeAlias = Union["numpy.ndarray", "torch.Tensor", "tf.Tensor", "jax.Array"] # noqa: UP040 +""" +Type alias for array-like types supported in decent-array, including NumPy arrays, +PyTorch tensors, TensorFlow tensors, and JAX arrays. +""" + +SupportedArrayTypes: TypeAlias = ArrayLike | float | int # noqa: UP040 +""" +Type alias for supported types for optimization variables in decent-array, +including array-like types and scalars. +""" + +ArrayKey: TypeAlias = SupportsIndex | slice | tuple[SupportsIndex | slice, ...] # noqa: UP040 +""" +Type alias for valid keys used to index into supported array types. +Includes single indices, tuples of indices, slices, and tuples of slices. +""" + + +# Its important that the enum values correspond to the folder names of the backends, +# since those are used for dynamic imports in _backend_manager.py +class SupportedFrameworks(Enum): + """Enum for supported frameworks in decent-array.""" + + NUMPY = "numpy" + PYTORCH = "pytorch" + TENSORFLOW = "tensorflow" + JAX = "jax" + + +class SupportedDevices(Enum): + """Enum for supported devices in decent-array.""" + + CPU = "cpu" + GPU = "gpu" + MPS = "mps" diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000..d0c3cbf --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/make.bat b/docs/make.bat new file mode 100644 index 0000000..747ffb7 --- /dev/null +++ b/docs/make.bat @@ -0,0 +1,35 @@ +@ECHO OFF + +pushd %~dp0 + +REM Command file for Sphinx documentation + +if "%SPHINXBUILD%" == "" ( + set SPHINXBUILD=sphinx-build +) +set SOURCEDIR=source +set BUILDDIR=build + +%SPHINXBUILD% >NUL 2>NUL +if errorlevel 9009 ( + echo. + echo.The 'sphinx-build' command was not found. Make sure you have Sphinx + echo.installed, then set the SPHINXBUILD environment variable to point + echo.to the full path of the 'sphinx-build' executable. Alternatively you + echo.may add the Sphinx directory to PATH. + echo. + echo.If you don't have Sphinx installed, grab it from + echo.https://www.sphinx-doc.org/ + exit /b 1 +) + +if "%1" == "" goto help + +%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% +goto end + +:help +%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% + +:end +popd diff --git a/docs/source/_static/custom.css b/docs/source/_static/custom.css new file mode 100644 index 0000000..4378c6a --- /dev/null +++ b/docs/source/_static/custom.css @@ -0,0 +1,6 @@ +/* Make the logo bigger */ +.navbar-brand img { + max-height: 65px; + height: auto; + width: auto; +} diff --git a/docs/source/_static/logo.png b/docs/source/_static/logo.png new file mode 100644 index 0000000..963c497 Binary files /dev/null and b/docs/source/_static/logo.png differ diff --git a/docs/source/_templates/module.rst.jinja b/docs/source/_templates/module.rst.jinja new file mode 100644 index 0000000..2644b73 --- /dev/null +++ b/docs/source/_templates/module.rst.jinja @@ -0,0 +1,8 @@ +{%- if show_headings %} +{{- [basename] | join(' ') | e | heading }} + +{% endif -%} +.. automodule:: {{ qualname }} +{%- for option in automodule_options %} + :{{ option }}: +{%- endfor %} diff --git a/docs/source/_templates/package.rst.jinja b/docs/source/_templates/package.rst.jinja new file mode 100644 index 0000000..8e47918 --- /dev/null +++ b/docs/source/_templates/package.rst.jinja @@ -0,0 +1,49 @@ +{%- macro automodule(modname, options) -%} +.. automodule:: {{ modname }} +{%- for option in options %} + :{{ option }}: +{%- endfor %} +{%- endmacro %} + +{%- macro toctree(docnames) -%} +.. toctree:: + :maxdepth: 2 +{% for docname in docnames %} + {{ docname }} +{%- endfor %} +{%- endmacro %} + +{%- if is_namespace %} +{{- [pkgname, "namespace"] | join(" ") | e | heading }} +{% else %} +{{- [pkgname] | join(" ") | e | heading }} +{% endif %} + +{%- if is_namespace %} +.. py:module:: {{ pkgname }} +{% endif %} + +{%- if modulefirst and not is_namespace %} +{{ automodule(pkgname, automodule_options) }} +{% endif %} + +{%- if subpackages %} +{{ toctree(subpackages) }} +{% endif %} + +{%- if submodules %} +{% if separatemodules %} +{{ toctree(submodules) }} +{% else %} +{%- for submodule in submodules %} +{% if show_headings %} +{{- [submodule] | join(" ") | e | heading(2) }} +{% endif %} +{{ automodule(submodule, automodule_options) }} +{% endfor %} +{%- endif %} +{%- endif %} + +{%- if not modulefirst and not is_namespace %} +{{ automodule(pkgname, automodule_options) }} +{% endif %} diff --git a/docs/source/api/decent_array.array.rst b/docs/source/api/decent_array.array.rst new file mode 100644 index 0000000..ab02904 --- /dev/null +++ b/docs/source/api/decent_array.array.rst @@ -0,0 +1,7 @@ +decent\_array.Array +=================== + +.. automodule:: decent_array + :members: + :show-inheritance: + :undoc-members: \ No newline at end of file diff --git a/docs/source/api/decent_array.interoperability.rst b/docs/source/api/decent_array.interoperability.rst new file mode 100644 index 0000000..00ead3d --- /dev/null +++ b/docs/source/api/decent_array.interoperability.rst @@ -0,0 +1,7 @@ +decent\_array.interoperability +============================== + +.. automodule:: decent_array.interoperability + :members: + :show-inheritance: + :undoc-members: diff --git a/docs/source/api/decent_array.rst b/docs/source/api/decent_array.rst new file mode 100644 index 0000000..66c4fcd --- /dev/null +++ b/docs/source/api/decent_array.rst @@ -0,0 +1,14 @@ +decent\_array +============= + +.. toctree:: + :maxdepth: 2 + + decent_array.interoperability + + +.. toctree:: + :maxdepth: 2 + + decent_array.types + decent_array.array \ No newline at end of file diff --git a/docs/source/api/decent_array.types.rst b/docs/source/api/decent_array.types.rst new file mode 100644 index 0000000..6a29483 --- /dev/null +++ b/docs/source/api/decent_array.types.rst @@ -0,0 +1,7 @@ +decent\_array.types +=================== + +.. automodule:: decent_array.types + :members: + :show-inheritance: + :undoc-members: \ No newline at end of file diff --git a/docs/source/author.rst b/docs/source/author.rst new file mode 100644 index 0000000..0506444 --- /dev/null +++ b/docs/source/author.rst @@ -0,0 +1,6 @@ +Contributors +============ + +decent-array is developed by `Simon Granström `_ and +`Adriana Rodriguez `_, under the supervision of +`Dr. Nicola Bastianello `_. diff --git a/docs/source/background.rst b/docs/source/background.rst new file mode 100644 index 0000000..7370504 --- /dev/null +++ b/docs/source/background.rst @@ -0,0 +1,3 @@ +Background +========== + diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000..f95654f --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,158 @@ +# Configuration file for the Sphinx documentation builder. +# +# For the full list of built-in configuration values, see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Project information ----------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information + +project = "decent-array" +copyright = "2026, Team Decent" +author = "team-decent" + +# -- General configuration --------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration + +import os +import sys + +from docutils import nodes + +sys.path.insert(0, os.path.abspath("../..")) +sys.path.insert(0, os.path.abspath("_extensions")) + +extensions = [ + "sphinx.ext.autodoc", # Expand rst automodule directives generated by `sphinx-apidoc` + "sphinx.ext.intersphinx", # Link to types from external packages + "sphinx.ext.napoleon", # Parse Google style docstrings + "sphinx.ext.viewcode", # View source code +] + +nitpicky = True +nitpick_ignore = [ + ("py:class", "numpy.float64"), + ("py:class", "float64"), + ("py:class", "numpy._typing._array_like._SupportsArray"), + ("py:class", "numpy._typing._nested_sequence._NestedSequence"), + ("py:class", "T"), +] + +suppress_warnings = ["toc.duplicate"] + +intersphinx_mapping = { + "numpy": ("https://numpy.org/doc/stable/", None), + "python": ("https://docs.python.org/3", None), + "torch": ("https://pytorch.org/docs/stable/", None), + "tensorflow": ( + "https://www.tensorflow.org/api_docs/python", + "https://github.com/GPflow/tensorflow-intersphinx/raw/master/tf2_py_objects.inv", + ), + "jax": ("https://jax.readthedocs.io/en/latest/", None), +} + + +# A way to link numpy.typing.ArrayLike and NDArray correctly +# Seems to be an open issue, see https://github.com/sphinx-doc/sphinx/issues/10794 +# https://github.com/sphinx-doc/sphinx/issues/10785#issuecomment-1321100925 +def _fix_missing_ref(app, env, node, contnode): + if node.get("refdomain") == "py" and node.get("reftype") in {"class", "data"}: + target = node.get("reftarget") + if target in {"ArrayLike", "numpy.typing.ArrayLike"}: + return nodes.reference( + "", "ArrayLike", refuri="https://numpy.org/doc/stable/reference/typing.html#numpy.typing.ArrayLike" + ) + if target in {"NDArray", "numpy.typing.NDArray"}: + return nodes.reference( + "", + "NDArray", + refuri="https://numpy.org/doc/stable/reference/generated/numpy.ndarray.html#numpy.ndarray", + ) + if target in {"DTypeLike", "numpy.typing.DTypeLike", "numpy._typing.DTypeLike"}: + return nodes.reference( + "", "DTypeLike", refuri="https://numpy.org/doc/stable/reference/typing.html#numpy.typing.DTypeLike" + ) + if target in {"DefaultContext", "multiprocessing.context.DefaultContext"}: + return nodes.reference( + "", + "DefaultContext", + refuri="https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_context", + ) + if target in {"SpawnContext", "multiprocessing.context.SpawnContext"}: + return nodes.reference( + "", + "SpawnContext", + refuri="https://docs.python.org/3/library/multiprocessing.html#multiprocessing.get_context", + ) + if target in {"TensorflowGenerator"}: + return nodes.reference( + "", + "TensorflowGenerator", + refuri="https://www.tensorflow.org/api_docs/python/tf/random/Generator", + ) + if target in {"TorchGenerator"}: + return nodes.reference( + "", + "TorchGenerator", + refuri="https://pytorch.org/docs/stable/generated/torch.Generator.html#torch.Generator", + ) + if target in {"JaxArray"}: + return nodes.reference( + "", + "JaxArray", + refuri="https://docs.jax.dev/en/latest/_autosummary/jax.Array.html#jax.Array", + ) + if target in {"TorchTensor"}: + return nodes.reference( + "", + "TorchTensor", + refuri="https://docs.pytorch.org/docs/stable/tensors.html#torch-tensor", + ) + if target in {"TensorflowTensor"}: + return nodes.reference( + "", + "TensorflowTensor", + refuri="https://www.tensorflow.org/api_docs/python/tf/Tensor", + ) + return None + + +def setup(app): + app.connect("missing-reference", _fix_missing_ref) + + +# -- Options for HTML output ------------------------------------------------- +# https://www.sphinx-doc.org/en/master/usage/configuration.html#options-for-html-output + +html_theme = "pydata_sphinx_theme" +html_theme_options = { + "logo": { + "text": project, + }, + "icon_links": [ + { + "name": "GitHub", + "url": "https://github.com/team-decent/decent-array", + "icon": "fa-brands fa-github", + "type": "fontawesome", + }, + { + "name": "PyPI", + "url": "https://pypi.org/project/decent-array/", + "icon": "fa-solid fa-box-open", + "type": "fontawesome", + }, + ], + "icon_links_label": "Quick Links", +} +html_context = {"default_mode": "auto"} +html_show_sourcelink = False +html_logo = "_static/logo.png" +html_favicon = "_static/logo.png" +html_static_path = ["_static"] +html_css_files = ["custom.css"] +html_sidebars = { + "author": [], + "background": [], + "developer": [], + "user": [], +} diff --git a/docs/source/developer.rst b/docs/source/developer.rst new file mode 100644 index 0000000..053f9ec --- /dev/null +++ b/docs/source/developer.rst @@ -0,0 +1,203 @@ +Developer Guide +=============== +Want to contribute to decent-array? That's great! This guide contains useful information +about development tools, processes, and rules. + + + +Getting Started +--------------- + +Prerequisites +~~~~~~~~~~~~~ +* `Python 3.13+ `_ +* `tox `_ + +Installation for Development +~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +.. code-block:: + + git clone https://github.com/team-decent/decent-array.git + cd decent-array + tox -e dev # create dev env (admin privileges may be needed) + source .tox/dev/bin/activate # activate dev env on Mac/Linux + .\.tox\dev\Scripts\activate # activate dev env on Windows + +Optionally install development dependencies with proper gpu support, e.g. for PyTorch and TensorFlow: + +.. code-block:: + + tox -e dev-gpu + +It is not recommended to use the development environments for regular usage of decent-array, as they +contain additional packages that are not needed for that purpose. This may cause performance degradation +due to multiple packages competing for resources (e.g. GPU resources). + +Tooling +------- +To make sure all GitHub status checks pass, simply run :code:`tox`. You can also run individual checks: + +.. code-block:: + + tox -e mypy # find typing issues + tox -e pytest # run tests + tox -e ruff # find formatting and style issues + tox -e sphinx # rebuild documentation + +Note: Running :code:`tox` commands can take several minutes and may require admin privileges. +If you have mypy addon installed in your IDE, you can use it to get instant feedback on typing issues while coding. +If mypy fails with ``KeyError: 'setter_type'``, delete the ``.mypy_cache`` folder in the project root. + +Tools can also be used directly (instead of via tox) after activating the dev environment. Useful examples include: + +.. code-block:: + + ruff check decent_array --fix # find and fix style issues + ruff format decent_array # format code + mypy decent_array --strict # find typing issues + pytest tests # run tests + sphinx-build -W -E -b html docs/source docs/build/html # rebuild html doc files + +To verify that doc changes look good, use an html previewer such as +`Live Preview `_. +If you are running :code:`pytest tests` while using ``WSL`` on Windows and it starts to randomly fail (or if its really slow), restart your ``WSL`` instance. + + + +Compiled hot path (mypyc) +------------------------- +The :class:`~decent_array.Array` wrapper and the entire +:mod:`~decent_array.interoperability` package are compiled to C extensions with +`mypyc `_ at wheel-build time, so end users get the +speed-up automatically when installing from a wheel. During development you can +manage the compiled artifacts with two tox environments: + +.. code-block:: + + tox -e mypyc # compile in-place (drops .so files alongside the .py source) + tox -e clean-mypyc # remove compiled .so files so Python loads the .py source again + +When both ``module.py`` and ``module.cpython-*.so`` exist in the same directory, Python +imports the ``.so``. This means edits to a compiled module will not take effect until +either ``tox -e mypyc`` is rerun or ``tox -e clean-mypyc`` is invoked. A typical +edit-test-edit cycle therefore looks like: + +1. Run :code:`tox -e clean-mypyc` once at the start of the session. +2. Edit, run, repeat (Python loads the ``.py`` source directly — no rebuild needed). +3. When measuring performance, run :code:`tox -e mypyc` to recompile and re-bench. + +Compilation takes ~25 s from a clean state. The build emits a hash-named shared-runtime +file (``__mypyc.cpython-*.so``) at the project root that holds helpers used by +every compiled module; it must sit on ``sys.path`` because the compiled modules import +it by bare name. The file is gitignored. + + + +Performance tools +----------------- +The :code:`benchmarks/` directory contains scripts to measure wrapper overhead and +performance of the interoperability layer. + + +CUTE Design Principles +---------------------- +CUTE is a set of principles that serve as guidelines for code design. They are meant to help keep the +codebase simple and the development fast. To mitigate any conflict, the principles are ordered from most to least +important: + +1. **Correctness**: working code is the top priority. +2. **Understandability**: others should easily understand your code, avoid bloat, unnecessary indirection, and fancy + abstractions. +3. **Testability**: code allows for short and clear tests. +4. **Extendability**: code allows for future extension, but avoid premature generalization and keep YAGNI and KISS in + mind as trying to predict tomorrow's requirements can cause more problems than it solves. + + + +Pull Requests +------------- +To give other contributors an opportunity to review and to run GitHub status checks, we use pull requests instead of +merging directly to main. The process is detailed below: + +1. Fork the repository. +2. Create a feature branch. +3. Make your changes. +4. Update documentation as needed. +5. Run :code:`tox` to ensure that all checks pass. +6. Submit a pull request. +7. Doc changes? Click the readthedocs link found in the status checks to verify that everything looks good. + + + +Commit Messages +--------------- +To keep the git history easy to follow, encourage well-scoped PRs, and facilitate changelog writing and versioning, we +follow certain rules for commit messages when merging pull requests into main. Each message uses this template: + +.. code-block:: bash + :caption: Template + + (): (#) + + + + closes # + +.. code-block:: bash + :caption: Example + + perf(costs): Cache m_cvx and m_smooth (#105) + + Cache the properties m_cvx and m_smooth where applicable. This led to a + 75% speed up when running ADMM on a logistic regression problem. + + closes #101 + +Notes: + - See table below for types. + - Scope can be a subpackage, module or build tool, e.g. metrics, costs, or sphinx. + - Max 72 characters per line. + - Capitalize but do not punctuate subject. + - Start subject and description with a verb. + - Use imperative mood in subject and description. + - Description explains what changes and why it changes. + - If the PR has a related issue but doesn't close it, skip the "closes"-keyword and simply reference the issue. + +.. list-table:: + :widths: 15 40 + :header-rows: 1 + + * - Type + - Description + * - feat + - New functionality + * - perf + - Performance improvement + * - ref + - Refactor + * - enh + - Small improvement that doesn't qualify as feat, perf, or ref, e.g. improved variable naming, additional logging, + or prettier plots + * - fix + - Bug fix + * - test + - Change to tests + * - docs + - Update to readme, comments, docstrings, rst files, or sphinx config + * - ci + - CI related change, e.g. modifying GitHub checks or tox environments + * - meta + - Update to metadata, e.g. project description, version, or .gitignore + * - license + - License update + +Inspired by `Sentry `_. + + + +Releases +-------- +1. Update the version in pyproject.toml using `Semantic Versioning `_. +2. Merge the change into main with commit message :code:`meta: Bump version to .. (#)`. +3. Create a new release on GitHub. +4. Publish to PyPI using :code:`hatch clean && hatch build && hatch publish`. diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000..f9ec474 --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,15 @@ +Welcome to Decent-Array! +======================================== + +Decent-Array is a Python library for efficient and flexible array computations. +It provides a unified interface for working with arrays across different backends, such as NumPy, PyTorch, and TensorFlow. +With Decent-Array, you can write code that is portable and optimized for performance, without having to worry about the underlying implementation details. + +.. toctree:: + :maxdepth: 1 + + background + user + API Reference + developer + author \ No newline at end of file diff --git a/docs/source/user.rst b/docs/source/user.rst new file mode 100644 index 0000000..735989b --- /dev/null +++ b/docs/source/user.rst @@ -0,0 +1,12 @@ +User Guide +========== +This user guide shows you different examples of how to use decent-array. + + +Installation +------------ +Requires `Python 3.13+ `_ + +.. code-block:: bash + + pip install decent-array diff --git a/docs/sphinx_theme.txt b/docs/sphinx_theme.txt new file mode 100644 index 0000000..6246669 --- /dev/null +++ b/docs/sphinx_theme.txt @@ -0,0 +1,2 @@ +pydata-sphinx-theme +sphinxcontrib-bibtex diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..38542cc --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,190 @@ +[project] +name = "decent-array" +version = "0.1.0" +authors = [{name = "Elias Ram"}, {name = "Simon Granström"}, {name = "Adriana Rodriguez"}, {name = "Nicola Bastianello"}] +maintainers = [{name = "Team Decent"}] +description = "A library of array operations and linear algebra primitives for interoperability across ML frameworks." +readme = "README.md" +requires-python = ">=3.13" +classifiers = [ + "Programming Language :: Python :: 3.13", + "Operating System :: OS Independent", +] +license = "AGPL-3.0-only" +dependencies = [ + "numpy", +] + +[project.urls] +Documentation = "https://decent-array.readthedocs.io/en/latest/" +Source = "https://github.com/team-decent/decent-array" +Issues = "https://github.com/team-decent/decent-array/issues" + +[project.optional-dependencies] +dev = [ + "hatch", + "mypy", + "pytest", + "ruff", + "scipy-stubs", + "types-networkx", + "types-tabulate", + "torch", + "torchvision", + "types-tensorflow", +] +dev-cpu = [ + "tensorflow", + "jax", +] +dev-gpu = [ + "tensorflow[and-cuda]", + "jax[cuda12]", +] +sphinx-tools = [ + "sphinx", + "pydata-sphinx-theme", + "sphinxcontrib-bibtex", +] + + +[build-system] +requires = ["hatchling", "hatch-mypyc"] +build-backend = "hatchling.build" +# Compile hot-path modules with mypyc when building wheels. End users installing from +# a wheel get a precompiled .so; ``pip install -e .`` and source builds also run this +# (a C compiler is required). The ``.py`` source remains shipped alongside the .so so +# Python keeps a fallback if the extension fails to import. Tox envs that only need +# the ``.py`` source (dev, dev-gpu, sphinx, mypy, ruff) opt out via +# ``HATCH_BUILD_NO_HOOKS=true`` to keep their installs fast. +[tool.hatch.build.targets.wheel.hooks.mypyc] +dependencies = ["hatch-mypyc"] +# Compile the Array wrapper plus the entire iop2 package as one mypyc group: keeping +# them in a single compilation unit means cross-module calls (Array → backend, iop +# function → backend) are direct compiled-to-compiled calls. New framework backends +# get picked up automatically — no edits needed here. ``--ignore-missing-imports`` +# treats torch/jax/tensorflow types as ``Any`` when those packages aren't in the +# build environment; the wrapper-side boundary cost is still removed. +include = [ + "decent_array/", +] +mypy-args = ["--ignore-missing-imports"] + +[tool.tox] +envlist = ["dev", "mypy", "pytest", "ruff", "sphinx"] + +[tool.tox.env.dev] +description = "Generate dev venv with all dependencies, active with `source .tox/dev/bin/activate`" +set_env = { HATCH_BUILD_NO_HOOKS = "true" } +deps = [ + ".[dev]", + ".[dev-cpu]", + ".[sphinx-tools]", + "git+https://github.com/microsoft/python-type-stubs.git@main" +] + +[tool.tox.env.dev-gpu] +description = "Generate dev venv with all dependencies including GPU support, active with `source .tox/dev-gpu/bin/activate`" +set_env = { PIP_EXTRA_INDEX_URL = "https://download.pytorch.org/whl/cu128", HATCH_BUILD_NO_HOOKS = "true" } # torch 2.11 uses CUDA 13.0 which seems to have broken linalg.norm, use cuda 12.8 for now +deps = [ + ".[dev]", + ".[dev-gpu]", + ".[sphinx-tools]", + "git+https://github.com/microsoft/python-type-stubs.git@main" +] + +[tool.tox.env.sphinx] +description = "Generate rst and html files using sphinx" +set_env = { HATCH_BUILD_NO_HOOKS = "true" } +deps = [".[sphinx-tools]"] +commands = [ + ["sphinx-apidoc", "-o", "docs/source/api", "decent_array", "--separate", "--no-toc", "--templatedir=docs/source/_templates"], + ["sphinx-build", "-W", "-E", "-b", "html", "docs/source", "docs/build/html"] +] + +[tool.tox.env.mypy] +description = "Run mypy (static type checker)" +set_env = { HATCH_BUILD_NO_HOOKS = "true" } +deps = [".[dev]", ".[dev-cpu]", "git+https://github.com/microsoft/python-type-stubs.git@main"] +commands = [["mypy", "decent_array"]] + +[tool.tox.env.clean-mypyc] +description = "Remove mypyc-compiled artifacts so Python imports the .py source again (use during dev iteration)" +set_env = { HATCH_BUILD_NO_HOOKS = "true" } +skip_install = true +deps = [] +commands = [ + ["python", "-c", "import pathlib; ps = list(pathlib.Path('.').glob('decent_array/**/*.so')) + list(pathlib.Path('.').glob('*__mypyc.cpython-*.so')); [p.unlink() for p in ps]; print('removed', len(ps), 'compiled artifact(s)')"], +] + +[tool.tox.env.mypyc] +description = "Compile hot-path modules in place with mypyc (drops .so files alongside the .py source)" +deps = ["mypy", "numpy", "setuptools"] +skip_install = true +# First command clears any stale shared-runtime ``.so`` at the project root from a +# previous compile (its name is hash-derived from the file list, so changing inputs +# leaves orphans); second command compiles the Array module + entire iop2 package +# (including ``_abstracts``) as one group so cross-module dispatch (Array → +# _Backend.add → concrete backend method) is native compiled-to-compiled. +commands = [ + ["python", "-c", "import pathlib; [p.unlink() for p in pathlib.Path('.').glob('*__mypyc.cpython-*.so')]"], + ["python", "-m", "mypyc", + "--ignore-missing-imports", + "decent_array/", + ], +] + +[tool.tox.env.pytest] +description = "Run pytest (test executor)" +deps = [".[dev]", ".[dev-cpu]"] +commands = [["pytest", "-rs"]] + +[tool.tox.env.ruff] +description = "Run ruff (format and style checker)" +deps = ["ruff"] +set_env = { HATCH_BUILD_NO_HOOKS = "true" } +skip_install = true +commands = [ + ["ruff", "check", "decent_array"], + ["ruff", "format", "decent_array", "--check"] +] + +[tool.pytest.ini_options] +addopts = "-q -W error" +testpaths = ["tests"] + +[tool.mypy] +strict = true + +[tool.ruff] +lint.flake8-annotations.mypy-init-return = true +lint.select = ["ALL"] +lint.ignore = [ + "BLE001", # blind-except, complains when catching `Exception` + "C901", # complex-structure, complains when there are many branches + "COM812", # missing-trailing-comma, may cause conflicts when used with the formatter + "CPY001", # missing-copyright-notice + "D100", # undocumented-public-module + "D104", # undocumented-public-package + "D107", # undocumented-public-init + "D203", # incorrect-blank-line-before-class, incompatible with no-blank-line-before-class (D211) + "D212", # multi-line-summary-first-line, incompatible with multi-line-summary-second-line (D213) + "D415", # missing-terminal-punctuation, D400 already enforces first line to end with period + "DOC201", # docstring-missing-returns, always documenting the return value is too verbose + "EM101", # raw-string-in-exception, complains about putting raw string as error msg + "EM102", # f-string-in-exception, complains about putting f string as error msg + "FURB140", # reimplemented-starmap, wants to replace comprehensions with the less efficient itertools.starmap + "PLR0913", # too-many-arguments, complains when there are more than 5 arguments + "PLR2004", # magic-value-comparison, complains about dimension checks in cost functions + "PLR6301", # no-self-use, complains when self is unused even when implementing an abstract method + "S311", # suspicious-non-cryptographic-random-usage, complains about `random.random()` + "TC001", # typing-only-first-party-import, reduced rt overhead but slower development and messy imports + "TC002", # typing-only-third-party-import, reduced rt overhead but slower development and messy imports + "TC003", # typing-only-standard-library-import, reduced rt overhead but slower development and messy imports + "TRY003", # raise-vanilla-args, complains about simple error messages like `ValueError('Matrix A must be 2D')` + "FBT001", # boolean-type-hint-positional-argument, complains about bool args with default values in functions + "FBT002", # boolean-default-value-positional-argument, complains about bool args with default values in functions + "ICN001", # unconventional-import-alias, complains about common import aliases like `import numpy as np` but it doesn't work for most libraries +] +preview = true +line-length = 120 diff --git a/readthedocs.yaml b/readthedocs.yaml new file mode 100644 index 0000000..3b75931 --- /dev/null +++ b/readthedocs.yaml @@ -0,0 +1,27 @@ +# Read the Docs configuration file +# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details + +# Required +version: 2 + +# Set the OS, Python version, and other tools you might need +build: + os: ubuntu-24.04 + tools: + python: "3.13" + jobs: + build: + html: + - mkdir -p $READTHEDOCS_OUTPUT/html/ + - python -m sphinx -T -W --keep-going -j 1 -b html -d _build/doctrees -D language=en docs/source $READTHEDOCS_OUTPUT/html + +# Build documentation in the "docs/" directory with Sphinx +sphinx: + configuration: docs/source/conf.py + fail_on_warning: true + +python: + install: + - requirements: docs/sphinx_theme.txt + - method: pip + path: . diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..1ca38bf --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,121 @@ +"""Shared fixtures: parametrize tests across every (framework, device) combination. + +Each test using the ``backend`` fixture runs once per (framework, device) pair from +:class:`SupportedFrameworks` x :class:`SupportedDevices`. Combinations whose backend +package is missing or whose device is not present on the current host are marked +``skip`` so the test report stays interpretable on machines with partial accelerator +support. +""" + +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING + +import pytest + +from decent_array.interoperability._backend_manager import reset_backends +from decent_array.types import SupportedDevices, SupportedFrameworks + +if TYPE_CHECKING: + from _pytest.fixtures import FixtureRequest + + +def _framework_importable(framework: SupportedFrameworks) -> bool: + try: + if framework == SupportedFrameworks.NUMPY: + import numpy # noqa: F401, PLC0415 + elif framework == SupportedFrameworks.PYTORCH: + import torch # noqa: F401, PLC0415 + elif framework == SupportedFrameworks.JAX: + import jax # noqa: F401, PLC0415 + elif framework == SupportedFrameworks.TENSORFLOW: + import tensorflow # noqa: F401, PLC0415 + except ImportError: + return False + return True + + +def _device_available(framework: SupportedFrameworks, device: SupportedDevices) -> bool: + """Return True iff this (framework, device) pair can run on the current host.""" + if not _framework_importable(framework): + return False + if framework == SupportedFrameworks.NUMPY: + return device == SupportedDevices.CPU + if framework == SupportedFrameworks.PYTORCH: + import torch # noqa: PLC0415 + + if device == SupportedDevices.CPU: + return True + if device == SupportedDevices.GPU: + try: + return bool(torch.cuda.is_available()) + except Exception: + return False + if device == SupportedDevices.MPS: + try: + return bool(torch.backends.mps.is_available()) + except Exception: + return False + if framework == SupportedFrameworks.JAX: + if device == SupportedDevices.MPS: + return False + import jax # noqa: PLC0415 + + try: + jax.devices(device.value) + except Exception: + return False + return True + if framework == SupportedFrameworks.TENSORFLOW: + if device == SupportedDevices.MPS: + return False + import tensorflow as tf # noqa: PLC0415 + + if device == SupportedDevices.CPU: + return True + if device == SupportedDevices.GPU: + try: + return len(tf.config.list_physical_devices("GPU")) > 0 + except Exception: + return False + return False + + +def _backend_params() -> list[pytest.ParameterSet]: + params: list[pytest.ParameterSet] = [] + for framework in SupportedFrameworks: + for device in SupportedDevices: + test_id = f"{framework.value}-{device.value}" + if _device_available(framework, device): + params.append(pytest.param((framework, device), id=test_id)) + else: + params.append( + pytest.param( + (framework, device), + id=test_id, + marks=pytest.mark.skip(reason=f"{framework.value}/{device.value} unavailable"), + ) + ) + return params + + +BACKEND_PARAMS = _backend_params() + + +@pytest.fixture(params=BACKEND_PARAMS) +def backend(request: FixtureRequest) -> Iterator[tuple[SupportedFrameworks, SupportedDevices]]: + """Activate the (framework, device) backend for this test, then reset on teardown.""" + from decent_array.interoperability import set_backend # noqa: PLC0415 + + framework, device = request.param + set_backend(framework, device) + yield framework, device + reset_backends() + + +@pytest.fixture +def reset_after() -> Iterator[None]: + """For tests that touch backend-manager state directly without the ``backend`` fixture.""" + yield + reset_backends() diff --git a/tests/test_array.py b/tests/test_array.py new file mode 100644 index 0000000..164e223 --- /dev/null +++ b/tests/test_array.py @@ -0,0 +1,389 @@ +"""Tests for :class:`decent_array.Array` operators, properties, and dunders.""" + +from __future__ import annotations + +import re + +import numpy as np +import pytest + +import decent_array.interoperability as iop +from decent_array import Array +from decent_array.interoperability._backend_manager import reset_backends + + +def _np(arr: Array) -> np.ndarray: + """Convert a backend-native ``Array`` to a numpy array for assertions.""" + return iop.to_numpy(arr) + + +def _create_array(data: float | list[float] | list[list[float]]) -> Array: + """Create an Array from a list of floats for testing.""" + return iop.from_numpy(np.array(data, dtype=np.float32)) + + +# Construction ------------------------------------------------------------ + + +def test_init_requires_active_backend() -> None: + reset_backends() + with pytest.raises(RuntimeError, match=re.compile(r"No backend registered", re.IGNORECASE)): + Array(np.zeros(3)) + + +def test_init_records_active_backend(backend: tuple) -> None: + arr = iop.zeros((3,)) + assert arr._backend is not None + assert isinstance(arr, Array) + + +# Binary arithmetic ------------------------------------------------------- + + +def test_add_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([4.0, 5.0, 6.0]) + np.testing.assert_allclose(_np(a + b), [5.0, 7.0, 9.0]) + + +def test_add_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_allclose(_np(a + 2), [3.0, 4.0, 5.0]) + + +def test_radd_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_allclose(_np(2 + a), [3.0, 4.0, 5.0]) + + +def test_sub_array(backend: tuple) -> None: + a = _create_array([4.0, 5.0, 6.0]) + b = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_allclose(_np(a - b), [3.0, 3.0, 3.0]) + + +def test_sub_scalar(backend: tuple) -> None: + a = _create_array([4.0, 5.0, 6.0]) + np.testing.assert_allclose(_np(a - 1), [3.0, 4.0, 5.0]) + + +def test_rsub_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_allclose(_np(10 - a), [9.0, 8.0, 7.0]) + + +def test_mul_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([2.0, 3.0, 4.0]) + np.testing.assert_allclose(_np(a * b), [2.0, 6.0, 12.0]) + + +def test_mul_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_allclose(_np(a * 3), [3.0, 6.0, 9.0]) + + +def test_rmul_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_allclose(_np(3 * a), [3.0, 6.0, 9.0]) + + +def test_truediv_array(backend: tuple) -> None: + a = _create_array([4.0, 9.0, 16.0]) + b = _create_array([2.0, 3.0, 4.0]) + np.testing.assert_allclose(_np(a / b), [2.0, 3.0, 4.0]) + + +def test_truediv_scalar(backend: tuple) -> None: + a = _create_array([2.0, 4.0, 6.0]) + np.testing.assert_allclose(_np(a / 2), [1.0, 2.0, 3.0]) + + +def test_rtruediv_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 4.0]) + np.testing.assert_allclose(_np(8 / a), [8.0, 4.0, 2.0]) + + +def test_matmul_array(backend: tuple) -> None: + a = _create_array([[1.0, 2.0], [3.0, 4.0]]) + b = _create_array([[5.0, 6.0], [7.0, 8.0]]) + expected = np.array([[19.0, 22.0], [43.0, 50.0]]) + np.testing.assert_allclose(_np(a @ b), expected) + + +def test_rmatmul_explicit_call(backend: tuple) -> None: + # Python dispatch never picks Array.__rmatmul__ when both operands are Array, + # so call it directly to exercise the code path. + a = _create_array([[1.0, 2.0], [3.0, 4.0]]) + b = _create_array([[5.0, 6.0], [7.0, 8.0]]) + np.testing.assert_allclose(_np(a.__rmatmul__(b)), _np(b @ a)) + + +def test_pow_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_allclose(_np(a**2), [1.0, 4.0, 9.0]) + + +# Comparisons ------------------------------------------------------------ + + +def test_eq_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([1.0, 5.0, 3.0]) + result = a == b + assert isinstance(result, Array) + np.testing.assert_array_equal(_np(result), [True, False, True]) + + +def test_eq_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_array_equal(_np(a == 2.0), [False, True, False]) + + +def test_ne_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([1.0, 5.0, 3.0]) + np.testing.assert_array_equal(_np(a != b), [False, True, False]) + + +def test_ne_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_array_equal(_np(a != 2.0), [True, False, True]) + + +def test_lt_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([1.0, 5.0, 0.0]) + np.testing.assert_array_equal(_np(a < b), [False, True, False]) + + +def test_lt_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_array_equal(_np(a < 2.5), [True, True, False]) + + +def test_le_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([1.0, 5.0, 0.0]) + np.testing.assert_array_equal(_np(a <= b), [True, True, False]) + + +def test_le_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_array_equal(_np(a <= 2.0), [True, True, False]) + + +def test_gt_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([1.0, 5.0, 0.0]) + np.testing.assert_array_equal(_np(a > b), [False, False, True]) + + +def test_gt_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_array_equal(_np(a > 1.5), [False, True, True]) + + +def test_ge_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([1.0, 5.0, 0.0]) + np.testing.assert_array_equal(_np(a >= b), [True, False, True]) + + +def test_ge_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + np.testing.assert_array_equal(_np(a >= 2.0), [False, True, True]) + + +def test_array_is_unhashable(backend: tuple) -> None: # noqa: ARG001 + """ + Overriding ``__eq__`` makes Array unhashable, matching numpy/torch/jax/tf. + + The check is on ``hash(arr)`` rather than ``Array.__hash__ is None`` because + mypyc realizes unhashability via a type-error-raising slot descriptor while + pure Python sets the attribute to ``None``; ``hash()`` raises ``TypeError`` + in either case. + """ + arr = _create_array([1.0]) + with pytest.raises(TypeError, match=re.compile(r"unhashable", re.IGNORECASE)): + hash(arr) + + +# Bitwise ---------------------------------------------------------------- + + +def test_and_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0, 4.0]) + mask1 = a > 1.0 + mask2 = a < 4.0 + np.testing.assert_array_equal(_np(mask1 & mask2), [False, True, True, False]) + + +def test_and_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + mask = a > 1.0 + np.testing.assert_array_equal(_np(mask & True), [False, True, True]) + np.testing.assert_array_equal(_np(mask & False), [False, False, False]) + + +def test_rand_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + mask = a > 1.0 + # Python evaluates ``True & mask`` via ``mask.__rand__(True)`` since bool's + # ``__and__`` doesn't accept Array. + np.testing.assert_array_equal(_np(True & mask), [False, True, True]) + + +# In-place arithmetic ----------------------------------------------------- + + +def test_iadd_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + b = _create_array([4.0, 5.0, 6.0]) + a += b + np.testing.assert_allclose(_np(a), [5.0, 7.0, 9.0]) + + +def test_iadd_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + a += 10 + np.testing.assert_allclose(_np(a), [11.0, 12.0, 13.0]) + + +def test_isub_array(backend: tuple) -> None: + a = _create_array([5.0, 6.0, 7.0]) + b = _create_array([1.0, 2.0, 3.0]) + a -= b + np.testing.assert_allclose(_np(a), [4.0, 4.0, 4.0]) + + +def test_imul_scalar(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + a *= 2 + np.testing.assert_allclose(_np(a), [2.0, 4.0, 6.0]) + + +def test_itruediv_scalar(backend: tuple) -> None: + a = _create_array([2.0, 4.0, 6.0]) + a /= 2 + np.testing.assert_allclose(_np(a), [1.0, 2.0, 3.0]) + + +# Unary ------------------------------------------------------------------ + + +def test_neg(backend: tuple) -> None: + a = _create_array([1.0, -2.0, 3.0]) + np.testing.assert_allclose(_np(-a), [-1.0, 2.0, -3.0]) + + +def test_abs(backend: tuple) -> None: + a = _create_array([1.0, -2.0, 3.0]) + np.testing.assert_allclose(_np(abs(a)), [1.0, 2.0, 3.0]) + + +# Indexing --------------------------------------------------------------- + + +def test_getitem_int(backend: tuple) -> None: + a = _create_array([10.0, 20.0, 30.0]) + np.testing.assert_allclose(_np(a[1]), 20.0) + + +def test_getitem_slice(backend: tuple) -> None: + a = _create_array([10.0, 20.0, 30.0, 40.0]) + np.testing.assert_allclose(_np(a[1:3]), [20.0, 30.0]) + + +def test_getitem_tuple(backend: tuple) -> None: + a = _create_array([[1.0, 2.0], [3.0, 4.0]]) + np.testing.assert_allclose(_np(a[1, 0]), 3.0) + + +def test_setitem_with_array(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + a[1] = _create_array(99.0) + np.testing.assert_allclose(_np(a), [1.0, 99.0, 3.0]) + + +def test_setitem_with_scalar(backend: tuple) -> None: + # __setitem__ wraps a non-Array value in Array internally. + a = _create_array([1.0, 2.0, 3.0]) + a[2] = 99.0 + np.testing.assert_allclose(_np(a), [1.0, 2.0, 99.0]) + + +# Container / coercion / repr ------------------------------------------- + + +def test_len(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0, 4.0]) + assert len(a) == 4 + + +def test_float_coercion(backend: tuple) -> None: + a = _create_array([2.5]) + assert float(a) == pytest.approx(2.5) + + +def test_repr(backend: tuple) -> None: + a = _create_array([1.0, 2.0]) + text = repr(a) + assert text.startswith("Array(") + assert text.endswith(")") + + +def test_str(backend: tuple) -> None: + a = _create_array([1.0, 2.0]) + # ``str(arr)`` delegates to the wrapped value's stringifier; just check it succeeds + # and is non-empty rather than pin per-backend formatting. + assert isinstance(str(a), str) + assert len(str(a)) > 0 + + +# Properties ------------------------------------------------------------- + + +def test_shape(backend: tuple) -> None: + a = iop.from_numpy(np.zeros((2, 3, 4), dtype=np.float32)) + assert a.shape == (2, 3, 4) + + +def test_size(backend: tuple) -> None: + a = iop.from_numpy(np.zeros((2, 3, 4), dtype=np.float32)) + assert a.size == 24 + + +def test_ndim(backend: tuple) -> None: + a = iop.from_numpy(np.zeros((2, 3, 4), dtype=np.float32)) + assert a.ndim == 3 + + +def test_transpose_property(backend: tuple) -> None: + a = _create_array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + np.testing.assert_allclose(_np(a.transpose), [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]]) + + +def test_T_alias(backend: tuple) -> None: + a = _create_array([[1.0, 2.0], [3.0, 4.0]]) + np.testing.assert_allclose(_np(a.T), _np(a.transpose)) + + +def test_any_true(backend: tuple) -> None: + a = _create_array([0.0, 0.0, 1.0]) + assert a.any is True + + +def test_any_false(backend: tuple) -> None: + a = _create_array([0.0, 0.0, 0.0]) + assert a.any is False + + +def test_all_true(backend: tuple) -> None: + a = _create_array([1.0, 2.0, 3.0]) + assert a.all is True + + +def test_all_false(backend: tuple) -> None: + a = _create_array([1.0, 0.0, 3.0]) + assert a.all is False diff --git a/tests/test_backend_manager.py b/tests/test_backend_manager.py new file mode 100644 index 0000000..982812d --- /dev/null +++ b/tests/test_backend_manager.py @@ -0,0 +1,189 @@ +"""Tests for :mod:`decent_array.interoperability._backend_manager`.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +import pytest + +from decent_array.interoperability import _backend_manager as backend_manager +from decent_array.interoperability._abstracts import Backend +from decent_array.interoperability._backend_manager import ( + _instantiate, + _normalize, + register_backend, + register_backend_listener, + reset_backends, + set_backend, +) +from decent_array.types import SupportedDevices, SupportedFrameworks + +if TYPE_CHECKING: + from collections.abc import Iterator + + +@pytest.fixture(autouse=True) +def _isolate_listeners_and_backends() -> Iterator[None]: + """Snapshot+restore module-level state so backend-manager tests don't leak.""" + listeners_snapshot = backend_manager._BACKEND_LISTENERS.copy() + registry_snapshot = backend_manager._BACKEND_REGISTRY.copy() + reset_backends() + yield + backend_manager._BACKEND_LISTENERS[:] = listeners_snapshot + backend_manager._BACKEND_REGISTRY.clear() + backend_manager._BACKEND_REGISTRY.update(registry_snapshot) + reset_backends() + + +# _normalize ------------------------------------------------------------- + + +def test_normalize_accepts_enum() -> None: + assert _normalize(SupportedFrameworks.NUMPY) == SupportedFrameworks.NUMPY + + +def test_normalize_accepts_string() -> None: + assert _normalize("numpy") == SupportedFrameworks.NUMPY + + +def test_normalize_unknown_raises() -> None: + with pytest.raises(KeyError, match=r"Unknown backend"): + _normalize("not-a-backend") + + +# set_backend ------------------------------------------------------------ + + +def test_set_backend_with_string() -> None: + set_backend("numpy") + assert backend_manager._ACTIVE_BACKEND.get() == SupportedFrameworks.NUMPY + + +def test_set_backend_with_enum() -> None: + set_backend(SupportedFrameworks.NUMPY) + assert backend_manager._ACTIVE_BACKEND.get() == SupportedFrameworks.NUMPY + + +def test_set_backend_idempotent_same_backend() -> None: + set_backend("numpy") + # Re-activating with the same backend+device must be a no-op (no exception). + set_backend("numpy") + assert backend_manager._ACTIVE_BACKEND.get() == SupportedFrameworks.NUMPY + + +def test_set_backend_different_backend_raises() -> None: + set_backend("numpy") + with pytest.raises(RuntimeError, match=r"already set to"): + set_backend("pytorch") + + +def test_set_backend_with_string_device() -> None: + set_backend("numpy", "cpu") + instance = _instantiate(SupportedFrameworks.NUMPY, SupportedDevices.CPU) + assert instance.device == SupportedDevices.CPU + + +def test_set_backend_invalid_name_raises() -> None: + with pytest.raises(KeyError): + set_backend("not-a-backend") + + +# register_backend ------------------------------------------------------- + + +def test_register_backend_rejects_non_subclass() -> None: + class NotABackend: + pass + + with pytest.raises(TypeError, match=r"subclass of Backend"): + register_backend(SupportedFrameworks.NUMPY, NotABackend) # type: ignore[arg-type] + + +def test_register_backend_replaces_cached_instance() -> None: + # First import registers the real backend; instantiate to populate cache. + set_backend("numpy") + cached = backend_manager._BACKEND_INSTANCES.get(SupportedFrameworks.NUMPY) + assert cached is not None + + # Re-register the same class — cache should be cleared so next instantiate is fresh. + from decent_array.interoperability._numpy.numpy_backend import NumpyBackend # noqa: PLC0415 + + reset_backends() + register_backend(SupportedFrameworks.NUMPY, NumpyBackend) + assert SupportedFrameworks.NUMPY not in backend_manager._BACKEND_INSTANCES + + +# register_backend_listener --------------------------------------------- + + +def test_listener_called_on_activation() -> None: + received: list[Backend | None] = [] + + def listener(backend: Backend | None) -> None: + received.append(backend) + + register_backend_listener(listener) + set_backend("numpy") + assert len(received) == 1 + assert isinstance(received[0], Backend) + + +def test_listener_called_immediately_when_backend_already_active() -> None: + set_backend("numpy") + received: list[Backend | None] = [] + + def listener(backend: Backend | None) -> None: + received.append(backend) + + register_backend_listener(listener) + assert len(received) == 1 + assert isinstance(received[0], Backend) + + +def test_listener_called_with_none_on_reset() -> None: + set_backend("numpy") + received: list[Backend | None] = [] + + def listener(backend: Backend | None) -> None: + received.append(backend) + + register_backend_listener(listener) + reset_backends() + # First call: immediate notification with active backend (because already active). + # Second call: notification with None on reset. + assert received[-1] is None + + +# reset_backends --------------------------------------------------------- + + +def test_reset_backends_clears_active() -> None: + set_backend("numpy") + reset_backends() + assert backend_manager._ACTIVE_BACKEND.get() is None + assert backend_manager._BACKEND_INSTANCE is None + + +def test_reset_backends_clears_instance_cache() -> None: + set_backend("numpy") + assert SupportedFrameworks.NUMPY in backend_manager._BACKEND_INSTANCES + reset_backends() + assert SupportedFrameworks.NUMPY not in backend_manager._BACKEND_INSTANCES + + +# _instantiate ---------------------------------------------------------- + + +def test_instantiate_caches_instance() -> None: + a = _instantiate(SupportedFrameworks.NUMPY, SupportedDevices.CPU) + b = _instantiate(SupportedFrameworks.NUMPY, SupportedDevices.CPU) + assert a is b + + +def test_set_backend_device_mismatch_raises() -> None: + set_backend("numpy", SupportedDevices.CPU) + # NumPy backend rejects non-CPU devices at construction; check behavior via the + # configured-mismatch path: re-set with a different device after first activation. + with pytest.raises((RuntimeError, ValueError)): + # The same backend cannot be reconfigured to a different device. + set_backend("numpy", "gpu") diff --git a/tests/test_decorators.py b/tests/test_decorators.py new file mode 100644 index 0000000..3c75137 --- /dev/null +++ b/tests/test_decorators.py @@ -0,0 +1,107 @@ +"""Tests for :mod:`decent_array.interoperability.decorators`. + +Note: this file deliberately omits ``from __future__ import annotations``. The +:func:`autodecorate_cost_method` decorator inspects ``__annotations__["return"] is Array`` +on the superclass method, and PEP 563-style stringified annotations would break that +identity check. +""" + +from typing import Any + +import numpy as np + +import decent_array.interoperability as iop +from decent_array import Array +from decent_array.interoperability._decorators import autodecorate_cost_method + + +class _Base: + def returns_array(self, x: Array) -> Array: + raise NotImplementedError + + def returns_float(self, x: Array) -> float: + raise NotImplementedError + + def takes_kwarg(self, x: Array, *, scale: Array) -> Array: + raise NotImplementedError + + def passes_through_non_array(self, x: Array, n: int) -> Array: + raise NotImplementedError + + +def test_unwraps_array_args(backend: tuple) -> None: + seen: list[Any] = [] + + class Impl(_Base): + @autodecorate_cost_method(_Base.returns_array) + def returns_array(self, x: Any) -> Any: + seen.append(x) + return x * 2 + + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + result = Impl().returns_array(arr) + # The decorator should have unwrapped the Array to its underlying value. + assert not isinstance(seen[0], Array) + # Annotated `-> Array`, so the return is re-wrapped. + assert isinstance(result, Array) + np.testing.assert_allclose(iop.to_numpy(result), [2.0, 4.0, 6.0]) + + +def test_does_not_rewrap_when_return_not_array(backend: tuple) -> None: + class Impl(_Base): + @autodecorate_cost_method(_Base.returns_float) + def returns_float(self, x: Any) -> float: + # Use a fixed value so the test doesn't depend on framework-specific tensor + # methods (e.g. TF eager tensors don't expose ``.sum()`` without enabling + # the numpy compatibility shim). + assert not isinstance(x, Array) + return 42.0 + + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + result = Impl().returns_float(arr) + assert isinstance(result, float) + assert result == 42.0 + + +def test_unwraps_kwargs(backend: tuple) -> None: + class Impl(_Base): + @autodecorate_cost_method(_Base.takes_kwarg) + def takes_kwarg(self, x: Any, *, scale: Any) -> Any: + assert not isinstance(x, Array) + assert not isinstance(scale, Array) + return x * scale + + arr = iop.from_numpy(np.array([1.0, 2.0], dtype=np.float32)) + scale = iop.from_numpy(np.array([10.0, 100.0], dtype=np.float32)) + result = Impl().takes_kwarg(arr, scale=scale) + assert isinstance(result, Array) + np.testing.assert_allclose(iop.to_numpy(result), [10.0, 200.0]) + + +def test_passes_non_array_args_unchanged(backend: tuple) -> None: + class Impl(_Base): + @autodecorate_cost_method(_Base.passes_through_non_array) + def passes_through_non_array(self, x: Any, n: int) -> Any: + assert isinstance(n, int) + return x * n + + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + result = Impl().passes_through_non_array(arr, 4) + assert isinstance(result, Array) + np.testing.assert_allclose(iop.to_numpy(result), [4.0, 8.0, 12.0]) + + +def test_does_not_double_wrap_when_impl_returns_array(backend: tuple) -> None: + """If the decorated impl already returns an :class:`Array`, the wrapper should not double-wrap.""" + + class Impl(_Base): + @autodecorate_cost_method(_Base.returns_array) + def returns_array(self, x: Any) -> Array: + return Array(x * 3) + + arr = iop.from_numpy(np.array([1.0, 2.0], dtype=np.float32)) + result = Impl().returns_array(arr) + assert isinstance(result, Array) + # If the wrapper double-wrapped, ``result.value`` would itself be an Array. + assert not isinstance(result.value, Array) + np.testing.assert_allclose(iop.to_numpy(result), [3.0, 6.0]) diff --git a/tests/test_iop_functions.py b/tests/test_iop_functions.py new file mode 100644 index 0000000..117daac --- /dev/null +++ b/tests/test_iop_functions.py @@ -0,0 +1,556 @@ +"""Tests for the module-level functions in :mod:`decent_array.interoperability.iop.functions`.""" + +from __future__ import annotations + +import numpy as np +import pytest + +import decent_array.interoperability as iop +from decent_array import Array +from decent_array.interoperability._backend_manager import reset_backends +from decent_array.types import SupportedDevices + + +def _np(arr: Array) -> np.ndarray: + return iop.to_numpy(arr) + + +# Array creation --------------------------------------------------------- + + +def test_zeros(backend: tuple) -> None: + arr = iop.zeros((2, 3)) + np.testing.assert_allclose(_np(arr), np.zeros((2, 3))) + + +def test_zeros_like(backend: tuple) -> None: + src = iop.from_numpy(np.ones((2, 3), dtype=np.float32)) + arr = iop.zeros_like(src) + np.testing.assert_allclose(_np(arr), np.zeros((2, 3))) + + +def test_ones(backend: tuple) -> None: + arr = iop.ones((2, 3)) + np.testing.assert_allclose(_np(arr), np.ones((2, 3))) + + +def test_ones_like(backend: tuple) -> None: + src = iop.from_numpy(np.zeros((2, 3), dtype=np.float32)) + arr = iop.ones_like(src) + np.testing.assert_allclose(_np(arr), np.ones((2, 3))) + + +def test_eye(backend: tuple) -> None: + arr = iop.eye(3) + np.testing.assert_allclose(_np(arr), np.eye(3)) + + +def test_eye_like(backend: tuple) -> None: + src = iop.from_numpy(np.zeros((4, 4), dtype=np.float32)) + arr = iop.eye_like(src) + np.testing.assert_allclose(_np(arr), np.eye(4)) + + +def test_device_to_native(backend: tuple) -> None: + framework, device = backend + native = iop.device_to_native(device) + # Just verify the call succeeds and returns a non-None value (varies per backend). + assert native is not None + + +def test_device_of(backend: tuple) -> None: + _framework, device = backend + arr = iop.zeros((3,)) + assert iop.device_of(arr) == device + + +# Array manipulation ----------------------------------------------------- + + +def test_copy_independent(backend: tuple) -> None: + src = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + dst = iop.copy(src) + np.testing.assert_allclose(_np(dst), [1.0, 2.0, 3.0]) + # Mutating the copy shouldn't affect the original. + dst[0] = 99.0 + np.testing.assert_allclose(_np(src), [1.0, 2.0, 3.0]) + + +def test_to_numpy(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + out = iop.to_numpy(arr) + assert isinstance(out, np.ndarray) + np.testing.assert_allclose(out, [1.0, 2.0, 3.0]) + + +def test_from_numpy_roundtrip(backend: tuple) -> None: + raw = np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32) + arr = iop.from_numpy(raw) + np.testing.assert_allclose(_np(arr), raw) + + +def test_from_numpy_like_matches_dtype(backend: tuple) -> None: + # ``like`` is int32 so the resulting Array should be int32 even though the + # source numpy array is float32. (Avoids float64, which torch MPS rejects.) + like = iop.from_numpy(np.array([0, 0], dtype=np.int32)) + raw = np.array([1.5, 2.5], dtype=np.float32) + arr = iop.from_numpy_like(raw, like) + out_np = _np(arr) + assert out_np.dtype == np.int32 + np.testing.assert_array_equal(out_np, [1, 2]) + + +def test_from_numpy_like_matches_device(backend: tuple) -> None: + # ``like`` lives on whichever device the active backend uses; the result of + # from_numpy_like must report the same device. + like = iop.from_numpy(np.array([0.0, 0.0], dtype=np.float32)) + raw = np.array([3.0, 4.0], dtype=np.float32) + arr = iop.from_numpy_like(raw, like) + assert iop.device_of(arr) == iop.device_of(like) + + +def test_to_array_from_scalar(backend: tuple) -> None: + arr = iop.to_array(2.5) + np.testing.assert_allclose(_np(arr), 2.5) + + +def test_stack(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0], dtype=np.float32)) + b = iop.from_numpy(np.array([3.0, 4.0], dtype=np.float32)) + arr = iop.stack([a, b]) + np.testing.assert_allclose(_np(arr), [[1.0, 2.0], [3.0, 4.0]]) + + +def test_stack_dim(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0], dtype=np.float32)) + b = iop.from_numpy(np.array([3.0, 4.0], dtype=np.float32)) + arr = iop.stack([a, b], axis=1) + np.testing.assert_allclose(_np(arr), [[1.0, 3.0], [2.0, 4.0]]) + + +def test_stack_empty_raises(backend: tuple) -> None: + with pytest.raises(ValueError, match=r"empty sequence"): + iop.stack([]) + + +def test_reshape(backend: tuple) -> None: + arr = iop.from_numpy(np.arange(6, dtype=np.float32)) + out = iop.reshape(arr, (2, 3)) + np.testing.assert_allclose(_np(out), np.arange(6, dtype=np.float32).reshape(2, 3)) + + +def test_transpose_default(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], dtype=np.float32)) + np.testing.assert_allclose( + _np(iop.transpose(arr)), [[1.0, 4.0], [2.0, 5.0], [3.0, 6.0]] + ) + + +def test_transpose_explicit_dim(backend: tuple) -> None: + arr = iop.from_numpy(np.zeros((2, 3, 4), dtype=np.float32)) + out = iop.transpose(arr, axis=(1, 0, 2)) + assert iop.shape(out) == (3, 2, 4) + + +def test_shape_function(backend: tuple) -> None: + arr = iop.from_numpy(np.zeros((2, 3, 4), dtype=np.float32)) + assert iop.shape(arr) == (2, 3, 4) + + +def test_size_function(backend: tuple) -> None: + arr = iop.from_numpy(np.zeros((2, 3, 4), dtype=np.float32)) + assert iop.size(arr) == 24 + + +def test_ndim_function(backend: tuple) -> None: + arr = iop.from_numpy(np.zeros((2, 3, 4), dtype=np.float32)) + assert iop.ndim(arr) == 3 + + +def test_squeeze_default(backend: tuple) -> None: + arr = iop.from_numpy(np.zeros((1, 3, 1), dtype=np.float32)) + assert iop.shape(iop.squeeze(arr)) == (3,) + + +def test_squeeze_specific_dim(backend: tuple) -> None: + arr = iop.from_numpy(np.zeros((1, 3, 1), dtype=np.float32)) + assert iop.shape(iop.squeeze(arr, axis=0)) == (3, 1) + + +def test_unsqueeze(backend: tuple) -> None: + arr = iop.from_numpy(np.zeros((3,), dtype=np.float32)) + assert iop.shape(iop.unsqueeze(arr, axis=0)) == (1, 3) + assert iop.shape(iop.unsqueeze(arr, axis=1)) == (3, 1) + + +def test_diag_from_vector(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.diag(arr)), np.diag([1.0, 2.0, 3.0])) + + +def test_diag_from_matrix(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.diag(arr)), [1.0, 4.0]) + + +def test_astype_to_float(backend: tuple) -> None: + arr = iop.to_array(3.0) + out = iop.astype(arr, float) + assert isinstance(out, float) + assert out == pytest.approx(3.0) + + +def test_astype_to_int(backend: tuple) -> None: + arr = iop.to_array(3.0) + out = iop.astype(arr, int) + assert isinstance(out, int) + assert out == 3 + + +def test_astype_to_bool(backend: tuple) -> None: + arr = iop.to_array(1.0) + out = iop.astype(arr, bool) + assert isinstance(out, bool) + assert out is True + + +# Linalg ----------------------------------------------------------------- + + +def test_dot(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([4.0, 5.0, 6.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.dot(a, b)), 32.0) + + +def test_matmul(backend: tuple) -> None: + a = iop.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) + b = iop.from_numpy(np.array([[5.0, 6.0], [7.0, 8.0]], dtype=np.float32)) + expected = np.array([[19.0, 22.0], [43.0, 50.0]], dtype=np.float32) + np.testing.assert_allclose(_np(iop.matmul(a, b)), expected) + + +def test_norm_default_l2(backend: tuple) -> None: + arr = iop.from_numpy(np.array([3.0, 4.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.norm(arr)), 5.0) + + +def test_norm_p1(backend: tuple) -> None: + arr = iop.from_numpy(np.array([3.0, -4.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.norm(arr, p=1)), 7.0) + + +def test_norm_dim_keepdims(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[3.0, 4.0], [6.0, 8.0]], dtype=np.float32)) + out = iop.norm(arr, p=2, axis=1, keepdims=True) + assert iop.shape(out) == (2, 1) + np.testing.assert_allclose(_np(out).reshape(-1), [5.0, 10.0]) + + +# Math reductions -------------------------------------------------------- + + +def test_sum_all(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.sum(arr)), 10.0) + + +def test_sum_dim(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.sum(arr, axis=0)), [4.0, 6.0]) + + +def test_sum_dim_keepdims(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) + out = iop.sum(arr, axis=0, keepdims=True) + assert iop.shape(out) == (1, 2) + + +def test_mean_all(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.mean(arr)), 2.5) + + +def test_mean_dim(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 2.0], [3.0, 4.0]], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.mean(arr, axis=1)), [1.5, 3.5]) + + +def test_min_all(backend: tuple) -> None: + arr = iop.from_numpy(np.array([3.0, 1.0, 4.0, 1.0, 5.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.min(arr)), 1.0) + + +def test_min_dim(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 5.0], [3.0, 2.0]], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.min(arr, axis=0)), [1.0, 2.0]) + + +def test_max_all(backend: tuple) -> None: + arr = iop.from_numpy(np.array([3.0, 1.0, 4.0, 1.0, 5.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.max(arr)), 5.0) + + +def test_max_dim(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 5.0], [3.0, 2.0]], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.max(arr, axis=0)), [3.0, 5.0]) + + +def test_any_true(backend: tuple) -> None: + arr = iop.from_numpy(np.array([0.0, 1.0, 0.0], dtype=np.float32)) + assert iop.any(arr) is True + + +def test_any_false(backend: tuple) -> None: + arr = iop.from_numpy(np.array([0.0, 0.0, 0.0], dtype=np.float32)) + assert iop.any(arr) is False + + +def test_all_true(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + assert iop.all(arr) is True + + +def test_all_false(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, 0.0, 3.0], dtype=np.float32)) + assert iop.all(arr) is False + + +# Math elementwise ------------------------------------------------------- + + +def test_add_two_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0], dtype=np.float32)) + b = iop.from_numpy(np.array([3.0, 4.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.add(a, b)), [4.0, 6.0]) + + +def test_add_array_and_scalar(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.add(a, 10.0)), [11.0, 12.0]) + + +def test_sub(backend: tuple) -> None: + a = iop.from_numpy(np.array([5.0, 6.0], dtype=np.float32)) + b = iop.from_numpy(np.array([1.0, 2.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.sub(a, b)), [4.0, 4.0]) + + +def test_mul(backend: tuple) -> None: + a = iop.from_numpy(np.array([2.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([4.0, 5.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.mul(a, b)), [8.0, 15.0]) + + +def test_div(backend: tuple) -> None: + a = iop.from_numpy(np.array([8.0, 10.0], dtype=np.float32)) + b = iop.from_numpy(np.array([2.0, 5.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.div(a, b)), [4.0, 2.0]) + + +def test_iadd_func(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0], dtype=np.float32)) + out = iop.iadd(a, 10.0) + # Returned wrapper is the same instance. + assert out is a + np.testing.assert_allclose(_np(a), [11.0, 12.0]) + + +def test_isub_func(backend: tuple) -> None: + a = iop.from_numpy(np.array([5.0, 6.0], dtype=np.float32)) + out = iop.isub(a, 1.0) + assert out is a + np.testing.assert_allclose(_np(a), [4.0, 5.0]) + + +def test_imul_func(backend: tuple) -> None: + a = iop.from_numpy(np.array([2.0, 3.0], dtype=np.float32)) + out = iop.imul(a, 4.0) + assert out is a + np.testing.assert_allclose(_np(a), [8.0, 12.0]) + + +def test_idiv_func(backend: tuple) -> None: + a = iop.from_numpy(np.array([8.0, 12.0], dtype=np.float32)) + out = iop.idiv(a, 4.0) + assert out is a + np.testing.assert_allclose(_np(a), [2.0, 3.0]) + + +def test_pow_function(backend: tuple) -> None: + arr = iop.from_numpy(np.array([2.0, 3.0, 4.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.pow(arr, 2)), [4.0, 9.0, 16.0]) + + +def test_negative(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, -2.0, 3.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.negative(arr)), [-1.0, 2.0, -3.0]) + + +def test_absolute(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, -2.0, 3.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.absolute(arr)), [1.0, 2.0, 3.0]) + + +def test_sqrt(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, 4.0, 9.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.sqrt(arr)), [1.0, 2.0, 3.0]) + + +# Operators -------------------------------------------------------------- + + +def test_sign(backend: tuple) -> None: + arr = iop.from_numpy(np.array([-2.0, 0.0, 3.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.sign(arr)), [-1.0, 0.0, 1.0]) + + +def test_maximum_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 5.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([4.0, 2.0, 6.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.maximum(a, b)), [4.0, 5.0, 6.0]) + + +def test_maximum_array_and_scalar(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 5.0, 3.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.maximum(a, 4.0)), [4.0, 5.0, 4.0]) + + +# Comparisons ------------------------------------------------------------ + + +def test_eq_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([1.0, 5.0, 3.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.eq(a, b)), [True, False, True]) + + +def test_eq_array_and_scalar(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.eq(a, 2.0)), [False, True, False]) + + +def test_ne_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([1.0, 5.0, 3.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.ne(a, b)), [False, True, False]) + + +def test_ne_array_and_scalar(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.ne(a, 2.0)), [True, False, True]) + + +def test_lt_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([2.0, 2.0, 2.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.lt(a, b)), [True, False, False]) + + +def test_lt_array_and_scalar(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.lt(a, 2.5)), [True, True, False]) + + +def test_le_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([2.0, 2.0, 2.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.le(a, b)), [True, True, False]) + + +def test_le_array_and_scalar(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.le(a, 2.0)), [True, True, False]) + + +def test_gt_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([2.0, 2.0, 2.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.gt(a, b)), [False, False, True]) + + +def test_gt_array_and_scalar(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.gt(a, 1.5)), [False, True, True]) + + +def test_ge_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + b = iop.from_numpy(np.array([2.0, 2.0, 2.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.ge(a, b)), [False, True, True]) + + +def test_ge_array_and_scalar(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + np.testing.assert_array_equal(_np(iop.ge(a, 2.0)), [False, True, True]) + + +# Bitwise ---------------------------------------------------------------- + + +def test_bitwise_and_bool_arrays(backend: tuple) -> None: + a = iop.from_numpy(np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float32)) + mask1 = iop.gt(a, 1.0) + mask2 = iop.lt(a, 4.0) + np.testing.assert_array_equal(_np(iop.bitwise_and(mask1, mask2)), [False, True, True, False]) + + +def test_bitwise_and_int_arrays(backend: tuple) -> None: + # Bitwise on int dtypes is well-defined across all backends; bool tensors are + # tested separately because TF dispatches them to ``logical_and``. + a = iop.from_numpy(np.array([0b1100, 0b1010], dtype=np.int32)) + b = iop.from_numpy(np.array([0b1010, 0b0110], dtype=np.int32)) + np.testing.assert_array_equal(_np(iop.bitwise_and(a, b)), [0b1000, 0b0010]) + + +def test_argmax_default(backend: tuple) -> None: + arr = iop.from_numpy(np.array([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.argmax(arr)), 5) + + +def test_argmax_dim(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 5.0, 2.0], [4.0, 0.0, 3.0]], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.argmax(arr, axis=1)), [1, 0]) + + +def test_argmin_default(backend: tuple) -> None: + arr = iop.from_numpy(np.array([3.0, 1.0, 4.0, 1.0, 5.0, 9.0, 2.0], dtype=np.float32)) + # First occurrence of minimum is index 1. + np.testing.assert_allclose(_np(iop.argmin(arr)), 1) + + +def test_argmin_dim(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 5.0, 2.0], [4.0, 0.0, 3.0]], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.argmin(arr, axis=1)), [0, 1]) + + +def test_argmax_keepdims(backend: tuple) -> None: + arr = iop.from_numpy(np.array([[1.0, 5.0, 2.0], [4.0, 0.0, 3.0]], dtype=np.float32)) + out = iop.argmax(arr, axis=1, keepdims=True) + assert iop.shape(out) == (2, 1) + + +def test_set_item_function(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + iop.set_item(arr, 0, iop.from_numpy(np.array(99.0, dtype=np.float32))) + np.testing.assert_allclose(_np(arr), [99.0, 2.0, 3.0]) + + +def test_get_item_function(backend: tuple) -> None: + arr = iop.from_numpy(np.array([1.0, 2.0, 3.0], dtype=np.float32)) + np.testing.assert_allclose(_np(iop.get_item(arr, 1)), 2.0) + + +# No-backend errors ------------------------------------------------------ + + +def test_function_raises_when_no_backend() -> None: + reset_backends() + with pytest.raises(RuntimeError, match=r"No backend active"): + iop.zeros((3,)) + + +def test_to_array_round_trip_with_bool(backend: tuple) -> None: + arr = iop.to_array(True) + out = iop.astype(arr, bool) + assert out is True diff --git a/tests/test_iop_rng.py b/tests/test_iop_rng.py new file mode 100644 index 0000000..10bca67 --- /dev/null +++ b/tests/test_iop_rng.py @@ -0,0 +1,199 @@ +"""Tests for :mod:`decent_array.interoperability.iop.rng`.""" + +from __future__ import annotations + +import random + +import numpy as np +import pytest + +import decent_array.interoperability as iop +from decent_array import Array +from decent_array.interoperability._backend_manager import reset_backends +from decent_array.interoperability._iop import rng as iop_rng + + +def _np(arr: Array) -> np.ndarray: + return iop.to_numpy(arr) + + +# Seed management -------------------------------------------------------- + + +def test_set_seed_records_global(backend: tuple) -> None: + iop.set_seed(123) + assert iop.get_seed() == 123 + + +def test_set_seed_makes_normal_reproducible(backend: tuple) -> None: + iop.set_seed(7) + first = _np(iop.normal(shape=(4,))) + iop.set_seed(7) + second = _np(iop.normal(shape=(4,))) + np.testing.assert_allclose(first, second) + + +def test_set_seed_makes_uniform_reproducible(backend: tuple) -> None: + iop.set_seed(7) + first = _np(iop.uniform(shape=(4,))) + iop.set_seed(7) + second = _np(iop.uniform(shape=(4,))) + np.testing.assert_allclose(first, second) + + +def test_set_seed_seeds_python_random(backend: tuple) -> None: + iop.set_seed(11) + a = random.random() + iop.set_seed(11) + b = random.random() + assert a == b + + +def test_set_seed_without_global_keeps_seed(backend: tuple) -> None: + iop.set_seed(42) + iop_rng._set_seed_without_global(99) + # The "global" seed observable via get_seed() must be unchanged. + assert iop.get_seed() == 42 + + +def test_get_seed_initially_none(backend: tuple) -> None: + # Fresh activation via the fixture: nothing has called set_seed yet on this + # coordinator instance, but the coordinator is process-singleton and may have + # state from earlier tests. Call set_seed/then-clear to assert via reseed instead. + from decent_array.interoperability._iop.rng import _reset_rng # noqa: PLC0415, PLC2701 + + _reset_rng() + assert iop.get_seed() is None + iop.set_seed(0) + assert iop.get_seed() == 0 + + +# RNG state snapshot/restore -------------------------------------------- + + +def test_rng_state_round_trip_normal(backend: tuple) -> None: + iop.set_seed(123) + state = iop.get_rng_state() + first = _np(iop.normal(shape=(4,))) + iop.set_rng_state(state) + second = _np(iop.normal(shape=(4,))) + np.testing.assert_allclose(first, second) + + +def test_rng_state_includes_python_random(backend: tuple) -> None: + iop.set_seed(7) + state = iop.get_rng_state() + a = random.random() + iop.set_rng_state(state) + b = random.random() + assert a == b + + +def test_rng_state_round_trip_uniform(backend: tuple) -> None: + iop.set_seed(456) + state = iop.get_rng_state() + first = _np(iop.uniform(shape=(3,))) + iop.set_rng_state(state) + second = _np(iop.uniform(shape=(3,))) + np.testing.assert_allclose(first, second) + + +# derive_seed ----------------------------------------------------------- + + +def test_derive_seed_when_seed_set_is_reproducible(backend: tuple) -> None: + iop.set_seed(123) + a = iop.derive_seed() + iop.set_seed(123) + b = iop.derive_seed() + assert a == b + + +def test_derive_seed_returns_int_in_range(backend: tuple) -> None: + iop.set_seed(0) + seed = iop.derive_seed() + assert isinstance(seed, int) + assert 0 <= seed < 2**32 + + +# Distribution shape checks --------------------------------------------- + + +def test_normal_shape(backend: tuple) -> None: + arr = iop.normal(shape=(2, 3)) + assert iop.shape(arr) == (2, 3) + + +def test_normal_default_scalar(backend: tuple) -> None: + iop.set_seed(1) + arr = iop.normal() + assert iop.shape(arr) == () + + +def test_uniform_shape_and_range(backend: tuple) -> None: + iop.set_seed(1) + arr = iop.uniform(low=0.0, high=1.0, shape=(50,)) + samples = _np(arr) + assert samples.shape == (50,) + assert (samples >= 0.0).all() + assert (samples < 1.0).all() + + +def test_uniform_custom_range(backend: tuple) -> None: + iop.set_seed(1) + samples = _np(iop.uniform(low=-2.0, high=-1.0, shape=(50,))) + assert (samples >= -2.0).all() + assert (samples < -1.0).all() + + +def test_normal_like(backend: tuple) -> None: + src = iop.from_numpy(np.zeros((3, 4), dtype=np.float32)) + arr = iop.normal_like(src) + assert iop.shape(arr) == (3, 4) + + +def test_uniform_like(backend: tuple) -> None: + src = iop.from_numpy(np.zeros((3, 4), dtype=np.float32)) + arr = iop.uniform_like(src, low=0.0, high=1.0) + assert iop.shape(arr) == (3, 4) + samples = _np(arr) + assert (samples >= 0.0).all() + assert (samples < 1.0).all() + + +def test_choice_shape(backend: tuple) -> None: + iop.set_seed(1) + population = iop.from_numpy(np.array([10.0, 20.0, 30.0, 40.0, 50.0], dtype=np.float32)) + sample = iop.choice(population, size=3) + assert iop.shape(sample) == (3,) + + +def test_choice_values_in_population(backend: tuple) -> None: + iop.set_seed(1) + pop_np = np.array([10.0, 20.0, 30.0, 40.0, 50.0], dtype=np.float32) + sample = iop.choice(iop.from_numpy(pop_np), size=10) + drawn = _np(sample).reshape(-1) + assert all(v in pop_np for v in drawn) + + +def test_choice_no_replace_unique(backend: tuple) -> None: + iop.set_seed(1) + pop = iop.from_numpy(np.arange(20, dtype=np.float32)) + sample = iop.choice(pop, size=5, replace=False) + drawn = _np(sample).reshape(-1) + assert len(set(drawn.tolist())) == 5 + + +# No-backend errors ----------------------------------------------------- + + +def test_rng_function_raises_when_no_backend() -> None: + reset_backends() + with pytest.raises(RuntimeError, match=r"No backend active"): + iop_rng.normal(shape=(3,)) + + +def test_set_seed_raises_when_no_backend() -> None: + reset_backends() + with pytest.raises(RuntimeError, match=r"No backend active"): + iop_rng.set_seed(0)