Fix: Prevent AttributeError when checking for UnifiedSolver in transform function#60
Conversation
|
|
||
| def transform(model, data, label, **transform_kwargs): | ||
| if isinstance(model, cebra.solver.UnifiedSolver): | ||
| if isinstance(model, BaseSolver): |
There was a problem hiding this comment.
same here, just use solver and cebra
CeliaBenquet
left a comment
There was a problem hiding this comment.
You can simplify naming here, just import them as they are called, no renaming :)
30a9a88 to
57d83d4
Compare
CeliaBenquet
left a comment
There was a problem hiding this comment.
lgtm! Just suggestion to remove the prints if not used.
|
Mm actually @anandawolz let's discuss the change, it is not correct as not all solvers have the labels in the transformer (only Unified). |
|
@anandawolz, regarding my last comment, how did you handle it? the def transform() in that wrapper takes a label, but that will not work for other solvers because they don't take a label. Have you tested that? |
|
@CeliaBenquet can you finalize this PR and merge? |
this was addressed here: 08ec240 |
The previous code accessed
cebra.solver.UnifiedSolverdirectly in anisinstance()check.In CEBRA versions where UnifiedSolver is not exposed as an attribute (e.g. sklearn models) of
cebra.solver, this caused anAttributeErrorbefore the check could fall through to the next condition.Change: Removed the direct attribute lookup of
UnifiedSolverto preventAttributeErrorin environments whereUnifiedSolveris unavailable.UnifiedSolvercheck is replaced byBaseSolver, which is already imported in a stable manner from cebra.solver.base.