Skip to content

Commit 3d80751

Browse files
authored
Support rewriting of comprehension scopes (IronLanguages#1083)
* Support rewriting of comprehension scopes * Revert "Workaround for IronLanguages#817 (IronLanguages#818)" * Cleanup
1 parent 11e14a8 commit 3d80751

File tree

5 files changed

+99
-19
lines changed

5 files changed

+99
-19
lines changed

Src/IronPython/Compiler/Ast/Comprehension.cs

Lines changed: 21 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ public override Ast Reduce() {
6060
res
6161
);
6262
}
63+
64+
internal ComprehensionScope Scope { get; private protected set; }
65+
66+
internal Comprehension CopyForRewrite(ComprehensionScope scope) {
67+
var newComprehension = (Comprehension)MemberwiseClone();
68+
newComprehension.Scope = scope;
69+
newComprehension.Parent = scope.Parent;
70+
return newComprehension;
71+
}
6372
}
6473

6574
public sealed class ListComprehension : Comprehension {
@@ -106,10 +115,8 @@ public override void Walk(PythonWalker walker) {
106115
}
107116
walker.PostWalk(this);
108117
}
109-
110-
internal ComprehensionScope Scope { get; }
111118
}
112-
119+
113120
public sealed class SetComprehension : Comprehension {
114121
private readonly ComprehensionIterator[] _iterators;
115122

@@ -154,8 +161,6 @@ public override void Walk(PythonWalker walker) {
154161
}
155162
walker.PostWalk(this);
156163
}
157-
158-
internal ComprehensionScope Scope { get; }
159164
}
160165

161166
public sealed class DictionaryComprehension : Comprehension {
@@ -207,19 +212,17 @@ public override void Walk(PythonWalker walker) {
207212
}
208213
walker.PostWalk(this);
209214
}
210-
211-
internal ComprehensionScope Scope { get; }
212215
}
213216

214217
/// <summary>
215218
/// Scope for the comprehension. Because scopes are usually statements and comprehensions are expressions
216219
/// this doesn't actually show up in the AST hierarchy and instead hangs off the comprehension expression.
217220
/// </summary>
218221
internal class ComprehensionScope : ScopeStatement {
219-
private readonly Expression _comprehension;
222+
private readonly Comprehension _comprehension;
220223
private static readonly MSAst.ParameterExpression _compContext = Ast.Parameter(typeof(CodeContext), "$compContext");
221224

222-
public ComprehensionScope(Expression comprehension) {
225+
public ComprehensionScope(Comprehension comprehension) {
223226
_comprehension = comprehension;
224227
}
225228

@@ -229,12 +232,12 @@ internal override bool ExposesLocalVariable(PythonVariable variable) {
229232
} else if (variable.Scope == this) {
230233
return false;
231234
}
232-
return _comprehension.Parent.ExposesLocalVariable(variable);
235+
return Parent.ExposesLocalVariable(variable);
233236
}
234237

235238
internal override MSAst.Expression/*!*/ GetParentClosureTuple() {
236239
Debug.Assert(NeedsLocalContext);
237-
return MSAst.Expression.Call(null, typeof(PythonOps).GetMethod(nameof(PythonOps.GetClosureTupleFromContext)), _comprehension.Parent.LocalContext);
240+
return MSAst.Expression.Call(null, typeof(PythonOps).GetMethod(nameof(PythonOps.GetClosureTupleFromContext)), Parent.LocalContext);
238241
}
239242

240243
internal override bool TryBindOuter(ScopeStatement from, PythonReference reference, out PythonVariable variable) {
@@ -269,7 +272,7 @@ internal override PythonVariable BindReference(PythonNameBinder binder, PythonRe
269272
}
270273

271274
// then bind in our parent scope
272-
return _comprehension.Parent.BindReference(binder, reference);
275+
return Parent.BindReference(binder, reference);
273276
}
274277

275278
internal override Ast GetVariableExpression(PythonVariable variable) {
@@ -281,21 +284,23 @@ internal override Ast GetVariableExpression(PythonVariable variable) {
281284
return expr;
282285
}
283286

284-
return _comprehension.Parent.GetVariableExpression(variable);
287+
return Parent.GetVariableExpression(variable);
285288
}
286289

287290
internal override Microsoft.Scripting.Ast.LightLambdaExpression GetLambda()
288291
=> throw new NotImplementedException();
289292

290-
public override void Walk(PythonWalker walker) => _comprehension.Walk(walker);
293+
public override void Walk(PythonWalker walker) {
294+
_comprehension.Walk(walker);
295+
}
291296

292297
internal override Ast LocalContext {
293298
get {
294299
if (NeedsLocalContext) {
295300
return _compContext;
296301
}
297302

298-
return _comprehension.Parent.LocalContext;
303+
return Parent.LocalContext;
299304
}
300305
}
301306

@@ -311,7 +316,7 @@ internal Ast AddVariables(Ast expression) {
311316
CreateVariables(locals, body);
312317

313318
if (localContext != null) {
314-
var createLocal = CreateLocalContext(_comprehension.Parent.LocalContext);
319+
var createLocal = CreateLocalContext(Parent.LocalContext);
315320
body.Add(Ast.Assign(_compContext, createLocal));
316321
}
317322

Src/IronPython/Compiler/Ast/PythonAst.cs

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -699,7 +699,7 @@ internal void PrepareScope(ReadOnlyCollectionBuilder<MSAst.ParameterExpression>
699699

700700
/// <summary>
701701
/// Rewrites the tree for performing lookups against globals instead of being bound
702-
/// against the optimized scope. This is used if the user compiles optimied code and then
702+
/// against the optimized scope. This is used if the user compiles optimized code and then
703703
/// runs it against a different scope.
704704
/// </summary>
705705
internal PythonAst MakeLookupCode() {
@@ -739,7 +739,7 @@ protected override MSAst.Expression VisitExtension(MSAst.Expression node) {
739739
return PythonAst._globalContext;
740740
}
741741

742-
// we need to re-write nested scoeps
742+
// we need to re-write nested scopes
743743
if (node is ScopeStatement scope) {
744744
return base.VisitExtension(VisitScope(scope));
745745
}
@@ -752,6 +752,10 @@ protected override MSAst.Expression VisitExtension(MSAst.Expression node) {
752752
return base.VisitExtension(new GeneratorExpression((FunctionDefinition)VisitScope(generator.Function), generator.Iterable));
753753
}
754754

755+
if (node is Comprehension comprehension) {
756+
return VisitComprehension(comprehension);
757+
}
758+
755759
// update the global get/set/raw gets variables
756760
if (node is PythonGlobalVariableExpression global) {
757761
return new LookupGlobalVariable(
@@ -803,6 +807,22 @@ private ScopeStatement VisitScope(ScopeStatement scope) {
803807
}
804808
return newScope;
805809
}
810+
811+
private MSAst.Expression VisitComprehension(Comprehension comprehension) {
812+
var newScope = (ComprehensionScope)comprehension.Scope.CopyForRewrite();
813+
newScope.Parent = _curScope;
814+
var newComprehension = comprehension.CopyForRewrite(newScope);
815+
816+
ScopeStatement prevScope = _curScope;
817+
try {
818+
// rewrite the comprehension in a new scope
819+
_curScope = newScope;
820+
821+
return base.VisitExtension(newComprehension);
822+
} finally {
823+
_curScope = prevScope;
824+
}
825+
}
806826
}
807827

808828
#endregion

Src/IronPythonTest/EngineTest.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -618,6 +618,26 @@ public static void ScenarioTryGetMember() {
618618
Assert.AreEqual(result.ToString(), "IronPython.Runtime.Types.BuiltinFunction");
619619
}
620620

621+
[Test]
622+
public static void ScenarioComprehensionScope() {
623+
var engine = Python.CreateEngine();
624+
var scope = engine.CreateScope();
625+
var source = engine.CreateScriptSourceFromString("assert [lambda: i for i in [1,2]][0]() == 2");
626+
var compiledCode = source.Compile();
627+
compiledCode.Execute(scope);
628+
}
629+
630+
[Test]
631+
public static void ScenarioComprehensionScopeGlobal() {
632+
var engine = Python.CreateEngine();
633+
var source = engine.CreateScriptSourceFromString("assert [lambda: i for i in global_list][0]() == res");
634+
var compiledCode = source.Compile();
635+
var scope = engine.CreateScope();
636+
scope.SetVariable("global_list", new[] { 1, 2 });
637+
scope.SetVariable("res", 2);
638+
compiledCode.Execute(scope);
639+
}
640+
621641
[Test]
622642
public static void ScenarioInterfaceExtensions() {
623643
var engine = Python.CreateEngine();

Src/StdLib/Lib/sre_compile.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@
6464
)
6565

6666
# Maps the lowercase code to lowercase codes which have the same uppercase.
67-
_ignorecase_fixes = {i: tuple([j for j in t if i != j]) # https://github.com/IronLanguages/ironpython3/issues/817
67+
_ignorecase_fixes = {i: tuple(j for j in t if i != j)
6868
for t in _equivalences for i in t}
6969

7070
def _compile(code, pattern, flags):

Tests/test_closure.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# The .NET Foundation licenses this file to you under the Apache 2.0 License.
33
# See the LICENSE file in the project root for more information.
44

5+
import types
56
import unittest
67

78
from iptest import run_test
@@ -322,4 +323,38 @@ class C:
322323
exec("z = 7", globals())
323324
self.assertEqual(z, 7)
324325

326+
def test_gh817(self):
327+
# https://github.com/IronLanguages/ironpython3/issues/817
328+
329+
for x in [lambda: i for i in [1,2]]:
330+
self.assertEqual(x(), 2)
331+
332+
self.assertEqual([tuple(i for j in [1]) for i in [1]], [(1,)])
333+
self.assertEqual({tuple(i for j in [1]) for i in [1]}, {(1,)})
334+
335+
self.assertEqual({i: tuple(j for j in t if i != j) for t in ((1,2),) for i in t}, {1: (2,), 2: (1,)})
336+
337+
self.assertEqual([lambda: i for i in [1]][0](), 1)
338+
339+
self.assertEqual([x * a for a in range(3) if a == 2 for x in range(5,7)], [10, 12])
340+
self.assertRaises(UnboundLocalError, lambda: [x * a for a in range(3) if x == 2 for x in range(5,7)])
341+
342+
def foo1(z):
343+
return [(x for x in range(y)) for y in range(z)]
344+
self.assertEqual([list(x) for x in foo1(3)], [[], [0], [0, 1]])
345+
346+
def foo2(z):
347+
return [(x + z for x in range(y)) for y in range(z)]
348+
self.assertEqual([list(x) for x in foo2(3)], [[], [3], [3, 4]])
349+
350+
dl = [{ y: (z**2 for z in range(x)) for y in range(3)} for x in range(4)]
351+
self.assertIsInstance(dl, list)
352+
self.assertEqual(len(dl), 4)
353+
for p, d in enumerate(dl):
354+
self.assertIsInstance(d, dict)
355+
for k in range(3):
356+
g = d[k]
357+
self.assertIsInstance(g, types.GeneratorType)
358+
self.assertEqual(list(g), [0, 1, 4][:p])
359+
325360
run_test(__name__)

0 commit comments

Comments
 (0)