{-# LANGUAGE CPP #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module System.Systemd.Internal where

import           Control.Exception         (bracket)
import           Control.Monad
import           Control.Monad.IO.Class    (liftIO)
import           Control.Monad.Trans.Maybe
import qualified Data.ByteString.Char8     as BC
import           Data.ByteString.Unsafe    (unsafeUseAsCStringLen)
import           Data.List
import           Foreign.C.Types           (CInt (..))
import           Foreign.Marshal           (free, mallocBytes)
import           Foreign.Ptr
import           Network.Socket
import           Network.Socket.Address    hiding (recvFrom, sendTo)
import           Network.Socket.ByteString
import           System.Posix.Env
import           System.Posix.Types        (Fd (..))

envVariableName :: String
envVariableName :: String
envVariableName = String
"NOTIFY_SOCKET"

foreign import ccall unsafe "sd_notify_with_fd"
  c_sd_notify_with_fd :: CInt -> Ptr a -> CInt -> Ptr b -> CInt -> CInt -> IO CInt

-- | Unset all environnement variable related to Systemd.
--
-- Calls to functions like 'System.Systemd.Daemon.notify' and
-- 'System.Systemd.Daemon.getActivatedSockets' will return
-- 'Nothing' after that.
unsetEnvironnement :: IO ()
unsetEnvironnement :: IO ()
unsetEnvironnement = (String -> IO ()) -> [String] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ String -> IO ()
unsetEnv [String
envVariableName, String
"LISTEN_PID", String
"LISTEN_FDS", String
"LISTEN_FDNAMES"]

sendBufWithFdTo :: Socket -> BC.ByteString -> SockAddr -> Fd -> IO Int
sendBufWithFdTo :: Socket -> ByteString -> SockAddr -> Fd -> IO Int
sendBufWithFdTo Socket
sock ByteString
state SockAddr
addr Fd
fdToSend =
  ByteString -> (CStringLen -> IO Int) -> IO Int
forall a. ByteString -> (CStringLen -> IO a) -> IO a
unsafeUseAsCStringLen ByteString
state ((CStringLen -> IO Int) -> IO Int)
-> (CStringLen -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
nbytes) ->
    IO (Ptr Any) -> (Ptr Any -> IO ()) -> (Ptr Any -> IO Int) -> IO Int
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket IO (Ptr Any)
forall a. IO (Ptr a)
addrPointer Ptr Any -> IO ()
forall a. Ptr a -> IO ()
free ((Ptr Any -> IO Int) -> IO Int) -> (Ptr Any -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \Ptr Any
p_addr -> do
      CInt
fd <- Socket -> IO CInt
socketToFd Socket
sock
      CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> IO CInt -> IO Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CInt -> Ptr CChar -> CInt -> Ptr Any -> CInt -> CInt -> IO CInt
forall a b.
CInt -> Ptr a -> CInt -> Ptr b -> CInt -> CInt -> IO CInt
c_sd_notify_with_fd (CInt -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
fd) Ptr CChar
ptr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nbytes)
                                           Ptr Any
p_addr (Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
addrSize) (Fd -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Fd
fdToSend)
  where addrSize :: Int
addrSize = SockAddr -> Int
forall sa. SocketAddress sa => sa -> Int
sizeOfSocketAddress SockAddr
addr
        addrPointer :: IO (Ptr a)
addrPointer = Int -> IO (Ptr a)
forall a. Int -> IO (Ptr a)
mallocBytes Int
addrSize IO (Ptr a) -> (Ptr a -> IO (Ptr a)) -> IO (Ptr a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (\Ptr a
ptr -> Ptr a -> SockAddr -> IO ()
forall sa a. SocketAddress sa => Ptr a -> sa -> IO ()
pokeSocketAddress Ptr a
ptr SockAddr
addr IO () -> IO (Ptr a) -> IO (Ptr a)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr a -> IO (Ptr a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Ptr a
ptr)

notifyWithFD_ :: Bool -> String -> Maybe Fd -> IO (Maybe ())
notifyWithFD_ :: Bool -> String -> Maybe Fd -> IO (Maybe ())
notifyWithFD_ Bool
unset_env String
state Maybe Fd
fd = do
        Maybe ()
res <- MaybeT IO () -> IO (Maybe ())
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT MaybeT IO ()
notifyImpl
        Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
unset_env IO ()
unsetEnvironnement
        Maybe () -> IO (Maybe ())
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ()
res

    where
        isValidPath :: String -> Bool
isValidPath String
path =   (String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
path Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
2)
                          Bool -> Bool -> Bool
&& ( String
"@" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
path
                             Bool -> Bool -> Bool
|| String
"/" String -> String -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` String
path)
        notifyImpl :: MaybeT IO ()
notifyImpl = do
            Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT IO ()) -> Bool -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ String
state String -> String -> Bool
forall a. Eq a => a -> a -> Bool
/= String
""

            String
socketPath <- IO (Maybe String) -> MaybeT IO String
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (String -> IO (Maybe String)
getEnv String
envVariableName)
            Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT IO ()) -> Bool -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ String -> Bool
isValidPath String
socketPath
            let socketPath' :: String
socketPath' = if String -> Char
forall a. [a] -> a
head String
socketPath Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'@' -- For abstract socket
                              then Char
'\0' Char -> String -> String
forall a. a -> [a] -> [a]
: String -> String
forall a. [a] -> [a]
tail String
socketPath
                              else String
socketPath

            Socket
socketFd <- IO Socket -> MaybeT IO Socket
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Socket -> MaybeT IO Socket) -> IO Socket -> MaybeT IO Socket
forall a b. (a -> b) -> a -> b
$ Family -> SocketType -> CInt -> IO Socket
socket Family
AF_UNIX SocketType
Datagram CInt
0
            Int
nbBytes  <- IO Int -> MaybeT IO Int
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Int -> MaybeT IO Int) -> IO Int -> MaybeT IO Int
forall a b. (a -> b) -> a -> b
$ case Maybe Fd
fd of
                  Maybe Fd
Nothing     -> Socket -> ByteString -> SockAddr -> IO Int
sendTo Socket
socketFd (String -> ByteString
BC.pack String
state) (String -> SockAddr
SockAddrUnix String
socketPath')
                  Just Fd
sock'  -> Socket -> ByteString -> SockAddr -> Fd -> IO Int
sendBufWithFdTo Socket
socketFd (String -> ByteString
BC.pack String
state)
                                                (String -> SockAddr
SockAddrUnix String
socketPath') Fd
sock'

            IO () -> MaybeT IO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MaybeT IO ()) -> IO () -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ Socket -> IO ()
close Socket
socketFd
            Bool -> MaybeT IO ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT IO ()) -> Bool -> MaybeT IO ()
forall a b. (a -> b) -> a -> b
$ Int
nbBytes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= String -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
state


            () -> MaybeT IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

socketToFd_ :: Socket -> IO Fd
#if ! MIN_VERSION_network(3,1,0)
socketToFd_ = fmap Fd . fdSocket
#else
socketToFd_ :: Socket -> IO Fd
socketToFd_ = (CInt -> Fd) -> IO CInt -> IO Fd
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CInt -> Fd
Fd (IO CInt -> IO Fd) -> (Socket -> IO CInt) -> Socket -> IO Fd
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Socket -> IO CInt
unsafeFdSocket
#endif

fdToSocket :: Fd -> IO Socket
fdToSocket :: Fd -> IO Socket
fdToSocket = CInt -> IO Socket
mkSocket (CInt -> IO Socket) -> (Fd -> CInt) -> Fd -> IO Socket
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Fd -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral