diff --git a/java/com/google/gerrit/sshd/ChannelIdTrackingUnknownChannelReferenceHandler.java b/java/com/google/gerrit/sshd/ChannelIdTrackingUnknownChannelReferenceHandler.java new file mode 100644 index 0000000000..f8ab90e56e --- /dev/null +++ b/java/com/google/gerrit/sshd/ChannelIdTrackingUnknownChannelReferenceHandler.java @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/** + * This file is based on sshd-contrib Apache SSHD Mina project. Original commit: + * https://github.com/apache/mina-sshd/commit/11b33dee37b5b9c71a40a8a98a42007e3687131e + */ +package com.google.gerrit.sshd; + +import com.google.common.flogger.FluentLogger; +import java.io.IOException; +import org.apache.sshd.common.AttributeRepository.AttributeKey; +import org.apache.sshd.common.SshConstants; +import org.apache.sshd.common.channel.Channel; +import org.apache.sshd.common.channel.ChannelListener; +import org.apache.sshd.common.channel.exception.SshChannelNotFoundException; +import org.apache.sshd.common.session.ConnectionService; +import org.apache.sshd.common.session.Session; +import org.apache.sshd.common.session.helpers.DefaultUnknownChannelReferenceHandler; +import org.apache.sshd.common.util.buffer.Buffer; + +/** + * Makes sure that the referenced "unknown" channel identifier is one that was assigned in + * the past. Note: it relies on the fact that the default {@code ConnectionService} + * implementation assigns channels identifiers in ascending order. + * + * @author Apache MINA SSHD Project + */ +public class ChannelIdTrackingUnknownChannelReferenceHandler + extends DefaultUnknownChannelReferenceHandler implements ChannelListener { + private static final FluentLogger logger = FluentLogger.forEnclosingClass(); + public static final AttributeKey LAST_CHANNEL_ID_KEY = new AttributeKey<>(); + + public static final ChannelIdTrackingUnknownChannelReferenceHandler TRACKER = + new ChannelIdTrackingUnknownChannelReferenceHandler(); + + public ChannelIdTrackingUnknownChannelReferenceHandler() { + super(); + } + + @Override + public void channelInitialized(Channel channel) { + int channelId = channel.getId(); + Session session = channel.getSession(); + Integer lastTracked = session.setAttribute(LAST_CHANNEL_ID_KEY, channelId); + logger.atFine().log( + "channelInitialized(%s) updated last tracked channel ID %s => %s", + channel, lastTracked, channelId); + } + + @Override + public Channel handleUnknownChannelCommand( + ConnectionService service, byte cmd, int channelId, Buffer buffer) throws IOException { + Session session = service.getSession(); + Integer lastTracked = session.getAttribute(LAST_CHANNEL_ID_KEY); + if ((lastTracked != null) && (channelId <= lastTracked.intValue())) { + // Use TRACE level in order to avoid messages flooding + logger.atFinest().log( + "handleUnknownChannelCommand(%s) apply default handling for %s on channel=%s (lastTracked=%s)", + session, SshConstants.getCommandMessageName(cmd), channelId, lastTracked); + return super.handleUnknownChannelCommand(service, cmd, channelId, buffer); + } + + throw new SshChannelNotFoundException( + channelId, + "Received " + + SshConstants.getCommandMessageName(cmd) + + " on unassigned channel " + + channelId + + " (last assigned=" + + lastTracked + + ")"); + } +} diff --git a/java/com/google/gerrit/sshd/SshDaemon.java b/java/com/google/gerrit/sshd/SshDaemon.java index 84cf98a985..7512b3eabe 100644 --- a/java/com/google/gerrit/sshd/SshDaemon.java +++ b/java/com/google/gerrit/sshd/SshDaemon.java @@ -209,6 +209,7 @@ public class SshDaemon extends SshServer implements SshInfo, LifecycleListener { final boolean enableCompression = cfg.getBoolean("sshd", "enableCompression", false); SshSessionBackend backend = cfg.getEnum("sshd", null, "backend", SshSessionBackend.NIO2); + boolean channelIdTracking = cfg.getBoolean("sshd", "enableChannelIdTracking", true); System.setProperty( IoServiceFactoryFactory.class.getName(), @@ -222,7 +223,7 @@ public class SshDaemon extends SshServer implements SshInfo, LifecycleListener { initMacs(cfg); initSignatures(); initChannels(); - initUnknownChannelReferenceHandler(); + initUnknownChannelReferenceHandler(channelIdTracking); initForwarding(); initFileSystemFactory(); initSubsystems(); @@ -653,8 +654,11 @@ public class SshDaemon extends SshServer implements SshInfo, LifecycleListener { setChannelFactories(ServerBuilder.DEFAULT_CHANNEL_FACTORIES); } - private void initUnknownChannelReferenceHandler() { - setUnknownChannelReferenceHandler(DefaultUnknownChannelReferenceHandler.INSTANCE); + private void initUnknownChannelReferenceHandler(boolean enableChannelIdTracking) { + setUnknownChannelReferenceHandler( + enableChannelIdTracking + ? ChannelIdTrackingUnknownChannelReferenceHandler.TRACKER + : DefaultUnknownChannelReferenceHandler.INSTANCE); } private void initSubsystems() {