aboutsummaryrefslogtreecommitdiffstats
path: root/prism
diff options
context:
space:
mode:
authorKevin Newton <kddnewton@gmail.com>2023-10-10 10:52:44 -0400
committerKevin Newton <kddnewton@gmail.com>2023-10-13 15:31:30 -0400
commitdd3986876a96f9e9fec078247d7d40b322f8fd17 (patch)
treeaf033269dedfbcb978e391b59b9df145e8234319 /prism
parent1a941c70e42c1e64b961088e953ded6a148e1351 (diff)
downloadruby-dd3986876a96f9e9fec078247d7d40b322f8fd17.tar.gz
[ruby/prism] Handle remaining escape sequences for character literals
https://github.com/ruby/prism/commit/ba33607034
Diffstat (limited to 'prism')
-rw-r--r--prism/prism.c346
-rw-r--r--prism/util/pm_buffer.c16
-rw-r--r--prism/util/pm_buffer.h3
3 files changed, 354 insertions, 11 deletions
diff --git a/prism/prism.c b/prism/prism.c
index 3b4261e5ca..e4448ef394 100644
--- a/prism/prism.c
+++ b/prism/prism.c
@@ -6064,6 +6064,340 @@ lex_interpolation(pm_parser_t *parser, const uint8_t *pound) {
}
}
+static const uint8_t PM_ESCAPE_FLAG_CONTROL = 0x1;
+static const uint8_t PM_ESCAPE_FLAG_META = 0x2;
+static const uint8_t PM_ESCAPE_FLAG_SINGLE = 0x4;
+
+// This is a lookup table for whether or not an ASCII character is printable.
+static const bool ascii_printable_chars[] = {
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0
+};
+
+static inline bool
+char_is_ascii_printable(const uint8_t b) {
+ return (b < 0x80) && ascii_printable_chars[b];
+}
+
+// Return the value that a hexadecimal digit character represents. For example,
+// transform 'a' into 10, 'b' into 11, etc.
+static inline uint8_t
+escape_hexadecimal_digit(const uint8_t value) {
+ return (uint8_t) ((value <= '9') ? (value - '0') : (value & 0x7) + 9);
+}
+
+// Scan the 4 digits of a Unicode escape into the value. Returns the number of
+// digits scanned. This function assumes that the characters have already been
+// validated.
+static inline uint32_t
+escape_unicode(const uint8_t *string, size_t length) {
+ uint32_t value = 0;
+ for (size_t index = 0; index < length; index++) {
+ if (index != 0) value <<= 4;
+ value |= escape_hexadecimal_digit(string[index]);
+ }
+ return value;
+}
+
+// Escape a single character value based on the given flags.
+static inline uint8_t
+escape_byte(uint8_t value, const uint8_t flags) {
+ if (flags & PM_ESCAPE_FLAG_CONTROL) value &= 0x1f;
+ if (flags & PM_ESCAPE_FLAG_META) value |= 0x80;
+ return value;
+}
+
+// Write a unicode codepoint to the given buffer.
+static inline void
+escape_write_unicode(pm_parser_t *parser, pm_buffer_t *buffer, const uint8_t *start, const uint8_t *end, uint32_t value) {
+ if (value <= 0x7F) { // 0xxxxxxx
+ pm_buffer_append_u8(buffer, (uint8_t) value);
+ } else if (value <= 0x7FF) { // 110xxxxx 10xxxxxx
+ pm_buffer_append_u8(buffer, (uint8_t) (0xC0 | (value >> 6)));
+ pm_buffer_append_u8(buffer, (uint8_t) (0x80 | (value & 0x3F)));
+ } else if (value <= 0xFFFF) { // 1110xxxx 10xxxxxx 10xxxxxx
+ pm_buffer_append_u8(buffer, (uint8_t) (0xE0 | (value >> 12)));
+ pm_buffer_append_u8(buffer, (uint8_t) (0x80 | ((value >> 6) & 0x3F)));
+ pm_buffer_append_u8(buffer, (uint8_t) (0x80 | (value & 0x3F)));
+ } else if (value <= 0x10FFFF) { // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+ pm_buffer_append_u8(buffer, (uint8_t) (0xF0 | (value >> 18)));
+ pm_buffer_append_u8(buffer, (uint8_t) (0x80 | ((value >> 12) & 0x3F)));
+ pm_buffer_append_u8(buffer, (uint8_t) (0x80 | ((value >> 6) & 0x3F)));
+ pm_buffer_append_u8(buffer, (uint8_t) (0x80 | (value & 0x3F)));
+ } else {
+ pm_parser_err(parser, start, end, PM_ERR_ESCAPE_INVALID_UNICODE);
+ pm_buffer_append_u8(buffer, 0xEF);
+ pm_buffer_append_u8(buffer, 0xBF);
+ pm_buffer_append_u8(buffer, 0xBD);
+ }
+}
+
+// Read the value of an escape into the buffer.
+static void
+escape_read(pm_parser_t *parser, pm_buffer_t *buffer, uint8_t flags) {
+ switch (peek(parser)) {
+ case '\\': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\\');
+ return;
+ }
+ case '\'': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\'');
+ return;
+ }
+ case 'a': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\a');
+ return;
+ }
+ case 'b': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\b');
+ return;
+ }
+ case 'e': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\033');
+ return;
+ }
+ case 'f': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\f');
+ return;
+ }
+ case 'n': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\n');
+ return;
+ }
+ case 'r': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\r');
+ return;
+ }
+ case 's': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, ' ');
+ return;
+ }
+ case 't': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\t');
+ return;
+ }
+ case 'v': {
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, '\v');
+ return;
+ }
+ case '0': case '1': case '2': case '3': case '4': case '5': case '6': case '7': {
+ uint8_t value = *parser->current.end - '0';
+ parser->current.end++;
+
+ if (pm_char_is_octal_digit(peek(parser))) {
+ value = (uint8_t) ((value << 3) | (*parser->current.end - '0'));
+ parser->current.end++;
+
+ if (pm_char_is_octal_digit(peek(parser))) {
+ value = (uint8_t) ((value << 3) | (*parser->current.end - '0'));
+ parser->current.end++;
+ }
+ }
+
+ pm_buffer_append_u8(buffer, value);
+ return;
+ }
+ case 'x': {
+ uint8_t byte = peek(parser);
+ parser->current.end++;
+
+ if (pm_char_is_hexadecimal_digit(byte)) {
+ uint8_t value = escape_hexadecimal_digit(byte);
+ parser->current.end++;
+
+ byte = peek(parser);
+ if (pm_char_is_hexadecimal_digit(byte)) {
+ value = (uint8_t) ((value << 4) | escape_hexadecimal_digit(byte));
+ parser->current.end++;
+ }
+
+ pm_buffer_append_u8(buffer, value);
+ } else {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_HEXADECIMAL);
+ }
+
+ return;
+ }
+ case 'u': {
+ parser->current.end++;
+
+ if (
+ (parser->current.end + 4 < parser->end) &&
+ pm_char_is_hexadecimal_digit(parser->current.end[0]) &&
+ pm_char_is_hexadecimal_digit(parser->current.end[1]) &&
+ pm_char_is_hexadecimal_digit(parser->current.end[2]) &&
+ pm_char_is_hexadecimal_digit(parser->current.end[3])
+ ) {
+ uint32_t value = escape_unicode(parser->current.end, 4);
+ escape_write_unicode(parser, buffer, parser->current.end, parser->current.end + 4, value);
+ parser->current.end += 4;
+ } else if (peek(parser) == '{') {
+ const uint8_t *unicode_codepoints_start = parser->current.end - 2;
+ parser->current.end++;
+
+ const uint8_t *extra_codepoints_start = NULL;
+ int codepoints_count = 0;
+
+ parser->current.end += pm_strspn_whitespace(parser->current.end, parser->end - parser->current.end);
+ while ((parser->current.end < parser->end) && (*parser->current.end != '}')) {
+ const uint8_t *unicode_start = parser->current.end;
+ size_t hexadecimal_length = pm_strspn_hexadecimal_digit(parser->current.end, parser->end - parser->current.end);
+
+ if (hexadecimal_length > 6) {
+ // \u{nnnn} character literal allows only 1-6 hexadecimal digits
+ pm_parser_err(parser, unicode_start, unicode_start + hexadecimal_length, PM_ERR_ESCAPE_INVALID_UNICODE_LONG);
+ } else if (hexadecimal_length == 0) {
+ // there are not hexadecimal characters
+ pm_parser_err(parser, unicode_start, unicode_start + hexadecimal_length, PM_ERR_ESCAPE_INVALID_UNICODE);
+ return;
+ }
+
+ parser->current.end += hexadecimal_length;
+ codepoints_count++;
+ if (flags & PM_ESCAPE_FLAG_SINGLE && codepoints_count == 2) {
+ extra_codepoints_start = unicode_start;
+ }
+
+ uint32_t value = escape_unicode(unicode_start, hexadecimal_length);
+ escape_write_unicode(parser, buffer, unicode_start, parser->current.end, value);
+ parser->current.end += pm_strspn_whitespace(parser->current.end, parser->end - parser->current.end);
+ }
+
+ // ?\u{nnnn} character literal should contain only one codepoint and cannot be like ?\u{nnnn mmmm}
+ if (flags & PM_ESCAPE_FLAG_SINGLE && codepoints_count > 1) {
+ pm_parser_err(parser, extra_codepoints_start, parser->current.end - 1, PM_ERR_ESCAPE_INVALID_UNICODE_LITERAL);
+ }
+
+ if (peek(parser) == '}') {
+ parser->current.end++;
+ } else {
+ pm_parser_err(parser, unicode_codepoints_start, parser->current.end, PM_ERR_ESCAPE_INVALID_UNICODE_TERM);
+ }
+ } else {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_UNICODE);
+ }
+
+ return;
+ }
+ case 'c': {
+ parser->current.end++;
+ uint8_t peeked = peek(parser);
+
+ switch (peeked) {
+ case '?':
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, escape_byte(0x7f, flags | PM_ESCAPE_FLAG_CONTROL));
+ return;
+ case '\\':
+ if (flags & PM_ESCAPE_FLAG_CONTROL) {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_CONTROL_REPEAT);
+ return;
+ }
+ parser->current.end++;
+ escape_read(parser, buffer, flags | PM_ESCAPE_FLAG_CONTROL);
+ return;
+ default:
+ if (!char_is_ascii_printable(peeked)) {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_CONTROL);
+ return;
+ }
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, escape_byte(peeked, flags | PM_ESCAPE_FLAG_CONTROL));
+ return;
+ }
+ }
+ case 'C': {
+ parser->current.end++;
+ if (peek(parser) != '-') {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_CONTROL);
+ return;
+ }
+
+ parser->current.end++;
+ uint8_t peeked = peek(parser);
+
+ switch (peeked) {
+ case '?':
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, escape_byte(0x7f, flags | PM_ESCAPE_FLAG_CONTROL));
+ return;
+ case '\\':
+ if (flags & PM_ESCAPE_FLAG_CONTROL) {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_CONTROL_REPEAT);
+ return;
+ }
+ parser->current.end++;
+ escape_read(parser, buffer, flags | PM_ESCAPE_FLAG_CONTROL);
+ return;
+ default:
+ if (!char_is_ascii_printable(peeked)) {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_CONTROL);
+ return;
+ }
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, escape_byte(peeked, flags | PM_ESCAPE_FLAG_CONTROL));
+ return;
+ }
+ }
+ case 'M': {
+ parser->current.end++;
+ if (peek(parser) != '-') {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META);
+ return;
+ }
+
+ parser->current.end++;
+ uint8_t peeked = peek(parser);
+
+ switch (peeked) {
+ case '?':
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, escape_byte(0x7f, flags | PM_ESCAPE_FLAG_META));
+ return;
+ case '\\':
+ if (flags & PM_ESCAPE_FLAG_META) {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META_REPEAT);
+ return;
+ }
+ parser->current.end++;
+ escape_read(parser, buffer, flags | PM_ESCAPE_FLAG_META);
+ return;
+ default:
+ if (!char_is_ascii_printable(peeked)) {
+ pm_parser_err_current(parser, PM_ERR_ESCAPE_INVALID_META);
+ return;
+ }
+ parser->current.end++;
+ pm_buffer_append_u8(buffer, escape_byte(peeked, flags | PM_ESCAPE_FLAG_META));
+ return;
+ }
+ }
+ default: {
+ if (parser->current.end < parser->end) {
+ pm_buffer_append_u8(buffer, *parser->current.end++);
+ }
+ return;
+ }
+ }
+}
+
// This function is responsible for lexing either a character literal or the ?
// operator. The supported character literals are described below.
//
@@ -6108,11 +6442,15 @@ lex_question_mark(pm_parser_t *parser) {
lex_state_set(parser, PM_LEX_STATE_BEG);
- if (parser->current.start[1] == '\\') {
+ if (match(parser, '\\')) {
lex_state_set(parser, PM_LEX_STATE_END);
- parser->current.end += pm_unescape_calculate_difference(parser, parser->current.start + 1, PM_UNESCAPE_ALL, true);
- pm_string_shared_init(&parser->current_string, parser->current.start + 1, parser->current.end);
- pm_unescape_manipulate_char_literal(parser, &parser->current_string, PM_UNESCAPE_ALL);
+
+ pm_buffer_t buffer;
+ pm_buffer_init_capacity(&buffer, 3);
+
+ escape_read(parser, &buffer, PM_ESCAPE_FLAG_SINGLE);
+ pm_string_owned_init(&parser->current_string, (uint8_t *) buffer.value, buffer.length);
+
return PM_TOKEN_CHARACTER_LITERAL;
} else {
size_t encoding_width = parser->encoding.char_width(parser->current.end, parser->end - parser->current.end);
diff --git a/prism/util/pm_buffer.c b/prism/util/pm_buffer.c
index 0d84375767..55f6b0f7f3 100644
--- a/prism/util/pm_buffer.c
+++ b/prism/util/pm_buffer.c
@@ -1,24 +1,26 @@
#include "prism/util/pm_buffer.h"
-#define PRISM_BUFFER_INITIAL_SIZE 1024
-
// Return the size of the pm_buffer_t struct.
size_t
pm_buffer_sizeof(void) {
return sizeof(pm_buffer_t);
}
-// Initialize a pm_buffer_t with its default values.
+// Initialize a pm_buffer_t with the given capacity.
bool
-pm_buffer_init(pm_buffer_t *buffer) {
+pm_buffer_init_capacity(pm_buffer_t *buffer, size_t capacity) {
buffer->length = 0;
- buffer->capacity = PRISM_BUFFER_INITIAL_SIZE;
+ buffer->capacity = capacity;
- buffer->value = (char *) malloc(PRISM_BUFFER_INITIAL_SIZE);
+ buffer->value = (char *) malloc(capacity);
return buffer->value != NULL;
}
-#undef PRISM_BUFFER_INITIAL_SIZE
+// Initialize a pm_buffer_t with its default values.
+bool
+pm_buffer_init(pm_buffer_t *buffer) {
+ return pm_buffer_init_capacity(buffer, 1024);
+}
// Return the value of the buffer.
char *
diff --git a/prism/util/pm_buffer.h b/prism/util/pm_buffer.h
index 160d60bc58..d881b32441 100644
--- a/prism/util/pm_buffer.h
+++ b/prism/util/pm_buffer.h
@@ -21,6 +21,9 @@ typedef struct {
// Return the size of the pm_buffer_t struct.
PRISM_EXPORTED_FUNCTION size_t pm_buffer_sizeof(void);
+// Initialize a pm_buffer_t with the given capacity.
+bool pm_buffer_init_capacity(pm_buffer_t *buffer, size_t capacity);
+
// Initialize a pm_buffer_t with its default values.
PRISM_EXPORTED_FUNCTION bool pm_buffer_init(pm_buffer_t *buffer);