mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 06:15:00 +00:00
black
This commit is contained in:
parent
eb79dd2760
commit
d61ab51a74
1 changed files with 84 additions and 77 deletions
161
coder.py
161
coder.py
|
@ -1,7 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
# This is a Python script that uses OpenAI's GPT-3 to modify code based on user requests.
|
||||
|
||||
import os
|
||||
import sys
|
||||
import copy
|
||||
|
@ -25,7 +23,7 @@ from dump import dump
|
|||
|
||||
import prompts
|
||||
|
||||
history_file = '.coder.history'
|
||||
history_file = ".coder.history"
|
||||
try:
|
||||
readline.read_history_file(history_file)
|
||||
except FileNotFoundError:
|
||||
|
@ -38,25 +36,26 @@ openai.api_key = os.getenv("OPENAI_API_KEY")
|
|||
|
||||
def find_index(list1, list2):
|
||||
for i in range(len(list1)):
|
||||
if list1[i:i+len(list2)] == list2:
|
||||
if list1[i : i + len(list2)] == list2:
|
||||
return i
|
||||
return -1
|
||||
|
||||
|
||||
class Coder:
|
||||
fnames = dict()
|
||||
last_modified = 0
|
||||
|
||||
def __init__(self, use_gpt_4):
|
||||
if use_gpt_4:
|
||||
self.main_model = 'gpt-4'
|
||||
self.main_model = "gpt-4"
|
||||
else:
|
||||
self.main_model = 'gpt-3.5-turbo'
|
||||
self.main_model = "gpt-3.5-turbo"
|
||||
|
||||
def add_file(self, fname):
|
||||
self.fnames[fname] = Path(fname).stat().st_mtime
|
||||
|
||||
def files_modified(self):
|
||||
for fname,mtime in self.fnames.items():
|
||||
for fname, mtime in self.fnames.items():
|
||||
if Path(fname).stat().st_mtime != mtime:
|
||||
return True
|
||||
|
||||
|
@ -64,28 +63,27 @@ class Coder:
|
|||
self.request_prompt = prompt
|
||||
|
||||
def quoted_file(self, fname):
|
||||
prompt = '\n'
|
||||
prompt = "\n"
|
||||
prompt += fname
|
||||
prompt += '\n```\n'
|
||||
prompt += "\n```\n"
|
||||
prompt += Path(fname).read_text()
|
||||
prompt += '\n```\n'
|
||||
prompt += "\n```\n"
|
||||
return prompt
|
||||
|
||||
def get_files_content(self):
|
||||
prompt = ''
|
||||
prompt = ""
|
||||
for fname in self.fnames:
|
||||
prompt += self.quoted_file(fname)
|
||||
return prompt
|
||||
|
||||
def get_input(self):
|
||||
|
||||
print()
|
||||
print('='*60)
|
||||
inp = ''
|
||||
print("=" * 60)
|
||||
inp = ""
|
||||
num_control_c = 0
|
||||
while not inp.strip():
|
||||
try:
|
||||
inp = input('> ')
|
||||
inp = input("> ")
|
||||
except EOFError:
|
||||
return
|
||||
except KeyboardInterrupt:
|
||||
|
@ -93,22 +91,21 @@ class Coder:
|
|||
print()
|
||||
if num_control_c >= 2:
|
||||
return
|
||||
print('^C again to quit')
|
||||
print("^C again to quit")
|
||||
|
||||
print()
|
||||
|
||||
readline.write_history_file(history_file)
|
||||
return inp
|
||||
|
||||
def set_files_messages(self, did_edits = False):
|
||||
|
||||
def set_files_messages(self, did_edits=False):
|
||||
last_modified = max(Path(fname).stat().st_mtime for fname in self.fnames)
|
||||
if last_modified <= self.last_modified:
|
||||
return
|
||||
did_edits = (self.last_modified > 0)
|
||||
did_edits = self.last_modified > 0
|
||||
|
||||
self.last_modified = last_modified
|
||||
print('Reloading files...')
|
||||
print("Reloading files...")
|
||||
|
||||
if did_edits:
|
||||
files_content = prompts.files_content_prefix_edited
|
||||
|
@ -119,8 +116,8 @@ class Coder:
|
|||
files_content += prompts.files_content_suffix
|
||||
|
||||
self.files_messages = [
|
||||
dict(role = 'user', content = files_content),
|
||||
dict(role = 'assistant', content = "Ok."),
|
||||
dict(role="user", content=files_content),
|
||||
dict(role="assistant", content="Ok."),
|
||||
]
|
||||
|
||||
return True
|
||||
|
@ -141,26 +138,26 @@ class Coder:
|
|||
self.cur_messages = []
|
||||
|
||||
self.cur_messages += [
|
||||
dict(role = 'user', content = inp),
|
||||
dict(role="user", content=inp),
|
||||
]
|
||||
|
||||
#self.show_messages(self.done_messages, "done")
|
||||
#self.show_messages(self.files_messages, "files")
|
||||
#self.show_messages(self.cur_messages, "cur")
|
||||
# self.show_messages(self.done_messages, "done")
|
||||
# self.show_messages(self.files_messages, "files")
|
||||
# self.show_messages(self.cur_messages, "cur")
|
||||
|
||||
messages = [
|
||||
dict(role = 'system', content = prompts.main_system),
|
||||
dict(role="system", content=prompts.main_system),
|
||||
]
|
||||
messages += self.done_messages
|
||||
messages += self.files_messages
|
||||
messages += self.cur_messages
|
||||
|
||||
self.show_messages(messages, 'all')
|
||||
self.show_messages(messages, "all")
|
||||
|
||||
content = self.send(messages)
|
||||
|
||||
self.cur_messages += [
|
||||
dict(role = 'assistant', content = content),
|
||||
dict(role="assistant", content=content),
|
||||
]
|
||||
|
||||
print()
|
||||
|
@ -181,18 +178,18 @@ class Coder:
|
|||
self.cur_messages = []
|
||||
|
||||
def show_messages(self, messages, title):
|
||||
print(title.upper(), '*' * 50)
|
||||
print(title.upper(), "*" * 50)
|
||||
|
||||
for msg in messages:
|
||||
print()
|
||||
print('-' * 50)
|
||||
role = msg['role'].upper()
|
||||
content = msg['content'].splitlines()
|
||||
print("-" * 50)
|
||||
role = msg["role"].upper()
|
||||
content = msg["content"].splitlines()
|
||||
for line in content:
|
||||
print(role, line)
|
||||
|
||||
def send(self, messages, model=None, show_progress = 0):
|
||||
#self.show_messages(messages, "all")
|
||||
def send(self, messages, model=None, show_progress=0):
|
||||
# self.show_messages(messages, "all")
|
||||
|
||||
if not model:
|
||||
model = self.main_model
|
||||
|
@ -201,7 +198,7 @@ class Coder:
|
|||
model=model,
|
||||
messages=messages,
|
||||
temperature=0,
|
||||
stream = True,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
if show_progress:
|
||||
|
@ -211,7 +208,7 @@ class Coder:
|
|||
|
||||
def show_send_progress(self, completion, show_progress):
|
||||
resp = []
|
||||
pbar = tqdm(total = show_progress)
|
||||
pbar = tqdm(total=show_progress)
|
||||
for chunk in completion:
|
||||
try:
|
||||
text = chunk.choices[0].delta.content
|
||||
|
@ -224,18 +221,18 @@ class Coder:
|
|||
pbar.update(show_progress)
|
||||
pbar.close()
|
||||
|
||||
resp = ''.join(resp)
|
||||
resp = "".join(resp)
|
||||
return resp
|
||||
|
||||
def show_send_output_plain(self, completion):
|
||||
resp = ''
|
||||
resp = ""
|
||||
|
||||
in_diff = False
|
||||
diff_lines = []
|
||||
|
||||
partial_line = ''
|
||||
partial_line = ""
|
||||
for chunk in completion:
|
||||
if chunk.choices[0].finish_reason not in (None, 'stop'):
|
||||
if chunk.choices[0].finish_reason not in (None, "stop"):
|
||||
dump(chunk.choices[0].finish_reason)
|
||||
try:
|
||||
text = chunk.choices[0].delta.content
|
||||
|
@ -247,7 +244,7 @@ class Coder:
|
|||
sys.stdout.flush()
|
||||
|
||||
# disabled
|
||||
if False and '```' in resp:
|
||||
if False and "```" in resp:
|
||||
return resp
|
||||
|
||||
return resp
|
||||
|
@ -261,12 +258,12 @@ class Coder:
|
|||
def print_lines():
|
||||
if not diff_lines:
|
||||
return
|
||||
code = '\n'.join(diff_lines)
|
||||
code = "\n".join(diff_lines)
|
||||
lexer = lexers.guess_lexer(code)
|
||||
code = highlight(code, lexer, formatter)
|
||||
print(code, end='')
|
||||
print(code, end="")
|
||||
|
||||
partial_line = ''
|
||||
partial_line = ""
|
||||
for chunk in completion:
|
||||
try:
|
||||
text = chunk.choices[0].delta.content
|
||||
|
@ -274,18 +271,18 @@ class Coder:
|
|||
except AttributeError:
|
||||
continue
|
||||
|
||||
lines = (partial_line + text)
|
||||
lines = lines.split('\n')
|
||||
lines = partial_line + text
|
||||
lines = lines.split("\n")
|
||||
partial_line = lines.pop()
|
||||
|
||||
for line in lines:
|
||||
check = line.rstrip()
|
||||
if check == '>>>>>>> UPDATED':
|
||||
if check == ">>>>>>> UPDATED":
|
||||
print_lines()
|
||||
in_diff = False
|
||||
diff_lines = []
|
||||
|
||||
if check == '=======':
|
||||
if check == "=======":
|
||||
print_lines()
|
||||
diff_lines = []
|
||||
print(line)
|
||||
|
@ -294,7 +291,7 @@ class Coder:
|
|||
else:
|
||||
print(line)
|
||||
|
||||
if line.strip() == '<<<<<<< ORIGINAL':
|
||||
if line.strip() == "<<<<<<< ORIGINAL":
|
||||
in_diff = True
|
||||
diff_lines = []
|
||||
|
||||
|
@ -302,13 +299,14 @@ class Coder:
|
|||
if partial_line:
|
||||
print(partial_line)
|
||||
|
||||
return ''.join(resp)
|
||||
return "".join(resp)
|
||||
|
||||
|
||||
pattern = re.compile(r'(\S+)\s+(```)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED', re.MULTILINE | re.DOTALL)
|
||||
pattern = re.compile(
|
||||
r"(\S+)\s+(```)?<<<<<<< ORIGINAL\n(.*?\n?)=======\n(.*?\n?)>>>>>>> UPDATED",
|
||||
re.MULTILINE | re.DOTALL,
|
||||
)
|
||||
|
||||
def update_files(self, content, inp):
|
||||
|
||||
edited = set()
|
||||
for match in self.pattern.finditer(content):
|
||||
path, _, original, updated = match.groups()
|
||||
|
@ -329,7 +327,7 @@ class Coder:
|
|||
|
||||
# does it want to make a new file?
|
||||
if not fname.exists() and not before_text:
|
||||
print('Creating empty file:', fname)
|
||||
print("Creating empty file:", fname)
|
||||
fname.touch()
|
||||
|
||||
content = fname.read_text().splitlines()
|
||||
|
@ -347,31 +345,33 @@ class Coder:
|
|||
|
||||
new_content = content[:where]
|
||||
new_content += after_text.splitlines()
|
||||
new_content += content[where+len(before_lines):]
|
||||
new_content = '\n'.join(new_content) + '\n'
|
||||
new_content += content[where + len(before_lines) :]
|
||||
new_content = "\n".join(new_content) + "\n"
|
||||
|
||||
fname.write_text(new_content)
|
||||
print('Applied edit to', fname)
|
||||
print("Applied edit to", fname)
|
||||
return True
|
||||
|
||||
def do_gpt_powered_replace(self, fname, edit, request):
|
||||
model = 'gpt-3.5-turbo'
|
||||
print(f'Asking {model} to apply ambiguous edit to {fname}...')
|
||||
model = "gpt-3.5-turbo"
|
||||
print(f"Asking {model} to apply ambiguous edit to {fname}...")
|
||||
|
||||
fname = Path(fname)
|
||||
content = fname.read_text()
|
||||
prompt = prompts.editor_user.format(
|
||||
request = request,
|
||||
edit = edit,
|
||||
fname = fname,
|
||||
content = content,
|
||||
)
|
||||
request=request,
|
||||
edit=edit,
|
||||
fname=fname,
|
||||
content=content,
|
||||
)
|
||||
|
||||
messages = [
|
||||
dict(role = 'system', content = prompts.editor_system),
|
||||
dict(role = 'user', content = prompt),
|
||||
dict(role="system", content=prompts.editor_system),
|
||||
dict(role="user", content=prompt),
|
||||
]
|
||||
res = self.send(messages, show_progress = len(content) + len(edit)/2, model=model)
|
||||
res = self.send(
|
||||
messages, show_progress=len(content) + len(edit) / 2, model=model
|
||||
)
|
||||
res = self.strip_quoted_wrapping(res, fname)
|
||||
fname.write_text(res)
|
||||
|
||||
|
@ -384,21 +384,27 @@ class Coder:
|
|||
if fname and res[0].strip().endswith(Path(fname).name):
|
||||
res = res[1:]
|
||||
|
||||
if res[0].startswith('```') and res[-1].startswith('```'):
|
||||
if res[0].startswith("```") and res[-1].startswith("```"):
|
||||
res = res[1:-1]
|
||||
|
||||
res = '\n'.join(res)
|
||||
if res and res[-1] != '\n':
|
||||
res += '\n'
|
||||
res = "\n".join(res)
|
||||
if res and res[-1] != "\n":
|
||||
res += "\n"
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
parser = argparse.ArgumentParser(description='Chat with GPT about code')
|
||||
parser.add_argument('files', metavar='FILE', nargs='+', help='a list of source code files')
|
||||
parser.add_argument('-3', '--gpt-3-5-turbo', action='store_true', help='Only use gpt-3.5-turbo, not gpt-4')
|
||||
parser = argparse.ArgumentParser(description="Chat with GPT about code")
|
||||
parser.add_argument(
|
||||
"files", metavar="FILE", nargs="+", help="a list of source code files"
|
||||
)
|
||||
parser.add_argument(
|
||||
"-3",
|
||||
"--gpt-3-5-turbo",
|
||||
action="store_true",
|
||||
help="Only use gpt-3.5-turbo, not gpt-4",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
@ -412,6 +418,7 @@ def main():
|
|||
|
||||
coder.run()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
status = main()
|
||||
sys.exit(status)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue