diff --git a/tinygrad/nn/state.py b/tinygrad/nn/state.py index 7b340c4ec71e0..100584a427653 100644 --- a/tinygrad/nn/state.py +++ b/tinygrad/nn/state.py @@ -147,7 +147,8 @@ def __init__(self): print("WARNING: unused weights in state_dict", sorted(list(state_dict.keys() - model_state_dict.keys()))) for k,v in (t := tqdm(model_state_dict.items(), disable=None if verbose else True)): t.desc = f"ram used: {GlobalCounters.mem_used/1e9:5.2f} GB, {k:50s}: " - if k not in state_dict and not strict: + if k not in state_dict: + if strict: raise RuntimeError(f"missing key in state_dict: {k!r}") if DEBUG >= 1: print(f"WARNING: not loading {k}") continue if v.shape != state_dict[k].shape: