aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorKevin Newton <kddnewton@gmail.com>2024-06-05 11:45:38 -0400
committerKevin Newton <kddnewton@gmail.com>2024-06-05 14:40:03 -0400
commitd13112b779d8dd8482f4a545655c37bdf860abf1 (patch)
tree2cc495a29f43f6342b91e86d298cb8c179a1ff3d
parent3cb866ce35f8dea504a1ec5c84c4f0ede51c7a32 (diff)
downloadruby-d13112b779d8dd8482f4a545655c37bdf860abf1.tar.gz
[ruby/prism] Parse all regular expressions
https://github.com/ruby/prism/commit/11e0e204ce
-rw-r--r--prism/prism.c103
-rw-r--r--prism/regexp.c37
-rw-r--r--test/prism/onigmo_test.rb2
3 files changed, 84 insertions, 58 deletions
diff --git a/prism/prism.c b/prism/prism.c
index aaad62d739..3eab96dea4 100644
--- a/prism/prism.c
+++ b/prism/prism.c
@@ -17391,6 +17391,51 @@ parse_yield(pm_parser_t *parser, const pm_node_t *node) {
}
/**
+ * This struct is used to pass information between the regular expression parser
+ * and the error callback.
+ */
+typedef struct {
+ pm_parser_t *parser;
+ const uint8_t *start;
+ const uint8_t *end;
+ bool shared;
+} parse_regular_expression_error_data_t;
+
+/**
+ * This callback is called when the regular expression parser encounters a
+ * syntax error.
+ */
+static void
+parse_regular_expression_error(const uint8_t *start, const uint8_t *end, const char *message, void *data) {
+ parse_regular_expression_error_data_t *callback_data = (parse_regular_expression_error_data_t *) data;
+ pm_location_t location;
+
+ if (callback_data->shared) {
+ location = (pm_location_t) { .start = start, .end = end };
+ } else {
+ location = (pm_location_t) { .start = callback_data->start, .end = callback_data->end };
+ }
+
+ PM_PARSER_ERR_FORMAT(callback_data->parser, location.start, location.end, PM_ERR_REGEXP_PARSE_ERROR, message);
+}
+
+/**
+ * Parse the errors for the regular expression and add them to the parser.
+ */
+static void
+parse_regular_expression_errors(pm_parser_t *parser, pm_regular_expression_node_t *node) {
+ const pm_string_t *unescaped = &node->unescaped;
+ parse_regular_expression_error_data_t error_data = {
+ .parser = parser,
+ .start = node->base.location.start,
+ .end = node->base.location.end,
+ .shared = unescaped->type == PM_STRING_SHARED
+ };
+
+ pm_regexp_parse(parser, pm_string_source(unescaped), pm_string_length(unescaped), NULL, NULL, parse_regular_expression_error, &error_data);
+}
+
+/**
* Parse an expression that begins with the previous node that we just lexed.
*/
static inline pm_node_t *
@@ -19511,13 +19556,22 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, b
bool ascii_only = parser->current_regular_expression_ascii_only;
parser_lex(parser);
- // If we hit an end, then we can create a regular expression node
- // without interpolation, which can be represented more succinctly and
- // more easily compiled.
+ // If we hit an end, then we can create a regular expression
+ // node without interpolation, which can be represented more
+ // succinctly and more easily compiled.
if (accept1(parser, PM_TOKEN_REGEXP_END)) {
- pm_node_t *node = (pm_node_t *) pm_regular_expression_node_create_unescaped(parser, &opening, &content, &parser->previous, &unescaped);
- pm_node_flag_set(node, parse_and_validate_regular_expression_encoding(parser, &unescaped, ascii_only, node->flags));
- return node;
+ pm_regular_expression_node_t *node = (pm_regular_expression_node_t *) pm_regular_expression_node_create_unescaped(parser, &opening, &content, &parser->previous, &unescaped);
+
+ // If we're not immediately followed by a =~, then we want
+ // to parse all of the errors at this point. If it is
+ // followed by a =~, then it will get parsed higher up while
+ // parsing the named captures as well.
+ if (!match1(parser, PM_TOKEN_EQUAL_TILDE)) {
+ parse_regular_expression_errors(parser, node);
+ }
+
+ pm_node_flag_set((pm_node_t *) node, parse_and_validate_regular_expression_encoding(parser, &unescaped, ascii_only, node->base.flags));
+ return (pm_node_t *) node;
}
// If we get here, then we have interpolation so we'll need to create
@@ -20084,38 +20138,6 @@ parse_regular_expression_named_capture(const pm_string_t *capture, void *data) {
}
/**
- * This struct is used to pass information between the regular expression parser
- * and the error callback.
- */
-typedef struct {
- pm_parser_t *parser;
- const pm_string_t *content;
- const pm_call_node_t *call;
-} parse_regular_expression_error_data_t;
-
-/**
- * This callback is called when the regular expression parser encounters a
- * syntax error.
- */
-static void
-parse_regular_expression_error(const uint8_t *start, const uint8_t *end, const char *message, void *data) {
- parse_regular_expression_error_data_t *callback_data = (parse_regular_expression_error_data_t *) data;
-
- pm_parser_t *parser = callback_data->parser;
- const pm_string_t *content = callback_data->content;
- const pm_call_node_t *call = callback_data->call;
-
- pm_location_t location;
- if (content->type == PM_STRING_SHARED) {
- location = (pm_location_t) { .start = start, .end = end };
- } else {
- location = call->receiver->location;
- }
-
- PM_PARSER_ERR_FORMAT(parser, location.start, location.end, PM_ERR_REGEXP_PARSE_ERROR, message);
-}
-
-/**
* Potentially change a =~ with a regular expression with named captures into a
* match write node.
*/
@@ -20130,8 +20152,9 @@ parse_regular_expression_named_captures(pm_parser_t *parser, const pm_string_t *
parse_regular_expression_error_data_t error_data = {
.parser = parser,
- .content = content,
- .call = call
+ .start = call->receiver->location.start,
+ .end = call->receiver->location.end,
+ .shared = content->type == PM_STRING_SHARED
};
pm_regexp_parse(parser, pm_string_source(content), pm_string_length(content), parse_regular_expression_named_capture, &callback_data, parse_regular_expression_error, &error_data);
diff --git a/prism/regexp.c b/prism/regexp.c
index 65ab573c85..9eea90e12f 100644
--- a/prism/regexp.c
+++ b/prism/regexp.c
@@ -225,21 +225,24 @@ pm_regexp_parse_range_quantifier(pm_regexp_parser_t *parser) {
*/
static bool
pm_regexp_parse_quantifier(pm_regexp_parser_t *parser) {
- if (pm_regexp_char_is_eof(parser)) return true;
-
- switch (*parser->cursor) {
- case '*':
- case '+':
- case '?':
- parser->cursor++;
- return true;
- case '{':
- parser->cursor++;
- return pm_regexp_parse_range_quantifier(parser);
- default:
- // In this case there is no quantifier.
- return true;
+ while (!pm_regexp_char_is_eof(parser)) {
+ switch (*parser->cursor) {
+ case '*':
+ case '+':
+ case '?':
+ parser->cursor++;
+ break;
+ case '{':
+ parser->cursor++;
+ if (!pm_regexp_parse_range_quantifier(parser)) return false;
+ break;
+ default:
+ // In this case there is no quantifier.
+ return true;
+ }
}
+
+ return true;
}
/**
@@ -276,7 +279,7 @@ pm_regexp_parse_character_set(pm_regexp_parser_t *parser, uint16_t depth) {
while (!pm_regexp_char_is_eof(parser) && *parser->cursor != ']') {
switch (*parser->cursor++) {
case '[':
- pm_regexp_parse_lbracket(parser, depth + 1);
+ pm_regexp_parse_lbracket(parser, (uint16_t) (depth + 1));
break;
case '\\':
if (!pm_regexp_char_is_eof(parser)) {
@@ -584,7 +587,7 @@ pm_regexp_parse_group(pm_regexp_parser_t *parser, uint16_t depth) {
// Now, parse the expressions within this group.
while (!pm_regexp_char_is_eof(parser) && *parser->cursor != ')') {
- if (!pm_regexp_parse_expression(parser, depth + 1)) {
+ if (!pm_regexp_parse_expression(parser, (uint16_t) (depth + 1))) {
return false;
}
pm_regexp_char_accept(parser, '|');
@@ -615,7 +618,7 @@ pm_regexp_parse_item(pm_regexp_parser_t *parser, uint16_t depth) {
case '^':
case '$':
parser->cursor++;
- return true;
+ return pm_regexp_parse_quantifier(parser);
case '\\':
parser->cursor++;
if (!pm_regexp_char_is_eof(parser)) {
diff --git a/test/prism/onigmo_test.rb b/test/prism/onigmo_test.rb
index c8aa2317b8..03f44c4e4c 100644
--- a/test/prism/onigmo_test.rb
+++ b/test/prism/onigmo_test.rb
@@ -54,7 +54,7 @@ module Prism
private
def assert_error(source, message)
- result = Prism.parse(%Q{/#{source}/ =~ ""})
+ result = Prism.parse("/#{source}/")
assert result.failure?, "Expected #{source.inspect} to error"
assert_equal message, result.errors.first.message