From f330b06044e74d410cdc8a832e92bc6ce190e143 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 29 Dec 2024 19:28:59 +0800 Subject: [PATCH 01/19] update --- README.md | 12 +-- examples/llm/tape/generate.py | 6 ++ test/datasets/test_tag_dataset.py | 22 ++++ torch_geometric/datasets/tag_dataset.py | 135 ++++++++++++++++++++---- 4 files changed, 149 insertions(+), 26 deletions(-) create mode 100644 examples/llm/tape/generate.py create mode 100644 test/datasets/test_tag_dataset.py diff --git a/README.md b/README.md index f96d720837d1..64a809fdcf83 100644 --- a/README.md +++ b/README.md @@ -383,9 +383,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` | | `cpu` | `cu118` | `cu121` | `cu124` | | ----------- | ----- | ------- | ------- | ------- | -| **Linux** | ✅ | ✅ | ✅ | ✅ | -| **Windows** | ✅ | ✅ | ✅ | ✅ | -| **macOS** | ✅ | | | | +| **Linux** | ✅ | ✅ | ✅ | ✅ | +| **Windows** | ✅ | ✅ | ✅ | ✅ | +| **macOS** | ✅ | | | | #### PyTorch 2.4 @@ -399,9 +399,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` | | `cpu` | `cu118` | `cu121` | `cu124` | | ----------- | ----- | ------- | ------- | ------- | -| **Linux** | ✅ | ✅ | ✅ | ✅ | -| **Windows** | ✅ | ✅ | ✅ | ✅ | -| **macOS** | ✅ | | | | +| **Linux** | ✅ | ✅ | ✅ | ✅ | +| **Windows** | ✅ | ✅ | ✅ | ✅ | +| **macOS** | ✅ | | | | **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, and PyTorch 2.3.0/2.3.1 (following the same procedure). **For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source. diff --git a/examples/llm/tape/generate.py b/examples/llm/tape/generate.py new file mode 100644 index 000000000000..a53c62e8fea7 --- /dev/null +++ b/examples/llm/tape/generate.py @@ -0,0 +1,6 @@ +def main(): + pass + + +if __name__ == '__main__': + main() diff --git a/test/datasets/test_tag_dataset.py b/test/datasets/test_tag_dataset.py new file mode 100644 index 000000000000..a816dfa77ee3 --- /dev/null +++ b/test/datasets/test_tag_dataset.py @@ -0,0 +1,22 @@ +from ogb.nodeproppred import PygNodePropPredDataset + +from torch_geometric.datasets import TAGDataset +from torch_geometric.testing import withPackage + + +# @onlyFullTest +@withPackage('ogb') +def test_tag_dataset() -> None: + root = './data' + hf_model = 'prajjwal1/bert-tiny' + token_on_disk = True + + dataset = PygNodePropPredDataset('ogbn-arxiv', root=root) + tag_dataset = TAGDataset(root, dataset, hf_model, + token_on_disk=token_on_disk) + + assert 169343 == tag_dataset[0].num_nodes \ + == len(tag_dataset.text) \ + == len(tag_dataset.llm_explanation) \ + == len(tag_dataset.llm_prediction) + assert 1166243 == tag_dataset[0].num_edges diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py index f25992ced989..09444a904d48 100644 --- a/torch_geometric/datasets/tag_dataset.py +++ b/torch_geometric/datasets/tag_dataset.py @@ -1,3 +1,4 @@ +import csv import os import os.path as osp from collections.abc import Sequence @@ -10,6 +11,7 @@ from torch_geometric.data import InMemoryDataset, download_google_url from torch_geometric.data.data import BaseData +from torch_geometric.io import fs try: from pandas import DataFrame, read_csv @@ -22,14 +24,16 @@ class TAGDataset(InMemoryDataset): r"""The Text Attributed Graph datasets from the - `"Learning on Large-scale Text-attributed Graphs via Variational Inference - " `_ paper. + `"Learning on Large-scale Text-attributed Graphs via Variational Inference" + `_ paper and `"Harnessing Explanations: + LLM-to-LM Interpreter for Enhanced Text-Attributed Graph Representation + Learning" `_ paper. This dataset is aiming on transform `ogbn products`, `ogbn arxiv` into Text Attributed Graph that each node in graph is associate with a - raw text, that dataset can be adapt to DataLoader (for LM training) and - NeighborLoader(for GNN training). In addition, this class can be use as a - wrapper class by convert a InMemoryDataset with Tokenizer and text into - Text Attributed Graph. + raw text, LLM prediction and explanation, that dataset can be adapt to + DataLoader (for LM training) and NeighborLoader(for GNN training). + In addition, this class can be use as a wrapper class by convert a + InMemoryDataset with Tokenizer and text into Text Attributed Graph. Args: root (str): Root directory where the dataset should be saved. @@ -40,6 +44,12 @@ class TAGDataset(InMemoryDataset): on huggingface.co. text (List[str]): list of raw text associate with node, the order of list should be align with node list + llm_explanation (Optional[List[str]]): list of llm explanation + associate with node, which should be align with node list + llm_prediction (Optional[List[str]]): list of llm prediction associate + with node, the order of list should be align with node list + llm_prediction_topk (Optional[int]): Top K prediction from LLM used as + features for GNN training, default: 5 split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary, for saving split index, it is required that if your dataset doesn't have get_split_idx function @@ -51,22 +61,40 @@ class TAGDataset(InMemoryDataset): or not, default: False force_reload (bool): default: False .. note:: - See `example/llm_plus_gnn/glem.py` for example usage + See `example/llm/glem.py` for example usage """ raw_text_id = { 'ogbn-arxiv': '1g3OOVhRyiyKv13LY6gbp8GLITocOUr_3', 'ogbn-products': '1I-S176-W4Bm1iPDjQv3hYwQBtxE0v8mt' } - def __init__(self, root: str, dataset: InMemoryDataset, - tokenizer_name: str, text: Optional[List[str]] = None, - split_idx: Optional[Dict[str, Tensor]] = None, - tokenize_batch_size: int = 256, token_on_disk: bool = False, - text_on_disk: bool = False, - force_reload: bool = False) -> None: + llm_prediction_url = 'https://github.com/XiaoxinHe/TAPE/raw/main/gpt_preds' + + llm_explanation_id = { + 'ogbn-arxiv': '1o8n2xRen-N_elF9NQpIca0iCHJgEJbRQ', + } + + def __init__( + self, + root: str, + dataset: InMemoryDataset, + tokenizer_name: str, + text: Optional[List[str]] = None, + llm_explanation: Optional[List[str]] = None, + llm_prediction: Optional[List[str]] = None, + llm_prediction_topk: Optional[int] = 5, + split_idx: Optional[Dict[str, Tensor]] = None, + tokenize_batch_size: int = 256, + token_on_disk: bool = False, + text_on_disk: bool = False, + force_reload: bool = False, + ) -> None: # list the vars you want to pass in before run download & process self.name = dataset.name self.text = text + self.llm_explanation = llm_explanation + self.llm_prediction = llm_prediction + self.llm_prediction_topk = llm_prediction_topk self.tokenizer_name = tokenizer_name from transformers import AutoTokenizer self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) @@ -93,8 +121,11 @@ def __init__(self, root: str, dataset: InMemoryDataset, "is_gold mask, please pass splited index " "in format of dictionaty with 'train', 'valid' " "'test' index tensor to 'split_idx'") - if text is not None and text_on_disk: - self.save_node_text(text) + if text_on_disk: + if text is not None: + self.save_node_text(text) + if llm_explanation is not None: + self.save_node_explanation(llm_explanation) self.text_on_disk = text_on_disk # init will call download and process super().__init__(self.root, transform=None, pre_transform=None, @@ -116,6 +147,14 @@ def __init__(self, root: str, dataset: InMemoryDataset, if self.text is not None and len(self.text) != self._data.num_nodes: raise ValueError("The number of text sequence in 'text' should be " "equal to number of nodes!") + if self.llm_explanation is not None and len( + self.llm_explanation) != self._data.num_nodes: + raise ValueError("The number of LLM explanation should be " + "equal to number of nodes!") + if self.llm_prediction is not None and len( + self.llm_prediction) != self._data.num_nodes: + raise ValueError("The number of LLM prediction should be " + "equal to number of nodes!") self.token_on_disk = token_on_disk self.tokenize_batch_size = tokenize_batch_size self._token = self.tokenize_graph(self.tokenize_batch_size) @@ -128,7 +167,7 @@ def num_classes(self) -> int: @property def raw_file_names(self) -> List[str]: file_names = [] - for root, _, files in os.walk(osp.join(self.root, 'raw')): + for _, _, files in os.walk(osp.join(self.root, 'raw')): for file in files: file_names.append(file) return file_names @@ -194,10 +233,17 @@ def download(self) -> None: folder=f'{self.root}/raw', filename='node-text.csv.gz', log=True) - text_df = read_csv(raw_text_path) - self.text = list(text_df['text']) + self.text = list(read_csv(raw_text_path)['text']) + print('downloading llm explanations') + llm_explanation_path = download_google_url( + id=self.llm_explanation_id[self.name], folder=f'{self.root}/raw', + filename='node-gpt-response.csv.gz', log=True) + self.llm_explanation = list(read_csv(llm_explanation_path)['text']) + print('downloading llm predictions') + fs.cp(f'{self.llm_prediction_url}/{self.name}.csv', self.raw_dir) def process(self) -> None: + # process Title and Abstraction if osp.exists(osp.join(self.root, 'raw', 'node-text.csv.gz')): text_df = read_csv(osp.join(self.root, 'raw', 'node-text.csv.gz')) self.text = list(text_df['text']) @@ -212,6 +258,43 @@ def process(self) -> None: "The raw text of each node is not specified" "Please pass in 'text' when convert your dataset " "to Text Attribute Graph Dataset") + # process LLM explanation and prediction + llm_explanation_path = f'{self.raw_dir}/node-gpt-response.csv.gz' + llm_prediction_path = f'{self.raw_dir}/{self.name}.csv' + if osp.exists(llm_explanation_path) and osp.exists( + llm_prediction_path): + # load LLM explanation + self.llm_explanation = list(read_csv(llm_explanation_path)['text']) + # load LLM prediction + preds = [] + with open(llm_prediction_path) as file: + reader = csv.reader(file) + for row in reader: + inner_list = [] + for value in row: + inner_list.append(int(value)) + preds.append(inner_list) + + pl = torch.zeros(len(preds), self.llm_prediction_topk, + dtype=torch.long) + for i, pred in enumerate(preds): + pl[i][:len(pred)] = torch.tensor( + pred[:self.llm_prediction_topk], dtype=torch.long) + 1 + self.llm_prediction = pl + elif self.name in self.llm_explanation_id: + self.download() + else: + print( + 'The dataset is not ogbn-arxiv,' + 'please pass in your llm explanation list to `llm_explanation`' + 'and llm prediction list to `llm_prediction`') + if self.llm_explanation is None or self.llm_prediction is None: + raise ValueError( + "The TAGDataset only have ogbn-arxiv LLM explanations" + "and predictions in default. The llm explanation and" + "prediction of each node is not specified." + "Please pass in 'llm_explanation' and 'llm_prediction' when" + "convert your dataset to Text Attribute Graph Dataset") def save_node_text(self, text: List[str]) -> None: node_text_path = osp.join(self.root, 'raw', 'node-text.csv.gz') @@ -224,6 +307,17 @@ def save_node_text(self, text: List[str]) -> None: text_df.to_csv(osp.join(node_text_path), compression='gzip', index=False) + def save_node_explanation(self, text: List[str]) -> None: + node_text_path = osp.join(self.root, 'raw', 'node-gpt-response.csv.gz') + if osp.exists(node_text_path): + print(f'The llm explanation is existed at {node_text_path}') + else: + print(f'Saving llm explanation file at {node_text_path}') + os.makedirs(f'{self.root}/raw', exist_ok=True) + text_df = DataFrame(text, columns=['text']) + text_df.to_csv(osp.join(node_text_path), compression='gzip', + index=False) + def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: r"""Tokenizing the text associate with each node, running in cpu. @@ -259,7 +353,7 @@ def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: pbar.set_description('Tokenizing Text Attributed Graph') for i in range(0, data_len, batch_size): end_index = min(data_len, i + batch_size) - token = self.tokenizer(self.text[i:min(i + batch_size, data_len)], + token = self.tokenizer(self.text[i:end_index], padding='max_length', truncation=True, max_length=512, return_tensors="pt") for k in token.keys(): @@ -312,7 +406,8 @@ def get_token(self, node_idx: IndexType) -> Dict[str, Tensor]: # for LM training def __getitem__( - self, node_id: IndexType + self, + node_id: IndexType, ) -> Dict[str, Union[Tensor, Dict[str, Tensor]]]: r"""This function will override the function in torch.utils.data.Dataset, and will be called when you From 6b1c55600fd93fb9d2e9fa036c44e7266491d217 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 29 Dec 2024 20:17:58 +0800 Subject: [PATCH 02/19] update --- examples/llm/glem.py | 2 ++ torch_geometric/nn/models/glem.py | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index ec76cef4c010..c6bae703fd33 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -371,9 +371,11 @@ def load_model(em_phase): if gnn_val_acc > lm_val_acc: em_phase = 'gnn' model.gnn = model.gnn.to(device, non_blocking=True) + test_loader = subgraph_loader else: em_phase = 'lm' model.lm = model.lm.to(device, non_blocking=True) + test_loader = text_test_loader test_preds = model.inference(em_phase, test_loader, verbose=verbose) train_acc, val_acc, test_acc = evaluate(test_preds, ['train', 'valid', 'test']) diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py index afc8b09d77c7..d30d5f8bd062 100644 --- a/torch_geometric/nn/models/glem.py +++ b/torch_geometric/nn/models/glem.py @@ -144,7 +144,8 @@ def train(self, em_phase: str, train_loader: Union[DataLoader, acc (float): training accuracy loss (float): loss value """ - pseudo_labels = pseudo_labels.to(self.device) + if pseudo_labels is not None: + pseudo_labels = pseudo_labels.to(self.device) if em_phase == 'gnn': acc, loss = self.train_gnn(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose) From f1c793157b3dca1cf5a549568db90bf24211406f Mon Sep 17 00:00:00 2001 From: xnuohz Date: Thu, 2 Jan 2025 22:00:00 +0800 Subject: [PATCH 03/19] add llm explanation token --- examples/llm/tape/generate.py | 22 +++++++++++- torch_geometric/datasets/tag_dataset.py | 47 ++++++++++++++++++------- 2 files changed, 56 insertions(+), 13 deletions(-) diff --git a/examples/llm/tape/generate.py b/examples/llm/tape/generate.py index a53c62e8fea7..9064ac111b8d 100644 --- a/examples/llm/tape/generate.py +++ b/examples/llm/tape/generate.py @@ -1,5 +1,25 @@ +from ogb.nodeproppred import PygNodePropPredDataset + +from torch_geometric.datasets import TAGDataset + + def main(): - pass + dataset_name = 'arxiv' + root = './data/ogb' + hf_model = 'prajjwal1/bert-tiny' + token_on_disk = True + + dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) + dataset.get_idx_split() + + tag_dataset = TAGDataset(root, dataset, hf_model, + token_on_disk=token_on_disk) + raw_text_dataset = tag_dataset.to_text_dataset() + llm_explanation_dataset = tag_dataset.to_text_dataset( + text_type='llm_explanation') + print(tag_dataset.num_classes, tag_dataset.raw_file_names) + print(raw_text_dataset) + print(llm_explanation_dataset) if __name__ == '__main__': diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py index 09444a904d48..f88e67014684 100644 --- a/torch_geometric/datasets/tag_dataset.py +++ b/torch_geometric/datasets/tag_dataset.py @@ -158,6 +158,8 @@ def __init__( self.token_on_disk = token_on_disk self.tokenize_batch_size = tokenize_batch_size self._token = self.tokenize_graph(self.tokenize_batch_size) + self._llm_explanation_token = self.tokenize_graph( + self.tokenize_batch_size, text_type='llm_explanation') self.__num_classes__ = dataset.num_classes @property @@ -185,6 +187,12 @@ def token(self) -> Dict[str, Tensor]: self._token = self.tokenize_graph() return self._token + @property + def llm_explanation_token(self) -> Dict[str, Tensor]: + if self._llm_explanation_token is None: # lazy load + self._llm_explanation_token = self.tokenize_graph() + return self._llm_explanation_token + # load is_gold after init @property def is_gold(self) -> Tensor: @@ -318,22 +326,31 @@ def save_node_explanation(self, text: List[str]) -> None: text_df.to_csv(osp.join(node_text_path), compression='gzip', index=False) - def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: + def tokenize_graph(self, batch_size: int = 256, + text_type: str = 'raw_text') -> Dict[str, Tensor]: r"""Tokenizing the text associate with each node, running in cpu. Args: batch_size (Optional[int]): batch size of list of text for generating emebdding + text_type (Optional[str]): type of text Returns: Dict[str, torch.Tensor]: tokenized graph """ + assert text_type in ['raw_text', 'llm_explanation'] + if text_type == 'raw_text': + _text = self.text + elif text_type == 'llm_explanation': + _text = self.llm_explanation + data_len = 0 - if self.text is not None: - data_len = len(self.text) + if _text is not None: + data_len = len(_text) else: raise ValueError("The TAGDataset need text for tokenization") token_keys = ['input_ids', 'token_type_ids', 'attention_mask'] - path = os.path.join(self.processed_dir, 'token', self.tokenizer_name) + path = os.path.join(self.processed_dir, 'token', text_type, + self.tokenizer_name) # Check if the .pt files already exist token_files_exist = any( os.path.exists(os.path.join(path, f'{k}.pt')) for k in token_keys) @@ -350,12 +367,12 @@ def tokenize_graph(self, batch_size: int = 256) -> Dict[str, Tensor]: all_encoded_token = {k: [] for k in token_keys} pbar = tqdm(total=data_len) - pbar.set_description('Tokenizing Text Attributed Graph') + pbar.set_description(f'Tokenizing Text Attributed Graph {text_type}') for i in range(0, data_len, batch_size): end_index = min(data_len, i + batch_size) - token = self.tokenizer(self.text[i:end_index], - padding='max_length', truncation=True, - max_length=512, return_tensors="pt") + token = self.tokenizer(_text[i:end_index], padding='max_length', + truncation=True, max_length=512, + return_tensors="pt") for k in token.keys(): all_encoded_token[k].append(token[k]) pbar.update(end_index - i) @@ -383,10 +400,16 @@ class TextDataset(torch.utils.data.Dataset): Args: tag_dataset (TAGDataset): the parent dataset + text_type (str): type of text """ - def __init__(self, tag_dataset: 'TAGDataset') -> None: + def __init__(self, tag_dataset: 'TAGDataset', + text_type: str = 'raw_text') -> None: + assert text_type in ['raw_text', 'llm_explanation'] self.tag_dataset = tag_dataset - self.token = tag_dataset.token + if text_type == 'raw_text': + self.token = tag_dataset.token + elif text_type == 'llm_explanation': + self.token = tag_dataset.llm_explanation_token assert tag_dataset._data is not None self._data = tag_dataset._data @@ -438,8 +461,8 @@ def get(self, idx: int) -> BaseData: def __repr__(self) -> str: return f'{self.__class__.__name__}()' - def to_text_dataset(self) -> TextDataset: + def to_text_dataset(self, text_type: str = 'raw_text') -> TextDataset: r"""Factory Build text dataset from Text Attributed Graph Dataset each data point is node's associated text token. """ - return TAGDataset.TextDataset(self) + return TAGDataset.TextDataset(self, text_type) From 5c76117c06fe08d7783a021194653033e5390c49 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Fri, 3 Jan 2025 00:16:09 +0800 Subject: [PATCH 04/19] add lm training --- examples/llm/tape/generate.py | 87 ++++++++++++++++++++++++++++++++++- 1 file changed, 85 insertions(+), 2 deletions(-) diff --git a/examples/llm/tape/generate.py b/examples/llm/tape/generate.py index 9064ac111b8d..4caade957384 100644 --- a/examples/llm/tape/generate.py +++ b/examples/llm/tape/generate.py @@ -1,6 +1,10 @@ +import torch from ogb.nodeproppred import PygNodePropPredDataset +from tqdm import tqdm +from transformers import AutoModelForSequenceClassification from torch_geometric.datasets import TAGDataset +from torch_geometric.loader import DataLoader def main(): @@ -10,8 +14,7 @@ def main(): token_on_disk = True dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) - dataset.get_idx_split() - + split_idx = dataset.get_idx_split() tag_dataset = TAGDataset(root, dataset, hf_model, token_on_disk=token_on_disk) raw_text_dataset = tag_dataset.to_text_dataset() @@ -21,6 +24,86 @@ def main(): print(raw_text_dataset) print(llm_explanation_dataset) + # Train LM ========================================= + lm_batch_size = 256 + train_dataset = torch.utils.data.Subset( + llm_explanation_dataset, + split_idx['train'].nonzero().squeeze().tolist()) + val_dataset = torch.utils.data.Subset( + llm_explanation_dataset, + split_idx['valid'].nonzero().squeeze().tolist()) + test_dataset = torch.utils.data.Subset( + llm_explanation_dataset, + split_idx['test'].nonzero().squeeze().tolist()) + + print('Building language model dataloader...', end='-->') + + train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, shuffle=True) + val_loader = DataLoader(val_dataset, batch_size=lm_batch_size, + drop_last=False, pin_memory=True, shuffle=True) + test_loader = DataLoader(test_dataset, batch_size=lm_batch_size * 4, + drop_last=False, pin_memory=True, shuffle=False) + print(f'{len(train_loader)} | {len(val_loader)} | {len(test_loader)}') + + device = torch.device('cuda') + lm = AutoModelForSequenceClassification.from_pretrained( + hf_model, + num_labels=tag_dataset.num_classes, + torch_dtype=torch.bfloat16, + offload_folder='offload', + trust_remote_code=True, + ).to(device) + optimizer = torch.optim.Adam(lm.parameters(), lr=1e-3) + lm_loss = torch.nn.CrossEntropyLoss(reduction='mean') + # import pdb; pdb.set_trace() + + # Pretrain language model + num_epochs = 100 + patience = 3 + verbose = True + best_acc = 0 + early_stopping = 0 + for epoch in range(1, num_epochs + 1): + # ======================================== + all_out = [] + total_loss = total_correct = 0 + num_nodes = len(train_loader.dataset.indices) + lm.train() + if verbose: + pbar = tqdm(total=num_nodes) + pbar.set_description(f'Epoch {epoch:02d}') + for batch in train_loader: + inputs = {k: v.to(device) for k, v in batch['input'].items()} + out = lm(**inputs).logits + labels = batch['labels'].to(device).squeeze() + loss = lm_loss(out, labels) + loss.backward() + optimizer.step() + optimizer.zero_grad() + all_out.append(out) + total_correct += int(out.argmax(dim=-1).eq(labels).sum()) + total_loss += float(loss) + if verbose: + pbar.update(batch['n_id'].size(0)) + + all_out = torch.cat(all_out, dim=0) + approx_acc = total_correct / num_nodes + loss = total_loss / len(train_loader) + if verbose: + pbar.close() + print(f'Epoch {epoch:02d} Loss: {loss:.4f} ' + f'Approx. Train: {approx_acc:.4f}') + acc = approx_acc + # =================================================== + if acc < best_acc: + early_stopping += 1 + if early_stopping > patience: + print(f'Early stopped by Epoch: {epoch}, ' + f'Best acc: {best_acc}') + break + best_acc = max(best_acc, acc) + if __name__ == '__main__': main() From a25231282e00088bf8b65e0010d198d9dc388ba3 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 5 Jan 2025 16:37:01 +0800 Subject: [PATCH 05/19] update --- torch_geometric/nn/models/__init__.py | 2 ++ torch_geometric/nn/models/tape.py | 9 +++++++++ 2 files changed, 11 insertions(+) create mode 100644 torch_geometric/nn/models/tape.py diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index 9ade58cebc05..b74bc9d7a82c 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -32,6 +32,7 @@ from .git_mol import GITMol from .molecule_gpt import MoleculeGPT from .glem import GLEM +from .tape import TAPE # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -82,4 +83,5 @@ 'GITMol', 'MoleculeGPT', 'GLEM', + 'TAPE', ] diff --git a/torch_geometric/nn/models/tape.py b/torch_geometric/nn/models/tape.py new file mode 100644 index 000000000000..2829dd994c59 --- /dev/null +++ b/torch_geometric/nn/models/tape.py @@ -0,0 +1,9 @@ +class TAPE(): + def __init__(self): + pass + + def train_lm(self): + pass + + def inference_lm(self): + pass From e302bd487485ab128b6cc7affc02a062f35da009 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 5 Jan 2025 17:33:50 +0800 Subject: [PATCH 06/19] update --- examples/llm/tape/generate.py | 109 -------------------------- torch_geometric/nn/models/__init__.py | 2 - torch_geometric/nn/models/tape.py | 9 --- 3 files changed, 120 deletions(-) delete mode 100644 examples/llm/tape/generate.py delete mode 100644 torch_geometric/nn/models/tape.py diff --git a/examples/llm/tape/generate.py b/examples/llm/tape/generate.py deleted file mode 100644 index 4caade957384..000000000000 --- a/examples/llm/tape/generate.py +++ /dev/null @@ -1,109 +0,0 @@ -import torch -from ogb.nodeproppred import PygNodePropPredDataset -from tqdm import tqdm -from transformers import AutoModelForSequenceClassification - -from torch_geometric.datasets import TAGDataset -from torch_geometric.loader import DataLoader - - -def main(): - dataset_name = 'arxiv' - root = './data/ogb' - hf_model = 'prajjwal1/bert-tiny' - token_on_disk = True - - dataset = PygNodePropPredDataset(f'ogbn-{dataset_name}', root=root) - split_idx = dataset.get_idx_split() - tag_dataset = TAGDataset(root, dataset, hf_model, - token_on_disk=token_on_disk) - raw_text_dataset = tag_dataset.to_text_dataset() - llm_explanation_dataset = tag_dataset.to_text_dataset( - text_type='llm_explanation') - print(tag_dataset.num_classes, tag_dataset.raw_file_names) - print(raw_text_dataset) - print(llm_explanation_dataset) - - # Train LM ========================================= - lm_batch_size = 256 - train_dataset = torch.utils.data.Subset( - llm_explanation_dataset, - split_idx['train'].nonzero().squeeze().tolist()) - val_dataset = torch.utils.data.Subset( - llm_explanation_dataset, - split_idx['valid'].nonzero().squeeze().tolist()) - test_dataset = torch.utils.data.Subset( - llm_explanation_dataset, - split_idx['test'].nonzero().squeeze().tolist()) - - print('Building language model dataloader...', end='-->') - - train_loader = DataLoader(train_dataset, batch_size=lm_batch_size, - drop_last=False, pin_memory=True, shuffle=True) - val_loader = DataLoader(val_dataset, batch_size=lm_batch_size, - drop_last=False, pin_memory=True, shuffle=True) - test_loader = DataLoader(test_dataset, batch_size=lm_batch_size * 4, - drop_last=False, pin_memory=True, shuffle=False) - print(f'{len(train_loader)} | {len(val_loader)} | {len(test_loader)}') - - device = torch.device('cuda') - lm = AutoModelForSequenceClassification.from_pretrained( - hf_model, - num_labels=tag_dataset.num_classes, - torch_dtype=torch.bfloat16, - offload_folder='offload', - trust_remote_code=True, - ).to(device) - optimizer = torch.optim.Adam(lm.parameters(), lr=1e-3) - lm_loss = torch.nn.CrossEntropyLoss(reduction='mean') - # import pdb; pdb.set_trace() - - # Pretrain language model - num_epochs = 100 - patience = 3 - verbose = True - best_acc = 0 - early_stopping = 0 - for epoch in range(1, num_epochs + 1): - # ======================================== - all_out = [] - total_loss = total_correct = 0 - num_nodes = len(train_loader.dataset.indices) - lm.train() - if verbose: - pbar = tqdm(total=num_nodes) - pbar.set_description(f'Epoch {epoch:02d}') - for batch in train_loader: - inputs = {k: v.to(device) for k, v in batch['input'].items()} - out = lm(**inputs).logits - labels = batch['labels'].to(device).squeeze() - loss = lm_loss(out, labels) - loss.backward() - optimizer.step() - optimizer.zero_grad() - all_out.append(out) - total_correct += int(out.argmax(dim=-1).eq(labels).sum()) - total_loss += float(loss) - if verbose: - pbar.update(batch['n_id'].size(0)) - - all_out = torch.cat(all_out, dim=0) - approx_acc = total_correct / num_nodes - loss = total_loss / len(train_loader) - if verbose: - pbar.close() - print(f'Epoch {epoch:02d} Loss: {loss:.4f} ' - f'Approx. Train: {approx_acc:.4f}') - acc = approx_acc - # =================================================== - if acc < best_acc: - early_stopping += 1 - if early_stopping > patience: - print(f'Early stopped by Epoch: {epoch}, ' - f'Best acc: {best_acc}') - break - best_acc = max(best_acc, acc) - - -if __name__ == '__main__': - main() diff --git a/torch_geometric/nn/models/__init__.py b/torch_geometric/nn/models/__init__.py index b74bc9d7a82c..9ade58cebc05 100644 --- a/torch_geometric/nn/models/__init__.py +++ b/torch_geometric/nn/models/__init__.py @@ -32,7 +32,6 @@ from .git_mol import GITMol from .molecule_gpt import MoleculeGPT from .glem import GLEM -from .tape import TAPE # Deprecated: from torch_geometric.explain.algorithm.captum import (to_captum_input, captum_output_to_dicts) @@ -83,5 +82,4 @@ 'GITMol', 'MoleculeGPT', 'GLEM', - 'TAPE', ] diff --git a/torch_geometric/nn/models/tape.py b/torch_geometric/nn/models/tape.py deleted file mode 100644 index 2829dd994c59..000000000000 --- a/torch_geometric/nn/models/tape.py +++ /dev/null @@ -1,9 +0,0 @@ -class TAPE(): - def __init__(self): - pass - - def train_lm(self): - pass - - def inference_lm(self): - pass From 82c94508fe5e86e8806d166cb5947a5aec71788c Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 5 Jan 2025 17:35:56 +0800 Subject: [PATCH 07/19] update --- examples/llm/glem.py | 2 -- torch_geometric/nn/models/glem.py | 3 +-- 2 files changed, 1 insertion(+), 4 deletions(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index c6bae703fd33..ec76cef4c010 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -371,11 +371,9 @@ def load_model(em_phase): if gnn_val_acc > lm_val_acc: em_phase = 'gnn' model.gnn = model.gnn.to(device, non_blocking=True) - test_loader = subgraph_loader else: em_phase = 'lm' model.lm = model.lm.to(device, non_blocking=True) - test_loader = text_test_loader test_preds = model.inference(em_phase, test_loader, verbose=verbose) train_acc, val_acc, test_acc = evaluate(test_preds, ['train', 'valid', 'test']) diff --git a/torch_geometric/nn/models/glem.py b/torch_geometric/nn/models/glem.py index d30d5f8bd062..afc8b09d77c7 100644 --- a/torch_geometric/nn/models/glem.py +++ b/torch_geometric/nn/models/glem.py @@ -144,8 +144,7 @@ def train(self, em_phase: str, train_loader: Union[DataLoader, acc (float): training accuracy loss (float): loss value """ - if pseudo_labels is not None: - pseudo_labels = pseudo_labels.to(self.device) + pseudo_labels = pseudo_labels.to(self.device) if em_phase == 'gnn': acc, loss = self.train_gnn(train_loader, optimizer, epoch, pseudo_labels, is_augmented, verbose) From 29f38811d9d59c10cc4a941aea1e7c933c2ef784 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 5 Jan 2025 17:42:35 +0800 Subject: [PATCH 08/19] update --- test/datasets/test_tag_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/datasets/test_tag_dataset.py b/test/datasets/test_tag_dataset.py index a816dfa77ee3..cbce973c5e11 100644 --- a/test/datasets/test_tag_dataset.py +++ b/test/datasets/test_tag_dataset.py @@ -1,13 +1,13 @@ from ogb.nodeproppred import PygNodePropPredDataset from torch_geometric.datasets import TAGDataset -from torch_geometric.testing import withPackage +from torch_geometric.testing import onlyFullTest, withPackage -# @onlyFullTest +@onlyFullTest @withPackage('ogb') def test_tag_dataset() -> None: - root = './data' + root = './data/ogb' hf_model = 'prajjwal1/bert-tiny' token_on_disk = True From b139ac2c78a258a6db7763ec8b9f14f726731179 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 5 Jan 2025 18:33:48 +0800 Subject: [PATCH 09/19] add changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4e6789b9a86d..37e93da28c1f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Added +- Added llm generated explanations to `TAGDataset` ([#9918](https://github.com/pyg-team/pytorch_geometric/pull/9918)) - Update Dockerfile to use latest from NVIDIA ([#9794](https://github.com/pyg-team/pytorch_geometric/pull/9794)) - Added various GRetriever Architecture Benchmarking examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) - Added `profiler.nvtxit` with some examples ([#9666](https://github.com/pyg-team/pytorch_geometric/pull/9666)) From 0f643d591fe37796f498be27178b51470d269dbf Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Jan 2025 10:33:59 +0000 Subject: [PATCH 10/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- README.md | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 64a809fdcf83..f96d720837d1 100644 --- a/README.md +++ b/README.md @@ -383,9 +383,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` | | `cpu` | `cu118` | `cu121` | `cu124` | | ----------- | ----- | ------- | ------- | ------- | -| **Linux** | ✅ | ✅ | ✅ | ✅ | -| **Windows** | ✅ | ✅ | ✅ | ✅ | -| **macOS** | ✅ | | | | +| **Linux** | ✅ | ✅ | ✅ | ✅ | +| **Windows** | ✅ | ✅ | ✅ | ✅ | +| **macOS** | ✅ | | | | #### PyTorch 2.4 @@ -399,9 +399,9 @@ where `${CUDA}` should be replaced by either `cpu`, `cu118`, `cu121`, or `cu124` | | `cpu` | `cu118` | `cu121` | `cu124` | | ----------- | ----- | ------- | ------- | ------- | -| **Linux** | ✅ | ✅ | ✅ | ✅ | -| **Windows** | ✅ | ✅ | ✅ | ✅ | -| **macOS** | ✅ | | | | +| **Linux** | ✅ | ✅ | ✅ | ✅ | +| **Windows** | ✅ | ✅ | ✅ | ✅ | +| **macOS** | ✅ | | | | **Note:** Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0, PyTorch 1.6.0, PyTorch 1.7.0/1.7.1, PyTorch 1.8.0/1.8.1, PyTorch 1.9.0, PyTorch 1.10.0/1.10.1/1.10.2, PyTorch 1.11.0, PyTorch 1.12.0/1.12.1, PyTorch 1.13.0/1.13.1, PyTorch 2.0.0/2.0.1, PyTorch 2.1.0/2.1.1/2.1.2, PyTorch 2.2.0/2.2.1/2.2.2, and PyTorch 2.3.0/2.3.1 (following the same procedure). **For older versions, you might need to explicitly specify the latest supported version number** or install via `pip install --no-index` in order to prevent a manual installation from source. From 07fe2ec5324ee29851ac4c96415b107a0c73f519 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 5 Jan 2025 20:52:33 +0800 Subject: [PATCH 11/19] update --- test/datasets/test_tag_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/datasets/test_tag_dataset.py b/test/datasets/test_tag_dataset.py index cbce973c5e11..2c88e5e51a19 100644 --- a/test/datasets/test_tag_dataset.py +++ b/test/datasets/test_tag_dataset.py @@ -1,5 +1,3 @@ -from ogb.nodeproppred import PygNodePropPredDataset - from torch_geometric.datasets import TAGDataset from torch_geometric.testing import onlyFullTest, withPackage @@ -7,6 +5,8 @@ @onlyFullTest @withPackage('ogb') def test_tag_dataset() -> None: + from ogb.nodeproppred import PygNodePropPredDataset + root = './data/ogb' hf_model = 'prajjwal1/bert-tiny' token_on_disk = True From 695ed01d7979ec76ceaf50f7c4141ea0c31926ae Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 5 Jan 2025 12:53:45 +0000 Subject: [PATCH 12/19] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- test/datasets/test_tag_dataset.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/datasets/test_tag_dataset.py b/test/datasets/test_tag_dataset.py index 2c88e5e51a19..58b0d3ef66f4 100644 --- a/test/datasets/test_tag_dataset.py +++ b/test/datasets/test_tag_dataset.py @@ -6,7 +6,7 @@ @withPackage('ogb') def test_tag_dataset() -> None: from ogb.nodeproppred import PygNodePropPredDataset - + root = './data/ogb' hf_model = 'prajjwal1/bert-tiny' token_on_disk = True From 74f2bb9aab190971f015339260d255956f972254 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Tue, 7 Jan 2025 23:40:38 +0800 Subject: [PATCH 13/19] update --- torch_geometric/datasets/tag_dataset.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py index f88e67014684..557eaeeffb6b 100644 --- a/torch_geometric/datasets/tag_dataset.py +++ b/torch_geometric/datasets/tag_dataset.py @@ -48,7 +48,7 @@ class TAGDataset(InMemoryDataset): associate with node, which should be align with node list llm_prediction (Optional[List[str]]): list of llm prediction associate with node, the order of list should be align with node list - llm_prediction_topk (Optional[int]): Top K prediction from LLM used as + llm_prediction_topk (int): Top K prediction from LLM used as features for GNN training, default: 5 split_idx (Optional[Dict[str, torch.Tensor]]): Optional dictionary, for saving split index, it is required that if your dataset doesn't @@ -81,8 +81,8 @@ def __init__( tokenizer_name: str, text: Optional[List[str]] = None, llm_explanation: Optional[List[str]] = None, - llm_prediction: Optional[List[str]] = None, - llm_prediction_topk: Optional[int] = 5, + llm_prediction: Optional[Tensor] = None, + llm_prediction_topk: int = 5, split_idx: Optional[Dict[str, Tensor]] = None, tokenize_batch_size: int = 256, token_on_disk: bool = False, From 502659aeaa7dcf1e5dadb7df7d813d346e663878 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Tue, 7 Jan 2025 23:58:03 +0800 Subject: [PATCH 14/19] update --- torch_geometric/datasets/tag_dataset.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py index 557eaeeffb6b..0cfe912a1b5f 100644 --- a/torch_geometric/datasets/tag_dataset.py +++ b/torch_geometric/datasets/tag_dataset.py @@ -190,7 +190,8 @@ def token(self) -> Dict[str, Tensor]: @property def llm_explanation_token(self) -> Dict[str, Tensor]: if self._llm_explanation_token is None: # lazy load - self._llm_explanation_token = self.tokenize_graph() + self._llm_explanation_token = self.tokenize_graph( + text_type='llm_explanation') return self._llm_explanation_token # load is_gold after init From ac402f320ece574c8a93c4a6e52735a9dc3c4459 Mon Sep 17 00:00:00 2001 From: xnuohz Date: Fri, 10 Jan 2025 18:29:53 +0800 Subject: [PATCH 15/19] update --- examples/llm/glem.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index c6bae703fd33..e337ffeb15af 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -40,6 +40,7 @@ def get_n_params(model): def main(args): gpu = args.gpu dataset_name = args.dataset + text_type = args.text_type root = osp.join('data', 'ogb') hf_model = args.hf_model pl_ratio = args.pl_ratio @@ -83,7 +84,7 @@ def main(args): tag_dataset = TAGDataset(root, dataset, hf_model, token_on_disk=token_on_disk) - text_dataset = tag_dataset.to_text_dataset() + text_dataset = tag_dataset.to_text_dataset(text_type) print(tag_dataset.num_classes, tag_dataset.raw_file_names) num_classes = tag_dataset.num_classes @@ -395,6 +396,8 @@ def load_model(em_phase): help='number of iterations') parser.add_argument("--dataset", type=str, default='products', help='arxiv or products') + parser.add_argument("--text_type", type=str, default='raw_text', + help='raw_text or llm_explanation') parser.add_argument("--pl_ratio", type=float, default=0.5, help="pseudo labels ratio") parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', From fda21c21003a3f50f077f94beb5f25efa6c7548b Mon Sep 17 00:00:00 2001 From: xnuohz Date: Sun, 12 Jan 2025 23:53:36 +0800 Subject: [PATCH 16/19] update --- examples/llm/glem.py | 2 +- torch_geometric/datasets/tag_dataset.py | 19 +++++++++++++++++-- 2 files changed, 18 insertions(+), 3 deletions(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index e337ffeb15af..ed787fb08a3c 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -397,7 +397,7 @@ def load_model(em_phase): parser.add_argument("--dataset", type=str, default='products', help='arxiv or products') parser.add_argument("--text_type", type=str, default='raw_text', - help='raw_text or llm_explanation') + help='raw_text, llm_explanation or all') parser.add_argument("--pl_ratio", type=float, default=0.5, help="pseudo labels ratio") parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py index 0cfe912a1b5f..5cfa003a631a 100644 --- a/torch_geometric/datasets/tag_dataset.py +++ b/torch_geometric/datasets/tag_dataset.py @@ -160,6 +160,8 @@ def __init__( self._token = self.tokenize_graph(self.tokenize_batch_size) self._llm_explanation_token = self.tokenize_graph( self.tokenize_batch_size, text_type='llm_explanation') + self._all_token = self.tokenize_graph(self.tokenize_batch_size, + text_type='all') self.__num_classes__ = dataset.num_classes @property @@ -194,6 +196,12 @@ def llm_explanation_token(self) -> Dict[str, Tensor]: text_type='llm_explanation') return self._llm_explanation_token + @property + def all_token(self) -> Dict[str, Tensor]: + if self._all_token is None: # lazy load + self._all_token = self.tokenize_graph(text_type='all') + return self._all_token + # load is_gold after init @property def is_gold(self) -> Tensor: @@ -338,11 +346,16 @@ def tokenize_graph(self, batch_size: int = 256, Returns: Dict[str, torch.Tensor]: tokenized graph """ - assert text_type in ['raw_text', 'llm_explanation'] + assert text_type in ['raw_text', 'llm_explanation', 'all'] if text_type == 'raw_text': _text = self.text elif text_type == 'llm_explanation': _text = self.llm_explanation + elif text_type == 'all': + _text = [ + f'{raw_txt} Explanation: {exp_txt}' + for raw_txt, exp_txt in zip(self.text, self.llm_explanation) + ] data_len = 0 if _text is not None: @@ -405,12 +418,14 @@ class TextDataset(torch.utils.data.Dataset): """ def __init__(self, tag_dataset: 'TAGDataset', text_type: str = 'raw_text') -> None: - assert text_type in ['raw_text', 'llm_explanation'] + assert text_type in ['raw_text', 'llm_explanation', 'all'] self.tag_dataset = tag_dataset if text_type == 'raw_text': self.token = tag_dataset.token elif text_type == 'llm_explanation': self.token = tag_dataset.llm_explanation_token + elif text_type == 'all': + self.token = tag_dataset.all_token assert tag_dataset._data is not None self._data = tag_dataset._data From 3aa5e56075fa99b6b3a128d346b352dde42c04de Mon Sep 17 00:00:00 2001 From: xnuohz Date: Tue, 14 Jan 2025 01:04:43 +0800 Subject: [PATCH 17/19] update --- torch_geometric/datasets/tag_dataset.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torch_geometric/datasets/tag_dataset.py b/torch_geometric/datasets/tag_dataset.py index 5cfa003a631a..cbc9fb70c30b 100644 --- a/torch_geometric/datasets/tag_dataset.py +++ b/torch_geometric/datasets/tag_dataset.py @@ -352,6 +352,9 @@ def tokenize_graph(self, batch_size: int = 256, elif text_type == 'llm_explanation': _text = self.llm_explanation elif text_type == 'all': + if self.text is None or self.llm_explanation is None: + raise ValueError("The TAGDataset need text and llm explanation" + "for tokenizing all text") _text = [ f'{raw_txt} Explanation: {exp_txt}' for raw_txt, exp_txt in zip(self.text, self.llm_explanation) From fbc573c17f2981f2a906a930aab8543ee58ad11f Mon Sep 17 00:00:00 2001 From: xnuohz Date: Thu, 16 Jan 2025 22:36:22 +0800 Subject: [PATCH 18/19] update --- examples/llm/glem.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/examples/llm/glem.py b/examples/llm/glem.py index ed787fb08a3c..8730fb17f9b5 100644 --- a/examples/llm/glem.py +++ b/examples/llm/glem.py @@ -40,7 +40,7 @@ def get_n_params(model): def main(args): gpu = args.gpu dataset_name = args.dataset - text_type = args.text_type + text_type = args.text_type if args.dataset == 'arxiv' else 'raw_text' root = osp.join('data', 'ogb') hf_model = args.hf_model pl_ratio = args.pl_ratio @@ -394,10 +394,12 @@ def load_model(em_phase): help='number of runs') parser.add_argument('--num_em_iters', type=int, default=1, help='number of iterations') - parser.add_argument("--dataset", type=str, default='products', + parser.add_argument("--dataset", type=str, default='arxiv', help='arxiv or products') - parser.add_argument("--text_type", type=str, default='raw_text', - help='raw_text, llm_explanation or all') + parser.add_argument( + "--text_type", type=str, default='llm_explanation', + help="type of text, support raw_text, llm_explanation," + "all for arxiv and raw_text for products") parser.add_argument("--pl_ratio", type=float, default=0.5, help="pseudo labels ratio") parser.add_argument('--hf_model', type=str, default='prajjwal1/bert-tiny', From 075b7b20329ca6c79c31301d11beb0638ed3242d Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Sun, 26 Jan 2025 00:39:52 +0000 Subject: [PATCH 19/19] empty