创建自定义 Spring Cloud Gateway 过滤器 - spring.io

22-08-30 banq

在本文中,我们着眼于为 Spring Cloud Gateway 编写自定义扩展。在开始之前,让我们回顾一下 Spring Cloud Gateway 的工作原理:



  1. 首先,客户端向网关发出网络请求
  2. 网关定义了许多路由,每个路由都有谓词来匹配请求和路由。例如,您可以匹配 URL 的路径段或请求的 HTTP 方法。
  3. 一旦匹配,网关对应用于路由的每个过滤器执行预请求逻辑。例如,您可能希望将查询参数添加到您的请求中
  4. 代理过滤器将请求路由到代理服务
  5. 服务执行并返回响应
  6. 网关接收响应并在返回响应之前对每个过滤器执行请求后逻辑。例如,您可以在返回客户端之前删除不需要的响应标头。

我们的扩展将对请求正文进行哈希处理,并将该值添加为名为 的请求标头X-Hash。这对应于上图中的步骤 3。注意:当我们读取请求正文时,网关将受到内存限制。

首先,我们在start.spring.io中创建一个具有 Gateway 依赖项的项目。在此示例中,我们将在 Java 中使用带有 JDK 17 和 Spring Boot 2.7.3 的 Gradle 项目。下载、解压缩并在您喜欢的 IDE 中打开项目并运行它,以确保您已为本地开发做好准备。
接下来让我们创建 GatewayFilter Factory,它是一个限定于特定路由的过滤器,它允许我们以某种方式修改传入的 HTTP 请求或传出的 HTTP 响应。在我们的例子中,我们将使用附加标头修改传入的 HTTP 请求:

package com.example.demo;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Collections;
import java.util.List;

import org.bouncycastle.util.encoders.Hex;
import reactor.core.publisher.Mono;

import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.filter.factory.AbstractGatewayFilterFactory;
import org.springframework.cloud.gateway.support.ServerWebExchangeUtils;
import org.springframework.http.codec.HttpMessageReader;
import org.springframework.http.server.reactive.ServerHttpRequest;
import org.springframework.stereotype.Component;
import org.springframework.util.Assert;
import org.springframework.web.reactive.function.server.HandlerStrategies;
import org.springframework.web.reactive.function.server.ServerRequest;

import static org.springframework.cloud.gateway.support.ServerWebExchangeUtils.CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR;

/**
 * This filter hashes the request body, placing the value in the X-Hash header.
 * Note: This causes the gateway to be memory constrained.
 * Sample usage: RequestHashing=SHA-256
 */
@Component
public class RequestHashingGatewayFilterFactory extends
        AbstractGatewayFilterFactory<RequestHashingGatewayFilterFactory.Config> {

    private static final String HASH_ATTR = "hash";
    private static final String HASH_HEADER = "X-Hash";
    private final List<HttpMessageReader<?>> messageReaders =
            HandlerStrategies.withDefaults().messageReaders();

    public RequestHashingGatewayFilterFactory() {
        super(Config.class);
    }

    @Override
    public GatewayFilter apply(Config config) {
        MessageDigest digest = config.getMessageDigest();
        return (exchange, chain) -> ServerWebExchangeUtils
                .cacheRequestBodyAndRequest(exchange, (httpRequest) -> ServerRequest
                    .create(exchange.mutate().request(httpRequest).build(),
                            messageReaders)
                    .bodyToMono(String.class)
                    .doOnNext(requestPayload -> exchange
                            .getAttributes()
                            .put(HASH_ATTR, computeHash(digest, requestPayload)))
                    .then(Mono.defer(() -> {
                        ServerHttpRequest cachedRequest = exchange.getAttribute(
                                CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);
                        Assert.notNull(cachedRequest, 
                                "cache request shouldn't be null");
                        exchange.getAttributes()
                                .remove(CACHED_SERVER_HTTP_REQUEST_DECORATOR_ATTR);

                        String hash = exchange.getAttribute(HASH_ATTR);
                        cachedRequest = cachedRequest.mutate()
                                .header(HASH_HEADER, hash)
                                .build();
                        return chain.filter(exchange.mutate()
                                .request(cachedRequest)
                                .build());
                    })));
    }

    @Override
    public List<String> shortcutFieldOrder() {
        return Collections.singletonList("algorithm");
    }

    private String computeHash(MessageDigest messageDigest, String requestPayload) {
        return Hex.toHexString(messageDigest.digest(requestPayload.getBytes()));
    }

    static class Config {

        private MessageDigest messageDigest;

        public MessageDigest getMessageDigest() {
            return messageDigest;
        }

        public void setAlgorithm(String algorithm) throws NoSuchAlgorithmException {
            messageDigest = MessageDigest.getInstance(algorithm);
        }
    }
}



让我们更详细地看一下代码:
  • 我们为该类添加了@Component注解。Spring Cloud Gateway需要能够检测到这个类,以便使用它。另外,我们也可以用@Bean定义一个实例。
  • 在我们的类名中,我们使用GatewayFilterFactory作为后缀。在application.yaml中添加这个过滤器时,我们不包括后缀,只包括RequestHashing。这是一个Spring Cloud Gateway过滤器的命名惯例。
  • 我们的类还扩展了AbstractGatewayFilterFactory,与所有其他Spring Cloud Gateway过滤器类似。我们还指定了一个类来配置我们的过滤器,一个名为Config的嵌套静态类有助于保持简单。这个配置类允许我们设置使用哪种散列算法。
  • 重载的apply方法是所有工作发生的地方。在参数中,我们得到了一个配置类的实例,在那里我们可以访问MessageDigest实例进行散列。接下来,我们看到(exchange, chain),一个GatewayFilter接口类的lambda被返回。Exchange是ServerWebExchange的一个实例,为Gateway过滤器提供对HTTP请求和响应的访问。对于我们的案例,我们想修改HTTP请求,这就要求我们对交换进行变异。
  • 我们需要读取请求体来产生哈希值,然而,由于请求体被存储在一个字节缓冲区中,它在过滤器中只能被读取一次。通过使用ServerWebExchangeUtils,我们把请求作为交换中的一个属性进行缓存。属性提供了一种在过滤器链中共享特定请求数据的方式。我们也将存储请求主体的计算哈希值。
  • 我们使用交换的属性来获取缓存的请求和计算的哈希值。然后我们通过添加哈希头来突变交换,最后将其发送到链上的下一个过滤器。
  • shortcutFieldOrder方法有助于将参数的数量和顺序映射到过滤器中。该算法字符串与配置类中的setter相匹配。


为了测试代码,我们将使用 WireMock。将依赖项添加到您的build.gradle文件中:

testImplementation 'com.github.tomakehurst:wiremock:2.27.2'


在这里,我们有一个测试检查头的存在和价值,如果没有请求正 ,文另一个测试检查头是否不存在。

package com.example.demo;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;

import com.github.tomakehurst.wiremock.WireMockServer;
import com.github.tomakehurst.wiremock.client.WireMock;
import com.github.tomakehurst.wiremock.core.WireMockConfiguration;
import org.bouncycastle.jcajce.provider.digest.SHA512;
import org.bouncycastle.util.encoders.Hex;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.reactive.AutoConfigureWebTestClient;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.cloud.gateway.filter.GatewayFilter;
import org.springframework.cloud.gateway.route.RouteLocator;
import org.springframework.cloud.gateway.route.builder.RouteLocatorBuilder;
import org.springframework.context.annotation.Bean;
import org.springframework.http.HttpStatus;
import org.springframework.test.web.reactive.server.WebTestClient;

import static com.example.demo.RequestHashingGatewayFilterFactory.*;
import static com.example.demo.RequestHashingGatewayFilterFactoryTest.*;
import static com.github.tomakehurst.wiremock.client.WireMock.equalTo;
import static com.github.tomakehurst.wiremock.client.WireMock.postRequestedFor;
import static com.github.tomakehurst.wiremock.client.WireMock.urlEqualTo;
import static com.github.tomakehurst.wiremock.core.WireMockConfiguration.wireMockConfig;
import static org.springframework.boot.test.context.SpringBootTest.WebEnvironment.RANDOM_PORT;

@SpringBootTest(
        webEnvironment = RANDOM_PORT,
        classes = RequestHashingFilterTestConfig.class)
@AutoConfigureWebTestClient
class RequestHashingGatewayFilterFactoryTest {

    @TestConfiguration
    static class RequestHashingFilterTestConfig {

        @Autowired
        RequestHashingGatewayFilterFactory requestHashingGatewayFilter;

        @Bean(destroyMethod = "stop")
        WireMockServer wireMockServer() {
            WireMockConfiguration options = wireMockConfig().dynamicPort();
            WireMockServer wireMock = new WireMockServer(options);
            wireMock.start();
            return wireMock;
        }

        @Bean
        RouteLocator testRoutes(RouteLocatorBuilder builder, WireMockServer wireMock)
                throws NoSuchAlgorithmException {
            Config config = new Config();
            config.setAlgorithm("SHA-512");

            GatewayFilter gatewayFilter = requestHashingGatewayFilter.apply(config);
            return builder
                    .routes()
                    .route(predicateSpec -> predicateSpec
                            .path("/post")
                            .filters(spec -> spec.filter(gatewayFilter))
                            .uri(wireMock.baseUrl()))
                    .build();
        }
    }

    @Autowired
    WebTestClient webTestClient;

    @Autowired
    WireMockServer wireMockServer;

    @AfterEach
    void afterEach() {
        wireMockServer.resetAll();
    }

    @Test
    void shouldAddHeaderWithComputedHash() {
        MessageDigest messageDigest = new SHA512.Digest();
        String body = "hello world";
        String expectedHash = Hex.toHexString(messageDigest.digest(body.getBytes()));

        wireMockServer.stubFor(WireMock.post("/post").willReturn(WireMock.ok()));

        webTestClient.post().uri("/post")
                .bodyValue(body)
                .exchange()
                .expectStatus()
                .isEqualTo(HttpStatus.OK);

        wireMockServer.verify(postRequestedFor(urlEqualTo("/post"))
                .withHeader("X-Hash", equalTo(expectedHash)));
    }

    @Test
    void shouldNotAddHeaderIfNoBody() {
        wireMockServer.stubFor(WireMock.post("/post").willReturn(WireMock.ok()));

        webTestClient.post().uri("/post")
                .exchange()
                .expectStatus()
                .isEqualTo(HttpStatus.OK);

        wireMockServer.verify(postRequestedFor(urlEqualTo("/post"))
                .withoutHeader("X-Hash"));
    }
}


为了在我们的网关中使用该过滤器,我们在application.yaml的路由中添加RequestHashing过滤器,使用SHA-256作为算法。

spring:
  cloud:
    gateway:
      routes:
        - id: demo
          uri: https://httpbin.org
          predicates:
            - Path=/post/**
          filters:
            - RequestHashing=SHA-256



我们使用https://httpbin.org,因为它在其返回的响应中显示了我们的请求头信息。运行应用程序,并进行curl请求以查看结果。

$> curl --request POST 'http://localhost:8080/post' \
--header 'Content-Type: application/json' \
--data-raw '{
    "data": {
        "hello": "world"
    }
}'

{
  ...
  "data": "{\n    \"data\": {\n        \"hello\": \"world\"\n    }\n}",
  "headers": {
        "Accept": "*/*",
        "Accept-Encoding": "gzip, deflate, br",
        "Content-Length": "48",
        "Content-Type": "application/json",
        "Forwarded": "proto=http;host=\"localhost:8080\"for=\"[0:0:0:0:0:0:0:1]:55647\"",
        "Host": "httpbin.org",
        "User-Agent": "PostmanRuntime/7.29.0",
        "X-Forwarded-Host": "localhost:8080",
        "X-Hash": "1bd93d38735501b5aec7a822f8bc8136d9f1f71a30c2020511bdd5df379772b8"
    },
  ...
}


综上所述,我们看到了如何为Spring Cloud Gateway编写一个自定义扩展。我们的过滤器读取了请求的主体,产生了一个哈希值,我们将其作为请求头添加。我们还使用WireMock为该过滤器编写了测试,以检查头的值。最后,我们用该过滤器运行了一个网关来验证结果。

如果你打算在Kubernetes集群上部署Spring Cloud Gateway,一定要查看VMware Spring Cloud Gateway for Kubernetes。除了支持开源的Spring Cloud Gateway过滤器和自定义过滤器(比如我们上面写的那个),它还配有更多的内置过滤器来处理你的请求和响应。Spring Cloud Gateway for Kubernetes代表API开发团队处理跨领域的问题,例如。单点登录(SSO)、访问控制、速率限制、弹性、安全等等。

1