Skip to content

Commit

Permalink
Merge pull request #4 from decryptofy/model-refactor
Browse files Browse the repository at this point in the history
Model refactor
  • Loading branch information
decryptofy authored Mar 4, 2024
2 parents 57fb2db + 0c7dea0 commit d083c50
Show file tree
Hide file tree
Showing 24 changed files with 343 additions and 242 deletions.
34 changes: 17 additions & 17 deletions devtools/data_creation/correlation_data.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
import numpy as np
from src.scarr.models.utils import AES_SBOX
from src.scarr.model_values.utils import AES_SBOX

class CorrelationData:

def __init__(self, num_traces, sample_length, bytes=[0]) -> None:
def __init__(self, num_traces, sample_length, model_pos=[0]) -> None:
self.num_traces = num_traces
self.sample_length = sample_length
self.batch_size = 5000
self.tiles = [(0,0)]
self.bytes = bytes
self.model_positions = model_pos
self.key = None
self.plaintext = None
self.traces = None
Expand All @@ -19,22 +19,22 @@ def generate_data(self):
l = self.sample_length # number of points per trace

self.key = np.random.randint(0, 256, (16)) # just one key = 1x16
self.plaintext = np.random.randint(0, 256, (N, 16)) # 5000x16 plaintext bytes
self.plaintext = np.random.randint(0, 256, (N, 16)) # 5000x16 plaintext positions

self.traces = np.zeros((N, l), dtype=np.int64)

# Generate random HW traces
self.traces = np.random.randint(-128, +128, (N, l), dtype=np.int64)

# Put leakage where it is needed
for byte in range(16):
leak_plaintext = self.plaintext[:,byte]
leak_sbox_out = AES_SBOX[self.plaintext[:, byte] ^ self.key[byte]]
self.traces[:,4+byte] = np.subtract(leak_plaintext,128, dtype=np.int16)
self.traces[:,24+byte] = np.subtract(leak_sbox_out,128, dtype=np.int16)
for model_pos in range(16):
leak_plaintext = self.plaintext[:,model_pos]
leak_sbox_out = AES_SBOX[self.plaintext[:, model_pos] ^ self.key[model_pos]]
self.traces[:,4+model_pos] = np.subtract(leak_plaintext,128, dtype=np.int16)
self.traces[:,24+model_pos] = np.subtract(leak_sbox_out,128, dtype=np.int16)

def configure(self, tile_x, tile_y, bytes, convergence_step=None):
self.bytes = bytes
def configure(self, tile_x, tile_y, model_positions, convergence_step=None):
self.model_positions = model_positions
self.slices = []
batch_start_index = 0
while batch_start_index < self.num_traces:
Expand All @@ -53,17 +53,17 @@ def get_traces(self):
def get_key(self):
return self.key

def get_byte_batch(self, slice, byte):
def get_byte_batch(self, slice, model_pos):

return [self.plaintext[slice, [byte]], self.key[[byte]], self.traces[slice,:]]
return [self.plaintext[slice, [model_pos]], self.key[[model_pos]], self.traces[slice,:]]

def get_batches_by_byte(self, tile_x, tile_y, byte):
def get_batches_by_byte(self, tile_x, tile_y, model_pos):
for slice in self.slices:
yield self.get_byte_batch(slice, byte)
yield self.get_byte_batch(slice, model_pos)

def get_batch(self, slice):

return [self.plaintext[slice,self.bytes], self.key[self.bytes], self.traces[slice,:]]
return [self.plaintext[slice,self.model_positions], self.key[self.model_positions], self.traces[slice,:]]

def get_batches_all(self, tile_x, tile_y):
for slice in self.slices:
Expand All @@ -74,4 +74,4 @@ def get_batch_index(self, index):
if index >= len(self.slices):
return []

return [self.plaintext[self.slices[index], self.bytes], self.key[self.bytes], self.traces[self.slices[index], :]]
return [self.plaintext[self.slices[index], self.model_positions], self.key[self.model_positions], self.traces[self.slices[index], :]]
12 changes: 6 additions & 6 deletions src/scarr/container/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,14 @@ class Container:
- byte_positions: The byte positions to be processed by the algorithm has a default value of [0]
- tile_positions: The tile positions that are to have the byte positions processed default value of [(0,0)]
"""
def __init__(self, options: ContainerOptions, Async = True, byte_positions = [0], tile_positions = [(0,0)], filters = [], points=[], trace_index=[], slice=[], stride=1) -> None:
def __init__(self, options: ContainerOptions, Async = True, model_positions = [0], tile_positions = [(0,0)], filters = [], points=[], trace_index=[], slice=[], stride=1) -> None:
self.engine = options.engine
self.data = options.handler
self.data2 = options.handler2 # second trace (only t-test)

self.fetch_async = Async

self.bytes = byte_positions
self.model_positions = model_positions
self.tiles = tile_positions

self.filters = filters
Expand Down Expand Up @@ -82,17 +82,17 @@ def __init__(self, options: ContainerOptions, Async = True, byte_positions = [0]
def run(self):
self.engine.run(self)

def configure(self, tile_x, tile_y, bytes, convergence_step = None):
def configure(self, tile_x, tile_y, model_positions, convergence_step = None):
for filter in self.filters:
filter.configure(tile_x, tile_y)
# int() casting needed for random typing linux bug
return int(self.data.configure(tile_x, tile_y, bytes, self.slice_index, self.trace_index, self.time_slice, self.stride, convergence_step))
return int(self.data.configure(tile_x, tile_y, model_positions, self.slice_index, self.trace_index, self.time_slice, self.stride, convergence_step))

def configure2(self, tile_x, tile_y, bytes, convergence_step = None):
def configure2(self, tile_x, tile_y, model_positions, convergence_step = None):
for filter in self.filters:
filter.configure(tile_x, tile_y)
# int() casting needed for random typing linux bug
return int(self.data2.configure(tile_x, tile_y, bytes, self.slice_index, self.trace_index, self.time_slice, self.stride, convergence_step))
return int(self.data2.configure(tile_x, tile_y, model_positions, self.slice_index, self.trace_index, self.time_slice, self.stride, convergence_step))

def get_batches(self, tile_x, tile_y):
for batch in self.data.get_batch_generator():
Expand Down
33 changes: 16 additions & 17 deletions src/scarr/engines/NICV.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
# This Source Code Form is "Incompatible With Secondary Licenses", as
# defined by the Mozilla Public License, v. 2.0.

from .engine import Engine
from ..model_values.plaintext import PlainText
import numpy as np
import numba as nb
from .engine import Engine


class NICV(Engine):

def __init__(self) -> None:
# Creating all of the necessary information containers to compute the SNR
def __init__(self, model_value=PlainText()) -> None:
self.trace_counts = None
self.means = None
self.moments = None
self.results = None

super().__init__(model_value)

def update(self, traces: np.ndarray, plaintext: np.ndarray):
self.internal_state_update(traces, plaintext, self.trace_counts, self.sum, self.sum_sq)
def update(self, traces: np.ndarray, data: np.ndarray):
self.internal_state_update(traces, data, self.trace_counts, self.sum, self.sum_sq)

async def async_update(self, traces: np.ndarray, plaintext: np.ndarray):
self.internal_state_update(traces, plaintext, self.trace_counts, self.sum, self.sum_sq)
async def async_update(self, traces: np.ndarray, data: np.ndarray):
self.internal_state_update(traces, data, self.trace_counts, self.sum, self.sum_sq)

def calculate(self):

Expand All @@ -42,18 +44,15 @@ def calculate(self):

@staticmethod
@nb.njit(parallel=True, fastmath=True)
def internal_state_update(traces: np.ndarray, plaintext: np.ndarray, counts, sums, sums_sq):
def internal_state_update(traces: np.ndarray, data: np.ndarray, counts, sums, sums_sq):
for sample in nb.prange(traces.shape[1]):
for trace in range(traces.shape[0]):
if sample == 0:
counts[plaintext[trace]] += 1
sums[plaintext[trace], sample] += traces[trace, sample]
sums_sq[plaintext[trace], sample] += np.square(traces[trace, sample])
counts[data[trace]] += 1
sums[data[trace], sample] += traces[trace, sample]
sums_sq[data[trace], sample] += np.square(traces[trace, sample])

def populate(self, sample_length):
# Count for each plaintext value
self.trace_counts = np.zeros((256), dtype=np.uint16)
# Mean value for each hex value and each sample point
self.sum = np.zeros((256, sample_length), dtype=np.float32)
# Moment value for each hex value and each sample point
self.sum_sq = np.zeros((256, sample_length), dtype=np.float32)
self.trace_counts = np.zeros((self.model_value.num_vals), dtype=np.uint16)
self.sum = np.zeros((self.model_value.num_vals, sample_length), dtype=np.float32)
self.sum_sq = np.zeros((self.model_value.num_vals, sample_length), dtype=np.float32)
38 changes: 20 additions & 18 deletions src/scarr/engines/cpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,17 @@
# This Source Code Form is "Incompatible With Secondary Licenses", as
# defined by the Mozilla Public License, v. 2.0.

import numpy as np

from .engine import Engine
from ..model_values.model_value import ModelValue
from multiprocessing.pool import Pool
import numpy as np
import asyncio


class CPA(Engine):

def __init__(self, model, convergence_step=None) -> None:
self.model = model

def __init__(self, model_value: ModelValue, convergence_step=None) -> None:
self.trace_count = 0
self.model_sum = None
self.model_sq_sum = None
Expand All @@ -29,10 +29,12 @@ def __init__(self, model, convergence_step=None) -> None:
self.candidate = None
self.results = None

super().__init__(model_value)

def run(self, container):
self.final_results = None
self.final_candidates = None
self.populate(container.sample_length, len(container.bytes))
self.populate(container.sample_length, len(container.model_positions))

with Pool() as pool:
workload = []
Expand All @@ -47,12 +49,12 @@ def run(self, container):
for tile_x, tile_y, results, candidates in starmap_results:
if self.final_results is None:
self.final_results = np.zeros((len(container.tiles),
len(container.bytes),
len(container.model_positions),
results.shape[1],
256,
container.sample_length), dtype=np.float64)
self.final_candidates = np.zeros((len(container.tiles),
len(container.bytes),
len(container.model_positions),
results.shape[1]), dtype=np.uint8)

tile_index = container.tiles.index((tile_x, tile_y))
Expand All @@ -61,12 +63,12 @@ def run(self, container):

@staticmethod
def run_workload(self, container, tile_x, tile_y):
num_steps = container.configure(tile_x, tile_y, container.bytes, self.convergence_step)
num_steps = container.configure(tile_x, tile_y, container.model_positions, self.convergence_step)
if self.convergence_step is None:
self.convergence_step = np.inf

self.results = np.empty((len(container.bytes), num_steps, 256, container.sample_length), dtype=np.float64)
self.candidates = np.empty((len(container.bytes), num_steps), dtype=np.uint8)
self.results = np.empty((len(container.model_positions), num_steps, 256, container.sample_length), dtype=np.float64)
self.candidates = np.empty((len(container.model_positions), num_steps), dtype=np.uint8)

if container.fetch_async:
asyncio.run(self.batch_loop(container))
Expand All @@ -84,10 +86,10 @@ def run_workload(self, container, tile_x, tile_y):
converge_index += 1

# Generate modeled power values for plaintext values
model = np.apply_along_axis(self.model.calculate_table, axis=1, arr=batch[0])
data = self.model_value.calculate_all_tables(batch[:-1])
traces = batch[-1].astype(np.float32)

self.update(traces, model)
self.update(traces, data)
traces_processed += traces.shape[0]

result = self.calculate()
Expand All @@ -113,10 +115,10 @@ async def batch_loop(self, container):
converge_index += 1

# Generate modeled power values for plaintext values
model = np.apply_along_axis(self.model.calculate_table, axis=1, arr=batch[0])
data = self.model_value.calculate_all_tables(batch[:-1])
traces = batch[-1].astype(np.float32)

task = asyncio.create_task(self.async_update(traces, model))
task = asyncio.create_task(self.async_update(traces, data))
traces_processed += traces.shape[0]

batch = container.get_batch_index(index)
Expand Down Expand Up @@ -169,17 +171,17 @@ def calculate(self):
def get_candidate(self):
return self.final_candidates

def populate(self, sample_length, num_bytes):
def populate(self, sample_length, num_positions):
# Sum of the model so far
self.model_sum = np.zeros((256, num_bytes), dtype=np.float32)
self.model_sum = np.zeros((256, num_positions), dtype=np.float32)
# Sum of the model squared so far
self.model_sq_sum = np.zeros((256, num_bytes), dtype=np.float32)
self.model_sq_sum = np.zeros((256, num_positions), dtype=np.float32)
# Sum of the samples observed
self.sample_sum = np.zeros((sample_length), dtype=np.float32)
# Sum of the samples observed squared
self.sample_sq_sum = np.zeros((sample_length), dtype=np.float32)
# Sum of the product of the samples and the models
self.prod_sum = np.zeros((256 * num_bytes, sample_length), dtype=np.float32)
self.prod_sum = np.zeros((256 * num_positions, sample_length), dtype=np.float32)

def update(self, traces: np.ndarray, data: np.ndarray):
# Update the number of rows processed
Expand Down
40 changes: 21 additions & 19 deletions src/scarr/engines/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,71 +5,73 @@
# This Source Code Form is "Incompatible With Secondary Licenses", as
# defined by the Mozilla Public License, v. 2.0.

import numpy as np
from ..model_values.model_value import ModelValue
from multiprocessing.pool import Pool
import os
import numpy as np
import asyncio
import os


class Engine:
"""
Base class that engines inherit from.
"""
def __init__(self):
def __init__(self, model_value: ModelValue):
self.model_value = model_value
pass

def run(self, container):
final_results = np.zeros((len(container.tiles), len(container.bytes), container.sample_length), dtype=np.float32)
final_results = np.zeros((len(container.tiles), len(container.model_positions), container.sample_length), dtype=np.float32)
# with Pool(processes=int(os.cpu_count()/2),maxtasksperchild=1000) as pool: #used for benchmarking
with Pool(processes=int(os.cpu_count()/2)) as pool:
workload = []
for tile in container.tiles:
(tile_x, tile_y) = tile
for byte in container.bytes:
workload.append((self, container, tile_x, tile_y, byte))
for model_pos in container.model_positions:
workload.append((self, container, tile_x, tile_y, model_pos))
starmap_results = pool.starmap(self.run_workload, workload, chunksize=1) # Possibly more testing needed
pool.close()
pool.join()

for tile_x, tile_y, byte_pos, tmp_result in starmap_results:
for tile_x, tile_y, model_pos, tmp_result in starmap_results:
tile_index = list(container.tiles).index((tile_x, tile_y))
byte_index = list(container.bytes).index(byte_pos)
final_results[tile_index, byte_index] = tmp_result
model_pos_index = list(container.model_positions).index(model_pos)
final_results[tile_index, model_pos_index] = tmp_result

self.final_results = final_results

@staticmethod
def run_workload(self, container, tile_x, tile_y, byte):
def run_workload(self, container, tile_x, tile_y, model_pos):
self.populate(container.sample_length)
container.configure(tile_x, tile_y, [byte])
container.configure(tile_x, tile_y, [model_pos])
if container.fetch_async:
asyncio.run(self.batch_loop(container))
else:
for batch in container.get_batches(tile_x, tile_y, byte):
self.update(batch[-1], np.squeeze(batch[0]))
for batch in container.get_batches(tile_x, tile_y, model_pos):
self.update(batch[-1], self.model_value.calculate(batch[:-1]))

return tile_x, tile_y, byte, self.calculate()
return tile_x, tile_y, model_pos, self.calculate()

async def batch_loop(self, container):
index = 0
batch = container.get_batch_index(index)
index += 1

while len(batch) > 0:
task = asyncio.create_task(self.async_update(batch[-1], np.squeeze(batch[0])))
task = asyncio.create_task(self.async_update(batch[-1], self.model_value.calculate(batch[:-1])))
batch = container.get_batch_index(index)
index += 1
await task

def update(self, traces: np.ndarray, plaintext: np.ndarray):
def update(self, traces: np.ndarray, data: np.ndarray):
"""
Function that updates the statistics of the algorithm to be called by the container class.
Gets passed in an array of traces and an array of plaintext from the trace_handler class.
Returns None.
"""
pass

async def async_update(self, traces: np.ndarray, plaintext: np.ndarray):
async def async_update(self, traces: np.ndarray, data: np.ndarray):
"""
Function that updates the statistics of the algorithm to be called by the container class.
Gets passed in an array of traces and an array of plaintext from the trace_handler class.
Expand All @@ -91,5 +93,5 @@ def populate(self, sample_length):
"""
pass

def get_points(self, lower_lim, tile_index=0, byte_index=0,):
return list(np.where(np.abs(self.final_results[tile_index, byte_index]) >= lower_lim)[0])
def get_points(self, lower_lim, tile_index=0, model_pos_index=0,):
return list(np.where(np.abs(self.final_results[tile_index, model_pos_index]) >= lower_lim)[0])
Loading

0 comments on commit d083c50

Please sign in to comment.