Skip to content

Commit 11b5fab

Browse files
authored
Fix default session for processing vineyard operation. (alibaba#230)
1 parent c7c27c7 commit 11b5fab

File tree

3 files changed

+27
-12
lines changed

3 files changed

+27
-12
lines changed

python/graphscope/framework/graph.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -734,7 +734,11 @@ def add_vertices(self, vertices, label="_", properties=None, vid_field=0):
734734
raise ValueError("Cannot manually add vertices after inferred vertices.")
735735
unsealed_vertices = deepcopy(self._unsealed_vertices)
736736
unsealed_vertices[label] = VertexLabel(
737-
label=label, loader=vertices, properties=properties, vid_field=vid_field
737+
label=label,
738+
loader=vertices,
739+
properties=properties,
740+
vid_field=vid_field,
741+
session_id=self._session.session_id,
738742
)
739743
v_labels = deepcopy(self._v_labels)
740744
v_labels.append(label)
@@ -845,7 +849,7 @@ def add_edges(
845849
else:
846850
e_labels.append(label)
847851
relations.append([(src_label, dst_label)])
848-
cur_label = EdgeLabel(label)
852+
cur_label = EdgeLabel(label, self._session.session_id)
849853
cur_label.add_sub_label(
850854
EdgeSubLabel(edges, properties, src_label, dst_label, src_field, dst_field)
851855
)

python/graphscope/framework/graph_utils.py

+10-6
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ def __init__(
5858
loader: Any,
5959
properties: Sequence = None,
6060
vid_field: Union[str, int] = 0,
61+
session_id=None,
6162
):
6263
self.label = label
6364
if isinstance(loader, Loader):
@@ -68,7 +69,7 @@ def __init__(
6869
self.raw_properties = properties
6970
self.properties = []
7071
self.vid_field = vid_field
71-
72+
self._session_id = session_id
7273
self._finished = False
7374

7475
def finish(self, id_type: str = "int64_t"):
@@ -84,7 +85,7 @@ def finish(self, id_type: str = "int64_t"):
8485
self.loader.select_columns(
8586
self.properties, include_all=bool(self.raw_properties is None)
8687
)
87-
self.loader.finish()
88+
self.loader.finish(self._session_id)
8889
self._finished = True
8990

9091
def __str__(self) -> str:
@@ -153,7 +154,7 @@ def __init__(
153154
"Source vid and destination vid must have same formats, both use name or both use index"
154155
)
155156

156-
def finish(self, id_type: str = "int64_t"):
157+
def finish(self, id_type: str = "int64_t", session_id=None):
157158
if self._finished:
158159
return
159160
self.add_property(str(self.src_field), id_type)
@@ -165,7 +166,7 @@ def finish(self, id_type: str = "int64_t"):
165166
self.loader.select_columns(
166167
self.properties, include_all=bool(self.raw_properties is None)
167168
)
168-
self.loader.finish()
169+
self.loader.finish(session_id)
169170
self._finished = True
170171

171172
def __str__(self) -> str:
@@ -221,9 +222,12 @@ class EdgeLabel(object):
221222
src_label3 -> edge_label -> dst_label3
222223
"""
223224

224-
def __init__(self, label: str):
225+
def __init__(self, label: str, session_id=None):
225226
self.label = label
226227
self.sub_labels = {}
228+
229+
self._session_id = session_id
230+
227231
self._finished = False
228232

229233
def __str__(self):
@@ -251,7 +255,7 @@ def finish(self, id_type: str = "int64_t"):
251255
if self._finished:
252256
return
253257
for sub_label in self.sub_labels.values():
254-
sub_label.finish(id_type)
258+
sub_label.finish(id_type, self._session_id)
255259
self._finished = True
256260

257261

python/graphscope/framework/loader.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -278,17 +278,24 @@ def func(source, storage_options, read_options, sess):
278278
self.source = source
279279
self.preprocessor = func
280280

281-
def finish(self):
282-
from graphscope.client.session import get_default_session
283-
281+
def finish(self, session_id=None):
284282
if self.finished:
285283
return
286284
if self.preprocessor is not None:
285+
if session_id is None:
286+
from graphscope.client.session import get_default_session
287+
288+
sess = get_default_session()
289+
else:
290+
from graphscope.client.session import get_session_by_id
291+
292+
sess = get_session_by_id(session_id)
293+
287294
self.protocol, self.source = self.preprocessor(
288295
self.source,
289296
self.storage_options,
290297
self.options.to_dict(),
291-
get_default_session(),
298+
sess,
292299
)
293300
logger.debug(
294301
f"processed protocol = {self.protocol}, source = {self.source}"

0 commit comments

Comments
 (0)