diff --git a/build.gradle b/build.gradle
index 5175b98..19ea743 100644
--- a/build.gradle
+++ b/build.gradle
@@ -46,6 +46,7 @@ dependencies {
implementation group: "org.http4k", name: "http4k-client-okhttp"
implementation group: "org.http4k", name: "http4k-metrics-micrometer"
implementation group: "org.http4k", name: "http4k-server-netty"
+ implementation group: "io.netty", name: "netty-codec-haproxy"
implementation group: "io.netty", name: "netty-transport-native-epoll", classifier: "linux-x86_64"
implementation group: "io.netty.incubator", name: "netty-incubator-transport-native-io_uring", version: "0.0.3.Final", classifier: "linux-x86_64"
testImplementation group: "org.http4k", name: "http4k-testing-kotest"
diff --git a/settings.sample.yaml b/settings.sample.yaml
index 454d877..0325ea4 100644
--- a/settings.sample.yaml
+++ b/settings.sample.yaml
@@ -80,6 +80,13 @@ server_settings:
# 0 defaults to (2 * your available processors)
# https://docs.oracle.com/en/java/javase/11/docs/api/java.base/java/lang/Runtime.html#availableProcessors()
threads: 0
+ # Whether to enable support for HAProxy Proxy Protocol
+ # If using a reverse proxy to forward requests to MD@H via
+ # ssl passthrough, you can use Proxy Protocol to preserve
+ # original IP if your reverse proxy supports it. This
+ # will allow geo location metrics to work correctly.
+ # https://www.haproxy.com/blog/haproxy/proxy-protocol/
+ enable_proxy_protocol: false
# Settings intended for advanced use cases or tinkering
diff --git a/src/main/kotlin/mdnet/netty/ApplicationNetty.kt b/src/main/kotlin/mdnet/netty/ApplicationNetty.kt
index 7a4da2d..b208f1a 100644
--- a/src/main/kotlin/mdnet/netty/ApplicationNetty.kt
+++ b/src/main/kotlin/mdnet/netty/ApplicationNetty.kt
@@ -19,6 +19,7 @@ along with this MangaDex@Home. If not, see .
package mdnet.netty
import io.netty.bootstrap.ServerBootstrap
+import io.netty.buffer.ByteBuf
import io.netty.channel.*
import io.netty.channel.epoll.Epoll
import io.netty.channel.epoll.EpollEventLoopGroup
@@ -27,6 +28,12 @@ import io.netty.channel.nio.NioEventLoopGroup
import io.netty.channel.socket.SocketChannel
import io.netty.channel.socket.nio.NioServerSocketChannel
import io.netty.handler.codec.DecoderException
+import io.netty.handler.codec.ProtocolDetectionResult
+import io.netty.handler.codec.ProtocolDetectionState
+import io.netty.handler.codec.haproxy.HAProxyMessage
+import io.netty.handler.codec.haproxy.HAProxyMessageDecoder
+import io.netty.handler.codec.haproxy.HAProxyProtocolVersion
+import io.netty.handler.codec.http.FullHttpRequest
import io.netty.handler.codec.http.HttpObjectAggregator
import io.netty.handler.codec.http.HttpServerCodec
import io.netty.handler.codec.http.HttpServerKeepAliveHandler
@@ -43,7 +50,10 @@ import io.netty.handler.traffic.TrafficCounter
import io.netty.incubator.channel.uring.IOUring
import io.netty.incubator.channel.uring.IOUringEventLoopGroup
import io.netty.incubator.channel.uring.IOUringServerSocketChannel
+import io.netty.util.AttributeKey
+import io.netty.util.AttributeMap
import io.netty.util.DomainWildcardMappingBuilder
+import io.netty.util.ReferenceCountUtil
import io.netty.util.concurrent.DefaultEventExecutorGroup
import io.netty.util.internal.SystemPropertyUtil
import mdnet.Constants
@@ -173,6 +183,35 @@ class Netty(
.channelFactory(transport.factory)
.childHandler(object : ChannelInitializer() {
public override fun initChannel(ch: SocketChannel) {
+ if (serverSettings.enableProxyProtocol) {
+ ch.pipeline().addLast(
+ "proxyProtocol",
+ object : ChannelInboundHandlerAdapter() {
+ override fun channelRead(ctx: ChannelHandlerContext, msg: Any) {
+ if (msg is ByteBuf) {
+ // Since the builtin `HAProxyMessageDecoder` will break non Proxy Protocol requests
+ // we need to use its detection capabilities to only add it when needed.
+ val result: ProtocolDetectionResult = HAProxyMessageDecoder.detectProtocol(msg)
+ if (result.state() == ProtocolDetectionState.DETECTED) {
+ ctx.pipeline().addAfter("proxyProtocol", null, HAProxyMessageDecoder())
+ ctx.pipeline().remove(this)
+ }
+ }
+ super.channelRead(ctx, msg)
+ }
+ }
+ )
+ ch.pipeline().addLast(
+ "saveOriginalIp",
+ object : SimpleChannelInboundHandler() {
+ override fun channelRead0(ctx: ChannelHandlerContext, msg: HAProxyMessage) {
+ // Store proxy IP in an attribute for later use after HTTP request is extracted.
+ // Using an attribute ensures the value is scoped to this channel.
+ (ctx as AttributeMap).attr(HAPROXY_SOURCE).set(msg.sourceAddress())
+ }
+ }
+ )
+ }
ch.pipeline().addLast(
"ssl",
SniHandler(DomainWildcardMappingBuilder(sslContext).build())
@@ -206,6 +245,26 @@ class Netty(
ch.pipeline().addLast("keepAlive", HttpServerKeepAliveHandler())
ch.pipeline().addLast("aggregator", HttpObjectAggregator(65536))
+ if (serverSettings.enableProxyProtocol) {
+ ch.pipeline().addLast(
+ "setForwardHeader",
+ object : SimpleChannelInboundHandler(false) {
+ override fun channelRead0(ctx: ChannelHandlerContext, request: FullHttpRequest) {
+ // The geo location code already supports the `Forwarded header so setting
+ // it is the easiest way to introduce the original IP downstream.
+ if ((ctx as AttributeMap).hasAttr(HAPROXY_SOURCE)) {
+ val addr = (ctx as AttributeMap).attr(HAPROXY_SOURCE).get()
+ request.headers().set("Forwarded", addr)
+ }
+ // Since we're modifying the request without handling it, we must
+ // call retain to ensure it will still be available downstream.
+ ReferenceCountUtil.retain(request)
+ ctx.fireChannelRead(request)
+ }
+ }
+ )
+ }
+
ch.pipeline().addLast("burstLimiter", burstLimiter)
ch.pipeline().addLast(
@@ -256,6 +315,7 @@ class Netty(
companion object {
private val LOGGER = LoggerFactory.getLogger(Netty::class.java)
+ private val HAPROXY_SOURCE = AttributeKey.newInstance("haproxy_source")
}
}
diff --git a/src/main/kotlin/mdnet/settings/ClientSettings.kt b/src/main/kotlin/mdnet/settings/ClientSettings.kt
index d3fbbab..3c37773 100644
--- a/src/main/kotlin/mdnet/settings/ClientSettings.kt
+++ b/src/main/kotlin/mdnet/settings/ClientSettings.kt
@@ -43,6 +43,7 @@ data class ServerSettings(
val externalIp: String? = null,
val port: Int = 443,
val threads: Int = 0,
+ val enableProxyProtocol: Boolean = false,
)
@JsonNaming(PropertyNamingStrategies.SnakeCaseStrategy::class)