Merge branch 'main' into edit-formats

This commit is contained in:
Paul Gauthier 2023-06-20 15:43:14 -07:00
commit 4d3fc3de7e
11 changed files with 194 additions and 98 deletions

View file

@ -0,0 +1 @@
__version__ = "0.6.6"

View file

@ -51,6 +51,7 @@ class Coder:
verbose=False,
openai_api_key=None,
openai_api_base=None,
assistant_output_color="blue",
):
if not openai_api_key:
raise MissingAPIKeyError("No OpenAI API key provided.")
@ -69,6 +70,7 @@ class Coder:
self.auto_commits = auto_commits
self.dirty_commits = dirty_commits
self.assistant_output_color = assistant_output_color
self.dry_run = dry_run
self.pretty = pretty
@ -78,19 +80,17 @@ class Coder:
else:
self.console = Console(force_terminal=True, no_color=True)
main_model = models.get_model(main_model)
if main_model not in models.GPT35_models:
main_model = models.Model(main_model)
if not main_model.is_always_available():
if not self.check_model_availability(main_model):
if main_model != models.GPT4:
self.io.tool_error(f"API key does not support {main_model.name}.")
self.io.tool_error(
f"API key does not support {main_model.name}, falling back to"
f" {models.GPT35_16k.name}"
)
main_model = models.GPT35_16k
self.main_model = main_model
if main_model in models.GPT35_models:
self.io.tool_output(
f"Using {main_model.name} (experimental): disabling ctags/repo-maps.",
)
self.edit_format = self.main_model.edit_format
if self.edit_format == "whole":
@ -98,6 +98,8 @@ class Coder:
else:
self.gpt_prompts = prompts.GPT4()
self.io.tool_output(f"Model: {main_model.name}")
self.show_diffs = show_diffs
self.commands = Commands(self.io, self)
@ -106,12 +108,12 @@ class Coder:
if self.repo:
rel_repo_dir = os.path.relpath(self.repo.git_dir, os.getcwd())
self.io.tool_output(f"Using git repo: {rel_repo_dir}")
self.io.tool_output(f"Git repo: {rel_repo_dir}")
else:
self.io.tool_output("Not using git.")
self.io.tool_output("Git repo: none")
self.find_common_root()
if main_model in models.GPT4_models:
if main_model.is_gpt4():
rm_io = io if self.verbose else None
self.repo_map = RepoMap(
map_tokens,
@ -121,8 +123,16 @@ class Coder:
self.gpt_prompts.repo_content_prefix,
)
if self.repo_map.has_ctags:
self.io.tool_output("Using ctags to build repo-map.")
if self.repo_map.use_ctags:
self.io.tool_output(f"Repo-map: universal-ctags using {map_tokens} tokens")
elif not self.repo_map.has_ctags and map_tokens > 0:
self.io.tool_output(
f"Repo-map: basic using {map_tokens} tokens (universal-ctags not found)"
)
else:
self.io.tool_output("Repo-map: disabled because map_tokens == 0")
else:
self.io.tool_output("Repo-map: disabled for gpt-3.5")
for fname in self.get_inchat_relative_files():
self.io.tool_output(f"Added {fname} to the chat.")
@ -318,7 +328,7 @@ class Coder:
]
main_sys = self.gpt_prompts.main_system
if self.main_model in models.GPT4_models + [models.GPT35_16k]:
if self.main_model.is_gpt4():
main_sys += "\n" + self.gpt_prompts.system_reminder
messages = [
@ -488,7 +498,9 @@ class Coder:
show_resp = self.update_files_gpt35(self.resp, mode="diff")
except ValueError:
pass
md = Markdown(show_resp, style="blue", code_theme="default")
md = Markdown(
show_resp, style=self.assistant_output_color, code_theme="default"
)
live.update(md)
else:
sys.stdout.write(text)

View file

@ -8,7 +8,7 @@ import git
import tiktoken
from prompt_toolkit.completion import Completion
from aider import models, prompts, utils
from aider import prompts, utils
class Commands:
@ -183,7 +183,7 @@ class Commands:
"was reset and removed from git.\n"
)
if self.coder.main_model in models.GPT4_models:
if self.coder.main_model.is_gpt4():
return prompts.undo_command_reply
def cmd_diff(self, args):

View file

@ -4,7 +4,7 @@ import sys
import configargparse
import git
from aider import models
from aider import __version__, models
from aider.coder import Coder
from aider.io import InputOutput
@ -30,13 +30,20 @@ def main(args=None, input=None, output=None):
default_config_files.insert(0, os.path.join(git_root, ".aider.conf.yml"))
parser = configargparse.ArgumentParser(
description="aider - chat with GPT about your code",
description="aider is GPT powered coding in your terminal",
add_config_file_help=True,
default_config_files=default_config_files,
config_file_parser_class=configargparse.YAMLConfigFileParser,
auto_env_var_prefix="AIDER_",
)
parser.add_argument(
"--version",
action="version",
version=f"%(prog)s {__version__}",
help="Show the version number and exit",
)
parser.add_argument(
"-c",
"--config",
@ -84,7 +91,7 @@ def main(args=None, input=None, output=None):
action="store_const",
dest="model",
const=models.GPT35_16k.name,
help=f"Use {models.GPT35.name} model for the main chat (not advised)",
help=f"Use {models.GPT35_16k.name} model for the main chat (gpt-4 is better)",
)
parser.add_argument(
"--pretty",
@ -113,6 +120,11 @@ def main(args=None, input=None, output=None):
default="red",
help="Set the color for tool error messages (default: red)",
)
parser.add_argument(
"--assistant-output-color",
default="blue",
help="Set the color for assistant output (default: blue)",
)
parser.add_argument(
"--apply",
metavar="FILE",
@ -228,6 +240,7 @@ def main(args=None, input=None, output=None):
verbose=args.verbose,
openai_api_key=args.openai_api_key,
openai_api_base=args.openai_api_base,
assistant_output_color=args.assistant_output_color,
)
if args.dirty_commits:

View file

@ -1,47 +1,38 @@
class Model_GPT4_32k:
name = "gpt-4-32k"
max_context_tokens = 32 * 1024
edit_format = "diff"
import re
GPT4_32k = Model_GPT4_32k()
class Model:
def __init__(self, name, tokens=None):
self.name = name
if tokens is None:
match = re.search(r"-([0-9]+)k", name)
default_tokens = 8
tokens = int(match.group(1)) if match else default_tokens
self.max_context_tokens = tokens * 1024
if self.is_gpt4():
self.edit_format = "diff"
return
if self.is_gpt35():
self.edit_format = "whole"
return
raise ValueError(f"Unsupported model: {name}")
def is_gpt4(self):
return self.name.startswith("gpt-4")
def is_gpt35(self):
return self.name.startswith("gpt-3.5-turbo")
def is_always_available(self):
return self.is_gpt35()
class Model_GPT4:
name = "gpt-4"
max_context_tokens = 8 * 1024
edit_format = "diff"
GPT4 = Model_GPT4()
class Model_GPT35:
name = "gpt-3.5-turbo"
max_context_tokens = 4 * 1024
edit_format = "whole"
GPT35 = Model_GPT35()
class Model_GPT35_16k:
name = "gpt-3.5-turbo-16k"
max_context_tokens = 16 * 1024
edit_format = "diff"
GPT35_16k = Model_GPT35_16k()
GPT35_models = [GPT35, GPT35_16k]
GPT4_models = [GPT4, GPT4_32k]
def get_model(name):
models = GPT35_models + GPT4_models
for model in models:
if model.name == name:
return model
raise ValueError(f"Unsupported model: {name}")
GPT4 = Model("gpt-4", 8)
GPT35 = Model("gpt-3.5-turbo")
GPT35_16k = Model("gpt-3.5-turbo-16k")

View file

@ -83,10 +83,12 @@ class RepoMap:
self.load_tags_cache()
self.max_map_tokens = map_tokens
if map_tokens > 0:
self.has_ctags = self.check_for_ctags()
self.has_ctags = self.check_for_ctags()
if map_tokens > 0 and self.has_ctags:
self.use_ctags = True
else:
self.has_ctags = False
self.use_ctags = False
self.tokenizer = tiktoken.encoding_for_model(main_model.name)
self.repo_content_prefix = repo_content_prefix
@ -122,7 +124,7 @@ class RepoMap:
if not other_files:
return
if self.has_ctags:
if self.use_ctags:
files_listing = self.get_ranked_tags_map(chat_files, other_files)
if files_listing:
num_tokens = self.token_count(files_listing)
@ -166,7 +168,7 @@ class RepoMap:
return self.TAGS_CACHE[cache_key]["data"]
cmd = self.ctags_cmd + [filename]
output = subprocess.check_output(cmd).decode("utf-8")
output = subprocess.check_output(cmd, stderr=subprocess.PIPE).decode("utf-8")
output = output.splitlines()
data = [json.loads(line) for line in output]

View file

@ -67,7 +67,7 @@ def replace_part_with_missing_leading_whitespace(whole, part, replace):
# If all lines in the part start with whitespace, then honor it.
# But GPT often outdents the part and replace blocks completely,
# thereby discarding the actual leading whitespace in the file.
if all(pline[0].isspace() for pline in part_lines):
if all((len(pline) > 0 and pline[0].isspace()) for pline in part_lines):
return
for i in range(len(whole_lines) - len(part_lines) + 1):