From c529e825545dda9e772b9e88e6654bd9c7d21648 Mon Sep 17 00:00:00 2001 From: Gustav Baumgart Date: Mon, 3 Apr 2023 13:53:29 -0700 Subject: [PATCH] updated feddyn implementation pytorch The new implementation matches the algorithm implemented in the paper. This is no longer is similar to the pseudocode they provide. The new implmentation includes an updated communication protocol where all trainers send the dataset size for that round to the aggregator at the beginning of every round. This allows the aggregator to compute adaptive hyperparameters for each trainer individually and then send ALPHA_ADPT along with the global model to begin training for the round. --- lib/python/flame/channel.py | 4 + lib/python/flame/common/constants.py | 10 +- .../flame/examples/medmnist_feddyn/README.md | 48 ++++ .../examples/medmnist_feddyn/__init__.py | 17 ++ .../medmnist_feddyn/aggregator/__init__.py | 17 ++ .../medmnist_feddyn/aggregator/main.py | 185 +++++++++++++++ .../medmnist_feddyn/aggregator/template.json | 81 +++++++ .../medmnist_feddyn/images/accuracy.png | Bin 0 -> 51388 bytes .../flame/examples/medmnist_feddyn/run.py | 87 +++++++ .../medmnist_feddyn/trainer/__init__.py | 17 ++ .../examples/medmnist_feddyn/trainer/main.py | 222 ++++++++++++++++++ .../medmnist_feddyn/trainer/sites.txt | 10 + .../medmnist_feddyn/trainer/template.json | 81 +++++++ .../mode/horizontal/coord_syncfl/trainer.py | 6 +- .../flame/mode/horizontal/feddyn/__init__.py | 16 ++ .../mode/horizontal/feddyn/top_aggregator.py | 154 ++++++++++++ .../flame/mode/horizontal/feddyn/trainer.py | 148 ++++++++++++ .../flame/mode/horizontal/syncfl/trainer.py | 7 +- lib/python/flame/mode/message.py | 3 + lib/python/flame/optimizer/feddyn.py | 85 ++++--- .../flame/optimizer/regularizer/default.py | 4 +- .../flame/optimizer/regularizer/feddyn.py | 25 +- 22 files changed, 1173 insertions(+), 54 deletions(-) create mode 100644 lib/python/flame/examples/medmnist_feddyn/README.md create mode 100644 lib/python/flame/examples/medmnist_feddyn/__init__.py create mode 100644 lib/python/flame/examples/medmnist_feddyn/aggregator/__init__.py create mode 100644 lib/python/flame/examples/medmnist_feddyn/aggregator/main.py create mode 100644 lib/python/flame/examples/medmnist_feddyn/aggregator/template.json create mode 100644 lib/python/flame/examples/medmnist_feddyn/images/accuracy.png create mode 100644 lib/python/flame/examples/medmnist_feddyn/run.py create mode 100644 lib/python/flame/examples/medmnist_feddyn/trainer/__init__.py create mode 100644 lib/python/flame/examples/medmnist_feddyn/trainer/main.py create mode 100644 lib/python/flame/examples/medmnist_feddyn/trainer/sites.txt create mode 100644 lib/python/flame/examples/medmnist_feddyn/trainer/template.json create mode 100644 lib/python/flame/mode/horizontal/feddyn/__init__.py create mode 100644 lib/python/flame/mode/horizontal/feddyn/top_aggregator.py create mode 100644 lib/python/flame/mode/horizontal/feddyn/trainer.py diff --git a/lib/python/flame/channel.py b/lib/python/flame/channel.py index 5b64a7c43..1be88d567 100644 --- a/lib/python/flame/channel.py +++ b/lib/python/flame/channel.py @@ -160,6 +160,10 @@ async def inner(): result, _ = run_async(inner(), self._backend.loop()) return result + + def all_ends(self): + """Return a list of all end ids.""" + return list(self._ends.keys()) def ends_digest(self) -> str: """Compute a digest of ends.""" diff --git a/lib/python/flame/common/constants.py b/lib/python/flame/common/constants.py index 12dcfbec8..8371dfeed 100644 --- a/lib/python/flame/common/constants.py +++ b/lib/python/flame/common/constants.py @@ -48,9 +48,9 @@ class DeviceType(Enum): CPU = 1 GPU = 2 -class TrainerState(Enum): - """Enum class for trainer state.""" +class TrainState(Enum): + """Enum class for train state.""" - PRE_TRAIN = 'pre_train' - DURING_TRAIN = 'during_train' - POST_TRAIN = 'post_train' + PRE = 'pre' + DURING = 'during' + POST = 'post' diff --git a/lib/python/flame/examples/medmnist_feddyn/README.md b/lib/python/flame/examples/medmnist_feddyn/README.md new file mode 100644 index 000000000..6282bbd30 --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/README.md @@ -0,0 +1,48 @@ +## FedDyn MedMNIST Example + +We use the PathMNIST dataset from (MedMNIST)[https://medmnist.com/] to go over an example of FedDyn (alpha=0.01). +Here, the alpha value can be specified in both `template.json` files in the `trainer` and `aggregator` folders. +We chose the most commonly used value in the (Federated Learning Based on Dynamic Regularization)[https://arxiv.org/abs/2111.04263] paper, along with the same `weight_decay` value used (0.001). The learning rate was chosen to be 0.001, because a larger one did not allow the models to train well. + +Since we include the `weight_decay` value as a hyperparameter to the feddyn optimizer in the config file, we recommend setting the `self.optimizer`'s `weight_decay` value in `trainer/main.py` to be 0.0, as shown below. + +```python +self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=0.0) +``` + +This example is run within the conda environment, so we activate the environment first. +Once you are in the `medmnist_feddyn` directory, run the following command. + +```bash +conda activate flame +``` + +If everything in the two `template.json` files represents the desired hyperparameters you would like, go head and run the following code to run the entire example: + +```bash +python run.py +``` + +This will generate the different config files needed to run the example with 10 trainers and 1 aggregator. + +All output will be stored in the `output` folder that is generated during runtime. +This includes all log files and data that was downloaded for the trainers/aggregator. +The aggregator folder should also include the list of accuracy/loss values derived from a global test set. +This folder can be deleted and will not affect your ability to re-run the example (in fact, if you re-run the example without deleting this folder, the `output` folder will be deleted first). + +To check the progress at this level, you can run the following command to check on the global model's accuracy: + +```bash +cat output/aggregator/log.txt | grep -i accuracy +``` + +Once the model is finished you should have that the command below should return 100 (or the number of specified rounds, if that was changed). + +```bash +cat output/aggregator/log.txt | grep -i accuracy | wc -l +``` + +We compared global test accuracy values using alpha=0.01 to using mu=0.01/0.0 in FedProx (FedProx with mu=0.0 is equivalent to FedAvg). + +![acc_feddyn](images/accuracy.png) + diff --git a/lib/python/flame/examples/medmnist_feddyn/__init__.py b/lib/python/flame/examples/medmnist_feddyn/__init__.py new file mode 100644 index 000000000..506f034ea --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + diff --git a/lib/python/flame/examples/medmnist_feddyn/aggregator/__init__.py b/lib/python/flame/examples/medmnist_feddyn/aggregator/__init__.py new file mode 100644 index 000000000..506f034ea --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/aggregator/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + diff --git a/lib/python/flame/examples/medmnist_feddyn/aggregator/main.py b/lib/python/flame/examples/medmnist_feddyn/aggregator/main.py new file mode 100644 index 000000000..53b07b7c5 --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/aggregator/main.py @@ -0,0 +1,185 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""MedMNIST FedProx aggregator for PyTorch.""" + +import logging + +from flame.common.util import get_dataset_filename +from flame.config import Config +from flame.dataset import Dataset +from flame.mode.horizontal.feddyn.top_aggregator import TopAggregator +import torch + +from sklearn.metrics import accuracy_score +import numpy as np +from PIL import Image +import torchvision + +logger = logging.getLogger(__name__) + +# keep track of losses/accuracies of global model +fed_acc = [] +fed_loss = [] + + +class PathMNISTDataset(torch.utils.data.Dataset): + + def __init__(self, filename, transform=None, as_rgb=False): + npz_file = np.load(filename) + + self.transform = transform + self.as_rgb = as_rgb + + self.imgs = npz_file["val_images"] + self.labels = npz_file["val_labels"] + + def __len__(self): + return self.imgs.shape[0] + + def __getitem__(self, index): + img, target = self.imgs[index], self.labels[index].astype(int) + img = Image.fromarray(img) + + if self.as_rgb: + img = img.convert('RGB') + + if self.transform is not None: + img = self.transform(img) + + return img, target + + +class CNN(torch.nn.Module): + """CNN Class""" + + def __init__(self, num_classes): + """Initialize.""" + super(CNN, self).__init__() + self.num_classes = num_classes + self.features = torch.nn.Sequential( + torch.nn.Conv2d(3, 6, kernel_size=3, padding=1), + torch.nn.BatchNorm2d(6), torch.nn.ReLU(), + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(6, 16, kernel_size=3, padding=1), + torch.nn.BatchNorm2d(16), torch.nn.ReLU(), + torch.nn.MaxPool2d(kernel_size=2, stride=2)) + self.fc = torch.nn.Linear(16 * 7 * 7, num_classes) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + +class PyTorchMedMNistAggregator(TopAggregator): + """PyTorch MedMNist Aggregator""" + + def __init__(self, config: Config) -> None: + self.config = config + self.model = None + self.dataset: Dataset = None # Not sure why we need this. + + self.batch_size = self.config.hyperparameters.batch_size + + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def initialize(self): + """Initialize.""" + self.model = CNN(num_classes=9).to(self.device) + self.criterion = torch.nn.CrossEntropyLoss() + + def load_data(self) -> None: + """Load a test dataset.""" + + filename = get_dataset_filename(self.config.dataset) + + data_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.485, 0.456, 0.406), + (0.229, 0.224, 0.225)) + ]) + + dataset = PathMNISTDataset(filename=filename, transform=data_transform) + + self.loader = torch.utils.data.DataLoader( + dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=4 * torch.cuda.device_count(), + pin_memory=True, + drop_last=True + ) + self.dataset_size = len(dataset) + + def train(self) -> None: + """Train a model.""" + # Implement this if training is needed in aggregator + pass + + def evaluate(self) -> None: + """Evaluate (test) a model.""" + self.model.eval() + loss_lst = list() + labels = torch.tensor([],device=self.device) + labels_pred = torch.tensor([],device=self.device) + with torch.no_grad(): + for data, label in self.loader: + data, label = data.to(self.device), label.to(self.device) + output = self.model(data) + loss = self.criterion(output, label.squeeze()) + loss_lst.append(loss.item()) + labels_pred = torch.cat([labels_pred, output.argmax(dim=1)], dim=0) + labels = torch.cat([labels, label], dim=0) + + labels_pred = labels_pred.cpu().detach().numpy() + labels = labels.cpu().detach().numpy() + val_acc = accuracy_score(labels, labels_pred) + + # loss here not as meaningful + val_loss = sum(loss_lst) / len(loss_lst) + self.update_metrics({"Val Loss": val_loss, "Val Accuracy": val_acc, "Testset Size": self.dataset_size}) + logger.info(f"Test Loss: {val_loss}") + logger.info(f"Test Accuracy: {val_acc}") + logger.info(f"Testset Size: {self.dataset_size}") + + # record losses/accuracies + global fed_acc, fed_loss + fed_acc.append(val_acc) + fed_loss.append(val_loss) + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='') + parser.add_argument('config', nargs='?', default="./config.json") + + args = parser.parse_args() + + config = Config(args.config) + + a = PyTorchMedMNistAggregator(config) + a.compose() + a.run() + + # write records to files + alpha = config.optimizer.kwargs['alpha'] + file1 = open(f'acc_alpha{alpha}.txt','w') + file1.write('\n'.join(map(str,fed_acc))) + file1.close() + file2 = open(f'loss_alpha{alpha}.txt','w') + file2.write('\n'.join(map(str,fed_loss))) + file2.close() diff --git a/lib/python/flame/examples/medmnist_feddyn/aggregator/template.json b/lib/python/flame/examples/medmnist_feddyn/aggregator/template.json new file mode 100644 index 000000000..06c5c351a --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/aggregator/template.json @@ -0,0 +1,81 @@ +{ + "taskid": "49d06b7526964db86cf37c70e8e0cdb6bdeaa742", + "backend": "mqtt", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "groupAssociation": { + "param-channel": "default" + }, + "channels": [ + { + "description": "Model update is sent from trainer to aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "param-channel", + "pair": [ + "trainer", + "aggregator" + ], + "funcTags": { + "aggregator": [ + "distribute", + "aggregate", + "getDatasetSize" + ], + "trainer": [ + "fetch", + "upload", + "uploadDatasetSize" + ] + } + } + ], + "dataset": "https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/all_val.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 50, + "learningRate": 0.001, + "rounds": 100, + "epochs": 4 + }, + "baseModel": { + "name": "", + "version": 2 + }, + "job": { + "id": "336a358619ab59012eabeefb", + "name": "medmnist" + }, + "registry": { + "sort": "dummy", + "uri": "" + }, + "selector": { + "sort": "default", + "kwargs": {} + }, + "optimizer": { + "sort": "feddyn", + "kwargs": { + "alpha": 0.01, + "weight_decay": 0.001 + } + }, + "maxRunTime": 300, + "realm": "default", + "role": "aggregator" +} diff --git a/lib/python/flame/examples/medmnist_feddyn/images/accuracy.png b/lib/python/flame/examples/medmnist_feddyn/images/accuracy.png new file mode 100644 index 0000000000000000000000000000000000000000..28236361607deca69005518a5262787de82fd1c3 GIT binary patch literal 51388 zcmeFYWmJ_-_%FO^X$7TAN$C*j5ESW_ZiG#DcO%_h0#ee_ozmUi-5pYAc>eEtKc02Y z_p=s=ZuY+Kxo58HS2G65N`FO1AwYpZAn0PEpXDJCST6_!X5|GUc!hm*VFUc&viqW9 zr(kJl=csFA0FlzQvof=^Gc(pBb1<;6HMX>1W8`3Dp(iu4v$L}0W@0k`|KGr9X#-`N zM|lwlE`n?&s%i^?VCp`9!{iI(8ABjF@5MfUQgli^SafuH{SCKub-DUtN%(~yywFZ& zz)}?B3)rkNdNTR@&tJ@5;FQV3EB^W+9+maeYhuLp*bjaTZnp{QxP7M|?N8}1E49zxdKGFZ2Zj;lID8 zv%Z78{-3w@{eZ~+_o>zQL%$l?rZmwg9Ec%nG6iU?r`Gcy>dFGqSpTYkdqTDhr_uq(wR zJQ2JIo|AG3YG~N?;}v-#yD1Jfm84*%EboXdYWuq2^Q5FCXnV}=Y?FTtFVpW1*PEx4 zS5m@xxYkL~}4So}04oQrQ^UujK_`dcj4&X+>tgKV@*VP$we= zQIVBx4SqxCv>wJG-q<1*LxuU^c-ZcbdVeu4Bbmk>SD=s=Upu_g=6eg4X{O1QBaO$= zBmeM~Uocs{7s4vZ&ZW!NK$38i%Ox`#TgT*N9E*NeWOTG%<#G@<_0CcgkF>OOVoHia zwg|eG^X41Dqca-^2SlBJ7rOx^CC1~qi>LSd)1}nZ)PbLT5l`DunNF^kJ!4W*)XiEJ z7Bq6DQUXd#nG6jLYa1HabXp-Lnhk~*8e2nY626G&&v%@iO(NlQv0Cdy7Uq#4e6n|S z%>$ni%3K-Ck!ZYG3z?apAFno7hYY3hboWG&pI%*c4Gf@haBw`2COSGgn3#i#DnMD= zJ&yLZo3!$LT$gl6^gqpg36ICoD3-1 zFD}ujd#<6!`&%!fe8~*{G_@)dT|>iw>}+!L`Dz1DAe1C`S`!5dNZ<*xui?n}`P0F@ z-Y6?84-5<#P2|67dALd$<#)Y0cyCQ3Bz8Ov##sL6&pR-t$A>$80|Wnx3by0sJ4Wub zp`Fu$f`U*YE-ZC*^@D>0JQh88@WS6Y5<yd1B&d~xgJECn`$iHrKt7nwoP=*{m+K*b5^oXgx0d@a*8ia0!%}bt0OEd?9EZb z@l2g9HJDL8et!N?VjdibSoF?xDJfXSl`xK_p8$S}v|G|&k&C?MS!(r$t8+Si2XNsT zer|4@91dnaeg96OrKOcqolPUkR4l>4!;{k4*{L^wcyJ(FxGA9o8YxKI$s%vmsI01r zeXbe|20UKpW`DvlUeQP6F+X1<%%xLaAKCX}K4h`R3$@l`#VxdV*T)GEV*n>2cjk-r zlSX&{F83(~xjpWk3c965MI#}*0N)rKT|r+^RH-pAs8#lq6$6NoZ}E5ty;85YD}WcR zFdU?HKx@)s5{srdoU4+6{RQwLzFtN}B@te<9cZu&F1lO{QU4e*KCC@FJP8v6G)6SU)rh>1hYnM@pw zYdQikab5Wx>+N?&+IEOQ1w1nAwrQ&aKa*`g8?jcVD==U&s%(ScE+_u!*?Nu_+0la)Pb9eUc%@XS)A#pks8 zsXGrQbM}uxh#c;Pw&#O?Obj+C7#fGrEVKFNe+y;O#G^OZXSm9Z-7wnP+IBpFKg1*@ ze;M@0LsGfzQ6d|%!E*^NH@iy=@Z(e&pRBYcCMA{Y{YB|;pRYD2XJI)FU++dl!}C%r zHoV-M3^2!K)OZG8GrbyGyLb=NwBz8f0;aXemf<$pRnmS}J|e1VS6&RwHc{P8KMj>6veSD*3+KI4{=E#yQ-z8o7Hv$D2+fBFbI3c_u- z(+c`5iQk>O0y>-v@a^JwIRjVQ?UjPUI2akImS<2A6B7$KwFTTwB_NPdWjf93a$)pL zRU)9$(H9LV>l_XVzWr8gYinCgpI3+mFz|XlJ~egVVt4%Zd{_XzR4xiY1_Vr#sEkaW z$HPAZl4US+&F*)9Y&OODpYCl1j~*{iPHNY^+yRl)+wHc42}Fy~Y;s{YshS%>$2?6 zvCbiKK-0c{{=6_=7t-ap)VH*>MB^3m<45b_A~7hh><5$S67!3Sf`a#;V%U!x&oN3$ zN-E7}qxs$cF}vMZfpt5Y!6~p(5flT;&A$888}vRq%8Pz-Ufxv8mWOx*RGinY9*!n? zi3s$(L7yQsFC4{6r`Olk^kUuXzm0yW&TpQHBL6y5YfTPHYCTR_YZWZkHlV`I<`;{! z6cm2z-4SqS{$~zAYca!nVTv2$w#-vwF6>R2X-BpBBH8YYyj^lV=DUA*0OToDIz68K zHNy3`J|`Gmhf@h4ij(7(r!$~na8Ogl&w$EcHkqIqO6BrxX%R4;D$0~dVC3ZD!oMxTFiEXyZR5^)K9e}_wx2aPDSZ^B}o1@x{ zprD}D{y`uSPESsLf&vRD_be|j_qzZ4JbPe1@!dTDmz0&2(VU*szlci%Mr4)o_U5L{ zc3XO`(paGcRk%N%!Fcu$CEzG zW|BZs4>DHPXmIV-_4SU??62=>XauuIWB+M@;fjD!6{xmm9o5f6eU_xCXelZETU!Nz z-nCG4Sid^@28BX_#;&zmZ6}VH&Hjp+nwF*tNFD|T<~tP1W;Ro1QIZ&j&eR zZf(^**uic1Ok!p&<&p7*cj5^G$OVKbk~fZDk7=1RX@ZWs5gQw;r>Xh)_y`H0%%!KN z_gcd}9V=2P_W^Lv^1bI}y62;F_PpzH6Os;%H>=f(AQm<@x`C9%w}}#OIHY%C@<9f2 z!B=2~+j@J0z%<}5KCMRx#1{@)vEA9-?Xy6Kh_W-{#3>){uA%E)VF9ckD65bKiPsk^LOPw-dHAp~fyGPP?Gyp1 zd5-!w4;V7(yh%WF;T9@8?@hc{tFdq&K78%hozw$CY1iBX7EQk^R2Oh=Dwk~lG~F37 zOhi~z^aWtJXJf((l8{S%7j_-L1rYE)%i@zbHsZNIn2H^M+381TYI5fB}!E2S+NYnK-Mt;Zbu6gr~}XmyMRp;6czoE$>Z?!_+Ym;(Q8#Qt>qe4+Was0`KRGva=GO) zzmKnP|3mWZ%*;r$yKD3979k1>%E`%z4gd^4fB!f0j$NVnX3NcKKv^*ZBG{WK=m6B9 z(%`@Vx{tYU1dWtGmDJ8!lX#A>r4nk${w2i%?%7%@Jd%Li0k#_t%}I#H~i06cGliWeafkz@j+CLXU7Gr;sO#Yz<(4=zCFIRVAq$NLqawY8nychIbq z;nTeKdx}B$pMi1*CJ13+9rE^j_Qsl22O zn(uFYeG4A^7chF_TU%SvRaI3JKR>}41JWJ?&r2y5{R;rBm~bFQ9FT%kNy*9LjEs!( z8T{_Qu&Jf-h_`{fkq26w5*s^)Ria%ok@*wsFHl5rb5lT*iFw&w9jJ;$Z~Va}b-##G z1SV80dsGzogE0WMoa)%9Oxv z_0G+y5%4;dy(T5qHo;E>NTX*S5*{9(?s~-4J2WKyr>ZIrundJ}i$_!EHqonBZ_!El zHGtENmz0vCw6wH*meTS-nU}9m6)MZf)5gi$+1VA8N{Ee(k3XwCJxb+Zx6{A4@*y}3 z6g)gUy}Kw!CSxU}U_%_Jyw9%ii8dR@G+bjO{Jb{7j5ml_yWaRA!U zl0dbAF~mBB`}x5eue5qc`}_NUx3pve2rE)=wJH>3phM5V5W?5I^HBdyUOpN)H0Iu} zE|KTKgX*VB<9DCiNd$cKd*6N z+2)<`UuPAuG6Q;TRxDOy7v6pQ$!Uy+q9Alv1w?2ObzNbE>nkgQIyyREIKI&2jf}|N z-`~&tF+kM>wqR%LbxcA+U_t_}Lat=pcKf`~NiDWMNZRr%pUWPUEgrv{6Zl{)D!}Q+ ze|v>)U_st2#$rAE+b=cxkSsq4KT$=b{F_YZy+eo9TNcw6*O2T_pCaiwk#pdV0YA&A90EJM!nJ!PVMbZRYL*28*I_(`|{;W!Sk)bBoHN#F^zBS?)Lg2V*{iSFMlXhENK5q z>fQJB$|$g2S2q;`g4m1 zn2YB`1CecWbFzv>`dmjq7mR)c{sn`GQ+Ch!?CcDvX^wj2|9Pg5GM^!diHSd?^FT^9 zY1B|yNC#M$l0O&V?Labn2$234E_@*7TLHDa1{SMNI zYr8%4Y!7d4Z*}$b+HX$RpF;<6gI%UxpmhY*)RI?m-n_Ydp-%rlWu?RW3g#zJi?ngU zwY8ktht}5CwRLrIH+VoweFWYL#CESY;0_h+KB7u{M1d+Dy3#-js!!9yH#OhW8>xJ1>wj&;F62|>0@0E<(KfQnb8zM z{;s9gWEACDDeqpi>x*lSqC3exQy>Twji!q5flb!|68O7;0SjOi$a7d37>ESIL--Y7 zJ%OG8VoF9*5)Gg$kYa4kXB)w81Yj<&%{@imvDLt@nwj^NV>bMhrJ=0{X#uZcdLqQC zC-N-?1qCyp0zG~G)9Y*K{#0>5(A%h}D3SF$8(Z63KbEhKGI+V$LcNB;K})JPb`XA7}-LD>U`S5~ew#V$t&2mn58tc|5gmH;iot)0xiK)eamx+6BWv8Nt%58!MpS4?~ z2j|Yzsy{iKudPn&c^^yTBFG=p9lUlBaY)i`LN%nn$Dm_~11#gc8U3vsGH@wqS4km%VMnf$73$6X_6XS~c*EInLa_Pw$m>5k*O z=gHu3TY1@CKZFVz3^~!og850t3qP%3i&t?^kutojbEXX?HQJ(OXz=2Al*V%YfPD{fr)*M%#bj}%a;Sx0bQ7h|c zqWIpHYAZ9n@ZAAQu>P)h+1Il>PUAsUL?+&Wo4CBuebR;7tEK2m9_%=m;dPnwDi_wX zjb-<*uOR%d!IqBS{DNrz>dIuN>!V@Oo%GT<>Fh9vX71?4;l!OCVUy>~Ft|x#)kB}Y0{1VPJLUOF0wf7+A?!aY_Vj=B z2rplL-!@e(yQ4L4nf<94qJEO^qOnSNADOTV5<)8B1lzQDS- z-|IzzoL6#F74Z&2Qy8sT*H#}N4>5u{?xp5jSmKk^83iKO_b#yRgf}3vnO?lbOZgO| zi!^(W-`>BONi-%Ay@S_(o(uJz3oIrE?mmf-h_TZKzIzwu&;z%_G*YJ!QU}dV6(?4u zT$6z66Zn<>!Gzv{hKBW63i2AS9He##-so@cToXzjrtb5tGX+Ot_!0r)!7<;n@;~SH z8ja~gxwafPy3^kfAMNrx8N0q7n%U8kA$_oPJV(&E3SY57!jEWhD^6d=3><$<;d5&^ zM8U7)fV+zhz#ytk76p?)o6}f1ZyESbl%Y{J*Kpf*H-L9GJEsN?^3jmMgj*fPYU`~% z-10xJpy8hN{o+rXE(^^o@a`ar0nx<4$_#CQ0yq7C|MYU{_-v z!AoEF8fSp&jl%u4_E-nsFzi=tyC89#YM9gpvEgL?~UI2oUENJIst!rhd9?~kp--*NDWf7z}W zeOiCc*zw=o@j}+QZ(rU>BjKQ1b_F)w?PILBI9N6oHTB-+ck(!HpBe47{%5?kb!p62NVs)&MxEy4r?X8?aq z##I|i5roCmvo9?IccK?!a#N=ola$Bq_LdVTK9s^6jE!;1FrfQGi6K&K3*n80&o2Pi z4hX&LtOCc{q+W4f$+-wzwjsGjdg2bv;fWJt0_0q`+*(f$6zYtfdRg*+PBwFEXA`6H zCYLAiDwI8VKJAw}@Ks9gJj~f{Zrx0EWMW$KZXOV~b^h8~(S;8Ekz;Twbv1|NKDb7; z_FG@uA7qJJ3&d!A}{hRLYAN8qBh21)C z?mHJCr6RAn7*LQmB(W^z_oZlCG5of>ee9hi^0yMY|AvGoHL^C;)k0kp#Y9QSln?U)LeH%uU+p$=A%?aH@|M zm(XR;-<8>Yx7D&#hVx?jsJ(S}sC>jp1>$ZDEvfJKeF~MSuX=D(m~1H-eS0>{n&%xX zy)C_u46o(LWq2lNDM_!+;5MdNX>Z$PO6rtdnm!AO+`@$v@iLGn5L459gac^8f2AHa^^Ji0yL0NVy-G=81kBEcO)wb!2Ls=Q5iDY5ZahyjR{kc~bG}Ga*R=3CsYv zlP^#IRFZ!?S7c)fCp%{)=-l^>UoRM-a&$jcUEAl@C8YZEzlySv(MS7&oERtr@71o5 zSE;ilZiX0GLYb-9x_{W@TxVbpyS&>o%*4?)*%e!Onk1=aTiq#O2$0TkI-Qb{FCv*y zDc~pOPX%+_`bqmQg$arOhQD;&j@zQDcycGpJk_-(W6S+TyjZrljd~&f2ld3J@~4cj zW5?A;_B0jlEUqa-ldNdl8|z_qm!U?eDD}I5*^V>gW51 z2)!XGikHB}6RdbZ{)Slz<@Zc^S}XdZMJJAR2&SeKJzEIM*-1R+bjH`bgq_8aG-(eT z++wf;8?wDOUJ{SZg}|-#aCk$!d>8%802i+dqXtxh)*C{4NkvepFEzxyv~lF%|E*%f z{5+Wu9UOx8w0H2bxssmx`%YO_9msM*7WeB`b)gtNaFMPPj?mRI!^Hf!gWA{K3#~7+ zyMg38nLDPQai@gAu8O{8`qawj0d)4o6JAMV_!LHqYx29hT> z9uW<;EcA|OhnF4b;FTM99rpK5m*+Fzag-=XP0+4uFHW(iasJcA-9k5)yml!8_XDzM zEwJ_FKPFDm3ML0;>q%pN4bOCQ+;wKT_qV5;xjB#4-)LTk?kmO&-%Wqb^b!dFJo1>b zc>fYlZ)a?9dM~Mk&JSvMtPiv^l?f$NO;J}aviU4$CPPemfmBAokto*nhn|IzhkavH zJh)$=AwHd5=~o|$J^$5zeQoB>%43}PQ3D|%5v(iSzP}6-NNOJ(Xi{xQjRB+HR$6Li-+iE^6lk7+XZnh| zHgwEnPyQNcd`rvY;rPSs5FHDh7wSwsn%2*1_2OLjN<*wwGW+||NY_o> z>}%J?JYR%$?q2+IZOniQI+}F1v~)E_C4i7=^+z9p8VVGSNvuA_DsDm9xLlSg37sX;bKec1d^VE;P%EDB*y4jk z#aDO7`so<9rRCV@96p7a2UGc*FI2x`(RAdtYPH_1FTdwO#qQY7GMX}04PS`9M3oqz zJw?^|GN(TCFkfDc%7EH%f~iTjoRIP~)bAMrmK{vAUeKLir@_hM6&*x!cNh9Wr1D2q zQUIpwB#ch4cbPm}@a*jzmUfGKjolN@`#nT{9FJ@4Prv*-xvnjuc|kEV;$o>8Ux%j@I07x6UdG z(o%xRIO+wE+cl9>+8w?k! z%4;+GE8Diyc!*4nB3>74`{1RgB{UyNrHivQi^chp&&+5=S3;>d9HBvC2!3;8f5u6K zk(@PObyDxTGOe2tR7r32GCdPgsNPqt}nAE4So?3QM%i3dp!lWx%81P38P1Bizw z+pxkvhLUs+P2=9DX(ne%32FgBs0PHALfgpzQS02u)6nz!l^YnSwV^vW^t<*|0L;Zm zxzt%KrzZpd%!wCZ3ux|I(VvR{izsrQb3gv`#$zZH^6ULw}Eu(CB<(t0*+&WT$4vmzW=IZ=-%- zv3*i!S|THt3CZ*2VlaJ;$!&_`0Jmj&N$Fq^AR@w|bjh^b6}s{@hEx{!`^z*vHY0l- z(SB5{s`;7hi@@E~d0O^a6a&G$H)+NcS=Eo)KtD^KHOp$I_3QpHA#2U88lNS-fY^7N zkLeg&_c!w6tP;OzUZe0IF+90zcyb5W%Ds>DAU><;=^H=W`-80^)1MZ*D0!oJ8^X)Czt87SW#lJfYDwJnbnbH2%@H1D?x*1|>*LT#g>uRy7@L*vK z`HwV3#+VpM{Mmu^*a4}MqKmdAy59bmy3e1X0=O4gMcg=I-ImwP-&b`pG@7`Uem4g) z_@+3I+n(RSDfMs;})04`-CX!#n*P zF31gGA5Smt9>r`pxokzWE8Op^@g=~gLQ$l!gwib?YNgaP-0CTZW1(NgQw1CyAJ=HV zpFZdcnchT39hw53?DcU8tW&kz8$db~pozvknRp{Z>|88hrm2_F>RsyQ*4E7oXpm##6})2Nf^=6NixtHY zTEa})?r{kTVN>UCai!sr!-r|?y`E+m|`Z_`$7Y827LL`A2U zP^RoiuMCFBwMJb|$j4oR+o$HivkPS>?Q?1z)u;=3l2y`}h#+Ks1G~^48)kpb?15i) zN{H07P@KdtWn(wyLJ(mKwp9M`SejEq4&&dpS8f_)wda_)4149SzNf!FZo64t7ZM(; z>B_o_qA?4r>?kzD|AQB{9Jep-kk-YW`-Zz;`&XrNs8w@c!XD*^ClFhW;B9dP7JC0B zX{cbnn;z?q!gQT(8y-ylRpiFMWCbHuy1kCR%>}%2RY7tNd2#|>qKWKe{5FB>iX^Rb ze_(qNwh;n-2J20M@DsvR&Dtr4pemSl0K3#fnI z(!oRvt*p7su)Oi`#AzN~gRJnsTURSy-XsuGt!da1K(oi0ztO3(W^qXs_Alz7`1|0)5mD{2Wn=N+QRz_GT2xw+%Z}gaJ)WiV>XdRr$-|^1AR5-H_ zx%%>!y*hTeCF^Lp#9o8>Tj7m`EIA_mj7j0$*!4NrfJKq~{n!-G*AQMn{8Yv@`p-IMyjq^<(W2RCK+DZ!1j5OYG%ty1idS#jxjGAK$T5PYI z{Y}JJ8ko0_+a+Yu)#Swg6fG!}6%b=kFgyFL`yKki_5 zj%iWZ-E0B%Fq;f*8Bbn!Jn^jD%vvDEv&E{HWE+5#s0&?NMqO2<7PuI{S`=3L^umT2 zx(?IK*IL8Vi^4O2#U!s>@H59BY3k0fV$|*}?S;`KPJR=Q;0eym#YavjUuE^|^^B&R z{`;{oyXOEscv`J5V{yT5h8f4_V=qtb$4CL*df1W8{x&J0mu7e&9@Ds~5=ZR_n)Z#u z%Rl4RU^`PGP4SMOnz5GUUF4v*_z=HX};K267Ip+d8F)RMR>m2Gq2!T_!Nf^WR@% zB+h=S84hKAe_hzpg47@J*@~7mj_IjkpfNK{VsV-ou9k=`uG{{1uK=_JJ1KP#MC}BN zc}ld2W8)FZ^Ha=y6c)U$X_(XfZOi=(g`?T&JVy?Owe!i=-1Oe7Eop`wN&CP8tQD8| z$IuUz_!T*o6`-IaTAz)5gOzXDwy)51;_|@ksj$fd$tfz!t{>I+wL>$5o=0@!qoCry z45rCiOEgF%5b}>c5@~wFQU!`9+dJ{Uwv7t5oYW&}#3FGRkaR|=H!P}3##mkv zq}yekl++i!UVZR;n6dL%&7_a`NXl^KJ8UV}n&Bt%XRwm;sI7Cvh>TnNa9c#-L7Tqq zu#@TOlX!B>{-q2YCvHdDRz@VBiQb6`kT5fRgu7AVv-zl(r;6Y;MGmpQcNUHC#+F#= z#3J$<{DRzIa3(;j6T|4Wf=J42EcjLB8b#xtP^ks26u~DuV)f>WX8*C>xdHI)AjM)a zT79b|uq_&AUiwL|?Ec=>9ZBjd9tR^(I`sY>0$Gv~tJs&vXKy(EZ80WoedA2u4V~)q zwJ1*ZDW>L}6xuu*F0Ht>W_TOYQI56f>Vb}q*wKPb_V^C8!}lk)Fg(E$SA|;%X%Czt zk^oCRGKBmj1!gZ?^P!a}Y&6%F0R5I8?{aZC@v!^SUkS@Ku~wo%UOw&Dss&V~mpqLw zwJg@EHp&)-P{dW`*gDLAm>(NIH7)2ovfV}gMc{LmsL*OIopdr{A37mQd2`hW;;DD< z*@iiNU3gHD%B;xv8oVU9od=5IWF%<54u@J@KdVmyut!kD_v&w-e$?mJ82>DH4_lSt z2$n8?1-#Sez8P7XuIt1vlKC~@#9S5M;=e2%t#eMn`4D(tq0DpQzO^(&E@pMB?MxJ_ zFJ*Nqyt~m4*#7+B9?p!HutOs3z-ae(!nOSc3w(<4EP@|ya|=43`0(Rvpej^m1cW5J(7Zu8t8x*-uDH=VHwX`)JYPHyE2|nQ zUhO&H5kYnFgP`pot<8x`b)+0t5*PY?Gb#4HoT5^;I2XA&}c2 z)J3;eV<+rx?Du#V2Ia@1Xc?22mIyhpnmLmWmDb%-JmNu0Esz@4RQW+Nv0fg}RZw;) zTw9IGYJqp0nu0R?q^7sBG&_ud9Tz>)NiVd0c`NZiBI&@%XOlq?yN493$+uKJ?@k6sfO1ai}}?XF89I+I=vSlWC2~ z2^c7xrc(W@)c8;>2lw)2dhni7t1mxF5Gpu2r9`meALzzvJwM zS*fC^s~4QiWi~7&B0Ijr=#H%@;^6Bym}GltikY$;1?L!sFbC7}>ZeI)N(8yaEDcly z+jsTgceJ(yod3Os%W(Td)7#_OujfZ|6^(eZVdcaYh+Q|AJ}h^-hoT218Ecw@YR;|? z-Xn*$+2nKT$t(LNU#v8!fDP}Pw2juNMN7>I`mNDntqc$_7m?G|hSq+6=R0Z+7a-T8z2$iOaS0W%F}jnPXT zPIhUVltrw*#Vl0|=?>bKo@7I{mBVgx-CrEe*oc4c&CA?in9yM9+WB_ovE1zwd)Cip zcYA5oO`Z6R(Ai{l#L0QrPm3P1HKf~rc;zy)^E6*yeQl0zXdEjQ4w=q8q>x?y82NC6 zwP%(lN!%3Ah;a9FFmJ@=9o8BB6cFo%AT)9ro*z!c;CCzE#ZqhfEbv6j;oEX=#&yHc zD5xaH*~-{JgP7a&*VS}bPJr@dMl3UPO(0bV*tA%a4Gs5XH^dU0pl};#<73`uEAG>XCw)rHRrSuWIVf{dJ_N)orOIz9(Zn+K z3?Fqo=yv+wR}ri3)&al$rN9f(zs;Q;2~~1Z97Ld53-zm)TW`;%!`gWFAlfGJS@XCE zivkG(HHOt^$8G*CE6b-4B5V8s$EXK9tW?3{fTOd7_;w~nxaMVHBN3dCfDhT~nspnY zjxLM`C>u6r+yI+q)Dk;T>2&3wT%u~1uoHVCdI9NgJ9R8dGWVLLXoMWjBrJGeKfdQ- zytQuH9KI-2bD<)HgyE;KFW!r$RBSw zOUd}>BE{Jrrr~+2^#zp49;>Efcz*Lv%6v7aDp48rr~9b>-W6O6zfBhxCEA7Bt!0cB zs;G{_m&XT?z_fVfTN>^>m*$WM|DMqQ8FL6feXhvHD`Nb1*?B6F6j%ynl zG)PQ8F2u06B9XYL6K&=BC`dcpTQMwQ4=(w&wSzPhp7Z^DoNK#xa>00sLs=g$pfMQ9}3+f}!ik0o3 zr|4&Wh?mVPDQD++HP;FyGg0H>H7dTzhE}w+0S)7 z;3x4U*mAovs>m32y6k!uDh%mXJKL|+r~ORwW%TK5K-F3M@_0IjaykVqSD?&C4am=s z4Af6&WP5rwmZZ(_T*jGw+jHYk6GKqGDT*^$qp7zALdJTwipLN!g-^y{m#R*nAx7&@Ss$soV z#3$<*x8CCHi99_S>|&Tw0Wte`Ec85W&5`hx+)cXeqHZpXz?0Ct#+$sdD`&|Lh@sRC zQce9Knkzx9T;|}LS-7zr($n$MIq`QZF2P0}F%*T=Oz+F5ON{#{jX-=B^ihk(p(xHb z8CS*1hPl7dwJn=U>BIGU-l8AJ&9(!PfDZ{EfRjMmnST#V6498X|srn=QAVl-ww>2t&qf=n#7_)JP z9;zCmUwzF{w5qkbvINoaUl%e^0x#ne|BjZmLteHvy$W5RhYFzn;j(53zlr^DpDJ}) zvM8Fat5N)hpbDA%5=*tAl+_fg(i(l zC5WV#&2>W?4W`3?j=q1ebwo_yL8TGHY9T?}4jD`sk+4x(OVE|FexF0Kp#L>`gGCKu zpB~55zCGQy`HJxYPLtoe{vnKwX3ov|V4|(gi^4PI$_(b{@y!4IRS|DE^sd_`LF&Pq zgS2qGeW_IE={je6=?u*arlkiL`3ZM5mmO&iTG&K7-}xc>3z1M8QVuH}$bHk>{o6r=DWf}A zZIvr9kByqauM#=@%hHpAR=p#zLp!VPngO>`ZMbJKZ(FG&s(m)Vu$tNG{lp&a(}LM+ z#nS}e{8+1-9P%5xuIgDjdba|kw?TDI^QN-eDhNnhJ@~?2zXx?kENz$(&k)QfJj-XEFb%wZ-p&y_oqcMZ%y7D6Gfl~WO@a-_y-LS-pJ6XKFpAZ=$IR2+Gg!h#=E<{;- z$&Vn2-v7<|dDmE<2~BpKc!yugye})aI|b6_d)+fA>d#?%VP!{W`=Y}E(y#Akh+}o| zOkG{AkF&5Szx-_pF@*3;oCIE}Z=|8vY|`5}=qBWwy35iqE3|_kDpb=W+H**wTe4Vh z%L_4Jb6K2Lk$B~>HKRK_xb(%^=;vmbn3j2Aqfa>f^&H&|)u(rPinNI%IXGdwhfe#$ zb;Q#fb@>u<%AwPS6T#9mrfzmMmUj*uOUSdUAvTtKxi(_MBa4hzke zKdx2l3Oe}?PH36XgmagPeQS0r@z^AI;dZNC*YM4p(J^91G0L}O@VI9&>?t&+f$_nm zpY86Tu;U+#P43H$PnU~Of&|2ZNvR(XfyHr*c879Vw`zp}qK z85bRnd)f$kV2tIMYmW(UvrN|CLLGuxMNAZS&I(e+xH-RCD17>O4|>64Foz^`fnpdC z0!eanE9bgmKo)_A`xY_CvN@|YLmM`i>S^A{mhPG3il8FZEna2>G8}ssprff7elW$3 zY?$jlh^a8Z<>*l@H*JtrX;|N%R%e_^(W-qHl*9lZ8mFBx-;HB3u(>j}NSYNfUOZfr zDmz<#E>Gm@i~W4UicX=5x65Q7$;JwC#Gc!Np=+G?oUyYSvfg#Lojg+~YkgCM zuI7?nON!w~8?l5*rhWmWPzGntR^rMt>(0-`6jbbiwe9TREt!tjO2(Q^y`d@s#bySw!7X|ntAWA9-2uO&8k^<5o0)kS~-69Pl-62Yegh+@q2uhc9 zgLHRycbC+h*Y7uLX4aaSwfskgC+>6ax%-^6@7}`^Wl(*s&rr;IdfO*m+*Aqi*MS-H zLp>JTo{hOmT2oX#Y_}pURTr}R6z-jHdbW@5v#GM>-bO{@=~0s7(R zCdN!<$rO8AslJLN{u&C*`I$i)KKBYgZvMd^Ky-#q7xZv`S7pa`&s)6XkR6KkpHq6n zpleay0XtoIxz(rph&v?zNY5exEPsCNo~B9pTD3@`uh5=2H8@|A3=<794_*bfX&qQ9 z`7k&%O|MnPGOccJ*XjODKE40=BkhAgPsIqD0KvMO=fBTLo9h?%)59^JK1_0JHjs^> z4-rN+;n-bjZ2QUaS&+#%=hfAH4DaRPse2uGHmecF%fge!ft0FAq2}hPV(kVSm-iSj z99(Nyn4~QpU|=0$w2`!U-{aL&jA@&+rX|naf3e4tzi*z@J`gimB({bV`oW*fc;W5f z`p<>k53i~ppC`ls;2EV~oOIXBoqyZF-+fZ>&g!AMv&QcdrsYAX#vHK4S9iCdSoiKf z3@EF*ch3RG?`utKu5%L(PqkY0bel4T;;5a{Xi;ytq{*33Lr8H;V+h7nIZe#i9llCd zWBr6v{Byt8*R!1Gx7ZC40~-w8Ip=4br&x7kUu zFesMMX|`L>%)N*?(7MBq%1a(CX`zvGA9Y1nA#r~A#a3Il%-8m7lc-n(oG6_4MkvGl zRao^iR20TLPdHZB88Bp^T`1AnMbEB&iLa<~^V|60+m%=G#kaS!1kKgPV_|(!b^KjAukjm5J}0V0^FT6gZulCA;RO_rsqD1DP4z<; zsp=8804QdN^gDbg1#BH)gt;cB^$1*xTxmL zDxJ2{;N3X$0^oqW7v7Bxe3d%OO#D5x6njd?))!_z&@LtGYM}-^IaM)2+1K2*7_@g% zPGs=8;!iMhSOrD4R?^|~`DT-j1@$n?Y2R;(HpCk~H49rxxehkRliAuHC5jn(Gp~8E zh~1?rQZQ_WH*3Xg#PJpqTmB>u%+-sN;znB0#>+l(T)rGqWm!kj(?#{YI;(9Wts2cw z(^n1!J{VA1Mo-HXqC7bsZbw~nWp#GbYsOosJ=I7Qc>kj}aPMq45$$tiMJMH-D8irL z_fDz=b?e@%+eC47LNpA}T2#^;vywApX_tDE9M9eW;JZw~>38 z=rM@jl*VwUXZPaFLz3LQB2VafSLjgYYefq&0$EoqU<=w{9avP|!fJW1>>Np2IS5tp z*8Vjd^y+oxeCDES!q0GCBr7IELsW_B)}xH8;T>F`mcS>u`rY|EZ-RYN?`RhNe26%( zttc@=Wv(JWb=1-vN(i1Ew^a3K?nrjX*4%p++pIpx!1dW;voQ=+AJ~`I#xIAf|4fm? z#`EMk;v}4x=UZVwE0>Fl}e?imFuOVi6Fs z7XNB8->hucx<8kH%T|(wBC|cJ{q2A&ZuzSy4$HIZFUPCr603zYgqY50ZwAwBnjWrZ z#40lyUXQySR3;!iGfWYBml^|&JZ)lQ%Du*iB{XyYUS-dR4yDW|duqMkwHIblcun`d)sEBqhKn2E3f3?jZp9x? z!s4l)yRbQ$Y97TTEA(^B-&7lq?#80`i+Y0nJY7kyz47}#n`ss(b97IeJEtWvXMD*W z+;3cxI?{tnt~`^rOo{j95Syo#!u%xFGtrJx+7#eFr+h2tP8Z^+ z!Vg>THnco|v$Q%`8<&dFP#XWe!PZcqqN&Da8-1m1WHw!% z#d{mQcIC*hdyRCujE=19Pu_k(Ec%4?qhRNXh14Vs7yidqCIqqB?WgVWQTIdK+h2t@ z=Zq>d9*-2-^e(l3IF(xC>y&t3^k8A8yN)cZVS(!LiqN}kM#qSSA5pcM1WFKMxSQwJS<$aKf~; z>TGfIjRe%DsJ|9UWnY(TAN&O4>#I-i*ni<#aXY)}?pU%$G^YrVRObtWqDn~2!CBWdU)s7mv9+3mmK|ijd+#jqIzYe^NH=N3l461wj9OLbrYej zEa^e%O~nvvM%mr|cynwEg`yD#ZZV(4b|&Em6WnlM1-(n`Cp(;N5oN9dD3-oz;Uk{< zEL$O2g<;2`_(z1}sDUK(Np%l$Wv`&?1yj4vc5$nmZ;Q+E-sg=?i^f7=Db9A-`(E07 z%&N!gX0$Zh(a+@{@aZFL;hYgUZkfy!D@5Sjv7uU#6W!V?=OjpB{}PB%F_fxBb5lit zHP>{^Fk>P}Gd%h$O&O7f*^h~$lfQ2p#-cxN*YF0+N3(9#?Ci$@g4n;|Efm^(-QjYj zHx^DwrMo27QF;N>@-SoME}iDd^i$Lm6vc;k+_-<;KJzA68BAR0w(NZ#N_|X#xaP=o z%gP{a)@Lc6>FKSexr-!U#A7bOD4?J#N&4*4g!A9a435Exdi%qcNnlj2VA0Htj<4!t zcG*%_11SS}Dpphfwjd1y ziV4%zTO#aFrM7A`*u&9mktuupuZHwz7lLj+S%J?VPb)# zxLW03x_0woP-Y^=*RBM+Xe6i~vxmBAKNU1Pc$WTf8dEA?J>SjwKq`2WN4%J8e5Lxq zaLA5$3b#{HUFNR{?gktm5n(B-7eK-wP~h3Gc=`5yS18pjl$i#5PjeGNBO)d>FF3Xf zhey%S(B5r~O&P2_%erJH+{){)pPhW8YQlZ$(w)7D($6L+WvZJ>M7*(D+xQJ-I1Hl}ru zI#^OE3kc-E^r10LV3TNkK$tON+ER`w1oflA>?2=GF@k4n6U`?zFYZ@R#f0Fx{TX+i zTDrK{=CEjgoI|5qHa{O`6XxW8w!RTCCR)rzW=>$vr?S~n(K(e`0RSf~^3-^V!)EJm#L||HUJP%&ecX_Qd31V;@rhPOXm#4;X`uZCTEvN-7 zw?#0Pmfm5lF9%haH0kJikl8VvEawDW8Kf*2@NYmEi(!UDz)7rV@H%&98+SLii(Ye4 zQ4jK5MrFkj)3$xCwuQK^dXs&Ix*N9CoO0RD*5+3_8VQ|L;g|=q;wD!*yYG)mi0&RS zCH4F^7gQS}>&gko7~poQK4{X$3XcMG+fMtOeSVSsp2U3{l-HrElqP%`6Wyw}C{zLA ztFb4O3>ket(~&(B-pS)=u755DBuy7(3%Q4P{=8{=d|osA&Aj6@eO=Z`!}aOTe)UvW z3}m>*^JQef&T%Yc>e99gG(Xh%t`c>O_eq9cumnNIffI#gju`UzgW~Yq1+Cd2q`y#H zmWeZ!iBat#6Z2;YS2}CXLk0#POvQ*t%lKVvZ~AZf4w5&Gh#Tbh+)0Rzh)?JuBVqSz>^;c~<9A%MOx;#)@ylj{eo;ZhgBBV3=JPjJZU%He z52F$FeUhGcHayk0kl+1X#X(jY0qr{`9nTAU@t`o)E}^S)4v;r6TJ1}fj%H6Cxdd&I zWDr%TJV_Q%n+92~BvC(`gO%PFV_SY=kRqeEqUdeg7_DrHcVFK+U@~Ot)-F51%ObfS zGV8i4J8#d_o%=QLBttr>JJvPEP`#??9?#Z=1`D=U(P*d5qmnS15yfv+=0|Ovi8com zyf*|(NMkRys>r=}?yhYdpM>LoF}|j0&BYh=kF?JBlyB$uF7IwA>VuNtH?15}I|YaA zE?P?BcXTOE!vwS`6g6d1eJz8+WaU0V`1rl&umDV`z+*q_r%PMxUbCjNpGVoQ8)F$~ z@ZvsTmuPSGzr#1^z-EUMWqi$_M#2Jjyrc0?V5vswoYA6hMJm5CQ|63?e0718i$wRQ zcE)e6W-QL-0z?VA9mJN^iP4&c_{Y>0tBg^G<3ew?p(Kn>_l;~z@F>;%6@(%m(?y;?&e2G}&6fp!Q-Gu5V)|T~tXF>}~ z=Z{phbo76RC|lx;#hP{SQHc1NF`KST^n{|=mA2zQtDjVSsbFd=!WaSqk098B)O-VR z6>ba4|Mc@f0#B%sq7YWpl-s^cvDNe|kYE6edeMTV;;-< zU4Mvc&fa7(W0rdg;WEeN;@Q_}gB-5r&c2|$-;eKg-cvvCcx^T|46~F=YqD0gxtjg6 zActzqzQ_JCiJW<@wqb|^Mk9yWs5KN|G^#$9i8qeQR@g8ze^lK437cVpcvqvVyn>CQ z(`rqWe_Ln;x8R_5sa-XhmBYa-R41RDYO&0?Te~WwJIF zj)(g6&v3Oa^z^#u94 z$CPo+kK<0yz45SRpRQhUcXM8l$}!gFzL%>Sq;rrvnvl?S`6p@JFOMsKX?UwfVLZt} zot2fZDmZ&qvn<>$pyl8tH9NI`o4-irFzKdo27BY+fL~^B@j%{@;gM}sx~VFGjB`8X zqz{PkSy)<9wyz(j>^(>G8qU>iZfN)b8X~ z9NcRnvqB1WM!3X-bO@XK--FlVl+AcOv2k@Bqa-4542bET%3etgCauKIveY|64q!pi ztAE@=xzl6+U5jk{fYN5jvg2~D-EJ~F{Dw1=u1qNkxdu2WLuLd5I z@)s>`zOAk-U3oZok4a4^>HhYHyWh!%)0mQ}=vl_5>Fm_CB9YFFMCJ2$e=%%+^L|SQ zvfr{gI~IFD=ch&SA)L8_`5q}3B}_urA?GjLWL2gue^CAyHxt|?X_-*8ZDnYw*Rt&5 zmzhqAa%8RM-zE3-YZ05Mf?+5+Hgp_#yN5ZPrIM(pbRjQoB2Z@%CjMo^+hXW??qY6e z%?&9mj`D8!b=FCby6MImt=!#BI~Vc?QC4#;LGw?tqMs)p;kTAGu0PC(Fw@DtVwfij zsrk7d4iXx&)`(K$ZVDZzS@i0YIV{*vOF{Vw!yMuS*1=Qzy|H4dkVhK=$~w;l zEfOlneA<+{E(Jf0xqx_pWV@8XP^-FcVhRm}xY^~I7v#-+C%h-K+Z$GQdKT9AoUQi* zwIs?!{4&VD4Lef+BlmaZlctTk4UInneccBoIp@!t*)pbppN?LmeI7*l>cUHH;%s+iHoh1{WZG*_8FIh zDKgGK=4+bzRVB?Y(0MD-`yTeu;i}X&X_a+J`TjJ&L>LdD=O_ER^%06{axZ7?rQUMM zb^fQtOmIB9{(Mr^nJtl2_b6A^$%SWnYgA&UNKwgJf=-rDOx)wTqK5QrDj)0QaM@e> zYqj(=h{`CI8PMG7Eija-zBtlsCZ z8(+#5-QK3yI7Lg6t7b{q4_Hp$R-#lKG&tgSPJlAG<_T9d-w~linN3`maP@~B+nsrh zHO`z=?FUad!#~D;d>z869NpJj_`tpLtfwgj2kRB9UaJr7NX*#2V@*v^IW5N?hm-Ac zOsoXYhmve=-J-Kke8KjggB{s+slOh}(6K+j_xkR@Nz!kEJ(*!Is;rivm^q9ehF{-H z_TrRrfi-1DI!Zf;;wd@F)^w1|t3x@)fpdplZli?nK{|c?EoHZHVoCGP_iP3*IfWe0 zq7lakOgr?-gJ~(PSJj`nGU?v(68OIitSuX|db9rZU5TCX>#qT^pZFW6R;pmtqUxpR zzftIy5ZFF*i+;cM{u^qC0AK%3<)?>)l5U&stLu|dQEQcNoPtXvJ7o;(qR}Cw{dGD{ z$}yO9jo16~B|$izrF+G_Oce;=v)Mx3P*;dAvTCVYfYTU-b0#!a_UVmR=`}lFsEyUWP3+Tj>2vslqwsvrE$fuG`4fC{S zaFwoPe0kZ>_bWrY@w#X9O&exGM&cbata#?SKlqn%R2ZHVGD)3RZ?|&gp4HxUUV3|i zVsfabca+3`G%q6f^f@gH|hGi3vd2FLzi_Sf%cHXFWue##PPx^bxX335Jjzy2|%C}=Ec1(`_AvLGyP+R%#ld9)E|F#RQ zx(qv30BfWn!-mCKl|%s4l|tKVKi_5Z2Km;$WHBs|*J%J9%Nv-O^`Nsm+u%1pJ6i`* zi73tj(~!xWfG`-gb}eWkrox7trBhX|wd|%1(hqetyFoByfwEF&c$ zsh(Xr4CNcVsI066U3L z)u;p1fX}(}$?dmoc~e~-=HUPe;-#D?)pK&c-rvilo z>b?g%MaRtfWQT5c3?GI<#0+dv+>5#W#F?9iq?Pa<%J^A1!U>aLJp zzG*Z^`CZn4pK?5JUFjiYCaZ)v`6U=knV4*IUD=|OhbPsx@lIbj%}G^4s#-D;%Z6Dhv!nIat1>b*p|Ay!;>$st@Su=mB}xN*wiUMkV;(S%SR4KBt|3HTOZ7N|Kp)2@3D z_h!3rPnjOQ3aAlXJzeeK;X_5o{9eYDQl@DukS*A|R&?&XDC@y@_BE*@U-R&-+ua@g z{Co=|GY*}cGxiT=cylX3DQM!?3;1ogA2~ldciq(G@_G9@bepQ_ETp-sqf5@I#AMJ1 zCPe*7?XJbm@%o+e2ISTsim}34^ZT9`8&30`_IF5hq}zSjhz!zh9&BgH$0SW#?% z!8Q7K43~EhF(TY3Gq1P2#Z%eB7k*jl#Yh@G1@tKxh>`C%LWWwJQ0 z#umg=fjP(O9X;8($Kowl*SFvA>^Xh*$FTW!3}Jg~>DLZHPGVwPV$1{d9}B4nxydv> z=1I+hzinnjvFcu#K8)RiF1Vref|V7=+OHZD>NN~H6URRA5P13m=v!Qb>{yzfEB@ux z_X{beW_h9_1sU6tr6=(HL&rqFIa|=xJ*snaR*J{gPh&UmSU&Z&b|4^?)dPHp{6Mk; z6dI)^GsjXjsE^BMSNMai_KOdVRGR4xCik0V;8eL?t|Ra&t_|3HX@24R$1qIm+eIOl zq7c!#qR=WdxzI-K)vIq5rLQWnN!2VGH{{euyKYXS+xA}EpQzF#_Vtvj=KD)%J!%a_ zwLP1-C|;)vs^+)%IMCk455-sZuHd)~uJ~xAN45vW>UJD7^jot#CI`JBOQVQwgEZWR zYRIrUsY4=%D+=BCF~?)Lw4ipXVD<`5MT8CBobIrYTRzMd-}n$abTN?OrGO zs&WdkyW8IxOUvkeHs!}W(e;v;Yal(_SDJRr-p*ltETuL z&vc)Wg>Z>W>C0FMd@0T1;qraE5LdO`Q*Gv2>|-UGEw}I9yd7jV5uM&?vNd#GS54JI zQkUc*io~>8-b`A|fq3i4O87bX9{23(?77p@i)USzla@#Fw-`Qw%tOY}((|7S@0g@@ zFhTn*#P((Byjk7 zoKo7AdhW3HFbH@+{N}N{xD_YBub-6f^kexIq>Yw48eH1re~NyaiytFLW96KhC{Fh| zlr8PDhX^AS#2kdVM#n4)unGdYgXtL_eEQaL7v?o$z%i@)_fS7}U_bpE`km3)ZiG|O zWu`dE6DW*k3rgy=NhZlXt~#{Cx54MJnZGe!UxKpZJl<`pTXgZk<5Ali!`I4YD!qV3 zAy(9N7;t#UB z2@d(Z<2_}EEHWv?F`&c z`$FtI)ky?2*d9nEg5pi(ur7s!_a-GO{(uQ9VTLc0g>Gr7U;jlurp=PtXdL%Q0sW;zU{slFtZ#)U5)45K4fSuakOelD!Ez-#TeR66Lgt&QPf*p6Gy5T$J0-N#_LXLW$kU>1;ek&w;d8aA;=(CGsn zsuJ==pa>PB?eQbYR+O~jiP6}ssRiz2$LACdHS*uzAwhgp_UMgyp?sp=vuqSV9!b4? zxdFTZ=T^s8@LK&4g3QK;jSPDvA+pSbMpKC zdZTXds(ors6yuHh`J{)aYR}&wC~=@$d4G;$0nszqbTEbc&W(-6uWhp2^2*V+q6zyi zvlq|>Sms<%|KPt3c_a~bLk4S6ZG#PJj?2$CZYu04K<1h9U4}D5QQ2-17+GmLryZ_r&#z$VL%^I-V=*kOeJostHOU=mMrKjgQ2F zupxgNwgc*?qL+=<)lt}b4(uBj8ke5SA$14EKJA2bND%GfjRjHaDbW_Q}noVPEz0qOvn77CUz*X6%>u~eB%|EAV%uU=B8y9 zff@)Ka-2a?hGpnD6lLNJXCV!EpeE;gA1uG^)JlXshwidweMug(1=(<64%rrma`fEx zoGexe5rBf=kT^VUcB@-oD1_gL(?G)T#MV+AdKRe^)8j)GYMWKX6qf3Pu7p;bE`jEz z!*}xy(KIBmYCDQKq_PH$)3_R>ncGlqdcQL+(#jf9RclxI-xhPrm(9y@9Q__8S~GF$ z)%GQ=|FwdUZ1eQt5&U8ws48F=I-)=&#V%(k8nAs|UsY=FT|opHMhph~Qck&yy~=@3 z9mbiE9#mQpQx*qwaT)gI+Xh(PJ9hi_RlK-z|L1LP`e@fWa^8HN>nC%CT18}be2voj zzGo=FZyplIupwoQX;sVOo!{^1wqrUL7utO-(-;)ToPVdlR`EM?Z0k!^fbJ9fn%Eg?=*8vTy5mR)?c2i5ArV{(Y2@{EJO)BAiWI9+;L3kf1g+x>dh_!Sw# zv3JQ7?k^Le&9)BIHbi=n;+~Xi2<`OAiL@f&Q_ngrKKoVy77qtXF>$2(>Ooy+0c$<# z?YP^PYELZ`SLmP+?_4!InG0I;16pt!YiKsvCk&RVmkX|>VR}<3@qo10pn-<69Y;Jv za(puZ6)!(!wvI>=f{i!JP0yq}-7}qIz{}c+zC^7GH!BU+AFCG@w$F4*iztmR+ux zpWfXN>SjM%u(*EC@5QM)B=`HVtj$|?O7im?UiXXqAha}m@X!Di-+D!9tN75TNhm4m7(UsF2k zy6ZwBx0p@0fTVeF16N0=|E|x+U2Owu&8TZfj(Hfqq}Y%3$_koo3C9mSKX`+ zu<&j1ck>TtCylnB7A!M@JhA{esCeCcr%Lb)wOM<{cr1D#+iB#tA<3L&_ zk}8*^@QWwvU9hR&&t6j*`ZV(I;w~QLy#_EZpl*Wi1z`wOh~*du-Nt7%0T}vs2P{_* zOn(WB@`kFaki+R0^6u@KDx+=vh1Wm$8tKVJ~{Ys@rKFB+ufLx;q z=SFl$dx`EZTpfQyF)GF@GmKD`Ho|Tw!CrRSu!hCOv{K*|_rYP$p}%CeV3+zuOoQ%y z&)v2m)t4O-PXylDt{8>ce0^7Fmjdnn1_x;WFWE=?1(snsq&dJ|aDP2s`vs(iaxi6K ztcu)+6}s~&ocT^K&UL?D2@Y`Df{orskkPy5aqZ7fVtbV7`v?IZ49tSsNnM@(<4568 zRDEBa*s<8!MH_1{mlHswjQ>lCupy82Rwb`55=toj(DJYyd>F1C~iSgmuyTl zYqZ)_5ZI1DqIP)elzA5o*8=BSBPB`Y{4+%5X60$hFeNg_)3Rh?9VW~yKnhpm`2ng$ za94P>uP{j0pfnS=19K(6vU*#QPByB!r&Bbb%irkb=KadmE4tz|?!rWMboCA)!Gs85 zMnEtCVAOjIh4V*CptC`=#e7n=yp3!*QFBm}6dT@cSi^q6#Ei`tM2Ko0QR%Q3S8mHG ziNiYRyn30C1&ZfX;lg$(LMUyV*LP}}dX_{6iR22`ADE5KFD7JOL12CVsxKdQ9`~VR zeS09)xiCXD;$XZdWHBl#CFW;vhz3lhMJ!Eho6FGgMqCg-~ zCV_x-d&jyx1d5`?TFv%Fc82O9hRE5x9?NR9c>Y@qvuK3*#ZfG8z^sGYZ zhc(3Wo{nRw{=F%=nk3|ico}usxiR%_jK$ZkTG(FCrjUiXz4ZIlS^^&)AHEnT> zNR@ZVk33Lam{089ugNa+1@con-ioIP4hp~#HN#^o@9yddCSI>p6J`ea{97MC-KTr6 zWMlXWA%Xo!YcGiY=^4B(e2r>eR(l$bH}h^HdyAMDljW#JP0FU&6+TaIZRJ&u@JW>E;MoRTNaCVg3E@I(J-> z*}brcn85lB3lz|yrN`s2p_YCI9%#40EOKM0!g2e>$lMvR?igC=2^CMxt4q(Hcof!| zZQD?MEYI2$$*pnChv4QSp_|hj2R)CL><6p6s(4Ticn0@2VmRd-xm*6Fdi{)j`#^vw zUi(Jvt(yDdoL5Li_F2=TKWJW=s<}w>D!0F@Kfx*M6WH7If{`1TFk6CI08%i!W36sAi2)Le92Zlg<5eS7`ipg6HEI#l@JH3ncsy&SHekT5>6HE*&a_MC^Nv~)NfGW4RF z;XIQUkv9ka$ncu>s$3B&v`Gi-vouk<|3v}{-T%8oKq(>@90I^Q$I+t0AS!PDC-mOwT+f_5ZLI7| zs+O^tqK)7&88-(3Xr2~u*w7=8v==BrK%QDLzst#h1NDA^fo0?hxe?Jza2qj|ODh!# zK0l}$zfI(qJV2HI4OnJ|{8C;tXh9I6d$}65T@9_{<%C!*WkS;da6IAw3cq%`mVRSy z6MF_vF;D!ZTBAmw=8ejRv6H_)`+&*yBF)gjw-Xm}&01vM= zHw$ng@Vq>wPjEjPCGXy^!vGPDKAMK^)zwue7ndo`6iP49BUV43|d{5Wtrb{3Us+tp=I9~g>iCV;KmvyoMx!?Zgu2rw3 zhRXHqCunX|#^{b>Th~50UI93SP`@L{mL7Fq)HgIVv=+QF0)m1Xr>7IZQiGJv_qn)u zvJz=?g=&u0ELh2bF;ErtceiOsEv^iBV=u7oFuj7>4^>rQZ0dMWtSLEii z<-pMmcLY^a}LE6H5{xk|zyo5Q1RiS|2w=cR1!p!q04ksnYTQZ}Jq z((tzewzwt{gZC8?{0T*o#_Cs4um?s7vXEUXF|hLA&O&vqwr$OfiotGUkdr2>*EQt) zKq&t7%gqY9oef{(Q3?-R?8O9OPkiVM0sriSn-xpuaH#qSaswWszn#TMyI4)fc_>Q)xBoT&MYpy`}mpJ>{o<9(rge7$_bBJ;dGDkyF zm!ps?ydv2lDe3^(u!X4DNC9l6#pC#zuU``0e=X~OU|INYsdEQl>lxBWybQ1jh=9v! zWpz@lt`2Ldqt@ZU-u~WNV$=cR^D&%7wFZTx^c$DRcS`a`{&F~pA2lm# z+>e>^HgdB47sp)Jf1=q!d7r`O6H3u_#NJpHqigOP?HxYA%68dPhhK_bAV=h51#}hK ze!V-_4PWrCDjIwcC!K%Prh(v7UH=^D6R2OIk-mgpLrkI0kGz#|sUtNa>oaDBrzowX z)|Tx#a39VI_lSW)%=4Xp&Sc5|Lo_Od3Pq`A8QErxc4#Niy%6FYZvRH?m)3KH8VRj7 z`ogZ~+u~#H%WbGfq*2D}%((;H3poB>uppkA?~@d5n3HUcLpvVH;swIP=Mg?{D99nJ zgmjVwe@_o8YL5O{*~le5C;eIQsO9~@!)ILx701eB>>0rPNnh=rSew4SbBC!P@=kg4*7UD~hmg zI_6J|cI`1A${WkP*bDQ`UQ$`}Zj7NP$YH+4fN>k4LKpT%3GJ^Ao!7$Z-oN}*)Xh(z zXX&5&jvTK)n^%NH6HZI_bHA{ct%o`6Ua$b6YdS)a9|v_EfOsZQh8DH>hy8DaFd{GL zGk=v#LyRe>IGMD8tp`o5ZamMf(T|1Cjl}`u`<}(;9#-lH{QJ_k;#pvZN@F%(`K)9@ zv34hqo7MIhN+N*adUpyVnlP<2t=4#kW*5s@g$9r*Jw)vjdY3rW1-QZ5)>kbB#@Mc)9Fq&GN_>?h24VNen~(#mG_E8 zsZ_lA?SkzL#Y8)m(=7Gn!M`Va_Z0>_@P|3#2cyNyx}{)Ry-!V7MTaJ5Ju!61`-m)3 z!|F{oxyjQ?p4fq{r|=4&_$$k`o;1lAziF=jJ_#W_Fp&?3*M}k-gxU~IXmTs12d)Pc z@r2Hc|NAEC{)oJSm*r*G^D44QjK22;@(>4I5lq$~$A&5#cFqW&BsA#=_Cim-s;*F! zJK*(#ygBV{tGfnV0x;=* zKt<&RERKjWPjEd3Q^N)@7FASK1h3Llz{wQy=G)#}lzSYAm*E!aDotpifLIVF%Bh5; zsY_p;$o`90ZHobqZk4zknjsC0!0u4`*)xRgQb!V)H))mGpc6;UW$G;~FZYA%x|Ws} zlH&^&fZ)6i?n1{aWS;cAj}R;Jp}}I&B(=#1IlKS%ZMpB5od^$r@9={Hp$47c0tHTb zabSNfU-DKP$-PJNh0Eq~M9Q^8WYo`!2 z1RM}@3ig$zgBgU#VFFSNSdDQQG-2EEX{Rr@wzoeJ6r5Z+dU?kP@%2#U*(xUjPgm&k z$B(&Htiad{NwfTK^!ebX@nEY6_N0orI)YO!>&$@fkgn#^X5= zw6wJWtuXJhB_$-nbhod!Iu8jm=DvA4*{fyiZPGLm;6 zoU_2wYcnZU^`HN1JrGPpz+Z}l{ml(9=(`PcdfW7n)0q!=FQue7CX$nrDI~`~J?cQD zy;Mnkg?Q9Md>7mf_DzmW#%TX7gjA(j6usLLVPNP+B_u>C5tdg{VhnEV09OmZ!IGVF zePQ9FWGXcaOP7hOmexIR{-b@`1k+DfS63+`20dbZAYIhLf(guV;ok=a?vRm@rHnt2 z=ehovFjIlvE7o677kQZ#k^lzV+S;nTI2wQQcH$d2mlZm0>xhVm zoB(=_=COVq-Tau1EeXg=XxFck7O(yzvBAex4hVsN!%@xE5d)9YVss3Q`K_(sfPmHy z&Ng_G`qv-J6foJJPj~spA?(tJ(2Mp0q9>Sochrd;}Z0 z>c4CM{60sFS9+4bersY^?f)0_@=&_%yw|@%B4IEg*0ib;d-v`!@)!dcSAHhr-l`Mo zJyV7u5G!C|Lkz6iqlvfGSwYSeILV593h#HBm@4*y^>oN_#IAC}?lU*`32Y!7uSM*1TghmL({2wkiwQG&>yF$Zb zJ{C27N0fNInB@B}q(x)mAcRgAzGApk?#?%Jc$}Z$Z;uxm%jR2x59u7Z1>^n3!4s25 zJXD{5W2&*S1&XJ^@#MX~{;W~{TNmNkA{Aa>!Px?fT-)LdygwTfg+DMPiiMNB(s|WE8tgjEizkTE=UeL9`YPuTJq<3DT+z>la-0(b7`(@=b z=sf}QfWDOW?au#&`1BqG#J$2lBqgLlOTF0z@+1@AQ7V}c6SvDyy^fT}4br$=7i9kX z;l1=7^;+7&lQRet>6#^P(SNV7DTRIzi9u~KKyYTjwi`x6(sUKDrTcFf+A{0u-(%l> z>;)Ehy=At>2p~qP7nu>lU$a}%1fOCSdTK~WTwf9=sfBDJnLYV2P*A5*+{KEJ4r$Qf zNE1P={hvLtQO7r=x9DST?w|S%zDPq<@GMn&(@#O^yQRzc?7fCF2Wsg9k*!Nh*q(^K zhhaVc&)jb!r#|ZQ%47w1mFscNijhDyCcW1WiAWPA%sZdvQ1nXHCK-O1{uhg05mw)c zeTU&gQhVpKF~Tp+7ymoaCyn&94?kNX{{u%Jgx^sHz+pCCFa}H}Z;_Blcz9IzUbMEh z5)KqJzYIPn{UI4?GQf>mt1VxFxbY6Jw&_or$oK#4gi1*ywX6L zeytbE!>j43I6pDv5-UWF@dz;r4vN?tc?D}!;m{|YG+)AX-GuV%(GRZjl4@Vzn^!D_8~@o%2NJmY(zmk zJ()r6h~*^!tM41QxJ~@ql@AO}%jW}bKNA(?Rr1EHnK~sz>|_X{Qh|4&a_o;nFQKyy zj@@ol4BU??DG$KD6Gc^372=+xS~RXaJgF1f&Yn=zF4ZI9x{{oruZ#>_RhWCV#}EpTloP)p(-Fogd(#mc0*;fE&tl(yC*P z)Yc-VJ+G>i{BkMKg!i9%x!CL;v07Z=-@e_uA9g@mJ5}XsaJ)6C5SciRX^Sx2?aZf4 z#!Za!3B1xmL)aB;(-;~eA0bG>1t$@TJx@;;2*4u!9=LGA;_W|i1J71`<+KMBx>rBL zRv%!9w+vRGAxz|R(br;zq6*mm&kqa^%s&>n`RGatWZ){MmBeT$j~+c5sc>Y4?UN=$(oY~B%|@0E74dClFrt7IIn?)I zPJ`2MzY}XzyL^p@yW8dIa(C=~Ek*$fluCL61i}d03x?%(75sgPH>NmF%w2T&E905J<#ge|xcyqPnqskEtY%-zuQ4aGzJGN?togbT60 z>O&w6VPdSYU(IaqLs~@(z1AVT2bgowm!)#g((e z#wg=Kl*kZ&5CSe~AqKT=a8%7T1+(82os&v0i9H{)MIhVJ=Zp%t{Ah>pin#v*@oySPkRtdUw?4PDybQ_%F`?TeBmuAD_##||(9zLf zzI^EgM#-hpEKv>J-FRTb%4A#dB7~x@xA!(KF0K>!tG|N^^YR&W#O&^kL_Gy!pB4mS zfb%ti24({>{VLPU>HcyGU_hk(sE+%gGNKkRNN99)0Mg&tRp@wo+D{>!iOt*C+&mDR z?8D1t3XMAR95yw6oAzHPAt6!wZE{=3d65uIXKzAyW3Z2bGkpO21Gk|E(5`LZBUp{0 zmli;aHu?bJMffi?nZEz=zzbj>vVeOX`>xmT-%nt)?(^{6g@T-Lrb;%$;RslWLyYsS zsNgTRTlJ3PcT5IPaJ3>c<>+y!-qnK(vKucCkCM)NSeN~uI}r%)KD4y!d`~2NkSmS3 zHGy#`PhS))(zkt%^eruM1O)|uq@|fL$;!z1#Kmcs9qH@oA@K0<&Y=u%n8RT?_9W#^n74r4< z-8ndT4mI0S=~RXLEG)S2=5a=V|O9QCBD@4)>9@5-0WI_x& zEqYFi;TDySRHXR%`R#C{Q$JoAM#jc6g7wv@9er=A?hSNW2Dl*ZnUDRD}onX-ApK7a4D+c;`a5=HUgxF?s^_FFem|p~40xsgmM};RJSJEs* zptqMUqpJuR#`VXb@$B!JZ={YwtLO`~J+LlUphxJtnyL8dZ2^7z(uy;>_EgOktvA2L z@gNVs2)&t%2h#uSTJXDvN9|FxK6(1IqPrp>Rc6NqXnzhIn!} zQ2Yzeaj)n$>I(`3N=xDj-UO}Jx1NrUPDNF<2gOZHqQd(mK^nyucUb(SZ5GxOZQ}IF zGK2{gkABv>Sy`!Tbkj0&XcVJ(7i%p{Zp&3?IGzAB?O3lY;s)s1QX^C;;^mPV)=jP`p2KpF%E(%OOup!`OMnQ843ky?}>u`sv z;^B7F`j@Lx1!?Xc-76HdorX$mT3DAH0D70jcXWg3AQQ9V>>`wNi>X_;Zgr!uj=86} zE#oN7uC7}e%Mm8C1DX~G|rV>(PS2EqP1O^ zemFlLH+N8Ot{k=p5mr~V^cQOCq|>z#{7`3s^xDNlyMA?kZcBYQkh`X?gKuK|#SI zjA_QZ+3C+Y=J)^JnIU$`A@srpoKvr`t&L$}J-u7Cd^ZQZ_N<4I5nHx(Nf;2Gc9Kyy z5QfZTNXJ-xeSKxFZB;Jtv$>^Z%<0FVvP8thbiFf#?DxbigsH<*5+5Qbe)_sYrnU&n7(=YHp+aiJ>U6}TOR`$j2Z{Y*L^czKYbQcRb9R_ zUU6;S4FH)77cShqQ#)X`N`pC^<`S8>y7wqTuEj9=&y6Gax_@nGZf3)@2|0hB^XSK9 zlMFUq-e7>rX!Ej@WBuYs))V#=fY0{rJqK4;*EL-+>9mMfA`1g4etA(jo}BsKy-tUH z!KX1YG5KRxo8pp^yknnC4jeovrl`n8v8`OmdiwNf0tGP@!8N^n(^+57Ol@Bm!@scB z+a)Zhm5PpRE-o$)3d;QH7Wbm`WL&l^SSQI3s1g%#wE=6P=VzV|--mdUncps8gDa~j zjXUEr#kJuAmQL8pJA9?;_m;z`C31e9ec*S09_GATm9R+PaPpgy-$IP|dg%bYk$PYB z&5H{WR2k8~ZehZ$pi2z5Us!$@Gj!@kH#g7E6@Bgv*?+<Bqr zXs=2xJov}4lE-3C(K7y-z{}_om+pG#OQW%H&R`tO&7-bL`(*A8p2Su)a&*z6bul!j z+!sEt#z%RY_w@0|ZJ8qnT#H>t4rrV&rrN6yBvnnVBvC(;mDO*;phym#N3ws&$Ft-l`zlG)8y^tRC-mI*9t!J5qf=|4Hufoqe-qW z$1w|s)XyKHLQB$ml;2aG;89~&{Ajw|$@HH2Ka_$ehC7&KWo6AREGilsS-}oZ^~fu6hg?Qd z?W*%N=Xvd0(v7iENzk=xE0R-EY@M9Ku>DeW(QoIRl5-korV}@v=W%>_i>Wq3C_?#g zdpAg@tls*D+k!M77UtG9C%tmElpuDqF5NtZudR5GRpU}TLy;2&-$HvAY6B1-gn;ILv zM|%^_&d=RURwYqNBL|5^SYDnzLQq`?1U*7!H*iav_35mZaWOG5QVuzc=gnipsHDyG zTF%bSgP%T$C@GoQpM;w9<;$1)4g;B(6htA4I&tE}DsgcU1%-Q!lnM$L6xN}nH020{DNz>{6QkjJ=uuK!+rCu-Bav;?0@faU-Y^=r${F?2G< zV|7UCUS#qES5^%!k4;c;X^;*LF8<5Ahlbbr+&{NUKlwPvkbyF+c<%Si)XcXsTRXe6 zNWBDi@~OBcn(j}c|8t@MebBl?qM3IZ8}$*$@ck(!?G<#%hP^*7E}l|I7wB|(ul0Fm zHD>rl$(hL$p5>j%OQjZIq9hGgjFid3@V3`sF)*L^5_80Ac&Q=5B2a5V(v9hIvV0aI zrYRHFINyBmod5NI|G2QcqWo!$f-@2pH7=kN338hGg=)rtAH^AI!%T<%g)L6wWLJxR zu~5^001y^A(5)=Rr$E?^NFI5M!Sv+So2i9>dDxC>FYS(m8lx>4mgNioYOg((l;XunZZ313KJSLGg-N(>+JsjtRjM1$ z*zA%?hMui$Qi}uIw&&m$Pn|kt5|i`_7ik*FN2wwN7_MfdMRf}6SE@|V)`Mh`KYt#6 z;^i7T#9m1*%W~XDa@ww7O~^<<2;X@HUpv9`xO(;CLx+I(P^E0OiabJYgRJC_Nv4uD zBa<(8$Zp%WG|~W45e)_6r=qs@+}CasxqXA_I#_Y(D_h!T2`)>0tW-F(!#Le(eI$P9#u2I4LAzBtS8^AI$MBJS6iG}{&(owKdsqw8Sj0|Z&`^g%BWa9ViH8b%%-p4@x) zDgj|e24AG49k~({6D)6E+npL?eYUVr6?iQeK;WWh?Zp*~QS0+9UtaP4eMz;(sUL$h zC>mTki!HA0Pup#~a4W`#n^U$d1FCf_Mk{*$sSBiX6&GI=BeSm(+oFvQz7PWvw}EWB zSKZ){&#&=4gn0ZlJ$~nh{WcL?4i40T(g)IQq(QdK21Taf`@bs#nWjeTga8@1HEb!R=ufVlp84{s<3fr&2(IunKsEh3lTzA;Wt?oqEnc&1k z;R^z4JSe^OjvrSBXEXBtMAspgho>ix#Ky&mNJ;H*$wNiPiIWEbMu8bTk3rTLpE>-l zV9C+qQHBe!nK{5ZSp@}e5ddE7wQDJ7W+qg;M!I)Rid}ii1_bU0?lU&Fyrw2quY?*w z3W&1kg31~I)}soM8vZ&*Y#@gcdb*`gW2cD;2e3*ocE#Zb;vQ_t;M{%J+#lc9L*~$t zE0K}@kZ5dxhP{4`xAlDb6pE+NQyt2w;yHBnQRnxM4f4Rb0K!*K&Z6=qmb*vTqs0Sb zcj3=uYTCg@|AC8Z+jl@rsjeOnHAG&R8trrY`Q3q3^`F~v>OhmyP?Ay1xSOD)BoSC@ z*#!hHuHAL}>ghR734Ja26tQuE?caZ&(+*&;MUln-nTXdYPq*CI;92c}TA~hD1r#5G zB(xloODkh)>X~Io%oiPyX)3QBEe(b1@w(cHrRt}CXxZ2ZA9-Ii-V%QhN;kN_9+Zw` zR|?33m#$vDmG}JBP(wpQKw8@BiQ)I>&-2RKVfyeXyYLw#Xh@YpA&93y9&ciOHTU|FoBf;i_FjwucYAo) z>+}`cIaY-nEW9#-I!RinbVzCBb^3Hn(J=I*Jh#5hR>fZIR5ykl*fy%r$BbZ#!IK|( zBzt&D_g}W(|GQe~_PLL9uQ@n5=^z-5o`FjOj%a^iT>TpxxJ5)nD2kVt7dp^lvR<*r zhU-YTHZf4lOn%z}0+5C)IL6P<&t2dlzJUr`ZAg`aL>sM_$`Qs|Tf(oX1%Zt?COE`A zI9D6!p$Md*kQ@n*LV|;v@0{mY|IR#(eCWK%m^i&Nsvi9|n50uZ!9sGl76oS?l<4xP zTHLO^d1HRpJdYN8N;nh_-GBIyJhK{4O(;)ay2^yQd}gwy9)I(ln-KzWYcvr`SN$!; zXpN1%eadKDX{iRrWHGGXLj`!rJ+e6&cJm|7&X)bkO_?7t%-o&fPDWHm=@~VE3c`Ye zu3xvWi0k+7-@r98qe5$_m8$OaM30loDW zwG~$3CR5X^Q=LF{lcbUdgjqkCDhz|j0&(Z>BICtn3TbGk1>vOZXgoFgj!N{TuP8qfI<^KVTizvMnfc@6(zgw6?6vzf_k86v>cLU4 z1!AC-aSvuDW3RXfdDrx`T5DO1=4i~9H{4{k;qs%+v6-tFvg_3is@-M&LnqTM(+mw% zmaDq2g1IF$V^^9T)+1o4wA*@HgmK(q)6L^j9KdZGy{C-q-{KjwuUK&&;FK66ko2_Y ze*ZEyG8(Du`g~|y+UgNoL$q`icD8&`|B7B2AmzC|(& zl;ex!MMt_yr3+kW9Q8%LqA_7G$8@lJIre+B`~Kl6GKQ+m{?c^?3H!lzFUF5|^S&U+ z!EvktA2(g1%9l*b7ls}~Eo(2*gkfe>R8^&P^zXXw{o)%FYeN+7H}weH9lSof(`GqDN|%>2qO>7y|fzMD30Mgr>%a*~!M$U~WB z*|IG;)+NmdR6o9T386q3XiBKVY#RNPmQV=|2>u-f(5bnuJvv}0>~L$T!&2aru_RFh zO!Y-*<~qHW{L)#n0+e|RUr?syLz(G`ujfIt5$l#?OZp2I-(?WIut2o!<3O+;AMOZ6 z>HsEz`h4H&FU8J0`^B)4HD~Nv4M?y&%`l-Bdj{3-JCB1)XNJ{?Mu03-1C28>K0cG3 zl6pyd1kVRB!XHj9kqsO8gMxyH`3dlA)wybLyo9@d^T=dieIp!xO&T+- z2-j8F`yStH-y5rZm|sCbp}OtsyN6O)67vH&HaWRPU!Qg9N)!1Z zTqjN*USmEU=v$ zc}5p5u$r5jZ#?=lHU3fG#?l>AD|VeytUH~f^kS$qG`;^xLqiov;+`*G?nXKHIXymF z`jY`bp&Ze?Ch|6~651HZ{KTa-l|ZRIhiG2nYIA2Ts@RCXpT+k)?DCNstL-tId7IwY z)%AMQMLuf{MNZPG;k>d0GODm%$gO5k|Hd<{=a9N{*F^Ybat@}1-RpSh#NX`6oZkf>1IQ`wd z;{yr4PG9*9hudDhRZ&5M?lG@^ z<%dV!>j1@oJz+-CrXR1!1ih_yfM`pg;jIBpo_uuo26_>nkaz3VBZ=Ji*LaL_+0rPw zS3%(>n#|#C6o>K7)_+we4}N*&{HaagwhYeCSHfa3`@4J1x^*o>d9U5w-E;qb#*bal zl1(P&L@~Wd(->M$L`6hs0R5`53`RjqL;goKiYr5E^aIqg0Xj1;^HHy?s;cq)J~l=i z5RI|&bii2Bo|<4+M-XBFP+H4))dg^2l3p8W=aI{S&6_u4jlS|i_Mo9AzPt?}3m-(L zjEoG3h-N<9B5a1a;lwa25YVH-LJ?)<(MubzH^hl zu*XypjLxMgorzE`r>Lk0p?>7qUtJ+Hrw4AFXa=+aCZVI4nVAs>{JXom70L#%tm#4N zz}=&5b^OBm+<_--SXm+OR8>`NVJzUE0Y&0F zygH8L0g&bd!t5<@2_-cE6kPLSFSodND^*;n7}@=`1ZlxkiTa12U6={S#<4iV^t!t#n2PjGo)nh%HqPC=+?(#zlZ8TI#l8LO7zJ_h=2T;V%tP+NnOurzc#< zzg66N=>az`($F!G7QQ$499T#l0{Y0NCXD-|np?yodiwH?j^f5Wz)m%uhBr5cA|zKe zHyc^?BNM>54D6NL@+$$XA{IW6(8J75iS_F_Q4?nxWr4^d6w}GW2e|JTN1VO04eK*` zU^ljVbYAl;PqUPTFFxw@IQD=;k#qhV9W5<<2}`lUZH8Jhz?joeDEMriUssLYcv1nu zGzf^g7bW?dGG9%8<#oTak;Qbiw6x;t>LE~oBS=h33zpee6w(mPy8JAzar76Me0ex+ z2-!suh7IkM%-Drl2V-Y@^g?NPf8A{QU99q5_~KGwK$as+LSnW3Td<3iRBJ%7zNG0{-Lb-}9~PooV27aR*MjIw#_824Q{(?`6tz4A!W=Y1#Gndq=$f3;XA&p)Aw| ze7t=5GUj+9dnP>XJJf4&KS{VrB=we@{X+1EX~X|q-2_rlvDFYB3-BRI(VnnvzBH8sH%5}QkEtC2C%8wpi6Z(0}rEe&88^VZ;uyCUpDhl=;e?#eLjJzZ$Ev_~^* zV$Q&tiII_*cIxley?9~l^1A))Tf)Wmd5oV{Q+sUk@F8$G`}TzVofd){y(lX6x36wg zFpbgTwPtg_ZHWDVTjGjY`jR{6kuHO=xJ{}=s0WTOXdfQ!+6=QIzoH{WflM&X_>410 z6(-!9I5RutmiLR5B3FpE)hVBz${?s%%;AFS%m9w4`wdbq&hAGPdGp}Vu0m(8crZ(1 z;^LdU6JRU?o8WwRS0qnvPL3>DVL<6_CqV-NLhg2&*;DciVGNQW17g413iuE&j+qMb z%325d`}@~=hMNAw6u1P2k`zPvyO2W(@IgLL!OTK=uj2cAql|TN2mvgp?!A6~j9~|u zLC>!!Ru@=^w)(u)JL3bfaLcQz=y1q9+vG7|uyUgp+q|%g@}TeWw^H4@aDhBvxC=s> zg8Pb&y>NW5@f6?;bZ-@=djqT4$8pqX_TaM^Sy=<&HY9lQ|Y!|l}}0}ftmG-Mcj z^d-1dbaizvUh+Z&ES?^^V<#1x$q*cK%&3P`_ZMhIRCyP%!xV4)3XnFQJgxQjma} zN;G%u*g>ua>RB*Q8;Q0US!b)oap0^F|3*u0-V_A)huE_i-hWns4O}R6r#@UURGj#_ zfK!7b$i)O_4 zo}Y(C%M^3AlrqI#Nb%D%Ekb`V`W^(}Y2YqZ4$o}GCCUSSlGyLAI_I{G0{ymrJ{vr8 zrJ0^?KfGPz>%o+JKE8WayU=_Vc4Nr3svX&)XR7<-%Ue$3PlO*0u|?QZb%ThS+6n;p zSEr^@4)!f?(=KLGWH)|h+CJv$d>;^Q2R0w52}}xTp}E8flnAW}5}Gi0;)i}B^P(31 z_6K@Z|JRRH`hWj6Y~hIQj@9uV<5zidPnqwnT*OE}eFf8_^B9{I$&3fYhrdemYP%t9W%OQ%)0(L`eQ>KKm z96^ap!GNdK+LiCji_{eW&W0+73jjD_DXgqR3PzWF!h) z-kybt{JTY|sXd-(qHh9yC5F<>>WHz7cfr(F_b*c-qjE}e(v#D$rojEd+qIaUz8i+H z*bSbND94IcGpJbLvBmPUG=4cbM?=rR1I+_6FSnq8fhzvlf9=VW zC*`>8>PZF$20}Md#2u~FWz1R1N|!h#vSa*5>&jxuclfFmfcJckAzYzx6?S>RFu zcbrn|>dXwMM&FN!-?eB?XCs;-OiHzgAH>%O{x-jQngm;T3272rno`MfDq%$Zlcan% z`%qi90Lt?n^t$;x(tEo>GHl(p4ZwH_MWq?scxgTW6;cPZIbh*p*i+qpg2aJr83x1& zs893}q9_2B1tJ*t;4KK;ZHHmfg>D=OP7Di?h62J%`I@t1RVcrFwJC~{z(9JpH-gjx zjpq9O!?rBa)h2`x$)8hSj#XoNEhsqH9*PWNQ8~sH&o2;o+ZH+zMFD{nV|fBUR=^bO zhD@WMps~+>9xS2ic~NlK@`2Oq?cPiHA=Cyoch_Ub=#6HV%_13zU}6F_ z^MUm2hdh2HAwE6~DJ~g2qxgh$Ov4LwQGi@$&~m%9VxyPb^90QZs67`wr6KZw8WBaV zg)<;l!htdO{%BdS@W|f;A8p6QhINCi_DmB;@86Yiq&m{bfrJ9z1yNvSjYbvA^Nvg)B@b zA;JB)GcvRFm5=fQt737-*S9UinF1k5;3`q$K4xTOkTn zz-hovvM>osrDI^VnM!fW+5Et7$&w{R2ZEU5hSmW{_*NMY8H7beC6@*1_+vG~KgAO( zXG=@o2jxBfdXGuWrXQ6~)r`Uv#CbRb>8-7;2~q&zi&9AXIWwH(|JKnBR~CBzyf-xf z&W?Bl`uqDGo^KQrO>TDp#~QowHHLqQ zCi`aiO0;N--O8wa;D9{WY##SrUx2BxnI#f}Xj!x}v#^n4>dQ z$G>K`h+GPx@Y}3T=Iq zp$LRn_=~ZHg#|#&)gsnyBgU?>TB#;P-9vHpoaMcj5)vhz0CA{fomsWhtfw%0!^SqV z;DMr{5vx5n>NmDQ&&K&+NDs8KR?9ip$JosRBn&1uvUTF1t!+h_Q6^)P=j;FaplNGi zc9xj0Q~vsi`HQui@6*uBXN7YI{YaHqDK-ZVtnyr|B=~=R5Z@?cKnv_5n(46M!1>P& z|N9)on?c57O=AO3<`^@DxKya0M1&lFBZHg&YmDOjv$UUP<#6GH-b8efyK82~FV3&w|}nba}&u z%n8c1b*gydz&JDt$c<;gKHt52msGLfoYR@hHjUR&YmyN2#gK_XuvEUie~<*mshJru z;wXNZnU9=f2O0<^6BiUwV{`KmJWDY(HA}05|Ez6`-z}O}gV~nQ8HnzV>jny2!-JOz zB?HjN$;p+&bb0}xrVoDIEf8{QJjt8h@M;Y)@aUtXqhNjl%E~mqfB(MfZt*E%yCs}Z z4ex&8wQHHi2Ab-?Oy)wibQ8GwZQu&H4S+j4K!H>r>ziZzFiHZvgJI%9cqQBW%YmiO zoMu)E(7HT2F~Nh(AXd6ZGD-(zR}iWaraPG@@v0?ys`5q37P2!M(p3e<4~l8q96wn+v(1R`71kG=QE ziSPw)@HEY?nx_D3@A3k1r5HW__1zvk3l4h!0ZU7Y5kQBdD`ueBy42f(627kLvDs<> zY$FdmsII_Cl6Ojq`WO@Sz;i)YB1z3H&PpaYAwkIC=Ek+Q4oZfbH#1O)z<3nNmsr|P z*Ju7nz#Au<4&l)ag7&UeUF4OW{|USBlb7m|P>H!pDn^2>G|PN`ZQx=Ksx2&d@E|ZE zVp-F;hu6obJ;dpV&H3r()kxy#1cao@$M#v&dsu@3VnTcYqz(N%_4B8;OpKU@Mrpls zb~emgi3LTza~hbYaK0-=zk>n|wuZyZY%JYWk}#dl?+mblYBQ>~2;Ncr zlf%@a1Sx#U!KRxvaqACm_pbHGQU@Q(i&?-1uWf|pZ$4<~6>xRhI68)0y2Li>9X)Y^ zxyiF;Cj@9<1INx{RqW4}XeudxV%Um)(IOhE>4f;&rjytz2#D*&lofg;A;t6+=Ah{! z^980v03cB>u`U!{<^Kd^L+CUj{{zW+sw%~APd$r4!=NsvNbN>|4a@@g>d-rPPLg*V z5V{a3OB*PgI0EoOFE)g)4?l17`%Bw|!o;d;XqZ#&@&x%0g!(p9vu6s}ORT!Ox*J6) z32cu#oMuZg$!KrO>?Zfwe1S2ObGG=6ITl-tsMNW>e`c@ysG5VfD~L<% z;8~D0H%@YwQ)}2!jiaL-UF-%{x92Ucq=sdZx`T~i%CQ#moZDnpZaBt>nv*}_^r=&o zSR#ELK-qmLo^c8YLw}KO2h=rO`;l+S{GMy{$PxfozRpc8h0G%KqfiuL7LgRLe+I0m zuB#if(j>z6laKLbqT_gGCK=R-_nzqQ*Z1J-Z>6P)iHK-h4g6b)%=jLEM>^wRQp?{E z!#LsH4Fq4x0y36obA+HY(jybl7YSOEou>_vfw+)?MWJ$st* zLJU++3}2_Tz$fPs>wOQV7B=&s~wPx+8(Rwaofh(Sqea$2_NSK)ECbzIMvYD$OU_+EOs9- zhXIKP=qnI;kA**j5q^I<6Zd^_(O>7S02gxiM92LZvI?&*f54s33 z09FBSGt!0t4AL?TYO(&TH#QC=L4)Wq;PJabl%f6e?$pntWxvj067`VQ0c@&3>{!Gt zGO&Y_U41dBl=q$w4fThAE-WnUQbNKq^ccyOPSK*{AtWrU64f<38yiKLmz+{Uh|Gi1 z1W7BSqYhpHq8@?vpreR!39_af;xEy*K~GmeEonxOpvqC@H>Vqmxy_(ZB?e~Lk#)dH z!?XfB`!ea9i$s&dFb$HOvT$SgMjcq!2%(I?M=%tzEJG;XaQw@fY-cn*9zA*#3a8m1 z9|qiRijFuSg@KW^Mh9&m0fQK*Zut>Wuuu*D! zYm*Q?B|?bDe9+p&g+i=FME=22#68BlAC`(~z{Z%h_h7@|_`IBt&#^e>=>gpI$Lu`x zuaeH%KW$O^xy!qY|Q%fpsAAjNH}L)a+_b|u_=FJeM1rZ=0A zPz1v1oN_MQ%wjRjN2k9l+Zz7Wog%x?mi0QIacuy$L<{Ki_k}GF-<|t z9t#SGhN?p*4ft^)#}FP9+Y*E%SXx++a0>lc8Q8FUXebQXE!m+Vdhf4cJ-v_OoWD=g wup+pyzjiEUAe|ilU(w`$2{!-m=l(gG2gj6dM9I`MQ21}#CIh`}9h log{i}.txt 2>&1 &') + os.chdir('..') + +os.chdir('../aggregator') +os.system('python ../../aggregator/main.py ../config/aggregator.json > log.txt 2>&1 &') diff --git a/lib/python/flame/examples/medmnist_feddyn/trainer/__init__.py b/lib/python/flame/examples/medmnist_feddyn/trainer/__init__.py new file mode 100644 index 000000000..506f034ea --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/trainer/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + + diff --git a/lib/python/flame/examples/medmnist_feddyn/trainer/main.py b/lib/python/flame/examples/medmnist_feddyn/trainer/main.py new file mode 100644 index 000000000..b4e5fe532 --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/trainer/main.py @@ -0,0 +1,222 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""MedMNIST FedProx trainer for PyTorch Using Proximal Term.""" + + + + +import logging + +from copy import deepcopy +from flame.common.util import get_dataset_filename +from flame.config import Config +from flame.mode.horizontal.feddyn.trainer import Trainer +import torch +import torchvision +import numpy as np +from PIL import Image +from sklearn.metrics import accuracy_score + +logger = logging.getLogger(__name__) + +class CNN(torch.nn.Module): + """CNN Class""" + + def __init__(self, num_classes): + """Initialize.""" + super(CNN, self).__init__() + self.num_classes = num_classes + self.features = torch.nn.Sequential( + torch.nn.Conv2d(3, 6, kernel_size=3, padding=1), + torch.nn.BatchNorm2d(6), torch.nn.ReLU(), + torch.nn.MaxPool2d(kernel_size=2, stride=2), + torch.nn.Conv2d(6, 16, kernel_size=3, padding=1), + torch.nn.BatchNorm2d(16), torch.nn.ReLU(), + torch.nn.MaxPool2d(kernel_size=2, stride=2)) + self.fc = torch.nn.Linear(16 * 7 * 7, num_classes) + + def forward(self, x): + x = self.features(x) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + +class PathMNISTDataset(torch.utils.data.Dataset): + + def __init__(self, split, filename, transform=None, as_rgb=False): + npz_file = np.load(filename) + self.split = split + self.transform = transform + self.as_rgb = as_rgb + + if self.split == 'train': + self.imgs = npz_file['train_images'] + self.labels = npz_file['train_labels'] + elif self.split == 'val': + self.imgs = npz_file['val_images'] + self.labels = npz_file['val_labels'] + elif self.split == 'test': + self.imgs = npz_file['test_images'] + self.labels = npz_file['test_labels'] + else: + raise ValueError + + def __len__(self): + return self.imgs.shape[0] + + def __getitem__(self, index): + img, target = self.imgs[index], self.labels[index].astype(int) + img = Image.fromarray(img) + + if self.as_rgb: + img = img.convert('RGB') + + if self.transform is not None: + img = self.transform(img) + + return img, target + + +class PyTorchMedMNistTrainer(Trainer): + """PyTorch MedMNist Trainer""" + + def __init__(self, config: Config) -> None: + self.config = config + self.dataset_size = 0 + + self.model = None + self.device = torch.device("cpu") + + self.train_loader = None + self.val_loader = None + + self.epochs = self.config.hyperparameters.epochs + self.batch_size = self.config.hyperparameters.batch_size + self._round = 1 + self._rounds = self.config.hyperparameters.rounds + + def initialize(self) -> None: + """Initialize role.""" + + self.model = CNN(num_classes=9).to(self.device) + # ensure that weight_decay = 0 for FedDyn + self.optimizer = torch.optim.Adam(self.model.parameters(), lr=1e-3, weight_decay=0.0) + self.criterion = torch.nn.CrossEntropyLoss() + + def load_data(self) -> None: + """MedMNIST Pathology Dataset + The dataset is kindly released by Jakob Nikolas Kather, Johannes Krisam, et al. (2019) in their paper + "Predicting survival from colorectal cancer histology slides using deep learning: A retrospective multicenter study", + and made available through Yang et al. (2021) in + "MedMNIST Classification Decathlon: A Lightweight AutoML Benchmark for Medical Image Analysis". + Dataset Repo: https://github.com/MedMNIST/MedMNIST + """ + + filename = get_dataset_filename(self.config.dataset) + + data_transform = torchvision.transforms.Compose([ + torchvision.transforms.ToTensor(), + torchvision.transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) + ]) + + train_dataset = PathMNISTDataset(split='train', filename=filename, transform=data_transform) + val_dataset = PathMNISTDataset(split='val', filename=filename, transform=data_transform) + + self.train_loader = torch.utils.data.DataLoader( + train_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=4 * torch.cuda.device_count(), + pin_memory=True, + drop_last=True + ) + self.val_loader = torch.utils.data.DataLoader( + val_dataset, + batch_size=self.batch_size, + shuffle=True, + num_workers=4 * torch.cuda.device_count(), + pin_memory=True, + drop_last=True + ) + + self.dataset_size = len(train_dataset) + + def train(self) -> None: + """Train a model.""" + + # save global model first + prev_model = deepcopy(self.model) + + for epoch in range(self.epochs): + self.model.train() + loss_lst = list() + + for data, label in self.train_loader: + data, label = data.to(self.device), label.to(self.device) + self.optimizer.zero_grad() + output = self.model(data) + + # proximal term included in loss + loss = self.criterion(output, label.squeeze()) + self.regularizer.get_term(curr_model = self.model, prev_model = prev_model) + + # back to normal stuff + loss_lst.append(loss.item()) + loss.backward() + self.optimizer.step() + + train_loss = sum(loss_lst) / len(loss_lst) + self.update_metrics({"Train Loss": train_loss}) + + def evaluate(self) -> None: + """Evaluate a model.""" + self.model.eval() + loss_lst = list() + labels = torch.tensor([],device=self.device) + labels_pred = torch.tensor([],device=self.device) + with torch.no_grad(): + for data, label in self.val_loader: + data, label = data.to(self.device), label.to(self.device) + output = self.model(data) + loss = self.criterion(output, label.squeeze()) + loss_lst.append(loss.item()) + labels_pred = torch.cat([labels_pred, output.argmax(dim=1)], dim=0) + labels = torch.cat([labels, label], dim=0) + + labels_pred = labels_pred.cpu().detach().numpy() + labels = labels.cpu().detach().numpy() + val_acc = accuracy_score(labels, labels_pred) + + val_loss = sum(loss_lst) / len(loss_lst) + self.update_metrics({"Val Loss": val_loss, "Val Accuracy": val_acc, "Testset Size": len(self.val_loader)}) + logger.info(f"Test Loss: {val_loss}") + logger.info(f"Test Accuracy: {val_acc}") + logger.info(f"Test Set Size: {len(self.val_loader)}") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description='') + parser.add_argument('config', nargs='?', default="./config.json") + + args = parser.parse_args() + + config = Config(args.config) + + t = PyTorchMedMNistTrainer(config) + t.compose() + t.run() diff --git a/lib/python/flame/examples/medmnist_feddyn/trainer/sites.txt b/lib/python/flame/examples/medmnist_feddyn/trainer/sites.txt new file mode 100644 index 000000000..8572a63f1 --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/trainer/sites.txt @@ -0,0 +1,10 @@ +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site1.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site2.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site3.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site4.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site5.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site6.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site7.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site8.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site9.npz +https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site10.npz diff --git a/lib/python/flame/examples/medmnist_feddyn/trainer/template.json b/lib/python/flame/examples/medmnist_feddyn/trainer/template.json new file mode 100644 index 000000000..17d2ec8d2 --- /dev/null +++ b/lib/python/flame/examples/medmnist_feddyn/trainer/template.json @@ -0,0 +1,81 @@ +{ + "taskid": "505f9fc483cf4df68a2409257b5fad7d3c580480", + "backend": "mqtt", + "brokers": [ + { + "host": "localhost", + "sort": "mqtt" + }, + { + "host": "localhost:10104", + "sort": "p2p" + } + ], + "groupAssociation": { + "param-channel": "default" + }, + "channels": [ + { + "description": "Model update is sent from trainer to aggregator and vice-versa", + "groupBy": { + "type": "tag", + "value": [ + "default" + ] + }, + "name": "param-channel", + "pair": [ + "trainer", + "aggregator" + ], + "funcTags": { + "aggregator": [ + "distribute", + "aggregate", + "getDatasetSize" + ], + "trainer": [ + "fetch", + "upload", + "uploadDatasetSize" + ] + } + } + ], + "dataset": "https://github.com/GustavBaumgart/flame-datasets/raw/main/medmnist/site1.npz", + "dependencies": [ + "numpy >= 1.2.0" + ], + "hyperparameters": { + "batchSize": 50, + "learningRate": 0.001, + "rounds": 100, + "epochs": 4 + }, + "baseModel": { + "name": "", + "version": 1 + }, + "job": { + "id": "336a358619ab59012eabeefb", + "name": "medmnist" + }, + "registry": { + "sort": "dummy", + "uri": "" + }, + "selector": { + "sort": "default", + "kwargs": {} + }, + "optimizer": { + "sort": "feddyn", + "kwargs": { + "alpha": 0.01, + "weight_decay": 0.001 + } + }, + "maxRunTime": 300, + "realm": "default", + "role": "trainer" +} diff --git a/lib/python/flame/mode/horizontal/coord_syncfl/trainer.py b/lib/python/flame/mode/horizontal/coord_syncfl/trainer.py index f68d4ea56..f48419953 100644 --- a/lib/python/flame/mode/horizontal/coord_syncfl/trainer.py +++ b/lib/python/flame/mode/horizontal/coord_syncfl/trainer.py @@ -17,7 +17,7 @@ import logging from abc import ABCMeta -from flame.common.constants import DeviceType, TrainerState +from flame.common.constants import DeviceType, TrainState from flame.common.util import weights_to_device, weights_to_model_device from flame.mode.composer import Composer from flame.mode.horizontal.syncfl.trainer import TAG_FETCH, TAG_UPLOAD @@ -68,7 +68,7 @@ def _fetch_weights(self, tag: str) -> None: if MessageType.ROUND in msg: self._round = msg[MessageType.ROUND] - self.regularizer.save_state(TrainerState.PRE_TRAIN, glob_model=self.model) + self.regularizer.save_state(TrainState.PRE, glob_model=self.model) logger.debug(f"work_done: {self._work_done}, round: {self._round}") def _send_weights(self, tag: str) -> None: @@ -82,7 +82,7 @@ def _send_weights(self, tag: str) -> None: channel.await_join() self._update_weights() - self.regularizer.save_state(TrainerState.POST_TRAIN, loc_model=self.model) + self.regularizer.save_state(TrainState.POST, loc_model=self.model) delta_weights = self._delta_weights_fn(self.weights, self.prev_weights) diff --git a/lib/python/flame/mode/horizontal/feddyn/__init__.py b/lib/python/flame/mode/horizontal/feddyn/__init__.py new file mode 100644 index 000000000..e11510fe5 --- /dev/null +++ b/lib/python/flame/mode/horizontal/feddyn/__init__.py @@ -0,0 +1,16 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + diff --git a/lib/python/flame/mode/horizontal/feddyn/top_aggregator.py b/lib/python/flame/mode/horizontal/feddyn/top_aggregator.py new file mode 100644 index 000000000..6e6c23506 --- /dev/null +++ b/lib/python/flame/mode/horizontal/feddyn/top_aggregator.py @@ -0,0 +1,154 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""FedDyn horizontal FL top level aggregator.""" + +import logging + +from flame.common.util import (MLFramework, get_ml_framework_in_use, + weights_to_device) +from flame.common.constants import (DeviceType, TrainState) +from ...composer import Composer +from ...message import MessageType +from ...tasklet import Loop, Tasklet + +from ..top_aggregator import TopAggregator as BaseTopAggregator + +logger = logging.getLogger(__name__) + +TAG_DISTRIBUTE = 'distribute' +TAG_AGGREGATE = 'aggregate' +TAG_GET_DATATSET_SIZE = 'getDatasetSize' + +class TopAggregator(BaseTopAggregator): + """FedDyn Top level Aggregator implements an ML aggregation role.""" + + def internal_init(self) -> None: + """Initialize internal state for role.""" + ml_framework_in_use = get_ml_framework_in_use() + + # only support pytorch + if ml_framework_in_use != MLFramework.PYTORCH: + raise NotImplementedError( + "supported ml framework not found; " + f"supported frameworks (for feddyn) are: {[MLFramework.PYTORCH.name.lower()]}") + + super().internal_init() + + def get(self, tag: str) -> None: + """Get data from remote role(s).""" + if tag == TAG_AGGREGATE: + self._aggregate_weights(tag) + elif tag == TAG_GET_DATATSET_SIZE: + self.get_dataset_size(tag) + + def get_dataset_size(self, tag: str) -> None: + logger.debug("calling get_dataset_size") + channel = self.cm.get_by_tag(tag) + if not channel: + return + + self.dataset_sizes = dict() + + # receive dataset size from all trainers + all_ends = channel.all_ends() + for msg, metadata in channel.recv_fifo(all_ends): + end, _ = metadata + if not msg: + logger.debug(f"No data from {end}; skipping it") + continue + + logger.debug(f"received data from {end}") + if MessageType.DATASET_SIZE in msg: + self.dataset_sizes[end] = msg[MessageType.DATASET_SIZE] + + # record all active trainers + self.optimizer.save_state(TrainState.PRE, active_ends = all_ends) + logger.debug(f"dataset sizes: {self.dataset_sizes}") + logger.debug("exiting get_dataset_size") + + def _distribute_weights(self, tag: str) -> None: + logger.debug("calling _distribute_weights") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"channel not found for tag {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # before distributing weights, update it from global model + self._update_weights() + + total_samples = sum(self.dataset_sizes.values()) + num_trainers = len(self.dataset_sizes) + weight_dict = {end:(self.dataset_sizes[end]/total_samples) * num_trainers for end in self.dataset_sizes} + + logger.debug(f"weight_dict: {weight_dict}") + + # send out global model parameters to trainers + for end in channel.ends(): + logger.debug(f"sending weights to {end}") + channel.send(end, { + MessageType.WEIGHTS: weights_to_device(self.weights, DeviceType.CPU), + MessageType.ROUND: self._round, + MessageType.ALPHA_ADPT: self.optimizer.alpha / weight_dict.get(end, 1) + }) + + def compose(self) -> None: + """Compose role with tasklets.""" + with Composer() as composer: + self.composer = composer + + task_internal_init = Tasklet("", self.internal_init) + + task_init = Tasklet("", self.initialize) + + task_load_data = Tasklet("", self.load_data) + + task_get_dataset = Tasklet("", self.get, TAG_GET_DATATSET_SIZE) + + task_put = Tasklet("", self.put, TAG_DISTRIBUTE) + + task_get = Tasklet("", self.get, TAG_AGGREGATE) + + task_train = Tasklet("", self.train) + + task_eval = Tasklet("", self.evaluate) + + task_analysis = Tasklet("", self.run_analysis) + + task_save_metrics = Tasklet("", self.save_metrics) + + task_increment_round = Tasklet("", self.increment_round) + + task_end_of_training = Tasklet("", self.inform_end_of_training) + + task_save_params = Tasklet("", self.save_params) + + task_save_model = Tasklet("", self.save_model) + + # create a loop object with loop exit condition function + loop = Loop(loop_check_fn=lambda: self._work_done) + task_internal_init >> task_load_data >> task_init >> loop( + task_get_dataset >> task_put >> task_get >> task_train >> + task_eval >> task_analysis >> task_save_metrics >> + task_increment_round + ) >> task_end_of_training >> task_save_params >> task_save_model + + @classmethod + def get_func_tags(cls) -> list[str]: + """Return a list of function tags defined in the top level aggregator role.""" + return [TAG_DISTRIBUTE, TAG_AGGREGATE, TAG_GET_DATATSET_SIZE] diff --git a/lib/python/flame/mode/horizontal/feddyn/trainer.py b/lib/python/flame/mode/horizontal/feddyn/trainer.py new file mode 100644 index 000000000..44b2517fc --- /dev/null +++ b/lib/python/flame/mode/horizontal/feddyn/trainer.py @@ -0,0 +1,148 @@ +# Copyright 2023 Cisco Systems, Inc. and its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 +"""FedDyn horizontal FL trainer.""" + +import logging + +from flame.channel import VAL_CH_STATE_RECV, VAL_CH_STATE_SEND +from flame.common.constants import TrainState +from flame.common.util import (MLFramework, get_ml_framework_in_use, + weights_to_model_device) +from ...composer import Composer +from ...message import MessageType +from ...tasklet import Loop, Tasklet + +from ..trainer import Trainer as BaseTrainer + +logger = logging.getLogger(__name__) + +TAG_FETCH = 'fetch' +TAG_UPLOAD = 'upload' +TAG_UPLOAD_DATASET_SIZE = 'uploadDatasetSize' + +class Trainer(BaseTrainer): + """FedDyn Trainer implements an ML training role.""" + + def internal_init(self) -> None: + """Initialize internal state for role.""" + logger.debug("in internal_init") + ml_framework_in_use = get_ml_framework_in_use() + + # only support pytorch + if ml_framework_in_use != MLFramework.PYTORCH: + raise NotImplementedError( + "supported ml framework not found; " + f"supported frameworks (for feddyn) are: {[MLFramework.PYTORCH.name.lower()]}") + + super().internal_init() + + def get(self, tag: str) -> None: + """Get data from remote role(s).""" + if tag == TAG_FETCH: + self._fetch_weights(tag) + + def _fetch_weights(self, tag: str) -> None: + logger.debug("calling _fetch_weights") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"[_fetch_weights] channel not found with tag {tag}") + return + + # this call waits for at least one peer joins this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end(VAL_CH_STATE_RECV) + msg, _ = channel.recv(end) + + if MessageType.WEIGHTS in msg: + self.weights = weights_to_model_device(msg[MessageType.WEIGHTS], self.model) + self._update_model() + + # adjust alpha based on aggregator computation + if MessageType.ALPHA_ADPT in msg: + logger.debug("got alpha_adpt") + self.regularizer.alpha = msg[MessageType.ALPHA_ADPT] + + if MessageType.EOT in msg: + self._work_done = msg[MessageType.EOT] + + if MessageType.ROUND in msg: + self._round = msg[MessageType.ROUND] + + self.regularizer.save_state(TrainState.PRE, glob_model = self.model) + logger.debug(f"work_done: {self._work_done}, round: {self._round}") + + def put(self, tag: str) -> None: + """Set data to remote role(s).""" + if tag == TAG_UPLOAD: + self._send_weights(tag) + elif tag == TAG_UPLOAD_DATASET_SIZE: + self._send_dataset_size(tag) + + def _send_dataset_size(self, tag: str) -> None: + logger.debug("calling _send_dataset_size") + channel = self.cm.get_by_tag(tag) + if not channel: + logger.debug(f"[_send_dataset_size] channel not found with {tag}") + return + + # this call waits for at least one peer to join this channel + channel.await_join() + + # one aggregator is sufficient + end = channel.one_end(VAL_CH_STATE_SEND) + + msg = { + MessageType.DATASET_SIZE: self.dataset_size + } + channel.send(end, msg) + logger.debug("sending dataset size done") + + def compose(self) -> None: + """Compose role with tasklets.""" + logger.debug("in compose") + with Composer() as composer: + self.composer = composer + + task_internal_init = Tasklet("", self.internal_init) + + task_load_data = Tasklet("", self.load_data) + + task_init = Tasklet("", self.initialize) + + task_put_dataset_size = Tasklet("", self.put, TAG_UPLOAD_DATASET_SIZE) + + task_get = Tasklet("", self.get, TAG_FETCH) + + task_train = Tasklet("", self.train) + + task_eval = Tasklet("", self.evaluate) + + task_put = Tasklet("", self.put, TAG_UPLOAD) + + task_save_metrics = Tasklet("", self.save_metrics) + + # create a loop object with loop exit condition function + loop = Loop(loop_check_fn=lambda: self._work_done) + task_internal_init >> task_load_data >> task_init >> loop( + task_put_dataset_size >> task_get >> task_train >> task_eval >> + task_put >> task_save_metrics) + + @classmethod + def get_func_tags(cls) -> list[str]: + """Return a list of function tags defined in the trainer role.""" + return [TAG_FETCH, TAG_UPLOAD, TAG_UPLOAD_DATASET_SIZE] diff --git a/lib/python/flame/mode/horizontal/syncfl/trainer.py b/lib/python/flame/mode/horizontal/syncfl/trainer.py index 801112405..09a0c6e33 100644 --- a/lib/python/flame/mode/horizontal/syncfl/trainer.py +++ b/lib/python/flame/mode/horizontal/syncfl/trainer.py @@ -19,7 +19,7 @@ from flame.channel import VAL_CH_STATE_RECV, VAL_CH_STATE_SEND from flame.channel_manager import ChannelManager -from flame.common.constants import DeviceType, TrainerState +from flame.common.constants import DeviceType, TrainState from flame.common.custom_abcmeta import ABCMeta, abstract_attribute from flame.common.util import ( MLFramework, @@ -134,7 +134,8 @@ def _fetch_weights(self, tag: str) -> None: msg[MessageType.DATASAMPLER_METADATA] ) - self.regularizer.save_state(TrainerState.PRE_TRAIN, glob_model=self.model) + self.regularizer.save_state(TrainState.PRE, glob_model=self.model) + logger.debug(f"work_done: {self._work_done}, round: {self._round}") def put(self, tag: str) -> None: @@ -156,7 +157,7 @@ def _send_weights(self, tag: str) -> None: end = channel.one_end(VAL_CH_STATE_SEND) self._update_weights() - self.regularizer.save_state(TrainerState.POST_TRAIN, loc_model=self.model) + self.regularizer.save_state(TrainState.POST, loc_model=self.model) delta_weights = self._delta_weights_fn(self.weights, self.prev_weights) diff --git a/lib/python/flame/mode/message.py b/lib/python/flame/mode/message.py index d83d176fa..b21d46523 100644 --- a/lib/python/flame/mode/message.py +++ b/lib/python/flame/mode/message.py @@ -42,3 +42,6 @@ class MessageType(Enum): COORDINATED_ENDS = 11 # ends coordinated by a coordinator DATASAMPLER_METADATA = 12 # datasampler metadata + + ALPHA_ADPT = 13 # adaptive hyperparameter used in FedDyn implementation + diff --git a/lib/python/flame/optimizer/feddyn.py b/lib/python/flame/optimizer/feddyn.py index 381509775..bbd74bfc0 100644 --- a/lib/python/flame/optimizer/feddyn.py +++ b/lib/python/flame/optimizer/feddyn.py @@ -18,6 +18,7 @@ """https://arxiv.org/abs/2111.04263""" import logging +from flame.common.constants import TrainState from diskcache import Cache from ..common.typing import ModelWeights from ..common.util import (MLFramework, get_ml_framework_in_use) @@ -30,7 +31,7 @@ class FedDyn(FedAvg): """FedDyn class.""" - def __init__(self, alpha): + def __init__(self, alpha, weight_decay): """Initialize FedDyn instance.""" ml_framework_in_use = get_ml_framework_in_use() @@ -43,12 +44,29 @@ def __init__(self, alpha): super().__init__() self.alpha = alpha - self.h_t = None + self.weight_decay = weight_decay + self.local_param_dict = dict() # override parent's self.regularizer - self.regularizer = FedDynRegularizer(self.alpha) + self.regularizer = FedDynRegularizer(self.alpha, self.weight_decay) logger.debug("Initializing feddyn") + def save_state(self, state: TrainState, **kwargs): + if state == TrainState.PRE: + active_ends = kwargs['active_ends'] + + # adjust history terms to fit active trainers + new_local_param_dict = dict() + for end in active_ends: + if end in self.local_param_dict: + new_local_param_dict[end] = self.local_param_dict[end] + else: + # default value for no diff history so far + new_local_param_dict[end] = None + + self.local_param_dict = new_local_param_dict + + def do(self, base_weights: ModelWeights, cache: Cache, @@ -71,35 +89,44 @@ def do(self, """ logger.debug("calling feddyn") - num_trainers = kwargs['num_trainers'] - num_selected = len(cache) - - # populate h_t - if self.h_t is None: - self.h_t = dict() - for k in base_weights: - self.h_t[k] = 0.0 + assert (base_weights is not None) + + # reset global weights before aggregation + self.agg_weights = base_weights - self.agg_weights = super().do(base_weights, - cache, - total=total, - version=version) + if len(cache) == 0 or total == 0: + return None - self.adapt_fn(self.agg_weights, base_weights, num_trainers, num_selected) + # get unweighted mean of selected trainers + rate = 1 / len(cache) + for k in list(cache.iterkeys()): + tres = cache.pop(k) + self.add_to_hist(k, tres) + self.aggregate_fn(tres, rate) - return self.current_weights - - def adapt_fn(self, average, base, num_trainers, num_selected): + avg_model = self.agg_weights + + # perform unweighted mean on all hist terms + mean_local_param = {k:0.0 for k in avg_model} + rate = 1 / len(self.local_param_dict) + for end in self.local_param_dict: + if self.local_param_dict[end] != None: + h = self.local_param_dict[end] + mean_local_param = {k:v + rate*h[k] for (k,v) in mean_local_param.items()} - # get delta from averaging which we use as (1/|P_t|) * sum_{k in P_t}[theta^t_k - theta^{t-1}] - self.d_t = {k: average[k] - base[k] for k in average.keys()} - # (num_selected / num_trainers) = (|P_t| / m) - # this acts as a conversion factor for d_t to be averaged among all active trainers - d_mult = self.alpha * (num_selected / num_trainers) - self.h_t = {k:self.h_t[k] - d_mult * self.d_t[k] for k in self.h_t.keys()} + self.cld_model = {k:avg_model[k]+mean_local_param[k] for k in avg_model} - # here h_t needs to be multiplied by (1/alpha) before it is subtracted from the averaged model - # although the averaged model is weighted by dataset, we take this to be the same as (1/|P_t|) * sum_{k in P_t}[theta^t_k] - h_mult = 1.0 / self.alpha - self.current_weights = {k:average[k] - h_mult * self.h_t[k] for k in self.h_t.keys()} + return self.cld_model + + def add_to_hist(self, end, tres): + if end in self.local_param_dict: + if self.local_param_dict[end] == None: + self.local_param_dict[end] = tres.weights + else: + # aggregate diffs + self.local_param_dict[end] = {k:v+tres.weights[k] for (k,v) in self.local_param_dict[end].items()} + else: + # case: end was not previously recorded as active trainer + logger.debug(f"adding untracked end {end} to hist terms") + self.local_param_dict[end] = tres.weights diff --git a/lib/python/flame/optimizer/regularizer/default.py b/lib/python/flame/optimizer/regularizer/default.py index 6207bb636..98c7c1ebd 100644 --- a/lib/python/flame/optimizer/regularizer/default.py +++ b/lib/python/flame/optimizer/regularizer/default.py @@ -16,7 +16,7 @@ """Dummy Regularizer.""" import logging -from flame.common.constants import TrainerState +from flame.common.constants import TrainState logger = logging.getLogger(__name__) @@ -32,7 +32,7 @@ def get_term(self, **kwargs): """No regularizer term for dummy regularizer.""" return 0.0 - def save_state(self, state: TrainerState, **kwargs): + def save_state(self, state: TrainState, **kwargs): """No states saved in dummy regularizer.""" pass diff --git a/lib/python/flame/optimizer/regularizer/feddyn.py b/lib/python/flame/optimizer/regularizer/feddyn.py index 80657d1a6..bcac00609 100644 --- a/lib/python/flame/optimizer/regularizer/feddyn.py +++ b/lib/python/flame/optimizer/regularizer/feddyn.py @@ -16,7 +16,7 @@ """FedDyn Regularizer.""" import logging -from flame.common.constants import TrainerState +from flame.common.constants import TrainState from flame.common.util import (get_params_detached_pytorch, get_params_as_vector_pytorch) from .default import Regularizer @@ -26,10 +26,11 @@ class FedDynRegularizer(Regularizer): """FedDyn Regularizer class.""" - def __init__(self, alpha): + def __init__(self, alpha, weight_decay): """Initialize FedDynRegularizer instance.""" super().__init__() self.alpha = alpha + self.weight_decay = weight_decay # save states in dictionary self.state_dict = dict() @@ -51,18 +52,18 @@ def get_term(self, **kwargs): w_vector = get_params_as_vector_pytorch(w) w_t_vector = get_params_as_vector_pytorch(w_t) - # proximal term - norm_sq = (self.alpha / 2) * torch.sum(torch.pow(w_vector - w_t_vector, 2)) + # weight decay term using alpha parameter + w_decay_term = ((self.alpha + self.weight_decay) / 2) * torch.sum(torch.pow(w_vector, 2)) - # grad term - inner_prod = torch.sum(self.prev_grad * w_vector) + # remaining loss term + loss_algo = self.alpha * torch.sum(w_vector * (-w_t_vector + self.prev_grad)) - return - inner_prod + norm_sq - - def save_state(self, state: TrainerState, **kwargs): - if state == TrainerState.PRE_TRAIN: + return loss_algo + w_decay_term + + def save_state(self, state: TrainState, **kwargs): + if state == TrainState.PRE: self.state_dict['glob_model'] = get_params_detached_pytorch(kwargs['glob_model']) - elif state == TrainerState.POST_TRAIN: + elif state == TrainState.POST: self.state_dict['loc_model'] = get_params_detached_pytorch(kwargs['loc_model']) def update(self): @@ -75,4 +76,4 @@ def update(self): w_t_vector = get_params_as_vector_pytorch(w_t) # adjust prev_grad - self.prev_grad = self.prev_grad - self.alpha * (w_vector - w_t_vector) + self.prev_grad += (w_vector - w_t_vector)