Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Created shell tests and fixed bugs #2940

Merged
merged 2 commits into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
clangd tidy clangd-diagnostics \
install \
clean-extension clean-python-api clean-java clean \
extension-test
extension-test shell-test

.ONESHELL:
.SHELLFLAGS = -ec
Expand Down Expand Up @@ -174,6 +174,12 @@ extension-release:
-DBUILD_KUZU=FALSE \
)

shell-test:
$(call run-cmake-release, \
-DBUILD_SHELL=TRUE \
)
python3 -m pytest -v tools/shell/test
MSebanc marked this conversation as resolved.
Show resolved Hide resolved

# Clang-related tools and checks

# Must build the java native header to avoid missing includes. Pipe character
Expand Down
30 changes: 13 additions & 17 deletions tools/shell/linenoise.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1253,14 +1253,7 @@ static void refreshSearch(struct linenoiseState* l) {
}

static void cancelSearch(linenoiseState* l) {
char* tempBuf = l->buf;
l->len = 0;
l->pos = 0;
l->buf = (char*)"";
refreshSearchMultiLine(l, (char*)"", (char*)"");
l->buf = tempBuf;
l->len = strlen(tempBuf);
l->pos = l->len;

history_len--;
free(history[history_len]);
Expand All @@ -1270,14 +1263,15 @@ static void cancelSearch(linenoiseState* l) {
l->search_buf = std::string();
l->search_matches.clear();
l->search_index = 0;
refreshLine(l);
}

static char acceptSearch(linenoiseState* l, char nextCommand) {
bool no_matches = true;
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
int history_index = l->prev_search_match_history_index;
if (l->search_index < l->search_matches.size()) {
// if there is a match - copy it into the buffer
auto match = l->search_matches[l->search_index];
no_matches = false;
history_index = match.history_index;
}

Expand All @@ -1297,7 +1291,12 @@ static char acceptSearch(linenoiseState* l, char nextCommand) {
}
strncpy(l->buf, history[history_len - 1 - l->history_index], l->buflen);
l->buf[l->buflen - 1] = '\0';
l->len = l->pos = strlen(l->buf);
l->len = strlen(l->buf);
if (no_matches) {
l->pos = l->len;
} else {
l->pos = l->search_matches[l->search_index].match_end;
}
}

cancelSearch(l);
Expand Down Expand Up @@ -1390,13 +1389,16 @@ static char linenoiseSearch(linenoiseState *l, char c) {
char seq[64];

switch (c) {
case 10:
case ENTER: /* enter */
// accept search and run
return acceptSearch(l, ENTER);
case CTRL_N:
case CTRL_R:
// move to the next match index
searchNext(l);
break;
case CTRL_P:
case CTRL_S:
// move to the prev match index
searchPrev(l);
Expand Down Expand Up @@ -1479,7 +1481,7 @@ static char linenoiseSearch(linenoiseState *l, char c) {
case CTRL_A: // accept search, move to start of line
return acceptSearch(l, CTRL_A);
case TAB:
if (pastedInput(c)) {
if (l->hasMoreData) {
l->search_buf += ' ';
performSearch(l);
break;
Expand All @@ -1502,12 +1504,6 @@ static char linenoiseSearch(linenoiseState *l, char c) {
case CTRL_L:
linenoiseClearScreen();
break;
case CTRL_P:
searchPrev(l);
break;
case CTRL_N:
searchNext(l);
break;
case CTRL_C:
case CTRL_G:
// abort search
Expand Down Expand Up @@ -1890,7 +1886,7 @@ static int linenoiseEdit(
break;
case CTRL_P: /* ctrl-p */
linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_PREV);
break;
break;
case CTRL_N: /* ctrl-n */
linenoiseEditHistoryNext(&l, LINENOISE_HISTORY_NEXT);
break;
Expand Down
178 changes: 178 additions & 0 deletions tools/shell/test/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
from encodings import utf_8
import pytest
import os
import shutil
import subprocess
import pexpect
from test_helper import KUZU_EXEC_PATH, KUZU_ROOT
from typing import List, Union


def pytest_addoption(parser):
parser.addoption("--start-offset", action="store", type=int, help="Skip the first 'n' tests")


def pytest_collection_modifyitems(config, items):
start_offset = config.getoption("--start-offset")
if not start_offset:
# --skiplist not given in cli, therefore move on
return

skipped = pytest.mark.skip(reason="included in --skiplist")
skipped_items = items[:start_offset]
for item in skipped_items:
item.add_marker(skipped)


class TestResult:
def __init__(self, stdout, stderr, status_code):
self.stdout: Union[str, bytes] = stdout
self.stderr: Union[str, bytes] = stderr
self.status_code: int = status_code

def check_stdout(self, expected: Union[str, List[str], bytes]):
if isinstance(expected, list):
expected = '\n'.join(expected)
assert self.status_code == 0
assert expected in self.stdout

def check_not_stdout(self, expected: Union[str, List[str], bytes]):
if isinstance(expected, list):
expected = '\n'.join(expected)
assert self.status_code == 0
assert expected not in self.stdout

def check_stderr(self, expected: str):
assert expected in self.stderr


class ShellTest:
def __init__(self):
self.shell = KUZU_EXEC_PATH
self.arguments = [self.shell]
self.statements: List[str] = []
self.input = None
self.output = None
self.environment = {}
self.shell_process = None

def add_argument(self, *args):
self.arguments.extend(args)
return self

def statement(self, stmt):
self.statements.append(stmt)
return self

def query(self, *stmts):
self.statements.extend(stmts)
return self

def input_file(self, file_path):
self.input = file_path
return self

def output_file(self, file_path):
self.output = file_path
return self

# Test Running methods

def get_command(self, cmd: str) -> List[str]:
command = self.arguments
if self.input:
command += [cmd]
return command

def get_input_data(self, cmd: str):
if self.input:
input_data = open(self.input, 'rb').read()
else:
input_data = bytearray(cmd, 'utf8')
return input_data

def get_output_pipe(self):
output_pipe = subprocess.PIPE
if self.output:
output_pipe = open(self.output, 'w+')
return output_pipe

def get_statements(self):
statements = []
for statement in self.statements:
statements.append(statement)
return '\n'.join(statements)

def get_output_data(self, res):
if self.output:
stdout = open(self.output, 'r').read()
else:
stdout = res.stdout.decode('utf8').strip()
stderr = res.stderr.decode('utf8').strip()
return stdout, stderr

def run(self):
statements = self.get_statements()
command = self.get_command(statements)
input_data = self.get_input_data(statements)
output_pipe = self.get_output_pipe()

my_env = os.environ.copy()
for key, val in self.environment.items():
my_env[key] = val

res = subprocess.run(command, input=input_data, stdout=output_pipe, stderr=subprocess.PIPE, env=my_env)

stdout, stderr = self.get_output_data(res)
return TestResult(stdout, stderr, res.returncode)

def start(self):
command = " ".join(self.arguments)

my_env = os.environ.copy()
for key, val in self.environment.items():
my_env[key] = val

self.shell_process = pexpect.spawn(command, encoding = "utf_8", env = my_env)

def send_finished_statement(self, stmt: str):
if self.shell_process:
assert self.shell_process.expect_exact(["kuzu", pexpect.EOF]) == 0
self.shell_process.send(stmt)
assert self.shell_process.expect_exact(["kuzu", pexpect.EOF]) == 0

def send_statement(self, stmt: str):
if self.shell_process:
assert self.shell_process.expect_exact(["kuzu", pexpect.EOF]) == 0
self.shell_process.send(stmt)

def send_control_statement(self, stmt: str):
if self.shell_process:
assert self.shell_process.expect_exact(["kuzu", pexpect.EOF]) == 0
self.shell_process.sendcontrol(stmt)


@pytest.fixture
def temp_db(tmp_path):
if os.path.exists(tmp_path):
shutil.rmtree(tmp_path)
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
output_path = str(tmp_path)
return output_path


@pytest.fixture
def get_tmp_path(tmp_path):
return str(tmp_path)


@pytest.fixture
def history_path():
path = os.path.join(KUZU_ROOT, 'tools', 'shell', 'test', 'files')
if (os.path.exists(os.path.join(path, "history.txt"))):
os.remove(os.path.join(path, "history.txt"))
MSebanc marked this conversation as resolved.
Show resolved Hide resolved
return path


@pytest.fixture
def csv_path():
return os.path.join(KUZU_ROOT, 'tools', 'shell', 'test', 'files', 'vPerson.csv')
Loading
Loading