From 9d13fadd4129e5c134f49ed8d9b1c331dddf6a91 Mon Sep 17 00:00:00 2001 From: Paul Gauthier Date: Tue, 19 Nov 2024 12:02:42 -0800 Subject: [PATCH] refactor: Simplify model settings configuration and remove tracking decorator --- aider/models.py | 44 +++++++------------------------------------- 1 file changed, 7 insertions(+), 37 deletions(-) diff --git a/aider/models.py b/aider/models.py index a4f90047b..afeedc955 100644 --- a/aider/models.py +++ b/aider/models.py @@ -63,19 +63,6 @@ claude-3-5-sonnet-20241022 ANTHROPIC_MODELS = [ln.strip() for ln in ANTHROPIC_MODELS.splitlines() if ln.strip()] -def track_init_fields(cls): - original_init = cls.__init__ - - @wraps(original_init) - def __init__(self, **kwargs): - self._set_fields = set(kwargs.keys()) - original_init(self, **kwargs) - - cls.__init__ = __init__ - return cls - - -@track_init_fields @dataclass class ModelSettings: # Model class needs to have each of these as well @@ -815,11 +802,8 @@ class Model(ModelSettings): self.weak_model = None self.editor_model = None - # Find default and override settings - self.default_model_settings = next( - (ms for ms in MODEL_SETTINGS if ms.name == "aider/default"), None - ) - self.override_model_settings = next( + # Find the extra settings + self.extra_model_settings = next( (ms for ms in MODEL_SETTINGS if ms.name == "aider/override"), None ) @@ -850,22 +834,13 @@ class Model(ModelSettings): def get_model_info(self, model): return model_info_manager.get_model_info(model) - def _copy_fields(self, source, skip_name=True): + def _copy_fields(self, source): """Helper to copy fields from a ModelSettings instance to self""" for field in fields(ModelSettings): - if skip_name and field.name == "name": - continue - # Only copy fields that were explicitly set in the source - if hasattr(source, "_set_fields") and field.name in source._set_fields: - val = getattr(source, field.name) - setattr(self, field.name, val) + val = getattr(source, field.name) + setattr(self, field.name, val) def configure_model_settings(self, model): - # Apply default settings first if they exist - if self.default_model_settings: - self._copy_fields(self.default_model_settings) - - dump(self.edit_format) # Look for exact model match exact_match = False for ms in MODEL_SETTINGS: @@ -875,20 +850,15 @@ class Model(ModelSettings): exact_match = True break # Continue to apply overrides - dump(self.edit_format) model = model.lower() # If no exact match, try generic settings if not exact_match: self.apply_generic_model_settings(model) - dump(self.edit_format) - # Apply override settings last if they exist - if self.override_model_settings: - self._copy_fields(self.override_model_settings) - - dump(self.edit_format) + if self.extra_model_settings: + # TODO: merge the self.extra_model_settings.extra_params dict into self.extra_params dict. don't remove existing entries, just add/update entries in extra_model_settings. be careful with sub-dicts: same things apply, preserve existing entries add/update def apply_generic_model_settings(self, model): if ("llama3" in model or "llama-3" in model) and "70b" in model: