Skip to content

Commit 3441f60

Browse files
[minor][spark] add suport for types and print schema (#175)
Implements essential DataFrame introspection features for PySpark compatibility. These properties are used in virtually every PySpark script for schema exploration and debugging. 1. Added DataFrame.dtypes property 2. Implemented DataFrame.printSchema() method 3. Added StructType.treeString() method
2 parents 4ebe06f + 3f3a794 commit 3441f60

File tree

5 files changed

+291
-7
lines changed

5 files changed

+291
-7
lines changed

duckdb/experimental/spark/sql/dataframe.py

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
from duckdb import ColumnExpression, Expression, StarExpression
1616

1717
from ..errors import PySparkIndexError, PySparkTypeError, PySparkValueError
18-
from ..exception import ContributionsAcceptedError
1918
from .column import Column
2019
from .readwriter import DataFrameWriter
2120
from .type_utils import duckdb_to_spark_schema
@@ -569,6 +568,22 @@ def columns(self) -> list[str]:
569568
"""
570569
return [f.name for f in self.schema.fields]
571570

571+
@property
572+
def dtypes(self) -> list[tuple[str, str]]:
573+
"""Returns all column names and their data types as a list of tuples.
574+
575+
Returns:
576+
-------
577+
list of tuple
578+
List of tuples, each tuple containing a column name and its data type as strings.
579+
580+
Examples:
581+
--------
582+
>>> df.dtypes
583+
[('age', 'bigint'), ('name', 'string')]
584+
"""
585+
return [(f.name, f.dataType.simpleString()) for f in self.schema.fields]
586+
572587
def _ipython_key_completions_(self) -> list[str]:
573588
# Provides tab-completion for column names in PySpark DataFrame
574589
# when accessed in bracket notation, e.g. df['<TAB>]
@@ -982,8 +997,27 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
982997
def write(self) -> DataFrameWriter: # noqa: D102
983998
return DataFrameWriter(self)
984999

985-
def printSchema(self) -> None: # noqa: D102
986-
raise ContributionsAcceptedError
1000+
def printSchema(self, level: Optional[int] = None) -> None:
1001+
"""Prints out the schema in the tree format.
1002+
1003+
Parameters
1004+
----------
1005+
level : int, optional
1006+
How many levels to print for nested schemas. Prints all levels by default.
1007+
1008+
Examples:
1009+
--------
1010+
>>> df.printSchema()
1011+
root
1012+
|-- age: bigint (nullable = true)
1013+
|-- name: string (nullable = true)
1014+
"""
1015+
if level is not None and level < 0:
1016+
raise PySparkValueError(
1017+
error_class="NEGATIVE_VALUE",
1018+
message_parameters={"arg_name": "level", "arg_value": str(level)},
1019+
)
1020+
print(self.schema.treeString(level))
9871021

9881022
def union(self, other: "DataFrame") -> "DataFrame":
9891023
"""Return a new :class:`DataFrame` containing union of rows in this and another

duckdb/experimental/spark/sql/types.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -894,6 +894,77 @@ def fieldNames(self) -> list[str]:
894894
"""
895895
return list(self.names)
896896

897+
def treeString(self, level: Optional[int] = None) -> str:
898+
"""Returns a string representation of the schema in tree format.
899+
900+
Parameters
901+
----------
902+
level : int, optional
903+
Maximum depth to print. If None, prints all levels.
904+
905+
Returns:
906+
-------
907+
str
908+
Tree-formatted schema string
909+
910+
Examples:
911+
--------
912+
>>> schema = StructType([StructField("age", IntegerType(), True)])
913+
>>> print(schema.treeString())
914+
root
915+
|-- age: integer (nullable = true)
916+
"""
917+
918+
def _tree_string(schema: "StructType", depth: int = 0, max_depth: Optional[int] = None) -> list[str]:
919+
"""Recursively build tree string lines."""
920+
lines = []
921+
if depth == 0:
922+
lines.append("root")
923+
924+
if max_depth is not None and depth >= max_depth:
925+
return lines
926+
927+
for field in schema.fields:
928+
indent = " " * depth
929+
prefix = " |-- "
930+
nullable_str = "true" if field.nullable else "false"
931+
932+
# Handle nested StructType
933+
if isinstance(field.dataType, StructType):
934+
lines.append(f"{indent}{prefix}{field.name}: struct (nullable = {nullable_str})")
935+
# Recursively handle nested struct - don't skip any lines, root only appears at depth 0
936+
nested_lines = _tree_string(field.dataType, depth + 1, max_depth)
937+
lines.extend(nested_lines)
938+
# Handle ArrayType
939+
elif isinstance(field.dataType, ArrayType):
940+
element_type = field.dataType.elementType
941+
if isinstance(element_type, StructType):
942+
lines.append(f"{indent}{prefix}{field.name}: array (nullable = {nullable_str})")
943+
lines.append(
944+
f"{indent} | |-- element: struct (containsNull = {field.dataType.containsNull})"
945+
)
946+
nested_lines = _tree_string(element_type, depth + 2, max_depth)
947+
lines.extend(nested_lines)
948+
else:
949+
type_str = element_type.simpleString()
950+
lines.append(f"{indent}{prefix}{field.name}: array<{type_str}> (nullable = {nullable_str})")
951+
# Handle MapType
952+
elif isinstance(field.dataType, MapType):
953+
key_type = field.dataType.keyType.simpleString()
954+
value_type = field.dataType.valueType.simpleString()
955+
lines.append(
956+
f"{indent}{prefix}{field.name}: map<{key_type},{value_type}> (nullable = {nullable_str})"
957+
)
958+
# Handle simple types
959+
else:
960+
type_str = field.dataType.simpleString()
961+
lines.append(f"{indent}{prefix}{field.name}: {type_str} (nullable = {nullable_str})")
962+
963+
return lines
964+
965+
lines = _tree_string(self, 0, level)
966+
return "\n".join(lines)
967+
897968
def needConversion(self) -> bool: # noqa: D102
898969
# We need convert Row()/namedtuple into tuple()
899970
return True

tests/fast/spark/test_spark_dataframe.py

Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,3 +427,173 @@ def test_cache(self, spark):
427427
assert df is not cached
428428
assert cached.collect() == df.collect()
429429
assert cached.collect() == [Row(one=1, two=2, three=3, four=4)]
430+
431+
def test_dtypes(self, spark):
432+
data = [("Alice", 25, 5000.0), ("Bob", 30, 6000.0)]
433+
df = spark.createDataFrame(data, ["name", "age", "salary"])
434+
dtypes = df.dtypes
435+
436+
assert isinstance(dtypes, list)
437+
assert len(dtypes) == 3
438+
for col_name, col_type in dtypes:
439+
assert isinstance(col_name, str)
440+
assert isinstance(col_type, str)
441+
442+
col_names = [name for name, _ in dtypes]
443+
assert col_names == ["name", "age", "salary"]
444+
for _, col_type in dtypes:
445+
assert len(col_type) > 0
446+
447+
def test_dtypes_complex_types(self, spark):
448+
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType
449+
450+
schema = StructType(
451+
[
452+
StructField("name", StringType(), True),
453+
StructField("scores", ArrayType(IntegerType()), True),
454+
StructField(
455+
"address",
456+
StructType([StructField("city", StringType(), True), StructField("zip", StringType(), True)]),
457+
True,
458+
),
459+
]
460+
)
461+
data = [
462+
("Alice", [90, 85, 88], {"city": "NYC", "zip": "10001"}),
463+
("Bob", [75, 80, 82], {"city": "LA", "zip": "90001"}),
464+
]
465+
df = spark.createDataFrame(data, schema)
466+
dtypes = df.dtypes
467+
468+
assert len(dtypes) == 3
469+
col_names = [name for name, _ in dtypes]
470+
assert col_names == ["name", "scores", "address"]
471+
472+
def test_printSchema(self, spark, capsys):
473+
data = [("Alice", 25, 5000), ("Bob", 30, 6000)]
474+
df = spark.createDataFrame(data, ["name", "age", "salary"])
475+
df.printSchema()
476+
captured = capsys.readouterr()
477+
output = captured.out
478+
479+
assert "root" in output
480+
assert "name" in output
481+
assert "age" in output
482+
assert "salary" in output
483+
assert "string" in output or "varchar" in output.lower()
484+
assert "int" in output.lower() or "bigint" in output.lower()
485+
486+
def test_printSchema_nested(self, spark, capsys):
487+
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType
488+
489+
schema = StructType(
490+
[
491+
StructField("id", IntegerType(), True),
492+
StructField(
493+
"person",
494+
StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]),
495+
True,
496+
),
497+
StructField("hobbies", ArrayType(StringType()), True),
498+
]
499+
)
500+
data = [
501+
(1, {"name": "Alice", "age": 25}, ["reading", "coding"]),
502+
(2, {"name": "Bob", "age": 30}, ["gaming", "music"]),
503+
]
504+
df = spark.createDataFrame(data, schema)
505+
df.printSchema()
506+
captured = capsys.readouterr()
507+
output = captured.out
508+
509+
assert "root" in output
510+
assert "person" in output
511+
assert "hobbies" in output
512+
513+
def test_printSchema_negative_level(self, spark):
514+
data = [("Alice", 25)]
515+
df = spark.createDataFrame(data, ["name", "age"])
516+
517+
with pytest.raises(PySparkValueError):
518+
df.printSchema(level=-1)
519+
520+
def test_treeString_basic(self, spark):
521+
data = [("Alice", 25, 5000)]
522+
df = spark.createDataFrame(data, ["name", "age", "salary"])
523+
tree = df.schema.treeString()
524+
525+
assert tree.startswith("root\n")
526+
assert " |-- name:" in tree
527+
assert " |-- age:" in tree
528+
assert " |-- salary:" in tree
529+
assert "(nullable = true)" in tree
530+
assert tree.count(" |-- ") == 3
531+
532+
def test_treeString_nested_struct(self, spark):
533+
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType
534+
535+
schema = StructType(
536+
[
537+
StructField("id", IntegerType(), True),
538+
StructField(
539+
"person",
540+
StructType([StructField("name", StringType(), True), StructField("age", IntegerType(), True)]),
541+
True,
542+
),
543+
]
544+
)
545+
data = [(1, {"name": "Alice", "age": 25})]
546+
df = spark.createDataFrame(data, schema)
547+
tree = df.schema.treeString()
548+
549+
assert "root\n" in tree
550+
assert " |-- id:" in tree
551+
assert " |-- person: struct (nullable = true)" in tree
552+
assert "name:" in tree
553+
assert "age:" in tree
554+
555+
def test_treeString_with_level(self, spark):
556+
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType
557+
558+
schema = StructType(
559+
[
560+
StructField("id", IntegerType(), True),
561+
StructField(
562+
"person",
563+
StructType(
564+
[
565+
StructField("name", StringType(), True),
566+
StructField("details", StructType([StructField("address", StringType(), True)]), True),
567+
]
568+
),
569+
True,
570+
),
571+
]
572+
)
573+
574+
data = [(1, {"name": "Alice", "details": {"address": "123 Main St"}})]
575+
df = spark.createDataFrame(data, schema)
576+
577+
# Level 1 should only show top-level fields
578+
tree_level_1 = df.schema.treeString(level=1)
579+
assert " |-- id:" in tree_level_1
580+
assert " |-- person: struct" in tree_level_1
581+
# Should not show nested field names at level 1
582+
lines = tree_level_1.split("\n")
583+
assert len([line for line in lines if line.strip()]) <= 3
584+
585+
def test_treeString_array_type(self, spark):
586+
from spark_namespace.sql.types import ArrayType, StringType, StructField, StructType
587+
588+
schema = StructType(
589+
[StructField("name", StringType(), True), StructField("hobbies", ArrayType(StringType()), True)]
590+
)
591+
592+
data = [("Alice", ["reading", "coding"])]
593+
df = spark.createDataFrame(data, schema)
594+
tree = df.schema.treeString()
595+
596+
assert "root\n" in tree
597+
assert " |-- name:" in tree
598+
assert " |-- hobbies: array<" in tree
599+
assert "(nullable = true)" in tree

tests/fast/test_insert.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def test_insert_with_schema(self, duckdb_cursor):
2727
res = duckdb_cursor.table("not_main.tbl").fetchall()
2828
assert len(res) == 10
2929

30-
# TODO: This is not currently supported # noqa: TD002, TD003
31-
with pytest.raises(duckdb.CatalogException, match="Table with name tbl does not exist"):
32-
duckdb_cursor.table("not_main.tbl").insert([42, 21, 1337])
30+
# Insert into a schema-qualified table should work; table has a single column from range(10)
31+
duckdb_cursor.table("not_main.tbl").insert([42])
32+
res2 = duckdb_cursor.table("not_main.tbl").fetchall()
33+
assert len(res2) == 11
34+
assert (42,) in res2

tests/fast/test_relation.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -280,7 +280,14 @@ def test_value_relation(self, duckdb_cursor):
280280
rel = duckdb_cursor.values((const(1), const(2), const(3)), const(4))
281281

282282
# Using Expressions that can't be resolved:
283-
with pytest.raises(duckdb.BinderException, match='Referenced column "a" not found in FROM clause!'):
283+
# Accept both historical and current Binder error message variants
284+
with pytest.raises(
285+
duckdb.BinderException,
286+
match=(
287+
r'Referenced column "a" not found in FROM clause!|'
288+
r'Referenced column "a" was not found because the FROM clause is missing'
289+
),
290+
):
284291
duckdb_cursor.values(duckdb.ColumnExpression("a"))
285292

286293
def test_insert_into_operator(self):

0 commit comments

Comments
 (0)