11from datetime import datetime
2- from typing import Any , ClassVar , Dict , List , Optional , Type , TypeVar
2+ from typing import Any , ClassVar , Dict , List , Optional , Type , TypeVar , cast
33
44from loguru import logger
55from pydantic import BaseModel , ValidationError , field_validator
@@ -29,15 +29,26 @@ class ObjectModel(BaseModel):
2929 @classmethod
3030 def get_all (cls : Type [T ], order_by = None ) -> List [T ]:
3131 try :
32+ # If called from a specific subclass, use its table_name
33+ if cls .table_name :
34+ target_class = cls
35+ table_name = cls .table_name
36+ else :
37+ # This path is taken if called directly from ObjectModel
38+ raise InvalidInputError (
39+ "get_all() must be called from a specific model class"
40+ )
41+
3242 if order_by :
3343 order = f" ORDER BY { order_by } "
3444 else :
3545 order = ""
36- result = repo_query (f"SELECT * FROM { cls .table_name } { order } " )
46+
47+ result = repo_query (f"SELECT * FROM { table_name } { order } " )
3748 objects = []
3849 for obj in result :
3950 try :
40- objects .append (cls (** obj ))
51+ objects .append (target_class (** obj ))
4152 except Exception as e :
4253 logger .critical (f"Error creating object: { str (e )} " )
4354
@@ -52,15 +63,44 @@ def get(cls: Type[T], id: str) -> T:
5263 if not id :
5364 raise InvalidInputError ("ID cannot be empty" )
5465 try :
66+ # Get the table name from the ID (everything before the first colon)
67+ table_name = id .split (":" )[0 ] if ":" in id else id
68+
69+ # If we're calling from a specific subclass and IDs match, use that class
70+ if cls .table_name and cls .table_name == table_name :
71+ target_class : Type [T ] = cls
72+ else :
73+ # Otherwise, find the appropriate subclass based on table_name
74+ found_class = cls ._get_class_by_table_name (table_name )
75+ if not found_class :
76+ raise InvalidInputError (f"No class found for table { table_name } " )
77+ target_class = cast (Type [T ], found_class )
78+
5579 result = repo_query (f"SELECT * FROM { id } " )
5680 if result :
57- return cls (** result [0 ])
81+ return target_class (** result [0 ])
5882 else :
59- raise NotFoundError (f"{ cls . table_name } with id { id } not found" )
83+ raise NotFoundError (f"{ table_name } with id { id } not found" )
6084 except Exception as e :
61- logger .error (f"Error fetching { cls . table_name } with id { id } : { str (e )} " )
85+ logger .error (f"Error fetching object with id { id } : { str (e )} " )
6286 logger .exception (e )
63- raise NotFoundError (f"{ cls .table_name } with id { id } not found" )
87+ raise NotFoundError (f"Object with id { id } not found - { str (e )} " )
88+
89+ @classmethod
90+ def _get_class_by_table_name (cls , table_name : str ) -> Optional [Type ["ObjectModel" ]]:
91+ """Find the appropriate subclass based on table_name."""
92+
93+ def get_all_subclasses (c : Type ["ObjectModel" ]) -> List [Type ["ObjectModel" ]]:
94+ all_subclasses : List [Type ["ObjectModel" ]] = []
95+ for subclass in c .__subclasses__ ():
96+ all_subclasses .append (subclass )
97+ all_subclasses .extend (get_all_subclasses (subclass ))
98+ return all_subclasses
99+
100+ for subclass in get_all_subclasses (ObjectModel ):
101+ if hasattr (subclass , "table_name" ) and subclass .table_name == table_name :
102+ return subclass
103+ return None
64104
65105 def needs_embedding (self ) -> bool :
66106 return False
0 commit comments