Source code for wolframclient.utils.externalevaluate

# -*- coding: utf-8 -*-

from __future__ import absolute_import, print_function, unicode_literals

import ast
import logging
import os
import sys

from wolframclient.deserializers import binary_deserialize
from wolframclient.language import wl
from wolframclient.language.decorators import to_wl
from wolframclient.language.side_effects import side_effect_logger
from wolframclient.serializers import export
from wolframclient.utils import six
from wolframclient.utils.api import zmq
from wolframclient.utils.datastructures import Settings
from wolframclient.utils.encoding import force_text
from wolframclient.utils.functional import last

HIDDEN_VARIABLES = (
    "__loader__",
    "__builtins__",
    "__traceback_hidden_variables__",
    "absolute_import",
    "print_function",
    "unicode_literals",
)

EXPORT_KWARGS = {"target_format": "wxf", "allow_external_objects": True}


[docs]def EvaluationEnvironment(code, session_data={}, constants=None, **extra): session_data["__loader__"] = Settings(get_source=lambda module, code=code: code) session_data["__traceback_hidden_variables__"] = HIDDEN_VARIABLES if constants: session_data.update(constants) return session_data
[docs]def execute_from_file(path, *args, **opts): with open(path, "r") as f: return execute_from_string(force_text(f.read()), *args, **opts)
[docs]def execute_from_string(code, globals={}, **opts): __traceback_hidden_variables__ = ["env", "current", "__traceback_hidden_variables__"] # this is creating a custom __loader__ that is returning the source code # traceback serializers is inspecting global variables and looking for a standard loader that can return source code. env = EvaluationEnvironment(code=code, **opts) result = None expressions = list( compile( code, filename="<unknown>", mode="exec", flags=ast.PyCF_ONLY_AST | unicode_literals.compiler_flag, ).body ) if not expressions: return if isinstance(last(expressions), ast.Expr): result = expressions.pop(-1) if expressions: exec(compile(ast.Module(expressions), "", "exec"), env) if result: return eval(compile(ast.Expression(result.value), "", "eval"), env)
[docs]class SideEffectSender(logging.Handler):
[docs] def emit(self, record): if isinstance(sys.stdout, StdoutProxy): sys.stdout.send_side_effect(record.msg)
[docs]class SocketWriter: def __init__(self, socket): self.socket = socket
[docs] def write(self, bytes): self.socket.send(bytes)
[docs]class StdoutProxy: keep_listening = wl.ExternalEvaluate.Private.ExternalEvaluateKeepListening def __init__(self, stream): self.stream = stream self.clear()
[docs] def clear(self): self.current_line = [] self.lines = []
[docs] def write(self, message): messages = force_text(message).split("\n") if len(messages) == 1: self.current_line.extend(messages) else: self.current_line.append(messages.pop(0)) rest = messages.pop(-1) self.lines.extend(messages) self.flush() if rest: self.current_line.append(rest)
[docs] def flush(self): if self.current_line or self.lines: self.send_lines("".join(self.current_line), *self.lines) self.clear()
[docs] def send_lines(self, *lines): if len(lines) == 1: return self.send_side_effect(wl.Print(*lines)) elif lines: return self.send_side_effect(wl.CompoundExpression(*map(wl.Print, lines)))
[docs] def send_side_effect(self, expr): self.stream.write(export(self.keep_listening(expr), **EXPORT_KWARGS))
[docs]def evaluate_message(input=None, return_type=None, args=None, **opts): __traceback_hidden_variables__ = True result = None if isinstance(input, six.string_types): result = execute_from_string(input, **opts) if isinstance(args, (list, tuple)): # then we have a function call to do # first get the function object we need to call result = result(*args) if return_type == "string": # bug 354267 repr returns a 'str' even on py2 (i.e. bytes). result = force_text(repr(result)) return result
[docs]@to_wl(**EXPORT_KWARGS) def handle_message(socket): __traceback_hidden_variables__ = True message = binary_deserialize(socket.recv()) result = evaluate_message(**message) sys.stdout.flush() return result
[docs]def start_zmq_instance(port=None, write_to_stdout=True, **opts): # make a reply socket sock = zmq.Context.instance().socket(zmq.PAIR) # now bind to a open port on localhost if port: sock.bind("tcp://127.0.0.1:%s" % port) else: sock.bind_to_random_port("tcp://127.0.0.1") if write_to_stdout: sys.stdout.write(force_text(sock.getsockopt(zmq.LAST_ENDPOINT))) sys.stdout.write(os.linesep) # writes \n sys.stdout.flush() return sock
[docs]def start_zmq_loop(message_limit=float("inf"), redirect_stdout=True, **opts): socket = start_zmq_instance(**opts) stream = SocketWriter(socket) messages = 0 if redirect_stdout: sys.stdout = StdoutProxy(stream) side_effect_logger.addHandler(SideEffectSender()) # now sit in a while loop, evaluating input while messages < message_limit: stream.write(handle_message(socket)) messages += 1 if redirect_stdout: sys.stdout = sys.__stdout__