diff --git a/src/core/templates/circuit-wrapper.ts.ejs b/src/core/templates/circuit-wrapper.ts.ejs index 5f0b55e..f2a372f 100644 --- a/src/core/templates/circuit-wrapper.ts.ejs +++ b/src/core/templates/circuit-wrapper.ts.ejs @@ -7,7 +7,7 @@ import { PublicSignals, } from "@solarity/zkit"; -import { flatten, reshape } from "<%= pathToUtils %>"; +import { normalizePublicSignals, denormalizePublicSignals } from "<%= pathToUtils %>"; export type <%= privateInputsTypeName %> = { <% for (let i = 0; i < privateInputs.length; i++) { -%> @@ -83,24 +83,11 @@ export class <%= circuitClassName %> extends CircuitZKit { } private _normalizePublicSignals(publicSignals: PublicSignals): <%= publicInputsTypeName %> { - const signalNames = this.getSignalNames(); - - let index = 0; - return signalNames.reduce((acc: any, signalName) => { - const dimensions = this.getSignalDimensions(signalName); - const size = dimensions.reduce((a, b) => a * b, 1); - acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions); - index += size; - return acc; - }, {}); + return normalizePublicSignals(publicSignals, this.getSignalNames(), this.getSignalDimensions); } private _denormalizePublicSignals(publicSignals: <%= publicInputsTypeName %>): PublicSignals { - const signalNames = this.getSignalNames(); - - return signalNames.reduce((acc: any[], signalName) => { - return acc.concat(flatten(publicSignals[signalName])); - }, []); + return denormalizePublicSignals(publicSignals, this.getSignalNames()); } } diff --git a/src/core/templates/utils.ts b/src/core/templates/utils.ts index d3d6807..edd1fdb 100644 --- a/src/core/templates/utils.ts +++ b/src/core/templates/utils.ts @@ -1,4 +1,29 @@ -export function reshape(array: number[], dimensions: number[]): any { +import { PublicSignals } from "@solarity/zkit"; + +export function normalizePublicSignals( + publicSignals: any[], + signalNames: string[], + getSignalDimensions: (name: string) => number[], +): any { + let index = 0; + return signalNames.reduce((acc: any, signalName) => { + const dimensions = getSignalDimensions(signalName); + const size = dimensions.reduce((a, b) => a * b, 1); + + acc[signalName] = reshape(publicSignals.slice(index, index + size), dimensions); + index += size; + + return acc; + }, {}); +} + +export function denormalizePublicSignals(publicSignals: any, signalNames: string[]): PublicSignals { + return signalNames.reduce((acc: any[], signalName) => { + return acc.concat(flatten(publicSignals[signalName])); + }, []); +} + +function reshape(array: number[], dimensions: number[]): any { if (dimensions.length === 0) { return array[0]; } @@ -14,6 +39,6 @@ export function reshape(array: number[], dimensions: number[]): any { return result; } -export function flatten(array: any): number[] { +function flatten(array: any): number[] { return Array.isArray(array) ? array.flatMap((array) => flatten(array)) : array; }