Refactor check_for_local_edits to get_last_modified and update last_modified in commit method.

This commit is contained in:
Paul Gauthier 2023-05-08 23:28:42 -07:00
parent af972538f1
commit b401c803e4

View file

@ -64,7 +64,6 @@ class Coder:
"[red bold]Will not automatically commit edits as they happen." "[red bold]Will not automatically commit edits as they happen."
) )
self.check_for_local_edits(True)
self.pretty = pretty self.pretty = pretty
def set_repo(self): def set_repo(self):
@ -168,15 +167,8 @@ class Coder:
readline.write_history_file(history_file) readline.write_history_file(history_file)
return inp return inp
def check_for_local_edits(self, init=False): def get_last_modified(self):
last_modified = max(Path(fname).stat().st_mtime for fname in self.fnames) return max(Path(fname).stat().st_mtime for fname in self.fnames)
since = last_modified - self.last_modified
self.last_modified = last_modified
if init:
return
if since > 0:
return True
return False
def get_files_messages(self): def get_files_messages(self):
files_content = prompts.files_content_prefix files_content = prompts.files_content_prefix
@ -218,7 +210,7 @@ class Coder:
self.num_control_c = 0 self.num_control_c = 0
if self.check_for_local_edits(): if self.last_modified < self.get_last_modified():
self.commit(ask=True) self.commit(ask=True)
# files changed, move cur messages back behind the files messages # files changed, move cur messages back behind the files messages
@ -506,6 +498,7 @@ class Coder:
diffs += these_diffs + "\n" diffs += these_diffs + "\n"
if not dirty_fnames: if not dirty_fnames:
self.last_modified = self.get_last_modified()
return return
self.console.print(Text(diffs)) self.console.print(Text(diffs))
@ -544,6 +537,8 @@ class Coder:
commit_message = prefix + commit_message commit_message = prefix + commit_message
if ask: if ask:
self.last_modified = self.get_last_modified()
self.console.print("[red]Files have uncommitted changes.\n") self.console.print("[red]Files have uncommitted changes.\n")
self.console.print(f"[red]Suggested commit message:\n{commit_message}\n") self.console.print(f"[red]Suggested commit message:\n{commit_message}\n")
@ -563,6 +558,9 @@ class Coder:
repo.git.commit("-m", full_commit_message, "--no-verify") repo.git.commit("-m", full_commit_message, "--no-verify")
commit_hash = repo.head.commit.hexsha[:7] commit_hash = repo.head.commit.hexsha[:7]
self.console.print(f"[green]{commit_hash} {commit_message}") self.console.print(f"[green]{commit_hash} {commit_message}")
self.last_modified = self.get_last_modified()
return commit_hash, commit_message return commit_hash, commit_message