Skip to content

Commit bbf4743

Browse files
Copilotckrause
andcommitted
Use Nat.nthRoot instead of custom helper and add formula tests
Co-authored-by: ckrause <840744+ckrause@users.noreply.github.com>
1 parent 40b801c commit bbf4743

File tree

5 files changed

+19
-35
lines changed

5 files changed

+19
-35
lines changed

src/form/lean.cpp

Lines changed: 8 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -141,7 +141,7 @@ bool LeanFormula::convertToLean(Expression& expr, int64_t offset,
141141
break;
142142
}
143143
if (expr.name == "sqrtnint") {
144-
// sqrtnint(x, n) -> sqrtnint x n (custom helper function)
144+
// sqrtnint(x, n) -> Int.ofNat (Nat.nthRoot (Int.toNat n) (Int.toNat x))
145145
if (expr.children.size() != 2) {
146146
return false;
147147
}
@@ -153,9 +153,13 @@ bool LeanFormula::convertToLean(Expression& expr, int64_t offset,
153153
if (ExpressionUtil::canBeNegative(root, offset)) {
154154
return false; // negative root index not supported
155155
}
156-
// Use helper function sqrtnint
157-
expr.name = "sqrtnint";
158-
helper_funcs.insert("sqrtnint");
156+
// Use Nat.nthRoot from Mathlib
157+
Expression baseToNat(Expression::Type::FUNCTION, "Int.toNat", {base});
158+
Expression rootToNat(Expression::Type::FUNCTION, "Int.toNat", {root});
159+
Expression nthRoot(Expression::Type::FUNCTION, "Nat.nthRoot", {rootToNat, baseToNat});
160+
Expression ofNat(Expression::Type::FUNCTION, "Int.ofNat", {nthRoot});
161+
expr = ofNat;
162+
imports.insert("Mathlib.Data.Nat.NthRoot.Defs");
159163
break;
160164
}
161165
// Allow calls to locally defined functions and sequences
@@ -256,36 +260,6 @@ std::string LeanFormula::toString() const {
256260
auto functions = FormulaUtil::getDefinitions(main_formula);
257261

258262
std::stringstream buf;
259-
260-
// Print helper function definitions first
261-
if (!helper_funcs.empty()) {
262-
for (const auto& func : helper_funcs) {
263-
if (func == "sqrtnint") {
264-
// Integer nth root: floor(x^(1/n))
265-
// We compute it by binary search to find largest y such that y^n <= x
266-
buf << "def sqrtnint (x : Int) (n : Int) : Int :=\n";
267-
buf << " if n <= 0 then 0\n";
268-
buf << " else if x < 0 then 0\n";
269-
buf << " else\n";
270-
buf << " let xNat := Int.toNat x\n";
271-
buf << " let nNat := Int.toNat n\n";
272-
buf << " -- Binary search for largest y such that y^n <= x\n";
273-
buf << " let rec search (low high : Nat) (fuel : Nat) : Nat :=\n";
274-
buf << " match fuel with\n";
275-
buf << " | 0 => low\n";
276-
buf << " | fuel' + 1 =>\n";
277-
buf << " if low >= high then low\n";
278-
buf << " else\n";
279-
buf << " let mid := (low + high + 1) / 2\n";
280-
buf << " if mid ^ nNat <= xNat then\n";
281-
buf << " search mid high fuel'\n";
282-
buf << " else\n";
283-
buf << " search low (mid - 1) fuel'\n";
284-
buf << " Int.ofNat (search 0 xNat 64)\n\n";
285-
}
286-
}
287-
}
288-
289263
if (functions.size() == 1) {
290264
buf << printFunction(functions[0]);
291265
} else {

src/form/lean.hpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ class LeanFormula {
2828
Formula main_formula;
2929
std::string domain; // Int or Nat
3030
std::set<std::string> imports;
31-
std::set<std::string> helper_funcs;
3231
std::vector<std::string> funcNames;
3332

3433
static bool initializeLeanProject();

tests/formula/formula.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ A007531: a(n) = n*(n-2)*(n-1)
8282
A007583: a(n) = 2*floor((4^n)/3)+1
8383
A008785: a(n) = (n+4)^n
8484
A008999: a(n) = 2*a(n-1)+a(n-4), a(4) = 17, a(3) = 8, a(2) = 4, a(1) = 2, a(0) = 1
85+
A010052: a(n) = (sqrtint(n)^2)==n
8586
A010170: a(n) = (gcd(-(n==0)+n,2)+7)*((n==0)+gcd(-(n==0)+n,2)+7)-63
8687
A010873: a(n) = b(n+3), b(n) = b(n-4), b(4) = 1, b(3) = 0, b(2) = 3, b(1) = 2, b(0) = 1
8788
A014731: a(n) = 4*b(n)^2, b(n) = 4*b(n-1)+b(n-2), b(1) = -2, b(0) = -1
@@ -122,3 +123,4 @@ A356334: a(n) = 3, a(2) = 4, a(1) = 3, a(0) = 1
122123
A356464: a(n) = sign(n-1)*((n-2)%2+1)+1
123124
A368047: a(n) = floor((n*(n+1)*(sign(n)*((n-1)%2+1)+1)*(n*sign(n)*((n-1)%2+1)+2))/12)
124125
A382019: a(n) = floor((-sign(-n)*((abs(-n)-1)%6+1)+2*n)/3)
126+
A999998: a(n) = (sqrtnint(n,3)^3)==n

tests/formula/lean.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ A008369: def a (n : Int) : Int := Int.fdiv ((Int.xor (n^2) 3)*((Int.xor (n^2) 3)
6565
A008531: def a (n : Int) : Int := n*(5*n^2+5)+(Bool.toInt (n==0))
6666
A008576: def a (n : Int) : Int := (Int.tdiv (8*n-2) 3)+1
6767
A008577: def a (n : Int) : Int := Int.fdiv ((2*n+1)^2+3) 3
68+
A010052: def a (n : Int) : Int := Bool.toInt (((Int.ofNat (Nat.sqrt (Int.toNat n)))^2)==n)
6869
A010170: def a (n : Int) : Int := ((Int.gcd (-(Bool.toInt (n==0))+n) 2)+7)*((Bool.toInt (n==0))+(Int.gcd (-(Bool.toInt (n==0))+n) 2)+7)-63
6970
A016729: def a (n : Int) : Int := (Int.fdiv ((Int.gcd ((Int.xor (n-1) 97)-1) (n+96))+n+96) 3)-31
7071
A021019: def a (n : Int) : Int := 6*(min n 1)
@@ -73,3 +74,4 @@ A080457: def a (n : Int) : Int := (Int.lor (3*n-6) 3)+4
7374
A106249: def a (n : Int) : Int := Int.tdiv ((Int.lor (n+1) 2)-1) 2
7475
A109008: def a (n : Int) : Int := Int.gcd n 4
7576
A129760: def a (n : Int) : Int := Int.land n (n-1)
77+
A999998: def a (n : Int) : Int := Bool.toInt (((Int.ofNat (Nat.nthRoot (Int.toNat 3) (Int.toNat n)))^3)==n)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
; A999998: Characteristic function of perfect cubes: a(n) = 1 if n is a cube, otherwise 0.
2+
; Test case for cube root in formulas
3+
4+
mov $1,$0
5+
nrt $0,3
6+
pow $0,3
7+
equ $0,$1

0 commit comments

Comments
 (0)