Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 50 additions & 43 deletions kubernetes/base/stream/ws_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,69 +353,76 @@ def _proxy(self):
local_all_closed = True
for port in self.local_ports.values():
if port.python.fileno() != -1:
if port.error or not self.websocket.connected:
if self.websocket.connected:
rlist.append(port.python)
if port.data:
wlist.append(port.python)
local_all_closed = False
else:
port.python.close()
local_all_closed = False
else:
rlist.append(port.python)
if port.data:
wlist.append(port.python)
local_all_closed = False
local_all_closed = False
else:
port.python.close()
if local_all_closed and not (self.websocket.connected and kubernetes_data):
self.websocket.close()
return
r, w, _ = select.select(rlist, wlist, [])
for sock in r:
if sock == self.websocket:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_BINARY:
if not frame.data:
raise RuntimeError("Unexpected frame data size")
channel = six.byte2int(frame.data)
if channel >= len(channel_ports):
raise RuntimeError("Unexpected channel number: %s" % channel)
port = channel_ports[channel]
if channel_initialized[channel]:
if channel % 2:
if port.error is None:
port.error = ''
port.error += frame.data[1:].decode()
pending = True
while pending:
opcode, frame = self.websocket.recv_data_frame(True)
if opcode == ABNF.OPCODE_BINARY:
if not frame.data:
raise RuntimeError("Unexpected frame data size")
channel = six.byte2int(frame.data)
if channel >= len(channel_ports):
raise RuntimeError("Unexpected channel number: %s" % channel)
port = channel_ports[channel]
if channel_initialized[channel]:
if channel % 2:
if port.error is None:
port.error = ''
port.error += frame.data[1:].decode()
port.python.close()
else:
port.data += frame.data[1:]
else:
port.data += frame.data[1:]
else:
if len(frame.data) != 3:
raise RuntimeError(
"Unexpected initial channel frame data size"
)
port_number = six.byte2int(frame.data[1:2]) + (six.byte2int(frame.data[2:3]) * 256)
if port_number != port.port_number:
raise RuntimeError(
"Unexpected port number in initial channel frame: %s" % port_number
)
channel_initialized[channel] = True
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE):
raise RuntimeError("Unexpected websocket opcode: %s" % opcode)
if len(frame.data) != 3:
raise RuntimeError(
"Unexpected initial channel frame data size"
)
port_number = six.byte2int(frame.data[1:2]) + (six.byte2int(frame.data[2:3]) * 256)
if port_number != port.port_number:
raise RuntimeError(
"Unexpected port number in initial channel frame: %s" % port_number
)
channel_initialized[channel] = True
elif opcode not in (ABNF.OPCODE_PING, ABNF.OPCODE_PONG, ABNF.OPCODE_CLOSE):
raise RuntimeError("Unexpected websocket opcode: %s" % opcode)
if not (isinstance(self.websocket.sock, ssl.SSLSocket) and self.websocket.sock.pending()):
pending = False
Comment thread
yliaog marked this conversation as resolved.
else:
port = local_ports[sock]
data = port.python.recv(1024 * 1024)
if data:
kubernetes_data += ABNF.create_frame(
port.channel + data,
ABNF.OPCODE_BINARY,
).format()
else:
port.python.close()
if port.python.fileno() != -1:
data = port.python.recv(1024 * 1024)
if data:
kubernetes_data += ABNF.create_frame(
port.channel + data,
ABNF.OPCODE_BINARY,
).format()
else:
port.python.close()
for sock in w:
if sock == self.websocket:
sent = self.websocket.sock.send(kubernetes_data)
kubernetes_data = kubernetes_data[sent:]
else:
port = local_ports[sock]
sent = port.python.send(port.data)
port.data = port.data[sent:]
if port.python.fileno() != -1:
sent = port.python.send(port.data)
port.data = port.data[sent:]


def get_websocket_url(url, query_params=None):
Expand Down
4 changes: 1 addition & 3 deletions kubernetes/e2e_test/port_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
#!/usr/bin/env python

import select
import socketserver
import sys
Expand Down Expand Up @@ -28,6 +26,7 @@ def handler(self, request, address, server):
data = request.recv(1024)
if not data:
break
print(f"{self.port}: {data}\n", end='', flush=True)
echo += data
if w:
echo = echo[request.send(echo):]
Expand All @@ -38,4 +37,3 @@ def handler(self, request, address, server):
for port in sys.argv[1:]:
ports.append(PortServer(int(port)))
time.sleep(10 * 60)

142 changes: 71 additions & 71 deletions kubernetes/e2e_test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,6 @@ def test_exit_code(self):
resp = api.delete_namespaced_pod(name=name, body={},
namespace='default')

# Skipping this test as this flakes a lot
# See: https://github.com/kubernetes-client/python/issues/1300
# Re-enable the test once the flakiness is investigated
@unittest.skip("skipping due to extreme flakiness")
Comment thread
yliaog marked this conversation as resolved.
def test_portforward_raw(self):
client = api_client.ApiClient(configuration=self.config)
api = core_v1_api.CoreV1Api(client)
Expand Down Expand Up @@ -267,7 +263,7 @@ def test_portforward_raw(self):
'name': 'port-server',
'image': 'python',
'command': [
'/opt/port-server.py', '1234', '1235',
'python', '-u', '/opt/port-server.py', '1234', '1235',
],
'volumeMounts': [
{
Expand All @@ -278,17 +274,19 @@ def test_portforward_raw(self):
],
'startupProbe': {
'tcpSocket': {
'port': 1234,
'port': 1235,
},
'periodSeconds': 1,
'failureThreshold': 30,
},
},
],
'restartPolicy': 'Never',
'volumes': [
{
'name': 'port-server',
'configMap': {
'name': name,
'defaultMode': 0o777,
},
},
],
Expand All @@ -299,77 +297,79 @@ def test_portforward_raw(self):
self.assertEqual(name, resp.metadata.name)
self.assertTrue(resp.status.phase)

timeout = time.time() + 60
while True:
resp = api.read_namespaced_pod(name=name,
namespace='default')
self.assertEqual(name, resp.metadata.name)
self.assertTrue(resp.status.phase)
if resp.status.phase != 'Pending':
break
if resp.status.phase == 'Running':
if resp.status.container_statuses[0].ready:
break
else:
self.assertEqual(resp.status.phase, 'Pending')
self.assertTrue(time.time() < timeout)
time.sleep(1)
self.assertEqual(resp.status.phase, 'Running')

pf = portforward(api.connect_get_namespaced_pod_portforward,
name, 'default',
ports='1234,1235,1236')
self.assertTrue(pf.connected)
sock1234 = pf.socket(1234)
sock1235 = pf.socket(1235)
sock1234.setblocking(True)
sock1235.setblocking(True)
sent1234 = b'Test port 1234 forwarding...'
sent1235 = b'Test port 1235 forwarding...'
sock1234.sendall(sent1234)
sock1235.sendall(sent1235)
reply1234 = b''
reply1235 = b''
while True:
rlist = []
if sock1234.fileno() != -1:
rlist.append(sock1234)
if sock1235.fileno() != -1:
rlist.append(sock1235)
if not rlist:
break
r, _w, _x = select.select(rlist, [], [], 1)
if not r:
break
if sock1234 in r:
data = sock1234.recv(1024)
self.assertNotEqual(data, b'', "Unexpected socket close")
reply1234 += data
if sock1235 in r:
data = sock1235.recv(1024)
self.assertNotEqual(data, b'', "Unexpected socket close")
reply1235 += data
self.assertEqual(reply1234, sent1234)
self.assertEqual(reply1235, sent1235)
self.assertTrue(pf.connected)

sock = pf.socket(1236)
self.assertRaises(socket.error, sock.sendall, b'This should fail...')
self.assertIsNotNone(pf.error(1236))
sock.close()

for sock in (sock1234, sock1235):

for ix in range(10):
ix = str(ix + 1).encode()
pf = portforward(api.connect_get_namespaced_pod_portforward,
name, 'default',
ports='1234,1235,1236')
self.assertTrue(pf.connected)
sent = b'Another test using fileno %s' % str(
sock.fileno()).encode()
sock.sendall(sent)
reply = b''
while True:
r, _w, _x = select.select([sock], [], [], 1)
if not r:
break
data = sock.recv(1024)
self.assertNotEqual(data, b'', "Unexpected socket close")
reply += data
self.assertEqual(reply, sent)
sock1234 = pf.socket(1234)
sock1235 = pf.socket(1235)
sock1234.setblocking(True)
sock1235.setblocking(True)
sent1234 = b'Test ' + ix + b' port 1234 forwarding'
sent1235 = b'Test ' + ix + b' port 1235 forwarding'
sock1234.sendall(sent1234)
sock1235.sendall(sent1235)
reply1234 = b''
reply1235 = b''
timeout = time.time() + 60
while reply1234 != sent1234 or reply1235 != sent1235:
self.assertNotEqual(sock1234.fileno(), -1)
self.assertNotEqual(sock1235.fileno(), -1)
self.assertTrue(time.time() < timeout)
r, _w, _x = select.select([sock1234, sock1235], [], [], 1)
if sock1234 in r:
data = sock1234.recv(1024)
self.assertNotEqual(data, b'', 'Unexpected socket close')
reply1234 += data
self.assertTrue(sent1234.startswith(reply1234))
if sock1235 in r:
data = sock1235.recv(1024)
self.assertNotEqual(data, b'', 'Unexpected socket close')
reply1235 += data
self.assertTrue(sent1235.startswith(reply1235))
self.assertTrue(pf.connected)

sock = pf.socket(1236)
sock.setblocking(True)
self.assertEqual(sock.recv(1024), b'')
self.assertIsNotNone(pf.error(1236))
sock.close()
time.sleep(1)
self.assertFalse(pf.connected)
self.assertIsNone(pf.error(1234))
self.assertIsNone(pf.error(1235))

for sock in (sock1234, sock1235):
self.assertTrue(pf.connected)
sent = b'Another test ' + ix + b' using fileno ' + str(sock.fileno()).encode()
sock.sendall(sent)
reply = b''
timeout = time.time() + 60
while reply != sent:
self.assertNotEqual(sock.fileno(), -1)
self.assertTrue(time.time() < timeout)
r, _w, _x = select.select([sock], [], [], 1)
if r:
data = sock.recv(1024)
self.assertNotEqual(data, b'', 'Unexpected socket close')
reply += data
self.assertTrue(sent.startswith(reply))
sock.close()
time.sleep(1)
self.assertFalse(pf.connected)
self.assertIsNone(pf.error(1234))
self.assertIsNone(pf.error(1235))

resp = api.delete_namespaced_pod(name=name, namespace='default')
resp = api.delete_namespaced_config_map(name=name, namespace='default')
Expand Down