mirror of
https://github.com/Aider-AI/aider.git
synced 2025-05-24 22:34:59 +00:00
refactor to use io.read_text
This commit is contained in:
parent
413ddb84df
commit
c8fb1ddaa7
4 changed files with 18 additions and 15 deletions
|
@ -167,13 +167,13 @@ class Coder:
|
||||||
self.find_common_root()
|
self.find_common_root()
|
||||||
|
|
||||||
if main_model.use_repo_map and self.repo and self.gpt_prompts.repo_content_prefix:
|
if main_model.use_repo_map and self.repo and self.gpt_prompts.repo_content_prefix:
|
||||||
rm_io = io if self.verbose else None
|
|
||||||
self.repo_map = RepoMap(
|
self.repo_map = RepoMap(
|
||||||
map_tokens,
|
map_tokens,
|
||||||
self.root,
|
self.root,
|
||||||
self.main_model,
|
self.main_model,
|
||||||
rm_io,
|
io,
|
||||||
self.gpt_prompts.repo_content_prefix,
|
self.gpt_prompts.repo_content_prefix,
|
||||||
|
self.verbose,
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.repo_map.use_ctags:
|
if self.repo_map.use_ctags:
|
||||||
|
|
|
@ -141,7 +141,7 @@ class InputOutput:
|
||||||
return f.read()
|
return f.read()
|
||||||
except (FileNotFoundError, UnicodeError) as e:
|
except (FileNotFoundError, UnicodeError) as e:
|
||||||
self.tool_error(str(e))
|
self.tool_error(str(e))
|
||||||
return None
|
return
|
||||||
|
|
||||||
def get_input(self, root, rel_fnames, addable_rel_fnames, commands):
|
def get_input(self, root, rel_fnames, addable_rel_fnames, commands):
|
||||||
if self.pretty:
|
if self.pretty:
|
||||||
|
|
|
@ -74,8 +74,10 @@ class RepoMap:
|
||||||
main_model=models.GPT4,
|
main_model=models.GPT4,
|
||||||
io=None,
|
io=None,
|
||||||
repo_content_prefix=None,
|
repo_content_prefix=None,
|
||||||
|
verbose=False,
|
||||||
):
|
):
|
||||||
self.io = io
|
self.io = io
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
if not root:
|
if not root:
|
||||||
root = os.getcwd()
|
root = os.getcwd()
|
||||||
|
@ -130,7 +132,7 @@ class RepoMap:
|
||||||
files_listing = self.get_ranked_tags_map(chat_files, other_files)
|
files_listing = self.get_ranked_tags_map(chat_files, other_files)
|
||||||
if files_listing:
|
if files_listing:
|
||||||
num_tokens = self.token_count(files_listing)
|
num_tokens = self.token_count(files_listing)
|
||||||
if self.io:
|
if self.verbose:
|
||||||
self.io.tool_output(f"ctags map: {num_tokens/1024:.1f} k-tokens")
|
self.io.tool_output(f"ctags map: {num_tokens/1024:.1f} k-tokens")
|
||||||
ctags_msg = " with selected ctags info"
|
ctags_msg = " with selected ctags info"
|
||||||
return files_listing, ctags_msg
|
return files_listing, ctags_msg
|
||||||
|
@ -138,7 +140,7 @@ class RepoMap:
|
||||||
files_listing = self.get_simple_files_map(other_files)
|
files_listing = self.get_simple_files_map(other_files)
|
||||||
ctags_msg = ""
|
ctags_msg = ""
|
||||||
num_tokens = self.token_count(files_listing)
|
num_tokens = self.token_count(files_listing)
|
||||||
if self.io:
|
if self.verbose:
|
||||||
self.io.tool_output(f"simple map: {num_tokens/1024:.1f} k-tokens")
|
self.io.tool_output(f"simple map: {num_tokens/1024:.1f} k-tokens")
|
||||||
if num_tokens < self.max_map_tokens:
|
if num_tokens < self.max_map_tokens:
|
||||||
return files_listing, ctags_msg
|
return files_listing, ctags_msg
|
||||||
|
@ -198,7 +200,7 @@ class RepoMap:
|
||||||
|
|
||||||
with tempfile.TemporaryDirectory() as tempdir:
|
with tempfile.TemporaryDirectory() as tempdir:
|
||||||
hello_py = os.path.join(tempdir, "hello.py")
|
hello_py = os.path.join(tempdir, "hello.py")
|
||||||
with open(hello_py, "w") as f:
|
with open(hello_py, "w", encoding="utf-8") as f:
|
||||||
f.write("def hello():\n print('Hello, world!')\n")
|
f.write("def hello():\n print('Hello, world!')\n")
|
||||||
self.run_ctags(hello_py)
|
self.run_ctags(hello_py)
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
@ -237,10 +239,8 @@ class RepoMap:
|
||||||
return idents
|
return idents
|
||||||
|
|
||||||
def get_name_identifiers_uncached(self, fname):
|
def get_name_identifiers_uncached(self, fname):
|
||||||
try:
|
content = self.io.read_text(fname)
|
||||||
with open(fname, "r") as f:
|
if content is None:
|
||||||
content = f.read()
|
|
||||||
except UnicodeDecodeError:
|
|
||||||
return list()
|
return list()
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -3,6 +3,7 @@ import tempfile
|
||||||
import unittest
|
import unittest
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
from aider.io import InputOutput
|
||||||
from aider.repomap import RepoMap
|
from aider.repomap import RepoMap
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,7 +22,8 @@ class TestRepoMap(unittest.TestCase):
|
||||||
with open(os.path.join(temp_dir, file), "w") as f:
|
with open(os.path.join(temp_dir, file), "w") as f:
|
||||||
f.write("")
|
f.write("")
|
||||||
|
|
||||||
repo_map = RepoMap(root=temp_dir)
|
io = InputOutput()
|
||||||
|
repo_map = RepoMap(root=temp_dir, io=io)
|
||||||
other_files = [os.path.join(temp_dir, file) for file in test_files]
|
other_files = [os.path.join(temp_dir, file) for file in test_files]
|
||||||
result = repo_map.get_repo_map([], other_files)
|
result = repo_map.get_repo_map([], other_files)
|
||||||
|
|
||||||
|
@ -65,7 +67,8 @@ print(my_function(3, 4))
|
||||||
with open(os.path.join(temp_dir, test_file3), "w") as f:
|
with open(os.path.join(temp_dir, test_file3), "w") as f:
|
||||||
f.write(file_content3)
|
f.write(file_content3)
|
||||||
|
|
||||||
repo_map = RepoMap(root=temp_dir)
|
io = InputOutput()
|
||||||
|
repo_map = RepoMap(root=temp_dir, io=io)
|
||||||
other_files = [
|
other_files = [
|
||||||
os.path.join(temp_dir, test_file1),
|
os.path.join(temp_dir, test_file1),
|
||||||
os.path.join(temp_dir, test_file2),
|
os.path.join(temp_dir, test_file2),
|
||||||
|
@ -83,7 +86,7 @@ print(my_function(3, 4))
|
||||||
def test_check_for_ctags_failure(self):
|
def test_check_for_ctags_failure(self):
|
||||||
with patch("subprocess.run") as mock_run:
|
with patch("subprocess.run") as mock_run:
|
||||||
mock_run.side_effect = Exception("ctags not found")
|
mock_run.side_effect = Exception("ctags not found")
|
||||||
repo_map = RepoMap()
|
repo_map = RepoMap(io=InputOutput())
|
||||||
self.assertFalse(repo_map.has_ctags)
|
self.assertFalse(repo_map.has_ctags)
|
||||||
|
|
||||||
def test_check_for_ctags_success(self):
|
def test_check_for_ctags_success(self):
|
||||||
|
@ -100,7 +103,7 @@ print(my_function(3, 4))
|
||||||
b' status = main()$/", "kind": "variable"}'
|
b' status = main()$/", "kind": "variable"}'
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
repo_map = RepoMap()
|
repo_map = RepoMap(io=InputOutput())
|
||||||
self.assertTrue(repo_map.has_ctags)
|
self.assertTrue(repo_map.has_ctags)
|
||||||
|
|
||||||
def test_get_repo_map_without_ctags(self):
|
def test_get_repo_map_without_ctags(self):
|
||||||
|
@ -120,7 +123,7 @@ print(my_function(3, 4))
|
||||||
with open(os.path.join(temp_dir, file), "w") as f:
|
with open(os.path.join(temp_dir, file), "w") as f:
|
||||||
f.write("")
|
f.write("")
|
||||||
|
|
||||||
repo_map = RepoMap(root=temp_dir)
|
repo_map = RepoMap(root=temp_dir, io=InputOutput())
|
||||||
repo_map.has_ctags = False # force it off
|
repo_map.has_ctags = False # force it off
|
||||||
|
|
||||||
other_files = [os.path.join(temp_dir, file) for file in test_files]
|
other_files = [os.path.join(temp_dir, file) for file in test_files]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue