Skip to content

Commit 5e594a4

Browse files
committed
fix(Translate): don't look in arguments of free variables in shouldTranslate (leanprover-community#35549)
This PR addresses a `to_additive` limitation reported at [#mathlib4 > leanprover-community#35179 breaks to_additive call in FLT](https://leanprover.zulipchat.com/#narrow/channel/287929-mathlib4/topic/.2335179.20breaks.20to_additive.20call.20in.20FLT/with/574823087) by making sure `shouldTranslate` doesn't looks at the arguments of free variable applications This PR refactors `shouldTranslate` in a few ways - The monad is changed from `OptionT` to `ExceptT`, with the underlying type changing from `Option Expr` to `ExceptT Expr Unit`. Although this is technically a bit less efficient, it makes the program shorter and more intuitive, because a failure now actually corresponds to a monadic failure instead of a monadic success. - When visiting a `Expr.app` node, we now loop through the whole application at once, instead of doing this step by step and checking the cache at each step. This allows us to detect when the head is a free variable, and skip its arguments. It also lets us deal with `ignoreArgsAttr` more naturally. - This change removes the need for the `inApp : Bool` argument that used to be passed around in order to tell whether a constant was in an application.
1 parent 2a6bde3 commit 5e594a4

File tree

4 files changed

+44
-39
lines changed

4 files changed

+44
-39
lines changed

Mathlib/Tactic/Translate/Core.lean

Lines changed: 27 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -350,56 +350,53 @@ def ReplacementM.run {α} (dontTranslate allFVars : Array FVarId) (x : Replaceme
350350
return (a, allFVars.findIdx? relevantFVars.contains)
351351

352352
/-- Implementation function for `shouldTranslate`.
353-
Failure means that in that subexpression there is no constant that blocks `e` from being translated.
353+
Returning `none` means that `e` contains no constant that blocks translation.
354354
We cache previous applications of the function, using an expression cache using ptr equality
355-
to avoid visiting the same subexpression many times. Note that we only need to cache the
356-
expressions without taking the value of `inApp` into account, since `inApp` only matters when
357-
the expression is a constant. However, for this reason we have to make sure that we never
358-
cache constant expressions, so that's why the `if`s in the implementation are in this order.
355+
to avoid visiting the same subexpression many times.
359356
360357
Note that this function is still called many times by `applyReplacementFun`
361358
and we're not remembering the cache between these calls. -/
362359
private unsafe def shouldTranslateUnsafe (env : Environment) (t : TranslateData) (e : Expr) :
363360
ReplacementM (Option Expr) := do
364361
let visitedFVars : IO.Ref (Array FVarId) ← IO.mkRef #[]
365362
let dontTranslate ← read
366-
let rec visit (e : Expr) (inApp := false) : OptionT (StateT (PtrSet Expr) BaseIO) Expr := do
367-
if e.isConst then
368-
let doTranslate :=
369-
(t.doTranslateAttr.find? env e.constName!).getD <|
370-
inApp || (findTranslation? env t e.constName).isSome
371-
if doTranslate then failure else return e
363+
let rec visit (e : Expr) : ExceptT Expr (StateT (PtrSet Expr) BaseIO) Unit := do
372364
if (← get).contains e then
373-
failure
365+
return
374366
modify fun s => s.insert e
375367
match e with
376-
| .app e a =>
377-
visit e true <|> do
378-
if let some n := e.getAppFn.constName? then
379-
if let some l := t.ignoreArgsAttr.find? env n then
380-
if e.getAppNumArgs + 1 ∈ l then
381-
failure
382-
visit a
368+
| .app .. => e.withApp fun f args ↦ do
369+
match f with
370+
| .const n _ =>
371+
-- A constant in an application, e.g. `Prod` in `α × β`, is translated by default.
372+
let doTranslate := (t.doTranslateAttr.find? env n).getD true
373+
unless doTranslate do throw e
374+
let l := (t.ignoreArgsAttr.find? env n).getD []
375+
args.size.forM fun i _ ↦ do
376+
if !l.contains i then visit args[i]
377+
| .fvar .. => visit f -- We don't look in the arguments of free variables.
378+
| _ => visit f; args.forM visit
379+
| .const n _ =>
380+
-- A constant not in an application, e.g. `ℕ`, is not translated by default.
381+
let doTranslate := (t.doTranslateAttr.find? env n).getD (findTranslation? env t n).isSome
382+
unless doTranslate do throw e
383383
| .lam _ _ t _ => visit t
384384
| .forallE _ _ t _ => visit t
385-
| .letE _ _ e body _ => visit e <|> visit body
385+
| .letE _ _ e body _ => visit e; visit body
386386
| .mdata _ b => visit b
387387
| .proj _ _ b => visit b
388388
| .fvar fvarId =>
389389
if dontTranslate.contains fvarId then
390-
return e
391-
else
392-
visitedFVars.modify (·.push fvarId)
393-
failure
390+
throw e
391+
visitedFVars.modify (·.push fvarId)
394392
/- We do not translate the order on `Prop`.
395393
TODO: We also don't want to translate the category on `Type u`. Unfortunately, replacing
396394
`.sort 0` with `.sort _` here breaks some uses of `to_additive` on `MonCat`. -/
397-
| .sort 0 => return e
398-
| _ => failure
399-
let x ← (visit e).run' mkPtrSet
400-
match x with
401-
| some e => return some e
402-
| none =>
395+
| .sort 0 => throw e
396+
| _ => pure ()
397+
match ← (visit e).run' mkPtrSet with
398+
| .error e => return some e
399+
| .ok () =>
403400
/- In the case that we do translate, we mark the visited free variables as relevant for
404401
the translation by inserting them into the state. -/
405402
modify (·.insertMany (← visitedFVars.get))

Mathlib/Tactic/Translate/ToAdditive.lean

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -257,10 +257,10 @@ initialize ignoreArgsAttr : NameMapExtension (List Nat) ←
257257
descr :=
258258
"Auxiliary attribute for `to_additive` stating that certain arguments are not additivized."
259259
add := fun _ stx ↦ do
260-
let ids ← match stx with
261-
| `(attr| to_additive_ignore_args $[$ids:num]*) => pure <| ids.map (·.1.isNatLit?.get!)
262-
| _ => throwUnsupportedSyntax
263-
return ids.toList }
260+
let ids ← match stx with
261+
| `(attr| to_additive_ignore_args $[$ids:num]*) => pure <| ids.map (·.getNat - 1)
262+
| _ => throwUnsupportedSyntax
263+
return ids.toList }
264264

265265
@[inherit_doc TranslateData.doTranslateAttr]
266266
initialize doTranslateAttr : NameMapExtension Bool ← registerNameMapExtension _

Mathlib/Tactic/Translate/ToDual.lean

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@ initialize ignoreArgsAttr : NameMapExtension (List Nat) ←
112112
name := `to_dual_ignore_args
113113
descr :=
114114
"Auxiliary attribute for `to_dual` stating that certain arguments are not dualized."
115-
add := fun _ stx ↦ do
116-
let ids ← match stx with
117-
| `(attr| to_dual_ignore_args $[$ids:num]*) => pure <| ids.map (·.1.isNatLit?.get!)
118-
| _ => throwUnsupportedSyntax
119-
return ids.toList }
115+
add := fun _ stx ↦ do
116+
let ids ← match stx with
117+
| `(attr| to_dual_ignore_args $[$ids:num]*) => pure <| ids.map (·.getNat - 1)
118+
| _ => throwUnsupportedSyntax
119+
return ids.toList }
120120

121121
@[inherit_doc TranslateData.unfoldBoundaries?]
122122
initialize unfoldBoundaries : UnfoldBoundaryExt ← registerUnfoldBoundaryExt

MathlibTest/toAdditive.lean

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -863,3 +863,11 @@ attribute [to_additive someOtherTranslation] abstractMul
863863
-- Test that we don't blindly translate the prefix of a name.
864864
def Mul.test : Nat := 5
865865
@[to_additive] def Mul.test' := Mul.test
866+
867+
-- Test that arguments of free variables aren't considered by `shouldTranslate`
868+
@[to_additive_dont_translate]
869+
def dontTranslateId {α} : α → α := id
870+
871+
@[to_additive]
872+
theorem functionTypeMonoid {ι : Type*} {R : ι → Type*} [(i : ι) → Monoid (R i)] (i : ι)
873+
(a : R (dontTranslateId i)) : a * a = a * a := rfl

0 commit comments

Comments
 (0)