diff --git a/exir/delegate.py b/exir/delegate.py index 959bd4bb17c..076e08daf37 100644 --- a/exir/delegate.py +++ b/exir/delegate.py @@ -102,7 +102,7 @@ def fake_requires_grad(var): var.requires_grad = True return var - return pytree.tree_map(fake_requires_grad, res) + return pytree.tree_map_only(torch.Tensor, fake_requires_grad, res) return res