@@ -141,7 +141,7 @@ bool LeanFormula::convertToLean(Expression& expr, int64_t offset,
141141 break ;
142142 }
143143 if (expr.name == " sqrtnint" ) {
144- // sqrtnint(x, n) -> sqrtnint x n (custom helper function )
144+ // sqrtnint(x, n) -> Int.ofNat (Nat.nthRoot (Int.toNat n) (Int.toNat x) )
145145 if (expr.children .size () != 2 ) {
146146 return false ;
147147 }
@@ -153,9 +153,13 @@ bool LeanFormula::convertToLean(Expression& expr, int64_t offset,
153153 if (ExpressionUtil::canBeNegative (root, offset)) {
154154 return false ; // negative root index not supported
155155 }
156- // Use helper function sqrtnint
157- expr.name = " sqrtnint" ;
158- helper_funcs.insert (" sqrtnint" );
156+ // Use Nat.nthRoot from Mathlib
157+ Expression baseToNat (Expression::Type::FUNCTION, " Int.toNat" , {base});
158+ Expression rootToNat (Expression::Type::FUNCTION, " Int.toNat" , {root});
159+ Expression nthRoot (Expression::Type::FUNCTION, " Nat.nthRoot" , {rootToNat, baseToNat});
160+ Expression ofNat (Expression::Type::FUNCTION, " Int.ofNat" , {nthRoot});
161+ expr = ofNat;
162+ imports.insert (" Mathlib.Data.Nat.NthRoot.Defs" );
159163 break ;
160164 }
161165 // Allow calls to locally defined functions and sequences
@@ -256,36 +260,6 @@ std::string LeanFormula::toString() const {
256260 auto functions = FormulaUtil::getDefinitions (main_formula);
257261
258262 std::stringstream buf;
259-
260- // Print helper function definitions first
261- if (!helper_funcs.empty ()) {
262- for (const auto & func : helper_funcs) {
263- if (func == " sqrtnint" ) {
264- // Integer nth root: floor(x^(1/n))
265- // We compute it by binary search to find largest y such that y^n <= x
266- buf << " def sqrtnint (x : Int) (n : Int) : Int :=\n " ;
267- buf << " if n <= 0 then 0\n " ;
268- buf << " else if x < 0 then 0\n " ;
269- buf << " else\n " ;
270- buf << " let xNat := Int.toNat x\n " ;
271- buf << " let nNat := Int.toNat n\n " ;
272- buf << " -- Binary search for largest y such that y^n <= x\n " ;
273- buf << " let rec search (low high : Nat) (fuel : Nat) : Nat :=\n " ;
274- buf << " match fuel with\n " ;
275- buf << " | 0 => low\n " ;
276- buf << " | fuel' + 1 =>\n " ;
277- buf << " if low >= high then low\n " ;
278- buf << " else\n " ;
279- buf << " let mid := (low + high + 1) / 2\n " ;
280- buf << " if mid ^ nNat <= xNat then\n " ;
281- buf << " search mid high fuel'\n " ;
282- buf << " else\n " ;
283- buf << " search low (mid - 1) fuel'\n " ;
284- buf << " Int.ofNat (search 0 xNat 64)\n\n " ;
285- }
286- }
287- }
288-
289263 if (functions.size () == 1 ) {
290264 buf << printFunction (functions[0 ]);
291265 } else {
0 commit comments