Skip to content

Patch training bfloat16 on MPS bug#1278

Merged
mrariden merged 2 commits intomainfrom
mps_bfloat16
Jul 23, 2025
Merged

Patch training bfloat16 on MPS bug#1278
mrariden merged 2 commits intomainfrom
mps_bfloat16

Conversation

@mrariden
Copy link
Copy Markdown
Collaborator

MPS can evaluate the network using bfloat16 but not do backprop. This solution just converts the network back to float32 and does the training loop. Finally, it resets the network dtype to the original one.

Resolves #1263

@mrariden mrariden self-assigned this Jul 21, 2025
@mrariden
Copy link
Copy Markdown
Collaborator Author

resolves #1279

@mrariden mrariden merged commit de51608 into main Jul 23, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[BUG]save model fail for RuntimeError: "mse_cpu" not implemented for 'BFloat16' [BUG] error when train a model

1 participant