I suspect take's `axis` argument will be needed at some point. Can we add a simple implementation for PyTorch? ```python def take(array, indices, *, axis): key = [slice(None)] * array.ndim key[axis] = indices return array[key] ```