-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathpython-docstring
More file actions
executable file
·504 lines (390 loc) · 19.3 KB
/
python-docstring
File metadata and controls
executable file
·504 lines (390 loc) · 19.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
#!/usr/bin/python
import ast
import os
import requests
import sys
import json
import tempfile
import subprocess
import hashlib
import pprint
import json
class OpenAIProvider:
def __init__(self):
self.api_key = os.getenv("AI_API_KEY")
if not self.api_key:
raise EnvironmentError("AI_API_KEY not set in environment.")
self.endpoint = "https://api.openai.com/v1/chat/completions"
self.model = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
def improve_text(self, prompt: str, text: str) -> str:
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
body = {
"model": self.model,
"messages": [
{"role": "system", "content": prompt},
{"role": "user", "content": text},
],
"temperature": 0.3,
}
response = requests.post(self.endpoint, json=body, headers=headers, timeout=60)
if response.status_code == 200:
return response.json()["choices"][0]["message"]["content"].strip()
raise Exception(
f"OpenAI API call failed: {response.status_code} - {response.text}"
)
def normalize_docstring(docstring):
"""Normalize docstring to strip leading/trailing spaces, remove markdown, and preserve multi-line format."""
if not docstring:
return ""
# Extract everything between triple quotes
# This assumes the docstring is surrounded by triple quotes
start = docstring.find('"""') + 3 # Skip past the first triple quotes
end = docstring.rfind('"""') # Find the last triple quotes
# Extract the content between the triple quotes
if start > 2 and end > start:
docstring = docstring[start:end].strip() # Remove the triple quotes and extra spaces
return docstring
def get_functions_and_classes_without_docstrings(tree):
"""Parse the file and return functions and classes without docstrings,
considering top-level functions and functions inside a class only."""
functions_and_classes_without_docstrings = []
# We need to track the current class level to identify nested functions
class_stack = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
# Determine the level of indentation
indentation_level = node.col_offset
# Debugging: Print function name, indentation level, and docstring presence
docstring = ast.get_docstring(node)
docstring_status = "No docstring" if docstring is None else "Has docstring"
# Skip functions with indentation level of 8 (nested functions inside other functions)
if indentation_level == 8:
# print(f"Skipping nested function: {node.name}")
continue
else:
if docstring is None:
functions_and_classes_without_docstrings.append(('function', node))
elif isinstance(node, ast.ClassDef):
# Add the class to the class stack, as we're inside a class now
class_stack.append(node)
# Check if the class has a docstring
docstring = ast.get_docstring(node)
if docstring is None:
functions_and_classes_without_docstrings.append(('class', node))
# Pop the class stack when leaving a class body
if isinstance(node, ast.ClassDef):
class_stack.pop()
return functions_and_classes_without_docstrings
def generate_docstring(object_code, openai_provider, cache, object_name, object_type, filename):
"""Generate a docstring using OpenAIProvider."""
code_digest = hashlib.sha256(object_code.encode()).hexdigest()
entry = filename + "-" + object_name
entry = hashlib.sha256(entry.encode()).hexdigest()
if entry in cache and code_digest == cache[entry]['last_hash']:
print(f"Using cached docstring for {object_type}: {object_name}")
return cache[entry]['docstring'] # Return cached docstring if it exists
if entry in cache and code_digest != cache[entry]['last_hash']:
print(f"Has for {object_type}: {object_name} changed. Getting new docstring.")
prompt = f"""
You will be provided with a Python {object_type} (function or class),
including its code. Based on this, your task is to generate a Python docstring for it.
Use the function signature and body to infer the purpose of the function/class,
the arguments it takes, the return value, and any exceptions it may raise.
Follow these guidelines to create the docstring:
1. Summary: Provide a concise summary of the function/class's purpose.
Focus on what the function does and its main goal.
2. Arguments: List the parameters, their types, and a brief description
of what each one represents.
3. Return: If the function has a return value, describe the return type
and what it represents. If there's no return, OMIT THE SECTION.
4. Exceptions: If the function raises any exceptions, list them with descriptions.
If no exceptions are raised, OMIT THE SECTION.
5. Side Effects (if applicable): If the function has side effects
(e.g., modifies global state, interacts with external services), mention them.
OMIT THE SECTION if it is not lear in the code.
6. Algorithm or Key Logic (optional): If the function is complex,
provide a high-level outline of the logic or algorithm involved.
OMIT THE SECTION if it is not lear in the code.
Ensure the docstring follows Python's standard docstring conventions and provides
just enough detail to make the function understandable and usable without overwhelming the reader.
"""
# Get the docstring from OpenAI
docstring = openai_provider.improve_text(prompt, object_code)
# Normalize the docstring by stripping leading/trailing whitespace
normalized_docstring = normalize_docstring(docstring)
# If the docstring is empty, set a fallback docstring
if not normalized_docstring:
print(f"Warning: Generated docstring for {object_type} '{object_name}' is empty. Using a fallback docstring.")
normalized_docstring = f"TODO: Add a docstring for the {object_type} '{object_name}'."
tree = ast.parse(object_code)
indentation_level = 0
for node in ast.walk(tree):
if isinstance(node, (ast.ClassDef)):
if node.name == object_name:
indentation_level = node.col_offset + 4 # Add 4 spaces for method body
break
if isinstance(node, (ast.FunctionDef)):
if node.name == object_name:
indentation_level = node.col_offset + 8 # Add 4 spaces for method body
break
# Split the normalized docstring into lines and prefix each with the correct indentation
docstring_lines = normalized_docstring.splitlines()
indented_docstring = "\n".join([f"{' ' * indentation_level}{line}" for line in docstring_lines])
# Cache the docstring
cache[entry] = {}
cache[entry]['docstring'] = indented_docstring
cache[entry]['last_hash'] = hashlib.sha256(object_code.encode()).hexdigest()
return indented_docstring
def validate_syntax(code):
"""Validate Python code syntax by writing to a temporary file and checking it."""
try:
# Create a temporary file to check the syntax
with tempfile.NamedTemporaryFile(delete=False, mode='w', suffix=".py") as temp_file:
temp_file.write(code)
temp_filename = temp_file.name
# Run python -m py_compile to check syntax
result = subprocess.run(
['python', '-m', 'py_compile', temp_filename],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE
)
if result.returncode != 0:
print(code)
os.remove(temp_filename) # Remove the temporary file after checking
if result.returncode != 0:
print(f"Syntax error: {result.stderr.decode()}")
return False
return True
except Exception as e:
print(f"Error validating syntax: {e}")
return False
def process_directory(directory, openai_provider, cache, cache_file):
"""Process all Python files in a directory and its subdirectories."""
for root, dirs, files in os.walk(directory):
for file in files:
if file.endswith(".py"): # Process only Python files
filepath = os.path.join(root, file)
process_file(filepath, openai_provider, cache, cache_file)
def load_cache(cache_file):
"""Load the cache from a file."""
if os.path.exists(cache_file):
with open(cache_file, 'r') as f:
return json.load(f)
return {}
def save_cache(cache, cache_file):
"""Save the cache to a file."""
with open(cache_file, 'w') as f:
json.dump(cache, f, indent=4)
def get_next(tree):
"""Get the next function or class without a docstring."""
# Assuming this function returns the next function/class that lacks a docstring.
# Return None when all functions/classes are documented
objects_without_docstrings = get_functions_and_classes_without_docstrings(tree)
if objects_without_docstrings:
return objects_without_docstrings[0] # Return the next object
return None # No more objects to process
def get_object_node(tree, object_name):
"""
Traverses the AST to find and return the node for the object (function or class)
with the specified name.
Args:
tree (ast.Module): The AST module to search through.
object_name (str): The name of the function or class to find.
Returns:
ast.FunctionDef | ast.ClassDef: The AST node for the function or class.
"""
for node in ast.walk(tree):
# Check for function or class definitions by name
if isinstance(node, (ast.FunctionDef, ast.ClassDef)) and node.name == object_name:
return node
return None
def process_file(filename, openai_provider, cache, cache_file):
"""Process the file: find functions, classes, generate docstrings, and save the result."""
print(f"Processing file: {filename}")
# Reread the file to ensure we are working with the latest version of the code
with open(filename, 'r') as file:
code = file.read()
# Start an infinite loop that will keep processing until there are no more items
while True:
tree = ast.parse(code)
# Get the next function or class that doesn't have a docstring
next_object = get_next(tree)
if not next_object:
print(f"All functions and classes have docstrings in {filename}.")
break # Exit the loop if there are no more functions or classes to document
object_type, object_node = next_object
# Get the object code (e.g., function or class)
object_code = ast.unparse(object_node)
object_name = object_node.name
print(f"Generating docstring for {object_type}: {object_name} in {filename}")
# Check if the function or class already has a docstring
if ast.get_docstring(object_node):
print(f"Skipping {object_type}: {object_name}, it already has a docstring.")
continue
# Generate the docstring using OpenAI or another provider
docstring = generate_docstring(object_code, openai_provider, cache, object_name, object_type, filename)
if docstring:
code = insert_docstring_to_object(code, object_node, docstring, object_type)
# Re-read the object code after the docstring insertion
updated_tree = ast.parse(code)
updated_object_node = get_object_node(updated_tree, object_name)
updated_object_code = ast.unparse(updated_object_node)
entry = filename + "-" + object_name
entry = hashlib.sha256(entry.encode()).hexdigest()
# Save the SHA-1 hash of the updated object_code into the cache
cache[entry]['last_hash'] = hashlib.sha1(updated_object_code.encode()).hexdigest()
# Save cache after processing each object
save_cache(cache, cache_file)
print(f"Docstrings added successfully to {object_type}: {object_name} in {filename}.")
else:
print(f"Something wrong with the docstring generation.")
# Validate Python syntax before saving
if validate_syntax(code):
# Write the updated code back to the file
with open(filename, 'w') as file:
file.write(code)
else:
print(f"Syntax error in {filename}, skipping save.")
def fix_indentation(code):
"""Fix the indentation of the function signature and the code after the docstring."""
# Split the code into lines
code_lines = code.splitlines()
# Find the index where the function signature ends (the first occurrence of ':')
def_index = next(i for i, line in enumerate(code_lines) if line.strip().startswith('def '))
# Add 4 spaces to the function signature
code_lines[def_index] = ' ' + code_lines[def_index].lstrip()
# Now, we need to add indentation for all lines after the docstring
in_docstring = False
for i in range(def_index + 1, len(code_lines)):
line = code_lines[i]
# Detect the start and end of the docstring
if line.strip().startswith('"""'):
if in_docstring:
# End of the docstring
in_docstring = False
else:
# Start of the docstring
in_docstring = True
# Indent the line after the docstring
if in_docstring is False:
code_lines[i] = ' ' + line # Add 4 spaces to the body of the function
# Rejoin the lines into the final code
updated_code = "\n".join(code_lines)
return updated_code
def remove_duplicate_functions(code):
"""Remove duplicate function definitions based on the function signature."""
# Parse the code into an AST
tree = None
try:
tree = ast.parse(code)
except Exception as e:
print(code)
return code
# Track function signatures
function_signatures = set()
functions_to_remove = []
# Walk through the AST and find function definitions
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
function_signature = ast.unparse(node)
# If we have seen this function signature before, mark it for removal
if function_signature in function_signatures:
functions_to_remove.append(node)
else:
function_signatures.add(function_signature)
# Remove duplicate functions from the code
for func_node in functions_to_remove:
# Removing the function by filtering out the lines corresponding to that function
code = remove_function_from_code(code, func_node)
return code
def remove_function_from_code(code, function_node):
"""Remove the function from the code."""
# Split the code into lines
code_lines = code.splitlines()
# Get the start and end line numbers of the function
start_line = function_node.lineno - 1 # Adjust for 0-indexing
end_line = start_line + len(ast.unparse(function_node).splitlines())
# Remove the lines of the function from the code
return "\n".join(code_lines[:start_line] + code_lines[end_line + 1:])
def extract_main_block(code):
"""Extracts everything from 'if __name__ == "__main__":' to the end of the file."""
lines = code.splitlines() # Split the code into lines
main_start_index = -1 # Initialize the index
# Find the line number where '__name__' starts
for i, line in enumerate(lines):
if line.strip().startswith('if __name__ == "__main__":'):
main_start_index = i
break
if main_start_index == -1:
# print("No '__name__ == \"__main__\"' found in the code.")
return ""
# Extract all lines starting from the '__name__' line to the end
return "\n".join(lines[main_start_index:])
def insert_docstring_to_object(code, object_node, docstring, object_type):
main = extract_main_block(code)
"""Insert the generated docstring into the function/class in the correct place."""
object_code = ast.unparse(object_node)
# Get the indentation level of the function or class signature using col_offset
indentation_level = ' ' * (object_node.col_offset + 4) # Add 4 spaces for the body
# Check if the function or class already has a docstring, if so, remove it
existing_docstring = ast.get_docstring(object_node)
if existing_docstring:
# Remove the existing docstring from the object_code
object_code = object_code.replace(f'"""{existing_docstring}"""', '')
# The docstring is already properly formatted and indented, so we use it directly
docstring_code = f'{indentation_level}"""\n{docstring}\n{indentation_level}"""'
# Find the first occurrence of ':' to determine where the signature ends
signature_end = object_code.find(':')
func_start_lines = object_code[:signature_end + 1]
func_end_lines = object_code[signature_end + 1:]
# Insert the docstring immediately after the signature (before the body of the function/class)
object_code_with_docstring = func_start_lines + f'\n{docstring_code}\n' + func_end_lines
if object_node.col_offset > 0:
object_code_with_docstring = fix_indentation(object_code_with_docstring)
# Get the start and end line numbers of the object_node
start_line = object_node.lineno - 1 # Adjust for 0-indexing
end_line = start_line + len(object_code.splitlines()) - 1
# Split the original code into lines
code_lines = code.splitlines()
print("================================")
print(pprint.pprint(code_lines[:start_line]))
print("================================")
# Replace the lines between start_line and end_line with the new code containing the docstring
updated_code_lines = (
code_lines[:start_line] +
[object_code_with_docstring] +
code_lines[end_line + 1:]
)
# Join the lines back into a single string
updated_code = "\n".join(updated_code_lines)
updated_code = remove_duplicate_functions(updated_code)
main_after_update = extract_main_block(updated_code)
if main != main_after_update:
updated_code = updated_code + "\n\n" + main
# Ensure a line break after the inserted docstring for readability
if not updated_code.endswith('\n'):
updated_code += '\n'
return updated_code
if __name__ == '__main__':
# Get the directory or file from the command line argument
if len(sys.argv) != 2:
print("Usage: python add_docstrings.py <python_file_or_directory>")
sys.exit(1)
path = sys.argv[1]
# Initialize the OpenAIProvider
openai_provider = OpenAIProvider()
# Load the cache from the cache file (if it exists)
cache_file = "docstring_cache.json"
cache = load_cache(cache_file)
if os.path.isdir(path): # If it's a directory, process it recursively
process_directory(path, openai_provider, cache, cache_file)
elif os.path.isfile(path): # If it's a single file, process it
process_file(path, openai_provider, cache, cache_file)
else:
print(f"The provided path '{path}' is not valid.")
sys.exit(1)
# Save the updated cache to the file
save_cache(cache, cache_file)