aboutsummaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/Main.hs37
-rw-r--r--src/Sproxy/Application.hs372
-rw-r--r--src/Sproxy/Application/Cookie.hs44
-rw-r--r--src/Sproxy/Application/OAuth2.hs18
-rw-r--r--src/Sproxy/Application/OAuth2/Common.hs39
-rw-r--r--src/Sproxy/Application/OAuth2/Google.hs78
-rw-r--r--src/Sproxy/Application/OAuth2/LinkedIn.hs83
-rw-r--r--src/Sproxy/Application/State.hs30
-rw-r--r--src/Sproxy/Config.hs88
-rw-r--r--src/Sproxy/Logging.hs99
-rw-r--r--src/Sproxy/Server.hs190
-rw-r--r--src/Sproxy/Server/DB.hs189
12 files changed, 1267 insertions, 0 deletions
diff --git a/src/Main.hs b/src/Main.hs
new file mode 100644
index 0000000..7101af0
--- /dev/null
+++ b/src/Main.hs
@@ -0,0 +1,37 @@
+{-# LANGUAGE QuasiQuotes #-}
+module Main (
+ main
+) where
+
+import Data.Maybe (fromJust)
+import Data.Version (showVersion)
+import Paths_sproxy2 (version) -- from cabal
+import System.Environment (getArgs)
+import Text.InterpolatedString.Perl6 (qc)
+import qualified System.Console.Docopt.NoTH as O
+
+import Sproxy.Server (server)
+
+usage :: String
+usage = "sproxy2 " ++ showVersion version ++
+ " - HTTP proxy for authenticating users via OAuth2" ++ [qc|
+
+Usage:
+ sproxy2 [options]
+
+Options:
+ -c, --config=FILE Configuration file [default: sproxy.yml]
+ -h, --help Show this message
+
+|]
+
+main :: IO ()
+main = do
+ doco <- O.parseUsageOrExit usage
+ args <- O.parseArgsOrExit doco =<< getArgs
+ if args `O.isPresent` O.longOption "help"
+ then putStrLn $ O.usage doco
+ else do
+ let configFile = fromJust . O.getArg args $ O.longOption "config"
+ server configFile
+
diff --git a/src/Sproxy/Application.hs b/src/Sproxy/Application.hs
new file mode 100644
index 0000000..2391220
--- /dev/null
+++ b/src/Sproxy/Application.hs
@@ -0,0 +1,372 @@
+{-# LANGUAGE LambdaCase #-}
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+module Sproxy.Application (
+ sproxy
+, redirect
+) where
+
+import Blaze.ByteString.Builder (toByteString)
+import Blaze.ByteString.Builder.ByteString (fromByteString)
+import Control.Exception (SomeException, catch)
+import Data.ByteString (ByteString)
+import Data.ByteString as BS (break, intercalate)
+import Data.Char (toLower)
+import Data.ByteString.Char8 (pack, unpack)
+import Data.ByteString.Lazy (fromStrict)
+import Data.Conduit (Flush(Chunk), mapOutput)
+import Data.HashMap.Strict as HM (HashMap, foldrWithKey, lookup)
+import Data.List (find, partition)
+import Data.Map as Map (delete, fromListWith, insert, insertWith, toList)
+import Data.Maybe (fromJust, fromMaybe)
+import Data.Monoid ((<>))
+import Data.Text (Text)
+import Data.Text.Encoding (decodeUtf8, encodeUtf8)
+import Data.Time.Clock.POSIX (posixSecondsToUTCTime)
+import Data.Word (Word16)
+import Data.Word8 (_colon)
+import Foreign.C.Types (CTime(..))
+import Network.HTTP.Client.Conduit (bodyReaderSource)
+import Network.HTTP.Conduit (requestBodySourceChunkedIO, requestBodySourceIO)
+import Network.HTTP.Types (RequestHeaders, ResponseHeaders, hConnection,
+ hContentLength, hContentType, hCookie, hLocation, methodGet)
+import Network.HTTP.Types.Status ( Status(..), badRequest400, forbidden403, found302,
+ internalServerError500, methodNotAllowed405, movedPermanently301,
+ networkAuthenticationRequired511, notFound404, ok200, seeOther303, temporaryRedirect307 )
+import Network.Socket (NameInfoFlag(NI_NUMERICHOST), getNameInfo)
+import Network.Wai.Conduit (sourceRequestBody, responseSource)
+import System.FilePath.Glob (Pattern, match)
+import System.Posix.Time (epochTime)
+import Text.InterpolatedString.Perl6 (qc)
+import Web.Cookie (Cookies, parseCookies, renderCookies)
+import qualified Network.HTTP.Client as BE
+import qualified Network.Wai as W
+import qualified Web.Cookie as WC
+
+import Sproxy.Application.Cookie (AuthCookie(..), AuthUser(..), cookieDecode, cookieEncode)
+import Sproxy.Application.OAuth2.Common (OAuth2Client(..))
+import Sproxy.Config(BackendConf(..))
+import Sproxy.Server.DB (Database, userExists, userGroups)
+import qualified Sproxy.Application.State as State
+import qualified Sproxy.Logging as Log
+
+
+redirect :: Word16 -> W.Application
+redirect p req resp =
+ case W.requestHeaderHost req of
+ Nothing -> badRequest "missing host" req resp
+ Just host -> do
+ Log.info $ "redirecting to " ++ show location ++ ": " ++ showReq req
+ resp $ W.responseBuilder status [(hLocation, location)] mempty
+ where
+ status = if W.requestMethod req == methodGet then movedPermanently301 else temporaryRedirect307
+ (domain, _) = BS.break (== _colon) host
+ newhost = if p == 443 then domain else domain <> ":" <> pack (show p)
+ location = "https://" <> newhost <> W.rawPathInfo req <> W.rawQueryString req
+
+
+sproxy :: ByteString -> Database -> HashMap Text OAuth2Client -> [(Pattern, BackendConf, BE.Manager)] -> W.Application
+sproxy key db oa2 backends = logException $ \req resp -> do
+ Log.debug $ "sproxy <<< " ++ showReq req
+ case W.requestHeaderHost req of
+ Nothing -> badRequest "missing host" req resp
+ Just host ->
+ case find (\(p, _, _) -> match p (unpack host)) backends of
+ Nothing -> notFound "backend" req resp
+ Just (_, be, mgr) -> do
+ let cookieName = pack $ beCookieName be
+ cookieDomain = pack <$> beCookieDomain be
+ case W.pathInfo req of
+ ["robots.txt"] -> get robots req resp
+ (".sproxy":proxy) ->
+ case proxy of
+ ["logout"] ->
+ case extractCookie key Nothing cookieName req of
+ Nothing -> notFound "logout without the cookie" req resp
+ Just _ -> get (logout cookieName cookieDomain) req resp
+ ["oauth2", provider] ->
+ case HM.lookup provider oa2 of
+ Nothing -> notFound "OAuth2 provider" req resp
+ Just oa2c -> get (oauth2callback key db (provider, oa2c) be) req resp
+ _ -> notFound "proxy" req resp
+ _ -> do
+ now <- Just <$> epochTime
+ case extractCookie key now cookieName req of
+ Nothing -> authenticationRequired key oa2 req resp
+ Just cs@(authCookie, _) ->
+ authorize db cs req >>= \case
+ Nothing -> forbidden authCookie req resp
+ Just req' -> forward mgr req' resp
+
+
+robots :: W.Application
+robots _ resp = resp $
+ W.responseLBS ok200 [(hContentType, "text/plain; charset=utf-8")]
+ "User-agent: *\nDisallow: /"
+
+
+oauth2callback :: ByteString -> Database -> (Text, OAuth2Client) -> BackendConf -> W.Application
+oauth2callback key db (provider, oa2c) be req resp =
+ case param "code" of
+ Nothing -> badRequest "missing auth code" req resp
+ Just code ->
+ case param "state" of
+ Nothing -> badRequest "missing auth state" req resp
+ Just state ->
+ case State.decode key state of
+ Left msg -> badRequest ("invalid state: " ++ msg) req resp
+ Right path -> do
+ au <- oauth2Authenticate oa2c code (redirectURL req provider)
+ let email = map toLower $ auEmail au
+ Log.info $ "login `" ++ email ++ "' by " ++ show provider
+ exists <- userExists db email
+ if exists then authenticate key be au{auEmail = email} path req resp
+ else userNotFound email req resp
+ where
+ param p = do
+ (_, v) <- find ((==) p . fst) $ W.queryString req
+ v
+
+
+-- XXX: RFC6265: the user agent MUST NOT attach more than one Cookie header field
+extractCookie :: ByteString -> Maybe CTime -> ByteString -> W.Request -> Maybe (AuthCookie, Cookies)
+extractCookie key now name req = do
+ (_, cookies) <- find ((==) hCookie . fst) $ W.requestHeaders req
+ (auth, others) <- discriminate cookies
+ case cookieDecode key auth of
+ Left _ -> Nothing
+ Right cookie -> if maybe True (acExpiry cookie >) now
+ then Just (cookie, others) else Nothing
+ where discriminate cs =
+ case partition ((==) name . fst) $ parseCookies cs of
+ ((_, x):_, xs) -> Just (x, xs)
+ _ -> Nothing
+
+
+authenticate :: ByteString -> BackendConf -> AuthUser -> ByteString -> W.Application
+authenticate key be user path req resp = do
+ now <- epochTime
+ let host = fromJust $ W.requestHeaderHost req
+ domain = pack <$> beCookieDomain be
+ expiry = now + CTime (beCookieMaxAge be)
+ authCookie = AuthCookie { acUser = user, acExpiry = expiry }
+ cookie = WC.def {
+ WC.setCookieName = pack $ beCookieName be
+ , WC.setCookieHttpOnly = True
+ , WC.setCookiePath = Just "/"
+ , WC.setCookieSameSite = Nothing
+ , WC.setCookieSecure = True
+ , WC.setCookieValue = cookieEncode key authCookie
+ , WC.setCookieDomain = domain
+ , WC.setCookieExpires = Just . posixSecondsToUTCTime . realToFrac $ expiry
+ }
+ resp $ W.responseLBS seeOther303 [
+ (hLocation, "https://" <> host <> path)
+ , ("Set-Cookie", toByteString $ WC.renderSetCookie cookie)
+ ] ""
+
+
+authorize :: Database -> (AuthCookie, Cookies) -> W.Request -> IO (Maybe W.Request)
+authorize db (authCookie, otherCookies) req = do
+ grps <- userGroups db email domain path method
+ if null grps then return Nothing
+ else do
+ ip <- pack . fromJust . fst <$> getNameInfo [NI_NUMERICHOST] True False (W.remoteHost req)
+ return . Just $ req {
+ W.requestHeaders = toList $
+ insert "From" (pack email) $
+ insert "X-Groups" (BS.intercalate "," grps) $
+ insert "X-Given-Name" given $
+ insert "X-Family-Name" family $
+ insert "X-Forwarded-Proto" "https" $
+ insertWith (flip combine) "X-Forwarded-For" ip $
+ setCookies otherCookies $
+ fromListWith combine $ W.requestHeaders req
+ }
+ where
+ user = acUser authCookie
+ email = auEmail user
+ given = pack $ auGivenName user
+ family = pack $ auFamilyName user
+ domain = decodeUtf8 . fromJust $ W.requestHeaderHost req
+ path = decodeUtf8 $ W.rawPathInfo req
+ method = decodeUtf8 $ W.requestMethod req
+ combine a b = a <> "," <> b
+ setCookies [] = delete hCookie
+ setCookies cs = insert hCookie (toByteString . renderCookies $ cs)
+
+
+forward :: BE.Manager -> W.Application
+forward mgr req resp = do
+ let beReq = BE.defaultRequest
+ { BE.method = W.requestMethod req
+ , BE.path = W.rawPathInfo req
+ , BE.queryString = W.rawQueryString req
+ , BE.requestHeaders = modifyRequestHeaders $ W.requestHeaders req
+ , BE.redirectCount = 0
+ , BE.decompress = const False
+ , BE.requestBody = case W.requestBodyLength req of
+ W.ChunkedBody -> requestBodySourceChunkedIO (sourceRequestBody req)
+ W.KnownLength l -> requestBodySourceIO (fromIntegral l) (sourceRequestBody req)
+ }
+ msg = unpack (BE.method beReq <> " " <> BE.path beReq <> BE.queryString beReq)
+ Log.debug $ "BACKEND <<< " ++ msg ++ " " ++ show (BE.requestHeaders beReq)
+ BE.withResponse beReq mgr $ \res -> do
+ let status = BE.responseStatus res
+ headers = modifyResponseHeaders $ BE.responseHeaders res
+ body = mapOutput (Chunk . fromByteString) . bodyReaderSource $ BE.responseBody res
+ logging = if statusCode status `elem` [ 400, 500 ] then
+ Log.warn else Log.debug
+ logging $ "BACKEND >>> " ++ show (statusCode status) ++ " on " ++ msg ++ "\n"
+ resp $ responseSource status headers body
+
+
+modifyRequestHeaders :: RequestHeaders -> RequestHeaders
+modifyRequestHeaders = filter (\(n, _) -> n `notElem` ban)
+ where
+ ban =
+ [
+ hConnection
+ , hContentLength -- XXX to avoid duplicate header
+ ]
+
+modifyResponseHeaders :: ResponseHeaders -> ResponseHeaders
+modifyResponseHeaders = filter (\(n, _) -> n `notElem` ban)
+ where
+ ban =
+ [
+ hConnection
+ ]
+
+authenticationRequired :: ByteString -> HashMap Text OAuth2Client -> W.Application
+authenticationRequired key oa2 req resp = do
+ Log.info $ "511 Unauthenticated: " ++ showReq req
+ resp $ W.responseLBS networkAuthenticationRequired511 [(hContentType, "text/html; charset=utf-8")] page
+ where
+ path = W.rawPathInfo req -- FIXME: make it more robust for non-GET or XMLHTTPRequest?
+ state = State.encode key path
+ authLink :: Text -> OAuth2Client -> ByteString -> ByteString
+ authLink provider oa2c html =
+ let u = oauth2AuthorizeURL oa2c state (redirectURL req provider)
+ d = pack $ oauth2Description oa2c
+ in [qc|{html}<p><a href="{u}">Authenticate with {d}</a></p>|]
+ authHtml = foldrWithKey authLink "" oa2
+ page = fromStrict [qc|
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="utf-8">
+ <title>Authentication required</title>
+ </head>
+ <body style="text-align:center;">
+ <h1>Authentication required</h1>
+ {authHtml}
+ </body>
+</html>
+|]
+
+
+forbidden :: AuthCookie -> W.Application
+forbidden ac req resp = do
+ Log.info $ "403 Forbidden (" ++ email ++ "): " ++ showReq req
+ resp $ W.responseLBS forbidden403 [(hContentType, "text/html; charset=utf-8")] page
+ where
+ email = auEmail . acUser $ ac
+ page = fromStrict [qc|
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="utf-8">
+ <title>Access Denied</title>
+ </head>
+ <body>
+ <h1>Access Denied</h1>
+ <p>You are currently logged in as <strong>{email}</strong></p>
+ <p><a href="/.sproxy/logout">Logout</a></p>
+ </body>
+</html>
+|]
+
+
+userNotFound :: String -> W.Application
+userNotFound email _ resp = do
+ Log.info $ "404 User not found (" ++ email ++ ")"
+ resp $ W.responseLBS notFound404 [(hContentType, "text/html; charset=utf-8")] page
+ where
+ page = fromStrict [qc|
+<!DOCTYPE html>
+<html lang="en">
+ <head>
+ <meta charset="utf-8">
+ <title>Access Denied</title>
+ </head>
+ <body>
+ <h1>Access Denied</h1>
+ <p>You are not allowed to login as <strong>{email}</strong></p>
+ <p><a href="/">Main page</a></p>
+ </body>
+</html>
+|]
+
+
+logout :: ByteString -> Maybe ByteString -> W.Application
+logout name domain req resp = do
+ let host = fromJust $ W.requestHeaderHost req
+ cookie = WC.def {
+ WC.setCookieName = name
+ , WC.setCookieHttpOnly = True
+ , WC.setCookiePath = Just "/"
+ , WC.setCookieSameSite = Just WC.sameSiteStrict
+ , WC.setCookieSecure = True
+ , WC.setCookieValue = "goodbye"
+ , WC.setCookieDomain = domain
+ , WC.setCookieExpires = Just . posixSecondsToUTCTime . realToFrac $ CTime 0
+ }
+ resp $ W.responseLBS found302 [
+ (hLocation, "https://" <> host)
+ , ("Set-Cookie", toByteString $ WC.renderSetCookie cookie)
+ ] ""
+
+
+badRequest ::String -> W.Application
+badRequest msg req resp = do
+ Log.warn $ "400 Bad Request (" ++ msg ++ "): " ++ showReq req
+ resp $ W.responseLBS badRequest400 [] "Bad Request"
+
+
+notFound ::String -> W.Application
+notFound msg req resp = do
+ Log.warn $ "404 Not Found (" ++ msg ++ "): " ++ showReq req
+ resp $ W.responseLBS notFound404 [] "Not Found"
+
+
+logException :: W.Middleware
+logException app req resp = catch (app req resp) $ \e -> do
+ Log.error $ "500 Internal Error: " ++ show (e :: SomeException) ++ " on " ++ showReq req
+ resp $ W.responseLBS internalServerError500 [] "Internal Error"
+
+
+get :: W.Middleware
+get app req resp
+ | W.requestMethod req == methodGet = app req resp
+ | otherwise = do
+ Log.warn $ "405 Method Not Allowed: " ++ showReq req
+ resp $ W.responseLBS methodNotAllowed405 [("Allow", "GET")] "Method Not Allowed"
+
+
+redirectURL :: W.Request -> Text -> ByteString
+redirectURL req provider =
+ "https://" <> fromJust (W.requestHeaderHost req)
+ <> "/.sproxy/oauth2/" <> encodeUtf8 provider
+
+
+-- XXX: make sure not to reveal the cookie, which can be valid (!)
+showReq :: W.Request -> String
+showReq req =
+ unpack ( W.requestMethod req <> " "
+ <> fromMaybe "<no host>" (W.requestHeaderHost req)
+ <> W.rawPathInfo req <> W.rawQueryString req <> " " )
+ ++ show (fromMaybe "-" $ W.requestHeaderReferer req) ++ " "
+ ++ show (fromMaybe "-" $ W.requestHeaderUserAgent req)
+ ++ " from " ++ show (W.remoteHost req)
+
diff --git a/src/Sproxy/Application/Cookie.hs b/src/Sproxy/Application/Cookie.hs
new file mode 100644
index 0000000..07cc162
--- /dev/null
+++ b/src/Sproxy/Application/Cookie.hs
@@ -0,0 +1,44 @@
+module Sproxy.Application.Cookie (
+ AuthCookie(..)
+, AuthUser(..)
+, cookieDecode
+, cookieEncode
+) where
+
+import Data.ByteString (ByteString)
+import Foreign.C.Types (CTime(..))
+import qualified Data.Serialize as DS
+
+import qualified Sproxy.Application.State as State
+
+data AuthUser = AuthUser {
+ auEmail :: String
+, auGivenName :: String
+, auFamilyName :: String
+}
+
+data AuthCookie = AuthCookie {
+ acUser :: AuthUser
+, acExpiry :: CTime
+}
+
+instance DS.Serialize AuthCookie where
+ put c = DS.put (auEmail u, auGivenName u, auFamilyName u, x)
+ where u = acUser c
+ x = (\(CTime i) -> i) $ acExpiry c
+ get = do
+ (e, n, f, x) <- DS.get
+ return AuthCookie {
+ acUser = AuthUser { auEmail = e, auGivenName = n, auFamilyName = f }
+ , acExpiry = CTime x
+ }
+
+
+cookieDecode :: ByteString -> ByteString -> Either String AuthCookie
+cookieDecode key d = State.decode key d >>= DS.decode
+
+
+cookieEncode :: ByteString -> AuthCookie -> ByteString
+cookieEncode key = State.encode key . DS.encode
+
+
diff --git a/src/Sproxy/Application/OAuth2.hs b/src/Sproxy/Application/OAuth2.hs
new file mode 100644
index 0000000..0f7d6e8
--- /dev/null
+++ b/src/Sproxy/Application/OAuth2.hs
@@ -0,0 +1,18 @@
+{-# LANGUAGE OverloadedStrings #-}
+module Sproxy.Application.OAuth2 (
+ providers
+) where
+
+import Data.HashMap.Strict (HashMap, fromList)
+import Data.Text (Text)
+
+import Sproxy.Application.OAuth2.Common (OAuth2Provider)
+import qualified Sproxy.Application.OAuth2.Google as Google
+import qualified Sproxy.Application.OAuth2.LinkedIn as LinkedIn
+
+providers :: HashMap Text OAuth2Provider
+providers = fromList [
+ ("google" , Google.provider)
+ , ("linkedin" , LinkedIn.provider)
+ ]
+
diff --git a/src/Sproxy/Application/OAuth2/Common.hs b/src/Sproxy/Application/OAuth2/Common.hs
new file mode 100644
index 0000000..07fb759
--- /dev/null
+++ b/src/Sproxy/Application/OAuth2/Common.hs
@@ -0,0 +1,39 @@
+{-# LANGUAGE OverloadedStrings #-}
+module Sproxy.Application.OAuth2.Common (
+ AccessTokenBody(..)
+, OAuth2Client(..)
+, OAuth2Provider
+) where
+
+import Control.Applicative (empty)
+import Data.Aeson (FromJSON, parseJSON, Value(Object), (.:))
+import Data.ByteString(ByteString)
+
+import Sproxy.Application.Cookie (AuthUser)
+
+data OAuth2Client = OAuth2Client {
+ oauth2Description :: String
+, oauth2AuthorizeURL
+ :: ByteString -- state
+ -> ByteString -- redirect url
+ -> ByteString
+, oauth2Authenticate
+ :: ByteString -- code
+ -> ByteString -- redirect url
+ -> IO AuthUser
+}
+
+type OAuth2Provider = (ByteString, ByteString) -> OAuth2Client
+
+-- | RFC6749. We ignore optional token_type ("Bearer" from Google, omitted by LinkedIn)
+-- and expires_in because we don't use them, *and* expires_in creates troubles:
+-- it's an integer from Google and string from LinkedIn (sic!)
+data AccessTokenBody = AccessTokenBody {
+ accessToken :: String
+} deriving (Eq, Show)
+
+instance FromJSON AccessTokenBody where
+ parseJSON (Object v) = AccessTokenBody
+ <$> v .: "access_token"
+ parseJSON _ = empty
+
diff --git a/src/Sproxy/Application/OAuth2/Google.hs b/src/Sproxy/Application/OAuth2/Google.hs
new file mode 100644
index 0000000..6b68f44
--- /dev/null
+++ b/src/Sproxy/Application/OAuth2/Google.hs
@@ -0,0 +1,78 @@
+{-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE OverloadedStrings #-}
+module Sproxy.Application.OAuth2.Google (
+ provider
+) where
+
+import Control.Applicative (empty)
+import Control.Exception (Exception, throwIO)
+import Data.Aeson (FromJSON, decode, parseJSON, Value(Object), (.:))
+import Data.ByteString.Lazy (ByteString)
+import Data.Monoid ((<>))
+import Data.Typeable (Typeable)
+import Network.HTTP.Types (hContentType)
+import Network.HTTP.Types.URI (urlEncode)
+import qualified Network.HTTP.Conduit as H
+
+import Sproxy.Application.Cookie (AuthUser(..))
+import Sproxy.Application.OAuth2.Common (AccessTokenBody(accessToken), OAuth2Client(..), OAuth2Provider)
+
+
+provider :: OAuth2Provider
+provider (client_id, client_secret) =
+ OAuth2Client {
+ oauth2Description = "Google"
+ , oauth2AuthorizeURL = \state redirect_uri ->
+ "https://accounts.google.com/o/oauth2/v2/auth"
+ <> "?scope=" <> urlEncode True "https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile"
+ <> "&client_id=" <> urlEncode True client_id
+ <> "&prompt=select_account"
+ <> "&redirect_uri=" <> urlEncode True redirect_uri
+ <> "&response_type=code"
+ <> "&state=" <> urlEncode True state
+
+ , oauth2Authenticate = \code redirect_uri -> do
+ let treq = H.setQueryString [
+ ("client_id" , Just client_id)
+ , ("client_secret" , Just client_secret)
+ , ("code" , Just code)
+ , ("grant_type" , Just "authorization_code")
+ , ("redirect_uri" , Just redirect_uri)
+ ] $ (H.parseRequest_ "POST https://www.googleapis.com/oauth2/v4/token") {
+ H.requestHeaders = [
+ (hContentType, "application/x-www-form-urlencoded")
+ ]
+ }
+ mgr <- H.newManager H.tlsManagerSettings
+ tresp <- H.httpLbs treq mgr
+ case decode $ H.responseBody tresp of
+ Nothing -> throwIO $ GoogleException tresp
+ Just atResp -> do
+ ureq <- H.parseRequest $ "https://www.googleapis.com/oauth2/v1/userinfo?access_token=" ++ accessToken atResp
+ uresp <- H.httpLbs ureq mgr
+ case decode $ H.responseBody uresp of
+ Nothing -> throwIO $ GoogleException uresp
+ Just u -> return AuthUser { auEmail = email u, auGivenName = givenName u, auFamilyName = familyName u }
+ }
+
+
+data GoogleException = GoogleException (H.Response ByteString)
+ deriving (Show, Typeable)
+
+
+instance Exception GoogleException
+
+
+data GoogleUserInfo = GoogleUserInfo {
+ email :: String
+, givenName :: String
+, familyName :: String
+} deriving (Eq, Show)
+
+instance FromJSON GoogleUserInfo where
+ parseJSON (Object v) = GoogleUserInfo
+ <$> v .: "email"
+ <*> v .: "given_name"
+ <*> v .: "family_name"
+ parseJSON _ = empty
+
diff --git a/src/Sproxy/Application/OAuth2/LinkedIn.hs b/src/Sproxy/Application/OAuth2/LinkedIn.hs
new file mode 100644
index 0000000..b60afde
--- /dev/null
+++ b/src/Sproxy/Application/OAuth2/LinkedIn.hs
@@ -0,0 +1,83 @@
+{-# LANGUAGE DeriveDataTypeable #-}
+{-# LANGUAGE OverloadedStrings #-}
+module Sproxy.Application.OAuth2.LinkedIn (
+ provider
+) where
+
+import Control.Applicative (empty)
+import Control.Exception (Exception, throwIO)
+import Data.Aeson (FromJSON, decode, parseJSON, Value(Object), (.:))
+import Data.ByteString.Char8 (pack)
+import Data.ByteString.Lazy (ByteString)
+import Data.Monoid ((<>))
+import Data.Typeable (Typeable)
+import Network.HTTP.Types (hContentType)
+import Network.HTTP.Types.URI (urlEncode)
+import qualified Network.HTTP.Conduit as H
+
+import Sproxy.Application.Cookie (AuthUser(..))
+import Sproxy.Application.OAuth2.Common (AccessTokenBody(accessToken), OAuth2Client(..), OAuth2Provider)
+
+
+provider :: OAuth2Provider
+provider (client_id, client_secret) =
+ OAuth2Client {
+ oauth2Description = "LinkedIn"
+ , oauth2AuthorizeURL = \state redirect_uri ->
+ "https://www.linkedin.com/oauth/v2/authorization"
+ <> "?scope=r_basicprofile%20r_emailaddress"
+ <> "&client_id=" <> urlEncode True client_id
+ <> "&redirect_uri=" <> urlEncode True redirect_uri
+ <> "&response_type=code"
+ <> "&state=" <> urlEncode True state
+
+ , oauth2Authenticate = \code redirect_uri -> do
+ let treq = H.setQueryString [
+ ("client_id" , Just client_id)
+ , ("client_secret" , Just client_secret)
+ , ("code" , Just code)
+ , ("grant_type" , Just "authorization_code")
+ , ("redirect_uri" , Just redirect_uri)
+ ] $ (H.parseRequest_ "POST https://www.linkedin.com/oauth/v2/accessToken") {
+ H.requestHeaders = [
+ (hContentType, "application/x-www-form-urlencoded")
+ ]
+ }
+ mgr <- H.newManager H.tlsManagerSettings
+ tresp <- H.httpLbs treq mgr
+ case decode $ H.responseBody tresp of
+ Nothing -> throwIO $ LinkedInException tresp
+ Just atResp -> do
+ let ureq = (H.parseRequest_ "https://api.linkedin.com/v1/people/\
+ \~:(email-address,first-name,last-name)?format=json") {
+ H.requestHeaders = [ ("Authorization", "Bearer " <> pack (accessToken atResp)) ]
+ }
+ uresp <- H.httpLbs ureq mgr
+ case decode $ H.responseBody uresp of
+ Nothing -> throwIO $ LinkedInException uresp
+ Just u -> return AuthUser { auEmail = emailAddress u
+ , auGivenName = firstName u
+ , auFamilyName = lastName u }
+ }
+
+
+data LinkedInException = LinkedInException (H.Response ByteString)
+ deriving (Show, Typeable)
+
+
+instance Exception LinkedInException
+
+
+data LinkedInUserInfo = LinkedInUserInfo {
+ emailAddress :: String
+, firstName :: String
+, lastName :: String
+} deriving (Eq, Show)
+
+instance FromJSON LinkedInUserInfo where
+ parseJSON (Object v) = LinkedInUserInfo
+ <$> v .: "emailAddress"
+ <*> v .: "firstName"
+ <*> v .: "lastName"
+ parseJSON _ = empty
+
diff --git a/src/Sproxy/Application/State.hs b/src/Sproxy/Application/State.hs
new file mode 100644
index 0000000..29d9252
--- /dev/null
+++ b/src/Sproxy/Application/State.hs
@@ -0,0 +1,30 @@
+module Sproxy.Application.State (
+ decode
+, encode
+) where
+
+import Data.ByteString (ByteString)
+import Data.ByteString.Lazy (fromStrict, toStrict)
+import Data.Digest.Pure.SHA (hmacSha1, bytestringDigest)
+import qualified Data.ByteString.Base64 as Base64
+import qualified Data.Serialize as DS
+
+
+-- FIXME: Compress / decompress ?
+
+
+encode :: ByteString -> ByteString -> ByteString
+encode key payload = Base64.encode . DS.encode $ (payload, digest key payload)
+
+
+decode :: ByteString -> ByteString -> Either String ByteString
+decode key d = do
+ (payload, dgst) <- DS.decode =<< Base64.decode d
+ if dgst /= digest key payload
+ then Left "junk"
+ else Right payload
+
+
+digest :: ByteString -> ByteString -> ByteString
+digest key payload = toStrict . bytestringDigest $ hmacSha1 (fromStrict key) (fromStrict payload)
+
diff --git a/src/Sproxy/Config.hs b/src/Sproxy/Config.hs
new file mode 100644
index 0000000..30a8bae
--- /dev/null
+++ b/src/Sproxy/Config.hs
@@ -0,0 +1,88 @@
+{-# LANGUAGE OverloadedStrings #-}
+module Sproxy.Config (
+ BackendConf(..)
+, ConfigFile(..)
+, OAuth2Conf(..)
+) where
+
+import Control.Applicative (empty)
+import Data.Aeson (FromJSON, parseJSON)
+import Data.HashMap.Strict (HashMap)
+import Data.Int (Int64)
+import Data.Text (Text)
+import Data.Word (Word16)
+import Data.Yaml (Value(Object), (.:), (.:?), (.!=))
+
+import Sproxy.Logging (LogLevel(Debug))
+
+data ConfigFile = ConfigFile {
+ cfListen :: Word16
+, cfUser :: String
+, cfHome :: FilePath
+, cfLogLevel :: LogLevel
+, cfSslCert :: FilePath
+, cfSslKey :: FilePath
+, cfSslCertChain :: [FilePath]
+, cfKey :: Maybe FilePath
+, cfListen80 :: Maybe Bool
+, cfBackends :: [BackendConf]
+, cfOAuth2 :: HashMap Text OAuth2Conf
+, cfDatabase :: Maybe String
+, cfPgPassFile :: Maybe FilePath
+, cfHTTP2 :: Bool
+} deriving (Show)
+
+instance FromJSON ConfigFile where
+ parseJSON (Object m) = ConfigFile <$>
+ m .:? "listen" .!= 443
+ <*> m .:? "user" .!= "sproxy"
+ <*> m .:? "home" .!= "."
+ <*> m .:? "log_level" .!= Debug
+ <*> m .: "ssl_cert"
+ <*> m .: "ssl_key"
+ <*> m .:? "ssl_cert_chain" .!= []
+ <*> m .:? "key"
+ <*> m .:? "listen80"
+ <*> m .: "backends"
+ <*> m .: "oauth2"
+ <*> m .:? "database"
+ <*> m .:? "pgpassfile"
+ <*> m .:? "http2" .!= True
+ parseJSON _ = empty
+
+
+data BackendConf = BackendConf {
+ beName :: String
+, beAddress :: String
+, bePort :: Maybe Word16
+, beSocket :: Maybe FilePath
+, beCookieName :: String
+, beCookieDomain :: Maybe String
+, beCookieMaxAge :: Int64
+, beConnCount :: Int
+} deriving (Show)
+
+instance FromJSON BackendConf where
+ parseJSON (Object m) = BackendConf <$>
+ m .:? "name" .!= "*"
+ <*> m .:? "address" .!= "127.0.0.1"
+ <*> m .:? "port"
+ <*> m .:? "socket"
+ <*> m .:? "cookie_name" .!= "sproxy"
+ <*> m .:? "cookie_domain"
+ <*> m .:? "cookie_max_age" .!= (7 * 24 * 60 * 60)
+ <*> m .:? "conn_count" .!= 32
+ parseJSON _ = empty
+
+
+data OAuth2Conf = OAuth2Conf {
+ oa2ClientId :: String
+, oa2ClientSecret :: FilePath
+} deriving (Show)
+
+instance FromJSON OAuth2Conf where
+ parseJSON (Object m) = OAuth2Conf <$>
+ m .: "client_id"
+ <*> m .: "client_secret"
+ parseJSON _ = empty
+
diff --git a/src/Sproxy/Logging.hs b/src/Sproxy/Logging.hs
new file mode 100644
index 0000000..651a73a
--- /dev/null
+++ b/src/Sproxy/Logging.hs
@@ -0,0 +1,99 @@
+module Sproxy.Logging (
+ LogLevel(..)
+, debug
+, error
+, info
+, level
+, start
+, warn
+) where
+
+import Prelude hiding (error)
+
+import Control.Applicative (empty)
+import Control.Concurrent (forkIO)
+import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan)
+import Control.Monad (forever, when)
+import Data.Aeson (FromJSON, ToJSON)
+import Data.Char (toLower)
+import Data.IORef (IORef, newIORef, readIORef, writeIORef)
+import System.IO (hPrint, stderr)
+import System.IO.Unsafe (unsafePerformIO)
+import Text.Read (readMaybe)
+import qualified Data.Aeson as JSON
+import qualified Data.Text as T
+
+start :: LogLevel -> IO ()
+start None = return ()
+start lvl = do
+ writeIORef logLevel lvl
+ ch <- readIORef chanRef
+ _ <- forkIO . forever $ readChan ch >>= hPrint stderr
+ return ()
+
+info :: String -> IO ()
+info = send . Message Info
+
+warn:: String -> IO ()
+warn = send . Message Warning
+
+error:: String -> IO ()
+error = send . Message Error
+
+debug :: String -> IO ()
+debug = send . Message Debug
+
+
+send :: Message -> IO ()
+send msg@(Message l _) = do
+ lvl <- level
+ when (l <= lvl) $ do
+ ch <- readIORef chanRef
+ writeChan ch msg
+
+{-# NOINLINE chanRef #-}
+chanRef :: IORef (Chan Message)
+chanRef = unsafePerformIO (newChan >>= newIORef)
+
+{-# NOINLINE logLevel #-}
+logLevel :: IORef LogLevel
+logLevel = unsafePerformIO (newIORef None)
+
+level :: IO LogLevel
+level = readIORef logLevel
+
+
+data LogLevel = None | Error | Warning | Info | Debug
+ deriving (Enum, Ord, Eq)
+
+instance Show LogLevel where
+ show None = "NONE"
+ show Error = "ERROR"
+ show Warning = "WARN"
+ show Info = "INFO"
+ show Debug = "DEBUG"
+
+instance Read LogLevel where
+ readsPrec _ s
+ | l == "none" = [ (None, "") ]
+ | l == "error" = [ (Error, "") ]
+ | l == "warn" = [ (Warning, "") ]
+ | l == "info" = [ (Info, "") ]
+ | l == "debug" = [ (Debug, "") ]
+ | otherwise = [ ]
+ where l = map toLower s
+
+instance ToJSON LogLevel where
+ toJSON = JSON.String . T.pack . show
+
+instance FromJSON LogLevel where
+ parseJSON (JSON.String s) =
+ maybe (fail $ "unknown log level: " ++ show s) return (readMaybe . T.unpack $ s)
+ parseJSON _ = empty
+
+
+data Message = Message LogLevel String
+
+instance Show Message where
+ show (Message lvl str) = show lvl ++ ": " ++ str
+
diff --git a/src/Sproxy/Server.hs b/src/Sproxy/Server.hs
new file mode 100644
index 0000000..bd2af17
--- /dev/null
+++ b/src/Sproxy/Server.hs
@@ -0,0 +1,190 @@
+module Sproxy.Server (
+ server
+) where
+
+import Control.Concurrent (forkIO)
+import Control.Exception (bracketOnError)
+import Control.Monad (void, when)
+import Data.ByteString as BS (hGetLine, readFile)
+import Data.ByteString.Char8 (pack)
+import Data.HashMap.Strict as HM (fromList, lookup, toList)
+import Data.Maybe (fromMaybe)
+import Data.Text (Text)
+import Data.Word (Word16)
+import Data.Yaml (decodeFileEither)
+import Network.HTTP.Client (Manager, ManagerSettings(..), defaultManagerSettings, newManager, socketConnection)
+import Network.HTTP.Client.Internal (Connection)
+import Network.Socket ( Family(AF_INET, AF_UNIX), SockAddr(SockAddrInet, SockAddrUnix),
+ SocketOption(ReuseAddr), SocketType(Stream), bind, close, connect, inet_addr,
+ listen, maxListenQueue, setSocketOption, socket )
+import Network.Wai.Handler.WarpTLS (tlsSettingsChain, runTLSSocket)
+import Network.Wai.Handler.Warp (defaultSettings, setHTTP2Disabled, runSettingsSocket)
+import System.Entropy (getEntropy)
+import System.Environment (setEnv)
+import System.Exit (exitFailure)
+import System.FilePath.Glob (compile)
+import System.IO (IOMode(ReadMode), hIsEOF, hPutStrLn, stderr, withFile)
+import System.Posix.User ( GroupEntry(..), UserEntry(..),
+ getAllGroupEntries, getRealUserID,
+ getUserEntryForName, setGroupID, setGroups, setUserID )
+
+import Sproxy.Application (sproxy, redirect)
+import Sproxy.Application.OAuth2.Common (OAuth2Client)
+import Sproxy.Config (BackendConf(..), ConfigFile(..), OAuth2Conf(..))
+import qualified Sproxy.Application.OAuth2 as OAuth2
+import qualified Sproxy.Logging as Log
+import qualified Sproxy.Server.DB as DB
+
+
+server :: FilePath -> IO ()
+server configFile = do
+ cf <- readConfigFile configFile
+ Log.start $ cfLogLevel cf
+ Log.debug $ show cf
+
+ sock <- socket AF_INET Stream 0
+ setSocketOption sock ReuseAddr 1
+ bind sock $ SockAddrInet (fromIntegral $ cfListen cf) 0
+
+ maybe80 <- if fromMaybe (443 == cfListen cf) (cfListen80 cf)
+ then do
+ sock80 <- socket AF_INET Stream 0
+ setSocketOption sock80 ReuseAddr 1
+ bind sock80 $ SockAddrInet 80 0
+ return (Just sock80)
+ else
+ return Nothing
+
+ uid <- getRealUserID
+ when (0 == uid) $ do
+ let user = cfUser cf
+ Log.info $ "switching to user " ++ show user
+ u <- getUserEntryForName user
+ groupIDs <- map groupID . filter (elem user . groupMembers)
+ <$> getAllGroupEntries
+ setGroups groupIDs
+ setGroupID $ userGroupID u
+ setUserID $ userID u
+
+ case cfPgPassFile cf of
+ Nothing -> return ()
+ Just f -> do
+ Log.info $ "pgpassfile: " ++ show f
+ setEnv "PGPASSFILE" f
+
+ db <- DB.start (cfHome cf) (newDataSource cf)
+
+ key <- maybe
+ (Log.info "using new random key" >> getEntropy 32)
+ (\f -> Log.info ("reading key from " ++ f) >> BS.readFile f)
+ (cfKey cf)
+
+ case maybe80 of
+ Nothing -> return ()
+ Just sock80 -> do
+ Log.info "listening on port 80 (HTTP redirect)"
+ listen sock80 maxListenQueue
+ void . forkIO $ runSettingsSocket defaultSettings sock80 (redirect $ cfListen cf)
+
+ oauth2clients <- HM.fromList <$> mapM newOAuth2Client (HM.toList (cfOAuth2 cf))
+
+ backends <-
+ mapM (\be -> do
+ m <- newBackendManager be
+ return (compile $ beName be, be, m)
+ ) $ cfBackends cf
+
+ let
+ settings =
+ (if cfHTTP2 cf then id else setHTTP2Disabled)
+ defaultSettings
+
+ -- XXX 2048 is from bindPortTCP from streaming-commons used internally by runTLS.
+ -- XXX Since we don't call runTLS, we listen socket here with the same options.
+ Log.info $ "listening on port " ++ show (cfListen cf) ++ " (HTTPS)"
+ listen sock (max 2048 maxListenQueue)
+ runTLSSocket
+ (tlsSettingsChain (cfSslCert cf) (cfSslCertChain cf) (cfSslKey cf))
+ settings
+ sock
+ (sproxy key db oauth2clients backends)
+
+
+newDataSource :: ConfigFile -> Maybe DB.DataSource
+newDataSource cf =
+ case cfDatabase cf of
+ Just str -> Just $ DB.PostgreSQL str
+ Nothing -> Nothing
+
+
+newOAuth2Client :: (Text, OAuth2Conf) -> IO (Text, OAuth2Client)
+newOAuth2Client (name, cfg) =
+ case HM.lookup name OAuth2.providers of
+ Nothing -> do Log.error $ "OAuth2 provider " ++ show name ++ " is not supported"
+ exitFailure
+ Just provider -> do
+ Log.info $ "oauth2: adding " ++ show name
+ client_secret <- withFile secret_file ReadMode $ \h -> do
+ empty <- hIsEOF h
+ if empty then do
+ Log.error $ "oauth2: empty secret file for "
+ ++ show name ++ ": " ++ show secret_file
+ return $ pack ""
+ else BS.hGetLine h
+ return (name, provider (pack client_id, client_secret))
+ where client_id = oa2ClientId cfg
+ secret_file = oa2ClientSecret cfg
+
+
+newBackendManager :: BackendConf -> IO Manager
+newBackendManager be = do
+ openConn <-
+ case (beSocket be, bePort be) of
+ (Just f, Nothing) -> do
+ Log.info $ "backend `" ++ beName be ++ "' on UNIX socket " ++ f
+ return $ openUnixSocketConnection f
+
+ (Nothing, Just n) -> do
+ Log.info $ "backend `" ++ beName be ++ "' on " ++ beAddress be ++ ":" ++ show n
+ return $ openTCPConnection (beAddress be) n
+
+ _ -> do
+ Log.error "either backend port number or UNIX socket path is required."
+ exitFailure
+
+ newManager defaultManagerSettings {
+ managerRawConnection = return $ \_ _ _ -> openConn
+ , managerConnCount = beConnCount be
+ }
+
+
+openUnixSocketConnection :: FilePath -> IO Connection
+openUnixSocketConnection f =
+ bracketOnError
+ (socket AF_UNIX Stream 0)
+ close
+ (\s -> do
+ connect s (SockAddrUnix f)
+ socketConnection s 8192)
+
+
+openTCPConnection :: String -> Word16 -> IO Connection
+openTCPConnection addr port =
+ bracketOnError
+ (socket AF_INET Stream 0)
+ close
+ (\s -> do
+ a <- inet_addr addr
+ connect s (SockAddrInet (fromIntegral port) a)
+ socketConnection s 8192)
+
+
+readConfigFile :: FilePath -> IO ConfigFile
+readConfigFile f = do
+ r <- decodeFileEither f
+ case r of
+ Left e -> do
+ hPutStrLn stderr $ "FATAL: " ++ f ++ ": " ++ show e
+ exitFailure
+ Right cf -> return cf
+
diff --git a/src/Sproxy/Server/DB.hs b/src/Sproxy/Server/DB.hs
new file mode 100644
index 0000000..b760afc
--- /dev/null
+++ b/src/Sproxy/Server/DB.hs
@@ -0,0 +1,189 @@
+{-# LANGUAGE OverloadedStrings #-}
+{-# LANGUAGE QuasiQuotes #-}
+module Sproxy.Server.DB (
+ Database
+, DataSource(..)
+, userExists
+, userGroups
+, start
+) where
+
+import Control.Concurrent (forkIO, threadDelay)
+import Control.Exception (SomeException, bracket, catch, finally)
+import Control.Monad (forever, void)
+import Data.ByteString (ByteString)
+import Data.ByteString.Char8 (pack)
+import Data.Pool (Pool, createPool, withResource)
+import Data.Text (Text, toLower, unpack)
+import Data.Text.Encoding (encodeUtf8)
+import Database.SQLite.Simple (NamedParam((:=)))
+import Text.InterpolatedString.Perl6 (q, qc)
+import qualified Database.PostgreSQL.Simple as PG
+import qualified Database.SQLite.Simple as SQLite
+
+import qualified Sproxy.Logging as Log
+
+
+type Database = Pool SQLite.Connection
+
+data DataSource = PostgreSQL String -- | File FilePath
+
+{- TODO:
+ - Hash remote tables and update the local only when the remote change
+ - Switch to REGEX
+ - Generalize sync procedures for different tables
+ -}
+
+start :: FilePath -> Maybe DataSource -> IO Database
+start home ds = do
+ Log.info $ "home directory: " ++ show home
+ db <- createPool
+ (do c <- SQLite.open $ home ++ "/sproxy.sqlite3"
+ lvl <- Log.level
+ SQLite.setTrace c (if lvl == Log.Debug then Just $ Log.debug . unpack else Nothing)
+ return c)
+ SQLite.close
+ 1 -- stripes
+ 3600 -- keep alive (seconds). FIXME: no much sense as it's a local file
+ 128 -- max connections. FIXME: make configurable?
+
+ withResource db $ \c -> SQLite.execute_ c "PRAGMA journal_mode=WAL"
+ populate db ds
+ return db
+
+
+userExists :: Database -> String -> IO Bool
+userExists db email = do
+ r <- withResource db $ \c -> fmap SQLite.fromOnly <$> SQLite.queryNamed c
+ "SELECT EXISTS (SELECT 1 FROM group_member WHERE :email LIKE email LIMIT 1)"
+ [ ":email" := email ]
+ return $ head r
+
+
+userGroups :: Database -> String -> Text -> Text -> Text -> IO [ByteString]
+userGroups db email domain path method =
+ withResource db $ \c -> fmap (encodeUtf8 . SQLite.fromOnly) <$> SQLite.queryNamed c [q|
+ SELECT gm."group" FROM group_privilege gp JOIN group_member gm ON gm."group" = gp."group"
+ WHERE :email LIKE gm.email
+ AND :domain LIKE gp.domain
+ AND gp.privilege IN (
+ SELECT privilege FROM privilege_rule
+ WHERE :domain LIKE domain
+ AND :path LIKE path
+ AND method = :method
+ ORDER BY length(path) - length(replace(path, '/', '')) DESC LIMIT 1
+ )
+ |] [ ":email" := email -- XXX always in lower case
+ , ":domain" := toLower domain
+ , ":path" := path
+ , ":method" := method -- XXX case-sensitive by RFC2616
+ ]
+
+
+populate :: Database -> Maybe DataSource -> IO ()
+
+populate db Nothing = do
+ Log.warn "db: no data source defined"
+ withResource db $ \c -> SQLite.withTransaction c $ do
+ createGroupMember c
+ createGroupPrivilege c
+ createPrivilegeRule c
+
+-- XXX We keep only required minimum of the data, without any integrity check.
+-- XXX Integrity check should be done somewhere else, e. g. in the master PostgreSQL database,
+-- XXX or during importing the config file.
+populate db (Just (PostgreSQL connstr)) =
+ void . forkIO . forever . flip finally (7 `minutes` threadDelay)
+ . logException $ do
+ Log.info $ "db: synchronizing with " ++ show connstr
+ withResource db $ \c -> SQLite.withTransaction c $
+ bracket (PG.connectPostgreSQL $ pack connstr) PG.close $
+ \pg -> PG.withTransaction pg $ do
+
+ Log.info "db: syncing group_member"
+ dropGroupMember c
+ createGroupMember c
+ PG.forEach_ pg
+ [q|SELECT "group", lower(email) FROM group_member|] $ \r ->
+ SQLite.execute c
+ [q|INSERT INTO group_member("group", email) VALUES (?, ?)|]
+ (r :: (Text, Text))
+ count c "group_member"
+
+ Log.info "db: syncing group_privilege"
+ dropGroupPrivilege c
+ createGroupPrivilege c
+ PG.forEach_ pg
+ [q|SELECT "group", lower(domain), privilege FROM group_privilege|] $ \r ->
+ SQLite.execute c
+ [q|INSERT INTO group_privilege("group", domain, privilege) VALUES (?, ?, ?)|]
+ (r :: (Text, Text, Text))
+ count c "group_privilege"
+
+ Log.info "db: syncing privilege_rule"
+ dropPrivilegeRule c
+ createPrivilegeRule c
+ PG.forEach_ pg
+ [q|SELECT lower(domain), privilege, path, method FROM privilege_rule|] $ \r ->
+ SQLite.execute c
+ [q|INSERT INTO privilege_rule(domain, privilege, path, method) VALUES (?, ?, ?, ?)|]
+ (r :: (Text, Text, Text, Text))
+ count c "privilege_rule"
+
+
+dropGroupMember :: SQLite.Connection -> IO ()
+dropGroupMember c = SQLite.execute_ c "DROP TABLE IF EXISTS group_member"
+
+createGroupMember :: SQLite.Connection -> IO ()
+createGroupMember c = SQLite.execute_ c [q|
+ CREATE TABLE IF NOT EXISTS group_member (
+ "group" TEXT,
+ email TEXT,
+ PRIMARY KEY ("group", email)
+ )
+|]
+
+
+dropGroupPrivilege :: SQLite.Connection -> IO ()
+dropGroupPrivilege c = SQLite.execute_ c "DROP TABLE IF EXISTS group_privilege"
+
+createGroupPrivilege :: SQLite.Connection -> IO ()
+createGroupPrivilege c = SQLite.execute_ c [q|
+ CREATE TABLE IF NOT EXISTS group_privilege (
+ "group" TEXT,
+ domain TEXT,
+ privilege TEXT,
+ PRIMARY KEY ("group", domain, privilege)
+ )
+|]
+
+
+dropPrivilegeRule :: SQLite.Connection -> IO ()
+dropPrivilegeRule c = SQLite.execute_ c "DROP TABLE IF EXISTS privilege_rule"
+
+createPrivilegeRule :: SQLite.Connection -> IO ()
+createPrivilegeRule c = SQLite.execute_ c [q|
+ CREATE TABLE IF NOT EXISTS privilege_rule (
+ domain TEXT,
+ privilege TEXT,
+ path TEXT,
+ method TEXT,
+ PRIMARY KEY (domain, path, method)
+ )
+|]
+
+
+count :: SQLite.Connection -> String -> IO ()
+count c table = do
+ r <- fmap SQLite.fromOnly <$> SQLite.query_ c [qc|SELECT COUNT(*) FROM {table}|]
+ Log.info $ "db: " ++ table ++ " rows: " ++ show (head r :: Integer)
+
+
+logException :: IO () -> IO ()
+logException a = catch a $ \e ->
+ Log.error $ "db: " ++ show (e :: SomeException)
+
+
+minutes :: Int -> (Int -> IO ()) -> IO ()
+minutes us f = f $ us * 60 * 1000000
+