Skip to content

Commit b3cfc4a

Browse files
committed
Add better precompilation with PrecompileTools
- Add option to skip compilation of `kgen`-generated kernels. If `compile=false`, kgen returns a symbol instead of the function itself.
1 parent 1bd0daa commit b3cfc4a

3 files changed

Lines changed: 54 additions & 51 deletions

File tree

src/SourceCodeMcCormick.jl

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ using Graphs
1010
using CUDA
1111
using StaticArrays: @MVector
1212
using MultiFloats
13+
using PrecompileTools: @setup_workload, @compile_workload
1314
import Dates
1415
import SymbolicUtils: BasicSymbolic, exprtype, SYM, TERM, ADD, MUL, POW, DIV
1516

@@ -49,8 +50,27 @@ include(joinpath(@__DIR__, "relaxation", "relaxation.jl"))
4950
include(joinpath(@__DIR__, "transform", "transform.jl"))
5051
include(joinpath(@__DIR__, "grad", "grad.jl"))
5152
include(joinpath(@__DIR__, "kernel_writer", "kernel_write.jl"))
52-
include(joinpath(@__DIR__, "precompile.jl"))
53-
_precompile_()
53+
54+
@setup_workload begin
55+
@variables x, y
56+
@compile_workload begin
57+
kgen(1 + x + y^2 + x*y, overwrite=true, compile=false)
58+
kgen(1 +
59+
(-x) +
60+
exp(x) +
61+
log(x) +
62+
(1/x) +
63+
abs(x) +
64+
2*y +
65+
(1/(1+exp(-x))) +
66+
x^3 +
67+
x^4 +
68+
x^3.5 +
69+
cos(x) +
70+
x*y, overwrite=true, compile=false)
71+
kgen(x^3)
72+
end
73+
end
5474

5575
export McCormickIntervalTransform, IntervalTransform
5676

src/kernel_writer/kernel_write.jl

Lines changed: 32 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -4,20 +4,24 @@ include(joinpath(@__DIR__, "math_kernels.jl"))
44
include(joinpath(@__DIR__, "string_math_kernels.jl"))
55

66
# The kernel-generating function, analogous to fgen.
7-
kgen(num::Num; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, setdiff(pull_vars(num), constants), [:all], constants, overwrite, splitting, affine_quadratic)
8-
kgen(num::Num, gradlist::Vector{Num}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, [:all], constants, overwrite, splitting, affine_quadratic)
9-
kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, setdiff(pull_vars(num), constants), raw_outputs, constants, overwrite, splitting, affine_quadratic)
10-
kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic)
11-
function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool)
7+
kgen(num::Num; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true, compile::Bool=true) = kgen(num, setdiff(pull_vars(num), constants), [:all], constants, overwrite, splitting, affine_quadratic, compile)
8+
kgen(num::Num, gradlist::Vector{Num}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true, compile::Bool=true) = kgen(num, gradlist, [:all], constants, overwrite, splitting, affine_quadratic, compile)
9+
kgen(num::Num, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true, compile::Bool=true) = kgen(num, setdiff(pull_vars(num), constants), raw_outputs, constants, overwrite, splitting, affine_quadratic, compile)
10+
kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}; constants::Vector{Num}=Num[], overwrite::Bool=false, splitting::Symbol=:default, affine_quadratic::Bool=true, compile::Bool=true) = kgen(num, gradlist, raw_outputs, constants, overwrite, splitting, affine_quadratic, compile)
11+
function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, constants::Vector{Num}, overwrite::Bool, splitting::Symbol, affine_quadratic::Bool, compile::Bool)
1212
# Create a hash of the expression and check if the function already exists
1313
expr_hash = string(hash(string(num)*string(gradlist)), base=62)
1414
if (overwrite==false) && (isfile(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl")))
1515
try func_name = eval(Meta.parse("f_"*expr_hash))
1616
return func_name
1717
catch
18-
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
19-
func_name = eval(Meta.parse("f_"*expr_hash))
20-
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2,3) for i = 1:length($gradlist)]...)
18+
if compile
19+
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
20+
func_name = eval(Meta.parse("f_"*expr_hash))
21+
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2,3) for i = 1:length($gradlist)]...)
22+
else
23+
func_name = Meta.parse("f_"*expr_hash)
24+
end
2125
return func_name
2226
end
2327
end
@@ -44,7 +48,7 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
4448
# and Software, 33:3, 563-593 (2018). DOI: 10.1080/10556788.2017.1335312
4549
# This method is also used in EAGO's `affine_relax_quadratic!` function.
4650
if affine_quadratic==true && is_quadratic(num) # NOTE: When switching to MOI variables, this will be easy to detect
47-
func_name = kgen_affine_quadratic(expr_hash, num, gradlist, constants)
51+
func_name = kgen_affine_quadratic(expr_hash, num, gradlist, constants, compile)
4852
return func_name
4953
end
5054

@@ -142,9 +146,6 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
142146
end
143147
end
144148

145-
# Include all the kernels that were just generated
146-
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
147-
148149
# We can assume that the newly created kernels have sufficiently
149150
# high register usage that they all have a max number of 256 threads.
150151
# All that we need is to figure out the maximum number of blocks
@@ -160,10 +161,15 @@ function kgen(num::Num, gradlist::Vector{Num}, raw_outputs::Vector{Symbol}, cons
160161
close(file)
161162

162163
# Compile the function and kernels
163-
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
164-
func_name = eval(Meta.parse("f_"*expr_hash))
165-
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2, 3) for i = 1:length($gradlist)]...)
166-
return func_name
164+
if compile
165+
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
166+
func_name = eval(Meta.parse("f_"*expr_hash))
167+
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2, 3) for i = 1:length($gradlist)]...)
168+
return func_name
169+
else
170+
func_name = Meta.parse("f_"*expr_hash)
171+
return func_name
172+
end
167173
end
168174

169175
# This is a quick function to detect if a Num object is quadratic. This function
@@ -257,7 +263,7 @@ end
257263
# A special version of kgen that only applies to quadratic functions. Instead of
258264
# doing McCormick relaxations, this returns either affine bounds or secant line
259265
# bounds, depending on where on the quadratic function the point of interest is.
260-
function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num}, constants::Vector{Num})
266+
function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num}, constants::Vector{Num}, compile::Bool)
261267
# Since it's quadratic, we can construct the kernel according to
262268
# `affine_relax_quadratic!` in EAGO.
263269

@@ -350,9 +356,6 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
350356
end
351357
close(file)
352358

353-
# Include this kernel so SCMC knows what it is
354-
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
355-
356359
# Add onto the file the "main" CPU function that calls the kernel
357360
blocks = Int32(CUDA.attribute(CUDA.device(), CUDA.DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT))
358361
file = open(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"), "a")
@@ -366,10 +369,15 @@ function kgen_affine_quadratic(expr_hash::String, num::Num, gradlist::Vector{Num
366369
close(file)
367370

368371
# Include the file again to get the final kernel
369-
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
370-
func_name = eval(Meta.parse("f_"*expr_hash))
371-
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2,4) for i = 1:length($gradlist)]...)
372-
return func_name
372+
if compile
373+
include(joinpath(@__DIR__, "storage", "f_"*expr_hash*".jl"))
374+
func_name = eval(Meta.parse("f_"*expr_hash))
375+
@eval $(func_name)(CUDA.zeros(Float64, 2, 4+2*length($gradlist)), [CUDA.zeros(Float64, 2,4) for i = 1:length($gradlist)]...)
376+
return func_name
377+
else
378+
func_name = Meta.parse("f_"*expr_hash)
379+
return func_name
380+
end
373381
end
374382

375383

src/precompile.jl

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)