Skip to content

Commit

Permalink
Added support for Plonk Protocol per circuit
Browse files Browse the repository at this point in the history
  • Loading branch information
KyrylR committed Oct 3, 2024
1 parent 08c4b88 commit 902925f
Show file tree
Hide file tree
Showing 18 changed files with 88 additions and 14 deletions.
12 changes: 6 additions & 6 deletions package-lock.json

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions package.json
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"name": "@solarity/zktype",
"version": "0.3.1",
"version": "0.4.0",
"description": "Unleash TypeScript bindings for Circom circuits",
"main": "dist/index.js",
"types": "dist/index.d.ts",
Expand Down Expand Up @@ -49,7 +49,7 @@
"typescript": "5.5.4"
},
"peerDependencies": {
"@solarity/zkit": "^0.2.4"
"@solarity/zkit": "^0.3.0-rc.0"
},
"devDependencies": {
"@types/chai": "^4.3.12",
Expand Down
4 changes: 4 additions & 0 deletions src/constants/protocol.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
export const Groth16CalldataPointsType =
"[NumericString, NumericString], [[NumericString, NumericString], [NumericString, NumericString]], [NumericString, NumericString]";

export const PlonkCalldataPointsType = "NumericString[]";
46 changes: 46 additions & 0 deletions src/core/ZkitTSGenerator.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {

import { normalizeName } from "../utils";
import { SignalTypeNames, SignalVisibilityNames } from "../constants";
import { Groth16CalldataPointsType, PlonkCalldataPointsType } from "../constants/protocol";

export default class ZkitTSGenerator extends BaseTSGenerator {
protected async _genHardhatZkitTypeExtension(circuits: {
Expand Down Expand Up @@ -70,6 +71,8 @@ export default class ZkitTSGenerator extends BaseTSGenerator {
circuitArtifact: CircuitArtifact,
pathToGeneratedFile: string,
): Promise<string> {
this._validateCircuitArtifact(circuitArtifact);

const template = fs.readFileSync(path.join(__dirname, "templates", "circuit-wrapper.ts.ejs"), "utf8");

let outputCounter: number = 0;
Expand Down Expand Up @@ -114,11 +117,15 @@ export default class ZkitTSGenerator extends BaseTSGenerator {

const pathToUtils = path.join(this.getOutputTypesDir(), "utils");
const templateParams: WrapperTemplateParams = {
protocolTypeName: circuitArtifact.baseCircuitInfo.protocol,
protocolImplementerName: this._getProtocolImplementerName(circuitArtifact),
proofTypeInternalName: this._getProofTypeInternalName(circuitArtifact),
circuitClassName: this._getCircuitName(circuitArtifact),
publicInputsTypeName: this._getTypeName(circuitArtifact, "Public"),
calldataPubSignalsType: this._getCalldataPubSignalsType(calldataPubSignalsCount),
publicInputs,
privateInputs,
calldataPointsType: this._getCalldataPointsType(circuitArtifact),
proofTypeName: this._getTypeName(circuitArtifact, "Proof"),
privateInputsTypeName: this._getTypeName(circuitArtifact, "Private"),
pathToUtils: path.relative(path.dirname(pathToGeneratedFile), pathToUtils),
Expand Down Expand Up @@ -150,4 +157,43 @@ export default class ZkitTSGenerator extends BaseTSGenerator {

return signal.dimension.reduce((acc: number, dim: string) => acc * Number(dim), 1);
}

private _getProtocolImplementerName(circuitArtifact: CircuitArtifact): any {
switch (circuitArtifact.baseCircuitInfo.protocol) {
case "groth16":
return "Groth16Implementer";
case "plonk":
return "PlonkImplementer";
default:
throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`);
}
}

private _getProofTypeInternalName(circuitArtifact: CircuitArtifact): any {
switch (circuitArtifact.baseCircuitInfo.protocol) {
case "groth16":
return "Groth16Proof";
case "plonk":
return "PlonkProof";
default:
throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`);
}
}

private _getCalldataPointsType(circuitArtifact: CircuitArtifact): any {
switch (circuitArtifact.baseCircuitInfo.protocol) {
case "groth16":
return Groth16CalldataPointsType;
case "plonk":
return PlonkCalldataPointsType;
default:
throw new Error(`Unknown protocol: ${circuitArtifact.baseCircuitInfo.protocol}`);
}
}

private _validateCircuitArtifact(circuitArtifact: CircuitArtifact): void {
if (!circuitArtifact.baseCircuitInfo.protocol) {
throw new Error(`ZKType: Protocol is missing in the circuit artifact: ${circuitArtifact.circuitTemplateName}`);
}
}
}
13 changes: 7 additions & 6 deletions src/core/templates/circuit-wrapper.ts.ejs
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,12 @@ import {
CircuitZKit,
CircuitZKitConfig,
Groth16Proof,
PlonkProof,
NumberLike,
NumericString,
PublicSignals,
Groth16Implementer,
PlonkImplementer,
} from "@solarity/zkit";

import { normalizePublicSignals, denormalizePublicSignals } from "<%= pathToUtils %>";
Expand All @@ -22,20 +25,18 @@ export type <%= publicInputsTypeName %> = {
}

export type <%= proofTypeName %> = {
proof: Groth16Proof;
proof: <%= proofTypeInternalName %>;
publicSignals: <%= publicInputsTypeName %>;
}

export type Calldata = [
[NumericString, NumericString],
[[NumericString, NumericString], [NumericString, NumericString]],
[NumericString, NumericString],
<%= calldataPointsType %>,
<%= calldataPubSignalsType %>,
];

export class <%= circuitClassName %> extends CircuitZKit {
export class <%= circuitClassName %> extends CircuitZKit<"<%= protocolTypeName %>"> {
constructor(config: CircuitZKitConfig) {
super(config);
super(config, new <%= protocolImplementerName %>());
}

public async generateProof(inputs: <%= privateInputsTypeName %>): Promise<<%= proofTypeName %>> {
Expand Down
2 changes: 2 additions & 0 deletions src/types/circuitArtifact.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,12 @@ export type CircuitArtifact = {
/**
* Represents the base circuit information.
*
* @param {string} protocol - The proving system protocol used in the circuit.
* @param {number} constraintsNumber - The number of constraints in the circuit.
* @param {SignalInfo[]} signals - The array of `input` and `output` signals used in the circuit.
*/
export type BaseCircuitInfo = {
protocol: "groth16" | "plonk";
constraintsNumber: number;
signals: SignalInfo[];
};
Expand Down
5 changes: 5 additions & 0 deletions src/types/typesGenerator.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import { CircuitArtifact } from "./circuitArtifact";
import { Groth16CalldataPointsType, PlonkCalldataPointsType } from "../constants/protocol";

export interface ArtifactWithPath {
circuitArtifact: CircuitArtifact;
Expand All @@ -16,9 +17,13 @@ export interface DefaultWrapperTemplateParams {
}

export interface WrapperTemplateParams {
protocolTypeName: "groth16" | "plonk";
protocolImplementerName: "Groth16Implementer" | "PlonkImplementer";
proofTypeInternalName: "Groth16Proof" | "PlonkProof";
publicInputsTypeName: string;
privateInputs: Inputs[];
publicInputs: Inputs[];
calldataPointsType: typeof Groth16CalldataPointsType | typeof PlonkCalldataPointsType;
calldataPubSignalsType: string;
proofTypeName: string;
privateInputsTypeName: string;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"circuitSourceName": "circuits/fixture/credentialAtomicQueryMTPV2OnChainVoting.circom",
"baseCircuitInfo": {
"constraintsNumber": 86791,
"protocol": "groth16",
"signals": [
{
"name": "merklized",
Expand Down
1 change: 1 addition & 0 deletions test/fixture-cache/Multiplier2_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"circuitSourceName": "circuits/fixture/Basic.circom",
"baseCircuitInfo": {
"constraintsNumber": 1,
"protocol": "groth16",
"signals": [
{
"name": "in1",
Expand Down
1 change: 1 addition & 0 deletions test/fixture-cache/auth/EnhancedMultiplier_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"circuitSourceName": "circuits/fixture/auth/EMultiplier.circom",
"baseCircuitInfo": {
"constraintsNumber": 1,
"protocol": "groth16",
"signals": [
{
"name": "in1",
Expand Down
1 change: 1 addition & 0 deletions test/fixture-cache/auth/Matrix_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"circuitSourceName": "circuits/fixture/auth/Matrix.circom",
"baseCircuitInfo": {
"constraintsNumber": 8,
"protocol": "groth16",
"signals": [
{
"name": "a",
Expand Down
1 change: 1 addition & 0 deletions test/fixture-cache/auth/Multiplier2_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"circuitSourceName": "circuits/fixture/auth/BasicInAuth.circom",
"baseCircuitInfo": {
"constraintsNumber": 1,
"protocol": "groth16",
"signals": [
{
"name": "in1",
Expand Down
1 change: 1 addition & 0 deletions test/fixture-cache/lib/Multiplier2_artifacts.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"circuitSourceName": "circuits/fixture/lib/BasicInLib.circom",
"baseCircuitInfo": {
"constraintsNumber": 1,
"protocol": "groth16",
"signals": [
{
"name": "in1",
Expand Down
10 changes: 10 additions & 0 deletions test/helpers/index.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
import { CircuitTypesGenerator } from "../../src";
import { findProjectRoot } from "../../src/utils";

const circuitTypesGenerator = new CircuitTypesGenerator({
basePath: "test/fixture",
projectRoot: findProjectRoot(process.cwd()),
circuitsArtifactsPaths: ["test/fixture-cache/Multiplier2_artifacts.json"],
});

// circuitTypesGenerator.generateTypes().then(console.log).catch(console.error);

0 comments on commit 902925f

Please sign in to comment.