Skip to content

feat: add scalar array selection in geos-trame #103

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion geos-trame/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dependencies = [
"trame-gantt==0.1.5",
"xsdata==24.5",
"xsdata-pydantic[lxml]==24.5",
"pyvista==0.44.1",
"pyvista==0.45.2",
"dpath==2.2.0",
"colorcet==3.1.0",
"funcy==2.0",
Expand Down
4 changes: 2 additions & 2 deletions geos-trame/src/geos/trame/app/components/alertHandler.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@ def __init__( self ) -> None:

self.state.alerts = []

self.server.controller.on_add_error.add_task( self.add_error )
self.server.controller.on_add_warning.add_task( self.add_warning )
self.ctrl.on_add_error.add_task( self.add_error )
self.ctrl.on_add_warning.add_task( self.add_warning )

self.generate_alert_ui()

Expand Down
5 changes: 4 additions & 1 deletion geos-trame/src/geos/trame/app/io/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from geos.trame.app.geosTrameException import GeosTrameException
from geos.trame.app.ui.viewer.regionViewer import RegionViewer
from geos.trame.app.ui.viewer.wellViewer import WellViewer
from geos.trame.app.utils.pv_utils import read_unstructured_grid
from geos.trame.app.utils.pv_utils import read_unstructured_grid, split_vector_arrays
from geos.trame.schema_generated.schema_mod import (
Vtkmesh,
Vtkwell,
Expand Down Expand Up @@ -97,6 +97,9 @@ def _update_vtkmesh( self, mesh: Vtkmesh, show: bool ) -> None:

def _read_mesh( self, mesh: Vtkmesh ) -> None:
unstructured_grid = read_unstructured_grid( self.source.get_abs_path( mesh.file ) )
split_vector_arrays( unstructured_grid )

unstructured_grid.set_active_scalars( unstructured_grid.cell_data.keys()[ 0 ] )
self.region_viewer.add_mesh( unstructured_grid )

def _update_vtkwell( self, well: Vtkwell, path: str, show: bool ) -> None:
Expand Down
7 changes: 1 addition & 6 deletions geos-trame/src/geos/trame/app/ui/viewer/regionViewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,6 @@ def __init__( self ) -> None:
"""
self.input = pv.UnstructuredGrid()
self.clip = self.input
self.reset()

def __call__( self, normal: tuple[ float ], origin: tuple[ float ] ) -> None:
"""Update clip."""
self.update_clip( normal, origin )

def add_mesh( self, mesh: pv.UnstructuredGrid ) -> None:
"""Set the input to the given mesh."""
Expand All @@ -26,7 +21,7 @@ def add_mesh( self, mesh: pv.UnstructuredGrid ) -> None:

def update_clip( self, normal: tuple[ float ], origin: tuple[ float ] ) -> None:
"""Update the current clip with the given normal and origin."""
self.clip.copy_from( self.input.clip( normal=normal, origin=origin, crinkle=True ) ) # type: ignore
self.clip = self.input.clip( normal=normal, origin=origin, crinkle=True ) # type: ignore

def reset( self ) -> None:
"""Reset the input mesh and clip."""
Expand Down
117 changes: 78 additions & 39 deletions geos-trame/src/geos/trame/app/ui/viewer/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,7 @@
from geos.trame.app.ui.viewer.perforationViewer import PerforationViewer
from geos.trame.app.ui.viewer.regionViewer import RegionViewer
from geos.trame.app.ui.viewer.wellViewer import WellViewer
from geos.trame.schema_generated.schema_mod import (
Vtkmesh,
Vtkwell,
Perforation,
InternalWell,
)
from geos.trame.schema_generated.schema_mod import Vtkmesh, Vtkwell, InternalWell, Perforation

pv.OFF_SCREEN = True

Expand Down Expand Up @@ -49,13 +44,20 @@ def __init__(
"""
super().__init__( **kwargs )

self._point_data_array_names: list[ str ] = []
self._cell_data_array_names: list[ str ] = []
self._source = source
self._pl = pv.Plotter()
self._mesh_actor: vtkActor | None = None

self.CUT_PLANE = "on_cut_plane_visibility_change"
self.ZAMPLIFICATION = "_z_amplification"
self.server.state[ self.CUT_PLANE ] = True
self.server.state[ self.ZAMPLIFICATION ] = 1
self.state[ self.CUT_PLANE ] = True
self.state[ self.ZAMPLIFICATION ] = 1

self.DATA_ARRAYS = "viewer_data_arrays_items"
self.SELECTED_DATA_ARRAY = "viewer_selected_data_array"
self.state.change( self.SELECTED_DATA_ARRAY )( self._update_actor_array )

self.region_engine = region_viewer
self.well_engine = well_viewer
Expand All @@ -68,8 +70,9 @@ def __init__(
view = plotter_ui(
self._pl,
add_menu_items=self.rendering_menu_extra_items,
style="position: absolute;",
)
view.menu.style += "; height: 50px; min-width: 50px;"
view.menu.children[ 0 ].style += "; justify-content: center;"
self.ctrl.view_update = view.update

@property
Expand All @@ -88,21 +91,33 @@ def rendering_menu_extra_items( self ) -> None:
For now, adding a button to show/hide all widgets.
"""
self.state.change( self.CUT_PLANE )( self._on_clip_visibility_change )
vuetify.VDivider( vertical=True, classes="mr-1" )
with vuetify.VTooltip( location="bottom" ):
with (
vuetify.Template( v_slot_activator=( "{ props }", ) ),
html.Div( v_bind=( "props", ) ),
):
vuetify.VCheckbox(
v_model=( self.CUT_PLANE, True ),
icon=True,
true_icon="mdi-eye",
false_icon="mdi-eye-off",
dense=True,
hide_details=True,
)
html.Span( "Show/Hide widgets" )
with vuetify.VRow(
classes='pa-0 ma-0 align-center fill-height',
style="flex-wrap: nowrap",
):
vuetify.VDivider( vertical=True, classes="mr-1" )
with vuetify.VTooltip( location="bottom" ):
with (
vuetify.Template( v_slot_activator=( "{ props }", ) ),
html.Div( v_bind=( "props", ) ),
):
vuetify.VCheckbox(
v_model=( self.CUT_PLANE, True ),
icon=True,
true_icon="mdi-eye",
false_icon="mdi-eye-off",
dense=True,
hide_details=True,
)
html.Span( "Show/Hide widgets" )
vuetify.VDivider( vertical=True, classes="mr-1" )
vuetify.VSelect(
hide_details=True,
label="Data Array",
items=( self.DATA_ARRAYS, [] ),
v_model=( self.SELECTED_DATA_ARRAY, None ),
min_width="150px",
)

def update_viewer( self, active_block: BaseModel, path: str, show_obj: bool ) -> None:
"""Add from path the dataset given by the user.
Expand Down Expand Up @@ -205,7 +220,7 @@ def _update_internalwell( self, path: str, show: bool ) -> None:
tube_actor = self.plotter.add_mesh( self.well_engine.get_tube( self.well_engine.get_last_mesh_idx() ) )
self.well_engine.append_actor( path, tube_actor )

self.server.controller.view_update()
self.ctrl.view_update()

def _update_vtkwell( self, path: str, show: bool ) -> None:
"""Used to control the visibility of the Vtkwell.
Expand All @@ -219,7 +234,30 @@ def _update_vtkwell( self, path: str, show: bool ) -> None:
tube_actor = self.plotter.add_mesh( self.well_engine.get_tube( self.well_engine.get_last_mesh_idx() ) )
self.well_engine.append_actor( path, tube_actor )

self.server.controller.view_update()
self.ctrl.view_update()

def _clip_mesh( self, normal: tuple[ float ], origin: tuple[ float ] ) -> None:
"""Plane widget callback to clip the input data."""
if self._mesh_actor is None:
return
self.region_engine.update_clip( normal=normal, origin=origin )
self._mesh_actor.mapper.SetInputData( self.region_engine.clip )
self._update_actor_array()

def _update_actor_array( self, **_: Any ) -> None:
"""Update the actor scalar array."""
array_name = self.state[ self.SELECTED_DATA_ARRAY ]
if array_name is None or self._mesh_actor is None:
return
mapper: pv.DataSetMapper = self._mesh_actor.mapper

mapper.array_name = array_name
mapper.scalar_range = self.region_engine.clip.get_data_range( array_name )
self.region_engine.clip.active_scalars_name = array_name
mapper.scalar_map_mode = "point" if array_name in self._point_data_array_names else "cell"

self.plotter.scalar_bar.title = array_name
self.ctrl.view_update()

def _update_vtkmesh( self, show: bool ) -> None:
"""Used to control the visibility of the Vtkmesh.
Expand All @@ -230,21 +268,22 @@ def _update_vtkmesh( self, show: bool ) -> None:
"""
if not show:
self.plotter.clear_plane_widgets()
self.plotter.remove_actor( self._clip_mesh ) # type: ignore
self.plotter.remove_actor( self._mesh_actor ) # type: ignore
self._mesh_actor = None
return

active_scalar = self.region_engine.input.active_scalars_name
self._clip_mesh: vtkActor = self.plotter.add_mesh_clip_plane(
self.region_engine.input,
origin=self.region_engine.input.center,
normal=[ -1, 0, 0 ],
crinkle=True,
show_edges=False,
cmap="glasbey_bw",
scalars=active_scalar,
)

self.server.controller.view_update()
self._point_data_array_names = list( self.region_engine.input.point_data.keys() )
self._cell_data_array_names = list( self.region_engine.input.cell_data.keys() )
self.state[ self.DATA_ARRAYS ] = self._point_data_array_names + self._cell_data_array_names
self.state[ self.SELECTED_DATA_ARRAY ] = self.region_engine.input.active_scalars_name

self._mesh_actor = self.plotter.add_mesh( self.region_engine.input )
self.plotter.add_plane_widget( callback=self._clip_mesh,
normal=[ 1, 0, 0 ],
origin=self.region_engine.input.center,
assign_to_axis=None,
tubing=False,
outline_translation=False )

def _update_perforation( self, perforation: Perforation, show: bool, path: str ) -> None:
"""Generate VTK dataset from a perforation."""
Expand Down
19 changes: 19 additions & 0 deletions geos-trame/src/geos/trame/app/utils/pv_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,27 @@
# SPDX-FileCopyrightText: Copyright 2023-2024 TotalEnergies.
# SPDX-FileContributor: Kitware
import pyvista as pv
from vtkmodules.util.numpy_support import vtk_to_numpy, numpy_to_vtk
from vtkmodules.vtkCommonCore import vtkDataArray


def read_unstructured_grid( filename: str ) -> pv.UnstructuredGrid:
"""Read an unstructured grid from a .vtu file."""
return pv.read( filename ).cast_to_unstructured_grid()


def split_vector_arrays( ug: pv.UnstructuredGrid ) -> None:
"""Create N 1-component arrays from each vector array with N components."""
for data in [ ug.GetPointData(), ug.GetCellData() ]:
for i in range( data.GetNumberOfArrays() ):
array: vtkDataArray = data.GetArray( i )
if array.GetNumberOfComponents() != 1:
np_array = vtk_to_numpy( array )
array_name = array.GetName()
data.RemoveArray( array_name )
for comp in range( array.GetNumberOfComponents() ):
component = np_array[ :, comp ]
new_array_name = f"{array_name}_{comp}"
new_array = numpy_to_vtk( component, deep=True )
new_array.SetName( new_array_name )
data.AddArray( new_array )
Loading