From 2fa051f627172674b342da26ebe1e671a5e449ec Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 13 Feb 2024 17:45:27 -0500 Subject: [ruby/prism] Validate multibyte characters in strings Check that multibyte characters are valid using pm_strpbrk. We need to add a couple of codepaths to ensure all encodings are covered. Importantly this doesn't check regular expressions, because apparently you're allowed to have invalid multibyte characters inside regular expression comment groups/extended mode. https://github.com/ruby/prism/commit/2857d3e1b5 --- prism/encoding.c | 20 ++++--- prism/encoding.h | 7 +++ prism/prism.c | 50 +++++++++--------- prism/util/pm_strpbrk.c | 136 +++++++++++++++++++++++++++++++++++++++++++----- prism/util/pm_strpbrk.h | 5 +- 5 files changed, 172 insertions(+), 46 deletions(-) diff --git a/prism/encoding.c b/prism/encoding.c index 981945caba..1d455c2421 100644 --- a/prism/encoding.c +++ b/prism/encoding.c @@ -2253,12 +2253,12 @@ static const uint8_t pm_utf_8_dfa[] = { static pm_unicode_codepoint_t pm_utf_8_codepoint(const uint8_t *b, ptrdiff_t n, size_t *width) { assert(n >= 0); - size_t maximum = (size_t) n; + size_t maximum = (n > 4) ? 4 : ((size_t) n); uint32_t codepoint; uint32_t state = 0; - for (size_t index = 0; index < 4 && index < maximum; index++) { + for (size_t index = 0; index < maximum; index++) { uint32_t byte = b[index]; uint32_t type = pm_utf_8_dfa[byte]; @@ -2267,7 +2267,7 @@ pm_utf_8_codepoint(const uint8_t *b, ptrdiff_t n, size_t *width) { (0xffu >> type) & (byte); state = pm_utf_8_dfa[256 + (state * 16) + type]; - if (!state) { + if (state == 0) { *width = index + 1; return (pm_unicode_codepoint_t) codepoint; } @@ -2282,9 +2282,17 @@ pm_utf_8_codepoint(const uint8_t *b, ptrdiff_t n, size_t *width) { */ size_t pm_encoding_utf_8_char_width(const uint8_t *b, ptrdiff_t n) { - size_t width; - pm_utf_8_codepoint(b, n, &width); - return width; + assert(n >= 0); + + size_t maximum = (n > 4) ? 4 : ((size_t) n); + uint32_t state = 0; + + for (size_t index = 0; index < maximum; index++) { + state = pm_utf_8_dfa[256 + (state * 16) + pm_utf_8_dfa[b[index]]]; + if (state == 0) return index + 1; + } + + return 0; } /** diff --git a/prism/encoding.h b/prism/encoding.h index 7ba1695de8..d0f947eacd 100644 --- a/prism/encoding.h +++ b/prism/encoding.h @@ -245,6 +245,13 @@ extern const pm_encoding_t pm_encodings[PM_ENCODING_MAXIMUM]; */ #define PM_ENCODING_US_ASCII_ENTRY (&pm_encodings[PM_ENCODING_US_ASCII]) +/** + * This is the ASCII-8BIT encoding. We need a reference to it so that pm_strpbrk + * can compare against it because invalid multibyte characters are not a thing + * in this encoding. + */ +#define PM_ENCODING_ASCII_8BIT_ENTRY (&pm_encodings[PM_ENCODING_ASCII_8BIT]) + /** * Parse the given name of an encoding and return a pointer to the corresponding * encoding struct if one can be found, otherwise return NULL. diff --git a/prism/prism.c b/prism/prism.c index d0b4265190..972f813b79 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -9737,7 +9737,7 @@ parser_lex(pm_parser_t *parser) { // and then find the first one. pm_lex_mode_t *lex_mode = parser->lex_modes.current; const uint8_t *breakpoints = lex_mode->as.list.breakpoints; - const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); // If we haven't found an escape yet, then this buffer will be // unallocated since we can refer directly to the source string. @@ -9746,7 +9746,7 @@ parser_lex(pm_parser_t *parser) { while (breakpoint != NULL) { // If we hit a null byte, skip directly past it. if (*breakpoint == '\0') { - breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1)); + breakpoint = pm_strpbrk(parser, breakpoint + 1, breakpoints, parser->end - (breakpoint + 1), true); continue; } @@ -9765,7 +9765,7 @@ parser_lex(pm_parser_t *parser) { // we need to continue on past it. if (lex_mode->as.list.nesting > 0) { parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); lex_mode->as.list.nesting--; continue; } @@ -9850,7 +9850,7 @@ parser_lex(pm_parser_t *parser) { } token_buffer.cursor = parser->current.end; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); continue; } @@ -9863,7 +9863,7 @@ parser_lex(pm_parser_t *parser) { // that looked like an interpolated class or instance variable // like "#@" but wasn't actually. In this case we'll just skip // to the next breakpoint. - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); continue; } @@ -9878,7 +9878,7 @@ parser_lex(pm_parser_t *parser) { // and find the next breakpoint. assert(*breakpoint == lex_mode->as.list.incrementor); parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); lex_mode->as.list.nesting++; continue; } @@ -9917,14 +9917,14 @@ parser_lex(pm_parser_t *parser) { // regular expression. We'll use strpbrk to find the first of these // characters. const uint8_t *breakpoints = lex_mode->as.regexp.breakpoints; - const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); pm_token_buffer_t token_buffer = { { 0 }, 0 }; while (breakpoint != NULL) { // If we hit a null byte, skip directly past it. if (*breakpoint == '\0') { parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); continue; } @@ -9946,7 +9946,7 @@ parser_lex(pm_parser_t *parser) { // If the terminator is not a newline, then we can set // the next breakpoint and continue. parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); continue; } } @@ -9956,7 +9956,7 @@ parser_lex(pm_parser_t *parser) { if (*breakpoint == lex_mode->as.regexp.terminator) { if (lex_mode->as.regexp.nesting > 0) { parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); lex_mode->as.regexp.nesting--; continue; } @@ -10055,7 +10055,7 @@ parser_lex(pm_parser_t *parser) { } token_buffer.cursor = parser->current.end; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); continue; } @@ -10068,7 +10068,7 @@ parser_lex(pm_parser_t *parser) { // something that looked like an interpolated class or // instance variable like "#@" but wasn't actually. In // this case we'll just skip to the next breakpoint. - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); continue; } @@ -10083,7 +10083,7 @@ parser_lex(pm_parser_t *parser) { // and find the next breakpoint. assert(*breakpoint == lex_mode->as.regexp.incrementor); parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); lex_mode->as.regexp.nesting++; continue; } @@ -10119,7 +10119,7 @@ parser_lex(pm_parser_t *parser) { // string. We'll use strpbrk to find the first of these characters. pm_lex_mode_t *lex_mode = parser->lex_modes.current; const uint8_t *breakpoints = lex_mode->as.string.breakpoints; - const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); // If we haven't found an escape yet, then this buffer will be // unallocated since we can refer directly to the source string. @@ -10131,7 +10131,7 @@ parser_lex(pm_parser_t *parser) { if (lex_mode->as.string.incrementor != '\0' && *breakpoint == lex_mode->as.string.incrementor) { lex_mode->as.string.nesting++; parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); continue; } @@ -10143,7 +10143,7 @@ parser_lex(pm_parser_t *parser) { // to continue on past it. if (lex_mode->as.string.nesting > 0) { parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); lex_mode->as.string.nesting--; continue; } @@ -10185,7 +10185,7 @@ parser_lex(pm_parser_t *parser) { if (parser->heredoc_end == NULL) { pm_newline_list_append(&parser->newline_list, breakpoint); parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); continue; } else { parser->current.end = breakpoint + 1; @@ -10199,7 +10199,7 @@ parser_lex(pm_parser_t *parser) { case '\0': // Skip directly past the null character. parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; case '\\': { // Here we hit escapes. @@ -10268,7 +10268,7 @@ parser_lex(pm_parser_t *parser) { } token_buffer.cursor = parser->current.end; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; } case '#': { @@ -10279,7 +10279,7 @@ parser_lex(pm_parser_t *parser) { // looked like an interpolated class or instance variable like "#@" // but wasn't actually. In this case we'll just skip to the next // breakpoint. - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; } @@ -10407,7 +10407,7 @@ parser_lex(pm_parser_t *parser) { breakpoints[2] = '\0'; } - const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + const uint8_t *breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); pm_token_buffer_t token_buffer = { { 0 }, 0 }; bool was_escaped_newline = false; @@ -10416,7 +10416,7 @@ parser_lex(pm_parser_t *parser) { case '\0': // Skip directly past the null character. parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; case '\n': { if (parser->heredoc_end != NULL && (parser->heredoc_end > breakpoint)) { @@ -10491,7 +10491,7 @@ parser_lex(pm_parser_t *parser) { // Otherwise we hit a newline and it wasn't followed by // a terminator, so we can continue parsing. parser->current.end = breakpoint + 1; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; } case '\\': { @@ -10555,7 +10555,7 @@ parser_lex(pm_parser_t *parser) { } token_buffer.cursor = parser->current.end; - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; } case '#': { @@ -10567,7 +10567,7 @@ parser_lex(pm_parser_t *parser) { // or instance variable like "#@" but wasn't // actually. In this case we'll just skip to the // next breakpoint. - breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end); + breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; } diff --git a/prism/util/pm_strpbrk.c b/prism/util/pm_strpbrk.c index 115eba1fd2..6c8dea1836 100644 --- a/prism/util/pm_strpbrk.c +++ b/prism/util/pm_strpbrk.c @@ -1,10 +1,18 @@ #include "prism/util/pm_strpbrk.h" /** - * This is the slow path that does care about the encoding. + * Add an invalid multibyte character error to the parser. + */ +static inline void +pm_strpbrk_invalid_multibyte_character(pm_parser_t *parser, const uint8_t *start, const uint8_t *end) { + pm_diagnostic_list_append_format(&parser->error_list, start, end, PM_ERR_INVALID_MULTIBYTE_CHARACTER, *start); +} + +/** + * This is the default path. */ static inline const uint8_t * -pm_strpbrk_multi_byte(const pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum) { +pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { size_t index = 0; while (index < maximum) { @@ -12,22 +20,39 @@ pm_strpbrk_multi_byte(const pm_parser_t *parser, const uint8_t *source, const ui return source + index; } - size_t width = parser->encoding->char_width(source + index, (ptrdiff_t) (maximum - index)); - if (width == 0) { - return NULL; - } + if (source[index] < 0x80) { + index++; + } else { + size_t width = pm_encoding_utf_8_char_width(source + index, (ptrdiff_t) (maximum - index)); - index += width; + if (width > 0) { + index += width; + } else if (!validate) { + index++; + } else { + // At this point we know we have an invalid multibyte character. + // We'll walk forward as far as we can until we find the next + // valid character so that we don't spam the user with a ton of + // the same kind of error. + const size_t start = index; + + do { + index++; + } while (index < maximum && pm_encoding_utf_8_char_width(source + index, (ptrdiff_t) (maximum - index)) == 0); + + pm_strpbrk_invalid_multibyte_character(parser, source + start, source + index); + } + } } return NULL; } /** - * This is the fast path that does not care about the encoding. + * This is the path when the encoding is ASCII-8BIT. */ static inline const uint8_t * -pm_strpbrk_single_byte(const uint8_t *source, const uint8_t *charset, size_t maximum) { +pm_strpbrk_ascii_8bit(const uint8_t *source, const uint8_t *charset, size_t maximum) { size_t index = 0; while (index < maximum) { @@ -41,6 +66,85 @@ pm_strpbrk_single_byte(const uint8_t *source, const uint8_t *charset, size_t max return NULL; } +/** + * This is the slow path that does care about the encoding. + */ +static inline const uint8_t * +pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { + size_t index = 0; + + while (index < maximum) { + if (strchr((const char *) charset, source[index]) != NULL) { + return source + index; + } + + if (source[index] < 0x80) { + index++; + } else { + size_t width = parser->encoding->char_width(source + index, (ptrdiff_t) (maximum - index)); + + if (width > 0) { + index += width; + } else if (!validate) { + index++; + } else { + // At this point we know we have an invalid multibyte character. + // We'll walk forward as far as we can until we find the next + // valid character so that we don't spam the user with a ton of + // the same kind of error. + const size_t start = index; + + do { + index++; + } while (index < maximum && parser->encoding->char_width(source + index, (ptrdiff_t) (maximum - index)) == 0); + + pm_strpbrk_invalid_multibyte_character(parser, source + start, source + index); + } + } + } + + return NULL; +} + +/** + * This is the fast path that does not care about the encoding because we know + * the encoding only supports single-byte characters. + */ +static inline const uint8_t * +pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { + size_t index = 0; + + while (index < maximum) { + if (strchr((const char *) charset, source[index]) != NULL) { + return source + index; + } + + if (source[index] < 0x80 || !validate) { + index++; + } else { + size_t width = parser->encoding->char_width(source + index, (ptrdiff_t) (maximum - index)); + + if (width > 0) { + index += width; + } else { + // At this point we know we have an invalid multibyte character. + // We'll walk forward as far as we can until we find the next + // valid character so that we don't spam the user with a ton of + // the same kind of error. + const size_t start = index; + + do { + index++; + } while (index < maximum && parser->encoding->char_width(source + index, (ptrdiff_t) (maximum - index)) == 0); + + pm_strpbrk_invalid_multibyte_character(parser, source + start, source + index); + } + } + } + + return NULL; +} + /** * Here we have rolled our own version of strpbrk. The standard library strpbrk * has undefined behavior when the source string is not null-terminated. We want @@ -57,16 +161,20 @@ pm_strpbrk_single_byte(const uint8_t *source, const uint8_t *charset, size_t max * * Finally, we want to support encodings wherein the charset could contain * characters that are trailing bytes of multi-byte characters. For example, in - * Shift-JIS, the backslash character can be a trailing byte. In that case we + * Shift_JIS, the backslash character can be a trailing byte. In that case we * need to take a slower path and iterate one multi-byte character at a time. */ const uint8_t * -pm_strpbrk(const pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, ptrdiff_t length) { +pm_strpbrk(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, ptrdiff_t length, bool validate) { if (length <= 0) { return NULL; - } else if (parser->encoding_changed && parser->encoding->multibyte) { - return pm_strpbrk_multi_byte(parser, source, charset, (size_t) length); + } else if (!parser->encoding_changed) { + return pm_strpbrk_utf8(parser, source, charset, (size_t) length, validate); + } else if (parser->encoding == PM_ENCODING_ASCII_8BIT_ENTRY) { + return pm_strpbrk_ascii_8bit(source, charset, (size_t) length); + } else if (parser->encoding->multibyte) { + return pm_strpbrk_multi_byte(parser, source, charset, (size_t) length, validate); } else { - return pm_strpbrk_single_byte(source, charset, (size_t) length); + return pm_strpbrk_single_byte(parser, source, charset, (size_t) length, validate); } } diff --git a/prism/util/pm_strpbrk.h b/prism/util/pm_strpbrk.h index c1cf0d54db..f387bd5782 100644 --- a/prism/util/pm_strpbrk.h +++ b/prism/util/pm_strpbrk.h @@ -7,6 +7,7 @@ #define PRISM_STRPBRK_H #include "prism/defines.h" +#include "prism/diagnostic.h" #include "prism/parser.h" #include @@ -35,9 +36,11 @@ * @param source The source to search. * @param charset The charset to search for. * @param length The maximum number of bytes to search. + * @param validate Whether to validate that the source string is valid in the + * current encoding of the parser. * @return A pointer to the first character in the source string that is in the * charset, or NULL if no such character exists. */ -const uint8_t * pm_strpbrk(const pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, ptrdiff_t length); +const uint8_t * pm_strpbrk(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, ptrdiff_t length, bool validate); #endif -- cgit v1.2.3