-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathqa_generation.py
More file actions
191 lines (174 loc) · 8.65 KB
/
qa_generation.py
File metadata and controls
191 lines (174 loc) · 8.65 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
# qa_generation.py
import os
import json
import re
from openai import OpenAI # For DeepSeek/VolcEngine
import utils
import config
def generate_advanced_qa(qa_folder_path, use_gemini=True, metadata=None, file_lock=None):
"""
Generates advanced Question-Answering pairs based on consolidated captions.
Can use Gemini or DeepSeek/VolcEngine API.
Reads various caption files, writes 'QAs_advance.json' or 'QAs_advance_deepseek.json'.
"""
backend_name = "Gemini" if use_gemini else "DeepSeek/Volc" # Simplified name
print(f"\n--- Starting QA Generation for: {qa_folder_path} (Backend: {backend_name}) ---")
video_id = os.path.basename(qa_folder_path)
# Define input caption file paths
# Using revised/consistent captions as input for QA
video_caption_path = os.path.join(qa_folder_path, 'video_consistent_captions.txt')
audio_caption_path = os.path.join(qa_folder_path, "audio_revised_captions.txt")
# Assuming 'av_alignment_captions.txt' is generated by another process or needs to be added
alignment_caption_path = os.path.join(qa_folder_path, 'av_alignment_captions.txt') # Placeholder
# Output JSON path
output_json_file = f"QAs_advance{'_deepseek' if not use_gemini else ''}.json"
output_json_path = os.path.join(qa_folder_path, output_json_file)
# Check essential files exist
# Allow skipping alignment if not present? Modify prompt accordingly.
if not os.path.exists(video_caption_path):
print(f"Error: Required video caption file not found: {video_caption_path}. Skipping QA.")
return False
if not os.path.exists(audio_caption_path):
print(f"Error: Required audio caption file not found: {audio_caption_path}. Skipping QA.")
return False
if not os.path.exists(alignment_caption_path):
print(f"Warning: Alignment caption file not found: {alignment_caption_path}. Proceeding without it.")
alignment_content = "N/A" # Indicate missing alignment data
else:
# Read content safely using utils function
alignment_content = utils.safe_read(alignment_caption_path)
if alignment_content is None: alignment_content = "Error Reading File"
# Read main captions
video_content = utils.safe_read(video_caption_path)
audio_content = utils.safe_read(audio_caption_path)
if video_content is None or audio_content is None:
print(f"Error: Failed to read required caption files. Skipping QA.")
return False
# Construct the prompt content
content = f'''
Please generate question-answer pairs for the following video description:
### Visual Caption:
{video_content}
### Audio Caption:
{audio_content}
### Audio Visual Event Happened at the same time(Audio_Event -- Visual_Event):
{alignment_content}
'''
print("Input captions loaded. Calling QA generation API...")
qa_pairs = None
json_string = None
# --- Call the appropriate API ---
if use_gemini:
response_text = utils.call_gemini_api(
contents=[content], # Text-only input
api_key=config.GEMINI_API_KEY_QA, # Key from original QA script
model_name=config.GEMINI_QA_MODEL, # QA/Thinking model
system_prompt=config.ADVANCED_QA_SYSTEM_PROMPT
)
# Extract JSON from response text
if response_text and isinstance(response_text, str): # call_gemini_api returns response obj, need .text
# Check if call_gemini_api was modified to return text directly, otherwise use response_text.text
# Assuming it returns response object:
if hasattr(response_text, 'text') and response_text.text:
json_string = response_text.text
else:
print("Warning: Gemini QA call succeeded but no text content found in response.")
return False # Or handle differently
elif response_text: # If call_gemini_api was changed to return text
json_string = response_text
else:
print("Error: Failed to get response from Gemini QA API.")
return False
else: # Use DeepSeek / VolcEngine (using Volc details from example)
try:
# Ensure you have the openai package installed: pip install openai
client = OpenAI(
api_key=config.DEEPSEEK_API_KEY,
base_url=config.DEEPSEEK_BASE_URL
)
# Add retry logic here if needed for this API
completion = client.chat.completions.create(
model=config.DEEPSEEK_QA_MODEL,
messages=[
{"role": "system", "content": config.ADVANCED_QA_SYSTEM_PROMPT},
{"role": "user", "content": content},
],
stream=False,
temperature=1.0,
max_tokens=8000
)
if completion.choices:
json_string = completion.choices[0].message.content
else:
print("Error: DeepSeek/Volc API call returned no choices.")
return False
except Exception as e:
print(f"Error calling DeepSeek/Volc API: {e}")
return False
# --- Process and Save the JSON response ---
if json_string:
# Extract JSON block using regex
pattern = r'```json(.*)```'
match = re.search(pattern, json_string, re.DOTALL | re.IGNORECASE)
if match:
extracted_json_str = match.group(1).strip()
try:
qa_pairs = json.loads(extracted_json_str)
if not isinstance(qa_pairs, list): # Ensure it's a list
print(f"Error: Extracted JSON is not a list: {type(qa_pairs)}")
qa_pairs = None
except json.JSONDecodeError as e:
print(f"Error decoding JSON from API response: {e}")
print(f"--- Raw JSON String Attempted: ---\n{extracted_json_str}\n---------------------------------")
qa_pairs = None
else:
# Maybe the response is *just* the JSON without backticks
try:
qa_pairs = json.loads(json_string)
if not isinstance(qa_pairs, list):
print(f"Error: API response is not a JSON list: {type(qa_pairs)}")
qa_pairs = None
except json.JSONDecodeError:
print("Error: Could not find JSON block ```json...``` in the API response, and the response itself is not valid JSON.")
print(f"--- Raw API Response: ---\n{json_string}\n--------------------------")
qa_pairs = None
else:
# Already handled API call failure above
pass
if qa_pairs:
print(f"Successfully parsed {len(qa_pairs)} QA pairs.")
# Add video_id and metadata
for qa_pair in qa_pairs:
if isinstance(qa_pair, dict): # Ensure it's a dictionary
qa_pair["video_id"] = video_id
if metadata:
# Add metadata associated with this video_id
# Assumes metadata is a dict {video_id: {meta_key: meta_val}}
# Or find the matching metadata from the list
# Example: find matching dict in metadata list
this_video_meta = next((item for item in metadata if item.get('video_id_for_alignment') == video_id), None)
if this_video_meta:
for key, value in this_video_meta.items():
if key != 'video_id_for_alignment': # Don't add the helper key
qa_pair[key] = value
else: # If metadata provided but not found for this ID
pass # Or add default values? print("Warning: Metadata not found for video_id: {video_id}")
else:
print(f"Warning: Found non-dictionary item in QA list: {qa_pair}")
# Write to file, potentially using lock for multiprocessing safety
write_content = json.dumps(qa_pairs, ensure_ascii=False, indent=4)
write_successful = False
if file_lock:
with file_lock:
write_successful = utils.safe_write(output_json_path, write_content, mode='w')
else:
write_successful = utils.safe_write(output_json_path, write_content, mode='w')
if write_successful:
print(f"--- QA Generation Complete for: {qa_folder_path}. Saved to {os.path.basename(output_json_path)}. ---")
return True
else:
print(f"Error: Failed to write QA JSON to {output_json_path}.")
return False
else:
print(f"Warning: No valid QA pairs generated or parsed for {qa_folder_path}.")
return False