曹耘豪的博客

Spring自定义RestTemplate单次请求超时

  1. apache.http配置
  2. okhttp3配置
  3. 使用注解方式
    1. 切面类

设置单次调用超时时间

apache.http配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
@Aspect
@Configuration
public class HttpComponentsClientConfiguration {

@Bean
public HttpComponentsClientHttpRequestFactory httpComponentsClientHttpRequestFactory(CloseableHttpClient client) {
HttpComponentsClientHttpRequestFactory factory = new HttpComponentsClientHttpRequestFactory(client);

factory.setHttpContextFactory((httpMethod, uri) -> {
HttpContext context = HttpClientContext.create();

HttpRequestTimeoutContext.current().ifPresent(timeout -> {
RequestConfig config = RequestConfig.custom()
.setSocketTimeout(timeout.millis())
.setConnectTimeout(timeout.millis())
.build();
context.setAttribute(HttpClientContext.REQUEST_CONFIG, config);
});

return context;
});

return factory;
}

}

okhttp3配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
@Aspect
@Configuration
public class OkHttpClientConfiguration {

@Bean
public OkHttp3ClientHttpRequestFactory okHttp3ClientHttpRequestFactory(OkHttpClient client) {
OkHttpClient.Builder builder = client.newBuilder();

builder.addInterceptor(chain -> {
RestTimeout timeout = RestTimeoutContext.current();
if (timeout != null) {
chain = chain.withConnectTimeout(timeout.millis(), TimeUnit.MILLISECONDS)
.withReadTimeout(timeout.millis(), TimeUnit.MILLISECONDS)
.withWriteTimeout(timeout.millis(), TimeUnit.MILLISECONDS);
}
return chain.proceed(chain.request());
});

return new OkHttp3ClientHttpRequestFactory(builder.build());
}

}

使用注解方式

1
2
3
4
5
6
7
8
9
@Target(value = {ElementType.METHOD})
@Retention(RetentionPolicy.RUNTIME)
@Inherited
@Documented
public @interface RestTemplateTimeout {

int millis();

}

切面类

HttpRequestTimeoutAspect.java
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
@Aspect
@Component
public class HttpRequestTimeoutAspect {

@Around("@annotation(custom)")
public Object cut(ProceedingJoinPoint point, HttpRequestTimeout custom) throws Throwable {
int size = HttpRequestTimeoutContext.size();
try {
HttpRequestTimeoutContext.push(custom);
return point.proceed();
} finally {
if (HttpRequestTimeoutContext.size() == size + 1) {
HttpRequestTimeoutContext.removeLast();
}
}
}

}

上下文类,主要保存注解配置

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
public class HttpRequestTimeoutContext {

private static final ThreadLocal<Deque<HttpRequestTimeout>> STACK = ThreadLocal.withInitial(LinkedList::new);

public static RestTimeout current() {
return STACK.get().peekLast();
}

static int size() {
return STACK.get().size();
}

static void push(HttpRequestTimeout value) {
STACK.get().addLast(value);
}

static void removeLast() {
STACK.get().pollLast();
}

}
1
2
3
4
5
6
7
8
9
10
11

@Bean
public RestTemplate restTemplate(HttpComponentsClientHttpRequestFactory factory) {
return new RestTemplate(factory);
}

@Bean
public RestTemplate restTemplate(OkHttp3ClientHttpRequestFactory factory) {
return new RestTemplate(factory);
}

   /