From 4856fe737f54738c1fc7d6dc429c3e6f7388bcf5 Mon Sep 17 00:00:00 2001 From: stepan Date: Wed, 3 Jun 2026 09:45:07 +0200 Subject: [PATCH] Preserve defaultdict subclass copy type --- .../src/tests/test_dict.py | 33 +++++++++++++++++ .../src/tests/test_set.py | 37 +++++++++++++++++++ .../objects/dict/DefaultDictBuiltins.java | 21 +++++++++-- 3 files changed, 88 insertions(+), 3 deletions(-) diff --git a/graalpython/com.oracle.graal.python.test/src/tests/test_dict.py b/graalpython/com.oracle.graal.python.test/src/tests/test_dict.py index 3f6a93465a..314855f55a 100644 --- a/graalpython/com.oracle.graal.python.test/src/tests/test_dict.py +++ b/graalpython/com.oracle.graal.python.test/src/tests/test_dict.py @@ -38,6 +38,7 @@ # SOFTWARE. import unittest, sys +from collections import defaultdict graalpy_only = unittest.skipUnless(sys.implementation.name == "graalpy", "GraalPy-specific dict storage test") @@ -577,6 +578,38 @@ def test_copy(): assert set(d1.keys()) == {'a', 'b', 'c'} +def test_defaultdict_operations_subclass_preserve_type(): + class DefaultDictSubclass(defaultdict): + pass + + d = DefaultDictSubclass(int, a=1) + copied = d.copy() + merged = d | {"b": 2} + rmerged = {"b": 2} | d + + assert type(copied) is DefaultDictSubclass + assert copied.default_factory is int + assert dict(copied) == {"a": 1} + assert type(merged) is DefaultDictSubclass + assert merged.default_factory is int + assert dict(merged) == {"a": 1, "b": 2} + assert type(rmerged) is DefaultDictSubclass + assert rmerged.default_factory is int + assert dict(rmerged) == {"b": 2, "a": 1} + + +def test_dict_operations_return_builtin_dict_for_subclass(): + class DictSubclass(dict): + pass + + d = DictSubclass(a=1) + other = {"b": 2} + + assert type(d.copy()) is dict + assert type(d | other) is dict + assert type(other | d) is dict + + def test_keywords(): def modifying(**kwargs): kwargs["a"] = 10 diff --git a/graalpython/com.oracle.graal.python.test/src/tests/test_set.py b/graalpython/com.oracle.graal.python.test/src/tests/test_set.py index af1fceac19..4fdd291dc9 100644 --- a/graalpython/com.oracle.graal.python.test/src/tests/test_set.py +++ b/graalpython/com.oracle.graal.python.test/src/tests/test_set.py @@ -686,3 +686,40 @@ def test_set_iterator_reduce(): it = s.__iter__() it.__reduce__() assert [i for i in it] == [1, 2, 3] + + +def test_set_operations_return_builtin_set_for_subclass(): + class SetSubclass(set): + pass + + s = SetSubclass([1, 2]) + other = {2, 3} + + assert type(s.copy()) is set + assert type(s | other) is set + assert type(other | s) is set + assert type(s & other) is set + assert type(s - other) is set + assert type(s ^ other) is set + assert type(s.union(other)) is set + assert type(s.intersection(other)) is set + assert type(s.difference(other)) is set + assert type(s.symmetric_difference(other)) is set + + +def test_frozenset_operations_return_builtin_frozenset_for_subclass(): + class FrozenSetSubclass(frozenset): + pass + + f = FrozenSetSubclass([1, 2]) + other = {2, 3} + + assert type(f.copy()) is frozenset + assert type(f | other) is frozenset + assert type(f & other) is frozenset + assert type(f - other) is frozenset + assert type(f ^ other) is frozenset + assert type(f.union(other)) is frozenset + assert type(f.intersection(other)) is frozenset + assert type(f.difference(other)) is frozenset + assert type(f.symmetric_difference(other)) is frozenset diff --git a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DefaultDictBuiltins.java b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DefaultDictBuiltins.java index 4bd39b2182..5c4d891d71 100644 --- a/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DefaultDictBuiltins.java +++ b/graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/dict/DefaultDictBuiltins.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2021, 2026, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * The Universal Permissive License (UPL), Version 1.0 @@ -78,6 +78,7 @@ import com.oracle.graal.python.nodes.object.GetClassNode; import com.oracle.graal.python.runtime.object.PFactory; import com.oracle.graal.python.util.PythonUtils; +import com.oracle.truffle.api.HostCompilerDirectives.InliningCutoff; import com.oracle.truffle.api.dsl.Bind; import com.oracle.truffle.api.dsl.Cached; import com.oracle.truffle.api.dsl.Fallback; @@ -149,13 +150,27 @@ static Object reduce(VirtualFrame frame, PDefaultDict self, @Builtin(name = "copy", minNumOfPositionalArgs = 1) @GenerateNodeFactory public abstract static class CopyNode extends PythonUnaryBuiltinNode { - @Specialization - static PDefaultDict copy(@SuppressWarnings("unused") VirtualFrame frame, PDefaultDict self, + @Specialization(guards = "isBuiltinDefaultDict(self)") + static PDefaultDict copyBuiltin(PDefaultDict self, @Bind Node inliningTarget, @Cached HashingStorageCopy copyNode, @Bind PythonLanguage language) { return PFactory.createDefaultDict(language, self.getDefaultFactory(), copyNode.execute(inliningTarget, self.getDictStorage())); } + + @Fallback + @InliningCutoff + static Object copyGeneric(VirtualFrame frame, Object self, + @Bind Node inliningTarget, + @Cached GetClassNode getClassNode, + @Cached CallNode callNode) { + PDefaultDict defaultDict = (PDefaultDict) self; + return callNode.execute(frame, getClassNode.execute(inliningTarget, defaultDict), defaultDict.getDefaultFactory(), defaultDict); + } + + static boolean isBuiltinDefaultDict(PDefaultDict self) { + return self.getPythonClass() == PythonBuiltinClassType.PDefaultDict; + } } @Builtin(name = J___MISSING__, minNumOfPositionalArgs = 2)