diff --git a/dias/rewriter.py b/dias/rewriter.py index bc9e349..66397eb 100644 --- a/dias/rewriter.py +++ b/dias/rewriter.py @@ -372,12 +372,24 @@ def sort_head(called_on, by: str | None, n: int, asc: bool, orig: Callable): req_ty = pd.Series opt_func_obj = functools.partial(getattr(req_ty, func), n=n) - if type(called_on) == req_ty: - return opt_func_obj(self=called_on) + if type(called_on) != req_ty: + assert isinstance(orig, types.LambdaType) + return orig(called_on) + + # nsmallest() errors on columns of dtype object, while sort_values() doesn't + if by is None: + assert type(called_on) == pd.Series + ser = called_on else: + assert type(called_on) == pd.DataFrame + ser = called_on[by] + if ser.dtype == object: assert isinstance(orig, types.LambdaType) return orig(called_on) + return opt_func_obj(self=called_on) + + def substr_search_apply(ser, needle: str, orig: Callable): if type(ser) == pd.Series: ls = ser.tolist() diff --git a/tests/sort_head-object.ipynb b/tests/sort_head-object.ipynb new file mode 100644 index 0000000..e145608 --- /dev/null +++ b/tests/sort_head-object.ipynb @@ -0,0 +1,75 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import numpy as np\n", + "import dias.rewriter" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "df = pd.read_csv('titanic.csv')" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# DIAS_DISABLE\n", + "defa = df['Name'].sort_values().head(4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "our = df['Name'].sort_values().head(4)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Convert to list so that we don't take account of the index. See correctness note in sort_head.ipynb\n", + "comp = [x == y for x, y in zip(our.tolist(), defa.tolist())]\n", + "assert all(comp)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "base", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.13" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/sort_head-object.json b/tests/sort_head-object.json new file mode 100644 index 0000000..6f2184a --- /dev/null +++ b/tests/sort_head-object.json @@ -0,0 +1,24 @@ +{ + "cells": [ + { + "raw": "\ndf = pd.read_csv('titanic.csv')\n", + "modified": "df = pd.read_csv('titanic.csv')\n", + "patts-hit": {}, + "rewritten-exec-time": 3.460375 + }, + { + "raw": "\n# DIAS_VERBOSE\nour = df['Name'].sort_values().head(4)\n", + "modified": "our = dias.rewriter.sort_head(called_on=df['Name'], by=None, n=4, asc=True,\n orig=lambda _DIAS_x: _DIAS_x.sort_values().head(4))\n", + "patts-hit": { + "SortHead": 1 + }, + "rewritten-exec-time": 374.231 + }, + { + "raw": "\n# Convert to list so that we don't take account of the index. See correctness note in sort_head.ipynb\ncomp = [x == y for x, y in zip(our.tolist(), defa.tolist())]\nassert all(comp)\n", + "modified": "comp = [(x == y) for x, y in zip(our.tolist(), defa.tolist())]\nassert all(comp)\n", + "patts-hit": {}, + "rewritten-exec-time": 7.40225 + } + ] +} \ No newline at end of file