aboutsummaryrefslogtreecommitdiff
path: root/src/Text/Pandoc/Lua/Marshaling/Version.hs
blob: 090725afcf139dc18b98d658c43f2ab9280be33f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# LANGUAGE LambdaCase           #-}
{-# LANGUAGE OverloadedStrings    #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{- |
   Module      : Text.Pandoc.Lua.Marshaling.Version
   Copyright   : © 2019-2020 Albert Krewinkel
   License     : GNU GPL, version 2 or above

   Maintainer  : Albert Krewinkel <tarleb+pandoc@moltkeplatz.de>
   Stability   : alpha

Marshaling of @'Version'@s. The marshaled elements can be compared using
default comparison operators (like @>@ and @<=@).
-}
module Text.Pandoc.Lua.Marshaling.Version
  ( peekVersion
  , pushVersion
  )
  where

import Data.Text (Text)
import Data.Maybe (fromMaybe)
import Data.Version (Version (..), makeVersion, parseVersion, showVersion)
import Foreign.Lua (Lua, Optional (..), NumResults,
                    Peekable, Pushable, StackIndex)
import Foreign.Lua.Types.Peekable (reportValueOnFailure)
import Foreign.Lua.Userdata (ensureUserdataMetatable, pushAnyWithMetatable,
                             toAnyWithName)
import Safe (atMay, lastMay)
import Text.Pandoc.Lua.Marshaling.AnyValue (AnyValue (..))
import Text.ParserCombinators.ReadP (readP_to_S)

import qualified Foreign.Lua as Lua
import qualified Text.Pandoc.Lua.Util as LuaUtil

-- | Push a @'Version'@ element to the Lua stack.
pushVersion :: Version -> Lua ()
pushVersion version = pushAnyWithMetatable pushVersionMT version
 where
  pushVersionMT = ensureUserdataMetatable versionTypeName $ do
    LuaUtil.addFunction "__eq" __eq
    LuaUtil.addFunction "__le" __le
    LuaUtil.addFunction "__lt" __lt
    LuaUtil.addFunction "__len" __len
    LuaUtil.addFunction "__index" __index
    LuaUtil.addFunction "__pairs" __pairs
    LuaUtil.addFunction "__tostring" __tostring

instance Pushable Version where
  push = pushVersion

peekVersion :: StackIndex -> Lua Version
peekVersion idx = Lua.ltype idx >>= \case
  Lua.TypeString -> do
    versionStr <- Lua.peek idx
    let parses = readP_to_S parseVersion versionStr
    case lastMay parses of
      Just (v, "") -> return v
      _  -> Lua.throwException $ "could not parse as Version: " ++ versionStr

  Lua.TypeUserdata ->
    reportValueOnFailure versionTypeName
                         (`toAnyWithName` versionTypeName)
                         idx
  Lua.TypeNumber -> do
    n <- Lua.peek idx
    return (makeVersion [n])

  Lua.TypeTable ->
    makeVersion <$> Lua.peek idx

  _ ->
    Lua.throwException "could not peek Version"

instance Peekable Version where
  peek = peekVersion

-- | Name used by Lua for the @CommonState@ type.
versionTypeName :: String
versionTypeName = "HsLua Version"

__eq :: Version -> Version -> Lua Bool
__eq v1 v2 = return (v1 == v2)

__le :: Version -> Version -> Lua Bool
__le v1 v2 = return (v1 <= v2)

__lt :: Version -> Version -> Lua Bool
__lt v1 v2 = return (v1 < v2)

-- | Get number of version components.
__len :: Version -> Lua Int
__len = return . length . versionBranch

-- | Access fields.
__index :: Version -> AnyValue -> Lua NumResults
__index v (AnyValue k) = do
  ty <- Lua.ltype k
  case ty of
    Lua.TypeNumber -> do
      n <- Lua.peek k
      let versionPart = atMay (versionBranch v) (n - 1)
      Lua.push (Lua.Optional versionPart)
      return 1
    Lua.TypeString -> do
      (str :: Text) <- Lua.peek k
      if str == "must_be_at_least"
        then 1 <$ Lua.pushHaskellFunction must_be_at_least
        else 1 <$ Lua.pushnil
    _ -> 1 <$ Lua.pushnil

-- | Create iterator.
__pairs :: Version -> Lua NumResults
__pairs v = do
  Lua.pushHaskellFunction nextFn
  Lua.pushnil
  Lua.pushnil
  return 3
 where
  nextFn :: AnyValue -> Optional Int -> Lua Lua.NumResults
  nextFn _ (Optional key) =
    case key of
      Nothing -> case versionBranch v of
                   []  -> 2 <$ (Lua.pushnil *> Lua.pushnil)
                   n:_ -> 2 <$ (Lua.push (1 :: Int) *> Lua.push n)
      Just n  -> case atMay (versionBranch v) n of
                   Nothing -> 2 <$ (Lua.pushnil *> Lua.pushnil)
                   Just b  -> 2 <$ (Lua.push (n + 1) *> Lua.push b)

-- | Convert to string.
__tostring :: Version -> Lua String
__tostring v = return (showVersion v)

-- | Default error message when a version is too old. This message is
-- formatted in Lua with the expected and actual versions as arguments.
versionTooOldMessage :: String
versionTooOldMessage = "expected version %s or newer, got %s"

-- | Throw an error if this version is older than the given version.
-- FIXME: This function currently requires the string library to be
-- loaded.
must_be_at_least :: Version -> Version -> Optional String -> Lua NumResults
must_be_at_least actual expected optMsg = do
  let msg = fromMaybe versionTooOldMessage (fromOptional optMsg)
  if expected <= actual
    then return 0
    else do
      Lua.getglobal' "string.format"
      Lua.push msg
      Lua.push (showVersion expected)
      Lua.push (showVersion actual)
      Lua.call 3 1
      Lua.error