Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions test/abstractarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1260,6 +1260,10 @@ Base.cconvert(::Type{Ptr{T}}, S::Strider{T}) where {T} = memoryref(S.data.ref, S

@testset "Simple 3d strided views and permutes" for sz in ((5, 3, 2), (7, 11, 13))
A = collect(reshape(1:prod(sz), sz))
# The following test takes pointers from A, we need to ensure A is not moved by GC.
# Furthermore, as pointer() returns the buffer address, we need to ensure the underlying buffer. We use tpin.
# If we take address from any newly allocation array in this test, it needs to be tpinned.
Base.increment_tpin_count!(A)
S = Strider(vec(A), strides(A), sz)
@test pointer(A) == pointer(S)
for i in 1:prod(sz)
Expand All @@ -1272,6 +1276,7 @@ Base.cconvert(::Type{Ptr{T}}, S::Strider{T}) where {T} = memoryref(S.data.ref, S
(sz[1]:-1:1, sz[2]:-1:1, sz[3]:-1:1),
(sz[1]-1:-3:1, sz[2]:-2:3, 1:sz[3]),)
Ai = A[idxs...]
Base.increment_tpin_count!(Ai)
Av = view(A, idxs...)
Sv = view(S, idxs...)
Ss = Strider{Int, 3}(vec(A), sum((first.(idxs).-1).*strides(A))+1, strides(Av), length.(idxs))
Expand All @@ -1282,6 +1287,7 @@ Base.cconvert(::Type{Ptr{T}}, S::Strider{T}) where {T} = memoryref(S.data.ref, S
end
for perm in ((3, 2, 1), (2, 1, 3), (3, 1, 2))
P = permutedims(A, perm)
Base.increment_tpin_count!(P)
Ap = Base.PermutedDimsArray(A, perm)
Sp = Base.PermutedDimsArray(S, perm)
Ps = Strider{Int, 3}(vec(A), 1, strides(A)[collect(perm)], sz[collect(perm)])
Expand All @@ -1303,7 +1309,9 @@ Base.cconvert(::Type{Ptr{T}}, S::Strider{T}) where {T} = memoryref(S.data.ref, S
@test Pi[i] == Pv[i] == Apv[i] == Spv[i] == Pvs[i]
end
Vp = permutedims(Av, perm)
Base.increment_tpin_count!(Vp)
Ip = permutedims(Ai, perm)
Base.increment_tpin_count!(Ip)
Avp = Base.PermutedDimsArray(Av, perm)
Svp = Base.PermutedDimsArray(Sv, perm)
@test pointer(Avp) == pointer(Svp)
Expand All @@ -1322,6 +1330,10 @@ end

@testset "simple 2d strided views, permutes, transposes" for sz in ((5, 3), (7, 11))
A = collect(reshape(1:prod(sz), sz))
# The following test takes pointers from A, we need to ensure A is not moved by GC.
# Furthermore, as pointer() returns the buffer address, we need to ensure the underlying buffer. We use tpin.
# If we take address from any newly allocation array in this test, it needs to be tpinned.
Base.increment_tpin_count!(A)
S = Strider(vec(A), strides(A), sz)
@test pointer(A) == pointer(S)
for i in 1:prod(sz)
Expand All @@ -1343,6 +1355,7 @@ end
end
perm = (2, 1)
P = permutedims(A, perm)
Base.increment_tpin_count!(P)
Ap = Base.PermutedDimsArray(A, perm)
At = transpose(A)
Aa = adjoint(A)
Expand Down Expand Up @@ -1372,6 +1385,7 @@ end
@test Pv[i] == Apv[i] == Spv[i] == Pvs[i] == Atv[i] == Ata[i] == Stv[i] == Sta[i]
end
Vp = permutedims(Av, perm)
Base.increment_tpin_count!(Vp)
Avp = Base.PermutedDimsArray(Av, perm)
Avt = transpose(Av)
Ava = adjoint(Av)
Expand Down Expand Up @@ -1915,6 +1929,7 @@ module IRUtils
end

function check_pointer_strides(A::AbstractArray)
Base.increment_tpin_count!(A)
# Make sure stride(A, i) is equivalent with strides(A)[i] (if 1 <= i <= ndims(A))
dims = ntuple(identity, ndims(A))
map(i -> stride(A, i), dims) == @inferred(strides(A)) || return false
Expand All @@ -1924,6 +1939,7 @@ function check_pointer_strides(A::AbstractArray)
for i in eachindex(IndexLinear(), A)
A[i] === Base.unsafe_load(pointer(A, i)) || return false
end
Base.decrement_tpin_count!(A)
return true
end

Expand Down
8 changes: 8 additions & 0 deletions test/arrayops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,10 @@ end
@test_throws ArgumentError PermutedDimsArray(a, (1,1,1))
@test_throws ArgumentError PermutedDimsArray(s, (1,1,1))
cp = PermutedDimsArray(c, (3,2,1))
# The following test takes pointers from c, we need to ensure c is not moved by GC.
Base.increment_tpin_count!(c)
@test pointer(cp) == pointer(c)
Base.decrement_tpin_count!(c)
@test_throws ArgumentError pointer(cp, 2)
@test strides(cp) == (9,3,1)
ap = PermutedDimsArray(Array(a), (2,1,3))
Expand Down Expand Up @@ -3045,14 +3048,19 @@ Base.:(==)(a::T11053, b::T11053) = a.a == b.a

# check a == b for arrays of Union type (#22403)
let TT = Union{UInt8, Int8}
# The following test takes pointers from a and b, we need to pin both
a = TT[0x0, 0x1]
Base.increment_tpin_count!(a)
b = TT[0x0, 0x0]
Base.increment_tpin_count!(b)
pa = pointer(a)
pb = pointer(b)
resize!(a, 1) # sets a[2] = 0
resize!(b, 1)
@assert pointer(a) == pa
@assert pointer(b) == pb
Base.decrement_tpin_count!(a)
Base.decrement_tpin_count!(b)
unsafe_store!(Ptr{UInt8}(pa), 0x1, 2) # reset a[2] to 1
@test length(a) == length(b) == 1
@test a[1] == b[1] == 0x0
Expand Down