Moved all model names into an enum

This commit is contained in:
Paul Gauthier 2023-06-05 09:19:29 -07:00
parent efb8cad881
commit 95b32a74a9
5 changed files with 37 additions and 18 deletions

View file

@ -17,6 +17,7 @@ from rich.markdown import Markdown
from aider import prompts, utils from aider import prompts, utils
from aider.commands import Commands from aider.commands import Commands
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider.utils import Models
# from .dump import dump # from .dump import dump
@ -39,7 +40,7 @@ class Coder:
def __init__( def __init__(
self, self,
io, io,
main_model="gpt-4", main_model=Models.GPT4.value,
fnames=None, fnames=None,
pretty=True, pretty=True,
show_diffs=False, show_diffs=False,
@ -73,10 +74,10 @@ class Coder:
self.commands = Commands(self.io, self) self.commands = Commands(self.io, self)
if not self.check_model_availability(main_model): if not self.check_model_availability(main_model):
main_model = "gpt-3.5-turbo" main_model = Models.GPT35.value
self.main_model = main_model self.main_model = main_model
if main_model == "gpt-3.5-turbo": if main_model == Models.GPT35.value:
self.io.tool_output( self.io.tool_output(
f"Using {main_model}: showing diffs and disabling ctags/repo-maps.", f"Using {main_model}: showing diffs and disabling ctags/repo-maps.",
) )
@ -106,7 +107,7 @@ class Coder:
self.gpt_prompts.repo_content_prefix, self.gpt_prompts.repo_content_prefix,
) )
if main_model != "gpt-3.5-turbo": if main_model != Models.GPT35.value:
if self.repo_map.has_ctags: if self.repo_map.has_ctags:
self.io.tool_output("Using ctags to build repo-map.") self.io.tool_output("Using ctags to build repo-map.")
@ -299,7 +300,7 @@ class Coder:
] ]
main_sys = self.gpt_prompts.main_system main_sys = self.gpt_prompts.main_system
if self.main_model == "gpt-4": if self.main_model == Models.GPT4.value:
main_sys += "\n" + self.gpt_prompts.system_reminder main_sys += "\n" + self.gpt_prompts.system_reminder
messages = [ messages = [
@ -326,7 +327,7 @@ class Coder:
if edit_error: if edit_error:
return edit_error return edit_error
if self.main_model == "gpt=4" or (self.main_model == "gpt-3.5-turbo" and not edited): if self.main_model == "gpt=4" or (self.main_model == Models.GPT35.value and not edited):
# Don't add assistant messages to the history if they contain "edits" # Don't add assistant messages to the history if they contain "edits"
# Because those edits are actually fully copies of the file! # Because those edits are actually fully copies of the file!
# That wastes too much context window. # That wastes too much context window.
@ -562,7 +563,9 @@ class Coder:
def get_commit_message(self, diffs, context): def get_commit_message(self, diffs, context):
if len(diffs) >= 4 * 1024 * 4: if len(diffs) >= 4 * 1024 * 4:
self.io.tool_error("Diff is too large for gpt-3.5-turbo to generate a commit message.") self.io.tool_error(
f"Diff is too large for {Models.GPT35.value} to generate a commit message."
)
return return
diffs = "# Diffs:\n" + diffs diffs = "# Diffs:\n" + diffs
@ -575,12 +578,13 @@ class Coder:
try: try:
commit_message, interrupted = self.send( commit_message, interrupted = self.send(
messages, messages,
model="gpt-3.5-turbo", model=Models.GPT35.value,
silent=True, silent=True,
) )
except openai.error.InvalidRequestError: except openai.error.InvalidRequestError:
self.io.tool_error( self.io.tool_error(
"Failed to generate commit message using gpt-3.5-turbo due to an invalid request." f"Failed to generate commit message using {Models.GPT35.value} due to an invalid"
" request."
) )
return return
@ -590,7 +594,7 @@ class Coder:
if interrupted: if interrupted:
self.io.tool_error( self.io.tool_error(
"Unable to get commit message from gpt-3.5-turbo. Use /commit to try again." f"Unable to get commit message from {Models.GPT35.value}. Use /commit to try again."
) )
return return
@ -715,9 +719,9 @@ class Coder:
return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files()) return set(self.get_all_relative_files()) - set(self.get_inchat_relative_files())
def apply_updates(self, content): def apply_updates(self, content):
if self.main_model == "gpt-4": if self.main_model == Models.GPT4.value:
method = self.update_files_gpt4 method = self.update_files_gpt4
elif self.main_model == "gpt-3.5-turbo": elif self.main_model == Models.GPT35.value:
method = self.update_files_gpt35 method = self.update_files_gpt35
else: else:
raise ValueError(f"apply_updates() doesn't support {self.main_model}") raise ValueError(f"apply_updates() doesn't support {self.main_model}")

View file

@ -7,6 +7,7 @@ import git
from prompt_toolkit.completion import Completion from prompt_toolkit.completion import Completion
from aider import prompts from aider import prompts
from aider.utils import Models
class Commands: class Commands:
@ -118,7 +119,7 @@ class Commands:
"was reset and removed from git.\n" "was reset and removed from git.\n"
) )
if self.coder.main_model != "gpt-3.5-turbo": if self.coder.main_model != Models.GPT35.value:
return prompts.undo_command_reply return prompts.undo_command_reply
def cmd_diff(self, args): def cmd_diff(self, args):

View file

@ -6,6 +6,7 @@ import git
from aider.coder import Coder from aider.coder import Coder
from aider.io import InputOutput from aider.io import InputOutput
from aider.utils import Models
def get_git_root(): def get_git_root():
@ -75,15 +76,15 @@ def main(args=None, input=None, output=None):
parser.add_argument( parser.add_argument(
"--model", "--model",
metavar="MODEL", metavar="MODEL",
default="gpt-4", default=Models.GPT4.value,
help="Specify the model to use for the main chat (default: gpt-4)", help=f"Specify the model to use for the main chat (default: {Models.GPT4.value})",
) )
parser.add_argument( parser.add_argument(
"-3", "-3",
action="store_const", action="store_const",
dest="model", dest="model",
const="gpt-3.5-turbo", const=Models.GPT35.value,
help="Use gpt-3.5-turbo model for the main chat (not advised)", help=f"Use {Models.GPT35.value} model for the main chat (not advised)",
) )
parser.add_argument( parser.add_argument(
"--pretty", "--pretty",

View file

@ -14,6 +14,8 @@ from pygments.lexers import guess_lexer_for_filename
from pygments.token import Token from pygments.token import Token
from pygments.util import ClassNotFound from pygments.util import ClassNotFound
from aider.utils import Models
from .dump import dump # noqa: F402 from .dump import dump # noqa: F402
@ -64,7 +66,12 @@ class RepoMap:
TAGS_CACHE_DIR = f".aider.tags.cache.v{CACHE_VERSION}" TAGS_CACHE_DIR = f".aider.tags.cache.v{CACHE_VERSION}"
def __init__( def __init__(
self, map_tokens=1024, root=None, main_model="gpt-4", io=None, repo_content_prefix=None self,
map_tokens=1024,
root=None,
main_model=Models.GPT4.value,
io=None,
repo_content_prefix=None,
): ):
self.io = io self.io = io

View file

@ -1,11 +1,17 @@
import math import math
import re import re
from difflib import SequenceMatcher from difflib import SequenceMatcher
from enum import Enum
from pathlib import Path from pathlib import Path
# from aider.dump import dump # from aider.dump import dump
class Models(Enum):
GPT4 = "gpt-4"
GPT35 = "gpt-3.5-turbo"
def try_dotdotdots(whole, part, replace): def try_dotdotdots(whole, part, replace):
""" """
See if the edit block has ... lines. See if the edit block has ... lines.