Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Solution to the ODE model in chapter 15 #3

Open
rasoolianbehnam opened this issue Jun 10, 2023 · 0 comments
Open

Solution to the ODE model in chapter 15 #3

rasoolianbehnam opened this issue Jun 10, 2023 · 0 comments

Comments

@rasoolianbehnam
Copy link

rasoolianbehnam commented Jun 10, 2023

Just FYI. I have written the following model for the ODE part in chapter 16 (Hare/Lynx). It's working but very slow.

def Ind(d, reinterpreted_batch_ndims=1, **kwargs):
    return tfd.Independent(d, reinterpreted_batch_ndims=reinterpreted_batch_ndims, **kwargs)
    
root = tfd.JointDistributionCoroutine.Root

N = len(data)
@tf.function
def get_HL(b_h, m_h, b_l, m_l, H1, L1):
    @tf.function
    def ode_fn(t, y):
        H = y[..., 0]
        L = y[..., 1]
        a = tf.stack([b_h - m_h * L, b_l * H - m_l], axis=-1)
        return a * y

    t_init = 0
    y_init = tf.stack([H1, L1], axis=-1)
    solver = tfp.math.ode.BDF(rtol=1e-3, atol=1e-3, max_num_steps=500)
    results = solver.solve(ode_fn, t_init, y_init, solution_times=tf.range(0, N))

    HL = einsum("t...k->...tk", results.states)

    H = HL[..., 0]
    L = HL[..., 1]
    return H, L

@tfd.JointDistributionCoroutine
def m03():
    mx = tf.float32.max
    m_l = yield root(tfd.TruncatedNormal(1, .5, 0, mx, name='m_l'))
    m_h = yield root(tfd.TruncatedNormal(.05, .05, 0, mx, name='m_h'))
    b_l = yield root(tfd.TruncatedNormal(.05, .05, 0, mx, name='b_l'))
    b_h = yield root(tfd.TruncatedNormal(1, .5, 0, mx, name='b_h'))
    
    batch_shape = m_l.shape
    
    sigma_h = yield root(tfd.Exponential(1, name="sigma_h"))
    sigma_l = yield root(tfd.Exponential(1, name="sigma_l"))
    
   
    H1 = yield root(tfd.LogNormal(tf.math.log(10.), 1, name='H1'))
    L1 = yield root(tfd.LogNormal(tf.math.log(10.), 1, name='L1'))
    
    p_h = yield root(tfd.Beta(40, 200, name='p_h'))
    p_l = yield root(tfd.Beta(40, 200, name='p_l'))
    

    H, L = get_HL(b_h, m_h, b_l, m_l, H1, L1)
    yield Ind(tfd.LogNormal(tf.math.log(p_h[..., None]*H), sigma_h[..., None]), name="H_obs")
    yield Ind(tfd.LogNormal(tf.math.log(p_l[..., None]*L), sigma_l[..., None]), name="L_obs")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant