Skip to content

Commit ae50d45

Browse files
authored
fix: make score nested if loop_is_running (#1276)
1 parent a4a3e56 commit ae50d45

File tree

1 file changed

+14
-8
lines changed

1 file changed

+14
-8
lines changed

src/ragas/metrics/base.py

+14-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
from ragas.callbacks import new_group
1919
from ragas.dataset_schema import MultiTurnSample, SingleTurnSample
20+
from ragas.executor import is_event_loop_running
2021
from ragas.run_config import RunConfig
2122
from ragas.utils import deprecated
2223

@@ -59,8 +60,7 @@ class Metric(ABC):
5960

6061
@property
6162
@abstractmethod
62-
def name(self) -> str:
63-
...
63+
def name(self) -> str: ...
6464

6565
@property
6666
def required_columns(self) -> t.Dict[str, t.Set[str]]:
@@ -103,6 +103,15 @@ def score(self: t.Self, row: t.Dict, callbacks: Callbacks = None) -> float:
103103
callbacks = callbacks or []
104104
rm, group_cm = new_group(self.name, inputs=row, callbacks=callbacks)
105105
try:
106+
if is_event_loop_running():
107+
try:
108+
import nest_asyncio
109+
110+
nest_asyncio.apply()
111+
except ImportError:
112+
raise ImportError(
113+
"It seems like your running this in a jupyter-like environment. Please install nest_asyncio with `pip install nest_asyncio` to make it work."
114+
)
106115
loop = asyncio.get_event_loop()
107116
score = loop.run_until_complete(self._ascore(row=row, callbacks=group_cm))
108117
except Exception as e:
@@ -138,8 +147,7 @@ async def ascore(
138147
return score
139148

140149
@abstractmethod
141-
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float:
142-
...
150+
async def _ascore(self, row: t.Dict, callbacks: Callbacks) -> float: ...
143151

144152

145153
@dataclass
@@ -246,8 +254,7 @@ async def _single_turn_ascore(
246254
self,
247255
sample: SingleTurnSample,
248256
callbacks: Callbacks,
249-
) -> float:
250-
...
257+
) -> float: ...
251258

252259

253260
class MultiTurnMetric(Metric):
@@ -299,8 +306,7 @@ async def _multi_turn_ascore(
299306
self,
300307
sample: MultiTurnSample,
301308
callbacks: Callbacks,
302-
) -> float:
303-
...
309+
) -> float: ...
304310

305311

306312
class Ensember:

0 commit comments

Comments
 (0)