Skip to content

treigerm/Minijyro.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

44 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Minijyro

A simple probabilistic programming language in Julia based on effect handlers. The name comes from the fact that this project is largely based on the ideas from Pyro's effect handlers and their Mini-Pyro implementation.

The design goals of this language are:

  • Allow for concise definition of sample statements using ~ syntax
  • Use effect handlers to implement simple operations such as conditioning and computing the log joint probability
  • Leverage existing Julia packages such as Distributions.jl, AdvancedHMC and Flux

NOTE: This is not meant to be a serious PPL to be used by anyone. If you are interested in probabilistic programming in Julia have a look at Turing.jl, Gen and Soss.jl.

Example: Bayesian Linear Regression

A simple model taken from Colin Caroll's tour of PPL APIs.

using Distributions
using LinearAlgebra: I # Identity matrix
using Random

using Minijyro

Random.seed!(42)

# Generate some data.
N = 100
D = 5
true_w = randn(D)
X = randn(N, D)
noise = 0.1 * randn(N)
y_obs = X * true_w + noise

@jyro function model(xs)
    D = size(xs)[2]
    w ~ MvNormal(zeros(D), I)
    y ~ MvNormal(xs * w, 0.1*I)
end

cond_model = condition(model, Dict(:y => y_obs))
samples, stats = nuts(cond_model, (X,), 1000)

@show abs.(true_w - mean(samples))

Behind the Scenes

Here is a high-level overview of the inner workings of Minijyro. For more details I recommend first reading through the links to the Pyro documentation from above and then through full the source code of Minijyro.

Minijyro models are normal Julia functions which are annotated with the @jyro macro. The macro does some source code transformations and translates the function to a MinijyroModel type. See dsl.jl for the full implementation of the @jyro macro.

Most importantly the @jyro macro translates each ~ expression into a call to

function sample!(
    trace::Dict,
    handlers_stack::Array{AbstractHandler,1},
    name::Any,
    dist::Distributions.Distribution
)
    if length(handlers_stack) == 0
        return rand(dist)
    end

    initial_msg = Dict(
        :fn => rand,
        :args => (dist, ),
        :name => name,
        :value => nothing,
        :is_observed => false,
        :done => false,
        :stop => false,
        :continuation => nothing
    )
    msg = apply_stack!(trace, handlers_stack, initial_msg)
    return msg[:value]
end

sample! basically samples a random value from dist. Crucially, any side effects of this sampling (e.g. computing the log density or saving the sampled value in trace) can be conveniently defined as effect handlers. The function apply_stack! is used to apply all effect handlers that are active at the given sample site:

function apply_stack!(
    trace::Dict,
    handlers_stack::Array{AbstractHandler,1},
    msg::Dict
)
    @assert length(handlers_stack) > 0

    handler_counter = 0
    # Loop through handlers from top of the stack to the bottom.
    for handler in handlers_stack[end:-1:1]
        handler_counter += 1
        process_message!(trace, handler, msg)
        if msg[:stop]
            break
        end
    end

    if !(msg[:value] != nothing || msg[:done])
        msg[:value] = msg[:fn](msg[:args]...)
    end

    # Loop through handlers from bottom of the stack to the top.
    # If we exited the first loop early then we will start looping from the
    # handler which caused the loop to exit.
    for handler in handlers_stack[end-handler_counter+1:end]
        postprocess_message!(trace, handler, msg)
    end

    if msg[:continuation] != nothing
        msg[:continuation](trace, msg)
    end

    return msg
end

Effect handlers are subtypes of AbstractHandler:

abstract type AbstractHandler end

function enter!(trace::Dict, h::AbstractHandler)
    return
end

function exit!(trace::Dict, h::AbstractHandler)
    return
end

function process_message!(trace::Dict, h::AbstractHandler, msg::Dict)
    return
end

function postprocess_message!(trace::Dict, h::AbstractHandler, msg::Dict)
    return
end

For example, conditioning on data can be implemented as:

struct ConditionHandler <: AbstractHandler
    data::Dict
end

function process_message!(trace::Dict, h::ConditionHandler, msg::Dict)
    if haskey(h.data, msg[:name])
        msg[:value] = h.data[msg[:name]]
        msg[:stop] = true
        msg[:is_observed] = true
    end
end

About

Probabilistic programming with effect handlers

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages