Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 37 additions & 2 deletions duckdb/experimental/spark/sql/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,6 +569,22 @@ def columns(self) -> list[str]:
"""
return [f.name for f in self.schema.fields]

@property
def dtypes(self) -> list[tuple[str, str]]:
"""Returns all column names and their data types as a list of tuples.

Returns:
-------
list of tuple
List of tuples, each tuple containing a column name and its data type as strings.

Examples:
--------
>>> df.dtypes
[('age', 'bigint'), ('name', 'string')]
"""
return [(f.name, f.dataType.simpleString()) for f in self.schema.fields]

def _ipython_key_completions_(self) -> list[str]:
# Provides tab-completion for column names in PySpark DataFrame
# when accessed in bracket notation, e.g. df['<TAB>]
Expand Down Expand Up @@ -982,8 +998,27 @@ def groupBy(self, *cols: "ColumnOrName") -> "GroupedData": # type: ignore[misc]
def write(self) -> DataFrameWriter: # noqa: D102
return DataFrameWriter(self)

def printSchema(self) -> None: # noqa: D102
raise ContributionsAcceptedError
def printSchema(self, level: Optional[int] = None) -> None:
"""Prints out the schema in the tree format.

Parameters
----------
level : int, optional
How many levels to print for nested schemas. Prints all levels by default.

Examples:
--------
>>> df.printSchema()
root
|-- age: bigint (nullable = true)
|-- name: string (nullable = true)
"""
if level is not None and level < 0:
raise PySparkValueError(
error_class="NEGATIVE_VALUE",
message_parameters={"arg_name": "level", "arg_value": str(level)},
)
print(self.schema.treeString(level))

def union(self, other: "DataFrame") -> "DataFrame":
"""Return a new :class:`DataFrame` containing union of rows in this and another
Expand Down
66 changes: 66 additions & 0 deletions duckdb/experimental/spark/sql/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,72 @@ def fieldNames(self) -> list[str]:
"""
return list(self.names)

def treeString(self, level: Optional[int] = None) -> str:
"""Returns a string representation of the schema in tree format.

Parameters
----------
level : int, optional
Maximum depth to print. If None, prints all levels.

Returns:
-------
str
Tree-formatted schema string

Examples:
--------
>>> schema = StructType([StructField("age", IntegerType(), True)])
>>> print(schema.treeString())
root
|-- age: integer (nullable = true)
"""
def _tree_string(schema: "StructType", depth: int = 0, max_depth: Optional[int] = None) -> list[str]:
"""Recursively build tree string lines."""
lines = []
if depth == 0:
lines.append("root")

if max_depth is not None and depth >= max_depth:
return lines

for field in schema.fields:
indent = " " * depth
prefix = " |-- "
nullable_str = "true" if field.nullable else "false"

# Handle nested StructType
if isinstance(field.dataType, StructType):
lines.append(f"{indent}{prefix}{field.name}: struct (nullable = {nullable_str})")
# Recursively handle nested struct - don't skip any lines, root only appears at depth 0
nested_lines = _tree_string(field.dataType, depth + 1, max_depth)
lines.extend(nested_lines)
# Handle ArrayType
elif isinstance(field.dataType, ArrayType):
element_type = field.dataType.elementType
if isinstance(element_type, StructType):
lines.append(f"{indent}{prefix}{field.name}: array (nullable = {nullable_str})")
lines.append(f"{indent} | |-- element: struct (containsNull = {field.dataType.containsNull})")
nested_lines = _tree_string(element_type, depth + 2, max_depth)
lines.extend(nested_lines)
else:
type_str = element_type.simpleString()
lines.append(f"{indent}{prefix}{field.name}: array<{type_str}> (nullable = {nullable_str})")
# Handle MapType
elif isinstance(field.dataType, MapType):
key_type = field.dataType.keyType.simpleString()
value_type = field.dataType.valueType.simpleString()
lines.append(f"{indent}{prefix}{field.name}: map<{key_type},{value_type}> (nullable = {nullable_str})")
# Handle simple types
else:
type_str = field.dataType.simpleString()
lines.append(f"{indent}{prefix}{field.name}: {type_str} (nullable = {nullable_str})")

return lines

lines = _tree_string(self, 0, level)
return "\n".join(lines)

def needConversion(self) -> bool: # noqa: D102
# We need convert Row()/namedtuple into tuple()
return True
Expand Down
154 changes: 154 additions & 0 deletions tests/fast/spark/test_spark_dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,3 +427,157 @@ def test_cache(self, spark):
assert df is not cached
assert cached.collect() == df.collect()
assert cached.collect() == [Row(one=1, two=2, three=3, four=4)]

def test_dtypes(self, spark):
data = [("Alice", 25, 5000.0), ("Bob", 30, 6000.0)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
dtypes = df.dtypes

assert isinstance(dtypes, list)
assert len(dtypes) == 3
for col_name, col_type in dtypes:
assert isinstance(col_name, str)
assert isinstance(col_type, str)

col_names = [name for name, _ in dtypes]
assert col_names == ["name", "age", "salary"]
for _, col_type in dtypes:
assert len(col_type) > 0

def test_dtypes_complex_types(self, spark):
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType
schema = StructType([
StructField("name", StringType(), True),
StructField("scores", ArrayType(IntegerType()), True),
StructField("address", StructType([
StructField("city", StringType(), True),
StructField("zip", StringType(), True)
]), True)
])
data = [
("Alice", [90, 85, 88], {"city": "NYC", "zip": "10001"}),
("Bob", [75, 80, 82], {"city": "LA", "zip": "90001"})
]
df = spark.createDataFrame(data, schema)
dtypes = df.dtypes

assert len(dtypes) == 3
col_names = [name for name, _ in dtypes]
assert col_names == ["name", "scores", "address"]

def test_printSchema(self, spark, capsys):
data = [("Alice", 25, 5000), ("Bob", 30, 6000)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
df.printSchema()
captured = capsys.readouterr()
output = captured.out

assert "root" in output
assert "name" in output
assert "age" in output
assert "salary" in output
assert "string" in output or "varchar" in output.lower()
assert "int" in output.lower() or "bigint" in output.lower()

def test_printSchema_nested(self, spark, capsys):
from spark_namespace.sql.types import ArrayType, IntegerType, StringType, StructField, StructType
schema = StructType([
StructField("id", IntegerType(), True),
StructField("person", StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True)
]), True),
StructField("hobbies", ArrayType(StringType()), True)
])
data = [
(1, {"name": "Alice", "age": 25}, ["reading", "coding"]),
(2, {"name": "Bob", "age": 30}, ["gaming", "music"])
]
df = spark.createDataFrame(data, schema)
df.printSchema()
captured = capsys.readouterr()
output = captured.out

assert "root" in output
assert "person" in output
assert "hobbies" in output

def test_printSchema_negative_level(self, spark):
data = [("Alice", 25)]
df = spark.createDataFrame(data, ["name", "age"])

with pytest.raises(PySparkValueError):
df.printSchema(level=-1)

def test_treeString_basic(self, spark):
data = [("Alice", 25, 5000)]
df = spark.createDataFrame(data, ["name", "age", "salary"])
tree = df.schema.treeString()

assert tree.startswith("root\n")
assert " |-- name:" in tree
assert " |-- age:" in tree
assert " |-- salary:" in tree
assert "(nullable = true)" in tree
assert tree.count(" |-- ") == 3

def test_treeString_nested_struct(self, spark):
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType

schema = StructType([
StructField("id", IntegerType(), True),
StructField("person", StructType([
StructField("name", StringType(), True),
StructField("age", IntegerType(), True)
]), True)
])
data = [(1, {"name": "Alice", "age": 25})]
df = spark.createDataFrame(data, schema)
tree = df.schema.treeString()

assert "root\n" in tree
assert " |-- id:" in tree
assert " |-- person: struct (nullable = true)" in tree
assert "name:" in tree
assert "age:" in tree

def test_treeString_with_level(self, spark):
from spark_namespace.sql.types import IntegerType, StringType, StructField, StructType

schema = StructType([
StructField("id", IntegerType(), True),
StructField("person", StructType([
StructField("name", StringType(), True),
StructField("details", StructType([
StructField("address", StringType(), True)
]), True)
]), True)
])

data = [(1, {"name": "Alice", "details": {"address": "123 Main St"}})]
df = spark.createDataFrame(data, schema)

# Level 1 should only show top-level fields
tree_level_1 = df.schema.treeString(level=1)
assert " |-- id:" in tree_level_1
assert " |-- person: struct" in tree_level_1
# Should not show nested field names at level 1
lines = tree_level_1.split('\n')
assert len([l for l in lines if l.strip()]) <= 3

def test_treeString_array_type(self, spark):
from spark_namespace.sql.types import ArrayType, StringType, StructField, StructType

schema = StructType([
StructField("name", StringType(), True),
StructField("hobbies", ArrayType(StringType()), True)
])

data = [("Alice", ["reading", "coding"])]
df = spark.createDataFrame(data, schema)
tree = df.schema.treeString()

assert "root\n" in tree
assert " |-- name:" in tree
assert " |-- hobbies: array<" in tree
assert "(nullable = true)" in tree
Loading