Merge changes from topic 'cors'

* changes:
  Support faster cross-domain XHR calls
  Allow CORS to use modifying REST API
This commit is contained in:
Shawn Pearce
2017-06-17 04:23:02 +00:00
committed by Gerrit Code Review
11 changed files with 548 additions and 121 deletions

View File

@@ -14,16 +14,24 @@
package com.google.gerrit.httpd.restapi;
import static com.google.gerrit.httpd.restapi.RestApiServlet.ALLOWED_CORS_METHODS;
import static com.google.gerrit.httpd.restapi.RestApiServlet.XD_AUTHORIZATION;
import static com.google.gerrit.httpd.restapi.RestApiServlet.XD_CONTENT_TYPE;
import static com.google.gerrit.httpd.restapi.RestApiServlet.XD_METHOD;
import static com.google.gerrit.httpd.restapi.RestApiServlet.replyBinaryResult;
import static com.google.gerrit.httpd.restapi.RestApiServlet.replyError;
import static javax.servlet.http.HttpServletResponse.SC_BAD_REQUEST;
import com.google.auto.value.AutoValue;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Splitter;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableListMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.MultimapBuilder;
import com.google.gerrit.common.Nullable;
import com.google.gerrit.extensions.registration.DynamicMap;
import com.google.gerrit.extensions.restapi.BadRequestException;
import com.google.gerrit.extensions.restapi.BinaryResult;
@@ -47,10 +55,97 @@ import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.kohsuke.args4j.CmdLineException;
class ParameterParser {
public class ParameterParser {
private static final ImmutableSet<String> RESERVED_KEYS =
ImmutableSet.of("pp", "prettyPrint", "strict", "callback", "alt", "fields");
@AutoValue
public abstract static class QueryParams {
static final String I = QueryParams.class.getName();
static QueryParams create(
@Nullable String accessToken,
@Nullable String xdMethod,
@Nullable String xdContentType,
ImmutableListMultimap<String, String> config,
ImmutableListMultimap<String, String> params) {
return new AutoValue_ParameterParser_QueryParams(
accessToken, xdMethod, xdContentType, config, params);
}
@Nullable
public abstract String accessToken();
@Nullable
abstract String xdMethod();
@Nullable
abstract String xdContentType();
abstract ImmutableListMultimap<String, String> config();
abstract ImmutableListMultimap<String, String> params();
boolean hasXdOverride() {
return xdMethod() != null || xdContentType() != null;
}
}
public static QueryParams getQueryParams(HttpServletRequest req) throws BadRequestException {
QueryParams qp = (QueryParams) req.getAttribute(QueryParams.I);
if (qp != null) {
return qp;
}
String accessToken = null;
String xdMethod = null;
String xdContentType = null;
ListMultimap<String, String> config = MultimapBuilder.hashKeys(4).arrayListValues().build();
ListMultimap<String, String> params = MultimapBuilder.hashKeys().arrayListValues().build();
String queryString = req.getQueryString();
if (!Strings.isNullOrEmpty(queryString)) {
for (String kvPair : Splitter.on('&').split(queryString)) {
Iterator<String> i = Splitter.on('=').limit(2).split(kvPair).iterator();
String key = Url.decode(i.next());
String val = i.hasNext() ? Url.decode(i.next()) : "";
if (XD_AUTHORIZATION.equals(key)) {
if (accessToken != null) {
throw new BadRequestException("duplicate " + XD_AUTHORIZATION);
}
accessToken = val;
} else if (XD_METHOD.equals(key)) {
if (xdMethod != null) {
throw new BadRequestException("duplicate " + XD_METHOD);
} else if (!ALLOWED_CORS_METHODS.contains(val)) {
throw new BadRequestException("invalid " + XD_METHOD);
}
xdMethod = val;
} else if (XD_CONTENT_TYPE.equals(key)) {
if (xdContentType != null) {
throw new BadRequestException("duplicate " + XD_CONTENT_TYPE);
}
xdContentType = val;
} else if (RESERVED_KEYS.contains(key)) {
config.put(key, val);
} else {
params.put(key, val);
}
}
}
qp =
QueryParams.create(
accessToken,
xdMethod,
xdContentType,
ImmutableListMultimap.copyOf(config),
ImmutableListMultimap.copyOf(params));
req.setAttribute(QueryParams.I, qp);
return qp;
}
private final CmdLineParser.Factory parserFactory;
private final Injector injector;
private final DynamicMap<DynamicOptions.DynamicBean> dynamicBeans;
@@ -98,24 +193,6 @@ class ParameterParser {
return true;
}
static void splitQueryString(
String queryString,
ListMultimap<String, String> config,
ListMultimap<String, String> params) {
if (!Strings.isNullOrEmpty(queryString)) {
for (String kvPair : Splitter.on('&').split(queryString)) {
Iterator<String> i = Splitter.on('=').limit(2).split(kvPair).iterator();
String key = Url.decode(i.next());
String val = i.hasNext() ? Url.decode(i.next()) : "";
if (RESERVED_KEYS.contains(key)) {
config.put(key, val);
} else {
params.put(key, val);
}
}
}
}
private static Set<String> query(HttpServletRequest req) {
Set<String> params = new HashSet<>();
if (!Strings.isNullOrEmpty(req.getQueryString())) {

View File

@@ -20,8 +20,11 @@ import static com.google.common.net.HttpHeaders.ACCESS_CONTROL_ALLOW_CREDENTIALS
import static com.google.common.net.HttpHeaders.ACCESS_CONTROL_ALLOW_HEADERS;
import static com.google.common.net.HttpHeaders.ACCESS_CONTROL_ALLOW_METHODS;
import static com.google.common.net.HttpHeaders.ACCESS_CONTROL_ALLOW_ORIGIN;
import static com.google.common.net.HttpHeaders.ACCESS_CONTROL_MAX_AGE;
import static com.google.common.net.HttpHeaders.ACCESS_CONTROL_REQUEST_HEADERS;
import static com.google.common.net.HttpHeaders.ACCESS_CONTROL_REQUEST_METHOD;
import static com.google.common.net.HttpHeaders.AUTHORIZATION;
import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static com.google.common.net.HttpHeaders.ORIGIN;
import static com.google.common.net.HttpHeaders.VARY;
import static java.math.RoundingMode.CEILING;
@@ -52,8 +55,6 @@ import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.ListMultimap;
import com.google.common.collect.Lists;
import com.google.common.collect.MultimapBuilder;
import com.google.common.collect.Streams;
import com.google.common.io.BaseEncoding;
import com.google.common.io.CountingOutputStream;
import com.google.common.math.IntMath;
@@ -92,6 +93,7 @@ import com.google.gerrit.extensions.restapi.RestView;
import com.google.gerrit.extensions.restapi.TopLevelResource;
import com.google.gerrit.extensions.restapi.UnprocessableEntityException;
import com.google.gerrit.httpd.WebSession;
import com.google.gerrit.httpd.restapi.ParameterParser.QueryParams;
import com.google.gerrit.server.AccessPath;
import com.google.gerrit.server.AnonymousUser;
import com.google.gerrit.server.CurrentUser;
@@ -139,15 +141,18 @@ import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
import java.util.concurrent.TimeUnit;
import java.util.regex.Pattern;
import java.util.stream.Stream;
import java.util.zip.GZIPOutputStream;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.eclipse.jgit.http.server.ServletUtils;
import org.eclipse.jgit.lib.Config;
@@ -169,8 +174,17 @@ public class RestApiServlet extends HttpServlet {
// TODO: Remove when HttpServletResponse.SC_UNPROCESSABLE_ENTITY is available
private static final int SC_UNPROCESSABLE_ENTITY = 422;
private static final String X_REQUESTED_WITH = "X-Requested-With";
private static final String X_GERRIT_AUTH = "X-Gerrit-Auth";
static final ImmutableSet<String> ALLOWED_CORS_METHODS =
ImmutableSet.of("GET", "HEAD", "POST", "PUT", "DELETE");
private static final ImmutableSet<String> ALLOWED_CORS_REQUEST_HEADERS =
ImmutableSet.of(X_REQUESTED_WITH);
Stream.of(AUTHORIZATION, CONTENT_TYPE, X_GERRIT_AUTH, X_REQUESTED_WITH)
.map(s -> s.toLowerCase(Locale.US))
.collect(ImmutableSet.toImmutableSet());
public static final String XD_AUTHORIZATION = "access_token";
public static final String XD_CONTENT_TYPE = "$ct";
public static final String XD_METHOD = "$m";
private static final int HEAP_EST_SIZE = 10 * 8 * 1024; // Presize 10 blocks.
@@ -252,8 +266,7 @@ public class RestApiServlet extends HttpServlet {
int status = SC_OK;
long responseBytes = -1;
Object result = null;
ListMultimap<String, String> params = MultimapBuilder.hashKeys().arrayListValues().build();
ListMultimap<String, String> config = MultimapBuilder.hashKeys().arrayListValues().build();
QueryParams qp = null;
Object inputRequestBody = null;
RestResource rsrc = TopLevelResource.INSTANCE;
ViewData viewData = null;
@@ -263,10 +276,13 @@ public class RestApiServlet extends HttpServlet {
doCorsPreflight(req, res);
return;
}
checkCors(req, res);
checkUserSession(req);
ParameterParser.splitQueryString(req.getQueryString(), config, params);
qp = ParameterParser.getQueryParams(req);
checkCors(req, res, qp.hasXdOverride());
if (qp.hasXdOverride()) {
req = applyXdOverrides(req, qp);
}
checkUserSession(req);
List<IdString> path = splitPath(req);
RestCollection<RestResource, RestResource> rc = members.get();
@@ -279,7 +295,7 @@ public class RestApiServlet extends HttpServlet {
if (path.isEmpty()) {
if (rc instanceof NeedsParams) {
((NeedsParams) rc).setParams(params);
((NeedsParams) rc).setParams(qp.params());
}
if (isRead(req)) {
@@ -372,7 +388,7 @@ public class RestApiServlet extends HttpServlet {
return;
}
if (!globals.paramParser.get().parse(viewData.view, params, req, res)) {
if (!globals.paramParser.get().parse(viewData.view, qp.params(), req, res)) {
return;
}
@@ -415,7 +431,7 @@ public class RestApiServlet extends HttpServlet {
if (result instanceof BinaryResult) {
responseBytes = replyBinaryResult(req, res, (BinaryResult) result);
} else {
responseBytes = replyJson(req, res, config, result);
responseBytes = replyJson(req, res, qp.config(), result);
}
}
} catch (MalformedJsonException e) {
@@ -490,7 +506,7 @@ public class RestApiServlet extends HttpServlet {
globals.currentUser.get(),
req,
auditStartTs,
params,
qp != null ? qp.params() : ImmutableListMultimap.of(),
inputRequestBody,
status,
result,
@@ -499,11 +515,50 @@ public class RestApiServlet extends HttpServlet {
}
}
private void checkCors(HttpServletRequest req, HttpServletResponse res) {
private static HttpServletRequest applyXdOverrides(HttpServletRequest req, QueryParams qp)
throws BadRequestException {
if (!"POST".equals(req.getMethod())) {
throw new BadRequestException("POST required");
}
String method = qp.xdMethod();
String contentType = qp.xdContentType();
if (method.equals("POST") || method.equals("PUT")) {
if (!"text/plain".equals(req.getContentType())) {
throw new BadRequestException("invalid " + CONTENT_TYPE);
} else if (Strings.isNullOrEmpty(contentType)) {
throw new BadRequestException(XD_CONTENT_TYPE + " required");
}
}
return new HttpServletRequestWrapper(req) {
@Override
public String getMethod() {
return method;
}
@Override
public String getContentType() {
return contentType;
}
};
}
private void checkCors(HttpServletRequest req, HttpServletResponse res, boolean isXd)
throws BadRequestException {
String origin = req.getHeader(ORIGIN);
if (isRead(req) && !Strings.isNullOrEmpty(origin) && isOriginAllowed(origin)) {
if (!Strings.isNullOrEmpty(origin)) {
res.addHeader(VARY, ORIGIN);
setCorsHeaders(res, origin);
if (!isOriginAllowed(origin)) {
throw new BadRequestException("origin not allowed");
}
if (isXd) {
res.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
} else {
setCorsHeaders(res, origin);
}
} else if (isXd) {
throw new BadRequestException("expected " + ORIGIN);
}
}
@@ -516,8 +571,10 @@ public class RestApiServlet extends HttpServlet {
private void doCorsPreflight(HttpServletRequest req, HttpServletResponse res)
throws BadRequestException {
CacheHeaders.setNotCacheable(res);
res.setHeader(
VARY, Joiner.on(", ").join(ImmutableList.of(ORIGIN, ACCESS_CONTROL_REQUEST_METHOD)));
setHeaderList(
res,
VARY,
ImmutableList.of(ORIGIN, ACCESS_CONTROL_REQUEST_METHOD, ACCESS_CONTROL_REQUEST_HEADERS));
String origin = req.getHeader(ORIGIN);
if (Strings.isNullOrEmpty(origin) || !isOriginAllowed(origin)) {
@@ -525,20 +582,17 @@ public class RestApiServlet extends HttpServlet {
}
String method = req.getHeader(ACCESS_CONTROL_REQUEST_METHOD);
if (!"GET".equals(method) && !"HEAD".equals(method)) {
if (!ALLOWED_CORS_METHODS.contains(method)) {
throw new BadRequestException(method + " not allowed in CORS");
}
String headers = req.getHeader(ACCESS_CONTROL_REQUEST_HEADERS);
if (headers != null) {
res.addHeader(VARY, ACCESS_CONTROL_REQUEST_HEADERS);
String badHeader =
Streams.stream(Splitter.on(',').trimResults().split(headers))
.filter(h -> !ALLOWED_CORS_REQUEST_HEADERS.contains(h))
.findFirst()
.orElse(null);
if (badHeader != null) {
throw new BadRequestException(badHeader + " not allowed in CORS");
for (String reqHdr : Splitter.on(',').trimResults().split(headers)) {
if (!ALLOWED_CORS_REQUEST_HEADERS.contains(reqHdr.toLowerCase(Locale.US))) {
throw new BadRequestException(reqHdr + " not allowed in CORS");
}
}
}
@@ -548,11 +602,19 @@ public class RestApiServlet extends HttpServlet {
res.setContentLength(0);
}
private void setCorsHeaders(HttpServletResponse res, String origin) {
private static void setCorsHeaders(HttpServletResponse res, String origin) {
res.setHeader(ACCESS_CONTROL_ALLOW_ORIGIN, origin);
res.setHeader(ACCESS_CONTROL_ALLOW_CREDENTIALS, "true");
res.setHeader(ACCESS_CONTROL_ALLOW_METHODS, "GET, OPTIONS");
res.setHeader(ACCESS_CONTROL_ALLOW_HEADERS, Joiner.on(", ").join(ALLOWED_CORS_REQUEST_HEADERS));
res.setHeader(ACCESS_CONTROL_MAX_AGE, "600");
setHeaderList(
res,
ACCESS_CONTROL_ALLOW_METHODS,
Iterables.concat(ALLOWED_CORS_METHODS, ImmutableList.of("OPTIONS")));
setHeaderList(res, ACCESS_CONTROL_ALLOW_HEADERS, ALLOWED_CORS_REQUEST_HEADERS);
}
private static void setHeaderList(HttpServletResponse res, String name, Iterable<String> values) {
res.setHeader(name, Joiner.on(", ").join(values));
}
private boolean isOriginAllowed(String origin) {
@@ -1054,7 +1116,8 @@ public class RestApiServlet extends HttpServlet {
throw new AmbiguousViewException(
String.format(
"Projection %s is ambiguous: %s",
name, r.keySet().stream().map(in -> in + "~" + projection).collect(joining(", "))));
name,
r.keySet().stream().map(in -> in + "~" + projection).collect(joining(", "))));
}
}