From 148055504afa9b95afdd42bc932c39d64257f826 Mon Sep 17 00:00:00 2001 From: Haotian Tang Date: Wed, 6 Sep 2023 17:21:32 -0400 Subject: [PATCH] [Major] Add new implementation of TinyChat. --- README.md | 5 +- awq/kernels/csrc/attention/README.md | 8 + .../csrc/attention/cuda_bf16_fallbacks.cuh | 257 +++ .../csrc/attention/cuda_bf16_wrapper.h | 23 + .../decoder_masked_multihead_attention.cu | 154 ++ .../decoder_masked_multihead_attention.h | 184 ++ ...er_masked_multihead_attention_template.hpp | 1608 +++++++++++++++ ...decoder_masked_multihead_attention_utils.h | 1786 +++++++++++++++++ awq/kernels/csrc/attention/ft_attention.cpp | 182 ++ awq/kernels/csrc/attention/ft_attention.h | 15 + awq/kernels/csrc/attention/setup.py | 152 ++ awq/kernels/csrc/pybind.cpp | 7 + awq/kernels/csrc/quantization/gemv_cuda.cu | 247 +++ awq/kernels/csrc/quantization/gemv_cuda.h | 9 + awq/kernels/setup.py | 31 +- awq/quantize/pre_quant.py | 2 +- awq/quantize/qmodule.py | 70 +- awq/quantize/quantizer.py | 4 +- tinychat/README.md | 59 +- tinychat/benchmark.py | 123 ++ tinychat/demo.py | 181 +- tinychat/figures/4090_example.gif | Bin 2165864 -> 2673510 bytes tinychat/figures/orin_example.gif | Bin 4502760 -> 6669106 bytes tinychat/models/__init__.py | 3 + tinychat/models/falcon.py | 296 +++ tinychat/models/llama.py | 311 +++ tinychat/models/mpt.py | 301 +++ tinychat/modules/fused_attn.py | 247 ++- tinychat/modules/fused_mlp.py | 52 +- tinychat/modules/fused_norm.py | 21 +- tinychat/stream_generators/__init__.py | 2 +- .../stream_generators/falcon_stream_gen.py | 31 +- tinychat/stream_generators/stream_gen.py | 77 +- tinychat/utils/__init__.py | 3 + tinychat/utils/constants.py | 7 + tinychat/utils/load_quant.py | 59 +- tinychat/utils/prompt_templates.py | 150 +- tinychat/utils/tune.py | 61 + 38 files changed, 6458 insertions(+), 270 deletions(-) create mode 100644 awq/kernels/csrc/attention/README.md create mode 100644 awq/kernels/csrc/attention/cuda_bf16_fallbacks.cuh create mode 100644 awq/kernels/csrc/attention/cuda_bf16_wrapper.h create mode 100644 awq/kernels/csrc/attention/decoder_masked_multihead_attention.cu create mode 100644 awq/kernels/csrc/attention/decoder_masked_multihead_attention.h create mode 100644 awq/kernels/csrc/attention/decoder_masked_multihead_attention_template.hpp create mode 100644 awq/kernels/csrc/attention/decoder_masked_multihead_attention_utils.h create mode 100644 awq/kernels/csrc/attention/ft_attention.cpp create mode 100644 awq/kernels/csrc/attention/ft_attention.h create mode 100644 awq/kernels/csrc/attention/setup.py create mode 100644 awq/kernels/csrc/quantization/gemv_cuda.cu create mode 100644 awq/kernels/csrc/quantization/gemv_cuda.h create mode 100644 tinychat/benchmark.py create mode 100644 tinychat/models/__init__.py create mode 100644 tinychat/models/falcon.py create mode 100644 tinychat/models/llama.py create mode 100644 tinychat/models/mpt.py create mode 100644 tinychat/utils/__init__.py create mode 100644 tinychat/utils/constants.py create mode 100644 tinychat/utils/tune.py diff --git a/README.md b/README.md index 59bb310..9dab7aa 100644 --- a/README.md +++ b/README.md @@ -12,14 +12,15 @@ The current release supports: - Efficient CUDA kernel implementation for fast inference (support context and decoding stage). - Examples on 4-bit inference of an instruction-tuned model (Vicuna) and multi-modal LM (LLaVA). -![TinyChat on RTX 4090: W4A16 is 2.3x faster than FP16](./tinychat/figures/4090_example.gif) +![TinyChat on Orin: W4A16 is 3.2x faster than FP16](./tinychat/figures/orin_example.gif) -Check out [TinyChat](tinychat), which delievers 2.3x faster inference performance for the **LLaMA-2** chatbot on RTX 4090! +Check out [TinyChat](tinychat), which delievers **30 tokens/second** inference performance (**3.2x faster** than FP16) for the **LLaMA-2** chatbot on the resource-constrained NVIDIA Jetson Orin! It also offers a turn-key solution for **on-device inference** of LLMs on **resource-constrained edge platforms**. With TinyChat, it is now possible to run **large** models on **small** and **low-power** devices even without Internet connection. ## News +- [2023/09] ⚡ Check out our latest [**TinyChat**](tinychat), which is ~2x faster than the first release on Orin! - [2023/09] ⚡ Check out [**AutoAWQ**](https://github.com/casper-hansen/AutoAWQ), a third-party implementation to make AWQ easier to expand to new models, improve inference speed, and integrate into Huggingface. - [2023/07] 🔥 We released **TinyChat**, an efficient and lightweight chatbot interface based on AWQ. TinyChat enables efficient LLM inference on both cloud and edge GPUs. LLama-2-chat models are supported! Check out our implementation [here](tinychat). - [2023/07] 🔥 We added AWQ support and pre-computed search results for Llama-2 models (7B & 13B). Checkout our model zoo [here](https://huggingface.co/datasets/mit-han-lab/awq-model-zoo)! diff --git a/awq/kernels/csrc/attention/README.md b/awq/kernels/csrc/attention/README.md new file mode 100644 index 0000000..ec0aae5 --- /dev/null +++ b/awq/kernels/csrc/attention/README.md @@ -0,0 +1,8 @@ +# Attention kernel from FasterTransformer + +This CUDA extension wraps the single-query attention [kernel](https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp) from +FasterTransformer v5.2.1 for benchmarking purpose. + +```sh +cd csrc/ft_attention && pip install . +``` diff --git a/awq/kernels/csrc/attention/cuda_bf16_fallbacks.cuh b/awq/kernels/csrc/attention/cuda_bf16_fallbacks.cuh new file mode 100644 index 0000000..f5641f6 --- /dev/null +++ b/awq/kernels/csrc/attention/cuda_bf16_fallbacks.cuh @@ -0,0 +1,257 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_bf16_wrapper.h" +#include + +namespace fastertransformer { + +#ifdef ENABLE_BF16 +inline __device__ float2 bf1622float2(const __nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = __low2float(val); + f_val.y = __high2float(val); + return f_val; +#else + return __bfloat1622float2(val); +#endif +} + +inline __device__ int16_t bf1622int16(__nv_bfloat162 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float2 f_val; + f_val.x = max(min(__low2float(val), 127.f), -128.f); + f_val.y = max(min(__high2float(val), 127.f), -128.f); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(f_val.x)); + int8[1] = static_cast(static_cast(f_val.y)); + return int16; +#else + val = __hmin2(val, make_bfloat162(127., 127.)); + val = __hmax2(val, make_bfloat162(-128., -128.)); + union { int8_t int8[2]; int16_t int16; }; + int8[0] = static_cast(static_cast(val.x)); + int8[1] = static_cast(static_cast(val.y)); + return int16; +#endif +} + +inline __device__ __nv_bfloat162 float22bf162(const float2 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __floats2bfloat162_rn(val.x, val.y); +#else + return __float22bfloat162_rn(val); +#endif +} + +inline __device__ __nv_bfloat162 bf162bf162(const __nv_bfloat16 val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + __nv_bfloat162 val2; + val2.x = val; + val2.y = val; + return val2; +#else + return __bfloat162bfloat162(val); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl + fyl, fxh + fyh); +#else + return __hadd2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) + __bfloat162float(y) ); +#else + return __hadd(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hsub2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl - fyl, fxh - fyh); +#else + return __hsub2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hsub(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) - __bfloat162float(y) ); +#else + return __hsub(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(const __nv_bfloat162 x, const __nv_bfloat162 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + return __floats2bfloat162_rn(fxl * fyl, fxh * fyh); +#else + return __hmul2(x, y); +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(const __nv_bfloat16 x, const __nv_bfloat16 y) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) ); +#else + return __hmul(x, y); +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(const __nv_bfloat162 x, const __nv_bfloat162 y, const __nv_bfloat162 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh, fyl, fyh, fzl, fzh; + fxl = __low2float(x); + fxh = __high2float(x); + fyl = __low2float(y); + fyh = __high2float(y); + fzl = __low2float(z); + fzh = __high2float(z); + return __floats2bfloat162_rn(fxl * fyl + fzl, fxh * fyh + fzh); +#else + return __hfma2(x, y, z); +#endif +} + +inline __device__ __nv_bfloat16 bf16hfma(const __nv_bfloat16 x, const __nv_bfloat16 y, const __nv_bfloat16 z) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16( __bfloat162float(x) * __bfloat162float(y) + __bfloat162float(z)); +#else + return __hfma(x, y, z); +#endif +} + +inline __device__ __nv_bfloat162 bf16exp2(const __nv_bfloat162 x) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fxl, fxh; + fxl = __low2float(x); + fxh = __high2float(x);; + return __floats2bfloat162_rn(expf(fxl), expf(fxh)); +#else + return h2exp(x); +#endif +} + +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800) +inline __device__ __nv_bfloat162 operator*(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hmul2(x, y); }; +inline __device__ __nv_bfloat162 operator+(const __nv_bfloat162 x, const __nv_bfloat162 y) { return bf16hadd2(x, y); }; + +inline __device__ __nv_bfloat162 make_bfloat162(const __nv_bfloat16 x, const __nv_bfloat16 y) +{ + __nv_bfloat162 t; t.x = x; t.y = y; return t; +} + +#endif + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c)); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hadd(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c, __nv_bfloat16 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) + __bfloat162float(b) + __bfloat162float(c) + __bfloat162float(d)); +#else + return (__nv_bfloat16)((float)a + (float)b + (float)c + (float)d); +#endif +} + +inline __device__ __nv_bfloat162 bf16hadd2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal + fbl + fcl, fah + fbh + fch); +#else + return a + b + c; +#endif +} + +inline __device__ __nv_bfloat16 bf16hmul(__nv_bfloat16 a, __nv_bfloat16 b, __nv_bfloat16 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + return __float2bfloat16(__bfloat162float(a) * __bfloat162float(b) * __bfloat162float(c)); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hmul2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + return __floats2bfloat162_rn(fal * fbl * fcl, fah * fbh * fch); +#else + return a * b * c; +#endif +} + +inline __device__ __nv_bfloat162 bf16hfma2(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c, __nv_bfloat162 d) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 + float fal, fah, fbl, fbh, fcl, fch, fdl, fdh; + fal = __low2float(a); + fah = __high2float(a); + fbl = __low2float(b); + fbh = __high2float(b); + fcl = __low2float(c); + fch = __high2float(c); + fdl = __low2float(d); + fdh = __high2float(d); + return __floats2bfloat162_rn(fal * fbl * fcl + fdl, fah * fbh * fch + fdh); +#else + return a * b * c + d; +#endif +} + +#endif // ENABLE_BF16 + +} // namespace fastertransformer diff --git a/awq/kernels/csrc/attention/cuda_bf16_wrapper.h b/awq/kernels/csrc/attention/cuda_bf16_wrapper.h new file mode 100644 index 0000000..efb6e79 --- /dev/null +++ b/awq/kernels/csrc/attention/cuda_bf16_wrapper.h @@ -0,0 +1,23 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h +/* + * Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#ifdef ENABLE_BF16 +#include +#endif diff --git a/awq/kernels/csrc/attention/decoder_masked_multihead_attention.cu b/awq/kernels/csrc/attention/decoder_masked_multihead_attention.cu new file mode 100644 index 0000000..e5a6690 --- /dev/null +++ b/awq/kernels/csrc/attention/decoder_masked_multihead_attention.cu @@ -0,0 +1,154 @@ +// Adapted from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "decoder_masked_multihead_attention.h" +#include "decoder_masked_multihead_attention_utils.h" +#include "cuda_bf16_wrapper.h" +#include +#include +#include + +#include "decoder_masked_multihead_attention_template.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \ + size_t smem_sz = mmha::smem_size_in_bytes(params, THDS_PER_VALUE, THDS_PER_BLOCK); \ + auto kernel = mmha::masked_multihead_attention_kernel; \ + if (smem_sz >= 48 * 1024) { \ + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \ + } \ + dim3 grid(params.num_heads, params.batch_size); \ + kernel<<>>(params) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// !!! Specialize the launcher for Cross attention +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + constexpr int THREADS_PER_VALUE = Dh_MAX * sizeof(T) / 16; + constexpr bool DO_CROSS_ATTENTION = std::is_same>::value; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_max_len : params.timestep; + // printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION); + if (tlength < 32) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 4, THREADS_PER_VALUE, 64, DO_CROSS_ATTENTION, stream); + } + else if (tlength < 2048) { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 2, THREADS_PER_VALUE, 128, DO_CROSS_ATTENTION, stream); + } + else { + MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, 1, THREADS_PER_VALUE, 256, DO_CROSS_ATTENTION, stream); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#undef MMHA_LAUNCH_KERNEL + +template +void multihead_attention_(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream) +{ + switch (params.hidden_size_per_head) { + case 32: + mmha_launch_kernel(params, stream); + break; + case 48: + mmha_launch_kernel(params, stream); + break; + case 64: + mmha_launch_kernel(params, stream); + break; + case 80: + mmha_launch_kernel(params, stream); + break; + case 96: + mmha_launch_kernel(params, stream); + break; + case 112: + mmha_launch_kernel(params, stream); + break; + case 128: + mmha_launch_kernel(params, stream); + break; + case 160: + mmha_launch_kernel(params, stream); + break; + case 192: + mmha_launch_kernel(params, stream); + break; + case 224: + mmha_launch_kernel(params, stream); + break; + case 256: + mmha_launch_kernel(params, stream); + break; + default: + assert(false); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Masked_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream) +{ + multihead_attention_>(params, stream); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream) +{ + multihead_attention_<__nv_bfloat16, Cross_multihead_attention_params<__nv_bfloat16>>(params, stream); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/awq/kernels/csrc/attention/decoder_masked_multihead_attention.h b/awq/kernels/csrc/attention/decoder_masked_multihead_attention.h new file mode 100644 index 0000000..292fa34 --- /dev/null +++ b/awq/kernels/csrc/attention/decoder_masked_multihead_attention.h @@ -0,0 +1,184 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_bf16_wrapper.h" +#include +#include +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CHECK_CUDA(call) \ + do { \ + cudaError_t status_ = call; \ + if (status_ != cudaSuccess) { \ + fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \ + exit(1); \ + } \ + } while (0) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// The structure of parameters for the masked multihead attention kernel. +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. + +template +struct Multihead_attention_params_base { + + // The output buffer. Dimensions B x D. + T* out = nullptr; + + // The input Qs and the associated bias. Dimensions B x D and D, resp. + const T *q = nullptr, *q_bias = nullptr; + // The input Ks and the associated bias. Dimensions B x D and D, resp. + const T *k = nullptr, *k_bias = nullptr; + // The input Vs and the associated bias. Dimensions B x D and D, resp. + const T *v = nullptr, *v_bias = nullptr; + + // The cache for the Ks. The size must be at least B x L x D. + T* k_cache = nullptr; + // The cache for the Vs. The size must be at least B x L x D. + T* v_cache = nullptr; + // The indirections to use for cache when beam sampling. + const int* cache_indir = nullptr; + + // Stride to handle the case when KQV is a single buffer + int stride = 0; + + // The batch size. + int batch_size = 0; + // The beam width + int beam_width = 0; + // The sequence length. + int memory_max_len = 0; + // The number of heads (H). + int num_heads = 0; + // The number of heads for KV cache. + int num_kv_heads = 0; + // The hidden dimension per head (Dh). + int hidden_size_per_head = 0; + // The per-head latent space reserved for rotary embeddings. + int rotary_embedding_dim = 0; + bool neox_rotary_style = false; + float rotary_base = 0.0f; + // The maximum length of input sentences. + int max_input_length = 0; + // The current timestep. TODO(bhsueh) Check that do we only this param in cross attention? + int timestep = 0; + // The current timestep of each sentences (support different timestep for different sentences) + + // The 1.f / sqrt(Dh). Computed on the host. + float inv_sqrt_dh = 0.0f; + + // Used when we have some input context like gpt + const int* total_padding_tokens = nullptr; + + const bool* masked_tokens = nullptr; + const int* prefix_prompt_lengths = nullptr; + int max_prefix_prompt_length = 0; + + const T* relative_attention_bias = nullptr; + int relative_attention_bias_stride = 0; + // The slope per head of linear position bias to attention score (H). + const float* linear_bias_slopes = nullptr; + + const T* ia3_key_weights = nullptr; + const T* ia3_value_weights = nullptr; + const int* ia3_tasks = nullptr; + + const float* qkv_scale_out = nullptr; + const float* attention_out_scale = nullptr; + int int8_mode = 0; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of cross attention + // will need it here till if constexpr in c++17 + int* memory_length_per_sample = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + +template +struct Multihead_attention_params: public Multihead_attention_params_base { + // output cross attentions + float* cross_attention_out = nullptr; + int max_decoder_seq_len = 0; + bool is_return_cross_attentions = false; + + // allows to exist attention eary + bool* finished = nullptr; + + // required in case of cross attention + int* memory_length_per_sample = nullptr; + + // required in case of masked attention with different length + const int* length_per_sample = nullptr; +}; + +template +using Masked_multihead_attention_params = Multihead_attention_params; + +template +using Cross_multihead_attention_params = Multihead_attention_params; + +template +struct outputCrossAttentionParam { + // max decoder output length + int max_decoder_seq_len = 0; + T* cross_attention_out = nullptr; + bool is_return_cross_attentions = false; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +void masked_multihead_attention(const Masked_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void masked_multihead_attention(const Masked_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); +void cross_multihead_attention(const Cross_multihead_attention_params& params, const cudaStream_t& stream); +#ifdef ENABLE_BF16 +void cross_multihead_attention(const Cross_multihead_attention_params<__nv_bfloat16>& params, + const cudaStream_t& stream); +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/awq/kernels/csrc/attention/decoder_masked_multihead_attention_template.hpp b/awq/kernels/csrc/attention/decoder_masked_multihead_attention_template.hpp new file mode 100644 index 0000000..84eda81 --- /dev/null +++ b/awq/kernels/csrc/attention/decoder_masked_multihead_attention_template.hpp @@ -0,0 +1,1608 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#pragma once + +#include "decoder_masked_multihead_attention.h" +#include "decoder_masked_multihead_attention_utils.h" +#include "cuda_bf16_wrapper.h" +#include "cuda_bf16_fallbacks.cuh" +#include +#include +#include + +// #define MMHA_USE_HMMA_FOR_REDUCTION + +// Below are knobs to extend FP32 accumulation for higher FP16 accuracy + +// Does not seem to affect the accuracy that much +#define MMHA_USE_FP32_ACUM_FOR_FMA + +// Seems to slightly improve the accuracy +#define MMHA_USE_FP32_ACUM_FOR_OUT + +#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT) + // Does not seem to improve the accuracy + //#define MMHA_USE_FP32_ACUM_FOR_LOGITS +#endif + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// We use the following terminology to describe the different dimensions. +// +// B: Batch size (number of sequences), +// L: Sequence length, +// D: Hidden dimension, +// H: Number of heads, +// Dh: Hidden dimension per head - Dh = D / H. +// +// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use +// 64, 128 and 256 threads per block. +// +// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to +// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The +// cache buffer helps with memory accesses and contains keys with bias. +// +// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and +// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The +// values for x are chosen to create chunks of 16 bytes. +// +// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs +// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At +// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an +// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32. +// +// After that loop, a parallel softmax is computed across the different Q * K^T values stored in +// shared memory. +// +// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many +// timesteps are computed by loop iteration. As with the keys, the values are read from a cache +// except for the current timestep. The layout of the cache buffer for the values is much simpler +// as it is [B, H, L, Dh]. +// + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_vec_ { +}; + +template<> +struct Qk_vec_ { + using Type = float; +}; +template<> +struct Qk_vec_ { + using Type = float2; +}; +template<> +struct Qk_vec_ { + using Type = float4; +}; +template<> +struct Qk_vec_ { + using Type = float4; +}; +template<> +struct Qk_vec_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_ { + using Type = uint32_t; +}; +template<> +struct Qk_vec_ { + using Type = uint2; +}; +template<> +struct Qk_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct Qk_vec_<__nv_bfloat16, 32> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 64> { + using Type = __nv_bfloat162; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 128> { + using Type = bf16_4_t; +}; +template<> +struct Qk_vec_<__nv_bfloat16, 256> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_ { +}; + +template<> +struct K_vec_ { + using Type = float; +}; +template<> +struct K_vec_ { + using Type = float2; +}; +template<> +struct K_vec_ { + using Type = float4; +}; +template<> +struct K_vec_ { + using Type = uint32_t; +}; +template<> +struct K_vec_ { + using Type = uint2; +}; +template<> +struct K_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct K_vec_<__nv_bfloat16, 4> { + using Type = __nv_bfloat162; +}; +template<> +struct K_vec_<__nv_bfloat16, 2> { + using Type = bf16_4_t; +}; +template<> +struct K_vec_<__nv_bfloat16, 1> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct V_vec_ { +}; + +template<> +struct V_vec_ { + using Type = float; +}; +template<> +struct V_vec_ { + using Type = float2; +}; +template<> +struct V_vec_ { + using Type = float4; +}; +template<> +struct V_vec_ { + using Type = uint32_t; +}; +template<> +struct V_vec_ { + using Type = uint2; +}; +template<> +struct V_vec_ { + using Type = uint4; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_<__nv_bfloat16, 2> { + using Type = __nv_bfloat162; +}; +template<> +struct V_vec_<__nv_bfloat16, 4> { + using Type = bf16_4_t; +}; +template<> +struct V_vec_<__nv_bfloat16, 8> { + using Type = bf16_8_t; +}; +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA +template +struct Qk_vec_acum_fp32_ { +}; + +template<> +struct Qk_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float4; +}; +// template<> struct Qk_vec_acum_fp32_ { using Type = float; }; +template<> +struct Qk_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct Qk_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct Qk_vec_acum_fp32_ { + using Type = Float8_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct K_vec_acum_fp32_ { +}; + +template<> +struct K_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat16> { + using Type = float; +}; +template<> +struct K_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct K_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT +template +struct V_vec_acum_fp32_ { +}; + +template<> +struct V_vec_acum_fp32_ { + using Type = float; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float4; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#ifdef ENABLE_BF16 +template<> +struct V_vec_acum_fp32_<__nv_bfloat162> { + using Type = float2; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float4_; +}; +template<> +struct V_vec_acum_fp32_ { + using Type = Float8_; +}; +#endif // ENABLE_BF16 +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_dot_(const K_vec (&q)[N], const K_vec (&k)[N]) +{ +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = K_vec; +#endif + // Compute the parallel products for Q*K^T (treat vector lanes separately). + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } + + // Finalize the reduction across lanes. + float qk = sum(qk_vec); +#pragma unroll + for (int mask = THREADS_PER_KEY / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(uint32_t(-1), qk, mask); + } + return qk; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Qk_dot { + template + static inline __device__ float dot(const K_vec (&q)[N], const K_vec (&k)[N]) + { + return qk_dot_(q, k); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 hmma_fp32(const uint2& a, uint32_t b) +{ + float4 c; + float zero = 0.f; + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 \n" + " {%0, %1, %2, %3}, \n" + " {%4, %5}, \n" + " {%6}, \n" + " {%7, %7, %7, %7}; \n" + + : "=f"(c.x), "=f"(c.y), "=f"(c.z), "=f"(c.w) + : "r"(a.x) "r"(a.y), "r"(b), "f"(zero)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float qk_hmma_dot_(const uint32_t (&q)[N], const uint32_t (&k)[N]) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using K_vec_acum = typename K_vec_acum_fp32_::Type; +#else + using K_vec_acum = uint32_t; +#endif + K_vec_acum qk_vec = mul(q[0], k[0]); +#pragma unroll + for (int ii = 1; ii < N; ++ii) { + qk_vec = fma(q[ii], k[ii], qk_vec); + } +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + uint32_t qk_vec_ = float2_to_half2(qk_vec); + return hmma_fp32(make_uint2(qk_vec_, 0u), 0x3c003c00u).x; +#else + return hmma_fp32(make_uint2(qk_vec, 0u), 0x3c003c00u).x; +#endif +#else + return 0.f; +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +struct Qk_dot { + template + static inline __device__ float dot(const uint32_t (&q)[N], const uint32_t (&k)[N]) + { +#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION) + return qk_hmma_dot_(q, k); +#else + return qk_dot_<4>(q, k); +#endif // defined MMHA_USE_HMMA_FOR_REDUCTION + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float block_sum(float* red_smem, float sum) +{ + + // Decompose the thread index into warp / lane. + int warp = threadIdx.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + +// Compute the sum per warp. +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Warp leaders store the data to shared memory. + if (lane == 0) { + red_smem[warp] = sum; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The warps compute the final sums. + if (lane < WARPS_PER_BLOCK) { + sum = red_smem[lane]; + } + +// Parallel reduction inside the warp. +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + sum += __shfl_xor_sync(uint32_t(-1), sum, mask); + } + + // Broadcast to other threads. + return __shfl_sync(uint32_t(-1), sum, 0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float& dst, float src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint16_t& dst, float src) +{ + dst = float_to_half(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint32_t& dst, float2 src) +{ + dst = float2_to_half2(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(__nv_bfloat16& dst, float src) +{ + dst = __float2bfloat16(src); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(__nv_bfloat162& dst, float2 src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst = __float22bfloat162_rn(src); +#else + dst = __floats2bfloat162_rn(src.x, src.y); +#endif +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, Float4_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint2& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(uint4& dst, Float8_ src) +{ + dst.x = float2_to_half2(src.x); + dst.y = float2_to_half2(src.y); + dst.z = float2_to_half2(src.z); + dst.w = float2_to_half2(src.w); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ void convert_from_float(bf16_4_t& dst, Float4_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_4_t& dst, float4 src) +{ + convert_from_float(dst, Float4_{make_float2(src.x, src.y), make_float2(src.z, src.w)}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(bf16_8_t& dst, Float8_ src) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + dst.x = __float22bfloat162_rn(src.x); + dst.y = __float22bfloat162_rn(src.y); + dst.z = __float22bfloat162_rn(src.z); + dst.w = __float22bfloat162_rn(src.w); +#else + dst.x = __floats2bfloat162_rn(src.x.x, src.x.y); + dst.y = __floats2bfloat162_rn(src.y.x, src.y.y); + dst.z = __floats2bfloat162_rn(src.z.x, src.z.y); + dst.w = __floats2bfloat162_rn(src.w.x, src.w.y); +#endif +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float2& dst, float2 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void convert_from_float(float4& dst, float4 src) +{ + dst = src; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(float4 u) +{ + return u.x; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float convert_to_float(uint4 u) +{ + float2 tmp = half2_to_float2(u.x); + return tmp.x; +} + +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float cast_to_float(float u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(float2 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 cast_to_float(float4 u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(Float4_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(Float8_ u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 cast_to_float(uint32_t u) +{ + return half2_to_float2(u); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ cast_to_float(uint2 u) +{ + Float4_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + return tmp; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ cast_to_float(uint4 u) +{ + Float8_ tmp; + tmp.x = half2_to_float2(u.x); + tmp.y = half2_to_float2(u.y); + tmp.z = half2_to_float2(u.z); + tmp.w = half2_to_float2(u.w); + return tmp; +} + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float float_from_int8(int8_t u) +{ + return u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 float_from_int8(int16_t u) +{ + union { + int16_t int16; + int8_t int8[2]; + }; + int16 = u; + return make_float2(int8[0], int8[1]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 float_from_int8(int32_t u) +{ + union { + int32_t int32; + int8_t int8[4]; + }; + int32 = u; + return make_float4(int8[0], int8[1], int8[2], int8[3]); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// clang-format off +inline __device__ Float8_ float_from_int8(int64_t u) +{ + union { + int64_t int64; + int16_t int16[4]; + }; + int64 = u; + return Float8_ {float_from_int8(int16[0]), + float_from_int8(int16[1]), + float_from_int8(int16[2]), + float_from_int8(int16[3])}; +} +// clang-format on + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int8_t cast_to_int8(float val) +{ + union { + int8_t int8[2]; + int16_t int16; + }; + asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=h"(int16) : "f"(val)); + return int8[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int32_t cast_to_int8(float4 val) +{ + union { + int8_t int8[4]; + int32_t int32; + }; + int8[0] = cast_to_int8(val.x); + int8[1] = cast_to_int8(val.y); + int8[2] = cast_to_int8(val.z); + int8[3] = cast_to_int8(val.w); + return int32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ int64_t cast_to_int8(Float8_ val) +{ + union { + int8_t int8[8]; + int64_t int64; + }; + int8[0] = cast_to_int8(val.x.x); + int8[1] = cast_to_int8(val.x.y); + int8[2] = cast_to_int8(val.y.x); + int8[3] = cast_to_int8(val.y.y); + int8[4] = cast_to_int8(val.z.x); + int8[5] = cast_to_int8(val.z.y); + int8[6] = cast_to_int8(val.w.x); + int8[7] = cast_to_int8(val.w.y); + return int64; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ __host__ T div_up(T m, T n) +{ + return (m + n - 1) / n; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline size_t smem_size_in_bytes(const Multihead_attention_params& params, + int threads_per_value, + int threads_per_block) +{ + // The amount of shared memory needed to store the Q*K^T values in float. + const int max_timesteps = min(params.timestep, params.memory_max_len); + size_t qk_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + + // The extra memory needed if we are not using floats for the final logits. + size_t logits_sz = 0; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(T) != 4) { + // TDOD + logits_sz = (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 4 * sizeof(T) : + div_up(max_timesteps + 1, 4) * 4 * sizeof(T); + } +#endif + + // The total size needed during softmax. + size_t softmax_sz = qk_sz + logits_sz; + + // The number of partial rows to reduce in the final reduction. + int rows_per_red = threads_per_block / threads_per_value; + // The amount of storage needed to finalize the outputs. + size_t red_sz = rows_per_red * params.hidden_size_per_head * sizeof(T) / 2; + + size_t transpose_rotary_size = 0; + if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + transpose_rotary_size = 2 * params.rotary_embedding_dim * sizeof(T); + } + + // The max. + return max(max(softmax_sz, red_sz), transpose_rotary_size); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ constexpr uint32_t shfl_mask(int threads) +{ + return threads == 32 ? uint32_t(-1) : (1u << threads) - 1u; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // The type of the inputs. Supported types: float and half. + typename T, + // The hidden dimension per head. + int Dh, + int Dh_MAX, + // The number of threads per key. + int THREADS_PER_KEY, + // The number of threads per value. + int THREADS_PER_VALUE, + // The number of threads in a threadblock. + int THREADS_PER_BLOCK, + bool DO_CROSS_ATTENTION> +__global__ void masked_multihead_attention_kernel(Multihead_attention_params params) +{ + + // Make sure the hidden dimension per head is a multiple of the number of threads per key. + static_assert(Dh_MAX % THREADS_PER_KEY == 0, ""); + // Make sure the hidden dimension per head is a multiple of the number of threads per value. + static_assert(Dh_MAX % THREADS_PER_VALUE == 0, ""); + + // The size of a warp. + constexpr int WARP_SIZE = 32; + // The number of warps in a threadblock. + constexpr int WARPS_PER_BLOCK = THREADS_PER_BLOCK / WARP_SIZE; + + // Use smem_size_in_bytes (above) to determine the amount of shared memory. + extern __shared__ char smem_[]; + + // The shared memory for the Q*K^T values and partial logits in softmax. + float* qk_smem = reinterpret_cast(smem_); + + // The shared memory for the logits. For FP32, that's the same buffer as qk_smem. + char* logits_smem_ = smem_; +#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS + if (sizeof(T) != 4) { + // TODO - change to tlength + const int max_timesteps = min(params.timestep, params.memory_max_len); + logits_smem_ += + (DO_CROSS_ATTENTION) ? div_up(params.memory_max_len + 1, 4) * 16 : div_up(max_timesteps + 1, 4) * 16; + } + T* logits_smem = reinterpret_cast(logits_smem_); +#else + float* logits_smem = reinterpret_cast(logits_smem_); +#endif + + // The shared memory to do the final reduction for the output values. Reuse qk_smem. + T* out_smem = reinterpret_cast(smem_); + + // The shared memory buffers for the block-wide reductions. One for max, one for sum. + __shared__ float red_smem[WARPS_PER_BLOCK * 2]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + + // Use alignment for safely casting the shared buffers as Qk_vec. + // Shared memory to store Q inputs. + __shared__ __align__(sizeof(Qk_vec)) T q_smem[Dh_MAX]; + + // This is one of the reasons we should have a separate kernel for cross attention + __shared__ __align__(sizeof(Qk_vec)) T bias_smem[DO_CROSS_ATTENTION ? Dh_MAX : 1]; + + // A vector of Q or K elements for the current timestep. + using Qk_vec = typename Qk_vec_::Type; + // The number of elements per vector. + constexpr int QK_VEC_SIZE = sizeof(Qk_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % QK_VEC_SIZE == 0, ""); + // We will use block wide reduction if needed + // static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, ""); + // The number of vectors per warp. + constexpr int QK_VECS_PER_WARP = Dh_MAX / QK_VEC_SIZE; + + // The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread + // owns x elements, we have to decompose the linear index into chunks of x values and the posi- + // tion of the thread in that chunk. + + // The number of elements in a chunk of 16B (that's the x in the above formula). + constexpr int QK_ELTS_IN_16B = 16 / sizeof(T); + // The number of K vectors in 16B. + constexpr int QK_VECS_IN_16B = 16 / sizeof(Qk_vec); + + // The batch/beam idx + const int bi = blockIdx.y; + if (params.finished != nullptr && params.finished[bi] == true) { + return; + } + // The beam idx + const int beami = bi % params.beam_width; + // The "beam-aware" batch idx + const int bbi = bi / params.beam_width; + // The head. + const int num_kv_heads = params.num_kv_heads; + const int kv_rep = (params.num_heads / num_kv_heads); + const int hi = blockIdx.x; + const int hi_kv = hi / kv_rep; + + // Combine the batch and the head indices. + const int bhi = bi * params.num_heads + hi; + const int bhi_kv = bi * (params.num_heads / kv_rep) + hi_kv; + // Combine the "beam-aware" batch idx and the head indices. + const int bbhi = bbi * params.beam_width * params.num_heads + hi; + const int bbhi_kv = bbi * params.beam_width * (params.num_heads / kv_rep) + hi_kv; + // The thread in the block. + const int tidx = threadIdx.x; + + const bool handle_kv = !DO_CROSS_ATTENTION || (DO_CROSS_ATTENTION && params.timestep == 0); + // Every kv_rep threads have the same kv_cache values. So only the first one writes back. + const int write_kv_cache = handle_kv && (hi % kv_rep == 0); + + // While doing the product Q*K^T for the different keys we track the max. + float qk_max = -FLT_MAX; + + float qk = 0.0F; + + // int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh; + const int q_base_offset = bi * params.stride + hi * Dh; + const int k_base_offset = bi * params.stride + hi_kv * Dh; + const int v_base_offset = k_base_offset; + + const size_t bi_seq_len_offset = bi * params.memory_max_len; + + // int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep; + int tlength = (DO_CROSS_ATTENTION) ? params.memory_length_per_sample[bi] - 1 : + (params.length_per_sample == nullptr) ? + params.timestep : + params.length_per_sample[bi] + params.max_prefix_prompt_length; + const int first_step = max(0, tlength + 1 - params.memory_max_len); + const int tlength_circ = tlength % params.memory_max_len; + + // First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep. + const bool is_masked = tidx >= QK_VECS_PER_WARP; + + // The offset in the Q and K buffer also accounts for the batch. + // int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE; + int q_offset = q_base_offset + tidx * QK_VEC_SIZE; + int k_offset = k_base_offset + tidx * QK_VEC_SIZE; + int v_offset = k_offset; + + // The offset in the bias buffer. + // int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + int q_bias_offset = hi * Dh + tidx * QK_VEC_SIZE; + int k_bias_offset = hi_kv * Dh + tidx * QK_VEC_SIZE; + int v_bias_offset = k_bias_offset; + + const bool do_ia3 = handle_kv && params.ia3_tasks != nullptr; + const int ia3_task_id = do_ia3 ? params.ia3_tasks[bbi] : 0; + + // Trigger the loads from the Q and K buffers. + Qk_vec q; + zero(q); + if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto q_scaling = params.qkv_scale_out[0]; + const auto q_quant = + *reinterpret_cast(&reinterpret_cast(params.q)[q_offset]); + + convert_from_float(q, mul(q_scaling, float_from_int8(q_quant))); + } + else { + q = *reinterpret_cast(¶ms.q[q_offset]); + } + } + + Qk_vec k; + zero(k); + if (DO_CROSS_ATTENTION) { + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength * QK_ELTS_IN_16B + ci; + k = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) ? + *reinterpret_cast(¶ms.k_cache[offset]) : + k; + } + else { + if (!is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh)) { + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto k_scaling = params.qkv_scale_out[1]; + const auto k_quant = + *reinterpret_cast(&reinterpret_cast(params.k)[k_offset]); + + convert_from_float(k, mul(k_scaling, float_from_int8(k_quant))); + } + else { + k = *reinterpret_cast(¶ms.k[k_offset]); + } + } + } + + // Trigger the loads from the Q and K bias buffers. + Qk_vec q_bias; + zero(q_bias); + q_bias = (!is_masked && Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.q_bias != nullptr ? + *reinterpret_cast(¶ms.q_bias[q_bias_offset]) : + q_bias; + + Qk_vec k_bias; + zero(k_bias); + if (handle_kv) { + k_bias = !is_masked && (Dh == Dh_MAX || tidx * QK_VEC_SIZE < Dh) && params.k_bias != nullptr ? + *reinterpret_cast(¶ms.k_bias[k_bias_offset]) : + k_bias; + } + + // Computes the Q/K values with bias. + q = add(q, q_bias); + if (handle_kv) { + k = add(k, k_bias); + } + if (do_ia3 && !is_masked) { + k = mul( + k, + *reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + tidx * QK_VEC_SIZE])); + } + + // Padded len + const int padd_len = (params.total_padding_tokens == nullptr) ? 0 : params.total_padding_tokens[bi]; + if (params.rotary_embedding_dim > 0 && !params.neox_rotary_style) { + if (handle_kv) { + apply_rotary_embedding(q, k, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + } + else { + apply_rotary_embedding(q, tidx, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + } + } + else if (params.rotary_embedding_dim > 0 && params.neox_rotary_style) { + const bool do_rotary = !is_masked && QK_VEC_SIZE * tidx < params.rotary_embedding_dim; + + T* q_smem = reinterpret_cast(smem_); + T* k_smem = q_smem + params.rotary_embedding_dim; + + const int half_rotary_dim = params.rotary_embedding_dim / 2; + const int half_idx = (tidx * QK_VEC_SIZE) / half_rotary_dim; + const int intra_half_idx = (tidx * QK_VEC_SIZE) % half_rotary_dim; + const int smem_pitch = half_rotary_dim; // TODO: adjust for bank conflicts + + assert(half_rotary_dim % QK_VEC_SIZE == 0); + + if (do_rotary) { + *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx) = q; + + if (handle_kv) { + *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx) = k; + } + } + + __syncthreads(); + + const int transpose_idx = half_idx * (half_rotary_dim / 2) + intra_half_idx / 2; + constexpr int tidx_factor = (QK_VEC_SIZE > 1) ? QK_VEC_SIZE / 2 : 1; + if (do_rotary) { + mmha::vec_from_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + + if (handle_kv) { + mmha::vec_from_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + + mmha::apply_rotary_embedding( + q, k, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength - padd_len, params.rotary_base); + + mmha::write_smem_transpose(k, k_smem, transpose_idx, smem_pitch); + } + else { + mmha::apply_rotary_embedding( + q, transpose_idx / tidx_factor, params.rotary_embedding_dim, tlength, params.rotary_base); + } + mmha::write_smem_transpose(q, q_smem, transpose_idx, smem_pitch); + } + + __syncthreads(); + + if (do_rotary) { + q = *reinterpret_cast(q_smem + half_idx * smem_pitch + intra_half_idx); + if (handle_kv) { + k = *reinterpret_cast(k_smem + half_idx * smem_pitch + intra_half_idx); + } + } + + __syncthreads(); + } + + if (!is_masked) { + // Store the Q values to shared memory. + *reinterpret_cast(&q_smem[tidx * QK_VEC_SIZE]) = q; + + // Store Dh values of k_bias into smem, since will need to add later + // if params.timestep == 0 + if (DO_CROSS_ATTENTION && params.timestep == 0) { + *reinterpret_cast(&bias_smem[tidx * QK_VEC_SIZE]) = k_bias; + } + + // Write the K values to the global memory cache. + // + // NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory + // system. We designed it this way as it allows much better memory loads (and there are many + // more loads) + the stores are really "write and forget" since we won't need the ack before + // the end of the kernel. There's plenty of time for the transactions to complete. + + // The 16B chunk written by the thread. + int co = tidx / QK_VECS_IN_16B; + // The position of the thread in that 16B chunk. + int ci = tidx % QK_VECS_IN_16B * QK_VEC_SIZE; + + // Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements. + int offset = bhi_kv * params.memory_max_len * Dh + co * params.memory_max_len * QK_ELTS_IN_16B + + // params.timestep*QK_ELTS_IN_16B + + tlength_circ * QK_ELTS_IN_16B + ci; + + if (write_kv_cache) { + // Trigger the stores to global memory. + if (Dh == Dh_MAX || co < Dh / QK_ELTS_IN_16B) { + *reinterpret_cast(¶ms.k_cache[offset]) = k; + } + } + + // Compute \sum_i Q[i] * K^T[i] for the current timestep. +#ifdef MMHA_USE_FP32_ACUM_FOR_FMA + using Qk_vec_acum = typename Qk_vec_acum_fp32_::Type; +#else + using Qk_vec_acum = Qk_vec; +#endif + qk = dot(q, k); + if (QK_VECS_PER_WARP <= WARP_SIZE) { +#pragma unroll + for (int mask = QK_VECS_PER_WARP / 2; mask >= 1; mask /= 2) { + qk += __shfl_xor_sync(shfl_mask(QK_VECS_PER_WARP), qk, mask); + } + } + } + + if (QK_VECS_PER_WARP > WARP_SIZE) { + constexpr int WARPS_PER_RED = (QK_VECS_PER_WARP + WARP_SIZE - 1) / WARP_SIZE; + qk = block_sum(&red_smem[WARPS_PER_RED], qk); + } + + // Store that value in shared memory. Keep the Q*K^T value in register for softmax. + if (tidx == 0) { + // Normalize qk. + qk *= params.inv_sqrt_dh; + if (params.relative_attention_bias != nullptr) { + // TODO (Haotian): check whether we should replace hi with hi_kv, + // although params.relative_attention_bias is usually not used. + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + (tlength - padd_len) * params.relative_attention_bias_stride + + (tlength - padd_len)]); + } + // Add alibi positional encoding + // qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0; + // We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0. + + qk_max = qk; + qk_smem[tlength - first_step] = qk; + // qk_smem[params.timestep] = qk; + } + + // Make sure the data is in shared memory. + __syncthreads(); + + // The type of queries and keys for the math in the Q*K^T product. + using K_vec = typename K_vec_::Type; + // The number of elements per vector. + constexpr int K_VEC_SIZE = sizeof(K_vec) / sizeof(T); + // Make sure the hidden size per head is a multiple of the vector size. + static_assert(Dh_MAX % K_VEC_SIZE == 0, ""); + // The number of elements per thread. + constexpr int K_ELTS_PER_THREAD = Dh_MAX / THREADS_PER_KEY; + // The number of vectors per thread. + constexpr int K_VECS_PER_THREAD = K_ELTS_PER_THREAD / K_VEC_SIZE; + + // The position the first key loaded by each thread from the cache buffer (for this B * H). + int ko = tidx / THREADS_PER_KEY; + // The position of the thread in the chunk of keys. + int ki = tidx % THREADS_PER_KEY * K_VEC_SIZE; + + static_assert(Dh_MAX == THREADS_PER_KEY * K_VEC_SIZE * K_VECS_PER_THREAD); + + // Load the Q values from shared memory. The values are reused during the loop on K. + K_vec q_vec[K_VECS_PER_THREAD]; +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + q_vec[ii] = *reinterpret_cast(&q_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + + K_vec k_bias_vec[DO_CROSS_ATTENTION ? K_VECS_PER_THREAD : 1]; + if (DO_CROSS_ATTENTION && params.timestep == 0) { +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + k_bias_vec[ii] = *reinterpret_cast(&bias_smem[ki + ii * THREADS_PER_KEY * K_VEC_SIZE]); + } + } + + // The number of timesteps loaded per iteration. + constexpr int K_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_KEY; + // The number of keys per warp. + constexpr int K_PER_WARP = WARP_SIZE / THREADS_PER_KEY; + + // The base pointer for the key in the cache buffer. + T* k_cache = ¶ms.k_cache[bhi_kv * params.memory_max_len * Dh + ki]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* k_cache_batch = ¶ms.k_cache[bbhi_kv * params.memory_max_len * Dh + ki]; + + // Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync). + // int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP; + int ti_end = div_up(tlength - first_step, K_PER_WARP) * K_PER_WARP + first_step; + + // prefix prompt length if has + const int prefix_prompt_length = (params.prefix_prompt_lengths == nullptr) ? 0 : params.prefix_prompt_lengths[bi]; + + // Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values. + const bool has_beams = params.cache_indir != nullptr; + const int* beam_indices = has_beams ? ¶ms.cache_indir[bi_seq_len_offset] : nullptr; + + for (int ti = first_step + ko; ti < ti_end; ti += K_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + + // The keys loaded from the key cache. + K_vec k[K_VECS_PER_THREAD]; + K_vec k_vec_zero; + zero(k_vec_zero); +#pragma unroll + for (int ii = 0; ii < K_VECS_PER_THREAD; ++ii) { + int jj = ii * params.memory_max_len + ti_circ; + // if( ti < params.timestep ) { + const bool within_bounds = (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len); + if (ti < tlength) { + if (!within_bounds) { + k[ii] = k_vec_zero; + } + else { + if (has_beams) { + const int beam_offset = beam_indices[ti_circ] * params.num_heads * params.memory_max_len * Dh; + k[ii] = *reinterpret_cast(&k_cache_batch[beam_offset + jj * QK_ELTS_IN_16B]); + } + else { + k[ii] = *reinterpret_cast(&k_cache_batch[jj * QK_ELTS_IN_16B]); + } + } + // add bias and update k_cache + if (DO_CROSS_ATTENTION && params.timestep == 0) { + k[ii] = add(k[ii], k_bias_vec[ii]); + + if (do_ia3) { + k[ii] = mul( + k[ii], + *reinterpret_cast( + ¶ms.ia3_key_weights[(ia3_task_id * params.num_heads + hi) * Dh + ki + + ii * THREADS_PER_KEY * K_VEC_SIZE])); + } + + if (Dh == Dh_MAX || jj * QK_ELTS_IN_16B < Dh * params.memory_max_len) { + *reinterpret_cast(&k_cache[jj * QK_ELTS_IN_16B]) = k[ii]; + } + } + } + } + + // Perform the dot product and normalize qk. + // + // WARNING: ALL THE THREADS OF A WARP MUST ENTER!!! + float qk = Qk_dot::dot(q_vec, k) * params.inv_sqrt_dh; + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + + // Store the product to shared memory. There's one qk value per timestep. Update the max. + // if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) { + if (ti < tlength && tidx % THREADS_PER_KEY == 0) { + if (params.relative_attention_bias != nullptr) { + qk = add(qk, + params.relative_attention_bias[hi * params.relative_attention_bias_stride + * params.relative_attention_bias_stride + + tlength * params.relative_attention_bias_stride + ti]); + } + if (params.linear_bias_slopes != nullptr) { + // Apply the linear position bias: (ki - qi) * slope[hi]. + // The padding token locates between the input context and the generated tokens. + // We need to remove the number of padding tokens in the distance computation. + // ti : 0 1 2 3 4 5 6 7 8 9(tlength) + // token: i i i i p p p o o o where i=input, p=pad, o=output. + // e.g. ti = 2, dist = (9 - 3) - 2 = 4. + int max_context_length = params.max_prefix_prompt_length + params.max_input_length; + float dist = (ti < max_context_length ? ti + padd_len : ti) - tlength; + + qk += mul(params.linear_bias_slopes[hi], dist); + } + // Add alibi positional encoding + // qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0; + qk_max = is_mask ? qk_max : fmaxf(qk_max, qk); + qk_smem[ti - first_step] = qk; + } + } + +// Perform the final reduction to compute the max inside each warp. +// +// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the +// group so it's not needed to run the reduction inside the group (again). +#pragma unroll + for (int mask = WARP_SIZE / 2; mask >= THREADS_PER_KEY; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Decompose the thread index into warp and lane. + const int warp = tidx / WARP_SIZE; + const int lane = tidx % WARP_SIZE; + + // The warp leader writes the max to shared memory. + if (lane == 0) { + red_smem[warp] = qk_max; + } + + // Make sure the products are in shared memory. + __syncthreads(); + + // The warps finalize the reduction. + qk_max = lane < WARPS_PER_BLOCK ? red_smem[lane] : -FLT_MAX; +#pragma unroll + for (int mask = WARPS_PER_BLOCK / 2; mask >= 1; mask /= 2) { + qk_max = fmaxf(qk_max, __shfl_xor_sync(uint32_t(-1), qk_max, mask)); + } + + // Broadcast to all the threads in the warp. + qk_max = __shfl_sync(uint32_t(-1), qk_max, 0); + + // Compute the logits and start the sum. + float sum = 0.f; + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + bool is_mask = (params.masked_tokens != nullptr) && params.masked_tokens[bi_seq_len_offset + ti]; + float logit = is_mask ? 0.f : __expf(qk_smem[ti - first_step] - qk_max); + sum += logit; + qk_smem[ti - first_step] = logit; + } + + // Compute the sum. + sum = block_sum(&red_smem[WARPS_PER_BLOCK], sum); + + // Normalize the logits. + float inv_sum = __fdividef(1.f, sum + 1.e-6f); + // for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) { + const size_t cross_attention_out_offset = + params.is_return_cross_attentions ? + bhi_kv * params.max_decoder_seq_len * params.memory_max_len + params.timestep * params.memory_max_len : + 0; + for (int ti = first_step + tidx; ti <= tlength; ti += THREADS_PER_BLOCK) { + float logit = qk_smem[ti - first_step] * inv_sum; + if (params.is_return_cross_attentions) { + params.cross_attention_out[cross_attention_out_offset + ti] = logit; + } + convert_from_float(logits_smem[ti - first_step], logit); + } + + // Put Values part below so we leverage __syncthreads + // from the previous step + + // The number of elements per vector. + constexpr int V_VEC_SIZE = Dh_MAX / THREADS_PER_VALUE; + // A vector of V elements for the current timestep. + using V_vec = typename V_vec_::Type; + + // The value computed by this thread. + int vo = tidx / THREADS_PER_VALUE; + // The hidden dimensions computed by this particular thread. + int vi = tidx % THREADS_PER_VALUE * V_VEC_SIZE; + + // The base pointer for the value in the cache buffer. + T* v_cache = ¶ms.v_cache[bhi_kv * params.memory_max_len * Dh + vi]; + // Base pointer for the beam's batch, before offsetting with indirection buffer + T* v_cache_batch = ¶ms.v_cache[bbhi_kv * params.memory_max_len * Dh + vi]; + + // The number of values processed per iteration of the loop. + constexpr int V_PER_ITER = THREADS_PER_BLOCK / THREADS_PER_VALUE; + + // One group of threads computes the product(s) for the current timestep. + V_vec v_bias; + zero(v_bias); + // if( vo == params.timestep % V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + if (handle_kv) { + if (vo == tlength % V_PER_ITER) { + // Trigger the loads from the V bias buffer. + if (params.v_bias != nullptr) { + v_bias = *reinterpret_cast(¶ms.v_bias[hi_kv * Dh + vi]); + } + if (DO_CROSS_ATTENTION) { + *reinterpret_cast(&bias_smem[vi]) = v_bias; + } + } + } + } + + // From previous, before values, step + // Also make sure the logits are in shared memory. + __syncthreads(); + + // Values continued +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + using V_vec_acum = typename V_vec_acum_fp32_::Type; +#else + using V_vec_acum = V_vec; +#endif + // The partial outputs computed by each thread. + V_vec_acum out; + zero(out); + + // Loop over the timesteps to compute the partial outputs. + // for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) { + if (Dh == Dh_MAX || vi < Dh) { + for (int ti = first_step + vo; ti < tlength; ti += V_PER_ITER) { + const int ti_circ = ti % params.memory_max_len; + + // Fetch offset based on cache_indir when beam sampling + const int beam_src = (params.cache_indir != nullptr) ? params.cache_indir[bi_seq_len_offset + ti_circ] : 0; + const int beam_offset = beam_src * params.num_heads * params.memory_max_len * Dh; + // Load the values from the cache. + V_vec v = *reinterpret_cast(&v_cache_batch[beam_offset + ti_circ * Dh]); + if (DO_CROSS_ATTENTION && params.timestep == 0) { + v = add(v, *reinterpret_cast(&bias_smem[vi])); + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + *reinterpret_cast(&v_cache[ti * Dh]) = v; + } + // Load the logits from shared memory. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + float logit = logits_smem[ti - first_step]; + out = fma(logit, cast_to_float(v), out); +#else + T logit = logits_smem[ti - first_step]; + + // Update the partial sums. + out = fma(logit, v, out); +#endif + } + } + + // One group of threads computes the product(s) for the current timestep. + // if( vo == params.timestep % V_PER_ITER ) { + if (vo == tlength % V_PER_ITER && (Dh == Dh_MAX || vi < Dh)) { + + V_vec v; + if (DO_CROSS_ATTENTION) { + v = *reinterpret_cast(&v_cache[tlength * Dh]); + } + else { + // Trigger the loads from the V buffer. + const auto v_offset = v_base_offset + vi; + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + using Packed_Float_t = typename packed_type::value>::type; + const auto v_scaling = params.qkv_scale_out[2]; + const auto v_quant = + *reinterpret_cast(&reinterpret_cast(params.v)[v_offset]); + + convert_from_float(v, mul(v_scaling, float_from_int8(v_quant))); + } + else { + v = *reinterpret_cast(¶ms.v[v_offset]); + } + // Trigger the loads from the V bias buffer. + // V_vec v_bias = *reinterpret_cast(¶ms.v_bias[hi*Dh + vi]); + } + + // Compute the V values with bias. + v = add(v, v_bias); + if (write_kv_cache) { + + if (do_ia3) { + v = mul( + v, + *reinterpret_cast( + ¶ms.ia3_value_weights[(ia3_task_id * params.num_heads + hi) * Dh + vi])); + } + + // Store the values with bias back to global memory in the cache for V. + //*reinterpret_cast(&v_cache[params.timestep*Dh]) = v; + *reinterpret_cast(&v_cache[tlength_circ * Dh]) = v; + } + + // Initialize the output value with the current timestep. +#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS) + // out = fma(logits_smem[params.timestep], cast_to_float(v), out); + out = fma(logits_smem[tlength - first_step], cast_to_float(v), out); +#else + // out = fma(logits_smem[params.timestep], v, out); + out = fma(logits_smem[tlength - first_step], v, out); +#endif + } + + // Make sure we can start writing to shared memory. + __syncthreads(); + + // Run the final reduction amongst the different groups computing different partial outputs. + if (Dh == Dh_MAX || vi < Dh) { +#pragma unroll + for (int active_groups = V_PER_ITER; active_groups >= 2; active_groups /= 2) { + + // The midpoint in the number of active groups. + int midpoint = active_groups / 2; + + // The upper part of active threads store to shared memory. + if (vo >= midpoint && vo < active_groups && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + convert_from_float(*reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]), out); +#else + *reinterpret_cast(&out_smem[(vo - midpoint) * Dh + vi]) = out; +#endif + } + __syncthreads(); + + // The bottom warps update their values. + if (vo < midpoint && (Dh == Dh_MAX || vi < Dh)) { + out = add(*reinterpret_cast(&out_smem[vo * Dh + vi]), out); + } + __syncthreads(); + } + } + + // Output the final values. + if (vo == 0 && (Dh == Dh_MAX || vi < Dh)) { +#ifdef MMHA_USE_FP32_ACUM_FOR_OUT + if (params.int8_mode == 2) { + using Packed_Int8_t = typename packed_type::value>::type; + out = mul(*params.attention_out_scale, out); + *reinterpret_cast(&(reinterpret_cast(params.out)[bhi * Dh + vi])) = + cast_to_int8(out); + } + else { + convert_from_float(*reinterpret_cast(¶ms.out[bhi * Dh + vi]), out); + } +#else + // TODO: support int8_mode? + *reinterpret_cast(¶ms.out[bhi * Dh + vi]) = out; +#endif + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace mmha + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& stream); diff --git a/awq/kernels/csrc/attention/decoder_masked_multihead_attention_utils.h b/awq/kernels/csrc/attention/decoder_masked_multihead_attention_utils.h new file mode 100644 index 0000000..fe46257 --- /dev/null +++ b/awq/kernels/csrc/attention/decoder_masked_multihead_attention_utils.h @@ -0,0 +1,1786 @@ +// Downloaded from from FasterTransformer v5.2.1 +// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h +/* + * Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include "cuda_bf16_wrapper.h" +#include "cuda_bf16_fallbacks.cuh" +#include + +using namespace fastertransformer; + +namespace mmha { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float8_ { + float2 x; + float2 y; + float2 z; + float2 w; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct Float4_ { + float2 x; + float2 y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +struct bf16_4_t { + __nv_bfloat162 x; + __nv_bfloat162 y; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct bf16_8_t { + __nv_bfloat162 x; + __nv_bfloat162 y; + __nv_bfloat162 z; + __nv_bfloat162 w; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct num_elems; +template<> +struct num_elems { + static constexpr int value = 1; +}; +template<> +struct num_elems { + static constexpr int value = 2; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 8; +}; + +template<> +struct num_elems { + static constexpr int value = 2; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 8; +}; + +#ifdef ENABLE_BF16 +template<> +struct num_elems<__nv_bfloat162> { + static constexpr int value = 2; +}; +template<> +struct num_elems { + static constexpr int value = 4; +}; +template<> +struct num_elems { + static constexpr int value = 8; +}; +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct packed_type; +template +struct packed_type { + using type = T; +}; +template<> +struct packed_type { + using type = int16_t; +}; +template<> +struct packed_type { + using type = int32_t; +}; +template<> +struct packed_type { + using type = int64_t; +}; + +template<> +struct packed_type { + using type = float2; +}; +template<> +struct packed_type { + using type = float4; +}; +template<> +struct packed_type { + using type = Float8_; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float add(float a, float b) +{ + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(float2 a, float2 b) +{ + float2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 add(float4 a, float4 b) +{ + float4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat16 add(__nv_bfloat16 a, __nv_bfloat16 b) +{ + return a + b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ __nv_bfloat162 add(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hadd2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t add(bf16_4_t a, bf16_4_t b) +{ + bf16_4_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t add(bf16_8_t a, bf16_8_t b) +{ + bf16_8_t c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t add(uint16_t a, uint16_t b) +{ + uint16_t c; + asm volatile("add.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t add(uint32_t a, uint32_t b) +{ + uint32_t c; + asm volatile("add.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 add(uint2 a, uint2 b) +{ + uint2 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 add(uint4 a, uint4 b) +{ + uint4 c; + c.x = add(a.x, b.x); + c.y = add(a.y, b.y); + c.z = add(a.z, b.z); + c.w = add(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint16_t float_to_half(float f) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better? + float zero = 0.f; + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f)); +#endif + return tmp.u16[0]; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t float2_to_half2(float2 f) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(f.y), "f"(f.x)); +#else + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[0]) : "f"(f.x)); + asm volatile("cvt.rn.f16.f32 %0, %1;\n" : "=h"(tmp.u16[1]) : "f"(f.y)); +#endif + return tmp.u32; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float half_to_float(uint16_t h) +{ + float f; + asm volatile("cvt.f32.f16 %0, %1;\n" : "=f"(f) : "h"(h)); + return f; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 half2_to_float2(uint32_t v) +{ + uint16_t lo, hi; + asm volatile("mov.b32 {%0, %1}, %2;\n" : "=h"(lo), "=h"(hi) : "r"(v)); + return make_float2(half_to_float(lo), half_to_float(hi)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float add(float a, uint16_t b) +{ + return a + half_to_float(b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float add(float a, __nv_bfloat16 b) +{ + return a + __bfloat162float(b); +} +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 add(uint32_t a, float2 fb) +{ + float2 fa = half2_to_float2(a); + return add(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ add(uint2 a, Float4_ fb) +{ + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ add(uint4 a, Float8_ fb) +{ + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t h0_h0(uint16_t a) +{ + uint32_t b; + asm volatile("mov.b32 %0, {%1, %1};" : "=r"(b) : "h"(a)); + return b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(float a, float b, float c) +{ + return a * b + c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float2 a, float2 b, float2 c) +{ + float2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(float a, float2 b, float2 c) +{ + float2 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float4 a, float4 b, float4 c) +{ + float4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float4 fma(float a, float4 b, float4 c) +{ + float4 d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(float a, Float4_ b, Float4_ c) +{ + Float4_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(float a, Float8_ b, Float8_ c) +{ + Float8_ d; + d.x = fma(a, b.x, c.x); + d.y = fma(a, b.y, c.y); + d.z = fma(a, b.z, c.z); + d.w = fma(a, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float2 add(__nv_bfloat162 a, float2 fb) +{ + float2 fa = bf1622float2(a); + return add(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ add(bf16_4_t a, Float4_ fb) +{ + Float4_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ add(bf16_8_t a, Float8_ fb) +{ + Float8_ fc; + fc.x = add(a.x, fb.x); + fc.y = add(a.y, fb.y); + fc.z = add(a.z, fb.z); + fc.w = add(a.w, fb.w); + return fc; +} +#endif // ENABLE_BF16 + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t fma(uint32_t a, uint32_t b, uint32_t c) +{ + uint32_t d; + asm volatile("fma.rn.f16x2 %0, %1, %2, %3;\n" : "=r"(d) : "r"(a), "r"(b), "r"(c)); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint32_t fma(uint16_t a, uint32_t b, uint32_t c) +{ + return fma(h0_h0(a), b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 fma(uint2 a, uint2 b, uint2 c) +{ + uint2 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint2 fma(uint16_t a, uint2 b, uint2 c) +{ + uint32_t s = h0_h0(a); + uint2 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 fma(uint4 a, uint4 b, uint4 c) +{ + uint4 d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ uint4 fma(uint16_t a, uint4 b, uint4 c) +{ + uint32_t s = h0_h0(a); + uint4 d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(uint16_t a, uint16_t b, float fc) +{ + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb + fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(uint32_t a, uint32_t b, float2 fc) +{ + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return fma(fa, fb, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(uint16_t a, uint32_t b, float2 fc) +{ + return fma(h0_h0(a), b, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(uint2 a, uint2 b, Float4_ fc) +{ + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(uint16_t a, uint2 b, Float4_ fc) +{ + uint32_t s = h0_h0(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(uint4 a, uint4 b, Float8_ fc) +{ + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(uint16_t a, uint4 b, Float8_ fc) +{ + uint32_t s = h0_h0(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 fma(__nv_bfloat162 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(a, b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ __nv_bfloat162 fma(__nv_bfloat16 a, __nv_bfloat162 b, __nv_bfloat162 c) +{ + return bf16hfma2(bf162bf162(a), b, c); +} +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t fma(bf16_4_t a, bf16_4_t b, bf16_4_t c) +{ + bf16_4_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_4_t fma(__nv_bfloat16 a, bf16_4_t b, bf16_4_t c) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t fma(bf16_8_t a, bf16_8_t b, bf16_8_t c) +{ + bf16_8_t d; + d.x = fma(a.x, b.x, c.x); + d.y = fma(a.y, b.y, c.y); + d.z = fma(a.z, b.z, c.z); + d.w = fma(a.w, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ bf16_8_t fma(__nv_bfloat16 a, bf16_8_t b, bf16_8_t c) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t d; + d.x = fma(s, b.x, c.x); + d.y = fma(s, b.y, c.y); + d.z = fma(s, b.z, c.z); + d.w = fma(s, b.w, c.w); + return d; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float fma(__nv_bfloat16 a, __nv_bfloat16 b, float fc) +{ + return __bfloat162float(a) * __bfloat162float(b) + fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(__nv_bfloat162 a, __nv_bfloat162 b, float2 fc) +{ + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return fma(fa, fb, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 fma(__nv_bfloat16 a, __nv_bfloat162 b, float2 fc) +{ + return fma(bf162bf162(a), b, fc); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(bf16_4_t a, bf16_4_t b, Float4_ fc) +{ + Float4_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float4_ fma(__nv_bfloat16 a, bf16_4_t b, Float4_ fc) +{ + __nv_bfloat162 s = bf162bf162(a); + Float4_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(bf16_8_t a, bf16_8_t b, Float8_ fc) +{ + Float8_ fd; + fd.x = fma(a.x, b.x, fc.x); + fd.y = fma(a.y, b.y, fc.y); + fd.z = fma(a.z, b.z, fc.z); + fd.w = fma(a.w, b.w, fc.w); + return fd; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ Float8_ fma(__nv_bfloat16 a, bf16_8_t b, Float8_ fc) +{ + __nv_bfloat162 s = bf162bf162(a); + Float8_ fd; + fd.x = fma(s, b.x, fc.x); + fd.y = fma(s, b.y, fc.y); + fd.z = fma(s, b.z, fc.z); + fd.w = fma(s, b.w, fc.w); + return fd; +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ Acc mul(A a, B b) +{ + return a * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(float a, float b) +{ + return a * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(float2 a, float2 b) +{ + float2 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(float a, float2 b) +{ + float2 c; + c.x = a * b.x; + c.y = a * b.y; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float4 mul(float4 a, float4 b) +{ + float4 c; + c.x = a.x * b.x; + c.y = a.y * b.y; + c.z = a.z * b.z; + c.w = a.w * b.w; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float4 mul(float a, float4 b) +{ + float4 c; + c.x = a * b.x; + c.y = a * b.y; + c.z = a * b.z; + c.w = a * b.w; + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(float a, Float8_ b) +{ + Float8_ c; + c.x = make_float2(a * b.x.x, a * b.x.y); + c.y = make_float2(a * b.y.x, a * b.y.y); + c.z = make_float2(a * b.z.x, a * b.z.y); + c.w = make_float2(a * b.w.x, a * b.w.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint16_t mul(uint16_t a, uint16_t b) +{ + uint16_t c; + asm volatile("mul.f16 %0, %1, %2;\n" : "=h"(c) : "h"(a), "h"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint32_t mul(uint32_t a, uint32_t b) +{ + uint32_t c; + asm volatile("mul.f16x2 %0, %1, %2;\n" : "=r"(c) : "r"(a), "r"(b)); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint32_t mul(uint16_t a, uint32_t b) +{ + return mul(h0_h0(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint2 mul(uint2 a, uint2 b) +{ + uint2 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint2 mul(uint16_t a, uint2 b) +{ + uint32_t s = h0_h0(a); + uint2 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint4 mul(uint4 a, uint4 b) +{ + uint4 c; + c.x = mul(a.x, b.x); + c.y = mul(a.y, b.y); + c.z = mul(a.z, b.z); + c.w = mul(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ uint4 mul(uint16_t a, uint4 b) +{ + uint32_t s = h0_h0(a); + uint4 c; + c.x = mul(s, b.x); + c.y = mul(s, b.y); + c.z = mul(s, b.z); + c.w = mul(s, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(uint16_t a, uint16_t b) +{ + float fa = half_to_float(a); + float fb = half_to_float(b); + return fa * fb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(uint16_t a, float b) +{ + return half_to_float(a) * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(uint32_t a, uint32_t b) +{ + float2 fa = half2_to_float2(a); + float2 fb = half2_to_float2(b); + return mul(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(uint16_t a, uint32_t b) +{ + return mul(h0_h0(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(uint2 a, uint2 b) +{ + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(uint16_t a, uint2 b) +{ + uint32_t s = h0_h0(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(uint4 a, uint4 b) +{ + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(uint16_t a, uint4 b) +{ + uint32_t s = h0_h0(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +template<> +inline __device__ __nv_bfloat16 mul(__nv_bfloat16 a, __nv_bfloat16 b) +{ +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 + return __hmul(a, b); +#else + return bf16hmul(a, b); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat162 a, __nv_bfloat162 b) +{ + return bf16hmul2(a, b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ __nv_bfloat162 mul(__nv_bfloat16 a, __nv_bfloat162 b) +{ + return mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(bf162bf162(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_4_t mul(bf16_4_t a, bf16_4_t b) +{ + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_4_t mul(__nv_bfloat16 a, bf16_4_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_4_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_8_t mul(bf16_8_t a, bf16_8_t b) +{ + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.x, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.y, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.z, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(a.w, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ bf16_8_t mul(__nv_bfloat16 a, bf16_8_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + bf16_8_t c; + c.x = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.x); + c.y = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.y); + c.z = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.z); + c.w = mul<__nv_bfloat162, __nv_bfloat162, __nv_bfloat162>(s, b.w); + return c; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(__nv_bfloat16 a, __nv_bfloat16 b) +{ + float fa = (float)a; + float fb = (float)b; + return fa * fb; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float mul(__nv_bfloat16 a, float b) +{ + return __bfloat162float(a) * b; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(__nv_bfloat162 a, __nv_bfloat162 b) +{ + float2 fa = bf1622float2(a); + float2 fb = bf1622float2(b); + return mul(fa, fb); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ float2 mul(__nv_bfloat16 a, __nv_bfloat162 b) +{ + return mul(bf162bf162(a), b); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(bf16_4_t a, bf16_4_t b) +{ + Float4_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float4_ mul(__nv_bfloat16 a, bf16_4_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + Float4_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(bf16_8_t a, bf16_8_t b) +{ + Float8_ fc; + fc.x = mul(a.x, b.x); + fc.y = mul(a.y, b.y); + fc.z = mul(a.z, b.z); + fc.w = mul(a.w, b.w); + return fc; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template<> +inline __device__ Float8_ mul(__nv_bfloat16 a, bf16_8_t b) +{ + __nv_bfloat162 s = bf162bf162(a); + Float8_ fc; + fc.x = mul(s, b.x); + fc.y = mul(s, b.y); + fc.z = mul(s, b.z); + fc.w = mul(s, b.w); + return fc; +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float v) +{ + return v; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float2 v) +{ + return v.x + v.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(float4 v) +{ + return v.x + v.y + v.z + v.w; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#ifdef ENABLE_BF16 +inline __device__ float sum(__nv_bfloat162 v) +{ + float2 vf = bf1622float2(v); + return vf.x + vf.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(bf16_4_t v) +{ + return sum(v.x) + sum(v.y); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(bf16_8_t v) +{ + return sum(v.x) + sum(v.y) + sum(v.z) + sum(v.w); +} +#endif // ENABLE_BF16 +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint16_t v) +{ + return half_to_float(v); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint32_t v) +{ + float2 tmp = half2_to_float2(v); + return tmp.x + tmp.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint2 v) +{ + uint32_t c = add(v.x, v.y); + return sum(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(uint4 v) +{ +#if 1 + uint32_t c = add(v.x, v.y); + c = add(c, v.z); + c = add(c, v.w); +#else + uint32_t c = add(v.x, v.y); + uint32_t d = add(v.z, v.w); + c = add(c, d); +#endif + return sum(c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(Float4_ v) +{ + return v.x.x + v.x.y + v.y.x + v.y.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float sum(Float8_ v) +{ + return v.x.x + v.x.y + v.y.x + v.y.y + v.z.x + v.z.y + v.w.x + v.w.y; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float dot(T a, T b) +{ + return sum(mul(a, b)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ float dot(T a, T b) +{ + return sum(mul(a, b)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ void zero(uint16_t& dst) +{ + dst = uint16_t(0); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +inline __device__ void zero(T& dst) +{ + constexpr int WORDS = sizeof(T) / 4; + union { + T raw; + uint32_t words[WORDS]; + } tmp; +#pragma unroll + for (int ii = 0; ii < WORDS; ++ii) { + tmp.words[ii] = 0u; + } + dst = tmp.raw; +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline __device__ float2 rotary_embedding_coefficient(const int zid, const int rot_embed_dim, const float t_step, const float base) +{ + const float inv_freq = t_step / pow(base, zid / (float)rot_embed_dim); + return {cos(inv_freq), sin(inv_freq)}; +} + +inline __device__ float2 rotary_embedding_transform(const float2 v, const float2 coef) +{ + float2 rot_v; + rot_v.x = coef.x * v.x - coef.y * v.y; + rot_v.y = coef.x * v.y + coef.y * v.x; + return rot_v; +} + +inline __device__ uint32_t rotary_embedding_transform(const uint32_t v, const float2 coef) +{ + float2 fv = half2_to_float2(v); + float2 rot_fv = rotary_embedding_transform(fv, coef); + return float2_to_half2(rot_fv); +} + +#ifdef ENABLE_BF16 +inline __device__ __nv_bfloat162 rotary_embedding_transform(const __nv_bfloat162 v, const float2 coef) +{ + float2 fv = bf1622float2(v); + float2 rot_fv = rotary_embedding_transform(fv, coef); + return __floats2bfloat162_rn(rot_fv.x, rot_fv.y); +} +#endif + +inline __device__ void apply_rotary_embedding(float& q, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + return; +} + +inline __device__ void apply_rotary_embedding(float& q, float& k, int zid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + return; +} + +inline __device__ void apply_rotary_embedding(float2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(float2& q, float2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(float4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + + Float4_& q_ = *reinterpret_cast(&q); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q_.x = rotary_embedding_transform(q_.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q_.y = rotary_embedding_transform(q_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(float4& q, float4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + + Float4_& q_ = *reinterpret_cast(&q); + Float4_& k_ = *reinterpret_cast(&k); + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q_.x = rotary_embedding_transform(q_.x, coef0); + k_.x = rotary_embedding_transform(k_.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q_.y = rotary_embedding_transform(q_.y, coef1); + k_.y = rotary_embedding_transform(k_.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void apply_rotary_embedding(uint32_t& q, uint32_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(uint2& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint2& q, uint2& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(uint4& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(uint4& q, uint4& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} + +#ifdef ENABLE_BF16 +inline __device__ void apply_rotary_embedding(__nv_bfloat162& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); +} + +inline __device__ void +apply_rotary_embedding(__nv_bfloat162& q, __nv_bfloat162& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (2 * tid >= rot_embed_dim) { + return; + } + const auto coef = rotary_embedding_coefficient(2 * tid, rot_embed_dim, t_step, base); + q = rotary_embedding_transform(q, coef); + k = rotary_embedding_transform(k, coef); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_4_t& q, bf16_4_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (4 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(4 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(4 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); + q.z = rotary_embedding_transform(q.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); + q.w = rotary_embedding_transform(q.w, coef3); +} + +inline __device__ void apply_rotary_embedding(bf16_8_t& q, bf16_8_t& k, int tid, int rot_embed_dim, int t_step, const float base=10000.0f) +{ + if (8 * tid >= rot_embed_dim) { + return; + } + const auto coef0 = rotary_embedding_coefficient(8 * tid, rot_embed_dim, t_step, base); + q.x = rotary_embedding_transform(q.x, coef0); + k.x = rotary_embedding_transform(k.x, coef0); + const auto coef1 = rotary_embedding_coefficient(8 * tid + 2, rot_embed_dim, t_step, base); + q.y = rotary_embedding_transform(q.y, coef1); + k.y = rotary_embedding_transform(k.y, coef1); + const auto coef2 = rotary_embedding_coefficient(8 * tid + 4, rot_embed_dim, t_step, base); + q.z = rotary_embedding_transform(q.z, coef2); + k.z = rotary_embedding_transform(k.z, coef2); + const auto coef3 = rotary_embedding_coefficient(8 * tid + 6, rot_embed_dim, t_step, base); + q.w = rotary_embedding_transform(q.w, coef3); + k.w = rotary_embedding_transform(k.w, coef3); +} +#endif // ENABLE_BF16 + +template +__device__ __inline__ void vec_from_smem_transpose(Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); + +template<> +__device__ __inline__ void vec_from_smem_transpose(float& vec, float* smem, int transpose_idx, int smem_pitch) +{ + return; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u16[0] = smem[transpose_idx]; + tmp.u16[1] = smem[smem_pitch + transpose_idx]; + + vec = tmp.u32; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp_1, tmp_2; + tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + union { + uint2 u32x2; + uint16_t u16[4]; + } tmp_3; + tmp_3.u16[0] = tmp_1.u16[0]; + tmp_3.u16[1] = tmp_2.u16[0]; + tmp_3.u16[2] = tmp_1.u16[1]; + tmp_3.u16[3] = tmp_2.u16[1]; + + vec = tmp_3.u32x2; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + uint16_t u16[4]; + } tmp_1, tmp_2; + tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + union { + uint4 u32x4; + uint16_t u16[8]; + } tmp_3; + tmp_3.u16[0] = tmp_1.u16[0]; + tmp_3.u16[1] = tmp_2.u16[0]; + tmp_3.u16[2] = tmp_1.u16[1]; + tmp_3.u16[3] = tmp_2.u16[1]; + tmp_3.u16[4] = tmp_1.u16[2]; + tmp_3.u16[5] = tmp_2.u16[2]; + tmp_3.u16[6] = tmp_1.u16[3]; + tmp_3.u16[7] = tmp_2.u16[3]; + + vec = tmp_3.u32x4; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +vec_from_smem_transpose(bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + __nv_bfloat16 bf16[2]; + } tmp_1, tmp_2; + tmp_1.u32 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u32 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; + vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; +} + +template<> +__device__ __inline__ void +vec_from_smem_transpose(bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + __nv_bfloat16 bf16[4]; + } tmp_1, tmp_2; + tmp_1.u64 = *reinterpret_cast(&smem[transpose_idx]); + tmp_2.u64 = *reinterpret_cast(&smem[smem_pitch + transpose_idx]); + + vec.x = __nv_bfloat162{tmp_1.bf16[0], tmp_2.bf16[0]}; + vec.y = __nv_bfloat162{tmp_1.bf16[1], tmp_2.bf16[1]}; + vec.z = __nv_bfloat162{tmp_1.bf16[2], tmp_2.bf16[2]}; + vec.w = __nv_bfloat162{tmp_1.bf16[3], tmp_2.bf16[3]}; +} +#endif // ENABLE_BF16 + +template<> +__device__ __inline__ void vec_from_smem_transpose(float4& vec, float* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.z = smem[transpose_idx + 1]; + vec.y = smem[smem_pitch + transpose_idx]; + vec.w = smem[smem_pitch + transpose_idx + 1]; +} + +template<> +__device__ __inline__ void vec_from_smem_transpose(uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + half u16[2]; + } tmp; + tmp.u16[0] = smem[transpose_idx]; + tmp.u16[1] = smem[smem_pitch + transpose_idx]; + + vec = tmp.u32; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +vec_from_smem_transpose(__nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.y = smem[smem_pitch + transpose_idx]; +} +#endif + +template<> +__device__ __inline__ void vec_from_smem_transpose(float2& vec, float* smem, int transpose_idx, int smem_pitch) +{ + vec.x = smem[transpose_idx]; + vec.y = smem[smem_pitch + transpose_idx]; +} + +template +__device__ __inline__ void write_smem_transpose(const Vec_T& vec, T* smem, int transpose_idx, int smem_pitch); + +template<> +__device__ __inline__ void write_smem_transpose(const float& vec, float* smem, int transpose_idx, int smem_pitch) +{ + return; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint4& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint64_t u64; + uint16_t u16[4]; + } tmp_1, tmp_2; + + union { + uint4 u32x4; + uint16_t u16[8]; + } tmp_3; + tmp_3.u32x4 = vec; + tmp_1.u16[0] = tmp_3.u16[0]; + tmp_2.u16[0] = tmp_3.u16[1]; + tmp_1.u16[1] = tmp_3.u16[2]; + tmp_2.u16[1] = tmp_3.u16[3]; + tmp_1.u16[2] = tmp_3.u16[4]; + tmp_2.u16[2] = tmp_3.u16[5]; + tmp_1.u16[3] = tmp_3.u16[6]; + tmp_2.u16[3] = tmp_3.u16[7]; + + *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u64; + *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u64; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint2& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp_1, tmp_2; + + union { + uint2 u32x2; + uint16_t u16[4]; + } tmp_3; + tmp_3.u32x2 = vec; + tmp_1.u16[0] = tmp_3.u16[0]; + tmp_2.u16[0] = tmp_3.u16[1]; + tmp_1.u16[1] = tmp_3.u16[2]; + tmp_2.u16[1] = tmp_3.u16[3]; + + *reinterpret_cast(&smem[transpose_idx]) = tmp_1.u32; + *reinterpret_cast(&smem[smem_pitch + transpose_idx]) = tmp_2.u32; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint32_t& vec, uint16_t* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + uint16_t u16[2]; + } tmp; + tmp.u32 = vec; + + smem[transpose_idx] = tmp.u16[0]; + smem[smem_pitch + transpose_idx] = tmp.u16[1]; +} + +template<> +__device__ __inline__ void write_smem_transpose(const float4& vec, float* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[transpose_idx + 1] = vec.z; + smem[smem_pitch + transpose_idx] = vec.y; + smem[smem_pitch + transpose_idx + 1] = vec.w; +} + +template<> +__device__ __inline__ void write_smem_transpose(const uint32_t& vec, half* smem, int transpose_idx, int smem_pitch) +{ + union { + uint32_t u32; + half u16[2]; + } tmp; + + tmp.u32 = vec; + smem[transpose_idx] = tmp.u16[0]; + smem[smem_pitch + transpose_idx] = tmp.u16[1]; +} + +#ifdef ENABLE_BF16 +template<> +__device__ __inline__ void +write_smem_transpose(const __nv_bfloat162& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[smem_pitch + transpose_idx] = vec.y; +} + +template<> +__device__ __inline__ void +write_smem_transpose(const bf16_4_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); +} + +template<> +__device__ __inline__ void +write_smem_transpose(const bf16_8_t& vec, __nv_bfloat16* smem, int transpose_idx, int smem_pitch) +{ + write_smem_transpose(reinterpret_cast(vec), reinterpret_cast(smem), transpose_idx, smem_pitch); +} +#endif + +template<> +__device__ __inline__ void write_smem_transpose(const float2& vec, float* smem, int transpose_idx, int smem_pitch) +{ + smem[transpose_idx] = vec.x; + smem[smem_pitch + transpose_idx] = vec.y; +} + +} // namespace mmha diff --git a/awq/kernels/csrc/attention/ft_attention.cpp b/awq/kernels/csrc/attention/ft_attention.cpp new file mode 100644 index 0000000..bef83ec --- /dev/null +++ b/awq/kernels/csrc/attention/ft_attention.cpp @@ -0,0 +1,182 @@ +// Adapted from NVIDIA/FasterTransformer and FlashAttention + +#include +#include "ATen/cuda/CUDAContext.h" +#include + +#include "ft_attention.h" +#include "decoder_masked_multihead_attention.h" + +#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") +#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")") +#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous") + +#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \ + if (TYPE == at::ScalarType::Half) { \ + using scalar_t = at::Half; \ + __VA_ARGS__(); \ + } else if (TYPE == at::ScalarType::BFloat16) { \ + using scalar_t = at::BFloat16; \ + __VA_ARGS__(); \ + } else if (TYPE == at::ScalarType::Float) { \ + using scalar_t = float; \ + __VA_ARGS__(); \ + } else { \ + AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \ + } + +template +void masked_multihead_attention(const Masked_multihead_attention_params& params, + const cudaStream_t& stream); + +template +void cross_multihead_attention(const Masked_multihead_attention_params& params, + const cudaStream_t& stream); + +template +struct SATypeConverter { + using Type = T; +}; + +template<> +struct SATypeConverter { + using Type = uint16_t; +}; + +template<> +struct SATypeConverter { + using Type = __nv_bfloat16; +}; + +template +void set_params(Masked_multihead_attention_params ¶ms, + const size_t batch_size, + const size_t nheads, + const size_t nheads_kv, + const size_t memory_max_seqlen, + const size_t headdim, + const int timestep, + const int rotary_embedding_dim, + const float rotary_base, + const bool neox_rotary_style, + const int qkv_batch_stride, + T *q_ptr, + T *k_ptr, + T *v_ptr, + T *k_cache_ptr, + T *v_cache_ptr, + int *length_per_sample, + float *alibi_slopes_ptr, + T *out_ptr) { + // Reset the parameters + memset(¶ms, 0, sizeof(params)); + params.q = q_ptr; + params.k = k_ptr; + params.v = v_ptr; + params.q_bias = nullptr; + params.k_bias = nullptr; + params.v_bias = nullptr; + params.k_cache = k_cache_ptr; + params.v_cache = v_cache_ptr; + params.linear_bias_slopes = alibi_slopes_ptr; + params.out = out_ptr; + params.cache_indir = nullptr; + params.stride = qkv_batch_stride; + params.batch_size = batch_size; + params.beam_width = 1; + params.memory_max_len = memory_max_seqlen; + params.num_heads = nheads; + params.num_kv_heads = nheads_kv; + params.hidden_size_per_head = headdim; + params.rotary_embedding_dim = rotary_embedding_dim; + params.rotary_base = rotary_base; + params.neox_rotary_style = neox_rotary_style; + params.timestep = timestep; + params.inv_sqrt_dh = 1.f / sqrt(float(headdim)); + params.total_padding_tokens = nullptr; + params.masked_tokens = nullptr; + params.prefix_prompt_lengths = nullptr; + params.max_prefix_prompt_length = 0; + params.relative_attention_bias = nullptr; + params.relative_attention_bias_stride = 0; + params.cross_attention_out = nullptr; + params.max_decoder_seq_len = 0; + params.is_return_cross_attentions = false; + params.finished = nullptr; + params.memory_length_per_sample = nullptr; + params.length_per_sample = length_per_sample; +} + +torch::Tensor single_query_attention(const torch::Tensor q, + const torch::Tensor k, + const torch::Tensor v, + torch::Tensor k_cache, + torch::Tensor v_cache, + c10::optional length_per_sample_, + c10::optional alibi_slopes_, + const int timestep, + const int rotary_embedding_dim, + const float rotary_base, + // neox_rotary_style = not interleaved + const bool neox_rotary_style) { + CHECK_DEVICE(q); CHECK_DEVICE(k); CHECK_DEVICE(v); CHECK_DEVICE(k_cache); CHECK_DEVICE(v_cache); + int batch_size = v_cache.size(0); + int nheads = q.size(1); + int nheads_kv = v_cache.size(1); + int memory_max_seqlen = v_cache.size(2); + int headdim = v_cache.size(3); + CHECK_SHAPE(q, batch_size, nheads, headdim); + CHECK_SHAPE(k, batch_size, nheads_kv, headdim); + CHECK_SHAPE(v, batch_size, nheads_kv, headdim); + CHECK_SHAPE(v_cache, batch_size, nheads_kv, memory_max_seqlen, headdim); + // k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32 + int packsize = k_cache.dtype() == torch::kFloat32 ? 4 : 8; + CHECK_SHAPE(k_cache, batch_size, nheads_kv, headdim / packsize, memory_max_seqlen, packsize); + TORCH_CHECK(q.stride(2) == 1 && q.stride(1) == headdim); + TORCH_CHECK(k.stride(2) == 1 && k.stride(1) == headdim); + TORCH_CHECK(v.stride(2) == 1 && v.stride(1) == headdim); + // TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0)); + CHECK_CONTIGUOUS(v_cache); CHECK_CONTIGUOUS(k_cache); + + if (length_per_sample_.has_value()) { + auto length_per_sample = length_per_sample_.value(); + CHECK_DEVICE(length_per_sample); + CHECK_SHAPE(length_per_sample, batch_size); + CHECK_CONTIGUOUS(length_per_sample); + TORCH_CHECK(length_per_sample.dtype() == torch::kInt32); + } + + if (alibi_slopes_.has_value()) { + auto alibi_slopes = alibi_slopes_.value(); + CHECK_DEVICE(alibi_slopes); + CHECK_SHAPE(alibi_slopes, nheads); + CHECK_CONTIGUOUS(alibi_slopes); + TORCH_CHECK(alibi_slopes.dtype() == torch::kFloat32); + } + + // Otherwise the kernel will be launched from cuda:0 device + // Cast to char to avoid compiler warning about narrowing + at::cuda::CUDAGuard device_guard{(char)q.get_device()}; + + torch::Tensor out = torch::empty_like(q); + + DISPATCH_FLOAT_AND_HALF_AND_BF16(q.scalar_type(), "single_query_attention", [&] { + using DataType = typename SATypeConverter::Type; + Masked_multihead_attention_params params; + set_params(params, batch_size, nheads, nheads_kv, memory_max_seqlen, headdim, + timestep, rotary_embedding_dim, rotary_base, neox_rotary_style, q.stride(0), + reinterpret_cast(q.data_ptr()), + reinterpret_cast(k.data_ptr()), + reinterpret_cast(v.data_ptr()), + reinterpret_cast(k_cache.data_ptr()), + reinterpret_cast(v_cache.data_ptr()), + length_per_sample_.has_value() + ? length_per_sample_.value().data_ptr() : nullptr, + alibi_slopes_.has_value() + ? alibi_slopes_.value().data_ptr(): nullptr, + reinterpret_cast(out.data_ptr())); + auto stream = at::cuda::getCurrentCUDAStream(); + masked_multihead_attention(params, stream); + }); + return out; +} \ No newline at end of file diff --git a/awq/kernels/csrc/attention/ft_attention.h b/awq/kernels/csrc/attention/ft_attention.h new file mode 100644 index 0000000..df67d53 --- /dev/null +++ b/awq/kernels/csrc/attention/ft_attention.h @@ -0,0 +1,15 @@ +#pragma once +#include + + +torch::Tensor single_query_attention(const torch::Tensor q, + const torch::Tensor k, + const torch::Tensor v, + torch::Tensor k_cache, + torch::Tensor v_cache, + c10::optional length_per_sample_, + c10::optional alibi_slopes_, + const int timestep, + const int rotary_embedding_dim = 0, + const float rotary_base = 10000.0f, + const bool neox_rotary_style=true); \ No newline at end of file diff --git a/awq/kernels/csrc/attention/setup.py b/awq/kernels/csrc/attention/setup.py new file mode 100644 index 0000000..dc479f2 --- /dev/null +++ b/awq/kernels/csrc/attention/setup.py @@ -0,0 +1,152 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +from packaging.version import parse, Version + +from setuptools import setup, find_packages +import subprocess + +import torch +from torch.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension, CUDA_HOME + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_torch_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + torch_binary_version = parse(torch.version.cuda) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != torch_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pytorch binaries. " + "Pytorch binaries were compiled with Cuda {}.\n".format(torch.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pytorch/pytorch, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.2"): + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + + +if not torch.cuda.is_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pytorch/pytorch/pull/23408 attempt to query torch.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, Apex will cross-compile for Pascal (compute capabilities 6.0, 6.1, 6.2),\n" + "Volta (compute capability 7.0), Turing (compute capability 7.5),\n" + "and, if the CUDA version is >= 11.0, Ampere (compute capability 8.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6;9.0" + elif bare_metal_version >= Version("11.1"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0;8.6" + elif bare_metal_version == Version("11.0"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5;8.0" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "6.0;6.1;6.2;7.0;7.5" + + +print("\n\ntorch.__version__ = {}\n\n".format(torch.__version__)) +TORCH_MAJOR = int(torch.__version__.split(".")[0]) +TORCH_MINOR = int(torch.__version__.split(".")[1]) + +cmdclass = {} +ext_modules = [] + +# Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h +# See https://github.com/pytorch/pytorch/pull/70650 +generator_flag = [] +torch_dir = torch.__path__[0] +if os.path.exists(os.path.join(torch_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + +raise_if_cuda_home_none("--ft_attention") +# Check, if CUDA11 is installed for compute capability 8.0 +cc_flag = [] +_, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) +if bare_metal_version < Version("11.0"): + raise RuntimeError("ft_attention is only supported on CUDA 11 and above") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_70,code=sm_70") +cc_flag.append("-gencode") +cc_flag.append("arch=compute_80,code=sm_80") +if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + +ext_modules.append( + CUDAExtension( + name="ft_attention", + sources=[ + "ft_attention.cpp", + "decoder_masked_multihead_attention.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-DENABLE_BF16"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-DENABLE_BF16", # TODO + "-O3", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[this_dir], + ) +) + +setup( + name="ft_attention", + version="0.1", + description="Attention for single query from FasterTransformer", + ext_modules=ext_modules, + cmdclass={"build_ext": BuildExtension} if ext_modules else {}, +) diff --git a/awq/kernels/csrc/pybind.cpp b/awq/kernels/csrc/pybind.cpp index cda4a6c..220911a 100644 --- a/awq/kernels/csrc/pybind.cpp +++ b/awq/kernels/csrc/pybind.cpp @@ -1,12 +1,19 @@ #include #include +#include "attention/ft_attention.h" #include "layernorm/layernorm.h" #include "quantization/gemm_cuda.h" +#include "quantization/gemv_cuda.h" #include "position_embedding/pos_encoding.h" PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("layernorm_forward_cuda", &layernorm_forward_cuda, "FasterTransformer layernorm kernel"); m.def("gemm_forward_cuda", &gemm_forward_cuda, "Quantized GEMM kernel."); + m.def("gemv_forward_cuda", &gemv_forward_cuda, "Quantized GEMV kernel."); m.def("rotary_embedding_neox", &rotary_embedding_neox, "Apply GPT-NeoX style rotary embedding to query and key"); + m.def("single_query_attention", &single_query_attention, "Attention with a single query", + py::arg("q"), py::arg("k"), py::arg("v"), py::arg("k_cache"), py::arg("v_cache"), + py::arg("length_per_sample_"), py::arg("alibi_slopes_"), py::arg("timestep"), py::arg("rotary_embedding_dim")=0, + py::arg("rotary_base")=10000.0f, py::arg("neox_rotary_style")=true); } diff --git a/awq/kernels/csrc/quantization/gemv_cuda.cu b/awq/kernels/csrc/quantization/gemv_cuda.cu new file mode 100644 index 0000000..3a55e66 --- /dev/null +++ b/awq/kernels/csrc/quantization/gemv_cuda.cu @@ -0,0 +1,247 @@ +// Inspired by https://github.com/ankan-ban/llama_cu_awq +/* + +@article{lin2023awq, + title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration}, + author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song}, + journal={arXiv}, + year={2023} +} + +*/ + +#include +#include +#include +#include "gemv_cuda.h" +#define VECTORIZE_FACTOR 8 +#define Q_VECTORIZE_FACTOR 8 +#define PACK_FACTOR 8 +#define WARP_SIZE 32 + + +// Reduce sum within the warp using the tree reduction algorithm. +__device__ __forceinline__ float warp_reduce_sum(float sum) { + #pragma unroll + for(int i = 4; i >= 0; i--){ + sum += __shfl_down_sync(0xffffffff, sum, 1<(zeros + oc_idx * zeros_w + packed_group_idx * 2); + uint32_t packed_weights[4]; + // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) + *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); + // load scaling factors + // g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups. + float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 16 + (threadIdx.x / 2)]); + float current_zeros = (float)((packed_zeros >> (threadIdx.x / 2 * 4)) & 0xF); + int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; + const float4* inputs_ptr = inputs + inputs_ptr_delta; + // multiply 32 weights with 32 inputs + #pragma unroll + for (int ic_0 = 0; ic_0 < 4; ic_0++){ + // iterate over different uint32_t packed_weights in this loop + uint32_t current_packed_weight = packed_weights[ic_0]; + half packed_inputs[PACK_FACTOR]; + // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) + if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { + *((float4*)packed_inputs) = *(inputs_ptr + ic_0); + #pragma unroll + for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ + // iterate over 8 numbers packed within each uint32_t number + float current_single_weight_fp = (float)(current_packed_weight & 0xF); + float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); + psum += dequantized_weight * __half2float(packed_inputs[ic_1]); + current_packed_weight = current_packed_weight >> 4; + } + } + } + } + psum = warp_reduce_sum(psum); + if (threadIdx.x == 0) { + outputs[oc_idx] = __float2half(psum); + } +} + + +/* +Computes GEMV (group_size = 128). + +Args: + inputs: vector of shape [batch_size, IC]; + weight: matrix of shape [OC, IC / 8]; + output: vector of shape [OC]; + zeros: matrix of shape [OC, IC / group_size / 8]; + scaling_factors: matrix of shape [OC, IC / group_size]; + +Notes: + One cannot infer group_size from the shape of scaling factors. + the second dimension is rounded up to a multiple of PACK_FACTOR. +*/ +__global__ void gemv_kernel_g128( + const float4* _inputs, const uint32_t* weight, const uint32_t* zeros, const half* scaling_factors, half* _outputs, + const int IC, const int OC){ + const int group_size = 128; + float psum = 0; + const int batch_idx = blockIdx.z; + const int oc_idx = blockIdx.y * blockDim.y + threadIdx.y; + const float4* inputs = _inputs + batch_idx * IC / PACK_FACTOR; + half* outputs = _outputs + batch_idx * OC; + const int num_groups_packed = make_divisible(IC / group_size, PACK_FACTOR); + const int weight_w = IC / PACK_FACTOR; + // TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address + const int zeros_w = make_divisible(IC / group_size, PACK_FACTOR); + // consistent with input shape + const int sf_w = make_divisible(IC / group_size, PACK_FACTOR) * PACK_FACTOR; + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w); + // tile size: 4 OC x 1024 IC per iter + for(int packed_group_idx = 0; packed_group_idx < num_groups_packed; packed_group_idx++){ + // 1024 numbers in one iteration across warp. Need 1024 / group_size zeros. + uint32_t packed_zeros = *(zeros + oc_idx * zeros_w + packed_group_idx); + uint32_t packed_weights[4]; + // use float4 to load weights, each thread load 32 int4 numbers (1 x float4) + *((float4*)(packed_weights)) = *((float4*)(weight + oc_idx * weight_w + packed_group_idx * (WARP_SIZE * 4) + threadIdx.x * 4)); + // load scaling factors + // g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups. + float scaling_factor = __half2float(scaling_factors[oc_idx * sf_w + packed_group_idx * 8 + (threadIdx.x / 4)]); + float current_zeros = (float)((packed_zeros >> (threadIdx.x / 4 * 4)) & 0xF); + int inputs_ptr_delta = packed_group_idx * WARP_SIZE * 4 + threadIdx.x * 4; + const float4* inputs_ptr = inputs + inputs_ptr_delta; + // multiply 32 weights with 32 inputs + #pragma unroll + for (int ic_0 = 0; ic_0 < 4; ic_0++){ + // iterate over different uint32_t packed_weights in this loop + uint32_t current_packed_weight = packed_weights[ic_0]; + half packed_inputs[PACK_FACTOR]; + // each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8) + if (inputs_ptr_delta + ic_0 < IC / PACK_FACTOR) { + *((float4*)packed_inputs) = *(inputs_ptr + ic_0); + #pragma unroll + for (int ic_1 = 0; ic_1 < PACK_FACTOR; ic_1++){ + // iterate over 8 numbers packed within each uint32_t number + float current_single_weight_fp = (float)(current_packed_weight & 0xF); + float dequantized_weight = scaling_factor * (current_single_weight_fp - current_zeros); + //if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros); + psum += dequantized_weight * __half2float(packed_inputs[ic_1]); + current_packed_weight = current_packed_weight >> 4; + } + } + } + } + psum = warp_reduce_sum(psum); + if (threadIdx.x == 0) { + outputs[oc_idx] = __float2half(psum); + } +} + + +/* +Computes GEMV (PyTorch interface). + +Args: + _in_feats: tensor of shape [B, IC]; + _kernel: int tensor of shape [OC, IC // 8]; + _zeros: int tensor of shape [OC, IC // G // 8]; + _scaling_factors: tensor of shape [OC, IC // G]; + blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC; + blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC; + +Returns: + out_feats: tensor of shape [B, OC]; +*/ +torch::Tensor gemv_forward_cuda( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int group_size) +{ + int num_in_feats = _in_feats.size(0); + int num_in_channels = _in_feats.size(1); + // int kernel_volume = _out_in_map.size(1); + auto in_feats = reinterpret_cast(_in_feats.data_ptr()); + auto kernel = reinterpret_cast(_kernel.data_ptr()); + auto zeros = reinterpret_cast(_zeros.data_ptr()); + auto scaling_factors = reinterpret_cast(_scaling_factors.data_ptr()); + // auto out_in_map = _out_in_map.data_ptr(); + auto options = + torch::TensorOptions().dtype(_in_feats.dtype()).device(_in_feats.device()); + // kernel is [OC, IC] + at::Tensor _out_feats = torch::empty({num_in_feats, _kernel.size(0)}, options); + int num_out_feats = _out_feats.size(-2); + int num_out_channels = _out_feats.size(-1); + auto out_feats = reinterpret_cast(_out_feats.data_ptr()); + int blockDim_z = num_out_feats; + dim3 num_blocks(1, num_out_channels / 4, num_out_feats); + dim3 num_threads(32, 4); + if (group_size == 64) + { + gemv_kernel_g64<<>>( + // pointers + in_feats, kernel, zeros, scaling_factors, out_feats, + // constants + num_in_channels, num_out_channels + ); + } + else if (group_size == 128) + { + gemv_kernel_g128<<>>( + // pointers + in_feats, kernel, zeros, scaling_factors, out_feats, + // constants + num_in_channels, num_out_channels + ); + } + return _out_feats; +;} + diff --git a/awq/kernels/csrc/quantization/gemv_cuda.h b/awq/kernels/csrc/quantization/gemv_cuda.h new file mode 100644 index 0000000..748abc5 --- /dev/null +++ b/awq/kernels/csrc/quantization/gemv_cuda.h @@ -0,0 +1,9 @@ +#pragma once +#include + +torch::Tensor gemv_forward_cuda( + torch::Tensor _in_feats, + torch::Tensor _kernel, + torch::Tensor _scaling_factors, + torch::Tensor _zeros, + int group_size); diff --git a/awq/kernels/setup.py b/awq/kernels/setup.py index 78e938e..88e3095 100644 --- a/awq/kernels/setup.py +++ b/awq/kernels/setup.py @@ -1,9 +1,31 @@ from setuptools import find_packages, setup from torch.utils.cpp_extension import BuildExtension, CUDAExtension, CppExtension + extra_compile_args = { - "cxx": ["-g", "-O3", "-fopenmp", "-lgomp", "-std=c++17"], - "nvcc": ["-O3", "-std=c++17"], + "cxx": [ + "-g", + "-O3", + "-fopenmp", + "-lgomp", + "-std=c++17", + "-DENABLE_BF16" + ], + "nvcc": [ + "-O3", + "-std=c++17", + "-DENABLE_BF16", # TODO + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT16_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "-U__CUDA_NO_BFLOAT162_OPERATORS__", + "-U__CUDA_NO_BFLOAT162_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--threads=8" + ], } setup( @@ -15,8 +37,11 @@ sources=[ "csrc/pybind.cpp", "csrc/quantization/gemm_cuda_gen.cu", + "csrc/quantization/gemv_cuda.cu", "csrc/layernorm/layernorm.cu", - "csrc/position_embedding/pos_encoding_kernels.cu" + "csrc/position_embedding/pos_encoding_kernels.cu", + "csrc/attention/ft_attention.cpp", + "csrc/attention/decoder_masked_multihead_attention.cu" ], extra_compile_args=extra_compile_args, ), diff --git a/awq/quantize/pre_quant.py b/awq/quantize/pre_quant.py index cf24d8c..5af81a5 100644 --- a/awq/quantize/pre_quant.py +++ b/awq/quantize/pre_quant.py @@ -20,7 +20,7 @@ def get_named_linears(module): def get_blocks(model): - if isinstance(model, LlamaForCausalLM): + if model.__class__.__name__ == 'LlamaForCausalLM': layers = model.model.layers elif isinstance(model, OPTForCausalLM): layers = model.model.decoder.layers diff --git a/awq/quantize/qmodule.py b/awq/quantize/qmodule.py index e40f107..e77555a 100644 --- a/awq/quantize/qmodule.py +++ b/awq/quantize/qmodule.py @@ -4,6 +4,25 @@ import awq_inference_engine # with CUDA kernels +def make_divisible(c, divisor): + return (c + divisor - 1) // divisor + + +def calculate_zeros_width(in_features, group_size=128, pack_num=8): + if group_size >= 128: + size_multiplier = 1 + elif group_size == 64: + size_multiplier = 2 + elif group_size == 32: + size_multiplier = 4 + else: + raise NotImplementedError + + base_width = make_divisible(in_features // group_size, pack_num) + base_width = make_divisible(base_width, size_multiplier) * size_multiplier + return base_width + + class ScaledActivation(nn.Module): def __init__(self, module, scales): super().__init__() @@ -25,13 +44,15 @@ def __init__(self, w_bit, group_size, in_features, out_features, bias, dev): self.out_features = out_features self.w_bit = w_bit self.group_size = group_size if group_size != -1 else in_features + self.split_k_iters = 8 # quick sanity check (make sure aligment) assert self.in_features % self.group_size == 0 assert out_features % (32 // self.w_bit) == 0 - - self.register_buffer('qweight', torch.zeros((in_features, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) - self.register_buffer('qzeros', torch.zeros((in_features // self.group_size, out_features // (32 // self.w_bit)), dtype=torch.int32, device=dev)) - self.register_buffer('scales', torch.zeros((in_features // self.group_size, out_features), dtype=torch.float16, device=dev)) + pack_num = (32 // self.w_bit) + # TODO (Haotian): a function for buffer shape calculation + self.register_buffer('qweight', torch.zeros((out_features, in_features // pack_num), dtype=torch.int32, device=dev)) + self.register_buffer('qzeros', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size)), dtype=torch.int32, device=dev)) + self.register_buffer('scales', torch.zeros((out_features, calculate_zeros_width(in_features, self.group_size) * pack_num), dtype=torch.float16, device=dev)) if bias: self.register_buffer('bias', torch.zeros((out_features), dtype=torch.float16, device=dev)) else: @@ -46,24 +67,31 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze # need scales and zeros info for real quantization assert scales is not None and zeros is not None scale_zeros = zeros * scales - - awq_linear.scales = scales.clone().half() - if linear.bias is not None: - awq_linear.bias = linear.bias.clone().half() pack_num = 32 // awq_linear.w_bit + qscales = torch.zeros( + (scales.shape[0], calculate_zeros_width(linear.in_features, group_size) * pack_num), + dtype=torch.float16, + device=scales.device + ) + qscales[:, :scales.shape[1]] = scales + # awq_linear.scales = scales.clone().half() + awq_linear.scales = qscales + if linear.bias is not None: + awq_linear.bias = linear.bias.clone().half() intweight = [] for idx in range(awq_linear.in_features): - intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[idx // group_size]) / awq_linear.scales[idx // group_size]).to(torch.int)[:, None]) + intweight.append(torch.round((linear.weight.data[:, idx] + scale_zeros[:, idx // group_size]) / awq_linear.scales[:, idx // group_size]).to(torch.int)[:, None]) intweight = torch.cat(intweight, dim=1) - intweight = intweight.t().contiguous() + # intweight = intweight.t().contiguous() intweight = intweight.to(dtype=torch.int32) qweight = torch.zeros((intweight.shape[0], intweight.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=intweight.device) for col in range(intweight.shape[1] // pack_num): if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] + # order_map = [0, 2, 4, 6, 1, 3, 5, 7] + order_map = [0, 1, 2, 3, 4, 5, 6, 7] else: raise NotImplementedError("Only 4-bit are supported for now.") for i in range(pack_num): @@ -72,25 +100,35 @@ def from_linear(cls, linear, w_bit, group_size, init_only=False, scales=None, ze awq_linear.qweight = qweight zeros = zeros.to(dtype=torch.int32) - qzeros = torch.zeros((zeros.shape[0], zeros.shape[1] // 32 * awq_linear.w_bit), dtype=torch.int32, device=zeros.device) + qzeros = torch.zeros( + (zeros.shape[0], calculate_zeros_width(linear.in_features, group_size)), + dtype=torch.int32, + device=zeros.device, + ) - for col in range(zeros.shape[1] // pack_num): + for col in range((zeros.shape[1] + pack_num - 1) // pack_num): if awq_linear.w_bit == 4: - order_map = [0, 2, 4, 6, 1, 3, 5, 7] + # order_map = [0, 2, 4, 6, 1, 3, 5, 7] + order_map = [0, 1, 2, 3, 4, 5, 6, 7] else: raise NotImplementedError("Only 4-bit are supported for now.") for i in range(pack_num): + if col * pack_num + order_map[i] >= zeros.shape[1]: + continue qzero_col = zeros[:, col * pack_num + order_map[i]] qzeros[:, col] |= qzero_col << (i * awq_linear.w_bit) awq_linear.qzeros = qzeros - return awq_linear @torch.no_grad() def forward(self, x): out_shape = x.shape[:-1] + (self.out_features, ) - out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) + # out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.split_k_iters) + # print(x.shape, self.qweight.shape, self.scales.shape, self.qzeros.shape, self.group_size) + out = awq_inference_engine.gemv_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.group_size) out = out + self.bias if self.bias is not None else out + #print(out) + #assert 0 return out.reshape(out_shape) def extra_repr(self) -> str: diff --git a/awq/quantize/quantizer.py b/awq/quantize/quantizer.py index 2670607..0437fc0 100644 --- a/awq/quantize/quantizer.py +++ b/awq/quantize/quantizer.py @@ -127,8 +127,8 @@ def real_quantize_model_weight( else: module.cuda() module.weight.data, scales, zeros = pseudo_quantize_tensor(module.weight.data, n_bit=w_bit, get_scale_zp=True, **q_config) - scales = scales.t().contiguous() - zeros = zeros.t().contiguous() + # scales = scales.t().contiguous() + # zeros = zeros.t().contiguous() q_linear = WQLinear.from_linear( module, w_bit, q_config['q_group_size'], False, scales, zeros) module.cpu() diff --git a/tinychat/README.md b/tinychat/README.md index e0edc40..7657618 100644 --- a/tinychat/README.md +++ b/tinychat/README.md @@ -29,7 +29,7 @@ The current release supports: ## Examples -Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit inference. The following examples showcase that TinyChat's W4A16 generation is 2.3x faster on RTX 4090 and 1.4x faster on Jetson Orin, compared to the FP16 baselines. (Tested with [LLaMA-2-7b]( https://huggingface.co/meta-llama/Llama-2-7b-chat-hf ) model.) +Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit inference. The following examples showcase that TinyChat's W4A16 generation is up to 3.7x faster on RTX 4090 and 3.3x faster on Jetson Orin, compared to the FP16 baselines. (Tested with [LLaMA-2-7b]( https://huggingface.co/meta-llama/Llama-2-7b-chat-hf ) model.) * TinyChat on RTX 4090: @@ -44,7 +44,7 @@ Thanks to AWQ, TinyChat can now deliver more prompt responses through 4-bit infe We benchmark TinyChat on A6000 (server-class GPU), 4090 (desktop GPU) and Orin (edge GPU). -We use the default implementation from Huggingface for the FP16 baseline. The INT4 implementation applies AWQ and utilizes our fast W4A16 GPU kernel. Please notice that the end-to-end runtime for INT4 TinyChat could be further improved if we reduce the framework overhead from Huggingface (e.g. utilizing the implementation from TGI). We are working on a new release with even faster inference performance, please stay tuned! +We use the default implementation from Huggingface for the FP16 baseline. The INT4 implementation applies AWQ and utilizes our fast W4A16 GPU kernel. We also apply additional optimization techniques in the latest release. For example, we fuse all the operations in MHA/GQA/MQA into a single kernel, and fuse positional embedding kernels into the attention kernel. We also pre-allocate key-value caches to avoid the online memory allocation overhead from Huggingface. The latency reported in all tables are per-token latency for the generation stage. @@ -52,37 +52,40 @@ The latency reported in all tables are per-token latency for the generation stag | Model | FP16 latency (ms) | INT4 latency (ms) | Speedup | | ----------- |:-----------------:|:-----------------:|:-------:| -| LLaMA-2-7B | 27.14 | 12.44 | 2.18x | -| LLaMA-2-13B | 47.28 | 20.28 | 2.33x | -| Vicuna-7B | 26.06 | 12.43 | 2.10x | -| Vicuna-13B | 44.91 | 17.30 | 2.60x | -| MPT-7B | 22.79 | 16.87 | 1.35x | -| MPT-30B | OOM | 31.57 | -- | -| Falcon-7B | 39.44 | 27.34 | 1.44x | +| LLaMA-2-7B | 27.14 | 8.71 | 3.12x | +| LLaMA-2-13B | 47.28 | 14.64 | 3.23x | +| Vicuna-7B | 26.06 | 8.39 | 3.11x | +| Vicuna-13B | 44.91 | 13.46 | 3.34x | +| MPT-7B | 22.79 | 7.99 | 2.85x | +| MPT-30B | OOM | 28.15 | -- | +| Falcon-7B | 39.44 | 11.71 | 3.37x | ### 4090 Results | Model | FP16 latency (ms) | INT4 latency (ms) | Speedup | | ----------- |:-----------------:|:-----------------:|:-------:| -| LLaMA-2-7B | 19.97 | 8.66 | 2.31x | -| LLaMA-2-13B | OOM | 13.54 | -- | -| Vicuna-7B | 19.09 | 8.61 | 2.22x | -| Vicuna-13B | OOM | 12.17 | -- | -| MPT-7B | 17.09 | 12.58 | 1.36x | -| MPT-30B | OOM | 23.54 | -- | -| Falcon-7B | 29.91 | 19.84 | 1.51x | +| LLaMA-2-7B | 19.97 | 6.02* | 3.31x | +| LLaMA-2-13B | OOM | 10.35 | -- | +| Vicuna-7B | 19.09 | 5.33 | 3.58x | +| Vicuna-13B | OOM | 9.17 | -- | +| MPT-7B | 17.09 | 6.18 | 2.77x | +| MPT-30B | OOM | 20.60 | -- | +| Falcon-7B | 29.91 | 8.02 | 3.73x | + +*: The reason why LLaMA-2-7B is slower than Vicuna-7B is because we need a longer prompt (with > 500 tokens) to prevent the model from talking with itself. If we use the benchmarking strategy from exLLaMA (i.e. only 4 context tokens), our speed is around 195 tokens / second. ### Orin Results | Model | FP16 latency (ms) | INT4 latency (ms) | Speedup | | ----------- |:-----------------:|:-----------------:|:-------:| -| LLaMA-2-7B | 104.71 | 75.11 | 1.39x | -| LLaMA-2-13B | OOM | 136.81 | -- | -| Vicuna-7B | 93.12 | 65.34 | 1.43x | -| Vicuna-13B | OOM | 115.4 | -- | -| MPT-7B | 89.85 | 67.36 | 1.33x | -| Falcon-7B | 147.84 | 102.74 | 1.44x | +| LLaMA-2-7B | 104.71 | 33.07* | 3.17x | +| LLaMA-2-13B | OOM | 58.20 | -- | +| Vicuna-7B | 93.12 | 30.73 | 3.03x | +| Vicuna-13B | OOM | 54.98 | -- | +| MPT-7B | 89.85 | 31.22 | 2.88x | +| Falcon-7B | 147.84 | 45.10 | 3.28x | +*: We can similarly achieve 33 tokens / second on Orin if we use the benchmarking strategy from exLLaMA. ## Usage @@ -153,11 +156,21 @@ python demo.py --model_type llama \ --precision W16A16 ``` +5. (Optional) Run the benchmark script: + +```bash +cd tinychat +python benchmark.py --model_type llama \ + --model_path /PATH/TO/LLAMA2/llama-2-7b-chat \ + --q_group_size 128 +``` +Note: The kv caches in the current implementation are pre-allocated. So if you run out of memory, it might be the case that the kv cache is too large. To solve the problem, you may pass in `--max_seq_len [a smaller number]`. ## Reference -TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat). +TinyChat is inspired by the following open-source projects: [FasterTransformer](https://github.com/NVIDIA/FasterTransformer), [vLLM](https://github.com/vllm-project/vllm), [FastChat](https://github.com/lm-sys/FastChat), [llama_cu_awq](https://github.com/ankan-ban/llama_cu_awq). + diff --git a/tinychat/benchmark.py b/tinychat/benchmark.py new file mode 100644 index 0000000..5ff87ce --- /dev/null +++ b/tinychat/benchmark.py @@ -0,0 +1,123 @@ +# Usage: +# Please first install awq/kernels +# then directly run CUDA_VISIBLE_DEVICES=0 python benchmark.py +import argparse +import torch +import time +import numpy as np +from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils +import tinychat.utils.constants +from tinychat.utils.load_quant import load_awq_model +from awq.quantize.quantizer import real_quantize_model_weight +from tinychat.utils.tune import tune_all_wqlinears, device_warmup +from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp + + +def skip(*args, **kwargs): + pass + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--model_type", type=str, default="LLaMa", help="type of the model" + ) + parser.add_argument( + "--model_path", + type=str, + default="/data/llm/checkpoints/vicuna-hf/vicuna-7b", + help="path to the model", + ) + parser.add_argument("--q_group_size", type=int, default=128) + parser.add_argument( + "--verbose", + default=False, + action="store_true", + help="Wheter to print more information.", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=2048, + help="maximum sequence length for kv cache" + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=1, + help="maximum batch size for kv cache" + ) + args = parser.parse_args() + + tinychat.utils.constants.max_batch_size = args.max_batch_size + tinychat.utils.constants.max_seq_len = args.max_seq_len + from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM + + modeling_utils._init_weights = False + torch.nn.init.kaiming_uniform_ = skip + torch.nn.init.kaiming_normal_ = skip + torch.nn.init.uniform_ = skip + torch.nn.init.normal_ = skip + + device = "cuda:0" + # exLLaMA benchmarking parameters. + context_length = 4 + gen_length = 200 + input_ids = [1 for _ in range(context_length)] + + model_type_dict = { + "llama": LlamaForCausalLM, + "falcon": FalconForCausalLM, + "mpt": MPTForCausalLM, + } + + config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) + assert args.model_type.lower() in [ + "llama", + "falcon", + "mpt", + ], "We only support llama & falcon & mpt now" + model = model_type_dict[args.model_type.lower()](config).half() + real_quantize_model_weight( + model, w_bit=4, q_config=dict(q_group_size=args.q_group_size, zero_point=True), init_only=True + ) + model = model.to(device) + + # tune_all_wqlinears(model) + make_quant_attn(model, device) + make_quant_norm(model) + make_fused_mlp(model) + device_warmup(device) + + print("huggingface ckpt loaded") + print(model) + + time_lis = [] + + start_pos = 0 + + print("Benchmarking...") + with torch.inference_mode(): + for i in range(gen_length): + torch.cuda.synchronize() + t_st = time.time() + + if i == 0: + inputs = torch.as_tensor([input_ids], device=device) + else: + inputs = torch.as_tensor([[token]], device=device) + out = model(inputs, start_pos=start_pos) + start_pos += out.shape[1] + + torch.cuda.synchronize() + t_ed = time.time() + time_lis.append(t_ed - t_st) + token = out[:, -1].max(1)[1].unsqueeze(1) + if args.verbose: + print(i, np.median(time_lis)) + + print(f"Speed: {1 / np.median(time_lis)} tokens per second.") + + +if __name__ == "__main__": + main() diff --git a/tinychat/demo.py b/tinychat/demo.py index 966adac..93d095b 100644 --- a/tinychat/demo.py +++ b/tinychat/demo.py @@ -6,38 +6,46 @@ from transformers import AutoTokenizer, AutoModelForCausalLM, AutoConfig, modeling_utils from attributedict.collections import AttributeDict from tinychat.stream_generators import StreamGenerator, FalconStreamGenerator +import tinychat.utils.constants from tinychat.utils.load_quant import load_awq_model, load_awq_llama_fast from tinychat.utils.prompt_templates import get_prompter, get_stop_token_ids +from tinychat.utils.tune import device_warmup, tune_all_wqlinears import os + os.environ["CUDA_VISIBLE_DEVICES"] = "0" # opt_params in TinyLLMEngine -gen_params = AttributeDict([ - ("seed", -1), # RNG seed - ("n_threads", 1), # TODO: fix this - ("n_predict", 512), # new tokens to predict - ("n_parts", -1), # amount of model parts (-1: determine from model dimensions) - ("n_ctx", 512), # context size - ("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS) - ("n_keep", 0), # number of tokens to keep from initial prompt - ("n_vocab", 50272), # vocabulary size - - # sampling parameters - ("logit_bias", dict()), # logit bias for specific tokens: - ("top_k", 40), # <= 0 to use vocab size - ("top_p", 0.95), # 1.0 = disabled - ("tfs_z", 1.00), # 1.0 = disabled - ("typical_p", 1.00), # 1.0 = disabled - ("temp", 0.70), # 1.0 = disabled - ("repeat_penalty", 1.10), # 1.0 = disabled - ("repeat_last_n", 64), # last n tokens to penalize (0 = disable penalty, -1 = context size) - ("frequency_penalty", 0.00),# 0.0 = disabled - ("presence_penalty", 0.00), # 0.0 = disabled - ("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - ("mirostat_tau", 5.00), # target entropy - ("mirostat_eta", 0.10), # learning rate - ]) +gen_params = AttributeDict( + [ + ("seed", -1), # RNG seed + ("n_threads", 1), # TODO: fix this + ("n_predict", 512), # new tokens to predict + ("n_parts", -1), # amount of model parts (-1: determine from model dimensions) + ("n_ctx", 512), # context size + ("n_batch", 512), # batch size for prompt processing (must be >=32 to use BLAS) + ("n_keep", 0), # number of tokens to keep from initial prompt + ("n_vocab", 50272), # vocabulary size + # sampling parameters + ("logit_bias", dict()), # logit bias for specific tokens: + ("top_k", 40), # <= 0 to use vocab size + ("top_p", 0.95), # 1.0 = disabled + ("tfs_z", 1.00), # 1.0 = disabled + ("typical_p", 1.00), # 1.0 = disabled + ("temp", 0.70), # 1.0 = disabled + ("repeat_penalty", 1.10), # 1.0 = disabled + ( + "repeat_last_n", + 64, + ), # last n tokens to penalize (0 = disable penalty, -1 = context size) + ("frequency_penalty", 0.00), # 0.0 = disabled + ("presence_penalty", 0.00), # 0.0 = disabled + ("mirostat", 0), # 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 + ("mirostat_tau", 5.00), # target entropy + ("mirostat_eta", 0.10), # learning rate + ] +) + def stream_output(output_stream): print(f"ASSISTANT: ", end="", flush=True) @@ -57,40 +65,77 @@ def stream_output(output_stream): total_tokens = timing["total_tokens"] generation_time_list = timing["generation_time_list"] generation_tokens = len(generation_time_list) - average_speed = (context_time + np.sum(generation_time_list)) / (context_tokens + generation_tokens) + average_speed = (context_time + np.sum(generation_time_list)) / ( + context_tokens + generation_tokens + ) print("=" * 50) print("Speed of Inference") print("-" * 50) # print(f"Context Stage : {context_time/context_tokens * 1000:.2f} ms/token") - print(f"Generation Stage : {np.average(generation_time_list) * 1000:.2f} ms/token") + print( + f"Generation Stage : {np.average(generation_time_list) * 1000:.2f} ms/token" + ) # print(f"Average Speed : {average_speed * 1000:.2f} ms/token") print("=" * 50) # print("token num:", total_tokens) # print("Model total Time = ", (context_time + np.sum(generation_time_list))*1000, "ms" ) return " ".join(output_text) -def device_warmup(device:str): - warm_up = torch.randn((4096,4096)).to(device) - torch.mm(warm_up,warm_up) -if __name__ == '__main__': +if __name__ == "__main__": parser = argparse.ArgumentParser() - parser.add_argument('--model_type', type=str, default='LLaMa', help='type of the model') - parser.add_argument('--model_path', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b', help='path to the model') - parser.add_argument('--precision' , type=str, default='W4A16', help='compute precision') - parser.add_argument('--device' , type=str, default='cuda') - parser.add_argument('--q_group_size', type=int, default=128) - parser.add_argument('--load_quant', type=str, default='/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt', help='path to the pre-quanted 4-bit weights') + parser.add_argument( + "--model_type", type=str, default="LLaMa", help="type of the model" + ) + parser.add_argument( + "--model_path", + type=str, + default="/data/llm/checkpoints/vicuna-hf/vicuna-7b", + help="path to the model", + ) + parser.add_argument( + "--precision", type=str, default="W4A16", help="compute precision" + ) + parser.add_argument("--device", type=str, default="cuda") + parser.add_argument("--q_group_size", type=int, default=128) + parser.add_argument( + "--load_quant", + type=str, + default="/data/llm/checkpoints/vicuna-hf/vicuna-7b-awq-w4g128.pt", + help="path to the pre-quanted 4-bit weights", + ) + parser.add_argument( + "--max_seq_len", + type=int, + default=2048, + help="maximum sequence length for kv cache" + ) + parser.add_argument( + "--max_batch_size", + type=int, + default=1, + help="maximum batch size for kv cache" + ) args = parser.parse_args() - assert args.model_type.lower() in ["llama", "falcon", "mpt"], "We only support llama & falcon & mpt now" + assert args.model_type.lower() in [ + "llama", + "falcon", + "mpt", + ], "We only support llama & falcon & mpt now" assert args.precision in ["W4A16", "W16A16"], "We only support W4A16/W16A16 now" gen_params.n_predict = 512 gen_params.n_vocab = 32000 + tinychat.utils.constants.max_batch_size = args.max_batch_size + tinychat.utils.constants.max_seq_len = args.max_seq_len + # TODO (Haotian): a more elegant implementation here. + # We need to update these global variables before models use them. + from tinychat.models import FalconForCausalLM, LlamaForCausalLM, MPTForCausalLM def skip(*args, **kwargs): pass + torch.nn.init.kaiming_uniform_ = skip torch.nn.init.kaiming_normal_ = skip torch.nn.init.uniform_ = skip @@ -99,38 +144,63 @@ def skip(*args, **kwargs): config = AutoConfig.from_pretrained(args.model_path, trust_remote_code=True) if "mpt" in config.__class__.__name__.lower(): # config.init_device="meta" - tokenizer = AutoTokenizer.from_pretrained(config.tokenizer_name, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + config.tokenizer_name, trust_remote_code=True + ) else: - tokenizer = AutoTokenizer.from_pretrained(args.model_path, use_fast=False, trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + args.model_path, use_fast=False, trust_remote_code=True + ) modeling_utils._init_weights = False torch.set_default_dtype(torch.half) model = AutoModelForCausalLM.from_config(config, trust_remote_code=True) + model_type_dict = { + "llama": LlamaForCausalLM, + "falcon": FalconForCausalLM, + "mpt": MPTForCausalLM, + } + if args.precision == "W4A16": if args.model_type.lower() == "llama": - model = load_awq_llama_fast(model, args.load_quant, 4, args.q_group_size, args.device) + model = model_type_dict["llama"](config).half() + model = load_awq_llama_fast( + model, args.load_quant, 4, args.q_group_size, args.device + ) else: - model = load_awq_model(model, args.load_quant, 4, args.q_group_size, args.device) + model = ( + model_type_dict[args.model_type.lower()](config).half() + ) + model = load_awq_model( + model, args.load_quant, 4, args.q_group_size, args.device + ) else: - model = AutoModelForCausalLM.from_pretrained(args.model_path, config=config, torch_dtype=torch.float16, trust_remote_code=True).to(args.device) + loaded_model = AutoModelForCausalLM.from_pretrained( + args.model_path, + config=config, + torch_dtype=torch.float16, + trust_remote_code=True, + ) + model = model_type_dict[args.model_type.lower()](config).half().to(args.device) + model.load_state_dict(loaded_model.state_dict()) # device warm up device_warmup(args.device) + # autotune split_k_iters + # tune_all_wqlinears(model) - if args.model_type.lower() == 'falcon': - stream_generator = FalconStreamGenerator - else: - stream_generator = StreamGenerator + # TODO (Haotian): Verify if the StreamGenerator still works for the unmodified falcon impl. + stream_generator = StreamGenerator # Optimize AWQ quantized model - if args.precision == "W4A16" and args.model_type.lower() == 'llama': + if args.precision == "W4A16" and args.model_type.lower() == "llama": from tinychat.modules import make_quant_norm, make_quant_attn, make_fused_mlp + make_quant_attn(model, args.device) make_quant_norm(model) make_fused_mlp(model) - model_prompter = get_prompter(args.model_type, args.model_path) - stop_token_ids = get_stop_token_ids(args.model_type, args.model_path) + stop_token_ids = get_stop_token_ids(args.model_type, args.model_path) count = 0 while True: # Get input from the user @@ -139,7 +209,14 @@ def skip(*args, **kwargs): print("EXIT...") break model_prompter.insert_prompt(input_prompt) - output_stream = stream_generator(model, tokenizer, model_prompter.model_input, gen_params, device=args.device, stop_token_ids = stop_token_ids) - outputs = stream_output(output_stream) + output_stream = stream_generator( + model, + tokenizer, + model_prompter.model_input, + gen_params, + device=args.device, + stop_token_ids=stop_token_ids, + ) + outputs = stream_output(output_stream) model_prompter.update_template(outputs) count += 1 diff --git a/tinychat/figures/4090_example.gif b/tinychat/figures/4090_example.gif index 64857e6174ea698ccdcb013eb6d9dd8ff046dd3c..6a3609075f5eaf871f7ae9b55bf87b5ba8120c02 100644 GIT binary patch delta 2477307 zcmV((K;XaVn1li5mZJf%H3NSW07;NUAVGr#2PQnI5TV0{5E)8bNb%uBixnegyr>bQ z$BrO5irh%@<4BVwL#8~b5~a(QFj>l6N%Q4Qn>AzRyr~nX&z?Ye3f)Qc=TM_XgC;$y z6sgmuP?<_yO7-bft5u_By{Z+f*REi>irq@~>sYg8!=^o}7OmU1aM^##T}$`vT)TDS z=Dn*Iuiw6a`3l}k`0rrDg##x(tQfK5#*i6HUQGG%WXm)GB+#sRAm_~nKX(TG*)!?T zr%jhWJz6yC)vi^8mfbov>({ey)5cvJw(Z!ug99J_ySVS;xrrn1o!ofy<-DWso-W)u z^Xt;9L#M9%xp(Z`t!ID#?+#x3_wnMxZ#Q3_Jo@$S)t{H&K0W*Q^Y7EgUmw2x`1%D1 z;DG-X=-+|)5r|-a2^!eog8W6e--H4>$l!$%RtO=53O?u|h8$X$VTT}U_~D2mhPa`L zC6b8ZiY}^nqKqx3$l{AL-bf>kHNt3Pj5-DhEn?(5{Z9gk4YNYMoQ_WmSzg+rj>Gvsi&Ie<=UvXk!otG zsGh27s;sW+YO9km1!!0BX`gm#aKQ+NnlPvgAFS}h4KM6)#1KzxvBVfxoH4~5U##)R z9dGP$$RLkwvdAcxoHEHQpRDrBEwAiy%oPu8bHF&?taHvh@9cBWK>sXs&_oYybkURW z1sZ>E?RD2+-%Rt^GnXy%*)W@(Hriyft#;dM#|`(~Z`bYi-E!NVH{Nvf|E+i5dj}5q z;C~nH_u+yYp7>W^Hy-xmjzpj{54V zw@$BDLENO-LP#J^xb2GHzWDCC_n!OjxCejFJMhB`@B5jvI}|(d#~1H>^UyOdz46mO zUw!n{XK%gsLv?v-m|JT3WtQHPZ>9O+mk<8<>w_u2`|7`sKKt^wKR^8E+pmB9{KxNq z{`~`B`AD@N0tPUB0!$zR7f8SVF%W+b#MS;B*gyqVFoF}bAOtVS!3}!wejv;s1`mJO z!4Q)0gCZ4M6H<_c18m`L4Co#jYS4u@Tp|zp`NXAJK5ePs8q7}#J#4@t+ zjZciD8|B!>IvVDSU)%~F!5Ef3?ooe?cBCU5S)wf|L?Q|^WSt^i$H+xGvXPK{q$DFr z$w^wWl1R`(64Y49PG0hpo(v@@M=8oulJbBmLKN%}$2u~B zNCc!3l?XkkLPwX#8Yc9i3-#6(T44){wxgmLwJ1g}YEe=YWLI0*s7N(RQd?Yf6B6*% zL`^zUmX?&GWcVmVXBx>Uj3R#%qEN>=#vuwMAOSlOrKwPZO4OkmwWyJlLMuv2(U*#p z6e;c1N;&G%s#2Ast&r)hq+r#SHnpmxkOESV3N*3AE{F(7LKggq({-#u9jg$BEY;W- zVoD^6Gt1(f>{`yF@D*c!rD9(Pn@@83B(Rr3?4%T1*o_2n3+-qHDad~?SyHqDrj@0v zW-Z%}m>OjhOO@jwk0)rLrqVIpc-MSwkNAcrd{oZwZcm(dJlmYtckL2Bn~ zg3!{Q%JfLA&#LTFH$*c*V0Enn5bGeCTBiK1_H~6#>M5#G*K#s78soz3antkKF41;0 zBdilLv||fN5Vr(gjPB(1w;gQ&ZxvOHT6McR0ZN6ct&xA7gg*;bnSYM=?B=cLqo|k^ zTi8UzxA1U-SXSZ~mn?<%EAEa58`Z*w7=QQsCvdwO0sjB=XE;I;inn&U%VN-VWU_r~ zX-DkU*0%Ry{2daSlO^N=0s6tT?Q@9WArKmwb74$lcopI3vW7$Py&I9r$#n*R1U7%+*&a63I1=FeM=1_y8$xMTrrAVT zQB^?JP&Usn6p!&iqf{VAA415Nq4xjx``<@!`X2;;|CGev{8xGZJmi;O`W@<`^#jX( z;&(p}b7=bnyt&2nj`5O*%pZ3bzgBT}AD=;IKK5wTw`9;KeFSk?&=-JrHe>-fe-c3n|7BUEkb#n=ftF=`J2WG$kbCE$RC1<6tA{FBhGCc{P~xBt>7Wjp0A^U?Do*BE z3GjTi7eljhgG{DWB6xkQG9AQseAI_$%(s28@=U}>Xyd1Tv|@rtIDr`Ae)rTv`xJ!@ zQiVzQgjhI()B=U(*M%PlgjaHTZPtNswp4#i7#;+8D|eQF+c9~n@@Ho_XcppT1~@xH z)?X?nQx7PADkO&T5pa@Kh!~iG?3ICA0EP}SS;qf|dw|FsJYh{@5DDFsTuIPRp^y!m zh-vH4X`qM+-85V|7FtUsaC?Sl0Kj)rU|Wj;TBf*Es1*?k!GMh?WCPbnp+<40SX+O{ zR*Uwqh<0&O2j(UV(SbrYQSg?Ev_=qW){BRSbSqJ5o6rWm6#!-z5+)ZC%vKRfheowi zI|*@Y$dv%hCXT}NYs0pVL$r<&Q9IZuM&T$C-DnW?s1c;V2?>CWv}h2MfM)@4k1{b! z^C(2!RE}_`jxtqlW>*mLco6MKJ1&2*5JV)A=opb6(TP$D z!Wcw6sbAM+aXnUEOC}J8XAppwi|iJ3MJ8ohnOCC1lf@5C`SC~5Z2|A*(iSzy4H@fxsAJtkjOS!erQipP!9hGoDvBU-86P$#B7z; z5U~l42f>=M$&$DAizo?}7%81n(RfC9n8N?$oBOz(McE2#6q~}Sj!HL=&xw*NA#;!g zZ_=1<37~j*Ws9j*l|}h(RR@+znUkT3Y>H=kph207)^~!*kv1tsOSyl53M!RqSf3pM z0DIPz80wY3@LnRfik5kBo`I9LFpF$4D@ZU3|MX>_(RwPF30VLMakUPyrw*ZTezpQx zT!~c$5pbK~W=1$JN$_MWby1yVDtO3Z6V-dWM+ON%EjM^o74?6oa(~29eY0U>|6*vQ zFe-x9TOp-)Hxp7&nqPk?1_Z1Sg4p;j{B)*k<{7-`UVt_bS`dpY=A(pWSyQkoE=C6Y z7hj}M3r+ef@fBICvRc9TD?!Q%$vFmp|GH*!Hm9xN3Q$!VTF|IJiieWLgswuEgchWD z=41`$Ec0fEx3^RSHU-QAd5PC!X=-PdHDD08D>ZhgWTs+2xPX7EsDDg`d{~+*5r`{o zSSm+)8cDE+cD7Vf>VT3(Wxq;PXb53R0Id+lRbbjP@?{IR+G6P?rQ`wxuCxl7kO{!H z84(Z(t``b8bx6%Mqv{~8&r+;Xc2QCmf)!Yj1ktPGvI)NcrPYdm^a^kzig1%kEj$*l zTYVx zrEp^bkf^ml3VF&aQ81LQfdn>&vr^Ei%aQ;LXeyOzs-=Ix1+A*JxuO81a0>--gJ`M% zz7hqnIAAZEeGr?iA*N=1#u-kQhpf@8bM~=rn18k+hs^h6$x2b03M*8qhm(qjeh9cm z#bdqIMbetK0rxZTdbg}GZ@%gep_YoMwWDMJWK0=s6dQH*1ie~h zaJXr8#duFwcb~Y35N6p9#Y=Y!{~8##$Y+rYaFo{^gW_~tm%ipWY6vh2c5scX2Atu$ zzTf*{nt-xWz)GpOR2Q|S1u+UH6=SxqGY@u&suAgQ+73X;&Y zTQCZ>AiF^jkdvrX^)LYp%#Yoh07@WP3NW2M5p1HSs&}Qmk+oFrPzC{j08((it-7G2 z(T9Jl=x#D>p3yrIYSgk*Y^tOH#Ldff9-G5DSHIyjJN~C|5{k6VD_@wGos?MF zAPVAOx|#q9X#|&?mea zC!WDdW6%y85UAhTYzz%j8Zy-usNdR?9fNxB>bw6yVaoWTPAPKP5u&&_0wonQwl~!Y*!j=_V zy;W)c^_hUS3F+(#S^JPA|7WY^xk&&J1u7h^q)-MD%#b>#oK;jCXoy!pRKf&YQ5M&G z*9&$@Pzv_IvPL-3QHPiCj08A}Z%cp0!s_YFK)?mG`p%JDbp^V2mzB&B0eDF zy1S|*+K-MWvuDjt)El6lXS*IbY7D1uk*t^sTFD>hp~TyvTy@E)!N~s$$x|^aP%sLb z77AJh1uTlLlTE>_(1NkoVba2>re-R6i?G~cfA1zMpKmx37}8zyg1{MI&g{ zdV}_`rJM2DJi}xEdS*Z{3a+pL+tPhnP!D&D)G{RnT%Z|Ruz#Jg1uT}@$Ql3&kVwT! z3JN_NQos%tEJ3LuU#zr);IOg@=WQBNumuwS zJ%r1C2<6&ucq%f+f3l*PzB-||;;^^j;Rv`Zd?%%Z*`#;o+|K=s?Ch+vtEYY&+wJfV zKiJ8bt1eQ2Q<9wxq5uU@V6GU3t}$p(5zt(k-L5joEQ)Ni_*dHKg5rOt!nIaDYMil( z%ndDu_Tx{;-ZtuIYbF}Fy)R+z8Wb+EMdNOIJ7V|5V+($S*FwQjIkGEm-?_(H5Q9_aBbfcx>~KE-na^Amrl~0 zfu|XsDr4Xd?{ylQU=M#>pn!sktfNB1NQ>N>5e0ePD~yXQt7XD}P-Ck0(})-r!Paa*w0*~@QSzL~oh zak-_RN#NWBL2)Lf@M+SE%na7;G z6Q`4UYKu9%|UOz>nrT~U}`Wf7GwBOZebo+gG#cs!(G92$G&#qiLCe?M z$U;GxB*&j{|7?^|jFeM^{XY4WRH<>#6iV;OXZ9lUz7 zmYKHhrR#t00FfyR5(GdHBw001-NJOD2m(ovBmqQ;6emW!SP`Q|AQIZOk^%q-0Fo59 zNJ#)70V!K3Awrb0Wz)%&8*OUDnNcT2GWNENEO}GrNeL`nO4uX?CYg^jUEbUYm1E6T zwlB_w}A`c^B_v}0pREW0I1)|^xON=>lI z7OAmGTDAMh^(k7jXrIDOy7g}1sDJx=|1i+dg$k?C|k)#Z;8vt(7S~rh`xLbKaE%nH#1kjVq@RtD$3Pu^=Y=Ty; zP8WY9Q(HUkwfX*Y35E#XkDQ*KG2~ygb000NcNTQ$wM#)UJ45bLJwVeQnPSFjJkhlK?DR73%QE`kfI$j*we1J03$SxBnjR5 zO0(6pvd+AzfQZr+76&YEK?Ju0$~rE560*LneiN-uz+Ai2MLY52Q_sLMTkSBVqJpWr zF2CG!|H3{wvT(6Hw_=jei^3{QMnoTtsHemzvalivs=FPe1V#l1<7M#guERw%l|xj* zl_@-#!kPrL=i`0xTgj?c0uFUpiFYPfiu7Do$!Ee;+CtcxB>XqveI*iX=!kA^1;2O; zzPJBsbD1KHXR>YvSYQ&hwmNK?zD8CGGNh=OtC_+EyKN}9AY+rSn{DN!xJiG}7+x|+ ziMwJcW=dl0Efi38Y`{rTg79*IXfEWFLoRm{do~=`8=$nNF4vyqsahXh2(}lLj{d&1R9|3$dx&tSj}X=n;mo_U^clap=H|I8yP6)HV$sD z|AQa&pw(Csz@v31g6b053AaZu*KKcV^J^P+B$hnUr6+A3G|cys^+JEv#bh!e1YrJI}*tu&`j1agyl_*^t;}&;)S!pbcaH{B1{xT1+aA8(TM+vm|^&+!{{JJ zfOO2`990%7lZ~ojz!QHA34!3RuH`IQAKF{P+Cd5!-Z5V9g3iIbLz07$vL;)Aq*of* ziq&oG7fy?#A&EAa8}=({^RkKxrx-B`s&ET!nV$sUX2)T&&3#Exi6&|>1!a;L6wV%y)7BABSF?W&1;8z#gANb1Ly;Rr zFec#48v?eV1u0NaflqTeG9rICwt-ZILJVZmQ>s>tMx|F7 z(0!t0mA5z%BqM?4t&lR?$CUOoVq{p|_LAG$c|}71sTG_$#1!4_5n_A9n*`#8HnaHz zZ4yx{OAv!LSi&h&$?Yw}ph#L{{pGg*lkO(pxfco5acgj@9UMPpi*`T^e}OcvPBsQ8 z(5OT({+s5>H)8t$(X+fGK`d5+}D2|w}Kq;^DwkWw;{@)7eP0TnqGuc`ghq5t&jZs_rm#x?WWY5}1-^9-F#T z#5sR15F&6f54QrPBpG&^YaJ5V?AXnNQ5=o~ykTl<*xI?;b8L_Ese;uhN>;Y^do#?M zpdF9Fgt|3D8LX-0Mugkk)D@B>lb$RIs6*4fwxmpr?*`}lNa-TC6wKRi|AGg&LJT9A zbjQu_(gFbqnz(GlEsKdLWa}jdS;ab3VMc!~{Dw~cPEDhLRPI~VjM?sTF1nZva-5`F zwLJ3KJJE(eO0sgkT=|G3wy~mF zyl#ad@=H7WsnO5DsBawrL`o%ZoNmf)f!Y8Fe3b@f7nOQ8osqq<2mG|&sX3I|sxO9WG;S*{QvtvxVK0oJlmQ7d9@Lc}M3el& zi2c|P8^kx7Ij&i`K^n}7`(O&%vJ6|vyDhXrJ5eRx)0ne5piLtfSHcniv?EP9E!XiG z3|om?njIiGnL<)BQWV90L6}-PB><6}VPUIvL9JYRo%Co4k#NNYx~hL_6d6WRGA#TV z@WDo5^u|FPMn%e_O$0K>5-(s3#cxbUa2%JS@G8Qg2_P7Ssu-NWDaMp&3%ST7!WlAN z3WAB^MEe4Q@UR6c5CD9d2rYOgWYGkKImMUQwR$N;f*BY(OOO94z`eaX8bd4vS0I?w zX$RFY7`%arTEH-O`JR6{YLX}^h5Kj=cWW~FvIR2u3X|~#NGunP`UwYd8AigvD2jAxNLkV@S@9SD0l)w{CMm#+*%BZ55k|ItZB-_*@v9L?fX$0(T> z$U#oYY);ioNIQSD!#SiDY*7{oAR|9Sqiork=iEqy%O7QgJ7=^%SQ(|>^iIL3&4QuC ztl=vl6B^AJ9lSwL*%X~F+o;na7_F?CVcAdAOeMGrt`n-IX|$dIIwd>W#M1c1ylEO? z@=nhbI7N(zFH5CSf+Mi}#CPGCP1GS+!pY-Yt0gqfQyPC3l1R|*v`^KPBw&oW(o_jb z)X@5*(Zyi}pTLVLFbHXSrvy-dT$~Y2Bb&Kmk0db3?0Fg~^1rf704;M~?46;WlGn)STRe*ISC zG*^IKN5d*nw8Gbb%~yeak`8H@$l^_f6~u5ufJm?w|A~t@OehEwu_AI%1$F3HRlryx zDZ_u%SQkrDGg-5;NmB|UB#nbvFd7su{-U*z@Ijap6ac&xK-!KiG_WpH#9rwyIQo-7 znh=$OBAoM%{SsQ7P?Sk|mw}wSu|Sd_oHt!OH);dgL3y+YaapEe7yYV=p|gto85VfwBZpI;^O#E;vXLGzdu;+kIj^pIs00AU#&_l~COcEjx_T4as_VQ)#nY z1V94J&>VL-3|3<`#aWHyak8!}i727IS{a(GKwSTqbu?%z2@AIb8&QwuO%<9k-T<+YcHkTU*sF?pv@U+tXm#OOZ9#WEG=RmF;b*E^!rh zlN73bHPx{4^5H@x!MA}|7R<%vUOOj}mmgQ6U6^rMJ<9G>LOM!|=kvmB%J5`&y{(O?!aTLNk zwP=*D2!n|A?c}A!UEmFjR`{>;a+#|Q;^_LL@N+@6up!l|FHcrtA&%a$bK8b;ttiIg zu>jpOv5whr4$Yva21Z=6@JN5c!HhEjfC*NlKQgf069DB{GBR-v=qQN#omnh1he?kMyV1>*$=Qtp9Fv1!1yMsS;n3aH(T;c}jaG06+Hnp&+uVBS66|2zRsd@1 zU_ExxUhxgU2DaP+X2EMNlYiEdmbngF&;%WkXc6w^DcJhMEzqdhYirs+hGECj zQdK@KNX|PK7McE1>%D&iJa_YBcrFgEHo|qW<0H1?Rf=o84mb03LcieMu9o7>NEurG zVKM}wNDu`~$b>kgmN}$OCrEu4Lpq5mf5C5;$YUe?|k8COH%ondLD6jNWdl;4p zsgnBj4%R~$hJp?*5W`oxo(6};2R}Yn;O?y{F;&gVT5=a|N-}wDhw8aT|B4URtA$Yz zomwazgp5ynIhpCWtAue`WOXDyq}Knq~=ZSZDu#kzkNw<;l4?U}x!R2B*$6a_^N zg7bFF%veTOKhjrFbQepaP&bd5wOpN=Ky-03bi(2(V#=d7r*8gQHc4k2-_QbU@ggjm zY&x{gl!$`t?3V3REH)z6)`ruRTo~LA@UhBn6)*Hu|D!Ii*EKJzJzAw7WA60mJ0|`f zl-zZtL-l`g0Zk3W$8xkD|3vm&OwAx_#D%PxRP4mnQP13#St#NqI<*|Md(>uf7ip)+ zZNwTz7gbxg^h9s;4(f@mqmBd!sZTzUe7qlxVC0KL4W#Anf9Fd{Vemsfc;L8go1q4Wm2$1-& zWsgU1!#OgxJ*UP zf_ZW?&U3m~ z^u~t>aa#1bFQUG0#~){W!rHHkIc>o@m5H^M$j$;txT1&yhs&xs_xLQ2^_I~BYybh$ z+_2@ZNrE=`EzD>Ox#63L0&M?ryON>p{n>wit>kr@>Bd63-L^E;YjDba-9c^(;hX0V ztCJ44O0I1qk|Jcx3`M3H%$?t_3v9==GXE<+u83N;$e`R8LK6}m?72FZ;n|wWS(^=7 z>bIVf^L|afQm^FBG6=OQ-CB00I&qWbCa!734%Oii699Xdu4(o zP|~Uu%cwvQT$~b9YJ!<6lSpXSCdt*QNm|ugsi{^+Qd=%c3E&i}&z4V<02oVdBs+{6 z$+RPb1OQM;K6?UfTM@}h{~%J@VU&L$Wy(cTQj#ntGm|IMc2k1XU8_;cz=q{o&D05C zOJB}NSxq3yRORVv#R+O+7ZP7()PfGH7TX=F-$B;8DCq_#gkkZu2dpym6hhAi7G|5g%n#z0;G^c4jH5*QUnPUDH8B^qKL=2x5bkj z!q|?CH}ZBFlq_<{R)_ym5VwC`Mr9h;C7LUy*&&UIxoMr3YsRT2ojh_$N^f-na1xbr z`dR0kc3$U`nD$WF)lC2j8YrVDB7j5|nq)F*CQ(2VN;si_gJ~#`>{r(+;-nIeI_YS! z;i7)Fn2%VbWjp)^p(-AY8?bCDOeq+5>VK(l$5nMMTMEMGhK@vZ>{ngC{L2u(u!kG z!8aI_StTWHnHs@2Yk8KD(F%~Qgv%_$xFuy7fIgLHOGEv_#!|G(&bu*}PvRIQavZ0+ z;!&vCf@J{vjWLmL{|bNEqKUu$kut_6wXlMa1SNs>5J(FXS65xPa3X+AG$BD!N{Q{1 zm3AHPmjIIrFjWeMgE7#S@-7DRvP`U?tZMDKkE3+{D}_wRHkiH}oW8uGZ@czWsS5 zQtMg$*3P4&)KlHutk|X0a4IGCPFbOBp?9@QOJZSFcAuDC^3H%OoRd}NTo$g6tRm#{Gu06 zGnc^#QB2_ifE2)FMh=D%jBacrPZAKLE^c8iLrf#ZibcEch_N?td?O&c=#zgXK|8*2 zBbZPJy&{@%kzXv#5a&3nF9Lxe^O@w|>gdNna?*>a+d{b>sV*Sk>5^V7*TFJWM~&F% zkFFeI9rJ(qH~&JlB9Hwd;={P~%3XG{m%RLCAdf=1du)Lv1KEHOoi_#lkwO7Jkx(>r08b=zFEe9*n^@JwMRRCfu%ELQ!Uk1XeT&2vln*P>RXAkA8E7L<$l+g&2{I+Crr`188SRVUH$>kqiJRN+>`1 zQvmSgodE4*K6}bO3E+v3_`KOjV<^M``Y#GKsfVssa!!1ib2?K9Bf5x4Kmd?$eTfm_ zd=`IGiw3^`RHD9oD=^~<*SWS4OjFfY8#zi>zSzP-^Ap|4`YKq!79^&UE;}+o~)^6Ln~L<)TOUJ9Z5U>Vbjfu7KxgWVkDw4 z6)1=ziBwEtB%t6#ajZg%>aZefiaFsI&C0P3+riLiA zj*x{f<4*Qpp=y$y**znGt*c~M{Nh1-C;WX`F#bul$pE6ZZF!n}C;G7+P|Rw;Pa zO@*=VjqtjXG!_J;W`uH2ov2tb(HMUKN$n&6E>cWr##zWrwFP1{QIB#WBmx)FO5doY zg;ublh!@F^rTxL^B5uMTeU7tGq;L-j1NA)#Kmyptu;yYWxYY}OH9{0e#6xcyjO)xS zP-+R-zY3Pu7LrU{9~MtAhINP&Gps9V{l(#O6^~q=tgE99Ygj`Pj<8DFw55MNl2&BX zH^UYGr?%tRup~r56IsYa7Cbc-BtStEoZ3`4vVaIEj6)S#R0Sp;HJ%oNNG02ap2wQD zybMijf;|cm2}Q6aJl3$;ZrIR zn37Z24l*V~zKV*&0e-_Mk!gQA9NbE~mb%wDq? z_S9u%@QFgQ9CDPL9jkw6QvnJqzR48X-7Z|5VG2+<0ho|H#EImf*ilSJb%e}ui3S-;m3WZQ@kq+ST+mh4 z?=ViRaEJmnN(uHG0=}TC1RDAMN>9K>S0v2=_Mn_loCw~Lh46n+l_X&PI7k$n+oVMQ z!SNwL6of)HXo?hK0u&&ED2Ri-r9(Jq!toj5hr~jfw6x?9@fRdZP&rzjLiIp3mPGdao2J11UUs*(m9Nwd|VyMpeuSI#-WNX zwcdMd-bApE%ishgafZD_Qv~P?_Mq5sJ={(_h|mZC^%;d?)Wc~J#fMM@Xd$06f?Y%% zmr)==)=&hLB_oR5SVbUS*^r-vu){Y^Nkrj^yo3jQBmsXp^hR`~jK#FeLpUL{&Ho>6 zC`su}fh0WC;(#1~JYN9Bolazq97Y8xxPpW!1m_H1)M3R;h#pChLRA!nk6?&^Fko#k zp64)$6M~I19Ys=XM(^a24AvcTgwIw1hU{EJ|c>kj9+PB zS7u;WuAEk41y!QSA|m3$-4MEXj0a&5dDIeA`e0m2Wt_kUUjQ5dI?LsCU{uZ}Ph?cM zj0+H6(pc^fFF6sqL70g>ktk5nZG{3R*qb_NQ6_&(#j7=q8!catSy(GE4YGy*Rubrx zdQnzna8RwpVL@clDZUy3FH( za#@%u^$msv8DZYhQ(&5gtreNc9vQOKW<`GhJ8&q6eHdayfovrJ0vybIl>Y~gijglI zM213-7T5wp0s!2whZ{{nmQ2@#M%GHmSan_mTs&tvSts;tf-hJV0I7rNlIM69st^gSm()>KC1%`j+ z$ynA~RofdW+f?CcW(gII=83{A(ybAys2VC#y;ojA=cuyQw~c0hu4t;d2^45SZkbbU zO%Y^5W+sG#WRePSMH{ohSRa+fpKX^ZH4wF->RA;8iIT!U@agj5lXUqiMr3I$>56W; z=CSx$v=R`uIR*!17s9ASpOzDH`p=2Ua2P4X z>70d58I;%;ML|6v*cNO-_yibNl`MgJs;QD8kR<>YNXB~72m#?(-RW36L_vR@E-7mG zf>pfOy;|vB9x43b(@q_SlksUp*n*C_QUExaO);tjAl9%&!Jlzzb)6Y&ol`s5!)}eb{j`FD7SFm8x)kRX=%XH{VVfN2O_{USo%v*m>>QjU-mF(Xc zVXt2u3k^!iaQH`sKn#n4@AkT!B{3idK1E**OlFYpD2<=K3=0Yt1pJOkwgh2uWTjLF z)DRLNCZgf&kS{?TOY%_%j0|C4N=0({)%L0l^qeO6YU*5$a0bswa&3zxrOZNX9YiQz zTB(f)O~=@o5j_zl+LeD+LEuR{ut9CWgiOri+y$f!(yy`tKoDgAP5CKUD6?az#C~0Z1K204^ST(r?ST#dIXbg8s%`#D?p2a7ygoCweg8q|j8VOMB!E zo{-p3c$5w!%#E1L9k<2lumLMb08_YMPqc}bJj%-4!aYRUC49~IA1;NVb#Z!Qr zjI}};lyCG2;I6=sD}-?3%7mJJDp4{SAiUsJn+VVILE?IZZlrD?HJ4PXV5AS- z2!JSa4%vtel_7t`L4^c*&P(nbG8;1qLtP>v96KvU5T^4n7yn8cTL)!fumGX6hdhe# zO2wn(D)LHUHrSi;rh^qZBm%BVcUlQn%F4SGM?r*dD+tCtmyr5uFZ@ExSUmLhDp%k@ zOa%uW^d^9=b@0!)hyPmedKer~Xf$IK6S?$N`L3e%su+LP1R=>0vC$=8Rv~mcMx~Ej z*=b$O#ARpr8i)FNv;Sa;T(+UE`L8{+fg?t9Oe6JDSM@o+2ybM>9aW-&5-_Mm!2?obe)#CkEeyPV8(OmH$jqe*`g@yD`Y4r|Lz%%i>3gDP9n>HJGZ z$ZhsrwR^*}o{3gj&AjWm-cGeCbW`g9RH~PW>tJIfccs!)apzMYxkYqlu#UKnD|v51*B8OcYT;2W*Q*!}C*}y0 zh~v1MqIn)Sc>sTTk84>QwK$-k`JE5Cofm)ln-Kb;Q|#W$7DPce)o@~jlX-FZy_5o%5D=q>CwNh6nwEl9OzlUx(O7Uojd>Iw zMS*k&N!2iwoEI6vNhFQL(sCO_v8Y9MQD>X|xqtEby>Nk3g-7`^no1x5m;d@oVH$s; z*LjY?N+}eVuOo&e=&!$O6_~SzDqrb$&iY#iUU?<>0;AXhB9u)~0t9N%4`JbB+{{m|LBgiL!lQ zg+A|tO_W#u%~6kyAO*tA>`i|K>5Y)WAHzs3Crm?xX6SopX-MR}UtP&&gufvSv19v) zgLP*DK_NYbD?BWlcE9zjy5$1^8??fgkv>2GNFb1)6jH6E3{;qe6e$Tnw6Lnx&Xz-h z2{yHAWoy$GQe=Q6KtN_=t1VIx4s0mpBEcyKOQs~Kg;Xt*AyH~1<*|QN0EIbQ?aY`Y z0gx>w2};;x(-zXA3R@-!;IU~_Qc4!Gika~LB1(!oQF2X?gk)1!nkGT z)TeCSs*U^BZQF}>72=f(mv7yfaNLeM&SWdGR5GJWU@GNl^qqpkxY_4HULe z;ZlW57fuzX_u~O9FTnE%B+9@64@~f>kOWlFyay4CutD(-WRSuLE4&fL-fBj zVI*qO1U?P2P&`Ey(aRM9O_fx-6g2e202tg9Rap=GR7oN$T~b$EcfFOtMnwZaSYah- zC$lcq8WwnPR?ju<^m({(S2 zOyt9+zWnyfPYWdelT+UxEd|mvk#5hYLG5I5$~dtku)UDNa}iXp@9zMG7a!=D9{$MK+Mo z3?pZeYONb*(q_T;t{ms2E8aTmn=?I>?5sWgkZsGKX4>_CqML5Zuz8hrw4JoRRkqPb zA9^~lYk_BXT6c#u4>H=xT}#*mW*=x++k5{n8VCtC626@+7~F!QYOmL^>F-}}9d-Z# z;OXD<0X}Hn@!5B}JTkId`q!|+Cr`8cL8~hYNJPOT6l~NHjsU776Hvh40rQ7I1tySz z3v^%uA^1Rl2}Y2D6SQCjF?c}@W{`s$^k4@;_(2hdkc65s%oP5jg?3~oHzRam31OHR zgJ4D@s)->7TXPU4DBy)XyrB<)_(LHEk%&W_3uM4@k}Vvi6=aZt5}z2wEg-Rl6)a0h zD$^Y1U~w|d;@}o40!1-Cadu(^qeVa{1u|@+A!~GhqFdDXkxjg@h;=lg9r1WaJ?0S! zT9A+b1Smjphyn>n$m1dRm`Fq}Qjv{h|KuYb8A(WLkQ8CkOcd(j5YtRjk|bn^u{2T= zGIXtz%SwV)roc&6mXejLbmb69f)Fuoq$5(aWh^&IK`>GVi@c1?FVQH$E!cw;wjAT5 zoH$E=FKSQibsIu|fY+ojE@5dxHbL-PdjNm{Gkqm5BrQ(> z_$fM&J#s;ma;Bi zJr$`uDXLL64HXDT^(pzpn$d~56r%*SDm^D^&s*hGJx8TydXiN!i8Az~9rbJY6lRNm zfBeE8zxcwp|JFUW5b!EyWg=vxarfW*Pg7Rzy#qd!1@= z{%TRxdUUm^W$kOFgjv}3RXtG3LMAYQjr^!099ckwtXj)ewMa^Ym#SgaD@b)TQP%?1q1aV_=*$3j+!zg1q^Aw;dBwXViJH}`(v4_z!Aq%S zsT2rOpvdA1`W`80^=$TQD|t)D|6TVA*i_KnXj#8H;FG>pX?A^zB-ZHEldwv-Y5}fT zvrE!t)x;i(3xJv;W?uL1b-ETRXg&RQSpXAvp}pCVKd1U#>iUSRXl*V;MYdvp+M;*6 zxB^#a63bb{HWsx1ARfXpOIani!!>;EZn#3*St)&}?RJ5H~cd)#H| z^>u@784e^gfeBUM7G%T8=9{$%XF122&U3c29^E{Uz>r~yqp*c)ssn^5M3E(RHne8) zjKAnCFwqGhG+qxK={rwa(m#ZM%@{Fl%#?1BshZBT7;`y6PmdakRs3M18%=3dS9;Zq z_Rd_85C#7Y5)x1tRI6E?Yh3^DTGzeib#GpHyfq%ghHyaw6!VvA4k{bY%!Z(|5!mUI z@DPO1e)X>_ZEb9GvNfb8b*M4u?N4tygyS|hxG!jC>1f;6?cR1Lou(y!(LrcE-qb=Q z-eb-$tITL%K5NVN?a|MLS!e+#xU(*cB7h&<-(ohn>+FU#k=+e*>=_VBs>KQk_++E& zAs0P$nmG^6dqAI>IK6CzI-*{xqn-;C#m$p)RizGZWK-7IJN7~!;NIxo0AyBA=cg3!H#A1O4pzA)E%-; zwOZwtIGPp+eXkT;$3FR%E@mzFvY zl0^3r_Z*%oa&fQ>o`aS#oW zI1Vuo!6Om>uo3<65d{$uB~cLzaS|Od5-E`pA#o8caT6;s6FU(TFOd^Lu@pz~6DbhU zPEix9Z52a*krP$X6Il^~Krt3YQ5Ihj6>-rNchMDj@fLxR6n$|Rd$AaUQ5T<$7>{ul zkx>|p(HWO97nyMy|D~}Se-Rq7u^F|I8n1C1tI-y-ksB$f8^!S%y^$QPQ5nNA9nG;E zqY)k9Q60}w9^G*r3tAp_DO z>+vBQav~+NA|vu5`_UpZk{RFbMn@`b#zD+QA)36m-dGcXTxFb#7t5tA_$voRs_F(orGDbq14Gcqr8 zGA(m6F_SYjvok^SGet8qNz*e+Gc-?gH1qI(Fa5GLS@ShrGd5v!Hf6ImY4bL1GdFQ_ zH+8c&dGj}YGdO{BIEAw~!7??G(=?S6HJ4L0ne*_z4B#SfC##S;pVK+9GdYjaIea169&C@#(?mEv?J-smTmNF*66Fu2;KAkf@ZLmD$vpeb2C5f{?jq^W$ z{WCxTbU+2PKne6f4KzU!bU_ugK^gQx9W+8AbVA{BKl5`v^)o{;bVDz+LoM_}HPk~x z6ht|cL`Aej|3@@MQFKL5v_(zyMOD;AV-!YNltyLL2PHH|DRf75v`2aLM}0I%fpkcP zv`C5cNR2c}k#tFwv`Gu_Z&n6Mr8G)esdP%Mv`Vq`O0_ggxpYgtv`fMCOT{!y$#hK3 zv`o?TOw}|^*>p|av`yjkP31IB>2yx*v`+E#PW3cT`E*bHv`+!`PX#ql33X5nwNMfD zP!%;%8Ff(|wU_d|0x5qLQVoPPT(wo}4`O1q{>I`}Wz|+;^;K(iR(DlbXZ2Qj^;dJ1 zR)w`!i4|9k60S&Ow zUJVvt{S{#gwqY6eVFMOo7dB!kc495|U@I15FIHhSwqrT=V+R&w6EcXBKC7R%czdW_k8!b5>@Bwr7dfW{nnTeHLkp zwrQF6X@eGOmo|TDsdj3u_GqgXYp+&mwYF=y_G^b0Y?C%@$#!hb)@id=Y|(aW%XV$O zwr!twRc0Zi5S9ymQd5f2Mj~9BAS981c zd8t=VeK%*UvUplAZ|P-u&D46qcYK>ydd0VVR~LQBH+s{Teb;w--}ijoSAFSse(e{2 z>lc6TSAKu>w}1Kff1Njg%~yci*MA8(eh)Z*0eFBJxB(%?Ng*hXP^g9#!jWy#r?X!yE*omcBj^mh$>G+QAIFIL8kL!4k{kV?- z8IS*1kOP^I3Hgu>IgtmM3KfzfsWB;%s*yGFASc-%Em@K=l9Ds|k|A;|ChOzqdAzNxtgWfnyLAkwKk` zY5Jz2`l+Sbsj0fAtvaf!8mqYppR;6bzZ&ZdnyJ-Vso8q1-MX#e zx-E(Nt?4?h?Yge<`mX;on4`^Ftoxd*&pNRETCf3ounoJg5&N$do3I!Aup65W=R~$1 zTah8l-zwX&PY|5X`GL=QoV!`GKihw^%XykdJDfqgv^(3JLtC{`d$mdXv|(GcW1ALR zyS8x=AWd7gqY<}L`?YnOwP!n;f!Vfy+qZ$cxO4lsmD{+PJGq;?x1Bq>p?kVrTe_>8 zx`~^%gWI~Z+Z%Zsl)rl%yPLbW8@n$eu^BtPF?+qyo4xJ0>OvAf)%yyQ&+>oT$M_~D zbtc=Nxq80$+r9lezybWGQ_fl7Tfl8HMqGuy?VJ3B&m$4Mq0(Bf^SZ(*{K73f!!dk> z=6b_5{KGvw#6f(-|Iw18<$J)}JH`K7#Uq?XPkd}Ho5izJzGHmFZG3W3{Kj*<#do~N z1^U2!JdwQ*$;7MbfLz8oDyM%*e91+;$(LHE>6+sr4a%LIEjygavAoK&e9N`G%jqoO zdU)`39LdML%z2#5XPQBjLc*Sk&DlIRUVO~UW5NpjKF7k%+dR*kYBi>&wA>>K+_OIPov>z|ux=rkd^!S*MF6NRn?|A?{-M@qowII% zvxYs_O|9B0Y1sc`eb`As%3Yn?SiRb(-P*G_3f6_NV%-;t-PVN++3lefu!5nHz!qE~ zQYtMr?7`WI{n){M*vWtWhBjf?hn=wep%j*#p+QhQR_8`)&Z7Ez@=iVYD$FEnT)6r; z;k%6CzYN4~1>yq?K%%bXT*u)#GvVv}31Moy939p1DRHI%Hb`CyU?ZhUKII{U6Fwab z%kRRBoSUKwd^WyyT%;CAK>>td$d?cRBA_0~J>^Z_FeX6)lz@NaTjO|~gy(f-!9rpq zO#bL|B2vuousi-It0qbXL{iWShYNnT?n<(~vgRq9UNGLWU(dPFo9xlPs7z1n6+Q09 zVhq}W3yxk1|4`oK^*oqf{^RRagz)Pn;bR=apNzIz!4^TloVK30NWB1K`&qNDJJd z466Tuc)!emU%R0nG$KGFe4V_FBMEGw6s(^hSiv14W&MA*z!k_~`@c~nP)8#Q4XbQn zhY~;m03t~O011E$JV?+W!G#489#rTMB14G}9VWEckl{s*86#q(xX~lWkr+vaocJ*% zN01#SZrT4t2Bu6hTVmRTsidUHoeTx8RLQfVNu4N%79}ba=}?v|Niqd`G^$dlNS_)_ zdev&slURSNQoRbcWmm3b#g+|w_Nz{A6%y5(~B?9&)h+?ah0>?=o|8;7&vDdIoMRP?^OLKtDh#?zR9L1HZ z02nt-4!{vJCk3=Qb3~E=fdo>M8zyPhYRc*$ycK_&e(l@2YtPZ?PIs6Y+pmz{ns3Cv~Lg-+D9YUC)f+>QyV0Z?xc;Soxb;uo! zB9?!sVvaF}2ONaJrP!m09U3&@j}75?p^Y*Ms3V3=uIMCwGNM>te-q-!V3hwMsilli z7U<=TO@_%GcoV{?<(E&ciKd!uvY95FZpQg$opjzgCwC4SM2ag)1V9@Q0lWx+j0qJA zfE2cHA^?>!7Mg&Pc)G}Ar9?)kprm?oC?bEMOkTR{*AV7ta+6g5P5!6!# zpn(EF5OIIvmJkApZX!UcHP)vc0!h?jk1fR(>yELu)T8XNu1FC8jvKNnXSGYNiY=HA zx`_WkhqRGOt({0{pT#`DfwOUFV=Cx&d*lDDgCVB6+@!H!UrQv4l@0Nf0 zswq&16aM;UyYjl5VwMI6jBvebo(b^3@zLjSi=a|CB)=YK*)NOp&TFx|`<^`Jkt*|7 zF@@8r8u7aR8Mz~V4=4CCxGn3$>I4}rNsViY~!B)EzNSz79)RPP&=+9(Ol3y2b+W~t;iY#qCk*gPtv#*M5_d_ z(ncG_^leB`JH9Y`Y_p2nbBjAkC!p@oX1AQR+p2#1wY@LP45MXXMM26aF>O&3={0Ra z|9VWMn74@(p6iqVUa!wCd+oH}ZhP*y@BW$YzVim#>AcYhl_&X_#f)uDLUFuAy zy49H{C%l^=1|5jP6RI$UEL?w~3!#UA7RvC2G@PLgYly=e>M(~&i9i6CP?_2(suR3< z-%yf(J|3nJi92Kn5|G22y)|(l5opD1jFv_cfEnwFh0sAm%r2*Tdp1~Y$mbtnoFh=LTP zAOZLY>yZ+8Wcub40Y^$f0U_E>B_A1qAc-#ucx(tG9|_42G4hhe$z*oWCah0h&~ns5 zVk$8d$&ckye5jPADPkfNIsI3_)(}ffuqS>ItY@`G|nz9Dov{qshA4QWkqdS%~Gav|4hz}Fsg<9iWl?{H0SExPsamT5pb1-S zqLlwC8WNm8DKPCVs=wd3sgo@ah-g2OJ1t!R{qfOS9Q=1T3C4kVD9o06I zkE1kg(U3T=h@$qenBrw4J=z^oVl0=v6mD*vTQqAzuCC9FfI}UM*e6jlV<7(x5!Nnq z%!0VHnGsE8p{%=G;wFe7;{56((-{OQNMZ`$BqV&^tJQxCrVpGs+3bBHpa}^eh^4&! zDpMx3;=&EjDi)UD8uS5*po5Lde)?Xb)*DI z>jG|}Hvk+%9Sa9O}*kj*Pu*Xt$sK?M6KX^Y?(iO30H5{aNr4pp_;~_G;S0Mb#OS`rHsHx1AwAjyaf=GLN%ZJ( zWi8curSN0N$cJO~0<;SOjSaSN{K9~KYaj$w0Q3JOpH8I+)aDGZwo?ihbI z7FCUR9!7y!CdZ>=TS3C)k9~E5cL)N?GX^QJhZNDLUG<3<_qmI33ge5y(dF)L;7ooC z=q5$i#jOQ@yT95<7zGys4um9ZLjJ7Sf_uo_br_LsSr>lp*K5ev|2*2abf|M|l(laO z5EBo#bof?p1sHMFws^;eHBPs9_*Q>F<&k&}lt-*XYhy4$QlJG8=YHYD5L(cFVdpvk zbZ+i9Yd`lA!6tSpbR5M}5BAV;$&wS~@DCkFa_#|m3M6`D_I!ee0Gsd+tPo59KpZ~7 zay#b`wvlZBAPG7bf`--#2&g{Jwn5xx9@(~d?Iv`%wsi!@9Wa3sMi+hOvvq$2l!GHt zNepK}cywqLjb$SPh(kF@0BN;%%5Q@@u zLxCAY6h}xwadZ?sV#W$Sfj9kwY^_iaiKi%}a1a#7EbY*8Wzhe9I$?CMh;jBn3I;K4 z8aNeT_&TVe8qya5vf)1mVH_hGU>y@%Alm)HV&lcG=-k-^dD4^&$6oYfGbW_GMR5wq_p^DT{v>PASEYOVuCR zM@kZ6G)n~^K12o@<|^(6A9^GOMMEBu<&O%dUktgBWu_u85lPk+fzYID0jWU}XAqju zk4Pjs7jtg`Ia=Q4k;#-uHuXt;(jJiHPBGU}B(y(VpipgK{}1%Ym9{`wof0*s(h%hU zas|P1WZ6ff^(=e?00@7;4z0jHR<=Jv8IN%?Wg^Lt;B-aNAZ;a&^!RG_I;y)q*z5<>`(RPZDlN?`wJQxF7hrT~Ja9D7zdcsVKI zi7hXJT-s8T5~_bD0ir98C2_bxSuJ-#j$)g3qJ5rq5!Z5Fi#K^%GGqQc@> zY@i*b&<^+m0JdQdZ9tc~=M!i{3gqAltcV*}wThWD|C?hIh?HrV*+B-HkW$7;9(Agv z^R!gzBA9=pfIk3onfg^p>4c*Em`caWTjc>}uyvFqB?-%>PoC5|*5r=7gg(Voh73uR zl?kWe@ueekn^hXC=UT4lS~6scG^yfWP?~XFG$c_oH9*xdQX*tYHZ1OyNJBMIblmrAhFiFp7b!0hY6H=>t=N2FiybvIZGd)2 z?VxfFp#@;MHx$>6iGpeR$`H2E3T=QOt_y#eo6vG#i4bHUjy+KeXBcq1l0&jKx>)0q z1OYnLmw^YNx$7o*$VPT}>v#;2xOE#7&3lEUU_xYyxzkz`CzylWW@T`Te`Skv1{YOX z>rIhJzuH*8$zwg(Q;Z)mjIG53h85hmM3;G1VH0xLkM7v zn9CXL%eL_qq#lgsjLX?x6o2q@XqjOc2(XZBdYW$EZK9`?Y^6WPX6? zq;V&12~oDt)}CQZ%3kE8{J4lEA%iBg9d8uOaVM1LXt>&^9mQJZ3dENlYtF;ZkJc?vr^fX7kCJdmI!QLe;cyV4 zu4B`L@%#`;P!6Qp)uLPvk7Jp;%Uj5S9K-sSNzewUJP?viss(X#&^f`?mPZOJ32lih zhGv(znP6}YVkHSEYN~%&!CFWKp-0A8xlFCb5@&S^ASfHPG5#gk^Ad zXsmo^RBm=9&e$&jXWu9V9kY`5wy;G z>{Y2J7dkbZ5Yt8mrBDVh*AAt(G`0~1%x{VSK)g7Ye|xleoZ54eU%)d`XsI!6B3lIPj{{YdlNdf=~o3?)l6bK{%!-oSOM%XkaKtzZLwrq+5v7r=& z4yBMv0&?NQjZy?kA|(amLz5RhCWtahNdhDrBX~49v*F91GaU*ADv{{XmqU{lRSL0b z#HT@*B6Z5tDpjagr*5r^RccqRS-l$FiiDM`2>}$$Ow!6$tD7FRxGJO8XTrC>NF9IN zQi{wLTkULZDe6kV7SULXLw$+rc zhruSXRZ4-(hzMA@lVJZB>K0dI%twFi4RlL8%&xpoxiv*WDOWMCx`-F+z&AP)dr~5}1&;B*<_Hh?IotDFH3UKmvaO06@}0 zAS6cF;vlG!s3-`aBp5}DEsAVti!H|l0D&k{BE*VvQ(fl^@Q`su^x?Rlfg`L zh@4g$1z-yYVIu9q|IWy0$D)6MsvL|o34j1#HR>poLJC{_1a&uzctwEM-()j2&JIs( zv&S1x+}( z0#M?zk!9N%}>j5Rk&35Z@}Uo~UneCAqC_;g7@4-kgY(T6lxjrwCkeWxZ3K697DcJmw_i ziBTkWi~C-Y)D{0O0%(6_zWU~txTsC+2`!?w6sH!}V*R69MT#~yn<0()&YNJKmhR5% zgt0PGBF7vmeRPMAAYCHW9dez;(@mGibTVkPz4qGEN&rmRdq;it-%l5Qb>D~g9f%-J zHwg%!gvLk$fF#E7^P<=j&{Cj^qKH6>F*JDy0ZBK(TcUr8q>uo5>ckv&-ZhDg zLV}ijm6Vh*mJFWw{fp<{_}w7@2sK&X?q*jdHaJFpz)O;!T$g|wl}>;W&=L>=5QLFz zVM>xH(nXjDuLJ}@fDqim${6@R6#7qvM%oDJT3EiU$W9@S(Zsw;q9qARY81UfOzlc| z!W9Bgh$);M{{nwT!JiO*OGN5ZA9dct1tmy860K+jmpY>bcC_JimXSgmEChrkkVA_D zp-=?c(vI>#ut3c7n83z>Mi0K{UqcKc5$V`G$M6m(TG&ef=(mOWu}^?n*kMU#I68$~ zXuBfOEHTF0g_NFo3i(IXH* zqPvO+k9^i!UjbD|$NuFKm%FrMFYEYAUHUSZz~m({g;~sDCKH#MbOpQ*1{Y6vB#PF%Ebu3 zSS426PNab;gI`2iI|X!8ng~EbJd5N_HlXn*#&lf->R|>U81S3UiRS5uwFvRT_|{Mu`GDgu{hy0swzi^-xSykw6Bnl?HpqB%V7h>KG4N zjha97C|Ok~Kmaf(I4)tOGTmxd*-0>x&D&vhR_B--os<-1m_jtYNmhVP6_@R-p7>11 z1qWjEhvnnaAzMkka#7ZhKja=l)2dks=xKUMk_1R3z&ip~Pk%KuNcM<&*CF0gp&-;_ z=|+EY5dR8FfSDwWohM!APVq^mT&~=q2MMaB3yCl%*JB-Q@9K~l0sy*2iXcIS*TR+- zGL*yZ;Ug152Js#I^N459>~xdcq4V8wyj5wZ!PkrZPDLMiSW z0IecxHBwmRdxRw22>A0MH>0brz9?0^Mi_s9NWdr$Gub?sD4|ca%p(Fg$ixO>EG?|J`w z-v=M~!xw(>i4SnIm|_9u@kkU<s01=r!tcwn7GLn$G zyujlQ@0dOyIE(U-oL7Mi0^~hdiH}wwHe0ZdQfLJMC3JH+*A>)DS_G+q@m5I8Cs; z%-gWKA%&cqk$S@kwmA zyom`B6{ujBsiPL>%OCZ!5iU^|fJv0BD3kwdL!v7>g}{``kS~8T3=>I^5*nM4KL6Q0 zCYiXxQ@ADMJp}j}Rhf``a0|D%n^!?6-B=>(vk+8q4ed(+FOs|PbBAED2kiK@p#cou zIK{PaDVcZ=86?EO76hEk7JDH&g{kt5Z>kQKX$OMoMl>m(rnw8E_iC^1Ai1;b>lZ`ZL&(^V=@lPI8uv| zIr^Uf^p!QixG4Y=b>pSPTDSh8Ah2{72|_vvC=*I5%kVKEEE$LhNDb~8gZYsHLMkEZ z;Up4+EFS|x+cql0B`?!13CWdl!8Fk#oy?ld8#$!` z8U@^GtpKXAF$1m;VkrnhsLh-*#9|SfF%1~1tqGd38j^x3xDokcoJo10cA~ISN-1@c zO<-c4pQ)rUq71~Gg1C^Q#soGap&p2wl_0Q{_3?kAhe(2QDjn@ylZx>ZnE#wib~r$R z$Ra0V!R)zA*utYE8YO}dEBbjC#Iz+sl7J{eI`laS`9y#kQXp@ftue42dua$ZAq7=Z zH^c0a>k^+JPz}>?r0tS|5&EDV!?0)yG66EJi#R0V2_OA5BTi#W_e4$;WwH}ZQ50QK z7v+B+XM!e`njV3um?~1KAkY;~7&G7jr-Qhl0FZ=qlD{bMqBIg0QY)bmqK-@~5cxYI z0)vtIL&yNo&dJCpS4zs^u}C5@&D4lc>SRzOvIW{$lT+Ls9lH@Q8>O~jhxk-c#?e0o z5Q51Z9zrW9BN-&fAQG8cH=8npx_Y41fTn-3ni-&?)2U*fpSq8vatA8^XetG`$-)YP z0C+OhiGnSlBLCRdWDNG)ol0fVC=-*f z6shCfpN5dX$s7g3;!$H-%NH%4%EO%|>QF7g6xQocv_c(s@{{thPyXzzJOWOGNX&oH zA_YyT5+xZ-XuV4;p&mY(mr+GgqH-s?k|z3DABvEI*E%E#A*8CQ9xqj|*}1Kms)-FC zvE|CuVNIqRsn=wjjgY)9QwmH(Qlo$PwYu2Q z>YNaZ5|v8*wHBEmR%inPyI5BWEIVu^10%~f+E=Zr1xkWgE>c|DXp4M0+s7b-$bBe3 zu}F0r&H?Et;P^2ByS|hf?Sj0JzE4V})AKu*Hdj1|bd7cnbxlnIdi(LaIG5Bn6K9MJ;sDP7#n9 z2F9I`te=Yl5jGHIgb05_Bs`~Bo;aR}$}l8pjNCDZ+#384oByDfO#tMFfZ;Tz;YqG! zO1@-E9u>1d3j(AF#W6^OlD?wU7}r^%QUJzan)IzI|D$0l5&(a#L^&pL3m4`=2HeSq z8;=NFk@P^qelxHDkPi*y!2H;c{um1mQD;nM9SUg)Dq7?vd^X3SWfuImOoj@Os1cjl zU(>{A5-}Yr1ZRL_!zIkaB)DNtS(1%_*-3GdER2#-D!m8R643FI2DZI4ayek%EGDuls$hI2tJA6iln-MfMMuql#yskE7TYCp~!as7K)UBo}hT>pvjbT>=sHK zIca1FehJ_r*2JG+6<4`H08nN$HWFsB3@uOsS1=Y{@fBcs2vR_100;o za>19FwdLoS= zhF>bO47YGIgDPdpSPNLmpAFNZv5OGqf>9?7zKvYoz!+{8$G-`nB$?OFj!T%u~``~Wh;~XATio^-0r#=zk z+sNwxDZ8WL2qUTjovcHH9U#gjfhGl8(<$c?*1KvYAB!=C`miL5An9s3tQ~CbQBsk} zAFDMV<7!Z`SYYVV!@8s(cU>_Lo0mnxTmDHM2r7S)?3p2Sb=JEMpz4Ao9|KuzqEI|x z6LJ;JR3-4a6(Sh>D~wXuRtwu(_A`X2pa9aLj*43JL8x-QTjn_;!tfyPJV45Ahr&8l zI5H9=btQJF9Xd+_?gS$-qK&-kQzn57E)pP!GVY@>7CFjWyBej=tS#``opo)i?7}Y+ zny!Cpost0Eq&wqK$h4Q*sytN#9tv%x?dk|7v5+&? z8wc@B-*hGKbT9l?OBb>tgNf-83}uE+yIya%6JF>C{k8PBP|1PNcuJ6 zHEzKu<^;&1l0}FLVU=`6n{?eGxHz;=rQLsOb&M%ll@ZBNTf9m0aSVAPpH^UpK_9PE zGo6hQfSlr`OfvGw<&q@F%s&+>bsJQ|6^UXADrM(205U4`$zEg@g29PYsG`~x)gwbc z9!)o^z$hWYm3PFNEd+7WQ;pRb)pWZmC1tmTgc(Eodd32v4HO1fHOG z9eMpN4X&$?=%YXh`Li9d=GuASiKWTZqe5>khD`!Ovdx(9E+fyOjo5aFH=goRunWTW z`}&C-)dE}4gjI@w{`<}SQi;)^k=B1`CjRoT@(8d76H*#c3|~PNG^ZjKluTlIu(slK zJ)*F^$a509rFq#apIszQx*sFYF7zRB(&41U|M6yL9Sesr#pfZP6|z2gxx=TiBH8oU z0Z>0HvJwT+u#bJHpM6i4{n`)KD+^S=rL#esoi7VB?qwdue~YwGTv?BREJ}aZyi)cc zQ>f7p5j0;S>S2W}zl_mp1#-Bu98D7VpcT;nL?eI-f0P}oM$^t*gvB6Gw%K?dGGGOX zVD=<3_~aa+M^gf@?_C4hwjqPC%?Boc2!bGqtyTg8BnXlK2m&Nz>`g$Z5C{MuQtFkF zxG({djY&Li?8wk!J1$AexKe*G#T6F?rCcfDVvmbPAPGL&#Ca2B&V)N{0`=K*C`X<} ziRw(capJ~Nwd&d0lmI|Tqe_t)9hqR1t#%Ru001DRoh?!d5ZYvV^(Wh_RNqdG3wLhZ zQLLoMjVp!K7729!=C!zq)D#7P59dvg7ct|-ag&h2OBaZiO;X-g-a>yh38}twrEJoC zbBrdrTP#!EG)6I8&Qlz7mJFH{YK{PTK!(4hBrpHgGD-@<#=adNmeQH--j4txgl0W; z%k8;iq*Bd!8@O=M(V6>=i)73e;NE|OPq2kdN4k(VWq%8rj7=n4JTE+bSuYvmj@`o# zuirb6QLSh`)($D@#id0pg;Aw{i7i_gCR{5YA;C&3ws4UhDQ$?`6>#^_2$0w7jnE9_12pL(RUqJW%S|GfqmC$>ly;($iB!cHxdF4qn#Ic{PL zCwTs~g@(#SH)kv9jp5D|O1RS8O=T?l~yS`e=S%IJruJ}5IErpe0P#NxjC`&A_*RDAbDSXk2?Pjgu`XO9V)M5`Aq}G})TnLacigl)u zQlFdB4VP|?-j4b1xJ&j3KopzE+Zria#HGm?_DMlT6dZ9!;Bm<<7KOV21aRJx6Fa7* zwd$&=C6l%FmtU4k4oQlAC5G1RV-=M2agokz&sjaCKS=q*0$rD-xGpChC+k z|0ZjWLmM~dd^A=zcfKW~IOZ7YVGP$Odw;wdY-Csoz?kv^fgN1c5j|JkVufbYxRL2~ za9-O%S+wSCZ{9h7=bwAtd8bPddt^PeY&otiot(8XTv~X#;({QyG69Taw4CLx3(h8|rjz6Jnj_UcB^iYv$F^{5g7 zs2qi-zn013CN8~_#7bNQvHb*7+(L!x--Axq@0qhL0;8~h2cP63g$bDjTUJ?MPas2z z4U{S;%&AXU+QPsDNP-Nch)4nuvX&3-{{)0s8KDJ%AiymUg&_&!jR#L70TmKX0v}0& zo(=|<1cYXW`3h6bWXP5r?nER4lmZd}5QHs|V+%MDpcZT*1u4uh0`Zz50!S#r5@Hc2 z%u(V_gk!URBvfaHVH64u5m!awa8NBb!HyU^W1o*~!X%Hm;1%l#lA6siC<#N44gq!< z4#`m<=Xs0^fifl^g3(=Xq$4A(xDkdNt%D0mLOEK27aAI83+t(8m@h`mj@*2!Iq^0UZQHp$)9~N*=LDkx;e+PZQNbJ1Dq`R+uXZ zuyjg2xPzL|f*PLP+I>sPloPQ&MorP$6TFRSpDPJ`Bo2d$O5J zDL@DnX((A0dd~=8L?fjlg)N}kPuXok0Ah82C{SV<6_>zd3#wAW_7uR0TX6D~bS*1G zyZX_w8up4W?b^~N$`P@YtCxkH|142PL5A*ZQ={!_DG2?!iBe3$qlu+#FbA@Q(!BAE zqhRSrx+xh6bY@{6i9)-e*~1ya6Q9alBOgNuN-_uu0h*xALVhcWDsI7@HB*2Z2W!}W zZfYU|Gpr{UH9OBezLlM$cq{;bikF}IGeMLiXKp=m8qSi|j=C+O6fCP;V{WE0EIp1Q zeX>HtCbyVU@nCxYc~4DpK`~(hf>tqk6X9Btxu7IPJ+f1ZRfdx+-9g9{7eWfJgh{G5 zG$A`|Q4cKuKnwMdR}dSzZ$5#b>6aJX3^Ka+rf49fO!vMu9|{dkj# zY|)pr{h~+2xnn7CagR{^re*(ilEE;QY7?wjtxY5m;~WCnjrx6ze0`}9Ao`NO&h77* z%S>1^xA@FzPBWY7h`R|0z$Y=o6>q4*6eQepEEx%aK!5_uC$6z1Bo*{RBEX7&Hq_am zRIy${!V*hSY7$o6bq`tIhZY=mMTRoeea5z@YnO!Xd>H4zdw*+v;TJ z-Z#GU(yy0IrEZs|ZRR)=|F6N+*GSGJ94RV%98I*FaF>;ULk`}qgP~O2K*CLI{(Dn^ zHT+)u4$R1dHks4Vrs0%>i+irALCei$3)>Y~>K^0Fc9SCE+9|KeN#5^&o&!Bk07KU{ zA!FIWXxT{bArO)1^F|G(T<7pEIlzBjb>0CeC#`TtK?*Wd1vP3ptzbn0B*18~Aj&=$ zQVbBaVEP`yXPZ`DKlJT^JX{@m2P>QAqVFGOK#wO zZu_9u9`}RSeY|sD_1pjdp7&JuJ@I!R{NoqD_r_2D@R#p=^71RAOnxEiy+LvTO-5eeVhZbys|UrFDHmQbB@37)@*hydUT0s&eS-DIVcmhMrhQds&B`fSAv`9^@V2 z86KbJ(MG#<$Caqu|6quW6OCSUNR0ih8|Yb<$fO3;OdSsTA#XX1ynWv0SqYY0N4%wqpnkV4kf0>bzMUX)Zj0MGU)Po@+E4;Ens)&lnsp**sqxrq-4#zm*} zjOu)f$`m1sw8R!z%@hb>0Fc3_I7?yRq)Qs5ItI#rQTm&R*x@!&1}nG>QX1cy?VwHo z#ROU9KwVTKG+Y{y7FSgpNOc5LwUXsz`#SwZeCVGS>9>t7u zwUZ?E=0;2c{}%zl6i7u-a0nLx5n1Ap5*^MK9_J&CCPCTd6S<}yRcC&g5M0tGa&D$} zN@rQZW?ERq6D5~quE=SDCs5EOz*X6jUFH^l=w*1)6L%g{P*|TNv{Qba=6+h)M>N-T z9_E0`W`Sy(fEuWpg=W^kghw&w2-y`U2ta6xM0zS-Ttu5>m=?j*#6WeYbkb%IDTW}O z#;@F2zJvxT)Iv!BXjU{tmxP9E!XvGs6bVJ zs0pQ&M{N@+qR&{VP*iwSg%*@ZHL2t@DaHI1Cv_-+9_T_TCW3zHVrJAwWyww`1eY4r zg$~6qtrq`2UJymV%qWyWU}dI5eJO#0(ZZZ588OvvmgZURr)!?61aN15l}Z$7L7`p} zUD~B4ROUcvfu5>mQbej-%IA6(X<3r4K0tl4;R# ziUo<~+HI9+O*F)L)eNx$5s{*)6P;E&?P^)d>UmCAV`gQ17UsDY8j@)NcK z*41f}t8{`9gcKQHe&?%tL=~kczm^KI=_g?>*c|G!#hsJ4T#MQM>Oo^dFEZf?Ykl96J`%}1cci=qcYyvKKqY;}R?X%K6r#ad>5 zDUy9GJ?+;*98$B&Y_wrk#mKCGJ2}*X{+La)15Er%rxMnk9#mQqL_45Q0QecO%~x?P zW?K2go(dFY^i-2l?G1$m_{{`DXoSMzn$t3@BdM*mQIUZm8_%|AfOb`y*hS6lg&vxT zU=YUs{KoR|PvG{C;1cfO8gAht?%}41@}-GCI$s1TZsJOA%`S$A!BtC{psw*U?(teL@?!7xYVYAzFXwdc_Db&e zX7BidFZr5p`J%7?)B@>$=m$!zkK>-N?!lX=lm#=U@80%RlcFzMaY7lm++ZM(>~0}j z0g#-KODT{-@D7FoOPtt*UHslJ`7%^a>|g|)6VVh@1`&or#4X!8tWlut2a`|;Lu`9Q zENPK&!=~`sif{{;unV8C2|E(k!Y~WJ@D0as4eKxt+i(u^@D4wJ)F?J+iA9IMI+jz4 zg{M$k4>Per(U8mjFoAxRS;!(gM|3#R0mV? zss^ALLxmZ0u~cC4LX8qFd$B8Y$Ge8|D^D_;rHPc)9J{&Oepp1q)f@c2$T2}MwxlG- zh%ezfB{fH|G^@!p!z1}d^KWRgIX-jfeDgSGvx|IlV{|irIA3!(v-3J@PU6g;-?hR> z%5y>N+DS@fn~<~jy0dvm5(FnetO>LZ%}~vj-Si@}V#tu9>_h@cS3g5^J5NT#@gV;- zmLodf1V!sE%~-Qk?(^tI^t?3R%$##bf3*BsZ%Hq1_=a>J?yWqAhf33a+5J3qCeAd+-R z^-$yVO{WP}Y8i53Wo)hma~hmirfsHft=Fma+{{TCa7aq1$&cnnuYhfVzy|4c6YD!cN>LoC$?MTH*C*0YJ+!w z7xyD$_n30GQaCqRg>YbpcX$(MgDXXFXE$@B_HiHh74flz%e9ANMP;uwhuMH1uxtZy>ar>p7eA`JVfEp9A`z3wodv`k)(np(FaCBkp85&;~Gn zZ&F|-MlFydD}n}bkSl{w+cNp4pLmmVI*Nn3r*}Goi+ZSkda0xOsjK>^2eDF&GJ2Uh ztH=7RBj~HYdJo@vs>`~r3ltn@)G(RTB6}8S-SMi^y0Pe&|d$U9PvrD_O zN4vGzx&`IBvs3%Dw{fs@pFGHCe8{W($hW-7zx>NoaP#t^z6ard5Tc{V z!Fy|fCpFTuywb1y)4ROWM}5*$ebZMx)DykcYrR2! zywqns)*JoSPrcZG{n$gjOJBX&t9{wCePSdQH4*0{jWufr`B`u21dDs#hx^`dd*2^6 z-)no|>%HF(zTg8sv}=0dEB@hsKlkDjKI518(HTVsJ(HUTRC?R7b=y(mL%!iZ{^xr> z#8SB9gV5-AKIx8n2I=P=!~{_2N*>c9T%^RSfTZIvT4mKQH#Y`OoUC;IREe((eT z@C$$O6aVoWfAS;$@+-e%!1?lw^?um-?!!j&GynBlfA(Yl_G^FlbN}}lm?+1thaaU!@lgpKkdtZ{m1|P(?9;t|Nh&5{sTmUBng0%o0l?w%ZsDP zY!f037TYvS-0jWu^oO@KWNym@AngZMxKHRH##BKCLSCDpjjnv2xYg6)e@C zS;=}W8}{s1wQ1R!b&K}x+PH1u)|G3Q?p?ik`Ofu=_wU}oeE}C1Y?$z2#fks(_H@h> z@?*)6H?fiu8L|{sn=Ttn;2HFPz|TOWkSe7F^JLYLB@|r!I(2N=vt`$&ecN_!+`D!6 z=Kb4uaNxs*7bkw)cyi>+l{aVp+=l}yZ6c}ZUk;WG{QQ)W`7#oa7#Clxp z#~pnf5=bKXQgOc|kvM3{6`4H9$tIyp(#e9PjB-jVv!rs$E1#^AODBsQ|5HdZgFMqn zG|fzt%{ARD6HYbbY;(?kH|4yOPCV_*6Db!ragU)BXB@P~Ldj4f03FS;PtQ6XZ4^>I z_nh=mq;y15OE5JJ^UE%`^fFUV!334lP&bv5(nwR4w9-{wZ538kV`Wv=T4}xYRa6w2Ky+ZMD%}+wHdDp8M^(<+eL+HfgfyVTgMETZ;L}m}2I?fu{TJ z#PddcalG$}V(G}Ko?LRVEgu{6%QerOX{OIsobk{dA02dm(i<-w_0v@sJ$2Ss&uTIW z0iEnu273cyK2j7=fHv8Kip@0%9`@bA2!TI7cI0hOp84gSH@}ln{P5L3|9bM>k4y?pM9F0S`su&l{`~RZ|6l+9`TyU40R-Rx z>!*np`n-b)|?(O13(zHfmJY~T7GNWlq45P}uNpam7EK@VQ=gB=9n2t`;z29ofE zCX}EGGl;?!rVxfNlwk{5$iW!W@P;;=p$>7l!wvp0ggFGF3u*Ws6cML1oR9!-Hux|h z2GNI5?BNu7Sj8dM@QP0)q7|jc#VuxWhe_B10m&$TKr@yRjb~IN{i1+?FLn`(RE(n+ zu?R;yz7dag%%c|d*hfD8F^YQ(q#p+fNIBk-kcSkcA`?kRLprjMk!0i}BWNNNVJayX z*#s$Wlp}*I=`dn4o$E+BN>jEjEoMVP*;uK{R}dCF!k)0xbarbuY=C$0cN3#2FmqAq0+Vy@+I(G({% z$@xsTDI}P@)a5MK2}^apGoA4S<~-TSOId0uoaLORKF4X!efqPX{}gCI3Cd5lA!ihp zNe&P;Nd_SZfIHGs*+f@HQHv6XTLf{VMpv|d(TsYGqZ$p$wnQ3IkbcyoB>gCqJj&9M zcJ!qem1#w3I#Zk06sI@UX-;`sn|}REqPQuDW;h{*0PwCmPTc8Id)m|vom8bQg{etX zI#ra$6slH@DpspnQj|{BsbOX6SjAdavYyqfY2DXZ5(1xq|B13BNC1E)O2G*nC4iWJ zpTkN%0cy~`{&k>j)9YXh3fRI1*06p>>|zhwSj9Rvv61b_-lPHm+>zr@iDC?28T(kz zN*1ym(JNF&8(PzTwzHr`?P*op+SI<5wXgM+LV<%j`s{8ytwHN<)!JL${?@m_1@3T# zTU_EE*SG{SB7m5f07)cl0sufNb(ETaT;yKYy4l6d!UCU&gcE2k zva2Tt0?OFFceU|#?N!dEBl~(2zXR#!Lhj37yh^0M`2FU82OMDi{@1?)mM?4<4Bz=S z7{U*hFoP#d&}u>kDz@FGW)=Li3P+g2AvUFf2@GHa3)ug}1wQeCRovnfn>ZqW`~C2Y zMGRsG*OcN=y=6jWFjBg$Vo=>l9k+KCO?@#KV`0N0Pq+ow4$ie z#ZGmh++`oEbfTXHsw>37R$Zqz6rDL0|gNn6`ALIlXC5R~pm;-J;8g zRqA53p;@Ue7$>E$NY)Gwk(7+JBq2FTTF(*Jv$plFYu)QyyZYCS47RL+t!rN!d)T)| z_OgrJY+*C|+0G_*w5dI9Yb#sY*siv=ug&XZdmA&k|63TsNP%i{Gm8v=T;?^GzkTd% zgZtg?-Zs4FE$?pIo7?%O_r2TQZ)o%T-}=V)z6I{@cpJRl2nYDV15R*<4;}f~)+ST56w!b|)aOz%zoYdjH@Vw|q*SpgBp7*{d-S2@1{NVQ< zYL#MmN^^%3-USY zc=wJ%Qu3K^{OVKR`q{_+_Lr~y?jtUzW4HbIXXum)E!;$9F3 zTabgojt75GV02IhbC3p!&<2fg2ZitkhY$#h|F8*<@ClQD&tvq!f{K$k*zL* z9N&TuZ)Nu^%OIvt`o_^A&QTo+0|=^YN<#l>@$Qj-E;dCTfpHgo5g`9jAbW8k9;(k62fCnd%QrEw>95-5LCD1(wHhw?l=F%)Yf|DvKE zf9INi65v3jv0u)fBKu znldBzBLd1Y3G6W;GH@_{f~2AV@ye2B^kXZNMgpQhc<_=-fTA$q5&*Q)B-~P?K&deG z?X=x+aQPVV26F~n{K$~d}v6C3Nu-P*3eI}znV@L~Zp&h0yC_Zs`qQGXPCknJc zT9SqvX=XJ%&@qq_6nQ5AlE4;PVGE!okfeYw-G>b7!4`bR1t-xlwxAvU!5&&cS^(fd zOF=e2REi>i6zoAmdtwuC^dD@&7N-Ay08o%%Nd3VcAXGxt2S~pNDO;^MkYEg2;TP;7 zNy#9F#OOkuM@+SV-GYcmZNW!JZ3>_f8UoiZ7kWZ_yPqFY% zz3@);lu!eeP7T$2<}@}EATn!%eQY#FZ6^t|U`)5tOo7K{el!8zlrgI`H8>Q18hcX$ zT46~YH8>^}>?q?zr<963HADmFVWhwoymU=@!bV*|se%m&Siu&+lqbl*OQWj~0Cg#Xhw?|A+|&K|?4BS^-97G&mAs78^(kC{@?CpdM<0CLu^fx0Miw$`(ig zOZU`I*A-CNRZ!KnUE}p#3)NkJ5C64Z2NhA_wc$F{S;=P%9r7-P(J)&=Tcn^CB9u3w zBm&r>OUZyIv_%tqfgK>DNpRv|mk0o^%pQIrNn2qGfIzd_%MitgRPUinZGjVbLIRHU zGPKe+>rqqKkyoR@Wvj$&NC8G8tvKU~O5#!>*pX2HU=wEb7g~))>cLrm*uWmrjXGPw z9wxw(cmyYGffbxnSmwhPTww!pgruZk3^u79T0s&djY>#LKXr~ErqeZ>(mW|5X@K<~ z{^1uQG}5{$6l}pC;x;xtE@B-w@W=1GJfwq4B>0AyMt^l*KDHv!n8REf!TU4k@_ zw$*eba9pGv?7$|D!b+}o6j~udD-&i9;%=S9DGzT&LUU~sV0_~?O8+Ebn<9Zhwjy?q zF{{Kbf%Ym2f)sqA9keqZM>k8TlNG*KIxC{ehSnZ#!KgebuSOy?AGj`MwK-jub{(=6 z;>FoS^sx+lbDH5w^Ca2A=7rO!c`!~l^ZJ&!8b zo7jOBrtZp)H~?hT9%^A&ZWKj-V@UsjSk^d3OH@f&by1IJSL;CwdO}4<>U~LpI06PY zO!Qb=0ZTXG7d8O^qM#OZ6aavAi_61FmsT>2msFWlHIk-(9uT%UsCYG|HA(|#gx$f8 zXXDCHcDlglF_2(B{GpIp)kfi%3?xxgo2!o#Kod^1mWv}NI_^Zy)mG)SD2Lf7iP@Ny z5sQ1mO*=GQbwfz);YNam9qhp#TmdzDf{{rf0wM$f7}=1QlZ`WOR!fyPHUUb{mCou| zx>~^(Fu6H@hSz(oRWgMC^juT9x!Q-GNflVepdA4Dct69Obq72Gnm93IT6^_Fqkt9K z*FuC8j=A8Xr>e3)sw)Q$q{bp%&yhGP)T;vv!w-s-PufeQiOaKvXs`uH#z4 zrH3jiH_sP^ zho^(xXh>VZ9TxXJ+7}2Gwid+5bL-YP@1n1(buk@63fw?zRkR_HKngfo6V8(>=Q?dY zCN6V-^IPoFJ%sd|0slY(cvY(hyLq)`Xxrhgf}n$?GA@Hc3+lliP$a1lk4IrtiK3Hn zpTcoZ${uH1a_y2aAyOcc#%WVjXm!~u3xau_Sh<&Yxs}+tn|QB~f|iwHW}>kqh-qkj z)RQVUm)pS=?12jcU|Wc{7NWo|-RO34bz#qcbqWSTV+|R5o3FTsdKfn~MX0kDs#hz` zdOuVOJ~@heS9pocQvv%rw#`>12pl049G%AZSlib~2Uj3obY^sUG4Ycr@w*eD41HD%tTY)z)WE!l+eTAjFb? zuJy7W8JUGOTR!vuIKSr@!>4k^bJ)UnGrHR=0(jM$Zuf|N8^NLbxtm+dyPV5cA`U;I z7*(1yVAY^Ou`^8hl4U~z4tER4z#U3KfFHXSt~oN8bBtfKqdV1G8TvM^?6qZsj!{$! z%D`AQK~rtZ)Zm&YN~0B6;b2Y8Sw3}td(cT zq%_iv)*Uu8HlwO#U1f)qZ=E9oh6UJxg-2__N3YCSwjh;NS~Y}PM+ssR4)PCx9K?f1+SofgS1r*kh!+NFi+19hZrvnXjxAn9E{Ov=rt{=S0!l%`jiSXDI^D4MZV!yKZGna3+b71l7D(Z{$GK`l6h>i`eVO*k zzV%y=^}#Q`qez2gODx7t%wUH)I>-7KN&!dQXsm5P6INEQo1gcw_D=-(Fxf&;gd4(7 z%v=}y#G-oYIeKir>7+e>{aEXv^pyZY1-IWHycH%l2+CjeRloG*6E0U>)ueK$nr&Xm%TD2?F${|GIJpn|C z`O=;(QVPVkt5uKv%Z5CE_C%1q-4?z;B((e0Dm{Sg_JB8160Fp7Un}-(0RVjnf`rr- zGSQR(5|^n&VM<4B5kL^n<=2iKTNrm00#Puug?t4`p;JSD=EbJ~d^;g%ODoQ)SWkK9 zrDxFpP$NlkV|B1m=-EYEJj5bIGbOOeEu|>%nn_A3BvJ%9+2hK3=52%`M;vmP1bt(y z*N}aqLD))_JN2}MnC^84-$e)j*H9E%G=v*$SzUeQl$lxb#ZrI=Qj>1LXC%IT+#n-?%ImMb1`F)22@G3N5?g@gs}%PcOHhGh_zJ`r^{|3bcT@Bv1v&Oq zHy$Yw)YJ+aQn+D@g4QLsi9e)BA?$)k6v#@kz0wzdlP{$xxMqUZNztx#n?R)DgcBC1 zMgK+g{t5(wWWe_Bu3A9EMMb0dL*uO#^OS@u3EJxKnaSq(@W{rR=`gQSV5@KfQPg9f zuKrdBfCK@zqF~DRu(CzXw6)?2x-s_@!6v@6GDWa9QRl43t^8x2D-PXa4}b3NYfozl zMC=ZK7bhRvPRHqPGFuX&NkKNr7~7NYvB~yZiz~m7A^{Sl*ts;Y2xxM(b2YKpNLn=9+iT`RAU84*KY#mrnZWrl*ej z;t~d1TC=()0h+tR0ujY1_H5*DE-J>djgDd!uS^5anTc_whUw|?yUp?pcHWuL<%Xi zvG2-6rVEsgV&^jsNGl5LsE7Hsr-foAfebgALI5_lo=t>lMmw`b_<+MbppfAq*;*EV z1_u|kiV)3c;_6R&%+?d)v?3)LqaA_jF@~-=Dj3XT9D91X|@s4)PBOdoSI$5oMN-&)<3)w<5nyG+BG+mL+4gmlg2?2phNdH-s zOej(xEdXIPi&POQH~=*X;t4rsDoB@v0)Pk{Wg~d{jnEoGBT6=JCL@89d`NK%PB;lU zrmRI2tRtbIAOQhOLQf;_f`FWaWD5$J0EZT&30u@d0`mtc50GS;)>vpcSkyz#gJHK zg3}e@bITGLC<*Me%YYE%D3b(#EJDY*=s{`a$DZ=^r?U#`ABXBwp%#^>Kt-zmQImSq zrZV-ZOO2{hr|MLxRy9+{Dh)KlGCQw`rxga`=0OPrw&n1IiIKF?&FWExEtspE1fixR z6e5M>Xfu!lB1Ju5DOQ*Q2t3I#Yr5F^9|V;_0S2MZDxQe|uYfWl1A&WwPcB%SbEPXI z(gYB??24QOxCE6u>g$fABHF4f=z)_gC~3N*p<&{XL!yxW(h(CN)T@`UxCtlvlOuS+g zw^+q2o|z+mg{;u@>UGakv5HthyA-pm7Nww_B(CNze<`hE_8^6SxX3UC44y0{L%tZU zigK%X?G?tzRY+?q!n^V2Do!IIP8kH?A1g;;WLt2BD-0jE;mRc0u|*kftE)}? z!)ApWNs;JU*)6}?&Fma zYawOb;LgBUPIl5-w+S%fYB{a2A%A1+=4vOq*Vmr5wXw}@ZD*S!)c&@Qz1?kaZ=2lc zHn+I_obGX}yWH-MH@n~6?s>~Q-0-%yzVTgee$)Hk_dYj&!1pcifUn!%2G2Lb34ZW` zGhE>gU%130PVt6Y9O4)6cf~QzagAqu;1mCN!aZKlaDxVRyOJGfPCRZcw1SF;PT3A~ z<8tauFn9nUjz2b{A}(+Mo*_(Eoovy%{OPJ<$mW#Lq}9P`p=mR2!F5*Q+PL;r~ra#p~`K&=IK z(4rs_9rV0z=8~bx{U$>V{k{4%3PpP*f8L8&$>?n zFxPg_j)}hzysgqHbX6zxd`KU-=?`zVmy}eCb=C`q1xq_N$Nm=`Vl#;P-y?yMO)fi(mWa=f3&NKmPTT z|NR^HaE1@X5oQtbIV#^OT2>>RlvV{;Z)(qmEQafkTFB?1T$t-uLWP!h{z6Uy}(sS#Uxk{1FMU5o@%0n-ltp%l&3 z4$vWDo=9S#D1Rd+il0b|p*V`DSc<85imKR(toVwi7>j@rNLle2gp`5IavBJrTh z5#jWOIYLP^g;twDJ?!x}5)ctn@K&$EAxN+kAyGwRK!G7R1^?KwXE23Baa9u8VIUNd z9L*FTvBWGPXjlUgTuHDB^gv&Ulo4)_Ef7W)g=HS4z){WAAHb6!EGHgzNRG`ye^LR5 z{8bf?0!)7*Ru`flAi{tIlN>H}B3xyEg_~elWHuX;u#2>?4r)>lTvHVOIFS*7Sa?!} zlj4gUp(j)U07x)~J2M74vYY)k@Pe4$%dWR#6atzk~z~ zV@?+#Mt?IOm$V+VD2uKrm$0~sba|I{nU`~^mvy<9e%Y7*fazg;WI9VkW3%%-D#t>? z5@t~3UE>lqX5gnEN7Av%u;FdBshN-Jr?LRpm`5`rjiB16L^yZ#HJtsr8lDx zF4hE^PcxacRv>4#G+Ll$N^p38ce5~wM`iiNO9G=Stl$q)AeBy2Xq`i7Zg)jz^eZis zX#8b>uxSgf&}cpr1qhG{e%L}k0~^u_QcuG*ZN^=EMrDK(Cq?24N+1BsvnFlHoRODu zU&CuiuvkZEE9^sbh!X`j^$)WlKLK!$cX&JeVmNw38@UonxPp6yIhcBX^q>wJp%6Nu z5?Y}YdZ8AYp%}WM8rq>8`k{2RoGF(qg64Z|^Dcyj8~?gOH?}YnWy2^KRfb#eEy1HA zQ!oKQhhjGIHx$7ru=x*Xl>ix2qPFr*$axW)P%She1Ul)O91JAR1*j3Zu|JGO-D+5IzjG z38uy;IF%w=MK%($9P_T2KuaWu91w7VrZyG^*(C9377KD} zHxeKNff{6Jc5tVvwZp8l>cVZr}=VrA_FhL+f-b z3=$xObg%I=Pg|%3U5Hj8a#HnbPgv+A$FT`9!B-8rUmUR{t*`*Fk%pxQP1Piz3*r;7 z7D+y#90y4fbr_ARL{dP(P7=_h@5z5kP)$~LA2F2{&j=uY4Py{a8!+=y5+zYH&c&B_ zS*&W?wru;hf$6q!3%6@Kw{}~%cq_MZd8}q(fB%S-A|%SUKtUoP!F91=N(w<;+7S~A zs9j;Q1sXdOpkAKoHnh=m&I_)I?0 zwL@VOZV*|2fl(68WFFPEPt&2e)MbdJV1qT0WCmCo3ZWxFxD}m6A!3CZ!7C9vg0YXp zS{yrs4jEemiMTo=1rZEwb~H z6GAOpz&{gdAH!1+_$mM;2wt{;fo$=N3la(1VU4kW1O!vC9`52njbgLlyRHc=q~Ecy z!!stWpd}}PlG~(`z6U`w!nwuMO2~l#VrgLiZ_!x;6<4&jg8(EQV?!x)tG6^f!!~@k zHH^bIoVPoi!#u3Rq!O%!vsDjMW65%Ixp+fA25RQUJghV$wSYuJ6eauvE;~kD5>yJT zu|%$aN;q}0bBY!PPIeJUz&pBQ6J}=;t)QiBwtMx^MekySoQ630Igz%|2K2H8yx}G@ z@)FebdbANS8<8FEfE$NHg9;^Wz9VbGaxF)yVu-d`6JjRvQ&`m*0Q$j$*(qtU9H&D1>2*6b@Hnk%nke@I*%OCtp?L3aEV zG1N+Iy$fn(*b4hAb2$+{>mp?%kv*^+#xb$6=2@gJu|XHnJW)p?W2BlUWDi|nbwWdb zKH>~LNv50cQin3}Yt9h~!$WLf_ch&7k8(nEYocWLnrbAZGYU;5{t$H740nngdIue4 zyylpC=L$A+NvyyMun{p^5Sh2o4seVdFNwdQlLYxI>NEHsp1pWZ)z)lLUhg(JtJgrVu%cp@4`N}kefBZ8oF07lxjF^ z$!oBRsy%y<*F4wNT-S7c*LI!Pc)i#EdY#NaELGcl#8M%+06bnn`&~H!THIJ%5aLYe zk`q6YUF^YEOUE38Cm*Bt6Xc2%e_I+F_#?)lk{$?+BJm$g06p|Ama*3slEZMC-uyIqcM3Arz7AABHKYQ3EF|;OuS^6{- zn589v0#>{=5|9+6z?gWWMG%O8G!jO63lenJv+IakdzGHy!R!z>aY%)i7#{2!bUu9H zJ)Ge{{MQ)1;T_K58V=$f9^xOKivNNu6!w=afjfu<2_}0r5nYm$sfA4!r2xouBb7>4 z!8vpfOIEbeubOR=XeAQ|>wBW{A2GoJ`;<3-f)SBjMdqmB&C?pyJtAIz*C<2C7AdjY z65+15pfR0=A_yYZh^rvsI6=fA0Zn)oggYtlJ;H)0B#gI#8Zr_=i)q#yORB*d?!sjh zl>oS<3FhTd#sMMjV}xARl@o zH#QHKXExaF%DbQa2LQHU2^hQGdVqk{$t|TfQ8n&ExT;Dn^`{D^h7QX2U-J zz|(vk_t!r6a$om!fA@Bu_y2gm_Y<1UhABIKuVcdpr|qnx&Xb>s7BNq*S4p<>=Aoj5 zQzB0DGhE>C=#tN2qogVs6a7O9PkPw`)<4REGE;n|ST8?+WKDJ{L0~Cr9p$bqFK3y? zyfjHOBPPZuWeTVEM_pwg1yf+wETaWI_*t{FuCLJ#yPH`egs_=hT6GU6`XnA85kE3sGNJJw#nqozrmg&bc-_ zSjLO7HxieB`1Syi0D!;*Nm`NGB1K^J z4Y??kQp+Bh03$oPg!xkcW=xthZPvtjQ)f<|JAL*9`cr67qC<@qMS4_eQl?9tHii0B zYE-IItyaZ)RcltRTfO2;0Hk2pv0{^uJ!=;2TD5I|%f3wum#tj4anIhJYnLn#00YI= z&5IxkLoO|Y1V}>0Uz;QW0`vs{q@=A)TkQRE*jI`ze%tMYWB<2!plV999E`cBjl7t}}e<>qK#Cn*sC&7Pe>207*R6@|S z!wf@UGkaVSDJ=k`F!H@7jcX~QEv0x2fh1V5XNxH5f~%Fq8cPiU&B9a=$@u`|6He;> zq_0mv1r5~CLI*XJP(>4E6j4VHg>=zJC5^PvN+-3HQcW}E6jM(x1$EO;MGdvnQb#qF zR8>=De-%|%Pla{WS7nX0)mmq@l~%y~`s}Pd>570IB&|!p$5teG47xA9+pM`LC1@wR z?*PaRI5Ar>Z@)s3urmQt|3q7HB_#kzLbMGY0|J3f{9#2P1d5!IG*Za0rwk>846oQQ z{|PWIzP=@ephp4F%@!2<0)pOKP()22a&5Wif4c-W3lL$;@C)fNM!7YpGZcsPON!dS zdkY8!;>B;dJjX5fJ%4-|jo%2Qq*7jbSyI3q$`H_`7IsoxDS=`W>}=k8d)urtl!e=l ztY3XiG1~W9{C1PDpygLU&|of1GhwIJBD7Hi4itd@m;-k~1%R;3Gqc*bXF83qb5B8j ze;pPN02yz|cC0Tivu6si{7Rq|1?^s`u^_%=Q9#@fV83G70PriBP0gwP{<(($p`RWJKq@gJ`f0BX>+mNqt3Mnc~h49g&DF?zOz@5d;BSs7yzsP?_%$j(l2^ z!Wf`75Cd*OYYf!M1h-J8ijZO@#v6$NJLo?tY%c*1R3Qrw7#98|#5@2rj{*#Mf0qNc z@Pk|EpF)5j0S~4If-+Pg{e(iq5mJwB?Vt<`M?yX%&X9vroQVoy5<@L+q(a{l%MRVQ zKP`-5fz&f%45NU&|K~Ljg>8JJ8`&sF5Y91$pcW>amY}oL(Tu2*^SH(U5|K zqaP7T$3qr!k$5CzA|uJjM=H{hf0vYGBQsgaNkUSSl*A+`HyKJ#a`Ka@Y?Fc#Ng~4yE~^zT^yEe)VZ{{1FN@Yxq4E6Jf6GaJ5hm;6 z&`c6g#Ux&li#EjO#H6@Skpv+j0*XTajCe)^22p4Slt2(=7{7q}M4SGTj1V6hKkL*` zdNovtMd~pM66*gX3jb7Li0YR@E}W5jEJWW+!uP@bZ2_GGA*c*H!o+|gK?|CQsU^u5 z)Q&`ei$ta6DRrsTT{6|Ff2M?LDw`TrST2>SNxdpArz+K_ay6?~^=eqFy4A6EwX9zy zD^|s-*0Zivt7c_uT=D2T-mU5{jJhjzb|jX)k_BKvie>#z7W=7h(%cFD65gr3`s4Ub?jfDH7rg+reu(@tYmn_S-xb}!Eri*rq@TK@p(vX1p^ zaCzI7&E~Z%|E;xb-hO)*<$8%Pnkp|_wo6y`0?)m_gKu5i8{hfL_r3O|?|tjLU;Xws zzyHN=fcXpH0Snl_e}KVDpA-wKyQqm?jPYA%PdlLBa*MKi?Tc;Ebeq^JSht%M?1vGv zmtK81v4&MIW%CMv!K@@llKSmml*`xe5=g@7QgNq#o82U+L_e;bFlaF+-W(S;v6E%4 zZlilx6PvcOLh|j6q5E6}G+D_J#_WjsleZ|aqZVT$+=3Mwf92Gyhr8f$u8cSP+y-Y> zFVhVjbJJVc({9SN)FCpP34GuJ_nFUs-gBT0ENDLy8qkAAbfFV{Xhk=AcVB@tQ!3gO zmo(EJriuTwiHc($2qHDpSpkzT@=57WA4Lnc7X3S0il&NdB1Va zs#U#fUH6*Tf4}y1uz@{nVHca&$2NAdk-cnXH=Ehdc6PL(J#D`g+aL|0Csr)d5Ro*q zg3>0+wM#K+YI_CO-2MjmzXcv} zf)Cu_1xNV76`pX0N7~T;P4Iq^AO$FH;xpc*v8ei@e+4Mo;}%W&Shoh2xwVA1WCT$y zpu+rE1cattlp_?(4JGq8n`$5(u6D}Tvhv~iYmRY=EzNWLS?cwM=x-($LqxOmCN1;2N~|9$X=FMQw=e|W_s-tmfObmSkt=*c^}@{*@~ z<}FXCh`+Z^EtZ(h-2yT=dAC4HZ#oy=gsI{672`mb-du3r^RVb9*P_R(>}9+ARUMaA z!!!M3uby!@`5yOBOG1B2UlHU*%jwS-4H+CJfBe-n{>ZEFhb9t04yJd%N$%jnWyNFn z@bslWVc+d;7r%bHpP%t!SHIokZ~gbPpZ@l*|Ksz2fBoa%|NRF*{u99cBf$R~Km$Cy z%TxbA&BMISbHE0Kz{^`WsW^>d8yqdz1e$_0DHw%=z?N0xw@eD5$pOJaI=B#|3E=Z3 ze@X*_DA)uKBn23R!5;&pW+H`bn?R;e!AaUdCz_)k1i~K_LIw&#A{;^=G{PiA!X+$1 zCOpC>gu*A3LM4nsDx5+mw8AXJ!Y!;qF1*4n1j8>BLoEzLGF-tNM8k$t!!%^WHFU$L zz`|kkuFwILG&7XWI-`$c z)IH)ul-RN?ZE*``EIweI3!S6JtUE6e|N; zTSh%gNFsyCsFcd7q{^za%B#f6tklY_?y#!Ns45H(LY? zwoFU61UM2B8nLLdxBRpUk*F!CqWJkIzMQDN1kAn+%)b=O!X(VYj3K{F%)wkt!)#2% z^dH5HOvao{$E-}qTp`QMOw24G&Xml}e9X|iOwk0*(KOA{M9t15P1QV2)=bUMbj{d= z&Do^P*Q`z19DmFrf=j#H&EDkA-}FtSpu^xCh_wXH<0MX|kV~90q?g#IYSK)bx=g>U zChFv;)`U&Hj85(3Oq7aE?~Ebt1kc;#PVR)w@kGz-OwaXH&-P@`_jJ$rgwOeu&-$d# z`?Sye#LxZI&;I1k|MbuPjGy5|&H_zN14aMPVB5u@vwszQEK7~^i;x^k`VhJarIZRi z%k~017{ky9&Aq`Xi97+E1Zg^q8PO9p(G)Gw6(!LYMNt-gQ5RKF8HLdqjnNym(HyPO z9i`D8#ZeypQ6JS&AqCPQ4bmeu(j+a?B_+}(MN%exQYTeXDTUH0jnXUq1&ixYvpgNj z*-{J5P=Dm((wuZacN|kN6;m{Yl$HC)ZmBvo-4%%pQwp80FSQjr<+>yZJNJW&ZUGr+ z!4E4nQST#E=)=+~MN}$P)I(iVK}}Rfg;Ymv)JK)nNTt+EozzUVR86haPQ}zs1=UXd z2Z?z-K*Q5Dom191RaH&ZR`pQ#YOAuE{|ZE6w|}&%!=FGar~0K>l`F8~)wd!yIfb~~_1=xW#*mRYc zoqx#HTwT^PqcZCPFAs%tCQnf=$AC0Lt1Se&KVo8{S@?b)66*`Nj5q1D-<{aK?GTBIe~qh;Ep zZQ7-s*Y?3xkbPNGs#>5(S(4?%i!J{|kblkEmE~HP4co6(s{N8i51XA*#T^HQuckcP zpFyoFor5{St#)IxRK#)aI+ zmE6ds+{v}v%EjEv)!fYG+|Bje&IR4i72VJ!-O)AO(na0VRo&EOUAOrz&GSaJ?SIC- za8t*M3*6Jjn&cG2Z5`enotZSb+~qUjEfq6uMlyxTu{g@14BUx4Q&XYae2hud(OciC z-qYd9xA5IYtG8My$v(q5do{FimF)e+K?yCgxR2I#-_?cR_m$uHrQi9r-}=Sh z`_&6NI7J_H_L*y*e7#gqmX$=`imT=B+_6w1F1 zM(TPldxYRvNzmdv&=C${5;ox+gf)#ojHlIEkZ{@;hT0fb+8Jiy8irvSzF`}dVI9Wd z9nN8={T$9}asPxjCXlHY7=x@$P>Dn4r5T`Hx5N77+XLja~Zsw5gXpkOhk@jelCh3zd>6KRL zmQLxHM(LPlX_%gAnRaQLrs-2A)ny4B+=5U>`P*~Wl;Fyj7b`n=);e~+=W0V>l`H27 z<6~5%J)(PPLaXXU0e|Pj#s6mN?Njh@>Ue&WtLCpyJDwz%AW}fY7_1Be4T=inwP01X ziHpM?1U9&?#h%l^g8Pg*7C16+2bORcia}#GKI6hxx1x~%w`IaMAbH8<>{8U z9v>a#Y~dM$cHj+oR=ULoZpRMp#wKpa7VhIN?%_7>N-BDIj7NM2cYj zqIt^g?bhsUV(se5ZZSf({R*QDG4FXnZ}Xl4 zj0{xA{8{hDpyU^k5Aq0QP}b;9PVs16ag9cCRc>(>KXDkR@fWx88NYEAuW=mbaUIWb z8UJw??{Od>av}HeBA4+aS8^g>awBK*CwKBFPjV=qa(^k0@hji)Ca>}>$MPWO@+_zF zE=O`PZxWF>FN&M)5IbJyT@AsQAq$7|1NW9&xQD~B8lN^~ZUindgT_ixtZ}yIsAh9R zemy>K>M2{W@XB-cp0DJcUaZ~=cy_IP;d57+Aa+x&sjEQK3o&{;M)vpx3Av0-u;vQ( zb8_K8Hh)O(RnHJ>;{WPHXXr(ffKAu{00%Bj01$4{-3DCi--DDSz(B3ZaNF5}`u-jX z{`0qmot`9gM0fTD({oKAgA6PaHG!ZcQ5ReInT>E`Ft>7V|8g%U^DqbZasPI4FZXmu z_jMn0b#Hfe*K%_Y_jjlFcz1Vuhj((%_j9jzeShzFe#iHJpZ9<#cQ#{4hJ1*H)KztmtUy!xQV^HLZRbob~Z1)_2oHzCfrju^gTCZcI3TgKifq|cG?-7 z^@VjHd*_hPERScsiGgtI=#OHzZ)dM?fw&l#A%ms9uU#{sgFx%^G3)NZ!&0!Rt1!;0 zIDhL!>@|b9hgMMkfRP{tm>K$nqd0q@`mN}(HsLiSV2Fi*ZK9Zh=~y)b!4_h(`ZCZ% zKiqRa50WBA;9+P=2q^;XZ*)!?#Op;$CrG{ zpL~Ip>?p7WSGcL6U>iL29-P=V>2n|mN`LzA_99Ynhcaj(`*{hbKM2Xz2nsHtVPZPF?LTgsYK&8kZ;0Jtx@%wo} z{g?ogfA}_os0V+LT8`KPPC%BJIHrXN;v-IiF#!H)w;}S$klLq;?LP{$_(|@vE>$WUhwQ<|deJl4aUA%Vl^4;sVE?~WZ`wspq z_%LC_h7&Vh?6@#w#gQ9Nek}PiWz3c{bKdNj>qmrA!cLL-P)f}$5>y%`wfdlxEu=!5 zE^3R@l-NBPq6SLlX2T?8f^ymdQ)R;_!*yD8n8Y{o)0=euc1~s#(B{z-3V#o}NjvQ6 zJAv0OuTskY@Q>ybpa;5hx54&6b&}fjdi%kwTe%!JNa|=xK9Qk>0ud!4MQo(lLK}V! zW|RO)Jw*tiL=;L`;e;7hSP&2rY%&}fs$CeN0__bq0Zc(yxT1y^_Vgl)7|Q5kj5LxY zNh|GaK~jq=A~9o(NFY$qg?}pYSkf(}D3BroQ$2Olkv|>@V~kE3+;8k;>v;fg&dirHq55h4JD6!o<6Qk^nNxhF{~ zX33)jt=z&%0x%UiX`73($3{K2*t3D81RQ`Lx`m3hB3Ja{S!4f;HvB?&@EV9fV`>e9fN(-&C(Na6Dwb@p? zEwnR!iVLo{;gUP9x#^a>F1qZV`>wj}ip!NGSV6d}EmH*0X`dUiM+T~(8AKxh zQnb^Gn(S>A?5LC;Nq@>M`tGYJga~NDu%`UE3BXjDq6({#2^>nWL|gF6)h4$33&a+p z21s(Pev-s0s_l>wL;!D&ak4A0S~@a?LMA~K!$J{cFU`hsfwQF!BPt^h005zdR~~~U zL4QEEs0Aw+w?k*a3%jE!!fsL+a-mVPIuMb6meiFL7fOk*6n_eosbSOwp$TgMZae)q zpHDmg4Fm$+rKbd`QcL6^h&NFTc)ySX$)UlmDo*LofxDO$z-1`0ui~ug$h3@)$2g^y zg8Q91NnSxIvQ!kJ9=X}O70yaM#tmFyPd}o)yUbyd39u`U18??^Mc(VEzQ5D{yvU!D z;xQR{L&&++qkooJ(5b;HwXYPmxVAPyTGZ3Z=Zdac3o8X0lVTxK3Or|MUe6DE*nP596 z@VG66#U5?2(xf;BF#ha|RDaSSsvs2vq%F)>JW*NN)_;epdIV^OGh0|Yjwd(ZH3^2n zV*gu(Hc^Y3^)N^=N}3ImfR!Si>^4Hw92WaEvm?sxX@t{6EwVQ#i3C76Tp63F%I3ky z4as~bG#g2fAgeU(j!12=2Ni_S!p&htI6Ul+q!U_aHj=VMkkf0 z42cv;LsfHBOI?V zNPm?$$&^x^k}rcuVyH=wz<-p|&|7WQsSv@cKU@O<68b`{UcIDf6l9aJZpExxdCH4O zQ492;RVMnW2v&x%;XDx^EJxlcxZ0+ddIB$cgX z>|!(P)ClDWJP!ILMmF(_QV5fnx9lo2lYig}9i3JtQRoP5Kub*7sw9$HS!g>^oBv5k zzP2mSY?U;LwuPE5j64->#m#<_6rzy9W&n@?D{7GdtWFh@qo_)xhy+A`a(AHJHRyK_ z8s71Gx4hsruXxc*-t)3Iz3pAEd*kaC1~n>Q$xLiPra-o^e6y6zZAE{!@XW6Dw|}w} zy$2PRv=FMb^h?eCh0oYSww}-|K(&={J6J(sTiC+FJA2AjTJc#egvF)Jy=hRpLW--h z3c2_ZDoR-Fi&h+!AR636O#!x`j4)HT4MD;vymG2+?)JxH7=;#kG8$XFu^^kc2ThR) z6WU^yBOX!aL9iOJ`<`_a6)7@i6@M${PYMo;R+?FQdSbdG#qlS@dCly~%w{?N3(lRL zUE_+9P}7{jk#`kJ5sv6 zxgc+<*rVYXfv&xzZxOO23jM_)PQ&O{k6v_CeurrqDcKe7m;mBhZK%`b*?&;G8l$dh z?Pntyde(n-)9SA3C@v_N(h^y8qzR|zo&frJ0ocT9x1fYOs)t8ZOHzp5sdY~ z8X}p>ky;QXd^CNPK_~_>86q-ZUx8n;@Leo@-y2!{#y7D3-S2(-yWjygc)k%%@Pjiv z;SFE7!y^vyiTC^A6_@zMDSr-djBi}y1J5|fJ09|ni@f9{Ke@nDuJDno-2dVxZ#lG$>Btp~3{kk=XexrUPnSt6I=YZY#fYpK?(MbrGxQ)=xlM37j*T|9 zr$Zm=MRAOG05d}ed4G5Mv$aXjWKz1bVHE4ubBv;k1z5-I!wH4+)1XcQUL!#Xw=D^MeZ zaalx_qglmLJ#TIOVJ#u70Y(Y7? z!Yv?$H-A2bGa`o;oC-k#(`<+qQfSq{{Np=nm{j0S(``jI5?NaR{ewSzqbq1kK@21{ z{v!J!L2;Q!Y^(-!RoB(shWikXmk}BekU~2Q1V*{S3GEC|KvPGG50?Qz0Un1t!IxR8 z7h1koTFTd2o~2s8C0oL!TgD|_+68~zTS9>&Lw`NvbNEAqL6XS;7>`APQw$e=JjXro z$9-Mi(Tr6)*uzH6)l#I`f-nRc{zFqm!IQ9`LTSg+`9dlD#4QkCM1&7E0S5`y!bAZ| zrrbpM9AbQc1$dZ|QwG&jFqFzDlrYUv^i2&!woq%emR3~eD-}#FRAy0KSwaX{Dcr|T zP=5qLLRLbI0%j@%b^atpXu(0W3{c*}NhXDKHc4mAV{ks>GyiVEGD;^%NY)MA(p8mb zOc|YL`6EASBT)>DRlcJdNd;A0q%9~(J86PHV5eU;mObQ7_#nkd0o62S<5B!Wa~`Hn z+NX9Y1vY-qK}0}BQR9j=6Djz@J?;ia^nXKWh+D9{h&||tfFT8I!Nz}JCr=W@L;gZg z@j28Gqe@$Try^ur& z=FnW_Yu+a;QH50;&0Ec-n6@RE%4M0><(ZDDnXW0CvZc zSaofIb^^h1A%I3KnM}UsUaHDOXcg748*I1&oyNzdLX!F9C?av#OPDJAX{I~OCwyjv z96^awbZBrwszGeREifB-;wLr!!+$-*C?t&~sW^p$+JoCA=s%3jcm9JdTvkDN*+uoL zgZy6bf#D+I4DAq!Y*dFRR^RN2YXabij>H)fd>e8Fg*_C{#z;a}kimzL#cJFfS&#uI zG}kQ)ky&LeXhI(XfD+`0rnJFy+^ggjj4kj`zk)>>Q0f!T$e$F^2Y>SG9Z5h` z7@elD!Nl&8L`31h3Ypy$E7|M>)Y(b|pb6&T3slG=n*W$k%2w>f@a&28m-)%Y7$`_z z)!}oUPl=I1*L=-$$*iW7!NBfBX*gq@RL)LJEl)^|!TeS!FzwZ#>-os+5}lCLK)@)} z0-D&2z0!)*%xb+7Y_3?$5`R$#al{)|%qq1?(Yj&@0L;)^amoTxP^2l*y()ne*sS?s z4_HXVLxzZ#SPDtZ8*t*w0v*uhHW21I5a({L=V~tKX0GUVZs?Nk=zi|$#tX{WONUU? z9{nDX+!eV&4H*Jgr)I)~d>A)PDnn*^1ko${Wey zb(O)Y2_R^gT`S~H)#VjQ{H)vd#M-qG5>!q*@Q3mL)?v`O!beC*Q=U)ENJR8vg{t6) zzziZ*2$kI~3`6x%h|LBm;ECth3gZL?_OdSEK`hh?UEXHOPK?6!iq8g= zEDQsHzV7gF>{#nSY`u1c1&bTm3G6@86^Fnr8drp*954}^kbe|~aloF6akNpU)N66s z3wqQb6Y;n{Ik^j_JBGOq0OGE)Cp+w+LN+8V=(UF31 z0_eyQ{9p~h*kdb5LL2NV>HLVQC=my~kLkWL>cVpA#xgC>vMr}BF3WN**YYmkvM&EJ zFSkns-BktE9)Ih~3J2kkg{YxL`EDUm#T)(2;BkQfm<>|F74ve`XkMEwSVjGgnx+H@ z1_6=X2r{c!g%4E@wIW1T06;}$9qqD=kinVPDI!O?0;HU?r!ero0+Nowvs897sq6}f zT7=h4j1JoZfA~-h502tU;6OBDr99!%wLvSO$TTA~^M6H@PJFOVDA=Lk9jz+ibyy{E z@rZ^lhdW5^17o4>_>oaIwElwYLO2p-QQ33h6|hLDuu<6LWnbMBS-Nhf)E8( zkPw}kN=qU2Tb02mvGd#w7CK%l84$!V;hqcN=Mq|mqhZZaR|9X^cT?n7v%p zJVwDLbL=TKWi{0);sA<*ctu}LP=T3jE3j$?>3@b*>JCJx6FZeSTw{u4~ zbWiu2R`+wSMPB08Ga8IN(00YVrh0A?ssEKnNG<4nL91gv#iJ;IoJ82gWC{xLL0l787vfzqo3|Q-fVIP1x!rSc1*A-;6*<|( zN`GboTvJLg#dUoJv(dDJ;3}3i#A^x@d;5bYPA^#8CuJk`b8#z#Ol5x*(#^30WTKTZ z`PCsI0D$cNxqq`^syXu^pAS}>~ojQ`kKi^3w^vICqMXFY7ds`27;@Q@Tq!T=62_asTP z{mGfw0>gTm4!7s;7Hom|FSyxC;O>h=^9TaS2|Fm*BbZ{poc5 z1OnR5O%CrWKs7%MK-k&@APE_J0u7c?V2}%d2&8Pa+cL`5s02VRCVxn>MT>$Y7ZD&~ zRgVk+2@oYxB$WVwlzOeC5Nu_yWJVwlBW?(gAV@}X0$Y-)htOUbC2UAKL^-iy#*Z5% z*+h_JWJ8=OYi&x<4k%T7SraM*lm$gU>?YpqDo}e(|>6mf$A-n?_vu$AjYVU zr;H>^f-CvR8lxx$5>TR*b|x6hJV_#w0HaN8X+?sadKxT^@W4n4FG01WM^u&#?FXvdEK>>{WZl2`(c!xniYMKs>>+QKr=Xv54nAY`IxvZ*%P zqNSxssYR0>ZGVxZrI}jUqzressEa=54s#4OBWqy?zyw-}$j*bf%F`A+?c|D2J^#G3 zPeJwcb5KJI-SbdI6Acvq(Lo($6w*Z_t(4MB8^v_eOgY`uQ$RHZwNp_)CA8E>Nln#M zRZ(5l)mCAZlvP<@r8QPdN39jtQ*G^a*I9qnHB@19y?<5HgY4f`;0pKvF(> zECDS{j7}wj3=+seHp=P_fKBXa2QKbTwB18I{u zC7{ndwW|CNH6)~qg8v z2yW->A&w=a6$&7UA+>a@uowwv8+6VzZ`*KtlvtY#SI7uJfi2>T`Z8wQiJIA8Oplvv z;eH`Sc(ws6eeEskv`k6K2TP*PrX8=dZ`q^h&~6|S^G7i8O1VV|uw%bliVj=(Uha8? z2Y(_8#m-m072bD3VwHNvpuEzMA}N>Ao-KN}odWU9Zi91R`$pFaUnGxNwul1txTZZ< zOh5{?fC=#0K?WrM?5!3QY)1+I))&nsp+k166IWi?LKwP`hBLIG3~?Ak9o8_1Jlx?8 zf7nAH`jCi2G@=lZ7(^u&F^NoE;t`+NM1Lqck&08aq7<!hPP)oB)U z4o(w1OXf&wk~hp@O%wh=(qK68i+@rqz;;m>N3zz2rDOrXO~5OKG$p_ae@qMj<&>B@ z*|h~=x=le$Ya8&oz)Rji<} zSHaO-wQbq|^w8kO>Pf*Nz@pDXCMRw|1+k2jU9Cis8kSAa6wyWdw)vWB@ zLV!}T!mX*~G709H=}h)CBoTtAa>LyiX|`)Wc6jMwW0DA!2BR?sfhR+QLF9+r6eaIO z;Bep4%m0Mtdl`ywLM&V`RML=ji!fNL-Kk&ZILo<^*d- z^gf_A?!+!>gNyv0WaH)f-BsiqY;HTBOF3jg+Zb>Y{_UuiQ$ZKq(BBk9VATw zfB;1-a^cBDKsAEt&Pi^;iMmXsK#Fo6v^33=sfoB;NT`xdG~(k4aSc0ZbWSVkl&m~Q zB1-{Cc9=m;E>wBr?{adKJ)vo)>cnb38A zc+#1JfW47Ins--A(8zE;yW`#OIlIVs2I3Y{7w|DeCV#noV?30VC{l#c!g%gR^c}L9 zXlmpO41nxo6p9w;m=QA_9chClv?m!gl}U2wA!KMsgh)zr8oU8r24oVw|?k< z01!_9fI!k3u0uRWW)eVN{DhRcqd*XbWPmKzx_{+PF7A~U!`DJ$VeY|U9-~Uc3kcj~ zGK>j^JZBVc43YLETS9_mwrd0dgvcze#fZ=^qVVRd#0-%@*p{Zh3hwwaqRg^J?#?iKW`YXR<==85(h^M zSbyQW9_Y6QM`%Qft^Ov;`~?z0@srr8f=uVPM4}Z4EUiYN7X)MKaE)kG!y*!=9!3Fh z41!~TN%io^aMtQ{Y@rNX@I~6;Am*<_;7gcRrxvupc3O%n(r$Np=Xai|6qu_OsIe4` zXBe3YacZkh4(jl@0=l3_H9jJJ#3`qmrGJlrNwVgr6@seJG-++5VEvR}q%d(#@&*B- zz`7L0{1PC1E{bhtiZd>QL)yU=hz26~r!+>Q7T-qf00X(rLEkQ8EVlp&JZKCOz$5m?9z?BQxQV|6N8?(dDGh8i7HoUck`-zp75{K@ z`~fWgQgmi5FDc3-*v0Me5hEc}DSspLbcO<+G%|&%=P`-t!~g(zoM#COLN(gqA3k$> z01PtqrklQr5I1O?F0-5>QZr@q6DN~4CG$3Cb1>OvB1l0TbMrFmLpRcnDszheIDeBW z^~sEECkZ%2@EEclNMSJK41NH!9asS`3hEm-L81DN_$Wv4WI{O;z@lK%r+;i={{G=0 zRx>j9u@{-~weSddxL`iGpiWA`6|S)wb0@p-XgO)BOFD%6vL_;X3fGpgSpo|C_GmsO zk^>?rM%)CdpcD1N$*R@~H~~bS;xVn*rQZxFbSwkLOv-h<$4xq@ZtP|Drt)$Ysrs@@ zC}A`xWfUlBlqhG^Mr$-ibAR+kbreT^v`2xIN9_uGB5AUW(>BMmLp+N^d+{=9#IrCo zZZ_c_M9Vdklq>wPm>8^>h7o}@hjQSFtrUYL6ZAX1k6l_Jz_e#5){{nT!T%qDfE0KK z02HS~7$+krB7bBvPR}wgzZ9Yv1}%$Yf3UKBqO>+YNxhELZ^%a_G=J-DcE=SI0GK3T z6KoDe0?el_D-zi&w(_J@v^4oTxXbd1NAka%KZR)tI~n2fo~pcKd;(`rUGFk)*o zVMibXEnZ7Q(1Js#EPnt(k0o5>Tq9&Abi`ja&5@$uSGn~@T8>IEqK>Y+{YOhR`9BbX)7#=>kcBjHq#GYi8W+9U`B_Rhv)#U8EyC5BB_jqH?I z0t(PiD?DRBB0vsB_R~V6(zwMu^sD9A^*|UF@AQ$)=)x?FX@5QJ;#yik36kI^$V}B} zWIh0(y|zGJHURq?T8#mTIv!YPA+vxHfCS_G`tq zYR7hK$@Xl`Hf_BYZP!+9iN$TNRz2aO)))=LK!V}ukM#!RTE?Xgk+$I&@#D}&NsdYg zq{7e|XlQ3bjemp<`Ba4E{$pF9B+;gBAROZ+DkBQ~p=EGk*LlG zYbGur>Oio@JPM=9q=U%ju2!i;Gu-1%KqFLR7BxD}b${K^R>haqBJKL@4q>acem+nI z$M-c5Z&t4)49~(LI;~3X%>J|q2*lteS}em1!VGX0E;jEyl;HD_HT3Y~waP#@gp=oP zb@P(LCKLpA%mVfj;7G*cT+M@AJ`WK=f^0TG3S6Nmk|X!J!*veFUU*_V6xSYh_c0=% z6@G*_^nWpQ-tjv)#C$=`-x~H>M;2Tj&ec?IKsrh>Qsi=}jnf!z+c=Kjn2hDP zZ0UH8(b$gjSdHsAS*Qd!682fh#b%UaZ?WYg%zx%aJm^EpMUk>)M-uKg_=3R<*K8VB zO|lgzBtgdLj6d53Ydk|T{+971YJ3#NARi`nM}lisNZ8)&Tb_gALTE7ZY)$P(a^VhX zx00f2IZHZI-&6#1SEDC#3_;L>n8R1e>cyW%FeIXkX1euCro%waVp})_8xxr^8*Y3q z!hbQO72*`JY>qCO-SEL4Xm23t!I&kU*D1i8EY_DSxEJ!+RnMq)W1VwkkB_=bkm_bOEw{?k64z zgrCo7e;Cn!VpE`supm{mcH}WLKUoVx#-4}ib#~gNOLJ)bFiwe@HOZi6iAD@vp))jt zr5?j&8Ycj_>_>%kN8OsO-#SR+x~<{5uI1XU>AJ7P^L^6jyvk}S;YOa^(``B^FMp|8 zm=cGNEObl`j3pf7zgP-B+XkS!Q#>|dWZK4|D~o<`DEMWycaC1Neb8Sh(l6}vDZ^XNWmS_ zRHhW;KjdsGbxJ(l$hCndw|BEn20VL|im7J07bTQmrYag8ND6{2ys@fhOMi7Rq-(b0 zkG@f;nR>}Fl3NVhVH5)3BwS$|i7=zh<3;g$u8q9Nk^HWeJg=Gj$d}y7n>@**oQ!fU zc^TVC&&#rwbb!L@us?@9&kMAwbfh#WK><1@_7S2weE+Zu!V&>QYt)=2FzB0rptlf! zowRZj$b7y1CBDUrn+S3@X@BpVa6_~&LUUah^lZTv8iR#qJGP$kF{IjRv>OuRlxJA{ zm2j=W-G-VB938zBZxZ8?$m_H3_&t1dGoyHX>ipEF#p%JcLc0fZ$ha z5T;O#g~P`lZn(}C?P4XA&py~7+F(7LFzS$F3W$UMAcQR@g22w~9b?aqUmwVvv?{_3es-#QOWc-Uvk?1GEk z-n~`TrjL_mzPv|Y|l1GFhC?wA`0j7o%B7f+QC!}ME+u>QTso_rd z2dN|oLg{6(b5F=%KSP3m?o1;9V>K}1Cs^TG{$)sFs$UF@CQvOSP(nwZrXBi-c`a6T z4OoEXf|X@`HnOlDvce5@PTD!x9p=Lz|i;LgI; zraeo3uz#1hMt^wjTjGvuBDsAXQGDZ!Zq!%8jKj0RmFAD5*JEGF-j_RI*j}QB*h#o0 z5Pv5k4cQ%H=PyDJjY8FiyhIegfC0kRrUU>I0NG-%j1mA+YSr74BtZ}W1%nt&5D^H* zix~|{X=jU+1ONz8qp2}mg%x71UJkyX7mEP((p$$yfhlvW9)P)x8%LI8}b9tJ?5 z$(B7b0RjMt0&){Y0GfDm32>8VnIr*%XkisH!BCq3B7oc?_Ggqs7 zYBh!I#)3#9kWxyfV8_863rF0!G-5}vtVRO=wbgD5#UxsNnyhC_62*_LSZ=B{cV{aF z8cBX;rGJ2AL69KrL00# z(9VAJZFFA}WXz<*Mnl<0AVxQlVGkDqNW>8UaeqZN6k=pV1i%>f><5GdV38u$dq6-K zRuEVrfW#Q})RNN{Ad%RgVnEr`js!MpgaiUWXfdM%r6pijD-b21p6h*-fVE;+{#9*3*-B(}|e~GjsK_4l#6`KC}2mlaL zY=6?qLlM#UlN1nkwxuLeWP|`>TP#SGEsY6w5fTIdi9mM0`4~Z&n{3IUK`y%D697(? zh^T?}Fa@QWY&K{rsue<6P^+xI>T0a7O2Nt|KEaBstp9C7$`+~R3haMF>KbgZuIdRA zTcotIEGf52@!GQO*dvmslwnp%E6K_QWn9XJM23GWWr&3YDXg>t){LzL#n4Kul3~lZ z0U*)mU9L2{p|P)S0o$^z*izG+nVx71Q>lIRA+v1#l9NHCB&%?HeDS13t7N!i3tH{o zrSP;G6XasTf5qFX1kJL=aVexGahF@&`ZJKL|IoU&tja?bG%Tzp+mfxnws7)m6n8lW zKoozYwCB;kZ2b$8D}g0|giNClThJy6|4RxRu3UxD(#wM74ymq2QjcpNQw^&Wq_}c! zvaYl;Zj(x*Y^1AN*y7H;ujNE5zy!%yOPB6hyi@?G1pw`~ir@7b(~K|fc;t;w4mstI zUoLs&nqzKx=A3WtdFY*w4m#$zEY&1ZKYJiQg&CS^oN9U_9> zqOa?fj1~T9Mc}B{FaXHNHDrL^09oOS0A*`=N(g{~eCC=AnyePeVv8@TB|g9DN^Koe zUcOqjg}E)uSF)m4J;orx2nsQXL}Zl~M&T@AjAnQQQeHARx*Yr|_~ zA{p{8DQpP|qo@Zf2Bt+fuFH?z3ri6h|3Zpdd{23;C?jIp*NOusAPTjhS-pQWW-_EP zWJCr$Rx4kLspRd>chu~iHK%#a-)Xa&*zD#uv&qeGdNZ8l1gAO2c}{Vvvz+K`=Q`7Q zww>5w3fmK64wuHbR&c6Y)T*be_=L};?X!rSxCa#!U^rQgkCVMDRRUY_LR^6nGnqml zW~4=}cd>$njbomy^wg!|5hQ;Kr0~VJ+|df|C2&S*BgHT5;R;#R(r5Z3sY%7?zneiR zYM5z-A8!e)awQRQq`$a)e--V=xG#W&;o_tx+Vf8%86*nD_np53Rt`f79)sD zo@)-MUvUcY_4<> zAoxfE|8h;;2m&A29M*r95`bx^SmdF6@d!zV4agHhe*hi zi25agxGbA7txMm=6iJ4!tdzT{(NF!Ym%3a+`AXIL5Hl;|7kNQS42X_Q9 zN@+t!{t`1EIq4x{34)Kfbc+S5|L8vx{tat(^J3k^c#JNt@oj%*{Nfwi7{)u6agS-N z;~?``$UY`AkcW)qA}9ICN^UZfpPb_*N7>0!hBA$*d}S-+ILlS;GMB%+@R%Yf8m+FMLP0IYRwNyadnfv{Gz zxH*zW0Cgcp|F?g#kQwBI<=VAQqL(9W;DiYca%j~mxetYy2rc8`gd3tcgBn6s{NNAY%MT?M}<35eFNV;PkcpU$U3yH2gO zEqm%Y9d;|gg%VgUb>6||8p8kME;pzJHHtxgYTO}uV$6U3``0*`TihPubER>H>`KR# zOSq8&m`D^_CnY2a-EG4{k{}anbJnK-P(W0$dys@IBnXTZj)NgpAOPo^xL(||-|B;q zGAIBeT(ym|rMnZ(Vs;`3e+Ux=EE$y`7%u}8t6?4MDUDS3+8sW)LtG}^xabdR7^V|f z26UmkL<)b`C5i3K9Rq{|TUQD|_>^=FgxKUhW=eU=Qew6g^jfgvRsX`6nbUmsGD{il zYgap#-L7Rd%RTLLZ+qR}e)qW3eeZVXJKp`S_r44M?}HCK;SX>4#Ge`Piw917LW@dc ztH>gHPE(#x#A~HUCM$lT|Im&wkKp$FiQMXqkPv^G0<}puy5=THG?Xw604-o2k}H16h&rLfM~Ns4;VcX2viMcHHoz)Gb1(>SS&hKfyJ^!Q!p)3 z;2#_qL=pIbDfoH`Sb{3}fgWgrE9ink1O!&Hfh@yL@8nK7$b;$BgFEPhK?sCHsDnlL zgGMNXNJxafV@L%-OvL{KR4Di=Rs%G|B87iIRWA3SJu}FKG3Y8Vs2W~afUXuSO?Opd z*o9%Gf>$^x2-hB9h=x0*EsGL|X-I=nunD#DQ$j>5oM1gdR8e(E7h=eUA6SSel0iH1 z3#Y_^QV54G*n+r{1n$s;g7{IEh>5WhiH}$;E#gWZlwB`&grk^*q_~8oh>EAEib{XT zil*3#srZVk=!&xli>_lftRheJWQRB-BMr5SxWWmya8L=AiN*MZ^;C$)IEcohHol-n zzX*-M;($~WiLnv^R>MOD$ce;qG+dyJvXXrPh$=|oMBs;MsS5?eE0kq`-yc!e`vcNPy;lw(mTbw9R;hqwz!DzG zAST9a2soE@S(X*4Dpy{}=^`)=z&)+Mi9zqeg0oR*EaM!hlkWp;jsdU%C@p!=+S;r8qjKW=f_q zHA+xqh0(JG@HwSuN}gxxpmJKL8)}IclQC0}1YgPuuF-KCs-z9trq{zVm1w6k$fN=~ zsAbxSj%tN20!`*AO)+YzE{drznyE6%sh6s$pW3OL>ZzfcP4RzJ3iISoX=;FEz%FCx zsQNXee7dBLx~PxZs3%B+CMX3}@gZDst8sd#R0OHR`cYcYE8`ig5?CKv5CvNRVzdga z9EzRd=B%s=D*TaG8X+ie1ZUxSYST)3lJcD3%ADgRT0C(J+f{yInV#Z!j{g;&836!b zVh>Q)uI2wqu5o`huIA>B{mPN=N|(u*SHB6c5ay79rebnJXHmwmN7k?p>#z|Eu@ft? z6-%)fYq1%Nu^X$g9m}yF>#-pVvLh?9B}=j=YqBYevMZ~ySymHrf*GnLu-^I;=|yku z7@5Io59xXfXaRuUYJA_iZ$P`0Wr?#8R^imgimh^6 zZdu2Oa_%d#x1WQ(h~kL$RR3%Qdkxs^+~mutD1i@BStxt+VY zht+C||Al`TQ!sg7i=BXyapvc?Sf>=}DqilYw?`Yd`RcF#D7*L}R}oMTzd&4P@@RpJ zt+xBN0jOAfn{lOq5Zc8xzdN{K3!xY>N~9ob`l2?Yw51J*r_ZasrLqN5(7i{pz2TcH z67Z#gD8A>Lz3Gd-F8HOi#5zhdIq?TYg!;?B_Upgp)V~1yzq1p- z=2XC?bHLw}z@oFj)}+6+)4-Jz!MroU`8&Yal)#nPr&5rr>#M8YD+O9Wq#jJd>KjE- zpap?y!X?bY?45px1 zOBR3s%*63~!TI#W`0KTKvUg9L8Ks#%FBCX?(?K zyvAzWz;A5DaJ(Pc|6BfV;}5+{&+f z%COAJ`HF9uXUn07%eSn{xy;ME?906j%)czm!A#7jn5cZOY{Qx6gZkWPn7QaL5t52~LzmE_hp$ ztkb}{(>;yI%KFnj9fd%RqeEiU`+0xVMUB);t<+6Dh)Mm_OdZuu4b@dG)mBZ_SFP1q z&DB%w)mZ)2TOHP24c28X)@DuCXRX$0&DLY>)@c3KYt5h(8q~9jt3v(Lgj&~leb;-v zqL86-u6x3|DE6WE#Lso-|-#b z1m55Jo!|q0;0<2j{{6cvZO)XZm7~ng6^_!(4Vcwk+#R0bc(rdIUd~9@$u<2%y+Wjl zu_wGOY0^F45^k;@-p$7?v>GnrH_p%QD6a8s-0->)I}YCqj^GCV;0k|E5neykxuE8Zt0bd>6d@5>6y;yicTvveHb<^*z2;UIV@5$c-Wn;>YL8$tM2Np4(qQj z>#Dq}Z)mnuZLB70Q>a41t`y8)yj>_CQ;|r4(HE~@2{<7=;9-i9H z&-Q*QE*{PhdF|=$@A-`HsxqAaUhV~NZq%;t1h4J}ukZ~|?+pL&3or2w5AhX0@v{7@ zbb1h@zzUpjRHXu@V5*;t3at4F)fykIbDi<3+NDZbs4CCXDUb3iZ>P0t^D@u!IbZWc zUGlj~^GbyCJO6+4H$U`2f2Twb^E2P`MbGq0@AOI!^*Vp_MvwGSFZ53@^;MtsRL}KI zZ}n5}^;;kISO4`|PxfMu^=Gg4V4wDEU-n(!_Ga()X#e(1RJ}KSJ|TZ7AHVhy$is49 zt1n;oG2hpN-|~bX_i%6cbkF#GZ}fuS_;c_0KM(nnpZI@c5A>Ap^Ow*0i~n!=o-g^K z5BQ27`J<2drr-IcPx_y)`lrA8qW}7+5Bjp-`l-+Qw-5WSFRevMA0;u*jzO?KKD+QZ z{2TeL6b{PYdforU;h2_Zsh98z24Dlr<8l7u`~9wVzRn@7;l*F$-*0?X{^CJg<7>P*0N}mihU_C>`%A|1FvN(4Y0GWkxPuQDUOYi@ z;>nROS57=Y@#oE%PoEtfw{_{+sS8J--TU_M;lYm=U!MGVfzs2TH{YK9d-w6-&zE1H z{(bxT@$c8)pZ|aV0t9eC0ShGXKm!v*a6ttdWbnTvb|Phpe+QK#$idsNn+^fo*2A#E z`+(>U#1Bg(@j(uQ6X>|vJUr1q7fqD$MH*+cu|^zk)GA5Gxx~VI$!=@OyDmmF$w~KBu1%+4A#<2 zZO%C9WQ(LGf49`=uP(u?6Dc4fXyq0r&GZO?t6)0_Q6vA%OM+HxnIeGakRYR-R!EUR z(MCI!)6q{owKG&wM>Q4HQ$9c3f9qSTvLvWWf%+VlTaf@rq7{3% zpfV*S^t@0mdC%2$POsSAf=UE7v8A0hd;)jdd(kcUV1g6>o)+PS6*kymh97QtVu>r} zRX*J|VPzC1gOnh~0sZAf0qfRvg$voF@Hf(SY#~L8VBNIGm49uClt}<2AjKAUc((bs z0ZI8~f8^-cti_#4c?2SgQTo}0$RU3OS-d1zv8NU#Xj|8wR+0c9idyQixn^5zfjNp+ z#AewRnp^g1#w2R7#oguB+Zt%tfY_pzTIQV^>gl9ln{L=PsihsR;RHPD#T#e*amOQv zd~(Svr~Goui;GanEjFF*Z~5l+4~a2cX(a#!e^Msp7ES_Cl$5nwY31z%G^uCpvh5)Q z0Cm~**&bV-whsX^xK<$b*mRZ@c@vRXr5;v9&m8l|ng8}@cGZ8uUT1q+ZkgrZxR1T& zR<;1WLJ3Owg$w38tSt&s{K2^(1O&%^`nqG2qLf>vV4jS%X^MLI8(;wmct8Ut5P=I+ ze_#WF6dHkSVkcYRi$c)!B@1TIVp=I)8H5mktU$^ZtqXt_zK1K>!AwM>s0BOPz!Mm5 zs&`!DNu+d_gaG&@Oj1cn5Y$x$dz3*5lcHDvs33wMrzSc%9#0G@_2++j8#7M@taHNZsDI-vfC@D(O z4+s*l)(t@zF;qfKl@v=QD_6NeSH=>Svux#DHYv+i((;uV3Ck9EvI&wQ@rudnf0Un$ z6v11RBA7l*LL2PSh380tU3r9EJAQ*GG1Q|LFvOJHCV9J8P!J%BT43ugiXo&%4vCsu4$>|A9M)V6Nf)L3^O9y{i@Hbwyr;QN zn!tlu62{QSf3bll6}{*kXI6@FfA-{)^grK(5dqXky*M0zt~ z(SrQd36iCbO|5W;yQZ+Sm&rx~Qh=FU+!VcU+Gc-*a-9PFrQus>siAJ=fMP;@NAhA z*VuG6w&Ce4hT~I5;Q2O(m*UzKWYe87K+m$p*@DZ$TAm7@M}6#Te~$I?ivX~-g1eS2 zE_h&TVEszixB!^zfIGB>R`?>n)@AZ_oBU)aM;Xdfma+jslM`HJjiCfHrfS&KLGrFC zXM>v5BwEp0MSvBo3+0-0Ly9qU*;(r9;1GQ(LT?6OODgMThNVCX zP2?tr36R1VuvHWyf8%NtX;Ykt^u|Xxo$^^3v7^8MAR>z{DG3dh1g-?Q(#?4;CuJy9 zOipWR3BUz5rH&z|HbFsOZUG5(6lu;xt^czLdX1hs;ifn9MoBWnm)Igr&z49T>{EV3 zM2~DmJ6v(LYP+HxA8W8h-z4lyZsHb0UCXjJ`c2(%EC~{jf8dfhW!#M6Dv^tf)*&Bh zi%0??hm@=ZD+Z#jEnxc8%-jhheW^!%C!rMNJ_xV{#%NQM!Osejpx4?Br(6A03aLH7F5E#iOAaWrH zQfQeZ3Olb+eQ$D`Sa+rwmLUkyt{)jT5o@j4bQH;4fB^ozwF&q%(lDID&9$KJ zfqe5O706K8G9i0U2ls2DC(l-%T4X9p5b(N({-N)r%1h)1Yd68I9Vj) zZVH_hL_Pc5Yl_g*B>1$=g?%7@N4y;wJL-7X`b>xlfB2>jJwckUGzz%7y1sEfkZyyP zZFFQuwar#o5gCphZT~kAxDZ4o2*A#+j^o8o?KPb|O$oSKG7{=e3Rl>J1MjkG)pVPp zg21$Iw2?$dH#)4sW%}aA26on}w3Dz0APHJ9zUM~W`PgiHQ|Nwv^rtVzFWaLHO_F)b zbKZZLfBQcH^gjX&zybt71ROw_xDAl`n?qBQtWlBTXut`qz~PVpnPNZ-bf4S!suvrG zRydg;;DUCj1=?v5w^A|W;2xY&KmT?i1JqF~QuqhmK|u|i4Waof-dPu4YqB%__3vBFl^9sQvw z+DHI0xQ8-Gt_%D^8vKV=FtOxeLwn#Flo<${*%|HeC4b|>f53i>M|^ZgfuzTQ3`l}JNP|pBek{m^M979z zNPlF=h;+z_gh+sVNQ|t=jJ!yV>`0FM$bIz4kOawp97&TbNtBdGlGL9hIW^+g$dk;v zm88gEDINa`rmtBYP7naWu_oehFPFrP`D#U9`-hR4jj9nsPN+!PD4y{uMddNFe@0YA zGB^%N@Q0v;jU3xChfG2ttOt9jhd~^zCrrkw%thLWpShb2u(1cA`N2pW!p|GV+1P}C zAcOWfpQ_=;Av89*fSpu?sQ=_?g<9Z*TL2y;gF*q4rdrqpb{RsKl%_o#!Z3`EI;sZ= zY6n%!t7rto^3jl~fuC9^1tMg$f8YTvmjVDI6vw)>g;H367_`XRK(alo!k>wbEr1?Q z2+S*#jU<>#A#?>r%1h+1o=D1@uc?JrhydP^ggI(CqExe!g9 z%n-wyC|HGlkOI)6%(rBk!1RUaY@gZKf@I`Mm5fQ6luy}A$@#R;`ixKfe}u{X6v_Ua z$^5)e{TxsMEzkfxQ2r!P1ock^O;7`M&;!vS1*}5|*O|_#BeGQtrJs_fa_T#2<1^qBK14Db z5GtN4j1a_Rv=)lB7yT*2fAKfGdx%C<97$k>Z(|cukivnmn$jdHwaC4HQMG=t8RQ8k z3W|U){0$=e7nmVDom?g?N{Ddc0)m)=R+xlE+nMrNH!}r*2=Jbr_?K2ll6H`SQEC~% z^cQRTDJ0ND3Su;Nxi;Q`f^i5No58?`>KcHX8G-39Z(_WskOCLte}YWWnO2YlQdrba zLlifG6T*XvturUDLnD#^004T_iE@P@;w?m-9W)Ko9ek)G)kH-SxsSNFL^B9VvYCH_ zKFlLJX5B&G38ExmBmb9Czv5{F)3QD>seS%S2(5)o+)pj`!Y1@*V4YVA}%{EX1e;_?lVcmorN&>ezx-AGI zl8KgpGSaz2Ja^kX!9>y{O&eDr1$J$Riwd6lvDC6cJS~L{DQHX#RGxO20OM*CdT^uJ zQGjrgfajQmpWT=L1W;7XODDpmpLJ;m5-=T|Xos`vP5OCKhzp2W<`!2K8IQVvqM z2Teg84v{8TFtI5(qP)pcTM&p6oLzIZRZK|$_S77GvCPHFJ_Y_uxk+E(kfwiV!J#xw zSHK$((=j6Jng2G84KlceG#n2tpv&U0A7bnfSut4R~Q**>_nhs&QNrRG9X1Y^s%=sVM@^i zxv5>=0@=6GMC%N!S0EhOLV~dAn7j!p-!e!yrk(pLD_|5Gx4j-!gVqR~VYZ=Lxv+*}g>%YHTbLIma0PFijRat~<*G6&+hk4de`HSnWKRxdP$m&}LI$Y{pBa z-m#xeF`+k3fNK8VnT-_xS(>J)%=n8Ge+VwrlkOj!)EWTlH5*y$F9-o{cH*%xKOr38+I>!Lg!~C7N z&YIm^OaG_?n|i>&YTJ}CxEhNKo?!ZjM&%kmW7R9Q7Es(#C-fXpU5Ne?HrNHye=^YJ zf{=pI4FDyGSY#ran0P@P5<5jHo=o#8qPsQ)xS2I=1+gw3^YQFSq`9i?8ujCyE#VzG zE7kyjfN{7M%NixvQG#XNy%SrRHe-wKAi6S807%BGAh;A)t=km4?r$bc2+RbJs@vm(>JcBa1IaO4fpU4pObc@pt{aG zR`Tlq(hxUg@;S%=<&F= zhpkCgpeJ~Jw_&++)#I57NUz^W;8qwSe|Z{d`pj?YNzH>d-bPX( zQWr1JAwEN4*oE{=)W0Dbf4+<8Z*xMjmT4VMu!lA8mlm>e-uXQt5&$9a7I>k%9XzSj zdE0sjyBA&PP3VqMFr|2szLRAZg~Ga zssCPA9A+wokh3V=@lce#KRZ)r;1OOr$8ZoIcX2QG4>$J^KX-LcfA@$X-!a&Nl@Xo8f6gIiJe}isLEGRS?N0*J$t%0;0(5o}9L9W0CaYPDLx8X~MFJ^-38dI6 zr635Wc4PoZX;n)~0ss{w+M;L?qLeKiOWES0z@n&rqbj1@m_SoKRvuq&vein0O8=EW zQbuf*L;)#(rKSKN(~eYvlw(>Jt%)*bsRaQ>ri`j|f1{L+9b4TJ31Evv1t@{Asy7Be zgi-#sq?Ax9E7Ykeql6_ziW8Ypa%Iv)nZ%6!?qIRn~HS5)`Tf=@WJ2vgvwrk_QtvfgG-M)JR z|1CT?fAQhQiz7d-JUR2_&YMGjE* zgaAPxfP{==c*sshE@B5`SY!0aUIa>RVqGitY~kR9S$>ETR$K&tP&-m^NXjjoC}0?f z=7}ig1QzM%-Hlx8$KNemNZ?8<6m8*3Dfw~X&K5{WP){3D{D++s?YM%dVWeF7%Rj%^ ze-lzg5Uf*6l^jB-;9r!2YR@n2tMVs70 z$^rs(8Rd2W5ESV@w%C(ZMOzqpT7pP~m1Pu4YQY_`?I^2ee7a6+OSKM)sQ-yfNkt*l zE&upaX=Wl$$iaVxdDtmJ_Pi-YJ+52_e_*kq9&6wf|3Q&yi|Pe2Z7orBxgWMuFnLo< zqy%u2wl~GQt;ND#ET6`vX3R0hA8*{T$RK+hGRY&GoN~!4qs;QjFRR?L%rLtgGtD#G zoO8`PT#bfKnze?>|# z3bJL$ii8;T=;R=F0jOlAe>B?DMiTjj*kQ^R-l7oOT3$Ip5Tw+@MMa&8eNs%D@>fxGALf=Z7Wc(vH01sD$7y_Hr!U|Mu)~-YXfLF+L zM<8se4_ShPNG(Y9qRnl>Ntpqc)+ltl?z}H+Nr=`)5TgaOZKy1fvWGwFe^HAT;f)ha zk{S>eB@hYO!hu~X7mg-XDR;OFMysI5pb{1=4x%IqTLV*~fKa`}T>K&##R$eRnz4-2StD_hLjN7n*aA&YNswS% z0j9QCiwt_A2of;EiVaCce+r~q7GO)g1O7|qe?(LYPOaQrtr`ff-DNrhp?dC58;Q_+D*Xh^VQg zNS0P|#TJsFr!4geiAop^MWS$^yjY}_WFiP@Oj8wG*b!b&a*L5#e^M2zd~=hlL5M<# zRh3=Y149K zQ=D5M#Y5g`Q#RuCf2KOkDNlFm)1CtLr$P-XQHN^Oq9XOEN=+(LsRKwhB7h9;_)}X5 zVkx(50h?T6Cn?B;63j{HN`*xN|Dt)X6DLiBrVXD+N2+#y+B#;D|ya_vH#R(~Xe=!9CJhmMx-NG-lLYv1T zkDIM1_h?-DdB~uX3mL#r9DG4cFB84e*#96V)i2@YB zO6|NVZHYMwE8M7*E#%E9czg|Qw=ji@u0<$Ce=lPh<{jPQD$;BPgnDr#@_@SqkW>Z;phRI*1k8%gAm%f-*eFUOqT8iH zxTzA3Foh><;R<8;!Wzynhd1ou4pXO$XvEod&X@!kk?3{~T9hmP1xOLK39<J)kzmYpm2Ks$<_W2bmK}}t2;mLlo9J; z1`{Q5e?+j7m4j+kk6VqJJG>&rok`Z|nK%d{*jk_#+R+9RJu6y3fB*zj^1*CbY&-zy zR<$a<*#2Rsc87A%me-|Ycli*INMM#u9Lcx`b($3tfUyK{#IPU)?1U@qP`5O3)>1Dx zQ56fah)rDYd#5qqBeu7``Mqy^|GVD+_cy@@e{OJrBRt>=FL=TkzVMvwDj_P?@k&9h zushL86fKcerEx_Tx1mJIq*9{ph9$C_j3RHhV^`mTAifPbs9O|42q{Y8jtD&Lc#bt$ z^Ca#yCAI41*onE4DDpIo>H0DCsp(D3L=?7($eHv?I{z*uG}2TxOw{MO$zc2gpe~K) ze?kp|q4fmQjty+Zq1SO<{N6-uNyxc{Ce%c(;g%{1$w+i1r`1thnQI`NHA)tNyY%)p ztC*1EE6--fr@iwaF4#v2LlO|4quq{))SY-E5pz4z^fj?X5D9>bRqEnw4-Q~$V^*Ia zwSn-#v`u=>0b)g7QIg#|$NI$LYDWZue`k?xqpj*mCQn#v+P6OUv)_H}gMa(pAAk72?=z+7s=%m8zfIQfe)PXz{i@kHhv092 z`A4OdD;*6#;VBan4b~rp+D1^5R)k3mtyIK?)>p*;pN0H_9E}8wbPrH4%S5C^e_BD5 zOn6!=JyZL^pEUj6jnrP+)x%Sb)I;H)2UZP8$plcKL{_zf3tGe|_yVn63r3xl*tkMq z*;iE1U@3T2`(;%txF8UOPYoKEOd!+p1c3sKMzuIo|6RnZh!SR8Su|nOSXs^}aMVWh z8b*a$7qUi7bb)zsLOMl-D1g*Ef61RzmBQ=XhO6k;I@Llp1;AoRp%JPQMF4>nZU{zX z6Kx=YLRH-B&EPNCV6+)XMy&-En3qZ*K?6a`6;=-5u!2g6iz{Fo6tF@&yj1&%9@+p1 zAl{L#NWmr~UKgHefIU zV=oq?FtQ>s4x=$HBQnb3F+yWAMq@HEBQ-kXEl%SzO5-+KV>V(VH)>-yg5x)eV>yQ7 zIg%qfDx*23V>_;+I#we*UgJC7;#1M%JGx#W0Z2pb32O|Y3$l~AO@)qq58L$GSx5y} zEQvL4L|Vxc0Rq_bpw0f-e?$M3TP?VQD;&gNERQ|%qA#I@*q~toL;~7;V%0QO4^D+D zEzd$O%PpN15~S7KY)L~J$lV2_MI=CdWW-t3#1>?vXUHUJ>;&yGjA}82L&TN^2F6H{ z!A&q0D|l5thQvrjKuL6tQba*LU`hLN6g2k#NgF(59uXzw0RVG3e^*p61>^ir{5ecI z1tC|k#a491S8UT_rKE&lijtJeKA>n&v&G=4zs5YkH$~deZ}KK^;^uA+XK$XSaQbF&Dra#b=W#BlDo#;}m0y{WQJfTn zW+j0HOd6ML0vTw+&Jl=#Y>KC_!_G-3-%yC3xfdyD(%$Hqs9=ch%m=#=*#IF%x%Ffc zutB9%3Ea3Oj$nyl0I1%HjFO03^@x?MZ2wTDpoJh+3PrGxf5ia62O&#gL!Z2WXZJ| zM&oD)02G9HCJK>k=a?F%K>UZn3=Dio!e7M7n>wnI-9$V5!c2VKcg)*#f~xq1s`-&B z`HAYOmg=aQYO12@s=BJG!fLA;oOJg8o07W1CKy*kf8`ae%+wZWsa3%W+n7!8frr+) z&57t|+$e$CScH511K9|V#iXB^G{pe1oIU&t0u%(6Bth>P<^(9HBc<7gT#2Wd$I%7I z@qOrZn(OF@DX?w;4sOjl8@D$MF+Va2*p{3ZtDn+ z&Y9F&h#UpSHjEaK#?1m~4jo?)z5fpq0IdXYe}zNsSeh zFcz+fS5Ih0R@npaZctW;VNERmn?xnxe^JyzQ9z5RAn)@0Lq@fNW!bOf4)9}sg@Yzb zr4$58RgOZs}(=7!ZhK+KIv!SPmzUW^p`>INy?gIWmyRkXrCe}qQZ zG%iW0#DR5X)liD)%*J~TF)9*eMF1;I;yJX_x*A7g%5}A*IVL)Po@X z!WP^zRUDWsA%GO5q`M;TflStf6viz}ok0v#C&n+@tc5Z!+krh8BsNqrBTr==vJrmq z0_t!(^YA;nb3DWIJj-)E)AK#svs>V8bp9KMXhL0(!6wk(^K3zP(giC1e-iwR!6*bY z8PJqKv&&3r626k`&Y9vURII$F2j%SrClpf=RgG!v&0ch8Z+*wEj6s#GjuJSud76xX zoMJ$i6ym@fNjr2##|(@a>Mh96iC9O1b!SjRMVv^%LI*T`nDWO62+lRMKr0{R1$96V zb=7bIMJNFmsPq(Y!noQ0e@iDjbX6}z)=2_DOMwJPK~>-BKYNCH(9sl>G#TJYzFZ9L zaKiLFjdoUScaB0z12jah2Z1oy$v~w;-_t=9HnE`Mh7^ew(6ri&HBFHTfh0B3g|wC` zbbm_7OvyAW33T#p!B*pr^hL}<<1~kqG+LiEgs#LGAWVBu-al^we`?>8ygl<#hxXj! zw%z9KZs+zs>vnJR_HPTfZxeTL7x%-_>WH!FUNCfb3S|Un0$&GpUPPbvN_Ru|t8;hB zt`v6ktyfArR^SzoVdE-~WHxA14`D;}$1?UuBX(h*Oha=_E6Q{Oz!b12p+0yLq_`9SUSi-Jlq^M4e;c-EXt#qq%#oXHiEY7< zXQx0Xz+#Vre)mXB3-nez`A@fyNoU(}54UmKIdJ27o##27^ZA|Y`JeYWnBsG7==X#j zqiQ6W?M_9a6UW!(b8ok6qOM}RP7 zbcniQO?st^e<7$3|N5iLdaTp>tlN66w-zd^@l2`nZdGxs&_3n|r#W`?{Yp zp!2rT7Aw7N$BNrc^8J~5NR7Scdxo6c5}EA0&-qgWf4rw|UBA=&!F!IE|8#;#!j;Gk zO4NAMc#e9B2L^$i!53@E)lJ0@eB~6%p%jJ1c!<4!Equ5Sl4g9ouZO;0e3Ya-0ldbue8p_X%O|}L*}Kqle9pxD#dkc(75v#BJ;8H~y%)X6 zZ~TL}P)`qi%0p>pmlZ~g7p z{_g93@8|ySTOYgYq`DLTghj;D`g*!M6MG7O@jHL>L;v$jfAmxT^jm-RW54w$fA({K zZwS(j{DTHVe@?*G^_ze9qyPD|>tU1$WPMkYs_Pps+Cs3Y0i3Tk?)M!$qONBN)>QpIIrcbFxtvc0eR;*j4 zcD?FVD_E{y$%ZXE)@)j|YsI!b>sD<`e}in_ew_=Kr9mJC6Y^Df_oYE8w(Qv=l>l(u zx{4D>Y1L|i!NhbKQ${$sE#}3R@pj(qxiV3eKJ8jHY}T)7$F4nF zr)0zob?e=&`|oew!h-`B?z{JJy7lGPnPd0f9XxmOf7{7V z51$-8_4NPQtG8$W{`+|I^0m(|pT2&6_wn!VKW`sE{owmgK>h;sk3IzZ8!$lz7c{Uy z0u?;aK?)(9Fv12etT4j|CFBsp5J4Pq!xBF@wpy{f2>hGADIMkJtK?6uSqJWd{Rm$t<>_!C6&z5%Pzqj zlSnSF{Ibk2&lEFFGO=tE&Nt;mb51tvToX?-^;C0DH|hM-PC)bIlTbecJ#_b0 zTdft>|66sXbyr(?{ngiCa|L!-VLQXL*J4wZv^J7}1r^#(ff{wxVp&6WRFbqcme^#y z?G{{a#kKa^amyt)Ty)J%_uO{beRo}Mx16`idg;BFx&-Xq7vFyQ{ny`s0UlUjQtTDC61Wle~B%w7yzq4RKMekHO`phjXmxd2TnqQoPE~$XZ%{R=a+^TChV7X{t+o?pq*|S>ZhfSn(C>o zt{UsBwa%LBt-bCV?61WR`)7AU!!}-Z)dmZQHW)rS;ka#)2;Q~pHrwvB@qga??!NW* zJMX~x{+n?B!2uuK@WK(-745_wU)&)8NP(O3hNY+$a>WndT=UK!=lpZeGY_5g(M<>a zbkR`4iOdep`7P`-6XBy87n5?|&Zr@5K+F{PE2%AN}*yPv3ae-LI6{pc-e$*@^0>yQlf* zFG|x!hgMi_aw{FLD8-Rwt?qrP8(;zvxIhOg5P}ViAOk7*KnY$@arnDdS}<3YWsyWx zFzH}YFt?75|%D?GnaBxakgwQB2|#rC3EOUeStK#NrmU*hMaa4|-nA;`NT^L??1jj9*OS z8r9fFHonn~afIU><><8`w#0wSn;ZR2NYV4n zy>t|$+RTz<$Y{oKqOqY1m1#_A`cRun6sG~rX+dXtQ=R^lr#0OvLVG$?hZ?o0KqabB zm3q{sCN-!&HGdQ@sihhTY+?%{`6B{=aMdkfHLG+3Koq`OR*?iC07;MnR@qusC5)sm zWNoVy6d;(gZlSDfl|cW@3RfT?K&}#S76Ou>g-sk}t$9`KS(DN~#QZOSloQD=b4gXp zRu;3k)GTE;d)dxrmb0PNTOQ$}5Gj0Bm|{)BVj%^9u79=_tIs(tY*DK%u6mWN8z6{U zY5$8?tWK68vlZ+}xEcjbaDlN3k;E1dGF!#^Hns7CVMkDiJ>0Rds-X3(Xtmqj&wiJ? z;N9+cy&GQhO7Vy@lL9IJVUJd%*C0q3Lo3?*5Bs)a3%c1vE$m_6_S%AJj;U{d?_u9E z0N|qg?SC(U0Stg7*uo$FC2)NYBN+kzqZI*&&M4~fk5beXF(2^LwKisQ;Qu6mU1&ld+R$;-ENM3b04da>9=3ptiU~W0NRvVWo!SH$ zfIAApa5~dlAVp*xoaw7BYzj(9^+hLVk1|Lh(`3N3U?dQQwMsgxrY1;1HbKUANI=7X z)qez}k<85~u%aHo?o_QU-3m`dS0URfvk3pEN&&jCyyO*cwz<7+Z<|-!;O@4#za8$g zk;>8Ss6{(mkz539;?uA`Eh)CJ7a)+KF!YA?CfYCszQ}tE?4H5|=*mqXj3M6=P{LYY zeUMCdD&Vfz#KYgJmyZyp-NXfe7U)5kEq|bOtn5Ityw|$!T*G?W29dT*Qn|_%UjMG# z<2JX?q*|Xhn}|(9W5lR=(>?Du2qZ}V1INA_Hcy)xOvgR z*n$fnM-0_}6V0gGLJ1cI4k?C?^bOOaAGRoiilxAgdN3`dns9HEsJEqXN_x+d#(#67 z5#4vb|K0C_2mIg#U--dCv}BU72P=k_;@w1H-Hw+G&>da7E!34SVRW=DToDUIeC9ttjXXC?1lUoFO@APKAPG_P!dhtWO+~%kXmz1R;N6M^pfv#hO8)x8 z^upDNC}b12aK-rpP(L6b(FC`W)?OZsZA=XMSQFHi$JFatl313_>DS%uL?iLlt99=1Re3?s=7Fn`h>=C6E1 zAy%*+JkL7lMhc`Z>@+M0-ic(QA;513TPq2WO$#Qc!nlg#^o(huK;(dHFNUoi zBH%6tE6dm+B>XBS2;%M#sbRX$14EDuM{o?yPz29V4bzYf%g`mz%Vr?26|^7Y*O| za1m?gu&iYBTww|Vp!5PR^!9}hCqU~Y<=-fKnDeYBKP756Rsdg;x6#e9$sMy#IFb8j|I()3kc!> z>0t~WE!jd2$^OFsehqV|=_+Zjlvu9ST+Vi2u8(?>FL}~0b$_xi0rM|?k}v~vRhCMs z@I#nx%F!lg7?Xeu+<_FDE((Umy^yju*6Y7G;p5bU5eYFf*{d)(v+0WCGX)^{B%t^L z00f|shBFNkP0~wQ#ljU zsgmkjYJE2oMsdGD@lRLlDJHvB2#dA5yQ%52ypZ|c&9f~pmHo+ch zFT_ZJ%?zXHgu~2!p%qBM_ymLVl(F;(E67>_GpU0Qhfx9C639Z&AV$F#;8XTw(>)IZ z8}SP`wm=hSF$zdw6JAU|Sn3>I%J{MqJjatn%dI)J0`9NA*R!IteV000AED*Z6bNv~RTluqoRQ{)VP#UacTmPtxdbup-h) z?O_XpjrDH99Tc($kpKWB!4}v7^Wd#W{ZAAUU<(#TD3eWJ3=jb~@lp~XQhY=OcmsMGtHCdmvyno;@W#p?JA~D2jXBab63K$PR^UFR( z4deW53tXWQ$5Adt?Of=I+k% z6ycFzp4VDDS+QX`YYF_eZvg(De>u{WYA(}ohm zynj_UK6Vt>tik{Q32H$>rLiA40WnyiVvleR0nr;zM$bOk(F=<*KnP+a1r-# z<04VXg0RS77v(Jp4dXy}156D<652rOER{~ZNv%Y#NO^+@Rh3dKa4(SH==APX0xnhW z)WKRIavu&8H*QwBNlfQ33VFl-*r8Y_;&QRgE{9c>9L87$S8-*uy;}` z%+f{yty;EPH8D4-X*XHnUy0-Q24*;kZ|{nhfREOI3mAb9Sb-C`YaZ5PI0avg276yn z5xt2IP4+l;&tW{aKbKKGQWvHQ6n|txtR2{_XzZaLD3#0-K=#OBHzsEk7!QN9>_RU@ z_;^;Fe3pS1n1^@RhkF=^fB1Zim#|Ed6v}`U+JF>bH}pa(BQ|f+p7J1EYyB(`A|{*N>6+c>(#4nOBetnUHm^aY-lvRKQ+nB^pQz{3~v${v@$;Z zL7V56n@dl`lCB^dPrmAbei4KE(ltrc%pZF56I=FRE)QQy@f_oZLxb3#h4`NX8lVSS zpcUAGE9NRK7QvR56Yc@i)ao)$&(m6=XPm&j%W!V&^@{Xs#|SKVfD6xd9mtw(PEh?9E>+NTTpr-K@(hgztqBxyS% zzJ?Tz6QlK{U=;W>yJh*h=V!A$ZG$vkBcqs>Pk@=W~*{~BEu@}3T z6}z!7Yf-CDB(NpwoK4Hf;HnwMTR9;=@#?GoN2?n{1+m%{ICJ7Oj1*jf6u5w`(s=wn zqIDDE;xyqRN)k&8YgVrr-oAJwTx)mts98N0xtH6y9s9W% zJGvbwMj6xTo+F>`4PN2vT5UlKw9L+krX5HjH~-gE%$C&Ywk&Dn>>e!Eya&uMbjLKg zc>#x#%oNcWqhNl6Ey@Z5dgP9tv4bfyc0vO9LydDsS+qwLJi%*}M;m;>9h||dTf!?m z!Y>@cGn~RPe1F3=97jD|M>|}^?X#d%wtowv4h7=T&Raj%RcT7B6-uGvUU$JbufAlq zBwL|q>Z@9f5A}v|H#Wf^__w3#F!XlwrN^p-ZD9idG19PmIy`MXCB#F`(aLXe#6$eU zv;4uie8RUJ%(Yz1yFARve9XW6%q<+m&795Ayv@}-oqr~2go<}7q@b-Af@h>40`lC{ zP|gBLLf-m(H}1603tg83UEX$MTO2}Kw#C-}OoFg7ZO?_Nexe}Le77-xn^z$v(v!=8 z@))`02Dq9Vx=UTUO`X(FUDZ>a)w?2*aiV^ZAOc`*x-1Jt+i`rrW*2)J~-QS(4;~n1TUEb@xcMkey{x^2?{bTqYV|Mn+d)D3up56E>#BK&uUd}Qz{;K9e*N9M`e#W?(*bEdF5r^<)t3xqu%PL-m#;%?rdi3x1{S;=If(oR0rE~5+32t-t5yJ?blxI zm4D87!k&EKK4AJbfZS_)*dFiOKJWKl@B3bPQlqLQiH7(mek9_6xG+*q>EjoIlos87 z;0o3+Vp|$NB^;|J@^nuLyN|EF>apJQJ0J9^Ui3M?E>A0be3)4g3IAn?;` zBK1+FbEgX@Y`-NGKlLrb3n{R2emV3>Uw`;NfAooe_=7*((w*>^6%$C=LNG`G0DloF zg1i_KqsfveNvh1a@?^`DEMvNq33H~+njtqz*||j$&z(|mz8p%E=uM+Siz+?3G%3@k zPNO27+LS6(t5U0Ky}C6k*RNi~VjbI+ELgK*%c?!QHZ9w?ZsVez+mqdFZy=e?I^J{(l81-+=KI z2;YJC^+zCr`2o0Kf({x8p@8v8!5A4q?IxjkoHa-xg9>gKB8U-^xEo|5mN?;q9O5N$#RDPMAwkU(TY6n{it43&x3nrODk7hGk=Ii{OjviWA6blQpMoqG0}=bwB6 zy62#N25Kmwf)Z+HX%4(^tzUpeMw8CoZthdUlYpuBUx@)ey{_1P6!~$E=iMb^!*%Bsn{bB>XT0#n8h6a`#~z0a^2j2WO!CPlr;PH- zD!0t?%Pz+Z^UN~mnie=RrHbh}-5+;`7y_uhE( zt#{vm|BZCNKrqHd018-0N-K?7{I3MA)lBo`l2=ao<(6lT`R1B;&iUt_hYtGaqL)tk z=>-cUNn;8GvBfBimlAEj7^7~x>A2snd+xmN?tAaR|1Ny+#1C(L@hT6x$;1i-pwSp* zi%r>?e&=mHf1FLBUH$ffU*A3Vf_oo6_~MsO{`udhKfe0r+X}>DTqqGh5E6I^<|Y!H zNqqwQ)1@~1=v$uv+c&@f67YZqOdtRkSU?6M(18kUpageCIA)O{00J1v>ljB0DM%n- zjziw@MkvDn5|Z$QCQKm;SE#}kvham2jG^rW@-R*`e~&U!_|mh))UVpf@P{xAA`pit z#3B;$h(=5z5|^k%R-WN6{Y|$W#F#e$$-dG$uI5xlVPWf0Lc>bmuwenNDrilbrN~XFlb5&vx3A zmC9q@{>X%YBxnj+cT6Y}+M+#2Hu9m2jOZdK>QISRl%f#Ds6{h+(T@5lQ_hs1^MEi% zhi!@!6G&4+7RnVbZuF%Zjj2arI#Zc)^rke$sZDdr)1BI+f?%3tZnO}wR%}6w@d{xR zf0@YCr8>2#P<^UYqe|7OTD7W3q+u&;C{)8F;Y*AoLng~L)~lKot7lCsTGy)9wzBoD zZjGzxZWhHUhNLr^=-CvK@R2Uo^shPvY)=L2Q^FcHu!tROVhg+2!zP4!V2j`UG6Dnu z7_@Bv7@`bWT*_F*dbYEmb!=!GOWMzhfA+Mc9qmR5XE*^U09jG^5;euTF38FCws6fY zZg;EO-tzXhz6~yLwbjgLI`2dqp=M){YOb;h_qoG`u5_bI-RWAly4byLwgPI`Gfjkg z1=Rv(D+Ut%q1LpiJ+ErltKRme_r2(iFMHuz*!P8WvLq=}CBwwmh7yId?v-zVf9ES; z`VRQO22OB+7fj%o0`-C!ycQ6W%NWO`OR3t;u7)$r;SGDZ!yx{!h?}ZauDl|RiTtr5+@Z?|C9n?1vePM3`X*ijjZG*Gx^C$ zhW~N`BP+ilV`L;)1XIQQ8|5x9f7#1U2D6mI{N*x}`OIS;3oU`L;tQS_JR^Z#g-1D? zP44rb{M2VZ}|TGFVdbgC~+>Q$S%)U8&vtXZ9DTAzB>uXc5;e{BuxTIU+q zyZ-epxm2d*br&TpEl5fCI|YgqnX+Pzwwco`Z8TS#+S9(awzGX?`;tMV5J~n$5P*a% zZ(*QGc{aDLy>4u?``zx2_qO3ZZ+X9oa8j7U;l8z#W^KZn;=*u_2Y%y%6a3%?PdLK= z7p_}8PT|eky6jd%?PBICfBfPM&p5_6uJMj@{NoDi3lac87%M8oc`qSp60|Uio@rJq z>ZW(Q>rL~S+g#>2w|CBKzVn+Pp`>bRokF~uk%fW;$Ym1~moXA^ojaZ9PrrH8drtMK zTOI0Qq}NV3wS_GJyjx$57;VvITacUG<7hv-+SAVVwYxpyThPQ9f4bh{>D15PQaG-- zUi|jHxBc&a4?N%pFL)8g<%;$Ta!2O$5P{@1=^e6i)v+#h%1^%Xn7_Q{ubz1WlT->S zJ}IXXFj9q%si3cOITg1^m8M&s^P7LY>^ndE*wg;@wyz1iq)7@d`F_sQn(#C$Z2xg{ z3B2%xPd@UOul(jSfB*TP`|65+52j`#OZpe$JBpppedv3?`{4h+_`~lRtpH6EVzQr@ zk|2ealOlgx=u~-2df83w@#QVeJ?!^id;a6z|N7T|0{DM07k>kxDMv*R`xO)$v3MY{ zQ2RwONkDP|NPretffxvY8aRL%$bn7=AyU8!6q7@oum#bye_~x%VVehGE$MhHUtT za=0Z(!BWyBe|r4|DhYrD2lXWr;1Tqbc=)q;M29^WxPgo4fsWXSkl2Bd$cP9Qi0>yE zP$YWySAPiA8U7=Qph$_L_=urMilaD*rpPMgB7CgC2`$KeK(P_MM+zkf7`*2!QUG!F za(v`Be!kd?!1#;87>o-Sb}R-I>=y-5*ogzd7$WF`fAx|C#W#%9NQ~8Zjnx z5e12NKRt6&Qt&egfOHhcCa#By>R5{Fc#7=!j_w$bsmLdRXo#X`eFI@3U$RMO0yJ!q zj_^2-0ojfMDUSqskONtD=|Y7(#! zJ3x>Gf6U^FJd-mK=_R$e5RLJ5`f`yW36e4ylQcP#Hd&L}5*5iN5MWY*6!Vku7k~R$ zUIURo(^HU0iI4`FkW1N=Ny(H?i53R&C0qcMR;ejL5df35i9?q*P^pwr8J1ufX$Lw6EPB^XGfBUj%F#Amr0gqnVFoqnV!i=2#`l%awaH85$?B-=JlDH ziJ7X|nXLJms~MXJ1|h>&VM1jYx9CjMcuX)zhjHkecj$(5XotXAoWyzmoWq%%W%!%O zf4Q8&S%=CgoX%;S(FvW$lt+y5aDoXE3GgLWXir=?oy{qo8`k@XQf1(YskhDk!$M-sY-8kq*4l{ zor#q%r$NxtSB~POQ~IS=TBTx2rDO`G`$wW$R(zmVcLM{1K>DUa8vmzoI;V0Pe?&l_ zjGIzB@3UehxSfMpr*s;qfI6sxTBsxA6$lUo(33@@l4bTcCky4JVOpk>YNnJrrk0wi zoJm8P$axC5yoaE=Rt=)=I2vAAsCwgVG8Tz;ha}uuZ3a;=vuHrhc@493VVo>b*fbpkX7KuZopr^3} z7`=&~?wOwjd$93IunEha_Gz#SOP>zgum~%$4||^stFRQCu?&*`Lmg5{f3x*(nF6x> z6rL6fu@j52D0{IgYqAv^u_?>48oRP3JF_k;vo#B|FH3}JVP27{5DY*9+^ZlYYrf%|z2~dG?^J+s+q}pdzyv(N2AsO3ySxRgzy~bA4E(@@)&d=P=`c%HZO{Ro(1Lr&0(&-+?8%%A%7|Rii!9L=e;xnP6n)Vft#S#8u>t+YYi)m$CcL=Dzstyt4c)o5MS zX}#7oe~Hg*{nlz7*Kn;lTV2+6UDtR`)_JWLDy7$cE!KN|*MIF(NiElgJ=cew*v#kF zirv_V{n(CeFjqa;fnC{@ec6DG*@B(foW0qeZP^Rq(1?M@Dc#aEZQ3)P+BRL&tc}{M z?b@yV(y?9IvVF)U{o1vi+qAvgF3sDgt=qjFf84+A|J$lP+rqulI1SpLt=Y=m+{`U+ zU)|ZxE#1#e-PY|8g$>!-9ogI6-P8ru-W}fDJ>KFS!W9+W%YEI`o!;!d-hzeR?p@vO zE#L9Y)n`rK_+8%lz29)f*!un7{2kzIUElMq-t>Lo25!WCo!|xD-~|5Q2#(z=)U^U$ zf8hXr;TCQ?{+;0%-r*dM&2+=0tq0)@PT~$8;U-=xB3|Ms-r_9&;wO$Kwbm^XLfXfT z+r*9IH*Va+z2mU`>2%KG zR-}(S1CFB35tFszOn2v-p6Z>x>d~w*kl~zo-%@srp6iaj>jADBf70|U zvL5WSK6Wmq>&DLO$PUKtlc-cV1wjdvT}Lqz!>*+Vw5ZPNoPO=vuG#I!brcir&mMQ# z8nkjg=GlJks=n>%Uazj6NAfcm6i66kb2z$qEfk*Y_W=N}{_!3!*aCkTYYq_= zSq5y0xg7$C_f?u9-|;Yy?l1qqD?(?bj4=VL_dhfwMK>g>Ct>a&AM=4P^MWs%Wb*bTC-;bNa&CX} zDL+{|SrAk2r@T`x;Ne~5{ zujf*r=d3}6q(JqfBnhSPe@x*5-mCxntsnpV&RQ0hg|5cmuSxJ@f>?SNL6QRHhyQi? znm_Z?Fa4GZ5|nPOF4yH2g7R8_4{SI8Jj zp!2O!qo_|Ug)9905Bva;0DwS(0}U2Lcu-+Nh6^1wg!oWmM2ZtFe^$hJQDa7q8$EUe z`B7v@k|Rx)M0rwW%9If{ks?I~=1iC}Yi=P3Bta5VoTda&@DNCnNv$|30QxZr%au); zK6M%u>Qt#!sb00373)^5UAcbs8W!wWv1Q4gHJcXgTD5K2zI7WH?p(P}cP7xeG%wz} zI{j{mDMj#L!h*HPf0Qr~2m(nkCj>|k|AenzzMD#ILK=DU=DnT!ex@9nGw9KxJ)0(d zy7X$*ty8~79h)`m*|J^RrhU8iZr#0e|Hd7hH}K)YeH$l!y!dkE&67Vz9-TS#>C&BB zr+&TqcJ1A>f5#r4JNWV9y_+Y0zWjRi?bE+UAD=z^`SRTxe@CE+%&0Ae2}_UyFfydr z2>>9ZAf=3-B$EsYzWySI3nc=OOu41(qwl`@Hq@}g3_rv$L=Hh5QA7_-B=N)&S5&dZ z6kkL!Miyb5QAQVSr18cYchs@R9Dl?yNX|}r2`7aRFem^3Bp8Ji0u=*{6eWU!2nd+C zND8C`=;E@je=fiDGE6YX6the+&ouutO*Ge3vrRVNbTdvk=ajQfItvmI06nXOX$!@6 zf@ul>s9Z=VNlM{@P>z88Ni6lsM86=J`e~`7rkXFKt#-QVow0_R>#4ogTIjF0CR^;VuPz&Hv%zLN?X25g+wHF99y@Nk z)yDg6y3f}8?!NP$+wZ^uubXhb2md>9#1}7IfAPi{Km2jYBc~j0!!38b^1U&Kob$;& z*Ie+=Hz!?mg#_Gkq6B2}bFx6QEHsHtjLK~DR68&I^4v2Az4zU5H{JK(e;<8#<8?ot zcj7}wzIo-72j2PUpO?P)>Y1lL`{}p0Ui^Yaj$4Xt4=Kkb)A# z-~}~kK@L{XgBk1~2tT;N4>|}~1|-8K@~1$Z1O_kfN{|#Nq$gViP*3=Z(1aL79AS}A zgg^9Q5JM=$5f0IaMMNSHkEp~X5|N2de{^CLm)N0`XySeqVG=1ucA%F`VFF2z0vN>@ zMky$vPp})6s6w?yHMWtBZ**fE;W$S*){%~Pv|}FexJRZ4=ynW20_zUK6H-WsjAI;Q zky?h4Kl0I$dW7U7B^gOdPSTQ<#N;J4nMqDQaym!=%ot>|DpF`+V@c4$D9R9%f0@X{ zm82-e{gj{?3r*3MO}r%(|8prtT=r6zzU1XEg&9m@?$VgRJmxTWQm%8AiZ3Tz%DT3g z04cokl{IPQCM3`+9wxJx$ONY{!}(2dmXn<4^kzE8sZMjU6P@df&q+G=F^PnX4I?wi z7G4sQeWs8U>*53%8krH8b<&`le?;g(6`D|nF4Un7h3G>i8c|SfVO^<{S`q|_kdpb+ zT>a$4{Srfxfl_p$CpBqGQMyuGC*QU)dgWH&fe;w+khgF5D zRnxjwsj?NVZH4Pw<*NVIxYBj5Zq-W_uV@hoyab zQK+K`68a>Uk}}q^j>YU{HJe$^Zq~D%1+Aw@wUbaLVNWq>(ssALf9)-9qlJJiG!tv6U01t!284SN#6f@ht#G9qUFrt6y1>Qm zbgjGH>|R&9RwdjP*q}n7+4HYe6Uy2QQc(Fp4}SA=Z+poH-}%Nje)pxX`S5Gs{I-w3 z|E+I;_1oX~0(ig%_OE~wd|(9^xWV#eaD*2;VFydN!W4!uhAn4Ne=-D?A*;LJNZdD2 z0%!QbCpK}4_xs=#|65qaC2p~XVLW35qd3MkrZJ0i?BW~0xW+!#@sE2vac$VJBS zk(C_egVYpX_d_iM^`z0AE;oNVfh}3bI-Q<=*UR1oGk3o%<}i1c%x5-pn#ZhW&tNr$ zSAI|dUC3N1=T9=wf4xgtyUAuWui4LK{&SxJJ!nA}n$L$0w2kz-iT((tKaU=aKJVA) z096{&TY!vJJRDk1f7a8T2KA>!9cogK+SDEqP^2v_1uay%CaQL{rDtgspO)Izr^fZI zb)9Qo@7mYBUehN9#7b)tJ0|NgcCq^;+wCx$(1?C^qNN>ee`*_g+SaZ%w6Trm{1n^U z$p*JIadOILZ(IM{=e{<&vz=~suiM?`Za2IiCQml3>!YN&_r2|%LXY-)pRJ3E1h)2T zf&;tY2RC@a5x#JRH$0!CsL4z8?QeWjoS!`*13g(%S%!Dq;UD*S$U#1Gk(c~UP96i4 z?;~&vz+=D2* z9ur!~oFufgH;=tUUNVo73}ode-^fgEKJ=d-edodRm$j`p_xj%mK zmmmDwe>eZ~H+)fwfEeHqk-~WY9y}5|pZw?lzWViV|Mc@8|M%a&@cY016TkpOzX3!* z1I$1FQ$PS@zy~Zq2t2?BJRc{!sn^gco*E6EI1HM@GxX@P+~d976TuSfJrX>@5mZ4F z)Hc}4st`;7B(N2(8juybGA)P+(}F?i$h;Lye?b=XK@=3iAS6N{G{PUu7QNCqhsdJy zN{HfninaesyeU*XDzw5X#KJ4oLOX)OhX^FSn+U$sCaDO#Ei}U{M8h*w!!%^WH5>_H zBLz`f4ag%m&(IHT+QH@M!6O_(J}g2c>_b5O!$2g&K`fST^T5|oy&80;y(kseBLmoL zf5g}lL_f5|O2ouMyhKgRL_^#}dRxMTFfX*rL8!1iH*CXHM8#E9#a3j+SHu(cn+OEh zs4!%OO^^aIbj4kS#a`saU-ZRb^fU`}ybat5{;@#@(t>)cxlY_fPjtp;M}Ij991KN;Fgqtaol=wv zDFnuS{}jf4+Pl7+zm+IMfMiI4bjXH;$cL23vtT@n!k@=GMlGtt zyBLT&JVo*-M|Bj*b{t7|G|6`)$&@@vBtk^Vo5W~h7uNd~1R0RiQc09t$(t-mf1KpW zo%Bhatd`#+FQpMaCzKh2teo|V83|m#2z1H`WJ(8o%BF^vvee77?8~|Q zOT6^Uz`V=BL`=in%fe(#!Xd_se}uEepfU|)AiqF>Y8uKsbPmnzNubQh&+JUl{7KRD zOrHOIL0q{qIjccBOU)+}M;JUIXj+YsEX~mj&D*ri+^o&rBu#fa2?W`<;`*o*8Ux|1 zICBwBenU>ySxYU6$mg8M=!DMcl+Nk|i&RoheoIc{Pok!8bI))yHvAz++~iIA>`nW`Pu;vv{oFZQ2}k&RPxr)(jpV`ntWW+#&;wP_ z{ajE3{ickToJB?F3ozpz!(>?W5a1u=VIzGh&%*H%SLM6;VHPppaf7C`j)I}{+M`hGU zmDEV3)JvVzOtn-^O;p~IOHAF=Q0-JsO-x7~RZ%U~Nu#oDc%TCMHctOZ-H zt=h2tTCpYCvt8P6Gp)75e^uPVW!%MeT$NJUy_MX)o!q~*T))NK9l=}5<=o4y+|2D<(3M7d{~fr; zHQmNV-P2WFD)iUXb=}p4-Pe`fG4b5c{oLEN-Q2C+-BmM))m_mAUfktf;SJua3*6a7 z-q=;%CMNs+T7v&UE;Ogmj{$La4kmIG^7M9)@ zcHtO)f272P;To3V8@A!(joc7MVIEH59#&x=&dCw};UX5|AVy*!R^rm2*Vg~V;U~^v zD3;D=KpK4UfZ<1q&0J`Q9)CgeXBf8;|hWC^O^gjin+c4P~F_LMAqd&-sNAOm@?+&VFu=2HfCZj!ZyxhJZ9!)c4q0d;%AoTW~Sz82DUveW@J9* ze{AMvU-o7l5$0~*=5N;Kau(;%Yhr7z=56LbAf0BM_SB~kEo@rXHX_e;LQBLWdZfThIX_@|M znht849%`Hh>Z2BFq$X+$b6b#h>W_wMk&bF%{pY8a>Z=avthQ>lI_Qa(=&$zbu)abX z7VEMG>$5g%OR4CsX6vbLYpi~2bggQ)rfayaYpuTP+aTSvM(e)j>%V3=hxY5ie+KNr zChUB)>%_)u#g6O7&g(r}?8tU(#(r$do}jKiY{S;<%;xOH670g87M<~C~Pes1UH>6@PJqONYGe~#|#M!@4<8=dZH?7nX34sYrHZt*T}@}}QB!^Eb|0&QSC-Mbdq{`@-#orHP3RGh-MiV^Ee-KIiKF| z_^1hx0trw80O9jI_wzmnfAl~1DL;q0T#oWXck)Dsaz)Q60ht0okAgt|sXx#2K?ii7 z%5e=WP(x>QL|^nzmvUov@Q~`Tttgo;-v|$%^HryFR!8exDTOUS^^`!$MQQb0cXeEs zV#Y%uPLO~E;)?>mj#RTrbp>@#4|QZ$_Fb-!RNIpIcyMIw^l3kKe`-&5YljWOv67fj z0F)@*_#4A2DZ8b#ZuU0!^hWn`Pj~P>cXfAnc3*e%j(7CFnm_+hnF;8MD4Dg-@$Pwd z?{}AXfTwqO7x;fC_=694gdg~XZ(HoBym=vwHq8qGQ33!Eb|+i*YS;K{-*_~JEl_eH zV0Vvds-?ad8n6qmB(3U-xBwT_LApxn2UO<_xh9vdoXtJnt*wZ z_;P{}L!p0}voHF#H+r|X;an*spHE+rpm_0o`@DyHy&qXSe@~%b-;A5vOwx)89SZxY zKm4!ndVqFK0)Y(K^bV5ObOYb|#kYLKzkJ2t)0XFUQ-=w)|Az>h-}}+$d(t=ATxo@# z*9xG=d($ud*LQts<#luMb)_Gr&$u(qVEVu)cF4~B#0UQ07k=+;b{D#O_^uhrcY4fc ze$99O;Qvu=eIFT|<>6GTvq)?|yt=g38RjX06a^30`>sPH{$BO?g+m-BDvtiS+UF#O@TeWcK z%B|a$?p?ca^YY#67w=!afCmdM+?ViS!+{esUhEk0W5tjsORn6Q@@31BGjrbT8S`h& zc!S*Be?p3klms!?lAd{zr$X3COiQTvw`Azfpm&Szt=o6*;Jkkm7Y=;5@#4srA8)SQ zdGhGYpHr6(eY*AP*tcKruHAd~@Z6DNEUyvBrqc>Enm3`1-I@lnD4@wvze}xlTSRsZNYM3F18+zCwh#!g=B8elK zSR#ris+b~+E4tVsj4#RZ+=i+A6HC%F61juF`s|t+~dE>#n+b^&k+OlCj=`NsvLxnl1@T zN-1tyC;>=eO8X_X(^^|Cw%2N#Ew|fxf7>m%--;V9x#OB!E{G5blpYs3CMd-wtguN+ zOX;1a1cfD?yDz%@`r9wS{|X#1!2=syFv15boG|~x-!fLOnKO~Ll|WJ)d+ZidNT4a1 z$9|eXYrTqG>&U#4e6qJxuE4$n>%qYJsbIdH$e6!6t`vyd%sjY_9#8Py+e~FtL zS9;o;gDF+>%sMmuG|o*+4YkuzQ$4lSSy#O^)?8oxwbqpZo8GVpM!{*DLl%_87^85y z#f17bJomzN(|xzydE>pe-hK1^x8H#Sj_$gx;HV(H!s^SA1X>82bg_dAKDp$TQ+~PS znPa}W=ACo?dFE4vI#o-Yn7I}be+U?R+7==0l+$c%+PBnOa&~F#a;>5@iLy1Vld>F)0C&Y_j=?(XhpfVt=Yto3{W-oTp0xo7{5eclGH z;=k77kiyyvM6P=CNU}S|MwqGY z;K2+Hd`p^&_K`Cp#e5q04&AK}(4n9}Uy>4+iG#!@MeLH(hYpE2PRC{z7*mNsYld;x zacKgDG=ilvKcc1+N{#pEWI;okt#R?t-Q-@wZGpr96a>uAfyLHFSd0P(u(QI|c9LWLT_a=|>Hw7}$l>qE_HSe3{{X|5!Y2-NRW>Th|i?}Lj73|7p zGMKF8h{+KwHnHOtbE)33#`Z1;n9{zJf=Nk-yj#*)J(?)pzsXiTGgGU!oU3reDc4&& z)@U3q$z!n-f<>fLl`IZrOQ2OHL5788m{*7Y?uU}UDT)>0iC5|h;8ghHSn0zM&Ua(E z)&#v;8DM&8KCseg9>YwwYSz97Bb>>L7>G~DY;v)|i@+fLnO6`SQJIQ!Z75=7v{Wo# zo~?5gJ`N+MUFQ6nC5Zd3H-lK3rD%kv*XmQ2_$09g{pviHTT4G9lVf!G>Mk3%Z8)uN z|7t_vk!4w{_7Bl_z@J|POf=mkUCyDR31Oob&s{sj}PSnxP7cp{GI6>Rom!{+>!P7rN^lh%<>*eE0*uzb==qoWB)Do*Ur$XX zy_V)(n;PnU&Mb|*R{Hn3TDll*&7;$oH-aY0W^8mLOMwIxY5t=MNwbxi-WGrWu?jR) zdEtHIwfr{k+6+g0?gLp_f&ci}4UW8Y#M9YCf!@udR@sMg3vA&4Jbk1>4#AQcJA`iA zJ|FWwz1!@~K>Wq+q3@j**9%XL+@#7TD{uK@LT8q(i+9Yy_9o-T`@pOf@m3})0I?v* zhTg2K7)XkXcX`DwbNpLBVa4QoL)dT?7Fhcf{n$s}c@+T& zgt+D-wzU>qOX};Usgc>meHjzVYvnrx5kFVP1H>nVBppX8(a#&1@ig>}0(gU}7)Oj% z$qhn9GY@5+Q#Wt79UC1N{^&1T)q1xiqVIT<`)HdwpRJ2w+P|y28ER7_EW~tm?$k%W zo>_H)S0SAe9GH>~k&nKYVLQZb4YY9|hQQ|^g6_Kg z`2H1177h>KfqD&4_Q_cuSC}b-3b1}sdRl(;#487OAoY$kT=;t};_g^uvMxzQo zqUJ^Z*d0NqGQeQU#o{9UB$WGUJvcK!<^yn?ooaxhd^UXm`h`@NgXc>6x$4(v$K20J zq?fmAQ#vS93Gf!bmOuOE;{PMXAJ6?VzeTui@WoZU-vK_lJD4Od6Ju++&n2VJV)d)< zmwt~`Ql>mou5DtvJaW-(a@=hSDl$qLGAg}1DlIZapx=r1=P1R~hF-}H1Z^*Ds#@MvYc%8?vpT~S`$oy}c32TRCd7BxRoGodx zgl0`1C75jq8WxZfZ5wbZ=X2=gL%2+Lrg?UHX(BmI^Lc{Ec@2$tZFhL1j20hJ(IiN* zJM;Pb$pyyq1?G1mK4%H;lM9~b3*7_c!f!hppTls^Kg_1)i{$Nyl#z=v6^L>f^DQrU z*yM^)?TRTEh_e}s8}5oLP?&qmHf#ppJ>^TmQbttuX_@}g zD>5}Tt*Xs4HZ(Qh+cyo`H;XC)43!l=yRs4T1YE&l9V6`5MJ7FpI&SnZBrbORaNvkW?bP3=8n>sZQ0%Nu?V?xgaxST1O6}@O z?dDtT7F4{RV&X(~;KWDmnP%phMD3-V=#W+H*>C1mMeRLa?5$Z4VY5a&w{HSPx7ODN`i~Xmm1J=y~xJtM>4+4Zr0!(QFEop*m552{y{{)r% zi7E+6q6rn9{hd`3T2&I(L<0#|p!RSriO{2obTp6jJd9W_iP}Dl+^2~?FNyw8nzKbk z@kALrUSu@7(Z(y3Ca539N*pDw7sZ>F##z!P*;>?j>?cQ+CMQ`W zXVE64S)}FBCORIaby}qVp-rtTO&l&w>p#jUDz&|*$%eJaexu1jD$S-V2t>2Yc`MB! zF3Y8&%V+x-(n?d%X^xGR^&?e&_|OObetCkl;tH^=10*LOCEUUVwt9mP|{$M51M^}z}Qr<*YM?hb@OcOwB zRYzQ2FG1g+VC5l1SEE-BsWGK*vOQ71rfc>sZwjJsi8|@}P1h&{U4*7@ttxLQvudk5 z0dtTkzm?Ys9IA#-%)4U*Q?sU zF6ys~=}fZ*hT=~9qN@6mY!8B%hU3nU%Blv!s*aj$k2{$T>zED|tWT?|PFtDI&zVm1 zsxGFh&hMEn*Q+k~nXbmGu7awDZL9WVZ1?$?Z-kf)lx%L)t8ew}ZcUlnpKZ2XFSdQF zAA*>}K36>^F+XNiKV2K`XjR|WG2eDpgZrx=e=y$z|Lk6e?ZE37d9F6kdDW16=64Xg zTSWVJcq}j^ET8tPVOK6-cbMUX?E{6#;ZN+~R4(B)YJ41-VeM*Q+*m#W^iRGuAERp! zk}oYQtKak2AQiJfkc}?icU>a=N2l5&%Q9Y&O&)w;zFx0Ed$_!MzPzBixctD14ynQX zV+EID2-jkv*J6t~d}6P~fxE(%WW~|H!ZdThF}kv1Vn&4mtP?f($*ctB z@UL;KgweHx*$#v~$sgibh|+6_7FgXoXNfLYiEnC22C$I^90k~oL+~MQwGj_zRQ ziya$-TOGrIN z04VxQ&CW(xFL=N%BxXW(!p?eeBmD7Jkh-7#pibn1O_b!;eyL8Bx1qy{`UVm&XQ`eBIsaQZyc3B9 zXStaM1%N{WuR*q%;|HYS2M{jvu@O*0=TuI=_3!=ZX{t0L>7%I%`&WW&XN zvz>4wknJSU%&rl9#~*&D;npY+&&eI#sFlpA6?~&r%&FbgsL{izJJG0{-KewUqI2P* zd(xOWe*;m~=2Sy(`i1ApzQ$qjvB^-M)99vBk%ZHjy3ttjzEGjbMCIND`4Uj(ZZdGX zH>7SdJNaPf#$}$&Ws%*q7UgVe#AQ|8WTh{q)YW9&%w^+dX`X&>mVa+^;(8V7Y74oy zh2yq`92!kDIV^BFYP(ufb32hVJ2@SgqBlF=aJeA6xqReyRcv;#b6b&cVo&GPJaBQp zxzmky(@%Hv@Nf2PZqzU4_5$)Bs6%hOs+-+sxILP=eS4aHR~|fWn!OetyaD!KW)Ch_ z+{X{|hXRul28KnGXgc>!8{wbW3$!Rv5!rUcB_0N2*=@C<2y z@HOhUHCk0CPk3bA@HWGFW(R|t!MvbwaAq>NCA}?c0N`yYZfnbKYxUx9?*eDJwRJRu z+Yj16h@Op-{7qo5rslo+2mZehufHD!3=sJ1ncM5Qy`_&?dnDU?6a^$5_`3C9{u;Ia z{Y3ZA&U*l8@7D+}lNA_JX&<^FuS;$p&K4L6MjSK~81?cV?7FHNco`XKADa*`q46Gh zco~Lx1B~GINjSkN@`%0zf$5$0=?#IIlb30JpBa_4rjxez1^!NL!B&;m4l}{|0r0%n z>zsyQ1|ZlGF1TdnvzYC(kleAH?z7zGqkjXQq4t@b=vbSX>D~~Wy%1dE7lcgU32vfy zY(9MMMep4DD70-gG0I_N55>D`EJ{FMisX2140MEZ&Ip4 zd;Gpf+MOk`HAnG6C)J%h5W&NN&cg}ceRrX=jm}f#*X?Yfi)Nu?i0?(4&?VUSlDcoU zT5#E^V}%}A;TB%?7rYS`UXcZ^4`|yYh400>?iHbQD&$_H>mgY9F}my8a?TbitOCWt zSJe<#DIqYr@N*YLxJT%G0XScQ{KFP{+Yp9a2!DGg^_2xC23{J8ST1Dpt5_f^oml7po={Rb!yf3M7dosyX+W)*ul}+W-xFJT{8L9CH3fWv< zu&<>m=1O^z{$xaH2}jDs66vBoWq_)sTA4<^VP;z5iF&nq+n?7mHS3>sW<7aC=}Bi= z&892=Tvn89bsJrG*5}fbFIGBy!Q1ZzH0`&#g5XHhGg4T$`l1NQg9Ls$?hPc7>(6JT zvYTYUilAc?rQMp(6e?w_XQtg9`(dFpC7pS?+AKHQ9Q1c+yVV}AM3iy?hm^BqH84zmwO zQngkbSJ8Pm_p^m_t;A=WQ7C$0!FscI6QU(CUBY>5K9x%g~pZQMM_8o3@8pxtTGpEx=o4oaZQGW}NS;U1dVxX?14e zIYTv^i4xA_dQcQoP*q73k27IPk|>fxNs2nLVM>OzLsdnFyD?!#Q4p0=<%c9r&lq%rSK>o65bUWvlQz+!Y(h&{@-}T_5Vhs>7HK-kQstOY)lAUS#sR*TL9= zo=$|=9}NKVAa_GQqH#^bNXGq??QrD#&pVOQ z=Pm!~r%zD*)E0j)QOz}V??3p|)Ne5Wv<;-2oiCYYyL}+o%>kI81VOL@pGiS7v=3Ws zW%_(nRGPJXT#~!LbXXBsMPOCknDl&7-m~v%Rntbre^$H1M0i#|7u06sYCtrcqpng( zc+qipuC3k)`42utBw4lZM#uYd)t5-nVBb%op>sV*tJv;1$T;!kW{i5F-D#W=cxC6H z6~CCdomLXoy_@~1zjikdF#vS$7cH`VS~6^^bstteHr5{2{UExJo1rA@kMM#6gIBw$ zA3L5Va=Qec4+^WZo=-9-d_B)9Ar2nrwHtab=L1mq^kz(vHh6KRg539^Dj73sqB0COCN02?DuPKf01l| z_!Et8eq{YDipF>0MRuEmD8eI(UG*c7Q6LvtNq@f*>8Ofdhg0lMF# z80m*vA^0!bn4uxxDXC&ZWt|N_<+^vDBw_?HWaQ#lgorb$$A)X18sc`f0OHW0y$EA$ zBm60M3C^I{NK3uzKn7iTVl!XF1;+R zjLg5d#0F;*1`Ur9t?jx*pnkWI(FY`Ja2%KNm!XIuHdWS$i9J;#eRFO8P{QJ$L)uiq zK4)d>7>{jz+G+#_P8+_0wWmt@Hp2nWn6!dRQG6z4tpWDd(|3>l_^g`-Gr=nlMW5FG zkf+-NA(&@HzYhsHaGZx2#qJ7$Ri83ZbkTrN#p1Nf(;!G@+cYung`w7K-I7f=oNV72?5=+E5 zkENpNR11ia@?~5sm6D%l3q;k*lnRg4Y9Ge4Y+Z_V_l`9>2-It$Tm(6APJRxhsW()q zRoa-8r(%Vvl`AAx0ZxrpYFBg%?ei|xZgVI4Pe__wAChYPIj!{#=H!0Cc1LwNpR%Nu z%Jm7|r^Gc{GgBikjY9F$WZeogTEeCAsQdZ?)H73YZLPU1+=et8Ml+3-dX=K2#v+$W z8w2f?#r^xHwn7^_7ebv)wB+W#*GdPa^wkC8WYBP*jZ< z`3ytsYRx{tvW_Lt;KdN%$h{x0$sxjGl>Fn)V?XEqRkY8S-7gQCHK!(5vDy*Fl+>w1 zvg-z+nXAUMlBvTgO^%6hD7&Oqse`)v*U5pWdo2DQBSu_KsZ$L59PvNL_#Mtyzs+cF zi#~wtaG4zQ(+~JgmYQ9&>a(xf%|sqjC)|qIp$RBp{*6RyD)6m7--~laT#;Zp!q+*E zInrF7n_vdwx9(iT`)Z*WoHm=obyuQjW1(98JUQvVAF(`Y`ZJq%z9H)_&4y^72zt=p zQPo&&2aU5ori)5+Ue}gqSQ|5IFOA(}C8QGBm?>&6&lTMZ)Q_GRSfwwn-gDJ-b=cgV z5v&4xxDTKoT-lU=PFW zX+)R!CUK;FkG>e~O{tDd4z;(=4Pol6um zkCY`L=>_F_w2>~31pxP z@IS44OaZ>PY0Mo5^3}O+Dp_}lcCRP8Bw(_CS@(6#K4%V!>eClOcO8l)u5 zJDG(pg8}b7>THUjuGbqU5}!jwlI!T=x4Xh@pHsVR_hkLft5kH~D{|mvgBtQQmi=~_ z54;|m$t^Hanbo#sp4k;cUe2-s$nypC6iCqv0&@w0$9TJCeup>-LBp7jf z_a+!Y8H!-*7_j~22N#UKxNoqImml3&5d7>Bj(WkSm+zxlKB)XajGcS{H0=@NS&(w= zke`CXd=bZW^i~%DCOudt={UlxsCxVHiI0dH`i>*C{ts9GDhJPX!HVlPh38d^D?5ejSNl0I_AA}t9Lp3IWi3{w1AaDZeijBnX)S(vEPn0OttAE~ zdMzR7im*dT))V8)5bKwbt1r?DnANNV^$tX9Qv_Q|_C>73OSQxsKrPWxE#i(D##k-M z)D_7q>(@20uZUP*QR}`wp_1TCLtpr$c-O?i3}{4gq$G}Hq}ODLc7&MMUwIwL1+l(7 z)sl-}lS?~Nz{wCXvQg4IQqtB@vc^&BPgANXQyIh&+r%C?vVIP|`s_6Q+3lJ-n3YB! zlO|Z1#w(5*Dpa{l12k8WG||(v@mO@pbu^{dw5g7Cm34Ia*K|H{^odwa-(#ryAZ%1a zj#MLcRGfyyQ;rN{b&Lzw4AL^>@^utPj!b89_o8)7ch^i$j?AeDlpE8GOKdENPAqsz zjHsWOa8y|E>RHAn5YW_WV9se&o# zQ=RA=vFJ1Fd2^lk+OB!aRrspe`5NQ-y6X5}tzl|s_&b#8>f`x`R0JwN2||}!mYf9E z*abOR8KUcXfFmcNvv{6BRMv}6yifH)S2x1wVoXk-xDlPdp*9qn*MGyg{e};57MX|% zcy_8+Wj8bAmw)}1zb3Cn>b6`B}lu>O1ojp6gEhg z-b%w9Q1jLbZrunj#LG6<3;b0T>{68*n~@!HmYr&lox#@ccKWY0x-~05#v!+%s<3b? zw}Y*C#-Vt03%ge@laKwQcvc2cO`QeGl5+mQRQs_|BN=9=vaXA|L8Ceda-k}VqvqhE zA>SBj*`VQqqv_WbDRe@-xYdcNh9NF!xCU z%)RE*XmHIExy+JriL9C|Qn@U1aVv`M%);+2qvtK_=gk|NtU&ixDi8!6Fz2rU^b;2QhjE@|r;-|UgN z;6Zxm6x!@os^RYB=2_Y7S?lISq+$J$+ir!+{_ld_Q{QDEc{d)wSXZC|15zj_H z>mSmVKeV623N(IWYWm}~_~#G(7S#0Pbq^3vVbogqqxASk#eFEnEl}+-P?twe%H3Z! zCCFwG6L1R>dkk_(aVF#mGT;dT{8B2kG(#fYL!#Yf*jj?zT0*^2LbFqXbKS!VHOJbM zeV-nD!3(}K58k!z;dgiu^)2C`$MA!FpB~Ndxs-^Rl*oa_h&9cKJ@?41mZ+VoUzsUk zH!Wcgnp=4-(GZ@P)W_iX&!MQG7!%7s=pM0{sj>COp--AIPc#FpX zLQp)>QhZxkh_FY3NNR!#Ts$>zBFxYD1E%1g_=y@n6UCEa3TdscMtRUBwlo$ESA-t4y0=;zhsm2P{f*&qvU5$0Eg;#l!*pj7R43jX>+ z?fOFgno@9U@e1eGGDteTE!N8{85+Q8w-2T7-Sf2pwfr4Z+N85>Q20V^0NggR(sr~W z_D{R>s;%>Y@NeEq$1AwQW~mkCr3HYVi?Jv&i-{|kW`1Cp!+D+xH zQ1Y}T5ZuhY+AJv0!mMLs4esIp(#!bL%bw9ArqidS)1#50r^85|!1~uB<1fjo&_!E6 z>1sdV(q8lI)$P~b9jnhAfr9QmkoeN?C7|JyF_7y$l&%vn-ZohIG6?r&q}qGL ziG4VAb)@C($B_@Es>GN&_A?>)AbF~-j}G~_+rlkp!FF(NR5sr|2>qfh#! zPkwa*Zmmx{qwjr3FJ?!Cr1vCh<|I1NJ54bE;5hL zn`ACV3d#($FFL%=r@p31=`Q9HEmGbsJo*cdYzf6ITRJ(E@_g zL$9kuf@_36YmB;UE4qAmYpZJ=(?`1NOC9T^f*T{c89U;T{uGpqf(z;@wZ*U(H(ld zJ!T=+_)KWdxM$Mofla(;v%Y%(ooTf5J#bj>qUqf86WaGD-VgQN4;1<|P(N|7w!YDE zbmg;=OT2cGd6cbpOhUi@OHB6*uIDS3l>v^P05l77lVNO>)N!6g+198(2T?%R0$_yR*>0 zHy}ARgRIvIJ#+}2IzaAygdaVG@4Yt4Ue|6_gr79Bf4mVtB}1OTLN{W-vk1w9K=yNO z*K=V=z1qf8M;G{2?+NJAe-U}jt%tlcc9B#IzfN_%z6QT+kv#tuUQGku_JrSJwV#&& z$QsG+J2Gk5cW}s90_H5F{~&=#glyi2EM&c*SoF&6=B(uXkp%pXFNdrYgRx|?aRL@> zl*5U1+I8MXY*eGEY*y3l7VOmHnS5T?FGuV&lewbdSb~-uw19E~;%^F!L+s<3B9(mQ z#IUUqZ&#`5kQ&m615!+YKdF#an>xEWvusY)^$nm79&}+eVwcY#Sxeu(ty!rPJ z@&l#!qWxxX1RlS{ize&tP#isXk@u40?q~+LSO3e;|4#a4-Kcz)oeyU#G}dc{jb{p-Sme)Q2DO9U+#`3^Gm+z{CCq|&Ry)g=6QR%+6*n^>Oyb& zP4M|RH!i(J?+}(Cc+QF#LTb04C=nvEg@j-4o}Vu4Hb{O$I`E*)(`~m+?>9!8ul^tS zLcRtdaY=3Y|CbU5;M#_e1{1g z0>yceY)HlVv7(j-1%Z04G)2*tmKH@>w$B#DIZt9M_|-Y+N#wWo%w{ZL?|F zi~`$$*0boVKv|51s$z%@Berzym*ezy9k=UNcAbys<#t^!Zzp!$@6Z|S|H6}4+y6sg zuCVWcYd*5?MM=)BguwwA9ET{H1DI?liqh*^cbibyN7%Xq>PNT+ zCPX3aE5qMO&YY%2$U7RQ#rUC0s=Okj%bb!vs>{48xy`jC*5Jb3kU=z~D-m=-kfyxeXG18?^yidhB4`xeK|s>k3bq7QUv|O+#S%502Vm| zR5Noif@79of!L*^UGk-Qv(b|&F^Gn4xs>qL2er}eXy5MNGVLA!Vo){r_XIF zTJ7$H?Ct)e1oQPpxFHHTK9pGGNK=Ze7+ELS(Cp{Q!;ErDASXNt8Q?#Tjq-v=Ju8#! zfQAAwe;6o8(Ji1^s$*;zd>9Jh(4ds4V_aMX1tqh^kep$h&5=_+(G$Y3Vi;R=F2x=V z7%(4J&2vmFL`kAnTN>8PQckMLC}gw>l~rF?P6naKZgdt6{jya~>HPw|;0_%%qGL}T z@}gi*MjSJfVNaXQDB>&+9rHhnOKLx*;_d>C+21Q?tW!|)41_AIx7TGH2~hKI9F4oD zIc1#{Pz#(LB@3H6WiOmk3qy`3e#1KF0I(g9bauqB$v;%ixkxrNVkF0tkuF%-*p$I0`7%N~A?Rr(=s63QpVeFh62-C%3AG{-~hEF-sFqUB(Vq zUoBOzJ61`9btw_(vrr6up2_>bSuCnXr<|TPTO#UGCRZ7YQtqi*4E@$fle0`C8AwyB z3Q8<9I-}F#sjrXx92%m)9umAf{_-xkkqTKbj*Gs z@@saHxz>i+oEi|NYqA`(=V4Sbm@r!{jT*Yvrv*oMg|TQ+8b z&#iS58m3lTt(_Vs)Arz0xp^S3Zw58D?YvgGPardtCb)GRax#0YjI2KlFLYF7oVqrK zZ-9@vyI$6+0Y$2{H>Afdcw%Ul8nG!>mehG=!s1OLu<=3m^FOpGmOyI2dJ9Fvz321( zWuWkvRuCFXk+W`j$cmQ6r{c#-OB2>GH54PfZ1;ZpB&$eSh8@D@#eUX(hbXTIBa-05 zGUm71*uXEl6c?I9;u6GoJ^4WL?J=R_xGd@Q$dv)?r*68=$(mkA6j~PaR^| zXKVY^Y+-woI)`G$o&PijIKGwpphrLs;7@obyw~fqp9nXFky|GMh#PW|Bh3pK+{YuL z*fUAi4*6sWrlU;p(`f~dq>Mag($1X=g?;+BsXS+cnK(-oZH`qXnI}tZ8_V?1jx~Fh zK{@-*m4cZ^+9!Mq?MQ$w}Q znKiU!I;ibj=M`yf540_AGjW3^h-{ss`JsM7bK6+u3X^15uXa&g(gov%r~c0ky>aeN z5Q@DQ*lQDJY2U=9CB#JD+B&<{hzlbIcfc*$3%zH3L;=IdDFQq8oNt6hEBv z=mn4D^jl{*_s=qJbdFUwJmy@Y?n>)}_O&LS=i1QjtHalh(>}S(OuXGT?)TWIq z+~$9 zx_y_~?rC&~`#6-u`vCdkY0}^KHU!V-NE-4qQ!aG3o5d4%XmcHtyn06-?Q>=!3=Ufn zdLTS_9r9rYuOX8>^ehNo#oyZR(04ubnY~>lL%{ozkf-rr-^)^ApOXzeaQB99pmLPw zOugOhX7)d&F8w#KvLd77sJfOB1z+0NrAb3N8v;c8^>Ni9ADyKBbAi;E6sr zkG_(M-roth(rk>CO`!76T~cjo(#ruO2@N>Cdig1XeTE~E+zb+O;p5pHI+=zpiTH0nppA} zuXU8D{gi01q^LTSpz<@x7(dY*B+Q(eWOkHfgP3gkq{~g4??6H&# zc<@VgVj}2vnG$oJIW=`;C&Mk5M@yO^4&FCM>TsiXiR|;fqrDpCdWp1Yu zsas?oNoPV%O0yntGEWh+&+xNvO3inVGWR{Q??bcu5VM|-viF{{pAd8E@-nV1a&AL& zKDNqTl;$8F=b$3x68UAq(dFQ@X2W>qVyETehUF5KzUh*@Y zMk|lbGmj}Puk0z0%@c?zmd$MJ<6va&gW0d7m+Ctvn)6_h=U@lvSkJG zt+{{avgMu&h>r_Y%QD|T=lmqd)h;U>n=RBqD$=Jb`pTNCMpvjFR%BLIOO2GWDBhLB(@E`*Rtz&L?{&fV3*wEGve>Dn=}gr9I0gEX$^}%4S*#Y>+Bv zJuCC+D!m^{faK%KRHUj6Ei5mBsy(`@eM=y{6^a;Eog7!mT2@|^RW8z1pDla-hLbz= ztOA!+zb<3$4p+mRRKxK*-}6;}w5mbWuI?(Txjn8yXse-Ltb~`Xg|({1M5@6luf-dw z^)ISIU#Y=BuEV6Sqa>`OmaU^X(Tja5Z(Xh%c&-};2hwa{5N9^r9=iMmpI>4Z_Av&RU|AMndE! zx71pF*(L+4ru1eZD4uE+-e}X-cwF3MTi)n2(qw0)01)Z{e7)Ksj{GSdU=wQmI!c5HZmxmzO|6B zwHUdz@o$~nNUQWpt0ZAV4S$1#Y?~B)o4i+BwRWnIRa+xrTPJ^82Yq`NfBRqA_5rK* z(dUMO@Yd1v*75M);m93D;PzSmjyrNtS~<`G#XOUPL2JmJEA*YSJD@SEjsvSs$dRn} zba>}ndFR4OXV<^*E!nOetFHTpj)(HDyOFN8o6a-jE*O;V`H@bDRreKrHzd8|b)*}7 z((Rwrh05?3UGA@eN+%R}`IyoDx#BPW=wA|O$E&5aO78%;%9U;00WcHrOO z%0Cu4x_0^=>eGK*Iz3bsJ#-O0Lxep-FaN$pIJO@BCDrMbT>TrR-u*SAS1zOXb9*-q zL*FO4zDJzDDiM8Z6@AMIO^VjNdaJ$q-paBJ{U%@f%~1NCS-XCY_G_Q^e{k)C`lvQ? z1K(r&6(R;qy$32J`^_^3d@BZkFMWWM+@Oo~AhPFRsLo(GgOcvcfEUVOEXoih{$()n z%V4scBzb!ecl!|UXip|XZ_Z0k{+Hf@)t(}qSdK5l*`vc5tHW8R!?h?Q_3guT746^U zhEg&{;;e_1X9n6(MmxTYPVqp=)zOHE(Qxd6fzi>S)6t9V!B&RR8HTaB3}0Y&bYxy& z)HkAe>~w4bWjwKAbW3i0Bcg95Vtg@Ud^tn5t780Obo}bH#9xJfUuXQmdSZ@!{3&AM zm7x=lVd5lX432T~J?aF!&typXP_@oTL&Px3>txfHDKy<-9G_ua#$k4psm~o#cxO}i zuTw;*)8vfPl#Ellb?pd5lZe3BH2s;(hqGy@j`~4+^v!zaC1M7*eu8^!2Er@f^iFVw zk!Y3`b!J9$R+4d6iZOI;bw>VdMxm0GAG*_wItNoeD;7DYQ8}lzGyXkuR{w17Msr40 zeqPOHo|kmata9FjvFKaJoI&Q?FQ}3FI`3q&U<{h1CYtt?pZ2zy#s)eTey=V3d0jyD z5i1jz3d&pzADi+$TMT$zgi@(7y7siX3(1TNDX7(HJ`1U5ONR0j)*bWtjLQ#k^M&%u zrKs7quS=Qo3zci_x@XJvV@r*h26Bs7Yl53Av@7qUHestapKn&@&o&n_ zCuUKhV8v!%%G$R4CUVqPp4TQp)+RpM_Ek`IzxWoV-WE3_uti<9P30R#(z(Tiw(~Wj zg!z1nlX%-}VUwYAn_sVwwrWRse22h&i`93BOJPSmYFDyq_q!hZ?%Iap_=eKk#+l&W zh3?*!{N7L7J;Crj?ejfdwEf4*J%jTNBc}blFPqZiyH>WlB@0`Y>-!F92gHdxB51pA zdLXdbfmhT4aLl{w*tz?Ae1D4n&`;gZVIAddJbSD_{v?YNNXSoq-}sq-jv{HXl=xD4%tBdhEkOP;drzRBC3nZo`L z-@VrJla|gCF^-Mq@snQL6A;>Im)_}M=jmVKf;E2)V2J3XLgB2M$;od0c*gc@PKZC( z=6J*Rc&+?kQ{j9|A*?j(>}dS#Sjcs;^ZcUo?8+A>%=UcG_5w!s_^#^Wp(^Tn{CpYh zvg+*o4ebI(|5E4b2wZjfB9#3!eu>_7`4V{kQSl1i?&~1JSNAUHpAfG30j<_xbyKV@pZ_;8f zG3;(JnUm0>Z=tL9vz%9A=(j(ZZwc>jRP63VqVJ@?UC2+|>P{5OOxzh<+@*68O^u6d zFyH+oxg$2cpSJy?R()^Yb??~4v@UdS82xYwAl|!8Ja{PnIKU@)hw(V2hS3D`U=jW3 zMDmcr`Y2EOXj}als`wBN2;oq?SN`@PJ?1IB>%DfjrTO+174tJAa8n@sOlkLA8vR@@ zoIln0J`EjgT=U-8{)xgstpajgEPP!U4Q8JJw^f5d7vFmP;8k;<3ch_9tbQ4qcq!Nf zo)9tMC#!L$m#5-rcAH^MCi_D0q=~dd#T)JiDIEdA>W*(T7^cV>0axcl}3$zpR4>k3K7|Q^NA$g^rT9ydgZ|Y6p%$2 zHlp-9gaBF-S?N}fPw*bvi)UL+R@!AI);tbIy)h*GNjj<5dxI&|X8Y|+^?Tz$j<6N3 zZrbh9WQk<4X~%No(R_`5IrK-~pDnhS9qf0kG@q?Qx`7j1y^P1J&B0)5v(8n})&68Q zeX?HW^WEWMvDrcATHD?EW|!4M`tHsuNV`DF0@Pf%A#ZdT7t6MobJe$d7rY^y1py$i z*mm7-2=J5NV31WWyWf8@z63-*l7F}V_klKGQWTMX_VOPp`|hO}x>&EhXu(TbcmiXS zxn16!0*Zxr4h0K61HnLDt;83*SyoA6H|RC?R{)z$iY%C;PKqL0l}(x|*|~0rYGxU> zAZ8x9JCCB_Gh%^!0S!_=bx(k@9NPd#y&T7cD!V+_f^)q*&qe~f0v~YDFhh3nJLDb9 z^9SVLEG>(o+m$zEDBTS8ABdcL(eU3-DPa;O-a@9JO6sk$EPJ5yj0|tztcpB8&fScX z80DR+ipqPJIb}*QSU-v)IXm%x7$zh0Et4x0MJLx{iF5keI9!^BYDDvkhJMhVs!6)w zJT#w7?+N`G1mdHwgT%rW*WTO*J&JHIf;O02fi%2rmlz?kmOQ zF7w?vyZI#JPo!^^B2G6r+9|n_Pgh-7H!067Wy>dxn&(lhcH_x-ei7cv!Vw#))+;7-*yR- zQL#mZW$cg|DlNXwc!#&sV_V?EAO2LzOBA1k5b!%O!4i25Eac%|J( zb1$ROkCjxPV8UqOA*;QQmE6`qx&0#{h5|@Y*73axrz)9c{vQA{LCn5|crqm4@|P~1 z$xCQHQ<~AF<}|HYO>ACMn`b&s5M;Fl`~P?>0bA-55cYBwZq0>b?Lhz%_9ha8Jk2zS z++;l~X-`bzQ1mJG^8IDX-P?XQjex|q?D=2jIM~H2?|6CtN;l!7s4%?ZsAG=31&@q z;<#-FHJd{fYEg-LRHG&psY_MrylRI8aoLn0Q{vkc8lp@xY~g_iK^!NnsuRUEm8oMD zYgx&9RsM~Wvyvj zi(1#l7Phk`R!)E*I5J3eAXL?AM1TOW`lPZI;4D{L*?L^#CKtKORc>>c`&{RG#Vk{@ ze+w|h)yjs<^q0v*4|ZP?UFe1vyyF#bdC7ZT^QL!Rx0;K z4Q5CqTF#G_G@U1XX-j9?(U{(}rX#)RPd7T$lm>OBJALX=b2`A_ zW5`W5rpX*8|ec*9gZ_ z;R%0v%V!?*o7a5jDejwEf62$!cI{)ZJ(nn1aU$0lv+O2H(ewD;m+fk&{p@X@_S>%= z_q*49@7bs^c;c#y3q>fdI{&-ObMqAlQ|fy1!TxRE=f3v5pMLbIFMaG^U;As_H$zK_NUrBK*N26sYde z0#@KR((#DF$tsl~f1LFb2pRh;=OG9qd$m&ALQ!iqFXTc}>q0P8wJ#JyGYmsCB*QUe zwKdejGgQMjWWzXg!!(rt!#boxIJ`qwbG4}wobXc!0?I2cQ3*lxKFDz|xbrQCIJ7q0 z!!p#vN6bS>w8KfnLrT0vOUy(^Y{W>!#7&&TPlQBH97RywF@%oJ)1oWeo3wg7-Y^CP-2Hwd&W zWE8%1y9^KXMQjwtV$4Qj5(MsJKhe-W4RQ<}e6LXc<$Qcy>d5SJ+sKqX{CnV3Qa ztVagCM|;die|+3WeRPSB8wrAQh>s8)PAG$QT*q#iF~y;(KgqmZ!NQj(BiL6M9 zJTLY_0KnL*`-njytN#!4z#8W(}vl2(w%y}`4e zsvV2$%8T?$u>4A~3`?TgzsFOWSfMEtk%HRcFLk8Kf~Z2g1FNyDOR}^}yu3@j%u6}} zy|+LQufsw9k}E;6D~S?`fGdfp^EReTO2$ZX8WY|6)!Opb~gl+mllnLuy?0-571 zLPHO|e*l1$krgTM$>At1%cRW7tW4HiP1lUg)|^e`=!u!)1RBG?jl+_3?7#gZ16Jb8 zy&O*AEKcG)PU9T^k^(}Jt%?G*>_3X&7In-`QbRl8BCDHkl7KCk&ho1tMwHPP#nB{< ze^Dj1Q5wBbCXLb@ZBi(mQYl?hE6q|Y-BKq-(k`XaCk@l{k)pF`8qWeF(Fhs3;Lp<| z(*%9f1%*=wmH*QRrBgfg!3oSP|NJk~Dhp=R714A)+fd0nozplyR6|YFIz`k)RimKj zBX$f(ndqr>^iiRpQ18rC?%Y&O?Nm;!f5<&iOH{IntbEIc9J#ssR8L)1Rc%#PeO2)a zBW8Rk036M5`iye|41?qq^BhV=E!0L0R$v`gMkUr?6(|G%fIUGE)icn}n32nRr<*(u zLM>KnJyvWL)@#+)Y&BN$_>#HEKbzRUs}z_R;YwIdS65wEb!}I6eJ-5hzf@(se|Y^j zqX?^abyt1OSAN}Bf4!4DVa6B?u5m zoPzcMnZz*vCz%>pyXsZV^MH?ZL=jA#RH7YPsf}8yom#6k3i7xt8uTL} zOA)#l94h{*y$Y+@)RFt=-t&-QewA z;MHBx4gV0n+Q|HS;l?ZLFTqA^(B4=Wkx)sM2p-s{a? z@}=JLO<(k7B(h)&b3xane|Xxeo!_giU;4dY`wfa#lHV_x*Zjp_0RCSA4&dZekPef* zj#ZHHHD3l+-}8Oo23}tX4lL7I5ev32>sVk1j^GKFU=IG^4Gv)vQJB3W-~m2i6HZ|i zUSX%eR~3F?7LH*Up5c^>kf&+S1O{Ob7U2%=VIB5i9tKto-eDpJf8rwkVI&S>B?b+M z;a?hlVi}HND4t?=bzv#KVk*vJEZ$;iDyShgVlYnPBVOV#CSx%!<1tQSG+yH~Zeuh4 zncuDdp2dx0;Wb{~z2o7`}cKHg(L_T4%LWa6dcK@Q~N6=Xd|>T4rqd|rgAQ5gdS*xPH0~C;(dN-dhX|lhG>1BXp63BjF#w( zrvGS-ew0wo9H0_qk5)!pm?r6$X6cibokp%{oxbUr zmT8-gX_w~df1jplq26hr?&+XD>Y?^&qYfZa-sp?&XpWxhsHWwQlQIb?B}3YPq)RxsGePhBvIfYrEd-ujcEzhH7qx zU$!1>w=QhLKI~3HY{fQg#$N1B)$6{F?7yCD$fj(+f6Z&Zu58WDY|fVKAP#HCZfwyG zZPH%rv@UJb9&Oc5?ZAU;&+cr?er?%~ZQ9=K*|zQ4hHc#5?do9a*9mE&R_dnqZQ(Be zYU3_$gGmj=w52)W^VY&>EnKG>5lHI-@A@V}&|Yo+X6^p|?-@RA03UGwE^q=TE7!*F z`_69$UvLMv?+$Em2v_h4mv9R2#=tgk4L|S=@9?P=a1Rf04j*w5f4K+0a0^fI6IbyH z*QEJI@fDBp7N2n#SE>CjaT3pQ9N%$^1aTeze{miUav&d*1ZVLYKXM~aav4YRC0BA9 zcXB3w@@Ap#ISy~~9`EnA@+?R1^Tu-X*77gsaxlm4D=+gc_i|IS?l4F5E;sWvcmH!T zKXW*5b2+c`H@EXRzjHmWUG;`?KW}nCmvTWj-xvS$K{xa$Pjp20j~gFyA%Ap7kMzLv ze{o2!bVNX;I^Xkk&vSQgcX(HKc}Mp-pZ9bZB61fZ zHGg+|uXlRK^LYPvcJFt9C-{LEcz#FtfR{vTp(kZ7<#6Bji1+r12ltg4^@`8`c!}S5 zikE9-pLURsc99SHHFP}>nFn{5cXNe!4I`h`?kM)%TN1U!8i|tsKfLUzPKV4;aJRP z`^-0e)c^AnTE08$NW% zP$I;M3olyCSaBo8juAg<^cWH(N0K5-j!cPi<;j*NU#fH&6DCWVGHcGviF4=8o;H8# z^cfT=PohGL4o!-5>CvV}f1fIK8Wk#2s#2>?&5CvF)vi{*YV{fxELXB(%Z^QpcJ0}= zX5XrH8yD_I0152ky{k7b-@RK7ZX&aZRNyTJNe#4{V9Vewj7b58Oh(H8-pYM1>rG&J zGv>~mKYI=hdbH@$q)(epje51})~sK3>| zPmX-K^5)E+JC6>1y7cPQuUq$AaB$$ki2oj3Xjh8lf`AWiQn)*bgzVVgyN?flzWn<1 z@7vFhf4~0z{QvtGAbBU}#6eb}hVqk^1IG4|kH5DHe_UJ)J^-d!9TTAKt7e<>L0f|61?sil=-da0(F za=NLfor3zQsG*WNs;QEiN0=0LW$B$@TjZ5QqJpjJkd_18Cf=#Lmf9 z?5f)?yzk0ee{a0^(u*&j5?~1cL>vm*Z2}!`@*%Yz3Jg_5^D=C2z7IF7ufq{TJh8#&(?I&R}Y;vf7V-T-8I)=dmT2|V~bt3W@Q09 zbHYJ2)M1wjj~utibAv2*-E@b1x88a4ZFk>z`yDvogYPZ)!&z3l_S+$83mCUm7Ebuz zkyAdo<%VA_Ip&*d-g)Jmf9`qap_5)xU(@D=x?YezmY(W%y}nthP{Y1i?089_YwWw* z-uvmke^c7u^{O&8?zWwvxUqAl)>)(I%{-ck7|Nk3c009U;1IiD9^dsN_ z8JIu^F0gb{6`Z18A6J+TK3CKYrvXF^9q#_&1$VdJtBGIBuAMe-^Afzo@Vp$v{7iq{(dNP!b z1m!42*~wC(@|3D19wo%xf+)f1B9crZ&6D&2M@$oZwv6Q~gXG@hmM+6S7vf5Je$MR_Rvh zS=qvb!byT+)Sv;)=tebK(2j!iqax+#NIgnYugE5rk*Su`X2l&?txig=YK(eXe^WSq?yAzccF3{7ihp{UtXS0Qy}uCXRjt18s1S~aU!-Kti* z%GIyRrYR7aQie8BM`%sysAtV3S*cXJl7{r8CY@_s>HiAXyUMk$dc~_<^NQD_Vhf^} z1Z;)`8_z2#b&ol@XkQz($q~Hny;xt!!&c+g~Cssg79)Yp3ZU05DZ`3CNveP>V~^sP?wDU9NJQ z%iQNWH@eWBE@q}8B}rIlkJ*(_EwdY~B&@Tdz)f6I7>ilUYPPcKJ@0zcf6LzYx;MV) zoiDy5M8osxDz^n=6~7F^(2z_ryzr&(fZI!8`5HLE1zs?M8GPUezl4D@+$CmY<;+2e z)-y?vj5Tk!DBW36jB#}1h~Y@$6Pq~2B1Z9wRUG2~7l#qXXlyZ!W&Fk($2i9`wsDDZ zyki~Tn8z#@vWtP-V9IEGA?Talwx(5B@tYkBb8O=;qbCKB`<~5%=&S{o2oab!kI^(&`bk_5o`JCrB`#I2n9yFQ1Clw2; z(~=dD(0{*cN9y&rV3ssW9~Io-OGDVfmd-S$H|=RngSyjov#*6Xe==5TcB{XnSnI;f ziI9o{#k``PHK%L+Xsi zwzR8F?Q3fj(Gxu`K{#w72`szAQQ{7_HH__ZTmL)V=w7$F+s*EGOUJs1-L|fxHBk)J zvP$4q=#yBzSY87>f7b%{HNksr@PHQ_;RjzYVAGRYeN)IW9g;-yT;}j8<@etMUwFbd zuJDh0Jmei8Imb(Wr`Wg*wX=${VIVUh!qk)@T9S;blFON!!aL{g-nq_u&hwx9yyq6C zxyv7p7x0jx4K8=9UQ)GET1_?RK%ctQt4{T+TRrPy^Gcuqe=FF5kyVpzOY3JVYI2dE zo#Y@_JKEQt_P4t|?r2JA!&#k9!0hqv^2~~?<<9oM!@cc-2fW_{Px!$L6qe|zmD3!f zGR8+umoc}Ev$1Y@tz-W3n$JAvH!rMKy$~<|=$!nbf?e*_*xd7-U%l#E&-&N9KK5sO zmW6JW@rtl2efis}-vRC)0rsD1^_ypS7yuv*rT+3SkirVG<(Y64u1FDB7ceSmQKb5dX?V-znh{VxbXg;T3Kn6Lw)1dQ`~>3d~Ry zgv=6of3!k+RNiZ>q3pRK?7`t1%3&PR;nl?qg;j|JMwHc%)_T<8AI@PQ0^%SFVj&_^ zRxsawS%ggCmrR|NIXNNZAcYlrAs1R=C1N5NX5tre;w65R5Bke?CBzRd#3V6g21ZP;LL%?x2_Qag>}!AP1!P~t4U zqC3VSEz09N&SN~%W4H(m!LU_XeTb_6^_wkWiDFG+J;Gx_668VNqdh95K_(;+=mz2t ze+FtrWMgDT>&!~a$&zlg;Wu(*IC^A9g5*bv{Uz}sWsM=a}2NZ_Ie^hEE zRdVHS`ette=T-`*E9IkLJPJHb$m1XyCFx{h=_XFJ<8c0_bOtAHQs;GAXLcTmaM)7Q zbk<}zMnqZ-dkDuy-lR>U=Xt7Uda~!vEzK;+RB@3T84gD_uIGKWXMW=6e(LAfm`Lui zir(Sfv`h5_V>lXfYXe}1X)4UKsM z>6r#8nyP7mY@H= z$TgX2_z>=Zm}$1=V5X*|y5^)}=4euCre3P0LMoZi6KX0#Xu zp6aNEDyqV&tD0)0#%inDe=4oYYOUhxt$J#%`YNxI>aGfF9JvXeas`GOYN8_Qpek## zGV8NarLjWmvL@@aI%~CJE498$na*jq*6Fu;Yq)}|q>!h$n(MfpYr3MV?A@ugx@)#x ztG34LyuK^T=zWVD=2+fMB>%gun!5VDABCKg0f9%30Y{N3_!&=q9 z_A9^Y>%>}Y#bT_)q6Dm6tj2mQ#eQtZg6x=btHYWs#G>rUs%(tDY0A3n%EIi+%50~! z>&S}i&1S64<}A*J?9TdZ&jPK^1})DDEt>qQ`Z#K@0_&?Dtawn<)NJR%?&fALf9J|>><+Eu)^6_7?(WWR z@7k_C1?(`k?(nW|@e=Rx235-<@A4XN^D^)ARz_0c#Zh zg#af20UK}te++N{PlN$0Fn|`YBtCEgOR)1rFa>V}0-p~9Pw)Xpj|Nw;2Xk-+H}D4s zFb0G02%qoWa;3K#4z-|Zw-4d`~t-c-!MHA@S~FO1W!u}|F8&eFa`(l z_Vh61M2`>qa1fYq3IDh75PNWeZch`luoPdh5L59Ie^>Did$Id|F%E|@@SYEWZjTv9 z4;q*88MAR3w=tygtsBen8lRCGt8w<&@f+vy9Pcq5`|%y;?es9}YT%Em26EjZGTz$l zAuDnsb1fr3vR6X#+BR|}NAmUbaVG2WCi5{TTd*f=YEZ^-b%W^H}GAr+L0w4eo{Bke@^Dql@F%$DK8*?%v^D--QGcz+1 zICC^J^E69yHB<97TXQyJ^ENv(<9KtDg7YPPvp9#d*^%=(o3lATVLIdHq_y*wW@$OU zb3CK-JdZOwyK_GO!}C4cvpnlFJ>&B}`?EmTfAc@{vq29uK=*S&6Ldf;G(t1!YV6l}I;i*{(E_F{K7e-xayYq$1m%eH90b`-=mXRkJ9%l2(+HgD@T zZ3A~`*LGzGw`(hQ9Balzu-5D?w{mBOb1MdPCwD_MH+4U^bw@XLPj_->H+OG$b$fU1 zfcJP`w|6sld6Rc{r}uZGcYCY1d$ad@%XfU!_j%8^eb={iv)gz_>XtEk9T+gFvMj4wrYR2 zaSL~GQ}%HG=Qe03d6O%4BsjT}e;;{fH+B~Iww6CRaW8pq=QfvHxs-#sm5X_mqdAj% z`IVD-M4We!1G$OE`HBm$g!x~E?{sV91=k9w->f4E(QLO8Vg ztGjxv!}_etdacv?t=oF80)%Cr2+yL`*T{L9OH%+vhL+kDOA z{LSlp&hvb>V*@Dse9!~^&b|TYc7J zJ1HRw${of0I;1mAg4}RhszTy{t<0Jm!L%!oP{^Lu2{tHmyMFB3KJCkX?c@IK>wfR&{_pcX z@bCWc`+o2nKk*BH@gx87D}VDR|MT~;DmyrITe;0N@14NPpKoSTfc(C9= zg#-;QOo))7!-xrq(PZ3b*dC#|ws-5swW}7d-nl#jAjt#@8>YmI6*p%5*l}dY zlNBSDY?*Rq$(J>E-s~ClXVIZak2YQUaTX*MSLcQMnyX*guwBQdJ==Ef+PiV{zWrOb z?47`U2Ol1scyHy&nIm^z9QyL-(xW$z9vFMy?0mTcf7brJ`(E$i2X7yLUVQrU;L)!q z|5rVH_wD7^7i`~r{I%UF0>OqDGXVcvOuzyMG!Vf96-mfCEutrpvDwe1#MZpHm}TyM<{ zm)vyEEf?K&)omAEcIEALM0b-T=-D9ZjaN2(=S>2E!*U%sR)S}h)mDQMR+wQ|KmSX( zf8dB|ZTMh{G1H=1aG$j|C>2nmf2;RVV)W1 znsdHc=bm}~+2)`d@)f8LgFv9=XYmbIT}%*kdO@fQl=?uZI|Ptwte@UGYOVjv`fIMI z7Q1S)y&gO4w9#I>?5^Er`)sMf21V*je==EGM7&4T`$N7JveECrmlWJY!cR0j?-dc> zTk*yL-*NH29Y1{X#0#&S@yj)*d~?e=|J?J?GY5Ti(IxM^^wdu$eRb3SX8m>7BadBm z+FP%E_S`YQop;-P*B$lVfrou~*?0f__~LO-p84RBcV2nnnT1ZI3ku8pBHEHyj(AKy!Hjc54afIU>-&jXC z-cgQr%p)G@$VWZ$%wHb)V$HxPNOT2KAb}(Z60*prMS`m^nW*94Ee^~#rl9|Nh zCZ(vFE*guAh=iBE6seF>O3sTc$rv3GvKmsl<&}Nh<1A%a%Rk=IkGZ5JE_bO*Tl$ih zzwG5OeG*GUnnYi24CE;jVoYKUB#RJ<<|nN~r#4|C6xhfFHnN!~YkCu$8tUdZ$(g6Q ziL+{$h~TA$lmrO~BvJ5We`luXIZu1m6QB3gXFmD4Pk;6kp#KzTKm+B@BuK%a2z{qQ z1yUwI9aNr9DiaXS6ix2ps5+67Wt5B;ZjJ zAZP*{y3(26)TTMb=}vXpQ=b0Rr$GhkP#2m|h$0lJFn#DdBj6>8e?jaq|Er*fDg@>* zCQgh~x4Pu4wE8NGb!K5t{pwap)zz>LXjjxUP8Q|4v@imIAe(4ID_YSGx{^T(BNFAj zfPjEav_caBI9V_sf`BA!VG}k1KoZE1*uer5vcx28WQAGT$x4>8m$htWIg44(Zq_+{ zgKI)W8xf-s5&%-De=A*COIN$PwiTpMKoA5Qv!nG@3S@vSDYlRV06Z2s0Xr*HGocQ8 zti!BVeb`mYx~i@wSFBm37<4OURkDt?x}CzY<5P&kB3jj!qzfNDM(NDfS!DLB-b=bln6QNMVb1*n$)?6;u># zQIA$Qd7Xsf{nhpvrk2s!-XyW;eGP&Tp1;oasF0N*(sEyk!WLc`F5d^(nrj zkir%QU1&ldnyyM1U6_Bgi73aG6|AswDa?GN^I;=>I3{w{LLrWR)MFKxkR*=xXC?GT$6$nWsrgb03d-)u!0nL zHPp2Gd}u=#x)s&-HISl=#E9x?GB9 zG>dx`e@?D#VUN1j!XCERM6dh1%5?RWu^vw>MpSF%34CaxxXB1yj+O*cc)1;1C{rnK zBgSk2Ap&-lf)q9}+l7=@3R-AK8`fKbvjidsl|6nc$~NoMaYAS|}76BzV*IkJhN*P<09-$^pn@^!60-}%w^ zfBN*dPJQTGU;5PtZ6eF-d9laa`yp=yf0@r$DYWChqYZpt{himCZwqN#@G|WEh2|)e zJdl{z4M^{Vi5B-`95kUr6y&6^sY5i%06RpIBq^H?P}KU5ZaUVXwt@Iss>1s70mLg5e%@C7eW25WEw53m5YX#z7O6v9CfIjU;H0TaKeLkO{(67T_I z>}wRU1vf?1Xs{9|#MX9X+U#lzrXUHl;IG`l7TEAh94*UQgv@SnOp>4$1kLtv5mN#U z+vLkcu&ot{?)ARV8N(17pHUj4e~}ue(HgA;2}q$GJTCvX)bPpx?%%dx53fYEi1Bbn zVHdeX>Aq}5Hi7KeZAzjGG!U*IMuFkHic?f!9ado<$3YPlBNM`bA6H=<5Kl2gK^;UP z@hT%6Q~|~e1``BQ92#&jI&BrQK^_mTAJ!r9$|@Y_fgVHwyBvcb71A;Ef3YMNQkw*F zA0N^&LO~q{QYEPiG17`?;6ewNabJwE9VEc5(50_{U<^LbAiC@Kek<%+LHCwo(2Q@r zG!7&@CIQ%C83Ew!`s=Zf#;~+N69`RYzU4#s#OB&T3N%6FMnn>1u^{3Lu}rKckRSq@ z@|~_v`{>dxv+pjkPcQ2de=qxTFY_{GeoqoQF6~C7|H?d03fLcRiYwt%({YX^N2Uh?ZGy^H;Hl0#1} z0n%)5jPgf)u}7I>edvd7M(u0V!No-3OiJx(=Aj<^A>k4Z#;|D}P?JjiK@eeZ9sGeE zNRSlO0X3uZNXfJxqVqrkaTFd9)+8hpvXmozjBY}K9{M31e@Jno`foWE?nyC0O1Bi$ zQZY`oG~og;1&wPC14c%;qM}>6*yrEN&!{zCf9ty9{!;fbdHBCwTO;2 zh>F#Rk`;)If0bEr=vk4~S*LYbsa1%qRavjK|5~?ITe~$|zqMJp6!w!g3Vq zp%ywUMB?imMMM+G4hY4yM2HUF`0H}wu&siK%*en+uF_o11k;XmY*;XC0+AI`5Pk}^ zN~v@eEYSa^v|%|C60)=(8nv53p&vrhP1^<>#K9jpe>N0VVM@i}OvY)*3K8LG0Re?9 z6X;=3TaDJf1``aA6`+9uEp|&0wL=0m;RMlc5+fA&p<%xJ>~Ch$5lm=U<~5T_*4W5RzcAA4SYVNRC%!-NC9?} zXo`$W9A1(i{J|av@HA+lAMAl0Mxh|-K^;28e>g#{AX|(hF#+LfK>!IP9X2);W)d=t zlOFzYF-Bo{Hz5;tvx*STAT=W#>Y*M;u!=;19{6??BtbC{t{%og6q+r4?PhsbK^zdy zcNuUVz*H1G&Do}RiLj_P{AIDY!2PD_Al5Un-Va*{qBHfbRng2R1xxcBO9Jd_y@aJ$ ze{``2iL$@U2%3&jrVs$W(68=X<3$gx|H|O5{GtT`A|bX$jOh|!3S!GYH$XwlU<&4n zaSkE^*g@oQ&FnV7!+efJm-s|W6h)huiJw@CotTPA^tKRS6lx*&+S4jQ4)8ooMd_h* zV;EQt0?OzLF=^`Pgpd{RPbA3LIP}XGe>Vnc0MwbBBa*5}73M(|fEOH4vwTkjCb^Uq z#$goJfgg~Qe~b5c=^+!cV0kqm5>Aq66|yrBj$>K#eCy$NugEx&b2RLfIj0C5YQZ1; zfgHjY6Y^IaSRvtlmmab6OZ}mIO}Q2fxqaKAQAy+Sc26NjF6R0}{v4C9cKIgayoz@3fO@aqJRwA zH~`L%6-L1oHUXL0VFGwn2}TQnlVAwBKnc$8ES;{Jp;(}&IEt%ypbeU!54xZinwC&_ z42Y16SF{z_Edcr|5^A9ywm^d8e;8h7yxOmuNBHLkLjL%!WI&UeAy`_h2~G19 zn=*kPx-$FbQ2y>dC|8~9kKt&v=)N)YDD4F zMgeCH@d8oN6fsHyG)q@Cq8EU(sI2jXgxa7DN*9hLD7 z%cox@RHu<3+hk-cxZu1h4PYsiJ+INWs}Z+v zTeowYw|Co~kRV-IA)+4Ie;1liazM8ow4e>?3T{VxMdEM{tAx3~Hnx_e59!3y=pob? za7q`hWgkSQ1u+sB@B_iRt%X_?%(STMArl1fYKE3$Yq~&0A!J3t0>4HSrc@IIa0UMn z0khf^za~p*0TQywzhm%a4^I^jFtGoZZ-Lrw=)oQaF~a@fsoB?0e_^nwi#KBHVZR|E zzx}~vAyE@+jZc1L976+&kcS^8L(6YGoY{9UHRY7I742=75{4=yW`mlYwx3^r&yPV6v+{?qfaFRf( zNg)bC1OT+a7TjUSfBQxY*r9Z_M9DRpwyUjh4BHN;#L>)Aa%@kjo9|D$@hSCTMsHyuuUDH1WH+8!hwrH0x}8{-3` z)6}aI8^|~DE>R)VyD`Fn7D6%;t~b5!vEtNW(QS1U#+yuSf4xJZHxs;<76>>TP*WlA zlocdlo3h}4gL7eQ^1<*7|3Nh(gp2Chi^I*uxr606%NVO3n65pq5`8siaK;~6^RIez0Wp7yYqc-alEBtX*Y!4*7Ju)@~MqCjS{XhA^) zuGq5l?$9#xy<0{rwt0esw|zNCQyyXR)5=#BP;=oreKoUy9{P7C&$k}2x45{pJF!6( zDp>-s14DgL3W+7P z@V#&1OD7$T72pvFq<|xGu0gt4~sD)4K?chDRaP+h%x zA`#LPW$u6>Q*N+}Oi%2tuI=OD-x_(9zii8x3 zUcYeykM(-LD*qTx~wk=Ym zY_UmVc}r>DtRv)>65z5Fdt?H*uAM;krUXdZZUG?NmXzh7Ntaq(SvKw0*u%FbRV#{b zE&#T$vh}-ptxdkkNMU6{s8_bV1NV;XkdW|9;je~3@+$kgAhI_ zVT2S;XkmpIUa0?Jh8%9_VTT}oC}M~tj%Z?dK#b9jJzGeS#1y%4=1ON$B%uW<{*WTU ziIgGHijLW}qJ?j*w75wt_Ot>Vh(MIa9c-fgh#?Y-HD)DR5}6f>CSXT99 z^QBi;WkpL*SJBhcJW%;V&nOIK@)Rp48AXqmHIdm-S2@;W4=E!3#Lgy!rsGv6qts(3 zM`rr7ikbD0LMSNe^pg@cpL|J|yeiMK5&lySKQ0-BjCo&@DyciFbdrTZZeX28h6a`#~z0a^2j2WO!CPlr)(htf0C&BtQ1nP z!dWZMCLuu-Z7E6OXyrni04ZchQrfd*u;NQ8-;O9*v-U10S-JOp*yhx8Qf>8|4z;;T zLNsj@Y(JwERQ1?e8wID>Q|HuAEn%aL&DOV#GR`RCtildZBB_E^*H-7m)+lnz&5A$v zY?4Vfv+zUDI<%PG>PIiJf8&)u;~c*Au9yVt6gw8xlTkLb_yg49P=!tQ*p0KIiAAGm zlF3QzXcGTPo1q-0Nv*Z;9gaYwa9PhL8F}5$I<8-B{F;~`AXyNi3EVtyNsv*GE%w;* zoN~NLmi@N2OphFI%rl?8^V&aezGUcQD*^fd0D(ZW_Oy{i0Lldcf55-3;P-qh?R@Kg z6t_`8{h&j@|M71!2pB~zGyyLuP-}kTJD&vFA-@0w0bX0E2N~|eyz+UFd>{m00R@Jy zB(On93WQHPwx9$jk-`;I2!ID!Rtr*01^|>f&2z#R!t?d;gX0UL4}}QCAsVrWNc^D^ ziMYfgHZh4zjG_~ze+b1XTCs{)e4-YqxWy}WF^gObqZh>p#xa_)jA;C#8p*iEGqy2} zY>cBD|K$kBIoh#~czmNC>9|K9cJO_*@E-N3M}{L+?~v4!UjlfT$MHSwS@ct00*WMs z%~Z=3K)c=|?+C)DSx0}2OaQl*lo>My0SRt9o+{V$HL|oxe-%ZkTsBsjiJa&pY#f;r zEP>KGs?`!GtN`3NL?MpBc`kEUQAIAPwl!Y_b8d6Hn>|JmiCQ*=7VL0bCJ60B%xMg0008QMG66K5{RYoUjigiL!UPCq(s$Gxg@czEgWeR z3MfS@q(FuhDhGX{aiIKgh=3$CH3$|2Ax8N~R!)}nf2e59s9MjuRQgA_{YDHgFC2$|Zo)ieBXomvsQ%XFVB}=ebyl$C@+6)VZa zDqi9`F8}F7B%V62Cd65gaEL;kdI?G-L~*z7io_yzlI=ulyS7Xi${*Ws-Cq(@7*K#> zA2ZPeRuYp5NE`+*wE|3mkx3W*S{NdjFi=%X!=BEdh<|(TC`HsbgsrAwW(@WCg$-Gw zh4!#U0yK+4ZR(Rh6pg_h=z&gKa;jr_fhP&x0?og;pf3&$3}$1fkB%O;g(mZ~$e7tg zJ6ORQGPokgD9f_SwyfnYbNS0&4l|g?Eaozk`OIccGlVGGVRzhSg9*SH&1P1!65(cC zT1aUUYJb6*i4|vOPWn%iM0}yE7Wl%5#*%K4n>mY!ski#gt#*$)rbZtnopCA1NtVk_ z>;}dEB=q13q(Na`=3cdnBX(F&TQjizIILhvTh4K~P#Je2fHD#j4~d1T>+Z zS|p*8Q{}=Gh~zWLNMco?aSoa{uI7$&{No-EImkyY@{*JM4T*#sCDI*`wS5_%p>eVGk3B=4l4UP@A%5mZf9u=~rV#LV}0PvAB)t zRR0sjq9`}qjf7)z3K84jY{80EtYa1GSj9Ym;-{`VJ#D}|d)n7AID^hI8)m0P=g1<8 zkF;%@Kq1TJ#I{XVu%i})Gn_V37d&guM1O2pe;e1ZkPY5RLKA<%$SjE!?bl(Y7KgWp zKi&%z*-b4SO>rI5SO4?|6VPGi1NgJ}#|6elv3(_MD`v0-Nrz?dN4F6TSHW#VtZI}r z{lg3Z0BBL{Jx_TPRHXY(qY1A5%Z67CAxV>fL+Bv(SqX6dQ7NN|+e82fzXuv;Vt?4e zc1(aW7Zi@TObMFVn0YA1@IvS-19vzbyqC&BkPj@N;Vk|WIdaNe`{q@J#$iyBuYBN zf^mcZnZSB)*K|6_biWi(W%Ctkl7E9d7!F%=6qqw7qF`Gn(O^Bu3gcjQXxBLXkP*vt zI_AX=tMG(pcM{W7P1ukH#dLV0GYVKB4o5*vYQiO7kp-}mb#zA$)H4c_parzX6WGvq z$E1T(qg!!;DzY>Q)jq|+Qf;6FKjl0+H848@07!sSnE%02 zN)m$2g9HI!i=y~s;wOPwHA)*bbJapU(-Bpl#5|(7a`mtq?Sp&vzy+ZMf!e5m-1vaq z7=hd9jp7K7-x!YMIF9C+j(_F2j_Ama5eNiQzz*F*1|A3>+4FBE*koF8B-2w_x+fYW zWiMLAeK?~*^FxDfbz=Q986Si`2ULt?Wo%dyX%mTI*9HwEfe-{y60x;kTk=gC$q^7C z4)w(f3;{bTaVIX(4?IzkW|0%O!V*7WU*fP9M}b}Cl@LPF5?DbPCx4+11(6Eb)J=r5 zY@^^6VgeDeGKGh>UoPPfqo4(_W)Yj91?j~K%oY(uF(}pqU4aB)0znuhArH6zrWn}3 z7#(RPk>NCbcz@OzP*#_BX_u|B1!S;+xltN(MhXF-0Pt~`qJdcKVi{FsZ)}N~ z4Re{U*>0=(nym?&tO=R5X_>Sso3^Q&xH+4?xtqBeoWI$dz&V`r!JEe^oW;4E#Mzw5 zX`HdCoX#1Y(0QBB$(+|Yoz_X6*vXyRiJab$}?kq}s! z1(~2KSg8rD@DJp$2^xWba8hZ4M`VAJ#~L@_L*CXuh@ZTJ}7gVlW%FBvRlH6Jwe6mL2#em!r@MWB?lY)HFz-B7drBn|Z1;^U-4YB9{_68BV2R z?NAGz89)I*27}rfQP3-)k(l$Ds&`7C^0}(2ilMF=tFSt&>)EQbnyIq?TC1$OtFNl7 zy^5>1nxMeytGPO?#wx7HYOBN=tjc<<&6=#dTCC9etM|z)Gq)<=(j2l#88Cy2>#=@z z%B=T7Gk@j683IWPW1t+;ik$i3aKrhS)LIz|bvniq4z++7jJOm*NUvm*4aUTXv{P*& zu{8-*HK<1qZO5?!||z6BLp#HS*vO#Iv%e(_oV5 zL@PB);Wvxuvr}8Z8`;NJ9)$qb0svZ24^#DLxj}#HK^!#5KhsehJO8DE17uhT06_le zJ`H7B@JKe!Bkdya@(xQXkI zihtX$s5%xr;lwma9fj15lxebHNt?S~a*?AdmHM3o+!0oa8 zXb)&&3(bWlTO17v(z z1yuvYS1D2tTTmR~l6*ELJ!QasNzfKTg;lw+39h(O^-uiIQ;`&jYbzRO+kXnI zP{R8ISvVva<+lkh#eY`|e!THQw80yuGNm8pkGRG~Y5kIT(a1f|b zQSc*1$ktU8TTO1KA zX`&_9CKu=RC1I%rq1+Os%wB6M2H};6Mv0V#(aH4?Y7>SQnE#RxK(So-+ZLWfg~0+R zrD71aMhnblPZ?1VMUlvi>|282pTf*>K`;# zXrj@Sg&b+9LOdk8%$!^k9Dng&iZNUgktcDYT%Wvjp%A0IrO2ZuC~>lMiDD6{=4!y4 z2|e9iIm#wOTBMmXD^hVc;~-#&@hf?ja z9;e}`?7T4&FwZxX9`+HT_FNfsjnJH}(4O7dp#9mR9onQl+NO;%uGWtc6K*@=8=nM^ zos_TPddSfpm4B$a(|>Ki*-4HuvJ(q21ONFuoBfK3iGuVhBHfkBX}O039z(4Sh_WZ zqY2u8cnM_?(*}9+#1Pq#1r-66lCVwp8a7yf35|1vS^sBGS%3+%ur^faJ6&_!tmk@1 zJVlmtkRa5@8oVSJ3@@yJ9IW+c9==n2mVvQ%93k935#S^wUVj-#uzgFrG*boRHQWxF z*gQQ9E`Ch7b7eo$*ao@{Jr3jwvIvbJq%6^+!8`nY+=q}~+{ISx#Z*4!RxagO9_3nI z%UJd@dpZ z3THEa?&lHE$AA3;=U`-m6mHyO1G4PECXU$L@)Qmv(I$Y$u-QNo`#Qavz=j-4HLw)D zn4q(c&P-fmH5+RW=&fAJEqPE`hJiObcW4yi3p-fcE26UU}(1DLSbVEFq<6!RP6hHA6-{lwoE9P-4u536`vdq;;?GKCA7y( z6l+sA7E#r~-1HE!mItCeGBZ^!A}}!HEWu48d1(vg1|I*iLjR!|-y?IQS{wQCNw%8+ z2+%yGRev7#8mtd9rP0nNBG2^u3-)ji_TUTj3v~REmqo(cGezfZr1B>dDP{KRkM6hb9u zY4k>K6JIB^GJ5m?fejQUPyz`G%*2o0Dsij-#`UYm3R9?g{2Gc{SI^(NOpw<3tCuO$ zL~P(f^`j>Wld69G!eug85Su7}DItX_xqnilHcjFnvI;j(W4%$3MAhpj3MM8|5R0xt zi76Gme(Og0^yhD!rcCs*vRfs}t3Pp?NDeGG5a&QlUG1@gG;&_caic!cE1D^kI&q?W zq7*sNE4xum{Yiw0Hs<2Ak&pmL09kTm%98;^7-g^3mLw7qNJz4!m90&cO~CqDA_{F zmK5Bwy_;SRd3Zfv_HNnxWzU~ITb#PjX8%Q6itFOv?Z3aDAOC*+{rUg*FF*hX6tF-7 z4>T}A1Q%4WK?WamFhU3?l(0eyFMqT!Lku_6utN@0)2*G|B!Fy--C}di74BMN#U6KT znIbj{F>8gp)wF1b6zM#iY`pUBVb7j_g0u&{B3qe2LQJq(vPmYNbTZ0tXz7QQk)Xsf z$}F*IDXXmd83!hcy!>*NE1{~#tS3Q92QW;aN=i#i)Tt>HOq^^g%}lbP@_$RJY*K}v zdX&lrlW-tX3aV8mQl}p;Ytbl7ZK$fn(X8q*Q_d|7MQW5Q$MljOH8J^!Om+fe5-Tq` zx^vSlJ3?T&Ae2iW30je$)m9(?NMa2C5!-2llvdayg|h^t(=`EHmrJ(UT%To@Rv-|N zq!xR~uu<9sHgUzaS--7ST7O<^vxU9Vj*~S3`p6)KU3c4s_uW$5X~Wi8lV!17TQ;*- zTW6yMINX4BMZmSqHrWDUgQbz0YgtT6z{{uMZ5j5mf>v(xKdNcu^IE!#$DAkl(D?TbR=-7BGaE!+oE&O z$phtQq(2d<+NiDaMq?RAV^?-??Pa_;I*&z;VS`Iqkq_Kkaa8f;md%n_}kyw zQ894!FMjOMS38mv0QN!6Y9};d3M;rmy|EC5E>vL)W%$Aw#*l_Lv|$cmxWgLq5Qje0 zVGo7)!y&5hTd|4)05_8x3EXS{KDn`0dlaxf8UjEHwJ1#qm4krcutNfQBVxvO6&i*4 z3WM3%O~TZexPPe$#EofrV;l>Dley4BByyDFLFUM$iHKqqqhQ52SRo3t+=(xL45X)) z5{h+*VKHfKNTuYhAz~h4Tb1KB^ptRPSm0m#pp#fno*8!)T830 zPCfR8Eq{MPCIJa+MLYbl#0_nt7V!j+Xat&3Qq*D#TPVfdE|fP-;PIM2ZPYqcvKF5f zbs(bHWg>eS4rvO~k|XhpN)lrbzPz#_zqPdDhp4A|s0AyJAYScG6#z3K| zC&Cg6QQFudM>j=h`9!n@8Qlk-+`&x}*mY0R=F^*GNI+=M`8A2^rl6;R&V5?**^bus zqqntfZgIO?-S(EZzx8eZaDh8q;T9J>QW%9TL|UxLQbr2utcOiY3L%-wMq)7XC=i5& zt$$WD%tK2m2zb{tUhx`g6@_wRdDFYz@L~yiPHHdo+6!Oux@x}ly>FD3M>+m_uSsor zl3G`>jsXvN9;~zPlir&p0Bb3}{LQZ$VI{tv;i@wOOlT0pl{d0&0ffIwDF(loH2=Bm z!nFboir1Fn6PFBEkeSuRd`P}BNFi)5?tfdc!jrPOl~~4d17eViDC96|*vLda@{pBW zWF$9v$w_ANlbsx8Dnr@IQ@(P|<`x>WS=L>g7;TqH(ZnK?p+-Ay762r%ov>xOZ}F)| zjUW~)93mhS?ZtC_o&U5KJJHv{Tgq>P-3w?!KQ7OIe%yln>ncSLx=Gq%9x$ivw)Tj0KcfgGbaK;Ke z-~_iO!vDQ*b?~DlRC)M10BrD**n8mD1joYrZIDdtcrqSmPHI4Y9Mp`QRwe%rvsMev z=8_ydB8RztOm1@j>&E=$D93rYTrQai%=|$<&-uq~F7%*>yy!bO`puP2^nax%o#{_^ zdeo6V^^-T9>Oi-8)uDd%p-WxsUdQ^@zs~iekKOEH2Yc7ao^`aZUF>Z?yV~PkcDUDF z?slj9-SwV#zO#Mre((F+|1S8zr@ioQPdX5yK%aeCw(%mePPXvg^w6~Y(TUTd4J};RUgNUQ@kLgUp*XKzj4_UeD<<0IAd?0d$j%uCBI*X zUo~EPj1S-93FkLj96R>~g?{bK!u)DG)6T=2-9K=@{@~Mntk4=B@ZaZr_`NTF#FwA> z6(lrh`<8GzSP(zaH78YLybifK}{RM5+uPBG{F=^K?M^D zkuau8TR|06L0C~60ty#5qcCuZ8+4l?_URuR93N)mn_+vHqA4*4lD2A7LTt0PBxJ%R zbiycX!YPD8Dty8#lz+l1)WVKoA0(Kk4eBoIn1qb!LR!%x9ef}DahVQM!!ks}emTQ& zBQYLqpcM)L#*x7o)WbaF!#(uFMO&p76vRIy#1RA^Iz*ZyWHv>Fv1fV18+?`l3YZe} zK{K?(XnDl{Oq4@7bi_?mLrnCw#bYGKU{uCpM8;lxMqzZuWlY9sY({3Z z#%!cUXxv6=#KvyC#%csdZJfq({6=#WM|9*yb@aw`JjZWzpePsvVNo@Dk~cY2L~`sH zBIHC1dq;O$&_Ttm2}CL{12WgIE91An#9PPOh}8|Bb(gGq6#KT zqDV2p$^W55J~vv6<$KAbgh{1j%B6J5ri99;l**{2%73Y}%BnPwjl@cpRIfu@FF#{1 z{1VHq9LumI%d<4gwDiigEK9aTw6p|ExBSYsq)W8SvyMPSEcBUN%e5>NnZC3_zZA^C zB+M-wOv5Zpy+q8#RLsUaOvg;j$CS*-d?CPO%z?2pxs1zAlgrMmOSSw=&D68d98J3{ z%gr>+)PDp`v;<2}gUrjMOxTpo*QCwKoK4%T&DhM%+tf|m^v&J`PTmyGDdWcqIz7fX zPQ_5p0D0ftWN5D&g)dp;mc0$w9f2g&W`iC>@zN{)Ibot%JkgI z^;FOQ_GHiZbkF#N&-wf<^Q6y}M1s)+&*VG4{eQ&H?*!2Q6j1H-PXZ0l?TozwRnP-X zPJuv5`GnB>jL->{&7Es}(iFwgE7j60HNiuaQaUWo!6d>j{Zb+|(;-FEBRx|!P17W8(=~NdHkH#j zWilrPG8dzuCHqoEEH=KOQ4=aio+Qda71Tl{)QGIeLp4;M%q9+X)DC^rNQKl%mDEb5 z)G|x8)J&yRnWQ+K^hrep)kO`}q4ZQyH9b`aeGE-iRZL~oRdv-?h1FM;)mWv~Sxt?A z001HR1O*BJ0{|=l00;nk3orx#2>$>F2pmYTpuvL(3j#@(;RysY10gQVmyu)xD+BV} zDVNbn1RVummq`8!m*HdsA_Gg34VN*V2NMK)q`KOd(SrsPf2Af#&b+zWR?ef#L?AFQ z*U<{4Yv0bjyLaN_1SbFZPQJYP^XSv7U(de1`}goaE*FYjF@Ui3!EFt)~Ge~dKt23?If=E&i2GwSFgO{9P(3Skin_={>PG8Lp`J#N&bM9T$HP?8Vj zXCa6%Nl8(D0c=F&k6|6$ah6D=32swLe=Ye`%fq$x{X5+vzBR}!1it`}*sk)CWEWXnJa zG#SMJm?m(mW~}0>?Eu%VgzUCXISN~_Gfj}ym}~+eAf`BxVbB&Lb!$+QQlhBk7*JC5 zD^Bo&f4f?#*>V;Mw>HV^@4=9L`El5*yG$uJwPXu&LP0HNWe|%GG9Nrh9?g7OMefbKt;-?6u423e`L4fVP1Y0-F#{O z(9#_7{Omw!V}4WV1N6)DN^4&x^Z=71w0c6QOH?~VA6InaKql9m5UsDzHoMBEhZMI% zf0fR(P!gUOAJacym7$S|7-kkZQ26TtN)8SDW;iXaLl zK&T)>a#tJaAp~>=453OYhMwJBDqnJsT|HLQ5;rLL5Xo1*GR9F%w8tWw{WYPdoNE#3h1c4@DAVE+UJCF>ke?XDM z1}d81^Uh#FLRB!1(;VoCIv6E(VJa4g|j6#rrie=AO; zerYL`+2F{?MmjPg-pfj-f-)+KM3O0!v@*U$ z;gPQ_U7{l=UE@c*d@?7ZjLaD!QpJW(zBzU=mnKB&$$cJPyA&4Z1`G9gCf0kMfI2zrI zN<-pO_E?lJ3nfodm%_|u4nzSZRSQ64lCzV3Ws%ccSWtzM(^>MfMz0i0QAILVf_!x? z@k~fPX(y1V4q%)u2`fSV#K-?Nbt*U7Nlb&{R>d(sperp$!Th zD+?7&|FH?H1Boq0R2x{HZPt^5k?T`j`$VF6G$*GFj#Vp?*{%R!v;?W7N_w?QkKlFy z=K5Vse43gkMeA%+B8;pMOPI)#w=ArP3uBpb+_OP9I~GB$G0#h1e}@dUz79?6Z+%kJ zo2B<=g}LuV;Dt@Km>0jhxhEWX8(y$5*d*_LO1A1VwFlQkMf3s-MOA_#?hfr{8%A&% zk7MEY6sp8JLUH;W5@M~@sv{n5h>C}@9271Wz7dhqjYYEK)yxf3JXzBA_4t#L9Jv@L9^41ah_b!d_u5HxFX4{0KfIU_;N8~Ji) z6#N$~*ZI=K@v~aQ9C7_xDbFqmm0tajJ5gK6rhbc+QFv|S zt`ux9w#cq!9o`mNm{u4*;Q*p75M{^gxWp$;@yUcL;uX($BlE*V_dfqlumk>i zU4Gx@@4QL2ufy*q@` zk$Wq{e&|PkAaPC=X=zhkdfcDpc7&w6z2!arJ5)NpB9Qc3%=zq0W zfBAQQD42d#2pEIXSA#ZqgE*LjI=F*8*n>X!gFqOBLO6s(ScFD+gh+@Kua|^O*o02_ zgisiTQaFWFScO)2g@QqSSh$5;*o9vBg|EYL_sz6CHnk7=>f9 zgnw9f#|9(CB8Y>Sh>Ey~jM#{d|M-ZIm^fM(iFB85EmDY-n2DOWiJaJpp7@EN7>c53 zcP2NArg(~|n2LE}3;m@Uho_3L_=)o*5J7f|YE+7M@j8gO61a$HBr%E3b5lH}T}8Kx zo3|Oj_#Vla8@L#7mWPQo0gZnxk!sS$gsaw!Wr2-ap^al^5UXa44>x8A2pZqyG-{%I{WkzPGmkT975vCbR6~sLHWy^kVbRzScfuzx zrH>zFA(nV7Bk?+cK`W;Rk!VC0lS3sPf;!}gkGVJi2C^K|7+@K}Qh|T*Mk6T`_a!B8 zraT5%f=+=#3z3o*nHD;BkGuc&kULjEC%F^~c^y6DkuCv_X61=Q8H>KLUJpTMNZFJ& z$dudxX(^}@Qppxai4-(xKq!+u{}@8j#eyp#Xoc~Wy#$X4tlW5!W&r*#(w9nKJScm!YAQT+xwvv0h6v zm>sbY^=NclH)U7pZLC3D!r_=L0hOc?Wz{DX|Ai2sIT54w5J4iEAu*RUXG6E}8?z}G z{tz4rqLj)aW+UgCRRS7+*$}*0Up%x)9&sOHsS;m8jSay-zBqpp0Dg5k!Aw+7jS70DCGDH2M@;>J))m zfD>_=y}_eK#z6Le|Js7H~f6TzrJv7X-9 zf&|BTp_h(0ckCOn;OZ4uE)0tb~;*cax2(@1X^aa3{z5D3uKHZ zqC9^=B>;f2%>k}j5wbGzu~{LsB;l|>At(DqE;I`(bzzs95pjoGUxtbjBzt6hI}%|l zag}Qkg4-En3lWdY5r9i|HlqZvYDOf1G_f^16009KdliXrAr_&mdSOE$wxsr@o$yg6 zPjb1Kwy`dOVhz*>$MRfUCQg0WJViq`x5LzyAe5W zAECPr6`cc^6yN^>y9t3M4V6RDn-bN#bblhgkb%8Ip|k@*yixJ8C(*e`!M&Ser^uTx zSVM<_>bFJd5P0hr_?r@+Yrvt=H>nH26mh_wmYrx(wU5ERPByOch7kiZ08?QQ6x4IdjC{;TO_8XMj z{db_l=@t=PU(Z(zBpcTWxT_L}N1`~?QNq9UI_r{+>xfI2O z#cF)Wh>RFp|D4E->=aWB5%ilJs+GZkqQkX8!i}6}doj8Qaa{^mh?smZmRyI63=)m& z6t(G4QJl%q$Es*i%6m7;ZPAVoY-kYy%X)WevaDy$v9MN#%WHHlBZZeRwjzI(7LGyj zP9q7-7f3QQ>=VWoug6${I32>sJSJ=8>9)M4n+MtyWkoYYL+)T|NIP94>f zT+G~>5~nQHSe?~cJ)&FP)m|OD<%}JF{MB=Z)MS0uXr0z-z1D2q)^30O)^Huyn4!{g zUDtMf*La=RdcD_t-PeB2$oE0le?8cQUD$}X5TOz@hP~K~-Pn%(*pMCBlD%M)UD=j> z*_fT#n!VYa-PxY~*|j0qpgr28UD~F7+Nhn{s=eB*-P*4G+OQqlvOU|hUE8*O+Y8q= zW{ulGCwWccdov;0yuE+ZdOY05ecZ@RabpdvP*L0=i9Wq90>N8W8|p3#1wm#wDJ1EB?J4k>M#C6K2Ve*<#}c zLE<-_bQklA1cHq`UJxeqG=gLjG&3FrygH=_D$sJlK3?QqH=Pcllh-NDZy~NrJ``U} zxoFYfWai*nelUO4$`cVC5-}W6VnHrQyyh?c<}E?stl1d(LEkZP;S7Q1B#6MCW#?2U z;1BQFu&|MJ1Iwo*6IqZh& z+rDsko+~JYP(C4N(M}bLy%A^ps4lSzqW%!3o)Xyp?lf0ai=GnN4ilak(i!LxM$F`t zHIxVu=LCPhbDah51CTcYaaZ$n?hN7TB9Y@5A@VYjWa-$ZCk_!gc%%sdEo?gY~g`U#|&~05@+GJMriTvD87!5PKRwB<~Rw)AJo6 zN+`0zgAa7*Q}!l7P17DOd_nOD;rN{&g;7Xajv{09Gl zA+?w`@lFuCKM=D|Di87V8R72o#LE=%pj(yoPY)C2kNy7UnUV(Hu#Xm%g#R}&{{S&S z;6Q=}2i7at&fr3Z4IMr#C_!K;h!CYz#ApW(<3^4GwRD#;tso+QWg;1nMBo;ILT6qC zNy=wU01N-|Y&8?1K#fa9B6x^(=RKMU2m%=p#p+nHWzC*Nn^x^wwr$p z9d7{mKw6Hi_SQmgvH{IX&_4$sgfK!0C#0~#l339#HVZd@JKC#4f^t?Fr6f`Of%0!v#kgOk}*Wdn#`$9ncz&bPCM_!GfzFQ zQ>rI9&!SU*zCa`6sg3&hlI;J_Jr`xPQAc%Cw4z8wnsgwcpo9sm15X4>Q%jdB4^#FA zXfxDL;j*$Lpm-GZ%Ai`Rl%YQZjP+4lZ^boNT}i5yAfZylsaM>vEQwfyOvQ-Rwv64i zS!bVxHd<+?rM6mYuf;Z7ZF6mI0Hn(L_OEWYCAVCEbBjz)*x=MXmqljVg*RS#rDb>C z`|3rhIDGBpw_i&4Y;({S1e-B1DVTBiFSS^AZCKO_cCr-;?7cHg_z$?W)CV23qRvYA)rxnC%EUJ==XfL2j8p#bbT zX{DEcX1ZyopN2YWsn1eK|DmaqthqCy@64LfkF5$qD>kntyKJ*5+xl#^*JfL-OL6_W zX^IBu62G|H#yfAlpYEEWzV`+^aKZiiS~L>xuJG@{%^G7{u?G^lDA^ugyz|83|nw9fAc z{EFB~>F-e*N>2Gk1FdyHGi*YB=#l#0NaG}R78+5mA@duFSBU351}e;e4y%shQU#WO zs=%aA4UC^}qJj|sQV0JCNw^>eMc5+4HRORa$=+ou*bwfGaB1Z$2pLwgK^N9gIvN_l z+*a~6TM&eXe}SP5|Kl9@u@5Eg;|aS`7^yJ%2#K#Ui2dgB#E1M)ASo1L7I#8LgJ5U? z3ye;HLdY8#jY}j*?9Oj z)pE%kTB{YAWS|>IIZMod@rD5iWrs>gCRtwQku2dPLrQ6oQ-U%ltRu_oC^F1{0l2~) zF=?Tf2B1oaaG@f$|1?Q2X}Qh+b@3+FbjU3PL^q$5Czxj(KzX8q%93zkE{EJ9LGFUh z0mv{o)U1$g$nr^yuyY~)1c@KVxf^_P)1Xkh8K!=Dl5-w3PAAEv=R%@G%HdF&OQB~) zLKhNPl0=^>*&|t!z|5OWz@aC9)r=W`GsXdQbT`FeY0hW^P@K5bAq!PZjnoGd51nwuDlmHY}Plu_%hfhEskOG)$_@B?mDw$&(0lA?4|tEr!aGFJ-hQ;Do?Syn2)4 zjD(*>MJsN`D%FrMbs%nyPDbyTB5e-jtT!oSS>^Imj_{Q@X$5SvY-$;Q4B_&sEF_6Q z*P5i6MKxZGfs+649`sT5ltiUFp=w90I+H5$4Q*~^;#b8Wuput3T}Em$#@Xm_BwHN_Zz(&Oqzc5IjM@TkEpt?u zz;!dv)E?)^gA?hpB%~{UaRWqfE0FH$R%)~PZdoE_T$zlut_6XkP8=(ezhMcs3R|o} z(pwSgnxwyPInaAGXmrgmF`@*SlBsHa{;WKkr`ygzM(Qs zq8uDAmn5kRL~dnevRRw#XCxAZNSq0>+?Hb5&sEHknPoK>AXh}TG)B>fu|-lP2k9hC zo)DrDi=syVSzD5Sjb}xkyb?&~C(SFVG)EptVlh|b%w_uYZxr2Uf*@DYmWK7NF58k+ zGbE{@j`d3DOuv`l`60Gws83!!nbWGS)YfS(u5(>BVjbweG>My$F|yNvq6B~s&Iwjb z?2ra)VZ1iMHj8R45I)yO*$aUmL(1UdDx<&BnqH47xO@m|QL$D$*Aw*)BU zxXH5@S7`sX2s|eh?nt=XvTKtlTrmb5_q=K8+RFsD$Tt}_bSpye2L*fMxzP&4B|dVt z==mga)Xb9&E)ZW}Hzph(3H-jDS`5ORBsOPBryJt&73w^>D`CvLF%q70lYHJZ$9PNx zP?2MUMpQ35=t{`%;0?QhqsW=oul{|>xS|4i8zEJFTw!v z#kYuuS^{Y7<;iE0dd=&4!tSW|40=YNYWHsSFu(eLa|xaGWeGd=NYZ)zHZAn8|K05e zSo@(v{~q(Y$9?b(+BJg#-#txtkL!a&`j#kvVKb{B*@fRQuBoUo=MVLBz_aMW`vpG@&+7JjQ6$ z43E2wX4{BS%#37w5nY4{px_wAn3VXE#&PV7GQ5wzLqj;BJGpp2a&*UB!LakIiKRQh zSwp$GPz@ZSzjx%vxOlod!AAMmjz|%I#X9knhM+jlkgkX1D3$JhmeKX0K%oZ}(N4|NH zd!)&oR2y4BzL&74fb58PW66^c5|=b42^)x_TtXsY%5O4@r{u}0oEk4n44w3U3{<2E zt2-crh`*P>v@!b{g{zx(q{_1tj_2|qG(pO<_(F(as&$MP+Cs~^9389-39jr6bxBB> zsK}H+vXh|48EPip7%Zu5h`g}N#k3P`e2LDJ$cEz#tI;#G1BnG7lb>7?#^g+y0gl03 zl<%@clQ759D+sSx$U2jbXDkSRc|;3$>de@z%F5u2z$_9KbP~p(1&|mi+4RkYDGmR0 zR0++j7Pb71j%mw~up;67&E_N&$Ha+SKr*?kv9Y<1$ixVTB#56v$$_BDfgr8N3m~ho zP2a$`;t)#bR8Rk4&4Tz#n_$imgo$Qsi?DQv1!)rW_)3UK&FFZ~jc_=B9%-Tbg3k3c zP+h^zn6ONP*n;G=i^USZg4oZRKn=vy4GZmw4LHyZbrggv(33Du`FP8OQ;YZnjwmpr zQj?MYe256;&=>U&zk;-zVV|2!39}?Gmk`f|kR^SKl;B9EpFZ9k;?!|AzPy zR_dm3DUK;Al_jOq;8@U-7*Vz;&yE1imypN)+=)CDrjrPl=OR<;a8Xlo2rGRE)>NN0 zt5u(3v>rS1hR>I&2t=x%Xwa&LY*LvL$ zRdon>O`0xn2})gmi4Uy_C|C(hwbz2hj~?>K_<&NGa7~k7)Ma@KNWBfMGuZ!%)eVUy z2ynFueT|TZT?ma8laz2%lW5q?uu+Re**%G~bp?()B?y<>Q->&@lU)jyg^9ru)16Qx zg-B3S!q=1q+6p`gtHU>nrLXjB*N{C4Z{^irZCZjjkw_7L71RiuEe+bMT?<_-SDHv# zUfbB6z>>z^?Hb| zt+p&t3pEUvx^lg=O~T7STy)ag#;u9P?FqoOD?f}28hHuH5E&?elT~?%Q}jVa<5(>9 zAA@1-QD%w|KA1P;1yo0 zP}1Qw-s4>g*)5W{Mc(Fh-sgqh=#}2-rQYhb-s{EQ?2TUL)!yzkFBbLQ@YTTL{nTh7 zU-3m>+=-s_W#9I7-}i;z_?6%N3*AMW-}}wm`b^zdJX~+U_WQw*QDYdr_bxioVi>)P z-U))}Nt6(y8@(GXg6O^X2%<%cUZeM3B0BT>KhJ*m-lywyt&{t+?(gqPk%HfES?I?s zK4Pvu9Mm^%R+psg`TR=``IGHLllQGrm#sw_jsd*$lnIv-=hmoM z*W0K9IP@YW(yTJHemT&9oZ&p{>>4$f)vX|{FE#}cw6c~Q+t!k;OV3^}^OA0qY135& zaK0=awN&b{%xx z)^ocS8uu8_fO!CtqK)W!%1O^`hn7UyxkQ{RZ99|hIKAAC7{I3z-g4J}>!AEL5V0EW z9r`hKXT(7@Kc1gk1CA|{{WLa1dc0<%gG1_1YVU}>6DYDcy{6QImcoP9#&^^| z=h|itzHcA!UzF|O1SA|EbkY7yXFX)p-}UG)hWB98PB6tGE*t`tR-pKz548c8cB~O~@E}A#!?0dxH7w zWZcuOyYd87bQ&~pg42JBH+za-b$SEM!QMH&n;9isajRr?AK^_8+dUzEdq(PYMizcX z{^L*@;zZGOM%8~ty>swy+WjuURvH=S**p7}=i&a^Ie_dYxxr1;zTG za}=F(HhGkV9dgZ{^X#1SUY);q_Lq;Xg^mo#r%3W1^P}_Vxgiqz55kL_Q6zZ|_k4G7 z>H6%-jqK{FY@5&ECC>Wyy}{K-P2yK}Ub_B!US2D>$g59HSH7roSl89(ohzq^EB|NL z0d%psppYoE8B+X(L2|%24Esw^e=+m zaZSrktT!tSn@MoHaD$svFQ0>JILEyI&DAte{Ej2$s}JMMX4cBaguPq#v)ddx!Cbc6 zJf!5eD)U{3h%P0se6QQ0@Y`bKr#I@s*w1Wwj9<7hq0XJB$VuV5uB~; z-(tjX+4BiF1~yZ5wfxmXM~EC>537a+cGn5OGyB zEPUdsY}xkAul#;{%T?WlBf(wMyUxBZ>qOqhT{j|r$6aqMz?w{@?yjkoQn^^UjwY|w-TFg5D?qQi53kty}&_U^!tcSv)j-0L-Oy41WNK^TopdE zGid^sx}1OaL+WbD_vH1z736b&f{#Eej23ES)f;pp^VkIW`0_<)%zJ`<@fGpR2s_SE0EPG&y}70Z3$6pU(}DwGy3eO|@Jld}tCtw$NYiW}C9}ZT zLrPA?9vx04J#J$TZU;S{M?avC9)BhWf0>?OJIARwhX9$84OE7EP(N>Epzd3v|DH=d zW2HNuOA|JaeZ2H*li^ED5Z%K35s;CAG>?JS+M5{-{cvqrn?Qaww=ZH4j{rcgF|-L~ z9Q?;1d#b+ZjK1KWx9G4A)m@hn%bT=?D5&MllrS=eQLsyoLaQ0aKSpym=J8x6FC&2H zS)&Uik*t!mOqFm`h#Qb&;tL0-84o2)VA(n=cYPf5+thb_#po`s%5@m^#;o4PJkQ3| z%Wq?N0T3}J!N&Q0p)X>w-!ieIJ9EAvt-j4#teoFLi)h zZf{1;@?<(Upoed!9lpsZ$w1lj7gIK7c<|>*R0?I76=dK{6An#JMCas&a2q#-g3W>n2I%En*VjTseXDSGK`7J6-#9I$0Ho*r(hJu++6j`1nL zDB{1UW7#HG>-7K_UQi$@kq>vS`$?QncnD_T}- z)XKYhh9SaK` zD_P&MPQPv3hJRsoDm75}yy3LW>atzva?I*_SqOJM{?kWoE7x8){BOgRmd%5;$U~aa zoww+{^xHZu)(;wNAN6-m*785PAl(T%Rb#(!=H-2`EAsBseji)plg;+2FvP8p&DVeY zQzM(7`GjX5+xySkzS&EkC)m7F-j4Ix2! zV#QSl#a#}?(>>UAi2}e;0puvZ49qp}$u)-M>G$Lbz`nWkeB44=czzVii3e5NA z>`Ud`9HDBZssWCw2u{`beU-FQl^jPq@6vCSrDY4iGJocBo&B=deP|D7q*}@T9-MP> zZcpyIR1R?9wj5GDn4dz#Rl5#^ki$s1-$8{AWW?UdDjmq!bjiX3>vhUx49n!)-f3Bu zsoJ113x!2x=(OZT>KlhN>W4OGFSOW|7_K|Y4xkRw#tt-#IaTJ$jDpbbAqQI5WnO^8 zG|O%%VYwoCxm~;iH*dLsfD@D*UM@dX^0OA$L-4Lg<4{SL8=6(p7rUqYp16B~T@D%Yr0t#n{t);xmZc_Or<_{TUrB5kJ^0?V+5XvmBxV?(XCaORIRX zS7r@6D$6-i9I-8oj$OpUTukko3WBvR_ojGG{VDyToQ>HXCM$1RVHICte1GJj%%s7` zu2ml7wRxlru~ogOFp029YlbaJ1J_jKw--=7+vu2EsXf07VY`C{J2UcL=Aot~Jh2p& zD$U2%e!F@DmF6jx>3=HV_WZ}a1R?9Bq2kSbdmXU;VO##u5VpqhU!FY>?w)gFC+Qeo zoqU(~z!$Pwu5y0I=6Ecz*e-7jU+B&KRG$k+EaQD*J{7Ea;i92?G=I#Y z`_jhD@uy@e6s&AokS| zz&F^d=&&P#fMZg~K?Vp|_d@x?Rg<3gSS5s0|IIVO5D76i6R4}EO^_9zn_@Dbn(oGw z4gu%|aFek&OIq)FI81^Y^Dq)i!dqUt4+SGc0UsK5aVF(44S>KV;9@r)6y$Ph{gYS_ z#*7z&C)Wea#6*Wb?;NMEG`K%ih;aKHfT8`t-OR5@GmT&MUXNIm_<9{k?b?g^BJ66r z8j5hey~PCJVp3*SL-A{k{+#LJoxRSgWb~5fA%qeAfHAYd=&9t%xnNjyQ_M*)4r z(qpBmGbp7*>E zoqACi%CIO8R)lfCkrlH6l6>poLxpl`_F}z&4Z8*65>GNuxac(hlrgQdAgop##?Tcy zQHg)AuxZVw8Y22SWGR zWF-aE+TrhAgzBuS-3%dhY6E|v^a2v|?~5a`7Gr}I@87GyYAj_wDANZ#+D;L??g8jr zSmmC-A-|9-4l=f9;Ru&kvq3{e`l9zB*9YBS#ok(jeC^=MsGYu&Ghxt5*OuYFGv_*43CngrSEY zV4LfzR!>)UMU*byD~|<15vcbtNnv^NMisw*YD&19`5YF+FxAIXHLelGc{Eu=-nSPY z;84GF<*{?6VvIMzA$YDeusKgPdq-ui9tWxr^xu4HN|kE-w&;LzGmKEXq#kLB5c}Z1 z4@Foe@23Q08=-$>q4%KgKA zdz#wUimPI#S-gR$K1VrJh%S0y+$6Enosuo$D7>vV39gioNo-pupRCyv?;Dl`M~23% zq}cp@=55mZIWDE&xF;>DCw*mtdE&1UFe1H$J4EysV8%M5E_qSorJAYFwEWAcyPh38 zAI22}WE5r&5<6RSD#uGc7q|CoRJuNEYqHFzqHM#js#agZLpWU2vMeyCKHSQ9x5@+y7?ukm99Ht)*>n@v0;YC}6++BmFZ=-=kqgPv_uS65PtvOtxh3#%B zI<)w`(hAX@o>Qx<9+r4}hQ7@I4{ANthX9WMFVw>MPin~kFu=4z|4A(ojF$iq@@o)? z9K%?P0okPtWMVDR4216QFAX(St^5eUc%!a!TDkfYLCj@7++4HKAD>IEnTU`%>%G$tqP63R|?2>&hMjDG_6DB~wVf0RBrtgicLTa*2l9mF8t)^_AfsLYa z*xXnco`bShA_u{ytrdpZymk!j&6HUhg;WG2lRdp9F;i~vWq}+(X)IRGv$UrG4kh#` ziO@D7X9cm^zK) zoEcruY*|)_;^gwqo3e@1&mZ$VC^uA&pJZ;w;wm@XP@COT=wD7|*C{qO*p>N*N5};m z|5VSNfDw307KERYI9aKK_Up(2?(D6wUrk`fWmx;3NO-<%x3G;~&8Oj0LL)^@*|^!C z)=6!i=j!jYDOi`*wOVAJ(ydV#*iK{lV_j4QRW|cfC%w5W9Q+YNBp`PssyEyM$!nJR z8v(y#qjd&Y0?WvUAT?`=O^PvWQxLqEGLKG)7oO}kxL-+yEW|}aT=l4J;Tay#H>beA zn&o22T=LqlZCGXHGCv3@5WscoVRTpAjaohO`J*~UY?xIY(1P!J=q@j|w`L6^>-nuf z@_`dGlsu^0577NWh|KC6GMDQYYT9^2>&^2-rh_8z2Ki0;2Ylki5BLit+hMOR-tRy{ zbwu=3xrR(7%z|l&)-X7_;M#jc+63MT6>1gfTbws}x5tt(8u7~*?W9+F&Efqmj>ft+ zH89X3tFkJD?CZD`(VqU1RmtCFk=TjMSYkpPnv+Dm>c2g$$%h_L?Mo`#5osuaDqg#) zn88jhthm&IaAo06q}_CmTnQl8pLw?@?}_jJ#|v+c5RNZ{MxCe+m!lU8IQVGu=7bBo zGfU@HC3jV@A&14?tzviLhm(L`R`j}~vt)#gTOpX*^zs5JUA4LuOb}fhT-(~8M7Py1 z1n$S%Q5_RMr^cK;i38`s{N1q$ofSbM>1tqX>)ajed?Q?n09@SjmUQlM`XO9wq>DE= zd~=zQJ|L&-g?HsCV|7TP)xe`5-V!Yy;mjD4jNRQ7l)o}r0fVvpw?ja!7%&= z-aMbf!0;f1R5+t!HpUL!ZvTLv~ za;p>K%Sr|MX$uz!fhgvii6q3CQ}7}-**S=YL@ARP%Qlta;*aRVZ~6+N@c3=C688u- z%n8}3bO*WWK;OeZGowoi{)#ByQ=7JurPmURW-QSiEw`Cu4C+D1TF(5`uPsneze5ZI z&rz^PZD;zHo=HUkqhu5Pkc^`btkAu6)jT0x#^(hHDJc?S`$S-JoE5b~WJz78!xT%> ze1KvzgbN7A>B%FYuw}CF_^D?#!kwh{3)W?m7lHR$7=($h@V@qS>Si(*cuAe#CsAN@ zq*f9uW6 zmdbP1CB5za$n>e5l#yz&KoH+D_g>f0;#uuXM#wmW=%&7<;^fTey`DH`EtkBt3P(K3 zX8J|d$e=T(G4fla5Oc_WxYdfCGsEAhvGaWFBp}xxS55fZgrzGQmxhRwg^@LW#A;!nV(Xia;_n`S{PHb(`ijOvmRLf^|qa>IXw zjPt&v>0Z$6@m*uCW{$NE(h`bI)8&W(*|P*;Q%mS}s>J0HELOPGzQXs=fO#Q+vj z9D8rXNx;wLR_obPu_zkpr!syT+XQdlQ9t>#c(;7rCDQT3W8w3V{K_}~@%HaM;%@!F z8aKqp#p|%#QrkZng@LE5FbV8M+3afrj>c-sk5b`&H=)XZ%9>ZG(rn6lH|tP0Wd&>k z9%1%}>q-Zuk~Go~dW!pDN^A>_a)C*bs;~{ImKGJt*+x@&os~`n#SVf8{-998p!if`phDt;MI#P6KeiFo zE>>5Y+_A6c`@7G_cC|zjdU-6`6%MD2lWiXnF1Z~=VTa*?ehwgF_jKwwS0h~FEpSz; zmnQRz5Q5ip6UXPO342FY$8pk|FY`WDmUnzIiTi%|R?{~g^!>8=_xEG3-*+_s91q>__?^Sw-Ud>9tPxNUBX_@OR~|5EBn6ZxS=`yu3o z>5vbh3AGW_#C%UpUTR_Ml*qI60*@eaoisIf>G&IZA69$YJ(65(;+oVCnG|42J6AZ% zoINqrX=j^|D;yH(^@X7(D?E@e$rJL&0n(15a2tbU6LyK?SY|gjg1LN1t){aj!K_@^ zU%}dG?!;gEr#yu*!Taf9GzpRy^F(bzSy(jyHhwGB8jkB5FL`ZxK5WwGm`qJV9Jj=D zo(g0ltaN09ggIeB&xO;d?X&R5bKz7^`?9~n1E6_U2~w?sgxx7%3QiJlPcUaBL|gD% z+p|28Gu#(gkX8=Jt5RyvAcZcgWf3%|7ESKL%2*$q=f%Bo-C1PF0K9M zNNDjvTDZFmcL(gD5KAHo0TBy{6LU4`3R%f92{Tj~)D)^FFoHVaj`ui_P=VT!J(B)g za0`(bTAL%^c5EU`yL?JklrQ*ShMgQIgqjNSM;siH`c~Kdi-c`8`#8??R&q4U zGMR)rt{Ucj6%r7GjfM$?X8sF@RlS6JokugVL8ca86#0S!C%QnRN^S;AuIWy2ss-(c2R*@`-bHoCxXa?mn3`8YR-ug*cc!GH`rc0d?) z;5X?D);Ip1*Op9^+hm(hR)~`E#!rxZT;*9lNb&7z`w-YpVO8;*p`bleq-;n&y1V0< zDq$>D07`{)Co?$2{d+KyjX+zvwM4Ws7_9j8;jVSaK-HUkNlXtY=zaIC%enwA{0szpFTR$T?;l>n|7avuHNIMw+!7kKcAr z)FCqhgtYwhQX+{G{dLW4Q z%(!j3|ItUiHJ5(-Ce<3nZ);z}?IlUZua`L4pZ%OamMNMS#}E8Rz_95lU*nTNlMrhR z?wFeh-IaLvB9QiK`UgE)i9;(XixZ*Yr(bf=_LA^wtiUpNNdQ{)kNU#|LjHcHgnmI7 zF$J8J>9LgM&v`FNLQ#MB@Yt_(!rTh`wth@^?)U@3o
bwmDEjST!PKPZL;82W<7 z`z{kk)))zQO$kl%%ZJXZ@9ait+>JXNA?1H)*(Q4ZYwF)v4SB4MFs~IFwv$y#lIr=9 z{NQ%Ix(AP^3tvc5h5UI{-`j1<@H+_FW?ci>J0= z%Hd3t@w%^_f81BQt{jy+`a8VTBhPKphvcmqA^r&X-juj3Wo&t^-55MXXf-*1$?!W* z@k8B6e&AHbKnGNcGToGuGxALk+Qjg_LIgS?0G*QFEFccl#zFh}oMBjC+@@`hDEM<#Q_Dij9mM&896YiRS~#=y zwjD!ufl+QzG-1JOj=KM2QmeWb{oDlno%^(D?);zdn(eN|R|bI>$vHJ=3(N_>JFKQY z)_kx83w6LXD76&M5MsOfDe}i?(E=P)h5e2F+yxUJ;WV zkSWyU*A<$^7n%C5=%$I%ih)dd#&>d!%q4j2A>D--X`joI>E+ZkB#MYw)ijtDCAAf$ zonu%N=^N>>iJ;9DW|T-n{Eu)}Yi?r-zNM+Cu!FueShS^Y9;LFh zWl6hjExm1H-c}q<5@fRN*s<*dU*2|k+;*eg@sQs6V7}w&zhlRb`53+9x4h&3xD!xG z!cMyzV!j*Zzl%U+NAV?XK5LeX=Oc{ECNa(4P41utM9MI439wUp)4_XOW0^bh5?KtR!(;jw9ANFJi zR8;yEm>&-MGrTY_R7f}+qCJ}6-Jg;=n&D-s`FylsuA$g*v^+|;C2+7Jef*R0c+>xQ zH~aV~ZF_(DxWMD!nD*q?8u2dA$z>(onE%P`Bb9gE$)hya>Ej9dt5e`ffRyj&} z&6zX%lmK-~NOwl`>I`B5KO=>ok)JSJk(^SjoYA7r=;_XPybc*H&ROB->|f70f1Go# zob#g2`RM)%pj>VG3kL<^e_a^{39@OuzE6q4BmX`Am3?&q`{8d&3s{m)e%WPqC|#v$ zaY0!8S9)iSN+UrHt{wE4Oqy#GJU`L(22Jek}Qcw&%qZlasCzJk~tQ}9XT)0@kx>@FO9X2o>t~_eh@vlWe*#H zxkDrvTkATFDQi%-@UdqcT;26B>6`4{me*e)!N*nS=WLa9E*NKZX*Xc%le=N0F4n9a zs#Q)RV#NjZa78t)k`(JP;tFn5e*g@8*@6@01xa+^B$y^J9{qvDIX%uP-LsBQN+#U0 zsgm?$fB|63&L_NuPYX-aVu0%45OeE_VXea=NqEeA(zKB(${;#8hYAPB8WR9{%sNx* zRXy{InUu;OVxMHEQvgYqb(+@AXAOcLDfedR9q3n@efP(t(m(tkzO{7=v62Xt!J?-q zes4#W*a#Zr8Pv;rDwYQ0?Y5nl%C}O`F=h#@p2##&iRKF1d*%K~aboztd`k+BMPn>v z2uN0hQ}rmv7wPjnQ+1FoJ)P+W;GhU%^Te(C*NnqW!;o>qMAO-V@`q7pA5p zvRaV+dmZq@UUeMdrjjJ4FyjgHUm|q4CIc{!9Dy55J6!eu@vWW=!cKxV9HRx-TzbN5 z_TqTTh4okQpe$vHNPQ=VUQ&`oenDFG|I4>nHr}vs5#TNi<+0VFTfe#{tOXARyP4W6 zrPo4!4c+bMIx0W$EkSkRS(~J~r&%xrM(E6)JAjrhcz_Nb%gR^$RjAi=Kq9u&23rH+ z6bwR*@h;jt@h!SgAsuD5BX%@^CAU%qhbs)B8oLh}QBc-G|0j;4w~(x(XJFf)XJC=S z*_|k^S6?fjbz|*k940D(aczwXf!GQZnh%J2_4sbOsdQS7CDAs+HL1Z`gu(fjtPGHtW)~W;7JlW1(x!hDlQE*YdGa}bDus` zRZU#z*X#0KA|awsZL13eIao(ACM_Wb4z!`iq^JZzSg<_;H|R-PloUiJ66{qrnOI&b zs|58%(XM5qKkF`|J4;~9fJD%~U{p!X`TaV(eV~wcD^I0yBBcsK#__yc3VW0-5!8~; z-K0dQ8@ZK;5zr~mX+W%H4TnL0nCi+$VyoJTOs2E76v%rQs}gG`(f>gC+KDu@jN}{D zNiPc#i*eM7W#%KevUMQddB4;>8lL6@Q&E19{2u9n{#7H;E?pf>m;>L_JhocVo3)jGrV{JEBz=dh5=_ z_Af2t{An@(Jo@Fet)o6wiEf&uUfJ}<{!#a>Ug4){?Ki#AMz4zWtE*GWT|?iQ&LdP3 zA~*BwS`N$#@(g0^;kuP@^NW zrEr^a3IwxMVo&Q$$uzRpWv0GJ-q}$vvwm3C(S6g@*~#{R;pyUQ39zN2jBspWo;R@x zxRzGx@0+o9>1oj(S^i&oqFGH3+AF+ng?&B0K6S|qHFbe|HY{rF>aaB}V!C?wat&~& zRz#aCyBNF(S~_b)x~|wxvSt8@$Ne$Uz$SX9dAs0!5lpqo53T`38rKbZSAMMI_<>CwOQNHslo;CD$CK_IC>+7}T z4gC-_k_ok$ytavI{;q!@k@Pw%2BeV~W+mueEa;oPrvSr*1d+ol^u-WFi=e$rnFtpm zs|}x?lN$&~85vf%EZwA`zo~}WDOKRPz`Dii>fpV*S>r7 z-M#g*&m}5lYs?fTL7G86bRwO8@qi;SxPqHhLWaS8Fh=vsm`Sq)4oMb(#tUqn1D^56 zJ#i+^l{uI$uUoXA^R(4`&iJ%{`eyDOm(GmRxV(pC)o-@%Ohv!C`7z!m7>e)eDBr3M z@Mh5~2A+se+%Nq4WP4wl|DJlUPGx}LLtQ@4{#sC*>cnUQ%wm@GwpU53l_0585w3Ms z^mT3<-@|7>Mp7`Re_-uhfm)DLjMUd(nT!8e2+PS(u#>6Fx>VNnY}H6bMg?JhMSFc1 zlqSkWS8V{h$-f@0nD*bS;5+Tpj#kU05KWUn!9xp=jdz+M#{?}1V&dOInmpgR%|yA7 z`V998W7U6-$yB$iR#g5?D_;RrBH)(4xC5X>bq-(hJ4LgvzmFPrj$JoxbtYTgA}}x{ zQt3mFRbz3OBe4FwO3$dmFF`2Ybq2EkC?xM9aEi#91kJY5P=H@SLb|F>Q=eUcW=ueo zkO+PD=pg&{PhD!K?N2q#qm5{c$q2IY268+acer?4g5fI! zAoqeF5oSq2kX#r%2Q5g;N)*KA=O_-{{hlbpilX<9XCX(>;ia)iA~<{yC$&)S1w>ao z;M;@@?kz|=M{K>CfK#X^rdoOO6e#pI$ zSpdV9V7kvXr3^rHeLG)>{8&MdSn05s76rPDpZ<1j*vb|_^F;cdCK)E(5JtR7|2JgpfQsrE^ z{YlpNWQx$$U!jEXfgnSbdU}=GfnPaOgS7yaRVXN*I^d~&H(cB+6j;R&5?#j3n%F7- z=Pyn&H}-pO$`Ja_vcyF){H{0vZGPo@3K6Uf>o?2o~#70ho9|+3-$+?Q(3&5&WnC?WnQwD$9|Sc#Wqi zlC3$NcP$zS&?M~x6*WE!OUo#Vcbav^Nx9&9VtlpqN@!A;b8=l~l5xDu@*+q4bIM||0GHZVCxLBCGi*~I;*9y zU?+c}tGBOoMl;$*%w*6PfYRsJ*=NxvYJ4PXvha)Ei%&$cEaV9<1cizTN-h-G!i`hV z7nY*+5TCmX3l`i;i~^rfV+y2XTDNj~p%l~PP)vh}f){Xr%>G7DiMiNNz|rtMyWV|~ z_E&V3?P9=DRGHMbw@b6HqzK+y6Ds0^<32`wk&1vj;x19g<~8$oRxI|#u-KE}b_7LP z+3G~NlXj_Nw~pWpfBZ9g6pfOYg(GmxgelB^3oj8E%_*BrPuea~>KJdyFm%7w~m$yG2_e|;y#_GucjN? zolhA#V`oZ&)X*k&FfCLPJBceeyybEK;ICYm#V)Z#e29sdJwg}-Lq>Z=;TO(0HuYSY zZsz~CRzzAf_@fr7Xcjh)gq;ouRfZ8^10%|ZIARQ|kD-4^_DD|<^cdiYU**j0QN zlOJeO)jgoQ5R!|n*90O~6FI`MZWieu`p)bXuLbn+X~BW<7Q^?+BZ0U>K)|TfTJ^tp z-c_+k#Os6yyr!;r8=gWQ`1N1I6+p ztROon#98PDOqu{p8eK>j#~_`EvtFWRyGcrg=o7HGkr|5&uqcIhS<7e74%}=aGdpNi zI}%RKe;^Dw5+aB6a40UJHK%Xl*z@BobM+_cjdfsCl+cLFZsHM49r#zCKFOrq_nYcgNn zp05Yy=)}jjq3~!Bu7$)pcxFdy z&(Rfc@0B)2ypZp2HJ;EiQkax8ttp;jNsgKt+oSmxjW~{ONr*0rk3J2I-GK5k<#U%6 zz8h8&e`{solq$(PH5@aLnXf>LB-|Kkj(5+%KBBQrfs95d;T)A6^rvFi_kp-BK~Ha` zlvX%8XB@6FN`M$02-ve`{sOvgmPIksh=(3dh2ceoMXc`jOnO^xf_4<;qfD@!quBW- zBAN=lol(=5M-7suNOoxkMy5!Bl&&wrn-y{Pk)2*CW;l5|C?J(QDD%N`|DkkzlNum? ziBKnvx`}j-w-4`8JzNoSg?`5yJ!aT)Nu~ey4*n)YUyXrqhSVBYR_l`B*EwwjHw~=+ zNbKem*6t!X9iybUQ_Fa@p5rfxd@Dl(E0MjyTo*4q|Y#2fhYP^+o?!WhZ`i9z!od z0$;>)Yr*MedH#63zn(M67U)vmT#em?{eBsQo%O0k8gwX&PW$|vmkmK$6tS!^7i=h; z6NZfvt(3A~K%@#0c0TUw=enkolMY{=iz_1V=s@@RFL_vw3EiLhS{_$dpCUWoCmHSq z6S~24f5+v#m+-ns7)V^yvChE5eQ-&)gKHqlDmsLJyeNoUJAZ_I)3c&UzwRYCM6rcbzxexFz|${ zyasFcRcHJgJbLYCe(AVF);hRy%DV0!yN%vqqdVi7Y-t+X=Tb6>_{GEG;pz%pfETQ( zE1IvUCSJS0E4nJ99Ue}(4C!2NGaCU(H9z!09qA=GS*`q^q8nE!^Z=REw$6i#pu`y>Ih2LCKqj zDe*!7nX~>+1v|HApKhyudl$Q2d3@~*TgQnr*s1|}c{ND{jop?b11e5#Tj1oXHFq*U zecPTRne8!suQ7eAF`vxYkEYQ1!n>bzrELcH?JsVvHhh1O`I(Tr4!rP#_cZzGVEa~k z-Nq>1u}`w81MkmBA3B>t_2UELf80;amJh#u81ed4Zs>dJfjfr0U-)`IfednjEpRdE=7~B8oS{T{d`YXs$1r}bR01z_< z77@4oDZOln420Eb_mrUrLG+xL+OTF#UZ?SJ-f{Mfso#?bfY&T`h8({nVx+`@$e`z| zBniknUU2A!D;tN_Pw#I`dsgJbkPMdSWbicg9bOiSE8gYyPajzs^n}Z%V@JQU|;zIKChvJCV9PLsZ=22sdomdCI%oJ}sk*UF4!zV%x7<8=%Ja-b@l zpjxSzRz8`k_Kh|MKpm&R{ES!H+Y1gDVm_1UOoe-Nz8LZ21RQQeLPlABwD|=FE}Wbz zewX<1Fe+i|e+l}gU>}ac^#8b^T+QWTnl!^rQUT!++4#pC=bB%o6|C?CWAuN}m+?Ro z_RfW(2IZ~kfA5U zKQ1Wo;`_gacS%(l^){PYMFh#t=zWg0@^7I%pSZhn z9Qd8f3|MR2%F0y^a^VfcK2bo5%jP^_gp130Ruk=tidG}YOx65Xp@K*;6xB=+5&&tO z5)8Ck%(&O`d|WpuhukLG)MYAX8Ss9<+Vk527~>5))y*zReuLt+{BoYEpLrc6WztFd zp_fI$Ctl{YLA1xT`K$bI?~6W3$UrC`{tmrr^t(4vz^pW}RDws3etNO;KVHzS!Gr9Z z99qr6tIPWh&2mn7p`j(4`@j)NlK_IRZ<;0dg7K8jjuOy`|1Af&3-hS zbRS?g6Ylc-8udr+c9kQhiz^hwEHnKolP$18>UBVIDx_j+f=d^rQ*HhH@fic3-uIpGhi{A2UYt7wn7~d04&;(I;k3#~;qe*2c%r319E?8+^l(Y1a2&LeVa;K@HRPJRCBg^DGl(K3{irhvzt z5{N8HK<7;1tG$hf(bKFCA`6XpI7DpVuXaC#zBkp6H!aN9_)(tZOh_MLT>l0a?GsV! zsUbbE1rbm1`H5{BlTZC00C7N$zpZ3ct|*(KT2hacj4CDx;)*?za2l9ljSK<6pt~AU zzk7V>mIR4{RVMgPzn(THdH2MMJnxFaf@dG|6mjY0H4+hvW2Uq z3`;R{1&G)fk>|PafqJymCU9hnDF}cSu@Kq$*Hy=rp+=PN2kAyQe#+QNSo69@tWD;U`*#mpTNxD#oAb64=hmW_vH z5@#twyH@=6QLT^^a;-u_t}?f|&V6q4_UXT-u8=JHBnWCP7J&NMVwW z88l*rU_+}PDI~B7s<7q|YQYX7mIwd@+?Z@<3bT%kWN3B_rQAr;3J%(X z1J&HZdO0ZDM`4*#T*&7@A~0c()T4hDb0f%@W!4H#5VIuT`qKnkA>JrTp%$&!z!Y#7 z;V2%iQY5JjmQlydfiO5E6b%4|H-R;)jEN-bl^KWG%MllHw?OW7uRUNFo7%uRB%xwW zHVJ}&P1wS%km)n*y4(&jfEj6>?nv~EEVGYXdLt$+2^q{gWjcr+kZ2lp;)|AOn&rH|48nL@SX*$dtL=*=E_omKYp^ zJ6s_Eu@XcR^~kP*OG3a}_}hOLwP1tLfPm#GE0^GftKUDSOo|PE5QMl(?z)zH?wQlP z<~F~1LWw_Ig{cL|xijvzS!g#kw zI}VwfUvOf6RU`!~SYeC86Xc7&P`Rxs)V44+y%chmGa3I5*Kt==`0}vkiKJUfzF|D2bUp&dw3q|4RzsN@EhFaPa2%C3lpPInL9s{`Cu~k@o3APi9mw@m^4 zS&tib&n8I0Y|MY#)o6j87$8-&0apNl8PQ*zaYQNj#tBB5FnL5G8eqoUnrYEt$B@zv zcAGHi7c&tHxIoj2xQlFjVGxx>0~O00Zp9`(T{a=XsQH7bnW7A4T5Jqp5~T(yXd+2S zfdn9lw^&a)RZKI@#tYU16p;xNrBgNS;gapqpd|r0*n)q)1js$*6QU%+{;*nVbXl() z6Ekub>{ud4d<^!r4u_G>FYpSYL`jJCU`$~l5!U~sJ>H{m_(M+MUqiT6K}Z0`wTFFx z6su@LxHt|{1lAs16REiZF-buw+`SVYj+ zVHhQFC4gor0P3)UE3^R+Y=S$Gf{p+I*VV$oX@M=6g48KMS=pqeK*m&T#ja=sr8S)C z^~EiG-9YdaL8J)D_=VLWRnKWhxM&*2;7GL?3OaxJ!=B+~J@i!);0Ph2nOY8j0>EWL zz*IpM1gSk@<7rk6LEUGG)$I7gPOiipvO-b;|9~WbPf3!(Ke(l_X=XCvU27$g>6L_f zDPtdX7G8lCnQ-OtMTYVf#BR~1U38fOC_vhE6E1h`Skh>d@ncBZI`t|*H-h7tAzgdRi|v;=rg1ZeCU zcijRf5NK{B1%n1zBK-`@7@5XYON;`S@rD1~pd^7-#84h-4FQ$G8m7pKY-kxUP!OD; zoO#zm+=(W1LvlLdQ&ZW#K(l9nbcW3&59M-}pcvonRjYN!^quQlxfMY1`r7gvVe#(t^ z3F(R?s+yw5iVRdc0OFva#x#wfqF#wj-qs``!8lmsuqYxbvFb<-gu;a)fm($zVFUnv z)`~>n1YHubD4e%BMBS(h-gJaawJ3kIMr(@tV?hMu4B-X$K*iy_m4`&Z7VOWwAq9B# zPi#!ZJvdzIMC1|_mJL}BWTK{O$*FUY|7Kr}0ZkSJ8ARN~C4|LU>qK-&n2bW=)xzA2 zTwgv4RGO3ofEO1isN~VkW>U_}1%~dRpR`^q#%8R>{t=SY!m~V2D@=gM4M2Y>EL}Tj zrAn=)x~Wv(&E3EB9aqq$bHof8FbV>!3Mtr~Y4$~pm|0~?7!0x+LR?$Uf>JwVYhg;? z(Av-BXb0_(0uTg16yQh=aY9?p511&S=^LyLeA!FZ#3RIz-B@y0q0#N<%G-Q ziHXZv4C*BbXi{xKkisvxWdMI9pVo?!%Pv_dFi>k^<<(fHY*M69*%|V&!ecgGYXL6# zwH7IqLBK)~s>GB*aNS^GE$ShmRFwZ8b?(l_hOX!iAyW2HjdGMnA%rBbgi?lx{=ter zD!}?6(a+RLJ6tU@9pF7k-GVCOeeNmkl%so4kv))u3}G1=kOa3e(4c<>*x!yII~9wT zrr<>I-m-02hmph!ibT|sC>Rca610IkuqQLgQHY8MZp0&R>=8YVuKT_({J!tX0)@Z2 zL?Y2du}oPKAult1OrShR1PYl1@RP{3!xnB?W@sFazy zQXEVXpG9CKyl#fP9)!KNK~jyuNh)a7sH94ave2du?TksE{p6#Nf-9Jm0z{-~;v_4a zvQ+7j`MI7WC$oPtFZ0dG-&@to`n1ACF2z)molx1&R(9oAE>dS%Y!cAnK}dlwMA-Y3 zS+!_hVn&VaF&blm@Dt&k;*tM?r6gZHWS#^q?$U0S)7HZ%P)O1u5UH8wK7H#3F^VPt zrgh?w6f_;-swPjh=D&RF*qYD9*+*!~9K%u^|FPcO?IwR|y`~iS++8v6lr0R|MFhd9z1N)}s0d1`SMw=?&6YzDP#etCgr4vA z>P>Fw=%aruHF|w@hhcS*9kPk~Ad3F9U;i~=|4|v7ge3n+$mG-n2tceHwI+>4K^jOwOP}H>r_S|3s8N4AlYv{aGo9Kx09mY1gF2qtpgfY$~ZOoKE5stq6BrX(5_R1bBoNI9i_$&2@h~qxr_wzjO@%hT)@}HnB(s0G3Lq zX$Pgcl}(S_GS-HAx@9Ta#&m;Buy&!Dn5scZ7CJqValcH=v_ouRrD+qyc-Mz>kXljs za}|#hb)jW7d5bp1hO5nWe8l<0bi8(Igpf-g9OqjEch)P-;(jz~{!nG}U&)<0;Z z-THr>ufUWmT#mbB#b*iPW=Vg*mky>d%HZe4sf^T_dJhDc45po6Ip|Tv z-GT^kmI_IA*uPK(t2s#!a8f};fPVU&YajvZ;gk(!5_@CKVWPxoRS+@)UJcpp&RY7F zXaUyL&{PcPsAxuQoS;2Kfon{KRgm|ih01(sPIKU1syE(M)cJBxYL^`nsk;LL0YQH+ zqWHghZcfKskQX_#H+!bA5DXQM3r9u$w#2eCL}=?28AuCNq|Tt}Mj5b5JN&}n;0BRe zQLDhybZhV&Y8sk&|I3$ln|AzJyw@J9tQ{9ni!mBllg1`>rH&Jk!m7NS2fp>WXoR2< zNIP&r!I8pazT*n(_4@vxvsb*uUwnV22%V2`#otO#XP|~8Oc`kOOa`tRtLVt3I-`{v z#Hc+USh`u#)B_uki)_g0KzIZ@417pPsR2F{87#~GY()b_K$tK~4O$zAwE?SS4vu^b zYJ?EFmtjGnh7u4IiM?Sh`r6Oq&U!YG5|9Er{JhS$LC6tn)33@ElP{u)T)=-f3&jM} z0G5HY1fwt|p4d;AJ&=8B=3lF5M9w#hZH%0FBEc0~878#?E8IO*lz}#0i8p1M#L&I6 zY=JkK{M*~c0FwW~6s*c^4MfHetg3x4D|u5rN4~?HPt@aZI3}>4Oxmj0numHdu@JH< zW}NyS2&yhJ#_vAwFUILsmvDaw(_0mgpD4lO02tOJj@-=8sm)hAijONK#VEtjLWsC( zZTdkNU-uLcrk&2I4ZxRH`&U_n^K0$T_juU7^P%5s>S)kC{DM4hf!v_Vap`nb{eJJ~ zzy9z4@bm*f3=lYwU_pZi5hf(K#a=4`3IHUC2w;oABm-Lgv*)j!Eqj0e{*fUFV89kz zt|kN+KvLmB39Rf@WH~^SKT<0TO0m^SLdjPCKDHyXbgIycGo`d* zI`?V8BxBl5ZJ3~--N9?5Bmj^ScW^sWj7KWmQct6=c1wTVlPRfT&w-@=Z8=;(Ns}78f5J3`BiS7Mz`gR}? zQuWNX%$%f6$5yt?qetKBs(6>bhYTiv!08f%Dxr&jNFWI*q>v!P43hvLq9AaJLPHI4 z0-`$v1{g6zC2-5qB;qp7$6B1i4+LLAYdFILnGk> zF~to{yUN51+Z$j?(*i)i#0^_)$wat33bUgzcZ$qP2hmJZ%{AF<)6F;Gj8o2+990Jl ze{2b}`glCZN%1U{K#2r`NT9+*hm;}#B#4BPARLvH;G#j}1A-x2qzFJI6GgP~Oo3bk zLdFq6ltKZO28cjJNm&B1Q5A8*R6{9(yl_$&izJ9dOn0;r#Y!{WQUb=3aFxg;0vHsC zA0-_CQve`^^3v`Gcy&aJPK6TFg93dLf5Tj=bd@(+2@~u4qE7;6a15Ce*jfZ zdgb|Gt`7;AS~0rmqYavxYNoH=&$^j+Qtv#QtIS&LvB@sm?6cA4dCIIa)-Q>af$&h# zl;1WHM3Ul$a3s7vlup`$)@q5yzz;$I?-Kh4A^->tx9?fLZ``{e#iz^rOuaKRapH>A zvzVv0(+*wq(Md1e^wUvK{Y}gPe*}H!xht%&ZbcWjuyM=3Jm_EgM2$kQHU?OHRNeiy zJczz0Ohra?Qh3*K{~qxk*|&Vx^F@7+|L@c19?gk~fS!{=+*Fr91ul?*4Rqk1e_R(I*0Cvq z4^r9%>1RJWQL2J^l34-6WU+8Pkc1^P;R#WgLKUu%h48uH$|OP{4kkw{6=~t{CRM#R zb#P-m%%PiPwiJq?M~Gp&;SrITL?tfqV-*b4MI;o(^}WZ3^cmE*`h!KoSjj?jo89aN zV6^zqsDJKyj}|E%DKDmoa&4$Pm_c4)=ee~lW0U1a^7Kn`l zDP(}0Xdu~nac+rP|7000k`PG3hekL;fE2c%l0i-KhXvuJmjw4WCTa0V11Y6J;AqMz zs*;CftWiPg#XD6J1Vo>_&m1c=yYd(v}rMUa`T(- zwC6qXnNNN4XG)A*4AsW$;{j-YzWzz&DdP9uX z2a^XKXhA!=(2S0hf21Wf=>s#!y?_$$q%C#nOJN$*n5I;wHSLeyHfmFy?v$tdnTaiX z3NAwqHIn6=VdVcJcgUbFm8ng2s>iDMRDwvcsB=V9RHe$t+#OY}4^V+5O~Ivz_(qXF+Sz$%dA+r8VtoQJY%Tu9mf}-5F}e9yyZ1-U!{9q^{$t_bscSc;TvE1GIqT4weNlLE7AJq zm%sh>?|&gm|6KqNn7{=#@PQF5-&Q1;!3}osGLd0V5dZ8Z0wXXsBF#MnfC4pe0|~H0 zIPe2o;s8N#`$~b;M(_ktP=5jcFE9mJum#g^|LP+HU6B14Z3c@E0^rcSM5d131s>(18)o=~juF2!9Rr@cJrfZT@ig z-fs}UuMiC}5fyO}8L<%^an>B}5uNW1CDHZ-u@WIq6EJZTIZ^FuuoHDp3qWxcNwE}7 z(ci-G6hDs$@kbS1@fD{o6=6~JFvHts@fLA07j?1GC~+5k@fQJ(41qE8Y;hQk@feXY z8I^GvnXwt2QTd!?3?`mY_5YS|8m;jfu`wI9aT`mB54rL4it!sa?+Obd{>0JmaPY*= zm)@Yg1vn4&SI+LwJuv6FYay#9OH@#EYRxmtg%{a;PJkPN;6N1#xvpxB74&Bqz7A!vL zlOAhxAlS1$UriG3!)WePKfSERWFsLE5kPZoIt|p!5K$Bp^w;3CLEj8IA#_5gu{nK3 zAuNFY(9nCyRw8{S@R3IYpL{CjUSro`bMG~a(MWM|^X>`O~ z^hR~m7;SV%q3l6@G)Q+5JcabpdbCK9G)a|oNtv`JfwW1XG)kqEXJS*Rrnf+V2yX!s zmL`vMX`S|Ip*Cuzc50~>QLD6nhA7gNZQ#&vwh0fa}ko5%O^!lzfrW7oi;*T&`7 z$bsA}`zOb-SH=YdvKd#&c~_~3{ChJLd$pj%n-|I@i$J=waznh6o;SnY!^>Yc!UuxO z6{0}tJ(uly1{HrjOK1JswSC*U zz1weD*S$U5#eLk#z1+?H+|fPV9Sz*oz1=f9%-y{tA>{qwJ>Kp8-bEJG@qOQ|%iDf@ zfc5=zQ;psEz2FW0;1S+zvOVD$zPSkAd8=sQ8-C&q`rl{w(CcH*bJyWEo#H*7!z^BV zG5+H_a?wft><^KW@MkQ1cz|O+4Si<0SO06!b^`bjEx@ zvfv=1015`c_46bPAOTKfzwa>L`mzA_@9FlziudQk_kZOB_}PB=5peP!4ipqZ_`Mbr zMv(Ng%K78YI3<7a@nibo#Q3=jtDt|XP9ILTpYPkm${B>J)`a};82Jwb{V!Pk34-}e zC-p@}`2i(Q-G2b;KS0=jAnxC1_8*-90?2@Y1PdBGh%lkTg$xlUX$Uc*#EBFuTD*uc zqsEOKJD1>l1{Qzwqriv(V%oeZ5gX2(JbNMp;IkvpphSxrJ&H7`(xptBI(-V&=FF&6 zTk6bs6=c?iTM2e8`L(LSp<>IXoLUgA*|lujx_t{buH3nF6{@X^H?Q8k1Ca3j3plXg z!GsGNK8#qQ-;WQ2-Yu9QvE<1Z1vaL*SY+m?k12cp3_5?b=+UH;UM^h~6H(O!K`{)) z)-~$cv}@bGjXSq)+Y)g0Zt7b0?%~9X8$XUbx$@2U?fdKK$@8^6T5bk3YYDkL?JyGT7fslkA67SOgBF-GK@& z$Y6sGJ_vvRAzSa|6k$=NRES}Q8g9s;O_6nITRnsYIKV^cfrw&?Dy~T2W^+lwVup($ zP|=B#1rcM8I_}70M^lxj+i5sam&st^#FPmn8O6uehd2q@W0X=(Nu^o#RJmkz158+z zlJ31%WI>>81)EAhVu@y&YOa}~FY+<@5;z=*0}_9ippm!`jy9U*qDD?4)Tc&)vPo#6 zh8~JxHXl(3=Rne_laXp+wnmVhHm!1yoI3&VCrLCK6u_RHUNkB~r4n>vf-lu#(5EF0 zifF8|&dMiw8sT`7TDZowRihJ4`Vl`iS?5nnbkd4!vdS(CA(W$SLYJ@>`C3sw6Itrk zC^3Jjs{hx2B5^wsj07pckh9F5i*C9-o%vE{TMYE>bLWLe4{6RB#I8o*#&xB-{{9Pa zP$K%|>SUq-yB4;D{S%V83l(e=7Y-lPuR*&G^w7izovAHP+`<}g$Rdy2(ZvIGe38Z* zkr>3OB@L`l6fr5x5W>}Vq@;lhQM(Y$D&c=Qm@OCD=IuZ(lZ98M<{F;{!5U+-w zr-b`-B)?BtLWPI}8e-C1L?)Pp7=&=J?La;%d zdH*d0Qqn8Qq#~ga>vgAe#seYYx`@X-#%d~{d&(U<62ZiH$slIri4zA>$dgrsa}Z&q zLDV>qSv2L3D=C!~GeSa)LBxj*p<+c;B}Iae1#AQHq#jR+O4{g83$uYCaSjQQLZ%{> zJozFmZ;4CM%&U(Fan|@)LP>v;0P!et+e;RR0!xF05;a^f+5k)`HH*0NC9Xh-0&NOE4wj_EQ(IQLyrxYwEOrp*t1{{A$AZO;DX$1xU zo*xkd4dpYjUs$A>8yO};6k5@k$`m}jgvcv6L{Ke#W{@f2glDj~6jA`SAVr+1QIAR& zeKtf~)l7*$>+`&J9>jbF*`OSA6GohXu&ID?!{ADZ)Ub-RE?!L#OH1-lmqdoEX36PC z36Ub)Rn!kUF&<>hCf$3QP3fQB-but=V3Enn>sYD(`u{<4xTO&eJgB*4M zgC&Sum6FV!WR)b-Bneq08raZ^mO0x)&`Hv|5x@SWvIJ3VLoQnoY-KiYE{Uo_CTkOy z8fBg!`B-z3U#QHy1yjAiX6*&~;sa+bdg<}l~-$YL(Dna_;oG^=^d->I>eN%Cf+0GZ8n zuCtx*e33Yt4$pkcg zW6PdRJ1ffCxO3~?&AYen-&tRxCYr$EaNx*c6Hb?*{|Xs@Xa6T7zP$PK=*316h=6Ll z0Zs>q#$6n0`uq6v>)+46zyJRL1}GqA_6%s?fe0oz)N%?oC>DbbMkwKg6jsRJd=_S? z;f5S)G{HcmbSUD8B$j9*SoI+$#fdDo=;Dho#wg>AG}f3|igz8ynU0m!W8;rN1}Wr_ z3Hd_gkw_+g*;9N>1<(+WPm0H+l2le{<&`wOBoKt*VK@*%WuB0jFoIV8xC7pOW#!6nCZz!PwPLE~t;*_utFXrZDr={d&T4B?fZh7qop0Ud zQ-KcoI*_Zn7HjOW$mUjSvdpITX|vEqEA6z@R{N8s)@G}1s=N-BK)2h5EAF`DmTT_0 z=sJY$Q(>m-?z`~%2X15Y%2sc@#=ePfr1k&03eSnjkcnus8%54I}h#=A3u#`RAZlsB@PAj}vJFw!+?*ke>@Ae?gCo zu7ViMpuslhk^E`!gCGo{2pf2t3!a5R9vor+OlXkFq%c+}8wl5?$3alBP=z?mp$>P* z!%%gsheRV*WB(HH!`QS*Btkr%lMXXO+T1W!M+DfOfJntETCr>ExuO=`mBlT35p%Wp zq8Rg1!`Qs(Jsv6I7}xliGET^ie{A#{9M6J;em&%gf10CQHo>MiW(5Fv^rIjL$+svH z@{p=w%o6{C$l0t5XN{De9nI3jBn2{(A%SGj;uID=8tYk@45gfam7pYwGF;1|mKskq z$;Pbml;3$Jq}WlCm%R{{0+9eNXM@IFrbi0{FeJ~&gArM}%SgPr<%m4lf1eT_gpUGI zrqZODv^7=oJf^Hn5L)TW{m>F3)O=SeyV*7cB*jPMOs6{6$xc*Vvz_par#$m!Oxyu; ziu25;KCyEDNy=$5WPIu3X8yU)f*SOo2u-L$7doy|OoE{ZWhO)?O3{j1^r9FA**hf^ z&WtK%qaF>ZNJmQ2lA5$56CgdQN>|EKM2)khFpa59XO+#Fm*J=$K9^sS6B?JF4H+$$ zJFycLmw>Px3x5)JO#~%dKV6%rdjEDtn6%BUZl8uv@$|N)&q?V3kh$A6+BPM}EiS$m zS)=4C7m?0IPjsJKyV+8AY1X9)bB`j?>_n!!*A1_Dee+Uif!8tJGUs@Mx83z7O_`(` zz+_^}6iyO^zA|gCd-uy<-01fp>g}(9)iT`yi?UBTIe)K#8Ar;eWOu<3j_}y{^1=d_ z2Dk&@riAB&;dhN8As-qDG&kJT|B{%E!d;6!Fm+-b_OiwLD}z$#p_5t&R3`00CyZ4X z#vS|k$3PCUPj3905gP@+Lbj4FsTN|2lsCyL?I@I|wcXxM`Jx(z*^95tvLV7{~K;LNNC>lto*{Q+;4%&v~Lta&qv`bl}#>Us*P=>}>N?Ru@frhkx*LltaXOSZ>9{k22RwC0w;Iwt`R z%>ahownYQ^*)rXLU^@I)B;Q5Gb+NXwv4q>onb;^+_Q<)d-4TH}Z7Z%B@NzsM5dQ-C#6^qXh{Tsd@4oAk=i&-*69kwm&-r#}%oH!3{3Fqh zrlCcC5sd4+vOhPF(eF|mh95m(TG05@!zFdJEL|w{)=JVP!kAUCgy>VBh`j;E@33zc z;b#Ad!w+Wgv`4n(TA6vL!R~gF*1GQUTYtO8G-zLG6pMDKo2{Hz48 zc-5f2@tE6s}OB=+H+;~w$HupcQ0OR4&WuOSJKs`g0quLocG92zVer^)Zq`WiMW5dQ-o<>=6_rN z`qVBeeeW+l|^CLA?9Tf`svtDCP$8MfK;; ze*mVe{W21CEJ1kvr!X&9fK#)19P)Sxf`16uI_CF))G~GfGHhc;7)=C#ZWetKI6#ufri2tKcs*v*f;RU5Xg3eLSuu6k%K%q zC+SxbG8lvjLKr?66dGs~ONdet*c0*x7j!0oM%W!mScTC8giJw#Pbd^7h=uEuf)YW3 zVCWV0#(?QUg=NSZLkJvS*c?W-hMaX~XDE~JR9aV&mI3P)2i25Okcrl2RxQU$DiTJ`7gaL}*G613Y5Q``u z(g6UGL4)$8imm905_2}92#WGWilbN>rxjkTzV0>K9KXbblE zkoVXMr6`cCQI6pfiq3eD8Htb-i68&?kpbx^{GH2H`OL6$tBmXY(40GX8A@s34l3#~Ah zq_B@N={s)ek9i3hrhm8+CWK@JxsilfiRgHIFF|UNfdsCo6pd*RK1Yff2S@|)62bxm zs+gIw6pRC)1Z~ieMtPF6mHOpD$HkesCJ}Uq5jMG&JOOKV zG&;8yXj=&p7&(rINsd%W5t&(=3gL4Ipo*BOiX^cvb`)Lx5`O^5DH6?j5s?`zs7MoD za-H`T5%u&EFaLvx7l8m!2~OKt5k=%BO~{<-sSqBNn913ga2XQwVO!=|oR#^WAEANu ziHot}CI5mU|5*_0sETiip!fNgvq_cfxex)h1#c3Y9J3Kgumvm9FhD>Gtw08~S&bNG zjTnld{3#H!X@8&uF_I@)m-fh_wm=Ho_SQ%1jT603s>{7%G4#N)Wa+k`hUb z1K_2e6sAV%qYF`_9wDM-gdyA{8frQ$q+knE5EPLilz**oM2hm5VtS=0+Mo=9XxF%= zgNhIu#ur}ys)=qIs0{(A(k2j6&7x~a{B?z0Y>wn4~oA3**AOVpfEXAj-)T*r$>lYFr zuKzHPZvU_%`h~0zI}q)`3Zp3>PI7f5;jan7m{M?$6A-ivp#^+oC$eftz!$LwajW|Z zp9#@XLmRMsF$qZPBra670 zl7IQD126{lK$qXD0Gl8IS!)ncOJN0ZtlO#(6RE9lLNlVl4llYygmJeAn_&Z>1w=-Q zLrDhRSrCJJwE=q)>B^>rTd&{hvIHTkf(y1hI{;*`Vj+Sp$ZD~qND6%$A{Fboxr(|B zA-BzX5G_kM08pv_zzGV#j#Kap?trOM5q|}vP!BkF6h})6YRI!m5RuH}tnIL7NT3B$ zparA<@S=FB5ON8TE^3c=*_SGD4?JVDc3a&5()ypRJ@P7{w zf&ioN5BAr-9PGgju?2n0C8OXECv3S}Ag)~S!8gIB_eciAu?hcB!vpZZLV*OV;0hEt z5kT;~MGOQj{2tQUospph_0S5LtDL2~pabxkCK;MWS*E3FdcxJk z{h1RF1rfO-5L(C*d&;^*I1pb85!+=kHMEKdfv2{doH-#Hw>YQ>z$w@)RtDE70hJiCXneTQ>BpKtm&ZG)yQ>hB!JLom26@_@*~~~8F}%XM1wjx1t0XMG z{Iceptw-w)6L7!hrq1=CvVIHzd`kxHu}uUKiWq_t1bPs+F&`#ox0<52-b}mQ+6qa_ z5L{Xd?7+{Mi3CWTBGMX7j(<{(5)j30IuJp7(D$Xs0UMdUdZsw1D6h)}cKpzSN=;G! zYodZmy1#221HcNlun7R*B=U*Q(#FuEPz!vdD6QZQoO{j>vBv?6wq@zg($>ffQ3m$F z1thUZi>WEvIH6HujVyzrc+1C$i;5@bC_ru2xNIFgS{hp1x*O3jYJbFs1}Xui@DEgQ zyy>_iu7Is;jnhKS7bdo=xofrHECsAk3(=dFYmIhe$*=xU3O0PYflbn3yS;$@*8?C4 z^}q^F2PFunqAj|jCz-z1$GNE+5*b_&bZrvF5yLRyBZeo_m-K7@TTgxwaQ(}54>8NN z2uB?g8O1nVRd=B12!D^AaJ=qZ%)Mw3ThI!vP#i<*u$hezTVM+l49?VDi|qm01phI` z{-BZ)+rM1@ur0Tp|kn zFiu$(80$6}+cn*j^w{^!5S(xe#v1_9=G{77-ImSR6oCM((0>lZ^wi5_R^NXylD|;HAI_{8PL<8w;P##4I+_sr6(*_tC))*pMLYoAwSNv_ zo=(vP0^rVb+zxg8(-ANV8y=BI?aooIk8(MhM=b^HK$`3!1#Qr$@{t5&0Fha|905S) z?H~b4Q08=Lhks>x%mX0h_D~D$u+B7|w_>X1CfTP+keNwP25T+_WZ(`C3FiU-Am`z` zVy!R*?c91dDqo!3wq(GfWI(lK-nZKeqe`1$NU#Mw-M7%q!EMs51=}018TPAEPzxy=&xf1l?MjrPIhwY&1n&>hmX z2mlcPxlMaV9>1dOT&q*O>Iy&#s^0K%4g?WE^qOu9_N=ri1OSpCtSYOD$IQ3-PNwzP z1{o?J5&+#vFa;)F>+4{WDvDx@x!=jCs#T9*t4#v{48S-5K@>6t+kdqJ*^W#S1e#jatMzfq03aEG07%kt;wA|J z0sx3Ws+E)h0T3YBA~GX_0bAH5BuK_$#sNVFAYqE}VykvifZQCg$<-zW0V1$!BSp(V zdtq4)Nb)r5TDEQ7zJ(iC?p(Tc?cT+E5$|5Uef|Cg%(PZuw}T0<1t6u=Eh=3T0Dnlp zv(%J}n{0v9Y3~-Nj7V73V>Pf8ReKdZ+anWz=@^h{fsiVwla(!7Wd9Oilt9Xs?vJxk z-ySgbALf*{wA3YKua<0uN^$Gf4RUtvj8X>b)a}lB+?G;&jkecHcX+0o#Di?5AS#eC zv0Ke9NM=LPkGez0UOzySKTq>d?buB4G+Xf#Q}sg zu^pyB%7{q{m6R*TxFEwa0aAYHq$>rm93YBT!qcgs88g}ttZ#PgYLbDZLunX0eDb$Ng3wHeBG^0)UJBys*_#8B&Bn&F$Al_`@Ns5Gc zbg4ih$2^bOGi9YEI*@#6Wkw*jsOJ|}U2$azSlP8VSB2sFNNnXsvYa=Z~x4<-C zdoY#@h$Pz4lUHKVTxiVQYEeiDTf-C7mU7_&A~bfXvV~)`OfJ)(Y6p;j;ZY*`6g#Vu zAV<%M=p+3)@TS}5GNNQ;mf*QhT%v;=9en~rVn(c6hB>8M5*WbDb0Oiwih_{BO z(ySef5wP#9>~9+W`K8=y2>>k=LQNyf>RAd11TK=qWsFBD#o1yLH_3@DSkaTLAO%%q zkcjM5<3A*5p&qSBf5Pm(7o9KNME?Q>sYvIdqL(({PeBSPNO*Wsx6`%Ig)f9*3}rY& z8rIN;H^gBMG4eUUAb}`^37xOLM;9WMXl)tM3k)q6n541gWpO!E6yAa~n!fT2fU zMp2GH}FI+;yYFnE`rG^ z+$1QOyvUYb{HVeFw!^{rBa0~ zk-}`!lTDXy;(PUro<}ACfRmO&lSy1lEj4zT8G*DUBYja~6jI6{Ak|H-i@;Ixs>ph5 z!F50|e}DN9 zTrVNIOdIoRZ+Y9F23z!7rEcxKrKRFfo(5-mx+jgbWYJvi8v)jHx09;&Bp>a5$kd)w z@69zD&NnQG<+08)>d*$BHmNZMMu+bxq&fsW#f(mdR4t?qa_4kt*xdftqs*|+1UZJH zxotg6eztvPZAccn;vK0HVnF~8=VF}1s3S9chGZXnCOaSh)RMUtM zPY*3!xVY06_*g^~fRX&^{_LymV(494n zsoDC8f(N9q{dB~=v4mb$Za3%JY4u^9JVDy;gdo`OCqL0-J?Z`|JITmc*(IpAG`;WE zi`u>ZRTkmgYxWpjvkbmwt}lIZI@B$%V^KH7=lFijV=v1oGhtT^`cU-uw$+z>y2`2P zby4Z$K{)Y6D8I;Wl+7=}@1$6V3qGfJ$(!Ifd?EMvWCwG&> zH=g(!H7K|Ra)X=y{2ph^=kwV7=dAO}XdO*l;K}t{3+DE@l&+nZtL)AB-`~Ft=FPPd zjGQu(3fdg%`4sF5jRgb${rZH*lk@{<`++xaN9E&En@K@C;;0_;{$AWesA0$5F4Dm} zEZ0e~BkoH5HKd@@Sg=Ce@PhQATUVt^f8VTl+iGz)xW0RVONnLrTaJLe$l{ zeJ59HC{G+lkWh>tGbOOco@GLM5(sus3+idi|Aa5 zH3)U@iYWxW6@KNDjGhM{m*n`}FzN00;KxfazP@0-5T4qjp;kCFS(WE-aI~9!vd8I?TL{l=V9&pCZTolT6|+ZVKKi~w&10TMtVqryJT&{ ztYCo>ffx*1HZFJ~Tcg=CP*$n2>R#w@j)F=Ql~<|?6VOieTQxOKgXzQ?f%{RNnq)?+ zaeTbJoLDr*7Tk-i4~)q`t_f#LC{&j;O!21g1hx;X5#%uvnQn>Tvu_(`redQ-Udc3A z{1C=xlfGsm9CDCBri6^xPS@&Sjd%%r;uy{|NrlYmGCI@+_%KtNQ@oiu1F|tFoh+64 zaJ>_Lv2s}JKl50H#tQ47BumC6c)OyMn+3gQ$(b*RnH}SqU5A;yja+6vWvG}0^$twS zC5g8t*06Bk-9+)!+6<(Wjbj1>#5K`RP>-)B%Pc56u|>?rF*s?(rX0O%TSZ$|GFn38 zg*t>3*Rq*j(S^riWPBpVhruYJ4l;yt10gx`amwANkd~aePD`r$k>G|<^WxT3YH>}S zPS~A1TjT|tZY^#+RffW%iiFdyk=&8g;Jo7*G$T`{GP)*NR4#^kc-@#3RAu7WQ_0zV_KnYIq7FSwGf9j9X4w|r35OP881A%^ zm~M4Q&gINBP9doh#Tg@drVA3X;#-x`VslLa&B`0OZDH$O7x3(PURtl0_rJ|D=A4!3 z25HNtHF>*6`E!H?qUwerv>=vp1{7`KN%(U%ubq#^p3LkF!#U8YUyu?+$bpa{TJ%0BQo>=)oR8 ziVj1zF?6P%n)9==3nq6gl^Hxpb z`F_*|xu~pmbN6B)w3uj9Y;g$UZH_$@4{1n?EihS!mj-?eU>tsVIGPzGV5gup_M9jq zm=tx7x{^_XH=LcYcs5>Wy-Ki<@nJ}){-=ffDPJ+%b|u6T5^D79f~uhvrhiG=r(@O8VN8D0u)!EIiW;Y$ zFt3rZpbMI)@BMk(0Mdz3ZIMGXSXPqj$23@GCNa5_v8+)OY&QHWj%Ao~{K*d^kbnEc zDOSi%m+HK~6-$6Zs(u*!a{G=;9(3OKdpFllW|cpf7~LW@z<6<06!>)TVg&3skjHJb z!9+90?1$AU{~FZx_0{Ot*W6TJ22o`;7qN(?0R zOy)lCOp&_OVe?vQ@miav3UgUP-k3kk(PmRpmVg-#)pp~?2;WzStzPnFf$^3dMu90? zfV%jWI1L;6lm;n70ie8(-6%g(V%MJFDqKqP%$2I?t?Qy`O;#*{-Do;slS@ZLGu5eU zFg1PF668X3!iec5^aN+mpz5^-Dx!CDbkHgbUs6Z6MMxWkB23bG}l(&vBe`H9Qi^+C=7SD-x>tiFS{VI02K!N$91#$xU3{nf>Iqf9W~>IUCP8 zyR>gGtIGof{~(fS2?hVqAFbbgHF5(>8WzyG#%N15eQ3iDK%*4i{v_4cZ?#5S6>7`k zD=XUPjpEX=0JVa z=JPGWRQNWm^@gIQT%rVZv9PXQWwZGA5Y>Wf?!4YU9e3|iq9ak<}VoHClx z663<%$Bc+ygcN?{8a9zbG}nBL92`vy8rEGY9r?p@R>ZSEQSJ_%SojIq$k#I5P-yjz zw;W5a$M97RdVh%@u?T1N--u3A%qJX41(Q`;ICL*?@I*8bf%-1sFD{*}6D_4r&uf&) zztYxfa+1%WE(ariHTMhM|LNbt80BYR`Dbc~OWWadGnr8c{F$%+>@4g9+j zpLE8LgzJWvZ%CtS2UIg~K%tMRuq`0?w8l4I$QN@}uhH~*TqJG}yTN!uh&>dac$M)Z zp4YtNvzn|yae-~vr$X)tKbDKuos|^9{2s+4nF%NeWrko;hP>da!h&87$CR9MnzQqe zjeRn{CYsrQ(Z8hrS5O5Vz9ZffQ_xY)@!r`Chn4#p+XwnaYM3lbLRNfDWbG-DzZ$dB zFFn>m_F<>KCMtHvJ7Pm6SMs^z>i)^Qp9gLG6<7-FdUkAm5t5{ae3qsjO2vAbzK~(Tvnpk2x z7!jX-3&X!{o`0XE-`st*-;Lz}?UJ_M3)vB&LU~cuk^g0R={s!h(xJJ+@R9T&Dz=Gp zMMPUiX8s2*g6;bl)g7M?JTGh8c=44k`+YJ7D52j`GJlfC;B`!$;8kXsTM(3wkeE9u z*64!0|h?ruw*jr}^$fwtqU6{ONx?OZOtLa0szi7tz4YcYPF6}j=2T0rth3TlNOoPS-bQeu zF4kR{WF~C{Aac)qjhRg<=>z`dQy@eq3zmR&$Bo*v$jO-i#wKS#CmK-L4vaUEE38H!hx>H^^uXQp@C2ix#*I8j;KG) zn`e4zAGyidZl5Sm5O#yNZH~yqgkv{79K`o~>@iB7NgAnEx3ids1immiGLeEPjOShjE(P*QJGNky$N-B|FE>O74G$g&Hr8!o)FPU!^3pqV8Fz| z?XwVj(n$TF)NA?x}F9d97Sl z!3_q3^$V6?rM8XJ~)#!&7Rrxop<98DN4Ou34df1Oy_v)kCqbDZ;$ zJ1n;C69Ro3fn=p%MxDRJM_UH^6bXJhxHEo{>@byoGQC#mZK2 ztJ5pxRr03&vRDjJxx}6=M)iAP=MBg5kA5jDYvM*)J-P=Ws4vv1!DT{cZHK>#jT+p9 zzU=JL)2`pTrE0J_OX+Pu>4rVsyJu{M`-}q?-`wOc4c$#T?3+{@Y3mxBM$}7N=8RP9 zzVfOGU694e0UA^0JX}snbs`o2wWv1QUrU9DBf&fBz57!J@_==~y?X44;p|`nEe&}r z;(KHYkHQ%OMZf+{3!fA6$qrfECapvtz(527C=&qPRpr&OcA!5O^K$}_Q^nLYlLYFo9)(}3HMT5#G6*k-92jUdJKWUs3*`;DQSlIs1;{4 z5NY;3DaJE=iB_7|4$xT>yT;9t^4TLa$P=FV*P#Oy!I$ffb~-X$uOEq=1HifNRTEhf z(AR_lbl63&_4zzpBqBGxMYwvFZM+t0DVEW6)ZU*CpsNW_G!rp>;>C|G4O!^g_=O&G zn%_9dj7ZdD@3$(YgB2$bsGc{Ja&9q|GKS4qsQIK9r(Y(kZ1&bWR>#=%(1Vm)Pv8cB=mF5eqXr$>4fsT~hQoN)usZ}qL=s)w% z02!aej1kwolbbR5lhCBIo8n010R>+t1U2h3sRcdj{Ge3aPZ3J5Q&bOo_i$MiW*f6t zP{IUem>nz-kX5JtbrYp{Y)F~U6CKT>C=QsnzV0+Pb_^Llqiz<@$fQ`RN^ev8a2CM zQgTk3%7>52{MPJrU+XMUQlyU^>UH^+trvNG$`#!vbHnEyg)qSiVWTBc*{#vol01A%zWJO zWpmmFc^Erldw}DdQ_(%1n>h>^>vS;ZNb+r7B}lP{LJhriw12RTHA$xGp?2RJeO#F1 z5>DjAv+tD$rR;E}zkwpsj`UTm*&e}k909j!bDc7t2}oKcy~IB%y-wM?8IqjnyAPIW z+m>olK9-1Ydqg}$x+MlYhu)rNq$_?avK7qY1(<#IlGo4v{L8rWZY`EC9T-$MOuXwW z>!n7;p_{zYhVJJ7nUl84iK6{{O;+{hpPz9irebH7Fa3EkD6be$SUJf@fT>?dR9XNEV#RLxu2Q=y7sTDH_Rn#IEQsvhMW2GUp;v#J*IBJl&_;^uif^bEi} zc$n){zHpr6HYj8XXNwEk^6EvOCrFW`cPP%iW-(I zyI%4l^7+W|+52tJVTr1k@bK(~AQjvj^2+viy}Y)=U)0Smb#&kSa3tj-Ab-8v?EWp! zuwiaCayg)vqUplvLsv-OKO`+{aDMX#Wy^p1}TjKRCjN2Cpm}D7JI- zSqun|H%qV=Iel`wlx63z{qiq!`cvP-k8LeO{_*Vt*g4~zv8D|I5LVKSsk;`AB@U zqQ?NmZ(2$>4gVV=ioH~78#emG2^uK$k(hPv^D%m1qLOzLjJ{;edgCjl+vSh5=c`!i zuWZ0VBZs|2v}u!fq*$~vXI=*8e}89}>2a4ed0l`D^Gi4SJGma)AnXob-q);z46HI$ zrwvP0_%`b9lU+KpcT{meZGgZAn$BXAl_Gp-mV$58c-euCzr2{go?W1=Sb#vA$!Yto z|5!ozcH!%3!F9WQrA(p8`P{#dqD^h8Wi64<+jz(6KWh_=!g0m8xx8FW#?RwS)w(O3 zN|V8QE#@WnJ(&l7ZbQ5tOQYFvyXtJ>X(Ja_GV3VTflPgtK^V(nk!z^o)dfjv1ipY4 zM{qAK(|BKVe|ic&3E>tuSoVw^Y5gK5tTxSzv|SI!YXk%eTmWDzf33VoW{@+E4M zmF4;)vk<0bOH_qgPUc6pp8%w=?np zqiC(^A%Jl}BbPmQxxFy=Q>pT&zFUJGgQ(HfGNj?2^3tj*@gwVhv=UN|@N7JpHHGck z&jQX6!>GTi^4LM;tKHzcPN)|0Zj17hi&T&0^8E2~Q4`(o-WhbP@W+Sj1q$gpX_jMK zPK8S=F8&M3!rPw3yHA_AChtGVUELcL8qYn8xkrI`%apzTnA7KDnGZ#Tzj%%mdnk&z zyl6%DsjrEB=tdvagf*2oADKql;9h-K{H@mGHp zH9;l_+7*ge;x=41cGuk>#!mU!@J2dSMtbo^jcWBd@(fUg=RdC)yUAp@gv3lv9PVx! z-K-wgRid0&94BdC&wq8s@I_^26sdF2v@L!+XJH<|J zn_S&VRasBjo49XN^0M4{7~&c%UUh^;y#1RwF~M`#!BgdvnB-Q$2IUQ{cTIk*5oS5A zEnMk~+f1%c()FoGOmNL7K1d3nLuR?SX1W&s`x>Oqhx=WW`EmF4Ow7G}S639Eaxo5F z-mi)2Q`?k{iofK`&~`2la40h6FIun4Z>pU5`yr(5eSR{Rw;X@=B=@Vn86_8$Q0mHl zP{hur7N+l0OnLR$#6(#DFOI<7WzZ4lJMFYS(HKPH{$ScxJ>8vNphdl=MPH!Rw5Ijn zR%_P>Toar)AnJ_My)DumCye4jGQm4A0$t@bOxh7r)f+GQtGfPKlzHMqImzx2uGRD( zyY+k%czap%_Mbq0h>Kr z^s|5(=gZnziWAhV;2f(EPQ4-;!&`%{?MsN0_2Qq`7g{o{TO#0@DCJEgz*L|Q^Y=`sH(Ivajq% z;k0b5tr+>#;&3UCy|((JKH=lh`r{elljZu8kKaZRp5sd&_Pf|A`0Djc7`{=6oY6O& zv3?hguPdq-I+tqr@t|Qx`~z+xu0lMGQ!}I@-*fcwvkCrZ1>OcI!js|`#KqMY!t3eP z`odtsyPIQDZkd5LKYE3_-hJOt5y8LXO;~UEcfPprtpWc}1VDlVP@Z~0#`Im$6}fcz zoU!ql(+E-TrRQWO9x7iZbyuwoz9hTn#=%@~G(HnQBl;!?OcaG=;2;H}&4z!$Clq8K1e21xZ*Fn%2Bj_Hf*ALnW&R?~F{>){fp$BPMb* zPE+mq60g2W`DsqQZ%@~n#O!>wae0bcVJ6yIjc?IlFM8Lp4NAP7ghL{3E$d#7GqdV5 zc6ALV+f6DqhnxY=t~#&%Cfq%|7@VXDPASgDaNb8zfkee4G4HgSzhXf0&;zed7EO|&;q(Zfr@AKRTPSO zWA0kI;3(1e%(pg+xF|`pD5ZoL!;cAKPBqi%(QP)MPw5|^qr_fOv1GGOcIh8FC@`Wi z8cmLnL9wwgXes^6F5%DH0lFxm1-K)9xadE-^WGJHb?`tRA9J;(ePh1wKTC@K$iH5# ze=Fw0u;H{FXnEv7bZUZW)@++&E*|usgi@&kx1mC*t77ny>bdZ zBi`Y~2C+m*t&A3}0!i(P7VQsrZZt#b$#QhLPy<}Raomy!Y}Ne=sy4mLd`iS}x>iHo zF@o!Dxox+iMGU+`AZiz%k8H8%w}TsRp6iQ5l#u%qj@a8TN)i&WtZj}NP`)$dxlhQa zY$!6Gx{J^ovxwJf{Z~Z7*P)sX_U9*6g6W1Pl>5h8Uc^p+!fk>Wcn4^){z%G*UTz>N zurp5?@LoZ!5XlM9I5}rEJJmLjax`mmd?MxKeCf!6SAMjSvj3WHA!d?kg4bwJ@p$KgEql=W&c~6l4Hfi;4{ueHX{) zA(SyM4ddcg8)qTuG3mex1nudhyR;pO>stk_aiU^sviH%n+vCK4tEv^o6!n=j{uui76nHG`EN>5E#gAdsM z9P1HsD)3DFPH%#8%_+@Xh{j{Yxq5^V9TN!y%9=Ln zC)Bb7@x0LVE0%zVHyt{%`cf9un?NA&Q>CX83pK%$^wFy)mdvDC4`(N3B%6Z>Jwr<` zzbADQkO+<_E)U*lcJh8&Ph=RuYRD=Vj?|>0>2WHDGdhSRw7lp`%C=U!hv zu#WO)auqMSDo}Ej_kKh!5i;aPBTNzyQRJaI$uxD1w*`>tBjPW3M@2kTnkz9ZL=kq~ z8~_9WQIN(!H=QIJOKxl_h$*9_1DBm=(`-e$da7^|u^4v{J&yo8|@9*1bQZ5?E*>5~|GA556J{QbD zycapM?flp{8&0q){#4aOpj0>_Qq=IO;vUeS7n)T5m$eUOVk?$pp^2Z9b`_XAs0 zlO1=Cdw+Rs{YEK~gam(5KdY2e1-fuDrXx7*XALp@q{=GlZ(jWx>m>^jVN>#*ZTu97 zc(LcRVR^;#ulMh#hyVUVg`z48h?Q19NQS@xXlhLeXC{_x0BB{o*5p5xUl-KI?*+Nt zM6EN`2FcLgrm^|r-de_MiO$<5M>|L{KWB6_dC14%gX(%t^oSE-RnKdw@G`}rCi>Fa zGbtp6?@=D76n!dn67PWFg(jElBrP7Sm!p5uy3X#q@K6{^?Cleg#IE*=uEy&wHYBy% z2slgMdS5Lx8E54{vFdQ-u?SQu#nmvZ6wN>aNhy?+M)QUX=>*(*<(An7t(bgP%hDz-@nT3Rnh%{%YvW~fi_?QxZ%u!YWTPW3hZc*iVy->fp^O!ngP4}oBKRItzusg9jTP$yV8B15jsrFj%w|yS2xt`lXrJ@j^&$~ zJ}Qb%6%|Az?M=_EmZxu)(!XM6p5U?yWadNPQbQgBc}-j{q%vvnanWTKx5T*Ut86_@ zvkp-Y03Wgz8dIZb0cqpyaMYE26@GHSFx!C-kqC{SNqZpxr*kz9j8I%B^a5Cq@2_g~P6^mFCpaO>JGHWVB4W>nL93&mpYyfpBadf~jqd+v|`?Ozd z7+ej=Z&VKZ)S?Cjgk(|~5te~XIo^Ix zxT8Iwz;86bX(dN>s`qDAhhRPp4W`lrZz$l|gGW_o_uO0OsLn zo`WJg%EDZlE(ZgBUQwuXGRev2`x&znpFRU(Fj+lUdC$@PRq}42+}p8G;?r+6Hd+* zYVUNEr8)U5@P|c?w{@5rI|zN_#jAF1Nr?o2MDyJj_}>jxV7kMP?&tMW({4dZG@r1Jd}_cK`r7iY0Pe)Duhr{bMRJ6 zXe)+9I~luk1?{_vGmj-91~{40DwWj_lM<^J*aQ{?XjEa@cKUEFYwU>OnJObNGMV&Z8&5JLdnxb{^YBgoS^JkJr^rj?<)VV})wX zG`?}S6lzBcKS-NlBcDFDV~>HqFBVVE2#&WWw75^_pKu?nROeax^?~+F)7IuBkEd@T zTw03sE$oKeisUF?oiU@soXNO2t?CyB$4m^3|8+a>JU-CR_#e0P56j*)y|m8{FSJYZ z+7v-f-`J6TJd@FP@gKG`=gHr1w`}Kqx&otBJl`obqN76M zv=oMTzoDMkLF$eRq1>Pae_0-oF+A0T_HcZ(mO;t1v4}=kN!TSh$P6zu+ms8J8Q|;x z*X?9^%Je>S^nTl!xAnRBUH8c+(zk49D2oHz|7AOMS*Rar1aXIAeRWcLISsOI+0OH{ z##~3CXKSz6y}Sc@6mmtaoTuyf{a$(gv`H;4AF~elk+aR>Ec-rd#{b)QgoW^k-Fl!* zxULq?<7W?xLdMm7=k%#$NZi^3(!g0=ExSUy*tLf%qakTtana14&h15+q;o*7A2~fo zFYtf1@1g#KTV9aJl%5oz{(irAa{U>P2Z{pP6KFKmGuWoU(iu4PE#endN0-?y77T6! zq6Tn(cbbGx&V|p!+0`P2fBcWd2$T>d@)I>kq62?+089L@#d!CJ*zZ#>_^rj*EG{HL zFL5G%Yca|;O9;*36(l6d=W`JM-(qxYWUsN@Rzyx!*K)Aha%}BS|`HlqV$?ymxCKY zxVwKE+5Sru{zW$M{b>@DZ@&MkS+=`b@lUgIhDac>DabuXq%{MHBxTdN!%F4Xck-w2 zO#ba!Yu`op+s4(Vl=SYuulnuc-a=lp5Z(0CDd1qQ28zT6?kWs&y&hz8j^dLzkP|i$)@d^`ZuP3s4ChRPR@_H1-*d`zOO!^SeR`$@= z=+s#Xk(8->d+6(KouAa^&wmHS5_rIZYcZ0XMF*A-%b&ts9nNX6m&M)rn zn9tu`T1cvGO;_-f}o1o$Q-o zoaHhs7=G??`&$kt6G=_P1^*Jin1ue4P~STEzZ?$2_tGv&)9s2qxSjK0Ck3g_x0`C5 zul`gm&D^PKH{CjbiCVxWl5a1QOGs-k%cZ<(PusTsz{$dWB4Q)w`LgTRyuiJxulcV> zSH2d65cBUBhSR(47e#YVI>P)OS7jIqLuokWOsOH62M3j<>27Z3Wck&H)wOMhhyMj| zjx-ZKbMsrq&1Cr0(Nt!Lia=#Ew)luSP${PFpW~#zU46l7JUFd$-jf5Ar0dMM?hy*s z-)b{a!q@wy84uDxd(}7JpLmUAY(wAOfv(iN0);#@;Z|*<2Rd_H0?T1!=P-qr4$_F7uu|FtHNkWa)L{>j*T4!#A)9Lt^dfUYT?KO@ZzF?+pe1>dG(w_~CYwNoVG?N+UsME(uWZ<7%0g>rY~iVC37da5XKc%X_)*Lz zJc4qSou9wXlvwg)=*PPSl}0@RXkAn<9dMhrPcGG)jv{QxYs{6yx*D_Z3+O4j>!(&P z2Q`_99ni+ANmsvbTfe9;cBQfTWtp48PE-Aw+~b)`scsC?7#FQ~Y`(V+v-@^Y_f@vm zo~Ftqab4A5Emx4^B73eyoTw?A+>^q%l7D?Z`?94{ug?FlI)kl~jwS5_9Vv>?LC`m2 zYGY6CSLbYe{qqW)&H8N3qNR8B0qs+p--3R1Z#^~+bmdq(t3PXE1J=A;nfiE{uhID} zUPM2O?+F1xTOT3j$ag>V)4xH|-hYbKPdarz0q$J&_pPaTUx#iHr%<#|=$pP2NVZt- zmmZ|WEQCViwx1|>3@O?TDG8II>}C0{I4FhAOk_xz$RDa{;D9*H*cbCq!1GZV?t>`5 z>8C2+0;LJRomrR*?6BYYJ^qZ$k8 zj5etZZZrz3rshHbbg+3AeyIi~yp)shY*;cpLhi%wnE6Fo93iq!pWyYg*h*PnJ4a44*c4J^ zCbVxat_!WZQY%(_=kf}1z1Z<_G4~?FRrgQ!knnGBBk8;}AlT zKugO*&m$nV^W_Y;Nt<6Vire(nhb3dS^J-2dC^Dy;ZqHPbkMrG0fW#VqeqHpNhv(bw z0&2#}uy)O3MXcrd{)X|FGPlHDt+7h&7&y0olGXq(kM@I9p-nvc|c2Az}SSyw8I z30P9A8#aoE;74qqT`ene7Qf?Jw}Q7-yrW&LIau(T)5;6yhl{xqD4nK{Q;*oIv&M*8 znQcsA?q<)1T9myD%UA}FbhK%WQ^8M)luv2BlDx??oU zi7Te#Ne2(T-SD?U{+uLZe<38k;}^+dO=`5N^6}XN`=pv)4#UbH@46qnNAsG!0)$+2 zGJN0HTzPDl_z4`tQGlYJe6b;!UXeKWA^9azTf4b$J&G^(Ir>)uowBP z?6vpCk4UIrp}lAB-Vq%S32ub_`1_P?d4V+>6%Udk34K>50xp^F@;l;E{tm_E9+5P~ z+;&OkbLtDgV7M4V9Ux%Rv=4tw^LZv$TDSS~6`zrQx( zyPtdi>RlDzt;d44GEnb%h=8`TR4c(JTO3dlAL?Hr;86D4rb)|N`(%sqmO>dHfbToM z%6J5mACt}GP@pP=P>a4I9+3U}^>(`5i@1^V+|X-919sN-8lDN89d+5GcPnx7G~omc8DiMpZde2}r}VIy|tUXfDH7B1MYYyWt^*L0@^^kWszLr3u$Ir?S=z8|J1S zei|Xw=uWWBYaB=|*c9IXE0i@e;@)_qBwG}ba->})+uL@!mteY|i_cXWp&>~DtY@Ad z1E|j%iN4K3f!^W7K=*%moEHt7g5@n~0f;dT1sLL7P!0eOF`wv*(`aJ0dk)@|RD;n( zhM|3@5qyRrR!FbvO6r^~O0)|VjSTT|KUDhbn?BK4CS|X-5U=?5ur}Kmp-1LY#&*t) zAac8?d}T|x-GiQ0f#R=J;+K8~fi$gymaZ}|u&8&7G|h)hlpHW(^77dTr+0R_Ph2wp z8AmV#MfGq~13Kp%NF4>xL^?x71D?6)-Fh^Yst*mz4LGEAw!wb>fv|t$FV1go$k+r+@HVMWmqBuS=|WoSSMsI)Ltl0DdDPyc*CmQO3uRui`WH1R2t z(kdkR&96v$SI^t|X=rQzt?3E_D>*+mO97$8o&`$-xH7zW7<^bIfFwWz#6TQUIn_Xm zA@}OV`(^s)(e%hbm|7sUh7p#7P07LnrnZ_!YnKiKrkVVT7dE4=`fN{2<*?vF#&M4M`@g^%g z)`V#(r+T(Q>TnS1Q-#Yd=(&tz$IAf2t=P8m+jyEnrY7Zrt_eqB~*8faKf)Gh79TODN z_^#qRyfDGyib8c_lr}8t7D%Z*bv;26Yr_CPB2j8?i4(so3FhLF%$}*t0Lpofit7?n zW&opwSG+p$ClUa4KgDuJjRK?VJswD1C$OU zlMc${Tzr?fA5a=SB}vNTo6EPWp#D8XJ@$)wWi6f8Fg^mMMdyj6svN2tQZ_i?hv_h8 zQd2d|ALa;E(H>PhXb%(+Y}c^@^Eb?4xh4gcb{Kj)9Qbw?WxLKWy`aW0qZ3#len!f? zT~AnKqdV}5wXq&t)Cg8?OxB<(Y9eY|=AK_DgusYog93*3U02mAT5a8D4qe3XI#~&O zwKJH2SX3oZ9nuS64kf#8=2_03$*wAeTR^U3A0Cz88bOlm$upA_mr8MqE0s-dq^{D? zKETuKlv+P&Shg)`>uk&ZUX+MTa%mc5WlU z-HrM1&;m=grgb3BFDdd*SYMW1gZdiXhnU8IN*te2J^gkgj;GuM*rYP>7PL*lvE4_0 z)<&<^9}rc+a0WMOgZvULcinqKC`$B*p6GasbVZs7s8r_&0|5&`Pp;+8#fUF$%K$K7 zTs4OYH;~K1(SbiPfAaS92_zQdi%NeKe1an~Qm;e4lBLHCtuPRov;{vENFiV(imxwZQZMuL1AMKUgSuJrGbJ*P*yjCr%S1Kw!edEFd+a0BNkooMKs=cfbAjN)9GsNYH z89>Jmk-p_w0Dvss6chSlGBCj6^i2R1usBI{UIdbOHS)YVNfpmhY&)2M5|K+HiT3&H zlo(3dS}g}l#+F!#I)=MZI@)r?T9Us1F(}PbeE}POkvm%0=8sDs%(E==`7p$uTYiUs zJjmHICd$PzWJ;7TMbmruL1l;*|9G7HMC$K1C1mZbN`0hfQ=0wIiXNFFDTB;t7o>s{ zflsP}%scH)rP>#@?2XRHs1Nidvs7z4z4W0o>!q+1+_MLE+}Bfd^dAa5i~69UxDa5v zO=E({fITVI`|>PN>KwXDGJvb}ENkfdQldiA&_>!K2#OfQmR%lpIc;L4i&UbVjOlW{ zrgV}681wpZ(f+MCFS$4%_;J@JT2%6;Hlzu?6W)VPdye-U~Q7-UO_sr4e1iK>NY~V z-n!W`&LI0FB;{llNMjI!$BhyPx8+-wk|cpj02trk$dapUVR8@^a-|dotYS~TrH87L zV_bOdWg#g~lu!m?RS!E|DR*^YJn!NKt>#~LO=7d>i5rRfkkDuHfX+3_xvNp;ao#}= zA|rAD6bZG|flYMgQUC`GMJX;$G0{$0W=CsASyOSpV!x&@#6RuVGqn4uW!NDjd4DAR#(CWRy(%1py3LKOR(>RBpMNg$3Oth|fxaKZjFT`7hlqfTA7n1_S zeoy$u0Z>WUlA+^;r-Bot1MF70pUqY%ys&IQCE|ZdF2GLK^7YrmIY8=s{e%{4(V*329zx2r^T_tMQX zcE{RotEJ=q-azFh^r!5XY}~XoeP|>@vR&>&O8b%$!P0uN{&KClY)tAB_uF2=*=#); z3cdG)aC`2?BT3+9H*iMCPVulVNxtzZSP!=Sm*(chF=+d3hMi(u#L<`4+xbLlzme?E zsJ4gK9&CeN3lY!Vf-lz&)oaF>Vo{zjSpy8t0`>faa-V!%sHQnhr`S&=tA7=m3;P5O z&GE0p7*V1dgcSTeO~7mlRyatSPH)|xwhlgMvEGHn>_wUAX>7TzHL@HEM=4QDL)*>$ zlxDNLEnqh5Lm;YinYxUv4E{K0^^TJNn6=$y!W zJh<=OvulMHulq%svw5jDrOKiT(>)unGQ^LP6JH7 zJ+J%QKxu&Fqzel)R!`%$q4`x zjFA8qe!8+b_L>x6sU;|TxO6SkDV5WD{rQw>C}EIWk(p_^PK#0bNB%N68vvB;7N3ms z|8`hb`)!nqQSP9qU+5A)$s5?gdi-W@_xZ@?8@CNSB1U4x z-o&if)T~k>NbF6G+ItkGMft_vBlg}gi=t+$YS&)1M{BQAlp3wp<9Y7q#r+>#ug>eb z&hK%2zHpFc!@xFM^NF0vm_T?3CgnWW-e-4 z@fGXyD_c0{^4gGCCYcE_2c(QmB)wxh#wMCQ%K!cE_pd28|Kn`+O!#4_d99kTV@J7x zQP57_nrnwWeVn#tC)}NG)rFnieYvJ%m;HQD)~TQ@Eo&lr(uiTdLs$F9@Fx8tE zCiUE$%6^p_gfNGrH2~|-gU(12Hvl|lh7+&};#3iZ4Hg^2QEFx^fX%2kolo8eJwK-(v5!5XZ?5hCtfQsjD%gI z#Tou#2oNu~S>>uCQD6HE;)8yHF^H|rjd;df+B%Gcl{II4uc`hngpVM}!eJpH#Le_E zZ;eZ#Y>kEFZ39auuYI=~mvaKiED4)ZZ@tOI+(=^OZV3Ru)JhY_A*!uK}8?IdmKZ&ZMR} zPa}ne&F?~n=~xhDPZDi)FPk*rtf}hp3_o?0l+v0UCF{41^ph~JnIc8Z^w?rM(^7ZP z2>B|NJmcFkbA{=MIsAlou%kl7Tqe5!4#0#f5#xQerIN4OgVY|pBr~7J=3-A`&M>0XR7&s?x6{z% zOLtcd(TEdrwPIjimN-VyLE&4wO5K+ATs5PjC0Fqn^ml zkO;nX4;DRD+cd5GSklZ4A$k^ZB;Ia$Vcq(hWU~^}$Zfy&>aw9i{?P>)?~MleQIDA0 zUg}p&FR7b?NCBdD@HCHJTHLje-*IDPx^8t-ev7-?t1o%i!f*loDM5Ehl8l9sh^Cm9 zXn}uCnEU6v`XC%Y>3tmeZKM@w^oJ*rI}Cp{+a18ic&7JX5h8z;jV^-umQ`taBR=S< zu@nw38^i9F$8GOGB3yFhG3jPR;DI3dIR3`ri6ObU-~ksy%lIyh1tT=>R64>u$28-v z)1KELI$ptBElo_S?~a>MR!WkBsfe@@U+g9l^Ud3Yea)FB#TleF@(Ow!KrAb+&wm*9T!TTFjm1)iMl63g$l=KB(F*1 zK3Y)IzRrjOlw7=<5nl=`JiJNP;kx4U)=MZ>2>z&r&7Gz2Ysqast=9R#)KTo$qPcp8 zNygkFtTHYxo_HRo3Hj$54-n4X2LfEo-Ck7CryxYpS`byETJOEGS$2BEIGsOyymt)* z^lTTlRx-I?Jcj)Q+N(ZWmrw#m|6PdieNQvS{9dUH(#2S=*~QQPVHZBQjMgvsaYqjZ zX}4VJ6%xs7n&9Wp4y#s~qbz?oRm=Fa)*#&vq6Gi>Pf(f*Cd8v3^Z&F228fN8kLq%V`U$)&oFLCK5LKzY$nUY(WqOC<;;l97fCXhNVV5TYDNCbS6vqiStJOJ=QrP-zCJvr+`fEaNBKFP z9sZs&_ayct@RR~@#`{+OM~+qCJazO%_4eNms5RzfMd239{;9n9v0_TtIQspIGx-zd z?$MmII^D;3wmCn5N#7Vaqy;{Z^mmi?zZZXz)MiIcasj3YO2Roq`|LjKU}TgItRKPO zu*U@OlDQ}0`k=~U6pW8kWYH4DcuGeN?@zjGe(dMaX!0|I5DXH%T*%&{bE3UP$~5Q5 zyxy|W^L}HtnPb^*;d}`R^D3do9pHe2vFWPRie64W(OrJ+EefaCLs+uB-|pc zoXM4|%j$Z{uDZ*(_jA`9DXNF!BmW4j803b^B%o8`^!nw`d}6h_9|$otR#w$D03b|y z0|xf;#fn)~sMsbw^1aZw0X#NyC*s@`fMTLi7Ffr`&X@fNJ*xqlOy%PL!fa3@ERCf{9J_rLc+(gBnC zOH*Fol3ULzzvaZB4nk6&m=DzhDWo6u2m20Z9hKgo(4}sOui0Gigx-XNq#VIYF=qlR z#dfw)x$OH*#CQzq=Vb3rQ%debI6~8`xFqyPb6?OwF`P1#YAAoN*Tn+I9uI*{TcJR85%G>T%Qv0?J)Ohy&r}G zXa}MAY~L-14{e;q49><4g4nPA2wxyX(&UCRm6F;t$!R&1cFZ*fB1Wfhqg(qq#K8RK zh_S@Sn!GuMKI-EGK}mIh@y75Eo9js&n`133J!I2bJMH;jp?!Dj$m9m@(M=kOafyiz z|J*`^8EMtP=aYz6%KoM@l6^iyHU>91kDJ_neTuZl4d=IIYuUjnFo?DqdqAi4jIcYvgKc* zF<_nLPDaK$gW98Hg|#nFGAQI7nto|gC?zUgZfhZEDX0X;>N&<(g*0QH&fMtEuHfg_ zHBGl0YM+Q|`Boe5>Nd1V7w}mMb#I(+i+G0xQLJlp2ZTm=d%?>$v!U*vhu`qkr*g->o~_EH%t%mwI?sdnH%< zR4lN5i~kZh2fS3MLs?s4_#`dl!$wbUx<$nOd-$AbGryC!LaU>^S7l3RqWH;Mok=ci zbA|3jCg*r98M)NB*3MFk?U03ZCChI$mVsrhlIL7^|BW(sZHBEC#Cs^!V@^?FVz9^^ zq4IveuX~YS_cOk3avN^Wp#Xo4kZW2ix7`=(-O}PTc`+eh#i^SCc_#e{u|8oNDe)AAwvi_4NRWPw%XT_{F8olg(QT8e6oGv1U^7~xuGw+ zV1rX7M5}m2*rX|iLAlIUnq!xt)r7JG0;=S99F6&rr9yZ~HXN8Wt7n30qo+0esy#2< zkxPNG1g(mzuZ6Urra5w88;cyFsNz$O7K;EofoBUV5ABI~m;ct;RTo2yYe=P&)&Z2n zd*HZX$d{*@XhNSv`6v>8@~u@xJHv9j*WRBE3cs2}?U-inJdMh!+D0`0){Zwt7;5!| zT8c}o#N(YQ!{8XH%*{kRESFT&II1Piya+bld+ zm;oy!EuZd2{+Wk&k%u%!!iyP=T?lpq$bZFVeUStUTaU1EV=t64I3P|w)zL5W()pb8 zb@7L?`681Py98xfUa#XmpYPNm)qFG6?hGT5)*m!pBw?KEaS-hxB`;$KiJg&sT#84D zB>^f-XrEZ(LvlsCv${CArQ0U9l~{;4TWBTuA8caK8=qi*KggI|by`C5<5zTJyqUGr zcKlj&#ZI~7Va3bCN-56K?A>@iIxKGapj*rPrM6bR+24KQ2a=CSJ$sINz7<--%&Le|yUj28R(yKhCf2?mbE-U%f3ncYyc)m0 zVP+l=bBY)m*OOotX##^oV(FO4{&OIsBuR-N&R=^+V2ow^B=bqZPi?K=yqGekZ*T+h z?3=)@5!~_E*ZV_EX5KyvXGPvo5rEBw>= z&VAYJ<2Bc6h-7@fSy({wSDC#rbhCHq$>a~>O%6di^E~|nFi0`4$jeUKG5JR$ctdY4 z^qOR#f3O?PRIieyrTY2r8`3{x&;QJmEs-h?%$t4UP4j+CPDR7IBD0Q97bD0>H+2y| zackCuQ7KBMJJ}a>+j*^O==(r?asJ7fA~5-ZqC&zqi?`3X2T3ZcPcBIYc~67b)Hct? zi&`=oj$upxoR`Vjr?6~&KUi*zPw_F(rr;ki6U{pLCt3q$X56s98x9WK!v7pOZxr0% zFSX$H9wm{7!dE{ca z=-apS=IJJo$NEADh5++*dU!{obsk?AAFwM;6Q#%Ze+e^c%N@<7f)`(Gi+)|p`<0T4 z@inW&jc)RF=Y`pF$fGbxyO`WzRX{hN{1FsoccuR zPl@~u zos)~w5zjAtpImoX$78Jb7?8dTFYL?e@FMj89M^u4Zi#=`a_K{_##+6^p}0^WqMXT@oB&wHpXx|(ct-SjCtja z58{=&*{ks9(_s|n?OA_1D*s^Dc0>!AYQ7}OVaV0{lf=x)x*pK2S&wv_zJ4E)%S@xh{p??& zNn#p^wY=i3uA|M{iO@=3PgscWFz}%Uwh{>wlqmK(a&5*e)+aMiAYStcAk;ubQELqW^Goh!*UuanK(Fz?TP=GY zych%s-P)XXJz7*JM2y8Y?!A3h7sz}3cNqjwjClFUdD*S3e_aby=AIJf+>`+V5K^wQ zxSo9L*CICWy8j5TFf1o&BL`Gh&-cgvGD9CcDlSlr^`Ii{joa6(d2WWSuAK<_UIn+69B@kWvPq+;-b}|d1msg+1%YdE5xBlXM&c^y4P*~p6%e& zpA)@aWS)&`nF$}P?wUtEq@yzWh4L)41mndC8-|hTNI4y1jIri~+L)+SF9w3fHV6 z|EW2R2AF=P$V-d58^sHWpxd7_%?u{~I0uF5H5kOxgdWUclN=$lBQ(DfI|$?+rtsu| zj#rx^?FAAySS`rVDcet2Ice;2v&4o-xU!&GAUaH3wOoCY_#vnys$(j&EhY7^$;sP6 zr+$2QI78?iFeQZMQkOvy2=LA8ZsBqfhX}N|H$~EzYl_`Z-8WeC9o3;o9LUVAoqae6pu)*z3s`9O zs9@eZtMW(E9WGJ*9J(_oj>pm+&r(fy2y(XM>+QKximE3u_r%FqtrQ~HE0l!f<90}EfzS{AMo0Q2MeYb+(L{KHzdUZNv3%bx8xR#0uSX*5^)Z?eX7+!dk z{2A{Y`fZzEByy)BG}3ay)4|#Gh2~*MPdk|uC5H;Lg%V+_u7_g)7e0K@i5-i?6O;;E zMGkpEzch0Cy$OevaA(wu&*4A&%})XiYW+~E8$}9!nj(J@{Roc}_sotjQm7rxzGnN%J2s-Xtu9Wf^luVrx+X z`k<(;uSlGZI0mMIwg>sYJSqi*hfLRQN6ADWLb%=al0DQ@3RV+}c)5RaCG=~C`2E=h z2#DVYAFDaJ*Q)HL;^;%>xFX*5J5iduGK6%5-wTfg&$>!WIbPOsXPuQl6IV(~Pc&I^w9h};l`9m}5 z;@QSXRc`9iWc^fF=(hENPHD!0gB^k?uvJA?QA_)o>-*YvfVC<$hGRfVO>Qb}{uCEt z6DvQWSnRzS8G>F)d$oDLlVY`(_2YLB2T6uDpJj<+O^Bk#fY~bQfB?KCG~~7e3E`mp z!QX}*QF82{<>qq`$wkdV|H>MH*Gg$Ts^Kor!?;ga)b1l1B&w-QX`DSg58BovEV+3? z!}bS?7;LrCsFDz?C{jf{0@)Z>b^@l&F$4EI(B&BKZ{{34Qk6cSCz=tGD(g$2`Sv`5 z1ohiKyVI7Y9n4t}MBO`9!D!4|g{H8d?gMuc!{5t}d)Zf9jw>#4>J-Y7DgE=|Tlo4= z?aggw{{3W3=!H@0N(h-5ZJ+%Ub%Ig-CP|L^*>XfDK?olyi5Wt@x2>5o(r?yo(jpcB z_bgFfNFU82PU$!0d-|oE!=5JgA+0sH4F&xi9iADqYT2<>`P@^JL*4Xg`a@b>6QpCb zXnf4DseES`JjwsY+`=neHpJkYyS{0uO_Bs}e;ojGM5n0FEC))5s`>;1IxDsT%(Czu zp(RNt^Sd)2i47?$eUH~6|6IYA_s6TKh!8?7h?m&dxavbvtryYN zEc*Ftfrz!aXWBG|qnz4fbA6gG{HY{VA#+hOTw;%*P?SYqM#_-XlB==Tqb`<2ic@yV z3n|P%(nzK1H%H}5=*!QhvvOT)a|Wn=i{EsmkJsnvI15_*aD;u}F8&BtfKyc%s!l3N zO{QFbDWC~HEA~!EuvqHRt6W>UlL*CScImi$I(4xs@N0&Bm-hj0;$`L;et3~Y8iYin zyS&Q7-xYf+bT zL`Jd+G&IDfb3hBzHN~BInNAiXf)N2hn)!j`<>q^cF(hbbaI!Z08EGVMEpiXqFK;t! zb~J3-JrLqO37|_cBQQ1L?vYvS4IaP@%=+iaTSVxysN4QB9tuzeC(&o7<5}dJL#<78 zT<#(dTM&9n21$H#YEV6s60~(K7-)|W5+t!^Nz=EcwVGS7c3#A^qdS9*k+kcvvp{or zQLv&h`WEz;Op$I%S?FRK;j6|BT|;zdLucs_>+^i41h8I$3~)+;uD4d*Ri`XCamL0QNgHxv}vXN-+Tg= zNZ6=>F=0=fdj+M&6>Ezy`lzrt0zgKF%9PC9@2AtDb6@Py9FK&)NirFnw+y;8-vKQ^ zc(F<)S#W~j4J)ers5vMpGjE)8;ufVJ9ZVfipJ=46nJCb%tnd3d5A8`KUJgF>jFYN{0m7Td^-xdHW%xyL)UDk!dfq&ns?Q<2tJZtY#1_785I zOPtO>Ze7Ci+H0UmJAy+s?mCf?Xg{tbSp7n%U#}W6xEiKMu4tgD-XddI8XSMwN&dKH znBSe%1TgkPv%h~ITUv)NuV_RSgOEQ;5?T#Q`8)QNAyR%(`P`xyP6Y$L+#C}nGfkP7 zcNuHACa>cttHty%+0K8S0E}DL6Ta|G3bGg=f~>k~766VN*q)8q37hve-}T9BB5jKk z^-V1ctqp-+SNJZ%gEXU2I~1*OBDJ{k<#p(?U_|@JaRjp4da+le9Y1*85ge+8N3P1; zZeF?6Yqm-)kNxJiUcNLy40T3;jnos?S2WAXXe27>FcRrKq0Gt&bK!N_moB6~dkTe6 z5pG|3N=dul&<@)wjQnd<3mn<8)K+_^$T}hv@?4WxAbYd)SGZhzMao|@GJ>h+Yel$z z;^iHSH8hMxCE{L&GY;DDAFNzPjrrE@S-ywv{?+8=?|(raAJEvh$H(nrWVw5r0PeC` z%YY6(#h7od+j#r~E?ca76t4x)SvIf|f9O!MA6pjc-$^0s_Z(rF>bP?ddhC$fKa_w( z$~%wJsR9s3+9#l8#Msb}S>^D4-Ab}b1UA&CNkws~4?Du6V!i^Oqaf zlt1IIaIC4QiIevnz`~K0#{16{_)Q}C>lz|iEvhS!R$8{OisqVN(l6x~If(u(^QU|E zDkJ8|BqeS|Xo$wM3uD`-TERAHzZvP{jgX93AgJ*(Zn~(EpYLx8E0YqoMO+|~aT9Ye z^f{7J5B7DSt;_lolp&3P*OJvPm@(aEdy|lDsBU5*SaDV;?5^J&L?aNwOg7(3+_Y#a z`sS0FTk!4FOowR}T#eLc_~Wl(vY)ncf&w5)a!Y^0c+|kUM>fh?v*n+g-t~;fM&yIk zLG6|ly}8UVp$zYxK?WHu!&$n&E*u#1@-JdGv|nZktP^a~?rDtT*!w}G{J6g198eCb zhyD27hIk}`H=8k6Gw6V0DiY1Os7yb`?n$a*Gw6nH4(t2pW-PjSzx7XFmvP?ps0S8} z9C`3`!cdEBG*p-2_nb0wjECOoyfPNCVuxsPDgBty8&~D3{7prz}vC`OHkH`gD`FN0HSxP8uATzmWki~V2^^U zGENcD-_2W!O~Zi1vK6t(BMcj9kz&!WV%88ea12vw)B1;?bN5z<*5NP2xEPH2mk0eJ!t1u1lBsjb>CU?)+A_Zz>bw18ynO*S2fxF~c^e`OMjkEZGV-7?0 z$V7-{8uL+EW1}q9!+%pfP^@A{kAZagOh^eCvkT^g@`yhmV)RSovCc%nOuRC2><0Kv=V%^FVQwGiSU`1;0j&0?cConOrl{!(`vzPES|g!-^0l!PIIj)S@WjRLvzY-~4%f&4k2?y-t^=WH(3`|CfgUW_NxhKD;^%GelvuHYRBJ}+NT`@Jt62H+WIV6Fb#wPCaehKru1hFMm);$=1fpR$ZW<@$3E8{O#lCe&xA9-ms1b9W7_hk7 zt_3Jz{|AljL4Lnz3}VZq8UK?nm+;};##j?>Jkw@eKD}7cG$Fl}{XE{dK|)zrgQ`~& z@_ml2>-_j-4}0(Avnmh&f4XU;dKRFOoemWgw)d^KclN2r0SdIN9_B6ydhRj+b;aFc zfJ!CU*cB#goP<^UO;k{;t@W*kWkk!fzp?UK=zL+d&}-YRI6i;DcwOE4ry`NrL6?5# zv!{IO0*Yr6Ka-$~11IV)wFv3}^F*;?e(J#{kMFnbaSi>eD~nZZRTYSKu0-lK^h5qD zNt-SL?51l1!2HIUXXLr=c!R!keKJXyZzfqU|48Ae9&hWQaO=E+0U)pwk6)6g7bTAe z_SRP(W>00N>0}d?Jk{=7$sK5aM*xWUqA_vu5}9jweSxxmb8Fk2t|Oaa(ve{$o9Pq) zSSh>usBZL%sj>y?z=+(RdrowDHhs zBH#UPqZgN{P`h0puJ;7ea=I)*aJ~5d6?_xP!-?pchXTuNYw-^cJ~dry~WEML=!wSfake|yY|w&L|LBpk>mMnY4R&|bg; z!MI8F*)kf^&i3>-cek2Jg;M=wy9siD;Y)+uhddP)^A8QcJ8AC@DU*;!$78f zQQfW!3OL_yI_IfP=P(+;mn6-M@@;~NrPSv=`Ffgs1_j=-^waCu*BCMv)sDBCLKM^1 z>rJuRmT|&_iBv}29RW@JRNV~GJXv9h6MxEOG6gGj8)#Q2trOb!pa-zw+EELs1$?21 zjs*X68YhN~A6DI(@u2px!5s4_RrusMYBPOcs!t4CiJZ0*0oJ;l@No}-0Qsta?HidG zL1}UXPMwBv&lp^DFd83s6(ZEmGCTnfa*y6;G%l~~tD+4(-iMHV7!WCUz-CH!Z9P%g zR_of9u;-?7 z&SyqQKvMDi>?bb1&UxiLQ+KlQoZM;Ie-+9inSexP@>{ow10Ddp-EiU@U9$f zs~id0K*XyX%3kWmjRR^&~=9ef{U)e-dudFsiur~y0%H)(Bc*p5JD5i_g#KX zqPV)R^3p=9M4kE$kSZj*(!o^K;vx{*dbH%TQ!__c!Hgy&ZV<%@p$)lGo8BALBSH9s zA;C3%wp7Ik7fw`ITA5;+epN3rry|p@(me1CWIC^= z0iaO_+CTuxAXNS>S8MhHooOpIAc(@<@M>Pj)Y9&bxsa5K_{cT3w^N@y4E>;aIfS&V z2s5jJl4)3W8P8H_xrfy~$=f=m7?pRC64Kz8VWpY;& z^cr6?Q>Y3Ji;{9>5nfMawu_cIE0brPekGE{v?{T}Q%{W?w+HdFeNIeeiZA{ACf z8~EVvV}j+wZv`~sviA*P1F^N8m<&uKbvBoN*kUPdus~dCMuhB-*{xU7+Nk*CHw1B* z^i()7*oHa2CmCVAzygoFgO)HpYW@=4|H4Wd#uTeNr<_w9zH z%=}4=iT_<&23!7Uqve@hDNXL|v5h&VVVNG}kI;)QwA=1n-A0+?i1UA=T9H zWmjIkYjWlirqQJ!)<4WKmjCaYNApfHhRjo!M8E0f+ms=d^g{hRI{4(` zMbgmSdI~2vBCpg-FQuXVYh)zJ&2H}f;YcH_Px$-jE#-1pK`eSe{&L_*GH}Oz%cyeb zZ&+G3%@SU-_0A?Yw^wI+0FS4+%*2LSPb40yn5JGA0l>ldim@E34F=h_DBfnB#4!)V z0n;~r{p@lhN59gQXCMv#hpja3o{T97A;xTx+e^~|o${Z_)B$9CYgHrUDZHtO&~4pKF10uqcLqlS7Y|QI2I5~M z*47?kA_yNvKl+)TeoF!jjAc+6X-cMc)n!&&8d7O($Y=pU_!z+x(m%xv8U-ggD9=mn zk`7bf2GPkb&<@huQ0Ho14992QRJ(XmiVe&fDk#Y|Yio#6%UwU_RrSBo(2tzS3*N3& z&S@{R;8ZG}RCkn&3}B7gOjVu+(aOkzbgdl?o}8wmgxPr>0A_iFeUY14-#!!4V>Ev0 z+NMxf)lePQ-_2$SmsjXj?{`*J#N=vM2l4~+vUAjOhu~>VFN|(&wd+eq2C5N$wK|gD z1*K_?!yMJ?oCd1QwdRa$L%6WY90#>vFqz$Gc4ZOu{v(2T5KQi{#ZG4pr;fSQ~+M26lR}tH>Tz^^9 zdd2hnzq<1~>V17s`(J{$)t!H=enrrIn927HdgZnnY9?(6j%ui4V%boQx_ZEP$^BNa zcH??}TW;eGQJX?umRR7u?*mQC-KwY%hGNK5@pNy*=)bt_?g7@w0` z!r}n`-*iksv~t*SmXrA)i)yzXS-Cq6x9)DAPtLf<{Lz1F?gDXjvJ-x-=LKi^PVy%? zlK~YEL!4^ne+-)Tyq`HM2@q5z`fd;%DeSBJHg6SQ024SP6YQHWVtJ6|m>Zp6-&E;M z;QHq>XSQ7GSG8V&tI^xv1GT?O-YBW`ypyq?ZzV2>F(7aozB|mFrgYrLk|u1ZghRLr z545J#dV)+(JZc*PoTxd{UD0w|S&*hV=@4GKy{|KT=I}+a%E?zs=v3 zY`YovE__jNlD{W^bu$^uK%?-_m>hjziOR=!wNcoJL|%W zn&$=IEw6qr^sv2bOfERGe*22};eT|AuR4!BUj13wV0+auS#bBmuT3tYZ2d`{)8UTk z)rX!FHvb-q!qZ6U+YQR1>&WVhHwW|fukWuZ^ho(I9Gr?L6^e2bCo?_qGZ%VdC71>352!AQh;H@u3a^E4zsg z?qBaWjv}}LafES!9K!gM25q%Pl~R-A`E<6NkMYDr!5LoQoCa{gDY&G;lOJEaSyAYM z2NT8y(o-xbQmWF$?@PCi4^7nFjctE??`&}(!%<9;x^y>eDhn% zQx(!k8R^Xs+sKf3#*jj}KGe<1GNj+E7_L1yUs%C$M=BibW9w&D<(aSQ>t3hSx=ycd zgGYs+{$0YgEJAgMBQ<>k@CYRfwWc`?4F-zsGU@(GJ!3l)WB=vNe$&YA=FWww-Oj-3 zP!C&^v|3ZQTGN<;IJ>yu#ZLSdo0Kah|CezkbAYfB;6)K~wDhr!DB=w%lkS(0ZTH3| z!8sv7&aHkL`ZBh_O?*ikeW}Qw&f`CC&v?DXANct^2xxlnoLG9kR2Hgi=h?&dUk~5$ z2K!NoWP`F@(hQw?kH*YrL3hjsUxq$>B?XYT%}-KF+TTd_r{uqRF3c*2PlPSSnm5gI zC{6M41+n&Skht)UkMP004K*YDZi^0>nknH8yh~RE;$Y+40C&C!OY2g;N*E{3k))e2 zWNJ)waf_o{6{7D=i&sp`TG`iXo3;&1hxXpUMDVhK;6gizTC%NL%nAL{58W{R-@22 z@!camsm6zLEidFc-pOr|!y;3aI5}CxMn-@0agKCzJ#CVGt}KQn0D#a;T%{;OG47K7 zJaI~)R+tln3nOVpB_Fn~YM)vcce?+6T8>geiPl$%u~`W%Oun*yzd0M?qzb{oz!JP# zvrtDtn#VJ?kY`+B{IdvxDAH9v@SXL(1+JLSr<648UMtn-3+rQB4J8%4u%{ywdb`KE zkot((3(DH36zPkKZAK+tpwr7|Vz7qD}@!Nx>{FfWmQOUP0c+mN^W zGH6LY&G37g)AF>SrD;|C`(2@*S#^t9-7ho7izJrnPlM>1AlhPNO>wi35vAa<>FMtV zNy|Au%cT~}bIOKW{`#hoi^nZ&DRY)*l2(qt{H$(%ollIfHEXO8yA!(*0IabDl$Z2~ z=(FD~HmsL698$Icsp4%RoC0(lB9~e`FYRPo?c^`*(!6hwQVLiKd#y`*Jt+sHmkv)` z9V{;$Y@{3uwBGl6SRarWP9(+V?_VeBIJQ00BthYUzGrcZW}OwYJpmPAmt)$FwMf$N zY~K&oI@i?wu4#PqhAh|Sf9S9trhHod+F_pv-QCF??Jnl>Wp>O5mUt|J-Stl%nfmxZMgc=GYDsg-2uinqr)%(4TUpsPr z^c{W0M%n!wl9nm(zF?x82PH6=^ETKr;M+yW^AN}MjW4D9?6a#1WIv%3RMaJor?E=5#|G4xhRy``Ok|8Z}ROXUF;8RGCJR_Q&^#$@zC~-zGkOxH?(O z3BP04>FT^We-K`5HQDv?&(-PHXps*4apvzI6kAz(%d}4jhUfr*_O_B8uqa-Zq6dfu zqsFOpJ}$FYdp_~_>gaqjr-tZa3Q6yAF^zH}GMmwq)}n=Ks|E8~>zjD|nlp8(`_-oB ze*9~}I$G#*(LT$g^|fwMaErfd+wtX!7lp3+s&AXqi#7kPCGD?4S9ix(-$ICluh#)5 zTGtyf59+Ts6U_at02CD`N8jt92vuj5*zM+h=P}1e=Q&e29CO-Lzq8qS?&26Ps4_rE)%84YowO2?!#5GT zxCQslpWhm}c_(#$sdAK1`FIblQ|*5a+pelqxB}nhz4)1iy1I3u{Q{o?+f)`4-xmj{ zN2a14U6erOEhV|sr~mSV=F^T>sg~i!5(6*FU}S{VsGYv?lq_!sd}`j`aPZVch0J)h z_N(88AwMrF70GIJqt$0inT&+LKd#ZIe3Tz6{3}cGii|$1i<=gV$xs-OePgyTGT-|C zSKSl>V6j)Dl{fRN-ak@P?R>=Kh_(6crVo)dk;Y;#(`7@<0O>=@2v7&_Wg|AR*8af< z?kdaKCUCdl>tZ@CShq7Ezr5D@>7K#GksqK^{@B?>PId~wftMBO;pp(k?BWi!(b|>) z@O=J(AS^SprQ5>-0|bb!Zc<&f4@l}zx!WXLErxeYdDP3BH7>53WqG^gdAPUbJsnCU zBBpXC@LjPX`I&mvF?Lq#`gZhd598IxA)b1l9Jl1X!mO?@b)K)Um}M4bEW2b7Gqew(Q2zD@MWLw~THSqSMU;;pQCL%$GGo`Z;q27u8^>dN)Q6UN|3;aqhve-pRc#M5Aww<$R z#3N)pl*%E#(}BE^G|3!EwPOATc>sf$!}e06A8&-Ox2rgv+8x0`OLjC7v}q+q{-9&mBUxHGhegQk9?T*{%WcmMCv-a$m+3EyC`R$ zyr;a|c2h4MkRDBS_=n>4?bc@9A|g>t9D+fGj=S}UrXH-?$UQ-WweF(wK9V$f7l~4B zk55Zuloh$oQ@U9=7d?l^_1qkt=|Nuu$!~Bv1(f=B{grsZel-N*7GUfqoq?_*6nbhl zNnSk9JZ0;gwyfuOcMH?y5`p`Eh$R0O2sp$;OvQUZ*`%ewwN;-97ytbV1;JVtb~}(k zn%ECl`7l@dnDKpuke0GwE)_cI)af_J59=_mUrS;}Cy07a@{!NDlq!d}zU$w}d*t$6 z755rEV$X-stBT>Ghvn60ehjbpbLX(Ddy93R0{CEE`2OwPC8`rh6Z?+i;$;9%5n`T> zz$4WA1!%sOsp3$Cl`?A!+9{srV0}{{$mb9qG1|z0HcT~MVE7wdz`Az-K$MhTl%Dn> zl%cG^eM(S^bQq5*{?t7Y#O1wU&UhJ2RqTgf0&X#WW)EOG?C}7mfVNw}0Bb*cYYNX| zRo5(#o%ORiMF^x=WvCgMAvV+N{(t_Hy$p~-3T16Af zGV%s#I|uD#fj(44FiKgT@A)>PAT7>LUTToqZt@o0_%RgZ0UuO<3i1ySF(7SuaDQ!T z6f`bl$qb;>*R`r&#E^dyeOW{oUZ)Fbn@afCorDK?wQm%!er@mWoUq3C9|MQa3py;D zu~_K-BRiPGa*A=_Mxd8%9P+0M1xM_|#(<~WwE5SR^wm@vWmwh}5R`#3<=lk5F+gC! zpSICcHyrdT3!o!S0EB`>Q$xgjBF|(9ltd%C?nEK4q5*ljm}mWj08TYRvlwW>Kv?T~ zm>4HPy|o*jZgQ2Q#+R?3agQFV>snI2!l7eiEE*;%Cx)G#PKqRH4QyL@`R21vS;r~D9bKX7g zpEvjO+~4cEF1)zhs9dn<;bW>1ULXny-dFdy{rbIi6cF9vL}O!^Ad-r|tnEOcFn6-M z@gOHami7g1=P>T=vX%8}j%5xau_F|tSs9~qC^XPXTg|}G2j!r?ZR+H~i%3~I@C>Y1 zX&QU@Xg6wx6!=&8RHw}3!KG5Mf1rqJB37Pc}&WQETS<)q8L#GCu>kzT1)|nsAG6J{I_};(@y-! zyL3ZdK**>|>qw@24Qb}B?aV&ruCx;ZWot!vY~O)93i6*`ge6%}xxX4aaCT!ESvp;%3pat0c}2K?qa~S|sC&;BTdq7jvT-*n^1!h{plS2C4@wvR)n( zvDJW)GEM*_`2CfCruk!03ljOU2v2$+QmMRy$LV9Z(y^T~3CEI+<`P4qx(7~mJac(+ zPiaa%XH=4w$3 zws~DgGKjl1{lK*StnQA1$24J$BJ~gB8;J>kXEia-fC5_qm|q}#R=WVyTy+ItKp<^R znl%0~7D@7gowpmx{9<;2Akh^P{a?2cOK>|O#$_PInlG1#*-LgTr32eIm;?v#y9P+% zr;Y2`c1mrC%c4qNXEx`vRG63v&xp zBFK0Ec60!%lMB)wsCx1<#NV~(Wkqm?j>oW88^NV5v)e}$2lAA+)wCchAIa%3yd6hd zU?NCu8j(IOd6{=RL(89rf{a|!Wb^h%YKco*UMtsq-08yqqO1G#`T$s9w_|bHfe=+8 zgGAMJh}JI{rT@WoC+(EpQ!Z6-cr6Jmg^ISYuv@@?x#}a57)#1mG~q3pur^-6Wn1C# zFZbmM7{R~Z#`9T`yX70V$CC%JR!i7t%#8@kYjnhCk7uh|l|HnjQO7OFFwV9%Hc(~9=p%g07*f3>?({861r*b2@--EEH?x75L-DVK_uPH9O*(F zgNji-7R_~Fm9+YpK_w#$?_kf&rA*GzpiCXz`{6;dW(}S*gk;V`5(R*}Y@jJdvqFvm z@cLbhxjB`qd8oNyj9N7*0AB^%4%%4@)%b*qg#ZwQ9N=umc~dk#8en8-%M_2k&*0&W z&DLiG1YBjQFnT~yxmPHgwC@$T#A?fq>N*{`#zsdmnu^B0JYA>eVO1wi8gA?tGjoW3 z84HTiF7}hFt;i4Aa_`Ye#-PJ)a$^eqxDOCvjhb84AaoRe<4Kjvg@8i?PTdadLsTV$ zLeag9z+S$&p$vyo`<)(MD|pr!wS+mmP=TJRWSj!n17)93a`1EVw{{*#GRFE*9+54A zKm~>)lfkh(Ihh3i1R*!7AzrZgr^40hn73nvXj*`=A@;3#M4<>qKYyF50%ZuYpaK7N#Fd|Km@HRfil)}ruJouR}HEm_*%KVuog|71*s@^Id$VUrmhJlr_fLj{z^Y^6&R-4ZUqrRc%(@ z{Tl$?T}3fkavH8wde&zMp*1?uO+J(dynuq6{x40LBAjQFML{0keHH*Qvb&|E>W-*| zQW0NV7U+IbH3E{Uq~ANii*7aLjM?k+u>tGj*2DTM=a#`^n=qyo&vK*$K?gW6YKF&JV(yo6YIas!dI87%DUuNGM*OOP$w;ffg61uc#t7Gn@@u%4R9QtxWKq zC3DoxGW|xjT_fxI`N|1f5K+Umd5ub1G}y|z_-aDl(n z?wVGp8Y^V+t>DP&sbWO%*c0SEFtvR=Oa1OWOn;UauQ|VuiNw+~{zxl@4$o)u0K~;^ zKUlVlPBXN2z!)_fzzw!3`{2q{#ZM`V6aw;9sW0$^K0nx#W#d$&++A@qM?(NZ=kIxo zL!L);Rl~ZpAN%~ojX7<=CEIsacMq3xm>3(G+w_^hEGjwY{I*Lji&gYdRkUc4H&khu z*=^0&e|=>SXvwx#a#QGC3{mkYoT{1ub?&s^6AZM zTjqNugG0x6N{)5XK1uAH9Sxq{S3XB@Vb1j&&gfn{hAUao$GM%g-zh$*SrCu5nzNjH&)5_e^=_{_4x+#^11cBYBaky7f zR+83lv$qO-BH-FC9)F*If+$~-ragfUI{Eg#qH_?Z@$DM@5fnoF3{$j65f!gdX^(*| z*&~d}y?t840+?)N(A^S2!SrvHM*;BS_HwO-dvw|HYd$Jy`X% z2K|O(J=Nz@#7=-webuY9y9jit;4l8D(;B{4Fg__?Sk~)?3j)0yRVOKSt zE$40-`F&*@fp_I9)us)$!a6~yaUeP)paDi-XKbuQXF)3;OfU-}@AENb?l`dO0)1lg z>XB%e>j_s3;NmlS+>jRn@u?}uA@JWeM-Inr|B2N@EskBY7$up4-Me{Cx=Ts7cK@bJ zOsA3Ds0x6$XG=~SfUqTvW2<9tf~%_&3t_F|BF>XBU&)8r(VMIOPF%{g;l*GI5c}i- zL%MVT4wWmd>1p(9qG!*cSktXfaQTcZ7K1rvn0W5^G`Sl`qUH<0>=y2g z3+~c$L5Je|=gS3g-&LXYfArSs`d5@dqH2i$tMXD4Ri`s|ZdR}S8x0j5PD(x>A*Py^7m-9j@*AxMRPaoi7;=&d z{Q0UQKBFeCFS+d8qA&ds6C-XgdRFu4Wp=0Aw>z@pENo?mzyPDDXB%yi%ImLV&udY1$J~ubFq?SqX_dxYBZ4$@NIT+Z# zbFN2EO6e+RCdvcYmD!nQjVyk0wOcs7P?xfF$#H47bbT57qeKkifuZBq3sIP?5i%WW zQaAPdohc?}!6MCVKvqgyVHo|Jt|Bm+_6NbPzSUh1v+`DyTR~AEr@Ch%NyeUt2ZR_Ud`@Izqs;7HLPCO6iwv6Gr~&9Rkhp^2*-Bs#4wfmy_{I8Jg<$d zX3ba;{a(=rxj5KJj_hA2j0h_nn-abn^aM=aA60lVC4JQUA|=da+Du{Wg%=37 zA~`K0FDM;Gr^1&f{Q^pQpJ;%u5tPj*%Tp8Dp(?Hc;&`w=T?GJ8k5#Bi2HvQZC{NRL z;I7r5H-KDa-pLvSWX~6}*00@@ADSwS5a-2mC5cd#YBp7*i@7H{d;PV@;$KpW($Y0G zvQJRNwZhrYtP|KFgmGN%XC%Hq`l>e}-3vrY-IKLS$#>Kbk& zE5b&IJ3(*__@)@`TuE&C2?h7a@S@B>p0!qK8~-+XoxND^s_uyn2{37sdRLJ2 z@ls4@EmxlU_=!#mQ+Zc@OtIP;B8KZJzF&!kkA&OKimu0EL~5z|g>c%9+#4wik<0zY z%-Huf;!Mr%TWXT4@_T;+Xe!?gKT>lme6$cjz4QjP)b=XY+sj4XDo2a{iBzliQ$0uY znqzF6z=XH2kM*Hc2+|Tm<|Df)zBEpd;-|t*^0LQl9+AB7CQH6& zBvNT`qQdj+Ajs$GVvp^Q_T|1WvAz)pJ$N4X;QpWQd?T-V?EXW?b^fD6JOz@cT}00G z6KnY4=z8r*X5QVS^6`riRDC(|c$p8(!y6Z0>QMLBeCRHTe}cM#<2lsLjmH3x1nxD$ zKpl>I(0jL$=3czi%3FG=Vvpy(8!-M*8E!$J2_H97=Tg5=ew!@vH18W{=L)}=ajv+4 z%nF5r6e1+hY~*WJo0p3~NS;u8u7{u(`(x4gu$Q(PPx2RgA4_#G8VD@8B<*?alK<}H z4W`!tYrcQY)k-ufO(IS>A9PaaS)Gl?Xatth^||#A%TAiHG39K!Y)H|ed9x!uROE2( z<$g>K#l_X2NrgjH#xG66?Q6zx?a^sH|%N>Du+Q{ z3Dz%pPlpvt^DZCS3un08ADO`xpS*-@?m73Ytu2oPx3>u&X!@?MXY>DP-{c#_h~#uH zxYCDoFZTHacl7l6We4@@bAZM4zIHqsPoULOCW%@YLH$Y_=+BeQyGcP}IR_0g%o3uJ z^M0VkR%NX=y5NgB0Ac0>Wov5HKfjQWjlH{&urbM;<;2jQ_4LoHExeXA#@BD@DE45~ zrC5HXS`_U~=&yMEK+;ngS_4)L$GewSK=BwZ;O47LEH$kk8<^qWtrm5N9&JRnRm!zp zelMgeOeml+Wb@Px5}rS#zYXnEn^sl*vud(|h}W8B6F#9;R_&<`F4Ci+WX|K125o8?u!^s03zYLX|kZO$e(2s^gEG%!M{8aTGsc+WBp1Kc- z{v8lsP_AWQd#J#y6yEMO*~xk;O5!~7IB>)@W+ckK`g}mhN)Y;=CA5^dc$hO5gvaW; z#Xh5fYpvsqSBL;BM8;6fo)$PibTp)pKJbxdP$sMaGYc}q-#Qz>L1&1t2u-IVKgj%H0XQk$E5BAF7QpTLXFu9h*XQn*l<)I&h@NtREdRH zK#wxET=VXwR%i;`7*E6x&>VKt3J%mfs_8XDk=M==8QX)S2uY%yV~?80-b_f~@N)57 z&}W#_5i;O}ujK3Kp7okA!sI7ES@mLH%lZXGsa#c+eMyO7C7!vJLdO(QU^GJQD}?0 z9tqiqVP7535sD=|(yBt}7E17)0&q8En1SjIO|c7@SgxkA{qwk(F=JFw9dp|fOUp>g z$V%qaYWb9R2Ky_dtO;X2Nj~nShicsQ;Ej?KM2wGGt{GA>R}uRfY=9&Vfw%e_B?wnT zP<4>bCzQoB5?09Aius;MWVF$k!I4qlS!)kKMu9(f`swl8Goj?m1(`u7%6JRRZRrzo zFBC48JN;8A*{V*MAEm)u2Z_;Byou2Kx1a>j>OFh(@REfGFgyAe)XG-CBaBo+WvCFH z_bV6ZFnh)_x4tC;w;Ef>0WA80@qll7lmQoD{#F?&tReGW%nT=VwT=QgJ9_37zw~5Q zRNkl-39-TteZ>R9o1*!12Qk$x@k-VY|D{YQtCnz8i~gB}7O&&T!4kvO(1Gx|hPOQ* z6%yy3sB!l7TzX9|=n|QQ&baF;-T=1yi=Xr`KZz}TlrS(a!>0Q#c4Y{d`%rm)Y;e6d z7^$z61=CF-lF~Aff5-LGe^i1_A-iPq2GmkmsAlX3J`w_pB}j=isbP?KfZ)4knyU}W zt9@GqtPecF{QRcQ^^kj623AnIZ}pU<0?A(nUceTc-8fsD>!xfD7HzKVFpr@e7%z05sgQgz z<_hGf5$(2u|2u8`iK%&TT`g;N5zCIK2PP2xRu#a_;zSG5u1*o3>SKjv$iR4_lMK?3 zzX*N?FlJypgGi!~94DMBaKp)(pwH0yEDfim@aFlk_!W6p68TCpdE)cike3NjZ7Wz9 zN$e}BNh)T{j~Bpz6&@5IYFc9Zfkeem_>2u7@{eT;BPXIUv!lbxDQ@1eTryR%Eq6OG%FVz|HL|__GI5L(#36QIp+-YrkzOQ9A`$p($wcHd>Xo+>{>3mk9>qUY6 zs=EH1MIr>ejs_S%mssRVSSJ*$x7d`&BK3vg2@wtJ)hw}1;PGUWDXH)4>FV<)=Ltmx z?bRDN+d<4~11sR!pZJC{i}27@QBSb;_FNPa%*m92CEmPrO87}Opu{Xr>J@7$Y+36C zjsg(`Z|_1y5=_dNfm?UC@0^6z`U<`$ttw}4x)??UNxELbtVWze;J z@##D~;%>N`;*GF{X^n_RI5PS_G91xiOnLH>ud^b+&`G|ivR%JVDs`w4g{98IPCM)V zWCf_Q=EzeBTg6ab5x3NA_L@#&j;DQnOy>|XL> z-85@^{6A)_3JF&GOi?riTjMzE%CvaP?~P}vCmqnnjxGPzQ+(vJ-QZ@*rJO2dFml7y zy61BEIQY=jM~IJICy|g*GDRv#Y8K+Y1kA+=ml=jzFd>n|<|r$x4ejV!Szk;8{=PcB9-k@f3>qNm$LCxZzA`-2Uu;K@_gpv z9+L^075U;Q?L{s}+%-$xyz%@vb}>*BNiXv91Mr1TLMO{%kw z#X~y5;hiVp10EqLt7@qP{fqeX;bqg~_$+ygmf@AGGF{FI{!~297 z`x?@=1!`Y{uCz_!w%^K}^XLS{R(r;-S4?U?0w9sD{3w9L`JGr4X1->GsV;Rx7>MIy zi0#$5Lz-k%p)${wW-5x1d@|o%7Jay2cE0%d+`sg$U%75Fbc5-DCa-hrX)f9(z#%+$ z?FP6VV+A)tQVt(=dv`MLJ>ABs=qW-iAV^6WBSJhsT3J)rLBlIve%p?T1yH=cZSmy2 z*^1qRz7^3uqLLzvQa}MweMGx11)U)6i|t9oeKn%o4V*p$@wJlzQP!+%kkdyb_+79g z)UcQ9Yv0j|#YW{?mE6iF;?aQEotl6%E#wZ9#%I8`OaB$4LgNJAz+R!!94^}Na>M(4 zSczf@-R;d2pwwsd~6;wRk7J{EibZllwtvXYG}tJ zK~<||om&6p_haIAmR^rE;irgCLw-%>`E<4wb<-SU%_U9F0tfIK=ue5u0~5+cm{w{E&1n@S+$uuCE9_63(q1*tG8W7JMt1A~wUm@PX8XzA=SEFUYg5{vd8Zox52h zmPJ>VFHgtJ03qjU=VXl=vG(ZU$a}>+j5S&2%Ll0?Y!Z!EYXkL=&^1_-Hw4oJ@rC@f z!;BrWeA*rq*9o9Zj4S$;fAb?BlFMfw{LBB>Q^3!P4#@3)<>yUbOFBf$*-p{eun7H) z64&1t!cjkEck0VMdd%-0T5)cCx0DvPYP#&*p$Gq%t8d3IznQ6OVw5QJA4j|R;1_>V zAq-Mi%dC%e&4_I0i$g4>|Gw4yJN$($F&R@54<~<{%W(Q{Wh1gQf)%^P<^qi?yqjsi zrR{=bffjQ{^KZal?O=Tw2P3Ej!9l*mgBD8yWz5DT{cO8 z5;|DO?1{B=qU6TL| zV1ZB)x8@^_2`qq=h9ki^d20};EJl>}q0v+YKsMgFM-5P>HoIy^<|T+gb+VjGweTG5 zi7`NGE&~=U2sl)%63YSNC|MN77y!9^Hk`?vC!RX{33TzZPYfmrSdvh58CKGQWxRLE zWdp0TgWI)~S&;9tae)orh=;e;K;#)2tj!_eA`5$Qe%0gkkGJ)i}sY@v)hqhQ6!QR}fmE12I? z?5CAB&`$>SqJg1ZJ2BM>M)ywO9|J8HTpOQIr#rQ(qvP(AwOdQX)<}I|`C;v?cH08D zB2Um^Aj=F4Ds=wV*)V~l+8x+@i*UTNzE^s1ONZn%#}bzg`;klyZNcFuf`Z&! zaMdeXE*I+jwzT0)$A!0XDvFGU)NllC;wN28cIBi;#+2C-4rH~lpNthG`+Upn1Z6F} zi2<(0AGFkF8{hV=7U4}ObhBgVjrW$)&|}(C*vHOX@Bv!PfB(~|8DxCaZcoM?ak;0% zfgR33g4+@W3&7iH!p2BElMldcL>chTQQzXw8l|&6%wP@&7?N<2&XT*PZ$IBd4)C9kCFaFc z^$zql(x1y{p_ix;GpUB0e3F9Mg|c$ffH zSnzm0J%NYYSwaAT>`Sz_vV9}_aR&%Tw4j89PzjeZ%xHxuL8>`cqx%J2eg@H(-5j`4 z>#>BfTRHu0u0Tx70TxV?!1H&G(pX2CbX^z9Bf6S&VnCPfdli zcPOo5YIr}8O|xdr3t$XDuj$U4Mg^!!W2zR?Xxfj{4YZ6(g|BN^8!alPpYHW-`?Yv9 zbdnvWjM3XBe$b-*&7*eFnnZl5ZQAGMe26X)_s+%@OxMTuN6f>tN2p~Q84AU$E?=1W zDdP(O^J=Z***>s%pWI{|@}38IrQkoqQ=kxY{EkWvro3q<`xGVD#?WU$Kg|GPy%;8ZtLcn#@RXZfcuzAs?FO_ann! z9IEI?Cz3wjA(^Gzh8@ulzDr+Ydup@ehF~ZBx))}|eajRJah!{v>sAEYh+#+6jc2#~ z>Vh#;<5Le|XV>o@6u$;n3B@Jr53w+w4BAL5PV?{4-DfnD)>|Gl0KPWaq)p;1e4x+m z4iGd547;KVPz>q488iI?0n>oSwW3&tc6{EN{oSEhk2Zg^QV2RjY=Unb-)KFV4D0I2 zed+A;wqU@4$Wf(|HUk#0&N_{bEWbS{?%%;+#V*`aK0GfM;&xYR6;Yw?-R3)0=h-HM zJJj;nu~Yda;214JnukfZ3BLD3c5ELGn74i8XqoCuBtIQt)_<YVVh9)aR-BVgF8^%maR>b6lTH^%zz0_oY3k#>;`sm=A3srE+)G;=o&(?ETgd zZQN!Cm#`JyuIDDWgE8#f#4%vQwaY{>B2Dig5VR)gKB?>N_+EG;fa74qxu;5)mUxtT zJ%d)K)9&eHsOEsGac6CZ`KW`7bo(Y|BPM~CYc3mUEx!)+rqsxzK*#5N5RQu3qF3oZ z^Lp5`9&{$HX6w-*xW3!LoJ2DIDBrsdmva6qwLuXz=*(9j;~kP|7l z)lIF(g6?E2Ndue@!?J${{b|H~`}x|4Pf0mxBNvyMFah~?(I32?IVF0X|JfDe?-6fh z4((ic!qIM#qBlZAw1c62F!q?@%YbNY6UKY)HR_x(;#yZP4dyM77h3+LXn zHM~2BJ}#Z<6=uU2f74YsnikIdplTP%w@>lw<&V^f z5p3+#JSEikZrEw0N@$eWX*C>1uMt9YRx#ENGVbv^MS3#~lMyS0;aa!?Eq2CWO=@Tc zk0QizZgXSRhOwWWWu$~<8pF=IRKmK&&Njis@L&9r3uK8)S(toa#$t@#n{AGPA%hnD zbBwS*g(S&NCiYjmW=XMJc*!N;$k9)T|253jpHQ8;qoIY~Yi7%@`Dn;BPOhRUAZN%c zRr>IDPg)bZhGl<9*xg@sBz4ex!;&Hl8{ybp<>Mg;X0NTr#!s>CVLnOb;h<##W|M3r zPFyOZXWV7t#(RcZXhEgXzofBhg{zsVYV@1xd}5jo&>cGE)q?#JxkgU;_A>c?PKA*& zh3UNjg&WRMkwnFOvwD_zI|T6CrWM^!>25W$L$+(OgKAccfC`VGDB%-syRYb zhjwY7zE_jde)gyw6Iwo@nnuLGz#KSCn$4H3y=J5Fp-iWdOSipTx1URIq+D+rtrVyG z#M(NZq4bbD>cQ82gCAKCW5^%LQO*F(Bs!PC-khM$4Gfa|3pWL6<4e*Cys;U?KqcP9 zfIGzcAlR8(R0m@IwBnJ9=GRBuL_-PsunI<)#(tYx63~GTqnu~$!w|5rueZf*dst!n zncMD1h24J@cN|LWe{kBas*jWSO*PVdMRZ;yEI zABR39ReqGb{<`#4{%pJfQd|+8EOXcUf&ru9g9i(Jr4}S{fyVfdM^$5f2Zq+J!M!=b zr4^z19EOd0p&AfB9G=03@b%fKkrHo2zKcoPgao~PK+aq~aear#8!#gf@1yHdF0nE@mXLi|h*n}aZe z$mQ5mLgYklZ4QoKCq-whzfH-0+4wf2Ol9*u2EecxSg3QKFJay?M=){8;f2kQRrgn$ zKUj%7I~-bI%0N(82kE06!VcsPEWk>Nd4QgTpktdexC801d6<^}=)R*$js5|fNfDE& zW~D)`X%L%fL;hk+luDfqc%R<3L$3a5R=tblspVY08Et^1bc5lk<;rifuIa4Cpte(+ z%Y5^`ZnE0&tJ4oeL;&_*ATSI-3=rZ0(B#ZMp;gVX{{%u1KNSP!7)6tqLE8(QU%O;0 z-^+=Dff#bOHOXPbe!%UVk}6UTVhr(k&ZPuY7Nla@&YI(3L@5L+LuXCpq}@Idqp$39 z9KKO3=e#%t@#Wb4qh-cmNCv*mpdiBEXCr*NHnY`K^DS=~p7ZbXhMJpjYGWz}<;7Hu zsGu<|-~Z3>U>-TGlH5=IvF7!^49}?W&kYDWfq65uU{w8o86N3==wIQ^_z)_e(7-m9 zMw#e#q0iTQc+FdehpSEHmf>-|C$yL!^7+>hA*TXM?S5OfhU(U2mLiYe4LZViWjW! z^EHD&UcskHs!EFgnwjGiO3RyroudWk;QzZzqw2wvw)BWfs%2dN55DAoP#ziOmMQ=I zZ-vL-*SV5|+5@}46&ZN&DjfZ46d#((`2PMDp=8J#dB04UWf+Qlbgd7JQ8Uc*izm?W{x2cMm5C50R@k_0eHnu-%u8GvA6*Lej=uuNmfwf0tx^$NdaCx6L!3Gho74b z_XebX4v>xA8nF)ne2LWx4|+Wv)W6)%4Mc%%58$AU?6yi%EtWRVEy${O-zWAjZ9X1XlM7pQHjkTH20Q$@>{7HBy84KhJ+LL{dI=0c`_30D2Bh<{AA|B+8SK$yrw z*1@-=$GAydzDWpf{1-b5t3?>C<_Ry8zpkx_S&i6B0%IGomsICd-KB4$YInBJr zhstjYspBRS4@GuDl~4K*PJ-_aA$eRc!#>!S5*M-i(uXOM7{W*i`K-i9^;Z4ojx?N6 zF0a!?0z`c%K6N9Jb=3XT)5m-U}<;!Cram%yI1$#5pFfiZ`hil zU%Zdg#1%S6#(@DRaB|O4uclLz5FQVZ?GM*CLHG$2fEX$iDeqDzJ**?~h2XYNMde>r zn5KuRX4s`rK!gpHiYA8nKVhXWGr~TfIgm1QVsAQ=j!23AL|rRR-9tp^ae4oWp}i~Q zVlZ$aNE8;IYjGmL7Wlf#&#^J&4#2KM^M)~#@DJF;@fq{cw@qbX$%d=YZu--g>`mTu z&!orAK_uoV^!I*dT+8rH+@-J8Hs!Ol&|ozzwP={8GVHH3z{}N-z65h8&6&8#d5#Be z3xpE(4tg zNfp}}s4;G?rcqdC_CU%@oNiP!$smja5wu0(?Q#d4{r1@h4eyHzW_wK%7%CXB?R>0> z@Z${$yu&gT&(!;vO%6)!G!{PQY@#zp_YV~#;owMdsV^7`7Z`)nx}Jqa!<^j-Gn-)fHA2ZDQCc}1oEE20l4$ARKmrVl ziT5%#B%2_y550g9DF2q#PcqFC){YlWZalmU=FU{TMUNU+o_d5-H1gn zu>YFEhzC-LMPd1_I%A;%XBP0xYKCvscUGaimu_5IZpCSuki+O;+ zgG&4uOQQ8GsGOU?w_(?<6_{@&Gx-XoGb-(BJ9Vx7d2kDOR|7`V1s=!^yx>kh;7Q~8 zWiT;{4`)xW9Sp8de__HME^>p2c*~7}i@4MLTY-D?!Ly_QA6DQtWp)LIWIrybbY%g8 z=4iwKR`*xD6*8gf$jNFJruU2Lr+pgnoUN7!$@w2k~f1%pX=uZTp_F z?O5)P9p-G36@uuKbs+Y4ih=lMx;PTRMGdAh-R$u-+VDpv zMcGwFi-{~fHSRQn$;~%ki!cjSd2?z-Hi3m(j0FS|;r&n2<6+`XpganP#k(@jK^Ir9tO&_xw*LcDzxgwWII?6iGR%ObT5liD7bfI6&^GyK+sG zY7>_Qu$S|>!}r{%-k^MRJg9U$$R7NDMDD?%)-7tXhLVy3Fis%ghXnBwTc?Gc}AmZbp&n>{+BOs#I-vCyQ?G{}ce`(r{ zm`$vZ9C5W=qwT)Iq*dy#RYtB+2GCanjH@$G`2V>h_ByK~kE_B@vT~SesppG8{#6m5 z93Yj<9`kO=I@GoK`T6*qB=g*gH4p9C^uvV53U|V2qAERLHO~ZU^43fqK4p*RuVqdL zG4to}l+>3E)XPQztWfDLQm{ywXKX-Fxrq1d$>WCGqAU}p29s3(-l?XO3&?YuY&VRA zHV1dJTi=tb17I5)w89+R;PKt8x=&uIb|hW|*;9_yF@)XH2>AfG%$6IO9J+4fR~ED> zINx42I=E7a%b#^)x;`^fogsC1eC{-9JJdcluIJea$ye(`idJ@K(_xYs zV3;_N)XhqSbX*_nT(;UotDQT|`#QKQjh$zfma6BS8Rj=GJFP$wpg*G{8VUX@7^;l~ zvn;)F=V`;_0;u6&G{6FZ9r#<2m)iygo(-^1yb;z$+^i2YY{ML@0Oj;K{``Y78iRks zrTlA1w@3zQ$*uqPyfN9<7ORvb#ScL->_Y&db~FKOl05X+#Vc&BW*gSHzuwz^Spb@A zkZ~MlP~Uhhc*`J7zyc1%x3v6D zVRpyA$)WoQ zx@oVxn;Q$oKG*MVs?G7x@}Ro&N1yjG)qmS$A1ljaywbM7!B4!%+M9q(bt6A!VQ7)D z{X;cf8W~G>vKxGPHZdS$o2p=AJk*QIyU8WB_VvK~Dq+FlQp7NINe2C3v_dH+BAiaK z>u#VHTUTmi3!t{?ZN2n_)4BO__$r!HrXS$OgaXi1*tSfYZ5D~sC1Du zrk|~!H)bsvoe?5)Jl0fB19r%N1IGA(c2tb^Yz@tl5h!&Vb)g&Gz8oyS(m94pgX-vM7|&l;?+oIIW9{_q|@XK;IYKCeKB&06<# zcpsehj-GvybkK6svzRWQ0h%xs-@Gn*=^I7HR3H$1%MD+_4bhm3fR5~}Mts!TP5y*r zx(D|;{C2;;APg_qrV7dmB=sAyyxwfE*9`Uc2rX(CNQ~~6_8!O@L{GL^psCI;TFGUW zND)qMqxBr-A0`;UV)H82{icP%Nx z5P)0OI^i>+@@k-5NEmo;Dq7hPX6acXYb}#MOH~dBhKM(K%FVzVFqtWSDEpNjd4Lm& z>S`MUyOVsnFBRr+8p$?guZ@B+vCK8Zfhxy_o4v*{d;}-}-0FvWlifvX5JO%(i^79G zu{s+Yo-(?boEngMTX_XAH$K>ykH5!N)W~9P4L9Cjh#0*1;D5p5kEudzen~rFT4J`a z`EbhQgq~fLwIw_AE0dCg8B>BPyk-3gqcBRUp_;bdQd^5=T0Rvw_GR>fE4O=1HOnqY zLFSnuiOkcozT{U{ks61`x2XV^`_jJk$(f>xJ}xMN7G_m+?G?H0gi=}a_u8+bQZP3C zu-jE88?yk zqr7-NrX{|)riucU8&Z|wHTS9qNeL-hfNhnAr_}I-r~PK9&x3_xK1b<;z zfpriXDJK3<&%F&K8}@o||K2^sOHBX0s2OG*3IzQN`X>(%aW`!A-;-SKzX=$pUvp0Z zx$D8pMEtipj~dUfvgDhs@BgRGU%wgK*s;COU`bZi^q;%+7QpBs5f%t0BVmLOY^y{= zNszOTl8Zpm&L^9@BlxXIPxr+^Z;Ozyi*w2I7t|Xm%qAYW^vM9N7 z3i%^E!P z#el@L*yFigYjM6%6Vv8i_NX8n)62VS^H)W9k_&{FA5YAA48Cux7YTgNCOjHZqi`&R zQE?{O%z?~HM!|;MEh6KHEPX` zaRN&1C$bft!pEoRSy@{jww7)<3Ynj=JJDE)!H`78`(E7DGVAQb{Zb?8;`J7u0Z<^B zO_q3#{DF{&(1GT3IKEM? z4(lf~>e4nr%Pt)@A<48)#-zwSzS^{%qzMVNSr49#vyHY%+ecZp_}j;gxEyS(z%-IS zb^_Q4wv!Wvjv{fNa4Yt0ku+b*6I85cjSkQ|vv$IV;it=Ht?=nuT)IV&Glsj}EtdQY zCKe=L&#|KBtXUP>B>Xi~H|cS~3#V8?&8+qX!mWz-02ODyFK{y~mBnoNsf^Ui1s5yT z4B5x=nRaLC-%+~PtmO6cUzKqEG_WeBU@ePSunfDz8PFiRsLNr-M z*Qlkx`i&5`xQNu;eL$VCfT^Xxmqoj)@U=cCK}Ijc&_5)D0z3pa0&sGpI#22adOnrU z{x|%B2Wa#_(R<!1dzqUl$wI^9p8`X(7$HV$2_g_RKy876vRyJCh!j>orA*U95A=31 zRl?_pbqggn5*EcD1{L*kWve4`hdr z^q&y3oC9?ig@q2{-O;9|~@xJ^BmNi7(+c2@e?aqq?X`&ZXsSCsFWK^W8Ro8y!& zYQr6>5i;{dG&omvegz{{9=;5^{4z_aSoCM@Cau~b7xdQ-y$2J{67_3B8m67b8Lim) zvYR+JJtr%1+~S$@tzkC<-%gW)^Z5!DD!uJ19OFw+OMw=n+pVA`aw#~h@*eKd{W1m9 zYO{sv?D#pkmg<=eG%(B@^4MZn!K^XEG$t$#F{C5Rsx8Xry(j#r>!S{{pVF=bLfYBp zWLz)sttSrqm%;Au@5cUO4GL9m4LDXdr?ZvdS?`Q1VzLro`vjwvvYqAi?3ri&=T7UE2e@Xn=uGc7{%O?0v^Raj)$qM-KN2JF$4 zH6PV8at7^bxTr*Ko_oe)y#{t#e@+c?Izw>xuX!Py*!C9`$slO@m94P(@`U+_6e+L! zjhp=Bt-;%KAdN$ok0Ut8tZmA>GtrrA_Is@cow~n2W)9=-mrs(w#n<6dYWwr4hvXsCvxea zxVsu5_PmCJU+o^hXOMc3)Ay6nEK9+NApEjeTQluws#7HYeE!v-n#v=m<)u+gzsB^0 zbcqjQHi_2uVg<+bF3h`c-J+&29;p(u_y35}*u*Oen6D9)0p*D`5OT=viP@AU>`OL* zuhya1N@LncYm>lOl_}`Vm`qL2)a}((=j0 zFN^^v;pbxFIjl;$c3sUYS>7D94Fg;tuD+>qWlyF{68vP;^C_o0&*pyHYmt1RvZ5;H z=?=b>_K{f<$_co!d}9#-LpE(@98H%Y_K1eDU@r3J9VV3Mc+oyCMP)UD-P{XZX`L*(z*3_mZ9oMml@P3qd@Bqd zx2ymn$kl210M!v6OD$#-aY3>BnWMq`(za%WB^vFyZPS1^RGIJC_L&1?VpZ6G?Dy<4 zi#;G-i&%~xvD#o0{psaTvi-bi`w7C1zw2X(%r54du&&#D%js2Hvj#4 zN9mx@n=}5<@1~Ig;k~QK_*jifr2ja>0iEAv*{^f zrH^8by_8nuxv%aH-o8I5?KpTnC!~B#OYfmeQDxa929#}H0`yuyS`MRY3G%%GcW0CC zj?T!pWGO4%iWjR9=?8>DAL0OG%C&K+xv`N<$$;m)I8#MfhU6VL`p{qzb$!n}v=3v6 zGn(R!(R{gg%nzFtgWF5C?nD(b1YO{;a?6pT58-Im&_Y%cLzD1b3O2AC#4H$&K+6W< znYbPgIywHA<@^gobtb9w^Pp1al4lg?ZW`lV_2QN1<4A{MlUWkH=mB#Mc!lFQ(jcBr z!-TuFZ#pjo0B{8}{AjmSGlOWHntZK*+ijTu{O$}m9jIvBFix*JpT2*P_n^J{tlE8<5lppvD+nb zJS%NH+p!e2d()Sa(w%`Wmw^S|kUJxRYP~4GTZJoI6+R_I&wPKQOx_cQ!3Ay&?U0P9 z@j`89I)i%v%^vj#O9nB&%XNCX+=tz>03MwYNOnsEE z{pD5Bc=!C|o83vikK-ahej^EdZz<(4ie(I;C!9?MJMMV1rCH0;)@K|y;mO)5$9{b) z{ui=}i8zszH_`-nwG4-YfE*=x6~$xpFY_eTS3$K+WHNZw&}MAE!K7-^G0bF@8w2mG zZi3cd-KG_mxrX1)kA`Ze1klYkPn*r)Yr@hdEwzzo24m#y$xAi6XX6%ZBx(_{*0Ygj z+OhABCr=sPedBxg4LI7=*P1s^H=-0r)lHGF9*5bATVCoN^}09E&RW!UYf5OS53je_ zcjt{Acb+J;C>2X|4rS?y;{ffFUUNwrgK`)nsNVcgAe4G1SHiHA25E?%MdD|;Ss7Gs z@Abi7dqWD^2ZfSAVt5dfks)%7j^a`knFTUwfSeiP)8lXab0vE;F| zTmEMSf8+A}S4$t1ce+fr3X6PgEwN3Mq^Zj2>V=>TpWfOtkOMgww5q-a&gF;bysb&5?j`q|gWV;pv;RX5J7n%7^X# z4`vSWDUfxdUmTu3*zPuHzMljLTsLpR;zM8jT!_$EiZr!Q(4ja_gT3maz&i>3k*Bzh z7OU+TOD}oWRJ)`Z6)=@B9w%^YT$DzEPhMiTf`(hL->lNwMgw}gB zM3pn$yG=(KNr?UmK9N=J?i;dlE-cT7C zlgWuuw7Yrzqx8m(hK+X-+xQ+^oJXPSKzEW_xl@jfivqREH;Xjfq5$P|Z6Sw6VEZY5 z^|+h?VYRqPN4Gv2~Z6SztCx2VszXh^r=Oxv^q+jLCq z)h+RWxCcGx2MjODdS`6uGqzdFw{Ns=Cp&3%9ICe4ad$b!v5Zh6dbYU*KJv(X+zAm& zv=C<=*uWiba5+#ge7o`Sasw+6RlKOJ>QiTtI*_*a-9G%#f-S6l{*uYy_tuXg(dcnwJDraa6RGfDrBWk26WZ(PAn8zY3@6k?K_o^7aP)>d8TF) z%%@rQ_f~K1R61|*ltDe)W+Wr9&ti58@ekjwC9*>YEME$Db%Y~s2uCly$yIk-W; zbxHmjg$@|OetNHYGH>P)Hu2UBO>>!aza-sm()82q;6x#h>fiIaNQlzE;#$v%qfyjqGc03bXsi^qpJqJI{5o&rP{Zd z=Cfdatc4P%k)9|R_~a*)aZ?6LK?{R;rq*03AeO-X2wZsb6RtX@>qe82_CnSlhNchu z;o`w`x&O!txfR97UjH!oROXy0hG?lYvgRc_LouJJy6{T7rM!;;{6*)wzO6SN$n$YX z4-vsj9+Jsy;SmTtdJzx2nf&HS63{LpknjWB%4Q5WJhgu=UVF@v1C620u^@nZUYmrA zZa1wBHQ%S;B~D=FNew1d$uE>`TNdw(qyT4l3&iO$$kL(oGNeseV79W;QHJf0eV*W zDx}CaGiciSkrp)>80Si0!br$)Xv zXWU|;ZAooWE&Ev@y$>FWTnl~|Z?u!oY$zyV$ImoGw6A?PU`7?sdE-xWUHWZV_s#;L zBPM?*lfh!{afL6bkJNs|l865gXL|fFw!D*F-|inUOCbdvLi3bpU`Up4g(n@nFcI~L zA^ZE&<@($o+lTVrs!ZU&iN#ZVNTpC@T=wgb0W=o^@PTzH02V@7ftYa}_(`Lw4k#30 zWVNh|)+`c}A|Xv`Y;i)-Vd_c&%X5kIn);=h4Ub%2FK9?mckA772 zo7Dv{frYCUb&XlHwNKsNGge8J{tH>F9M0%7YF#m!hc+^r=fP~DZ_l#9Rmy3oOi4SC zE|!|3R4^r?mK3}M1p@)(C+;4_@~8=C6L>iW{f8bE8@l=_yP2G7;}N)z~c@ojCN%{w(oJ9HeH@_{h8U! zW5ASFa){pFHu8uQw3p2@_T^CykJrj%?F&q(zQr4vVfU!{j5GMTRV`J3wEZugwZAze zDbqgZ3r0vU&=V8E4%(zU+@XLZy~S1|I$A!|p58Tz&!{vN4SFtMsLeiJ`hZ;Nh>6bF zY&;{|@rJf@ic|4&M!s^UQ6LwawwntllM{b<&-(uRvQ|b$TUE~UNRGtCe0k*c*UfUXrLFJegG5jxiF$eCEIPRs1CrDoEtYc38K z5pUni2@>x-s_qi+I_dl)-u-Pt@#Y)Fe_CJe-8lF8YIOnHdwTCy|D#2#TYpD0dH^JT zr`9AlS2Loi$+FC|a)4vgCAXfBRWmgS@h#AO{6_l2gP_0COAq6eWtKyQ^}I?fxz4$~nLEXbMw}*< zW!HUILu5Cee|{sodCF6?VRn8Z(=KLw79#gCMx;k>CqeO_+^+1M_4;U!u7eGJ6mesdQLw>++#k9ZYoNd+}g8LHI z1lZNKCZ(=G{eGCQ&-1MamF26G1q@dCeED&&@~`zje8Wr}6S-jIcv;WtCu`O(%70F} zuJ2rZn^eJ;z^EliO|EJ6|Ct#<<+}ONpw^o}tanUkt#=%8z=pB#_+^;d4M1c$ia5=r zX+DyoTV0GIE%Gdry?C@`#je@CVQJubKE4J#su*BwB;dq@xhWl3SiChHNdr0L zYBYc*h^khHy!IOSLh9bN(Z--Anoxkfc*y{gTzk38(hy8#7%=UCIRm-eX}{Ea`8HVK zb_sIY&G?pcj_W%X#1JIeQBoXrT86IXf&4L41z52VFaYrGyZ(CwidI@kXD{qSKhTRKt|Raxa6AU@Y#TN{@J0wJqu- zd|Ky2e^I2(aGM*5iTNy@Q#rvm)R9NUQu`1fr=yJLy?_62c_pUs$guYYfFICO@~0+` zeA=9rLEyT%?n#J<^^Wj|4s41hzTOiLK1c!J2x>wqbQ4VOvycLJWIjD{%ra!|t^bAw z{W)P&<#?M0`424jr9E2tg(}mV)`0>mQEibIs<48?awd%HUBxRS+LGUn6zG`Rn?yP6 zrrQA5qpIT|wYkJRlnn3bRlfad|B}6jL}SjLW8Z{%EV3Qc$T{dx2aFSO?4rHy*5J-{)xHR50tVf82|pB_p@S=;fC~eqNWWq-oif~Z=LTM;B`S$6yp`cY z0k6ME3VXL2vVFoP15=MHaVYdGDGm(;pjiezO966>%)8H9Jx9&QPw75isM{-l`zBQ) z`{s*O$GE`@skD;rH_+RWjFa1rBG%iC?Pot~$Z2)L=o^POmj)Q4IiB5XKK;09d>dAq zC2AR8F>BHBcT{dPn7wH6& znTe9uC52wC4*7Wj_tcec$0qL*A|El5Ya6?MPI%V6!+9+?rmT6FE%B2WCmqjPq&tOJ z(k1t%$dBv~jTSP4{IFfZC@8he!M44a670N)wOR>CNtDfn7rnE`TlPq>TVs1RzzIeBY>;B@` z`#*h+OLmZa&+G?G3W_@RWvWU?l@{9}yD*XgijNUX6skB8#si1VwwAv)UMKSAVX()G zscDakbSzBX1*J^09RxhWS1ks2*e^>*(f1!p|CsSqrLpOefqPfPNvsqEFI%y zp3Q?tUISu$YY!~)q#m-LHy{wC#RzI1?wrEH1~o81W7`*VuhzDAD$izV?R`H4xkf@Z zF?-XJl+8sbV_D#J;GI{K3X@Yid9ZuG$dB~srSe1|+4*fX394 zMlcvYpaDMFVj5f0;_KHy6@?N&4+WC!nAWQR z{w>xqb~tq0oS!^p)yUrJXHn$Uw$*<5lP>~K&jr@;I@Lb36+%Ux8yYRs`!Ll!_&uqZL34>GB=tzdLl64twdvIRgmFS3Aq4<%?__sm{ ziJCv`TpRLIFo$twUxvReX#{>ev|4gGn}!TRHSa0rT&Fy+NHM%!Wq3b9ccC)P0GV!9 zm2QK~aIVUDg3Rn$@bW`uMO9@bBC{tZGc%Am zjIgL!&gx=u;S!nZ67P}R2^Tx)QArzWgB%PZe>$_l^ryAQkJwh>kbJj~qQ0rx7+`!Tsgp}naN#2?&7MzKQ~ut z(xg<^+4A`V!C|sMzpM3RwY@(j=kd+9ubX`l6kG<~{{!}j*gjzd{2#D~vw`vF=cWI^ zo^?Xnm#x9L{{?#v-`D&H_M8=}t`WhW-k!gIE`EG|=N2NK{kN^q3#JYHe-a-m5|bQM zCUNAD1>@88jic`D|IT~Fe(D2ngNNq-BYGCC{r}5*Y$1IM4>Q*{u+9~A=PT@rnefN% zqjkK?-X9VkVEqn49G(XLx?bm4fAJsDbJ%+KKcc5V+}NC?J~K>$=6>c(J-uM&3^zSE z)9@di8{}i&buf44tSJSnbh37E*Z&*xQ1WWA3;vJCiRSpw4Wp z)mbqMQPh#0FjBB`q2S1T#6luMYpV+QA2_G6l0j4d6YtZij4Dp9Oxt97ChQ<%TE_~J z$oUCf8*-rNb$>5H$Ef{){`-f9W9ir4S6UzO_@xo!V?urQ@;yWi!>Htfc-!uP)3C`> zh3{g+46U%kxZr$&p$tEapzoHOW;N5vB&fGui{bl2%yC}qG4Zb^*rZ7$_19~)wdR*> ztNARE!*`qaEbFR+6|?)EV23*1#fCdN#z|~F`Q9WGCO*;pncly9xMPHm+7_E8vSCFtTKan*3R@ znVk36z4dt%z|(gN0u2y&BJ%aFg(Sz-c@5(`b#niow?67IX7moaMWC?rlk`lo0Se4a z_u?s}ck*1fxn*s#e6K-s3`nLdai&=&y0t2fXM#O$4^QlN?39RPhlZ-S*o(Kl7jZMjSrUUIHPNW^R|>ZTf3R>@ZjG6MgY*L zA;F%fv5MA6W?S?Kgl9E@rf-J@*n#Fl!Gf+jI0UZShF_h0i!+N%-Imx)?kyXTd$D@8 z@i#Hn(1nxU>=jlXc8XU|F*Ma({hCqS zxyAqGyev;y-167&MaHKTuX)){lPccYJg262veo$X#!q1t`=`!N_S{b}@p0Zk3>~81IgCz+B67}bdRQ$JKdb$*wfk$xXH1*cdm66G4PN|r`=-v**9P#(&%``x3CtSW znF z&QL?54R!t=n~!)BbPf2ur1ckL)^8w)@z}+3`3I6nz}!#0d=C6`3w`NOp``sUKt)&v z)v&sK_HSWfWZzx@;E-BF8pNJIrBNV&0O_dkiIxd_p(T47=JYG`A-x6smkS&<^Wu+S z1Q3sfcu7-K(gHI51p~r61tB~aA&31zX~hp4G8wR|U`t&V*3(7W9dl8&1*hefFgCeh zMi*ZYH^im_3`=t2AoHT&_AoM}&~J+nSc26SJy<$`JGjEzz+lcLnjBP=w-B5z)TfN# zVoY#Rpp7;-pp6*~Y4>-Zb%_KhRYRXTJPZ6I6Z@SRQolt_e;RltE%W4l}IC1mtFi3mO42LOtZSObNY~gKOKRV9{{2X9@`+)tCTD;!nfCA+p$Pun}?4 zzL4A%ohC9a_ZZKhmYquCm3lWOh1M;N|HxKW9zbW9PR0$wbosGJOp1bK3OFoS`AAbBCN1c>wbfRp;u7vf2gHsgnY<_@!3CdSo_xzG-GD9y zwv*&il#ylcSj8Rl{UCz80qX`01-r-tLkCUY%3t*tz_Zob9eR%BGxJQYZTT-fSd~5Z*ik(3oXh8yu6eH zbXNJi){X-*7hp5O8NK{lnO_BKMUp)?pugn5wUbQBTjDKgC8n00OY4lUf=?~d+qIR@ zCqD%wIee8*!;)0IAet}6C{mkIa%&05z3OZtbMrGas+9mabD>#Z#93*PYF3v8`9;}$ zfH@XI(B>36Gn%=DK{g7cwKD;R@@WRkFb9dWyU6r-tz7;d_ zv9pvIV1>xRX;H9$TQD2gs|rcSIus^=xq3$q_Fg;J7=>j}z*`mi>U2VC3(KCUSG8=_ zKr=I$jPR>%AulWNWX0KuQm&7`71I=zG8Nh1Pv!ZyP+}dHeP~MSFGtG|<-`Rng%!m~ zWBu*Qj$ZtZDh)WVqS%V{Ap@4!5nsdt16fC6PoNX(W1R$c-~~Pnv>`%1@3r$(34}{> ze`9!qXB*__7EIm3pV;2@L*s;9kJRv|hWY>;@?3{QynwZa;7K#TLSkj~{ z6fc92MDBdi z$e$Kbl)?gNkk6gTdR_&!H?i~MeN2(U0%rT%;`_m>`qy*%IvYd@?|Gm*me?0f__^R z>k+#{VI>Ko9APSLBmq00acFTPfV`sWlfn&A6e&5GEmhgTq?+)Ymt6P`w(YR&->)&h zyjwBSR7=%0x@9@PWgd1^*kyqqX1_#hk#0+q4$9v`pLS=v*3>4~uzPyUY1gArRkhb} zo7yed-!L%Op?+k0{%j{UEAAl6N0BoiNLcO>3-kU18p&Sa|?s5Y`Bpo z+&9B3i&XYWp_h0v7#i9(*g^iA6zd^K{IrD8BfzptfkleclM49Eme-^?&vk=x>1L=y zqT|UdXdFVp6hZM0ym`RY7Zq&{9Dg2>*u|fmhQ4Uaw>B)03TncNUpN~4%5J`BT5)aK zWGN&;6d}sM?&A;;Vo-Bxpmp(889WMV)w=oZh1BOtIs)N65;Aeoa+jj$$z2#frs$7C z-ewqsMWi;FDlT98<*gW5jB6 z@3U{cG^+Jd{nS_|;5Z+5DY--gXq#Gm6vt^r46r(`oO(_qNTFfiZmVV;a~fnO7>Ix* zw7K}WcK&1pmg~Yy;GT6`IdyyNuPH(@mMR4D-t34}J`&EK%8sJl#$jahPYfyK_p971 zQ|sW>b}=I+2mt*>(}8dbs6V3)0qBWwGLp!%J=V0s_Q6_Ek;b9bL9Ldpu{8^#l`x_- z99k(9o%i89Eat{5teFeGI-sr(6J8k+!&((@S@lK4FRq(|!UG;2&rG`xPhl0o)pDH& zan3u^C3eHtMQtEn(dn)X(*FJqh9PfgUw@8vsc8qStyJ8dFe%! zHTdgb9wh1PWIJu2d>0k^)+t2{%r^ss!W z12SxxY>1+%E3ArFn3}JxA=?h?jdSo+fY+^ep$Oz}*B}322k>3H(A+n4wF=D`?ZU#h zMBl>vq7W8vrl6XlxKTFOnuW#nS$$YbpWQx&=JuPaT7oaAJhM1tCVo6%;;LuGXjW-< z_IL4C7hNx{E@3{It+1MU+2hdP2lHX;MRX#qE9v05 zRnl8O2sTinN0UHv;bj>_q0A0;oO!MNnH|}+(b>9qlWGUq#isRThXj$!tuxF{HtDPc zcu+6`Jk5T+zxwX}fv)bjP{B*~&TI~0YziXZLMhDnlxUAnz=b}Iss#*S#UT%lbQob( z91L$GyX5Fn#rX@U=(26-=8g~sQLM_O+iQn#Wh>H)$91%vjIlMPw1w^ES1X`L6lH~< z{?rc43P4sb3wGKed4<^o(W)kQnuCS4&uiTk!L^VtUGLW(4_iwEwNl>@&l_Y0dij7C z!DR*&>(iWR>whN_J`Ccp7etqHBM~~tEeb4Oh4m=|=7QW`vA-+Bv_`zewra~d6YQ#tSkHR_)u;0oGiv zw9t6<;a*6nb(?RlA*>>Ik*Zu|%P#`%G_|MY7&geewI6e= z1lcUG(x0EoUtuEU*dIuh3wxzRwc&OvS>W@NR#IS_3(ZjO3v-e@MEV!TS2qQU-Z3Ix z#WYt6irrvj=Ag^}>O@Xx|Fi}Zx(O*b_UY`&a=ft|!u2`#Q$j+;lncv1O|i zB}sQbXOJcTO)#ezc(<`|5YQj!p|=v?2mtG?IH+^|0X(cR6+x0va;B;26rzlENWm;g zk@stGB&qB}ct->)o<;dZS|MS5Aeljz=_)I;-v)dmlW?064(U`M_;i$+oDX5eGP#V# zKXL*|=JElPIskwveNXdxRX(yXgapg(Hh?eGb^sy)sGk_ZQ4N@)ABC>>=D07frm*a6 z4#k;+$P1SxQWEhlOka*ayxyNqrMe-bd{J)KFMRD#0k$`9cpQ;HOE3^!`sJ|EZW<@> z_D0F~^(hY$^xQ;2UGlu<1RTsBt$-}k0pX}z+kaQS(%-W$zj`!D$S z&l4A4Ut6J|W}mPBMk!DS)He((fb_b>8G$2*R0J_0;mW=|zNn|vZ---n)85%b8Q^GNeM zeOLI_$}M!}bTK_MhlXxnw+y31jgU;$FsBig80a7w`p%S+g<;8s_cICYBn()UIOuD> z=zNiK`3R|FZSn$kRJ*OO{j_W$mL6?;t(d315Ls)lQGQ}iGd%qassCYG zx5npz@GYC2q`gJp zDuPBzewB$~?r|USghqN3$C!JX@=odI&K*vf*;1EI*)EylInW!Ko*g9gtPLN(J^KHU zHlqST-C2_buOBvN)^%}mWWq}687Oi5sm9)uuN8s>T5e`0i4Y&tP1zr;53u_8zAn*C z$nT|i81#ZkmmX}MHts)qwiqqg6R`7A>&|7dD@~VwPOIUaz@v))!WRywjcT*VMkkHK z-pSWhfo2mnVqmv{klu)Lh8IFxH_u%EWIgHVdGY7#sLHi&*G*e&@OaaxMR90>(J^1J z()Pgp-OLn7B95LDiyYPk&Ggb;z6)m_BI-0k?S5N)*CC`{1*M>|u2*O$VSwwadun5J zVtA)W%w-{F;{)TwG#_wEcbn9={VwPEOtg;Wc9Sz&c4=XZgX*@fSW0+x8s)d>dy)1G zWpWz7B%Z=4lj>u zVoUDuT}|Qqe=`{E46x2rq}(U_Aok)KSE{+>mywQq{PN>9N(9$`r`om`hYX5l{LyJd zokl}m=ekJJo!%}N`E*XLTvQ+mNv?LuGsD22GMCXd0S-q-e^YwCiRo8w2%-+)NGR&p@;-{z%@=g73B)3n~Ng`Dcp|*-%!%|41~yEH|;wh!l;Y?eFi^LBiu6aiq=bAAB;OPvXb{V;EZy^HLv~ zIECLh7BAj38AuHng_f3LxOeu9+9GX zGr5$%>n}vY>6ue0>Ll*JJZUQvo6@i3ER`u=7UE&t+dTK6c&f_KZ_c90SjtSz%jqUB z@BOb+kDbe`k0m>orCvdlUDcRZW}-ElAAO#x)o_`Zs?s(n=2vzb<%u!@cKT!w9Ty3N zKbNFiQWp`BdSsLtNH2t4zFlVe=F=d=%7~xdtTv|f;H>_c@pqkt0`Bt9)oI*@y(1Yl zmdft-U1<+K3!j_NsR@TYsNp-+_iHBBxet4-B+&zt$x$sOS7kXa(87$q>b;8NfS{EF zP5-+uL)@NKD0admwQJl}CHz0EZbZJXwa{YX>o^NCS^4syr(%E$du{V-_+5{pTn=@? zX^7&4L6YW$Ddv-BYovL)cXq?YD zd-!DSpj~~NMNG3*NzGubpf-RHl_G|KNEwE+4YW4Lfy^q&P9yjO=Q=suQ@iM;1(e!d zV>03$N@%Lp_Kn%Nj+wP*<0dk5A)J{IP9_Ni zMunSJMw{8;WU}yPLJ@puO8#bIK!`Z%qRs}~RZ4C0OQHWD*Pi&WgyT$2{ztrggW5uX zoIG!$kG#9>hLZd&zZ{3KI^z~pZrie{Kc+P>7T;HcycML&Wys^r8QbXNT3XwB$38Q$O z7rsbIH4JELw$>+K%Fl4nZNs`V4W3SjKA*hC7wSNgCKTgAbCGyN`33oqF+jBupm19jc@7$n({Sp!tYN7u;X4>%tD3IiAT_f$?pwIdIi!!SQfR|cH^thc)l~TtknPF7vnJ= zo8}aed}`xRnF$hOGjc)TsmX!z>9iT|PYB>P^xYtxC6!UU=c3)TVu!A_Ss_c1 z^c?ni4Z^0*EU29fGd6cEMu>iXhF|3!%fDczu9XkGHJij}=`|}*q!(imNiNb9?*xRgzGr!$xu{;yeAp144M>ds`2g1SlDXBHQD$6s<&)->e%# z{o&B+T{$O)cp*c1$R+3dt%j0^&A5i2TKB08`VZ^ZK}Kc08~6Nm=f&X!e(k{ zfC!}Vfwd|n=sKiyqAxA+N%|d#7mhoR`bvb6RwG<8OQkm-WN^XRzo#w??Dn%8Sm_(Q z^8)Aa>4I1t5fC(uaSO;prxz-{FrOn!A#^+7$Z%V4OwqXHK01L#HCC1p*F6h+u zD5!NQU-p;iU2vbGF+meZ0wQN2H!IW$5q$`Z3$JjrGOUW0@`Rec8S@L&;WVXxdMxKj zX~b}2nOF0pW(=XxLMvyiK!u|&!mvLgOW#(_D9b`Q8(+bSA=j_MZgYf;UK_?o3|VyO zjue0euite++}3qI*_Ac30mSU0182&?ThC5 z`E;u^0PVQlggPU(Wkb>l(~gf9O$L^;hvaDu&Va(U$~}LWL?rc^L7kDOl*pMoXsY*4 z>*1KbjziCM`<7^Jzqg|-5q^di3Nh5T!E0%rGuoy~;fzq9+m3fOCagjgW;_iAr0;&d zi_#T;A7$cf9tIUF#53)XCr!y~&-BnO?}+P+z~(T3tw`);4azz{QNvpSv$WnU;{2pj z`C3L3#Of!XR!fE}%a7R?_*a4#3&l&0xgIQU-IxPomn??IOuk`X940j&zfa zWILVf*kM8=R2jo~;z=BU4d9QC$$b$IHZ%ZuK75s~(2Y_vDG{qXA4;R1T==KP*HXD zb@@rGBQ0zex8x2ZsQDcu;s_MqZw(t3B}!T7Lf!Sa8kHWiaU*gUvhJZD~DWhJcBtXvxF)A@~sts}9hw#C>a`5DVyR^?mSXScxG8aBXI& z*(2vpyioqEMXHulpYDWjZ~)$xY4JDzUrgQiJ6nOn_wghWl2C$BTM)YxwMQE}HnB(T zJzCVL)!0InPut|4V{5C;$<4au8>JAO#?cvjiO(0X8R8ewckn^jyl@Br0`(k$i(9v)D z*MMHO)w=hQ+p-eASw+=ve)rKU2FOd`0oSG;&i5JahOR{77OA#mKVHP8OcU;(Odnh+ zT0uGE3yh=d?o?^_RIzGX6dhyh*yC&NK4`$o$iQ>Z{I$m-F*^ZUnx?u3_ZUm#s}I0R zh?$h!-4l*CX%0T~MZT-4vM$im!m{Z^E_|YvMv~76LBjs_8X0lV7eWpFSgL@>%D?D+ zH}mpkGAywdnewjkqY2``OIgj;Kj#mFH&t&_Jv0Ps2%wTXOp!1 z5azop4%glgp=RCGdk!gj0*fQBgkgsX)5207nUR+|fa^tf`WI|A%1m4$ERJqLVf6`& zj?}j820lsA-ySEm|Q)o4$wtyem z1(7qqEjVi&>G9ywEj3$_g#j76B{bYx4&+Tj2G`3$DE+S+pI zRa00z@{yFj?-KX~T#Y47rMbEf^Ia={k*Y4m`S#*TRobB?$o-WJv*MkiF(@xx1AD;t zQ~`CGA*fj1b~qe!ncNpc;QD7ZcsAL&Zt7XQ!4Ad_Rf!@#>)5)7DNk59(Z3$MkqmB# zlNF}@ur&I1MAkW?)&U>pe)cureEf_Ok)|I}nq~I$OKef!!16ogN7{!g;PP<-G5Cq0 zs_L5vDn)PI{YN%RY-DIa`v^v@qbQ90@X+lKVT}mN42KTv-C5xPD3!lD=wCA%ct2wC zg|Z1Eve&1SJY(4Q>-6=8rfuJ`T6F)7cl*bIdSdskes@KE6D05ms@IvZZMwL_Zf^58 z8t*WhA=hr`*3YC9?vwhQ;=EEd+$B}7O2)~D8(#p@7M&~5g98@PjoHn=)hS^_$wwc* znThBmW>n4i0giW0L#m1%-SLR&Yfe8Bq#+jz6utS;<3)4=hk_wFXaib2@fuKO8g4V~cQ zFTnsJNubVCYmQlYiV1haTJSQrJfltmj@$#6#@+275_`ayy-tQeueR1?>fq>-K$!X7CEKU@Uo&xH_%z3D zqrWB}G92q>Yz3W#$hekLkOuiUXZ@Rk6=uS)t=}HKFm$KLs1XwK=!yEVq43N4ML08G zi|e$Htkz)CO)AeVwnt5)%i45gSNs4!cjmnQeWcXSb!{$C99Zz2n8t2Ezm~oZaB3fR zFtEZ%M#WTiv_P(=fj~Qq7;rN_QK=kBv572rVnh@4PhDJFshlXTPP8U+ZQ`&3CftrB zqexfYG$(*yURcK_7z_{@a(m{Xqm!1vc2z7S9-yV2?Ze(qv(MvZ4*|BW{ffY~M*G0-YJh#|ZwjSIDI0fj;=Hx4cpP$}BRL05 zNn>BHC;5O%3_;uM^q3Etlg$Yb@}pIw+W`|SrZ8ffo=Gc6~qi8~{+Qn|&Wu#nzFCQc?)U zbIiPNO8fNw_5+bKfyfXopWYn#2RB-7e_I>UnB&(=N8SVuWZkhGQ_B**mpb!lEu#Md ztEd&r*pA=O`vT3DI^n~oce$t;GHz{ z=zD7EViod};N+4vcVf||da%tPTQWy$#!IGnsm4tJ<+M~qD3M=dz$3l&fS>v6r%J!F zG?06e6TKc_yAAAtQ0^akX5d9e!Qu_Qzkk0M^mB+X7Ey9r5=jXlIPXRpBGRe zhxXJxYOpY;PAjH64FBxVG#70zuXan{##_w~ZCLYYJJ(EX*MeUe{p5B>B7A#Cl~!(& zY4%|~NrgH|t00$n(^EDNh^j%>-$=C|JFY6e@TPhmjt%w)scDAoNhHTfCW4)B@msI# zD_|pWHptyn8M!A6gW|BO4e)C8x*(|hI6>5zn@vrC$&8SxC01_e?XlHBC6$S5l#6}q8k0tu$lT+5IL(V7GhFLWW#%v&)dGPCiN z2aCxNHc>T7@^;kW^~+ZiVi}bWO|dO}ig?*M%k}V@Jx`~Fi;F;B_W;5Tplt}7-Z|?@ zl%jUgEt~`_8mCs<>Uh8(dg^xj=7=L+sl~RbyZCqO;Dx0~m8vb4upuEKt+3-`1(2n> zQnDN%>5()2Gx{sld0x8WiC5W7Z<_lnSXzg&7sKg_3YMCJE^0NJf$r+KwD-9)W~oYb z<1wHAva34l#_L>_>MjpC+ZUyPz;?kj5uI1-j+B71dEvS_ zhk81SV?u4=fPIzNzPhCglW;Ueq*|&X! z*M+9gzhCI-ABWWuR$%(vzF4Gy&JeSzlu-RLz5Jipubx&9SjC26SMDLG(YkZ!so!urJYW|L4B7qsaJYWg zklDYfGSM&uaIID}>)>|h-fV&j61?ZWmmH*VXy^`Oos?yHh7x>I{oUH&PvjIsr&!Sa zLybLC15H1YyKjEX!mucLem;j#yU2P}AwDI~t&Z-Dq$Qqt=Akrx3pnKvr?cL2kdwB( z`a=8(-L5l@*?<5nWQuQRg1vAX1vjzgHBC%`9^k?zc*i<_F!EKZv2LVkeri!Keer}| zP2S4H?~nM{t;Uo8SIel1eouSO3$Hr3i$kZR=N_1(-G&Kj50sk~Gz0l>;1iPC>(!D9 zfA&!QwAt6~npKxDZLgI7_8{jnEv!F%E4>JobDN-(LOr>r2;gb8_Vw}K& zU76Vjj(!Oh`mEXCe^^Rhk^^RP<*l_?ZrGmZel|<8{VvcNhdb13KbH6~m#Uv8*zQQL z_}oq}f3=;YR`o`u_k({rf*GG_?S;f$PY02rr6VIchCfGW$f)skCQ5sYHxpCM*Vvl! zB1Sv=#+}0K-VS$Ea~3{e3|Lz?aeCXO@Zx}BV+|AV^+$5?%$K2;pMQ%nyLIPRrYaSP1D6V?ccJ$t3IK2~e}uia*)zeDd}ddQdI7pG zUCQ3}pH8H$hjQGN`UwB_d#>o?Um8;Xv!%irKAYIjfpam5ulx~M1xe`6$)c+7t5-(4 zbbAK6hpD6iw#f(-G`<*`f`Vlg|934@l+^#w-WdGEVZ@M%Lt|kP6}ZmPqdO`6X(M7M z)3Do~X+v#$Wz!$OsUP-cC;iBDl`BjG{*75elT}5aDa-Mz0t` zORjM+Gf~G~8os&qf{i-G0l`;7+vvcqq$Of##QK_@RUc&l*fQMQVzY8kdXhceOi=0u zF(lH{_3M;Oa2b?e%QdyB+~0!V=06X%}RM~x%u$llb9MO+;egI;b3(mE+eKa zeZwt24X)%kZgGaVc?hHfOL?Vq(4ix>jC6#5jfkAb3Uu0htbWaGbp7oa>x9Wte|RZu zBR0u0gDJWcwIA=&%&nGKqBo!`@fGK`xbUQzP5kSkK5tn^cPwcMeA9*Xq{UoVT3@tO zO*sDQ1|}ZTj?ib798+s9`&vx;Xv^#r>#?YnV;k$oBQaH|6>CUfe{2s-emF?+s@*AD zs_bv%`1He42blA%&=IYikDka;Msn^6bP6Gk!;fLzC^750s|B`_N4uG+dWG4K;KcVo z8Vy=+xm-_lR4}N3>MQ!%3kWT2jPDR%V2{R)O;SW))yXw9nZdkGC3&^Y8VVUhqF<-t z(OT&YCRs)sIE{^m4UBa#E8&i48QKB^%+yVAUSZt4UL&t{XNC0(UYoHBo7cP+0_6=k zK&v2dB&b zFTm}g6$D{Osi9#A4j|xzzAuIb#fMPGaXVVyaml;CTov)2|Iu#Mqp$qt7qR zK6Y6!+D&lVDN^wc8XL9?2dD z${hk_*-$sfAlwYs(znjV;_-y3*dm?tx@a7HZj8n#5VV zT3BED83%B&y%A?u73W?S=lJYLGkpT6CUB6%QD2Vz$tAe1w{mexaPzlbTIF+h9UH@v zVO`BaP)Z>1S@zZ^u$_?unTiuve$%){R%>R5w}K2Y5(u^jJncVpQ)acAOUY^%@YM7$ z%C}l-^r5c{Jo0j2wO-YRAlpo-0N!=aBL7#2!vkOe+@hrVPl%&N%>bI5Y^!?n3=;#7 z=y{OMg`_}{6?yY!&uM|Vvj+^R>B;^`!LqmL$(R_fE%+NP|4*9p;LH2|F81sbO)6(H z3rmeI1OtW0e}y%y*Kd%E6UR9o-1t7R5wH7wQlP5-`xIDbAUjQvoH1MMCY+Y15KX7A zDRQ=GV(xRcWa0LN7GU*XjZK2c`HHhDRR@MKD|;{xh0bo41-oZsW|8LEzDd`ZUGE9` zh$eGIqPd1y){f(4Upah33}nCEJ5Io`A!$f4D>BxIGy2vLpkqt<3qG7V*dN&^t zt$mWj;}M5!{*RSo6s`KFlxZA0Cef=24E$cgY$qAVe|60J@5H9WyCqvWlR-H9CW%fI z-mAd+pQ`4hyCr*TgMlx5RuRdMFFGmX{%^|W$1}jSY7B+;Y&7*K16fYA^7X0Nc=^*3 z#FOg(+t8SV{OzvRSj~-m5_(#Ox>h52CODs-cUmssQ_f;MHJ{acS|RqNMt4vnD|hYx z@i)e6Y3)@Q^-l&d+KYw>8zNP(6`)x0`J1mIT*rww9bYH9=*^rKk-G;=9Oht!?YCi!$jxnQSNfZ28 z4Ek#$-PCoPza``h=k(Kcb=RA#4YGE_-Dq?qoh_)nx?jtIw9gv1Xo7OE(ek2iHGfDd zTt;;W=z1GZTjQs~vrt_RpDJKtSK}7Ty6kS;E9j1oT%)6ZGNHHQRCjg8AL-1{b7qDl zKVH*q2vRDOE%J9mMt$PSKz-KJ3Z_rLo3S1w`?`N4xa_$WFJUJILHW(iwJ!ey2XFT& zWZFjQ;(an(mjl`PVmR5K){WXd4#b>Wh!c2&ay4eqQyAZ@POX>;ez(*4hhZ|BzSWlM zM`Iam;HIdY!}TwsJ44EXN80jU+bqNUKJQ#;T_{jMnwPvIinJm+v5LeC+sW|+)}|-f zubg|EOD7FV{H%+-vuS^iIBRgg7=U6Xf>j2SWe=M529g~ln&Ob4OLf89Up~cgGn1wc z7ksUx#fkXBzFZ0atVxH3F}GiR@e@`pxPl+oZF*6ujoo`UZ*-LBQK{5ZC(m+Q3_eT@j!>KEo|CqhpRvKNm=ZCXIUhv+FKZFcH{T^S1)8NV?~Z0J@$GF_!ihW#{)&XTRRBq(SXT zD*w_gZ=_F;zM4Si$1W%Q3ut_6y#ssu6L7Kdqcz*5`UZ3w!80Bdy5=%{f+c! zVZ|-EmkMuY9$oMK@mJzbs#xy$Z#8kwovN}Rypv7KUst)N#V1kbSULm%kl)1(2SQZ< z{tby{$=pETo-CsKRbaI5kyP2u#@)%HU-L9Or-6L^%KA5d6K4ONCvb+^olS=90hxq_dtbLN5DX;@_WlNW5vDUpL{vV!$s+^QMwV zWDdddZzJJ9eV4A^gkLbdUX*OP6RcSTL{&ARhfh)G=5mLcogUHf|L*Nz0a5 zv=WNtQS?ACZj%)^;F8Z5-Z4PkZHB;?xgi>^gH>ZAReq{%xJH+DhD-#<+K|PTX2rHp zgmJXp)4#F-j$_D<)zppRGO|Nj_&rS)Lz|+&$`%%C(%}uo0CNg3)?GqrFVqnVm@##3 z{^YdjdjFCy;GJw94nLPcWXdLyb3v>(#kGndnOG+Y_oPe#f-OJfiAACr8e)vN zYJrmKCSo$XApb@^^j_;%u(gK4WxpTs>HE-@@KB8atE+?R69`Q!*zR=pU%3s!bIy)a`U z&^C(4bQs1c54gxelWbzZ)>l=5JfTelj)HSRrU?yn^lwH=5K@CYl~zPb{5T;!KMw`Es7n9|WgI6tAG06MV( za}q0p;12kwY%Ph$Z@n;AaS25znQxZ$X8O5wuz`s;JeabanA>Sc!zmX~cAS~(alaf9 zt&D*mVc?ar;a4WfeJM!Aac+rRILDL zU153u^`#Crvan(`mcrq|!?(|_QAB-~l`B9MZQU|?l&iK6p+96YFXYXkC`pkdPo!}0 zh+Faq?Ln9_K@P6rfgj}ws5CmJDuT+@@m#KmY zR!RsyFmOviCnWri_YY@g_`An%h7TYFlcI8{2dmpYW^-Pho)z1%4?1veXv0O6Q`o9=Am_231rbHvYCl8cXk-s0j zxB^!mfKX-=$oY#Xv+G;n8J*hY*JxAz)dT~Q+>O@+6NH#CW(?_Y2Hqf^=I$o&T$7%1 zv;LfsA&Qo9yk?RD$>?tz8%+N#umm|r&LN%+C@$%NM69mD&kV1=NFI7}SIkVNi&&8M zCj7h^s3xn11pw^E@sO-cuq`P!-S{EG5sS#;g4GXDS6PHy0G}F7%a**d1=6F4wybH7sZfqvbbe4j)a*bzu zKwG_(x!x0Fh0ct?{8@O1Sl&6JwZGT5L&4WP|(^d@lI@#Bw(!ZI;5@9-MFf6y@_Hdz$3^4(pMAv5Cmw z&yjC5x)Y+h>*$-0(w^OR=oSMK9K|z1X~py{B{z*Mrd$RXl?N~CJWh7944=l$c!&(=jB`nB9!!kZ)R{bxG!GvE4M0NAgUb96S*dFrEw# zNDBIC{j%Jko#geLO&k6KQ%u_ed%Zi#Ry#UMnK5dC__JTkTvB=_12Dqiiz&el)(tzW z{aRx1zo@JIC*L#IGD$Yt7SF~_!73N;>8*p9ddy4QRnYVWae5F*vK-%cY>{$!wYxoj zmlj~kL`P(yXR)E5TSt(7IfQ$}(%_>8$hHP(wpci9VpxQSU}EGUP{2L4(7dY0zgI~$akW{S}e5xkovX_!tvs@UYJ=Kd7D#LN9O?C+pW_#EYvbQ z6S&zy774CHb&_E^y>-j$s>1HnBOI&qeGj~?W0SlwdF)beOeWHg49OcK@hzb%i2H)$ z{+2VRq?lKkj?9urV4)vV00k2Ja0kX6)^o`2=+95j??eQqDbplTwi%h|=zB^Xf}}RkE^59*kbre}^yw9Np-EdsIrTdipwzUWo@+JlR(88Reyb&n#edmZs- z@})v>6GeGV^Y>S75;GFwiyMREa&mx5C4{mM>Wg24r})|RW=!g!ZwL#O zVk0R6v^@(nlOA|yxAi*5YJ_eaHJ9snj*jk9AF{=%$Sfa3^xqzf-B)`5pijMAlYIMc z*U@pAwqTQk*XR+d@^dcnz~{b{i-#*AK1yhQLT0(K_=7vS#Eg2bwZ}7t>sygwF9LLx z-4l3A@6rdqY7UU-#tRM&d^fTf?`JNEphrtBdIsam`)PTL=rtnfZi>=%H6RWy5J%aC zHIKcV2H?dPW=6z9_-A;G((|s#N%FL!?dn1WTZC0}Rzoukcl^BH3jU}TVQ2LzAMue~ zd^h;#3VDfVa6(3NUZEi?G0HQ=vS`<51JU+JN2jRJ!`RyyPwlhWS#Evt6AfUC@^h^AduF!JEE7=PMNbGW8orzthaObQAbvok?7~y2#N$oX z-jv!RbXROzGAC4n!%-J_vw(xj#xRPsv{M=PH`lYHbRq~Z%to(?xCTYs=g$?J*z2i( z)*wo|^%?&Q-}Q*}vm8!6c^3X$#%(X;=k)NuqfS5n*VJp7$}M-R_QXy+^s@a9-Uo*7 z9g?4HC_WxM<~TsM1iI7H^0j?S*Z(G!jNrfb?N$%Wxp~ay0a21a!0zYPCFAzyzvEGE z-}^t=WHhr}QxK=}De`V9hz}<%%3oba0}fyNp$v|+2l0o@bJ5o+gZyqBa*)2u;J#Cn z=E6j`vL~s1 p^Z#HL_xw|H694){?9M6Df__hw`SDwNJo)S|+Y$X#T2HXl)1XrW z8gn@>b>JIB*Huzqq7jkxP0`>qEd4B6<;)R(?*8oD^X0j>0lmYAb6=9M^L-Xq&YwZ} zpAVn?41J0J8UEqt)khG?c@cf*;`iVJHrhwc^&y;*aveiB-u(5&^tcWGI~Vu! zlb`+R1>%0{?*Yy~Loa2E?)(}1AUTR>d-;Lu+32?ZT=)%m`mgajf8DrC62yUv&;DAJ zU9OQn{C(sqB(2x~{-zf3gXE`FrpcFdYxNf&vke8vdf0G4$#^!)@6J=1B6ed1mOory z7b@Y&udz%F&y{3<%OOWA=W$-wEV8-u-<>JElfZHJSHQvDRP*toyTAW`<#JIDBPT~# zSn+_q|F6pB!~M&frP0IA$JtSn4A-G5F#d8wu75(=b9P`BOq%eSRjjxMvs-zK2J8J1 z@d~v6{}?hzZzT7B88X?qZq{ZiJp_2{(;M=!kUfw7$^m&3^69!HGA1Zl-wa zrokN&Skh3Lg15a%h*DjWWTbxGrQHZk6DWz(;;(4O>D+99;&f#`wHq7UfpokQP~|mk zFf!1yZ#U9$=`c07i<2_5bgSqvv%WtnWnN>#?yjjvQ-CvfNZ{?XaPm+TT>=9vIxSIo zF4Fyn(Wc33PEDiI);=$Pnc~|_jZCByea9ott@lw!K^*~I|B+%o#J%Nr|LoHF{`SXF z8M}y|>T_0)#Cc?Ze$%{NWr1$;Ic>4|prH@=|CB9>{O1JQ$7L0=PETZZEvDlfeqRim zW=3>6rw?ynmC!-Bo6Z$B9Dt3)c|92dki(i;&`wJMwwr<{&s)CYd;DeH4R3a?@LWw- z{Hz3buBb=}7B35t62;5KfY{>`}#hc{O5*Ut)Ly;Hs_O?amSoyspHn2}qTcGE;# z^SGp80r&lW-ZQg_^?uz`JOBfMfO_673Hv~A(caQASSw?s$XZKWMowG$5Caa0vgQJt zHRcx&aB|wabW-9&lQX0%EeiUhH5B^TVH{53<(=a2<$(B!>hY|bgD>u)!07u`3?Hp? z0sWf{ECpVlpUL>$d(sW9De67UO%BLsJEsxa?@`(dOXVi%gk|8CqBE_ixGOB-y z#SbnMGFE9MK7_|^a@y06SbJFgyFGMgLTN|Vxj9Z4SsaNU13FM0yaTUBHpiJQa(kW( ze*jd@CJ7RcNb=a~IFSiM$t$Lhv!eUMg#+IYw zo>;u$io$zY(^W5Z4ixHJ>Gu_=z7WtL9(tgZQP#(`69|d)@>EK}0<^nPsu{7A*dPE| z036?tq%L=X`SE~CG4W09$;=KLV zk8)udt?nrAZN`Sp^=%bPA_`?sf8`{z>W@?7t=pYqi^H8m5k#=WN(|-H^cFy%D-Hmj z_JWcCDqstvp8WlJo_@&*A>-8<=V)AkG5|w2Wltl7d78P;Gb|{_+&55j_%w=a0}$j+ zLtQ-Xpm)O&1DQgb2i!utKYlDio|JWUg&lEN9ET+GrN90(5@@}u0=hndKwm90cHn&> zXXFS-nBi9)1Xr9$H!2cnf{e|yfrhnXdYHC%O5uWAB9*hVG3!X4IE4vnpO%HVtA-|V z-cC0=^VI}O zKztcQp|aT|T3i^0TZsjeqoVM7-ds)2tVQU*@+&Vij>IF>55d27dTWl{JtIcfNPh7$ z06i+{*7r!L?O!2a`$-^3dGm|Ci!4CIgt3);u+CQM_uFX2pk~9zd1g($xmv_5}en^NkTI8-M$KF+%*rb+ts0VtAst4=;Truk|qaxrxMLSenszxMr z=0VtEl&ng?0Ae6MVUL_J#cQsWvWZ#3k6Ao{`w8+K6~~2lNTpvkc}DgAlexVfobc}C<^-wOT+@t5 z^{*{~sK>mVLA}ax$GH`gK9Cf&KE(0UJ#JgV>oLx^Wa@=jWk@6{F6g3+5^7E`o0?DU5bTE z@QQ=>J^0R;*|}!U%v12l>M;A&CZtmRA)FTsfQ|8xm3>i8s#Gs|W3Vw1J#6#I4*AU_ zsgbLwi$`3B`5R!DRGO& z4kTc_eFPCdKjY05e++$#mr0qgrL`f0)U#=oGDBL`^9c+h|IT6TsVZiqNE-FgXRolu zWtWH`l-gETvCD|e{K$=%olRk*3qZ5EudHJSSKNY{ja#rp0N>UaL3nskz^$iCyX4Ku zj_~#b%I6dd9P(+_WR?xq0++W3mk$WX{2jJ+JP7fBfq%a06npQ1z=J3s%{x;3$M9E_ z@1@{4lxEdOi@UjXe;5ng1wj{wx|P!fyEzS5aI(jbin}570L&X{{Oh;bcU)&OK5j1J zl5{8%QoCR9#0YnaLGJO1!uABqeJVD8h#X!YQZWa`a;@t*HnKHI;Hi$tK{i=s8$PCu zHna5E0eHb?(5}ztdjKl)$5Zb*bJdqWf1DHz@p!p(a7(?no#R$g;}22?gB>Um>>hgS z{FQ;R5E~;yFNVl?2=rb?f!u|-bgIyq4H_r@cx^TQE@HISem0|JPh)iik76V z`zQU2{ryCE2ubC@XCR?@h0kg4pUU<(Vl;K+JT*SC_UVNfu~KBE6BY)oy%HvQiQq`C ze*^&TBB4sCfZChF>PI=NQ=@N2>6URpN?t$<=3rNw{O+5!b_7Gf>yx;a^E<2@^0IYRJa zDqj{x^mHlpB2$woTb>@P#WolGwI&#{;h{1X$zOX2k13PRVxZYZ)x3&OX*T}6$lN1q z*{Gb_@yj?%9O0}5Dq8?`0FT{+4`hfnIU}KWkx<8VA;&U;i8z~CaYB^uRht~+DIs^F zEf0{D*ioi=e~a!s@0p0>O^6p?p;zHz4I8=8O%YVNz(XpUbx?rgP4cozprH(Y)4Cjp z;@f-&Vsb9DRf(R|YFdtJ+LlP~ghjLdNc|=+fyDL8*YT0OJuAN|kIwrQ#>>@j)?)9~ z@}Y?1N>ofxi8tIF^zys?u)~17A!d>{D#Sx=fu{8CrkKJ2q6H)YR~^V zCRkmFow`>_yGuRi;qUle@>mTdZYWx6B1Zm&|CLa|ch4N4Sm(!&BZKHy!hVIe1g!s??ZaZum!Af^1cxm6#?039(nZ^E^5_@1`X8K8o`0afuQX z>x|nYs5CQ^d1G3K*b3UR!}X#ca+%RXX`CTHu zFv(o5X*pzYYFQVL;rB7-D}bXN(VK_Ewjs9#H!34$=4NQEU|P9VWX<$k z>?0I}nH5^1(iQg@K+S8&_GDxTJr;C4s4!neUt!>S&$L#J|43ze`MAI_W+ZJ)My#!yu7mnxn zzJ}I$gBqc=J<`Q&+Wk}Y$T%ypbg_f-bZh?B(>^?UpihCMN*mOaI(YEgb~dmQKvO?v|M z819bS)XfyYlpA4IyyKV_t9<+^RC(fbb!&K8pP%SA$kev#(BB9Mk7K1SUS@=kW|=9fw7;_i~w zKXv8p4oYmMmRvq*?K&y>w@%sLg;vFK6`$JmZ-{1DmwUT3A#acsqghLIl>jTIW%Ud4 zHcMgs1-=lADqW>>`&-t;=tuo{o%3y_dK0CqPQtJHiZU>xzbkt_`&KWrggga1piR?L z&Ue0DNQj+=cr{N0$KQ+BlQxVS3t{FPfgfeYX#BB>SDjc6(>12y&1CmDStUYd<~0E` z=pDwEA12UdVdw;Mp;=5l8e{pm`40GXeFJ7hU6G(PIWihR2n{QkJ;W2TI`Mo8j6DYc}KVI}Q3nCm8mOEY=u3En^S_B}MHCjaMJE)T5Es z+4#893O=B}QnoK8XVbXX#^?#j+g9-wVh~RQkIPS=zi|29>Y4JeA#usf!Ad&ruKQ)F z@G25S#2P08OCRt+ubwsHOfP2(>7A9{OnswN`mwZq{-WS-L%O`lh8x=HaiR2WHIawZ z!YA8Kg7yi13r4b8d!)o=;-<&Pwn*FOo9`#P_sV8LPc4{&Hvw=d!tE&aDhz8wuYi70 zZk`xxZM*UDCsw~z<6-NoI{4Ga(#@-B8qQa89y%)@;m@Ub%JDowLeq7oYF>`^+C5D| zB_H$IWb6wecSkY>4Ye5wjA5avgzLaR2Njz$b_P*x zz>vfH+}4azHoo(|XfYXcMD!*bVS6wy^TE4ZXR`2qWxT0R%%F)9A=!!%gRqIU! zFBclO~2LtL!cO+@>>5(Iyc^mOfG8KCtHNNgVg2@+^K1S@MizK4u_meO75@x%e z6v0f4p+D^RqcA%qIrq2=padT1VU{z1X>A_;^4x*(EnafRtvb?6LwRQxi2I;P>Z_M9$=&%5TQ=pBf6=flR)C`Hoov z(N4yuiIEbpR7-d=vnvMydz`SnD8aIHA2%Bft}W#m-i=?|&948JS;t15ad*lO0>CK| zRcxkM@G_Cs>kj`+Wq5Yy$s2Z_nM48ZNH~b+Kf@2|#d!4z_pE zTUyZ~-K#(To^rfDw0JBf!Vf7{JG1a2Gcsp~9zByLu=rZ_edKYpHs}*@xM6iYFeG#? zclyge=0NK=;^L}j)w%T*%98n8{pw+XPbH@N+9|&Exi^ZES&!#Q7e(-XnS6sHR>-*u{Z&7AYwD9fa zKg@5BG`=R9g6|2tp%$o)Vl-+8vSHA->kEs42rYqOL$vl zVo;UFd1Y`YoU5FvII`61U=`k87UG{KzkY(3(1Qi*Cn{3Ii|t^~NK`JDJLOu0ap*Xu zP}4oOn?awSv6kLpBp`8{R?)M?8m?80-MMS(O-aB0X460&9J`e1MFq^P>5&?&Z4a zvU7gysS50*8|h=$mdbAkt-NhJ`ZgsvT~pcRDKA7Fv2n?L#>6Y1=Q1g*c?^HLR^_Yt5AjS1D&J=fgR( zn|>l;`0HNj3D~dS8+y%Abl?aI`cBGC6oS-r#ch~XwVHb# zAYK=#3xlIkFy_j0?^Ul3EtTy_%AZqw7WtwGK(%RGkO*SHq4KEUeItw97r(YjyL&9Y z3SF9EJ}G=|ObQ@jYh!W~w^jCA7suPLq7(3o9448!xtOjQFUqU7MFJ+fv zt|fKPC%^(9mc&S{J>*G|JLu*~Rr+!2G`MrGr{7hV+m6$H z?_NPP#&?#l0<3~AE082QG#6hkS%%!~QxnHaV@KOx&Yx_(R7QAJKSao7_ywnB!^lYw2X zMUzKM8)fd)&TeKJrAc?+%4_x3bUV(J#_A`_Q$`cnTMi%w zy!Db#LrPeq+0(R^HP5~pa*>9;|0pzPYhgEETpm2X`SS0t-#68MGFz++crj3k4e?1$ zA8JvRsjUn`KQpzAFj$9+k1>1oiI1~~g>%$d`K+1Qw!VNL*z8iT7Eh5lZ<-va_%u$l z%=C5eU5%%Pz}lvje*Z6_27g7=P&V8crlfSE6r2-Ml3Fsgj*wcm{0~uc#v%RPK*U_9 zW#0L74y)=x>$LCles)OETfhHB)F@RUo%l*3&2X;*yzWLnRk^UdndbPk=;e205t`;( zlG+Da*>c}u{b%Sv2j6dh|68c3$bS@SnLpddave{msBh{hcZ7`me9DeiPTfyMxb0WrnIBpW509uGbrzC&zE3Z2BVVk|0NYX&s-58VL9qGt+o^g?(zlwTd*V2kwL5$RfXUV zYxsqN(dO!grf9}K6#1>{!E64_ZbHA;09e1r%l#=_yt-KXf2EKE+0QOQX-{f6{8a{+ zb|kGao65=P{FozNquPTE<$v5jJzdnzw4T%{y14E{s5cjxDo1^!ECPTnVZR!YJdw>* zm@G(+x1gnhL%8GKTJR;o_o+wu$TV=4_)oWQd2=b-^D-SQsn8235Y{{^$Mcn&VjlhB zL6HegDATnPzqQQ!DAK*8mrmn8)B32&Vu{ddRFh6kUPbZz%NbMT!%NUc z1JLeK>`s^=MvJ*5GophDS33k|jfI;_e)dby(vx!HV-jh164de2_9n+ry9-Vi$f`}) zAy9GBx({Tv#ER}|a{~p-D;XZgSUP6l-gtC6B9lEU%expJxO6Wz(pBZ;FPy}Mqq&qDw)7uQby`rr|2vAHCFK{AVFk(D~V`@mO&Nn23Sqp zdAPqQ;|ex)2IU!kK$5O>6zm@3ja$ow@n^BKH{9e>{h5S1Rm-h?-SKv-OyEH^K(q1^ z71Oj;lHjrK<7Yx%3Tw@CTacjAi2l!hV5e2HMFV=^fNi#GW&ctjP1dBMh^KMg+EQz0 zhth(Y5;iR_zNd`A7+3zsN-b$gAPOaVHX6jn9cOiQhBu%)q*bY9{ZLsgP+c%d3A@D! zYc6Hv^B6VrF@o2?@S;60koy3%vd#8dDrzEg_krcdr1uA%maCuClr{X)4>hiygju)?L#a!gBT1+p3l)Q~nCD_#I8+|FRKgwQ4q|8<+ z7jJ6@C2|7X@zIuanJJ`jf_eC?*Xm;(y9|kF{C|r z%^@Mff*deiN3s(Z^*Pj={hCMDa;8W+$lUzGj17QCfBFaFnat2NU?0ecB3Ystx;Mw5MXVegf8g6k%)4jI&)V?vN*C_ zb`Gya=c=?L2F$Cs11b2p+PDtF45?o-zJQ80(<9EH*pUy~sN+^TaMXa3aUm7?B=XD1egsk*iQP)w)fJa-u)nW=-w2L+ z!#Cel7a&Mt8i(6CzfhkgZ_KCN};83MwXJZ#l<{+X0*tt-61Bb z+*y2gEU}bxQTO;xB_^cWw-_b5p=5vjc2iVbM`%1)BA+}mWYrE!AI?T#3c@=&IGqq}wy6AO&ChqO9%t|X@nsf2gsim?PGt|f8n75Zx;@>|6;o%_-Nh)Y7C;=r z!jZZJ>O;W z%@6jbd1^bXZ?2gIY+AXtSRKIB>BL5Z0ON+fT${RsT8mc9Ld6OzVEQ#PW}MN#$U*^d z$rs%*CM^J_aI3S2`lw*6Un|eUBv1Jo1Iz6^P03k@g3Q3Zymkjh<%0Mmtzr+n;qos* zjKX3x8p)<0JI;u*8Zct#)K*riw;Rl~0bprB0yg~V8Q^x+ABqMM)hYA5rLF85^MV5= z9jTbeRAqF+kgxLVDB^=|6EdZxt*9;W?>1uOi#H2cwT#xZ0(dw>tm!hMsx3!+r3?|W zBx~;*3ru@zRr@wx0oJEcCCvDwLZ*lU<5f0AP(F0fTn?8ivO>$staiSaHWOe#V^*$i zi`H;0*YHOq4w(HPs?PeY$p>!R12(piqZ>yG(kLA=a&(7uOACU42yToT-6fYd&jY3_i=r$^Eww08vfULzRDvl2-01t!Hc*sIbqhE zmVsUhAUtCL_y~yaI53iof4_@dr(_*qt0n+>yT58RcADS`4JXK9z1(2SSy7Vlo~FH4 zqyv!AF`F@Xpa%Z%M-Xdljx=J7_e#_|@CnEiesmYsE-0eYlYZM#*ea}lieP2`Y-z0> z@mv2Mf+yAhKQhJ>LD_*h5ToLwFa3))BmH-;%LOm7@$DZ37Qe{=koEbWkD{?Pm!Tw$ z(ea?VV)&NJ+s8hL*UV3b1aRT_Z5Rq7W&A!Mud?%#PpH(%#R7;A-j;BoKp+&iMW@q~ zh^*l-MuyTR_4C9jb2D1`sgqI?ZN_sk;t9-HaVzJ#k{RmFsLJ2>Q!V$?x&MM&TUeAz z?$&78x-W z;e+6|fw@}lX{}+zQV#yyx-UOfi?&M`t<&=zRi00aCym41i^=*KZ@mIpz(6#xMNw3d znI)cuj_rW8k#x=Y{(ssU`3EJcRqc~`YLUz_Z9DARb7$(Inlz;xT%W%{RTwkiw&NvA zC`Zf>*Z$4xc>HU%_jtRzk`cPT?PFmes}U$dgsVfXzR$`#<{ z^@7SmQ(Nj?*(IO(r}t{6zw>{ri>kb}_1}wjRp1c6Sy(=ZJm<$2y0gL zr(2#r3t@Y$jQAH!djbHf)WWw(#q)(?xVWy}UOl9%v&#k(@zvIBrmtTqpI5|~L zDmr54W`_l5*am-AMCthY<%0C0Ld(4J3$L?&Z0_)W zywNoCmrYB4d}Z9-z?KUj&@y2UWz2A8$8>Q&BWQhGgc$~OGrUdcDzxZGFTB{%3Q(-H zQg^MDRXvGZ>qpSj=%{{{FS$S}EWM_k2NiA1ee*ZZ`;t#6`@FB}gfWl(O*`q($P32g zBAQz7Lk;0#I2k}`O&^9tc+w0a&ch+lp_e|+xu{;OwFpzv*LrJJx81orFA2;<)++LD@UQ8IPHW3*+1 ze%pXtUln1|AY*@nXZW7h^gTnrGldI283T{Fkb=i87+M9$D=W6T$C!jepd@0y@e9=0cZKkriLen8WGlqe7C)%jbe`waOme>?_~ z(+3E(5M~%X9m$$S#&BvRhL&e&r0{&5M1eMjX0fN-G5qADlW=^NoX++Qqj`9{4|}aq zqDP^qU>FgFo1(xb6N5)TDO5T1oO!KgV*PnkW1;e&EM<)y)l&5+F^kd)(AVZCJdC6t zX2k*SXs@KA{h7e5LJ#-LWN9J~8UZb?V!6$K-`(Qolcs zG*j(!Wfw8=_D z5BT~h&U|g8ngt`5NvBBePx?xgJ>uQai~W|K?X|D+^&Vga^@ssEU}L3Eoa|?8?#lpS zSuwZ#i)G|3`+_y`m3>I^RjQsMFY@hqIAw6S`^CIqYTm+ZebL@?o?bD_kygu(V&oZV zk1tLU$|H|gJ_wzP*%0d?c8aqRK)&Nk>*is*MC#}Ijr#u-_s6hjA%oVxF2D9$^8qsT zI;_D;xB|Be)TP}+YhXjz-KQTfgvljcxhagf%FJ-zDoHTydKHQ<1o~9G^mh*~8*XYq z@LyTiM?U4)Dr1dtE>DWdYxBHk@qN(OFFF^21hg_-sXaKb<1qUNHP;YkWulOw#}(8?2Y*^sHiavK36i^Y_8#MDd1{-OA&tdO+>EuX%YUktRCd?P*W^_7Qecd{e#RXx4bcLyGY ztTfy#%42~mo=HFR*f4FpklO5lVz+7*5@D(-5|0b~bx0nkiKb13WDqR49#fTzPTTgo zP$%86FytjJ2WA<4m5*L@mF-kWN_B5|sUa~!k-cBgPjXH_ZlEpoRjW8rX5_248zm$% z7mOD-|ENa2^VjoY(vN|e%ETqOo2<-PJ8WkON;jQSOJPx(wg_ET*89?)pONy9q;~sE zdb3pKFPkIl8y0oNUT#84VX2k;pp?zow8YL7O}o-X0=hW%`IsMZ5`+xxQYmV@oI%h2 zg`VC{Klz(u600lIYWlv_=GC&d!KUIkpnj0i^-m~c_$$Nk&<@rokqtT}EXL?R#6GsV z-%&VPlsuWbAT%QWS|jciDJhUXVlP?-_`ba8RmbeS*x(S}nJxd*JuSUn2bm8NT^}CH z4%%pwtDx_?Ay>mgo>>A)ejDJ;v7sS5ZfK1IFILQ0h5wb`h=i(!!>pt5RA+>&W}xnE z42|wV#XpzZ2x>3b-#_7W$|kx)Bcc?VEhu|jtukeBm-1Hdn(}F(KM^Ne)oSZygsSYK zY4@Uy+)~uUd?M@98v+ztbj#i4kD!}d1uvc}SijG0hC_KK>hIcp;8@XL-2`iQk8@p* z?si5Sci`5_uWrSy5Iw-Z%>aYRk171~BwBx?w&_~#3n=H_qPZw*0f|7di!2I<2s9aM%X1t<)d^DAI6FXe!-=D8kIuT}8K^{qaRfz>LVjl^zs<-Yqs#$6a0uERNPD}QO)duj9d>Id@YQ(T>2 zhsTkGD<_!7T+@|Jc*PDzG)J- z_mhC5@DMr6yd`yOht-YK=0~*;Q!q#M{bUx$n*gTurOcr3ol3^& z@9&FufIqtLafSWp#Ur=$e!Iik?wv+_OqiMUhLtHZ0YPGlr<>Ur2BOuNO;NC8Uae%F z5@A{NVWHrqUO8jg7Gd2L!TY4YBy7iecBZ*PICsH*@9)_?r`mgX&*7aw2?LJ*E#v!- z0Tv-9{*M9H#(%+`sgO&o0bF@dIgmL^MZgI%PbPt2+7UiP75`6_$3W`aV|jf1^P>Xz z)101SCPf5SeJ*vhi3-Ecu@fHWCu`Y&f+NVXQB$J zL76Y<3FuIUhZTG3z>~^5#hp>1D|YfEm&nWx%2ajTxDx-3ZMIY{hDzhoKv@B_K zEw30a@zexii^dZBZt7$GcnyZ}@g*MhP&rf2YK0kIQ@4=Ty1=PwPYluhp2N#}wAi%X zT7OA9;gzhP6`4Eoz_g*2qcJMJ#!-dN^nzbpH2{@Y<19hk)%<+PgYlN1t@C#lFZQaX z!fC=upsn?70%vpO;+&hqd(JuB+*KPc)}6fC@W}=lyoUeiMAPnQ{Y#8sb9VQ!2a9N0 ziOzo0pje$>I+tarQNX*EPoLPS0O3H+(D$D~KLU#1d^&H^se}Rmfr}HLFWzOs`%00) zQ<^9nmShQyt&{HoBgCKa0^6%WPrfTzzS<0jm~n$yvF=9`wrQR1-LTis!tUQnpT*GO z9jYh>fg9E%`Kcj*EFnZ2w>{AtNpEr;pKyf17S;PeFQrxUxDV@}7=0xGcpYv|T;4`I zWR7nFQ~9E4oS080P^;ZTat+3Q^am@9b=Y z8=o+QrRcA9__kDS4QhNT!ZvjYMYLDg76EfMJWAqllFOMEIZHMW0C@Q!KEex2g^#k| zdKU!0r!yL5TAO3q{u|z$(25n|ZMTdU5v$Ve_N%UI|`PfwqH3jMkIxXD^0; z&t+k}4Fj}?{8;MaSu%qLZm7jdZILG;=Cy_bk*6TC8PxQBQ@&*I#Ov3azdZJwD@zK* zWa1Y1EGfVP8SYJ%(cWf7kJt?v9uzA7p}MojR}m8U)4g{I{fQ6k34|=p;eO!azEyTw8 z{Cl`4+4x+rNRB z4ic~1jsTMTH*AXGJHHKDOi3uI9x>WZ{I+fQf$BxZObExbr=yunzOmIFR&4rpqtEA= zLWS19-oy6+SNe2$qEKjqEcS1gqPpxdd9!SwmdoTx$&<@X_f0tLc= z0D6Z^0I~tZ=0YBYfZ4S{OZ0-B+MwSg5EgwMqN?lhbOz{hlb0-vOg4D6EpFEx8j~DM=iCm8ny_6swMo z*^N}fnZ``e2Y3Pk?MeZ!!~qTsA#u*=HyN>Gl6Y+5ejMB($t{fJF-n7$s^G7a4=aLt7Op^{zSst<~~;Er=PzI$3GsLb1bz^Cv#jN&F=$B zttZr>ZRzxO9>&oyAnkw5sXp4prR z&kgfs2pB03uI0g-2KnXQr0Jd;r()sy3=~v5>3@CFKlLlVU<0Tk-t5d0Du_6u?%asZ zl}T?UJ;^(u3Kd@X+o*pHpe|>a$V>7OWWW^|ROLec9T0of7h?s1-Wka9!IpYB1@gj- zylfNl?2M#4F6^)bBwHEl15Q5XhRJ^;YJkM%h$Dz;QcoK4tRoZqs6i<-(T-*;r}Wgb z5m5qmZ+McQ;Pbu>KFa~)6`)A;hyAf4gnzbvZAD(7exbOUlvP9gWgGISKDQ0siVbto<~3N zT;eGalLW<<7Q5006hu2zu4iZZEX3PY)KwMK24C<-eL*mtZWR4*f$ z!Jb*l&3(+uv(hIGMQG*wBzQ#4Zw3sYYzUqjVI+2Xh6|vtr;U+~$#O)JwZV)>*No$@ zLLt#g>AX<&cVan!9RKXB(@LM$yRILk0Z~_>!;s9y->*vmY5vWHJ_eye$yI$Ni25q# ztV6o|H44}}+7b;~{7wqTKB#iCb;(vd6n=;70csTzx+74zIBCEnZ@C0r-ux4oS~TX{ zRa~_MAcUyiWjOqet=Y#>ke_JG{@0@hm{UW!xoy~kr+j$z2K@yvyF}Nsr}%hP%s$2sr6~M{i~;Jj?gw|{2hV_2 zRm`*YI+`e&+95kDu^fU>ojdTjHhJFg;K>`n7*FU?V5{9ygxW0a@(~S(Ykc(t9!7NXHvZukRUg zZAO zETL~;2q%By;!@{!vkXAeFj4Ca(V|873nUD;@9(=th+&+CP;R!n{mQ1^_UCybp6q#iq)eq zYeR+FCln)|KP>bhZnSISbLaBOq)WEJ5_TS5T8(KqR(A#x<8`W*<-!SlAN+pavDD5X z`xe+gAhMCo#XTUj=F_k=P?N2m4SL2$V0ra`vf~Z)L8^@i0ykS{$-|NVYN4rTI_x)*-H41bJ?H%N6zk9p+q z+{2K2#Q#{5c4;Umc;tafzjD+NoM05QZLOv?@`!DeJ#;h}7mS2o#bpnNC-I|pOx-Y- z%r--e=9bhkaa7s-l;R_>H?kCkvLBpoM_i65o>@}5*4Wl~q!6}^Ha;AmX2TPwAG1OY zhXZflH*`$|P?=O*kLC=Ar(BuvPEPnd8ng(XB+#FD!#&n6J2}ie^=f{!mwB=%qkArQ zD)7T(2mjc~De+3p)E8NaMgGYYlBp)yX+z~<%a*jfz4~e9kAE4a4=q2cC`Awl534Z7 zCQrsZt(d%Go?#CI7&JhO*FEvvVJ#b|$xjT?`zRHal1KH!$>IxU`)i)_0$(b4Exky~toYbzR!Ny#!`kiP#w(t1k zJFWTaFLQ>;GmXhF?iap)X5_@Oxr44R3(C~Q1tm|v4dn{YS?oGkoW?>XO?B>%e-@4J{z zu`Jk#{+ryVMX*%jUPseE%NjC`dU~=_60*+ivpjFG%wzTx)D6ys-;d% zU6J$BV9l3Df)7dmP4ZZOm6jKc_b8QT`O4b!RWV=m2RO{G10;gosgc{!nii?$_M`c? z)4sH0WW8%5ZzsTNBJg_G67_l4`rocC>z=**o}=|%t;iSW{5`jxJ>BRfkAHjKtozU9 z_x-H*{imf4`RFuX?*~uc&;PWKP7tzWpYsN+vxp4~p=F95Y(Z!IW#~fE+40tEihaQofs~^n$HJp(h6)|ZbFv@0 zsVA0m-(TN;cbxt1#x~j{i=qbDP`o=C<0^b9&pk}sO9OuV-n#0CPD;A)635Ln-7lT9 zg49!=i_LcC4pZz&Wm-3-)!G5m_Ju0j>^R>GXyKI#3Pz=HIjwlU#$Nxp|3WtOLVo+FdH*lq z8M@?dHudZ(82gLgr}g_^D&Ja4Jf^ZSw_mL1OPjOT@}ZZac3^&yt5UWk){n5#tbJ3v z`g~nVd=xc?oO+b08il64fVe!|Bc`B$A5cKwDA*7;Y~yq zRnNyiQTY8S&;J~l(AnarM-dk11YU{T;ujlf?Dw%I+@^ zalxz~YAr-{NPa_H_=_m?25MJjY16-0MtQPH`R(`uP-xqyfFT66QY&ZaWi}j>2xkfx zlRzc!gz0 zM+LX(Bka2~&u*YJsB0d&_|nsNFUAn3_U|1T)GoMM+_ixLweowC*)+~+2FyCN3`2Vv zohOQ_3Ou2O0@Jlty^$;yR~~DvIM2=DJd0o7qaMfq%OA`6<^6??qSplDQ~hbTfnX-Y z`m95#1t?LfH7@afEQkQXcl)`qY59~zo9<45_c!S3|H8qj|8d}L-j`(g;-8pL_p`oQ z_7L{$PLGOWNKZfNwS7zZLyl5ps`z8+=kTrH8s4Q3!Q%^s9A8OZ6zP&j zVD&ZV9=t#^o6J zc61(%&Hyd$AK zoMbBRj^DOriSR=^ha%DTG2n{p_j+`yyhiR1k29;@qL`PkJ277zFTNVF8N5`hGS4vM z3Rn6Ztrt!#O1&3Mlaprq7DXuSQW(Yk&Sjkv0B|i$bvAb`ej6$6R#{&54(C=~U-!{5 zt7X`{E@b^O;Y&0X{i`oo?^R_yTbG^Rd$w;Tq&rvcM9TEyrIvNLcIDN*f7Xj5l=bch z-_O;}Ye)Ih6upTu4BRLmo$fdu<86qB$^_`N~@ACcl7%aD3s#lY^F~k8vDy`z{qBHG6j(lRI|_VwRJ_Nc^C(*5X}$-qAQ&1gf?OT& z`%*t5|LUZDzvmBCKq5`h)!n>f@+DNpMOLi^W49PZA_jHYfw z&GcbBX4Dz}N5}WT#7FL1llrI!mhZ2mKF`&xT;1P=asc}CVq|qNs?i1US zA)&^%G@B4G=)ihOpDgjNBvFAEJ`J1WqrdCQ%7{!VYJDw^N^&;NVdx_idXX=VP*BUYR|F%@NOd=AL!U`RLd5|rRf0UPQU@-iVa1sfdmNP#Jzb`bYP$?e(MKp4jV2=ed$Td8f4^8g8yIo;z*#tC9;n zN9>kwitAiTo4qOnON79Rerym-pUu1?LPkrv*MyoUy5~_PpH7uPfcRpiZiJf=Uirr~ zbW+vRW+b4cAR|M8ix>6#_Nx*;oHm}!>0yB*dDerOtgI;;WeBVw!+3%s7P-J zTT315;9+AtSsZ7_(7@rT@KvpWP;oqfRntY;o2<>d#Q|3Xs4xgXVH+h-_kr4AgT3UG zct4P+I{L<#maU-%%pjY(qRoRw6TnsE%o^e9$gL`aAy2OtsP`71V~*(Pe{yOKOD27X zJJw4^A9>zTE@O|4(d7E>22&;fWZAnhMv^HUogZvoPy)w6nW;bq?>LzKr#n$PJ1z|phxj6YNETk; z?FfnGTP{h&?&?}VYLxkZbJxI>|&8TqRc?LWRnS(%uH1G`JD{s3Rbv+>lBz3sR zN->E>k|?j11|u&W@Bw5-Szx7bO3C%ECz^#3@kLPsfwrCH%Xk=(xAb=Mzg9n@2wOr# zDF!Mo8~`p6_YfVLlmM&y1u~UqNKEc{%qTqvjw>v2BXlt?PQB2PN52G$0knL1$oAXP zTgi-<5*j$166z_c2svM5f*xV!XE-If?mB;C9;1lcx$#z>|~JW zkII1wDNIXZ2n*G>T%XwHgyhYmx||zjpN80!J)k`F(9W=5Z-YS7-+k!z&HuHr z|5zS+@?JX~$x2wIKQ`fHLtnZqGLFlCHl(og1h8Gt?sN^IXsFyXXyb=dk{yE>}S zhy<+G3Ag}9X2=ZuQg}zhJ+!8d_grbi<4dzn-=x@Tqk_dd)WJbppvQp&bn-l?YF2e6 ziZ9DN)Lk4G){reOMdPt@STw}WL+N=ofNOcm6B`DP=c6t|!G$a;z=B9gxUh`qN;p@M)K4O`^g6!w3?}zC7uPunK8EXRJ6vD| z;SG2W^7SePMDo(R+m z=H*NHmB@qZ$_SqpHBd3n`ET&6)2N_^C~ptmhDPyM)FT#*kU!^OObc%@K@+O)y-oNi zHf=O|L&}je8dnP-C5XfRvVeS?hdj*1?rtGtD7zyi+og=KsU#$L`snkm(KxfDjv)0& zbQVw&)x?#8&wIByswWyVJGPL74+>1+`4bInCMtuZ=Nf51XV^X)j+ZZ}7sJZ&Tv~G- z#}z9dR>X~;eISxn!kAEvpB0Zba!xconjkSyPeE#KbTs7hsJ}DQ%nqC=t`K97X(Gsh zj0TNnAtx+lCwhUCebkdvK8pS3*bZPc|0D~F*){rHA&ZGEeQ^A@@ zz2dlRTD9H)rc-F;sON6kDg?ALGBU|c$5`uAZwPb3oMSDSnb?iO$v#wad zrmn05I$V+^w3ibapvjSVkRZ-Xlo(U+oX7@JX<(Jb~!^1()S;q`EsU( z$E;hVUPEB4uQ~aM){|)cjP0I=HULFD%h)>0G&IZn3Hy5|nj-#gO=INdMlIv_h)u!0 z^O;#r<2f!T1E0G%xG4JK7BbWo9iKg$6|J2p`Eeo+E3g}Ty`wurY$zDd7M4o!szT~U z-w>BkC6fx^5Lp~x!l%jE!Tzb4=N_7uU7VNOov(yFys3=3vC~F)&dZA}D4|pqluuF{ z##PQPAZ1&#*?saygV1|4LPGdm#j#YHCnC}ajC<<6N2H56z8orIi-szTMxUrOx*;ZF z(URvJ%x=cV9#GR#D)V&3Z1}X5F;ql`_vU7?@PUyr*OI;1l7j%*KoON0>DZY}dR#>+Pa;;#*R>DH!ymBMKrze3BYhL!+hlyT@xf+I~%5}XptXCZKSVlXn}l^*}DFIXua#xY^64a z6J5rc1Q3Kx<0$}8q8(8t5Hw>l+QeI}Ljj{-0XCq2vK@obKLx+mz`sKDd5QI#ct=Sp zMf>05k|F_1jXF+L+`Ga^P$qyrzf>#NS{I27_@!{qeiFR8SAYJ37Wi@oDj${Tj+}{7p5 zBjWI6N1|Sz{I{fLpE()#Lk3rTts{cWQ0x#&HG3kPNH{#PXa8>IG`Vb<(>B77a8w*} zA22H9zXA~X6XjJ_#=*Mp;LIy{5W_!3j^`Co&=yfhIG`6J%8NcI>~TVp%4ZNFGoPmk zv^qsjmxYJ9cCaimlv(KH?Y24}wSCv&r3YP!@Fq?ewE` zwV^hxV*W-m893+Qmn(-4{t?qx#yChrkCKXsB=4e=5vkUW3Hk(Rs0PrhH)B^8KZybR ze?7H03D{r*=oJviwV7Ole#Ss-xE4s;t=jJRcVsCGE zjH<->Z6v(Kmgwv%LgKZ@D;~AO_X&+e_*rgYJGMl5xkkLB{IY-c0wW({GVb9{O%p*q zx#QB&qf(PiAiC&pxn6HME7ZQkm+;tvZTj~9?H~?S@Nw5KmU-OEcp%0C1n5w$N|#rHLo3=J{)R z_m^9OoBXW%a%Fy|cx~Y1t8eC8&Z5Qg=P!Rwiq`YQ;%`!o?-tBa>{Tk)-@lQ<+OVpq zKjm$ZYgu@wAceAuFa&(G-BM#7;^V^7XRM2rZo{5lIo};CTT#>xBg=qUXalq?n46&R zLBw5?!{PyNumhdWb%YrcTI(RHGQ&zu`qVPRNDv%swIlS(5`};T)7V-L8?7CA3i!382q(QNOAdS9Sv=HFW2tKKwa%#;^ z_oPH7b`n;$`S38gb;GZlPE*<~4Gtd1Bh4Y^zawpTNx@5|udizNf7j_g-t-=(6#SST z{eG&hQ{e@e&t16-!}}2M5Dp^2KPB4OtBAb@>OQjYj|{n$=7;tuI4ExnnxqMmt*D0JJ+2E=NYq>%9Z2Nz?Yqbu&E%6R>F1{$k5MFqB$ zqO~5L8)j~b2<^W}b)<_PJ}=6G5Wwb?h|)rb8Hg{sJ&m=h4Jx6e=5(l@Qc4w^fWY)% zV--O)>^UBXehPEx;N+SrH9{=4Nm~g$FvV}nFkGRi=5R#JE+x`{*z5PmE<2qm^Zr^( zyBPnl$w>!X+4acN+SCx=GHtbLaTSM$e1XqUP-*?%oXNB8cSN89Q2(#^cF2QLb>rH> zShX=uv6q5dyvgbG$^%ub-yp@mYgxAFy!WRq6zH69a$5R7vUqmBi(nef^8a_7Qf1X} zW_Y(Ou-e4X`8jK2LD#oFUs57-c7~ol6L%k2qm^9Otg$#zOfFl&(-f0K~x_wNb zlm|8s0FcbY$XNnmecWx2l!+M1RvjWKxnl~ABCTk`(SD?M&hxZ*007!AZ3syu)0SlL z1ptqsPmTIWsxU5a_sdR;ELy?OwP#7$HkhoqVDMX}}p0kO&bJLRC4S>F&0_;3|%k8?;>g;s*Srd;XpK}<>W5@TN^L7d?P zH$;zyn-Y|5rW(UByujaT?N5m2PT_jThW~YCAUj%kVR1gPQ1E$w>Kg?lTu)B0ZrrgM zifCt3YD-fb=BZyTf@;RHE<{@SFYXq_W?BVihGqiUP&N>tQZ3Nd2tkqde8Lu|-T|St z=uK-51wRB+=))quDQ3O%lt##l%3N%P(X!rOOP}pWWkJ zciwkZ!^?3gX(8B{77}^6uJU$qU`@ZIZ)s(rRoig`3#H-v?A3NbIOy&G2(6VniJQFw zl$LZqN;U9^YWDJ1FZKAgW~@f{gE00s^R$+X5mo(MoS|qQ0Z1wiboIeZ_(n!HnPel9F;unNe15NSh#Vswy1FD;H9)tWHl*psLK*JjBV)8Dwlm=v1zY z>aV;DMo3;r_A6xj#h9%VFA|o{2x`k7!q+qvjYpHhEh0v=`5fnyQeY3JFDhub7*nVL zSG0Dk!rY7+<2>Sz2q9P1Dk+TUs4F$Xx;`SI=Bnwo9+$+b&Jr4mH2j8h+VGF*NPO{a zBJdjwTKjq+CV-u-LwdCP6Ex1&KUtCBPsRPr0yr5RDuJ!Q5&s|_XtE^wO419UZ>u8n z?~DRR6&+6Ppaj7K59eHc8>8dPUaVtOgaP!xJQ2~H&UY~ppFjyqO7z|L34SNjStM#4+%srnDOG_n*U{hgRW{(ii!pqM0;pSBky&D+Z zw~BY+b*dj0=NoL@<9i8p@orP-#B+&!bOqFvuyBytLbAOuxY6`4%;7ztS8 z1E9xl92O*G{a)YFu1?Bmi+}PQF7YOtbcFeWW-+u`b44Z_Fd-fiQq8skn&X2ynY3-J zhFWE2XbzAX1*H}`KPG>rCrWW^ubAhbR%MH>ngw_{j=a{!4L#XNT_w|2j`U|4mezPM z3-68@78K*oWJc)oN&1+YvY~ZO9t;7@qUcCN_EdNbC~w*_sdPIlRlp=A6u+eTj+i4Y zU#`t34Q378J*#>`>lpXi@^CoZu6WV)gQ|53Q0oL$(RjtaT=kh=K}eBHM<>{W&(>X# zg1a&3xuU|~rSRTo(vR?WDyHq2E7=9ts^A~BOBjlYi>W)OVqZU^Lzt{nd-Nxp-Rl;V zkdn!k<%S1$i)wppUh0Lr)Hg=G~+m4cvfDhR4_A)Qfpj_ z?JQw9Fa)~PS_Gfg3I+p4UBL|0>kKs>`}N0~hQ4G?5xvyT{Exn2vHhtECxm011KF|Z z9uADNaiKKT98A#?bIeP#C2qUPc|VfSC3o%>?0~rIv{xHU%Eb03PqB-NKS@*c!E^5E z2F--W-URudYA&T>9)8kR54E!pTxD;AE_$c%=kwj;hFW6J4Tsm>f4Z6O&;R$RLP_9I zKlR;0nY>t)q1+<-T)T}{G-q_$VrluKSYAx(<+2u~HNtsGsnd&uPqF30S1&QOe-BWS}~>dnRlaUa{CP8|#GzJt%FTq4S*PTcPS zJKeqzSAQ57F^iH}`yeQ;xsb`^t|=27fBWk@yz(%Vs%wkzvE%?%-D$+bu5Ge?a_3nk zK2xJkIf}%1X$rILt4N%8!@?QLrqUQlH^C|ENz<%#pFG;D=exqW zW%JKh+t$TB%YN;;p8{j%mbKG(C(m7u_l6A}1_eD?#`yBQz=})1)=V2u&!pHlWc*N^ z@aWFh?kkho1x5r@*YDPcY{5$Rx`gfMvd%O4e@$dQif15J9`;6;u zJ0c8c*P1m;j_)6Av~ zGddr51vAHH8*_*|7mIswNSM#zy~y0`0=>ut762b-Ze9}vx0CmP#_=;zy`+%#`b{V) z@Xn3}ro=0|tN#PtHQEKZAH8dmmoCA1f^>rFOx~3;@$fjl`|Qt01Tj#**HzbK)}KsohhP5!T`-t zE)eCIz(bm6Hil4YWX?DefC-=HVl(PLXC{Xc@`GL5@N^yn3o?_-q3o226G5Fb@T|M? zr%v~~%cJtY_L&q&5^B`rm%q_^w2O^zpXtTu{JH3EJ6e67Cr zTA~Lu2#;%RS~esRvy*NZy>rEOB(wiP;R~q3uWc*{FByQqk;37sf0NB}fP%w3U*jH4 z?^wr6?wKX=NM9Xrc=G+@*-28Sb=^ruVwUYw;-RA&c7G_SIZ>%bCa$~bX3mwn8J6~- zVrwcRKX(06l0Mr^m$J&e9FMkd+{kZTuB?7DpODFUK_QhH{#K=Nw5q1<@UXgJ z7)P|8#@X$cO#7_Yn8)3gYKeBYu8)xKxIw#A(c^Wss(*!^2wvN9V;dLoNy~zX`$_AH zbM;BvhX2t?yJ{$g-;_^&{Y}|cS@rjK$8AS-4%uBYOtrMxd_TIc_o{#N+%7rty}UA~ zIrS!x;VcTAQGc_bO!VuI`3JI0yJ-kR=FBu@Nahwd8E8mNVy-!A&|VBTtHi zlo>C>K&0;4d4@o?lapRV-hHNA)s;@ZD z)7HAAQ>?LAWqx@_i&UAkX9j1`5CebD7NlfS_!H6cq?=21S6lxYU+(#R2Zn@XKOxXu z8DP9>z-rgnpAn@FNnbT4wN6TKWa^g_@MlgZQxRuwzNDtg)Z3&1X2QIwKqzO4kUGC| z%}8GN;!St2&PoT|Lfl+EO z`*xQ%Jtymu7G=z>!=q2lV|v5T)Z&d!YD1}-$27tw`N6w_uNa;orbpA72fhlu>Evpa z{F42c#yJntjFWRuXwHo;b%5>yqM37k zcDk+2xbwv86_ZcPzqXE0l;MfbtFnzfo%VP0)W__+uRd83UIK#go>3OIf4 zBE@Xp@0;0s_Agj#=Dudu%)PGjJaXZzBU6O__ToQp4l*CbuA-=jE+e%Ggt(IKddOlc z-d_cfCPlmP_F}&(^R^FcQ;291ulToseuXQWUVqCNcZbcvjEl8#K>{hwJsjKEUF(r9 zYQ!;P0AJ*NnkH}MYRA~%?uxqe$>g&&htt9i`KVQZe)DG93q2XLTp^LzH*A*S0q=gK z%b_Y5U{lYR>C^Z+S;6aJ>#FMk1943bMYyg_mL7f*Ok2=z5DUzXB*Vi zA0y^34^{vC{i3S*$z1o}f7Rv>C@n8f7XJ9}KYxGztj{Z;OiNuZ9!PK)qKVNxIIZMO`U!znIMT3Y+e^Ae}?dS+d+EF6& zu}x__9Hb<}3Jg#R5)G5%Lv(uJn@KFkB=4$w0d!=+1zW;osBx*x4CcY*3~qi-1e5(m zYaLm>)3#lN-BJWjj)&j95RvSowT^u{O4wd1%IIaw2(`%g${Ohs52EqMF?M>Pjt0n% zNhl}6f`OAwlp>@KiFFNA-g?dEehioHglCGtm5>IvygmMmz-JJ&tz|CNd2ok&h#%#+ zbS0$IHX9Eb{gV#hP!rFd5r55|0=^sPOQNt?66yP`*r+yerER#3of|z~n<~R*MKAyb zVO!TC;*Jvxh&hpJKpkKS1&pB3;YAD=;91^@pmk>O)O7?V5#Q_@)f~;geoyd#Qcr+O zKoa}5Rol|TZ4Ln3rQ(z=XB z_baFRW_fu{#fF5jc~Rg^Us1Kz@#71K8@}4(!+em-u_8|gTy{n?viX2r&>dfcU zSrVz3JdV-=%c(+8ODJfpeQbt zca0Rskjh;!6}rF*4PrzE2E{6npq12ILJ&jH;PcI|R9ZV$+D25`7gjoUSC&>i#4T3P z?58<*!{rpJyd$c7$E~fnpV0K*y*0(Fe**uZ4lTan54RUIEv0f)J9Q+QTGn7PtPY1ChF-ZxDX7vSVrpZvFX}Bz=Kiv#M!?B9 z)^IQB))uYe4{HpD6#eEIe_1gB`dXomYCSm%ew-3_B>$fZW1phgCFsl!6}URX_X1meEkGnKjq z;eIn1Z7nwIC?BtJ!)L=PhSCawD{$G2nlXREY&go& z=h4h4%Dsj2kg5VW#v~aX6g%VVoevq_pTPNdh-@W}$k9bi?qZ8s>2v+x^)o>cQGPlK z>m19q97@C7(^x)4Ts^yNEA5p2JF(ildvJ*plNWVx1qMWaDCU_TTq(!k?k`zgS zLffK&9s@P;mSP%mQF{tk%WV&9M_x<18gl|?^*5kVt^XVmRps9)=^hk1ZOFUd@f3j| zGJj;Ae+U#UXxm9Y^d6K**(lbIf@5U^Yjg(udvphl%yrwd&&I;O<%S^ zf4Qc6*Ci@Kk8n?*#?NfQNp9_F>paku7>Y$V3O+~m_g4L_L1a8;$w;%jf_$myiXDg3 z&l&*wUeIIs{6kxzvXM&e%_yLLTCngtN94n4Lr4!CDEy4I)H2f#kp)L634wu4h8Z<& zbU>3=hfmmmei2_?&Pf(bQ|+9kN$-N%QM`zBkynyPL1{!AZp#RkYznmS=p|Y;sk*=V z%g_m0PqBQ2ps8}(PePENBB!}?PzXI~C+vdl@b}Y;yXqk^=N>!!5W|7EVLg2H5o>vH zwyJE56kNDQDBqZRM!N%WI0X>@ zddb}&%TfJiF;^$nWMQ3KRlW42LI^_UAFndGe*6+L;)pD6--CL&jQpDne=Kj9N;is- zZ3RTRs{a}Tn!@khhqs-<1xMbg00oqu<}%{-ZwU`xmYt+SjEC{QA_c(z){nAK53YV; z7rZJb^C&$FjwdGVjIqk6z$F5G{@WSA9!%29(k0JN()7)MgUqLRMDda$3=gWbP2SV2 z-y^(sN2N?~ug^~JNE@gDx)_nKB8$PLQ_%I%dz<_@hnFBCk1-kC`lgwOY%VDUfGz7p zu#XH3F%HvVKB{D$e*wqyh5j0Fski=M@}`!+mvpIXinw>i`s+5L91eW5+t}zVLzVMX z-wd*stX3RA07qu&*I}mhJ#```z++tyDwSgp>cO1}2@a{= zUN~@t$dBn0Y-ax5*K$1szb+}UeQqTotLGKKfBW8wo~(lt4?-u`d_ABcAb~Ym5izlA zm8fQ}VbI$qH*nhl1zlTV>le}UTxBGMiGw0X1tqCIh>w9pdiSUTb$G*D_#=Jk(=*oL zA|1os0+D8&8qo!ERroWJ1x%Yd$)y)dEmQ5w?6*H^A_uG1NPds_tMNZVZ&;^Cq0Q9nn)d@N=HvAg^A*?st4 z?l(m9Tq*^-*JH?lKt$9k5$<;r{>2qRWVj|J*rNy(I2NA{z6sw=g;#Uz^7Vfa z>c8`An%9%QeuVEDcM#@CVw@Jg6R(7<(}1h^?L#eFuf$(URl6(yj^p=Zq{%$y1Yltg|D13w4@OT0{ zAOU9t&C?zn(ls02atL=|Mhd8`N+F){TW-_#zxetPUU_uKfEl^k2E3$R<80hclQDFj z@nvz@SEN}C5|H}e&3IbbL;mJ(_~LB9_YZX6;Rc#Wh*D(w>LGtH7P9t5BkQ;2`X9;C zU#xEs!vQQuuJzpRVG^1E;-3;$%Lgb+BFZ=w!H~VA%7SA8{t{T%__GG%Qu#v@bS&KO zOX)bi@q@kKgG^N#+(190P!pN)9zplv(woFTeP3b}h^!31BMoI*dF+2j*U13ZP9HQp zI%GY{1&YSMOqS_vxGJ%7b)5{dZ2RG5`^WBp1rt6h4aZWegPB9$sVJ6u8%3S#NY7*! zqvGZ4B@;WY1{^rfcf!^y(r%r`h5cEO+hcCzg%Mb@%&7obtu-w&n28@nfc-DsfhGFM zlJqYym#Ub8W-mONFWTBfPUdSr0CbQ0J{POIdAcrmqTYU{bQ(x_py6qRgie>RKd$N# zv)&mf(5l0}H3|{8h}IgYdK(_vXP)5ArGFu?ioYTG)8F@^tf5z%A24Ctl4PGwB?mmR z+T6)=L(>Vhy<+NKFSpTu(@8Q13D$oAE{733a zMC;A$XQld^E=x)`H=akU-r-QiZ?q2&Yx6cLbccMOYq<8>Y%+K3+hQINP0LCYc>AqW zRj)J7{W|mCkbgr(M>_LoxRt+{sdkPUKL^4=2R%`b||M(Z$Si-?J|T`CcRO?-tFn}oaG z$@1>@tGdlWN-PzVYnp(eaJu*Z=+y+wvaD6{?ab-G^NzcQoqC zB6wq38hy$_PiDkyNU2QAWTzj_>~Q zL86jI_!nYcamuZYBF!4Ml#*9WP57_R-`|dFz4zZ`g)R|-=~siFkh8d&YF-Rj4K~r- z^pctu{`YJ3tB0pQU%x_^QDE@!^tc~Sa>bPT$`Npk@^dcx(ap^2m}h8Y*%IMlrdD!{ zn0SPr;dTBytN^p0R8q+Ls22g11qZtsseI6OPVd#S^y#CUZ zUz_C6m&=_>_8DRZ2%_OaXk7+NnbDJ0bfz?ONeZf~oY0CckKb=bKjaokG#f-zuIAO^Y#x@N|Gi zL@+jE2*&LNeX^n_2zox`C3Uy!P3H1-X&MX4zg>Gy(r67kTi240*$~38$PaMjJBov1 z-#t4Uz`9UR#`Z)Hb>nBP&=lS%@z3+byhuiL4i`>G=eiHv&o#z+=f5f$@r_r&R@N+p z`|BlUtMTZ7z?e%x!+>@VQ~Gc9sZ1(K%J18>;4 zYWZDi@_n6NHEqEsC_K!4)t{Jjm2NE;T)K9hbE(hcqR)%?E=kYKKb@(C&1T~UWv>W}8?K#>9b0eVnN>q0v6!XSqGX`~R$^?!QMFfo9 zfpC8$nQWkc`~C=`m$8dEw#6p$C0=Ylz)j&-n||s+BG!|hOip}w`G*Le_u*)HJgQ&O z+h5P;n_pPhnHXoCw?YWyX@RAMvG!%MLg?XY-2Y{Q>#Moy0v+-5QxP$j!HQ@d;SZmH z_0po>JlC{*ITGSqY1)J?Hla|-*SPPsjVO&YN4OS!*8Vn6QB8&35yxIGyMMQs|0N}c zj+%FsWNnX@m6B?`u1JhBRHPjjz!ZWasHb}LtxIL9hi(f zT$s)q`Q_@8|Lp;^3nKy>h`zkxNj{d-1Fe%SxA7eoFu`YNqPx7of&~1>_aO9+7nTRH z1b5bK?V*JVUvg-kNa=Az?*$$b{NH6w)>E8K%u;IEW^OG1ibih=5!6eV5Z5w0m2E8j z)hX%#b_36kn8k_CZv%dz9UoVI#@%a)I^#0GrHtsuCx{rD^Cfs;hI#b7#)czdSWWT`AU*)A+>A0?ei;r2N%&&s4?{ay1`#dD zZ@A`9kp9jf9X1Of#5so{p1`vh^dn{V14hK8_}0PU*TkJJf#y|yJYDtZnqUMPn9`*q z_O?6B?s2;9>xlW&VS=*p1qctlGs1xsGd(=J96<~cLZ+|6Q!ddif(w)+>jr=NSXGR7 zl0wygG;oV_l&^Y7b-T51SY^TUuDqdB^UE&-1H$g3gWs=f$+b$&$te1rMf7EG*-oo6*ab^Rul2kR73b_Y{ z`N_mXrgytjO*w{v$nE0?gYUjRyE{6l>{uDOWMXKy`)GmYXMf?Z&7ICmQOzOe@U64~ zQNf>&!Z}dvR_l(^kcTgIyx$2>TH=1${yaNpt=BncdB5-7ql(X{iV{d@ya9ebNa1=z zIrD%Rh(pR~5&+iXF>Xzn$FDN4Lm(zl@Gb+<_lBe&CV@d7LN5rl{0M1C6eW+A_m(sN z8ho50BNxD$=*11;E`oHkQrfd9f3KDT3T20OUZGLiDe3(UlN}gc?dF$llVezl0Oyh@>#EzzQODMSwA z!9;Xoq(IlgxaR7#8fgtrsd@f&Avwt&3eC_?tN}~)v;)R@bp|@fXmJELE67w91|x|9^0D@NqFbNTY+JTg!^cyKyF5fWSuOc^pp6yEd|xiLufJh? zFe^%bXaJf5*G!)jrj>u0D5Dzd53L%<@$mv|DGCdhp~1<+;mLYF6$ZKf!;umyrvemt z3Bi=_io=L+nk};01xKZZxQ3O>q=v@R!Vu2`2sOX{Pkv)a+DTV{ANh7r z_wA7N+hI?W>4lMz7?WWgrZ9=|bVB}xPyTlH?YqslVbA#!SdFfXq4nw$atOwUkkqCc z6KyCm;#dC6?%Rc3&i}2V zC&v7*Apns7{-omlX}vkYFpBAUERC)iRes`o@_V4rBuHVBM$h6gto7otpElpMxe=yi zipHQt2SyXb&pj4jvZLlq{%?L`3eGmw-1i>&&h%tx^iSA3Tfihc(}&m$kk_=ZDYwwwY0(@e*2-x}3hO`J z3F?pqFlSoo_p~&4MrK&{u!n&^ka5Lxa!-PpO56dmttFSCwBJ8MNi#@J2sB)k{%9fl zfASj|zpwUTXR1a;aRd;Q0Mb^NJ*!XEj?K~uEC*Ob>3CV|+JM0O8oD{wIyqz`7Z|&O z8Hb*k)cK61(45r@ySRdl9m?j)!MmA49$%|E)8lty%_)ND=#A_Qv{ux`;&((RScnJ# z2>HgUHq2bhC%f^@Gt9ML>5a7>Z+wNe5*)T*x|?L(Hh*JyK45M>X{@KH^!-(Tu-pPj zYJnV_141=0ZwS(qO@cyeE>_lU!Ra8W{L6Mtv8HPvDSLr%AsW*K@}gg?EawMF3z>hk za~v!4F*XaaUJG%(oRu9)D?n>L%}!+TDe6J;z1o zxj?=*1#gNT(cxZ$@zg|cMBoTDEsN^oJRVDE{_MME|Y4Q>jI zZDgxxP*{rhg*JGBBhwKrbB>V}1RaYviR4@?Y|aiLmWL`x@~tU6;Qo2(gI7Hus6K$+nxU&beM6>FkaWeQem zu~y^5tGjcnS?iN~ImNg5qO2AyPK(TGEnaNMyI!2B?LJuXv&}D;B+0cH55A9%zOlVF zML@K(s15+P-y+GWORjvOVr*u$uZ#@@OD;5t=JPqV#oql=O_aS2?(=dj)B>39)BAfu zB+h3gf3LF8t-~jn5LL^rvPm=CQHuE>XgY{pakg*W!6`ihwgQU04Nh!uccKXH06|c# ziMI-$WNYGk{1cocqXnB9n(oTAdQ>A>^8VtztN_(64PZNE zAut+zZH<<1clqD4lM)`p__D?HA}Z&?4t3K&6Ud>yv4RUA>RWNSFgF4pMfUh(Ujv_M z4gFvs!|>jFX68BN&@XoF@~>mRINZz*f^+hz3)vCZVuDx57lSIez3e0ujV^J(_JYl{ zYSs}KR|uyBi%*> zn~2xW8lVHeKT}ERQKmKE$YXkw_K}XxO%pl*fmzu6Ce1*)+Xch}~uS!9wD(G|MH`~ISY@GnMxy`am#_l)*fNR6KC3eLtl=IS7z2_?EdVA*} zAPREq{sy{%N#)5Fx8ArDTOh%t5bXQ1ac3UePCxuT!D(*&276%LOS`-_ml1eoU%`+Qr2_1}*vM|q zo*3guX~trG($#N#X_E@lV6;nFeJPQ2w|(Zjsny50pQ4p}$Md+Pik}V!vN%H9)=j*X~^1s6Y#^}cN88w>j28CB^?M!QAI zxN-Y^u3^|H!e!-?LmHS%+c`kZ^yXC&~;Y5M`shnc@42XA9u$9BFb8{t8Z zKK-OvO85Bv1lz{7&d-D=>*Y;z*S&&Szb|+H25+!$RcZOF+6Fs`g6$)CD~Jz&Om@3D zq5Gddc|;W)wNM@LIJwIj&lOEby|G)qb2LBsV)a1v^?zy~LU_x~O>;s0Lf78SC8;lw zamQbxwLp@SAo2=Xb7oWT)kn#EL7dwiqgxlhn%CBXSaBA=qV4y4b3Pb-^&$(gX%uc=ex%j(gRXFsDzJ@?SCvCsVjB89eW?mh{3?gX+<>L<^D~Us=J< zl8~*M-XCgY-f=?papI>uu{mcw)2|ge2^Xl_hlhUQ9mg3We>0WBugRLz$R0kpXYsh1 z>&g0r)&t{a=VRN2!}w=2_cZ=K{`&XHpYX{3u+!{pS9ZG~NxO@jZyCqmVxN<$dneVm zvU7dH2NF@99~QENLTf&pG<@YhY5a5YjQ+HV+fPs`Mk*`;-x&T zY+6*_WH>Vm9ZOoJ!P#)$K`k?n#aHo%MtnT~olS|HPb;0z7@W`AJs4lvp1XJcA^v<@%W0fQ8 zZHzDL5`Gle{a3iCAV=;h(x3aAo0SU4CdndXd<)HC)Ql7GLkj=}50ALp z#@Tu%dWkuvKhE>-##KMYZf|GbXKzoab_fcd6s;5%jH+wsJ`t;LCVpS^sBasA_=iMr zx0W~bo}7p`wPzxC@lFA=`gxX)o>&WNYekNVH9n{(l9cKMUjh~LbaKN=V`)eeU5)vfQPPsv3!U*n;w{hG4- zA*GGjRcyWn)6`Yhes$IJ+)O?~I_n7JC#QAj$q>_rch-=9`Rm|~rslb%UOjsQ)d}|k zfy@2z?CL+0;O>MWDS_xRVnXis>OQbs6N#L>b{_$fQn{Z z^X4^UOF_%ks|GO@?*&3(i#*cu0Nk=Au2h185vm=xX$kgmQ8VOu6m3W)Y283uHdR=8 z`(<2hZ^?UkMy%v!oiwCuX2@TKGfI(a-+IYCFQ^I<^tqFO5W4cbaK1-a5!^Cd&-ffu zIAq8!F4u^W)kQ9JsLHbajJ`$ylxD3D90q{mY4MQAGau+rxLvO`SbTN&c-B!s<99 z^UU%bW-U85Q=Ey1*wr@~C^apD7BU8BLc1(qt@IdY5iK@bilhTTWyGw~>a5F=ZM9|( z#6EA97p0LsD`+NdO%s$cZE1c|-P%&yIHdN*&-7iU}3LTzif`P;+lbN6iC z%eKbQUgA^lFla8?L|*klFTqwXlfOw!p+2rsoIp;meH9VYED7a`idtL9AmoZ>tw3yk z9$M6;6mfx_js+y+1r2y>R}1 zy#urSxFL+~y6)!p_TI1=TW@u%s^|WCKFkp;&+g2xWS;#OkF+#%I0%~~EgzvzFJPWP zB(1xm7?4?`{=P<`TYMNhU=L@tce0bdG~9t^jWbMtI9r2wQ8h?)PNBpGb2y@{JG~BZ z#Qh{8H#5}-VFF@Xn|eKL>txjA5G~_g9E;VO+}^@#S^0hZZAAe~EFPq8C7BVWmM-)& z@%6tzf&`zn0|Pz4dN<@S2P3zk4+i)c^T^`2PvJO%mde-C!)(?v41Ee0z)Ono^Lm1J zQS;*HkMqkV&EVLhf?B1M*&abeOv_{vjU#CxR=%<7 z*riIz)h*@7g>9olyvx@wF!bMQ8@^*Xpvi`|8yzVZn0AUuF=3}H2!#=FX&#Vt8TyI#q#J+ ze8sUI4h7c*2m)u@Z`*a0HypWcT=oD=UNrb*YE!|v#wYqd;fXxkJ- z%1^Oa$;b;?)xXzmPDCbbB& zGKgMhI2_-4b_-B&xXS-xQhSR93{-CkFSoxxb5MS3C#N?!B5Yja3HJi!S6s`k`oxTA z(;)j-p1lwu2I4Owv_`$hvG8PUb$*XCi3wmnmP`G83B;WbnxAJWwLRNQ|L0P0T|Mtv zb2M!Umu~0s-!A>AMXp{bZ8Wdj**Xl_los3XB>4*x?@S3`GX$P|06r`d#Tleex{r^!Ua9jVtwGK$JKd5F-mv{MCV7NeM0wo*u*iVqCIuE7U|0cf|=02EYT#5C*QutnS1RlXN_&-Q+QFaR*e0B!=0JbIV!N2D`NdhVllenXt3Ux7+CR+!$*HUuK=D^Ou28r(5>x-Ueb&HjQ4{8GZ2senc|u>u9J zv=}c}==M~==IZEwG5xW5wQ&02PgK(5oy26h~;|Yfz z4qO$RIe3898?-3Dk2e+VeOk--1r60hnTjH(X4$^VXnN}7F>aI$wx%2==g`*sUp15>pELSfMN*_l{x~-bhO^M<}`sQ>b zf0=sqI`*CGsC&8=V6;%%d?sYa$sp(-6^`)lE9_xU)X!)z6O=!tBE-qquhtl?=ieO1 z)mLaYM?;tZT&;t_u*Z^bKi-|@?aN5K@z5?4`J% zVg_m8?>4)#n>K*jM-+DR-30T6q~@bqb;Jokj?Tg0%yHkbqO^8Uu8})R*@dKEghDw= zy17l%)&Q|SiPxXJ2KPsaJ<^;VWGD2eb`PL9uOpLv`_2b6O};Vn7Zqx6-{DUB%T$}j z{x!off6mk|N2A)gA|vxZ0q*n7GWA94 ze=Uh90?!YKN~}IfE#EZ+-1!$R*-29-yAEY{B(hf*FFD}&zp)tN->vXw^wp79$<-}F zu`TA}or8}<3tH)-=Gc5Oitf+Pqo1u-l+%g|VwVI#ke2c4ar{ci13&e%==P=P>EZN; zOXNFZ&Uq0;s`Fxs$8*E#db#z%#jnyOPJZDx=7FLox`6tzVxdCuaHJn;piE+`NK<0F zIIe0)Z1VYC-JJs6&s;+Qn_b;b^JjCk_(QZ#`c$Z??h?sP|NcjS;FIuvm*t%A^`->6 z4Qp}2H*zL<)AGIWLbP~xe?=Q;;hX_izsw#$${i=JX!+v%fIQ8^kDKJo)fWbZ1Nwz$ z$rA@fRdYJyZ6K~qUR*N6oNm9^EJ*y!@xzwrdzHtHTzk`IytTtCq)hxGNSH4J^uny+ zt>HIzLeve>K-Fz-u2!;U{wlv-pJ+ZamoE_4I;c5x-Htsl>3iT5|40A!=$jj&7_rL5 zAZ_B*KPorARa@NoCh_5(mi!}CfD~?Mg%43%YuRv7D0R9UFrEThZrA3C`r!@g->h=x z60BM@4GeA(j(~1ou<^@VyINXB;=7_V?a}J;X{DH(n|gsI00K^~?k!2`QGIJ?!T$LzX^z|4;wGQt*l+J=I3w*F zjtd*u*J@Pd)1FSzNVaN!9tnUsNzc(WO6I%bb#WUSHMLAJE5dbJ0sV*bndm1@|A-Tg z^v|eV9KIl?5D-TIp7oO>QeKwH?q}@gl4iCxD+`t;IP6%77_nYgSYU{XN|?A`+{9nOnRhQAxW)1{ zyihB7z%s|buZ@Iide2#36s%QLpZPeMImy~EDWKp7E$vuHFt8i|JYT#NrTdK(>QG85 zjn2=Ko%_>@@iGCxI}vro`HSF=q)n zs@8+l&y;YH!Sm4{2^VE@JL`#&r29sGoBonuqfEZm!r1jYS)W z>0^7}$u4c%o1_6#QmmntsvFT|3rJbb0S4#(H^0(So3mJkZIvMIg&2Y)O)igBI;*8+ z(xOSaAGH@2x5@f_x~hxM68J9j7?8wD^}z$$ zM8~ky?{f6Qi=-u-2AX*J(#(7c=wbJUiLv}3ITYr7Pw3MhCk)TnlNfnure&DO>y640 z?9hSX8=)*(w%!=oh-Ho-eycrKaazE;?u+4H2AoNEyEjStdIzPER2K$*-d~c`6^}74 z)6(j7+TE( zZGbRiaXr?!2#J%0z?ctr!Z(67wPRElS5aCX;-fsezN7;)f$iJ@Z`csEeUPI|iQaIv zc2}YsNQ%6azKH|=2m~(letD|?Oiy_QRtJSMmYr(xHZCOy16FO+pYnY`BBKiP#ruTe zCt3kO5WS^7vAA_Nu>V;~7aq_oIr>+4na0g%zsarr9S$&bLABU-*U&+07ip?+w=i{3 zOqLOHfh=C+Smq-I`PfYA-IEa;lC)E1In{}BdMv)3z&S}%$y4Iv(XCa%%eD68@W~Q{ zbps$8wQ`eOUedkrtugc{!b2awIk%4z?Kj-8YMdj=q)naF>h;$lY3rA$``L5a4XiQ3 zzLLiQ{axXz`hB-lEi$VYCBs;TI6Fbj7j)X$aawO+EWYBx@YGX*%SfL~ zyXfg^3dvS=+}d;{>H#-3)ZuwE2uK4&CJ)X;7R|Mhkk0PUNyQTMXHk1!W7_YJ$X@Pn zh@hb=w7G9FGC0J2;L*tfpMxD!SGO-|7-%h3F;`h1^#lBy%A3q0)<;hSR4 z$R(v1)dEwsr9H_$bM;1cw#FTJNgxC;%B#B8!`9vT`^&oMZYaS1VJmDj9iVlUXTvi7 zF_tv&@`Eh7MbZ_nqn)dk0mR4>BEx6f@=pPe>>wG z5&;SM5%mzzn*UQa4k224G*@@tt!t%_C^W|B2CFOC&~fpRc1Tu(6^lTVfLfG)KZzz9 zOF`zmA><&sZ+?U58mr5>D=cE!I@W4(SssHZ*si8)A3K7+&i6jLcPmd!li{L}Wrd!D z*m3$DIm>?1yLT&1?BPA>$(N1GUpttzj2mpBNC3J|5(UEohn&~30Pv9!xv@SbvyAN- zC$H83ROC+VfS|MCZG3EciCz`3jI*-1QFynG{hDJ&lrAfjmYE`I`%N$FqZGBZQk<#U z?RJvwKL!}ob}0eTY&%NO20>R4&{87wTVl)8EX9_FD*E^(vH=~9*7fWNTL!7ue1Vqa ze%bGS;11r)zl;DvSBlmKfh8r?Y?9ud!-)VzZBe^v;xaDU_||HpXfEwxy`; zbnAQcC#_&Z(R4?LNxYq6y1aHI|C<(_%at{ij|KB8#7BUwZt?t()65M6#z11w>l|I0 z9F_!|A8#Q9A=;*NROD`{xZx+yK=8059}EDotvED>b#HFoV3o4n z#8?u^1Y*?kkmaiWK|E`q`Jw!xm#SxaYlLs@w4&9+b*xjI^C4;mjiPdfR+Y5&*pPyA z)N6AjTjrP`P~0zTQE1<9Yb9+b<*GJ5;_mwOnp>ymADZLY&9t)~5)!m9)Amq1o8UB6 z-_cNihVM?1ZhUTxcH?1?W=Gq{+0u!ahktR@>$F%b@s1Gq?JF(V6(_-kS|7&q&b-^h zbX-J5DPZjP1fC8V+#872+>Rqul!gk9U&)xVwqGK7Yu2@@;ZEsqBFXIxLp(P30 z8eTnG_xXy2E}~borQvKH@BBi%sM9xYRnt?5faOp}E>{C)w10ATA}aBEk>gEg!>#M4 z#sdz?DMGGC0}`Ki(0~b}8@Lj9ddyNKK+6r~wCRB;RW=ii$_caquC=u^|!{8bfrJV}+f==jq~gB77qBR#vKv(LxyFJ_4Cs3uov7@3(NGbNac!$4QU3aXR6^38z> zT>cP%U`Ya#M*pZV=OV1lS)bh9`cA7x>5(HU?l#zRp1o;XeUzZjASg5?1kd$jVSfA0 z4}dPS+t_&i>$io*z68XW0DPn};VlS^8-6yQ#&kuFq*l3<6Duy3>Mffko26kRl}FJ8f80mP-~F52+6wF-u-y;&7I zBQZ=gCR`)Gxdr+?Q*TydaHWCX{NCSJirS~Kn)T@=;RfH~0j4I5bH zvMk#7lPsWm{%y+NZp979wQ$lpdRY*h^;l&YEHnKNth7p24dNt!CCl%Y=)I4dma^*e zul63dakcM<>8vz5CCvCqZrAZx9eZ@R8Rk0@>Cg(c_d`9~Z-W`14=t4d-0cIIjPQc{ z0+HFke8O>t^y)CrkaDPFG!slb|Bchmk}<=tQM;RRBbBjHb#cf28uS2SHVlGcXMX1?#|d-a9>S?(pq)w1mDo-W<_o-*ZZ>4*u8(KKL|h* z^>1{!Uz={-sGi!pr;zo)I0rshqY?{-5{9?aNW7jHmV!||#lwt@B2OXdzok`7t-`NxS#r_GY6v<>26=X2iS|IGr+HH{@3v_EJW$E1;t?c{-ciX`zp}e zg@SaGWf=era1&q+VjRl6wQz4WRvt%Tmy>x)}Zo1xSzMCPal_WE*W-ML?i0X9^ zj(tGSFxz1PTN8MZA(hqOYme!SP-0*Y7xy zb={s-q0f0@rH?w*R3a;0T)N3f_l3M1{XiM=Dw>^3X2Ad2))f2CGFVbz2Ad#(*zm@t zCD>~bP%nPhHQx%niPMEKp^0v*gO zBj`7eo*j~-2$N*~HGxELWJr_n5A!%f?Y`qyG!I&#Dfs+N6R1x8>t5n=v)I&ch8K3X z?rfdR8@A>k80>HB>=Y$%;$T>!d&o``z>N`uSqtpBx7YQ+X-UclDk`Ln=&ex`t+s|( zq|4+P_E>_5eESj2>iwb16<-$>@AA}4M9WRk&Q?cxr-n)y zeJj@81>$quJcV8@+*{3rk4YYHDqZGIULUs?%`4^=rFz8Dq90DSg%G`O4yQ7@Pu3*l z7tsMuZW!nV^u{*R&&Kg%whan+BNyL4Icn=T-wP^^T>1IrPw%G*%_FPCuN4nO_t%*o z9Bs(0)yyf4q*+|s>M0DBs20}+rb&kl79WXXw4e}cecBi_xHi|uWnEXVVlp3#(`DUG z%wSK)Ubz&zWw_Vrk0)kEIR|d&!CtpaUOB;3Q*KGWt-R#+EMbFTNVxXd{W>E;@(BBT zOC#97ZYJ$a0!)Hg$@+bBk5|_+gYNG4H=&;=&q`4(+8{-s5%mN{*~L6~d)-EVUC@~p z{Mj7F>6y}^t-brR+V&|Q0fT8tWF(^jSaE659g6^rWz2Qfe?<4(qcRsW@a=YVb&0vt zYT17;hdK6z8?VHkX@QDk^en`*p}99YGS}uB>up>srDgOPzhH-^H`;!aD0eho&0ixL6azJ8giYpD1ai9$%k_4XsOXD|$k6d+KwO%!5jEU4s5WM)*x`2R!HT{c7+ z^lby*rC|ZL=#XwClm?|^3F%g(JC%;58)=r7?(XjHln{_mL_$g_fxW%1`-ykwCzx|) z{&V~eI$#4c{$c-?pEDhP10B%?9fY5r#EqV!;d&}i=t7O{-h~HHI=L>-)1pEjS}Q86 zE+XnCv`Z+=Yj#KNAQ&(>?S=W=Wm33qcBR9dQIvI-seICEcgn?j&WD?APF5jsh}2@> zG|iF-1+Zz~J5GIxq{hu!dZgk{$#JUJo!jz!br-{Q;b$BTr*l$2zn3o_$M@l5<0uNZDpK&*eQ>= zzD)RX{9sW~Bnlq#q2d`oKA&W7DRTc(&{+wO9$%)m;5lXm@?hE&{h>(QXva&1NHLhx zc>P`07FWmPAs>zlk6;i9)6d$e=Sl!ce178Hr6?xCMXBe=U%X>^#Ge*OUi@AL0L~17 zL61jXNj9Qrd&QuC4n?B^YZHC`MR&{zI4xq!>K2#+=M#-KMR~-9L!rlovM&{di{8QQfa;59J{-K*6m=vMeL$rklA#Lispp8U z8fisyDd+dh)k9tSI+{^t@B~?ThA7Tffcd(j1_K>7-#5(UNImqjTvCGb0+cYPq11X+ zR$#*hz(ns$Xh!J6+|k4is43>L4Wl3{U=145d)P&SY?%!hi%1YsZ>|TFb5o%q!CSck zaefPxn5w?cxJI41?G zX$(uGv({UGOmnbuvlofX8&kP3S>sDBvm88Y=gT68R3nIti8V6aL`^ zC*I(xWbiIs7Hs5CH3h z!|M<|;jph&>dcIJ9IZTvw(wJTwqvABxV>;tGg+A2>K$9D8&h#4MQc>fO;o;cbg@@- zd0TX~7Zz$S)YlxKOz@h~Z&8ePq3!fIiuOdB+eCViBxaE$*e%k2p;kQQ z#6yTobfPWE+dMUUFUKjw%`as7_iv$>Zm%x*StS@!IlYrxJsAPhGug!-Zl@%PF?L-- z0*VQFZHReFexOXOP`d4Y^lqr^ACrN7FJ-+lBv$CqXShCwQm_}UAk8x>Y}?d!+sUnQ zHb6*VgviI-cAl0D_7+gKc_A`uBO1lWkluV7oHLvqv6h*xkexG}mE#@$=2?JS$S2`~ zoJ5hAt;!kp1zTg1uwqFNVnazF&i#m_qxgjf9ImILW-<6}XVy%n1y zlTP<*Ee~&Y=f!EUzoD)0E*qDYPx{Yf zLu!UoI3IyZqz3i$Vhu=U-v;*12JXKN<~(0;oEsM@8l6L25U3BgDoq-mP3MLMO{*`S zGvq@$l8j+VB=r^c`|M!Ke9JG5{?1CqOe-#(t^R*ogT(%6-fh;eziUzzz992>eb^L~ zH9N4{-dTGco8=Ytg{P#mqh74D*|)R3v$OkeC(mK+ny{~q_A^H0V%Mfl8%2D9vQii7 z_+u7h^6*l!6so)5_v+wp&mZ63+s@vHzr8>|$G&3yG)v({u|9}+KZ##IMOQz~eLubU z05i2vPn|y3z28gT&c>A)aPXjM_U;NeOVgmYpflp#l zz<*Jsd+~X9N{R0%pS5)FcPVu0hfu>(Rv+JZ#tYHna#R0G)}^9B3f}PIFMZOLt&=bF zV7ndO#j^&>%baO@Zex7elM)hZss3vj-D~Kahqd?jo-9M1Lm|Psb0XRMTJ3>q^Wi$p z)N%%;!Yjq&b{>4%N^ZWm`qAC9(s5wQe`;9x7R(9QCHA)OX)M%MLwEsMMABB21B zC)MV9VfzQomkcY4J{dB?<;pXgVt?n~8wLyO3vpy~hIJ(_NbCs&?1}X3Nj&a_seshK zSJjQXCgEiX2Eo;7XIMSfnZsaO0pEcN1j6+7KbpQF;|7ffxVz>Mj*wuGpp?=0JM1R~ zX!U>t61n{OXn0%TR4@a^3_?kf7V-)5J8U~&EyI7x9P{@ z`Ijfl0Viv;qD7l8aH!=s54IsLUcfm-L?2y4_^Tj_Wqg4Bm2*YCeqBXAP3ZaTXw#fxDZnI|#S)gl`J3pc z=Sbdfd#*fBimf&%A9vUBm&YsDSJ&>yYwu4R%@x)Hd~zmMtc_+*f6fZI#jekLvKC36 zqDSkwETkRkJd~c+MPojFko}N`)&B9jVsr`LGwCT^4d(NjZdR@$Qt@wQX5r>Op;dWb zx#&zTZV}e^JK8Bs=dqt_@nje+uyX)rfqg+(Aj%ht6S`tyErxs_RuL2ZsKf_SI@E7p z91bMs_)=rDB_I%3=#WSWu#g;31xygzSjMX)0HnB_h%^N;Dvb;<31v0Gl2{&Zv0Lnf z=mSGN9GeiDWq-DH2MZA*kDi-%>CGlVU3rSnyYeu&ESmtXV510dYnW z#C56J)f$Csx9dgC=j6VQDGm=)9QZ~HQBYWtCb+L=eAwc6&8-J~cNwXTs`bO6L%x(} z$JYtjTzDoN$b{5!W_JG~8S~v8jBH8B-;nKHUL<&nFU5kElS$YaMPM+GTTa3W@UTMR z#O!>M(3+;buK=Uffnf{`+ktJ;f@2lv5IfO;f{I|A%vn=$`&%9+MYLgo2aw(x$^8c_EsLcQ zzX&@dQ}L_}ezV4a4G15hgjA0QkyN}da-g5r%_R;*qa-)o>^1b2cqwRP<~4AwI}U3m zojQN@fSkK7W|W<~ueM^Gd+st7p6a%0-t3YRzu~NtsO}m|48kJ~4egsh?zl_7y{k|v1cH@}Z*=8K1am)RG5@u$bodD?yDEt1G%)hu4sW6kPQ zoX5JDz=(W~)V=|T9TM6+ZZgXbUw5IPhmd`AcU(@_d23o3=3oLuYH(4wuWYIni9Weu zS1@{Lmcp?7r+eF%66YP787(zt6qvHdxJuMmfY784rNx4hMln&D|H>32`MIEGylvxH z6A1pWhN8G0Y?pk~H?qgH4sBzZwKzGnTiNxc8n`e^IP)fVzez|C&_ z>wsH05HcLnYL|pGfx5K}M8{|{(T9amIaZ!c)6$-Bx%~R8NboPJC0K2PZ9x#s`_ZCC z>D~N5g!IyQ{*7^p-Y7LasCq|-VhyuYvcxgaQ1pj(3^f_91VxPf-D9AYgIT#<;&gY zR;;Gfb_+{RIVNSj)Sm43|BLPAr1@gy)JaED?uc6#2r95qf6C&m8iME@77= z9?Fc3-GHwt9vyISzfX}I$dws3l7YcqI@2}$98`|Hv zlKv0m`_Rt`y{kn+&F9N z@jC3nOt#!yRcW|ze^{VvaE79J7+g>m%$|T8d|KRhBA--;{kb451(QC@%NlH@k7q%%YdbGJa5u!p0Bnz5svx$h`e>E4)UQ)J{X7isH{ zPk8qHuCDbzjrn?i=-IoEAERBG9;!@380*6~Q}Aq|OpIjKo0@;J zmI1X(bF)OZRMfS^3USMeLvEP_Dy~kCY8j`P#_`LT&)8~JW!cN>?zR9&!Q+Y?HCx}{*VuxWteOHOm+YRfc+5PqAc{$WmS1K1!!+ zZyNSy1@Yw!4B9EC8bQTsDYv5-710%J4QUxg(^ntcRtoh&&Gqqd4o0lLYg0}wR*iB# z6H+g4yv1t9sBj1NopW3gGz4}o9sVTB>ivzMe>v91!%%_|q_qjFXJ}KS7?eFS92279 zDESn3>UZS(_5N3DFPg}u43E$p%i0iNRFBV%>$gMb+&1`>_&ieCZ}Htw{qU_g%J(VG zDawEgF4tO1hr^GFB#X|csK}?9$fr5Xe@ak5&r!gLC}6fLV2v(duPET0DBwOU;2|jF z<0yQNC={?O6pAhssVEeiD3mxXd__%31k0^Dw zEA@;n^{y!OohbD`MENn{SU;&6L{tsiRgFehjaO7nPM~&p;l?O6W*lCb$_CEK2JWK< z9#A76XXA6_MgjXqp_oRI%0{utMv0@wSD+>-&L$bCe@>+*Pvz% z&Sov;W*vJ}9UmNljEHHAs%(pyY>PW;iwCtQa<(Tcx2M{-qhi`KD%-Or+jEZEKY=>( zIXeoKJBsZ)N@F_8D?2JDJF1U5YC)a#oSluzoz3>0tudYLm7SfFo!v*By`Zjs&aOe_ zu3`JG(U`9B%C5=DuIZz$Sy1;pXZNCV_p*KWN=)}!W%n24WcTJ#_g7HQ4rkAva?f}B zo*ywiKP!8VCVNhfdQL&T=bXKl%DsQ=d#_`9Z!3HMPWC<=^`b$2K(0Pa6?7lALmzHz zAAVIIXsVCsxbF#JKZL6vrqWO1&`%cIPf^uRHPugZ-2aqtfSzlBQDuPHVSqJufW2yf zb83M5cz_2Y;&|{C;gA&9kc`TZoWqbp?2uB`kn+@! z>haKP!eI@rVJ($m9fx7P*kOaJVWX+xH^;*!gd?V0Bjze2mJTDP(8!m%K(u@DvHSeV0DMC@2p)mY5bSlsd0*40=d z*LbqZc&fuVDt0`hYCLOdJm+})6X8TY*F>SpM6ts}Y3xLK)kNjgMD;OxqLy&7o@=sE zWwO~}vNd+Hy=t;^YO?!yvX^kGpKEGRWop=AYBY9gylQH4YHIp;YL;+%o@;thWqR3R zdL3=i(=D@ z5+{qVh?bo_j!J>0gxwSocTmfxH# zn-G0AmC*ddt7Jx#Qw7Gh%85Cwb1=+-I?+IjiaxtX;*-ulAq^th;v-hLVj#FqQ^RCI zwn;TAZ+iB7d{RlFy^!y*IDvReD7)oAD2ix{c~@-MmA=`MB}hO5RA*csU{~sEf@8}T zg9Ivx_`LnAI1qeUKE*tumqHBb!Eb0uLmls@Z1QvBKl;O)`emBBGs3(3^XR}e56oyf zK_U%sOCC*6+?VAte73}OCVCu9mlaGkI{f1Sd~CjTHrFrHCz~_(&zUt=FkRQ}-@|fg zprH?&k_MYE>-91EH;z=Z54)kiiAbWkznYK}+;(RHf&brlrvMNGBtQ%(`kz=_IQai= zbCa1+uGw%Fl)4;@2M>(D+397DF@|!z0g++`#SR^X4lovVoRFN`pD$BDM+ZU<)hj@=hO-Po}ugm}{ zEGgza&C+c5GyX5}zMthXvIdWH`;4^xkrU?Fu9ccJ{iEcQ7tTN9J(Sj(3=zh1XTTOM zcz95l5G`j_l%i{AtNy0v!?)s?`JbuW=rbHJt+N}DR$U;`uO(@DTs4>zT%8b zc(qlFe!Rc#O~v9F&|ND@AbCl|gfv(ixMb4KcyoIign=R182~K_#hX(jGbVfZ#`N+W z4F2c5=O>&kif1BK>Uhm1KjW_~^rKb@c=3*v{s?P=AnOS)IlJgb)&8Ut=8gr467Ghg zk2O;9ht8y#j}f5K_e1eH&hK7wDZ=?Kr0EUMeYLdbq4h~JobDPNs=x8Toyo(Ak_s$T zT?LdezvZl4gXnN%6V5A=hIqHQqDT>2q$^{iFG;_J5m?x&UveCzO5B~ZTZHtS{5L~L*lx$KzpXf7-x%MOuq1k+kcuRfybX7p4KFzS5heh_ov4PlX4Qqo=Q<7|^pt|Y9Q>&9eS8n5o-SmS$@^azdYZ4%pZwAO` z>cWMlc8Gm%2I+|FBcJMifn?nbv8kQ+$Bu6iDYVLcD<%%nQQ51=RD8i*Q~#0ba(5fr z@SOB&*}{{$n`%(5Q&z1Z={#w#I<|rcPl?GX(U=@V(|cT>_?h!_EoL=~+ewRzbBCJy zb@GtgDH|3P_Z4&V5Wv_%nrer0wz$fPcs!uO>r%_PTkH@opTu+RzV~%@8%-a zw8_4frJRM5(qkCEFJM*uS;O10kUrDohS)HXw!d5C_r7%cB@WYbSgS0>X%6{vX49r< zxJ)Un8vjo9$nf8hudRfz^!@bF+rM`!eS(WEyVUeyH&AUnr*AQ|7(mX1(o1^4?xj!80a-@p^>zbijI7nzB8 z6hPU_9fSirRE8elnXreS9D*W>cRM5>WgZScX^BX&Z6)|wf144nRqS)bu#6&?U@*pw zTA>+ur)mChvfha+r#nx3YN+715!Sht@1gfO`-#yH9MIzOOO{}ym(OV=+2NX)+dq%` zfrsapulz#p&BUkVKtt9s|1q$k#w&|sH4(T*!dBvx(c^E3L@KmOjr1zB>|E+x{Iz~< zLS5nGReO_!LsB}{W~0SP+mirs+*c&KAoNYf@9vV9WWvOy!Pnb=(&fy;i+KZKj>Q>-rd z`F=o3S)je(zVrl`^x<|p5p?-Mm_d;Y^$~daQJ!AnYMvyUX`a%iZWsw*(Nbvqs-fqO zVbB4@$qQWR#g!5XFS@409fgRb$v59bVKMu@@q+TYnx9BVW6s40Njadxql|DNIt(97 zK7>+9aV4!l`y#?P{aeq7s1V~PA)!GL2P@r@XhK-~{|Ez+E@JW<^CuH}# zsU+Gt2F3mw6QRV1bkjwp%zr!#a{c8Y@kKh33KGp;76Wok@G}kk#+*bB)SbeEA|KnJ zY3Q*xeGy6ciGXZg8PISqRXY_EXotu9!I9n!u^f`JbqzXe~6fdv)7rkb>( z^qkMyNsk4w7uu8p2B_VHNPaA)XycPAD1oz!Q!)80MBE@Vi3x;37CxU7iTjP=*D1KJ zadP8oN>MN+FZi1%63YxQe>)iYO-G*LXB4zv35x>iRSvO<0*fhQiI*%moqVa%rmkTb&=(9@!9bsRNUDPdD4-o$z6$9O zA*(cikg`Mj*WkVFq;2fb<|ufV2+4uV$q2_a8+atcf{jngRtbm%^1h+29Y|Gb4ywB~Z}=-RS=qMDb*DvHOS z&!*Zj0sH$Cno2=hy-hsqP5R^z!r@IpkO}@HiHh4XQf?Q*K;(?LW)uOj zUbd4&8GvudGoQ0(79v+-=UHQ+8UnYs8+&b`HWy} zcr^-UWUTg$P>n`WpdZ1HFr|9wj8L>G!bbL)5zRi(|4V^L#`hV4Rv*=BX zC=9o6$2alD%%>hIDkkEY^Y5*1FS3Alt0ut`nGyyweB{zxDB*kYf2whv<4sX zA_K(Q!ee-)z->>@!*GK*V|`h*#YmAF+NCW>K766wLCj7 zOI~f`S`Ss%Xlt+mCyTWIT^HBkZygbZj4V24r&SenLO9ygtHhj|sJLIl+jEBA`kI1x z5TxWHQXZZ4AV?Sd+^+WvN7~<@(l6~BNHSZxda{#X$sx_|BJ?^~52usXS9k^Z;Fr5V{TGQp=%U2n#$D4p*eW;MB?+lb05KyefRPVn++B;ME zwR1(5(aiV3>noXD%V!ZVgQ6Ni6|lqx$q4h{R~xYK9YpwIKpt2Ck{V>{9|E7-?3@q! zauhxR#%r^=UIY6cONaXM2Sn?O9m>JE4eU=ihG_Xp@#-ry`t#~5Mi>&i2AO%TE|Np55pd8DAsVQ6Ha@_l@uYDGOqLDyiUvnl zL4`^f!@8ia;Q+&IXqF!sFB__+l7V&ZOc6SUc{YA&HSjF4AK_tLQw3h^qsK3(gzTkEk##W6jE8z9$tmwJWIWyA9^aCr^1oaN_O>8ZU-qI(Q z?tl=9+HkeE?IhMv>KA{H(?5mVKb1cAw!^_+;(VIG6fvgAdTR?-N)cy;&%el&l`|$0 z-iSAu_gp3`$4i_?Vy?{e3G}C7jM#jE?cI~9qZ3HZsbQq4;l=mlXa^X54)mvz`?uKE z!6~qSEMiIA9}3ieP6WICyi|t=?Mx_) z&++cCA8t_jZbGr&)UCv%%RIb*$B72DdC`$msB@1B1#ypmSW6EosRk?(?hv$b!JLOQ ze&94dOPmwn`;57>mS74l$blfI!3`I+qaGZ(Dv)RzlHZe@Zq*v?Ik46oN#c3paU!U| z(uNTI7g{7_btem|sIp24*-*>hyu)PcMZZ}qcUz{3Oc#rT?-G#~Qj>Oj!rl}Nb8&)Q zECW}QwuU$T?TNOIGdZr~zTUs#Yjf+2(D-)^cSzh`dyQQ>V#wc<7(K6H42^A!Umfe3 zdr_4I-4T018rxePBZ8D<&#?f|ccUloKym@!mEl5UEI+_#TA&+wM7u2wiz=-#NKNOHTheOcRA~|=Y^!US!_IU8w2_XI{ zxf&Sr6`)|^plFWCX#0DZEtzF2nMQb0?M!lh0)qW7vKWi|W!5{;pE-geNt_~=Q??nfv^_+lk2q^fvGy5Q@ATr0MYYx^ut{zBC>}sV9c0#X z34@0vZ|-r`9AQ9!ImD7Dlvg6+-blZ)Q^m)l{+VuG^tGteX2kre4qxfKXx80aPu}Jro5*J`vVUC!e!Uyx zyU0+S46f>s^6hPT0Ka|Mm!y*_XO|NAh~+dB0SE+izG}b(W0Oo^OMwAANX>hUAiyBH zxBTb55h`=&>DU~Pl)K5*2KuOOxb}fD25`3VTVwmdC;ef`q*rmfPHOIN(A4g)+56ev z3CX!X1G9e<&L4i(f{DZ)0Kk8_rC>}j0OF zX1dsu?y!r&s?5X0X!RH4=W`a?WGQv!sYyUBO`{DC+TRWrm_Z+$nk#h0b_}X}$OwD3 z9#q7;ahES;9 zG6fiTw!kQFK#{=CpJ|Ag2mT%+h{|gzDg=Sxjr>7Y7rU4G}oxkcqIg!vJ6aBtgFGlDaxEO?^u`#aPtu^#49D5jv90AQk7}J z8YDq^X{qbr$UeqBLwZI4d#0Q9tJ+U9hk`Wr?e$I?z%tiY?{irA*gc(gWP`iqP7b@N zki~6l$3P#a^E#34EwiG6_>XH(=keCD=Fory=|d^z$Z0*=*Ku}|x;U!YjH&?lGl>s! zM6QrHkDoA}$2I&yn^7#> z{ky?#sQxdF0nRJG^HJ7bzsp&LSN^Fox=0h=sr86g0k=PpRRdM4U&}e~V;a74E4-#m z)nO?~onEzijvIv`^@2VD3k&G1Dh4TS!3d0Q!gwvUD%~NNUs`xmfrdX+;Bo$Y1^n7J zm>ude3t+LNGOA<5$S%jFz*!Je(~fq>NR-sV8cKL1IaIHHC&%fx9L--f%3uyifg&Y? zg$$eN@qy|X)eagm1etZTnrmOx!3!crI~>dxc5-lw1sKGQ9yoPD08m*ahFgWg1G7GB zD&_-l<)tDWE#Opq7`%E*;jFmx80Yd!VCR3ozW}>Y?%dB@Yt6XMR|Wt44!)=O!v7YG zcM;rm*kSHBvGg^K!zE+G!vPW%$aii{ct3o?28@-e@U8*Y-^NTM|NC@pWsi@ z_5kNDSa<lu@7(V>BD?PU7u56Xk(&COv~A)mDb(fmS<|xihRYqgEiI`ii4XS1-$nI0K9f{` z;e~R?XbwhgSq5T^J-z5j3S2)X)Ra_w$IT?QibKnSk#0ALX1Rk`RxJHY475XNdul@P zLlFdm_|Y0l!hzn742(EY!FcvKHKl8nq}GTDd(l#b-(B_LLWv|ANN{kd&{2 zej#OrEuUyt)EvO!mT<4}n+g7H7mHY?251pP8%jm+gAVc1_%ay57rXW^@HjoVGpP-elwAH0fV{`zj+fk+4 zzy&qpKTI@E#3LNJi{@(B{a}uWgdfN|1FemA21lLlXm{kVp9O5HfajX;bOiUqxKsloZlIC*3vPmNX(?A)sPhc z&6@#|*1Az1a_$##bpu0E9@y-73h|;h1z!; znP;9!agBu)b=lMX`dQ^KDHg~V>%v%^I?i?CC5NtT4gqD?g?SnFV|%}xH*&~U)coz> z@KHUXH6KyT9;9pDv_7<)D; z+aC%q!!2%vcrano&-Lc1*)9J+-XcLF zCJf0RvT_n%(C%bsN1sXJ<6XRQu2u4iOcKG^ywvj??Ao$Bsx)RS`es@!SbegH<6f7& zLy%3+_|*6Hbg5*!Sx)eksxevWkUDx%T_OfOOz&TFzbmcNm@D}CtwP=v zg$SbgqYlQZH9?%Xq##@7Q4JiG3V%B>Ui_P!b9PJjVO#X!eCWfi>g>BDprLB0z$+eO z;}Q&tWV49?wny@p6HQ7Np(&#hS0kZDLca9G*mRtv=4^lNxV}qc%yZC}tOnn?kt3xw zvV2^y<&=UYlPLKQ(qM@EqPA*MDj1+m_^T-X%R8>1y28NvNXnnVUlXzilp+rr@Si23 zLS^FN>m^^6idNX9D%A^3*60VQ@X6gU(b(^5ch)doo2H=qBP}yvp>;g79e9eDq#nG? z!DqFl_yFtvOlkU51D`hhGT4{ALZsKvq8eCKE+D$@vV!mnApwFLOchx~m!8=+2!t>Y z1_>!0XGMsR*WC%?X%FjI4C{Uv)(aihPm!|)@weoqaIq5yr43O(Pe#1&E_Eb^G{vK@!1{7@B$Pg+F6c46{qUGQZSe^Z3QDARCEvT^f^; zdwsw!01ae%yqQ`MlJK#{fkguLgJHh5MDBD9_U(kIi5!om;Q0(9_VnNa5gwrp0()(^ z&jyiE2B}i=OM8p*DLVxpY7(Y{VRQ`;UQ-TxT{%XzgAu+>>&K+H=?%tY#AMD>;2#qR z50t4?SLDALVP9=upn?}76Htf{JWUtYaZN>2>ga$@`5MWVddf=6tl>i0i6ZTZ;{Vm$ zyP)qK2p5iGl;}VcyeU3s*hq4KFI}Vwd0=NbL|dzQ)z-zRp)XoxE-=_9;otjNV#B9?PvrYP>tLVFhT>X>1&LV{2qP~&>#p%t0 zXU|d$U+75*@Vn3{-9(^?Rq}Oj!e|Vql*?7M-Qqohk_|X9Ukxys1P&4Qt7v0V2T2wP zvUR$xrGmwfaXyM_#eLHXwozJ2ZK-oTL!HGDiHzle6W5`$Hz~8XMYDHJvwyWg*wh3M z#{?HH&}l`0Lkb2^AQOXf4pSi0eed_4<8jSApg|3OXIZ~d|b3Qw%j+yM@ zkbjb?WF#NrUV9NOV7Vyhz975b! zqy&}BE!t69mP^rrcssk=*8rVxgPN$b$XAAG{IbDe008mtfpp&d8|bpA@lq5bG7c=I z#6_$!%`G0QQ<2WBLL;Z*sH5*K@>^Cn3OS~1q)!(s3Ec>B;eKulGO<|anOW9F)Bbu#d_-oa?_}>|$;mC!8 zT^Z?Q^5rRBW)w5X_(a<(Y%Tf={zNw9qN;s6s}l=Ea%U6v#sDRqh@UApP!TC!TFj&M zW+9WCActl>Ge*1+o6Or#EX=h*e*9BvPFbrMIhE_{hKJQ=%P%eNUs}Vyw51vi&c*V3 zi`zClZ59A^pMbLCKt10H`XoT{eMs85gwi!(!v1A~;ncNZA`q#_1}umqb>m4-aU$S1 zZDZz_>F*mee>P^ZH|LgDMH5-TDJ35kGbCtYppA>wd5g7Kpiq*S6`J|9`pwUA1~qK; z@&dAa`J3qK?KhRoBt3X=&lwQsQnp_>P(TR)4p4x6`*9*|=G%Wp@9 z-x6*FtI}bPN0$L(*v9WRf60CQt@HJd<<~3suc=a-Hxh4XLS7r57_iwcl{v0Ym_HkS zh(gEV|MLF|WZA|L+{TpO#u|CMW@&)!v5gzPjhD8KU$T8pkqQdWES-MGDDss}L<+w! zW;afD<%peWiOEOq!>+&SRX9g=r1tfS zR$Z_GZ-Lc%T0%nIqn4h7%x9onk0Tq}3Asocr5LByySSI<(N`z>q*b?^#1Y8)$y!+5{Rj!HtqES&dH+u z$&%W;blsEB9w&k$^DAj5Yb7V^HD-e?sJToy3DHNsO2J>>bbo#CG0m_-y_P2-L+(mp-ZN}#7s{St)}MXBUZ*{vpH0HB1eQqrl^6mA79!`fC%CS-f{p&qVco|JEFIYS;SR=lzT4ewuE;w4u+fxyx%ek9|&rW zWKmD&p3)=Zo%zWAXbSz&Quw2-=dMXh#}O~oy!~F*N11?2%9zXEORsyjm<(jn2i6ReS%Rm?+q_o32(p;;Sx~WE zOtcA7uE>)x_cfyCI6BdUrQno+@-<^6o?RmzU0o<)1z^d&*0q1k6-)|xMRwO1&skY} z+ob=umKyN>6q09Bc@6wq7x;I$>~E-EegW$z2rD!~{VbRKFTJUEb=zNNzPn}{-!Ag| z5%rr2eZQAkNu%*G0}6j9^)mR!K)tLWA3YGV_?o!~kt8KNISrPN+H0Yh@{)buK`TfGWHiu zRZ>uLXY2c^+Uap{R(}mm(B4bHJ0CsN8{PK_^d%xD;4CH!H|AKFvZC((TLfBx<&rFR zC$O3hU03E_^UlZ2?XDC%5)0Fghz)hdf!l=jK zwlKJs8~!+xgRkrNNpzhnYL1$i6zOdpI2%iHP&MwxGt&NK6Bt)+_j zy>2-zDO&V8F!_lp4+%h1Xzi<1;;GZGdy(G<3#D4MP7k72$LslfA1kOkyKA1e^89*S zynoOAb-YpGE;Bzpub8kES=sjT@&4}m^7It_DiBB@oF{KZE&C*ha=CEHhYHCIz`|gT z+k?Nyf-20DY5_gw{Bho}X@srGDSpvcvN9XA)E;>Jr{~_Gf?|7JfHOnbqijmmY0-QbZ}JuuenaCJ@U~|f|A3- z{6&Wbr>>RKp>*qnXkt3x8D3?%M*9_Mx{>O*b_NN)DA#|;b#7$KnqB9KCOvVRz0CFn ziY=jpKuFhyMY^O${s;fiGp$|2Z@;8}yLE!w5$<*gt#UgpwN_^c=Ckg9<~r{HK4({5 z;W*{o+_s&d~nrD_ZRfLo~5g(ZvS}>lB9` z;_<2iRW?jkJt-vd9@cZHH9^sH1zq>@v|;Z0?73;qUaqAvzW$9xPp1DN>b~FM3cJ05 zk1->RV2mMpMvW3Rh;A@M??jItJ$eu_j1t}Gozc7KUG(02OGJ$lBoZx&#PK}mJ=gmW z?4R~t*S+tx*5`Z3K2xOXxhN#xtzMnbZ~S@5Q#PMGaQ4|}cLu@#n(H6k$Nyult1fTL z_N;ZCto|>#PAo0Dm63C?2~bAV_&$r-Nz0dg@mWvQoS;oj>`R8z2eJK}H~+*A2w$Tc zVTY;K-hK!FHQB2>{)r#gjmMQnB3XN`PsqNwcC0k0{xkg6@L!Xi4lVg(^br{nI3e*_ z@_d3@M(SeG1TA&B>eMZDMSd73aUJnmM*7!Y7FzoEVP*H*%IS`Wfs^&`C#C;htfOWA z-TwO8jW%ui_D|;8OHihSrpO&1O^wU@VpW5msqmPFg<1|e(oLjOi=pa75Ijk*y>jP& zOM8PLl;`LpPrNhP7a~BiC~`uo-N(w-+Om?OfCZX8x6TmrYu6#UFoZ0yuvQ(w`;I>N z<`f}>QVK0XM)CN5jWE-VXwCM8@1FCRmNabdlGgWsc6J9OBIL~!AM~OD77lzmfSvpK z5m<)Y+E~{B7Wy5IXU@uJ%$mkZkkM3f_dV@L3;~}lsbFB8>q5$A5@P!p<5`Va$zck@r)Q5;8pBe4vx7~r+eVrbkx&!Mcw;y>6Ok5RtN=H5YtZNs z*WAg*>sdZ-W>LP8a>RK8(;^9f-k6 z5DHBDbP`YDJ+iGx0Wd}2TReaQupon#ekv1$HU4wv>^P--k|KL~i6_at(a4d^@p^VK z;Wn5)CNbSkI}SL1T@Y&AkbmFHUaoB<@vXEF2Hcn-u0I-vo*{f^#ws>S^*C}QjI^QP zzSO?Lo{^4ifmb*HiZZ9tAHYl_(IEDFdo{nKu*yp z3JgLRJD~%<3glHRcl~I3|5li(b0UQK(c2=*7O0(hp&aZkSu+eGF#QJlot9#2ppGV( z5$17@m=ur8+Dd*DHD?4(Jho+10ssJ;k-1)p6iu6OQh>kXy+C8C_-!wjS)lRSwBUTx zKCD!3QVg)!yp%2`j&P;q(%53I4;toqX~>djgbu`#VI>94idX=WtgwRHQEBLD|;YX?NbwNGb^miWU1D*(&S!bjH zJJT?}mRy{&tYrwj*dP(PAg7Wn!7EPp5x!h|BQyY2Z_0U8|)h-z5@n3XP<@>&vrX3F!MkvA{kOhQJrpSPtNL+fG^3U|qq`lxE| z3Y5~;ww1Enkn5#E+A!`AgThU-81#i$pC}v6w=XrH8nX9>W4RR5?=pOwWSOjteg+V# z2@XUoz@Uxl3(S%!T=KyB$RMueeyWqkUG^zQsG!h~wMMVo6ZR^e)7L)x)1S~(uO(Ws zAm3&!VD!RE)ra+o0@)ooK@KkP5v>!a4mn88$$1lTRIb~-n*=F*K2Ac%Z*^d%W_Q?Y ziDD?qJI?a+QpuhRJJ>B^-Y(A~2{cfB5=8^X!_)Z==7x{DOL<7_o^Ar3kzl3~F|HZa_5dOHsd(}1{sKr1FaLP=4T{OYL>|sY3 zOD-HFAzs9rsWnnZ3uA0vqhz}4*{I6(OLD?JCSXc^akrE`o!-qmm5JyYk@$q53w6`C z#MxL+*!8i4Wj)QljuHm0s6EnM>U}}P*h7Wn3kFHPC4QLU!|fISi%r}mISr-M`j7?$ zW__WSSoVzNq0yc~6emEL@GXyh%r70a82E)nF%YYzY|_(?$Vv}XyBj@^?Gz@0w7!+_ zPI_Y^sspLJCmz@}J%*lJZ#yH=b2=B8y{@J^ZIa_AG|skB(iDTyxM5hDFznR0^xbTD zvN>B*Uorr}Nn{~$2oF?{h0V}$nuRRqO5< zQZX49fzu)OvQicw-NJBE{GU{UEP0R{<}t(SUF#Yd3Lq%Mv=OOVn{I7BCFgZtS1^Uf zgh|t4SjV}gb5jUIXkxHmiqy3^tv9U&pcgJQcX!<*6=>Fvha6l5B1D!ER#t8{$x@&- zm0Ai*)@AgGdg$zGKL|9-5yn)7*E8c@(B{sp;gI-ia54_SjxjP#}fK?jTYoP?n?-bJi5vjA2A zjiHWN5tSTAg3OG$tLU5vV>TW;z*Wn7znX5O4L>ZQ&oUt#UMr{OXe0aBhBKA~;EZ8I zt4AKI8`l!CYA6AWKG6s-0e_8}0$p>YY;276i<`wlb0fC^O32brMtI?A21_A8URYHI zf|COPO_p0=nstjbH}fY~)z~v~%#8%E64!j@$iO7-+JF>Tj51$5Xa4*saR~7XUr#O+ z`J2Vopv2ak-3}P0eKBB1+EeUxOFW6r2$c^VVF*IS5}7ERlZxQ=&XDv_8Q^@uJtqT1 z501gm0ZSfwb65ZitNw$#7!S*J_#IG-MNbCJo>re#4bDF#_$_DVN z5FTjZSuo{;R(TpMNUVohY3`9ny}YYNj?xyDw9CR) zC`Kc6fQUN(KxJ->Ms6FMc#EEWc6W3ZDNmB-Ojazn#YLGRP}`$-t%)#wQO#rFb-qSm zqyQPG%|pk~qMV)uLg|4E7fOG%ygNKxW+6nXF(jw6sI)6d0fonX=YBWNC)1#W?DpiT zKTD*Hp3i>f>nxYLf&T*c7&gPSJn$Txdzt-IpV?csupMP>KTO8^h?WT|*O49YU6U{H zU*o~qo-*JJmTW--M7woHFSetM2a9Fz>&Ow z?H(A}q}ojihI$*wXmS?odsG zu6eF$iv5pl~hq{-_3T#uU1^;ykX2l)jNinl; zFR7>;4L6s%y0gHlFcPl4?ykM8$Bg)Y_lu3_=6~4*3R9ub1h~K8GAYS}cbU4d7Y5xj*wceGT64&A|VEO`!l706dON8H=x z(WPm8>L>3o#nX(Px0$J)7iCN8S=fygIPP(@SuP`9RTFrPmdCPPns-;RW>>>i?pjV) ztGpW#T5E%LwidWWLAop)+S4G?-li-GiclJDjV+PJn=I-S&pv_>1-UUlu$kH zM#fMBb)-DU>U~BO%AYNX76r1G)1^`;)8Pd@s|9Jgg~(beNPd$yI+cUwGRB5k-l!_z z3g=lvEwI^|7P>R%UwsF+0M2p&7^b0%kAQ=$gm9dMAeL|SIYS~{D_)Kh_~?1|pzq0d z^a@sQ8F!odc(os%t~)rJ=%*Abw`gg$Tl9RfU*to$Mg?0aYj9Y2&Xw4iIo1HjXzoaZ zUBu~4cO?sPrR6Tk(;(~Yxk7D%(^N@%rg={M#bJ+KcrY+)iSRU=#jWu{dc0o4Xxf2j zf1Z)C?*}vt1VF%p75UXvd$WN3uHaqtP5Ha|MUbOQ+TvPa%$&VC1C*S_Q1t8$g|-(X zNlR>Fc{BUbs%gQO$xbzVKBnGALA#dsn;_rS4|9gBY^y;BjnG?tRM(eDRLlfCWZ*z! z?Q4_Kq{~LG-RC_^v$~!<2_r;&*TcLf3^ynf7!dmWBpWoalbuMoP5W@=moo~jx60A8 zh-kZ&*GN%^y?pY8##Gjnk?gVtSo?(3e=WT$kC}~|mkw2`e00CAR7$r(nn{*h2=>Ty zBi7x@D6WsXrv&y3NWdKG_kr_kn2Pf+`31IKxLzXnG2dkcSKtZfuIXnD9Y>ytd&UKe zRno+4pDJJ{8P8>{!jF$`wsYU3flg+bIEgch?}wXzYxL^^<_1{t2%~Ut6f!I#SL%Fy zug-JZG?y5Y_eVI6s||C>IjhN+snGxB1+s*5%T^+=k*}r6q+8h3WSX37Al#5&-k#sa z)Im`wTFEd}oB?E>rY<3I%1`7fi*=M$da*(Rpj)7uDo)_#GA($UMVp2`YWgHcGo_UR zzBRoIO<|dWsu?vm{q2+Ltdnb1edbJGzAfmWP&#ZpqEeHkH7#N=fo7W{WrZ*LCtuIMG|H+~o8I4eSL@qNT6fOVR$=m}@Jjh>C}+oYxomgC;NU zsKL%>H=EXl)ak7=emGTO(e}Qlu24i7-KbujZ!gEJOfnGrvUBXcLWO7~ z)iwhF^_fHv_^Jm@X<0^}M3e))5YqjpnS293gX$YGr z;W4p#1I7-dcJaLBWdEh{<eo@WLe=D5}?}556GzpQ+*oU+;P{g8O#7^8ayJIvjywO#7pd`lTR64 zr8OzM>O4RBbSp-8O6;&|eoA6}8?CCMQ>67xmF#uE4;+$vELtnn+4*HO1q)1TmX#G- zkYeDRVKov;s4fQq^HLtn$jb6?0M&EOy#RQhW_5z)3nQ2%sn3q)aGGaTi%TZ%llseM zekY5|FM>&YN)b>cOtSxon>NtEnTZWhoL!;PxHtokRVyvguS3k%>$~94R~4`k;|1qN zg+$B@>Y&EMQ&PWQB$*mFYOl#azRvu2`|aHKo^p#)v?4{!J8IiCzuFj>Y&$}6R z9FpHBTTx_j?iJ`R5Kw1LCkRZN{+ty1_VMSmEScH)DE|Wg>dSlAXpe5h5cgWM0VXx* z34r;!o?(5>9=!1~g=;NU=fL$Duo!?Xa1&H2AIvpC>N*}Pf4a?p%lP7zna|xQ)%kUw z0badJNwa-1GDbJ;EdaDDMgk(Wv4I(s70)T@e;FLru{b@5&8v^uJy~TaS9jHp5<$Jv z`Kt;6nD#?_Rtgx=Cbc>+!?P}2N${-}S=qQkK%q+_^wqLk1f zc;7-5D2s#>prU3cvEjmvTO{u3@(icgcTDq1eAJJ?CcY!iEzz-Xr7dC&Qx%GF6ZlJJ zZ%#z*H@9rr0+|v1A^tpn9A{OoW@`4D!lQR{C*Q0~>=Sp)E2U8y=nKuo23`d3m$b#I&ySJ-%QX zHA~WYZuEc(CL9ym$>|;(2KdezTvTwE+Aariq?ea6c1OBQo@Qhr834`LhnzT9K5a*; zm?1kd-r6{yU>pFSW(=6>(PavaFwjz|D9 zUw73(YnJ!yj)?KBDK?G)K~LX=ep5YFT5n+PLC3x5(N)CvW%}C((jvk#tMc5WacZt1 zx9lg~5($yY#1)|8+V|3GvXGZih0?=DlU4&lN7}aunR>;%Ddp83rVuPDL#>+C= zhywm;HZ~7!LaZE5j4zhsx-d333d3VMJItnD$5;iuPZK^O^g)RUX#V5gBd8#wG|Yzh z&N9oEuVv?LH#&0qi%>A@ZVF??emALhQ&=PPcqRzK=_=jav)o#Pv`$ji^_>+-K_Ar^ zbYdWQ(4clT4qc3W0i&%1iJvj~h+#8^G9P@P>z-xu=`^*a>*C~9a8~+U^+eq7)x_VX zy2J*JOmCPz?#c}}=lftSvyYQYJWQC-J;yHg+DNS(s?0whJjK6JRa;akCvBVIEn4H7 zFsHV(zGGS7V9U0p*3o`x+`7DX&-O_E0370WOFJJF;X`hw2c)VHr{wVh*lO)Nh>`dF zv{DLu&PAnC&6=IktmmREi)7VHog-$FB@Q)p^|X9f7}o5f^6YDE6^725W;%TcE!=z- z>z-}>36+8bQI;VE@s^l2Vsub9slJ~qquGFMjB^YgGxP^)Ox;zZt|ZpyqJw34N3WN#|)ga&N;YIDLCo9R20FhX8{G++?SkgE*fAw2^P~4zi<)zS)B`iyOCe>L( zT}JwP((?LWyqw zmdg6u$WUzvdAIa@Oof!u>vdwjvnhWOP|QTCk2?u3)|p`CGIlUjBY``f4;unC9T~&nRJP%ZKOn7UDZTal){s9b3HPj zc706Y0XN2*0r~L(FEh`3)g7JY46puv3w{`|MV5Chb%)Jq7q)WM#KwNoF1Zh`k2J>i z{Zf9+{qxq6?liz*bo_?`ce}hyUt7~_!!K=aEBhng-w^f!e`uKxT`tca*XPPbAl7Is9AO#~bGwGSHmn7keFt0e6WBsEW0dQ)1~R_y*xP?oJr1sL zBk-;nec%s$)Aw@0ph7a>yNK^gBkYW1?mb1Ntp_S(r~rJlJXa0Tx)LWMmXnmv9A|6d ze+P!{C~SPqtasR33$6&Lz4xt`GI`Soilq3$AyY%*T~Pbe+!2sG{O~P`5S`D!yvfUt zBA)xsSQh1PF~t~)mG`e@ASEQimJwy`b;cD-y5hZ|<}M?!$P*bYJ_@t6Syq;Kp1*U(X_~dQd0ql0P~meLxk9<@Ih3T3Hw&$$PIiQb7vDb_lDmfYBO!7-#b$G zfkgJ&Om6NEm~0!&XQ|{_0&byTs%L|=OnZo$Gn(!&D$6=ei&;do06k|kWsU$77D2D= z&ahg?@ac?UOMuZ9NJM?CLdB(E7=J;~mrNCPfo2~nnyY!J%Y)E<#pSB0#$SxkReyL= zca=Jj497f_VSD(ej*%#F?SA5#xf=7Z27)GxIirA?Qs*fu(H5C_WmUwu=gwwS&t~?6 z%~Fs(L1T{jWWpIXAX;=!3w&*|j6igVes z<5v-_*e|J=>rkBfY0k=3I`@;>ith`b0lgKs^=|Wd1UDV2nK|b1`#wP#;1fE}Ct?jxq<%h; zP2yJyx&JSiT>Xg*YcC2bfaFMG!qxN86&Tmpi73lG(XK@f6GOv`+7P{iFYxt73L%JzA{#r+D2Hi6&2poPH*rqtTw92d$ox zqYaXiKP6{9xjl4w3xA})4V8zylds^EI&6?W`6+$&L`JzlkKSrx`=gc6y4&>!JF+)K?qT zKb@-|7KuQHDJ$6VfCz;d8==A{uMo)`=&f^2JYg-cx0V14@J9Dm(nje|lG}_Bis(X{ zNmz%?TZhD2;fYOXQIHgjjL9A@y{D_lDXgd9t*6|i=f^Dn)C*7nXOcWuVp!C(6gIH+ zHgH&!j#9X_*nRkXN=iy!d~j6LN!T#V+b~jJ?^lD;uX6`%qhUsqQO<=Ck+<+II`)T( zl(%O2yhy&VNt3rp+ahXuK_9=-WYF7mw8`|40_ZB^Hg>2~DnqExsX;_afB zFJC>Xa=m;REMgz#V;|XUAA4zEon-O>?2ysykaOvfFXCA2<5=!FgO>H%dhKu2xL zZ{I+^FI6-*Uc>*eByXdZL4u3vt=LgbrP`Qmb*ILg6Z4Qc^RFsIHh=s(*Dv0D>$6Uc z4R7ug9zs^oq4mbU>kpjjx)Je)7ntu&u(Z~D6Y%A_|} zy`%YLz18#U#$-p!x6eHxSs+Ab>)Fmo68DR#&bFWXGk53}uK~on5D#Yhd8(`9>f2_2 z^zB2<51qfp@dzDd((2xkaG)bnq(~JxwnE_s{~@lp2vFhlJ_lRx{<#n`M=;0P{~vL6 z%5Dl|3Vi<(TpDzp z5{|Fnt(>N<3OfIfEd@mf>ZSPXS6^)eT-}~reGZaW;cS$M`o9x`5U$2>i~-oZZO z6&qM0LG10OzpN3{tZLjJ`+|GN`}@ExL0jywthDC$z+v_2@2~4zTCL^Dbe-n1IUy$F z0GpoEj+ny7s6XEZ9=y6a8$pW&ybXEuzjD>zb4~Sp(Tk1kNy|cSSF1a^YRHK!^_nsT z`0tm&Q@&3Z2O%rJgVcCG|GQasl)e3Xy?>j1@%!rA?QKZnL%=DA_}4Jp@oVY_lKZB~gQEe%Nux|hax3!_0-&D_)d ze*>K4py}--c{o!H41`*B`vwf9CKPvdK*mP3tsodtFSF zN?+03ihJvTvSC|aOb$jlr!llbdL2JM%6@I=)fn~6r*wRTdZaQ7SFgPTrOLX$D;L>2 zo)^~N+)Hz&#(#J(Y(w62Y(6&)FNTB)z2Q)*!9b>WZ+mtJAXH&ESz*vJwpinTQTcg` z!*JL?y*yC?hJUcZ8v+Poh|L$Epa#f}sw}r(9%uMWR+`JO%y#`wvEs9W+!f&E?mo|& zTv=y#-gnp!w3l`26AV@pkkv8cREe+)l^2!%t5g1$4LFi&`)8_aIc54nCh44LPL*-v z$g1Whu`Op)8{&1x+q-S$*2!wuzXdCwep97uUst=mlYqNihU+x4)p(?h{|B3Dn^8Z~ zFI4?>WPH`WGFju(GXCjw$gw>h@BQmxI+3xtvCduf_x{V{ci7Zb7ndqez#TS)V|~DL zQ2qXG@XFNm$yGOiao!*E;I7W`d1M0~^hC98dhME_aQWV0;M|M4xO{a1&5 z&9TEo_X}$ZC5$87jtgGE=i%EPNGNqyQpxvay(+Fx?KZZGg8PkI_}HM+p8VMJgRqWI ziM&k#dXsf{3M%0<}rf%Q)Ia)EU_D zK&MqzVqi&Gr%4$e3pGX44eE886Bb*z99gNd&Z{Vw)w#5V+Q2&W=!kq&usIJv$?dG| zK4n(M`(G>Bm`3owHk*R@gPgy)%?Alwl@>`AJMk%g4Dh_FZ^L(;7+F^y(jQ5B0D489 zzs@mG_rmyf|14K8UH4DMi1v^uiw2=KmW8;5e*N?9P?G>v{Bn1+8ka^=nC|Z9@H0`x z+`p;OV_~~wx4Vyr`OLY|^`W1zoZ76%CRu^6QLF1-5toWFW05SQUmj1VVJMkB!m=9< zirQM_VE#LJ5uh(bvflHadari2Bo8V*+SHi;ITd|3rC|6SrqJ4Eog`$+ZM;0T(4>?W zU~ML|n>(@aX^K1&4Ges$^ju?AfEiX?>a+*Ki0_yuzi!_q`CgT9p?nj~YtXoJyfNrz zcku7>nh>u6OmvF$=S~b6Mty(NTl<-z&$~A{mx+UK^zbD|LFIGC=kBpYdr$B^AS4>c zBbUFF4AA+69|s=lS_=Eo{?6~%Cpu#*h7UjM*4yr8e9a_x@G_IoC!fTeXO{5FwzkyIiY%zb}&_M5WPS!etz)N;9vr3j5l)- zvmr)mCzw4QgXU_*JYK=@USs&FLj*-bgbYJOu8lZ+L&UFNEVYDyb-mGhXnZ#Z)|H1z z&1A+G1p$ds1dMjb1%pAxSM_z&Fy4MVG#bNZb zYHQ@->Ak@ZgViN;!u4gt^?_l~t?=@lSLPh-nXhRA(Pa;(mMPd@!%l+2H00h21c$#_ z(LH2xt+{>uh|umGE1sW0egu8K?<=NvcdYf$zEP?#G|%2WUzBI<9o;cM#mYShgML-=S@CKhh$`2eVXOp; z9`_};KtG6OnT>oH9HXBA>aeG3xPFi~M8H9A^*-HFF&K364yJ4f8}o&BHN`EoLSIpG z9bSt7Uz3xi!hWt$nI1e$Rw6GyfZf_tTtcx#v+>beaU8+1q2UqWC6}Et`1jXQo!X=x zc!1HcpgneIW&}Bz5i~y`nrG|Hi(o9U7tDHy#{4<32L9G13uHE=7>Tlx;zr zn5hhoC3*75fbq%YBBM1V)D*jGF&#iT1L*#W5U9cS@gTkhbu4@HO)Re5Z?fTHO5x8>2lu91{7N_UV^V@9=~;3~{E#LL3i>MhAkSAc zd6fbX5p)o4sI-}mjZ|J@vP~SY3^aPsKc87jgDIIyFu+S66UcI^V>+IfY;}0s><3F5 z%3>7DuFFujt7DvA!CT9NkJ0xIerjRq(Qoi9j5;DOO0|PR;K~N$gla@*8)-nLj*IF#VE?dMsCm+=^O-g1G{^ z`OJxCHTwexbZ<3Z&p7vGD#f=9I2*7Sm{5c`g8pF5J0C5){^iSFk!Q|T#GGEJZdK^2 zf`irT60a3{ITBX`iX74Y?i86=?cNfttx~e!bdkDz{fc}E*_2!nK$aH4(rs{=NiYC+ z55VeJ@@NgmnOQnC^yaSPzu*835|<{Vt$5N-u8<14GYzuD;dIR4=1tOSi(Ktskns#@ zla{{3QK456)vwCzm+wkNb>DpO?3lwIgj_ZVkf}v{T$|S?UtmM~zB#C*!{5DqnI`V{ z!&gToqIKZQ8HgOVLT*+~xfX1K0RO6>tT!fWT7$Mg@_md;+Ku1ieW}#g(ym=2ha}q~ zv9+-D1d0SSz%=UVGJYxDS%w)Gv>IQdurxn|2=MLRosp5dN>0mn%iR#!&*tLYARhgc z-^qo4-2l|vfPdC7Fl%7K|9*ANKu#4ZK-?h91$8OTk{nNix;JQ5N(rntD9|;k(E(t> z8M@q!Q4LK3CefN}!D<~%8Xe8*S>(Q+4NvHrhSSka<|a_9j)ugZnzVLkk|XgC$8xGa z`lc#%eD3dG2((gdr677+N%C8Pbv|9z!H@v_=cEJl19ceW#2h zKsTuptgwQMkPtjvTed2}6k`;cXaA@sJ{b@uD2D=KrZ;}h}rb8RXz zp@cZ0!7riCcaeZFUsThFmDH5nm4+N0Rt~MFq@~Zy__4>n!qwOSY32GzYSLpMrcCN0 z_}*p{cZux6>mm9%hT}3NGetUCTaZam0q%4Z?H-(t909vy!KoYZ6|S#bc%U_FAZVR{ zaNVG4jr2)X-|bo%e!l%i=H2WKS^1(Ge+M|(zeJXHkc+-X{=kbcsN(`R*p61o%TX`4FQ3PQ@(uRw#U5BYGPNC=J_ z&ji^860dA~3@#mq{BzU%jLsal=7%)RKwN6U{oHJ((0mZf z7|HS;#O8Q&s}BiI#@HcOQNTB5=$jVO^b&Ba$4hV(STdLYd~EJ<2<<-!E`&J5L? z5W?tw;rkBcNOYbk5FiuNeP}#z;Z{ypPt?*gWU!ucIYXd=7G$%Iz(Cd} zaSp!5UCK773oKXFWH33rnY`Sr1zA;r+Gvrv032SaEdeb@POvsgPGqk9>S6dbUL7m| zen_)hlrA@3{8O-+3%NcfSAuvwr{AFPJDJtZh=~>%x(2#3YSMq0Q1jJzl$fdf++jHV zBP8(abB#d8c={0N{SqEoXQ?vs-rI&TB67SG2|%MH!0Q*(%q~LbFBJK`pG^R#17I5F zzDt$|j|W(vn2}XML3a0x3be>X>v7c<>$6prTP9>hIEG9MVjDNHHB++dM)0W?1PF~d z`{;!(+$2+N29t_|2cCDEXjvUqjT7dBeJWCbXVZA<&{rH?hW}z5XiV^PT=iai_=c(4(^lWF3C8QfWEG+y;RU82*N5Y6AP`mA#X2- zIz4iD+64{Mx7k#Onzq4=P8;$0Ny$5RfatL1uW!ix{hUnY`tZN6y0FzVa_zn(0W=iW z1V}9H-XQ`=vSB*BDmTrix6Z4yZA48_Ne?r!8?}L|9@4%Eb+O25_ow^XrUV2XL{H_W zoNcR5-FD<8_d&hVHA`y!W_dQXM4VhJ_!difB=s*GkD@9XN2OLydk&xeIK)mFc%dZ$ zhfH}Vn-B^(plEZIYdv_ReX{)PVeZtLny525j;O0<;>4|-pSQ8m>?0mO_#&ga=`I1J zr{yE$eMC+Lq=IELlyNuAnB@>7T{0+5kAb?!jiZ9&!(_~U)? z6k~3%F%t5@qM%x98vm9b!W!#uV?oSqK1X^4VSDl8p9R_V1Y7Y8X(2s0BX9>FdQ6B1 zPP+NobMo_>S_M66>#Y{qKdlI?$Xww>MK$#K)G-7h1AyT6)p?zHV=HCW(5F1bmuBfj zx4aK3{NMu>=MEQQbd;X(Q8#HIVnIi@#X8X#320zjU`M3kw zqx~uzAdL}EiJu@;gT4s26#2n8J_eBXHpWhU8Ra~%T)phv`62p#2gIUcJ{weTy;2=@ zZcl3NTnQFZBj|Tss@Fb3`@I-XuU_Rg8F*X;f_?uDn|XK6>DQRE*0-_Rh0chl2d{S$ zRikgnEh>&o46ANO+5T}&D?`CC8#UkF{Z4Wb8)RI0-nlw9+46kvUtr*ozyS5W?r%e$t z5d)PFM7vbUJQTPXoytY3(88lSNPV#xH!D7Bu`r`ZpPLhJq>&{h2mv&>7y(h&jxuH> z;(IS8>bOOB7(n3J$#P>sPVpyYL~I{{lxlb6~_S1$a`wa z9W+Gc{?)f?ui%c;1|97kZ@vkUWc*CSw%}6}n1BBf)U}$1hXB!|KMy}NOZ3I!cUX7Z z0-6;I2nl1U#i99=xw;YU>uNCGB5l^O(_Xv^W^5rrn*wuK(%Oyl6%yFs#K4nYuc=4 z;DpN|6F)T*2B8%~2x$9rMe2XR_(4T(gAAmSpt+(DMjLIRH>P#>gE&^Sf7A)yqb4Hc zqhV191`r5G#BlJijNCRR4CqvA60M2N)f4*aPHKh%pp^_n7!KXw1i_C*sXncu9Q{D@ zk=XluEF+CviX?)nQT`vbfLsP8v_g^muTAQh5x?sdTF_j9^Xh{nKjlS?Y}0I;jb7#j z5*s^|7B?F^z2|wQ>D;o{Z1Sob{M6KaxXTCV5#PXH;Qv9_1+v}BUoIG~#t@{@;Lm*U zK7xvs7jkV}Fq((gtYk&3<7}-RZJ@)F9&He?wafiT{SRkZK09F$@tcT5<|u)=HB5Cn zHgTKdGZp&&c%Du()(XP5GUjqeTS#AvaLyTzi-wG1t-i#gw^;K9j?Uh}Q%RPb5Jdi@ z+3LFQl!`YrKeuLFv$@uIgJDb0&LJfvszB;91jd?FGBXP5U z0Uad4(}Xd08vKaY$dR>g;Nv)$;zb9GYH5UDpXO;AJ+6d-@@#rv~W zJ_Emd3h>xTxqpGB~QlJzgw9cDO0npKi&wc|?KimBP^ z&^X4?SzWO)Ex~xO}cNhG%(Co12@h}EW=g^ zRIML9{@X)OOUaQ-2tdxH&wM`UXz;mEv)X5|aw7ET`2#_N8#W1}F@vT0v_WS(;WrHR z8kG&u%qi37@>rD&P(B~5ELzKB?Vi-R4oQ9)iUglI{H&zRkSDuEwkb@%6_gCITG7hQ>aLpCgMZ(Z}|~6iUWKwrT0)+$^gS4I3l|Ppb&spjOK9N!E*>L#TWf%0x5FySIishs78Da?3Ln#t@b;8nMe}Pq&vRKKy+qYq&M6WQ*JwL*q^SBYKD6x7fbFamVe>yiiu5ZWmOnV zNaL~NY!+C`Uhcu#jN&^*uL4RQ142GppH12NLN+M}mFH`7X+vHm!X$%;X``tU+lW&4 zq_QrdX-OXiZ-w#(>!~7RetM<`-SQ`M=dH%_P^Qz8v{^oUX<(;R8aY2cbsA3;B^@YU zHH#!I;RF6t(0i*Q_cW<@6bztgpThh!^LG@=Aepu)@-)OK1q<~RSMGV~v%n+$ZvUm0 zjf?m10*`Ic=H=6AwbT@h=vNB!e4l?!zk7RLEE9P2{PXV*x0>1V%tv3}2-|ROoDf$p z6K4(TcCrkysFcDF?IBD?I{Nomnxqj8^xR)St`o=nUFjhM^+Msjg35F~SNXGW3%aF zzb}2_*~d{u%95NWJ4OK~32puh>lYO2f0EouiAp<(pb$c%x=eZ$LkEsn&kV zwvLL|*P|2P3-Yr^3sEYQeu{A)aYbG#${>=hpTng^D(BIo!4lYqjY>gPV>KX#!YUEt zF;d@*s?99k`;79)aLy9Z$D*Iah#Jb4NUH! zG_sw>J<-9FyIvZuh41A;%D-^PVvxu&qp6pFCii!zw04Kw^VMeCRN(Z}+qL7TsOhbc z;b6p+gX1(tR8#hiCg0KI_K1c<`}Fz3w5Qkbjb7WY$Anu&q@xQ${Wn~%KTd;UnE`_$ zK2VWcH8Sz7<|DlsNcD${j|JNNT0d8&wtHq0l2o*=?WeHUTExP$6k*bYF7Z?`>{JS~ z55lxxJ(;2|)|L!N3K(OiB;ir8o;f?1y&@(fUY!XhoZjJ|V5W>`(yy}xJ}RHr=FXg@ zO_ROH#DrFu|L!|?D|F`>?X?Sg&x@waixtm{OUdR`YEIdY|H!5noox8xrhRxlyN48f zD<~!Ni4fH^uMn$fZAzx_iTSTi?Y>IgS&&{7RO?qyZ_65w;>ChOC+*RJ)=N2HaS$YS zM7QZY?oYPLjbX1C_q1wQwb7piF=uAQ>}ZV>-t>j;i&(jb3QD=>dl z1eA_t>>J_2bK>}ylVs9irQrvWCu7nZvx_c^1K-s6jBLpq14$0dF#d4U-B?N!|N zmV9z#*?0^tA+vw%_5UE3zCoAWIMu@u%NWFz|D?cc$`ucx@k``##I-cP7+K%VBIaOu z-yXfPHKIr2q;|X4id>*yiJ=^@Y+wFCy$~2lT9&x5PSmnMkJm7zRRc7eOX=Ga+aZpuPCr21m9T$^HRvb7) z>=hsI7gnJj*CT2O7MT0+d-eS!YL)f9_b8h7@W65 zOPaOgAE}K^>@ZrumR>*SK}K-vl#=6mpZBsKL#v|BgD(+={obY|A6G|Y=aMMnN18}| z)$f)Yssd1CL0}AJv`N(U|1fpm?`*~Y-@t=}Bv!<(6?>~Hs#aqsR&8oVi=rs4U2S5= zju}N;dsAw5Pzvo?@qA9zx6H>0EyN+pzP+5E zkWursG2I%`Y_bsylbhReo{yf^!)m3{mQ&Lf$w02s3fTRs&D#Fb`c0cmYS_}p>Mv0z z^l~KBt+P+~Fm))E<(blY%G(YhM%LDC2R~a~dfMK3G3_Wqc7$217h3ZAZF@Mqc;K_t z|JkNnsx8|c+z+C)%%s`Yqq&%<9(cpFATQzW$TX8mGxc(bjOH_!m<7;j8B0!nC$@=+ zRG$e_nRPDy9!^JlP-9E6f!Ud}todN~Wq+AY!2T)ZI*^!9r%21m7DcgQw=219{LNVC z1Nytay-PLb#%>{eV}dg)N4_N1=o}-77vfpo;eY?U78Lp1l`$07_V$D$bo0~%?w#F+CTxhUugh5CS9uI7cJyK#XSO-tU;-$LQpA1d6f1(Is8I$1z zR2RfM*n(H|-c00(<}kwC%r<%J9SeO30g$>w7!A(4?0~QlJ_0QHOnuLG^Oe)LF+&b9 z8VH!2cMEK@i*X3^$5Q-+Ib(gdfQ%b`IU`rf@%XUC3JH|SvL!xOoMw(4kFO*~cy>p; z+Dft=zY1BDdFyzLC~HBkAIpiE%Wl7(9HW(`;XNL9Y^n>%-r#DlG*?dhEuE@>!KW6& zkd3wbyx_$9#?!psh}~k}wVVYj#utYdcU&xh%i_V$B*3j6PLUdRjb?jL zo7ELqmD)OMY6lT>_u&&rj|zZ0ui&FFs~op;>-K_cn~FL3VdNQ#d+ga?$Bv!}nnG zOkP2)Lg9%UU;bO=b`NBk$JWNdlg}T%SiYZQ`doE4{dg$nrvRw(;NY~YK_Qfm{&S6Y z_{vo3kDcXP-z9{<+>UuD@&Vcxz3=n`@VNd*^G)@E6zn33vhxEa$zxBPF?c)T=_+T{ z4*bVTxzirZT*^BaI-EA2+milU;wLw>*!>XReJ900{Bo>J^Hn?*tvS`N!`?3)v=49;eyxQMsjIV zfccO(R+4CoytgC>GDbSx<&UgJ7?6=OnP0Ux6%Cu2Gno;hpQ~ZG7y|<4AEDjTieJS# zngiU->kN^S@!6*&=_#KG0Kj`UgbJWo6(?koF1zLO^Dvoj`5hnWdpf~iMil-wYPbY@ zY+O@mtXXT&y5%FCcmaNkz9R?q`=N z!~&9H5fqt8r*@{T2KxZHPH_zv4rJKGH{OXW_JK@jobR$5(HuuMMeTRJKy%N@kd>Xi z6fPlI$5#XcjsT;d`b9{p2BAq&+$ZZ@%s|8sQ|*;JEynn6md)8}t!lCi=OQZmOUb7{ zAN>pW&Au^702EADiupZgyjgmq`SOA($cHQ3GM%uGYa`DgK3Ny+_=A-#5tjeA z?Y(_DZtnW7koa1g6tK*81v$*wdH5HX`h)i8R8(WKCCmIjj&Sx>ag5yo?dKr1ioarC zh377+#Y#YCc)+?qNt9f)cgo)_?5k-CNn}we^v21}qmwdotPL5dls`w97DU+j9Z5DR z2^dhnoEav#Rpw8i27!TOW!k-`;jXjXs2he?+8 zJtkr#k+QJE#7u_Zv-0Z& z5yKePtC_!fc7m3TJpl+I^pF$*3xN zQaDX^IMWG9I3SqAqcvJQ@9@_fEix4M0s}1x%+NDwt60tYl1`zzDx;NG24>gg;}w6yItW(dV7eqh*@mrfT!A{ QMDrO8<> z4~S6a@>#rLA@GlQ<4UvwalV{}9;@R)>QL3z2dGF-|Dk3hv@B98s&|;JYB^%81Apq6 znR)Nu<)egqR-lqZCE9!C?{(3&V|`5dmqyj14zY400#|a@2}lY9dAr}3-c-f-rtfC;6W zMBP4y8(2R*5rO^fZ*H#NqV-N9G(UdSXH6mCNYKG@I*93Dr_fT%k-P`Yic!z=CHela z$0-2_q$52DHDN-le267#|6)Aeay&E;k~`T#ZdAznZ^->)&ZKufCUeQXM|pf~WkF6V z1zg)sWBmz#GMm^zfqufMXD6jl-L2Kq{wOGJ4v1!Ka1K|rL#5)K zlP$>$Vrr`*h9U5}VofCxqvX>635F6e$rG6u$MR);_Q}CIR@S`>&(L@g0}>cbnsnn zt--kfUrL2rZlZ;xN!uS+neNndGh9eBqo6JfXF{EU+S5=eaM}IcQA>LC8Rchf3`4`` zAq{Od^Q1=8*JO!E(aY*6tRuf*QY!ei6&zPvAOWy^u{a|80aTsFSNTf_^bL>=x`Nd5 z3%i6X^x$B8#rt9kt=J+>x5Nkdfiwg!2l$ueQOuDhhT|{ljpw{QCI6+)i|-tk7LQG6 z?NHPJLqJBK2oD>Ojt}!DIZTQ0Vj@m=#csv@5b3dM{uTg8K4p2U0xMtci!zt{E|V;u~rnwomz|JnfL2GC~!1g^8=_YPG#J zht)fUcmvhkEvz9dbsBo}lZXNeXU3MFEA&a11Q6W(B=m-LX=0p8Mh}lMU3I6?i6 z@^du=xOva}Ggy}ULs6hxw$2Sp=bT4yPim5u`TCJIw|b?husJHv(_}$7qAK3fgdpg6 z>0Fc^0^HYIMUi^h4135C&rGLRGf)uyfgpC)c4-J{0|m1Q5VsO0f44_7O>=&<2IZ=1 zUR{a>M;d&3%YsH2xx~Rp^BYDH(x@pZ*U&L8@13MDxFz`;vpQ5ml|)U`z6Vy}zk{Kg z4t&2VEP{+Fqo^j7w&{(e_PtC^sIOa$fKQUdttXq*x!;pa;$?cX!y4KgS}IWcrdRHx z;h=LMIPRUK3Y-lgRO0@{XnAx6clZQr7AtmAw>)E#i_A(xt<+(ZFUK+j_?+EyMeqE! znQ*6D9(h`U7^`_QfJfE*lt1t_>Y)ue8;{;4v~m<6ES_o^&TgE2M$ppW;2Wxn`r|Ua zh#9=^8zC5i(8}bV{w0ba`MnstDPisW@-^$7v3PIz)k}CHQwk~WQs+9C=^l2^G^;wY zesJt1e+QGbg$sI`#UMP2PX*XBz)9EBNS9zOZ9$4LAI^#kqvjHGX@!9@mr@Oc*=hI} zNW+vcqOma6zl&-p72CTH7*O8=-P-mC7=_h!Rb$6dUMyjQnw-^d3t}_hQC|)2q6g+0 zLqe%$1nP{15v-WK1eMc-1nfEovu_UbD{hvE9F}-)*0dbfx7>CF1{jG7oItki9CDMY zB^-rR`obkkq5}EY2KgQK+_~qOoK>L=a}=fnu-w6diCWs{th-cIn76x5g39CQ$q7g; zXW-eY7`0Y6RPWWc;~%ZFx8!+7SLCw@Ahe5*43>M5BpESRAmnU370Hvud2P zVikZyR_;x>@9C=oRVqo4FW4=gAj0^H<8E5$!u)H!JbU|x-s$j(EBs(-Pu=kw^cO$;WL#w zzAIXqmFEK^M<3OBVuP>8m(mmM+~h5^y1Y_00k5k2!K_VlE&?IKFKq5?pQ88xX{ zRXS4`chqcz^S6h>Em+-P02sF!Yu)?XPDPB5^m2vY>;8_3g?0F4uDp<6!odMY5 ziH}u1gG{UNdt&kp{4?*D7|r^0;Bs=ThxlzK3v3qoZ8r*R5BTkDBiuaAMHw7{J>4%D z7U9QnqG1KU6H|5i!ZxGmNdE&irJWXw&?bnbsFF^~)=_Z%nR32jpfsiZ=IaWYYY zN}?MHYfm>rhqh>=Dk~7Dodnj8ZJREj6=w}&FtN9YSYk@zI(`$@6{h=Yfl-+7oyoH` zP^u_sd8r_armkbz4N2Ku^_W&>^Dpa^@C+^O-$)y*29AV09XSu%M|F5lr~(h1bPbH6 z5`(=JNRKs=CCLmyCE;^>v0wQq+nS#DwnRsx!&ao=N_J-GAjmrRWQ% zV$DiYZH3cvV{SC%*&+xC%avK6oOkWlg zpkOteDNWMMuI#<|4rq5W_V7#^0HO~PeJmO-}W>q6E3aGtd$fFbc45>nc zlvXjybF*@C>={Qb#Vn&Vb{OyQXiQFMv{PHa{ih@%-dcZ%~GqJ%7V-*j+jn0%e^GVF) z4CBkF(LLQ>3*s+rM8(Y7R$l8C_x#CTm>Jxxb}SiO6dT$oA36{lJ}VzivUh7SI@S{d z@F7Q;Ds0&W+~02Mh>WeahEo>~VDAg5AIENfI-Xc}yv`jj{LZS*hr8dJwyO!Fb-h>` z0t^`(X1k|PPt@;^MNfFU=P;d&NYS#paA^n2Uq_An!J()#(2ca)jB88`GZF z?dzObd?(;vYDb-eATpNXzh>IJ1!k^o@xg=6go-Qkzr7n9Qo5o9)>W!yERqG|@Jhq_ z`W`7`#QjsleqUGi72UryCx=4L*@wadF{h)O03j9r#B6jrB7#vJztc%=_?=SwTZnAp z6OMWJzu7a^9ovQpS9`lZjr8L`MKU>ipwG9AoTwAVEw@Va2p}8-&r?53ZsQ)R1MuM9 zBq({Y^R#kdc=6`Hjraes2RE;F46*6SQC`543z-!dGi`?Q6$V*_IRndamxUQ3|3-z zv3@zV)uB@yFGndn{XKjAmg(!&YuEWFh9 zGS(gzqqb3LHw3UM-PCZpM0J0<3`?dkWon*!4w{K^2ceZUdc7r$!?OF9ET`5KHjZp& z2$lP2nrA)HmmFK@L6FI_T)BdAQ1`lN&Nv^JwhCIm1QjgK2p>%x_qjf=f3ov7v*{kD z7?V!SKo|k&`_iB?m12T)xqeE3haUzz#6+~eI62IYPMrL5f1T0_ZTUq z3~RF+H!tU@I5H!bOheEEgjABpZEQAF(gyYeKYD*?y2aZ1e_ufaKYpWeL8uX*AD763 zNS)Q_qJou01^jX?YE$lCG-(?z9JH7S#d4umS znF49JrLjY`$;%!f{))8U;$gp!g=UMXc^`V~bTDfE8q~kcs>^zl=;>qZF>o`*`Nejc z&UxbHn8tVK;Sf$UU6c3X0~!`O;s^R%D)SrSDb5~z>G!WcTQRh{-tSd%25fX`LVjDA zG=!Z^TgK{V$1>{GGI04_n2>c12oGM>K8s4-!<=B(f{bl0Rc%C>Rn$LR&Z-^#lW{qA zN`gX0yjkvx3ICOyQCO1o*sS;XA^Y@K{ZnIBxO(tk{0&kzaekTOZ8s|6l-90V(&e7$ zu&hOak7(n=8@pAi`VGO)ZiCqP{q)(Q{&vAVlSrA%Z?I;u`9AvP+Z`()>c`f3Ms)fn z@LQJ^>u#{0MYHSw%bA-}$~|Phk+6EL z17r4E@a>!DYbmV4ZpbG%l>Xc=zFO6a7GP-=iSOpG^_e@Iv-evU6Y^YvCw-b`0b1dP zW0`*p`bA7cW81#HiB8=7rRJ?$Wx!8cJ=DdJp_vjO3jW=N{|FRp$ogT`Ju<0LYgmQB zCWjfA^x-|xE#)&MNmiu`Pm4a4zLzg+_*C{)zWh14S{HtF@?vx>5Um@Y?X6+Nruja! z^ZmQZ{f}FySgQA5JJAjlq&Tn0Rrf6cOc8w=dN4-Oq<@+Nx;msy81U2k+xr9@08Qt$3NpY)6)NTeh~TM)MXfFQ!52+a ze->0XjeDVf8abck8rBuZ_IQ6NEN(U}{!m)_)wHy^ph(eNVdVozv``?Z#2HYCo~teV z>+GvO{c`m;^*H0D<|~a6Un{$-_t>?k%A0=8n^@(o$bX-OJX?tWzUIC7tR5zS|JEH8 zE9%)+L-v_RNtgTT)8FV}a`e-xk>eu^we;U62dGz(( z<&URd-!*S1znZ(LDLSW^q28r>07L>=C1JhDPzpL4p>!CS6hMY!iA^Uw)QDle!ySWIB88YnBTC_cJs%l#muT+>HDUUm zMjuWSwwGR_*GIhdm@~5ju^IeUhKMksV32?5D>XJ-^LVMNDbw%i($`=LX5}n@ukYQ7 zg641U@;I-MKE6>a_VR9Tx*16SeY3Q@&H~%rxiJ0F-@Zh`WAfV#Bkkn3D2l3ruT!Ys zn|gvC{BQ2VFItB41Md1*az{yug>41zDhf3Lp*3EHAWv{q6*%q?9DoMkWcZD zkYk_(Pw0I|N`W|vm^?0WQZ{Wp1zbFYf|TOgfWg}siwPx_-^EE0#Z`w?6Csr3xg78| zCf8`H+k=ponuxwf2;2y@>N^qY6;KA~>e<7+T1jprd)R~J=r%d5wPn_f=@0r-DrnFGLjZI&5{+}l7QDD^@xJNiSuDnx+jmBk=HG|-^urk6<0 zw&}00m}~I--uFhQp>hH2OC)(q!3;d=1H==L9LN~S;vSF?k7gJ-xcL*Qz@+~}mbLm_ z|A#E=H+?$YovwKj`R6Hd)SE$(teS+Rq53U_ISCP|2$JKT6SW=}np8cVuy|Z?I%(~< ze>!ztVOSplh;=`kaW2xaCaQ#k^@;Csa62M)7_4Y--&T)kgJ&E*K9kFVn)kFZeL4)6 zjLo5XjB3B}$r&yymWX=R)6RKu7>T6HgGXCs`u-jrB=ljh&|r^d^@Qy zmMp;0ybE=Wc9SH887L*N|AThDsYvx7#gL#{(t2YHrHjydtQt8P(*uRF4ZVXNfj={s z*=K-l=%HsBKeP5rv@H70dVRBrEi^zrB8;7j3R}mez^gqPZF%w^j;nb`B+<Bnkbz5!|nQpsrR)s&s%w)fw?qW)=j*ly6S`f~!);{}?K3MK_3%|-ue z0zW(m$`U0Uw*6`jxj3jusT!MKol;Es`yRJxJl4zNXZ7naNrGYkBk5GvD#9iaCaAk| zo4>VrBF8;Q+IoeiJ-_9a^!rGy%XZOc{)!d+)zRhyt4yEIJ0;_*V{av`vd*1%$qiwv z(+dl~bN)H+M$&vBMCz{bd?Pfk-Q|Ie$w#+LpL9qEf3SgjYS29z>U&Ts64Au7&Jx_% zr5RyDwCY}z%=GLJVtvjV5Y+70H+3}te;{dHJzqdU%gPTqzM1uPjDsauj=yi=wn|VU z@g7lEP7oBEMHynNGS$O%KjpBo&v^^Gy8N2kc|-X62bzl(ttS=K<=M8>zKclWg-j}E z58aYjx~5g3smv+saZ3 zH8#@J)WlWf85_S)ul$Fq{T&U^>o~hpWlL3+x7zXa+jVsmq(LX>rM11Cm1Kt#*3XGx zO{etrvrd1Pn=5@2omG1m;{>Gg(=gAEN{dX19U1r5r&9ZR!M_*9a_%_Myn1@9`o*cy zT23Pg=kcNH=WsbE(sCmAC??bT0sobGY{31jrY;P=r4*=1L{PYd-Oe3F7$ybF!3;?e>t$hA>*XHY| z>6}MCrng>h>)rb^nW+rOOz2p8l<>#eCh^ zJ&U#hG|dHD=~r7ZHx}t&6<^skQa@) zRLAfBsPg3l4{3uCs1fKHhV~g&le_{$d0k9iA?d-yO`@F*m`Md15wLTSs5eD)9o&JT zE53?e*9DUgzxxHk=mHx3eD$@U^~=@>6uS%~9enHwxC&B_q3xppyzjJ-YxHl{BH#C< zya{5R5e%Uw(E}U2I=G`r81)oufIsjHrCkYwt^{*&gCITXv5v+<(^#5MeB8rYPzMZc zJyI1d2YYlDLWYDkv_N2T>;MF)a+#)u8PL%h%s6dzdB#A|#j?Qs!ciT1fOti&5)9@B zU6*O!V4(vOfEAwDT1`0oDhc?g3Is>s7mx7E-8fPbpivbh0{@@Fi60H16n_}H1^Nqz zwc!U*Y<_8t1W2LeJKH;s%B6QGobql+Tv4t)m*$ z7fJp^h;HpXPF;(AGCh9%)Nj`r=AscVo1Q>wABP2dCN!DZ%HreBgkb&N1bPjNe_&vv z#mkzlxV3aYFeIL=%9Up`VRFj%_c?*(W8yR_{+WbxgOkd?qvopk9O zXW~qmRF$xLg-RF^f`x$qa00H25pc}UPWw7h%|F6L2!Nc9xGhJ5JlEmYVleD|*x?Jq zZZoqwb}`L?T(!LLAg|OvM8K{f=;6$DFFMrWqtD}M8fyZe`Ei<&fh@WOQ~|>RW;7~V zK=2lQcooFSC9Urc4DFeQ{jF`x4+`53uc(W{mZVbz1Isa#jw5)A(TG=QDDJ@{D@ka0 z)MI!2>ullZOS$wMNOHa~2z{9L6)WgZS`EtnhzpKL$Fh1D64I-BVL^AH`2^t0*RO#q z8P<}3N=Q0L090;};RgY>|IWyVWszk=%8{z2_}B2M@NhINzY5Z{ZH+d74n8omY{`0a z{W&`%} zqw5cGb<@%A9gwdN=#7QbDHF4`NYd`AyarVVDb8fK_rSoOdSDfckEbS*zZr}jpDHrP z$`bL{%#2(PCR_^|Ob4!EDDy|iG2n=ONsuFN=>BpF`CR}hid@A&){20;Q;K^r6Dza* zO1j4jAqTz9QXp-Cje{G}7|Dmax2QYTSfC&Y_h(>W98e>$paS<8m7S}3qks>XXyjQy zm0|NLAYU3x*gqs{Sr!t!;!0i@Vo-5vB3VJ&d8`5InFf^4cA%Gzj`pBerA}4YYBZPR5!5TkwaJopg>DrbFmozOdFdAln0n;LnACH^9;z(JSml+ z0D}Fiw-<>qT6FfMpfo6My|b(W5$qHFI`312rZeS%Ijt2g^d9T=5YvH_yVU4L7M+W+ zqEm!(aRl8=#8WvE6@#SC9zlj^MyAqiqldZ7!Ifk=c9QlFY)fGMeBoxTl`>u~Y>Doq zD~9=!x>!FvEVw>K_Esh7s557V=L9nestW?o#v@C;Rgh4O1rZH-=EMu_d2`}4q@7!>*Xl{}2R^o=7Q==+Rs zJ^=?Ui^{JGk1?<5yc=3EUU-`lxLkybXar%)F|^hrA5e!i!soj9-`bA?F}s_T?-@X6 zuve~$Kr2bmRVgfS8)`2^V(5u$ttBoFHBM@D=OMD?fm=&dNZBZQ|GGl}m zf=^qgh|pZq&4+!E2S*q$e;`c{dGsIjOw7ViTJOH9;3HA|QyDOoQ`Z8um6icxvq5#& zxH%gSX~YryaRI37#bHV^W>5~e$VW5_T>OsmR=TkWK<;z&V^N)RZQVP&MgsP^XQhm9 zMVWqYz`Qref!wZABsL!`UjMlZ!%fq}SAD5mVLKP;(5qH&QF%aSiw_t9;RNMD2;sS? z{Ig|p?ZYU$X>!+}#7~A(7t0T9@{uv0Y|CxAGv3 z5D`fJX$uG|0zH^=xU~Z{`crV>T!<5`{BN8)PIA-V!4T=mso1GrmUQ12K zW!(bLj_LF8dqOEw)AfDVD14$^dT351J%3%je&^`Z-U>B&wBui~Ow#}>J3FL3JETm3j&762 zY{=7LO<7qBbeU9$O(JxaT-;ucHNBHouy)|X2X|q53$WL*Ogw49KsP%3iHSiSi+?UK zGd+4cXcL*ejg^3lNJgjCo&1u!fO8_?c4^2&PuZbW9GPTZeqR|rI_8+IdPFOmzc2Wip#d&Bl3K0HN-)PghD8ysHow? zVAKO}P^(;I8@{BV1;l)iQIBa`>RaKK`O?%$f_j7a+5#ijVJx5m(2Y1)0bUFhM5rA0 ztxgqIO6rpsyzqw~Ay6m4!3TMi%uyN;qw7MpvVjM#SI5@Ng`cKm1wE^w^GOdg9I1mN%W*hgp z=%+CpCVw-Mzy=_#09;)`vsdl_`uIa_FgJ`naklB<)rzp(5!!ql^Yv5J8|blST=t1C zAGAjpm7kxz8xDLl0@YNmWv?!>T8?NEBM~}p6j|61dP3M-K zuQLoY6tLQ#OF{ssZZBYu8k!})SeVFlL@d;cy^QLyO;luC|75lJ?~f60R7$=p%rSXs z_AcPI2jw21b^0Ry2uxRls)rPZ`z|b7Tgob?5y~+XH?quS`U!t6Qjr=%a%QMhPdTst4_O7v3d^u`t~paaxN*dy)UjFWxOF?Z)ase)tc zSJ_Bhm|M>}v4*n#i5P_(2!&fkv3g ze)vcG-5S=jNsB5xf>{liI%Zdd4Vc=(2*R&_6HV}}Giz?Nuk{~KyAi*NeqP}x62Zwe zxBA5+5Sh@{ zN9NHqc)4NID-fI6p2Y{p!M7TjW))MwC&Z&_25zE3!}Xrxo>|*eV4}EPR~GA9PmNlR z>*SnS_Et;&2|24yh9ta78(IGo3aogX=yi2Ft?smO!UBly8RsQK0C2!e^4cT-0&>@? z%kg{pW|t2x)@Wzq;S_tR+mPzbB;jH{fGYw(AbijwfB2gO%s2C)yHlYD0g#>j32u4i zYX*0??)<~2SoPp`N=Pi!_C=2a$?i9i;6`#e_Gjsx^MPFKI0h>@R-9mB=`$E^k!Q;Q z0lA}sYZBjcc_awl9o4^gAUQBpiy&P-$R?3Uxus@>g0!?GNhs=b{{@bXbOpN!I~I|t zB1qPv!fq*{MHJHQabote0LEJH<&Ml|n_+s?RW8EOdVUYKPf)x;N_ z+)@$0Fl{vnN8V%6Kk#P-=!|q-{VOPB0vOU7B9dF#BxG6gn0pwIf@~9)2WBe@yP}`+ z9iV`_U#Xt1I26U~J@0%2@lcPG8F4D8D_3~ZDPwNqw_ay%hXu)6*hli$TRck8ke!Og zBB3JJS~vWxBH`0+*^Kjh{_S^l&tm##-^(!=S;eu$gh&8*B;!-WQwVL1X7vb(_7l~l z4=E5;)sWP{OJPST0x2StK_Z{P12vkjc8nf`E|Mzw%<#L@+szAyFUJ`%;#2L$-Ag!L z?73(oAY!jdsD%ZVUBwVxh;XWyUINxL+HpA9V1ch*C_H8}9~H7p?30_3304NptcqK5 zAGVgtIx>yY{0vY0o|oL!d-W)@peHK1NZQ^pMm&%+T)43NVF8no=Hf#J!t1laR5J#P zdcF{A!>3TApa63O-;PUBs&IYPw3XoncldZx5-x$Uyt-DaNRjadk}?vGL}GcO0UPO< zlQF(h5xdyiI?$8(#QUVw#y7UTGLdabxzIZN>edGM%$zsmx$4E*K~-(1Lx{}{w(_6l z(N<+C)f&nH>%_;+zuQ-Rl>@hnikbs=-{0qyZEob>p4c%eGZ9zClLp*Y0gnKK#&2^L zOCB)C)kKQMKIc0;C%Xzg{2i{^_naRl8?SvDg%8{FYBkG}xT|c*O^)@VO*<48v;VU^ ztD2N}s~DyX3Dp+W2Df27D-Dy1wILsYg}#>$Lxmq^`^!zMznIVU!XA#%!fQvMqW`Km zjQ#sqcDz(>UU#Ajpb5`}PTRW@AChvyN9I-|xw`_qNSPs|q@F!p!gM`~d0SdNH&RoA zO~PZvMNwAzYi})UdGCsfv}>VMx>zz|?=vS-)ATPhoS_|!;*67@nr*Rb%&5Y(ZUFrI zvVj005|bcn``1!FSQYJmt6c6Zy{WdWY7%d5iF~+=5C$G5UZ8q8;pur7DIay6{(rvv zErahWXVAJlwfTwfZSmpPws{`OWmf_85>R?en=hM0eVP`ofFm=L$-7Le%lBQ4Q2M0! zc34RGnv)UEc8$q2Y0jSd$$EPR6QE%vGzs_wYx_i~hU!8l&P_!c@X;{h(mj5es-!_6H7P2B z`Wa7vFuFxJbNVlJ_Eiu;GZ+REb`Dn9g|kj(X&WX>cW~u{i!B~D=)M2F`)<*Q)Qln< ziz7JIiVM1zp&smvesH#jVq)25H(d0a&nDhuL2ek3f<+tCWYtobuEg9J6L*wI;J#w{ zRi7_EDb~%$3A6(!bwp+a9q=N3!e^@U@9P3B}+${p(RuAaH zRZf>fJiV=ui?1}zz%UwD1Nv;PY9>DHn+*3T)iEIeKT$@f37j7GB#};e3kA`p!01n{7lq z{bKjcmL|VOgNkF`kTI>b>>W~>Kup4yq=jW@cnDmxApG9LrN62W-gkU3P;Zjpl8)~V z`e6Tur%<}qU)MYB;%Rg*n&xt7?240z_t7tI4=nYJxz;KXy_(4Av0vA_V*)ni9-9u6 zc&fZJ7tl}B;fIN5cR$6wc_hWbs~*0t;U1kw6nCTh(#fY{NBi40d0p|b{M5{lU)S@u zr;6oV_X38oX@&1z)=jk|*KUR=&sTivw-hFe!8W7UJ%0b%B7Jq%NE?ldxbdr!t;APB ztF$60_=!!4Mc~J|qeHCMtrBEyx=4MISp{jb6?`LOwvi|S5qR2UeFBF!;{~p5lo2gp z&T`_<_o=}x4UTgn{!I<^V>3eky zLzDLhbRbW#-j|NmEKJQq*^pX|+wS_w3Wu!hv7+i^bFUibx5D zTw7kXu3yUjZK8Mkfr9=kj9>k`?kpZ9LpAG^2_m(dB-hHy1uS`GiW|vv$hHZsU%@PKbMAqyB?IrEXhuS%A z9^AQyaendge^GOC(e^y0*DvE)p}+JkB-f}}x`j-=1!!{p!pDy8j+7Q!S6SuURvzj94>gxFmZJEo8Be2!EJ`%rA zb|*@K*PH$hA4q~Tg6dUt63M?`m^vtqy~QtVOrtTqhdS$=>hBrJEp(WLN|j$XifP1a5t zbf>{y6a!sUujCsmK-vA2#_hZoD>ruS@oN#?X?xOXhXr=Q^5ur3GL~DD^a!N-=iEkM z=(nhZA#?%)4M_i;4hv}{V-5FZAzN~$poOv$kj%=@!d0nagiX}1C|fYsBHE$b*Sp&< zOzZOt+Q&~`I~mJKct(a{;OW)j|G*))gfb9lI&mDG3<-|WB?m{sC^1T_)J6llD@UM9 z1v_JT9+x#<`65%_z;4Txx5U?i1)2_mhDhSNl|X?o!RV&-LOHW;T^9gdRf&{6C7w|kFKy|*O4ue7$W ztfQ~|OKc-O^kA8L83r|c7Os~=bEk}9RSHY}H$A0hK9PzN_Fx_onNN`bxG9kl1~^X% zY~-b)BRn|aSIubQzC={;4YWz?eyn=eZZgox&|A`mzN7<-Z8Kw6S%arTLN386&f(tj z>?FdW@6<@M9Y7N-uPJ`Vjn?7;axBk)l+SfLL>vaDQ}kx$<7R3H=VH}?n7EjrD6A!b zrU+GVb?BMBW7NJK%FIPVG;EIw-DQgfiqf%$Te5)4(<^vKVV5KX^k%oxWEaxQZ+mH}I8J}n7@^1!ic{%PIS|bdDC!=~V zSwoM%!6q-Gi}KsN)rZM_MkvBZu;A2@Zn6-tb8k%)!>wapL#gY}wOZztBt{o9($5L7 z7yS>^0fvO3N(&LCW6<6nePe#7b&UqsagKG?2>Zk+$NDJesUecLuR1EEDtNfZTH8VU zQ?esHPJ zMT(XVA4jE*qYK8h>!=DY6OB+EHPyOBtpb7eAiF9GL=N*!N){^|ogcwS#1Z7dlheH2 zpNi2|IieGZ)dts&TVp3CY}O}i_f01Wq_Q|V!VEBF+Ca*{bZRUA0!fM+k3WOI;R<9H z3($|u8cHgzP?Dk+KI->f?V0~P+Ho@B#Xr?a_U4(!l()|vgpkG%xD%R)18x}!mFOl{ zslUBdHe1~x1iS6OuYAW%&r_*exLVamW;#@3I?QysG(|=ij!nw3FnccgX%5zYMg639 zTCS5>w65iHNXuqA0rWtLx0b*^lhoOORHpl z&AlcsO3zj88*h+1^yV4-v$-;}c^b2rwXfWYDrtJNMd7o>sk0?|ZBMDE?&i;yPs~=V z&sORo4_04S?K8gCVV`^PbGAlfuGVy}E?m+4r{uFo#KcvE61MzH>%jW_(CPdz$Yw8U zW`uuXOlIMmX_~QG@1VxQq|d@s_`-DR!c4)!>|ZPERHZ&k%E`Kg#nXkWCD7tBqLp#!usMNC!33M>AIYrC;Go=c!A-7Zi31mi}}uT{<@WoLKsIx^xA?>?S4v1eQrs+SZzzo=YrIBEo2d&W2*iZ`)NB=T1GmhX@S&*FreDOSuJXEn;iiqx(i3=`(LXi6%FS12o7h6TQj)FU$v#Datx(M^R@trX z;PtRdg;%~?c=DO6oJMAhy1YWDbJXNk9Qks@+i|6hEdrBE(g;u>sHm7@Inm5zo(gI% z7n`_lscO6Zy3nO#SJw4QxUSdGR-Th9&^o+5 zW3!N4EI4haG(yX)YNveCbw!-H4($4&Wv63nr{dMe_X4{$PEL#N0(BB!ikq7)k-Lp) zPilFWq|$dkD_f~-%qMm2HlIDI133(T+x;rA*H(Ei#br#~#gt zb*F;|GkZI?o^X}LWOl`D9;B{%eA<{i*xxufIP(ZM^&o3cJN$QWEc@ewP|YXaAENWy zNkZX2Bc7VYd-zlCo!9^PQ*X81`Qz`#k3VYeU5Bxx`#(qo57+ty`Z%SE#1BD_4_*59 zICf(4K7&DR9l{5}$J@ z_4|8!5tOBzADQ0Qsf0*`@ZhqHKOLL z=;%B2f;YABm*b|-k@Od3p}6?=3zfAC?Sl)Q>kD1HOFf!PeQrO&nF|BeOGD$!XO4aW zV&_I^xE2W&5@V0TB@(;7wC467@_g5r>3iGKZR6S#`(AO>6^4%2&5s5BiUj}}U-Yc;yG~akJ&Qbll4PJcT%$V42 z@o0Fa^YyIBo0Wr>xq!bBIrl@f3L7+Wn;63>RU0IxAU*D~rn|pFreVUgyysO5H{qUT z$E$wd8l1Q7l0BtQtK(makHi19!2i^azb)|jqW_LRCqklPK7C^TaqUyd%Ip2qc_$Z< zrtSAcO-fBhuKaBevxek7eX=fywn2uz>vS~*-nhq~;#L1Im7=@$y7QIkSA}6N&ELOlZgC_*&yjJa=L%m0x z;dt;tyjxl0c@(_eZE_0|O5Ka#)A;WCXh##n5%BkyVN3}6>sWM5TcPbwGbVTGU zwLSY?+cyc6P;p{yw?_&PPQL&7Q%;A+=OiqIvwcHC3KoGh8$4`VVl>oz;^FJ>QUv=$ zj5xJnzx*+^!JDGgV}ku#>R0V*$_+in@_)XHNDL^t(19S2n}F}QvAK`~iB$-KxuHI&z$i}FnA~91oosOe-8XFU zB2!ZA3F4c9d!|>lH>8R3Z_Dy@C|7}YFA(P=)F~b-xZ%`SHy@@R){MzOFgvG?$tTAV zoTEG?7nwj9#+0iqM*gaGXwx+3#>T;2%y}@@T0Pk-iLV?f-Wv0yuGJl@YCG7Y8QPS)*{$fO{VkMrTUtkO)>;$TWgLX!eyLSzvb$}95QfA{n3Ml z6QLq&p*4cE)WCPzC2R>!wk{7?B%sq*WA3BpSHn#Pg>>{ciozk$(hLU56hV|9`fx9M zU=GzkEi$!u7F_g|f_8wdF9E3EjiJeHZq@<`%D6fVBq31WC_hGmQ-FC)q$mIsOI?|d8ddQ>W*J?zO`z=U?)7`b$7@X6!`e|b2MU_%g|EL9k zS?eS3u}A1*y#ON4i;EdzdnSqnHRC>g%40yZh!TbRpctO2>wa=m$CSL}C%>&XlAnWq z^htj4=U#UBB9+&(ydC07^+VU=x7e6*EU9?sHVIun1t6WTE8{Dv^shYX%NbdS4OMa# z=b@%|qqi+bsE$}iF?>U-Tb^ec+9GD+Zb z|C06LPECB9gN1M7Pt6W2<^R^FhQ0SJa=w~8`9~I3{p@A0?g87PhWvt3c!5Axd-&V! z%Pr)Sf0)}J@1J1)rf$yZ0X=yKg>gm9oYXE6NCOH>VorEla7+t|Mn0jjhV$Hq;3m+R zQk&zDbU}vKgJ-{sl{cZs9)9ND1 zR$;Jt_WI72pA~fB!Oc`NN=wakt3Ki20X)AOR3ViLF$F%?(?_BWG5zr zn6JvJQE>z@oYqxpC4+O)>LIK|G)Nx?*;V}{N@o&Uh5ZL35*{Tbwgv!7t7~83p;F)` zd}Sun0*xt@4-3GoI7$;dBlV97tnIDap;%qHA2L5Be`EWU0a1A2jZP;tQ_NI}#A71H z3Lz>H>q~7APZ>!ZvcTuRbOrV2}=@5FhrDD2pdG@=zZ3Q)^dqHn2t~NQeH+`_}AX}#9(s6I?qNf{b~)B zmau*^;MYpBokcB@XsXH&hzNxgEcUjTz3^;+ZC%t zu#%r+g|8behY`gQ*IUe`=l@6<^4fQ2t(OIj@EmrxdmaJCwxt)S zuiT0`MzcR25gJ>)1eo8{$lr&WpNB*@2Sn7kTn)tnzwEs)zQ>FH=|w2WNJ@Zdi(a-- zvWSf#N~Qda!vYyY>X{I@Iv0;^40V1CCe(2PcN6rVQh$u()CoVJfqINm0}26cZ2kZs z7Y-77z+!ZRe&{gx;a60upVC;&w7{MOG$VUYq}Wl1RQ5Iqs{SV4(lzVx_EGNMjATu9 zC9`T#i&EGf9N#7Z7F|NX?SMUy`+APE<(L!!&Pz4=hz&yqzK!|gGDKhQFwa?8)cSP& z_hEuEhcw9!uP*R>Sie<>V7`hGfb!hpKrY7axx916KBd)Fv8J@`-DT#+2D(x4yft@< z61duL@BC};lGDgm;xWklgqaW^SpJLIi$CHUVj#BNA%a_8ec*1hm`UOl<}fDfhIS8n zTLj`Ae)k!<2Yy6_yFRS~aGkUo)m|=UI`{v?IxgNEMIRL-R@1>JwHWKa%=Y9;$D{9k zm&)D*t~HVulJ>Tb`74LrcEFB7XyI#JqkHVh-~HD?DOSp-V|cc;_*c)FgI{L;)dCMh zrc#cHdzq#dKj>2t)Dnv)6yDXW?A%cTy}rzZ$=K`2D;}jDtPvJnu*l(t4yVng$d{pI zrSHwFBeaj0eJ+nB@yVQ`gbMpWMz4LI{9_E7#R)HNXmA9xpcNIfmpIaB$~A;@Iv)SQ zZ%OUhN~9YPn7EwOEx=nW{TW0At7fjM8`nKaIlOv4V4lQb+rm`u86|*`(*Neat-mYB z5WTE|%y2+V10K;6XUiBBqfWQ-O&WkD~wOq;rZuOx#3aDdwPBhbzCfYR)6foOyP z*K#AQVsJzzeJ*T04BJi*aYM1clT*)XhOGj8`AIQ1KJ1CpWI-$T-Ab~+df1IMfj=?g z`%ZjF)6hF2Ac6mv`Np{{#sm&Zi?l|adtY7HHcFn68r#h{rV^U+5Pr?doovgDAzzO{ zP6Vj}QBVsQ>#=m!DL`SjDa?Xsikxk%NPMKA@r7!9XeI&fpITK5;zjFZ%Qt|~A22gz zjwv(Z5dvo7VHEw|$H+rU>+kV#e1*8Fa#Q3oSzc+O#2HYVWUP0gy46kC9`cA(r;Ww< zk+e?p)S{TUm@kUC>%r?W+4yLsprvahRNJSM8MdwIb^lHlZVG@7PNrkE*drA6Vy0+E zf;Pgede!ftHl?%(V7yLhTlg{rO(^hOr8h>N7I#RE4J;GR-=3s0Fi$y zuSD`KyHsR}RANcC^xR)(${6K41@{oqsXOR@Yl*?d9tuS45qoyFRjN>hD`>0sO_j^>WvbKCHwsh!MB`HkQ@3fD~97r1m)^N(SJQh~? zlqX6qVQb&{gFKa5O5MIIZph&=?Zcv_U2`^knMewzLm96m059H~hE9yW3uL}2^a)H%CscALnX4M>ik>G(WLwsXi&g1f{` z;U5_p^B}{YQfB>MBuEPb3T1e-=_Yo&0x>{M@lAC^#LVaLL(ov=a!lv6yuvK+vXZ+y zy|fNbgi&S03!X>^CxcU4ZKoBvJ5Y!Kh(twM-pJ0>s!%`PX}@zQUu_4F$$}_!LU>lr zkA}xHg9xq7v^;HfYQHGm99}7eU(AMIt?28IXhch3Krhl1k=TgAn7kP0&MJEpS+n!8P!@ZpDGJ5~J{ zn8(J4#|P;UlME2oN9F^!Wnc>AzI1?K9RU9Z5(H=K#If+Q4r0>g7hEj%#`Y2)!4br| z;TV(&O+G+wM?9fZ-TJBF<6C3_Sq~lotq*K3(~@MxD*}l^^5Izu2y73mDwR?;k2ZBC z3HEvSLH11o^3$R;HEn_Y%?eR76qSv@gffI@;KlO0MBMd84`7vmWq$}8cOguTh>;{E zGOQ7LlbI<1=Lik5Ve(^(J+DUJe0;@3g~Z43Ng@5JDGl}8$DXB+JyV7%04+vDWD4^Q zQT4Q{rS9QB6!rmnF;C=B{J~i&p5H-qBj2ML`>$|dGu4qE{Nw%)S3UlT$lBAvJ)P~( z8N?$}7b6LZqxyNaaIQp!kpit=H!XZ%Gyf|fv3P5!I!Id8P7U~ zKyIJm&wubI(xwdeSf&BfN7@O(=At*%$Fko;;arWS@GinU>XFVarN!3?c(+lDm-HLm z#3s0*Pee2tvUYZ?^UycdN5%nLv5-@;KUFDeK^!iWBzI=Rb|wIm^>PRrLOgMw*yf zk`rOl>qb(H=NG$De0P{5gj)+afZ<61gDS<9+vt@EWoVyLy6ab#sZ{=eQ?6*99$H67 zP*W%u${|u@5OYJ}GwtEA^rIY%5o$ULFvA&uCl*boA`A#f-&AB^EPMfMnEAJJKA1#V zTl+!oJC~b46ZXxfkTO$Uq&({3>;T+`0y+ZA zh~^2t4!F^%29yIj`c%&aMM)jz|Jahh~a#B{G)buq0;_)TTjNKkUT2peI9$0@>{R{Y;A z5IP(lZZ<~F;-_0m*YxvHGYy6Px@qvAC1V){+9z>V zZ*br=N0~Wj0y;lt0bCZ&3=|i>a)I{br6smW8s4tk9o%z1E~Af*uVC5vYCSf7$~cm) z&HX!AT)lSR+W=ABr0_@-;TyAT%dt(jQhNMB@FLLx`RMcW#^ef@T`|Y;PgjBmx<(|^ z(L(G$<8HjK$m&msx=uWUC4Dn1MrNC~UVK}lzVT6+GSYC|D>4Hbe)3zLmM3^&>wRS* zSnkfn)3{lFry{ItU=gC=L~gNx`|@A_5Hj7>9#h~8jJEfkiaPjrMMO3p@~<@1gHtW= zMx!X35^4A;Jn&bXDn?w7MLLDMD@8~;HEC~~wyOeu8>i)`E9Rw1`F-*iFr7wON3u>+ z%v*~Rmia_yu`P^`F*-wR%^WdKEl0rYXmW5`{DNrh-0M5*G&Sr-NrpF>&lgr#sY=O) z`-8)EiDW8AkXJz7G)>x6fZx|Uy-;gw2%9HF4av3f?c#%b#3Sb!D{SE-KQ3%ezQt1d zaE{CvYb##qIJ!H|!I(F6g?Bi7_$rZug;{nGLDKd6!!6oKqDFc>dS2DKZsPi>Is$PI z#MaM(k2(&KAK~W9)U03{%>%g8Y z7G$)7lVbir-h30L2GOXSu(J57Y@fD~$0fu%7f!<-<{KlW6PbP+xK`T3k>92iegK-K zT~-v}5h&@HFO~(~Yq^PK#i-f7B5-j_eJM8{&^wOo)nykj97m)@%T=t5xFZN=9oP|R zJWqvx-G4=tfJUDw0jXspCUHiCB!L;K$dP%2->Qw@by3IK!fjg1_Zm{+3S1y|g3IZH zk;MUGMP))ollU}jc*vm|fNKN+_~5l20Z`!%O=eSGfQ^z2*nX*4lZ*VJuRVjcwK?>1 zTy9p0Kn3FYJwJLAo}1~`cu#|sGsQr(x`iFSN;aJC5Nzei*`+8;Icu*o!0b1J{B;El zKc5k9cKMo(8Tp4__nk0M zh@m@v<=`{(G6ay))0G%cbK=e2@bUd@WPl~)V(S9e+`{1F9#ubeetfuEO=^}g%Xi73 z+#`qG2=9JL6757owM;EFDxZk!eed3x(_0t zU3FAk7!CjkTcy@L{76)Da*SIJ;V@g@XjXLq066T9)AdE610hgmyl^1)L((#sh|~8T zpTkZQEuh}syyl{*Mo?OAs5RtLo`y1~_NbNu5a`1cEF-l<3rpn%38YAk(!33vQUu7j zBN5c@H!X7OZJHM$1ap^#AoTyliWd_Cs?b-Svyv3y-}(GTC#%Cp{6> z47TL}O6X_?3|p_bB92=0v?nX2@*&=4w#n9a)$TNkj?meyNnMn%Hts zdSZrIES8Rm{~@B%;%_!TKB1Tr6Q5|30vtOYV=EgwSWOKpBaHx`N&ib@8_mBrqKrxV z6y6%uf)#>XFU~j1IOhd3n|DQ|3AlOL!+>Vea~TYDQ##WZ#~cxAW!*XIT}Qa6Q(5!3 zA`9t8G6+SkOL$3Te;BnEN$cN<`wyK0QH<~yqkqmxGa|cMPyo=$kVN;%A3poAMmAdI zV)iAMBN|2|nRu+KWNVp5)X_&ujWop+Pu5o+`mJE%r({;j6ATKq+UTNPApOJU=j*t# zQru>~w&Bf8z8Bq)K!WNv6@|F2=Jmtse$Op|X8w%F7#HXG`d@!t-eDd2eYf--bXn(z zdeB3~LbvY3w;RZrL-O8ku+4L-&up4-ptby+-;r3dUGSlMf+b6%1~rQJO2Pl88~_kh zIFV61K6VI)6UVSata?>TLMhVRxg?z!8{Y~Fg`kEuPR)|-2K;&0C_*`=csR?NP|%q8 z@-K`bL{DTjZaGsjjm&s5u$k4!B^o`3xl1n->AC7~KLAgRG&YKd{X@aoT2ZSi(a^n3 z9{ZoF3{$53Fs_%u6a3b1l-Ho z554u4j^8+)N8*i)HcfEI;qWV;DrmLQ8=mPp*{y#Zb1Y72mv!TFNdw>b|#H@>{kctrV_LB34ttaek?hL zIaZLt#N*)6#84uP-SdAFXtnw#K4LzSFtZe(Ko_>AKbGQ~hjVDA#tTc$Ej8E4qOD%B zPZChKH)U4OFcwS{QmY*GzWOj_P_L|5O16gNKB3jJCP~d475gTVoTX(in3^@)%Vs>4 zUfw%HE-8qWWmc)KnKuf`eotSYaw^fxC9E4S?of!fRR(YyeK zJ(jOGF&&KwuC1Syf4x9rl!yQudO`*}k8l?JWktiBb{W%*$<5kLX=%N$DhWc{E2mSImO{ZrCvHDHtt<{3+ zoCf2(%c1 zT-^5ERNy`olhXYaJpht84JFI9dMJe-BrdPta_XcxlgS*i%H<4N`eAm_e)WQ4l^>f8 z1%z=ZV1M9fFbC-EJ`rdeG9;o!LR4mK+&C3RxIE4hymD=sPBO>kug;RHei$|2`@uZ5 z`BBLql}9!?wf|`ey^?(0INmbGARw0e#nWONWo%9^p4_?>ldUTICt zCohHgowLK}C_SGw=hyx_?=06*{-kSe?8CfQm_6yh+x?N(5*I=h!OcoS$n3=g|BDh< zCTDfs><>1EK1qpwZ01BwAM&m)+M4g3bz`%a%OozV^z&Sv;c+eRVuyrB+E?6YIIpy2 zl-i_LAL3wsFRk`nT{eU=xmlCutj`l0m7)5++X&@s%=%xo6fwCw>*j2J;59G$Nv`YJ zPKXuw#@x6uk4ABg3~=L)Mi#8iuq*gtE)KZE z1nN1?YSfpH%0Pvz)s=M0CRod5PwE!JaRN2GYtu`hqj0y6vR?==R4B@6(00=_34 zstr;l4%0-(od??~33GCY6|oF77)X_}Co1YCM(9VD>l8Oodm^5EC`F_VQ)dXKCb~a0 zv3J7oI00Eu{D9IOHCAm7=v%1rHZJtxDkSDH_+4jyr#ul5DvTfTKjj29cN9UaGsM&$ zB1LQGg|sVUvKoGM%=!XXkw~YQ8!#XpIPI2O@rontF^ZJHr08e+bg_dcUWrXq;1HEd zvDW6!e&jb33u~|%8`0S`Cj<&EcjU?4gGrpKxF`~$!W_ZHOlIs{hq|Y+Wv3TH{Uj&& zoL=OxrvEA@WG%F)&em<&aCWeZ0@!*t#R{oJF&9p3cd8U9%&~c>6s+`HrHI7^JK-h@ zxb=;jI&A54#jH#*Q++9?#2_r+<4bQCYpqqTs=4Z=gjKm-(JMKy!+qjcfeK9e=uv*ZY;gxElF6}b*=0pO+vEsbyI3*IPQ z_rrOcBPT?i&EBsEX_@i7i5idKbs}6=)AKP1p(fypuOO)Hsn|nCB(RSI4L|UIJ3ib|y;cOpACBMJ#VQIaGFNx&dzn*pBEMg)j{%+*pr*^{*hJr*S)WZTER=r;I!b~>mUCg7C=hF_ zDM}2!R0(quRRr)GY#71Vadi3nxt{IDJ(3JV9$90f!a1}*2qrYY@aOq9Tq4?^Zb+57S~dGFGzZ1BDyF#~O{$w%!~B(1LcO@6IbawA@Ua4=QZYxM z?dY#PRbMW?0}^<9h@4^-cItrC$OoE<6n?hBcExb+)nnu=)| z@2`QDxZhLhwx{yQ%cot@j<|4wf|S3)Hp+KLUR;VJ@nwX`0789{#@1b)bD+nC~DEEcW z$NoA!#^lpgV#85Ci^(>N*-MM>)101-vi}B|-xhIaOi6Ddsr1biDegn5`j^r*_MdOH z5rGxmUbtEQ++q1ONS!EzWmo(>;hiNWmYs10flK`&-A40tiN$RpNdx%g-yoc&S(J7q zJ^PIAfO7R;C|BtK%LII9WLM#j0mI^@y|sjcy}yI4M2Cbmp+`o!&X_}o6*(Zz5d{T{ zVLb%6b8xYdzdWfW+P9xwpkN+X4$=vZ1=Nn*R&P^Yx)^(VafC|qs8D1Yt$p@zlC%bA z*idk7zi5RMa%{i2f^;~+ut(a%JyFhahDLqjpB%_9ocv#6Zrfk}y?TiyX@0TFcSHUZ z&F@Jj=|vUbMc2WA=a#SMaB?)o_sav#W!yOD@vWB)=`%n1b-?D-!{^bkpC?0yulBXC zzNFvJIxfO^d!}DP&m{fr1N@yk{GZL12%>$K>MI7Yiyu$bjuktZ>imqRar9l>?i!A;kiLfTwGK>wEOSA&wTM+08JuiX}n6`onkCkHxK zM7=3!Cv{yyzy2EV_Itq|wzh?_`ggu>AQ z?>eey>I2tyzZ#++OM!CHokh$#9q)Ar@3A^A0&&Z5JF+eaWwDS7LqH`3-0-M><3;3k zm10;{i}41K68fK&Yw(>#6uFN*p*k%bf;*s}Ti87{y(74EWJ%IjC9s606E|m%GQXE{e=QyG#)oQ z(cJYnyYb`S&Cn1@d=_+l2)WgN#4HUo$^%G)RKj*HdJwT}S}^ylBW(0%(Q3Zt3n>#L zTryrkYJtU+$|W(d$CXS*GsV=>$|C~b5?;O_Dxqb1MjFO3prtOgM@tOs6z@Juh-3pX zZsCqeZ=NLv&Xn}Kkb&gy&XPjH&6sf1Clu+=QQG?KY;ENe$};E4sQ6M2KkiUq*Rf27eP;piqTrQI?_T+W7;k&5iDzo(T zvEy&gIv%bh!Rb!*km$$n(ThctH0V5HI$Gr03pRbR<3|)V>lo*@CX^r4mR)mW%B7(? zBOU)K=zM)fkQ4vpe$|1UqIE9#s@49M+r^l#*o|bhwS|Mx)lqXR8WS%eWdMznIdM$i z6eMathq0>-nN3jd;K%hpo4}rtm}tJrD6x%30mL2|WrRz(^-z_#qI&dmGn!Iz_>nl? zozu@V#3j29wHvg^-bE-T)LDmt5rs9ejY9P^H*{5wlDiZD(b#} z?aZ>uABEpT1k|2*C4ocGhWgKrhuh2PC` z?-sE2=kv{6NB{Ng$=g4Fz5w2D)ZrlKL;}adn|^b34$;hcPLuTRK!hfT>+W`9RPP^% z``@M(HD2F*@1ZQ?of9SG=gB#8*^U^kn@Qu(EOENyUnvKRy-f3uagdRW*CL&cLjsVz zrQ8YOHR6Ur%8LR9owYJ>NY|4TKstN*q1Iyce8ss=89S#F`h|=m9;Bm6s1&(kqtdNC zMiBCiP2t#r1s_{OP_zNh^^=vY?6=}sGErL4FY&5YFM6#EUXs#JY{{>bB~U1xM~G7u zT{+{a8nU|U(eI2R0%=}Tk?~p|e|>bgS6g_ToW%z`#rKMWm&!ppo=_Zdrq}H~$97H@ zraqy<1~i}T1|`Ww{<0&JH@rZ`xUCZ+ev>!#D1^VK5GH-Y-hvksNYQ-vOWg8Io<{9m z!93NMB)f%o&NvFye2@v2$Un9UuQ3kMe5ieP#EHvf-7S>?E`Tq*3Af>m0;I1fI`gyH zfNIhQLA#qqueZ-I#uN(hH(uWMgynRH*JR^p3XHEyy%AaE8W~kP^G96d8}n~(^w*eD zwkd!n5WR1HODmA#LP}Z_h(e?FfCM0#`|Z1-u)CBnr6xB{P_THDC5+W#Rh84c@Krk% ztz0JAISsj}8oj~|gi1|Vx{?gfL_oq(?0THsIz1e+7@RE{-bogkw)56G8HmN>b-7CY zoR*yF5VDSE$>bIGh&2L9&BJHQLq1X==^jmMz0P2?L}4-kk#_PTSa%}X3Hu-t_o9ydws#QzIO47BE^cX6$ZUC2~--4p2Ldv!2#ypAk;1+_0A*W ztlvGr&SvUZC5oKuF0C_Xt3f(09sUGHKDSPSDkR2c0I(;?b)9T3C1Z$L zlW;}U{3u;(kd6+_>wdr;C(VoSpZd|ZsYxYXsIP|M?mLNWwuvXvGJD@p*jOpCWRmpX zQanq-ZqJg6i<0+jlDRRHBBmhk3`qvDD*o8aRPr`v0+h;`Q<`#+gYqpIAuM%KmCWvn z;yi;46G&qpkS-LIZhxJcGL-&E0>IcoUHYIpPC{vr`WcG1LI5k2TUUmMGisg{WG;;M zodkJzf$X3l!pw}`oD^h;%%*d?FI#l8TmP*oK%9}~=m(p48RVfV$ z;PwTQ_62ZOS5{tEW-?ti-VkeOP#Ur-jd9RM8%9r`0gQ;rY;-1#Bgje@&0yFuL#xp_ zsUvB7VA&|8tO4h&8dj{#s)vwI)&mh#I*{4Fky{%><>;GDPo3oIoMx2;L3PEp20_pg zYWgE7cBsrejB{o-EXM!FE>9c zcPuFXhjd~2pozN1sT&Rs!qSblk%^4FQrq7X12KvsgvGKYc;Jd1%lDS>9md@+?o zm_f;%piE)2{KBk!eY2{zA|2LF$S;DT$46w$w~(|WNPQ5YJp^D_385hawJnfv(2@U( z$<@9q&5)@XTlCU+ct7MQU{^MPj6Nx5RI*zJ(A8ap=jD$B!BT@rl0j;nJ5sQdYtro! z%4o2jcZ3`gWTdCn9L4}*66*vIfAE5xX zLeS*`SRGo2En<2|FKM3&tdOaF`rZ10s!KX5j)s-AcGGPk2|TSY3E6JRFjn#9H)O6= zx&q`zPZz4x;ov$z(*xmWC+wjH2VWL~n9NIKGPNIAHvV(MpK#n|O3)Mx&5vQrd6)%A zGPEQq;a*w=nmrVoLLPv>XahJd%MX%5bqI{Gsokpb=4&&nYi2#9w4xhgphsA%SAX*? zQ-`o^2pOC8R2YPCA}bPlLi9&M*2>bF9|*rf9R@(T zYZO3#0en5xQMwH-1vErO0ZQ};OSZvRa4<7p4l5WyzR>>0Jkw~WoZPcvTpc&tl$3ls z#r(AR%`}_)K!hNl>SS;^ppc-wu}uS&l`V_xs`q@TqU&rU?4H`FRnE!EGADfvZZP8p znL!C`JsYy2Sl(IYUhHjk10C6>VDNUO#t)LI$sX@;Kd@Kho^*Z-hOXLEw%_zjmyCgI z^16yy4U%``EVvDq4D?U0^s=#9#ig_zWnOL8mec~u-LQH^*i^6g zY|@;tiCVJZE~@}mxZgASM!y7!PnIEh=|(l(M!2ZG7y~5m^fv%>F+z$8N6Tn(>wT&r z5~H1M10?1U$&8*70|>yBuiG965Z{xv@~W$kZ-SGrnFN!aJcxJA;HozUFA#J>64h#ghYVH}GuG4EP2a}mjgSqG)gnI}-<0w{R>6@um z1C+fcxQ6S~!kiXlSB|z_z3}B~EKaihtYd@&CJkbiJqCa{BDV^yOuI$87g(#S=Egez zp1yd!SH9Iv(6-$rQ(eL@+Ko0Rd=p=B#o%=8IkHMSBaE~l1sOfqW>yVVx|ti}ErgST z(g5@f0V{oeJ+#$hmaSV3u&+m$<5_3)n()okl%-=sA4sMCd)aFo$;fy^jqMftftoV$ z@tf~`$@C>%oc%DS4dp@BUAI6zV=!&#D=vwwGs3|$)3-T=CH)Iv_CJtq3(^v(bH#Jg zvh7vHAIK)h<+)zmd62|>IxbMZ5&Z8|3C$0B3eAbB1S;m3 zju6Mza6AR=oU$}-6Myw>-1gIs!woh1_8|l>5-aw?`NMkW$z=oYOJfbduhYa+ z?S?B^T>TM}{zd||I_qlQZY<#`i|uB5#ukb^U^Dlml^8g%Z?&JXxa=}LQyOd{4LI+a zg7i07R{3Av4dO0>1#5lHO%0;&V%?=BXuH`N}HJ;UB1nvbgxH^Q=Q z@N&-(a8?r6olxOuykyqrPj>Ikg9s?7x%fC2lI=N1m#{je9$~)yptIFwN`EFRYB%X^ z_xQ74_>YbNl^9#sZA{2RrYE)ec?>V_D&jtH?t%4IAM{y&!Wd zjkPY5pXHoCNEe??Y4%M*YP0Tu05WJV_Ra6=bp7kSJXeUQ~nVI$*;;*D( zZF3-O&{1i4{k&evzTxZRUu~x=0Ek%R&kEA<^@-XURD#!M#v@l&WQ;3>2ZfKZD?lAN^4-0I7S z**N~Mw6czERfZ~2hU1^Y`(1IsN!I|8DuM=B&TRcpA+x>|k`wOFBqXzvq1CH_)h)3b z$2dmqv@=8T$UXO@UvfGSN17pF@W5?w#?D5`-6n)Two$lx27p9fR_jtT&5$Q@C9Y)E z7!=mDKIwLBt<~52W19&~{x`c&U2`K zh~S!;I@QR9Cbmxl0}ufW0K$NwKqxIRbP!7c7l0+)xIG+$%|HT|vSJ+#r{>0P3e>Yv zPN5=;=ZM+g2XL@zk|S=L3lXAMZX0d^OGY}c>PkJ+(1VUN-1z@ht^^n;(AN7DG|+z_ zeoA=QU<3gRb1#+M+Nc_X7B`tI`Qm+^x4&c1*D{>41>g<|1xG{SM*AMq=UkV~Bk5pYez9!ZO*50*N2fd1B<;4-B^?Ssm#Mx%}R99L2*)wDNwr9qBM~O8imS}fg94@wQmGpKq z*M?CEbN7y~?dC5hchAke#yLAIeCCxUpZk4{^+EsZ#GyeZ{HXW~;Oew`N2(%a9r2sZ z);fwMQpzTV+oP2!_Pgz=O@idN&WAorm1Ihk^fYaQ*%dEn{;1P)DnHpWQraQQ@x`MU zPv`rGRARnbhann}Vctqyvel7b$w)b4(CSkDwOctg zn2n&TFBMia;s{4hh=x)x7vd%8lD)}La~Gj3B8_lDaLVY%R?mCu;W#Oy1d#IR2z@-mLA5S0hE_K;8_=wQ5RkHtExOI^{~g| z-A3*0D|Ig;_*{?l$)aSw)|O_uT3mJ;cZFUG(bZ`_zm+i=(wVBUJFhVz*RA#Ov`zN0)3Y8ku8fSfh6wi;Qs!xO z^f3jN8t!g>JzG6IP@$y1M4o4+oH%ZVwv<|FDIr|6#_83jZe)ERnQ zikkoMU+eh-dKnmMW_p%koV)afF6ejJi{I@5WC*Q;y&*P}1rUZN)jaX67T z&g(%5uLXLoNdne~S-8sl736Hl52C0zRlH(J{l%#|)<;k7#$8`)GFbJmk4gG#tE|65 z(<1EvGoxOp2qKQKy8|gD74H;HOxjFQr`Duy+)8pc z^r)Oeg%+!UESFFAO0R=N$ApjEGM?#f6S#`$H&(j^TY&n^4S5Ssu7hk#|(`p1|B=sCLP0tJP%C<6aWY4$!k^TidctVi_-_x}v(r%e7Z6^(qGPnY~RtMx0xvw;e98tzTlo zTtS}<2?7mVEcl67l%?O}8uThiX9T#G?fOh!jS@U~c<)O;fh=1`Qd7+RY<>0qt|oOY z8k~#eby`D(yvExirnO*Lh2n23%!(9!%3G%k0+RnT zbu2RWpIC=Bbk>LlHEVSc>60}fEj`H6=Oy!ninPijby?D*8NZ$yn6iGWz)tDS#@uRy zvi|I?nA+vR#4a_veBOi!l`-?#|FoaXlr5G?8(-zjWx13eX|rJvN~hz#p)3Bk~t_c4_UPj&_uJO_$R@P2xB+rx5F%pr)q5{N z2oXdQEeUp&W%VU`@15wq_ZBUBCy}U$9+Jpj-}}3F=H8j}2b?+c%yVYWd0y|2e49Jc zdAF-bx3Rky%J#N1Z~@6?*jp*}h`^sue`Y}Yz$Jp#BOKsU1hOgk+td}FtytF7Q;d3? zi&FgPBVTdZTgi-5v{Xz8moH%MA!;ItTj>4hqeL+bUI#07NcH}m6um>I)REj>%G|Sq zlrLt+>Zjv`eZv(z$kFpleTzd%HA5A(sY?B(eZS)hP0AU@Uyz?Je*nhQ{=@g|x02;@ z^b#K{KS-gpP7jwa2isUe4OyYw;GXhidFeJ~t(u3`%qpw8(7EOm^}2A5u7S9nUh-I# z4!HcQ^zaISSOu|0H4&JjIJr?i6aXKLLp=`JiZOo_cMQipMF;g(!Yae@(s+(e7nj1> z{fOZ=DG+9ieL2*CSj;Inir~>FH(oU)&Jwib)8qMEDe$tejkcSOg?BX~y!mxyo)d}2 zQWS$Agyr{8laoRrP9~5Uh3(6W)vdf6rc7-sT4*Y76&dbTI$VQPy9NhlaYOV`s8-ZS zEo?BZFCq^;7G6|`-?14f$Ela@jD#(y=Ld(!PpP%=KLSx6SU7GLzXkI43p%Eb+@!>dR{jVxr zNIH}o2L`xt_Pee`(E^m3&ddVOg5I3ug>Lm8ACi12~N8^QKx zll3?ZQl0Od_ER=91dk+!U^IE70tI3L=)W#g{GT!%YdcAV6i1!OLr&`R;3~Vg$zPAC z!1bCA`DN~-Q3PpuOd={L(dIL^4JZwk!Z5AV-uJnvxZ$B04s1jA#HlFOo&Wb zFnIz-D|t>_%dmtwKe_k=7@a(e^aHaeL`PgI0l~&znEAL3DWMD`$7EnSjyPU7BH@y% z-Y1eklMrlvqjX_mQC}2DXH0U2>YNzN01sQQbdCB%pl9x6o)x%$T%eh2E3+Sz+Xic0 z-?4ZiZ(Iy=Rc->{14n%3)La+dftO+>miF{2DU+xC_``O))br7}(Rvv3yRs(l-?uJV zsQ}18Ar8<}SFDy)BsT|%J|1wR34xQrAJ6q9MdrqimW-W}!5koCr@41N%f_&jfQ4XV z8&I22P!>G5Ist&-2@F)q^p(io>Ob3ar(wDe!{D zQy!4y&eG>;w0Wh3vDhV~?_5RwT<^6L$g3FCDg^qt%A?MPs{>gf07m%Bm0=+qLC7i8 zxHzto2m8EVRdRs`VH4D`z2d`$S(9A>;23Z?U?f9w|Iqk%s}+Q)`T>Wjm`!NEYmjj= zfX4<@vpQdS4bhcdx^Tz_){PWg)&vTIYv{v;;b>?N5^%;5CIkim@PJX%TO7=w_Z*YD919qYXR)GF#=mFi-3V7Ki#Ro4x?;pgusZ6jgd)6({+)0)3)!aNKnKvf4a2Y|d#-0#M@ifF$$-%W zVy^{ABF{?Jk$k38P}Ty_r0N?Xfmlx%i(m^%hd)CMMVf1EryZ?_Au$nX2Qma8>?9I1 zhM6C(3RUF&dWPHzNdElt_qu9@&5v@3aK|TUubodzb~GP! z1W||V#uP4E7tE#ovmrwB;_fJayZs{!`7UzEb1n|E5m|w>kXreE2njIL&#DJ`%Diy*hnG_9AF+< z^;{++BnNZ!wJY{r)3X31 zu~rq#_nW5yVIGEoAx>m6mtP4{WzOlDu5?s}t6nj+Y=CDVH;Vn4 zV5iRu^w!IXfkn@@^>qQ0AsD@4EYj1_k!|WrLnvZ$dE{uDM&hYY{nJ0UUCiB|Q7PlT zHCO~N>+!pxK|Ec9(~rL{K0|I8JsorYELKG(Eb@IpXc>opVNf|k7{JwnFbhiGR4c7> zg>)1$4TMpF1S{M^IVkPQEQsf&pT&Vsr2#?^MIZy-GdeI%8>kdFe(si z|8dBu7LdBHNuAECQ^>x7zgZ_QuqN5W@!z80rdYJU&8~6RZ1|)NVhHS<|9r>BSolK5 zYaGY1aI+JUyqkTUISXv%2|{cH3O?9u7X}eb;6Z?)-ds06ne>zOdI(RKpjg$i&;~JL z=(YY3B$Ee3qX%=A;)Fm`x^k!LJPt z_qU8VXwnFRwciYQ6qSux$iS!?yP`qF%@y`o=lN96?KrT_nzc#P&aM~AD-pRo+W?ss(QLAduGlPew!JnC0~i|L|1tZrXw!Wn?A!Tw3sY)n%0--0 zSS{t;IaQ1#8`g{PE6qK|j`A7g?>B<-ov>8+3*Wwlnor}z>%rZ^uBVT~;FHh%pv1(D zj=tO$@#Q%4F~K}^v6}X~ z!jg@AmbChPdwaG}ZT7}a{Kdfo(dV8CsJmnyDNef-g=Yl;CnY+8W1bc^uk{5Gf8uDc z{v047RA9}^KMG>I;!jRf&y5my%p(i%1(5gIp8AbRfT&aX1=plHK2RN)cXvZb& z;A(!I^>DfpzJt6!Gzy^zU#?vzM2wH+xDQ0->uqJ*PmBvr`O{8UlBqJ{5|h5(l>Wy< zX59*0ql`>fza29EIy85}QUbQMhTN>&LVaqNLfv0x;jQQEz}j7Vd(mica;25GWYONv zj2QK0w20!fO0o@=-U`2;r(`%HVlp5R3~SUbL_$VP(^Ozp#tk6lK2A3JlPmLDay}2GJ24N> zq))|H)e#{}xJVb{Jfa(>BkgLZBV$g%4MKFd*XW|m>kM@K!7KCOIO5nH9WN3*jUh4o z9SA5;QMaY8^yInir~&2cn-P+UKrZwxQl;n!o~Ea9)N!CC1x+EtXWBO}t!#SgS5voP z`M?4!a70#lL`Q@&ERL2Im4l2nw!O^6V0j&kXm5}{IUgJ7;b)FhEpA zwlDYA4!j}zY%j`y6{M6k{aSB_msWUtMgyQ}e^(`k_A0gptTf61pm~&_YoJsAG#*LA zP1DayMwMi!IKmU(r`S-Ji!feW!1M+xvIy+jb- zLJztX8jUIx#Md92$R;>UI9ex2W1d@QTrp0_-sOV{xNWxGrp5fTn~-~?kwc?jQAa1C zW{H3Bs=Zw0I!2Y8M_ye(M^(s>cG}LbUbZt>sN3Kh=P?^XTJoBH=Iv?dsi}E=hPS|N zD7~@(W`36n0p^%ITYu6y&q4c?zOoVHs}&)kM{8f@cQJAG?V4QacC<$B(od{xycT)v zNd%Vpv8SyjUlbADZhTsivaH^ppkIq5YDo3nJ}UZv%qT@a7Zbgzr!fM#2%fyUqwYsZ z0m(oH^ac5S`{pO583j)3wk6Mv6`f1zm56CMIs17D7CUkh*R(|GXm52aL%O&F`PFhC z@HPF1{>=#taeVAQRv#oJ{yDktYwdsr7?y}bUfCjSgLi0BkS2#(OV8crBhZc^B}Sd6fsILo32UxvTtktD22(&*pcJs`_X&-xpB?$2nb~73GtBqn9{BL{`>nQ z10gI-qeehZmklPS6cg3@4KbtdWehCbXxbvtg%LUaMdrIk4Dbe&!Sji!DL~fSI;^HW z7y35@aErL@ZS^a_Q=buD;x07{+PZ8~S-1jS=_FRHE{DnwJ~qw5Xzh>9wX5EV^(dKA zer=fR6i2S&IZX2w5{ddtM4_5`9=|V6ZkFEdly*tb@OiZDB*ub^ku`+Sf$hhRL*hVWPN9mo@<#^ z1g9vtMf0JlkHGs_G~7UNXgwi5Uw~mD)JX6P%jc#6$qv!zH<2^7Ha>l_b8Gzv0`@O= zn}$>#i^Y8?|H^a!t-9Q|W1TGY->FAGYWt0OUw^1A-xhk)JZk-~Mu=#CCVkBm^d=r%5~6!^nGpAhJ3SK;uemi~N*% zHmtt)-y5jOe^aq8abmS`FZ1 zM`e?m+%G@wPp_Ib8i*x(=+Ko%*!yjpT5c%%0ZedpJr@yla@~o>JSdfm%Z^NQSFc#(yQV zz$at1R!8zxg#M&PI5y1)P<~=uP^FXlo=?L(=biY)?lu3k#|O<^5Pv}Ma;z66QgP{3 z=m9IX%p;-Jgm842z%}wyfIk+pxxt`QDeM*P+|wwYrPal&;3XWTis=^TQ6&!KiFWte zdea5Ez!v)pqCb{d`k^OakT$-@iu z`KL~axLYH(OuajX;@3tKK*bQM|B%nW=oY3?ni-34Ld&9X$iH8hh^iO&uhTj!h~cwC zw#xb9Gz$2@QHG5_l$7t9dNMFKg(4^kS3hQhF%4wXZZjK=LKDf%C`NjKh!@B~RE9)c z;b&nRlfl{getqWU*`t&Imzf@|jkRy>S^X6z#0wncw;~&HHh+oO|FwSz(91SF-e?W` z2D zx$DnfFs~zTaduto@LL`kc2>|vkyfs3Ws0{a#DHgO)Vku(XA9^ zTKAqd(eTs0KAk4=Mjp;0d+pj86PB+iZ5O3Z+p!R(EH58stym@69w4z2v4Ts$80(hb z`K%Ef8Zm2f8)LPtjI5iXXK_I>giQ+YVnESDd}8WodLG-?AaYe!ay}&;fwtw+6(++- zPHU~;b>v-jAm)c?M4e9rsKm`tIjo}QarC=kTrLcDmXAP`T>(i9W} z1Jy1nA}GM{W(Ba`?H-_ca341U9?D=oIL3qX6hUc@j10*Fbslf4irMM97&kuZpBBQN zEcB950e{pH^Y!Jcl2R#ZLJ^M&$~XG_(fJO>q-wm<4vlr=X7mM`4T1OK2hChm)d~tS zgQFCKVHKPxleWB~SS_WsKChW2V;4q#D|(FPAZ94ID1AJ%XfYNTBaMtJlU&-9eF?sSV^{%TSeV=K`EjF zf%$i3lDC-EA_U+-<{niPuE?omv*tdZlU z2-(*t5KIa*fC4s0erFD}7Qmr1y|_VQEyODDv>@a0L5A4nO@p1-02Q3MY?VBHRjdXT z4n98Rrw#nE&@)jHs^S`~I-M58o7Q_X+sQIi!A5syby%rRNkB>7b4yFgmC*NqH{2Ag zvGUdGRxXfFy<2O+tjn+j36B7ZlXpo|Q2Q%A=IlaMjRNxdTH(8fBIsSHQ9QjkdCG4s z9g6-hg4l3WjFQ%!+~Ej$D!#0XlB|D-?92>f6j6tvWynrs=}Nv(yaEDGxoA0PTsnSl zG6QL^AjdH=7~3Ofipm7`2jPbdDQwc{RY%#dxruSNwUiBzrGuQ^{Fp)VmRav$1Mb;C z5?&N)iPy9ZT+A|rR*O|$s#9Ez#B~X>!Qe625KzRF1IC4}4-y&6Oi{MUg7G?WKz%oV_jJW}T$H zc|1MsJf?pJM;m5F%>@BxK6|ghnjVE~ub@<%*jT!S{1^$?t$#x6HD&kk&-q%&WY!X8 zXzk>}N+N}+YWtp70RrA{sAP*6ryDx?4+AOCElRYPcOTh}qyAb(EP-<=SzdYxp4^Z> z;d`gQP3ZBT$lcQWMM1lo-B=8;*6bix7$9-cD{ah#01oQ!sI^fkgxs%zc}&8?u$lK zR&1;9cWDW^Y!5gEYIrDB8g9~O-|qhO?Rfh0Gp|w4{FgG3#PgBL6#M~i#IG;nm5toLfu9 zyQWcfZ3jwSIl4Y^H6r9r^lEb5lLj010!SCX=&6<2RaQNeUVfG){E(|uJ%g4}AfK<8 zV-0{}-*!kyjMMV#-__<;lD!7*kKYZ7JIrP)^VfC?*r?=&1~*`iql=|_PgRuT8m@!tkLTR`Um7}pG`Ud9rJtm~DX)%6RsZvUPNU4uc zMFe4WSVR6Rn5-`GzczNzz$w4<@ zRS*SLe-{P_}7y?c)k^Kib)zS(V&YRNqGBc;wBCFZ+=+39xX ztiku{Z~=8a|GEdd9=Y^|A17>o>({S)_L3m!=-uFWs>hPp!$PYRzZUksOZBl;rED_M z3Rl3`+UC?4lOBEaJM>CL_ewcZ>$V>7a775f`arqkp<@@fS4|%~Iq#a=7m5`_J9Jkt zw>I#mj>sseZMyZLrVqX#dec7H$QR!KS{-NN_Rw zhGw6-BsN6&r!IEK>3#LL1;WXSQHGn)h9>TADUtinZ6ad+p^K)x66Z=M7xBp#tbF7A<&m-XLWq!mEek+ITOAhs;69l<;RirT zK7jn58HN4vvkS*xO)&&1JlWNdce>w67oCP4<($pRp=1jGmd8 zO#MmTP<3YC4_oX*h&vVgFjS_LWj z%pab$Q8#0F4XuaO6tGJQ8q?ZX;WoeOh-dPDC(!G|Ox6?~)NP1=!Q54jh`#WUAA)vP zrbUnbwq@WrPqOv$N23Jl@On@3EM(Tcx3B$_T_+Dw&U!M)lHoCmI>u;~XqFB>50tA4 zRHSo(Qz|LbQ_anU6+!z3UcI%yCVXB5P-v6h%G2q14^&&)iIT|{T%Z^V{YtqId0x8r zNgGWh5X(Hc_*FH=UrTn#fBS(l`%;h`YXtpYg`Ag#U+MA|@KKAJae|DYdpi0e!oW@D z==XTy5Ub~j*8Av~$8+LTo{!?{k~?1c@b77Y_uGD&2EP>WDD?9A`t8HVX1xE*4yWSI z@VBQw-oeJdv1**g;C|mpzv>`Eax;s-AZ0g$*GPE{oAB`!H(w2{=KRglGJk?&~&6dGJUze|6AT)7b7 zx7=cdOM1FSq|e3mchnPU-DOSpFywa1@zX$iKI#>9kvD)|?ESPy%Tk*=gC@1nfpE?j z^I8oDt~70Lb2jk7>%wkrOf_!Poqk_vf4J7%b6t0qiwkpl$#uHTP^OZU6-vQo$iDM^ zg?zOv@{ZhB-)O7#ynZJF7y9y#sCrMPG5u&v%|C#WZa3j^E;TpdVc-$6$-Fu=5>Nwu z5dkG%8R{mbyNTXjBOj^t#N7(t;bh<|-BG--NrK@hSkMzBxs}6_ZD6^9o+alpo_k`^ zqMt;}Y1Q%NnmZ#gEzlREi0*u$(SVi~XH5|qVq#H?eANCZclV=HP3*UsY3+N5+;2Gf zs>clSEA!(|$BFb;6v7oBv=uPSk#5`{QTbFlJ#laM%s^9)iOz6BORLmC2Qj||hlr)8 z!E+6A?h@%I*&2~C-g-eVB56$*w49L}$!iKt9B0w4t8g^eH@?Iv^ew|2<} z3DJ?mwo$x^%mv|LCo_mh1FmhR_sD-u3o&E|x|<*Lw9U2?%-j~YlWe2Sp5^*{(BO@> z;xDs~UvM4PLw)3rki7q=&W(I*%Buss+dZ?r^F2<4_g#L*+x_OuPkD7F8Qjs4k4L~VK04Y{{+)Jrol ze%R-2m3}%5k+k3B# z|9eVN_ap^rQB@#_h-f~kItWyZqcAzx>sb~brZC*$cXDox@e<=C^$qPnE z(o{OVILYsl_9RhZY9_7@AMWEXDN%QyjJGwPexJHe*8-Wy(#rgtND)7x}3 zjheqxm2Ma-5bh1cSp7ap>Y^wV302cPI6IN868s<4EOyM8blk`wp7gJm1DeQs=RV)A zMFYolTvawn}abFytC(`e5M_oy<^umHgI_mNNlB{buX;4{K7c z?srQXUHSG6lr$by=~1MyI@>*We;#Y*RW@q%F}Xo{(CPAcA~;yV`l51~RGYi&eGuvh7=DiW7C5W6kgeMZaf{H^s^wD-x8^_ef{j zUwg{CXROl^=!Tpi2VqV4B&nCcv1g(eI&LcF<}xyZL&9-2vSNpchONAl@yS)SuMbUU zrL+C#rW?aZT*h21?qZ{OKU#c&1~|W6H4XT+^h|?><8v!t2jf$po@2e*E$o zK_L?^3~-=FBOZ zLnsd&ROUM|&ursc9Y}a_+J1t+c3tU}$GDLl{cOkh%5bl%_sN`|1_HBpB&eNg89dx- z;uQo(GLR}my2p(GZ*D#9cf1yEx3$~g%bY~-wHE|0B2Ak`36l?5f~;Y@Xm8R1$_I4g zvu>@WPx!wv=DgiodY979mGJ{ZwynFuet9oi2qLEZck}A?%jmASeeJ9FYE`%9sukD; zz3#Qi-FwMBez*UWE%vX5(#N-1&mIfoA7~TPcivc!ob!)z?Rxfi=!iaNiXs7+AKnx1 ztFa7c(Q*@oWPrcZNpt`RvIBkcVZ!(8OEswdLg3jT$F-0Wnp|J1^|$9kc~$5& z!3Op57Jxg4;i5|U_xwly^F{?|%lD4IMKynaKX~#}Hx7VsnSxr>1rh#{Ve(<#rzRLC zJzvu9cOQTGQ7pW3e)XPY!w3ear$aG(3-v!7^xyyG-^aw<+vpw_fFVUMX#H|Tj~4Oh zSD^{>XaQyniBup!+5h6bg0A{1e$FA46`~j`lCCueJ>?sy_JP7lr31+r^ zt19&lA{xw+6HIpehAG*f;D}|p^w$Chu_YQwT*9h&uy9eyizFbe4Vs@OgfjrB53p@h zG0O>l@!Hg!Vk(3o2OWZ>F&?bLiqok7Cd1ss?e7g?2{-4ZcIdQz(S_n@gZW^Lgb+*V zOB(2qZ;&#S76y=MvTG)@h)Bkey$XunHj>8x)R3g-NKdvMoQD$(d6Ve8k*9zNO;Ri| zO?ZVdj6NiyT?txLO!KV4JpvWc2nW?LM*u~Ai8%uxn`l$e>mbk+>47(v$cdWu5*zFj z#qy0>6qwx_gjvg7ie7=lB@%MQ3)xTPvQh56cUr>LErX zm{2f_qdhK8+0@huvExxgDVMSX5cm;s~hnslD9;AH~P zCe^P>7U&#&!q%h!dN+&{<1%MTlg0}|*A+>EVGyfbTK5k?4{1jCJjg$M4Amx<{W$d6 zCz)^=ZHLeazhVRzNp4EP>=*$~Gp___Qb>pb%NoNUAZX;8-qG5G;6lUJlYr*B)b)6q zNl%|fhE$TtNHTd2kct_NO$aef7|EnY4#ZN^DXp^Eq<-<$-D)HM&9nxVcpy3efOJna zPmZLX3Lk)2A_~(6n~fS+(#dktNo>-f_z2);DDpD3k1($>Q{uN-oyLvGEFH?|Wl5tU zN^1@SjxENL>7{i^gon<*T0_7(pVMpx1eji;O-!RAp6l=wzLt&kTzvqGx?<4yPBW$o z-IoGXZa`gezDN<$^9|~AwC;xv-RO%Wa_hnDKV-QLA~BF05YKP&-J5s?#$WJVuH&!V zugNzC4={}LXLY6q3}=^5qB)nK_U?9IJmmZ+kzPJ;`;xq}3#5FJ%ejvK-&y zZW`l?LqdAj8-=60L0WqDmti4;$yNHjr;lxGB#)0aNDR5@R5#IppgUP6<= z07ThumPa6mEfCeBbPJn;>W9!4b<6RlNO&0V)e;oV_B@!RI*vG^bnkn9t~q(jc3kdF zbZ$S(91%6WkZ|{H@v4LeIL9`lK7))v#0dnO0gGrOlD&v+6Jgy?UzUcSH-cbdptS{y zdVGv615y44`8^*AKZvvTXkMdd<_IUKegj&J13aBXTVw+^bN&;!#Zmx$yx;n6_}j2_ zw8u!>Zz}BXM`W-&5~@x3tED@nD0aN ze^neZHL(0fk5ofud>Wq16lg4)nbd*^e3p^>)aggI27hQ?IT!aivC@b@FLa@_0AO3P zxt<8=-B@6GRaro3{NH9Y4uMl7uO&fyntpsG!BCb2lBY7IakRaHZ5zXc>sz^E!#=9s zsi-0%y$yt5fbH?1GhK2{!Z)Vo#wf?z@piQ^rx}g+GWjHvjgs|oE2%*OyZ<5FWm8hf zhIu?9Q>H)tY}5M>V?6C7RU@Vn*P8Bgo95Ooc{x_)m<-S!?C;<%F`^#-5B(g?56rKZEj(x$h=rC}Cbammcx7f8|;vH~uRI0Ke{vl1@64U{McKL&9#4Q)Rk+HmfD#Sudw{7OvUH{Oy0&nAMY;_R{lz=(o)iO4O;0;0@~Ll zKX#djaw%YWor-pKzfUFuIB5Yr?H%`Q&D{A;8pS&Og9#Xa^=UV>$oF!BBk4V}#RvUi z?@6+$I~C&w1db|N5lFpF!QwphnM)OaM1Xo*fc4u^0zIHwlWxQprfLNX;-PV_g-Q}V z?bC!E8pAv#bItVXR0SHw9!T2w4uQLykkB4l&bbd7Md)kQA(us{*29~Jg;vvbefp-+ zMi%z`Z%mX6`CChvCq=)rI`o1$d8odpqN+G{gW9vAu~-etd7W3y53`RbAGrlO8R@gw zpsL7r3E724t`0>;%@T7$qI}6OY;rhhiaetl_Lr#AGv+OB-7xqe#KO>_2&utQ%y2kO z$Z$*>+I3W^LPxEFZg;L|xP+#1+f>Vn=2Js4)l>@QakdmQO-%^e@Zf>L!`N2G_R^)G zwY~NQ!lSt6(76KHnM|KI&C-&(&N)PYg1wb0B%tN&PwzjnrhCKTf6A*zAOM19`MHP( zN#q15trA>b%`ps<_6bfb?*`aRu9m!CTFiu@A8oLO8G%Kp;H_ECr&SEA9r?*qKljkh zl<_11{+5%Cj?G4!vVb&<|M6kQl@JXm_dSVQ$4D69fT9t}LF7^~uC6;_sS9o5uX5HyU_P4q&?k&eFAAj!94$#)M!Kvh>lyx^s*&a`+u_2w<#DP z9OwRXAAyQW#$guBzYR*Ly|D@}k7fpaMGp7<%I)mhA*+Vy-llQatWXsqGPh}EB1)DG zZ}lC`{_8ygMV1$EBE2;+ix4XGUIrYivGt9G-y!s3U4_SNKP=0}F9A zIpUz+MUDFKRK?ZsGOKa@F4A+%1ZmnjpetQJWs54O;dpR8iJ4vCZLXb0IJ&Q3VqHCo z&1}8Wl>Bhg?3f8;Fd0ejI<;p(ZK2Dd2Ah1v%HHhOamP3Bn2$)AUfoZ+gZxP+KcQml z2UGm??KSG#3CsR9iOlS~luwE9KMu1?1^({2Iu1V?J*}yD*z|d?DjMMpH2JJKe4x%(9G*SqbuS=TmJ^~|cNol>v9qrDq z+9EK0&rS?Xk1CVfE2xeHMqbCbJZU&ci+ulr9Qhb3V_j9ybisdFouCpP9LX^c3q;;sou@C~UB4n)9ZCn_$9`A;L4T*h{28jI0N#2| zT^bfpzD{+gFsFl~N`Pvf*tU*1?bkU)kL3B<;RbapKaF13Rp$Cov-wrX$H0#Se+hey zsH^3I4Qp7c&GE$V41b`w8MtFkfa0$>hhov_L7N@Qp8ZG zC|1|{r8ppueB@K>CFfr>_Mf8V4rA4b0bo6gd~SS=1Smi8o#gT?7b2bmmS9PwWb0Ua zS33|u3V`fR^C!gvip?1`LImWsDoUeQx!!2BOOlXrpn5b)Gvh0sqZ^(^#64Sx3;y(w zd$>B`ONJAP_X<4+k36>kxp0(=M}$LRej)y~1P>A*bXzACyEo&q4A&*Bx`H>ZvKJ+D zHkZrWMe1=ef-y2!qhwPp`@b8Ht0z z2MM2yP~eJaV4%|@FizC6S{eBz zALE>7jA;qe)?Z;2Npfi-;XwZ>k3YM=n%O6!wF?8dKL`}@{Gp*mz9-(&%&0KS#X(#J z$A#0@%m@1;7)$hU|_W>@gnja`?R36b^RR)Co`Hd(9v@I zY@|H$W-7au!G-xs(; zZrz^Dt{1QWWON#Mx-I zcIru=pXa!_4&>)25w6~>Lf)>{+V(UpQ^D3wSp7PyP@mOt20pdZBseHf2QTpN0R`g2eYef{fuCz?PYvjKplkRQ#c zACrnxl>!^PuwCl!O-SO)Staj4y!O156-@2v=aM0`lmDN&=2NnT8cE&9<6=1Py;+6c zCzZFi9a4C3Fdb(9|CnnjB`(IPrzkB? zy{00ro^LTFt+kNAdFcPzFtqax$uUmjLzV2g{v+9v`Z|`2tTVzp10pSh?tLZscasEK zNqB&_oa5?~j+hx)fb}SpL(;2O64s=DL39vc33tErhs^;-^!g>t@8WY70qBGz6hi{M zT>L=FP6`Yyx*?^pDU|;QO&il^A;R*NC0g{Ny=L<6fS#k2AL_>?&gAn&eNohF(NAoc zDY%DyuIcd5Aa!o0@WDBs$}4Y;^oyAyIgqn1_Mu@e;|;S`cGgd5F)U=~%hSR-8x}mA zEoqNQ$_~CE>mtG=AP_@5L7g+HgMp4ws&(chYObzG+0{A1+wr#q=UF#m}0^2q0reXee@ zG@UfyQ8-Luspb8wrk`BA9E_O`I3p?i%RWa=c4{K5uCqlgH3S?cKO_rCw9iP@Gb=%Ux(VvVY z2h;st`u}RJt-kQ?0*y-$FhvsUL2)Nt3?1$v4^}p*=gk{9AHNFk{amW>$Y;Lk?;}f- zl`nMne7w0IE13AMeC1PpJ7U!F`mucK_Px+%ug;z!eCz3!=-}JBBnyShJ;!yj-M8b| z#~!&%t2?T9YzEHF-sC@6-TiWX^|`>{vF9I=j6IXV7SjsttYV$=eKWH63;mD1YTg_1 z3$(xG|7`8m@KNg1*1vbl+dr0SnS$x|NY7W!sg`4ZuKq-m`EG#Ze13RX<`1NCy^Uu2 zbp+2iN{TI^lX-If<;ChR!rHGo!QP)(IbY+8jLSh)f3Bae@ELsTR}&KpiAsz>Ci(T> z^GT`c>kGaI-s^uzpAY`(m-}}wWNO=K<>A$Ar~gIc`gNhl-#^=O=->0}|MJfIug^OP zoRzqwu3 z`$N=iOnn-)s09r{Aj0H|bN4Qlnklt7@HV2nhagFT<~V#oQas$*mwPkFy=Mz7cPb_K zKuP*OeQDe&7x&lV(5Bv~;T{WJh@?19fsXnxh@SClFDjVoiy~OvgZ^Gyj4TC*rZ`xc zm0IrW9pOl7@Cn&bQ!5cdNzG6>#12a&rF?I?kNp-_M1Z<|iZTO_JHqr49rcx)VU2(T zLg7I4_GfmnH!;Np&O8XB+5JHq?)-pKgpK%RuR_|dE6 zhmixX57;^L3x#Zyj(u9c6{;k0>fEtg zVm3qgEtG8R@TSuU9x{^AR+nI3KOj*pT6#V+DQ%!08A)A8;2~1XN6o3D@l>mARBhF92QHJ;4#fa06-5*~bpkc9 z1UOWaO62a89;%-o+NqCDrt%We1rtr(o~`%>Er8sZivYEJQk93ao3` zlj%ZbNk>!PJq70}dGa+A;igWw())%|L+}_RIrocPC$ncn6MV__ ze=kfwvwbAbH7-_Quo@^{WjJ*ps&nn68subf?tYI(g32v^Y7@PqGg9(dslAx8qVKVu zHbH*K6PTd7O$9fa(TL)wtB(=Rp8Pq)qV7BLGkb=}O#Jc@m~OjHwm$mNtf8i-QAGHd zT-eCzB*bVo`jOI<=wbB6)9Bi#`tr|bO&I3^lxRUg4q^AKUH}N-3`Mo5W*d# zQ}nVUdD89bQ-5ZgMu5!dPN7VE~X^$hHkX2%9-mXO0cl^Y9Z$woZsiz)K3 zs85SsSL_;A(SfUXGt9$yS{55DA@XNtyRR0G0OqEMrNo5p|3%eZwnh1cZNndi8JHo5 zZUm$T1OWj-L}uuryStwK|} zb?m=And7}%=yK&8cN#}wNeOMPZGM?#;XuQ3{T)l>ppo#o(NV^7GDksE+@l!F6>7?r zgoi0sJZTDPhIM_}acfH-=9V)WD5-(a(Ml#XK?6z<0i4Jj6&f4h8 zJ?(ooGSQTg?}l1ER-}$a#`Jg_Pu;Gz=SrIT=QVS@N1%ZGF#X%5U_OiaJ;h)>d$3UC zT8u-~JUHw2Z1x~y4ZF<29?=`Ml12*mHvPnOKYuq}lcd~U0||MXZh;fH-5$)!@u88I zHTtqgi(Z>7(OQFX0VZnlmVw!Oa#qsqcu~7#^fJlIrBC}t6 znAV2Hp`u0lt|kZ(tXYm=Dd*uQ-J2t=9<=(X2|f2=rP11W z>}vjyAGMhX!ur*}r%xskaSsz+l&iE4W5E&YRm>14f%&n}ioUSkGia7Vb6RirT0Gga z4p{_=l&&LwSO+Htd)+5Q%ReE_cG4n&2^LEl!)*?{f z4Gh-MTU!i82iio+0+;7L33Bf2-tM60*JZrG#Tvlv7FjK38WX&|G}wKbFzl{++9-TW zyeWh8yovIXh<%xf$p;<3dA;feQGVQPb}PbmV} z?Z3a}%mE-^lS>uu;p{7#IX&Q&PP&%?-RQ`&uZ-f&=$M=c0njnX>b;vpH97ve`5>8G31>gfxPEq2%za)pR?CFPRLu%rA$L5RltU zhK2o}gtbF}s;18-@#nZzgtSn5#2^HN2{^1_CaJXm8!v(?slH=QzT3wivCP0VN8^TO zD9h7L((?|`)(3G3_J%;|1-*J_e?ZnkLs#$7|H$>)l6p-C58^!g-MWJ;-9PId>S?J}KwN#!BAF zmY7TbBw5;Y)S0}__6$tJ5V_B6L|j10vFx%q4{iDeLW@CTH(YlBK!+iw3o}=Yhl@&( z>yqow^Up`CUMF2q!EP^fq3z~s8Gu7&z$-Z*NrCHHB<$Li@u!`PVxcQXzy?Tx4EL4J zGMc35J2li;(e`tA3VLK`BkyM$+1XgqE{@yJw7?(K4wC?lz@Uw24#S^+TPYuunJ6KD z?zu&l8xWtzpJbA_aaOKwxUMcz|G1$)?e-yq(}Jz3Pd43LW=oHM08!fyi>FFSuqLLd z{B}mch`qy8jb0cL)Gaa?5v_k_QMGvg!njE=0wd&aKN=2!@_Od51X@CXIoXV1J z(FT1>>y?PGf*Uv?zsngq*z87mvKU;7E(MD>;Oj4v2aImK?}Y-yZG<*1@;Bt<7(S#d!j;t8<+1F@rqnnQ?e_s_AtMV zNOSuLD8ReAKFUK$wAUh<_#c`ncroHI4q%|1|$OY>mwZNsL_qcK+1TcGK1uF zZ;#BWe7@Z=F1(@K>B9xw_R}H zLd93B^9`=lSKZ%f?PSp$McrxhrK233ZPS*0+9|6s#3V0Jev!0>+HIzun_b`Lah6 z8+R)w`uS{%85!$G&3N?)Q7wfA4Qz=CEKT*5bM=&&nk1t-XvQy}=lfM#jT)-G8+5bd zYDxncbxAqtGJiORv)d-ZDqYoAT<{iZ(df#(eK%uQ^`Y8my;+XVbd0JS>@I zssOMpur+&oaLO9f&XG*_*P6)`6R!v|CCakO7~8N%rq05ud$AaQd@+}W!v;i^0f#+3 zM?Eud2>$(zj$L<5uR@5NhEt4Nj-p~bQ;v;QyKeqdW(G|Ir;RBWsR2zEBa{TDl1B|E8A{)KXc>d0$;vW7Dpo||SJ z74fp{L12E`_En9F{L9rbeIH{}-5Kg9sI(N_TJC~$1CTr`N1_Zfaj&tjmNP4uPS3r6 zh|*XENo(B#YlitWp)?8Q8?J0KB*{q-pJbMCENSfr1ZBg$42hc=Wv)|!R;Jar+RVJw zx6(?WtElEMU2d1)LOy;T4AK(Fj(PEpgc2=o$6%xMt1OL@M1vl4ykBbT>b`fIRNzc9 zMon%4WYLlcjlL{T&aCAaRe;!_7D7$fo*^Vn@Pc|hDq7i)0I2ipL?BYJ_ZtAf!ZM_1 z3<4mkY}j#nOHcs#ka82bU!mkj5?ClDB@PaN+`5|+3_z^<#z?wY81o@JNb2R#o?G<1 z{2l5p4V5Y1NacXN-YPZ@okYWaX%E`{}smR8f_08p=%!9R#U`c-LsibSKru9Ji8 zN`=B+b-oLt__~p=m>Ylq`ONuFm)LNxHGZa^1Ra&k%l|r}M)@&11OF|TX&F$etz^w2 ziT=9FK_+6_+b`mkHKs7GObyE|ry`etjy~g=5wjCdHKr-*Bm2|n>;~_O|Ou-l>0~#Wy|20DU$}l%eO^z3?S4?GcoFZAd zL&@ftq+D33MLS#^c2Zoagoy0t{(bhb-vt?jQVi7odqXxRSSp$=wrwOFM89P_P z{QY+!RjhEdpiChohG>);Ww8@b>M%lCy~u(23Y0{+B_i+lrlJ)$o@B~KQVUSD^5OJx zp~05QK}s`1ShYz-MJ)h(dY}bOH(=DYRY1|D|2*l)9Vww@p%7D!SFy|__f)SyeW#G6 zrFyZS+O~HQAiPPcrvfpYghCW$!HN3vi!+EZIhWUc>wyWZ_xVqFb^`k*TqKj6{Pat-_T1$gP3}o&%QJs$+;M&h8M?y)h= zZjufi*v6#9!lal_up)pxWubFJjNHvt#wyFX*#ojL#uV8S10@Im>oHGrUve@e)Y=g5 z|H$Sc@6xDNYh(I{M-=1SoVJZ6aoTJlbw2oft+c9>%(sW8ToGg4C@=;{1Akm(1cQEDYo@t2~%DWPbpgsvh)c;JCriJQFTbrV{#>SA*UM5zk z`+=LSVqke{vP>lyC;0b6>W~w!NqQw$L)y7{<WCu3AUy!F$L zKamIPz71!M1@i5a2_EQ#j=&#G4Qc|K8v)Kw|8QR4F?@O>YEi~W2hy-=0; zf!0SZtKpSK+azN*VzUma(XA0T(j7@c9P$h1H6^)d`wgN(*rv=oV-Juk@4_Q43 z_v7yxItp5xR$i=GP81Z^6XQ9Yl^D1as(}InuH6QLrAMsU?U57c#o#l5nGp0?%fU#~ z4I3u;ZSHunqPX|0F+Vb9D^EE$+cjp#Q&H5N(D0uUg@o2TlH4TbNE)LLBLsAe!A`c~ zgPBFhJO+{lb|~`NnA%2)s|0z^qO`lWsC&ddJ*hp5d(*wm{C%r9Ud%*@>Qc2QjzYL8 zbs6Vtp%Ua}tQ`iu59jc4m&ls1#mw)Q?DjP<8*wvE<7l))+8e1gm!_1>N{nw~!zd8J z)>6xs9{R#BUwJ%Hd82eL-Xk0KwSF)X!0-&9lq@YR{LFTswHd5W`{$t()(r4spKKuO zTr;bl7h{A<+^nsmB#eY-w`{o$tv)}2i~=Pb$#6Ab^IJIE3!~8{Q>G?Cx_9v{W8TfmnCLkx`-r&e=y;^u8w5kK(SQvA^_O0 z4V%ik`g*CFZ^rGoOGrA$O!e;04R>`#mz zpE`rw>}D6LEOO5Mj~`K10H~97q74(<^^5lVw3+KPQtmno29`rRHcEhu&iBs=}&H z>)fi>#HvEzSo(-%Zs0i_PJ~)RfAYattQ$!G6x;wQ{4T5za5*0H zR~8)*8Ck$Isi~TFdSdyZJ9;oM^eE7@wSKI~y?6hVWaL(kxMaBUDsZX#?pB!` z5DRyV?Q{FW)k((9LmQbI|27lZ9AtFE&9)pH@9HxANp>l}dQ6>c&>B|AdZ;2u5ZU#d zF5*CeJ$xE92?{XSYGIw!UBv7btL27i7^7K0Jz^PXlZR)i3H=RINYIv77;(3@`-z|MvT5yFXN(s%B}41Q=?Z=vN8hM@h`Q(f~7ljf70vKsd-(|iSn!0k6 zBZ<7c9NK<}Q(t&;1N+M03N*_GP4a=!E`eQAYdq)J1%nU2ye zn`cmd9rHBDI9sL{&yXSu(m@M{E>I_uz}sDviy4%^&n_p{E+a613d}SelrPn6*5;Bo zca`JbKdb&I+&wo@uN`>4YO=v?=9`!sssURLiZW4MVCa;Aht6s;DzCAkLuogdcspj@ z9D*KCXD74t?Gd92Sj;3<#u$>MCG&?<{|;Tb#kn+wNbP?&qEzMLU}tNwP(^(wZ3Q6* zO}g&}UQ4f&=TM}X48EQXf8I&K7n-BE$Xq5ilpw3-)II?G_XTew^}HiXy2i|%iddRy zkl)?j-ZqlY4e6r0F29zJmY93#Zy+b7D@(Znmg`pIt0t!_r%ZRl_CKclLTWa6%eWKj zJT$8Tz3a17ZRx*T4D)Q9y{ZTK@{RnS7z>z5)%}(LRKc+q+b|MNThb^^I~^P7qdeH) zg4>hDndpLa?=#qn&n0F5hEXg0y>}x?+}bxeS3@41kUc)wkaFusEA;B|9WMt5aF*y* zmYDIB+Etdi@|2k!rk0jw-tMvgLLF2uWylf?6q{u9ZsNWk4KqSHq^tXP^Sl?UED3Z> zYi6_!J}i?Xpbe)^Z!Z6LyUE0=PfodTL>cN}Nh8vjU9Q-tZZW=li+N$G-`xyjlSek@ z=yT>xyZ%rbRw0x|tI$?X(nxRsKIx1}kCY1qB+dZ{#oHs^{=#A~EdzYv(?@*twXVY> z(=mLHXcb{G23mv8mJB?1ve4RsHdliPJ{9-Pm%L(XExyFDGjh@_&W&?23LK!Ig5MHCsRU7bg#^oifGfV?1@&{ zm77>6Gc0M#Gp%K6F{AFs~U!u5=jkg?|wbzdwHz?PL*Hg!xut1 zU02bAvG`RubaSB#C{|K{Vn2dWe=~sS?(4ud`FP66c2I-djT>E~IDem4KD&^YJe=m8g*5` zV4Tq)(quj6-Qdl?(aTRW_Zz>{P$o8Ud4|`;4ze-rhjFt@r6#?(4v-k5#keia*d%9D zQeP%xA{|2_LTgyW{Hn~72$<77nwtIes#Y$DyK0^&*&z|E%^j$cqY(d<9GS-QcrP0L zjSA$5Qn+9AZZIXNkX6PTRl6yDSU$t`{xF<95etTg#!-5&IqbwHPJ-4n$W=p`ti9Lt zxSd)6_Pp=%WNwu4Z6IuKV}-DhcZ3Uj|)M=K2Z~K zru)ZHDb3e4)4YmAkB+hm8&8KJ z2h;n>n`xqXiWzHIT!S9qA%W;#;n!iZA1pxdW~5H&O3DG{_tBw*b5a2eP;DHXT=K}k zVEeVrcVF#NO>nyQFga|45agg`=QpPnHX~b%Tir) zxZ#WNwraK>dTKZ&YD*kXfihiuSy-)S{P^5McOavV5lUmgy9ffnk{o576R|>ieVZ5Q? zY_Ah)SBpRsdX=y5`Ey%I_#K6^S<>nASPj!j*zR`fe5_=%;;MJ>;A=zVA!X&W4 zzUX%ixg;=`Nj#7iqspW$88!koNQ#A0f|GYJ}s_)TV3(n(Yhcs@3l(t?q&Uu4jw``KwZyw zLNCs1_Aa3bxq_sZ?K$VrZ3g*pndKVCtD*E5&>ktk=ywTiX2&r%N)9<` zPnB~2cuaCXx6=8pU)xP&D)p(+2kT^K4AyEt{`kxbi1nTY3a2V`q^gUj>35`=i>Lh! z$qDtaowXb;c}3yVk=*@QAqB-}pqcsellrd^&eO-!nViagx-;Wc4r-rCl9%$2XN7XPT-f!Hhw{DjCkiwf^3XH$O z#vCI0U-%w`f;2=6Q{11gQz#ma$}IwccLQt@ar`Kz2y{3_{x}#Yxte45hg(FBrm9tQ zwL?-?LJ0@nsO;K^)A_g=w^kYZ$C(n+g<&#XY-$_+8PJbksh|JNyDN0(wYEx%*=PqZ zZH!KP8ZDh%>s0OrYiEsF6!juWp-c)+-ZOo%T!!8AR!AbHTfN7iPe*&MSVe`Ul)6Oz zZ_1a@95nTfB|70d4M0k#pvlC`5#ce zSQyYR*?+t8;6-Xc+&9AItuE|`g$dYei@B&1-wvI4M-#EvHtuaJ(Z+F zrT9n!t#;1v!uapMfYQRN?HBcjy9z|uBF>L%NB)Ck0sJNZRzhBUt&{rG*7K)F>hG|W z-POmxf1==T?64$d*>5FJ>7@`vFflg(AQ+dJ`B8adL2oGj1M*lg0JOrZCRz9^xBElc za0(_bO6cnX>11}f*D8jyMWYFiO+Pt{zu6fJ$`(0xk%m0sP|D-+TYmoQO_@qIGLBkO z>5aokv063{_*u*LbQyh((Axe5)^VU(r(nK&jLJa=^GzF{2W=pb*o^n2*5lE%JM0rv zl8G5g@4^)9_>dH`*0F7qT?#J)C9|M{=W0Ugy(WG09lo}>m+9F{+yK6HDc^8<8@K(0 zB#(S;;;;?eE$tzde*1j6?exosGX2M{&)%GEE^utOKFVeNb+Rk*>?UsGxm!?r1>uBP zZmgsb0JzFLW%jJ)>Ut*RPaW%4GjAf|f=J~EaZLM{wDV{GMQvXRb50^6Y!9~%0r``X z$OjUabwiR;yd#w)m#a){#`NjnvoIVf0RRt^o@=4lquA|J(Z>#99h?=%&M9cc3Je)B zKCKWX>cg#9w?HLa&gnutd!xCjX2~MuL}Y<{`?&fCd!lS`?AMDb3CT@zk6K7^BFm!+ zTFjXX0R?(D`H6He&|)-S7xmsEnnGhKT{na{qq%DD)u4uW!cFms} z%b{y1k#gd>^s*pMZ`q-uW`5bJA(vNuj)-l$P_LKxrZ8q-XQWZ~PT9#AnjPn8ES_T8 zR`0UA#w`>4nF;b)ZpOCi8vutL@;T@`mhIG>Yz6ImBkR9F@18IBvt-{g@MNuCEnp!! zV<3*}(_6U5pSgYWcoAyxS{E5J_-;7WUuDZS$cswTB&GQ(I|4c*kC`-~mCUVj4>L!b z0={qdd#QS#^$tohU6&8a3gV3%AD|kB^URMMIF7uU!y8vC`mU1~>i=(ydC7GJeV=2N zZ(6Btn_s4`ZL<5Z-~Q{z_~ys|C&#RAayGP9sauYt*s+qX$6l6mq> zlA2!5pdPPXe$jk;BnkkO$6B@o3X~IfatyEhsME$5I#~7jxpuQ2L~3@s8TB~t=7UT0 z-}6!49)`Eu1=s|$-{^yix^KUa8m7$tlv!h9u5h|rR=ifOzJ2?9F{5sNceV6t{qAN7 zUli)rmapuy*U|d!{QKt5bu1Q#7J);O7zh^5kLO{pI|Bfq#umfWL{<#&=56Y12m(Zt zUj3sb*vqqs>5cGP1X+AyD2MPSR5L@GD(^8t(Z0eROERdY53RwURF(bj(QpsdkHiSc z%T3ZCT3J>SP1S}ciZBC=Y(YH<)VdrWlwv;!UG5J?fe<6Hi_g+6!0(zF)=#44Wd-At zVr_-D>1HzIMI2=bY(_??N9v>aY-BZ1i@}_VebGjVRzz;8*5m*IDZj6;2;GVcdRyzG z#15w)-V9(Z(Yr!z^7npfa$55ahhQe^E2~Ao6^qQyzn~T&A{Ic!VI<|@87X>ES{&=P zD<$v}RvTCNqUSb(Pq83Si&~a|v3d9LavxZWUe(JibroMEQ^l4)C3Dha?@2)v94`Br ziYhgj+5m+&s>lj8q8*|j9a4o3x(1xoXEC%*N1IQh1qfQ|rx@;wKV@8Z0)9C$?-gdM z`~OIJe|Py^3dfP=x}MePncxVw8=ZU#?O>Ef>;|~VP23W%@rA8M|7!EE52`AzIi)7t zw;%WWcDaS~IYbUnnps+DwH{%JNa-ChO$NTNAyfr8>B>D7;4ly(aankm`k!Vrl*oc0 zH6>U9jV&7)2@*Tc`(u5eT|lo}p_W#uJ823N8x7}@hv8(EgL&By$(O0 zk6FE^aeV|`DlCw1)rBxAI+fFCCLX>M^NitmL$Hrlx{4&dK1}4;k}TUmS@jFGFsPq; zS@TPgERyoEFiM=XBg{B;>xoJJ>fA97$BY`9l|$6ok%TJSXs$NBrN;Bb$w7FUtCRp> zx^|z^xh1>KpbDmu9ijBJl5I;zVi;V&G^GMyID2F5`z0lsA)2-BB+AygmLHxC*`_D` z*Se2ezz*=HsF-_ukMA-!14cY^A=B3iNO*#0St)rzGkWp^({p$A6OG21 zz^YesBC8lSDmk6|;#ae(#36|Orz?PwpJ&NV*^!0DYi=`noPBM{a61o|r)z+Oj6n|3 zvnb7X(x0ADHzv10fyHj0W7&~fK@(1PXiWkXF@oZ|POyC|BAPF@=HZnM^R7m=702Vu zQB|+=6rW*clGj14kj2o)i$TVX7HGF358P4^K%mbY)MktwaYI&Di&P^k8rbj5Ze^8T z97<&wKYI7GQ zg1HW=ud=?_S*`Ko_cHSpmRo;p*a~IZ%EGCBamiV{n?={(sQ4H7ISH`6T3WllZtG)m zbsrx4SQb{?G?eS=S(m+b^6%xu)_GePapXMu&Z%K7_m%(8?DaeTGfpF9bRen826(uw zhb+%6>~YSfR@~Nr;mg5rb-J7_s9P%d_cUei z01fsIgl23ZEOn1C*0}$V$v_4m1vp^xML^GR2l)6g_5S2Y0b7Vo%}Sfw_E_#8YQS2T z_r>Abcw60iPhi+Gu}XXW=0GF`uh~R;@jRm>D_2lJ$Q`!~RAHEPBU8T}#%! zK#Rhxd^4Q!zs7qfIlL`Lx=$Qik%b%r*}>ec`&%&rpW4Y|c3RlBolaKEi}b~R>_3zj z?Z4bkl%;mqNxF~5Y9}b-mF+llxKi7uXyq#IrWx2dFa{fW9qbwk2f+6-Y|=7<6%X4Y2{&Mb?f1wruIiH℞rd?!s=>s}?VX8VuBqq<%^-sAegLn=p% z1}l{rW}}9JpJ-(GTlK}~=QoV6>Rw<39}|dSHsE-Q=x2QhKjx!@YHhkf07NhS%~7-U z;rHW~@3nt^bl?2=^RpNK!OT$=;j8Bn!OZ3k_e$T-zZV&0c|JB1GOzBg7GKugU9Z0V zd3RI)Mz1??Hssa6U;EeyiNJvL_y7K!Iypa#0IZ#5r|yLeqI>|pVlR+h z01^7kVvx~HpAOu~@th zP^$GD^;wG(d!yIoldirO*&9SN(0ZvR3TdAHoN+=U+58KTkKGr$O-O6qsL*h zSG8fwJZ3kCtv~Poly}6{770xgee~D$k*dy$W@U??o?Tmj{Htxal$XYgt{X`-;vU(Z zSuafbUN@~}0WmCZsF*2Rzq+p*xR0-NcC@thPENjZ5L(^9_%~3nqZ5uFd4I_6KFhl4 z-dC;jT^Rp%-Ko#_#=q{}cjfipt*H&ylXU?%ZTS#a6)hRp-#Uk6Z&okHhEd^W#f(1Wi#Mih^vn8^RsEQ@9SThO&M<%c{j9jhi3Q47rFur&WZ zUtQ3G-g^g!3;4aISNq)L5WM-lz*~4BfRZ;0WYL@2&)W4dvC+qW;_*Xctu^M1257UZ zD%fIm2hYA%JQZy}GvSByjDKsP?Jwd*WI;G6imC(t!#ECnUxhe4H7I4G@iq?K{Ef%h zF9$q}H17Vi#IG@=T2TnH*ITdqfeK-QM zKfah9Q+|Gy^(RM4^1<@wDxphO-gn`0!{<|&&zG0r4+m&PzfP+XlB>%1d~QN^01(K{IvxNgFjEA|Wm}M%Q4g^o=jl9K~0?b$Z)Sz7vmO*dfLJJ)QmU z^_5>-pQR4^ri%ODWxbUB2N~}-6lD@2ndp@x>z<^p5F4U9UdTG}7vFCT>%tEd=0=`8 zdsi!4V45ox@b+C@^4?RMn{iI{%aW-Pw>N0}uj?7ut6YP3gKnD&i&BtH|CE1iCQtqj zl6~FRq`mkV;db#Z6MgVniV1B#Mx~bY`d_kp;*P)jCK^4h>Bq{(1b-TE*yb-g6L9@B z2q+dr?Hxp$5k#NCNYWm}43TEK4Pp}uW;Y1#1qE?t1oN~9O99ci4}%`jh6t5E=NAhR zHPGhv4iTUX7QziVbbs``JS31jRN5d^-hfutJ5)kDKxtp{esNTTHcWjpL{KqI%R9_q zkwLFr9F3D7xV;l*m|;q_NPk2Vz9Z^)n*tK&eqalEF&;tnCjI^^Aro4@3epjFVmxg$ z%A;J*T{SfRmOaGFAPdTfxGO0MYOSh3r7F7}khi+oA#k;I#k=@6!MQCOt6 zPD)19iB_awP=wq?6qRK({EXenfa*dc$3U12_(-b)(1@XDjqVg<9Bk(XESluD^Qx4;N~Iy26^2bcjb!MHVd{^y z*E9eL#&owtC}UzWv{|cO^4E-L&)xDq1<8&=m{M_Qqq#ZjaIHQ=9&CvL7Q|q8`!Fmz z!wT#}hRA3CMN5TqV2}SRc27HjT|STyYFjVEtp|)J6r&Bn0g+~s(?gTk47td~lR2Sc zs6+rfGZ{LX%!A@4r%jf!V`^HaG*y`6Id^iv`!PW24Z}dNNQNpC|KE?57I-UC<907?vwTg6&cXcjHsyC zO&M~nrx~Cb>Y5EgyUa|#igZI{X4tPx7&)d;98;o+v0lRXbYN_LVIDxyS+yM$9^M%j z4Qb8ykE~F#no*b&l$Sg(%R375wj;CSSJohkf@3SQR50-&8J$~Z;%mcHkPmEI-fkiZ zG$E7wK$C6G0aF>3<^KyaL}yKmHS|ZZW+G`2)0LSXk0nTfn}o(>BShoKK;30hgW6`ujr4;{ck zpz)D>)n}~C4p}yU(8J~uQ=?2v0xo7a%*hDmL|1xx<7#nG;O&daffptZmOZbe{z)s& z@jHhrt5BQ1JPk!~4U5CaYMRW8=aFodua08G4l#Sb%1XdNKTtFnM@mIlG1;s;9~=jT zs5Hn-bO~ZXhaf9WlwhBeK=(;>R~Dn+9`-JMP(%96VtN2SHS@g*qX@$R_93nxRu>ca z5`iT8m)}zwy@qI1;`Ubl87Tp1RBF~&xy;y|dQ=ZM!d5E*e5CT%zfUo$@arhjprpgJ)s6>qFNqd2SjPGZ=0M02$Z#rA4SA&K zg21klW)sa+R4?B3G(j-}PY~a2L<}!LJiWM=;EAK2pkgBoczLWi01Kr`Qj8`K7eiaQUs>f6)-y`UN%}qbRbhbA^PGx4g zN{r-`LUT-rCF@J_rCVCs-Dbj|cqvUZP&09YE)|@W41nFs9i%1s>Ck#;2D}qD5EM<2 zhsEMG7;ornTa@DjmZz+cVDsDKvN6&^1Rpib5G>#m4Z;bTo`01901ORaekzf@L0SyM znZN;&bhbzbmz3o9;06{R^TJ$?2u!4Vu34P4+$8R47VG*RAS*b(EZ1PJ>D#ZS+)l}T z!tUd(GVi|b=Y1R+f2gID$Wh>GXi)gLQXR8HHr076PJXMX1Ymrm@hU&PHoHnh9Q;F| zD=Qm@fC1igfkngt=xPf>d#BEKM*|QTxTTjQKcbK#XOJQp)cXwYO5im>pv{j5=X!bo zXS(pMll_zEqTDHFhKXo@#FPL%Ni zM0$=bdT@hE;9W0zDxonv??^hMfCV6URy}!O3GF;}bk=Lk5`hVex}pZ*Rerl-SB#`` z#k0!r$YXmvdrfOv@_|1L77!#~H%0UW)fWTuvO{;94NQ&2!=@|%lI2rHYb<(&;FtxB zfPotr!2xJOkVVA+0c%@Fbg@iziSKw@0Aah)@xbc|5~j*_aop;njgH~q5@Nh`R}EIN zBZ7mR&mPLpZ~d~2VLT8P=&A#$i-lgDjGLEZZCG_#m*M1gTHa+ZT|j3OC$^y9t#4e> z@(VL-s+NYCi-Z&fEcI>-T<6@-0t4%t5DIkl?V-m<2v_Q(j7ZHT8ciI({smPXiF+?UAE%V1^YP!X^X`ubKlx#6OemPl8uaX;GX7NJnzBHAF!&fGw~fdH(}M( zK=)9Qml`L2YVL#$stle!nt4&(mT8>wXuPZD8NmBFz*TiND|^bS+r?68rJA_xo;UC^ z9J?5;a-EpFb0l?{FmMM;SmjjJy?GRDQrp!veUK$5EV;OI9;#W;t&Mf}+bJh+Jta^O ztP^n`+3x=0DgNm%Lu%nwYsT>yZr^C0GrIMNZldSjmQf*4Ig zC`wzJ$N5%?Kx|Mc&wu~3LfQbmnI^zBhB>YfyewUK;ZzY(y~eY@#?@Eu21&#_>$+Tq z72qL}`E09k-O#gs_J&EiO~#d5m>itDIKPMa9nSqc9@B_O{6nKK_kqNBJa< zgeHyPgVh$PlS4ukf+CGpV`YMDi@h?FTpJe}+!sZrNAtd;Q2|j$O=gnq69m&J|&5IqS!?e)cqh(D%^3SJVxQ%t?N}R7>9#~{2hVwx z=0?6F6mY>XWbC9@JdHk7yu+iqUo`VKJXr2yK?=o?05B>;d9$j|9 zOQ&T3fU3HF2%L>`&u028yqG94ayvET!SQm21*2+%UDqC9Kb|l^=F86}^pX(s-~CQP zw$A9t*y3SwhEpXxA4<_$G=)*y;^SS ziX29#{<8+C1k+u=G5}u>CtuD-U6b@}`kibFOWj1_SZ%71v02?PyKChVs9Q6)HQm9? zy((3r(~t+lNa!cf<|O}mDc5@U8sV))XSUa-RqBJ=#Pt$m9yBbiuRT|_H17AU^4Q!O zLD?hz-)@y^k)$Jp?p-{&W&KgVJf!}R3tl`GDtG%u=N5dDw!EM!Jlc%r{1VMh!tIQp z{S8T~5#QwEDtYUs2OEC<*XIN5wnbt0IQcj)?Kr1F3TysekoONkQo+{P+@{iRGIy;@ z@qa;5wsgNrw&@>MFz}=^Rc1Ca{k^V(i~fYQtEZOX+yE@A(rUa~k6V<07QL8;W1`0x zfg&Q-2bE*&&_MQbUlCrtDw7};a2G5vGzfe`C+pXm3ktKuoHbCgm^CN>p!DW0j7x`B zxrkyDH8zf@(23j<_IyU8OZ(Xp)%WtbMpusWHKsk+zo@RAms_wd>(n_mz4~i}o)@jO zHi2%dlvHEiBsaX>_a<|tBA88odlZjUm=EXq+rK(VX8|!0uMfXB#|HenJHfRdm%_xu zacQCb71h@9UE#vb7|SgPJ^E9;SZaej!>d6-1*W$<4GIsiKo?d}0DGh2!vt9&6bHuU zH9##ZR7eS4R@|t>TQ?}A%-6Ql$V$W|&DQt0ODE~mdjr-dsv-+EO{!0sVl7zTek)EH z1JtGzs!87LG^tBrLBbkHe2!)f8B)1aED}Xdbxi#?z*1NX#Z}U*g;st%q$4FH6sL@% zR>Qcft*OA#qNA9A!=o-fgo8A(a=i45KL>Ne>6viCIN)m9|^OGzi~V zy-H%RC5n^66W>|hr}`ye%f;FLIfcun$OdB`E^Y@xr*+s3*|&I?6ND3 z=2va=4w8T7P=c}RbdY}P`OLARj2eo4Sy}gg)rxPVDEsYMK$-matZ1LS7wQY5{%6BJ(fd=TPSqF`cTjv@+VeQgWe{U4Cxu~1q5WDSOBTo z5qP50L8PG;AddEYv86usAK4aoBJD^Lum1%pP{lQH0h_mM|A7?k($wkUlzgFXU)ak&BwisWV`nTb@8wIWxymf{ediMsHkp?EJQ$D=nBtv`70QJ}dS z*j#5~%t6+)Lt^qm=`*q9No+&2-twZYGjXp%tr>U46rRn`#J?Tf_KCi!=HHu1K!bMD z0-gkls1+w3ZhmI<7al=~&L$;<+OTtoqm=Y!lQDxfoFWI9e=f5r@869uW6#qB;K@h; z1&2ID^tndsY+BupG8Zugmw}lA5Zy9p%a0vY%KKU*)!m#V!mEo)$%(ydrA38pp1i!J;gvH&6 zMUZ*$gYs@yI3=z+raE@NNc?-AATF1tG=qqHkFT2tryH6GGib6|0!2y4==X<7b^r@# zhV1446EO}|3Yl_AkXWcza_e*fhn$0wZl@N`OiTgg1+%Z!J>C1Ch%w}#<{2#HH^R6v#%WBq-HZSQt8LVK(@1IWRq;6fpEr^CES!H(iSk0oA10#WRC` z>iPNzIp@wo-_9Y<`G$lrXRHl}gdtLLj_om3qF5$SbIfbLDUg?Sz)XxOB;el8I7zbA z>@u34oy+8Iuy?w*Go*lsrVYnOF86kZry+jG@0c$&JX;G%t0`^6a<060;w19`(T=Gw z*EeS4+-OVrh0XehD|S97d~h@;>V8o6oDYu zA8v+zCbL{Idxk4?RsVVw0M40NCuy&27Ctr3^<%%%A`mBJCVolr+%lHDg*7|hEnKAA zlH~2;fFI#Wgxqs0v2e#2V$KF_mYl8B!i=HE;h3-0RJ_6(jXzULp4;-=qZ|64UZ=No z+w#Bl>4#ju&KUYX=*GlL`8QcuJVBARrO9NXQyEZJ&QlCgIR!mhbY5!1n$(7Ec_tzJ z{|9Vj?pUe5UHKi*ZxPj6ZgJ` zE3x0`!%g-o>cwGsug0{v7uJtgmZrQ$YmY|k?e$k&-?)0omw1Gk^aO zK$XWX+ErDy2qfxrI(O*^_13sG^Zt}D5|8th!5{|tCGS&>d$I)Yn$R2#TffM{ zRGc@Xw%@Yu_-%GRSBmIJSxu*!gPdG=Ht|h=ai)^!fUbOZN7^wPs|!HlC`X)Q26pjYEfC_u>w-M|G?`_)56-1x5E-; zJE~Vajl{=Ezx}@56C8_>d)|`f^&~#x?I0$;f`SJ9&w!EO{$xn`&SC;ynduYQFT48A zw|!(Ycs}3OL8C|NX9jHCF!&`%$)jha}ZLTbDu`JbRD4 z@Q6a+9rU38-K=AON`J`S7r$(ez&`1Xy=ceP=>^;L;t3aRO2pMA_7u>5gnIN6>aY<` z#9t5ek}z`qXqO}lLXfiekqh@x$Ray{A1Q76s6CK)&tI(2^%Z6y(+&D)2KyMl_A$MK z5%2Up$f19*lSzMXc~sM}701)E7XcUvScTF3T%rBkv2?x?Y#gaElCS$yd!>srn!-n< zmm`rRHUyP~$hc{#?DDe5sr|^RJ}fXKM-bUxXx}O3aj#BAJrO28FeZ3rB18IUAiKW* zJpkFJ@fdE@kP^hb;U>L#8(ktpva|6_xE})gErSuuVTeV79IL%I;~CyTH_HcO0O-WH zQ(5VLh>-9QaY2(#D5Od}6LBn))i2W<*C`=D={ z>a!&WQO8gMcOmHXMMcbz{V_$PUsPpz)OFQUKn7B>I4V|R0(dw zNwgzT8dRajQ6NU5NXFh|gAQyG1a?sMBC=Q^o?d}vYoM=pIHFxK=JaU{tlzv^K0C2c zRjea1HqzWsJ{%?=oq>wB@5oLp2@566I39V=PN54QE@1hD=P=R*tBtQ$K*tz##fT~8 zhmK_nbZ@jn3tyMRhy+Kfj8STY z(%PF=joQ0vuhLSZDB7Z}%I)+0+`s$2fB)s2yid-_KRH)kujk`=m4q->zlCbmVXOcu zY!#%*7h{OIJEDZVv7wqWLpKnTTnJ0Q$2RgP2^J+oRr3~8Pr%QTu_kB$)hIVL45p*R zT`Nb3tF#`nF2s_AseR*)ty3AZ%GaM?ziX|ilM`rAfdOR=jiV9ct!=}^boIL_=i?;a zGD&~}&iE}!_t-+!nc?3OWrwi;AC9sZM zFk<&e6Y&|?e@h&_qR?r&-p>Tc!v-Tr=Oo>iQP0w}J;uZnnRg`>x z!yWP%NPL`Q*jm?y2^6GAhMq_27oJbx7bf|h@NY~}(^b>a_$LyLpUd|^j0zwOatE|B zs+q5+OtWxi&j8hwZnzTvM=w|7<{RB^i_Pl)%qsQxX{5|&jSV{644vweg(A!o!_0_L zhGj7rl~pWK6=j}G`Oy$O`!~Gr_J?FW13E*0El8S)^@o{p*letHTa7bhs%X~X;`yxI z4Cct}`GSSjutmxFEOuwss({D0sphbNu}t&U%=fm;6S8>a3^9MAI&VDl zDjM<_+m{3vT~*+_$q{9_N*nxZb~(!3Tm3AtsNR9` zsm~1$wL4s&E0(nSz-oZhdqof%}_ViS1jOG3I5j+ISxf z`;#uAQq>VRHI_n|#!aY(9rX=9gh28Zc@-EJU-_pz?q9kKhDmf@v(zy4Zd$5yfGan{ zKy8V@5v&%ngvWf0XHCCKRfq|_0B-lU9n^HIleRV2rMZpPqA}cQF$N2O-@z~dfp>Xe z!kD0ZtXdn!cR+Pwd=UGGOV;(xEAy#%lP)7YAIceg)wLnm z&hS|s;A4$#9$MUV?&A)3vJ0a(FsB)@A8iTj!#7)ygknrM*7a!_c13$x=AAH+pG$7$bv zks;eTUeXSV09zTyP|TrhSmIsSLpS0v>kQ@VmKETLKr-R5^?cRO^o}kF06IR!J4#Q40%!?Q%$|w6JQSXkL%a?DH>{tjIhFYGpooQ+m$*)L8XeO~Uu_ z#f$=~t(~mrOqM& zcujVI`YQmZp*-;Lxwmz(wPk+kVMmTLH?(I+kBX_&t68Ns&)!m6*#ZJQ?uox}oLg*J z|LWYb;3!v&+En)%{MGAO7|1}1vb9K!2_(T0Qz|cdu zj>5{=>1dY8jD~)Q0eZCH2M3jwO1e(pTH+i|~AH_eup~KP|;cK>+9!Y3aaTytjO*09c$ASv?0bb*4Ze zz%}pgEkahL7D!*y=xls(Y=1h?-F&aX5c_H^u1j1}jCKDFr>gpVNrR4a%zMK$K*sn3nuD<1Wr{;aRqu<5rd)vquDb{@zZq^Bb{fC@Ug95R9 z>t!X5ev`M9gpCinzS0l@Z~2nE{pTQ{U;A6M^=7#qsp$|y3++C${WvN~RwE+H9F?&7b`{$_Vn{O@4(1HF#;*xy8ftbQa65XiNmV2(| zm=ex?dKi3ToMF|`KX_)ZKl^?wN7T81Pr}swwSir7qr}H|C!Up?=RSe>Zmz)uVuOBD zXyyA4vPQ*p*R`;cUA5pmxR+pwK-(Ts*iC_?>k^y-DHJU+@U zw4gomQg)4I<3izuN!6ZasLqo=TGC zhtt9d2-tDwrc^ugLY6gh|2|Q!MIE!4ZoR6Z(4G9DDj2MFx{ZJo8MLiCwNXSDt%xaL zZng&5<^0Gf54qclMzpw|2T; znZn#E#rGY$NBe;ZmzHV5Pt^J;h;~P8Xh$VYC>TsC_vrYs*P})CDEXQY6$)LT-IE99 zit?POt@}k3Xa#Sp(y-_RU2{5CF>(5uS5P=xW!aLGa6lxP7UiX7}%!B@KqY*`)lPIq?-{Dq6Y#NbFL z6dzk1J%9!xwVWUmZB&GW(ES-Iu?jb_DXIzGwlR#O)jE(M6lNDl@w??u+%5ShXHL~! z-(qBaQ47c8tbr36q^1=V=8yhdQrExh3L0M^H4T(+d<1SnVboA-p!b4c)Q$s~!d?6m zq+McjOkvOW;L{ja>@XKl{lb0pKEwbH`OWNT;sik|9NUFC@C$(}S{3;q8?WZ?O|gXA zvF;H!Z&5CJ8XDj(7PA5Vz&-u)m!Yh&<2zw?JINBYV7sX&SN$+f-r$F%q^8}@FSmoV zJl*AQVLp1Tuc79dU@7;iJ;8fj{NSOa7mv`w#xc-^B*gFVxGAQygWo`r%OgohT$o!M zIfjH(esPJ1f0)_wQWpu|35;U2s}@kH^8Ub0Z+dELkcQJ7DHo&7Xlk&!3)7P(5EOnw zrY(gNnRNb%$Zqjgq6mysfO_3(r>R{@re9W{OQ!89yB{(Qfo8lDt!wRXUJ|u40MH}wVzl&3X0`JP!eQ-bkbHnAj&yY!n8W0 z$xR7>M3LC-XVFO(g}#QE+RpNpzHwyZCu=$#Q!PqX6(2#)HOiK4``VR0UL(g{5npkp zAfswQZux-i@#u?p{EohG5Z^%>6dv_Y$d1sKFV^y3%$5V>6Z!i1n%08Iky#tI%!I z#^GfLuT0+NNSqUET@U=_=i(Ib?|`C@`Cw%_1w;ZMz#;a!(5UE^-E)^`VA-qFJYu_v z7CHTzV&C&S`yVU>0_YZRMrQ`Jph{LUo9jWE`1HVqcx6i(q$T{d*lZ%mzr%Hs{Agpr zA$E3k*j2ekh)z9qhr0-?NnB+HD__&Yeil1q#V@m7*V5Q6QSub(f5I-eOa9us1gwtZ zAvx!cWo+qv_1KoXek~wCU&GZXeSx$L>kv?pui7fvS0Z-Uox7UTD7_MgUu77si_MZ;~uYJ5rbsn)X<)ygNvgZsyJ%JDKM| zo?DX}BQVgxJ8jt&Ckj@)oO<$M3hBG`QLQ>;Y*vrm);+{wA}e^5QN@2lNV|%4*zuFK z#9Esr#)%KvESAB?9`d2;~`Jlu2 zjx)tu{EHkA8_Fps__Hf*9{3&~R2}}1eyiXvbrd7rd5YLiJ{)4doJvJFAs&un4E?zo zME5t{thYBmrMZM7aC)Sy=|R4XRuNNZSlSQrI@pi>008`!niPOXhrSI|9OK4V<;A2M zWFD!Um?S=>N`KF!hxRlur=;(gp{ds`(iux*XA8cnz(Ht6aWqfGk^9mqiz%gHh4zYT zS^3#tdKa55n@nQuWe84X9(w#ACK-3)ZVS(pa54bEgCy<2W$*7wxxPhD1rVB0zh0)H zd+7wBtZT3!KfNa^w67}UJ4k0^nQps7+4eSIsmNlbexJlgEdzZmb%ToBGn&X7Y;;6R zAwZi9buB@xI|nrEPok<2V>s&9_BL52LBk@37GunVmb>SzYKEC1Fv=?F(~amqAMr4{ zy;;ajy@MUrgx<91>l^I3rz;O)q=7c18Q2f%w-tX4kMl|6lzROfbv^g#Z45pCs;7LJ zbxN)+_f_`6#H_k+L)@s+2-eI^0&zBd=qiXt5#g_s^8RvZ-+c8UIk>5m^PiCJog33> z=}iOy)vKP~FD!x0u7H{Mcn^KNy3}RL!{^58a6KV<;jm*N+X}6r5VLgD!WXXQ-Cj~n z?S|ZfB&OJXW3eW=LZPbLo?f{Qk{r_HbWF{!HUCk(R_aR(h0DH(NW1+saf!whM3f{Y zpep(q0k-yO4FZQClG22rL@(X+nd=hqO0C>Gt%_q;Oq2VYsoV#6w0E(bit^o~l$m-c zc89a;XwTHTU-7o;<^%4Ji)hZ?=X|UcoK9kUYNV|Tb-zn0su6eE4|{d(!e*I^lnq7n zc(dvgH$#8r*etXpFrVu7_|OEs4h+&vM7;}@t?PP*5x&RPF{{hhL?)EF?-$?rjL)Ur zsanVCsUU8BxKC`A=arSEa=hU{i|925$3@MM#gFAJOd1u1L4)}!g*$q!X0kXdvu{gH zTbs3x#}axBRqNoJZ9OzkCF6KSH=t&1z0B|G1V*e`X84LzVeH279vRn^rC!AlRZvJe zQXY+~Hk$1qX3z_<#G7a_znI(or8Lh2E3meHFL{)yjI=&Cjy~-(!-@L!b2!f2Km@o9 zaHo+bx0mwbjCYWEO566BH|R@q5!tsPhBR-^wJ=lWHJoJ2Sj`Z^Sok7m&3c?hlX`>v z+E=q+^D>(;PxLROx+e9v!zBO4<)h^+S{_k}5Yow+kEtM5N5k+>@#wcP7V-5$m zK}|VF-Kz@Qz*snsxQC!54@5)~TL=^n&#KTqT|Krx z0BC)UHj)p`Lb-Jc$s@i%Z+Tf3HS@6Q$bE@AbZmj4tY2#?B9Jg&&tq;zD1B6Z*~E|AM=}gtq)MK;MsC z!8RS1gYr??G@(>FHh+hH)Dr*w=ptSc)d2D$fCfS02nu0j&K~a9s-L#SX{Ux*s|w+K zFdG@uk91Vy4wWs`;xU(e$M7>#Nl^Okls z0y~N93*XdstHk9+VV>c1i!SoMjx9$UQzL6K1 zS}6oQ8iV2bZfs=I@ISE;&_jwpi|^Y>=*jPK{}SWPz?AsH(aOeAnNi3TIEGqEmyr|P z0Ukk=B}_`vk^04$PH|l_6)C^tRT)5)s!Z^z&Jjyh*LZI^YV)R}ublhG@sWczMtzbV zfjX90NhdS7=EFd$ai-+{soOpmjDtDUPa^w_{Pa!14tcB);i)v8NN4jj4N~1G)Q0w! zScy(WlGy^4UmD?#`I@f{_%@DoH6*!#L+g=8c9kZH|01d2+kSznxVw@z35=+I=CC;J zu+$G?{U#HeYeIlERe}hHK)S;Wk>h7;idBVVs}nX?lilY@c-ca(8Jc4LbdD*7#+kKf z(nl6))@kb{>OXyJX*aQH?Kk&j*WPFdYa0-4AUYVys9CI@O9D+tEOBzx}EwiRF ziyNRjA*0%uxJYl7LKDn)d-P@;{*HPXsBR*#s0*+0`~rOt0In*D1ZbQJX!&@W*$#5}rHfYhJS4URE60@*$Rqf-2Gw^aN3%you0 zOP{O>KK=Ug>7gL%^d;)j>+k*-r;FPiK{@J?)=HOT%1zq$GW$Gs6OyyY6SaEzW!;V| z%#2#dLRMJ)w-%zNK#R@qJya`HFM8G5qJyW*7y}nEEsT`uJD9atLJrb4+1oQ0>7Th{ z%9p*thfG?udX?Mnk2>`B?c9=HOG%Lt6 z$x6FwIkw)@L#it=DSF(rq4}|gLT2iE`0O3koNFvWY`M2~3-$-{lr*nj+zGaaiW=-D z=Rwi99T7dKCFB*f80(wqtU3effL5jo>;|vKW`Uv@FG!0~=Nmv>tb`wA>!-V`iK&nM zYK-J)9W?}bK{f%}Gi`cr?Y~!9?1R00?}6o$m?O)crbzZ9rJnxW>Gqg!pTe)}3guQ-?n4tEYKWa_n-PhOe4n?|7Uw2)KbkoPvfXnEq2+T>O zzB_YqfushHL2NQK8F8R*{xSa+SKfz?na$+58t8t(=qIBmjuBPEQ>-EBF$2?^A@vE& z>VhH6<2#vE0}XyXV}(KF9?4*1wF1B7X8drINn{O*-qzi3n6pd^0Sq7R&mlm9b^Bxb zx?^*!EcRD>`fiWf6871vhtqqe@4c|$1Cl)sCW7Z#7pkl?sYlcOlOKv6MO8q)aBZy@zfP~&#Jmq=ySa$WREb6&<(4?*3 z*n~dgk=RVf;V4zjc#HqSg&}CgPm6VXI{rYH{os8~HlsJ`f}XeyXlsSM>ZhT^FU#6r zjuNMywWp7@t;t8=Em}!4A5?;@7BdHW!Uj&Kz0b3MnGI2=)JM}?wZG<`(LvlauhqD$l$$?EKlEkh+#icOU*+zt4 zmz*-=@1;b^lElrYx?r;Ovk_gAT2sYDNw%yy=u0Gw0@fk!_5kCObWuzGSAzOM9R$U4 zkW#g18(I9foZ{6GBIg$|Ze}O&#uHZ?D&EfoQxg8bn6NPf%#jSmOP*Q>P@0kI)x#<3 zD9Mh_`d-jvxaH~AKVETymgW;JrM@EVC?f7tQ8Iw9OUhbGW*2{>M14O!t>+?2At)LU zw9?#mwNa;Fpxc~BBA8SrnCID?9jbypZ^+TsSlZ4ySgb_r(c;w-argQFe1q60Y4Pm_@$b?SsWT!BG2^BEQsrMUQctW+;vo{vdgPX zF=Ium#+mLqU%s14+e8Oxe82M9cgXt(G=J%#FTgn#jcYzq2?{!j4~DIMiouOw@TZ#( z%=N@dYzPknl}BD4gPMi|8$VM2RQ;&0Nx7!F+$eot#%3&ErSzuqdnYZU#yeazuZF&< z;4#eo542-LBqn7H=9>&Y$rx^ve~v))~>bLj^c#$ zX|o$&^zUg|sTbr9%gvOZV0O=sbU~6fV~4EU2&bW6C2>Erj-;|xE$pc1QQ5imL9MIhL1l&boHE%1}FyTl-ZYi>TK= zb_nH_^Xh62_%3(`4FAf5F`1AR%cwq zCWqIxzguF84u!DHdUl8w)SGitOR~^z%-Or77Te4-uEy4fUGkBGt&x-RPv#$WNdAz? z#%Z#TJ|#;K+iV#~4z-tvfBao5nsOlet|SR(aLjURteQfcPFtLj0{U-1f|Dx6`4Koc z>t$%0YQ>2opx|!Z5>6TKBv0;37LUmY^>%FWiv0|pY>v5^<9VOl3nd?h-fm0zu8?}x zmRkDo{ZT9C{E>eZ@jR7FFEf_Xjz2zo_4o!xtsRe2 z%t>g^Nmb0nx4WqSGPn9IlPGJybA~%RdwRV+cT%xnzP;d+V&V4h2+CuiRqpYpFwfT0 z^O9CeuW*0%j%*3$6x27xQu&TjH6?;h2Z8v~?C15(SI?d@D56+?6CQQEL@8AybX27N z85~YnQt){beYRix^G_Nb^(4YvEem;bq_ViK*h>hX(NXhVsrIa+mZ((smofsOF7yyQ z`YJb~tG}caolp6@ocO3t=S`E5atL%zv2yG_>0i{T~ z-Xo{3KW8d?jrWPa6aJ23-IP1$-*kRbez&d6eK$&blg5T>(CgliHJ^qsE{0hM%LoTM z$axDF#pS<;cTplInr?fEW9F|Zb(d-3eE;N+w!cHE3?;lBN>v&D8ye7Ux{tuFY4yd5 zvIO>Jj6O3FsYAs)oumeYx|e?@D1(0IgEG_Xkua&a4xcy?nWXGjf3(QDtKTd8ml+MK zsm|zh&KRlAQk?Kw;%y|;{Y#u$f=J6E8=ub#e8UVzuEe}QQvS1$CLxrn&N>ua&b7=N zmnN?~cqd}^lj_=b=h}DGPb`QCznI1uQ&->mtn40>BL-9r69u3T+y?;@U7ruD^kOGm zVnC{o7&A+7L^`)xC7w#*l5KDNS?5mbtzG=P-50lX3^2+w6Cq->`s?PxQJ==ONTk-x z>EE)i&h@Tf^nC){1$Qm~1d%Q`YL&lNEQwIr6YM&XQ2Qa@_2U8g`kU(<51N9~WArjL zkAEiA2=t&53v|HXD0+ym$V0kqRUK99VN#cobn|4JuaUp0o2S1heH@HvQ0w~lN$qmG z>+-vr4T3}+S3j#a48&5<>2Z6EUd1KJ{Ig8qge+xovXZLO{|pv(1V+9fkNcxN0Md5x zhH!LNwJ=4r->dl>P<=7V18P1q1{b5aYR+=1JrGXF=mq`*B~%2D^MmnN*s}})1}!TF z5J|*P&j>k7%MC>Ylr6%8k;=Las=;Wv&j(B$DMrO42{t`4C%N;zxF`n2HhC_PB+2MDWgQ6zs%?SS^`V^*!A&5BCG?<52D`U)@KOOVMtL@^ z31$@(T(#Em4`hpWQNWJZ^;qE2k6wzs#4esyI^Y4%>HYmJi>cjwkXD2VUJ?EVx{_#;ghF`fT(4 zJ>7r0IN0v591n7*+6PzG+}+V<6@!xuUVV>tH0)5 zQXNhgeY)tPrI6eos%f#;o(rw|{l^VZN}>&+xf&i2IqidzTn^gQ^j%6 zF8qD%@AsZFhCdaeYqw!PKBy5R-@pG)7sp1b)8+5oT2%Py$KH!o#rhU}{W^sl7^oWF z)@w~kr20xi{dfA9Kf6BpVHI$EW{r!QKubW@EyuR$h~S6$?RU>nIk z+hY8oI+Zv`=e5FfJVJwN3UtiO;EeBz1$&^<7mbDZkPUnWCwwLi?+u0cp9I!Nj`G`_ zWH?I;-1X=4IS~kw7JN9(tkKhE7%2Giv*0&fq38yo3=e^1Y2mD3;oJt{f)nB5hB!yb z8;l_CS zMrnfgUXl5vr{{IR9=m1qb*Dgp8^n({yh%b(TT%*d-CFg4& zX?|^0NdFFB|D4eU#GDb=k1sx<7KUgtL=m~vEL8APCN#|cWaN@lg_lGT|KfQ|fZ2OJ z8gkXnqYV2W(624Tt?y5T=#OLd{NJEIg6gTzXAEDq_h+n7xk6OA)UietXFhHY1$two*@4;~c;-k)vx6vq;9zJU>}IsbfPxf=Whrz8G*Gt(;I_g2ndfB90B z=i%?|BEh$c(Z%?HKf6!BBno>~f>w9G)eVaOwTFxa{5@#EU)S%l(Vv zA;H>81Ogs-_0tlgeQ2Nf(Vx-{3gA()WLkQ=_N|2#Vk8cJLD%=B+hA zKRl$ZbdrWkBJc?!hu?PAq>p^#6i6LPG$^m)1hp3n>L9P8t0$QiZ!rR9!y}e0epxk1 zWujilQQI5v?3S@I@khg>j!Gt(@9;-OjErQ4k0bHNrd&5sW`*3>x_|036UQU|Fa*JH zdTK=ER7|{yIvHz)#(zwH3V#s|nw@?C{^8k1LkTPY2MW63tJ#ezlXP+0B=pON^gJ24 zONC0kVrg_%$Io2YlS+NB8$x$z)=OjzV7d=RJ{JD@FVgQM9oy~;sn9ieW##v3rd;WN zkUkz#$;kxr)eO9WKQ?6Iz;_uUp!jNPHVCQ)%Mfo+<;pVpA9r7mOBNC4`LDZAP`TkL zKe{;aST5W>7_BbwcX8U)GQ(ZA`kr9U(rk%;T8#z65h!MAfoNG>+OnGF`QWac^wDW+ zH+8^k+tKtYKl3)$y?Oq$(Uq@CXYEtTHT4*g)xGrK*17ST(ALq_!^dZD)~QjoVaX3x z{%C5Md@qvQ^{r}H3iU078L#;az13C_|4OYf zRZ=Nudz(80!txm^&(#hUZG8vc`7yOp0QkAY8m(hF6nvk(0x@chrhMm`y2|V!fDeAT z_XeFcC_K&mgSM6L$4Peb;1%o)AM-YR@Id#~SGhO8iQ|DajQRYP8w$&dwrA9}Ip3*@ zRDS%PjG}2Oa$ESf8Q;-X&YOIgVn^UI*wwT9Go3SmYATl>-!+Wr5GnoJ zRH1Y54(T*DLkP9VG7~FME6$t!@%W(>EwwL-`Hxxt&_RlCvZM`a;(g|0=~~&gZ|uF+ zQi`0XuA{~(i1TpiW{Z|4BeH3JqrXcBB65~!4pTpo`;)i$>syhh`wAZ)uk7)zw%dLF zb}`nXw4~koc00L_0mbV+dzcg|`EYZ&@-%;KGWk)DZJIklKSsX0=AocnOF$-{#yXIC zZv0a$$Y{EHgK4sDKz<@J)((df6Tnz%sH>x(VP)b=qi!ykFe z7X#<-!Z+<}eiZ(_*jA_g189o;ED=fB(zd!joiXvVocMM4`vhur`dP-0n5^Yp>odxJ zog5=fNWg{-EjJu-tKhX0;r4Y!pY4Fvry4&%+(7vA&cjEO5wigy`?M5KNe2ugpIrNNBkIhh zM1SqCB`JbrGWFKamt6n8+P!dlZ=`d!W@@?;tH=jSBq>dV!fHWG=Sw&AHP^kfw%@*(!CSQ9Ij=ZTxuQ|bkHMVTH2 zT{O8yG)0@CgBMS&Jcsm*-1RUP3vHzf3Wn>yo~}NM>6V8Juox7}8S{q0p)HC758)vn zAG3tT9%KE6rwZxjQtysq z$T0?-(enH#xP<~B{nd{=8swya2G=X98^=57!mWhj)Xw9(UE;{N6Ldese%+1_gqVj^ zaJuPmXJiGe5QL?Gc9 zx;m!6nM@`?JZU{P%iG@N113&E4V9DdnIW8)1Qpf=Fj`wOPU3y->cAOWT>0ux?9>C1 z-QemY+AF}@#zb<(icgOy51f<0X&}NBO_wq4f6)3AO^alri819BbhRr=nq~F3HNJX= z_6odTF|=o(t0!T0lCb_Zu;~=6zXEhafxP;Na@#uTWING3GdzOZmhg}se8!0)Lo#!P z;h-8YV3)aWh`w})2D(M>>;*5`hOcJ>y6WWp3bq6-7!(e=y2zimChuQJ*}lq~v655; zCc8(ax70yN>U2%iQ>(k7?!>9|UsL4mb<^ZMY zNzBrbgmo^x?l`+C8lp3$$GXPPh5J{|_fQCq?1l+rBixN?t8t)bd|-@l25B1fmwGbc zDC>teB}Nv2A;5zIV8Z&9?gTUf@lw#dIN{V5$vmJA+>COj6wC&hP9>V%$#)x;0+}j% zuw6g3w=eg$xJ|q|^(Z)JZ>Gw;RI(rR$Ug z$@HJznPfyb(2-m4OcN$G%+-AJ%0rWD7%QvUPA-mH&>QYdq%687Q0nVu=HZbQQ zm>voW&`r_&V*@^U(=x!Zc47Jc1RNV#cBF~&s7L-?sb?S|{3VbMbjLx}VVB)deFboh zGi5l`#RLv^see&}6C@nb0@nccaNMeJ2`L@)XE&^kKVJ`cL{|JC(a+GGasa)WP+`qQ z=_Yk~B>W_E3o!es)QHA>nm^YLyHPCv5DZo~*ywHfyV~i4+seqL!g;wl-w&p_OJ*4C z*q!A9e8Pk}3p9X>&kv-$8b+JngHN5wNo1jh!te#*LNXlq=M>Cjsz`ULC`$p*Cn5vK zhhAJ_Qy@{Gs%{!5VG`v^;HEHm!JD=Qmm=>GUCvu>apO9(tpsg5oRElO)})|wt)zF- zgPV8IInSFSLZ6e96hoy#i-(~f2VZv~N=Vd8T8e;MO2A*6zx< z(^uZ$T0f2S8U}Zb+%he~^LNycr{!uX!)1La_jC&m)N9G$P}DE=YDv~YP+C&CB;nET zYLhw`5(eD0h0z19qe^+;<7X4rN-}nhRYA#DEjP(&jZU{|SfyN8M_71=c&bO>!Pn`o zVas`tMfA}0iTDD@n|dEea(@^%D6x==r)3|WLDs`%2LV_qGjAYOw3XT)ja2;9Z6G+m z6tttw$*p8YO=EPBsb`f&b`=dH6}%36ElM&9#D1po?8Y;I>wPYBd`Q*vdXm87R!*tV z)kWBDPR8YN)>JE)f3M9dB88w1t6HZ-=Ngwf=MKHebeBxbuxY+JTLuxFbKRXYUSsQS z-Xjoj1c+oyp6$D#t8M60utR4`XYGR1Bid>J+k{fq{vPUg7*w06(K@hEOPBWCy{DG+ zLo0(uIkgn#VPidf?tb?G6jq0JyIMt#YYSnB2aI3L2z7eo8$*hqldzQj3NR1_vhB_> zfdf3?`hq=JejT3gdRBg5!#=LzpxY0!2m9+ysTa=eJydb~Em-zE>EGS^Fcp@yH*DR$ zD5k#IK0%VE>whqb-cLKYJ+JBnI?_O1Ey)LxVzoo6HDgQ$lv|DBkk*M@vDz< z@4X`HG~;_0(V>mtq+HO0rJ~aHdcL=)VOUx`HyZ7yn`qO6_H-5S@y3=@#XCWU<})Ms z-=m+p48e+#1K%j^0uyFhU(?b-Pm3}mM8SE1Db&04Z}(*zeMeK$nBe+M;c*e{`Q!A* zV?8G0$PaHAdd7`R(&*nZnkK^Ow3TY*2Wjmm{B+4LaphfqP}~(Ui|p}NC?56m>KKqU zW%X+ep)K-sAI{-DX+ttG{d_87qZ6lc4YFSDCYc{vL5IXWvy%j~thT_aO`k8LKR=iG zV9(IqiUu$-v{!#9?xF24U_^Rk-}Q<`ydF!+K9~AFVj%V$)~A z_2|3fv);+CT|4KxgR31nXJ^BboyX@w40tW5>5xxG9<+quJES!4G-l==KB74I0HV8 zvsKUZpGp}I;~sqpuA1ufZG*%2qjL`6rqnLc#D5D8$;k;~E1T6^35zgyASVgi#72^P zeaXZW{)vKnh(hBn9cp9Y-D+i;!;QR^!r=N581(=JG3`a-~nRGx`#^wmna#&H4As1mTL7cR`Tf} zNAmE+Vrhi})9#Ih9AEp2>!t`27jdxPdZpI0_VmC$7tUqgq!E*;N)O;kextK2>jLks zh0mRvc4nUbzn+~g(+!_5`oC#I+H4NpOKv8jyWE{Q>3~#1pyxg&xpNranY=mD_UUoq zE-fjDek%P%K8mjU=EuL*d9WWUe*VcSnI~?cHc^q zU*cW+R$jdu)Rd@t{Ar9VuaSCXh$h-rid^z_?k&9nnHqR;drK2j3T$eCkF9x_hE9ib z9!aymeuy}atY*I2j7Bm=cgdhyllvvy%m}mBQn$x3{uOmY71*y#YK8X+SGH|jtU5f| zZjoHuczR8Z`-n{9YSJq#bKbdj99KTPc5pYvPH-J{Rn>Br+^qXQ^-8 z<#+AWWRyUcw!5mDspJR&=!T!xzNXIK@Po52)`6r7OvDl2wJZDElQ25`WXj)HV0OVuut z|FUc+^qX@pnwm@dx%IK@Kr)9~JiFp6%Z_w0r-A1-KRhQ2R8i#AwjwK2WmdOAqI3_P zpyhaRVX|weI5&(0)w(J7;0PZHuyFUoeW%31>MOwS@Hg~uwG&0b<#-WxvNirf$0`CIKDOlskmL*I=(G%}q=%R7c|mt8^p<|+>$ z2>SW6BZ>%CaE_)pWr!|TSTTXduz0q)ByfewD#q}wL0!xlqOLBm5bXROvgaPTUghnc z`S+R46Cr_L+r3_($d$Z{69hWEOH;4J{q?&hcu^5DeyD%*yyqMLMzK)*+9d@S==jCD zC7y(Zk`Zs0PYy+OB_B60^ZXx6pP(8#Xn_A8ORsuX;xqq0mj2XcJwfo@e=YqM-s@fX zF-`g$|5{JzBm1fiJu&rvEWNA2lOW#l4oA!PoXr2V z^jU41MGtYpUN3^SJ@tqR)sjSezZ)Fyi=%l|d;?MQaI zw2T)Z4%Yihi>NpQk55>Dz$~(LDio`!`zssYA!*_u^C5dy_ zj)gr?y_%lVS(T+i;vaJ-;{+a6u^8O5yqR)9gcZv+w;Kf9NM~9Fc`3-x>%WXPdA_ar za^HP{*D$zb7{43mrLyJ2W;F4!@JvPUEU*t6+2L96544k1k!cY1PcUmm_yeC5Kb`_S zQ%sBY*5p?*Nwl0P`Tt1#pWTAx-`+am4B#~Uhkgbjyde{VCTPAtE)%aYaOH)aRCipM z-dIn}W?WHrQR4dkHJcI3xdwc{pQr2_n~5vR{sLV6g#EhhboLcx*Ydxo^H&#LKGz(g zzwg`AXuCLPdAC~jz`fqhaQSquwF?{&h*q}$ye8O^R;45L?2Y}`M175MbU?_%A)_aH z^BpwcFKTI;%QA__9bHd6*rBMsJgz@}d7rdux8LSVA(GDYewU0la)w=9*%-GQ)44wW z6)}0TcO_{==flsGotIDMZtf!jW0?G^@4v&okP6beWL9y>3dAQdx2!xCe!DSliGR;^ zc|E?d+y#zbs2tqN?ae!NW@v4*PCtTT9ZKw2<`)L!2ZBcWG!UboH`Cs9vXlYUTo-)32wO3)4|Ve0hOmS6hN`gp#SU)6`3 zJ7*gIE?5eQx>r=Q!}fm_{Iw6H;sz+NtB+ah(rj(~!zvF|U)%Ymxwe(XT;nIcj%U}T zc}^eJp;i4{i7gua+NaLZ1%94)&rAzBr}HRcJ3Qok?J~75iV!bEt zQID{&%V%9XZGua~eVX3;hf{JJ^+919x9=&n8ryGZQAu<#B06GGE1P$NC>{ZR+v#M3 zomxJ>habIiaDSDDNGo5dF$i|bZ9V>+hHd*AQ&?y|M<&nqJ$%}3>2{fXXDGSbV}~eI z(@76F5Fpz|j(Ej9=-XM@b)uadg7!X8PWz`IB|pX#Arn&9+Nq{iJ6!L7``3gshiVoZ zIwjcSNA{G%9XS_e!}G(_)MbtlA|kTI$xwjldjJr-miWo6B=qDW@EAMzc7@!j>*aj^~5_aq-`EQzbtAo=kXp)s@Nf#i;AU+xx<&)kgE(vBPGwgOnP1Tr0=-|Ks9O zd9-AgKU=w6_bHHB;S|IZZxruA0};3LJ?-MnCt{6o+Q|c|5ghynCih=PM+z#C>kQ8M zSvhGmf(kA9ZsD$aT449l-!7_)+{;or5EV6KrW?L_4R*nFo~DvDrAlY6;R{v#tkAG4JX6%Etu~ElKzCl}e zzXhPz`+7vE5)qnIV)}GqtSE^17(?FsUa3s_bNWP<8K@)(O790HF5Q-r>w~8wcxc_# zsvd|wxP7NUomAkWmu3TkK9aAVx+Tl^{&o`865X5YK0VnBP3M0h@D&WJeG)f26|Ecg z@C$^ZQasK_C8n+~58%Bm(`OV zBpe!SD;O?=eW(dVJ@pfc3FJu%7t{wq?~7u*^artgIEP3mF*cGeP7`W5ih=GyY%U;J zdn$oG-I$}H5}X|BdDQDh_0+nAx1~clP1Ff*lc)px^Sm`v0HiJ-CS4SIM&BQEq>-H? zfD6fv85+)5M(+5fly}m6>TT${eoaOAn1q{|Z!a)H-*R54c!`pe7dQk@OJd$O48rFfi&wV zRv-N|)X`Y4Oj4x^ZS|<;q#GtHR(ED|Jgx3_8saW-etj%mH|6cjAgGNpNCuKF^vqLQ zFRo}f1(yxk=`YnROM@dRcKa{{C4{;ZaP?ydBgGQN*JO;!74Bvq4fkO>lGMlW)|!K1 zM%f76;)R|{U@)srKM@t$wz&u@xJK$p*fWnbvI{7`vWNDf*s=^qn+e! zlmlqT7(+N2XjwC`Cz_Dps|pWBOlZtA#u?+Rin~ciCAhG0UOvo1oMA83hop+mXx$GO zcL>A*bDaq@r9Z_rj`YF%kN(Hfw{joI*hZRFZ%P z=7NEf`LA+}w@DmsI!Sl(5t|u8%PhzDJyHMXR9e^sUWk%>2J-pY#E>@&_4U?YpT&SW z$)S*s?=lNvnRO8_$na-(2Nr*ONjNHFz8T3r)xk2*ObpEDV_88>jN}8T(>LPVj0#XB+zD!`amDuAE%yUJ${=n!fX(3aR;1e!C@kLYzWp=Hh6M7 zZe9Hy->j=D)o(k@mURkpi_X&_ow^?CB|GOMn@Day3(_U^>ztG5P>%~tb>@J~xL}}t zkSCy&fKHs*L&h*(h>F2b2F~(cmQ_kl&totU(L5hMJOA_^+<$A%cia*cP~9ttoqr;0 zUDh)nVX+XYx)87RvISVcjL-R%X%p2}F#j5HO8<+h`*5c^;Qv4Vah$_B$2vInI)oe| zgp|nMa_pTwqp0la9OH28V;&^gnVOhpsbQ#fmJMmXhD zp@YuZi|w5&TVYSO%Ekq)Ar5^ZR*Fl( zNt9FP*mJeZTY>f|i2RL^I+9}Z(K>%PiW)sWBkW{1$_vD);q>)%E$gGkGZDh7IAJGz z7!hYsx%izAdN`d?cz>J@mJl4RxIMk{_R9MRET{da~-$#cXzB2N1 zva{}bSDGyRcgb>Y?C-YhC~hfFz}_?){y&*ffheS!K}00i0L| zBcgLB%JI8aBoqa)gsz1_nN7X$!Bl{V7p+w*0i_5fQ&MADgpy1}EK?|fDa#o`To*<` z!YK;^L|4-S*BCfx_Y0jYVVO^K<8{D@$_&BoC&B2@IYk%P6oJ63tqY#?LU9DKX`)F5 zt>xS*^gSxVe{sX%Can=D+zT5kX`O6zOsk**HQAxT&uHDCZuu+pJEqRp(K-8+Hy{4gtu!H4sCJWjDBbW8W!K_roJXudqMfjK<7y2gJTl zYf9AV4y4mNTu&foi&a3Z>h&Ti5VanoqwG%doB-Y$6SK2YtQq-Jgc>?CukH6tppI7T z5uogf~>f1NO_G?`EO49I!_ls zr?ZV)vV6O#REnK=<0|Rz&_qh=u7~lxyQmI2tihLH&n`X~+ZuebW4Y=P-@7eb;G@UB zO^Nf-jE?di?!v3PuHQYDj$FuT>$ZQi_NJbES<8B?eMI(==+LN?hI{3_J3vNeiYJWt1?Np zE?38D5#0D)l0poj0gK=ejxW<6d^17K@|=ZPbV@OsZSG>VbM7%hPY)%REp=~+JI4k+ zLd1^mQu%+pGt5@o>aX%Bxfz3-;e(nUL&8-zgYgIK#ymbIP2))E4R&Zx&2aBq5;p|6 z5a_nacu2+hIhg80RKa=>D(anItgS!g*BYDHZM0_T6-9Y9(bNRcm~XH4?GoEo!^)fq zVxKHkyaj8&2K9XWvJ4`}u7Z#CA}8uxID+rJh~rx2QH+s*qV}Ih>?BaFjxQ$C`Fyb& zUQVFO79D-kOT`JXu3wqptm*DnE`Rf7u&gI~r#h<(*b^9KUKxMr;=S(dlFu*ywA-oX?eH?hih*Q5q z=Tx5JC;)Y33UalIj8OVeY@A&B0qrHAw(FQ*~vD{8{j7HGc zJbwGqK&(>g*9+^h40k&0L)(p#Us$0T1e#`Zde5MRZIn8Cu}xX^!7Ti+>w2@aSnKAm zd4^y3i(HLn(2qCYldJFTlNiinS}0pNgH!IyMuO^vZilU-tveHb-egz^)KPx^&KubA4CC(i^9C3Xpatr{{-a-my?<7$;!4%n&9lTzcX z(5G%sRwm_52$wG2TB-fr!2h}5{deu`bq@TA!Tn#^mxGgI=it0dwNJHjQAuXl`#LGm ztA~WWNr%3pS12yc>YBVZJXfs1n5E{ zx8K{J`b5q4h9AMnCEI^5gN<=nfQ2zkv5&$L{YgG%I;RjfUdD1mZ)MUis`K6TA5Fh5 zdE%dULw|`qS^?kPJv=?&IGy1j^Eju70UysP-GiSe$T;CGiQ9&>%+g*oA5H!759imV z2e8sK%rVKgA^<3?*4Ky&Z~|UM)3AwN8$1;3Bf$A&OT@kk4aBkwf_ht-3q>tWWp5f5 z;Vse3q0<<=Yy(n29srbj^3ZG2Kd28fVdk280am z=Odtz!mA>@a@To=vV#1ywQr)69Tc@US)QHyG1f4xtEXF!42F)1Iz+&NqukfDlqkQqe>Zmtu;P@ApLXThU=v zT!i~oA?`a*>{K2@51dL`84MgI7#uQ1I8b&Vcjy<3E zq`p#Od)hfHM_@^=aM9tf*!qTfa z`gwLd(sVrRv3mPf)h&=QLv}TIL9l{ZomrPfP3L99>{AVDADKzD`OT}UEH9Iz^Mw2F z#!`&1t#&VEJH-wpNLO?9a}$L^Db&*aN$jn|zXtsw1~=YZ6jz^RMD-)vwTG7~Ym}9E zSxy>8B$_QfB>V%J&nbD|Kx$hDDBTrmC*4MaIIwfg>+zDM)+gzhntB&A1ZdOCfCK)+ zCqaT$ZU}*}ZB>xgW2rYCVOKqide5%HbVdJ$@uE814dO$!cJb&`BOYudqBd+K37gLl zJ@ah{ym7WsK@4lU@5eL))2m-VYS?t?sq4*EjEjNRTAEZ8)Sz+Hq1YqJ1Ld(Y_IEMLDX7PFBbmy+2(EDvE_>Tsc{{(Bv@@ zl)SqZpURuy^@&?cnbQJI8?Q}8rP{N<)!!Uc|3)c!Qtt!71ucm!{e=3kX3zEDsLlxI zd6VkWta`?jYVkD#mcv~^s`SAp!oG!Zwe?ZtaZLNM1eNTC?*7 z-Xt*sStIybbns9$O4ym+qSK@dm=f1-qE6lnUykURON~;6x+2pMUn|ZJsW3CAc zu~rUh$u^`WKPbUk3i;<;T_?~z{iDaCZ=z!!!KA9c9zFFNA}5MB6|{At&B$kgSRO~} zJktP4+-JE!jusXK!BRmI#WZPQ#8>LN5=_k{Jp$I7&(K8`vA2X20ZDo{VO=K}+5;V7 zGCt}FqOwg%`~GxRL{09@V37S$FCL@qjO^Yq)%3|hcwmo=1vinb+D@fc(I%Bx)2K`( z*;vWfP2U|$>-!b;)Jw5bulY6Cp3cfD;O6J==>{080s}K`rWg9}uVhU9jE>DvRd4wJ z8+Ww6(k_BFIhwu(WzTaui^{(r&Fj)`C%G|Q`AzFES_zT~Xv?QYR{EsWMP_0zP~-TG z?EM5)$LTRdoW@fsHG07k%fN*<>511YWT6GzX!g6zW`b(;KLDZ!<0MVEAz2V=f{GU| zu$5yFU$e*=_V}C8rw-ID5b{cl(~W8Nj}T-Q_2tGeq{xP0^x9wN9M?W0xB^p93_0z--^{ z%Sq(KgPT*bXFE3P!!}IGG!jsWHZ9e18snJWJ$K%ndV?LCQYx`RF*3u4NJ&+jvh124 zUOQSo3U_Qu0pSArO}Z?GbM?IJOhCw;-vO6$BsK(D+Xl5&WYf}XHbgGjT--6C_qg1U z=PDAx>opyH@kl%(4BIFexmUURyOEps2N+|q2fK4SsF({ZEnnjhXkR8SY7NU5ACsJq z1|gcCZb%1-T^aH=Ys;;3O;fw~-Y8OR$C)lVx6I{Hf6{G6m#XC^3z$6bJ9S9M+`*Iv z^>i5FqrDpU`wik4#{Ww%3RHLAGRx)hxKGf`E*E;^QUb2u8H+EMU{)yKt*43WQAI_m zbohwNR?7=!J^GAVzUixu^ObEZmfZ6;>sXlSR%m|VpesD5Nk?8F_p#^4FV`>M%~+Nf z5y!@*<{~tGROd96$CPj~2t3k#4m=OE0|GKKQC8=E8B4jQ(?+I{c4G*Sdf8JAqls+Z z8(=PNHTVr|bZtI+Lx`qP4bk7j+Ms$THc!Dy%arJcLfjZrx?-Eo`*l&Y;IoBS*?9KL zxL>kd9>c-3G4E=Yzde;@dUbuofB$Lykdf|uMRR)HkJ3wBpKf5UP8CZ9R)*+_YT6ZK zBm6x8PCpg$AMc9jVP&(zstTj$)wdT~rpiF8wV9@>SgZ1?H&*@SPo?$LZ_tf;)~EHE zQ|tb?;SvijZd(7mQ#n(8_iDrufRXz`^}_K^?I%|DAi6S_X%hK!?L3v(QFKzNPOF_3 zAGL8aYJy?ZQzsB1-c(%6=Bs7T7y0@-^Or`Ct>nMrBE9kkuh{l*dy>`$dENKQQ^p23 zb;|2bH?fyqo{V8teiUE)n}$|u(PSCYFf3obd#@!UZK$oDUhLjb_GJ;19v5|?p0p!v zFqB4EgSI@1c%B)W)tTr1=&ai@4pEuRa;0AVlvY!4T8lCh!=MUIj9gIFT>CMm4`g)c zajsn8xjXoNL~r0rqk4rC>Z4m8If6AN)KMG59D;p9yY00a>T9N%8L6~EgnC4$@bTfn z^FO8w7(SpFw+a~dQA|fD`?uyG232G~s81O>vAvjoYn}(S$%N zfSm_Pyj}6yK-?)wC~u4`ZHx`(b(VA!j6Pn2^}ew{4~E*O*Gp;|%|*>hX3%TV)cTD( zcERgXhqe6S&zAG8_QoFYM)Ul9#|v@vSl0O?sVN}8C?N8bX-r6RrbzOGkknR@l=dQ` zAyU#v6N$tlp;|*!NfM{&QaO`8ym^{a*!OP%wUfVW5?)BK5E#mkXH>%3ld9=87j5DO zM@yg65MEM{?W1mfT`b%`;=roF(jUi%KY(H6@9L1*$XZ7`PyRX>DKPe}-nV9(kN(iB zFVp>Yq|fRlQtvHc9K8R2acb!r3m@#v6j;+l6LR+_CRL$)=E!CGAh+BvLb@^765V8B zy{rSW)ER&o=;p%3wJ9RveKMtpDwSlN(=g1rSrD&6QL! zrt2%jW1C{9nRMrPqn&sI3b1&H);5b_1H>4Lf0UZ+IM^w;t2Mq^E|}IYEp>o%XC7{p zC%HVn7EjFMLbty1Ok$VjTDfhZ`?FBkGpNk-5n0qLrpzl@)X{p&SCgb3^2TC)reSEs zS`YE{Nx#Q#t~C|v7x3QCTJ$ca&sPDTQXm=(cxqdhmH4u0YR05sz}UN2wJ^sy^snMB zq?O&jaX!$pJn-iEpsm&3QDpF>s3xlCl-|+*8aRwoS6$d}XS^&QqT{OGpL0d4-ANO1 z#K6hWgEZ|6sPJ@`i>;d2+DYy6M?s@jV?&#_3T2`>(>xy*tu<_#_|NIwQzi)TJymxv z$KMo-@u-NoEfyQJ#a|_R;0dFVhkiY-{Eej z{te%r@U*_WQx&$yUe-U#EHU&qkd=Af9too4$}}l-xTlsSSoaOEX{8Cm}Gp+y(wPiQB`(ZygW!; z7p|)VHqa+jm1k8|=8IRARaMpPrpkPnSr)4vsHzzgubru?{UHABkh5-I{Mk{}vorB} zn$MPpmA={)_538F0WCl0t;V49-$SacaS|=b5^Y)4ZTS-IK0MFMTGaBW4L~jR%naUO z>^b@J2gGeTK+#-;QLz1AIht=-K$J0 zYaHEc0xIi|m6j>f^}^jBud8h6sCdpsZEQCypul40s*j|CD!0ys5Q0>8V!NGNqINF- z_RKshq!!=(b4BZLcK4;qmp9#CB;Kj)&35noY2RJx-ak+|_|<(tR(T`!$O;Sw1W&*a zI6w_tU8@P08!==4fI+s|NjW5gyYz$V9pjPHI_mC8aA5h^gQZT7sG!-or?d+o8~x1?TU$e zES6qt`2U7F`I|RlQ8E7~)X5^K#lx2AzPK@H*W8L+J1^lgE;?7bHvOa@>lH?B`WOh6ccm% zJxIi+?S}|8vVD=QT1gfcm7(z(Mlb!{_Arl)TSOD zuP55oYW}Cs8IOCtym&j^v+iUoFZf@d6Dt*#8kbDzb1L4AczmjC+5i2YKIcp0?Cq-M z=KqUoz8`QS)#IT8)LM-1NG8EZ<8#X|IA;6~qfqpPP#xV@EV+l#Vupoqt3h?vuzI|7 zVj&Q4o+P-;W@B4%VS<<(^0t`&*I*I$@V^O;abSWP_ctQ!JHHUe|C8Xfy|Bc5OUr&+ zBDx}1VHEZ)y@0+{jPDBGoaN;gPNNWX1cChVX@&GCIOyIKB zsi^G7|ALsuZ_8DF>5&^u4k;ieLxn2s?m0WSn-n&UHjogwG)$ddm-ad{SOT>m(jAp{%!&tXg*pe#%h?WaNs#l8sK4LE}eec@|Npy zjhBvrt&&Auqgk%Gqshp^O#4w&eUe`ey~N@I8J^bMoH@b9hW+aO<^HvVrRMv0A4EJ6 zxt&tHw3VCKhI)PbT)9LB=w!70&2Om?*7Y}EXUf~FzS4dde{9jui0hDWHoN~Ed%3}eJAd!A^hvB<_}0<`6&>-p*gpwY%2xgkS-B@*xF#px`bzTq#Tf5W zXO$SrRn;`YEp-kXBjavvU8*<{-PBwZ=Bs#@!DuDk@gIEI>Pc0{rDXg1Cd+863Sam0 z_d@g;FX1PHx->?~>2Ehgl}?6)U-6!U(Qk_1NUUt=U%jGoOcbK6_Cx#Id8+@wq1^X6!NR zzm&#S2>rZjJA(C^|Te+Tb6j9#oy_gr~kBSI(D zR8z<@VYA2`v@c>*>nOivclmq6^MBuLjtFNSJI7}#%?++UJ~f_&%nSK8>)SdMx@r7b zm>WA%GJc3Xf8oG(adkiCdhf`De=+5~D-2XRI=Ua!@n`ukxus-Q@*wOWY2j$BrTf?D z!7IIv6@Zqfo0|Us%X+#BHEw;yF?JZEbQ(s3ZS6Zh*5GM&y3YBoH7~q2poO9HUrD2< zc+)IF;O&MG**MfjyHky$`t(BDLy=*lyUtMCjuor`wrzHt;dcJ#PW84rlQ!&)`%gzZ zckFE7LtU^!!a3!OH?Y6`???rfY}wB(dB}E&l;f+yl|9{o(D9eab*wy@>TcJa>uLQ{ zp0dmC5lh--A2J57@hS)2y)!t&!vH|tyY_KRovXCQ$FF$K4v%d3*8jPeN@rh*Dc5?7 zrLZT%c4yzR+OKY1HdY_L-ufl%kepx>1IX`y5=K6cp4Ed8q9kg#&GEbntE}e1-8cbSAEoIgG*(uX~l#~8( ziwSkG$P_}3yXd-cl<0DLSop`_OmZ`l*e!&8$Mo_v;p5$?kf#!>#zXEB4djzQFAeyk zA3#m$sMQ|QGh>k^G(oIhamCIt^4CDe*e-Y6qcp5`FWjCPh4w()(qS%BMY!Q`F;NjL zUg!UENQwE@vgYmtuDtok8rG)+;0CXYOHjjUr38Cg*~&aq}Jv2tpW)M^B+=Xg%%hq5O3 zLOBiX$(9H{FuUbk8g*5_fgQP*ToDh;3+t{0qQPs>Qt;fXu zsq=HONehK32mLAWuchg{Q}dB&AQKeDvpk=ZCf1gwx*DH9kpz;t%u)}jRfWrkrYn%W z(;uS}gls&7Ls1W3Uxs1F(=>A^i_sa{S2CU#rgBmN6uy$u`j}kFG(~6DVp5f_c$RPq zi_d3Dgr*$`Qk7sb_b8Q2XjZd-rcGN`c0`8GeA=Cy49OFIDRmlin=F$!lFv!H?}?87 zklanKY$-IKi+7f{cXll<+j^dqsK=>>Mg-Voq*iA9wNA_M!scYb93Om2qBY4m$YfDn z%?R?&4%N%`na}nZ$XJnyQ54L+$(7|Am*tg{T~eQ2F{B@LBK3sbYo$NnAGt^bdbz+7 zp_DIzW8!Ca&(P(!TmLA73F_wkEC-W_^+HuC(akv`_0;~#^kacO6=AV?H zsMf8Si*2yEUqW+Ra%9`|v~jJ>a!BvGSofv^aDQ9uK--hpob0$Z9^KWaKv*^ zU&J98{G%`AVIBYlGd(H!TM-6jbAa5nb)5-kk@0vK;Ny~u_(sS%faKX0l$hdx zhm#!kyKo)^D(FQqBqHyUI0Cz!QPPuN;+@(2sp!b;Ni*wHZn8`PwG-2gu(Tot0Na9s z7)AoQKqa#`!flnY4gn6P1>5--WM7K&Vut-3hg<6+q_GrV06+ z<``2f`vlKRMFh;7nOWhU?hs?veLAz%j9RA}4ShsWZq?WoN1z_ovdRcYpenY7>=zhw zoBrXd9cx0F=O8P#bvxeJP3&&%voM6oYVAIE9g84U?rNP~fepoTTfGXH5}zHnr_Wsl z_HyeExzq1+L%79ZYp^&v=(44>3JyY*vMTVSA!F0;!d5JS%7wUUspj!2O&t-i7KKE$-;$7}sf>Rp^1HCM9so%LaWC!{4%K_Y z>L@W<9XGIhRQR}_$Fv9Tp;vOh9byui&-$mfBpgw~Q@(0yN8NMr z$FY^IpbI^+*H}-qSZ0DEFJ}XL<21*%rDD9`yj4b|P`+)R$)Ij+oquiBX9N`rR);Gn z{3KN3PnYLk3uPOmbUig>4AVOk6831EW-Yg#s9g?I8l=FGbgAGV8=4f>{PnPBb%+)@ zrFLp*cr(J}qeT81nyS{H&c5IcHIQfOU!S%VxqFMQZneJZvrzZQK>mZ46vHn4`XPNp zJ`@4*?Sa1^Lh#b{@be6)_re}=BBvhvFO^0jQ(Y3dtyzd*hM5A6PFNR~P^n;72%4xw z)d1|ai0jqhDA_D`1R%%`(HkGNm4=@o?)~O2Qk8%-kich&ofmlC1tlQpd5T&(knD?$ z!%9F!d#r}<-3dU2D-|${+Eo*jn+nk)Q9z>Lt|&K8R|mtA{wGv0Qu2F z^U!&?;u_#)lu4>--;^i^#?O!EWQB4Au#xBHu*&{C@X-hsV)(Z_Ss4-e`=#~YN$EY- zBE=NTd&$F51VgnR_UA&vd;n_!#MVEJ!w}A8KL%avtwVu@z95S9t00OM91LJ7`bN6~ zh}-I`-@GHb^$8U(-BZF+w( zR2rWSmXN&`s*u5gVY(DjyLKgmIn(6G`puY4yZ*!bYi+gu4nyXAfGMiO`N?2O&)`Hn z0MrW=$JcxuK|t$k?GS^Qut}X+4vcJ-5I2C8m}cw{ob_E53Q`8row7vjyQFSjZZI4P z4;yJqto`LfqO`F2yubw&z}<<--M`j(?yVbF&k-M2mA9vKB?9~b(Kd{TIYxAtz-Oz* zG_KTQuD-B!qO0lYrH&YbW5zqbzPMqX<{CQTw_hSM|Eha@F_egw>YIn-_BJ^C`{$9NmSdAoBc}ogcB7`qOzx zy%V1gRQsFZh3+Uwej#uH9IevQu-qfYi;7>GT)I;6+rB_jd6!1HqH?U^@fa5Ijcix2 zWOM5g-LkyFvIeiod{VZw0LmEMK`Z`4HKAE@4<741;aU64bv{2trniir`P-A-A6Ee-$Of3MK8zKaBz#P7IG}n=CilHfu4vS&>mGP3NplZnTbDQfp z$72jlR?6(1gJa^$AH|19Ch%W&1-G+}PwqO>h#HoMFud@82caH;lwfWJTam^dY`Try zw)>Fn+r#x3<}IO@7wJ7V=vZW2G&yAa3lYXV%XR-V{ruZ!r;IXxnR+L&4dP4Z$UZ+8 zj{caK>|ITwJQr=7FGnyRWj~EW=v(jU>gD;PN%~w4ci0_%TxaI>Ou3>T=l44YNOc2< zR1gM!IWgjDFAB=1!i%aeP>*7u6H+4ttO1D-Nyg(=W)~eIl;@E!EOqrwq>vy;SA8I# zqGfu_zLBo`(XufojnX=Z_aqG{C)+eRDhXRFXO-bj;&nl54ljlgOXlkU{;k9_i^^0P0 zMM_?;8Gq5mqnIwmdsU`C&=8711~1mOr1O9Kq*bo{C0QdK!EJJ#>kd-r396q_X>^Gc;RHTU!iiI}F>l_}xWyXwx%I|~(2)K{XQ={7wV0KIw0DYv(KNXr zYlVN2FlV%E*lcS*t{>>B=7l&#*QCo^t*$8)rE<)+NNCeB%036>mTBeD=1x&6xjYhF z|I^G2P?SnnMpdhDnru+ve9-@HW=dGS;8dkgWBw_XW$k1ik0;GnZ`HPJwcmR7l19n* z`G1j0ZKeMvH7zSDx{*}Al|*6h)`J11J6*rF7B|H<_Lcl8bY-XiYbKp*0sS1FodIv` zL^lIS{3WbDv`lYw&wVpeA*r-e^I(G&Uc5(atAjNEfs#ER=93~magS3|{%HxH9EiAe zGGDK`(Vb!OKcHmqdV=%wXOfQ*|3JxC=1KI1qQdszw7Y8TGn5F1-uUIz)FlF1&|0p~ z8G%>SlM%_YYqIcTa_074mE7bl_l-HIpEHgqWzq`&y;J+!;Gp?g?&`i;7hOB4!{dB^n0 z(Ny_$zdX%n=4J)#(-k@*NO)?yd6`1yRUX0Z#3Bvza?9x|t9~ELrgn?Epy_Hy$Sv#s ztCrnOX*Hf&w``}|ElJzX^xPz)0ah83u@tpD-F}2ej{TOx=I@yn_m9)VG_Ru_zZkvkCPVyjkZVpLj=7$b#}^e~ zp-wcPWqaqQJflQ9XO2Cdy<#2kk11pU&e(Tfanl44Y?_zmGF&d`KUui0+3E7Fcgl~C zpW*zYOxd32lyRN!Z;tHO^jWYSMgx~=Wba(R)+|UnKbZg6h;N!aj6ZkTrbIdj!?H1V zjqS?(u)q9fW>e>8@wMy`NB$ZTNoXtkEAp);V>L$aJyD32zTb-yy0C3BUYTuct??Q) zioidANt_?U(%dhAD0_8wz7UHM{b5{?_n*)uB}T6MQ2-IDvoJ{-Ai34DvC&daL)nvj zEJ?UpHg<11DHiGww`c!C7LLXLa*kcT9+y9ldidOXQL60cW7RPxl}Bi5>hin&pO7WE zm7~NhfVB93LzczmTYECq`(JPW3t4^*1b6@JI;hutyh5f4>49_)kr@OROds)+$hB-@ z*lPx73T?{F`FFE;=3b5p>6MPO+$C}GUaP*{R`xMz`N{Uu_I29dS4{Zt7jgVe4m(>N zN$UQQnELpW@d1CEcYbAltf$w{;exk&C2l1ee+z7q9h2xH2dwru#iecxkN*N6(OOJ> zzT^0Inzm>ATZMyM$XN1BQqiv_^zzmFu}aslPi{BYF%=v}=sXD7<1ef@gVfckx*yqtx4~4&87;Nrj`G=o)eMd8Pa@fM`hocq> z-Wd2Y{KGHu*~P+N4-fIzAH-e1X*odB|M;+pg@^fnB9-e8m&Q~qhOrqxD_{Qg{)8o2 zfz-|(e_?-gQjb4f(!q{Zk)dY2Mk!dD}b~rC`0vKt2FpNLe!3q&Nl%)0%!cf`us7RCa;HCjwq`L z4fgMlF?qhnd2P(;>CdPnXd_*Xpc0m)vG?0*=Mxjv2`}Z!!QKw1j;Q0_V(VTWxmOp} zlqP{|2Y}j~^-Jz(DES=S%XR&YtoREE~x0!MfT5eLZWqF7g{U#^r(_kdrj_Z zl}V#YRyz%s68Ly=f^tmDV$2`y2$vG!adVo!CmeQP;NpGB^sxR4pxerApcShA%qe!9 z2+b!7fm_W^v74{dHnm2jg$3#s$?I-^5&Cl9kY9<$~LqCt9ux zl3h0v;(ipvdv&XKbYsDKKgF1sg`pU6owsp=#Kt$lOZ1@MgClxF<5-MPN2p$A<4|<` z(BrA0YhSO&Pw^Zs4izAW3x$V^{l|t4F6~aAau8@N=j^2|=1w zePk3l`i{zQj`(ImbJYGW?2fZRMRpPIj~=o6Psalbjj_Jo*9xjoP4U>s ziV8>qWucjlljrrmn=zHk&{4GUmrrfur=1af3Db%1PQk(}Q}51BxqmC_m04f;J*atC zKMiO?_^ca~-NtDGF`t;39lj(%9mZD+-#r6BP+lm*C}zE4oR+PcPB|MNlCmtF*sF_V z-ZZf0)wyUgfxR3$QAL26qNq*lC%C&zxY-Dt5R;x)*GYgrSf0o<4ZlT|e9kZRS0}xo zQIcRpJkgpf(mydTgkIba0z*ymk9M)*nfg#=!d{SVFUG-fnCP_W`nrjM34y$#JE^eA z;0yxD1*Tj6j4<8zmyBvJNU}+7n!P%vR~F%yW;4SMRlwB!AeZ{zw|u*#*`Kl%J_so! zz}(F!AMAT;Dhea;<3~I_8|Ke(*3)r<5Vd*n_|!AP+g%jF%mR*nNPqQVq6FWxrhD_q zg#yaw6`2RwKBR*YBi%BHH#c>PfSFPM8&^{!B{8NVu=g?|Q*WHiZI)|2AqnGV#VLxF zaIA#oYihbUf7GiQI6wsP{Y~m6nrq(6-NZ4<=PENX%+3?9F!2Fzyl9n{*$XLn+#DD( zUTow9;Vw1F2TB9k6501sfeJaMB;XAS5*7%h&`c(1sK15v1uZD~;kReEAQmVYXcToM zloKKj1ByT(*s(SDCK%?FOJ_D4i{-;b@xkvnXJkvlAHIS<60M43W%~kZj*pb96apnm@A?) z0GTS<6j1KqD`^L)w)Wk82(WF0w_sa!lq$vt9Cz0^_TktPJAC=`So9-=TK&oTrMd;DTG$~LMHUC#vL_Hw+~`o1TIz2QGO11 ze%Ft^?hG&~{5sOoP7Du$(t{S3#xfVJ`=6+7Y^ud<=ce&BT z=$l=hd_n|Qt!-@bD(to=8J!Ihm=qKy;(`X9>~6X;DC`LDSio;D%Wq8+u`}hd1B)?4 z5ZB%9m>H+_3^1$*dOz!N;Mm8YM_ z1ZnQYLhD|%G=)4s;ERB3exmU*a>galW}&=z1{sf%!}m~EES^^h?!fm3OxX}y1abYz zu?(@h&Ie31(}o}8gT!*8yddGnu{q zj;RwjITE)?iQG(qyPPfS(IF0xv+p|kyC+>BU{^Llq{(tguSLcdJbsU8vUB%)K5T^Z zdyaoB@6))Y#y(1eC4Uiees(^J2!0&m{{$7q92dEHe4CG?guVqi{_wSx`D>0p-eG^Q zh)DVCq>Y2mp1&NRPQG`EZWcWYyK?y;NV?PnxG5X5pR)>fguM3<3CdNyxsYE|-UW6hm@m$0aLhpOPYt#O@So*&10hT48>nDRQZ*P5(^vaxKKN zz*;}?SCG?hN&nyCHz;Y{r%D0HWSvaCC%)jXlW4%LSE`ZJN1-t3u$CK0QM zOi-2MpM0BxZ?_@xDeGAl$y7{1k7{E@v|Ueef4C*@l%?1|fWW{`_?O#ZJxO_RQT)I% zxUf6ma*$UJ;8S5aiVEw<0<#VBLmM(80lhdTu&mSjmzHf5 zUC@vJ2mnAjcz$I7Ry|&ouB*E+{|&H>7UuM-?X*~Ck=01u&JF!h8V%+uJ{G2uXo6~u zi#BICuc$&%g|lKFsV-Kx-J)pDRXfkGRv8X14alI0Hw)^P}O`L z+O?@NoG-?zA}-4J6CdaiT@yqhf)SNGW?!xXAnxhc5Islvs!w@WrnxD&ERp%?1Kqy! zUTxbT%;(q@{k3GtG*p3}WKEx^P<6RJm1>p1!uyHPW+KS<0d=B_{{vx@%>ko2mirei zp{&th3y%xTxKbQE8r=Oww_dcu)O?c##HEi^9mbOBY3-Ob`5$YEvhdT{PMH*_XW9ZE zMf-d-TU7gG5hnHb%-_lSxEXWli}|OUyA9LhgzH*g?V3Kh&0O>L>LtnOva8Irp9!_7 zOewCdG-QpNBH2#NF$?tkmR6~#GN9D#DYI~a)Z`X80%Cr9&um@cTJTJxq^og9WQAy)3pHR6BSC>b zan@wlyuPdl_78uXOZXI`PN&Xk4p+tcghdI>e(Ceoj~l!w)u4c3gG9iTk{$umJnDXh z>`gn2RCezlI54o4o!O>?Gtd0Y;YJ6L=R2zFvo-){}e1ywmBaS?X1E zYpS$PS5BYq-eR=!{gB*()n`{{Y&no#shP9`6+-rkX0x~5Ugy)HFUZvbnAx(`>_gz2 zsp(gOd3zu~b}&F8i<&DpS939_DLB5U z(!cBVymBzW(337$x5|jcE0$-Y1oZ9BoEcLz-UraG z#)}#T>Bc9$(l~|pRI)VZ6gfS+C|@{x(`OPcK|@}1=1HhjEzy=qO|3EFTFma7Vyhy~ zdm8rNP6A-Nc-2G7g+(Ow{qRItXjB*#Grgmt4gk{Gid~kvNL7KDg zq`nB+(!7vPy%_Q(EN#lcD6P1n|I=W>guGfQy*ku5)KYvR$XZ5-#L$!b!<0tvb`e{6 zplR{?;MOCzDOBXV`0qfUYG?m9LT1 zHNORVkY;LIbZN#NEVTy5sX9pR(ZUZci<5%{O%2-smew8`73s+TUu;}X5-c^bJ)a&K z&-!^n3}kfyiWVO#kr{-BlI!$J6Dx2XH0{qU<7MoC!2*8M3};qB(SFgGzK= z9j9}ql`PXRwEQO4WhO?(OX!*;NDgu)g5y+U6wDD!C$gG~pYoLHf2Xr^zi3Qlh?8|N z&y+K6T{%0u+(RvO7{S`PN^@QD6)X1afJ*684joH|YO51_>_Nrby26`ogP&6Or&x8^ zA2~@J_8AM==(EVFe@*&u&)mE#LHx)7c>5~7u`QI|GMoFtamF;G9zI^BQK-5Avd8*e zTA}m#bQ+GnW=ab+0ZtbGi>mv6XESd2{~shsC=q+d zUZr-8(oiGz-qdc*+G@2y?Y*})h*^76)ZTkj)TU~$st#YhuKT_}$M-nSKOpCCIo{{{ z`Fa@7$`0pfHx5N%F#D+R>~FebEQ~Fc7zPu~Uxqtglw7WJAFuZf?2-Wdy#i|7{Ld)M zldhXWXK}1PmHqVlx!ex0ysY?{qB_B;5SPgGp1qbWn)S|dwR#|F=unv?^=-0V$yC^t znO5k^3llej8K@e8ucS!~(kh=1I<5_M5N(ArH-qO@rIYyiu3n&s$a_?#_bCGoB3430 z8@S3vJFx+<%G*XA`3IZIEu7|lxF26XVE?pIOg6d3GfQpnMY{3Uing;I9GB+^INFhJ zbht9caeu0JDoyE{+VHDJ%BJb>AM!q~oPS^=1saP1j?Iq7=7Zyi!h@L*tiR=CEa-es zvmKi!V|}RNTu~V|lbMFivZ+3V!!5XxBw)#PB=$;_%woiC_5H4E{B1}CAFndOPMWe? zXai@2627KyO7th^;HQHou^~js4?>d7LNuikK_qt}?7%!nI~mtN%@+PJ-O2mxwNqPxWWxUW4k@dlSBzSKZgIg?Bf||7e@Go=?;fU&9*WCVy`W}3kqP%BN zc@U%xocFh0vAr-V&}qY-5Wk>wh*b~faML?2ya^zSum}WJCQL-s%aHadQ>FDOFUs<| zW-C`?qqtmY+*q;FOYtZ}l6a`_)k{>9{#X{Aeu=qX~9eqosivj=yMWgI50CEaa5J;mQ1o<(wMQ@xY zWUE5EuB-ro%nA2>#{cj(#-^;??5)@A3&PI2K@&YCfL4&%V!hI`C!~g&{2e}|z?52^ zT~(`D)x5q32qu{!&SE=;OENJEqLaA}s;Cn;c1u?c_J{K!VZp~7ieJcAgCn57IAEMZRB8Ci zQ;C-lfR78#rH&M*2m3w%A}B(7!OU*1a1ceWq4+&|+b=3{EC$#)(!x?>fm--xTZ%jBst!@^58&-}W;rh*K&8gt^$1ahmms8n< zhpe=CCpxB4mP^wzUD+h@6-p*mA76;Pz#vA(dvT)|KO&aBG^ScHx;la^Bb@a)Czk>j zS#B?@u=P6y-10a12&63Uq)I9csPf1n92Zv|^B_rK3Xz!Q=wJs`>w7NefFun69fK6q zEA@-PGoB3zY29_S*hktfYvq8A4?`9D29Hv+xaYzwRD=A0cpasRDDJ#$DR_8DYpDyE z$#7*_VC{?G2JJy7{u^nYdWDbm>O2kl9~++YG#((CB_cIS658ZF0qcs;XFoShTY}-Q%Zs{WXq*v4~3WK+1ci+q}@DE#oW!Gi+TM{&|Mt zqRShBzni_#yE95`L3;e8GH#Ycj=uox~J}>;e3f^^P~%I4#N+6Mci#>PZ^t( z?|MOUodRsp=QmSZ@Ty54<(f-GO2@a#zzTlU|A>u0Y9HMR@=tJF5qVr0+a)CgKON0{ z>F3PzBs*5lmH#IWPfQnuVQ$PH?D`EP%n1uqt(&eYOXqCG-;-1k9J!GH@MNB-g7YhvpNiEb#;S3Qo&V0WiS4WNtHB_f0s7iknDq%&ySkGiGlw@f=QSe9A zMgci!xj@Um@1*E5XQavFx92=$OdsC`KfRH)ds4yu=#lOXHm}lUx{3|D z8X|fI4fP$DYR@i`?5>(5Jl~a0T%ss0_({)&e${&-2awX$*I_42pZkq2@IKeLn`Gm` za;_Su@eS&(%9ZdEwy%so?fIX2rQ_FII7XT?CW{I>;@{QlJ#MsqDr%$HXrs}XT(Xi@ zl7p4KYH58p1f3CvPpGnfHc^kMmlhJWPxjJGZt!&ymFB&aeN^LwI72!LTcD&FI!Bq% z2#=qqUs#Dx*V;!`IxksN2=@U|M_iFF7raXVO)-Xx(7VETL{saixBbS6CyNq;Nk*?- zBPrt#nWRy#uSoYj9T;a3n`D}XqYDFL(~@{oKES{Z$IW9e#G^p;@B1d@MBbUTFJ0B5 zws(&M@tOmP#eyiCgPwv3*{xKKn6H8!Hy4jtg}rPJdv@hdC%}R; zYeoRpXN+8SzL84zQA@Jf$P$ZgYL5Oa7BkQs6JB7-auf5_2XE_t!NxnWgbSk20-~qu zcduWA#KBc;=_mMjoU|vks1K~H=`C4R;@M3t*`LL827C{ln(FZ}IZgChczmEH3nDx= zY*>^S9GG%1$L-6IWNC)hMv)UV)ePWjVqq<$*PuYkDD7naz@W8^fdwo?d=Q-t6LIFC zyy$siqN$c9Yvz!ex8VHS-!@ztnZtUG=czS0w!E~p0NX$1W3W8)}j}gN94A` zRZD{6{e{_|n2AIuGRM6tb)^n-UfwJ|n;;$!LOlp)SCaA3d8BgwV9seab4;$5Gw z_Y8v6V5H%bEV?XQvihbKA_@om|HDCf0*>^vdkYi?5(G~V(j?-8d?z0YWc9o_ILb_4 z7s2LBKV2ec2d*M+8PZ9#h}d!B=>h5MWO$USa4%)B?nEH|i|sWN5c~f6JE#I@-t%kK z{|f`X_c#6<2ENrp{u8`PHlrRkiqT>7-7cvASMUHd~xy{V~NL~L*K!!(OjPd_h>WX=Me6#gr9Q;q}lIHvJ+UUO}pm{4<=X$lxn8M^|Gt=1eDvr4Y_4{LP?z+x zKp1^`{9}AZAYyX8FkK)Ij~rK%iFkAcFCG1t>{&Wg44($|wpgZDsDNf+uwT@Dkn&Yb z0o@)_QSnd7eIIDuc3AZPh%Ti;+2C@Du$jI&3y@q#+?_#U z!KmG$_9FK*rufO;ziyD!20_&$^<(I6;_dCl2l=ET&hK#Rk zGwwmk!=$CZ3E;%s2Nn0-;PQ0Vb$y9ig`q4GI$!oLWvg&{sdC>9E?+k`zlpT*82Pex zd)*Yj%hZI4c$l!K!2U3tbM*b)tD9a<^)o^gsBDws z(@h_L95mYcUr0#hf4szpOWxLx|Mg}7K~%daTVh>sb0a5DQJa+h-;hxAxmi*{z!Zzl z&tb#y+O(!oTc`8p5exmb6v?_Ff$*QBc8xWeuZvT}-(=n;IHpZ7E19-DX89=Nt?toy zo}(iD{by!I-T4PLvUM3+)a_(_&1ndy`AzO;X1_$@QQ{LGm&# zhv+I_P%W7=!gy3F@}ksW>_Fedc+wG)8?NMFVL+?$A%9Ko6}Q88pKlVi;!3i4{%VI- z*6wxv>W#JOV@!JQlE+7JF?9uMM-DFb%k$&j{`F%=!S0P;JVQj9Q3Gnn?%BUq_SKtI zdn;a**L=DU34aVA??b{S!mjJ_ACuJ;E_rA7AtA9C{#N(0B|(6|rF3(5AHG%iL)ce> z_?xaLQh_MO+b`})&ApF_xJjDix9C0)48FEGO~`&?&T{v(WjL4cDD~6tod&0tLLH(D z^RZSghN_nFAieKJ^{Bixwy!Od;p&(9w0{&JVz)}I#h%ZmOxa7eBl>c4=z=ucz8Oxm z&Nq!;Rr>$=?r0C~?^3_6&Hi&}r}1go{?~ECr$0xoKc*MA)L&;ww;vZeiG|y|E%$54 zWeynl^e-e-0MRR=6bTuJJblsu8^@u=+wa8je-6@m9-Nj+B%=5~)9Fm)qEedzekbuI zl$hL~@#-ND@GK0bgc1op2HoVns!Z#oEIQ)=w|{@F!8UX8_eZ?)6Qt7}$*RfrUgu*g zu9Vkav+L47orSd=S1ccmN#8OZklvoP@uM%{Lxb$ zsd>+zG6lihJi;3SO%>Viim6!p-zrQ65No}CM&f6S4&uC_qR01v=h@I7Mi0N@0=sKtory9_`Sk_@~MY7-QWViJc{gYWyg z?8V5O!AMlL=DTp8m%n8qCoLlZ=Mmj$bRRPS75L#ScJyy=sO>x{D=df`1EX6rqFbmW zio~MINF#W0B765E`ZM0;F=JsY2#z{q#usDmlZ1Ty81@CqGN0J;<_Kn1&Xh&mPx!Ip zcK5(d^sJprpHJL?7-xwcRJ%H6Y!Nb!A3w$!yMr>++4h0{KvH`&lT<~-R+dCL){!)D z!hV&&e($P(kATWIL$wECR!uM*Bj^lnn38Ms#!XB+GUCcd3Y^Kmf0LMq~W*20cy$CW%8m zuz?tOoYLuw3L|KX5>#svQjw5(-)r6DVXAg%xHZ|e;Ix6#92};9%9YOjXz{l!U=Tpt zf*9iqZI4W^Q-T~JAV?)B=iSe^hd<-ksWX0EzI)I@7C4kZzM;OThk9hYNa9-{IvGZJ zvKxeV;!Ff>SYL1{r5O>MSR~`+k)4S{H5YTZN&&J^KqDR?sWaxDcJ4+( z?d!utxv-9038w5Hf4z2gF@$~%6Gs48j5Ue3kamXyV#8(qRTmg@g_ePUtkZ9j0 zg4J0h0!fWaBqet(polUOt@V5YW!xkFIAjJHH3Ce z%ke+HIaE&iF`Tm_fqSJ%FgaWWL#Of%RV?|{1P)eDwt^w+3=T6Dv9{IA9MAEgRZqD= zWHfbhcg6%Klz@A+jh&~)BME2%4Md>6RKcJ?END$DuJX}sC9{96qEX>};fFm^52&ub z&u(liN?43ZWT168>*cDnM)Ts;&m?P|?J0MZaru^OU_8)m3xd|&T1X>&Z7}+6ONr~U zrkNjY(<$lR0<<>+$nb*rhXwKHYH+j|?$J6iEkW#O3*x)Z)a<}$Cf&rQP8290C9>s) zpu~W)(${y6QHHh)uT%|t13IX(2*bt z1Uwr2G=;QxL6-itXl^zv)6pdE2n@_#CpfnF7>LBZ?rhFJAQ=2b+=>g#UMH$RM<(1w zM~gRsjskPz9f|qGp|ihGt(DM@55J&Ux7l=MHE4;Xo#?iGN$c(JpKK?IYO^~EepZ7> zJ{5LanZVmVueTRV0uqNICe<+;wCKRWNQ@(K(p_RwZ;SOp~vtIXTPqSyCg&jirnR-6MuG!{u6N5TFbMgM-W-Y6AERjuL7uk&yH+X$74z_nENv zy=oVT20#Jtqq@vSA@_-(@?w7be#K!z65~ikR&|!-#D2x>og=}f!{hRV zq_x)1D2K$C7ELEP0f<&P1hiC^=J_8}~gl<-EZ7c{^@vc_(Y9Y^X`DlU-yl`+xB)&0LdIw}LYlQ6y4 z0Dqd{*0rq2CZAwg$Y=wk!xRcUwjB8SrEgZUu;~6LD4#WESrpJL#e`nMc)32uWqAdas5+aS@Sr(#EHER6MLX)FcNqf zA(kS@((kFP|BNaOY6le`g(j55bX0@(A^`E6P)!6uIiO^lr>pU9ACf`)&zGuPFVRL3 z42Oc!PY5DT4UGV3P$Rk%I}dc4;i0__xli{f$DmQ%am@OOP1&;@{6j9{t<=blu9X~# zqY}t^4oU)Qtl1%$1AWkjrs0Wus8IIg0Z^&aIO^Bj3T>>*gQ53Dr8yGmmtQE|W5c3o zd$hkmC`Uno?HzAwz-1#=*Z0@8pP;<&Il+6jk57%@{m5Hjj;Bd=4TEA?StbeZpCt9UyN;{K~Zo0L{c7&N`#Xt0Ana1XshLBRmpoG=9+;Dx*03 zxYw>FI0^kH9OIfxG0_x2lF5N==tz@lF)O?Hmt4@>DB!x(sMYV98%Vs4=0ve|_*qn? z34Cqzrsj|ri+g+nOj3vHJZu?tBB0H#=ZTfGUepR8B^E;h{2x-`eA}LDS!tBjQXxAI^D?FA^ zahR_{OzaiQHMdi|IN;G3#%LT1Sj7J8}s8kF6U{qkSdS>3y8pB)gA zm67d9Fj9v=N!F1UFRqU7TE&**=9M;CY1if0BXT(K7r3fyP|j2HzCwb`?@p0q8jqw4CA?xU9LSDU&Uh9v)RltG6c>?A z_xD?|+9BHV$<{l&{tVi&YXlZC`*?5`X6v-StvR}5G2a;tJbaKMT~R5kJP#Wl0e6Yg zL=A3tW|^5_jsj=i#%y)c=10#HHCQ+tIBhZ0jNo^s9BG~-oV`F(<~$Mr<~7KHakP+e zfxSIykwqdUY9yL=gh&TEV5c4-;D5-|p>)`=+WBzw+-x*WBW0cHXZL7D2l&XK483Kx zp4)z9wV&xIv5N)Y^sXr2HU42l1n7txm*7U@*M!+f0Wj62AHILLO{m5R+J#I~+U^%x za2Mf_0Bp+MAOQK!>V;uGlN3K=4em~&EioujJ{O7xiofl(`&6kx1a9S2XH$Au{zR)J zdt&w1Y_UvW@FjGnLSiP>{{SoD0FY zdL~XU7)#5z#;IiJ;c}3+gE|*jPkEXbqOi1;WH|Cw+XmyLbl@Nxh^4vTP#X=CN%5-) zq?m6BsxzJEe9)x58KqmQ5504C<4}H*5K|!=5CNH|y! zH;WlC9yn-4Uyt4F$_x-SQV)i+Cs|2^q`+|mv21{dWPKtUK#*2S7AF9%-|Lpodtmktsf2?0}?P2$pD{PGQe7`VndL0+jJ@lkT0e}x+y(t)yjp7zu6<8 zeDRS1?7k9rH}X(AI)6B5dDE?w94nzygL=S1DN3o})IFLjoB1U(Fc3YIM#U2X!HOE6 z_8jJYTq@s;73K5{LWi}UlFJ(mKtK#u@&I}%xNO$7-lCsQl@5K58t5PYZM%4=#4c#B z`@}iZTP0QE&}~Pz1lr(=)?YTVk6svd4ULAEc}5Os5B}!VWUWU#z%7S@YKJlp+)`J?N_|utMaF|g26BI`)LC`!(P@8QV&@klZ&ht-EQ;#|Am?+D59a`;+ zIA(T6R}a_924;iK%L-qnU{Cgb(67Z|jvA2VIlSZFYc?q*|JoVsh%Q1Z5w?Dbpe<&TzX!zkCt zLqPc>^o6VS3~|Ex>i*!59bLw+>eBO#%OwtqlGRHOkwxiln0vo!fMe@xM8o1F20Ddf z*7>x+5AtNQ@4GDE>s#UE22#=lY&q715TM+Ke6>c^2tOk!d36N_%yG3D| zf&9?_BE!46`;Ot0Lv*@n4HzMBL2BC!22{?5k-3-Z@Y$=RH}8;4s_nD>_9x9M6RQ85 z67*;+X$@Gw0tKRwdBL!pm1b48f{=&R!HtV@fZLbPXca=~rQ8YJGOVICgGzOayHo5R z?iC)`IYcKpj8JaaCjAxRaws>RdM`}!F*0O|MH{|VvwTiNcDOP~>q?8`V z>AlqZP?jeoyPO;{fcT6QIr&>h5hcq=7~LEV99U_00_j9;duJ3>aH?PjY>YVMyXNL$ zaRZd@1ux{zLzIksnki{U7sQ+2A$Q_G7FlRZ4P#3$n={?LMb@Hiy^g91oa^qk0pr_%KlmrQYlQSWu4mJ?w}Rdbl*XhGl6wFaq{y zvxn6LrF_6NGm7j;7_`pMRle?Xe16qx&~ZIiMLa^n<#4+IQp+O?l+Yj_yAZsf6MeqR zMRI32yEPgg_)*&RSA?vZa3ZTk7$N4S+!i*t2Q8Je1RSvkq{!$aHFc3Uhp*IN2<>|8`r?|0%ERYITd$bjf)hJAT~%;b|IV;x!~5Ja zwZoeEzR7Eo#oI@#L~+&2!!U3rs4I~E7~4VE#zLMNsA2*RiDa6Uh;qzaq?u93tHntG z(u&Fy{RsB%$VimU2>vyoNDXa>5J^Un9r~Jroq`D;W61&)EHg_2O1kr|B0?R-6X&jo zGM}1_zQJ0;-QZm8Ea!Li=b2riv>tGrl|NKP_$QOzy#}DDGcN9@fx9V|1{D?DlTz>6 zY}o(9pVj&+klAvhM$-3Vj7Bwkna@}kM)Vbi(z*lK1wHmgtcfmcjQ!rzp`^6=hGK`) z7nW`LKLE#V8!vL^L-xKyQpWAZh0s6B#pmlYeN}l6+$4v@Sn)K1{gt-nQ@X-|{QgVpg4)zs?tK(cjs(l{kHb*hGItqnLxqbtly>=+ zCFIjc#ngtC-M|qbsHq#ioWM%ct?h^Q7>)q>>Ek-=wL3wHSBvg8Uzcs%&YVbjgBuDw` zk&JD*xzt606Ad@dOO27E*o(iXoII?vEhzoJ=SQb)mVxC&vZ7$7d~-V`nKj$4MZv-p zgVh6r0l)c%4zJ4BtKZiNBshN3eF!eVTZ1S{roU@(if!)QT?=|$)1c$sGD&|-CaKJ= z!&|d^HgjV5+Yhx4ikJFCymlI{7_dq1VsApZCaqBQsra+d<++&j`HI`c7dDqyb;RLM zM2dZ*WnG@>I&!SC*Zi!MVoVuB$x|q|kV#p~Who8l9G8eP!59`vu zR*_Sh$H?3Z^Ng7mJz^<6;+P(Z`X0&Ot;yXzQrkV!mpxB`a@|&V7ml63JjxMry>hy} z@;1E+P7fR%dW~JV6fwQZ^}Q;+z2cbiQ@0+~%U%s&pC(zKRsa=W3eT%$o-melel8$i zIx!tKINgGV90IgScvhc|We@a4DJ9f(rq~Rnc@CwA3}vJYWnzZ1fX%Y?LpifUx!XhU zFNZz=hta}Asbv4{B$aJfdNDS`MgQp}#Ym-ZI+g6?*k5-RmCO!T&?&e{bpJR*BI+f@Orei`zz$Mu^6kp{Tp8kt*RiGgu`fA1PK1h&pY>_5j+JK%-mQ?0Z?KMU3XgBesoyM) zZrhCSdXDdr4bzr~jDN$7A8g;>b_O(U_mBQ2KEf!TJW2RkKb}TME`!oER2{`jodg$7 zLK-Fs`X&jz;+LIz3nDZhfTp12QzUFtFcGa@(g{+%DRSE>nACfU&?&0ash@`9)D2U# zeN%r~M(K8@7_O%7`v^F>HclUY4c#=0{4^`C_^IwRyVo>F=rm{Q^k3CVn%rrgzUhZ- z4+mEI)@P?5fo5E6rmao3r};%@gyd(0^=9rwY-cXFCPhPM#2O^N0Tm}wW+eM&Bz3c1 zg=9Zw%M?6{qZyu&VVg}z9>3+9mD8J*x1Ck+npIq7kdz-G3!PPNm<{mPa=Fw=WKura zpH%_PY5w6+CQt9WU?fXFVZbIzU&mv*uT@@>nv=@QX_GGoix@hu%)ZlG4E5>{u}wdF zNlJ!#E7B2rwHOatN+36K^IH6KfAlNAl<{!s3ci@)wUjRZY;R;SHFPPfVJW+BDQ9je zcgHBmap}4~9JRlc$2&%Rq*1`OjL}<8XJtZ~E*I-f7q&@cIjbW=mp}5V^2{*|w-Fah z_0ab%SI-S3_dP4)d0Hv*;c4W@i^Z%ok}plDLcRL-c!74R>X0v=`b?-Ur`vYEu*kQB zJ&C)0ru(BB)N%Nw#|&XkB`ktjw!^OUqa*K4H5K(L9%?Xk{k=3Yw=zcFH3}l0Ag7)L zX|?XGOiOjg(X&g4tj^i4&U>vckRzt_pi7~v%MGhv`c_xyyZh%>kJrcV)2HuXDom5gez)GBO530=+UO+Qpy}VBo8O?vUD(~(V8q>oQ*1J^Z!(K+vM6jO zU2d?x*kt$KaI{N6GPXxreqwu)Zed_=J&z`k|xz*?T(^e(i1 zS@6Y{B+w+~$(CrEHNNhOc;l9zsrAjwvgGborOv8>$aG88R;7QawCFZLOykdlFGq=y zPj){(EP9HRwvqMvqSUyp^8TY#+-E7XezyK?jj+Ln)&6$Sjy6hNL4)rD?%&Uh)HXWO z+KR8{%n?M>pdxcBYGVp^L+>3kK7{_(jydkG1;vXQxgAT(>Iz_I>VlNo>yo#QBzSw&mH}m2` z#iwY`uW`@6e=lGjwfAOsugA98MS(a_AMu)<*z|bs9Yv;mw(z^-$4VlC;S^9Wh5ft8 zd4${J_w?F%qH{S<>9~9uOac8zCxP42hL)l}(929iEo&B_4!RZ=~ zJt(hyn((4_&^@7e{-CsEM)Z3$mgDuh-LL9St0IQ!%h-B#vRk`@y2kHD!rwm4 zfB#(m{ZGjO>*4oy3McZE?_HvYI9}g-^bh-99QJ!34ul;JrX8L=JRE8~9O-uod~i6r zdpLf5IMES5hI@pXA{uL8KbmY`)DD|S*Y;aVv+NcaeTcNp68`Us|t#D zwQ*D_yGQF3n&?*#zET`-UCXUv7(!n_cNLEJ!jAXZEi*gOw!`+{`j5MRn_MCf53i4< zw+~JzT=Yeb&qPnq!bcbSC+TEISKcmOPA8|$$3Kc(P%$T{+kThWmBY62lfQ2VhW|DK zn)CrbPOu-J;<#8D;zW}goo+unHETb~Bj4v$5rc$h&VQ8s`}XvX)io*d^v8{>TMFA$ zatq8Q@yhVv_}gi$Z)LMi0{Pcds`#SN0=NF8Gg@LL(hm*vFHiY(o$2u0Kb5)B?6pD% zPT?AZ$Lvq}QI_m1Cv9!qCnTd!OidcjsN5J9&MpEicfT-DrAvic?&ucepCZpl^shwOw5- ze(6ts6jO2D=xdE10mSuh81bJmgFZ^t;X^~;cE-^qGI7$x^V#DKF2)6MfYzUT->ul} zj$NPQo^3}}5$_!nYYm7{UfbRn#do;D!_RcYJQn}9xCBd{UsO-`mXCE`v=?sd3pTNB zA`|aqNHSUAyoanWbQ>{O^w5@%hl-o0xl|KvI&VZ?9OZ?F z$C*gxMR_%O$;7Rv$|(QqMpce@R*45=)}O??qWv^WG?d;8`aYDQ7tshOZ;)sjmJO~o zqIQYD`B>p;zv=&-dp|iN__Q+K<*#G3{l!C;AGJ%zf2c0gk=X9lc;!SN_`QABrE@SE z@hGiRysD#L`Z3;ed!7Yqv~TrYtRA!4cUfM`Q#0RoOkZcIo`vI9tkyqqpC zJud_X(UErdioq<$&_X1viOgKqlli#Gh>%4$b{IgJskTg!y-+0r0xPGR)TOix%)rS< zKB3A(&sQ_x$>nkv9?TV!g{H$l5=o$zYo%H%C_DsmZ-ozM?g%xEsQhbqdvnekJr4=a z54u|(NUyE_SQ|BoH#^FeLenc{**LsvY5zf~wl$ZVyi9OU(hF7`F8xaNiv|MLFR&A9 zhtk?n*RL7`IfHUP`!y~nKEA8c^TJt@9VWwc?rHGYxE5-XmDgLnmjYDUu2Z0_>o?=2 z8eN%#WFNvNgE5n`)h zp;Y^uv>xmiJ4F<66r+_)wk-)jDoFmYX^zAemY0F#UIXljxkxy-tpk|vCz0eazn?)c z2mn-`_geLFCp#5;3kVb3!TI6Q6}%PWm5x>SfEMMqnA@g&?)ZZBZz0IKcd8dsBLd>k z4fL9`qOJ|l`$Q=jDlcjHs!-eZ+KFkE^myq^5)HsXn;Mdv}+9=Mm042u90$f%Ijly)dK?LlxfGi5&-a;omf1h zdU>f_99N3%ip6djLon8cTT$FIzLu1+VQI;uhA(d=k6XTXNS?I+`YUYp2}NulBTv|&=kx0ggPEHTrnaL z6e~od(8sDILhyZ)yaictck;K!Il0$PT1*A|c+lM90jcwv__x1V?(Nk7;jwI9SW(hH z0T;fFC?~(BE9k-MRq~EgjbZ);d)jL=EwubG@@K}oOwB1a-#zA06#G?DmhSn3N97o$ zPZBNZcQ>)JkE(KpIKEBevJGl;dT4=RP&5N8>$+sHBWS>oqRHD+ohh5Jh|?ht?Bj)e z<^6qwN;otGoI#{WFSN`n+EuG9%*Vz@kcHjtF5J~X2$o}cCc9ji&_FeYG>%pR@->11 zRQSP;wRn`KMWw7L__|X{lD4!#BQzk9l__h|LL!%o)M?&Y1iX<`t8JginWN_Vz59L~ z;I&Ft`ElM6QbUz5wKP}06a8&=g{s_rsFf86)`H_Lb58dLc>}Rd#7oFB>=WRZP?Q^q zE-hdrGO;IBl%ASt|AM;|dV#fT(u`4I<@FfuZJQHLMi5oq3{zuqzVN9XWi#8} zp=_UKeK?+{#QQB6*tRL30;Vc|q;^TAeZW@d3wCPh)ENaucGJRt z^TQvB9HFtfV4(u~cNxdGR_}&he@S% z09+3X)Ru&v4Ilim>J~gQPC~2XplPv9# z0&cD5Ave`<*&IS9l0yXiY$prxa%k=6NV4O#f#6TJREH%&%lV`vAgjz19JB$K7zKYD z3T~lESr^VDT9f!*pEYadVKRTyc35?vI@?B#BDH5X#-z=~^m;Kk4>TL)$&4UvNZ{64 zn}MW+8L`VWgi0b-KR-S5KwxRg(ylOt!a1?D$c%iXe?nvIX~KzvmuV-&96XCB+*kuC1J5De!Np3NI>hg zbdT}L(Gy5w|42#WC7b5?`At7UANP6Kml9UJ0&T~%gfls%c{?49d{We$|F8+Z4iC4K z=$s9wP}>F(O01WK?<`WMZ8KnW%aB}>RMLI>R$ z`}CJB{gQ@Gx<6C8WAE530=dMr=y5`S-0%0gyoM&M<^Dsf5Znql%zd25 zvudz`2dVnQKGvo4XL*;^`(fP`sON!!D_IY3BAk{;Kz;dOxo=I603FPHcXcHX;ZBAv z**74LX)qTJ4|l{@cHUxKOd#v5RRr%_GBqUDZ3K5nKgDOMV~?s|s}~(VI_}%?E+h%@f{UNYxtb z)H98oA=q8>vu!;&vcP4mzExGR1&o(L7O?+mVZLm=h7a}NQl>l;0ctmnUQufk|lX7V|N8IVvCJ(hOxHhp+-{) za#2!)1k}y4)QE+7wNBMjFtzjMK(nD_BQ_6$Nr@{2nDZRE6wCw1$ex1pQy7@uvLeI& zt7EY;;xhJX0hQzCe)HmrcNSPUB{5&WZ(o$Nut1qY1HW4C+o;uVn-dbYJOG;Q6RXwB zgo}N(bYe~~+_b1r;m(^^)(;!R0$E^GNFQ|SHIEj-NjpP~l$-IwPkX0A-lxC=S>;l@ z=VH)qA&SngLjKUjqTDrZl)_Deo&no+#gqwFqR_9eq%STE4_a_?xO$xXGOw z0as@g4aqi?5tar!R%v!??K)YRNMl9 zvkRX$ktnb57Fh56bb`yDh!Fh_!L zP-ZV340X5Tab3(mESYfXmWr|RL9v(R^Klo5mK8{s@myJQ&?r#f3>RwRnsef$Y*|C-~5-AW-@GH>QYANc_#TOrY4RR*JuhRG>lE3ANA9X9 z>dy6yvUFsOBg|61JyEM2n^q8Yiay-gFTpq1S=U+(?ed|ll8*gVh?IneE8nb!FEP%W zg4zU0^!XfXvYh%{850pfNMpuWT&tDs8v5E4ceH?#C8LG`E7nd@!#?NF`e;?ssE8 zA@<4_t`y(Zw@oXSUT+^e>|kzn3o|xAh7GotsYB+}MqTB-c@+vua5>=UvTYOyX~-+p zwWV4nr+9_7Yal>5&i7u9lRvD;R^xVN2|an(&3fhj$bVg1M|9^z|BErzQskjrgPx1Y zt@HA%+>v_a-{VJj*H!nkodCS*S=`g|gi3QNSCiBV?Cc}I+2KcN-4w1}a{zF#D6XFw z3)ahOTo-=4*VT9^eteeiMui^=P0`?J(4K7mkq@64IqVNA>STdApOE0B408;4jh}sV z>`*#!ORn>x^VJ5oob|9`lEu&i$2-_UuYt9ktu8~2Ex-+5v4%`TfZU*l+`Q%^{B%Y7 z$axe)`mFre-pwjPz{;uer?&vrKtH7tMUwMMNSzxVMO06ZnDpS!pQO!B$hnfazTz$+}MApi72<`J((9Ezcrux%FDsz> zO%f~RE{0XAqLj_4E5$0DLy1d5FDpPBpyDJyHbHG~oRZtTAmc{2*dn++bWKVnSL!10U<# zzEp1#FMJZa;9yMz3Dr!~LXsJ<@)@rrQEYn<}MeY8b{6HQyv z|7;Y}pO&QbCtr}484QnHXo408*^Yf)HJc&hDH&@tRZ~dV^+uOZfMM*+v#f7Uki-?1 z>Gs%m%{OAhC3DzuY2D!+&PU$m&g(xLh0kr8B2$PR1y2{N}H<^e~889O^}{%dukI?qs9>oYe3=r4ecAfG+zdHv4}bzFlO6XK8Mdh||g8 zY&D*fZ~Y~2URPTDx3+ZKF>WqSZdSezY^#InuP|oT0cLloXmjX)bGdG4yv0#O6J9i< zsp>7~Z2Q}N;8*=!u!HVBLIPMf?({L2a-%?Jql9wP_u``4b4DTPW9-l;2EESe=%Jap z_?9+Y%TMUWDX48E&7@Z5e%hRFsu`@j4lNsgD)hRe?n>2l0JhA^bl=Y@(A5wQyEO79 zQzkrc^7};0n}|Yp!cL-%`Ktbl9GGhAZX zFeYRSe!e&_qEA3aYg!Jm!DDmac1rX>R@%4ID*I-1n1M)6<%5sIf*QjQi$qrkM~NEv z?G-J-$?1Ae#uPFkKa8^v`E{9m9n-|cnMptjSxG48MafEQGu5Pok=CcTR)CDfOl@d0 z!?@G(tK~m`u=gjvPb}QNdJ(y1(fiPM$v!}{d$qzr%<4XxeqLoVT198Ucje9_r>~y| z<_12DsNkY$$9J6`2);LSA7_!iArAUFgeh#XOB&N9J+e9&+-`oPw;uk%VQ4|%#@vmG zuNmZl)xeK>e+~A5TOT#+LrxPDetA~553i{mZvVIS_>?l&2r51VQ#YJ0`m2>8d5%3D zNTqni^hg%H#sDhK&?97|7rYoP65Eto{UA2&&F_=wsJzK$C-j_+cY7j8}drt5j}ly+|TKU--zdP1$hRwjk|^*#VZo|(gY znokREzcinXmEJJ6af}f+(u^-rF12894Q4eXRb`Rc2h!E>&9J;h01%muFEZxT#i=O) zV5OPWH;FpeSN&wUh2PfzWy!xPGL*Uyr3NUq(2J?AfWmNx7;ku^IFdvHy%f`_ldc7UtlnWKeRAH&hY;r_4ZC zjiF{4UF2H=6Px^WPO5ySN(%zHiVBY9?@4QLbDsWQrQ zzt|d8sqj#N$2m8va)2WzFm{Qm@aJHOst5dR5@28s(@*6c7Iu0tL?(Dw00BkvjIlB; zyYmQu(xm8fsH5*rA#JZNTeo+O52a?R`qwlh)_&r!eduy&RtcUbh1dCE`_vnG}~dPz`p zK&;&pMQ;5EBE=nYP%R|lFuEhcZT`^BX2Q%-(Hns9AWagx5UAskH86^d=m1xPDQuV{ z0+VRrmap?SVha72DiJt)UyPq()RF^Q`|iGdHMPmRn-Ks{PK=v_!A-zz#);D@=iCa{ zk(-%zzz-!Y@wskDQ zu)tg5S|d!&w0z{?`+(0~Hz**FhWMO1N{ra=5Bk%1daGPY%Z4gCw04J|KcnJs?JQgl z$~_UJ&RTz4Vo%D113HTGVgTJa0xiGD0)c>cG+Sjm4|=o&{_;mVrn7{blSIpk9$>!aWp7~G z!eZVahSJccd;*v^vCbR;XQZXQuzs|kDX}E0!;36r;jsHk9Vg^LQwt`|aBPDa9U`|L z;1TZxmt#kDKT~Cyv&HFlHEL^q+QcB&wUtLRFt^nj)3Z*-Gd3$SO1f@yukMY^v+>HK z?!_|mLEh=~GsC2NU(vR@32Fh_@i?=&JFqwvqo|s*EL-v+7SyFa%(_7)$GeQ{=k6Dl z5@O>Ep&^OMXzJLN_enybSumg?&5ytjHG(x1{(}03lV?_M!>_@qGv;uQhF#mb6|#jO z?Toj9S}eWQnZ!WHYcMo>vchAE?BM(_96LrjJmEs4yNy~mfuY;V7WdX~bs~6OmIk=U zV;KRk`W~BHA&AMBiKG~>EG2yalH9@1+;W zNr2Bj^wNC7ye_fy5I5-7aoXVl!3AUedTwBL*X33HX0>X~%+VhJy4<|a7eZqLpXDG# zPp!}UhP=z67=WcxQqLhw3QPcsEH!{we-O+C_S2(Hg>9?@16BU+zL%EIvU=Vl81_X2 zgRt~o6!O`<`vB`;NBo^$%TEu;e(bTg&)z0DjcpqCT|hfBu+F3}CqhHBXNu9$iLcBJU8=w+mcZ?|?qzK5+Uch_4i+HRTr>6moG&S&f{P zFa7hOE4THeLaNBe`2OSfMd-7WN)3Or$s%8xq9I>W*&Pi-H6uNNZ{sG=L=OwgFWstrl6 zmh}-%#bd2Jhls?y_ji2E>iB%0htfq=44-EIvH1Y6%~`p*fg{^sMYEosGACQUuYF`- z+()v!SDrtgbN3AehIwc!j+7%NZ3wiJa2W4-5Y2 z*SpP3q~+6CaV#_|Eg^Q3DOB@_O)8W&e{l$-3evXLt_r*4yjMzJw4)>cdnU6uv_hh2 zH?QuMX(oGr_0-#mSoC<_!y8$Ws?z*gyG4bnGKvfoA_#Ju;cSWF7=6%Ub>&6yRU#wM zNm>d=5>5*>U!lq9idMx#se}T;OT=WRc2o{5 z^uQ%aa;48-d40I^Ej7z5t?q=Tyo1W$%*_NIww-grAFl=3atLEx4n4U4^zzgBzw znB(-Pe;)5zBvswwiOt|4p9;?kDSMkWX@M&k>W*--=;Q zOV!$QqfpRW8bBJYhw2V6_cslea;i~$=#l-wAAa(GbD{ z(~g9AMmdduk=hk1?18(xDH!GG4~X$#5~MY^)wbG`z}-cV2?GU20pvu~eBH>K)?Usd z3KIPOj3quvG_!SFf=M=rNwI-R^_WRR3Z`Qgg#GHc-_MHby8{>0*zXsY{LCsie0Q0= z_9hlxeHieGO_L^rhR1vcPJO3hxtjUu#NFjSu2~vz&lVi3t!qESlE24QmXK-R$3>c^ zX6H-Zu96zF-^so^E*`^ucYTQO33iA)9cJxIyWDq!izg39c}z!`z%tbSyDLSFx~ae6 zXio+vFPVVecVPVd&FCHzu!0vB=Zni^cREHFp{|mV{ahQAZqxUgLp{Xj8khyg;4{?R zk2001H95OyxO1zSamOqu4Uf(E>sy&wqGfJ4GuKrE1ovl!6-sfKtGJ((?p}py`8qf5 z_&VpTB>e5~Km3=v4T7mZ1T&?Da)X5m8ikZI6?!30U^+~UWw#wAg=+^aei*?!q=nH9 zJUfUw;a!3Vvkv1ADoAsQ2+>`6HA%^__GungXINUeGx+RKn`tNbyY==1xmX%fS6M9K zi*Uzm6#0-?r}pjNRl*a`L{90@Kt?R#4p5+%*MLrJK6vTSYw2qvPlmR~&1^iSuC!1{ zMH!>k=f(zTxrhXlpvo4vQ=s%CUkUVXaN7D^6}cu^DK9yR_eP=FG;(g_PDsO+P z%-m9)4^jQ_KAA09;Wy9d3r6beylPR?5@m*lKP*T`A$mtI)l=QoSl!jb65i#bvB>LR zte`OJI6sY$Y9rcaB)yFE4+}MJ8SW=z8rLTodh3=3FKmK5M+QKpXGftoz&q+4%I#q( z^%%zI@fP*>mLBUfy~?U2sDXqVI@GYE*|7V>uusNl=wvoW2BqGt#rMpJ%1CO?QnS($ zwK!)yV`=iP*@W|%@ixq4ZO-KSoDu1k$M%=RI zd|~E7Emumgc`}H*946IoE{#33-qX7#duO}YQi?DuQNLqB|HaZOj9V@YZRKUmZ6Hg4 zez7blv-nGnKBCgoVdlGQ<^M0;%*Inis>L?-)HYMrE;r1spvA76CS$IFg;YIZUDRS< zdurbx>(CPBkR7ZqFz%3W)sND0)E_V!b!8gGz-Ggo=36e>Pn|YoowvfAcUznf!piKK zZS1|JelTkqAGmOxI)lOmT_LTm*UnrS7HrIAZ`p?3`fX)%^URHNG!P+oD<%xZyJfYV z_Pt8htxepBCKX6LiKMItK9bQEY^#c*%pgNQt;}vdA$#L)hCa1>bLvDobyGj{L@!CX zo%xf_g46>o)z72eoIfI+$Esi8(HDg)ECkX;ruwg3^nixjuOiZ~5$|)eHO`~22ePeA z6>HZ0WWPJs?lyV6?{pcu$Z+iX?9~^!|7op#fNB~N~bqtWjdp#D+0(8OBS~t7?M;+S0!L0wQ z4(;TrvgeY~ut3VZ^}7V0nbS`ZGtKDWk~k?fp{*QCYoD!L+q>0Ud5+<`d>F(dp|ANK z|MiCW$DIGy8yXf8c;Gfg+%TidpqSB3BuY(z^@&G@z^nbm#l%xh-++VDMc0Jx?ALoq zSN4#1W$oFrU)w>eu?Fq*>BrDTu~u04_LV*4>`M|0&yV~mcCn6Q{eSHtv0US-wG=(^ zUz^un(kO1_d)0Ng0^s$(wo4MO?4gJ${YvC>bJih$jqa9dv{2FZg}fPyL`QpD1CU}l=0ufaEIOTn^ioeae39A<4Iw1kFP zw~eJq6G?dm36bH)@Jbcxbg5O^0tPclSCEkEuE2jFApv7zrU~

d)UA?O=<8TCxynL)@#Gnv)AX{*Eu8}`nTgHUqtj6UcJJ| z{2uzg{JjTcj!|gyEghg_$h33UsWpgQqg1}vM2AkR$5DihKm?j-FK-1Tm|Hdzd6p{XG-j-)hK6Fxup~ z@oN+n@GU+2fAk>p6{(!cNj0G>J!s28I&YVGQ!M(|MB?D#u)<_l`~29rD+|Zt_mf}I z-@*z=c2_HIPpkyj74H;imB$`d`aR6HZ-%~0X$kpCF1Dkomok|%Bl-NZ-;S>GYr}t! zn#%<^W!No#&njxlJeOBhYJ}itzR${3>KX5uPvpTUs$NywbfeTnCYftuxU23k0?kH# zFSJj!)J44g=JfIPVlT@|P280pbmMZV1nqFEk@!f?^H%5LR7`7g~aJjq`Fj3g_ zw$(f2^5c5EUHO1x5UpZ?L1B`*&9H+1bWYeU7;+{pNEmhHOU)#YVx(;#eCxi=vPCR>aXdw$(@3^kXUcn2~_Inq>%gv`l}x4 zQmvqy6lRc492VI;UH^V}-#;`Ot8({muF%HMHO)Us1ogic61NXhUS&+a8zKElu={C) zNH2fD))09_gz4Bo=g*neHDaI|?$_8bgEAQ00Gd2D_TRHu*pf8IC==HX1GhKJYAZ%R z0}zvqQS0MeAp?qm0pBB_!b`De8QVCyd=-R7IK{e;o|^@NK|Vd`Zg94U(X@z&n>awxB!Ev`W(4n`i|CKXHE@Iht(1R~>oMsZ#W z0J9}VbN>IhJ!i;-aNC5irK@jatTRep(c9dFjbPf2o0z8#TcT@4V+aYTgj#s~Bhm$d zvLENGkSGXd2EY>%y%Uo16CW($GYAOD(IBis5)RC)?@D-*pL8%ycsgptd%-1DCQeD< zNK;_*NQlqMPxilv7m76<0s-^e3H{zQ<@E`*ktsJdlb+cUqN$>Y%R-l{*iIv$kLBcvOdekd3;;Gpx2$9hO!B=d+_xbNe=H!= z6|dC}&2LZq=S@?@mO3kw3M~ajVpAD|6#Wx2!ZUG(k@0bCghlT(*eOAT*SOhyE(Z+FPW7 zH{woF3`|sLjfE`wnTHAy$x#vUsNQ&^rJU;uP|*f{ky1ubX&m`R@`@Eyla1G25dbcM zwDi(3PMh>9BqS3Eg|;c*1hX)1Ijp#c(!l{h7WTkhsJ|VP$K5Q3(i8?r9)H7Kw(3ld zM065YbzVVx62ixhA<)?+oMtENMsNYtw=tig56|Wnqp_iY^UN4H%>`vBn@WZ$zrL3qAVEFC@?Wx zGcjPP0%nE+>mZ_`l+9xy=uQE{;dKTkQ%1f)B2%9$YzF7yNq=R%1kO1#_s z-oNP2OsV&HK((gyjK66P>hTmB)Nf25#QD&8=uw|tnN-LG#${@QX$K@6gl4e6B4J8i zBtkiNQxdLcnfW;O=~=E;vSdjzMCrr2E}lt~l~6OtH%3B#oiWrVuu5q@rFF}oFU0t0KOsU6@TMx-qTl9Fuwu^p1W z1w(iOtqg4ki3@^STn{K^y%Kp=@wex)zJsM25~+XrM^ zp+3+TMZ;B4uj<4y`vgCT^Z7Pv|4yoIP=Klu@6j%!@^CRtPz5{pW3ibRp))vrD4D5X zrDp-aY?_|k$F=;3zKRl9V@h9h1h+4#j@zQgZWW@g@b~&mVYEisnW;@2C?l*D^3OEo zdVc1@I8^pQKA<39LAW?syLkFatkQkS=T;0PdqvSup?976mEcRZ% zme7X_Trb9<2d=NEj?;Xp;|`TTXY17FM}-mgG*foIfO^nDF9?9FZN=Z zUr!5D4f2psG~kJAP4JgG#*BKA7r^*-Pz&i!VS!}c##s8Q;)^HkO&d61JG(HTP`k2{ zi|QrQfNmgEqv8(Pm7w6Z7sbGC2ffsX#tzV3bF+|ge`&kdOnn?J$z4Ge6!7;o!#)E+ zA5NwZD5O}YX6h}WdjSf0205Fiu2O_PE2$xuqV6d{*9vRYE>Hszk0aBsbSpp`g(Qej z1qhR?JwejocQfIM1=*pIdrDz6opIN-KceF<3`$3%D5P)Jfw3q#2;ne$3zXuhPD#E(xMN-D^au5pTWG(kxO?>f0BbzgFN z$XUy6`s&0o+nCARR8ol4eLb*DU9}Xb^KXVX0s~x`PoXoWrwOltvCko9j@BymwjZJQ zN2tpTp^!h&=qNKtc5Z5!|wBlAs0k3=G&<&PMC^S{(F9ZmpUqTYw z6Ef_}hJ3oEZvY72=$HbVz$9ojG8y{Gt|!BBAoY(i`9~o`TC16HX6pJ#8b@fO0W~=u zs$U2#Re=`$VRBP}?kQrf!q75u*SeaJN{*Vl1Rc?U@#%NEgS9o?450MK_RTN#xR0x;+7~=wI9KN?2ZiC6sN`zEVhF@ zGrRWT)Gdv4%FHQ9gL~EU0pxF}S02EAg4(YW1kUS59u2Q+s3dbd^MLYZxe`Kuvkm&v z;H;_>6>yO?j0qpdT56jLhp@V}HTE<0hLyU?tgj#(aLZ#1Q5 zugc@pZoN!o@9DYUK}eOLdjH<^L?Xwj2?X~`&s^ZlgQ|kWLdlVet&iL z=g{o?@Xl0*PovDV(Vub@($g{-iKEn;m+535lzURwc4pU{gVG+2utE>x848w76*^Zw zF*QA^YxbX~TeU)`F(o_eecI}NJN}oU^zX`y7*5nISE-|eO`}L&=+lkA&^Cqi52?ul z?Vq~3p^az+cdqoWv;dbv|4Vr*|1FG=Bmo*U=6qX zx8-{CtGn7)59hA|QE%@Sf4$f9b#Z3hso`V5)Yo9O?NBu#U$j3oNSsmHY5QRhRQ!O) zRD~h=H8kntcD(aVTT%~4^}`+YwjFfP_PaiAqlq0;ss;a`df%e$oSR_{6uXtj+wnoO zg=%}{&U?>e_OS2s!!dg`t9y0-_UZ%t(r$ihcK+5H^R2!3+v}chU8~=E{(bAcx!XAtNw?F56un=>wR4i>>d~hyvAa9oYc6!ZiluEHGD2;ICv6!c%h~bYHOvkXYo= zC+a$Hj7TR*nLp37PE1`+tWzHyBPZ>K7W^r#SS5ywomsCSNz@BJB-R z4C$iPc|40NppNC&YTwV?H@{lR_MSX-%I6ZKBK?|wa#W7~&1rU+!|=OG!hmM&47u`K zw)E`At>0?=m$SHs?x~kuYnMG}SdzL-Jn6Efxm<@g96az;ry2cw>^2cDNO6Oi^mm=> z-)HrIn=b#pqW|qY{kQk#-#*b)kfNAEDSM81EJ$HNQFjAH%xI(F0jePY$HclODps0D z7n8utpCsbGIDnWgL?H;$|KX|4MEt(Z#e&W(0|ak2=y9Qe(kN~(nlhyy@!Hb*d_<2% zeSxR@!es3RK0Zu>Qopwi#M(9+4JCv^w*cWBE1Qtt)yZv#DOv zWm!t}>M~IA-TDTUG1Z{rYc^ddmuV)It1x*so9S`kEohUh-gws$tw6rOxsXTHEiC-t2=DbB|$a( zgI8sDc)nt*YQ(3A?e2&*G*wTR+52lmr*bBFw0g|1?aDVXm)}R!(SOe)3S3qRtc(D= zxK0b^c+4LySMs=pdKBIumcE>mf?Gj@+PYrnUZX3eY`X>%CzRxKQ#7?gK|zx-ZyG0{ z?M7|2lN$}*3>Cp~Z?r*@$F!fH3sW zb3L=X*_>fsPlvt8BFp00+*dFwjd+C&>tepl)oMP`u#SnXnyf2*<~jRcR066ua4_%; z8%=U-d^V}iu_!tp-2Tcxlr8fW{r>5f&ATB>&1?ZCS>l&6iK?Uy(|EbLQhL0g2kah; zWnT!KoZzi1yT^ZvN3ZzYZb?^1$-Lj&O@?b8EK|x{nwT|Lv9sx}b6Aygy=U&Zyx7$l z=Im3lm7y|a4Il0J5F*B8q(nhSSk2Cinao%tEwxkIdr!6#B>iFNY z>a<^IU1*sjKk)Tbc356OFSJhXZ~KRJe&ptht?U8?1mcw)KeMp%?#l!O=U!IiiZ6D& z!v};`C_C+0FTQs9wHwyb>2w&r*trI}A6e>~o{_ZJwf*;YI0K`_N$X3qcMM#UCCLOdZ}~+vCIz}mYy5*(!>Yi&>BBn_k|-Y0eOM( zvR!U;1K+ARDT4?)D=gIM?Kyl?bnzd5xEYt(3Tw#*r3}qFf?u@T^5-@L+O05ms!b1P z|Ejkad;Nu{^kRh4r(P$E)Rn+B*omj50nM28a+U_hDF6!%<`7NM zJHS`w0h?(JR$pr`(A)h~yHdJK``LBL#2rro2eUm2bx$DEJUvW!wsN7e2csM#*g|30+t8r5QS8 zyu&a4n3+XEuvUB#xM2u+0OT({?F6p_U@?rbeXlx%MKoK+XpG zSRmjwf;`g~Rms|&NunQc;2Zrn+QiIr7$l&*ad=mBpbs0^wpl3MlV;zbh`p2Ut<1Zx zOKAF>R2}Pxjz(O?`t&|4XBP0Vy3EA4-}GI!>BIb7xkdKN{%~}M%=L-I`#PR?QjLe> zD_dXL3$VR+7=;;LY;HJd90#(~iVBQ4j_XN`zxSR0o(GN=VE;~2C7~%Ak><-?yMi5p zkmGVgbZpz@45n(;BI;zE;MDUETyvbWqrd3vT+56$L~9pix?k$qbL@8N#VEt-1O~Th z$@)2dd7j6^suSIbqie?I6N5}ZGuujK>Ndcc^|fNjC#K-)XsyHP`hQ{_R?#}NF_~rx z{>hcJrS3*+J-|(C!Uxm!JJ?+q-(G4LCqGcp4y<}Gf4_Xt5Tp+q zHM7U=qg8n}728zs@1lQHf-czX{@tBQXGxo=TqUoLu^4{CSbw=H;E2$9v*_|~a(RtE zxvcBulKN-(^NFL9u5WS~;9j|xgikj*Q@wk#P?(r|aq-CBd_wORY!U#OgE&(EM^V1> zj74JnTB0HCVqD^{p9K2Jqz?ns5l0GmFo2mi0nO%@d@mZzj3XreP=lUmr0-~CBGot* z-vHQ<@yNb1g6heU(%@+CbChOEq$Xm&=e|MGfQ5XlWGC00_}U$fY6ZQQ#EvJthy!(<^#uv!jS^9ER52EH-G#L}l_ zI`~R&khY$J+rS8-Sq`>=y@+E=3=NJByE!D6?xh2( zv?Cvmpq@aYg9kzeN9Ke^W^G3@kDiZGj06BtL7PLs8t{G1(TA3(dB|}0(Nmfy_0B3i z8YeoQqNAkr`r&-IaMVd{EtR8Ihc()TPn%CxbvRHp+;gEiLi{hy{C7IaDQx zA;-sWD3tPMTh~+B1is+Bj;2#f!dx&;yz-cw57H>8Qr*A7l-#arg*7PH1Orjz4sZ#E z8iJrV9_@`k=OqZb;_Xq`0>8<5OJcgYah$iB6$QK*25t$KB8G5d1*K~W5ulQxv-fXF z$t8NY`ca&%O3aw}>pwbFp?a}6!zhk%)*sJ9Az1h_P&&W635EvqBF0>v7?>|%-gDC< zkO2KEkYSYZg5~Xdc9RuLZ_X!;2^m>d3Zv2&mgR@fKeu)D^S>W`T=T|H5d5Sa<<1Y+YB?4oH8wA^_b*f+w0p5 zzG0cPGTSUCn;6yfE*q*`H=r2@?#Hc{J-En8yU;}grVV226I@D|*2pnp@K+mdplC^E z0E)aX)XZUYVv2)cR-f=-+R!0TR&2{?>JPrwF$dv_H*5h#T#eq^pXu{SIJD!fd*QK5 zz^kmcam$>Ee?m?+HvA~z!+YS>MWeNP@1Dx4)O|Z=x(^3dab?KxFED*p52JaTq}bnO ziMGUeJ|dK6&H4M_*&dlgP{u;_csE&1I18K6EFxkMYK1CP^{EdiatDp zyqcm3>25^e@tIS_<}j;I`5W#QA<2UYm5>~zL}rCj&U@W0uDqi;Sa^91|jul0_RnVM3n zf`h{ns%Ej)M{^lg9|=+v1*510?*wHijcD|fc<#IbBZB0h`SHuQE*1n;Sv1(+3f!4B zMFj*v*2|1D%y^&ZX^PoE)~7;SKt67Z4wBg+seoIlvZ1NChq>hCb4I0~u%Toa;MN=h z&SZ|V^`05abX0}}Qpp2mV{ZU}JC0o+mXfbb9&L;7^1S`1qOBJwCCN?axIj>hE(N9bwyV8Ma_*xK$@QzrB^@tCA!pQrKNMEYecSVP8ql5VW4QV_WQ~xJt7Ceo zV~7qQloRA&yGSO8x3U7<6$O75dS~yN@JLn@xZ{!}#u%511y!sR^Trwa<9SYROS@Tq zsBn^O`ecu4Ua#QeWijg6w{mEoE%d()k^3S?1|T8WT%J}LrIc7FF0D5|1d9{y6x1`$ zg?#9Xwu=safmXV?YLN{rJS`;~y^R-i9p$x*y;vzWwk=_{FVk_524l8V)x^9LDl7P= zjXutvti<-33r){jAK_(`SjbyH*87{{rssHyun~MMWU60m#q;Rs5}D~FW9#2Tb%(If zu{c{K$CEewjH_3Xv@IR5Hbb67DB&xfppvMYbyXci-vnzKuXLw&c^Vn(*ISONu60ei znYV*XcHCTY2H-b8f1&x*ZR8HLox@qV9XZU1jT(AqY7ORoN>6Q>s-=PeFe)F?*;_q0 z0w|)F07soDw^FV`JpGQz@Re(>;_RIg@$xh33?>Tp^FQK)&;G9b1yE?Edv>`yLir+Qf5+`)#VU0%@#WJu- z-4kd&+pehzYj57S2)x-ex#4z)yfu{1R7``{C{^X4`j^(m$0+t4s^^_B)64rrF23zw^-9%jrV zL2MGW99~Y#^H>g>ps9daFyIIC8o3Z>d*s*SPAwZmiDfkB$6GC)6qP^p3Id0A0&fS< zdiH?-&FN33v?RLgURajsRA4yZt)zo(NO-gujoMOd}uvXH{N>!NkTobY;k$x zsY(pv)VU|O6I9MteBK>qhW0u`W`#<6rPhWY+^sxgjSWQxjqJ>YM>gjwgOY;vS-Td1 zL#x;z)`+3Hv(QCeA)lad^b65++N)SqbpAP!KQL@0f>vX!i6+row~Lw=;#Dgt9Q_r~ zc9*_#^iB4Ve`8pQLJ5{+Pui$~0PZbzHD5TU?q7R%o+?m>7=7c2-kM9knp=Op$nv6} z&Kl2J8pAF5`n(I`7Nc#FehMTor(_+*aFB~~}@`^B~^k&Z}N{Ry$&L?51 zGk?E3JUimIGx@2H&9AC+{TJuqWyq+PYZ`|4!P8AYoqA5x^djRB5ZlAwc}EwcQ&!EX z`=m#|+bYxl9!1&apzq>(weIwD8hNk*;J2R2^8>Pe1@kxK2pCmXz#x0X#5}r?RY!k za*3EKI87dTpBfJE6vCiLU9EqI0iMAw`_kLsuHnEucuejE`qA}QYAm_OV)YH+Uf&IY zu9_atHD<3|)MZP=-*Wz_59cw@#`i3?wqLq{*;}%Nbb>mI%`NGpvqev41oVrvF+prw z;3X-6PXUMo{;~UsBIIBo1tp`XcLqF*3_vGsG-|qmF>Brb+vw#xfdPQI)i~hwU;6NH zDsr-`H4{}^X$U3o`tFa?fi%AC_TM8qcP5BR%HoKv&(;$&v#FU9ynn3PXvdo zML-IA(*D%aUy%Uv7H@Xbc{mhk#=rcS5k9I%w5`qJKqmPqw|h9~{~A9dLSZ05A~Kqe z0X<`s?Wd82vwVr5ZgqW_h?(rRCWAq9rpNie&rxKiyciKqG=bFIgADoFM6Peg$D%SD zJ+sAZIAr#6=xqBP#LdqitO=FQi5;&n+blMM_Hd^!TePpwrztv!)!Nr09E-W3_M!=n zdgd8FZVNF=w%XqaYH?4FgKDB!Lz(R(iI#al4)kLJlJU-*RuaskS-Q`m9C@;aEL<@k zkA#*E>Oj4sPBTNnIogcmG+Ud@7@kL0kG@soF$N6_9(a=2Grze^B&l|e>1hstx*TEX zTJ&o^+gPe)Is|B*%U9(80kRcnTozTC=ihR&X%FJ20!DcMO5DzjU<3jHjDz@0T`6O9 zhjSheb2Gv=4BIlT>ACc5)zZfEIwm`}zuL#zIYv&vQqWg3?T7nor1k$3vo!Bea=6oKHD{ubB2yJiS!2p2zv8%AY zT7C^i#*+k7n1~atH}FPVt~L8a}zN`dcWI$@zetAmH86EwUzCjceh^Pflrul$K&aM|R~1$x!3aTeQWT{0T_D6~ch zu+eGxC;GaDp1}i+<-G8BMno0SPxP5aoDcTw!Ww2U>NsS>4OdQAOF_yPwRflI7)RNGAF#qlA@7YDL^t9XSi0JK|1X@Vk49bP zxW%xvPUNmz(LNfWmPamP$bU_L58zOx%-~ecgTb+#W|%!vR&U`R(yS>jpRQKaY=#(k zWHNG{F8TA233%{T%@~muvaN0tDS!{4DfTmI5WhZV79SmdZRz-LQMUyzZb7bJcrirh zxbbbdk7_Q-Yd8psTC<9dYTixATJr1TDK|wzS=`}`qJfcY_jEvRbW(H(P#yd*g<3=^> ziO6O0bLe^`J*8%*Y?9W3EcZu?1A8rC&$X&hC8zza%*=9;AFxdD5NIkVfg3Ozg5kTe}i6PnNx2pFGgCZfH_G7MsTPRo!TQtEMB zCiK?Fyy@{m7r$In^;`eorovU)VxXKEjSUj5!)HrA=D%{>wT_z=WWLc0`>dH1wZJ0? zozYfzZQ@$OFep9-+zIKIW{KFjV!#1+*z9JgU79s*D7fR4&j$Le^CdK_e7S_iC6nyi zr4%QkGWr5^OJn^?nOpL=JM2qwjfvIF<$@p~M-9sG9&Kvyya1;KC@Xuv()E%|XbLS; zh7ZSN(h524BDJM>$zIQ~Fq|$?n~l%q+TN=q$}EWmKIl)z1~Ju_mYGl;wG^^xGs@@g z2wYR@2WF5A)NglrYb%o7aM=L|^AuHT4_R1cIOf81hi=B!Bv>c5mN@~Y>9vlRliE)+ z;Yzg6`ut+&$Y_~_fB(Vv)0wv?z|nN-ly)VCq&+(6K)rg12nfY07ms_(6n*pMZB<0$ zZL9ci1lN;q4t{3wd2HS=YE>a;eX|=2y}IR0(78C21N7OTZk6_szamjj$$XSfxcHmyDFpJW;qJ%u&UPF_WKyos| zxj*@SIbT=E6zY`{cQ0RM`B*(LQJ}aOA?aNaKy>og`U+rTa*`tvg$Jv4^*@Gr$~I!w#gYctr1=bg}6Zb6v$WUHETnQYn!44y6KZ`m=2CifEIw5S>Vs zF4X)FmZ@(Br;a*q?-|^GXvgF#^x(Quii(a(Ry|0>Tb>2T-@;W>m1;}!H+5pwrU{+r z*h(?aYYfWSrng^dy!6R3M*)QkTxnA_Hr!YHEm#%s@pfobo|45A{2C40&sAe`QLd zr8<`@b-^Gaaos%WJlYx-9$#cG3b(dr&=b>@3Xv=-|AV-B_mUa%1PKKlQ+^JF)|;jm z)iK)J&Jwxdb z+D$5pg3J_QJR7nO{c~EN3!8{6hceL_jwp^<&t+v6%(D5dBU4J)f76?x`cAxW^P?D3 zAGuUm(sLFM6$BP&PMR7Z{vWdL`=8D5?;n0dLhOh=8@2bUmf8^{_NcA)ELvLCmR1NM zcI>TcYpq(PsJ*w^MQz%mv?!|b^|{6m*M0vD=XoCI@i>n6^Y!XLgupM(3@I8qg&G}X z2YX4>KOZ#|h<>^$)>j}lc2j(|Kzw^FWm?KuvYUwwRl%Sc|8?4_ zymxq%koXvV2H zq~u2g_~{1ymW^l=7NLw+#=Nl$Af)C_ILlcNvl>N@!f|cjC4P0B=xT~4ykF?JTO&H@ z>wb?{&$Qv_@~avK=s;{bvP0-Bj2WQ+)o8P12`-xzZmR5Z9h1|M$#%~h;7*gOz>H)p z<27hwkc*~Xb{&~vLo(LqG6U3i9Id9A(vc*#A=Cq?qP*|wI3qQTlTnnbUxl-l%CZcN zZkVOT)BP3k(_;H8eh0D=2U-C~<`T!OLF>bTf*uxvZ$CoMu!xHQV1u5Xvvko6)xe{f zd6v|ahT>lmt&scVL-2PdLSKGts-wZ>%t z=BqMd*RRb=97Z5Th`|yGhn__92_q@|LSX|VYU>n?6&m8+U5PIe>KVh;yk9q5-s`OI zK4Tpuwo+7&vgIN(6yemSOtD?z3$zDr%-_bJMWelF9T@8kD_<0vmhZgWUIJ8LO^nP+ zUs67*?|FVV4WN(zxCHvL<^aeUl=oi)sJS;OQQe@`#OQ^yJ+U(`-*vD>snfkcGO|wS z8zp2&o3QIdmp62$z3NSB?s{g>!>tS!t)Ml&w7n-33P5vM<6U=ABvrzi0V{)?y=97f zlVR*c4x|z~yTXe11PX)nG*}77Q?kVZJf4lY3-nICK3e?vJbRz3BqLrpGp#%`PdKZj zJnOpwraOdQp0?*pW-l6$;!M8;WVKbnd7`OvpUz_!qL@T`ax_Rn<=r^%`SNP0F%pom z!~tsI*g|yzXJI1JPw9J&C6%Zi^YCcV_pLM&EXX#C6K1>A0fA~kP(x8OnkY^y+DODq zgwW>aEBbIm5`uBJ*9+wW^mooaaN2Oyp2OrWKWWf~9lBt==|p>Q;(Bi_&YV=V!c~m3?sjyM~2!}d)q3oQfz6*(Hb#I1MqHmYldkx zc-l)FP+8vhP^>?oqCZ${AfjU63Jkk~952lsqezH^aXnl6Qy=?3R_k0_Fkwa$Oq(Oe zpqr(<1KB^kH#dkTY0$A37jcB_IC(urX-Ygc5np~V%$yGw7f`%J_@L4o>y2sCjQ<{m zqgh#jNT9g)%8u=7RzsH%8Kf%_Lq~o8-ByOfX%VFje?dkT0fjIpghaZjfEZ-N|_OMJ7&wOltG19!TNoI z8yr)dLn-Pq^Qa`@n@Sl%Y@zB_l%$-=bWU>l8jrDF)dB*<_mR}<+%Z!Dy&6|BaaQneV=PLe>jXW1KMYADehI@fESk5ElXhOx*vp3`M_rgUGQ` zwh{rjY}_{Cvo+y7oA1VKv5f`^lh?tt6<0R^NMP4ib=;YTqC`0L(BLZrTW8mW>CQy? zQIzGkxbjn&=mhZn8BmLlB;3RIWGS|QMy=u|pGFR{Q_oyVC4Ys0RjeoSYKru4?OO?% zXuK9_*h)FxKlbPHQM4cQDNJNp_=ys_#8!IEd*&7zxwVbE`ZL+SC9dZ{HwO?Wt66z?ehq= zp$)g9&()sDtr~DuHS%8YMvTaj+uDH?Q;P=M=DM&(+oAN9XzK6`(;0~M?5f?~?pCea z-rp4&iJe*r%ugA~Kedv{_}`#G=cF|o7nrI(lFfDz;6E-AZ!lA)L&3nQoigrIT%@Fq zmMS#1%&)OuW+r#lkJJa430KBT;qYv-I@6{Q7E#3%l80}4)~>6hUPoQmiT}PMDqxDr zPa>0IQ%SQ2Tm6Z43XPpmY&S{t_cUx>oyfQNNFWi!6cu&LB^VdGv{9rZQhYsn!9=zB zBL{OP^(2wRAriMK4U9)#!{hj4HXAk}jryJy^Lp`IV}MP}w;T!u2ODXknbxbJlz%k= zi)pXl^3s+Vbsy6CHxLj)fcBXR%@nM$kJ^Ubu}#jZ5P?K!__MwKL@WZY$dTHx$1dBD zsF;kw8Z-3%oYh)-8*SMO7Il%^44jCTJ3&(6GQ^*XD98z_HT-_N0NP&+JgYMzdb@R5 zNDudZU@~v|SL-Bs+xbqTv+DJ(Iw0-rX0aZ%>FFAjXTp?`$mzOOuIr7}%yC)Zdk5N} zmV@CXZ;FJ(0{cWluY=EZQ6jSo`NZ&`(+T+znexgZ55p}pp()UKD&g&q}bE?ezG)(?H^2uOKlMVIKOyQe~4mTkhRdy zvWRm`oF{T3Q-Yg6Ot`e5EVg?T(#rJ5EAC~sW#rM_{d+$CHYKV2`A*wsRpn{zfBq!RdywX7d*xa8w%JF@bI&ksW*N1@>8Sai zJ@dnNn_~AmUu!M=9yz0JvB9>?^|vjJD=*KrEmLhTl9J8Wp;n$`8|=OMd)nrC+}4!* zx5nn*vf!t6NtKN|pEh`da$@`Bw_i3cy>RPN{%J(KicnbJ_EOoweA)t}dSNvsFWcX-AA+jC@3j9`RXx_me3m4-8)3eA^?lR*aD7E< z4+c%f^Q(Wrh(cbIU|O}~fjooyX*7TwPFG>W%ptSJU#ahNwuG^W;rpk*(%Hydr4lkn zTyOu*oXl~Pt7dyq>F_&yjniFWJag=6 zQ!=X~(;-HgknFD2eTxG~(OWhk!74>hK#}n& z4AvG8r%O2x4}WwVx%DqYUv$}?yA}zL5L(CcCC%uh^OR9AlQQES)88yI zG0EvYBnyw*ll%Xm)l8=pyNGMF`feY!o#|16UPyk>_-#Aew@-K{=Wq{x!} z{W^D;!6RKtl9Cq#dO!RhuqxhYpiX`7bpC(A>ggBh%*<0AyKbL9%g_G0cfV2v%SI8oBdRNh42H^55~a#iM^G^X}FEPM4njv|U|E z_j^sH4Au)uI)bMsvXC=Tb97e?#~p?XEv@dGq_16vPI;6u1n3jud}9kacbfR@CyuWD zY9aSC*BIrCqXc>?0Gu2>fMEYc!2iduddg2q!(NEKETXmI8BCf9gHOII61X!sXb|`- znKz10NEk7KnraL}`GlICk$MvNrUfsCmAnP;m zvQ+8IpP?6vzw<0p#sTY5#PQ&3*9u^QL$&B9ujq(l_p%?T*ge$HkK>JASa;pP{)=7fb|+fA1W~b?3Xt67;K}csrWm z`F5?x<#BNpKzG4~28!BRrKueoMZ^NhE<>GGGSI_wvuUz1j=Ds@@_|K=$Uq+-KZR>0 zwsjk|KwLa+Y#Ob6_>4+y?$_yuJZdrDNZqA%_CHOS0sX4S>BCE#J&ny1aaI1M!#BRd zw40`TW1iM2DyJE^wupeMJ>L(bem(xvD*c;0@ccI&);3LU|9vDGVcJ6*TAA3Glmv`>gto-+U95Wsy+YO@e(A#gC_)Nm^hJLwHlcjDXR-iOP zxPf9JSbi#_k0Dl8zeD#6v^z2rP^x`jT3dCyuE^lT>7lKSSrJKl9D{J{{M#H05&VrE)`{ySdWbue5vLsj)R+`L zl+Bk_+CMVGeO`Xqf0uTdVKrZx5Bxc5mk;7Uw9+(ud-3J(%9*)a^GyBEmvpsXuV|Y( z$PEs$9G9!9B=TMTC5Jf0%Qa}iS>8MGqv){873Ox1R@rO#I$pHqQ2wFDy?ig&XJJE# z_>AhU?DGx0T6~f)-B0~RK&S2_&N&Qg>G0@h=gka5_pNVNezpMzzuR(JcTwZcmC02K z1bY!3e)o?q1=sM^rtP(Xh1^VAtA9@F+IQo@yYr;&WCJN zOrzRcx9^OfAIk|LBkawWuU5{dvagPhvRZ17s~Gt5xlU@xXC|HLrU&1V|5VRv(=apq zOCyo_r`J=vF^F!Uc8C9cj50+pljf7GzTC5k8!U%~7^5kclJjXv=U?S_{%K37{;l+j zB-&ksG8}~wC0?txU#M#ujekORV}2y4HZ+N22EE(mDuYuFO1m!HqCZ}JNW!E@D>kxl zAEz@aCRqZSLS8)%(I9ye{2oJ{5g_ZzAaYt}nCG%%6%U%{krjfkL_Ax$jYnkwLU$Fg*RfOm+?5V=jl@+v z;S8TJPif8?_N+`Y4P}XO%ZSLHk22u_7duex!=sDjqYLLL1D;2{ExU=e#7CY*qppc8 z0ThRdj+c)nJiqy(DY|w3#^rQW4r^=+D+3G~U8xq`|2(=;jk0_`2F)LCHyK&X6Zd(| zCpYZb{BCqrQz#RU`cfHG*8{pX0F7mhD)ER(X^I$YisuE!9$?BMjCsI=9#p^O6N>V& zdk}4mJfPa*L9Tin#1W6eUz1Kk&Gq@90(2G(A~@rWu>c3e4HAF=eK~=qlTiOWCJ7&t z3^D1>NYJZKAoG9h0RdU(Ta4^R?eQc`3vg9|0cCaIBDI9;jh=u|LMahNCea*f(?MA@ zKw)xf3B>S_$XX=~G{ufWk_0uMPGqpUe@<)G71i%ScikmT0tM|Ir5(>xJ{Bgq zUk2T3N^F#mN-aw<_e+JJB?hQRWINEE97dcHFiD~shK+YYsHUy9j66+zsVZyIXh{<3i5;|Z3*pyLvP=QIH;ExT zRU-?&t~t)9D3Avm$|R?fy%vZ{V?NQTvy9;!OmrFyk5w=!Zf5NJa0AE+qaC0>{J`as z0W!J1m4s6eLo(fjz{B6HoY3AT@H9U`aN{Z^QXs35@P-5cvXP2r1{ai}pzsHDNONeP zbDCCp>`0jzQW0FvwIvXHd?0WgD9NemOovjy5~-&+B7`t>~fR%SM-h=w91JXC}k zuzKBA#-OXiuq_Bxe8JEX&Cg*R-=HSfTTd(#9Ns~Y6XAJifeS< z*1V;N%ZRtBiDe+v-Uy22`J4G(Bp+;y)kMFOApR}=5bY81rxZOYls;YQk@dU zz7;n~nnxF)N;a(BEKIo@J}!9BWOVN52gNSdhI9%5YpLiz3W>&mEK+#FNe9>!^4j?D ztql}pqSpG90p>j~{qd#T z2e^7?7LPx5z$Yq4g_0|- zL9ZHR{^j6){K{D|){@kuHYBb)|(S*XK1C8R6>5@#>rU{=^;7K>w4h}Isf}}0k7ts=guJ%km+Si^G{+U1;7TIh3MadS~GmgO$KZCkAPr=GpM zSQ7^~Ssi$P|4ZYmFB10_n$jO>gM=}R#*67DA7M-A@NUJh8_b@=6dT94X>k)T{?+@#*F?t^>w&MpE7?*mn7raD&Y!ME zfDsVC%6`0vm0sc7g;9w}f81Z4cAAIrpkQ{)U^~V#FIH&$8XtErpTu+eQ77r-SLpL4?`?@3&?+bHrGqD`3%ja!`|6W^YSik$pl<8^%p*1czw9#C!#wNRnNJpVST$ruN z6}$h^tIft6F2UO#dE1_A8xK~tA6{+ya_?O0)uE|X|KOcjM}|LfJ1dEeQ)kZ$9^@C$d#% z@x(L2jq@bc^Ae4s6wGgX79B$Aa~5(~Zk6L$wyw%RbCyklAHX$>qS;%Q52RguOZ= z+ds8vlFMn#zxmhy7TfN!$>+ZO3CCekg$x3c zj`U_*{IvV@>faTZ)AG=v3rA%ronv`q0|xtu3oKFA+V&)~e5%c{`t3HBtBfJzvOe~h zD%O8pLIFB?Fl+bY0nIO?hJ}DiqR-d9T$?`+SK3|?$+$Tkfor{SH?;F?&mV7(=Bh?= z-@owR@hrEu@6WscH{f8o4MTpz?(+GstEC~DizqD54opmzt4U%qV z-%#rY?v@E!u-LzLKCmKPSAs4wD89xLG#)FcC*jkL0E^YmXm5L|JTj_ zH=8YCF!iRXrTrf^`?MksJt)v8?;Lk$TV5C=jODd33RVzF5J`c1mY0@%BtCdYBc@Rr zRJHxltGb>0_eQ!1sKvW(TuX6(^5r&qEnF*Y*ez0o&)TQuXP@HFNVDW?08AuP^6dbx z*?oA_nGP*t5XR6o%zZTeKRmZ%>6GHxm;d3pD$my<$(N0nzX+(j*h+W*q}81rsS^0D z)ZXXm_v(I?p#2X!pMrk2ldHZw>KAOkOrMuleRVSD-u~*(*U10tdGFkQa783{0D$Gy zfXoviX4dx98t1H{uN5>W#EkKirJIWcNNFgFs+X?U}#T;(;` zbq-$0^;`?yGs&*~iAYtD6BH+pyvaO?*G4+QjAt}#myNCguM?wkFH$%yGTQ7K@HRAQ zN(2}?i5xqz_R4F?q)o=U7mXiH4BP}(PR2cQNrLsTg5_r?G4T^58(f(6p78qjDle62 z$ByMui1_e$caV!-mF0f~uS!e>o?SUf+Z8RtoANUH``rR|RYVdz4M{PvEheQD z=_-Q3q0LmjN#uJEH1y2sSc?oZv2{)}!+&Rq{38|1fl!P({K@zi`-V4xM?*ZdSx*H zD_V0j%sA8c9uKF72kOXe8Sw{hJ{W|mjR|-b;y62{<;@YUPAR&^Ch!c4k{ULsQkcubc?$pPv-6I{b3`H zf_pjs4I+w6_^g4qW{u1vZ^>SG*-J3zM>Z8!>RNds9$K-ttXzBES18@#WBub;1l*5l zyb*wDnddZV;E#9pCOVF=TMbD(s`Z<6vHsdMRBdZ6aLo*=m29l0(oS6na@U~uFMNVB zzE1Z7{8FpMdU4iVqOQT`;nQJ-iS@KQ)%qhIuZtCHOd^pzR_`>C*&N2T;dQ;sKL71` zE%22M4+H+k^Ae~Q)oR``&71{b0Nei_2F|qDonC`pDUF}kp!X{qd&+-6FKO}@4!Nzl zcH+Lk0}>?%;17XMfh4@ZJ`VZ`CkrIW>RhmLt;2h@JG^>Dz6M=34z$lKwG73-+klqx z>4WBb1aBKgwA?fYrCRb)N(|^4D><`syyccTrLrb>pzHaxNie4Hi^E}*T%M|IrojSq z2vNjV7XgluUez_%JW$Z6fW^X4)_Wbhz-CD-=(+4`EoL%Zxm18>Xb!1V7RB;BHKq$XlQYuWM=DBo{IQj*B8%8Xw45%KYR}2y%c5g}=wol6R735} z!8x*LP2_~{g+jLlzmmn&I2Ih>2{yvJ%+FePTK2yYM&BQv(l=~j+8#+W#I305TfctO z&{*1-IK|?X2C< zRyf@aSOM|zkf?c_4?C~;-Mp?@u$k{+E~>$ycR6=ds;!1_RjL5K{9Wxkfl~CPRECN~S5h=dp z7}M1M>OkqC#O#TLYSEwj-?U0yMZFypIzKNT?l)vqZe61q9>*Yz;5Y1H?D zQZu4}GCatM(L-ADVo}pRiQfGaAM<@icQoINxFs7=WiE=SyBLKK0>w+-67r$%EKWg* ze8ef^_PLiZYxT1xhW5kLdC>!o_RG;4m5rY(_}@>vko^*^f|`-dU+d`3^0SfWydR^w zZ~Xe#ACF1cCU4jH)OM9{c6BM?*RBo|e%{&NcGc+7EAF8cyltbfm%dcg0WBte7K?t~ z`i?zXa+o>FnH7`QA@QYKJ_64EI=Ywy-t-I6lEKiHR(A`hgKyR7Fl2}5>Hr~iY?VJc zu&A~V-|+oknnGEjR4D*H2;f6h2jiv^%a@vHct-@F(~lH!M;33#(;2xcf5W8{yDKcX zcsH`6iH`utd#@6IrXJ@YLN;AN9mWSl1Hr!8+%LuEG0yl8qMy2xTROdPB75y}-W0@@ z7*>^9g|okk?KXG6{Z-5TLb>KeZWjg-CSZI2#^;iD;|N%551fZ+9D)z$rV%aFp%K~C zj=H7P>#ULTL?fw2tEzTlhm1&DG!1{4c<+cjOdJ@*sGkeg@X!!8O^y**hcv zLZovBG($Tyqa9?GVg*li+vzz!B~K|Q0Dx}CnwN@_P8Gc|5X}X7h~5je91@{14XnRQ z0uoInJP)FH&Yw5I22jJA?UD~dHNtC}*JLqD8@#RdMZG$kHMhMOkb0B>dk|vGt*D2j z@^^DII!J-n9_>~Ie%K^Vza~;53?fk5%~u8va=^JKVlhHsFjUB%FAVd80t>_%myvrV zQO>}xQRkf`0`Ma*>Aho0{nht6F@WoVU}1QLz6~Cy3pCy(H*v_q5kmgfwWIH0wDHGH zxtSD2msGVQ_|#;G6s*ka8@B9a6!+wZEY`ORcpii(gg$CG(L*!yiwYl{7BmoUJCHSZ+E2R)yf&Np3Y|N;ouOr&#p@!~43bo>Qr*IOp=u#d!n>ft8i$unLg4KB))L8_Ba_p5LOr?MW+Cj#cuntP@sEMFdKWR$Jjf6ZtEi-51FKnCFuCJg}3u)#seS1Zsz~S_1%Ab0;KH#|kQ*xFLNB{;cq!hT! ztY@{9E!qv~?X71GlR++qbA_bjVu`{Xwl94aul1SGz;c&{H9+mY%*1_HPzXFkG2)dF zg^s3nF9sDOv-RwCy+$91 z^-YU+2TCpHu`dKn06X|UVQ8KgtDcWWRLr^D9SB1WkXJ6pbfVmPtA$P5^QTbpWvdB< z@K802qKcOLhGTY84>}y^)nT(dcxE4Q=>z)@-QysX+S9z^2 zvY^cp>j8c?56u1oGOuw?obD5cgUyW=6$Ef5*Z7%GNOIWAAE2)<;fOPu$zO6*r2V>P z2BBaDbBl)cVVGJ?;9-3p8;Ag0UxP)a{{2fYoN~P4dkTQ^wBv2SR}ByVJOMIi1#2+s zNvwn?psDhf)&=7Y#2v^x>s{Rqii9Od0m1X!w9XcgLD>j{MQ5DU%BEt8{bhIiZQt#G z9+AnLPF+t*6n)&dm^t83bImGTfV>WwCLkQ9CnJw>a^*tcHl6I{k)di>sVZlcl7SF` zR-*uYR>oal0S}@#(DmBBdon9otO7lkw)rEhSoPKuS!I2)P~pD_14v_1PZx z#zS{~3EaJs&4gv3jS#rhQC7otS1V$j>)~LH@-?M&;?52Ub+Ej4`nOqe<_=()V93R9 z%uv;mDx~py`feTZha`f3jnMHHQy&$SXQU7X< za0C=q{<});tqg{qJRQKP2aZC`rh`$>;nDyw5#Grg7S{kaP+Lfvv9Kv4PfN52I@$R7 z%45FUcudq`ZFJq188XFWYk4^ukdMXyEWgBY-`7ay&GLz~3DFZqQH4JsE&TCSLqadj zYyp1DfN01Kxn7?fK2FZ>-+#%U*mbNQr1tvvh|iC$@h&}m5G&YTs(xr-4a}uZwS=V( zw)y7@(OU?=cI1PaKH(n+8O)Y@$wGvr9$mG$FL8%lhX>R=HPD~H^6d&!0KQ^`W;35q z|6Fw$n*gCLLM&l}i<|p4LHc$S&wbWP7KOm0J|SgcB;<@*wF z*2Mc$Y;61J%ZC#YdA4M953}@Qj>8vW5p6OPhF3+J(*o{ZNZ|QHQiWQ+u@;5QnYl-H z4s$uSDKrP=#ctU)xWF^ZV4m)-cP%(ZG>gj***-w2=J|<9aZt-O*n&mG#JHbUx9OH#0zEo%rk&^mj1F_N2loR5i^xM zZG3Su@N9qg1j>H*9M{1~vQg}rb*c|Fcw^Am=q=xMyaww~V%sT6$adgrcv*8yz8Fqb z(F6T8A%2cu=6YEHJ#^3kw*6Zn4FC18zoX(}8jV!ewWmfpK$R6YFJW zVR*2d?G;NuZ{xG&)*Zo;X9i@gcdt>wNMp zA$u`WOxCSum|C5R@SgJ9J1)z|Gwu|N^^{`a6xN0|s_$b$)f!U>92qUl#HGTbH;d(o z#?n9V*e|M(x_&(IuTuLx#c2r_T1W`UatT$2dnVjJ5PSg_`hER>e=KB0_HzTTtgS!M z&V*U-B_k@RE`Bz*Kkc}(ryd}C3hp#GvYA-IkXG6+{toW(HAs~EB}>LfGLVb;6qrJ5tA!Eu{@jVg4Hu%6lB$U)6%$EDF1ENJEbJ69F?8WhW%)8$eOyJH3 zCNNqu^Qpen6u9XkKc}zqp_i7`^qjp7fM9A`EYy*R1RUPVa-oV!L9(41b*}2c&?dTy z46%G@(Mrwj+{ZV0=$Y8i@|JL2{lwaQU00zaW|-4N82cl2dXEHg$0NvYq4;TiLE5u) zDMz3j=3F4h&&f(FWqD8m;7zqav2CZX^2|M1LnL!Hi8EJtH=82=TeJp7t9&r&Ue1Dj zA}@)W8PJ5FK18ULx{Ct_cy`i?21(Rwiu#1q7V3gGGUUj9$ zq;;L!`VCX82r~K(fm5T8UZRb9eu<@xDeIJ&(oxECckwD;hcb=e=9`tG74?syUQ0L2 zlQarvlS0f+n{KMZPQ*B-gtRjQzbjLAZ9>_G9#d&5?|X#u0d4Mmp;Xl7Z{Ym+Jsu2t zOAx>H&IXEJD;ND7@jO?#OxoYCp|NRw2=;N^>L)8SF7Tcxu{f4W)PcuS-B6rLvUza; z!yd`X_rT?nz2jvQS$=pe8n^pCyd;r_LHIPMcY>LBM<=JV5+B4RY>Z-+PAKvYtHu>DVJRx-D4b%MW&ZDj?O7Fd+ z!xp@N%>s>%3E`BV@T>{5zy#hz)mm$7smCkdrh`5e6_DjSWpyZ2{3?D=>XLCi#GY*(S5y1^_jgkbm-*U6VD{zI;18;(DLUckAeI*ilXtQ?yh5zgNrGwFF`v>J zm5@RyYC?Cn+lFy}_`Qk8(`a9#Q!#Q(Qr>AvpvRz5z}v3?+xM)>XwP6Yf>C3Mgg%~i zh$%i0mDdGo#GO_1Kwb`fO+q#j_~1Nq*cpIzCkY9knp9>PZi!unezvE`-WH+;8wzE7 zCnTi^x-M-)ZZr3I1I^V7%nUGkNrVX{SB4KlI5k^f1x4pur($TPC}NX7A)H+7!*k{c zK2Tetr!;cIDHfsSQB;b2!*qfNl%o8ou6O>_Z@r>l>`f6^DR^$tsE6pMl_j6@xrRa6 zP%xA+-E4=5WIg3j9u$iEHs8hd!cPZvx4prIX7Pf{+x)0wz!qGU`H5?+k== zJc=hRu0p#Ufq=??lt#rlQv)+(;)il2ZmX?s-;{$fY9w8sEWoq93%baXY$H>!nWM&T z*wU?I7Y1@`w6`meA|*5jWtaS3`}?9P;Y=t*h13WuxaIb^w2C2GscWp!S|i%<7<$1L}-3KgZ|((0#0@MsLn+kWE#hsSfm!o3E`x4 z9<=tRZkrYB((GBF^M|_Bdg{llHY z0?ii*;dLD#oN1$1*#thVf$=JqC)lr_Ib*wC+0B^b<74Mh)-DC}Rhoddbq;ioR`d^p zy$;g+1BoACOeYTXkO86wST>OXp#8^CY7St`Y&i-7Zi?od9_B8EbEOofH#X?$+Z-Iz>Y#Xg#ty5s?l37yL1--@9<{VJx0U%vZk_$2l4 zZ5u~DzxgAA%@Ze3c-RYb{S6o1L}|B1K#e>!c8e?S!Gb~9Hv?{|`i;k;BPmd<0Gcc9 zo>r=UxLlX^%|t#u^0{s72l3)VkBD>fRSY%#xYWNSHkW{HSpxk?l|g38RFAiSx)=`R zFjY~v34$1x(-@^}_y{5ic)S^#^vd#!IipBzR7!ZJY6aK#5+ zI*Lw}X`x1Zd+g~h39-y|r+g$T0|C3eOK(8`aw78PjT;kont4gl5xaxD1GSxvm?O&6 zPJiB|Mq9=~ToY`NN8D<2})-kJ@3t0A}PV-vQ?AeVb5d}Rl%=arqr^rKH2Z4 z)fY2$pH8YU%6ra|e--~`>Vd+&jhN;CE^C)p%{xn_BMg1s+ELcq4c-&wgNU^BJf!mS zp>WdP@gh0HnL9TzJzl6y=hGwM!;tH!wcEbm%9=-*Q4dyueBV|?;F<3gdd)KTNW4pk z^0pED>ReZ{x_Qn(>2!kU2$v&+pKH|d^$MJ6ST@<$%u~JAB=I`bw>}q6935Yd4mZ)+{>fvrVQe2)vdUfWyrO|bRdly@Q_-Yuamh3_(@z3 zhUz5wrzyOr>wVUL@ic~Ian61M4S-Q_o{XVyO?wA>R1o-`4>^Y_IwRN z&_K1UghX(R)7XctJZv=@OJ@coWaEco5H0_ao?{dnGrBG%Q^MSE<0wz^>ZTM~p%g8z zG;^UeH?NEU?;i*y5c0LBhfVQYD!F*nt-DB$dmD(&UIgDF0t&@w+=C&{#;c3Xvh+gI z%RoJ1AYpGEjhZY6KJ|O>o)Q8{`zSFB6!Nul%i1{(**@T0y=h@!JJNLRw2WY$fRHE9 z&5~otSIt>%Mo_ucEHpE#T4+<4KawRhQ6kaW@4Rq4gh`2>Y*RYYB!mWVyg>kj7|$b= zsS`_LyF^Y0FiH}7@|BEFeY;9FIpJ$I5BMxBQR@1)$xJhl$|)qJhxjgly@q_hTvDfi zyz9&nK`YS2{@7V~#96UEbeAbc>AO>}x9w-6VZ8_j&i5{+*nVscOu2}Scq^{n!8Ka! zvp%&N2P?y{@MrZ)N`vhJ5J5k5bE;x{7*Z#YJ6=MQOQ^;9b-2Ml8Rv;$~@^keB4y2T0Wh*@9ea0WX!1;l4*)|mk0*F<R|XW27MfFk{QEl=zajQY z6CFG>#JrPHV#t>hcjd@>Z$|x5rifID(&Hn4`-a`K!DxU??Op0;l7;kFnX#B$C0rh3 z#v6Y1&`{8H^T3D%$Vh5Jqkd0iuwu-dqgQ89Gb6S+fBL4lR{w#w_nBCObmeE5?K8sc z)ZK(h{?Om@g|C+P#?wkb(+r-*H);H0dWW=T*Tw8zjD{qZHOnfjJNiD$T<70_dSnK_ zBw2?>@#JUXoGQcv-iqbV7{V|$-n}b(U&ktDY4?Y}e_l%4)Lg3Ic^Ww(@zC@3!EBn) zn#^8UvG~_2akL)MdFw!IGEIEzvBc1N+_V-_UQ?4}1T|GqQDrW+cPdW$YL7k6E-1@R z^6K5efyB?#s-IU9hh)`F_bGon>HL4!(?0NCg2rD$7BB({?DT%cqd zkEp52_5D$-i9{D8&=FO4@^{G5qRcXlyBoEluiN(F0!g8YMxp1D!mW+MFC|473w}%={%9Oyt=AZS zrhWBM;%@_{LFtLQ^B7O?Zq7ITDdFTU_fTz)~2Hk6Vve<5QlC42A3 zA6MUOBZ)$^_-KmPXt_ANWF=b*n7U@0+dIPSEL4PjZmN9&ESUC{pYY7nR+abtk=o=d zKj53Qt5pRWYc|cf&eEiENm^B)N%e}fc%YP;NvzsL?CI^ebH)4D)Sl;_`DQscnX<8K z-tZGCz;;uPT|tlQYlchfR5a;4m)33d6UouRxfwwp?(e#Y70w=|3^rwKNHe9pFub&z z;z;DEAunW-IN6sz?Gnc=vbobV>#AAtKYj}KgWNQ1zWES>y+OKfNQEW`L``1$0}7gK z7MpLa%iP{>zWu#9_?3+1S83aAOFMd5dzKdaOR{$aWV6x7$KtYM?JzrAdd?qTKSb%X zt!L@CHaR@Jn7J$58tM!`r)=H$0R99dK&kry03(yU+w@R{6Pj-UyRF> zvrRACt5)gze=#W%s?XO*ufEy*~&soXa+AewJQ$ zHP_X8Q(g|Uby#N$)3iCrHt*g`e;(X@*<=w<|B~(mzk+~bq12rKOh&`DmAeCv35VsF zfvnKC!d<7%Vc4W+i>vR)-FsyjCZR-7Q<=;+7vmhOH_tMf>-@jD+3QNr3}>`7H-Gc+ z&aqzGTPX9N(k9P$cz%K*Z~z5BgaLp6Xi#&L4)`f=QovKuq);O|QP@z$j9P77ZOL%O0f4%$5?X!C$qNq9^+q6QgyV6eTGqFn^ed|l|a#>ewP&b6)a1#{ev=g zt{~*+oh*~1m?x{+v`tsu`YryCG1RBJkmy|rq7svPtwZxq6qDjx>vF4h^4SOCKT*sy zbI+e2Gaa&=Kjl8$d(`FV`S&N;7<&FUidhGmCL6;MvERE*i=L0~O*!@@Sm7Tz&kJ|>KF!i3YiX`+wT2K2Q5_GnE1Zb9sI`*YK_RR93)M5Lby*y zHACy_pN7zs{;wU}vJ^&+jcUjrCVPxC{@cE^I!uYpC6mtqvI*QGf$${nuuDwf8#}|d zNDTvMIa6QWVucF%#j)uJ{{}DxMuySk0OpVUv&o7?jQ9od#uqXayqRzVq%)i6UQ`Gr z2QXR2jgCYbi`1*A{wILRI#uZ(ST&av*u<~nHKa-MV5Q&yH-&xfN=GnaNpw}*3?$oR zMxRXjO%8O!n+@+{y_!H!S%Vo%4@k|&ATHz`$_)mrRiHUp88Ton_RNhm)T>yF1=U_k>9(YnRod~M~>xtjZ?i5 zrTa~zHMcJ}ws!CHQQ!NZYun1#+H>Y#t@T5=5ay2UKEKHn@)>J+txi(a25^U?<;ch; zy*T-DigO6#c02)Qeb&!9?PXj3Rq{dZrw&ez`Y0vyMFw*CGHX*j&YT>+>0bQE5cwqekN$p-9nq{G8`0G_mJfQq)x0tH$uh)gcCY5T= z+o{kFm3;2zYV)Zt?)lmJvQd6D_oniVb6dHZ-fi$V<)or(uPye{wbVzG!^4|FyPRR($)%|M6(Fb}UYLvXeU3 z^c`fow)MYPWnZ;Beb(C3mtsf#K5*S(Sw`?jV#*)T`3CepJce=F>xXY-WJSKQ}Q^^Q1rICAVAy4)y*jV!-GD~xb7RUfPHDL0s*Im zRorfLqaB+L6y#BU%?-ARl(-Ik2|&ZRZp5%|Lg^@%nv*=J?;9<<>qtoSXpxqO~C91YlfQZ zuHW(aTcxh~YNhC2-LrucjiAL!B0 zAN|}mI2$SayV0idWW7^q@8Qhvcd5EZMD{>KmG6&>^WMtEUQ$l*Z#;W%eey{r)!*RG z_wz3|%9ttpbPsR64yxu{zxueD&~tHi^-l3{j^m8=a*qWkRp8Oni$C@Z*EJTRFvn6A zq)*W=|GxS0=TCuu7{$ibH32ls)sQ~>K=x#Ulzj-&o)E3VucPv?eIVp#C<{Lbe7sFy zZwuvY6Z~_^$RGzmVbRQT09Gu5ZYfmkICLdBOv#&r467tE0m|axlGt#9j5A6^j{QGf z%Nwi_%FqZ$@i3t?4uwoJu}maRH5^6+8$ct;GcR6N9%VGJj$E`XjI?Qsv?doV>mxL= z9CtG9WdowF%7wG0fa^9QTW0CV{*N4)(~ygL(1tY5j5gVcpfK|L!EhT1i1uoW3}=mw zPmAEu3{SloH9-iCQM(9@B4Wc4jW!<6F&RsAegV;D>ip15ZvPmRep__$S;UhaM4mH4 zLM0PC2I2fL(o!qN8yiEmefE{HEm%}CG}dhi?v)wYjg8IwQVS#VWoapjvO zbeSRe1R}|TaCRv=porx)C74`_thIBaY=|SgKs+HZCea_beNLBM2tbU*Aau>1Oj{uM zv=Dx+|1|q>DM`S9r(g!Y6D$G{66tgH)XUjnBp@C%kU+JShy@0L=j>jM#DS9IU9hnf zYihy4$#L2*uGFIK85*7@&`i2}O(NqJD4b0FzaG9tD=$x0B>+w2&~ z5EihUu-At0)6BT@*_m8hbSeb6T_cd-Krlluyj2_$E10w|hsZp(c*YuUgN?U`rFAaS zscIuKW#9#D3|cM>3|AQtkD|!cz`ylDPh9*smXkSw0QUfT8w9{mn^pW#=7-8us`MO| zzNnXSxmIVXCnYWvg^%r>h_|=0*;O#cg$yOv5YsLQlURUuJtU_vkDP`iNL9jC9;G>6 z)aR|TIpAmW*5v@mdPv-Z(D8~qD`ohL%>1UYyqpa=UbL*~-da(WV5BLAI)# zfgykdSpL={3$G85!P5U8_xmER|SOAyHa{_1vB(1svl*M%(f@^p#{snQF4913>W z__hThz(Vlb&vM?^k_lIYGM8KmJVtY%tcAeXVgdyK8UjWt7bmF|QO`Y9gvBdl<>!tm z|A~^Gga%VDBW7D=rLd5AKl+n>YTr!ok|lYXNPkSGmxW~)^^r**a=;Ccmqzde2>75X z@&a9+ug!1V8-GG9`=*>iksQgp47YNnZ>kJ>(h>R0i~58RQuY!Mg>8czUbxUZb(9Zo zBT!jo59L$94Y1of<#o!G$D{F%v+y2bMXp_Ci&wd?7G+{}*$u(mht*{*^^nJAnRVij zXB(B)P!~@fc#mDp3oC?~YfYysb=Mfe%&NS_s4_ac%sr0e0H|(prOu5;$Ullz`WZEx zbwk12bdnH@N+EK+7DT;SM(bH32ilmDcCKC|LS)uKf zW);(L7*NmK_S_^SBUcQUz{Ut75tT*Ya^=y9po^l0lUd4NLMb_RywJji zLH|*I_@6F~l3PAt> zFx}&N?Ib>hrfkVT6$r>vs)$P01ow$P0r1&@_;i5f)5E6 z-?-C=i&e~SEqxoB>#fS4T+P5liEQ@DhBJgyy@9`DOO}FywR^MQuz|aUh>#47;a-)A{)bywie$oqX56>&Jqi_A0mc2n=TOV|V1#-K( z69f#cLo`t^0CU8n$vgkkSs?^`FI__o0KZRdQwY$$*noRMYd}SizlD|j&`6&TRTRdV z2Z0vaM%2%R5TzURvQebc`Z6-`ozaC0h_!4KGzfm-cpL`)IkUJTh zmnGEPsAEB`vopmi2&|A&{gxSxu;HLqh<0JKNmYziJ^Yy1%zu_h2vDKpjmI1=QAS>d zyjK8zI8JY&1c+K!|Ey|u`hBab6aDjNx_wo1)$dHuh8v4BFS+WQ?G~mXg5YaUt}6s< zBqV-Z9E{#6Nf@K!{n}qzI8duh&jFPs86dl5*6JaVxKM4QOgS^*8=Ou3?Pz>G-4P(KXTMt|WU?^|+-z5spg48qaT>Yy+t{Jw z_5yNjL~VjDdd=CGR03MC4R^xn8+ zahiby;T4zhGYXZ*`UlT$jaJx0iqWYFz$UFVQ zUBP~G`1Oz1IzwV6eP0x1!job*dcveI0uSh2SMZ~sOw|mmQ}SKszs2@L=Ji!?-`;jc zp<4{S>0rJpZnhw%C_aYv*SBB@up#S+jC0vAC%7t)rA~p;o78bd1i>LV02GHPwF9o| zjhwoT0z*P|YnMHY=vxIxP?m%IcN)>22%X8n+70Rw@?9z#kS&^T@J_>wus6sb%W^x& z=y3UkaW8x~B-0A8q;=8xlE4T&-k=n5EjwQ;@ukTSrb$gwLhM;~dhM5> zdeAQAEj=abrLcGRJG(f1C>!v|9D(d5{I!VjI$Mcel2V{cYn6Cg)zErBbTyPG%19aA zplS0r{R6~)6|{{o7;VoXc+uzDjkSN64TYyeYCAvWMGoygU4%4cjlSf+?o%5wg{_?0 zp#>jW4&Alf_}SS;Qv~zPOs6R9no&Yn*>%{D_8gqA|Jb8H(c7RXYM2m8VS0o<@0N`s3al zuQoTGqJ~FbN)F-)$8uN76GZnto@HKcOn!j{vdiZ~bw1Xji`3WYqO&p!+ZXsVkM0`q z$7a2J=4BMy6`73*iTj-NmwbIW&Qz{{fJ6qvV|g2QYoHR(EDagT#&!YwRg}gNuGp|r zueQF-_lUP}(2RwDb(VF=#0zWcXLhkU^*V5LK^uF{Z(iU%KgDy)m>xg9+&OLry!hVB z(V6jwM^QU3`7Z(*##`adg(%?P?oH~$_{Sqrqw#h1jTE%GT1LM|Tw{nWkPmz3Z;Yjwq{a~6Wp&UkJ_-O$h;H?KIZ#K-*kDvz@{S+{2$P_|1HBDSu`t#s7C$a; zkKv-9RW%n|n|8jq^w{wiA;9DhrkO7LVkXTI+@>r%04=>e@B)sicI~jnA%O z_)bIKECL@=Dk}SO8052+n=Q4{-ZC)ZT)(ZjxOYeF?i@evCBq}Pj_+m+9fv{2E7x;JAS&Z|OY8ryM7#l(z|}7y8gK+gL!5*D;7JBrMthvkano_MyqR1`$XyZ*}^@n6zEa;~_hA--v8l{XtS z>}2=cezS1L>8{0Q1f#O*Y~$5$*k--dCX{*SBmEsbkt-w3x_FDZ17EIICR3`bTlV;V zCMr>;)ZwEmpU)Uifv{ISDNdMz2Yiw}m_gDmWz;39bx;UF3lHy{FgB{sl6dZ2LfIoT zy!`LIxTT=?4@R7gyb@tgq;;{1+p(Nv_cT2P3Ap;rQjhNsZL- zbY_!Fm}YZv=7Ysw+0BqKz~>T{V2A%=L^Yu%bCxMmkq8Bs7_QjP-Tt{K=oK*<=g3RW zhQ3i(5l@J4#PD6y)-ijPLK<}Bmuc5Fcz{&9bKpqY6`}Bf%{;BFj+r0iQ6d4$8Mz3e z?E-P!YHiVG^TjMfSWfD!xm{RH$t8%3R&3*3(UMxAjlwpadIvtEC9SY0=JxfcVQi#@ ze*&JCw2-L&>l%d^%#e#wq`YZn-EJ)HIq%WAx@o!Wn<}T2&vD8y(?l(AV?!lIP)MOh z?SnTpw*F&a;PB0y06!~jr9z9o4rjBLVbgGa^^e9}OK~j=Z=SgakKY=S*6Kca^E|Nh zi>)S`aS7Id9Pr$Yw9B>_?#I<7M!H`(@>!=Tym4r}TJ+UPrqgoz!NQ9Ih==o#;9^L! z*}nnL7k6&HAqPC4M~coi9xkO*%96e_F$$B!D;OCsznf4`@7h+s|37x`j^yq?G0tBL zi)6cZ|H;ieR)0u|wM=A_f6vVQ=Mghk`FtFYp^G;ryk=%(U(iqz(TjR}$1n_YL_ z@j9jo-t|s!>Y8L!OR1F__D*{F`3~3V(#VZ+-pGTS4rrF;QB68l{p5d)-c?KC+Y_+O zm|aI9ve7di_Q`1OauR#EJYh>VdPAIw{YlG{_cVO7-*nxTZ6%4=Gg^^?q8NDySF#Qp z(!1U5GP0IuqE1AnI6Pd*WNvZzK}&PDhv&o9)pg2Ktz0foAJXdEZH?2)5sx+hiq*AT z{gcj*O5Z6Ae)AvDUNQXKegDbn>iPv`KrbYJ{31a2VFQKcnR4p5Ewhy3CVg}pl$zT+ z&Rh|2Lbo<3+~b|ZINNeXXvi>^;5$h0d%HCs(ZI2~R&*$SKl-WSqsePK^d23&SfEz` zr?r@Ge)O*op7Q}Skv&??`^^(Jtbwnn5)S-$?5lL51Lr?hI98b}eG+aBT#(^9Vmh?= zYCF?FRL2?9C`;NWYQKs6~wY`dk)3xKC(|em77=|a?3a@`& z59nukw;w%|Kc^5y;u&LPqr7dM4C>k@DUUI-5DZ^x(eyUHAX7X4myh3_Uc5VfN%dqu z@`lU{`t`GsOUrMQUj`L@mA*e2{p6_ZM(`K&^2;H3HhCmOV|H=k0zXdyg~Stb7Sv#gz)$Hf=zoN z=T`^-$6l(4UT9h`th5(SBc9*TOEcd~yVXl~)JqTUL)h?Aq56=x)~40Y`SrLL2(S7j{PHS(Tvz<8!th8z%5vc1sm?XYb1`vn z2?qy_T$7rI?m%E{e{KJU?WI8n#dCWx)(P;D#MdHLMTS9bJc=B8HvQhFv1yR(hA*N{8J` zS=`euOUw`7$JO1#FlvEEe9}5prAPdZE{h!@HIATmwak8{Fi+Ewz^0MM(pS9)xdL*T zpAU=>nyx(Af+>SX!%$)myI|VNcP|A>k4EhhucBZCztPyiD|@dsFV?*|YB%8#rK3sG zmmlmCyS;F32gB)^h=Rk>G(W*>VAjUoI5pF;OdFmiURbFK>_OUC?zwoBBaAalTf)0V z;FoqDW<2u!STRm#jD1|YW3-eM`kIM}FentdHU1ntQIDEvz)YOFj5p#YnoK8}9Vc4+ zCR!tOgRf1rl}>atO>}a`$_`F+Z%y> zMSg0zNuTG=)av}y+LnF==hQp!^u-2h`d5wq1EJ|n-1L^I!R*heH`k^wxf)0czSyS8 zdTTTN>xBVYVf?F~z?Za&-E(r2dHO(vdjMnTh#`JU=t;?i9V^b%k{=Va2RjEF zT{XXKl>g*rfaL~KL$uDju=?TzwguOL-zcsO)GS1aQB(^WnpQdNuP$y4yE#AZgCTM{W&>;!Je2Un@dP zBP3rdbKQ{Dhl|L-)o8g{mDtqz=1u(OHJx;EO66MV^SWjE4>(NA>gQ=!O!a7G%6I0i zO;e56h_d#k8RBndgoy<1Wq4v!uGBRHV_IgTp^2N#Z*1PYx%KhQ?J_x?p_{f$3-){q zcVre`&Cfa9T#)EM>D*elw}xSkTyXJ^?e;Qrg_*g&F)KPV{rITZeQm+(4>}FC=yMYd zlUek;x#)kh>4`-$+G#PsEL*y$iLr2QuaA3U@rF(N@ihx^w%3WGW+92OxCBs|_JRwK zW%8S$^hh4TeWOZ%Sw@*<5*m>dWU<@+(z9+k|9H6ovZBb7UC325ki`Y?S#B9nPTV&y z_Fwr8M2YCd#0Rfb6TV%FkjQKvTd`_bc@|vgTg3MKBl>nJ%;KZf=ZK=Ro6RsFs(H*R zyj`dveH8*)?I>GmA6hkeq`54}+uy(P@@CG9DeK32E27`1+vHO?!KW78&NU^-K8i+k}a4Wi`H(ja85V{SD?nV^EDfHkt^k~UDeDzLA$ z`ptig?mR*$UqkikXSU_Ny)gc;93pSI<88$2LJh88&HlMID`YcWhY0)#-4BW2ir`+MjMDW(v>6W?HiTAueTAszwoS zmmq8bApl}p*iqyvVAxic{oQIq>)+$p3yTyq-<7hKtSGAVB<|fj;Hq!!z5B@;m z$FZD$VLsh(7L(0>9RRk__DCNDv!nhfy7?h6A?;%ffm+R6>Ze>@^7|Bd#vh!i?b5e0 z&~~Xbu*fvH25#e+3wAM{%V@n}&dm)&O&UCjgzLph_9XP$Q=|CP<8@&D+*82E zBrk1(o*B$Ybt^K{Hd)z8=Py(_LUnj8MyR|HC@xq~to^k9WA={t-!A*4dPj37C~OIg zXN_dfjz6|QejwVYJh-f$4-=z|yXUxB+*T2u3WqW?deX&Zs>4jLuTdr2le|7WmD?sc zrzsf3Xvo66qS6uz?ggFTbH?3Nx$(s6loUCG>~~P+sR6O#ktF@t7lM&^Gw9YE=mXTaS#wHS8w#McJsk-v zGKdx5(@uzS)3WLO-Db@Z8scx^+{%hl{`;AsBa)yITRfcw?0ucT|0xjWa*~BEr;8mu zfu`c!%^omvuDzCFT~CpM_gSdD{zY)6X4p=Wb4{vu&(r!=4cI=@gAQiEewMB0Unljq zlqAc(Jqi7CQQl+6?Bbqn;Zb^KXvGyf(tq!9IUK0nzA>D-*%JHw=e4hX#z~mu96Reo zFP9%3(1+MD1Ox>L3G=Uv#3D8m20d`TpLe(bJHs(DJ^2ig5t5eQ-`s(o4ioMsEye#z zigmjV+s;Q?$j0JCs1m*IrO$p|X<0qWDA?*LYoqaF2AkH&#%BGFx&Eqh134* zO69vxJOPxMB9!r(EihJ-8p16~+xl6Ls&DXikOvXrj<1Y1gKn-yeMYksl$IT2qYvFw zF1J>(Tu_-qG277ZQeW?1+x|)!D-jI`m~ZEAxaI{Y04rdGh69SH2Ucd#3(wdaL$R;_ z#>i^KF;2r=xSSgER@iPhvp$C|E3d~M9_sx1K^pag|AYh?|0pWlCGbe`EDy(&yuV*q z_#BsbFC+}6P-Hp`A448|| zt9`DJTbgT$F+vWuF@Mz`Bq{$uJi^%dog^w}*2!n`{XM|GJbr*Y4S@fc7rlkO`uTRL zQ+6SnFWW<*U-M-mr?!M$`z!5?>Fmm961A` z-1s_kT-|wciueJ^`xIqp-ihi9kvJZu0RZfJ%P&HaAvC?%x&4q7r!X3$B?DEIS$L{@ zgQkWy$3sStmU1RIq%@qOxarX5Y`-!bNh8p*sG+k7-DL(TH&(x61~7m#QV^J4``dNKqTT!dVoBR-bVg zQ*!^ZijE#Q1hvn{nx*ut=_&wO+Jt-!C^>o&2`*hSd22XU)Jj6wgE)*<8P1B3NG%4H z$P@5YfP?l@M5!a(^zH;dz;SrmJukfor)>k6%@V!HMHkNT^Uz86YAFMVIF1u938X;F z#??z8Eiy#5D;pK;3sPKW!;tK1=~&c+P$28x?Qzk)CH#yAwZ17&5Px5E!INiB8WS#H zAD%w`9(R1P=w4drAnP~x*8%{{@GNN_KDxs21l)3j79xiz zN~+3C&8BwvEk4Mgt0*zVWE~_i%;wf5F~TKP3kl-*G-#zziV>e<_W5KUS9ej<5H1u6 z5voV4Tq;+D3)d}~z9Oh!z3Dtb{v@+DBu$<{w^+3mUdE!V|#DN7q7BP7xvgy8P#|t2Lr*tUQH~ zRwUmXyD)30;+OIAM!aq{h2VaNa<2HYjuQD4FeJ9li(yYAA{i_)H_(rn3zS~=Gf_YB z@@xFO?DkUWbnWql(%EMGjh>SS3mIM&PP{Y8%E7A(vgvLMjn6VJN3wGQE}A`JZNTs> z4Dp`0PbLAP=S^O41ri>l1ROVgU`cU6Mq2sC|3xGOF?XRUr7HVqed@!x&(YBDhYa+| z#V8Y_khrE^IK8(j+X8@L+gCoPBPT;+Z!kjKtt_M17oD7ccXR${|sohWOiQOB-*c2n^#!{0VwGce45MvJyHGlF{~h zY`_<7*Rqo(Jho=p?&DP&#}4=(vSsE-YU)X0&wiiP66q zw>fMi@${VYiH8eL+9O{JJ4*AJ;vS6@#Ne|M*)AIhE$QXSy~reOW`SA7Dgth|qx_3s zLXU9!Rp>90TRGz$7r zv;ODqe8|K>sd04{F24Ij4AL?lN+1;T+vK(OaTv2Zo(XTlOfvmo{1d`bwdvF|kpqRO z_~>xtou-5{>;>P+c%(@#v&Ay0=b^%Aq=C&Q+v~PrdSQpal*=1MyUgjhz6Pu?bfbjR zB;o#~#K*S|tbNBYAOP{!9ucyt_`X3v$zYsVWvLT&?c=xs!;WE`J1QZ7@Kexuct(}y-N&aG01?5Qefl{bB0DOK9kUd4NWnLt8WyQvXmx4n z(#xWsoyHoH7&NH;yPI!GXW{uAmInPR7Exhbce_iU9Hvc_UqPd{`SDSD1DRX4&%B}- zOH@7?nBEdFF=LMls5-7A{0L$`Zb}j~WPV{^y6+PDt<0eI>webN1K?BvhuNHk|uEpYj9zZ+acO;*?ax=?ckU%}QlP3Yr^6%a& zSDhPO2>B#KGEyeNdBk_z{hwurcSE^a6+cNlw-eRmfoE^zpkD2kV*Q3`+w7g*|M;sZ zOCNLJRI$AN^!0?3W!x8)C0p9dWjPHPMLMvv`9WGoI59xr4aQ1uk4iX_QteB}T# zAbFLA&+Kr8+A#nP{HHiUIl&#{tYdwx1M0W|0s8lWBg8tJ1aZ2I0pvAG}Z0F&UYY@TRD5V5SP(X|_ z=3>;|arJAn8l$#}y}OE1S6GDm>gX2QQ_moe6_#x`Ox3mg zkoS^U5@&4j(Z|9F$0SZ2+d#B3W|E%ibF#fwh|&@_yD9@8w_U-89knM9x}AI7xsT=x z0c|Lf*~(DxNQD*ppycY4%eJgEr)xMA7A%l9dShe#LwAx=J?I3d;?ZV3(?&Iez36+= zW*vmm0BqcXtdy3yA2(>1c6IC#yd0bp*9r2Z|OO4i?6;mc7&l(U)pmL1pH=%Jp5M2pHxC4dHKG4}iuXpq-y$$b94ojO%LgBcttCBklK{);4 z9A$0Jh$$+2ItSi6d{HdC;rXEV>xI>*H8? zRw92gI4i>*M6N~ zgJ1%YUJ(N@hN)(D&-5sIpA`Z(FzPO1*wp}+R}WGkstC9B1KdLu3i;NQ-Z6>|bC%hu zcl9H2d8ordp%sv@BPn=_(yj>vNblpX=fY=~y>14z62lB%T{^GY=;O?q=q(ELifC<$ zdZl;!j$g$+@q5DO96^V$4VQ>bgjP`4%GxBQ`)qhH9heOYzWpU0ML60fct(VnvBm$L z3%_fw07&Qs3}ZtowGm+NLL;y`_TiW5jr1~{UZ^_J`6;)3c0ym`gA4v}Ar*iYu&J5c z&;zk}_516G2qB?>TP)!zPVADi%6{=|@YysytOthX4s#iSmqxmLxGez~v_bt`yZ?Of z-Mn{<1n-C5dHxE5$UQfU?T0>tiAg&ug~W-sq8g*bz!2+6(-F7^?6#TsVd+6boCn$?=rGb@x3^`Iq@j1I39pL~yHGpeR4FM##o*G0r@j%4%2R0f>-#6tN3R%_8O|Ch zwP7sA9Ubbwxb7Al`sI|f(jUpKJQU+WyJ#H!q`I^~738im7jvO{rT^)wfr$ z*gPm3tx?fP8=bMkglq(OZ0M*opvzX$w%Cw($Y)(U)P-(e9YQ$7aeioJW6*=~*kg8U z*b}PdL_^7qe-_xD>?0wzhV9EnP&2X(e~qC8IIDaVGDqTWuVKWnQ+xKfcR96V_oXWd zmMu|lQPh*wqA4w*za=I;GsZYz!x8VOsXrOkT3``7`t?xd0(7)SM4k}$uxDCqefK?u zoyu(YFp6#sYae*!94)sF&c9eYyNMlnwim&xnWy(0cf=)*+{6A6lvLdV(0=jpzHjY1 z?#{27(-jN2*ME-dO%q(ocKopKaT!mzE$}8u9QN3ujBYce@N46^rZxNRj!z#dA`)f+ z)TtessN8kIWZH&7y4nX7U5R=ss-h-u4U(~gRiF!?K$Dtatu+oV3 z0}<4MfGGZ1umKjQLkVe+URg7`oGw9$tp7T;5ht7mL>}iVGsMj4a}Z*(a|s7uc{Kpt z8Fw9;Ms(8d#vTddmm1PGyl*%bd~J&6pI0->ubK7|&+xJNzE=Ro^e!dQj%F5r1xz^O zz{m4pJMGbT6I6222azT-$8T{v`GXTG^edG}-Exr%i74jEC{`fQ^AHFARG`{b{Wu{^ z-DsHTK67y&*A_f}=7Dzl+Q{=ZUz;K)qCA9S z`R_z-+QA11CE!heexsee@;1GJz*gHc=*)$?aDs+5IR3P{WTlHcV=@aHJhu6)BJ-O? zV49ZOnI#*v*B&qfZsI{Y4`SN}?V5_-)$K_B;nnq|L~8>>w(e^26KLcY`_-OCsBSp1 zH_RRDfcF#Z))|hj)>H~`>-~MdX>WS3#rxU*As4!#lAGa(O5py7+32qIU6^-dBY{H3WNl?Q&Vpe51h)(*l)A6X zo56dRliUy9pR%|gK@grT!cPb)%(@t$`xf4RW(&pHT+pj~t@|{@DtYy472XnWR0GrG zZwr2T6r58s%_eOq>lE&g8nMT^g&R-1=%1d;G3lYBdQzd51{$DEcA3z0Y*Xz$bF1ce z>@AX_4|h7QLaHpdnr_jUeTeG_Eas!;X5iH?RxtX#-K?gI1vZ;0w z4qXvE=EKeGr~1w{+~KlpEJ-el)OX7z*lN8&XKt0Z2bUGmsp!wj*w#rcaJND0AMI2B zRMnRYwt>5Buuc8P^u3RNL|{?zF59QPzas#iTHim@DA3tQz6D8KjpC;Q2$aygGkwdp zIKXOSXURSoNT5EWUv*bF3s-v|@g3B9NgcRH84$89qV>Kp_iXoTn`&AMFxn&v5Vvy4 zq)e1T>p1&U&^3=>n70#EJMl`v=Z=fHo;FIKTw0wa{z*Ezs6edmP_~@FwD5IG2hXw9 z;W`lOdq9(_-(PGi%-()CqYVOirw35IlyjCWUqp!TwAj9`Zx41R7Y~KmsX|J{HNx5gg@$>$! z5vDHoBN2#HW>cVhzG$(mtv0MLd_}`>yP$y5SINt;YwpoY%2cDGhsdYP4cqG>&)YedL5^ zOh4Y+Wy^u_jyK9XqVnbp8>Oi(KLc|2k_?_F&^+L%%6qM{^d>$wI1K7Ao8u@Xa`V+! zwU@~5hOo*v<4-OInQO{XX&^MCWRPJ5=3LBRi&_YxJT_9WXQw6XT)q_lc|7W#_Z2;m zm>nT7qS2`Ew~()(Ae8=PSb11|n8$>r2B(u3DMi$Ja2n75^3q_^>q(PCw{QVo8v01d z-)*k%6}B{=l=2w$rC?FV+!cx!vsFK6KF#w>BaO_qmts4{#)5G2m$Gh-2v^K$GsgZ5 z!dQaV!oH(C*Vg8UBFeZW2G{*W${QMw%JBrYvhBcWUZdVUcgj9ZIj&(L;J&*uN0N4d zly=@}^ssl_R=V`ZH`<|}=HBp@&E^+=R1R#UC^XnmJksc#%Em3|xF1%&Z&p)G&GkEO zCCvZ4bwUuRC{66|(5~^9SEfP8p+}X`3`@(rd^}L{SWi|0BL^5$j;CmC^5t^^5DCB> z_vIY2*hsTF8N~vLn36hCFa=nJybDVh2BehJ11QJWnOy|}RK%wj%0rp1Sotl%`81>P zEtt>+t7MDfrUKu1+XunnB~BiYdFdI+|%X1Wj=x z+g;@;H4MDs*lX6*@sy@Ygm;L02pJ+;V)m=kIo`}YSx5D}xBq#1Bc-YxwY7~U8gOyD zW><ubjZX$VD%Y3?^zeQm!!dWcLH{CGn%sZ2Fm_i( zRL_qDd)!v1)t7%-1&roG2B7|B3xpq)>6Tk_oO3Tq= zALV%pI4|ln2$E3X0FZ3h?v471;hH*6P1e-+{Tx<4J0LJf0a-j-uW}3P_BAlh_ zspyte4x3Aql$KzwvZ`jEWFKW1*6%IsifUh(|9+(QCK~1;m%}_?)Y>RZ^gT^`!|ZMb+x;~2S!O**Pq*O0+qttWgp*yb`d zNX_(Aj094uI$yRCT1QVu{k*d2Ivg1ux^-C{;gbBF=a8sQaS=9YbLWL}BR`S5&&XWP zp$p^Lhz_yQS5bVWrJW(34C>oLOO#5t;;Paf<4T381QCYtQZ@dgEsXU~9W%;ZVw-dk zpFr`T>9`NCf(@TUDle0{St+TGp-fp?d(y#E1{wCc7HR;xHdSiDn2ZPW35SxG{Uaqa z5^$xG*V?tcg(T9%2g}r(KI43f?>;#$SwHNL)g&1SliLSYy_f_QrQ7{1tc>i&KK@Aw zo7J-|m~7)j&$JS*;~h;Ec6@SDS}9+;vswn^M(VTqq^A_Ae-x>*eK1q=w zACa*}k*PxwAQUz0Bf8uunwKKHCMiZ<0r9y}?AxK(f#j8ApUc+vSM2{BUI9vpLtY3; z*o)Kt5ND8*VES?CBF=WuUV=wTlFt|M)m~Ec#p3R!q|A%DMisQGz0{AyIAveyL~1eZ zAJWEBGU6pqxlCW*oMY45d2Zyxzf>+!3F03r*JNtw|1$P=*Q$-b z89O=4&6~eA?%mm>vcUc3h{pfexc9#|BVm-W+sIyKGBTt#CO!|aoVG^?Dfb$W0>oj@ zMEMxWSU_&>sC4(3E8Oibjtf;Lyf6L-$Nio0p84b?A$Zns1M`1y++k??Tv1Zn7kV~p zXtYYuPbJKx683-NxZ!mF;J9)BhvObQ{|CqYLY((+99I(@|1TU@_mxVff7#P+s)@fi zF1mS_{ymJwaX%B_79W?%;;5o=+;f#`G>%I(G8uu!aUpT#3Y^RQ{7flD61{9_92ecZ zpBFO`zDMJ@x=j2o`0N_UAXc&mKt1O9dZZ2?_5&pNJp(z~nPmTjoGK)zg8if6?s6yc z0)-7aZZ4{W30*Hqc@`_y>S}nl;j*GzBq~MZwX&>?b}9|!A$PCq3|X-p#J((z&#Q9# zoK(Oj&!JOUmks5AHQX=D^KI#TY3xu}OSP5d|7y5jR@O&$in<7^{%W{t=1nc05)WU- z8A?rIQq{mF2<=`Y)y(6ri`x!baeT}dS9j-1>I&C>S~t@w?Is(lk4XtMIec8dZ2R4L z%s*%Q>0s?@SuA3izQ@n7Wa>}BudX{|K=mbU$NBg2u6hQ4gc~em2a2lpBla{z4>1n?pIRwv|jZ+;jez9 z{d%f<0R$?KCzMgi;9qpvrE>)m*>B8zh#c~ca>gc@Gbc!d1q*OAoR)Tc@nG*2yb7tt z4vv1d3M1SbN;vPkl=W`V8g51r`}aY0{LyI-khezB=C}RlmlKb@UcE=x@CNNqG+9W! zR2c7E9OHUtVh6@TTnGjAyV-EzdyO@bV=9xRw~IrXO%eKmGf#sQMYHCeQlYwPTK@cV9J*at|BVuPm~f&oo5c z&TcmVGDBTUq$^f90YXL6Eq!w>KM+viKKul~PkV>I3=Qx2aFH#abtt|y*#&%kbM2Q% zKy;dnQrqD`X}w8Zs60(?H$zf6KK^29!EY=!(EoViO|8>fs_0~P4eSYV+hnBXdC>wC zLrDhrYw|i*>H_cpk4FT6E%EBze`@&giz;NP9My6pR(PL7q`Gk-PQP8Dld z`|7@?Z(ow%v*)n+TpNquUAOl~)hurk<{wg(6$8N$I8QnG)%*H+^9sWfv|jH0x*A-m zXIN+qpr-f;dw`I1A(vk0I8fGl+jc4?Tqw{a;;=Hs9DTrcA!?{?qXuHRow#@y?9 zgxs!wb?NG&QkS*VTTUS>`8<&$d#pF-Jj%s++W9^}XS?fep3>5@VeidFb2ZBW73lja zF$@)xc$5ZKh`8JC{ShZ#;(Ztiy^HyJf4;II4!h`v8K?E3Fm}x4uaUGELu9_f z!8@m7zJ+vWWiW!pV4@*eFxVBdHkjB?AKOR_kO-j3prcs~c~BnW?HdBzRs=(z^rX{q zW&pTb;Bs;y6kOD*TX2DLN(Q4)#>J3F8KJM>96T8Su7gl11RRSA{$x>i+E?s_PpGO< zsG=WCkzHJc8v&HffDbc=DV@MHpj5&c;m+2f7GSuUpYZH&W`zS`BO}tFODcOVQ`}x# z8EPOdAW{bGN)8QtlkO=S3_!P!S%Ls%>WF&b2(?+W=0O%wN0JxMu&ZRTF$w_mC@xUR zCY~-JBHbH;^o~r=h$6xOpo7l}ZQ|ld7!?eRQ1Omf4+$y`Q8%R~-?}ivmBnlB!Nr3B zc)$SNNxWZzAWwvqs*o7iipoPDq6LLiUO{@fAU?IWk1;WROyXlSMhLSbN}vpcVvsIb zHgpJX#{`H+tG+#N^)?Yi7(x77Q3Wm0AiFYh1i&b3gW`JM z5EKHr`XSZSF?dWsFQ*umq8KzSOImQaYEmz)@Xyi!&sDHZH}UM=q$h)brs#@VHvIxE3PYklyh?$B;+}@QqKB)r8nEy9 z(FI&T2&SRz*oW8-7=bCpUY#Y>L1W}(6MqR}U_b!blX%5UAYE7is3aJf@r|Kambg)y zSQ9s@X%df=3F(V2$^if{0BCDPT-lR=c5Dx2vC)>MtusC=P1a#e7Vj_~I&Jn#%Jx;k z>uLoFUB|vFPRNEQh`eVok$_Zgz7g`uAf0kPFDG4~qFF)>A?nLROabxn6Q|fiC4sDAfQSO7XSh(yG{+TCvNV^!6X4SY2(MT zV%)TXPMqUyk8|f##2>$Z{40YLJQ+zS6rik=fj1>4>*zVlO+tS|x~u}>Z-XQmtJHHs z9`Qih4q(rQqJAaC*%tt*CGaj(Txv6cc9TS8n=pn#6}?hUq96V-H{?tL?+nLH-ZChObF&0ZwzIbLF|1OO~ZWTTqwhx5_D_bnz#=mPv*M-yHS zMp>SGP25%)mX+t*CluX1Bp;Npf63-0fa7sg@O+08O=YnQY2KeM5wk=?(9-fKH|S$j z0Er-8({MWen=IP`xN|Ph$CDzgpDI$EGaN&JW=NQ%q*dTwhWoTW%n*&~qoJ zq6)baRqO+Y`Z+IhYdaR-GQf|78m^dk(SFStF*Xu8u>(<9J;l;lKv`96V)P?iCeYy~ zUbbM~%|KMP75+tmQ4Ru7xYT&^y$V!FUR+a9e-KG>oO7I2UwK{&_CnYRf5a4OpsTGB zCUJ+pcu*o+|Lj{_g;-Q_Fftidnnfr1KJ6_8+=F+RnA@WQvDL=UdIH5ZA-2oTv=6Qn zQiGUo5@kJs6mr)M@U-~9YL_bv)GBO$Vx7Q1wCxZ7S-==%A`COg7F_vC7ng z0Q7nqF@>-{MOBu(Z%hAB-46z2d1m3GzDL#%mzV`pWqW2dSv4O%iD`!IFg*YEiZmLD zPWi7sFnifj`GjnTWj3cbqkwmz&m#gLnpetMt3)LMH& zla|^B(e1(eV?}J{sml2RpXq6i$%ER5FYw>HG=XL7NOHQ&3gWO=yOWb7O3=&@iDQH$ zBrCjTio~?MoVC}5Phb@Wky|B^;vF&mJ}czQZ*bHEMQhUcvv=SNMtNtef4|?^zD5 z;xoztvN;!c@sYx}&I3w74i|?FgKdVznd4}LOef2` zfjSXUW)!av`aQYAhnXXJ#39$bVyi17i8P-vrz+uy?u4f#)pTFFrz%$#V38eUmIvXT zPg$DhsS9Iu2Fxh(R!L-9Mscl1-ggZ0IS%!zD-NzQPOGzH-+fUt`f~5mwC)PI(;XWU z8hb@z(Ss=Vn#qP4FO~oNnJvOGz7jGli7~;SB%hWxighaCZ#dCUM-mqcx%!Fz-;fOB zkc3w}96TnlZa@BEbYycCLQN^wvulo}F)m(|b}E_k;6BZeyNvRFN+LOK{%z;!i)=pJ zX)Z$Mc+Q-U+}$@zL>#%(ab;@MDDLTRCexNncnthAWDOoL&={1ub4nl%Nk2BDM92iK z`M_X3BOciKWNb{Ce(o#;Sa?IiCp9ZTKcm|@cM$?KwIYUmu&i7iBaVY0luN`qr@nod zSs0x=e)j2Y38<+D58OsV9XI!|GfVPtK0u1WA~wt44Kh76%XT;`C_fRn_8^*;LWciA zR)9e!E9er|gQvS`JLNt1hHyqJcQGfB0a^emRCAG&1k5hu6|UfknGHDwR?_6ogsd%D zvkVannI`0t;Z<1QB?C(!y;TMI1S1f++$Dw1r5ZP3CvKw`#n~?S)p^EVVN%mogmVLH z^gn|c+{WlTOjSK@SMvGF`%t-a{ZeZ-@-0bl7EPF%YQ8O1Z6oMn6K_39(ipfO%F>ww^&t~xp| zPRDIN=%W9XlvFC4^D1JAw_>TfbBlw4o>YVNBM(L?BqsBHb90VQ|s{v zjunuRL@Vf+2u4nbIv$zcU6^J9Kjp?}r1jxDq1!m2n>qTFcEW(t-j_PvrwDy;b=Y5b z#9DZ|%5n<9IO9OcoISQY6NR0@#PiRdbe}!lIFm-5$ugeH%bY97oE}@8L-Ws71s;%e zv&VFtYcpP4aGdE{UUVg|8C1VTBi?35q(9d$tQaqCwAda}9@@e#*|$!ex-VT(Hps>c zH^$$jLPy1QCoi?kT?Ba0)Na7W?;y(_*pUXi7C#2#m7jP1%(R&XF_}LUH4#Cqv9Hmc zAak8+dCjA9rTFY3y<5%bh%e@d^tk*mTShfY=HJNfhwht-jhiaeO%3Dif05mm{M)wf z+YTfS7Z!&+s{Z%9z)5L7+R)v&Uq{_sgVu8oJYXv*(Y^dX^Yw%IM}M{t zs|(0cr!dr2KI$f)>C!_4V12X!KthO_9wKEcnP`cHLdX@f^uOu8)5s-KJk>{97m~4J z8-9O>QM`jZ=E9=Dp|%GO=q89F6Cq?2u-6vAx2ByXk|y4&{vXY|Bj`;pU6&jGp>nfd zlCE{CaBv^e?mPGBVMB2<3pV+;{zK({pQ}Vmx%aFmNWk>`48y;rT;+U=`?r^eYdH27 z7XOxV^DR+9*S1%vjra7rfS-tzWLM@;*?&v9VJ3ooWdTA%lDJs@d`)b3?Px0ZZ=g=1 z%vP&Frq$q%PySDzZW$(|MB`ZJsYL71FQiQGyHl&o2ul@wh72R9P-05>$q!&d+SRFW z=7$N7#K5V0pnb6^$SzYdYV4n1MylYpNb43XDtunnm)g&WcM&L61>_SREMBa7c~fL< z3l=B#`>$gKPWecIV)V}6;T;n$lDHL%ALwDXZ;K2@Fi|0ic(HJ>Gg}^U+2tZ>d9i4T zvQR=Fsy{z@QP~_M$Ww*PYiFQ=ja|MCb|x$pK5G5F*Zwb>wiidN*1eBd`Ki|cm1Bq3 zAY-+m&JgF|clQy2!Kbgr#P&O0O~~L#y`EC$yhqpRd{W+XQpo|{3zm^mK1+^eojxlb zgY@#dFZb8nX8m!{c&%46-H>63^q;+*NXM=>`}zOiwPjszf1H@Z{=6N<9$j)d?AB`W zJ05i8^*=?QCA`F&`K#B?E#E7NGpiks&SPtz2HyWpP*MET`q(lm<4gvVGV$fp;?W~K zR_W1I)`tXqYLW!vy%B_oaH@VQY>|VU>tQGgmlkLZRuq@NTNMyEQBnH?{hN>k2yK&<{-+2}2jKU^IP_>|Z(T(l?`DCh_T*+J;8 z(UvWyT4MM=S$WpBqIFWjzp`@J|JRk*MV;`189cn9TItiNa%8*Q&s$Xu7=Wqb7d8PMODLPii$3}-S4zZ(Fr?!=)PX*ljk1{<~|q>xEM5M z>EXJg6ZS(22&6!>CVj5LQSpVz+*^*=s<}#!zzpE#6T>!i!R|Zo!s4QBp;#!j8U}K; z0!c15Hl?8pc2^sUcB2G_wAzGl*x!U*()h7(S{-u0wR0#7ov^2?>!d~5N=cebr_cYB zus_G2AFDzq?CiF-Y;{9XfWg0ba6F5^m(W}wt%Qe*Va1V z^aq1sxe_S_wwLNcO#8QJ4xZnHUeTGN+x4l`>a^py(a)dS(l5#P6@5QeG8+}6{5l|e zVC<=8r*6wq&5k^fguSfvLmAm|;FCYiz41sXHOptlsptU>yCd2j6tZb8eR-GTp3wbt zFN6O2vmL{5=2VA6^*fHiQ=LcIyH6dTEL^{!8A;5T(c6{Sy?*w#=~cnnXGb}zQ}nkzxw;GMyWgk_m__-7k8C>+vYmmeP|t42$Ou!{k_NAh3t2qDao5s zY(0W3#n#zQU=4}4Ql2R9;osI6WhBtBa}T)KDmuWwvLB@Pihhy2aIZK8=WYHLpVZfL!ju*sCU<^yx*()GB|H1!=+#U z8NFdse)D`hzvy+pVD5nyy1y@#_D{^1{)gCgn|@zjdab3@wkENg z=HwrWFKPFpS6$GxejMMLRXUc$JcHBIE=c>v&htlN{+na7zFMRf$FIC^J>KG`#y?pVj^LniJzZ7M~UG@(0%+{HSx< zU|}OL0UP-8D=s95gfvH%-9*y+y5O5&ONCQ!rL0_-c5e)q>_e3<3UU@)a!^K#oC=K$ zVu%qpEE)a;uen~*6w#Y4-reXa6ED~sn;}2$C&19(_aX{-c-+JHvw}vSrX3J>`V4nB zxC^F5Bj2mdEX?-|CBXNN>t__l>yIWwa^f*DoUflPV|=QHqK&xt+6P%tD_K~hpPt)% zvPt?hCjM#m{ih7SIwih-HB7iPFb>Z(_U>C>Ymq|O;wNas2PSagB5@SnbGiZJk^z&30n`2gv)KWpdkqB*1W0LM_u28EAYZpO=RO;m7=_+Z zAdgXej7vxmrT%!(hp(V68xYp>%(rAv5dJANQ8BZ;H_UI~?YZ)@N`DeG7jh+mK1Ky!RGn8*ahvL*3g#vJgy&{>u4ke|sur}ZZ)H5aP z5C2{o#BfrHTkKCuRly`J4XB6;${Eg)nH|pE8Lq;|{iHUOM?Ny!J6!l!mGpa0sBBx7 zB%|qnj!r(! zbQWr3LzF-4tLLE!o+baMpMRXL@>pX@9k84_#@8mz)FKU-99!#0}L4u`_T zSwp0yQ{?tX%}HL^_|?<#4Dv!Yr|>gZ9NL))G_UF71o|GJb?L8lm^t7TGNdvpMq(B$D?7NCKy97hlAM$5GkCqt}lpkmw0TnAFp0%yB#(LjYcug+mAQMHJzQ<+}92{2#KlvGgJ# zk96+AN8(d={S&;;b#RY~=;o%=qE#kBr|G<-cC#l=QsEAzT8Koz_f!CsJQ`dxUyxw{ zc`pz6ihp}I$1|se0#4ePQ;9{**FzT)6LXL-1;h%QK{R3^nW@zKBIYNM2!II2Ucvxk1Q^ZXB?XhwC$j0QhCH93ePJ9bn@k-zwa>me?; ztk#JNW_6`;m=(z1DLk7Hm-{#p?cbuYM4~jf23R6T8&fzDe%-Dm++v!+6zX3SPE~=+ zW)=Q+Q7-NUh&SAfkAI=aKI54P4A*}xs*9{&#|>}LX3X$VI~r^(a0nL9S|x5r{K!!V zdJfpS)6+9k4d&4gorL>%g4nf<;FGg~X+}Y-m?-}i^YC1AHB*hFtF=vj0VYIn^jOWr zz$^nlV#$9N0UbNyD+Ui~rNpgIU#$rjkW%1b0tAijwi==Vjd19Cm^O@DlXa=&Ie49^ z`e1i3Um;kd#%L~QZU25Y!Y{2nI=Z0|eW6@1FBeQ!iqBqa7VBwo9d5B^zbUvDX6h8) zSc)*+ZN~{B08)pWP$xckKZg_XycB3nz_8_;yMaw<9d=-h?3lhIShgxPOU4~)&af^z zz-)pJ*CCXvGoYD(cvw`h0b8)a+$y-9Kita&u`O&<<&9wf%&47<2R7LTSbb;;G=)=} zBZjs^!SNIf!Tgutb-K{Eo|c)5Vg8$#s2(>^@4$8;6V~Q)0EY!G)1s9y(6%OHnwMgM zXKb3fE9>fO6wxcKbTFnHrrkI_ZVFxSvw5(@AH;RYW2Ja740lvu+Z)k4Fc;2lb3V77 z#18If2kQtk9(*KGkw(linYM>avLyy{Wf;sIj$Y$*snLhp1a#o!E9INVi}m$>+?WafWgDwtXnBFX$z|=dz+?ggLQ*X82w3U4lMBw zt(X9_^?4|ae&38}AF(*kWX0*yG{gmU_`S1>;rG=*8xsWHOQ6_T4}aeJ*74-QzVkc# zP+20!N_*d?#5dA=;kgE_y1{HlD=5$L{E&5vS9@`whQ!D}rqi)-!e0w8(F1da;6Q9A z{F8VWiaqn7eaN5vQ1`FXn@)~@n9PE{3SX^Q{b90zHG)u?`_9tfcs%tmli{MF)aw2eF#pY~9VhlrUT0I^8*Q5~eTy;I%X~7&B3FDn27( zgMKp`cLX>&%Rb$6Dcuy!`{pdS>E?d$U`iwQkCZym4B7o$FT--mg)Mwo@aU6S5*KJM zqjW$ChR`YBNLKWSa3|amv9A$WF>oAW4A*Si>+Y)&}j(BWg zXVBsX^IU0X2CX)M3|l~}e?TZ2tme3UwsKhbTWyj(X)WKYNPnXvm^R6>w7Y-JFKJN( zxW@8u9n;#}-T!#yw!`g)1MzX-aK*vRxpe{;3}C8DyrqMjcDPwunXMan9vXODQX+nH zgqv8QD@~7+ATWNbBl}4_7eoxa5(b+zU6GzhBs%tKaq}uZdN2j#2=Fo-TE6E9hjqlm z-#SI!Ph&oqIMn9=+3_n%z3 zYQoX)C{_a_(h8YFfjVc?FkBT!Vw0rv*4V<#7`YqLa#L4cQG;l ztirhqPTn$wP8%JtP^}!s+a^mA!4w@ouvzu-AICR%pIoqo-^*hPm!I#5#y4_=`^Yjc zzx&0?iSY4!$tR=b4cxr=6Rt+Q%dC73v+^<;^AyH8=8}D_i=@uJ zl+oTfG$Kncel8su5XIPzAo`3r{DIha;aK8W!{GAN3T2+=h#HeZJWB{z<6mI+Z=tJ=IXa?T5oXd9Ub5s-2bz!jdnyB=nF6-9Ww-=US< z>=%m@tka@vSbkkr`pOjB^gUyu@8y*V0*+p@*l>FbpO~mj*xR-Hy-q(5_q--=76LK? zZX59v6)_Wzu3L^K5Y>^%fi@9F#+2*V8{=nelkhM{^q)#3Vga_owmHUKWB$iAyg|&v z2XoWF)1Q-ggX!x0))k~X`MS%Ilc+gT)bsW&SLoe3e79plxcq7_^&{r`#0Rk>LzyWN zO+Q^v??)F}sEzdNvmDG2{!NpD>8ynF2i73V#WxxWF6=5BNsVjPe>9l10VpWlR z`GBKI!{*)lTRt0Oo_Q;xKVueiM;M!TKVV;v$^v+OZW%Vlb{lK=@7%iCUz~OS+2soj zqBmH#58nlXMK8@%&_nUEsRoxOZp(G=3AoWOXVnRWH(5So$O^w41zQ2QO%Gj{$edr= z@!~Ry`5-4h5IuqE`lE!EwuVS6)HRI{H8cPexF(rJu9^)lFM5C(RYC5M`|6)XuyX4Mogo`-@9;@q`W+Gc_zs-MW!tE^4rhgD` z>)Pn$PlDU-6DULEN8k2vu)p|wqiUbdOy3%KD7OEpl{)OBW#tW#Y~HC&vdQ5PK#Qf2 zgT`=`0?KSl2jyL(32|j{r(=bWmcU-I=}pX^$7J7`m#KWVpzV>9%b~ovub;HTx^q-> zrD7LVlB(#4B(bX`hi2NVaRGodjh>M@Y`~EPJO>N_hB<%b)MFEV+Z<6@QWC3klCa{I z$0jN{0WeMwZ;sA9ku9~Gya~KS;j;;P>U>+1w>D{eT_J1fS#DyJr<|oyg^@w@tKvm@ zJS9}|uHi!1WA@sCFcq#@JaqK)RZW2^ZksschFz>j*=}(=8k4XqEX>}FprD7+&L|Za zloD5)70nbFMM7$t8g9;8)gVj)DQ@T+QOw>glS;}$HRwfqxlB{`D4-||4WZtb5tf!@ z+=7KR7lNNnG4pX%_U$m3dPm+r;D4~A*owKIlEq$Vmu&wrjhh-d6-4y__7(5#jD=t= zxuv@6GWym((k zp&6DW$MozhsnPyvWk;e+Wm87x@rWK@j}#Y>1RzC4Fc%D4>H@qHM@$JCqU5L=t%<3y z-y2U*Lo+g7DJ3=8Mv#`niXwocJ5a!pf+(w0u;O!8K{i)j#ol=StWA|nDJX9s$q;s@ zE54)_BRM1Q1@&790D8B8o%d}1QOLZR7IrW>t2d>>_-)(ki(;%AVfH*()!v7s0xMs= zhN-yGXLka+Kv`&|x)nseiGzy_U;(F*{cO7N1i%T-LIC2kIr(7wNG+IDkEb0V>5L=2 zO}ZqGjJejER~L}jpJEq4-hYgD@stc=vf8Y;$vzm_!_~M5K9c)u@|s{XV-%jm`(XQS zG~Kz`9lv-1&;P+3m(UaCuBX2MGh-uMRYp~+lINe1y6klDU%pEknzC~M1Tz9ey<#|$ z9~p>2(}7MGJO3WWhKncoG8Iy##$P=qUPlK#%JsghK(VR<45>e2oI*G zn~Bpt?P;S`oV2)6bx@n+do@(wJ*-M|4+$8%Qg zQ82U$Qo}FV`la!7F3IstDZ@jdMctbVGAL1SGK~aMi|!RQlastxyfKX(^h9FJj6*cr z2+S+JkhH+Z1PC(G;*9zF05#p9E81}KwR)C}!z9mD>QhqbPJ>mZ5{5e2>tvGG4?*0A zW?*mybrhFB4J>7^5w1ZR{0=iwJ7)7*SU#4nhun1v7j&LXzCuS!s<6Z8A8^2;Q?4vY z+K~T5k^5QDNUE3+CXl;^Rgva>EPpyv@UGub8f_FNzwQ#P&+I1&swi2gAa{!0k`4aG zH}GA=*CJz=DkmLJHFB3NyaQ`C<0h;qsVAXHBP}e>5wMRq+`(U^<6XBX4ZoMBFG~bN zqLOvlYs+c8iln5?`-LggsI^!ppE6Ckh$`gbzUr}i^ox)Q?~+Lo*q@3_cGoNVK@-Y* zwZp615K3VZ{9JiCq$Mo^HOFr<2K7;j&NU(EFRHS}fqg!%C?(~6l%6>%?f*Ep^N)hY zmnVEG!Ea!L1Hw(U@qv7Hth`IZx=nTIoqP^a+aj4pryQXrChBy!Rq~Xj^&+?*G#;mu z&AyEgrq2WqTAbEgI4M|Hgnv=$dPC{EqtGaM%?GjJO(C>U0DvD64dt9B1*qt8BJLcN z`RYh!w%J;!Prhp!Py*9B)g^g!GNW8Szw3v3Inu#;$PfElaxL!9oK)1RT#CW<4fZ82 zy8hKlq@rxPiD#_uSVu?Uo*QIAu7DOP+CVZafC?fkr;$Dw7~Z^r-e1cPPGflm-Ht{; zk*u+HxjuqUF}G}6xu|G<2mT#z%5MOD>dc|$5jse9Giu0O*OC@1apEmYZ`YCG6qMnFbzLH10 zO2J$ezZ)rDOc3f()3#qqkotys$LbTplSb>GR3;NSv{h_j^T4h1a6@`<^0>cU4ig!u ze-Lw5b#)Ao)S$K(|8nJ3bE{3!{AbI#O9r-%z@z~$$MwdNgTc5dOg}{jMrpqI>;Xe4 z>}Y(Q9HoUTL?)zP@`Y=kut>)iUS95e=g&T3t`9BY%~byKuJ<=1d4anY0Eaq>EZ0T>R+7jk-_%A_NbISG*pc~|LMH*eLHDn#3Zsrk2^N8TCX3Yh<{*)H(IjMAN_amG21jSsqcbI6 zbDS*ojjk}~Hd%Hu{Qf>PY2G{{RoO}C!U^(@Ezu=7;ePy-WtjgbGktc<)w}A7y;CB* znF)JWf=3BNvnWG5(#9j)$5_w_B{9;GI@xlHVBm`_`{@oa^OoA&Y#(~p52mPgA9y%R zGv+e=aJtXYB|8&Kteo%BDSt~Jq5H1?p$hsK*@R(9Y77b#)fwy08+>DEE_vzb{&mF2 zfq8NaSCM3;f~>jyZcj$$W~TbjMSb`EFPMyOmT78Gv)!?BQn$vRxD@ugutvctnLl&= z`vlW#B8nJI{c9A{yx0A%&af42;AHIUeu2q(HCZdO4pZHqYJistJzo8oL2-Y2(?)7y zZ@4l?Cu4&vKPP;PmGs2uyBe1x2xe>0xH*n+Y3aS+q4MyX2pAqeq0?Dxh)_Jn22>Ek zb~P48#IARqAZE z$-d+hCG5j^{`2^d6TKUKV<~;wRlA~)&VCERiJ~29Je`Cr^T~w<Oh+)PZ8RWhe~~J%(gJfCWoFljj_*Go~0}RAV_=SDDNf} zM#)dhv%+2@2P-zYWNj-5HKC@MOb-17YGt11g>IMV8rjx0Tq&fRM0}-Lsl4odO0C545>bRqCNqD_Wb@9 zQyfO3v>o}ghf&$5L^+5J8c_m`XH!Wpq1U0JFH9GB9nx{4;64zR)dIZ72#>AM0TcD8 zxAj|E8R~BiYC=9YaOG;mZgP~G*}9T`GN0p;Tg74N7nQ(CoXZI{icfS}U60dePRnB) zQyJ}y=7lwwD{rRQOf5PrGD5HVX2&9g*RUUTD%9+NZpwwrjaUJnP)^@wUw@M!PaUlW(NY})t2IVs*;xQ#f5@Dn3d;H^nkT6|isP_Uc}PHk@p+t7 z`82E~PbW=X_X4{nHktpYO8#AG&b1i;l=)gy$)};O3$zafC+sB3pEP1a5-G}{!Seko zLGR~|!=F1bMe~PgDc{j@uRf#MBvwxB6DTK_Vb*l)>UUUN{jHx8r$Ig@f*%x1a-9id zgL9O_A996?mWN7ng}sdNlv2cD8TkFifmEGKiF{QT5o;QN!$Oa=GSgD;36EkJ(BM!s zhdwB42+P-JQx~5CcycM{jP_WkYO0U)JP!BT+)~R(qGkK4h?v)$Rg97TBG!viOo+Gg zbes1b8{HvSw&_8F~iiD+IsgB~*k2$iG2l$BFbpX16l?jKAsidK~! zLgX%lk`R^sR0GP+I76&IHn7HhBP>}pA~e>P`!K#kN0CTUw7&K&jDPC483!yOG3m!u zCP>Q623M;xLo3uZgq*Iptp{4DYiILxl|NR~t(KcACsI0d+d0RWko{v6K@a6mY|>Or zNp_4ky+s>)cxjqk|NzD#N&`jaHmW2%>;ZZeip`iR2olOzAkP4POaLi{3rl+1J3$ zH5ioI&TzNlEj;B>g9a>7&unu*? z-K5ZnxNzGWQxP#%+KV;wi&x;bvV`}nEzg;3=;Dx4NBwAo3-srFxj>aZj8eHQq&P!l zXmySFL`3Y2uLhRTe}zd${g84`;#2TtGsg%4!qhi*hi7D%9OhJbH`~_FboBT1Q+WA$ zBAh-S@>Cs=602HSis0_3yJ7^13j_t*lRI&wl$gWAQc_wK&ZDMw164He)@7bYM<-%z zIBcAy&2DwL)0TtXi;&NyEU|SwxGNIBPsK~^c z0FF~h{W{EpvWwWZ)xE z{jWE4*rx2UOr{&rers5*f!ro-z)ChQpEeRtEJ-Aj&Dj-Cr7Q?u9 z@WTukx31uN1N4dK+Ut`{%>F1mzJ+j9`_1o1OO$CMq=0#eV*2*w_=X>D3cPzqz~7{ z&M^yG8mae3k|%zG_(oh}6u#-v>R#p$j}R6T(DVcuOSQ;zAER7E+~UQs&RkskKFw`z z*kqPwSi{KL*avp=$Z%P)L0*I(KBu!3OFzfqU$Mz+2#u-GNKEVdLoh6_PgXmi>PGU_ zr6A00+{{wlb|W#7RurGvJtMR=H|Mo*b>E|RbN0YFES2VZdDlfd*9?WbM?5@GNr-~Y z@eGoSuw3Nz)4p)uqOg+1<+-bzCF(fd@^I@Z)=#iylOPq#2B9$0@ny&>hv)6MYTY3N_kliIlAvT?ne9ohfkvLcqDM6_?-ZOR@x&4pD7h0}_K zM-YWmXt6g_(?(y0S#_UukbfCrDz#5wGIMO&eU182?a)8$*@Pv`0}}qQIr+sACh*#* z=<=XVfko>9%g0Nn(90L5cHZwrow7xpKV8ae0i~By9O^0EyJOr@P3~XvT~{vEHW^fQ z=d+$VdZ7;NaHgF#d|s4YI_LX1+Jjih07ovV&!>lQ|Y))QSR*U&J zH~V*s1q?RZay#p(17k_U=OaYbcVNHWj%&&A2LeBeCifwNkop>Te6bvc!TSO|i5lg} z%^{NF84iYi%6@A4Vw|K`EcRY%J7Vtq4es0xY60S5FtV3xTk1W?@Py`&+Sr)tSo&`9 z5Rb5uIInWe7MH0nQJ*NHZx5pSlwl~I*bkJ}yQq~SJk2<&#iG4tbYDM|1LKPK*hbo$ zcBQGOMDj#ci{jX{$_rd`?EN-20{+U~ONbGT-lxqN>`%)F1#(^&eQrvfP{@o8`RzsY z^B9>m?HBZ3JX@dv{^)zQ#~UVr4@r+);s@UZh8QFVU1vIG=aIF!biYXOX=!)2DZsfY zW)v^}I3Fdg7C>QCG^0@PMFBSHW2zF69Z=Txq58?k=C+UBZ*PSJD<*CX_*Q$DZmeh8 zp4aI2OUyeVBFWX|J`oK=ty`)%6|V}#-jbcvinZ6vNH!?9H)yrL)|0eTY`Pt2X-e2^ z8U!^}i-md%*t&&rFl9@zts}J~9|uY@^Pf=RqTH0ylNu)^I~UqJHzZ+Q`|Vv?k~-Tr zmCvra7(s3Ba#h77&g>zSk1P7V9B9_u>L#Xt;yrkiT^I7uugCJ~XUC4u9#01a=52iM zKEG-o`qVNqDn7&-(?oOI((%!|am&;Zzl>1`u;|h}_jD@n>G*!f_|eme-yIXEJN+SM z@OG57r}!k}J-jX4i%Y78%;f2`jMS`h=d700oMGqO%Iz0RWMMBhBgs^J8y)u#mnuI4 z@j(CVmdy6%sa|X-5KG#WrkUJXi9! zKo?YwX6K>+6?%t?sV-fy-Y% zU~BIGm6RzL)A_ME{AdF3MEyuYiIlju!pIH8Hz_IqO3IPDgGf0MFcp{xsSx}43CJ25 z0wrLgdMII=s&C-Dn`Y$jemC9B=i4sQDtgEQ=}Cc8hV?`<3Mn)WRzV)4fBe8GQo_8S z*}5NIELcV6rnG&JcDr11DyzKwo0~SF*Ssb#(77pL6S>cRAyK;)hdbz z^%54UVJP&ZHYqW-!i^}F~QaE-G?t>-rD!;Lh7Qa7y^f_&GLEZg`5mYctSUo6ycoS)*#uhU0*{Kq< z^`q^3(THJ_%esuQpSL5ThGWQ@Br;Z5x_#mquP2+oY5VNM04B3i_vig9>kO}3HXCxh zkSqKH!oH1a8TMYGScTuDsiHBd0e012)*dl8w{H?2nGG>Pdf59r9eOx4g*?P7gzr9}3sPx{0NVh;ZqRR@@1?cz zzN23mZvPgf-{V5ncUWkzho0mJB_j!>*&kgGFVG67mZ_r)(#%KK=hWdyDzR>f>k;=` z;f%qY4mvTl5j)L#?A!NyjQ&nA)Z=kbT8x)&4CUMMS~m;2ytC1LxZ4|H($C~sx9SJ- zk}BO1nhizvpSKky+Z1hr8cNuyoHGb+%wc&A@5IKOZ!#q`OQ?I$2SVSUj!RfqaW&#% zt)(#8natJ1H&!@m{4iO$nQs|ytnwZEVc~Hz*F_{c^N|0i4edV#X&nqX_S0b|Y;kl4nY`EJ7OXDFi1ien@Z`}7?xfN#Z zY7Wu7FM8`2wsAw$($lMy>_!_pVS=P$ko_KZa%U9!PdAE1_vgm@8 z{^xNa?fn7NQEw_`{M<^Djzi1itw5_eOFE~kk^vI!DUV+WW8|0!9zA)s#CKVjz`IQ@ z-?tD*?%itF$?hT6zUo%LA3Kpd3@*A+Sr*LH&W`NxsJ%6YomTd6t_BUSLy!wNeXB<@B;c)MkBDR2k`ss1?p0XBK_p2jzR?g1vZUf)vs8o((lb`BA z*8FVdS1u$>{X(e$G}FX6DIE3=7M-vTLW$KHtxDdEVR6+ zhC1HY-+%X)fmJM_f>0xU`ShnS;Q6*z_szn@pUZaCyEAe3_kl5OllOeZWsJeK7_lM;e3JWcNvMf zC;*xWK+c5JhKEoh=+I4`83o2}n!y9G=m8aL8xAYX>k}zg23%C%@ks`=@40nyFiVaY z;1?y7TMQ!P1bKEK9qh`*&Ek2$=%)=!&DfCRysi4>ApBr)frurv%M zPYTNvikOE3iTrq+4iuPu8MP4P9vRd~PBLXK&U@SUYRqz}wVY~MszuB3T0lLu{+i$fel-@#Zd z9GC>1skl&?5TeJn65o<#FcYL{QspI56^v6!y;8y&B9)fp;zLu_u2Z#X(sU)#^shNC zjM8Xl8=u~r!B1ITqJg$ z)>TZ=rD>6`j0HTe>8h`g#Le+xTZ-vTY2GhYGG0u_yXwRuBmnl@@v+XPwPcwFEQzin znQ2RzdDodgY%&O@7(JTO3uxFLI^mSHW>qX@Rb6MTv1biDwta_-w3Y*M!Bo6pIk+B^ z#<}xvau+jmiD#0RjWt8i#*#$t8cHJXawlU);@Jilnd}?JgsrswCEioW+5{y#N*YDqNA(Qh$WmWBoHoQPc8IJ2Hd4D3LAXw?_H=-z@88N{uZq`ymVTPqXZix)kGh>=wF?|KutRw3Klt2&P%XkKyDT2=Ln z#NNu}%X`&Kl+~l{)st)0)4!`{8EfWcY8EYPmIG>5a%$GvYc|$uHh{h-3+36|22d0z~8FC4QHAvght8$`zw4W%)##Tm|Jxsm0x( z>ba~bIaCT(joLKt9s3&$Wo7~2+R90>K4mGS=5%p%sJI9J(JMB+Pv$y;mjZ^Aekv~-@gD1z&hoI%oN;{Fu5Ht-40u2%_D();hX2*1IoO?H8y$49ugMafpDEq-nlU`)ukDlTF z-e+@-V1C~$#=fsl(9k@w9^6iT5EX+v{X4HdctlZD?mhHiC$hdR7_Se;EX~x}fB#88 z;J6@=2(L^@oHN9K)3{Nf$ouQB0uj;J$shHg9vm}F3qrH|}{Qe*QdU&)Oo-i)3r$pv@P(SpZh z(g!lEOT3v-*`5008<2EusLlG=M?`pv|G4J)NSEJu3^Svr_ITFgNK<4V%rA8qIafl{ z2a6_~P3x<>0q9e|1cVL?$i_6KfEzH9WOp9$2)p6xoADL95l6(H?`I| zwPDS4ygs$XJWahgzH{QTU_D)OId#}MeX=qAwVLDLX8OC_%ocUih4#}3jcMG|;tjt> z>c-4;SqngZ7HBg|D1c;n7c>j*nkD-<8*5JycQPy6+>~KTS-?N1NIQdHR74-_$i_?a zdt;7~cAn#PR=6J}+-9D)Yo7nMgGV0D_;Fq{eby4ozaSC3Fr~a8)wLie&oUR!D&qka z`nVul3RJn{pqQY`?#HMcyoi^wxJHe0HMw-|P&XY@naJN$=|s{p%F>XxcN@mK!`i&aVY|acemC9b)4lE;{>`1bKOc|! zS_B`smdJmW2>Co=@R@XbLu%TztsUTEKOrBYxiv@rMs9bRc}qiA1S&72*}bdJx@YK_ zq36jEBfsa5Z2kl;+A;rXy!LU=j&Wm!-;67M-YV@8k%ltjOH!v5w zbC|}ehnj{y+GMiGAR!Ms(o@_V5s?Hh9c72`%IUuPh;#@q&Obg0KhCC%_V0$&Umb&d z_G}EkjL15K2c@Oj!fLwcG$lSv^?^&K8UHSSBv7<{=DZ%rF;Pd(7^-U%Ik&seEtFHE zuzUShAbVnHy4D$VH};kRC)q;n(8Ct(PwK@uL)D zkt_STzrQr7J$}8Wp4bw+cYFQp=HMj+fM?gM5KaIjrd7$pp~exUnTk44=FHg>GDxQI z8#1s1Cm}nDW9V$0iX#3`*9DhhIO6rb^>5?jv2vy=C3hL&s{We0OtkI+FOL=KA5W?A z2VOdchifeUWu7&8Y#FDN;F<20;i=X^90~{E?)(5SyuVzC&2_bl3HHDGMfL1L|X_kuwR}ZyIws4LI9ljJZaC^)bp*cefIw+5uaFADQ6h|Pv?d9&R-A7-tIR9^ zKar@g63M-P5#!k=Q=|lOOefjHAdNVr0$hZ{w0y#Y)X;R-aY$5Sl>^XwvEV>`!Hgdc z4fM=b3QN^EA7MSwgg0p1O;Qm>aQ+jC=B>p|AkwMed%q$9c4Xr%YBt9mkP7Wi`Pr)ID){*;Dl)51my24-kRcql>w z@>K@lwjfuP9!8IPjnHrg@srG&CaFqTIAuCDTReH2`y3!Z33pfy!le19nS2tpI@l z3Sz*lmGLkPqQMakFSZ-Wz}l(4FhS0B;6 z&a*?p%h-0)!ii^TC|WawE85eWYPq}2Hy;EaXK{GKIZCYtGXeJ_-RCveQ)d`=S zs^eLlj+*^j+7`+4XMW*?q=}ki6Y0vuT4{_ycIh}orA?lb;?P6r_F87%VW43}eH$A0kj3f31$lF@+5?>Jw z!*GykTPqI!=&Qbta;r5D2RNemY?dPQYnf)aj}1txTvfqLj>J??`>L=s|65juY|k26 z12~=w9}H%j*NXqj(fFn>iGw2GX4}~@3NmOal1=yvFb|CGm5*A2oh}@yPC7eiR~p<>O$Q!a$&? zuK^jX*S9&ZB96{0MruX6oJMN3ra$u;iuGll=Sab}uhY_hW{20kVeLzEk{XiM^~qSs zn9EPD|_()ZLKN&A$&k>gWc0uL)Ov!8tr(x@sGXGIB zUJ5FcC1x2G=%5v+As%9?y{oW%1PKEuBr?=KQtNab)k}NzSMyEC|6T7e}?eoD5_8d0Mew zg-pSCa`DPuYRB>QXDDkb!EIAm11hmk*toQ29`C$~>x3KR6+sT#`in>nZt#ss|4T?; zkE{v`(Oc9nX)CcQ?V}GR?y~K}-tO>HN)5oP$+NwmpJ}CvlRHied;d`Qu3ZPbT55aG ztIDTI`KCiNj^^)G)pHMVN}HyEK;-HBp_1uz6it(oenv)Pma6ps_A zD{@Ay3Erq0sbVx{$Gh_orwV%RH>-M;)wvcrsPmn~tvDm;54u)4D>rQ4}UQz!vvyYlE?A^>oY8qWEU(v{>_HBX! zwJsWU(h`0#yH6KtYGOAZRB9d4BI<*7aLw$blb~*gvBv5>#LH-xwY3}MJyB-e7tutqT@rhvU)O0EF5%$@!}~VPkicdUvtRJ$*=h9VXD~0AT8X% z`$^G3yPOaG$IUjsCx5UUYx4D2zg^^VZHV5P}tPkr(vK$|--WKZN`bY`r_kV7F9YbeXEOd|;-p&bKtDHt^~1uYY0d8ymYe$7yf`^fQy%T5!u&dzZd$!*8-n@Vw!vDh7ilCtGoGb;-WV$0G z^f0MRc=>}5`qWSqhE5wT(aVF^Spl$S%GAopacfOC>V^ROVbFedct6J^N#0m5d{N=9 zV?Wo6{(C5;GzCsGt0c-7#7?9FKilb$!}`=MD_vRvGQ5@j^J1y%VyUKY(EJx@(f%Ur zp+k%iGYMZcIx18#2culz*L}!@7UET2H6(uTAb=(4#bfU6;OXU78B{SIRCQE21xTQT zl{JC}>PNd!xTp{CHac`b(0g#tUj?X4!p4ji+8BhU3`Q5mX-<+j%gBfOHzC`5iEEzb z<{P2Gg9@aM4>j=y)OLq#FNPj5_t44>42a3qajOkjb?*MC3A7)kAsY-Oh^=uM4p-aKd9Y}CXb6^!e+~}c`CKyIf@bgbq<;Fb-0%3KS{1%N*a8#d2oZC?^v2)&H1=Ox^ z9Pfiq_KQynznAGlt41;nqrAufOX4v%tP`)2%#Gp#RkwFCiE{J*^gS1YwfZ5n!268%eNhcH(w0d z%k|A`A?Fiu%QZJo?(7trd$}Cyfj!ox+hh2h@YbWD5BFLjx+aK$BU} zDO(s^F?^Q2Z;%+M5o+p1RAZg$cJt7zS3fv8Cf15Bk6kOH1WYV$V4O#MaiNdVpKE&$ zj>ClK8EVCjiXLB}1x72n`NzZW^r z?NJ*1g$5}QO8M|^ge@gnhd!bVz=lheIHl0hiugQNhN22d&(iR*qG#FH-b@0)<0AAR z1~-D1Yu2D6KGC`b+lr){Syna!N^C35HfF*1eWW3bKzG(iY+_8QiA}}Cm_DpSRNe?^ zukC1Q)WQJDCmC68T;M6!{>^|^xWf{rfze(vLW-%{k0#dB+z95fKEHNL$W zHBwcjxUw+1P*}2&M-ZdvrR7yh(15c(1bQyq^kuF%OYCFL2-fdLhGz^;ERlLC;ifE+ zuKpqhU4ydsm*iEJu;>bF3)2;I(+_#HMG8yKHE9VIJkBO2{9+5XV@8PkWRTn?f1Ne~ zWF3*Llm?@RVBoD`Xxy*>eI#*t#kb^8puL${Lb@FbK#({CTLw|v@zPN&9fK>FM{+Ld z{m%3@!M))x*eWhDf!ke|j0Ee9ZxzCEOntP6j)#d!xj`ztoD$t+9{+R6_~kOMtC?yM z;3|1K)p%Kq^TTbBCR!EzP<0K>Nji#DF~_c?mm5o+xh~E2HSEWRUbij7X7kzcz|}UZ zavbYe*xdc9KK!AUH$e!V$o1Cr^|sg9ShoVxv&hRnhP>bG9ZL(jzatTIRB#wUZ$tDW zFCwcg;90L`pt-o23sKx~r5+vL9}IWNWs=+Qg-LFVZG`Y`VZ-Pn5UcA<&;sGFl6jie?CHeHCR^fM z#C#qOW49n;#z!5}i?=Mq7a|+IQ(O?Y$BQ5e<~_+rUb&{Nk+XYwmUTprG&T%-g=p}s z^YBGcFpb_DEt1l8qyg8Sp@b7pA49{Aq?s2JJ6^13ye&ebzz45Hq_<*diuupnG|O}& zi3Yw^b0kyIYa2W%Q>!dXx7e;oLBLf{NW?^4TS0_N?d(SEhj1E=uas3|U}vs~kac_$ z5#yREea&%6R}qc=l@hMyGUAo>1_x&bG27f@rlfh4HWO7MQoy7u)yZ~y2KR&zpdp1Q zR9vf0!bg0@-_{?ya!hSYB@p35k+0xk9Iz2~l7XQU_lGKU(0V+OjiN1Q_L!OV;^=Vx z?2A;u_F3*@wa>Qo6+t1Q4<{N^r;|+Rk8s&v%Jke-P($?!eqPW)hpQ1|<4#S5qUNd< z<|E?(r@trnEiZu|AO7OI0>T*7|-CkLgM@LQ5v$zs)o@Kty z$jI+Dj`GutU?xV(E&e*n;(B15Vh;)gj8ERQpb<#uFO*)7!K4A^1W#(q3a|q>BX#j4 zN+sDseIDjXBt;l_1pan4YfZjS?u;FSMNhJZ*6JST&4+(=jTY^lGu_h7-*>N~ObX~w zxH#mS{_WaK^SMcS?_~Mw>Hb#=S7P1whujL>kBDE>zbwFNSxYV9e}1iX&JZzpTQpe@ zT8cs@6mStFF@^#j?40wX|K5Ju3K4(JNM=({L{3Y5k(g}duX3?GZolKwK1jF8C4!#C zk09OVW92wUK$}CJ#)Syue_&>ztcDs9qYSH-MWwtPoTuOuTf^Dg?AEX7y?vj9hB!0e zxZ~_hV&-Q=jO35=>np_X0Q$R_hIGS~m$#Aia9Z3ySiqKD3}H2zoYFV6k&F>csTYoW z0wtPwmd74p6|)s`c)lnNvl8~fGQ6>;!kybV9P?)a$6iKOV?qsEqyN;TL`W5o3DGnB zRGCiSnDY37|0|i>93;Y)QiE4r7W&C|Hif&FK^GmzDgPxe_POizPFQ)bDAx<+q#Y#V zv)CK?oJjw3gHuUuR!7n&a+GE#EdoET65}rVaH9y{#0#b$2%$yaO$d}ng330S53+wQ z4?Ql^a-bYA(AluJch>cJWU|XBrpKD6`P%kfS=cPF8z|EIJzJ7(B5b{=3{Y*~_Hs)N zJ>|=`BR5p0*~2;2YxnJ|@2g+0qF)^$x73DMoC6M-kOBl&JVZvKXyVZ0x7F_FQUw=b z#efwrwF+oY1N^4!NT5zbKTiiMLQ@j3h5xbe;Dk9AvpmT9B;$3i{iwIP> z3L+y#uhOhaSi)!7GaeSiN~^s7?BU+kCTvJo`w}%= zWaJvwo4q8MChcK?E-YDeI=ul|N*v6Qp2-?`pbPADo*Y7tiY6CySWcXbG5Q>`TyZnC0{@x^wwk zpuT^oTKI5g_-xBPtK~p!dXCP=@`DfL!{PY*o*Lg~Q!vtu7%QZN?oiGVNr|DRS*v<=piKe=z>FzYMJmgF7;oA5?6C8?$ zW>Vh6S41Xcql{$*2}-JqX+U96s}d;Ob?Ir`AiVl%ALrww_B<5qHr24i-Cld#G@FC~TGx=i0)y!MH0ShRxn#)%;M+C7l;l*{jxOEU#x0+w3Gr6Y zyl;>ks(aW=OV3L$J`k~pf329-C;eha*93O?0XQ!!<LTz&9~fQ4J;yb2pKFtrp;^=_s+7A91~G_tpx~jEan|A?v_P;R zuvFIBrBLm}{=G+5i}RZJ{^x}I=75yN3_cI-B2M#~7ZRc98cgf97J`3yh5q)|q&%PE ze3{((-Rpy#B7PE?&46a@WM_HzeDMk?T0@LI_v<@*Adq~GivcI!U-#lY+9dmXF_r`{ z`#f60Hq|T@OD~c`<&s6CiEu&rvShD8ANUiqO#N#p8&c^{x=6P0Dnc2uy`gzP3?1OOO?EA`6#W!JMTHkS$h$b_2I-RLs#jHKYO7M|sabsA*R236 zL{#7D@tc6+??wW&E>iFt@D;T$R=BVPQcxiNqt%5W7j9L1kzoS<>i3FW9*WNfh~pn? zg?oq~cQM1v_i%2F0<_IQ7XE!w7(e-nY_Jo~q^_nyx^6Q3h+*ydgsRhH@vyO{I?{8~ zT(0~dY{=`4;%I-6!#247j9b>r7a`cq9mv-s>AK#`zF`Ag&~7v%9xTr%0yW<3#mX`< z3X1H*?$nKM3x3;Z_F+h#V}SJC8F^+EOjH+%v@5wAs=`#MY+RzsbX$q&y3^uTvB~{% z%AuH%qENLAf!SE6k7G*vS^d+_Q(~aOXD*RX^WtwnKYg-5ef|MgCy2_oYVrawET3e+ zC-ddENWUX?STVmvLbvZXzE&aZvwex>L8NIvL7lpVwNPIrSKy0ht&FtyVM< zfTUzJHAWgxRM5Y{x1%!zPJ{q|lQOs)bp6sO(LiUThc_Kbb=}#)HSMbR;CI8ut-)f$ z+wo5vHi~eB3<B@9x4$`6KV0yJl0!j}pc_@N1SNy_A&sGQcaLRr@z8ne~u`VGafFne&KW7Mt?;B-krV6YHE5PeM3(RH5#qE!& zr)U(>3LLwZET(ZE77kV5MpsL9f3TTZ&uh6=nn0Hgl1$Q{lTui6`R2T3+TOxe+Cl|D zT6hbI7k#MI32CHh{nw{@qTid|h1~Ml5%8__Q8m{^$@4oE1qMeDmvt(OoYv~#3_bga_}(ny}wApW8Cu?VBcv( zp!$xO<_*0yW|^<)=#Bs_N%1&Yal>woJ`Sp(b9V48okJe+dmCVq+=ccn;M{lGA^rkL zM>_F?zmq5pdS&^QPz-1L1UuA26`Qr(6wp@L5*tAL+?%n$2NXrtro-gx!dK!e?$`tK z_c$5Fgob!~zQD5+d`@bbQmLfdtE+iPE4tT)K?^zOX>F>&ChO4dJFQr9`|Jt136aq~ zePM$_B;D9}VLQYJ=<%+GOfqorUSKPpJSQ{zIt*WK=Dr7-1f4wtZr-5mChsEGWnS>f zeI(dAK%?YWr>L>Nd#tqTBz*VTRmyd94f}2rWQ=iSOX~r{+D^mV+^8$+i){xxMd(ujzJeKkt}V!n2xx%<+Wfh?_5K^ju) zXOk5<`$aq1+WTY8@5I{FgVJ;PB1|mkvwxp%;J{qDt|yS!wJv3nut@g^RYK z_D^kzFN6Ok_0Z&slUT>Qc{5GPQQ!>eeFW~RmMmzwYJBtYVIBJOT z0U4|0bl%KKtbNI~Yv;MQhiLT&{f43YN2QU8a1IV6sp{c~1#iL-t5s4US%Ps2`txhO z2}F2q(0;?BNKo-K(;fnnPZgb5Axs%-Toj}uY}y6OSH=R(CMQnWJYU za02+vo>1HA#DAjAzjA?=$$!dUv0c*-XvigJ9Iz6wwPw6`!E`1LjBHq zc5(G(K}CLIT9^aETw^+Ax&9l{=R`2jFtqRWAbxUg$OQN7u1_WLp1}?NbZsyHbBAreOqqwn(!FDPxP;%r`lT0nC>7f@< zZ~ah(t6Gb4qXUgf${6i_82JmYlttq4It{>bMDJ&I2BLi?4uSV%vY9F z8z-AtHdK3u0P=;;mR`$O82-jkd`Hcs_)pdm&zigASy#j>!|=W|u8#snIBN596f>To zC_YO~aab1JB{WdrQ5Zm~0HimxJS>(xB}jpcoRy10j%}1&l&IDxhdxfcHjsKDSa2yFd^wtE%uQXdNCT1TwD>ANArBamqBk z@#9!)pxKm`$>&PMh|givTR8-zZkn9)Glf64NcjXIg|ysaE6%ipNAb_oz(lo= znJ4Pgr>G6sl(5^6tywewEM_ufd1CCR*ksYAebDSK)eV%zKJc zycL7}C|!MCM{aLyW+~mXZC{_0v)lf%n`$BoCq#u5R`F;ypY6;K_ReKm7 zPyDW$F>6~Rk+=S03rxriCKJ(JZ1I=1b7Uh-~BfFv`pt8mv zOuIf6quN>WBh)eW1xmhk!O8LCz|ZlAH=O&`*gcy=WV~2*g4#;IkXJu~j6KdpWe}qj zS{PT!O$bQ7JCJOy+C^mKjIuA*bk6`9$t1~I%p{U9&l!w_i>J$O46ID1 zTA4>7+w2F0CezE(bCcxM99USHMhjHX-&s@VY`^%@VII=|((1g>J502u<*E8o1S6@=MixOV;K<`my=+hVO{m-h``jW#-Vqxjshm@v!&IuKn zDN6Zeygc@VajoHTe|fBDC-UdZp)R~H$Y!#R)ll}tbT5We5QAycJa=;{OL;S=#=tyH z$JNIPZ^f%#YUubBsa7?*JO*UfoFlP_^c1{jcvH>^ie1M=RLR_!O~NRZX_AU6MVEeh z>nNeWJi^1Dxuqbb21^5zM^WN2BFn{Xn!V3w}OhlC>&hXr>K{ru58+4Ji zDgmO2-|Gek1UC`YpSBEDkh?qBwBe^f-$^s&E?uFN!T!9OJwtXMojY2Ef|-I}w`>Iu z5>X?CD!W|K=26EfEP|>oPj?z>wa#jHQ@-pIX78~MZHJJz`PQCa^7Wl*_m#7uwFOVf zzny@E^xc`LS>mW&6F-FrT}8dSN))=zcz2!u&Y;x;wyecuar#W>3-Y&a$#2FXx7Hzf z?YCN!-*ovuFQ?3{8ss$`e)}UD2Z%n^hMqMkl%p6YY>VCzJQN11*8%l}L1w}#FpU>v&)pqvw9enP)XBh20Z@J|lajc1H_H zp85q)bhZr^!k2`;L^g*freJ4+r@rm6){-xH2~p*U2-czv22;iaVIS+B@P8$Jo=?=_ zLD4aV@vm#$SuHCOW{~yx=IgChaY_^S{9VyW{ z;XUe74-xP{J(PN#HB%Uqe8!YH2st>pQzVmlll-bWBX{0~g7=I?QMv)cfHcpwqqjw@y$9tLwQI*!_0yJL zYTCLJUf1@KTb9K~=iA>?g<9X5s_?e-3P`XTgRca^DZx5>6(#&Ki5kEl)wL)eu zl!%Ge4J169|N9rBWy^*HSM~k5`tfCC#WV99N<jTF{$EkwJvEa3EXRy7%cTF{5&Z;xeqe*vmxKDvdfASF*daiyQY<%&E=1heY+-_q}L z-Rlc#e~mAg<=02Qe@efgU)a-M^Db`Zzy1qL{!{wJ2cnvbJtc6ZU)sYPso?Y{I9M_Q zBqDiGXl15_gC#u!?1=Mfe;?yYzkguKX}jZw4V$tmX?@am{24#ktu*!pO)-{FcGVvTpXvjtxCRYMQqQ6>u=W8lh& zal=~ov@-2dUbC@rQzb}xZG0*J6U2n2p+Rc%WU1h*u?gGX-!ps3%7mj4lMaljhF)@+ zMb&Kg9a6NoyLf44;?%j8j^{RB*l`Tuf3 zLeLszvgprDg~e;CStGQkWn@3ECunANU$$8eTno#fDE8{+}r7XDk2Uc!u`#~<%P@b6vq)trM+-!4Ci@4V`tquHSU!uJMhPwz-Q zyQF0lU~eXUHGs6ERCpnO^N{w9p8=SDG~|)?XT~S>B@)x#J-=l`sTbZgGXKEMW{s@k z4f>60==wZkuH5C4y&ls~q)vVFPej-zrqnGOGFi1J7uJB~_xW_M+T{OS&nyOt4FZvO4t394XBWJwizo$ZRr%ge z#_Jwr@$IdCNR*n-`nnJ=X49*6pg({n5qYZ~ETXQi)>xwZ^08))l}QIM5^>(^b zj;7x-?!>s{`D=Nj)Ir?ExyMU~sW1OTgz$~~xrT#lMG{P_Nu0iYmfJgJ+DejSQNDe= zQ!@o{#Wmy_c}r`qtAtY$j-k|*Om5B7v#SbB-VbT%Jn*Q{S@^SgHQgfeajMks%R1)f zZ%Z$c?pZj~%@(obWD#|(BvR~V8-tXw7vQG7%r?F%gN4 z`^bjXleW_D>LXHx`=lHD`#+nH0uQfk9}l-#NFo2`-<`;l;n#Z=x!wPx;1_65O)r4e ztp>M36l7#1>c7Bw_h-ew3?*H5?R!y|TONa}_V#>%e;G=>jr3elxPt%C-MtDBaQn!RLdodYS&7q4hd<=w_~lW&fUIk~#~)Lf7LnMrX*3PZbl& zFJ3+2f&R#*=sHwS5@_u?JPK0`(84JA%y)Df6(Z%e9v85}vNly4^+=Cf`~p|qW)Et5 zbY&Y(10STw3&p=YK~#1L)e6m@4kdNB1PI+me0?Bt|J?(I^k?$N&r+_}%o4D&7xNT( z@egkf^?C->DAzJKly1*5O=mwNXXwASlr1Iu{53Py!~VSo#HJOo-s8ocwd4z;tgp?Ey@0FQQISif^p-asvLX1)|N2&RtG_sJWv>y?gD=6J7 z)YM!=3ox=79%<|sNwnZUR~)&${oJH4(%c+kN+0!x&X)ia>Zrk`On+aA-q1BmY2<@! zlR}VWp9jRr$gwHP#2|Fh9;qqvf(AbXwjB&1h$1p}%uzx=Yen0aqF-Vi(Xt|vRl;8XN&D>=mOz9>nyIiNkWAaDCCg4>F0p zdl7qNZ}e7!idOQ3}YUM zBxEVE^_jtk(IOsFii_szQ-{#{7094eV&G@4lCy*mJoRvMS7Zpc-Gs)AmGVTY@1fjA ziC4rnH{~>cO7;2Xo&eN148AAbt9N}Afgvn^NYDluMZ*lJlDA)BZvB$chHpP)r`Lk5vnca|@FeW)R)Mjxz9gyP!ZR@G&mj z5UQednvz8tQu#Vng8?pZ?a@4vM$((cYe1Efo@Q(TQ}a((iAgZBNVm?03AhI8WCM7P z(m2c|K}GtF_W?;pev+%{o@2}wG3jd8885Cmv{qAi7%~i$QprU)`w6v>7MbDKEItgG zs4F$n?bL|2taplRulyMe3qd$W@)K9uCsT}J9q1>r9HPVx%9qY|rq02q#4E6{97&bQ zugI2{00fO?)s8VKXXB(s@TDCEb{6cYpUsF5v{C@$T+AOU5F8GAx{Wn{Rugj-a(N8%i3>>@EeMV3KsAN=U`*m_Hp5*W_L>-LrPTsM z;k0xK&|WsMMxUY@M$(AzO^qz-tziELA{59+S3O9RWlrXb6!r5BF}O84gYF}X6ZY3Vi0sBT&=mvl-S6BMVGy{k zEQ68$fSdJ>C2PR3wof(?$C!?}z7@@Q3!Kv%a3#gwC;YPtuI&YpY=M8m*ua?5Qw4B% zC5RGMPN%9*qzOi@9T%s|R9A}0@JFL(#1&tWWDsvze6oC%Z-m(eq?>rk`!oiDG#0IUAt9>H*8C zCSu16AFnl7E9o8s;cbDZEx?voOUtQa3)SNSzGT1}?|bC8_Id|JJMcIqTNTaFZ+byj zD_({&Vn_i2CKnHo!K+wVwr1vqlC28J``ku#k{q-B`2Ek&Jl^rwIKQ@ays6qGpK2D| zlT(kMOf|yW_`;3ez%WfuHoC;F;W-baVgL;2fV3{u3CdI=>_}0(CC{I}F%o?fD+91t z1GmdS29QVKUq_K}$(ENLZGqcq-$|moy1i>{-NL!b$-! z5);d2FUB^KN|5bwyFIv+wXB`Ew$2*V$D(IRs}zI+YnZH^b@YWl^rfckByuz$1RlOeju-gU>s;%1rzscJ>e1 z)pn$EK`uVB5KnAFy8p&#aZ!`c2D0T%EadDux0j)pux}K|ZpEcNMlm;*$UIIZVk4am z$^Vmuyh$tD@Bs8SVB;$9oZl}Ag1CYy(H0oOzQN+o2`b7li&fgV>e2qmiNQRY5bX?J z2Fh2tG=H_gSS`ZojmdY_6GPTh)N^;Wv}kJeb1%0@HzxB&EGCxarpb$)d08-bL-{D>C{<>6sCRnI263PTtHivk%XCp_4Xc^nU)kFH~G7 zxwqOgb(S*|{Ik$oO5XA@>b|jpAPA!erIAverOeD3^E{r-96>O|x_lU<=*Jibxo7N2 zVx9j|-7Xw!TIP8&O+`9~{J}hD8au(&1^;!zQ{(qwWC8;HGMT49n0o}R*&^q-T?E`N z7}`);3WIll_fXkM4CX=%I7rM+C$}Rsod;MI3ro5+->ztl5EPO)LVT%%q2yCDJ|3mt zq?ha)oKev=J~rSprU@Wrd4~1C%r%5WA={Vv1u@{wPry?ayZ{(dejan%>N};ZXmda5 zu|%!{IHi!vs;e#lWDY{uK}HaT8*lW$Nr#VoAhLS2UI9-AP( z>R>xkdC1F}Q6sJPE*lJZDx30rT1H^z)_o?Gg^~02`B#S#Edk0BuiVsiJM|2JAq=zb z0a?O6eM3Nrhd@HIndJymr8l!aXG2wn@!3HJ3O#9Ko6hl@`L5$1St$SHBWaq>?S}g%6=-W(;>#Y;IBosj0DgFxgFI-dfY*AqS zTQBvS4H`&V>`I7*!A&?&+V1^CO*=k$>xKJkIn(;)I^=2t%?G(Vb5v)`e{Us~1*p;! zv_Xo|iX2>)3LdjhY!zVrHCB)i*RGT=TseR|N{IBQ+wx@TUK3561HGsDadUuf8DEjG z?6w#uI9G2iPLb`D+jsT0fx28NDiik3w87TD=6>HSe7@n1 z#RryKAV&wj42w%I(@lh=>}@SpZN&)myG^X`xKX?hUH}vxwz7b(u8&UoSbGrrN&%#a z$0fXf8*Ua?n+t(^eC0)fwav4K_{rt@rSo9#LJF3-iPMD;TY)G~2DIQwa>~^6`PU}_ zgpO5kqTkz~zi z8u>N>puVkQl*W^k8M7?H0xDuPB%OMVp4i~(KIX9{J^4^+q?CPaNxGc<#SIrwVnex* zdkg?OVt`ET!o=-5-ZSoB0L2kc_Lj+eP?HO6op@vtppYGU_F2v;T_w~g32MNHf z>6J9ueW3mR;Y1Fk1H1qNy4sQw$;|UJM;xTV{jS8&{8-?f_d}uRCp262Ru(<(6_$)StX< z?@j;|>yw@Uri#B!1`l12vbI#6e32FSaJPHigMzA`m$D4YI4%gyccFjwP^a)`<7F=x zB(+J!Ubd(;Mj+S*oU0DG`be0(Vupl#{<_}n|5c;#nqNhqQIl<_X^C*|TQU~eb@`5s z`$MpNcrCa8vs*egUbNv3;VaUBq5rvUYB zR=cB+1yP_fyX3-wAu3uOR(VhbW|J>=7!(1J17lp34@Kb-!bGG)?}p+jG-Egx4E85c zCGM;5m%chrzdnTn82%Skcj49q`|y2uzy@s8=!SuGNGJ#*64G7LF-p1-5RlQ`H9AL1 zOG~%3bPGsJD+r44?*83Ry??^u*p6MhzMr!vc~1kGYzh-VyywC=HEw@?*?KH??OVoseo`g(?vHahf= z%iVX_j_gDrv0y+N#Mj|`GYP)}lwrMnahk5{3W_z&Fj>0nD~Sj{XMynHVS5e>dQQn{ z|7E`1Nzf2{&Zehi7m|Vj!|dzBaN?BhnKoR({3QI;a*!g1|8f^Ln0Ua^n`3y9 zz?>{1jTy+n-x!vPtuA;Zl;f#Y$KHm&R;$EaJ0Pgc+p^PuQ0D6fYS-!Zv3yjygwPhy zKUuUQ<$SuDqogRZZ^Z=$n{kKXuaaO!3t;?mRD$R>D7(k(7x~3Kka22py8^(DzN$S6 zWAM7-Z@8?okAqAbWF@BzF-}neOVlMZg{UP|x`I*WHgl%ddTa%GQ}KNI>A`@Io{9Hv zlb#s@B&?4xk6>-qxBSm|@xxI(Yk{FH}Z1myd;paS*ikpGo+2Av?_rG>qzSgQ-J%b0~u(iU?@}KB`eRi8_zg13m ze>8c*7VykGhHps3GVa-4t7U>jfD0s1o~>=?ju}A$nee2}^ZJ}AVozP@Yb34zCMezU z-JT7BvFIQqs62BtJv2)g5Sbu%^YrO!HCtae>xiybLsqV)$^RMxYU=u5&=Bn9Fm`Xp*uT6(tNtla$vnmRKcU^&S30QuQIrkiiLOpTY0ylG_pV z-(+2U=au=zCTmiR-9VHJVrIV=Tw1nuoZnx4GIO6y?|uNC%l)8*V^U_smk$<^;CH~W zr!+`@WIBfYXI{6|M>*tbc$#fa=)Kg>i0u|B`yvDPQOUp~hUGxb4hxDF{-r^_@Ot%X zzig_-rJ-j-Zt-u8zY?4+4NG9VCx*SSVy0Xkk(YB%PJ8Isv3(p{r;=VxPLGf{>$T_d z$eQnXBiyn);S%nV^B;OLv(ml6Bk!tX`N!$<6k_jNEDolvG@71_kn=1e@3fVpr0#gV z;w=Ar%~nZoWd=FqSt`_Nr{=vfn?LCcRQqMCfnL#3o?KM?k39Ld%39pPaB-z{_4&ou zJi5G>CKC(Y*j|XbF!AHyy{v!e$%eE}M^@VG<*9J5#u9NSd%e|_#UZceCZ`!4m(|s; z*xu-Wu$!_2+3f=;ist_ucen zA2y#pJ)bk~LE0R`MSXFymu}IYCMn5>Thl=9Zf*#>YX)3V?;%;!{M|s0uV7+NQRSEJ z>9ildu3L$GTUi|n^k&)HiSitK+trbHhNdTf8GEGh2o_MTAAEcz^4>JSvlL5ow^Z3z zQ%(D&vps_}O-9KWT(m)BfgxyLs0ek=1`Ud=7*CoU0wYgeO56=UnC z-ot)~jjNYw*=Or#x~9HgP1tWbiSN%MhB?N2B7BCPZCw1H8C|!bZwg=ce`a)!lXtiH z5ASBg+<%tHzd!oy?K|JHaYgR``@6Nef##>PHVgB=$8Wp+Hhw*HB|*D$m)}d=nHOi? z49mZ?B3bYNhQGX>;`@eKAo0FgQ-0<4C-Lt(%pQ_KyzpBk&D`?}iGT|dalO4^h?~87 z!1bBv^_~5zKfCIyhgrLSeheeNIowlN*DRxSv}DhK1HJT(ABlH* z8UJT8XZ%6&Ngs<&AM1Zj=K45Nr6~nlIgyi*q!fr(e+}FE6I%s!UWs`0JP3LAKmXDD zEVcj2zF)MlUv;)uZ0;3)@$iN?FnW}mbodG$|0r}7Y*`t01 z)PPdsfbzhA%G`kJj!aqB0QuOfUhtqA%V4=J8`-M?EgjiX1zF?;M9%S%K8mPdDP7lC zw$~TYZ8RqvogzE@1Waoc{d;wA!9MeeBPm)CM-(Ld8ja2&t>B1`GEbE&bC-Lc*=FbX z)N)|RDO3RdlZKc7kF7vRbGQ%uCK;OP+ zxet)&l}Eq(N#&A5U;Hk?w2JYzU7n3L#Dqx*v25P8>Q`v2G7=?In_WNrK;148P1Z6j z@+^RbC1h?SF*G{$eB_I?Wa10zB%XmSuCTcea%nt$$g2F5vFwmQnyxRv44D^!_;`7- zLiuxww%m#kX~mYCfl;ZHBLA!ct5HhK(b7;QxTcaJjnc|3sO#Nm`J7V2h?2q4Xm#iS zB9Akn(vhwq6&%D^x;V&Nzp?Ne0TJ@%7Kt3LTD*hx;JO?j06_6K`@y@Tg*rEirzuZcr@g#Y zOE3<_B_xoP{JDKv^#|@WSCE#NQEgnGsv~+~$S;4$h1huo9XX@71abWY{eJ|!AvGr48E80@ZA`(=$a?N>WBwFQ~hO7++LeAfo7xkF3V z5Z14uJFuRxgyBdDy{xs4MD5nPwXu>sn|-(7trGu#xk6(`%x0NzL6zu)B0{>o=(Tm;iB z?U`O4j3%5U!EvUkKkqJweIUB-uU~?H;x{gXZQ|F|Qoh6^9zyyYn+)0wcKb2 z(9A*M#XaU#a47A4{Cn%pw1XV#`>i<0wx8YCo@KW(_WAJ6@I(Ha~i2xV0I0# z{1s@{ITHKO-TIXyFZVG;0#-B8TuHG&eV8iTyi6b7IbwP4kWnS6ubrE3gke<=Gq2^a z5FQ$F=Y6yt2I_4+p1+WW z{A0wVpzdyv5VJ`xkQZa(r{uJp^X{im^)Qz>%$Nndk$r7W%?eSD7j2Zat$}qU=S6uR9u#ua>y<{Ecx+eP z?7u&?33$BkS`(e-)=DX7=b;%L(7zLFs$9#VQ5XQRW#S3ElPO?}{_c=Dy-E>OV}*!r zrjMXVFlbIN=;Zt06_a+|RI1M8E7J^tn{G5PEvmgbB7mq!?ZiMp=wSORAY%}z^*4YFJf z5_ZxoRjQ;vrqxu6$yRI-imI0N>-KMT5xlk~+meyIJz!Q*T4ZyS>Osv)yz^g%=wE7w ze@`O?%FhQd&XMK^M{<#R)-{@~U_phzNGi(v&wSYOE?N4#Kvy@v&0sSbbf(!kEAhMGf~~l*%}xc_^NEB29CoY{ z2Ky0$K;DDY2z?SY9D-!YqX7&`I~S1sFI<2RnRp@Cws0@E!nsf~UVz_wB^>N1ad1K# z7|fpvbA3x^0~v%u5b24S8MdXDYS+a}P5w0}AD1)NvC$MTrr>7i9|}lNWT0C)=pk7@ z;~fy72n*3n3?kadEze4sgXHHVxmIwRn6uXnq!Cu@KW% z%1$87cwB#dGY$6meIMgXNmWTZ$@d=e!-Se4{df6I*8*;hs!;S3?B9!fu2nZ0wRt>^ zI)4hf@dg1inK!C!64ArbfgDsJn&s|-0D>+;0Cp5$IuI-t@1UySKNSA!855SDZAc7% z)U##Gfdqdc1Rorz5Zn`g4rT-;NyTUQfCrI*8OvnTi;l$xk;p{ARHXj^Q}RRo8}KyX zuie*k)gbV$&lr{>@mlR=n%G?oNhN(P_oe=mJ@cjOXqqZS#8h=|h6cpqP$gc;=j7=X))5sVKL z*@X3l08d3B$;`!WOi-Zai0ARmISd+Z^^`!Lie;NZi<~%A zg(L;P42*@zLkE+60|gUUvvEdhj84RBdr^=ny&?My#vo*_6Sc<|Gomjfd40>z+LvKO zlFUmZc40SHfsKPo6)Z))Ds>nul7S5tnj~{C&maYUk%%Dv-V|->;y!R*hfVA5XoaY! zH|=NfpMH&d&u$fVo_-$MODs=a;t!y1sJ8l3vc8h{LC~|}!@aC<(D_01s+0UdeBFmp zOVSlJj)0!dP;Js=%C#<`Ky)B{`N!_34M3|RMjlz!lEqYz;e~Jy$2H7$if@*eYJtz% zQAU@ny731hM!W?kQfQ`Cr%e<|oM5CWNfgK@JGs$L2&oE8gC;%v)jWhGE7=#8 zgq$Q|no6@bE3gKmcv=;MakcDXcr|~&eUta5$zlxg%Gv&)AQ;z9b1MCM?v;{R_FO|@ z$~Kb#eN6M*6KifIsPJJ5`Zh8VHwo(JJ{wJ!!e|&3lYN4OJzJB z=0|IiFP~XMM}N~$kreYB*@}HD%AiByTVT>6B81H>>j)C_ECCD3Iel0T7nHOldd9ky zIih+i(u_8`%FNL};7@4qUNyQw6#y`)aB(LNehp>HdM$NsB37m_Ic+B5C61F=a^g9~My)AqjB|uY z$t#l9cTEaH>4ls%*OG0}c^ql#r<{5%JCsm7(E1g?{p(D$j^qHOw^7U!DC8l{3^5Brzi!mb%V=xFMKpr_ndERrl_&$ zG~?C=UNmo*X)BIO+7iLH_r;m&porX)wG&cBfJ)ly9OC0L?=wV+)h64ZoM*Wso&fk4 zz~5bq=ZJYl@Cozu*$|BW8lW8mIo; zK=(B>+&qx8xNlP-oi9mN(Gn2b&+%<&=q{Gbn0w(tcq>6@*M=-a1-E7|jHJP~k5Vp# zMp}3a5ATW)ccyYM^m+-IYTL)~B;<80$#69K9y|t@Jt(ziDP9+W>x^$4LvpcD%>3YO zd>Cz|q-D_I1ASG*#bZUX)|?#3O64TwXFp5pVYbF^NK5-Lh5>NKYdQQwsEZwAQORO6cydzZvcqGW-q1B|p z@=CX{;;@4qtf`MkBFB-8_iGHsyY$2nH=`hSOlAe|q1N}*bg=8ia}yV-6r^oa>@vrQ$9xvhy7^Y}v>}=9Px1)zO(R3qH0gobL5*o5GbXPhCZf6E zGpArvh1z>Aono%}3azF}-A*pu?}NERwpD_j{@c2&DYGF6o+gebAzb3WnPta19#s5s z%u~dV>oQBEZTPPQAO^7^QeFIVnL&pcW&%I@YXidT-m&0Om{%qOl0c%NOg!cRHw}yo zp$&nYIUcx*DZC;lHk!68KuLFE?Ie@^6@uLMJMNkLOaA4&Ml;IZYA)D4OJ{lj|e^&6c69^G@MjdJ;#O&Ri&( zwVOR}S?k;Lc20p2S?H$dekhA4wczSI1i5iB%8IV)6|O#mUMc`EZL2Ms-;tI4cB56*8!foZ+oF{%6vz57GvSL}A)BU2 zuP(%sydPEj;y>D^0`8v_NPJxAMtzAgCulUo{EiG|lmC)0z?FeJtD7hA#yXvN+J@x2 z?%0}u0Hu+N#1ROfw;&CcP!1FYw}pA7W+lB;eGM!s0HJJy9rCfY!xmN)t(~MGk_6+y znk7@O1Du{?@RzU?-?Fpf0$K?Vo5)PgNE2EGY5(7Wef`MVW3xBrTJ_vj}iq_xxcLvAvWDi7ZQ1Tt*IqY!4Rhi6Bh)=i}BD ze=X$%T^VyGLlI5H@t4M$lUgrrKi-ABs;DbBr#CTET8kq+M@MQ>+*n)0L09Y`J)LK3 zeU}=!{w^+^O#`mOyNd8v+wVlmc8Yt`dM?h)tC#7mdc^9hY!{VZv3?RQ2R=iXAGkW0 zrGK>En*E{ZjgiRI>12+FEocb!krv5=(J9en#cQBp@7mkO_j+C;ntn zi$S6L7v|X}OBP$ROK%G)1DL473#sCnsMDE>I&rKOOqb%;;Lu_01B31PpPO&{R$+ckCEnVjtNK4!~8nWor>&D=jfFj;O3wsI}`&W+F6D zNDlicOJW;ToQP)*T%|}2<`#iaCI$^N87-Rwn@?8H3rQxJ!gx$^6RIq_mN6e+yf)+C z{?r!C%@Fl!&bao9hIa$2-EFO-8f@l-fzxRu)Tt!6KpQAPmt0PgmjjtPOx0Dv;LRPO zNW_SaHRjc_d|VnTju=$QUh4nzQ+Bw}oTUmERX^rg72(MAFejCWg`xEgm1}9BPx65k z;~HaBvj>-QWSd{9H*8vwGe=EAZG*yx@LguuGm`ZMTNwUPw>q@mdEFhf_`rXH9Gw9uO77 z<4N!*rc@@SpAzJhGE0Z!>wjNB_QS@vUp(x^tL{nHv3hdENb--{l`N)- z!rg?RVpe$4IU~|)AVgT`%|%HxG2H1234V*hL|Ho(PFi14?X3rs9!hB_5XYD5cVnnpxEjb_-JuYX@w;bn6Ms5Jb?^ZY1STYu`Z=&- zCG#_5BSMl<;UFKQt=CQoE=@K`D)@PTT*gKFCWG&>xnHi#;wQGQy8#6gL*1Bx2G~j znCJ)dc|btHSlAX_4evTz6H(%wZ940Orz?h-W(p5$cKSAR+9WC6YuXIB5hqxVY+7fT z0P2KF?NJa_OGRYeC_Ud=^<)8y8L*h&<#Ve5#Ck^*V=eVL z#=4}6C6fk&+g=3>`#)LdweOb$Z^&uOHG ziB@mV5||kDD2b6(wsGJITydD-=E|po(zWdCYiMqY#N{F}v{UAu9Dz<3?21Cg?YcCbJ|Vi23*0Tr(2 zF5}VXQE3q>YO2jgr%?5SkvXow_yfilNf;YeL{f=VX{Ny6I^1AYMcKsOk}AXu4hhUm zsTWr%@h-N{5Q88Grubbm{%cOC{(-uk>BE$tDdp9`#E;N1?4N3o!?@A|FlGNF_M=Kq9S9w|kvMaCxfNi1zQn6M_~jPeH9TQ*|_F9qWmX|v2nH6%(Q?xk>$EY4|F zIB~c!x;xo&n1JKC^HdJyv&Z?o8(P2C9~7Y>%#^mVm%M8<7EJ6-y{_Q z>&0;SU0iwb)mEkMz%)w15-&(JlfVBskIWqyA*2t5e53V(L!r~A6dxhq3)SNM)eLy3 z^LpS6&V3Tc#uB$l2z>}bsv&BvBa*w0KhFy07e1D>7M#;Dy80q!2SZiisE)VQ5wJvB z`1G8jY7qI8w8_4fGhdLatbwamklSG>_CD9VfJw_eJYGvKXoyWZ1N~`p1pWNj0@uOj zj^-EJK*#+cKSE9;+V)idHowabKwtAo%Mio-fdzi8KZ;TaPwo`KKO`V;Ov)zb#2V=F zROheIK>YFRZG2hdST-Y(FqIcFujU@o*ABv-N5n0@WKr{xIn8lOryr!Ul~D@Zp_rRi zhbQ4rL?~DR70GtgL>6Bu^)3M$<$I_QZ`h4=232$BrECb~%01rk&X{=0>`5iF#A4_? z9ng46Etmh3Tzu>^T6&V_@zjqEuDuvkkV)*c3w!xYMrYok)nQ@z+h|Fu~80><}I z14fB5wP04xjy3plnBwe5AQgwg-V=72CO>8wMVGgJ7(KVD2zqOfq-l)j%LLE^N25ZF zlI*JmWtl!IAE*h32e+B<8=K}gE)T*s-pv;JA8qtJY`w8Uy{dp5^Vw{wQ z37E$ZtS}(|Nb%+NYQK5lz#d1=DL=f8IcWu|?0$T0OWESa_{^QX#a%2(@3)|v4eJ|~ zxAR%ye0(hqI*4CgFIudoGrf>QbJn#MFbuRZO4P9Kq9^CAjQy@wzFX-HW?es^L;t>- z4DQb^4f8Zu@~$tchx-t#K3}oJNmI5e*W`~4@W`>|87Mc$Nm4BZ*7Y$3JeijYS6lfN z_zi9W^f8xVM7Wrz+qD!QqPF-IOYsZM$89s--hd6BA!1ruW4c9Rhi_bLTVu1g-u>El zh!u{*BGzZ=5m?T~qReri{6Qo5!1>&SBAM^U4y!^;<4At!ob_63m%^=uBt;r`xJ>Z| zl3X{=yA6pD&4_5rpd3mI`S_S^C`esSDTl(W5ebKT+oaw$XE~&zMBleLTaCO!e-CcW zfO4jY5eDb)WDS1WW_&lC-l02c14awx5j`iA`}wv+gdMv*31^R4HtKb0TSBd&w8(h{ ztF^N%*ykb9nq1boM5H3W)$9)Hv(K`HTgF%1Zn5iaO^Yqn=xcyWmD5udeCYwNYd3>G z*Dza55%+Bq9MXIJy+ZqW%|1QK@+OmZFzdCC1cDKirb2tSBr1^*XFdBnuU4ccp{+GR ztPT08Enlp?>`!aAS6Z{!#q~pT`|ziZ{{|sVfH(1sVZcPAZu!;jMZqqj7fWG3yq3rB zZ(1h;agwr|7jN0kj>Rs2Kc0Yg4F2muK1v$k#s}PtH0h0nEQR#AbbLO`EJ7H04?j8^ z&F>iNelb4WG5#y&^~}3Qy`gGtxvfyEk%$+cr~JE6FQ%UjY0Ll712x9tbzXD*oqpCi z`;P}HUSOvXeC#`d?3lCZTyPOz^zK|lh%XJlAZgq%$5pp&yz9z~oY(ss)Z?(!ExtC~ zxi%%fzSy}wC8`_rB35TC{|2$N{vf`I)3r(Ta*MKSOFQic412)n_XJ~G;LK8$^3deR z&fK}ERmPWlE)Uq*HuqHbqWg%K2N7Kddh>^}*5HKb&t|K+h|=k=7t-^K2~UnTDM z5eu(nySqgt@u5_}oD-fSUVRV5F1fhO8*Xjn70VwLvKS}{gN`))1 zFiU1$6uv?d7)kH6664|nqOTpB#@a~B#5tcJxB9SExOn)PgNu0><;<-eMyFYD+JG9< zI*?Dk5j#D zQ$%1^WEtoprJxlDd7(x^qm#;8bXtn#L>iBRu}b`%#*2%^U@j<6&3XCoqa&7{5`T9xzi1Xjpu1!q#k;U%nWU<_ zY|RK{j@CRj(0s&^O);)i?VGlbV0Il3Qhl&o*bE_7)87iCc3#?wV2an@j(Sw%5~wM2#K9ab zoLNiY*UAyK609^cK;gx5=>kw>KjZ@h$4d8ga!964bw&j+?J@)GgWa*h06f_RP}XUp-1M-~4)TXuMnALmufvOu6$z$^Jd>-i7^npW`DCvGA49(7WR zVR4{`1%-@g9JHlLIj7+pt!VTZ$HZX5#=i{kczhqEc3Ha^W=b%={LERq=57&kkdo@# z4Ci1zXg8YsV9WYm-^K>gTdOMI9)oVr8-;EhmuET8 zyp_s-tZsU{o9VJ~yI=T_VEX$o#hb#=j!~L>a@}n%LeQPRe zD>~xZb_FwtN`{FrDN^?7HgQshjQRUX|Kf9;Yc*6(sE9l>Q4)aLAEVACiAX@NA&l(j z8#1=ZX!_3bRRF37QNr~399Gh*(|jM{0(=_-RtY`=8$!F{l-!yh_v! zZ5P=Z`=eh$SRgwi4`6D-wlH$0N`jo#acsOD;;0^WiKHEs6)XKGMAGt3*MK}0qU~rM z0+=m78LpBBGuiglw9iYr9bUIgWs}onWarOaK2Ek=#L_oLL?JVq8mi}`X5d~%f8%gs zyLM$`(D$c>w0jbU7yU&roOe%_r8#SERs0P&C6xH}w1DX%4;U%AC!YIaTh!|BCR6fi z6oY}NK=t=|q3@gYzlh_&_{`uofE&6V)|efZIDHjN&7Lw$^oxPZkt)t9`sL-}BPh6z zMO#L5dd9T@@xeDf6~!uasHmBxmDb?lVzwG8feYc}NZQtlTv1SF4A$nNku8%t_XzY1RKizl%7x735nyZ+^MHf_*%_qD0uG{a@)<+sZ5Fm2`v0iw{MPY|^mr zn+m81wDc>GZ7|`S{xO*DgNDNO(i+_ko_1hB%b6TTi!N<(?#<((a zB7juS#-csH(T6}x<4k^Q#X;4@%nM)SK34sY^m{vr?7DIqvpGT5^5Aj->s+3Aw45XH9I=pz*$lME zcqx1V#Kjl7zD}Qb+_0+inGyGfBC2lm874+v8=;z(&m+R@*FX*=f7*?k#ck3DSoZBLyv^J0E zj_Vo!UNB+^!a!Pt#l+G_<}hNx&y7+rNhYQW%g04@g3P?UEi4VpyqC-Z$$BeHk3=uG z8cH5X&1^byvnXn9UbB@d39`O+XEkVGHMwLp7i6<`XR}Mblxtvf6J&qu&i<}}{rx3- z&|L1MUZ!nn#a@>_XWHCTr>K}{XGBcD4omzY_5ZXPzw8lLX)^CWFa|D3>L!&TV&98pzuHS0}VT` zbkDrp>t(cUe7%Yj*T%Qzx#;v?#REB<2Z5dyf1JlYGfMc*H(Qavj$4*n6q9bgCH`X@ zr|LQ1?q=|TZTE7#Puo6OI^W;^Q#|-D+qevK@}J_t<3GoZi_U&duQreuB0vVW+0Hut9(hNCzlO>>zTi584)7NTXJU!nzLtBjv4pOp+RlQYJxn92%1tNC806yTR| z$fHXJS(mBh=#s&DN#JXw`4=*BEbR7rqC~Hr}(s!&WW-M>Y;GDoY|1QW9hSS2m`aD3|+CJQ?Tv zv-qiIFt;SqnF@;zi~K258Zjw5lCXk-7-^P7N{XOYF>z#qsx>O`xlMd{kSWvd{GImx z;f{L|lXHob$rIppq5u=kWcDM|l_v4wiKJj(u8k{1E5=oCmSvVFm=!fWuc~b2l~-o6 zt{aVxxb5{m$j>)iH`T>eJK89(t~{(?lM>1%f443G?3z>;&Bq;!}{geUVtf8@Kv7!*(K(j8{&Hv-G$}GsPM0L!_mk7 z)vW8ca`m0p6E6DFAHLo@axEP!IL{%dOkd_qv!ix8Xz5sC+YYq9TxpT>xGug zl_dW*ZsotkadV8n$x!^PC(;svAQadq`LIp#iR#ItfcjV)be_Nxw^MksKK|9K=-Htq z2OrHdJOutnva_2Jg)W`Q{O5J+5us z);TYAc6?9x)9u5=gxHfDb9wf=Npzk-I5sz>Y5&EPCaLS&y8O%z#QMx#pQ!N=L6TnI zS1GX-j)?b-D6wBA(=u5)fV`&y^kRW>4!>ucVgTjx)itsyzvp1dAN8YR>YqG;stdJE`-Uanv#wtgXR!yLP#1*y zJ%3csYbBS0sUQ333%iX56^Hgf&c1d&s%au!r5PsU(0G%f=d9?Mp&gVtRGxb^dsozgzDR>YroyT_8?;@-=_?$DugIyAM zb#JDBwX2hsh190UlRn&ZH9WpVx>U_Sd62Ju=h8@&I%SDymRU4< zNB|4opJg+0$>hV|5=xM7y_CMSw8>}WZPTg<$Fg^fq8P_NStxCw0 zwx0dPffc0$c5C|SV{l+81d@mZk?EuJ1lqISLD6JE)`NjGOF|4cLChk-EC-LlJwaMw z!K|&phJs!^H^KZ0T3VD(nK)q7BK-4|AqhGm&zC|ji-OOreB`l1UmnnPx%%qJA`leJ zx`YbPR9l&!;W7mjQ=+>CXCAIsA019ZJ!M{pNiczzG5qC&+@M-`>{|gO{jRiRp4k=% zT?NM7MaY+7Cg`tl3zl$(Z()oV4o*Sgnf>4`N-{-gn2UaZhd_wv1p?~}Q7|zm;^9ez zO^UBqE3snK8V8j^Bz1I#D1t%``r<&apWUDqNnRong=ZdF&O!b5fGo(CYA6jT zTF&5$OU|?JUq~G7rVVZNdC_hg)^ZVDnFI(uAS-K)Ch>^PDv8Fth_T;}C>Dsp4bqsg zbsh1ELngUG-J+J3gzp3(_-rqdsf0%lD5L;!q}#C=BiQRp@dicl)fRzv2#f@Uq^Kv_ zvB6Rar?~XYEeZEa?2b1Hnln*o@fe%)@ihnC(LQMtF0ESW6TD?=X!t=%xFHRChPg`NRg3ep{-INwCY>v*L?3d697 z)8D;gQa+^e=QND*%HS~fj~2~{kD?B^C9c+ozPZgv2UA;d5*zU&qLFrtNuab?&CEHu zR4t8++%`^4*EE{k>@zBwODiS5APOxyq;mTCY)vtD(bkxYnoRcr31rUz?1<*E zivVhFDFBJ+)Bys59$;Gm@-M;p=gYYpnfVk_v3Fp=-{t%wCJfX$(e*jrwQB*ohky?1 zR!^t=zAeBv_t|}ywn42cmO<0)gR(MF@o! zWB?a{zev~;SbbAS#;^*3Xow_nK_?A~s?Y6%DxtWrP)#wYj}+<;JF$TgGX{VI5D3uP zgq$SSX=G8Zy=y+<$H3SuH~5ej`5a*Wr`-4xf>pZ6_yd$Cx&rpQkned>R(nF25h<%h z`5qDwH>a?_tk8sh+kA&xZ6XM3U;t+kK-MOa4=%A56qL^mQVwt6%x-ul-(KVgm8=8R zY(gI6MHhVnSw9D9!$3{k7(rwhlVnp4s8uOP0>*OJ9wmIC2VZvpOQ{JLq~de`+$P(D0dMvibXv z7`D~p3S|0ummrR(A=w&uRfzzjO*}10?B*VHNCIvNZGy-t2KUku--S7sx zk3A>w24nsn{9mS7up&tnNcA~U$)*Sd&|%aAVH^b+^Ml-YFcf=0Mv~apkYOYEa8fhq z1Wtsi17YzLHHZKLUVOa#48ebhHZ?*6a9JOO#`+MnLo5x4EM;RPrRjE+BaxuqFf2eP zcF;SBl$X*9q1Loz$ub1Y4{50bomb|``E?!sDf7ewG7Dl~O=J5e32Q7=T;KsU&kM8a z5a{CwT0?9lK*Y2 zfPxA(xMv_ercI%;?WzZIM@Q_lHGWd(fccvUrcA8B4KJAEv*Y(%C85<5?jB-#+>?GN@cS;V|X5RU1-R$*(1HMcj~JF~_v^ z?9QC%`5KhJ3W~AVmOGT0Vn3$zg>bJT?f0tXSu~@yLzmizuABNXgFju+61W5HcoPhK zvsI2+pUHAp>@WuP?S*`6Z*D#wzmhE93csKy zYlML&%qUkpuU%Jmx}u!FC>h?lRE$aqfdFN1yQ24e`JLnjc+=xd-I!aE*%-ih?Iv9Q z0+eK|Wx;IfQ|#(L1DL&x%^ht`x6mXFB5L|k%kyD6nvu=QJoi=yO-aaAPgb~N8w3Wx z(afreSqDi$pbvj$ZK_HGj1f&ipCL7Mv-nH}%ty0!G;;;HunKX?DskddUP^`aM;Gf% zK5t=!B40@{Rw4DFkf13~Kj?FBDEC|$*xaW{k}%YF)JOw`@4f0yKc@C)Ce?ZjD{!p3 zZVX)CWE%(6OswfQh8l2VT;3L**1<3U)|$$bO$8u8oyf#EHVn?bmwgwL&R4-_TiG$*8)@~Qzs zxYZ|--TOMG$SQF#NT&_yfbb~Q72mxIVVP*E`m|oP z&@d3T&G@jYRMp&$IN5|Ko~tXsHx4?X1Kc0eO~0P4(*)${c|A+Q3Yv>pHu!e4#VRE( zEw|oeGX_^6yi?56FjO}e|AqzcI&fGx&R#f=$tMZ|C?xix`wMAai)*)qU}Z1gt$NDX z@{!5hfN%AiYdhUe&B~Pn$J|i+>PCNbh`_l<`DImOEGirEz82Zu_fBz-I_)b1J$;5L zY5xQi1Da>Z@a5Eq2<2MYs#@ycOjH8});exyJIoUkgN9V=K(cg1$gD6ow zkgB>12Bqr5DipK*kTs+W)BL!j3v}grWo<|VP#I&rJg@a%xZmjtUd2XDW1I7w(5*n6 zG)hkky0HRxOq9F54^J-cb^^bnxL)jLs`+7IAVvj#M9E405OD9O>Ak@^8McI<0brl1 zpn!N|NR37*u>({~1Megcs(3zbl{Fa21BKLqb0t2GS7&lCjwo<_s_r?%CA<)60tY165HZ6tnqfxG3>dhnV$;>VpRj%i zs?aW>pAJAEi-^An*kXwK4@DPuM!`q|E{qi0Rc}WYfDt%1mHCb^_}H?^%VtIl|t5Y zVo^%C2?mXe;de&QWxJMlDOgktrzp5_$}vU!e~>bJeyHH+;p0jIp{2PQ1;JA}(#WyA zYfNoEPNYUF7OSME7#A@upidfXefc0LD~8{wDbE=8R7;V)q)|)xA1FD$X7(Q_86>2m zDbL!Zqpcp-s1)Cb?cl#wRrvDQWFk|li@eFyEiA(kWxVr1Nruzo&{~K(;MhzHUqdO&4y1S8< zR7wOCB;DwSkt1cuNCA=VlF`yFC?F`PAOZrSr+nvmUDtgd_xB&zPdoPc?0CN)uO|jo zPDg>c)kzFTs>ZNMMFaf;I`0)d98tVq^2AtL=NXdZre|3Kf0t)@hW1Ub3d}vbx&>dJ z_T0*<%C2kg7cWQbM8Qq(^t>DTNpAVR8hy8Xzv*t+m2d0Hz3vC?C7V|dIzLtZyA5E^ zm7gGU>tX+Yy2f`N(G5}S+zudY12_m};+76Xr-3XO1dfi2*xMrrv)#+YiE+cxNntd) zi2dmkhmI-!Z2?pL%G2bRn@eJk-K4`HMG30kT`ip;_1##OI<(~oQlG+pht-2e=jO86 zBeRyjQh&6#;FX})2RC&GI)c=nM<-~ybJPDv>hHdGV~_l9L~Y?le)HL- zAij|uCVlg~`5fSJIZoS4|8MvATzKcHegORS(pS!x!ye*3d$KPo0&f+QA(ufU2%saw zLjK_Qgj0u1cp{oIe>C=)BT%E zozoUxaV6hk1(-W%CLwfjZ41;>oScW$)Mi= z;u_bb7$W2rpEKe;wPm}^SfRD#{2Y6Nl@rY7S$T^%0WB|mE0*lLFkH3>=9L&R0vu91 zubY8E$I`JY%X;#Xm{K?9E!ExE^08a$vTkczs(%>g1`eAV&86T6o?3)ufxP9GDx(*wX5Dtnf0qT` z>^P}j)g%vLfI|>23^fBCIYzRI4vc(cS>!msoY(yx{B4XxPfVisEbN`^nlBA*@wn~6 zqL@-VV)O5Az)D~0+oQj`0e|`)_51_&XJ;nt2CxPXShM6sh~1}-%5NW_rs9!;3^{Ag zQw%XagMc63KcUkpcS=JkOxzg>%xmrO5dUwM zx7@%(3bQtKU;9yEQE=vb_5a-ym|U9}=sLxnINVo$jn+yasXe5Oe>(g96Twf~A*Ft6~qby$g!3 zG)T{r2x~(=P8aJd@3LiHJ$}Op{Js9~p1?(IoFL^}euWY1#$2TbUEfc5OgK_*)$|^o zcd-oq4hTk)1u@u*RE#fi-)LVdkq?%__|cya``MxcvMz2eR#D$yeZy`|lV_oonTvo zxmiZV>HG;Yq`aTzz5Ro075VK(X)je!&HxorA57?Z z+bPp-(SSrzG%&FrIG{3{6ivxHh$S1OHtU~-s2}SHMY1h7B86HG7 zc@MJ5bgse@p>QgMXgYfxrWY{;R2Yz& z6^AoROsPv5qzKE%X>?nzDQ;+K<2im=sT-0C!zE(Q> z-@*VzsWR`dhUR}7yu--969Y0#D`Oh}P7G-3ig3&H3>!^7quU)e0pU&Icr#wDiJSfA zns`e~ycH-<-W`7@7HSq;BM&)4LHHIr33LUGTJPk2iz=2 zK1&T(H7Z9j4SM8^aQKXPwP-7wrMZeGI!7=#^^80Ojrt?uqXE35#nsxbgZMzp(IEHH z;K!p+Vn;*13IvKIJ{nGbECY+k8I9Z>eR?*E1dZ*UXoHx?P;9cdK+2JexnJ^!*F)K2 zV#gA5bfbiHBMA!wNyB5Qi(|z!9bR6#M047P-?I$!>CE$aO>1GR(7pjOahx1@$ zYBF*rURdUl<3NFh8i$@m(+x?`2EgcrVyH{v!BFyPVLxE;C=KG?=%TWFk&6#$unltyA^T#t6vHxxc3@0UM7+wv}{9c@qaMhh(#YkvT zDPv}@=p~Vq1!EiRv%`6BH0K387;VkW1t}Q?jL5W_JLkVLh*ivs|Crx$Hmt@fki6h$(y2zTzo15HDT^662;#|Iq zwtY5G-78cvDg=HRJDdd+m03hPFE**7l&ls-3j52E2U9SXDbt^ehD-UT?1_4Z8Na3q z_1Q6+^h?(Nz)<8g!k)=}l1taHiKcz!J4;N2H93k@E;(Wrt2Yol9A>_>^!$6FoO5Pn zU9x>i(MdU$^!74Q;KIyLi={AQ-bAca{2EKU1~h(bL`znRy!KVID09Rn7-Q%9wr_i(}v-cL}aCg@et=+n&SMXULzr@ZW{(+m3#Cl*~3uRR)=V2 zJ0GC^6_cZ&;-<(y$mB(?Z1$CklGX4*sJH zu%;XUHY6cwtNG^N<6s1!=!8}6K=!_$v3;iT>2TXGhKH0x`G83pmu;*l&rY%EJLrRbMyJdU>X+Ls7Z?dT=vf&NUtN^7>+l>`!YFEH9n$ zUB>kugyc^WyK$C@Ry0aB@iZ1HpGT#_;TBc1D!{;m3TL63rBb|?#XOS4C7#j+V3N}& z)3T-tr;H{OCvWWCDL{Q_fPGLk&P!EwVe+bsPuxb92sy?B+@Eh*39J$6zoV1=L#BQN zy=k2GUX1JnyjSSrUOl$JiPPZM^AH$xpPeJp`v_Sr+e_cdB0Eb}6wZlFtUmKt{uJr> z(z=B=1w)_m5abk8Rz3@hM0E2bn{&TLt3Nb%724(~nSPYqI+l zeR5{*Z1SKgH(nM^EQjlH_w|#o`cxR~7ONitRUop_kFO`JE`7m6NUCCA*wM z3jUCpH+rPP5}nea-qOzS@q5`H^VCG{-E`S;LAp3j$|P4>CtB}2IR$u?AiOJ-i!+1B zaOnE#V^0^Z-s(1}$=PPK`(phF^qC)3SLA1k@(kh3wt8dlhND_GDIc86my+r)I3^eJ z5g$vl76o@Q1a7py>j0-a)IFV(+~d>kmx8_Q9#&4iz2Q@hmZ=by=WL-au-X(A%DQEM ztc`AbJ;m{L(fWZ$6dL4txH8S>rFOV9@_g+miKQ&_6#IZK*Y|o+q4ieeAqmaV>(#@Z zN=sHXzYiQ{Z_{%%T+p9_AJkeORV!tDKKgQyZhQ-uG_>%Mh4pCv9@^ks`s0Y-?6`NB zhTo?6Lniix(M>XRAnZ8r8$K`TmD4v^FEX(_eeLteVN%E5ucs^f;X+=_8Y8OFBQQUW zgDlT<%0mXa6O!ZA)Gv+xum7|ik*%1Xvo2J+OQ@u?-gr*?=;&3a0kBUS{I(*%P4PBk_l`5}e7LDv5E%WXT$-iUKhJRSKcTJ|R;8 zU^O~BaYer!%%{qivsz?y3>IGvM5l`q7)uhC|JfGHzc4@ikQa; zCa}HdS%N3{$yW5YM1c7vsi_Ik6VEs?(=?;wQ$1wxy(fguAfKahqCRr78OP@WX^_Nt!19UT?h!e=Z zyah#)U+ljFWv>SDA|D5BCI*2K!4-E;5@u425?AdLLJBUXAy1OqZoHI3E$L2BCCw(= z{hfdyN-}*EzC{E{xiS|#CHFc5*qnro1@dJRZC-*kMz5$N6|~cMX-QF?{a|ylkmGmJ zVV);q*s!0oj~5iK!Vod|hjC$-r`_1c!xyJ=Z^Ms$yS?5c$z~j>l2ie~PryMKvY}NX z^J79Bos>1Pof>)!(E=e7$DEa$8zj2Rg?a1zj`a-^BD@_mi6DNu5~|bh zH7N*(prh@nQ7H8^K|tseOf-(uI|_};Uy0m*Y_mZ^^n@tntRpObEPX-n-4Vtl?OWm> zw_J&U2arUX=ntF#ifOoj zlQFWBxvOT06^-da5VCqi2nc|cBCFAR6VsSW7^yrRtp?z-Dp^{$9A|6qH2WQ1$$njJ zy1&ry;FsLdM%R;5hWslR+HS@edpO1F<<{;n>2EpF*~ZhD_)H4ayI(iYzpl4G`gZis zt;^$`(QJu_&%&dl4pyjjk$;qaU!Htgc=_;J`OojGi_>os=j})!Q~?nMW(h$=LxejJ zF_h0(NU{FnYWs*ddb<#=cxLYot_1e5ORhv#F-FWQH#jQ>YeL0Ex~D-$CBvPI5wD{f z#lT#NpvZ@5R=}uLMA0D3L18DdO6#Zsf@6sXly{GEERwT%;NDv4EKpdWFLbBe%-@%4lwY!ENef zpy1qC*n<_#|*6+agE9OVVY z@mU6|*`(K-J;quvIslgVK(QbkTFvx;G_4O7v{;c2%bTeYpFOCaduMuB_JZoBMzb5- z(Z3n_DSZVn_{DGbg!aul#bZjdhdmmW25J)wlO`*_bBS@rw|9tJUP*0=zx{B!Rn&d0 z{QEm;ZY&W{P}P+7nzrM9Hc>Szua-!JR9rlIh~pgb`cQ!uoQtpHE|kQoe0IqYOLeEa z_Q%6%$ufU?HFBSQaW=vYBtwY!pkQAxn2sA6XV>x?)vvp)p03@=Y-jvbu7`{2(SDZE z#&bAyf8%EdEEa$^@nyAr3ow<7v0c2d< zac1afHWNdx4=h`O;MHLr~Vo3fvsyeq@|(L{(K^Idi#q6yP4z5$4O zPf{lGo9KPhmgkLb#*CbW>Rwk1@~n{`08eda<5{$w4Flq6Psl8`X*AOj$XxDo!7#?p zFqx!={Ymw=Jbfj`QfPoKrG^A_wn~myA=xmR6PPrg@yhbKdR&bmu{Y!J@R6GpWSW7} z!aV!v`lYjc5YK+HO|G2HyvJj=ySGnkS9-mvAa4FTLLh&Ys21@Q&>E4Dof#BPiwmkJ zo94I6T)kf?{VJQ@&DH87lIl**5Pi~Kw0Nix7^0a+k7J^7wJ%#Y_XWlVoAP~~@Y(OO%W7+qaI#mf z*Y4G$)4rJw{Wf*%?xxHmpMlMDzXQn+G^B)-&pTpzg58cf;zrgHamh+&J;AI^2~X~k zia`@HZ;QmRmL?$4Ic=}U&O=;THxl!uW)G8uorHI)E9Xx(E7%)1ZA5s{UA0&WJc~%o zrTHi8@U5Jh7bzbm_y^j|svmm1WJR%k`Iugk>7w}Xm#2_qyi@Kt(>U?gjVmVsv$$ia z7v*Y1%#zC(-$KP#2RvH;X!C%BajSTfS73pA;h)*sXPgC*3#(>>zL@sUGSe@30yhIo zUUdg_tn~f)G5lM20&)8yC%Z#}*(PsD3Itmu_GBNS=pOjdd7oS`?`7IQYJ#Sk z*TG!jokPS=StW!uu>Q}TGlE<46z%h)K~$p#cM`jcMOp(SLRklZP^ymeS@`2GhkY;m z1l}=}UozKN95D7T^Gv=rsPVTiJkQ04%Pfz0SqaWSQ{Gp^gjYQ>emUu;*rq-M4w%+y zEky3_a#sc|raJ0agbnZ-I?vX)y9={l={yy70)}+ga;_YmKd2HvRPv3aJz)lDw`sh& z`SH;$I2N>2TOgG6G4_)fhtW{mYf9pt5;v=E+ohAcL@}iA;RV!J{!@;jhVr58MEqEy zXmX++n~~q@7OSr%DAoZ=rh#j7qP@lsei9+W2`(Rq>S}$>6=?_$VGNhj6lq?^_2&N7D&ax(rleLjbU*QU}6q8{+gspLeD=BVFN)zUn6f-{2v zhQQ${TY9KwTs_^q=O)opZozohHc8qjqiCmnEs+j$n9bWSGCADe*8~Z9@XhQEKo$dP zTd;V?sc{Mbj5?w>{K5;TQ%n#=YSYkg0XFyY2U6ofEw8i2*Bw>uM0|+|Cq^1r4yH0o zX=p6dtcFM`5u@n`@$-2Y5cP48ttt{b0KnX&0w_++$pll=9OOTeyvDl`E}02tnf!bv zMhOr~-w@(3>ao)K(i5!?#A<-<0t`{!Nz`XQ%S#`qsWv%T`DQAfXQm=4?PS-r*4cA4 zJy?i@RE4lv#*`F*wV2hPW&sCNfXtyOHPDyG8~>D%)bq_Um}NCM(!NaMHaW+cNRz{2 z%|$%XkTP0L?f0sN928WlAVD(|MGxLZ4W&zUkouMi6NCTb4G|+Ii6z=cV;p%lsbVl! z6UQmB{DM2_E26TqyRArKAyXax+1HlV6+j3bO;e1Lr`moUMflS5hrCgfvT@#CB`h~L znA=_(uz;8;3dI0xO1jue$0Wl-E{%{a0R25!$2PwVVsbUwYWZ84k)}Wy8_T-Op z6D7I{kSh_+wVN^)l)6is7Ok{jR#0|_vvq;Fub7!=04|>GQi-Hnwn`WHgA3pg2wb1I z{lip<#_7_?6qX9?RH7<xyIiU!fDG^Hl#%3W&B61T-I@7#91*9^f?K@q_9r3Pl_ zKS7N3!>wgC zYaiH*K^$Qer5?qqXD*aL*MBIUlThZ1fkuh77=(0zJh_$4TF(Wgw$*H0W6dhneV@-i z=FU5>eTn>2TNT3e2&$GysmfgKdV{4xkqY&AdKG02FN$OaXaZ`Y?2I6w#M~k^%S`@S zn`*9jIyFmx;JnvkWliOL;l!DnCMcQWrb2wR}FZ12PvxX#W1n9jozTie%UFn|XmJV9t@{!BL%s5g5X3&J-hZ*%%D%7eLW>)GrPB) z>b2+^eiRtrqiqxyTabHDPWD!io^@h(6Tn;aNpJAnIWb~G9}q^F3@Nat=5+ZM!`es9T5X?eI~W7G>XH|lb1CM4HCcF9wxFm zn}`a~bEj{h{XTE8dVbze`8g-HK|Bh3uk@W^Q`LFGaPZp20>AZ5;&%+%Dy~;{5y}0N zNQA=u7OZguV4G)4Bh-BM9&e)O?Zx)lwrVB>Sl$PRA5y@ zICK@SiU2{GGhI2pdftbtSgmV4ME2-KRVrEGQ*D@c=UUWRfW(sPR+7WYb|diB3(Gf* zcE`l;G+0-UMthc6>*pOPog4UG{o;`37LYE|MZw+C$b;?AM1>rmw-aO81TO2}+immy zdVm_@U3P7PQ!wzo)q*=e@Qyg)*Eq@bR1%uQ8%p}xAQ-n%Acwwknhs_nm(z})2|HYs9hT-FbqpXF)x|s9$B7k>C2=67YdT; zU2i|uytv^Ff1rsC(6S$0ygaREPQ@KH%W_XYw9M2|0U4LaYj(Q+bfq-T@G*MbLaUOi zZ@{LOCZbPXmm>r+=040+3{WM79pu=|R)!K;iN(x& zN%{nvPj0y_vN9DvY!X1>U9I1GT3PCU=v32P)u#H{wcfp!b^b7B^g-4^WjdIj-XM*V z)^_}Cbf_X>2(^S64~?-s$^`}&6BUNw z%ja_U+Jhb3vTYdYTy7x`NHJ~ldbW>>MIFAiX>7GQsthA^ytpj9od3G|3v|AzYv0Jz zPLYhM(>NjdpgZkBAINBz;`uxA!7uR^r-sQwhd~+gp_q$bkL4Zqw+RC{vii92(>Rf$ z_J|qzNX|sNHEo+wyp{loCSwlr)FP6!14XS6#n=(WsSqt7&#wFFakP9#j7mrT?h{81 zg}6H%aROD|D%#1~@v(Q?)x4lyM9>$1yx#X6eY+j@4PzEh0jDm42#WTfs;HdlT+dzf zonh3-N$G@;>MhZinEs!oDpytgS1(l}or!++tNb2S>3!Vsdj;8Y|MI;`^PR+)e=v>l z6z-%VUmTaZa>p>+TOfAg;`6T<;5LXswa`g+V?mFDr;pn5{-Aw{^@t#*(n0FDtJYNo1XWoqv5ln82v=?6Q7w|7U3e|7B@W|6ys+ z7xd#O-p8k(4+}6`vP9yT<28DI(3wlT30iC?W4r^+mQen?aJSp6aWcHcb% z43{rR1XN%Vi=udtR$cP-|8-bN;hGEMC5eb0Y^r^|JT_?%I!_-Nc=yY zUALeR%yt7E17?pw>A zldjIfKh=MYTF-IiU!#_RqQL!Y)B@`m;#1iVX$#cNC28Gmu#9nY>ix&474vZ37#}y* z_sxFxU!!LJFzbJ{U1M0z|7yGICwU$JtLc=8A2v zbJJyO-2nO-Yhj9A+LEA;Wxek!slYTn-Ev?CT=AiHrGcG0~>`LZ0^jo(kc=n-~50yKm_*Zp(< z@jfL-_@j#r)`AtQ(~5GYofo+7Y3jb%uLddr0aR3RUO$;}Bgy8AEfP18==2Ka7@e5a z)V-BKrs>A`hk~<|;g_!wQ~~ifMR#b6E{C@sAx)mhRy)GGFY%J;QfTcC*K#%{E!n(R z=})FUaQwO)RrPaBk1^ckX8lRX>o#R!rU-dmMY|%llfQlqX0Y}|f7>F3GIZ~!=t*Dx zTj?jH+#iJ>y_o9MY%VCP>12JWc&+{;l91P>VEtjtmBU!s{OqERUG~n$=?JP8?3+)Y zkZP<~v(%H#>faLgBZF+~;@b4!ENMMDrNi?%(=CmsEGT~%}jR` z`Zhe3E_jT4GVeyJi!n#r;Mb)u#Bn<}uS}mAC#mFkPveD>wd8YtL?=*WX20B0jBWm3hxL-?1`k zeBP8Gy{A}s$Z?JAEeLO8p)yZ&4UHP zzyHt0%=I5+{GUIJC?W+E$Oc8l^AwSVJUNLVNz@@*MZuViq16)Sj8P1iHraB3AS>%7 z9Xa-pD1;zg6Js<$IvT++&NDQ~3C#)l`byzUUIw%EC zMz1Tt9x|80x1FjPm8wKz$g`NtO`I0@J#oAvZDTd<-DTRApg=WK`mP{{-u<+_{PfQq z>0ejVk1o@-@kz%D$PB5rkTXwqq#*r%K1&msn@30nfEi5_6U!1{_fH}nVaA+U5KZew zelmuBJ4&@LmA>N=IvJY=v*iij7fd`uGao$6B4Y&U-JieA#0H>+0uyzuv-sDt#D8WL zH)TjFX3N-S%Y|kuT-!ci7`G5D$X5QDt;(FE-l=xaBIoOKqy=0tCpyTjYAL(LHK(#7 zr$fs$(k%CQhr>qip1F`c1kLvHW3FQX;I3Eheiq0c8{(-);?wO#SpH`FSIze-;{anOAL<>s*lU%ACifCrY|Z=;V@Tu2HqA7UXzQ z*P}t?^ZCg5pVp6m76c1nLe`-D{sp;01=V`_CF5q&<1|0XXx*5LQknB~XLCu_{Ld1y zV1quk*0jc6ulL{Vm7Lx7dF+`P6(j*2K zu$3|A4TREkT!+$nR+h*}J=;Ejee9&!`$^M!@eD~taB8T6sx{~jFY;R3lI0|`-A{?6 zG$0;kx@Jh88w5lK0K^5SbGkT9dOD#!ovJDzR|lmOWofn{STg8SxxpgH+&u`FS%03} z$BHzf%y|4B77_)bMcEW)r_h&ald+T^shK+86I@b&1!q1EOJJ0;gY8@tXA6~N6%+uF z$3@Ecf_qF^#9KD@ro{xPVk(T$YfLKLhZzgg01>fVT&4N873`TH_W=m4e?`fl6QL{I z*OaUnoWvM3M& zvb#2=ZbN{9s2t@mn%5P^4vi2f5(dh=Dx;6Wnmd!H3xk! z=?dZ3kf*YjPkk+FA7sO_Se}`wlGvkbmTVpT+94WtQOr>w072`(3eF|-p|H54xk4f! z$R1We0B0+?DKpF^YPEIc&ChuE>LZ^d*>y6SK`RJX@+*$3Zj^x=oRQ@E*a|K+g0g|? z-=v*BfU=@1^0++QrR(cqfWRj(bHq!I0vbI8*!eT`0XB>CdrgBR@F%_^^5&~kyL{A5 zvw4LuZc7-8$TL_}zDog{Dy+sa36N?>?S}*+3?JQNTcRjqTG$8lZyyU*B|T6TCet{ZLYA)pt~J_ zXtooo!yPwk8Nggp7-c`VHGFQ@`%8Kd~rC?S^ujW{Hklx2S%DA-~Iw@*4Aa`-zeYKH7wHe=2kaDN^i%t4|xNu>bMKm z_oA0fsvFt{Cu!3?o!B$qi_Jru0>6b^n;mX%?xYW_^V- zExZgZjhN{3w0im3x;l9f^;s5BbV03L{ftPBiWvEw8a!wZh63&>@?0E~8x(@1Uw{v1 zNeBil%8UBJ!1HGlR9T&v({-0oZ+wSOAr1i)<_Fdx3Qz_35Q`VYjS%!tLecc)Bn0pA z7t{?;e%2*}g6?0DUmuhI*s{rl0`7ZLTVO{fy{X-ukzWT!u-E$e<6!VN& z;Ygb)G>e=@7xgCVV4{Dj#eH`4r|(!|IM!=pa`pg<4y)JvU0kVHvbmlnu1~vlQ~2W^ zRo{J)=qm>An}DiGSol8qtt8k+0YDdkKL=EK-v(G*A9uAnmo}J^c&8#uc?iiIeW>RZ z{&6?7eioWDURl3I9n(E|I`P7Nmg*ec0Fr_}U4s^4^V()p3B_gA2js;AV3Qs&iiCg; ztBV@J_UO8b0UhFBGZ%ygdLv|9pM<<=$xU_Xr&pB)VA&o4c5^NTu-AJ$AirO-9B&9K zL=+SyE;S4QuhwD6ldhz~*`+_S(DVt2rx8g7I)NhUjn*a<=f7}zH6v98ASAo}>%q9K z=@=0hp~(7!zY{p!wR@9h zeaAj84iY*ulylJj)Mag>X+uY7msnlfaRAPL({5lDHVLEsG9+dQrC*+FZHcl_kL9L~ z=TxJf9Bfp3T_*7TEeyB~i}H80#l^Kk0`!K-SZfS7FzsTHR#R%Du40j{R$|nf;%+C? zMk*}wCjOxVbpsXzgn^`0?y*Ce?%w1I+ka)r&+UmT`G50DnKnH9HqnrJ9g9fwQyZF(L z@EhiBrn&?Ue;L36Fl~e$4nTQ!8<&#b`4oXjIIGgh3UC8XGJhIKUP1mKmL`(`1dHy- zS9)#eYUn^tV@DmT1dF(i1?T~gGX(mk-|HlLYb!k#l=lW&6#Knl7I%IKlU36D) zmzz}Os-_mI>mRP(z%%5>?Dn8h07?fak?EPJ$O*0yO;k3qTMWxyJV6HOz8BWJMRr_M z@4w~1(LLuq-YV&(ItwkFTHXoY`q(`bQM;8>YH5zhWHY`|^`5Obk}j5-&+pU$mqK?e zoS6aqCjIi1Y%rKKKJ<|gL?QJfuSk>Q_%Q8edEM5|7p?JN8nu&9a^0sNS`lSBw*Y-% zdpy3a4zqPQikhYCzE))V29zA}RnOp@BCUMn9SO}MNr>QAnd-wQ8`Y^4N`46-v@OZG ze8rxGdRrep1L_C>0#LERZIu*clBxv1MoG_aCRnjlF9ak)FnPRSXBHv`fXH@Cc~QVZ z-`o^{CPl9luijb~cbS!#5D^fdhOJ?^m0`BJG#_h`fJu>f$;=Ku-%IXpQh&;Zvi^S!Dh9gF1FlH+4C#}1R0 z`i^q8p{px6#74aF^i)KEEcqtM2wE+$=HwVhxE+_9-g~mgRk^${Sj96jm?}2c8t0GB ztq8Sv+)T495yrlxJ+iuXU9*QA!y%ONsj4b6G)=Y8H9M%o7Und1jiU%Z+ z#Vm8rq(Ly0R;J7=WI_&^tzniJ7k#oS-7VSLro7)Q+^_TG_#d>C@s9XQzlP9;9jGvr zdsgMiS5#M&IP$C?+r0C<7y`1pEiiX~}P%p{&|_8Q%UMIG2|F!zuD2 zPMqj-w6r|#=W3d;Znd3THRV_!kDn@tml?~BgfLKWt071hpFfkM>Jfx!w5vw*zVQ`` zqc9)y{KOy*i*{wl#cCpwGEI~f_i}+_t8a=#TVMcEds~E6l%DtNrYtYHQ1T4t=7bna znx|hRi%G^huPYe>h*ZV$0Wh1g)pb=CueslvWI%+w^w@LYR!l1Wn&-U=1mQyfKp~`I z<(oWc*)uEJwBPh`&%_)8wUQz0snI@J9L5x=^5`aguI3EaQTIuc2DHmqM+xRtUCr_u^3*gfJzE!R>v>>ddH?DA1eK>h zKUJyxY4|eA*)eqT>005kKAfvT&;fq%oqqpCpeTX_%%STHZuqPuj7BPTW ztx8cp;c+a=xeB6C&0g%9X=me-soMHFDPGUzAof)737VU%K6LROoEB`?NSscoMkJJe zsVQQO<2`rAzrd=;%k9CW6_LzK$779xWA5%+ymRW!-U=wG%VcLeU#(k5yo#i`Lw3wu zF>yH6wr|XkNc(DQ<}!7j2Oj&{S(eU{af*hn1VK2eSsCyLV|Y()p0T>`6UNcv|sT2MK+g30KU zh9QJAn`frb_AG^#`T9b2^df>?MP znIqbXwrjLgs|bqpN(RxIo9`dQOOXZ{eD^Jr0eHsozloXj6kcRJeZ`*>X0ZS5?2^$( z%k)`KGhB1{q;YjNN9!A#s&B`g_8I1br%DwQZckr}_DZ>D``MVo+H^*`B1d!|O6_0s z`>Go+>8d?w2rau>9qqdjG+p>FKqP-jl}-SNS`WOfx{-E(z8A_39uIuu7C2{#ZYamN z%;tVIg9cprlGFLj#yH7xAyb3k&)HB_rY&u^)%hguT8T4}qMEF=_Wk&*$b_^#-hTV( zZStfjNAsx~7}2WxNVCa9Lz|0F?_*{6M~6hXrEY8VV~`v;08i4?(b$TIa=Ry($supkX==`8K$i}Tfm`ZIs!W-*>_QfSCm(G)E6@mE7pOVkbF^TRb}3jw9gWoy zp-LJB31BFy%e~6no55hTCe&WAW=3XLxV2W_^7gV@(wL+W^|x8Z+j#?@a;|sLMr}I- zw>`{+=LB!i8KJ6i$-2!0HDfQ;$t}C!kT#iJq`_Uk)#fIx|F?O}Cb2pldh4C#p+%|1 z{fpf%FNa@^s-d&7U`4*$^`6AUKGDr#2nJq>Q~>~QRG0ZMzn%KjeHZpXjNTpUpi9I3 zG0PgRaLEi_qDXYf`hXQ+UsEoa`^T`f54EJRELKLp&*GuCHZ4v?W zoT_*jS0_1bOXY-q>39>@_h5~P!kfB~_DZK9`3V0V?`yZw5&Bw}v=1A3pUZEK2m@4y z#+)7F)7{hHo$uO8I+l=((eUnfzEWw+UoBx-Hy+aU^M|s+-q63BYS!;mw(y`CieHg} z3Zz?6Xx%j`YJwjKCYx|)S~bWTl^*(^y4y}@E+YvBw!<6 zor3#1VPZ(8e}V+<)SK&g#OtSPmhKP@ zip#Lx8P1AjETt668=Dl04MAf#e)x*!S4%w{JC4!(onBJyOPT9R4heb1(<-Rlc^ZkX z!}b(JKGII~jAkJ2>2Rf0uM?>0JFM&Y<<|n>zvXgWDX+R3D{;RGqNg22kFyDQwzW^13=yZBAEeV%>j_*0C+=AG6>xO;X4A4P6 z2v$Xo8vvoGvlLPzCSS#MBvQquwsUjAG;zCjiTznsuaJ_1AtHku8fr|t1Dm0POf7?1 zSE`H-$-m0fsBWlj1FATBsib>uW4%~jNd7E09=N9EWHf0V=vQhpgKvO?8-Qaic-mGRf6P1o$ zRqPtVaV>Y>r^_2# zOt;M85bf~+aA!(N4HK_IV@*`bi#fLDC?rM5+tL?Fg3rw8lgx7y9;x4S&+LMAd|VUi zt0woyjrd1A?@nZ+`dYA1ozlQe5i~!tIi^Mqyiru!9C512c_V*&O)xvFc;!-uGgjMQ zM*ZDZaT!@wzjtQD!f04aiqDjGjpFF_GK{IYQ#)R0s5`mX@AoJJ4O5uqsF2K1;%%f> z65H6Lv!OG{O(oH`)irzpZoDuyeZ>@}*)m>omgeG1U9vll<;7~jC!Skknqr2VZs+IP1LqbycnLSTb!uhooIkprTo+Qwai@q5n=Yv^9H#!g`}S?Sn}>n2_rHZ(07ZMbU&_tab`q!_&C zo!*g|{-8O%YdO6)oIcw_^~wD|nsNGb%k-Dw>E7S?tH!DQeM5=ghVL$oIVooDM;lkN zYIjadW292(IEho6Co$+ndR7Q!a12caB-? zrfB&bD}Ih`X^wr*Ok=T`13ZtQpXcJ6=aw~(Bg^2yl!<`o>C(*cSTWJPzs6Alydb%x zg@t2Yu@SZ`EtVZBInx$ywl3VlFWg?TI2Rnb(b}a1UR0%DRO4Ixh**$4Tu`@K)bd!= z4qDWSThz^6)T>z37c(VqB}m3aqrFABZoH9IzcKxi8Q+q*qls*dW$@tQrQRaXl(t3O zl1=WCtrC`SQteT;iJe>jj1 zZ_M&-&dY8tXl*W9Z7z9iE(bXt%x<**e^lM|TN8fV`1=7Hj8UUOiP4>s5{@3--Q5TX zsDO+v0qN1*B_ItVAdNw{bayKP?(n(Kea`uwzhFP?+O>B*UeD$z!Ho zj;67Sb8ByvmPOyKVkxcxH9kSK*RcJpuwsAVq_5$YSCv{msyi`uRX~dxaddqZoujm?V>*sCqI&2$8ZX2a<8<%aHv}~IWZkx^8mqvUTsr~ugY{Q&l$C`D=Mr0@4 zciYxr$KGMb!OuaEaK|xy$GL3BrDex;aK~+S$9-qVV|KymX2+9a*Nb)6TV(f@;;xSe zl9ke6*U!&!|7v@(Yu!qH=XJ|&fWwcsgS+o`c7twq7iS!)W-YzkXOJR$Go=~%<4F2y+x((DJ|?tf5xj5Y12QOw1z zsOk~uc{=Q8mN`S>EpM%Nvu5`h2Dc7_k+Wt;!?|?3F(Z7lUNBssz0}e-nen&%*N5eNWNZqW% z>?X~g^rLSY7T;!%ek8b*V)ql?98Itun=MaHDjuuU9vP`^-+OgcZ5>UgA1{<0FSc0B zd+)pWyD#k=qe-RrSFlgk(mR%K2&0KlHtAMD{FL;vCp&&86Dk$wK<)0JlLKr6KWBG` z(Y>t#kHZ@e>7YGv-Hg9`C#NE(Oaq|fOM{o46z&%p%NK+6-*d_SSe*VX+ni}Rz31JT zchI??Me$#e!I1Rs|H{trhR*QkJegb02!4CodYlqbo-5Vu5{sUbDxH%Vo|8MC56x^* zM4dAZW>J-&KWROGI=6E*cus3*zqD~qk8=T|ykKCvU_|Lxi(D`pUa&Y`RL`8VMqRL@ zE;wE+5SCwX4P9{0T|5SCd85?OHy8YTwp^5#f})q`XGo<>A;U}ImoSmnuSm;ZmIqvl zmtQ__y_6Wbl$^ViVl!es)fU{ngwva3P+Z9l?TYPc%kR#Si26*6xY1!>DLQ&8&{Ha& z68t&4+>E~BDD?pxT%CfgmFbtl(>ISPu5~HPmglc9-+7mPzgqpXvI}^nr(~n!sIgdE zR+??RQ--{j)Rb;|a2M*_w`euPPJhnrdi}yDLgbmpj+DC%{gx)>qA?DSNBE7XrSD(m zn?Hq*L#A#jQ&dIXoV)(sz`^J@%rA4~OYqP%@=+=X!Ime4{eJdu-$Cd1$Irjt&iy8_ zu!KDLZh+6haaqTbs0-#!kftaHDe?Z3{NyPf-H<+pm9 z&!(ZIJ(-!0${teb2>By|>=SwKQ-t%kJnC-+>ThNF->O#sL6+g_xxcl$f9rn#{fKi{ zPZ>aLe~WAoz5AqeH$$^X0cde+q_VT+9q}Kb^yDmz7+nKxW-*;bd0_*Vq z^swFci3ZY$-hMRH-LVlWx(NEe_#(U{gz=|0%(xohI4|DM+(%nf_~Nd}D1Dya?=MMnmrL|z2@ znUN3UhwAd~jFXf`^Sb{PI+*@ntLK(iV*8)f^O;I9^|(Aaz7suL>x&}g^SBqg*h&&* ze@S=y@#`{a>2KiieTT~7{Qqb5RMlkvN&3fxg)HK|-x936`u|xyda@3Lg3N5}O8~_c z51lVl<}Q60!Sg?`aAFlPhX32@rO71!uhnCS6=RHH$;y~w|J>k>f?8wJ7xK7 zgZ*SrGE$S(pWia4YO9E|q_Ni;ZQC9{L;}*y{cl+^(AKfyteH9(GIp6fxPL6N+&gYr zbG$~y*>Zgo(n@mQ9vD;S1^m5bD+qy9&4=Q$h@sZgfgZqPg@{m;*Ay-_ipU1@rIgeH zPbrvrZA#zto$#!&>u2FP^8kXYLd*A3A`AA}??e`zt3Qj>T2qFzE_sbfiT?6kdMCQ_ z_V8U{8`jllR6#A_G}6-mi8q6*TG2=Tci5dfkO6;lTQru|0&*_~TzWZw_Q_);pr1$(iw7@DyP6?1et3dSW* z+l}z_3Sm2Yd+4gYZ^~evC`uP@3}Xt9n-uP>{Bc6P?|Cc<=J3xXwyY^hlXSr<+iL&u z?%#fHL8x6M*zUc$&@=8xJ3JWmre-*ID%(QpgxoF1M2rSkF4a^)Y?ZvPE*aY_xDEhgA@5jzDmB=-+8w==X*^C^Pr<# z)}+<*^QBG=3GcY=n7EAJ(l$x7R-h;jriMqnDTG5xQMB}os>Ojf%){Zza>(%cse5EI ztpAAHZF_po6BJn;IYZ=gWQ|jpO2L#}8?Cr1o3x-)()vW;Fy0^=^-XrGX_djlt=Kh_ zBd09Ko5aH?=Y6vKMs?%l122*8Qo%%?d^>+AuWhiES=vz!&R9qY-U8oh-u%yz7<~WKcZyL zVn;+Vaod2M4NcLwf6k(HYc?XLi~HdMd?lPm8BpX_%YD*$?5nP30joJXsi3$}O`or; z2W`W=OpQq?HFobA+I{)`94e{0wm5X#RW<7b)A3mCGdvbxy0iepnKyIwK1lH~2kB>7 zQ}@rWs@lJM1bsSj!rv7S+a1Lc_>jB*zZKX)ST>q_PlmK(Jkr5NcQj{L?o;WvxFDet zCLZV7k++lS-B({Y6ShURbvjj9m3(IDWv)-DJ^H#f&-K|4@d?Fq?0h*ifrBg5{HfL* zs{Uqa;lrdU4{?|F=F*9OHGvB76`Ug?EEpLxh43O5(%f;7+E z@Q_&E3zzF;iwY2P8$5@D;wQUF&7Ajn>Eph zZqP-_%P2B5Q1j3CO^=>D&7R#CUy7NVifCIN?Y!Ct|CyzY-W;UV^}oqK3(uq{L|jz# z8DHLx>UPtTRH%f~5#FtT;k_z&(T&{mO$8d5{ky6}21%H#`)}yd%f9>x@ey84d?N}R znL|YnMSRiWoU}L}Qj*;~*Ki99M6=AZBJ@{~&_j6ZgLh=$>E(;J`>tR9{(Bs1&_plK zZGw&nZ==pX&&9O^uPHR8KYY($P0hP($lF(6xq6_-Rz&Cy=i7R&`uK$7C?bbs!Uny>d!2&zK|kSc7D`t>k8jC6eX#lP8jTY6I!(?ra&z71d@Ej1gYA5Q3aA59U!4AIcDf5V8oCoEP=z| z;)=37HDKv(CM1UgleCMC9Ak+jc2DQJEupzJMcu~sg9W(`lbTIbM>~*81&k%#-@{>R zf{@`tSnerfa3EGKVQ`X1FOJQ^tc=0cBvT0MWv>ry5D2H3mV4^atIbLx1Krd_$%UZgtN_#aw2hU}62{Co} zt6wnrM3gi`I7x80Mpb_RZof09JY#pfn!FuCvIyve&A|)@008FQ5vLQS<|iv#pB12z zlPpMVfTGPhoFMoyZKPnbq0@o6J%dGEzsWi;`@VVJlaS$H!9*K?foWFmPy>4E#OB9a z&US7~qrtZQ83dRv=1G*LM(l3CBL)B+89^ARxDYA+Qe|Fw@r$h!ZTo;E8=mQ}As2a` z?lM>Q*6-i4WhlTYI^+&$sWOz|Ty3&iI>^0$?!g{6%AC@xoEQ0y=}u{8Mh0EKzDxjP zse!&B@_^{E_uZ^|002i*6bTJIl%JkD&X!~v^J(%RZOe+0qNNN22G=wi>$Lz7a8|_A z7Rr?|B+$|)YW(=&bP-uffW#1qKa9{uS&4CClI1-!LI6o($bKnn3M5?Hg2Q1E8A?J1 z3R~>k&m4M`J8-wq7l<7bL{GeT2aQ%Hz=hRZu*woriRlXFkeuFES6CIW9Cq8 znF6Mfi(pj*h1$WXa>WaImO!0-z`QLF0G!OZFQ;bDQ&E{2>Rd}0p1+7mE;N<9p*@r? z`)t4RyN67mj=VlXh66c|Sy0CkI=GM0;dyD}K)v^<`F722(sdTGELDaZ=0`#)qaK9b%$D9f@N)Q;;NjwSEbEz+ z{>?>hja!3%j@7DEJ}3KNwBJTr!IBKbz&XJ5WQBq`g6zy85NUPtu0>YLEyN##{>yCN zsR$RV4KrrkS8owE7Qxd`tQc288ICo_6mFsV9ib9db;>s$sSx?M9w9-jy`0lnT%LS) ztjTskIo{YU-5a!g>)@JVDaOxe4r$h*#FESY`3yWV-a z?lm)qY6gRPJ#REa{}Dz1SS+wc8OUtIN)6~CgGQ4DqgMd+m`OLZm~afM6YR%dK@#hJ zvWh*QihJ@CQ}9wB-+Y)9f^E~!p6du(v#$~^-@`tQe8jJzQp;gd^$jn8$0i3Bly_KE zW6qlY2CyBxRyj&}jugq(oaF!*nw+PCgf0^NbNr36>vZAPLArB-y!-8aJo9pta{^DZ zhL~+n-Qk98O%PP*9F$n#R-U%evbxf*xlqy*P(~Cp z3^vMJ@yKR)$i2y<0I~`^Zl*ch1J>)vfF}SW9!u%toMoRGy>BTk4Lq|T5gj<*v%5R-0+kI3_OeV=&SGymu zln4ii^`C^YK@&0c8?AE)si27)9o=j}g6Q40*{`Si1KAyHZgep$IvgaXFL?@#PDzH5 z7P}9RTf<5RrS{3OIlHAQ}BUb*MERyhX zOL;P%d0CD*bMS(-(=CK8%XM^X;9X0@f>*K?jub6!c&#?(`LAsQDQMk|WP!`ZHG6Ve zGOC_JN==V2`bw$iLb0AWP_&=u8St>1Q<;9_9ODFV|CM5I&sTGW0@2$Rg!C#*41EFL zZquodC=iM+tVLNgpG)%;PcoP`u5UblWi3r_@!ajb7oUVuo-$v*T5A}G$u=e5iBTg1 zh3)l9yo^L(1BYBcZ9)Uk%#rmpcV`eSKq=N1COW_j!Pyv$k%dKVyd1qY5a&8!?`~csG1qn z836HFEyUm61Q*(V<#eS~rK2ukJ*%vKQIe;>(r3Dr7^Kq6C+6zi7{+Hcmuu1}HQmQ< zwUd_b^&Z3bMVO9UHe17!eKH~$r^X-@`8GB&Z`=V7~pp4|bpm$f8 zV;zNB$mfsKy9q7_->+N0bQ!-9A}{J+!fi1yTWwOm`XRw8K41&{H> z?kM;$^*ffo$Y@bNSM|_Sl#ASMgdVMhF>`9JigA!KsX7kJqw0L zU;=P!Y}5q-;F&@IT@isGpmyPdz48O2*xc;0<#9=>FRHK}0;ARX;&9V(%QZX3w4{jt zoytA*anh8Vcx@KG%Z%*{l=W9^%BSbU^*Bbf_4-E!+k^D!3e}>5k1x#|9C6^Z{wQbz zgq>8LuV$-`F|iYnpLO!Qk)p{%2>4%Xm5E558&7=2LM>q+m(ZtTkh*n3562&))rt_O z0|Q^$gNG<*nHik=Hhj%>>1`If2STlc@6 zNL)#+{*=7yrd_FXYNgk1u>6PRTwwxILL^~o!O@f|7bI(3?nbVVz}Kn#ptl$xnOa0N z9VxftEWoQr=8n6zrzDe=5)=8qnMjjrzmim&ANT`+oDrSBUb1Dv&CdRbp1QdwdXj|< zAkB~`dQydiX%+u5g&k4sh{*TI_>u06H@I&+er!x3Iq|>ok_WvcFdmHc5jKQ06LVcwQ zs1O7`GnpEO%yycQL*;h5ncu#>#y+myHVWlMjzfkuF`$I51zL|x&P-mYi}0a?wB{{Wbo^Kmr!1O5LuK*7$H7qxAAi0oX6d9+iu<6N!^$4 zcc(o#eD`PllrR5BW_G#haG%~^jx$BWQH9IoRHzr2ctSvabzA@c6s9mTCk96NdiX2d zti6Xcei0P?QBGvds506{#zrwUnuL{bAVw~(h^!njz-e+GYxJ#%qIGzHH*_}|FQ$Od zkg-LO7l?~h@JnvI0?JNd7l&WeWUr7P35(7W52^1@>i2-C|1>*29 zN)+(lCV>Z)l45A0VMEB9B=G2YBJ!$;RrHZ>`ZwD9xq+boXc-XDn`#|v>OXju!+%W- z22*WwDsw<}DFmeiRr;A|V5g^}63yMP!<1Q$AEl$Fab85qa%D4c5WWh|VXcj4Z zW3*=b{;PC=)UL!f?sut35{7p*Wn^T~;7C|VI!FeR%xXIx5UG2fGp5&)7%HZh$IMvF>yV5W* z%1o(x$gq)uy0U{QW!&)xvsO<4An4;09w&R5^^r+Y$g`>=ZZL_WHj$2}55vN;iE8oU zH(hCACRclVvv<1&m5R)H*7nibohpf~vdj=T`ObXHc$Y*USp}S>hB+d3W{{*=2s7r; zfHaw|SOf?lVISaWvzmznv_vb9_aO=){asJwtB@7$mCKx?#2ik3c)7FyuA}u22Ixc? zw;%;)i+_ctQ}Lw5$=rLa9@>yd*2fnU2wAS9i3os@9$rcPz>BiG_x^4efN2^zslfFQ zmTJH0zwjTb7+d27yfK8?C(6Dcuo;mh1jLt@uqzy@#9ERWs*RYRP$Wtun$vByE#!QJ z7bC^BZm}_? zk$m^$844f!Q;_8MPiojgfM%4GR4~KlsXd9&I*}=&xBd1Z>H`KW!bEM3Zu|O#ug-Mq zH2pWl8t9ZYRMkpsvKjU)NCWeE&E!@~FCEhCjfUKN@ap~F%?_Oz%;ZAISiQ#|bUJ1T zUkrwswTp*w|EsCZYyysntH4Y=0{05CL8y5^(?;=Ubs_w@(Y;TMR`_-75dgAr9b9=7 z9wwlV!ekv!2kM8H?KGg%P#xo6Z-XEItB-;iuCm=BH?2$y)>u$om~KqNkE93&U?Xju zP1P7$MpGNinYfP9I7*4wqnW-_V{vLGj0+6 zN%`q=XxgC6 zxjth5aj(XuP)JdmIMVZKY)0{!e=~A$Wmuc&pR0g6A11bvpcuKjH=%*?q6eDB!z;o{ z1<1TU_B{8VZmwyQsoO-=<=Yh>s8C{1l_I+!b>%TB z;UbhTQ8&b6BD0T?LpyT7Qko{RPUJU?xGq@^l^QJo!petG*iyA=P`-JXPjg*@T|F;X zj%9!+pah~!#8sm@!S7s1hitaGyrM!YnlR3l2DVB^)Dju|R%|-A=5#>&>tzm-9lmKF za`eo6??&W{x_@Bpdh-k`2+9oaAV?8uAwziKX@KF7mT^Lpai`@n+y^)$kwt%?+`~i> zg6_fX5dn}9+(&0)jN|qS;Rb+WJgSirbiM8=20&C68}}@vRUs0K zKaBNpq*ydmyT?|-Ne^)5SqsJZ5)OPR27UKNsYDAxk{BT;<{%p#Zdnnkz=GlYC;E>A z193fWO(O6!I4WX0JZm1(JsX4BfZGlAI-855zcS57vblpB{QK?AhOfP?C;&R+gpY+c z?zGq2^6EBPL{dtszRf!KL%3L;R-?iJGQpaedw6I#_V9MVOGDf00c0uMw_h(%7pl{- zM~HnMC{?cMaUIwatQ%NvzfI;=P8rM-?U&SSix`YYJZ`)P5=|)75Iwe1N@_T)pdwMw zH5n2k!oXr%A6pvdQ7qaeC(=#HVgJ^`#^3g{R)UhC;TGh^7VoHl;xWlsS{AKobOH0& znCpag^{~h?i|qa9bFlJL_VnMT8NZY?-O}n!<8x z%lpy*yUx?}SkcL1L8Zh1FSx=#w!-<7;p5~nsOnSdH-8kO4;x@WFRSojh&wD%<*h0F zCMIMd&59L(9I_(zi64t{mIzV&+i}C0GU`Sl85M7lwrhmR zE#c|2NX!DWag^SS0Rn6lTmr?QXbnCb_Mkz(4n6)u)G?eh742GE;F^{OpqbbDq>uNN z9hw7B+p4hWDY3wU$0lB@U7_My#vx`<`usMx)8;@gCuHqe@HgiWB1fY1Ou(ZIsbhzh zXqn1a0D_INUy}vH+d$yC8012`Zp|Z9E6Xw#iz`{fD^1iV0c4m|0JiK+O`*Y2ijY8r z-QXUi4hF&n1RwVVQ*=O!#))iVf{ml?nsI>E3y;SI&e8@Ud+o^PdXR?}By-PJY!q~i z3q(>yV`~@R8R@@^@XBb*EDm^^m73SHz(Bg(eyT-(V2>vc0`w9{h_%O0jLxJ;82jnP<(Yduw;#} z4yse@nIF;Q+bIg8Re^3n4PUJ9B76GxVjWX> z9mLxS48v`%hb7Mwx(DmF&Y^5KyvLd!|LXv_>3%=i@+*LVMe@?_RvqdA6_&rzM{Q;* zIKemk7@NPdPI_FH#O$6&x(c=XJgHQwrVJ*$^Z7G~Yyffg8))h_v9Otr42lM`Th(16*C$&7k2<(RQ3sSi**c6lH@IKyGr?mRr>tM+2cW6dXr(nL)J4EyRa zt_asFXaQxR_w$5ALE-a+H4bV^WHwI5_CJ~ew8_me|JKToyWE_8DA6K1P*mZ0?Wap& zzGCFmI#ZD$Vui{nt$X4$SY?G{lEsXTu{j?9%AgumQw)tY4`;V9?!X;HUUg%gIbsw5 zmn-tagad@%rgTQV%zLh8Kq*tfJyIVYZ&ms)E#HN`&mQ+F2^8r zG|VF#puGRGuHYpJbewiL97756i4xFxV_`P{acc6^?Fk!Ij9FwwVqpxJMs_;gbR9cF z?}v<)*9tYLib9Via%n$%zVmf|>y^!eF?ClQ5&g<+oUl09c>HMvH+fC`F;w9o9xV>Q zHX`~5@9%LhLPpqox!HRzNfd5G6qlSOC|;MA@<2`befek>;mk1+WrKqiv4g`ut5d(` zpbvs|q#`x2SYV1O9r10&VZiv&5>yKYK(wUuW#zY*m zoRc4hbco>=eUch(hV`11YUOt8t9pT?P++)vOuu~85wvuJ`@d#qT|KVmLcYQK>U+wu z#kYRWhMCXbS!vflE+N9i9Y^+ns+ihDi1OZZ&XEZPA4>B;y@L2LGtFzehFpb@sarbB zZMf6Ou4Md{?WaLOH}x1Q^-)i)vP#ziHgW(=GZ14X)&C-`XCJU9@>5%W6di5UqNhJc zv0|v4#nN5J8a^EYpX~-F6qS(yNL2`{L8@6=cp7)3(L*2|)n)?P{nKeljOQ3V>>x$s zJ`22&=k9*LD$*TA)I6oz1GnCJ7YELQx*m{!5m;74?l2H>B2c@B5Y(`MPwwXAh8Ilr zEtH*9Xm>B!Gb3l)dsK1yWl8!qfAFGs_lw%M7BmyQ*K5$Alk3gd;%YV1p5kVlFN5q~ zRu?1@zg3s_JK|b`4T%I1-3Ld!Dn48J;;NmVz;=8GQaAmvciGpaC6tCstSoeN2DswPV*_}Vps0y(n zuabKGl>cgPSC9Y)Aodu zP1I7*w+@}~;aRqLb7hkpb{=e<1>FZHx074ELw2VL(Nw2z9e2Lg<6YW&V!p`4xA9{v zU$pCCD^%?!^Ucc8q|!5*G5h|Z67?IZPo-5p$B0*u{2ixXpM>nwes^gPqpyc`Voqz! ziLg{uSd|~jlwWPpr~c?#-z!n{zZ%9`Z0rS=)Fl{CTt|-%L}Bn_Tn*L$8b(&ifZ8t2 zS&uihP>y3aVxS;=Lp~OZ41h!kx&Shn0nDt_P+|lyDjS>q3M@Dr48u)kc3L1oghpq( zrRhL}Sup`%w{(?Yl3;ZM8Op{I#9-c277R>$I3qEuI!X@*M#W!d!IcOgna>exGLXQ< zfC6OTC+j0T*KW$^6{Wub0cK?QZ(_p_zD3LMH(s>b!$Yg#P1Dsc()!E!VmO)?T_l62 zc{Ie6baJSh>3)fD5;IeW-E!Fx8m81#iLxH3%#tj%gl3o2)^=#3zcVgO#>VxR({uex zm%@IEvf#ktLS_em&f36k-lg~)EwF)f1E4BA#Xw|17J_TqFqF;{>QYaqgDs{d%fXC@ zkrnj0a3MpgBn1~Sn>Y{A>zovUU z%Nb{{9|EJ2k(%p?LN8BCeQB=2G6NH1ER3d6mZcB*e7}U_DPC0}WGQ2!So`1l1zhJo z#GUb}Yy0G&fm;Kf`p0n9;8bLuq-P2U_h}Eh-)}f^ijENje!EVdRYymkkfnvw&Ge)$U#dnnkAjFXs;1tT!) zK@Xd4wjW+#i`ykBv9{Z#XgUyTry4r6+apnyk>U=Sj%ELoUZn1fwzv>vUUw-vMj?Zo z9q~TdRqy5(rnQvE_hJG7G1>>`18Dlvi1cqv$w3S(7$l|jyOb%xXNbnCku%rQUca(f z$122aX<7>B&9++NOXO0F1|uKZ%C@QZFG-ZHQas#N=q@=3vw$o(J5^prTR> zdCRO7q*|b9Ylu);0ynw#lXJfG>F@M)iGB9oE&r2|)C{*FD-~1VM1AXcjFTE18NW(? z10z)LXQ7GJdV>3``0)-K&zXkFOxC-!(;6&lsKXGoLRA_oFK<#OfZJ2zT_A9bFy*@? zeUh?Ap6F(+M|Pe*+D8&WIN$|w#fQ*Prw*@vD{-Te9b?t->EY)TlpW8CcJsJS(k-e? zB{?UF%_tIg+xf%|!=tC++%7X!NRUkG-y+MGsU0TcgP`!V9EV#Ye#tv7^@%}N@ znzFHRv1lLlD7ws1-5bTs$<)$JKj8@L3oihP7}Q#_5)|F5^U4&qs|=1Ab^50Wv4z_$ zMH*+Am!6|oCuYe;=aiojczkF{Np>qD3lYg84R8kHGH@|wGY3r+p+4|nQ6Skji{;)k zkrc86Ii86ttG<{cc?}tei#HzDBO{%Pkq5D;Tz$7%Hs}>3VExc#9nVZItJj#MUSXFU zZNWL;UBUUI%+0F3RKR1z6oL5pah6TR3Gc^?{>QkmZW>L zKAOt&*+uE5VO3vocV)A4K?|dSh^@ensg*m4lkGomr- zoIm9JWir?Od*Efze@V)lxi4stXD~?8>ft8BCtJ7dr{v_b*-+o!${d95ZPd(no$JqE zo$KW~7nO1}uKJQE&_+N_F0D?--ATW@2w{`kFOlu^)m&{ z^cD350`$iaJ)`NC8(X#Xr|aVq0qQ<;pqGbU>M==a z(7HIIxqRYzz;#IV;3fHq*UihOHiwmI7gZu=*qbrs3RR7o@0;NHAcnaby5nL#FevO) zN}zxH+|)!URHV7w+zsLKRBS1TuWjJr_?N%L(z1wXBXUaWuzU568aw0V}8ImEWIny5Bx~5x7CQ`e#Tf=pqemS;iH0*>d@;<7Ir# zPPL(!tjJ-&n{FDEh2wtkte{Lv+?xx^mHSvF?>~OI@B9LJ+EdQeb14pQ)vfsr7T=whrroT-NhKWmlz>e}jceA8KixyH z)kApIljJX5bXR{0?gaxskhl@z0~4IGh$(<9zh3nsT+*mYVyS&Z9xIhkr)8dQ$#@?0 zQY_2RlgYyFdnvj5=Er)OYI+%jWLZ;X*vw=Z!;yW=IeoCyJ}x&|?m$_dIN6C}S>F3T zZnCdDG+%j{zdpO_;+L0u;wFU_4r~`~>YfZGBq!_r`-g|Dgi-RoM{4Wqf8G<=k13bX zW+ILs#-CUZ-x(HPE9hC}7qs21-x8~mz2YaDK+aH!C5nEWvJ^ORNF z;k8%8k>SHpsl(AF!!b?6=-B?@_tV31Tf_0dZgTImaI%pk=8ZUG_!vsA)sC(S%A@18Re~ffZ+NToW(dgk}f1r08~*>qyHAIj5LZJ znfOb{M>=qj%^r8zUUT0jdN=^|8N}6%#=C%6<-BleDkwE+w7XYXbsn^pmxP3k)%%QD zxT$vBcc!0>w%4d~g(+@K;C{YW=_%>JV9vWLt`BUKsB2Q$Ol+Jy=ytlx*_~JEx6W2G z8`^*hH1`R(fBUpH{Eck6J=*!(A#(idYi9h$uum#wm4i4a|Zqx*Ohz54Q>nWbIty- zw}lW{{*iF>V_FsSaQeserUvuXkMq>e7e1dMMJ)5x&7PBlxY5-l-0G&qAGVQXGWU`{ zPPH^R_?1Fs0k>u!672}@YQAGBj05$@L7&;>O~AiK+lB7BE0VQvS!%#sqbD*vUqdUU zlgAt1)=F>il)#^L4NKST4Q4@bvuU_9D;nknwM)R53naLc{-tET&%kc|S#o%=XtK7m z;)wBT(nH1RDnTvD1X|z5a&M&k2UBj3#qjqCmNr3p@n@XOsyR%D&z*g0KUF}Cw(!1X z#ad23*W2Q5yvq9o=0d+bpKQ8(Zg4;O8LG=$Tc0@(eaxD5=S+Pr)|Fq>?ak?|Si%Hj zq}XqbaXn@c1;^l0HE&`P1!|}C!UtukrwQhBxaFqgqz39XpZAEDw=5!6l^^0{?CD*$ z#&Sm4gqY~y1o|~hGhe1TDlR88=3~BTsanei^CuOwHWo;K*5`C@Uw4X*s#ow8(pHX* zl?NtpFMTVTYU2p2M@l9LmM6S~KJH-GSG#^YpAHp)9yrBL82GT)Lu2D^g zOQGzG!E_ic;l9jhGlsBm7LedtxDFLyC5buaSE5LT(O2t;7+<5fU!nZ6(3QgIHyjIi z`;`LFwD=w)9jb_w?Oe--((Uok_>&DR=5Jf7OB-jW!nKFXsh2S|dP<{HL$%qBvulmd?S^q; zbx6EK5)n_@;7gl)v^tKvcK>;M5^x=hs;a! zSyuXv(~s49Tr^Io=67o{j)xERGrwxCNw*~Q42ouX|_0`ab)j!BbG+SzU_Ll~+K&s>%yYdI$ zXaoh}$jSo*58kc@S||VFSc9}j5esmy6&6A=F+VuvqCZDKDiaBe`Efm?k$d>xOuzzf z^4^9#!MBNO-J8esU|5VcSH$dCA1*WnmvrDh6^GpEYcgbNT)s>q1V zTm8Ru)pVjQt3WaH_&VpH&pf$e@18$@SIg`TVyPgzYM3E##*x5e5fmXYvVS4u5Z%sC z;vxx`J;@hoR#0>(VRi;7N9QZO>&)c;aX|~v+eaGF?`!|doWKAoHU(1|{2}kpk*YCWd0!O7!J)s`=vPIptu$?8~*bF7UV&Pjr+Y3wjxVPI5)z8yuaD7=be!5dl5! zLtR+9eXES_O~4D?7D07SQ22O^%8wZg@Ft#YsJL}l4Hw)(5kT)EP+SpfysR_u%>W!1 zhg3T*W?(R&xFIQpg~3@xGY&ufFq!0#R%JO9#&+%esSMR{s%EbV7j6ycLYy2Lemljn2o@_PIDfYYPGh8%0H?Q-}lD)>Eis- z-rH6&+zD27RkC+hk0Cz&z247%oSDb@9a+Fz`Kuo#WelYXjkRq^l-!0YCPjm(KuX3f zK&olmEWtpph|UpQDh80{c>GL9GQpZypchUZVX@iwWF}53;4rYYErKVQnAr*>X`C*# zUR?fWg?!63CgX{zX9Rl^9?JrA)_81xn^sLUn0w*CaYX(57F3B?1m>AHZ45&qJ)=3G zfY0eHEBJ)NW(HWoq#|Yo1{ENB4ATrjxx{7jYd2Dkio=A&^q<3ab#+qL-|&q%Vh;nQ z(2IT%IR$1yZlpje($|2K$pD@w4YoG|nk@1sU&Y)X*Atq_0G&1*M{NMGK;!6Ml6yUN z%>h9vwEqnxH>2vo+h2fj1~Nbv{x}kP>6(IRNEJNs>!e^D)9)p7H>t1mjf=e6XT zZ?;o%^DypO6ggb5zP@0y9A9#JpWF1+z4kh_#pSeGQM7|J^6X!{i~_SG1^5_pBor&aQ_k5M$-E<4Bg4t9XW&;2IGk@F&Ly)Xu9M}_ z$wE_YpS~I z17znCuY=@fI5;akOQDLWmm5Gl@2rvY?S&Va=lfN^4_1A6ZqaUw^g_mb;r`vD@QK;! zw>jo(Nxz$(V?UH9SL2<6;tGqvHeNmzz>h!nr2Vy%$J0`FEC&8f`z}3`V>{7YZFCBh z{ATmohxMjWHrV*;=Cc;K&)WX=7nrLWreWC|=Q-b3uW@giUgynIBIYE*UQd1GdL0+7 z8;XvrIUy|F#Qd= z4cM;mb<>T1PJ+~O;)gs-cXNI`7f>wOW3!EjVyBxqsGCs1yhR=nw0|Ut_nb=LdE&-B ztJ=l$e*tSil)sr~n|zv4Bq;ZcH~V7!9tc}H-O}~YD($*h{awck(qApuPwSs^s9lnZ zq2kd0FGZ@wg8h?2lzRM}w99=Gy~nHE>2vZ8kbri4Ju)SQ^5~cT>8E}UTrgvKi`U>I zN!W_%{jM#j1;NQ$h0y9?eaSPNY(Bzc}V8-Q&Sp0I?3TW?%Oy~z-+n4~L>sHVL zNCJp0>}n-R03gAD0S6Bb@Mnw5!h{J8QjC~zR4r1swkSZ7k(50y2`X-BN3fzo1OO07 z8Q^L`fB+IFE<`EO;FfwPXC`UYj+H%y3R<<=WQ&Rd018MAu=mg2e=kx}0w5d!M9L^z z1D^gPrGO_$uJ&3bnV2N3)3Id}F4RKm-@T_~5lU%QOA48^t)x(tVz1qnnr}IMwJPOp z+W&uqYbF8u@87RYNfSu&XL&%a#A>&sb@?lgt!%xW@r-9-x3?PkVuzyDT2M;$eSI z2Z&UiIyI0jTgQ_rh+p;2b^-MApul|^4SQX zJ)NCq5J+UyrI&ksWyV!2xlLxHL0YilPZ9tugn%X!_SH&a1CUWuOD(qK4{m`a;9!gY zuz@r{5&=-Rq>dG7K@UOd3HKY8w?Wm4D@z zq&u2ICu&P?wIabxGpbdpYBL>VRAM7iH9!Je{8Neqo7^&Ga*kUOFW2jqTxR@q)>#hT6?5aIi(#zIQM8-XBB*3yQBcAHih#i+veGW8#_ti*@LO0%d>%BMM zee2z1OS|?|Phc~%V+)`MC81IRK>#2{Emj&71a^5^PMtljSW3B*kO$S|XbLNoXGyOt z5mN$-$0iAah}Hs_<9gCVhEc326vg2!0bvl7uS>8y5M-tl1uN~?Qp(V&Ta~)y1?|50 zjO~y?rxKiPVU%ruNI2O$I}#uf!1K$eSC8qaJJ0;{&F4vG0vb&xJhJQzFp5tRWV*=) z{~;M9g)i*TK7jFBF#)Pv#dxG311HFEzyZVrQ&eij0Z9Qtg^+0qV)oZMbR{Pm(I0=TqVp48DbIyo)wLD(`%1^LLOHn8}@8E&B|ac>Vd!^ z&MOIH7zG&;LJBSxi1L)d^AR8v69aW5g{5a7#CyDbol%t$~_K4XI30zW@m`p$mBbdZnkRm1s zpdV8RM2g6jfhs}RA~iR`FqaIla_M{CU!6sZ*BWD4N`z|jHB z%S;6!#{f710IJw7j%)}DC;S-+R)5;RTO@En90VToydc5#<)8P+&rM<;qPi zajW)-6y<_<&mf4At2G>G$yV~X07lgoo8TsYgyL!o|J5~;Wb(-aNin8Pu*@nG314Bk z`M?V%(H8w20QUxE8!sYwlD=ua2%}BSp-uxj+VlTbUlB9v0^X$fyZN^ z^Oqh{m{Ckf3QiJfl@s8r={(MMaB>py2*OU(aC11iF%#nsQes6vE=Wi^4}4f*AcFuv z5;-a~PFit`DNKeGt$u&?AS=>qSK!A?=5hbhhU&2#ZI^}HD?N9Tini<-)uYtBOM)xR z?d^p~fGKctojYEJr_rbf+HI!sl`QxaYNh2DdpY)eIB}2ryy$TWxi_a?rm;d#CJ8L= zH+kAzixC|}6FE|Dv_VcCSLCKa#%y@922hF)Rb&kAD7mSIlhc0yMEuntX+?gvI+uft z+7%ij71087P7QZ6!--)T?uED z!B@Omo^VrM#L`8$hC(bB1#wnZc!f-S9sEc{EA7r!dC^vh1|Q}LQ3(;~{1f8rPhc$* zDwfR9U>Pui506O;tL%lDEE8w6!RC+`EzkpI^`c0<(G&=P7A{U>0OK-w+c(-%7Hv@% zjRb!m_S?L`UmJvkN!d;{0RcJk{~@nvjv19>D@j7+Xuijk!36aXne5C|Gp08l6i8ly4KMCw=b zq@oa@&hFU8uGGXR@RLDEK+mYtQ~e{9JY0WF00HhWB7ir}l4aWc> zR6E2aP|A-K{0A!t(LXwzjn&gcPUJ&1TM$9ebFqR12@w=ABk-WnRBWS1I?4F70Rn%p z1c!_U8@>~>(c@dXLtI9|HMakUR@w&gC;{PJ3HTt2_%#$u?qgBu(OU^o1m4X|F4Km% zkS|6PDQaD0!Cw-AW;ll0t7QjRX+jWzCU!XsQv?8K5R9a7BK#mkE!<6lwGtBK1s4dw z6lNima1}lMN5F7FLR`v5u1{I)j`n{@OC5Pic`j!=0zh?^3@(~hOe_;l+QK9N0Of4f zGLcfzB!!azmT3v5f4y6A+95_1k>%uPD^M2$wUQLvLMz}$xH!**K2Q$1l2!!99R`M0 zG9`0~-bzFPcZSt!{SJ@mp;d5G8Neeig-a_P;Nf(a_6X-n1dm2^7jV@jdI5hA_qe7c zAX8vOfbA5eUm<}Qh~o-E|5|=bDDzNgsr1(td|j%^B7oAGqv4PdW~r8LDVKJsmwqXj zhN+m2DVdh3nc@w0jao3x%%$~MZD>l-aKW)Cj0F8ct=wBgWC~veK#<8w@iYYkG9TX6 z*o`p+d1)d8i73edSCZL=no57cjI=_OErdW21o_z>m&sT&VhbxA-&JM~7nmL?7#pZq zs<8A0p>A3E-GsgP1@xT^nfR6z3=RWY8l!NTNO6KKkYYltT?>(l(v-yUd=YFoJU<~g}2#FJ&=O+ah0pl3n$zHpe2MPoC~?|7)5A8v@RZ^?B;*|tOGraTWa*g zo(e!}fa{~7>%QTM_ObuMi>!_&3dS!0nv;0kyV3_`Lgmm&3ag0Qd_WAh&8M1-#k$#A zoE+?(pcw4c>s}F^^7tHT1W2Vw%D*fd#ioWvFkHpa2*+y7T+|$WY{ABr$w5RYg};hjtT9A`k9>b#=c4)lCo1 zQ3Mj$k`_D;JJ^HKQEHJ?9n)Uj#uZaS*izA^#-l3E&ccS#S&p^QmPmmF($Xy7v8xuC zE4&s&089Zp5Ukc33Bt(^pHPIK9EXSrk=ng$YS=~@h-oTD+&$!&@9{2d%Y2Zs4Qx_0~O5Y zKV`%zDvmj>|D`6!k~N*MOGpGbtK9eYtQ)FjrsoN{T1(86KaX${mBC2pqb;v2goKMm z|C4`5`c+-t1p-JyUV6_dLbQ#^5!qSX09XJ* zqAeE|<81$ePuN0IgAQkGaq?jHXxReAc($ehAIq2{L$r^45|9}26G{2jN~O+!zy}*R z4+4lnbVeF>_(XpUDpaDL=xEl58Y(Gi#}~7LiAoEw5$6*QOszyWhptK?!;T>H*ouG3 zX#)u$5(g!YKAho3)AVuHEJndNpqK4s#dVR%e1^{5$K~*kOQk7>2C%A$yID0S#N+mq@XrtzBH z4NObqX{z=qow?^(!K{|69k!iJNvUGO+XxYE2!7Nx1=8Bwh|T*-k<*Zr_?mvT{G$Uz{0D}s=C0co7Co|zNI>TOvdw&!ON8b96^w;L(zOx`dT1fTOd#!)5^;_QodA73y zEhKy34PGt85Jg0SmLIM!ByMU%+@!cL09@UpNqZ~w8L9Z~<%UaoGDWDYsV&}bJMf+X z7TPT!#ps@3=phJTr0??dc#>4Fm><1wjh&?lkahH!oyA+p`+C#a%VJT4^Z9O15Qvq; z2dpOA=s@cja9JTS1>1i4WHME@#TvIcO+g{wpmx>|LG%&abPV@ma^ zU@GDV1OuW33*PNhxZh}ns}Vo)C%^J9Kl3+#*7S_?NB;$L){lQQMY_OgJ0PBuuCgt; zh!$dEvo(djT}9@cj#>^PHvxc%rXcW?m@)0UQ)tbB19k$TeX@H06<3|05;ibr!c@rO|?c2EST^jfGzc2 zDM=6{qD6@ZA<}MghGLEL?DTkK`9Av zv+TP&2?8JpHub8D8#K=WKbay(QpO&atqwneZB;KVQ@h8NY@Mjn_iVsYQ1b;M1$zSM z4-t@WCE+}|0eVrZ9w5`9#RR0@SN+Vp;%WJ;p+i6ugS3lllv-G^?jo9mvd=;58p}!u zt|r2ds?UEqDygTRf;ug<0ZB0ph$z;nMH7X3YNa!Xl4xT94g!#%6;}ZKaDY;@|ND)n zdfd1TGyt?%t+lu|8GsZ8jBAAyGM3btu)YJORpAoL`+~njq2-<#nAX9wK&EE1uaVRd}H${ z1sQwmF&(YLZ<8Q7jEm2D6vMD10Y{Z@A%m7>sH4W(xdp$jg6M88{F*f=!Y-7cFxvrG z$qRom6$c2yAUwT1%dCT18DPDEfUHMF6nz4!EdLVY=(B^ktckDA77BM0uB`tW0Iy9z zvL!*uenWzd&5Dz&7Q)O8FNt=yV79GRwpy2=br-sLfGt|`cmQNcJ?J3RkZngGV?FLD z0#a;g4qoM$8slw<|%Bv_?sb>oSA_%ui z*90xdASAvzjErc>>Ix#9A~Flhc1T`o6lu4L0Iw-%MPiiAyf#^v5pN1Ue*w>|o%hJt zBz7`dv4snztXo$KTgGYNSBR8MWv`x;+R4kL zA5P5Ly`5T)GU&%G+RR&g2uIn%wB&u?TZ)Xb``w`6=N3T>)LvieET~fzRqw0TVRKtTsITUCb6_@F5Rzb+LNvP>4;` z9}w=AAY4fzQ;O-;45Rl%07Rh{_JH9-jzg9#4i5-gsE3=HSP)lme*uZ{ss(WpP^6LR zPd#5OPR8bSL?@<^748_Jb&x;`ceDbIP^^a(+qg$0lHwPgDF09b5P+3d$Rv?OBts%y zVv9E54VmQ*shC85_U^hKNBQBB3&a0;QAoI39N$s1)lH zBBSD`NChQ<6#lpne^a*Lid!rJd9-W8EUB}@Db)j)7i*WJFzGuju)-gl{EaA#i4l7k z(K?!_B50Iw9G#7khj=U!j*2rSf8dfQyi?pPRP=^KkJZ8!t~AId5~!L4Y~dEAfZjp>el);w`fy!5Tvz*mfI&KD zDq~vsUa=NcJJu8@iUoOM63Af>IeJV1G0WY$MCMZ+8YG=S5W^w!_CzVL!;98QN?`al z3V(3VED|8aK|Mm$Hzn33q(b^3q|Zq5uJiXCSUvPjmplU;R+zj_r8GdQ(+k4TEu6_nC@+ znyOXCdayw^D>LIJrEo@sjZ!3!NrDz6p;@cu1pr-PE`igMrK=4<61HdsS%_D)iFI!P zN&=V4ZSDh~LM@Jy>#W7og$KHH@1PLyJ5o5IAW-Py?Izg~X0&4TFrO+fUA0dy+ zY8;+~BszMrVQ3VlW6QBR?IP_8QEWid8%%HJ%ETKlg!)rp&_s6toPm|14E>t}B`DH# zwzbf7C4dh9#pQkD#Hx7hQ>1oq0btD)vw(~gy(C0o9410ht68L+R)sH5vmVQ>kikqtn7S|&2w9*<#TOHyGN_v=j-Ji{tmcqd zZwNm7oNDPY*4}G^5et4hdjJtyG6q_;^E9HWwT^J3U&n#QktFjeZ1q0L7JgvZNQB@Pw19V%#ALXF zc2vk+WJpsct561je_WvyXk;i1&>*g-9sJ>4oW~a2fi*ek{-y zHekE|f~OW(33s-rG6spMT(F5oW{?Dcog&K?yhcg-#K$rswX7q509|1L%EGgd=pq7$ z&?w0|BI=mVXLow(V(_9Arp?D}ffLHgufWGhG9tmWB_ovNri@~x!o|}J#`%!Y+AQOG z8nA=XLU=4lEdD6VsL%{4NG&p=0;ghqw#gpcg$*0UchaNCZbV5?X?X}S47UJz!lD+a zkPXF(pfD*cc<_vW5U^ONsNAlIhA0Dk!m!8ys+|lXBd9PG{fY+@aY;mR6t5=~MDBT|Hu@U4J=Dnx$7uNozKTI&->Mzq9nMyL>kM!^q%!;p6nN>p6I7L~+~K*!egw6 zJX}GHvS|~1fjYWHe=@Ng2S5Q7q9M_!lCq>03WZ7Lqtb*4)0*++`VS#!=AjZGqBQY+ z0Fe}JPJal0BU6@8C8kFJuV;|Vql`w1QN(UtRLB`u<9C>(T~J~wl!Q?LtZH;|k~}6P zd<>>eh^EX!EL4aZr-D%+$WD^VET%*x!Z1rp5R=wrI;fHqy67N4>a=WuDqqejDQbHE z^03Q*fQVKMMffM3dPjgXrW`Wm<;12urr`p928{)Kn?0i-$u<{T8*zq}~b2_QBI<50Mu`@fhb33`SJ4NRj`GV0B?^oQ5 zI!tDNAf})d+`%lU$ACNxIcvuGz$-odus&l`K~7^szRM~kiTfJsFdFbzHbZI#f-=eD zcaWtXA|PG3!zJ2b+YF4(-UI9U0>MVYU8EvIYvw%;EziosTneH=D&p^yW@w-d-v6+R z7+0z!iBbD3qlz*Hm2w3y{7_zsglA~xMb{#KL>0mtN72gCq<7xOAl9%d{-Wd-BLU`u zFrMWlYUU_>D%;|ds)9x#W`;fQQ|Tt6*TV8NTtlJibLet~NPm<+Rz<@Oz0GDFTcTs|dtGql{jsK3(lE*c9oy6ucrONbK|wo2(hZEI$?mAqW*Y$qC7S zri~CuVG0-|;p`(kG$Bg{z#VL^Qh%dJXzfnw%|pDzF~BRs(j!d5l2Scn6Ew9wiKNTo zi5#4bCo-ioxWq*OY9d^{BS+h@lZ&@ zIMO6Mv}MZN;>cP?S*p}Q*%L7Lf@T^}L#s$>;*?=WaRP?Iq`SfJ7=B|3R$YN9M_CmuVV<<#}4PM6}*oONeCPLy0jb%Sf__U5hzM|K-WcDtg24Cq_{h;{?Vcg2Tqz!GqB zw{KP|r_Snl_~v*`#xK_5Z)_L0KuY}p=PTMrIE@iM^N|;=cYCwfdoihhd%t&l!Iyl^ zw|vo8h{QL2^~WLfh<&s7KgU;o;g=UFW1fabuObL6+(jM}3S-XKeBU>F2e^I%*nsm< zS_Ala5cq%{cv9Tk2n=j*Hi^F_iTCD+J;sR($18irWb&p<1jqdDD(~1McdgiCxcH3CMRw0Pge3=y z2WP`~X$CRg*ceklpM{nBLYNcjL&q6k$AWSkycFCE8{f3G8 z1DylsTk3fy*e>n-`Ek~{hzYu&4f>!FI-wPMp&7cN_c@I9M*X^3iSIe0c~^TQ_&hB- zawHmYp5&pInOUa97uF+mY8YfT+M`oichj(eUAmxP+JNbQ_mIgbrfd2L=QpQu+IwfZ zr)|0yclxH|SEq-Xr-?d&6^Sw&2=%ON3|awy*(_sj)rle?9S6Xdd^)I++NiY}qk&qe zw;HU4DXYc$NxC|l$Qpjr+N+^g@~g6Zx^CNI-8Zbgb}B_ zyeP0QIhFe%%#EUXofuuMZd2=|9DLAhyr=uM@vo(9OIlHqx`?Eniv}KEn z1xsZTXiK8C{H-kD>dRaaoi}HM*BOVON`fKVX9*E;YKHJG-U(m=)H!H!PP& zwO@ppyvO@A!op687ND-e9)1CDIz-i^?W%idmG^kPZ+Vm3+q}(mHy|az!8^MBdymB% zZF^Zn1)RP27;VmU?GmS&Z98q%FGRhRyow+!gO(wWedd2;^ z#aUc`Z?Z)4;wy_kTwT0wH;|AaV4JUlW+~&;V>{gV4!c>tn8%5`cXr&k{|3W%Lnjzw zpb@9YwPo={yUMNn%CS7lwS3FDyvq;AiO$ck6()jZAFe9hgw&Efpb zM6`2&3Zs&*wbQ0e#Sa1-*^T>ZCnqA?)1BikF5hJ9YhHyc9>w zIcGc?8}(RY{R|y)Bp?YooghBl(_4Fny}Z;-{nSxC)m44fSzWj$SgzlX#C;g^3>?A_ z+}3fu)^+{Ccm39TJ=b|X*nM5tfql9`*^$xv*Oi^vmwnj(oxRwhz1e4ZrC)1=i(J-! zuM^uz2g<=F#XT5ww0(>9nsKCj?RLD?(LLSOecjo;-QE4&OG-PLc(x`Q(F@1k4?BOJ zIe7U!-}l|$`(57y{@= z;}KhuWr? zslKDXrHO@`%J-&f|3<|Phl`>9aG-p1)GzJ@XPv43bN_5!V`7_dsF>Pqd++ss-pI$@ zpYz7*yV$p6gCuegF4?Klp`z_=&&xjsN(O zKl$@d)PuhHJze>sKl-JA`l-Kv`s2^5R44C`8}W@en|J=U=>F~*XPf~W^Ff}*iCg=1 ztDXyo?gPjDHRt>3PyT01{QJiKtzUBRzjnu;`vGErz<~q{8a#;bAdrL%8x92EFrvhX z6ca+M*boWEjT{Lk8Te6P$N?ienmmazrOK5oTe^Jz2{We5nKWzKyoocGdQltwF zHm&)T;Khzp9ae34m0{GSK(ksr=oMjDgh5hYNI{Jn1p)~mP&n_~zb7NCOYviF+$5&{AUU@0$$J}f7+pgjp}Ew6z|n?ho_g-dXPuJEU#OWr%#*`Ummw6`bv({dVZMNEO%l~b+ z-hK;ixZXmvi7oXYOMhLm2>_7*l2}UVUyW!XfW#Q}+``EK%T7oymdKG(k1bqy3GKS= zQWpezq}XFirF5c%M85|~(M~H;KpB&A*d4U5L!{W!&K5v~Mdn5TBvFPvWN4ajLz>=6 zEJ0B|b=G=L)|7<7Id5qYL_kR4j>Q-4EN?{*Pei8_wzQ(aVSnrry>q+k2KXTYS}Z+) z&!S1uN;?|>3joDSOK@Ugj!B%*$Wqi|PaCkNiYks>PK2Ho>{yH+00>0Eb{hii8*!@@ zw=?#c>uu7txp+T|-WCCqLiS4IkKp?;=;e?Z-K*tCTJJm%1i=4Gd2x#K}i^`sl zc5}IOISU^O!3REpL|}zcfLDqz(!QQSC~W@oRruHyof3AgB2P(z5IeSo z6243#IfP+Fw&1^HeMl!6=D$1I0&;)S3}>YE%FZ-zmb9#;EpLg- zTK<;TeD<->;U`(kIb2{v5rhmbn_Q@W7NmgW_)qm8J1PE>+YC!Z3 zHEFfFPwvP?C{iZXenyIFmM%Rq`{Y9OBb>`+vmO$-%FhTV35*emV7OyQcQ!~00dz+S zu5bqx&P28JEMy^4Aj2ru384ixaAOYYjatx@vNCiIITDJRv;3D%OtymmY{29jJ)akZ zXMZNsnZ?-~1nY^PixzDp+fhqMG^Mi`EfjqKsFVqDqs#d$f<_s=m49Z9V?yG2vf7nF2vCI+ylP6?cU>qVmD8=f zYUhe6$gVm^6YKyGFbZGXjz_Jq#}*1JPIxH?J;Hd zh`nvChBQcqJdTgMy+IIi?Mk+jc~&5}9jkr!^I3SMRwckoXfOW@-~bDFzyvO^fe$<= z?L7%S5CPwl#5XTU<&dY)Lrx4$*0TqXq8~nVVP4U*M1h1Dq9bM^_O=HYO~p=!W-CaE zzgLhM`a=o~_J}<~78!<=cv9{t1AhVsQpG6<08}F`@^LL<#NY-%01y>1i4(FOQv9NW z5lH~wRMx_FN#TSLkpesFk%T9GPJ6OjB6>o|LcF}#Ieo>7^-=@?N%XLE6Na&R1Y`?W zB(SP3en=AXJj6B+Knt~KMTlPv;uk)zXW2W<7M2!fK^w$Yka3Cst%7_-^ncVCMmp3& zJqX)mz+9xNcD z5cP_#sge*4-eU`gq#F>*_sFU&M3S3E_CPvDY05~%*pVTFEA-s#QV-OGHLZ%wBZ z*aT!U2v-RJS9rUgBmrlIbPFs)Ktn?+@pVR^;hAj($Q3?-j{_N}-$Me%8=l@b%TWqr z0C}rHlFX4b^r2WxfX^>7250GY6bVNkG?|$IB+O8lt*CV*6aNl9v46p;;Urm;dH{f4 zhg;pM%cRZ*OEF}j?yaB)RRAZp8G!)^bDL9<0@2g?s(}LBsii1`+9}mSjJ>tzTQ>+P zTwxC?x6B~dbDRkqoMj}4iWH>CY#C0z&Wu-l%U#|?%`-~|P5|6fnu@dgmIyrr{7mSP zoInyl6}8k)%|HvELVpw`PpOnbq4K|~T2?{z~M z3h@C0CUb#6^(O$n;-$bRS;v)gw$~ndj%^K z2sz^BR!nzPrGJolnP*)ilLS~;DKe-*p3-pJ#3rs10XruU1gCvlm`-O`5lXpZMd8tkG4{$LO6XbTpXKPt#dxN$OUb3Z_kM5*%+ll4775I?Oj34phc z0z(i^P9jwrEJmHouUNf^=ft!!ADKkOOd3To8QW*9x_OIQ6iP_ArkNhbqKp3+;H31F?@k zH4wsNKL5}Tq~LP{u?6k$kPIV1Kh!-5KnezgK!4g3|6~H71pcFx_V5c$>2nOXj~3ZL zZ|0E<;|jNM4;lG7{g^~7DG;_G1uv6762vf6_I?Sl3H~57S<^l1=noc&ku7OxZ`ca@ z$S@&;9GYN07FiEsiI_K4E-cnA6!Z%N@e77z9557FfmvpRP z4}Xs|leXXv|8NhR`45;hKNJQ44#}3k@RnkUNqSX6TL5@axkUO?5T!tg?J|`2Q%MZt zNevN|t-u9Cd6ETL4~)4qo+4Q)QxH&O3rVF%<9u z7vq+08Cef&qy-T1XDw+9<5du{0{~KHlg>yG=m9}ICm_w0krlH= zq632)lVXx|x-#n&q&?{hZSbOvT7Q6`(lw>f|De6I38`b3$g!4!>YM~Ia);@t5`-12 zhBy0H6`ty*1=x8qM69QYiX93-e;J&(F$NNPn@Cy{(-JM4u^Bn@8R*EZ-RiC13a;TQ zuHO0)$doft05I_bOh6DudW9XL=Qyr#F=0bE&N7De3M%Eoo-vvP_F|WI_B3VP(Km#xeLkWBnIa0E7Ie+6|3z|?J_li^r z6t2;P^7AMHl2iXRCZJ`q!0i}R6-Y7F(p)hmxBb&#i;_p!s^#lvS|zX!m(s<5pF9G zO3++nfE@V9Fm00=479+*Ww9eQ02*X>GNim%#0uUkwgUh`Zhu1(qGDY~1AQSeL*Cnp zlk&sy!!=0TS93!WvgJvYl|noG5eFPUPZE`&>n!TyT)VkBW55cS#|f7j01^c`X?(vD z(Kwp}M;XgGa-29Hd`GU}Fn?$xoRM6lhh47OoGBNJA}l*mu)|oi!&>tZ*rzforXKGb zFgOODn*g^Av42XQHM-}bHRM?aX)OOt+yS5__k^N5l+Ohb7RySqB8y@hNTB+>aooMK z+BL3_PJzc+#SsOB*o(qJy^)-x&Qc|GVxSj%ktKA(SvtCb#~JiVq21fDT%&|B>_AU! zFUYcgxLh5U8(lKtjpORh?+nlJEYI^y&)V`sDr659_kXgkX~WPnLq1u1q~H&nfR||d z5xeOQ;zfQ*z&4(u3I33U!?h_4#KR{#gd4^_^8-C6V@EvPJ&M>gY(feZb3(O)l!WJD zxv4#vQwy$OlW!&gO^nF;f&~2Bc`~PHB9=lJEtnoQ02G{v8T2up5&%fRM>mCHfWc!v z=0Rb@Kz~9;D;0x6;!9=O<7Et+|2;OUWVN?wK~`Z#El7Mb5Sju<2#ag$G79bB50R8* zDVcJL1-4_sm8H<00)b}ZVaY42#NPWo0$n|dhEqOWX9~eX40F(t1knkk1*?fuHC+#n zB$!G7XtCo$S^bg(p)`87mD!^OLJU(zR0~-u5Py(6zZct6g!48Dur@I|5gEIvd=uIU zL7Wq%R{`^I%vYx==Rc^Jazl-&kOY~y$EbB}9BJJOHRKTt1gadvuAe5FZ|DmP8HMfO z3$g~RzEaaJd2N5w+^leH5dn*dNwXNZfmgGrz}823b_;EJM+3kG40OSD1iis&!Q#C) z+kYt%pnKa$2+?mCMWDhrGmZa*|KJW4tyk#}34HwJVbU+-aZ}5D2rNS?_{O{K*oH<>OjEW?vw;(_NBLZk zsNH)@12XE(o*f$zqDO0xZNBE?oMaF*pnt+gZ6;^+JEc=L9QLwCh+9Ag{(c8By3ZN~ zd4}b{vQq>>3e`Qo_K+|7;*Y7CKtRCHZT3}UL%{yl3iy(8Gb|M=To0o#VoC!)?ZhDv zREUN4;|!F0wZPf~@l_VY=gYm_4BapU@zC?Q3BNFpJ^^o+ZmrZ>&zsKao$l$M4u9&Q zt`j?hIqCXBYCFDX%&>My|GG~yGBsom7UV1gDx!XAF{1agS;Ho!3v0%Zx=$hrATuqCYkKq;R^~Lj4(reDV6r4J5!MA&<--ba({F@A`s#w0PULqO5h3ua8N+g3e{t^vc)pM6bW=B zGC6g`T`K@WAbJMFyX41w?T|pEGBnEfl%jcYf&;pzD^zk6Li8ZN$|z!-g))K1GAYCg zBEN*qjg@Cr#)wSN#swC1EG z6!W^6Tjzq?!?e5!{ZjuAxWOh*5IbiLuYSTZDd!RrXFPuh<-Dyx4pMN0z23x-)5BusB&Smgn3N$xOlbG*Nuv_mg?K zYN7z3%-j*fq^viJ#|rmDr6PJcAHxZfANH4H$?X7qj*t9X2Yw2pS<^>d=OVV?&L~*# zDCmLVweV2jha9!T^?DQZr90$mH}^=77FgG^>IV_D4+K*{IKs925Px1NN?yRoRSTK` z0Fpgk0WbyZ2@qQ({{#U5V8B6x2Nfj;X9+w~i0w^5G z&?EvN2t@hPG*MwfQufG%G$_edDFq{2+1gSN0FVLgY;6J%WPs5mWMqysI>2Eqc2f)* z+|+8*QLH8vBA^JQAb-j&^%R|jTN7*>#kVnFq%cO;1_L(c=th)o5Jq=5l8Pw43`VCg z>5%Sj1XQ{u1d*1MMnEx9KEA)3(MCCieg(jxfewrS+AsJV+9n6mW zlmVJ)8Cn5pHEBT|<%tprK}aIu0q^mu7(I|EdO7+qdG%04X}Ta?*e$fNEf5?v%`ANr z+(RRVzdi$KDk{Q@S%zNfks12|(B_*rs#iiTt=PI$iemz5c+o=&Fxuz>XobMAd`PyI zm^3WypF~lsQPe>pkmgu}l9Uwb*E?kCKQzSF;a--3#SI#$&@r$m60)p4RN_aNs%J-& zJQLr;hpji^VWjL5x(K*>RZQC6%T{^~5Z=Z3d<=LzJI*acyVajzT26#l4kZdn&g-#XF{3ru791Ijc-t@ z`z>j9w{7#gKJ{?q7TKfm{!~wrbYFjl=`;-GZ?b568%L$km2p|J*}Z&J=RIsdm71h6 z`Enq;M)D7d#=qB$Nnv=-0Ve6%m={!59z||f&5)wxNKC8dmJbz0=AzU{;h<)sJS0{b zo8n50IvAIw4M>GS=a4S|g#0{ZxUI8OEw09k1z=f>n|HhH*;XG;~md%&`wnRGPt#AUq)~tW=X9_>lc3&CBJv z)#^6$cb07y)}Zud2Q-Aj&yfKm79Eu6|Cj_pWdLYoSEaE>ie9}F+;1WUHyWnQ1n|%W z@m8y$!HJHoF~Xw|P8F^ii<3%LQ^>Euc4K-}ri3%*xG^fGi9VsNGG<)RZVvZSiK?#`y|Zg_58 z`;(0>MI_nMz4$o;pOdhF9BV(MOrR>KOkVFj%(&e#BmGq(|FW`@-r5#-fHx_#b#YWL zU;}SWkA^@^i@M7*X1Rqi@g^3JhftAq0_Nv0?_S!ejrW2)G} zwv1vWVKE;WY9l&7R2OTjVy2;?^FATH+>Zj7X*pQo%Ih8B_xf*ygtE7 zy$fhhb?P_ZWLw&nm4c8ss#JsR$TusDC%zQ)URAcG08D0Ns`kZkD2=O=hFA*V`MO;w zz6H!TRf+5~c~}bxk1Ve4-Edo+Mc!+IsnV-$B{2oRlAe3}-_ls__kF>t=rkG6I8St^ z%4~|ccZJy5ytCU>yPaT1_AKH0?khWb=al3NG5dqAc#ild2_fur?PfV0Z}_$1Aa)7F zDer4CB%AsqRnNzlLT&*ycfEwgl-AC_?bb=TCeeE*{S zVHx&64_!CQxm>BOEXtj-`QDC|UqEWeipG_^Q(5V!74J}B%Maxd% z1BFwaI**T8VB>oGTYr|Ra4)pr)f!^d@k>+ftgp^X8N1Rq11WzQ0n!Y1>?kYU;)Btcq3S4J+)bQ&A=Xhy}fEI@l{_@Wk@rfhtz zw#!(d3k^Rn(xi}bqI-w`m;}zXV6Yl1UJy)O<7hz5$Uq!sV&+W~ELC$Ax1PqmnlV|Y z>+&)+w9M_Ad?5TR>V5+`{dd&-<1DJA9;_(Z(%Bqw^7-4b3*CwwsaJR%jIVD*5VNcX zyUqM8X;9H_nE*A<=RsfP%F&3?^himp2)(N9$9|h-@kkWk5Us$w-6R1h4;La)2ms}! zyN0%;!>w{OImGd7C1{~NZ2Hyful%x^!*?Uj8=7Rjr(fh@I|ZUBREr*l-HV0JJStm~ z8~GS|FBcrv9EUg7e;{AwVR037m#j%8zSi~8K`B7|6}R?d=tGVjdspdQqe?~w$o1X5 zfT5={ACAAY^DHgnbv>ZneaVc>-;7Q+_()v}*nYEkk1Vl@^=!$$3^90l;YI6TqHxzL zE2{9Ng$}0>{BG|HLcbgY_cKcB2|W`2`eHWzuT^$dBdNZ5!EZ!%er&O zr}MYbv29X~oaB*;xv8Kxn}xg%Zva_fuzxNY4|8xF1>IpXo?g4E?iVRVf3dSz4Vq4Q z*Phx|JWCjTTU{&NhQ0rLN)d}^(a?r6>z0HFwTdW*$`uP*?2d|ru8O3!ij=pC^lFhF z8{O^k{3bL+*O9Jh@#a(!H0qnP_jx*AW}02j)|mOvQ<* zB4t^CeJ-4oxg$pMHwLPd% z_%JO!5l9b1A_0|)38yo5j2d8zZ-}MZ+^2?&(X{`h#=pnRo7Z4u;~Rn@e-N+4O5%SN zQ=rA~R`&=4TgatHDGJORxK<=)%}OF`ZFQ+Fd9m$#(n(>P~U)#4Er^ zCMxG@ORIyG)DC{bpMg|uY2;2ORTA+@e{c2;?3AGo6EvO`jIg(vy2l(YgE>@^Gi6X{ zL5Q6aM|0EiY7Hj)eXnq(Ef27mr>zx_eBK|&3S{==wQ=-ZJ_QI+&ULk}+ ziTajW1TFi2Dje?&IoAxiP7Lwfe+}U@I0WbW-)9meN-3)ExikCHA4R~Wlz13N+wLqQhFYqd?>1qg?kYLz@ zgfb?j_UGF)yL57lT8J|TT2FtYz2yp!LsWcjSH#ad&)Mm?b&gm z>D&HlhJg(oN11GKVOgEx-d=2N{Y;F87;+%romD7hYfO;4oY6MdZZOepyj}2y9?A}= zoyhFsg@4T(4zMutz+)CBkk zklcjojh?}s&vYW6C(Q0B=>E(Vj;3QPRsuEcGxLF8e?T)UAT}=30a9#{yPj1S=_kW`ojv;+A7(bt5=FZ4g6yDgx|bzv z*KBtN^u0P6N8d5J#8#BV`We%kY|_ndv78WVvCJd)!-}(T100faZ)vR0VlF-y69rDO zO#DDR``UC<;&u?d$DIZFPa4|Pm~A`~2@c6$DBhg=zd9C)olni`(KaLN%TSSw+jGv1 zRD=fWW%RA#nG(;lW5U_}9D1qKaiSrjJkcNXjFnTH6G_!lIF3zk>a5e|Jf!FMv>Y^CB8 ziwf=Z2sC8BC|`xdAtAK=9|%&5k*&e@x`d4g7AobUPZ2Uxtul$GDUrpRQg^|lf<&+v z^b?-w%xC3JE)5tONM=7zjhZ`{K#3xX_klz#(OW;A3 znj#Tqy&?f2z~~=Fs@wmRS&G=XG8kX6-`j&n=ykpUH=-n)Xbg#w9Xecb#{%&?bFHgT zR&DBMMa1D?y*j#K?DA%)p+j-_C);ly2JzP$%-qRR6*?vDu>{etf+}K89Qv|EirZba zlbD4-rgEAi7i~hH+{^TteG#^+hdFR+1C7s(4B)Gzc1I4-s$ZafFcC}x2 z&<23_Niqtq7u0#Thk~W8XsVz`+l8@QEh7Pxn;bXVm?5@n^t+hPg4&7UE6mveylu;A(r?{Xv1KN9Vi?H7A*Hu^!qR*xy$4 zO4uKNH&ZKAscS~D|LLxhU3A~ujw=(+DjorkRv*^`_8Ip%`@J$rAX&g5tiNYr8V&+ z7Pl9p+F?LS24p%jiHrj`<<-x3`g9C>4JmSCu4^QK?BLJ9=nR4ikK*y!UP8! z6~ul@Uokk0PTElnWF!T_2e16D*`cs`yxoa%6{4 zL`|h|S@Lb)|FFq=)Q^LidQP)a4eBR0Hh2H*so$j*3#&=t0tjtH8kkVZdr;1MQ19^^ zjZP)AwKb?1*{-}ff&V36U4O8C`ry;Q2OC1ShB*Ov`wXC8DJ^FT{yYZwv%ZQX`Ce)z zPY6YsX6BMHs$0mDlPT#an0DF=ogpB8QG}PmDMtweqVhj`Gpy7qj%C z`bbu9#G4Z--eDq^y4_BC`%lz)#wR~{I%e;#K$3QWV)pQUu`YsfXkR-6TeoRS$sn=t z>ZRPGW3EQ}f--%I-cQm7S)xQ$%$EKO(Rd5?R8RoDkR&4e;ybzWwj`e0;gp9r0;_3} z|7dBKpSQCsZatW2{7 zi;->qm7)TqT8|%u&{c+}Hk|PDUXzvZ_gO@lGg+zK$N+>}RdPj_Pfx3oh2{K>7(NQ; zG@kkRp7KWF&+?Q^KS11K`R`pCetGfzZu7KaZV0>;qVdM%2ZKBw$mCLTR+_{*n9~>p z7kWcM#Zu5{3MbX^O)j-(y8r3a#)-wr?mJMNx~gq_*y@q{iP$F5ch@SHCir!?2fRir zAAiL9-}C@QEt88^;E+9}<1cyFO0sZu5Ymj^$!;Zhz7QKMZ-@^bEIhkm;##@;wgPCA z>P}j{NHvyEixlx0ceSXwDzY!G3A`*C7Y_wizxZUFeV{L%)zm$NbP{3;et_*z9gDAS0{{iBW|(rGoZP*{|TI%fQF^X&yZ% z2jDR=a(Rzd<<&pBV^kQ6h~fl5EO#_znTS+^hJ%2B#rLNlnW74vCf8&wnMJ6-03ah(We{P*Q7r6^94UZNeSEUBdMaz9g99{z!kuznReccIqf{Co{dx)^DgJwP)8N7O~|g_lnN5#->p&V_oV2+DH; zO`=$7Bpkk1>@Dv(*;lD6>Z7#y*hr(-6{xxdO`m(8EFuu-afdk^CT0abU6UwBEY#p} z%lzJ1W^m|pr(&*aU5w1OMI>JOFX;uVH0x9@9sa`|K4IKzOS$PF^FO>Z5k`VelbnUG zRKT7~QtKDyg=vu>Up8Y!1_%K8fC6vxM|KCis>W4S?C?;y)NJT!~-;#x``DoT&+Pbmw2`~p3zsSTP-i=zi(C8!fhTF3Cf&~SkCMM=R@uUlyyy22q^C?HKI+Zz}iQ~#G+mPyx2Erc$ z_5En=9Bg}D`-b+}A@YI!&!F)y`QzFSI+)9Uxwl_*v?pvf>79J==r6mQSl36|z~(=2joziVRSFN)lqXY~+FDK~s! zbE$1xVD?Bw=2czL2P_5Ij}`m1{~oIw-GBE&kNO1ZqCR_ol~#9}u5(_G-{uqQRGy%~ z`^?+^;cQ=$^mx#@;opFDQHkxSyS-{uQ0Hu`Klb0R4P%WT!5_>g-nBpU$H?DYuomak zRm`T6<(l0E*5y8D7W`{F15t4cd8|6**z4Bt;iX`i9jHC{@ zdck>CIMVsMO4vW2P%y)RN~!JJ_cQ&Os-|DrV&y!&9s>G6cS&+u%F zHSR+gN6`e!RUYQ<{UKq3XrdzrFYCl;LBNBf_+2rlI*B>Tki~vwZ*T(4QYrqh%~3vZ zm6!X^{>VL0fGLhe4U;pPW;*NbDyD;+1RacOD2ow^c}PBt$C!?z*b8Ip%m$f*#{9KX zfz=uwsH$u)-;?*xZHsK>36Ai`czSI54Xy_|3>XQ5LE>{@8DkkABrJ#lnf*v!fp8B& zfM)U22Pw*cL<&K4xo_6{JbuXoYKXJj3;P>^R*)4zu^J|xgin3O`w+-#<%$xHHEcR zapAu41%d}Y`oW5JQ4M9LJ5^Pn&eDQ6g81+%59*fa`Sb9@kg%mB+@9OKvBG6l2hU+6 zT_&!=MzUJ(PqoQ4B+?bH&-kQ*P<}0QxX{rrSs##(u{rYm*u5p$@N^Ah_vi4VgJ?jJ z?~7hLDzC)>9;v2yo7+z8M~frMQq9Et>)Y&0<0j*dQY|@ax83BAEQC&t91HkNl?zNQ zUzWY_sCZ!MnK{2a+rLmx{d!4-{pr%;R#21kiMvn8(aI-iO{UBpyvMcos1ggKkg+f| z)8W$tqfE!~Jn^K@Rcu+RegwYS25eDX-p7S6!-9?$yNrgL9Ua&24%;O@RMqNN+UDyq zzph@c4|%urrf>C*oy zl<46e(US~y1YW4xjs1So@Jx0}kj~;F`TCsfhQBLo9QJhu`Z#zUm`Ufi-($90>D`3P ztrjdB46qj(nC5yu{wVL8PXiaTOfw=(QYLgy{Fk&5pyu%1{UE$p#@Ll(Ow`F+eOg`p zr^;>^GL!LzvigeH!@X}zS+BWYekH7zTXk2~p$ZtKG%UfWooW1lg#DBf> zlfZmm>>?*>Ct+?E{Gqy_Ey}bNwD3hapwZ#`_0UHaTs z;re>m&8?Bw-CMGwsi*$sZTZ`VDIg_krt8zt)ST`{XScVx9G3sBcZ+G+5A{Oz>mNGh z&5_4cym`lXXd%6t%W0Kf^h(b?&^8L0G1W0DUXM-b?#|r6;bm)bR|QiF{+JNj9y|~~ z38ZL9iMhuuH%p($fE}q9DC1}$S}N?KyXypOLc8>1i(RS`#EuMktI&1&yZiowyn-u= zAi&-u3S3=}*{!lU%*Xh{)Np|S{BLE#%p;cHzFU6Cp&*dn1j?Pz=z7ZQ+Pzln>&Mh}S&f9}pC+AX*2+}l z@mJf-e$RM}QwyVzLyaNpOn-N5GZ!9S^PUGcmWOQc8zLru_S9(vfq1qt4PC>yFZ!FXqa=#@k)q`x2lqSH7cQeeTcr^7rqG zcfai$x~V#5R82L*S~|Ha4wkRVY!-@G|EVogzv5U=pE|o_nNrLLVW$f&Z@sknly5xr zDeDsjMd8nUn)%(VN2!H$sk(ap1z`1!)NI^f;`ooaokpApwEMO9W>N;c?Z<7iPU9Eh12KzVO=t0Fp+5H%@9 zbkP>j42wb8$kqpaOpr-%m>R4#Lbw#i{5_rt&mWFmC1G|SWa>49O3cB(BLvqRCd8OJNr#Z)f#2_y2Ke`oGb->x!|z&Y0^e|ob6n6LMsgq zm0_~{fX+vehMAY@DtIf3pGU4evHs!0JW%H#CH866h5{=Z3DyLv0B%~aHRd&fY2G*u_8Pa;s=*pnu=ZR9JsaK5 zNQe!1e%8GD#EII3>V6syU>>QjjUUm{Eua}+)_s{GSC-eE zT^zK3=5C-;|LMQ*4YDfBKeAEa)_$xXfooT-x0UcT4Q%*5Qnnb9AD7jgMzr$GPgB+S z6->1?nb(zMQTPFz^#l>G`VY=_Yue^xz5(b;HBd*uI%;KeaXNZQ2Xz7bsWQgf6ei%dWiPBppEPw4bdo zKTT+S#0COb6IppM?^i8bS2R_Dm{`QS|FnM=QfZwz->Y7A?dKuNR@!+|-v@`i`c(vL z>}L;J6_x|*)y8PvAPR1NFrLrxmd0EQ11zZ6j@8lFP^*=h)xNmgyZ9t)^M{sq1E`;G zx#Q3F%K0qE`NAxG!n!$Ny0cg#>fpabw*%DQ#&q; z^~x96mW)<5CM?git_LCR_@IP}QukqsjmV7rzv&kC?n;<>P}e?=aFHvtgnq@b+JSJ- zI~_J@ta3Ybf4s;ofT!5t4z(!KKJc^kNEc-<=NByY$<^aOHI7x~zu)6u@{}Zdz7)!<;tpi})@LoquU>EOlEh%PHtuFwBnT*^s zyp?Kvw2mOVwvsSZXa!k0>w0>7_YlJcs@0&0N(y&5KAg=qH9f3YK@~G$oqxsg_Q++dylT?U@B`SyB>^Z0L3MZIDz(n zyg`)n18?;kq6qWvA;e6X1s@s zF)je0QDiMD_2vP$Clf*PngEMljF90JV3oducJc@HIweot-Uk8m^?sOlOt`_cMX?EZ zmHxV?y+7dLpm`(ckZGu`S*V*9aAx}a=wH;@Bi{3O>O>;2gHm;3#uV>DWU9bGDz8QZl7 zb{i6(H91OEOKTSBOfLae4#$zIouCzfVp}^1Y?l0Oyg;LL>jre+TT6~DD<&bg3Ut>r z7F1;gi-@z538hCHS9_Zst3KhdmB-&U3gr)<%m7-aDSu}|J~4MTZU=p4wA|ziXIho-JCMr8~MFDu(q;W#&Wu&g760U77>5BRRV zc`|2d>suW8+3cLw+}%zbOUzu^xsLn;r{Ki0b4jRO|J3>H$vO9te+T2mcRXJ{B0-cq zIY-749jxt8<-pHtRuNFP_fzzUKQKTbs_GD&?ZNU&vCt@FZkWBDH3ibrY|`^gYHF$-{4zPH+pf!ktzonRd6*ta&*v|)@3 z%4i66J;h|}uZk85Rj@UFlXH^|Gav98060xiMz8i`J0k9Z*|kM#`xX#SS_#67VqAwZP3OeU}qii`!neSKOC?Zo%5Jg8OiJzj}wh0Qnj zz476oP@VVTrHlQ_61f8LJqB#Bu^Q~7eW;XDF#5Ud0t@|8pNkYK!EES*z7O=DXKO{= zTJ`t8NyJ!@+zNZu=5Wtj8S8!9%>34q&H#mOWrIe~!EA{RWO!RZc;>%WOe|Wj^GH7; z&|pTDeSrk6T83BYB$*if3P|f8DaPLS2@KC@`RyO2(yhD|j%sYHRoRB(jt@-JKr+5c zfmsjLH;yc~*>qljFXu0?1BPm1Q$P(g{4*q)&-lKHs3`i zE+yVx-W^DlD2eP%1Ab?M`UVcdSs%*ZG$Uez`^(BsD&mjd4x)^wZ!(a4SJW28uzJti z-!u`GrltM{SzdN`UEF&JJkJiNbD}osQ(%DKFwa#i8uTr)h#ld_ZLz8)gxR< zrjkuiut>Wrq=RYH10td-BWzYlib1R5-hWZEm*pr1cLb1zq};y#YkI~``VA0E_-%1( z+jLqN?7gVn@}5b(cVPhSRx>Gd2E{cm3@$yx-xAhNA=yeD6zM+I)C6l!PHL^TjkT8% z-ANDYzq6+UG0&BU$w)zsL4&7p;MM`oyy*3K(i|)2_ES|3GRf0qsrxyqN{a)u+@%FI zH7YU^Fjh9!X8IfQlznl){XVdxp4w>MFmEr2IU$N2M`mgqG$04n1aQb@ zit~4@gmAs(mpcWY>i3_j$M;r}o0sFgriL*N7Ts#5{>8qemI0(w_3f6^k<-OXf&5`e zl7lqh7#b{{Z|5 zldivPp3AGUs54k-I!cLk6mrW?`pK6DhDv=Y^aF$J*GY2+j~q3?=D*riTv2)Ag^F*tbtww4gyi=BH=uy!g>rfNF-s6R3-JI>w^? z;V-DGZnn6G;FUm%vtGXZlN=-k9GC)@VE}My4yV^)DZ_X?T%Jz?C`ykXP*^@pP_V** zsNqfxrbK{Z5@lNdc&16AV)HT%m*Ne=y_G}LRaV*6RFVKR-7kG9=h|`cv~->lpk*BN zdLfyR<7|I}5hG<-vAW}@BX5-ggqv7O0Ew2@?K`j=rKy4*spOKE0AD?mRb4p*mY%7r z|2FK`g&!%!X(mFgI8|1ke4=XEg zDdWUNgFxYW-mkn9Og#+&5?9|CATkY_Y}#@&Q{wjvD=f7Dt%UZ<3P!rzabM~!Xo_Py zjV=Qv1(qXHw-P|enc~9PlVJOM$znHTOtrTsx~$fT;JXxL@o;6_Gb_xUF-3;!7qyqG zc(ejpxS*|X5HLuUe-$r50tIF^Ot+yp8!6No_r&oqKjBIV6i?liS_tj}Ymdz27FMsVi>6kRQe0yVJ;+ly^YO3Xv zTQ=RReeTWbebGZUCsCDNBIjN<;=OieTm9Or1@jn2;CH~ z{D~FcVdkzG1Pz6flHzJVxwsWaDfNCh_vNq{0CHpJFNUP<#X3ut;!K0@GQ+m` zim1PDC<7{a718sKH&fKFahP}u>l-)#AWlgIXf;RxSlj|`Hx70F+$szDIqj|Kj-Zco zW_qWK04_K=KmqW$rIgbk59^X@F%&Rl5%^raKl#ADFS!^DfTTzy!rZ^fCaI;80B05j z_eG<}u3SNFTKz1CN@Pg7=h!oPoo|j1xeH19jBOR=RwN#NT@!M(m2v$u##9?Yhs%Mc zTX>Lgm!xq2?GvQ1N+AR$Z@UV>a5qz^sMDgn%fD#YCCdPs-M!}y zC8>O4PrL!od@|OLU)@AE*+J|f)puP7sr{wm)YaBWNK!&DVE_HcJQdZ4{Z&lIXA|Ep zMwMmX$K!*aJ<+E9Xr%q`P878|&W%jvMHn2QsKU2%pN!t})^Hd-P!rCF!O7}^6Z@HR zM%R1hGa0YZ{%MyelB|H3KGb{(6H93_#*-nERb7z*3c24ZCZUi;>QaQ;;>~h*xn;Jf zVIopyOZR&u^P{2We(3p}Jt_PYLqY~mq@0pg<0g{heA@x+k53&=2Ndtnu3#Re)_b3JU@|aQVqCm^~G{=)JGx; z%c1Fppt5}1WzZPKdE0e!+$@rXrUj{eVin#--<{0Vku!JVBgq3;|HstX$9U=&XJx(3 z#f*P7W;0rvwqiY+RB_MBPmR4=BZGb+9_DiPo_{egVHr|t^PL5N_IV0>=bXt#!V*`R$-W152WvSt4*+0feqFQ6R*sreCM{iN?c&jdjSzSHba^5wQk89#HUIo#g znv9eR&jli=Xo*g9r5k(2eq&7GzvfZDi=q_nFLRg9Jj6&~G2hQ*Jj5n0RMSOB(Rpm> zbJnPP=9^`viOO&0oU~tFNK1ZXt>CJ4e}vv)FPeya>hg98gx2nhd!iP-V7b8t7`~CG zwmqGD#1qv!PLRm<2p4>;zhy>(FWnRA>Sxrx-sta%E=Jq|^>GF~O~%L++30s`Y8c%V z%sjP|LRBSrsXIwiT)r=?&)AG4t3US>@e_Ivd9Uv0m)XIry)$Q?kMzC7s4&00yT8hq zwCQ^J|0ykW!Y7qYRajfiluBG^CmExMbB;VG8m(2Roqb=vy6whoTK^Gn>IN2M`(hfK3p>u2Ak`{kL> z_RWpw~wz^q*7}R)l5-r&7Qc{_q1j+f?{s7lldXM zPYH&jOcGETyT+1KV*^vzA-vH$9##nRK^WEwP!0nV?s!H^VI+(z0HkC%vCl#; z9Rb>39+$_;?d*i`G>RWO#Fs0I|BRKbo{UHB{I(tETGWh|0R9&`!g!~?5RmZ9)m3EQ z_m8`;R5g`~J>}S40X$}PV@DA%Dpkl{wnlrS5CZ7Uz^*m}O zqPw-9gj6{<{($mmCS9T`>#&asAM^TZgcfK>snS5gs0(jZ!Ba#|$4UTKk${4Se86_p z)88IFWGUb58=~cEV%$?Wts^noZh*5SE?eJQn<~OGydaJyi6D1N+(cj(2-vmrvr*u& z6rw1HT7}2u3GoK-;Rk!BR8H(o^rRGga|;qm@o(7va=8((yCWU#N!jHFutd{|M;ncX zqZMyyzvTIW@8p3r^W=cwPs5Z|w67h2X@xcB^)m|s6r&sO26Y50`KZHZ)2SWnr7-oN ziMq(>HvXfQJLIlR#Vr}UEm^Vc-{N`1#dz_)P5*>_R<+wz#Spp^@aZNVS)F=U43e*= z8A;V7b>H1l>lqs{1tx|t)FHP)AuNed;pR56t#&+$ibr%B^&EH7m{(eftV1O$lPIg$w`lsV#^32UAn$RkX(L+l>ihx8-)!{Iq zgDLY!r0C|827%sY4O%_T!#?|Ye1rzr1xqZ{a(%bui4iDNwH6E%4HUog{i`Q6 z!@soKKv5@Kq9V}r0RvB=M`al<tw)72CI3EWUf=N9?2luLGm>L==(wd34m#UG%a@tKgpp}Mpv zq$>v_pSR$Y1GTp?rP?>`BFwUpVivEmtr}1ULH4%PGy|ZQKchpZ<{iLendox7(GM?o zT{0CvZq)sv+k#m18Orc6BnRUsV8!>6)eZlM8x>87VA~u=>kYn~3JeY$6D7_-jO}tp zhDI%?ii~Hb9}Zahjei`UN;&%ejz07I;9E*G3;AG%XEcv6F^@?%w*lc$)#mxQ7C{`H z#)zhbc65EE;D)fo@9COG0gB}fuJMf$_Yn=s_`G5@Dvnf<`jOx)5Coq#B0F4MA~XKG zFny`?)^FHBi1)Pv%rpR*5MVp&|G2o~>q8d;5-PT%8L}St(zu!!`#f>{(gWmt zVvX~Z0OEA%KfJ-)oNeO^H&<%jdjSLGg-uMqca*J)$|0bKFF_Kw_!1=I+;+jl!Bh^X z%5I98qBsh^!HLA1lC}G^Rr6;l=c=yLRp4owB*TuF*VHAV%SCBuo z2}^+8wY%f_B|pB|yiqz%1Jh*r(;!{O?tTKFnOCwW(Yr6AEjg?1jw8xT#$sYIahfG* z&OGU3R?^Di4X5eQnQYh^sM?4VdbWxC?W=#T{?;`H_77nwr@ln-${YY5b z0Tc$p56FjU1;BLTV<`WD-e^Mdc@9a;9+IeN5ey@v5lona`=1A^cstBnwz; zH4dd*oA3rB{;!X-pS+jIUrK~1lq$ZnRME%Z^iQFi6Da5G)-bL#apf6rsly;D@Z{~O zYc(wL7n|Y3dd*^$@&Utj{kyF<-iTY06(w?w#hh0_!C0rDXW^l@>ayCuxrFiBjeomD z*5M=&CLbR^8iwCV9K5^bNcHSy2~W8D*GJ(Lp(t*M8FG9i7H{(c&eXU2#{MYBb)XDS zp{!}AK5v@zi)Tj|ds|8kt6yAl4Jo;==XA$Dhp?=If$w74K@#mlaz3Yea>9mTfV-Y0 zrm(<>_5ORT^_-p#sLCeaFI&-9M`cZ(aOZZ$P=r)pGV0%i^puOEbHcGz&RTq6S$kwT$=R9psw4&wqcA9G)?ti1lj{$|_i&7M^ti`dn{YHqi7q58SZ{4^s;xzW01 z&voQCX7L0wd_E}8vm9i#`#AYINmXTmx*sfyK&>kRFqpxQmKBcP&;7IE z#PFt=X46z>Iqb80J*Us61;@m*rpLdB6GqD(r&bpG7vy-66F&Di0<3RdVy?1zmi~`* zHcMIl2Il*sSG-sTEvHT^C+C}G0mq1~sv1IxBP5oJ_b=R?VRW|A)Q!_2{v)~JF59_{ zj~jW?u*<}a$KrJn%6|f%bVu3SS7zT0P*11r|5z3uWAijWG{PPlo{F^mHvj872l)Jp zavOmm`eb^CirZL-{mP9fp*!UPvnthMIBGX@`YM_m0$taFya?iHH-?nvjg^X2D@|Sf^+sy=iG0TazCHQMeOAXp`wl8r@Er_*DIW0$INna zsYbG|@*-OBB?G#gJ`0XLsG_#rpD)V1<0syk`z|OZ!EEVY4v34BHo*-Da53=v(359+%VWZ%wTguqbnwNqC|;mKr^|60|p6-6uV zbacyI2pOcIYYA;MfD#A_RF-vjAvjf%G?5}fjDZnm&~egH=2Uc+F*YTb3$CS*mS}Ei zTbNO8sT-OMk(m*im#IbHe03@aCUo-Md7DCF#+f@NVHp(%sF_)6IUSp~eQ9W-cFO4| zqiacOX>n8HX(ynX;_2m}buvmRbZ-u7Dq#}rmqe-&B+)8?te%&#Yy>!eZJPTTMPImt zEmY8clLe6_5@I5-5RZ&K2WOEDx+`W(2|P!x6cQNp9KY(i8bzvvX$1uT0!XkIffNiA zm@lkSpq4LBJ!)w%fdul@N;_Qe%dbPQl_Y@@14!B6!{o-qqgnv0YEX4`A`!vBNDLrv zgE)IE!N)sGu$93N^~~^p#UJ;`Gh_C``%ula9{giH&Y|ouSM;fMa1sm?%$J+|8eCFm zojJtQR7pG}1-V{(tyawlRqT{e2=Jxzy9)e!HcwG#6~xX805C;6DD}Kw%q^2$IA$^J z%rHZPmo)BeOgGr_&SXn2dBYz7G)KoAJeOdq@F+PqVm@~V#U?F6!yxc?M&haEuiQK+tc zz@(=9yv>i}UUf-EJB-lag?rpt00^!c@!r(`OK|8&2;HypX=j=CR8jCLx1B2Uo$?Mg zfC#k17E%~qq`v2WwSsUdCLTkLP8OoGvK2&rNf?_^B+$PgY{4FJ37!QV*S4N*aAkV4 z&ESGJFrDoVf~F$i=n$f>e$j_}nBn04sP~Ze?SyF$JJ+~QmA~XoM|?=MQ{yhS#PBhY zh$myB{!XSQwAsc{OsrKClV_ADjxSFRLt>!3hdZ3jDtwJN3S&v2D8?0Haf*alV-mOM zxObtkh_-2?{%)hTd)-csV|3yQy=O5ee#v!=5#qasWyLN!@>`Jqm!aJi8-Get`jL`b ziJPxf`8Hgw27$9OS+7nx6;__=H2@TrEqA4b1BORhv;<}?wIm6gxb0fOj8-yt=}Bfz z%bCK$)SupjDJ>bwjNDSCNHnyoWYS7X1Vj^txCu&V7R#H@!c8f=H6Ksn@@>^zr?OV1 zDuU4yR`SfJKKIGbe){vD0DldrKnF_Df*SOo2!)nb(iAG&_^^=Efzwo^l(IyLDFDFB z2yBj%M{Ht7e+ZbG@TTTYBCl7@RA@9i>%3ijJa4$ahB69!DgWyZ495v27l2oY6+(UDi=z% zddPznAQi|$gbJ@z0d6u&rjof8%76iJGfq++4?3U^Z zQ=MqVWDWJJMiCpR(aP1jr_3T?>!jV37PoS9eXDq7``P7^Hjk#&)p`+^-n-UUR#}{D zdxtAT%zpH~kuoh~q5p^9J*qdVkUDS|`|H&rCbEuegx-KVI)C9q4zgSg&Q*S02iivh ztC3GksE9{Q;u4$q#3)X&idW3y7Q6VxfK^mUA&FQ};YtEDIi{>?tSbQX*de9j6@YD3 zI`~*glQm9quR;Chl?<%GSj*10eAGLS1!$%As?fS(m83@Swcs#!;)eU}_n|C?xP# z<(`VXv$`2=WtAuY0gD=3!+oo;Z6yIxkfAzTt}U-Y+ka{qchS=pvvFTdE#o?^(AJ@j ziLt$BYiIjUc{j5)Z7&>c_427^s5a`im#ylz>NnV#E%&LqExV~i&Wc7{x4xH4?=;m+ z%ojw{&B$$E2&+~r2*=s9skUp!M!UWs1~r%!g|Lv2VG}x8Dzyc^Z+%B@$RYXVRMKrK zPuDDRCx2JAtG`X~<+^<1RmB*+$+&8G(|1z@m~ikihR z&;v|!TP=OTDW&<-W&6kdHJs-)A9lMtuGdy$9oSpP6^MP#>9U*s>}XHB+SktZw!8i9 zC^n}g(FylH*Z=+QU7nIeL5XN=^&RhkmlfSF8Gk5E-}CVHykFV>9{9#P{_&8HyyPcO z`N~)FIj|*s*nbu2K?R|%K{rbztNu`IoxU6nX%2@z*Yl=u+b3C(f3%lHTSO#cqg(ZyGCxeBwhdrlzItXk# zXKJhmjsH<6fL$kQf+$}axGR7-dw(xTfd^J}(Nba65{3lVN8o64Lj+)HSaVV)hzli0 zQ$vq>C|u-7HGXJfjnjILh+p~ybofYWZ1jQLh*{1khPG#a2Z=~8mwXJCj)AC+*64~B zd65{Iks7&?$KqwmG6f(JHn!3xbmAU*F(@PnEPr)pDfuidDM^EtCY$(@Fn>mqe?}-C z2~2*5N$~_^J9$nv`7JG}XJ`^5d)1Scn3Kc;ijR0refBB35|f~$l*$)m#bT5LwS=%z zPaN5mUip<^8I~7W5SsuZaD`ouC2Go*8n;FU_s|9;gKAfGFAJeA2ys-(Mjs;K4*|kJ zItNz&KbK%~0V0G6IPH*TX@6NrUs6}C#8rjGSyls1km*Hs)L#%bU+7hRu$Dw_a)(}} zD3?i8iWFaR2xg<{nc#9qdsLZ@vyTLZUIG?g*hr0nDP(Gdn$~z>l{tu+iAQg=U*kA* zP?eAD6`Mr&nVI=t?8Td1CrEdsbh8Ct@AY5426g5boWGf2VrNK{Nq>)3q?)H`NCr1lZ?`x{>3SC(Fou{%D?`0+E{e zD^%(xKWQsl>L^6Yd3}PhEkis@&fs=BI_myoRbs;PJx zQcw@IP!I7DTvh`c>hc^hr*V4cF6=QBX22w}6gOZsIB}CN4-pvd0($mxFy}*6w!tp{ zYBFvJ21c8}qkpus*kSf4gZoo; zg1KIJlymQijjdUOUsRn`=bvT7aG|M~eo3%cm3!l+LIdfl(p9h-S~>R0nH9;fyE#bh z7-0uXkX+GO-U%uK`$X&6YveYsk5r8(%SS^LuMOC*bbqu&)|9aPs;|l9vKDKy!g+J^ zXm#~CVLQ9BJsX{4cdx7Yu_Jp$n<<l@Mzhm6C0W~ou}7Q<8(v;ZMj?x}3fGb7IcoWTO3u-dqe`?!!Bxi7X5**Xf&#WD?nHBzvPP9YUqBNoTR z5K^E80}v9w@S>tfFHsON_v08^upj561)Ni+RvIo@W2s2F5k4WMNg%tssCczYy0^BTF;GFuS;-Pth>A)FxC(lsgtrf~KcDg@2W0Drom2z1e%Im|LgctG(g7lUVYS zPKlFJS!LRLy_7Md>q{%^D;45aV4p=5R%XaNTD9hBN7H7#aqx<>Sqi7;0psV z5oMqN?9(zuaSH?B3+-?vEweFP&<=?q1%KIUOq3-SzJSJJL_t1b9bK$17-|sN+6t$0 zSnaYRV+YSWNXgM%(LZO%X!;Ij!4XoLd`-)&i|UP+?>yXB+nD=j_i!j5B;DH zP0ufz&Q~kTH^|ZHESmo8U=}vb&VNiSjGM(Q-O?`o(lDJZTh>dUApxv#k|?YrAh87) z#}p;gWTY?#Q{fAV{4yL-3cE-KJz5IuQyVHX3OErSq)-chqai(7746{4OeF}ZOwra#p{z;G8OzY#RdrDWxF5Yprzy`fd(Qug%(I!0bQI3;ipvk` zbTMA9yFAdmtkI_8pMMPJkOzIYcTS;ON6;C|&ru7Ux4Gt6{?3O!n(#EwIQxKBF6evC zwk3_y0u5nqe(4ge<-D>Eq0F zeV$*|oaY!VMixdzR21d3e(Jm4D8qc_yx!1aUhAw0&7EmrpMMUw|6WLu(MaIbUhUR? z?ULIQuD}I!Z4rt$6iY(8qP-Iy`5rLg^ zKoAA2;13u_7<3^SSn9)4v7@cQ{tHSV$E4s6t&jjzd7}m~6y`oCO`&6u z-4vP4+T$B2<$ql&8H|*$&D$g&!rx85@ax!diYP*CCWIzUHazYUEW%5g+srNV-z)An ze@3gV@z5LeKVS3{yeY}u!>k>_;VZx!Z1m_0@;OZOO~2e858p)3Wfknciw(qAKg0tc z!!HloGh8I|J(5-#+mt=P^*!`rFZDw2CA|IHXrI8=-GAP7s=xm!Pw+)dXHgIHEzjBV zE85Tt^Isp6(CyxqZNNV-O%S~Iz}@#B+Qf*5?UsM}n4kF}x-@Hr5|p)8dXX9lAQGIw zJ`s*uW))!x0ODJr1@)jG!$ZgR5XS*99ut6744%gWzy^O20R7?$N&o|-+sBI9vmYXG1l zuD}WFvovNNQJ=0hw~p)-E6mN#McRzGW{XISHLw7YAV~rs0tp%fG9bwSh7BDuB&hJ< zz=RYL3T(KLA;pFTIa=h{@FK&C6HOK*`B0_9gnueq4k+m{|DweKCoe9HiE!o4jR|Uk zba+ytOrjlc9;{iEXGWnwV;ZH2)MiqjFd?FhxD#a7rVxQ{gsC&B!muB!Ivn_Orp>K6 zLq063aVcD#X~hQhX!W7clvHs-MQYV-+_iaM60O^`FhsQ|^|B58a;M~>b;DBi%Xg{I zu1-q57EYWqGh(N6_qH8bw)0%KSsxQ7tylG7v7Ae%{L9n!YtDqhmSlPKDeBx4XA3vj zm~2a_6-}?)+7mYNv44p-PffWm+~?ipe$Uulm#eb}8U&xbFa4L%w*wV_`hZAc6MO$i zNdN$WSh0nS0Rl*(6?eD*U;qm(6bT3^NU^0Idl;&R6$BE{#2*St9AJ}MN?8DaEn2a~ zm0Jd|r5=0iu|fY_GC*Cu_+xD%Ev|{qkAK zs_ibJPT}oWQc52RvC>rEUrAzFwq_?T z8DxSZMpxorLu8kK znJ3aMO=1gDS|PF(1)8U-po9o#$Z(lz@+_SFb- z_JXMu8zA{(=q7nStzj_Lsq4SVP|N> zKs}YwZl+g50ko+S>y*uS%GL6)Risn(D#Df;6oO9lK2Up_)2bF(QkTIM6BB>ClIPV( zc3YUj6r`YqO>mQN#cN03#DV}R)S?~};MtLUu{|x+qZU%A9RP}Fk1`O-PjU)EnP`_e zHmHtuu7jQII+D3oSP2LSaF}Uh8w6#BB?$kAbWQR2)+NjgKnp24-T>sL3G_A(MdR&m zLdbVJJcZ~WYhOAO-bU!MS)0 zD{UML`$|~9B0@NafP`=%CwZtv_Qu4-A>&=nxG*#FuZdj*%>d^TM- zef@EX57F_DW&FmY(2$BTo?;B2sARY)88;^RO^8F~z&q|($XvNFkbQp`6^MdR!&#Pb zh2uiT2+NR%{bR6{bwiWmz$}i>G4pc9Xl6nJ(aKKN7=|Uh;TqRC!N8j^4kc{JLS|Sz z3l{VU3#~>JR;PkabTftB6G=gi*3y^8|8%A|t07Hy+S8u~b*M!hphqR)cCaKE=$JHs z6WOB`TllW#rXWiUZJ~cc-U|{dl0e-REkwLaS^z^v;fq#uQ$$j@8$(buYao?E3QmHO zPXK@b0JK`4eQJe2N`WVVkikd>mFWQfq=0zVs+soKHbb^hi+9SBN87eg5=imw1dL)r zwtUi%U|EwbNa5Q~0u(Db=1yC1g0722XkuxVFNUkkr#!8yxA=crmB11fsS$6hVF7E> zhx3c#ca2q2Eq>Lpg4|>eSL<^iUbBohJftyOxoLLl^2Yjer~GQDT`P{$iMn*=Y+Wq4 zm5SGqnrE*X4{Fg5KC*+p9Oy)Mt3>}#PBWadoaHgE@LHfs@OaNT+Bq7Z zs7)=hj;GqFB|rK4xPtPRAHa%sf2wI>(#aHbpa?I=`?3dW4Z*p zS^(1Nd1nt;7ZAjG8?WjUpgx#L;c?mz5*ZdEHzxr}_U*_#tSe;t6jc|%mO4`X(NDVO zF~@7KG@k)5oZ-1J0!s@9!WQsS|B60?BL+0UsdzLw`yvSnGPdx(0OBJXS|K|^z!_37 z3w$xNpd)_}8!@!#ATPVK8}hIj+p6r~Qw|ZLjz7DgD|43UI6@qB!VYYt3uHhjB(WOgG8h}d43sl5@*)si!AGjFKN7Jv zJ3$|Vj41SuAbh|&+A}{oK@N+I3ChCkDZxQ|G97=+!Zw6J4O}A*EJ5_3G)fb_K_tXN zG{i$h#6(oYMP$VEkOEs61y@;!V(XLE0uliG4l>{oaZ3U}X$8A55+2E$B#<2iIF>VE zh*x6>IH`px5QtV_5*?w${A>S(TBya1s*n>&Md&%ekt!5&8VIUlwbtShQnz9Ms#19YH-;aX6m|tU*#Jlv|XFW4nRUM!JeBxD!0AE62Wb ztabD_!`nHgBdM?nSGSH zm5V8zD=T00I)yvCm2gL_L&%B}ESOWscf^0YkpruE+&c5XN4-<5iyJGDL`QnWDv?~P zm@3JKbUJM;tCo~VaO60fq)D>tMut2`oQ;X(>bNnI<-^#_(SlZPvVS1L6oz@p;6CQ!RU%D`il*3{?MwO$5u2K!hX~oX|ssO+Qso3<9Fnbg}k?Pb$UK6LlmiG}Ihx zK}M=HRn5#!ozDvW|HDA-R7=&=2CYykx>Otm&sX(S<;+zI1)yMsRm*>D)mf!N0Zoro zNUZ57sIMNDM@*8yqPQx(H7UDk0u z*E@yRc$L?ArPq43*L%fRI)%y0>c@Qb*M9|A`S{59aLR&JQvnOugjLvuW!Q#w*oTGK z(M!ZPI!?}Nv53W3^2~qOI(<}vb=Qpr*^m|4ktNxZ#ZVLFnEyc8P+HlqX<7UWn-Voi zv02%dl~M7qS&D%aoZVSb%2}!`l$q6}pQTv>A=(3Zl%qWre$9`@Yg4CvSCf_6sioSg zwc4w-ti^oS9%}zqMTL)!bq_+-)wh|{Fpb)7ZPhYW*HS%O^NfE_`q)-sgp0MpRneMcxGg-rF4&^^n;GfxGJ! z5bsUW7TqN91(2L&mrvW?DHYn#s$KzUUr=gaD1Bc9ncs9-kkw_=$BTgIVJ`ImN^;MbP#rUI!uHtI=I9b&nNmD^LR4 zylsdbj$F&WC%pfKrjA+!$#h(~84a7MSTT?F*qhj2Z0y5YHv#czy zohT$OEk$;o^qI_sVd1dZ&H!=ayy@ZR@t!LgV>;)hn~ z?+AZqH5q4kIh((8=lOUEdHa(|I_GcZArg16g+T!jsGWbT{H5l#HX8+I>$iq$O^bpoA%%7r9|0qr z9hQsXf~D?4k#;KuTTmsz7VLByfWRgaB3YcomgSdVU;&hkzg9oTvP-#8pCPi#VsxgLj*?byl~a)KeycJC+vP z+Hh6h(r(T16N<~W)a&f%QY9_usfRZ39I(t1t8VVuLEuBFHt=Z!Bhl@Z;50i&A_=@K2O2V#>unk}e( zrBSQ?wv0%0*&%~LKb>qYWJkB8B;bNWuZ4DCg+T}Oqiz@gh?5%U=#H?rp>|)3tWqtg zlGD+GT%YzWje>1YCzKEHgfRb{ zTk;BErk&ZU9qx*7|7w5BNc$j_hyvT{8}*_#fxs^8^5uC-ipyzwBVzCf2Y~AlRm{=u zzLEQhf=jr3?FZ%Z%NgPjx#bpHuEalLbmZ`wK%W3m%jS8DyO*l~WYuIEzuJNP5W$Gi zZXw}H5hpJb$?D3>;R52~uIrH-OSh7{ryUN-9NStF-$8)CwgrE?m@wL4dNuAvnu;W= ze+ah^uBN|x<(7N=QT%wTkjp8{TdHl2kYv>K44#Pq44L`?&}3cO=1;%FtY{v-oVJF* zd%urxz;{c|7m5Go3`@zbBmjPXO8P1#h}uB_z|jN|A)hNRA8K=Z#|MZ607%&*|C0m& z5C8xMFbpyvA;W)&3gfAH)P{yki8a>tin0uUIN`zHbqm;T5H3V&dL zOG&UvEv?wn3K<}sBoIhB0YK3JWVEx|Wgyi;iY;6uHjxyzxH3Ti3UtMj7WCADlXDba zxZ(gtkLvDUoe40FDu+mE}?oZ6PH? z8(hdx6s+9H&WAYmsoemBI!ekG+6^froD7jcDFAy$>FJb?4!{bPnT~lQsyKSOkQP5q z+10611i%GDlqUZ*WUV+#0cEYb?%2u{O3FEDD>`mA(WNT^K~O81PQ}SDrGF%*B%-7A z11hc&Z0Mj8BH2QUE3HT{WU_}25Cw^CNQohAMUpv}6q^uBN*hKcu&Jz@sv2h!lA=1r zj8aglL@AMW%B4|5%1a7E1zVJ5R7wq?$+t#H6u~I=kODE9hi(d_z6^0OZd(01x&_Ih z>L~D(%HH%Y!pP!9CXS@wM1Lc>nFh+#kVvpn&ls1&2IZ-JcB>F48rSHk#U_nmE3He? zSP{b=V>C)VtlYfPg{Nw2sYzN$glZW|Ou>plCJoyqPaIoxA*SXIpar9)PHmB^6=_j5 zLyrQ*C619&|H}{%PixUHs-TjjudQjtd-8R2?wJJOo19srkEo7O)_(xy?Y z5dKc;=iT3(^f;=1@`Y~zjHgTz;5WMBijsF@b|WjUl)+HOo7ln$?WuxVw1cIb4iXZr zv{I?Qn^;Wa77sS99_3vYExq2GD5|Q2xedCfziA_HODSOX2~z2)Qp!~OCsUjK6cWw3 ziA54PxW(WRNx}H@+706+pm?c(BT0gnB7mTmQLYXge_#fIAce=+qh7<=3r>RIlb+zD zC#Y(W1eEY8dz38!0T75$T%iQ9_=G12kOEQ0l@k``q=gAU|H2UWWD9%PVgO@6!d5mj zrgl(G0zDxj7JpJJTWmrB4iTPbV8oO*lmq}%Tw+8>k(kMlafocHhbxkxfDM65A_cpU z3bTllf01NlBa(286?;e_5ZNOl^)SduNO22Nkii`TVB#T1;*cyNF;HNP2x$U<03=Mt z7WJq_J(eg+h5?`^8WJTFNePr9va%;30a1n=$&h+f%8X+u!4xk=fxS0f4!4Odz6Hc~Bf1yaq;F*vJlaQ$NQy!1l$~4-d9plo= zLLh|@GL=DvVgbS?S3yfCMrA8ztOYB|U<)z?U<>vblr8^sq{ksr4VB+=R+O63ONqDw zf4Kodn+(cCcaTPZTwDz`BZCls8nhi6Es^v3cFGd+Q3|X;BSVlRmK#=bg)4js0+PUr ze}6h;k})MkB-=vJ77ArG5+S2PhS@}hxIsr_d58qO$;5U{002qxXF@Ha7>?xak!@6n zY6|JgVo}J5UezL9j0n$ha4|@)@yS6L^ve>Cg;`4FB5elZ3wklrqY+~WE-P|Fo(hi~ zLcF3-YLQnyS!|PAs7FPZD6SIG5|1(Me@q$J*GipiXP_w&3N7gISA--0D~046OKe#Y zQHl!y5+ET%%8H$b%*}kImEkqQOQTW5rK(X@RvJ@vw zKBh}JW^w{HaSv%_86#1!f0jd?43arL4}B6)9>MG?aJdk!nCQb&a#hU~f<#0$h_%?y zYBZ&pvQ1zHXNa;N6k%}-QC=1k*s{kpr%Vga|9J$r*8tjWr%2gt%xV_YhOWn2P1wVj z3!g|71k2S(Sq>@`YI##JM7dwEEvfb>g_8o%kpTpka`uQLiG<#*e=%{QLd{$D9v|5( zuI)(P*rAZm8q2QNl(Iw{k|E+SQkb4seKoF73{Fk-yJ7mnVdtXDrLd+ooocz9HcP;q z2JoUKjV5eVCmyadwy}UQK=?x$m8W6Zh0h*Bv&XOWGfMpkuiu zK{z7L8#WgZpoKH@&F#jg-0k`|6`w(f3jzP#f-yCw03cA#d}C-lUy>RkFNP!{U$3(W z9Tf1`FzdG1!zhG&&Y3~pxv;|YMS$?AltzpKqvTwvxx$_(e~HHhOAshP0`!kvsD?*b zgiJ9J{(Mg3KnDa$1kbUJ{WwUl@Cc=h1QOtsFX-HPAkYB~kPH%#ku{Lz-QZl@n|JJ& z4*nnz2B8oRArTg#5gs8DCLw){L81K*r6ki{V9_dF5i;3=B#e;rFq=QA)^G@fKj{!{ zyaY(J1p@pLe>+%_LiEro)gLSIgc-F0&`_bdu!01*hF!Q)5zRyku@;En&lD`tGg(9% zG~6ge2}n#vjr>I6z!5{_oW@PwVX4*}{|$gS2?k^(#-RE^v zg_2Rg1UyLs9Bu(v0Kt3hibC`f$yARywM9mdS39UkYaz)<06;|P2!n`=3Q<~gmCmc| zjOqc!&yWES(8g9!7A%6*h#l516oeF55pHx9WC_7-?N{27K|T(omuy&cfs61E4>{Pv zBxpn)e+EjR)Ei%9h=M_ip6x^u+|qWWgftCe83IyEG~YoaMJvb`X>1fSa@H)C5=2qa zuB_NIalzH;$U;hiQ#es`_{0r`lM(;54L5Pbx3E@3I0boKo%VEAX8}cLl)?^yTvr8& zKNXcuB?QHk7D7b8N#RfFd=_i<9*UR|L6kufe0#eUV8Sc)MF)k-S81TQVoS613^unJ|s;7;^gMoe9h#FR9V!5yZCrZ9x? zsFD-WkYe=K4kd-~umb%#WnV5ug?v?PU=}-AS+3<8CPE5RBmjABj7f|M4Go8D$R-ya ze+!5>lVF_Xu8`OvnMB()2ynC(Dd9zKCCnDU7SzQiL8MJ?aVp)`e2MUJpK%FLm zNFk0`XS_@5gojgrg(4M9b}o^j{}>i!#3meG#zSz#VT{OYfC$Ra5V$CT1T5Qi+yxRi z*N>n@XRPHv3dUo!+CtnC8NGyJz*EbOf5=^e2UiGKrU}Pkxq_^eW;ZQTVgyzsA(bL+ zg8wZO6|kjiEW}2f%Yf}4&Oyd?L`5r}#ruJJMPV%2bWkeKg0TOHpvkkx$T+YqGjYw36n84vdCE5gV6XKZ&e~QE# z+)PH8t=*rXnm|MWJJ>^pQADSN63PHcsx_IBP}CGK8mL9YQ?z@Fa#I3yJ<9#a{8!YTTM5bqmO86Bi@^M&#BZ zZNe6a=oBJle`sz;!~G8lIj7tKK|gjOK+2N=hAJ-~L=;?XjYVl6PDCYDmW@`#BpqyG z;thA%5MF@UOM;R{We9#Mg#M)42|Z^_JcmTQUoB|C55dqve3}Z?q;Wn*@l79J2C)>5 z7Z>d@Ah(tWjtIrj;a#~=LS%$J6qOAL1XiLC7pSr-fB)i=fSo|RAyc(Q%T-7?!BapA zbf~UIvr-{WAQ3DoZ$*d|0+3f-x`p$$CIDz-KU3jS5Cj>O#sIj2J)lS$73Qke4Y{2a zJT2zM!cAbYgNpy`#@z79EKlDFfZd8 zOhC%We@=w(aMy-$hz+HL8C{}S?1pyUg$U)`hP5L^dsk$<7o|L_DmU^=Bxg~b5Ip57 zUyPM46k@~i2~Blnbd~2%Z}mU@7e?=(MfAmj@UkL-S9W=*a7+fX-86zy=6i;hT}Z?v zwO0LP^%N8}*a1e93iQvcuLu8UQn*G&j0R|Me^Pqt7hgz$_Fi!|yS7b0@NxX}Eg}^w z+=62-L}Cy0E$|xxh(axlf<6*X7*@45Q2}8?q?*2SLjZshtOjWuGgrKqh!yCD28tB4 z0gZ;wWr2=p|Nmz6@u-Y(FnXuAdapNox3_!0H+;wUdVG$vkXc6mh!%h&PY335s?CpWNfac`MmL$+9@e$%-$)#s=1q0up6RspM*oR#}PI zm;B9|J8WP?bcAY^>T9&gD#6O!@{0Un85!6@8GuHn82H{^h-$Q&&ai4Nd~`-9mp`PE z;^25!fr}(k>ZQ;P+p0}7NR}*Y{m6(E6wK*qoZ5g$ zw(=4f{X@>(IK&_n;YdMI3bn%Q8+<;*zx6}F;Yx`#YqWLUvJ{_^PX%ZtK>wbJwM<2< z;e@NDhGYMDO4hc5yZ|jpIEbCS4dPYCKLl>NK946e+oKInXWi?b;1D6To9VW>e|@)L zLAW`oiUepm(T3P|t21va#O#&2!Yu&qJtAK#aYsB(8)HgwpsAw{-19*k3)M*OT=!Rg-S(!RG1ylcPG10!W~PFp_sg>3ZyEE8JtIn@F>)pM$&$ zLqvi-|IpzgUS#d!4s0O6sn*XFJjMZOjKwfWJ7l^)d=mCX02@<5o#$#Tf4D-_WP~Q{ z?wyq+_r~CAVs`$aV zVEsZpB)x*qHv`Y$49d6Q4?f`+zTqD};wQf1N6=byg+qL!AXfxA$Ql6H#m%tMA9u>t z?NA=I_(!tCLkI^`*RV@Le?$%sA8Y0FxghH`RtOP|up+kBCY+Lg&4p0pHAdtx+#t$; z8HMm{f=q5f?`(p1*P#+Th9sOszi=@RT{P-9L^#(2$$*7}a6(+=hO*h&cw#cGjlN4Y z9E|^rkS&_>Pt0PXXx3D0QjNfBt-JB}7g0?h6CvY;II$33lwmU-e?*{=f@5S9KmZs3 zAVDizn*szF00e*(SN0|ZP(W$`Cj|2)>t9oO61`Prr%B@lt z6G$03KxxXLNfr)3e_NETSs+^4wOZSjEztqB21p>71i{;OwfIJ1)$U==R#E~81mf`) zr-TE7q(v&Fmc2GoBqY(oZh)qSwp!_h$t6G%a<;@W+ldt?#}*kp^91iL*E{A`+JOFcIrX+~gWJF^7LNhB2AfATve!z8UBCLYde;6jZY z5eTkF=DPLjX!q*QIGb+$nj;Z#)}$6PTlNE#7$6F^0H_T*wZ}U;4He@JJThLg zXN#$d8cHb1e}a@vDEu^`g&is8LlA({>T9c^R%o>6mj;9Cj<=^2fD8eFaN;nr0ko*6 zl_(I9q?UTN7(fB+u3Ja~8aI(Z3HPu|DTpoHaYG8Jib8^|k*?%=>Mu&V1VUmAdk(4&r&cNqvn3?Te{2OQO={sxG2w0@706A=a%BpI zGJ{M3cMe0KwsyFXB7!K!FmoWDMoUPc$ONMC$HY)I6RpU!eAFEYsijrhlmLhTb3ly0 zfE`=fOspXqZyf+8SL(D$sf!R`l{}uxD@vd;Vp3M#l#;vfAWf%?DTpx@Tqpu(6YHxv zGNj7!J9{|5Xc5eBTiN1I^(@`iV*XN^OXG`|@KpyXf2OO9o`DWp=%I-&+UTQ^Hj5yd zj1sVx!{D6^vd6 zpz+W8y6g#2Qi9vWAHE5)uoheN+M*p+IJqT>Fc$%bONB6fMB}J@U*-C;Hdq|npa9dbPq7+xikSO@TuMaPT1T-*uDBD}q&wF5+H$W{|V$ft{_dkih=-`$VDMKIv@bF(6|5%$09)}#pIU9lMP0U zco6K4QAQyx-ZkYzoA?L9VEBOGKo|PD(stiz%qccc_6IO_DGQ ze}8aDbIT(l^*lzrs$8Xh7BieZlwd$^RWV*!Ttb&{!UiBA!A$aNURxljIUpd46+l_q zW^~sUL?MhUZ2DjbNjDaPRMB^G;az5g7bOKT@h$A2#ohSFn+DRw6;dh4Ov5hF~r7JQhOq7^zpFkjm>d=COe;^jp zq4HwQ`#MUM!EE6;4I#xZ{Mpec)M9c)x!~WfSQH81q-|BxU${~tMp9t&ULwmCcCHyz zwPc2;Ni*tEk(yMcE|sZGb?Q^08da%Ib*YS`>Q%9tRjqEdsoBcdStOIMb|C17Z>fhX z{saKENFQc?B#?0SdLocASz5B#Y02YmgqxBc2dNPGp*yYq~9qSkxTXJXfu%=^u)I#3I%JE&@V|p7D1)qnyZP_0_OTwn2 zkjFl1Nv*2PQj*^&WH^8mf0v$c)e83kDl(g_4J)#?xOxajMvwvkpUM!K(@3HxN8)Rl zP@`53BS^RlNfZA|>&ljm{8u0<<^o61@|yuYC~{y^&362xtN1hgRZt#u8chf<2$&iGitFQ&e^%6^42BGwVi;5k z6hpWv*u~ikJqHnD~MO2m^D)>^Pw*1TiNr96IcP z$FkTb(?c&6)=FENYiKMQ1gB3T-kf{1y7KUoBKR~b)W|Z==dwj1N{JL-Gqw;u*{Yg> z!q+D8rwe{AaSl!;0auKJ7!TFDgqFU%vO5T28=rt)88Uvr?&&HAY-|STi)6iKO2-5zJjM= zV+sJ0anzYL-4kUS%l6f|nwP*o3>AMKb|%s#3b%~F^3nqu4P{0OmTLtUlrRMXlg)FE z9$zSls1_|Shu>1bj9m~Wu|F}0drAxZ@)p7%!ar zB)5jiSUYgo6b`MHTbW#95*1!R1VJo^6X5EG8ifEG@c(=PyIPL~qvj^ghCprtm+od0Dx($R zDEQ{56|}(c$Uu75$N4}5?hEbua!XET5rYvTe>;V~-pl)p8)q(>E48p1gAQJ4MI8wrI zB1RmaDL}U6J51psVo5FJ4hU-DX$)^TXoMRr#1=q<7qcS(kf4A67EI`e2ujm3rZRpB zi*8SU2CHKhhCQr<@eI-l8St3yakUgcqi)F;>B4Gm#PA{o^M)_B3Q%J}j}w}23y`rp z7|-x=AUG;x16qsBlF-Rsq8+ZpFw%-Gk}L=i<}$H} zMkE2Zk`3eHk}iC!$LMM=#sVa2LlXzU#{wWSD1riogw=ltQ~$3`7A7UeQ#msuAq*3&Qoc@NX5)EI_nBM4b$E13`V&nGn2D7myZ~OLd9}|IV}YUwqra~ z;uZ#gN?&8YC6t51wWOuQp$hyf z;*uSFVG|0IJ$VyLUuzS+b>Z9#s4ZxRCt+AhiJ;~(6CK#WK+sZ%)A zDos=cOP8-|@Ip)JkUM^}FY;;Plz=DT(k8~L9d6}F3hYnAhE?i8OV{)t+F%7q4IJ^J zNl|~~{+Ay~VFPxn*#v@X{)8kz4?0`H*4An|4z#5zM<@Qo&m40PjdhoHWav&L zG&o{7?14gW6k#5NAx`8jwUu`kqO_QkIipqI6oM5V^+^x_Rf2KoaH3r8VGN?+7i48y zq~KO>u4zC;L!>l15oqWHg)=X29FYOVHau{LY9c5At|YrXbsTgHc8ZxGXhInwHN4(NYw zra*q!VhW-UkcQ7mq=bop?@zSg{OYdm!p872$dJD1i!w>^FhVW1Kw^S#nm{BJKLi2r zb_cPPA}m(7 z7w47)xg~VVWx?R3?z97U(}FUd^!$I4l5%Wf046H`WtGhH=q|JB+I3!*KLWBb|qqY{v?AosfW(38P`YlF2;F<*OX+b zi#&G=Gy(JiBmuUd`EpQfnB#7p?@#7VZ-*p~fIxJ!BLN1a?!= zc+&R#H0J!w(Q(1?{6yC#`A2W#Wk^uMZNp|c3^#AB7dnbX6UlZHb5RTQV(q3jC1N&X zQka8laeF&=Z|g3S1O+qbf^DVbd<(bk?lzbNB68QKdLFA4fM*M+P=8wSLn+pE2m%S7 z*n+P&@LWoFI~OEvGDrpk(S&~^CqN58G&q>{cmJ9?cPHmph=ZhY4U#Q}4{wcQj;8>Q zZ&G{nHjiOH;+q}I*r;wP9S_Mz5t7*XLWylc;z+XrN%ZTYFn=@ z)D0n0ND^!-Au`UnO72VcxoH6FB(Yjdzs)F?i%0}3*ET1gCMR2vx~SXPYXqgd+KpOZ zBZY|KV8!&WC0n40x{Zi>=%z)IGBmP}3s1=mMMTKkw#B!mEo{t9r0MNLRNB}iqm2OT zUXF_|7=oqc8RUNyG&z=9A#yFa0Ew+mr~f?2ELJ2WBWmL->ZhV>qdfyFqX7wHR%F_! z?X>_UArMGO9gDc-+R#|!Wc9f}keaXal(xJ&P?CeYUIne;8N5jwy}>JL+RZM;1OPNP z2Kmam0Lb0IqYacgQ!eHx26O1T+Pn!`UIHezWzFPLO|pOGWx4E{pbOluoz1&5E;7P< ztEaQsWUZ!+PQxvnv;k>kQ?ALh7PS^|`zWM8)_~i(O%tzLpgJp%fTdrt)XdJ?1+g zuE@7?)V-3)P5snSJ=MWZY~NhfUH#Qz9c+D3tVMrj@5VT_qhzIB+)GGEu<5+Sq|AN* zjjY;2nVcHe${B*pO|74s#odCWfvrfI451NP!EN2H%-OkkVmT6I-tru^Nv@;DSjF)L zRcItG8t2^Ggs+;bhQErXmAz@OZaBf6b81>@H>R9vrPYf)cJ}?*t9^<1`CUyi&o8FD zUqyc)B+?bo-7rvsS$e8I;G-=d-v4!&3)uY~+Q%JJtnJ1-X50V?+uBXRlN+_Oz1YX% z*R5JG&4orhC}dX97q{I*v*avr9JZqdf;bQ0r@E1+B4F0S$Q1NmlCts)g3wz-$CB0KcqWm$}rM>O)djq;9Gy`IHsZMXR*H-w|~gst-fCyhe~XL(a?+( zDubs#X6i#KjjDX?F}K~5PqaZE=RhxF{POGTG+ajhxe>c$4!mjPUd$7pNCM`y%${VT zmf2%{@+rUaEq`YC{qi+`^ErR)z#ikNTT0wj%tt2k<$YVUM)X(4^r_yd?A(1-=J9{E ziBQu%TH#UlGhS>FnAAZzX>4J3d;)-sYu43;kby)cC|8855oeHt zLHQ>kh4&4V+i(zhh7fsFiFn{j{|Tkvc_*^PVv8=m2xE*gmmqxxB!5}sO?uJ9)09wN ziDi~rZpmerUVaH?m|~7eW|{w6ftDJJ7NP}PS248(075`Sca?c&4kc%t+^rduollXN zBoKOTl|Y*k()0EBbq-?a( zPD^dIMg?mvup45VZMWV=%O_H_Vffd3gkhH(dE8FOP*^LrRvTiF5x1hDaWVC-O!-PA z+PVZrWG_VAHV2@-_jYBgf{}%LUP$jT)$fMdO0*ZC0BcBexSW%H-8otJTs%(E*^<-d+z* zc;O|+#xvrE4=MQKjz120Ef*?k=a^}kAM`lw8QyvmicwI*>H^ybB2h{@4ot%A+Sj+t&om!&Ht4( zTqzS3ggi*tLOu59k`F~lbSbNu!^-6`UR@4w5*rG3e~=PE03u-w^=QKY0^mF5EKqJB zQwqy?hNk1xMH5?yLU%qR0gK2iW+DvS07zjAQUq{@AYmBwYQ(@F3h_dTQ_&EQh{Pl+ zafwVcTmk^_M1yR>9{-@CEyxf;p+FIe1VI2R>fsL^y6%ZlTnZVtct(S?q8_%`go{Yw z4_A;ue^m+Co)iM003Au8751=4NiKzeDAb}JV?fF#{(-=#n9m-z;9nCT2}x8*ZIZ}~ zB(W@s$xIdr36j#JCl~j*u|d&>v5|r)ama-PWHB*o16y{~Hh}g$CQkUXMLV|tkOYUc za&5819@u0F5#Ppt z=uh6c!^?3;od_X?J=&mziMVq^4RQ(IFv2^65@aC@@nLz086I`ov!WYa)BvKuj#_9U zf1#EtSA!5T#Wv(}qWY}RDlfyIg~kPkI(!p5fd$e}3U#P9D~Uqg6xv&p8qBJYh1Q) z#T4*3F&CLNQ6UWC+@`WF0nvhxNRfm+e@MYqr?@3}FYzl>9}C%|B+pqnVQgeC8#&E6 zR!jE4tY<%Ksvrm`0a0)tO%#BxL`eXZs_Yn7u^Gzj!8K94w4VEz$2_wQM72mzZ4wla zgnxpEw$DMpT-N3;=GlaKR1(i?0HBwg(1c2{)X4yH3%}SXcej!G++(I2OOlSpe}%C; zjet&i8Q)q|ARmz~Uv;Bg>r$zrbGpfO?Mo$#r46X~g$rtUqJ;Shmo5mdQ^FIOW7-mv!nR{`(s)lC8% zgaI3bO5wSnK%^@W@rZT+(MfSdf6>Mr6JvD5-yJe~n5PAHv@xs|A_#ST8~`9pu@$l5 zYDFz9;}?JMenO@fa}~sz^V*j*k;%)KkK<9$s(H=IF(UTZ4Cgq@c`5Da?40kcDoxa5 z6fv~L0CITN>ayahQ54BT18~MD;E)qZo$6QxVpx0lqn}$SBWYT(1=88$e;1fWba=@c znlZHE7k@_VLL|V7{gJ_{^{B;J6F|CtGE&y9sL3YmQ4f3cQ+5tXLJkK|5?gdyAWuxf zGwpGYdXNE8D)Dn0(n|MbOP@GJq}#^WK6I=4}smlpc2X@fX%kEFb$Q>oBN^lte?+NUXAfZ2ex5>%5&g^SbukIqzPCVZseqi$+Z7Mx$T;~|zFOFW z44fcCKig4;O=xZuWEh3(t*4MAbe+oF_p{6|6M@Y)ofCYNsf-8U!2$G>6~PAp5=K!E zHc(lIR;4mRq(BCaw8IV;;m_L!a6Wu=ZvuEnqV;bNwF%B=e@kSbcT43GoTLfrMhdHj zX@G+OlAwI|rhu%FKF_B=wZIPgV@sSP1>0wHR)s$IR|}Wte*<8BZiIiNPzLSb3Z+m6 z6G##Y0RSbKL;1%?bEGFcMSMQEeDl&$JE(kZfP)>eP_;B}IRDs4dXj*~=LX)xDE!oe zq%efyv4i8*e@_!=dF_V~$EJZ+MLKV@M^C7LI&@P)(R)y+Y#$T^S}<-_1%|?AQCjeP z`S(RC#Zqm701?=APbh#cF=b!4h6;g&6^9TAF$y`P0CA!P=%$1|=tlLm1*9+rG{ zGke?SgXkk3k%x%-S80xiichsOZ_(2?ZSQ|%paQ`D=2v-nuS8q0<1&IO#ie?X+ zf8a={!(_Bbn@EH_&q7(XIU2foDnQ1YzX_b+A|PTw54G@VHL-loRVd-dd0WOQ;$(i; zrDSg8d8o&Yu$2IulmuV&dS(LvQSgNd5oKIpjuGW)r&kc*8I@541d>Ni(6vSyDO%6O zj_r^=2_Raaxdph9oyN6BxFkN>cSE3)e{E&JVi7d}8&Ff~_YvUfPXGmx-cCkm~m z3H)(3e8dGYW*Gd$f6f_5gLIDkfdG%laPimp=||#)IKh&i93)c!&|SvGLA7Rq z13-tZ5CsY_Vadl1qBlhw%7Z;cn&Wns=Yx6FWPK<9dKp-FSBjw!U@3|N@l5R3e^44~s#Z$qjgKiNWRw;)#1uPQ4MtV1j&NPC)0|1ZMQ#!X~agqdWU=Li7Q~6c~ zIfhP`A$e5xfZ5gtv-2@8;Wm;YkG7x%i7^G|=0*#Fr+ONd!WM~_fl{`xQY;aOwongx z8ic_kVg{M3UCDgdbpVb0ImB&RsZG+-^Y;4ngoBz7_Hz6lv-}Ja94D} zuX?e9X$h2icND`|3uQ+T5ikm`nVy5B1=PrPNyZ8hMwkaCY-eS!q@x7~hkQsdl%t?I z0HBk0g;&PZd+%dFwvbC4f9DR{N*7XKU?8D@oN%%Qp#{YFa|iKnEJzT&mrYX^Tsnt9 zzO#8pmpy045StK2#UXHl+7w9eMPr~9I#&vC#hiKvntGa4P%v0B3lUquk--%hTi_1h z_;rqitNI$1VZ>FVkdR3Q1P=9V0;36Cw-;sq>#_+E0k_t4zS*p)f2%ZpH#K9XCGIeI zsta0eNF;0}ySFPnkw*{jN-3$MdCeI@@u+066a?ZlqPXWiw3U0xw4HHDuoIR{K`=#S zz?sTvTd@H*#5n-^vweINl=HR${)YsbPzwyF6>fw+i1LV2$)8DJT-<3ra03^G0yxH- zM^fnu`op~gU=O52f1G}%zg<@kx6rCcP@#2)O&=5JpXAhjC0E;H3`u}t+&_xDp)ODtYe-Rj159zZo8`Xt&V1-#K zviiG5BS>|RtAo+;!V~&$DF}$~cR)pKRgA=%#j!gvIBzEBhrl*&!8MI|$^ikuPq%Pa zIfYVw$R~3Wfe{-iXZ)~tyrOq%4}1Ji?%ApzR}bzAP`o#5t&o_zN~>tRu;$3P9AX#Y zM|Yg562yeNf2X`O#EQD8%*ridXRQp&u`DLH*bgV=XpY8laI#U0LsJk%RP`6+ zI!Z1SjiUsgA52C9*RB}>uf%q-1E@b7acu)IXmh!*(K;8xW=mw}JJN_!2a87P*9whl zrS{;kvqxp6kiYuN5jCM_QG7$(SIwIL&=1QR5+ch1e=6rsD7%h}5s|LY%}|+FGO