refactor to use io.read_text

This commit is contained in:
Paul Gauthier 2023-07-05 20:46:15 -07:00
parent 413ddb84df
commit c8fb1ddaa7
4 changed files with 18 additions and 15 deletions

View file

@ -167,13 +167,13 @@ class Coder:
self.find_common_root() self.find_common_root()
if main_model.use_repo_map and self.repo and self.gpt_prompts.repo_content_prefix: if main_model.use_repo_map and self.repo and self.gpt_prompts.repo_content_prefix:
rm_io = io if self.verbose else None
self.repo_map = RepoMap( self.repo_map = RepoMap(
map_tokens, map_tokens,
self.root, self.root,
self.main_model, self.main_model,
rm_io, io,
self.gpt_prompts.repo_content_prefix, self.gpt_prompts.repo_content_prefix,
self.verbose,
) )
if self.repo_map.use_ctags: if self.repo_map.use_ctags:

View file

@ -141,7 +141,7 @@ class InputOutput:
return f.read() return f.read()
except (FileNotFoundError, UnicodeError) as e: except (FileNotFoundError, UnicodeError) as e:
self.tool_error(str(e)) self.tool_error(str(e))
return None return
def get_input(self, root, rel_fnames, addable_rel_fnames, commands): def get_input(self, root, rel_fnames, addable_rel_fnames, commands):
if self.pretty: if self.pretty:

View file

@ -74,8 +74,10 @@ class RepoMap:
main_model=models.GPT4, main_model=models.GPT4,
io=None, io=None,
repo_content_prefix=None, repo_content_prefix=None,
verbose=False,
): ):
self.io = io self.io = io
self.verbose = verbose
if not root: if not root:
root = os.getcwd() root = os.getcwd()
@ -130,7 +132,7 @@ class RepoMap:
files_listing = self.get_ranked_tags_map(chat_files, other_files) files_listing = self.get_ranked_tags_map(chat_files, other_files)
if files_listing: if files_listing:
num_tokens = self.token_count(files_listing) num_tokens = self.token_count(files_listing)
if self.io: if self.verbose:
self.io.tool_output(f"ctags map: {num_tokens/1024:.1f} k-tokens") self.io.tool_output(f"ctags map: {num_tokens/1024:.1f} k-tokens")
ctags_msg = " with selected ctags info" ctags_msg = " with selected ctags info"
return files_listing, ctags_msg return files_listing, ctags_msg
@ -138,7 +140,7 @@ class RepoMap:
files_listing = self.get_simple_files_map(other_files) files_listing = self.get_simple_files_map(other_files)
ctags_msg = "" ctags_msg = ""
num_tokens = self.token_count(files_listing) num_tokens = self.token_count(files_listing)
if self.io: if self.verbose:
self.io.tool_output(f"simple map: {num_tokens/1024:.1f} k-tokens") self.io.tool_output(f"simple map: {num_tokens/1024:.1f} k-tokens")
if num_tokens < self.max_map_tokens: if num_tokens < self.max_map_tokens:
return files_listing, ctags_msg return files_listing, ctags_msg
@ -198,7 +200,7 @@ class RepoMap:
with tempfile.TemporaryDirectory() as tempdir: with tempfile.TemporaryDirectory() as tempdir:
hello_py = os.path.join(tempdir, "hello.py") hello_py = os.path.join(tempdir, "hello.py")
with open(hello_py, "w") as f: with open(hello_py, "w", encoding="utf-8") as f:
f.write("def hello():\n print('Hello, world!')\n") f.write("def hello():\n print('Hello, world!')\n")
self.run_ctags(hello_py) self.run_ctags(hello_py)
except FileNotFoundError: except FileNotFoundError:
@ -237,10 +239,8 @@ class RepoMap:
return idents return idents
def get_name_identifiers_uncached(self, fname): def get_name_identifiers_uncached(self, fname):
try: content = self.io.read_text(fname)
with open(fname, "r") as f: if content is None:
content = f.read()
except UnicodeDecodeError:
return list() return list()
try: try:

View file

@ -3,6 +3,7 @@ import tempfile
import unittest import unittest
from unittest.mock import patch from unittest.mock import patch
from aider.io import InputOutput
from aider.repomap import RepoMap from aider.repomap import RepoMap
@ -21,7 +22,8 @@ class TestRepoMap(unittest.TestCase):
with open(os.path.join(temp_dir, file), "w") as f: with open(os.path.join(temp_dir, file), "w") as f:
f.write("") f.write("")
repo_map = RepoMap(root=temp_dir) io = InputOutput()
repo_map = RepoMap(root=temp_dir, io=io)
other_files = [os.path.join(temp_dir, file) for file in test_files] other_files = [os.path.join(temp_dir, file) for file in test_files]
result = repo_map.get_repo_map([], other_files) result = repo_map.get_repo_map([], other_files)
@ -65,7 +67,8 @@ print(my_function(3, 4))
with open(os.path.join(temp_dir, test_file3), "w") as f: with open(os.path.join(temp_dir, test_file3), "w") as f:
f.write(file_content3) f.write(file_content3)
repo_map = RepoMap(root=temp_dir) io = InputOutput()
repo_map = RepoMap(root=temp_dir, io=io)
other_files = [ other_files = [
os.path.join(temp_dir, test_file1), os.path.join(temp_dir, test_file1),
os.path.join(temp_dir, test_file2), os.path.join(temp_dir, test_file2),
@ -83,7 +86,7 @@ print(my_function(3, 4))
def test_check_for_ctags_failure(self): def test_check_for_ctags_failure(self):
with patch("subprocess.run") as mock_run: with patch("subprocess.run") as mock_run:
mock_run.side_effect = Exception("ctags not found") mock_run.side_effect = Exception("ctags not found")
repo_map = RepoMap() repo_map = RepoMap(io=InputOutput())
self.assertFalse(repo_map.has_ctags) self.assertFalse(repo_map.has_ctags)
def test_check_for_ctags_success(self): def test_check_for_ctags_success(self):
@ -100,7 +103,7 @@ print(my_function(3, 4))
b' status = main()$/", "kind": "variable"}' b' status = main()$/", "kind": "variable"}'
), ),
] ]
repo_map = RepoMap() repo_map = RepoMap(io=InputOutput())
self.assertTrue(repo_map.has_ctags) self.assertTrue(repo_map.has_ctags)
def test_get_repo_map_without_ctags(self): def test_get_repo_map_without_ctags(self):
@ -120,7 +123,7 @@ print(my_function(3, 4))
with open(os.path.join(temp_dir, file), "w") as f: with open(os.path.join(temp_dir, file), "w") as f:
f.write("") f.write("")
repo_map = RepoMap(root=temp_dir) repo_map = RepoMap(root=temp_dir, io=InputOutput())
repo_map.has_ctags = False # force it off repo_map.has_ctags = False # force it off
other_files = [os.path.join(temp_dir, file) for file in test_files] other_files = [os.path.join(temp_dir, file) for file in test_files]