{-# LANGUAGE TypeApplications #-}
{-|
Module      : HsLua.Core.Trace
Copyright   : © 2017-2024 Albert Krewinkel
License     : MIT
Maintainer  : Albert Krewinkel <tarleb@hslua.org>

Helper functions to call Lua functions with tracebacks.
-}
module HsLua.Core.Trace
  ( pcallTrace
  , callTrace
  , dofileTrace
  , dostringTrace
  ) where

import Data.ByteString (ByteString)
import Foreign.C.Types
import HsLua.Core.Auxiliary (loadfile, loadstring, tostring', traceback)
import HsLua.Core.Error (Exception, LuaError, throwErrorAsException)
import HsLua.Core.Primary (gettop, insert, pcall, pushcfunction, remove)
import HsLua.Core.Run (runWith)
import HsLua.Core.Types
  ( CFunction, LuaE, NumArgs (..), NumResults (..), PreCFunction
  , Status (OK), State (..), multret )

-- | Like @'pcall'@, but sets an appropriate message handler function,
-- thereby adding a stack traceback if an error occurs.
pcallTrace :: NumArgs -> NumResults -> LuaE e Status
pcallTrace :: forall e. NumArgs -> NumResults -> LuaE e Status
pcallTrace nargs :: NumArgs
nargs@(NumArgs CInt
nargsint) NumResults
nres = do
  curtop <- LuaE e StackIndex
forall e. LuaE e StackIndex
gettop
  let base = StackIndex
curtop StackIndex -> StackIndex -> StackIndex
forall a. Num a => a -> a -> a
- CInt -> StackIndex
forall a b. (Integral a, Num b) => a -> b
fromIntegral CInt
nargsint -- function index
  pushcfunction hsluaL_msghandler_ptr
  insert base  -- insert msghandler below function
  status' <- pcall nargs nres (Just base)
  remove base
  return status'

-- | Like @'call'@, but adds a traceback if an error occurs.
callTrace :: LuaError e => NumArgs -> NumResults -> LuaE e ()
callTrace :: forall e. LuaError e => NumArgs -> NumResults -> LuaE e ()
callTrace NumArgs
nargs NumResults
nres = NumArgs -> NumResults -> LuaE e Status
forall e. NumArgs -> NumResults -> LuaE e Status
pcallTrace NumArgs
nargs NumResults
nres LuaE e Status -> (Status -> LuaE e ()) -> LuaE e ()
forall a b. LuaE e a -> (a -> LuaE e b) -> LuaE e b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Status
OK -> () -> LuaE e ()
forall a. a -> LuaE e a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  Status
_  -> LuaE e ()
forall e a. LuaError e => LuaE e a
throwErrorAsException

-- | Run the given file as a Lua program, while also adding a
-- traceback to the error message if an error occurs.
dofileTrace :: Maybe FilePath -> LuaE e Status
dofileTrace :: forall e. Maybe FilePath -> LuaE e Status
dofileTrace Maybe FilePath
fp = Maybe FilePath -> LuaE e Status
forall e. Maybe FilePath -> LuaE e Status
loadfile Maybe FilePath
fp LuaE e Status -> (Status -> LuaE e Status) -> LuaE e Status
forall a b. LuaE e a -> (a -> LuaE e b) -> LuaE e b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Status
OK -> NumArgs -> NumResults -> LuaE e Status
forall e. NumArgs -> NumResults -> LuaE e Status
pcallTrace NumArgs
0 NumResults
multret
  Status
s  -> Status -> LuaE e Status
forall a. a -> LuaE e a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Status
s

dostringTrace :: ByteString -> LuaE e Status
dostringTrace :: forall e. ByteString -> LuaE e Status
dostringTrace ByteString
s = ByteString -> LuaE e Status
forall e. ByteString -> LuaE e Status
loadstring ByteString
s LuaE e Status -> (Status -> LuaE e Status) -> LuaE e Status
forall a b. LuaE e a -> (a -> LuaE e b) -> LuaE e b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
  Status
OK  -> NumArgs -> NumResults -> LuaE e Status
forall e. NumArgs -> NumResults -> LuaE e Status
pcallTrace NumArgs
0 NumResults
multret
  Status
err -> Status -> LuaE e Status
forall a. a -> LuaE e a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Status
err

-- | Helper function used as message handler if the function given to
-- pcall fails.
hsluaL_msghandler :: State -> IO NumResults
hsluaL_msghandler :: State -> IO NumResults
hsluaL_msghandler State
l = State -> LuaE Exception NumResults -> IO NumResults
forall e a. State -> LuaE e a -> IO a
runWith State
l (LuaE Exception NumResults -> IO NumResults)
-> LuaE Exception NumResults -> IO NumResults
forall a b. (a -> b) -> a -> b
$ do
  msg <- forall e. LuaError e => StackIndex -> LuaE e ByteString
tostring' @Exception StackIndex
1
  traceback l (Just msg) 2
  pure (NumResults 1)

-- Turn message handler into a CFunction by exporting it, then importing
-- at pointer to it.
foreign export ccall hsluaL_msghandler :: PreCFunction
foreign import ccall "&hsluaL_msghandler"
  hsluaL_msghandler_ptr:: CFunction