Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Provide interface to run Basix functions inside Numba-compiled kernels #201

Open
mscroggs opened this issue Apr 20, 2021 · 2 comments
Open
Labels
enhancement New feature or request performance Performance related issues

Comments

@mscroggs
Copy link
Member

Currently, a function is provided to allow dof transformations to be applied, but it would be good to allow more functions to be called without having to reimplement everything in Python.

@mscroggs mscroggs added the enhancement New feature or request label Sep 10, 2021
@garth-wells
Copy link
Member

@mscroggs could you sketch out a use case?

@garth-wells garth-wells added the performance Performance related issues label Jan 25, 2022
@atouminet
Copy link

Hello! Any update on this? By the way, is there any reason that prevents numba from JITting pybind11 generated functions?

@garth-wells Here is a real world use case for you:
Let's say I need to interpolate the results of a FEM simulation on a very large set of points. This is required, e.g. when working on implementing digital volume correlation where we seek to optimize a displacement field from two volumic images. As that interpolation is part of the process of minimizing an objective function, we also need the jacobian of this interpolation operator, which is basically the associated sparse interpolation matrix. The entries of that matrix are computed by evaluating each FEM basis function at a given coordinate.

A snipped of code to construct such a matrix would be the following (this code may contain some errors, I've not checked it, but it's just to illustrate the idea):

import numpy as np
import dolfinx.mesh, dolfinx.fem as fem
import basix
from basix.ufl_wrapper import BasixElement
from petsc4py import PETSc


def interpolation_matrix(
    x: np.ndarray,
    x_to_cell: np.ndarray,
    mesh: dolfinx.mesh.Mesh,
    element: basix.finite_element.FiniteElement,
):
    """Create a sparse interpolation matrix from a function space to a set of fixed points"""
    nx = np.shape(x)[0]

    # create dofmap
    ufl_element = BasixElement(element)
    V = fem.FunctionSpace(mesh, ufl_element)
    dofmap = V.dofmap
    tdim = element.dim
    bs = element.value_size

    rows = np.zeros(nx * bs + 1, dtype=np.int32)
    cols = np.zeros(nx * tdim * bs, dtype=np.int32)
    vals = np.zeros(nx * tdim * bs)

    num_cells = mesh.topology.index_map(mesh.topology.dim).size_global
    num_dofs_x = mesh.geometry.dofmap.links(0).size
    coords = mesh.geometry.x
    x_dofs = mesh.geometry.dofmap.array.reshape(num_cells, num_dofs_x)

    # loop needs to be JIT'd
    for k in range(nx):
        if len(x_to_cell[k]) == -1: # point is outside of the mesh
            rows[k + 1] = rows[k]
            continue

        cell = x_to_cell[k]

        vertices = np.array([coords[x_dofs[cell, i]] for i in range(num_dofs_x)])
        x_ref = mesh.geometry.cmap.pull_back(
            x[k].reshape(1, -1), vertices
        )  # need to use JIT'd basix.pull_back
        tab = element.tabulate(0, x_ref)  # need to use JIT'd basix element tabulation
        num_entries = np.shape(tab)[2]

        columns = dofmap.cell_dofs(cell)

        for b in range(bs):
            rows[k * bs + b + 1] = rows[k * bs + b] + num_entries
            cols[rows[k * bs + b] : rows[k * bs + b] + num_entries] = columns
            vals[rows[k * bs + b] : rows[k * bs + b] + num_entries] = tab[
                ..., b
            ].flatten()

    matrix = PETSc.Mat().createAIJWithArrays(
        size=(nx * bs, dofmap.index_map.size_global),
        csr=(rows, cols, vals),
        comm=PETSc.COMM_SELF,
    )
    matrix.assemble()

    return matrix

where x_to_cell is a mapping from the interpolation points to the mesh cells constructed beforehand.
As nx is typically very large, this python loop has miserable performance, and from my experience, constructing this matrix can be slower that solving several FEM problems.
That code could benefit from JIT'd versions of core basix functionality, such as tabulating a FEM basis and performing push-forward and pull-back transformations.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request performance Performance related issues
Projects
None yet
Development

No branches or pull requests

3 participants