//
// Syd: rock-solid application kernel
// src/workers/mod.rs: Worker threads implementation
//
// Copyright (c) 2024, 2025 Ali Polatel <alip@chesswob.org>
// Based in part upon rusty_pool which is:
//     Copyright (c) Robin Friedli <robinfriedli@icloud.com>
//     SPDX-License-Identifier: Apache-2.0
//
// SPDX-License-Identifier: GPL-3.0

use std::{
    collections::hash_map::Entry,
    fs::File,
    option::Option,
    sync::{
        atomic::{AtomicUsize, Ordering},
        Arc, RwLock,
    },
};

use nix::{
    errno::Errno,
    sys::{
        signal::{SigSet, Signal},
        socket::UnixAddr,
    },
    unistd::{gettid, Pid},
};

use crate::{
    cache::{
        signal_map_new, sys_interrupt_map_new, sys_result_map_new, ExecResult, SignalMap,
        SysInterrupt, SysInterruptMap, SysResultMap,
    },
    confine::{ScmpNotifReq, SydMemoryMap},
    elf::ExecutableFile,
    fs::{block_signal, retry_on_eintr, sigtimedpoll, unblock_signal, CanonicalPath},
    hash::SydHashMap,
    sigset::SydSigSet,
};

// syd_aes: Encryptor helper thread
pub(crate) mod aes;
// syd_int: Interrupter helper thread
pub(crate) mod int;
// syd_ipc: IPC thread
pub(crate) mod ipc;
// syd_emu: Main worker threads
pub(crate) mod emu;

/// A cache for worker threads.
#[derive(Debug)]
pub(crate) struct WorkerCache<'a> {
    // Signal handlers map
    pub(crate) signal_map: SignalMap,
    // System call interrupt map
    pub(crate) sysint_map: SysInterruptMap,
    // System call result map
    pub(crate) sysres_map: SysResultMap<'a>,
}

impl<'a> WorkerCache<'a> {
    pub(crate) fn new() -> Self {
        Self {
            signal_map: signal_map_new(),
            sysint_map: sys_interrupt_map_new(),
            sysres_map: sys_result_map_new(),
        }
    }

    // Increment count of handled signals.
    pub(crate) fn inc_sig_handle(&self, request_tgid: Pid) {
        let mut map = self
            .signal_map
            .sig_handle
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        map.entry(request_tgid)
            .and_modify(|v| *v = v.saturating_add(1))
            .or_insert(1);
        // let count = *count;
        drop(map);

        /*
        debug!("ctx": "count_signal",
            "msg": format!("forwarded {count} signals to TGID:{request_tgid}"),
            "pid": request_tgid.as_raw());
        */
    }

    // Decrement count of handled signals, return true if decremented, false if zero.
    #[allow(clippy::cognitive_complexity)]
    pub(crate) fn dec_sig_handle(&self, request_tgid: Pid) -> bool {
        let mut is_dec = false;

        let mut map = self
            .signal_map
            .sig_handle
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        if let Entry::Occupied(mut entry) = map.entry(request_tgid) {
            let count = entry.get_mut();

            /*
            debug!(
                "ctx": "count_signal",
                "msg": format!("returned from one of {count} signals for TGID:{request_tgid}"),
                "pid": request_tgid.as_raw()
            );
            */

            *count = count.saturating_sub(1);
            is_dec = true;

            if *count == 0 {
                let _ = entry.remove();
            }
        } /* else {
              debug!(
                  "ctx": "count_signal",
                  "msg": format!("returned from unknown signal for TGID:{request_tgid}"),
                  "pid": request_tgid.as_raw()
              );
          }*/

        is_dec
    }

    // Delete a TGID from the signal handle map.
    pub(crate) fn retire_sig_handle(&self, tgid: Pid) {
        let mut map = self
            .signal_map
            .sig_handle
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        map.remove(&tgid);
    }

    // Record a chdir result.
    pub(crate) fn add_chdir<'b>(&'b self, pid: Pid, path: CanonicalPath<'a>) {
        self.sysres_map
            .trace_chdir
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(pid, path);
    }

    // Query, remove and return a chdir result.
    #[allow(clippy::type_complexity)]
    pub(crate) fn get_chdir<'b>(&'b self, pid: Pid) -> Option<(Pid, CanonicalPath<'a>)> {
        self.sysres_map
            .trace_chdir
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .remove_entry(&pid)
    }

    // Record an error result.
    pub(crate) fn add_error(&self, pid: Pid, errno: Option<Errno>) {
        self.sysres_map
            .trace_error
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(pid, errno);
    }

    // Query, remove and return a error result.
    #[allow(clippy::type_complexity)]
    pub(crate) fn get_error(&self, pid: Pid) -> Option<(Pid, Option<Errno>)> {
        self.sysres_map
            .trace_error
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .remove_entry(&pid)
    }

    // Record a execv result.
    #[allow(clippy::too_many_arguments)]
    pub(crate) fn add_exec(
        &self,
        pid: Pid,
        exe: ExecutableFile,
        file: File,
        ip: u64,
        sp: u64,
        args: [u64; 6],
        ip_mem: Option<[u8; 64]>,
        sp_mem: Option<[u8; 64]>,
        memmap: Option<Vec<SydMemoryMap>>,
    ) {
        let result = ExecResult {
            exe,
            file,
            ip,
            sp,
            args,
            ip_mem,
            sp_mem,
            memmap,
        };

        self.sysres_map
            .trace_execv
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .insert(pid, result);
    }

    // Query, remove and return a exec result.
    pub(crate) fn get_exec(&self, pid: Pid) -> Option<(Pid, ExecResult)> {
        self.sysres_map
            .trace_execv
            .lock()
            .unwrap_or_else(|err| err.into_inner())
            .remove_entry(&pid)
    }

    // Add a restarting signal.
    pub(crate) fn add_sig_restart(&self, request_tgid: Pid, sig: libc::c_int) {
        let mut map = self
            .sysint_map
            .sig_restart
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        if let Some(set) = map.get_mut(&request_tgid) {
            set.add(sig);
            return;
        }

        let mut set = SydSigSet::new(0);
        set.add(sig);

        map.insert(request_tgid, set);
    }

    // Delete a restarting signal.
    pub(crate) fn del_sig_restart(&self, request_tgid: Pid, sig: libc::c_int) {
        let mut map = self
            .sysint_map
            .sig_restart
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        let set_nil = if let Some(set) = map.get_mut(&request_tgid) {
            set.del(sig);
            set.is_empty()
        } else {
            return;
        };

        if set_nil {
            map.remove(&request_tgid);
        }
    }

    // Delete a TGID from the signal restart map.
    pub(crate) fn retire_sig_restart(&self, tgid: Pid) {
        let mut map = self
            .sysint_map
            .sig_restart
            .lock()
            .unwrap_or_else(|err| err.into_inner());
        map.remove(&tgid);
    }

    // Add a blocked syscall.
    #[allow(clippy::cast_possible_wrap)]
    pub(crate) fn add_sys_block(
        &self,
        request: ScmpNotifReq,
        ignore_restart: bool,
    ) -> Result<(), Errno> {
        let handler_tid = gettid();
        let interrupt = SysInterrupt::new(request, handler_tid, ignore_restart)?;

        let (ref lock, ref cvar) = *self.sysint_map.sys_block;
        let mut map = lock.lock().unwrap_or_else(|err| err.into_inner());

        map.retain(|_, interrupt| interrupt.handler != handler_tid);
        map.insert(request.id, interrupt);

        cvar.notify_one();

        // Discard spurious pending signals.
        // Note, SIGALRM is only queued once unlike realtime signals,
        // therefore we do not need a while loop here for sigtimedpoll.
        let mut mask = SigSet::empty();
        mask.add(Signal::SIGALRM);
        let _ = retry_on_eintr(|| sigtimedpoll(&mask, None));

        unblock_signal(Signal::SIGALRM)
    }

    // Remove a blocked fifo.
    pub(crate) fn del_sys_block(&self, request_id: u64, interrupted: bool) -> Result<(), Errno> {
        block_signal(Signal::SIGALRM)?;

        if !interrupted {
            let (ref lock, ref _cvar) = *self.sysint_map.sys_block;
            let mut map = lock.lock().unwrap_or_else(|err| err.into_inner());
            map.remove(&request_id);
        }

        Ok(())
    }

    // Remove a PID completely from the cache.
    pub(crate) fn del_pid(&self, pid: Pid) {
        // Retire TGID from signal maps.
        self.retire_sig_handle(pid);
        self.retire_sig_restart(pid);

        // Remove preexisting error record for pid.
        let _ = self.get_error(pid);

        // Remove preexisting chdir record for pid.
        let _ = self.get_chdir(pid);

        // Remove preexisting exec record for pid.
        let _ = self.get_exec(pid);
    }
}

// The absolute maximum number of workers. This corresponds to the
// maximum value that can be stored within half the bits of usize, as
// two counters (total workers and busy workers) are stored in one
// AtomicUsize.
const BITS: usize = std::mem::size_of::<usize>() * 8;
const MAX_SIZE: usize = (1 << (BITS / 2)) - 1;

const WORKER_BUSY_MASK: usize = MAX_SIZE;
const INCREMENT_TOTAL: usize = 1 << (BITS / 2);
const INCREMENT_BUSY: usize = 1;

/// 1. Struct containing data shared between workers.
/// 2. Struct that stores and handles an `AtomicUsize` that stores the
///    total worker count in the higher half of bits and the busy worker
///    count in the lower half of bits. This allows to to increment /
///    decrement both counters in a single atomic operation.
#[derive(Default)]
pub(crate) struct WorkerData(pub(crate) AtomicUsize);

impl WorkerData {
    /*
    fn increment_both(&self) -> (usize, usize) {
        let old_val = self
            .0
            .fetch_add(INCREMENT_TOTAL | INCREMENT_BUSY, Ordering::Relaxed);
        Self::split(old_val)
    }
    */

    pub(crate) fn decrement_both(&self) -> (usize, usize) {
        let old_val = self
            .0
            .fetch_sub(INCREMENT_TOTAL | INCREMENT_BUSY, Ordering::Relaxed);
        Self::split(old_val)
    }

    pub(crate) fn increment_worker_total(&self) -> usize {
        let old_val = self.0.fetch_add(INCREMENT_TOTAL, Ordering::Relaxed);
        Self::total(old_val)
    }

    #[allow(dead_code)]
    pub(crate) fn decrement_worker_total(&self) -> usize {
        let old_val = self.0.fetch_sub(INCREMENT_TOTAL, Ordering::Relaxed);
        Self::total(old_val)
    }

    pub(crate) fn increment_worker_busy(&self) -> usize {
        let old_val = self.0.fetch_add(INCREMENT_BUSY, Ordering::Relaxed);
        Self::busy(old_val)
    }

    pub(crate) fn decrement_worker_busy(&self) -> usize {
        let old_val = self.0.fetch_sub(INCREMENT_BUSY, Ordering::Relaxed);
        Self::busy(old_val)
    }

    /*
    fn get_total_count(&self) -> usize {
        Self::total(self.0.load(Ordering::Relaxed))
    }

    fn get_busy_count(&self) -> usize {
        Self::busy(self.0.load(Ordering::Relaxed))
    }
    */

    #[inline]
    pub(crate) fn split(val: usize) -> (usize, usize) {
        let total_count = val >> (BITS / 2);
        let busy_count = val & WORKER_BUSY_MASK;
        (total_count, busy_count)
    }

    #[inline]
    fn total(val: usize) -> usize {
        val >> (BITS / 2)
    }

    #[inline]
    fn busy(val: usize) -> usize {
        val & WORKER_BUSY_MASK
    }
}

// [inode,(pid,path)] map of unix binds.
// Path is only used for UNIX domain sockets.
//
// SAFETY:
// 1. /proc/net/unix only gives inode information,
//    and does not include information on device id
//    or mount id so unfortunately we cannot check
//    for that here.
// 2. Pid is used for SO_PEERCRED getsockopt(2).
pub(crate) type UnixMap = Arc<RwLock<SydHashMap<u64, (Pid, Option<UnixAddr>)>>>;
