Skip to content
This repository has been archived by the owner on Dec 26, 2022. It is now read-only.

Commit

Permalink
Merge pull request #545 from afcidk/regression-test
Browse files Browse the repository at this point in the history
feat(MQTT): Implement MQTT regression test
  • Loading branch information
howjmay committed May 22, 2020
2 parents cc69983 + 3d9ef90 commit 0941fc9
Show file tree
Hide file tree
Showing 13 changed files with 313 additions and 25 deletions.
188 changes: 183 additions & 5 deletions tests/regression/common.py
Original file line number Diff line number Diff line change
@@ -1,46 +1,75 @@
import time
import re
import sys
import json
import string
import random
import logging
import requests
import statistics
import subprocess
import argparse
import paho.mqtt.publish as publish
import paho.mqtt.subscribe as subscribe
from multiprocessing import Pool
from multiprocessing.context import TimeoutError

TIMES_TOTAL = 100
TIMEOUT = 100 # [sec]
MQTT_RECV_TIMEOUT = 30
STATUS_CODE_500 = "500"
STATUS_CODE_405 = "405"
STATUS_CODE_404 = "404"
STATUS_CODE_400 = "400"
STATUS_CODE_200 = "200"
STATUS_CODE_ERR = "-1"
EMPTY_REPLY = "000"
LEN_TAG = 27
LEN_ADDR = 81
LEN_MSG_SIGN = 2187
TRYTE_ALPHABET = "ABCDEFGHIJKLMNOPQRSTUVWXYZ9"
URL = ""
DEVICE_ID = None
CONNECTION_METHOD = None


def parse_cli_arg():
global URL
global CONNECTION_METHOD
global DEVICE_ID
rand_device_id = ''.join(
random.choice(string.printable[:62]) for _ in range(32))
parser = argparse.ArgumentParser('Regression test runner program')
parser.add_argument('-u',
'--url',
dest='raw_url',
default="localhost:8000")
parser.add_argument('-d', '--debug', dest="debug", action="store_true")
parser.add_argument('--nostat', dest="no_stat", action="store_true")
parser.add_argument('--mqtt', dest="enable_mqtt", action="store_true")
parser.add_argument('--device_id',
dest="device_id",
default=rand_device_id)
args = parser.parse_args()

# Determine whether to use full time statistic or not
if args.no_stat:
global TIMES_TOTAL
TIMES_TOTAL = 2
# Determine connection method
if args.enable_mqtt:
CONNECTION_METHOD = "mqtt"
URL = "localhost"
else:
CONNECTION_METHOD = "http"
URL = "http://" + args.raw_url
# Run with debug mode or not
if args.debug:
logging.basicConfig(level=logging.DEBUG)
else:
logging.basicConfig(level=logging.INFO)
URL = "http://" + args.raw_url
# Configure connection destination
DEVICE_ID = args.device_id


def eval_stat(time_cost, func_name):
Expand Down Expand Up @@ -69,11 +98,15 @@ def test_logger(f):
logger = logging.getLogger(f.__module__)
name = f.__name__

def decorate(instance):
logger.debug(f"Testing case = {name}")
return instance
def decorate(*args, **kwargs):
bg_color = "\033[48;5;38m"
fg_color = "\033[38;5;16m"
clear_color = "\033[0m"
logger.info(f"{bg_color}{fg_color}{name}{clear_color}")
res = f(*args, **kwargs)
return res

return decorate(f)
return decorate


def valid_trytes(trytes, trytes_len):
Expand All @@ -94,7 +127,98 @@ def map_field(key, value):
return json.dumps(ret)


# Simulate random field to mqtt since we cannot put the information in the route
def add_random_field(post):
return data


def route_http_to_mqtt(query, get_data, post_data):
data = {}
if get_data: query += get_data
if post_data:
data.update(json.loads(post_data))
if query[-1] == "/": query = query[:-1] # Remove trailing slash

# api_generate_address
r = re.search("/address$", query)
if r is not None:
return query, data

# api_find_transactions_by_tag
r = re.search(f"/tag/(?P<tag>[\x00-\xff]*?)/hashes$", query)
if r is not None:
tag = r.group("tag")
data.update({"tag": tag})
query = "/tag/hashes"
return query, data

# api_find_transactions_object_by_tag
r = re.search(f"/tag/(?P<tag>[\x00-\xff]*?)$", query)
if r is not None:
tag = r.group("tag")
data.update({"tag": tag})
query = "/tag/object"
return query, data

# api_find_transacion_object
r = re.search(f"/transaction/object$", query)
if r is not None:
query = "/transaction/object"
return query, data

r = re.search(f"/transaction/(?P<hash>[\u0000-\uffff]*?)$", query)
if r is not None:
hash = r.group("hash")
data.update({"hash": hash})
query = f"/transaction"
return query, data

# api_send_transfer
r = re.search(f"/transaction$", query)
if r is not None:
query = "/transaction/send"
return query, data

# api_get_tips
r = re.search(f"/tips$", query)
if r is not None:
query = "/tips/all"
return query, data

# api_get_tips_pair
r = re.search(f"/tips/pair$", query)
if r is not None:
return query, data

# api_send_trytes
r = re.search(f"/tryte$", query)
if r is not None:
return query, data

# Error, cannot identify route (return directly from regression test)
return None, None


def API(get_query, get_data=None, post_data=None):
global CONNECTION_METHOD
assert CONNECTION_METHOD != None
if CONNECTION_METHOD == "http":
return _API_http(get_query, get_data, post_data)
elif CONNECTION_METHOD == "mqtt":
query, data = route_http_to_mqtt(get_query, get_data, post_data)
if (query, data) == (None, None):
msg = {
"message":
"Cannot identify route, directly return from regression test",
"status_code": STATUS_CODE_400
}
logging.debug(msg)
return msg

return _API_mqtt(query, data)


def _API_http(get_query, get_data, post_data):
global URL
command = "curl {} -X POST -H 'Content-Type: application/json' -w \", %{{http_code}}\" -d '{}'"
try:
Expand Down Expand Up @@ -130,3 +254,57 @@ def API(get_query, get_data=None, post_data=None):
logging.debug(f"Command = {command}, response = {response}")

return response


def _subscribe(get_query):
add_slash = ""
if get_query[-1] != "/": add_slash = "/"
topic = f"root/topics{get_query}{add_slash}{DEVICE_ID}"
logging.debug(f"Subscribe topic: {topic}")

return subscribe.simple(topics=topic, hostname=URL, qos=1).payload


def _API_mqtt(get_query, data):
global URL, DEVICE_ID
data.update({"device_id": DEVICE_ID})

# Put subscriber in a thread since it is a blocking function
with Pool() as p:
payload = p.apply_async(_subscribe, [get_query])
topic = f"root/topics{get_query}"
logging.debug(f"Publish topic: {topic}, data: {data}")

# Prevents publish execute earlier than subscribe
time.sleep(0.1)

# Publish requests
publish.single(topic, json.dumps(data), hostname=URL, qos=1)
msg = {}
try:
res = payload.get(MQTT_RECV_TIMEOUT)
msg = json.loads(res)

if type(msg) is dict and "message" in msg.keys():
content = msg["message"]
if content == "Internal service error":
msg.update({"status_code": STATUS_CODE_500})
elif content == "Request not found":
msg.update({"status_code": STATUS_CODE_404})
elif content == "Invalid path" or content == "Invalid request header":
msg.update({"status_code": STATUS_CODE_400})
else:
msg.update({"status_code": STATUS_CODE_200})
else:
msg = {
"content": json.dumps(msg),
"status_code": STATUS_CODE_200
}
except TimeoutError:
msg = {
"content": "Time limit exceed",
"status_code": STATUS_CODE_ERR
}

logging.debug(f"Modified response: {msg}")
return msg
1 change: 1 addition & 0 deletions tests/regression/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
certifi==2019.11.28
chardet==3.0.4
idna==2.7
paho-mqtt==1.5.0
requests==2.20.0
urllib3==1.24.3
4 changes: 2 additions & 2 deletions tests/regression/run-api-with-mqtt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,12 @@ for (( i = 0; i < ${#OPTIONS[@]}; i++ )); do
cli_arg=${option} | cut -d '|' -f 1
build_arg=${option} | cut -d '|' -f 2

bazel run accelerator ${build_arg} -- --ta_port=${TA_PORT} ${cli_arg} &
bazel run accelerator --define mqtt=enable ${build_arg} -- --quiet --ta_port=${TA_PORT} ${cli_arg} &
TA=$!
sleep ${sleep_time} # TA takes time to be built
trap "kill -9 ${TA};" INT # Trap SIGINT from Ctrl-C to stop TA

python3 tests/regression/runner.py ${remaining_args} --url localhost:${TA_PORT}
python3 tests/regression/runner.py ${remaining_args} --url "localhost" --mqtt
rc=$?

if [ $rc -ne 0 ]
Expand Down
9 changes: 7 additions & 2 deletions tests/regression/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,15 @@

suite_path = os.path.join(os.path.dirname(__file__), "test_suite")
sys.path.append(suite_path)
unsuccessful_module = []
for module in os.listdir(suite_path):
if module[-3:] == ".py":
mod = __import__(module[:-3], locals(), globals())
suite = unittest.TestLoader().loadTestsFromModule(mod)
result = unittest.TextTestRunner().run(suite)
result = unittest.TextTestRunner(verbosity=0).run(suite)
if not result.wasSuccessful():
exit(1)
unsuccessful_module.append(module)

if len(unsuccessful_module):
print(f"Error module: {unsuccessful_module}")
exit(1)
5 changes: 3 additions & 2 deletions tests/regression/test_suite/find_transactions_hash_by_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,16 @@ def test_time_statistics(self):
eval_stat(time_cost, "find transactions by tag")

@classmethod
@test_logger
def setUpClass(cls):
rand_trytes_26 = gen_rand_trytes(26)
rand_tag = gen_rand_trytes(LEN_TAG)
rand_addr = gen_rand_trytes(LEN_ADDR)
rand_msg = gen_rand_trytes(30)
rand_len = random.randrange(28, 50)
rand_len_trytes = gen_rand_trytes(rand_len)
FindTransactionsHashByTag()._send_transaction(
rand_msg, rand_tag, rand_addr)
FindTransactionsHashByTag()._send_transaction(rand_msg, rand_tag,
rand_addr)
cls.query_string = [
rand_tag, "", f"{rand_trytes_26}@", f"{rand_tag}\x00",
f"{rand_tag}\x00{rand_tag}", "一二三四五", rand_len_trytes
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ def test_time_statistics(self):
eval_stat(time_cost, "find transactions objects by tag")

@classmethod
@test_logger
def setUpClass(cls):
rand_trytes_26 = gen_rand_trytes(26)
rand_tag = gen_rand_trytes(LEN_TAG)
Expand Down
1 change: 1 addition & 0 deletions tests/regression/test_suite/generate_address.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def test_time_statistics(self):
eval_stat(time_cost, "generate_address")

@classmethod
@test_logger
def setUpClass(cls):
rand_tag_27 = gen_rand_trytes(27)
cls.query_string = ["", rand_tag_27, "飛天義大利麵神教"]
1 change: 1 addition & 0 deletions tests/regression/test_suite/get_tips.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_time_statistics(self):
eval_stat(time_cost, "get_tips")

@classmethod
@test_logger
def setUpClass(cls):
rand_tag_27 = gen_rand_trytes(27)
cls.query_string = ["", rand_tag_27, "飛天義大利麵神教"]
1 change: 1 addition & 0 deletions tests/regression/test_suite/get_tips_pair.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def test_time_statistics(self):
eval_stat(time_cost, "get tips pair")

@classmethod
@test_logger
def setUpClass(cls):
rand_tag_27 = gen_rand_trytes(27)
cls.query_string = ["", rand_tag_27, "飛天義大利麵神教"]
11 changes: 6 additions & 5 deletions tests/regression/test_suite/get_transactions_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@ class GetTransactionsObject(unittest.TestCase):
def test_81_trytes_hash(self):
res = API("/transaction/object",
post_data=map_field(self.post_field, [self.query_string[0]]))
self._verify_pass(res, 0)
self._verify_pass(res, idx=0)

# Multiple 81 trytes transaction hash (pass)
@test_logger
def test_mult_81_trytes_hash(self):
res = API("/transaction/object",
post_data=map_field(self.post_field, [self.query_string[1]]))
self._verify_pass(res, 1)
self._verify_pass(res, idx=1)

# 20 trytes transaction hash (fail)
@test_logger
Expand All @@ -33,7 +33,7 @@ def test_20_trytes_hash(self):
def test_100_trytes_hash(self):
res = API("/transaction/object",
post_data=map_field(self.post_field, [self.query_string[3]]))
self.assertEqual(STATUS_CODE_500, res["status_code"])
self.assertEqual(STATUS_CODE_404, res["status_code"])

# Unicode transaction hash (fail)
@test_logger
Expand Down Expand Up @@ -62,6 +62,7 @@ def test_time_statistics(self):
eval_stat(time_cost, "find transaction objects")

@classmethod
@test_logger
def setUpClass(cls):
sent_txn_tmp = []
for i in range(3):
Expand All @@ -85,8 +86,8 @@ def setUpClass(cls):
cls.response_field = []
cls.query_string = [[sent_txn_tmp[0]["hash"]],
[sent_txn_tmp[1]["hash"], sent_txn_tmp[2]["hash"]],
gen_rand_trytes(19),
gen_rand_trytes(100), "工程師批哩趴啦的生活", ""]
[gen_rand_trytes(20)], [gen_rand_trytes(100)],
["工程師批哩趴啦的生活"], [""]]

def _verify_pass(self, res, idx):
expected_txns = self.sent_txn[idx]
Expand Down
Loading

0 comments on commit 0941fc9

Please sign in to comment.