Skip to content

Fix jl_object_id__cold segfault with race condition defence #635

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/Convert/Convert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ using ..Core
using ..Core:
C,
Utils,
Lockable,
@autopy,
getptr,
incref,
Expand Down
64 changes: 36 additions & 28 deletions src/Convert/pyconvert.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ struct PyConvertRule
priority::PyConvertPriority
end

const PYCONVERT_RULES = Dict{String,Vector{PyConvertRule}}()
const PYCONVERT_EXTRATYPES = Py[]
const PYCONVERT_RULES = Lockable(Dict{String,Vector{PyConvertRule}}())
const PYCONVERT_EXTRATYPES = Lockable(Py[])

"""
pyconvert_add_rule(tname::String, T::Type, func::Function, priority::PyConvertPriority=PYCONVERT_PRIORITY_NORMAL)
Expand Down Expand Up @@ -69,11 +69,11 @@ function pyconvert_add_rule(
priority::PyConvertPriority = PYCONVERT_PRIORITY_NORMAL,
)
@nospecialize type func
push!(
get!(Vector{PyConvertRule}, PYCONVERT_RULES, pytypename),
Base.@lock PYCONVERT_RULES push!(
get!(Vector{PyConvertRule}, PYCONVERT_RULES[], pytypename),
PyConvertRule(type, func, priority),
)
empty!.(values(PYCONVERT_RULES_CACHE))
Base.@lock PYCONVERT_RULES_CACHE empty!.(values(PYCONVERT_RULES_CACHE[]))
return
end

Expand Down Expand Up @@ -163,7 +163,7 @@ function _pyconvert_get_rules(pytype::Py)
omro = collect(pytype.__mro__)
basetypes = Py[pytype]
basemros = Vector{Py}[omro]
for xtype in PYCONVERT_EXTRATYPES
Base.@lock PYCONVERT_EXTRATYPES for xtype in PYCONVERT_EXTRATYPES[]
# find the topmost supertype of
xbase = PyNULL
for base in omro
Expand Down Expand Up @@ -248,9 +248,9 @@ function _pyconvert_get_rules(pytype::Py)
mro = String[x for xs in xmro for x in xs]

# get corresponding rules
rules = PyConvertRule[
rules = Base.@lock PYCONVERT_RULES PyConvertRule[
rule for tname in mro for
rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES, tname)
rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES[], tname)
]

# order the rules by priority, then by original order
Expand All @@ -261,10 +261,10 @@ function _pyconvert_get_rules(pytype::Py)
return rules
end

const PYCONVERT_PREFERRED_TYPE = Dict{Py,Type}()
const PYCONVERT_PREFERRED_TYPE = Lockable(Dict{Py,Type}())

pyconvert_preferred_type(pytype::Py) =
get!(PYCONVERT_PREFERRED_TYPE, pytype) do
Base.@lock PYCONVERT_PREFERRED_TYPE get!(PYCONVERT_PREFERRED_TYPE[], pytype) do
if pyissubclass(pytype, pybuiltins.int)
Union{Int,BigInt}
else
Expand Down Expand Up @@ -307,10 +307,15 @@ end

pyconvert_fix(::Type{T}, func) where {T} = x -> func(T, x)

const PYCONVERT_RULES_CACHE = Dict{Type,Dict{C.PyPtr,Vector{Function}}}()
const PYCONVERT_RULES_CACHE = Lockable(IdDict{Any,Dict{C.PyPtr,Vector{Function}}}())

@generated pyconvert_rules_cache(::Type{T}) where {T} =
get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE, T)
function pyconvert_rules_cache(::Type{T}) where {T}
Base.@lock PYCONVERT_RULES_CACHE get!(
Dict{C.PyPtr,Vector{Function}},
PYCONVERT_RULES_CACHE[],
T,
)
end

function pyconvert_rule_fast(::Type{T}, x::Py) where {T}
if T isa Union
Expand Down Expand Up @@ -351,12 +356,13 @@ function pytryconvert(::Type{T}, x_) where {T}
# get rules from the cache
# TODO: we should hold weak references and clear the cache if types get deleted
tptr = C.Py_Type(x)
trules = pyconvert_rules_cache(T)
rules = get!(trules, tptr) do
t = pynew(incref(tptr))
ans = pyconvert_get_rules(T, t)::Vector{Function}
pydel!(t)
ans
rules = Base.@lock PYCONVERT_RULES_CACHE let trules = pyconvert_rules_cache(T)
get!(trules, tptr) do
t = pynew(incref(tptr))
ans = pyconvert_get_rules(T, t)::Vector{Function}
pydel!(t)
ans
end
end

# apply the rules
Expand Down Expand Up @@ -418,15 +424,17 @@ pyconvertarg(::Type{T}, x, name) where {T} = @autopy x @pyconvert T x_ begin
end

function init_pyconvert()
push!(PYCONVERT_EXTRATYPES, pyimport("io" => "IOBase"))
push!(
PYCONVERT_EXTRATYPES,
pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))...,
)
push!(
PYCONVERT_EXTRATYPES,
pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))...,
)
Base.@lock PYCONVERT_EXTRATYPES begin
push!(PYCONVERT_EXTRATYPES[], pyimport("io" => "IOBase"))
push!(
PYCONVERT_EXTRATYPES[],
pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))...,
)
push!(
PYCONVERT_EXTRATYPES[],
pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))...,
)
end

priority = PYCONVERT_PRIORITY_CANONICAL
pyconvert_add_rule("builtins:NoneType", Nothing, pyconvert_rule_none, priority)
Expand Down
2 changes: 1 addition & 1 deletion src/Core/Core.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const ROOT_DIR = dirname(dirname(@__DIR__))
using ..PythonCall: PythonCall # needed for docstring cross-refs
using ..C: C
using ..GC: GC
using ..Utils: Utils
using ..Utils: Utils, Lockable
using Base: @propagate_inbounds, @kwdef
using Dates:
Date,
Expand Down
11 changes: 6 additions & 5 deletions src/Core/Py.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ decref(x::Py) = Base.GC.@preserve x (decref(getptr(x)); x)

Base.unsafe_convert(::Type{C.PyPtr}, x::Py) = getptr(x)

const PYNULL_CACHE = Py[]
const PYNULL_CACHE = Lockable(Py[])

"""
pynew([ptr])
Expand All @@ -69,12 +69,13 @@ points at, i.e. the new `Py` object owns a reference.
Note that NULL Python objects are not safe in the sense that most API functions will probably
crash your Julia session if you pass a NULL argument.
"""
pynew() =
if isempty(PYNULL_CACHE)
pynew() = Base.@lock PYNULL_CACHE begin
if isempty(PYNULL_CACHE[])
Py(Val(:new), C.PyNULL)
else
pop!(PYNULL_CACHE)
pop!(PYNULL_CACHE[])
end
end

const PyNULL = pynew()

Expand Down Expand Up @@ -119,7 +120,7 @@ function pydel!(x::Py)
C.Py_DecRef(ptr)
setptr!(x, C.PyNULL)
end
push!(PYNULL_CACHE, x)
Base.@lock PYNULL_CACHE push!(PYNULL_CACHE[], x)
return
end

Expand Down
4 changes: 2 additions & 2 deletions src/Core/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1206,7 +1206,7 @@ export pyfraction

### eval/exec

const MODULE_GLOBALS = Dict{Module,Py}()
const MODULE_GLOBALS = Lockable(Dict{Module,Py}())

function _pyeval_args(code, globals, locals)
if code isa AbstractString
Expand All @@ -1217,7 +1217,7 @@ function _pyeval_args(code, globals, locals)
throw(ArgumentError("code must be a string or Python code"))
end
if globals isa Module
globals_ = get!(pydict, MODULE_GLOBALS, globals)
globals_ = Base.@lock MODULE_GLOBALS get!(pydict, MODULE_GLOBALS[], globals)
elseif ispy(globals)
globals_ = globals
else
Expand Down
53 changes: 31 additions & 22 deletions src/JlWrap/C.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
module Cjl

using ...C: C
using ...Utils: Utils
using ...Utils: Utils, Lockable
using Base: @kwdef
using UnsafePointers: UnsafePtr
using Serialization: serialize, deserialize
Expand All @@ -16,9 +16,7 @@ const PyJuliaBase_Type = Ref(C.PyNULL)

# we store the actual julia values here
# the `value` field of `PyJuliaValueObject` indexes into here
const PYJLVALUES = []
# unused indices in PYJLVALUES
const PYJLFREEVALUES = Int[]
const PYJLVALUES = Lockable((; values=IdDict{Int,Any}(), free_slots=Int[], next_slot=Ref(1)))

function _pyjl_new(t::C.PyPtr, ::C.PyPtr, ::C.PyPtr)
o = ccall(UnsafePtr{C.PyTypeObject}(t).alloc[!], C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0)
Expand All @@ -31,20 +29,24 @@ end
function _pyjl_dealloc(o::C.PyPtr)
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
if idx != 0
PYJLVALUES[idx] = nothing
push!(PYJLFREEVALUES, idx)
Base.@lock PYJLVALUES begin
delete!(PYJLVALUES[].values, idx)
push!(PYJLVALUES[].free_slots, idx)
end
end
UnsafePtr{PyJuliaValueObject}(o).weaklist[!] == C.PyNULL || C.PyObject_ClearWeakRefs(o)
ccall(UnsafePtr{C.PyTypeObject}(C.Py_Type(o)).free[!], Cvoid, (C.PyPtr,), o)
nothing
end

const PYJLMETHODS = Vector{Any}()
const PYJLMETHODS = Lockable([])

function PyJulia_MethodNum(f)
@nospecialize f
push!(PYJLMETHODS, f)
return length(PYJLMETHODS)
Base.@lock PYJLMETHODS begin
push!(PYJLMETHODS[], f)
return length(PYJLMETHODS[])
end
end

function _pyjl_isnull(o::C.PyPtr, ::C.PyPtr)
Expand All @@ -58,12 +60,12 @@ function _pyjl_callmethod(o::C.PyPtr, args::C.PyPtr)
@assert nargs > 0
num = C.PyLong_AsLongLong(C.PyTuple_GetItem(args, 0))
num == -1 && return C.PyNULL
f = PYJLMETHODS[num]
f = Base.@lock PYJLMETHODS PYJLMETHODS[][num]
# this form gets defined in jlwrap/base.jl
return _pyjl_callmethod(f, o, args, nargs)::C.PyPtr
end

const PYJLBUFCACHE = Dict{Ptr{Cvoid},Any}()
const PYJLBUFCACHE = Lockable(Dict{Ptr{Cvoid},Any}())

@kwdef struct PyBufferInfo{N}
# data
Expand Down Expand Up @@ -177,7 +179,7 @@ function _pyjl_get_buffer_impl(

# internal
cptr = Base.pointer_from_objref(c)
PYJLBUFCACHE[cptr] = c
Base.@lock PYJLBUFCACHE PYJLBUFCACHE[][cptr] = c
b.internal[] = cptr

# obj
Expand All @@ -195,7 +197,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint)
C.Py_DecRef(num_)
num == -1 && return Cint(-1)
try
f = PYJLMETHODS[num]
f = Base.@lock PYJLMETHODS PYJLMETHODS[][num]
x = PyJuliaValue_GetValue(o)
return _pyjl_get_buffer_impl(o, buf, flags, x, f)::Cint
catch exc
Expand All @@ -209,7 +211,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint)
end

function _pyjl_release_buffer(xo::C.PyPtr, buf::Ptr{C.Py_buffer})
delete!(PYJLBUFCACHE, UnsafePtr(buf).internal[!])
Base.@lock PYJLBUFCACHE delete!(PYJLBUFCACHE[], UnsafePtr(buf).internal[!])
nothing
end

Expand Down Expand Up @@ -339,22 +341,29 @@ end

PyJuliaValue_IsNull(o) = Base.GC.@preserve o UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[] == 0

PyJuliaValue_GetValue(o) = Base.GC.@preserve o PYJLVALUES[UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]]
PyJuliaValue_GetValue(o) = Base.GC.@preserve o begin
idx = UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]
Base.@lock PYJLVALUES PYJLVALUES[].values[idx]
end

PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin
o = C.asptr(_o)
idx = UnsafePtr{PyJuliaValueObject}(o).value[]
if idx == 0
if isempty(PYJLFREEVALUES)
push!(PYJLVALUES, v)
idx = length(PYJLVALUES)
else
idx = pop!(PYJLFREEVALUES)
PYJLVALUES[idx] = v
Base.@lock PYJLVALUES begin
if isempty(PYJLVALUES[].free_slots)
idx = PYJLVALUES[].next_slot[]
PYJLVALUES[].next_slot[] += 1
else
idx = pop!(PYJLVALUES[].free_slots)
end
PYJLVALUES[].values[idx] = v
end
UnsafePtr{PyJuliaValueObject}(o).value[] = idx
else
PYJLVALUES[idx] = v
Base.@lock PYJLVALUES begin
PYJLVALUES[].values[idx] = v
end
end
nothing
end
Expand Down
1 change: 1 addition & 0 deletions src/JlWrap/base.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ function Cjl._pyjl_callmethod(f, self_::C.PyPtr, args_::C.PyPtr, nargs::C.Py_ssi
pybuiltins.NotImplementedError,
"__jl_callmethod not implemented for this many arguments",
)
return C.PyNULL
end
return getptr(incref(ans))
catch exc
Expand Down
26 changes: 26 additions & 0 deletions src/Utils/Utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,4 +308,30 @@ function Base.iterate(x::StaticString{UInt32,N}, i::Int = 1) where {N}
end
end

@static if !isdefined(Base, :Lockable)
"""
Compat for `Base.Lockable` (introduced in Julia 1.11)
"""
struct Lockable{T,L}
value::T
lock::L
end

Lockable(value) = Lockable(value, ReentrantLock())

# function Base.lock(f, l::Lockable)
# lock(l.lock) do
# f(l.value)
# end
# end

Base.lock(l::Lockable) = lock(l.lock)
# Base.trylock(l::Lockable) = trylock(l.lock)
Base.unlock(l::Lockable) = unlock(l.lock)
Base.islocked(l::Lockable) = islocked(l.lock)
Base.getindex(l::Lockable) = (@assert islocked(l); l.value)
else
const Lockable = Base.Lockable
end

end
Loading