diff --git a/src/Convert/Convert.jl b/src/Convert/Convert.jl index a97f1a48..688d022e 100644 --- a/src/Convert/Convert.jl +++ b/src/Convert/Convert.jl @@ -9,6 +9,7 @@ using ..Core using ..Core: C, Utils, + Lockable, @autopy, getptr, incref, diff --git a/src/Convert/pyconvert.jl b/src/Convert/pyconvert.jl index d61268de..9aa3591e 100644 --- a/src/Convert/pyconvert.jl +++ b/src/Convert/pyconvert.jl @@ -12,8 +12,8 @@ struct PyConvertRule priority::PyConvertPriority end -const PYCONVERT_RULES = Dict{String,Vector{PyConvertRule}}() -const PYCONVERT_EXTRATYPES = Py[] +const PYCONVERT_RULES = Lockable(Dict{String,Vector{PyConvertRule}}()) +const PYCONVERT_EXTRATYPES = Lockable(Py[]) """ pyconvert_add_rule(tname::String, T::Type, func::Function, priority::PyConvertPriority=PYCONVERT_PRIORITY_NORMAL) @@ -69,11 +69,11 @@ function pyconvert_add_rule( priority::PyConvertPriority = PYCONVERT_PRIORITY_NORMAL, ) @nospecialize type func - push!( - get!(Vector{PyConvertRule}, PYCONVERT_RULES, pytypename), + Base.@lock PYCONVERT_RULES push!( + get!(Vector{PyConvertRule}, PYCONVERT_RULES[], pytypename), PyConvertRule(type, func, priority), ) - empty!.(values(PYCONVERT_RULES_CACHE)) + Base.@lock PYCONVERT_RULES_CACHE empty!.(values(PYCONVERT_RULES_CACHE[])) return end @@ -163,7 +163,7 @@ function _pyconvert_get_rules(pytype::Py) omro = collect(pytype.__mro__) basetypes = Py[pytype] basemros = Vector{Py}[omro] - for xtype in PYCONVERT_EXTRATYPES + Base.@lock PYCONVERT_EXTRATYPES for xtype in PYCONVERT_EXTRATYPES[] # find the topmost supertype of xbase = PyNULL for base in omro @@ -248,9 +248,9 @@ function _pyconvert_get_rules(pytype::Py) mro = String[x for xs in xmro for x in xs] # get corresponding rules - rules = PyConvertRule[ + rules = Base.@lock PYCONVERT_RULES PyConvertRule[ rule for tname in mro for - rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES, tname) + rule in get!(Vector{PyConvertRule}, PYCONVERT_RULES[], tname) ] # order the rules by priority, then by original order @@ -261,10 +261,10 @@ function _pyconvert_get_rules(pytype::Py) return rules end -const PYCONVERT_PREFERRED_TYPE = Dict{Py,Type}() +const PYCONVERT_PREFERRED_TYPE = Lockable(Dict{Py,Type}()) pyconvert_preferred_type(pytype::Py) = - get!(PYCONVERT_PREFERRED_TYPE, pytype) do + Base.@lock PYCONVERT_PREFERRED_TYPE get!(PYCONVERT_PREFERRED_TYPE[], pytype) do if pyissubclass(pytype, pybuiltins.int) Union{Int,BigInt} else @@ -307,10 +307,15 @@ end pyconvert_fix(::Type{T}, func) where {T} = x -> func(T, x) -const PYCONVERT_RULES_CACHE = Dict{Type,Dict{C.PyPtr,Vector{Function}}}() +const PYCONVERT_RULES_CACHE = Lockable(IdDict{Any,Dict{C.PyPtr,Vector{Function}}}()) -@generated pyconvert_rules_cache(::Type{T}) where {T} = - get!(Dict{C.PyPtr,Vector{Function}}, PYCONVERT_RULES_CACHE, T) +function pyconvert_rules_cache(::Type{T}) where {T} + Base.@lock PYCONVERT_RULES_CACHE get!( + Dict{C.PyPtr,Vector{Function}}, + PYCONVERT_RULES_CACHE[], + T, + ) +end function pyconvert_rule_fast(::Type{T}, x::Py) where {T} if T isa Union @@ -351,12 +356,13 @@ function pytryconvert(::Type{T}, x_) where {T} # get rules from the cache # TODO: we should hold weak references and clear the cache if types get deleted tptr = C.Py_Type(x) - trules = pyconvert_rules_cache(T) - rules = get!(trules, tptr) do - t = pynew(incref(tptr)) - ans = pyconvert_get_rules(T, t)::Vector{Function} - pydel!(t) - ans + rules = Base.@lock PYCONVERT_RULES_CACHE let trules = pyconvert_rules_cache(T) + get!(trules, tptr) do + t = pynew(incref(tptr)) + ans = pyconvert_get_rules(T, t)::Vector{Function} + pydel!(t) + ans + end end # apply the rules @@ -418,15 +424,17 @@ pyconvertarg(::Type{T}, x, name) where {T} = @autopy x @pyconvert T x_ begin end function init_pyconvert() - push!(PYCONVERT_EXTRATYPES, pyimport("io" => "IOBase")) - push!( - PYCONVERT_EXTRATYPES, - pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))..., - ) - push!( - PYCONVERT_EXTRATYPES, - pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))..., - ) + Base.@lock PYCONVERT_EXTRATYPES begin + push!(PYCONVERT_EXTRATYPES[], pyimport("io" => "IOBase")) + push!( + PYCONVERT_EXTRATYPES[], + pyimport("numbers" => ("Number", "Complex", "Real", "Rational", "Integral"))..., + ) + push!( + PYCONVERT_EXTRATYPES[], + pyimport("collections.abc" => ("Iterable", "Sequence", "Set", "Mapping"))..., + ) + end priority = PYCONVERT_PRIORITY_CANONICAL pyconvert_add_rule("builtins:NoneType", Nothing, pyconvert_rule_none, priority) diff --git a/src/Core/Core.jl b/src/Core/Core.jl index 886eaff9..7c95326b 100644 --- a/src/Core/Core.jl +++ b/src/Core/Core.jl @@ -11,7 +11,7 @@ const ROOT_DIR = dirname(dirname(@__DIR__)) using ..PythonCall: PythonCall # needed for docstring cross-refs using ..C: C using ..GC: GC -using ..Utils: Utils +using ..Utils: Utils, Lockable using Base: @propagate_inbounds, @kwdef using Dates: Date, diff --git a/src/Core/Py.jl b/src/Core/Py.jl index a457ce90..b2f8a861 100644 --- a/src/Core/Py.jl +++ b/src/Core/Py.jl @@ -56,7 +56,7 @@ decref(x::Py) = Base.GC.@preserve x (decref(getptr(x)); x) Base.unsafe_convert(::Type{C.PyPtr}, x::Py) = getptr(x) -const PYNULL_CACHE = Py[] +const PYNULL_CACHE = Lockable(Py[]) """ pynew([ptr]) @@ -69,12 +69,13 @@ points at, i.e. the new `Py` object owns a reference. Note that NULL Python objects are not safe in the sense that most API functions will probably crash your Julia session if you pass a NULL argument. """ -pynew() = - if isempty(PYNULL_CACHE) +pynew() = Base.@lock PYNULL_CACHE begin + if isempty(PYNULL_CACHE[]) Py(Val(:new), C.PyNULL) else - pop!(PYNULL_CACHE) + pop!(PYNULL_CACHE[]) end +end const PyNULL = pynew() @@ -119,7 +120,7 @@ function pydel!(x::Py) C.Py_DecRef(ptr) setptr!(x, C.PyNULL) end - push!(PYNULL_CACHE, x) + Base.@lock PYNULL_CACHE push!(PYNULL_CACHE[], x) return end diff --git a/src/Core/builtins.jl b/src/Core/builtins.jl index ec7eb10b..62d58309 100644 --- a/src/Core/builtins.jl +++ b/src/Core/builtins.jl @@ -1206,7 +1206,7 @@ export pyfraction ### eval/exec -const MODULE_GLOBALS = Dict{Module,Py}() +const MODULE_GLOBALS = Lockable(Dict{Module,Py}()) function _pyeval_args(code, globals, locals) if code isa AbstractString @@ -1217,7 +1217,7 @@ function _pyeval_args(code, globals, locals) throw(ArgumentError("code must be a string or Python code")) end if globals isa Module - globals_ = get!(pydict, MODULE_GLOBALS, globals) + globals_ = Base.@lock MODULE_GLOBALS get!(pydict, MODULE_GLOBALS[], globals) elseif ispy(globals) globals_ = globals else diff --git a/src/JlWrap/C.jl b/src/JlWrap/C.jl index fa96dd36..5a91d330 100644 --- a/src/JlWrap/C.jl +++ b/src/JlWrap/C.jl @@ -1,7 +1,7 @@ module Cjl using ...C: C -using ...Utils: Utils +using ...Utils: Utils, Lockable using Base: @kwdef using UnsafePointers: UnsafePtr using Serialization: serialize, deserialize @@ -16,9 +16,7 @@ const PyJuliaBase_Type = Ref(C.PyNULL) # we store the actual julia values here # the `value` field of `PyJuliaValueObject` indexes into here -const PYJLVALUES = [] -# unused indices in PYJLVALUES -const PYJLFREEVALUES = Int[] +const PYJLVALUES = Lockable((; values=IdDict{Int,Any}(), free_slots=Int[], next_slot=Ref(1))) function _pyjl_new(t::C.PyPtr, ::C.PyPtr, ::C.PyPtr) o = ccall(UnsafePtr{C.PyTypeObject}(t).alloc[!], C.PyPtr, (C.PyPtr, C.Py_ssize_t), t, 0) @@ -31,20 +29,24 @@ end function _pyjl_dealloc(o::C.PyPtr) idx = UnsafePtr{PyJuliaValueObject}(o).value[] if idx != 0 - PYJLVALUES[idx] = nothing - push!(PYJLFREEVALUES, idx) + Base.@lock PYJLVALUES begin + delete!(PYJLVALUES[].values, idx) + push!(PYJLVALUES[].free_slots, idx) + end end UnsafePtr{PyJuliaValueObject}(o).weaklist[!] == C.PyNULL || C.PyObject_ClearWeakRefs(o) ccall(UnsafePtr{C.PyTypeObject}(C.Py_Type(o)).free[!], Cvoid, (C.PyPtr,), o) nothing end -const PYJLMETHODS = Vector{Any}() +const PYJLMETHODS = Lockable([]) function PyJulia_MethodNum(f) @nospecialize f - push!(PYJLMETHODS, f) - return length(PYJLMETHODS) + Base.@lock PYJLMETHODS begin + push!(PYJLMETHODS[], f) + return length(PYJLMETHODS[]) + end end function _pyjl_isnull(o::C.PyPtr, ::C.PyPtr) @@ -58,12 +60,12 @@ function _pyjl_callmethod(o::C.PyPtr, args::C.PyPtr) @assert nargs > 0 num = C.PyLong_AsLongLong(C.PyTuple_GetItem(args, 0)) num == -1 && return C.PyNULL - f = PYJLMETHODS[num] + f = Base.@lock PYJLMETHODS PYJLMETHODS[][num] # this form gets defined in jlwrap/base.jl return _pyjl_callmethod(f, o, args, nargs)::C.PyPtr end -const PYJLBUFCACHE = Dict{Ptr{Cvoid},Any}() +const PYJLBUFCACHE = Lockable(Dict{Ptr{Cvoid},Any}()) @kwdef struct PyBufferInfo{N} # data @@ -177,7 +179,7 @@ function _pyjl_get_buffer_impl( # internal cptr = Base.pointer_from_objref(c) - PYJLBUFCACHE[cptr] = c + Base.@lock PYJLBUFCACHE PYJLBUFCACHE[][cptr] = c b.internal[] = cptr # obj @@ -195,7 +197,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint) C.Py_DecRef(num_) num == -1 && return Cint(-1) try - f = PYJLMETHODS[num] + f = Base.@lock PYJLMETHODS PYJLMETHODS[][num] x = PyJuliaValue_GetValue(o) return _pyjl_get_buffer_impl(o, buf, flags, x, f)::Cint catch exc @@ -209,7 +211,7 @@ function _pyjl_get_buffer(o::C.PyPtr, buf::Ptr{C.Py_buffer}, flags::Cint) end function _pyjl_release_buffer(xo::C.PyPtr, buf::Ptr{C.Py_buffer}) - delete!(PYJLBUFCACHE, UnsafePtr(buf).internal[!]) + Base.@lock PYJLBUFCACHE delete!(PYJLBUFCACHE[], UnsafePtr(buf).internal[!]) nothing end @@ -339,22 +341,29 @@ end PyJuliaValue_IsNull(o) = Base.GC.@preserve o UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[] == 0 -PyJuliaValue_GetValue(o) = Base.GC.@preserve o PYJLVALUES[UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[]] +PyJuliaValue_GetValue(o) = Base.GC.@preserve o begin + idx = UnsafePtr{PyJuliaValueObject}(C.asptr(o)).value[] + Base.@lock PYJLVALUES PYJLVALUES[].values[idx] +end PyJuliaValue_SetValue(_o, @nospecialize(v)) = Base.GC.@preserve _o begin o = C.asptr(_o) idx = UnsafePtr{PyJuliaValueObject}(o).value[] if idx == 0 - if isempty(PYJLFREEVALUES) - push!(PYJLVALUES, v) - idx = length(PYJLVALUES) - else - idx = pop!(PYJLFREEVALUES) - PYJLVALUES[idx] = v + Base.@lock PYJLVALUES begin + if isempty(PYJLVALUES[].free_slots) + idx = PYJLVALUES[].next_slot[] + PYJLVALUES[].next_slot[] += 1 + else + idx = pop!(PYJLVALUES[].free_slots) + end + PYJLVALUES[].values[idx] = v end UnsafePtr{PyJuliaValueObject}(o).value[] = idx else - PYJLVALUES[idx] = v + Base.@lock PYJLVALUES begin + PYJLVALUES[].values[idx] = v + end end nothing end diff --git a/src/JlWrap/base.jl b/src/JlWrap/base.jl index 47ffb084..f8250c52 100644 --- a/src/JlWrap/base.jl +++ b/src/JlWrap/base.jl @@ -84,6 +84,7 @@ function Cjl._pyjl_callmethod(f, self_::C.PyPtr, args_::C.PyPtr, nargs::C.Py_ssi pybuiltins.NotImplementedError, "__jl_callmethod not implemented for this many arguments", ) + return C.PyNULL end return getptr(incref(ans)) catch exc diff --git a/src/Utils/Utils.jl b/src/Utils/Utils.jl index 6e7b3f4d..972b6216 100644 --- a/src/Utils/Utils.jl +++ b/src/Utils/Utils.jl @@ -308,4 +308,30 @@ function Base.iterate(x::StaticString{UInt32,N}, i::Int = 1) where {N} end end +@static if !isdefined(Base, :Lockable) + """ + Compat for `Base.Lockable` (introduced in Julia 1.11) + """ + struct Lockable{T,L} + value::T + lock::L + end + + Lockable(value) = Lockable(value, ReentrantLock()) + + # function Base.lock(f, l::Lockable) + # lock(l.lock) do + # f(l.value) + # end + # end + + Base.lock(l::Lockable) = lock(l.lock) + # Base.trylock(l::Lockable) = trylock(l.lock) + Base.unlock(l::Lockable) = unlock(l.lock) + Base.islocked(l::Lockable) = islocked(l.lock) + Base.getindex(l::Lockable) = (@assert islocked(l); l.value) +else + const Lockable = Base.Lockable +end + end