Skip to content

Commit

Permalink
feat: add gelu non-linearity (#183)
Browse files Browse the repository at this point in the history
  • Loading branch information
sid-alluri authored Apr 14, 2023
1 parent af1e187 commit a3017cf
Show file tree
Hide file tree
Showing 21 changed files with 321 additions and 155 deletions.
5 changes: 2 additions & 3 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ jobs:
- name: Install solc
run: (hash svm 2>/dev/null || cargo install svm-rs) && svm install 0.8.17 && solc --version
- name: KZG prove and verify aggr tests
run: cargo test --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_ -- --include-ignored
run: cargo test --release --verbose tests_evm::kzg_evm_aggr_prove_and_verify_ -- --include-ignored --test-threads 8

examples:
runs-on: ubuntu-latest-32-cores
Expand Down Expand Up @@ -323,7 +323,7 @@ jobs:
- uses: jetli/wasm-pack-action@v0.4.0
- uses: actions/setup-python@v4
with:
python-version: '3.7'
python-version: "3.7"
- uses: actions-rs/toolchain@v1
with:
toolchain: nightly-2022-11-03
Expand All @@ -339,4 +339,3 @@ jobs:
run: source .env/bin/activate; maturin develop --features python-bindings
- name: Run pytest
run: source .env/bin/activate; pytest

7 changes: 7 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ colored = { version = "2.0.0", optional = true}
env_logger = { version = "0.10.0", optional = true}
colored_json = { version = "3.0.1", optional = true}
tokio = { version = "1.26.0", features = ["macros", "rt"] }
puruspe = "0.2.0"
bincode = "*"

# python binding related deps
Expand Down
40 changes: 40 additions & 0 deletions examples/onnx/1l_erf/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json
import torch
from torch import nn

class Circuit(nn.Module):
def __init__(self):
super(Circuit, self).__init__()

def forward(self, x):
return torch.special.erf(x)

def main():
torch_model = Circuit()
# Input to the model
shape = [3]
x = torch.rand(1,*shape, requires_grad=True)
torch_out = torch_model(x)
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"network.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})

d = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_shapes = [shape],
input_data = [d],
output_data = [((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])

# Serialize data into file:
json.dump( data, open( "input.json", 'w' ) )

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/onnx/1l_erf/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data":[[0.7970018,0.14464009,0.4023286]],"input_shapes":[[3]],"output_data":[[0.734375,0.1640625,0.421875]]}
13 changes: 13 additions & 0 deletions examples/onnx/1l_erf/network.onnx
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
pytorch2.0.0:n

inputoutput/Erf"Erf torch_jitZ!
input


batch_size
b"
output


batch_size
B
41 changes: 41 additions & 0 deletions examples/onnx/1l_gelu_noappx/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import json
import torch
from torch import nn

class Circuit(nn.Module):
def __init__(self):
super(Circuit, self).__init__()
self.layer = nn.GELU() # approximation = false in our case

def forward(self, x):
return self.layer(x)

def main():
torch_model = Circuit()
# Input to the model
shape = [3]
x = torch.rand(1,*shape, requires_grad=True)
torch_out = torch_model(x)
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"network.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})

d = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_shapes = [shape],
input_data = [d],
output_data = [((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])

# Serialize data into file:
json.dump( data, open( "input.json", 'w' ) )

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/onnx/1l_gelu_noappx/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data":[[0.61017877,0.21496391,0.8960367]],"input_shapes":[[3]],"output_data":[[0.44274902,0.12817383,0.72998047]]}
Binary file added examples/onnx/1l_gelu_noappx/network.onnx
Binary file not shown.
41 changes: 41 additions & 0 deletions examples/onnx/1l_gelu_tanh_appx/gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import json
import torch
from torch import nn

class Circuit(nn.Module):
def __init__(self):
super(Circuit, self).__init__()
self.layer = nn.GELU('tanh') # approximation = false in our case

def forward(self, x):
return self.layer(x)

def main():
torch_model = Circuit()
# Input to the model
shape = [3]
x = torch.rand(1,*shape, requires_grad=True)
torch_out = torch_model(x)
# Export the model
torch.onnx.export(torch_model, # model being run
x, # model input (or a tuple for multiple inputs)
"network.onnx", # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=10, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names = ['input'], # the model's input names
output_names = ['output'], # the model's output names
dynamic_axes={'input' : {0 : 'batch_size'}, # variable length axes
'output' : {0 : 'batch_size'}})

d = ((x).detach().numpy()).reshape([-1]).tolist()

data = dict(input_shapes = [shape],
input_data = [d],
output_data = [((o).detach().numpy()).reshape([-1]).tolist() for o in torch_out])

# Serialize data into file:
json.dump( data, open( "input.json", 'w' ) )

if __name__ == "__main__":
main()
1 change: 1 addition & 0 deletions examples/onnx/1l_gelu_tanh_appx/input.json
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
{"input_data":[[0.85212487,0.0874908,0.5229686]],"input_shapes":[[3]],"output_data":[[0.6819153,0.045654297,0.3659973]]}
Binary file added examples/onnx/1l_gelu_tanh_appx/network.onnx
Binary file not shown.
29 changes: 25 additions & 4 deletions src/circuit/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,9 @@ pub enum LookupOp {
Tanh {
scales: (usize, usize),
},
Erf {
scales: (usize, usize),
},
}

impl LookupOp {
Expand Down Expand Up @@ -227,6 +230,9 @@ impl LookupOp {
LookupOp::Tanh { scales } => {
Ok(tensor::ops::nonlinearities::tanh(&x, scales.0, scales.1))
}
LookupOp::Erf { scales } => {
Ok(tensor::ops::nonlinearities::erffunc(&x, scales.0, scales.1))
}
}
}

Expand All @@ -239,6 +245,7 @@ impl LookupOp {
LookupOp::Sigmoid { .. } => "SIGMOID",
LookupOp::Sqrt { .. } => "SQRT",
LookupOp::Tanh { .. } => "TANH",
LookupOp::Erf { .. } => "ERF",
}
}

Expand Down Expand Up @@ -549,6 +556,7 @@ impl OpKind {
"Sigmoid" => OpKind::Lookup(LookupOp::Sigmoid { scales: (1, 1) }),
"Sqrt" => OpKind::Lookup(LookupOp::Sqrt { scales: (1, 1) }),
"Tanh" => OpKind::Lookup(LookupOp::Tanh { scales: (1, 1) }),
"onnx.Erf" => OpKind::Lookup(LookupOp::Erf { scales: (1, 1) }),
"Div" => OpKind::Lookup(LookupOp::Div {
denom: utils::F32(1.0),
}),
Expand Down Expand Up @@ -586,22 +594,35 @@ impl OpKind {
}
}
}
/// Identify fused OpKind
/// is ploy type constrant
pub fn is_poly(&self) -> bool {
matches!(self, OpKind::Poly(_))
}

/// Identify fused OpKind
/// is lookup based op
pub fn is_lookup(&self) -> bool {
matches!(self, OpKind::Lookup(_))
}

/// Identify fused OpKind
/// is lookup based op
pub fn is_parameterized(&self) -> bool {
match self {
OpKind::Poly(Op::Affine) | OpKind::Poly(Op::Conv { .. }) => true,
_ => false,
}
}

/// is rescaled op
pub fn is_rescaled(&self) -> bool {
matches!(self, OpKind::Poly(Op::Rescaled { .. }))
}

/// is input
pub fn is_input(&self) -> bool {
matches!(self, OpKind::Input)
}

/// Identify constant OpKind
/// is const
pub fn is_const(&self) -> bool {
matches!(self, OpKind::Const)
}
Expand Down
5 changes: 5 additions & 0 deletions src/circuit/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,12 @@ impl<F: FieldExt> Table<F> {
let base = 2i128;
let smallest = -base.pow(self.bits as u32 - 1);
let largest = base.pow(self.bits as u32 - 1);
// let smallest = -base.pow(3);
// let largest = base.pow(3);
let inputs = Tensor::from(smallest..largest);
// println!("Are we here Tuesday input {:?}", inputs);
let evals = self.nonlinearity.f(inputs.clone())?;
// println!("Tuesday If we are here then evals {:?}", evals);

self.is_assigned = true;
layouter
Expand All @@ -77,6 +81,7 @@ impl<F: FieldExt> Table<F> {
row_offset,
|| Value::known(i128_to_felt::<F>(evals[row_offset])),
)?;
// println!("All good here inside assign table, Tuesday");
Ok(())
})
.collect::<Result<Vec<()>, halo2_proofs::plonk::Error>>()?;
Expand Down
Loading

0 comments on commit a3017cf

Please sign in to comment.