This commit is contained in:
Paul Gauthier 2023-06-07 12:28:45 -07:00
parent 70b981029c
commit 9cef379abd
5 changed files with 26 additions and 34 deletions

View file

@ -14,10 +14,9 @@ from rich.console import Console
from rich.live import Live from rich.live import Live
from rich.markdown import Markdown from rich.markdown import Markdown
from aider import diffs, prompts, utils from aider import diffs, models, prompts, utils
from aider.commands import Commands from aider.commands import Commands
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider.utils import Models
from .dump import dump # noqa: F401 from .dump import dump # noqa: F401
@ -41,7 +40,7 @@ class Coder:
def __init__( def __init__(
self, self,
io, io,
main_model=Models.GPT4.value, main_model=models.GPT4.value,
fnames=None, fnames=None,
pretty=True, pretty=True,
show_diffs=False, show_diffs=False,
@ -78,12 +77,12 @@ class Coder:
else: else:
self.console = Console(force_terminal=True, no_color=True) self.console = Console(force_terminal=True, no_color=True)
main_model = Models(main_model) main_model = models.get_model(main_model)
if main_model != Models.GPT35 and not self.check_model_availability(main_model): if main_model != models.GPT35 and not self.check_model_availability(main_model):
main_model = Models.GPT35 main_model = models.GPT35
self.main_model = main_model self.main_model = main_model
if main_model == Models.GPT35: if main_model == models.GPT35:
self.io.tool_output( self.io.tool_output(
f"Using {main_model.value} (experimental): disabling ctags/repo-maps.", f"Using {main_model.value} (experimental): disabling ctags/repo-maps.",
) )
@ -104,7 +103,7 @@ class Coder:
self.io.tool_output("Not using git.") self.io.tool_output("Not using git.")
self.find_common_root() self.find_common_root()
if main_model == Models.GPT4: if main_model == models.GPT4:
rm_io = io if self.verbose else None rm_io = io if self.verbose else None
self.repo_map = RepoMap( self.repo_map = RepoMap(
map_tokens, map_tokens,
@ -311,7 +310,7 @@ class Coder:
] ]
main_sys = self.gpt_prompts.main_system main_sys = self.gpt_prompts.main_system
if self.main_model == Models.GPT4: if self.main_model == models.GPT4:
main_sys += "\n" + self.gpt_prompts.system_reminder main_sys += "\n" + self.gpt_prompts.system_reminder
messages = [ messages = [
@ -340,7 +339,7 @@ class Coder:
if edit_error: if edit_error:
return edit_error return edit_error
if self.main_model == Models.GPT4 or (self.main_model == Models.GPT35 and not edited): if self.main_model == models.GPT4 or (self.main_model == models.GPT35 and not edited):
# Don't add assistant messages to the history if they contain "edits" # Don't add assistant messages to the history if they contain "edits"
# Because those edits are actually fully copies of the file! # Because those edits are actually fully copies of the file!
# That wastes too much context window. # That wastes too much context window.
@ -475,7 +474,7 @@ class Coder:
if self.pretty: if self.pretty:
show_resp = self.resp show_resp = self.resp
if self.main_model == Models.GPT35: if self.main_model == models.GPT35:
try: try:
show_resp = self.update_files_gpt35(self.resp, mode="diff") show_resp = self.update_files_gpt35(self.resp, mode="diff")
except ValueError: except ValueError:
@ -621,7 +620,7 @@ class Coder:
def get_commit_message(self, diffs, context): def get_commit_message(self, diffs, context):
if len(diffs) >= 4 * 1024 * 4: if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error( self.io.tool_error(
f"Diff is too large for {Models.GPT35.value} to generate a commit message." f"Diff is too large for {models.GPT35.value} to generate a commit message."
) )
return return
@ -635,12 +634,12 @@ class Coder:
try: try:
commit_message, interrupted = self.send( commit_message, interrupted = self.send(
messages, messages,
model=Models.GPT35.value, model=models.GPT35.value,
silent=True, silent=True,
) )
except openai.error.InvalidRequestError: except openai.error.InvalidRequestError:
self.io.tool_error( self.io.tool_error(
f"Failed to generate commit message using {Models.GPT35.value} due to an invalid" f"Failed to generate commit message using {models.GPT35.value} due to an invalid"
" request." " request."
) )
return return
@ -651,7 +650,7 @@ class Coder:
if interrupted: if interrupted:
self.io.tool_error( self.io.tool_error(
f"Unable to get commit message from {Models.GPT35.value}. Use /commit to try again." f"Unable to get commit message from {models.GPT35.value}. Use /commit to try again."
) )
return return
@ -776,9 +775,9 @@ class Coder:
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files()) return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
def apply_updates(self, content): def apply_updates(self, content):
if self.main_model == Models.GPT4: if self.main_model == models.GPT4:
method = self.update_files_gpt4 method = self.update_files_gpt4
elif self.main_model == Models.GPT35: elif self.main_model == models.GPT35:
method = self.update_files_gpt35 method = self.update_files_gpt35
else: else:
raise ValueError(f"apply_updates() doesn't support {self.main_model.value}") raise ValueError(f"apply_updates() doesn't support {self.main_model.value}")

View file

@ -8,8 +8,7 @@ import git
import tiktoken import tiktoken
from prompt_toolkit.completion import Completion from prompt_toolkit.completion import Completion
from aider import prompts, utils from aider import models, prompts, utils
from aider.utils import Models
class Commands: class Commands:
@ -134,7 +133,7 @@ class Commands:
print() print()
print(f"{total:6.3f} total") print(f"{total:6.3f} total")
limit = 8 if self.coder.main_model == Models.GPT4 else 4 limit = 8 if self.coder.main_model == models.GPT4 else 4
remaining = limit - total remaining = limit - total
print(f"{remaining:6.3f} remaining") print(f"{remaining:6.3f} remaining")
print(f"{limit:6.3f} max context window") print(f"{limit:6.3f} max context window")
@ -182,7 +181,7 @@ class Commands:
"was reset and removed from git.\n" "was reset and removed from git.\n"
) )
if self.coder.main_model != Models.GPT35: if self.coder.main_model != models.GPT35:
return prompts.undo_command_reply return prompts.undo_command_reply
def cmd_diff(self, args): def cmd_diff(self, args):

View file

@ -4,9 +4,9 @@ import sys
import configargparse import configargparse
import git import git
from aider import models
from aider.coder import Coder from aider.coder import Coder
from aider.io import InputOutput from aider.io import InputOutput
from aider.utils import Models
def get_git_root(): def get_git_root():
@ -76,15 +76,15 @@ def main(args=None, input=None, output=None):
parser.add_argument( parser.add_argument(
"--model", "--model",
metavar="MODEL", metavar="MODEL",
default=Models.GPT4.value, default=models.GPT4.value,
help=f"Specify the model to use for the main chat (default: {Models.GPT4.value})", help=f"Specify the model to use for the main chat (default: {models.GPT4.value})",
) )
parser.add_argument( parser.add_argument(
"-3", "-3",
action="store_const", action="store_const",
dest="model", dest="model",
const=Models.GPT35.value, const=models.GPT35.value,
help=f"Use {Models.GPT35.value} model for the main chat (not advised)", help=f"Use {models.GPT35.value} model for the main chat (not advised)",
) )
parser.add_argument( parser.add_argument(
"--pretty", "--pretty",

View file

@ -14,7 +14,7 @@ from pygments.lexers import guess_lexer_for_filename
from pygments.token import Token from pygments.token import Token
from pygments.util import ClassNotFound from pygments.util import ClassNotFound
from aider.utils import Models from aider import models
from .dump import dump # noqa: F402 from .dump import dump # noqa: F402
@ -69,7 +69,7 @@ class RepoMap:
self, self,
map_tokens=1024, map_tokens=1024,
root=None, root=None,
main_model=Models.GPT4, main_model=models.GPT4,
io=None, io=None,
repo_content_prefix=None, repo_content_prefix=None,
): ):

View file

@ -1,17 +1,11 @@
import math import math
import re import re
from difflib import SequenceMatcher from difflib import SequenceMatcher
from enum import Enum
from pathlib import Path from pathlib import Path
# from aider.dump import dump # from aider.dump import dump
class Models(Enum):
GPT4 = "gpt-4"
GPT35 = "gpt-3.5-turbo"
def try_dotdotdots(whole, part, replace): def try_dotdotdots(whole, part, replace):
""" """
See if the edit block has ... lines. See if the edit block has ... lines.