Skip to content

Commit 8420538

Browse files
committed
test s3 on cube
1 parent 92ed5a7 commit 8420538

11 files changed

Lines changed: 528 additions & 53 deletions

File tree

s3/llm_agent/generation_s3.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -696,9 +696,15 @@ def _passages2string(self, retrieval_result):
696696
format_reference = ''
697697
for idx, doc_item in enumerate(retrieval_result):
698698

699-
content = doc_item['document']['contents']
700-
title = content.split("\n")[0]
701-
text = "\n".join(content.split("\n")[1:])
699+
if "cube" in self.config.output_context_dir:
700+
content = doc_item['document']
701+
title = content['title']
702+
text = content['text']
703+
else:
704+
content = doc_item['document']['contents']
705+
title = content.split("\n")[0]
706+
text = "\n".join(content.split("\n")[1:])
707+
702708
format_reference += f"Doc {idx+1}(Title: {title}) {text}\n"
703709

704710
# if "mirage" in self.config.output_context_dir:
Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
#!/usr/bin/env python3
2+
3+
import pandas as pd
4+
import requests
5+
import json
6+
import argparse
7+
import os
8+
9+
def search(query: str, endpoint: str):
10+
payload = {
11+
"queries": [query],
12+
"topk": 12,
13+
"return_scores": True
14+
}
15+
try:
16+
response = requests.post(endpoint, json=payload)
17+
response.raise_for_status()
18+
results = response.json()['result']
19+
# print(results)
20+
except Exception as e:
21+
print(f"[ERROR] Retrieval failed for query: {query}\n{e}")
22+
return ""
23+
24+
def _passages2string(retrieval_result):
25+
format_reference = ''
26+
for idx, doc_item in enumerate(retrieval_result):
27+
content = doc_item['document']
28+
title = content['title']
29+
text = content['text']
30+
format_reference += f"Doc {idx+1} (Title: {title}) {text}\n"
31+
return format_reference
32+
33+
return _passages2string(results[0])
34+
35+
def main():
36+
parser = argparse.ArgumentParser(description="Run retrieval and save JSON outputs.")
37+
parser.add_argument("--input_parquet", required=True, help="Input .parquet file with QA data.")
38+
parser.add_argument("--output_dir", required=True, help="Directory to store output JSON files.")
39+
parser.add_argument("--endpoint", required=True, help="Retrieval API endpoint URL (e.g., http://127.0.0.1:8000/retrieve)")
40+
parser.add_argument("--data_sources", required=True, help="Data sources to process (e.g., hotpotqa,2wikimultihopqa)")
41+
args = parser.parse_args()
42+
43+
os.makedirs(args.output_dir, exist_ok=True)
44+
45+
df = pd.read_parquet(args.input_parquet)
46+
47+
data_sources = args.data_sources.split(',')
48+
49+
for data_source in data_sources:
50+
print(f"[INFO] Processing: {data_source}")
51+
retrieval_info = {}
52+
qa_data = df[df['data_source'] == data_source]
53+
54+
for index, row in qa_data.iterrows():
55+
# print(row)
56+
q = row['question']
57+
golden_answers = [row['answer']]
58+
retrieval_result = search(q, args.endpoint)
59+
question_info = {
60+
'golden_answers': golden_answers,
61+
'context_with_info': retrieval_result
62+
}
63+
retrieval_info[q] = question_info
64+
65+
out_path = os.path.join(args.output_dir, f"{data_source}_output_sequences.json")
66+
with open(out_path, 'w') as f:
67+
json.dump(retrieval_info, f, indent=4)
68+
print(f"[INFO] Saved: {out_path}")
69+
70+
if __name__ == "__main__":
71+
main()
Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# python scripts/baselines/e5_retrieval_cube.py \
2+
# --input_parquet data/cube/test_e5_cube_2wikimultihopqa_pre.parquet \
3+
# --output_dir data/cube/rag_e5_cube \
4+
# --endpoint http://127.0.0.1:3000/retrieve \
5+
# --data_sources 2wikimultihopqa
6+
7+
python scripts/baselines/e5_retrieval_cube.py \
8+
--input_parquet data/cube/test_e5_cube_hotpotqa_pre.parquet \
9+
--output_dir data/cube/rag_e5_cube \
10+
--endpoint http://127.0.0.1:3000/retrieve \
11+
--data_sources hotpotqa
12+
13+
Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,163 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Preprocess the QA dataset to parquet format
16+
"""
17+
18+
import re
19+
import os
20+
import datasets
21+
import json
22+
from verl.utils.hdfs_io import copy, makedirs
23+
import argparse
24+
25+
def make_prefix(dp, retriever):
26+
input_str = """You are a search copilot for the generation model. Based on a user's query and initial searched results, you will first determine if the searched results are enough to produce an answer.
27+
If the searched results are enough, you will use <search_complete>True</search_complete> to indicate that you have gathered enough information for the generation model to produce an answer.
28+
If the searched results are not enough, you will go through a loop of <query> -> <information> -> <important_info> -> <search_complete> -> <query> (if not complete) ..., to help the generation model to generate a better answer with more relevant information searched.
29+
You should show the search query between <query> and </query> in JSON format.
30+
Based on the search query, we will return the top searched results between <information> and </information>. You need to put the doc ids of the important documents (up to 3 documents, within the current information window) between <important_info> and </important_info> (e.g., <important_info>[1, 4]</important_info>).
31+
A search query MUST be followed by a <search_complete> tag if the search is not complete.
32+
After reviewing the information, you must decide whether to continue searching with a new query or indicate that the search is complete. If you need more information, use <search_complete>False</search_complete> to indicate you want to continue searching with a better query. Otherwise, use <search_complete>True</search_complete> to terminate the search.
33+
During the process, you can add reasoning process within <think></think> tag whenever you want. Note: Only the important information would be used for the generation model to produce an answer.
34+
"""
35+
36+
if retriever == "bm25":
37+
input_str += """Note: The search query should use Boolean operators (AND, OR) and parentheses for grouping terms appropriately."""
38+
39+
input_str += """
40+
For a question and initial searched results:
41+
<question>
42+
[user's question]
43+
</question>
44+
<information>
45+
[initial searched results]
46+
</information>
47+
48+
If the initial searched results are enough to produce an answer, you should output:
49+
<search_complete>
50+
True
51+
</search_complete>
52+
53+
If the initial searched results are not enough to produce an answer, you should output:
54+
<query>
55+
{
56+
"query": "[search query]"
57+
}
58+
</query>
59+
<information>
60+
[top searched results based on the above search query]
61+
</information>
62+
<important_info>
63+
[doc ids]
64+
</important_info>
65+
<search_complete>
66+
False
67+
</search_complete>
68+
<query>
69+
{
70+
"query": "[search query]"
71+
}
72+
</query>
73+
...... (can be several turns until <search_complete> is True)
74+
75+
<search_complete>
76+
True
77+
</search_complete>
78+
79+
Now, start the loop with the following question and initial searched results:
80+
"""
81+
82+
input_str += f"""
83+
<question>
84+
{dp['question']}
85+
</question>
86+
<information>
87+
{dp['initial_searched_results'].strip()}
88+
</information>
89+
"""
90+
return input_str
91+
92+
93+
if __name__ == '__main__':
94+
parser = argparse.ArgumentParser()
95+
parser.add_argument('--local_dir', default='./data/nq_search')
96+
parser.add_argument('--hdfs_dir', default=None)
97+
parser.add_argument('--data_sources', default='hotpotqa')
98+
parser.add_argument('--retriever', default="e5")
99+
parser.add_argument('--initial_searched_results_dir', default="data/cube/rag_e5_cube")
100+
args = parser.parse_args()
101+
102+
data_sources = args.data_sources.split(',')
103+
all_dataset = []
104+
105+
for data_source in data_sources:
106+
107+
with open(f"/shared/eng/pj20/cube_data/{data_source}_with_index.json", "r") as f:
108+
test_dataset = json.load(f)
109+
110+
initial_searched_results = json.load(open(os.path.join(args.initial_searched_results_dir, f'{data_source}_output_sequences.json')))
111+
112+
# Process each item in the list of dictionaries
113+
processed_data = []
114+
for idx, example in enumerate(test_dataset):
115+
example['question'] = example['question'].strip()
116+
example['initial_searched_results'] = initial_searched_results[example['question']]['context_with_info'].split("\nDoc 6")[0] + "\n"
117+
question = make_prefix(example, args.retriever)
118+
solution = {
119+
"question": example['question'],
120+
"target": example['answer'],
121+
"gt_docs": example['supporting_facts_index'] if 'supporting_facts_index' in example else []
122+
}
123+
124+
data = {
125+
"question": example['question'],
126+
"answer": example['answer'],
127+
"supporting_facts_index": example['supporting_facts_index'],
128+
"initial_searched_results": example['initial_searched_results'],
129+
"data_source": data_source,
130+
"prompt": [{
131+
"role": "user",
132+
"content": question,
133+
}],
134+
"ability": "fact-reasoning",
135+
"reward_model": {
136+
"style": "rule",
137+
"ground_truth": solution
138+
},
139+
"extra_info": {
140+
'split': 'test',
141+
'index': idx,
142+
}
143+
}
144+
processed_data.append(data)
145+
146+
all_dataset.append(processed_data)
147+
148+
local_dir = args.local_dir
149+
hdfs_dir = args.hdfs_dir
150+
151+
# Flatten the list of lists into a single list
152+
all_test_data = []
153+
for dataset in all_dataset:
154+
all_test_data.extend(dataset)
155+
156+
# Convert to Dataset and save as parquet
157+
all_test_dataset = datasets.Dataset.from_list(all_test_data)
158+
all_test_dataset.to_parquet(os.path.join(local_dir, f'test_{args.retriever}_cube_{data_source}.parquet'))
159+
160+
if hdfs_dir is not None:
161+
makedirs(hdfs_dir)
162+
163+
copy(src=local_dir, dst=hdfs_dir)

0 commit comments

Comments
 (0)