Skip to content

Short intro to scientific machine learning using physics informed neuronal networks. I used PyTorch as a framework.

Notifications You must be signed in to change notification settings

valentino-golob/IntroScientificMachineLearning

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

6 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Generic badge

IntroScientificMachineLearning

Short intro to scientific machine learning. I used PyTorch as a framework. Some equations are not rendered properly in the .ipynb version within github.

Introduction to Scientific Machine Learning


Scientific machine learning is an approach to solve problems in the domain of scientific computing, using neuronal networks and other machine learning techniques. One primary object is to use the traits of neuronal networks to enhance the manner one examines scientific models. [Introduction to Scientific Machine Learning through Physics-Informed Neural Networks, Chris Rackauckas] Through the use of learning directly nonlinear connections from ground truth data, machine learning methods allow us to surmount the imprecisions of an approximative mechanistic model. In order to produce precise predictions, conventional machine learning models rely on an extensive amount of training data. On this matter scientific machine learning merges physical models (e.g. differential equations) with classical machine learning techniques (e.g. neuronal networks) to generate predictions in a more data-efficient manner. For instance Physics-Informed Neuronal Networks (PINNs) engage differential equations in the loss function in order to integrate prior scientific knowledge. One drawback of PINNs is that the resulting models do not have the comprehensibility of classical mechanistic models.
Mechanistic models are restrained to employ prior scientific knowledge from literature, whereas the data-driven machine learning methods are more adaptable and do not utilize simplified assumptions to deduce the underlying mathematical models. As a result the main goal of scientific machine learning is to combine the benefits of both approaches and alleviate their individual detriments. [Universal Differential Equations for Scientific Machine Learning , Chris Rackauckas]

Phyisics-Informed Neuronal Networks (PINNs)

The following example is very much related to Chris Rackauckas course notes in Introduction to Scientific Machine Learning through Physics-Informed Neural Networks.

As aforesaid PINNs use differential equations in the cost function of a neuronal network somewhat like a regularizer, or solve a differential equation with a neuronal network. Consequently, the mathematical equations can steer the training of the neuronal network in conditions where ground truth data might not be present. We want to solve an ordinary differential equation with a given initial condition and :



In an initial step, we calculate an approximate solution given by a neuronal network:



We can derive that , if is the actual solution. Hence, we can express our loss function in the following configuration:



Therefore, one obtains that when the loss function is minimized. Consequently, solves the differential equation approximative. Within this study our values will be created randomly. Further there are different sampling techniques available. For instance the prominent grid size method. For advanced problems which incorporate a high number of input dimensions one should use sampling techniques which sample the space of input dimensions in a more efficient manner. (e.g. Latin Hypercube) [Global Sensitivity Analysis, Chris Rackauckas] Up to now the initial conditions of our ordinary differential equation have not been integrated. A first simple approach would be to incorporate the initial condition in the loss function.



The downside of this method is that by writing the loss function in this form one still has a constrained optimization problem. An unconstrained opimization problem is more efficient to encode and easier to handle. Hence, we choose a trial function such that it fulfills the initial condition by construction. [Artificial Neural Networks for Solving Ordinary and Partial Differential Equations, Isaac E. Lagaris]



As aforesaid always satisfies the initial condition, thus one can train the trial function to meet the requirements for the derivative function .



Accordingly, we have that , whilst our neuronal network is embedded in the trial function . One must note that the loss function is dependent on the parameters , which correspond to the weights and biases of the neuronal network . In order to solve the given problem the conventional gradient optimization methods which find the weights to minimize the loss function can be used.

In the next step we will look at a specific ordinary differential equation (ODE) with the intention of coding up the procedure based on the machine learning framework PyTorch. The given ODE is:



with and the known initial condition . Thus we will use as our universal approximator (UA):



In order to meet the requirements of:



we need the loss function L(p) defined as declared in Equation 6:

def loss_fn(self, outputs, targets):
    return nn.MSELoss(reduction="mean")(outputs, targets)

with:

X = torch.from_numpy(np.random.rand(10**5))
y = np.cos(2*np.pi*X) # see eq. 7
...
    for.. in EPOCHS
        ...
        X = X.type(torch.float32).to(self.device)
        y = y.type(torch.float32).to(self.device)
        g = lambda t: 1.0 + t * self.model(t) # see eq. 5 and 8
        score = (g(X + self.eps) - g(X)) / self.eps # see eq. 6
        loss = self.loss_fn(score, y) # see Eq. 6

In order to compare the results of the PINN and the true solution we take the integral on both sides of our ODE and visualize both solutions in a diagram.



In our study we use a simple feed forward neuronal network with one hidden layer, that obtains 32 hidden units and a activation function.

title

%matplotlib inline
import torch
import torch.nn as nn
import numpy as np
import time
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from torch.utils.data import TensorDataset, DataLoader
import warnings
warnings.filterwarnings('ignore')

# Hyperparameters
LR = 1e-2
MAX_EPOCH = 10
BATCH_SIZE = 512
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
SAVE_MODEL = False
PLOT = True

# Neuronal Network
class NNfunc(nn.Module):
    def __init__(self):
        super(NNfunc, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.net(x)
    
# Wrapper class
class Engine:
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device
        self.eps = torch.tensor(np.sqrt(torch.finfo(torch.float32).eps))

    def loss_fn(self, outputs, targets):
        return nn.MSELoss(reduction="mean")(outputs, targets) # see eq. 6

    # Training Function
    def train(self, data_loader):
        self.model.train()
        final_loss = 0
        for X, y in data_loader:
            X = X.type(torch.float32).to(self.device)
            y = y.type(torch.float32).to(self.device)
            self.optimizer.zero_grad()
            g = lambda t: 1.0 + t * self.model(t) # see eq. 5 and 8
            score = (g(X + self.eps) - g(X)) / self.eps # see eq. 6
            loss = self.loss_fn(score, y) # see eq. 6
            loss.backward()
            self.optimizer.step()
            final_loss += loss.item()
        return final_loss / len(data_loader)

    # Evaluation Function
    def evaluate(self, data_loader):
        self.model.eval()
        final_loss = 0
        for X, y in data_loader:
            X = X.type(torch.float32).to(self.device)
            y = y.type(torch.float32).to(self.device)
            g = lambda t: 1.0 + t * self.model(t) # see eq. 5 and 8
            score = (g(X + self.eps) - g(X)) / self.eps # see eq. 6
            loss = self.loss_fn(score, y) # see Eq. 6
            final_loss += loss.item()
        return final_loss / len(data_loader)
X = torch.from_numpy(np.random.rand(10**5))
y = np.cos(2*np.pi*X) # see eq. 7

X_train, X_val, y_train, y_val = map(torch.tensor, train_test_split(X, y, test_size=0.2, random_state=42))
train_dataloader = DataLoader(TensorDataset(X_train.unsqueeze(-1), y_train.unsqueeze(-1)),
                              batch_size=BATCH_SIZE,
                              pin_memory=True, shuffle=True)
val_dataloader = DataLoader(TensorDataset(X_val.unsqueeze(-1), y_val.unsqueeze(-1)), batch_size=BATCH_SIZE,
                            pin_memory=True, shuffle=True)

model = NNfunc().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, nesterov=True)
eng = Engine(model, optimizer, device=DEVICE)

best_loss = np.inf
early_stopping_iter = 10
early_stopping_counter = 0

for epoch in range(MAX_EPOCH):
    train_loss = eng.train(data_loader=train_dataloader)
    val_loss = eng.evaluate(data_loader=val_dataloader)
    if epoch == 0 or (epoch+1)%5 == 0:
        print(f"Epoch: {epoch+1},\t Train Loss: {train_loss},\tValidation Loss: {val_loss}")
    if val_loss < best_loss:
        best_loss = val_loss
        early_stopping_counter = 0
        if SAVE_MODEL:
            torch.save(model.state_dict(), f"model.bin")
    else:
        early_stopping_counter += 1
    if early_stopping_counter > early_stopping_iter:
        break

if PLOT:
    X = torch.linspace(0, 1, 101)
    y = np.sin(2 * np.pi * X) / (2 * np.pi) + 1.0 # see eq. 9
    plt.plot(X, y, label='True Solution')
    g = lambda t: 1.0 + t * model(t) # see eq. 5
    plt.plot(X, [g(X_i.reshape(1, 1).to(DEVICE)) for X_i in X], label='Neuronal Network')
    plt.grid(True, which='both')
    plt.legend()
    plt.xlabel('t')
    plt.ylabel('u')
Epoch: 1,	 Train Loss: 0.2834244524810914,	Validation Loss: 0.022784735774621367
Epoch: 5,	 Train Loss: 0.0035874302525690216,	Validation Loss: 0.0028067371109500527
Epoch: 10,	 Train Loss: 0.0009066580848881062,	Validation Loss: 0.0008224665507441387

png

The embedded trial function is a universal approximator. And due to the Universal Approximation Theorem (UAT) we know that it can approximate our nonlinear function. The UAT implies that a neuronal network, which is sufficiently large, can approximate any continuous function on a bounded domain. [A Universal Approximation Theorem of Deep Neural Networks for Expressing Probability Distributions, Yulong Lu] In order to approximate these continuous nonlinear functions one could also use different approaches, that satisfy the prerequisite of the UAT. For instance one could use arbitrary polynomials or a fourier series to approximate any continuous function. In the case of a function that is dependent on two dimensions one has to take the tensor product of the UA. Thus, we will receive a higher dimensional UA:



Since we have to include every combination of terms this kind of approach would result in an exponential growth with coefficients. (n coefficients in each dimension d) This kind of growth is the so-called curse of dimensionality. In contrast, the parameters of neuronal networks that approximate a d dimensional function, grow as a polynomial of d. This property of polynomial rather than exponential growth enable neuronal networks to surmount the curse of dimensionality. [Introduction to Scientific Machine Learning through Physics-Informed Neural Networks, Chris Rackauckas].

Harmonic Oscillator with deformities in the spring

The following example is very much related to Chris Rackauckas course notes in [Introduction to Scientific Machine Learning through Physics-Informed Neural Networks, Chris Rackauckas].

In the preceding example the differential equation had all the information that described the physical model. In the following case we assume that we are only aware of a part of the differential equation that describes the physical model. However, this time we have some actual measurements from the real physical system. Thus, our goal is to modify the loss function L(p) of our neuronal network, such that it incorporates the actual measurements and the portion of ODE that we are familiar to. Suppose we have a spring with some deformities. The entire second order differential equation that describes our physical model would be:



We have an undamped harmonic oscillator without externally applied force and assume a mass of m = 1. Hence, the ODE reduces to the following equation:



The term specifies the deformities in the spring.

Suppose we need to identify the force applied on the spring at each point in space. In the first stage we solely make use of the actual measurements on the physical system. With the limitation that we only know a few points (4 measurements) with information about the position, velocity and force of the spring at equally spaced times .

title
Hence, we try to predict the expected force of the spring at any location by training a conventional feed forward neuronal network (no physical information embedded).

import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import TensorDataset, DataLoader
from scipy.integrate import odeint
from scipy.interpolate import interp1d
from math import floor


# Hyperparameters
LR = 1e-2
MAX_EPOCH = 200
BATCH_SIZE = 4 # batch size is four, since we only have 4 measurements points
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
SAVE_MODEL = False
TEND = 10
PLOT = True


# Solve the second order differential equation numerical
def OdeModel(y, t, k):
    x, v = y
    dydt = [v, -k * x + 0.1 * np.sin(x)]
    return dydt

k = 1.0 # spring stiffness
x0 = [0.0, 1.0] # initial conditions (x: position, v: velocity)
t = np.linspace(0, TEND, 101)
sol = odeint(OdeModel, x0, t, args=(k,)) # numerical solution of the actual differential eq.


# True Force and Force measurements
def force(x, k):
    return -k * x + 0.1 * np.sin(x)

positions_plot = sol[:, 0]
force_plot = force(positions_plot, k) # force at the corresponding position points

t_measurement = np.linspace(0, 10, 4) # measurement timepoints
positions_data = interp1d(t, positions_plot)(t_measurement) # interpoltated position points
force_measurement = interp1d(t, force_plot)(t_measurement) # interpolated true force


# Conventional neuronal network approach wihout any scientific prior
class NNForce(nn.Module):
    def __init__(self):
        super(NNForce, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(1, 32),
            nn.Tanh(),
            nn.Linear(32, 1)
        )

    def forward(self, x):
        return self.net(x)

# Wrapper class
class Engine:
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device

    def loss_fn(self, outputs, targets):
        return nn.MSELoss(reduction="mean")(outputs, targets)

    def train(self, data_loader):
        self.model.train()
        final_loss = 0
        for X, y in data_loader:
            X = X.type(torch.float32).to(self.device)
            y = y.type(torch.float32).to(self.device)
            self.optimizer.zero_grad()
            score = self.model(X)
            loss = self.loss_fn(score, y)
            loss.backward()
            self.optimizer.step()
            final_loss += loss.item()
        return final_loss / len(data_loader)
    
X = torch.from_numpy(positions_data)
y = torch.from_numpy(force_measurement)

X_train, y_train = map(torch.tensor, (X, y))

train_dataloader = DataLoader(TensorDataset(X_train.unsqueeze(-1), y_train.unsqueeze(-1)),
                              batch_size=BATCH_SIZE,
                              pin_memory=True, shuffle=True)

model = NNForce().to(DEVICE)
optimizer = torch.optim.SGD(model.parameters(), lr=LR, momentum=0.9, nesterov=True)
eng = Engine(model, optimizer, device=DEVICE)

for epoch in range(MAX_EPOCH):
    train_loss = eng.train(data_loader=train_dataloader)
    if epoch == 0 or (epoch+1)%50 == 0:
        print(f"Epoch: {epoch+1},\t Train Loss: {train_loss}")


t = torch.linspace(0, TEND, 101)
positions_plot = torch.from_numpy(positions_plot).float()

if PLOT:
    plt.plot(t, force_plot, 'b', label='True Force')
    plt.plot(t, [model(x_i.reshape(1, 1).to(DEVICE)) for x_i in positions_plot], 'r', label='Predicted Force')
    plt.plot(t_measurement, force_measurement, 'o', label='Force Measurements')
    plt.legend(loc='best')
    plt.grid(True, which='both')
    plt.xlabel('t')
    plt.ylabel('F')
    plt.show()
Epoch: 1,	 Train Loss: 0.04374217614531517
Epoch: 50,	 Train Loss: 0.0013166749849915504
Epoch: 100,	 Train Loss: 0.0005590120563283563
Epoch: 150,	 Train Loss: 0.0002354290772927925
Epoch: 200,	 Train Loss: 9.870142821455374e-05

png

It is noticeable that the neuronal network fits the given data at the measurement points, though not the truth force time plot. The reason for the poor fitting of the truth force curve is, that the specified measurement points do not capture all the relevant underlying physics. Presumably one would obtain a better fitting of the truth force development, if the measurement points had been chosen in a manner that covered the underlying characteristics more skillful (measurement points at peak's and valley's). An alternative is to gather more data. However, collecting more data is often very expensive or not even possible in real life applications.

Prediction using measurements and physical knowledge (PINNs)

Suppose we are not able to obtain more measurement data of the physical system. As aforesaid one approach is to alleviate this condition is to incorporate physical knowledge into the neuronal network. However, in the present case we are not aware of the full underlying differential equation describing our physical model (see eq. 12). Nonetheless, we know Hooke's law, which describes the mechanics of an idealized undamped d = 0 spring without externally applied force F(t) = 0:



X_ode = 2*torch.rand(100) - 1 # see eq. 13
y_ode = -k*X_ode

The following diagram indicates the differences in position and velocity if we compute the numerical solutions for both differential equations ("full" eq. 12 and "simplified" eq. 13) with an ode-solver: title
The simplified differential equation drifts near the end, but it is a helpful non-data prior that we can add to enhance the predictions of our neuronal network. Hence, our goal is to use data from the measurements and nudge the prediction of our NN towards Hooke's law. Consequently, we develop a loss function that combines the loss on the measurement data and the loss on the simplified ODE :



for (X, y), (X_ode, y_ode) in zip(data_loader, data_loader_ode):
    X, X_ode = X.type(torch.float32).to(self.device),
                      X_ode.type(torch.float32).to(self.device)
    y, y_ode = y.type(torch.float32).to(self.device),
                      y_ode.type(torch.float32).to(self.device)
    ...        
    score, score_ode = self.model(X), self.model(X_ode)
    loss = self.loss_fn(score, y, score_ode, y_ode) # see eq. 14
...
def loss_fn(self, outputs, targets, outputs_ode, targets_ode): # see eq. 14
        return nn.MSELoss(reduction="mean")(outputs, targets) 
               + LAMBDA * nn.MSELoss(reduction="mean")(outputs_ode, targets_ode)

is some kind of weighting hyperparameter in order to control the regularization towards the scientific prior.

# PINN

# Hyperparameters
LR = 1e-2
MAX_EPOCH = 200
BATCH_SIZE = 4 # batch size is four, since we only have 4 measurements points
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
SAVE_MODEL = False
TEND = 10
PLOT = True

LAMBDA = 0.1 # weighting parameter in the loss function 

# Wrapper class
class EnginePhysicsInformed:
    def __init__(self, model, optimizer, device):
        self.model = model
        self.optimizer = optimizer
        self.device = device

    def loss_fn(self, outputs, targets, outputs_ode, targets_ode): # see eq. 14
        return nn.MSELoss(reduction="mean")(outputs, targets) + LAMBDA * nn.MSELoss(reduction="mean")(outputs_ode, targets_ode)

    def train(self, data_loader, data_loader_ode):
        self.model.train()
        final_loss = 0
        for (X, y), (X_ode, y_ode) in zip(data_loader, data_loader_ode):
            X, X_ode = X.type(torch.float32).to(self.device), X_ode.type(torch.float32).to(self.device)
            y, y_ode = y.type(torch.float32).to(self.device), y_ode.type(torch.float32).to(self.device)
            self.optimizer.zero_grad()
            score, score_ode = self.model(X), self.model(X_ode)
            loss = self.loss_fn(score, y, score_ode, y_ode)
            loss.backward()
            self.optimizer.step()
            final_loss += loss.item()
        return final_loss / len(data_loader)


X = torch.from_numpy(positions_data)
y = torch.from_numpy(force_measurement)

X_train, y_train = map(torch.tensor, (X, y))

train_dataloader = DataLoader(TensorDataset(X_train.unsqueeze(-1), y_train.unsqueeze(-1)),
                              batch_size=BATCH_SIZE,
                              pin_memory=True, shuffle=True, drop_last=True)

X_ode = 2*torch.rand(100) - 1
y_ode = -k*X_ode # see eq. 13

X_ode, y_ode = map(torch.tensor, (X_ode, y_ode))

train_dataloader_ode = DataLoader(TensorDataset(X_ode.unsqueeze(-1), y_ode.unsqueeze(-1)),
                                  batch_size=floor(len(X_ode)/floor(len(X_train)/BATCH_SIZE)),
                                  pin_memory=True, shuffle=True, drop_last=True)


model_pin = NNForce().to(DEVICE)
optimizer = torch.optim.SGD(model_pin.parameters(), lr=LR, momentum=0.9, nesterov=True)
eng_pin = EnginePhysicsInformed(model_pin, optimizer, device=DEVICE)

for epoch in range(MAX_EPOCH):
    train_loss = eng_pin.train(data_loader=train_dataloader, data_loader_ode=train_dataloader_ode)
    if epoch == 0 or (epoch+1)%50 == 0:
        print(f"Epoch: {epoch+1},\t Train Loss: {train_loss}")


t = torch.linspace(0, TEND, 101)
positions_plot = positions_plot.float()
plt.plot(t, force_plot, 'c', label='True Force')
plt.plot(t, [model_pin(x_i.reshape(1, 1).to(DEVICE)) for x_i in positions_plot], 'm', label='Predicted Force PINN')
plt.plot(t, [model(x_i.reshape(1, 1).to(DEVICE)) for x_i in positions_plot], 'r', label='Predicted Force cNN')
plt.plot(t_measurement, force_measurement, 'o', label='Force Measurements')
plt.xlabel('t')
plt.ylabel('F')
plt.legend(loc='best')
plt.grid(True, which='both')
plt.show()
Epoch: 1,	 Train Loss: 0.07980692386627197
Epoch: 50,	 Train Loss: 0.00022909167455509305
Epoch: 100,	 Train Loss: 0.00021723960526287556
Epoch: 150,	 Train Loss: 0.0002150132495444268
Epoch: 200,	 Train Loss: 0.00021372048649936914

png

From the upper diagram it is obvious that informing the neuronal network training procedure with prior physical knowledge (PINN) improved the prognosis of the architecture. It should be noted that the acronym cNN in the diagram means conventional neuronal network (not Convolutional Neuronal Network). Moreover, one should note, that in all the previous examples we did not attempt to find optimal configurations or study the influence of various hyperparameters on the accuracy of the method. [Artificial Neural Networks for Solving Ordinary and Partial Differential Equations, Isaac E. Lagaris]

Identification of Nonlinear Interactions with Universal Ordinary Differential Equations

The example in this section is related to section 2.1 of the following paper [Universal Differential Equations for Scientific Machine Learning].

In this section we support the symbolic regression framework SINDy with a Universal Ordinary Differential Equation (UODE). Thereby we provide a method to reconstruct the unknown underlying terms of our dynamical system in a more data-efficient manner. We demonstrate the approach with the Lokta-Volterra system:



We suppose that we have measurements for the state variables and for a short time series . Moreover, we assume that we are aware of the birth rate of the prey and the death rate of the predator . With the given information we introduce a UODE with prior biological knowledge:



In this manner we include the prior known information and allow the identification of the unknown interaction among the predators and preys. The true values for the unknown interaction are . The detection of the unknown engagements between the state variables in the UODE, results in the training of an Universal Approximator (UA) . The UA is represented as an fully connected neuronal network with 2 input variables, 3 hidden layers with 5 hidden neurons each, a gaussian rbf activation function and 2 outputs:

# Universal Approximator (UA) U: R^2 -> R^2
class UA(nn.Module):
    def __init__(self):
        super(UA, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 5),
            self.rbf(),
            nn.Linear(5, 5),
            self.rbf(),
            nn.Linear(5, 5),
            self.rbf(),
            nn.Linear(5, 2)
        )
    # Gaussian RBF activation function
    class rbf(nn.Module):
        def forward(self, input: torch.Tensor) -> torch.Tensor:
            return torch.exp(-torch.pow(input, 2))

    def forward(self, t, x):
        return self.net(x)    
    

The UA is embedded in the UODE (see eq. 16):

# Universal Ordinary Differential Equation (UODE)
class UODE(nn.Module):
    def __init__(self):
        super(UODE, self).__init__()
        self.ua = UA()

    def uode_sys(self, u, x, y):
        u1, u2 = torch.split(u, 1)
        dx_dt = torch.add(torch.mul(alpha, x), u1)
        dy_dt = torch.add(torch.mul(-delta, y), u2)

        return torch.cat([dx_dt, dy_dt])

    def forward(self, t, xy):
        with torch.set_grad_enabled(True):
            self.xy = xy
            x, y = torch.split(xy, 1)
            self.xy = self.ua(None, self.xy)
            self.xy = self.uode_sys(self.xy, x, y)

        return self.xy  
    

In the following code section we use the torchdiffeq library to solve the UODE system. The code was modified in a way that we do not have to train the model again. We load the model state dict of the pretrained UODE and predict the time-derivatives of the unknown interaction with the UA.

with torch.no_grad():
    # Prediction of the state variables with UODE System
    pred_xy = torchdiffeq.odeint(model, true_xy0, t)
    # Prediction of the derivatives of the unknown part 
    # state variables with the UA (!UA is part of UODE)
    pred_interaction_dot = model.ua(None, pred_xy)
import torch
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from scipy import integrate
import torchdiffeq
import pysindy as ps

# Controls
DEVICE = torch.device("cpu")
LOAD_MODEL = True
SAVE_MODEL = False
PLOT = True
OPTIMIZER = "LBFGS"
LOAD_MODEL_FILE = "trainedmodel/model_statedict.pth.tar"

# Hyperparameters
MAX_EPOCHS = 0
LEARNING_RATE  = 0.01
WEIGHT_DECAY = 0.0
TEND = 3
DT = 0.02

# Parameters
true_xy0 = torch.tensor([0.44249296,4.6280594])
t = torch.linspace(0, TEND, int((TEND-0)/DT + 1))
alpha, beta, gamma, delta = torch.tensor([1.3, 0.9, 0.8, 1.8])

# Universal Approximator (UA) U: R^2 -> R^2
class UA(nn.Module):
    def __init__(self):
        super(UA, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(2, 5),
            self.rbf(),
            nn.Linear(5, 5),
            self.rbf(),
            nn.Linear(5, 5),
            self.rbf(),
            nn.Linear(5, 2)
        )

    # Gaussian RBF activation function
    class rbf(nn.Module):
        def forward(self, input: torch.Tensor) -> torch.Tensor:
            return torch.exp(-torch.pow(input, 2))

    def forward(self, t, x):
        return self.net(x)
    
# Universal Ordinary Differential Equation (UODE)
class UODE(nn.Module):
    def __init__(self):
        super(UODE, self).__init__()
        self.ua = UA()

    def uode_sys(self, u, x, y):
        u1, u2 = torch.split(u, 1)
        dx_dt = torch.add(torch.mul(alpha, x), u1)
        dy_dt = torch.add(torch.mul(-delta, y), u2)

        return torch.cat([dx_dt, dy_dt])

    def forward(self, t, xy):
        with torch.set_grad_enabled(True):
            self.xy = xy
            x, y = torch.split(xy, 1)
            self.xy = self.ua(None, self.xy)
            self.xy = self.uode_sys(self.xy, x, y)

        return self.xy
    
# Full ODE System with unknown interaction
class ODE_Full(nn.Module):
    def forward(self, t, xy):
        x, y = xy
        dx_dt = torch.Tensor([alpha*x - beta*x*y])
        dy_dt = torch.Tensor([-delta*y + gamma*x*y])
        return torch.cat([dx_dt, dy_dt])

# ODE System with only the unknown interaction
class ODE_Part(nn.Module):
    def forward(self, t, xy):
        x, y = xy
        dx_dt = torch.Tensor([- beta*x*y])
        dy_dt = torch.Tensor([gamma*x*y])
        return torch.cat([dx_dt, dy_dt])
    
# Utilities
# Load model and optimizer state_dict
def load_checkpoint(checkpoint, model, optimizer):
    print("=> Loading checkpoint")
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])

# Save model and optimizer state_dict
def save_checkpoint(state, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    torch.save(state, filename)
    
# Soving the ODE-System, all parameters given
with torch.no_grad():
    # True solution for the state variables
    true_xy = torchdiffeq.odeint(ODE_Full(), true_xy0, t, method='dopri5')

# Define the Universal Ordinary Equation model
model = UODE()
optimizer = torch.optim.LBFGS(model.parameters(), lr=LEARNING_RATE, )
if LOAD_MODEL:
    load_checkpoint(torch.load(LOAD_MODEL_FILE), model, optimizer)

# Train the model
for itr in range(1, MAX_EPOCHS + 1):
    def closure():
        optimizer.zero_grad()
        pred_xy = torchdiffeq.odeint(model, true_xy0, t)
        loss = torch.sum((pred_xy - true_xy) ** 2)
        print(f"Epoch: {itr}, Loss: {loss}")
        loss.backward()
        return loss
    optimizer.step(closure)
    
# Save model and optimizer state_dict
if SAVE_MODEL:
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    save_checkpoint(checkpoint, filename=LOAD_MODEL_FILE)
    
# deactivate autograd engine, reduces memory usage but not able to backprob, no gradiens saved
with torch.no_grad():
    # Prediction of the state variables with UODE System
    pred_xy = torchdiffeq.odeint(model, true_xy0, t)
    # Prediction of the derivatives of the unknown part 
    # state variables with the UA (!UA is part of UODE)
    pred_interaction_dot = model.ua(None, pred_xy)

true_u1 = torch.mul(-beta, torch.mul(pred_xy[:,0], pred_xy[:,1]))
true_u2 = torch.mul(gamma, torch.mul(pred_xy[:,0], pred_xy[:,1]))
# True solution of the derivatives of the unknown part state variables
true_interaction_dot = torch.transpose(torch.stack([true_u1, true_u2]), 0, 1)

# Soving the unknown part of the ODE-System, all parameters given
with torch.no_grad():
    # True solutiono of the unknown part state variables
    true_interaction = torchdiffeq.odeint(ODE_Part(), true_xy0, t, method='dopri5')
    
if PLOT:
    # index for scatter plot
    idx = np.arange(0, len(t), 20)
    
    fig, axes = plt.subplots(2, 1, figsize=(15, 10))

    axes[0].scatter(t[idx], true_xy[idx, 0], s=20, label="Measurements x(t)")
    axes[0].scatter(t[idx], true_xy[idx, 1], s=20, label="Measurements y(t)")
    axes[0].plot(t, pred_xy[:,0], label="UODE Approximation x(t)")
    axes[0].plot(t, pred_xy[:,1], label="UODE Approximation y(t)")
    axes[0].grid(True, which='both')
    axes[0].set(xlabel='t', ylabel='x, y', title='Approximation of the full ODE-System, x, y')
    axes[0].legend()

    axes[1].scatter(t[idx], true_interaction_dot[idx, 0], s=20, label="True Unknown Interaction x(t)")
    axes[1].scatter(t[idx], true_interaction_dot[idx, 1], s=20, label="True Unknown Interaction y(t)")
    axes[1].plot(t, pred_interaction_dot[:,0], label="UA Approximation x(t)")
    axes[1].plot(t, pred_interaction_dot[:,1], label="UA Approximation y(t)")
    axes[1].grid(True, which='both')
    axes[1].set(xlabel='t', ylabel=r"$\dot x, \dot y$", 
                title='Approximation of the unkwon part of the ODE-System, $\dot x, \dot y$')
    axes[1].legend()
=> Loading checkpoint

png

The first chart compares the UODE approximation and some measurements of the state variables of the entire ODE-System within the time interval . The latter diagram indicates the differences of the true unknown interactions and the predictions of our UA . In the following code cell we make use of the SINDy algorithm which enables us to identify the underlying mathematical terms of the unknown interactions.

      SINDy: Sparse Identification of Nonlinear Dynamical systems
Most physical systems only rely on a few significant terms that define the dynamics. This property makes the governing equations sparse in a high-dimensional nonlinear function space. SINDy seeks to describe the time derivatives of the given measurements in terms of a nonlinear function :



The vector represents the state variables of our physical system at a certain time . Whereas the function denotes the dynamic constraints of how the system changes with time. As aforesaid the main idea behind SINDy is that the function often only contains a few terms. Terefore, is sparse in the space of a set of an a priori selected set of basis functions . If we choose the set of basis functions in appropriate manner, can be formulated as a linear combination of the monomials of basis functions (e.g. polynomial basis function) where most of the coefficients are equal to 0:



The data of the measured state variables and the corresponding time derivatives at different times is arranged in the matrixes and :



The library matrix consist of a set of selected basis functions. For instance the library matrix could consist of constant, polynomial or trigonometric terms.



Thus, each denotes a candidate function for the right side of equation 17. Since most of the coefficients of these nonlinearities in are 0, we introduce sparse vectors of coefficients that define which nonlinearities are active. This leads to the following sparse regression problem:



The vector denotes the coefficients for the linear combination of the basis functions (see eq. 18):



[Discovering governing equations from data by sparse identification of nonlinear dynamical systems, Steven L. Brunton et al.] In the following code cell we define the SINDy algorithm with the pysindy library. The main difference in the UODE-enhanced approach to the regular use of the SINDy algorithm is, that we do not have to numerically approximate the derivatives (e.g. with finite difference methods). In our case we have the UA as an estimator of the derivative for the unknown interaction terms. Thus, we perform the sparse regression on the UA to obtain just the equation of the unknown interactions. In our simplified example we only used a 2-dimensional polynomial basis. A more elaborated approach would be possible with the help of scipy's GridSearchCV method.

# Sparse identification of nonlinear dynamics to identify the underlying mathematical
# equation of the unknown interaction

# define the parameters for the model, the prober parameters can be determined
# with e.g. Gridsearch method (GridSearchCV - scipy) 
feature_library  = ps.feature_library.PolynomialLibrary(degree=2)
optimizer  = ps.optimizers.STLSQ(threshold=0.05)
# Define the SINDy Model
model_SINDy = ps.SINDy(
    feature_library=feature_library,
    optimizer=optimizer,
    feature_names=["x", "y"]
)

# Prepare variables for SINDy
pred_interaction_dot = pred_interaction_dot.squeeze().numpy()
t = t.unsqueeze(-1).numpy()
pred_xy = pred_xy.numpy()

# Fit the SINDy model
model_SINDy.fit(pred_xy, t=t, x_dot=pred_interaction_dot)
print("SINDy Predicton - unknown part of the ODE:")
model_SINDy.print()

# Param Grid for Gridsearch Method to identify the optimal parameters for SINDy Algorithm
# model_SINDy = ps.SINDy(t_default=DT)
# param_grid = {
#     "optimizer__threshold": [0.05],
#     "optimizer": [ps.optimizers.STLSQ()],
#     "feature_library": [ps.feature_library.PolynomialLibrary()],
#     "feature_library__degree": [2, 3],
#     "feature_names": [["x", "y"]]
# }
#
# search = GridSearchCV(
#     model_SINDy,
#     param_grid,
#     cv=TimeSeriesSplit(n_splits=10)
# )

# search.fit(pred_xy, t=t, x_dot=pred_interaction_dot)
# print("Best parameters:", search.best_params_)
# model_SINDy = search.best_estimator_
# print("SINDy Predicton - unknown part of the ODE:")
# model_SINDy.print()
SINDy Predicton - unknown part of the ODE:
x' = -0.885 x y
y' = 0.781 x y

The SINDy algorithm yields to an approximation of the and value of -0.885 and 0.781 and a 2-dimensional polynomial of for the unknown interactions:



# ODE-System with the SINDy recovered dynamics implemented
def SINDy_recovered_dynamics(xy, t):
    xy = np.expand_dims(xy, axis=0)
    x, y = xy[0]
    
    # SINDy recoverd dynamics
    u1, u2 = model_SINDy.predict(xy)[0]
    return [alpha*x + u1,
            -delta*y + u2]

The true values for and are -0.9 and 0.8:



In the following code cell we compare the Lokta-Volterra equation approximation with UODE-enhanced SINDy and the true state variables.

# ODE-System with the SINDy recovered dynamics implemented
def SINDy_recovered_dynamics(xy, t):
    xy = np.expand_dims(xy, axis=0)
    x, y = xy[0]
    
    # SINDy recoverd dynamics
    u1, u2 = model_SINDy.predict(xy)[0]
    return [alpha*x + u1,
            -delta*y + u2]

# untrained t-values
t_ = torch.linspace(0, 20, int((20-0)/DT + 1))
# True solution to untrained t-values
true_xy_ = torchdiffeq.odeint(ODE_Full(), true_xy0, t_, method='dopri5')
# SINDy recovered predictions to untrained t-values
recovered_xy_ = integrate.odeint(SINDy_recovered_dynamics, true_xy0, t_)

if PLOT:
    # index for scatter plot
    idx = np.arange(0, len(t_), 10)
    
    fig, ax = plt.subplots(1, 1, figsize=(15, 5))
    ax.scatter(t_[idx], true_xy_[idx,0], label="True Solution x(t)")
    ax.scatter(t_[idx], true_xy_[idx,1], label="True Solution y(t)")
    ax.plot(t_, recovered_xy_[:,0], label="Recovered Dynamics x(t)")
    ax.plot(t_, recovered_xy_[:,1], label="Recovered Dynamics y(t)")
    ax.scatter(t, true_xy[:,0], marker='v', label="Trained x(t)")
    ax.scatter(t, true_xy[:,1], marker='v', label="Trained y(t)")
    ax.set(xlabel="t", ylabel="x, y", title='Extrapolated predictions from short traing data')
    ax.grid(True, which='both')
    ax.legend()
    
    fig, ax = plt.subplots(1, 1, figsize=(15, 5))
    ax.plot(t_,np.abs(true_xy_[:,0] - recovered_xy_[:,0]), label="x(t)")
    ax.plot(t_,np.abs(true_xy_[:,1] - recovered_xy_[:,1]), label="y(t)")
    ax.set(yscale='log', xlabel='t', ylabel='Absoute Error', title='UODE Error')
    ax.grid(True, which='both')
    ax.legend()
    
    fig = plt.figure(figsize=(10, 10))
    ax = plt.axes(projection='3d')
    ax.plot(pred_xy[:,0], pred_xy[:,1],pred_interaction_dot[:,1], label=r"Universal Approximator, $U_{\theta, 2}(x, y)$")
    ax.plot(pred_xy[:,0], pred_xy[:,1],true_interaction_dot[:,1], label=r"True missing term, $\gamma x y$")
    ax.set(yscale='linear', xlabel='x(t)', ylabel='y(t)', zlabel=r'$U_{2}(x(t), y(t))$', title=r'True and predicted missing term $U_{2}(x(t), y(t)) \, \forall t \in [0, 3]$')
    ax.grid(True, which='both')
    ax.legend()

png

png

png

In the first chart we compare the extrapolation of the knowledge-enhanced SINDy approximation with the true state variables and . The green and red triangles show some datapoints which were used to fit the UODE. The blue and orange dots show the true solution of the ODE beyond the measurement data we used for training . The blue and orange lines display the extrapolated solution of the UODE-enhanced SINDy recovered equation. It is central to note that even though the measured data , which we used to fit the UODE-enhanced SINDy approach, did not include an entire period of the cyclic solution, the resulting approximation was able to extrapolate quite accurately. The second diagram displays the absolute error in the UODE-enhanced SINDy approach against the true state variables and . We observe that the error increases for higher values of . The final graph illustrates the values of the unknown interaction evaluated with the UA and the true values .

Further there are far more fields of application for using the UDE framework:
[Mixing Differential Equations and Neural Networks for Physics-Informed Learning, Chris Rackauckas]

  • Extensions to other Differential Equations
    • Universal Stochastic Differential Equations (USDEs)
    • Universal Differential-Algebraic Equations (UDAEs)
    • Universal Delay Differential Equations (UDDEs)
    • Universal Boundary Value Problems (UBVPs)
    • Universal Partial Differential Equations (UPDEs)
    • etc.
  • Deep BSDE Methods for High Dimensional Partial Differential Equations
    • Nonlinear Black-Scholes
    • Stochastic Optimal Control
    • etc.
  • Surrogate Acceleration Methods
  • etc.
The methods used in the scope of the previous studies incorporated data into our models using point estimates. Though in real-life applications, data has uncertainty and noise. Consequently, we should extend our modeling methods in order to incorporate probabilistic estimates:
[From Optimization to Probabilistic Programming, Chris Rackauckas]
  • Bayesian Estimation with Point Estimates (point estimates for the most probable parameters
  • Bayesian Estimation of Posterior Distributions
    • Sampling Based Approach (Monte Carlo, Metropolis Hastings Algorithm, Hamiltonian Monte Carlo)
    • Variational Inference (Automatic Differentiation Variational Inference (ADVI))
Further the question arises how the output of the model generally alters with a change in the input. For this we need an understanding of the global sensitivity of the model. This enables us, for example to identify if there are any variables which do not have an actual effect on the output. We could use this approach to reduce our model by dropping of the corresponding terms.
[Global Sensitivity Analysis, Chris Rackauckas]
  • The Morris One-At-A-Time (OAT) Method
  • Sobol's Method (ANOVA)
  • Fourier Amplitude Sensitivity Sampling (FAST) and eFAST
  • etc.
Moreover, we must note that our neural differential equations can lead to inexact results due to truncation, numerical and measurement errors. Thus, we want to identify the types and sources of uncertainties. The subject of this matter is known as Uncertainty Quantification (UQ).

In the scope of this study we used PyTorch and the corresponding available libraries to solve problems in the field of scientific machine learning. For more elaborated studies in the area of scientific machine learning, it is recommended to use the open source software SciML based on the programming language Julia. [SciML Scientific Machine Learning Software]

About

Short intro to scientific machine learning using physics informed neuronal networks. I used PyTorch as a framework.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published