//
// Syd: rock-solid application kernel
// src/kernel/shm.rs: Shared memory syscall handlers
//
// Copyright (c) 2023, 2024, 2025 Ali Polatel <alip@chesswob.org>
//
// SPDX-License-Identifier: GPL-3.0

use libseccomp::ScmpNotifResp;
use nix::errno::Errno;

use crate::{config::SHM_UNSAFE_MASK, hook::UNotifyEventRequest, sandbox::Action, warn};

const O_CREAT: u64 = libc::O_CREAT as u64;

const SHM_X: u64 = libc::SHM_EXEC as u64;

const IPC_SET: u64 = libc::IPC_SET as u64;

const MSG_STAT_ANY: u64 = 13;
const SEM_STAT_ANY: u64 = 20;
const SHM_STAT_ANY: u64 = 15;

pub(crate) fn sys_ipc(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    // Determine system call.
    // 2 -> semget
    // 3 -> semctl
    // 13 -> msgget
    // 14 -> msgctl
    // 21 -> shmat
    // 23 -> shmget
    // 24 -> shmctl
    match req.data.args[0] & 0xffff {
        2 => syscall_semget_handler(request, req.data.args[3]),
        3 => syscall_semctl_handler(request, req.data.args[3]),
        13 => syscall_msgget_handler(request, req.data.args[2]),
        14 => syscall_msgctl_handler(request, req.data.args[2]),
        21 => syscall_shmat_handler(request, req.data.args[2]),
        23 => syscall_shmget_handler(request, req.data.args[3]),
        24 => syscall_shmctl_handler(request, req.data.args[2]),
        _ => {
            // SAFETY: Safe ipc call, continue.
            // No pointer-dereference in access check.
            unsafe { request.continue_syscall() }
        }
    }
}

pub(crate) fn sys_shmat(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    syscall_shmat_handler(request, req.data.args[2])
}

pub(crate) fn sys_msgctl(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    syscall_msgctl_handler(request, req.data.args[1])
}

pub(crate) fn sys_semctl(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    syscall_semctl_handler(request, req.data.args[2])
}

pub(crate) fn sys_shmctl(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    syscall_shmctl_handler(request, req.data.args[1])
}

pub(crate) fn sys_msgget(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    syscall_msgget_handler(request, req.data.args[1])
}

pub(crate) fn sys_semget(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    syscall_semget_handler(request, req.data.args[2])
}

pub(crate) fn sys_shmget(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    syscall_shmget_handler(request, req.data.args[2])
}

fn syscall_shmat_handler(request: UNotifyEventRequest, shmflg: u64) -> ScmpNotifResp {
    let req = request.scmpreq;

    if shmflg & SHM_X == 0 {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let act = Action::Kill;
    warn!("ctx": "ipc", "op": "check_shm",
        "err": "Unsafe shmat call with SHM_EXEC",
        "act": act, "pid": req.pid,
        "sys": "shmat", "shmflg": shmflg,
        "tip": "configure `trace/allow_unsafe_shm:1'");

    let _ = request.kill(act);
    request.fail_syscall(Errno::EACCES)
}

#[allow(clippy::cognitive_complexity)]
fn syscall_msgctl_handler(request: UNotifyEventRequest, op: u64) -> ScmpNotifResp {
    let op = op & 0xff;
    let req = request.scmpreq;

    if !matches!(op, IPC_SET | MSG_STAT_ANY) {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let act = Action::Kill;
    warn!("ctx": "ipc", "op": "check_shm",
        "err": "Unsafe msgctl call",
        "act": act, "pid": req.pid,
        "sys": "msgctl", "msg_op": op,
        "tip": "configure `trace/allow_unsafe_shm:1'");

    let _ = request.kill(act);
    request.fail_syscall(Errno::EACCES)
}

#[allow(clippy::cognitive_complexity)]
fn syscall_semctl_handler(request: UNotifyEventRequest, op: u64) -> ScmpNotifResp {
    let op = op & 0xff;
    let req = request.scmpreq;

    if !matches!(op, IPC_SET | SEM_STAT_ANY) {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let act = Action::Kill;
    warn!("ctx": "ipc", "op": "check_shm",
        "err": "Unsafe semctl call",
        "act": act, "pid": req.pid,
        "sys": "semctl", "sem_op": op,
        "tip": "configure `trace/allow_unsafe_shm:1'");

    let _ = request.kill(act);
    request.fail_syscall(Errno::EACCES)
}

#[allow(clippy::cognitive_complexity)]
fn syscall_shmctl_handler(request: UNotifyEventRequest, op: u64) -> ScmpNotifResp {
    let op = op & 0xff;
    let req = request.scmpreq;

    if !matches!(op, IPC_SET | SHM_STAT_ANY) {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let act = Action::Kill;
    warn!("ctx": "ipc", "op": "check_shm",
        "err": "Unsafe shmctl call",
        "act": act, "pid": req.pid,
        "sys": "shmctl", "shm_op": op,
        "tip": "configure `trace/allow_unsafe_shm:1'");

    let _ = request.kill(act);
    request.fail_syscall(Errno::EACCES)
}

fn syscall_msgget_handler(request: UNotifyEventRequest, flg: u64) -> ScmpNotifResp {
    let req = request.scmpreq;

    if flg & SHM_UNSAFE_MASK == 0 {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let act = Action::Kill;
    warn!("ctx": "ipc", "op": "check_shm",
        "err": "Unsafe msgget call",
        "act": act, "pid": req.pid,
        "sys": "msgget", "flg": flg,
        "tip": "configure `trace/allow_unsafe_shm:1'");

    let _ = request.kill(act);
    request.fail_syscall(Errno::EACCES)
}

fn syscall_semget_handler(request: UNotifyEventRequest, flg: u64) -> ScmpNotifResp {
    let req = request.scmpreq;

    if flg & SHM_UNSAFE_MASK == 0 {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let act = Action::Kill;
    warn!("ctx": "ipc", "op": "check_shm",
        "err": "Unsafe semget call",
        "act": act, "pid": req.pid,
        "sys": "semget", "flg": flg,
        "tip": "configure `trace/allow_unsafe_shm:1'");

    let _ = request.kill(act);
    request.fail_syscall(Errno::EACCES)
}

fn syscall_shmget_handler(request: UNotifyEventRequest, flg: u64) -> ScmpNotifResp {
    let req = request.scmpreq;

    if flg & SHM_UNSAFE_MASK == 0 {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let act = Action::Kill;
    warn!("ctx": "ipc", "op": "check_shm",
        "err": "Unsafe shmget call",
        "act": act, "pid": req.pid,
        "sys": "shmget", "flg": flg,
        "tip": "configure `trace/allow_unsafe_shm:1'");

    let _ = request.kill(act);
    request.fail_syscall(Errno::EACCES)
}

#[allow(clippy::cognitive_complexity)]
pub(crate) fn sys_mq_open(request: UNotifyEventRequest) -> ScmpNotifResp {
    let req = request.scmpreq;

    // Mode is only valid with O_CREAT!
    let oflag = req.data.args[1];
    if oflag & O_CREAT == 0 {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let mode = req.data.args[2];
    if mode & SHM_UNSAFE_MASK == 0 {
        // SAFETY: No pointer dereference in access check.
        return unsafe { request.continue_syscall() };
    }

    let act = Action::Kill;
    warn!("ctx": "ipc", "op": "check_shm",
        "err": "Unsafe mq_open call",
        "act": act, "pid": req.pid,
        "sys": "mq_open", "oflag": oflag, "mode": mode,
        "tip": "configure `trace/allow_unsafe_mqueue:1'");

    let _ = request.kill(act);
    request.fail_syscall(Errno::EACCES)
}
