1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
|
module Sproxy.Server (
server
) where
import Control.Concurrent (forkIO)
import Control.Exception (bracketOnError)
import Control.Monad (void, when)
import Data.ByteString as BS (hGetLine, readFile)
import Data.ByteString.Char8 (pack)
import Data.HashMap.Strict as HM (fromList, lookup, toList)
import Data.Maybe (fromMaybe)
import Data.Text (Text)
import Data.Word (Word16)
import Data.Yaml (decodeFileEither)
import Network.HTTP.Client (Manager, ManagerSettings(..), defaultManagerSettings, newManager, socketConnection)
import Network.HTTP.Client.Internal (Connection)
import Network.Socket ( Family(AF_INET, AF_UNIX), SockAddr(SockAddrInet, SockAddrUnix),
SocketOption(ReuseAddr), SocketType(Stream), bind, close, connect, inet_addr,
listen, maxListenQueue, setSocketOption, socket )
import Network.Wai.Handler.WarpTLS (tlsSettingsChain, runTLSSocket)
import Network.Wai.Handler.Warp (defaultSettings, setHTTP2Disabled, runSettingsSocket)
import System.Entropy (getEntropy)
import System.Environment (setEnv)
import System.Exit (exitFailure)
import System.FilePath.Glob (compile)
import System.IO (IOMode(ReadMode), hIsEOF, hPutStrLn, stderr, withFile)
import System.Posix.User ( GroupEntry(..), UserEntry(..),
getAllGroupEntries, getRealUserID,
getUserEntryForName, setGroupID, setGroups, setUserID )
import Sproxy.Application (sproxy, redirect)
import Sproxy.Application.OAuth2.Common (OAuth2Client)
import Sproxy.Config (BackendConf(..), ConfigFile(..), OAuth2Conf(..))
import qualified Sproxy.Application.OAuth2 as OAuth2
import qualified Sproxy.Logging as Log
import qualified Sproxy.Server.DB as DB
server :: FilePath -> IO ()
server configFile = do
cf <- readConfigFile configFile
Log.start $ cfLogLevel cf
Log.debug $ show cf
sock <- socket AF_INET Stream 0
setSocketOption sock ReuseAddr 1
bind sock $ SockAddrInet (fromIntegral $ cfListen cf) 0
maybe80 <- if fromMaybe (443 == cfListen cf) (cfListen80 cf)
then do
sock80 <- socket AF_INET Stream 0
setSocketOption sock80 ReuseAddr 1
bind sock80 $ SockAddrInet 80 0
return (Just sock80)
else
return Nothing
uid <- getRealUserID
when (0 == uid) $ do
let user = cfUser cf
Log.info $ "switching to user " ++ show user
u <- getUserEntryForName user
groupIDs <- map groupID . filter (elem user . groupMembers)
<$> getAllGroupEntries
setGroups groupIDs
setGroupID $ userGroupID u
setUserID $ userID u
case cfPgPassFile cf of
Nothing -> return ()
Just f -> do
Log.info $ "pgpassfile: " ++ show f
setEnv "PGPASSFILE" f
db <- DB.start (cfHome cf) (newDataSource cf)
key <- maybe
(Log.info "using new random key" >> getEntropy 32)
(\f -> Log.info ("reading key from " ++ f) >> BS.readFile f)
(cfKey cf)
case maybe80 of
Nothing -> return ()
Just sock80 -> do
Log.info "listening on port 80 (HTTP redirect)"
listen sock80 maxListenQueue
void . forkIO $ runSettingsSocket defaultSettings sock80 (redirect $ cfListen cf)
oauth2clients <- HM.fromList <$> mapM newOAuth2Client (HM.toList (cfOAuth2 cf))
backends <-
mapM (\be -> do
m <- newBackendManager be
return (compile $ beName be, be, m)
) $ cfBackends cf
let
settings =
(if cfHTTP2 cf then id else setHTTP2Disabled)
defaultSettings
-- XXX 2048 is from bindPortTCP from streaming-commons used internally by runTLS.
-- XXX Since we don't call runTLS, we listen socket here with the same options.
Log.info $ "listening on port " ++ show (cfListen cf) ++ " (HTTPS)"
listen sock (max 2048 maxListenQueue)
runTLSSocket
(tlsSettingsChain (cfSslCert cf) (cfSslCertChain cf) (cfSslKey cf))
settings
sock
(sproxy key db oauth2clients backends)
newDataSource :: ConfigFile -> Maybe DB.DataSource
newDataSource cf =
case cfDatabase cf of
Just str -> Just $ DB.PostgreSQL str
Nothing -> Nothing
newOAuth2Client :: (Text, OAuth2Conf) -> IO (Text, OAuth2Client)
newOAuth2Client (name, cfg) =
case HM.lookup name OAuth2.providers of
Nothing -> do Log.error $ "OAuth2 provider " ++ show name ++ " is not supported"
exitFailure
Just provider -> do
Log.info $ "oauth2: adding " ++ show name
client_secret <- withFile secret_file ReadMode $ \h -> do
empty <- hIsEOF h
if empty then do
Log.error $ "oauth2: empty secret file for "
++ show name ++ ": " ++ show secret_file
return $ pack ""
else BS.hGetLine h
return (name, provider (pack client_id, client_secret))
where client_id = oa2ClientId cfg
secret_file = oa2ClientSecret cfg
newBackendManager :: BackendConf -> IO Manager
newBackendManager be = do
openConn <-
case (beSocket be, bePort be) of
(Just f, Nothing) -> do
Log.info $ "backend `" ++ beName be ++ "' on UNIX socket " ++ f
return $ openUnixSocketConnection f
(Nothing, Just n) -> do
Log.info $ "backend `" ++ beName be ++ "' on " ++ beAddress be ++ ":" ++ show n
return $ openTCPConnection (beAddress be) n
_ -> do
Log.error "either backend port number or UNIX socket path is required."
exitFailure
newManager defaultManagerSettings {
managerRawConnection = return $ \_ _ _ -> openConn
, managerConnCount = beConnCount be
}
openUnixSocketConnection :: FilePath -> IO Connection
openUnixSocketConnection f =
bracketOnError
(socket AF_UNIX Stream 0)
close
(\s -> do
connect s (SockAddrUnix f)
socketConnection s 8192)
openTCPConnection :: String -> Word16 -> IO Connection
openTCPConnection addr port =
bracketOnError
(socket AF_INET Stream 0)
close
(\s -> do
a <- inet_addr addr
connect s (SockAddrInet (fromIntegral port) a)
socketConnection s 8192)
readConfigFile :: FilePath -> IO ConfigFile
readConfigFile f = do
r <- decodeFileEither f
case r of
Left e -> do
hPutStrLn stderr $ "FATAL: " ++ f ++ ": " ++ show e
exitFailure
Right cf -> return cf
|