Skip to content

Commit

Permalink
Support passing pointer to Cell-to-Flux function.
Browse files Browse the repository at this point in the history
  • Loading branch information
pvc1989 committed Mar 11, 2024
1 parent b3ac68a commit 0d50acf
Show file tree
Hide file tree
Showing 4 changed files with 25 additions and 13 deletions.
6 changes: 4 additions & 2 deletions include/mini/spatial/dg/general.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ class General : public spatial::FiniteElement<Part> {
}

protected: // implement pure virtual methods declared in Base
void AddFluxDivergence(Cell const &cell, Scalar *data) const override {
using CellToFlux = typename Base::CellToFlux;
void AddFluxDivergence(CellToFlux cell_to_flux, Cell const &cell,
Scalar *data) const override {
const auto &gauss = cell.gauss();
for (int q = 0, n = gauss.CountPoints(); q < n; ++q) {
const auto &xyz = gauss.GetGlobalCoord(q);
auto flux = Base::GetFluxMatrix(cell.projection(), q);
auto flux = cell_to_flux(cell, q);
flux *= gauss.GetGlobalWeight(q);
auto grad = cell.projection().GlobalToBasisGradients(xyz);
Coeff prod = flux * grad;
Expand Down
14 changes: 8 additions & 6 deletions include/mini/spatial/dg/lobatto.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,16 +82,17 @@ class Lobatto : public General<Part> {
return gauss.GetGlobalWeight(q);
}
using FluxMatrix = typename Riemann::FluxMatrix;
static FluxMatrix GetWeightedFluxMatrix(
using CellToFlux = typename Base::CellToFlux;
static FluxMatrix GetWeightedFluxMatrix(CellToFlux cell_to_flux,
const Cell &cell, int q) requires(kLocal) {
auto flux = Base::GetFluxMatrix(cell.projection(), q);
auto flux = cell_to_flux(cell, q);
flux = cell.projection().GlobalFluxToLocalFlux(flux, q);
flux *= GetWeight(cell.gauss(), q);
return flux;
}
static FluxMatrix GetWeightedFluxMatrix(
static FluxMatrix GetWeightedFluxMatrix(CellToFlux cell_to_flux,
const Cell &cell, int q) requires(!kLocal) {
auto flux = Base::GetFluxMatrix(cell.projection(), q);
auto flux = cell_to_flux(cell, q);
flux *= GetWeight(cell.gauss(), q);
return flux;
}
Expand Down Expand Up @@ -138,10 +139,11 @@ class Lobatto : public General<Part> {
}

protected: // override virtual methods defined in Base
void AddFluxDivergence(Cell const &cell, Scalar *data) const override {
void AddFluxDivergence(CellToFlux cell_to_flux, Cell const &cell,
Scalar *data) const override {
const auto &gauss = cell.gauss();
for (int q = 0, n = gauss.CountPoints(); q < n; ++q) {
auto flux = GetWeightedFluxMatrix(cell, q);
auto flux = GetWeightedFluxMatrix(cell_to_flux, cell, q);
auto const &grad = cell.projection().GetBasisGradients(q);
Coeff prod = flux * grad;
cell.projection().AddCoeffTo(prod, data);
Expand Down
12 changes: 9 additions & 3 deletions include/mini/spatial/fem.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -220,11 +220,15 @@ class FiniteElement : public temporal::System<typename Part::Scalar> {
}

protected: // declare pure virtual methods to be implemented in subclasses
virtual void AddFluxDivergence(Cell const &cell, Scalar *data) const = 0;
using FluxMatrix = typename Riemann::FluxMatrix;
using CellToFlux = FluxMatrix (*)(const Cell &cell, int q);
virtual void AddFluxDivergence(CellToFlux cell_to_flux, Cell const &cell,
Scalar *data) const = 0;
virtual void AddFluxDivergence(Column *residual) const {
FluxMatrix (*func)(const Cell &, int) = &GetFluxMatrix;
for (const Cell &cell : this->part_ptr_->GetLocalCells()) {
auto *data = this->AddCellDataOffset(residual, cell.id());
this->AddFluxDivergence(cell, data);
this->AddFluxDivergence(func, cell, data);
}
}
virtual void AddFluxOnLocalFaces(Column *residual) const = 0;
Expand All @@ -237,7 +241,6 @@ class FiniteElement : public temporal::System<typename Part::Scalar> {
virtual void ApplySmartBoundary(Column *residual) const = 0;

protected:
using FluxMatrix = typename Riemann::FluxMatrix;
static FluxMatrix GetFluxMatrix(const Projection &projection, int q)
requires(!mini::riemann::Diffusive<Riemann>) {
return Riemann::GetFluxMatrix(projection.GetValue(q));
Expand All @@ -250,6 +253,9 @@ class FiniteElement : public temporal::System<typename Part::Scalar> {
Riemann::MinusViscousFlux(value, gradient, &flux_matrix);
return flux_matrix;
}
static FluxMatrix GetFluxMatrix(const Cell &cell, int q) {
return GetFluxMatrix(cell.projection(), q);
}
};

} // namespace spatial
Expand Down
6 changes: 4 additions & 2 deletions include/mini/spatial/fr/general.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -228,11 +228,13 @@ class General : public spatial::FiniteElement<Part> {
}

protected: // override virtual methods defined in Base
void AddFluxDivergence(Cell const &cell, Scalar *data) const override {
using CellToFlux = FluxMatrix (*)(const Cell &cell, int q);
void AddFluxDivergence(CellToFlux cell_to_flux, Cell const &cell,
Scalar *data) const override {
const auto &gauss = cell.gauss();
std::array<FluxMatrix, kCellQ> flux;
for (int q = 0, n = gauss.CountPoints(); q < n; ++q) {
FluxMatrix global_flux = Base::GetFluxMatrix(cell.projection(), q);
FluxMatrix global_flux = cell_to_flux(cell, q);
flux[q] = cell.projection().GlobalFluxToLocalFlux(global_flux, q);
}
for (int q = 0, n = gauss.CountPoints(); q < n; ++q) {
Expand Down

0 comments on commit 0d50acf

Please sign in to comment.