Skip to content

Commit

Permalink
Reduced code duplication
Browse files Browse the repository at this point in the history
  • Loading branch information
KyrylR committed Aug 3, 2024
1 parent ca97f49 commit 54cd525
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 18 deletions.
19 changes: 3 additions & 16 deletions src/core/templates/circuit-wrapper.ts.ejs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
PublicSignals,
} from "@solarity/zkit";

import { flatten, reshape } from "<%= pathToUtils %>";
import { normalizePublicSignals, denormalizePublicSignals } from "<%= pathToUtils %>";

export type <%= privateInputsTypeName %> = {
<% for (let i = 0; i < privateInputs.length; i++) { -%>
Expand Down Expand Up @@ -83,24 +83,11 @@ export class <%= circuitClassName %> extends CircuitZKit {
}

private _normalizePublicSignals(publicSignals: PublicSignals): <%= publicInputsTypeName %> {
const signalNames = this.getSignalNames();

let index = 0;
return signalNames.reduce((acc: any, signalName) => {
const dimensions = this.getSignalDimensions(signalName);
const size = dimensions.reduce((a, b) => a * b, 1);
acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions);
index += size;
return acc;
}, {});
return normalizePublicSignals(publicSignals, this.getSignalNames(), this.getSignalDimensions);
}

private _denormalizePublicSignals(publicSignals: <%= publicInputsTypeName %>): PublicSignals {
const signalNames = this.getSignalNames();

return signalNames.reduce((acc: any[], signalName) => {
return acc.concat(flatten(publicSignals[signalName]));
}, []);
return denormalizePublicSignals(publicSignals, this.getSignalNames());
}
}

Expand Down
29 changes: 27 additions & 2 deletions src/core/templates/utils.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,29 @@
export function reshape(array: number[], dimensions: number[]): any {
import { PublicSignals } from "@solarity/zkit";

export function normalizePublicSignals(
publicSignals: any[],
signalNames: string[],
getSignalDimensions: (name: string) => number[],
): any {
let index = 0;
return signalNames.reduce((acc: any, signalName) => {
const dimensions = getSignalDimensions(signalName);
const size = dimensions.reduce((a, b) => a * b, 1);

acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions);
index += size;

return acc;
}, {});
}

export function denormalizePublicSignals(publicSignals: any, signalNames: string[]): PublicSignals {
return signalNames.reduce((acc: any[], signalName) => {
return acc.concat(flatten(publicSignals[signalName]));
}, []);
}

function reshape(array: number[], dimensions: number[]): any {
if (dimensions.length === 0) {
return array[0];
}
Expand All @@ -14,6 +39,6 @@ export function reshape(array: number[], dimensions: number[]): any {
return result;
}

export function flatten(array: any): number[] {
function flatten(array: any): number[] {
return Array.isArray(array) ? array.flatMap((array) => flatten(array)) : array;
}

0 comments on commit 54cd525

Please sign in to comment.