refactor: gnn_dataloader#27
Conversation
Add benchmark dataset
Refactor: add dataset loader and examples
| global op_num | ||
| op_num = len(self.op_types) | ||
|
|
||
| def download_data(self): |
There was a problem hiding this comment.
Please use .bench_dataset.bench_dataset() to download and get the data.
|
|
||
|
|
||
| class GNNDataset(torch.utils.data.Dataset): | ||
| def __init__(self, data_dir='./dataset', train=True, device='cpu', split_ratio=0.8): |
There was a problem hiding this comment.
the data_dir should be default set as user_dataset_folder
| "source": [ | ||
| "# Latency Dataset - GNN Model\n", | ||
| "\n", | ||
| "This example will demonstrate our ability to predict latency with data from NN-meter through GNN. We will first build our GNN model, which is constructed based on GraphSAGE, and maxpooling is selected as out pooling method. Next, we will start training after the data is loaded.\n", |
There was a problem hiding this comment.
This example will demonstrate our ability to predict latency with data from NN-meter through GNN." Change the sentence to "Considering the dataset is encoded in a graph format, here is an example of using GNN to predict the model latency with the bench dataset. GNNDataset and GNNDataloader in nn_meter/dataset/gnn_dataloader.py build the model structure of the Dataset in .jsonl format into our required Dataset and Dataloader."
| "\n", | ||
| "lr_scheduler = CosineAnnealingLR(opt, T_max=EPOCHS)\n", | ||
| "loss_sum = 0\n", | ||
| "for epoch in range(EPOCHS):\n", |
There was a problem hiding this comment.
Please add more comment to explain the code, such as "start training", "start validation", "save the best model", etc.
| err_str = "Not supported device type" | ||
| assert device in hws, err_str | ||
| self.device = device | ||
| self.data_dir = data_dir |
There was a problem hiding this comment.
In the setting from .bench_dataset, the dataset will downloaded to __user_dataset_folder__ by default. Thus users cannot get access to the downloaded dataset if users set the init parameter data_dir to other path. The API of .bench_dataset and GNNDataset.__init__ should be unified, either both have data_dir or both not.
| self.raw_data = {} | ||
| self.name_list = [] | ||
| self.latencies = {} | ||
| self.download_data() |
There was a problem hiding this comment.
I suggest using self.data_dir = self.download_data() and return the dataset path in self.download_data.
| self.construct_attrs() | ||
| self.name_list = list( | ||
| filter(lambda x: x in self.latencies, self.name_list)) | ||
| global op_num |
There was a problem hiding this comment.
Why this parameter should be global?
No description provided.