diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/Main.hs | 37 | ||||
-rw-r--r-- | src/Sproxy/Application.hs | 372 | ||||
-rw-r--r-- | src/Sproxy/Application/Cookie.hs | 44 | ||||
-rw-r--r-- | src/Sproxy/Application/OAuth2.hs | 18 | ||||
-rw-r--r-- | src/Sproxy/Application/OAuth2/Common.hs | 39 | ||||
-rw-r--r-- | src/Sproxy/Application/OAuth2/Google.hs | 78 | ||||
-rw-r--r-- | src/Sproxy/Application/OAuth2/LinkedIn.hs | 83 | ||||
-rw-r--r-- | src/Sproxy/Application/State.hs | 30 | ||||
-rw-r--r-- | src/Sproxy/Config.hs | 88 | ||||
-rw-r--r-- | src/Sproxy/Logging.hs | 99 | ||||
-rw-r--r-- | src/Sproxy/Server.hs | 190 | ||||
-rw-r--r-- | src/Sproxy/Server/DB.hs | 189 |
12 files changed, 1267 insertions, 0 deletions
diff --git a/src/Main.hs b/src/Main.hs new file mode 100644 index 0000000..7101af0 --- /dev/null +++ b/src/Main.hs @@ -0,0 +1,37 @@ +{-# LANGUAGE QuasiQuotes #-} +module Main ( + main +) where + +import Data.Maybe (fromJust) +import Data.Version (showVersion) +import Paths_sproxy2 (version) -- from cabal +import System.Environment (getArgs) +import Text.InterpolatedString.Perl6 (qc) +import qualified System.Console.Docopt.NoTH as O + +import Sproxy.Server (server) + +usage :: String +usage = "sproxy2 " ++ showVersion version ++ + " - HTTP proxy for authenticating users via OAuth2" ++ [qc| + +Usage: + sproxy2 [options] + +Options: + -c, --config=FILE Configuration file [default: sproxy.yml] + -h, --help Show this message + +|] + +main :: IO () +main = do + doco <- O.parseUsageOrExit usage + args <- O.parseArgsOrExit doco =<< getArgs + if args `O.isPresent` O.longOption "help" + then putStrLn $ O.usage doco + else do + let configFile = fromJust . O.getArg args $ O.longOption "config" + server configFile + diff --git a/src/Sproxy/Application.hs b/src/Sproxy/Application.hs new file mode 100644 index 0000000..2391220 --- /dev/null +++ b/src/Sproxy/Application.hs @@ -0,0 +1,372 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +module Sproxy.Application ( + sproxy +, redirect +) where + +import Blaze.ByteString.Builder (toByteString) +import Blaze.ByteString.Builder.ByteString (fromByteString) +import Control.Exception (SomeException, catch) +import Data.ByteString (ByteString) +import Data.ByteString as BS (break, intercalate) +import Data.Char (toLower) +import Data.ByteString.Char8 (pack, unpack) +import Data.ByteString.Lazy (fromStrict) +import Data.Conduit (Flush(Chunk), mapOutput) +import Data.HashMap.Strict as HM (HashMap, foldrWithKey, lookup) +import Data.List (find, partition) +import Data.Map as Map (delete, fromListWith, insert, insertWith, toList) +import Data.Maybe (fromJust, fromMaybe) +import Data.Monoid ((<>)) +import Data.Text (Text) +import Data.Text.Encoding (decodeUtf8, encodeUtf8) +import Data.Time.Clock.POSIX (posixSecondsToUTCTime) +import Data.Word (Word16) +import Data.Word8 (_colon) +import Foreign.C.Types (CTime(..)) +import Network.HTTP.Client.Conduit (bodyReaderSource) +import Network.HTTP.Conduit (requestBodySourceChunkedIO, requestBodySourceIO) +import Network.HTTP.Types (RequestHeaders, ResponseHeaders, hConnection, + hContentLength, hContentType, hCookie, hLocation, methodGet) +import Network.HTTP.Types.Status ( Status(..), badRequest400, forbidden403, found302, + internalServerError500, methodNotAllowed405, movedPermanently301, + networkAuthenticationRequired511, notFound404, ok200, seeOther303, temporaryRedirect307 ) +import Network.Socket (NameInfoFlag(NI_NUMERICHOST), getNameInfo) +import Network.Wai.Conduit (sourceRequestBody, responseSource) +import System.FilePath.Glob (Pattern, match) +import System.Posix.Time (epochTime) +import Text.InterpolatedString.Perl6 (qc) +import Web.Cookie (Cookies, parseCookies, renderCookies) +import qualified Network.HTTP.Client as BE +import qualified Network.Wai as W +import qualified Web.Cookie as WC + +import Sproxy.Application.Cookie (AuthCookie(..), AuthUser(..), cookieDecode, cookieEncode) +import Sproxy.Application.OAuth2.Common (OAuth2Client(..)) +import Sproxy.Config(BackendConf(..)) +import Sproxy.Server.DB (Database, userExists, userGroups) +import qualified Sproxy.Application.State as State +import qualified Sproxy.Logging as Log + + +redirect :: Word16 -> W.Application +redirect p req resp = + case W.requestHeaderHost req of + Nothing -> badRequest "missing host" req resp + Just host -> do + Log.info $ "redirecting to " ++ show location ++ ": " ++ showReq req + resp $ W.responseBuilder status [(hLocation, location)] mempty + where + status = if W.requestMethod req == methodGet then movedPermanently301 else temporaryRedirect307 + (domain, _) = BS.break (== _colon) host + newhost = if p == 443 then domain else domain <> ":" <> pack (show p) + location = "https://" <> newhost <> W.rawPathInfo req <> W.rawQueryString req + + +sproxy :: ByteString -> Database -> HashMap Text OAuth2Client -> [(Pattern, BackendConf, BE.Manager)] -> W.Application +sproxy key db oa2 backends = logException $ \req resp -> do + Log.debug $ "sproxy <<< " ++ showReq req + case W.requestHeaderHost req of + Nothing -> badRequest "missing host" req resp + Just host -> + case find (\(p, _, _) -> match p (unpack host)) backends of + Nothing -> notFound "backend" req resp + Just (_, be, mgr) -> do + let cookieName = pack $ beCookieName be + cookieDomain = pack <$> beCookieDomain be + case W.pathInfo req of + ["robots.txt"] -> get robots req resp + (".sproxy":proxy) -> + case proxy of + ["logout"] -> + case extractCookie key Nothing cookieName req of + Nothing -> notFound "logout without the cookie" req resp + Just _ -> get (logout cookieName cookieDomain) req resp + ["oauth2", provider] -> + case HM.lookup provider oa2 of + Nothing -> notFound "OAuth2 provider" req resp + Just oa2c -> get (oauth2callback key db (provider, oa2c) be) req resp + _ -> notFound "proxy" req resp + _ -> do + now <- Just <$> epochTime + case extractCookie key now cookieName req of + Nothing -> authenticationRequired key oa2 req resp + Just cs@(authCookie, _) -> + authorize db cs req >>= \case + Nothing -> forbidden authCookie req resp + Just req' -> forward mgr req' resp + + +robots :: W.Application +robots _ resp = resp $ + W.responseLBS ok200 [(hContentType, "text/plain; charset=utf-8")] + "User-agent: *\nDisallow: /" + + +oauth2callback :: ByteString -> Database -> (Text, OAuth2Client) -> BackendConf -> W.Application +oauth2callback key db (provider, oa2c) be req resp = + case param "code" of + Nothing -> badRequest "missing auth code" req resp + Just code -> + case param "state" of + Nothing -> badRequest "missing auth state" req resp + Just state -> + case State.decode key state of + Left msg -> badRequest ("invalid state: " ++ msg) req resp + Right path -> do + au <- oauth2Authenticate oa2c code (redirectURL req provider) + let email = map toLower $ auEmail au + Log.info $ "login `" ++ email ++ "' by " ++ show provider + exists <- userExists db email + if exists then authenticate key be au{auEmail = email} path req resp + else userNotFound email req resp + where + param p = do + (_, v) <- find ((==) p . fst) $ W.queryString req + v + + +-- XXX: RFC6265: the user agent MUST NOT attach more than one Cookie header field +extractCookie :: ByteString -> Maybe CTime -> ByteString -> W.Request -> Maybe (AuthCookie, Cookies) +extractCookie key now name req = do + (_, cookies) <- find ((==) hCookie . fst) $ W.requestHeaders req + (auth, others) <- discriminate cookies + case cookieDecode key auth of + Left _ -> Nothing + Right cookie -> if maybe True (acExpiry cookie >) now + then Just (cookie, others) else Nothing + where discriminate cs = + case partition ((==) name . fst) $ parseCookies cs of + ((_, x):_, xs) -> Just (x, xs) + _ -> Nothing + + +authenticate :: ByteString -> BackendConf -> AuthUser -> ByteString -> W.Application +authenticate key be user path req resp = do + now <- epochTime + let host = fromJust $ W.requestHeaderHost req + domain = pack <$> beCookieDomain be + expiry = now + CTime (beCookieMaxAge be) + authCookie = AuthCookie { acUser = user, acExpiry = expiry } + cookie = WC.def { + WC.setCookieName = pack $ beCookieName be + , WC.setCookieHttpOnly = True + , WC.setCookiePath = Just "/" + , WC.setCookieSameSite = Nothing + , WC.setCookieSecure = True + , WC.setCookieValue = cookieEncode key authCookie + , WC.setCookieDomain = domain + , WC.setCookieExpires = Just . posixSecondsToUTCTime . realToFrac $ expiry + } + resp $ W.responseLBS seeOther303 [ + (hLocation, "https://" <> host <> path) + , ("Set-Cookie", toByteString $ WC.renderSetCookie cookie) + ] "" + + +authorize :: Database -> (AuthCookie, Cookies) -> W.Request -> IO (Maybe W.Request) +authorize db (authCookie, otherCookies) req = do + grps <- userGroups db email domain path method + if null grps then return Nothing + else do + ip <- pack . fromJust . fst <$> getNameInfo [NI_NUMERICHOST] True False (W.remoteHost req) + return . Just $ req { + W.requestHeaders = toList $ + insert "From" (pack email) $ + insert "X-Groups" (BS.intercalate "," grps) $ + insert "X-Given-Name" given $ + insert "X-Family-Name" family $ + insert "X-Forwarded-Proto" "https" $ + insertWith (flip combine) "X-Forwarded-For" ip $ + setCookies otherCookies $ + fromListWith combine $ W.requestHeaders req + } + where + user = acUser authCookie + email = auEmail user + given = pack $ auGivenName user + family = pack $ auFamilyName user + domain = decodeUtf8 . fromJust $ W.requestHeaderHost req + path = decodeUtf8 $ W.rawPathInfo req + method = decodeUtf8 $ W.requestMethod req + combine a b = a <> "," <> b + setCookies [] = delete hCookie + setCookies cs = insert hCookie (toByteString . renderCookies $ cs) + + +forward :: BE.Manager -> W.Application +forward mgr req resp = do + let beReq = BE.defaultRequest + { BE.method = W.requestMethod req + , BE.path = W.rawPathInfo req + , BE.queryString = W.rawQueryString req + , BE.requestHeaders = modifyRequestHeaders $ W.requestHeaders req + , BE.redirectCount = 0 + , BE.decompress = const False + , BE.requestBody = case W.requestBodyLength req of + W.ChunkedBody -> requestBodySourceChunkedIO (sourceRequestBody req) + W.KnownLength l -> requestBodySourceIO (fromIntegral l) (sourceRequestBody req) + } + msg = unpack (BE.method beReq <> " " <> BE.path beReq <> BE.queryString beReq) + Log.debug $ "BACKEND <<< " ++ msg ++ " " ++ show (BE.requestHeaders beReq) + BE.withResponse beReq mgr $ \res -> do + let status = BE.responseStatus res + headers = modifyResponseHeaders $ BE.responseHeaders res + body = mapOutput (Chunk . fromByteString) . bodyReaderSource $ BE.responseBody res + logging = if statusCode status `elem` [ 400, 500 ] then + Log.warn else Log.debug + logging $ "BACKEND >>> " ++ show (statusCode status) ++ " on " ++ msg ++ "\n" + resp $ responseSource status headers body + + +modifyRequestHeaders :: RequestHeaders -> RequestHeaders +modifyRequestHeaders = filter (\(n, _) -> n `notElem` ban) + where + ban = + [ + hConnection + , hContentLength -- XXX to avoid duplicate header + ] + +modifyResponseHeaders :: ResponseHeaders -> ResponseHeaders +modifyResponseHeaders = filter (\(n, _) -> n `notElem` ban) + where + ban = + [ + hConnection + ] + +authenticationRequired :: ByteString -> HashMap Text OAuth2Client -> W.Application +authenticationRequired key oa2 req resp = do + Log.info $ "511 Unauthenticated: " ++ showReq req + resp $ W.responseLBS networkAuthenticationRequired511 [(hContentType, "text/html; charset=utf-8")] page + where + path = W.rawPathInfo req -- FIXME: make it more robust for non-GET or XMLHTTPRequest? + state = State.encode key path + authLink :: Text -> OAuth2Client -> ByteString -> ByteString + authLink provider oa2c html = + let u = oauth2AuthorizeURL oa2c state (redirectURL req provider) + d = pack $ oauth2Description oa2c + in [qc|{html}<p><a href="{u}">Authenticate with {d}</a></p>|] + authHtml = foldrWithKey authLink "" oa2 + page = fromStrict [qc| +<!DOCTYPE html> +<html lang="en"> + <head> + <meta charset="utf-8"> + <title>Authentication required</title> + </head> + <body style="text-align:center;"> + <h1>Authentication required</h1> + {authHtml} + </body> +</html> +|] + + +forbidden :: AuthCookie -> W.Application +forbidden ac req resp = do + Log.info $ "403 Forbidden (" ++ email ++ "): " ++ showReq req + resp $ W.responseLBS forbidden403 [(hContentType, "text/html; charset=utf-8")] page + where + email = auEmail . acUser $ ac + page = fromStrict [qc| +<!DOCTYPE html> +<html lang="en"> + <head> + <meta charset="utf-8"> + <title>Access Denied</title> + </head> + <body> + <h1>Access Denied</h1> + <p>You are currently logged in as <strong>{email}</strong></p> + <p><a href="/.sproxy/logout">Logout</a></p> + </body> +</html> +|] + + +userNotFound :: String -> W.Application +userNotFound email _ resp = do + Log.info $ "404 User not found (" ++ email ++ ")" + resp $ W.responseLBS notFound404 [(hContentType, "text/html; charset=utf-8")] page + where + page = fromStrict [qc| +<!DOCTYPE html> +<html lang="en"> + <head> + <meta charset="utf-8"> + <title>Access Denied</title> + </head> + <body> + <h1>Access Denied</h1> + <p>You are not allowed to login as <strong>{email}</strong></p> + <p><a href="/">Main page</a></p> + </body> +</html> +|] + + +logout :: ByteString -> Maybe ByteString -> W.Application +logout name domain req resp = do + let host = fromJust $ W.requestHeaderHost req + cookie = WC.def { + WC.setCookieName = name + , WC.setCookieHttpOnly = True + , WC.setCookiePath = Just "/" + , WC.setCookieSameSite = Just WC.sameSiteStrict + , WC.setCookieSecure = True + , WC.setCookieValue = "goodbye" + , WC.setCookieDomain = domain + , WC.setCookieExpires = Just . posixSecondsToUTCTime . realToFrac $ CTime 0 + } + resp $ W.responseLBS found302 [ + (hLocation, "https://" <> host) + , ("Set-Cookie", toByteString $ WC.renderSetCookie cookie) + ] "" + + +badRequest ::String -> W.Application +badRequest msg req resp = do + Log.warn $ "400 Bad Request (" ++ msg ++ "): " ++ showReq req + resp $ W.responseLBS badRequest400 [] "Bad Request" + + +notFound ::String -> W.Application +notFound msg req resp = do + Log.warn $ "404 Not Found (" ++ msg ++ "): " ++ showReq req + resp $ W.responseLBS notFound404 [] "Not Found" + + +logException :: W.Middleware +logException app req resp = catch (app req resp) $ \e -> do + Log.error $ "500 Internal Error: " ++ show (e :: SomeException) ++ " on " ++ showReq req + resp $ W.responseLBS internalServerError500 [] "Internal Error" + + +get :: W.Middleware +get app req resp + | W.requestMethod req == methodGet = app req resp + | otherwise = do + Log.warn $ "405 Method Not Allowed: " ++ showReq req + resp $ W.responseLBS methodNotAllowed405 [("Allow", "GET")] "Method Not Allowed" + + +redirectURL :: W.Request -> Text -> ByteString +redirectURL req provider = + "https://" <> fromJust (W.requestHeaderHost req) + <> "/.sproxy/oauth2/" <> encodeUtf8 provider + + +-- XXX: make sure not to reveal the cookie, which can be valid (!) +showReq :: W.Request -> String +showReq req = + unpack ( W.requestMethod req <> " " + <> fromMaybe "<no host>" (W.requestHeaderHost req) + <> W.rawPathInfo req <> W.rawQueryString req <> " " ) + ++ show (fromMaybe "-" $ W.requestHeaderReferer req) ++ " " + ++ show (fromMaybe "-" $ W.requestHeaderUserAgent req) + ++ " from " ++ show (W.remoteHost req) + diff --git a/src/Sproxy/Application/Cookie.hs b/src/Sproxy/Application/Cookie.hs new file mode 100644 index 0000000..07cc162 --- /dev/null +++ b/src/Sproxy/Application/Cookie.hs @@ -0,0 +1,44 @@ +module Sproxy.Application.Cookie ( + AuthCookie(..) +, AuthUser(..) +, cookieDecode +, cookieEncode +) where + +import Data.ByteString (ByteString) +import Foreign.C.Types (CTime(..)) +import qualified Data.Serialize as DS + +import qualified Sproxy.Application.State as State + +data AuthUser = AuthUser { + auEmail :: String +, auGivenName :: String +, auFamilyName :: String +} + +data AuthCookie = AuthCookie { + acUser :: AuthUser +, acExpiry :: CTime +} + +instance DS.Serialize AuthCookie where + put c = DS.put (auEmail u, auGivenName u, auFamilyName u, x) + where u = acUser c + x = (\(CTime i) -> i) $ acExpiry c + get = do + (e, n, f, x) <- DS.get + return AuthCookie { + acUser = AuthUser { auEmail = e, auGivenName = n, auFamilyName = f } + , acExpiry = CTime x + } + + +cookieDecode :: ByteString -> ByteString -> Either String AuthCookie +cookieDecode key d = State.decode key d >>= DS.decode + + +cookieEncode :: ByteString -> AuthCookie -> ByteString +cookieEncode key = State.encode key . DS.encode + + diff --git a/src/Sproxy/Application/OAuth2.hs b/src/Sproxy/Application/OAuth2.hs new file mode 100644 index 0000000..0f7d6e8 --- /dev/null +++ b/src/Sproxy/Application/OAuth2.hs @@ -0,0 +1,18 @@ +{-# LANGUAGE OverloadedStrings #-} +module Sproxy.Application.OAuth2 ( + providers +) where + +import Data.HashMap.Strict (HashMap, fromList) +import Data.Text (Text) + +import Sproxy.Application.OAuth2.Common (OAuth2Provider) +import qualified Sproxy.Application.OAuth2.Google as Google +import qualified Sproxy.Application.OAuth2.LinkedIn as LinkedIn + +providers :: HashMap Text OAuth2Provider +providers = fromList [ + ("google" , Google.provider) + , ("linkedin" , LinkedIn.provider) + ] + diff --git a/src/Sproxy/Application/OAuth2/Common.hs b/src/Sproxy/Application/OAuth2/Common.hs new file mode 100644 index 0000000..07fb759 --- /dev/null +++ b/src/Sproxy/Application/OAuth2/Common.hs @@ -0,0 +1,39 @@ +{-# LANGUAGE OverloadedStrings #-} +module Sproxy.Application.OAuth2.Common ( + AccessTokenBody(..) +, OAuth2Client(..) +, OAuth2Provider +) where + +import Control.Applicative (empty) +import Data.Aeson (FromJSON, parseJSON, Value(Object), (.:)) +import Data.ByteString(ByteString) + +import Sproxy.Application.Cookie (AuthUser) + +data OAuth2Client = OAuth2Client { + oauth2Description :: String +, oauth2AuthorizeURL + :: ByteString -- state + -> ByteString -- redirect url + -> ByteString +, oauth2Authenticate + :: ByteString -- code + -> ByteString -- redirect url + -> IO AuthUser +} + +type OAuth2Provider = (ByteString, ByteString) -> OAuth2Client + +-- | RFC6749. We ignore optional token_type ("Bearer" from Google, omitted by LinkedIn) +-- and expires_in because we don't use them, *and* expires_in creates troubles: +-- it's an integer from Google and string from LinkedIn (sic!) +data AccessTokenBody = AccessTokenBody { + accessToken :: String +} deriving (Eq, Show) + +instance FromJSON AccessTokenBody where + parseJSON (Object v) = AccessTokenBody + <$> v .: "access_token" + parseJSON _ = empty + diff --git a/src/Sproxy/Application/OAuth2/Google.hs b/src/Sproxy/Application/OAuth2/Google.hs new file mode 100644 index 0000000..6b68f44 --- /dev/null +++ b/src/Sproxy/Application/OAuth2/Google.hs @@ -0,0 +1,78 @@ +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE OverloadedStrings #-} +module Sproxy.Application.OAuth2.Google ( + provider +) where + +import Control.Applicative (empty) +import Control.Exception (Exception, throwIO) +import Data.Aeson (FromJSON, decode, parseJSON, Value(Object), (.:)) +import Data.ByteString.Lazy (ByteString) +import Data.Monoid ((<>)) +import Data.Typeable (Typeable) +import Network.HTTP.Types (hContentType) +import Network.HTTP.Types.URI (urlEncode) +import qualified Network.HTTP.Conduit as H + +import Sproxy.Application.Cookie (AuthUser(..)) +import Sproxy.Application.OAuth2.Common (AccessTokenBody(accessToken), OAuth2Client(..), OAuth2Provider) + + +provider :: OAuth2Provider +provider (client_id, client_secret) = + OAuth2Client { + oauth2Description = "Google" + , oauth2AuthorizeURL = \state redirect_uri -> + "https://accounts.google.com/o/oauth2/v2/auth" + <> "?scope=" <> urlEncode True "https://www.googleapis.com/auth/userinfo.email https://www.googleapis.com/auth/userinfo.profile" + <> "&client_id=" <> urlEncode True client_id + <> "&prompt=select_account" + <> "&redirect_uri=" <> urlEncode True redirect_uri + <> "&response_type=code" + <> "&state=" <> urlEncode True state + + , oauth2Authenticate = \code redirect_uri -> do + let treq = H.setQueryString [ + ("client_id" , Just client_id) + , ("client_secret" , Just client_secret) + , ("code" , Just code) + , ("grant_type" , Just "authorization_code") + , ("redirect_uri" , Just redirect_uri) + ] $ (H.parseRequest_ "POST https://www.googleapis.com/oauth2/v4/token") { + H.requestHeaders = [ + (hContentType, "application/x-www-form-urlencoded") + ] + } + mgr <- H.newManager H.tlsManagerSettings + tresp <- H.httpLbs treq mgr + case decode $ H.responseBody tresp of + Nothing -> throwIO $ GoogleException tresp + Just atResp -> do + ureq <- H.parseRequest $ "https://www.googleapis.com/oauth2/v1/userinfo?access_token=" ++ accessToken atResp + uresp <- H.httpLbs ureq mgr + case decode $ H.responseBody uresp of + Nothing -> throwIO $ GoogleException uresp + Just u -> return AuthUser { auEmail = email u, auGivenName = givenName u, auFamilyName = familyName u } + } + + +data GoogleException = GoogleException (H.Response ByteString) + deriving (Show, Typeable) + + +instance Exception GoogleException + + +data GoogleUserInfo = GoogleUserInfo { + email :: String +, givenName :: String +, familyName :: String +} deriving (Eq, Show) + +instance FromJSON GoogleUserInfo where + parseJSON (Object v) = GoogleUserInfo + <$> v .: "email" + <*> v .: "given_name" + <*> v .: "family_name" + parseJSON _ = empty + diff --git a/src/Sproxy/Application/OAuth2/LinkedIn.hs b/src/Sproxy/Application/OAuth2/LinkedIn.hs new file mode 100644 index 0000000..b60afde --- /dev/null +++ b/src/Sproxy/Application/OAuth2/LinkedIn.hs @@ -0,0 +1,83 @@ +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE OverloadedStrings #-} +module Sproxy.Application.OAuth2.LinkedIn ( + provider +) where + +import Control.Applicative (empty) +import Control.Exception (Exception, throwIO) +import Data.Aeson (FromJSON, decode, parseJSON, Value(Object), (.:)) +import Data.ByteString.Char8 (pack) +import Data.ByteString.Lazy (ByteString) +import Data.Monoid ((<>)) +import Data.Typeable (Typeable) +import Network.HTTP.Types (hContentType) +import Network.HTTP.Types.URI (urlEncode) +import qualified Network.HTTP.Conduit as H + +import Sproxy.Application.Cookie (AuthUser(..)) +import Sproxy.Application.OAuth2.Common (AccessTokenBody(accessToken), OAuth2Client(..), OAuth2Provider) + + +provider :: OAuth2Provider +provider (client_id, client_secret) = + OAuth2Client { + oauth2Description = "LinkedIn" + , oauth2AuthorizeURL = \state redirect_uri -> + "https://www.linkedin.com/oauth/v2/authorization" + <> "?scope=r_basicprofile%20r_emailaddress" + <> "&client_id=" <> urlEncode True client_id + <> "&redirect_uri=" <> urlEncode True redirect_uri + <> "&response_type=code" + <> "&state=" <> urlEncode True state + + , oauth2Authenticate = \code redirect_uri -> do + let treq = H.setQueryString [ + ("client_id" , Just client_id) + , ("client_secret" , Just client_secret) + , ("code" , Just code) + , ("grant_type" , Just "authorization_code") + , ("redirect_uri" , Just redirect_uri) + ] $ (H.parseRequest_ "POST https://www.linkedin.com/oauth/v2/accessToken") { + H.requestHeaders = [ + (hContentType, "application/x-www-form-urlencoded") + ] + } + mgr <- H.newManager H.tlsManagerSettings + tresp <- H.httpLbs treq mgr + case decode $ H.responseBody tresp of + Nothing -> throwIO $ LinkedInException tresp + Just atResp -> do + let ureq = (H.parseRequest_ "https://api.linkedin.com/v1/people/\ + \~:(email-address,first-name,last-name)?format=json") { + H.requestHeaders = [ ("Authorization", "Bearer " <> pack (accessToken atResp)) ] + } + uresp <- H.httpLbs ureq mgr + case decode $ H.responseBody uresp of + Nothing -> throwIO $ LinkedInException uresp + Just u -> return AuthUser { auEmail = emailAddress u + , auGivenName = firstName u + , auFamilyName = lastName u } + } + + +data LinkedInException = LinkedInException (H.Response ByteString) + deriving (Show, Typeable) + + +instance Exception LinkedInException + + +data LinkedInUserInfo = LinkedInUserInfo { + emailAddress :: String +, firstName :: String +, lastName :: String +} deriving (Eq, Show) + +instance FromJSON LinkedInUserInfo where + parseJSON (Object v) = LinkedInUserInfo + <$> v .: "emailAddress" + <*> v .: "firstName" + <*> v .: "lastName" + parseJSON _ = empty + diff --git a/src/Sproxy/Application/State.hs b/src/Sproxy/Application/State.hs new file mode 100644 index 0000000..29d9252 --- /dev/null +++ b/src/Sproxy/Application/State.hs @@ -0,0 +1,30 @@ +module Sproxy.Application.State ( + decode +, encode +) where + +import Data.ByteString (ByteString) +import Data.ByteString.Lazy (fromStrict, toStrict) +import Data.Digest.Pure.SHA (hmacSha1, bytestringDigest) +import qualified Data.ByteString.Base64 as Base64 +import qualified Data.Serialize as DS + + +-- FIXME: Compress / decompress ? + + +encode :: ByteString -> ByteString -> ByteString +encode key payload = Base64.encode . DS.encode $ (payload, digest key payload) + + +decode :: ByteString -> ByteString -> Either String ByteString +decode key d = do + (payload, dgst) <- DS.decode =<< Base64.decode d + if dgst /= digest key payload + then Left "junk" + else Right payload + + +digest :: ByteString -> ByteString -> ByteString +digest key payload = toStrict . bytestringDigest $ hmacSha1 (fromStrict key) (fromStrict payload) + diff --git a/src/Sproxy/Config.hs b/src/Sproxy/Config.hs new file mode 100644 index 0000000..30a8bae --- /dev/null +++ b/src/Sproxy/Config.hs @@ -0,0 +1,88 @@ +{-# LANGUAGE OverloadedStrings #-} +module Sproxy.Config ( + BackendConf(..) +, ConfigFile(..) +, OAuth2Conf(..) +) where + +import Control.Applicative (empty) +import Data.Aeson (FromJSON, parseJSON) +import Data.HashMap.Strict (HashMap) +import Data.Int (Int64) +import Data.Text (Text) +import Data.Word (Word16) +import Data.Yaml (Value(Object), (.:), (.:?), (.!=)) + +import Sproxy.Logging (LogLevel(Debug)) + +data ConfigFile = ConfigFile { + cfListen :: Word16 +, cfUser :: String +, cfHome :: FilePath +, cfLogLevel :: LogLevel +, cfSslCert :: FilePath +, cfSslKey :: FilePath +, cfSslCertChain :: [FilePath] +, cfKey :: Maybe FilePath +, cfListen80 :: Maybe Bool +, cfBackends :: [BackendConf] +, cfOAuth2 :: HashMap Text OAuth2Conf +, cfDatabase :: Maybe String +, cfPgPassFile :: Maybe FilePath +, cfHTTP2 :: Bool +} deriving (Show) + +instance FromJSON ConfigFile where + parseJSON (Object m) = ConfigFile <$> + m .:? "listen" .!= 443 + <*> m .:? "user" .!= "sproxy" + <*> m .:? "home" .!= "." + <*> m .:? "log_level" .!= Debug + <*> m .: "ssl_cert" + <*> m .: "ssl_key" + <*> m .:? "ssl_cert_chain" .!= [] + <*> m .:? "key" + <*> m .:? "listen80" + <*> m .: "backends" + <*> m .: "oauth2" + <*> m .:? "database" + <*> m .:? "pgpassfile" + <*> m .:? "http2" .!= True + parseJSON _ = empty + + +data BackendConf = BackendConf { + beName :: String +, beAddress :: String +, bePort :: Maybe Word16 +, beSocket :: Maybe FilePath +, beCookieName :: String +, beCookieDomain :: Maybe String +, beCookieMaxAge :: Int64 +, beConnCount :: Int +} deriving (Show) + +instance FromJSON BackendConf where + parseJSON (Object m) = BackendConf <$> + m .:? "name" .!= "*" + <*> m .:? "address" .!= "127.0.0.1" + <*> m .:? "port" + <*> m .:? "socket" + <*> m .:? "cookie_name" .!= "sproxy" + <*> m .:? "cookie_domain" + <*> m .:? "cookie_max_age" .!= (7 * 24 * 60 * 60) + <*> m .:? "conn_count" .!= 32 + parseJSON _ = empty + + +data OAuth2Conf = OAuth2Conf { + oa2ClientId :: String +, oa2ClientSecret :: FilePath +} deriving (Show) + +instance FromJSON OAuth2Conf where + parseJSON (Object m) = OAuth2Conf <$> + m .: "client_id" + <*> m .: "client_secret" + parseJSON _ = empty + diff --git a/src/Sproxy/Logging.hs b/src/Sproxy/Logging.hs new file mode 100644 index 0000000..651a73a --- /dev/null +++ b/src/Sproxy/Logging.hs @@ -0,0 +1,99 @@ +module Sproxy.Logging ( + LogLevel(..) +, debug +, error +, info +, level +, start +, warn +) where + +import Prelude hiding (error) + +import Control.Applicative (empty) +import Control.Concurrent (forkIO) +import Control.Concurrent.Chan (Chan, newChan, readChan, writeChan) +import Control.Monad (forever, when) +import Data.Aeson (FromJSON, ToJSON) +import Data.Char (toLower) +import Data.IORef (IORef, newIORef, readIORef, writeIORef) +import System.IO (hPrint, stderr) +import System.IO.Unsafe (unsafePerformIO) +import Text.Read (readMaybe) +import qualified Data.Aeson as JSON +import qualified Data.Text as T + +start :: LogLevel -> IO () +start None = return () +start lvl = do + writeIORef logLevel lvl + ch <- readIORef chanRef + _ <- forkIO . forever $ readChan ch >>= hPrint stderr + return () + +info :: String -> IO () +info = send . Message Info + +warn:: String -> IO () +warn = send . Message Warning + +error:: String -> IO () +error = send . Message Error + +debug :: String -> IO () +debug = send . Message Debug + + +send :: Message -> IO () +send msg@(Message l _) = do + lvl <- level + when (l <= lvl) $ do + ch <- readIORef chanRef + writeChan ch msg + +{-# NOINLINE chanRef #-} +chanRef :: IORef (Chan Message) +chanRef = unsafePerformIO (newChan >>= newIORef) + +{-# NOINLINE logLevel #-} +logLevel :: IORef LogLevel +logLevel = unsafePerformIO (newIORef None) + +level :: IO LogLevel +level = readIORef logLevel + + +data LogLevel = None | Error | Warning | Info | Debug + deriving (Enum, Ord, Eq) + +instance Show LogLevel where + show None = "NONE" + show Error = "ERROR" + show Warning = "WARN" + show Info = "INFO" + show Debug = "DEBUG" + +instance Read LogLevel where + readsPrec _ s + | l == "none" = [ (None, "") ] + | l == "error" = [ (Error, "") ] + | l == "warn" = [ (Warning, "") ] + | l == "info" = [ (Info, "") ] + | l == "debug" = [ (Debug, "") ] + | otherwise = [ ] + where l = map toLower s + +instance ToJSON LogLevel where + toJSON = JSON.String . T.pack . show + +instance FromJSON LogLevel where + parseJSON (JSON.String s) = + maybe (fail $ "unknown log level: " ++ show s) return (readMaybe . T.unpack $ s) + parseJSON _ = empty + + +data Message = Message LogLevel String + +instance Show Message where + show (Message lvl str) = show lvl ++ ": " ++ str + diff --git a/src/Sproxy/Server.hs b/src/Sproxy/Server.hs new file mode 100644 index 0000000..bd2af17 --- /dev/null +++ b/src/Sproxy/Server.hs @@ -0,0 +1,190 @@ +module Sproxy.Server ( + server +) where + +import Control.Concurrent (forkIO) +import Control.Exception (bracketOnError) +import Control.Monad (void, when) +import Data.ByteString as BS (hGetLine, readFile) +import Data.ByteString.Char8 (pack) +import Data.HashMap.Strict as HM (fromList, lookup, toList) +import Data.Maybe (fromMaybe) +import Data.Text (Text) +import Data.Word (Word16) +import Data.Yaml (decodeFileEither) +import Network.HTTP.Client (Manager, ManagerSettings(..), defaultManagerSettings, newManager, socketConnection) +import Network.HTTP.Client.Internal (Connection) +import Network.Socket ( Family(AF_INET, AF_UNIX), SockAddr(SockAddrInet, SockAddrUnix), + SocketOption(ReuseAddr), SocketType(Stream), bind, close, connect, inet_addr, + listen, maxListenQueue, setSocketOption, socket ) +import Network.Wai.Handler.WarpTLS (tlsSettingsChain, runTLSSocket) +import Network.Wai.Handler.Warp (defaultSettings, setHTTP2Disabled, runSettingsSocket) +import System.Entropy (getEntropy) +import System.Environment (setEnv) +import System.Exit (exitFailure) +import System.FilePath.Glob (compile) +import System.IO (IOMode(ReadMode), hIsEOF, hPutStrLn, stderr, withFile) +import System.Posix.User ( GroupEntry(..), UserEntry(..), + getAllGroupEntries, getRealUserID, + getUserEntryForName, setGroupID, setGroups, setUserID ) + +import Sproxy.Application (sproxy, redirect) +import Sproxy.Application.OAuth2.Common (OAuth2Client) +import Sproxy.Config (BackendConf(..), ConfigFile(..), OAuth2Conf(..)) +import qualified Sproxy.Application.OAuth2 as OAuth2 +import qualified Sproxy.Logging as Log +import qualified Sproxy.Server.DB as DB + + +server :: FilePath -> IO () +server configFile = do + cf <- readConfigFile configFile + Log.start $ cfLogLevel cf + Log.debug $ show cf + + sock <- socket AF_INET Stream 0 + setSocketOption sock ReuseAddr 1 + bind sock $ SockAddrInet (fromIntegral $ cfListen cf) 0 + + maybe80 <- if fromMaybe (443 == cfListen cf) (cfListen80 cf) + then do + sock80 <- socket AF_INET Stream 0 + setSocketOption sock80 ReuseAddr 1 + bind sock80 $ SockAddrInet 80 0 + return (Just sock80) + else + return Nothing + + uid <- getRealUserID + when (0 == uid) $ do + let user = cfUser cf + Log.info $ "switching to user " ++ show user + u <- getUserEntryForName user + groupIDs <- map groupID . filter (elem user . groupMembers) + <$> getAllGroupEntries + setGroups groupIDs + setGroupID $ userGroupID u + setUserID $ userID u + + case cfPgPassFile cf of + Nothing -> return () + Just f -> do + Log.info $ "pgpassfile: " ++ show f + setEnv "PGPASSFILE" f + + db <- DB.start (cfHome cf) (newDataSource cf) + + key <- maybe + (Log.info "using new random key" >> getEntropy 32) + (\f -> Log.info ("reading key from " ++ f) >> BS.readFile f) + (cfKey cf) + + case maybe80 of + Nothing -> return () + Just sock80 -> do + Log.info "listening on port 80 (HTTP redirect)" + listen sock80 maxListenQueue + void . forkIO $ runSettingsSocket defaultSettings sock80 (redirect $ cfListen cf) + + oauth2clients <- HM.fromList <$> mapM newOAuth2Client (HM.toList (cfOAuth2 cf)) + + backends <- + mapM (\be -> do + m <- newBackendManager be + return (compile $ beName be, be, m) + ) $ cfBackends cf + + let + settings = + (if cfHTTP2 cf then id else setHTTP2Disabled) + defaultSettings + + -- XXX 2048 is from bindPortTCP from streaming-commons used internally by runTLS. + -- XXX Since we don't call runTLS, we listen socket here with the same options. + Log.info $ "listening on port " ++ show (cfListen cf) ++ " (HTTPS)" + listen sock (max 2048 maxListenQueue) + runTLSSocket + (tlsSettingsChain (cfSslCert cf) (cfSslCertChain cf) (cfSslKey cf)) + settings + sock + (sproxy key db oauth2clients backends) + + +newDataSource :: ConfigFile -> Maybe DB.DataSource +newDataSource cf = + case cfDatabase cf of + Just str -> Just $ DB.PostgreSQL str + Nothing -> Nothing + + +newOAuth2Client :: (Text, OAuth2Conf) -> IO (Text, OAuth2Client) +newOAuth2Client (name, cfg) = + case HM.lookup name OAuth2.providers of + Nothing -> do Log.error $ "OAuth2 provider " ++ show name ++ " is not supported" + exitFailure + Just provider -> do + Log.info $ "oauth2: adding " ++ show name + client_secret <- withFile secret_file ReadMode $ \h -> do + empty <- hIsEOF h + if empty then do + Log.error $ "oauth2: empty secret file for " + ++ show name ++ ": " ++ show secret_file + return $ pack "" + else BS.hGetLine h + return (name, provider (pack client_id, client_secret)) + where client_id = oa2ClientId cfg + secret_file = oa2ClientSecret cfg + + +newBackendManager :: BackendConf -> IO Manager +newBackendManager be = do + openConn <- + case (beSocket be, bePort be) of + (Just f, Nothing) -> do + Log.info $ "backend `" ++ beName be ++ "' on UNIX socket " ++ f + return $ openUnixSocketConnection f + + (Nothing, Just n) -> do + Log.info $ "backend `" ++ beName be ++ "' on " ++ beAddress be ++ ":" ++ show n + return $ openTCPConnection (beAddress be) n + + _ -> do + Log.error "either backend port number or UNIX socket path is required." + exitFailure + + newManager defaultManagerSettings { + managerRawConnection = return $ \_ _ _ -> openConn + , managerConnCount = beConnCount be + } + + +openUnixSocketConnection :: FilePath -> IO Connection +openUnixSocketConnection f = + bracketOnError + (socket AF_UNIX Stream 0) + close + (\s -> do + connect s (SockAddrUnix f) + socketConnection s 8192) + + +openTCPConnection :: String -> Word16 -> IO Connection +openTCPConnection addr port = + bracketOnError + (socket AF_INET Stream 0) + close + (\s -> do + a <- inet_addr addr + connect s (SockAddrInet (fromIntegral port) a) + socketConnection s 8192) + + +readConfigFile :: FilePath -> IO ConfigFile +readConfigFile f = do + r <- decodeFileEither f + case r of + Left e -> do + hPutStrLn stderr $ "FATAL: " ++ f ++ ": " ++ show e + exitFailure + Right cf -> return cf + diff --git a/src/Sproxy/Server/DB.hs b/src/Sproxy/Server/DB.hs new file mode 100644 index 0000000..b760afc --- /dev/null +++ b/src/Sproxy/Server/DB.hs @@ -0,0 +1,189 @@ +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE QuasiQuotes #-} +module Sproxy.Server.DB ( + Database +, DataSource(..) +, userExists +, userGroups +, start +) where + +import Control.Concurrent (forkIO, threadDelay) +import Control.Exception (SomeException, bracket, catch, finally) +import Control.Monad (forever, void) +import Data.ByteString (ByteString) +import Data.ByteString.Char8 (pack) +import Data.Pool (Pool, createPool, withResource) +import Data.Text (Text, toLower, unpack) +import Data.Text.Encoding (encodeUtf8) +import Database.SQLite.Simple (NamedParam((:=))) +import Text.InterpolatedString.Perl6 (q, qc) +import qualified Database.PostgreSQL.Simple as PG +import qualified Database.SQLite.Simple as SQLite + +import qualified Sproxy.Logging as Log + + +type Database = Pool SQLite.Connection + +data DataSource = PostgreSQL String -- | File FilePath + +{- TODO: + - Hash remote tables and update the local only when the remote change + - Switch to REGEX + - Generalize sync procedures for different tables + -} + +start :: FilePath -> Maybe DataSource -> IO Database +start home ds = do + Log.info $ "home directory: " ++ show home + db <- createPool + (do c <- SQLite.open $ home ++ "/sproxy.sqlite3" + lvl <- Log.level + SQLite.setTrace c (if lvl == Log.Debug then Just $ Log.debug . unpack else Nothing) + return c) + SQLite.close + 1 -- stripes + 3600 -- keep alive (seconds). FIXME: no much sense as it's a local file + 128 -- max connections. FIXME: make configurable? + + withResource db $ \c -> SQLite.execute_ c "PRAGMA journal_mode=WAL" + populate db ds + return db + + +userExists :: Database -> String -> IO Bool +userExists db email = do + r <- withResource db $ \c -> fmap SQLite.fromOnly <$> SQLite.queryNamed c + "SELECT EXISTS (SELECT 1 FROM group_member WHERE :email LIKE email LIMIT 1)" + [ ":email" := email ] + return $ head r + + +userGroups :: Database -> String -> Text -> Text -> Text -> IO [ByteString] +userGroups db email domain path method = + withResource db $ \c -> fmap (encodeUtf8 . SQLite.fromOnly) <$> SQLite.queryNamed c [q| + SELECT gm."group" FROM group_privilege gp JOIN group_member gm ON gm."group" = gp."group" + WHERE :email LIKE gm.email + AND :domain LIKE gp.domain + AND gp.privilege IN ( + SELECT privilege FROM privilege_rule + WHERE :domain LIKE domain + AND :path LIKE path + AND method = :method + ORDER BY length(path) - length(replace(path, '/', '')) DESC LIMIT 1 + ) + |] [ ":email" := email -- XXX always in lower case + , ":domain" := toLower domain + , ":path" := path + , ":method" := method -- XXX case-sensitive by RFC2616 + ] + + +populate :: Database -> Maybe DataSource -> IO () + +populate db Nothing = do + Log.warn "db: no data source defined" + withResource db $ \c -> SQLite.withTransaction c $ do + createGroupMember c + createGroupPrivilege c + createPrivilegeRule c + +-- XXX We keep only required minimum of the data, without any integrity check. +-- XXX Integrity check should be done somewhere else, e. g. in the master PostgreSQL database, +-- XXX or during importing the config file. +populate db (Just (PostgreSQL connstr)) = + void . forkIO . forever . flip finally (7 `minutes` threadDelay) + . logException $ do + Log.info $ "db: synchronizing with " ++ show connstr + withResource db $ \c -> SQLite.withTransaction c $ + bracket (PG.connectPostgreSQL $ pack connstr) PG.close $ + \pg -> PG.withTransaction pg $ do + + Log.info "db: syncing group_member" + dropGroupMember c + createGroupMember c + PG.forEach_ pg + [q|SELECT "group", lower(email) FROM group_member|] $ \r -> + SQLite.execute c + [q|INSERT INTO group_member("group", email) VALUES (?, ?)|] + (r :: (Text, Text)) + count c "group_member" + + Log.info "db: syncing group_privilege" + dropGroupPrivilege c + createGroupPrivilege c + PG.forEach_ pg + [q|SELECT "group", lower(domain), privilege FROM group_privilege|] $ \r -> + SQLite.execute c + [q|INSERT INTO group_privilege("group", domain, privilege) VALUES (?, ?, ?)|] + (r :: (Text, Text, Text)) + count c "group_privilege" + + Log.info "db: syncing privilege_rule" + dropPrivilegeRule c + createPrivilegeRule c + PG.forEach_ pg + [q|SELECT lower(domain), privilege, path, method FROM privilege_rule|] $ \r -> + SQLite.execute c + [q|INSERT INTO privilege_rule(domain, privilege, path, method) VALUES (?, ?, ?, ?)|] + (r :: (Text, Text, Text, Text)) + count c "privilege_rule" + + +dropGroupMember :: SQLite.Connection -> IO () +dropGroupMember c = SQLite.execute_ c "DROP TABLE IF EXISTS group_member" + +createGroupMember :: SQLite.Connection -> IO () +createGroupMember c = SQLite.execute_ c [q| + CREATE TABLE IF NOT EXISTS group_member ( + "group" TEXT, + email TEXT, + PRIMARY KEY ("group", email) + ) +|] + + +dropGroupPrivilege :: SQLite.Connection -> IO () +dropGroupPrivilege c = SQLite.execute_ c "DROP TABLE IF EXISTS group_privilege" + +createGroupPrivilege :: SQLite.Connection -> IO () +createGroupPrivilege c = SQLite.execute_ c [q| + CREATE TABLE IF NOT EXISTS group_privilege ( + "group" TEXT, + domain TEXT, + privilege TEXT, + PRIMARY KEY ("group", domain, privilege) + ) +|] + + +dropPrivilegeRule :: SQLite.Connection -> IO () +dropPrivilegeRule c = SQLite.execute_ c "DROP TABLE IF EXISTS privilege_rule" + +createPrivilegeRule :: SQLite.Connection -> IO () +createPrivilegeRule c = SQLite.execute_ c [q| + CREATE TABLE IF NOT EXISTS privilege_rule ( + domain TEXT, + privilege TEXT, + path TEXT, + method TEXT, + PRIMARY KEY (domain, path, method) + ) +|] + + +count :: SQLite.Connection -> String -> IO () +count c table = do + r <- fmap SQLite.fromOnly <$> SQLite.query_ c [qc|SELECT COUNT(*) FROM {table}|] + Log.info $ "db: " ++ table ++ " rows: " ++ show (head r :: Integer) + + +logException :: IO () -> IO () +logException a = catch a $ \e -> + Log.error $ "db: " ++ show (e :: SomeException) + + +minutes :: Int -> (Int -> IO ()) -> IO () +minutes us f = f $ us * 60 * 1000000 + |