diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Main.hs | 25 | ||||
-rw-r--r-- | src/Sproxy/Application.hs | 488 | ||||
-rw-r--r-- | src/Sproxy/Application/Access.hs | 18 | ||||
-rw-r--r-- | src/Sproxy/Application/Cookie.hs | 72 | ||||
-rw-r--r-- | src/Sproxy/Application/OAuth2/Common.hs | 45 | ||||
-rw-r--r-- | src/Sproxy/Application/State.hs | 25 | ||||
-rw-r--r-- | src/Sproxy/Config.hs | 139 | ||||
-rw-r--r-- | src/Sproxy/Logging.hs | 69 | ||||
-rw-r--r-- | src/Sproxy/Server.hs | 181 | ||||
-rw-r--r-- | src/Sproxy/Server/DB.hs | 213 | ||||
-rw-r--r-- | src/Sproxy/Server/DB/DataFile.hs | 83 |
11 files changed, 708 insertions, 650 deletions
diff --git a/src/Main.hs b/src/Main.hs index 7101af0..20e3ebd 100644 --- a/src/Main.hs +++ b/src/Main.hs @@ -1,20 +1,24 @@ {-# LANGUAGE QuasiQuotes #-} -module Main ( - main -) where + +module Main + ( main + ) where import Data.Maybe (fromJust) import Data.Version (showVersion) import Paths_sproxy2 (version) -- from cabal +import qualified System.Console.Docopt.NoTH as O import System.Environment (getArgs) import Text.InterpolatedString.Perl6 (qc) -import qualified System.Console.Docopt.NoTH as O import Sproxy.Server (server) usage :: String -usage = "sproxy2 " ++ showVersion version ++ - " - HTTP proxy for authenticating users via OAuth2" ++ [qc| +usage = + "sproxy2 " ++ + showVersion version ++ + " - HTTP proxy for authenticating users via OAuth2" ++ + [qc| Usage: sproxy2 [options] @@ -30,8 +34,7 @@ main = do doco <- O.parseUsageOrExit usage args <- O.parseArgsOrExit doco =<< getArgs if args `O.isPresent` O.longOption "help" - then putStrLn $ O.usage doco - else do - let configFile = fromJust . O.getArg args $ O.longOption "config" - server configFile - + then putStrLn $ O.usage doco + else do + let configFile = fromJust . O.getArg args $ O.longOption "config" + server configFile diff --git a/src/Sproxy/Application.hs b/src/Sproxy/Application.hs index 791e59c..3311f05 100644 --- a/src/Sproxy/Application.hs +++ b/src/Sproxy/Application.hs @@ -1,14 +1,17 @@ {-# LANGUAGE LambdaCase #-} {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} -module Sproxy.Application ( - sproxy -, redirect -) where + +module Sproxy.Application + ( sproxy + , redirect + ) where import Blaze.ByteString.Builder (toByteString) import Blaze.ByteString.Builder.ByteString (fromByteString) -import Control.Exception (Exception, Handler(..), SomeException, catches, displayException) +import Control.Exception + (Exception, Handler(..), SomeException, catches, displayException) +import qualified Data.Aeson as JSON import Data.ByteString (ByteString) import Data.ByteString as BS (break, intercalate) import Data.ByteString.Char8 (pack, unpack) @@ -16,7 +19,8 @@ import Data.ByteString.Lazy (fromStrict) import Data.Conduit (Flush(Chunk), mapOutput) import Data.HashMap.Strict as HM (HashMap, foldrWithKey, lookup) import Data.List (find, partition) -import Data.Map as Map (delete, fromListWith, insert, insertWith, toList) +import Data.Map as Map + (delete, fromListWith, insert, insertWith, toList) import Data.Maybe (fromJust, fromMaybe) import Data.Monoid ((<>)) import Data.Text (Text) @@ -25,34 +29,38 @@ import Data.Time.Clock.POSIX (posixSecondsToUTCTime) import Data.Word (Word16) import Data.Word8 (_colon) import Foreign.C.Types (CTime(..)) +import qualified Network.HTTP.Client as BE import Network.HTTP.Client.Conduit (bodyReaderSource) -import Network.HTTP.Conduit (requestBodySourceChunkedIO, requestBodySourceIO) -import Network.HTTP.Types (RequestHeaders, ResponseHeaders, methodGet, methodPost) -import Network.HTTP.Types.Header ( hConnection, - hContentLength, hContentType, hCookie, hLocation, hTransferEncoding ) -import Network.HTTP.Types.Status ( Status(..), badGateway502, badRequest400, forbidden403, - found302, internalServerError500, methodNotAllowed405, movedPermanently301, - networkAuthenticationRequired511, notFound404, ok200, seeOther303, temporaryRedirect307 ) +import Network.HTTP.Conduit + (requestBodySourceChunkedIO, requestBodySourceIO) +import Network.HTTP.Types + (RequestHeaders, ResponseHeaders, methodGet, methodPost) +import Network.HTTP.Types.Header + (hConnection, hContentLength, hContentType, hCookie, hLocation, + hTransferEncoding) +import Network.HTTP.Types.Status + (Status(..), badGateway502, badRequest400, forbidden403, found302, + internalServerError500, methodNotAllowed405, movedPermanently301, + networkAuthenticationRequired511, notFound404, ok200, seeOther303, + temporaryRedirect307) import Network.Socket (NameInfoFlag(NI_NUMERICHOST), getNameInfo) -import Network.Wai.Conduit (sourceRequestBody, responseSource) +import qualified Network.Wai as W +import Network.Wai.Conduit (responseSource, sourceRequestBody) import System.FilePath.Glob (Pattern, match) import System.Posix.Time (epochTime) import Text.InterpolatedString.Perl6 (qc) import Web.Cookie (Cookies, parseCookies, renderCookies) -import qualified Data.Aeson as JSON -import qualified Network.HTTP.Client as BE -import qualified Network.Wai as W import qualified Web.Cookie as WC -import Sproxy.Application.Cookie ( AuthCookie(..), AuthUser, - cookieDecode, cookieEncode, getEmail, getEmailUtf8, getFamilyNameUtf8, - getGivenNameUtf8 ) +import Sproxy.Application.Cookie + (AuthCookie(..), AuthUser, cookieDecode, cookieEncode, getEmail, + getEmailUtf8, getFamilyNameUtf8, getGivenNameUtf8) import Sproxy.Application.OAuth2.Common (OAuth2Client(..)) -import Sproxy.Config(BackendConf(..)) -import Sproxy.Server.DB (Database, userAccess, userExists, userGroups) import qualified Sproxy.Application.State as State +import Sproxy.Config (BackendConf(..)) import qualified Sproxy.Logging as Log - +import Sproxy.Server.DB + (Database, userAccess, userExists, userGroups) redirect :: Word16 -> W.Application redirect p req resp = @@ -61,151 +69,182 @@ redirect p req resp = Just domain -> do Log.info $ "redirecting to " ++ show location ++ ": " ++ showReq req resp $ W.responseBuilder status [(hLocation, location)] mempty - where - status = if W.requestMethod req == methodGet then movedPermanently301 else temporaryRedirect307 - newhost = if p == 443 then domain else domain <> ":" <> pack (show p) - location = "https://" <> newhost <> W.rawPathInfo req <> W.rawQueryString req - - -sproxy :: ByteString -> Database -> HashMap Text OAuth2Client -> [(Pattern, BackendConf, BE.Manager)] -> W.Application -sproxy key db oa2 backends = logException $ \req resp -> do - Log.debug $ "sproxy <<< " ++ showReq req - case requestDomain req of - Nothing -> badRequest "missing host" req resp - Just domain -> - case find (\(p, _, _) -> match p (unpack domain)) backends of - Nothing -> notFound "backend" req resp - Just (_, be, mgr) -> do - let cookieName = pack $ beCookieName be - cookieDomain = pack <$> beCookieDomain be - case W.pathInfo req of - ["robots.txt"] -> get robots req resp - (".sproxy":proxy) -> - case proxy of - - ["logout"] -> get (logout key cookieName cookieDomain) req resp - - ["oauth2", provider] -> + where status = + if W.requestMethod req == methodGet + then movedPermanently301 + else temporaryRedirect307 + newhost = + if p == 443 + then domain + else domain <> ":" <> pack (show p) + location = + "https://" <> newhost <> W.rawPathInfo req <> W.rawQueryString req + +sproxy :: + ByteString + -> Database + -> HashMap Text OAuth2Client + -> [(Pattern, BackendConf, BE.Manager)] + -> W.Application +sproxy key db oa2 backends = + logException $ \req resp -> do + Log.debug $ "sproxy <<< " ++ showReq req + case requestDomain req of + Nothing -> badRequest "missing host" req resp + Just domain -> + case find (\(p, _, _) -> match p (unpack domain)) backends of + Nothing -> notFound "backend" req resp + Just (_, be, mgr) -> do + let cookieName = pack $ beCookieName be + cookieDomain = pack <$> beCookieDomain be + case W.pathInfo req of + ["robots.txt"] -> get robots req resp + (".sproxy":proxy) -> + case proxy of + ["logout"] -> + get (logout key cookieName cookieDomain) req resp + ["oauth2", provider] -> case HM.lookup provider oa2 of Nothing -> notFound "OAuth2 provider" req resp - Just oa2c -> get (oauth2callback key db (provider, oa2c) be) req resp - - ["access"] -> do - now <- Just <$> epochTime - case extractCookie key now cookieName req of - Nothing -> authenticationRequired key oa2 req resp - Just (authCookie, _) -> post (checkAccess db authCookie) req resp - - _ -> notFound "proxy" req resp - - _ -> do - now <- Just <$> epochTime - case extractCookie key now cookieName req of - Nothing -> authenticationRequired key oa2 req resp - Just cs@(authCookie, _) -> - authorize db cs req >>= \case - Nothing -> forbidden authCookie req resp - Just req' -> forward mgr req' resp - + Just oa2c -> + get (oauth2callback key db (provider, oa2c) be) req resp + ["access"] -> do + now <- Just <$> epochTime + case extractCookie key now cookieName req of + Nothing -> authenticationRequired key oa2 req resp + Just (authCookie, _) -> + post (checkAccess db authCookie) req resp + _ -> notFound "proxy" req resp + _ -> do + now <- Just <$> epochTime + case extractCookie key now cookieName req of + Nothing -> authenticationRequired key oa2 req resp + Just cs@(authCookie, _) -> + authorize db cs req >>= \case + Nothing -> forbidden authCookie req resp + Just req' -> forward mgr req' resp robots :: W.Application -robots _ resp = resp $ - W.responseLBS ok200 [(hContentType, "text/plain; charset=utf-8")] - "User-agent: *\nDisallow: /" - - -oauth2callback :: ByteString -> Database -> (Text, OAuth2Client) -> BackendConf -> W.Application +robots _ resp = + resp $ + W.responseLBS + ok200 + [(hContentType, "text/plain; charset=utf-8")] + "User-agent: *\nDisallow: /" + +oauth2callback :: + ByteString + -> Database + -> (Text, OAuth2Client) + -> BackendConf + -> W.Application oauth2callback key db (provider, oa2c) be req resp = case param "code" of - Nothing -> badRequest "missing auth code" req resp - Just code -> + Nothing -> badRequest "missing auth code" req resp + Just code -> case param "state" of - Nothing -> badRequest "missing auth state" req resp + Nothing -> badRequest "missing auth state" req resp Just state -> case State.decode key state of - Left msg -> badRequest ("invalid state: " ++ msg) req resp + Left msg -> badRequest ("invalid state: " ++ msg) req resp Right url -> do au <- oauth2Authenticate oa2c code (redirectURL req provider) let email = getEmail au Log.info $ "login " ++ show email ++ " by " ++ show provider exists <- userExists db email - if exists then authenticate key be au url req resp - else userNotFound au req resp + if exists + then authenticate key be au url req resp + else userNotFound au req resp where param p = do (_, v) <- find ((==) p . fst) $ W.queryString req v - -- XXX: RFC6265: the user agent MUST NOT attach more than one Cookie header field -extractCookie :: ByteString -> Maybe CTime -> ByteString -> W.Request -> Maybe (AuthCookie, Cookies) +extractCookie :: + ByteString + -> Maybe CTime + -> ByteString + -> W.Request + -> Maybe (AuthCookie, Cookies) extractCookie key now name req = do - (_, cookies) <- find ((==) hCookie . fst) $ W.requestHeaders req + (_, cookies) <- find ((==) hCookie . fst) $ W.requestHeaders req (auth, others) <- discriminate cookies case cookieDecode key auth of Left _ -> Nothing - Right cookie -> if maybe True (acExpiry cookie >) now - then Just (cookie, others) else Nothing - where discriminate cs = - case partition ((==) name . fst) $ parseCookies cs of - ((_, x):_, xs) -> Just (x, xs) - _ -> Nothing - + Right cookie -> + if maybe True (acExpiry cookie >) now + then Just (cookie, others) + else Nothing + where + discriminate cs = + case partition ((==) name . fst) $ parseCookies cs of + ((_, x):_, xs) -> Just (x, xs) + _ -> Nothing -authenticate :: ByteString -> BackendConf -> AuthUser -> ByteString -> W.Application +authenticate :: + ByteString -> BackendConf -> AuthUser -> ByteString -> W.Application authenticate key be user url _req resp = do now <- epochTime let domain = pack <$> beCookieDomain be expiry = now + CTime (beCookieMaxAge be) - authCookie = AuthCookie { acUser = user, acExpiry = expiry } - cookie = WC.def { - WC.setCookieName = pack $ beCookieName be - , WC.setCookieHttpOnly = True - , WC.setCookiePath = Just "/" - , WC.setCookieSameSite = Nothing - , WC.setCookieSecure = True - , WC.setCookieValue = cookieEncode key authCookie - , WC.setCookieDomain = domain - , WC.setCookieExpires = Just . posixSecondsToUTCTime . realToFrac $ expiry - } - resp $ W.responseLBS seeOther303 [ - (hLocation, url) - , ("Set-Cookie", toByteString $ WC.renderSetCookie cookie) - ] "" - + authCookie = AuthCookie {acUser = user, acExpiry = expiry} + cookie = + WC.def + { WC.setCookieName = pack $ beCookieName be + , WC.setCookieHttpOnly = True + , WC.setCookiePath = Just "/" + , WC.setCookieSameSite = Nothing + , WC.setCookieSecure = True + , WC.setCookieValue = cookieEncode key authCookie + , WC.setCookieDomain = domain + , WC.setCookieExpires = + Just . posixSecondsToUTCTime . realToFrac $ expiry + } + resp $ + W.responseLBS + seeOther303 + [ (hLocation, url) + , ("Set-Cookie", toByteString $ WC.renderSetCookie cookie) + ] + "" -authorize :: Database -> (AuthCookie, Cookies) -> W.Request -> IO (Maybe W.Request) +authorize :: + Database -> (AuthCookie, Cookies) -> W.Request -> IO (Maybe W.Request) authorize db (authCookie, otherCookies) req = do - let - user = acUser authCookie - domain = decodeUtf8 . fromJust $ requestDomain req - email = getEmail user - emailUtf8 = getEmailUtf8 user - familyUtf8 = getFamilyNameUtf8 user - givenUtf8 = getGivenNameUtf8 user - method = decodeUtf8 $ W.requestMethod req - path = decodeUtf8 $ W.rawPathInfo req + let user = acUser authCookie + domain = decodeUtf8 . fromJust $ requestDomain req + email = getEmail user + emailUtf8 = getEmailUtf8 user + familyUtf8 = getFamilyNameUtf8 user + givenUtf8 = getGivenNameUtf8 user + method = decodeUtf8 $ W.requestMethod req + path = decodeUtf8 $ W.rawPathInfo req grps <- userGroups db email domain path method - if null grps then return Nothing - else do - ip <- pack . fromJust . fst <$> getNameInfo [NI_NUMERICHOST] True False (W.remoteHost req) - return . Just $ req { - W.requestHeaders = toList $ - insert "From" emailUtf8 $ - insert "X-Groups" (BS.intercalate "," $ encodeUtf8 <$> grps) $ - insert "X-Given-Name" givenUtf8 $ - insert "X-Family-Name" familyUtf8 $ - insert "X-Forwarded-Proto" "https" $ - insertWith (flip combine) "X-Forwarded-For" ip $ - setCookies otherCookies $ - fromListWith combine $ W.requestHeaders req - } + if null grps + then return Nothing + else do + ip <- + pack . fromJust . fst <$> + getNameInfo [NI_NUMERICHOST] True False (W.remoteHost req) + return . Just $ + req + { W.requestHeaders = + toList $ + insert "From" emailUtf8 $ + insert "X-Groups" (BS.intercalate "," $ encodeUtf8 <$> grps) $ + insert "X-Given-Name" givenUtf8 $ + insert "X-Family-Name" familyUtf8 $ + insert "X-Forwarded-Proto" "https" $ + insertWith (flip combine) "X-Forwarded-For" ip $ + setCookies otherCookies $ + fromListWith combine $ W.requestHeaders req + } where combine a b = a <> "," <> b setCookies [] = delete hCookie setCookies cs = insert hCookie (toByteString . renderCookies $ cs) - checkAccess :: Database -> AuthCookie -> W.Application checkAccess db authCookie req resp = do let email = getEmail . acUser $ authCookie @@ -217,77 +256,96 @@ checkAccess db authCookie req resp = do Log.debug $ "access <<< " ++ show inq tags <- userAccess db email domain inq Log.debug $ "access >>> " ++ show tags - resp $ W.responseLBS ok200 [(hContentType, "application/json")] (JSON.encode tags) - + resp $ + W.responseLBS + ok200 + [(hContentType, "application/json")] + (JSON.encode tags) -- XXX If something seems strange, think about HTTP/1.1 <-> HTTP/1.0. -- FIXME For HTTP/1.0 backends we might need an option -- FIXME in config file. HTTP Client does HTTP/1.1 by default. forward :: BE.Manager -> W.Application forward mgr req resp = do - let beReq = BE.defaultRequest + let beReq = + BE.defaultRequest { BE.method = W.requestMethod req , BE.path = W.rawPathInfo req , BE.queryString = W.rawQueryString req , BE.requestHeaders = modifyRequestHeaders $ W.requestHeaders req , BE.redirectCount = 0 , BE.decompress = const False - , BE.requestBody = case W.requestBodyLength req of - W.ChunkedBody -> requestBodySourceChunkedIO (sourceRequestBody req) - W.KnownLength l -> requestBodySourceIO (fromIntegral l) (sourceRequestBody req) + , BE.requestBody = + case W.requestBodyLength req of + W.ChunkedBody -> + requestBodySourceChunkedIO (sourceRequestBody req) + W.KnownLength l -> + requestBodySourceIO (fromIntegral l) (sourceRequestBody req) } - msg = unpack (BE.method beReq <> " " <> BE.path beReq <> BE.queryString beReq) + msg = + unpack (BE.method beReq <> " " <> BE.path beReq <> BE.queryString beReq) Log.debug $ "BACKEND <<< " ++ msg ++ " " ++ show (BE.requestHeaders beReq) BE.withResponse beReq mgr $ \res -> do - let status = BE.responseStatus res - headers = BE.responseHeaders res - body = mapOutput (Chunk . fromByteString) . bodyReaderSource $ BE.responseBody res - logging = if statusCode status `elem` [ 400, 500 ] then - Log.warn else Log.debug - logging $ "BACKEND >>> " ++ show (statusCode status) ++ " on " ++ msg ++ " " ++ show headers ++ "\n" - resp $ responseSource status (modifyResponseHeaders headers) body - + let status = BE.responseStatus res + headers = BE.responseHeaders res + body = + mapOutput (Chunk . fromByteString) . bodyReaderSource $ + BE.responseBody res + logging = + if statusCode status `elem` [400, 500] + then Log.warn + else Log.debug + logging $ + "BACKEND >>> " ++ + show (statusCode status) ++ " on " ++ msg ++ " " ++ show headers ++ "\n" + resp $ responseSource status (modifyResponseHeaders headers) body modifyRequestHeaders :: RequestHeaders -> RequestHeaders modifyRequestHeaders = filter (\(n, _) -> n `notElem` ban) where ban = - [ - hConnection - , hContentLength -- XXX This is set automtically before sending request to backend + [ hConnection + , hContentLength -- XXX This is set automtically before sending request to backend , hTransferEncoding -- XXX Likewise ] - modifyResponseHeaders :: ResponseHeaders -> ResponseHeaders modifyResponseHeaders = filter (\(n, _) -> n `notElem` ban) where ban = - [ - hConnection + [ hConnection -- XXX WAI docs say we MUST NOT add (keep) Content-Length, Content-Range, and Transfer-Encoding, -- XXX but we use streaming body, which may add Transfer-Encoding only. -- XXX Thus we keep Content-* headers. , hTransferEncoding ] - -authenticationRequired :: ByteString -> HashMap Text OAuth2Client -> W.Application +authenticationRequired :: + ByteString -> HashMap Text OAuth2Client -> W.Application authenticationRequired key oa2 req resp = do Log.info $ "511 Unauthenticated: " ++ showReq req - resp $ W.responseLBS networkAuthenticationRequired511 [(hContentType, "text/html; charset=utf-8")] page + resp $ + W.responseLBS + networkAuthenticationRequired511 + [(hContentType, "text/html; charset=utf-8")] + page where - path = if W.requestMethod req == methodGet - then W.rawPathInfo req <> W.rawQueryString req - else "/" - state = State.encode key $ "https://" <> fromJust (W.requestHeaderHost req) <> path + path = + if W.requestMethod req == methodGet + then W.rawPathInfo req <> W.rawQueryString req + else "/" + state = + State.encode key $ + "https://" <> fromJust (W.requestHeaderHost req) <> path authLink :: Text -> OAuth2Client -> ByteString -> ByteString - authLink provider oa2c html = + authLink provider oa2c html = let u = oauth2AuthorizeURL oa2c state (redirectURL req provider) d = pack $ oauth2Description oa2c in [qc|{html}<p><a href="{u}">Authenticate with {d}</a></p>|] authHtml = foldrWithKey authLink "" oa2 - page = fromStrict [qc| + page = + fromStrict + [qc| <!DOCTYPE html> <html lang="en"> <head> @@ -301,14 +359,16 @@ authenticationRequired key oa2 req resp = do </html> |] - forbidden :: AuthCookie -> W.Application forbidden ac req resp = do Log.info $ "403 Forbidden: " ++ show email ++ ": " ++ showReq req - resp $ W.responseLBS forbidden403 [(hContentType, "text/html; charset=utf-8")] page + resp $ + W.responseLBS forbidden403 [(hContentType, "text/html; charset=utf-8")] page where email = getEmailUtf8 . acUser $ ac - page = fromStrict [qc| + page = + fromStrict + [qc| <!DOCTYPE html> <html lang="en"> <head> @@ -323,14 +383,16 @@ forbidden ac req resp = do </html> |] - userNotFound :: AuthUser -> W.Application userNotFound au _ resp = do Log.info $ "404 User not found: " ++ show email - resp $ W.responseLBS notFound404 [(hContentType, "text/html; charset=utf-8")] page + resp $ + W.responseLBS notFound404 [(hContentType, "text/html; charset=utf-8")] page where email = getEmailUtf8 au - page = fromStrict [qc| + page = + fromStrict + [qc| <!DOCTYPE html> <html lang="en"> <head> @@ -345,98 +407,102 @@ userNotFound au _ resp = do </html> |] - logout :: ByteString -> ByteString -> Maybe ByteString -> W.Application logout key cookieName cookieDomain req resp = do let host = fromJust $ W.requestHeaderHost req case extractCookie key Nothing cookieName req of - Nothing -> resp $ W.responseLBS found302 [ (hLocation, "https://" <> host) ] "" - Just _ -> do - let cookie = WC.def { - WC.setCookieName = cookieName + Nothing -> + resp $ W.responseLBS found302 [(hLocation, "https://" <> host)] "" + Just _ -> do + let cookie = + WC.def + { WC.setCookieName = cookieName , WC.setCookieHttpOnly = True , WC.setCookiePath = Just "/" , WC.setCookieSameSite = Just WC.sameSiteStrict , WC.setCookieSecure = True , WC.setCookieValue = "goodbye" , WC.setCookieDomain = cookieDomain - , WC.setCookieExpires = Just . posixSecondsToUTCTime . realToFrac $ CTime 0 + , WC.setCookieExpires = + Just . posixSecondsToUTCTime . realToFrac $ CTime 0 } - resp $ W.responseLBS found302 [ - (hLocation, "https://" <> host) - , ("Set-Cookie", toByteString $ WC.renderSetCookie cookie) - ] "" - - -badRequest ::String -> W.Application + resp $ + W.responseLBS + found302 + [ (hLocation, "https://" <> host) + , ("Set-Cookie", toByteString $ WC.renderSetCookie cookie) + ] + "" + +badRequest :: String -> W.Application badRequest msg req resp = do Log.warn $ "400 Bad Request (" ++ msg ++ "): " ++ showReq req resp $ W.responseLBS badRequest400 [] "Bad Request" - -notFound ::String -> W.Application +notFound :: String -> W.Application notFound msg req resp = do Log.warn $ "404 Not Found (" ++ msg ++ "): " ++ showReq req resp $ W.responseLBS notFound404 [] "Not Found" - logException :: W.Middleware logException app req resp = - catches (app req resp) [ - Handler badGateway, - Handler internalError - ] + catches (app req resp) [Handler badGateway, Handler internalError] where internalError :: SomeException -> IO W.ResponseReceived internalError = response internalServerError500 - badGateway :: BE.HttpException -> IO W.ResponseReceived badGateway = response badGateway502 - response :: Exception e => Status -> e -> IO W.ResponseReceived response st e = do - Log.error $ show (statusCode st) ++ " " ++ unpack (statusMessage st) - ++ ": " ++ displayException e ++ " on " ++ showReq req - resp $ W.responseLBS st [(hContentType, "text/plain")] (fromStrict $ statusMessage st) - - + Log.error $ + show (statusCode st) ++ + " " ++ + unpack (statusMessage st) ++ + ": " ++ displayException e ++ " on " ++ showReq req + resp $ + W.responseLBS + st + [(hContentType, "text/plain")] + (fromStrict $ statusMessage st) get :: W.Middleware get app req resp | W.requestMethod req == methodGet = app req resp | otherwise = do Log.warn $ "405 Method Not Allowed: " ++ showReq req - resp $ W.responseLBS methodNotAllowed405 [("Allow", "GET")] "Method Not Allowed" - + resp $ + W.responseLBS methodNotAllowed405 [("Allow", "GET")] "Method Not Allowed" post :: W.Middleware post app req resp | W.requestMethod req == methodPost = app req resp | otherwise = do Log.warn $ "405 Method Not Allowed: " ++ showReq req - resp $ W.responseLBS methodNotAllowed405 [("Allow", "POST")] "Method Not Allowed" - + resp $ + W.responseLBS methodNotAllowed405 [("Allow", "POST")] "Method Not Allowed" redirectURL :: W.Request -> Text -> ByteString redirectURL req provider = - "https://" <> fromJust (W.requestHeaderHost req) - <> "/.sproxy/oauth2/" <> encodeUtf8 provider - + "https://" <> fromJust (W.requestHeaderHost req) <> "/.sproxy/oauth2/" <> + encodeUtf8 provider requestDomain :: W.Request -> Maybe ByteString requestDomain req = do h <- W.requestHeaderHost req return . fst . BS.break (== _colon) $ h - -- XXX: make sure not to reveal the cookie, which can be valid (!) showReq :: W.Request -> String -showReq req = - unpack ( W.requestMethod req <> " " - <> fromMaybe "<no host>" (W.requestHeaderHost req) - <> W.rawPathInfo req <> W.rawQueryString req <> " " ) - ++ show (W.httpVersion req) ++ " " - ++ show (fromMaybe "-" $ W.requestHeaderReferer req) ++ " " - ++ show (fromMaybe "-" $ W.requestHeaderUserAgent req) - ++ " from " ++ show (W.remoteHost req) - +showReq req = + unpack + (W.requestMethod req <> " " <> + fromMaybe "<no host>" (W.requestHeaderHost req) <> + W.rawPathInfo req <> + W.rawQueryString req <> + " ") ++ + show (W.httpVersion req) ++ + " " ++ + show (fromMaybe "-" $ W.requestHeaderReferer req) ++ + " " ++ + show (fromMaybe "-" $ W.requestHeaderUserAgent req) ++ + " from " ++ show (W.remoteHost req) diff --git a/src/Sproxy/Application/Access.hs b/src/Sproxy/Application/Access.hs index d8984ee..6ba972c 100644 --- a/src/Sproxy/Application/Access.hs +++ b/src/Sproxy/Application/Access.hs @@ -1,23 +1,21 @@ {-# LANGUAGE DeriveGeneric #-} {-# LANGUAGE OverloadedStrings #-} -module Sproxy.Application.Access ( - Inquiry -, Question(..) -) where +module Sproxy.Application.Access + ( Inquiry + , Question(..) + ) where import Data.Aeson (FromJSON) import Data.HashMap.Strict (HashMap) import Data.Text (Text) import GHC.Generics (Generic) - -data Question = Question { - path :: Text -, method :: Text -} deriving (Generic, Show) +data Question = Question + { path :: Text + , method :: Text + } deriving (Generic, Show) instance FromJSON Question type Inquiry = HashMap Text Question - diff --git a/src/Sproxy/Application/Cookie.hs b/src/Sproxy/Application/Cookie.hs index a86f42a..a9a8ad6 100644 --- a/src/Sproxy/Application/Cookie.hs +++ b/src/Sproxy/Application/Cookie.hs @@ -1,56 +1,57 @@ {-# LANGUAGE OverloadedStrings #-} -module Sproxy.Application.Cookie ( - AuthCookie(..) -, AuthUser -, cookieDecode -, cookieEncode -, getEmail -, getEmailUtf8 -, getFamilyNameUtf8 -, getGivenNameUtf8 -, newUser -, setFamilyName -, setGivenName -) where + +module Sproxy.Application.Cookie + ( AuthCookie(..) + , AuthUser + , cookieDecode + , cookieEncode + , getEmail + , getEmailUtf8 + , getFamilyNameUtf8 + , getGivenNameUtf8 + , newUser + , setFamilyName + , setGivenName + ) where import Data.ByteString (ByteString) -import Data.Text (Text, toLower, strip) +import qualified Data.Serialize as DS +import Data.Text (Text, strip, toLower) import Data.Text.Encoding (decodeUtf8, encodeUtf8) import Foreign.C.Types (CTime(..)) -import qualified Data.Serialize as DS import qualified Sproxy.Application.State as State -data AuthUser = AuthUser { - auEmail :: ByteString -, auGivenName :: ByteString -, auFamilyName :: ByteString -} +data AuthUser = AuthUser + { auEmail :: ByteString + , auGivenName :: ByteString + , auFamilyName :: ByteString + } -data AuthCookie = AuthCookie { - acUser :: AuthUser -, acExpiry :: CTime -} +data AuthCookie = AuthCookie + { acUser :: AuthUser + , acExpiry :: CTime + } instance DS.Serialize AuthCookie where put c = DS.put (auEmail u, auGivenName u, auFamilyName u, x) - where u = acUser c - x = (\(CTime i) -> i) $ acExpiry c + where + u = acUser c + x = (\(CTime i) -> i) $ acExpiry c get = do (e, n, f, x) <- DS.get - return AuthCookie { - acUser = AuthUser { auEmail = e, auGivenName = n, auFamilyName = f } + return + AuthCookie + { acUser = AuthUser {auEmail = e, auGivenName = n, auFamilyName = f} , acExpiry = CTime x } - cookieDecode :: ByteString -> ByteString -> Either String AuthCookie cookieDecode key d = State.decode key d >>= DS.decode cookieEncode :: ByteString -> AuthCookie -> ByteString cookieEncode key = State.encode key . DS.encode - getEmail :: AuthUser -> Text getEmail = decodeUtf8 . auEmail @@ -63,17 +64,16 @@ getGivenNameUtf8 = auGivenName getFamilyNameUtf8 :: AuthUser -> ByteString getFamilyNameUtf8 = auFamilyName - newUser :: Text -> AuthUser -newUser email = AuthUser { - auEmail = encodeUtf8 . toLower . strip $ email +newUser email = + AuthUser + { auEmail = encodeUtf8 . toLower . strip $ email , auGivenName = "" , auFamilyName = "" } setGivenName :: Text -> AuthUser -> AuthUser -setGivenName given au = au{ auGivenName = encodeUtf8 . strip $ given } +setGivenName given au = au {auGivenName = encodeUtf8 . strip $ given} setFamilyName :: Text -> AuthUser -> AuthUser -setFamilyName family au = au{ auFamilyName = encodeUtf8 . strip $ family } - +setFamilyName family au = au {auFamilyName = encodeUtf8 . strip $ family} diff --git a/src/Sproxy/Application/OAuth2/Common.hs b/src/Sproxy/Application/OAuth2/Common.hs index 0324e62..ae96e68 100644 --- a/src/Sproxy/Application/OAuth2/Common.hs +++ b/src/Sproxy/Application/OAuth2/Common.hs @@ -1,40 +1,37 @@ {-# LANGUAGE OverloadedStrings #-} -module Sproxy.Application.OAuth2.Common ( - AccessTokenBody(..) -, OAuth2Client(..) -, OAuth2Provider -) where + +module Sproxy.Application.OAuth2.Common + ( AccessTokenBody(..) + , OAuth2Client(..) + , OAuth2Provider + ) where import Control.Applicative (empty) -import Data.Aeson (FromJSON, parseJSON, Value(Object), (.:)) -import Data.ByteString(ByteString) +import Data.Aeson (FromJSON, Value(Object), (.:), parseJSON) +import Data.ByteString (ByteString) import Data.Text (Text) import Sproxy.Application.Cookie (AuthUser) -data OAuth2Client = OAuth2Client { - oauth2Description :: String -, oauth2AuthorizeURL - :: ByteString -- state - -> ByteString -- redirect url - -> ByteString -, oauth2Authenticate - :: ByteString -- code - -> ByteString -- redirect url - -> IO AuthUser -} +data OAuth2Client = OAuth2Client + { oauth2Description :: String + , oauth2AuthorizeURL :: ByteString -- state + -> ByteString -- redirect url + -> ByteString + , oauth2Authenticate :: ByteString -- code + -> ByteString -- redirect url + -> IO AuthUser + } type OAuth2Provider = (ByteString, ByteString) -> OAuth2Client -- | RFC6749. We ignore optional token_type ("Bearer" from Google, omitted by LinkedIn) -- and expires_in because we don't use them, *and* expires_in creates troubles: -- it's an integer from Google and string from LinkedIn (sic!) -data AccessTokenBody = AccessTokenBody { - accessToken :: Text -} deriving (Eq, Show) +data AccessTokenBody = AccessTokenBody + { accessToken :: Text + } deriving (Eq, Show) instance FromJSON AccessTokenBody where - parseJSON (Object v) = AccessTokenBody - <$> v .: "access_token" + parseJSON (Object v) = AccessTokenBody <$> v .: "access_token" parseJSON _ = empty - diff --git a/src/Sproxy/Application/State.hs b/src/Sproxy/Application/State.hs index 29d9252..5f836e6 100644 --- a/src/Sproxy/Application/State.hs +++ b/src/Sproxy/Application/State.hs @@ -1,30 +1,25 @@ -module Sproxy.Application.State ( - decode -, encode -) where +module Sproxy.Application.State + ( decode + , encode + ) where import Data.ByteString (ByteString) -import Data.ByteString.Lazy (fromStrict, toStrict) -import Data.Digest.Pure.SHA (hmacSha1, bytestringDigest) import qualified Data.ByteString.Base64 as Base64 +import Data.ByteString.Lazy (fromStrict, toStrict) +import Data.Digest.Pure.SHA (bytestringDigest, hmacSha1) import qualified Data.Serialize as DS - -- FIXME: Compress / decompress ? - - encode :: ByteString -> ByteString -> ByteString encode key payload = Base64.encode . DS.encode $ (payload, digest key payload) - decode :: ByteString -> ByteString -> Either String ByteString decode key d = do (payload, dgst) <- DS.decode =<< Base64.decode d if dgst /= digest key payload - then Left "junk" - else Right payload - + then Left "junk" + else Right payload digest :: ByteString -> ByteString -> ByteString -digest key payload = toStrict . bytestringDigest $ hmacSha1 (fromStrict key) (fromStrict payload) - +digest key payload = + toStrict . bytestringDigest $ hmacSha1 (fromStrict key) (fromStrict payload) diff --git a/src/Sproxy/Config.hs b/src/Sproxy/Config.hs index e0f35a3..f1d8004 100644 --- a/src/Sproxy/Config.hs +++ b/src/Sproxy/Config.hs @@ -1,9 +1,10 @@ {-# LANGUAGE OverloadedStrings #-} -module Sproxy.Config ( - BackendConf(..) -, ConfigFile(..) -, OAuth2Conf(..) -) where + +module Sproxy.Config + ( BackendConf(..) + , ConfigFile(..) + , OAuth2Conf(..) + ) where import Control.Applicative (empty) import Data.Aeson (FromJSON, parseJSON) @@ -11,84 +12,78 @@ import Data.HashMap.Strict (HashMap) import Data.Int (Int64) import Data.Text (Text) import Data.Word (Word16) -import Data.Yaml (Value(Object), (.:), (.:?), (.!=)) +import Data.Yaml (Value(Object), (.!=), (.:), (.:?)) import Sproxy.Logging (LogLevel(Debug)) -data ConfigFile = ConfigFile { - cfListen :: Word16 -, cfSsl :: Bool -, cfUser :: String -, cfHome :: FilePath -, cfLogLevel :: LogLevel -, cfSslCert :: Maybe FilePath -, cfSslKey :: Maybe FilePath -, cfSslCertChain :: [FilePath] -, cfKey :: Maybe String -, cfListen80 :: Maybe Bool -, cfHttpsPort :: Maybe Word16 -, cfBackends :: [BackendConf] -, cfOAuth2 :: HashMap Text OAuth2Conf -, cfDataFile :: Maybe FilePath -, cfDatabase :: Maybe String -, cfPgPassFile :: Maybe FilePath -, cfHTTP2 :: Bool -} deriving (Show) +data ConfigFile = ConfigFile + { cfListen :: Word16 + , cfSsl :: Bool + , cfUser :: String + , cfHome :: FilePath + , cfLogLevel :: LogLevel + , cfSslCert :: Maybe FilePath + , cfSslKey :: Maybe FilePath + , cfSslCertChain :: [FilePath] + , cfKey :: Maybe String + , cfListen80 :: Maybe Bool + , cfHttpsPort :: Maybe Word16 + , cfBackends :: [BackendConf] + , cfOAuth2 :: HashMap Text OAuth2Conf + , cfDataFile :: Maybe FilePath + , cfDatabase :: Maybe String + , cfPgPassFile :: Maybe FilePath + , cfHTTP2 :: Bool + } deriving (Show) instance FromJSON ConfigFile where - parseJSON (Object m) = ConfigFile <$> - m .:? "listen" .!= 443 - <*> m .:? "ssl" .!= True - <*> m .:? "user" .!= "sproxy" - <*> m .:? "home" .!= "." - <*> m .:? "log_level" .!= Debug - <*> m .:? "ssl_cert" - <*> m .:? "ssl_key" - <*> m .:? "ssl_cert_chain" .!= [] - <*> m .:? "key" - <*> m .:? "listen80" - <*> m .:? "https_port" - <*> m .: "backends" - <*> m .: "oauth2" - <*> m .:? "datafile" - <*> m .:? "database" - <*> m .:? "pgpassfile" - <*> m .:? "http2" .!= True + parseJSON (Object m) = + ConfigFile <$> m .:? "listen" .!= 443 <*> m .:? "ssl" .!= True <*> + m .:? "user" .!= "sproxy" <*> + m .:? "home" .!= "." <*> + m .:? "log_level" .!= Debug <*> + m .:? "ssl_cert" <*> + m .:? "ssl_key" <*> + m .:? "ssl_cert_chain" .!= [] <*> + m .:? "key" <*> + m .:? "listen80" <*> + m .:? "https_port" <*> + m .: "backends" <*> + m .: "oauth2" <*> + m .:? "datafile" <*> + m .:? "database" <*> + m .:? "pgpassfile" <*> + m .:? "http2" .!= True parseJSON _ = empty - -data BackendConf = BackendConf { - beName :: String -, beAddress :: String -, bePort :: Maybe Word16 -, beSocket :: Maybe FilePath -, beCookieName :: String -, beCookieDomain :: Maybe String -, beCookieMaxAge :: Int64 -, beConnCount :: Int -} deriving (Show) +data BackendConf = BackendConf + { beName :: String + , beAddress :: String + , bePort :: Maybe Word16 + , beSocket :: Maybe FilePath + , beCookieName :: String + , beCookieDomain :: Maybe String + , beCookieMaxAge :: Int64 + , beConnCount :: Int + } deriving (Show) instance FromJSON BackendConf where - parseJSON (Object m) = BackendConf <$> - m .:? "name" .!= "*" - <*> m .:? "address" .!= "127.0.0.1" - <*> m .:? "port" - <*> m .:? "socket" - <*> m .:? "cookie_name" .!= "sproxy" - <*> m .:? "cookie_domain" - <*> m .:? "cookie_max_age" .!= (7 * 24 * 60 * 60) - <*> m .:? "conn_count" .!= 32 + parseJSON (Object m) = + BackendConf <$> m .:? "name" .!= "*" <*> m .:? "address" .!= "127.0.0.1" <*> + m .:? "port" <*> + m .:? "socket" <*> + m .:? "cookie_name" .!= "sproxy" <*> + m .:? "cookie_domain" <*> + m .:? "cookie_max_age" .!= (7 * 24 * 60 * 60) <*> + m .:? "conn_count" .!= 32 parseJSON _ = empty - -data OAuth2Conf = OAuth2Conf { - oa2ClientId :: String -, oa2ClientSecret :: String -} deriving (Show) +data OAuth2Conf = OAuth2Conf + { oa2ClientId :: String + , oa2ClientSecret :: String + } deriving (Show) instance FromJSON OAuth2Conf where - parseJSON (Object m) = OAuth2Conf <$> - m .: "client_id" - <*> m .: "client_secret" + parseJSON (Object m) = + OAuth2Conf <$> m .: "client_id" <*> m .: "client_secret" parseJSON _ = empty - diff --git a/src/Sproxy/Logging.hs b/src/Sproxy/Logging.hs index 651a73a..93bc355 100644 --- a/src/Sproxy/Logging.hs +++ b/src/Sproxy/Logging.hs @@ -1,12 +1,12 @@ -module Sproxy.Logging ( - LogLevel(..) -, debug -, error -, info -, level -, start -, warn -) where +module Sproxy.Logging + ( LogLevel(..) + , debug + , error + , info + , level + , start + , warn + ) where import Prelude hiding (error) @@ -15,13 +15,13 @@ import Control.Concurrent (forkIO) import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan) import Control.Monad (forever, when) import Data.Aeson (FromJSON, ToJSON) +import qualified Data.Aeson as JSON import Data.Char (toLower) import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import qualified Data.Text as T import System.IO (hPrint, stderr) import System.IO.Unsafe (unsafePerformIO) import Text.Read (readMaybe) -import qualified Data.Aeson as JSON -import qualified Data.Text as T start :: LogLevel -> IO () start None = return () @@ -34,16 +34,15 @@ start lvl = do info :: String -> IO () info = send . Message Info -warn:: String -> IO () +warn :: String -> IO () warn = send . Message Warning -error:: String -> IO () +error :: String -> IO () error = send . Message Error debug :: String -> IO () debug = send . Message Debug - send :: Message -> IO () send msg@(Message l _) = do lvl <- level @@ -62,38 +61,46 @@ logLevel = unsafePerformIO (newIORef None) level :: IO LogLevel level = readIORef logLevel - -data LogLevel = None | Error | Warning | Info | Debug +data LogLevel + = None + | Error + | Warning + | Info + | Debug deriving (Enum, Ord, Eq) instance Show LogLevel where - show None = "NONE" - show Error = "ERROR" + show None = "NONE" + show Error = "ERROR" show Warning = "WARN" - show Info = "INFO" - show Debug = "DEBUG" + show Info = "INFO" + show Debug = "DEBUG" instance Read LogLevel where readsPrec _ s - | l == "none" = [ (None, "") ] - | l == "error" = [ (Error, "") ] - | l == "warn" = [ (Warning, "") ] - | l == "info" = [ (Info, "") ] - | l == "debug" = [ (Debug, "") ] - | otherwise = [ ] - where l = map toLower s + | l == "none" = [(None, "")] + | l == "error" = [(Error, "")] + | l == "warn" = [(Warning, "")] + | l == "info" = [(Info, "")] + | l == "debug" = [(Debug, "")] + | otherwise = [] + where + l = map toLower s instance ToJSON LogLevel where toJSON = JSON.String . T.pack . show instance FromJSON LogLevel where parseJSON (JSON.String s) = - maybe (fail $ "unknown log level: " ++ show s) return (readMaybe . T.unpack $ s) + maybe + (fail $ "unknown log level: " ++ show s) + return + (readMaybe . T.unpack $ s) parseJSON _ = empty - -data Message = Message LogLevel String +data Message = + Message LogLevel + String instance Show Message where show (Message lvl str) = show lvl ++ ": " ++ str - diff --git a/src/Sproxy/Server.hs b/src/Sproxy/Server.hs index 7b65f32..75a50a4 100644 --- a/src/Sproxy/Server.hs +++ b/src/Sproxy/Server.hs @@ -1,6 +1,6 @@ -module Sproxy.Server ( - server -) where +module Sproxy.Server + ( server + ) where import Control.Concurrent (forkIO) import Control.Exception (bracketOnError) @@ -11,136 +11,128 @@ import Data.Maybe (fromMaybe) import Data.Text (Text) import Data.Word (Word16) import Data.Yaml.Include (decodeFileEither) -import Network.HTTP.Client (Manager, ManagerSettings(..), defaultManagerSettings, newManager, socketConnection) +import Network.HTTP.Client + (Manager, ManagerSettings(..), defaultManagerSettings, newManager, + socketConnection) import Network.HTTP.Client.Internal (Connection) -import Network.Socket ( Socket, Family(AF_INET, AF_UNIX), SockAddr(SockAddrInet, SockAddrUnix), - SocketOption(ReuseAddr), SocketType(Stream), bind, close, connect, inet_addr, - listen, maxListenQueue, setSocketOption, socket ) +import Network.Socket + (Family(AF_INET, AF_UNIX), SockAddr(SockAddrInet, SockAddrUnix), + Socket, SocketOption(ReuseAddr), SocketType(Stream), bind, close, + connect, inet_addr, listen, maxListenQueue, setSocketOption, + socket) import Network.Wai (Application) -import Network.Wai.Handler.WarpTLS (tlsSettingsChain, runTLSSocket) -import Network.Wai.Handler.Warp ( Settings, defaultSettings, runSettingsSocket, - setHTTP2Disabled, setOnException ) +import Network.Wai.Handler.Warp + (Settings, defaultSettings, runSettingsSocket, setHTTP2Disabled, + setOnException) +import Network.Wai.Handler.WarpTLS (runTLSSocket, tlsSettingsChain) import System.Entropy (getEntropy) import System.Environment (setEnv) import System.Exit (exitFailure) import System.FilePath.Glob (compile) import System.IO (hPutStrLn, stderr) -import System.Posix.User ( GroupEntry(..), UserEntry(..), - getAllGroupEntries, getRealUserID, - getUserEntryForName, setGroupID, setGroups, setUserID ) +import System.Posix.User + (GroupEntry(..), UserEntry(..), getAllGroupEntries, getRealUserID, + getUserEntryForName, setGroupID, setGroups, setUserID) -import Sproxy.Application (sproxy, redirect) -import Sproxy.Application.OAuth2.Common (OAuth2Client) -import Sproxy.Config (BackendConf(..), ConfigFile(..), OAuth2Conf(..)) +import Sproxy.Application (redirect, sproxy) import qualified Sproxy.Application.OAuth2 as OAuth2 +import Sproxy.Application.OAuth2.Common (OAuth2Client) +import Sproxy.Config + (BackendConf(..), ConfigFile(..), OAuth2Conf(..)) import qualified Sproxy.Logging as Log import qualified Sproxy.Server.DB as DB - {- TODO: - Log.error && exitFailure should be replaced - by Log.fatal && wait for logger thread to print && exitFailure -} - server :: FilePath -> IO () server configFile = do cf <- readConfigFile configFile Log.start $ cfLogLevel cf - sock <- socket AF_INET Stream 0 setSocketOption sock ReuseAddr 1 bind sock $ SockAddrInet (fromIntegral $ cfListen cf) 0 - - maybe80 <- if fromMaybe (443 == cfListen cf) (cfListen80 cf) - then do - sock80 <- socket AF_INET Stream 0 - setSocketOption sock80 ReuseAddr 1 - bind sock80 $ SockAddrInet 80 0 - return (Just sock80) - else - return Nothing - + maybe80 <- + if fromMaybe (443 == cfListen cf) (cfListen80 cf) + then do + sock80 <- socket AF_INET Stream 0 + setSocketOption sock80 ReuseAddr 1 + bind sock80 $ SockAddrInet 80 0 + return (Just sock80) + else return Nothing uid <- getRealUserID when (0 == uid) $ do let user = cfUser cf Log.info $ "switching to user " ++ show user u <- getUserEntryForName user - groupIDs <- map groupID . filter (elem user . groupMembers) - <$> getAllGroupEntries + groupIDs <- + map groupID . filter (elem user . groupMembers) <$> getAllGroupEntries setGroups groupIDs setGroupID $ userGroupID u setUserID $ userID u - ds <- newDataSource cf db <- DB.start (cfHome cf) ds - - key <- maybe - (Log.info "using new random key" >> getEntropy 32) - (return . pack) - (cfKey cf) - - let - settings = - (if cfHTTP2 cf then id else setHTTP2Disabled) $ - setOnException (\_ _ -> return ()) - defaultSettings - - oauth2clients <- HM.fromList <$> mapM newOAuth2Client (HM.toList (cfOAuth2 cf)) - + key <- + maybe + (Log.info "using new random key" >> getEntropy 32) + (return . pack) + (cfKey cf) + let settings = + (if cfHTTP2 cf + then id + else setHTTP2Disabled) $ + setOnException (\_ _ -> return ()) defaultSettings + oauth2clients <- + HM.fromList <$> mapM newOAuth2Client (HM.toList (cfOAuth2 cf)) backends <- - mapM (\be -> do - m <- newBackendManager be - return (compile $ beName be, be, m) - ) $ cfBackends cf - - + mapM + (\be -> do + m <- newBackendManager be + return (compile $ beName be, be, m)) $ + cfBackends cf warpServer <- newServer cf - case maybe80 of - Nothing -> return () + Nothing -> return () Just sock80 -> do let httpsPort = fromMaybe (cfListen cf) (cfHttpsPort cf) Log.info "listening on port 80 (HTTP redirect)" listen sock80 maxListenQueue void . forkIO $ runSettingsSocket settings sock80 (redirect httpsPort) - -- XXX 2048 is from bindPortTCP from streaming-commons used internally by runTLS. -- XXX Since we don't call runTLS, we listen socket here with the same options. Log.info $ "proxy listening on port " ++ show (cfListen cf) listen sock (max 2048 maxListenQueue) warpServer settings sock (sproxy key db oauth2clients backends) - newDataSource :: ConfigFile -> IO (Maybe DB.DataSource) newDataSource cf = case (cfDataFile cf, cfDatabase cf) of (Nothing, Just str) -> do case cfPgPassFile cf of Nothing -> return () - Just f -> do + Just f -> do Log.info $ "pgpassfile: " ++ show f setEnv "PGPASSFILE" f return . Just $ DB.PostgreSQL str - - (Just f, Nothing) -> return . Just $ DB.File f - - (Nothing, Nothing) -> return Nothing + (Just f, Nothing) -> return . Just $ DB.File f + (Nothing, Nothing) -> return Nothing _ -> do Log.error "only one data source can be used" exitFailure - newOAuth2Client :: (Text, OAuth2Conf) -> IO (Text, OAuth2Client) newOAuth2Client (name, cfg) = case HM.lookup name OAuth2.providers of - Nothing -> do Log.error $ "OAuth2 provider " ++ show name ++ " is not supported" - exitFailure + Nothing -> do + Log.error $ "OAuth2 provider " ++ show name ++ " is not supported" + exitFailure Just provider -> do Log.info $ "oauth2: adding " ++ show name return (name, provider (client_id, client_secret)) - where client_id = pack $ oa2ClientId cfg - client_secret = pack $ oa2ClientSecret cfg - + where + client_id = pack $ oa2ClientId cfg + client_secret = pack $ oa2ClientSecret cfg newBackendManager :: BackendConf -> IO Manager newBackendManager be = do @@ -149,20 +141,18 @@ newBackendManager be = do (Just f, Nothing) -> do Log.info $ "backend `" ++ beName be ++ "' on UNIX socket " ++ f return $ openUnixSocketConnection f - (Nothing, Just n) -> do - Log.info $ "backend `" ++ beName be ++ "' on " ++ beAddress be ++ ":" ++ show n + Log.info $ + "backend `" ++ beName be ++ "' on " ++ beAddress be ++ ":" ++ show n return $ openTCPConnection (beAddress be) n - _ -> do - Log.error "either backend port number or UNIX socket path is required." - exitFailure - - newManager defaultManagerSettings { - managerRawConnection = return $ \_ _ _ -> openConn - , managerConnCount = beConnCount be - } - + Log.error "either backend port number or UNIX socket path is required." + exitFailure + newManager + defaultManagerSettings + { managerRawConnection = return $ \_ _ _ -> openConn + , managerConnCount = beConnCount be + } newServer :: ConfigFile -> IO (Settings -> Socket -> Application -> IO ()) newServer cf @@ -170,33 +160,31 @@ newServer cf case (cfSslKey cf, cfSslCert cf) of (Just k, Just c) -> return $ runTLSSocket (tlsSettingsChain c (cfSslCertChain cf) k) - _ -> do Log.error "missings SSL certificate" - exitFailure + _ -> do + Log.error "missings SSL certificate" + exitFailure | otherwise = do - Log.warn "not using SSL!" - return runSettingsSocket - + Log.warn "not using SSL!" + return runSettingsSocket openUnixSocketConnection :: FilePath -> IO Connection openUnixSocketConnection f = bracketOnError - (socket AF_UNIX Stream 0) - close - (\s -> do - connect s (SockAddrUnix f) - socketConnection s 8192) - + (socket AF_UNIX Stream 0) + close + (\s -> do + connect s (SockAddrUnix f) + socketConnection s 8192) openTCPConnection :: String -> Word16 -> IO Connection openTCPConnection addr port = bracketOnError - (socket AF_INET Stream 0) - close - (\s -> do - a <- inet_addr addr - connect s (SockAddrInet (fromIntegral port) a) - socketConnection s 8192) - + (socket AF_INET Stream 0) + close + (\s -> do + a <- inet_addr addr + connect s (SockAddrInet (fromIntegral port) a) + socketConnection s 8192) readConfigFile :: FilePath -> IO ConfigFile readConfigFile f = do @@ -206,4 +194,3 @@ readConfigFile f = do hPutStrLn stderr $ "FATAL: " ++ f ++ ": " ++ show e exitFailure Right cf -> return cf - diff --git a/src/Sproxy/Server/DB.hs b/src/Sproxy/Server/DB.hs index 662a9c7..be44f69 100644 --- a/src/Sproxy/Server/DB.hs +++ b/src/Sproxy/Server/DB.hs @@ -1,72 +1,83 @@ {-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE QuasiQuotes #-} -module Sproxy.Server.DB ( - Database -, DataSource(..) -, userAccess -, userExists -, userGroups -, start -) where + +module Sproxy.Server.DB + ( Database + , DataSource(..) + , userAccess + , userExists + , userGroups + , start + ) where import Control.Concurrent (forkIO, threadDelay) import Control.Exception (SomeException, bracket, catch, finally) import Control.Monad (filterM, forever, void) import Data.ByteString.Char8 (pack) +import qualified Data.HashMap.Strict as HM import Data.Pool (Pool, createPool, withResource) import Data.Text (Text, toLower, unpack) import Data.Yaml (decodeFileEither) -import Database.SQLite.Simple (NamedParam((:=))) -import Text.InterpolatedString.Perl6 (q, qc) -import qualified Data.HashMap.Strict as HM import qualified Database.PostgreSQL.Simple as PG +import Database.SQLite.Simple (NamedParam((:=))) import qualified Database.SQLite.Simple as SQLite +import Text.InterpolatedString.Perl6 (q, qc) -import Sproxy.Server.DB.DataFile ( DataFile(..), GroupMember(..), - GroupPrivilege(..), PrivilegeRule(..) ) import qualified Sproxy.Application.Access as A import qualified Sproxy.Logging as Log - +import Sproxy.Server.DB.DataFile + (DataFile(..), GroupMember(..), GroupPrivilege(..), + PrivilegeRule(..)) type Database = Pool SQLite.Connection -data DataSource = PostgreSQL String | File FilePath +data DataSource + = PostgreSQL String + | File FilePath {- TODO: - Hash remote tables and update the local only when the remote change - Switch to REGEX - Generalize sync procedures for different tables -} - start :: FilePath -> Maybe DataSource -> IO Database start home ds = do Log.info $ "home directory: " ++ show home - db <- createPool - (do c <- SQLite.open $ home ++ "/sproxy.sqlite3" - lvl <- Log.level - SQLite.setTrace c (if lvl == Log.Debug then Just $ Log.debug . unpack else Nothing) - return c) - SQLite.close - 1 -- stripes - 3600 -- keep alive (seconds). FIXME: no much sense as it's a local file - 128 -- max connections. FIXME: make configurable? - + db <- + createPool + (do c <- SQLite.open $ home ++ "/sproxy.sqlite3" + lvl <- Log.level + SQLite.setTrace + c + (if lvl == Log.Debug + then Just $ Log.debug . unpack + else Nothing) + return c) + SQLite.close + 1 -- stripes + 3600 -- keep alive (seconds). FIXME: no much sense as it's a local file + 128 -- max connections. FIXME: make configurable? withResource db $ \c -> SQLite.execute_ c "PRAGMA journal_mode=WAL" populate db ds return db - userExists :: Database -> Text -> IO Bool userExists db email = do - r <- withResource db $ \c -> fmap SQLite.fromOnly <$> SQLite.queryNamed c - "SELECT EXISTS (SELECT 1 FROM group_member WHERE :email LIKE email LIMIT 1)" - [ ":email" := email ] + r <- + withResource db $ \c -> + fmap SQLite.fromOnly <$> + SQLite.queryNamed + c + "SELECT EXISTS (SELECT 1 FROM group_member WHERE :email LIKE email LIMIT 1)" + [":email" := email] return $ head r - userGroups_ :: SQLite.Connection -> Text -> Text -> Text -> Text -> IO [Text] userGroups_ c email domain path method = - fmap SQLite.fromOnly <$> SQLite.queryNamed c [q| + fmap SQLite.fromOnly <$> + SQLite.queryNamed + c + [q| SELECT gm."group" FROM group_privilege gp JOIN group_member gm ON gm."group" = gp."group" WHERE :email LIKE gm.email AND gp.domain = :domain @@ -77,12 +88,12 @@ userGroups_ c email domain path method = AND method = :method ORDER BY length(path) - length(replace(path, '/', '')) DESC LIMIT 1 ) - |] [ ":email" := email -- XXX always in lower case - , ":domain" := toLower domain - , ":path" := path - , ":method" := method -- XXX case-sensitive by RFC2616 - ] - + |] + [ ":email" := email -- XXX always in lower case + , ":domain" := toLower domain + , ":path" := path + , ":method" := method -- XXX case-sensitive by RFC2616 + ] userAccess :: Database -> Text -> Text -> A.Inquiry -> IO [Text] userAccess db email domain inq = do @@ -90,80 +101,85 @@ userAccess db email domain inq = do not . null <$> userGroups_ c email domain (A.path qn) (A.method qn) map fst <$> withResource db (\c -> filterM (permitted c) (HM.toList inq)) - userGroups :: Database -> Text -> Text -> Text -> Text -> IO [Text] userGroups db email domain path method = withResource db $ \c -> userGroups_ c email domain path method - populate :: Database -> Maybe DataSource -> IO () - populate db Nothing = do Log.warn "db: no data source defined" - withResource db $ \c -> SQLite.withTransaction c $ do - createGroupMember c - createGroupPrivilege c - createPrivilegeRule c - + withResource db $ \c -> + SQLite.withTransaction c $ do + createGroupMember c + createGroupPrivilege c + createPrivilegeRule c populate db (Just (File f)) = do Log.info $ "db: reading " ++ show f r <- decodeFileEither f case r of - Left e -> Log.error $ f ++ ": " ++ show e + Left e -> Log.error $ f ++ ": " ++ show e Right df -> - withResource db $ \c -> SQLite.withTransaction c $ do - refreshGroupMembers c $ \st -> - mapM_ (\gm -> submit st (gmGroup gm, toLower $ gmEmail gm) - ) (groupMember df) - - refreshGroupPrivileges c $ \st -> - mapM_ (\gp -> submit st (gpGroup gp, toLower $ gpDomain gp, gpPrivilege gp) - ) (groupPrivilege df) - - refreshPrivilegeRule c $ \st -> - mapM_ (\pr -> submit st (toLower $ prDomain pr, prPrivilege pr, prPath pr, prMethod pr) - ) (privilegeRule df) - - + withResource db $ \c -> + SQLite.withTransaction c $ do + refreshGroupMembers c $ \st -> + mapM_ + (\gm -> submit st (gmGroup gm, toLower $ gmEmail gm)) + (groupMember df) + refreshGroupPrivileges c $ \st -> + mapM_ + (\gp -> + submit st (gpGroup gp, toLower $ gpDomain gp, gpPrivilege gp)) + (groupPrivilege df) + refreshPrivilegeRule c $ \st -> + mapM_ + (\pr -> + submit + st + ( toLower $ prDomain pr + , prPrivilege pr + , prPath pr + , prMethod pr)) + (privilegeRule df) populate db (Just (PostgreSQL connstr)) = - void . forkIO . forever . flip finally (7 `minutes` threadDelay) - . logException $ do + void . + forkIO . forever . flip finally (7 `minutes` threadDelay) . logException $ do Log.info $ "db: synchronizing with " ++ show connstr - withResource db $ \c -> SQLite.withTransaction c $ - bracket (PG.connectPostgreSQL $ pack connstr) PG.close $ - \pg -> PG.withTransaction pg $ do - + withResource db $ \c -> + SQLite.withTransaction c $ + bracket (PG.connectPostgreSQL $ pack connstr) PG.close $ \pg -> + PG.withTransaction pg $ do Log.info "db: syncing group_member" refreshGroupMembers c $ \st -> - PG.forEach_ pg - [q|SELECT "group", lower(email) FROM group_member|] $ \r -> - submit st (r :: (Text, Text)) + PG.forEach_ pg [q|SELECT "group", lower(email) FROM group_member|] $ \r -> + submit st (r :: (Text, Text)) count c "group_member" - Log.info "db: syncing group_privilege" refreshGroupPrivileges c $ \st -> - PG.forEach_ pg + PG.forEach_ + pg [q|SELECT "group", lower(domain), privilege FROM group_privilege|] $ \r -> - submit st (r :: (Text, Text, Text)) + submit st (r :: (Text, Text, Text)) count c "group_privilege" - Log.info "db: syncing privilege_rule" refreshPrivilegeRule c $ \st -> - PG.forEach_ pg + PG.forEach_ + pg [q|SELECT lower(domain), privilege, path, method FROM privilege_rule|] $ \r -> - submit st (r :: (Text, Text, Text, Text)) + submit st (r :: (Text, Text, Text, Text)) count c "privilege_rule" - -- FIXME short-cut for https://github.com/nurpax/sqlite-simple/issues/50 -- FIXME nextRow is the only way to execute a prepared statement -- FIXME with bound parameters, but we don't expect any results. submit :: SQLite.ToRow values => SQLite.Statement -> values -> IO () -submit st v = SQLite.withBind st v $ void (SQLite.nextRow st :: IO (Maybe [Int])) - +submit st v = + SQLite.withBind st v $ void (SQLite.nextRow st :: IO (Maybe [Int])) createGroupMember :: SQLite.Connection -> IO () -createGroupMember c = SQLite.execute_ c [q| +createGroupMember c = + SQLite.execute_ + c + [q| CREATE TABLE IF NOT EXISTS group_member ( "group" TEXT, email TEXT, @@ -175,13 +191,16 @@ refreshGroupMembers :: SQLite.Connection -> (SQLite.Statement -> IO ()) -> IO () refreshGroupMembers c a = do SQLite.execute_ c "DROP TABLE IF EXISTS group_member" createGroupMember c - SQLite.withStatement c + SQLite.withStatement + c [q|INSERT INTO group_member("group", email) VALUES (?, ?)|] a - createGroupPrivilege :: SQLite.Connection -> IO () -createGroupPrivilege c = SQLite.execute_ c [q| +createGroupPrivilege c = + SQLite.execute_ + c + [q| CREATE TABLE IF NOT EXISTS group_privilege ( "group" TEXT, domain TEXT, @@ -190,17 +209,21 @@ createGroupPrivilege c = SQLite.execute_ c [q| ) |] -refreshGroupPrivileges :: SQLite.Connection -> (SQLite.Statement -> IO ()) -> IO () +refreshGroupPrivileges :: + SQLite.Connection -> (SQLite.Statement -> IO ()) -> IO () refreshGroupPrivileges c a = do SQLite.execute_ c "DROP TABLE IF EXISTS group_privilege" createGroupPrivilege c - SQLite.withStatement c + SQLite.withStatement + c [q|INSERT INTO group_privilege("group", domain, privilege) VALUES (?, ?, ?)|] a - createPrivilegeRule :: SQLite.Connection -> IO () -createPrivilegeRule c = SQLite.execute_ c [q| +createPrivilegeRule c = + SQLite.execute_ + c + [q| CREATE TABLE IF NOT EXISTS privilege_rule ( domain TEXT, privilege TEXT, @@ -210,26 +233,24 @@ createPrivilegeRule c = SQLite.execute_ c [q| ) |] -refreshPrivilegeRule :: SQLite.Connection -> (SQLite.Statement -> IO ()) -> IO () +refreshPrivilegeRule :: + SQLite.Connection -> (SQLite.Statement -> IO ()) -> IO () refreshPrivilegeRule c a = do SQLite.execute_ c "DROP TABLE IF EXISTS privilege_rule" createPrivilegeRule c - SQLite.withStatement c + SQLite.withStatement + c [q|INSERT INTO privilege_rule(domain, privilege, path, method) VALUES (?, ?, ?, ?)|] a - count :: SQLite.Connection -> String -> IO () count c table = do - r <- fmap SQLite.fromOnly <$> SQLite.query_ c [qc|SELECT COUNT(*) FROM {table}|] + r <- + fmap SQLite.fromOnly <$> SQLite.query_ c [qc|SELECT COUNT(*) FROM {table}|] Log.info $ "db: " ++ table ++ " rows: " ++ show (head r :: Integer) - logException :: IO () -> IO () -logException a = catch a $ \e -> - Log.error $ "db: " ++ show (e :: SomeException) - +logException a = catch a $ \e -> Log.error $ "db: " ++ show (e :: SomeException) minutes :: Int -> (Int -> IO ()) -> IO () minutes us f = f $ us * 60 * 1000000 - diff --git a/src/Sproxy/Server/DB/DataFile.hs b/src/Sproxy/Server/DB/DataFile.hs index efac923..5708c96 100644 --- a/src/Sproxy/Server/DB/DataFile.hs +++ b/src/Sproxy/Server/DB/DataFile.hs @@ -1,69 +1,58 @@ {-# LANGUAGE OverloadedStrings #-} -module Sproxy.Server.DB.DataFile ( - DataFile(..) -, GroupMember(..) -, GroupPrivilege(..) -, PrivilegeRule(..) -) where + +module Sproxy.Server.DB.DataFile + ( DataFile(..) + , GroupMember(..) + , GroupPrivilege(..) + , PrivilegeRule(..) + ) where import Control.Applicative (empty) import Data.Aeson (FromJSON, parseJSON) import Data.Text (Text) import Data.Yaml (Value(Object), (.:)) - -data DataFile = DataFile { - groupMember :: [GroupMember] -, groupPrivilege :: [GroupPrivilege] -, privilegeRule :: [PrivilegeRule] -} deriving (Show) +data DataFile = DataFile + { groupMember :: [GroupMember] + , groupPrivilege :: [GroupPrivilege] + , privilegeRule :: [PrivilegeRule] + } deriving (Show) instance FromJSON DataFile where - parseJSON (Object m) = DataFile <$> - m .: "group_member" - <*> m .: "group_privilege" - <*> m .: "privilege_rule" + parseJSON (Object m) = + DataFile <$> m .: "group_member" <*> m .: "group_privilege" <*> + m .: "privilege_rule" parseJSON _ = empty - -data GroupMember = GroupMember { - gmGroup :: Text -, gmEmail :: Text -} deriving (Show) +data GroupMember = GroupMember + { gmGroup :: Text + , gmEmail :: Text + } deriving (Show) instance FromJSON GroupMember where - parseJSON (Object m) = GroupMember <$> - m .: "group" - <*> m .: "email" + parseJSON (Object m) = GroupMember <$> m .: "group" <*> m .: "email" parseJSON _ = empty - -data GroupPrivilege = GroupPrivilege { - gpGroup :: Text -, gpDomain :: Text -, gpPrivilege :: Text -} deriving (Show) +data GroupPrivilege = GroupPrivilege + { gpGroup :: Text + , gpDomain :: Text + , gpPrivilege :: Text + } deriving (Show) instance FromJSON GroupPrivilege where - parseJSON (Object m) = GroupPrivilege <$> - m .: "group" - <*> m .: "domain" - <*> m .: "privilege" + parseJSON (Object m) = + GroupPrivilege <$> m .: "group" <*> m .: "domain" <*> m .: "privilege" parseJSON _ = empty - -data PrivilegeRule = PrivilegeRule { - prDomain :: Text -, prPrivilege :: Text -, prPath :: Text -, prMethod :: Text -} deriving (Show) +data PrivilegeRule = PrivilegeRule + { prDomain :: Text + , prPrivilege :: Text + , prPath :: Text + , prMethod :: Text + } deriving (Show) instance FromJSON PrivilegeRule where - parseJSON (Object m) = PrivilegeRule <$> - m .: "domain" - <*> m .: "privilege" - <*> m .: "path" - <*> m .: "method" + parseJSON (Object m) = + PrivilegeRule <$> m .: "domain" <*> m .: "privilege" <*> m .: "path" <*> + m .: "method" parseJSON _ = empty - |