From 07a2a5aee0171e46a84e4efda60165cdc1c1b79f Mon Sep 17 00:00:00 2001 From: chriselrod Date: Sat, 3 Apr 2021 13:16:02 -0400 Subject: [PATCH] Improve `Closure` hygiene, add test for updating assignment of symbol --- src/closure.jl | 8 ++++---- test/runtests.jl | 15 ++++++++++++++- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/src/closure.jl b/src/closure.jl index 87fdd71..4c470d0 100644 --- a/src/closure.jl +++ b/src/closure.jl @@ -92,18 +92,18 @@ end struct Closure{E,A} <: Function end -@generated function (::Closure{E,A})(args::Tuple{Vararg{Any,K}}, var"##SUBSTART##"::Int, var"##SUBSTOP##"::Int) where {K,A,E} +@generated function (::Closure{var"##E##",var"##A##"})(var"##args##"::Tuple{Vararg{Any,var"##K##"}}, var"##SUBSTART##"::Int, var"##SUBSTOP##"::Int) where {var"##K##",var"##A##",var"##E##"} q = Expr(:block) gf = GlobalRef(Core, :getfield) - for k ∈ 1:K - push!(q.args, Expr(:(=), A[k], Expr(:call, gf, :args, k, false))) + for k ∈ 1:var"##K##" + push!(q.args, Expr(:(=), var"##A##"[k], Expr(:call, gf, Symbol("##args##"), k, false))) end q = quote @inbounds begin $q var"##LOOPSTART##" = var"##SUBSTART##" * var"##LOOP_STEP##" - var"##LOOPOFFSET##" var"##LOOP_STOP##" = var"##SUBSTOP##" * var"##LOOP_STEP##" - var"##LOOPOFFSET##" - $(toexpr(E)) + $(toexpr(var"##E##")) end nothing end diff --git a/test/runtests.jl b/test/runtests.jl index cce83ad..b5014bd 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,6 +15,15 @@ function sin_batch_sum(v) end return sum(view(s, 1, :)) end +function rowsum_batch!(x, A) + @batch for n ∈ axes(A,2) + s = 0.0 + @simd for m ∈ axes(A,1) + s += A[m,n] + end + x[n] = s + end +end @testset "Range Map" begin function rangemap!(f::F, allargs, start, stop) where {F} @@ -69,6 +78,10 @@ end bsin!(y, x) @test y == sin.(x) @test sum(sin,x) ≈ sin_batch_sum(x) + + A = rand(200,300); x = Vector{Float64}(undef, 300); + rowsum_batch!(x, A); + @test x ≈ vec(sum(A,dims=1)) end @testset "start and stop values" begin @@ -121,7 +134,7 @@ end batch((length(x), max(1,num_threads()>>1), 2), dx, x) do (dx,x), start, stop CheapThreads.threaded_gradient!(f, view(dx, start%Int:stop%Int), view(x, start%Int:stop%Int), ForwardDiff.Chunk(8)) end; - @test dx == dxref + @test dx ≈ dxref end println("Package tests complete. Running `Aqua` checks.")