Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llm generated explanations to TAGDataset #9918

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f330b06
update
xnuohz Dec 29, 2024
6b1c556
update
xnuohz Dec 29, 2024
5caf0c3
Merge branch 'fix/glem-example' into tape
xnuohz Dec 29, 2024
f1c7931
add llm explanation token
xnuohz Jan 2, 2025
5c76117
add lm training
xnuohz Jan 2, 2025
a252312
update
xnuohz Jan 5, 2025
6da5717
Merge branch 'master' into tape
xnuohz Jan 5, 2025
e302bd4
update
xnuohz Jan 5, 2025
82c9450
update
xnuohz Jan 5, 2025
29f3881
update
xnuohz Jan 5, 2025
b139ac2
add changelog
xnuohz Jan 5, 2025
0f643d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
07fe2ec
update
xnuohz Jan 5, 2025
695ed01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
9e65389
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 7, 2025
b064ab5
Merge branch 'tagdataset/add-llm-exp-pred' of github.com:xnuohz/pytor…
xnuohz Jan 7, 2025
74f2bb9
update
xnuohz Jan 7, 2025
502659a
update
xnuohz Jan 7, 2025
d09912b
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 7, 2025
1b3e679
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 7, 2025
ac402f3
update
xnuohz Jan 10, 2025
ac7bfed
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 10, 2025
fda21c2
update
xnuohz Jan 12, 2025
cb657c2
Merge branch 'tagdataset/add-llm-exp-pred' of github.com:xnuohz/pytor…
xnuohz Jan 12, 2025
3aa5e56
update
xnuohz Jan 13, 2025
85c8741
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 16, 2025
fbc573c
update
xnuohz Jan 16, 2025
3e4f4d9
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 17, 2025
46e1b36
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 20, 2025
5f853ba
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 20, 2025
1f47744
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 21, 2025
1ae39ce
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 22, 2025
cbfe731
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 23, 2025
02d7c9a
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 24, 2025
075b7b2
empty
akihironitta Jan 26, 2025
86ed422
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 27, 2025
af6f6f7
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))
- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
- Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666))
Expand Down
5 changes: 4 additions & 1 deletion examples/llm/glem.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_n_params(model):
def main(args):
gpu = args.gpu
dataset_name = args.dataset
text_type = args.text_type
root = osp.join('data', 'ogb')
hf_model = args.hf_model
pl_ratio = args.pl_ratio
Expand Down Expand Up @@ -83,7 +84,7 @@ def main(args):

tag_dataset = TAGDataset(root, dataset, hf_model,
token_on_disk=token_on_disk)
text_dataset = tag_dataset.to_text_dataset()
text_dataset = tag_dataset.to_text_dataset(text_type)
print(tag_dataset.num_classes, tag_dataset.raw_file_names)

num_classes = tag_dataset.num_classes
Expand Down Expand Up @@ -395,6 +396,8 @@ def load_model(em_phase):
help='number of iterations')
parser.add_argument("--dataset", type=str, default='products',
help='arxiv or products')
parser.add_argument("--text_type", type=str, default='raw_text',
help='raw_text or llm_explanation')
parser.add_argument("--pl_ratio", type=float, default=0.5,
help="pseudo labels ratio")
parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny',
Expand Down
22 changes: 22 additions & 0 deletions test/datasets/test_tag_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch_geometric.datasets import TAGDataset
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('ogb')
def test_tag_dataset() -> None:
from ogb.nodeproppred import PygNodePropPredDataset

root = './data/ogb'
hf_model = 'prajjwal1/bert-tiny'
token_on_disk = True

dataset = PygNodePropPredDataset('ogbn-arxiv', root=root)
tag_dataset = TAGDataset(root, dataset, hf_model,
token_on_disk=token_on_disk)

assert 169343 == tag_dataset[0].num_nodes \
== len(tag_dataset.text) \
== len(tag_dataset.llm_explanation) \
== len(tag_dataset.llm_prediction)
assert 1166243 == tag_dataset[0].num_edges
181 changes: 150 additions & 31 deletions torch_geometric/datasets/tag_dataset.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import csv
import os
import os.path as osp
from collections.abc import Sequence
Expand All @@ -10,6 +11,7 @@

from torch_geometric.data import InMemoryDataset, download_google_url
from torch_geometric.data.data import BaseData
from torch_geometric.io import fs

try:
from pandas import DataFrame, read_csv
Expand All @@ -22,14 +24,16 @@

class TAGDataset(InMemoryDataset):
r"""The Text Attributed Graph datasets from the
`"Learning on Large-scale Text-attributed Graphs via Variational Inference
" <https://arxiv.org/abs/2210.14709>`_ paper.
`"Learning on Large-scale Text-attributed Graphs via Variational Inference"
<https://arxiv.org/abs/2210.14709>`_ paper and `"Harnessing Explanations:
LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation
Learning" <https://arxiv.org/abs/2305.19523>`_ paper.
This dataset is aiming on transform `ogbn products`, `ogbn arxiv`
into Text Attributed Graph that each node in graph is associate with a
raw text, that dataset can be adapt to DataLoader (for LM training) and
NeighborLoader(for GNN training). In addition, this class can be use as a
wrapper class by convert a InMemoryDataset with Tokenizer and text into
Text Attributed Graph.
raw text, LLM prediction and explanation, that dataset can be adapt to
DataLoader (for LM training) and NeighborLoader(for GNN training).
In addition, this class can be use as a wrapper class by convert a
InMemoryDataset with Tokenizer and text into Text Attributed Graph.

Args:
root (str): Root directory where the dataset should be saved.
Expand All @@ -40,6 +44,12 @@ class TAGDataset(InMemoryDataset):
on huggingface.co.
text (List[str]): list of raw text associate with node, the order of
list should be align with node list
llm_explanation (Optional[List[str]]): list of llm explanation
associate with node, which should be align with node list
llm_prediction (Optional[List[str]]): list of llm prediction associate
with node, the order of list should be align with node list
llm_prediction_topk (int): Top K prediction from LLM used as
features for GNN training, default: 5
split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary,
for saving split index, it is required that if your dataset doesn't
have get_split_idx function
Expand All @@ -51,22 +61,40 @@ class TAGDataset(InMemoryDataset):
or not, default: False
force_reload (bool): default: False
.. note::
See `example/llm_plus_gnn/glem.py` for example usage
See `example/llm/glem.py` for example usage
"""
raw_text_id = {
'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3',
'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt'
}

def __init__(self, root: str, dataset: InMemoryDataset,
tokenizer_name: str, text: Optional[List[str]] = None,
split_idx: Optional[Dict[str, Tensor]] = None,
tokenize_batch_size: int = 256, token_on_disk: bool = False,
text_on_disk: bool = False,
force_reload: bool = False) -> None:
llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds'

llm_explanation_id = {
'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ',
}

def __init__(
self,
root: str,
dataset: InMemoryDataset,
tokenizer_name: str,
text: Optional[List[str]] = None,
llm_explanation: Optional[List[str]] = None,
llm_prediction: Optional[Tensor] = None,
llm_prediction_topk: int = 5,
split_idx: Optional[Dict[str, Tensor]] = None,
tokenize_batch_size: int = 256,
token_on_disk: bool = False,
text_on_disk: bool = False,
force_reload: bool = False,
) -> None:
# list the vars you want to pass in before run download & process
self.name = dataset.name
self.text = text
self.llm_explanation = llm_explanation
self.llm_prediction = llm_prediction
self.llm_prediction_topk = llm_prediction_topk
self.tokenizer_name = tokenizer_name
from transformers import AutoTokenizer
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
Expand All @@ -93,8 +121,11 @@ def __init__(self, root: str, dataset: InMemoryDataset,
"is_gold mask, please pass splited index "
"in format of dictionaty with 'train', 'valid' "
"'test' index tensor to 'split_idx'")
if text is not None and text_on_disk:
self.save_node_text(text)
if text_on_disk:
if text is not None:
self.save_node_text(text)
if llm_explanation is not None:
self.save_node_explanation(llm_explanation)
self.text_on_disk = text_on_disk
# init will call download and process
super().__init__(self.root, transform=None, pre_transform=None,
Expand All @@ -116,9 +147,19 @@ def __init__(self, root: str, dataset: InMemoryDataset,
if self.text is not None and len(self.text) != self._data.num_nodes:
raise ValueError("The number of text sequence in 'text' should be "
"equal to number of nodes!")
if self.llm_explanation is not None and len(
self.llm_explanation) != self._data.num_nodes:
raise ValueError("The number of LLM explanation should be "
"equal to number of nodes!")
if self.llm_prediction is not None and len(
self.llm_prediction) != self._data.num_nodes:
raise ValueError("The number of LLM prediction should be "
"equal to number of nodes!")
self.token_on_disk = token_on_disk
self.tokenize_batch_size = tokenize_batch_size
self._token = self.tokenize_graph(self.tokenize_batch_size)
self._llm_explanation_token = self.tokenize_graph(
self.tokenize_batch_size, text_type='llm_explanation')
self.__num_classes__ = dataset.num_classes

@property
Expand All @@ -128,7 +169,7 @@ def num_classes(self) -> int:
@property
def raw_file_names(self) -> List[str]:
file_names = []
for root, _, files in os.walk(osp.join(self.root, 'raw')):
for _, _, files in os.walk(osp.join(self.root, 'raw')):
for file in files:
file_names.append(file)
return file_names
Expand All @@ -146,6 +187,13 @@ def token(self) -> Dict[str, Tensor]:
self._token = self.tokenize_graph()
return self._token

@property
def llm_explanation_token(self) -> Dict[str, Tensor]:
if self._llm_explanation_token is None: # lazy load
self._llm_explanation_token = self.tokenize_graph(
text_type='llm_explanation')
return self._llm_explanation_token

# load is_gold after init
@property
def is_gold(self) -> Tensor:
Expand Down Expand Up @@ -194,10 +242,17 @@ def download(self) -> None:
folder=f'{self.root}/raw',
filename='node-text.csv.gz',
log=True)
text_df = read_csv(raw_text_path)
self.text = list(text_df['text'])
self.text = list(read_csv(raw_text_path)['text'])
print('downloading llm explanations')
llm_explanation_path = download_google_url(
id=self.llm_explanation_id[self.name], folder=f'{self.root}/raw',
filename='node-gpt-response.csv.gz', log=True)
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
print('downloading llm predictions')
fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir)

def process(self) -> None:
# process Title and Abstraction
if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')):
text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz'))
self.text = list(text_df['text'])
Expand All @@ -212,6 +267,43 @@ def process(self) -> None:
"The raw text of each node is not specified"
"Please pass in 'text' when convert your dataset "
"to Text Attribute Graph Dataset")
# process LLM explanation and prediction
llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz'
llm_prediction_path = f'{self.raw_dir}/{self.name}.csv'
if osp.exists(llm_explanation_path) and osp.exists(
llm_prediction_path):
# load LLM explanation
self.llm_explanation = list(read_csv(llm_explanation_path)['text'])
# load LLM prediction
preds = []
with open(llm_prediction_path) as file:
reader = csv.reader(file)
for row in reader:
inner_list = []
for value in row:
inner_list.append(int(value))
preds.append(inner_list)

pl = torch.zeros(len(preds), self.llm_prediction_topk,
dtype=torch.long)
for i, pred in enumerate(preds):
pl[i][:len(pred)] = torch.tensor(
pred[:self.llm_prediction_topk], dtype=torch.long) + 1
self.llm_prediction = pl
elif self.name in self.llm_explanation_id:
self.download()
else:
print(
'The dataset is not ogbn-arxiv,'
'please pass in your llm explanation list to `llm_explanation`'
'and llm prediction list to `llm_prediction`')
if self.llm_explanation is None or self.llm_prediction is None:
raise ValueError(
"The TAGDataset only have ogbn-arxiv LLM explanations"
"and predictions in default. The llm explanation and"
"prediction of each node is not specified."
"Please pass in 'llm_explanation' and 'llm_prediction' when"
"convert your dataset to Text Attribute Graph Dataset")

def save_node_text(self, text: List[str]) -> None:
node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz')
Expand All @@ -224,22 +316,42 @@ def save_node_text(self, text: List[str]) -> None:
text_df.to_csv(osp.join(node_text_path), compression='gzip',
index=False)

def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
def save_node_explanation(self, text: List[str]) -> None:
node_text_path = osp.join(self.root, 'raw', 'node-gpt-response.csv.gz')
if osp.exists(node_text_path):
print(f'The llm explanation is existed at {node_text_path}')
else:
print(f'Saving llm explanation file at {node_text_path}')
os.makedirs(f'{self.root}/raw', exist_ok=True)
text_df = DataFrame(text, columns=['text'])
text_df.to_csv(osp.join(node_text_path), compression='gzip',
index=False)

def tokenize_graph(self, batch_size: int = 256,
text_type: str = 'raw_text') -> Dict[str, Tensor]:
r"""Tokenizing the text associate with each node, running in cpu.

Args:
batch_size (Optional[int]): batch size of list of text for
generating emebdding
text_type (Optional[str]): type of text
Returns:
Dict[str, torch.Tensor]: tokenized graph
"""
assert text_type in ['raw_text', 'llm_explanation']
if text_type == 'raw_text':
_text = self.text
elif text_type == 'llm_explanation':
_text = self.llm_explanation

data_len = 0
if self.text is not None:
data_len = len(self.text)
if _text is not None:
data_len = len(_text)
else:
raise ValueError("The TAGDataset need text for tokenization")
token_keys = ['input_ids', 'token_type_ids', 'attention_mask']
path = os.path.join(self.processed_dir, 'token', self.tokenizer_name)
path = os.path.join(self.processed_dir, 'token', text_type,
self.tokenizer_name)
# Check if the .pt files already exist
token_files_exist = any(
os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys)
Expand All @@ -256,12 +368,12 @@ def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]:
all_encoded_token = {k: [] for k in token_keys}
pbar = tqdm(total=data_len)

pbar.set_description('Tokenizing Text Attributed Graph')
pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}')
for i in range(0, data_len, batch_size):
end_index = min(data_len, i + batch_size)
token = self.tokenizer(self.text[i:min(i + batch_size, data_len)],
padding='max_length', truncation=True,
max_length=512, return_tensors="pt")
token = self.tokenizer(_text[i:end_index], padding='max_length',
truncation=True, max_length=512,
return_tensors="pt")
for k in token.keys():
all_encoded_token[k].append(token[k])
pbar.update(end_index - i)
Expand Down Expand Up @@ -289,10 +401,16 @@ class TextDataset(torch.utils.data.Dataset):

Args:
tag_dataset (TAGDataset): the parent dataset
text_type (str): type of text
"""
def __init__(self, tag_dataset: 'TAGDataset') -> None:
def __init__(self, tag_dataset: 'TAGDataset',
text_type: str = 'raw_text') -> None:
assert text_type in ['raw_text', 'llm_explanation']
self.tag_dataset = tag_dataset
self.token = tag_dataset.token
if text_type == 'raw_text':
self.token = tag_dataset.token
elif text_type == 'llm_explanation':
self.token = tag_dataset.llm_explanation_token
assert tag_dataset._data is not None
self._data = tag_dataset._data

Expand All @@ -312,7 +430,8 @@ def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]:

# for LM training
def __getitem__(
self, node_id: IndexType
self,
node_id: IndexType,
) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]:
r"""This function will override the function in
torch.utils.data.Dataset, and will be called when you
Expand Down Expand Up @@ -343,8 +462,8 @@ def get(self, idx: int) -> BaseData:
def __repr__(self) -> str:
return f'{self.__class__.__name__}()'

def to_text_dataset(self) -> TextDataset:
def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset:
r"""Factory Build text dataset from Text Attributed Graph Dataset
each data point is node's associated text token.
"""
return TAGDataset.TextDataset(self)
return TAGDataset.TextDataset(self, text_type)
Loading