Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replace triple-quotes by modifying tokens #144

Merged
merged 18 commits into from
Sep 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 77 additions & 48 deletions blackdoc/formats/doctest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import io
import itertools
import re
import tokenize
from tokenize import TokenError

import more_itertools

Expand Down Expand Up @@ -53,44 +56,48 @@ def detection_func(lines):
return line_range, name, "\n".join(lines)


def tokenize(code):
import io
import tokenize
def suppress(iterable, errors):
iter_ = iter(iterable)
while True:
try:
yield next(iter_)
except errors:
yield None
except StopIteration:
break


def tokenize_string(code):
readline = io.StringIO(code).readline

return tokenize.generate_tokens(readline)


def extract_string_tokens(code):
tokens = tokenize_string(code)

# suppress invalid code errors: `black` will raise with a better error message
return (
token
for token in tokenize.generate_tokens(readline)
if token.type == tokenize.STRING
for token in suppress(tokens, TokenError)
if token is not None and token.type == tokenize.STRING
)


def expand_tokens(token):
length = token.end[0] - token.start[0] + 1
return [token.string] * length


def detect_docstring_quotes(line):
def detect_quotes(string):
if string.startswith("'''"):
def detect_docstring_quotes(code_unit):
def extract_quotes(string):
if string.startswith("'''") and string.endswith("'''"):
return "'''"
elif string.startswith('"""'):
elif string.startswith('"""') and string.endswith('"""'):
return '"""'
else:
return None

def expand_quotes(quotes, n_lines):
lines = [None] * n_lines
for token, quote in quotes.items():
token_length = token.end[0] - token.start[0] + 1
lines[token.start[0] - 1 : token.end[0]] = [quote] * token_length
return lines
string_tokens = list(extract_string_tokens(code_unit))
token_quotes = {token: extract_quotes(token.string) for token in string_tokens}
quotes = (quote for quote in token_quotes.values() if quote is not None)

string_tokens = list(tokenize(line))
quotes = {token: detect_quotes(token.string) for token in string_tokens}
lines = line.split("\n")
return expand_quotes(quotes, len(lines))
return more_itertools.first(quotes, None)


def extraction_func(line):
Expand Down Expand Up @@ -127,45 +134,67 @@ def remove_prompt(line):
}, extracted_line


def replace_quotes(line, current, saved):
if current is None or saved is None:
return line
elif current == saved:
return line
else:
return line.replace(current, saved)
def restore_quotes(code_unit, original_quotes):
def line_offsets(code_unit):
offsets = [m.end() for m in re.finditer("\n", code_unit)]

return {lineno: offset for lineno, offset in enumerate([0] + offsets, start=1)}

def compute_offset(pos, offsets):
lineno, charno = pos
return offsets[lineno] + charno

if original_quotes is None:
return code_unit

to_replace = "'''" if original_quotes == '"""' else '"""'

def reformatting_func(line, docstring_quotes):
string_tokens = extract_string_tokens(code_unit)
triple_quote_tokens = [
token
for token in string_tokens
if token.string.startswith(to_replace) and token.string.endswith(to_replace)
]

offsets = line_offsets(code_unit)
mutable_string = io.StringIO(code_unit)
for token in triple_quote_tokens:
# find the offset in the stream
start = compute_offset(token.start, offsets)
end = compute_offset(token.end, offsets) - 3

mutable_string.seek(start)
mutable_string.write(original_quotes)

mutable_string.seek(end)
mutable_string.write(original_quotes)

restored_code_unit = mutable_string.getvalue()

return restored_code_unit


def reformatting_func(code_unit, docstring_quotes):
def add_prompt(prompt, line):
if not line:
return prompt

return " ".join([prompt, line])

lines = line.rstrip().split("\n")
restored_quotes = restore_quotes(code_unit, docstring_quotes)

lines = restored_quotes.rstrip().split("\n")
if block_start_re.match(lines[0]):
lines.append("")

lines = iter(lines)

lines_ = iter(lines)
reformatted = list(
itertools.chain(
more_itertools.always_iterable(
add_prompt(prompt, more_itertools.first(lines))
add_prompt(prompt, more_itertools.first(lines_))
),
(add_prompt(continuation_prompt, line) for line in lines),
)
)

# make sure nested docstrings still work
current_quotes = detect_docstring_quotes("\n".join(reformatted))
restored = "\n".join(
replace_quotes(line, current, saved)
for line, saved, current in itertools.zip_longest(
reformatted, docstring_quotes, current_quotes
(add_prompt(continuation_prompt, line) for line in lines_),
)
if line is not None
)

return restored
return "\n".join(reformatted)
Loading