Source code for xotl.fl.ast.pattern

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# ---------------------------------------------------------------------
# Copyright (c) Merchise Autrement [~º/~] and Contributors
# All rights reserved.
#
# This is free software; you can do what the LICENCE file allows you to.
#
"""Pattern Matching.

"""
from typing import (
    Iterator,
    List,
    Mapping,
    MutableMapping,
    Union,
    Sequence,
    Tuple,
    Type as Class,
)
from dataclasses import dataclass
from itertools import groupby

from xotl.tools.objects import memoized_property
from xotl.tools.fp.tools import fst, snd

from xotl.fl.ast.base import AST, ILC
from xotl.fl.ast.types import TypeEnvironment
from xotl.fl.ast.expressions import Literal, _LetExpr, Let, Letrec, find_free_names
from xotl.fl.ast.adt import DataCons


# Patterns and Equations.  In the final AST, an expression like:
#
#     let id x = x in ...
#
# would actually be like
#
#     let id = \x -> x
#
# with some complications if pattern-matching is allowed:
#
#     let length [] = 0
#         lenght (x:xs) = 1 + length xs
#     in  ...
#
# The Pattern and Equation definitions are not part of the final AST, but more
# concrete syntactical object in the source code.  In the final AST, the let
# expressions shown above are indistinguishable.
#
# For value (function) definitions the parser still returns *bare* Equation
# object for each line of the definition.


UnnamedPattern = Union[str, Literal, "ConsPattern"]
Pattern = Union[str, Literal, "ConsPattern", "NamedPattern"]


[docs]class ConsPattern(AST): """The syntactical notion of a pattern. """ def __init__(self, cons: str, params: Sequence[Pattern] = None) -> None: self.cons: str = cons self.params: Tuple[Pattern, ...] = tuple(params or []) def __repr__(self): return f"<pattern {self.cons!r} {self.params!r}>" def __str__(self): if self.params: return f"{self.cons} {self._parameters}" else: return self.cons @property def _parameters(self): def _str(x): if isinstance(x, str): return x elif isinstance(x, ConsPattern): return f"({x})" else: return repr(x) return " ".join(map(_str, self.params)) def __eq__(self, other): if isinstance(other, ConsPattern): return self.cons == other.cons and self.params == other.params else: return NotImplemented def __hash__(self): return hash((ConsPattern, self.cons, self.params)) @property def bindings(self) -> Iterator[str]: for param in self.params: if isinstance(param, str) and param != "_": yield param elif isinstance(param, ConsPattern): yield from param.bindings elif isinstance(param, NamedPattern): yield from param.bindings
class NamedPattern(AST): def __init__(self, name: str, pattern: UnnamedPattern) -> None: self.name = name self.pattern = pattern def __str__(self): return f"{self.name} @ {self.pattern}" @property def bindings(self) -> Iterator[str]: yield self.name pattern = self.pattern if isinstance(pattern, str) and pattern != "_": yield pattern elif isinstance(pattern, ConsPattern): yield from pattern.bindings elif isinstance(pattern, NamedPattern): yield from pattern.bindings
[docs]class Equation(AST): """The syntactical notion of an equation. """ def __init__(self, name: str, patterns: Sequence[Pattern], body: AST) -> None: self.name = name self.patterns: Tuple[Pattern, ...] = tuple(patterns or []) self.body = body self._check_non_repeated_vars() def _check_non_repeated_vars(self): names = list(n for n in self.bindings) if len(names) != len(set(names)): raise ValueError(f"Repeated bindings in patterns: {self!s}") def __repr__(self): def _str(x): result = str(x) if " " in result: return f"({result})" else: return result if self.patterns: args = " ".join(map(_str, self.patterns)) return f"<equation {self.name!s} {args} = {self.body!r}>" else: return f"<equation {self.name!s} = {self.body!r}>" def __eq__(self, other): if isinstance(other, Equation): return ( self.name == other.name and self.patterns == other.patterns and self.body == other.body ) else: return NotImplemented def __hash__(self): return hash((Equation, self.name, self.patterns, self.body)) @property def bindings(self) -> Iterator[str]: """The names bound in the arguments""" for pattern in self.patterns: if isinstance(pattern, str) and pattern != "_": yield pattern elif isinstance(pattern, ConsPattern): yield from pattern.bindings
LocalDefinition = Union[Equation, TypeEnvironment] ValueDefinitions = Mapping[str, List[Equation]]
[docs]@dataclass(init=False, unsafe_hash=True) class ConcreteLet(AST): """The concrete representation of a let/where expression. """ definitions: Tuple[LocalDefinition, ...] # noqa body: AST def __init__(self, definitions: Sequence[LocalDefinition], body: AST) -> None: self.definitions = tuple(definitions) self.body = body @memoized_property def ast(self) -> _LetExpr: return self.compile() @memoized_property def value_definitions(self) -> ValueDefinitions: """The function definitions.""" return fst(self._definitions) @memoized_property def local_environment(self) -> TypeEnvironment: return snd(self._definitions) @memoized_property def _definitions(self) -> Tuple[ValueDefinitions, TypeEnvironment]: localenv: TypeEnvironment = {} defs: MutableMapping[str, List[Equation]] = {} # noqa for dfn in self.definitions: if isinstance(dfn, Equation): equations = defs.setdefault(dfn.name, []) equations.append(dfn) elif isinstance(dfn, dict): localenv.update(dfn) # type: ignore else: assert False, f"Unknown definition type {dfn!r}" return defs, localenv def compile(self) -> _LetExpr: r"""Build a Let/Letrec from a set of equations and a body. We need to decide if we issue a Let or a Letrec: if any of declared names appear in the any of the bodies we must issue a Letrec, otherwise issue a Let. Also we need to convert function-patterns into Lambda abstractions:: let id x = ... becomes:: led id = \x -> ... """ from xotl.fl.graphs import Graph from xotl.fl.match import FunctionDefinition # Type checking letrecs don't generalize definitions, we could end up # in a situation like the one described in [Mycroft1984] -- see the # test `test_conflicting_uses_of_non_generalized_map`. # # The solution is to first do a dependency analysis of the symbols in # the ContreteLet definition and rewrite the definition into several # nested Let/Letrec. # # We create a graph where nodes are (essentially) the names defined in # the ConcreteLet and there's an edge from name A to name B, if B is # used free in the RHS of A. # nodes = {} for name, equations in self.value_definitions.items(): nodes[name] = _LetGraphNode(name, tuple(equations)) graph: Graph[_LetGraphNode] = Graph() for node in nodes.values(): graph.add_node(node) for dependency in node.dependencies: if dependency in nodes: graph.add_edge(node, nodes[dependency]) # # After the graph is created; we compute the Strongly Connected # Components (SCC). Each SCC will be a bundle of mutually-recursive # nodes. # components = [] components_index = {} for scc in graph.get_sccs(): component = _ComponentNode(tuple(scc)) for name in component.names: components_index[name] = component components.append(component) del nodes, graph # # Each SCC has the names that must be kept together. But a node may # depend on another one in a different SCC, so there's still some # order we need to respect: Construct another graph, where the nodes # are the SCCs and there's an edge from a node C to D if any of names # in C depends on any of the names in D (this graph is guaranteed to # have no cycles, aka a DAG). # dag: Graph[_ComponentNode] = Graph() for component in components: dag.add_node(component) for dep in component.other_dependencies: if dep in components_index: dag.add_edge(component, components_index[dep]) # # Construct several nested Let/Letrec nodes following the reversed # topological sort of the DAG. But we collapse the components with # the same score: those has no mutual dependencies between them and # thus introduce no generalization problem. # body: _LetExpr = self.body # type: ignore for score, collapsable in groupby( dag.get_topological_order(reverse=True, with_score=True), key=snd ): component = _ComponentNode.union(*(comp for comp, _ in collapsable)) defs = { node.name: FunctionDefinition(node.equations) for node in component.nodes } compiled = {name: dfn.compile() for name, dfn in defs.items()} if component.recursive: klass: Class[_LetExpr] = Letrec else: klass = Let body = klass( compiled, body, { k: v for k, v in self.local_environment.items() if k in component.names }, ) return body
@dataclass(unsafe_hash=True) class _LetGraphNode: name: str equations: Sequence[Equation] @property def dependencies(self): from operator import or_ from functools import reduce return reduce(or_, (set(find_free_names(eq)) for eq in self.equations), set()) @dataclass(unsafe_hash=True) class _ComponentNode: nodes: Sequence[_LetGraphNode] @property def names(self): return {node.name for node in self.nodes} @property def dependencies(self): from operator import or_ from functools import reduce return reduce(or_, (node.dependencies for node in self.nodes), set()) @property def other_dependencies(self): names = self.names return {dep for dep in self.dependencies if dep not in names} @property def recursive(self): deps = self.dependencies return any(name in deps for name in self.names) def __eq__(self, other): if isinstance(other, _ComponentNode): return self.names == other.names else: return NotImplemented def __or__(self, other): if isinstance(other, _ComponentNode): return _ComponentNode(tuple(self.nodes) + tuple(other.nodes)) else: return NotImplemented def union(self, *others) -> "_ComponentNode": from operator import or_ from functools import reduce return reduce(or_, others, self) class Case(ILC): """The case expression. Part of the intermediate language. ConcreteLet, if using pattern matching may get translated to case expressions. """ def __init__(self, expr: ILC, branches: Sequence[Tuple["CaseBranch", ILC]]) -> None: pass class CaseBranch: pass @dataclass class LiteralBranch(CaseBranch): value: Literal def __init__(self, value: Literal) -> None: self.value = value @dataclass class ConstructorBranch(CaseBranch): datacons: DataCons @property def params(self): from xotl.fl.utils import namesupply return list(namesupply(limit=len(self.datacons.args))) @property def cons(self): return self.datacons.name def __repr__(self): if self.params: params = " ".join(map(str, self.params)) return f"{{{self.cons} {params}}}" else: return f"{{{self.cons}}}"