Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion packages/bolt/src/bolt/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
AstChildren,
AstCommand,
AstCommandSentinel,
AstError,
AstJson,
AstLiteral,
AstNode,
Expand Down Expand Up @@ -493,7 +494,7 @@ class AstProcMacroMarker(AstNode):
class AstProcMacroResult(AstNode):
"""Ast proc macro result node."""

commands: AstChildren[AstCommand] = required_field()
commands: AstChildren[AstCommand|AstError] = required_field()


@dataclass(frozen=True, slots=True)
Expand Down
10 changes: 5 additions & 5 deletions packages/bolt/src/bolt/emit.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from contextlib import contextmanager
from typing import Any, Callable, Generator, Iterator, List, Optional, ParamSpec, Tuple

from mecha import AstCommand, AstNode, AstRoot
from mecha import AstCommand, AstError, AstNode, AstRoot

from .utils import internal

Expand All @@ -16,7 +16,7 @@
class CommandEmitter:
"""Command emitter."""

commands: List[AstCommand]
commands: List[AstCommand|AstError]
nesting: List[Tuple[str, Tuple[AstNode, ...]]]

def __init__(self):
Expand All @@ -26,8 +26,8 @@ def __init__(self):
@contextmanager
def scope(
self,
commands: Optional[List[AstCommand]] = None,
) -> Iterator[List[AstCommand]]:
commands: Optional[List[AstCommand|AstError]] = None,
) -> Iterator[List[AstCommand|AstError]]:
"""Create a new scope to gather commands."""
if commands is None:
commands = []
Expand All @@ -46,7 +46,7 @@ def capture_output(
f: Callable[P, Any],
*args: P.args,
**kwargs: P.kwargs,
) -> List[AstCommand]:
) -> List[AstCommand|AstError]:
"""Invoke a user-defined function and return the list of generated commands."""
with self.scope() as output:
result = f(*args, **kwargs)
Expand Down
56 changes: 47 additions & 9 deletions packages/bolt/src/bolt/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@
AdjacentConstraint,
AlternativeParser,
AstChildren,
AstError,
AstCommand,
AstJson,
AstNode,
Expand Down Expand Up @@ -968,9 +969,13 @@ def __call__(self, stream: TokenStream) -> Any:

def resolve(self, node: AstRoot) -> AstRoot:
should_replace = False
commands: List[AstCommand] = []
commands: List[AstCommand|AstError] = []

for command in node.commands:
if isinstance(command, AstError):
commands.append(command)
continue

stack: List[AstCommand] = [command]

while command.arguments and isinstance(
Expand Down Expand Up @@ -1100,9 +1105,13 @@ def __call__(self, stream: TokenStream) -> AstRoot:
stack: List[AstDecorator] = []

changed = False
result: List[AstCommand] = []
result: List[AstCommand|AstError] = []

for command in node.commands:
if isinstance(command, AstError):
result.append(command)
continue

if isinstance(command, AstStatement) and isinstance(
decorator := command.arguments[0], AstDecorator
):
Expand Down Expand Up @@ -1185,9 +1194,13 @@ def __call__(self, stream: TokenStream) -> Any:
return node

changed = False
result: List[AstCommand] = []
result: List[AstCommand|AstError] = []

for command in node.commands:
if isinstance(command, AstError):
result.append(command)
continue

if command.identifier == "return:value" and command.arguments:
changed = True

Expand Down Expand Up @@ -1222,12 +1235,16 @@ def __call__(self, stream: TokenStream) -> AstRoot:
node: AstRoot = self.parser(stream)

changed = False
result: List[AstCommand] = []
result: List[AstError|AstCommand] = []

commands = iter(node.commands)
previous = ""

for command in commands:
if isinstance(command, AstError):
result.append(command)
continue

if command.identifier in ["elif:condition:body", "else:body"]:
if previous not in ["if:condition:body", "elif:condition:body"]:
exc = InvalidSyntax(
Expand All @@ -1240,6 +1257,9 @@ def __call__(self, stream: TokenStream) -> AstRoot:
elif_chain = [command]

for command in commands:
if isinstance(command, AstError):
continue

if command.identifier not in ["elif:condition:body", "else:body"]:
break
elif_chain.append(command)
Expand Down Expand Up @@ -1291,6 +1311,9 @@ def __call__(self, stream: TokenStream) -> AstRoot:

if not loop:
for command in node.commands:
if isinstance(command, AstError):
continue

if command.identifier in ["break", "continue"]:
exc = InvalidSyntax(
f'Can only use "{command.identifier}" in loops.'
Expand Down Expand Up @@ -1350,7 +1373,6 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature:
stream.expect(("brace", "("))

node = set_location(AstFunctionSignature(name=identifier.value), identifier)
lexical_scope.bind_variable(identifier.value, node)

deferred_scope = lexical_scope.deferred(FunctionScope)

Expand Down Expand Up @@ -1469,6 +1491,7 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature:
exc = InvalidSyntax(
"Expected at least one named argument after bare variadic marker."
)

raise set_location(exc, argument)

return_type_annotation = None
Expand All @@ -1481,6 +1504,7 @@ def parse_function_signature(stream: TokenStream) -> AstFunctionSignature:
arguments=AstChildren(arguments),
return_type_annotation=return_type_annotation,
)
lexical_scope.bind_variable(identifier.value, node)
return set_location(node, node, stream.current)


Expand Down Expand Up @@ -1755,11 +1779,15 @@ class ProcMacroExpansion:

def __call__(self, stream: TokenStream) -> AstRoot:
should_replace = False
commands: List[AstCommand] = []
commands: List[AstCommand|AstError] = []

node: AstRoot = self.parser(stream)

for command in node.commands:
if isinstance(command, AstError):
commands.append(command)
continue

stack: List[AstCommand] = [command]

while command.arguments and isinstance(
Expand Down Expand Up @@ -2074,9 +2102,13 @@ def __call__(self, stream: TokenStream) -> Any:
node: AstRoot = self.parser(stream)

changed = False
result: List[AstCommand] = []
result: List[AstCommand|AstError] = []

for command in node.commands:
if isinstance(command, AstError):
result.append(command)
continue

if (
isinstance(command, AstStatement)
and command.arguments
Expand Down Expand Up @@ -2198,9 +2230,13 @@ def __call__(self, stream: TokenStream) -> AstRoot:

def resolve_deferred(self, node: AstRoot, stream: TokenStream) -> AstRoot:
should_replace = False
commands: List[AstCommand] = []
commands: List[AstCommand|AstError] = []

for command in node.commands:
if isinstance(command, AstError):
commands.append(command)
continue

if command.arguments and isinstance(
body := command.arguments[-1], AstClassRoot
):
Expand Down Expand Up @@ -2283,6 +2319,9 @@ def __call__(self, stream: TokenStream) -> AstRoot:

if isinstance(node, AstRoot):
for command in node.commands:
if isinstance(command, AstError):
continue

if command.identifier in self.command_identifiers:
name, _, _ = command.identifier.partition(":")
exc = InvalidSyntax(
Expand Down Expand Up @@ -2402,7 +2441,6 @@ def __call__(self, stream: TokenStream) -> Any:
node = AstUnpack(type="dict" if prefix.value == "**" else "list", value=node)
return set_location(node, prefix, node.value)


@dataclass
class UnpackConstraint:
"""Constraint for unpacking."""
Expand Down
8 changes: 6 additions & 2 deletions packages/bolt/src/bolt/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
"NonFunctionSerializer",
]


import builtins
from dataclasses import dataclass, field
from functools import partial
Expand All @@ -15,6 +14,7 @@
from beet.core.utils import JsonDict, extra_field, required_field
from mecha import (
AstRoot,
AstError,
CommandSpec,
CommandTree,
CompilationDatabase,
Expand All @@ -30,7 +30,7 @@
from pathspec import PathSpec
from tokenstream import set_location

from .ast import AstNonFunctionRoot
from .ast import AstNonFunctionRoot, AstRoot
from .codegen import Codegen
from .emit import CommandEmitter
from .helpers import get_bolt_helpers
Expand Down Expand Up @@ -275,6 +275,10 @@ def non_function_root(self, node: AstNonFunctionRoot, result: List[str]):
result.append(source)
if node.commands:
command = node.commands[0]

if isinstance(command, AstError):
return None

name = command.identifier.partition(":")[0]
d = Diagnostic(
"warn", f'Ignored top-level "{name}" command outside function.'
Expand Down
4 changes: 2 additions & 2 deletions packages/bolt/tests/test_sandbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import pytest
from beet import Context, Function
from beet.core.utils import format_exc
from mecha import CompilationError, DiagnosticError, Mecha
from mecha import CompilationError, DiagnosticCollection, DiagnosticError, Mecha
from pytest_insta import SnapshotFixture

SANDBOX_EXAMPLES = [
Expand All @@ -22,7 +22,7 @@
def test_run(snapshot: SnapshotFixture, ctx_sandbox: Context, source: Function):
mc = ctx_sandbox.inject(Mecha)

with pytest.raises((CompilationError, DiagnosticError)) as exc_info:
with pytest.raises((CompilationError, DiagnosticError, DiagnosticCollection)) as exc_info:
mc.compile(source, resource_location="demo:foo")

details = (
Expand Down
51 changes: 43 additions & 8 deletions packages/mecha/src/mecha/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@
from pydantic import BaseModel, field_validator
from tokenstream import InvalidSyntax, Preprocessor, TokenStream, set_location

from .ast import AstLiteral, AstNode, AstRoot
from .ast import AstError, AstLiteral, AstNode, AstRoot
from .config import CommandTree
from .database import (
CompilationDatabase,
Expand All @@ -69,7 +69,7 @@
DiagnosticErrorSummary,
)
from .dispatch import Dispatcher, MutatingReducer, Reducer
from .parse import delegate, get_parsers
from .parse import InvalidSyntaxCollection, delegate, get_parsers
from .preprocess import wrap_backslash_continuation
from .serialize import FormattingOptions, Serializer
from .spec import CommandSpec
Expand Down Expand Up @@ -381,13 +381,30 @@ def parse(
preprocessor=preprocessor or self.preprocessor,
)

diagnostics = []
try:
with self.prepare_token_stream(stream, multiline=multiline):
with stream.provide(**provide or {}):
ast = delegate(parser, stream)
ast: AstRoot = self.parse_stream(multiline, provide, parser, stream)

errors: list[InvalidSyntax] = []
for node in ast.walk():
if isinstance(node, AstError):
errors.append(node.error)

if len(errors) >= 1:
raise InvalidSyntaxCollection(errors)

except InvalidSyntaxCollection as collection:
for exc in collection.errors:
d = Diagnostic(
level="error",
message=str(exc),
notes=exc.notes,
hint=resource_location,
filename=str(filename) if filename else None,
file=source,
)
diagnostics.append(set_location(d, exc))
except InvalidSyntax as exc:
if self.cache and filename and cache_miss:
self.cache.invalidate_changes(self.directory / filename)
d = Diagnostic(
level="error",
message=str(exc),
Expand All @@ -396,7 +413,7 @@ def parse(
filename=str(filename) if filename else None,
file=source,
)
raise DiagnosticError(DiagnosticCollection([set_location(d, exc)])) from exc
diagnostics.append(set_location(d, exc))
else:
if self.cache and filename and cache_miss:
try:
Expand All @@ -407,6 +424,24 @@ def parse(
pass
return ast

if len(diagnostics) > 0:
if self.cache and filename and cache_miss:
self.cache.invalidate_changes(self.directory / filename)

raise DiagnosticError(DiagnosticCollection(diagnostics))

def parse_stream(
self,
multiline: bool | None,
provide: JsonDict | None,
parser: str,
stream: TokenStream,
):
with self.prepare_token_stream(stream, multiline=multiline):
with stream.provide(**provide or {}):
ast = delegate(parser, stream)
return ast

@overload
def compile(
self,
Expand Down
Loading
Loading