aboutsummaryrefslogtreecommitdiff
path: root/src/IRE/YOLO.hsc
blob: 9cec3c4fe5c0d6a349306f843e20085acbe12d02 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
{-# LANGUAGE OverloadedStrings #-}

module IRE.YOLO (
  Detector
, Item(..)
, detect
, newDetector
) where

import Data.Aeson (ToJSON, object, toJSON, (.=))
import Data.ByteString (ByteString)
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Data.Int (Int32)
import Data.Text (Text)
import Foreign.C.String (CString, withCString)
import Foreign.C.Types (CChar, CFloat(..), CSize(..))
import Foreign.Marshal.Array (peekArray)
import Foreign.Ptr (Ptr)
import Foreign.Storable (Storable(..))
import qualified Control.Concurrent.Lock as L
import qualified Data.Array.IArray as A
import qualified Data.Text as T
import qualified Data.Text.IO as TIO

import IRE.Config (YOLO(..))

#include "libdarknet.h"

data Detector = Detector (Ptr ()) (A.Array Int Text) L.Lock

data Item = Item {
  itemClass      :: Int    -- ^ Object class number
, itemName       :: Text   -- ^ human-readable description, e. g. "cat", "backpack".
, itemConfidence :: Float  -- ^ A.K.A. probability
, itemBox        :: (Float, Float, Float, Float)
} deriving (Show)

instance Storable Item where
  sizeOf _ = #{size libdarknet_item}
  alignment _ = #{alignment libdarknet_item}
  poke = undefined
  peek p = do
    _c <- #{peek libdarknet_item, klass} p :: IO Int32
    _p <- #{peek libdarknet_item, confidence} p
    _x <- #{peek libdarknet_item, x} p
    _y <- #{peek libdarknet_item, y} p
    _h <- #{peek libdarknet_item, h} p
    _w <- #{peek libdarknet_item, w} p
    return $ Item (fromIntegral _c) "?" _p (_x, _y, _h, _w)

instance ToJSON Item where
  toJSON (Item c n a (x, y, h, w)) =
    object
      [ "class" .= c
      , "name" .= n
      , "confidence" .= a
      , "box" .= object
          [ "x" .= x
          , "y" .= y
          , "h" .= h
          , "w" .= w
          ]
      ]

newDetector :: YOLO -> IO Detector
newDetector (YOLO cfg weights names) =
  withCString cfg (\c ->
    withCString weights (\w -> do
      n <- T.lines <$> TIO.readFile names
      let a = A.listArray (0, length n) n
      l <- L.new
      d <- libdarknet_new_detector c w
      return $ Detector d a l))


detect :: Detector -> Float -> Float -> ByteString -> IO [Item]
detect (Detector d ns lk) threshold tree_threshold imgdata =
  unsafeUseAsCStringLen imgdata (\(img, len) -> do
    items <- L.with lk $ do
      CSize s <- libdarknet_detect d
                    (CFloat threshold) (CFloat tree_threshold)
                    img (CSize $ fromIntegral len)
      let ptr = libdarknet_get_items d
      peekArray (fromIntegral s) ptr
    return $ map (\i@(Item c _ _ _) -> i{itemName = ns A.! c}) items
    )


foreign import ccall safe "libdarknet_new_detector"
  libdarknet_new_detector :: CString -> CString -> IO (Ptr ())

foreign import ccall safe "libdarknet_detect"
  libdarknet_detect :: Ptr () -> CFloat -> CFloat -> Ptr CChar -> CSize -> IO CSize

foreign import ccall unsafe "libdarknet_get_items"
  libdarknet_get_items :: Ptr () -> Ptr Item