Skip to content
Draft
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ venv/
dist/
__pycache__/
Pipfile.lock
uv.lock
.ruff_cache/
.vscode
python/example/test.py
Expand Down
63 changes: 58 additions & 5 deletions python/databend_udf/udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,13 @@
from prometheus_client import start_http_server
import threading

from fastapi import FastAPI, Request, Response
from typing import Any, Dict
from uvicorn import run
import pyarrow as pa
from pyarrow import ipc
from io import BytesIO

import pyarrow as pa
from pyarrow.flight import FlightServerBase, FlightInfo

Expand Down Expand Up @@ -232,11 +239,19 @@ class UDFServer(FlightServerBase):
_location: str
_functions: Dict[str, UserDefinedFunction]

def __init__(self, location="0.0.0.0:8815", metric_location=None, **kwargs):
def __init__(
self,
location="0.0.0.0:8815",
metric_location=None,
http_location=None,
**kwargs,
):
super(UDFServer, self).__init__("grpc://" + location, **kwargs)
self._location = location
self._metric_location = metric_location
self._http_location = http_location
self._functions = {}
self.app = FastAPI()

# Initialize Prometheus metrics
self.requests_count = Counter(
Expand Down Expand Up @@ -296,16 +311,33 @@ def _start_metrics_server(self):

def start_server():
start_http_server(port, host)
logger.info(
f"Prometheus metrics server started on {self._metric_location}"
)

metrics_thread = threading.Thread(target=start_server, daemon=True)
metrics_thread.start()
except Exception as e:
logger.error(f"Failed to start metrics server: {e}")
raise

def _start_httpudf_server(self):
"""Start UDF HTTP server if http_location is provided"""
try:
host, port = self._http_location.split(":")
port = int(port)
app = self.app

@app.get("/_is_http")
async def is_http():
return 1

def start_server():
run(app, host=host, port=port)

http_thread = threading.Thread(target=start_server, daemon=True)
http_thread.start()
except Exception as e:
logger.error(f"Failed to start http udf server: {e}")
raise

def get_flight_info(self, context, descriptor):
"""Return the result schema of a function."""
func_name = descriptor.path[0].decode("utf-8")
Expand Down Expand Up @@ -377,6 +409,25 @@ def add_function(self, udf: UserDefinedFunction):
f"RETURNS {output_type} LANGUAGE python "
f"HANDLER = '{name}' ADDRESS = 'http://{self._location}';"
)

## http router register
@self.app.post("/" + name)
async def handle(request: Request):
# Deserialize the RecordBatch from the input data
body = await request.body()
reader = pa.ipc.open_stream(BytesIO(body))
batches = [b for b in reader]
# Apply the UDF to the data
result_batches = [udf.eval_batch(batch) for batch in batches]
# Serialize the result to send it back
buf = BytesIO()
writer = pa.ipc.new_stream(buf, udf._result_schema)
for batch in result_batches:
for b in batch:
writer.write_batch(b)
writer.close()
return Response(content=buf.getvalue(), media_type="text/plain")

logger.info(f"added function: {name}, SQL:\n{sql}\n")

def serve(self):
Expand All @@ -387,7 +438,9 @@ def serve(self):
logger.info(
f"Prometheus metrics available at http://{self._metric_location}/metrics"
)

if self._http_location:
self._start_httpudf_server()
logger.info(f"UDF HTTP SERVER available at http://{self._http_location}")
super(UDFServer, self).serve()


Expand Down
41 changes: 41 additions & 0 deletions python/example/http_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import requests
import pyarrow as pa
from pyarrow import ipc
from io import BytesIO


def main():
# Create a RecordBatch
data = [
pa.array([1, 2, 3, 4]),
pa.array([5, 6, 7, 8]),
]
schema = pa.schema(
[
("a", pa.int32()),
("b", pa.int32()),
]
)
batch = pa.RecordBatch.from_arrays(data, schema=schema)

# Serialize the RecordBatch
buf = BytesIO()
writer = pa.ipc.new_stream(buf, batch.schema)
writer.write_batch(batch)
writer.close()
serialized_batch = buf.getvalue()

# Send the serialized RecordBatch to the server
response = requests.post("http://localhost:8818/gcd", data=serialized_batch)

# Deserialize the response
reader = pa.ipc.open_stream(BytesIO(response.content))
result_batches = [b for b in reader]

# Print the result
for batch in result_batches:
print("res \n", batch)


if __name__ == "__main__":
main()
10 changes: 8 additions & 2 deletions python/example/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,12 @@
import time
from typing import List, Dict, Any, Tuple, Optional

import sys, os

cwd = os.getcwd()
sys.path.append(cwd)

from databend_udf import udf, UDFServer
# from test import udf, UDFServer

logging.basicConfig(level=logging.INFO)

Expand Down Expand Up @@ -314,7 +318,9 @@ def wait_concurrent(x):


if __name__ == "__main__":
udf_server = UDFServer("0.0.0.0:8815", metric_location="0.0.0.0:8816")
udf_server = UDFServer(
"0.0.0.0:8815", metric_location="0.0.0.0:8816", http_location="0.0.0.0:8818"
)
udf_server.add_function(add_signed)
udf_server.add_function(add_unsigned)
udf_server.add_function(add_float)
Expand Down
8 changes: 7 additions & 1 deletion python/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,14 @@ readme = "README.md"
requires-python = ">=3.7"
dependencies = [
"pyarrow",
"prometheus-client>=0.17.0"
"prometheus-client>=0.17.0",
"fastapi>=0.103.2",
"uvicorn>=0.22.0",
]

[dev-dependencies]
requests = ">=2.31.0"

[project.optional-dependencies]
lint = ["ruff"]

Expand Down