Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add llm generated explanations to TAGDataset #9918

Open
wants to merge 37 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
37 commits
Select commit Hold shift + click to select a range
f330b06
update
xnuohz Dec 29, 2024
6b1c556
update
xnuohz Dec 29, 2024
5caf0c3
Merge branch 'fix/glem-example' into tape
xnuohz Dec 29, 2024
f1c7931
add llm explanation token
xnuohz Jan 2, 2025
5c76117
add lm training
xnuohz Jan 2, 2025
a252312
update
xnuohz Jan 5, 2025
6da5717
Merge branch 'master' into tape
xnuohz Jan 5, 2025
e302bd4
update
xnuohz Jan 5, 2025
82c9450
update
xnuohz Jan 5, 2025
29f3881
update
xnuohz Jan 5, 2025
b139ac2
add changelog
xnuohz Jan 5, 2025
0f643d5
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
07fe2ec
update
xnuohz Jan 5, 2025
695ed01
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 5, 2025
9e65389
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 7, 2025
b064ab5
Merge branch 'tagdataset/add-llm-exp-pred' of github.com:xnuohz/pytor…
xnuohz Jan 7, 2025
74f2bb9
update
xnuohz Jan 7, 2025
502659a
update
xnuohz Jan 7, 2025
d09912b
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 7, 2025
1b3e679
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 7, 2025
ac402f3
update
xnuohz Jan 10, 2025
ac7bfed
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 10, 2025
fda21c2
update
xnuohz Jan 12, 2025
cb657c2
Merge branch 'tagdataset/add-llm-exp-pred' of github.com:xnuohz/pytor…
xnuohz Jan 12, 2025
3aa5e56
update
xnuohz Jan 13, 2025
85c8741
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 16, 2025
fbc573c
update
xnuohz Jan 16, 2025
3e4f4d9
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 17, 2025
46e1b36
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 20, 2025
5f853ba
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 20, 2025
1f47744
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 21, 2025
1ae39ce
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 22, 2025
cbfe731
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 23, 2025
02d7c9a
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 24, 2025
075b7b2
empty
akihironitta Jan 26, 2025
86ed422
Merge branch 'master' into tagdataset/add-llm-exp-pred
xnuohz Jan 27, 2025
af6f6f7
Merge branch 'master' into tagdataset/add-llm-exp-pred
puririshi98 Jan 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Added

- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918))
- Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794))
- Consolidate Cugraph examples into ogbn_train_cugraph.py and ogbn_train_cugraph_multigpu.py for ogbn-arxiv, ogbn-products and ogbn-papers100M ([#9953](https://github.com/pyg-team/pytorch_geometric/pull/9953))
- Added `InstructMol` dataset ([#9975](https://github.com/pyg-team/pytorch_geometric/pull/9975))
- Added support for weighted `LinkPredRecall` metric ([#9947](https://github.com/pyg-team/pytorch_geometric/pull/9947))
Expand Down
9 changes: 7 additions & 2 deletions examples/llm/glem.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def get_n_params(model):
def main(args):
gpu = args.gpu
dataset_name = args.dataset
text_type = args.text_type if args.dataset == 'arxiv' else 'raw_text'
root = osp.join('data', 'ogb')
hf_model = args.hf_model
pl_ratio = args.pl_ratio
Expand Down Expand Up @@ -83,7 +84,7 @@ def main(args):

tag_dataset = TAGDataset(root, dataset, hf_model,
token_on_disk=token_on_disk)
text_dataset = tag_dataset.to_text_dataset()
text_dataset = tag_dataset.to_text_dataset(text_type)
print(tag_dataset.num_classes, tag_dataset.raw_file_names)

num_classes = tag_dataset.num_classes
Expand Down Expand Up @@ -393,8 +394,12 @@ def load_model(em_phase):
help='number of runs')
parser.add_argument('--num_em_iters', type=int, default=1,
help='number of iterations')
parser.add_argument("--dataset", type=str, default='products',
parser.add_argument("--dataset", type=str, default='arxiv',
help='arxiv or products')
parser.add_argument(
"--text_type", type=str, default='llm_explanation',
help="type of text, support raw_text, llm_explanation,"
"all for arxiv and raw_text for products")
parser.add_argument("--pl_ratio", type=float, default=0.5,
help="pseudo labels ratio")
parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny',
Expand Down
22 changes: 22 additions & 0 deletions test/datasets/test_tag_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from torch_geometric.datasets import TAGDataset
from torch_geometric.testing import onlyFullTest, withPackage


@onlyFullTest
@withPackage('ogb')
def test_tag_dataset() -> None:
from ogb.nodeproppred import PygNodePropPredDataset

root = './data/ogb'
hf_model = 'prajjwal1/bert-tiny'
token_on_disk = True

dataset = PygNodePropPredDataset('ogbn-arxiv', root=root)
tag_dataset = TAGDataset(root, dataset, hf_model,
token_on_disk=token_on_disk)

assert 169343 == tag_dataset[0].num_nodes \
== len(tag_dataset.text) \
== len(tag_dataset.llm_explanation) \
== len(tag_dataset.llm_prediction)
assert 1166243 == tag_dataset[0].num_edges
Loading
Loading