fixed test_repomap

This commit is contained in:
Paul Gauthier 2024-04-18 16:06:35 -07:00
parent 2a3eb8ac35
commit 363b3202ab

View file

@ -1,17 +1,17 @@
from collections import defaultdict
import os import os
import unittest import unittest
from pathlib import Path
import networkx as nx
from aider.dump import dump # noqa: F401 from aider.dump import dump # noqa: F401
from aider.io import InputOutput from aider.io import InputOutput
from aider.models import Model
from aider.repomap import RepoMap from aider.repomap import RepoMap
from aider import models
from aider.utils import IgnorantTemporaryDirectory from aider.utils import IgnorantTemporaryDirectory
class TestRepoMap(unittest.TestCase): class TestRepoMap(unittest.TestCase):
def setUp(self):
self.GPT35 = Model("gpt-3.5-turbo")
def test_get_repo_map(self): def test_get_repo_map(self):
# Create a temporary directory with sample files for testing # Create a temporary directory with sample files for testing
test_files = [ test_files = [
@ -27,7 +27,7 @@ class TestRepoMap(unittest.TestCase):
f.write("") f.write("")
io = InputOutput() io = InputOutput()
repo_map = RepoMap(root=temp_dir, io=io) repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
other_files = [os.path.join(temp_dir, file) for file in test_files] other_files = [os.path.join(temp_dir, file) for file in test_files]
result = repo_map.get_repo_map([], other_files) result = repo_map.get_repo_map([], other_files)
@ -75,7 +75,7 @@ print(my_function(3, 4))
f.write(file_content3) f.write(file_content3)
io = InputOutput() io = InputOutput()
repo_map = RepoMap(root=temp_dir, io=io) repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
other_files = [ other_files = [
os.path.join(temp_dir, test_file1), os.path.join(temp_dir, test_file1),
os.path.join(temp_dir, test_file2), os.path.join(temp_dir, test_file2),
@ -109,7 +109,7 @@ print(my_function(3, 4))
with open(os.path.join(temp_dir, file), "w") as f: with open(os.path.join(temp_dir, file), "w") as f:
f.write("") f.write("")
repo_map = RepoMap(root=temp_dir, io=InputOutput()) repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=InputOutput())
other_files = [os.path.join(temp_dir, file) for file in test_files] other_files = [os.path.join(temp_dir, file) for file in test_files]
result = repo_map.get_repo_map([], other_files) result = repo_map.get_repo_map([], other_files)
@ -138,7 +138,7 @@ print(my_function(3, 4))
f.write("def foo(): pass\n") f.write("def foo(): pass\n")
io = InputOutput() io = InputOutput()
repo_map = RepoMap(root=temp_dir, io=io) repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
test_files = [os.path.join(temp_dir, file) for file in test_files] test_files = [os.path.join(temp_dir, file) for file in test_files]
result = repo_map.get_repo_map(test_files[:2], test_files[2:]) result = repo_map.get_repo_map(test_files[:2], test_files[2:])
@ -155,6 +155,9 @@ print(my_function(3, 4))
class TestRepoMapTypescript(unittest.TestCase): class TestRepoMapTypescript(unittest.TestCase):
def setUp(self):
self.GPT35 = Model("gpt-3.5-turbo")
def test_get_repo_map_typescript(self): def test_get_repo_map_typescript(self):
# Create a temporary directory with a sample TypeScript file # Create a temporary directory with a sample TypeScript file
test_file_ts = "test_file.ts" test_file_ts = "test_file.ts"
@ -193,7 +196,7 @@ export function myFunction(input: number): number {
f.write(file_content_ts) f.write(file_content_ts)
io = InputOutput() io = InputOutput()
repo_map = RepoMap(root=temp_dir, io=io) repo_map = RepoMap(main_model=self.GPT35, root=temp_dir, io=io)
other_files = [os.path.join(temp_dir, test_file_ts)] other_files = [os.path.join(temp_dir, test_file_ts)]
result = repo_map.get_repo_map([], other_files) result = repo_map.get_repo_map([], other_files)
@ -209,5 +212,6 @@ export function myFunction(input: number): number {
# close the open cache files, so Windows won't error # close the open cache files, so Windows won't error
del repo_map del repo_map
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()