Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature: Fork kernel #1261

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 17 additions & 14 deletions ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,23 @@ def __init__(self, **kwargs):
self.debug_just_my_code,
)

self.init_shell()

if _use_appnope() and self._darwin_app_nap:
# Disable app-nap as the kernel is not a gui but can have guis
import appnope # type:ignore[import-untyped]

appnope.nope()

self._new_threads_parent_header = {}
self._initialize_thread_hooks()

if hasattr(gc, "callbacks"):
# while `gc.callbacks` exists since Python 3.3, pypy does not
# implement it even as of 3.9.
gc.callbacks.append(self._clean_thread_parent_frames)

def init_shell(self):
# Initialize the InteractiveShell subclass
self.shell = self.shell_class.instance(
parent=self,
Expand Down Expand Up @@ -145,20 +162,6 @@ def __init__(self, **kwargs):
for msg_type in comm_msg_types:
self.shell_handlers[msg_type] = getattr(self.comm_manager, msg_type)

if _use_appnope() and self._darwin_app_nap:
# Disable app-nap as the kernel is not a gui but can have guis
import appnope # type:ignore[import-untyped]

appnope.nope()

self._new_threads_parent_header = {}
self._initialize_thread_hooks()

if hasattr(gc, "callbacks"):
# while `gc.callbacks` exists since Python 3.3, pypy does not
# implement it even as of 3.9.
gc.callbacks.append(self._clean_thread_parent_frames)

help_links = List(
[
{
Expand Down
95 changes: 93 additions & 2 deletions ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -726,8 +726,99 @@ def start(self) -> None:
if self.poller is not None:
self.poller.start()
backend = "trio" if self.trio_loop else "asyncio"
run(self.main, backend=backend)
return

while True:
run(self.main, backend=backend)
if not getattr(self.kernel, "_fork_requested", False):
break
self.fork()

def fork(self):
# HACK: Why is this necessary?
# Without it, the *parent* kernel doesn't work.
# Also, it doesn't work if I try to start it again with
# self.init_iopub()...
self.iopub_thread.stop()

# Create a temporary connection file that will be inherited by the child process.
connection_file, conn = write_connection_file()

parent_pid = os.getpid()
pid = os.fork()
self.kernel._fork_requested = False # reset for parent AND child
if pid == 0:
self.log.debug("Child kernel with pid %s", os.getpid())

# close all sockets and ioloops
self.close()

# Reset all ports so they will be reinitialized with the ports from the connection file
for name in [
"%s_port" % channel for channel in ("shell", "stdin", "iopub", "hb", "control")
]:
setattr(self, name, 0)
self.connection_file = connection_file

# Reset the ZMQ context for it to be recreated
self.context = None

# Make ParentPoller work correctly (the new process is a child of the previous kernel)
self.parent_handle = parent_pid

# Session have a protection to send messages from forked processes through the `check_pid` flag.
self.session.pid = os.getpid()
self.session.key = conn["key"].encode()

self.init_connection_file()
self.init_poller()
self.init_sockets()
self.init_heartbeat()
self.init_io()

kernel = self.kernel
params = dict(
parent=self,
session=self.session,
control_socket=self.control_socket,
control_thread=self.control_thread,
debugpy_socket=self.debugpy_socket,
debug_shell_socket=self.debug_shell_socket,
shell_socket=self.shell_socket,
iopub_thread=self.iopub_thread,
iopub_socket=self.iopub_socket,
stdin_socket=self.stdin_socket,
log=self.log,
profile_dir=self.profile_dir,
)
for k, v in params.items():
setattr(kernel, k, v)

kernel.user_ns = kernel.shell.user_ns
kernel.init_shell()

kernel.record_ports({name + "_port": port for name, port in self._ports.items()})
self.kernel = kernel

# Allow the displayhook to get the execution count
self.displayhook.get_execution_count = lambda: kernel.execution_count

# shell init steps
self.init_shell()
if self.shell:
self.init_gui_pylab()
self.init_extensions()
self.init_code()
# flush stdout/stderr, so that anything written to these streams during
# initialization do not get associated with the first execution request
sys.stdout.flush()
sys.stderr.flush()
self.start()
else:
self.log.debug("Parent kernel will resume")
# keep a reference, since the will set this to None
post_fork_callback = self.kernel._post_fork_callback
post_fork_callback(pid, conn)
self.kernel._post_fork_callback = None

async def main(self):
async with create_task_group() as tg:
Expand Down
23 changes: 22 additions & 1 deletion ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def _parent_header(self):
"shutdown_request",
"is_complete_request",
"interrupt_request",
"fork",
# deprecated:
"apply_request",
]
Expand All @@ -229,6 +230,25 @@ def _parent_header(self):
"usage_request",
]

def fork(self, stream, ident, parent):
# Forking in the (async)io loop is not supported.
# instead, we stop it, and use the io loop to pass
# information up the callstack
# loop = ioloop.IOLoop.current()
self._fork_requested = True

def post_fork_callback(pid, conn):
reply_content = json_clean({"status": "ok", "pid": pid, "conn": conn})
metadata = {}
metadata = self.finish_metadata(parent, metadata, reply_content)

self.session.send(
stream, "fork_reply", reply_content, parent, metadata=metadata, ident=ident
)

self._post_fork_callback = post_fork_callback
self.stop()

def __init__(self, **kwargs):
"""Initialize the kernel."""
super().__init__(**kwargs)
Expand Down Expand Up @@ -469,7 +489,8 @@ async def start(self, *, task_status: TaskStatus = TASK_STATUS_IGNORED) -> None:
if not self._is_test and self.control_socket is not None:
if self.control_thread:
self.control_thread.set_task(self.control_main)
self.control_thread.start()
if not self.control_thread.is_alive():
self.control_thread.start()
else:
tg.start_soon(self.control_main)

Expand Down
45 changes: 45 additions & 0 deletions tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from .utils import (
TIMEOUT,
assemble_output,
connect_to_kernel,
execute,
flush_channels,
get_reply,
Expand Down Expand Up @@ -491,6 +492,50 @@ def test_shutdown():
assert not km.is_alive()


def test_fork_metadata():
with new_kernel() as kc:
from .test_message_spec import validate_message

km = kc.parent
fork_msg_id = kc.fork()
fork_reply = kc.get_shell_msg(timeout=TIMEOUT)
validate_message(fork_reply, "fork_reply", fork_msg_id)
assert fork_msg_id == fork_reply["parent_header"]["msg_id"] == fork_msg_id
assert fork_reply["content"]["conn"]["key"] != kc.session.key.decode()
fork_pid = fork_reply["content"]["pid"]
_check_status(fork_reply["content"])
wait_for_idle(kc)

assert fork_pid != km.provisioner.pid
# TODO: Inspect if `fork_pid` is running? Might need to use `psutil` for this in order to be cross platform

with connect_to_kernel(fork_reply["content"]["conn"], TIMEOUT) as kc_fork:
assert fork_reply["content"]["conn"]["key"] == kc_fork.session.key.decode()
kc_fork.shutdown()


def test_fork():
def execute_with_user_expression(kc, code, user_expression):
_, reply = execute(code, kc=kc, user_expressions={"my-user-expression": user_expression})
content = reply["user_expressions"]["my-user-expression"]["data"]["text/plain"]
wait_for_idle(kc)
return content

"""Kernel forks after fork_request"""
with kernel() as kc:
assert execute_with_user_expression(kc, "a = 1", "a") == "1"
assert execute_with_user_expression(kc, "b = 2", "b") == "2"
kc.fork()
fork_reply = kc.get_shell_msg(timeout=TIMEOUT)
wait_for_idle(kc)

with connect_to_kernel(fork_reply["content"]["conn"], TIMEOUT) as kc_fork:
assert execute_with_user_expression(kc_fork, "a = 11", "a, b") == str((11, 2))
assert execute_with_user_expression(kc_fork, "b = 12", "a, b") == str((11, 12))
assert execute_with_user_expression(kc, "z = 20", "a, b") == str((1, 2))
kc_fork.shutdown()


def test_interrupt_during_input():
"""
The kernel exits after being interrupted while waiting in input().
Expand Down
6 changes: 6 additions & 0 deletions tests/test_message_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,11 @@ class IsCompleteReplyIncomplete(Reference):
indent = Unicode()


class ForkReply(Reply):
pid = Integer()
conn = Dict()


# IOPub messages


Expand Down Expand Up @@ -255,6 +260,7 @@ class HistoryReply(Reply):
"stream": Stream(),
"display_data": DisplayData(),
"header": RHeader(),
"fork_reply": ForkReply(),
}

# -----------------------------------------------------------------------------
Expand Down
13 changes: 13 additions & 0 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,16 @@ def __enter__(self):
def __exit__(self, exc, value, tb):
os.chdir(self.old_wd)
return super().__exit__(exc, value, tb)


@contextmanager
def connect_to_kernel(connection_info, timeout):
from jupyter_client import BlockingKernelClient

kc = BlockingKernelClient()
kc.log.setLevel("DEBUG")
kc.load_connection_info(connection_info)
kc.start_channels()
kc.wait_for_ready(timeout)
yield kc
kc.stop_channels()
Loading