Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions graalpython/com.oracle.graal.python.test/src/tests/test_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
# SOFTWARE.

import unittest, sys
from collections import defaultdict

graalpy_only = unittest.skipUnless(sys.implementation.name == "graalpy", "GraalPy-specific dict storage test")

Expand Down Expand Up @@ -577,6 +578,38 @@ def test_copy():
assert set(d1.keys()) == {'a', 'b', 'c'}


def test_defaultdict_operations_subclass_preserve_type():
class DefaultDictSubclass(defaultdict):
pass

d = DefaultDictSubclass(int, a=1)
copied = d.copy()
merged = d | {"b": 2}
rmerged = {"b": 2} | d

assert type(copied) is DefaultDictSubclass
assert copied.default_factory is int
assert dict(copied) == {"a": 1}
assert type(merged) is DefaultDictSubclass
assert merged.default_factory is int
assert dict(merged) == {"a": 1, "b": 2}
assert type(rmerged) is DefaultDictSubclass
assert rmerged.default_factory is int
assert dict(rmerged) == {"b": 2, "a": 1}


def test_dict_operations_return_builtin_dict_for_subclass():
class DictSubclass(dict):
pass

d = DictSubclass(a=1)
other = {"b": 2}

assert type(d.copy()) is dict
assert type(d | other) is dict
assert type(other | d) is dict


def test_keywords():
def modifying(**kwargs):
kwargs["a"] = 10
Expand Down
37 changes: 37 additions & 0 deletions graalpython/com.oracle.graal.python.test/src/tests/test_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,3 +686,40 @@ def test_set_iterator_reduce():
it = s.__iter__()
it.__reduce__()
assert [i for i in it] == [1, 2, 3]


def test_set_operations_return_builtin_set_for_subclass():
class SetSubclass(set):
pass

s = SetSubclass([1, 2])
other = {2, 3}

assert type(s.copy()) is set
assert type(s | other) is set
assert type(other | s) is set
assert type(s & other) is set
assert type(s - other) is set
assert type(s ^ other) is set
assert type(s.union(other)) is set
assert type(s.intersection(other)) is set
assert type(s.difference(other)) is set
assert type(s.symmetric_difference(other)) is set


def test_frozenset_operations_return_builtin_frozenset_for_subclass():
class FrozenSetSubclass(frozenset):
pass

f = FrozenSetSubclass([1, 2])
other = {2, 3}

assert type(f.copy()) is frozenset
assert type(f | other) is frozenset
assert type(f & other) is frozenset
assert type(f - other) is frozenset
assert type(f ^ other) is frozenset
assert type(f.union(other)) is frozenset
assert type(f.intersection(other)) is frozenset
assert type(f.difference(other)) is frozenset
assert type(f.symmetric_difference(other)) is frozenset
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2021, 2025, Oracle and/or its affiliates. All rights reserved.
* Copyright (c) 2021, 2026, Oracle and/or its affiliates. All rights reserved.
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
*
* The Universal Permissive License (UPL), Version 1.0
Expand Down Expand Up @@ -78,6 +78,7 @@
import com.oracle.graal.python.nodes.object.GetClassNode;
import com.oracle.graal.python.runtime.object.PFactory;
import com.oracle.graal.python.util.PythonUtils;
import com.oracle.truffle.api.HostCompilerDirectives.InliningCutoff;
import com.oracle.truffle.api.dsl.Bind;
import com.oracle.truffle.api.dsl.Cached;
import com.oracle.truffle.api.dsl.Fallback;
Expand Down Expand Up @@ -149,13 +150,27 @@ static Object reduce(VirtualFrame frame, PDefaultDict self,
@Builtin(name = "copy", minNumOfPositionalArgs = 1)
@GenerateNodeFactory
public abstract static class CopyNode extends PythonUnaryBuiltinNode {
@Specialization
static PDefaultDict copy(@SuppressWarnings("unused") VirtualFrame frame, PDefaultDict self,
@Specialization(guards = "isBuiltinDefaultDict(self)")
static PDefaultDict copyBuiltin(PDefaultDict self,
@Bind Node inliningTarget,
@Cached HashingStorageCopy copyNode,
@Bind PythonLanguage language) {
return PFactory.createDefaultDict(language, self.getDefaultFactory(), copyNode.execute(inliningTarget, self.getDictStorage()));
}

@Fallback
@InliningCutoff
static Object copyGeneric(VirtualFrame frame, Object self,
@Bind Node inliningTarget,
@Cached GetClassNode getClassNode,
@Cached CallNode callNode) {
PDefaultDict defaultDict = (PDefaultDict) self;
return callNode.execute(frame, getClassNode.execute(inliningTarget, defaultDict), defaultDict.getDefaultFactory(), defaultDict);
}

static boolean isBuiltinDefaultDict(PDefaultDict self) {
return self.getPythonClass() == PythonBuiltinClassType.PDefaultDict;
}
}

@Builtin(name = J___MISSING__, minNumOfPositionalArgs = 2)
Expand Down
Loading